Former-commit-id: 25d7bbd0a5142f001bd2ff498df07b24137050a9
This commit is contained in:
hiyouga
2023-11-07 19:42:01 +08:00
parent f23e5b602a
commit 14a38b5069
5 changed files with 21 additions and 17 deletions

View File

@@ -33,6 +33,12 @@ def run_dpo(
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
)
# Create reference model
ref_model = None
if not isinstance(model, PeftModel):
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage="sft")
# Update arguments
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
training_args = Seq2SeqTrainingArguments(**training_args_dict)
@@ -41,7 +47,7 @@ def run_dpo(
trainer = CustomDPOTrainer(
beta=finetuning_args.dpo_beta,
model=model,
ref_model=deepcopy(model) if not isinstance(model, PeftModel) else None,
ref_model=ref_model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,