format style
Former-commit-id: 53b683531b83cd1d19de97c6565f16c1eca6f5e1
This commit is contained in:
@@ -1,30 +1,29 @@
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
from typing import List, Tuple
|
||||
from pydantic import BaseModel
|
||||
import json
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import List, Tuple
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..chat import ChatModel
|
||||
from ..extras.misc import torch_gc
|
||||
from ..extras.packages import is_fastapi_availble, is_starlette_available, is_uvicorn_available
|
||||
from .protocol import (
|
||||
Role,
|
||||
Finish,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ChatMessage,
|
||||
DeltaMessage,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionResponseUsage,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatMessage,
|
||||
DeltaMessage,
|
||||
Finish,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
Role,
|
||||
ScoreEvaluationRequest,
|
||||
ScoreEvaluationResponse
|
||||
)
|
||||
from ..chat import ChatModel
|
||||
from ..extras.misc import torch_gc
|
||||
from ..extras.packages import (
|
||||
is_fastapi_availble, is_starlette_available, is_uvicorn_available
|
||||
ScoreEvaluationResponse,
|
||||
)
|
||||
|
||||
|
||||
@@ -42,15 +41,15 @@ if is_uvicorn_available():
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: "FastAPI"): # collects GPU memory
|
||||
async def lifespan(app: "FastAPI"): # collects GPU memory
|
||||
yield
|
||||
torch_gc()
|
||||
|
||||
|
||||
def to_json(data: BaseModel) -> str:
|
||||
try: # pydantic v2
|
||||
try: # pydantic v2
|
||||
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
|
||||
except: # pydantic v1
|
||||
except Exception: # pydantic v1
|
||||
return data.json(exclude_unset=True, ensure_ascii=False)
|
||||
|
||||
|
||||
@@ -90,8 +89,8 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
history = []
|
||||
if len(prev_messages) % 2 == 0:
|
||||
for i in range(0, len(prev_messages), 2):
|
||||
if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT:
|
||||
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
||||
if prev_messages[i].role == Role.USER and prev_messages[i + 1].role == Role.ASSISTANT:
|
||||
history.append([prev_messages[i].content, prev_messages[i + 1].content])
|
||||
else:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||
else:
|
||||
@@ -107,65 +106,65 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||
|
||||
responses = chat_model.chat(
|
||||
query, history, system,
|
||||
query,
|
||||
history,
|
||||
system,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_new_tokens=request.max_tokens,
|
||||
num_return_sequences=request.n
|
||||
num_return_sequences=request.n,
|
||||
)
|
||||
|
||||
prompt_length, response_length = 0, 0
|
||||
choices = []
|
||||
for i, response in enumerate(responses):
|
||||
choices.append(ChatCompletionResponseChoice(
|
||||
index=i,
|
||||
message=ChatMessage(role=Role.ASSISTANT, content=response.response_text),
|
||||
finish_reason=Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
|
||||
))
|
||||
choices.append(
|
||||
ChatCompletionResponseChoice(
|
||||
index=i,
|
||||
message=ChatMessage(role=Role.ASSISTANT, content=response.response_text),
|
||||
finish_reason=Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH,
|
||||
)
|
||||
)
|
||||
prompt_length = response.prompt_length
|
||||
response_length += response.response_length
|
||||
|
||||
usage = ChatCompletionResponseUsage(
|
||||
prompt_tokens=prompt_length,
|
||||
completion_tokens=response_length,
|
||||
total_tokens=prompt_length+response_length
|
||||
total_tokens=prompt_length + response_length,
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
||||
|
||||
def stream_chat_completion(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
|
||||
def stream_chat_completion(
|
||||
query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest
|
||||
):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(role=Role.ASSISTANT, content=""),
|
||||
finish_reason=None
|
||||
index=0, delta=DeltaMessage(role=Role.ASSISTANT, content=""), finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield to_json(chunk)
|
||||
|
||||
for new_text in chat_model.stream_chat(
|
||||
query, history, system,
|
||||
query,
|
||||
history,
|
||||
system,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_new_tokens=request.max_tokens
|
||||
max_new_tokens=request.max_tokens,
|
||||
):
|
||||
if len(new_text) == 0:
|
||||
continue
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(content=new_text),
|
||||
finish_reason=None
|
||||
index=0, delta=DeltaMessage(content=new_text), finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield to_json(chunk)
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(),
|
||||
finish_reason=Finish.STOP
|
||||
)
|
||||
choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(), finish_reason=Finish.STOP)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield to_json(chunk)
|
||||
yield "[DONE]"
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import time
|
||||
from enum import Enum, unique
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@unique
|
||||
class Role(str, Enum):
|
||||
|
||||
Reference in New Issue
Block a user