mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-03 21:03:10 +00:00
[trainer] fix vlm loss for transformers 4.49 (#7448)
This commit is contained in:
@@ -127,9 +127,9 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
return Trainer._get_train_sampler(self)
|
||||
|
||||
@override
|
||||
def get_batch_samples(self, epoch_iterator, num_batches, *args, **kwargs):
|
||||
def get_batch_samples(self, *args, **kwargs):
|
||||
r"""Replace the method of KTO Trainer with the one of the standard Trainer."""
|
||||
return Trainer.get_batch_samples(self, epoch_iterator, num_batches, *args, **kwargs)
|
||||
return Trainer.get_batch_samples(self, *args, **kwargs)
|
||||
|
||||
@override
|
||||
def forward(
|
||||
|
||||
Reference in New Issue
Block a user