8 Commits

Author SHA1 Message Date
hiyouga
35e76879f5 support dev set in web ui
Former-commit-id: fe1370561a9b027d9ebdef52733344f1e3683081
2023-07-18 20:40:49 +08:00
hiyouga
8e4ae0aaac add web demo
Former-commit-id: 25ea647e5ac36b497b8e176b123fdee39be3fd30
2023-07-18 17:21:16 +08:00
hiyouga
5ed2a97056 update baichuan template
Former-commit-id: 03520588c39986c98a0515a64993af8c2468b9d0
2023-07-18 16:43:51 +08:00
hiyouga
03eba6f041 fix template
Former-commit-id: 729053c9cea6254165ae9c8fd7809479b12f735c
2023-07-18 16:37:23 +08:00
hiyouga
ec166e736a fix #176
Former-commit-id: 2ae3445b0d28b4ed22ddbb2cfe09089ae0c23fe1
2023-07-18 16:36:24 +08:00
hiyouga
c85a6b83b3 fix webUI, fix #171 #177
Former-commit-id: 3459bb2d35162dbbef79cda05da08a56921aa276
2023-07-18 15:51:48 +08:00
hiyouga
a864a7b395 update webUI, fix #179
Former-commit-id: f9074fed5e22585679661588befcf266a79009f2
2023-07-18 15:35:17 +08:00
hiyouga
fd8c2d4aac tiny fix
Former-commit-id: bcdf5bb55651d639e9f57fd915268137156af9cd
2023-07-18 00:52:31 +08:00
17 changed files with 414 additions and 166 deletions

View File

@@ -291,6 +291,14 @@ python src/cli_demo.py \
--checkpoint_dir path_to_checkpoint
```
### Web Demo
```bash
python src/web_demo.py \
--model_name_or_path path_to_your_model \
--checkpoint_dir path_to_checkpoint
```
### Export model
```bash

View File

@@ -10,6 +10,8 @@ from llmtuner.tuner import get_infer_args
from llmtuner.extras.misc import torch_gc
from llmtuner.chat.stream_chat import ChatModel
from llmtuner.api.protocol import (
Role,
Finish,
ModelCard,
ModelList,
ChatMessage,
@@ -49,12 +51,12 @@ def create_app():
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
if request.messages[-1].role != "user":
if request.messages[-1].role != Role.USER:
raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content
prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == "system":
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
prefix = prev_messages.pop(0).content
else:
prefix = None
@@ -62,7 +64,7 @@ def create_app():
history = []
if len(prev_messages) % 2 == 0:
for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT:
history.append([prev_messages[i].content, prev_messages[i+1].content])
if request.stream:
@@ -81,19 +83,19 @@ def create_app():
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop"
message=ChatMessage(role=Role.ASSISTANT, content=response),
finish_reason=Finish.STOP
)
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage, object="chat.completion")
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage)
async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest):
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
delta=DeltaMessage(role=Role.ASSISTANT),
finish_reason=None
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk")
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield json.dumps(chunk, ensure_ascii=False)
for new_text in chat_model.stream_chat(
@@ -107,15 +109,15 @@ def create_app():
delta=DeltaMessage(content=new_text),
finish_reason=None
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk")
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield json.dumps(chunk, ensure_ascii=False)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason="stop"
finish_reason=Finish.STOP
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk")
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield json.dumps(chunk, ensure_ascii=False)
yield "[DONE]"

View File

@@ -1,6 +1,18 @@
import time
from enum import Enum
from pydantic import BaseModel, Field
from typing import List, Literal, Optional
from typing import List, Optional
class Role(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
class Finish(str, Enum):
STOP = "stop"
LENGTH = "length"
class ModelCard(BaseModel):
@@ -19,12 +31,12 @@ class ModelList(BaseModel):
class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system"]
role: Role
content: str
class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None
role: Optional[Role] = None
content: Optional[str] = None
@@ -41,13 +53,13 @@ class ChatCompletionRequest(BaseModel):
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length"]
finish_reason: Finish
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]] = None
finish_reason: Optional[Finish] = None
class ChatCompletionResponseUsage(BaseModel):
@@ -58,7 +70,7 @@ class ChatCompletionResponseUsage(BaseModel):
class ChatCompletionResponse(BaseModel):
id: Optional[str] = "chatcmpl-default"
object: Literal["chat.completion"]
object: Optional[str] = "chat.completion"
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
@@ -67,7 +79,7 @@ class ChatCompletionResponse(BaseModel):
class ChatCompletionStreamResponse(BaseModel):
id: Optional[str] = "chatcmpl-default"
object: Literal["chat.completion.chunk"]
object: Optional[str] = "chat.completion.chunk"
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]

