update gradio, support multiple resp in api

Former-commit-id: a34263e7c0e07a080276d164cdab9f12f1d767d2
This commit is contained in:
hiyouga
2023-11-01 23:02:16 +08:00
parent 2406200914
commit bff8b02543
10 changed files with 54 additions and 42 deletions

View File

@@ -1,9 +1,11 @@
import json
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from sse_starlette import EventSourceResponse
from typing import List, Tuple
from pydantic import BaseModel
from llmtuner.extras.misc import torch_gc
from llmtuner.chat import ChatModel
@@ -29,6 +31,13 @@ async def lifespan(app: FastAPI): # collects GPU memory
torch_gc()
def to_json(data: BaseModel) -> str:
try:
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
except:
return data.json(exclude_unset=True, ensure_ascii=False)
def create_app(chat_model: ChatModel) -> FastAPI:
app = FastAPI(lifespan=lifespan)
@@ -45,10 +54,10 @@ def create_app(chat_model: ChatModel) -> FastAPI:
model_card = ModelCard(id="gpt-3.5-turbo")
return ModelList(data=[model_card])
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
async def create_chat_completion(request: ChatCompletionRequest):
if len(request.messages) < 1 or request.messages[-1].role != Role.USER:
raise HTTPException(status_code=400, detail="Invalid request")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
query = request.messages[-1].content
prev_messages = request.messages[:-1]
@@ -62,6 +71,8 @@ def create_app(chat_model: ChatModel) -> FastAPI:
for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT:
history.append([prev_messages[i].content, prev_messages[i+1].content])
else:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
if request.stream:
generate = predict(query, history, system, request)
@@ -72,7 +83,8 @@ def create_app(chat_model: ChatModel) -> FastAPI:
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
max_new_tokens=request.max_tokens
max_new_tokens=request.max_tokens,
num_return_sequences=request.n
)
usage = ChatCompletionResponseUsage(
@@ -81,13 +93,13 @@ def create_app(chat_model: ChatModel) -> FastAPI:
total_tokens=prompt_length+response_length
)
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role=Role.ASSISTANT, content=response),
choices = [ChatCompletionResponseChoice(
index=i,
message=ChatMessage(role=Role.ASSISTANT, content=choice),
finish_reason=Finish.STOP
)
) for i, choice in enumerate(response)]
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage)
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
choice_data = ChatCompletionResponseStreamChoice(
@@ -96,7 +108,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
finish_reason=None
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield chunk.json(exclude_unset=True, ensure_ascii=False)
yield to_json(chunk)
for new_text in chat_model.stream_chat(
query, history, system,
@@ -114,7 +126,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
finish_reason=None
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield chunk.json(exclude_unset=True, ensure_ascii=False)
yield to_json(chunk)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
@@ -122,7 +134,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
finish_reason=Finish.STOP
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield chunk.json(exclude_unset=True, ensure_ascii=False)
yield to_json(chunk)
yield "[DONE]"
return app