video datasets
Former-commit-id: 33f28ce82d9e44d2615909250dc56d6a4a03cd99
This commit is contained in:
@@ -79,14 +79,19 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
processor: Optional["ProcessorMixin"] = None
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
batch_images, batch_imglens, batch_seqlens = [], [], []
|
||||
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens = [], [], [], [], []
|
||||
for feature in features:
|
||||
images = feature.pop("images") or [] # avoid NoneType
|
||||
videos = feature.pop("videos") or []
|
||||
batch_images.extend(images)
|
||||
batch_videos.extend(videos)
|
||||
batch_imglens.append(len(images))
|
||||
batch_vidlens.append(len(videos))
|
||||
batch_seqlens.append(len(feature["input_ids"]))
|
||||
|
||||
mm_inputs = self.template.mm_plugin.get_mm_inputs(batch_images, batch_imglens, batch_seqlens, self.processor)
|
||||
mm_inputs = self.template.mm_plugin.get_mm_inputs(
|
||||
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens, self.processor
|
||||
)
|
||||
if "token_type_ids" in mm_inputs:
|
||||
token_type_ids = mm_inputs.pop("token_type_ids")
|
||||
for i, feature in enumerate(features):
|
||||
@@ -136,6 +141,7 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
"attention_mask": feature["{}_attention_mask".format(key)],
|
||||
"labels": feature["{}_labels".format(key)],
|
||||
"images": feature["images"],
|
||||
"videos": feature["videos"],
|
||||
}
|
||||
concatenated_features.append(target_feature)
|
||||
|
||||
@@ -158,12 +164,14 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
"attention_mask": feature["attention_mask"],
|
||||
"labels": feature["labels"],
|
||||
"images": feature["images"],
|
||||
"videos": feature["videos"],
|
||||
}
|
||||
kl_feature = {
|
||||
"input_ids": feature["kl_input_ids"],
|
||||
"attention_mask": feature["kl_attention_mask"],
|
||||
"labels": feature["kl_labels"],
|
||||
"images": feature["images"],
|
||||
"videos": feature["videos"],
|
||||
}
|
||||
target_features.append(target_feature)
|
||||
kl_features.append(kl_feature)
|
||||
|
||||
Reference in New Issue
Block a user