[data] llama3 multi tool support (#8124)

This commit is contained in:
hoshi-hiyouga
2025-05-21 02:01:12 +08:00
committed by GitHub
parent c2f6f2fa77
commit 56926d76f9
4 changed files with 55 additions and 50 deletions

View File

@@ -50,7 +50,7 @@ def test_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == [
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}""",
"</s>",
]
@@ -60,7 +60,7 @@ def test_multi_function_formatter():
tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n"""
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}""",
"</s>",
]
@@ -85,7 +85,7 @@ def test_default_tool_formatter():
def test_default_tool_extractor():
formatter = ToolFormatter(tool_format="default")
result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n"""
result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
@@ -93,7 +93,7 @@ def test_default_multi_tool_extractor():
formatter = ToolFormatter(tool_format="default")
result = (
"""Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n"""
"""Action: another_tool\nAction Input: {"foo": "job", "size": 2}\n"""
"""Action: another_tool\nAction Input: {"foo": "job", "size": 2}"""
)
assert formatter.extract(result) == [
("test_tool", """{"foo": "bar", "size": 10}"""),
@@ -125,12 +125,22 @@ def test_glm4_tool_extractor():
def test_llama3_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == [
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}<|eot_id|>"""
]
def test_llama3_multi_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [
"""[{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}, """
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}]"""
"""<|eot_id|>"""
]
def test_llama3_tool_formatter():
formatter = ToolFormatter(tool_format="llama3")
date = datetime.now().strftime("%d %b %Y")
@@ -150,6 +160,18 @@ def test_llama3_tool_extractor():
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
def test_llama3_multi_tool_extractor():
formatter = ToolFormatter(tool_format="llama3")
result = (
"""[{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}, """
"""{"name": "another_tool", "parameters": {"foo": "job", "size": 2}}]"""
)
assert formatter.extract(result) == [
("test_tool", """{"foo": "bar", "size": 10}"""),
("another_tool", """{"foo": "job", "size": 2}"""),
]
def test_mistral_function_formatter():
formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral")
tool_calls = json.dumps(FUNCTION)