Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
35e76879f5 | ||
|
|
8e4ae0aaac | ||
|
|
5ed2a97056 | ||
|
|
03eba6f041 | ||
|
|
ec166e736a | ||
|
|
c85a6b83b3 | ||
|
|
a864a7b395 | ||
|
|
fd8c2d4aac |
@@ -291,6 +291,14 @@ python src/cli_demo.py \
|
|||||||
--checkpoint_dir path_to_checkpoint
|
--checkpoint_dir path_to_checkpoint
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Web Demo
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python src/web_demo.py \
|
||||||
|
--model_name_or_path path_to_your_model \
|
||||||
|
--checkpoint_dir path_to_checkpoint
|
||||||
|
```
|
||||||
|
|
||||||
### Export model
|
### Export model
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ from llmtuner.tuner import get_infer_args
|
|||||||
from llmtuner.extras.misc import torch_gc
|
from llmtuner.extras.misc import torch_gc
|
||||||
from llmtuner.chat.stream_chat import ChatModel
|
from llmtuner.chat.stream_chat import ChatModel
|
||||||
from llmtuner.api.protocol import (
|
from llmtuner.api.protocol import (
|
||||||
|
Role,
|
||||||
|
Finish,
|
||||||
ModelCard,
|
ModelCard,
|
||||||
ModelList,
|
ModelList,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
@@ -49,12 +51,12 @@ def create_app():
|
|||||||
|
|
||||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
||||||
async def create_chat_completion(request: ChatCompletionRequest):
|
async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
if request.messages[-1].role != "user":
|
if request.messages[-1].role != Role.USER:
|
||||||
raise HTTPException(status_code=400, detail="Invalid request")
|
raise HTTPException(status_code=400, detail="Invalid request")
|
||||||
query = request.messages[-1].content
|
query = request.messages[-1].content
|
||||||
|
|
||||||
prev_messages = request.messages[:-1]
|
prev_messages = request.messages[:-1]
|
||||||
if len(prev_messages) > 0 and prev_messages[0].role == "system":
|
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
|
||||||
prefix = prev_messages.pop(0).content
|
prefix = prev_messages.pop(0).content
|
||||||
else:
|
else:
|
||||||
prefix = None
|
prefix = None
|
||||||
@@ -62,7 +64,7 @@ def create_app():
|
|||||||
history = []
|
history = []
|
||||||
if len(prev_messages) % 2 == 0:
|
if len(prev_messages) % 2 == 0:
|
||||||
for i in range(0, len(prev_messages), 2):
|
for i in range(0, len(prev_messages), 2):
|
||||||
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
|
if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT:
|
||||||
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
@@ -81,19 +83,19 @@ def create_app():
|
|||||||
|
|
||||||
choice_data = ChatCompletionResponseChoice(
|
choice_data = ChatCompletionResponseChoice(
|
||||||
index=0,
|
index=0,
|
||||||
message=ChatMessage(role="assistant", content=response),
|
message=ChatMessage(role=Role.ASSISTANT, content=response),
|
||||||
finish_reason="stop"
|
finish_reason=Finish.STOP
|
||||||
)
|
)
|
||||||
|
|
||||||
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage, object="chat.completion")
|
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage)
|
||||||
|
|
||||||
async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest):
|
async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest):
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=0,
|
index=0,
|
||||||
delta=DeltaMessage(role="assistant"),
|
delta=DeltaMessage(role=Role.ASSISTANT),
|
||||||
finish_reason=None
|
finish_reason=None
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk")
|
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||||
yield json.dumps(chunk, ensure_ascii=False)
|
yield json.dumps(chunk, ensure_ascii=False)
|
||||||
|
|
||||||
for new_text in chat_model.stream_chat(
|
for new_text in chat_model.stream_chat(
|
||||||
@@ -107,15 +109,15 @@ def create_app():
|
|||||||
delta=DeltaMessage(content=new_text),
|
delta=DeltaMessage(content=new_text),
|
||||||
finish_reason=None
|
finish_reason=None
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk")
|
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||||
yield json.dumps(chunk, ensure_ascii=False)
|
yield json.dumps(chunk, ensure_ascii=False)
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=0,
|
index=0,
|
||||||
delta=DeltaMessage(),
|
delta=DeltaMessage(),
|
||||||
finish_reason="stop"
|
finish_reason=Finish.STOP
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk")
|
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||||
yield json.dumps(chunk, ensure_ascii=False)
|
yield json.dumps(chunk, ensure_ascii=False)
|
||||||
yield "[DONE]"
|
yield "[DONE]"
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,18 @@
|
|||||||
import time
|
import time
|
||||||
|
from enum import Enum
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import List, Literal, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class Role(str, Enum):
|
||||||
|
USER = "user"
|
||||||
|
ASSISTANT = "assistant"
|
||||||
|
SYSTEM = "system"
|
||||||
|
|
||||||
|
|
||||||
|
class Finish(str, Enum):
|
||||||
|
STOP = "stop"
|
||||||
|
LENGTH = "length"
|
||||||
|
|
||||||
|
|
||||||
class ModelCard(BaseModel):
|
class ModelCard(BaseModel):
|
||||||
@@ -19,12 +31,12 @@ class ModelList(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
role: Literal["user", "assistant", "system"]
|
role: Role
|
||||||
content: str
|
content: str
|
||||||
|
|
||||||
|
|
||||||
class DeltaMessage(BaseModel):
|
class DeltaMessage(BaseModel):
|
||||||
role: Optional[Literal["user", "assistant", "system"]] = None
|
role: Optional[Role] = None
|
||||||
content: Optional[str] = None
|
content: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@@ -41,13 +53,13 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
class ChatCompletionResponseChoice(BaseModel):
|
class ChatCompletionResponseChoice(BaseModel):
|
||||||
index: int
|
index: int
|
||||||
message: ChatMessage
|
message: ChatMessage
|
||||||
finish_reason: Literal["stop", "length"]
|
finish_reason: Finish
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||||
index: int
|
index: int
|
||||||
delta: DeltaMessage
|
delta: DeltaMessage
|
||||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
finish_reason: Optional[Finish] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseUsage(BaseModel):
|
class ChatCompletionResponseUsage(BaseModel):
|
||||||
@@ -58,7 +70,7 @@ class ChatCompletionResponseUsage(BaseModel):
|
|||||||
|
|
||||||
class ChatCompletionResponse(BaseModel):
|
class ChatCompletionResponse(BaseModel):
|
||||||
id: Optional[str] = "chatcmpl-default"
|
id: Optional[str] = "chatcmpl-default"
|
||||||
object: Literal["chat.completion"]
|
object: Optional[str] = "chat.completion"
|
||||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||||
model: str
|
model: str
|
||||||
choices: List[ChatCompletionResponseChoice]
|
choices: List[ChatCompletionResponseChoice]
|
||||||
@@ -67,7 +79,7 @@ class ChatCompletionResponse(BaseModel):
|
|||||||
|
|
||||||
class ChatCompletionStreamResponse(BaseModel):
|
class ChatCompletionStreamResponse(BaseModel):
|
||||||
id: Optional[str] = "chatcmpl-default"
|
id: Optional[str] = "chatcmpl-default"
|
||||||
object: Literal["chat.completion.chunk"]
|
object: Optional[str] = "chat.completion.chunk"
|
||||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||||
model: str
|
model: str
|
||||||
choices: List[ChatCompletionResponseStreamChoice]
|
choices: List[ChatCompletionResponseStreamChoice]
|
||||||
|
|||||||
@@ -47,6 +47,9 @@ class LogCallback(TrainerCallback):
|
|||||||
r"""
|
r"""
|
||||||
Event called after logging the last logs.
|
Event called after logging the last logs.
|
||||||
"""
|
"""
|
||||||
|
if not state.is_world_process_zero:
|
||||||
|
return
|
||||||
|
|
||||||
cur_time = time.time()
|
cur_time = time.time()
|
||||||
cur_steps = state.log_history[-1].get("step")
|
cur_steps = state.log_history[-1].get("step")
|
||||||
elapsed_time = cur_time - self.start_time
|
elapsed_time = cur_time - self.start_time
|
||||||
|
|||||||
@@ -202,7 +202,20 @@ Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
|
|||||||
register_template(
|
register_template(
|
||||||
name="baichuan",
|
name="baichuan",
|
||||||
prefix="",
|
prefix="",
|
||||||
prompt="<reserved_102>{query}<reserved_103>",
|
prompt=" <reserved_102> {query} <reserved_103> ",
|
||||||
sep="",
|
sep="</s>",
|
||||||
|
use_history=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
r"""
|
||||||
|
Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha
|
||||||
|
https://huggingface.co/HuggingFaceH4/starchat-beta
|
||||||
|
"""
|
||||||
|
register_template(
|
||||||
|
name="starchat",
|
||||||
|
prefix="<|system|>\n",
|
||||||
|
prompt="<|user|>\n{query}<|end|>\n<|assistant|>\n",
|
||||||
|
sep="<|end|>\n",
|
||||||
use_history=True
|
use_history=True
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -108,7 +108,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
replace_model(unwrapped_model, target="reward")
|
replace_model(unwrapped_model, target="reward")
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
_, _, values = self.model(**self.prepare_model_inputs(queries, responses))
|
_, _, values = self.model(**self.prepare_model_inputs(queries, responses))
|
||||||
rewards = [reward for reward in values[-1].to(torch.float32)] # use float32 type
|
rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
|
||||||
replace_model(unwrapped_model, target="default")
|
replace_model(unwrapped_model, target="default")
|
||||||
|
|
||||||
# Run PPO step
|
# Run PPO step
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ class ComputeMetrics:
|
|||||||
Uses the model predictions to compute metrics.
|
Uses the model predictions to compute metrics.
|
||||||
"""
|
"""
|
||||||
preds, labels = eval_preds
|
preds, labels = eval_preds
|
||||||
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
score_dict = {"accuracy": [], "rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
||||||
|
|
||||||
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
|
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
|
||||||
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
|
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
|
||||||
@@ -47,5 +47,6 @@ class ComputeMetrics:
|
|||||||
|
|
||||||
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
|
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
|
||||||
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
|
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
|
||||||
|
score_dict["accuracy"].append(float(len(label) != 0 and pred[:len(label)] == label))
|
||||||
|
|
||||||
return {k: float(np.mean(v)) for k, v in score_dict.items()}
|
return {k: float(np.mean(v)) for k, v in score_dict.items()}
|
||||||
|
|||||||
@@ -11,14 +11,22 @@ from llmtuner.webui.locales import ALERTS
|
|||||||
|
|
||||||
class WebChatModel(ChatModel):
|
class WebChatModel(ChatModel):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, *args):
|
||||||
self.model = None
|
self.model = None
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self.generating_args = GeneratingArguments()
|
self.generating_args = GeneratingArguments()
|
||||||
|
if len(args) != 0:
|
||||||
|
super().__init__(*args)
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
self, lang: str, model_name: str, checkpoints: list,
|
self,
|
||||||
finetuning_type: str, template: str, quantization_bit: str
|
lang: str,
|
||||||
|
model_name: str,
|
||||||
|
checkpoints: List[str],
|
||||||
|
finetuning_type: str,
|
||||||
|
quantization_bit: str,
|
||||||
|
template: str,
|
||||||
|
source_prefix: str
|
||||||
):
|
):
|
||||||
if self.model is not None:
|
if self.model is not None:
|
||||||
yield ALERTS["err_exists"][lang]
|
yield ALERTS["err_exists"][lang]
|
||||||
@@ -43,10 +51,11 @@ class WebChatModel(ChatModel):
|
|||||||
yield ALERTS["info_loading"][lang]
|
yield ALERTS["info_loading"][lang]
|
||||||
args = dict(
|
args = dict(
|
||||||
model_name_or_path=model_name_or_path,
|
model_name_or_path=model_name_or_path,
|
||||||
finetuning_type=finetuning_type,
|
|
||||||
prompt_template=template,
|
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
quantization_bit=int(quantization_bit) if quantization_bit else None
|
finetuning_type=finetuning_type,
|
||||||
|
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
||||||
|
prompt_template=template,
|
||||||
|
source_prefix=source_prefix
|
||||||
)
|
)
|
||||||
super().__init__(*get_infer_args(args))
|
super().__init__(*get_infer_args(args))
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, Tuple
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from gradio.blocks import Block
|
from gradio.blocks import Block
|
||||||
@@ -8,9 +8,10 @@ from llmtuner.webui.chat import WebChatModel
|
|||||||
|
|
||||||
|
|
||||||
def create_chat_box(
|
def create_chat_box(
|
||||||
chat_model: WebChatModel
|
chat_model: WebChatModel,
|
||||||
|
visible: Optional[bool] = False
|
||||||
) -> Tuple[Block, Component, Component, Dict[str, Component]]:
|
) -> Tuple[Block, Component, Component, Dict[str, Component]]:
|
||||||
with gr.Box(visible=False) as chat_box:
|
with gr.Box(visible=visible) as chat_box:
|
||||||
chatbot = gr.Chatbot()
|
chatbot = gr.Chatbot()
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ from llmtuner.webui.utils import can_preview, get_preview
|
|||||||
|
|
||||||
def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
|
def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, interactive=True, scale=2)
|
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||||
dataset = gr.Dropdown(multiselect=True, interactive=True, scale=4)
|
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||||
preview_btn = gr.Button(interactive=False, scale=1)
|
preview_btn = gr.Button(interactive=False, scale=1)
|
||||||
|
|
||||||
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
|
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
|
||||||
@@ -21,9 +21,10 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
|
|||||||
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
|
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
max_samples = gr.Textbox(value="100000", interactive=True)
|
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
|
||||||
batch_size = gr.Slider(value=8, minimum=1, maximum=128, step=1, interactive=True)
|
max_target_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
|
||||||
quantization_bit = gr.Dropdown([8, 4])
|
max_samples = gr.Textbox(value="100000")
|
||||||
|
batch_size = gr.Slider(value=8, minimum=1, maximum=512, step=1)
|
||||||
predict = gr.Checkbox(value=True)
|
predict = gr.Checkbox(value=True)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@@ -35,9 +36,20 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
|
|||||||
start_btn.click(
|
start_btn.click(
|
||||||
runner.run_eval,
|
runner.run_eval,
|
||||||
[
|
[
|
||||||
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"],
|
top_elems["lang"],
|
||||||
top_elems["finetuning_type"], top_elems["template"],
|
top_elems["model_name"],
|
||||||
dataset, dataset_dir, max_samples, batch_size, quantization_bit, predict
|
top_elems["checkpoints"],
|
||||||
|
top_elems["finetuning_type"],
|
||||||
|
top_elems["quantization_bit"],
|
||||||
|
top_elems["template"],
|
||||||
|
top_elems["source_prefix"],
|
||||||
|
dataset_dir,
|
||||||
|
dataset,
|
||||||
|
max_source_length,
|
||||||
|
max_target_length,
|
||||||
|
max_samples,
|
||||||
|
batch_size,
|
||||||
|
predict
|
||||||
],
|
],
|
||||||
[output_box]
|
[output_box]
|
||||||
)
|
)
|
||||||
@@ -50,9 +62,10 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
|
|||||||
preview_count=preview_count,
|
preview_count=preview_count,
|
||||||
preview_samples=preview_samples,
|
preview_samples=preview_samples,
|
||||||
close_btn=close_btn,
|
close_btn=close_btn,
|
||||||
|
max_source_length=max_source_length,
|
||||||
|
max_target_length=max_target_length,
|
||||||
max_samples=max_samples,
|
max_samples=max_samples,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
quantization_bit=quantization_bit,
|
|
||||||
predict=predict,
|
predict=predict,
|
||||||
start_btn=start_btn,
|
start_btn=start_btn,
|
||||||
stop_btn=stop_btn,
|
stop_btn=stop_btn,
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
load_btn = gr.Button()
|
load_btn = gr.Button()
|
||||||
unload_btn = gr.Button()
|
unload_btn = gr.Button()
|
||||||
quantization_bit = gr.Dropdown([8, 4])
|
|
||||||
|
|
||||||
info_box = gr.Markdown()
|
info_box = gr.Markdown()
|
||||||
|
|
||||||
@@ -21,9 +20,13 @@ def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
|
|||||||
load_btn.click(
|
load_btn.click(
|
||||||
chat_model.load_model,
|
chat_model.load_model,
|
||||||
[
|
[
|
||||||
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"],
|
top_elems["lang"],
|
||||||
top_elems["finetuning_type"], top_elems["template"],
|
top_elems["model_name"],
|
||||||
quantization_bit
|
top_elems["checkpoints"],
|
||||||
|
top_elems["finetuning_type"],
|
||||||
|
top_elems["quantization_bit"],
|
||||||
|
top_elems["template"],
|
||||||
|
top_elems["source_prefix"]
|
||||||
],
|
],
|
||||||
[info_box]
|
[info_box]
|
||||||
).then(
|
).then(
|
||||||
@@ -39,7 +42,6 @@ def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
quantization_bit=quantization_bit,
|
|
||||||
info_box=info_box,
|
info_box=info_box,
|
||||||
load_btn=load_btn,
|
load_btn=load_btn,
|
||||||
unload_btn=unload_btn,
|
unload_btn=unload_btn,
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ from llmtuner.webui.utils import can_preview, get_preview, gen_plot
|
|||||||
|
|
||||||
def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
|
def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, interactive=True, scale=1)
|
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||||
dataset = gr.Dropdown(multiselect=True, interactive=True, scale=4)
|
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||||
preview_btn = gr.Button(interactive=False, scale=1)
|
preview_btn = gr.Button(interactive=False, scale=1)
|
||||||
|
|
||||||
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
|
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
|
||||||
@@ -23,22 +23,24 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
|
|||||||
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
|
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
learning_rate = gr.Textbox(value="5e-5", interactive=True)
|
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
|
||||||
num_train_epochs = gr.Textbox(value="3.0", interactive=True)
|
max_target_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
|
||||||
max_samples = gr.Textbox(value="100000", interactive=True)
|
learning_rate = gr.Textbox(value="5e-5")
|
||||||
quantization_bit = gr.Dropdown([8, 4])
|
num_train_epochs = gr.Textbox(value="3.0")
|
||||||
|
max_samples = gr.Textbox(value="100000")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
batch_size = gr.Slider(value=4, minimum=1, maximum=128, step=1, interactive=True)
|
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
||||||
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=32, step=1, interactive=True)
|
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
||||||
lr_scheduler_type = gr.Dropdown(
|
lr_scheduler_type = gr.Dropdown(
|
||||||
value="cosine", choices=[scheduler.value for scheduler in SchedulerType], interactive=True
|
value="cosine", choices=[scheduler.value for scheduler in SchedulerType]
|
||||||
)
|
)
|
||||||
|
dev_ratio = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
|
||||||
fp16 = gr.Checkbox(value=True)
|
fp16 = gr.Checkbox(value=True)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5, interactive=True)
|
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
|
||||||
save_steps = gr.Slider(value=100, minimum=10, maximum=2000, step=10, interactive=True)
|
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
start_btn = gr.Button()
|
start_btn = gr.Button()
|
||||||
@@ -55,11 +57,28 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
|
|||||||
start_btn.click(
|
start_btn.click(
|
||||||
runner.run_train,
|
runner.run_train,
|
||||||
[
|
[
|
||||||
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"],
|
top_elems["lang"],
|
||||||
top_elems["finetuning_type"], top_elems["template"],
|
top_elems["model_name"],
|
||||||
dataset, dataset_dir, learning_rate, num_train_epochs, max_samples,
|
top_elems["checkpoints"],
|
||||||
fp16, quantization_bit, batch_size, gradient_accumulation_steps,
|
top_elems["finetuning_type"],
|
||||||
lr_scheduler_type, logging_steps, save_steps, output_dir
|
top_elems["quantization_bit"],
|
||||||
|
top_elems["template"],
|
||||||
|
top_elems["source_prefix"],
|
||||||
|
dataset_dir,
|
||||||
|
dataset,
|
||||||
|
max_source_length,
|
||||||
|
max_target_length,
|
||||||
|
learning_rate,
|
||||||
|
num_train_epochs,
|
||||||
|
max_samples,
|
||||||
|
batch_size,
|
||||||
|
gradient_accumulation_steps,
|
||||||
|
lr_scheduler_type,
|
||||||
|
dev_ratio,
|
||||||
|
fp16,
|
||||||
|
logging_steps,
|
||||||
|
save_steps,
|
||||||
|
output_dir
|
||||||
],
|
],
|
||||||
[output_box]
|
[output_box]
|
||||||
)
|
)
|
||||||
@@ -76,13 +95,15 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
|
|||||||
preview_count=preview_count,
|
preview_count=preview_count,
|
||||||
preview_samples=preview_samples,
|
preview_samples=preview_samples,
|
||||||
close_btn=close_btn,
|
close_btn=close_btn,
|
||||||
|
max_source_length=max_source_length,
|
||||||
|
max_target_length=max_target_length,
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
num_train_epochs=num_train_epochs,
|
num_train_epochs=num_train_epochs,
|
||||||
max_samples=max_samples,
|
max_samples=max_samples,
|
||||||
quantization_bit=quantization_bit,
|
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||||
lr_scheduler_type=lr_scheduler_type,
|
lr_scheduler_type=lr_scheduler_type,
|
||||||
|
dev_ratio=dev_ratio,
|
||||||
fp16=fp16,
|
fp16=fp16,
|
||||||
logging_steps=logging_steps,
|
logging_steps=logging_steps,
|
||||||
save_steps=save_steps,
|
save_steps=save_steps,
|
||||||
|
|||||||
@@ -6,29 +6,40 @@ from gradio.components import Component
|
|||||||
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
|
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
|
||||||
from llmtuner.extras.template import templates
|
from llmtuner.extras.template import templates
|
||||||
from llmtuner.webui.common import list_checkpoint, get_model_path, save_config
|
from llmtuner.webui.common import list_checkpoint, get_model_path, save_config
|
||||||
|
from llmtuner.webui.utils import can_quantize
|
||||||
|
|
||||||
|
|
||||||
def create_top() -> Dict[str, Component]:
|
def create_top() -> Dict[str, Component]:
|
||||||
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
|
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
lang = gr.Dropdown(choices=["en", "zh"], value="en", interactive=True, scale=1)
|
lang = gr.Dropdown(choices=["en", "zh"], value="en", scale=1)
|
||||||
model_name = gr.Dropdown(choices=available_models, scale=3)
|
model_name = gr.Dropdown(choices=available_models, scale=3)
|
||||||
model_path = gr.Textbox(scale=3)
|
model_path = gr.Textbox(scale=3)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
finetuning_type = gr.Dropdown(value="lora", choices=METHODS, interactive=True, scale=1)
|
finetuning_type = gr.Dropdown(value="lora", choices=METHODS, scale=1)
|
||||||
template = gr.Dropdown(value="default", choices=list(templates.keys()), interactive=True, scale=1)
|
checkpoints = gr.Dropdown(multiselect=True, scale=5)
|
||||||
checkpoints = gr.Dropdown(multiselect=True, interactive=True, scale=4)
|
|
||||||
refresh_btn = gr.Button(scale=1)
|
refresh_btn = gr.Button(scale=1)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
quantization_bit = gr.Dropdown([8, 4], scale=1)
|
||||||
|
template = gr.Dropdown(value="default", choices=list(templates.keys()), scale=2)
|
||||||
|
source_prefix = gr.Textbox(scale=4)
|
||||||
|
|
||||||
model_name.change(
|
model_name.change(
|
||||||
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||||
).then(
|
).then(
|
||||||
get_model_path, [model_name], [model_path]
|
get_model_path, [model_name], [model_path]
|
||||||
) # do not save config since the below line will save
|
) # do not save config since the below line will save
|
||||||
model_path.change(save_config, [model_name, model_path])
|
model_path.change(save_config, [model_name, model_path])
|
||||||
finetuning_type.change(list_checkpoint, [model_name, finetuning_type], [checkpoints])
|
|
||||||
|
finetuning_type.change(
|
||||||
|
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||||
|
).then(
|
||||||
|
can_quantize, [finetuning_type], [quantization_bit]
|
||||||
|
)
|
||||||
|
|
||||||
refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints])
|
refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints])
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
@@ -38,5 +49,7 @@ def create_top() -> Dict[str, Component]:
|
|||||||
finetuning_type=finetuning_type,
|
finetuning_type=finetuning_type,
|
||||||
template=template,
|
template=template,
|
||||||
checkpoints=checkpoints,
|
checkpoints=checkpoints,
|
||||||
refresh_btn=refresh_btn
|
refresh_btn=refresh_btn,
|
||||||
|
quantization_bit=quantization_bit,
|
||||||
|
source_prefix=source_prefix
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -25,6 +25,14 @@ LOCALES = {
|
|||||||
"info": "本地模型的文件路径或 Hugging Face 的模型标识符。"
|
"info": "本地模型的文件路径或 Hugging Face 的模型标识符。"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"finetuning_type": {
|
||||||
|
"en": {
|
||||||
|
"label": "Finetuning method"
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "微调方法"
|
||||||
|
}
|
||||||
|
},
|
||||||
"checkpoints": {
|
"checkpoints": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Checkpoints"
|
"label": "Checkpoints"
|
||||||
@@ -33,14 +41,6 @@ LOCALES = {
|
|||||||
"label": "模型断点"
|
"label": "模型断点"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"template": {
|
|
||||||
"en": {
|
|
||||||
"label": "Prompt template"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "提示模板"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"refresh_btn": {
|
"refresh_btn": {
|
||||||
"en": {
|
"en": {
|
||||||
"value": "Refresh checkpoints"
|
"value": "Refresh checkpoints"
|
||||||
@@ -49,6 +49,36 @@ LOCALES = {
|
|||||||
"value": "刷新断点"
|
"value": "刷新断点"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"quantization_bit": {
|
||||||
|
"en": {
|
||||||
|
"label": "Quantization bit (optional)",
|
||||||
|
"info": "Enable 4/8-bit model quantization."
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "量化等级(非必填)",
|
||||||
|
"info": "启用 4/8 比特模型量化。"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"template": {
|
||||||
|
"en": {
|
||||||
|
"label": "Prompt template",
|
||||||
|
"info": "The template used in constructing prompts."
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "提示模板",
|
||||||
|
"info": "构建提示词时使用的模板"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"source_prefix": {
|
||||||
|
"en": {
|
||||||
|
"label": "Source prefix (optional)",
|
||||||
|
"info": "A sequence used as the prefix of each samples."
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "前缀序列(非必填)",
|
||||||
|
"info": "作为每个输入样本前缀的序列"
|
||||||
|
}
|
||||||
|
},
|
||||||
"dataset_dir": {
|
"dataset_dir": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Data dir",
|
"label": "Data dir",
|
||||||
@@ -99,66 +129,24 @@ LOCALES = {
|
|||||||
"value": "关闭"
|
"value": "关闭"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"max_samples": {
|
"max_source_length": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Max samples",
|
"label": "Max source length",
|
||||||
"info": "Maximum samples per dataset."
|
"info": "Max tokens in source sequence."
|
||||||
},
|
},
|
||||||
"zh": {
|
"zh": {
|
||||||
"label": "最大样本数",
|
"label": "输入序列最大长度",
|
||||||
"info": "每个数据集最多使用的样本数。"
|
"info": "输入序列分词后的最大长度。"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"batch_size": {
|
"max_target_length": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Batch size",
|
"label": "Max target length",
|
||||||
"info": "Number of samples to process per GPU."
|
"info": "Max tokens in target sequence."
|
||||||
},
|
|
||||||
"zh":{
|
|
||||||
"label": "批处理大小",
|
|
||||||
"info": "每块 GPU 上处理的样本数量。"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"quantization_bit": {
|
|
||||||
"en": {
|
|
||||||
"label": "Quantization bit",
|
|
||||||
"info": "Enable 4/8-bit model quantization."
|
|
||||||
},
|
},
|
||||||
"zh": {
|
"zh": {
|
||||||
"label": "量化",
|
"label": "输出序列最大长度",
|
||||||
"info": "启用 4/8 比特模型量化。"
|
"info": "输出序列分词后的最大长度。"
|
||||||
}
|
|
||||||
},
|
|
||||||
"start_btn": {
|
|
||||||
"en": {
|
|
||||||
"value": "Start"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"value": "开始"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"stop_btn": {
|
|
||||||
"en": {
|
|
||||||
"value": "Abort"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"value": "中断"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"output_box": {
|
|
||||||
"en": {
|
|
||||||
"value": "Ready."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"value": "准备就绪。"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"finetuning_type": {
|
|
||||||
"en": {
|
|
||||||
"label": "Finetuning method"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "微调方法"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"learning_rate": {
|
"learning_rate": {
|
||||||
@@ -181,6 +169,26 @@ LOCALES = {
|
|||||||
"info": "需要执行的训练总轮数。"
|
"info": "需要执行的训练总轮数。"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"max_samples": {
|
||||||
|
"en": {
|
||||||
|
"label": "Max samples",
|
||||||
|
"info": "Maximum samples per dataset."
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "最大样本数",
|
||||||
|
"info": "每个数据集最多使用的样本数。"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"batch_size": {
|
||||||
|
"en": {
|
||||||
|
"label": "Batch size",
|
||||||
|
"info": "Number of samples to process per GPU."
|
||||||
|
},
|
||||||
|
"zh":{
|
||||||
|
"label": "批处理大小",
|
||||||
|
"info": "每块 GPU 上处理的样本数量。"
|
||||||
|
}
|
||||||
|
},
|
||||||
"gradient_accumulation_steps": {
|
"gradient_accumulation_steps": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Gradient accumulation",
|
"label": "Gradient accumulation",
|
||||||
@@ -201,6 +209,16 @@ LOCALES = {
|
|||||||
"info": "采用的学习率调节器名称。"
|
"info": "采用的学习率调节器名称。"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"dev_ratio": {
|
||||||
|
"en": {
|
||||||
|
"label": "Dev ratio",
|
||||||
|
"info": "Proportion of data in the dev set."
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "验证集比例",
|
||||||
|
"info": "验证集占全部样本的百分比。"
|
||||||
|
}
|
||||||
|
},
|
||||||
"fp16": {
|
"fp16": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "fp16",
|
"label": "fp16",
|
||||||
@@ -231,6 +249,22 @@ LOCALES = {
|
|||||||
"info": "每两次断点保存间的更新步数。"
|
"info": "每两次断点保存间的更新步数。"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"start_btn": {
|
||||||
|
"en": {
|
||||||
|
"value": "Start"
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"value": "开始"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"stop_btn": {
|
||||||
|
"en": {
|
||||||
|
"value": "Abort"
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"value": "中断"
|
||||||
|
}
|
||||||
|
},
|
||||||
"output_dir": {
|
"output_dir": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Checkpoint name",
|
"label": "Checkpoint name",
|
||||||
@@ -241,6 +275,14 @@ LOCALES = {
|
|||||||
"info": "保存模型断点的文件夹名称。"
|
"info": "保存模型断点的文件夹名称。"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"output_box": {
|
||||||
|
"en": {
|
||||||
|
"value": "Ready."
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"value": "准备就绪。"
|
||||||
|
}
|
||||||
|
},
|
||||||
"loss_viewer": {
|
"loss_viewer": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Loss"
|
"label": "Loss"
|
||||||
@@ -257,14 +299,6 @@ LOCALES = {
|
|||||||
"label": "保存预测结果"
|
"label": "保存预测结果"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"info_box": {
|
|
||||||
"en": {
|
|
||||||
"value": "Model unloaded, please load a model first."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"value": "模型未加载,请先加载模型。"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"load_btn": {
|
"load_btn": {
|
||||||
"en": {
|
"en": {
|
||||||
"value": "Load model"
|
"value": "Load model"
|
||||||
@@ -281,6 +315,14 @@ LOCALES = {
|
|||||||
"value": "卸载模型"
|
"value": "卸载模型"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"info_box": {
|
||||||
|
"en": {
|
||||||
|
"value": "Model unloaded, please load a model first."
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"value": "模型未加载,请先加载模型。"
|
||||||
|
}
|
||||||
|
},
|
||||||
"query": {
|
"query": {
|
||||||
"en": {
|
"en": {
|
||||||
"placeholder": "Input..."
|
"placeholder": "Input..."
|
||||||
@@ -305,6 +347,14 @@ LOCALES = {
|
|||||||
"value": "清空历史"
|
"value": "清空历史"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"max_length": {
|
||||||
|
"en": {
|
||||||
|
"label": "Maximum length"
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "最大长度"
|
||||||
|
}
|
||||||
|
},
|
||||||
"max_new_tokens": {
|
"max_new_tokens": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Maximum new tokens"
|
"label": "Maximum new tokens"
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import os
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import transformers
|
import transformers
|
||||||
from typing import Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
from llmtuner.extras.constants import DEFAULT_MODULE # will be deprecated
|
from llmtuner.extras.constants import DEFAULT_MODULE # will be deprecated
|
||||||
@@ -59,10 +59,29 @@ class Runner:
|
|||||||
return finish_info if finish_info is not None else ALERTS["info_finished"][lang]
|
return finish_info if finish_info is not None else ALERTS["info_finished"][lang]
|
||||||
|
|
||||||
def run_train(
|
def run_train(
|
||||||
self, lang, model_name, checkpoints, finetuning_type, template,
|
self,
|
||||||
dataset, dataset_dir, learning_rate, num_train_epochs, max_samples,
|
lang: str,
|
||||||
fp16, quantization_bit, batch_size, gradient_accumulation_steps,
|
model_name: str,
|
||||||
lr_scheduler_type, logging_steps, save_steps, output_dir
|
checkpoints: List[str],
|
||||||
|
finetuning_type: str,
|
||||||
|
quantization_bit: str,
|
||||||
|
template: str,
|
||||||
|
source_prefix: str,
|
||||||
|
dataset_dir: str,
|
||||||
|
dataset: List[str],
|
||||||
|
max_source_length: int,
|
||||||
|
max_target_length: int,
|
||||||
|
learning_rate: str,
|
||||||
|
num_train_epochs: str,
|
||||||
|
max_samples: str,
|
||||||
|
batch_size: int,
|
||||||
|
gradient_accumulation_steps: int,
|
||||||
|
lr_scheduler_type: str,
|
||||||
|
dev_ratio: float,
|
||||||
|
fp16: bool,
|
||||||
|
logging_steps: int,
|
||||||
|
save_steps: int,
|
||||||
|
output_dir: str
|
||||||
):
|
):
|
||||||
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
|
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
|
||||||
if error:
|
if error:
|
||||||
@@ -79,25 +98,35 @@ class Runner:
|
|||||||
args = dict(
|
args = dict(
|
||||||
model_name_or_path=model_name_or_path,
|
model_name_or_path=model_name_or_path,
|
||||||
do_train=True,
|
do_train=True,
|
||||||
finetuning_type=finetuning_type,
|
|
||||||
lora_target=DEFAULT_MODULE.get(model_name.split("-")[0], None) or "q_proj,v_proj",
|
|
||||||
prompt_template=template,
|
|
||||||
dataset=",".join(dataset),
|
|
||||||
dataset_dir=dataset_dir,
|
|
||||||
max_samples=int(max_samples),
|
|
||||||
output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir),
|
|
||||||
checkpoint_dir=checkpoint_dir,
|
|
||||||
overwrite_cache=True,
|
overwrite_cache=True,
|
||||||
|
lora_target=DEFAULT_MODULE.get(model_name.split("-")[0], None) or "q_proj,v_proj",
|
||||||
|
checkpoint_dir=checkpoint_dir,
|
||||||
|
finetuning_type=finetuning_type,
|
||||||
|
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
||||||
|
prompt_template=template,
|
||||||
|
source_prefix=source_prefix,
|
||||||
|
dataset_dir=dataset_dir,
|
||||||
|
dataset=",".join(dataset),
|
||||||
|
max_source_length=max_source_length,
|
||||||
|
max_target_length=max_target_length,
|
||||||
|
learning_rate=float(learning_rate),
|
||||||
|
num_train_epochs=float(num_train_epochs),
|
||||||
|
max_samples=int(max_samples),
|
||||||
per_device_train_batch_size=batch_size,
|
per_device_train_batch_size=batch_size,
|
||||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||||
lr_scheduler_type=lr_scheduler_type,
|
lr_scheduler_type=lr_scheduler_type,
|
||||||
|
fp16=fp16,
|
||||||
logging_steps=logging_steps,
|
logging_steps=logging_steps,
|
||||||
save_steps=save_steps,
|
save_steps=save_steps,
|
||||||
learning_rate=float(learning_rate),
|
output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
|
||||||
num_train_epochs=float(num_train_epochs),
|
|
||||||
fp16=fp16,
|
|
||||||
quantization_bit=int(quantization_bit) if quantization_bit else None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if dev_ratio > 1e-6:
|
||||||
|
args["dev_ratio"] = dev_ratio
|
||||||
|
args["evaluation_strategy"] = "steps"
|
||||||
|
args["eval_steps"] = save_steps
|
||||||
|
args["load_best_model_at_end"] = True
|
||||||
|
|
||||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
|
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
|
||||||
|
|
||||||
run_args = dict(
|
run_args = dict(
|
||||||
@@ -120,8 +149,21 @@ class Runner:
|
|||||||
yield self.finalize(lang)
|
yield self.finalize(lang)
|
||||||
|
|
||||||
def run_eval(
|
def run_eval(
|
||||||
self, lang, model_name, checkpoints, finetuning_type, template,
|
self,
|
||||||
dataset, dataset_dir, max_samples, batch_size, quantization_bit, predict
|
lang: str,
|
||||||
|
model_name: str,
|
||||||
|
checkpoints: List[str],
|
||||||
|
finetuning_type: str,
|
||||||
|
quantization_bit: str,
|
||||||
|
template: str,
|
||||||
|
source_prefix: str,
|
||||||
|
dataset_dir: str,
|
||||||
|
dataset: List[str],
|
||||||
|
max_source_length: int,
|
||||||
|
max_target_length: int,
|
||||||
|
max_samples: str,
|
||||||
|
batch_size: int,
|
||||||
|
predict: bool
|
||||||
):
|
):
|
||||||
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
|
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
|
||||||
if error:
|
if error:
|
||||||
@@ -140,17 +182,20 @@ class Runner:
|
|||||||
args = dict(
|
args = dict(
|
||||||
model_name_or_path=model_name_or_path,
|
model_name_or_path=model_name_or_path,
|
||||||
do_eval=True,
|
do_eval=True,
|
||||||
finetuning_type=finetuning_type,
|
|
||||||
prompt_template=template,
|
|
||||||
dataset=",".join(dataset),
|
|
||||||
dataset_dir=dataset_dir,
|
|
||||||
max_samples=int(max_samples),
|
|
||||||
output_dir=output_dir,
|
|
||||||
checkpoint_dir=checkpoint_dir,
|
|
||||||
overwrite_cache=True,
|
overwrite_cache=True,
|
||||||
predict_with_generate=True,
|
predict_with_generate=True,
|
||||||
|
checkpoint_dir=checkpoint_dir,
|
||||||
|
finetuning_type=finetuning_type,
|
||||||
|
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
||||||
|
prompt_template=template,
|
||||||
|
source_prefix=source_prefix,
|
||||||
|
dataset_dir=dataset_dir,
|
||||||
|
dataset=",".join(dataset),
|
||||||
|
max_source_length=max_source_length,
|
||||||
|
max_target_length=max_target_length,
|
||||||
|
max_samples=int(max_samples),
|
||||||
per_device_eval_batch_size=batch_size,
|
per_device_eval_batch_size=batch_size,
|
||||||
quantization_bit=int(quantization_bit) if quantization_bit else None
|
output_dir=output_dir
|
||||||
)
|
)
|
||||||
|
|
||||||
if predict:
|
if predict:
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import json
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
import matplotlib.figure
|
import matplotlib.figure
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from typing import Tuple
|
from typing import Any, Dict, Tuple
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from llmtuner.extras.ploting import smooth
|
from llmtuner.extras.ploting import smooth
|
||||||
@@ -23,7 +23,7 @@ def get_time() -> str:
|
|||||||
return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
||||||
|
|
||||||
|
|
||||||
def can_preview(dataset_dir: str, dataset: list) -> dict:
|
def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
|
||||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||||
dataset_info = json.load(f)
|
dataset_info = json.load(f)
|
||||||
if (
|
if (
|
||||||
@@ -36,7 +36,7 @@ def can_preview(dataset_dir: str, dataset: list) -> dict:
|
|||||||
return gr.update(interactive=False)
|
return gr.update(interactive=False)
|
||||||
|
|
||||||
|
|
||||||
def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, dict]:
|
def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, Dict[str, Any]]:
|
||||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||||
dataset_info = json.load(f)
|
dataset_info = json.load(f)
|
||||||
data_file = dataset_info[dataset[0]]["file_name"]
|
data_file = dataset_info[dataset[0]]["file_name"]
|
||||||
@@ -45,6 +45,13 @@ def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, dict]:
|
|||||||
return len(data), data[:2], gr.update(visible=True)
|
return len(data), data[:2], gr.update(visible=True)
|
||||||
|
|
||||||
|
|
||||||
|
def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
||||||
|
if finetuning_type != "lora":
|
||||||
|
return gr.update(value="", interactive=False)
|
||||||
|
else:
|
||||||
|
return gr.update(interactive=True)
|
||||||
|
|
||||||
|
|
||||||
def get_eval_results(path: os.PathLike) -> str:
|
def get_eval_results(path: os.PathLike) -> str:
|
||||||
with open(path, "r", encoding="utf-8") as f:
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
result = json.dumps(json.load(f), indent=4)
|
result = json.dumps(json.load(f), indent=4)
|
||||||
@@ -66,6 +73,10 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl
|
|||||||
if log_info.get("loss", None):
|
if log_info.get("loss", None):
|
||||||
steps.append(log_info["current_steps"])
|
steps.append(log_info["current_steps"])
|
||||||
losses.append(log_info["loss"])
|
losses.append(log_info["loss"])
|
||||||
|
|
||||||
|
if len(losses) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
ax.plot(steps, losses, alpha=0.4, label="original")
|
ax.plot(steps, losses, alpha=0.4, label="original")
|
||||||
ax.plot(steps, smooth(losses), label="smoothed")
|
ax.plot(steps, smooth(losses), label="smoothed")
|
||||||
ax.legend()
|
ax.legend()
|
||||||
|
|||||||
44
src/web_demo.py
Normal file
44
src/web_demo.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Implements user interface in browser for fine-tuned models.
|
||||||
|
# Usage: python web_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
from llmtuner import get_infer_args
|
||||||
|
from llmtuner.webui.chat import WebChatModel
|
||||||
|
from llmtuner.webui.components.chatbot import create_chat_box
|
||||||
|
from llmtuner.webui.manager import Manager
|
||||||
|
|
||||||
|
|
||||||
|
require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
chat_model = WebChatModel(*get_infer_args())
|
||||||
|
|
||||||
|
with gr.Blocks(title="Web Demo") as demo:
|
||||||
|
lang = gr.Dropdown(choices=["en", "zh"], value="en")
|
||||||
|
|
||||||
|
_, _, _, chat_elems = create_chat_box(chat_model, visible=True)
|
||||||
|
|
||||||
|
manager = Manager([{"lang": lang}, chat_elems])
|
||||||
|
|
||||||
|
demo.load(
|
||||||
|
manager.gen_label,
|
||||||
|
[lang],
|
||||||
|
[lang] + [elem for elem in chat_elems.values()],
|
||||||
|
)
|
||||||
|
|
||||||
|
lang.change(
|
||||||
|
manager.gen_label,
|
||||||
|
[lang],
|
||||||
|
[lang] + [elem for elem in chat_elems.values()],
|
||||||
|
)
|
||||||
|
|
||||||
|
demo.queue()
|
||||||
|
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user