[model] add Qwen2.5-Omni model (#7537)
* preserve image_sizes * preserve image_sizes * init plugin * support audio-text2text lora * nit * support image/video-text2text, audio-text2text * remove args * remove lines * add docs && nit * remove some comments * fix && add merge part script * add license
This commit is contained in:
@@ -146,6 +146,12 @@ class MMPluginMixin:
|
||||
video_processor: BaseImageProcessor = getattr(
|
||||
processor, "video_processor", getattr(processor, "image_processor", None)
|
||||
)
|
||||
if image_processor is None and video_processor is None: # hack for qwen2_5_omni
|
||||
image_processor, video_processor = (
|
||||
getattr(processor, "omni_processor", None),
|
||||
getattr(processor, "omni_processor", None),
|
||||
)
|
||||
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
||||
if len(images) != 0 and self.image_token is None:
|
||||
raise ValueError(
|
||||
@@ -1104,6 +1110,186 @@ class Qwen2AudioPlugin(BasePlugin):
|
||||
return self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
|
||||
class Qwen2OmniPlugin(BasePlugin):
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: "MMProcessor",
|
||||
imglens: Optional[list[int]] = None,
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
mm_inputs = {}
|
||||
if len(images) != 0:
|
||||
image_processor: BaseImageProcessor = getattr(processor, "omni_processor", None) # FIXME
|
||||
images = self._regularize_images(
|
||||
images,
|
||||
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
||||
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
||||
)
|
||||
if imglens is not None:
|
||||
images = _make_batched_images(images, imglens)
|
||||
|
||||
image_processor_kwargs = {}
|
||||
mm_inputs.update(image_processor(images, return_tensors="pt", **image_processor_kwargs))
|
||||
|
||||
if len(videos) != 0:
|
||||
video_processor: BaseImageProcessor = getattr(
|
||||
processor, "video_processor", getattr(processor, "omni_processor", None)
|
||||
)
|
||||
videos = self._regularize_videos(
|
||||
videos,
|
||||
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
|
||||
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
|
||||
video_fps=getattr(processor, "video_fps", 2.0),
|
||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||
)
|
||||
if "videos" in inspect.signature(video_processor.preprocess).parameters: # for qwen2_vl and video_llava
|
||||
mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt"))
|
||||
fps = [2.0] * len(videos) # FIXME hardcode
|
||||
video_second_per_grid = [fps[i] / video_processor.temporal_patch_size for i in range(len(fps))]
|
||||
mm_inputs["video_second_per_grid"] = torch.tensor(video_second_per_grid)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if len(audios) != 0:
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
||||
audios = self._regularize_audios(
|
||||
audios,
|
||||
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
|
||||
)
|
||||
mm_inputs.update(
|
||||
feature_extractor(
|
||||
audios,
|
||||
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
|
||||
return_attention_mask=True,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
)
|
||||
mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask") # prevent conflicts
|
||||
|
||||
return mm_inputs
|
||||
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
messages = deepcopy(messages)
|
||||
if self.expand_mm_tokens:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0
|
||||
use_audio_in_video = getattr(processor, "use_audio_in_video", False)
|
||||
|
||||
# get length or size from mm_inputs
|
||||
if "feature_attention_mask" in mm_inputs:
|
||||
input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
|
||||
audio_lengths = (input_lengths - 2) // 2 + 1
|
||||
if mm_inputs.get("image_grid_thw", None) is not None:
|
||||
image_grid_thw = mm_inputs["image_grid_thw"]
|
||||
merge_length = processor.omni_processor.merge_size**2
|
||||
if mm_inputs.get("video_grid_thw", None) is not None:
|
||||
video_grid_thw = mm_inputs["video_grid_thw"]
|
||||
merge_length = processor.omni_processor.merge_size**2
|
||||
|
||||
if use_audio_in_video:
|
||||
assert audio_lengths is not None, "audio_lengths should be exist when use_audio_in_video is `True`"
|
||||
assert mm_inputs.get("video_grid_thw", None) is not None, (
|
||||
"video_grid_thw should be exist when use_audio_in_video is `True`"
|
||||
)
|
||||
positions_list = []
|
||||
for i, message in enumerate(messages): # get multimodal index when use_audio
|
||||
positions = []
|
||||
for special_token in [self.audio_token, self.image_token, self.video_token]:
|
||||
start = 0
|
||||
while True:
|
||||
pos = message[i].find(special_token, start)
|
||||
if pos == -1:
|
||||
break
|
||||
positions.append((pos, special_token))
|
||||
start = pos + len(special_token)
|
||||
positions_list.append(positions.sort(key=lambda x: x[0]))
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
# separate with audio-video
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
image_token_replace_length = image_grid_thw[num_image_tokens].prod() // merge_length
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER,
|
||||
f"<|vision_bos|>{self.image_token * image_token_replace_length}<|vision_eos|>",
|
||||
1,
|
||||
)
|
||||
num_image_tokens += 1
|
||||
|
||||
if not use_audio_in_video:
|
||||
while AUDIO_PLACEHOLDER in content:
|
||||
audio_token_replace_length = audio_lengths[num_audio_tokens]
|
||||
content = content.replace(
|
||||
AUDIO_PLACEHOLDER,
|
||||
f"<|audio_bos|>{self.audio_token * audio_token_replace_length}<|audio_eos|>",
|
||||
1,
|
||||
)
|
||||
num_audio_tokens += 1
|
||||
# TODO handle video_input and use_audio_in_video
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
video_replace_length = video_grid_thw[num_video_tokens].prod() // merge_length
|
||||
content = content.replace(
|
||||
VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_replace_length}<|vision_eos|>", 1
|
||||
)
|
||||
num_video_tokens += 1
|
||||
else: # if use the audio of video # deal video token and audio token togather
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
|
||||
video_t_index = (
|
||||
torch.arange(video_grid_thw[num_video_tokens][0])
|
||||
.view(-1, 1, 1)
|
||||
.expand(
|
||||
-1,
|
||||
video_grid_thw[num_video_tokens][1] // self.omni_processor.merge_size,
|
||||
video_grid_thw[num_video_tokens][2] // self.omni_processor.merge_size,
|
||||
)
|
||||
.flatten()
|
||||
* mm_inputs["video_second_per_grid"][num_video_tokens]
|
||||
* 25 # FIXME hardcode of position_id_per_seconds=25
|
||||
).long()
|
||||
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]
|
||||
video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
|
||||
audio_chunk_indices = self.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
|
||||
placeholder_string = ""
|
||||
for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))):
|
||||
video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None
|
||||
audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None
|
||||
placeholder_string = "<|vision_bos|>" + "<|audio_bos|>"
|
||||
if video_chunk_index is not None:
|
||||
placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0])
|
||||
if audio_chunk_index is not None:
|
||||
placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0])
|
||||
placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
|
||||
content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1)
|
||||
content = content.replace(AUDIO_PLACEHOLDER, "", 1)
|
||||
num_audio_tokens += 1
|
||||
num_video_tokens += 1
|
||||
message["content"] = content
|
||||
|
||||
if len(audios) != num_audio_tokens:
|
||||
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
if len(videos) != num_video_tokens:
|
||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
@dataclass
|
||||
class Qwen2VLPlugin(BasePlugin):
|
||||
@override
|
||||
@@ -1328,6 +1514,7 @@ PLUGINS = {
|
||||
"paligemma": PaliGemmaPlugin,
|
||||
"pixtral": PixtralPlugin,
|
||||
"qwen2_audio": Qwen2AudioPlugin,
|
||||
"qwen2_omni": Qwen2OmniPlugin,
|
||||
"qwen2_vl": Qwen2VLPlugin,
|
||||
"video_llava": VideoLlavaPlugin,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user