[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -17,7 +17,7 @@ import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, NamedTuple, Tuple, Union
|
||||
from typing import Any, NamedTuple, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -60,31 +60,24 @@ QWEN_TOOL_PROMPT = (
|
||||
|
||||
@dataclass
|
||||
class ToolUtils(ABC):
|
||||
"""
|
||||
Base class for tool utilities.
|
||||
"""
|
||||
"""Base class for tool utilities."""
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
r"""
|
||||
Generates the system message describing all the available tools.
|
||||
"""
|
||||
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||
r"""Generate the system message describing all the available tools."""
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> str:
|
||||
r"""
|
||||
Generates the assistant message including all the tool calls.
|
||||
"""
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
r"""Generate the assistant message including all the tool calls."""
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
r"""
|
||||
Extracts all the function calls from the assistant message.
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
r"""Extract all the function calls from the assistant message.
|
||||
|
||||
It should be an inverse function of `function_formatter`.
|
||||
"""
|
||||
@@ -92,13 +85,11 @@ class ToolUtils(ABC):
|
||||
|
||||
|
||||
class DefaultToolUtils(ToolUtils):
|
||||
r"""
|
||||
Default tool using template.
|
||||
"""
|
||||
r"""Default tool using template."""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
tool_names = []
|
||||
for tool in tools:
|
||||
@@ -132,7 +123,7 @@ class DefaultToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> str:
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
function_text = ""
|
||||
for name, arguments in functions:
|
||||
function_text += f"Action: {name}\nAction Input: {arguments}\n"
|
||||
@@ -141,9 +132,9 @@ class DefaultToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
|
||||
action_match: List[Tuple[str, str]] = re.findall(regex, content)
|
||||
action_match: list[tuple[str, str]] = re.findall(regex, content)
|
||||
if not action_match:
|
||||
return content
|
||||
|
||||
@@ -161,13 +152,11 @@ class DefaultToolUtils(ToolUtils):
|
||||
|
||||
|
||||
class GLM4ToolUtils(ToolUtils):
|
||||
r"""
|
||||
GLM-4 tool using template.
|
||||
"""
|
||||
r"""GLM-4 tool using template."""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
for tool in tools:
|
||||
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
|
||||
@@ -178,7 +167,7 @@ class GLM4ToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> str:
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
if len(functions) > 1:
|
||||
raise ValueError("GLM-4 does not support parallel functions.")
|
||||
|
||||
@@ -186,7 +175,7 @@ class GLM4ToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
if "\n" not in content:
|
||||
return content
|
||||
|
||||
@@ -200,15 +189,14 @@ class GLM4ToolUtils(ToolUtils):
|
||||
|
||||
|
||||
class Llama3ToolUtils(ToolUtils):
|
||||
r"""
|
||||
Llama 3.x tool using template with `tools_in_user_message=False`.
|
||||
r"""Llama 3.x tool using template with `tools_in_user_message=False`.
|
||||
|
||||
Reference: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling
|
||||
"""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||
date = datetime.now().strftime("%d %b %Y")
|
||||
tool_text = ""
|
||||
for tool in tools:
|
||||
@@ -219,7 +207,7 @@ class Llama3ToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> str:
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
if len(functions) > 1:
|
||||
raise ValueError("Llama-3 does not support parallel functions.")
|
||||
|
||||
@@ -227,7 +215,7 @@ class Llama3ToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
try:
|
||||
tool = json.loads(content.strip())
|
||||
except json.JSONDecodeError:
|
||||
@@ -240,13 +228,11 @@ class Llama3ToolUtils(ToolUtils):
|
||||
|
||||
|
||||
class MistralToolUtils(ToolUtils):
|
||||
r"""
|
||||
Mistral v0.3 tool using template.
|
||||
"""
|
||||
r"""Mistral v0.3 tool using template."""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||
wrapped_tools = []
|
||||
for tool in tools:
|
||||
wrapped_tools.append({"type": "function", "function": tool})
|
||||
@@ -255,7 +241,7 @@ class MistralToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> str:
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
function_texts = []
|
||||
for name, arguments in functions:
|
||||
function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}')
|
||||
@@ -264,7 +250,7 @@ class MistralToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
try:
|
||||
tools = json.loads(content.strip())
|
||||
except json.JSONDecodeError:
|
||||
@@ -284,13 +270,11 @@ class MistralToolUtils(ToolUtils):
|
||||
|
||||
|
||||
class QwenToolUtils(ToolUtils):
|
||||
r"""
|
||||
Qwen 2.5 tool using template.
|
||||
"""
|
||||
r"""Qwen 2.5 tool using template."""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
for tool in tools:
|
||||
wrapped_tool = {"type": "function", "function": tool}
|
||||
@@ -300,7 +284,7 @@ class QwenToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> str:
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
function_texts = []
|
||||
for name, arguments in functions:
|
||||
function_texts.append(
|
||||
@@ -311,9 +295,9 @@ class QwenToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
regex = re.compile(r"<tool_call>(.+?)</tool_call>(?=\s*<tool_call>|\s*$)", re.DOTALL)
|
||||
tool_match: List[str] = re.findall(regex, content)
|
||||
tool_match: list[str] = re.findall(regex, content)
|
||||
if not tool_match:
|
||||
return content
|
||||
|
||||
|
||||
Reference in New Issue
Block a user