refactor webui
Former-commit-id: 813ecd8e51949c21ab6fbaa51cc2b1a84ee07952
This commit is contained in:
@@ -5,9 +5,12 @@ from llmtuner.webui.utils import save_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
from llmtuner.webui.engine import Engine
|
||||
|
||||
|
||||
def create_export_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"]:
|
||||
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
elem_dict = dict()
|
||||
|
||||
with gr.Row():
|
||||
save_dir = gr.Textbox()
|
||||
max_shard_size = gr.Slider(value=10, minimum=1, maximum=100)
|
||||
@@ -18,20 +21,23 @@ def create_export_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component
|
||||
export_btn.click(
|
||||
save_model,
|
||||
[
|
||||
top_elems["lang"],
|
||||
top_elems["model_name"],
|
||||
top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"],
|
||||
top_elems["template"],
|
||||
engine.manager.get_elem("top.lang"),
|
||||
engine.manager.get_elem("top.model_name"),
|
||||
engine.manager.get_elem("top.model_path"),
|
||||
engine.manager.get_elem("top.checkpoints"),
|
||||
engine.manager.get_elem("top.finetuning_type"),
|
||||
engine.manager.get_elem("top.template"),
|
||||
max_shard_size,
|
||||
save_dir
|
||||
],
|
||||
[info_box]
|
||||
)
|
||||
|
||||
return dict(
|
||||
elem_dict.update(dict(
|
||||
save_dir=save_dir,
|
||||
max_shard_size=max_shard_size,
|
||||
export_btn=export_btn,
|
||||
info_box=info_box
|
||||
)
|
||||
))
|
||||
|
||||
return elem_dict
|
||||
|
||||
Reference in New Issue
Block a user