[misc] upgrade format to py39 (#7256)

This commit is contained in:
hoshi-hiyouga
2025-03-12 00:08:41 +08:00
committed by GitHub
parent 5995800bce
commit 264538cb26
113 changed files with 984 additions and 1407 deletions

View File

@@ -19,7 +19,7 @@ import warnings
from collections import defaultdict
from contextlib import nullcontext
from types import MethodType
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, Literal, Optional, Union
import torch
from transformers import Trainer
@@ -120,9 +120,7 @@ class CustomKTOTrainer(KTOTrainer):
@override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
r"""
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
"""
r"""Replace the sequential sampler of KTO Trainer created by trl with the random sampler."""
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)
@@ -130,18 +128,14 @@ class CustomKTOTrainer(KTOTrainer):
@override
def get_batch_samples(self, epoch_iterator, num_batches):
r"""
Replaces the method of KTO Trainer with the one of the standard Trainer.
"""
r"""Replace the method of KTO Trainer with the one of the standard Trainer."""
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
@override
def forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""
Runs forward pass and computes the log probabilities.
"""
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""Run forward pass and computes the log probabilities."""
batch = nested_detach(batch, clone=True) # avoid error
model_inputs = {
"input_ids": batch[f"{prefix}input_ids"],
@@ -171,8 +165,8 @@ class CustomKTOTrainer(KTOTrainer):
@override
def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
target_logits, target_logps, target_logps_avg = self.forward(model, batch)
with torch.no_grad():
_, kl_logps, _ = self.forward(model, batch, prefix="kl_")
@@ -189,11 +183,9 @@ class CustomKTOTrainer(KTOTrainer):
@override
def compute_reference_log_probs(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""
Computes log probabilities of the reference model.
"""
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""Compute log probabilities of the reference model."""
if self.ref_model is None:
ref_model = model
ref_context = self.accelerator.unwrap_model(model).disable_adapter()
@@ -212,11 +204,9 @@ class CustomKTOTrainer(KTOTrainer):
def get_batch_loss_metrics(
self,
model: "PreTrainedModel",
batch: Dict[str, "torch.Tensor"],
) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
r"""
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
"""
batch: dict[str, "torch.Tensor"],
) -> tuple["torch.Tensor", dict[str, "torch.Tensor"]]:
r"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {}
(
policy_chosen_logps,
@@ -262,18 +252,14 @@ class CustomKTOTrainer(KTOTrainer):
@override
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r"""
Subclass and override to accept extra kwargs.
"""
self, model: "PreTrainedModel", inputs: dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", tuple["torch.Tensor", list["torch.Tensor"]]]:
r"""Subclass and override to accept extra kwargs."""
return super().compute_loss(model, inputs, return_outputs)
@override
def log(self, logs: Dict[str, float], *args, **kwargs) -> None:
r"""
Log `logs` on the various objects watching training, including stored metrics.
"""
def log(self, logs: dict[str, float], *args, **kwargs) -> None:
r"""Log `logs` on the various objects watching training, including stored metrics."""
# logs either has "loss" or "eval_loss"
train_eval = "train" if "loss" in logs else "eval"
prefix = "eval_" if train_eval == "eval" else ""
@@ -291,7 +277,7 @@ class CustomKTOTrainer(KTOTrainer):
metric_list = torch.tensor(metric_list, dtype=torch.float).to(self.accelerator.device)
metric_list = self.accelerator.reduce(metric_list, "sum").tolist()
metric_dict: Dict[str, float] = dict(zip(key_list, metric_list))
metric_dict: dict[str, float] = dict(zip(key_list, metric_list))
for split in ["chosen", "rejected"]: # accumulate average metrics from sums and lengths
if f"count/{split}" in metric_dict:
for key in ("rewards", "logps", "logits"):