Support Mistral format tools

Former-commit-id: e42d0e54b7a64a3f017a09e99846d174db7b438f
This commit is contained in:
ylfeng
2024-09-18 21:45:25 +08:00
committed by hiyouga
parent ebf6a07681
commit 469c7cd462
6 changed files with 160 additions and 60 deletions

View File

@@ -16,16 +16,12 @@ import json
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from typing import List, Optional, Union
from typing_extensions import override
from .data_utils import SLOTS
from .tool_utils import get_tool_utils
if TYPE_CHECKING:
from .tool_utils import FunctionCall
from .tool_utils import FunctionCall, get_tool_utils
@dataclass
@@ -98,19 +94,21 @@ class StringFormatter(Formatter):
@dataclass
class FunctionFormatter(Formatter):
def __post_init__(self):
self.function_slots = get_tool_utils(self.tool_format).get_function_slots()
self.tool_utils = get_tool_utils(self.tool_format)
@override
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
functions: List[Tuple[str, str]] = []
functions: List["FunctionCall"] = []
try:
tool_calls = json.loads(content)
if not isinstance(tool_calls, list): # parallel function call
tool_calls = [tool_calls]
for tool_call in tool_calls:
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
functions.append(
FunctionCall(tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))
)
except json.JSONDecodeError:
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string
@@ -118,15 +116,7 @@ class FunctionFormatter(Formatter):
elements = []
for slot in self.slots:
if slot == "{{content}}":
for name, arguments in functions:
for slot in self.function_slots:
if isinstance(slot, str):
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
elements.append(slot)
elif isinstance(slot, (dict, set)):
elements.append(slot)
else:
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
elements += self.tool_utils.function_formatter(functions)
else:
elements.append(slot)