Merge branch 'hiyouga:main' into pixtral-patch
Former-commit-id: 438302edfdb66b6397266b8b17ac66f60a89300c
This commit is contained in:
@@ -44,7 +44,6 @@ INFER_ARGS = {
|
||||
"finetuning_type": "lora",
|
||||
"template": "llama3",
|
||||
"infer_dtype": "float16",
|
||||
"export_dir": "llama3_export",
|
||||
}
|
||||
|
||||
OS_NAME = os.environ.get("OS_NAME", "")
|
||||
@@ -61,11 +60,12 @@ OS_NAME = os.environ.get("OS_NAME", "")
|
||||
],
|
||||
)
|
||||
def test_run_exp(stage: str, dataset: str):
|
||||
output_dir = "train_{}".format(stage)
|
||||
output_dir = os.path.join("output", f"train_{stage}")
|
||||
run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS})
|
||||
assert os.path.exists(output_dir)
|
||||
|
||||
|
||||
def test_export():
|
||||
export_model(INFER_ARGS)
|
||||
assert os.path.exists("llama3_export")
|
||||
export_dir = os.path.join("output", "llama3_export")
|
||||
export_model({"export_dir": export_dir, **INFER_ARGS})
|
||||
assert os.path.exists(export_dir)
|
||||
|
||||
@@ -52,7 +52,7 @@ INFER_ARGS = {
|
||||
OS_NAME = os.environ.get("OS_NAME", "")
|
||||
|
||||
|
||||
@pytest.mark.xfail(OS_NAME.startswith("windows"), reason="Known connection error on Windows.")
|
||||
@pytest.mark.xfail(reason="PiSSA initialization is not stable in different platform.")
|
||||
def test_pissa_train():
|
||||
model = load_train_model(**TRAIN_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA_PISSA, TINY_LLAMA_PISSA, use_pissa=True, is_trainable=True)
|
||||
|
||||
Reference in New Issue
Block a user