release v0.1.0

Former-commit-id: 63c8d3a17cb18f0d8a8e37bfa147daf5bdd28ea9
This commit is contained in:
hiyouga
2023-07-18 00:18:25 +08:00
parent c08ff734a7
commit eac7f97337
30 changed files with 1513 additions and 309 deletions

View File

@@ -0,0 +1 @@
from llmtuner.webui.interface import create_ui

View File

@@ -0,0 +1,79 @@
import os
from typing import List, Tuple
from llmtuner.chat.stream_chat import ChatModel
from llmtuner.extras.misc import torch_gc
from llmtuner.hparams import GeneratingArguments
from llmtuner.tuner import get_infer_args
from llmtuner.webui.common import get_model_path, get_save_dir
from llmtuner.webui.locales import ALERTS
class WebChatModel(ChatModel):
def __init__(self):
self.model = None
self.tokenizer = None
self.generating_args = GeneratingArguments()
def load_model(
self, lang: str, model_name: str, checkpoints: list,
finetuning_type: str, template: str, quantization_bit: str
):
if self.model is not None:
yield ALERTS["err_exists"][lang]
return
if not model_name:
yield ALERTS["err_no_model"][lang]
return
model_name_or_path = get_model_path(model_name)
if not model_name_or_path:
yield ALERTS["err_no_path"][lang]
return
if checkpoints:
checkpoint_dir = ",".join(
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
)
else:
checkpoint_dir = None
yield ALERTS["info_loading"][lang]
args = dict(
model_name_or_path=model_name_or_path,
finetuning_type=finetuning_type,
prompt_template=template,
checkpoint_dir=checkpoint_dir,
quantization_bit=int(quantization_bit) if quantization_bit else None
)
super().__init__(*get_infer_args(args))
yield ALERTS["info_loaded"][lang]
def unload_model(self, lang: str):
yield ALERTS["info_unloading"][lang]
self.model = None
self.tokenizer = None
torch_gc()
yield ALERTS["info_unloaded"][lang]
def predict(
self,
chatbot: List[Tuple[str, str]],
query: str,
history: List[Tuple[str, str]],
max_new_tokens: int,
top_p: float,
temperature: float
):
chatbot.append([query, ""])
response = ""
for new_text in self.stream_chat(
query, history, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
):
response += new_text
new_history = history + [(query, response)]
chatbot[-1] = [query, response]
yield chatbot, new_history

View File

@@ -0,0 +1,75 @@
import json
import os
from typing import Any, Dict, Optional
import gradio as gr
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from llmtuner.extras.constants import SUPPORTED_MODELS
DEFAULT_CACHE_DIR = "cache"
DEFAULT_DATA_DIR = "data"
DEFAULT_SAVE_DIR = "saves"
USER_CONFIG = "user.config"
DATA_CONFIG = "dataset_info.json"
def get_save_dir(model_name: str) -> str:
return os.path.join(DEFAULT_SAVE_DIR, os.path.split(model_name)[-1])
def get_config_path() -> os.PathLike:
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
def load_config() -> Dict[str, Any]:
try:
with open(get_config_path(), "r", encoding="utf-8") as f:
return json.load(f)
except:
return {"last_model": "", "path_dict": {}}
def save_config(model_name: str, model_path: str) -> None:
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
user_config = load_config()
user_config["last_model"] = model_name
user_config["path_dict"][model_name] = model_path
with open(get_config_path(), "w", encoding="utf-8") as f:
json.dump(user_config, f, indent=2, ensure_ascii=False)
def get_model_path(model_name: str) -> str:
user_config = load_config()
return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, ""))
def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
checkpoints = []
save_dir = os.path.join(get_save_dir(model_name), finetuning_type)
if save_dir and os.path.isdir(save_dir):
for checkpoint in os.listdir(save_dir):
if (
os.path.isdir(os.path.join(save_dir, checkpoint))
and any([
os.path.isfile(os.path.join(save_dir, checkpoint, name))
for name in (WEIGHTS_NAME, WEIGHTS_INDEX_NAME, PEFT_WEIGHTS_NAME)
])
):
checkpoints.append(checkpoint)
return gr.update(value=[], choices=checkpoints)
def load_dataset_info(dataset_dir: str) -> Dict[str, Any]:
try:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
return json.load(f)
except:
return {}
def list_dataset(dataset_dir: Optional[str] = None) -> Dict[str, Any]:
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
return gr.update(value=[], choices=list(dataset_info.keys()))

