@@ -40,6 +40,11 @@ class CustomTrainer(Trainer):
|
||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||
|
||||
super().__init__(**kwargs)
|
||||
if processor is not None:
|
||||
# avoid wrong loss under gradient accumulation
|
||||
# https://github.com/huggingface/transformers/pull/36044#issuecomment-2746657112
|
||||
self.model_accepts_loss_kwargs = False
|
||||
|
||||
self.finetuning_args = finetuning_args
|
||||
|
||||
if processor is not None:
|
||||
|
||||
@@ -60,6 +60,8 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
|
||||
super().__init__(**kwargs)
|
||||
if processor is not None:
|
||||
# avoid wrong loss under gradient accumulation
|
||||
# https://github.com/huggingface/transformers/pull/36044#issuecomment-2746657112
|
||||
self.model_accepts_loss_kwargs = False
|
||||
|
||||
self.finetuning_args = finetuning_args
|
||||
|
||||
Reference in New Issue
Block a user