support qwen2vl vllm infer
Former-commit-id: 03ddd2555fb97488cd4daab11e8b672d36150c5a
This commit is contained in:
@@ -62,6 +62,7 @@ class BasePlugin:
|
||||
def __init__(self, image_token: Optional[str], video_token: Optional[str]) -> None:
|
||||
self.image_token = image_token
|
||||
self.video_token = video_token
|
||||
self.expand_mm_tokens = True
|
||||
|
||||
def _validate_input(
|
||||
self,
|
||||
@@ -259,7 +260,7 @@ class LlavaPlugin(BasePlugin):
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
num_image_tokens = 0
|
||||
image_seqlen = getattr(processor, "image_seqlen")
|
||||
image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1
|
||||
messages = deepcopy(messages)
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
@@ -310,11 +311,13 @@ class LlavaNextPlugin(BasePlugin):
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
image_size = next(image_sizes)
|
||||
orig_height, orig_width = image_size
|
||||
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||
if getattr(processor, "vision_feature_select_strategy") == "default":
|
||||
image_seqlen -= 1
|
||||
if self.expand_mm_tokens:
|
||||
orig_height, orig_width = next(image_sizes)
|
||||
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||
if getattr(processor, "vision_feature_select_strategy") == "default":
|
||||
image_seqlen -= 1
|
||||
else:
|
||||
image_seqlen = 1
|
||||
|
||||
num_image_tokens += 1
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||
@@ -359,11 +362,13 @@ class LlavaNextVideoPlugin(BasePlugin):
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
image_size = next(image_sizes)
|
||||
orig_height, orig_width = image_size
|
||||
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||
if getattr(processor, "vision_feature_select_strategy") == "default":
|
||||
image_seqlen -= 1
|
||||
if self.expand_mm_tokens:
|
||||
orig_height, orig_width = next(image_sizes)
|
||||
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||
if getattr(processor, "vision_feature_select_strategy") == "default":
|
||||
image_seqlen -= 1
|
||||
else:
|
||||
image_seqlen = 1
|
||||
|
||||
num_image_tokens += 1
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||
@@ -376,6 +381,7 @@ class LlavaNextVideoPlugin(BasePlugin):
|
||||
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
||||
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
|
||||
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
|
||||
video_seqlen = video_seqlen if self.expand_mm_tokens else 1
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
@@ -443,7 +449,7 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
) -> Tuple[List[int], Optional[List[int]]]:
|
||||
self._validate_input(images, videos)
|
||||
num_images = len(images)
|
||||
image_seqlen = num_images * getattr(processor, "image_seqlen")
|
||||
image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0 # skip mm token
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
||||
input_ids = [image_token_id] * image_seqlen + input_ids
|
||||
if labels is not None:
|
||||
@@ -493,14 +499,18 @@ class PixtralPlugin(BasePlugin):
|
||||
if image_input_sizes is None:
|
||||
raise ValueError("Cannot get image input sizes.")
|
||||
|
||||
image_size = image_input_sizes[0][num_image_tokens]
|
||||
height, width = image_size
|
||||
num_height_tokens = height // patch_size
|
||||
num_width_tokens = width // patch_size
|
||||
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
|
||||
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
|
||||
replace_tokens[-1] = image_end_token
|
||||
replace_str = "".join(replace_tokens)
|
||||
if self.expand_mm_tokens:
|
||||
image_size = image_input_sizes[0][num_image_tokens]
|
||||
height, width = image_size
|
||||
num_height_tokens = height // patch_size
|
||||
num_width_tokens = width // patch_size
|
||||
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
|
||||
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
|
||||
replace_tokens[-1] = image_end_token
|
||||
replace_str = "".join(replace_tokens)
|
||||
else:
|
||||
replace_str = image_token
|
||||
|
||||
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
|
||||
num_image_tokens += 1
|
||||
|
||||
@@ -549,10 +559,27 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
return image
|
||||
|
||||
@override
|
||||
def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int:
|
||||
sample_frames = super()._get_video_sample_frames(video_stream, **kwargs)
|
||||
sample_frames = sample_frames // 2 * 2
|
||||
return sample_frames
|
||||
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
|
||||
results = []
|
||||
for video in videos:
|
||||
container = av.open(video, "r")
|
||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||
total_frames = video_stream.frames
|
||||
sample_frames = self._get_video_sample_frames(video_stream, **kwargs)
|
||||
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
||||
frames: List["ImageObject"] = []
|
||||
container.seek(0)
|
||||
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
||||
if frame_idx in sample_indices:
|
||||
frames.append(frame.to_image())
|
||||
|
||||
if len(frames) % 2 != 0: # qwen2-vl requires even number of frames
|
||||
frames.append(frames[-1])
|
||||
|
||||
frames = self._regularize_images(frames, **kwargs)
|
||||
results.append(frames)
|
||||
|
||||
return results
|
||||
|
||||
@override
|
||||
def process_messages(
|
||||
@@ -577,12 +604,9 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
if num_image_tokens >= len(image_grid_thw):
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER,
|
||||
"<|vision_start|>{}<|vision_end|>".format(
|
||||
self.image_token * (image_grid_thw[num_image_tokens].prod() // merge_length)
|
||||
),
|
||||
1,
|
||||
IMAGE_PLACEHOLDER, f"<|vision_start|>{self.image_token * image_seqlen}<|vision_end|>", 1
|
||||
)
|
||||
num_image_tokens += 1
|
||||
|
||||
@@ -590,12 +614,9 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
if num_video_tokens >= len(video_grid_thw):
|
||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
content = content.replace(
|
||||
VIDEO_PLACEHOLDER,
|
||||
"<|vision_start|>{}<|vision_end|>".format(
|
||||
self.video_token * (video_grid_thw[num_video_tokens].prod() // merge_length)
|
||||
),
|
||||
1,
|
||||
VIDEO_PLACEHOLDER, f"<|vision_start|>{self.video_token * video_seqlen}<|vision_end|>", 1
|
||||
)
|
||||
num_video_tokens += 1
|
||||
|
||||
@@ -640,19 +661,22 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
has_images = "pixel_values_images" in mm_inputs
|
||||
has_videos = "pixel_values_videos" in mm_inputs
|
||||
if has_images or has_videos:
|
||||
if has_images:
|
||||
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
|
||||
num_frames = 1
|
||||
if self.expand_mm_tokens:
|
||||
if has_images:
|
||||
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
|
||||
num_frames = 1
|
||||
|
||||
if has_videos:
|
||||
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
||||
height, width = get_image_size(pixel_values_video[0])
|
||||
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
||||
if has_videos:
|
||||
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
||||
height, width = get_image_size(pixel_values_video[0])
|
||||
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
||||
|
||||
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
|
||||
video_seqlen = image_seqlen * num_frames
|
||||
if getattr(processor, "vision_feature_select_strategy") == "default":
|
||||
image_seqlen -= 1
|
||||
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
|
||||
video_seqlen = image_seqlen * num_frames
|
||||
if getattr(processor, "vision_feature_select_strategy") == "default":
|
||||
image_seqlen -= 1
|
||||
else:
|
||||
image_seqlen, video_seqlen = 1, 1
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
|
||||
Reference in New Issue
Block a user