Merge pull request #4173 from mMrBun/main
Implemented the tool_formatter and tool_extractor for glm4 and Qwen2 tool_format Former-commit-id: 36b02ceed40198ecd5d559ee4ebef9205442ded2
This commit is contained in:
@@ -37,6 +37,17 @@ TOOL_SYSTEM_PROMPT = (
|
||||
)
|
||||
|
||||
|
||||
GLM4_TOOL_SUFFIX_PROMPT = (
|
||||
"在调用上述函数时,请使用 Json 格式表示调用的参数。"
|
||||
)
|
||||
|
||||
GLM4_TOOL_PROMPT = (
|
||||
"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持,"
|
||||
"{tool_text}"
|
||||
|
||||
)
|
||||
|
||||
|
||||
def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
tool_names = []
|
||||
@@ -67,31 +78,59 @@ def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
)
|
||||
|
||||
|
||||
def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]:
|
||||
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+).*?Action Input:\s*(.*)", re.DOTALL)
|
||||
action_match = re.search(regex, content)
|
||||
def glm4_tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
for tool in tools:
|
||||
tool_name = tool["name"]
|
||||
tool_text += f"\n\n## {tool_name}\n\n{json.dumps(tool, ensure_ascii=False, indent=4)}\n{GLM4_TOOL_SUFFIX_PROMPT}"
|
||||
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
|
||||
|
||||
|
||||
def default_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*({.*?})(?=\nAction:|\Z)", re.DOTALL)
|
||||
action_match = re.findall(regex, content)
|
||||
if not action_match:
|
||||
return content
|
||||
|
||||
tool_name = action_match.group(1).strip()
|
||||
tool_input = action_match.group(2).strip().strip('"').strip("```")
|
||||
results = []
|
||||
|
||||
for match in action_match:
|
||||
tool_name, tool_input = match
|
||||
tool_name = tool_name.strip()
|
||||
tool_input = tool_input.strip().strip('"').strip("```")
|
||||
|
||||
try:
|
||||
arguments = json.loads(tool_input)
|
||||
results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def glm4_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
lines = content.strip().split("\n")
|
||||
if len(lines) != 2:
|
||||
return content
|
||||
tool_name = lines[0].strip()
|
||||
tool_input = lines[1].strip()
|
||||
try:
|
||||
arguments = json.loads(tool_input)
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
|
||||
|
||||
return tool_name, json.dumps(arguments, ensure_ascii=False)
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class Formatter(ABC):
|
||||
slots: SLOTS = field(default_factory=list)
|
||||
tool_format: Optional[Literal["default"]] = None
|
||||
tool_format: Optional[Literal["default", "glm4"]] = None
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, **kwargs) -> SLOTS: ...
|
||||
|
||||
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
|
||||
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -189,13 +228,17 @@ class ToolFormatter(Formatter):
|
||||
|
||||
if self.tool_format == "default":
|
||||
return [default_tool_formatter(tools)]
|
||||
elif self.tool_format == "glm4":
|
||||
return [glm4_tool_formatter(tools)]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
except Exception:
|
||||
return [""]
|
||||
|
||||
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
|
||||
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
if self.tool_format == "default":
|
||||
return default_tool_extractor(content)
|
||||
elif self.tool_format == "glm4":
|
||||
return glm4_tool_extractor(content)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
Reference in New Issue
Block a user