update stream_chat
Former-commit-id: e57b2152cf1d5c9e481523e36be4ed09b88e1285
This commit is contained in:
@@ -23,7 +23,7 @@ class ChatModel:
|
||||
self.generating_args = generating_args
|
||||
|
||||
def process_args(
|
||||
self, query: str, history: List[Tuple[str, str]], prefix: Optional[str] = None, **input_kwargs
|
||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
prefix = prefix if prefix else self.source_prefix
|
||||
|
||||
@@ -59,7 +59,7 @@ class ChatModel:
|
||||
return gen_kwargs, prompt_length
|
||||
|
||||
def chat(
|
||||
self, query: str, history: List[Tuple[str, str]], prefix: Optional[str] = None, **input_kwargs
|
||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
|
||||
) -> Tuple[str, Tuple[int, int]]:
|
||||
gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs)
|
||||
generation_output = self.model.generate(**gen_kwargs)
|
||||
@@ -69,7 +69,7 @@ class ChatModel:
|
||||
return response, (prompt_length, response_length)
|
||||
|
||||
def stream_chat(
|
||||
self, query: str, history: List[Tuple[str, str]], prefix: Optional[str] = None, **input_kwargs
|
||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
|
||||
) -> Generator[str, None, None]:
|
||||
gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs)
|
||||
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||
|
||||
Reference in New Issue
Block a user