|
|
|
|
@@ -1,3 +1,20 @@
|
|
|
|
|
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
|
|
|
|
|
#
|
|
|
|
|
# This code is inspired by the HuggingFace's Transformers library.
|
|
|
|
|
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/processing_llava.py
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
import inspect
|
|
|
|
|
import math
|
|
|
|
|
import re
|
|
|
|
|
@@ -5,7 +22,7 @@ from collections.abc import Sequence
|
|
|
|
|
from copy import deepcopy
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
from io import BytesIO
|
|
|
|
|
from typing import TYPE_CHECKING, Optional, TypedDict, Union
|
|
|
|
|
from typing import TYPE_CHECKING, Literal, Optional, TypedDict, Union
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
import torch
|
|
|
|
|
@@ -56,24 +73,63 @@ if TYPE_CHECKING:
|
|
|
|
|
VideoInput = str
|
|
|
|
|
AudioInput = Union[str, NDArray]
|
|
|
|
|
|
|
|
|
|
class MMProcessor(ProcessorMixin):
|
|
|
|
|
patch_size: int
|
|
|
|
|
image_seq_length: int
|
|
|
|
|
num_additional_image_tokens: int
|
|
|
|
|
vision_feature_select_strategy: Literal["default", "full"]
|
|
|
|
|
|
|
|
|
|
def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_paligemma_token_type_ids(
|
|
|
|
|
imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin"
|
|
|
|
|
imglens: Sequence[int], seqlens: Sequence[int], processor: "MMProcessor"
|
|
|
|
|
) -> list[list[int]]:
|
|
|
|
|
r"""Get paligemma token type ids for computing loss.
|
|
|
|
|
|
|
|
|
|
It is slightly different with the original token type ids where the prompt part is 0.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
batch_token_type_ids: shape (batch_size, sequence_length)
|
|
|
|
|
batch_token_type_ids: shape (batch_size, seq_length)
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
batch_token_type_ids = []
|
|
|
|
|
for imglen, seqlen in zip(imglens, seqlens):
|
|
|
|
|
image_seqlen = imglen * getattr(processor, "image_seqlen")
|
|
|
|
|
image_seqlen = imglen * processor.image_seq_length
|
|
|
|
|
batch_token_type_ids.append([0] * image_seqlen + [1] * (seqlen - image_seqlen))
|
|
|
|
|
|
|
|
|
|
return batch_token_type_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_gemma3_token_type_ids(batch_ids: list[list[int]], processor: "MMProcessor"):
|
|
|
|
|
r"""Get gemma3 token type ids for computing loss.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
batch_token_type_ids: shape (batch_size, seq_length)
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
image_token_id: int = getattr(processor, "image_token_id")
|
|
|
|
|
batch_token_type_ids = []
|
|
|
|
|
for token_ids in batch_ids:
|
|
|
|
|
token_ids = np.array(token_ids)
|
|
|
|
|
token_type_ids = np.zeros_like(token_ids)
|
|
|
|
|
token_type_ids[token_ids == image_token_id] = 1
|
|
|
|
|
batch_token_type_ids.append(token_type_ids.tolist())
|
|
|
|
|
|
|
|
|
|
return batch_token_type_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _make_batched_images(images: Sequence["ImageObject"], imglens: list[int]) -> list[list["ImageObject"]]:
|
|
|
|
|
r"""Make nested list of images."""
|
|
|
|
|
batch_images = []
|
|
|
|
|
for imglen in imglens:
|
|
|
|
|
batch_images.append(images[:imglen])
|
|
|
|
|
images = images[imglen:]
|
|
|
|
|
|
|
|
|
|
return batch_images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class MMPluginMixin:
|
|
|
|
|
image_token: Optional[str]
|
|
|
|
|
@@ -83,7 +139,7 @@ class MMPluginMixin:
|
|
|
|
|
|
|
|
|
|
def _validate_input(
|
|
|
|
|
self,
|
|
|
|
|
processor: Optional["ProcessorMixin"],
|
|
|
|
|
processor: Optional["MMProcessor"],
|
|
|
|
|
images: Sequence["ImageInput"],
|
|
|
|
|
videos: Sequence["VideoInput"],
|
|
|
|
|
audios: Sequence["AudioInput"],
|
|
|
|
|
@@ -204,7 +260,8 @@ class MMPluginMixin:
|
|
|
|
|
images: Sequence["ImageInput"],
|
|
|
|
|
videos: Sequence["VideoInput"],
|
|
|
|
|
audios: Sequence["AudioInput"],
|
|
|
|
|
processor: "ProcessorMixin",
|
|
|
|
|
processor: "MMProcessor",
|
|
|
|
|
imglens: Optional[list[int]] = None,
|
|
|
|
|
) -> dict[str, "torch.Tensor"]:
|
|
|
|
|
r"""Process visual inputs.
|
|
|
|
|
|
|
|
|
|
@@ -214,23 +271,34 @@ class MMPluginMixin:
|
|
|
|
|
Returns: (qwen2-vl)
|
|
|
|
|
pixel_values: tensor with shape (num_patches, patch_dim)
|
|
|
|
|
image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
|
|
|
|
|
where num_patches == torch.prod(image_grid_thw)
|
|
|
|
|
|
|
|
|
|
Returns: (mllama)
|
|
|
|
|
pixel_values: tensor with shape
|
|
|
|
|
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
|
|
|
|
|
For example, (2, 1, 4, 3, 560, 560).
|
|
|
|
|
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
|
|
|
|
|
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
|
|
|
|
|
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
|
|
|
|
|
|
|
|
|
|
It holds num_patches == torch.prod(image_grid_thw)
|
|
|
|
|
"""
|
|
|
|
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
|
|
|
|
video_processor: BaseImageProcessor = getattr(processor, "video_processor", image_processor)
|
|
|
|
|
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
|
|
|
|
mm_inputs = {}
|
|
|
|
|
|
|
|
|
|
if len(images) != 0:
|
|
|
|
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
|
|
|
|
images = self._regularize_images(
|
|
|
|
|
images,
|
|
|
|
|
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
|
|
|
|
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
|
|
|
|
)
|
|
|
|
|
if imglens is not None:
|
|
|
|
|
images = _make_batched_images(images, imglens)
|
|
|
|
|
|
|
|
|
|
mm_inputs.update(image_processor(images, return_tensors="pt"))
|
|
|
|
|
|
|
|
|
|
if len(videos) != 0:
|
|
|
|
|
video_processor: BaseImageProcessor = getattr(
|
|
|
|
|
processor, "video_processor", getattr(processor, "image_processor", None)
|
|
|
|
|
)
|
|
|
|
|
videos = self._regularize_videos(
|
|
|
|
|
videos,
|
|
|
|
|
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
|
|
|
|
|
@@ -244,6 +312,7 @@ class MMPluginMixin:
|
|
|
|
|
mm_inputs.update(video_processor(videos, return_tensors="pt"))
|
|
|
|
|
|
|
|
|
|
if len(audios) != 0:
|
|
|
|
|
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
|
|
|
|
audios = self._regularize_audios(
|
|
|
|
|
audios,
|
|
|
|
|
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
|
|
|
|
|
@@ -270,9 +339,9 @@ class BasePlugin(MMPluginMixin):
|
|
|
|
|
images: Sequence["ImageInput"],
|
|
|
|
|
videos: Sequence["VideoInput"],
|
|
|
|
|
audios: Sequence["AudioInput"],
|
|
|
|
|
processor: Optional["ProcessorMixin"],
|
|
|
|
|
processor: Optional["MMProcessor"],
|
|
|
|
|
) -> list[dict[str, str]]:
|
|
|
|
|
r"""Pre-processes input messages before tokenization for VLMs."""
|
|
|
|
|
r"""Pre-process input messages before tokenization for VLMs."""
|
|
|
|
|
self._validate_input(processor, images, videos, audios)
|
|
|
|
|
return messages
|
|
|
|
|
|
|
|
|
|
@@ -284,9 +353,9 @@ class BasePlugin(MMPluginMixin):
|
|
|
|
|
videos: Sequence["VideoInput"],
|
|
|
|
|
audios: Sequence["AudioInput"],
|
|
|
|
|
tokenizer: "PreTrainedTokenizer",
|
|
|
|
|
processor: Optional["ProcessorMixin"],
|
|
|
|
|
processor: Optional["MMProcessor"],
|
|
|
|
|
) -> tuple[list[int], Optional[list[int]]]:
|
|
|
|
|
r"""Pre-processes token ids after tokenization for VLMs."""
|
|
|
|
|
r"""Pre-process token ids after tokenization for VLMs."""
|
|
|
|
|
self._validate_input(processor, images, videos, audios)
|
|
|
|
|
return input_ids, labels
|
|
|
|
|
|
|
|
|
|
@@ -299,7 +368,7 @@ class BasePlugin(MMPluginMixin):
|
|
|
|
|
vidlens: Sequence[int],
|
|
|
|
|
audlens: Sequence[int],
|
|
|
|
|
batch_ids: Sequence[list[int]],
|
|
|
|
|
processor: Optional["ProcessorMixin"],
|
|
|
|
|
processor: Optional["MMProcessor"],
|
|
|
|
|
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
|
|
|
|
r"""Build batched multimodal inputs for VLMs.
|
|
|
|
|
|
|
|
|
|
@@ -315,11 +384,11 @@ class BasePlugin(MMPluginMixin):
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
self._validate_input(processor, images, videos, audios)
|
|
|
|
|
return {}
|
|
|
|
|
return self._get_mm_inputs(images, videos, audios, processor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class LlavaPlugin(BasePlugin):
|
|
|
|
|
class Gemma3Plugin(BasePlugin):
|
|
|
|
|
@override
|
|
|
|
|
def process_messages(
|
|
|
|
|
self,
|
|
|
|
|
@@ -327,19 +396,21 @@ class LlavaPlugin(BasePlugin):
|
|
|
|
|
images: Sequence["ImageInput"],
|
|
|
|
|
videos: Sequence["VideoInput"],
|
|
|
|
|
audios: Sequence["AudioInput"],
|
|
|
|
|
processor: Optional["ProcessorMixin"],
|
|
|
|
|
processor: Optional["MMProcessor"],
|
|
|
|
|
) -> list[dict[str, str]]:
|
|
|
|
|
self._validate_input(processor, images, videos, audios)
|
|
|
|
|
num_image_tokens = 0
|
|
|
|
|
image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1
|
|
|
|
|
messages = deepcopy(messages)
|
|
|
|
|
boi_token: str = getattr(processor, "boi_token")
|
|
|
|
|
full_image_sequence: str = getattr(processor, "full_image_sequence")
|
|
|
|
|
image_str = full_image_sequence if self.expand_mm_tokens else boi_token
|
|
|
|
|
for message in messages:
|
|
|
|
|
content = message["content"]
|
|
|
|
|
while IMAGE_PLACEHOLDER in content:
|
|
|
|
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
|
|
|
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
|
|
|
|
num_image_tokens += 1
|
|
|
|
|
|
|
|
|
|
message["content"] = content.replace("{{image}}", self.image_token)
|
|
|
|
|
message["content"] = content.replace("{{image}}", image_str)
|
|
|
|
|
|
|
|
|
|
if len(images) != num_image_tokens:
|
|
|
|
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
|
|
|
@@ -356,10 +427,53 @@ class LlavaPlugin(BasePlugin):
|
|
|
|
|
vidlens: Sequence[int],
|
|
|
|
|
audlens: Sequence[int],
|
|
|
|
|
batch_ids: Sequence[list[int]],
|
|
|
|
|
processor: Optional["ProcessorMixin"],
|
|
|
|
|
processor: Optional["MMProcessor"],
|
|
|
|
|
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
|
|
|
|
self._validate_input(processor, images, videos, audios)
|
|
|
|
|
return self._get_mm_inputs(images, videos, audios, processor)
|
|
|
|
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
|
|
|
|
mm_inputs.pop("num_crops", None)
|
|
|
|
|
mm_inputs["token_type_ids"] = _get_gemma3_token_type_ids(batch_ids, processor)
|
|
|
|
|
return mm_inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class LlavaPlugin(BasePlugin):
|
|
|
|
|
@override
|
|
|
|
|
def process_messages(
|
|
|
|
|
self,
|
|
|
|
|
messages: Sequence[dict[str, str]],
|
|
|
|
|
images: Sequence["ImageInput"],
|
|
|
|
|
videos: Sequence["VideoInput"],
|
|
|
|
|
audios: Sequence["AudioInput"],
|
|
|
|
|
processor: Optional["MMProcessor"],
|
|
|
|
|
) -> list[dict[str, str]]:
|
|
|
|
|
self._validate_input(processor, images, videos, audios)
|
|
|
|
|
num_image_tokens = 0
|
|
|
|
|
messages = deepcopy(messages)
|
|
|
|
|
if self.expand_mm_tokens:
|
|
|
|
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
|
|
|
|
if "pixel_values" in mm_inputs:
|
|
|
|
|
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0]))
|
|
|
|
|
image_seqlen = (height // processor.patch_size) * (
|
|
|
|
|
width // processor.patch_size
|
|
|
|
|
) + processor.num_additional_image_tokens
|
|
|
|
|
if processor.vision_feature_select_strategy == "default":
|
|
|
|
|
image_seqlen -= 1
|
|
|
|
|
else:
|
|
|
|
|
image_seqlen = 1
|
|
|
|
|
|
|
|
|
|
for message in messages:
|
|
|
|
|
content = message["content"]
|
|
|
|
|
while IMAGE_PLACEHOLDER in content:
|
|
|
|
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
|
|
|
|
num_image_tokens += 1
|
|
|
|
|
|
|
|
|
|
message["content"] = content.replace("{{image}}", self.image_token)
|
|
|
|
|
|
|
|
|
|
if len(images) != num_image_tokens:
|
|
|
|
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
|
|
|
|
|
|
|
|
return messages
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
@@ -371,15 +485,16 @@ class LlavaNextPlugin(BasePlugin):
|
|
|
|
|
images: Sequence["ImageInput"],
|
|
|
|
|
videos: Sequence["VideoInput"],
|
|
|
|
|
audios: Sequence["AudioInput"],
|
|
|
|
|
processor: Optional["ProcessorMixin"],
|
|
|
|
|
processor: Optional["MMProcessor"],
|
|
|
|
|
) -> list[dict[str, str]]:
|
|
|
|
|
self._validate_input(processor, images, videos, audios)
|
|
|
|
|
num_image_tokens = 0
|
|
|
|
|
messages = deepcopy(messages)
|
|
|
|
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
|
|
|
|
if "pixel_values" in mm_inputs:
|
|
|
|
|
image_sizes = iter(mm_inputs["image_sizes"].tolist())
|
|
|
|
|
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
|
|
|
|
if self.expand_mm_tokens:
|
|
|
|
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
|
|
|
|
if "pixel_values" in mm_inputs:
|
|
|
|
|
image_sizes = iter(mm_inputs["image_sizes"].tolist())
|
|
|
|
|
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
|
|
|
|
|
|
|
|
|
for message in messages:
|
|
|
|
|
content = message["content"]
|
|
|
|
|
@@ -387,7 +502,7 @@ class LlavaNextPlugin(BasePlugin):
|
|
|
|
|
if self.expand_mm_tokens:
|
|
|
|
|
orig_height, orig_width = next(image_sizes)
|
|
|
|
|
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
|
|
|
|
if getattr(processor, "vision_feature_select_strategy", "default") == "default":
|
|
|
|
|
if processor.vision_feature_select_strategy == "default":
|
|
|
|
|
image_seqlen -= 1
|
|
|
|
|
else:
|
|
|
|
|
image_seqlen = 1
|
|
|
|
|
@@ -402,21 +517,6 @@ class LlavaNextPlugin(BasePlugin):
|
|
|
|
|
|
|
|
|
|
return messages
|
|
|
|
|
|
|
|
|
|
@override
|
|
|
|
|
def get_mm_inputs(
|
|
|
|
|
self,
|
|
|
|
|
images: Sequence["ImageInput"],
|
|
|
|
|
videos: Sequence["VideoInput"],
|
|
|
|
|
audios: Sequence["AudioInput"],
|
|
|
|
|
imglens: Sequence[int],
|
|
|
|
|
vidlens: Sequence[int],
|
|
|
|
|
audlens: Sequence[int],
|
|
|
|
|
batch_ids: Sequence[list[int]],
|
|
|
|
|
processor: Optional["ProcessorMixin"],
|
|
|
|
|
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
|
|
|
|
self._validate_input(processor, images, videos, audios)
|
|
|
|
|
return self._get_mm_inputs(images, videos, audios, processor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class LlavaNextVideoPlugin(BasePlugin):
|
|
|
|
|
@@ -427,48 +527,50 @@ class LlavaNextVideoPlugin(BasePlugin):
|
|
|
|
|
images: Sequence["ImageInput"],
|
|
|
|
|
videos: Sequence["VideoInput"],
|
|
|
|
|
audios: Sequence["AudioInput"],
|
|
|
|
|
processor: Optional["ProcessorMixin"],
|
|
|
|
|
processor: Optional["MMProcessor"],
|
|
|
|
|
) -> list[dict[str, str]]:
|
|
|
|
|
self._validate_input(processor, images, videos, audios)
|
|
|
|
|
num_image_tokens, num_video_tokens = 0, 0
|
|
|
|
|
messages = deepcopy(messages)
|
|
|
|
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
|
|
|
|
if "pixel_values" in mm_inputs:
|
|
|
|
|
image_sizes = iter(mm_inputs["image_sizes"].tolist())
|
|
|
|
|
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
|
|
|
|
for message in messages:
|
|
|
|
|
content = message["content"]
|
|
|
|
|
while IMAGE_PLACEHOLDER in content:
|
|
|
|
|
if self.expand_mm_tokens:
|
|
|
|
|
orig_height, orig_width = next(image_sizes)
|
|
|
|
|
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
|
|
|
|
if getattr(processor, "vision_feature_select_strategy", "default") == "default":
|
|
|
|
|
image_seqlen -= 1
|
|
|
|
|
else:
|
|
|
|
|
image_seqlen = 1
|
|
|
|
|
if self.expand_mm_tokens:
|
|
|
|
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
|
|
|
|
if "pixel_values" in mm_inputs:
|
|
|
|
|
image_sizes = iter(mm_inputs["image_sizes"].tolist())
|
|
|
|
|
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
|
|
|
|
|
|
|
|
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
|
|
|
|
num_image_tokens += 1
|
|
|
|
|
for message in messages:
|
|
|
|
|
content = message["content"]
|
|
|
|
|
while IMAGE_PLACEHOLDER in content:
|
|
|
|
|
if self.expand_mm_tokens:
|
|
|
|
|
orig_height, orig_width = next(image_sizes)
|
|
|
|
|
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
|
|
|
|
if processor.vision_feature_select_strategy == "default":
|
|
|
|
|
image_seqlen -= 1
|
|
|
|
|
else:
|
|
|
|
|
image_seqlen = 1
|
|
|
|
|
|
|
|
|
|
message["content"] = content.replace("{{image}}", self.image_token)
|
|
|
|
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
|
|
|
|
num_image_tokens += 1
|
|
|
|
|
|
|
|
|
|
if "pixel_values_videos" in mm_inputs:
|
|
|
|
|
if self.expand_mm_tokens:
|
|
|
|
|
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
|
|
|
|
height, width = get_image_size(pixel_values_video[0])
|
|
|
|
|
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
|
|
|
|
message["content"] = content.replace("{{image}}", self.image_token)
|
|
|
|
|
|
|
|
|
|
if self.expand_mm_tokens:
|
|
|
|
|
if "pixel_values_videos" in mm_inputs:
|
|
|
|
|
one_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
|
|
|
|
height, width = get_image_size(one_video[0])
|
|
|
|
|
num_frames = one_video.shape[0] # frame dim is always after batch dim
|
|
|
|
|
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
|
|
|
|
|
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
|
|
|
|
|
else:
|
|
|
|
|
video_seqlen = 1
|
|
|
|
|
else:
|
|
|
|
|
video_seqlen = 1
|
|
|
|
|
|
|
|
|
|
for message in messages:
|
|
|
|
|
content = message["content"]
|
|
|
|
|
while VIDEO_PLACEHOLDER in content:
|
|
|
|
|
num_video_tokens += 1
|
|
|
|
|
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
|
|
|
|
|
for message in messages:
|
|
|
|
|
content = message["content"]
|
|
|
|
|
while VIDEO_PLACEHOLDER in content:
|
|
|
|
|
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
|
|
|
|
|
num_video_tokens += 1
|
|
|
|
|
|
|
|
|
|
message["content"] = content.replace("{{video}}", self.video_token)
|
|
|
|
|
message["content"] = content.replace("{{video}}", self.video_token)
|
|
|
|
|
|
|
|
|
|
if len(images) != num_image_tokens:
|
|
|
|
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
|
|
|
@@ -478,21 +580,6 @@ class LlavaNextVideoPlugin(BasePlugin):
|
|
|
|
|
|
|
|
|
|
return messages
|
|
|
|
|
|
|
|
|
|
@override
|
|
|
|
|
def get_mm_inputs(
|
|
|
|
|
self,
|
|
|
|
|
images: Sequence["ImageInput"],
|
|
|
|
|
videos: Sequence["VideoInput"],
|
|
|
|
|
audios: Sequence["AudioInput"],
|
|
|
|
|
imglens: Sequence[int],
|
|
|
|
|
vidlens: Sequence[int],
|
|
|
|
|
audlens: Sequence[int],
|
|
|
|
|
batch_ids: Sequence[list[int]],
|
|
|
|
|
processor: Optional["ProcessorMixin"],
|
|
|
|
|
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
|
|
|
|
self._validate_input(processor, images, videos, audios)
|
|
|
|
|
return self._get_mm_inputs(images, videos, audios, processor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class MiniCPMVPlugin(BasePlugin):
|
|
|
|
|
@@ -503,7 +590,7 @@ class MiniCPMVPlugin(BasePlugin):
|
|
|
|
|
images: Sequence["ImageInput"],
|
|
|
|
|
videos: Sequence["VideoInput"],
|
|
|
|
|
audios: Sequence["AudioInput"],
|
|
|
|
|
processor: Optional["ProcessorMixin"],
|
|
|
|
|
processor: Optional["MMProcessor"],
|
|
|
|
|
) -> list[dict[str, str]]:
|
|
|
|
|
self._validate_input(processor, images, videos, audios)
|
|
|
|
|
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
|
|
|
|
|
@@ -602,7 +689,7 @@ class MiniCPMVPlugin(BasePlugin):
|
|
|
|
|
images: Sequence["ImageInput"],
|
|
|
|
|
videos: Sequence["VideoInput"],
|
|
|
|
|
audios: Sequence["AudioInput"],
|
|
|
|
|
processor: "ProcessorMixin",
|
|
|
|
|
processor: "MMProcessor",
|
|
|
|
|
**kwargs,
|
|
|
|
|
) -> dict[str, "torch.Tensor"]:
|
|
|
|
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
|
|
|
|
@@ -677,7 +764,7 @@ class MiniCPMVPlugin(BasePlugin):
|
|
|
|
|
vidlens: Sequence[int],
|
|
|
|
|
audlens: Sequence[int],
|
|
|
|
|
batch_ids: Sequence[list[int]],
|
|
|
|
|
processor: Optional["ProcessorMixin"],
|
|
|
|
|
processor: Optional["MMProcessor"],
|
|
|
|
|
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
|
|
|
|
self._validate_input(processor, images, videos, audios)
|
|
|
|
|
# image bound
|
|
|
|
|
@@ -745,7 +832,7 @@ class MllamaPlugin(BasePlugin):
|
|
|
|
|
images: Sequence["ImageInput"],
|
|
|
|
|
videos: Sequence["VideoInput"],
|
|
|
|
|
audios: Sequence["AudioInput"],
|
|
|
|
|
processor: Optional["ProcessorMixin"],
|
|
|
|
|
processor: Optional["MMProcessor"],
|
|
|
|
|
) -> list[dict[str, str]]:
|
|
|
|
|
self._validate_input(processor, images, videos, audios)
|
|
|
|
|
num_image_tokens = 0
|
|
|
|
|
@@ -760,43 +847,6 @@ class MllamaPlugin(BasePlugin):
|
|
|
|
|
|
|
|
|
|
return messages
|
|
|
|
|
|
|
|
|
|
@override
|
|
|
|
|
def _get_mm_inputs(
|
|
|
|
|
self,
|
|
|
|
|
images: Sequence["ImageInput"],
|
|
|
|
|
videos: Sequence["VideoInput"],
|
|
|
|
|
audios: Sequence["AudioInput"],
|
|
|
|
|
processor: "ProcessorMixin",
|
|
|
|
|
imglens: list[int],
|
|
|
|
|
) -> dict[str, "torch.Tensor"]:
|
|
|
|
|
r"""Process visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
pixel_values: tensor with shape
|
|
|
|
|
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
|
|
|
|
|
For example, (2, 1, 4, 3, 560, 560).
|
|
|
|
|
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
|
|
|
|
|
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
|
|
|
|
|
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
|
|
|
|
mm_inputs = {}
|
|
|
|
|
if len(images) > 0:
|
|
|
|
|
images = self._regularize_images(
|
|
|
|
|
images,
|
|
|
|
|
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
|
|
|
|
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
|
|
|
|
)
|
|
|
|
|
batch_images = []
|
|
|
|
|
for image_length in imglens:
|
|
|
|
|
batch_images.append(images[:image_length])
|
|
|
|
|
images = images[image_length:]
|
|
|
|
|
|
|
|
|
|
mm_inputs.update(image_processor(batch_images, return_tensors="pt"))
|
|
|
|
|
|
|
|
|
|
return mm_inputs
|
|
|
|
|
|
|
|
|
|
@override
|
|
|
|
|
def get_mm_inputs(
|
|
|
|
|
self,
|
|
|
|
|
@@ -807,14 +857,14 @@ class MllamaPlugin(BasePlugin):
|
|
|
|
|
vidlens: Sequence[int],
|
|
|
|
|
audlens: Sequence[int],
|
|
|
|
|
batch_ids: Sequence[list[int]],
|
|
|
|
|
processor: Optional["ProcessorMixin"],
|
|
|
|
|
processor: Optional["MMProcessor"],
|
|
|
|
|
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
|
|
|
|
self._validate_input(processor, images, videos, audios)
|
|
|
|
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens)
|
|
|
|
|
if mm_inputs:
|
|
|
|
|
num_tiles = mm_inputs.pop("num_tiles")
|
|
|
|
|
image_token_id = getattr(processor, "image_token_id")
|
|
|
|
|
max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
|
|
|
|
|
image_token_id: int = getattr(processor, "image_token_id")
|
|
|
|
|
max_image_tiles: int = getattr(processor.image_processor, "max_image_tiles")
|
|
|
|
|
cross_attention_token_mask = [
|
|
|
|
|
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
|
|
|
|
|
]
|
|
|
|
|
@@ -839,7 +889,7 @@ class PaliGemmaPlugin(BasePlugin):
|
|
|
|
|
images: Sequence["ImageInput"],
|
|
|
|
|
videos: Sequence["VideoInput"],
|
|
|
|
|
audios: Sequence["AudioInput"],
|
|
|
|
|
processor: Optional["ProcessorMixin"],
|
|
|
|
|
processor: Optional["MMProcessor"],
|
|
|
|
|
) -> list[dict[str, str]]:
|
|
|
|
|
self._validate_input(processor, images, videos, audios)
|
|
|
|
|
num_image_tokens = 0
|
|
|
|
|
@@ -847,10 +897,10 @@ class PaliGemmaPlugin(BasePlugin):
|
|
|
|
|
for message in messages:
|
|
|
|
|
content = message["content"]
|
|
|
|
|
while IMAGE_PLACEHOLDER in content:
|
|
|
|
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
|
|
|
|
content = content.replace(IMAGE_PLACEHOLDER, "", 1)
|
|
|
|
|
num_image_tokens += 1
|
|
|
|
|
|
|
|
|
|
message["content"] = content.replace("{{image}}", "")
|
|
|
|
|
message["content"] = content
|
|
|
|
|
|
|
|
|
|
if len(images) != num_image_tokens:
|
|
|
|
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
|
|
|
@@ -866,15 +916,15 @@ class PaliGemmaPlugin(BasePlugin):
|
|
|
|
|
videos: Sequence["VideoInput"],
|
|
|
|
|
audios: Sequence["AudioInput"],
|
|
|
|
|
tokenizer: "PreTrainedTokenizer",
|
|
|
|
|
processor: Optional["ProcessorMixin"],
|
|
|
|
|
processor: Optional["MMProcessor"],
|
|
|
|
|
) -> tuple[list[int], Optional[list[int]]]:
|
|
|
|
|
self._validate_input(processor, images, videos, audios)
|
|
|
|
|
num_images = len(images)
|
|
|
|
|
image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0 # skip mm token
|
|
|
|
|
image_seqlen = processor.image_seq_length if self.expand_mm_tokens else 0 # skip mm token
|
|
|
|
|
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
|
|
|
|
input_ids = [image_token_id] * image_seqlen + input_ids
|
|
|
|
|
input_ids = [image_token_id] * num_images * image_seqlen + input_ids
|
|
|
|
|
if labels is not None:
|
|
|
|
|
labels = [IGNORE_INDEX] * image_seqlen + labels
|
|
|
|
|
labels = [IGNORE_INDEX] * num_images * image_seqlen + labels
|
|
|
|
|
|
|
|
|
|
return input_ids, labels
|
|
|
|
|
|
|
|
|
|
@@ -888,7 +938,7 @@ class PaliGemmaPlugin(BasePlugin):
|
|
|
|
|
vidlens: Sequence[int],
|
|
|
|
|
audlens: Sequence[int],
|
|
|
|
|
batch_ids: Sequence[list[int]],
|
|
|
|
|
processor: Optional["ProcessorMixin"],
|
|
|
|
|
processor: Optional["MMProcessor"],
|
|
|
|
|
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
|
|
|
|
self._validate_input(processor, images, videos, audios)
|
|
|
|
|
seqlens = [len(input_ids) for input_ids in batch_ids]
|
|
|
|
|
@@ -906,33 +956,31 @@ class PixtralPlugin(BasePlugin):
|
|
|
|
|
images: Sequence["ImageInput"],
|
|
|
|
|
videos: Sequence["VideoInput"],
|
|
|
|
|
audios: Sequence["AudioInput"],
|
|
|
|
|
processor: Optional["ProcessorMixin"],
|
|
|
|
|
processor: Optional["MMProcessor"],
|
|
|
|
|
) -> list[dict[str, str]]:
|
|
|
|
|
self._validate_input(processor, images, videos, audios)
|
|
|
|
|
patch_size = getattr(processor, "patch_size")
|
|
|
|
|
image_token = getattr(processor, "image_token")
|
|
|
|
|
image_break_token = getattr(processor, "image_break_token")
|
|
|
|
|
image_end_token = getattr(processor, "image_end_token")
|
|
|
|
|
|
|
|
|
|
num_image_tokens = 0
|
|
|
|
|
messages = deepcopy(messages)
|
|
|
|
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
|
|
|
|
if "pixel_values" in mm_inputs:
|
|
|
|
|
image_sizes = iter(mm_inputs["image_sizes"].tolist())
|
|
|
|
|
if self.expand_mm_tokens:
|
|
|
|
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
|
|
|
|
if "pixel_values" in mm_inputs:
|
|
|
|
|
image_sizes = iter(mm_inputs["image_sizes"].tolist())
|
|
|
|
|
image_break_token: str = getattr(processor, "image_break_token")
|
|
|
|
|
image_end_token: str = getattr(processor, "image_end_token")
|
|
|
|
|
|
|
|
|
|
for message in messages:
|
|
|
|
|
content = message["content"]
|
|
|
|
|
while IMAGE_PLACEHOLDER in content:
|
|
|
|
|
if self.expand_mm_tokens:
|
|
|
|
|
height, width = next(image_sizes)
|
|
|
|
|
num_height_tokens = height // patch_size
|
|
|
|
|
num_width_tokens = width // patch_size
|
|
|
|
|
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
|
|
|
|
|
num_height_tokens = height // processor.patch_size
|
|
|
|
|
num_width_tokens = width // processor.patch_size
|
|
|
|
|
replace_tokens = [[self.image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
|
|
|
|
|
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
|
|
|
|
|
replace_tokens[-1] = image_end_token
|
|
|
|
|
replace_str = "".join(replace_tokens)
|
|
|
|
|
else:
|
|
|
|
|
replace_str = image_token
|
|
|
|
|
replace_str = self.image_token
|
|
|
|
|
|
|
|
|
|
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
|
|
|
|
|
num_image_tokens += 1
|
|
|
|
|
@@ -954,7 +1002,7 @@ class PixtralPlugin(BasePlugin):
|
|
|
|
|
vidlens: Sequence[int],
|
|
|
|
|
audlens: Sequence[int],
|
|
|
|
|
batch_ids: Sequence[list[int]],
|
|
|
|
|
processor: Optional["ProcessorMixin"],
|
|
|
|
|
processor: Optional["MMProcessor"],
|
|
|
|
|
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
|
|
|
|
self._validate_input(processor, images, videos, audios)
|
|
|
|
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
|
|
|
|
@@ -971,17 +1019,18 @@ class Qwen2AudioPlugin(BasePlugin):
|
|
|
|
|
images: Sequence["ImageInput"],
|
|
|
|
|
videos: Sequence["VideoInput"],
|
|
|
|
|
audios: Sequence["AudioInput"],
|
|
|
|
|
processor: Optional["ProcessorMixin"],
|
|
|
|
|
processor: Optional["MMProcessor"],
|
|
|
|
|
) -> list[dict[str, str]]:
|
|
|
|
|
self._validate_input(processor, images, videos, audios)
|
|
|
|
|
bos_token: str = getattr(processor, "audio_bos_token")
|
|
|
|
|
eos_token: str = getattr(processor, "audio_eos_token")
|
|
|
|
|
messages = deepcopy(messages)
|
|
|
|
|
mm_inputs = self._get_mm_inputs([], [], audios, processor)
|
|
|
|
|
if "feature_attention_mask" in mm_inputs:
|
|
|
|
|
audio_lengths = mm_inputs["feature_attention_mask"].sum(-1).tolist()
|
|
|
|
|
|
|
|
|
|
num_audio_tokens = 0
|
|
|
|
|
messages = deepcopy(messages)
|
|
|
|
|
if self.expand_mm_tokens:
|
|
|
|
|
mm_inputs = self._get_mm_inputs([], [], audios, processor)
|
|
|
|
|
if "feature_attention_mask" in mm_inputs:
|
|
|
|
|
audio_lengths = mm_inputs["feature_attention_mask"].sum(-1).tolist()
|
|
|
|
|
|
|
|
|
|
for message in messages:
|
|
|
|
|
content = message["content"]
|
|
|
|
|
while AUDIO_PLACEHOLDER in content:
|
|
|
|
|
@@ -1014,7 +1063,7 @@ class Qwen2AudioPlugin(BasePlugin):
|
|
|
|
|
vidlens: Sequence[int],
|
|
|
|
|
audlens: Sequence[int],
|
|
|
|
|
batch_ids: Sequence[list[int]],
|
|
|
|
|
processor: Optional["ProcessorMixin"],
|
|
|
|
|
processor: Optional["MMProcessor"],
|
|
|
|
|
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
|
|
|
|
self._validate_input(processor, images, videos, audios)
|
|
|
|
|
return self._get_mm_inputs(images, videos, audios, processor)
|
|
|
|
|
@@ -1072,7 +1121,7 @@ class Qwen2VLPlugin(BasePlugin):
|
|
|
|
|
images: Sequence["ImageInput"],
|
|
|
|
|
videos: Sequence["VideoInput"],
|
|
|
|
|
audios: Sequence["AudioInput"],
|
|
|
|
|
processor: "ProcessorMixin",
|
|
|
|
|
processor: "MMProcessor",
|
|
|
|
|
) -> dict[str, "torch.Tensor"]:
|
|
|
|
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
|
|
|
|
mm_inputs = {}
|
|
|
|
|
@@ -1104,7 +1153,7 @@ class Qwen2VLPlugin(BasePlugin):
|
|
|
|
|
images: Sequence["ImageInput"],
|
|
|
|
|
videos: Sequence["VideoInput"],
|
|
|
|
|
audios: Sequence["AudioInput"],
|
|
|
|
|
processor: Optional["ProcessorMixin"],
|
|
|
|
|
processor: Optional["MMProcessor"],
|
|
|
|
|
) -> list[dict[str, str]]:
|
|
|
|
|
self._validate_input(processor, images, videos, audios)
|
|
|
|
|
num_image_tokens, num_video_tokens = 0, 0
|
|
|
|
|
@@ -1162,14 +1211,15 @@ class Qwen2VLPlugin(BasePlugin):
|
|
|
|
|
vidlens: Sequence[int],
|
|
|
|
|
audlens: Sequence[int],
|
|
|
|
|
batch_ids: Sequence[list[int]],
|
|
|
|
|
processor: Optional["ProcessorMixin"],
|
|
|
|
|
processor: Optional["MMProcessor"],
|
|
|
|
|
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
|
|
|
|
self._validate_input(processor, images, videos, audios)
|
|
|
|
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
|
|
|
|
fps_per_video = mm_inputs.pop("fps_per_video", [])
|
|
|
|
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
|
|
|
|
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
|
|
|
|
|
if "second_per_grid_ts" in processor.model_input_names and fps_per_video:
|
|
|
|
|
mm_inputs["second_per_grid_ts"] = [image_processor.temporal_patch_size / fps for fps in fps_per_video]
|
|
|
|
|
mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in fps_per_video]
|
|
|
|
|
|
|
|
|
|
return mm_inputs
|
|
|
|
|
|
|
|
|
|
@@ -1183,45 +1233,45 @@ class VideoLlavaPlugin(BasePlugin):
|
|
|
|
|
images: Sequence["ImageInput"],
|
|
|
|
|
videos: Sequence["VideoInput"],
|
|
|
|
|
audios: Sequence["AudioInput"],
|
|
|
|
|
processor: Optional["ProcessorMixin"],
|
|
|
|
|
processor: Optional["MMProcessor"],
|
|
|
|
|
) -> list[dict[str, str]]:
|
|
|
|
|
self._validate_input(processor, images, videos, audios)
|
|
|
|
|
num_image_tokens, num_video_tokens = 0, 0
|
|
|
|
|
messages = deepcopy(messages)
|
|
|
|
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
|
|
|
|
num_frames = 0
|
|
|
|
|
has_images = "pixel_values_images" in mm_inputs
|
|
|
|
|
has_videos = "pixel_values_videos" in mm_inputs
|
|
|
|
|
if has_images or has_videos:
|
|
|
|
|
if self.expand_mm_tokens:
|
|
|
|
|
if has_images:
|
|
|
|
|
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
|
|
|
|
|
num_frames = 1
|
|
|
|
|
if self.expand_mm_tokens:
|
|
|
|
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
|
|
|
|
if "pixel_values_images" in mm_inputs:
|
|
|
|
|
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values_images"][0]))
|
|
|
|
|
num_frames = 1
|
|
|
|
|
|
|
|
|
|
if has_videos:
|
|
|
|
|
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
|
|
|
|
height, width = get_image_size(pixel_values_video[0])
|
|
|
|
|
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
|
|
|
|
if "pixel_values_videos" in mm_inputs:
|
|
|
|
|
one_video = to_numpy_array(mm_inputs["pixel_values_videos"][0])
|
|
|
|
|
height, width = get_image_size(one_video[0])
|
|
|
|
|
num_frames = one_video.shape[0] # frame dim is always after batch dim
|
|
|
|
|
|
|
|
|
|
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
|
|
|
|
|
if "pixel_values_images" in mm_inputs or "pixel_values_videos" in mm_inputs:
|
|
|
|
|
image_seqlen = (height // processor.patch_size) * (
|
|
|
|
|
width // processor.patch_size
|
|
|
|
|
) + processor.num_additional_image_tokens
|
|
|
|
|
video_seqlen = image_seqlen * num_frames
|
|
|
|
|
if getattr(processor, "vision_feature_select_strategy", "default") == "default":
|
|
|
|
|
if processor.vision_feature_select_strategy == "default":
|
|
|
|
|
image_seqlen -= 1
|
|
|
|
|
else:
|
|
|
|
|
image_seqlen, video_seqlen = 1, 1
|
|
|
|
|
else:
|
|
|
|
|
image_seqlen, video_seqlen = 1, 1
|
|
|
|
|
|
|
|
|
|
for message in messages:
|
|
|
|
|
content = message["content"]
|
|
|
|
|
while IMAGE_PLACEHOLDER in content:
|
|
|
|
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
|
|
|
|
num_image_tokens += 1
|
|
|
|
|
for message in messages:
|
|
|
|
|
content = message["content"]
|
|
|
|
|
while IMAGE_PLACEHOLDER in content:
|
|
|
|
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
|
|
|
|
num_image_tokens += 1
|
|
|
|
|
|
|
|
|
|
while VIDEO_PLACEHOLDER in content:
|
|
|
|
|
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
|
|
|
|
|
num_video_tokens += 1
|
|
|
|
|
while VIDEO_PLACEHOLDER in content:
|
|
|
|
|
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
|
|
|
|
|
num_video_tokens += 1
|
|
|
|
|
|
|
|
|
|
content = content.replace("{{image}}", self.image_token)
|
|
|
|
|
message["content"] = content.replace("{{video}}", self.video_token)
|
|
|
|
|
content = content.replace("{{image}}", self.image_token)
|
|
|
|
|
message["content"] = content.replace("{{video}}", self.video_token)
|
|
|
|
|
|
|
|
|
|
if len(images) != num_image_tokens:
|
|
|
|
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
|
|
|
@@ -1231,24 +1281,10 @@ class VideoLlavaPlugin(BasePlugin):
|
|
|
|
|
|
|
|
|
|
return messages
|
|
|
|
|
|
|
|
|
|
@override
|
|
|
|
|
def get_mm_inputs(
|
|
|
|
|
self,
|
|
|
|
|
images: Sequence["ImageInput"],
|
|
|
|
|
videos: Sequence["VideoInput"],
|
|
|
|
|
audios: Sequence["AudioInput"],
|
|
|
|
|
imglens: Sequence[int],
|
|
|
|
|
vidlens: Sequence[int],
|
|
|
|
|
audlens: Sequence[int],
|
|
|
|
|
batch_ids: Sequence[list[int]],
|
|
|
|
|
processor: Optional["ProcessorMixin"],
|
|
|
|
|
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
|
|
|
|
self._validate_input(processor, images, videos, audios)
|
|
|
|
|
return self._get_mm_inputs(images, videos, audios, processor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PLUGINS = {
|
|
|
|
|
"base": BasePlugin,
|
|
|
|
|
"gemma3": Gemma3Plugin,
|
|
|
|
|
"llava": LlavaPlugin,
|
|
|
|
|
"llava_next": LlavaNextPlugin,
|
|
|
|
|
"llava_next_video": LlavaNextVideoPlugin,
|
|
|
|
|
|