update api

Former-commit-id: a90db46e336a657d5fcf480986bfc68c77ad416b
This commit is contained in:
hiyouga
2023-06-26 13:39:57 +08:00
parent f9332bc329
commit 83346e86af
3 changed files with 17 additions and 4 deletions

View File

@@ -49,6 +49,7 @@ class ChatCompletionRequest(BaseModel):
messages: List[ChatMessage]
temperature: Optional[float] = None
top_p: Optional[float] = None
max_length: Optional[int] = None
max_new_tokens: Optional[int] = None
stream: Optional[bool] = False
@@ -100,9 +101,14 @@ async def create_chat_completion(request: ChatCompletionRequest):
"input_ids": inputs["input_ids"],
"temperature": request.temperature if request.temperature else gen_kwargs["temperature"],
"top_p": request.top_p if request.top_p else gen_kwargs["top_p"],
"max_new_tokens": request.max_new_tokens if request.max_new_tokens else gen_kwargs["max_new_tokens"],
"logits_processor": get_logits_processor()
})
if request.max_length:
gen_kwargs.pop("max_new_tokens", None)
gen_kwargs["max_length"] = request.max_length
if request.max_new_tokens:
gen_kwargs.pop("max_length", None)
gen_kwargs["max_new_tokens"] = request.max_new_tokens
if request.stream:
generate = predict(gen_kwargs, request.model)