fix inputs
Former-commit-id: 7d535bb8cdf7e81edda81152e63c8cfe6c9dcc9f
This commit is contained in:
@@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
|
||||
from ...extras.packages import is_transformers_version_greater_than_4_43
|
||||
from ...extras.packages import is_transformers_version_greater_than
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -115,7 +115,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
|
||||
|
||||
def _patch_for_block_diag_attn(model_type: str) -> None:
|
||||
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
|
||||
if is_transformers_version_greater_than_4_43():
|
||||
if is_transformers_version_greater_than("4.43.0"):
|
||||
import transformers.modeling_flash_attention_utils
|
||||
|
||||
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
|
||||
|
||||
Reference in New Issue
Block a user