add ds config to webui

Former-commit-id: 66d72b263d36dc81de9f6152077663b613035977
This commit is contained in:
hiyouga
2024-05-29 01:13:17 +08:00
parent 351b4efc6c
commit cb1a49aa02
5 changed files with 123 additions and 4 deletions

View File

@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Dict
from transformers.trainer_utils import SchedulerType
from ...extras.constants import TRAINING_STAGES
from ...extras.misc import get_device_count
from ...extras.packages import is_gradio_available
from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset
from ..components.data import create_preview_box
@@ -258,6 +259,11 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
output_dir = gr.Textbox()
config_path = gr.Textbox()
with gr.Row():
device_count = gr.Textbox(value=str(get_device_count()), interactive=False)
ds_stage = gr.Dropdown(choices=["none", "2", "3"], value="none")
ds_offload = gr.Checkbox()
with gr.Row():
resume_btn = gr.Checkbox(visible=False, interactive=False)
progress_bar = gr.Slider(visible=False, interactive=False)
@@ -268,6 +274,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Column(scale=1):
loss_viewer = gr.Plot()
input_elems.update({output_dir, config_path, device_count, ds_stage, ds_offload})
elem_dict.update(
dict(
cmd_preview_btn=cmd_preview_btn,
@@ -277,14 +284,15 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
stop_btn=stop_btn,
output_dir=output_dir,
config_path=config_path,
device_count=device_count,
ds_stage=ds_stage,
ds_offload=ds_offload,
resume_btn=resume_btn,
progress_bar=progress_bar,
output_box=output_box,
loss_viewer=loss_viewer,
)
)
input_elems.update({output_dir, config_path})
output_elems = [output_box, progress_bar, loss_viewer]
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)