fix inputs
Former-commit-id: 7d535bb8cdf7e81edda81152e63c8cfe6c9dcc9f
This commit is contained in:
@@ -79,7 +79,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
processor: Optional["ProcessorMixin"] = None
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens = [], [], [], [], []
|
||||
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids = [], [], [], [], []
|
||||
for feature in features:
|
||||
images = feature.pop("images", None) or []
|
||||
videos = feature.pop("videos", None) or []
|
||||
@@ -87,10 +87,10 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
batch_videos.extend(videos)
|
||||
batch_imglens.append(len(images))
|
||||
batch_vidlens.append(len(videos))
|
||||
batch_seqlens.append(len(feature["input_ids"]))
|
||||
batch_input_ids.append(feature["input_ids"])
|
||||
|
||||
mm_inputs = self.template.mm_plugin.get_mm_inputs(
|
||||
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens, self.processor
|
||||
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids, self.processor
|
||||
)
|
||||
if "token_type_ids" in mm_inputs:
|
||||
token_type_ids = mm_inputs.pop("token_type_ids")
|
||||
|
||||
Reference in New Issue
Block a user