[data] fix qwen2vl pos ids (#8387)
This commit is contained in:
@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from peft import PeftModel
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER
|
||||
@@ -94,6 +95,16 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
if self.template is None:
|
||||
raise ValueError("Template is required for MultiModalDataCollator.")
|
||||
|
||||
if isinstance(self.model, PeftModel):
|
||||
self.model = self.model.base_model.model
|
||||
|
||||
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
|
||||
self.get_rope_func = self.model.get_rope_index # transformers < 4.52.0 or qwen2.5 omni
|
||||
elif self.model is not None and hasattr(self.model, "model") and hasattr(self.model.model, "get_rope_index"):
|
||||
self.get_rope_func = self.model.model.get_rope_index # transformers >= 4.52.0
|
||||
else:
|
||||
self.get_rope_func = None
|
||||
|
||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
||||
batch_images, batch_videos, batch_audios = [], [], []
|
||||
batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
|
||||
@@ -171,7 +182,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
|
||||
features: dict[str, torch.Tensor] = super().__call__(features)
|
||||
|
||||
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
|
||||
if self.get_rope_func is not None:
|
||||
rope_index_kwargs = {
|
||||
"input_ids": features["input_ids"],
|
||||
"image_grid_thw": mm_inputs.get("image_grid_thw"),
|
||||
@@ -180,27 +191,29 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
}
|
||||
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
|
||||
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
|
||||
if "video_second_per_grid" in mm_inputs: # for qwen2omni
|
||||
elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni
|
||||
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
|
||||
|
||||
if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2omni
|
||||
if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2.5 omni
|
||||
rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
|
||||
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
|
||||
if feature_attention_mask is not None:
|
||||
audio_feature_lengths = torch.sum(
|
||||
feature_attention_mask, dim=1
|
||||
) # FIXME need to get video image lengths
|
||||
if feature_attention_mask is not None: # FIXME: need to get video image lengths
|
||||
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
||||
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input
|
||||
|
||||
delta0 = (1 - rope_index_kwargs["attention_mask"]).sum(dim=-1).unsqueeze(1)
|
||||
# avoid conflict
|
||||
new_position_ids, rope_deltas = self.model.get_rope_index(**rope_index_kwargs)
|
||||
features["position_ids"], features["rope_deltas"] = (
|
||||
new_position_ids.clone(),
|
||||
rope_deltas - delta0,
|
||||
) # avoid inplace operation FIXME
|
||||
features["position_ids"], rope_deltas = self.get_rope_func(**rope_index_kwargs)
|
||||
features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum(
|
||||
dim=-1
|
||||
).unsqueeze(-1)
|
||||
else: # for qwen2vl
|
||||
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs)
|
||||
features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs)
|
||||
|
||||
if (
|
||||
self.model is not None
|
||||
and getattr(self.model.config, "model_type", None) in ["qwen2_vl", "qwen2_5_vl", "qwen2_5_omni_thinker"]
|
||||
and ("position_ids" not in features or features["position_ids"].dim() != 3)
|
||||
):
|
||||
raise ValueError("Qwen2-VL/Qwen2.5-Omni model requires 3D position ids for mrope.")
|
||||
|
||||
if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
|
||||
cross_attention_mask = mm_inputs.pop("cross_attention_mask")
|
||||
|
||||
Reference in New Issue
Block a user