Former-commit-id: 412d856eeada2abcea598fac0a8d35ae90cc9c01
This commit is contained in:
hiyouga
2024-02-06 15:23:08 +08:00
parent 0dd68d1e06
commit b564b97b7e
2 changed files with 15 additions and 1 deletions

View File

@@ -94,6 +94,9 @@ class ChatModel:
tools: Optional[str] = None,
**input_kwargs,
) -> List[Response]:
if not self.can_generate:
raise ValueError("The current model does not support `chat`.")
gen_kwargs, prompt_length = self._process_args(messages, system, tools, **input_kwargs)
generate_output = self.model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:]
@@ -123,6 +126,9 @@ class ChatModel:
tools: Optional[str] = None,
**input_kwargs,
) -> Generator[str, None, None]:
if not self.can_generate:
raise ValueError("The current model does not support `stream_chat`.")
gen_kwargs, _ = self._process_args(messages, system, tools, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer
@@ -134,9 +140,11 @@ class ChatModel:
@torch.inference_mode()
def get_scores(self, batch_input: List[str], **input_kwargs) -> List[float]:
if self.can_generate:
raise ValueError("Cannot get scores using an auto-regressive model.")
max_length = input_kwargs.pop("max_length", None)
device = getattr(self.model.pretrained_model, "device", "cuda")
inputs = self.tokenizer(
batch_input,
padding=True,