support SimPO #3900
Former-commit-id: 6b954ce60155cf8334150b795cfc4bb63ca74c8b
This commit is contained in:
@@ -186,7 +186,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
pref_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
|
||||
pref_ftx = gr.Slider(minimum=0, maximum=10, value=0, step=0.01)
|
||||
pref_loss = gr.Dropdown(choices=["sigmoid", "hinge", "ipo", "kto_pair"], value="sigmoid")
|
||||
pref_loss = gr.Dropdown(choices=["sigmoid", "hinge", "ipo", "kto_pair", "orpo", "simpo"], value="sigmoid")
|
||||
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True)
|
||||
with gr.Column():
|
||||
ppo_score_norm = gr.Checkbox()
|
||||
|
||||
@@ -179,15 +179,10 @@ class Runner:
|
||||
args["ppo_whiten_rewards"] = get("train.ppo_whiten_rewards")
|
||||
args["top_k"] = 0
|
||||
args["top_p"] = 0.9
|
||||
elif args["stage"] == "dpo":
|
||||
args["dpo_beta"] = get("train.pref_beta")
|
||||
args["dpo_ftx"] = get("train.pref_ftx")
|
||||
args["dpo_loss"] = get("train.pref_loss")
|
||||
elif args["stage"] == "kto":
|
||||
args["kto_beta"] = get("train.pref_beta")
|
||||
args["kto_ftx"] = get("train.pref_ftx")
|
||||
elif args["stage"] == "orpo":
|
||||
args["orpo_beta"] = get("train.pref_beta")
|
||||
elif args["stage"] in ["dpo", "kto"]:
|
||||
args["pref_beta"] = get("train.pref_beta")
|
||||
args["pref_ftx"] = get("train.pref_ftx")
|
||||
args["pref_loss"] = get("train.pref_loss")
|
||||
|
||||
# galore config
|
||||
if args["use_galore"]:
|
||||
|
||||
Reference in New Issue
Block a user