support image input in api #3971 #4061

Former-commit-id: c70aaf763ef22fb83ce3635e8ffd5ec4c89c1cb0
This commit is contained in:
hiyouga
2024-06-06 02:29:55 +08:00
parent 35379c7c0e
commit 639a7f6796
4 changed files with 49 additions and 8 deletions

View File

@@ -1,10 +1,11 @@
import json
import os
import uuid
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
from ..data import Role as DataRole
from ..extras.logging import get_logger
from ..extras.packages import is_fastapi_available
from ..extras.packages import is_fastapi_available, is_pillow_available
from .common import dictify, jsonify
from .protocol import (
ChatCompletionMessage,
@@ -25,7 +26,14 @@ if is_fastapi_available():
from fastapi import HTTPException, status
if is_pillow_available():
import requests
from PIL import Image
if TYPE_CHECKING:
from numpy.typing import NDArray
from ..chat import ChatModel
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
@@ -40,7 +48,9 @@ ROLE_MAPPING = {
}
def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, str]], str, str]:
def _process_request(
request: "ChatCompletionRequest",
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["NDArray"]]:
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
if len(request.messages) == 0:
@@ -49,12 +59,13 @@ def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, s
if request.messages[0].role == Role.SYSTEM:
system = request.messages.pop(0).content
else:
system = ""
system = None
if len(request.messages) % 2 == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
input_messages = []
image = None
for i, message in enumerate(request.messages):
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
@@ -66,6 +77,18 @@ def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, s
arguments = message.tool_calls[0].function.arguments
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
elif isinstance(message.content, list):
for input_item in message.content:
if input_item.type == "text":
input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text})
else:
image_url = input_item.image_url.url
if os.path.isfile(image_url):
image_path = open(image_url, "rb")
else:
image_path = requests.get(image_url, stream=True).raw
image = Image.open(image_path).convert("RGB")
else:
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
@@ -76,9 +99,9 @@ def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, s
except Exception:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
else:
tools = ""
tools = None
return input_messages, system, tools
return input_messages, system, tools, image
def _create_stream_chat_completion_chunk(
@@ -97,11 +120,12 @@ async def create_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> "ChatCompletionResponse":
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
input_messages, system, tools = _process_request(request)
input_messages, system, tools, image = _process_request(request)
responses = await chat_model.achat(
input_messages,
system,
tools,
image,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
@@ -145,7 +169,7 @@ async def create_stream_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> AsyncGenerator[str, None]:
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
input_messages, system, tools = _process_request(request)
input_messages, system, tools, image = _process_request(request)
if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
@@ -159,6 +183,7 @@ async def create_stream_chat_completion_response(
input_messages,
system,
tools,
image,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,