[data] add min resolution option (#6975)

Former-commit-id: 76bd9a98a2fb00f1a1d881e6e1364c02fd36d327
This commit is contained in:
hoshi-hiyouga
2025-02-18 01:40:46 +08:00
committed by GitHub
parent f2fd9d1b25
commit c09b648934
9 changed files with 59 additions and 24 deletions

View File

@@ -104,12 +104,19 @@ class MMPluginMixin:
"This model does not support audio input. Please check whether the correct `template` is used."
)
def _preprocess_image(self, image: "ImageObject", image_resolution: int, **kwargs) -> "ImageObject":
def _preprocess_image(
self, image: "ImageObject", image_max_resolution: int, image_min_resolution: int, **kwargs
) -> "ImageObject":
r"""
Pre-processes a single image.
"""
if (image.width * image.height) > image_resolution:
resize_factor = math.sqrt(image_resolution / (image.width * image.height))
if (image.width * image.height) > image_max_resolution:
resize_factor = math.sqrt(image_max_resolution / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
image = image.resize((width, height), resample=Image.Resampling.NEAREST)
if (image.width * image.height) < image_min_resolution:
resize_factor = math.sqrt(image_min_resolution / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
image = image.resize((width, height), resample=Image.Resampling.NEAREST)
@@ -217,14 +224,17 @@ class MMPluginMixin:
if len(images) != 0:
images = self._regularize_images(
images, image_resolution=getattr(processor, "image_resolution", 768 * 768)
images,
image_max_resolution=getattr(processor, "image_max_resolution", 768 * 768),
image_min_resolution=getattr(processor, "image_min_resolution", 32 * 32),
)
mm_inputs.update(image_processor(images, return_tensors="pt"))
if len(videos) != 0:
videos = self._regularize_videos(
videos,
image_resolution=getattr(processor, "video_resolution", 256 * 256),
image_max_resolution=getattr(processor, "video_max_resolution", 256 * 256),
image_min_resolution=getattr(processor, "video_min_resolution", 16 * 16),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128),
)
@@ -606,7 +616,8 @@ class MiniCPMVPlugin(BasePlugin):
if len(images) != 0:
images = self._regularize_images(
images,
image_resolution=getattr(processor, "image_resolution", 768 * 768),
image_max_resolution=getattr(processor, "image_max_resolution", 768 * 768),
image_min_resolution=getattr(processor, "image_min_resolution", 32 * 32),
)
if "valid_image_nums_ls" in kwargs:
valid_image_nums_ls = kwargs["valid_image_nums_ls"]
@@ -626,7 +637,8 @@ class MiniCPMVPlugin(BasePlugin):
if len(videos) != 0:
videos = self._regularize_videos(
videos,
image_resolution=getattr(processor, "video_resolution", 256 * 256),
image_max_resolution=getattr(processor, "video_max_resolution", 256 * 256),
image_min_resolution=getattr(processor, "video_min_resolution", 16 * 16),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128),
)
@@ -774,7 +786,11 @@ class MllamaPlugin(BasePlugin):
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 768 * 768))
images = self._regularize_images(
images,
image_max_resolution=getattr(processor, "image_max_resolution", 768 * 768),
image_min_resolution=getattr(processor, "image_min_resolution", 32 * 32),
)
batch_images = []
for image_length in imglens:
batch_images.append(images[:image_length])
@@ -1065,14 +1081,17 @@ class Qwen2vlPlugin(BasePlugin):
mm_inputs = {}
if len(images) != 0:
images = self._regularize_images(
images, image_resolution=getattr(processor, "image_resolution", 768 * 768)
images,
image_max_resolution=getattr(processor, "image_max_resolution", 768 * 768),
image_min_resolution=getattr(processor, "image_min_resolution", 32 * 32),
)
mm_inputs.update(image_processor(images, return_tensors="pt"))
if len(videos) != 0:
videos, fps_per_video = self._regularize_videos(
videos,
image_resolution=getattr(processor, "video_resolution", 256 * 256),
image_max_resolution=getattr(processor, "video_max_resolution", 256 * 256),
image_min_resolution=getattr(processor, "video_min_resolution", 16 * 16),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128),
)