fix mllama cross_mask

Former-commit-id: c33967308bebd99489d28bd5a879525cf304c1f9
This commit is contained in:
hiyouga
2024-11-26 15:54:44 +00:00
parent 5a52e41399
commit 62eeafaba6
2 changed files with 9 additions and 2 deletions

View File

@@ -241,7 +241,7 @@ class BasePlugin:
videos: a list of video inputs, shape (num_videos,)
imglens: number of images in each sample, shape (batch_size,)
vidlens: number of videos in each sample, shape (batch_size,)
batch_ids: input ids of samples, shape (batch_size, seq_len)
batch_ids: token ids of input samples, shape (batch_size, seq_len)
processor: a processor for pre-processing images and videos
"""
self._validate_input(images, videos)
@@ -760,7 +760,7 @@ class MllamaPlugin(BasePlugin):
max_num_tiles=max_image_tiles,
length=max(len(input_ids) for input_ids in batch_ids),
)
)
) # shape: (batch_size, length, max_num_images, max_num_tiles)
return mm_inputs