Former-commit-id: d0340391e8075cff0d84b3ef879c2101b66ca1dc
This commit is contained in:
hiyouga
2024-04-01 21:35:18 +08:00
parent e7f13098c6
commit 40211db275
7 changed files with 37 additions and 16 deletions

View File

@@ -108,12 +108,18 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
input_messages.append({"role": role_mapping[message.role], "content": message.content})
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
name = message.tool_calls[0].function.name
arguments = message.tool_calls[0].function.arguments
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
input_messages.append({"role": role_mapping[Role.FUNCTION], "content": content})
else:
input_messages.append({"role": role_mapping[message.role], "content": message.content})
tool_list = request.tools
if isinstance(tool_list, list) and len(tool_list):
try:
tools = json.dumps([tool["function"] for tool in tool_list], ensure_ascii=False)
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
except Exception:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
else:

View File

@@ -1,6 +1,6 @@
import time
from enum import Enum, unique
from typing import List, Optional
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
from typing_extensions import Literal
@@ -39,6 +39,17 @@ class Function(BaseModel):
arguments: str
class FunctionDefinition(BaseModel):
name: str
description: str
parameters: Dict[str, Any]
class FunctionAvailable(BaseModel):
type: Literal["function", "code_interpreter"] = "function"
function: Optional[FunctionDefinition] = None
class FunctionCall(BaseModel):
id: Literal["call_default"] = "call_default"
type: Literal["function"] = "function"
@@ -47,7 +58,8 @@ class FunctionCall(BaseModel):
class ChatMessage(BaseModel):
role: Role
content: str
content: Optional[str] = None
tool_calls: Optional[List[FunctionCall]] = None
class ChatCompletionMessage(BaseModel):
@@ -59,7 +71,7 @@ class ChatCompletionMessage(BaseModel):
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
tools: list = []
tools: Optional[List[FunctionAvailable]] = None
do_sample: bool = True
temperature: Optional[float] = None
top_p: Optional[float] = None