Former-commit-id: 26d07de349c98b547cd6a6166ea20616d08ba343
This commit is contained in:
hiyouga
2024-10-29 10:47:04 +00:00
parent 248d5daaff
commit e2748fa967
8 changed files with 58 additions and 6 deletions

View File

@@ -28,6 +28,7 @@ from trl.trainer import disable_dropout_in_model
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
@@ -120,6 +121,13 @@ class CustomKTOTrainer(KTOTrainer):
"""
return Trainer._get_train_sampler(self)
@override
def get_batch_samples(self, epoch_iterator, num_batches):
r"""
Replaces the method of KTO Trainer with the one of the standard Trainer.
"""
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
@override
def forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
@@ -231,3 +239,15 @@ class CustomKTOTrainer(KTOTrainer):
metrics["kl"] = kl.item()
return losses, metrics
@override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
r"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
"""
loss = super().compute_loss(model, inputs, return_outputs)
if kwargs.pop("num_items_in_batch", False) and is_transformers_version_equal_to_4_46():
loss /= self.args.gradient_accumulation_steps
return loss