some ideas are borrowed from @marko1616 Former-commit-id: b5a062aa2d4a37670007e8b3dae5b6f5b7ffb15c
This commit is contained in:
@@ -27,8 +27,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
preview_elems = create_preview_box(dataset_dir, dataset)
|
||||
|
||||
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
|
||||
|
||||
input_elems.update({training_stage, dataset_dir, dataset})
|
||||
elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
|
||||
|
||||
@@ -127,19 +125,30 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
|
||||
with gr.Accordion(open=False) as lora_tab:
|
||||
with gr.Row():
|
||||
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
|
||||
lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=1, scale=1)
|
||||
lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
|
||||
lora_target = gr.Textbox(scale=2)
|
||||
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1)
|
||||
lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=1)
|
||||
lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01)
|
||||
loraplus_lr_ratio = gr.Slider(value=0, minimum=0, maximum=64, step=0.01)
|
||||
create_new_adapter = gr.Checkbox()
|
||||
|
||||
with gr.Row():
|
||||
use_rslora = gr.Checkbox(scale=1)
|
||||
use_dora = gr.Checkbox(scale=1)
|
||||
create_new_adapter = gr.Checkbox(scale=1)
|
||||
lora_target = gr.Textbox(scale=2)
|
||||
additional_target = gr.Textbox(scale=2)
|
||||
|
||||
input_elems.update(
|
||||
{lora_rank, lora_alpha, lora_dropout, lora_target, use_rslora, use_dora, create_new_adapter, additional_target}
|
||||
{
|
||||
lora_rank,
|
||||
lora_alpha,
|
||||
lora_dropout,
|
||||
loraplus_lr_ratio,
|
||||
create_new_adapter,
|
||||
use_rslora,
|
||||
use_dora,
|
||||
lora_target,
|
||||
additional_target,
|
||||
}
|
||||
)
|
||||
elem_dict.update(
|
||||
dict(
|
||||
@@ -147,10 +156,11 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
lora_rank=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
lora_target=lora_target,
|
||||
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||
create_new_adapter=create_new_adapter,
|
||||
use_rslora=use_rslora,
|
||||
use_dora=use_dora,
|
||||
create_new_adapter=create_new_adapter,
|
||||
lora_target=lora_target,
|
||||
additional_target=additional_target,
|
||||
)
|
||||
)
|
||||
@@ -161,13 +171,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01, scale=1)
|
||||
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=2)
|
||||
|
||||
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then(
|
||||
list_adapters,
|
||||
[engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.finetuning_type")],
|
||||
[reward_model],
|
||||
queue=False,
|
||||
).then(autoset_packing, [training_stage], [packing], queue=False)
|
||||
|
||||
input_elems.update({dpo_beta, dpo_ftx, reward_model})
|
||||
elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model))
|
||||
|
||||
@@ -177,7 +180,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
galore_rank = gr.Slider(value=16, minimum=1, maximum=1024, step=1, scale=2)
|
||||
galore_update_interval = gr.Slider(value=200, minimum=1, maximum=1024, step=1, scale=2)
|
||||
galore_scale = gr.Slider(value=0.25, minimum=0, maximum=1, step=0.01, scale=2)
|
||||
galore_target = gr.Textbox(value="mlp,attn", scale=3)
|
||||
galore_target = gr.Textbox(value="all", scale=3)
|
||||
|
||||
input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target})
|
||||
elem_dict.update(
|
||||
@@ -193,13 +196,16 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
|
||||
with gr.Row():
|
||||
cmd_preview_btn = gr.Button()
|
||||
start_btn = gr.Button()
|
||||
stop_btn = gr.Button()
|
||||
arg_save_btn = gr.Button()
|
||||
arg_load_btn = gr.Button()
|
||||
start_btn = gr.Button(variant="primary")
|
||||
stop_btn = gr.Button(variant="stop")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
with gr.Row():
|
||||
output_dir = gr.Textbox()
|
||||
config_path = gr.Textbox()
|
||||
|
||||
with gr.Row():
|
||||
resume_btn = gr.Checkbox(visible=False, interactive=False)
|
||||
@@ -211,20 +217,38 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
with gr.Column(scale=1):
|
||||
loss_viewer = gr.Plot()
|
||||
|
||||
input_elems.add(output_dir)
|
||||
input_elems.update({output_dir, config_path})
|
||||
output_elems = [output_box, process_bar]
|
||||
|
||||
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
|
||||
arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)
|
||||
arg_load_btn.click(
|
||||
engine.runner.load_args,
|
||||
[engine.manager.get_elem_by_id("top.lang"), config_path],
|
||||
list(input_elems),
|
||||
concurrency_limit=None,
|
||||
)
|
||||
start_btn.click(engine.runner.run_train, input_elems, output_elems)
|
||||
stop_btn.click(engine.runner.set_abort, queue=False)
|
||||
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
|
||||
|
||||
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
|
||||
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then(
|
||||
list_adapters,
|
||||
[engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.finetuning_type")],
|
||||
[reward_model],
|
||||
queue=False,
|
||||
).then(autoset_packing, [training_stage], [packing], queue=False)
|
||||
|
||||
elem_dict.update(
|
||||
dict(
|
||||
cmd_preview_btn=cmd_preview_btn,
|
||||
arg_save_btn=arg_save_btn,
|
||||
arg_load_btn=arg_load_btn,
|
||||
start_btn=start_btn,
|
||||
stop_btn=stop_btn,
|
||||
output_dir=output_dir,
|
||||
config_path=config_path,
|
||||
resume_btn=resume_btn,
|
||||
process_bar=process_bar,
|
||||
output_box=output_box,
|
||||
|
||||
Reference in New Issue
Block a user