8 Commits

Author SHA1 Message Date
hiyouga
35e76879f5 support dev set in web ui
Former-commit-id: fe1370561a9b027d9ebdef52733344f1e3683081
2023-07-18 20:40:49 +08:00
hiyouga
8e4ae0aaac add web demo
Former-commit-id: 25ea647e5ac36b497b8e176b123fdee39be3fd30
2023-07-18 17:21:16 +08:00
hiyouga
5ed2a97056 update baichuan template
Former-commit-id: 03520588c39986c98a0515a64993af8c2468b9d0
2023-07-18 16:43:51 +08:00
hiyouga
03eba6f041 fix template
Former-commit-id: 729053c9cea6254165ae9c8fd7809479b12f735c
2023-07-18 16:37:23 +08:00
hiyouga
ec166e736a fix #176
Former-commit-id: 2ae3445b0d28b4ed22ddbb2cfe09089ae0c23fe1
2023-07-18 16:36:24 +08:00
hiyouga
c85a6b83b3 fix webUI, fix #171 #177
Former-commit-id: 3459bb2d35162dbbef79cda05da08a56921aa276
2023-07-18 15:51:48 +08:00
hiyouga
a864a7b395 update webUI, fix #179
Former-commit-id: f9074fed5e22585679661588befcf266a79009f2
2023-07-18 15:35:17 +08:00
hiyouga
fd8c2d4aac tiny fix
Former-commit-id: bcdf5bb55651d639e9f57fd915268137156af9cd
2023-07-18 00:52:31 +08:00
17 changed files with 414 additions and 166 deletions

View File

@@ -291,6 +291,14 @@ python src/cli_demo.py \
--checkpoint_dir path_to_checkpoint --checkpoint_dir path_to_checkpoint
``` ```
### Web Demo
```bash
python src/web_demo.py \
--model_name_or_path path_to_your_model \
--checkpoint_dir path_to_checkpoint
```
### Export model ### Export model
```bash ```bash

View File

@@ -10,6 +10,8 @@ from llmtuner.tuner import get_infer_args
from llmtuner.extras.misc import torch_gc from llmtuner.extras.misc import torch_gc
from llmtuner.chat.stream_chat import ChatModel from llmtuner.chat.stream_chat import ChatModel
from llmtuner.api.protocol import ( from llmtuner.api.protocol import (
Role,
Finish,
ModelCard, ModelCard,
ModelList, ModelList,
ChatMessage, ChatMessage,
@@ -49,12 +51,12 @@ def create_app():
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest): async def create_chat_completion(request: ChatCompletionRequest):
if request.messages[-1].role != "user": if request.messages[-1].role != Role.USER:
raise HTTPException(status_code=400, detail="Invalid request") raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content query = request.messages[-1].content
prev_messages = request.messages[:-1] prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == "system": if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
prefix = prev_messages.pop(0).content prefix = prev_messages.pop(0).content
else: else:
prefix = None prefix = None
@@ -62,7 +64,7 @@ def create_app():
history = [] history = []
if len(prev_messages) % 2 == 0: if len(prev_messages) % 2 == 0:
for i in range(0, len(prev_messages), 2): for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant": if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT:
history.append([prev_messages[i].content, prev_messages[i+1].content]) history.append([prev_messages[i].content, prev_messages[i+1].content])
if request.stream: if request.stream:
@@ -81,19 +83,19 @@ def create_app():
choice_data = ChatCompletionResponseChoice( choice_data = ChatCompletionResponseChoice(
index=0, index=0,
message=ChatMessage(role="assistant", content=response), message=ChatMessage(role=Role.ASSISTANT, content=response),
finish_reason="stop" finish_reason=Finish.STOP
) )
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage, object="chat.completion") return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage)
async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest): async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest):
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, index=0,
delta=DeltaMessage(role="assistant"), delta=DeltaMessage(role=Role.ASSISTANT),
finish_reason=None finish_reason=None
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk") chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield json.dumps(chunk, ensure_ascii=False) yield json.dumps(chunk, ensure_ascii=False)
for new_text in chat_model.stream_chat( for new_text in chat_model.stream_chat(
@@ -107,15 +109,15 @@ def create_app():
delta=DeltaMessage(content=new_text), delta=DeltaMessage(content=new_text),
finish_reason=None finish_reason=None
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk") chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield json.dumps(chunk, ensure_ascii=False) yield json.dumps(chunk, ensure_ascii=False)
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, index=0,
delta=DeltaMessage(), delta=DeltaMessage(),
finish_reason="stop" finish_reason=Finish.STOP
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk") chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield json.dumps(chunk, ensure_ascii=False) yield json.dumps(chunk, ensure_ascii=False)
yield "[DONE]" yield "[DONE]"

View File

@@ -1,6 +1,18 @@
import time import time
from enum import Enum
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List, Literal, Optional from typing import List, Optional
class Role(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
class Finish(str, Enum):
STOP = "stop"
LENGTH = "length"
class ModelCard(BaseModel): class ModelCard(BaseModel):
@@ -19,12 +31,12 @@ class ModelList(BaseModel):
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system"] role: Role
content: str content: str
class DeltaMessage(BaseModel): class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None role: Optional[Role] = None
content: Optional[str] = None content: Optional[str] = None
@@ -41,13 +53,13 @@ class ChatCompletionRequest(BaseModel):
class ChatCompletionResponseChoice(BaseModel): class ChatCompletionResponseChoice(BaseModel):
index: int index: int
message: ChatMessage message: ChatMessage
finish_reason: Literal["stop", "length"] finish_reason: Finish
class ChatCompletionResponseStreamChoice(BaseModel): class ChatCompletionResponseStreamChoice(BaseModel):
index: int index: int
delta: DeltaMessage delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]] = None finish_reason: Optional[Finish] = None
class ChatCompletionResponseUsage(BaseModel): class ChatCompletionResponseUsage(BaseModel):
@@ -58,7 +70,7 @@ class ChatCompletionResponseUsage(BaseModel):
class ChatCompletionResponse(BaseModel): class ChatCompletionResponse(BaseModel):
id: Optional[str] = "chatcmpl-default" id: Optional[str] = "chatcmpl-default"
object: Literal["chat.completion"] object: Optional[str] = "chat.completion"
created: Optional[int] = Field(default_factory=lambda: int(time.time())) created: Optional[int] = Field(default_factory=lambda: int(time.time()))
model: str model: str
choices: List[ChatCompletionResponseChoice] choices: List[ChatCompletionResponseChoice]
@@ -67,7 +79,7 @@ class ChatCompletionResponse(BaseModel):
class ChatCompletionStreamResponse(BaseModel): class ChatCompletionStreamResponse(BaseModel):
id: Optional[str] = "chatcmpl-default" id: Optional[str] = "chatcmpl-default"
object: Literal["chat.completion.chunk"] object: Optional[str] = "chat.completion.chunk"
created: Optional[int] = Field(default_factory=lambda: int(time.time())) created: Optional[int] = Field(default_factory=lambda: int(time.time()))
model: str model: str
choices: List[ChatCompletionResponseStreamChoice] choices: List[ChatCompletionResponseStreamChoice]

View File

@@ -47,6 +47,9 @@ class LogCallback(TrainerCallback):
r""" r"""
Event called after logging the last logs. Event called after logging the last logs.
""" """
if not state.is_world_process_zero:
return
cur_time = time.time() cur_time = time.time()
cur_steps = state.log_history[-1].get("step") cur_steps = state.log_history[-1].get("step")
elapsed_time = cur_time - self.start_time elapsed_time = cur_time - self.start_time

View File

@@ -202,7 +202,20 @@ Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
register_template( register_template(
name="baichuan", name="baichuan",
prefix="", prefix="",
prompt="<reserved_102>{query}<reserved_103>", prompt=" <reserved_102> {query} <reserved_103> ",
sep="", sep="</s>",
use_history=True
)
r"""
Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha
https://huggingface.co/HuggingFaceH4/starchat-beta
"""
register_template(
name="starchat",
prefix="<|system|>\n",
prompt="<|user|>\n{query}<|end|>\n<|assistant|>\n",
sep="<|end|>\n",
use_history=True use_history=True
) )

View File

@@ -108,7 +108,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
replace_model(unwrapped_model, target="reward") replace_model(unwrapped_model, target="reward")
with torch.no_grad(): with torch.no_grad():
_, _, values = self.model(**self.prepare_model_inputs(queries, responses)) _, _, values = self.model(**self.prepare_model_inputs(queries, responses))
rewards = [reward for reward in values[-1].to(torch.float32)] # use float32 type rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
replace_model(unwrapped_model, target="default") replace_model(unwrapped_model, target="default")
# Run PPO step # Run PPO step

View File

@@ -23,7 +23,7 @@ class ComputeMetrics:
Uses the model predictions to compute metrics. Uses the model predictions to compute metrics.
""" """
preds, labels = eval_preds preds, labels = eval_preds
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} score_dict = {"accuracy": [], "rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id) preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id) labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
@@ -47,5 +47,6 @@ class ComputeMetrics:
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
score_dict["bleu-4"].append(round(bleu_score * 100, 4)) score_dict["bleu-4"].append(round(bleu_score * 100, 4))
score_dict["accuracy"].append(float(len(label) != 0 and pred[:len(label)] == label))
return {k: float(np.mean(v)) for k, v in score_dict.items()} return {k: float(np.mean(v)) for k, v in score_dict.items()}

View File

@@ -11,14 +11,22 @@ from llmtuner.webui.locales import ALERTS
class WebChatModel(ChatModel): class WebChatModel(ChatModel):
def __init__(self): def __init__(self, *args):
self.model = None self.model = None
self.tokenizer = None self.tokenizer = None
self.generating_args = GeneratingArguments() self.generating_args = GeneratingArguments()
if len(args) != 0:
super().__init__(*args)
def load_model( def load_model(
self, lang: str, model_name: str, checkpoints: list, self,
finetuning_type: str, template: str, quantization_bit: str lang: str,
model_name: str,
checkpoints: List[str],
finetuning_type: str,
quantization_bit: str,
template: str,
source_prefix: str
): ):
if self.model is not None: if self.model is not None:
yield ALERTS["err_exists"][lang] yield ALERTS["err_exists"][lang]
@@ -43,10 +51,11 @@ class WebChatModel(ChatModel):
yield ALERTS["info_loading"][lang] yield ALERTS["info_loading"][lang]
args = dict( args = dict(
model_name_or_path=model_name_or_path, model_name_or_path=model_name_or_path,
finetuning_type=finetuning_type,
prompt_template=template,
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
quantization_bit=int(quantization_bit) if quantization_bit else None finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None,
prompt_template=template,
source_prefix=source_prefix
) )
super().__init__(*get_infer_args(args)) super().__init__(*get_infer_args(args))

View File

@@ -1,4 +1,4 @@
from typing import Dict, Tuple from typing import Dict, Optional, Tuple
import gradio as gr import gradio as gr
from gradio.blocks import Block from gradio.blocks import Block
@@ -8,9 +8,10 @@ from llmtuner.webui.chat import WebChatModel
def create_chat_box( def create_chat_box(
chat_model: WebChatModel chat_model: WebChatModel,
visible: Optional[bool] = False
) -> Tuple[Block, Component, Component, Dict[str, Component]]: ) -> Tuple[Block, Component, Component, Dict[str, Component]]:
with gr.Box(visible=False) as chat_box: with gr.Box(visible=visible) as chat_box:
chatbot = gr.Chatbot() chatbot = gr.Chatbot()
with gr.Row(): with gr.Row():

View File

@@ -10,8 +10,8 @@ from llmtuner.webui.utils import can_preview, get_preview
def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]: def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
with gr.Row(): with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, interactive=True, scale=2) dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, interactive=True, scale=4) dataset = gr.Dropdown(multiselect=True, scale=4)
preview_btn = gr.Button(interactive=False, scale=1) preview_btn = gr.Button(interactive=False, scale=1)
preview_box, preview_count, preview_samples, close_btn = create_preview_box() preview_box, preview_count, preview_samples, close_btn = create_preview_box()
@@ -21,9 +21,10 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box]) preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
with gr.Row(): with gr.Row():
max_samples = gr.Textbox(value="100000", interactive=True) max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
batch_size = gr.Slider(value=8, minimum=1, maximum=128, step=1, interactive=True) max_target_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
quantization_bit = gr.Dropdown([8, 4]) max_samples = gr.Textbox(value="100000")
batch_size = gr.Slider(value=8, minimum=1, maximum=512, step=1)
predict = gr.Checkbox(value=True) predict = gr.Checkbox(value=True)
with gr.Row(): with gr.Row():
@@ -35,9 +36,20 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
start_btn.click( start_btn.click(
runner.run_eval, runner.run_eval,
[ [
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"], top_elems["lang"],
top_elems["finetuning_type"], top_elems["template"], top_elems["model_name"],
dataset, dataset_dir, max_samples, batch_size, quantization_bit, predict top_elems["checkpoints"],
top_elems["finetuning_type"],
top_elems["quantization_bit"],
top_elems["template"],
top_elems["source_prefix"],
dataset_dir,
dataset,
max_source_length,
max_target_length,
max_samples,
batch_size,
predict
], ],
[output_box] [output_box]
) )
@@ -50,9 +62,10 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
preview_count=preview_count, preview_count=preview_count,
preview_samples=preview_samples, preview_samples=preview_samples,
close_btn=close_btn, close_btn=close_btn,
max_source_length=max_source_length,
max_target_length=max_target_length,
max_samples=max_samples, max_samples=max_samples,
batch_size=batch_size, batch_size=batch_size,
quantization_bit=quantization_bit,
predict=predict, predict=predict,
start_btn=start_btn, start_btn=start_btn,
stop_btn=stop_btn, stop_btn=stop_btn,

