[model] add qwen3-vl/qwen3-omni (#9196)

Co-authored-by: kingsley <kingsleydodonow@gmail.com>
This commit is contained in:
xvxuopop
2025-09-27 01:21:47 +08:00
committed by GitHub
parent abc3b1e1c4
commit 0761a4448f
5 changed files with 268 additions and 2 deletions

View File

@@ -56,10 +56,17 @@ TEXT_MESSAGES = [
{"role": "assistant", "content": "I am fine!"},
]
VIDEO_MESSAGES = [
{"role": "user", "content": "<video>What is in this viode?"},
{"role": "assistant", "content": "A cat."},
]
AUDIOS = [np.zeros(1600)]
IMAGES = [Image.new("RGB", (32, 32), (255, 255, 255))]
VIDEOS = [[Image.new("RGB", (32, 32), (255, 255, 255))] * 4]
NO_IMAGES = []
NO_VIDEOS = []
@@ -145,6 +152,8 @@ def _check_plugin(
plugin.get_mm_inputs(IMAGES, NO_VIDEOS, AUDIOS, IMGLENS, NO_VIDLENS, AUDLENS, BATCH_IDS, processor),
expected_mm_inputs,
)
elif plugin.__class__.__name__ == "Qwen3VLPlugin": # only check replacement
assert plugin.process_messages(VIDEO_MESSAGES, NO_IMAGES, VIDEOS, NO_AUDIOS, processor) == expected_mm_messages
elif plugin.__class__.__name__ != "BasePlugin": # test mm_messages
assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == expected_mm_messages
assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == (
@@ -357,6 +366,27 @@ def test_qwen2_vl_plugin():
_check_plugin(**check_inputs)
@pytest.mark.skipif(not is_transformers_version_greater_than("4.57.0"), reason="Requires transformers>=4.57.0")
def test_qwen3_vl_plugin():
frame_seqlen = 1
tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen3-VL-235B-A22B-Instruct")
qwen3_vl_plugin = get_mm_plugin(name="qwen3_vl", video_token="<|video_pad|>")
check_inputs = {"plugin": qwen3_vl_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{
key: value.replace(
"<video>", # little different with original processor for default `fps=2` in our repo
"<0.2 seconds><|vision_start|>{}<|vision_end|><1.2 seconds><|vision_start|>{}<|vision_end|>".format(
"<|video_pad|>" * frame_seqlen, "<|video_pad|>" * frame_seqlen
),
)
for key, value in message.items()
}
for message in VIDEO_MESSAGES
]
_check_plugin(**check_inputs)
@pytest.mark.skipif(not is_transformers_version_greater_than("4.47.0"), reason="Requires transformers>=4.47.0")
def test_video_llava_plugin():
image_seqlen = 256