add cal_ppl script

Former-commit-id: 947068c11c0be00db2cecddb2c5842a0d6e2c321
This commit is contained in:
hiyouga
2024-05-04 22:02:25 +08:00
parent 6eda42eb7c
commit 342d7da8d7
4 changed files with 95 additions and 22 deletions

View File

@@ -3,7 +3,6 @@
# Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default
from collections import defaultdict
from typing import Optional
import fire
from tqdm import tqdm
@@ -15,10 +14,10 @@ from llmtuner.model import load_tokenizer
def length_cdf(
model_name_or_path: str,
dataset: Optional[str] = "alpaca_en",
dataset_dir: Optional[str] = "data",
template: Optional[str] = "default",
interval: Optional[int] = 1000,
dataset: str = "alpaca_en",
dataset_dir: str = "data",
template: str = "default",
interval: int = 1000,
):
model_args, data_args, training_args, _, _ = get_train_args(
dict(