fix system prompt

Former-commit-id: 411e775aa939bdd154a3f1e92921ede90d989f18
This commit is contained in:
hiyouga
2023-08-16 01:35:52 +08:00
parent ca9a494d0c
commit baa709674f
15 changed files with 170 additions and 152 deletions

View File

@@ -47,15 +47,15 @@ def create_app(chat_model: ChatModel) -> FastAPI:
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
if request.messages[-1].role != Role.USER:
if len(request.messages) < 1 or request.messages[-1].role != Role.USER:
raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content
prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
prefix = prev_messages.pop(0).content
system = prev_messages.pop(0).content
else:
prefix = None
system = None
history = []
if len(prev_messages) % 2 == 0:
@@ -64,11 +64,11 @@ def create_app(chat_model: ChatModel) -> FastAPI:
history.append([prev_messages[i].content, prev_messages[i+1].content])
if request.stream:
generate = predict(query, history, prefix, request)
generate = predict(query, history, system, request)
return EventSourceResponse(generate, media_type="text/event-stream")
response, (prompt_length, response_length) = chat_model.chat(
query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
query, history, system, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
)
usage = ChatCompletionResponseUsage(
@@ -85,7 +85,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage)
async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest):
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role=Role.ASSISTANT),
@@ -95,7 +95,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
yield chunk.json(exclude_unset=True, ensure_ascii=False)
for new_text in chat_model.stream_chat(
query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
query, history, system, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
):
if len(new_text) == 0:
continue