View File

@@ -0,0 +1,4 @@
from llmtuner.webui.components.eval import create_eval_tab
from llmtuner.webui.components.infer import create_infer_tab
from llmtuner.webui.components.top import create_top
from llmtuner.webui.components.sft import create_sft_tab

View File

@@ -0,0 +1,54 @@
from typing import Dict, Tuple
import gradio as gr
from gradio.blocks import Block
from gradio.components import Component
from llmtuner.webui.chat import WebChatModel
def create_chat_box(
chat_model: WebChatModel
) -> Tuple[Block, Component, Component, Dict[str, Component]]:
with gr.Box(visible=False) as chat_box:
chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=4):
with gr.Column(scale=12):
query = gr.Textbox(show_label=False, lines=8)
with gr.Column(min_width=32, scale=1):
submit_btn = gr.Button(variant="primary")
with gr.Column(scale=1):
clear_btn = gr.Button()
max_new_tokens = gr.Slider(
10, 2048, value=chat_model.generating_args.max_new_tokens, step=1, interactive=True
)
top_p = gr.Slider(0.01, 1, value=chat_model.generating_args.top_p, step=0.01, interactive=True)
temperature = gr.Slider(
0.01, 1.5, value=chat_model.generating_args.temperature, step=0.01, interactive=True
)
history = gr.State([])
submit_btn.click(
chat_model.predict,
[chatbot, query, history, max_new_tokens, top_p, temperature],
[chatbot, history],
show_progress=True
).then(
lambda: gr.update(value=""), outputs=[query]
)
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
return chat_box, chatbot, history, dict(
query=query,
submit_btn=submit_btn,
clear_btn=clear_btn,
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature
)

View File

@@ -0,0 +1,19 @@
import gradio as gr
from gradio.blocks import Block
from gradio.components import Component
from typing import Tuple
def create_preview_box() -> Tuple[Block, Component, Component, Component]:
with gr.Box(visible=False, elem_classes="modal-box") as preview_box:
with gr.Row():
preview_count = gr.Number(interactive=False)
with gr.Row():
preview_samples = gr.JSON(interactive=False)
close_btn = gr.Button()
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box])
return preview_box, preview_count, preview_samples, close_btn

View File

@@ -0,0 +1,60 @@
from typing import Dict
import gradio as gr
from gradio.components import Component
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box
from llmtuner.webui.runner import Runner
from llmtuner.webui.utils import can_preview, get_preview
def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, interactive=True, scale=2)
dataset = gr.Dropdown(multiselect=True, interactive=True, scale=4)
preview_btn = gr.Button(interactive=False, scale=1)
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
dataset.change(can_preview, [dataset_dir, dataset], [preview_btn])
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
with gr.Row():
max_samples = gr.Textbox(value="100000", interactive=True)
batch_size = gr.Slider(value=8, minimum=1, maximum=128, step=1, interactive=True)
quantization_bit = gr.Dropdown([8, 4])
predict = gr.Checkbox(value=True)
with gr.Row():
start_btn = gr.Button()
stop_btn = gr.Button()
output_box = gr.Markdown()
start_btn.click(
runner.run_eval,
[
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"],
top_elems["finetuning_type"], top_elems["template"],
dataset, dataset_dir, max_samples, batch_size, quantization_bit, predict
],
[output_box]
)
stop_btn.click(runner.set_abort, queue=False)
return dict(
dataset_dir=dataset_dir,
dataset=dataset,
preview_btn=preview_btn,
preview_count=preview_count,
preview_samples=preview_samples,
close_btn=close_btn,
max_samples=max_samples,
batch_size=batch_size,
quantization_bit=quantization_bit,
predict=predict,
start_btn=start_btn,
stop_btn=stop_btn,
output_box=output_box
)

View File

