fix gemma2 attention

Former-commit-id: aeafc68e169ae0ea5939cc81cb0cf89f0ca044b6
This commit is contained in:
hiyouga
2024-07-13 23:33:45 +08:00
parent 6e7048831b
commit 5ab997d484
7 changed files with 53 additions and 26 deletions

View File

@@ -28,11 +28,10 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.
e.g.
```
```python
# input
[[1, 1, 2, 2, 2, 0]]
```
->
```
# output
[
[
[