[data] fix qwen omni plugin (#9204)

Co-authored-by: kingsley <kingsleydodonow@gmail.com>
This commit is contained in:
Yaowei Zheng
2025-09-28 01:02:29 +08:00
committed by GitHub
parent 0761a4448f
commit 6ffebe5ff7
15 changed files with 292 additions and 210 deletions

View File

@@ -194,7 +194,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2.5 omni
if getattr(self.model.config, "model_type", None) in ["qwen2_5_omni_thinker", "qwen3_omni_moe_thinker"]:
rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
if feature_attention_mask is not None: # FIXME: need to get video image lengths
@@ -205,13 +205,22 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum(
dim=-1
).unsqueeze(-1)
else: # for qwen2vl
else: # for qwen vl
features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs)
if (
self.model is not None
and getattr(self.model.config, "model_type", None)
in ["glm4v", "Keye", "qwen2_vl", "qwen2_5_vl", "qwen2_5_omni_thinker"]
in [
"glm4v",
"Keye",
"qwen2_vl",
"qwen2_5_vl",
"qwen2_5_omni_thinker",
"qwen3_omni_moe_thinker",
"qwen3_vl",
"qwen3_vl_moe",
]
and ("position_ids" not in features or features["position_ids"].dim() != 3)
):
raise ValueError(f"{self.model.config.model_type} requires 3D position ids for mrope.")

View File

@@ -1397,8 +1397,8 @@ class Qwen2AudioPlugin(BasePlugin):
@dataclass
class Qwen2VLPlugin(BasePlugin):
start_token: str = "<|vision_start|>"
end_token: str = "<|vision_end|>"
vision_bos_token: str = "<|vision_start|>"
vision_eos_token: str = "<|vision_end|>"
@override
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
@@ -1515,14 +1515,18 @@ class Qwen2VLPlugin(BasePlugin):
while IMAGE_PLACEHOLDER in content:
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
IMAGE_PLACEHOLDER, f"{self.start_token}{self.image_token * image_seqlen}{self.end_token}", 1
IMAGE_PLACEHOLDER,
f"{self.vision_bos_token}{self.image_token * image_seqlen}{self.vision_eos_token}",
1,
)
num_image_tokens += 1
while VIDEO_PLACEHOLDER in content:
video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
VIDEO_PLACEHOLDER, f"{self.start_token}{self.video_token * video_seqlen}{self.end_token}", 1
VIDEO_PLACEHOLDER,
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}",
1,
)
num_video_tokens += 1
@@ -1611,7 +1615,9 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
image_grid_thw[num_image_tokens].prod() // image_merge_length if self.expand_mm_tokens else 1
)
content = content.replace(
IMAGE_PLACEHOLDER, f"{self.start_token}{self.image_token * image_seqlen}{self.end_token}", 1
IMAGE_PLACEHOLDER,
f"{self.vision_bos_token}{self.image_token * image_seqlen}{self.vision_eos_token}",
1,
)
num_image_tokens += 1
@@ -1630,11 +1636,14 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
else 1
)
timestamp_sec = timestamps[frame_index]
frame_structure = f"<{timestamp_sec:.1f} seconds>{self.start_token}{self.video_token * video_seqlen}{self.end_token}"
frame_structure = (
f"<{timestamp_sec:.1f} seconds>"
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}"
)
video_structure += frame_structure
if not self.expand_mm_tokens:
video_structure = f"{self.start_token}{self.video_token}{self.end_token}"
video_structure = f"{self.vision_bos_token}{self.video_token}{self.vision_eos_token}"
content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1)
num_video_tokens += 1
@@ -1774,7 +1783,11 @@ class GLM4VPlugin(Qwen2VLPlugin):
return mm_inputs
@dataclass
class Qwen2OmniPlugin(Qwen2VLPlugin):
audio_bos_token: str = "<|audio_start|>"
audio_eos_token: str = "<|audio_end|>"
@override
def _get_mm_inputs(
self,
@@ -1861,7 +1874,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
while IMAGE_PLACEHOLDER in content:
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
IMAGE_PLACEHOLDER, f"<|vision_bos|>{self.image_token * image_seqlen}<|vision_eos|>", 1
IMAGE_PLACEHOLDER,
f"{self.vision_bos_token}{self.image_token * image_seqlen}{self.vision_eos_token}",
1,
)
num_image_tokens += 1
@@ -1898,7 +1913,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
placeholder_string = ""
placeholder_string += "<|vision_bos|>" + "<|audio_bos|>"
placeholder_string += self.vision_bos_token + self.audio_bos_token
for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))):
video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None
audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None
@@ -1908,7 +1923,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
if audio_chunk_index is not None:
placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0])
placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
placeholder_string += self.audio_eos_token + self.vision_eos_token
content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1)
content = content.replace(AUDIO_PLACEHOLDER, "", 1)
num_audio_tokens += 1
@@ -1917,7 +1932,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
while AUDIO_PLACEHOLDER in content:
audio_seqlen = audio_lengths[num_audio_tokens] if self.expand_mm_tokens else 1
content = content.replace(
AUDIO_PLACEHOLDER, f"<|audio_bos|>{self.audio_token * audio_seqlen}<|audio_eos|>", 1
AUDIO_PLACEHOLDER,
f"{self.audio_bos_token}{self.audio_token * audio_seqlen}{self.audio_eos_token}",
1,
)
num_audio_tokens += 1
@@ -1926,7 +1943,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
)
content = content.replace(
VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_seqlen}<|vision_eos|>", 1
VIDEO_PLACEHOLDER,
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}",
1,
)
num_video_tokens += 1

View File

@@ -922,8 +922,8 @@ register_template(
name="qwen2_vl",
image_token="<|imgpad|>",
video_token="<|vidpad|>",
start_token="<|img|>",
end_token="<|endofimg|>",
vision_bos_token="<|img|>",
vision_eos_token="<|endofimg|>",
),
)
@@ -1862,7 +1862,14 @@ register_template(
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(
name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>"
name="qwen2_omni",
image_token="<|IMAGE|>",
video_token="<|VIDEO|>",
audio_token="<|AUDIO|>",
vision_bos_token="<|vision_bos|>",
vision_eos_token="<|vision_eos|>",
audio_bos_token="<|audio_bos|>",
audio_eos_token="<|audio_eos|>",
),
)
@@ -1880,7 +1887,7 @@ register_template(
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(
name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>"
name="qwen2_omni", image_token="<|image_pad|>", video_token="<|video_pad|>", audio_token="<|audio_pad|>"
),
template_class=ReasoningTemplate,
)
@@ -1899,7 +1906,7 @@ register_template(
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(
name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>"
name="qwen2_omni", image_token="<|image_pad|>", video_token="<|video_pad|>", audio_token="<|audio_pad|>"
),
)

View File

@@ -3060,13 +3060,14 @@ register_model_group(
multimodal=True,
)
register_model_group(
models={
"Qwen/Qwen3-Omni-30B-A3B-Captioner": {
"Qwen3-Omni-30B-A3B-Captioner": {
DownloadSource.DEFAULT: "Qwen/Qwen3-Omni-30B-A3B-Captioner",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-Omni-30B-A3B-Captioner",
},
"Qwen/Qwen3-Omni-30B-A3B-Instruct": {
"Qwen3-Omni-30B-A3B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-Omni-30B-A3B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-Omni-30B-A3B-Instruct",
},
@@ -3075,9 +3076,10 @@ register_model_group(
multimodal=True,
)
register_model_group(
models={
"Qwen/Qwen3-Omni-30B-A3B-Thinking": {
"Qwen3-Omni-30B-A3B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3-Omni-30B-A3B-Thinking",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-Omni-30B-A3B-Thinking",
},
@@ -3086,6 +3088,7 @@ register_model_group(
multimodal=True,
)
register_model_group(
models={
"Qwen2-VL-2B": {
@@ -3190,24 +3193,24 @@ register_model_group(
register_model_group(
models={
"Qwen/Qwen3-VL-235B-A22B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-235B-A22B-Thinking",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-235B-A22B-Thinking",
"Qwen3-VL-235B-A22B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-235B-A22B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-235B-A22B-Instruct",
},
},
template="qwen3_vl",
template="qwen3_vl_nothink",
multimodal=True,
)
register_model_group(
models={
"Qwen/Qwen3-VL-235B-A22B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-235B-A22B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-235B-A22B-Instruct",
"Qwen3-VL-235B-A22B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-235B-A22B-Thinking",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-235B-A22B-Thinking",
},
},
template="qwen3_vl_nothink",
template="qwen3_vl",
multimodal=True,
)

View File

@@ -94,7 +94,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def check_dependencies() -> None:
r"""Check the version of the required packages."""
check_version("transformers>=4.49.0,<=4.56.1")
check_version("transformers>=4.49.0,<=4.56.2")
check_version("datasets>=2.16.0,<=4.0.0")
check_version("accelerate>=1.3.0,<=1.10.1")
check_version("peft>=0.14.0,<=0.17.1")

View File

@@ -147,7 +147,7 @@ def _check_extra_dependencies(
check_version("mixture-of-depth>=1.1.6", mandatory=True)
if model_args.infer_backend == EngineName.VLLM:
check_version("vllm>=0.4.3,<=0.10.0")
check_version("vllm>=0.4.3,<=0.10.2")
check_version("vllm", mandatory=True)
elif model_args.infer_backend == EngineName.SGLANG:
check_version("sglang>=0.4.5")
@@ -174,7 +174,8 @@ def _check_extra_dependencies(
if training_args is not None:
if training_args.deepspeed:
# pin deepspeed version < 0.17 because of https://github.com/deepspeedai/DeepSpeed/issues/7347
check_version("deepspeed>=0.10.0,<=0.16.9", mandatory=True)
check_version("deepspeed", mandatory=True)
check_version("deepspeed>=0.10.0,<=0.16.9")
if training_args.predict_with_generate:
check_version("jieba", mandatory=True)

View File

@@ -162,7 +162,7 @@ def load_model(
load_class = AutoModelForVision2Seq
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text
load_class = AutoModelForSeq2SeqLM
elif type(config) in AutoModelForTextToWaveform._model_mapping.keys(): # audio hack for qwen2_5_omni
elif type(config) in AutoModelForTextToWaveform._model_mapping.keys(): # audio hack for qwen omni
load_class = AutoModelForTextToWaveform
else:
load_class = AutoModelForCausalLM
@@ -171,8 +171,8 @@ def load_model(
model = load_class.from_config(config, trust_remote_code=model_args.trust_remote_code)
else:
model = load_class.from_pretrained(**init_kwargs)
if getattr(model.config, "model_type", None) == "qwen2_5_omni":
model = model.thinker # use part of Omni model
if getattr(model.config, "model_type", None) in ["qwen2_5_omni", "qwen3_omni_moe"]:
model = getattr(model, "thinker")
if model_args.mixture_of_depths == "convert":
model = convert_pretrained_model_to_mod(model, config, model_args)

View File

@@ -298,6 +298,7 @@ _register_composite_model(
lora_conflict_keys=["audio_projection_layer"],
)
_register_composite_model(
model_type="mistral3",
)
@@ -351,6 +352,33 @@ _register_composite_model(
)
_register_composite_model(
model_type="qwen3_vl",
projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
_register_composite_model(
model_type="qwen3_vl_moe",
projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
_register_composite_model(
model_type="qwen3_omni_moe_thinker",
projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks", "audio_tower"],
language_model_keys=["model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
_register_composite_model(
model_type="video_llava",
)