@@ -0,0 +1,47 @@
from typing import Dict
import gradio as gr
from gradio.components import Component
from llmtuner.webui.chat import WebChatModel
from llmtuner.webui.components.chatbot import create_chat_box
def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
with gr.Row():
load_btn = gr.Button()
unload_btn = gr.Button()
quantization_bit = gr.Dropdown([8, 4])
info_box = gr.Markdown()
chat_model = WebChatModel()
chat_box, chatbot, history, chat_elems = create_chat_box(chat_model)
load_btn.click(
chat_model.load_model,
[
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"],
top_elems["finetuning_type"], top_elems["template"],
quantization_bit
],
[info_box]
).then(
lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box]
)
unload_btn.click(
chat_model.unload_model, [top_elems["lang"]], [info_box]
).then(
lambda: ([], []), outputs=[chatbot, history]
).then(
lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box]
)
return dict(
quantization_bit=quantization_bit,
info_box=info_box,
load_btn=load_btn,
unload_btn=unload_btn,
**chat_elems
)

View File

@@ -0,0 +1,94 @@
from typing import Dict
from transformers.trainer_utils import SchedulerType
import gradio as gr
from gradio.components import Component
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box
from llmtuner.webui.runner import Runner
from llmtuner.webui.utils import can_preview, get_preview, gen_plot
def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, interactive=True, scale=1)
dataset = gr.Dropdown(multiselect=True, interactive=True, scale=4)
preview_btn = gr.Button(interactive=False, scale=1)
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
dataset.change(can_preview, [dataset_dir, dataset], [preview_btn])
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
with gr.Row():
learning_rate = gr.Textbox(value="5e-5", interactive=True)
num_train_epochs = gr.Textbox(value="3.0", interactive=True)
max_samples = gr.Textbox(value="100000", interactive=True)
quantization_bit = gr.Dropdown([8, 4])
with gr.Row():
batch_size = gr.Slider(value=4, minimum=1, maximum=128, step=1, interactive=True)
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=32, step=1, interactive=True)
lr_scheduler_type = gr.Dropdown(
value="cosine", choices=[scheduler.value for scheduler in SchedulerType], interactive=True
)
fp16 = gr.Checkbox(value=True)
with gr.Row():
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5, interactive=True)
save_steps = gr.Slider(value=100, minimum=10, maximum=2000, step=10, interactive=True)
with gr.Row():
start_btn = gr.Button()
stop_btn = gr.Button()
with gr.Row():
with gr.Column(scale=4):
output_dir = gr.Textbox(interactive=True)
output_box = gr.Markdown()
with gr.Column(scale=1):
loss_viewer = gr.Plot()
start_btn.click(
runner.run_train,
[
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"],
top_elems["finetuning_type"], top_elems["template"],
dataset, dataset_dir, learning_rate, num_train_epochs, max_samples,
fp16, quantization_bit, batch_size, gradient_accumulation_steps,
lr_scheduler_type, logging_steps, save_steps, output_dir
],
[output_box]
)
stop_btn.click(runner.set_abort, queue=False)
output_box.change(
gen_plot, [top_elems["model_name"], top_elems["finetuning_type"], output_dir], loss_viewer, queue=False
)
return dict(
dataset_dir=dataset_dir,
dataset=dataset,
preview_btn=preview_btn,
preview_count=preview_count,
preview_samples=preview_samples,
close_btn=close_btn,
learning_rate=learning_rate,
num_train_epochs=num_train_epochs,
max_samples=max_samples,
quantization_bit=quantization_bit,
batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type,
fp16=fp16,
logging_steps=logging_steps,
save_steps=save_steps,
start_btn=start_btn,
stop_btn=stop_btn,
output_dir=output_dir,
output_box=output_box,
loss_viewer=loss_viewer
)

View File

@@ -0,0 +1,42 @@
from typing import Dict
import gradio as gr
from gradio.components import Component
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
from llmtuner.extras.template import templates
from llmtuner.webui.common import list_checkpoint, get_model_path, save_config
def create_top() -> Dict[str, Component]:
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
with gr.Row():
lang = gr.Dropdown(choices=["en", "zh"], value="en", interactive=True, scale=1)
model_name = gr.Dropdown(choices=available_models, scale=3)
model_path = gr.Textbox(scale=3)
with gr.Row():
finetuning_type = gr.Dropdown(value="lora", choices=METHODS, interactive=True, scale=1)
template = gr.Dropdown(value="default", choices=list(templates.keys()), interactive=True, scale=1)
checkpoints = gr.Dropdown(multiselect=True, interactive=True, scale=4)
refresh_btn = gr.Button(scale=1)
model_name.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints]
).then(
get_model_path, [model_name], [model_path]
) # do not save config since the below line will save
model_path.change(save_config, [model_name, model_path])
finetuning_type.change(list_checkpoint, [model_name, finetuning_type], [checkpoints])
refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints])
return dict(
lang=lang,
model_name=model_name,
model_path=model_path,
finetuning_type=finetuning_type,
template=template,
checkpoints=checkpoints,
refresh_btn=refresh_btn
)

18
src/llmtuner/webui/css.py Normal file
View File

@@ -0,0 +1,18 @@
CSS = r"""
.modal-box {
position: fixed !important;
top: 50%;
left: 50%;
transform: translate(-50%, -50%); /* center horizontally */
max-width: 1000px;
max-height: 750px;
overflow-y: scroll !important;
background-color: var(--input-background-fill);
border: 2px solid black !important;
z-index: 1000;
}
.dark .modal-box {
border: 2px solid white !important;
}
"""

View File

@@ -0,0 +1,54 @@
import gradio as gr
from transformers.utils.versions import require_version
from llmtuner.webui.components import (
create_top,
create_sft_tab,
create_eval_tab,
create_infer_tab
)
from llmtuner.webui.css import CSS
from llmtuner.webui.manager import Manager
from llmtuner.webui.runner import Runner
require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
def create_ui() -> gr.Blocks:
runner = Runner()
with gr.Blocks(title="Web Tuner", css=CSS) as demo:
top_elems = create_top()
with gr.Tab("SFT"):
sft_elems = create_sft_tab(top_elems, runner)
with gr.Tab("Evaluate"):
eval_elems = create_eval_tab(top_elems, runner)
with gr.Tab("Inference"):
infer_elems = create_infer_tab(top_elems)
elem_list = [top_elems, sft_elems, eval_elems, infer_elems]
manager = Manager(elem_list)
demo.load(
manager.gen_label,
[top_elems["lang"]],
[elem for elems in elem_list for elem in elems.values()],
)
top_elems["lang"].change(
manager.gen_label,
[top_elems["lang"]],
[elem for elems in elem_list for elem in elems.values()],
)
return demo
if __name__ == "__main__":
demo = create_ui()
demo.queue()
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)

View File

