[misc] upgrade format to py39 (#7256)

This commit is contained in:
hoshi-hiyouga
2025-03-12 00:08:41 +08:00
committed by GitHub
parent 5995800bce
commit 264538cb26
113 changed files with 984 additions and 1407 deletions

View File

@@ -17,7 +17,7 @@ import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, NamedTuple, Tuple, Union
from typing import Any, NamedTuple, Union
from typing_extensions import override
@@ -60,31 +60,24 @@ QWEN_TOOL_PROMPT = (
@dataclass
class ToolUtils(ABC):
"""
Base class for tool utilities.
"""
"""Base class for tool utilities."""
@staticmethod
@abstractmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
r"""
Generates the system message describing all the available tools.
"""
def tool_formatter(tools: list[dict[str, Any]]) -> str:
r"""Generate the system message describing all the available tools."""
...
@staticmethod
@abstractmethod
def function_formatter(functions: List["FunctionCall"]) -> str:
r"""
Generates the assistant message including all the tool calls.
"""
def function_formatter(functions: list["FunctionCall"]) -> str:
r"""Generate the assistant message including all the tool calls."""
...
@staticmethod
@abstractmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extracts all the function calls from the assistant message.
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
r"""Extract all the function calls from the assistant message.
It should be an inverse function of `function_formatter`.
"""
@@ -92,13 +85,11 @@ class ToolUtils(ABC):
class DefaultToolUtils(ToolUtils):
r"""
Default tool using template.
"""
r"""Default tool using template."""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = ""
tool_names = []
for tool in tools:
@@ -132,7 +123,7 @@ class DefaultToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> str:
def function_formatter(functions: list["FunctionCall"]) -> str:
function_text = ""
for name, arguments in functions:
function_text += f"Action: {name}\nAction Input: {arguments}\n"
@@ -141,9 +132,9 @@ class DefaultToolUtils(ToolUtils):
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
action_match: List[Tuple[str, str]] = re.findall(regex, content)
action_match: list[tuple[str, str]] = re.findall(regex, content)
if not action_match:
return content
@@ -161,13 +152,11 @@ class DefaultToolUtils(ToolUtils):
class GLM4ToolUtils(ToolUtils):
r"""
GLM-4 tool using template.
"""
r"""GLM-4 tool using template."""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
@@ -178,7 +167,7 @@ class GLM4ToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> str:
def function_formatter(functions: list["FunctionCall"]) -> str:
if len(functions) > 1:
raise ValueError("GLM-4 does not support parallel functions.")
@@ -186,7 +175,7 @@ class GLM4ToolUtils(ToolUtils):
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
if "\n" not in content:
return content
@@ -200,15 +189,14 @@ class GLM4ToolUtils(ToolUtils):
class Llama3ToolUtils(ToolUtils):
r"""
Llama 3.x tool using template with `tools_in_user_message=False`.
r"""Llama 3.x tool using template with `tools_in_user_message=False`.
Reference: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling
"""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
def tool_formatter(tools: list[dict[str, Any]]) -> str:
date = datetime.now().strftime("%d %b %Y")
tool_text = ""
for tool in tools:
@@ -219,7 +207,7 @@ class Llama3ToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> str:
def function_formatter(functions: list["FunctionCall"]) -> str:
if len(functions) > 1:
raise ValueError("Llama-3 does not support parallel functions.")
@@ -227,7 +215,7 @@ class Llama3ToolUtils(ToolUtils):
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
try:
tool = json.loads(content.strip())
except json.JSONDecodeError:
@@ -240,13 +228,11 @@ class Llama3ToolUtils(ToolUtils):
class MistralToolUtils(ToolUtils):
r"""
Mistral v0.3 tool using template.
"""
r"""Mistral v0.3 tool using template."""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
def tool_formatter(tools: list[dict[str, Any]]) -> str:
wrapped_tools = []
for tool in tools:
wrapped_tools.append({"type": "function", "function": tool})
@@ -255,7 +241,7 @@ class MistralToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> str:
def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = []
for name, arguments in functions:
function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}')
@@ -264,7 +250,7 @@ class MistralToolUtils(ToolUtils):
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
try:
tools = json.loads(content.strip())
except json.JSONDecodeError:
@@ -284,13 +270,11 @@ class MistralToolUtils(ToolUtils):
class QwenToolUtils(ToolUtils):
r"""
Qwen 2.5 tool using template.
"""
r"""Qwen 2.5 tool using template."""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
wrapped_tool = {"type": "function", "function": tool}
@@ -300,7 +284,7 @@ class QwenToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> str:
def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = []
for name, arguments in functions:
function_texts.append(
@@ -311,9 +295,9 @@ class QwenToolUtils(ToolUtils):
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
regex = re.compile(r"<tool_call>(.+?)</tool_call>(?=\s*<tool_call>|\s*$)", re.DOTALL)
tool_match: List[str] = re.findall(regex, content)
tool_match: list[str] = re.findall(regex, content)
if not tool_match:
return content