[webui] display swanlab exp link (#7089)
* webui add swanlab link * change callback name * update --------- Co-authored-by: hiyouga <hiyouga@buaa.edu.cn> Former-commit-id: 27a4b93871c63b839c92940766bd7e0177972c9b
This commit is contained in:
@@ -17,6 +17,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
@@ -31,7 +33,7 @@ from transformers.trainer_pt_utils import get_parameter_names
|
||||
from typing_extensions import override
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import IGNORE_INDEX
|
||||
from ..extras.constants import IGNORE_INDEX, SWANLAB_CONFIG
|
||||
from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available
|
||||
from ..hparams import FinetuningArguments, ModelArguments
|
||||
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
|
||||
@@ -51,7 +53,7 @@ if is_ray_available():
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel, TrainerCallback
|
||||
from transformers import PreTrainedModel, TrainerCallback, TrainerState
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..hparams import DataArguments, RayArguments, TrainingArguments
|
||||
@@ -592,7 +594,17 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
|
||||
if finetuning_args.swanlab_api_key is not None:
|
||||
swanlab.login(api_key=finetuning_args.swanlab_api_key)
|
||||
|
||||
swanlab_callback = SwanLabCallback(
|
||||
class SwanLabCallbackExtension(SwanLabCallback):
|
||||
def setup(self, args: "TrainingArguments", state: "TrainerState", model: "PreTrainedModel", **kwargs):
|
||||
if not state.is_world_process_zero:
|
||||
return
|
||||
|
||||
super().setup(args, state, model, **kwargs)
|
||||
swanlab_public_config = self._experiment.get_run().public.json()
|
||||
with open(os.path.join(args.output_dir, SWANLAB_CONFIG), "w") as f:
|
||||
f.write(json.dumps(swanlab_public_config, indent=2))
|
||||
|
||||
swanlab_callback = SwanLabCallbackExtension(
|
||||
project=finetuning_args.swanlab_project,
|
||||
workspace=finetuning_args.swanlab_workspace,
|
||||
experiment_name=finetuning_args.swanlab_run_name,
|
||||
|
||||
Reference in New Issue
Block a user