[deps] update to transformers 4.52 (#8125)
This commit is contained in:
@@ -121,11 +121,11 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
@override
|
||||
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
|
||||
def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]:
|
||||
if self.finetuning_args.disable_shuffling:
|
||||
return torch.utils.data.SequentialSampler(self.train_dataset)
|
||||
|
||||
return super()._get_train_sampler()
|
||||
return super()._get_train_sampler(*args, **kwargs)
|
||||
|
||||
@override
|
||||
def get_batch_samples(self, *args, **kwargs):
|
||||
|
||||
@@ -34,7 +34,6 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, ge
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch.utils.data
|
||||
from transformers import PreTrainedModel, ProcessorMixin
|
||||
|
||||
from ...hparams import FinetuningArguments
|
||||
@@ -119,12 +118,12 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
@override
|
||||
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
|
||||
def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]:
|
||||
r"""Replace the sequential sampler of KTO Trainer created by trl with the random sampler."""
|
||||
if self.finetuning_args.disable_shuffling:
|
||||
return torch.utils.data.SequentialSampler(self.train_dataset)
|
||||
|
||||
return Trainer._get_train_sampler(self)
|
||||
return Trainer._get_train_sampler(self, *args, **kwargs)
|
||||
|
||||
@override
|
||||
def get_batch_samples(self, *args, **kwargs):
|
||||
|
||||
@@ -70,11 +70,11 @@ class CustomTrainer(Trainer):
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
@override
|
||||
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
|
||||
def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]:
|
||||
if self.finetuning_args.disable_shuffling:
|
||||
return torch.utils.data.SequentialSampler(self.train_dataset)
|
||||
|
||||
return super()._get_train_sampler()
|
||||
return super()._get_train_sampler(*args, **kwargs)
|
||||
|
||||
@override
|
||||
def compute_loss(self, model, inputs, *args, **kwargs):
|
||||
|
||||
@@ -78,11 +78,11 @@ class PairwiseTrainer(Trainer):
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
@override
|
||||
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
|
||||
def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]:
|
||||
if self.finetuning_args.disable_shuffling:
|
||||
return torch.utils.data.SequentialSampler(self.train_dataset)
|
||||
|
||||
return super()._get_train_sampler()
|
||||
return super()._get_train_sampler(*args, **kwargs)
|
||||
|
||||
@override
|
||||
def compute_loss(
|
||||
|
||||
@@ -92,11 +92,11 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
@override
|
||||
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
|
||||
def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]:
|
||||
if self.finetuning_args.disable_shuffling:
|
||||
return torch.utils.data.SequentialSampler(self.train_dataset)
|
||||
|
||||
return super()._get_train_sampler()
|
||||
return super()._get_train_sampler(*args, **kwargs)
|
||||
|
||||
@override
|
||||
def compute_loss(self, model, inputs, *args, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user