@@ -60,12 +60,12 @@ OS_NAME = os.getenv("OS_NAME", "")
|
||||
],
|
||||
)
|
||||
def test_run_exp(stage: str, dataset: str):
|
||||
output_dir = os.path.join("output", f"dummy_dir/train_{stage}")
|
||||
output_dir = os.path.join("output", f"train_{stage}")
|
||||
run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS})
|
||||
assert os.path.exists(output_dir)
|
||||
|
||||
|
||||
def test_export():
|
||||
export_dir = os.path.join("output", "dummy_dir/llama3_export")
|
||||
export_dir = os.path.join("output", "llama3_export")
|
||||
export_model({"export_dir": export_dir, **INFER_ARGS})
|
||||
assert os.path.exists(export_dir)
|
||||
|
||||
Reference in New Issue
Block a user