add kto to webui
Former-commit-id: 6c866f4dbd45e868860be8351d1a65c4e1a4e02b
This commit is contained in:
@@ -184,14 +184,25 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
|
||||
with gr.Accordion(open=False) as rlhf_tab:
|
||||
with gr.Row():
|
||||
dpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
|
||||
dpo_ftx = gr.Slider(minimum=0, maximum=10, value=0, step=0.01)
|
||||
orpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
|
||||
pref_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
|
||||
pref_ftx = gr.Slider(minimum=0, maximum=10, value=0, step=0.01)
|
||||
pref_loss = gr.Dropdown(choices=["sigmoid", "hinge", "ipo", "kto_pair"], value="sigmoid")
|
||||
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True)
|
||||
with gr.Column():
|
||||
ppo_score_norm = gr.Checkbox()
|
||||
ppo_whiten_rewards = gr.Checkbox()
|
||||
|
||||
input_elems.update({dpo_beta, dpo_ftx, orpo_beta, reward_model})
|
||||
input_elems.update({pref_beta, pref_ftx, pref_loss, reward_model, ppo_score_norm, ppo_whiten_rewards})
|
||||
elem_dict.update(
|
||||
dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, orpo_beta=orpo_beta, reward_model=reward_model)
|
||||
dict(
|
||||
rlhf_tab=rlhf_tab,
|
||||
pref_beta=pref_beta,
|
||||
pref_ftx=pref_ftx,
|
||||
pref_loss=pref_loss,
|
||||
reward_model=reward_model,
|
||||
ppo_score_norm=ppo_score_norm,
|
||||
ppo_whiten_rewards=ppo_whiten_rewards,
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(open=False) as galore_tab:
|
||||
|
||||
Reference in New Issue
Block a user