refactor webui

Former-commit-id: 813ecd8e51949c21ab6fbaa51cc2b1a84ee07952
This commit is contained in:
hiyouga
2023-10-15 03:06:21 +08:00
parent 4b1473502f
commit 6a61b4b638
14 changed files with 440 additions and 501 deletions

View File

@@ -9,65 +9,54 @@ from llmtuner.webui.components import (
create_export_tab,
create_chat_box
)
from llmtuner.webui.chat import WebChatModel
from llmtuner.webui.common import load_config, save_config
from llmtuner.webui.css import CSS
from llmtuner.webui.manager import Manager
from llmtuner.webui.runner import Runner
from llmtuner.webui.engine import Engine
require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
def create_ui() -> gr.Blocks:
runner = Runner()
engine = Engine(init_chat=False)
with gr.Blocks(title="Web Tuner", css=CSS) as demo:
top_elems = create_top()
engine.manager.all_elems["top"] = create_top(engine)
with gr.Tab("Train"):
train_elems = create_train_tab(top_elems, runner)
engine.manager.all_elems["train"] = create_train_tab(engine)
with gr.Tab("Evaluate"):
eval_elems = create_eval_tab(top_elems, runner)
engine.manager.all_elems["eval"] = create_eval_tab(engine)
with gr.Tab("Chat"):
infer_elems = create_infer_tab(top_elems)
engine.manager.all_elems["infer"] = create_infer_tab(engine)
with gr.Tab("Export"):
export_elems = create_export_tab(top_elems)
engine.manager.all_elems["export"] = create_export_tab(engine)
elem_list = [top_elems, train_elems, eval_elems, infer_elems, export_elems]
manager = Manager(elem_list)
demo.load(
manager.gen_label,
[top_elems["lang"]],
[elem for elems in elem_list for elem in elems.values()],
)
top_elems["lang"].change(
manager.gen_label,
[top_elems["lang"]],
[elem for elems in elem_list for elem in elems.values()],
queue=False
)
demo.load(engine.resume, [engine.manager.get_elem("top.config")], engine.manager.list_elems())
return demo
def create_web_demo() -> gr.Blocks:
chat_model = WebChatModel(lazy_init=False)
engine = Engine(init_chat=True)
with gr.Blocks(title="Web Demo", css=CSS) as demo:
lang = gr.Dropdown(choices=["en", "zh"], value="en")
lang = gr.Dropdown(choices=["en", "zh"])
config = gr.State(value=load_config())
lang.change(
engine.change_lang, [lang], engine.manager.list_elems(), queue=False
).then(
save_config, inputs=[config, lang]
)
_, _, _, chat_elems = create_chat_box(chat_model, visible=True)
engine.manager.all_elems["top"] = dict(lang=lang)
manager = Manager([{"lang": lang}, chat_elems])
_, _, _, engine.manager.all_elems["infer"] = create_chat_box(engine, visible=True)
demo.load(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
lang.select(manager.gen_label, [lang], [lang] + list(chat_elems.values()), queue=False)
demo.load(engine.resume, [config], engine.manager.list_elems())
return demo