[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -37,7 +37,7 @@
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -59,8 +59,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
|
||||
r"""
|
||||
Gets the sequnce lengths in the current batch.
|
||||
r"""Get the sequnce lengths in the current batch.
|
||||
|
||||
e.g.
|
||||
```python
|
||||
@@ -76,7 +75,7 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
|
||||
bsz = attention_mask.size(0)
|
||||
dtype, device = attention_mask.dtype, attention_mask.device
|
||||
max_num = torch.max(attention_mask).item()
|
||||
counts: "torch.Tensor" = torch.zeros((bsz, max_num), dtype=dtype, device=device)
|
||||
counts: torch.Tensor = torch.zeros((bsz, max_num), dtype=dtype, device=device)
|
||||
for i in range(max_num):
|
||||
counts[:, i] = torch.sum(attention_mask == (i + 1), dim=-1)
|
||||
|
||||
@@ -85,9 +84,8 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
|
||||
return seqlens
|
||||
|
||||
|
||||
def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "torch.Tensor", int]:
|
||||
r"""
|
||||
Prepares the indices and seqlens for flash attn varlen function.
|
||||
def get_unpad_data(attention_mask: "torch.Tensor") -> tuple["torch.Tensor", "torch.Tensor", int]:
|
||||
r"""Prepare the indices and seqlens for flash attn varlen function.
|
||||
|
||||
Returns:
|
||||
indices: indices of non-masked tokens from the flattened sequence.
|
||||
@@ -106,6 +104,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
|
||||
[0, 2, 5, 6, 8, 11]
|
||||
3
|
||||
```
|
||||
|
||||
"""
|
||||
seqlens_in_batch = get_seqlens_in_batch(attention_mask)
|
||||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||
|
||||
Reference in New Issue
Block a user