[model] add qwen3-vl/qwen3-omni (#9196)
Co-authored-by: kingsley <kingsleydodonow@gmail.com>
This commit is contained in:
@@ -56,10 +56,17 @@ TEXT_MESSAGES = [
|
||||
{"role": "assistant", "content": "I am fine!"},
|
||||
]
|
||||
|
||||
VIDEO_MESSAGES = [
|
||||
{"role": "user", "content": "<video>What is in this viode?"},
|
||||
{"role": "assistant", "content": "A cat."},
|
||||
]
|
||||
|
||||
AUDIOS = [np.zeros(1600)]
|
||||
|
||||
IMAGES = [Image.new("RGB", (32, 32), (255, 255, 255))]
|
||||
|
||||
VIDEOS = [[Image.new("RGB", (32, 32), (255, 255, 255))] * 4]
|
||||
|
||||
NO_IMAGES = []
|
||||
|
||||
NO_VIDEOS = []
|
||||
@@ -145,6 +152,8 @@ def _check_plugin(
|
||||
plugin.get_mm_inputs(IMAGES, NO_VIDEOS, AUDIOS, IMGLENS, NO_VIDLENS, AUDLENS, BATCH_IDS, processor),
|
||||
expected_mm_inputs,
|
||||
)
|
||||
elif plugin.__class__.__name__ == "Qwen3VLPlugin": # only check replacement
|
||||
assert plugin.process_messages(VIDEO_MESSAGES, NO_IMAGES, VIDEOS, NO_AUDIOS, processor) == expected_mm_messages
|
||||
elif plugin.__class__.__name__ != "BasePlugin": # test mm_messages
|
||||
assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == expected_mm_messages
|
||||
assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == (
|
||||
@@ -357,6 +366,27 @@ def test_qwen2_vl_plugin():
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_transformers_version_greater_than("4.57.0"), reason="Requires transformers>=4.57.0")
|
||||
def test_qwen3_vl_plugin():
|
||||
frame_seqlen = 1
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen3-VL-235B-A22B-Instruct")
|
||||
qwen3_vl_plugin = get_mm_plugin(name="qwen3_vl", video_token="<|video_pad|>")
|
||||
check_inputs = {"plugin": qwen3_vl_plugin, **tokenizer_module}
|
||||
check_inputs["expected_mm_messages"] = [
|
||||
{
|
||||
key: value.replace(
|
||||
"<video>", # little different with original processor for default `fps=2` in our repo
|
||||
"<0.2 seconds><|vision_start|>{}<|vision_end|><1.2 seconds><|vision_start|>{}<|vision_end|>".format(
|
||||
"<|video_pad|>" * frame_seqlen, "<|video_pad|>" * frame_seqlen
|
||||
),
|
||||
)
|
||||
for key, value in message.items()
|
||||
}
|
||||
for message in VIDEO_MESSAGES
|
||||
]
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_transformers_version_greater_than("4.47.0"), reason="Requires transformers>=4.47.0")
|
||||
def test_video_llava_plugin():
|
||||
image_seqlen = 256
|
||||
|
||||
Reference in New Issue
Block a user