@@ -31,13 +31,14 @@ def run_dpo(
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||
ref_model = deepcopy(model) if not isinstance(model, PeftModel) else None
|
||||
training_args_dict = training_args.to_dict()
|
||||
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
|
||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = DPOPeftTrainer(
|
||||
finetuning_args=finetuning_args,
|
||||
ref_model=ref_model,
|
||||
ref_model=deepcopy(model) if not isinstance(model, PeftModel) else None,
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
|
||||
Reference in New Issue
Block a user