[test] align test cases (#6865)

* align test cases

* fix function formatter

Former-commit-id: a68f5e22d0391c80a9a826dc83967255be572032
This commit is contained in:
hoshi-hiyouga
2025-02-09 01:03:49 +08:00
committed by GitHub
parent 94726bdc8d
commit 72d5b06b08
3 changed files with 32 additions and 42 deletions

View File

@@ -86,19 +86,20 @@ class StringFormatter(Formatter):
elif isinstance(slot, (dict, set)):
elements.append(slot)
else:
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}.")
return elements
@dataclass
class FunctionFormatter(Formatter):
class FunctionFormatter(StringFormatter):
def __post_init__(self):
super().__post_init__()
self.tool_utils = get_tool_utils(self.tool_format)
@override
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
content: str = kwargs.pop("content")
regex = re.compile(r"<think>(.*)</think>", re.DOTALL)
thought = re.search(regex, content)
if thought:
@@ -116,19 +117,13 @@ class FunctionFormatter(Formatter):
)
except json.JSONDecodeError:
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}.") # flat string
elements = []
for slot in self.slots:
if slot == "{{content}}":
if thought:
elements.append(thought.group(1))
function_str = self.tool_utils.function_formatter(functions)
if thought:
function_str = thought.group(1) + function_str
elements += self.tool_utils.function_formatter(functions)
else:
elements.append(slot)
return elements
return super().apply(content=function_str)
@dataclass
@@ -143,7 +138,7 @@ class ToolFormatter(Formatter):
tools = json.loads(content)
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
except json.JSONDecodeError:
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}") # flat string
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string
@override
def extract(self, content: str) -> Union[str, List["FunctionCall"]]: