[data] fix internvl plugin (#7817)
This commit is contained in:
@@ -21,7 +21,7 @@ import re
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING, Any, BinaryIO, Literal, Optional, TypedDict, Union
|
||||
from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -86,20 +86,6 @@ if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
def _concatenate_list(input_list: list[Any]) -> Union[list[Any], "NDArray", "torch.Tensor"]:
|
||||
r"""Concatenate a list of lists, numpy arrays or torch tensors.
|
||||
|
||||
Returns:
|
||||
a list of numpy arrays or torch tensors.
|
||||
"""
|
||||
if isinstance(input_list[0], list):
|
||||
return [item for sublist in input_list for item in sublist]
|
||||
elif isinstance(input_list[0], np.ndarray):
|
||||
return np.concatenate(input_list, axis=0)
|
||||
elif isinstance(input_list[0], torch.Tensor):
|
||||
return torch.cat(input_list, dim=0)
|
||||
|
||||
|
||||
def _get_paligemma_token_type_ids(imglens: list[int], seqlens: list[int], processor: "MMProcessor") -> list[list[int]]:
|
||||
r"""Get paligemma token type ids for computing loss.
|
||||
|
||||
@@ -496,8 +482,15 @@ class InternVLPlugin(BasePlugin):
|
||||
**kwargs,
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||
attributes = ["crop_to_patches", "min_patches", "max_patches"] # need for image processor
|
||||
image_kwargs = {attr: getattr(image_processor, attr, None) for attr in attributes}
|
||||
image_processor_kwargs = {}
|
||||
if getattr(processor, "crop_to_patches", False):
|
||||
image_processor_kwargs.update(
|
||||
{
|
||||
"crop_to_patches": True,
|
||||
"max_patches": 12,
|
||||
"min_patches": 1,
|
||||
}
|
||||
)
|
||||
|
||||
mm_inputs = {}
|
||||
image_video_patches = []
|
||||
@@ -520,7 +513,7 @@ class InternVLPlugin(BasePlugin):
|
||||
|
||||
if len(images) != 0:
|
||||
images = make_flat_list_of_images(images)
|
||||
image_inputs = image_processor(images=images, **image_kwargs)
|
||||
image_inputs = image_processor(images=images, return_tensors="pt", **image_processor_kwargs)
|
||||
image_num_patches = image_inputs.pop("num_patches")
|
||||
image_pixel_values = image_inputs.pop("pixel_values")
|
||||
image_num_patches_indices = np.cumsum(image_num_patches)
|
||||
@@ -529,8 +522,8 @@ class InternVLPlugin(BasePlugin):
|
||||
videos = make_batched_videos(videos)
|
||||
num_frames_per_video = [len(video) for video in videos]
|
||||
patch_indices = np.cumsum(num_frames_per_video)
|
||||
image_kwargs["crop_to_patches"] = False
|
||||
video_inputs = image_processor(images=videos, **image_kwargs)
|
||||
image_processor_kwargs["crop_to_patches"] = False
|
||||
video_inputs = image_processor(images=videos, return_tensors="pt", **image_processor_kwargs)
|
||||
video_num_patches = video_inputs.pop("num_patches")
|
||||
video_pixel_values = video_inputs.pop("pixel_values")
|
||||
video_num_patches_indices = np.cumsum(video_num_patches)
|
||||
@@ -543,18 +536,16 @@ class InternVLPlugin(BasePlugin):
|
||||
image_video_patches.append(image_pixel_values[start_index:end_index])
|
||||
|
||||
if len(videos) != 0 and video_pixel_values is not None:
|
||||
patch_indices_with_prefix = [0] + list(patch_indices)
|
||||
for i in range(len(videos)):
|
||||
current_patch_index = patch_indices[i - 1] if i > 0 else 0
|
||||
end_patch_index = patch_indices[i]
|
||||
start_index = video_num_patches_indices[current_patch_index] if i > 0 else 0
|
||||
current_patch_index = patch_indices_with_prefix[i]
|
||||
end_patch_index = patch_indices_with_prefix[i + 1]
|
||||
start_index = video_num_patches_indices[current_patch_index - 1] if i > 0 else 0
|
||||
end_index = video_num_patches_indices[end_patch_index - 1]
|
||||
image_video_patches.append(video_pixel_values[start_index:end_index])
|
||||
|
||||
if len(images) != 0 or len(videos) != 0:
|
||||
pixel_values_list = _concatenate_list(image_video_patches)
|
||||
# in the latest version of transformers,
|
||||
# the pixel_values is a list of tensors not ndarray
|
||||
mm_inputs["pixel_values"] = torch.stack(pixel_values_list)
|
||||
mm_inputs["pixel_values"] = torch.cat(image_video_patches, dim=0)
|
||||
|
||||
if len(images) != 0:
|
||||
mm_inputs.update({"image_num_patches": image_num_patches})
|
||||
|
||||
Reference in New Issue
Block a user