add docstrings, refactor logger

Former-commit-id: c34e489d71f8f539028543ccf8ee92cecedd6276
This commit is contained in:
hiyouga
2024-09-08 00:56:56 +08:00
parent 93d4570a59
commit 7f71276ad8
30 changed files with 334 additions and 57 deletions

View File

@@ -15,9 +15,12 @@
import json
import re
from abc import ABC, abstractmethod
from collections import namedtuple
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Union
from typing_extensions import override
from .data_utils import SLOTS
@@ -38,26 +41,47 @@ GLM4_TOOL_PROMPT = (
)
FunctionCall = namedtuple("FunctionCall", ["name", "arguments"])
@dataclass
class ToolUtils(ABC):
@staticmethod
@abstractmethod
def get_function_slots() -> SLOTS: ...
"""
Base class for tool utilities.
"""
@staticmethod
@abstractmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str: ...
def get_function_slots() -> SLOTS:
r"""
Gets a list of slots corresponding to a single function call.
"""
...
@staticmethod
@abstractmethod
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: ...
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
r"""
Generates the system message describing all the available tools.
"""
...
@staticmethod
@abstractmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extracts all the function calls from the response message.
"""
...
class DefaultToolUtils(ToolUtils):
@override
@staticmethod
def get_function_slots() -> SLOTS:
return ["Action: {{name}}\nAction Input: {{arguments}}\n"]
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
@@ -91,8 +115,9 @@ class DefaultToolUtils(ToolUtils):
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
action_match: List[Tuple[str, str]] = re.findall(regex, content)
if not action_match:
@@ -112,10 +137,12 @@ class DefaultToolUtils(ToolUtils):
class GLM4ToolUtils(ToolUtils):
@override
@staticmethod
def get_function_slots() -> SLOTS:
return ["{{name}}\n{{arguments}}"]
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
@@ -126,8 +153,9 @@ class GLM4ToolUtils(ToolUtils):
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
if "\n" not in content:
return content