[data] qwen3 fixes (#8109)

This commit is contained in:
hoshi-hiyouga
2025-05-20 02:00:30 +08:00
committed by GitHub
parent 45030ff803
commit 9b5baa97f0
13 changed files with 197 additions and 160 deletions

View File

@@ -104,10 +104,7 @@ class HuggingfaceEngine(BaseEngine):
messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], processor
)
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or generating_args["default_system"]
enable_thinking = input_kwargs.pop("enable_thinking", None)
enable_thinking = enable_thinking if enable_thinking is not None else generating_args["enable_thinking"]
prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools, enable_thinking)
prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
prompt_ids, _ = template.mm_plugin.process_token_ids(
prompt_ids,
None,

View File

@@ -160,10 +160,7 @@ class SGLangEngine(BaseEngine):
messages, images or [], videos or [], audios or [], self.processor
)
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
enable_thinking = input_kwargs.pop("enable_thinking", None)
enable_thinking = enable_thinking if enable_thinking is not None else self.generating_args["enable_thinking"]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools, enable_thinking)
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids)
temperature: Optional[float] = input_kwargs.pop("temperature", None)

View File

@@ -124,10 +124,7 @@ class VllmEngine(BaseEngine):
messages, images or [], videos or [], audios or [], self.processor
)
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
enable_thinking = input_kwargs.pop("enable_thinking", None)
enable_thinking = enable_thinking if enable_thinking is not None else self.generating_args["enable_thinking"]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools, enable_thinking)
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids)
temperature: Optional[float] = input_kwargs.pop("temperature", None)

View File

@@ -52,6 +52,7 @@ class Template:
efficient_eos: bool
replace_eos: bool
replace_jinja_template: bool
enable_thinking: bool
mm_plugin: "BasePlugin"
def encode_oneturn(
@@ -60,7 +61,6 @@ class Template:
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
enable_thinking: bool = False,
) -> tuple[list[int], list[int]]:
r"""Return a single pair of token ids representing prompt and response respectively."""
encoded_messages = self._encode(tokenizer, messages, system, tools)
@@ -94,7 +94,7 @@ class Template:
return list(stop_token_ids)
def add_thought(self, content: str) -> str:
def add_thought(self, content: str = "") -> str:
r"""Add empty thought to assistant message."""
return f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n" + content
@@ -105,7 +105,7 @@ class Template:
def get_thought_word_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]:
r"""Get the token ids of thought words."""
return tokenizer.encode(f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n", add_special_tokens=False)
return tokenizer.encode(self.add_thought(), add_special_tokens=False)
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> list[int]:
r"""Convert elements to token ids."""
@@ -406,26 +406,21 @@ class ReasoningTemplate(Template):
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
enable_thinking: bool = False,
) -> tuple[list[int], list[int]]:
messages = deepcopy(messages)
for i in range(len(messages)):
if messages[i]["role"] == Role.ASSISTANT and (i != len(messages) - 1):
messages[i]["content"] = self.remove_thought(messages[i]["content"])
for i in range(1, len(messages) - 2, 2):
messages[i]["content"] = self.remove_thought(messages[i]["content"])
encoded_messages = self._encode(tokenizer, messages, system, tools)
prompt_ids = []
for encoded_ids in encoded_messages[:-1]:
prompt_ids += encoded_ids
if not enable_thinking and (
messages[-1]["role"] == Role.ASSISTANT
and self.thought_words[0] not in messages[-1]["content"]
prompt_ids, response_ids = super().encode_oneturn(tokenizer, messages, system, tools)
if (
self.thought_words[0] not in messages[-1]["content"]
and self.thought_words[1] not in messages[-1]["content"]
):
prompt_ids += self.get_thought_word_ids(tokenizer)
if not self.enable_thinking:
prompt_ids = prompt_ids + self.get_thought_word_ids(tokenizer)
else:
response_ids = self.get_thought_word_ids(tokenizer) + response_ids
response_ids = encoded_messages[-1]
return prompt_ids, response_ids
@override
@@ -436,15 +431,16 @@ class ReasoningTemplate(Template):
system: Optional[str] = None,
tools: Optional[str] = None,
) -> list[tuple[list[int], list[int]]]:
messages = deepcopy(messages)
encoded_messages = self._encode(tokenizer, messages, system, tools)
for i in range(len(messages) - 1):
for i in range(0, len(messages), 2):
if (
messages[i + 1]["role"] == Role.ASSISTANT
and self.thought_words[0] not in messages[i + 1]["content"]
self.thought_words[0] not in messages[i + 1]["content"]
and self.thought_words[1] not in messages[i + 1]["content"]
):
encoded_messages[i] += self.get_thought_word_ids(tokenizer)
if not self.enable_thinking:
encoded_messages[i] += self.get_thought_word_ids(tokenizer)
else:
encoded_messages[i + 1] = self.get_thought_word_ids(tokenizer) + encoded_messages[i + 1]
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
@@ -467,6 +463,7 @@ def register_template(
efficient_eos: bool = False,
replace_eos: bool = False,
replace_jinja_template: bool = False,
enable_thinking: bool = True,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
template_class: type["Template"] = Template,
) -> None:
@@ -513,6 +510,7 @@ def register_template(
efficient_eos=efficient_eos,
replace_eos=replace_eos,
replace_jinja_template=replace_jinja_template,
enable_thinking=enable_thinking,
mm_plugin=mm_plugin,
)
@@ -549,6 +547,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}]
assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
assistant_slot = assistant_slot[len(prefix) + len(user_slot) :]
template_class = ReasoningTemplate if "<think>" in assistant_slot else Template
assistant_slot = assistant_slot.replace("<think>", "").replace("</think>", "").lstrip("\n") # remove thought tags
if len(user_slot) > len(user_slot_empty_system):
@@ -558,7 +557,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
else: # if defaut_system is empty, user_slot_empty_system will be longer than user_slot
default_system = ""
return Template(
return template_class(
format_user=StringFormatter(slots=[user_slot]),
format_assistant=StringFormatter(slots=[assistant_slot]),
format_system=StringFormatter(slots=[system_slot]),
@@ -572,6 +571,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
efficient_eos=False,
replace_eos=False,
replace_jinja_template=False,
enable_thinking=True,
mm_plugin=get_mm_plugin(name="base"),
)
@@ -600,6 +600,11 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format)
template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
if data_args.default_system is not None:
logger.info_rank0(f"Using default system message: {data_args.default_system}.")
template.default_system = data_args.default_system
template.enable_thinking = data_args.enable_thinking
template.fix_special_tokens(tokenizer)
template.fix_jinja_template(tokenizer)
return template

