@@ -29,7 +29,7 @@ from trl.trainer import disable_dropout_in_model
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.packages import is_transformers_version_equal_to_4_46
|
||||
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
|
||||
from ..callbacks import SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
|
||||
|
||||
@@ -50,6 +50,9 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
disable_dropout: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
if is_transformers_version_greater_than("4.46"):
|
||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||
|
||||
if disable_dropout:
|
||||
disable_dropout_in_model(model)
|
||||
if ref_model is not None:
|
||||
|
||||
Reference in New Issue
Block a user