support pagination in webui preview

Former-commit-id: f2307e26b9c2ce5d60917cce5a9638466ea676c8
This commit is contained in:
hiyouga
2023-11-02 21:21:45 +08:00
parent ac74639b32
commit 7d13501b94
10 changed files with 201 additions and 154 deletions

View File

@@ -5,7 +5,7 @@ from transformers.trainer_utils import SchedulerType
from llmtuner.extras.constants import TRAINING_STAGES
from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box
from llmtuner.webui.utils import can_preview, get_preview, gen_plot
from llmtuner.webui.utils import gen_plot
if TYPE_CHECKING:
from gradio.components import Component
@@ -22,28 +22,14 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
)
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4)
data_preview_btn = gr.Button(interactive=False, scale=1)
preview_elems = create_preview_box(dataset_dir, dataset)
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False)
input_elems.update({training_stage, dataset_dir, dataset})
elem_dict.update(dict(
training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, data_preview_btn=data_preview_btn
))
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
data_preview_btn.click(
get_preview,
[dataset_dir, dataset],
[preview_count, preview_samples, preview_box],
queue=False
)
elem_dict.update(dict(
preview_count=preview_count, preview_samples=preview_samples, close_btn=close_btn
training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems
))
with gr.Row():
@@ -143,16 +129,17 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
input_elems.add(output_dir)
output_elems = [output_box, process_bar]
elem_dict.update(dict(
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_dir=output_dir,
resume_btn=resume_btn, process_bar=process_bar, output_box=output_box, loss_viewer=loss_viewer
))
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems)
start_btn.click(engine.runner.run_train, input_elems, output_elems)
stop_btn.click(engine.runner.set_abort, queue=False)
resume_btn.change(engine.runner.monitor, outputs=output_elems)
elem_dict.update(dict(
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_dir=output_dir,
resume_btn=resume_btn, process_bar=process_bar, output_box=output_box, loss_viewer=loss_viewer
))
output_box.change(
gen_plot,
[