[data] fix qwen2.5 omni template (#7883)
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
@@ -43,11 +44,20 @@ MM_MESSAGES = [
|
||||
{"role": "assistant", "content": "A cat."},
|
||||
]
|
||||
|
||||
OMNI_MESSAGES = [
|
||||
{"role": "user", "content": "<image>What is in this image?"},
|
||||
{"role": "assistant", "content": "A cat."},
|
||||
{"role": "user", "content": "<audio>What is in this audio?"},
|
||||
{"role": "assistant", "content": "Nothing."},
|
||||
]
|
||||
|
||||
TEXT_MESSAGES = [
|
||||
{"role": "user", "content": "How are you"},
|
||||
{"role": "assistant", "content": "I am fine!"},
|
||||
]
|
||||
|
||||
AUDIOS = [np.zeros(1600)]
|
||||
|
||||
IMAGES = [Image.new("RGB", (32, 32), (255, 255, 255))]
|
||||
|
||||
NO_IMAGES = []
|
||||
@@ -58,6 +68,8 @@ NO_AUDIOS = []
|
||||
|
||||
IMGLENS = [1]
|
||||
|
||||
AUDLENS = [1]
|
||||
|
||||
NO_IMGLENS = [0]
|
||||
|
||||
NO_VIDLENS = [0]
|
||||
@@ -76,6 +88,25 @@ def _get_mm_inputs(processor: "ProcessorMixin") -> dict[str, "torch.Tensor"]:
|
||||
return image_processor(images=IMAGES, return_tensors="pt")
|
||||
|
||||
|
||||
def _get_omni_inputs(processor: "ProcessorMixin") -> dict[str, "torch.Tensor"]:
|
||||
mm_inputs = {}
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||
feature_extractor = getattr(processor, "feature_extractor", None)
|
||||
|
||||
mm_inputs.update(image_processor(IMAGES, return_tensors="pt"))
|
||||
mm_inputs.update(
|
||||
feature_extractor(
|
||||
AUDIOS,
|
||||
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
||||
return_attention_mask=True,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
)
|
||||
mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask")
|
||||
return mm_inputs
|
||||
|
||||
|
||||
def _is_close(batch_a: dict[str, Any], batch_b: dict[str, Any]) -> None:
|
||||
assert batch_a.keys() == batch_b.keys()
|
||||
for key in batch_a.keys():
|
||||
@@ -104,6 +135,17 @@ def _check_plugin(
|
||||
expected_mm_inputs: dict[str, Any] = {},
|
||||
expected_no_mm_inputs: dict[str, Any] = {},
|
||||
) -> None:
|
||||
# test omni_messages
|
||||
if plugin.__class__.__name__ == "Qwen2OmniPlugin":
|
||||
assert plugin.process_messages(OMNI_MESSAGES, IMAGES, NO_VIDEOS, AUDIOS, processor) == expected_mm_messages
|
||||
assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, AUDIOS, tokenizer, processor) == (
|
||||
expected_input_ids,
|
||||
expected_labels,
|
||||
)
|
||||
_is_close(
|
||||
plugin.get_mm_inputs(IMAGES, NO_VIDEOS, AUDIOS, IMGLENS, NO_VIDLENS, AUDLENS, BATCH_IDS, processor),
|
||||
expected_mm_inputs,
|
||||
)
|
||||
# test mm_messages
|
||||
if plugin.__class__.__name__ != "BasePlugin":
|
||||
assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == expected_mm_messages
|
||||
@@ -279,6 +321,30 @@ def test_pixtral_plugin():
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="Unknown error.")
|
||||
def test_qwen2_omni_plugin():
|
||||
image_seqlen = 4
|
||||
audio_seqlen = 2
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2.5-Omni-7B")
|
||||
qwen2_omni_plugin = get_mm_plugin(
|
||||
name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>"
|
||||
)
|
||||
check_inputs = {"plugin": qwen2_omni_plugin, **tokenizer_module}
|
||||
check_inputs["expected_mm_messages"] = [
|
||||
{
|
||||
key: (
|
||||
value.replace("<image>", f"<|vision_bos|>{'<|IMAGE|>' * image_seqlen}<|vision_eos|>").replace(
|
||||
"<audio>", f"<|audio_bos|>{'<|AUDIO|>' * audio_seqlen}<|audio_eos|>"
|
||||
)
|
||||
)
|
||||
for key, value in message.items()
|
||||
}
|
||||
for message in OMNI_MESSAGES
|
||||
]
|
||||
check_inputs["expected_mm_inputs"] = _get_omni_inputs(tokenizer_module["processor"])
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
def test_qwen2_vl_plugin():
|
||||
image_seqlen = 4
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
|
||||
|
||||
Reference in New Issue
Block a user