tiny fix
Former-commit-id: 9b211861eba19ae9fc360bc96eeb8ad67ba40c49
This commit is contained in:
@@ -74,13 +74,13 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
|
||||
"""
|
||||
bsz = attention_mask.size(0)
|
||||
dtype, device = attention_mask.dtype, attention_mask.device
|
||||
max_num = torch.max(attention_mask)
|
||||
max_num = torch.max(attention_mask).item()
|
||||
counts: "torch.Tensor" = torch.zeros((bsz, max_num), dtype=dtype, device=device)
|
||||
for i in range(max_num):
|
||||
counts[:, i] = torch.sum(attention_mask == (i + 1), dim=-1)
|
||||
|
||||
counts = counts.flatten()
|
||||
seqlens = counts[counts.nonzero().squeeze()]
|
||||
seqlens = counts[counts.nonzero().squeeze(dim=-1)]
|
||||
return seqlens
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user