fix api
Former-commit-id: a4149fbcd600d4f3815f9353e5e92c569719bed6
This commit is contained in:
@@ -2,7 +2,7 @@ import asyncio
|
||||
import json
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import List, Tuple
|
||||
from typing import Any, Dict, Sequence
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -46,10 +46,17 @@ async def lifespan(app: "FastAPI"): # collects GPU memory
|
||||
torch_gc()
|
||||
|
||||
|
||||
def to_json(data: BaseModel) -> str:
|
||||
def dictify(data: "BaseModel") -> Dict[str, Any]:
|
||||
try: # pydantic v2
|
||||
return data.model_dump(exclude_unset=True)
|
||||
except AttributeError: # pydantic v1
|
||||
return data.dict(exclude_unset=True)
|
||||
|
||||
|
||||
def jsonify(data: "BaseModel") -> str:
|
||||
try: # pydantic v2
|
||||
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
|
||||
except Exception: # pydantic v1
|
||||
except AttributeError: # pydantic v1
|
||||
return data.json(exclude_unset=True, ensure_ascii=False)
|
||||
|
||||
|
||||
@@ -79,36 +86,40 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
if len(request.messages) == 0 or request.messages[-1].role != Role.USER:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
|
||||
query = request.messages[-1].content
|
||||
prev_messages = request.messages[:-1]
|
||||
if len(prev_messages) and prev_messages[0].role == Role.SYSTEM:
|
||||
system = prev_messages.pop(0).content
|
||||
messages = [dictify(message) for message in request.messages]
|
||||
if len(messages) and messages[0]["role"] == Role.SYSTEM:
|
||||
system = messages.pop(0)["content"]
|
||||
else:
|
||||
system = None
|
||||
|
||||
history = []
|
||||
if len(prev_messages) % 2 == 0:
|
||||
for i in range(0, len(prev_messages), 2):
|
||||
if prev_messages[i].role == Role.USER and prev_messages[i + 1].role == Role.ASSISTANT:
|
||||
history.append([prev_messages[i].content, prev_messages[i + 1].content])
|
||||
else:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||
else:
|
||||
if len(messages) % 2 == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||
|
||||
for i in range(len(messages)):
|
||||
if messages[i]["role"] == Role.USER:
|
||||
if i % 2 == 1:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
elif messages[i]["role"] == Role.ASSISTANT:
|
||||
if i % 2 == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
tools = "" # TODO: add tools
|
||||
|
||||
async with semaphore:
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, chat_completion, query, history, system, request)
|
||||
return await loop.run_in_executor(None, chat_completion, messages, system, tools, request)
|
||||
|
||||
def chat_completion(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
|
||||
def chat_completion(messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest):
|
||||
if request.stream:
|
||||
generate = stream_chat_completion(query, history, system, request)
|
||||
generate = stream_chat_completion(messages, system, tools, request)
|
||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||
|
||||
responses = chat_model.chat(
|
||||
query,
|
||||
history,
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
@@ -138,18 +149,18 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
||||
|
||||
def stream_chat_completion(
|
||||
query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest
|
||||
messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest
|
||||
):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0, delta=DeltaMessage(role=Role.ASSISTANT, content=""), finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield to_json(chunk)
|
||||
yield jsonify(chunk)
|
||||
|
||||
for new_text in chat_model.stream_chat(
|
||||
query,
|
||||
history,
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
@@ -162,11 +173,11 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
index=0, delta=DeltaMessage(content=new_text), finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield to_json(chunk)
|
||||
yield jsonify(chunk)
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(), finish_reason=Finish.STOP)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield to_json(chunk)
|
||||
yield jsonify(chunk)
|
||||
yield "[DONE]"
|
||||
|
||||
@app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK)
|
||||
|
||||
Reference in New Issue
Block a user