fix llama3 tool template

Former-commit-id: 63f28a594a44c011f2e6d418f22ddbfc445db163
This commit is contained in:
hiyouga
2024-12-17 17:04:02 +00:00
parent ab7567693d
commit 53f0fff513
6 changed files with 40 additions and 33 deletions

View File

@@ -116,17 +116,21 @@ class FunctionFormatter(Formatter):
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string
elements = []
for name, arguments in functions:
for slot in self.function_slots:
if isinstance(slot, str):
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
elements.append(slot)
elif isinstance(slot, (dict, set)):
elements.append(slot)
else:
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
for slot in self.slots:
if slot == "{{content}}":
for name, arguments in functions:
for slot in self.function_slots:
if isinstance(slot, str):
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
elements.append(slot)
elif isinstance(slot, (dict, set)):
elements.append(slot)
else:
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
else:
elements.append(slot)
return elements + self.slots
return elements
@dataclass

View File

@@ -244,11 +244,11 @@ def _register_template(
)
```
"""
eos_slots = [] if efficient_eos else [{"eos_token"}]
template_class = Llama2Template if name.startswith("llama2") else Template
default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}]
default_user_formatter = StringFormatter(slots=["{{content}}"])
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
default_function_formatter = FunctionFormatter(slots=eos_slots, tool_format="default")
default_assistant_formatter = StringFormatter(slots=default_slots)
default_function_formatter = FunctionFormatter(slots=default_slots, tool_format="default")
default_tool_formatter = ToolFormatter(tool_format="default")
default_separator_formatter = EmptyFormatter()
default_prefix_formatter = EmptyFormatter()
@@ -371,8 +371,8 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
if data_args.tool_format is not None:
logger.info_rank0(f"Using tool format: {data_args.tool_format}.")
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format)
default_slots = ["{{content}}"] if template.efficient_eos else ["{{content}}", {"eos_token"}]
template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format)
template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
stop_words = template.stop_words
@@ -490,7 +490,7 @@ _register_template(
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
format_function=FunctionFormatter(slots=[], tool_format="glm4"),
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
format_observation=StringFormatter(
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
),
@@ -535,7 +535,7 @@ _register_template(
name="codegeex4",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=[], tool_format="glm4"),
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>\n"]),
format_tools=ToolFormatter(tool_format="glm4"),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
@@ -684,7 +684,7 @@ _register_template(
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=[], tool_format="glm4"),
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
format_tools=ToolFormatter(tool_format="glm4"),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
@@ -750,7 +750,7 @@ _register_template(
]
),
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
format_function=FunctionFormatter(slots=["<|eom_id|>"], tool_format="llama3"),
format_function=FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3"),
format_observation=StringFormatter(
slots=[
(
@@ -779,7 +779,7 @@ _register_template(
]
),
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
format_function=FunctionFormatter(slots=["<|eom_id|>"], tool_format="llama3"),
format_function=FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3"),
format_observation=StringFormatter(
slots=[
(
@@ -833,7 +833,7 @@ _register_template(
]
),
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
format_function=FunctionFormatter(slots=["<|eom_id|>"], tool_format="llama3"),
format_function=FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3"),
format_observation=StringFormatter(
slots=[
(

View File

@@ -46,7 +46,7 @@ GLM4_TOOL_PROMPT = (
LLAMA3_TOOL_PROMPT = (
"Environment: ipython\nCutting Knowledge Date: December 2023\nToday Date: {cur_time}\n\n"
"Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n"
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
"""Respond in the format {{"name": function name, "parameters": dictionary of argument name and its value}}. """
"Do not use variables.\n\n{tool_text}"
@@ -180,6 +180,8 @@ class GLM4ToolUtils(ToolUtils):
class Llama3ToolUtils(ToolUtils):
r"""
Llama 3.x tool using template with `tools_in_user_message=False`.
Reference: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling
"""
@override
@@ -190,13 +192,13 @@ class Llama3ToolUtils(ToolUtils):
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
cur_time = datetime.now().strftime("%d %b %Y")
date = datetime.now().strftime("%d %b %Y")
tool_text = ""
for tool in tools:
wrapped_tool = {"type": "function", "function": tool}
tool_text += json.dumps(wrapped_tool, indent=4, ensure_ascii=False) + "\n\n"
return LLAMA3_TOOL_PROMPT.format(cur_time=cur_time, tool_text=tool_text)
return LLAMA3_TOOL_PROMPT.format(date=date, tool_text=tool_text)
@override
@staticmethod