Former-commit-id: e6120a937ddb4f3c0b9bcb2466742f5cf4f77f8c
This commit is contained in:
hiyouga
2023-08-23 20:21:15 +08:00
parent 4606340f0f
commit eb9ac9ee1f
4 changed files with 19 additions and 13 deletions

View File

@@ -31,13 +31,14 @@ def run_dpo(
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
)
training_args.remove_unused_columns = False # important for pairwise dataset
ref_model = deepcopy(model) if not isinstance(model, PeftModel) else None
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
training_args = Seq2SeqTrainingArguments(**training_args_dict)
# Initialize our Trainer
trainer = DPOPeftTrainer(
finetuning_args=finetuning_args,
ref_model=ref_model,
ref_model=deepcopy(model) if not isinstance(model, PeftModel) else None,
model=model,
args=training_args,
tokenizer=tokenizer,