refactor webui
Former-commit-id: 813ecd8e51949c21ab6fbaa51cc2b1a84ee07952
This commit is contained in:
@@ -1,53 +1,42 @@
|
||||
import gradio as gr
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from llmtuner.webui.chat import WebChatModel
|
||||
from llmtuner.webui.components.chatbot import create_chat_box
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
from llmtuner.webui.engine import Engine
|
||||
|
||||
|
||||
def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"]:
|
||||
def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
input_elems = engine.manager.get_base_elems()
|
||||
elem_dict = dict()
|
||||
|
||||
with gr.Row():
|
||||
load_btn = gr.Button()
|
||||
unload_btn = gr.Button()
|
||||
|
||||
info_box = gr.Textbox(show_label=False, interactive=False)
|
||||
|
||||
chat_model = WebChatModel(lazy_init=True)
|
||||
chat_box, chatbot, history, chat_elems = create_chat_box(chat_model)
|
||||
elem_dict.update(dict(
|
||||
info_box=info_box, load_btn=load_btn, unload_btn=unload_btn
|
||||
))
|
||||
|
||||
chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False)
|
||||
elem_dict.update(dict(chat_box=chat_box, **chat_elems))
|
||||
|
||||
load_btn.click(
|
||||
chat_model.load_model,
|
||||
[
|
||||
top_elems["lang"],
|
||||
top_elems["model_name"],
|
||||
top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"],
|
||||
top_elems["quantization_bit"],
|
||||
top_elems["template"],
|
||||
top_elems["system_prompt"],
|
||||
top_elems["flash_attn"],
|
||||
top_elems["shift_attn"],
|
||||
top_elems["rope_scaling"]
|
||||
],
|
||||
[info_box]
|
||||
engine.chatter.load_model, input_elems, [info_box]
|
||||
).then(
|
||||
lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box]
|
||||
lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
|
||||
)
|
||||
|
||||
unload_btn.click(
|
||||
chat_model.unload_model, [top_elems["lang"]], [info_box]
|
||||
engine.chatter.unload_model, input_elems, [info_box]
|
||||
).then(
|
||||
lambda: ([], []), outputs=[chatbot, history]
|
||||
).then(
|
||||
lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box]
|
||||
lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
|
||||
)
|
||||
|
||||
return dict(
|
||||
info_box=info_box,
|
||||
load_btn=load_btn,
|
||||
unload_btn=unload_btn,
|
||||
**chat_elems
|
||||
)
|
||||
return elem_dict
|
||||
|
||||
Reference in New Issue
Block a user