[model] add Qwen2.5-Omni model (#7537)

* preserve image_sizes

* preserve image_sizes

* init plugin

* support audio-text2text lora

* nit

* support image/video-text2text, audio-text2text

* remove args

* remove lines

* add docs && nit

* remove some comments

* fix && add merge part script

* add license
This commit is contained in:
Kingsley
2025-03-31 20:39:35 +08:00
committed by GitHub
parent 0f8296626a
commit 7eed496336
10 changed files with 348 additions and 2 deletions

View File

@@ -190,10 +190,27 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"video_grid_thw": mm_inputs.get("video_grid_thw"),
"attention_mask": features["attention_mask"],
}
if "second_per_grid_ts" in mm_inputs:
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs)
if getattr(self.model.config, "model_type", None) == "qwen2_5_omni": # for qwen2omni
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
if feature_attention_mask is not None:
audio_feature_lengths = torch.sum(
feature_attention_mask, dim=1
) # FIXME need to get video image lengths
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input
delta0 = (1 - rope_index_kwargs["attention_mask"]).sum(dim=-1).unsqueeze(1)
# avoid conflict
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid", None)
new_position_ids, rope_deltas = self.model.get_rope_index(**rope_index_kwargs)
features["position_ids"], features["rope_deltas"] = (
new_position_ids.clone(),
rope_deltas - delta0,
) # avoid inplace operation FIXME
else: # for qwen2vl
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs)
if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
cross_attention_mask = mm_inputs.pop("cross_attention_mask")

View File

@@ -146,6 +146,12 @@ class MMPluginMixin:
video_processor: BaseImageProcessor = getattr(
processor, "video_processor", getattr(processor, "image_processor", None)
)
if image_processor is None and video_processor is None: # hack for qwen2_5_omni
image_processor, video_processor = (
getattr(processor, "omni_processor", None),
getattr(processor, "omni_processor", None),
)
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
if len(images) != 0 and self.image_token is None:
raise ValueError(
@@ -1104,6 +1110,186 @@ class Qwen2AudioPlugin(BasePlugin):
return self._get_mm_inputs(images, videos, audios, processor)
class Qwen2OmniPlugin(BasePlugin):
@override
def _get_mm_inputs(
self,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: "MMProcessor",
imglens: Optional[list[int]] = None,
) -> dict[str, "torch.Tensor"]:
mm_inputs = {}
if len(images) != 0:
image_processor: BaseImageProcessor = getattr(processor, "omni_processor", None) # FIXME
images = self._regularize_images(
images,
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
)
if imglens is not None:
images = _make_batched_images(images, imglens)
image_processor_kwargs = {}
mm_inputs.update(image_processor(images, return_tensors="pt", **image_processor_kwargs))
if len(videos) != 0:
video_processor: BaseImageProcessor = getattr(
processor, "video_processor", getattr(processor, "omni_processor", None)
)
videos = self._regularize_videos(
videos,
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128),
)
if "videos" in inspect.signature(video_processor.preprocess).parameters: # for qwen2_vl and video_llava
mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt"))
fps = [2.0] * len(videos) # FIXME hardcode
video_second_per_grid = [fps[i] / video_processor.temporal_patch_size for i in range(len(fps))]
mm_inputs["video_second_per_grid"] = torch.tensor(video_second_per_grid)
else:
raise NotImplementedError
if len(audios) != 0:
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
audios = self._regularize_audios(
audios,
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
)
mm_inputs.update(
feature_extractor(
audios,
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
return_attention_mask=True,
padding="max_length",
return_tensors="pt",
)
)
mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask") # prevent conflicts
return mm_inputs
@override
def process_messages(
self,
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
messages = deepcopy(messages)
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0
use_audio_in_video = getattr(processor, "use_audio_in_video", False)
# get length or size from mm_inputs
if "feature_attention_mask" in mm_inputs:
input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
audio_lengths = (input_lengths - 2) // 2 + 1
if mm_inputs.get("image_grid_thw", None) is not None:
image_grid_thw = mm_inputs["image_grid_thw"]
merge_length = processor.omni_processor.merge_size**2
if mm_inputs.get("video_grid_thw", None) is not None:
video_grid_thw = mm_inputs["video_grid_thw"]
merge_length = processor.omni_processor.merge_size**2
if use_audio_in_video:
assert audio_lengths is not None, "audio_lengths should be exist when use_audio_in_video is `True`"
assert mm_inputs.get("video_grid_thw", None) is not None, (
"video_grid_thw should be exist when use_audio_in_video is `True`"
)
positions_list = []
for i, message in enumerate(messages): # get multimodal index when use_audio
positions = []
for special_token in [self.audio_token, self.image_token, self.video_token]:
start = 0
while True:
pos = message[i].find(special_token, start)
if pos == -1:
break
positions.append((pos, special_token))
start = pos + len(special_token)
positions_list.append(positions.sort(key=lambda x: x[0]))
for message in messages:
content = message["content"]
# separate with audio-video
while IMAGE_PLACEHOLDER in content:
image_token_replace_length = image_grid_thw[num_image_tokens].prod() // merge_length
content = content.replace(
IMAGE_PLACEHOLDER,
f"<|vision_bos|>{self.image_token * image_token_replace_length}<|vision_eos|>",
1,
)
num_image_tokens += 1
if not use_audio_in_video:
while AUDIO_PLACEHOLDER in content:
audio_token_replace_length = audio_lengths[num_audio_tokens]
content = content.replace(
AUDIO_PLACEHOLDER,
f"<|audio_bos|>{self.audio_token * audio_token_replace_length}<|audio_eos|>",
1,
)
num_audio_tokens += 1
# TODO handle video_input and use_audio_in_video
while VIDEO_PLACEHOLDER in content:
video_replace_length = video_grid_thw[num_video_tokens].prod() // merge_length
content = content.replace(
VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_replace_length}<|vision_eos|>", 1
)
num_video_tokens += 1
else: # if use the audio of video # deal video token and audio token togather
while VIDEO_PLACEHOLDER in content:
audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
video_t_index = (
torch.arange(video_grid_thw[num_video_tokens][0])
.view(-1, 1, 1)
.expand(
-1,
video_grid_thw[num_video_tokens][1] // self.omni_processor.merge_size,
video_grid_thw[num_video_tokens][2] // self.omni_processor.merge_size,
)
.flatten()
* mm_inputs["video_second_per_grid"][num_video_tokens]
* 25 # FIXME hardcode of position_id_per_seconds=25
).long()
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]
video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
audio_chunk_indices = self.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
placeholder_string = ""
for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))):
video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None
audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None
placeholder_string = "<|vision_bos|>" + "<|audio_bos|>"
if video_chunk_index is not None:
placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0])
if audio_chunk_index is not None:
placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0])
placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1)
content = content.replace(AUDIO_PLACEHOLDER, "", 1)
num_audio_tokens += 1
num_video_tokens += 1
message["content"] = content
if len(audios) != num_audio_tokens:
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
return messages
@dataclass
class Qwen2VLPlugin(BasePlugin):
@override
@@ -1328,6 +1514,7 @@ PLUGINS = {
"paligemma": PaliGemmaPlugin,
"pixtral": PixtralPlugin,
"qwen2_audio": Qwen2AudioPlugin,
"qwen2_omni": Qwen2OmniPlugin,
"qwen2_vl": Qwen2VLPlugin,
"video_llava": VideoLlavaPlugin,
}

