Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
35e76879f5 | ||
|
|
8e4ae0aaac | ||
|
|
5ed2a97056 | ||
|
|
03eba6f041 | ||
|
|
ec166e736a | ||
|
|
c85a6b83b3 | ||
|
|
a864a7b395 | ||
|
|
fd8c2d4aac |
@@ -291,6 +291,14 @@ python src/cli_demo.py \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
```
|
||||
|
||||
### Web Demo
|
||||
|
||||
```bash
|
||||
python src/web_demo.py \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
```
|
||||
|
||||
### Export model
|
||||
|
||||
```bash
|
||||
|
||||
@@ -10,6 +10,8 @@ from llmtuner.tuner import get_infer_args
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.chat.stream_chat import ChatModel
|
||||
from llmtuner.api.protocol import (
|
||||
Role,
|
||||
Finish,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ChatMessage,
|
||||
@@ -49,12 +51,12 @@ def create_app():
|
||||
|
||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
if request.messages[-1].role != "user":
|
||||
if request.messages[-1].role != Role.USER:
|
||||
raise HTTPException(status_code=400, detail="Invalid request")
|
||||
query = request.messages[-1].content
|
||||
|
||||
prev_messages = request.messages[:-1]
|
||||
if len(prev_messages) > 0 and prev_messages[0].role == "system":
|
||||
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
|
||||
prefix = prev_messages.pop(0).content
|
||||
else:
|
||||
prefix = None
|
||||
@@ -62,7 +64,7 @@ def create_app():
|
||||
history = []
|
||||
if len(prev_messages) % 2 == 0:
|
||||
for i in range(0, len(prev_messages), 2):
|
||||
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
|
||||
if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT:
|
||||
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
||||
|
||||
if request.stream:
|
||||
@@ -81,19 +83,19 @@ def create_app():
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=ChatMessage(role="assistant", content=response),
|
||||
finish_reason="stop"
|
||||
message=ChatMessage(role=Role.ASSISTANT, content=response),
|
||||
finish_reason=Finish.STOP
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage, object="chat.completion")
|
||||
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage)
|
||||
|
||||
async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(role="assistant"),
|
||||
delta=DeltaMessage(role=Role.ASSISTANT),
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk")
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield json.dumps(chunk, ensure_ascii=False)
|
||||
|
||||
for new_text in chat_model.stream_chat(
|
||||
@@ -107,15 +109,15 @@ def create_app():
|
||||
delta=DeltaMessage(content=new_text),
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk")
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield json.dumps(chunk, ensure_ascii=False)
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(),
|
||||
finish_reason="stop"
|
||||
finish_reason=Finish.STOP
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk")
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield json.dumps(chunk, ensure_ascii=False)
|
||||
yield "[DONE]"
|
||||
|
||||
|
||||
@@ -1,6 +1,18 @@
|
||||
import time
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Literal, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class Role(str, Enum):
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
class Finish(str, Enum):
|
||||
STOP = "stop"
|
||||
LENGTH = "length"
|
||||
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
@@ -19,12 +31,12 @@ class ModelList(BaseModel):
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Literal["user", "assistant", "system"]
|
||||
role: Role
|
||||
content: str
|
||||
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
role: Optional[Literal["user", "assistant", "system"]] = None
|
||||
role: Optional[Role] = None
|
||||
content: Optional[str] = None
|
||||
|
||||
|
||||
@@ -41,13 +53,13 @@ class ChatCompletionRequest(BaseModel):
|
||||
class ChatCompletionResponseChoice(BaseModel):
|
||||
index: int
|
||||
message: ChatMessage
|
||||
finish_reason: Literal["stop", "length"]
|
||||
finish_reason: Finish
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
index: int
|
||||
delta: DeltaMessage
|
||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||
finish_reason: Optional[Finish] = None
|
||||
|
||||
|
||||
class ChatCompletionResponseUsage(BaseModel):
|
||||
@@ -58,7 +70,7 @@ class ChatCompletionResponseUsage(BaseModel):
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: Optional[str] = "chatcmpl-default"
|
||||
object: Literal["chat.completion"]
|
||||
object: Optional[str] = "chat.completion"
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseChoice]
|
||||
@@ -67,7 +79,7 @@ class ChatCompletionResponse(BaseModel):
|
||||
|
||||
class ChatCompletionStreamResponse(BaseModel):
|
||||
id: Optional[str] = "chatcmpl-default"
|
||||
object: Literal["chat.completion.chunk"]
|
||||
object: Optional[str] = "chat.completion.chunk"
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseStreamChoice]
|
||||
|
||||
@@ -47,6 +47,9 @@ class LogCallback(TrainerCallback):
|
||||
r"""
|
||||
Event called after logging the last logs.
|
||||
"""
|
||||
if not state.is_world_process_zero:
|
||||
return
|
||||
|
||||
cur_time = time.time()
|
||||
cur_steps = state.log_history[-1].get("step")
|
||||
elapsed_time = cur_time - self.start_time
|
||||
|
||||
@@ -202,7 +202,20 @@ Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
|
||||
register_template(
|
||||
name="baichuan",
|
||||
prefix="",
|
||||
prompt="<reserved_102>{query}<reserved_103>",
|
||||
sep="",
|
||||
prompt=" <reserved_102> {query} <reserved_103> ",
|
||||
sep="</s>",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha
|
||||
https://huggingface.co/HuggingFaceH4/starchat-beta
|
||||
"""
|
||||
register_template(
|
||||
name="starchat",
|
||||
prefix="<|system|>\n",
|
||||
prompt="<|user|>\n{query}<|end|>\n<|assistant|>\n",
|
||||
sep="<|end|>\n",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
@@ -108,7 +108,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
replace_model(unwrapped_model, target="reward")
|
||||
with torch.no_grad():
|
||||
_, _, values = self.model(**self.prepare_model_inputs(queries, responses))
|
||||
rewards = [reward for reward in values[-1].to(torch.float32)] # use float32 type
|
||||
rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
|
||||
replace_model(unwrapped_model, target="default")
|
||||
|
||||
# Run PPO step
|
||||
|
||||
@@ -23,7 +23,7 @@ class ComputeMetrics:
|
||||
Uses the model predictions to compute metrics.
|
||||
"""
|
||||
preds, labels = eval_preds
|
||||
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
||||
score_dict = {"accuracy": [], "rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
||||
|
||||
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
|
||||
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
|
||||
@@ -47,5 +47,6 @@ class ComputeMetrics:
|
||||
|
||||
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
|
||||
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
|
||||
score_dict["accuracy"].append(float(len(label) != 0 and pred[:len(label)] == label))
|
||||
|
||||
return {k: float(np.mean(v)) for k, v in score_dict.items()}
|
||||
|
||||
@@ -11,14 +11,22 @@ from llmtuner.webui.locales import ALERTS
|
||||
|
||||
class WebChatModel(ChatModel):
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, *args):
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.generating_args = GeneratingArguments()
|
||||
if len(args) != 0:
|
||||
super().__init__(*args)
|
||||
|
||||
def load_model(
|
||||
self, lang: str, model_name: str, checkpoints: list,
|
||||
finetuning_type: str, template: str, quantization_bit: str
|
||||
self,
|
||||
lang: str,
|
||||
model_name: str,
|
||||
checkpoints: List[str],
|
||||
finetuning_type: str,
|
||||
quantization_bit: str,
|
||||
template: str,
|
||||
source_prefix: str
|
||||
):
|
||||
if self.model is not None:
|
||||
yield ALERTS["err_exists"][lang]
|
||||
@@ -43,10 +51,11 @@ class WebChatModel(ChatModel):
|
||||
yield ALERTS["info_loading"][lang]
|
||||
args = dict(
|
||||
model_name_or_path=model_name_or_path,
|
||||
finetuning_type=finetuning_type,
|
||||
prompt_template=template,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit else None
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
||||
prompt_template=template,
|
||||
source_prefix=source_prefix
|
||||
)
|
||||
super().__init__(*get_infer_args(args))
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Dict, Tuple
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import gradio as gr
|
||||
from gradio.blocks import Block
|
||||
@@ -8,9 +8,10 @@ from llmtuner.webui.chat import WebChatModel
|
||||
|
||||
|
||||
def create_chat_box(
|
||||
chat_model: WebChatModel
|
||||
chat_model: WebChatModel,
|
||||
visible: Optional[bool] = False
|
||||
) -> Tuple[Block, Component, Component, Dict[str, Component]]:
|
||||
with gr.Box(visible=False) as chat_box:
|
||||
with gr.Box(visible=visible) as chat_box:
|
||||
chatbot = gr.Chatbot()
|
||||
|
||||
with gr.Row():
|
||||
|
||||
@@ -10,8 +10,8 @@ from llmtuner.webui.utils import can_preview, get_preview
|
||||
|
||||
def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
|
||||
with gr.Row():
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, interactive=True, scale=2)
|
||||
dataset = gr.Dropdown(multiselect=True, interactive=True, scale=4)
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
preview_btn = gr.Button(interactive=False, scale=1)
|
||||
|
||||
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
|
||||
@@ -21,9 +21,10 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
|
||||
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
|
||||
|
||||
with gr.Row():
|
||||
max_samples = gr.Textbox(value="100000", interactive=True)
|
||||
batch_size = gr.Slider(value=8, minimum=1, maximum=128, step=1, interactive=True)
|
||||
quantization_bit = gr.Dropdown([8, 4])
|
||||
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
|
||||
max_target_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
|
||||
max_samples = gr.Textbox(value="100000")
|
||||
batch_size = gr.Slider(value=8, minimum=1, maximum=512, step=1)
|
||||
predict = gr.Checkbox(value=True)
|
||||
|
||||
with gr.Row():
|
||||
@@ -35,9 +36,20 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
|
||||
start_btn.click(
|
||||
runner.run_eval,
|
||||
[
|
||||
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"], top_elems["template"],
|
||||
dataset, dataset_dir, max_samples, batch_size, quantization_bit, predict
|
||||
top_elems["lang"],
|
||||
top_elems["model_name"],
|
||||
top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"],
|
||||
top_elems["quantization_bit"],
|
||||
top_elems["template"],
|
||||
top_elems["source_prefix"],
|
||||
dataset_dir,
|
||||
dataset,
|
||||
max_source_length,
|
||||
max_target_length,
|
||||
max_samples,
|
||||
batch_size,
|
||||
predict
|
||||
],
|
||||
[output_box]
|
||||
)
|
||||
@@ -50,9 +62,10 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
|
||||
preview_count=preview_count,
|
||||
preview_samples=preview_samples,
|
||||
close_btn=close_btn,
|
||||
max_source_length=max_source_length,
|
||||
max_target_length=max_target_length,
|
||||
max_samples=max_samples,
|
||||
batch_size=batch_size,
|
||||
quantization_bit=quantization_bit,
|
||||
predict=predict,
|
||||
start_btn=start_btn,
|
||||
stop_btn=stop_btn,
|
||||
|
||||
@@ -11,7 +11,6 @@ def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
|
||||
with gr.Row():
|
||||
load_btn = gr.Button()
|
||||
unload_btn = gr.Button()
|
||||
quantization_bit = gr.Dropdown([8, 4])
|
||||
|
||||
info_box = gr.Markdown()
|
||||
|
||||
@@ -21,9 +20,13 @@ def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
|
||||
load_btn.click(
|
||||
chat_model.load_model,
|
||||
[
|
||||
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"], top_elems["template"],
|
||||
quantization_bit
|
||||
top_elems["lang"],
|
||||
top_elems["model_name"],
|
||||
top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"],
|
||||
top_elems["quantization_bit"],
|
||||
top_elems["template"],
|
||||
top_elems["source_prefix"]
|
||||
],
|
||||
[info_box]
|
||||
).then(
|
||||
@@ -39,7 +42,6 @@ def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
|
||||
)
|
||||
|
||||
return dict(
|
||||
quantization_bit=quantization_bit,
|
||||
info_box=info_box,
|
||||
load_btn=load_btn,
|
||||
unload_btn=unload_btn,
|
||||
|
||||
@@ -12,8 +12,8 @@ from llmtuner.webui.utils import can_preview, get_preview, gen_plot
|
||||
|
||||
def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
|
||||
with gr.Row():
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, interactive=True, scale=1)
|
||||
dataset = gr.Dropdown(multiselect=True, interactive=True, scale=4)
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
preview_btn = gr.Button(interactive=False, scale=1)
|
||||
|
||||
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
|
||||
@@ -23,22 +23,24 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
|
||||
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
|
||||
|
||||
with gr.Row():
|
||||
learning_rate = gr.Textbox(value="5e-5", interactive=True)
|
||||
num_train_epochs = gr.Textbox(value="3.0", interactive=True)
|
||||
max_samples = gr.Textbox(value="100000", interactive=True)
|
||||
quantization_bit = gr.Dropdown([8, 4])
|
||||
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
|
||||
max_target_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
|
||||
learning_rate = gr.Textbox(value="5e-5")
|
||||
num_train_epochs = gr.Textbox(value="3.0")
|
||||
max_samples = gr.Textbox(value="100000")
|
||||
|
||||
with gr.Row():
|
||||
batch_size = gr.Slider(value=4, minimum=1, maximum=128, step=1, interactive=True)
|
||||
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=32, step=1, interactive=True)
|
||||
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
||||
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
||||
lr_scheduler_type = gr.Dropdown(
|
||||
value="cosine", choices=[scheduler.value for scheduler in SchedulerType], interactive=True
|
||||
value="cosine", choices=[scheduler.value for scheduler in SchedulerType]
|
||||
)
|
||||
dev_ratio = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
|
||||
fp16 = gr.Checkbox(value=True)
|
||||
|
||||
with gr.Row():
|
||||
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5, interactive=True)
|
||||
save_steps = gr.Slider(value=100, minimum=10, maximum=2000, step=10, interactive=True)
|
||||
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
|
||||
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
|
||||
|
||||
with gr.Row():
|
||||
start_btn = gr.Button()
|
||||
@@ -55,11 +57,28 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
|
||||
start_btn.click(
|
||||
runner.run_train,
|
||||
[
|
||||
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"], top_elems["template"],
|
||||
dataset, dataset_dir, learning_rate, num_train_epochs, max_samples,
|
||||
fp16, quantization_bit, batch_size, gradient_accumulation_steps,
|
||||
lr_scheduler_type, logging_steps, save_steps, output_dir
|
||||
top_elems["lang"],
|
||||
top_elems["model_name"],
|
||||
top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"],
|
||||
top_elems["quantization_bit"],
|
||||
top_elems["template"],
|
||||
top_elems["source_prefix"],
|
||||
dataset_dir,
|
||||
dataset,
|
||||
max_source_length,
|
||||
max_target_length,
|
||||
learning_rate,
|
||||
num_train_epochs,
|
||||
max_samples,
|
||||
batch_size,
|
||||
gradient_accumulation_steps,
|
||||
lr_scheduler_type,
|
||||
dev_ratio,
|
||||
fp16,
|
||||
logging_steps,
|
||||
save_steps,
|
||||
output_dir
|
||||
],
|
||||
[output_box]
|
||||
)
|
||||
@@ -76,13 +95,15 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
|
||||
preview_count=preview_count,
|
||||
preview_samples=preview_samples,
|
||||
close_btn=close_btn,
|
||||
max_source_length=max_source_length,
|
||||
max_target_length=max_target_length,
|
||||
learning_rate=learning_rate,
|
||||
num_train_epochs=num_train_epochs,
|
||||
max_samples=max_samples,
|
||||
quantization_bit=quantization_bit,
|
||||
batch_size=batch_size,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
lr_scheduler_type=lr_scheduler_type,
|
||||
dev_ratio=dev_ratio,
|
||||
fp16=fp16,
|
||||
logging_steps=logging_steps,
|
||||
save_steps=save_steps,
|
||||
|
||||
@@ -6,29 +6,40 @@ from gradio.components import Component
|
||||
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
|
||||
from llmtuner.extras.template import templates
|
||||
from llmtuner.webui.common import list_checkpoint, get_model_path, save_config
|
||||
from llmtuner.webui.utils import can_quantize
|
||||
|
||||
|
||||
def create_top() -> Dict[str, Component]:
|
||||
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
|
||||
|
||||
with gr.Row():
|
||||
lang = gr.Dropdown(choices=["en", "zh"], value="en", interactive=True, scale=1)
|
||||
lang = gr.Dropdown(choices=["en", "zh"], value="en", scale=1)
|
||||
model_name = gr.Dropdown(choices=available_models, scale=3)
|
||||
model_path = gr.Textbox(scale=3)
|
||||
|
||||
with gr.Row():
|
||||
finetuning_type = gr.Dropdown(value="lora", choices=METHODS, interactive=True, scale=1)
|
||||
template = gr.Dropdown(value="default", choices=list(templates.keys()), interactive=True, scale=1)
|
||||
checkpoints = gr.Dropdown(multiselect=True, interactive=True, scale=4)
|
||||
finetuning_type = gr.Dropdown(value="lora", choices=METHODS, scale=1)
|
||||
checkpoints = gr.Dropdown(multiselect=True, scale=5)
|
||||
refresh_btn = gr.Button(scale=1)
|
||||
|
||||
with gr.Row():
|
||||
quantization_bit = gr.Dropdown([8, 4], scale=1)
|
||||
template = gr.Dropdown(value="default", choices=list(templates.keys()), scale=2)
|
||||
source_prefix = gr.Textbox(scale=4)
|
||||
|
||||
model_name.change(
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||
).then(
|
||||
get_model_path, [model_name], [model_path]
|
||||
) # do not save config since the below line will save
|
||||
model_path.change(save_config, [model_name, model_path])
|
||||
finetuning_type.change(list_checkpoint, [model_name, finetuning_type], [checkpoints])
|
||||
|
||||
finetuning_type.change(
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||
).then(
|
||||
can_quantize, [finetuning_type], [quantization_bit]
|
||||
)
|
||||
|
||||
refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints])
|
||||
|
||||
return dict(
|
||||
@@ -38,5 +49,7 @@ def create_top() -> Dict[str, Component]:
|
||||
finetuning_type=finetuning_type,
|
||||
template=template,
|
||||
checkpoints=checkpoints,
|
||||
refresh_btn=refresh_btn
|
||||
refresh_btn=refresh_btn,
|
||||
quantization_bit=quantization_bit,
|
||||
source_prefix=source_prefix
|
||||
)
|
||||
|
||||
@@ -25,6 +25,14 @@ LOCALES = {
|
||||
"info": "本地模型的文件路径或 Hugging Face 的模型标识符。"
|
||||
}
|
||||
},
|
||||
"finetuning_type": {
|
||||
"en": {
|
||||
"label": "Finetuning method"
|
||||
},
|
||||
"zh": {
|
||||
"label": "微调方法"
|
||||
}
|
||||
},
|
||||
"checkpoints": {
|
||||
"en": {
|
||||
"label": "Checkpoints"
|
||||
@@ -33,14 +41,6 @@ LOCALES = {
|
||||
"label": "模型断点"
|
||||
}
|
||||
},
|
||||
"template": {
|
||||
"en": {
|
||||
"label": "Prompt template"
|
||||
},
|
||||
"zh": {
|
||||
"label": "提示模板"
|
||||
}
|
||||
},
|
||||
"refresh_btn": {
|
||||
"en": {
|
||||
"value": "Refresh checkpoints"
|
||||
@@ -49,6 +49,36 @@ LOCALES = {
|
||||
"value": "刷新断点"
|
||||
}
|
||||
},
|
||||
"quantization_bit": {
|
||||
"en": {
|
||||
"label": "Quantization bit (optional)",
|
||||
"info": "Enable 4/8-bit model quantization."
|
||||
},
|
||||
"zh": {
|
||||
"label": "量化等级(非必填)",
|
||||
"info": "启用 4/8 比特模型量化。"
|
||||
}
|
||||
},
|
||||
"template": {
|
||||
"en": {
|
||||
"label": "Prompt template",
|
||||
"info": "The template used in constructing prompts."
|
||||
},
|
||||
"zh": {
|
||||
"label": "提示模板",
|
||||
"info": "构建提示词时使用的模板"
|
||||
}
|
||||
},
|
||||
"source_prefix": {
|
||||
"en": {
|
||||
"label": "Source prefix (optional)",
|
||||
"info": "A sequence used as the prefix of each samples."
|
||||
},
|
||||
"zh": {
|
||||
"label": "前缀序列(非必填)",
|
||||
"info": "作为每个输入样本前缀的序列"
|
||||
}
|
||||
},
|
||||
"dataset_dir": {
|
||||
"en": {
|
||||
"label": "Data dir",
|
||||
@@ -99,66 +129,24 @@ LOCALES = {
|
||||
"value": "关闭"
|
||||
}
|
||||
},
|
||||
"max_samples": {
|
||||
"max_source_length": {
|
||||
"en": {
|
||||
"label": "Max samples",
|
||||
"info": "Maximum samples per dataset."
|
||||
"label": "Max source length",
|
||||
"info": "Max tokens in source sequence."
|
||||
},
|
||||
"zh": {
|
||||
"label": "最大样本数",
|
||||
"info": "每个数据集最多使用的样本数。"
|
||||
"label": "输入序列最大长度",
|
||||
"info": "输入序列分词后的最大长度。"
|
||||
}
|
||||
},
|
||||
"batch_size": {
|
||||
"max_target_length": {
|
||||
"en": {
|
||||
"label": "Batch size",
|
||||
"info": "Number of samples to process per GPU."
|
||||
},
|
||||
"zh":{
|
||||
"label": "批处理大小",
|
||||
"info": "每块 GPU 上处理的样本数量。"
|
||||
}
|
||||
},
|
||||
"quantization_bit": {
|
||||
"en": {
|
||||
"label": "Quantization bit",
|
||||
"info": "Enable 4/8-bit model quantization."
|
||||
"label": "Max target length",
|
||||
"info": "Max tokens in target sequence."
|
||||
},
|
||||
"zh": {
|
||||
"label": "量化",
|
||||
"info": "启用 4/8 比特模型量化。"
|
||||
}
|
||||
},
|
||||
"start_btn": {
|
||||
"en": {
|
||||
"value": "Start"
|
||||
},
|
||||
"zh": {
|
||||
"value": "开始"
|
||||
}
|
||||
},
|
||||
"stop_btn": {
|
||||
"en": {
|
||||
"value": "Abort"
|
||||
},
|
||||
"zh": {
|
||||
"value": "中断"
|
||||
}
|
||||
},
|
||||
"output_box": {
|
||||
"en": {
|
||||
"value": "Ready."
|
||||
},
|
||||
"zh": {
|
||||
"value": "准备就绪。"
|
||||
}
|
||||
},
|
||||
"finetuning_type": {
|
||||
"en": {
|
||||
"label": "Finetuning method"
|
||||
},
|
||||
"zh": {
|
||||
"label": "微调方法"
|
||||
"label": "输出序列最大长度",
|
||||
"info": "输出序列分词后的最大长度。"
|
||||
}
|
||||
},
|
||||
"learning_rate": {
|
||||
@@ -181,6 +169,26 @@ LOCALES = {
|
||||
"info": "需要执行的训练总轮数。"
|
||||
}
|
||||
},
|
||||
"max_samples": {
|
||||
"en": {
|
||||
"label": "Max samples",
|
||||
"info": "Maximum samples per dataset."
|
||||
},
|
||||
"zh": {
|
||||
"label": "最大样本数",
|
||||
"info": "每个数据集最多使用的样本数。"
|
||||
}
|
||||
},
|
||||
"batch_size": {
|
||||
"en": {
|
||||
"label": "Batch size",
|
||||
"info": "Number of samples to process per GPU."
|
||||
},
|
||||
"zh":{
|
||||
"label": "批处理大小",
|
||||
"info": "每块 GPU 上处理的样本数量。"
|
||||
}
|
||||
},
|
||||
"gradient_accumulation_steps": {
|
||||
"en": {
|
||||
"label": "Gradient accumulation",
|
||||
@@ -201,6 +209,16 @@ LOCALES = {
|
||||
"info": "采用的学习率调节器名称。"
|
||||
}
|
||||
},
|
||||
"dev_ratio": {
|
||||
"en": {
|
||||
"label": "Dev ratio",
|
||||
"info": "Proportion of data in the dev set."
|
||||
},
|
||||
"zh": {
|
||||
"label": "验证集比例",
|
||||
"info": "验证集占全部样本的百分比。"
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"en": {
|
||||
"label": "fp16",
|
||||
@@ -231,6 +249,22 @@ LOCALES = {
|
||||
"info": "每两次断点保存间的更新步数。"
|
||||
}
|
||||
},
|
||||
"start_btn": {
|
||||
"en": {
|
||||
"value": "Start"
|
||||
},
|
||||
"zh": {
|
||||
"value": "开始"
|
||||
}
|
||||
},
|
||||
"stop_btn": {
|
||||
"en": {
|
||||
"value": "Abort"
|
||||
},
|
||||
"zh": {
|
||||
"value": "中断"
|
||||
}
|
||||
},
|
||||
"output_dir": {
|
||||
"en": {
|
||||
"label": "Checkpoint name",
|
||||
@@ -241,6 +275,14 @@ LOCALES = {
|
||||
"info": "保存模型断点的文件夹名称。"
|
||||
}
|
||||
},
|
||||
"output_box": {
|
||||
"en": {
|
||||
"value": "Ready."
|
||||
},
|
||||
"zh": {
|
||||
"value": "准备就绪。"
|
||||
}
|
||||
},
|
||||
"loss_viewer": {
|
||||
"en": {
|
||||
"label": "Loss"
|
||||
@@ -257,14 +299,6 @@ LOCALES = {
|
||||
"label": "保存预测结果"
|
||||
}
|
||||
},
|
||||
"info_box": {
|
||||
"en": {
|
||||
"value": "Model unloaded, please load a model first."
|
||||
},
|
||||
"zh": {
|
||||
"value": "模型未加载,请先加载模型。"
|
||||
}
|
||||
},
|
||||
"load_btn": {
|
||||
"en": {
|
||||
"value": "Load model"
|
||||
@@ -281,6 +315,14 @@ LOCALES = {
|
||||
"value": "卸载模型"
|
||||
}
|
||||
},
|
||||
"info_box": {
|
||||
"en": {
|
||||
"value": "Model unloaded, please load a model first."
|
||||
},
|
||||
"zh": {
|
||||
"value": "模型未加载,请先加载模型。"
|
||||
}
|
||||
},
|
||||
"query": {
|
||||
"en": {
|
||||
"placeholder": "Input..."
|
||||
@@ -305,6 +347,14 @@ LOCALES = {
|
||||
"value": "清空历史"
|
||||
}
|
||||
},
|
||||
"max_length": {
|
||||
"en": {
|
||||
"label": "Maximum length"
|
||||
},
|
||||
"zh": {
|
||||
"label": "最大长度"
|
||||
}
|
||||
},
|
||||
"max_new_tokens": {
|
||||
"en": {
|
||||
"label": "Maximum new tokens"
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import threading
|
||||
import time
|
||||
import transformers
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.constants import DEFAULT_MODULE # will be deprecated
|
||||
@@ -59,10 +59,29 @@ class Runner:
|
||||
return finish_info if finish_info is not None else ALERTS["info_finished"][lang]
|
||||
|
||||
def run_train(
|
||||
self, lang, model_name, checkpoints, finetuning_type, template,
|
||||
dataset, dataset_dir, learning_rate, num_train_epochs, max_samples,
|
||||
fp16, quantization_bit, batch_size, gradient_accumulation_steps,
|
||||
lr_scheduler_type, logging_steps, save_steps, output_dir
|
||||
self,
|
||||
lang: str,
|
||||
model_name: str,
|
||||
checkpoints: List[str],
|
||||
finetuning_type: str,
|
||||
quantization_bit: str,
|
||||
template: str,
|
||||
source_prefix: str,
|
||||
dataset_dir: str,
|
||||
dataset: List[str],
|
||||
max_source_length: int,
|
||||
max_target_length: int,
|
||||
learning_rate: str,
|
||||
num_train_epochs: str,
|
||||
max_samples: str,
|
||||
batch_size: int,
|
||||
gradient_accumulation_steps: int,
|
||||
lr_scheduler_type: str,
|
||||
dev_ratio: float,
|
||||
fp16: bool,
|
||||
logging_steps: int,
|
||||
save_steps: int,
|
||||
output_dir: str
|
||||
):
|
||||
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
|
||||
if error:
|
||||
@@ -79,25 +98,35 @@ class Runner:
|
||||
args = dict(
|
||||
model_name_or_path=model_name_or_path,
|
||||
do_train=True,
|
||||
finetuning_type=finetuning_type,
|
||||
lora_target=DEFAULT_MODULE.get(model_name.split("-")[0], None) or "q_proj,v_proj",
|
||||
prompt_template=template,
|
||||
dataset=",".join(dataset),
|
||||
dataset_dir=dataset_dir,
|
||||
max_samples=int(max_samples),
|
||||
output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
overwrite_cache=True,
|
||||
lora_target=DEFAULT_MODULE.get(model_name.split("-")[0], None) or "q_proj,v_proj",
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
||||
prompt_template=template,
|
||||
source_prefix=source_prefix,
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=",".join(dataset),
|
||||
max_source_length=max_source_length,
|
||||
max_target_length=max_target_length,
|
||||
learning_rate=float(learning_rate),
|
||||
num_train_epochs=float(num_train_epochs),
|
||||
max_samples=int(max_samples),
|
||||
per_device_train_batch_size=batch_size,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
lr_scheduler_type=lr_scheduler_type,
|
||||
fp16=fp16,
|
||||
logging_steps=logging_steps,
|
||||
save_steps=save_steps,
|
||||
learning_rate=float(learning_rate),
|
||||
num_train_epochs=float(num_train_epochs),
|
||||
fp16=fp16,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit else None
|
||||
output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
|
||||
)
|
||||
|
||||
if dev_ratio > 1e-6:
|
||||
args["dev_ratio"] = dev_ratio
|
||||
args["evaluation_strategy"] = "steps"
|
||||
args["eval_steps"] = save_steps
|
||||
args["load_best_model_at_end"] = True
|
||||
|
||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
|
||||
|
||||
run_args = dict(
|
||||
@@ -120,8 +149,21 @@ class Runner:
|
||||
yield self.finalize(lang)
|
||||
|
||||
def run_eval(
|
||||
self, lang, model_name, checkpoints, finetuning_type, template,
|
||||
dataset, dataset_dir, max_samples, batch_size, quantization_bit, predict
|
||||
self,
|
||||
lang: str,
|
||||
model_name: str,
|
||||
checkpoints: List[str],
|
||||
finetuning_type: str,
|
||||
quantization_bit: str,
|
||||
template: str,
|
||||
source_prefix: str,
|
||||
dataset_dir: str,
|
||||
dataset: List[str],
|
||||
max_source_length: int,
|
||||
max_target_length: int,
|
||||
max_samples: str,
|
||||
batch_size: int,
|
||||
predict: bool
|
||||
):
|
||||
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
|
||||
if error:
|
||||
@@ -140,17 +182,20 @@ class Runner:
|
||||
args = dict(
|
||||
model_name_or_path=model_name_or_path,
|
||||
do_eval=True,
|
||||
finetuning_type=finetuning_type,
|
||||
prompt_template=template,
|
||||
dataset=",".join(dataset),
|
||||
dataset_dir=dataset_dir,
|
||||
max_samples=int(max_samples),
|
||||
output_dir=output_dir,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
overwrite_cache=True,
|
||||
predict_with_generate=True,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
||||
prompt_template=template,
|
||||
source_prefix=source_prefix,
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=",".join(dataset),
|
||||
max_source_length=max_source_length,
|
||||
max_target_length=max_target_length,
|
||||
max_samples=int(max_samples),
|
||||
per_device_eval_batch_size=batch_size,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit else None
|
||||
output_dir=output_dir
|
||||
)
|
||||
|
||||
if predict:
|
||||
|
||||
@@ -3,7 +3,7 @@ import json
|
||||
import gradio as gr
|
||||
import matplotlib.figure
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import Tuple
|
||||
from typing import Any, Dict, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
from llmtuner.extras.ploting import smooth
|
||||
@@ -23,7 +23,7 @@ def get_time() -> str:
|
||||
return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
||||
|
||||
|
||||
def can_preview(dataset_dir: str, dataset: list) -> dict:
|
||||
def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
if (
|
||||
@@ -36,7 +36,7 @@ def can_preview(dataset_dir: str, dataset: list) -> dict:
|
||||
return gr.update(interactive=False)
|
||||
|
||||
|
||||
def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, dict]:
|
||||
def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, Dict[str, Any]]:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
data_file = dataset_info[dataset[0]]["file_name"]
|
||||
@@ -45,6 +45,13 @@ def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, dict]:
|
||||
return len(data), data[:2], gr.update(visible=True)
|
||||
|
||||
|
||||
def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
||||
if finetuning_type != "lora":
|
||||
return gr.update(value="", interactive=False)
|
||||
else:
|
||||
return gr.update(interactive=True)
|
||||
|
||||
|
||||
def get_eval_results(path: os.PathLike) -> str:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
result = json.dumps(json.load(f), indent=4)
|
||||
@@ -66,6 +73,10 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl
|
||||
if log_info.get("loss", None):
|
||||
steps.append(log_info["current_steps"])
|
||||
losses.append(log_info["loss"])
|
||||
|
||||
if len(losses) == 0:
|
||||
return None
|
||||
|
||||
ax.plot(steps, losses, alpha=0.4, label="original")
|
||||
ax.plot(steps, smooth(losses), label="smoothed")
|
||||
ax.legend()
|
||||
|
||||
44
src/web_demo.py
Normal file
44
src/web_demo.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# coding=utf-8
|
||||
# Implements user interface in browser for fine-tuned models.
|
||||
# Usage: python web_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
||||
|
||||
import gradio as gr
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from llmtuner import get_infer_args
|
||||
from llmtuner.webui.chat import WebChatModel
|
||||
from llmtuner.webui.components.chatbot import create_chat_box
|
||||
from llmtuner.webui.manager import Manager
|
||||
|
||||
|
||||
require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
|
||||
|
||||
|
||||
def main():
|
||||
chat_model = WebChatModel(*get_infer_args())
|
||||
|
||||
with gr.Blocks(title="Web Demo") as demo:
|
||||
lang = gr.Dropdown(choices=["en", "zh"], value="en")
|
||||
|
||||
_, _, _, chat_elems = create_chat_box(chat_model, visible=True)
|
||||
|
||||
manager = Manager([{"lang": lang}, chat_elems])
|
||||
|
||||
demo.load(
|
||||
manager.gen_label,
|
||||
[lang],
|
||||
[lang] + [elem for elem in chat_elems.values()],
|
||||
)
|
||||
|
||||
lang.change(
|
||||
manager.gen_label,
|
||||
[lang],
|
||||
[lang] + [elem for elem in chat_elems.values()],
|
||||
)
|
||||
|
||||
demo.queue()
|
||||
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user