[misc] upgrade format to py39 (#7256)

This commit is contained in:
hoshi-hiyouga
2025-03-12 00:08:41 +08:00
committed by GitHub
parent 5995800bce
commit 264538cb26
113 changed files with 984 additions and 1407 deletions

View File

@@ -19,7 +19,7 @@ import warnings
from collections import defaultdict
from contextlib import nullcontext
from types import MethodType
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, Literal, Optional, Union
import torch
import torch.nn.functional as F
@@ -129,15 +129,11 @@ class CustomDPOTrainer(DPOTrainer):
@override
def get_batch_samples(self, epoch_iterator, num_batches):
r"""
Replaces the method of KTO Trainer with the one of the standard Trainer.
"""
r"""Replace the method of DPO Trainer with the one of the standard Trainer."""
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
r"""
Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.
"""
r"""Compute ORPO's odds ratio (OR) loss for batched log probabilities of the policy model."""
log_odds = (chosen_logps - rejected_logps) - (
torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps))
)
@@ -147,9 +143,7 @@ class CustomDPOTrainer(DPOTrainer):
return orpo_loss
def simpo_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
r"""
Computes SimPO loss for batched log probabilities of the policy model.
"""
r"""Compute SimPO loss for batched log probabilities of the policy model."""
pi_logratios = chosen_logps - rejected_logps
gamma_logratios = self.simpo_gamma / self.beta
logits = pi_logratios - gamma_logratios
@@ -162,10 +156,8 @@ class CustomDPOTrainer(DPOTrainer):
policy_rejected_logps: "torch.Tensor",
reference_chosen_logps: Optional["torch.Tensor"],
reference_rejected_logps: Optional["torch.Tensor"],
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""
Computes loss for preference learning.
"""
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""Compute loss for preference learning."""
if not self.finetuning_args.use_ref_model:
if self.loss_type == "orpo":
losses = self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps)
@@ -185,17 +177,16 @@ class CustomDPOTrainer(DPOTrainer):
@override
def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""
Computes the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
Otherwise the average log probabilities.
"""
if self.finetuning_args.use_ref_model:
batch = nested_detach(batch, clone=True) # avoid error
all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])
if self.loss_type in ["ipo", "orpo", "simpo"]:
all_logps = all_logps / valid_length
@@ -212,11 +203,9 @@ class CustomDPOTrainer(DPOTrainer):
@override
def compute_reference_log_probs(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
r"""
Computes log probabilities of the reference model.
"""
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
) -> tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
r"""Compute log probabilities of the reference model."""
if not self.finetuning_args.use_ref_model:
return None, None
@@ -236,12 +225,10 @@ class CustomDPOTrainer(DPOTrainer):
def get_batch_loss_metrics(
self,
model: "PreTrainedModel",
batch: Dict[str, "torch.Tensor"],
batch: dict[str, "torch.Tensor"],
train_eval: Literal["train", "eval"] = "train",
) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
r"""
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
"""
) -> tuple["torch.Tensor", dict[str, "torch.Tensor"]]:
r"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {}
(
policy_chosen_logps,
@@ -279,18 +266,14 @@ class CustomDPOTrainer(DPOTrainer):
@override
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r"""
Subclass and override to accept extra kwargs.
"""
self, model: "PreTrainedModel", inputs: dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", tuple["torch.Tensor", list["torch.Tensor"]]]:
r"""Subclass and override to accept extra kwargs."""
return super().compute_loss(model, inputs, return_outputs)
@override
def log(self, logs: Dict[str, float], *args, **kwargs) -> None:
r"""
Log `logs` on the various objects watching training, including stored metrics.
"""
def log(self, logs: dict[str, float], *args, **kwargs) -> None:
r"""Log `logs` on the various objects watching training, including stored metrics."""
# logs either has "loss" or "eval_loss"
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs