support qwen2vl vllm infer
Former-commit-id: 03ddd2555fb97488cd4daab11e8b672d36150c5a
This commit is contained in:
@@ -19,7 +19,7 @@ from typing_extensions import override
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras import logging
|
||||
from ..extras.constants import IMAGE_PLACEHOLDER
|
||||
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
from ..extras.misc import get_device_count
|
||||
from ..extras.packages import is_pillow_available, is_vllm_available
|
||||
from ..model import load_config, load_tokenizer
|
||||
@@ -67,6 +67,7 @@ class VllmEngine(BaseEngine):
|
||||
self.processor = tokenizer_module["processor"]
|
||||
self.tokenizer.padding_side = "left"
|
||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
|
||||
self.template.mm_plugin.expand_mm_tokens = False # for vllm generate
|
||||
self.generating_args = generating_args.to_dict()
|
||||
|
||||
engine_args = {
|
||||
@@ -83,6 +84,9 @@ class VllmEngine(BaseEngine):
|
||||
"enable_lora": model_args.adapter_name_or_path is not None,
|
||||
"max_lora_rank": model_args.vllm_max_lora_rank,
|
||||
}
|
||||
if self.template.mm_plugin.__class__.__name__ != "BasePlugin":
|
||||
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2}
|
||||
|
||||
if isinstance(model_args.vllm_config, dict):
|
||||
engine_args.update(model_args.vllm_config)
|
||||
|
||||
@@ -108,19 +112,21 @@ class VllmEngine(BaseEngine):
|
||||
**input_kwargs,
|
||||
) -> AsyncIterator["RequestOutput"]:
|
||||
request_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
|
||||
if images is not None:
|
||||
mm_input_dict.update({"images": images, "imglens": [len(images)]})
|
||||
if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
|
||||
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
|
||||
|
||||
if self.template.mm_plugin.__class__.__name__ == "Qwen2vlPlugin": # temporary solution
|
||||
image_str = f"<|vision_start|>{self.template.mm_plugin.image_token}<|vision_end|>"
|
||||
else:
|
||||
image_str = self.template.mm_plugin.image_token or ""
|
||||
if videos is not None:
|
||||
mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
|
||||
if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
|
||||
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
|
||||
|
||||
paired_messages = [
|
||||
{"role": message["role"], "content": message["content"].replace(IMAGE_PLACEHOLDER, image_str)}
|
||||
for message in messages
|
||||
] + [{"role": "assistant", "content": ""}]
|
||||
messages = self.template.mm_plugin.process_messages(
|
||||
messages, mm_input_dict["images"], mm_input_dict["videos"], self.processor
|
||||
)
|
||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||
system = system or self.generating_args["default_system"]
|
||||
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
|
||||
prompt_length = len(prompt_ids)
|
||||
@@ -168,7 +174,7 @@ class VllmEngine(BaseEngine):
|
||||
)
|
||||
|
||||
if images is not None: # add image features
|
||||
image_data = []
|
||||
multi_modal_data = {"image": []}
|
||||
for image in images:
|
||||
if not isinstance(image, (str, ImageObject)):
|
||||
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
|
||||
@@ -176,9 +182,7 @@ class VllmEngine(BaseEngine):
|
||||
if isinstance(image, str):
|
||||
image = Image.open(image).convert("RGB")
|
||||
|
||||
image_data.append(image)
|
||||
|
||||
multi_modal_data = {"image": image_data}
|
||||
multi_modal_data["image"].append(image)
|
||||
else:
|
||||
multi_modal_data = None
|
||||
|
||||
|
||||
@@ -62,6 +62,7 @@ class BasePlugin:
|
||||
def __init__(self, image_token: Optional[str], video_token: Optional[str]) -> None:
|
||||
self.image_token = image_token
|
||||
self.video_token = video_token
|
||||
self.expand_mm_tokens = True
|
||||
|
||||
def _validate_input(
|
||||
self,
|
||||
@@ -259,7 +260,7 @@ class LlavaPlugin(BasePlugin):
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
num_image_tokens = 0
|
||||
image_seqlen = getattr(processor, "image_seqlen")
|
||||
image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1
|
||||
messages = deepcopy(messages)
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
@@ -310,11 +311,13 @@ class LlavaNextPlugin(BasePlugin):
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
image_size = next(image_sizes)
|
||||
orig_height, orig_width = image_size
|
||||
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||
if getattr(processor, "vision_feature_select_strategy") == "default":
|
||||
image_seqlen -= 1
|
||||
if self.expand_mm_tokens:
|
||||
orig_height, orig_width = next(image_sizes)
|
||||
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||
if getattr(processor, "vision_feature_select_strategy") == "default":
|
||||
image_seqlen -= 1
|
||||
else:
|
||||
image_seqlen = 1
|
||||
|
||||
num_image_tokens += 1
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||
@@ -359,11 +362,13 @@ class LlavaNextVideoPlugin(BasePlugin):
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
image_size = next(image_sizes)
|
||||
orig_height, orig_width = image_size
|
||||
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||
if getattr(processor, "vision_feature_select_strategy") == "default":
|
||||
image_seqlen -= 1
|
||||
if self.expand_mm_tokens:
|
||||
orig_height, orig_width = next(image_sizes)
|
||||
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||
if getattr(processor, "vision_feature_select_strategy") == "default":
|
||||
image_seqlen -= 1
|
||||
else:
|
||||
image_seqlen = 1
|
||||
|
||||
num_image_tokens += 1
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||
@@ -376,6 +381,7 @@ class LlavaNextVideoPlugin(BasePlugin):
|
||||
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
||||
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
|
||||
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
|
||||
video_seqlen = video_seqlen if self.expand_mm_tokens else 1
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
@@ -443,7 +449,7 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
) -> Tuple[List[int], Optional[List[int]]]:
|
||||
self._validate_input(images, videos)
|
||||
num_images = len(images)
|
||||
image_seqlen = num_images * getattr(processor, "image_seqlen")
|
||||
image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0 # skip mm token
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
||||
input_ids = [image_token_id] * image_seqlen + input_ids
|
||||
if labels is not None:
|
||||
@@ -493,14 +499,18 @@ class PixtralPlugin(BasePlugin):
|
||||
if image_input_sizes is None:
|
||||
raise ValueError("Cannot get image input sizes.")
|
||||
|
||||
image_size = image_input_sizes[0][num_image_tokens]
|
||||
height, width = image_size
|
||||
num_height_tokens = height // patch_size
|
||||
num_width_tokens = width // patch_size
|
||||
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
|
||||
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
|
||||
replace_tokens[-1] = image_end_token
|
||||
replace_str = "".join(replace_tokens)
|
||||
if self.expand_mm_tokens:
|
||||
image_size = image_input_sizes[0][num_image_tokens]
|
||||
height, width = image_size
|
||||
num_height_tokens = height // patch_size
|
||||
num_width_tokens = width // patch_size
|
||||
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
|
||||
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
|
||||
replace_tokens[-1] = image_end_token
|
||||
replace_str = "".join(replace_tokens)
|
||||
else:
|
||||
replace_str = image_token
|
||||
|
||||
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
|
||||
num_image_tokens += 1
|
||||
|
||||
@@ -549,10 +559,27 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
return image
|
||||
|
||||
@override
|
||||
def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int:
|
||||
sample_frames = super()._get_video_sample_frames(video_stream, **kwargs)
|
||||
sample_frames = sample_frames // 2 * 2
|
||||
return sample_frames
|
||||
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
|
||||
results = []
|
||||
for video in videos:
|
||||
container = av.open(video, "r")
|
||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||
total_frames = video_stream.frames
|
||||
sample_frames = self._get_video_sample_frames(video_stream, **kwargs)
|
||||
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
||||
frames: List["ImageObject"] = []
|
||||
container.seek(0)
|
||||
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
||||
if frame_idx in sample_indices:
|
||||
frames.append(frame.to_image())
|
||||
|
||||
if len(frames) % 2 != 0: # qwen2-vl requires even number of frames
|
||||
frames.append(frames[-1])
|
||||
|
||||
frames = self._regularize_images(frames, **kwargs)
|
||||
results.append(frames)
|
||||
|
||||
return results
|
||||
|
||||
@override
|
||||
def process_messages(
|
||||
@@ -577,12 +604,9 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
if num_image_tokens >= len(image_grid_thw):
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER,
|
||||
"<|vision_start|>{}<|vision_end|>".format(
|
||||
self.image_token * (image_grid_thw[num_image_tokens].prod() // merge_length)
|
||||
),
|
||||
1,
|
||||
IMAGE_PLACEHOLDER, f"<|vision_start|>{self.image_token * image_seqlen}<|vision_end|>", 1
|
||||
)
|
||||
num_image_tokens += 1
|
||||
|
||||
@@ -590,12 +614,9 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
if num_video_tokens >= len(video_grid_thw):
|
||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
content = content.replace(
|
||||
VIDEO_PLACEHOLDER,
|
||||
"<|vision_start|>{}<|vision_end|>".format(
|
||||
self.video_token * (video_grid_thw[num_video_tokens].prod() // merge_length)
|
||||
),
|
||||
1,
|
||||
VIDEO_PLACEHOLDER, f"<|vision_start|>{self.video_token * video_seqlen}<|vision_end|>", 1
|
||||
)
|
||||
num_video_tokens += 1
|
||||
|
||||
@@ -640,19 +661,22 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
has_images = "pixel_values_images" in mm_inputs
|
||||
has_videos = "pixel_values_videos" in mm_inputs
|
||||
if has_images or has_videos:
|
||||
if has_images:
|
||||
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
|
||||
num_frames = 1
|
||||
if self.expand_mm_tokens:
|
||||
if has_images:
|
||||
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
|
||||
num_frames = 1
|
||||
|
||||
if has_videos:
|
||||
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
||||
height, width = get_image_size(pixel_values_video[0])
|
||||
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
||||
if has_videos:
|
||||
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
||||
height, width = get_image_size(pixel_values_video[0])
|
||||
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
||||
|
||||
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
|
||||
video_seqlen = image_seqlen * num_frames
|
||||
if getattr(processor, "vision_feature_select_strategy") == "default":
|
||||
image_seqlen -= 1
|
||||
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
|
||||
video_seqlen = image_seqlen * num_frames
|
||||
if getattr(processor, "vision_feature_select_strategy") == "default":
|
||||
image_seqlen -= 1
|
||||
else:
|
||||
image_seqlen, video_seqlen = 1, 1
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
|
||||
Reference in New Issue
Block a user