support longlora for main branch

Former-commit-id: f869501ad4c368df26534c41f62c6d63c6be17dd
This commit is contained in:
hiyouga
2024-01-20 19:25:22 +08:00
parent 8efc055511
commit 80637fc06d
7 changed files with 168 additions and 204 deletions

View File

@@ -95,7 +95,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
with gr.Row():
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
reward_model = gr.Dropdown(scale=3, allow_custom_value=True)
dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01, scale=1)
reward_model = gr.Dropdown(scale=2, allow_custom_value=True)
refresh_btn = gr.Button(scale=1)
refresh_btn.click(
@@ -105,8 +106,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
queue=False
)
input_elems.update({dpo_beta, reward_model})
elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, reward_model=reward_model, refresh_btn=refresh_btn))
input_elems.update({dpo_beta, dpo_ftx, reward_model})
elem_dict.update(dict(
rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model, refresh_btn=refresh_btn
))
with gr.Row():
cmd_preview_btn = gr.Button()