[data] llama3 multi tool support (#8124)
This commit is contained in:
@@ -50,7 +50,7 @@ def test_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
|
||||
tool_calls = json.dumps(FUNCTION)
|
||||
assert formatter.apply(content=tool_calls) == [
|
||||
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
|
||||
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}""",
|
||||
"</s>",
|
||||
]
|
||||
|
||||
@@ -60,7 +60,7 @@ def test_multi_function_formatter():
|
||||
tool_calls = json.dumps([FUNCTION] * 2)
|
||||
assert formatter.apply(content=tool_calls) == [
|
||||
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n"""
|
||||
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
|
||||
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}""",
|
||||
"</s>",
|
||||
]
|
||||
|
||||
@@ -85,7 +85,7 @@ def test_default_tool_formatter():
|
||||
|
||||
def test_default_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="default")
|
||||
result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n"""
|
||||
result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}"""
|
||||
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
||||
|
||||
|
||||
@@ -93,7 +93,7 @@ def test_default_multi_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="default")
|
||||
result = (
|
||||
"""Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n"""
|
||||
"""Action: another_tool\nAction Input: {"foo": "job", "size": 2}\n"""
|
||||
"""Action: another_tool\nAction Input: {"foo": "job", "size": 2}"""
|
||||
)
|
||||
assert formatter.extract(result) == [
|
||||
("test_tool", """{"foo": "bar", "size": 10}"""),
|
||||
@@ -125,12 +125,22 @@ def test_glm4_tool_extractor():
|
||||
|
||||
def test_llama3_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
|
||||
tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
|
||||
tool_calls = json.dumps(FUNCTION)
|
||||
assert formatter.apply(content=tool_calls) == [
|
||||
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}<|eot_id|>"""
|
||||
]
|
||||
|
||||
|
||||
def test_llama3_multi_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
|
||||
tool_calls = json.dumps([FUNCTION] * 2)
|
||||
assert formatter.apply(content=tool_calls) == [
|
||||
"""[{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}, """
|
||||
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}]"""
|
||||
"""<|eot_id|>"""
|
||||
]
|
||||
|
||||
|
||||
def test_llama3_tool_formatter():
|
||||
formatter = ToolFormatter(tool_format="llama3")
|
||||
date = datetime.now().strftime("%d %b %Y")
|
||||
@@ -150,6 +160,18 @@ def test_llama3_tool_extractor():
|
||||
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
||||
|
||||
|
||||
def test_llama3_multi_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="llama3")
|
||||
result = (
|
||||
"""[{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}, """
|
||||
"""{"name": "another_tool", "parameters": {"foo": "job", "size": 2}}]"""
|
||||
)
|
||||
assert formatter.extract(result) == [
|
||||
("test_tool", """{"foo": "bar", "size": 10}"""),
|
||||
("another_tool", """{"foo": "job", "size": 2}"""),
|
||||
]
|
||||
|
||||
|
||||
def test_mistral_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral")
|
||||
tool_calls = json.dumps(FUNCTION)
|
||||
|
||||
Reference in New Issue
Block a user