support qwen tool format
Former-commit-id: cbef4cb501fa1b50fa611e7054a856ce2c5ed10e
This commit is contained in:
@@ -51,6 +51,14 @@ LLAMA3_TOOL_PROMPT = (
|
||||
"Do not use variables.\n\n{tool_text}"
|
||||
)
|
||||
|
||||
QWEN_TOOL_PROMPT = (
|
||||
"\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
|
||||
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>{tool_text}"
|
||||
"\n</tools>\n\nFor each function call, return a json object with function name and arguments within "
|
||||
"""<tool_call></tool_call> XML tags:\n<tool_call>\n{{"name": <function-name>, """
|
||||
""""arguments": <args-json-object>}}\n</tool_call><|im_end|>\n"""
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolUtils(ABC):
|
||||
@@ -79,11 +87,17 @@ class ToolUtils(ABC):
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
r"""
|
||||
Extracts all the function calls from the assistant message.
|
||||
|
||||
It should be an inverse function of `function_formatter`.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class DefaultToolUtils(ToolUtils):
|
||||
r"""
|
||||
Default tool using template.
|
||||
"""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
@@ -149,6 +163,10 @@ class DefaultToolUtils(ToolUtils):
|
||||
|
||||
|
||||
class GLM4ToolUtils(ToolUtils):
|
||||
r"""
|
||||
GLM-4 tool using template.
|
||||
"""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
@@ -205,7 +223,7 @@ class Llama3ToolUtils(ToolUtils):
|
||||
@staticmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
|
||||
if len(functions) > 1:
|
||||
raise ValueError("Llama 3 does not support parallel functions.")
|
||||
raise ValueError("Llama-3 does not support parallel functions.")
|
||||
|
||||
return [f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}']
|
||||
|
||||
@@ -224,6 +242,10 @@ class Llama3ToolUtils(ToolUtils):
|
||||
|
||||
|
||||
class MistralToolUtils(ToolUtils):
|
||||
r"""
|
||||
Mistral v0.3 tool using template.
|
||||
"""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
@@ -263,11 +285,61 @@ class MistralToolUtils(ToolUtils):
|
||||
return results
|
||||
|
||||
|
||||
class QwenToolUtils(ToolUtils):
|
||||
r"""
|
||||
Qwen 2.5 tool using template.
|
||||
"""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
for tool in tools:
|
||||
wrapped_tool = {"type": "function", "function": tool}
|
||||
tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False)
|
||||
|
||||
return QWEN_TOOL_PROMPT.format(tool_text=tool_text)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
|
||||
function_texts = []
|
||||
for name, arguments in functions:
|
||||
function_texts.append(
|
||||
"<tool_call>\n" + f'{{"name": "{name}", "arguments": {arguments}}}' + "\n</tool_call>"
|
||||
)
|
||||
|
||||
return ["\n".join(function_texts)]
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
regex = re.compile(r"<tool_call>(.+?)</tool_call>(?=\s*<tool_call>|\s*$)", re.DOTALL)
|
||||
tool_match: List[str] = re.findall(regex, content)
|
||||
if not tool_match:
|
||||
return content
|
||||
|
||||
results = []
|
||||
for tool in tool_match:
|
||||
try:
|
||||
tool = json.loads(tool.strip())
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
if "name" not in tool or "arguments" not in tool:
|
||||
return content
|
||||
|
||||
results.append(FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
TOOLS = {
|
||||
"default": DefaultToolUtils(),
|
||||
"glm4": GLM4ToolUtils(),
|
||||
"llama3": Llama3ToolUtils(),
|
||||
"mistral": MistralToolUtils(),
|
||||
"qwen": QwenToolUtils(),
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user