support qwen2vl vllm infer

Former-commit-id: 03ddd2555fb97488cd4daab11e8b672d36150c5a
This commit is contained in:
hiyouga
2024-12-05 10:17:26 +00:00
parent 1fef702382
commit bbd432415d
4 changed files with 123 additions and 67 deletions

View File

@@ -62,6 +62,7 @@ class BasePlugin:
def __init__(self, image_token: Optional[str], video_token: Optional[str]) -> None:
self.image_token = image_token
self.video_token = video_token
self.expand_mm_tokens = True
def _validate_input(
self,
@@ -259,7 +260,7 @@ class LlavaPlugin(BasePlugin):
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
image_seqlen = getattr(processor, "image_seqlen")
image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1
messages = deepcopy(messages)
for message in messages:
content = message["content"]
@@ -310,11 +311,13 @@ class LlavaNextPlugin(BasePlugin):
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
image_size = next(image_sizes)
orig_height, orig_width = image_size
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if getattr(processor, "vision_feature_select_strategy") == "default":
image_seqlen -= 1
if self.expand_mm_tokens:
orig_height, orig_width = next(image_sizes)
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if getattr(processor, "vision_feature_select_strategy") == "default":
image_seqlen -= 1
else:
image_seqlen = 1
num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
@@ -359,11 +362,13 @@ class LlavaNextVideoPlugin(BasePlugin):
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
image_size = next(image_sizes)
orig_height, orig_width = image_size
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if getattr(processor, "vision_feature_select_strategy") == "default":
image_seqlen -= 1
if self.expand_mm_tokens:
orig_height, orig_width = next(image_sizes)
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if getattr(processor, "vision_feature_select_strategy") == "default":
image_seqlen -= 1
else:
image_seqlen = 1
num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
@@ -376,6 +381,7 @@ class LlavaNextVideoPlugin(BasePlugin):
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
video_seqlen = video_seqlen if self.expand_mm_tokens else 1
for message in messages:
content = message["content"]
while VIDEO_PLACEHOLDER in content:
@@ -443,7 +449,7 @@ class PaliGemmaPlugin(BasePlugin):
) -> Tuple[List[int], Optional[List[int]]]:
self._validate_input(images, videos)
num_images = len(images)
image_seqlen = num_images * getattr(processor, "image_seqlen")
image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0 # skip mm token
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
input_ids = [image_token_id] * image_seqlen + input_ids
if labels is not None:
@@ -493,14 +499,18 @@ class PixtralPlugin(BasePlugin):
if image_input_sizes is None:
raise ValueError("Cannot get image input sizes.")
image_size = image_input_sizes[0][num_image_tokens]
height, width = image_size
num_height_tokens = height // patch_size
num_width_tokens = width // patch_size
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
replace_tokens[-1] = image_end_token
replace_str = "".join(replace_tokens)
if self.expand_mm_tokens:
image_size = image_input_sizes[0][num_image_tokens]
height, width = image_size
num_height_tokens = height // patch_size
num_width_tokens = width // patch_size
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
replace_tokens[-1] = image_end_token
replace_str = "".join(replace_tokens)
else:
replace_str = image_token
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
num_image_tokens += 1
@@ -549,10 +559,27 @@ class Qwen2vlPlugin(BasePlugin):
return image
@override
def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int:
sample_frames = super()._get_video_sample_frames(video_stream, **kwargs)
sample_frames = sample_frames // 2 * 2
return sample_frames
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
results = []
for video in videos:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
total_frames = video_stream.frames
sample_frames = self._get_video_sample_frames(video_stream, **kwargs)
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
frames: List["ImageObject"] = []
container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices:
frames.append(frame.to_image())
if len(frames) % 2 != 0: # qwen2-vl requires even number of frames
frames.append(frames[-1])
frames = self._regularize_images(frames, **kwargs)
results.append(frames)
return results
@override
def process_messages(
@@ -577,12 +604,9 @@ class Qwen2vlPlugin(BasePlugin):
if num_image_tokens >= len(image_grid_thw):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
IMAGE_PLACEHOLDER,
"<|vision_start|>{}<|vision_end|>".format(
self.image_token * (image_grid_thw[num_image_tokens].prod() // merge_length)
),
1,
IMAGE_PLACEHOLDER, f"<|vision_start|>{self.image_token * image_seqlen}<|vision_end|>", 1
)
num_image_tokens += 1
@@ -590,12 +614,9 @@ class Qwen2vlPlugin(BasePlugin):
if num_video_tokens >= len(video_grid_thw):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
VIDEO_PLACEHOLDER,
"<|vision_start|>{}<|vision_end|>".format(
self.video_token * (video_grid_thw[num_video_tokens].prod() // merge_length)
),
1,
VIDEO_PLACEHOLDER, f"<|vision_start|>{self.video_token * video_seqlen}<|vision_end|>", 1
)
num_video_tokens += 1
@@ -640,19 +661,22 @@ class VideoLlavaPlugin(BasePlugin):
has_images = "pixel_values_images" in mm_inputs
has_videos = "pixel_values_videos" in mm_inputs
if has_images or has_videos:
if has_images:
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
num_frames = 1
if self.expand_mm_tokens:
if has_images:
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
num_frames = 1
if has_videos:
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
height, width = get_image_size(pixel_values_video[0])
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
if has_videos:
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
height, width = get_image_size(pixel_values_video[0])
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
video_seqlen = image_seqlen * num_frames
if getattr(processor, "vision_feature_select_strategy") == "default":
image_seqlen -= 1
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
video_seqlen = image_seqlen * num_frames
if getattr(processor, "vision_feature_select_strategy") == "default":
image_seqlen -= 1
else:
image_seqlen, video_seqlen = 1, 1
for message in messages:
content = message["content"]