Former-commit-id: 26d07de349c98b547cd6a6166ea20616d08ba343
This commit is contained in:
hiyouga
2024-10-29 10:47:04 +00:00
parent 248d5daaff
commit e2748fa967
8 changed files with 58 additions and 6 deletions

View File

@@ -44,7 +44,6 @@ INFER_ARGS = {
"finetuning_type": "lora",
"template": "llama3",
"infer_dtype": "float16",
"export_dir": "llama3_export",
}
OS_NAME = os.environ.get("OS_NAME", "")
@@ -61,11 +60,12 @@ OS_NAME = os.environ.get("OS_NAME", "")
],
)
def test_run_exp(stage: str, dataset: str):
output_dir = f"train_{stage}"
output_dir = os.path.join("output", f"train_{stage}")
run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS})
assert os.path.exists(output_dir)
def test_export():
export_model(INFER_ARGS)
assert os.path.exists("llama3_export")
export_dir = os.path.join("output", "llama3_export")
export_model({"export_dir": export_dir, **INFER_ARGS})
assert os.path.exists(export_dir)