add tool test

Former-commit-id: 639a355a9ceb2e4585b81aea71fc810f4b510776
This commit is contained in:
hiyouga
2024-01-18 10:26:26 +08:00
parent a423274fd9
commit d8affd3967
9 changed files with 63 additions and 37 deletions

View File

@@ -4,7 +4,7 @@ from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
from threading import Thread
from transformers import GenerationConfig, TextIteratorStreamer
from ..data import get_template_and_fix_tokenizer
from ..data import get_template_and_fix_tokenizer, Role
from ..extras.misc import get_logits_processor
from ..model import dispatch_model, load_model_and_tokenizer
from ..hparams import get_infer_args
@@ -36,10 +36,19 @@ class ChatModel:
query: str,
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs
) -> Tuple[Dict[str, Any], int]:
messages = []
if history is not None:
for old_prompt, old_response in history:
messages.append({"role": Role.USER, "content": old_prompt})
messages.append({"role": Role.ASSISTANT, "content": old_response})
messages.append({"role": Role.USER, "content": query})
messages.append({"role": Role.ASSISTANT, "content": ""})
prompt, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, query=query, resp="", history=history, system=system
tokenizer=self.tokenizer, messages=messages, system=system, tools=tools
)
prompt_length = len(prompt)
input_ids = torch.tensor([prompt], device=self.model.device)
@@ -90,6 +99,7 @@ class ChatModel:
query: str,
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs
) -> List[Response]:
r"""
@@ -97,7 +107,7 @@ class ChatModel:
Returns: [(response_text, prompt_length, response_length)] * n (default n=1)
"""
gen_kwargs, prompt_length = self._process_args(query, history, system, **input_kwargs)
gen_kwargs, prompt_length = self._process_args(query, history, system, tools, **input_kwargs)
generate_output = self.model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:]
response = self.tokenizer.batch_decode(
@@ -122,9 +132,10 @@ class ChatModel:
query: str,
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs
) -> Generator[str, None, None]:
gen_kwargs, _ = self._process_args(query, history, system, **input_kwargs)
gen_kwargs, _ = self._process_args(query, history, system, tools, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer