Support Mistral format tools
Former-commit-id: e42d0e54b7a64a3f017a09e99846d174db7b438f
This commit is contained in:
@@ -16,16 +16,12 @@ import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from .data_utils import SLOTS
|
||||
from .tool_utils import get_tool_utils
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .tool_utils import FunctionCall
|
||||
from .tool_utils import FunctionCall, get_tool_utils
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -98,19 +94,21 @@ class StringFormatter(Formatter):
|
||||
@dataclass
|
||||
class FunctionFormatter(Formatter):
|
||||
def __post_init__(self):
|
||||
self.function_slots = get_tool_utils(self.tool_format).get_function_slots()
|
||||
self.tool_utils = get_tool_utils(self.tool_format)
|
||||
|
||||
@override
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
content = kwargs.pop("content")
|
||||
functions: List[Tuple[str, str]] = []
|
||||
functions: List["FunctionCall"] = []
|
||||
try:
|
||||
tool_calls = json.loads(content)
|
||||
if not isinstance(tool_calls, list): # parallel function call
|
||||
tool_calls = [tool_calls]
|
||||
|
||||
for tool_call in tool_calls:
|
||||
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
|
||||
functions.append(
|
||||
FunctionCall(tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))
|
||||
)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string
|
||||
@@ -118,15 +116,7 @@ class FunctionFormatter(Formatter):
|
||||
elements = []
|
||||
for slot in self.slots:
|
||||
if slot == "{{content}}":
|
||||
for name, arguments in functions:
|
||||
for slot in self.function_slots:
|
||||
if isinstance(slot, str):
|
||||
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
|
||||
elements.append(slot)
|
||||
elif isinstance(slot, (dict, set)):
|
||||
elements.append(slot)
|
||||
else:
|
||||
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
|
||||
elements += self.tool_utils.function_formatter(functions)
|
||||
else:
|
||||
elements.append(slot)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user