fix inputs

Former-commit-id: 7d535bb8cdf7e81edda81152e63c8cfe6c9dcc9f
This commit is contained in:
hiyouga
2024-11-23 18:25:45 +00:00
parent cd2485f28d
commit 5003820a6a
14 changed files with 148 additions and 95 deletions

View File

@@ -79,7 +79,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
processor: Optional["ProcessorMixin"] = None
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens = [], [], [], [], []
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids = [], [], [], [], []
for feature in features:
images = feature.pop("images", None) or []
videos = feature.pop("videos", None) or []
@@ -87,10 +87,10 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
batch_videos.extend(videos)
batch_imglens.append(len(images))
batch_vidlens.append(len(videos))
batch_seqlens.append(len(feature["input_ids"]))
batch_input_ids.append(feature["input_ids"])
mm_inputs = self.template.mm_plugin.get_mm_inputs(
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens, self.processor
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids, self.processor
)
if "token_type_ids" in mm_inputs:
token_type_ids = mm_inputs.pop("token_type_ids")