better llamaboard

* easily resume from checkpoint
* support full and freeze checkpoints
* faster ui


Former-commit-id: 84cfb2452cc86b037ccddee6e833f8eb7c129fa4
This commit is contained in:
hiyouga
2024-05-29 23:55:38 +08:00
parent f90c4ca672
commit 87aa332583
14 changed files with 303 additions and 193 deletions

View File

@@ -5,8 +5,9 @@ from transformers.trainer_utils import SchedulerType
from ...extras.constants import TRAINING_STAGES
from ...extras.misc import get_device_count
from ...extras.packages import is_gradio_available
from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset
from ..components.data import create_preview_box
from ..common import DEFAULT_DATA_DIR, list_checkpoints, list_datasets
from ..utils import change_stage, check_output_dir, list_output_dirs
from .data import create_preview_box
if is_gradio_available():
@@ -256,11 +257,12 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row():
with gr.Column(scale=3):
with gr.Row():
output_dir = gr.Textbox()
initial_dir = gr.Textbox(visible=False, interactive=False)
output_dir = gr.Dropdown(allow_custom_value=True)
config_path = gr.Textbox()
with gr.Row():
device_count = gr.Textbox(value=str(get_device_count()), interactive=False)
device_count = gr.Textbox(value=str(get_device_count() or 1), interactive=False)
ds_stage = gr.Dropdown(choices=["none", "2", "3"], value="none")
ds_offload = gr.Checkbox()
@@ -282,6 +284,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
arg_load_btn=arg_load_btn,
start_btn=start_btn,
stop_btn=stop_btn,
initial_dir=initial_dir,
output_dir=output_dir,
config_path=config_path,
device_count=device_count,
@@ -295,24 +298,24 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
)
output_elems = [output_box, progress_bar, loss_viewer]
lang = engine.manager.get_elem_by_id("top.lang")
model_name = engine.manager.get_elem_by_id("top.model_name")
finetuning_type = engine.manager.get_elem_by_id("top.finetuning_type")
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)
arg_load_btn.click(
engine.runner.load_args,
[engine.manager.get_elem_by_id("top.lang"), config_path],
list(input_elems) + [output_box],
concurrency_limit=None,
engine.runner.load_args, [lang, config_path], list(input_elems) + [output_box], concurrency_limit=None
)
start_btn.click(engine.runner.run_train, input_elems, output_elems)
stop_btn.click(engine.runner.set_abort)
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then(
list_adapters,
[engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.finetuning_type")],
[reward_model],
queue=False,
).then(autoset_packing, [training_stage], [packing], queue=False)
training_stage.change(change_stage, [training_stage], [dataset, packing], queue=False)
dataset.focus(list_datasets, [dataset_dir, training_stage], [dataset], queue=False)
reward_model.focus(list_checkpoints, [model_name, finetuning_type], [reward_model], queue=False)
output_dir.change(
list_output_dirs, [model_name, finetuning_type, initial_dir], [output_dir], concurrency_limit=None
).then(check_output_dir, inputs=[lang, model_name, finetuning_type, output_dir], concurrency_limit=None)
return elem_dict