add pixtral template
Former-commit-id: c7b4e47e0fda955272ccd6340b2047fd92acbfcf
This commit is contained in:
@@ -24,6 +24,7 @@ if TYPE_CHECKING:
|
||||
from av.stream import Stream
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
from transformers.processing_utils import _validate_images_text_input_order, ProcessingKwargs
|
||||
|
||||
class EncodedImage(TypedDict):
|
||||
path: Optional[str]
|
||||
@@ -324,11 +325,65 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
return mm_inputs
|
||||
|
||||
class PixtralPlugin(BasePlugin):
|
||||
#TODO preprocess according to Pixtral hf
|
||||
from transformers import LlavaForConditionalGeneration
|
||||
@override
|
||||
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
||||
pass
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
patch_size = processor.patch_size
|
||||
image_token = processor.image_token
|
||||
image_break_token = processor.image_break_token
|
||||
image_end_token = processor.image_end_token
|
||||
|
||||
self._validate_input(images, videos)
|
||||
num_image_tokens = 0
|
||||
image_input_sizes = self._get_mm_inputs(images, videos, processor)["image_sizes"]
|
||||
messages = deepcopy(messages)
|
||||
print(image_input_sizes[0], messages)
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
img_id = 0
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
# only support one image for one time?
|
||||
image_size = image_input_sizes[0][0]
|
||||
height, width = image_size
|
||||
num_height_tokens = height // patch_size
|
||||
num_width_tokens = width // patch_size
|
||||
replace_tokens = [
|
||||
[image_token] * num_width_tokens + [image_break_token]
|
||||
] * num_height_tokens
|
||||
# Flatten list
|
||||
replace_tokens = [item for sublist in replace_tokens for item in sublist]
|
||||
replace_tokens[-1] = image_end_token
|
||||
replace_str = "".join(replace_tokens)
|
||||
content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
|
||||
|
||||
img_id += 1
|
||||
num_image_tokens += 1
|
||||
|
||||
message["content"] = content
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
|
||||
self._validate_input(images, videos)
|
||||
return self._get_mm_inputs(images, videos, processor)
|
||||
|
||||
class Qwen2vlPlugin(BasePlugin):
|
||||
@override
|
||||
@@ -428,6 +483,7 @@ PLUGINS = {
|
||||
"llava": LlavaPlugin,
|
||||
"paligemma": PaliGemmaPlugin,
|
||||
"qwen2_vl": Qwen2vlPlugin,
|
||||
"pixtral": PixtralPlugin,
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user