[infer] vllm video/audio inference (#7566)

This commit is contained in:
hoshi-hiyouga
2025-04-02 02:27:04 +08:00
committed by GitHub
parent 2bfcad2394
commit 5e22597ff1
10 changed files with 329 additions and 285 deletions

View File

@@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, Optional
from ..data import Role as DataRole
from ..extras import logging
from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.misc import is_env_enabled
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
from .common import dictify, jsonify
@@ -56,7 +56,7 @@ if is_requests_available():
if TYPE_CHECKING:
from ..chat import ChatModel
from ..data.mm_plugin import ImageInput
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
@@ -72,7 +72,14 @@ ROLE_MAPPING = {
def _process_request(
request: "ChatCompletionRequest",
) -> tuple[list[dict[str, str]], Optional[str], Optional[str], Optional[list["ImageInput"]]]:
) -> tuple[
list[dict[str, str]],
Optional[str],
Optional[str],
Optional[list["ImageInput"]],
Optional[list["VideoInput"]],
Optional[list["AudioInput"]],
]:
if is_env_enabled("API_VERBOSE", "1"):
logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
@@ -88,7 +95,7 @@ def _process_request(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
input_messages = []
images = []
images, videos, audios = [], [], []
for i, message in enumerate(request.messages):
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
@@ -107,7 +114,7 @@ def _process_request(
for input_item in message.content:
if input_item.type == "text":
text_content += input_item.text
else:
elif input_item.type == "image_url":
text_content += IMAGE_PLACEHOLDER
image_url = input_item.image_url.url
if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image
@@ -118,6 +125,28 @@ def _process_request(
image_stream = requests.get(image_url, stream=True).raw
images.append(Image.open(image_stream).convert("RGB"))
elif input_item.type == "video_url":
text_content += VIDEO_PLACEHOLDER
video_url = input_item.video_url.url
if os.path.isfile(video_url): # local file
video_stream = open(video_url, "rb")
else: # web uri
video_stream = requests.get(video_url, stream=True).raw
videos.append(video_stream)
elif input_item.type == "audio_url":
text_content += AUDIO_PLACEHOLDER
audio_url = input_item.audio_url.url
if os.path.isfile(audio_url): # local file
audio_stream = open(audio_url, "rb")
else: # web uri
audio_stream = requests.get(audio_url, stream=True).raw
audios.append(audio_stream)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid input type {input_item.type}."
)
input_messages.append({"role": ROLE_MAPPING[message.role], "content": text_content})
else:
@@ -132,7 +161,7 @@ def _process_request(
else:
tools = None
return input_messages, system, tools, images or None
return input_messages, system, tools, images or None, videos or None, audios or None
def _create_stream_chat_completion_chunk(
@@ -151,12 +180,14 @@ async def create_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> "ChatCompletionResponse":
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
input_messages, system, tools, images = _process_request(request)
input_messages, system, tools, images, videos, audios = _process_request(request)
responses = await chat_model.achat(
input_messages,
system,
tools,
images,
videos,
audios,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
@@ -202,7 +233,7 @@ async def create_stream_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> AsyncGenerator[str, None]:
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
input_messages, system, tools, images = _process_request(request)
input_messages, system, tools, images, videos, audios = _process_request(request)
if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
@@ -217,6 +248,8 @@ async def create_stream_chat_completion_response(
system,
tools,
images,
videos,
audios,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,

View File

@@ -70,14 +70,17 @@ class FunctionCall(BaseModel):
function: Function
class ImageURL(BaseModel):
class URL(BaseModel):
url: str
detail: Literal["auto", "low", "high"] = "auto"
class MultimodalInputItem(BaseModel):
type: Literal["text", "image_url"]
type: Literal["text", "image_url", "video_url", "audio_url"]
text: Optional[str] = None
image_url: Optional[ImageURL] = None
image_url: Optional[URL] = None
video_url: Optional[URL] = None
audio_url: Optional[URL] = None
class ChatMessage(BaseModel):