fix tool formatter, allow parallel function #4362
Former-commit-id: b8f16c976db4ecec1cc8558851c8cbfb6a5b7e9c
This commit is contained in:
@@ -79,6 +79,12 @@ class Template:
|
||||
"""
|
||||
return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
|
||||
|
||||
def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
r"""
|
||||
Extracts tool message.
|
||||
"""
|
||||
return self.format_tools.extract(content)
|
||||
|
||||
def _encode(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
@@ -100,7 +106,8 @@ class Template:
|
||||
if i == 0 and (system or tools or self.force_system):
|
||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||
elements += self.format_system.apply(content=(system + tool_text))
|
||||
elif i > 0 and i % 2 == 0:
|
||||
|
||||
if i > 0 and i % 2 == 0:
|
||||
elements += self.format_separator.apply()
|
||||
|
||||
if message["role"] == Role.USER.value:
|
||||
@@ -191,7 +198,8 @@ class Llama2Template(Template):
|
||||
if i == 0 and (system or tools or self.force_system):
|
||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||
system_text = self.format_system.apply(content=(system + tool_text))[0]
|
||||
elif i > 0 and i % 2 == 0:
|
||||
|
||||
if i > 0 and i % 2 == 0:
|
||||
elements += self.format_separator.apply()
|
||||
|
||||
if message["role"] == Role.USER.value:
|
||||
@@ -259,7 +267,9 @@ def _register_template(
|
||||
template_class = Llama2Template if name.startswith("llama2") else Template
|
||||
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
||||
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
|
||||
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
|
||||
default_function_formatter = FunctionFormatter(
|
||||
slots=["Action: {{name}}\nAction Input: {{arguments}}\n"] + eos_slots
|
||||
)
|
||||
default_tool_formatter = ToolFormatter(tool_format="default")
|
||||
default_separator_formatter = EmptyFormatter()
|
||||
TEMPLATES[name] = template_class(
|
||||
|
||||
Reference in New Issue
Block a user