mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-03 08:53:38 +00:00
[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -19,7 +19,7 @@ import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import Trainer
|
||||
@@ -120,9 +120,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
|
||||
@override
|
||||
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
|
||||
r"""
|
||||
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
|
||||
"""
|
||||
r"""Replace the sequential sampler of KTO Trainer created by trl with the random sampler."""
|
||||
if self.finetuning_args.disable_shuffling:
|
||||
return torch.utils.data.SequentialSampler(self.train_dataset)
|
||||
|
||||
@@ -130,18 +128,14 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
|
||||
@override
|
||||
def get_batch_samples(self, epoch_iterator, num_batches):
|
||||
r"""
|
||||
Replaces the method of KTO Trainer with the one of the standard Trainer.
|
||||
"""
|
||||
r"""Replace the method of KTO Trainer with the one of the standard Trainer."""
|
||||
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
|
||||
|
||||
@override
|
||||
def forward(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""
|
||||
Runs forward pass and computes the log probabilities.
|
||||
"""
|
||||
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
|
||||
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""Run forward pass and computes the log probabilities."""
|
||||
batch = nested_detach(batch, clone=True) # avoid error
|
||||
model_inputs = {
|
||||
"input_ids": batch[f"{prefix}input_ids"],
|
||||
@@ -171,8 +165,8 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
|
||||
@override
|
||||
def concatenated_forward(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
|
||||
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
target_logits, target_logps, target_logps_avg = self.forward(model, batch)
|
||||
with torch.no_grad():
|
||||
_, kl_logps, _ = self.forward(model, batch, prefix="kl_")
|
||||
@@ -189,11 +183,9 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
|
||||
@override
|
||||
def compute_reference_log_probs(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""
|
||||
Computes log probabilities of the reference model.
|
||||
"""
|
||||
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
|
||||
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""Compute log probabilities of the reference model."""
|
||||
if self.ref_model is None:
|
||||
ref_model = model
|
||||
ref_context = self.accelerator.unwrap_model(model).disable_adapter()
|
||||
@@ -212,11 +204,9 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
batch: Dict[str, "torch.Tensor"],
|
||||
) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
|
||||
r"""
|
||||
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
|
||||
"""
|
||||
batch: dict[str, "torch.Tensor"],
|
||||
) -> tuple["torch.Tensor", dict[str, "torch.Tensor"]]:
|
||||
r"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
|
||||
metrics = {}
|
||||
(
|
||||
policy_chosen_logps,
|
||||
@@ -262,18 +252,14 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
|
||||
@override
|
||||
def compute_loss(
|
||||
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
||||
r"""
|
||||
Subclass and override to accept extra kwargs.
|
||||
"""
|
||||
self, model: "PreTrainedModel", inputs: dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||
) -> Union["torch.Tensor", tuple["torch.Tensor", list["torch.Tensor"]]]:
|
||||
r"""Subclass and override to accept extra kwargs."""
|
||||
return super().compute_loss(model, inputs, return_outputs)
|
||||
|
||||
@override
|
||||
def log(self, logs: Dict[str, float], *args, **kwargs) -> None:
|
||||
r"""
|
||||
Log `logs` on the various objects watching training, including stored metrics.
|
||||
"""
|
||||
def log(self, logs: dict[str, float], *args, **kwargs) -> None:
|
||||
r"""Log `logs` on the various objects watching training, including stored metrics."""
|
||||
# logs either has "loss" or "eval_loss"
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
@@ -291,7 +277,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
|
||||
metric_list = torch.tensor(metric_list, dtype=torch.float).to(self.accelerator.device)
|
||||
metric_list = self.accelerator.reduce(metric_list, "sum").tolist()
|
||||
metric_dict: Dict[str, float] = dict(zip(key_list, metric_list))
|
||||
metric_dict: dict[str, float] = dict(zip(key_list, metric_list))
|
||||
for split in ["chosen", "rejected"]: # accumulate average metrics from sums and lengths
|
||||
if f"count/{split}" in metric_dict:
|
||||
for key in ("rewards", "logps", "logits"):
|
||||
|
||||
Reference in New Issue
Block a user