Former-commit-id: 71a6861667ae68c1fd6a69acf68e1359b858cf1b
This commit is contained in:
hiyouga
2024-08-05 23:48:19 +08:00
parent 2e477e7458
commit 13093963b1
13 changed files with 111 additions and 69 deletions

View File

@@ -41,11 +41,11 @@ from typing import TYPE_CHECKING, Tuple
import torch
import torch.nn.functional as F
import transformers.models
from transformers.utils.versions import require_version
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
from ...extras.logging import get_logger
from ...extras.packages import is_transformers_version_greater_than_4_43
if TYPE_CHECKING:
@@ -114,7 +114,15 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
def _patch_for_block_diag_attn(model_type: str) -> None:
require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4")
require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4")
if is_transformers_version_greater_than_4_43():
import transformers.modeling_flash_attention_utils
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
return
import transformers.models
if model_type == "cohere":
transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data
elif model_type == "falcon":