[misc] upgrade format to py39 (#7256)

This commit is contained in:
hoshi-hiyouga
2025-03-12 00:08:41 +08:00
committed by GitHub
parent 5995800bce
commit 264538cb26
113 changed files with 984 additions and 1407 deletions

View File

@@ -18,7 +18,7 @@
import json
import os
from types import MethodType
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Optional, Union
import torch
from transformers import Trainer
@@ -41,9 +41,7 @@ logger = logging.get_logger(__name__)
class PairwiseTrainer(Trainer):
r"""
Inherits Trainer to compute pairwise loss.
"""
r"""Inherits Trainer to compute pairwise loss."""
def __init__(
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
@@ -88,10 +86,9 @@ class PairwiseTrainer(Trainer):
@override
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
self, model: "PreTrainedModel", inputs: dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", tuple["torch.Tensor", list["torch.Tensor"]]]:
r"""Compute pairwise loss. The first n examples are chosen and the last n examples are rejected.
Subclass and override to inject custom behavior.
@@ -113,8 +110,7 @@ class PairwiseTrainer(Trainer):
return loss
def save_predictions(self, predict_results: "PredictionOutput") -> None:
r"""
Saves model predictions to `output_dir`.
r"""Save model predictions to `output_dir`.
A custom behavior that not contained in Seq2SeqTrainer.
"""
@@ -126,7 +122,7 @@ class PairwiseTrainer(Trainer):
chosen_scores, rejected_scores = predict_results.predictions
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
res: list[str] = []
for c_score, r_score in zip(chosen_scores, rejected_scores):
res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))