[model] support intern-VL 2.5-3 series (#7258)
* add internvl and rebase * fix for internvl2&3 * remove lines * fix video_inputs & lint * nit * add constants * remove lines * fix * fix error * pass ci * pass ci * skip internvl & nit
This commit is contained in:
@@ -25,7 +25,12 @@ from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers.image_utils import get_image_size, to_numpy_array
|
||||
from transformers.image_utils import (
|
||||
get_image_size,
|
||||
make_batched_videos,
|
||||
make_flat_list_of_images,
|
||||
to_numpy_array,
|
||||
)
|
||||
from typing_extensions import override
|
||||
|
||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
@@ -82,6 +87,20 @@ if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
def _concatenate_list(input_list):
|
||||
r"""Concatenate a list of lists, numpy arrays or torch tensors.
|
||||
|
||||
Returns:
|
||||
a list of numpy arrays or torch tensors.
|
||||
"""
|
||||
if isinstance(input_list[0], list):
|
||||
return [item for sublist in input_list for item in sublist]
|
||||
elif isinstance(input_list[0], np.ndarray):
|
||||
return np.concatenate(input_list, axis=0)
|
||||
elif isinstance(input_list[0], torch.Tensor):
|
||||
return torch.cat(input_list, dim=0)
|
||||
|
||||
|
||||
def _get_paligemma_token_type_ids(imglens: list[int], seqlens: list[int], processor: "MMProcessor") -> list[list[int]]:
|
||||
r"""Get paligemma token type ids for computing loss.
|
||||
|
||||
@@ -467,6 +486,163 @@ class Gemma3Plugin(BasePlugin):
|
||||
|
||||
|
||||
@dataclass
|
||||
class InternVLPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens = 0
|
||||
num_video_tokens = 0
|
||||
image_seqlen = getattr(processor, "image_seq_length") if self.expand_mm_tokens else 1
|
||||
messages = deepcopy(messages)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
image_pixel_patch_list = mm_inputs.get("image_num_patches", None) # pathes of images
|
||||
video_num_patches = mm_inputs.get("video_num_patches", None) # all patches for frames of videos
|
||||
video_patch_indices = mm_inputs.get("video_patch_indices", None) # num frames of per video
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if num_image_tokens >= len(image_pixel_patch_list):
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER,
|
||||
f"<img>{'<IMG_CONTEXT>' * image_seqlen * image_pixel_patch_list[num_image_tokens]}</img>",
|
||||
1,
|
||||
)
|
||||
num_image_tokens += 1
|
||||
message["content"] = content
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
if num_video_tokens >= len(video_patch_indices):
|
||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0
|
||||
end_patch_index = video_patch_indices[num_video_tokens]
|
||||
num_patches = list(video_num_patches[current_patch_index:end_patch_index])
|
||||
video_replaced_prompt = "\n".join(
|
||||
f"Frame{i + 1}: <img>{'<IMG_CONTEXT>' * image_seqlen * num_patches[i]}</img>"
|
||||
for i in range(len(num_patches))
|
||||
)
|
||||
content = content.replace(VIDEO_PLACEHOLDER, video_replaced_prompt, 1)
|
||||
num_video_tokens += 1
|
||||
message["content"] = content
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
if len(videos) != num_video_tokens:
|
||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: "ProcessorMixin",
|
||||
**kwargs,
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||
attributes = ["crop_to_patches", "min_patches", "max_patches"] # need for image processor
|
||||
image_kwargs = {attr: getattr(image_processor, attr, None) for attr in attributes}
|
||||
|
||||
mm_inputs = {}
|
||||
image_video_patches = []
|
||||
|
||||
if len(images) != 0 and isinstance(images[0], str):
|
||||
images = self._regularize_images(
|
||||
images,
|
||||
image_max_pixels=getattr(processor, "image_max_pixels", 1024 * 1024),
|
||||
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
||||
)["images"]
|
||||
|
||||
if len(videos) != 0 and isinstance(videos[0], str):
|
||||
videos = self._regularize_videos(
|
||||
videos,
|
||||
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
|
||||
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
|
||||
video_fps=getattr(processor, "video_fps", 2.0),
|
||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||
)["videos"]
|
||||
|
||||
if len(images) != 0:
|
||||
images = make_flat_list_of_images(images)
|
||||
image_inputs = image_processor(images=images, **image_kwargs)
|
||||
image_num_patches = image_inputs.pop("num_patches")
|
||||
image_pixel_values = image_inputs.pop("pixel_values")
|
||||
image_num_patches_indices = np.cumsum(image_num_patches)
|
||||
|
||||
if len(videos) != 0:
|
||||
videos = make_batched_videos(videos)
|
||||
num_frames_per_video = [len(video) for video in videos]
|
||||
patch_indices = np.cumsum(num_frames_per_video)
|
||||
image_kwargs["crop_to_patches"] = False
|
||||
video_inputs = image_processor(images=videos, **image_kwargs)
|
||||
video_num_patches = video_inputs.pop("num_patches")
|
||||
video_pixel_values = video_inputs.pop("pixel_values")
|
||||
video_num_patches_indices = np.cumsum(video_num_patches)
|
||||
|
||||
# NOT SUPPORT IMAGE VIDEO INTERLEAVED
|
||||
if len(images) != 0 and image_pixel_values is not None:
|
||||
for i in range(len(images)):
|
||||
start_index = image_num_patches_indices[i - 1] if i > 0 else 0
|
||||
end_index = image_num_patches_indices[i]
|
||||
image_video_patches.append(image_pixel_values[start_index:end_index])
|
||||
|
||||
if len(videos) != 0 and video_pixel_values is not None:
|
||||
for i in range(len(videos)):
|
||||
current_patch_index = patch_indices[i - 1] if i > 0 else 0
|
||||
end_patch_index = patch_indices[i]
|
||||
start_index = video_num_patches_indices[current_patch_index] if i > 0 else 0
|
||||
end_index = video_num_patches_indices[end_patch_index - 1]
|
||||
image_video_patches.append(video_pixel_values[start_index:end_index])
|
||||
|
||||
if len(images) != 0 or len(videos) != 0:
|
||||
pixel_values_list = _concatenate_list(image_video_patches)
|
||||
mm_inputs["pixel_values"] = torch.stack(
|
||||
[torch.tensor(patch_ndarray) for patch_ndarray in pixel_values_list]
|
||||
)
|
||||
|
||||
if len(images) != 0:
|
||||
mm_inputs.update({"image_num_patches": image_num_patches})
|
||||
|
||||
if len(videos) != 0:
|
||||
mm_inputs.update({"video_patch_indices": patch_indices})
|
||||
mm_inputs.update({"video_num_patches": video_num_patches})
|
||||
|
||||
return mm_inputs
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
imglens: list[int],
|
||||
vidlens: list[int],
|
||||
audlens: list[int],
|
||||
batch_ids: list[list[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
mm_inputs.pop("image_num_patches", None)
|
||||
mm_inputs.pop("video_patch_indices", None)
|
||||
mm_inputs.pop("video_num_patches", None)
|
||||
|
||||
return mm_inputs
|
||||
|
||||
|
||||
class KimiVLPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(self, messages, images, videos, audios, processor):
|
||||
@@ -1603,6 +1779,7 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
PLUGINS = {
|
||||
"base": BasePlugin,
|
||||
"gemma3": Gemma3Plugin,
|
||||
"intern_vl": InternVLPlugin,
|
||||
"kimi_vl": KimiVLPlugin,
|
||||
"llama4": Llama4Plugin,
|
||||
"llava": LlavaPlugin,
|
||||
|
||||
Reference in New Issue
Block a user