support llama3 tool prompt

Former-commit-id: dc45d2f56669fd99935a68cda1ec0e8f36229f7f
This commit is contained in:
hiyouga
2024-12-17 15:52:37 +00:00
parent 3d3324be5c
commit 1b8aab0723
5 changed files with 129 additions and 49 deletions

View File

@@ -17,6 +17,7 @@ import re
from abc import ABC, abstractmethod
from collections import namedtuple
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, Tuple, Union
from typing_extensions import override
@@ -24,6 +25,9 @@ from typing_extensions import override
from .data_utils import SLOTS
FunctionCall = namedtuple("FunctionCall", ["name", "arguments"])
DEFAULT_TOOL_PROMPT = (
"You have access to the following tools:\n{tool_text}"
"Use the following format if using a tool:\n"
@@ -41,7 +45,12 @@ GLM4_TOOL_PROMPT = (
)
FunctionCall = namedtuple("FunctionCall", ["name", "arguments"])
LLAMA3_TOOL_PROMPT = (
"Environment: ipython\nCutting Knowledge Date: December 2023\nToday Date: {cur_time}\n\n"
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
"""Respond in the format {{"name": function name, "parameters": dictionary of argument name and its value}}. """
"Do not use variables.\n\n{tool_text}"
)
@dataclass
@@ -161,16 +170,52 @@ class GLM4ToolUtils(ToolUtils):
tool_name, tool_input = content.split("\n", maxsplit=1)
try:
arguments = json.loads(tool_input)
arguments = json.loads(tool_input.strip())
except json.JSONDecodeError:
return content
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
class Llama3ToolUtils(ToolUtils):
r"""
Llama 3.x tool using template with `tools_in_user_message=False`.
"""
@override
@staticmethod
def get_function_slots() -> SLOTS:
return ["""{"name": "{{name}}", "parameters": {{arguments}}}"""]
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
cur_time = datetime.now().strftime("%d %b %Y")
tool_text = ""
for tool in tools:
wrapped_tool = {"type": "function", "function": tool}
tool_text += json.dumps(wrapped_tool, indent=4, ensure_ascii=False) + "\n\n"
return LLAMA3_TOOL_PROMPT.format(cur_time=cur_time, tool_text=tool_text)
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
try:
tool = json.loads(content.strip())
except json.JSONDecodeError:
return content
if "name" not in tool or "parameters" not in tool:
return content
return [(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False))]
TOOLS = {
"default": DefaultToolUtils(),
"glm4": GLM4ToolUtils(),
"llama3": Llama3ToolUtils(),
}