update web UI, support rm predict #210

Former-commit-id: 92cc6b655dc91b94d5bf9d8618c3b57d5cf94333
This commit is contained in:
hiyouga
2023-07-21 13:27:27 +08:00
parent c4e9694c6e
commit 0f7cdac207
13 changed files with 192 additions and 27 deletions

View File

@@ -1,4 +1,5 @@
from llmtuner.webui.components.eval import create_eval_tab
from llmtuner.webui.components.infer import create_infer_tab
from llmtuner.webui.components.top import create_top
from llmtuner.webui.components.sft import create_sft_tab
from llmtuner.webui.components.eval import create_eval_tab
from llmtuner.webui.components.infer import create_infer_tab
from llmtuner.webui.components.export import create_export_tab

View File

@@ -22,13 +22,9 @@ def create_chat_box(
with gr.Column(scale=1):
clear_btn = gr.Button()
max_new_tokens = gr.Slider(
10, 2048, value=chat_model.generating_args.max_new_tokens, step=1, interactive=True
)
top_p = gr.Slider(0.01, 1, value=chat_model.generating_args.top_p, step=0.01, interactive=True)
temperature = gr.Slider(
0.01, 1.5, value=chat_model.generating_args.temperature, step=0.01, interactive=True
)
max_new_tokens = gr.Slider(10, 2048, value=chat_model.generating_args.max_new_tokens, step=1)
top_p = gr.Slider(0.01, 1, value=chat_model.generating_args.top_p, step=0.01)
temperature = gr.Slider(0.01, 1.5, value=chat_model.generating_args.temperature, step=0.01)
history = gr.State([])

View File

@@ -0,0 +1,34 @@
from typing import Dict
import gradio as gr
from gradio.components import Component
from llmtuner.webui.utils import export_model
def create_export_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
with gr.Row():
save_dir = gr.Textbox()
max_shard_size = gr.Slider(value=10, minimum=1, maximum=100)
export_btn = gr.Button()
info_box = gr.Textbox(show_label=False, interactive=False)
export_btn.click(
export_model,
[
top_elems["lang"],
top_elems["model_name"],
top_elems["checkpoints"],
top_elems["finetuning_type"],
max_shard_size,
save_dir
],
[info_box]
)
return dict(
save_dir=save_dir,
max_shard_size=max_shard_size,
export_btn=export_btn,
info_box=info_box
)

View File

@@ -57,7 +57,7 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
with gr.Row():
with gr.Column(scale=4):
output_dir = gr.Textbox(interactive=True)
output_dir = gr.Textbox()
with gr.Box():
output_box = gr.Markdown()

View File

@@ -5,7 +5,8 @@ from llmtuner.webui.components import (
create_top,
create_sft_tab,
create_eval_tab,
create_infer_tab
create_infer_tab,
create_export_tab
)
from llmtuner.webui.css import CSS
from llmtuner.webui.manager import Manager
@@ -30,7 +31,10 @@ def create_ui() -> gr.Blocks:
with gr.Tab("Chat"):
infer_elems = create_infer_tab(top_elems)
elem_list = [top_elems, sft_elems, eval_elems, infer_elems]
with gr.Tab("Export"):
export_elems = create_export_tab(top_elems)
elem_list = [top_elems, sft_elems, eval_elems, infer_elems, export_elems]
manager = Manager(elem_list)
demo.load(

View File

@@ -452,6 +452,34 @@ LOCALES = {
"zh": {
"label": "温度系数"
}
},
"save_dir": {
"en": {
"label": "Export dir",
"info": "Directory to save exported model."
},
"zh": {
"label": "导出目录",
"info": "保存导出模型的文件夹路径。"
}
},
"max_shard_size": {
"en": {
"label": "Max shard size (GB)",
"info": "The maximum size for a model file."
},
"zh": {
"label": "最大分块大小GB",
"info": "模型文件的最大大小。"
}
},
"export_btn": {
"en": {
"value": "Export"
},
"zh": {
"value": "开始导出"
}
}
}
@@ -477,6 +505,14 @@ ALERTS = {
"en": "Please choose a dataset.",
"zh": "请选择数据集。"
},
"err_no_checkpoint": {
"en": "Please select a checkpoint.",
"zh": "请选择断点。"
},
"err_no_save_dir": {
"en": "Please provide export dir.",
"zh": "请填写导出目录"
},
"info_aborting": {
"en": "Aborted, wait for terminating...",
"zh": "训练中断,正在等待线程结束……"
@@ -504,5 +540,13 @@ ALERTS = {
"info_unloaded": {
"en": "Model unloaded.",
"zh": "模型已卸载。"
},
"info_exporting": {
"en": "Exporting model...",
"zh": "正在导出模型……"
},
"info_exported": {
"en": "Model exported.",
"zh": "模型导出完成。"
}
}

View File

@@ -3,7 +3,7 @@ import os
import threading
import time
import transformers
from typing import List, Optional, Tuple
from typing import Generator, List, Optional, Tuple
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import DEFAULT_MODULE
@@ -25,7 +25,9 @@ class Runner:
self.aborted = True
self.running = False
def initialize(self, lang: str, model_name: str, dataset: list) -> Tuple[str, str, LoggerHandler, LogCallback]:
def initialize(
self, lang: str, model_name: str, dataset: list
) -> Tuple[str, str, LoggerHandler, LogCallback]:
if self.running:
return None, ALERTS["err_conflict"][lang], None, None
@@ -50,7 +52,9 @@ class Runner:
return model_name_or_path, "", logger_handler, trainer_callback
def finalize(self, lang: str, finish_info: Optional[str] = None) -> str:
def finalize(
self, lang: str, finish_info: Optional[str] = None
) -> str:
self.running = False
torch_gc()
if self.aborted:
@@ -87,7 +91,7 @@ class Runner:
lora_dropout: float,
lora_target: str,
output_dir: str
):
) -> Generator[str, None, None]:
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error:
yield error
@@ -174,7 +178,7 @@ class Runner:
max_samples: str,
batch_size: int,
predict: bool
):
) -> Generator[str, None, None]:
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error:
yield error

View File

@@ -3,11 +3,13 @@ import json
import gradio as gr
import matplotlib.figure
import matplotlib.pyplot as plt
from typing import Any, Dict, Tuple
from typing import Any, Dict, Generator, List, Tuple
from datetime import datetime
from llmtuner.extras.ploting import smooth
from llmtuner.webui.common import get_save_dir, DATA_CONFIG
from llmtuner.tuner import get_infer_args, load_model_and_tokenizer
from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG
from llmtuner.webui.locales import ALERTS
def format_info(log: str, tracker: dict) -> str:
@@ -83,3 +85,41 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl
ax.set_xlabel("step")
ax.set_ylabel("loss")
return fig
def export_model(
lang: str, model_name: str, checkpoints: List[str], finetuning_type: str, max_shard_size: int, save_dir: str
) -> Generator[str, None, None]:
if not model_name:
yield ALERTS["err_no_model"][lang]
return
model_name_or_path = get_model_path(model_name)
if not model_name_or_path:
yield ALERTS["err_no_path"][lang]
return
if not checkpoints:
yield ALERTS["err_no_checkpoint"][lang]
return
checkpoint_dir = ",".join(
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
)
if not save_dir:
yield ALERTS["err_no_save_dir"][lang]
return
args = dict(
model_name_or_path=model_name_or_path,
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type
)
yield ALERTS["info_exporting"][lang]
model_args, _, finetuning_args, _ = get_infer_args(args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
model.save_pretrained(save_dir, max_shard_size=str(max_shard_size)+"GB")
tokenizer.save_pretrained(save_dir)
yield ALERTS["info_exported"][lang]