View File

@@ -115,6 +115,14 @@ class DataArguments:
default=None,
metadata={"help": "Tool format to use for constructing function calling examples."},
)
default_system: Optional[str] = field(
default=None,
metadata={"help": "Override the default system message in the template."},
)
enable_thinking: bool = field(
default=True,
metadata={"help": "Whether or not to enable thinking mode for reasoning models."},
)
tokenized_path: Optional[str] = field(
default=None,
metadata={

View File

@@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import asdict, dataclass, field
from typing import Any, Optional
from typing import Any
from transformers import GenerationConfig
@@ -62,18 +62,10 @@ class GeneratingArguments:
default=1.0,
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
)
default_system: Optional[str] = field(
default=None,
metadata={"help": "Default system message to use in chat completion."},
)
skip_special_tokens: bool = field(
default=True,
metadata={"help": "Whether or not to remove special tokens in the decoding."},
)
enable_thinking: bool = field(
default=True,
metadata={"help": "Whether or not to enable thinking mode for reasoning models."},
)
def to_dict(self, obey_generation_config: bool = False) -> dict[str, Any]:
args = asdict(self)

View File

@@ -15,6 +15,7 @@
import json
import os
from collections.abc import Generator
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional
from transformers.utils import is_torch_npu_available
@@ -68,6 +69,14 @@ def _format_response(text: str, lang: str, escape_html: bool, thought_words: tup
)
@contextmanager
def update_attr(obj: Any, name: str, value: Any):
old_value = getattr(obj, name, None)
setattr(obj, name, value)
yield
setattr(obj, name, old_value)
class WebChatModel(ChatModel):
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
self.manager = manager
@@ -198,35 +207,35 @@ class WebChatModel(ChatModel):
Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ...
Output: infer.chatbot, infer.messages
"""
chatbot.append({"role": "assistant", "content": ""})
response = ""
for new_text in self.stream_chat(
messages,
system,
tools,
images=[image] if image else None,
videos=[video] if video else None,
audios=[audio] if audio else None,
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature,
skip_special_tokens=skip_special_tokens,
enable_thinking=enable_thinking,
):
response += new_text
if tools:
result = self.engine.template.extract_tool(response)
else:
result = response
with update_attr(self.engine.template, "enable_thinking", enable_thinking):
chatbot.append({"role": "assistant", "content": ""})
response = ""
for new_text in self.stream_chat(
messages,
system,
tools,
images=[image] if image else None,
videos=[video] if video else None,
audios=[audio] if audio else None,
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature,
skip_special_tokens=skip_special_tokens,
):
response += new_text
if tools:
result = self.engine.template.extract_tool(response)
else:
result = response
if isinstance(result, list):
tool_calls = [{"name": tool.name, "arguments": json.loads(tool.arguments)} for tool in result]
tool_calls = json.dumps(tool_calls, ensure_ascii=False)
output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}]
bot_text = "```json\n" + tool_calls + "\n```"
else:
output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
bot_text = _format_response(result, lang, escape_html, self.engine.template.thought_words)
if isinstance(result, list):
tool_calls = [{"name": tool.name, "arguments": json.loads(tool.arguments)} for tool in result]
tool_calls = json.dumps(tool_calls, ensure_ascii=False)
output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}]
bot_text = "```json\n" + tool_calls + "\n```"
else:
output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
bot_text = _format_response(result, lang, escape_html, self.engine.template.thought_words)
chatbot[-1] = {"role": "assistant", "content": bot_text}
yield chatbot, output_messages
chatbot[-1] = {"role": "assistant", "content": bot_text}
yield chatbot, output_messages