View File

@@ -11,7 +11,6 @@ def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
with gr.Row(): with gr.Row():
load_btn = gr.Button() load_btn = gr.Button()
unload_btn = gr.Button() unload_btn = gr.Button()
quantization_bit = gr.Dropdown([8, 4])
info_box = gr.Markdown() info_box = gr.Markdown()
@@ -21,9 +20,13 @@ def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
load_btn.click( load_btn.click(
chat_model.load_model, chat_model.load_model,
[ [
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"], top_elems["lang"],
top_elems["finetuning_type"], top_elems["template"], top_elems["model_name"],
quantization_bit top_elems["checkpoints"],
top_elems["finetuning_type"],
top_elems["quantization_bit"],
top_elems["template"],
top_elems["source_prefix"]
], ],
[info_box] [info_box]
).then( ).then(
@@ -39,7 +42,6 @@ def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
) )
return dict( return dict(
quantization_bit=quantization_bit,
info_box=info_box, info_box=info_box,
load_btn=load_btn, load_btn=load_btn,
unload_btn=unload_btn, unload_btn=unload_btn,

View File

@@ -12,8 +12,8 @@ from llmtuner.webui.utils import can_preview, get_preview, gen_plot
def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]: def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
with gr.Row(): with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, interactive=True, scale=1) dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, interactive=True, scale=4) dataset = gr.Dropdown(multiselect=True, scale=4)
preview_btn = gr.Button(interactive=False, scale=1) preview_btn = gr.Button(interactive=False, scale=1)
preview_box, preview_count, preview_samples, close_btn = create_preview_box() preview_box, preview_count, preview_samples, close_btn = create_preview_box()
@@ -23,22 +23,24 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box]) preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
with gr.Row(): with gr.Row():
learning_rate = gr.Textbox(value="5e-5", interactive=True) max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
num_train_epochs = gr.Textbox(value="3.0", interactive=True) max_target_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
max_samples = gr.Textbox(value="100000", interactive=True) learning_rate = gr.Textbox(value="5e-5")
quantization_bit = gr.Dropdown([8, 4]) num_train_epochs = gr.Textbox(value="3.0")
max_samples = gr.Textbox(value="100000")
with gr.Row(): with gr.Row():
batch_size = gr.Slider(value=4, minimum=1, maximum=128, step=1, interactive=True) batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=32, step=1, interactive=True) gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
lr_scheduler_type = gr.Dropdown( lr_scheduler_type = gr.Dropdown(
value="cosine", choices=[scheduler.value for scheduler in SchedulerType], interactive=True value="cosine", choices=[scheduler.value for scheduler in SchedulerType]
) )
dev_ratio = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
fp16 = gr.Checkbox(value=True) fp16 = gr.Checkbox(value=True)
with gr.Row(): with gr.Row():
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5, interactive=True) logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
save_steps = gr.Slider(value=100, minimum=10, maximum=2000, step=10, interactive=True) save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
with gr.Row(): with gr.Row():
start_btn = gr.Button() start_btn = gr.Button()
@@ -55,11 +57,28 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
start_btn.click( start_btn.click(
runner.run_train, runner.run_train,
[ [
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"], top_elems["lang"],
top_elems["finetuning_type"], top_elems["template"], top_elems["model_name"],
dataset, dataset_dir, learning_rate, num_train_epochs, max_samples, top_elems["checkpoints"],
fp16, quantization_bit, batch_size, gradient_accumulation_steps, top_elems["finetuning_type"],
lr_scheduler_type, logging_steps, save_steps, output_dir top_elems["quantization_bit"],
top_elems["template"],
top_elems["source_prefix"],
dataset_dir,
dataset,
max_source_length,
max_target_length,
learning_rate,
num_train_epochs,
max_samples,
batch_size,
gradient_accumulation_steps,
lr_scheduler_type,
dev_ratio,
fp16,
logging_steps,
save_steps,
output_dir
], ],
[output_box] [output_box]
) )
@@ -76,13 +95,15 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
preview_count=preview_count, preview_count=preview_count,
preview_samples=preview_samples, preview_samples=preview_samples,
close_btn=close_btn, close_btn=close_btn,
max_source_length=max_source_length,
max_target_length=max_target_length,
learning_rate=learning_rate, learning_rate=learning_rate,
num_train_epochs=num_train_epochs, num_train_epochs=num_train_epochs,
max_samples=max_samples, max_samples=max_samples,
quantization_bit=quantization_bit,
batch_size=batch_size, batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps, gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type, lr_scheduler_type=lr_scheduler_type,
dev_ratio=dev_ratio,
fp16=fp16, fp16=fp16,
logging_steps=logging_steps, logging_steps=logging_steps,
save_steps=save_steps, save_steps=save_steps,

