@@ -43,11 +43,6 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.packages import is_transformers_version_greater_than
|
||||
|
||||
|
||||
if is_transformers_version_greater_than("4.43.0"):
|
||||
import transformers.modeling_flash_attention_utils
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -116,5 +111,7 @@ def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if not is_trainable or not model_args.block_diag_attn:
|
||||
return
|
||||
|
||||
import transformers.modeling_flash_attention_utils
|
||||
|
||||
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
|
||||
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
|
||||
|
||||
Reference in New Issue
Block a user