support SimPO #3900
Former-commit-id: 6b954ce60155cf8334150b795cfc4bb63ca74c8b
This commit is contained in:
@@ -36,10 +36,13 @@ def run_dpo(
|
||||
)
|
||||
|
||||
# Create reference model
|
||||
if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
|
||||
ref_model = model
|
||||
if finetuning_args.use_ref_model:
|
||||
if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
|
||||
ref_model = model
|
||||
else:
|
||||
ref_model = create_ref_model(model_args, finetuning_args)
|
||||
else:
|
||||
ref_model = create_ref_model(model_args, finetuning_args)
|
||||
ref_model = None
|
||||
|
||||
# Update arguments
|
||||
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||
@@ -69,7 +72,7 @@ def run_dpo(
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||
if id(model) == id(ref_model): # unable to compute rewards without a reference model
|
||||
if id(model) == id(ref_model): # unable to compute rewards if reference model is the model itself
|
||||
remove_keys = [key for key in metrics.keys() if "rewards" in key]
|
||||
for key in remove_keys:
|
||||
metrics.pop(key)
|
||||
|
||||
Reference in New Issue
Block a user