Former-commit-id: 9b211861eba19ae9fc360bc96eeb8ad67ba40c49
This commit is contained in:
hiyouga
2024-07-04 03:47:05 +08:00
parent 0517d7bee5
commit 8567dab167
2 changed files with 13 additions and 2 deletions

View File

@@ -28,6 +28,11 @@ def test_get_seqlens_in_batch():
assert list(seqlens_in_batch.size()) == [5]
assert torch.all(seqlens_in_batch == torch.tensor([2, 3, 1, 2, 3]))
attention_mask_with_indices = torch.tensor([[1, 1, 1]])
seqlens_in_batch = get_seqlens_in_batch(attention_mask_with_indices)
assert list(seqlens_in_batch.size()) == [1]
assert torch.all(seqlens_in_batch == torch.tensor([3]))
def test_get_unpad_data():
attention_mask_with_indices = torch.tensor(
@@ -40,3 +45,9 @@ def test_get_unpad_data():
assert torch.all(indices == torch.tensor([0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11]))
assert torch.all(cu_seqlens == torch.tensor([0, 2, 5, 6, 8, 11], dtype=torch.int32))
assert max_seqlen_in_batch == 3
attention_mask_with_indices = torch.tensor([[1, 1, 1]])
indices, cu_seqlens, max_seqlen_in_batch = get_unpad_data(attention_mask_with_indices)
assert torch.all(indices == torch.tensor([0, 1, 2]))
assert torch.all(cu_seqlens == torch.tensor([0, 3], dtype=torch.int32))
assert max_seqlen_in_batch == 3