@@ -60,12 +60,12 @@ OS_NAME = os.getenv("OS_NAME", "")
|
||||
],
|
||||
)
|
||||
def test_run_exp(stage: str, dataset: str):
|
||||
output_dir = os.path.join("output", f"dummy_dir/train_{stage}")
|
||||
output_dir = os.path.join("output", f"train_{stage}")
|
||||
run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS})
|
||||
assert os.path.exists(output_dir)
|
||||
|
||||
|
||||
def test_export():
|
||||
export_dir = os.path.join("output", "dummy_dir/llama3_export")
|
||||
export_dir = os.path.join("output", "llama3_export")
|
||||
export_model({"export_dir": export_dir, **INFER_ARGS})
|
||||
assert os.path.exists(export_dir)
|
||||
|
||||
@@ -58,7 +58,11 @@ class DataCollatorWithVerbose(DataCollatorWithPadding):
|
||||
@pytest.mark.parametrize("disable_shuffling", [False, True])
|
||||
def test_shuffle(disable_shuffling: bool):
|
||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
|
||||
{"output_dir": f"dummy_dir/{disable_shuffling}", "disable_shuffling": disable_shuffling, **TRAIN_ARGS}
|
||||
{
|
||||
"output_dir": os.path.join("output", f"shuffle{str(disable_shuffling).lower()}"),
|
||||
"disable_shuffling": disable_shuffling,
|
||||
**TRAIN_ARGS,
|
||||
}
|
||||
)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
|
||||
Reference in New Issue
Block a user