add demo mode for web UI
Former-commit-id: 5ad34f08b4e1505d7933b973497347f126b2e818
This commit is contained in:
@@ -14,8 +14,14 @@ if TYPE_CHECKING:
|
||||
|
||||
class WebChatModel(ChatModel):
|
||||
|
||||
def __init__(self, manager: "Manager", lazy_init: Optional[bool] = True) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
manager: "Manager",
|
||||
demo_mode: Optional[bool] = False,
|
||||
lazy_init: Optional[bool] = True
|
||||
) -> None:
|
||||
self.manager = manager
|
||||
self.demo_mode = demo_mode
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.generating_args = GeneratingArguments()
|
||||
@@ -36,6 +42,8 @@ class WebChatModel(ChatModel):
|
||||
error = ALERTS["err_no_model"][lang]
|
||||
elif not get("top.model_path"):
|
||||
error = ALERTS["err_no_path"][lang]
|
||||
elif self.demo_mode:
|
||||
error = ALERTS["err_demo"][lang]
|
||||
|
||||
if error:
|
||||
gr.Warning(error)
|
||||
|
||||
@@ -70,7 +70,7 @@ def get_module(model_name: str) -> str:
|
||||
|
||||
|
||||
def get_template(model_name: str) -> str:
|
||||
if model_name.endswith("Chat") and get_prefix(model_name) in DEFAULT_TEMPLATE:
|
||||
if model_name and model_name.endswith("Chat") and get_prefix(model_name) in DEFAULT_TEMPLATE:
|
||||
return DEFAULT_TEMPLATE[get_prefix(model_name)]
|
||||
return "default"
|
||||
|
||||
|
||||
@@ -1,4 +1,11 @@
|
||||
CSS = r"""
|
||||
.duplicate-button {
|
||||
margin: auto;
|
||||
color: white;
|
||||
background: black;
|
||||
border-radius: 100vh;
|
||||
}
|
||||
|
||||
.modal-box {
|
||||
position: fixed !important;
|
||||
top: 50%;
|
||||
|
||||
@@ -12,11 +12,11 @@ from llmtuner.webui.utils import get_time
|
||||
|
||||
class Engine:
|
||||
|
||||
def __init__(self, pure_chat: Optional[bool] = False) -> None:
|
||||
def __init__(self, demo_mode: Optional[bool] = False, pure_chat: Optional[bool] = False) -> None:
|
||||
self.pure_chat = pure_chat
|
||||
self.manager: "Manager" = Manager()
|
||||
self.runner: "Runner" = Runner(self.manager)
|
||||
self.chatter: "WebChatModel" = WebChatModel(manager=self.manager, lazy_init=(not pure_chat))
|
||||
self.manager = Manager()
|
||||
self.runner = Runner(self.manager, demo_mode=demo_mode)
|
||||
self.chatter = WebChatModel(manager=self.manager, demo_mode=demo_mode, lazy_init=(not pure_chat))
|
||||
|
||||
def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]):
|
||||
return {self.manager.get_elem_by_name(k): gr.update(**v) for k, v in resume_dict.items()}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import gradio as gr
|
||||
from typing import Optional
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from llmtuner.webui.components import (
|
||||
@@ -17,10 +18,20 @@ from llmtuner.webui.engine import Engine
|
||||
require_version("gradio>=3.38.0,<4.0.0", "To fix: pip install \"gradio>=3.38.0,<4.0.0\"")
|
||||
|
||||
|
||||
def create_ui() -> gr.Blocks:
|
||||
engine = Engine(pure_chat=False)
|
||||
def create_ui(demo_mode: Optional[bool] = False) -> gr.Blocks:
|
||||
engine = Engine(demo_mode=demo_mode, pure_chat=False)
|
||||
|
||||
with gr.Blocks(title="LLaMA Board", css=CSS) as demo:
|
||||
if demo_mode:
|
||||
gr.HTML(
|
||||
"<h1><center>LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory</center></h1>"
|
||||
)
|
||||
gr.HTML(
|
||||
"<h3><center>Visit <a href=\"https://github.com/hiyouga/LLaMA-Factory\" target=\"_blank\">"
|
||||
"LLaMA Factory</a> for details.</center></h3>"
|
||||
)
|
||||
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
|
||||
|
||||
engine.manager.all_elems["top"] = create_top()
|
||||
lang: "gr.Dropdown" = engine.manager.get_elem_by_name("top.lang")
|
||||
|
||||
@@ -33,8 +44,9 @@ def create_ui() -> gr.Blocks:
|
||||
with gr.Tab("Chat"):
|
||||
engine.manager.all_elems["infer"] = create_infer_tab(engine)
|
||||
|
||||
with gr.Tab("Export"):
|
||||
engine.manager.all_elems["export"] = create_export_tab(engine)
|
||||
if not demo_mode:
|
||||
with gr.Tab("Export"):
|
||||
engine.manager.all_elems["export"] = create_export_tab(engine)
|
||||
|
||||
demo.load(engine.resume, outputs=engine.manager.list_elems())
|
||||
lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False)
|
||||
|
||||
@@ -659,6 +659,10 @@ ALERTS = {
|
||||
"en": "Failed.",
|
||||
"zh": "训练出错。"
|
||||
},
|
||||
"err_demo": {
|
||||
"en": "Training is unavailable in demo mode, duplicate the space to a private one first.",
|
||||
"zh": "展示模式不支持训练,请先复制到私人空间。"
|
||||
},
|
||||
"info_aborting": {
|
||||
"en": "Aborted, wait for terminating...",
|
||||
"zh": "训练中断,正在等待线程结束……"
|
||||
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
import gradio as gr
|
||||
from threading import Thread
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple
|
||||
|
||||
import transformers
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
@@ -24,8 +24,9 @@ if TYPE_CHECKING:
|
||||
|
||||
class Runner:
|
||||
|
||||
def __init__(self, manager: "Manager") -> None:
|
||||
def __init__(self, manager: "Manager", demo_mode: Optional[bool] = False) -> None:
|
||||
self.manager = manager
|
||||
self.demo_mode = demo_mode
|
||||
""" Resume """
|
||||
self.thread: "Thread" = None
|
||||
self.do_train = True
|
||||
@@ -46,9 +47,8 @@ class Runner:
|
||||
|
||||
def set_abort(self) -> None:
|
||||
self.aborted = True
|
||||
self.running = False
|
||||
|
||||
def _initialize(self, data: Dict[Component, Any], do_train: bool) -> str:
|
||||
def _initialize(self, data: Dict[Component, Any], do_train: bool, from_preview: bool) -> str:
|
||||
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
||||
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
|
||||
dataset = get("train.dataset") if do_train else get("eval.dataset")
|
||||
@@ -65,6 +65,9 @@ class Runner:
|
||||
if len(dataset) == 0:
|
||||
return ALERTS["err_no_dataset"][lang]
|
||||
|
||||
if self.demo_mode and (not from_preview):
|
||||
return ALERTS["err_demo"][lang]
|
||||
|
||||
self.aborted = False
|
||||
self.logger_handler.reset()
|
||||
self.trainer_callback = LogCallback(self)
|
||||
@@ -196,7 +199,7 @@ class Runner:
|
||||
return args
|
||||
|
||||
def _preview(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
error = self._initialize(data, do_train)
|
||||
error = self._initialize(data, do_train, from_preview=True)
|
||||
if error:
|
||||
gr.Warning(error)
|
||||
yield error, gr.update(visible=False)
|
||||
@@ -205,14 +208,13 @@ class Runner:
|
||||
yield gen_cmd(args), gr.update(visible=False)
|
||||
|
||||
def _launch(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
error = self._initialize(data, do_train)
|
||||
error = self._initialize(data, do_train, from_preview=False)
|
||||
if error:
|
||||
gr.Warning(error)
|
||||
yield error, gr.update(visible=False)
|
||||
else:
|
||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
|
||||
self.running = True
|
||||
self.do_train, self.running_data = do_train, data
|
||||
self.monitor_inputs = dict(lang=data[self.manager.get_elem_by_name("top.lang")], output_dir=args["output_dir"])
|
||||
self.thread = Thread(target=run_exp, kwargs=run_kwargs)
|
||||
@@ -232,6 +234,7 @@ class Runner:
|
||||
yield from self._launch(data, do_train=False)
|
||||
|
||||
def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
self.running = True
|
||||
lang, output_dir = self.monitor_inputs["lang"], self.monitor_inputs["output_dir"]
|
||||
while self.thread.is_alive():
|
||||
time.sleep(2)
|
||||
|
||||
Reference in New Issue
Block a user