support rank0 logger
Former-commit-id: 84528eabe560091bfd866b6a0ca864085af7529b
This commit is contained in:
@@ -43,8 +43,8 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.packages import is_transformers_version_greater_than_4_43
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ if TYPE_CHECKING:
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
|
||||
@@ -152,6 +152,6 @@ def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments",
|
||||
model_type = getattr(config, "model_type", None)
|
||||
if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
|
||||
_patch_for_block_diag_attn(model_type)
|
||||
logger.info("Using block diagonal attention for sequence packing without cross-attention.")
|
||||
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
|
||||
else:
|
||||
raise ValueError("Current model does not support block diagonal attention.")
|
||||
|
||||
Reference in New Issue
Block a user