View File

@@ -6,29 +6,40 @@ from gradio.components import Component
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
from llmtuner.extras.template import templates from llmtuner.extras.template import templates
from llmtuner.webui.common import list_checkpoint, get_model_path, save_config from llmtuner.webui.common import list_checkpoint, get_model_path, save_config
from llmtuner.webui.utils import can_quantize
def create_top() -> Dict[str, Component]: def create_top() -> Dict[str, Component]:
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
with gr.Row(): with gr.Row():
lang = gr.Dropdown(choices=["en", "zh"], value="en", interactive=True, scale=1) lang = gr.Dropdown(choices=["en", "zh"], value="en", scale=1)
model_name = gr.Dropdown(choices=available_models, scale=3) model_name = gr.Dropdown(choices=available_models, scale=3)
model_path = gr.Textbox(scale=3) model_path = gr.Textbox(scale=3)
with gr.Row(): with gr.Row():
finetuning_type = gr.Dropdown(value="lora", choices=METHODS, interactive=True, scale=1) finetuning_type = gr.Dropdown(value="lora", choices=METHODS, scale=1)
template = gr.Dropdown(value="default", choices=list(templates.keys()), interactive=True, scale=1) checkpoints = gr.Dropdown(multiselect=True, scale=5)
checkpoints = gr.Dropdown(multiselect=True, interactive=True, scale=4)
refresh_btn = gr.Button(scale=1) refresh_btn = gr.Button(scale=1)
with gr.Row():
quantization_bit = gr.Dropdown([8, 4], scale=1)
template = gr.Dropdown(value="default", choices=list(templates.keys()), scale=2)
source_prefix = gr.Textbox(scale=4)
model_name.change( model_name.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints] list_checkpoint, [model_name, finetuning_type], [checkpoints]
).then( ).then(
get_model_path, [model_name], [model_path] get_model_path, [model_name], [model_path]
) # do not save config since the below line will save ) # do not save config since the below line will save
model_path.change(save_config, [model_name, model_path]) model_path.change(save_config, [model_name, model_path])
finetuning_type.change(list_checkpoint, [model_name, finetuning_type], [checkpoints])
finetuning_type.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints]
).then(
can_quantize, [finetuning_type], [quantization_bit]
)
refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints]) refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints])
return dict( return dict(
@@ -38,5 +49,7 @@ def create_top() -> Dict[str, Component]:
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
template=template, template=template,
checkpoints=checkpoints, checkpoints=checkpoints,
refresh_btn=refresh_btn refresh_btn=refresh_btn,
quantization_bit=quantization_bit,
source_prefix=source_prefix
) )

View File

@@ -25,6 +25,14 @@ LOCALES = {
"info": "本地模型的文件路径或 Hugging Face 的模型标识符。" "info": "本地模型的文件路径或 Hugging Face 的模型标识符。"
} }
}, },
"finetuning_type": {
"en": {
"label": "Finetuning method"
},
"zh": {
"label": "微调方法"
}
},
"checkpoints": { "checkpoints": {
"en": { "en": {
"label": "Checkpoints" "label": "Checkpoints"
@@ -33,14 +41,6 @@ LOCALES = {
"label": "模型断点" "label": "模型断点"
} }
}, },
"template": {
"en": {
"label": "Prompt template"
},
"zh": {
"label": "提示模板"
}
},
"refresh_btn": { "refresh_btn": {
"en": { "en": {
"value": "Refresh checkpoints" "value": "Refresh checkpoints"
@@ -49,6 +49,36 @@ LOCALES = {
"value": "刷新断点" "value": "刷新断点"
} }
}, },
"quantization_bit": {
"en": {
"label": "Quantization bit (optional)",
"info": "Enable 4/8-bit model quantization."
},
"zh": {
"label": "量化等级(非必填)",
"info": "启用 4/8 比特模型量化。"
}
},
"template": {
"en": {
"label": "Prompt template",
"info": "The template used in constructing prompts."
},
"zh": {
"label": "提示模板",
"info": "构建提示词时使用的模板"
}
},
"source_prefix": {
"en": {
"label": "Source prefix (optional)",
"info": "A sequence used as the prefix of each samples."
},
"zh": {
"label": "前缀序列(非必填)",
"info": "作为每个输入样本前缀的序列"
}
},
"dataset_dir": { "dataset_dir": {
"en": { "en": {
"label": "Data dir", "label": "Data dir",
@@ -99,66 +129,24 @@ LOCALES = {
"value": "关闭" "value": "关闭"
} }
}, },
"max_samples": { "max_source_length": {
"en": { "en": {
"label": "Max samples", "label": "Max source length",
"info": "Maximum samples per dataset." "info": "Max tokens in source sequence."
}, },
"zh": { "zh": {
"label": "最大样本数", "label": "输入序列最大长度",
"info": "每个数据集最多使用的样本数" "info": "输入序列分词后的最大长度"
} }
}, },
"batch_size": { "max_target_length": {
"en": { "en": {
"label": "Batch size", "label": "Max target length",
"info": "Number of samples to process per GPU." "info": "Max tokens in target sequence."
},
"zh":{
"label": "批处理大小",
"info": "每块 GPU 上处理的样本数量。"
}
},
"quantization_bit": {
"en": {
"label": "Quantization bit",
"info": "Enable 4/8-bit model quantization."
}, },
"zh": { "zh": {
"label": "量化", "label": "输出序列最大长度",
"info": "启用 4/8 比特模型量化" "info": "输出序列分词后的最大长度"
}
},
"start_btn": {
"en": {
"value": "Start"
},
"zh": {
"value": "开始"
}
},
"stop_btn": {
"en": {
"value": "Abort"
},
"zh": {
"value": "中断"
}
},
"output_box": {
"en": {
"value": "Ready."
},
"zh": {
"value": "准备就绪。"
}
},
"finetuning_type": {
"en": {
"label": "Finetuning method"
},
"zh": {
"label": "微调方法"
} }
}, },
"learning_rate": { "learning_rate": {
@@ -181,6 +169,26 @@ LOCALES = {
"info": "需要执行的训练总轮数。" "info": "需要执行的训练总轮数。"
} }
}, },
"max_samples": {
"en": {
"label": "Max samples",
"info": "Maximum samples per dataset."
},
"zh": {
"label": "最大样本数",
"info": "每个数据集最多使用的样本数。"
}
},
"batch_size": {
"en": {
"label": "Batch size",
"info": "Number of samples to process per GPU."
},
"zh":{
"label": "批处理大小",
"info": "每块 GPU 上处理的样本数量。"
}
},
"gradient_accumulation_steps": { "gradient_accumulation_steps": {
"en": { "en": {
"label": "Gradient accumulation", "label": "Gradient accumulation",
@@ -201,6 +209,16 @@ LOCALES = {
"info": "采用的学习率调节器名称。" "info": "采用的学习率调节器名称。"
} }
}, },
"dev_ratio": {
"en": {
"label": "Dev ratio",
"info": "Proportion of data in the dev set."
},
"zh": {
"label": "验证集比例",
"info": "验证集占全部样本的百分比。"
}
},
"fp16": { "fp16": {
"en": { "en": {
"label": "fp16", "label": "fp16",
@@ -231,6 +249,22 @@ LOCALES = {
"info": "每两次断点保存间的更新步数。" "info": "每两次断点保存间的更新步数。"
} }
}, },
"start_btn": {
"en": {
"value": "Start"
},
"zh": {
"value": "开始"
}
},
"stop_btn": {
"en": {
"value": "Abort"
},
"zh": {
"value": "中断"
}
},
"output_dir": { "output_dir": {
"en": { "en": {
"label": "Checkpoint name", "label": "Checkpoint name",
@@ -241,6 +275,14 @@ LOCALES = {
"info": "保存模型断点的文件夹名称。" "info": "保存模型断点的文件夹名称。"
} }
}, },
"output_box": {
"en": {
"value": "Ready."
},
"zh": {
"value": "准备就绪。"
}
},
"loss_viewer": { "loss_viewer": {
"en": { "en": {
"label": "Loss" "label": "Loss"
@@ -257,14 +299,6 @@ LOCALES = {
"label": "保存预测结果" "label": "保存预测结果"
} }
}, },
"info_box": {
"en": {
"value": "Model unloaded, please load a model first."
},
"zh": {
"value": "模型未加载,请先加载模型。"
}
},
"load_btn": { "load_btn": {
"en": { "en": {
"value": "Load model" "value": "Load model"
@@ -281,6 +315,14 @@ LOCALES = {
"value": "卸载模型" "value": "卸载模型"
} }
}, },
"info_box": {
"en": {
"value": "Model unloaded, please load a model first."
},
"zh": {
"value": "模型未加载,请先加载模型。"
}
},
"query": { "query": {
"en": { "en": {
"placeholder": "Input..." "placeholder": "Input..."
@@ -305,6 +347,14 @@ LOCALES = {
"value": "清空历史" "value": "清空历史"
} }
}, },
"max_length": {
"en": {
"label": "Maximum length"
},
"zh": {
"label": "最大长度"
}
},
"max_new_tokens": { "max_new_tokens": {
"en": { "en": {
"label": "Maximum new tokens" "label": "Maximum new tokens"

View File

@@ -3,7 +3,7 @@ import os
import threading import threading
import time import time
import transformers import transformers
from typing import Optional, Tuple from typing import List, Optional, Tuple
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import DEFAULT_MODULE # will be deprecated from llmtuner.extras.constants import DEFAULT_MODULE # will be deprecated
@@ -59,10 +59,29 @@ class Runner:
return finish_info if finish_info is not None else ALERTS["info_finished"][lang] return finish_info if finish_info is not None else ALERTS["info_finished"][lang]
def run_train( def run_train(
self, lang, model_name, checkpoints, finetuning_type, template, self,
dataset, dataset_dir, learning_rate, num_train_epochs, max_samples, lang: str,
fp16, quantization_bit, batch_size, gradient_accumulation_steps, model_name: str,
lr_scheduler_type, logging_steps, save_steps, output_dir checkpoints: List[str],
finetuning_type: str,
quantization_bit: str,
template: str,
source_prefix: str,
dataset_dir: str,
dataset: List[str],
max_source_length: int,
max_target_length: int,
learning_rate: str,
num_train_epochs: str,
max_samples: str,
batch_size: int,
gradient_accumulation_steps: int,
lr_scheduler_type: str,
dev_ratio: float,
fp16: bool,
logging_steps: int,
save_steps: int,
output_dir: str
): ):
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset) model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error: if error:
@@ -79,25 +98,35 @@ class Runner:
args = dict( args = dict(
model_name_or_path=model_name_or_path, model_name_or_path=model_name_or_path,
do_train=True, do_train=True,
finetuning_type=finetuning_type,
lora_target=DEFAULT_MODULE.get(model_name.split("-")[0], None) or "q_proj,v_proj",
prompt_template=template,
dataset=",".join(dataset),
dataset_dir=dataset_dir,
max_samples=int(max_samples),
output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir),
checkpoint_dir=checkpoint_dir,
overwrite_cache=True, overwrite_cache=True,
lora_target=DEFAULT_MODULE.get(model_name.split("-")[0], None) or "q_proj,v_proj",
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None,
prompt_template=template,
source_prefix=source_prefix,
dataset_dir=dataset_dir,
dataset=",".join(dataset),
max_source_length=max_source_length,
max_target_length=max_target_length,
learning_rate=float(learning_rate),
num_train_epochs=float(num_train_epochs),
max_samples=int(max_samples),
per_device_train_batch_size=batch_size, per_device_train_batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps, gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type, lr_scheduler_type=lr_scheduler_type,
fp16=fp16,
logging_steps=logging_steps, logging_steps=logging_steps,
save_steps=save_steps, save_steps=save_steps,
learning_rate=float(learning_rate), output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
num_train_epochs=float(num_train_epochs),
fp16=fp16,
quantization_bit=int(quantization_bit) if quantization_bit else None
) )
if dev_ratio > 1e-6:
args["dev_ratio"] = dev_ratio
args["evaluation_strategy"] = "steps"
args["eval_steps"] = save_steps
args["load_best_model_at_end"] = True
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args) model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
run_args = dict( run_args = dict(
@@ -120,8 +149,21 @@ class Runner:
yield self.finalize(lang) yield self.finalize(lang)
def run_eval( def run_eval(
self, lang, model_name, checkpoints, finetuning_type, template, self,
dataset, dataset_dir, max_samples, batch_size, quantization_bit, predict lang: str,
model_name: str,
checkpoints: List[str],
finetuning_type: str,
quantization_bit: str,
template: str,
source_prefix: str,
dataset_dir: str,
dataset: List[str],
max_source_length: int,
max_target_length: int,
max_samples: str,
batch_size: int,
predict: bool
): ):
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset) model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error: if error:
@@ -140,17 +182,20 @@ class Runner:
args = dict( args = dict(
model_name_or_path=model_name_or_path, model_name_or_path=model_name_or_path,
do_eval=True, do_eval=True,
finetuning_type=finetuning_type,
prompt_template=template,
dataset=",".join(dataset),
dataset_dir=dataset_dir,
max_samples=int(max_samples),
output_dir=output_dir,
checkpoint_dir=checkpoint_dir,
overwrite_cache=True, overwrite_cache=True,
predict_with_generate=True, predict_with_generate=True,
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None,
prompt_template=template,
source_prefix=source_prefix,
dataset_dir=dataset_dir,
dataset=",".join(dataset),
max_source_length=max_source_length,
max_target_length=max_target_length,
max_samples=int(max_samples),
per_device_eval_batch_size=batch_size, per_device_eval_batch_size=batch_size,
quantization_bit=int(quantization_bit) if quantization_bit else None output_dir=output_dir
) )
if predict: if predict:

View File

@@ -3,7 +3,7 @@ import json
import gradio as gr import gradio as gr
import matplotlib.figure import matplotlib.figure
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from typing import Tuple from typing import Any, Dict, Tuple
from datetime import datetime from datetime import datetime
from llmtuner.extras.ploting import smooth from llmtuner.extras.ploting import smooth
@@ -23,7 +23,7 @@ def get_time() -> str:
return datetime.now().strftime('%Y-%m-%d-%H-%M-%S') return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
def can_preview(dataset_dir: str, dataset: list) -> dict: def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f) dataset_info = json.load(f)
if ( if (
@@ -36,7 +36,7 @@ def can_preview(dataset_dir: str, dataset: list) -> dict:
return gr.update(interactive=False) return gr.update(interactive=False)
def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, dict]: def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, Dict[str, Any]]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f) dataset_info = json.load(f)
data_file = dataset_info[dataset[0]]["file_name"] data_file = dataset_info[dataset[0]]["file_name"]
@@ -45,6 +45,13 @@ def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, dict]:
return len(data), data[:2], gr.update(visible=True) return len(data), data[:2], gr.update(visible=True)
def can_quantize(finetuning_type: str) -> Dict[str, Any]:
if finetuning_type != "lora":
return gr.update(value="", interactive=False)
else:
return gr.update(interactive=True)
def get_eval_results(path: os.PathLike) -> str: def get_eval_results(path: os.PathLike) -> str:
with open(path, "r", encoding="utf-8") as f: with open(path, "r", encoding="utf-8") as f:
result = json.dumps(json.load(f), indent=4) result = json.dumps(json.load(f), indent=4)
@@ -66,6 +73,10 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl
if log_info.get("loss", None): if log_info.get("loss", None):
steps.append(log_info["current_steps"]) steps.append(log_info["current_steps"])
losses.append(log_info["loss"]) losses.append(log_info["loss"])
if len(losses) == 0:
return None
ax.plot(steps, losses, alpha=0.4, label="original") ax.plot(steps, losses, alpha=0.4, label="original")
ax.plot(steps, smooth(losses), label="smoothed") ax.plot(steps, smooth(losses), label="smoothed")
ax.legend() ax.legend()

44
src/web_demo.py Normal file
View File

@@ -0,0 +1,44 @@
# coding=utf-8
# Implements user interface in browser for fine-tuned models.
# Usage: python web_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
import gradio as gr
from transformers.utils.versions import require_version
from llmtuner import get_infer_args
from llmtuner.webui.chat import WebChatModel
from llmtuner.webui.components.chatbot import create_chat_box
from llmtuner.webui.manager import Manager
require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
def main():
chat_model = WebChatModel(*get_infer_args())
with gr.Blocks(title="Web Demo") as demo:
lang = gr.Dropdown(choices=["en", "zh"], value="en")
_, _, _, chat_elems = create_chat_box(chat_model, visible=True)
manager = Manager([{"lang": lang}, chat_elems])
demo.load(
manager.gen_label,
[lang],
[lang] + [elem for elem in chat_elems.values()],
)
lang.change(
manager.gen_label,
[lang],
[lang] + [elem for elem in chat_elems.values()],
)
demo.queue()
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)
if __name__ == "__main__":
main()