Former-commit-id: 067ba6e6cb4d8a1d95bba0a108f73008416a2865
This commit is contained in:
hiyouga
2024-12-19 12:16:30 +00:00
parent 0a465fc3ca
commit 0385c60177
6 changed files with 22 additions and 16 deletions

View File

@@ -60,12 +60,12 @@ OS_NAME = os.getenv("OS_NAME", "")
],
)
def test_run_exp(stage: str, dataset: str):
output_dir = os.path.join("output", f"dummy_dir/train_{stage}")
output_dir = os.path.join("output", f"train_{stage}")
run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS})
assert os.path.exists(output_dir)
def test_export():
export_dir = os.path.join("output", "dummy_dir/llama3_export")
export_dir = os.path.join("output", "llama3_export")
export_model({"export_dir": export_dir, **INFER_ARGS})
assert os.path.exists(export_dir)

View File

@@ -58,7 +58,11 @@ class DataCollatorWithVerbose(DataCollatorWithPadding):
@pytest.mark.parametrize("disable_shuffling", [False, True])
def test_shuffle(disable_shuffling: bool):
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
{"output_dir": f"dummy_dir/{disable_shuffling}", "disable_shuffling": disable_shuffling, **TRAIN_ARGS}
{
"output_dir": os.path.join("output", f"shuffle{str(disable_shuffling).lower()}"),
"disable_shuffling": disable_shuffling,
**TRAIN_ARGS,
}
)
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]