[data] fix qwen omni plugin (#9204)
Co-authored-by: kingsley <kingsleydodonow@gmail.com>
This commit is contained in:
@@ -194,7 +194,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni
|
||||
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
|
||||
|
||||
if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2.5 omni
|
||||
if getattr(self.model.config, "model_type", None) in ["qwen2_5_omni_thinker", "qwen3_omni_moe_thinker"]:
|
||||
rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
|
||||
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
|
||||
if feature_attention_mask is not None: # FIXME: need to get video image lengths
|
||||
@@ -205,13 +205,22 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum(
|
||||
dim=-1
|
||||
).unsqueeze(-1)
|
||||
else: # for qwen2vl
|
||||
else: # for qwen vl
|
||||
features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs)
|
||||
|
||||
if (
|
||||
self.model is not None
|
||||
and getattr(self.model.config, "model_type", None)
|
||||
in ["glm4v", "Keye", "qwen2_vl", "qwen2_5_vl", "qwen2_5_omni_thinker"]
|
||||
in [
|
||||
"glm4v",
|
||||
"Keye",
|
||||
"qwen2_vl",
|
||||
"qwen2_5_vl",
|
||||
"qwen2_5_omni_thinker",
|
||||
"qwen3_omni_moe_thinker",
|
||||
"qwen3_vl",
|
||||
"qwen3_vl_moe",
|
||||
]
|
||||
and ("position_ids" not in features or features["position_ids"].dim() != 3)
|
||||
):
|
||||
raise ValueError(f"{self.model.config.model_type} requires 3D position ids for mrope.")
|
||||
|
||||
@@ -1397,8 +1397,8 @@ class Qwen2AudioPlugin(BasePlugin):
|
||||
|
||||
@dataclass
|
||||
class Qwen2VLPlugin(BasePlugin):
|
||||
start_token: str = "<|vision_start|>"
|
||||
end_token: str = "<|vision_end|>"
|
||||
vision_bos_token: str = "<|vision_start|>"
|
||||
vision_eos_token: str = "<|vision_end|>"
|
||||
|
||||
@override
|
||||
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
||||
@@ -1515,14 +1515,18 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER, f"{self.start_token}{self.image_token * image_seqlen}{self.end_token}", 1
|
||||
IMAGE_PLACEHOLDER,
|
||||
f"{self.vision_bos_token}{self.image_token * image_seqlen}{self.vision_eos_token}",
|
||||
1,
|
||||
)
|
||||
num_image_tokens += 1
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
content = content.replace(
|
||||
VIDEO_PLACEHOLDER, f"{self.start_token}{self.video_token * video_seqlen}{self.end_token}", 1
|
||||
VIDEO_PLACEHOLDER,
|
||||
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}",
|
||||
1,
|
||||
)
|
||||
num_video_tokens += 1
|
||||
|
||||
@@ -1611,7 +1615,9 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
|
||||
image_grid_thw[num_image_tokens].prod() // image_merge_length if self.expand_mm_tokens else 1
|
||||
)
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER, f"{self.start_token}{self.image_token * image_seqlen}{self.end_token}", 1
|
||||
IMAGE_PLACEHOLDER,
|
||||
f"{self.vision_bos_token}{self.image_token * image_seqlen}{self.vision_eos_token}",
|
||||
1,
|
||||
)
|
||||
num_image_tokens += 1
|
||||
|
||||
@@ -1630,11 +1636,14 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
|
||||
else 1
|
||||
)
|
||||
timestamp_sec = timestamps[frame_index]
|
||||
frame_structure = f"<{timestamp_sec:.1f} seconds>{self.start_token}{self.video_token * video_seqlen}{self.end_token}"
|
||||
frame_structure = (
|
||||
f"<{timestamp_sec:.1f} seconds>"
|
||||
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}"
|
||||
)
|
||||
video_structure += frame_structure
|
||||
|
||||
if not self.expand_mm_tokens:
|
||||
video_structure = f"{self.start_token}{self.video_token}{self.end_token}"
|
||||
video_structure = f"{self.vision_bos_token}{self.video_token}{self.vision_eos_token}"
|
||||
|
||||
content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1)
|
||||
num_video_tokens += 1
|
||||
@@ -1774,7 +1783,11 @@ class GLM4VPlugin(Qwen2VLPlugin):
|
||||
return mm_inputs
|
||||
|
||||
|
||||
@dataclass
|
||||
class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
audio_bos_token: str = "<|audio_start|>"
|
||||
audio_eos_token: str = "<|audio_end|>"
|
||||
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
@@ -1861,7 +1874,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER, f"<|vision_bos|>{self.image_token * image_seqlen}<|vision_eos|>", 1
|
||||
IMAGE_PLACEHOLDER,
|
||||
f"{self.vision_bos_token}{self.image_token * image_seqlen}{self.vision_eos_token}",
|
||||
1,
|
||||
)
|
||||
num_image_tokens += 1
|
||||
|
||||
@@ -1898,7 +1913,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
|
||||
audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
|
||||
placeholder_string = ""
|
||||
placeholder_string += "<|vision_bos|>" + "<|audio_bos|>"
|
||||
placeholder_string += self.vision_bos_token + self.audio_bos_token
|
||||
for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))):
|
||||
video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None
|
||||
audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None
|
||||
@@ -1908,7 +1923,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
if audio_chunk_index is not None:
|
||||
placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0])
|
||||
|
||||
placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
|
||||
placeholder_string += self.audio_eos_token + self.vision_eos_token
|
||||
content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1)
|
||||
content = content.replace(AUDIO_PLACEHOLDER, "", 1)
|
||||
num_audio_tokens += 1
|
||||
@@ -1917,7 +1932,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
while AUDIO_PLACEHOLDER in content:
|
||||
audio_seqlen = audio_lengths[num_audio_tokens] if self.expand_mm_tokens else 1
|
||||
content = content.replace(
|
||||
AUDIO_PLACEHOLDER, f"<|audio_bos|>{self.audio_token * audio_seqlen}<|audio_eos|>", 1
|
||||
AUDIO_PLACEHOLDER,
|
||||
f"{self.audio_bos_token}{self.audio_token * audio_seqlen}{self.audio_eos_token}",
|
||||
1,
|
||||
)
|
||||
num_audio_tokens += 1
|
||||
|
||||
@@ -1926,7 +1943,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
)
|
||||
content = content.replace(
|
||||
VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_seqlen}<|vision_eos|>", 1
|
||||
VIDEO_PLACEHOLDER,
|
||||
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}",
|
||||
1,
|
||||
)
|
||||
num_video_tokens += 1
|
||||
|
||||
|
||||
@@ -922,8 +922,8 @@ register_template(
|
||||
name="qwen2_vl",
|
||||
image_token="<|imgpad|>",
|
||||
video_token="<|vidpad|>",
|
||||
start_token="<|img|>",
|
||||
end_token="<|endofimg|>",
|
||||
vision_bos_token="<|img|>",
|
||||
vision_eos_token="<|endofimg|>",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1862,7 +1862,14 @@ register_template(
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
mm_plugin=get_mm_plugin(
|
||||
name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>"
|
||||
name="qwen2_omni",
|
||||
image_token="<|IMAGE|>",
|
||||
video_token="<|VIDEO|>",
|
||||
audio_token="<|AUDIO|>",
|
||||
vision_bos_token="<|vision_bos|>",
|
||||
vision_eos_token="<|vision_eos|>",
|
||||
audio_bos_token="<|audio_bos|>",
|
||||
audio_eos_token="<|audio_eos|>",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1880,7 +1887,7 @@ register_template(
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
mm_plugin=get_mm_plugin(
|
||||
name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>"
|
||||
name="qwen2_omni", image_token="<|image_pad|>", video_token="<|video_pad|>", audio_token="<|audio_pad|>"
|
||||
),
|
||||
template_class=ReasoningTemplate,
|
||||
)
|
||||
@@ -1899,7 +1906,7 @@ register_template(
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
mm_plugin=get_mm_plugin(
|
||||
name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>"
|
||||
name="qwen2_omni", image_token="<|image_pad|>", video_token="<|video_pad|>", audio_token="<|audio_pad|>"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -3060,13 +3060,14 @@ register_model_group(
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen/Qwen3-Omni-30B-A3B-Captioner": {
|
||||
"Qwen3-Omni-30B-A3B-Captioner": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-Omni-30B-A3B-Captioner",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-Omni-30B-A3B-Captioner",
|
||||
},
|
||||
"Qwen/Qwen3-Omni-30B-A3B-Instruct": {
|
||||
"Qwen3-Omni-30B-A3B-Instruct": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-Omni-30B-A3B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-Omni-30B-A3B-Instruct",
|
||||
},
|
||||
@@ -3075,9 +3076,10 @@ register_model_group(
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen/Qwen3-Omni-30B-A3B-Thinking": {
|
||||
"Qwen3-Omni-30B-A3B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-Omni-30B-A3B-Thinking",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-Omni-30B-A3B-Thinking",
|
||||
},
|
||||
@@ -3086,6 +3088,7 @@ register_model_group(
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen2-VL-2B": {
|
||||
@@ -3190,24 +3193,24 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen/Qwen3-VL-235B-A22B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-235B-A22B-Thinking",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-235B-A22B-Thinking",
|
||||
"Qwen3-VL-235B-A22B-Instruct": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-235B-A22B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-235B-A22B-Instruct",
|
||||
},
|
||||
},
|
||||
template="qwen3_vl",
|
||||
template="qwen3_vl_nothink",
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen/Qwen3-VL-235B-A22B-Instruct": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-235B-A22B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-235B-A22B-Instruct",
|
||||
"Qwen3-VL-235B-A22B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-235B-A22B-Thinking",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-235B-A22B-Thinking",
|
||||
},
|
||||
},
|
||||
template="qwen3_vl_nothink",
|
||||
template="qwen3_vl",
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -94,7 +94,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
|
||||
|
||||
def check_dependencies() -> None:
|
||||
r"""Check the version of the required packages."""
|
||||
check_version("transformers>=4.49.0,<=4.56.1")
|
||||
check_version("transformers>=4.49.0,<=4.56.2")
|
||||
check_version("datasets>=2.16.0,<=4.0.0")
|
||||
check_version("accelerate>=1.3.0,<=1.10.1")
|
||||
check_version("peft>=0.14.0,<=0.17.1")
|
||||
|
||||
@@ -147,7 +147,7 @@ def _check_extra_dependencies(
|
||||
check_version("mixture-of-depth>=1.1.6", mandatory=True)
|
||||
|
||||
if model_args.infer_backend == EngineName.VLLM:
|
||||
check_version("vllm>=0.4.3,<=0.10.0")
|
||||
check_version("vllm>=0.4.3,<=0.10.2")
|
||||
check_version("vllm", mandatory=True)
|
||||
elif model_args.infer_backend == EngineName.SGLANG:
|
||||
check_version("sglang>=0.4.5")
|
||||
@@ -174,7 +174,8 @@ def _check_extra_dependencies(
|
||||
if training_args is not None:
|
||||
if training_args.deepspeed:
|
||||
# pin deepspeed version < 0.17 because of https://github.com/deepspeedai/DeepSpeed/issues/7347
|
||||
check_version("deepspeed>=0.10.0,<=0.16.9", mandatory=True)
|
||||
check_version("deepspeed", mandatory=True)
|
||||
check_version("deepspeed>=0.10.0,<=0.16.9")
|
||||
|
||||
if training_args.predict_with_generate:
|
||||
check_version("jieba", mandatory=True)
|
||||
|
||||
@@ -162,7 +162,7 @@ def load_model(
|
||||
load_class = AutoModelForVision2Seq
|
||||
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text
|
||||
load_class = AutoModelForSeq2SeqLM
|
||||
elif type(config) in AutoModelForTextToWaveform._model_mapping.keys(): # audio hack for qwen2_5_omni
|
||||
elif type(config) in AutoModelForTextToWaveform._model_mapping.keys(): # audio hack for qwen omni
|
||||
load_class = AutoModelForTextToWaveform
|
||||
else:
|
||||
load_class = AutoModelForCausalLM
|
||||
@@ -171,8 +171,8 @@ def load_model(
|
||||
model = load_class.from_config(config, trust_remote_code=model_args.trust_remote_code)
|
||||
else:
|
||||
model = load_class.from_pretrained(**init_kwargs)
|
||||
if getattr(model.config, "model_type", None) == "qwen2_5_omni":
|
||||
model = model.thinker # use part of Omni model
|
||||
if getattr(model.config, "model_type", None) in ["qwen2_5_omni", "qwen3_omni_moe"]:
|
||||
model = getattr(model, "thinker")
|
||||
|
||||
if model_args.mixture_of_depths == "convert":
|
||||
model = convert_pretrained_model_to_mod(model, config, model_args)
|
||||
|
||||
@@ -298,6 +298,7 @@ _register_composite_model(
|
||||
lora_conflict_keys=["audio_projection_layer"],
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="mistral3",
|
||||
)
|
||||
@@ -351,6 +352,33 @@ _register_composite_model(
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="qwen3_vl",
|
||||
projector_key="visual.merger",
|
||||
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="qwen3_vl_moe",
|
||||
projector_key="visual.merger",
|
||||
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="qwen3_omni_moe_thinker",
|
||||
projector_key="visual.merger",
|
||||
vision_model_keys=["visual.patch_embed", "visual.blocks", "audio_tower"],
|
||||
language_model_keys=["model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="video_llava",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user