[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -12,8 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Type, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -46,8 +47,8 @@ class Template:
|
||||
format_tools: "Formatter"
|
||||
format_prefix: "Formatter"
|
||||
default_system: str
|
||||
stop_words: List[str]
|
||||
thought_words: Tuple[str, str]
|
||||
stop_words: list[str]
|
||||
thought_words: tuple[str, str]
|
||||
efficient_eos: bool
|
||||
replace_eos: bool
|
||||
replace_jinja_template: bool
|
||||
@@ -56,13 +57,11 @@ class Template:
|
||||
def encode_oneturn(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
r"""
|
||||
Returns a single pair of token ids representing prompt and response respectively.
|
||||
"""
|
||||
) -> tuple[list[int], list[int]]:
|
||||
r"""Return a single pair of token ids representing prompt and response respectively."""
|
||||
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||
prompt_ids = []
|
||||
for encoded_ids in encoded_messages[:-1]:
|
||||
@@ -74,36 +73,28 @@ class Template:
|
||||
def encode_multiturn(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
Returns multiple pairs of token ids representing prompts and responses respectively.
|
||||
"""
|
||||
) -> list[tuple[list[int], list[int]]]:
|
||||
r"""Return multiple pairs of token ids representing prompts and responses respectively."""
|
||||
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
|
||||
|
||||
def extract_tool(self, content: str) -> Union[str, List["FunctionCall"]]:
|
||||
r"""
|
||||
Extracts tool message.
|
||||
"""
|
||||
def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]:
|
||||
r"""Extract tool message."""
|
||||
return self.format_tools.extract(content)
|
||||
|
||||
def get_stop_token_ids(self, tokenizer: "PreTrainedTokenizer") -> List[int]:
|
||||
r"""
|
||||
Returns stop token ids.
|
||||
"""
|
||||
def get_stop_token_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]:
|
||||
r"""Return stop token ids."""
|
||||
stop_token_ids = {tokenizer.eos_token_id}
|
||||
for token in self.stop_words:
|
||||
stop_token_ids.add(tokenizer.convert_tokens_to_ids(token))
|
||||
|
||||
return list(stop_token_ids)
|
||||
|
||||
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]:
|
||||
r"""
|
||||
Converts elements to token ids.
|
||||
"""
|
||||
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> list[int]:
|
||||
r"""Convert elements to token ids."""
|
||||
token_ids = []
|
||||
for elem in elements:
|
||||
if isinstance(elem, str):
|
||||
@@ -124,14 +115,14 @@ class Template:
|
||||
def _encode(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
) -> List[List[int]]:
|
||||
r"""
|
||||
Encodes formatted inputs to pairs of token ids.
|
||||
) -> list[list[int]]:
|
||||
r"""Encode formatted inputs to pairs of token ids.
|
||||
|
||||
Turn 0: prefix + system + query resp
|
||||
Turn t: query resp
|
||||
Turn t: query resp.
|
||||
"""
|
||||
system = system or self.default_system
|
||||
encoded_messages = []
|
||||
@@ -161,9 +152,7 @@ class Template:
|
||||
|
||||
@staticmethod
|
||||
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
|
||||
r"""
|
||||
Adds or replaces eos token to the tokenizer.
|
||||
"""
|
||||
r"""Add or replace eos token to the tokenizer."""
|
||||
is_added = tokenizer.eos_token_id is None
|
||||
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
|
||||
|
||||
@@ -176,9 +165,7 @@ class Template:
|
||||
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
|
||||
|
||||
def fix_special_tokens(self, tokenizer: "PreTrainedTokenizer") -> None:
|
||||
r"""
|
||||
Adds eos token and pad token to the tokenizer.
|
||||
"""
|
||||
r"""Add eos token and pad token to the tokenizer."""
|
||||
stop_words = self.stop_words
|
||||
if self.replace_eos:
|
||||
if not stop_words:
|
||||
@@ -204,16 +191,12 @@ class Template:
|
||||
|
||||
@staticmethod
|
||||
def _jinja_escape(content: str) -> str:
|
||||
r"""
|
||||
Escape single quotes in content.
|
||||
"""
|
||||
r"""Escape single quotes in content."""
|
||||
return content.replace("'", r"\'")
|
||||
|
||||
@staticmethod
|
||||
def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
|
||||
r"""
|
||||
Converts slots to jinja template.
|
||||
"""
|
||||
r"""Convert slots to jinja template."""
|
||||
slot_items = []
|
||||
for slot in slots:
|
||||
if isinstance(slot, str):
|
||||
@@ -235,9 +218,7 @@ class Template:
|
||||
return " + ".join(slot_items)
|
||||
|
||||
def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str:
|
||||
r"""
|
||||
Returns the jinja template.
|
||||
"""
|
||||
r"""Return the jinja template."""
|
||||
prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer)
|
||||
system = self._convert_slots_to_jinja(self.format_system.apply(), tokenizer, placeholder="system_message")
|
||||
user = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
|
||||
@@ -265,9 +246,7 @@ class Template:
|
||||
return jinja_template
|
||||
|
||||
def fix_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> None:
|
||||
r"""
|
||||
Replaces the jinja template in the tokenizer.
|
||||
"""
|
||||
r"""Replace the jinja template in the tokenizer."""
|
||||
if tokenizer.chat_template is None or self.replace_jinja_template:
|
||||
try:
|
||||
tokenizer.chat_template = self._get_jinja_template(tokenizer)
|
||||
@@ -278,9 +257,7 @@ class Template:
|
||||
def _convert_slots_to_ollama(
|
||||
slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content"
|
||||
) -> str:
|
||||
r"""
|
||||
Converts slots to ollama template.
|
||||
"""
|
||||
r"""Convert slots to ollama template."""
|
||||
slot_items = []
|
||||
for slot in slots:
|
||||
if isinstance(slot, str):
|
||||
@@ -302,9 +279,7 @@ class Template:
|
||||
return "".join(slot_items)
|
||||
|
||||
def _get_ollama_template(self, tokenizer: "PreTrainedTokenizer") -> str:
|
||||
r"""
|
||||
Returns the ollama template.
|
||||
"""
|
||||
r"""Return the ollama template."""
|
||||
prefix = self._convert_slots_to_ollama(self.format_prefix.apply(), tokenizer)
|
||||
system = self._convert_slots_to_ollama(self.format_system.apply(), tokenizer, placeholder=".System")
|
||||
user = self._convert_slots_to_ollama(self.format_user.apply(), tokenizer, placeholder=".Content")
|
||||
@@ -316,8 +291,7 @@ class Template:
|
||||
)
|
||||
|
||||
def get_ollama_modelfile(self, tokenizer: "PreTrainedTokenizer") -> str:
|
||||
r"""
|
||||
Returns the ollama modelfile.
|
||||
r"""Return the ollama modelfile.
|
||||
|
||||
TODO: support function calling.
|
||||
"""
|
||||
@@ -340,10 +314,10 @@ class Llama2Template(Template):
|
||||
def _encode(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: str,
|
||||
tools: str,
|
||||
) -> List[List[int]]:
|
||||
) -> list[list[int]]:
|
||||
system = system or self.default_system
|
||||
encoded_messages = []
|
||||
for i, message in enumerate(messages):
|
||||
@@ -402,7 +376,7 @@ class Llama2Template(Template):
|
||||
return jinja_template
|
||||
|
||||
|
||||
TEMPLATES: Dict[str, "Template"] = {}
|
||||
TEMPLATES: dict[str, "Template"] = {}
|
||||
|
||||
|
||||
def register_template(
|
||||
@@ -416,15 +390,14 @@ def register_template(
|
||||
format_prefix: Optional["Formatter"] = None,
|
||||
default_system: str = "",
|
||||
stop_words: Optional[Sequence[str]] = None,
|
||||
thought_words: Optional[Tuple[str, str]] = None,
|
||||
thought_words: Optional[tuple[str, str]] = None,
|
||||
efficient_eos: bool = False,
|
||||
replace_eos: bool = False,
|
||||
replace_jinja_template: bool = False,
|
||||
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
|
||||
template_class: Type["Template"] = Template,
|
||||
template_class: type["Template"] = Template,
|
||||
) -> None:
|
||||
r"""
|
||||
Registers a chat template.
|
||||
r"""Register a chat template.
|
||||
|
||||
To add the following chat template:
|
||||
```
|
||||
@@ -472,9 +445,7 @@ def register_template(
|
||||
|
||||
|
||||
def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
|
||||
r"""
|
||||
Extracts a chat template from the tokenizer.
|
||||
"""
|
||||
r"""Extract a chat template from the tokenizer."""
|
||||
|
||||
def find_diff(short_str: str, long_str: str) -> str:
|
||||
i, j = 0, 0
|
||||
@@ -532,9 +503,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
|
||||
|
||||
|
||||
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
|
||||
r"""
|
||||
Gets chat template and fixes the tokenizer.
|
||||
"""
|
||||
r"""Get chat template and fixes the tokenizer."""
|
||||
if data_args.template is None:
|
||||
if isinstance(tokenizer.chat_template, str):
|
||||
logger.warning_rank0("`template` was not specified, try parsing the chat template from the tokenizer.")
|
||||
@@ -1149,7 +1118,8 @@ register_template(
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
default_system=(
|
||||
"你是一个经过良好训练的AI助手,你的名字是Marco-o1.由阿里国际数字商业集团的AI Business创造.\n## 重要!!!!!\n"
|
||||
"你是一个经过良好训练的AI助手,你的名字是Marco-o1."
|
||||
"由阿里国际数字商业集团的AI Business创造.\n## 重要!!!!!\n"
|
||||
"当你回答问题时,你的思考应该在<Thought>内完成,<Output>内输出你的结果。\n"
|
||||
"<Thought>应该尽可能是英文,但是有2个特例,一个是对原文中的引用,另一个是是数学应该使用markdown格式,<Output>内的输出需要遵循用户输入的语言。\n"
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user