mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-02 20:43:38 +00:00
[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -18,7 +18,7 @@
|
||||
import json
|
||||
import os
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import Trainer
|
||||
@@ -41,9 +41,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class PairwiseTrainer(Trainer):
|
||||
r"""
|
||||
Inherits Trainer to compute pairwise loss.
|
||||
"""
|
||||
r"""Inherits Trainer to compute pairwise loss."""
|
||||
|
||||
def __init__(
|
||||
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
|
||||
@@ -88,10 +86,9 @@ class PairwiseTrainer(Trainer):
|
||||
|
||||
@override
|
||||
def compute_loss(
|
||||
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
||||
r"""
|
||||
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
|
||||
self, model: "PreTrainedModel", inputs: dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||
) -> Union["torch.Tensor", tuple["torch.Tensor", list["torch.Tensor"]]]:
|
||||
r"""Compute pairwise loss. The first n examples are chosen and the last n examples are rejected.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
|
||||
@@ -113,8 +110,7 @@ class PairwiseTrainer(Trainer):
|
||||
return loss
|
||||
|
||||
def save_predictions(self, predict_results: "PredictionOutput") -> None:
|
||||
r"""
|
||||
Saves model predictions to `output_dir`.
|
||||
r"""Save model predictions to `output_dir`.
|
||||
|
||||
A custom behavior that not contained in Seq2SeqTrainer.
|
||||
"""
|
||||
@@ -126,7 +122,7 @@ class PairwiseTrainer(Trainer):
|
||||
chosen_scores, rejected_scores = predict_results.predictions
|
||||
|
||||
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
||||
res: List[str] = []
|
||||
res: list[str] = []
|
||||
for c_score, r_score in zip(chosen_scores, rejected_scores):
|
||||
res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user