[webui] display swanlab exp link (#7089)
* webui add swanlab link * change callback name * update --------- Co-authored-by: hiyouga <hiyouga@buaa.edu.cn> Former-commit-id: 27a4b93871c63b839c92940766bd7e0177972c9b
This commit is contained in:
@@ -23,6 +23,7 @@ from ..extras.constants import (
|
||||
PEFT_METHODS,
|
||||
RUNNING_LOG,
|
||||
STAGES_USE_PAIR_DATA,
|
||||
SWANLAB_CONFIG,
|
||||
TRAINER_LOG,
|
||||
TRAINING_STAGES,
|
||||
)
|
||||
@@ -30,6 +31,7 @@ from ..extras.packages import is_gradio_available, is_matplotlib_available
|
||||
from ..extras.ploting import gen_loss_plot
|
||||
from ..model import QuantizationMethod
|
||||
from .common import DEFAULT_CONFIG_DIR, DEFAULT_DATA_DIR, get_model_path, get_save_dir, get_template, load_dataset_info
|
||||
from .locales import ALERTS
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
@@ -86,20 +88,20 @@ def get_model_info(model_name: str) -> Tuple[str, str]:
|
||||
return get_model_path(model_name), get_template(model_name)
|
||||
|
||||
|
||||
def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]:
|
||||
def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Dict[str, Any]]:
|
||||
r"""
|
||||
Gets training infomation for monitor.
|
||||
|
||||
If do_train is True:
|
||||
Inputs: train.output_path
|
||||
Outputs: train.output_box, train.progress_bar, train.loss_viewer
|
||||
Inputs: top.lang, train.output_path
|
||||
Outputs: train.output_box, train.progress_bar, train.loss_viewer, train.swanlab_link
|
||||
If do_train is False:
|
||||
Inputs: eval.output_path
|
||||
Outputs: eval.output_box, eval.progress_bar, None
|
||||
Inputs: top.lang, eval.output_path
|
||||
Outputs: eval.output_box, eval.progress_bar, None, None
|
||||
"""
|
||||
running_log = ""
|
||||
running_progress = gr.Slider(visible=False)
|
||||
running_loss = None
|
||||
running_info = {}
|
||||
|
||||
running_log_path = os.path.join(output_path, RUNNING_LOG)
|
||||
if os.path.isfile(running_log_path):
|
||||
@@ -125,9 +127,19 @@ def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr
|
||||
running_progress = gr.Slider(label=label, value=percentage, visible=True)
|
||||
|
||||
if do_train and is_matplotlib_available():
|
||||
running_loss = gr.Plot(gen_loss_plot(trainer_log))
|
||||
running_info["loss_viewer"] = gr.Plot(gen_loss_plot(trainer_log))
|
||||
|
||||
return running_log, running_progress, running_loss
|
||||
swanlab_config_path = os.path.join(output_path, SWANLAB_CONFIG)
|
||||
if os.path.isfile(swanlab_config_path):
|
||||
with open(swanlab_config_path, encoding="utf-8") as f:
|
||||
swanlab_public_config = json.load(f)
|
||||
swanlab_link = swanlab_public_config["cloud"]["experiment_url"]
|
||||
if swanlab_link is not None:
|
||||
running_info["swanlab_link"] = gr.Markdown(
|
||||
ALERTS["info_swanlab_link"][lang] + swanlab_link, visible=True
|
||||
)
|
||||
|
||||
return running_log, running_progress, running_info
|
||||
|
||||
|
||||
def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown":
|
||||
|
||||
Reference in New Issue
Block a user