Former-commit-id: 26d07de349c98b547cd6a6166ea20616d08ba343
This commit is contained in:
hiyouga
2024-10-29 10:47:04 +00:00
parent 248d5daaff
commit e2748fa967
8 changed files with 58 additions and 6 deletions

View File

@@ -25,6 +25,7 @@ from transformers import Trainer
from typing_extensions import override
from ...extras.logging import get_logger
from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
@@ -79,7 +80,7 @@ class PairwiseTrainer(Trainer):
@override
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
@@ -98,6 +99,10 @@ class PairwiseTrainer(Trainer):
chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze()
loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
if kwargs.pop("num_items_in_batch", False) and is_transformers_version_equal_to_4_46():
loss /= self.args.gradient_accumulation_steps # fixes the loss value for transformers 4.46.0
if return_outputs:
return loss, (loss, chosen_scores, rejected_scores)
else: