[test] align test cases (#6865)

* align test cases

* fix function formatter

Former-commit-id: a68f5e22d0391c80a9a826dc83967255be572032
This commit is contained in:
hoshi-hiyouga
2025-02-09 01:03:49 +08:00
committed by GitHub
parent 94726bdc8d
commit 72d5b06b08
3 changed files with 32 additions and 42 deletions

View File

@@ -123,11 +123,10 @@ def test_glm4_tool_extractor():
def test_llama3_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3")
formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
assert formatter.apply(content=tool_calls) == [
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}""",
"<|eot_id|>",
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}<|eot_id|>"""
]
@@ -150,20 +149,19 @@ def test_llama3_tool_extractor():
def test_mistral_function_formatter():
formatter = FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", "</s>"], tool_format="mistral")
formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == [
"[TOOL_CALLS] ",
"""[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""",
"[TOOL_CALLS] " """[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""",
"</s>",
]
def test_mistral_multi_function_formatter():
formatter = FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", "</s>"], tool_format="mistral")
formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral")
tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [
"[TOOL_CALLS] ",
"[TOOL_CALLS] "
"""[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}, """
"""{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""",
"</s>",
@@ -197,21 +195,20 @@ def test_mistral_multi_tool_extractor():
def test_qwen_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen")
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == [
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>""",
"<|im_end|>",
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call><|im_end|>\n"""
]
def test_qwen_multi_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen")
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen")
tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>\n"""
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>""",
"<|im_end|>",
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>"""
"<|im_end|>\n"
]