fix gemma2 attention

Former-commit-id: aeafc68e169ae0ea5939cc81cb0cf89f0ca044b6
This commit is contained in:
hiyouga
2024-07-13 23:33:45 +08:00
parent 6e7048831b
commit 5ab997d484
7 changed files with 53 additions and 26 deletions

View File

@@ -42,6 +42,7 @@ from typing import TYPE_CHECKING, Tuple
import torch
import torch.nn.functional as F
import transformers.models
from transformers.utils.versions import require_version
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
from ...extras.logging import get_logger
@@ -61,14 +62,13 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
Gets the sequnce lengths in the current batch.
e.g.
```
```python
# input
[
[1, 1, 2, 2, 2, 0],
[1, 2, 2, 3, 3, 3],
]
```
->
```
# output
[2, 3, 1, 2, 3]
```
"""
@@ -94,14 +94,13 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
max_seqlen_in_batch: the largest seqlen in the current batch.
e.g.
```
```python
# input
[
[1, 1, 2, 2, 2, 0],
[1, 2, 2, 3, 3, 3],
]
```
->
```
# output
[0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11]
[0, 2, 5, 6, 8, 11]
3
@@ -114,7 +113,8 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
return indices, cu_seqlens, max_seqlen_in_batch
def patch_for_block_diag_attn(model_type: str) -> None:
def _patch_for_block_diag_attn(model_type: str) -> None:
require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4")
if model_type == "cohere":
transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data
elif model_type == "falcon":
@@ -143,7 +143,7 @@ def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments",
model_type = getattr(config, "model_type", None)
if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
patch_for_block_diag_attn(model_type)
_patch_for_block_diag_attn(model_type)
logger.info("Using block diagonal attention for sequence packing without cross-attention.")
else:
raise ValueError("Current model does not support block diagonal attention.")