modify style & little change
Former-commit-id: c988477d14dc656450d5fec31895781b7f9f7dce
This commit is contained in:
@@ -4,6 +4,7 @@ from io import BytesIO
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers.image_utils import get_image_size, to_numpy_array
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -447,6 +448,7 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
|
||||
return mm_inputs
|
||||
|
||||
|
||||
class PixtralPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
@@ -466,32 +468,28 @@ class PixtralPlugin(BasePlugin):
|
||||
img_kwargs = self._get_mm_inputs(images, videos, processor)
|
||||
image_input_sizes = None
|
||||
|
||||
if img_kwargs.get("pixel_values") is not None:
|
||||
image_input_sizes = img_kwargs["image_sizes"]
|
||||
image_input_sizes = img_kwargs.get("image_sizes", None)
|
||||
|
||||
messages = deepcopy(messages)
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
img_id = 0
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
|
||||
if image_input_sizes is None:
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
raise ValueError(
|
||||
"The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)
|
||||
)
|
||||
|
||||
image_size = image_input_sizes[0][img_id]
|
||||
image_size = image_input_sizes[0][num_image_tokens]
|
||||
height, width = image_size
|
||||
num_height_tokens = height // patch_size
|
||||
num_width_tokens = width // patch_size
|
||||
replace_tokens = [
|
||||
[image_token] * num_width_tokens + [image_break_token]
|
||||
] * num_height_tokens
|
||||
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
|
||||
# Flatten list
|
||||
replace_tokens = [item for sublist in replace_tokens for item in sublist]
|
||||
replace_tokens[-1] = image_end_token
|
||||
replace_str = "".join(replace_tokens)
|
||||
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
|
||||
|
||||
img_id += 1
|
||||
num_image_tokens += 1
|
||||
|
||||
message["content"] = content
|
||||
@@ -514,14 +512,13 @@ class PixtralPlugin(BasePlugin):
|
||||
self._validate_input(images, videos)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
# hack for hf engine
|
||||
if mm_inputs.get("pixel_values") and len(mm_inputs.get("pixel_values")[0]) == 1:
|
||||
mm_inputs["pixel_values"] = mm_inputs["pixel_values"][0][0].unsqueeze(0)
|
||||
|
||||
if mm_inputs.get("image_sizes"):
|
||||
del mm_inputs["image_sizes"]
|
||||
if mm_inputs.get("pixel_values"):
|
||||
mm_inputs["pixel_values"] = mm_inputs["pixel_values"][0]
|
||||
|
||||
mm_inputs.pop("image_sizes", None)
|
||||
return mm_inputs
|
||||
|
||||
|
||||
class Qwen2vlPlugin(BasePlugin):
|
||||
@override
|
||||
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
||||
@@ -698,9 +695,10 @@ def get_mm_plugin(
|
||||
plugin_class = PLUGINS.get(name, None)
|
||||
if plugin_class == "PixtralPlugin":
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
try:
|
||||
require_version("transformers==4.46.0.dev0")
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise ImportError("PixtralPlugin requires transformers>=4.46.0.dev0. Please install it first.")
|
||||
if plugin_class is None:
|
||||
raise ValueError("Multimodal plugin `{}` not found.".format(name))
|
||||
|
||||
Reference in New Issue
Block a user