fix gemma2 attention
Former-commit-id: aeafc68e169ae0ea5939cc81cb0cf89f0ca044b6
This commit is contained in:
@@ -42,6 +42,7 @@ from typing import TYPE_CHECKING, Tuple
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers.models
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
|
||||
from ...extras.logging import get_logger
|
||||
@@ -61,14 +62,13 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
|
||||
Gets the sequnce lengths in the current batch.
|
||||
|
||||
e.g.
|
||||
```
|
||||
```python
|
||||
# input
|
||||
[
|
||||
[1, 1, 2, 2, 2, 0],
|
||||
[1, 2, 2, 3, 3, 3],
|
||||
]
|
||||
```
|
||||
->
|
||||
```
|
||||
# output
|
||||
[2, 3, 1, 2, 3]
|
||||
```
|
||||
"""
|
||||
@@ -94,14 +94,13 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
|
||||
max_seqlen_in_batch: the largest seqlen in the current batch.
|
||||
|
||||
e.g.
|
||||
```
|
||||
```python
|
||||
# input
|
||||
[
|
||||
[1, 1, 2, 2, 2, 0],
|
||||
[1, 2, 2, 3, 3, 3],
|
||||
]
|
||||
```
|
||||
->
|
||||
```
|
||||
# output
|
||||
[0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11]
|
||||
[0, 2, 5, 6, 8, 11]
|
||||
3
|
||||
@@ -114,7 +113,8 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
|
||||
return indices, cu_seqlens, max_seqlen_in_batch
|
||||
|
||||
|
||||
def patch_for_block_diag_attn(model_type: str) -> None:
|
||||
def _patch_for_block_diag_attn(model_type: str) -> None:
|
||||
require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4")
|
||||
if model_type == "cohere":
|
||||
transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data
|
||||
elif model_type == "falcon":
|
||||
@@ -143,7 +143,7 @@ def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments",
|
||||
|
||||
model_type = getattr(config, "model_type", None)
|
||||
if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
|
||||
patch_for_block_diag_attn(model_type)
|
||||
_patch_for_block_diag_attn(model_type)
|
||||
logger.info("Using block diagonal attention for sequence packing without cross-attention.")
|
||||
else:
|
||||
raise ValueError("Current model does not support block diagonal attention.")
|
||||
|
||||
Reference in New Issue
Block a user