add ds config to webui
Former-commit-id: 66d72b263d36dc81de9f6152077663b613035977
This commit is contained in:
@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Dict
|
||||
from transformers.trainer_utils import SchedulerType
|
||||
|
||||
from ...extras.constants import TRAINING_STAGES
|
||||
from ...extras.misc import get_device_count
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset
|
||||
from ..components.data import create_preview_box
|
||||
@@ -258,6 +259,11 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
output_dir = gr.Textbox()
|
||||
config_path = gr.Textbox()
|
||||
|
||||
with gr.Row():
|
||||
device_count = gr.Textbox(value=str(get_device_count()), interactive=False)
|
||||
ds_stage = gr.Dropdown(choices=["none", "2", "3"], value="none")
|
||||
ds_offload = gr.Checkbox()
|
||||
|
||||
with gr.Row():
|
||||
resume_btn = gr.Checkbox(visible=False, interactive=False)
|
||||
progress_bar = gr.Slider(visible=False, interactive=False)
|
||||
@@ -268,6 +274,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
with gr.Column(scale=1):
|
||||
loss_viewer = gr.Plot()
|
||||
|
||||
input_elems.update({output_dir, config_path, device_count, ds_stage, ds_offload})
|
||||
elem_dict.update(
|
||||
dict(
|
||||
cmd_preview_btn=cmd_preview_btn,
|
||||
@@ -277,14 +284,15 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
stop_btn=stop_btn,
|
||||
output_dir=output_dir,
|
||||
config_path=config_path,
|
||||
device_count=device_count,
|
||||
ds_stage=ds_stage,
|
||||
ds_offload=ds_offload,
|
||||
resume_btn=resume_btn,
|
||||
progress_bar=progress_bar,
|
||||
output_box=output_box,
|
||||
loss_viewer=loss_viewer,
|
||||
)
|
||||
)
|
||||
|
||||
input_elems.update({output_dir, config_path})
|
||||
output_elems = [output_box, progress_bar, loss_viewer]
|
||||
|
||||
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
|
||||
|
||||
Reference in New Issue
Block a user