support SimPO #3900

Former-commit-id: 6b954ce60155cf8334150b795cfc4bb63ca74c8b
This commit is contained in:
hiyouga
2024-05-26 23:46:33 +08:00
parent 26f293d587
commit b0d9966663
19 changed files with 145 additions and 339 deletions

View File

@@ -36,10 +36,13 @@ def run_dpo(
)
# Create reference model
if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
ref_model = model
if finetuning_args.use_ref_model:
if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
ref_model = model
else:
ref_model = create_ref_model(model_args, finetuning_args)
else:
ref_model = create_ref_model(model_args, finetuning_args)
ref_model = None
# Update arguments
training_args.remove_unused_columns = False # important for pairwise dataset
@@ -69,7 +72,7 @@ def run_dpo(
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
if id(model) == id(ref_model): # unable to compute rewards without a reference model
if id(model) == id(ref_model): # unable to compute rewards if reference model is the model itself
remove_keys = [key for key in metrics.keys() if "rewards" in key]
for key in remove_keys:
metrics.pop(key)