[webui] display swanlab exp link (#7089)

* webui add swanlab link

* change callback name

* update

---------

Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
Former-commit-id: 27a4b93871c63b839c92940766bd7e0177972c9b
This commit is contained in:
Ze-Yi LIN
2025-02-27 19:40:54 +08:00
committed by GitHub
parent b9f84900ee
commit 11672f760d
6 changed files with 63 additions and 17 deletions

View File

@@ -17,6 +17,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from collections.abc import Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
@@ -31,7 +33,7 @@ from transformers.trainer_pt_utils import get_parameter_names
from typing_extensions import override
from ..extras import logging
from ..extras.constants import IGNORE_INDEX
from ..extras.constants import IGNORE_INDEX, SWANLAB_CONFIG
from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available
from ..hparams import FinetuningArguments, ModelArguments
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
@@ -51,7 +53,7 @@ if is_ray_available():
if TYPE_CHECKING:
from transformers import PreTrainedModel, TrainerCallback
from transformers import PreTrainedModel, TrainerCallback, TrainerState
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import DataArguments, RayArguments, TrainingArguments
@@ -592,7 +594,17 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
if finetuning_args.swanlab_api_key is not None:
swanlab.login(api_key=finetuning_args.swanlab_api_key)
swanlab_callback = SwanLabCallback(
class SwanLabCallbackExtension(SwanLabCallback):
def setup(self, args: "TrainingArguments", state: "TrainerState", model: "PreTrainedModel", **kwargs):
if not state.is_world_process_zero:
return
super().setup(args, state, model, **kwargs)
swanlab_public_config = self._experiment.get_run().public.json()
with open(os.path.join(args.output_dir, SWANLAB_CONFIG), "w") as f:
f.write(json.dumps(swanlab_public_config, indent=2))
swanlab_callback = SwanLabCallbackExtension(
project=finetuning_args.swanlab_project,
workspace=finetuning_args.swanlab_workspace,
experiment_name=finetuning_args.swanlab_run_name,