add tool test
Former-commit-id: 639a355a9ceb2e4585b81aea71fc810f4b510776
This commit is contained in:
@@ -4,7 +4,7 @@ from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
|
||||
from threading import Thread
|
||||
from transformers import GenerationConfig, TextIteratorStreamer
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..data import get_template_and_fix_tokenizer, Role
|
||||
from ..extras.misc import get_logits_processor
|
||||
from ..model import dispatch_model, load_model_and_tokenizer
|
||||
from ..hparams import get_infer_args
|
||||
@@ -36,10 +36,19 @@ class ChatModel:
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
**input_kwargs
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
messages = []
|
||||
if history is not None:
|
||||
for old_prompt, old_response in history:
|
||||
messages.append({"role": Role.USER, "content": old_prompt})
|
||||
messages.append({"role": Role.ASSISTANT, "content": old_response})
|
||||
|
||||
messages.append({"role": Role.USER, "content": query})
|
||||
messages.append({"role": Role.ASSISTANT, "content": ""})
|
||||
prompt, _ = self.template.encode_oneturn(
|
||||
tokenizer=self.tokenizer, query=query, resp="", history=history, system=system
|
||||
tokenizer=self.tokenizer, messages=messages, system=system, tools=tools
|
||||
)
|
||||
prompt_length = len(prompt)
|
||||
input_ids = torch.tensor([prompt], device=self.model.device)
|
||||
@@ -90,6 +99,7 @@ class ChatModel:
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
**input_kwargs
|
||||
) -> List[Response]:
|
||||
r"""
|
||||
@@ -97,7 +107,7 @@ class ChatModel:
|
||||
|
||||
Returns: [(response_text, prompt_length, response_length)] * n (default n=1)
|
||||
"""
|
||||
gen_kwargs, prompt_length = self._process_args(query, history, system, **input_kwargs)
|
||||
gen_kwargs, prompt_length = self._process_args(query, history, system, tools, **input_kwargs)
|
||||
generate_output = self.model.generate(**gen_kwargs)
|
||||
response_ids = generate_output[:, prompt_length:]
|
||||
response = self.tokenizer.batch_decode(
|
||||
@@ -122,9 +132,10 @@ class ChatModel:
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
**input_kwargs
|
||||
) -> Generator[str, None, None]:
|
||||
gen_kwargs, _ = self._process_args(query, history, system, **input_kwargs)
|
||||
gen_kwargs, _ = self._process_args(query, history, system, tools, **input_kwargs)
|
||||
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||
gen_kwargs["streamer"] = streamer
|
||||
|
||||
|
||||
Reference in New Issue
Block a user