[misc] update format (#7277)

This commit is contained in:
hoshi-hiyouga
2025-03-13 02:53:08 +08:00
committed by GitHub
parent 4b9d8da5a4
commit 650a9a9057
62 changed files with 384 additions and 288 deletions

View File

@@ -18,7 +18,6 @@
import inspect
import math
import re
from collections.abc import Sequence
from copy import deepcopy
from dataclasses import dataclass
from io import BytesIO
@@ -83,9 +82,7 @@ if TYPE_CHECKING:
pass
def _get_paligemma_token_type_ids(
imglens: Sequence[int], seqlens: Sequence[int], processor: "MMProcessor"
) -> list[list[int]]:
def _get_paligemma_token_type_ids(imglens: list[int], seqlens: list[int], processor: "MMProcessor") -> list[list[int]]:
r"""Get paligemma token type ids for computing loss.
It is slightly different with the original token type ids where the prompt part is 0.
@@ -120,7 +117,7 @@ def _get_gemma3_token_type_ids(batch_ids: list[list[int]], processor: "MMProcess
return batch_token_type_ids
def _make_batched_images(images: Sequence["ImageObject"], imglens: list[int]) -> list[list["ImageObject"]]:
def _make_batched_images(images: list["ImageObject"], imglens: list[int]) -> list[list["ImageObject"]]:
r"""Make nested list of images."""
batch_images = []
for imglen in imglens:
@@ -140,9 +137,9 @@ class MMPluginMixin:
def _validate_input(
self,
processor: Optional["MMProcessor"],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
) -> None:
r"""Validate if this model accepts the input modalities."""
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
@@ -202,7 +199,7 @@ class MMPluginMixin:
sample_frames = min(total_frames, video_maxlen, sample_frames)
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> list["ImageObject"]:
def _regularize_images(self, images: list["ImageInput"], **kwargs) -> list["ImageObject"]:
r"""Regularize images to avoid error. Including reading and pre-processing."""
results = []
for image in images:
@@ -223,7 +220,7 @@ class MMPluginMixin:
return results
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> list[list["ImageObject"]]:
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> list[list["ImageObject"]]:
r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
results = []
for video in videos:
@@ -241,7 +238,7 @@ class MMPluginMixin:
return results
def _regularize_audios(self, audios: Sequence["AudioInput"], sampling_rate: float, **kwargs) -> list["NDArray"]:
def _regularize_audios(self, audios: list["AudioInput"], sampling_rate: float, **kwargs) -> list["NDArray"]:
r"""Regularizes audios to avoid error. Including reading and resampling."""
results = []
for audio in audios:
@@ -257,9 +254,9 @@ class MMPluginMixin:
def _get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: "MMProcessor",
imglens: Optional[list[int]] = None,
) -> dict[str, "torch.Tensor"]:
@@ -335,10 +332,10 @@ class MMPluginMixin:
class BasePlugin(MMPluginMixin):
def process_messages(
self,
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
r"""Pre-process input messages before tokenization for VLMs."""
@@ -349,9 +346,9 @@ class BasePlugin(MMPluginMixin):
self,
input_ids: list[int],
labels: Optional[list[int]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["MMProcessor"],
) -> tuple[list[int], Optional[list[int]]]:
@@ -361,13 +358,13 @@ class BasePlugin(MMPluginMixin):
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[list[int]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
imglens: list[int],
vidlens: list[int],
audlens: list[int],
batch_ids: list[list[int]],
processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
r"""Build batched multimodal inputs for VLMs.
@@ -392,10 +389,10 @@ class Gemma3Plugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
@@ -420,13 +417,13 @@ class Gemma3Plugin(BasePlugin):
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[list[int]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
imglens: list[int],
vidlens: list[int],
audlens: list[int],
batch_ids: list[list[int]],
processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
@@ -441,10 +438,10 @@ class LlavaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
@@ -481,10 +478,10 @@ class LlavaNextPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
@@ -523,10 +520,10 @@ class LlavaNextVideoPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
@@ -586,10 +583,10 @@ class MiniCPMVPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
@@ -686,9 +683,9 @@ class MiniCPMVPlugin(BasePlugin):
@override
def _get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: "MMProcessor",
**kwargs,
) -> dict[str, "torch.Tensor"]:
@@ -757,13 +754,13 @@ class MiniCPMVPlugin(BasePlugin):
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[list[int]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
imglens: list[int],
vidlens: list[int],
audlens: list[int],
batch_ids: list[list[int]],
processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
@@ -828,10 +825,10 @@ class MllamaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
@@ -850,13 +847,13 @@ class MllamaPlugin(BasePlugin):
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[list[int]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
imglens: list[int],
vidlens: list[int],
audlens: list[int],
batch_ids: list[list[int]],
processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
@@ -885,10 +882,10 @@ class PaliGemmaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
@@ -912,9 +909,9 @@ class PaliGemmaPlugin(BasePlugin):
self,
input_ids: list[int],
labels: Optional[list[int]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["MMProcessor"],
) -> tuple[list[int], Optional[list[int]]]:
@@ -931,13 +928,13 @@ class PaliGemmaPlugin(BasePlugin):
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[list[int]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
imglens: list[int],
vidlens: list[int],
audlens: list[int],
batch_ids: list[list[int]],
processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
@@ -952,10 +949,10 @@ class PixtralPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
@@ -995,13 +992,13 @@ class PixtralPlugin(BasePlugin):
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[list[int]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
imglens: list[int],
vidlens: list[int],
audlens: list[int],
batch_ids: list[list[int]],
processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
@@ -1015,10 +1012,10 @@ class Qwen2AudioPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
@@ -1056,13 +1053,13 @@ class Qwen2AudioPlugin(BasePlugin):
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[list[int]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
imglens: list[int],
vidlens: list[int],
audlens: list[int],
batch_ids: list[list[int]],
processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
@@ -1090,7 +1087,7 @@ class Qwen2VLPlugin(BasePlugin):
@override
def _regularize_videos(
self, videos: Sequence["VideoInput"], **kwargs
self, videos: list["VideoInput"], **kwargs
) -> tuple[list[list["ImageObject"]], list[float]]:
results, fps_per_video = [], []
for video in videos:
@@ -1118,9 +1115,9 @@ class Qwen2VLPlugin(BasePlugin):
@override
def _get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: "MMProcessor",
) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
@@ -1149,10 +1146,10 @@ class Qwen2VLPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
@@ -1204,13 +1201,13 @@ class Qwen2VLPlugin(BasePlugin):
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[list[int]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
imglens: list[int],
vidlens: list[int],
audlens: list[int],
batch_ids: list[list[int]],
processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
@@ -1229,10 +1226,10 @@ class VideoLlavaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)