@@ -106,9 +106,15 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
|
||||
fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor)
|
||||
fake_input_ids = self.processor.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
|
||||
features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids
|
||||
features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids)
|
||||
features[0]["labels"] = features[0]["labels"] + [IGNORE_INDEX] * len(fake_input_ids)
|
||||
if self.tokenizer.padding_side == "right":
|
||||
features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids
|
||||
features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids)
|
||||
features[0]["labels"] = features[0]["labels"] + [IGNORE_INDEX] * len(fake_input_ids)
|
||||
else:
|
||||
features[0]["input_ids"] = fake_input_ids + features[0]["input_ids"]
|
||||
features[0]["attention_mask"] = [0] * len(fake_input_ids) + features[0]["attention_mask"]
|
||||
features[0]["labels"] = [IGNORE_INDEX] * len(fake_input_ids) + features[0]["labels"]
|
||||
|
||||
batch_images = fake_images
|
||||
batch_input_ids[0] = features[0]["input_ids"]
|
||||
|
||||
@@ -123,7 +129,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
features: Dict[str, "torch.Tensor"] = super().__call__(features)
|
||||
|
||||
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
|
||||
features["position_ids"], _ = self.model.get_rope_index(
|
||||
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(
|
||||
input_ids=features["input_ids"],
|
||||
image_grid_thw=mm_inputs.get("image_grid_thw", None),
|
||||
video_grid_thw=mm_inputs.get("video_grid_thw", None),
|
||||
|
||||
Reference in New Issue
Block a user