update packing
Former-commit-id: f3d9c31efa0e64317bdd5b4ed6f78653cf3b5ba4
This commit is contained in:
@@ -29,20 +29,22 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
|
||||
|
||||
e.g.
|
||||
```
|
||||
[1, 1, 2, 2, 2, 0]
|
||||
[[1, 1, 2, 2, 2, 0]]
|
||||
```
|
||||
->
|
||||
```
|
||||
[[
|
||||
[
|
||||
[o, x, x, x, x, x],
|
||||
[o, o, x, x, x, x],
|
||||
[x, x, o, x, x, x],
|
||||
[x, x, o, o, x, x],
|
||||
[x, x, o, o, o, x],
|
||||
[x, x, o, x, x, x],
|
||||
]
|
||||
]]
|
||||
[
|
||||
[
|
||||
[
|
||||
[o, x, x, x, x, x],
|
||||
[o, o, x, x, x, x],
|
||||
[x, x, o, x, x, x],
|
||||
[x, x, o, o, x, x],
|
||||
[x, x, o, o, o, x],
|
||||
[x, x, o, x, x, x],
|
||||
]
|
||||
]
|
||||
]
|
||||
```
|
||||
where `o` equals to `0.0`, `x` equals to `min_dtype`.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user