update scripts

Former-commit-id: 51d087cbc14bf3c7dfa06b8b66052cd80a6081be
This commit is contained in:
hiyouga
2024-09-08 14:17:41 +08:00
parent eb5af3d90b
commit 3259397f89
6 changed files with 24 additions and 11 deletions

View File

@@ -60,7 +60,7 @@ def calculate_ppl(
save_name: str,
batch_size: int = 4,
stage: Literal["pt", "sft", "rm"] = "sft",
dataset: str = "alpaca_en",
dataset: str = "alpaca_en_demo",
dataset_dir: str = "data",
template: str = "default",
cutoff_len: int = 1024,
@@ -69,7 +69,7 @@ def calculate_ppl(
):
r"""
Calculates the ppl on the dataset of the pre-trained models.
Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json
Usage: python cal_ppl.py --model_name_or_path path_to_model --dataset alpaca_en_demo --save_name ppl.json
"""
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
dict(