implement webui resuming training

Former-commit-id: 2d41672ef52414c56c50c8b4fdc442797ba682e9
This commit is contained in:
hiyouga
2023-10-15 04:52:19 +08:00
parent ef248dbe15
commit 31e3805fb8
7 changed files with 42 additions and 26 deletions

View File

@@ -24,9 +24,9 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
dataset = gr.Dropdown(multiselect=True, scale=4)
data_preview_btn = gr.Button(interactive=False, scale=1)
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset])
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset])
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn])
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False)
input_elems.update({training_stage, dataset_dir, dataset})
elem_dict.update(dict(
@@ -128,6 +128,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
output_dir = gr.Textbox()
with gr.Row():
resume_btn = gr.Checkbox(visible=False, interactive=False, value=False)
process_bar = gr.Slider(visible=False, interactive=False)
with gr.Box():
@@ -139,15 +140,16 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
input_elems.add(output_dir)
output_elems = [output_box, process_bar]
elem_dict.update(dict(
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn,
output_dir=output_dir, output_box=output_box, loss_viewer=loss_viewer
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_dir=output_dir,
resume_btn=resume_btn, output_box=output_box, loss_viewer=loss_viewer, process_bar=process_bar
))
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems)
start_btn.click(engine.runner.run_train, input_elems, output_elems)
stop_btn.click(engine.runner.set_abort, queue=False)
resume_btn.change(engine.runner.monitor, outputs=output_elems)
process_bar.change(
output_box.change(
gen_plot,
[engine.manager.get_elem("top.model_name"), engine.manager.get_elem("top.finetuning_type"), output_dir],
loss_viewer,