update scripts

Former-commit-id: 05aa52adde8905ca892f1ed5847d6f90b1992848
This commit is contained in:
hiyouga
2025-01-03 10:50:32 +00:00
parent d1a8cd67d2
commit 8516054e4d
5 changed files with 19 additions and 13 deletions

View File

@@ -63,7 +63,7 @@ def calculate_ppl(
dataset: str = "alpaca_en_demo",
dataset_dir: str = "data",
template: str = "default",
cutoff_len: int = 1024,
cutoff_len: int = 2048,
max_samples: Optional[int] = None,
train_on_prompt: bool = False,
):
@@ -82,6 +82,7 @@ def calculate_ppl(
cutoff_len=cutoff_len,
max_samples=max_samples,
train_on_prompt=train_on_prompt,
preprocessing_num_workers=16,
output_dir="dummy_dir",
overwrite_cache=True,
do_train=True,
@@ -111,7 +112,7 @@ def calculate_ppl(
perplexities = []
batch: Dict[str, "torch.Tensor"]
with torch.no_grad():
for batch in tqdm(dataloader):
for batch in tqdm(dataloader, desc="Computing perplexities"):
batch = batch.to(model.device)
outputs = model(**batch)
shift_logits: "torch.Tensor" = outputs["logits"][..., :-1, :]