add docstrings, refactor logger
Former-commit-id: c34e489d71f8f539028543ccf8ee92cecedd6276
This commit is contained in:
@@ -15,9 +15,12 @@
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import namedtuple
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from .data_utils import SLOTS
|
||||
|
||||
|
||||
@@ -38,26 +41,47 @@ GLM4_TOOL_PROMPT = (
|
||||
)
|
||||
|
||||
|
||||
FunctionCall = namedtuple("FunctionCall", ["name", "arguments"])
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolUtils(ABC):
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_function_slots() -> SLOTS: ...
|
||||
"""
|
||||
Base class for tool utilities.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str: ...
|
||||
def get_function_slots() -> SLOTS:
|
||||
r"""
|
||||
Gets a list of slots corresponding to a single function call.
|
||||
"""
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: ...
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
r"""
|
||||
Generates the system message describing all the available tools.
|
||||
"""
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
r"""
|
||||
Extracts all the function calls from the response message.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class DefaultToolUtils(ToolUtils):
|
||||
@override
|
||||
@staticmethod
|
||||
def get_function_slots() -> SLOTS:
|
||||
return ["Action: {{name}}\nAction Input: {{arguments}}\n"]
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
@@ -91,8 +115,9 @@ class DefaultToolUtils(ToolUtils):
|
||||
|
||||
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
|
||||
action_match: List[Tuple[str, str]] = re.findall(regex, content)
|
||||
if not action_match:
|
||||
@@ -112,10 +137,12 @@ class DefaultToolUtils(ToolUtils):
|
||||
|
||||
|
||||
class GLM4ToolUtils(ToolUtils):
|
||||
@override
|
||||
@staticmethod
|
||||
def get_function_slots() -> SLOTS:
|
||||
return ["{{name}}\n{{arguments}}"]
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
@@ -126,8 +153,9 @@ class GLM4ToolUtils(ToolUtils):
|
||||
|
||||
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
if "\n" not in content:
|
||||
return content
|
||||
|
||||
|
||||
Reference in New Issue
Block a user