[infer] vllm video/audio inference (#7566)
This commit is contained in:
@@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from ..data import Role as DataRole
|
||||
from ..extras import logging
|
||||
from ..extras.constants import IMAGE_PLACEHOLDER
|
||||
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
from ..extras.misc import is_env_enabled
|
||||
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
|
||||
from .common import dictify, jsonify
|
||||
@@ -56,7 +56,7 @@ if is_requests_available():
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..chat import ChatModel
|
||||
from ..data.mm_plugin import ImageInput
|
||||
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
|
||||
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
|
||||
|
||||
|
||||
@@ -72,7 +72,14 @@ ROLE_MAPPING = {
|
||||
|
||||
def _process_request(
|
||||
request: "ChatCompletionRequest",
|
||||
) -> tuple[list[dict[str, str]], Optional[str], Optional[str], Optional[list["ImageInput"]]]:
|
||||
) -> tuple[
|
||||
list[dict[str, str]],
|
||||
Optional[str],
|
||||
Optional[str],
|
||||
Optional[list["ImageInput"]],
|
||||
Optional[list["VideoInput"]],
|
||||
Optional[list["AudioInput"]],
|
||||
]:
|
||||
if is_env_enabled("API_VERBOSE", "1"):
|
||||
logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
|
||||
|
||||
@@ -88,7 +95,7 @@ def _process_request(
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||
|
||||
input_messages = []
|
||||
images = []
|
||||
images, videos, audios = [], [], []
|
||||
for i, message in enumerate(request.messages):
|
||||
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
@@ -107,7 +114,7 @@ def _process_request(
|
||||
for input_item in message.content:
|
||||
if input_item.type == "text":
|
||||
text_content += input_item.text
|
||||
else:
|
||||
elif input_item.type == "image_url":
|
||||
text_content += IMAGE_PLACEHOLDER
|
||||
image_url = input_item.image_url.url
|
||||
if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image
|
||||
@@ -118,6 +125,28 @@ def _process_request(
|
||||
image_stream = requests.get(image_url, stream=True).raw
|
||||
|
||||
images.append(Image.open(image_stream).convert("RGB"))
|
||||
elif input_item.type == "video_url":
|
||||
text_content += VIDEO_PLACEHOLDER
|
||||
video_url = input_item.video_url.url
|
||||
if os.path.isfile(video_url): # local file
|
||||
video_stream = open(video_url, "rb")
|
||||
else: # web uri
|
||||
video_stream = requests.get(video_url, stream=True).raw
|
||||
|
||||
videos.append(video_stream)
|
||||
elif input_item.type == "audio_url":
|
||||
text_content += AUDIO_PLACEHOLDER
|
||||
audio_url = input_item.audio_url.url
|
||||
if os.path.isfile(audio_url): # local file
|
||||
audio_stream = open(audio_url, "rb")
|
||||
else: # web uri
|
||||
audio_stream = requests.get(audio_url, stream=True).raw
|
||||
|
||||
audios.append(audio_stream)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid input type {input_item.type}."
|
||||
)
|
||||
|
||||
input_messages.append({"role": ROLE_MAPPING[message.role], "content": text_content})
|
||||
else:
|
||||
@@ -132,7 +161,7 @@ def _process_request(
|
||||
else:
|
||||
tools = None
|
||||
|
||||
return input_messages, system, tools, images or None
|
||||
return input_messages, system, tools, images or None, videos or None, audios or None
|
||||
|
||||
|
||||
def _create_stream_chat_completion_chunk(
|
||||
@@ -151,12 +180,14 @@ async def create_chat_completion_response(
|
||||
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
||||
) -> "ChatCompletionResponse":
|
||||
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||
input_messages, system, tools, images = _process_request(request)
|
||||
input_messages, system, tools, images, videos, audios = _process_request(request)
|
||||
responses = await chat_model.achat(
|
||||
input_messages,
|
||||
system,
|
||||
tools,
|
||||
images,
|
||||
videos,
|
||||
audios,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
@@ -202,7 +233,7 @@ async def create_stream_chat_completion_response(
|
||||
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
||||
) -> AsyncGenerator[str, None]:
|
||||
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||
input_messages, system, tools, images = _process_request(request)
|
||||
input_messages, system, tools, images, videos, audios = _process_request(request)
|
||||
if tools:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
||||
|
||||
@@ -217,6 +248,8 @@ async def create_stream_chat_completion_response(
|
||||
system,
|
||||
tools,
|
||||
images,
|
||||
videos,
|
||||
audios,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
|
||||
@@ -70,14 +70,17 @@ class FunctionCall(BaseModel):
|
||||
function: Function
|
||||
|
||||
|
||||
class ImageURL(BaseModel):
|
||||
class URL(BaseModel):
|
||||
url: str
|
||||
detail: Literal["auto", "low", "high"] = "auto"
|
||||
|
||||
|
||||
class MultimodalInputItem(BaseModel):
|
||||
type: Literal["text", "image_url"]
|
||||
type: Literal["text", "image_url", "video_url", "audio_url"]
|
||||
text: Optional[str] = None
|
||||
image_url: Optional[ImageURL] = None
|
||||
image_url: Optional[URL] = None
|
||||
video_url: Optional[URL] = None
|
||||
audio_url: Optional[URL] = None
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user