[test] align test cases (#6865)
* align test cases * fix function formatter Former-commit-id: a68f5e22d0391c80a9a826dc83967255be572032
This commit is contained in:
@@ -86,19 +86,20 @@ class StringFormatter(Formatter):
|
||||
elif isinstance(slot, (dict, set)):
|
||||
elements.append(slot)
|
||||
else:
|
||||
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
|
||||
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}.")
|
||||
|
||||
return elements
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionFormatter(Formatter):
|
||||
class FunctionFormatter(StringFormatter):
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.tool_utils = get_tool_utils(self.tool_format)
|
||||
|
||||
@override
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
content = kwargs.pop("content")
|
||||
content: str = kwargs.pop("content")
|
||||
regex = re.compile(r"<think>(.*)</think>", re.DOTALL)
|
||||
thought = re.search(regex, content)
|
||||
if thought:
|
||||
@@ -116,19 +117,13 @@ class FunctionFormatter(Formatter):
|
||||
)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string
|
||||
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}.") # flat string
|
||||
|
||||
elements = []
|
||||
for slot in self.slots:
|
||||
if slot == "{{content}}":
|
||||
if thought:
|
||||
elements.append(thought.group(1))
|
||||
function_str = self.tool_utils.function_formatter(functions)
|
||||
if thought:
|
||||
function_str = thought.group(1) + function_str
|
||||
|
||||
elements += self.tool_utils.function_formatter(functions)
|
||||
else:
|
||||
elements.append(slot)
|
||||
|
||||
return elements
|
||||
return super().apply(content=function_str)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -143,7 +138,7 @@ class ToolFormatter(Formatter):
|
||||
tools = json.loads(content)
|
||||
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
|
||||
except json.JSONDecodeError:
|
||||
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}") # flat string
|
||||
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string
|
||||
|
||||
@override
|
||||
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
|
||||
|
||||
Reference in New Issue
Block a user