|
|
|
@@ -1,10 +1,11 @@
|
|
|
|
import json
|
|
|
|
import json
|
|
|
|
|
|
|
|
import os
|
|
|
|
import uuid
|
|
|
|
import uuid
|
|
|
|
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
|
|
|
|
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
|
|
|
|
|
|
|
|
|
|
|
|
from ..data import Role as DataRole
|
|
|
|
from ..data import Role as DataRole
|
|
|
|
from ..extras.logging import get_logger
|
|
|
|
from ..extras.logging import get_logger
|
|
|
|
from ..extras.packages import is_fastapi_available
|
|
|
|
from ..extras.packages import is_fastapi_available, is_pillow_available
|
|
|
|
from .common import dictify, jsonify
|
|
|
|
from .common import dictify, jsonify
|
|
|
|
from .protocol import (
|
|
|
|
from .protocol import (
|
|
|
|
ChatCompletionMessage,
|
|
|
|
ChatCompletionMessage,
|
|
|
|
@@ -25,7 +26,14 @@ if is_fastapi_available():
|
|
|
|
from fastapi import HTTPException, status
|
|
|
|
from fastapi import HTTPException, status
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if is_pillow_available():
|
|
|
|
|
|
|
|
import requests
|
|
|
|
|
|
|
|
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
|
|
|
from numpy.typing import NDArray
|
|
|
|
|
|
|
|
|
|
|
|
from ..chat import ChatModel
|
|
|
|
from ..chat import ChatModel
|
|
|
|
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
|
|
|
|
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
|
|
|
|
|
|
|
|
|
|
|
|
@@ -40,7 +48,9 @@ ROLE_MAPPING = {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, str]], str, str]:
|
|
|
|
def _process_request(
|
|
|
|
|
|
|
|
request: "ChatCompletionRequest",
|
|
|
|
|
|
|
|
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["NDArray"]]:
|
|
|
|
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
|
|
|
|
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
|
|
|
|
|
|
|
|
|
|
|
|
if len(request.messages) == 0:
|
|
|
|
if len(request.messages) == 0:
|
|
|
|
@@ -49,12 +59,13 @@ def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, s
|
|
|
|
if request.messages[0].role == Role.SYSTEM:
|
|
|
|
if request.messages[0].role == Role.SYSTEM:
|
|
|
|
system = request.messages.pop(0).content
|
|
|
|
system = request.messages.pop(0).content
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
system = ""
|
|
|
|
system = None
|
|
|
|
|
|
|
|
|
|
|
|
if len(request.messages) % 2 == 0:
|
|
|
|
if len(request.messages) % 2 == 0:
|
|
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
|
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
|
|
|
|
|
|
|
|
|
|
|
input_messages = []
|
|
|
|
input_messages = []
|
|
|
|
|
|
|
|
image = None
|
|
|
|
for i, message in enumerate(request.messages):
|
|
|
|
for i, message in enumerate(request.messages):
|
|
|
|
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
|
|
|
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
|
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
|
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
|
|
|
@@ -66,6 +77,18 @@ def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, s
|
|
|
|
arguments = message.tool_calls[0].function.arguments
|
|
|
|
arguments = message.tool_calls[0].function.arguments
|
|
|
|
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
|
|
|
|
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
|
|
|
|
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
|
|
|
|
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
|
|
|
|
|
|
|
|
elif isinstance(message.content, list):
|
|
|
|
|
|
|
|
for input_item in message.content:
|
|
|
|
|
|
|
|
if input_item.type == "text":
|
|
|
|
|
|
|
|
input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text})
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
image_url = input_item.image_url.url
|
|
|
|
|
|
|
|
if os.path.isfile(image_url):
|
|
|
|
|
|
|
|
image_path = open(image_url, "rb")
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
image_path = requests.get(image_url, stream=True).raw
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image = Image.open(image_path).convert("RGB")
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
|
|
|
|
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
|
|
|
|
|
|
|
|
|
|
|
|
@@ -76,9 +99,9 @@ def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, s
|
|
|
|
except Exception:
|
|
|
|
except Exception:
|
|
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
|
|
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
tools = ""
|
|
|
|
tools = None
|
|
|
|
|
|
|
|
|
|
|
|
return input_messages, system, tools
|
|
|
|
return input_messages, system, tools, image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_stream_chat_completion_chunk(
|
|
|
|
def _create_stream_chat_completion_chunk(
|
|
|
|
@@ -97,11 +120,12 @@ async def create_chat_completion_response(
|
|
|
|
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
|
|
|
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
|
|
|
) -> "ChatCompletionResponse":
|
|
|
|
) -> "ChatCompletionResponse":
|
|
|
|
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
|
|
|
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
|
|
|
input_messages, system, tools = _process_request(request)
|
|
|
|
input_messages, system, tools, image = _process_request(request)
|
|
|
|
responses = await chat_model.achat(
|
|
|
|
responses = await chat_model.achat(
|
|
|
|
input_messages,
|
|
|
|
input_messages,
|
|
|
|
system,
|
|
|
|
system,
|
|
|
|
tools,
|
|
|
|
tools,
|
|
|
|
|
|
|
|
image,
|
|
|
|
do_sample=request.do_sample,
|
|
|
|
do_sample=request.do_sample,
|
|
|
|
temperature=request.temperature,
|
|
|
|
temperature=request.temperature,
|
|
|
|
top_p=request.top_p,
|
|
|
|
top_p=request.top_p,
|
|
|
|
@@ -145,7 +169,7 @@ async def create_stream_chat_completion_response(
|
|
|
|
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
|
|
|
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
|
|
|
) -> AsyncGenerator[str, None]:
|
|
|
|
) -> AsyncGenerator[str, None]:
|
|
|
|
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
|
|
|
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
|
|
|
input_messages, system, tools = _process_request(request)
|
|
|
|
input_messages, system, tools, image = _process_request(request)
|
|
|
|
if tools:
|
|
|
|
if tools:
|
|
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
|
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
|
|
|
|
|
|
|
|
|
|
|
@@ -159,6 +183,7 @@ async def create_stream_chat_completion_response(
|
|
|
|
input_messages,
|
|
|
|
input_messages,
|
|
|
|
system,
|
|
|
|
system,
|
|
|
|
tools,
|
|
|
|
tools,
|
|
|
|
|
|
|
|
image,
|
|
|
|
do_sample=request.do_sample,
|
|
|
|
do_sample=request.do_sample,
|
|
|
|
temperature=request.temperature,
|
|
|
|
temperature=request.temperature,
|
|
|
|
top_p=request.top_p,
|
|
|
|
top_p=request.top_p,
|
|
|
|
|