support multiimage inference
Former-commit-id: 8083e4607549e805eb308c4e93c8aa256202f438
This commit is contained in:
@@ -69,7 +69,7 @@ ROLE_MAPPING = {
|
||||
|
||||
def _process_request(
|
||||
request: "ChatCompletionRequest",
|
||||
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["ImageInput"]]:
|
||||
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional[List["ImageInput"]]]:
|
||||
logger.info(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
|
||||
|
||||
if len(request.messages) == 0:
|
||||
@@ -84,7 +84,7 @@ def _process_request(
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||
|
||||
input_messages = []
|
||||
image = None
|
||||
images = []
|
||||
for i, message in enumerate(request.messages):
|
||||
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
@@ -111,10 +111,11 @@ def _process_request(
|
||||
else: # web uri
|
||||
image_stream = requests.get(image_url, stream=True).raw
|
||||
|
||||
image = Image.open(image_stream).convert("RGB")
|
||||
images.append(Image.open(image_stream).convert("RGB"))
|
||||
else:
|
||||
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
|
||||
|
||||
images = None if len(images) == 0 else images
|
||||
tool_list = request.tools
|
||||
if isinstance(tool_list, list) and len(tool_list):
|
||||
try:
|
||||
@@ -124,7 +125,7 @@ def _process_request(
|
||||
else:
|
||||
tools = None
|
||||
|
||||
return input_messages, system, tools, image
|
||||
return input_messages, system, tools, images
|
||||
|
||||
|
||||
def _create_stream_chat_completion_chunk(
|
||||
@@ -143,12 +144,12 @@ async def create_chat_completion_response(
|
||||
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
||||
) -> "ChatCompletionResponse":
|
||||
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||
input_messages, system, tools, image = _process_request(request)
|
||||
input_messages, system, tools, images = _process_request(request)
|
||||
responses = await chat_model.achat(
|
||||
input_messages,
|
||||
system,
|
||||
tools,
|
||||
image,
|
||||
images,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
@@ -194,7 +195,7 @@ async def create_stream_chat_completion_response(
|
||||
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
||||
) -> AsyncGenerator[str, None]:
|
||||
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||
input_messages, system, tools, image = _process_request(request)
|
||||
input_messages, system, tools, images = _process_request(request)
|
||||
if tools:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
||||
|
||||
@@ -208,7 +209,7 @@ async def create_stream_chat_completion_response(
|
||||
input_messages,
|
||||
system,
|
||||
tools,
|
||||
image,
|
||||
images,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
|
||||
Reference in New Issue
Block a user