fix some
Former-commit-id: aeca8c0f978cb9754e0526b40cd431aaf867044f
This commit is contained in:
@@ -275,28 +275,23 @@ class LlavaNextPlugin(BasePlugin):
|
||||
self._validate_input(images, videos)
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
if getattr(processor, "patch_size") is None or getattr(processor, "vision_feature_select_strategy") is None:
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while self.image_token in content:
|
||||
num_image_tokens += 1
|
||||
content = content.replace(self.image_token, "{{image}}", 1)
|
||||
else:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
if "image_sizes" in mm_inputs:
|
||||
image_sizes = iter(mm_inputs["image_sizes"])
|
||||
if "pixel_values" in mm_inputs:
|
||||
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while self.image_token in content:
|
||||
image_size = next(image_sizes)
|
||||
orig_height, orig_width = image_size
|
||||
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||
if processor.vision_feature_select_strategy == "default":
|
||||
image_seqlen -= 1
|
||||
num_image_tokens += 1
|
||||
content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1)
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while self.image_token in content:
|
||||
image_size = next(image_sizes)
|
||||
orig_height, orig_width = image_size
|
||||
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||
if processor.vision_feature_select_strategy == "default":
|
||||
image_seqlen -= 1
|
||||
num_image_tokens += 1
|
||||
content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1)
|
||||
|
||||
message['content'] = content.replace("{{image}}", self.image_token)
|
||||
message['content'] = content.replace("{{image}}", self.image_token)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
@@ -316,6 +311,7 @@ class LlavaNextPlugin(BasePlugin):
|
||||
res = self._get_mm_inputs(images, videos, processor)
|
||||
return res
|
||||
|
||||
|
||||
class LlavaNextVideoPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
@@ -329,47 +325,37 @@ class LlavaNextVideoPlugin(BasePlugin):
|
||||
num_image_tokens = 0
|
||||
num_video_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
if getattr(processor, "patch_size") is None or getattr(processor, "vision_feature_select_strategy") is None:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
if "pixel_values" in mm_inputs:
|
||||
image_sizes = iter(mm_inputs["image_sizes"])
|
||||
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
|
||||
while self.image_token in content:
|
||||
image_size = next(image_sizes)
|
||||
orig_height, orig_width = image_size
|
||||
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||
if processor.vision_feature_select_strategy == "default":
|
||||
image_seqlen -= 1
|
||||
num_image_tokens += 1
|
||||
content = content.replace(self.image_token, "{{image}}", 1)
|
||||
content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1)
|
||||
|
||||
message['content'] = content.replace("{{image}}", self.image_token)
|
||||
|
||||
if "pixel_values_videos" in mm_inputs:
|
||||
one_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
||||
height, width = get_image_size(one_video[0])
|
||||
num_frames = one_video.shape[0] # frame dim is always after batch dim
|
||||
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
|
||||
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while self.video_token in content:
|
||||
num_video_tokens += 1
|
||||
content = content.replace(self.video_token, "{{video}}", 1)
|
||||
else:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
if "pixel_values" in mm_inputs:
|
||||
image_sizes = iter(mm_inputs["image_sizes"])
|
||||
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
|
||||
while self.image_token in content:
|
||||
image_size = next(image_sizes)
|
||||
orig_height, orig_width = image_size
|
||||
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||
if processor.vision_feature_select_strategy == "default":
|
||||
image_seqlen -= 1
|
||||
num_image_tokens += 1
|
||||
content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1)
|
||||
|
||||
message['content'] = content.replace("{{image}}", self.image_token)
|
||||
|
||||
if "pixel_values_videos" in mm_inputs:
|
||||
one_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
||||
height, width = get_image_size(one_video[0])
|
||||
num_frames = one_video.shape[0] # frame dim is always after batch dim
|
||||
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
|
||||
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while self.video_token in content:
|
||||
num_video_tokens += 1
|
||||
content = content.replace(self.video_token, "{{video}}", 1)
|
||||
message['content'] = content.replace("{{video}}", self.video_token * video_seqlen)
|
||||
message['content'] = content.replace("{{video}}", self.video_token * video_seqlen)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
@@ -380,36 +366,38 @@ class LlavaNextVideoPlugin(BasePlugin):
|
||||
return messages
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
video_processor = getattr(processor, "video_processor")
|
||||
res = self._get_mm_inputs(images, [], processor)
|
||||
res = super()._get_mm_inputs(images, [], processor)
|
||||
if len(videos) != 0:
|
||||
videos = self._regularize_videos(videos)
|
||||
videos = self._regularize_videos(
|
||||
videos,
|
||||
image_resolution=getattr(processor, "image_resolution", 168),
|
||||
video_fps=getattr(processor, "video_fps", 1.0),
|
||||
video_maxlen=getattr(processor, "video_maxlen", 16),
|
||||
)
|
||||
video_res = video_processor(videos, return_tensors="pt")
|
||||
res.update(video_res)
|
||||
return res
|
||||
|
||||
@override
|
||||
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
|
||||
r"""
|
||||
Regularizes videos to avoid error. Including reading, resizing and converting.
|
||||
"""
|
||||
videos = super()._regularize_videos(
|
||||
videos,
|
||||
image_resolution=168,
|
||||
video_fps=1.0,
|
||||
video_maxlen=16,
|
||||
)
|
||||
return videos
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
return self._get_mm_inputs(images, videos, processor)
|
||||
|
||||
|
||||
class PaliGemmaPlugin(BasePlugin):
|
||||
@@ -579,7 +567,22 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
num_image_tokens = 0
|
||||
num_video_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
if getattr(processor, "patch_size") is None or getattr(processor, "vision_feature_select_strategy") is None:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
num_frames = 0
|
||||
exist_images = "pixel_values_images" in mm_inputs
|
||||
exist_videos = "pixel_values_videos" in mm_inputs
|
||||
if exist_videos or exist_images:
|
||||
if exist_images:
|
||||
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
|
||||
num_frames = 1
|
||||
if exist_videos:
|
||||
one_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
||||
height, width = get_image_size(one_video[0])
|
||||
num_frames = one_video.shape[0] # frame dim is always after batch dim
|
||||
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
|
||||
video_seqlen = image_seqlen * num_frames
|
||||
if processor.vision_feature_select_strategy == "default":
|
||||
image_seqlen -= 1
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while self.image_token in content:
|
||||
@@ -588,39 +591,15 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
while self.video_token in content:
|
||||
num_video_tokens += 1
|
||||
content = content.replace(self.video_token, "{{video}}", 1)
|
||||
else:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
if "pixel_values_images" in mm_inputs.keys():
|
||||
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
|
||||
num_frames = 1
|
||||
|
||||
if "pixel_values_videos" in mm_inputs.keys():
|
||||
one_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
||||
height, width = get_image_size(one_video[0])
|
||||
num_frames = one_video.shape[0] # frame dim is always after batch dim
|
||||
|
||||
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
|
||||
video_seqlen = num_image_tokens * num_frames
|
||||
if processor.vision_feature_select_strategy == "default":
|
||||
image_seqlen -= 1
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while self.image_token in content:
|
||||
num_image_tokens += 1
|
||||
content = content.replace(self.image_token, "{{image}}", 1)
|
||||
while self.video_token in content:
|
||||
num_image_tokens += 1
|
||||
content = content.replace(self.video_token, "{{video}}", 1)
|
||||
|
||||
message["content"] = content.replace("{{image}}", self.image_token * image_seqlen)
|
||||
content = content.replace("{{image}}", self.image_token * image_seqlen)
|
||||
message["content"] = content.replace("{{video}}", self.video_token * video_seqlen)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(self.image_token))
|
||||
|
||||
if len(videos) != num_video_tokens:
|
||||
raise ValueError("The number of videos does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
raise ValueError("The number of videos does not match the number of {} tokens".format(self.video_token))
|
||||
|
||||
return messages
|
||||
|
||||
@@ -637,19 +616,6 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
self._validate_input(images, videos)
|
||||
return self._get_mm_inputs(images, videos, processor)
|
||||
|
||||
@override
|
||||
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
|
||||
r"""
|
||||
Regularizes videos to avoid error. Including reading, resizing and converting.
|
||||
"""
|
||||
videos = super()._regularize_videos(
|
||||
videos,
|
||||
image_resolution=224,
|
||||
video_fps=1.0,
|
||||
video_maxlen=8,
|
||||
)
|
||||
return videos
|
||||
|
||||
|
||||
PLUGINS = {
|
||||
"base": BasePlugin,
|
||||
|
||||
Reference in New Issue
Block a user