Former-commit-id: a7bdaf1c92c7d798caf8438dc42a8972632ec584
This commit is contained in:
hiyouga
2023-08-21 18:16:11 +08:00
parent 4d128acc17
commit d6be98cda6
2 changed files with 2 additions and 6 deletions

View File

@@ -10,7 +10,7 @@ from llmtuner.tuner.core.trainer import PeftModelMixin
if TYPE_CHECKING:
from transformers import PreTrainedModel
from llmtuner.hparams import FinetuningArguments, GeneratingArguments
from llmtuner.hparams import FinetuningArguments
class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
@@ -18,12 +18,10 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
def __init__(
self,
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
**kwargs
):
self.finetuning_args = finetuning_args
self.generating_args = generating_args
self.ref_model = ref_model
self.use_dpo_data_collator = True # hack to avoid warning
self.label_pad_token_id = IGNORE_INDEX