[data] llama3 multi tool support (#8124)
This commit is contained in:
@@ -125,11 +125,7 @@ class DefaultToolUtils(ToolUtils):
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
function_text = ""
|
||||
for name, arguments in functions:
|
||||
function_text += f"Action: {name}\nAction Input: {arguments}\n"
|
||||
|
||||
return function_text
|
||||
return "\n".join([f"Action: {name}\nAction Input: {arguments}" for name, arguments in functions])
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
@@ -210,24 +206,23 @@ class Llama3ToolUtils(ToolUtils):
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
if len(functions) > 1:
|
||||
raise ValueError("Llama-3 does not support parallel functions.")
|
||||
|
||||
return f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}'
|
||||
function_objects = [{"name": name, "parameters": json.loads(arguments)} for name, arguments in functions]
|
||||
return json.dumps(function_objects[0] if len(function_objects) == 1 else function_objects, ensure_ascii=False)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
try:
|
||||
tool = json.loads(content.strip())
|
||||
tools = json.loads(content.strip())
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
if "name" not in tool or "parameters" not in tool:
|
||||
tools = [tools] if not isinstance(tools, list) else tools
|
||||
try:
|
||||
return [FunctionCall(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False)) for tool in tools]
|
||||
except KeyError:
|
||||
return content
|
||||
|
||||
return [FunctionCall(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False))]
|
||||
|
||||
|
||||
class MistralToolUtils(ToolUtils):
|
||||
r"""Mistral v0.3 tool using template."""
|
||||
@@ -244,11 +239,9 @@ class MistralToolUtils(ToolUtils):
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
function_texts = []
|
||||
for name, arguments in functions:
|
||||
function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}')
|
||||
|
||||
return "[" + ", ".join(function_texts) + "]"
|
||||
return json.dumps(
|
||||
[{"name": name, "arguments": json.loads(arguments)} for name, arguments in functions], ensure_ascii=False
|
||||
)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
@@ -258,17 +251,11 @@ class MistralToolUtils(ToolUtils):
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
if not isinstance(tools, list):
|
||||
tools = [tools]
|
||||
|
||||
results = []
|
||||
for tool in tools:
|
||||
if "name" not in tool or "arguments" not in tool:
|
||||
return content
|
||||
|
||||
results.append(FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)))
|
||||
|
||||
return results
|
||||
tools = [tools] if not isinstance(tools, list) else tools
|
||||
try:
|
||||
return [FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)) for tool in tools]
|
||||
except KeyError:
|
||||
return content
|
||||
|
||||
|
||||
class QwenToolUtils(ToolUtils):
|
||||
@@ -287,13 +274,11 @@ class QwenToolUtils(ToolUtils):
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
function_texts = []
|
||||
for name, arguments in functions:
|
||||
function_texts.append(
|
||||
"<tool_call>\n" + f'{{"name": "{name}", "arguments": {arguments}}}' + "\n</tool_call>"
|
||||
)
|
||||
|
||||
return "\n".join(function_texts)
|
||||
function_texts = [
|
||||
json.dumps({"name": name, "arguments": json.loads(arguments)}, ensure_ascii=False)
|
||||
for name, arguments in functions
|
||||
]
|
||||
return "\n".join([f"<tool_call>\n{text}\n</tool_call>" for text in function_texts])
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user