[data] llama3 multi tool support (#8124)

This commit is contained in:
hoshi-hiyouga
2025-05-21 02:01:12 +08:00
committed by GitHub
parent c2f6f2fa77
commit 56926d76f9
4 changed files with 55 additions and 50 deletions

View File

@@ -125,11 +125,7 @@ class DefaultToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str:
function_text = ""
for name, arguments in functions:
function_text += f"Action: {name}\nAction Input: {arguments}\n"
return function_text
return "\n".join([f"Action: {name}\nAction Input: {arguments}" for name, arguments in functions])
@override
@staticmethod
@@ -210,24 +206,23 @@ class Llama3ToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str:
if len(functions) > 1:
raise ValueError("Llama-3 does not support parallel functions.")
return f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}'
function_objects = [{"name": name, "parameters": json.loads(arguments)} for name, arguments in functions]
return json.dumps(function_objects[0] if len(function_objects) == 1 else function_objects, ensure_ascii=False)
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
try:
tool = json.loads(content.strip())
tools = json.loads(content.strip())
except json.JSONDecodeError:
return content
if "name" not in tool or "parameters" not in tool:
tools = [tools] if not isinstance(tools, list) else tools
try:
return [FunctionCall(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False)) for tool in tools]
except KeyError:
return content
return [FunctionCall(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False))]
class MistralToolUtils(ToolUtils):
r"""Mistral v0.3 tool using template."""
@@ -244,11 +239,9 @@ class MistralToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = []
for name, arguments in functions:
function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}')
return "[" + ", ".join(function_texts) + "]"
return json.dumps(
[{"name": name, "arguments": json.loads(arguments)} for name, arguments in functions], ensure_ascii=False
)
@override
@staticmethod
@@ -258,17 +251,11 @@ class MistralToolUtils(ToolUtils):
except json.JSONDecodeError:
return content
if not isinstance(tools, list):
tools = [tools]
results = []
for tool in tools:
if "name" not in tool or "arguments" not in tool:
return content
results.append(FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)))
return results
tools = [tools] if not isinstance(tools, list) else tools
try:
return [FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)) for tool in tools]
except KeyError:
return content
class QwenToolUtils(ToolUtils):
@@ -287,13 +274,11 @@ class QwenToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = []
for name, arguments in functions:
function_texts.append(
"<tool_call>\n" + f'{{"name": "{name}", "arguments": {arguments}}}' + "\n</tool_call>"
)
return "\n".join(function_texts)
function_texts = [
json.dumps({"name": name, "arguments": json.loads(arguments)}, ensure_ascii=False)
for name, arguments in functions
]
return "\n".join([f"<tool_call>\n{text}\n</tool_call>" for text in function_texts])
@override
@staticmethod