@@ -1,3 +1,4 @@
|
||||
import torch
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||
from threading import Thread
|
||||
from transformers import TextIteratorStreamer
|
||||
@@ -41,10 +42,10 @@ class ChatModel:
|
||||
gen_kwargs = self.generating_args.to_dict()
|
||||
gen_kwargs.update(dict(
|
||||
input_ids=inputs["input_ids"],
|
||||
temperature=temperature if temperature else gen_kwargs["temperature"],
|
||||
top_p=top_p if top_p else gen_kwargs["top_p"],
|
||||
top_k=top_k if top_k else gen_kwargs["top_k"],
|
||||
repetition_penalty=repetition_penalty if repetition_penalty else gen_kwargs["repetition_penalty"],
|
||||
temperature=temperature or gen_kwargs["temperature"],
|
||||
top_p=top_p or gen_kwargs["top_p"],
|
||||
top_k=top_k or gen_kwargs["top_k"],
|
||||
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
|
||||
logits_processor=get_logits_processor()
|
||||
))
|
||||
|
||||
@@ -58,6 +59,7 @@ class ChatModel:
|
||||
|
||||
return gen_kwargs, prompt_length
|
||||
|
||||
@torch.inference_mode()
|
||||
def chat(
|
||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
|
||||
) -> Tuple[str, Tuple[int, int]]:
|
||||
@@ -68,6 +70,7 @@ class ChatModel:
|
||||
response_length = len(outputs)
|
||||
return response, (prompt_length, response_length)
|
||||
|
||||
@torch.inference_mode()
|
||||
def stream_chat(
|
||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
|
||||
) -> Generator[str, None, None]:
|
||||
|
||||
Reference in New Issue
Block a user