@@ -0,0 +1,384 @@
LOCALES = {
"lang": {
"en": {
"label": "Lang"
},
"zh": {
"label": "语言"
}
},
"model_name": {
"en": {
"label": "Model name"
},
"zh": {
"label": "模型名称"
}
},
"model_path": {
"en": {
"label": "Model path",
"info": "Path to pretrained model or model identifier from Hugging Face."
},
"zh": {
"label": "模型路径",
"info": "本地模型的文件路径或 Hugging Face 的模型标识符。"
}
},
"checkpoints": {
"en": {
"label": "Checkpoints"
},
"zh": {
"label": "模型断点"
}
},
"template": {
"en": {
"label": "Prompt template"
},
"zh": {
"label": "提示模板"
}
},
"refresh_btn": {
"en": {
"value": "Refresh checkpoints"
},
"zh": {
"value": "刷新断点"
}
},
"dataset_dir": {
"en": {
"label": "Data dir",
"info": "Path of the data directory."
},
"zh": {
"label": "数据路径",
"info": "数据文件夹的路径。"
}
},
"dataset": {
"en": {
"label": "Dataset"
},
"zh": {
"label": "数据集"
}
},
"preview_btn": {
"en": {
"value": "Preview"
},
"zh": {
"value": "预览"
}
},
"preview_count": {
"en": {
"label": "Count"
},
"zh": {
"label": "数量"
}
},
"preview_samples": {
"en": {
"label": "Samples"
},
"zh": {
"label": "样例"
}
},
"close_btn": {
"en": {
"value": "Close"
},
"zh": {
"value": "关闭"
}
},
"max_samples": {
"en": {
"label": "Max samples",
"info": "Maximum samples per dataset."
},
"zh": {
"label": "最大样本数",
"info": "每个数据集最多使用的样本数。"
}
},
"batch_size": {
"en": {
"label": "Batch size",
"info": "Number of samples to process per GPU."
},
"zh":{
"label": "批处理大小",
"info": "每块 GPU 上处理的样本数量。"
}
},
"quantization_bit": {
"en": {
"label": "Quantization bit",
"info": "Enable 4/8-bit model quantization."
},
"zh": {
"label": "量化",
"info": "启用 4/8 比特模型量化。"
}
},
"start_btn": {
"en": {
"value": "Start"
},
"zh": {
"value": "开始"
}
},
"stop_btn": {
"en": {
"value": "Abort"
},
"zh": {
"value": "中断"
}
},
"output_box": {
"en": {
"value": "Ready."
},
"zh": {
"value": "准备就绪。"
}
},
"finetuning_type": {
"en": {
"label": "Finetuning method"
},
"zh": {
"label": "微调方法"
}
},
"learning_rate": {
"en": {
"label": "Learning rate",
"info": "Initial learning rate for AdamW."
},
"zh": {
"label": "学习率",
"info": "AdamW 优化器的初始学习率。"
}
},
"num_train_epochs": {
"en": {
"label": "Epochs",
"info": "Total number of training epochs to perform."
},
"zh": {
"label": "训练轮数",
"info": "需要执行的训练总轮数。"
}
},
"gradient_accumulation_steps": {
"en": {
"label": "Gradient accumulation",
"info": "Number of gradient accumulation steps."
},
"zh": {
"label": "梯度累积",
"info": "梯度累积的步数。"
}
},
"lr_scheduler_type": {
"en": {
"label": "LR Scheduler",
"info": "Name of learning rate scheduler.",
},
"zh": {
"label": "学习率调节器",
"info": "采用的学习率调节器名称。"
}
},
"fp16": {
"en": {
"label": "fp16",
"info": "Whether to use fp16 mixed precision training."
},
"zh": {
"label": "fp16",
"info": "是否启用 FP16 混合精度训练。"
}
},
"logging_steps": {
"en": {
"label": "Logging steps",
"info": "Number of update steps between two logs."
},
"zh": {
"label": "日志间隔",
"info": "每两次日志输出间的更新步数。"
}
},
"save_steps": {
"en": {
"label": "Save steps",
"info": "Number of updates steps between two checkpoints."
},
"zh": {
"label": "保存间隔",
"info": "每两次断点保存间的更新步数。"
}
},
"output_dir": {
"en": {
"label": "Checkpoint name",
"info": "Directory to save checkpoint."
},
"zh": {
"label": "断点名称",
"info": "保存模型断点的文件夹名称。"
}
},
"loss_viewer": {
"en": {
"label": "Loss"
},
"zh": {
"label": "损失"
}
},
"predict": {
"en": {
"label": "Save predictions"
},
"zh": {
"label": "保存预测结果"
}
},
"info_box": {
"en": {
"value": "Model unloaded, please load a model first."
},
"zh": {
"value": "模型未加载,请先加载模型。"
}
},
"load_btn": {
"en": {
"value": "Load model"
},
"zh": {
"value": "加载模型"
}
},
"unload_btn": {
"en": {
"value": "Unload model"
},
"zh": {
"value": "卸载模型"
}
},
"query": {
"en": {
"placeholder": "Input..."
},
"zh": {
"placeholder": "输入..."
}
},
"submit_btn": {
"en": {
"value": "Submit"
},
"zh": {
"value": "提交"
}
},
"clear_btn": {
"en": {
"value": "Clear history"
},
"zh": {
"value": "清空历史"
}
},
"max_new_tokens": {
"en": {
"label": "Maximum new tokens"
},
"zh": {
"label": "最大生成长度"
}
},
"top_p": {
"en": {
"label": "Top-p"
},
"zh": {
"label": "Top-p 采样值"
}
},
"temperature": {
"en": {
"label": "Temperature"
},
"zh": {
"label": "温度系数"
}
}
}
ALERTS = {
"err_conflict": {
"en": "A process is in running, please abort it firstly.",
"zh": "任务已存在,请先中断训练。"
},
"err_exists": {
"en": "You have loaded a model, please unload it first.",
"zh": "模型已存在,请先卸载模型。"
},
"err_no_model": {
"en": "Please select a model.",
"zh": "请选择模型。"
},
"err_no_path": {
"en": "Model not found.",
"zh": "模型未找到。"
},
"err_no_dataset": {
"en": "Please choose a dataset.",
"zh": "请选择数据集。"
},
"info_aborting": {
"en": "Aborted, wait for terminating...",
"zh": "训练中断,正在等待线程结束……"
},
"info_aborted": {
"en": "Ready.",
"zh": "准备就绪。"
},
"info_finished": {
"en": "Finished.",
"zh": "训练完毕。"
},
"info_loading": {
"en": "Loading model...",
"zh": "加载中……"
},
"info_unloading": {
"en": "Unloading model...",
"zh": "卸载中……"
},
"info_loaded": {
"en": "Model loaded, now you can chat with your model!",
"zh": "模型已加载,可以开始聊天了!"
},
"info_unloaded": {
"en": "Model unloaded.",
"zh": "模型已卸载。"
}
}

