Support Mistral format tools
Former-commit-id: e42d0e54b7a64a3f017a09e99846d174db7b438f
This commit is contained in:
@@ -15,17 +15,18 @@
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import namedtuple
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from typing import Any, Dict, List, NamedTuple, Tuple, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from .data_utils import SLOTS
|
||||
|
||||
|
||||
FunctionCall = namedtuple("FunctionCall", ["name", "arguments"])
|
||||
class FunctionCall(NamedTuple):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
DEFAULT_TOOL_PROMPT = (
|
||||
@@ -38,13 +39,11 @@ DEFAULT_TOOL_PROMPT = (
|
||||
"```\n"
|
||||
)
|
||||
|
||||
|
||||
GLM4_TOOL_PROMPT = (
|
||||
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
|
||||
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}"
|
||||
)
|
||||
|
||||
|
||||
LLAMA3_TOOL_PROMPT = (
|
||||
"Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n"
|
||||
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
|
||||
@@ -59,14 +58,6 @@ class ToolUtils(ABC):
|
||||
Base class for tool utilities.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_function_slots() -> SLOTS:
|
||||
r"""
|
||||
Gets a list of slots corresponding to a single function call.
|
||||
"""
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
@@ -75,21 +66,24 @@ class ToolUtils(ABC):
|
||||
"""
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
|
||||
r"""
|
||||
Generates the assistant message including all the tool calls.
|
||||
"""
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
r"""
|
||||
Extracts all the function calls from the response message.
|
||||
Extracts all the function calls from the assistant message.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class DefaultToolUtils(ToolUtils):
|
||||
@override
|
||||
@staticmethod
|
||||
def get_function_slots() -> SLOTS:
|
||||
return ["Action: {{name}}\nAction Input: {{arguments}}\n"]
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
@@ -124,6 +118,15 @@ class DefaultToolUtils(ToolUtils):
|
||||
|
||||
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
|
||||
function_text = ""
|
||||
for name, arguments in functions:
|
||||
function_text += f"Action: {name}\nAction Input: {arguments}\n"
|
||||
|
||||
return [function_text]
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
@@ -138,7 +141,7 @@ class DefaultToolUtils(ToolUtils):
|
||||
tool_input = match[1].strip().strip('"').strip("```")
|
||||
try:
|
||||
arguments = json.loads(tool_input)
|
||||
results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
|
||||
results.append(FunctionCall(tool_name, json.dumps(arguments, ensure_ascii=False)))
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
@@ -146,11 +149,6 @@ class DefaultToolUtils(ToolUtils):
|
||||
|
||||
|
||||
class GLM4ToolUtils(ToolUtils):
|
||||
@override
|
||||
@staticmethod
|
||||
def get_function_slots() -> SLOTS:
|
||||
return ["{{name}}\n{{arguments}}"]
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
@@ -162,6 +160,14 @@ class GLM4ToolUtils(ToolUtils):
|
||||
|
||||
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
|
||||
if len(functions) > 1:
|
||||
raise ValueError("GLM-4 does not support parallel functions.")
|
||||
|
||||
return [f"{functions[0].name}\n{functions[0].arguments}"]
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
@@ -174,7 +180,7 @@ class GLM4ToolUtils(ToolUtils):
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
|
||||
return [FunctionCall(tool_name, json.dumps(arguments, ensure_ascii=False))]
|
||||
|
||||
|
||||
class Llama3ToolUtils(ToolUtils):
|
||||
@@ -184,11 +190,6 @@ class Llama3ToolUtils(ToolUtils):
|
||||
Reference: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling
|
||||
"""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def get_function_slots() -> SLOTS:
|
||||
return ["""{"name": "{{name}}", "parameters": {{arguments}}}"""]
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
@@ -200,6 +201,14 @@ class Llama3ToolUtils(ToolUtils):
|
||||
|
||||
return LLAMA3_TOOL_PROMPT.format(date=date, tool_text=tool_text)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
|
||||
if len(functions) > 1:
|
||||
raise ValueError("Llama 3 does not support parallel functions.")
|
||||
|
||||
return [f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}']
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
@@ -211,13 +220,54 @@ class Llama3ToolUtils(ToolUtils):
|
||||
if "name" not in tool or "parameters" not in tool:
|
||||
return content
|
||||
|
||||
return [(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False))]
|
||||
return [FunctionCall(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False))]
|
||||
|
||||
|
||||
class MistralToolUtils(ToolUtils):
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
wrapped_tools = []
|
||||
for tool in tools:
|
||||
wrapped_tools.append({"type": "function", "function": tool})
|
||||
|
||||
return "[AVAILABLE_TOOLS] " + json.dumps(wrapped_tools, ensure_ascii=False) + "[/AVAILABLE_TOOLS]"
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
|
||||
function_texts = []
|
||||
for name, arguments in functions:
|
||||
function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}')
|
||||
|
||||
return ["[" + ", ".join(function_texts) + "]"]
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
try:
|
||||
tools = json.loads(content.strip())
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
if not isinstance(tools, list):
|
||||
tools = [tools]
|
||||
|
||||
results = []
|
||||
for tool in tools:
|
||||
if "name" not in tool or "arguments" not in tool:
|
||||
return content
|
||||
|
||||
results.append(FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
TOOLS = {
|
||||
"default": DefaultToolUtils(),
|
||||
"glm4": GLM4ToolUtils(),
|
||||
"llama3": Llama3ToolUtils(),
|
||||
"mistral": MistralToolUtils(),
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user