support llama3 tool prompt

Former-commit-id: dc45d2f56669fd99935a68cda1ec0e8f36229f7f
This commit is contained in:
hiyouga
2024-12-17 15:52:37 +00:00
parent 3d3324be5c
commit 1b8aab0723
5 changed files with 129 additions and 49 deletions

View File

@@ -13,10 +13,29 @@
# limitations under the License.
import json
from datetime import datetime
from llamafactory.data.formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
FUNCTION = {"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}
TOOLS = [
{
"name": "test_tool",
"description": "tool_desc",
"parameters": {
"type": "object",
"properties": {
"foo": {"type": "string", "description": "foo_desc"},
"bar": {"type": "number", "description": "bar_desc"},
},
"required": ["foo"],
},
}
]
def test_empty_formatter():
formatter = EmptyFormatter(slots=["\n"])
assert formatter.apply() == ["\n"]
@@ -28,39 +47,27 @@ def test_string_formatter():
def test_function_formatter():
formatter = FunctionFormatter(slots=[], tool_format="default")
tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
formatter = FunctionFormatter(slots=["</s>"], tool_format="default")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == [
"""Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n"""
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
"</s>",
]
def test_multi_function_formatter():
formatter = FunctionFormatter(slots=[], tool_format="default")
tool_calls = json.dumps([{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}] * 2)
formatter = FunctionFormatter(slots=["</s>"], tool_format="default")
tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [
"""Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""",
"""Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""",
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
"</s>",
]
def test_default_tool_formatter():
formatter = ToolFormatter(tool_format="default")
tools = [
{
"name": "test_tool",
"description": "tool_desc",
"parameters": {
"type": "object",
"properties": {
"foo": {"type": "string", "description": "foo_desc"},
"bar": {"type": "number", "description": "bar_desc"},
},
"required": ["foo"],
},
}
]
assert formatter.apply(content=json.dumps(tools)) == [
assert formatter.apply(content=json.dumps(TOOLS)) == [
"You have access to the following tools:\n"
"> Tool Name: test_tool\n"
"Tool Description: tool_desc\n"
@@ -94,26 +101,18 @@ def test_default_multi_tool_extractor():
]
def test_glm4_function_formatter():
formatter = FunctionFormatter(tool_format="glm4")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == ["""tool_name\n{"foo": "bar", "size": 10}"""]
def test_glm4_tool_formatter():
formatter = ToolFormatter(tool_format="glm4")
tools = [
{
"name": "test_tool",
"description": "tool_desc",
"parameters": {
"type": "object",
"properties": {
"foo": {"type": "string", "description": "foo_desc"},
"bar": {"type": "number", "description": "bar_desc"},
},
"required": ["foo"],
},
}
]
assert formatter.apply(content=json.dumps(tools)) == [
assert formatter.apply(content=json.dumps(TOOLS)) == [
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具\n\n"
"## test_tool\n\n{}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(json.dumps(tools[0], indent=4))
f"## test_tool\n\n{json.dumps(TOOLS[0], indent=4)}\n在调用上述函数时,请使用 Json 格式表示调用的参数。"
]
@@ -121,3 +120,29 @@ def test_glm4_tool_extractor():
formatter = ToolFormatter(tool_format="glm4")
result = """test_tool\n{"foo": "bar", "size": 10}\n"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
def test_llama3_function_formatter():
formatter = FunctionFormatter(tool_format="llama3")
tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
assert formatter.apply(content=tool_calls) == [
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}"""
]
def test_llama3_tool_formatter():
formatter = ToolFormatter(tool_format="llama3")
cur_time = datetime.now().strftime("%d %b %Y")
wrapped_tool = {"type": "function", "function": TOOLS[0]}
assert formatter.apply(content=json.dumps(TOOLS)) == [
f"Environment: ipython\nCutting Knowledge Date: December 2023\nToday Date: {cur_time}\n\n"
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
"""Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. """
f"Do not use variables.\n\n{json.dumps(wrapped_tool, indent=4)}\n\n"
]
def test_llama3_tool_extractor():
formatter = ToolFormatter(tool_format="llama3")
result = """{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}\n"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]