support qwen2vl vllm infer
Former-commit-id: 03ddd2555fb97488cd4daab11e8b672d36150c5a
This commit is contained in:
@@ -19,7 +19,7 @@ from typing_extensions import override
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras import logging
|
||||
from ..extras.constants import IMAGE_PLACEHOLDER
|
||||
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
from ..extras.misc import get_device_count
|
||||
from ..extras.packages import is_pillow_available, is_vllm_available
|
||||
from ..model import load_config, load_tokenizer
|
||||
@@ -67,6 +67,7 @@ class VllmEngine(BaseEngine):
|
||||
self.processor = tokenizer_module["processor"]
|
||||
self.tokenizer.padding_side = "left"
|
||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
|
||||
self.template.mm_plugin.expand_mm_tokens = False # for vllm generate
|
||||
self.generating_args = generating_args.to_dict()
|
||||
|
||||
engine_args = {
|
||||
@@ -83,6 +84,9 @@ class VllmEngine(BaseEngine):
|
||||
"enable_lora": model_args.adapter_name_or_path is not None,
|
||||
"max_lora_rank": model_args.vllm_max_lora_rank,
|
||||
}
|
||||
if self.template.mm_plugin.__class__.__name__ != "BasePlugin":
|
||||
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2}
|
||||
|
||||
if isinstance(model_args.vllm_config, dict):
|
||||
engine_args.update(model_args.vllm_config)
|
||||
|
||||
@@ -108,19 +112,21 @@ class VllmEngine(BaseEngine):
|
||||
**input_kwargs,
|
||||
) -> AsyncIterator["RequestOutput"]:
|
||||
request_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
|
||||
if images is not None:
|
||||
mm_input_dict.update({"images": images, "imglens": [len(images)]})
|
||||
if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
|
||||
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
|
||||
|
||||
if self.template.mm_plugin.__class__.__name__ == "Qwen2vlPlugin": # temporary solution
|
||||
image_str = f"<|vision_start|>{self.template.mm_plugin.image_token}<|vision_end|>"
|
||||
else:
|
||||
image_str = self.template.mm_plugin.image_token or ""
|
||||
if videos is not None:
|
||||
mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
|
||||
if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
|
||||
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
|
||||
|
||||
paired_messages = [
|
||||
{"role": message["role"], "content": message["content"].replace(IMAGE_PLACEHOLDER, image_str)}
|
||||
for message in messages
|
||||
] + [{"role": "assistant", "content": ""}]
|
||||
messages = self.template.mm_plugin.process_messages(
|
||||
messages, mm_input_dict["images"], mm_input_dict["videos"], self.processor
|
||||
)
|
||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||
system = system or self.generating_args["default_system"]
|
||||
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
|
||||
prompt_length = len(prompt_ids)
|
||||
@@ -168,7 +174,7 @@ class VllmEngine(BaseEngine):
|
||||
)
|
||||
|
||||
if images is not None: # add image features
|
||||
image_data = []
|
||||
multi_modal_data = {"image": []}
|
||||
for image in images:
|
||||
if not isinstance(image, (str, ImageObject)):
|
||||
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
|
||||
@@ -176,9 +182,7 @@ class VllmEngine(BaseEngine):
|
||||
if isinstance(image, str):
|
||||
image = Image.open(image).convert("RGB")
|
||||
|
||||
image_data.append(image)
|
||||
|
||||
multi_modal_data = {"image": image_data}
|
||||
multi_modal_data["image"].append(image)
|
||||
else:
|
||||
multi_modal_data = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user