[deps] upgrade transformers to 4.50.0 (#7437)

* upgrade transformers

* fix hf cache

* fix dpo trainer
This commit is contained in:
hoshi-hiyouga
2025-03-23 17:44:27 +08:00
committed by GitHub
parent 919415dba9
commit 05b19d6952
6 changed files with 10 additions and 10 deletions

View File

@@ -128,9 +128,9 @@ class CustomDPOTrainer(DPOTrainer):
return super()._get_train_sampler()
@override
def get_batch_samples(self, epoch_iterator, num_batches):
def get_batch_samples(self, epoch_iterator, num_batches, *args, **kwargs):
r"""Replace the method of DPO Trainer with the one of the standard Trainer."""
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
return Trainer.get_batch_samples(self, epoch_iterator, num_batches, *args, **kwargs)
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
r"""Compute ORPO's odds ratio (OR) loss for batched log probabilities of the policy model."""

View File

@@ -127,9 +127,9 @@ class CustomKTOTrainer(KTOTrainer):
return Trainer._get_train_sampler(self)
@override
def get_batch_samples(self, epoch_iterator, num_batches):
def get_batch_samples(self, epoch_iterator, num_batches, *args, **kwargs):
r"""Replace the method of KTO Trainer with the one of the standard Trainer."""
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
return Trainer.get_batch_samples(self, epoch_iterator, num_batches, *args, **kwargs)
@override
def forward(