From 487f8b81916a545f4720e26bc83f1dc4c3096d84 Mon Sep 17 00:00:00 2001 From: xvxuopop <127376094+xvxuopop@users.noreply.github.com> Date: Wed, 18 Mar 2026 11:30:50 +0800 Subject: [PATCH] [v1] add qwen3 templates and fix rendering plugin. (#10212) Co-authored-by: Yaowei Zheng --- src/llamafactory/v1/core/utils/rendering.py | 11 +- .../v1/plugins/model_plugins/rendering.py | 229 ++-------------- .../model_plugins/templates/__init__.py | 13 + .../plugins/model_plugins/templates/qwen3.py | 259 ++++++++++++++++++ .../model_plugins/templates/qwen3_nothink.py | 209 ++++++++++++++ src/llamafactory/v1/utils/types.py | 2 +- 6 files changed, 516 insertions(+), 207 deletions(-) create mode 100644 src/llamafactory/v1/plugins/model_plugins/templates/__init__.py create mode 100644 src/llamafactory/v1/plugins/model_plugins/templates/qwen3.py create mode 100644 src/llamafactory/v1/plugins/model_plugins/templates/qwen3_nothink.py diff --git a/src/llamafactory/v1/core/utils/rendering.py b/src/llamafactory/v1/core/utils/rendering.py index cbe8383c4..2cb22bfd6 100644 --- a/src/llamafactory/v1/core/utils/rendering.py +++ b/src/llamafactory/v1/core/utils/rendering.py @@ -91,7 +91,11 @@ class Renderer: self.processor = processor def render_messages( - self, messages: list[Message], tools: str | None = None, is_generate: bool = False + self, + messages: list[Message], + tools: str | None = None, + is_generate: bool = False, + enable_thinking: bool = False, ) -> ModelInput: """Apply template to messages and convert them to model input. @@ -99,6 +103,7 @@ class Renderer: messages (list[Message]): The messages to render. tools (str | None, optional): The tools to use. Defaults to None. is_generate (bool, optional): Whether to render for generation. Defaults to False. + enable_thinking (bool, optional): Whether to enable thinking mode for generation. Defaults to False. Returns: ModelInput: The rendered model input. @@ -108,7 +113,9 @@ class Renderer: else: from ...plugins.model_plugins.rendering import RenderingPlugin - return RenderingPlugin(self.template).render_messages(self.processor, messages, tools, is_generate) + return RenderingPlugin(self.template).render_messages( + self.processor, messages, tools, is_generate, enable_thinking + ) def parse_message(self, generated_text: str) -> Message: """Parse a message in the template format. diff --git a/src/llamafactory/v1/plugins/model_plugins/rendering.py b/src/llamafactory/v1/plugins/model_plugins/rendering.py index 8ca8b43fc..566c83f50 100644 --- a/src/llamafactory/v1/plugins/model_plugins/rendering.py +++ b/src/llamafactory/v1/plugins/model_plugins/rendering.py @@ -12,224 +12,45 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json -import re +import importlib -from ...utils.constants import IGNORE_INDEX -from ...utils.helper import get_tokenizer +from ...utils import logging from ...utils.plugin import BasePlugin -from ...utils.types import Message, ModelInput, Processor, ToolCall +from ...utils.types import Message, ModelInput, Processor + + +logger = logging.get_logger(__name__) class RenderingPlugin(BasePlugin): + _attempted_template_imports: set[str] = set() + + def _ensure_template_imported(self) -> None: + if self.name is None or self.name in self._attempted_template_imports: + return + + full_module_name = f"{__package__}.templates.{self.name}" + self._attempted_template_imports.add(self.name) + try: + importlib.import_module(full_module_name) + except Exception as exc: + logger.warning(f"[Template Registry] Failed to import {full_module_name}: {exc}") + + def __getitem__(self, method_name: str): + self._ensure_template_imported() + return super().__getitem__(method_name) + def render_messages( self, processor: Processor, messages: list[Message], tools: str | None = None, is_generate: bool = False, + enable_thinking: bool = False, ) -> ModelInput: """Render messages in the template format.""" - return self["render_messages"](processor, messages, tools, is_generate) + return self["render_messages"](processor, messages, tools, is_generate, enable_thinking) def parse_messages(self, generated_text: str) -> Message: """Parse messages in the template format.""" return self["parse_messages"](generated_text) - - -def _update_model_input( - processor: Processor, - input_ids: list[int], - labels: list[int], - loss_weights: list[int], - temp_str: str, - temp_weight: float, -) -> str: - """Update model input with temporary string.""" - if not temp_str: - return "" - - tokenizer = get_tokenizer(processor) - temp_ids = tokenizer.encode(temp_str, add_special_tokens=False) - input_ids.extend(temp_ids) - loss_weights.extend([temp_weight] * len(temp_ids)) - if temp_weight > 1e-6: - labels.extend(temp_ids) - else: - labels.extend([IGNORE_INDEX] * len(temp_ids)) - - return "" - - -@RenderingPlugin("qwen3_nothink").register("render_messages") -def render_qwen3_nothink_messages( - processor: Processor, - messages: list[Message], - tools: str | None = None, - is_generate: bool = False, -) -> ModelInput: - """Render messages in the Qwen3 nothink template format. - - See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-4B-Instruct-2507 - """ - input_ids, labels, loss_weights = [], [], [] - temp_str, temp_weight = "", 0.0 - if tools: - temp_str += "<|im_start|>system\n" - if messages[0]["role"] == "system": - for content in messages[0]["content"]: - if content["type"] == "text": - temp_str += content["value"] - else: - raise ValueError(f"Unsupported content type: {content['type']}") - - temp_str += "\n\n" - temp_weight = messages[0].get("loss_weight", 0.0) - - temp_str += ( - "# Tools\n\nYou may call one or more functions to assist with the user query.\n\n" - "You are provided with function signatures within XML tags:\n" - ) - try: - tools = json.loads(tools) - except json.JSONDecodeError: - raise ValueError(f"Invalid tools format: {str(tools)}.") - - if not isinstance(tools, list): - tools = [tools] - - for tool in tools: - temp_str += "\n" + json.dumps(tool, ensure_ascii=False) - - temp_str += ( - "\n\n\nFor each function call, return a json object with function name " - 'and arguments within XML tags:\n\n{"name": ' - ', "arguments": }\n<|im_end|>\n' - ) - elif messages[0]["role"] == "system": - temp_str += "<|im_start|>system\n" - for content in messages[0]["content"]: - if content["type"] == "text": - temp_str += content["value"] - else: - raise ValueError(f"Unsupported content type: {content['type']}") - - temp_str += "<|im_end|>\n" - temp_weight = messages[0].get("loss_weight", 0.0) - - temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight) - - for turn_idx, message in enumerate(messages): - if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0): - temp_str += "<|im_start|>" + message["role"] + "\n" - for content in message["content"]: - if content["type"] == "text": - temp_str += content["value"] - else: - raise ValueError(f"Unsupported content type: {content['type']}") - - temp_str += "<|im_end|>\n" - temp_weight = message.get("loss_weight", 0.0) - elif message["role"] == "assistant": - temp_str += "<|im_start|>" + message["role"] + "\n" - for val_idx, content in enumerate(message["content"]): - if content["type"] == "text": - temp_str += content["value"] - elif content["type"] == "reasoning": - temp_str += "\n" + content["value"] + "\n\n\n" # avoid using special tokens - elif content["type"] == "tool_call": - if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]: - temp_str += "\n" - - try: - tool_call: ToolCall = json.loads(content["value"]) - except json.JSONDecodeError: - raise ValueError(f"Invalid tool call format: {content['value']}.") - - temp_str += ( - '\n{"name": "' - + tool_call["name"] - + '", "arguments": ' - + json.dumps(tool_call["arguments"], ensure_ascii=False) - + "}\n" - ) - - else: - raise ValueError(f"Unsupported content type: {content['type']}") - - temp_str += "<|im_end|>\n" - temp_weight = message.get("loss_weight", 1.0) - elif message["role"] == "tool": - if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool": - temp_str += "<|im_start|>user" - - temp_str += "\n\n" - for content in message["content"]: - if content["type"] == "text": - temp_str += content["value"] - else: - raise ValueError(f"Unsupported content type: {content['type']}") - - temp_str += "\n" - if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool": - temp_str += "<|im_end|>\n" - - temp_weight = message.get("loss_weight", 0.0) - - temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight) - - if is_generate: - temp_str += "<|im_start|>assistant\n" - temp_weight = 0.0 - - temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight) - - attention_mask = [1] * len(input_ids) - return ModelInput( - input_ids=input_ids, - attention_mask=attention_mask, - labels=labels, - loss_weights=loss_weights, - ) - - -@RenderingPlugin("qwen3_nothink").register("parse_message") -def parse_qwen3_nothink_message(generated_text: str) -> Message: - """Parse a message in the Qwen3 nothink template format. Supports interleaved reasoning and tool calls. - - Args: - generated_text (str): The generated text in the Qwen3 nothink template format. - - Returns: - Message: The parsed message. - """ - pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*\s*", re.DOTALL) - content = [] - last_end = 0 - for match in pattern.finditer(generated_text): - start, end = match.span() - if start > last_end: - text = generated_text[last_end:start].strip() - if text: - content.append({"type": "text", "value": text}) - - tag_type = match.group(1) - tag_value = match.group(2).strip() - if tag_type == "thinking": - content.append({"type": "reasoning", "value": tag_value.strip()}) - elif tag_type == "tool_call": - try: - json.loads(tag_value.strip()) - except json.JSONDecodeError: - raise ValueError(f"Invalid tool call format: {tag_value.strip()}.") - - content.append({"type": "tool_call", "value": tag_value.strip()}) - - last_end = end - - if last_end < len(generated_text): - text = generated_text[last_end:].strip() - if text: - content.append({"type": "text", "value": text}) - - return Message(role="assistant", content=content) diff --git a/src/llamafactory/v1/plugins/model_plugins/templates/__init__.py b/src/llamafactory/v1/plugins/model_plugins/templates/__init__.py new file mode 100644 index 000000000..ec0d62554 --- /dev/null +++ b/src/llamafactory/v1/plugins/model_plugins/templates/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/llamafactory/v1/plugins/model_plugins/templates/qwen3.py b/src/llamafactory/v1/plugins/model_plugins/templates/qwen3.py new file mode 100644 index 000000000..e9f479941 --- /dev/null +++ b/src/llamafactory/v1/plugins/model_plugins/templates/qwen3.py @@ -0,0 +1,259 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import re + +from ....utils.constants import IGNORE_INDEX +from ....utils.helper import get_tokenizer +from ....utils.types import Message, ModelInput, Processor, ToolCall +from ..rendering import RenderingPlugin + + +def _update_model_input( + processor: Processor, + input_ids: list[int], + labels: list[int], + loss_weights: list[int], + temp_str: str, + temp_weight: float, +) -> str: + """Update model input with temporary string.""" + if not temp_str: + return "" + + tokenizer = get_tokenizer(processor) + temp_ids = tokenizer.encode(temp_str, add_special_tokens=False) + input_ids.extend(temp_ids) + loss_weights.extend([temp_weight] * len(temp_ids)) + if temp_weight > 1e-6: + labels.extend(temp_ids) + else: + labels.extend([IGNORE_INDEX] * len(temp_ids)) + + return "" + + +def _concat_text_content(message: Message) -> str: + """Concatenate text fields in a message.""" + message_text = "" + for content in message["content"]: + if content["type"] == "text": + message_text += content["value"] + else: + raise ValueError(f"Unsupported content type: {content['type']}") + + return message_text + + +def _get_last_query_index(messages: list[Message]) -> int: + """Find the last user query index, excluding wrapped tool responses.""" + last_query_index = len(messages) - 1 + for idx in range(len(messages) - 1, -1, -1): + message = messages[idx] + if message["role"] != "user": + continue + + user_text = "" + is_plain_text = True + for content in message["content"]: + if content["type"] != "text": + is_plain_text = False + break + user_text += content["value"] + + if not is_plain_text: + continue + + if not (user_text.startswith("") and user_text.endswith("")): + last_query_index = idx + break + + return last_query_index + + +def _split_assistant_content(message: Message) -> tuple[str, str, list[ToolCall]]: + """Split assistant message into text, reasoning and tool calls.""" + text_content = "" + reasoning_content = "" + tool_calls: list[ToolCall] = [] + + for content in message["content"]: + if content["type"] == "text": + text_content += content["value"] + elif content["type"] == "reasoning": + reasoning_content += content["value"] + elif content["type"] == "tool_call": + try: + tool_call: ToolCall = json.loads(content["value"]) + except json.JSONDecodeError: + raise ValueError(f"Invalid tool call format: {content['value']}.") + + tool_calls.append(tool_call) + else: + raise ValueError(f"Unsupported content type: {content['type']}") + + return text_content, reasoning_content, tool_calls + + +@RenderingPlugin("qwen3").register("render_messages") +def render_qwen3_messages( + processor: Processor, + messages: list[Message], + tools: str | None = None, + is_generate: bool = False, + enable_thinking: bool = False, +) -> ModelInput: + """Render messages in the Qwen3 template format. + + See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-8B + """ + input_ids, labels, loss_weights = [], [], [] + temp_str, temp_weight = "", 0.0 + if tools: + temp_str += "<|im_start|>system\n" + if messages[0]["role"] == "system": + temp_str += _concat_text_content(messages[0]) + "\n\n" + temp_weight = messages[0].get("loss_weight", 0.0) + + temp_str += ( + "# Tools\n\nYou may call one or more functions to assist with the user query.\n\n" + "You are provided with function signatures within XML tags:\n" + ) + try: + tools = json.loads(tools) + except json.JSONDecodeError: + raise ValueError(f"Invalid tools format: {str(tools)}.") + + if not isinstance(tools, list): + tools = [tools] + + for tool in tools: + temp_str += "\n" + json.dumps(tool, ensure_ascii=False) + + temp_str += ( + "\n\n\nFor each function call, return a json object with function name " + 'and arguments within XML tags:\n\n{"name": ' + ', "arguments": }\n<|im_end|>\n' + ) + elif messages[0]["role"] == "system": + temp_str += "<|im_start|>system\n" + _concat_text_content(messages[0]) + "<|im_end|>\n" + temp_weight = messages[0].get("loss_weight", 0.0) + + temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight) + last_query_index = _get_last_query_index(messages) + + for turn_idx, message in enumerate(messages): + if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0): + temp_str += "<|im_start|>" + message["role"] + "\n" + _concat_text_content(message) + "<|im_end|>\n" + temp_weight = message.get("loss_weight", 0.0) + elif message["role"] == "assistant": + temp_str += "<|im_start|>" + message["role"] + "\n" + + text_content, reasoning_content, tool_calls = _split_assistant_content(message) + if turn_idx > last_query_index and (turn_idx == len(messages) - 1 or reasoning_content): + temp_str += "\n" + reasoning_content.strip("\n") + "\n\n\n" + text_content.lstrip("\n") + else: + temp_str += text_content + + for tool_call_idx, tool_call in enumerate(tool_calls): + if (tool_call_idx == 0 and text_content) or tool_call_idx > 0: + temp_str += "\n" + + arguments = tool_call.get("arguments") + if isinstance(arguments, str): + arguments_str = arguments + else: + arguments_str = json.dumps(arguments, ensure_ascii=False) + + temp_str += ( + '\n{"name": "' + + tool_call["name"] + + '", "arguments": ' + + arguments_str + + "}\n" + ) + + temp_str += "<|im_end|>\n" + temp_weight = message.get("loss_weight", 1.0) + elif message["role"] == "tool": + if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool": + temp_str += "<|im_start|>user" + + temp_str += "\n\n" + _concat_text_content(message) + "\n" + if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool": + temp_str += "<|im_end|>\n" + + temp_weight = message.get("loss_weight", 0.0) + + temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight) + + if is_generate: + temp_str += "<|im_start|>assistant\n" + temp_weight = 0.0 + if enable_thinking is False: + temp_str += "\n\n\n\n" + + temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight) + + attention_mask = [1] * len(input_ids) + return ModelInput( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + loss_weights=loss_weights, + ) + + +@RenderingPlugin("qwen3").register("parse_message") +def parse_qwen3_message(generated_text: str) -> Message: + """Parse a message in the Qwen3 template format. Supports interleaved reasoning and tool calls. + + Args: + generated_text (str): The generated text in the Qwen3 template format. + + Returns: + Message: The parsed message. + """ + pattern = re.compile(r"<(think|tool_call)>\s*(.*?)\s*\s*", re.DOTALL) + content = [] + last_end = 0 + + for match in pattern.finditer(generated_text): + start, end = match.span() + if start > last_end: + text = generated_text[last_end:start].strip() + if text: + content.append({"type": "text", "value": text}) + + tag_type = match.group(1) + tag_value = match.group(2).strip() + if tag_type == "think": + content.append({"type": "reasoning", "value": tag_value.strip()}) + elif tag_type == "tool_call": + try: + json.loads(tag_value.strip()) + except json.JSONDecodeError: + raise ValueError(f"Invalid tool call format: {tag_value.strip()}.") + + content.append({"type": "tool_call", "value": tag_value.strip()}) + + last_end = end + + if last_end < len(generated_text): + text = generated_text[last_end:].strip() + if text: + content.append({"type": "text", "value": text}) + + return Message(role="assistant", content=content) diff --git a/src/llamafactory/v1/plugins/model_plugins/templates/qwen3_nothink.py b/src/llamafactory/v1/plugins/model_plugins/templates/qwen3_nothink.py new file mode 100644 index 000000000..244835e0f --- /dev/null +++ b/src/llamafactory/v1/plugins/model_plugins/templates/qwen3_nothink.py @@ -0,0 +1,209 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import re + +from ....utils.constants import IGNORE_INDEX +from ....utils.helper import get_tokenizer +from ....utils.types import Message, ModelInput, Processor, ToolCall +from ..rendering import RenderingPlugin + + +def _update_model_input( + processor: Processor, + input_ids: list[int], + labels: list[int], + loss_weights: list[int], + temp_str: str, + temp_weight: float, +) -> str: + """Update model input with temporary string.""" + if not temp_str: + return "" + + tokenizer = get_tokenizer(processor) + temp_ids = tokenizer.encode(temp_str, add_special_tokens=False) + input_ids.extend(temp_ids) + loss_weights.extend([temp_weight] * len(temp_ids)) + if temp_weight > 1e-6: + labels.extend(temp_ids) + else: + labels.extend([IGNORE_INDEX] * len(temp_ids)) + + return "" + + +def _concat_text_content(message: Message) -> str: + """Concatenate text fields in a message.""" + message_text = "" + for content in message["content"]: + if content["type"] == "text": + message_text += content["value"] + else: + raise ValueError(f"Unsupported content type: {content['type']}") + + return message_text + + +@RenderingPlugin("qwen3_nothink").register("render_messages") +def render_qwen3_nothink_messages( + processor: Processor, + messages: list[Message], + tools: str | None = None, + is_generate: bool = False, + enable_thinking: bool = False, +) -> ModelInput: + """Render messages in the Qwen3 nothink template format. + + See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-4B-Instruct-2507 + """ + input_ids, labels, loss_weights = [], [], [] + temp_str, temp_weight = "", 0.0 + if tools: + temp_str += "<|im_start|>system\n" + if messages[0]["role"] == "system": + temp_str += _concat_text_content(messages[0]) + "\n\n" + temp_weight = messages[0].get("loss_weight", 0.0) + + temp_str += ( + "# Tools\n\nYou may call one or more functions to assist with the user query.\n\n" + "You are provided with function signatures within XML tags:\n" + ) + + try: + tools = json.loads(tools) + except json.JSONDecodeError: + raise ValueError(f"Invalid tools format: {str(tools)}.") + + if not isinstance(tools, list): + tools = [tools] + + for tool in tools: + temp_str += "\n" + json.dumps(tool, ensure_ascii=False) + + temp_str += ( + "\n\n\nFor each function call, return a json object with function name " + 'and arguments within XML tags:\n\n{"name": ' + ', "arguments": }\n<|im_end|>\n' + ) + elif messages[0]["role"] == "system": + temp_str += "<|im_start|>system\n" + _concat_text_content(messages[0]) + "<|im_end|>\n" + temp_weight = messages[0].get("loss_weight", 0.0) + + temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight) + + for turn_idx, message in enumerate(messages): + if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0): + temp_str += "<|im_start|>" + message["role"] + "\n" + _concat_text_content(message) + "<|im_end|>\n" + temp_weight = message.get("loss_weight", 0.0) + elif message["role"] == "assistant": + temp_str += "<|im_start|>" + message["role"] + "\n" + for val_idx, content in enumerate(message["content"]): + if content["type"] == "text": + temp_str += content["value"] + elif content["type"] == "reasoning": + temp_str += "\n" + content["value"] + "\n\n\n" # avoid using special tokens + elif content["type"] == "tool_call": + if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]: + temp_str += "\n" + + try: + tool_call: ToolCall = json.loads(content["value"]) + except json.JSONDecodeError: + raise ValueError(f"Invalid tool call format: {content['value']}.") + + temp_str += ( + '\n{"name": "' + + tool_call["name"] + + '", "arguments": ' + + json.dumps(tool_call["arguments"], ensure_ascii=False) + + "}\n" + ) + + else: + raise ValueError(f"Unsupported content type: {content['type']}") + + temp_str += "<|im_end|>\n" + temp_weight = message.get("loss_weight", 1.0) + elif message["role"] == "tool": + if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool": + temp_str += "<|im_start|>user" + + temp_str += "\n\n" + _concat_text_content(message) + "\n" + if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool": + temp_str += "<|im_end|>\n" + + temp_weight = message.get("loss_weight", 0.0) + + temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight) + + if is_generate: + temp_str += "<|im_start|>assistant\n" + temp_weight = 0.0 + if enable_thinking: + raise ValueError("The qwen3_nothink template does not support thinking mode.") + + temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight) + + attention_mask = [1] * len(input_ids) + return ModelInput( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + loss_weights=loss_weights, + ) + + +@RenderingPlugin("qwen3_nothink").register("parse_message") +def parse_qwen3_nothink_message(generated_text: str) -> Message: + """Parse a message in the Qwen3 nothink template format. Supports interleaved reasoning and tool calls. + + Args: + generated_text (str): The generated text in the Qwen3 nothink template format. + + Returns: + Message: The parsed message. + """ + pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*\s*", re.DOTALL) + content = [] + last_end = 0 + + for match in pattern.finditer(generated_text): + start, end = match.span() + if start > last_end: + text = generated_text[last_end:start].strip() + if text: + content.append({"type": "text", "value": text}) + + tag_type = match.group(1) + tag_value = match.group(2).strip() + if tag_type == "thinking": + content.append({"type": "reasoning", "value": tag_value.strip()}) + elif tag_type == "tool_call": + try: + json.loads(tag_value.strip()) + except json.JSONDecodeError: + raise ValueError(f"Invalid tool call format: {tag_value.strip()}.") + + content.append({"type": "tool_call", "value": tag_value.strip()}) + + last_end = end + + if last_end < len(generated_text): + text = generated_text[last_end:].strip() + if text: + content.append({"type": "text", "value": text}) + + return Message(role="assistant", content=content) diff --git a/src/llamafactory/v1/utils/types.py b/src/llamafactory/v1/utils/types.py index 2f3906968..b1f7d52cf 100644 --- a/src/llamafactory/v1/utils/types.py +++ b/src/llamafactory/v1/utils/types.py @@ -85,7 +85,7 @@ class DistributedConfig(TypedDict, total=False): class Content(TypedDict): - type: Literal["text", "reasoning", "tool_call", "image_url"] + type: Literal["text", "reasoning", "tool_call", "image_url", "video_url", "audio_url"] """Type of the content.""" value: str """Value of the content."""