Former-commit-id: 627d1c91e675f1d9ebf47bad123cbbf29821da4d
This commit is contained in:
@@ -27,7 +27,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
preview_elems = create_preview_box(dataset_dir, dataset)
|
||||
|
||||
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
|
||||
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
|
||||
|
||||
input_elems.update({training_stage, dataset_dir, dataset})
|
||||
@@ -160,10 +159,9 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
|
||||
dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01, scale=1)
|
||||
reward_model = gr.Dropdown(scale=2, allow_custom_value=True)
|
||||
refresh_btn = gr.Button(scale=1)
|
||||
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=2)
|
||||
|
||||
refresh_btn.click(
|
||||
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then(
|
||||
list_adapters,
|
||||
[engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")],
|
||||
[reward_model],
|
||||
@@ -171,9 +169,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
)
|
||||
|
||||
input_elems.update({dpo_beta, dpo_ftx, reward_model})
|
||||
elem_dict.update(
|
||||
dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model, refresh_btn=refresh_btn)
|
||||
)
|
||||
elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model))
|
||||
|
||||
with gr.Accordion(label="GaLore config", open=False) as galore_tab:
|
||||
with gr.Row():
|
||||
|
||||
Reference in New Issue
Block a user