[misc] upgrade format to py39 (#7256)

This commit is contained in:
hoshi-hiyouga
2025-03-12 00:08:41 +08:00
committed by GitHub
parent 5995800bce
commit 264538cb26
113 changed files with 984 additions and 1407 deletions

View File

@@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Type, Union
from typing import TYPE_CHECKING, Optional, Union
from typing_extensions import override
@@ -46,8 +47,8 @@ class Template:
format_tools: "Formatter"
format_prefix: "Formatter"
default_system: str
stop_words: List[str]
thought_words: Tuple[str, str]
stop_words: list[str]
thought_words: tuple[str, str]
efficient_eos: bool
replace_eos: bool
replace_jinja_template: bool
@@ -56,13 +57,11 @@ class Template:
def encode_oneturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
) -> Tuple[List[int], List[int]]:
r"""
Returns a single pair of token ids representing prompt and response respectively.
"""
) -> tuple[list[int], list[int]]:
r"""Return a single pair of token ids representing prompt and response respectively."""
encoded_messages = self._encode(tokenizer, messages, system, tools)
prompt_ids = []
for encoded_ids in encoded_messages[:-1]:
@@ -74,36 +73,28 @@ class Template:
def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
) -> List[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
) -> list[tuple[list[int], list[int]]]:
r"""Return multiple pairs of token ids representing prompts and responses respectively."""
encoded_messages = self._encode(tokenizer, messages, system, tools)
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
def extract_tool(self, content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extracts tool message.
"""
def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]:
r"""Extract tool message."""
return self.format_tools.extract(content)
def get_stop_token_ids(self, tokenizer: "PreTrainedTokenizer") -> List[int]:
r"""
Returns stop token ids.
"""
def get_stop_token_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]:
r"""Return stop token ids."""
stop_token_ids = {tokenizer.eos_token_id}
for token in self.stop_words:
stop_token_ids.add(tokenizer.convert_tokens_to_ids(token))
return list(stop_token_ids)
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]:
r"""
Converts elements to token ids.
"""
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> list[int]:
r"""Convert elements to token ids."""
token_ids = []
for elem in elements:
if isinstance(elem, str):
@@ -124,14 +115,14 @@ class Template:
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
system: Optional[str],
tools: Optional[str],
) -> List[List[int]]:
r"""
Encodes formatted inputs to pairs of token ids.
) -> list[list[int]]:
r"""Encode formatted inputs to pairs of token ids.
Turn 0: prefix + system + query resp
Turn t: query resp
Turn t: query resp.
"""
system = system or self.default_system
encoded_messages = []
@@ -161,9 +152,7 @@ class Template:
@staticmethod
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
r"""
Adds or replaces eos token to the tokenizer.
"""
r"""Add or replace eos token to the tokenizer."""
is_added = tokenizer.eos_token_id is None
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
@@ -176,9 +165,7 @@ class Template:
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
def fix_special_tokens(self, tokenizer: "PreTrainedTokenizer") -> None:
r"""
Adds eos token and pad token to the tokenizer.
"""
r"""Add eos token and pad token to the tokenizer."""
stop_words = self.stop_words
if self.replace_eos:
if not stop_words:
@@ -204,16 +191,12 @@ class Template:
@staticmethod
def _jinja_escape(content: str) -> str:
r"""
Escape single quotes in content.
"""
r"""Escape single quotes in content."""
return content.replace("'", r"\'")
@staticmethod
def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
r"""
Converts slots to jinja template.
"""
r"""Convert slots to jinja template."""
slot_items = []
for slot in slots:
if isinstance(slot, str):
@@ -235,9 +218,7 @@ class Template:
return " + ".join(slot_items)
def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str:
r"""
Returns the jinja template.
"""
r"""Return the jinja template."""
prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer)
system = self._convert_slots_to_jinja(self.format_system.apply(), tokenizer, placeholder="system_message")
user = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
@@ -265,9 +246,7 @@ class Template:
return jinja_template
def fix_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> None:
r"""
Replaces the jinja template in the tokenizer.
"""
r"""Replace the jinja template in the tokenizer."""
if tokenizer.chat_template is None or self.replace_jinja_template:
try:
tokenizer.chat_template = self._get_jinja_template(tokenizer)
@@ -278,9 +257,7 @@ class Template:
def _convert_slots_to_ollama(
slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content"
) -> str:
r"""
Converts slots to ollama template.
"""
r"""Convert slots to ollama template."""
slot_items = []
for slot in slots:
if isinstance(slot, str):
@@ -302,9 +279,7 @@ class Template:
return "".join(slot_items)
def _get_ollama_template(self, tokenizer: "PreTrainedTokenizer") -> str:
r"""
Returns the ollama template.
"""
r"""Return the ollama template."""
prefix = self._convert_slots_to_ollama(self.format_prefix.apply(), tokenizer)
system = self._convert_slots_to_ollama(self.format_system.apply(), tokenizer, placeholder=".System")
user = self._convert_slots_to_ollama(self.format_user.apply(), tokenizer, placeholder=".Content")
@@ -316,8 +291,7 @@ class Template:
)
def get_ollama_modelfile(self, tokenizer: "PreTrainedTokenizer") -> str:
r"""
Returns the ollama modelfile.
r"""Return the ollama modelfile.
TODO: support function calling.
"""
@@ -340,10 +314,10 @@ class Llama2Template(Template):
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
system: str,
tools: str,
) -> List[List[int]]:
) -> list[list[int]]:
system = system or self.default_system
encoded_messages = []
for i, message in enumerate(messages):
@@ -402,7 +376,7 @@ class Llama2Template(Template):
return jinja_template
TEMPLATES: Dict[str, "Template"] = {}
TEMPLATES: dict[str, "Template"] = {}
def register_template(
@@ -416,15 +390,14 @@ def register_template(
format_prefix: Optional["Formatter"] = None,
default_system: str = "",
stop_words: Optional[Sequence[str]] = None,
thought_words: Optional[Tuple[str, str]] = None,
thought_words: Optional[tuple[str, str]] = None,
efficient_eos: bool = False,
replace_eos: bool = False,
replace_jinja_template: bool = False,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
template_class: Type["Template"] = Template,
template_class: type["Template"] = Template,
) -> None:
r"""
Registers a chat template.
r"""Register a chat template.
To add the following chat template:
```
@@ -472,9 +445,7 @@ def register_template(
def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
r"""
Extracts a chat template from the tokenizer.
"""
r"""Extract a chat template from the tokenizer."""
def find_diff(short_str: str, long_str: str) -> str:
i, j = 0, 0
@@ -532,9 +503,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
r"""
Gets chat template and fixes the tokenizer.
"""
r"""Get chat template and fixes the tokenizer."""
if data_args.template is None:
if isinstance(tokenizer.chat_template, str):
logger.warning_rank0("`template` was not specified, try parsing the chat template from the tokenizer.")
@@ -1149,7 +1118,8 @@ register_template(
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
default_system=(
"你是一个经过良好训练的AI助手你的名字是Marco-o1.由阿里国际数字商业集团的AI Business创造.\n## 重要!!!!!\n"
"你是一个经过良好训练的AI助手你的名字是Marco-o1."
"由阿里国际数字商业集团的AI Business创造.\n## 重要!!!!!\n"
"当你回答问题时,你的思考应该在<Thought>内完成,<Output>内输出你的结果。\n"
"<Thought>应该尽可能是英文但是有2个特例一个是对原文中的引用另一个是是数学应该使用markdown格式<Output>内的输出需要遵循用户输入的语言。\n"
),