[test] align test cases (#6865)

* align test cases

* fix function formatter

Former-commit-id: a68f5e22d0391c80a9a826dc83967255be572032
This commit is contained in:
hoshi-hiyouga
2025-02-09 01:03:49 +08:00
committed by GitHub
parent 94726bdc8d
commit 72d5b06b08
3 changed files with 32 additions and 42 deletions

View File

@@ -21,8 +21,6 @@ from typing import Any, Dict, List, NamedTuple, Tuple, Union
from typing_extensions import override
from .data_utils import SLOTS
class FunctionCall(NamedTuple):
name: str
@@ -76,7 +74,7 @@ class ToolUtils(ABC):
@staticmethod
@abstractmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
def function_formatter(functions: List["FunctionCall"]) -> str:
r"""
Generates the assistant message including all the tool calls.
"""
@@ -134,12 +132,12 @@ class DefaultToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
def function_formatter(functions: List["FunctionCall"]) -> str:
function_text = ""
for name, arguments in functions:
function_text += f"Action: {name}\nAction Input: {arguments}\n"
return [function_text]
return function_text
@override
@staticmethod
@@ -180,11 +178,11 @@ class GLM4ToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
def function_formatter(functions: List["FunctionCall"]) -> str:
if len(functions) > 1:
raise ValueError("GLM-4 does not support parallel functions.")
return [f"{functions[0].name}\n{functions[0].arguments}"]
return f"{functions[0].name}\n{functions[0].arguments}"
@override
@staticmethod
@@ -221,11 +219,11 @@ class Llama3ToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
def function_formatter(functions: List["FunctionCall"]) -> str:
if len(functions) > 1:
raise ValueError("Llama-3 does not support parallel functions.")
return [f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}']
return f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}'
@override
@staticmethod
@@ -257,12 +255,12 @@ class MistralToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
def function_formatter(functions: List["FunctionCall"]) -> str:
function_texts = []
for name, arguments in functions:
function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}')
return ["[" + ", ".join(function_texts) + "]"]
return "[" + ", ".join(function_texts) + "]"
@override
@staticmethod
@@ -302,14 +300,14 @@ class QwenToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
def function_formatter(functions: List["FunctionCall"]) -> str:
function_texts = []
for name, arguments in functions:
function_texts.append(
"<tool_call>\n" + f'{{"name": "{name}", "arguments": {arguments}}}' + "\n</tool_call>"
)
return ["\n".join(function_texts)]
return "\n".join(function_texts)
@override
@staticmethod