[trainer] fix vlm loss for transformers 4.49 (#7448)
This commit is contained in:
@@ -128,9 +128,9 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
return super()._get_train_sampler()
|
||||
|
||||
@override
|
||||
def get_batch_samples(self, epoch_iterator, num_batches, *args, **kwargs):
|
||||
def get_batch_samples(self, *args, **kwargs):
|
||||
r"""Replace the method of DPO Trainer with the one of the standard Trainer."""
|
||||
return Trainer.get_batch_samples(self, epoch_iterator, num_batches, *args, **kwargs)
|
||||
return Trainer.get_batch_samples(self, *args, **kwargs)
|
||||
|
||||
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
|
||||
r"""Compute ORPO's odds ratio (OR) loss for batched log probabilities of the policy model."""
|
||||
|
||||
@@ -127,9 +127,9 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
return Trainer._get_train_sampler(self)
|
||||
|
||||
@override
|
||||
def get_batch_samples(self, epoch_iterator, num_batches, *args, **kwargs):
|
||||
def get_batch_samples(self, *args, **kwargs):
|
||||
r"""Replace the method of KTO Trainer with the one of the standard Trainer."""
|
||||
return Trainer.get_batch_samples(self, epoch_iterator, num_batches, *args, **kwargs)
|
||||
return Trainer.get_batch_samples(self, *args, **kwargs)
|
||||
|
||||
@override
|
||||
def forward(
|
||||
|
||||
@@ -70,3 +70,7 @@ class CustomTrainer(Trainer):
|
||||
return torch.utils.data.SequentialSampler(self.train_dataset)
|
||||
|
||||
return super()._get_train_sampler()
|
||||
|
||||
@override
|
||||
def compute_loss(self, model, inputs, *args, **kwargs):
|
||||
return super().compute_loss(model, inputs, *args, **kwargs)
|
||||
|
||||
@@ -59,6 +59,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
self.processing_class: PreTrainedTokenizer = kwargs.get("tokenizer")
|
||||
|
||||
super().__init__(**kwargs)
|
||||
if processor is not None:
|
||||
self.model_accepts_loss_kwargs = False
|
||||
|
||||
self.finetuning_args = finetuning_args
|
||||
if gen_kwargs is not None:
|
||||
# https://github.com/huggingface/transformers/blob/v4.45.0/src/transformers/trainer_seq2seq.py#L287
|
||||
@@ -93,6 +96,10 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
|
||||
return super()._get_train_sampler()
|
||||
|
||||
@override
|
||||
def compute_loss(self, model, inputs, *args, **kwargs):
|
||||
return super().compute_loss(model, inputs, *args, **kwargs)
|
||||
|
||||
@override
|
||||
def prediction_step(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user