modify style & little change

Former-commit-id: c988477d14dc656450d5fec31895781b7f9f7dce
This commit is contained in:
KUANGDD
2024-10-23 15:24:07 +08:00
parent 7d135bbdb8
commit d0889012c2
7 changed files with 45 additions and 25 deletions

View File

@@ -4,6 +4,7 @@ from io import BytesIO
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
import numpy as np
import torch
from transformers.image_utils import get_image_size, to_numpy_array
from typing_extensions import override
@@ -447,6 +448,7 @@ class PaliGemmaPlugin(BasePlugin):
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
return mm_inputs
class PixtralPlugin(BasePlugin):
@override
def process_messages(
@@ -466,32 +468,28 @@ class PixtralPlugin(BasePlugin):
img_kwargs = self._get_mm_inputs(images, videos, processor)
image_input_sizes = None
if img_kwargs.get("pixel_values") is not None:
image_input_sizes = img_kwargs["image_sizes"]
image_input_sizes = img_kwargs.get("image_sizes", None)
messages = deepcopy(messages)
for message in messages:
content = message["content"]
img_id = 0
while IMAGE_PLACEHOLDER in content:
if image_input_sizes is None:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
raise ValueError(
"The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)
)
image_size = image_input_sizes[0][img_id]
image_size = image_input_sizes[0][num_image_tokens]
height, width = image_size
num_height_tokens = height // patch_size
num_width_tokens = width // patch_size
replace_tokens = [
[image_token] * num_width_tokens + [image_break_token]
] * num_height_tokens
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
# Flatten list
replace_tokens = [item for sublist in replace_tokens for item in sublist]
replace_tokens[-1] = image_end_token
replace_str = "".join(replace_tokens)
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
img_id += 1
num_image_tokens += 1
message["content"] = content
@@ -514,14 +512,13 @@ class PixtralPlugin(BasePlugin):
self._validate_input(images, videos)
mm_inputs = self._get_mm_inputs(images, videos, processor)
# hack for hf engine
if mm_inputs.get("pixel_values") and len(mm_inputs.get("pixel_values")[0]) == 1:
mm_inputs["pixel_values"] = mm_inputs["pixel_values"][0][0].unsqueeze(0)
if mm_inputs.get("image_sizes"):
del mm_inputs["image_sizes"]
if mm_inputs.get("pixel_values"):
mm_inputs["pixel_values"] = mm_inputs["pixel_values"][0]
mm_inputs.pop("image_sizes", None)
return mm_inputs
class Qwen2vlPlugin(BasePlugin):
@override
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
@@ -698,9 +695,10 @@ def get_mm_plugin(
plugin_class = PLUGINS.get(name, None)
if plugin_class == "PixtralPlugin":
from transformers.utils.versions import require_version
try:
require_version("transformers==4.46.0.dev0")
except Exception as e:
except Exception:
raise ImportError("PixtralPlugin requires transformers>=4.46.0.dev0. Please install it first.")
if plugin_class is None:
raise ValueError("Multimodal plugin `{}` not found.".format(name))