support llama3 tool prompt

Former-commit-id: dc45d2f56669fd99935a68cda1ec0e8f36229f7f
This commit is contained in:
hiyouga
2024-12-17 15:52:37 +00:00
parent 3d3324be5c
commit 1b8aab0723
5 changed files with 129 additions and 49 deletions

View File

@@ -98,7 +98,7 @@ class StringFormatter(Formatter):
@dataclass
class FunctionFormatter(Formatter):
def __post_init__(self):
self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots
self.function_slots = get_tool_utils(self.tool_format).get_function_slots()
@override
def apply(self, **kwargs) -> SLOTS:
@@ -117,7 +117,7 @@ class FunctionFormatter(Formatter):
elements = []
for name, arguments in functions:
for slot in self.slots:
for slot in self.function_slots:
if isinstance(slot, str):
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
elements.append(slot)
@@ -126,7 +126,7 @@ class FunctionFormatter(Formatter):
else:
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
return elements
return elements + self.slots
@dataclass