update api and support abort eval in webui
Former-commit-id: 8661bed68812e9ded9439e8a821b1d7716bc797b
This commit is contained in:
@@ -1,36 +1,29 @@
|
||||
import json
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, Dict, Sequence
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from ..chat import ChatModel
|
||||
from ..data import Role as DataRole
|
||||
from ..extras.misc import torch_gc
|
||||
from ..extras.packages import is_fastapi_availble, is_starlette_available, is_uvicorn_available
|
||||
from .chat import (
|
||||
create_chat_completion_response,
|
||||
create_score_evaluation_response,
|
||||
create_stream_chat_completion_response,
|
||||
)
|
||||
from .protocol import (
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionResponseUsage,
|
||||
ChatCompletionStreamResponse,
|
||||
Finish,
|
||||
Function,
|
||||
FunctionCall,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
Role,
|
||||
ScoreEvaluationRequest,
|
||||
ScoreEvaluationResponse,
|
||||
)
|
||||
|
||||
|
||||
if is_fastapi_availble():
|
||||
from fastapi import FastAPI, HTTPException, status
|
||||
from fastapi import Depends, FastAPI, HTTPException, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
|
||||
if is_starlette_available():
|
||||
@@ -47,23 +40,8 @@ async def lifespan(app: "FastAPI"): # collects GPU memory
|
||||
torch_gc()
|
||||
|
||||
|
||||
def dictify(data: "BaseModel") -> Dict[str, Any]:
|
||||
try: # pydantic v2
|
||||
return data.model_dump(exclude_unset=True)
|
||||
except AttributeError: # pydantic v1
|
||||
return data.dict(exclude_unset=True)
|
||||
|
||||
|
||||
def jsonify(data: "BaseModel") -> str:
|
||||
try: # pydantic v2
|
||||
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
|
||||
except AttributeError: # pydantic v1
|
||||
return data.json(exclude_unset=True, ensure_ascii=False)
|
||||
|
||||
|
||||
def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
@@ -71,161 +49,58 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
api_key = os.environ.get("API_KEY", None)
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
role_mapping = {
|
||||
Role.USER: DataRole.USER.value,
|
||||
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
||||
Role.SYSTEM: DataRole.SYSTEM.value,
|
||||
Role.FUNCTION: DataRole.FUNCTION.value,
|
||||
Role.TOOL: DataRole.OBSERVATION.value,
|
||||
}
|
||||
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
|
||||
if api_key and (auth is None or auth.credentials != api_key):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.")
|
||||
|
||||
@app.get("/v1/models", response_model=ModelList)
|
||||
@app.get(
|
||||
"/v1/models",
|
||||
response_model=ModelList,
|
||||
status_code=status.HTTP_200_OK,
|
||||
dependencies=[Depends(verify_api_key)],
|
||||
)
|
||||
async def list_models():
|
||||
model_card = ModelCard(id="gpt-3.5-turbo")
|
||||
return ModelList(data=[model_card])
|
||||
|
||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
|
||||
@app.post(
|
||||
"/v1/chat/completions",
|
||||
response_model=ChatCompletionResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
dependencies=[Depends(verify_api_key)],
|
||||
)
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
if not chat_model.engine.can_generate:
|
||||
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
||||
|
||||
if len(request.messages) == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
||||
|
||||
if request.messages[0].role == Role.SYSTEM:
|
||||
system = request.messages.pop(0).content
|
||||
else:
|
||||
system = ""
|
||||
|
||||
if len(request.messages) % 2 == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||
|
||||
input_messages = []
|
||||
for i, message in enumerate(request.messages):
|
||||
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
|
||||
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
|
||||
name = message.tool_calls[0].function.name
|
||||
arguments = message.tool_calls[0].function.arguments
|
||||
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
|
||||
input_messages.append({"role": role_mapping[Role.FUNCTION], "content": content})
|
||||
else:
|
||||
input_messages.append({"role": role_mapping[message.role], "content": message.content})
|
||||
|
||||
tool_list = request.tools
|
||||
if isinstance(tool_list, list) and len(tool_list):
|
||||
try:
|
||||
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
|
||||
else:
|
||||
tools = ""
|
||||
|
||||
if request.stream:
|
||||
if tools:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
||||
|
||||
generate = stream_chat_completion(input_messages, system, tools, request)
|
||||
generate = create_stream_chat_completion_response(request, chat_model)
|
||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||
else:
|
||||
return await create_chat_completion_response(request, chat_model)
|
||||
|
||||
responses = await chat_model.achat(
|
||||
input_messages,
|
||||
system,
|
||||
tools,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_new_tokens=request.max_tokens,
|
||||
num_return_sequences=request.n,
|
||||
)
|
||||
|
||||
prompt_length, response_length = 0, 0
|
||||
choices = []
|
||||
for i, response in enumerate(responses):
|
||||
if tools:
|
||||
result = chat_model.engine.template.format_tools.extract(response.response_text)
|
||||
else:
|
||||
result = response.response_text
|
||||
|
||||
if isinstance(result, tuple):
|
||||
name, arguments = result
|
||||
function = Function(name=name, arguments=arguments)
|
||||
response_message = ChatCompletionMessage(
|
||||
role=Role.ASSISTANT, tool_calls=[FunctionCall(function=function)]
|
||||
)
|
||||
finish_reason = Finish.TOOL
|
||||
else:
|
||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
|
||||
finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
|
||||
|
||||
choices.append(
|
||||
ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason)
|
||||
)
|
||||
prompt_length = response.prompt_length
|
||||
response_length += response.response_length
|
||||
|
||||
usage = ChatCompletionResponseUsage(
|
||||
prompt_tokens=prompt_length,
|
||||
completion_tokens=response_length,
|
||||
total_tokens=prompt_length + response_length,
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
||||
|
||||
async def stream_chat_completion(
|
||||
messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest
|
||||
):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0, delta=ChatCompletionMessage(role=Role.ASSISTANT, content=""), finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield jsonify(chunk)
|
||||
|
||||
async for new_token in chat_model.astream_chat(
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_new_tokens=request.max_tokens,
|
||||
):
|
||||
if len(new_token) == 0:
|
||||
continue
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0, delta=ChatCompletionMessage(content=new_token), finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield jsonify(chunk)
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield jsonify(chunk)
|
||||
yield "[DONE]"
|
||||
|
||||
@app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK)
|
||||
@app.post(
|
||||
"/v1/score/evaluation",
|
||||
response_model=ScoreEvaluationResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
dependencies=[Depends(verify_api_key)],
|
||||
)
|
||||
async def create_score_evaluation(request: ScoreEvaluationRequest):
|
||||
if chat_model.engine.can_generate:
|
||||
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
||||
|
||||
if len(request.messages) == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
|
||||
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
|
||||
return ScoreEvaluationResponse(model=request.model, scores=scores)
|
||||
return await create_score_evaluation_response(request, chat_model)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def run_api():
|
||||
def run_api() -> None:
|
||||
chat_model = ChatModel()
|
||||
app = create_app(chat_model)
|
||||
print("Visit http://localhost:{}/docs for API document.".format(os.environ.get("API_PORT", 8000)))
|
||||
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1)
|
||||
api_host = os.environ.get("API_HOST", "0.0.0.0")
|
||||
api_port = int(os.environ.get("API_PORT", "8000"))
|
||||
print("Visit http://localhost:{}/docs for API document.".format(api_port))
|
||||
uvicorn.run(app, host=api_host, port=api_port)
|
||||
|
||||
176
src/llmtuner/api/chat.py
Normal file
176
src/llmtuner/api/chat.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
|
||||
from ..data import Role as DataRole
|
||||
from ..extras.packages import is_fastapi_availble
|
||||
from .common import dictify, jsonify
|
||||
from .protocol import (
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseUsage,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatCompletionStreamResponseChoice,
|
||||
Finish,
|
||||
Function,
|
||||
FunctionCall,
|
||||
Role,
|
||||
ScoreEvaluationResponse,
|
||||
)
|
||||
|
||||
|
||||
if is_fastapi_availble():
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..chat import ChatModel
|
||||
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
|
||||
|
||||
|
||||
ROLE_MAPPING = {
|
||||
Role.USER: DataRole.USER.value,
|
||||
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
||||
Role.SYSTEM: DataRole.SYSTEM.value,
|
||||
Role.FUNCTION: DataRole.FUNCTION.value,
|
||||
Role.TOOL: DataRole.OBSERVATION.value,
|
||||
}
|
||||
|
||||
|
||||
async def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, str]], str, str]:
|
||||
if len(request.messages) == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
||||
|
||||
if request.messages[0].role == Role.SYSTEM:
|
||||
system = request.messages.pop(0).content
|
||||
else:
|
||||
system = ""
|
||||
|
||||
if len(request.messages) % 2 == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||
|
||||
input_messages = []
|
||||
for i, message in enumerate(request.messages):
|
||||
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
|
||||
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
|
||||
name = message.tool_calls[0].function.name
|
||||
arguments = message.tool_calls[0].function.arguments
|
||||
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
|
||||
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
|
||||
else:
|
||||
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
|
||||
|
||||
tool_list = request.tools
|
||||
if isinstance(tool_list, list) and len(tool_list):
|
||||
try:
|
||||
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
|
||||
else:
|
||||
tools = ""
|
||||
|
||||
return input_messages, system, tools
|
||||
|
||||
|
||||
async def create_chat_completion_response(
|
||||
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
||||
) -> "ChatCompletionResponse":
|
||||
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||
input_messages, system, tools = await _process_request(request)
|
||||
responses = await chat_model.achat(
|
||||
input_messages,
|
||||
system,
|
||||
tools,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_new_tokens=request.max_tokens,
|
||||
num_return_sequences=request.n,
|
||||
)
|
||||
|
||||
prompt_length, response_length = 0, 0
|
||||
choices = []
|
||||
for i, response in enumerate(responses):
|
||||
if tools:
|
||||
result = chat_model.engine.template.format_tools.extract(response.response_text)
|
||||
else:
|
||||
result = response.response_text
|
||||
|
||||
if isinstance(result, tuple):
|
||||
name, arguments = result
|
||||
function = Function(name=name, arguments=arguments)
|
||||
tool_call = FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function)
|
||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=[tool_call])
|
||||
finish_reason = Finish.TOOL
|
||||
else:
|
||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
|
||||
finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
|
||||
|
||||
choices.append(ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason))
|
||||
prompt_length = response.prompt_length
|
||||
response_length += response.response_length
|
||||
|
||||
usage = ChatCompletionResponseUsage(
|
||||
prompt_tokens=prompt_length,
|
||||
completion_tokens=response_length,
|
||||
total_tokens=prompt_length + response_length,
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(id=completion_id, model=request.model, choices=choices, usage=usage)
|
||||
|
||||
|
||||
async def _create_stream_chat_completion_chunk(
|
||||
completion_id: str,
|
||||
model: str,
|
||||
delta: "ChatCompletionMessage",
|
||||
index: Optional[int] = 0,
|
||||
finish_reason: Optional["Finish"] = None,
|
||||
) -> str:
|
||||
choice_data = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason)
|
||||
chunk = ChatCompletionStreamResponse(id=completion_id, model=model, choices=[choice_data])
|
||||
return jsonify(chunk)
|
||||
|
||||
|
||||
async def create_stream_chat_completion_response(
|
||||
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
||||
) -> AsyncGenerator[str, None]:
|
||||
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||
input_messages, system, tools = await _process_request(request)
|
||||
if tools:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
||||
|
||||
yield _create_stream_chat_completion_chunk(
|
||||
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(role=Role.ASSISTANT, content="")
|
||||
)
|
||||
async for new_token in chat_model.astream_chat(
|
||||
input_messages,
|
||||
system,
|
||||
tools,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_new_tokens=request.max_tokens,
|
||||
):
|
||||
yield _create_stream_chat_completion_chunk(
|
||||
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(content=new_token)
|
||||
)
|
||||
|
||||
yield _create_stream_chat_completion_chunk(
|
||||
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
|
||||
)
|
||||
yield "[DONE]"
|
||||
|
||||
|
||||
async def create_score_evaluation_response(
|
||||
request: "ScoreEvaluationRequest", chat_model: "ChatModel"
|
||||
) -> "ScoreEvaluationResponse":
|
||||
if len(request.messages) == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
|
||||
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
|
||||
return ScoreEvaluationResponse(model=request.model, scores=scores)
|
||||
20
src/llmtuner/api/common.py
Normal file
20
src/llmtuner/api/common.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def dictify(data: "BaseModel") -> Dict[str, Any]:
|
||||
try: # pydantic v2
|
||||
return data.model_dump(exclude_unset=True)
|
||||
except AttributeError: # pydantic v1
|
||||
return data.dict(exclude_unset=True)
|
||||
|
||||
|
||||
def jsonify(data: "BaseModel") -> str:
|
||||
try: # pydantic v2
|
||||
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
|
||||
except AttributeError: # pydantic v1
|
||||
return data.json(exclude_unset=True, ensure_ascii=False)
|
||||
@@ -51,7 +51,7 @@ class FunctionAvailable(BaseModel):
|
||||
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
id: Literal["call_default"] = "call_default"
|
||||
id: str
|
||||
type: Literal["function"] = "function"
|
||||
function: Function
|
||||
|
||||
@@ -86,7 +86,7 @@ class ChatCompletionResponseChoice(BaseModel):
|
||||
finish_reason: Finish
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
class ChatCompletionStreamResponseChoice(BaseModel):
|
||||
index: int
|
||||
delta: ChatCompletionMessage
|
||||
finish_reason: Optional[Finish] = None
|
||||
@@ -99,7 +99,7 @@ class ChatCompletionResponseUsage(BaseModel):
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: Literal["chatcmpl-default"] = "chatcmpl-default"
|
||||
id: str
|
||||
object: Literal["chat.completion"] = "chat.completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
@@ -108,11 +108,11 @@ class ChatCompletionResponse(BaseModel):
|
||||
|
||||
|
||||
class ChatCompletionStreamResponse(BaseModel):
|
||||
id: Literal["chatcmpl-default"] = "chatcmpl-default"
|
||||
id: str
|
||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseStreamChoice]
|
||||
choices: List[ChatCompletionStreamResponseChoice]
|
||||
|
||||
|
||||
class ScoreEvaluationRequest(BaseModel):
|
||||
@@ -122,7 +122,7 @@ class ScoreEvaluationRequest(BaseModel):
|
||||
|
||||
|
||||
class ScoreEvaluationResponse(BaseModel):
|
||||
id: Literal["scoreeval-default"] = "scoreeval-default"
|
||||
id: str
|
||||
object: Literal["score.evaluation"] = "score.evaluation"
|
||||
model: str
|
||||
scores: List[float]
|
||||
|
||||
Reference in New Issue
Block a user