View File

@@ -47,6 +47,9 @@ class LogCallback(TrainerCallback):
r"""
Event called after logging the last logs.
"""
if not state.is_world_process_zero:
return
cur_time = time.time()
cur_steps = state.log_history[-1].get("step")
elapsed_time = cur_time - self.start_time

View File

@@ -202,7 +202,20 @@ Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
register_template(
name="baichuan",
prefix="",
prompt="<reserved_102>{query}<reserved_103>",
sep="",
prompt=" <reserved_102> {query} <reserved_103> ",
sep="</s>",
use_history=True
)
r"""
Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha
https://huggingface.co/HuggingFaceH4/starchat-beta
"""
register_template(
name="starchat",
prefix="<|system|>\n",
prompt="<|user|>\n{query}<|end|>\n<|assistant|>\n",
sep="<|end|>\n",
use_history=True
)

View File

@@ -108,7 +108,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
replace_model(unwrapped_model, target="reward")
with torch.no_grad():
_, _, values = self.model(**self.prepare_model_inputs(queries, responses))
rewards = [reward for reward in values[-1].to(torch.float32)] # use float32 type
rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
replace_model(unwrapped_model, target="default")
# Run PPO step

View File

@@ -23,7 +23,7 @@ class ComputeMetrics:
Uses the model predictions to compute metrics.
"""
preds, labels = eval_preds
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
score_dict = {"accuracy": [], "rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
@@ -47,5 +47,6 @@ class ComputeMetrics:
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
score_dict["accuracy"].append(float(len(label) != 0 and pred[:len(label)] == label))
return {k: float(np.mean(v)) for k, v in score_dict.items()}

View File

@@ -11,14 +11,22 @@ from llmtuner.webui.locales import ALERTS
class WebChatModel(ChatModel):
def __init__(self):
def __init__(self, *args):
self.model = None
self.tokenizer = None
self.generating_args = GeneratingArguments()
if len(args) != 0:
super().__init__(*args)
def load_model(
self, lang: str, model_name: str, checkpoints: list,
finetuning_type: str, template: str, quantization_bit: str
self,
lang: str,
model_name: str,
checkpoints: List[str],
finetuning_type: str,
quantization_bit: str,
template: str,
source_prefix: str
):
if self.model is not None:
yield ALERTS["err_exists"][lang]
@@ -43,10 +51,11 @@ class WebChatModel(ChatModel):
yield ALERTS["info_loading"][lang]
args = dict(
model_name_or_path=model_name_or_path,
finetuning_type=finetuning_type,
prompt_template=template,
checkpoint_dir=checkpoint_dir,
quantization_bit=int(quantization_bit) if quantization_bit else None
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None,
prompt_template=template,
source_prefix=source_prefix
)
super().__init__(*get_infer_args(args))

View File

@@ -1,4 +1,4 @@
from typing import Dict, Tuple
from typing import Dict, Optional, Tuple
import gradio as gr
from gradio.blocks import Block
@@ -8,9 +8,10 @@ from llmtuner.webui.chat import WebChatModel
def create_chat_box(
chat_model: WebChatModel
chat_model: WebChatModel,
visible: Optional[bool] = False
) -> Tuple[Block, Component, Component, Dict[str, Component]]:
with gr.Box(visible=False) as chat_box:
with gr.Box(visible=visible) as chat_box:
chatbot = gr.Chatbot()
with gr.Row():

View File

@@ -10,8 +10,8 @@ from llmtuner.webui.utils import can_preview, get_preview
def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, interactive=True, scale=2)
dataset = gr.Dropdown(multiselect=True, interactive=True, scale=4)
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4)
preview_btn = gr.Button(interactive=False, scale=1)
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
@@ -21,9 +21,10 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
with gr.Row():
max_samples = gr.Textbox(value="100000", interactive=True)
batch_size = gr.Slider(value=8, minimum=1, maximum=128, step=1, interactive=True)
quantization_bit = gr.Dropdown([8, 4])
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
max_target_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
max_samples = gr.Textbox(value="100000")
batch_size = gr.Slider(value=8, minimum=1, maximum=512, step=1)
predict = gr.Checkbox(value=True)
with gr.Row():
@@ -35,9 +36,20 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
start_btn.click(
runner.run_eval,
[
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"],
top_elems["finetuning_type"], top_elems["template"],
dataset, dataset_dir, max_samples, batch_size, quantization_bit, predict
top_elems["lang"],
top_elems["model_name"],
top_elems["checkpoints"],
top_elems["finetuning_type"],
top_elems["quantization_bit"],
top_elems["template"],
top_elems["source_prefix"],
dataset_dir,
dataset,
max_source_length,
max_target_length,
max_samples,
batch_size,
predict
],
[output_box]
)
@@ -50,9 +62,10 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
preview_count=preview_count,
preview_samples=preview_samples,
close_btn=close_btn,
max_source_length=max_source_length,
max_target_length=max_target_length,
max_samples=max_samples,
batch_size=batch_size,
quantization_bit=quantization_bit,
predict=predict,
start_btn=start_btn,
stop_btn=stop_btn,

View File

@@ -11,7 +11,6 @@ def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
with gr.Row():
load_btn = gr.Button()
unload_btn = gr.Button()
quantization_bit = gr.Dropdown([8, 4])
info_box = gr.Markdown()
@@ -21,9 +20,13 @@ def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
load_btn.click(
chat_model.load_model,
[
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"],
top_elems["finetuning_type"], top_elems["template"],
quantization_bit
top_elems["lang"],
top_elems["model_name"],
top_elems["checkpoints"],
top_elems["finetuning_type"],
top_elems["quantization_bit"],
top_elems["template"],
top_elems["source_prefix"]
],
[info_box]
).then(
@@ -39,7 +42,6 @@ def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
)
return dict(
quantization_bit=quantization_bit,
info_box=info_box,
load_btn=load_btn,
unload_btn=unload_btn,

View File

@@ -12,8 +12,8 @@ from llmtuner.webui.utils import can_preview, get_preview, gen_plot
def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, interactive=True, scale=1)
dataset = gr.Dropdown(multiselect=True, interactive=True, scale=4)
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4)
preview_btn = gr.Button(interactive=False, scale=1)
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
@@ -23,22 +23,24 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
with gr.Row():
learning_rate = gr.Textbox(value="5e-5", interactive=True)
num_train_epochs = gr.Textbox(value="3.0", interactive=True)
max_samples = gr.Textbox(value="100000", interactive=True)
quantization_bit = gr.Dropdown([8, 4])
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
max_target_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
learning_rate = gr.Textbox(value="5e-5")
num_train_epochs = gr.Textbox(value="3.0")
max_samples = gr.Textbox(value="100000")
with gr.Row():
batch_size = gr.Slider(value=4, minimum=1, maximum=128, step=1, interactive=True)
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=32, step=1, interactive=True)
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
lr_scheduler_type = gr.Dropdown(
value="cosine", choices=[scheduler.value for scheduler in SchedulerType], interactive=True
value="cosine", choices=[scheduler.value for scheduler in SchedulerType]
)
dev_ratio = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
fp16 = gr.Checkbox(value=True)
with gr.Row():
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5, interactive=True)
save_steps = gr.Slider(value=100, minimum=10, maximum=2000, step=10, interactive=True)
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
with gr.Row():
start_btn = gr.Button()
@@ -55,11 +57,28 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
start_btn.click(
runner.run_train,
[
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"],
top_elems["finetuning_type"], top_elems["template"],
dataset, dataset_dir, learning_rate, num_train_epochs, max_samples,
fp16, quantization_bit, batch_size, gradient_accumulation_steps,
lr_scheduler_type, logging_steps, save_steps, output_dir
top_elems["lang"],
top_elems["model_name"],
top_elems["checkpoints"],
top_elems["finetuning_type"],
top_elems["quantization_bit"],
top_elems["template"],
top_elems["source_prefix"],
dataset_dir,
dataset,
max_source_length,
max_target_length,
learning_rate,
num_train_epochs,
max_samples,
batch_size,
gradient_accumulation_steps,
lr_scheduler_type,
dev_ratio,
fp16,
logging_steps,
save_steps,
output_dir
],
[output_box]
)
@@ -76,13 +95,15 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
preview_count=preview_count,
preview_samples=preview_samples,
close_btn=close_btn,
max_source_length=max_source_length,
max_target_length=max_target_length,
learning_rate=learning_rate,
num_train_epochs=num_train_epochs,
max_samples=max_samples,
quantization_bit=quantization_bit,
batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type,
dev_ratio=dev_ratio,
fp16=fp16,
logging_steps=logging_steps,
save_steps=save_steps,

View File

@@ -6,29 +6,40 @@ from gradio.components import Component
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
from llmtuner.extras.template import templates
from llmtuner.webui.common import list_checkpoint, get_model_path, save_config
from llmtuner.webui.utils import can_quantize
def create_top() -> Dict[str, Component]:
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
with gr.Row():
lang = gr.Dropdown(choices=["en", "zh"], value="en", interactive=True, scale=1)
lang = gr.Dropdown(choices=["en", "zh"], value="en", scale=1)
model_name = gr.Dropdown(choices=available_models, scale=3)
model_path = gr.Textbox(scale=3)
with gr.Row():
finetuning_type = gr.Dropdown(value="lora", choices=METHODS, interactive=True, scale=1)
template = gr.Dropdown(value="default", choices=list(templates.keys()), interactive=True, scale=1)
checkpoints = gr.Dropdown(multiselect=True, interactive=True, scale=4)
finetuning_type = gr.Dropdown(value="lora", choices=METHODS, scale=1)
checkpoints = gr.Dropdown(multiselect=True, scale=5)
refresh_btn = gr.Button(scale=1)
with gr.Row():
quantization_bit = gr.Dropdown([8, 4], scale=1)
template = gr.Dropdown(value="default", choices=list(templates.keys()), scale=2)
source_prefix = gr.Textbox(scale=4)
model_name.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints]
).then(
get_model_path, [model_name], [model_path]
) # do not save config since the below line will save
model_path.change(save_config, [model_name, model_path])
finetuning_type.change(list_checkpoint, [model_name, finetuning_type], [checkpoints])
finetuning_type.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints]
).then(
can_quantize, [finetuning_type], [quantization_bit]
)
refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints])
return dict(
@@ -38,5 +49,7 @@ def create_top() -> Dict[str, Component]:
finetuning_type=finetuning_type,
template=template,
checkpoints=checkpoints,
refresh_btn=refresh_btn
refresh_btn=refresh_btn,
quantization_bit=quantization_bit,
source_prefix=source_prefix
)

View File

@@ -25,6 +25,14 @@ LOCALES = {
"info": "本地模型的文件路径或 Hugging Face 的模型标识符。"
}
},
"finetuning_type": {
"en": {
"label": "Finetuning method"
},
"zh": {
"label": "微调方法"
}
},
"checkpoints": {
"en": {
"label": "Checkpoints"
@@ -33,14 +41,6 @@ LOCALES = {
"label": "模型断点"
}
},
"template": {
"en": {
"label": "Prompt template"
},
"zh": {
"label": "提示模板"
}
},
"refresh_btn": {
"en": {
"value": "Refresh checkpoints"
@@ -49,6 +49,36 @@ LOCALES = {
"value": "刷新断点"
}
},
"quantization_bit": {
"en": {
"label": "Quantization bit (optional)",
"info": "Enable 4/8-bit model quantization."
},
"zh": {
"label": "量化等级(非必填)",
"info": "启用 4/8 比特模型量化。"
}
},
"template": {
"en": {
"label": "Prompt template",
"info": "The template used in constructing prompts."
},
"zh": {
"label": "提示模板",
"info": "构建提示词时使用的模板"
}
},
"source_prefix": {
"en": {
"label": "Source prefix (optional)",
"info": "A sequence used as the prefix of each samples."
},
"zh": {
"label": "前缀序列(非必填)",
"info": "作为每个输入样本前缀的序列"
}
},
"dataset_dir": {
"en": {
"label": "Data dir",
@@ -99,66 +129,24 @@ LOCALES = {
"value": "关闭"
}
},
"max_samples": {
"max_source_length": {
"en": {
"label": "Max samples",
"info": "Maximum samples per dataset."
"label": "Max source length",
"info": "Max tokens in source sequence."
},
"zh": {
"label": "最大样本数",
"info": "每个数据集最多使用的样本数"
"label": "输入序列最大长度",
"info": "输入序列分词后的最大长度"
}
},
"batch_size": {
"max_target_length": {
"en": {
"label": "Batch size",
"info": "Number of samples to process per GPU."
},
"zh":{
"label": "批处理大小",
"info": "每块 GPU 上处理的样本数量。"
}
},
"quantization_bit": {
"en": {
"label": "Quantization bit",
"info": "Enable 4/8-bit model quantization."
"label": "Max target length",
"info": "Max tokens in target sequence."
},
"zh": {
"label": "量化",
"info": "启用 4/8 比特模型量化"
}
},
"start_btn": {
"en": {
"value": "Start"
},
"zh": {
"value": "开始"
}
},
"stop_btn": {
"en": {
"value": "Abort"
},
"zh": {
"value": "中断"
}
},
"output_box": {
"en": {
"value": "Ready."
},
"zh": {
"value": "准备就绪。"
}
},
"finetuning_type": {
"en": {
"label": "Finetuning method"
},
"zh": {
"label": "微调方法"
"label": "输出序列最大长度",
"info": "输出序列分词后的最大长度"
}
},
"learning_rate": {
@@ -181,6 +169,26 @@ LOCALES = {
"info": "需要执行的训练总轮数。"
}
},
"max_samples": {
"en": {
"label": "Max samples",
"info": "Maximum samples per dataset."
},
"zh": {
"label": "最大样本数",
"info": "每个数据集最多使用的样本数。"
}
},
"batch_size": {
"en": {
"label": "Batch size",
"info": "Number of samples to process per GPU."
},
"zh":{
"label": "批处理大小",
"info": "每块 GPU 上处理的样本数量。"
}
},
"gradient_accumulation_steps": {
"en": {
"label": "Gradient accumulation",
@@ -201,6 +209,16 @@ LOCALES = {
"info": "采用的学习率调节器名称。"
}
},
"dev_ratio": {
"en": {
"label": "Dev ratio",
"info": "Proportion of data in the dev set."
},
"zh": {
"label": "验证集比例",
"info": "验证集占全部样本的百分比。"
}
},
"fp16": {
"en": {
"label": "fp16",
@@ -231,6 +249,22 @@ LOCALES = {
"info": "每两次断点保存间的更新步数。"
}
},
"start_btn": {
"en": {
"value": "Start"
},
"zh": {
"value": "开始"
}
},
"stop_btn": {
"en": {
"value": "Abort"
},
"zh": {
"value": "中断"
}
},
"output_dir": {
"en": {
"label": "Checkpoint name",
@@ -241,6 +275,14 @@ LOCALES = {
"info": "保存模型断点的文件夹名称。"
}
},
"output_box": {
"en": {
"value": "Ready."
},
"zh": {
"value": "准备就绪。"
}
},
"loss_viewer": {
"en": {
"label": "Loss"
@@ -257,14 +299,6 @@ LOCALES = {
"label": "保存预测结果"
}
},
"info_box": {
"en": {
"value": "Model unloaded, please load a model first."
},
"zh": {
"value": "模型未加载,请先加载模型。"
}
},
"load_btn": {
"en": {
"value": "Load model"
@@ -281,6 +315,14 @@ LOCALES = {
"value": "卸载模型"
}
},
"info_box": {
"en": {
"value": "Model unloaded, please load a model first."
},
"zh": {
"value": "模型未加载,请先加载模型。"
}
},
"query": {
"en": {
"placeholder": "Input..."
@@ -305,6 +347,14 @@ LOCALES = {
"value": "清空历史"
}
},
"max_length": {
"en": {
"label": "Maximum length"
},
"zh": {
"label": "最大长度"
}
},
"max_new_tokens": {
"en": {
"label": "Maximum new tokens"

View File

@@ -3,7 +3,7 @@ import os
import threading
import time
import transformers
from typing import Optional, Tuple
from typing import List, Optional, Tuple
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import DEFAULT_MODULE # will be deprecated
@@ -59,10 +59,29 @@ class Runner:
return finish_info if finish_info is not None else ALERTS["info_finished"][lang]
def run_train(
self, lang, model_name, checkpoints, finetuning_type, template,
dataset, dataset_dir, learning_rate, num_train_epochs, max_samples,
fp16, quantization_bit, batch_size, gradient_accumulation_steps,
lr_scheduler_type, logging_steps, save_steps, output_dir
self,
lang: str,
model_name: str,
checkpoints: List[str],
finetuning_type: str,
quantization_bit: str,
template: str,
source_prefix: str,
dataset_dir: str,
dataset: List[str],
max_source_length: int,
max_target_length: int,
learning_rate: str,
num_train_epochs: str,
max_samples: str,
batch_size: int,
gradient_accumulation_steps: int,
lr_scheduler_type: str,
dev_ratio: float,
fp16: bool,
logging_steps: int,
save_steps: int,
output_dir: str
):
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error:
@@ -79,25 +98,35 @@ class Runner:
args = dict(
model_name_or_path=model_name_or_path,
do_train=True,
finetuning_type=finetuning_type,
lora_target=DEFAULT_MODULE.get(model_name.split("-")[0], None) or "q_proj,v_proj",
prompt_template=template,
dataset=",".join(dataset),
dataset_dir=dataset_dir,
max_samples=int(max_samples),
output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir),
checkpoint_dir=checkpoint_dir,
overwrite_cache=True,
lora_target=DEFAULT_MODULE.get(model_name.split("-")[0], None) or "q_proj,v_proj",
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None,
prompt_template=template,
source_prefix=source_prefix,
dataset_dir=dataset_dir,
dataset=",".join(dataset),
max_source_length=max_source_length,
max_target_length=max_target_length,
learning_rate=float(learning_rate),
num_train_epochs=float(num_train_epochs),
max_samples=int(max_samples),
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type,
fp16=fp16,
logging_steps=logging_steps,
save_steps=save_steps,
learning_rate=float(learning_rate),
num_train_epochs=float(num_train_epochs),
fp16=fp16,
quantization_bit=int(quantization_bit) if quantization_bit else None
output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
)
if dev_ratio > 1e-6:
args["dev_ratio"] = dev_ratio
args["evaluation_strategy"] = "steps"
args["eval_steps"] = save_steps
args["load_best_model_at_end"] = True
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
run_args = dict(
@@ -120,8 +149,21 @@ class Runner:
yield self.finalize(lang)
def run_eval(
self, lang, model_name, checkpoints, finetuning_type, template,
dataset, dataset_dir, max_samples, batch_size, quantization_bit, predict
self,
lang: str,
model_name: str,
checkpoints: List[str],
finetuning_type: str,
quantization_bit: str,
template: str,
source_prefix: str,
dataset_dir: str,
dataset: List[str],
max_source_length: int,
max_target_length: int,
max_samples: str,
batch_size: int,
predict: bool
):
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error:
@@ -140,17 +182,20 @@ class Runner:
args = dict(
model_name_or_path=model_name_or_path,
do_eval=True,
finetuning_type=finetuning_type,
prompt_template=template,
dataset=",".join(dataset),
dataset_dir=dataset_dir,
max_samples=int(max_samples),
output_dir=output_dir,
checkpoint_dir=checkpoint_dir,
overwrite_cache=True,
predict_with_generate=True,
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None,
prompt_template=template,
source_prefix=source_prefix,
dataset_dir=dataset_dir,
dataset=",".join(dataset),
max_source_length=max_source_length,
max_target_length=max_target_length,
max_samples=int(max_samples),
per_device_eval_batch_size=batch_size,
quantization_bit=int(quantization_bit) if quantization_bit else None
output_dir=output_dir
)
if predict:

View File

@@ -3,7 +3,7 @@ import json
import gradio as gr
import matplotlib.figure
import matplotlib.pyplot as plt
from typing import Tuple
from typing import Any, Dict, Tuple
from datetime import datetime
from llmtuner.extras.ploting import smooth
@@ -23,7 +23,7 @@ def get_time() -> str:
return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
def can_preview(dataset_dir: str, dataset: list) -> dict:
def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f)
if (
@@ -36,7 +36,7 @@ def can_preview(dataset_dir: str, dataset: list) -> dict:
return gr.update(interactive=False)
def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, dict]:
def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, Dict[str, Any]]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f)
data_file = dataset_info[dataset[0]]["file_name"]
@@ -45,6 +45,13 @@ def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, dict]:
return len(data), data[:2], gr.update(visible=True)
def can_quantize(finetuning_type: str) -> Dict[str, Any]:
if finetuning_type != "lora":
return gr.update(value="", interactive=False)
else:
return gr.update(interactive=True)
def get_eval_results(path: os.PathLike) -> str:
with open(path, "r", encoding="utf-8") as f:
result = json.dumps(json.load(f), indent=4)
@@ -66,6 +73,10 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl
if log_info.get("loss", None):
steps.append(log_info["current_steps"])
losses.append(log_info["loss"])
if len(losses) == 0:
return None
ax.plot(steps, losses, alpha=0.4, label="original")
ax.plot(steps, smooth(losses), label="smoothed")
ax.legend()

44
src/web_demo.py Normal file
View File

@@ -0,0 +1,44 @@
# coding=utf-8
# Implements user interface in browser for fine-tuned models.
# Usage: python web_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
import gradio as gr
from transformers.utils.versions import require_version
from llmtuner import get_infer_args
from llmtuner.webui.chat import WebChatModel
from llmtuner.webui.components.chatbot import create_chat_box
from llmtuner.webui.manager import Manager
require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
def main():
chat_model = WebChatModel(*get_infer_args())
with gr.Blocks(title="Web Demo") as demo:
lang = gr.Dropdown(choices=["en", "zh"], value="en")
_, _, _, chat_elems = create_chat_box(chat_model, visible=True)
manager = Manager([{"lang": lang}, chat_elems])
demo.load(
manager.gen_label,
[lang],
[lang] + [elem for elem in chat_elems.values()],
)
lang.change(
manager.gen_label,
[lang],
[lang] + [elem for elem in chat_elems.values()],
)
demo.queue()
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)
if __name__ == "__main__":
main()