Merge pull request #6367 from hiyouga/hiyouga/add_model
[model&template] add llama3.3 & support llama3 tool prompt Former-commit-id: c32012c5e4943a30c3061716ed780d6124b6c90d
This commit is contained in:
@@ -98,7 +98,7 @@ class StringFormatter(Formatter):
|
||||
@dataclass
|
||||
class FunctionFormatter(Formatter):
|
||||
def __post_init__(self):
|
||||
self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots
|
||||
self.function_slots = get_tool_utils(self.tool_format).get_function_slots()
|
||||
|
||||
@override
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
@@ -117,7 +117,7 @@ class FunctionFormatter(Formatter):
|
||||
|
||||
elements = []
|
||||
for name, arguments in functions:
|
||||
for slot in self.slots:
|
||||
for slot in self.function_slots:
|
||||
if isinstance(slot, str):
|
||||
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
|
||||
elements.append(slot)
|
||||
@@ -126,7 +126,7 @@ class FunctionFormatter(Formatter):
|
||||
else:
|
||||
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
|
||||
|
||||
return elements
|
||||
return elements + self.slots
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -750,16 +750,18 @@ _register_template(
|
||||
]
|
||||
),
|
||||
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||
format_function=FunctionFormatter(slots=["<|eom_id|>"], tool_format="llama3"),
|
||||
format_observation=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||
"<|start_header_id|>ipython<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
]
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="llama3"),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<|eot_id|>"],
|
||||
stop_words=["<|eot_id|>", "<|eom_id|>"],
|
||||
replace_eos=True,
|
||||
replace_jinja_template=False,
|
||||
)
|
||||
@@ -777,16 +779,18 @@ _register_template(
|
||||
]
|
||||
),
|
||||
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||
format_function=FunctionFormatter(slots=["<|eom_id|>"], tool_format="llama3"),
|
||||
format_observation=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||
"<|start_header_id|>ipython<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
]
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="llama3"),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<|eot_id|>"],
|
||||
stop_words=["<|eot_id|>", "<|eom_id|>"],
|
||||
replace_eos=True,
|
||||
replace_jinja_template=False,
|
||||
mm_plugin=get_mm_plugin(name="mllama", image_token="<|image|>"),
|
||||
@@ -829,16 +833,18 @@ _register_template(
|
||||
]
|
||||
),
|
||||
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||
format_function=FunctionFormatter(slots=["<|eom_id|>"], tool_format="llama3"),
|
||||
format_observation=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||
"<|start_header_id|>ipython<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
]
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="llama3"),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<|eot_id|>"],
|
||||
stop_words=["<|eot_id|>", "<|eom_id|>"],
|
||||
replace_eos=True,
|
||||
replace_jinja_template=False,
|
||||
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
|
||||
|
||||
@@ -17,6 +17,7 @@ import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import namedtuple
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
from typing_extensions import override
|
||||
@@ -24,6 +25,9 @@ from typing_extensions import override
|
||||
from .data_utils import SLOTS
|
||||
|
||||
|
||||
FunctionCall = namedtuple("FunctionCall", ["name", "arguments"])
|
||||
|
||||
|
||||
DEFAULT_TOOL_PROMPT = (
|
||||
"You have access to the following tools:\n{tool_text}"
|
||||
"Use the following format if using a tool:\n"
|
||||
@@ -41,7 +45,12 @@ GLM4_TOOL_PROMPT = (
|
||||
)
|
||||
|
||||
|
||||
FunctionCall = namedtuple("FunctionCall", ["name", "arguments"])
|
||||
LLAMA3_TOOL_PROMPT = (
|
||||
"Environment: ipython\nCutting Knowledge Date: December 2023\nToday Date: {cur_time}\n\n"
|
||||
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
|
||||
"""Respond in the format {{"name": function name, "parameters": dictionary of argument name and its value}}. """
|
||||
"Do not use variables.\n\n{tool_text}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -161,16 +170,52 @@ class GLM4ToolUtils(ToolUtils):
|
||||
|
||||
tool_name, tool_input = content.split("\n", maxsplit=1)
|
||||
try:
|
||||
arguments = json.loads(tool_input)
|
||||
arguments = json.loads(tool_input.strip())
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
|
||||
|
||||
|
||||
class Llama3ToolUtils(ToolUtils):
|
||||
r"""
|
||||
Llama 3.x tool using template with `tools_in_user_message=False`.
|
||||
"""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def get_function_slots() -> SLOTS:
|
||||
return ["""{"name": "{{name}}", "parameters": {{arguments}}}"""]
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
cur_time = datetime.now().strftime("%d %b %Y")
|
||||
tool_text = ""
|
||||
for tool in tools:
|
||||
wrapped_tool = {"type": "function", "function": tool}
|
||||
tool_text += json.dumps(wrapped_tool, indent=4, ensure_ascii=False) + "\n\n"
|
||||
|
||||
return LLAMA3_TOOL_PROMPT.format(cur_time=cur_time, tool_text=tool_text)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
try:
|
||||
tool = json.loads(content.strip())
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
if "name" not in tool or "parameters" not in tool:
|
||||
return content
|
||||
|
||||
return [(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False))]
|
||||
|
||||
|
||||
TOOLS = {
|
||||
"default": DefaultToolUtils(),
|
||||
"glm4": GLM4ToolUtils(),
|
||||
"llama3": Llama3ToolUtils(),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -837,6 +837,10 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-3B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-3B-Instruct",
|
||||
},
|
||||
"Llama-3.3-70B-Instruct": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-3.3-70B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.3-70B-Instruct",
|
||||
},
|
||||
},
|
||||
template="llama3",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user