support rank0 logger
Former-commit-id: 84528eabe560091bfd866b6a0ca864085af7529b
This commit is contained in:
@@ -18,11 +18,11 @@
|
||||
from typing import TYPE_CHECKING, List, Sequence, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
import transformers.models
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.utils import logging
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -31,8 +31,8 @@ if TYPE_CHECKING:
|
||||
from ...hparams import FinetuningArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
transformers_logger = logging.get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
transformers_logger = transformers.utils.logging.get_logger(__name__)
|
||||
|
||||
|
||||
class LlavaMultiModalProjectorForYiVL(torch.nn.Module):
|
||||
@@ -99,7 +99,7 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
|
||||
else:
|
||||
return
|
||||
|
||||
logger.info(f"Casting multimodal projector outputs in {model_args.compute_dtype}.")
|
||||
logger.info_rank0(f"Casting multimodal projector outputs in {model_args.compute_dtype}.")
|
||||
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
|
||||
|
||||
|
||||
@@ -119,7 +119,7 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
|
||||
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
|
||||
|
||||
if getattr(config, "is_yi_vl_derived_model", None):
|
||||
logger.info("Detected Yi-VL model, applying projector patch.")
|
||||
logger.info_rank0("Detected Yi-VL model, applying projector patch.")
|
||||
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user