|
|
|
|
@@ -44,13 +44,14 @@ import torch.nn.functional as F
|
|
|
|
|
from transformers.utils.versions import require_version
|
|
|
|
|
|
|
|
|
|
from ...extras import logging
|
|
|
|
|
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
|
|
|
|
|
from ...extras.packages import is_transformers_version_greater_than
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from transformers import PretrainedConfig
|
|
|
|
|
if is_transformers_version_greater_than("4.43.0"):
|
|
|
|
|
import transformers.modeling_flash_attention_utils
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from ...hparams import ModelArguments
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -113,45 +114,10 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
|
|
|
|
|
return indices, cu_seqlens, max_seqlen_in_batch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _patch_for_block_diag_attn(model_type: str) -> None:
|
|
|
|
|
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
|
|
|
|
|
if is_transformers_version_greater_than("4.43.0"):
|
|
|
|
|
import transformers.modeling_flash_attention_utils
|
|
|
|
|
|
|
|
|
|
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
import transformers.models
|
|
|
|
|
|
|
|
|
|
if model_type == "cohere":
|
|
|
|
|
transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data
|
|
|
|
|
elif model_type == "falcon":
|
|
|
|
|
transformers.models.falcon.modeling_falcon._get_unpad_data = get_unpad_data
|
|
|
|
|
elif model_type == "gemma":
|
|
|
|
|
transformers.models.gemma.modeling_gemma._get_unpad_data = get_unpad_data
|
|
|
|
|
elif model_type == "gemma2":
|
|
|
|
|
transformers.models.gemma2.modeling_gemma2._get_unpad_data = get_unpad_data
|
|
|
|
|
elif model_type == "llama":
|
|
|
|
|
transformers.models.llama.modeling_llama._get_unpad_data = get_unpad_data
|
|
|
|
|
elif model_type == "mistral":
|
|
|
|
|
transformers.models.mistral.modeling_mistral._get_unpad_data = get_unpad_data
|
|
|
|
|
elif model_type == "phi":
|
|
|
|
|
transformers.models.phi.modeling_phi._get_unpad_data = get_unpad_data
|
|
|
|
|
elif model_type == "phi3":
|
|
|
|
|
transformers.models.phi3.modeling_phi3._get_unpad_data = get_unpad_data
|
|
|
|
|
elif model_type == "qwen2":
|
|
|
|
|
transformers.models.qwen2.modeling_qwen2._get_unpad_data = get_unpad_data
|
|
|
|
|
elif model_type == "starcoder2":
|
|
|
|
|
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = get_unpad_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
|
|
|
|
def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
|
|
|
|
|
if not is_trainable or not model_args.block_diag_attn:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
model_type = getattr(config, "model_type", None)
|
|
|
|
|
if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
|
|
|
|
|
_patch_for_block_diag_attn(model_type)
|
|
|
|
|
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Current model does not support block diagonal attention.")
|
|
|
|
|
require_version("transformers>=4.43.0,<=4.46.1", "To fix: pip install transformers>=4.43.0,<=4.46.1")
|
|
|
|
|
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
|
|
|
|
|
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
|
|
|
|
|
|