support rank0 logger
Former-commit-id: 84528eabe560091bfd866b6a0ca864085af7529b
This commit is contained in:
@@ -18,7 +18,7 @@ import torch
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras import logging
|
||||
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
|
||||
from .adapter import init_adapter
|
||||
from .model_utils.liger_kernel import apply_liger_kernel
|
||||
@@ -35,7 +35,7 @@ if TYPE_CHECKING:
|
||||
from ..hparams import FinetuningArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class TokenizerModule(TypedDict):
|
||||
@@ -90,10 +90,10 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||
dict(additional_special_tokens=model_args.new_special_tokens),
|
||||
replace_additional_special_tokens=False,
|
||||
)
|
||||
logger.info("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
|
||||
logger.info_rank0("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
|
||||
if num_added_tokens > 0 and not model_args.resize_vocab:
|
||||
model_args.resize_vocab = True
|
||||
logger.warning("New tokens have been added, changed `resize_vocab` to True.")
|
||||
logger.warning_rank0("New tokens have been added, changed `resize_vocab` to True.")
|
||||
|
||||
patch_tokenizer(tokenizer)
|
||||
try:
|
||||
@@ -180,7 +180,7 @@ def load_model(
|
||||
vhead_params = load_valuehead_params(vhead_path, model_args)
|
||||
if vhead_params is not None:
|
||||
model.load_state_dict(vhead_params, strict=False)
|
||||
logger.info(f"Loaded valuehead from checkpoint: {vhead_path}")
|
||||
logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")
|
||||
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False)
|
||||
@@ -200,7 +200,7 @@ def load_model(
|
||||
else:
|
||||
param_stats = f"all params: {all_param:,}"
|
||||
|
||||
logger.info(param_stats)
|
||||
logger.info_rank0(param_stats)
|
||||
|
||||
if model_args.print_param_status:
|
||||
for name, param in model.named_parameters():
|
||||
|
||||
Reference in New Issue
Block a user