update readme
Former-commit-id: 6b08adc8219caacefa8d7b5a618e33ccd6060eec
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Implements API for fine-tuned models.
|
||||
# Implements API for fine-tuned models in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
|
||||
# Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
||||
# Visit http://localhost:8000/docs for document.
|
||||
|
||||
@@ -7,11 +7,10 @@
|
||||
import time
|
||||
import torch
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from threading import Thread
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from contextlib import asynccontextmanager
|
||||
from transformers import TextIteratorStreamer
|
||||
from starlette.responses import StreamingResponse
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
@@ -68,14 +67,14 @@ class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
model: str
|
||||
object: str
|
||||
object: Literal["chat.completion", "chat.completion.chunk"]
|
||||
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
global model, tokenizer, source_prefix
|
||||
global model, tokenizer, source_prefix, generating_args
|
||||
|
||||
if request.messages[-1].role != "user":
|
||||
raise HTTPException(status_code=400, detail="Invalid request")
|
||||
@@ -83,7 +82,9 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
||||
|
||||
prev_messages = request.messages[:-1]
|
||||
if len(prev_messages) > 0 and prev_messages[0].role == "system":
|
||||
source_prefix = prev_messages.pop(0).content
|
||||
prefix = prev_messages.pop(0).content
|
||||
else:
|
||||
prefix = source_prefix
|
||||
|
||||
history = []
|
||||
if len(prev_messages) % 2 == 0:
|
||||
@@ -91,7 +92,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
||||
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
|
||||
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
||||
|
||||
inputs = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")
|
||||
inputs = tokenizer([prompt_template.get_prompt(query, history, prefix)], return_tensors="pt")
|
||||
inputs = inputs.to(model.device)
|
||||
|
||||
gen_kwargs = generating_args.to_dict()
|
||||
@@ -134,7 +135,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
|
||||
delta=DeltaMessage(role="assistant"),
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk")
|
||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
||||
|
||||
for new_text in streamer:
|
||||
@@ -146,7 +147,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
|
||||
delta=DeltaMessage(content=new_text),
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk")
|
||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
@@ -154,7 +155,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
|
||||
delta=DeltaMessage(),
|
||||
finish_reason="stop"
|
||||
)
|
||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk")
|
||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user