fix data entry

Former-commit-id: e5c116816f2d00e3bfe1a9be5886fe1e41d93212
This commit is contained in:
hiyouga
2024-02-23 18:29:24 +08:00
parent 89f86cc970
commit 75603c45fc
5 changed files with 37 additions and 34 deletions

View File

@@ -75,11 +75,11 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
role_mapping = {
Role.USER: DataRole.USER,
Role.ASSISTANT: DataRole.ASSISTANT,
Role.SYSTEM: DataRole.SYSTEM,
Role.FUNCTION: DataRole.FUNCTION,
Role.TOOL: DataRole.OBSERVATION,
Role.USER: DataRole.USER.value,
Role.ASSISTANT: DataRole.ASSISTANT.value,
Role.SYSTEM: DataRole.SYSTEM.value,
Role.FUNCTION: DataRole.FUNCTION.value,
Role.TOOL: DataRole.OBSERVATION.value,
}
@app.get("/v1/models", response_model=ModelList)
@@ -95,7 +95,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
if role_mapping[request.messages[0].role] == DataRole.SYSTEM:
if request.messages[0].role == Role.SYSTEM:
system = request.messages.pop(0).content
else:
system = ""
@@ -105,11 +105,12 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
input_messages = []
for i, message in enumerate(request.messages):
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
input_messages.append({"role": role_mapping[message.role], "content": message.content})
if i % 2 == 0 and input_messages[i]["role"] not in [DataRole.USER, DataRole.OBSERVATION]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
elif i % 2 == 1 and input_messages[i]["role"] not in [DataRole.ASSISTANT, DataRole.FUNCTION]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
tool_list = request.tools
if isinstance(tool_list, list) and len(tool_list):