finish agent

Former-commit-id: d8d9d3afe32725fe79120fcd1a0970fdcdc45625
This commit is contained in:
hiyouga
2024-01-21 01:47:33 +08:00
parent 50459a39f4
commit 27f281480a
8 changed files with 105 additions and 41 deletions

View File

@@ -7,18 +7,20 @@ from typing import Any, Dict, Sequence
from pydantic import BaseModel
from ..chat import ChatModel
from ..data import Role as DataRole
from ..extras.misc import torch_gc
from ..extras.packages import is_fastapi_availble, is_starlette_available, is_uvicorn_available
from .protocol import (
ChatCompletionMessage,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionResponseUsage,
ChatCompletionStreamResponse,
ChatMessage,
DeltaMessage,
Finish,
Function,
FunctionCall,
ModelCard,
ModelList,
Role,
@@ -84,7 +86,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
if len(request.messages) == 0 or request.messages[-1].role != Role.USER:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
messages = [dictify(message) for message in request.messages]
if len(messages) and messages[0]["role"] == Role.SYSTEM:
@@ -96,16 +98,21 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
for i in range(len(messages)):
if messages[i]["role"] == Role.USER:
if i % 2 == 1:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
elif messages[i]["role"] == Role.ASSISTANT:
if i % 2 == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
else:
raise NotImplementedError
if i % 2 == 0 and messages[i]["role"] not in [Role.USER, Role.TOOL]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
elif i % 2 == 1 and messages[i]["role"] not in [Role.ASSISTANT, Role.FUNCTION]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
elif messages[i]["role"] == Role.TOOL:
messages[i]["role"] = DataRole.OBSERVATION
tools = "" # TODO: add tools
tool_list = request.tools
if len(tool_list):
try:
tools = json.dumps([tool_list[0]["function"]], ensure_ascii=False)
except Exception:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
else:
tools = ""
async with semaphore:
loop = asyncio.get_running_loop()
@@ -130,12 +137,24 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
prompt_length, response_length = 0, 0
choices = []
for i, response in enumerate(responses):
choices.append(
ChatCompletionResponseChoice(
index=i,
message=ChatMessage(role=Role.ASSISTANT, content=response.response_text),
finish_reason=Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH,
if tools:
result = chat_model.template.format_tools.extract(response.response_text)
else:
result = response.response_text
if isinstance(result, tuple):
name, arguments = result
function = Function(name=name, arguments=arguments)
response_message = ChatCompletionMessage(
role=Role.ASSISTANT, tool_calls=[FunctionCall(function=function)]
)
finish_reason = Finish.TOOL
else:
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
choices.append(
ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason)
)
prompt_length = response.prompt_length
response_length += response.response_length
@@ -152,7 +171,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest
):
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=DeltaMessage(role=Role.ASSISTANT, content=""), finish_reason=None
index=0, delta=ChatCompletionMessage(role=Role.ASSISTANT, content=""), finish_reason=None
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield jsonify(chunk)
@@ -170,12 +189,14 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
continue
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=DeltaMessage(content=new_text), finish_reason=None
index=0, delta=ChatCompletionMessage(content=new_text), finish_reason=None
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield jsonify(chunk)
choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(), finish_reason=Finish.STOP)
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield jsonify(chunk)
yield "[DONE]"