Merge branch 'hiyouga:main' into pixtral-patch

Former-commit-id: 438302edfdb66b6397266b8b17ac66f60a89300c
This commit is contained in:
Kingsley
2024-10-29 21:01:25 +08:00
committed by GitHub
81 changed files with 1145 additions and 1072 deletions

View File

@@ -44,7 +44,6 @@ INFER_ARGS = {
"finetuning_type": "lora",
"template": "llama3",
"infer_dtype": "float16",
"export_dir": "llama3_export",
}
OS_NAME = os.environ.get("OS_NAME", "")
@@ -61,11 +60,12 @@ OS_NAME = os.environ.get("OS_NAME", "")
],
)
def test_run_exp(stage: str, dataset: str):
output_dir = "train_{}".format(stage)
output_dir = os.path.join("output", f"train_{stage}")
run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS})
assert os.path.exists(output_dir)
def test_export():
export_model(INFER_ARGS)
assert os.path.exists("llama3_export")
export_dir = os.path.join("output", "llama3_export")
export_model({"export_dir": export_dir, **INFER_ARGS})
assert os.path.exists(export_dir)

View File

@@ -52,7 +52,7 @@ INFER_ARGS = {
OS_NAME = os.environ.get("OS_NAME", "")
@pytest.mark.xfail(OS_NAME.startswith("windows"), reason="Known connection error on Windows.")
@pytest.mark.xfail(reason="PiSSA initialization is not stable in different platform.")
def test_pissa_train():
model = load_train_model(**TRAIN_ARGS)
ref_model = load_reference_model(TINY_LLAMA_PISSA, TINY_LLAMA_PISSA, use_pissa=True, is_trainable=True)