update readme

Former-commit-id: 6b08adc8219caacefa8d7b5a618e33ccd6060eec
This commit is contained in:
hiyouga
2023-06-23 00:17:05 +08:00
parent 7daf6c8b8e
commit f9332bc329
3 changed files with 22 additions and 27 deletions

View File

@@ -1,5 +1,5 @@
# coding=utf-8
# Implements API for fine-tuned models.
# Implements API for fine-tuned models in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
# Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
# Visit http://localhost:8000/docs for document.
@@ -7,11 +7,10 @@
import time
import torch
import uvicorn
from fastapi import FastAPI, HTTPException
from threading import Thread
from contextlib import asynccontextmanager
from pydantic import BaseModel, Field
from fastapi import FastAPI, HTTPException
from contextlib import asynccontextmanager
from transformers import TextIteratorStreamer
from starlette.responses import StreamingResponse
from typing import Any, Dict, List, Literal, Optional, Union
@@ -68,14 +67,14 @@ class ChatCompletionResponseStreamChoice(BaseModel):
class ChatCompletionResponse(BaseModel):
model: str
object: str
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
global model, tokenizer, source_prefix
global model, tokenizer, source_prefix, generating_args
if request.messages[-1].role != "user":
raise HTTPException(status_code=400, detail="Invalid request")
@@ -83,7 +82,9 @@ async def create_chat_completion(request: ChatCompletionRequest):
prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == "system":
source_prefix = prev_messages.pop(0).content
prefix = prev_messages.pop(0).content
else:
prefix = source_prefix
history = []
if len(prev_messages) % 2 == 0:
@@ -91,7 +92,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
history.append([prev_messages[i].content, prev_messages[i+1].content])
inputs = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")
inputs = tokenizer([prompt_template.get_prompt(query, history, prefix)], return_tensors="pt")
inputs = inputs.to(model.device)
gen_kwargs = generating_args.to_dict()
@@ -134,7 +135,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
delta=DeltaMessage(role="assistant"),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk")
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
for new_text in streamer:
@@ -146,7 +147,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
delta=DeltaMessage(content=new_text),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk")
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
choice_data = ChatCompletionResponseStreamChoice(
@@ -154,7 +155,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
delta=DeltaMessage(),
finish_reason="stop"
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk")
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))