add some
Former-commit-id: 81176fe226da89eace89cb202bad68e73b7c2a02
This commit is contained in:
@@ -2,6 +2,7 @@ import math
|
||||
from copy import deepcopy
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -249,6 +250,130 @@ class BasePlugin:
|
||||
return {}
|
||||
|
||||
|
||||
class CpmOPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
num_image_tokens += 1
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
||||
|
||||
message["content"] = content.replace("{{image}}", "(<image>./</image>)")
|
||||
|
||||
if num_image_tokens>0:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
|
||||
pattern = "(<image>./</image>)"
|
||||
images, image_sizes, tgt_sizes = mm_inputs["pixel_values"], mm_inputs["image_sizes"], mm_inputs["tgt_sizes"]
|
||||
|
||||
input_ids_list = []
|
||||
image_bounds_list = []
|
||||
image_index = 0
|
||||
for index, message in enumerate(messages):
|
||||
text = message['content']
|
||||
image_tags = re.findall(pattern, text)
|
||||
text_chunks = text.split(pattern)
|
||||
final_text = ""
|
||||
for i in range(len(image_tags)):
|
||||
final_text = final_text + text_chunks[i] + \
|
||||
image_processor.get_slice_image_placeholder(
|
||||
image_sizes[image_index][i],
|
||||
i,
|
||||
image_processor.max_slice_nums,
|
||||
image_processor.use_image_id,
|
||||
)
|
||||
image_index += 1
|
||||
final_text += text_chunks[-1]
|
||||
messages[index]['content'] = final_text
|
||||
# print(messages)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: "ProcessorMixin",
|
||||
) -> Dict[str, "torch.Tensor"]:
|
||||
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
|
||||
mm_inputs = {}
|
||||
|
||||
if len(images) != 0:
|
||||
images = self._regularize_images(
|
||||
images,
|
||||
image_resolution=getattr(processor, "image_resolution", 512 * 512),
|
||||
)
|
||||
image_inputs = image_processor(images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt")
|
||||
mm_inputs.update(image_inputs)
|
||||
|
||||
if len(videos) != 0:
|
||||
videos = self._regularize_videos(
|
||||
videos,
|
||||
image_resolution=getattr(processor, "video_resolution", 128 * 128),
|
||||
video_fps=getattr(processor, "video_fps", 2.0),
|
||||
video_maxlen=getattr(processor, "video_maxlen", 64),
|
||||
)
|
||||
|
||||
return mm_inputs
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
image_bounds_list = []
|
||||
position_ids = []
|
||||
for input_ids in batch_ids:
|
||||
input_ids_ = torch.tensor(input_ids)
|
||||
start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (input_ids_ == processor.tokenizer.slice_start_id)
|
||||
end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id)
|
||||
image_start_tokens = torch.where(start_cond)[0]
|
||||
image_start_tokens += 1
|
||||
image_end_tokens = torch.where(end_cond)[0]
|
||||
valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
|
||||
image_bounds = torch.hstack(
|
||||
[
|
||||
image_start_tokens[:valid_image_nums].unsqueeze(-1),
|
||||
image_end_tokens[:valid_image_nums].unsqueeze(-1),
|
||||
]
|
||||
)
|
||||
image_bounds_list.append(image_bounds)
|
||||
position_ids_ = list(range(input_ids_.size(0)))
|
||||
# print(input_ids_.shape, len(position_ids_)
|
||||
position_ids.append(position_ids_)
|
||||
position_ids = torch.tensor(position_ids, dtype=torch.int64)
|
||||
mm_inputs.update({
|
||||
"image_bound": image_bounds_list,
|
||||
"position_ids": position_ids,
|
||||
})
|
||||
return mm_inputs
|
||||
|
||||
|
||||
class LlavaPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
@@ -790,6 +915,7 @@ class MllamaPlugin(BasePlugin):
|
||||
|
||||
PLUGINS = {
|
||||
"base": BasePlugin,
|
||||
"cpm_o": CpmOPlugin,
|
||||
"llava": LlavaPlugin,
|
||||
"llava_next": LlavaNextPlugin,
|
||||
"llava_next_video": LlavaNextVideoPlugin,
|
||||
|
||||
Reference in New Issue
Block a user