Former-commit-id: fd557ebb5e3ef2ca330b4d97731af43f4a5a5fc5
This commit is contained in:
hiyouga
2023-07-17 18:07:17 +08:00
parent e9736b2ba0
commit c08ff734a7
5 changed files with 38 additions and 12 deletions

View File

@@ -1,3 +1,4 @@
import torch
from typing import Any, Dict, Generator, List, Optional, Tuple
from threading import Thread
from transformers import TextIteratorStreamer
@@ -41,10 +42,10 @@ class ChatModel:
gen_kwargs = self.generating_args.to_dict()
gen_kwargs.update(dict(
input_ids=inputs["input_ids"],
temperature=temperature if temperature else gen_kwargs["temperature"],
top_p=top_p if top_p else gen_kwargs["top_p"],
top_k=top_k if top_k else gen_kwargs["top_k"],
repetition_penalty=repetition_penalty if repetition_penalty else gen_kwargs["repetition_penalty"],
temperature=temperature or gen_kwargs["temperature"],
top_p=top_p or gen_kwargs["top_p"],
top_k=top_k or gen_kwargs["top_k"],
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
logits_processor=get_logits_processor()
))
@@ -58,6 +59,7 @@ class ChatModel:
return gen_kwargs, prompt_length
@torch.inference_mode()
def chat(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
) -> Tuple[str, Tuple[int, int]]:
@@ -68,6 +70,7 @@ class ChatModel:
response_length = len(outputs)
return response, (prompt_length, response_length)
@torch.inference_mode()
def stream_chat(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
) -> Generator[str, None, None]: