Former-commit-id: 5dbc9b355e85b203cb43ff72589374f0e04be391
This commit is contained in:
hiyouga
2023-10-15 18:28:45 +08:00
parent a003d1fa1e
commit a6f800b741
9 changed files with 40 additions and 57 deletions

View File

@@ -9,12 +9,12 @@ from llmtuner.webui.components import (
create_export_tab,
create_chat_box
)
from llmtuner.webui.common import load_config, save_config
from llmtuner.webui.common import save_config
from llmtuner.webui.css import CSS
from llmtuner.webui.engine import Engine
require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
require_version("gradio==3.38.0", "To fix: pip install gradio==3.38.0")
def create_ui() -> gr.Blocks:
@@ -23,9 +23,6 @@ def create_ui() -> gr.Blocks:
with gr.Blocks(title="Web Tuner", css=CSS) as demo:
engine.manager.all_elems["top"] = create_top()
lang: "gr.Dropdown" = engine.manager.get_elem("top.lang")
config = engine.manager.get_elem("top.config")
model_name = engine.manager.get_elem("top.model_name")
model_path = engine.manager.get_elem("top.model_path")
with gr.Tab("Train"):
engine.manager.all_elems["train"] = create_train_tab(engine)
@@ -39,13 +36,9 @@ def create_ui() -> gr.Blocks:
with gr.Tab("Export"):
engine.manager.all_elems["export"] = create_export_tab(engine)
demo.load(engine.resume, [config], engine.manager.list_elems())
lang.change(
engine.change_lang, [lang], engine.manager.list_elems(), queue=False
).then(
save_config, inputs=[config, lang, model_name, model_path]
)
demo.load(engine.resume, outputs=engine.manager.list_elems())
lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False)
lang.input(save_config, inputs=[lang], queue=False)
return demo
@@ -54,21 +47,15 @@ def create_web_demo() -> gr.Blocks:
engine = Engine(pure_chat=True)
with gr.Blocks(title="Web Demo", css=CSS) as demo:
config = gr.State(value=load_config())
lang = gr.Dropdown(choices=["en", "zh"])
engine.manager.all_elems["top"] = dict(config=config, lang=lang)
engine.manager.all_elems["top"] = dict(lang=lang)
chat_box, _, _, chat_elems = create_chat_box(engine, visible=True)
engine.manager.all_elems["infer"] = dict(chat_box=chat_box, **chat_elems)
demo.load(engine.resume, [config], engine.manager.list_elems())
lang.change(
engine.change_lang, [lang], engine.manager.list_elems(), queue=False
).then(
save_config, inputs=[config, lang]
)
demo.load(engine.resume, outputs=engine.manager.list_elems())
lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False)
lang.input(save_config, inputs=[lang], queue=False)
return demo