support rank0 logger

Former-commit-id: 84528eabe560091bfd866b6a0ca864085af7529b
This commit is contained in:
hiyouga
2024-11-02 18:31:04 +08:00
parent ceb701c2d4
commit 093eda2ad6
42 changed files with 316 additions and 252 deletions

View File

@@ -18,11 +18,11 @@
from typing import TYPE_CHECKING, List, Sequence, Set, Tuple, Union
import torch
import transformers
import transformers.models
from transformers.activations import ACT2FN
from transformers.utils import logging
from ...extras.logging import get_logger
from ...extras import logging
if TYPE_CHECKING:
@@ -31,8 +31,8 @@ if TYPE_CHECKING:
from ...hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__)
transformers_logger = logging.get_logger(__name__)
logger = logging.get_logger(__name__)
transformers_logger = transformers.utils.logging.get_logger(__name__)
class LlavaMultiModalProjectorForYiVL(torch.nn.Module):
@@ -99,7 +99,7 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
else:
return
logger.info(f"Casting multimodal projector outputs in {model_args.compute_dtype}.")
logger.info_rank0(f"Casting multimodal projector outputs in {model_args.compute_dtype}.")
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
@@ -119,7 +119,7 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
if getattr(config, "is_yi_vl_derived_model", None):
logger.info("Detected Yi-VL model, applying projector patch.")
logger.info_rank0("Detected Yi-VL model, applying projector patch.")
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL