[data] qwen3 fixes (#8109)
This commit is contained in:
@@ -52,6 +52,7 @@ class Template:
|
||||
efficient_eos: bool
|
||||
replace_eos: bool
|
||||
replace_jinja_template: bool
|
||||
enable_thinking: bool
|
||||
mm_plugin: "BasePlugin"
|
||||
|
||||
def encode_oneturn(
|
||||
@@ -60,7 +61,6 @@ class Template:
|
||||
messages: list[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
enable_thinking: bool = False,
|
||||
) -> tuple[list[int], list[int]]:
|
||||
r"""Return a single pair of token ids representing prompt and response respectively."""
|
||||
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||
@@ -94,7 +94,7 @@ class Template:
|
||||
|
||||
return list(stop_token_ids)
|
||||
|
||||
def add_thought(self, content: str) -> str:
|
||||
def add_thought(self, content: str = "") -> str:
|
||||
r"""Add empty thought to assistant message."""
|
||||
return f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n" + content
|
||||
|
||||
@@ -105,7 +105,7 @@ class Template:
|
||||
|
||||
def get_thought_word_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]:
|
||||
r"""Get the token ids of thought words."""
|
||||
return tokenizer.encode(f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n", add_special_tokens=False)
|
||||
return tokenizer.encode(self.add_thought(), add_special_tokens=False)
|
||||
|
||||
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> list[int]:
|
||||
r"""Convert elements to token ids."""
|
||||
@@ -406,26 +406,21 @@ class ReasoningTemplate(Template):
|
||||
messages: list[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
enable_thinking: bool = False,
|
||||
) -> tuple[list[int], list[int]]:
|
||||
messages = deepcopy(messages)
|
||||
for i in range(len(messages)):
|
||||
if messages[i]["role"] == Role.ASSISTANT and (i != len(messages) - 1):
|
||||
messages[i]["content"] = self.remove_thought(messages[i]["content"])
|
||||
for i in range(1, len(messages) - 2, 2):
|
||||
messages[i]["content"] = self.remove_thought(messages[i]["content"])
|
||||
|
||||
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||
prompt_ids = []
|
||||
for encoded_ids in encoded_messages[:-1]:
|
||||
prompt_ids += encoded_ids
|
||||
|
||||
if not enable_thinking and (
|
||||
messages[-1]["role"] == Role.ASSISTANT
|
||||
and self.thought_words[0] not in messages[-1]["content"]
|
||||
prompt_ids, response_ids = super().encode_oneturn(tokenizer, messages, system, tools)
|
||||
if (
|
||||
self.thought_words[0] not in messages[-1]["content"]
|
||||
and self.thought_words[1] not in messages[-1]["content"]
|
||||
):
|
||||
prompt_ids += self.get_thought_word_ids(tokenizer)
|
||||
if not self.enable_thinking:
|
||||
prompt_ids = prompt_ids + self.get_thought_word_ids(tokenizer)
|
||||
else:
|
||||
response_ids = self.get_thought_word_ids(tokenizer) + response_ids
|
||||
|
||||
response_ids = encoded_messages[-1]
|
||||
return prompt_ids, response_ids
|
||||
|
||||
@override
|
||||
@@ -436,15 +431,16 @@ class ReasoningTemplate(Template):
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
) -> list[tuple[list[int], list[int]]]:
|
||||
messages = deepcopy(messages)
|
||||
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||
for i in range(len(messages) - 1):
|
||||
for i in range(0, len(messages), 2):
|
||||
if (
|
||||
messages[i + 1]["role"] == Role.ASSISTANT
|
||||
and self.thought_words[0] not in messages[i + 1]["content"]
|
||||
self.thought_words[0] not in messages[i + 1]["content"]
|
||||
and self.thought_words[1] not in messages[i + 1]["content"]
|
||||
):
|
||||
encoded_messages[i] += self.get_thought_word_ids(tokenizer)
|
||||
if not self.enable_thinking:
|
||||
encoded_messages[i] += self.get_thought_word_ids(tokenizer)
|
||||
else:
|
||||
encoded_messages[i + 1] = self.get_thought_word_ids(tokenizer) + encoded_messages[i + 1]
|
||||
|
||||
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
|
||||
|
||||
@@ -467,6 +463,7 @@ def register_template(
|
||||
efficient_eos: bool = False,
|
||||
replace_eos: bool = False,
|
||||
replace_jinja_template: bool = False,
|
||||
enable_thinking: bool = True,
|
||||
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
|
||||
template_class: type["Template"] = Template,
|
||||
) -> None:
|
||||
@@ -513,6 +510,7 @@ def register_template(
|
||||
efficient_eos=efficient_eos,
|
||||
replace_eos=replace_eos,
|
||||
replace_jinja_template=replace_jinja_template,
|
||||
enable_thinking=enable_thinking,
|
||||
mm_plugin=mm_plugin,
|
||||
)
|
||||
|
||||
@@ -549,6 +547,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
|
||||
messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}]
|
||||
assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
|
||||
assistant_slot = assistant_slot[len(prefix) + len(user_slot) :]
|
||||
template_class = ReasoningTemplate if "<think>" in assistant_slot else Template
|
||||
assistant_slot = assistant_slot.replace("<think>", "").replace("</think>", "").lstrip("\n") # remove thought tags
|
||||
|
||||
if len(user_slot) > len(user_slot_empty_system):
|
||||
@@ -558,7 +557,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
|
||||
else: # if defaut_system is empty, user_slot_empty_system will be longer than user_slot
|
||||
default_system = ""
|
||||
|
||||
return Template(
|
||||
return template_class(
|
||||
format_user=StringFormatter(slots=[user_slot]),
|
||||
format_assistant=StringFormatter(slots=[assistant_slot]),
|
||||
format_system=StringFormatter(slots=[system_slot]),
|
||||
@@ -572,6 +571,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
|
||||
efficient_eos=False,
|
||||
replace_eos=False,
|
||||
replace_jinja_template=False,
|
||||
enable_thinking=True,
|
||||
mm_plugin=get_mm_plugin(name="base"),
|
||||
)
|
||||
|
||||
@@ -600,6 +600,11 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
||||
template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format)
|
||||
template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
|
||||
|
||||
if data_args.default_system is not None:
|
||||
logger.info_rank0(f"Using default system message: {data_args.default_system}.")
|
||||
template.default_system = data_args.default_system
|
||||
|
||||
template.enable_thinking = data_args.enable_thinking
|
||||
template.fix_special_tokens(tokenizer)
|
||||
template.fix_jinja_template(tokenizer)
|
||||
return template
|
||||
|
||||
Reference in New Issue
Block a user