[webui] display swanlab exp link (#7089)
* webui add swanlab link * change callback name * update --------- Co-authored-by: hiyouga <hiyouga@buaa.edu.cn> Former-commit-id: 27a4b93871c63b839c92940766bd7e0177972c9b
This commit is contained in:
@@ -299,9 +299,18 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
swanlab_workspace = gr.Textbox()
|
||||
swanlab_api_key = gr.Textbox()
|
||||
swanlab_mode = gr.Dropdown(choices=["cloud", "local"], value="cloud")
|
||||
swanlab_link = gr.Markdown(visible=False, container=True)
|
||||
|
||||
input_elems.update(
|
||||
{use_swanlab, swanlab_project, swanlab_run_name, swanlab_workspace, swanlab_api_key, swanlab_mode}
|
||||
{
|
||||
use_swanlab,
|
||||
swanlab_project,
|
||||
swanlab_run_name,
|
||||
swanlab_workspace,
|
||||
swanlab_api_key,
|
||||
swanlab_mode,
|
||||
swanlab_link,
|
||||
}
|
||||
)
|
||||
elem_dict.update(
|
||||
dict(
|
||||
@@ -312,6 +321,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
swanlab_workspace=swanlab_workspace,
|
||||
swanlab_api_key=swanlab_api_key,
|
||||
swanlab_mode=swanlab_mode,
|
||||
swanlab_link=swanlab_link,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -364,7 +374,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
loss_viewer=loss_viewer,
|
||||
)
|
||||
)
|
||||
output_elems = [output_box, progress_bar, loss_viewer]
|
||||
output_elems = [output_box, progress_bar, loss_viewer, swanlab_link]
|
||||
|
||||
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
|
||||
start_btn.click(engine.runner.run_train, input_elems, output_elems)
|
||||
|
||||
@@ -23,6 +23,7 @@ from ..extras.constants import (
|
||||
PEFT_METHODS,
|
||||
RUNNING_LOG,
|
||||
STAGES_USE_PAIR_DATA,
|
||||
SWANLAB_CONFIG,
|
||||
TRAINER_LOG,
|
||||
TRAINING_STAGES,
|
||||
)
|
||||
@@ -30,6 +31,7 @@ from ..extras.packages import is_gradio_available, is_matplotlib_available
|
||||
from ..extras.ploting import gen_loss_plot
|
||||
from ..model import QuantizationMethod
|
||||
from .common import DEFAULT_CONFIG_DIR, DEFAULT_DATA_DIR, get_model_path, get_save_dir, get_template, load_dataset_info
|
||||
from .locales import ALERTS
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
@@ -86,20 +88,20 @@ def get_model_info(model_name: str) -> Tuple[str, str]:
|
||||
return get_model_path(model_name), get_template(model_name)
|
||||
|
||||
|
||||
def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]:
|
||||
def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Dict[str, Any]]:
|
||||
r"""
|
||||
Gets training infomation for monitor.
|
||||
|
||||
If do_train is True:
|
||||
Inputs: train.output_path
|
||||
Outputs: train.output_box, train.progress_bar, train.loss_viewer
|
||||
Inputs: top.lang, train.output_path
|
||||
Outputs: train.output_box, train.progress_bar, train.loss_viewer, train.swanlab_link
|
||||
If do_train is False:
|
||||
Inputs: eval.output_path
|
||||
Outputs: eval.output_box, eval.progress_bar, None
|
||||
Inputs: top.lang, eval.output_path
|
||||
Outputs: eval.output_box, eval.progress_bar, None, None
|
||||
"""
|
||||
running_log = ""
|
||||
running_progress = gr.Slider(visible=False)
|
||||
running_loss = None
|
||||
running_info = {}
|
||||
|
||||
running_log_path = os.path.join(output_path, RUNNING_LOG)
|
||||
if os.path.isfile(running_log_path):
|
||||
@@ -125,9 +127,19 @@ def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr
|
||||
running_progress = gr.Slider(label=label, value=percentage, visible=True)
|
||||
|
||||
if do_train and is_matplotlib_available():
|
||||
running_loss = gr.Plot(gen_loss_plot(trainer_log))
|
||||
running_info["loss_viewer"] = gr.Plot(gen_loss_plot(trainer_log))
|
||||
|
||||
return running_log, running_progress, running_loss
|
||||
swanlab_config_path = os.path.join(output_path, SWANLAB_CONFIG)
|
||||
if os.path.isfile(swanlab_config_path):
|
||||
with open(swanlab_config_path, encoding="utf-8") as f:
|
||||
swanlab_public_config = json.load(f)
|
||||
swanlab_link = swanlab_public_config["cloud"]["experiment_url"]
|
||||
if swanlab_link is not None:
|
||||
running_info["swanlab_link"] = gr.Markdown(
|
||||
ALERTS["info_swanlab_link"][lang] + swanlab_link, visible=True
|
||||
)
|
||||
|
||||
return running_log, running_progress, running_info
|
||||
|
||||
|
||||
def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown":
|
||||
|
||||
@@ -2814,4 +2814,11 @@ ALERTS = {
|
||||
"ko": "모델이 내보내졌습니다.",
|
||||
"ja": "モデルのエクスポートが完了しました。",
|
||||
},
|
||||
"info_swanlab_link": {
|
||||
"en": "### SwanLab Link\n",
|
||||
"ru": "### SwanLab ссылка\n",
|
||||
"zh": "### SwanLab 链接\n",
|
||||
"ko": "### SwanLab 링크\n",
|
||||
"ja": "### SwanLab リンク\n",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -423,6 +423,7 @@ class Runner:
|
||||
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if self.do_train else "eval"))
|
||||
progress_bar = self.manager.get_elem_by_id("{}.progress_bar".format("train" if self.do_train else "eval"))
|
||||
loss_viewer = self.manager.get_elem_by_id("train.loss_viewer") if self.do_train else None
|
||||
swanlab_link = self.manager.get_elem_by_id("train.swanlab_link") if self.do_train else None
|
||||
|
||||
running_log = ""
|
||||
while self.trainer is not None:
|
||||
@@ -432,16 +433,18 @@ class Runner:
|
||||
progress_bar: gr.Slider(visible=False),
|
||||
}
|
||||
else:
|
||||
running_log, running_progress, running_loss = get_trainer_info(output_path, self.do_train)
|
||||
running_log, running_progress, running_info = get_trainer_info(lang, output_path, self.do_train)
|
||||
return_dict = {
|
||||
output_box: running_log,
|
||||
progress_bar: running_progress,
|
||||
}
|
||||
if running_loss is not None:
|
||||
return_dict[loss_viewer] = running_loss
|
||||
if "loss_viewer" in running_info:
|
||||
return_dict[loss_viewer] = running_info["loss_viewer"]
|
||||
|
||||
if "swanlab_link" in running_info:
|
||||
return_dict[swanlab_link] = running_info["swanlab_link"]
|
||||
|
||||
yield return_dict
|
||||
|
||||
try:
|
||||
self.trainer.wait(2)
|
||||
self.trainer = None
|
||||
|
||||
Reference in New Issue
Block a user