match api with OpenAI format

Former-commit-id: 9cbe2b98b024393817e86ff8e3ff1636776fa263
This commit is contained in:
hiyouga
2023-06-22 20:27:00 +08:00
parent 84b66010a3
commit 391bf1c699
4 changed files with 144 additions and 135 deletions

View File

@@ -2,157 +2,157 @@
# Implements API for fine-tuned models.
# Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
# Request:
# curl http://127.0.0.1:8000 --header 'Content-Type: application/json' --data '{"prompt": "Hello there!", "history": []}'
# Response:
# {
# "response": "'Hi there!'",
# "history": "[('Hello there!', 'Hi there!')]",
# "status": 200,
# "time": "2000-00-00 00:00:00"
# }
import json
import datetime
import time
import torch
import uvicorn
from fastapi import FastAPI
from threading import Thread
from fastapi import FastAPI, Request
from starlette.responses import StreamingResponse
from contextlib import asynccontextmanager
from pydantic import BaseModel, Field
from transformers import TextIteratorStreamer
from starlette.responses import StreamingResponse
from typing import Any, Dict, List, Literal, Optional, Union
from utils import Template, load_pretrained, prepare_infer_args, get_logits_processor
from utils import (
Template,
load_pretrained,
prepare_infer_args,
get_logits_processor
)
def torch_gc():
@asynccontextmanager
async def lifespan(app: FastAPI): # collects GPU memory
yield
if torch.cuda.is_available():
num_gpus = torch.cuda.device_count()
for device_id in range(num_gpus):
with torch.cuda.device(device_id):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI()
app = FastAPI(lifespan=lifespan)
@app.post("/v1/chat/completions")
async def create_item(request: Request):
global model, tokenizer
class ChatMessage(BaseModel):
role: Literal["system", "user", "assistant"]
content: str
json_post_raw = await request.json()
prompt = json_post_raw.get("messages")[-1]["content"]
history = json_post_raw.get("messages")[:-1]
max_token = json_post_raw.get("max_tokens", None)
top_p = json_post_raw.get("top_p", None)
temperature = json_post_raw.get("temperature", None)
stream = check_stream(json_post_raw.get("stream"))
if stream:
generate = predict(prompt, max_token, top_p, temperature, history)
return StreamingResponse(generate, media_type="text/event-stream")
class DeltaMessage(BaseModel):
role: Optional[Literal["system", "user", "assistant"]] = None
content: Optional[str] = None
input_ids = tokenizer([prompt_template.get_prompt(prompt, history, source_prefix)], return_tensors="pt")[
"input_ids"]
input_ids = input_ids.to(model.device)
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: Optional[float] = None
top_p: Optional[float] = None
max_new_tokens: Optional[int] = None
stream: Optional[bool] = False
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length"]
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]]
class ChatCompletionResponse(BaseModel):
model: str
object: str
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
global model, tokenizer, source_prefix
query = request.messages[-1].content
prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == "system":
source_prefix = prev_messages.pop(0).content
history = []
if len(prev_messages) % 2 == 0:
for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
history.append([prev_messages[i].content, prev_messages[i+1].content])
inputs = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")
inputs = inputs.to(model.device)
gen_kwargs = generating_args.to_dict()
gen_kwargs["input_ids"] = input_ids
gen_kwargs["logits_processor"] = get_logits_processor()
gen_kwargs["max_new_tokens"] = max_token if max_token else gen_kwargs["max_new_tokens"]
gen_kwargs["top_p"] = top_p if top_p else gen_kwargs["top_p"]
gen_kwargs["temperature"] = temperature if temperature else gen_kwargs["temperature"]
gen_kwargs.update({
"input_ids": inputs["input_ids"],
"temperature": request.temperature if request.temperature else gen_kwargs["temperature"],
"top_p": request.top_p if request.top_p else gen_kwargs["top_p"],
"max_new_tokens": request.max_new_tokens if request.max_new_tokens else gen_kwargs["max_new_tokens"],
"logits_processor": get_logits_processor()
})
if request.stream:
generate = predict(gen_kwargs, request.model)
return StreamingResponse(generate, media_type="text/event-stream")
generation_output = model.generate(**gen_kwargs)
outputs = generation_output.tolist()[0][len(input_ids[0]):]
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):]
response = tokenizer.decode(outputs, skip_special_tokens=True)
now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = {
"choices": [
{
"message": {
"role": "assistant",
"content": response
}
}
]
}
log = (
"["
+ time
+ "] "
+ "\", prompt:\""
+ prompt
+ "\", response:\""
+ repr(response)
+ "\""
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop"
)
print(log)
torch_gc()
return answer
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")
def check_stream(stream):
if isinstance(stream, bool):
# stream 是布尔类型,直接使用
stream_value = stream
else:
# 不是布尔类型,尝试进行类型转换
if isinstance(stream, str):
stream = stream.lower()
if stream in ["true", "false"]:
# 使用字符串值转换为布尔值
stream_value = stream == "true"
else:
# 非法的字符串值
stream_value = False
else:
# 非布尔类型也非字符串类型
stream_value = False
return stream_value
async def predict(query, max_length, top_p, temperature, history):
async def predict(gen_kwargs: Dict[str, Any], model_id: str):
global model, tokenizer
input_ids = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")["input_ids"]
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = {
"input_ids": input_ids,
"do_sample": generating_args.do_sample,
"top_p": top_p,
"temperature": temperature,
"num_beams": generating_args.num_beams,
"max_length": max_length,
"repetition_penalty": generating_args.repetition_penalty,
"logits_processor": get_logits_processor(),
"streamer": streamer
}
gen_kwargs["streamer"] = streamer
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
for new_text in streamer:
answer = {
"choices": [
{
"message": {
"role": "assistant",
"content": new_text
}
}
]
}
yield "data: " + json.dumps(answer) + '\n\n'
if len(new_text) == 0:
continue
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(content=new_text),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason="stop"
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
if __name__ == "__main__":