fix tool formatter, allow parallel function #4362
Former-commit-id: b8f16c976db4ecec1cc8558851c8cbfb6a5b7e9c
This commit is contained in:
@@ -22,29 +22,20 @@ from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Uni
|
||||
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
||||
|
||||
|
||||
JSON_FORMAT_PROMPT = (
|
||||
""", in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)"""
|
||||
)
|
||||
|
||||
|
||||
TOOL_SYSTEM_PROMPT = (
|
||||
DEFAULT_TOOL_PROMPT = (
|
||||
"You have access to the following tools:\n{tool_text}"
|
||||
"Use the following format if using a tool:\n"
|
||||
"```\n"
|
||||
"Action: tool name (one of [{tool_names}]).\n"
|
||||
"Action Input: the input to the tool{format_prompt}.\n"
|
||||
"Action Input: the input to the tool, in a JSON format representing the kwargs "
|
||||
"""(e.g. ```{{"input": "hello world", "num_beams": 5}}```).\n"""
|
||||
"```\n"
|
||||
)
|
||||
|
||||
|
||||
GLM4_TOOL_SUFFIX_PROMPT = (
|
||||
"在调用上述函数时,请使用 Json 格式表示调用的参数。"
|
||||
)
|
||||
|
||||
GLM4_TOOL_PROMPT = (
|
||||
"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持,"
|
||||
"{tool_text}"
|
||||
|
||||
"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
|
||||
"你的任务是针对用户的问题和要求提供适当的答复和支持。{tool_text}"
|
||||
)
|
||||
|
||||
|
||||
@@ -73,32 +64,19 @@ def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
)
|
||||
tool_names.append(tool["name"])
|
||||
|
||||
return TOOL_SYSTEM_PROMPT.format(
|
||||
tool_text=tool_text, tool_names=", ".join(tool_names), format_prompt=JSON_FORMAT_PROMPT
|
||||
)
|
||||
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
|
||||
|
||||
|
||||
def glm4_tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
for tool in tools:
|
||||
tool_name = tool["name"]
|
||||
tool_text += f"\n\n## {tool_name}\n\n{json.dumps(tool, ensure_ascii=False, indent=4)}\n{GLM4_TOOL_SUFFIX_PROMPT}"
|
||||
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
|
||||
|
||||
|
||||
def default_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*({.*?})(?=\nAction:|\Z)", re.DOTALL)
|
||||
action_match = re.findall(regex, content)
|
||||
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|$)", re.DOTALL)
|
||||
action_match: List[Tuple[str, str]] = re.findall(regex, content)
|
||||
if not action_match:
|
||||
return content
|
||||
|
||||
results = []
|
||||
|
||||
for match in action_match:
|
||||
tool_name, tool_input = match
|
||||
tool_name = tool_name.strip()
|
||||
tool_input = tool_input.strip().strip('"').strip("```")
|
||||
|
||||
tool_name = match[0].strip()
|
||||
tool_input = match[1].strip().strip('"').strip("```")
|
||||
try:
|
||||
arguments = json.loads(tool_input)
|
||||
results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
|
||||
@@ -108,19 +86,28 @@ def default_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
return results
|
||||
|
||||
|
||||
def glm4_tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
for tool in tools:
|
||||
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
|
||||
name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
|
||||
)
|
||||
|
||||
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
|
||||
|
||||
|
||||
def glm4_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
lines = content.strip().split("\n")
|
||||
if len(lines) != 2:
|
||||
if "\n" not in content:
|
||||
return content
|
||||
tool_name = lines[0].strip()
|
||||
tool_input = lines[1].strip()
|
||||
|
||||
tool_name, tool_input = content.split("\n", maxsplit=1)
|
||||
try:
|
||||
arguments = json.loads(tool_input)
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class Formatter(ABC):
|
||||
@@ -193,22 +180,28 @@ class FunctionFormatter(Formatter):
|
||||
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
content = kwargs.pop("content")
|
||||
functions: List[Tuple[str, str]] = []
|
||||
try:
|
||||
function = json.loads(content)
|
||||
name = function["name"]
|
||||
arguments = json.dumps(function["arguments"], ensure_ascii=False)
|
||||
except Exception:
|
||||
name, arguments = "", ""
|
||||
tool_calls = json.loads(content)
|
||||
if not isinstance(tool_calls, list): # parallel function call
|
||||
tool_calls = [tool_calls]
|
||||
|
||||
for tool_call in tool_calls:
|
||||
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
|
||||
|
||||
except json.JSONDecodeError:
|
||||
functions = []
|
||||
|
||||
elements = []
|
||||
for slot in self.slots:
|
||||
if isinstance(slot, str):
|
||||
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
|
||||
elements.append(slot)
|
||||
elif isinstance(slot, (dict, set)):
|
||||
elements.append(slot)
|
||||
else:
|
||||
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
||||
for name, arguments in functions:
|
||||
for slot in self.slots:
|
||||
if isinstance(slot, str):
|
||||
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
|
||||
elements.append(slot)
|
||||
elif isinstance(slot, (dict, set)):
|
||||
elements.append(slot)
|
||||
else:
|
||||
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
||||
|
||||
return elements
|
||||
|
||||
@@ -216,29 +209,22 @@ class FunctionFormatter(Formatter):
|
||||
@dataclass
|
||||
class ToolFormatter(Formatter):
|
||||
def __post_init__(self):
|
||||
if self.tool_format is None:
|
||||
if self.tool_format == "default":
|
||||
self._tool_formatter = default_tool_formatter
|
||||
self._tool_extractor = default_tool_extractor
|
||||
elif self.tool_format == "glm4":
|
||||
self._tool_formatter = glm4_tool_formatter
|
||||
self._tool_extractor = glm4_tool_extractor
|
||||
else:
|
||||
raise ValueError("Tool format was not found.")
|
||||
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
content = kwargs.pop("content")
|
||||
try:
|
||||
tools = json.loads(content)
|
||||
if not len(tools):
|
||||
return [""]
|
||||
|
||||
if self.tool_format == "default":
|
||||
return [default_tool_formatter(tools)]
|
||||
elif self.tool_format == "glm4":
|
||||
return [glm4_tool_formatter(tools)]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
except Exception:
|
||||
return [self._tool_formatter(tools) if len(tools) != 0 else ""]
|
||||
except json.JSONDecodeError:
|
||||
return [""]
|
||||
|
||||
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
if self.tool_format == "default":
|
||||
return default_tool_extractor(content)
|
||||
elif self.tool_format == "glm4":
|
||||
return glm4_tool_extractor(content)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return self._tool_extractor(content)
|
||||
|
||||
Reference in New Issue
Block a user