fix mrope

Former-commit-id: 55bee1d333549ca19858b3f5c1b7b86926e5fb09
This commit is contained in:
hiyouga
2024-12-12 15:08:17 +00:00
parent cfff136b2a
commit fb22651faf
11 changed files with 32 additions and 9 deletions

View File

@@ -121,6 +121,15 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
feature["token_type_ids"] = token_type_ids[i]
features: Dict[str, "torch.Tensor"] = super().__call__(features)
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
features["position_ids"], _ = self.model.get_rope_index(
input_ids=features["input_ids"],
image_grid_thw=mm_inputs.get("image_grid_thw", None),
video_grid_thw=mm_inputs.get("video_grid_thw", None),
attention_mask=features["attention_mask"],
)
if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
cross_attention_mask = mm_inputs.pop("cross_attention_mask")
seq_len = features["input_ids"].size(1)