format style

Former-commit-id: 53b683531b83cd1d19de97c6565f16c1eca6f5e1
This commit is contained in:
hiyouga
2024-01-20 20:15:56 +08:00
parent 1750218057
commit 66e0e651b9
73 changed files with 1492 additions and 2325 deletions

View File

@@ -1,30 +1,29 @@
import os
import json
import asyncio
from typing import List, Tuple
from pydantic import BaseModel
import json
import os
from contextlib import asynccontextmanager
from typing import List, Tuple
from pydantic import BaseModel
from ..chat import ChatModel
from ..extras.misc import torch_gc
from ..extras.packages import is_fastapi_availble, is_starlette_available, is_uvicorn_available
from .protocol import (
Role,
Finish,
ModelCard,
ModelList,
ChatMessage,
DeltaMessage,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionStreamResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionResponseUsage,
ChatCompletionStreamResponse,
ChatMessage,
DeltaMessage,
Finish,
ModelCard,
ModelList,
Role,
ScoreEvaluationRequest,
ScoreEvaluationResponse
)
from ..chat import ChatModel
from ..extras.misc import torch_gc
from ..extras.packages import (
is_fastapi_availble, is_starlette_available, is_uvicorn_available
ScoreEvaluationResponse,
)
@@ -42,15 +41,15 @@ if is_uvicorn_available():
@asynccontextmanager
async def lifespan(app: "FastAPI"): # collects GPU memory
async def lifespan(app: "FastAPI"): # collects GPU memory
yield
torch_gc()
def to_json(data: BaseModel) -> str:
try: # pydantic v2
try: # pydantic v2
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
except: # pydantic v1
except Exception: # pydantic v1
return data.json(exclude_unset=True, ensure_ascii=False)
@@ -90,8 +89,8 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
history = []
if len(prev_messages) % 2 == 0:
for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT:
history.append([prev_messages[i].content, prev_messages[i+1].content])
if prev_messages[i].role == Role.USER and prev_messages[i + 1].role == Role.ASSISTANT:
history.append([prev_messages[i].content, prev_messages[i + 1].content])
else:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
else:
@@ -107,65 +106,65 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
return EventSourceResponse(generate, media_type="text/event-stream")
responses = chat_model.chat(
query, history, system,
query,
history,
system,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
max_new_tokens=request.max_tokens,
num_return_sequences=request.n
num_return_sequences=request.n,
)
prompt_length, response_length = 0, 0
choices = []
for i, response in enumerate(responses):
choices.append(ChatCompletionResponseChoice(
index=i,
message=ChatMessage(role=Role.ASSISTANT, content=response.response_text),
finish_reason=Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
))
choices.append(
ChatCompletionResponseChoice(
index=i,
message=ChatMessage(role=Role.ASSISTANT, content=response.response_text),
finish_reason=Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH,
)
)
prompt_length = response.prompt_length
response_length += response.response_length
usage = ChatCompletionResponseUsage(
prompt_tokens=prompt_length,
completion_tokens=response_length,
total_tokens=prompt_length+response_length
total_tokens=prompt_length + response_length,
)
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
def stream_chat_completion(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
def stream_chat_completion(
query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest
):
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role=Role.ASSISTANT, content=""),
finish_reason=None
index=0, delta=DeltaMessage(role=Role.ASSISTANT, content=""), finish_reason=None
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield to_json(chunk)
for new_text in chat_model.stream_chat(
query, history, system,
query,
history,
system,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
max_new_tokens=request.max_tokens
max_new_tokens=request.max_tokens,
):
if len(new_text) == 0:
continue
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(content=new_text),
finish_reason=None
index=0, delta=DeltaMessage(content=new_text), finish_reason=None
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield to_json(chunk)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason=Finish.STOP
)
choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(), finish_reason=Finish.STOP)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield to_json(chunk)
yield "[DONE]"

View File

@@ -1,8 +1,9 @@
import time
from enum import Enum, unique
from pydantic import BaseModel, Field
from typing import List, Optional
from pydantic import BaseModel, Field
@unique
class Role(str, Enum):