support rank0 logger

Former-commit-id: 84528eabe560091bfd866b6a0ca864085af7529b
This commit is contained in:
hiyouga
2024-11-02 18:31:04 +08:00
parent ceb701c2d4
commit 093eda2ad6
42 changed files with 316 additions and 252 deletions

View File

@@ -17,8 +17,8 @@ from typing import TYPE_CHECKING, Dict
import torch
from transformers.utils import cached_file
from ...extras import logging
from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ...extras.logging import get_logger
if TYPE_CHECKING:
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
@@ -54,8 +54,8 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
except Exception as err:
err_text = str(err)
logger.info(f"Provided path ({path_or_repo_id}) does not contain value head weights: {err_text}.")
logger.info("Ignore the above message if you are not resuming the training of a value head model.")
logger.info_rank0(f"Provided path ({path_or_repo_id}) does not contain value head weights: {err_text}.")
logger.info_rank0("Ignore the above message if you are not resuming the training of a value head model.")
return None