View File

@@ -0,0 +1,35 @@
import gradio as gr
from typing import Any, Dict, List
from gradio.components import Component
from llmtuner.webui.common import get_model_path, list_dataset, load_config
from llmtuner.webui.locales import LOCALES
from llmtuner.webui.utils import get_time
class Manager:
def __init__(self, elem_list: List[Dict[str, Component]]):
self.elem_list = elem_list
def gen_refresh(self) -> Dict[str, Any]:
refresh_dict = {
"dataset": {"choices": list_dataset()["choices"]},
"output_dir": {"value": get_time()}
}
user_config = load_config()
if user_config["last_model"]:
refresh_dict["model_name"] = {"value": user_config["last_model"]}
refresh_dict["model_path"] = {"value": get_model_path(user_config["last_model"])}
return refresh_dict
def gen_label(self, lang: str) -> Dict[Component, dict]:
update_dict = {}
refresh_dict = self.gen_refresh()
for elems in self.elem_list:
for name, component in elems.items():
update_dict[component] = gr.update(**LOCALES[name][lang], **refresh_dict.get(name, {}))
return update_dict

View File

@@ -0,0 +1,177 @@
import logging
import os
import threading
import time
import transformers
from typing import Optional, Tuple
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import get_train_args, run_sft
from llmtuner.webui.common import get_model_path, get_save_dir
from llmtuner.webui.locales import ALERTS
from llmtuner.webui.utils import format_info, get_eval_results
class Runner:
def __init__(self):
self.aborted = False
self.running = False
def set_abort(self):
self.aborted = True
self.running = False
def initialize(self, lang: str, model_name: str, dataset: list) -> Tuple[str, str, LoggerHandler, LogCallback]:
if self.running:
return None, ALERTS["err_conflict"][lang], None, None
if not model_name:
return None, ALERTS["err_no_model"][lang], None, None
model_name_or_path = get_model_path(model_name)
if not model_name_or_path:
return None, ALERTS["err_no_path"][lang], None, None
if len(dataset) == 0:
return None, ALERTS["err_no_dataset"][lang], None, None
self.aborted = False
self.running = True
logger_handler = LoggerHandler()
logger_handler.setLevel(logging.INFO)
logging.root.addHandler(logger_handler)
transformers.logging.add_handler(logger_handler)
trainer_callback = LogCallback(self)
return model_name_or_path, "", logger_handler, trainer_callback
def finalize(self, lang: str, finish_info: Optional[str] = None) -> str:
self.running = False
torch_gc()
if self.aborted:
return ALERTS["info_aborted"][lang]
else:
return finish_info if finish_info is not None else ALERTS["info_finished"][lang]
def run_train(
self, lang, model_name, checkpoints, finetuning_type, template,
dataset, dataset_dir, learning_rate, num_train_epochs, max_samples,
fp16, quantization_bit, batch_size, gradient_accumulation_steps,
lr_scheduler_type, logging_steps, save_steps, output_dir
):
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error:
yield error
return
if checkpoints:
checkpoint_dir = ",".join(
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
)
else:
checkpoint_dir = None
args = dict(
model_name_or_path=model_name_or_path,
do_train=True,
finetuning_type=finetuning_type,
prompt_template=template,
dataset=",".join(dataset),
dataset_dir=dataset_dir,
max_samples=int(max_samples),
output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir),
checkpoint_dir=checkpoint_dir,
overwrite_cache=True,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type,
logging_steps=logging_steps,
save_steps=save_steps,
learning_rate=float(learning_rate),
num_train_epochs=float(num_train_epochs),
fp16=fp16,
quantization_bit=int(quantization_bit) if quantization_bit else None
)
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
run_args = dict(
model_args=model_args,
data_args=data_args,
training_args=training_args,
finetuning_args=finetuning_args,
callbacks=[trainer_callback]
)
thread = threading.Thread(target=run_sft, kwargs=run_args)
thread.start()
while thread.is_alive():
time.sleep(1)
if self.aborted:
yield ALERTS["info_aborting"][lang]
else:
yield format_info(logger_handler.log, trainer_callback.tracker)
yield self.finalize(lang)
def run_eval(
self, lang, model_name, checkpoints, finetuning_type, template,
dataset, dataset_dir, max_samples, batch_size, quantization_bit, predict
):
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error:
yield error
return
if checkpoints:
checkpoint_dir = ",".join(
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
)
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_" + "_".join(checkpoints))
else:
checkpoint_dir = None
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_base")
args = dict(
model_name_or_path=model_name_or_path,
do_eval=True,
finetuning_type=finetuning_type,
prompt_template=template,
dataset=",".join(dataset),
dataset_dir=dataset_dir,
max_samples=int(max_samples),
output_dir=output_dir,
checkpoint_dir=checkpoint_dir,
overwrite_cache=True,
predict_with_generate=True,
per_device_eval_batch_size=batch_size,
quantization_bit=int(quantization_bit) if quantization_bit else None
)
if predict:
args.pop("do_eval", None)
args["do_predict"] = True
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
run_args = dict(
model_args=model_args,
data_args=data_args,
training_args=training_args,
finetuning_args=finetuning_args,
callbacks=[trainer_callback]
)
thread = threading.Thread(target=run_sft, kwargs=run_args)
thread.start()
while thread.is_alive():
time.sleep(1)
if self.aborted:
yield ALERTS["info_aborting"][lang]
else:
yield format_info(logger_handler.log, trainer_callback.tracker)
yield self.finalize(lang, get_eval_results(os.path.join(output_dir, "all_results.json")))

View File

@@ -0,0 +1,74 @@
import os
import json
import gradio as gr
import matplotlib.figure
import matplotlib.pyplot as plt
from typing import Tuple
from datetime import datetime
from llmtuner.extras.ploting import smooth
from llmtuner.webui.common import get_save_dir, DATA_CONFIG
def format_info(log: str, tracker: dict) -> str:
info = log
if "current_steps" in tracker:
info += "Running **{:d}/{:d}**: {} < {}\n".format(
tracker["current_steps"], tracker["total_steps"], tracker["elapsed_time"], tracker["remaining_time"]
)
return info
def get_time() -> str:
return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
def can_preview(dataset_dir: str, dataset: list) -> dict:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f)
if (
len(dataset) > 0
and "file_name" in dataset_info[dataset[0]]
and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"]))
):
return gr.update(interactive=True)
else:
return gr.update(interactive=False)
def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, dict]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f)
data_file = dataset_info[dataset[0]]["file_name"]
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
data = json.load(f)
return len(data), data[:2], gr.update(visible=True)
def get_eval_results(path: os.PathLike) -> str:
with open(path, "r", encoding="utf-8") as f:
result = json.dumps(json.load(f), indent=4)
return "```json\n{}\n```\n".format(result)
def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotlib.figure.Figure:
log_file = os.path.join(get_save_dir(base_model), finetuning_type, output_dir, "trainer_log.jsonl")
if not os.path.isfile(log_file):
return None
plt.close("all")
fig = plt.figure()
ax = fig.add_subplot(111)
steps, losses = [], []
with open(log_file, "r", encoding="utf-8") as f:
for line in f:
log_info = json.loads(line)
if log_info.get("loss", None):
steps.append(log_info["current_steps"])
losses.append(log_info["loss"])
ax.plot(steps, losses, alpha=0.4, label="original")
ax.plot(steps, smooth(losses), label="smoothed")
ax.legend()
ax.set_xlabel("step")
ax.set_ylabel("loss")
return fig