update stream_chat

Former-commit-id: e57b2152cf1d5c9e481523e36be4ed09b88e1285
This commit is contained in:
hiyouga
2023-07-15 19:51:02 +08:00
parent a8deee27f8
commit 6a0499ef40
2 changed files with 16 additions and 8 deletions

View File

@@ -23,7 +23,7 @@ class ChatModel:
self.generating_args = generating_args
def process_args(
self, query: str, history: List[Tuple[str, str]], prefix: Optional[str] = None, **input_kwargs
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
) -> Tuple[Dict[str, Any], int]:
prefix = prefix if prefix else self.source_prefix
@@ -59,7 +59,7 @@ class ChatModel:
return gen_kwargs, prompt_length
def chat(
self, query: str, history: List[Tuple[str, str]], prefix: Optional[str] = None, **input_kwargs
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
) -> Tuple[str, Tuple[int, int]]:
gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs)
generation_output = self.model.generate(**gen_kwargs)
@@ -69,7 +69,7 @@ class ChatModel:
return response, (prompt_length, response_length)
def stream_chat(
self, query: str, history: List[Tuple[str, str]], prefix: Optional[str] = None, **input_kwargs
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
) -> Generator[str, None, None]:
gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)