Release v0.1.6

Former-commit-id: 43c8b3c3c8bfb2e32d17fb3e8b194938e37d54bd
This commit is contained in:
hiyouga
2023-08-11 23:25:57 +08:00
parent 2144bb0e27
commit d5f1b99ac4
18 changed files with 127 additions and 41 deletions

View File

@@ -16,6 +16,6 @@ def create_preview_box() -> Tuple["Block", "Component", "Component", "Component"
close_btn = gr.Button()
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box])
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False)
return preview_box, preview_count, preview_samples, close_btn

View File

@@ -20,7 +20,12 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
dataset.change(can_preview, [dataset_dir, dataset], [preview_btn])
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
preview_btn.click(
get_preview,
[dataset_dir, dataset],
[preview_count, preview_samples, preview_box],
queue=False
)
with gr.Row():
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
@@ -33,6 +38,9 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
start_btn = gr.Button()
stop_btn = gr.Button()
with gr.Row():
process_bar = gr.Slider(visible=False, interactive=False)
with gr.Box():
output_box = gr.Markdown()
@@ -54,7 +62,10 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
batch_size,
predict
],
[output_box]
[
output_box,
process_bar
]
)
stop_btn.click(runner.set_abort, queue=False)

View File

@@ -22,7 +22,12 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
dataset.change(can_preview, [dataset_dir, dataset], [preview_btn])
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
preview_btn.click(
get_preview,
[dataset_dir, dataset],
[preview_count, preview_samples, preview_box],
queue=False
)
with gr.Row():
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
@@ -46,12 +51,14 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16")
padding_side = gr.Radio(choices=["left", "right"], value="left")
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
with gr.Row():
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
lora_dropout = gr.Slider(value=0, minimum=0, maximum=1, step=0.01, scale=1)
lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
lora_target = gr.Textbox(scale=2)
resume_lora_training = gr.Checkbox(value=True, scale=1)
with gr.Row():
start_btn = gr.Button()
@@ -59,7 +66,11 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
with gr.Row():
with gr.Column(scale=3):
output_dir = gr.Textbox()
with gr.Row():
output_dir = gr.Textbox()
with gr.Row():
process_bar = gr.Slider(visible=False, interactive=False)
with gr.Box():
output_box = gr.Markdown()
@@ -93,16 +104,21 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
save_steps,
warmup_steps,
compute_type,
padding_side,
lora_rank,
lora_dropout,
lora_target,
resume_lora_training,
output_dir
],
[output_box]
[
output_box,
process_bar
]
)
stop_btn.click(runner.set_abort, queue=False)
output_box.change(
process_bar.change(
gen_plot, [top_elems["model_name"], top_elems["finetuning_type"], output_dir], loss_viewer, queue=False
)
@@ -128,10 +144,12 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
save_steps=save_steps,
warmup_steps=warmup_steps,
compute_type=compute_type,
padding_side=padding_side,
lora_tab=lora_tab,
lora_rank=lora_rank,
lora_dropout=lora_dropout,
lora_target=lora_target,
resume_lora_training=resume_lora_training,
start_btn=start_btn,
stop_btn=stop_btn,
output_dir=output_dir,

View File

@@ -43,7 +43,7 @@ def create_top() -> Dict[str, "Component"]:
can_quantize, [finetuning_type], [quantization_bit]
)
refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints])
refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False)
return dict(
lang=lang,

View File

@@ -67,7 +67,7 @@ def create_web_demo() -> gr.Blocks:
demo.load(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
lang.change(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
lang.change(manager.gen_label, [lang], [lang] + list(chat_elems.values()), queue=False)
return demo

View File

@@ -277,6 +277,16 @@ LOCALES = {
"info": "是否启用 FP16 或 BF16 混合精度训练。"
}
},
"padding_side": {
"en": {
"label": "Padding side",
"info": "The side on which the model should have padding applied."
},
"zh": {
"label": "填充位置",
"info": "使用左填充或右填充。"
}
},
"lora_tab": {
"en": {
"label": "LoRA configurations"
@@ -315,6 +325,16 @@ LOCALES = {
"info": "应用 LoRA 的线性层名称。使用英文逗号分隔多个名称。"
}
},
"resume_lora_training": {
"en": {
"label": "Resume LoRA training",
"info": "Whether to resume training from the last LoRA weights or create new lora weights."
},
"zh": {
"label": "继续上次的训练",
"info": "接着上次的 LoRA 权重训练或创建一个新的 LoRA 权重。"
}
},
"start_btn": {
"en": {
"value": "Start"

View File

@@ -1,3 +1,4 @@
import gradio as gr
import logging
import os
import threading
@@ -13,7 +14,7 @@ from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import run_exp
from llmtuner.webui.common import get_model_path, get_save_dir
from llmtuner.webui.locales import ALERTS
from llmtuner.webui.utils import format_info, get_eval_results
from llmtuner.webui.utils import get_eval_results, update_process_bar
class Runner:
@@ -88,14 +89,16 @@ class Runner:
save_steps: int,
warmup_steps: int,
compute_type: str,
padding_side: str,
lora_rank: int,
lora_dropout: float,
lora_target: str,
resume_lora_training: bool,
output_dir: str
) -> Generator[str, None, None]:
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error:
yield error
yield error, gr.update(visible=False)
return
if checkpoints:
@@ -133,9 +136,11 @@ class Runner:
warmup_steps=warmup_steps,
fp16=(compute_type == "fp16"),
bf16=(compute_type == "bf16"),
padding_side=padding_side,
lora_rank=lora_rank,
lora_dropout=lora_dropout,
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
resume_lora_training=resume_lora_training,
output_dir=output_dir
)
@@ -150,18 +155,18 @@ class Runner:
thread.start()
while thread.is_alive():
time.sleep(1)
time.sleep(2)
if self.aborted:
yield ALERTS["info_aborting"][lang]
yield ALERTS["info_aborting"][lang], gr.update(visible=False)
else:
yield format_info(logger_handler.log, trainer_callback)
yield logger_handler.log, update_process_bar(trainer_callback)
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
finish_info = ALERTS["info_finished"][lang]
else:
finish_info = ALERTS["err_failed"][lang]
yield self.finalize(lang, finish_info)
yield self.finalize(lang, finish_info), gr.update(visible=False)
def run_eval(
self,
@@ -182,7 +187,7 @@ class Runner:
) -> Generator[str, None, None]:
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error:
yield error
yield error, gr.update(visible=False)
return
if checkpoints:
@@ -223,15 +228,15 @@ class Runner:
thread.start()
while thread.is_alive():
time.sleep(1)
time.sleep(2)
if self.aborted:
yield ALERTS["info_aborting"][lang]
yield ALERTS["info_aborting"][lang], gr.update(visible=False)
else:
yield format_info(logger_handler.log, trainer_callback)
yield logger_handler.log, update_process_bar(trainer_callback)
if os.path.exists(os.path.join(output_dir, "all_results.json")):
finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
else:
finish_info = ALERTS["err_failed"][lang]
yield self.finalize(lang, finish_info)
yield self.finalize(lang, finish_info), gr.update(visible=False)

View File

@@ -15,13 +15,18 @@ if TYPE_CHECKING:
from llmtuner.extras.callbacks import LogCallback
def format_info(log: str, callback: "LogCallback") -> str:
info = log
if callback.max_steps:
info += "Running **{:d}/{:d}**: {} < {}\n".format(
callback.cur_steps, callback.max_steps, callback.elapsed_time, callback.remaining_time
)
return info
def update_process_bar(callback: "LogCallback") -> Dict[str, Any]:
if not callback.max_steps:
return gr.update(visible=False)
percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0
label = "Running {:d}/{:d}: {} < {}".format(
callback.cur_steps,
callback.max_steps,
callback.elapsed_time,
callback.remaining_time
)
return gr.update(label=label, value=percentage, visible=True)
def get_time() -> str: