support rank0 logger
Former-commit-id: 84528eabe560091bfd866b6a0ca864085af7529b
This commit is contained in:
@@ -17,8 +17,8 @@ from typing import TYPE_CHECKING, Dict
|
||||
import torch
|
||||
from transformers.utils import cached_file
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ...extras.logging import get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
||||
@@ -54,8 +54,8 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
|
||||
except Exception as err:
|
||||
err_text = str(err)
|
||||
|
||||
logger.info(f"Provided path ({path_or_repo_id}) does not contain value head weights: {err_text}.")
|
||||
logger.info("Ignore the above message if you are not resuming the training of a value head model.")
|
||||
logger.info_rank0(f"Provided path ({path_or_repo_id}) does not contain value head weights: {err_text}.")
|
||||
logger.info_rank0("Ignore the above message if you are not resuming the training of a value head model.")
|
||||
return None
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user