[webui] display swanlab exp link (#7089)

* webui add swanlab link

* change callback name

* update

---------

Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
Former-commit-id: 27a4b93871c63b839c92940766bd7e0177972c9b
This commit is contained in:
Ze-Yi LIN
2025-02-27 19:40:54 +08:00
committed by GitHub
parent b9f84900ee
commit 11672f760d
6 changed files with 63 additions and 17 deletions

View File

@@ -23,6 +23,7 @@ from ..extras.constants import (
PEFT_METHODS,
RUNNING_LOG,
STAGES_USE_PAIR_DATA,
SWANLAB_CONFIG,
TRAINER_LOG,
TRAINING_STAGES,
)
@@ -30,6 +31,7 @@ from ..extras.packages import is_gradio_available, is_matplotlib_available
from ..extras.ploting import gen_loss_plot
from ..model import QuantizationMethod
from .common import DEFAULT_CONFIG_DIR, DEFAULT_DATA_DIR, get_model_path, get_save_dir, get_template, load_dataset_info
from .locales import ALERTS
if is_gradio_available():
@@ -86,20 +88,20 @@ def get_model_info(model_name: str) -> Tuple[str, str]:
return get_model_path(model_name), get_template(model_name)
def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]:
def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Dict[str, Any]]:
r"""
Gets training infomation for monitor.
If do_train is True:
Inputs: train.output_path
Outputs: train.output_box, train.progress_bar, train.loss_viewer
Inputs: top.lang, train.output_path
Outputs: train.output_box, train.progress_bar, train.loss_viewer, train.swanlab_link
If do_train is False:
Inputs: eval.output_path
Outputs: eval.output_box, eval.progress_bar, None
Inputs: top.lang, eval.output_path
Outputs: eval.output_box, eval.progress_bar, None, None
"""
running_log = ""
running_progress = gr.Slider(visible=False)
running_loss = None
running_info = {}
running_log_path = os.path.join(output_path, RUNNING_LOG)
if os.path.isfile(running_log_path):
@@ -125,9 +127,19 @@ def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr
running_progress = gr.Slider(label=label, value=percentage, visible=True)
if do_train and is_matplotlib_available():
running_loss = gr.Plot(gen_loss_plot(trainer_log))
running_info["loss_viewer"] = gr.Plot(gen_loss_plot(trainer_log))
return running_log, running_progress, running_loss
swanlab_config_path = os.path.join(output_path, SWANLAB_CONFIG)
if os.path.isfile(swanlab_config_path):
with open(swanlab_config_path, encoding="utf-8") as f:
swanlab_public_config = json.load(f)
swanlab_link = swanlab_public_config["cloud"]["experiment_url"]
if swanlab_link is not None:
running_info["swanlab_link"] = gr.Markdown(
ALERTS["info_swanlab_link"][lang] + swanlab_link, visible=True
)
return running_log, running_progress, running_info
def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown":