fix incorrect loss value for vlms
Former-commit-id: 0aa29a71ce958343a2086090d647eb63b8f5f5be
This commit is contained in:
@@ -27,6 +27,7 @@ from typing_extensions import override
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.packages import is_transformers_version_equal_to_4_46
|
||||
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||
|
||||
@@ -60,7 +61,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
self.add_callback(PissaConvertCallback)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
@@ -78,6 +79,18 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
@override
|
||||
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
||||
r"""
|
||||
Fixes the loss value for transformers 4.46.0.
|
||||
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
|
||||
"""
|
||||
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
|
||||
if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False):
|
||||
loss /= self.args.gradient_accumulation_steps # other model should not scale the loss
|
||||
|
||||
return loss
|
||||
|
||||
@override
|
||||
def prediction_step(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user