[infer] vllm video/audio inference (#7566)
This commit is contained in:
@@ -21,7 +21,7 @@ import re
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING, Literal, Optional, TypedDict, Union
|
||||
from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -68,9 +68,9 @@ if TYPE_CHECKING:
|
||||
path: Optional[str]
|
||||
bytes: Optional[bytes]
|
||||
|
||||
ImageInput = Union[str, bytes, EncodedImage, ImageObject]
|
||||
VideoInput = str
|
||||
AudioInput = Union[str, NDArray]
|
||||
ImageInput = Union[str, bytes, EncodedImage, BinaryIO, ImageObject]
|
||||
VideoInput = Union[str, BinaryIO]
|
||||
AudioInput = Union[str, BinaryIO, NDArray]
|
||||
|
||||
class MMProcessor(ProcessorMixin):
|
||||
patch_size: int
|
||||
@@ -146,12 +146,6 @@ class MMPluginMixin:
|
||||
video_processor: BaseImageProcessor = getattr(
|
||||
processor, "video_processor", getattr(processor, "image_processor", None)
|
||||
)
|
||||
if image_processor is None and video_processor is None: # hack for qwen2_5_omni
|
||||
image_processor, video_processor = (
|
||||
getattr(processor, "omni_processor", None),
|
||||
getattr(processor, "omni_processor", None),
|
||||
)
|
||||
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
||||
if len(images) != 0 and self.image_token is None:
|
||||
raise ValueError(
|
||||
@@ -211,11 +205,11 @@ class MMPluginMixin:
|
||||
sample_frames = min(total_frames, video_maxlen, sample_frames)
|
||||
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
||||
|
||||
def _regularize_images(self, images: list["ImageInput"], **kwargs) -> list["ImageObject"]:
|
||||
def _regularize_images(self, images: list["ImageInput"], **kwargs) -> dict[str, list["ImageObject"]]:
|
||||
r"""Regularize images to avoid error. Including reading and pre-processing."""
|
||||
results = []
|
||||
for image in images:
|
||||
if isinstance(image, str):
|
||||
if isinstance(image, (str, BinaryIO)):
|
||||
image = Image.open(image)
|
||||
elif isinstance(image, bytes):
|
||||
image = Image.open(BytesIO(image))
|
||||
@@ -230,9 +224,9 @@ class MMPluginMixin:
|
||||
|
||||
results.append(self._preprocess_image(image, **kwargs))
|
||||
|
||||
return results
|
||||
return {"images": results}
|
||||
|
||||
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> list[list["ImageObject"]]:
|
||||
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> dict[str, list[list["ImageObject"]]]:
|
||||
r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
|
||||
results = []
|
||||
for video in videos:
|
||||
@@ -245,24 +239,27 @@ class MMPluginMixin:
|
||||
if frame_idx in sample_indices:
|
||||
frames.append(frame.to_image())
|
||||
|
||||
frames = self._regularize_images(frames, **kwargs)
|
||||
frames = self._regularize_images(frames, **kwargs)["images"]
|
||||
results.append(frames)
|
||||
|
||||
return results
|
||||
return {"videos": results}
|
||||
|
||||
def _regularize_audios(self, audios: list["AudioInput"], sampling_rate: float, **kwargs) -> list["NDArray"]:
|
||||
def _regularize_audios(
|
||||
self, audios: list["AudioInput"], sampling_rate: float, **kwargs
|
||||
) -> dict[str, Union[list["NDArray"], list[float]]]:
|
||||
r"""Regularizes audios to avoid error. Including reading and resampling."""
|
||||
results = []
|
||||
results, sampling_rates = [], []
|
||||
for audio in audios:
|
||||
if isinstance(audio, str):
|
||||
audio = librosa.load(audio, sr=sampling_rate)[0]
|
||||
if isinstance(audio, (str, BinaryIO)):
|
||||
audio, sampling_rate = librosa.load(audio, sr=sampling_rate)
|
||||
|
||||
if not isinstance(audio, np.ndarray):
|
||||
raise ValueError(f"Expect input is a list of audios, but got {type(audio)}.")
|
||||
|
||||
results.append(audio)
|
||||
sampling_rates.append(sampling_rate)
|
||||
|
||||
return results
|
||||
return {"audios": results, "sampling_rates": sampling_rates}
|
||||
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
@@ -298,8 +295,8 @@ class MMPluginMixin:
|
||||
images,
|
||||
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
||||
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
||||
)
|
||||
if imglens is not None:
|
||||
)["images"]
|
||||
if imglens is not None: # if imglens are provided, make batched images
|
||||
images = _make_batched_images(images, imglens)
|
||||
|
||||
image_processor_kwargs = {}
|
||||
@@ -325,7 +322,7 @@ class MMPluginMixin:
|
||||
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
|
||||
video_fps=getattr(processor, "video_fps", 2.0),
|
||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||
)
|
||||
)["videos"]
|
||||
if "videos" in inspect.signature(video_processor.preprocess).parameters: # for qwen2_vl and video_llava
|
||||
mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt"))
|
||||
else: # for llava_next_video
|
||||
@@ -335,12 +332,12 @@ class MMPluginMixin:
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
||||
audios = self._regularize_audios(
|
||||
audios,
|
||||
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
|
||||
)
|
||||
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
||||
)["audios"]
|
||||
mm_inputs.update(
|
||||
feature_extractor(
|
||||
audios,
|
||||
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
|
||||
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
||||
return_attention_mask=True,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
@@ -726,14 +723,13 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
**kwargs,
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
||||
mm_inputs = {}
|
||||
if len(images) != 0:
|
||||
images = self._regularize_images(
|
||||
images,
|
||||
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
||||
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
||||
)
|
||||
)["images"]
|
||||
if "valid_image_nums_ls" in kwargs:
|
||||
valid_image_nums_ls = kwargs["valid_image_nums_ls"]
|
||||
new_images = []
|
||||
@@ -756,15 +752,15 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
|
||||
video_fps=getattr(processor, "video_fps", 2.0),
|
||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||
)
|
||||
)["videos"]
|
||||
video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt")
|
||||
mm_inputs.update(video_inputs)
|
||||
|
||||
if len(audios) != 0:
|
||||
audios = self._regularize_audios(
|
||||
audios,
|
||||
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
|
||||
)
|
||||
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
||||
)["audios"]
|
||||
if "valid_audio_nums_ls" in kwargs:
|
||||
valid_audio_nums_ls = kwargs["valid_audio_nums_ls"]
|
||||
audios_ls = []
|
||||
@@ -778,7 +774,7 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract(
|
||||
audios_ls,
|
||||
chunk_input=True,
|
||||
sampling_rate=16000,
|
||||
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
||||
)
|
||||
audio_feature_lens = [torch.tensor(audio_feature_len) for audio_feature_len in audio_feature_lens]
|
||||
mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
|
||||
@@ -1110,195 +1106,6 @@ class Qwen2AudioPlugin(BasePlugin):
|
||||
return self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
|
||||
class Qwen2OmniPlugin(BasePlugin):
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: "MMProcessor",
|
||||
imglens: Optional[list[int]] = None,
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
mm_inputs = {}
|
||||
if len(images) != 0:
|
||||
image_processor: BaseImageProcessor = getattr(processor, "omni_processor", None) # FIXME
|
||||
images = self._regularize_images(
|
||||
images,
|
||||
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
||||
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
||||
)
|
||||
if imglens is not None:
|
||||
images = _make_batched_images(images, imglens)
|
||||
|
||||
image_processor_kwargs = {}
|
||||
mm_inputs.update(image_processor(images, return_tensors="pt", **image_processor_kwargs))
|
||||
|
||||
if len(videos) != 0:
|
||||
video_processor: BaseImageProcessor = getattr(
|
||||
processor, "video_processor", getattr(processor, "omni_processor", None)
|
||||
)
|
||||
videos = self._regularize_videos(
|
||||
videos,
|
||||
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
|
||||
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
|
||||
video_fps=getattr(processor, "video_fps", 2.0),
|
||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||
)
|
||||
if "videos" in inspect.signature(video_processor.preprocess).parameters: # for qwen2_vl and video_llava
|
||||
mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt"))
|
||||
fps = [2.0] * len(videos) # FIXME hardcode
|
||||
video_second_per_grid = [fps[i] / video_processor.temporal_patch_size for i in range(len(fps))]
|
||||
mm_inputs["video_second_per_grid"] = torch.tensor(video_second_per_grid)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if len(audios) != 0:
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
||||
audios = self._regularize_audios(
|
||||
audios,
|
||||
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
|
||||
)
|
||||
mm_inputs.update(
|
||||
feature_extractor(
|
||||
audios,
|
||||
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
|
||||
return_attention_mask=True,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
)
|
||||
mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask") # prevent conflicts
|
||||
|
||||
return mm_inputs
|
||||
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
messages = deepcopy(messages)
|
||||
if self.expand_mm_tokens:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
else:
|
||||
mm_inputs = {}
|
||||
|
||||
num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0
|
||||
use_audio_in_video = getattr(processor, "use_audio_in_video", False)
|
||||
|
||||
# get length or size from mm_inputs
|
||||
if "feature_attention_mask" in mm_inputs:
|
||||
input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
|
||||
audio_lengths = (input_lengths - 2) // 2 + 1
|
||||
|
||||
if mm_inputs.get("image_grid_thw", None) is not None:
|
||||
image_grid_thw = mm_inputs["image_grid_thw"]
|
||||
merge_length = processor.omni_processor.merge_size**2
|
||||
|
||||
if mm_inputs.get("video_grid_thw", None) is not None:
|
||||
video_grid_thw = mm_inputs["video_grid_thw"]
|
||||
merge_length = processor.omni_processor.merge_size**2
|
||||
|
||||
if use_audio_in_video:
|
||||
if audio_lengths is None:
|
||||
raise ValueError("audio_lengths should exist when use_audio_in_video is `True`.")
|
||||
|
||||
if not mm_inputs.get("video_grid_thw", None):
|
||||
raise ValueError("video_grid_thw should exist when use_audio_in_video is `True`.")
|
||||
|
||||
positions_list = []
|
||||
for i, message in enumerate(messages): # get multimodal index when use_audio
|
||||
positions = []
|
||||
for special_token in [self.audio_token, self.image_token, self.video_token]:
|
||||
start = 0
|
||||
while True:
|
||||
pos = message[i].find(special_token, start)
|
||||
if pos == -1:
|
||||
break
|
||||
positions.append((pos, special_token))
|
||||
start = pos + len(special_token)
|
||||
|
||||
positions_list.append(positions.sort(key=lambda x: x[0]))
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
# separate with audio-video
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
image_token_replace_length = image_grid_thw[num_image_tokens].prod() // merge_length
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER,
|
||||
f"<|vision_bos|>{self.image_token * image_token_replace_length}<|vision_eos|>",
|
||||
1,
|
||||
)
|
||||
num_image_tokens += 1
|
||||
|
||||
if not use_audio_in_video:
|
||||
while AUDIO_PLACEHOLDER in content:
|
||||
audio_token_replace_length = audio_lengths[num_audio_tokens]
|
||||
content = content.replace(
|
||||
AUDIO_PLACEHOLDER,
|
||||
f"<|audio_bos|>{self.audio_token * audio_token_replace_length}<|audio_eos|>",
|
||||
1,
|
||||
)
|
||||
num_audio_tokens += 1
|
||||
# TODO handle video_input and use_audio_in_video
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
video_replace_length = video_grid_thw[num_video_tokens].prod() // merge_length
|
||||
content = content.replace(
|
||||
VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_replace_length}<|vision_eos|>", 1
|
||||
)
|
||||
num_video_tokens += 1
|
||||
else: # if use the audio of video # deal video token and audio token togather
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
|
||||
video_t_index = (
|
||||
torch.arange(video_grid_thw[num_video_tokens][0])
|
||||
.view(-1, 1, 1)
|
||||
.expand(
|
||||
-1,
|
||||
video_grid_thw[num_video_tokens][1] // self.omni_processor.merge_size,
|
||||
video_grid_thw[num_video_tokens][2] // self.omni_processor.merge_size,
|
||||
)
|
||||
.flatten()
|
||||
* mm_inputs["video_second_per_grid"][num_video_tokens]
|
||||
* 25 # FIXME hardcode of position_id_per_seconds=25
|
||||
).long()
|
||||
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]
|
||||
video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
|
||||
audio_chunk_indices = self.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
|
||||
placeholder_string = ""
|
||||
for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))):
|
||||
video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None
|
||||
audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None
|
||||
placeholder_string = "<|vision_bos|>" + "<|audio_bos|>"
|
||||
if video_chunk_index is not None:
|
||||
placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0])
|
||||
if audio_chunk_index is not None:
|
||||
placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0])
|
||||
placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
|
||||
content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1)
|
||||
content = content.replace(AUDIO_PLACEHOLDER, "", 1)
|
||||
num_audio_tokens += 1
|
||||
num_video_tokens += 1
|
||||
|
||||
message["content"] = content
|
||||
|
||||
if len(audios) != num_audio_tokens:
|
||||
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
if len(videos) != num_video_tokens:
|
||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
@dataclass
|
||||
class Qwen2VLPlugin(BasePlugin):
|
||||
@override
|
||||
@@ -1321,7 +1128,7 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
@override
|
||||
def _regularize_videos(
|
||||
self, videos: list["VideoInput"], **kwargs
|
||||
) -> tuple[list[list["ImageObject"]], list[float]]:
|
||||
) -> dict[str, Union[list[list["ImageObject"]], list[float]]]:
|
||||
results, fps_per_video = [], []
|
||||
for video in videos:
|
||||
container = av.open(video, "r")
|
||||
@@ -1336,14 +1143,14 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
if len(frames) % 2 != 0: # qwen2-vl requires even number of frames
|
||||
frames.append(frames[-1])
|
||||
|
||||
frames = self._regularize_images(frames, **kwargs)
|
||||
frames = self._regularize_images(frames, **kwargs)["images"]
|
||||
results.append(frames)
|
||||
if video_stream.duration is None:
|
||||
fps_per_video.append(2.0)
|
||||
else:
|
||||
fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
|
||||
|
||||
return results, fps_per_video
|
||||
return {"videos": results, "fps_per_video": fps_per_video}
|
||||
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
@@ -1360,19 +1167,19 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
images,
|
||||
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
||||
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
||||
)
|
||||
)["images"]
|
||||
mm_inputs.update(image_processor(images, return_tensors="pt"))
|
||||
|
||||
if len(videos) != 0:
|
||||
videos, fps_per_video = self._regularize_videos(
|
||||
video_data = self._regularize_videos(
|
||||
videos,
|
||||
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
|
||||
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
|
||||
video_fps=getattr(processor, "video_fps", 2.0),
|
||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||
)
|
||||
mm_inputs.update(image_processor(images=None, videos=videos, return_tensors="pt"))
|
||||
mm_inputs["fps_per_video"] = fps_per_video
|
||||
mm_inputs.update(image_processor(images=None, videos=video_data["videos"], return_tensors="pt"))
|
||||
mm_inputs["fps_per_video"] = video_data["fps_per_video"]
|
||||
|
||||
return mm_inputs
|
||||
|
||||
@@ -1454,6 +1261,186 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
return mm_inputs
|
||||
|
||||
|
||||
class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: "MMProcessor",
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
||||
mm_inputs = {}
|
||||
if len(images) != 0:
|
||||
images = self._regularize_images(
|
||||
images,
|
||||
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
||||
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
||||
)["images"]
|
||||
mm_inputs.update(image_processor(images, return_tensors="pt"))
|
||||
|
||||
if len(videos) != 0:
|
||||
video_dict = self._regularize_videos(
|
||||
videos,
|
||||
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
|
||||
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
|
||||
video_fps=getattr(processor, "video_fps", 2.0),
|
||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||
)
|
||||
mm_inputs.update(image_processor(images=None, videos=video_dict["videos"], return_tensors="pt"))
|
||||
mm_inputs["fps_per_video"] = video_dict["fps_per_video"]
|
||||
|
||||
if len(audios) != 0:
|
||||
audios = self._regularize_audios(
|
||||
audios,
|
||||
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
||||
)["audios"]
|
||||
mm_inputs.update(
|
||||
feature_extractor(
|
||||
audios,
|
||||
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
||||
return_attention_mask=True,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
)
|
||||
mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask") # prevent conflicts
|
||||
|
||||
return mm_inputs
|
||||
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
messages = deepcopy(messages)
|
||||
if self.expand_mm_tokens:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
else:
|
||||
mm_inputs = {}
|
||||
|
||||
num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0
|
||||
use_audio_in_video = getattr(processor, "use_audio_in_video", False)
|
||||
|
||||
# get length or size from mm_inputs
|
||||
if "feature_attention_mask" in mm_inputs:
|
||||
input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
|
||||
audio_lengths = (input_lengths - 2) // 2 + 1
|
||||
|
||||
if mm_inputs.get("image_grid_thw", None) is not None:
|
||||
image_grid_thw = mm_inputs["image_grid_thw"]
|
||||
merge_length = processor.image_processor.merge_size**2
|
||||
|
||||
if mm_inputs.get("video_grid_thw", None) is not None:
|
||||
video_grid_thw = mm_inputs["video_grid_thw"]
|
||||
merge_length = processor.image_processor.merge_size**2
|
||||
|
||||
if use_audio_in_video:
|
||||
if audio_lengths is None:
|
||||
raise ValueError("audio_lengths should exist when use_audio_in_video is `True`.")
|
||||
|
||||
if not mm_inputs.get("video_grid_thw", None):
|
||||
raise ValueError("video_grid_thw should exist when use_audio_in_video is `True`.")
|
||||
|
||||
positions_list = []
|
||||
for i, message in enumerate(messages): # get multimodal index when use_audio
|
||||
positions = []
|
||||
for special_token in [self.audio_token, self.image_token, self.video_token]:
|
||||
start = 0
|
||||
while True:
|
||||
pos = message[i].find(special_token, start)
|
||||
if pos == -1:
|
||||
break
|
||||
positions.append((pos, special_token))
|
||||
start = pos + len(special_token)
|
||||
|
||||
positions_list.append(positions.sort(key=lambda x: x[0]))
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
# separate with audio-video
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
image_token_replace_length = image_grid_thw[num_image_tokens].prod() // merge_length
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER,
|
||||
f"<|vision_bos|>{self.image_token * image_token_replace_length}<|vision_eos|>",
|
||||
1,
|
||||
)
|
||||
num_image_tokens += 1
|
||||
|
||||
if not use_audio_in_video:
|
||||
while AUDIO_PLACEHOLDER in content:
|
||||
audio_token_replace_length = audio_lengths[num_audio_tokens]
|
||||
content = content.replace(
|
||||
AUDIO_PLACEHOLDER,
|
||||
f"<|audio_bos|>{self.audio_token * audio_token_replace_length}<|audio_eos|>",
|
||||
1,
|
||||
)
|
||||
num_audio_tokens += 1
|
||||
|
||||
# TODO handle video_input and use_audio_in_video
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
video_replace_length = video_grid_thw[num_video_tokens].prod() // merge_length
|
||||
content = content.replace(
|
||||
VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_replace_length}<|vision_eos|>", 1
|
||||
)
|
||||
num_video_tokens += 1
|
||||
|
||||
else: # if use the audio of video # deal video token and audio token togather
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
|
||||
video_t_index = (
|
||||
torch.arange(video_grid_thw[num_video_tokens][0])
|
||||
.view(-1, 1, 1)
|
||||
.expand(
|
||||
-1,
|
||||
video_grid_thw[num_video_tokens][1] // self.image_processor.merge_size,
|
||||
video_grid_thw[num_video_tokens][2] // self.image_processor.merge_size,
|
||||
)
|
||||
.flatten()
|
||||
* mm_inputs["video_second_per_grid"][num_video_tokens]
|
||||
* 25 # FIXME hardcode of position_id_per_seconds=25
|
||||
).long()
|
||||
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]
|
||||
video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
|
||||
audio_chunk_indices = self.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
|
||||
placeholder_string = ""
|
||||
for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))):
|
||||
video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None
|
||||
audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None
|
||||
placeholder_string = "<|vision_bos|>" + "<|audio_bos|>"
|
||||
if video_chunk_index is not None:
|
||||
placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0])
|
||||
if audio_chunk_index is not None:
|
||||
placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0])
|
||||
placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
|
||||
|
||||
content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1)
|
||||
content = content.replace(AUDIO_PLACEHOLDER, "", 1)
|
||||
num_audio_tokens += 1
|
||||
num_video_tokens += 1
|
||||
|
||||
message["content"] = content
|
||||
|
||||
if len(audios) != num_audio_tokens:
|
||||
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
if len(videos) != num_video_tokens:
|
||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoLlavaPlugin(BasePlugin):
|
||||
@override
|
||||
|
||||
Reference in New Issue
Block a user