Support Mistral format tools

Former-commit-id: e42d0e54b7a64a3f017a09e99846d174db7b438f
This commit is contained in:
ylfeng
2024-09-18 21:45:25 +08:00
committed by hiyouga
parent ebf6a07681
commit 469c7cd462
6 changed files with 160 additions and 60 deletions

View File

@@ -59,7 +59,7 @@ def test_multi_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n"""
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
"</s>",
]
@@ -112,7 +112,7 @@ def test_glm4_tool_formatter():
assert formatter.apply(content=json.dumps(TOOLS)) == [
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具\n\n"
f"## test_tool\n\n{json.dumps(TOOLS[0], indent=4)}\n在调用上述函数时,请使用 Json 格式表示调用的参数。"
f"## test_tool\n\n{json.dumps(TOOLS[0], indent=4, ensure_ascii=False)}\n在调用上述函数时,请使用 Json 格式表示调用的参数。"
]
@@ -139,7 +139,7 @@ def test_llama3_tool_formatter():
f"Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n"
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
"""Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. """
f"Do not use variables.\n\n{json.dumps(wrapped_tool, indent=4)}\n\n"
f"Do not use variables.\n\n{json.dumps(wrapped_tool, indent=4, ensure_ascii=False)}\n\n"
]
@@ -147,3 +147,50 @@ def test_llama3_tool_extractor():
formatter = ToolFormatter(tool_format="llama3")
result = """{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}\n"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
def test_mistral_function_formatter():
formatter = FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", "</s>"], tool_format="mistral")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == [
"[TOOL_CALLS] ",
"""[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""",
"</s>",
]
def test_mistral_multi_function_formatter():
formatter = FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", "</s>"], tool_format="mistral")
tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [
"[TOOL_CALLS] ",
"""[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}, """
"""{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""",
"</s>",
]
def test_mistral_tool_formatter():
formatter = ToolFormatter(tool_format="mistral")
wrapped_tool = {"type": "function", "function": TOOLS[0]}
assert formatter.apply(content=json.dumps(TOOLS)) == [
"[AVAILABLE_TOOLS] " + json.dumps([wrapped_tool], ensure_ascii=False) + "[/AVAILABLE_TOOLS]"
]
def test_mistral_tool_extractor():
formatter = ToolFormatter(tool_format="mistral")
result = """{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
def test_mistral_multi_tool_extractor():
formatter = ToolFormatter(tool_format="mistral")
result = (
"""[{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}, """
"""{"name": "another_tool", "arguments": {"foo": "job", "size": 2}}]"""
)
assert formatter.extract(result) == [
("test_tool", """{"foo": "bar", "size": 10}"""),
("another_tool", """{"foo": "job", "size": 2}"""),
]