Former-commit-id: 8577f52b4152efe6cc7a8b5f6d37b4f9ba6684e7
This commit is contained in:
hiyouga
2024-12-30 05:55:15 +00:00
parent 5f473e2696
commit f8f05a883b
7 changed files with 26 additions and 11 deletions

View File

@@ -29,7 +29,7 @@ from trl.trainer import disable_dropout_in_model
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
@@ -50,6 +50,9 @@ class CustomDPOTrainer(DPOTrainer):
disable_dropout: bool = True,
**kwargs,
):
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")
if disable_dropout:
disable_dropout_in_model(model)
if ref_model is not None: