fix streaming response in API

Former-commit-id: 72a17ae3b4fac2dc93b04a816f16f863120bc71b
This commit is contained in:
hiyouga
2023-07-05 22:42:31 +08:00
parent d659907f34
commit 982e76978b
3 changed files with 10 additions and 8 deletions

View File

@@ -13,7 +13,7 @@ from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from transformers import TextIteratorStreamer
from starlette.responses import StreamingResponse
from sse_starlette import EventSourceResponse
from typing import Any, Dict, List, Literal, Optional, Union
from utils import (
@@ -144,7 +144,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
if request.stream:
generate = predict(gen_kwargs, request.model)
return StreamingResponse(generate, media_type="text/event-stream")
return EventSourceResponse(generate, media_type="text/event-stream")
generation_output = model.generate(**gen_kwargs)
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):]
@@ -174,7 +174,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield chunk.json(exclude_unset=True, ensure_ascii=False)
for new_text in streamer:
if len(new_text) == 0:
@@ -186,7 +186,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield chunk.json(exclude_unset=True, ensure_ascii=False)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
@@ -194,7 +194,8 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
finish_reason="stop"
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield chunk.json(exclude_unset=True, ensure_ascii=False)
yield "[DONE]"
if __name__ == "__main__":
@@ -204,4 +205,4 @@ if __name__ == "__main__":
prompt_template = Template(data_args.prompt_template)
source_prefix = data_args.source_prefix if data_args.source_prefix else ""
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)