[test] align test cases (#6865)
* align test cases * fix function formatter Former-commit-id: a68f5e22d0391c80a9a826dc83967255be572032
This commit is contained in:
@@ -21,8 +21,6 @@ from typing import Any, Dict, List, NamedTuple, Tuple, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from .data_utils import SLOTS
|
||||
|
||||
|
||||
class FunctionCall(NamedTuple):
|
||||
name: str
|
||||
@@ -76,7 +74,7 @@ class ToolUtils(ABC):
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
|
||||
def function_formatter(functions: List["FunctionCall"]) -> str:
|
||||
r"""
|
||||
Generates the assistant message including all the tool calls.
|
||||
"""
|
||||
@@ -134,12 +132,12 @@ class DefaultToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
|
||||
def function_formatter(functions: List["FunctionCall"]) -> str:
|
||||
function_text = ""
|
||||
for name, arguments in functions:
|
||||
function_text += f"Action: {name}\nAction Input: {arguments}\n"
|
||||
|
||||
return [function_text]
|
||||
return function_text
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
@@ -180,11 +178,11 @@ class GLM4ToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
|
||||
def function_formatter(functions: List["FunctionCall"]) -> str:
|
||||
if len(functions) > 1:
|
||||
raise ValueError("GLM-4 does not support parallel functions.")
|
||||
|
||||
return [f"{functions[0].name}\n{functions[0].arguments}"]
|
||||
return f"{functions[0].name}\n{functions[0].arguments}"
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
@@ -221,11 +219,11 @@ class Llama3ToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
|
||||
def function_formatter(functions: List["FunctionCall"]) -> str:
|
||||
if len(functions) > 1:
|
||||
raise ValueError("Llama-3 does not support parallel functions.")
|
||||
|
||||
return [f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}']
|
||||
return f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}'
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
@@ -257,12 +255,12 @@ class MistralToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
|
||||
def function_formatter(functions: List["FunctionCall"]) -> str:
|
||||
function_texts = []
|
||||
for name, arguments in functions:
|
||||
function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}')
|
||||
|
||||
return ["[" + ", ".join(function_texts) + "]"]
|
||||
return "[" + ", ".join(function_texts) + "]"
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
@@ -302,14 +300,14 @@ class QwenToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
|
||||
def function_formatter(functions: List["FunctionCall"]) -> str:
|
||||
function_texts = []
|
||||
for name, arguments in functions:
|
||||
function_texts.append(
|
||||
"<tool_call>\n" + f'{{"name": "{name}", "arguments": {arguments}}}' + "\n</tool_call>"
|
||||
)
|
||||
|
||||
return ["\n".join(function_texts)]
|
||||
return "\n".join(function_texts)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user