[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -13,8 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Literal, Optional, Sequence
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import fire
|
||||
import torch
|
||||
@@ -30,16 +31,12 @@ from llamafactory.model import load_model, load_tokenizer
|
||||
|
||||
@dataclass
|
||||
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for pairwise data.
|
||||
"""
|
||||
r"""Data collator for pairwise data."""
|
||||
|
||||
train_on_prompt: bool = False
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Pads batched data to the longest sequence in the batch.
|
||||
"""
|
||||
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, torch.Tensor]:
|
||||
r"""Pad batched data to the longest sequence in the batch."""
|
||||
chosen_features = []
|
||||
for feature in features:
|
||||
chosen_features.append(
|
||||
@@ -68,8 +65,8 @@ def calculate_ppl(
|
||||
max_samples: Optional[int] = None,
|
||||
train_on_prompt: bool = False,
|
||||
):
|
||||
r"""
|
||||
Calculates the ppl on the dataset of the pre-trained models.
|
||||
r"""Calculate the ppl on the dataset of the pre-trained models.
|
||||
|
||||
Usage: export CUDA_VISIBLE_DEVICES=0
|
||||
python cal_ppl.py --model_name_or_path path_to_model --dataset alpaca_en_demo --save_name ppl.json
|
||||
"""
|
||||
@@ -111,17 +108,17 @@ def calculate_ppl(
|
||||
criterion = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
total_ppl = 0
|
||||
perplexities = []
|
||||
batch: Dict[str, "torch.Tensor"]
|
||||
batch: dict[str, torch.Tensor]
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataloader, desc="Computing perplexities"):
|
||||
batch = batch.to(model.device)
|
||||
outputs = model(**batch)
|
||||
shift_logits: "torch.Tensor" = outputs["logits"][..., :-1, :]
|
||||
shift_labels: "torch.Tensor" = batch["labels"][..., 1:]
|
||||
shift_logits: torch.Tensor = outputs["logits"][..., :-1, :]
|
||||
shift_labels: torch.Tensor = batch["labels"][..., 1:]
|
||||
loss_mask = shift_labels != IGNORE_INDEX
|
||||
flatten_logits = shift_logits.contiguous().view(shift_labels.size(0) * shift_labels.size(1), -1)
|
||||
flatten_labels = shift_labels.contiguous().view(-1)
|
||||
token_logps: "torch.Tensor" = criterion(flatten_logits, flatten_labels)
|
||||
token_logps: torch.Tensor = criterion(flatten_logits, flatten_labels)
|
||||
token_logps = token_logps.contiguous().view(shift_logits.size(0), -1)
|
||||
sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
||||
total_ppl += sentence_logps.exp().sum().item()
|
||||
|
||||
Reference in New Issue
Block a user