update scripts
Former-commit-id: 05aa52adde8905ca892f1ed5847d6f90b1992848
This commit is contained in:
@@ -63,7 +63,7 @@ def calculate_ppl(
|
||||
dataset: str = "alpaca_en_demo",
|
||||
dataset_dir: str = "data",
|
||||
template: str = "default",
|
||||
cutoff_len: int = 1024,
|
||||
cutoff_len: int = 2048,
|
||||
max_samples: Optional[int] = None,
|
||||
train_on_prompt: bool = False,
|
||||
):
|
||||
@@ -82,6 +82,7 @@ def calculate_ppl(
|
||||
cutoff_len=cutoff_len,
|
||||
max_samples=max_samples,
|
||||
train_on_prompt=train_on_prompt,
|
||||
preprocessing_num_workers=16,
|
||||
output_dir="dummy_dir",
|
||||
overwrite_cache=True,
|
||||
do_train=True,
|
||||
@@ -111,7 +112,7 @@ def calculate_ppl(
|
||||
perplexities = []
|
||||
batch: Dict[str, "torch.Tensor"]
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataloader):
|
||||
for batch in tqdm(dataloader, desc="Computing perplexities"):
|
||||
batch = batch.to(model.device)
|
||||
outputs = model(**batch)
|
||||
shift_logits: "torch.Tensor" = outputs["logits"][..., :-1, :]
|
||||
|
||||
Reference in New Issue
Block a user