Former-commit-id: a4149fbcd600d4f3815f9353e5e92c569719bed6
This commit is contained in:
hiyouga
2024-01-21 00:03:09 +08:00
parent 5c9815ef6f
commit 50459a39f4
5 changed files with 60 additions and 50 deletions

View File

@@ -2,7 +2,7 @@ import asyncio
import json
import os
from contextlib import asynccontextmanager
from typing import List, Tuple
from typing import Any, Dict, Sequence
from pydantic import BaseModel
@@ -46,10 +46,17 @@ async def lifespan(app: "FastAPI"): # collects GPU memory
torch_gc()
def to_json(data: BaseModel) -> str:
def dictify(data: "BaseModel") -> Dict[str, Any]:
try: # pydantic v2
return data.model_dump(exclude_unset=True)
except AttributeError: # pydantic v1
return data.dict(exclude_unset=True)
def jsonify(data: "BaseModel") -> str:
try: # pydantic v2
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
except Exception: # pydantic v1
except AttributeError: # pydantic v1
return data.json(exclude_unset=True, ensure_ascii=False)
@@ -79,36 +86,40 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
if len(request.messages) == 0 or request.messages[-1].role != Role.USER:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
query = request.messages[-1].content
prev_messages = request.messages[:-1]
if len(prev_messages) and prev_messages[0].role == Role.SYSTEM:
system = prev_messages.pop(0).content
messages = [dictify(message) for message in request.messages]
if len(messages) and messages[0]["role"] == Role.SYSTEM:
system = messages.pop(0)["content"]
else:
system = None
history = []
if len(prev_messages) % 2 == 0:
for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == Role.USER and prev_messages[i + 1].role == Role.ASSISTANT:
history.append([prev_messages[i].content, prev_messages[i + 1].content])
else:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
else:
if len(messages) % 2 == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
for i in range(len(messages)):
if messages[i]["role"] == Role.USER:
if i % 2 == 1:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
elif messages[i]["role"] == Role.ASSISTANT:
if i % 2 == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
else:
raise NotImplementedError
tools = "" # TODO: add tools
async with semaphore:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, chat_completion, query, history, system, request)
return await loop.run_in_executor(None, chat_completion, messages, system, tools, request)
def chat_completion(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
def chat_completion(messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest):
if request.stream:
generate = stream_chat_completion(query, history, system, request)
generate = stream_chat_completion(messages, system, tools, request)
return EventSourceResponse(generate, media_type="text/event-stream")
responses = chat_model.chat(
query,
history,
messages,
system,
tools,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
@@ -138,18 +149,18 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
def stream_chat_completion(
query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest
messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest
):
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=DeltaMessage(role=Role.ASSISTANT, content=""), finish_reason=None
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield to_json(chunk)
yield jsonify(chunk)
for new_text in chat_model.stream_chat(
query,
history,
messages,
system,
tools,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
@@ -162,11 +173,11 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
index=0, delta=DeltaMessage(content=new_text), finish_reason=None
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield to_json(chunk)
yield jsonify(chunk)
choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(), finish_reason=Finish.STOP)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield to_json(chunk)
yield jsonify(chunk)
yield "[DONE]"
@app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK)

View File

@@ -3,6 +3,7 @@ from enum import Enum, unique
from typing import List, Optional
from pydantic import BaseModel, Field
from typing_extensions import Literal
@unique
@@ -20,14 +21,14 @@ class Finish(str, Enum):
class ModelCard(BaseModel):
id: str
object: Optional[str] = "model"
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
owned_by: Optional[str] = "owner"
object: Literal["model"] = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: Literal["owner"] = "owner"
class ModelList(BaseModel):
object: Optional[str] = "list"
data: Optional[List[ModelCard]] = []
object: Literal["list"] = "list"
data: List[ModelCard] = []
class ChatMessage(BaseModel):
@@ -43,12 +44,12 @@ class DeltaMessage(BaseModel):
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
do_sample: Optional[bool] = True
do_sample: bool = True
temperature: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = 1
n: int = 1
max_tokens: Optional[int] = None
stream: Optional[bool] = False
stream: bool = False
class ChatCompletionResponseChoice(BaseModel):
@@ -70,18 +71,18 @@ class ChatCompletionResponseUsage(BaseModel):
class ChatCompletionResponse(BaseModel):
id: Optional[str] = "chatcmpl-default"
object: Optional[str] = "chat.completion"
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
id: Literal["chatcmpl-default"] = "chatcmpl-default"
object: Literal["chat.completion"] = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
usage: ChatCompletionResponseUsage
class ChatCompletionStreamResponse(BaseModel):
id: Optional[str] = "chatcmpl-default"
object: Optional[str] = "chat.completion.chunk"
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
id: Literal["chatcmpl-default"] = "chatcmpl-default"
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]
@@ -93,7 +94,7 @@ class ScoreEvaluationRequest(BaseModel):
class ScoreEvaluationResponse(BaseModel):
id: Optional[str] = "scoreeval-default"
object: Optional[str] = "score.evaluation"
id: Literal["scoreeval-default"] = "scoreeval-default"
object: Literal["score.evaluation"] = "score.evaluation"
model: str
scores: List[float]