[model] support audio (#6701)
* support qwen2_audio * improve code * lint * fix * fix * fix --------- Co-authored-by: hiyouga <hiyouga@buaa.edu.cn> Former-commit-id: 5eacb5629e4d7733cd992a63747a1335f2c6a929
This commit is contained in:
@@ -9,8 +9,17 @@ import torch
|
||||
from transformers.image_utils import get_image_size, to_numpy_array
|
||||
from typing_extensions import override
|
||||
|
||||
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
from ..extras.packages import is_pillow_available, is_pyav_available, is_transformers_version_greater_than
|
||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
from ..extras.packages import (
|
||||
is_librosa_available,
|
||||
is_pillow_available,
|
||||
is_pyav_available,
|
||||
is_transformers_version_greater_than,
|
||||
)
|
||||
|
||||
|
||||
if is_librosa_available():
|
||||
import librosa
|
||||
|
||||
|
||||
if is_pillow_available():
|
||||
@@ -31,7 +40,9 @@ if is_transformers_version_greater_than("4.45.0"):
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from av.stream import Stream
|
||||
from numpy.typing import NDArray
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
class EncodedImage(TypedDict):
|
||||
@@ -40,6 +51,7 @@ if TYPE_CHECKING:
|
||||
|
||||
ImageInput = Union[str, bytes, EncodedImage, ImageObject]
|
||||
VideoInput = str
|
||||
AudioInput = Union[str, NDArray]
|
||||
|
||||
|
||||
def _get_paligemma_token_type_ids(
|
||||
@@ -60,15 +72,17 @@ def _get_paligemma_token_type_ids(
|
||||
|
||||
|
||||
class BasePlugin:
|
||||
def __init__(self, image_token: Optional[str], video_token: Optional[str]) -> None:
|
||||
def __init__(self, image_token: Optional[str], video_token: Optional[str], audio_token: Optional[str]) -> None:
|
||||
self.image_token = image_token
|
||||
self.video_token = video_token
|
||||
self.audio_token = audio_token
|
||||
self.expand_mm_tokens = True
|
||||
|
||||
def _validate_input(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
) -> None:
|
||||
r"""
|
||||
Validates if this model accepts the input modalities.
|
||||
@@ -83,11 +97,16 @@ class BasePlugin:
|
||||
"This model does not support video input. Please check whether the correct `template` is used."
|
||||
)
|
||||
|
||||
if len(audios) != 0 and self.audio_token is None:
|
||||
raise ValueError(
|
||||
"This model does not support audio input. Please check whether the correct `template` is used."
|
||||
)
|
||||
|
||||
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
||||
r"""
|
||||
Pre-processes a single image.
|
||||
"""
|
||||
image_resolution: int = kwargs.get("image_resolution")
|
||||
image_resolution: int = kwargs["image_resolution"]
|
||||
if (image.width * image.height) > image_resolution:
|
||||
resize_factor = math.sqrt(image_resolution / (image.width * image.height))
|
||||
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
|
||||
@@ -102,8 +121,8 @@ class BasePlugin:
|
||||
r"""
|
||||
Computes video sample frames according to fps.
|
||||
"""
|
||||
video_fps: float = kwargs.get("video_fps")
|
||||
video_maxlen: int = kwargs.get("video_maxlen")
|
||||
video_fps: float = kwargs["video_fps"]
|
||||
video_maxlen: int = kwargs["video_maxlen"]
|
||||
total_frames = video_stream.frames
|
||||
sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps
|
||||
sample_frames = min(total_frames, video_maxlen, sample_frames)
|
||||
@@ -126,7 +145,7 @@ class BasePlugin:
|
||||
image = Image.open(image["path"])
|
||||
|
||||
if not isinstance(image, ImageObject):
|
||||
raise ValueError(f"Expect input is a list of Images, but got {type(image)}.")
|
||||
raise ValueError(f"Expect input is a list of images, but got {type(image)}.")
|
||||
|
||||
results.append(self._preprocess_image(image, **kwargs))
|
||||
|
||||
@@ -154,10 +173,28 @@ class BasePlugin:
|
||||
|
||||
return results
|
||||
|
||||
def _regularize_audios(self, audios: Sequence["AudioInput"], **kwargs) -> List["NDArray"]:
|
||||
r"""
|
||||
Regularizes audios to avoid error. Including reading and resampling.
|
||||
"""
|
||||
results = []
|
||||
sampling_rate = kwargs["sampling_rate"]
|
||||
for audio in audios:
|
||||
if isinstance(audio, str):
|
||||
audio = librosa.load(audio, sr=sampling_rate)[0]
|
||||
|
||||
if not isinstance(audio, np.ndarray):
|
||||
raise ValueError(f"Expect input is a list of audios, but got {type(audio)}.")
|
||||
|
||||
results.append(audio)
|
||||
|
||||
return results
|
||||
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: "ProcessorMixin",
|
||||
) -> Dict[str, "torch.Tensor"]:
|
||||
r"""
|
||||
@@ -172,15 +209,17 @@ class BasePlugin:
|
||||
|
||||
It holds num_patches == torch.prod(image_grid_thw)
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
|
||||
video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor)
|
||||
input_dict = {"images": None} # default key
|
||||
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
|
||||
mm_inputs = {}
|
||||
|
||||
if len(images) != 0:
|
||||
images = self._regularize_images(
|
||||
images,
|
||||
image_resolution=getattr(processor, "image_resolution", 768 * 768),
|
||||
)
|
||||
input_dict["images"] = images
|
||||
mm_inputs.update(image_processor(images, return_tensors="pt"))
|
||||
|
||||
if len(videos) != 0:
|
||||
videos = self._regularize_videos(
|
||||
@@ -189,16 +228,23 @@ class BasePlugin:
|
||||
video_fps=getattr(processor, "video_fps", 2.0),
|
||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||
)
|
||||
input_dict["videos"] = videos
|
||||
mm_inputs.update(video_processor(videos, return_tensors="pt"))
|
||||
|
||||
mm_inputs = {}
|
||||
if image_processor != video_processor:
|
||||
if input_dict.get("images") is not None:
|
||||
mm_inputs.update(image_processor(input_dict["images"], return_tensors="pt"))
|
||||
if input_dict.get("videos") is not None:
|
||||
mm_inputs.update(video_processor(input_dict["videos"], return_tensors="pt"))
|
||||
elif input_dict.get("images") is not None or input_dict.get("videos") is not None: # same processor (qwen2-vl)
|
||||
mm_inputs.update(image_processor(**input_dict, return_tensors="pt"))
|
||||
if len(audios) != 0:
|
||||
audios = self._regularize_audios(
|
||||
audios,
|
||||
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
|
||||
)
|
||||
mm_inputs.update(
|
||||
feature_extractor(
|
||||
audios,
|
||||
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
|
||||
return_attention_mask=True,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
)
|
||||
mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask") # prevent conflicts
|
||||
|
||||
return mm_inputs
|
||||
|
||||
@@ -207,12 +253,13 @@ class BasePlugin:
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
r"""
|
||||
Pre-processes input messages before tokenization for VLMs.
|
||||
"""
|
||||
self._validate_input(images, videos)
|
||||
self._validate_input(images, videos, audios)
|
||||
return messages
|
||||
|
||||
def process_token_ids(
|
||||
@@ -221,21 +268,24 @@ class BasePlugin:
|
||||
labels: Optional[List[int]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Tuple[List[int], Optional[List[int]]]:
|
||||
r"""
|
||||
Pre-processes token ids after tokenization for VLMs.
|
||||
"""
|
||||
self._validate_input(images, videos)
|
||||
self._validate_input(images, videos, audios)
|
||||
return input_ids, labels
|
||||
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
@@ -247,10 +297,11 @@ class BasePlugin:
|
||||
videos: a list of video inputs, shape (num_videos,)
|
||||
imglens: number of images in each sample, shape (batch_size,)
|
||||
vidlens: number of videos in each sample, shape (batch_size,)
|
||||
audlens: number of audios in each sample, shape (batch_size,)
|
||||
batch_ids: token ids of input samples, shape (batch_size, seq_len)
|
||||
processor: a processor for pre-processing images and videos
|
||||
"""
|
||||
self._validate_input(images, videos)
|
||||
self._validate_input(images, videos, audios)
|
||||
return {}
|
||||
|
||||
|
||||
@@ -261,9 +312,10 @@ class LlavaPlugin(BasePlugin):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
self._validate_input(images, videos, audios)
|
||||
num_image_tokens = 0
|
||||
image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1
|
||||
messages = deepcopy(messages)
|
||||
@@ -285,13 +337,15 @@ class LlavaPlugin(BasePlugin):
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
return self._get_mm_inputs(images, videos, processor)
|
||||
self._validate_input(images, videos, audios)
|
||||
return self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
|
||||
class LlavaNextPlugin(BasePlugin):
|
||||
@@ -301,12 +355,13 @@ class LlavaNextPlugin(BasePlugin):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
self._validate_input(images, videos, audios)
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
if "image_sizes" in mm_inputs:
|
||||
image_sizes = iter(mm_inputs["image_sizes"])
|
||||
|
||||
@@ -339,13 +394,15 @@ class LlavaNextPlugin(BasePlugin):
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
return self._get_mm_inputs(images, videos, processor)
|
||||
self._validate_input(images, videos, audios)
|
||||
return self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
|
||||
class LlavaNextVideoPlugin(BasePlugin):
|
||||
@@ -355,12 +412,13 @@ class LlavaNextVideoPlugin(BasePlugin):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
self._validate_input(images, videos, audios)
|
||||
num_image_tokens, num_video_tokens = 0, 0
|
||||
messages = deepcopy(messages)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
if "pixel_values" in mm_inputs:
|
||||
image_sizes = iter(mm_inputs["image_sizes"])
|
||||
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
||||
@@ -408,13 +466,15 @@ class LlavaNextVideoPlugin(BasePlugin):
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
return self._get_mm_inputs(images, videos, processor)
|
||||
self._validate_input(images, videos, audios)
|
||||
return self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
|
||||
class MiniCPMVPlugin(BasePlugin):
|
||||
@@ -424,26 +484,30 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
self._validate_input(images, videos, audios)
|
||||
num_image_tokens = 0
|
||||
num_video_tokens = 0
|
||||
num_audio_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
mm_inputs = {}
|
||||
audio_inputs = {}
|
||||
audio_parts = []
|
||||
if len(images) != 0 and len(videos) != 0:
|
||||
raise ValueError("MiniCPM-V model does not support input images and videos at the same time.")
|
||||
|
||||
if len(videos) != 0:
|
||||
max_slice_nums = 2
|
||||
use_image_id = False
|
||||
mm_inputs = self._get_mm_inputs([], videos, processor)
|
||||
mm_inputs = self._get_mm_inputs([], videos, [], processor)
|
||||
else:
|
||||
max_slice_nums = image_processor.max_slice_nums
|
||||
use_image_id = image_processor.use_image_id
|
||||
|
||||
for message in messages:
|
||||
for i, message in enumerate(messages):
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
||||
@@ -454,15 +518,25 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
|
||||
num_video_tokens += 1
|
||||
|
||||
message["content"] = content.replace("{{image}}", "(<image>./</image>)")
|
||||
while AUDIO_PLACEHOLDER in content:
|
||||
audio_parts.append(i)
|
||||
content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1)
|
||||
num_audio_tokens += 1
|
||||
|
||||
message["content"] = content.replace("{{image}}", "(<image>./</image>)").replace(
|
||||
"{{audio}}", "(<audio>./</audio>)"
|
||||
)
|
||||
|
||||
if num_image_tokens > 0:
|
||||
mm_inputs = self._get_mm_inputs(images, [], processor)
|
||||
mm_inputs = self._get_mm_inputs(images, [], [], processor)
|
||||
|
||||
if num_audio_tokens > 0:
|
||||
audio_parts_ls = [audio_parts]
|
||||
audio_inputs = self._get_mm_inputs([], [], audios, processor, audio_parts_ls=audio_parts_ls, ret_phs=True)
|
||||
|
||||
if mm_inputs:
|
||||
pattern = "(<image>./</image>)"
|
||||
image_sizes = mm_inputs["image_sizes"]
|
||||
|
||||
for index, message in enumerate(messages):
|
||||
text = message["content"]
|
||||
image_tags = re.findall(pattern, text)
|
||||
@@ -480,12 +554,29 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
final_text += text_chunks[-1]
|
||||
messages[index]["content"] = final_text
|
||||
|
||||
if audio_inputs:
|
||||
pattern = "(<audio>./</audio>)"
|
||||
for index, message in enumerate(messages):
|
||||
text = message["content"]
|
||||
audio_tags = re.findall(pattern, text)
|
||||
text_chunks = text.split(pattern)
|
||||
final_text = ""
|
||||
for i in range(len(audio_tags)):
|
||||
audio_placeholder = audio_inputs["audio_phs"][0][i]
|
||||
final_text = final_text + text_chunks[i] + audio_placeholder
|
||||
|
||||
final_text += text_chunks[-1]
|
||||
messages[index]["content"] = final_text
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
if len(videos) != num_video_tokens:
|
||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
if len(audios) != num_audio_tokens:
|
||||
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
@@ -493,6 +584,7 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: "ProcessorMixin",
|
||||
**kwargs,
|
||||
) -> Dict[str, "torch.Tensor"]:
|
||||
@@ -528,6 +620,30 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt")
|
||||
mm_inputs.update(video_inputs)
|
||||
|
||||
if len(audios) != 0:
|
||||
audio_parts_ls = kwargs.get("audio_parts_ls", None)
|
||||
new_audios = []
|
||||
for audio in audios:
|
||||
if not isinstance(audio, np.ndarray):
|
||||
audio = librosa.load(audio, sr=processor.feature_extractor.sampling_rate)[0]
|
||||
new_audios.append(audio)
|
||||
|
||||
audios_ls = []
|
||||
idx = 0
|
||||
for audio_parts in audio_parts_ls:
|
||||
audios_ls.append(new_audios[idx : idx + len(audio_parts)])
|
||||
idx += len(audio_parts)
|
||||
|
||||
audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract(
|
||||
audios_ls,
|
||||
audio_parts_ls,
|
||||
chunk_input=True,
|
||||
sampling_rate=16000,
|
||||
)
|
||||
mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
|
||||
if kwargs.get("ret_phs", False):
|
||||
mm_inputs.update({"audio_phs": audio_phs})
|
||||
|
||||
return mm_inputs
|
||||
|
||||
@override
|
||||
@@ -535,12 +651,16 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
self._validate_input(images, videos, audios)
|
||||
|
||||
# image bound
|
||||
image_bounds_list = []
|
||||
valid_image_nums_ls = []
|
||||
for i, input_ids in enumerate(batch_ids):
|
||||
@@ -561,8 +681,38 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
)
|
||||
image_bounds_list.append(image_bounds)
|
||||
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor, valid_image_nums_ls=valid_image_nums_ls)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, [], processor, valid_image_nums_ls=valid_image_nums_ls)
|
||||
if "tgt_sizes" not in mm_inputs:
|
||||
dummy_data = [torch.empty(0) for _ in range(len(batch_ids))]
|
||||
mm_inputs.update({"tgt_sizes": dummy_data, "pixel_values": dummy_data, "image_sizes": dummy_data})
|
||||
|
||||
mm_inputs.update({"image_bound": image_bounds_list})
|
||||
|
||||
if len(audios) > 0:
|
||||
# audio bound
|
||||
audio_bounds_ls = []
|
||||
spk_bounds_ls = []
|
||||
audio_parts_ls = []
|
||||
|
||||
for input_ids, audiolen in zip(batch_ids, audlens):
|
||||
input_ids_ = torch.tensor(input_ids)
|
||||
audio_start_idx = torch.where(input_ids_ == processor.tokenizer.audio_start_id)[0]
|
||||
audio_end_idx = torch.where(input_ids_ == processor.tokenizer.audio_end_id)[0]
|
||||
assert len(audio_start_idx) == len(audio_end_idx)
|
||||
audio_bounds = torch.hstack([(audio_start_idx + 1).unsqueeze(-1), audio_end_idx.unsqueeze(-1)])
|
||||
audio_bounds_ls.append(audio_bounds)
|
||||
audio_parts_ls.append(list(range(audiolen)))
|
||||
|
||||
spk_start_idx = torch.where(input_ids_ == processor.tokenizer.spk_start_id)[0]
|
||||
spk_end_idx = torch.where(input_ids_ == processor.tokenizer.spk_end_id)[0]
|
||||
assert len(spk_start_idx) == len(spk_end_idx)
|
||||
spk_bounds = torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)])
|
||||
spk_bounds_ls.append(spk_bounds)
|
||||
|
||||
audio_inputs = self._get_mm_inputs([], [], audios, processor, audio_parts_ls=audio_parts_ls)
|
||||
mm_inputs.update(audio_inputs)
|
||||
mm_inputs.update({"audio_bounds": audio_bounds_ls, "spk_bounds": spk_bounds_ls})
|
||||
|
||||
return mm_inputs
|
||||
|
||||
|
||||
@@ -573,9 +723,10 @@ class MllamaPlugin(BasePlugin):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
self._validate_input(images, videos, audios)
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
for message in messages:
|
||||
@@ -593,6 +744,7 @@ class MllamaPlugin(BasePlugin):
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: "ProcessorMixin",
|
||||
**kwargs,
|
||||
) -> Dict[str, "torch.Tensor"]:
|
||||
@@ -617,17 +769,20 @@ class MllamaPlugin(BasePlugin):
|
||||
|
||||
return image_processor(batch_images, return_tensors="pt")
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor, imglens=imglens)
|
||||
self._validate_input(images, videos, audios)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens=imglens)
|
||||
num_tiles = mm_inputs.pop("num_tiles")
|
||||
image_token_id = getattr(processor, "image_token_id")
|
||||
max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
|
||||
@@ -652,9 +807,10 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
self._validate_input(images, videos, audios)
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
for message in messages:
|
||||
@@ -677,10 +833,11 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
labels: Optional[List[int]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Tuple[List[int], Optional[List[int]]]:
|
||||
self._validate_input(images, videos)
|
||||
self._validate_input(images, videos, audios)
|
||||
num_images = len(images)
|
||||
image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0 # skip mm token
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
||||
@@ -695,14 +852,16 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
self._validate_input(images, videos, audios)
|
||||
seqlens = [len(input_ids) for input_ids in batch_ids]
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
|
||||
return mm_inputs
|
||||
|
||||
@@ -714,9 +873,10 @@ class PixtralPlugin(BasePlugin):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
self._validate_input(images, videos, audios)
|
||||
patch_size = getattr(processor, "patch_size")
|
||||
image_token = getattr(processor, "image_token")
|
||||
image_break_token = getattr(processor, "image_break_token")
|
||||
@@ -724,7 +884,7 @@ class PixtralPlugin(BasePlugin):
|
||||
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
image_input_sizes = mm_inputs.get("image_sizes", None)
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
@@ -759,13 +919,15 @@ class PixtralPlugin(BasePlugin):
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
self._validate_input(images, videos, audios)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
if mm_inputs.get("pixel_values"):
|
||||
mm_inputs["pixel_values"] = mm_inputs["pixel_values"][0]
|
||||
|
||||
@@ -773,6 +935,58 @@ class PixtralPlugin(BasePlugin):
|
||||
return mm_inputs
|
||||
|
||||
|
||||
class Qwen2AudioPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos, audios)
|
||||
bos_token: str = getattr(processor, "audio_bos_token")
|
||||
eos_token: str = getattr(processor, "audio_eos_token")
|
||||
mm_inputs = self._get_mm_inputs([], [], audios, processor)
|
||||
if "feature_attention_mask" in mm_inputs:
|
||||
audio_lengths = mm_inputs["feature_attention_mask"].sum(-1).tolist()
|
||||
|
||||
num_audio_tokens = 0
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while AUDIO_PLACEHOLDER in content:
|
||||
audio_length = audio_lengths.pop(0)
|
||||
input_length = (audio_length - 1) // 2 + 1
|
||||
audio_seqlen = (input_length - 2) // 2 + 1 if self.expand_mm_tokens else 1
|
||||
content = content.replace(
|
||||
AUDIO_PLACEHOLDER, f"{bos_token}{self.audio_token * audio_seqlen}{eos_token}", 1
|
||||
)
|
||||
num_audio_tokens += 1
|
||||
|
||||
message["content"] = content
|
||||
|
||||
if len(audios) != num_audio_tokens:
|
||||
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos, audios)
|
||||
return self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
|
||||
class Qwen2vlPlugin(BasePlugin):
|
||||
@override
|
||||
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
||||
@@ -820,12 +1034,13 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
self._validate_input(images, videos, audios)
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
merge_length: int = getattr(image_processor, "merge_size") ** 2
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
||||
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
||||
|
||||
@@ -868,13 +1083,15 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
self._validate_input(images, videos, audios)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
if "second_per_grid_ts" in getattr(image_processor, "model_input_names", []) and "video_grid_thw" in mm_inputs:
|
||||
video_fps = getattr(processor, "video_fps", 2.0)
|
||||
@@ -892,12 +1109,13 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
self._validate_input(images, videos, audios)
|
||||
num_image_tokens, num_video_tokens = 0, 0
|
||||
messages = deepcopy(messages)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
num_frames = 0
|
||||
has_images = "pixel_values_images" in mm_inputs
|
||||
has_videos = "pixel_values_videos" in mm_inputs
|
||||
@@ -945,13 +1163,15 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
return self._get_mm_inputs(images, videos, processor)
|
||||
self._validate_input(images, videos, audios)
|
||||
return self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
|
||||
PLUGINS = {
|
||||
@@ -963,6 +1183,7 @@ PLUGINS = {
|
||||
"mllama": MllamaPlugin,
|
||||
"paligemma": PaliGemmaPlugin,
|
||||
"pixtral": PixtralPlugin,
|
||||
"qwen2_audio": Qwen2AudioPlugin,
|
||||
"qwen2_vl": Qwen2vlPlugin,
|
||||
"video_llava": VideoLlavaPlugin,
|
||||
}
|
||||
@@ -972,9 +1193,10 @@ def get_mm_plugin(
|
||||
name: str,
|
||||
image_token: Optional[str] = None,
|
||||
video_token: Optional[str] = None,
|
||||
audio_token: Optional[str] = None,
|
||||
) -> "BasePlugin":
|
||||
plugin_class = PLUGINS.get(name, None)
|
||||
if plugin_class is None:
|
||||
raise ValueError(f"Multimodal plugin `{name}` not found.")
|
||||
|
||||
return plugin_class(image_token, video_token)
|
||||
return plugin_class(image_token, video_token, audio_token)
|
||||
|
||||
Reference in New Issue
Block a user