[data] optimize qwen3 loss computation (#7923)

This commit is contained in:
hoshi-hiyouga
2025-04-30 16:18:00 +08:00
committed by GitHub
parent 73198a6645
commit 052ca871bd
11 changed files with 205 additions and 39 deletions

View File

@@ -13,6 +13,7 @@
# limitations under the License.
import re
from copy import deepcopy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union
@@ -59,9 +60,10 @@ class Template:
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
enable_thinking: bool = True,
) -> tuple[list[int], list[int]]:
r"""Return a single pair of token ids representing prompt and response respectively."""
encoded_messages = self._encode(tokenizer, messages, system, tools, remove_thought=True)
encoded_messages = self._encode(tokenizer, messages, system, tools)
prompt_ids = []
for encoded_ids in encoded_messages[:-1]:
prompt_ids += encoded_ids
@@ -77,7 +79,7 @@ class Template:
tools: Optional[str] = None,
) -> list[tuple[list[int], list[int]]]:
r"""Return multiple pairs of token ids representing prompts and responses respectively."""
encoded_messages = self._encode(tokenizer, messages, system, tools, remove_thought=False)
encoded_messages = self._encode(tokenizer, messages, system, tools)
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]:
@@ -92,6 +94,19 @@ class Template:
return list(stop_token_ids)
def add_thought(self, content: str) -> str:
r"""Add empty thought to assistant message."""
return f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n" + content
def remove_thought(self, content: str) -> str:
r"""Remove thought from assistant message."""
pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL)
return re.sub(pattern, "", content).lstrip("\n")
def get_thought_word_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]:
r"""Get the token ids of thought words."""
return tokenizer.encode(f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n", add_special_tokens=False)
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> list[int]:
r"""Convert elements to token ids."""
token_ids = []
@@ -111,18 +126,12 @@ class Template:
return token_ids
def _remove_thought(self, content: str) -> str:
r"""Remove thought from assistant message."""
pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL)
return re.sub(pattern, "", content).lstrip("\n")
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: list[dict[str, str]],
system: Optional[str],
tools: Optional[str],
remove_thought: bool,
) -> list[list[int]]:
r"""Encode formatted inputs to pairs of token ids.
@@ -140,18 +149,14 @@ class Template:
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system.apply(content=(system + tool_text))
content = message["content"]
if remove_thought and message["role"] == Role.ASSISTANT and (i != len(messages) - 1):
content = self._remove_thought(content)
if message["role"] == Role.USER:
elements += self.format_user.apply(content=content, idx=str(i // 2))
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
elif message["role"] == Role.ASSISTANT:
elements += self.format_assistant.apply(content=content)
elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION:
elements += self.format_observation.apply(content=content)
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION:
elements += self.format_function.apply(content=content)
elements += self.format_function.apply(content=message["content"])
else:
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
@@ -331,7 +336,6 @@ class Llama2Template(Template):
messages: list[dict[str, str]],
system: str,
tools: str,
remove_thought: bool,
) -> list[list[int]]:
system = system or self.default_system
encoded_messages = []
@@ -345,18 +349,14 @@ class Llama2Template(Template):
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
system_text = self.format_system.apply(content=(system + tool_text))[0]
content = message["content"]
if remove_thought and message["role"] == Role.ASSISTANT and (i != len(messages) - 1):
content = self._remove_thought(content)
if message["role"] == Role.USER:
elements += self.format_user.apply(content=system_text + content)
elements += self.format_user.apply(content=system_text + message["content"])
elif message["role"] == Role.ASSISTANT:
elements += self.format_assistant.apply(content=content)
elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION:
elements += self.format_observation.apply(content=content)
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION:
elements += self.format_function.apply(content=content)
elements += self.format_function.apply(content=message["content"])
else:
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
@@ -395,6 +395,60 @@ class Llama2Template(Template):
return jinja_template
@dataclass
class ReasoningTemplate(Template):
r"""A template that add thought to assistant message."""
@override
def encode_oneturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
enable_thinking: bool = True,
) -> tuple[list[int], list[int]]:
messages = deepcopy(messages)
for i in range(len(messages)):
if messages[i]["role"] == Role.ASSISTANT and (i != len(messages) - 1):
messages[i]["content"] = self.remove_thought(messages[i]["content"])
encoded_messages = self._encode(tokenizer, messages, system, tools)
prompt_ids = []
for encoded_ids in encoded_messages[:-1]:
prompt_ids += encoded_ids
if not enable_thinking or (
messages[-1]["role"] == Role.ASSISTANT
and self.thought_words[0] not in messages[-1]["content"]
and self.thought_words[1] not in messages[-1]["content"]
):
prompt_ids += self.get_thought_word_ids(tokenizer)
response_ids = encoded_messages[-1]
return prompt_ids, response_ids
@override
def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
) -> list[tuple[list[int], list[int]]]:
messages = deepcopy(messages)
encoded_messages = self._encode(tokenizer, messages, system, tools)
for i in range(len(messages) - 1):
if (
messages[i + 1]["role"] == Role.ASSISTANT
and self.thought_words[0] not in messages[i + 1]["content"]
and self.thought_words[1] not in messages[i + 1]["content"]
):
encoded_messages[i] += self.get_thought_word_ids(tokenizer)
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
TEMPLATES: dict[str, "Template"] = {}
@@ -778,6 +832,15 @@ register_template(
)
# copied from deepseek3 template
register_template(
name="deepseekr1",
format_user=StringFormatter(slots=["<User>{{content}}<Assistant>"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
template_class=ReasoningTemplate,
)
register_template(
name="deepseekcoder",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
@@ -878,6 +941,22 @@ register_template(
)
# copied from glm4 template
register_template(
name="glmz1",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
format_tools=ToolFormatter(tool_format="glm4"),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
template_class=ReasoningTemplate,
)
register_template(
name="granite3",
format_user=StringFormatter(
@@ -1458,6 +1537,7 @@ register_template(
format_tools=ToolFormatter(tool_format="qwen"),
stop_words=["<|im_end|>"],
replace_eos=True,
template_class=ReasoningTemplate,
)