video datasets
Former-commit-id: 33f28ce82d9e44d2615909250dc56d6a4a03cd99
This commit is contained in:
@@ -2,11 +2,10 @@ from copy import deepcopy
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
||||
|
||||
from PIL.Image import Image
|
||||
from transformers import ProcessorMixin
|
||||
import numpy as np
|
||||
|
||||
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER
|
||||
from ..extras.packages import is_pillow_available
|
||||
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
from ..extras.packages import is_pillow_available, is_pyav_available
|
||||
|
||||
|
||||
if is_pillow_available():
|
||||
@@ -14,8 +13,13 @@ if is_pillow_available():
|
||||
from PIL.Image import Image as ImageObject
|
||||
|
||||
|
||||
if is_pyav_available():
|
||||
import av
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
from numpy.typing import NDArray
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
@@ -24,13 +28,14 @@ if TYPE_CHECKING:
|
||||
bytes: Optional[bytes]
|
||||
|
||||
ImageInput = Union[str, EncodedImage, ImageObject]
|
||||
VideoInput = str
|
||||
|
||||
|
||||
def _regularize_images(images: Sequence["ImageInput"], processor: "ProcessorMixin") -> List["ImageObject"]:
|
||||
r"""
|
||||
Regularizes images to avoid error. Including reading, resizing and converting.
|
||||
"""
|
||||
image_resolution = getattr(processor, "image_resolution", 512)
|
||||
image_resolution: int = getattr(processor, "image_resolution", 512)
|
||||
results = []
|
||||
for image in images:
|
||||
if isinstance(image, str):
|
||||
@@ -56,7 +61,37 @@ def _regularize_images(images: Sequence["ImageInput"], processor: "ProcessorMixi
|
||||
return results
|
||||
|
||||
|
||||
def _get_mm_inputs(images: Sequence["ImageInput"], processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
|
||||
def _regularize_videos(videos: Sequence["VideoInput"], processor: "ProcessorMixin") -> List["NDArray"]:
|
||||
r"""
|
||||
Regularizes videos to avoid error. Including reading, resizing and converting.
|
||||
"""
|
||||
video_fps: float = getattr(processor, "video_fps", 1.0)
|
||||
video_factor: int = getattr(processor, "video_factor", 1)
|
||||
results = []
|
||||
for video in videos:
|
||||
container = av.open(video, "r")
|
||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||
total_frames = video_stream.frames
|
||||
sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps
|
||||
sample_frames = round(sample_frames / video_factor) * video_factor # for qwen2_vl
|
||||
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
||||
frames: List["ImageObject"] = []
|
||||
container.seek(0)
|
||||
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
||||
if frame_idx in sample_indices:
|
||||
frames.append(frame.to_image())
|
||||
|
||||
frames = _regularize_images(frames, processor)
|
||||
results.append(frames)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _get_mm_inputs(
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: "ProcessorMixin",
|
||||
) -> Dict[str, "torch.Tensor"]:
|
||||
r"""
|
||||
Processes visual inputs.
|
||||
|
||||
@@ -70,13 +105,19 @@ def _get_mm_inputs(images: Sequence["ImageInput"], processor: "ProcessorMixin")
|
||||
It holds num_patches == torch.prod(image_grid_thw)
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
input_dict = {"images": None, "videos": None}
|
||||
if len(images) != 0:
|
||||
images = _regularize_images(images, processor)
|
||||
image_inputs = image_processor(images=images, return_tensors="pt")
|
||||
else:
|
||||
image_inputs = {}
|
||||
input_dict["images"] = images
|
||||
|
||||
return image_inputs
|
||||
if len(videos) != 0:
|
||||
videos = _regularize_videos(videos, processor)
|
||||
input_dict["videos"] = videos
|
||||
|
||||
if input_dict["images"] is not None or input_dict["videos"] is not None:
|
||||
return image_processor(**input_dict, return_tensors="pt")
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
def _get_paligemma_token_type_ids(
|
||||
@@ -97,18 +138,32 @@ def _get_paligemma_token_type_ids(
|
||||
|
||||
|
||||
class BasePlugin:
|
||||
def __init__(self, image_token: str) -> None:
|
||||
def __init__(self, image_token: Optional[str], video_token: Optional[str]) -> None:
|
||||
self.image_token = image_token
|
||||
self.video_token = video_token
|
||||
|
||||
def _validate_input(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
) -> None:
|
||||
if len(images) != 0 and self.image_token is None:
|
||||
raise ValueError("This model does not support image input.")
|
||||
|
||||
if len(videos) != 0 and self.video_token is None:
|
||||
raise ValueError("This model does not support video input.")
|
||||
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
r"""
|
||||
Pre-processes input messages before tokenization for VLMs.
|
||||
"""
|
||||
self._validate_input(images, videos)
|
||||
return messages
|
||||
|
||||
def process_token_ids(
|
||||
@@ -116,24 +171,29 @@ class BasePlugin:
|
||||
input_ids: List[int],
|
||||
labels: Optional[List[int]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Tuple[List[int], Optional[List[int]]]:
|
||||
r"""
|
||||
Pre-processes token ids after tokenization for VLMs.
|
||||
"""
|
||||
self._validate_input(images, videos)
|
||||
return input_ids, labels
|
||||
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
r"""
|
||||
Builds batched multimodal inputs for VLMs.
|
||||
"""
|
||||
self._validate_input(images, videos)
|
||||
return {}
|
||||
|
||||
|
||||
@@ -142,8 +202,10 @@ class LlavaPlugin(BasePlugin):
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
num_image_tokens = 0
|
||||
image_seqlen = getattr(processor, "image_seqlen")
|
||||
messages = deepcopy(messages)
|
||||
@@ -163,11 +225,14 @@ class LlavaPlugin(BasePlugin):
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
return _get_mm_inputs(images, processor)
|
||||
self._validate_input(images, videos)
|
||||
return _get_mm_inputs(images, videos, processor)
|
||||
|
||||
|
||||
class PaliGemmaPlugin(BasePlugin):
|
||||
@@ -175,8 +240,10 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
for message in messages:
|
||||
@@ -197,9 +264,11 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
input_ids: List[int],
|
||||
labels: Optional[List[int]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Tuple[List[int], Optional[List[int]]]:
|
||||
self._validate_input(images, videos)
|
||||
num_images = len(images)
|
||||
image_seqlen = num_images * getattr(processor, "image_seqlen")
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
||||
@@ -212,11 +281,14 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
mm_inputs = _get_mm_inputs(images, processor)
|
||||
self._validate_input(images, videos)
|
||||
mm_inputs = _get_mm_inputs(images, videos, processor)
|
||||
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
|
||||
return mm_inputs
|
||||
|
||||
@@ -226,16 +298,17 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
merge_length: int = getattr(image_processor, "merge_size") ** 2
|
||||
if len(images) != 0:
|
||||
image_grid_thw = _get_mm_inputs(images, processor)["image_grid_thw"]
|
||||
else:
|
||||
image_grid_thw = []
|
||||
mm_inputs = _get_mm_inputs(images, videos, processor)
|
||||
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
||||
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
||||
|
||||
num_image_tokens = 0
|
||||
num_image_tokens, num_video_tokens = 0, 0
|
||||
messages = deepcopy(messages)
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
@@ -252,21 +325,40 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
)
|
||||
num_image_tokens += 1
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
if num_video_tokens >= len(video_grid_thw):
|
||||
raise ValueError("`len(videos)` is less than the number of {} tokens.".format(VIDEO_PLACEHOLDER))
|
||||
|
||||
content = content.replace(
|
||||
VIDEO_PLACEHOLDER,
|
||||
"<|vision_start|>{}<|vision_end|>".format(
|
||||
self.video_token * (video_grid_thw[num_video_tokens].prod() // merge_length)
|
||||
),
|
||||
1,
|
||||
)
|
||||
num_video_tokens += 1
|
||||
|
||||
message["content"] = content
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
|
||||
if len(videos) != num_video_tokens:
|
||||
raise ValueError("The number of videos does not match the number of {} tokens".format(VIDEO_PLACEHOLDER))
|
||||
|
||||
return messages
|
||||
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
return _get_mm_inputs(images, processor)
|
||||
self._validate_input(images, videos)
|
||||
return _get_mm_inputs(images, videos, processor)
|
||||
|
||||
|
||||
PLUGINS = {
|
||||
@@ -277,9 +369,13 @@ PLUGINS = {
|
||||
}
|
||||
|
||||
|
||||
def get_mm_plugin(name: str, image_token: str) -> "BasePlugin":
|
||||
def get_mm_plugin(
|
||||
name: str,
|
||||
image_token: Optional[str] = None,
|
||||
video_token: Optional[str] = None,
|
||||
) -> "BasePlugin":
|
||||
plugin_class = PLUGINS.get(name, None)
|
||||
if plugin_class is None:
|
||||
raise ValueError("Multimodal plugin `{}` not found.".format(name))
|
||||
|
||||
return plugin_class(image_token)
|
||||
return plugin_class(image_token, video_token)
|
||||
|
||||
Reference in New Issue
Block a user