better llamaboard
* easily resume from checkpoint * support full and freeze checkpoints * faster ui Former-commit-id: 84cfb2452cc86b037ccddee6e833f8eb7c129fa4
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ..common import DEFAULT_DATA_DIR, list_dataset
|
||||
from ..common import DEFAULT_DATA_DIR, list_datasets
|
||||
from .data import create_preview_box
|
||||
|
||||
|
||||
@@ -74,6 +74,6 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
stop_btn.click(engine.runner.set_abort)
|
||||
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
|
||||
|
||||
dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False)
|
||||
dataset.focus(list_datasets, [dataset_dir], [dataset], queue=False)
|
||||
|
||||
return elem_dict
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import TYPE_CHECKING, Dict, Generator, List
|
||||
from typing import TYPE_CHECKING, Dict, Generator, List, Union
|
||||
|
||||
from ...extras.constants import PEFT_METHODS
|
||||
from ...extras.misc import torch_gc
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ...train.tuner import export_model
|
||||
@@ -24,8 +25,8 @@ def save_model(
|
||||
lang: str,
|
||||
model_name: str,
|
||||
model_path: str,
|
||||
adapter_path: List[str],
|
||||
finetuning_type: str,
|
||||
checkpoint_path: Union[str, List[str]],
|
||||
template: str,
|
||||
visual_inputs: bool,
|
||||
export_size: int,
|
||||
@@ -45,9 +46,9 @@ def save_model(
|
||||
error = ALERTS["err_no_export_dir"][lang]
|
||||
elif export_quantization_bit in GPTQ_BITS and not export_quantization_dataset:
|
||||
error = ALERTS["err_no_dataset"][lang]
|
||||
elif export_quantization_bit not in GPTQ_BITS and not adapter_path:
|
||||
elif export_quantization_bit not in GPTQ_BITS and not checkpoint_path:
|
||||
error = ALERTS["err_no_adapter"][lang]
|
||||
elif export_quantization_bit in GPTQ_BITS and adapter_path:
|
||||
elif export_quantization_bit in GPTQ_BITS and isinstance(checkpoint_path, list):
|
||||
error = ALERTS["err_gptq_lora"][lang]
|
||||
|
||||
if error:
|
||||
@@ -55,16 +56,8 @@ def save_model(
|
||||
yield error
|
||||
return
|
||||
|
||||
if adapter_path:
|
||||
adapter_name_or_path = ",".join(
|
||||
[get_save_dir(model_name, finetuning_type, adapter) for adapter in adapter_path]
|
||||
)
|
||||
else:
|
||||
adapter_name_or_path = None
|
||||
|
||||
args = dict(
|
||||
model_name_or_path=model_path,
|
||||
adapter_name_or_path=adapter_name_or_path,
|
||||
finetuning_type=finetuning_type,
|
||||
template=template,
|
||||
visual_inputs=visual_inputs,
|
||||
@@ -77,6 +70,14 @@ def save_model(
|
||||
export_legacy_format=export_legacy_format,
|
||||
)
|
||||
|
||||
if checkpoint_path:
|
||||
if finetuning_type in PEFT_METHODS: # list
|
||||
args["adapter_name_or_path"] = ",".join(
|
||||
[get_save_dir(model_name, finetuning_type, adapter) for adapter in checkpoint_path]
|
||||
)
|
||||
else: # str
|
||||
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, checkpoint_path)
|
||||
|
||||
yield ALERTS["info_exporting"][lang]
|
||||
export_model(args)
|
||||
torch_gc()
|
||||
@@ -86,7 +87,7 @@ def save_model(
|
||||
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
export_size = gr.Slider(minimum=1, maximum=100, value=1, step=1)
|
||||
export_quantization_bit = gr.Dropdown(choices=["none", "8", "4", "3", "2"], value="none")
|
||||
export_quantization_bit = gr.Dropdown(choices=["none"] + GPTQ_BITS, value="none")
|
||||
export_quantization_dataset = gr.Textbox(value="data/c4_demo.json")
|
||||
export_device = gr.Radio(choices=["cpu", "cuda"], value="cpu")
|
||||
export_legacy_format = gr.Checkbox()
|
||||
@@ -104,8 +105,8 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
engine.manager.get_elem_by_id("top.lang"),
|
||||
engine.manager.get_elem_by_id("top.model_name"),
|
||||
engine.manager.get_elem_by_id("top.model_path"),
|
||||
engine.manager.get_elem_by_id("top.adapter_path"),
|
||||
engine.manager.get_elem_by_id("top.finetuning_type"),
|
||||
engine.manager.get_elem_by_id("top.checkpoint_path"),
|
||||
engine.manager.get_elem_by_id("top.template"),
|
||||
engine.manager.get_elem_by_id("top.visual_inputs"),
|
||||
export_size,
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Dict
|
||||
from ...data import templates
|
||||
from ...extras.constants import METHODS, SUPPORTED_MODELS
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ..common import get_model_path, get_template, get_visual, list_adapters, save_config
|
||||
from ..common import get_model_info, list_checkpoints, save_config
|
||||
from ..utils import can_quantize
|
||||
|
||||
|
||||
@@ -25,8 +25,7 @@ def create_top() -> Dict[str, "Component"]:
|
||||
|
||||
with gr.Row():
|
||||
finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
|
||||
adapter_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=5)
|
||||
refresh_btn = gr.Button(scale=1)
|
||||
checkpoint_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=6)
|
||||
|
||||
with gr.Accordion(open=False) as advanced_tab:
|
||||
with gr.Row():
|
||||
@@ -36,27 +35,17 @@ def create_top() -> Dict[str, "Component"]:
|
||||
booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none", scale=3)
|
||||
visual_inputs = gr.Checkbox(scale=1)
|
||||
|
||||
model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
|
||||
get_model_path, [model_name], [model_path], queue=False
|
||||
).then(get_template, [model_name], [template], queue=False).then(
|
||||
get_visual, [model_name], [visual_inputs], queue=False
|
||||
) # do not save config since the below line will save
|
||||
|
||||
model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False)
|
||||
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
|
||||
|
||||
finetuning_type.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
|
||||
can_quantize, [finetuning_type], [quantization_bit], queue=False
|
||||
)
|
||||
|
||||
refresh_btn.click(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False)
|
||||
finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False)
|
||||
checkpoint_path.focus(list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False)
|
||||
|
||||
return dict(
|
||||
lang=lang,
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
finetuning_type=finetuning_type,
|
||||
adapter_path=adapter_path,
|
||||
refresh_btn=refresh_btn,
|
||||
checkpoint_path=checkpoint_path,
|
||||
advanced_tab=advanced_tab,
|
||||
quantization_bit=quantization_bit,
|
||||
template=template,
|
||||
|
||||
@@ -5,8 +5,9 @@ from transformers.trainer_utils import SchedulerType
|
||||
from ...extras.constants import TRAINING_STAGES
|
||||
from ...extras.misc import get_device_count
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset
|
||||
from ..components.data import create_preview_box
|
||||
from ..common import DEFAULT_DATA_DIR, list_checkpoints, list_datasets
|
||||
from ..utils import change_stage, check_output_dir, list_output_dirs
|
||||
from .data import create_preview_box
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
@@ -256,11 +257,12 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
with gr.Row():
|
||||
output_dir = gr.Textbox()
|
||||
initial_dir = gr.Textbox(visible=False, interactive=False)
|
||||
output_dir = gr.Dropdown(allow_custom_value=True)
|
||||
config_path = gr.Textbox()
|
||||
|
||||
with gr.Row():
|
||||
device_count = gr.Textbox(value=str(get_device_count()), interactive=False)
|
||||
device_count = gr.Textbox(value=str(get_device_count() or 1), interactive=False)
|
||||
ds_stage = gr.Dropdown(choices=["none", "2", "3"], value="none")
|
||||
ds_offload = gr.Checkbox()
|
||||
|
||||
@@ -282,6 +284,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
arg_load_btn=arg_load_btn,
|
||||
start_btn=start_btn,
|
||||
stop_btn=stop_btn,
|
||||
initial_dir=initial_dir,
|
||||
output_dir=output_dir,
|
||||
config_path=config_path,
|
||||
device_count=device_count,
|
||||
@@ -295,24 +298,24 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
)
|
||||
output_elems = [output_box, progress_bar, loss_viewer]
|
||||
|
||||
lang = engine.manager.get_elem_by_id("top.lang")
|
||||
model_name = engine.manager.get_elem_by_id("top.model_name")
|
||||
finetuning_type = engine.manager.get_elem_by_id("top.finetuning_type")
|
||||
|
||||
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
|
||||
arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)
|
||||
arg_load_btn.click(
|
||||
engine.runner.load_args,
|
||||
[engine.manager.get_elem_by_id("top.lang"), config_path],
|
||||
list(input_elems) + [output_box],
|
||||
concurrency_limit=None,
|
||||
engine.runner.load_args, [lang, config_path], list(input_elems) + [output_box], concurrency_limit=None
|
||||
)
|
||||
start_btn.click(engine.runner.run_train, input_elems, output_elems)
|
||||
stop_btn.click(engine.runner.set_abort)
|
||||
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
|
||||
|
||||
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
|
||||
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then(
|
||||
list_adapters,
|
||||
[engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.finetuning_type")],
|
||||
[reward_model],
|
||||
queue=False,
|
||||
).then(autoset_packing, [training_stage], [packing], queue=False)
|
||||
training_stage.change(change_stage, [training_stage], [dataset, packing], queue=False)
|
||||
dataset.focus(list_datasets, [dataset_dir, training_stage], [dataset], queue=False)
|
||||
reward_model.focus(list_checkpoints, [model_name, finetuning_type], [reward_model], queue=False)
|
||||
output_dir.change(
|
||||
list_output_dirs, [model_name, finetuning_type, initial_dir], [output_dir], concurrency_limit=None
|
||||
).then(check_output_dir, inputs=[lang, model_name, finetuning_type, output_dir], concurrency_limit=None)
|
||||
|
||||
return elem_dict
|
||||
|
||||
Reference in New Issue
Block a user