update api
Former-commit-id: a90db46e336a657d5fcf480986bfc68c77ad416b
This commit is contained in:
@@ -49,6 +49,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
messages: List[ChatMessage]
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
max_length: Optional[int] = None
|
||||
max_new_tokens: Optional[int] = None
|
||||
stream: Optional[bool] = False
|
||||
|
||||
@@ -100,9 +101,14 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
||||
"input_ids": inputs["input_ids"],
|
||||
"temperature": request.temperature if request.temperature else gen_kwargs["temperature"],
|
||||
"top_p": request.top_p if request.top_p else gen_kwargs["top_p"],
|
||||
"max_new_tokens": request.max_new_tokens if request.max_new_tokens else gen_kwargs["max_new_tokens"],
|
||||
"logits_processor": get_logits_processor()
|
||||
})
|
||||
if request.max_length:
|
||||
gen_kwargs.pop("max_new_tokens", None)
|
||||
gen_kwargs["max_length"] = request.max_length
|
||||
if request.max_new_tokens:
|
||||
gen_kwargs.pop("max_length", None)
|
||||
gen_kwargs["max_new_tokens"] = request.max_new_tokens
|
||||
|
||||
if request.stream:
|
||||
generate = predict(gen_kwargs, request.model)
|
||||
|
||||
Reference in New Issue
Block a user