support vllm

Former-commit-id: 889f6e910e654d8ec3922c2185042d737ffbf1c3
This commit is contained in:
hiyouga
2024-03-07 20:26:31 +08:00
parent 9a69cadab3
commit 056d2d956a
32 changed files with 752 additions and 316 deletions

View File

@@ -1,4 +1,3 @@
import asyncio
import json
import os
from contextlib import asynccontextmanager
@@ -73,7 +72,6 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
allow_headers=["*"],
)
semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
role_mapping = {
Role.USER: DataRole.USER.value,
Role.ASSISTANT: DataRole.ASSISTANT.value,
@@ -89,7 +87,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
async def create_chat_completion(request: ChatCompletionRequest):
if not chat_model.can_generate:
if not chat_model.engine.can_generate:
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
if len(request.messages) == 0:
@@ -121,20 +119,15 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
else:
tools = ""
async with semaphore:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, chat_completion, input_messages, system, tools, request)
def chat_completion(messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest):
if request.stream:
if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
generate = stream_chat_completion(messages, system, tools, request)
generate = stream_chat_completion(input_messages, system, tools, request)
return EventSourceResponse(generate, media_type="text/event-stream")
responses = chat_model.chat(
messages,
responses = await chat_model.achat(
input_messages,
system,
tools,
do_sample=request.do_sample,
@@ -148,7 +141,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
choices = []
for i, response in enumerate(responses):
if tools:
result = chat_model.template.format_tools.extract(response.response_text)
result = chat_model.engine.template.format_tools.extract(response.response_text)
else:
result = response.response_text
@@ -177,7 +170,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
def stream_chat_completion(
async def stream_chat_completion(
messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest
):
choice_data = ChatCompletionResponseStreamChoice(
@@ -186,7 +179,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield jsonify(chunk)
for new_text in chat_model.stream_chat(
async for new_token in chat_model.astream_chat(
messages,
system,
tools,
@@ -195,11 +188,11 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
top_p=request.top_p,
max_new_tokens=request.max_tokens,
):
if len(new_text) == 0:
if len(new_token) == 0:
continue
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=ChatCompletionMessage(content=new_text), finish_reason=None
index=0, delta=ChatCompletionMessage(content=new_token), finish_reason=None
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield jsonify(chunk)
@@ -213,18 +206,13 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
@app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK)
async def create_score_evaluation(request: ScoreEvaluationRequest):
if chat_model.can_generate:
if chat_model.engine.can_generate:
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
async with semaphore:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, get_score, request)
def get_score(request: ScoreEvaluationRequest):
scores = chat_model.get_scores(request.messages, max_length=request.max_length)
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
return ScoreEvaluationResponse(model=request.model, scores=scores)
return app