release v0.1.0
Former-commit-id: 63c8d3a17cb18f0d8a8e37bfa147daf5bdd28ea9
This commit is contained in:
4
src/llmtuner/webui/components/__init__.py
Normal file
4
src/llmtuner/webui/components/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from llmtuner.webui.components.eval import create_eval_tab
|
||||
from llmtuner.webui.components.infer import create_infer_tab
|
||||
from llmtuner.webui.components.top import create_top
|
||||
from llmtuner.webui.components.sft import create_sft_tab
|
||||
54
src/llmtuner/webui/components/chatbot.py
Normal file
54
src/llmtuner/webui/components/chatbot.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import gradio as gr
|
||||
from gradio.blocks import Block
|
||||
from gradio.components import Component
|
||||
|
||||
from llmtuner.webui.chat import WebChatModel
|
||||
|
||||
|
||||
def create_chat_box(
|
||||
chat_model: WebChatModel
|
||||
) -> Tuple[Block, Component, Component, Dict[str, Component]]:
|
||||
with gr.Box(visible=False) as chat_box:
|
||||
chatbot = gr.Chatbot()
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
with gr.Column(scale=12):
|
||||
query = gr.Textbox(show_label=False, lines=8)
|
||||
|
||||
with gr.Column(min_width=32, scale=1):
|
||||
submit_btn = gr.Button(variant="primary")
|
||||
|
||||
with gr.Column(scale=1):
|
||||
clear_btn = gr.Button()
|
||||
max_new_tokens = gr.Slider(
|
||||
10, 2048, value=chat_model.generating_args.max_new_tokens, step=1, interactive=True
|
||||
)
|
||||
top_p = gr.Slider(0.01, 1, value=chat_model.generating_args.top_p, step=0.01, interactive=True)
|
||||
temperature = gr.Slider(
|
||||
0.01, 1.5, value=chat_model.generating_args.temperature, step=0.01, interactive=True
|
||||
)
|
||||
|
||||
history = gr.State([])
|
||||
|
||||
submit_btn.click(
|
||||
chat_model.predict,
|
||||
[chatbot, query, history, max_new_tokens, top_p, temperature],
|
||||
[chatbot, history],
|
||||
show_progress=True
|
||||
).then(
|
||||
lambda: gr.update(value=""), outputs=[query]
|
||||
)
|
||||
|
||||
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
|
||||
|
||||
return chat_box, chatbot, history, dict(
|
||||
query=query,
|
||||
submit_btn=submit_btn,
|
||||
clear_btn=clear_btn,
|
||||
max_new_tokens=max_new_tokens,
|
||||
top_p=top_p,
|
||||
temperature=temperature
|
||||
)
|
||||
19
src/llmtuner/webui/components/data.py
Normal file
19
src/llmtuner/webui/components/data.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import gradio as gr
|
||||
from gradio.blocks import Block
|
||||
from gradio.components import Component
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def create_preview_box() -> Tuple[Block, Component, Component, Component]:
|
||||
with gr.Box(visible=False, elem_classes="modal-box") as preview_box:
|
||||
with gr.Row():
|
||||
preview_count = gr.Number(interactive=False)
|
||||
|
||||
with gr.Row():
|
||||
preview_samples = gr.JSON(interactive=False)
|
||||
|
||||
close_btn = gr.Button()
|
||||
|
||||
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box])
|
||||
|
||||
return preview_box, preview_count, preview_samples, close_btn
|
||||
60
src/llmtuner/webui/components/eval.py
Normal file
60
src/llmtuner/webui/components/eval.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from typing import Dict
|
||||
import gradio as gr
|
||||
from gradio.components import Component
|
||||
|
||||
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
|
||||
from llmtuner.webui.components.data import create_preview_box
|
||||
from llmtuner.webui.runner import Runner
|
||||
from llmtuner.webui.utils import can_preview, get_preview
|
||||
|
||||
|
||||
def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
|
||||
with gr.Row():
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, interactive=True, scale=2)
|
||||
dataset = gr.Dropdown(multiselect=True, interactive=True, scale=4)
|
||||
preview_btn = gr.Button(interactive=False, scale=1)
|
||||
|
||||
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
|
||||
|
||||
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
|
||||
dataset.change(can_preview, [dataset_dir, dataset], [preview_btn])
|
||||
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
|
||||
|
||||
with gr.Row():
|
||||
max_samples = gr.Textbox(value="100000", interactive=True)
|
||||
batch_size = gr.Slider(value=8, minimum=1, maximum=128, step=1, interactive=True)
|
||||
quantization_bit = gr.Dropdown([8, 4])
|
||||
predict = gr.Checkbox(value=True)
|
||||
|
||||
with gr.Row():
|
||||
start_btn = gr.Button()
|
||||
stop_btn = gr.Button()
|
||||
|
||||
output_box = gr.Markdown()
|
||||
|
||||
start_btn.click(
|
||||
runner.run_eval,
|
||||
[
|
||||
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"], top_elems["template"],
|
||||
dataset, dataset_dir, max_samples, batch_size, quantization_bit, predict
|
||||
],
|
||||
[output_box]
|
||||
)
|
||||
stop_btn.click(runner.set_abort, queue=False)
|
||||
|
||||
return dict(
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=dataset,
|
||||
preview_btn=preview_btn,
|
||||
preview_count=preview_count,
|
||||
preview_samples=preview_samples,
|
||||
close_btn=close_btn,
|
||||
max_samples=max_samples,
|
||||
batch_size=batch_size,
|
||||
quantization_bit=quantization_bit,
|
||||
predict=predict,
|
||||
start_btn=start_btn,
|
||||
stop_btn=stop_btn,
|
||||
output_box=output_box
|
||||
)
|
||||
47
src/llmtuner/webui/components/infer.py
Normal file
47
src/llmtuner/webui/components/infer.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import Dict
|
||||
|
||||
import gradio as gr
|
||||
from gradio.components import Component
|
||||
|
||||
from llmtuner.webui.chat import WebChatModel
|
||||
from llmtuner.webui.components.chatbot import create_chat_box
|
||||
|
||||
|
||||
def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
|
||||
with gr.Row():
|
||||
load_btn = gr.Button()
|
||||
unload_btn = gr.Button()
|
||||
quantization_bit = gr.Dropdown([8, 4])
|
||||
|
||||
info_box = gr.Markdown()
|
||||
|
||||
chat_model = WebChatModel()
|
||||
chat_box, chatbot, history, chat_elems = create_chat_box(chat_model)
|
||||
|
||||
load_btn.click(
|
||||
chat_model.load_model,
|
||||
[
|
||||
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"], top_elems["template"],
|
||||
quantization_bit
|
||||
],
|
||||
[info_box]
|
||||
).then(
|
||||
lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box]
|
||||
)
|
||||
|
||||
unload_btn.click(
|
||||
chat_model.unload_model, [top_elems["lang"]], [info_box]
|
||||
).then(
|
||||
lambda: ([], []), outputs=[chatbot, history]
|
||||
).then(
|
||||
lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box]
|
||||
)
|
||||
|
||||
return dict(
|
||||
quantization_bit=quantization_bit,
|
||||
info_box=info_box,
|
||||
load_btn=load_btn,
|
||||
unload_btn=unload_btn,
|
||||
**chat_elems
|
||||
)
|
||||
94
src/llmtuner/webui/components/sft.py
Normal file
94
src/llmtuner/webui/components/sft.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from typing import Dict
|
||||
from transformers.trainer_utils import SchedulerType
|
||||
|
||||
import gradio as gr
|
||||
from gradio.components import Component
|
||||
|
||||
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
|
||||
from llmtuner.webui.components.data import create_preview_box
|
||||
from llmtuner.webui.runner import Runner
|
||||
from llmtuner.webui.utils import can_preview, get_preview, gen_plot
|
||||
|
||||
|
||||
def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
|
||||
with gr.Row():
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, interactive=True, scale=1)
|
||||
dataset = gr.Dropdown(multiselect=True, interactive=True, scale=4)
|
||||
preview_btn = gr.Button(interactive=False, scale=1)
|
||||
|
||||
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
|
||||
|
||||
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
|
||||
dataset.change(can_preview, [dataset_dir, dataset], [preview_btn])
|
||||
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
|
||||
|
||||
with gr.Row():
|
||||
learning_rate = gr.Textbox(value="5e-5", interactive=True)
|
||||
num_train_epochs = gr.Textbox(value="3.0", interactive=True)
|
||||
max_samples = gr.Textbox(value="100000", interactive=True)
|
||||
quantization_bit = gr.Dropdown([8, 4])
|
||||
|
||||
with gr.Row():
|
||||
batch_size = gr.Slider(value=4, minimum=1, maximum=128, step=1, interactive=True)
|
||||
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=32, step=1, interactive=True)
|
||||
lr_scheduler_type = gr.Dropdown(
|
||||
value="cosine", choices=[scheduler.value for scheduler in SchedulerType], interactive=True
|
||||
)
|
||||
fp16 = gr.Checkbox(value=True)
|
||||
|
||||
with gr.Row():
|
||||
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5, interactive=True)
|
||||
save_steps = gr.Slider(value=100, minimum=10, maximum=2000, step=10, interactive=True)
|
||||
|
||||
with gr.Row():
|
||||
start_btn = gr.Button()
|
||||
stop_btn = gr.Button()
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
output_dir = gr.Textbox(interactive=True)
|
||||
output_box = gr.Markdown()
|
||||
|
||||
with gr.Column(scale=1):
|
||||
loss_viewer = gr.Plot()
|
||||
|
||||
start_btn.click(
|
||||
runner.run_train,
|
||||
[
|
||||
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"], top_elems["template"],
|
||||
dataset, dataset_dir, learning_rate, num_train_epochs, max_samples,
|
||||
fp16, quantization_bit, batch_size, gradient_accumulation_steps,
|
||||
lr_scheduler_type, logging_steps, save_steps, output_dir
|
||||
],
|
||||
[output_box]
|
||||
)
|
||||
stop_btn.click(runner.set_abort, queue=False)
|
||||
|
||||
output_box.change(
|
||||
gen_plot, [top_elems["model_name"], top_elems["finetuning_type"], output_dir], loss_viewer, queue=False
|
||||
)
|
||||
|
||||
return dict(
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=dataset,
|
||||
preview_btn=preview_btn,
|
||||
preview_count=preview_count,
|
||||
preview_samples=preview_samples,
|
||||
close_btn=close_btn,
|
||||
learning_rate=learning_rate,
|
||||
num_train_epochs=num_train_epochs,
|
||||
max_samples=max_samples,
|
||||
quantization_bit=quantization_bit,
|
||||
batch_size=batch_size,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
lr_scheduler_type=lr_scheduler_type,
|
||||
fp16=fp16,
|
||||
logging_steps=logging_steps,
|
||||
save_steps=save_steps,
|
||||
start_btn=start_btn,
|
||||
stop_btn=stop_btn,
|
||||
output_dir=output_dir,
|
||||
output_box=output_box,
|
||||
loss_viewer=loss_viewer
|
||||
)
|
||||
42
src/llmtuner/webui/components/top.py
Normal file
42
src/llmtuner/webui/components/top.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing import Dict
|
||||
|
||||
import gradio as gr
|
||||
from gradio.components import Component
|
||||
|
||||
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
|
||||
from llmtuner.extras.template import templates
|
||||
from llmtuner.webui.common import list_checkpoint, get_model_path, save_config
|
||||
|
||||
|
||||
def create_top() -> Dict[str, Component]:
|
||||
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
|
||||
|
||||
with gr.Row():
|
||||
lang = gr.Dropdown(choices=["en", "zh"], value="en", interactive=True, scale=1)
|
||||
model_name = gr.Dropdown(choices=available_models, scale=3)
|
||||
model_path = gr.Textbox(scale=3)
|
||||
|
||||
with gr.Row():
|
||||
finetuning_type = gr.Dropdown(value="lora", choices=METHODS, interactive=True, scale=1)
|
||||
template = gr.Dropdown(value="default", choices=list(templates.keys()), interactive=True, scale=1)
|
||||
checkpoints = gr.Dropdown(multiselect=True, interactive=True, scale=4)
|
||||
refresh_btn = gr.Button(scale=1)
|
||||
|
||||
model_name.change(
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||
).then(
|
||||
get_model_path, [model_name], [model_path]
|
||||
) # do not save config since the below line will save
|
||||
model_path.change(save_config, [model_name, model_path])
|
||||
finetuning_type.change(list_checkpoint, [model_name, finetuning_type], [checkpoints])
|
||||
refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints])
|
||||
|
||||
return dict(
|
||||
lang=lang,
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
finetuning_type=finetuning_type,
|
||||
template=template,
|
||||
checkpoints=checkpoints,
|
||||
refresh_btn=refresh_btn
|
||||
)
|
||||
Reference in New Issue
Block a user