[data] fix qwen omni plugin (#9204)

Co-authored-by: kingsley <kingsleydodonow@gmail.com>
This commit is contained in:
Yaowei Zheng
2025-09-28 01:02:29 +08:00
committed by GitHub
parent 0761a4448f
commit 6ffebe5ff7
15 changed files with 292 additions and 210 deletions

View File

@@ -194,7 +194,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2.5 omni
if getattr(self.model.config, "model_type", None) in ["qwen2_5_omni_thinker", "qwen3_omni_moe_thinker"]:
rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
if feature_attention_mask is not None: # FIXME: need to get video image lengths
@@ -205,13 +205,22 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum(
dim=-1
).unsqueeze(-1)
else: # for qwen2vl
else: # for qwen vl
features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs)
if (
self.model is not None
and getattr(self.model.config, "model_type", None)
in ["glm4v", "Keye", "qwen2_vl", "qwen2_5_vl", "qwen2_5_omni_thinker"]
in [
"glm4v",
"Keye",
"qwen2_vl",
"qwen2_5_vl",
"qwen2_5_omni_thinker",
"qwen3_omni_moe_thinker",
"qwen3_vl",
"qwen3_vl_moe",
]
and ("position_ids" not in features or features["position_ids"].dim() != 3)
):
raise ValueError(f"{self.model.config.model_type} requires 3D position ids for mrope.")

View File

@@ -1397,8 +1397,8 @@ class Qwen2AudioPlugin(BasePlugin):
@dataclass
class Qwen2VLPlugin(BasePlugin):
start_token: str = "<|vision_start|>"
end_token: str = "<|vision_end|>"
vision_bos_token: str = "<|vision_start|>"
vision_eos_token: str = "<|vision_end|>"
@override
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
@@ -1515,14 +1515,18 @@ class Qwen2VLPlugin(BasePlugin):
while IMAGE_PLACEHOLDER in content:
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
IMAGE_PLACEHOLDER, f"{self.start_token}{self.image_token * image_seqlen}{self.end_token}", 1
IMAGE_PLACEHOLDER,
f"{self.vision_bos_token}{self.image_token * image_seqlen}{self.vision_eos_token}",
1,
)
num_image_tokens += 1
while VIDEO_PLACEHOLDER in content:
video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
VIDEO_PLACEHOLDER, f"{self.start_token}{self.video_token * video_seqlen}{self.end_token}", 1
VIDEO_PLACEHOLDER,
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}",
1,
)
num_video_tokens += 1
@@ -1611,7 +1615,9 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
image_grid_thw[num_image_tokens].prod() // image_merge_length if self.expand_mm_tokens else 1
)
content = content.replace(
IMAGE_PLACEHOLDER, f"{self.start_token}{self.image_token * image_seqlen}{self.end_token}", 1
IMAGE_PLACEHOLDER,
f"{self.vision_bos_token}{self.image_token * image_seqlen}{self.vision_eos_token}",
1,
)
num_image_tokens += 1
@@ -1630,11 +1636,14 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
else 1
)
timestamp_sec = timestamps[frame_index]
frame_structure = f"<{timestamp_sec:.1f} seconds>{self.start_token}{self.video_token * video_seqlen}{self.end_token}"
frame_structure = (
f"<{timestamp_sec:.1f} seconds>"
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}"
)
video_structure += frame_structure
if not self.expand_mm_tokens:
video_structure = f"{self.start_token}{self.video_token}{self.end_token}"
video_structure = f"{self.vision_bos_token}{self.video_token}{self.vision_eos_token}"
content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1)
num_video_tokens += 1
@@ -1774,7 +1783,11 @@ class GLM4VPlugin(Qwen2VLPlugin):
return mm_inputs
@dataclass
class Qwen2OmniPlugin(Qwen2VLPlugin):
audio_bos_token: str = "<|audio_start|>"
audio_eos_token: str = "<|audio_end|>"
@override
def _get_mm_inputs(
self,
@@ -1861,7 +1874,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
while IMAGE_PLACEHOLDER in content:
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
IMAGE_PLACEHOLDER, f"<|vision_bos|>{self.image_token * image_seqlen}<|vision_eos|>", 1
IMAGE_PLACEHOLDER,
f"{self.vision_bos_token}{self.image_token * image_seqlen}{self.vision_eos_token}",
1,
)
num_image_tokens += 1
@@ -1898,7 +1913,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
placeholder_string = ""
placeholder_string += "<|vision_bos|>" + "<|audio_bos|>"
placeholder_string += self.vision_bos_token + self.audio_bos_token
for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))):
video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None
audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None
@@ -1908,7 +1923,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
if audio_chunk_index is not None:
placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0])
placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
placeholder_string += self.audio_eos_token + self.vision_eos_token
content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1)
content = content.replace(AUDIO_PLACEHOLDER, "", 1)
num_audio_tokens += 1
@@ -1917,7 +1932,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
while AUDIO_PLACEHOLDER in content:
audio_seqlen = audio_lengths[num_audio_tokens] if self.expand_mm_tokens else 1
content = content.replace(
AUDIO_PLACEHOLDER, f"<|audio_bos|>{self.audio_token * audio_seqlen}<|audio_eos|>", 1
AUDIO_PLACEHOLDER,
f"{self.audio_bos_token}{self.audio_token * audio_seqlen}{self.audio_eos_token}",
1,
)
num_audio_tokens += 1
@@ -1926,7 +1943,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
)
content = content.replace(
VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_seqlen}<|vision_eos|>", 1
VIDEO_PLACEHOLDER,
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}",
1,
)
num_video_tokens += 1

View File

@@ -922,8 +922,8 @@ register_template(
name="qwen2_vl",
image_token="<|imgpad|>",
video_token="<|vidpad|>",
start_token="<|img|>",
end_token="<|endofimg|>",
vision_bos_token="<|img|>",
vision_eos_token="<|endofimg|>",
),
)
@@ -1862,7 +1862,14 @@ register_template(
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(
name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>"
name="qwen2_omni",
image_token="<|IMAGE|>",
video_token="<|VIDEO|>",
audio_token="<|AUDIO|>",
vision_bos_token="<|vision_bos|>",
vision_eos_token="<|vision_eos|>",
audio_bos_token="<|audio_bos|>",
audio_eos_token="<|audio_eos|>",
),
)
@@ -1880,7 +1887,7 @@ register_template(
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(
name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>"
name="qwen2_omni", image_token="<|image_pad|>", video_token="<|video_pad|>", audio_token="<|audio_pad|>"
),
template_class=ReasoningTemplate,
)
@@ -1899,7 +1906,7 @@ register_template(
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(
name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>"
name="qwen2_omni", image_token="<|image_pad|>", video_token="<|video_pad|>", audio_token="<|audio_pad|>"
),
)