fix tool formatter, allow parallel function #4362

Former-commit-id: b8f16c976db4ecec1cc8558851c8cbfb6a5b7e9c
This commit is contained in:
hiyouga
2024-06-19 03:23:51 +08:00
parent 19ffcfea76
commit e36a994fe6
5 changed files with 207 additions and 86 deletions

View File

@@ -79,6 +79,12 @@ class Template:
"""
return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]:
r"""
Extracts tool message.
"""
return self.format_tools.extract(content)
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
@@ -100,7 +106,8 @@ class Template:
if i == 0 and (system or tools or self.force_system):
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system.apply(content=(system + tool_text))
elif i > 0 and i % 2 == 0:
if i > 0 and i % 2 == 0:
elements += self.format_separator.apply()
if message["role"] == Role.USER.value:
@@ -191,7 +198,8 @@ class Llama2Template(Template):
if i == 0 and (system or tools or self.force_system):
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
system_text = self.format_system.apply(content=(system + tool_text))[0]
elif i > 0 and i % 2 == 0:
if i > 0 and i % 2 == 0:
elements += self.format_separator.apply()
if message["role"] == Role.USER.value:
@@ -259,7 +267,9 @@ def _register_template(
template_class = Llama2Template if name.startswith("llama2") else Template
default_user_formatter = StringFormatter(slots=["{{content}}"])
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
default_function_formatter = FunctionFormatter(
slots=["Action: {{name}}\nAction Input: {{arguments}}\n"] + eos_slots
)
default_tool_formatter = ToolFormatter(tool_format="default")
default_separator_formatter = EmptyFormatter()
TEMPLATES[name] = template_class(