View File

@@ -1367,6 +1367,24 @@ register_template(
)
# copied from qwen template
register_template(
name="qwen2_omni",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
mm_plugin=get_mm_plugin(
name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>"
),
)
# copied from qwen template
register_template(
name="qwen2_vl",

View File

@@ -2270,6 +2270,18 @@ register_model_group(
)
register_model_group(
models={
"Qwen2.5-Omni-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B",
}
},
template="qwen2_omni",
multimodal=True,
)
register_model_group(
models={
"Qwen2-VL-2B": {

View File

@@ -222,6 +222,10 @@ class ProcessorArguments:
default=False,
metadata={"help": "Use pan and scan to process image for gemma3."},
)
use_audio_in_video: bool = field(
default=False,
metadata={"help": "Whether or not to use audio in video inputs."},
)
video_max_pixels: int = field(
default=256 * 256,
metadata={"help": "The maximum number of pixels of video inputs."},

View File

@@ -21,6 +21,7 @@ from transformers import (
AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoModelForSeq2SeqLM,
AutoModelForTextToWaveform,
AutoModelForVision2Seq,
AutoProcessor,
AutoTokenizer,
@@ -147,6 +148,8 @@ def load_model(
load_class = AutoModelForImageTextToText
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text
load_class = AutoModelForSeq2SeqLM
elif type(config) in AutoModelForTextToWaveform._model_mapping.keys(): # audio hack for qwen2_5_omni
load_class = AutoModelForTextToWaveform
else:
load_class = AutoModelForCausalLM
@@ -154,6 +157,8 @@ def load_model(
model = load_class.from_config(config, trust_remote_code=model_args.trust_remote_code)
else:
model = load_class.from_pretrained(**init_kwargs)
if load_class is AutoModelForTextToWaveform:
model = model.thinker # use part of Omni model
if model_args.mixture_of_depths == "convert":
model = convert_pretrained_model_to_mod(model, config, model_args)

View File

@@ -257,6 +257,17 @@ _register_composite_model(
)
_register_composite_model(
model_type="qwen2_5_omni_thinker",
projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks", "audio_tower"],
language_model_keys=["model", "lm_head"],
lora_conflict_keys=[
"patch_embed",
],
)
_register_composite_model(
model_type="qwen2_vl",
projector_key="visual.merger",