@@ -10,7 +10,7 @@ from llmtuner.tuner.core.trainer import PeftModelMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel
|
||||
from llmtuner.hparams import FinetuningArguments, GeneratingArguments
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
|
||||
|
||||
class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
|
||||
@@ -18,12 +18,10 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
||||
**kwargs
|
||||
):
|
||||
self.finetuning_args = finetuning_args
|
||||
self.generating_args = generating_args
|
||||
self.ref_model = ref_model
|
||||
self.use_dpo_data_collator = True # hack to avoid warning
|
||||
self.label_pad_token_id = IGNORE_INDEX
|
||||
|
||||
Reference in New Issue
Block a user