fix dpo metrics

Former-commit-id: 57029280da825a39fbf5a05097921b861f126669
This commit is contained in:
hiyouga
2024-11-02 19:22:11 +08:00
parent b28b74c71e
commit 2bb3255e74
7 changed files with 143 additions and 58 deletions

View File

@@ -87,7 +87,11 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
"""
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False):
loss /= self.args.gradient_accumulation_steps # other model should not scale the loss
# other model should not scale the loss
if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps
return loss