disentangle model from tuner and rename modules
Former-commit-id: 02cbf91e7e424f8379c1fed01b82a5f7a83b6947
This commit is contained in:
@@ -1,14 +1,8 @@
|
||||
import json
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from contextlib import asynccontextmanager
|
||||
from sse_starlette import EventSourceResponse
|
||||
from typing import List, Tuple
|
||||
from pydantic import BaseModel
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.chat import ChatModel
|
||||
from llmtuner.api.protocol import (
|
||||
Role,
|
||||
Finish,
|
||||
@@ -23,10 +17,28 @@ from llmtuner.api.protocol import (
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionResponseUsage
|
||||
)
|
||||
from llmtuner.chat import ChatModel
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.extras.packages import (
|
||||
is_fastapi_availble, is_starlette_available, is_uvicorn_available
|
||||
)
|
||||
|
||||
|
||||
if is_fastapi_availble():
|
||||
from fastapi import FastAPI, HTTPException, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
|
||||
if is_starlette_available():
|
||||
from sse_starlette import EventSourceResponse
|
||||
|
||||
|
||||
if is_uvicorn_available():
|
||||
import uvicorn
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI): # collects GPU memory
|
||||
async def lifespan(app: "FastAPI"): # collects GPU memory
|
||||
yield
|
||||
torch_gc()
|
||||
|
||||
@@ -38,7 +50,7 @@ def to_json(data: BaseModel) -> str:
|
||||
return data.json(exclude_unset=True, ensure_ascii=False)
|
||||
|
||||
|
||||
def create_app(chat_model: ChatModel) -> FastAPI:
|
||||
def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
@@ -56,12 +68,12 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
||||
|
||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
if len(request.messages) < 1 or request.messages[-1].role != Role.USER:
|
||||
if len(request.messages) == 0 or request.messages[-1].role != Role.USER:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
|
||||
query = request.messages[-1].content
|
||||
prev_messages = request.messages[:-1]
|
||||
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
|
||||
if len(prev_messages) and prev_messages[0].role == Role.SYSTEM:
|
||||
system = prev_messages.pop(0).content
|
||||
else:
|
||||
system = None
|
||||
@@ -73,12 +85,14 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
||||
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
||||
else:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||
else:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||
|
||||
if request.stream:
|
||||
generate = predict(query, history, system, request)
|
||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||
|
||||
response, (prompt_length, response_length) = chat_model.chat(
|
||||
responses = chat_model.chat(
|
||||
query, history, system,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
@@ -87,18 +101,23 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
||||
num_return_sequences=request.n
|
||||
)
|
||||
|
||||
prompt_length, response_length = 0, 0
|
||||
choices = []
|
||||
for i, response in enumerate(responses):
|
||||
choices.append(ChatCompletionResponseChoice(
|
||||
index=i,
|
||||
message=ChatMessage(role=Role.ASSISTANT, content=response.response_text),
|
||||
finish_reason=Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
|
||||
))
|
||||
prompt_length = response.prompt_length
|
||||
response_length += response.response_length
|
||||
|
||||
usage = ChatCompletionResponseUsage(
|
||||
prompt_tokens=prompt_length,
|
||||
completion_tokens=response_length,
|
||||
total_tokens=prompt_length+response_length
|
||||
)
|
||||
|
||||
choices = [ChatCompletionResponseChoice(
|
||||
index=i,
|
||||
message=ChatMessage(role=Role.ASSISTANT, content=choice),
|
||||
finish_reason=Finish.STOP
|
||||
) for i, choice in enumerate(response)]
|
||||
|
||||
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
||||
|
||||
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
|
||||
|
||||
Reference in New Issue
Block a user