support save args in webui #2807 #3046

some ideas are borrowed from @marko1616


Former-commit-id: b5a062aa2d4a37670007e8b3dae5b6f5b7ffb15c
This commit is contained in:
hiyouga
2024-03-30 23:09:12 +08:00
parent b0efebf853
commit 6198121923
9 changed files with 219 additions and 80 deletions

View File

@@ -27,8 +27,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
dataset = gr.Dropdown(multiselect=True, scale=4)
preview_elems = create_preview_box(dataset_dir, dataset)
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
input_elems.update({training_stage, dataset_dir, dataset})
elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
@@ -127,19 +125,30 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Accordion(open=False) as lora_tab:
with gr.Row():
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=1, scale=1)
lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
lora_target = gr.Textbox(scale=2)
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1)
lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=1)
lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01)
loraplus_lr_ratio = gr.Slider(value=0, minimum=0, maximum=64, step=0.01)
create_new_adapter = gr.Checkbox()
with gr.Row():
use_rslora = gr.Checkbox(scale=1)
use_dora = gr.Checkbox(scale=1)
create_new_adapter = gr.Checkbox(scale=1)
lora_target = gr.Textbox(scale=2)
additional_target = gr.Textbox(scale=2)
input_elems.update(
{lora_rank, lora_alpha, lora_dropout, lora_target, use_rslora, use_dora, create_new_adapter, additional_target}
{
lora_rank,
lora_alpha,
lora_dropout,
loraplus_lr_ratio,
create_new_adapter,
use_rslora,
use_dora,
lora_target,
additional_target,
}
)
elem_dict.update(
dict(
@@ -147,10 +156,11 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
lora_target=lora_target,
loraplus_lr_ratio=loraplus_lr_ratio,
create_new_adapter=create_new_adapter,
use_rslora=use_rslora,
use_dora=use_dora,
create_new_adapter=create_new_adapter,
lora_target=lora_target,
additional_target=additional_target,
)
)
@@ -161,13 +171,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01, scale=1)
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=2)
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then(
list_adapters,
[engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.finetuning_type")],
[reward_model],
queue=False,
).then(autoset_packing, [training_stage], [packing], queue=False)
input_elems.update({dpo_beta, dpo_ftx, reward_model})
elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model))
@@ -177,7 +180,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
galore_rank = gr.Slider(value=16, minimum=1, maximum=1024, step=1, scale=2)
galore_update_interval = gr.Slider(value=200, minimum=1, maximum=1024, step=1, scale=2)
galore_scale = gr.Slider(value=0.25, minimum=0, maximum=1, step=0.01, scale=2)
galore_target = gr.Textbox(value="mlp,attn", scale=3)
galore_target = gr.Textbox(value="all", scale=3)
input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target})
elem_dict.update(
@@ -193,13 +196,16 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row():
cmd_preview_btn = gr.Button()
start_btn = gr.Button()
stop_btn = gr.Button()
arg_save_btn = gr.Button()
arg_load_btn = gr.Button()
start_btn = gr.Button(variant="primary")
stop_btn = gr.Button(variant="stop")
with gr.Row():
with gr.Column(scale=3):
with gr.Row():
output_dir = gr.Textbox()
config_path = gr.Textbox()
with gr.Row():
resume_btn = gr.Checkbox(visible=False, interactive=False)
@@ -211,20 +217,38 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Column(scale=1):
loss_viewer = gr.Plot()
input_elems.add(output_dir)
input_elems.update({output_dir, config_path})
output_elems = [output_box, process_bar]
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)
arg_load_btn.click(
engine.runner.load_args,
[engine.manager.get_elem_by_id("top.lang"), config_path],
list(input_elems),
concurrency_limit=None,
)
start_btn.click(engine.runner.run_train, input_elems, output_elems)
stop_btn.click(engine.runner.set_abort, queue=False)
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then(
list_adapters,
[engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.finetuning_type")],
[reward_model],
queue=False,
).then(autoset_packing, [training_stage], [packing], queue=False)
elem_dict.update(
dict(
cmd_preview_btn=cmd_preview_btn,
arg_save_btn=arg_save_btn,
arg_load_btn=arg_load_btn,
start_btn=start_btn,
stop_btn=stop_btn,
output_dir=output_dir,
config_path=config_path,
resume_btn=resume_btn,
process_bar=process_bar,
output_box=output_box,