[trainer] fix vlm loss for transformers 4.49 (#7448)

This commit is contained in:
hoshi-hiyouga
2025-03-24 10:24:05 +08:00
committed by GitHub
parent 3612946dd9
commit 7203365b80
5 changed files with 21 additions and 4 deletions

View File

@@ -128,9 +128,9 @@ class CustomDPOTrainer(DPOTrainer):
return super()._get_train_sampler()
@override
def get_batch_samples(self, epoch_iterator, num_batches, *args, **kwargs):
def get_batch_samples(self, *args, **kwargs):
r"""Replace the method of DPO Trainer with the one of the standard Trainer."""
return Trainer.get_batch_samples(self, epoch_iterator, num_batches, *args, **kwargs)
return Trainer.get_batch_samples(self, *args, **kwargs)
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
r"""Compute ORPO's odds ratio (OR) loss for batched log probabilities of the policy model."""