add pixtral template

Former-commit-id: c7b4e47e0fda955272ccd6340b2047fd92acbfcf
This commit is contained in:
Kingsley
2024-09-26 17:14:51 +08:00
parent c4a585f232
commit 9390927875
2 changed files with 60 additions and 41 deletions

View File

@@ -24,6 +24,7 @@ if TYPE_CHECKING:
from av.stream import Stream
from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
from transformers.processing_utils import _validate_images_text_input_order, ProcessingKwargs
class EncodedImage(TypedDict):
path: Optional[str]
@@ -324,11 +325,65 @@ class PaliGemmaPlugin(BasePlugin):
return mm_inputs
class PixtralPlugin(BasePlugin):
#TODO preprocess according to Pixtral hf
from transformers import LlavaForConditionalGeneration
@override
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
pass
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
patch_size = processor.patch_size
image_token = processor.image_token
image_break_token = processor.image_break_token
image_end_token = processor.image_end_token
self._validate_input(images, videos)
num_image_tokens = 0
image_input_sizes = self._get_mm_inputs(images, videos, processor)["image_sizes"]
messages = deepcopy(messages)
print(image_input_sizes[0], messages)
for message in messages:
content = message["content"]
img_id = 0
while IMAGE_PLACEHOLDER in content:
# only support one image for one time?
image_size = image_input_sizes[0][0]
height, width = image_size
num_height_tokens = height // patch_size
num_width_tokens = width // patch_size
replace_tokens = [
[image_token] * num_width_tokens + [image_break_token]
] * num_height_tokens
# Flatten list
replace_tokens = [item for sublist in replace_tokens for item in sublist]
replace_tokens[-1] = image_end_token
replace_str = "".join(replace_tokens)
content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
img_id += 1
num_image_tokens += 1
message["content"] = content
if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
return self._get_mm_inputs(images, videos, processor)
class Qwen2vlPlugin(BasePlugin):
@override
@@ -428,6 +483,7 @@ PLUGINS = {
"llava": LlavaPlugin,
"paligemma": PaliGemmaPlugin,
"qwen2_vl": Qwen2vlPlugin,
"pixtral": PixtralPlugin,
}