7 Commits

Author SHA1 Message Date
LincolnBurrows2017
2c4f121817 [fix] handle empty content list in system message (#10291)
Co-authored-by: AI Assistant <assistant@example.com>
2026-03-18 12:05:49 +08:00
xvxuopop
487f8b8191 [v1] add qwen3 templates and fix rendering plugin. (#10212)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-03-18 11:30:50 +08:00
SnowCharm
78cad1e332 [fix] unused keys in ray example (#10290) 2026-03-18 00:23:53 +08:00
LincolnBurrows2017
70653026f5 [fix] make position_id_per_seconds configurable for Qwen2OmniPlugin (#10281)
Co-authored-by: LincolnBurrows2017 <lincoln@example.com>
2026-03-16 19:42:38 +08:00
Ruijie Hou
246192abd2 [data] correct gpt_oss template format_assistant (#10269) 2026-03-10 21:36:38 +08:00
浮梦
0258dc14d0 [docker] update npu docker (#10268)
Co-authored-by: frozenleaves <frozen@Mac.local>
2026-03-10 19:37:27 +08:00
xxddccaa
3045adf0ba [fix] fallback to audio_processor when feature_extractor is missing (#10267)
Co-authored-by: kevin <742971636@qq.com>
2026-03-10 19:36:41 +08:00
13 changed files with 541 additions and 225 deletions

View File

@@ -1,6 +1,6 @@
# https://hub.docker.com/r/ascendai/cann/tags
ARG BASE_IMAGE=quay.io/ascend/cann:8.3.rc2-910b-ubuntu22.04-py3.11
ARG BASE_IMAGE=quay.io/ascend/cann:8.5.1-910b-ubuntu22.04-py3.11
FROM ${BASE_IMAGE}
# Installation arguments
@@ -33,9 +33,11 @@ RUN pip config set global.index-url "${PIP_INDEX}" && \
COPY . /app
# Install torch-npu
RUN pip uninstall -y torch torchvision torchaudio && \
pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1" --index-url "${PYTORCH_INDEX}" && \
pip install --no-cache-dir -e . --no-build-isolation && \
RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh
RUN pip uninstall -y torch torchvision torchaudio
RUN pip install --no-cache-dir -r requirements/npu.txt --index-url "${PYTORCH_INDEX}"
RUN pip install --no-cache-dir -r requirements/deepspeed.txt
RUN pip install --no-cache-dir -e . --no-build-isolation && \
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
# Set up volumes

View File

@@ -33,7 +33,7 @@ services:
dockerfile: ./docker/docker-npu/Dockerfile
context: ../..
args:
BASE_IMAGE: quay.io/ascend/cann:8.3.rc2-a3-ubuntu22.04-py3.11
BASE_IMAGE: quay.io/ascend/cann:8.5.1-a3-ubuntu22.04-py3.11
PIP_INDEX: https://pypi.org/simple
container_name: llamafactory-a3
image: llamafactory:npu-a3

View File

@@ -28,12 +28,7 @@ save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### ray
ray_run_name: qwen3_4b_sft_lora
ray_storage_path: ./saves
ray_num_workers: 4 # Number of GPUs to use.
placement_strategy: PACK
resources_per_worker:
GPU: 1
# ray_init_kwargs:
# runtime_env:
# env_vars:

View File

@@ -1,4 +1,4 @@
torch==2.7.1
torch-npu==2.7.1
torch-npu==2.7.1.post2
torchvision==0.22.1
torchaudio==2.7.1

View File

@@ -88,7 +88,10 @@ def _process_request(
if request.messages[0].role == Role.SYSTEM:
content = request.messages.pop(0).content
system = content[0].text if isinstance(content, list) else content
if isinstance(content, list):
system = content[0].text if content else ""
else:
system = content
else:
system = None

View File

@@ -161,7 +161,9 @@ class MMPluginMixin:
video_processor: BaseImageProcessor = getattr(
processor, "video_processor", getattr(processor, "image_processor", None)
)
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) or getattr(
processor, "audio_processor", None
)
if len(images) != 0 and self.image_token is None:
raise ValueError(
"This model does not support image input. Please check whether the correct `template` is used."
@@ -390,7 +392,9 @@ class MMPluginMixin:
mm_inputs.update(video_processor(videos, return_tensors="pt"))
if len(audios) != 0:
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) or getattr(
processor, "audio_processor", None
)
audios = self._regularize_audios(
audios,
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
@@ -1876,7 +1880,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None)
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) or getattr(
processor, "audio_processor", None
)
mm_inputs = {}
if len(images) != 0:
images = self._regularize_images(
@@ -1981,6 +1987,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
f"Each {VIDEO_PLACEHOLDER} must be followed by an {AUDIO_PLACEHOLDER} when using audio in video."
)
position_id_per_seconds: int = getattr(processor, "position_id_per_seconds", 25)
audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
video_t_index = (
torch.arange(video_grid_thw[num_video_tokens][0])
@@ -1992,9 +1999,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
)
.flatten()
* mm_inputs["video_second_per_grid"][num_video_tokens]
* 25 # FIXME hardcode of position_id_per_seconds=25
* position_id_per_seconds
).long()
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]
t_ntoken_per_chunk = position_id_per_seconds * 2
video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
placeholder_string = ""

View File

@@ -1113,7 +1113,7 @@ register_template(
register_template(
name="gpt_oss",
format_user=StringFormatter(slots=["<|start|>user<|message|>{{content}}<|end|><|start|>assistant"]),
format_assistant=StringFormatter(slots=["{{content}}<|end|>"]),
format_assistant=StringFormatter(slots=["{{content}}"]),
format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]),
default_system="You are ChatGPT, a large language model trained by OpenAI.",
thought_words=("<|channel|>analysis<|message|>", "<|end|><|start|>assistant<|channel|>final<|message|>"),

View File

@@ -91,7 +91,11 @@ class Renderer:
self.processor = processor
def render_messages(
self, messages: list[Message], tools: str | None = None, is_generate: bool = False
self,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
enable_thinking: bool = False,
) -> ModelInput:
"""Apply template to messages and convert them to model input.
@@ -99,6 +103,7 @@ class Renderer:
messages (list[Message]): The messages to render.
tools (str | None, optional): The tools to use. Defaults to None.
is_generate (bool, optional): Whether to render for generation. Defaults to False.
enable_thinking (bool, optional): Whether to enable thinking mode for generation. Defaults to False.
Returns:
ModelInput: The rendered model input.
@@ -108,7 +113,9 @@ class Renderer:
else:
from ...plugins.model_plugins.rendering import RenderingPlugin
return RenderingPlugin(self.template).render_messages(self.processor, messages, tools, is_generate)
return RenderingPlugin(self.template).render_messages(
self.processor, messages, tools, is_generate, enable_thinking
)
def parse_message(self, generated_text: str) -> Message:
"""Parse a message in the template format.

View File

@@ -12,224 +12,45 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
import importlib
from ...utils.constants import IGNORE_INDEX
from ...utils.helper import get_tokenizer
from ...utils import logging
from ...utils.plugin import BasePlugin
from ...utils.types import Message, ModelInput, Processor, ToolCall
from ...utils.types import Message, ModelInput, Processor
logger = logging.get_logger(__name__)
class RenderingPlugin(BasePlugin):
_attempted_template_imports: set[str] = set()
def _ensure_template_imported(self) -> None:
if self.name is None or self.name in self._attempted_template_imports:
return
full_module_name = f"{__package__}.templates.{self.name}"
self._attempted_template_imports.add(self.name)
try:
importlib.import_module(full_module_name)
except Exception as exc:
logger.warning(f"[Template Registry] Failed to import {full_module_name}: {exc}")
def __getitem__(self, method_name: str):
self._ensure_template_imported()
return super().__getitem__(method_name)
def render_messages(
self,
processor: Processor,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
enable_thinking: bool = False,
) -> ModelInput:
"""Render messages in the template format."""
return self["render_messages"](processor, messages, tools, is_generate)
return self["render_messages"](processor, messages, tools, is_generate, enable_thinking)
def parse_messages(self, generated_text: str) -> Message:
"""Parse messages in the template format."""
return self["parse_messages"](generated_text)
def _update_model_input(
processor: Processor,
input_ids: list[int],
labels: list[int],
loss_weights: list[int],
temp_str: str,
temp_weight: float,
) -> str:
"""Update model input with temporary string."""
if not temp_str:
return ""
tokenizer = get_tokenizer(processor)
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([temp_weight] * len(temp_ids))
if temp_weight > 1e-6:
labels.extend(temp_ids)
else:
labels.extend([IGNORE_INDEX] * len(temp_ids))
return ""
@RenderingPlugin("qwen3_nothink").register("render_messages")
def render_qwen3_nothink_messages(
processor: Processor,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
) -> ModelInput:
"""Render messages in the Qwen3 nothink template format.
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-4B-Instruct-2507
"""
input_ids, labels, loss_weights = [], [], []
temp_str, temp_weight = "", 0.0
if tools:
temp_str += "<|im_start|>system\n"
if messages[0]["role"] == "system":
for content in messages[0]["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "\n\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str += (
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
)
try:
tools = json.loads(tools)
except json.JSONDecodeError:
raise ValueError(f"Invalid tools format: {str(tools)}.")
if not isinstance(tools, list):
tools = [tools]
for tool in tools:
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
temp_str += (
"\n</tools>\n\nFor each function call, return a json object with function name "
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
)
elif messages[0]["role"] == "system":
temp_str += "<|im_start|>system\n"
for content in messages[0]["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
for turn_idx, message in enumerate(messages):
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
temp_str += "<|im_start|>" + message["role"] + "\n"
for content in message["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
elif message["role"] == "assistant":
temp_str += "<|im_start|>" + message["role"] + "\n"
for val_idx, content in enumerate(message["content"]):
if content["type"] == "text":
temp_str += content["value"]
elif content["type"] == "reasoning":
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
elif content["type"] == "tool_call":
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
temp_str += "\n"
try:
tool_call: ToolCall = json.loads(content["value"])
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {content['value']}.")
temp_str += (
'<tool_call>\n{"name": "'
+ tool_call["name"]
+ '", "arguments": '
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
+ "}\n</tool_call>"
)
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 1.0)
elif message["role"] == "tool":
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
temp_str += "<|im_start|>user"
temp_str += "\n<tool_response>\n"
for content in message["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "\n</tool_response>"
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
if is_generate:
temp_str += "<|im_start|>assistant\n"
temp_weight = 0.0
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
attention_mask = [1] * len(input_ids)
return ModelInput(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
loss_weights=loss_weights,
)
@RenderingPlugin("qwen3_nothink").register("parse_message")
def parse_qwen3_nothink_message(generated_text: str) -> Message:
"""Parse a message in the Qwen3 nothink template format. Supports interleaved reasoning and tool calls.
Args:
generated_text (str): The generated text in the Qwen3 nothink template format.
Returns:
Message: The parsed message.
"""
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
content = []
last_end = 0
for match in pattern.finditer(generated_text):
start, end = match.span()
if start > last_end:
text = generated_text[last_end:start].strip()
if text:
content.append({"type": "text", "value": text})
tag_type = match.group(1)
tag_value = match.group(2).strip()
if tag_type == "thinking":
content.append({"type": "reasoning", "value": tag_value.strip()})
elif tag_type == "tool_call":
try:
json.loads(tag_value.strip())
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
content.append({"type": "tool_call", "value": tag_value.strip()})
last_end = end
if last_end < len(generated_text):
text = generated_text[last_end:].strip()
if text:
content.append({"type": "text", "value": text})
return Message(role="assistant", content=content)

View File

@@ -0,0 +1,13 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,259 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
from ....utils.constants import IGNORE_INDEX
from ....utils.helper import get_tokenizer
from ....utils.types import Message, ModelInput, Processor, ToolCall
from ..rendering import RenderingPlugin
def _update_model_input(
processor: Processor,
input_ids: list[int],
labels: list[int],
loss_weights: list[int],
temp_str: str,
temp_weight: float,
) -> str:
"""Update model input with temporary string."""
if not temp_str:
return ""
tokenizer = get_tokenizer(processor)
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([temp_weight] * len(temp_ids))
if temp_weight > 1e-6:
labels.extend(temp_ids)
else:
labels.extend([IGNORE_INDEX] * len(temp_ids))
return ""
def _concat_text_content(message: Message) -> str:
"""Concatenate text fields in a message."""
message_text = ""
for content in message["content"]:
if content["type"] == "text":
message_text += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
return message_text
def _get_last_query_index(messages: list[Message]) -> int:
"""Find the last user query index, excluding wrapped tool responses."""
last_query_index = len(messages) - 1
for idx in range(len(messages) - 1, -1, -1):
message = messages[idx]
if message["role"] != "user":
continue
user_text = ""
is_plain_text = True
for content in message["content"]:
if content["type"] != "text":
is_plain_text = False
break
user_text += content["value"]
if not is_plain_text:
continue
if not (user_text.startswith("<tool_response>") and user_text.endswith("</tool_response>")):
last_query_index = idx
break
return last_query_index
def _split_assistant_content(message: Message) -> tuple[str, str, list[ToolCall]]:
"""Split assistant message into text, reasoning and tool calls."""
text_content = ""
reasoning_content = ""
tool_calls: list[ToolCall] = []
for content in message["content"]:
if content["type"] == "text":
text_content += content["value"]
elif content["type"] == "reasoning":
reasoning_content += content["value"]
elif content["type"] == "tool_call":
try:
tool_call: ToolCall = json.loads(content["value"])
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {content['value']}.")
tool_calls.append(tool_call)
else:
raise ValueError(f"Unsupported content type: {content['type']}")
return text_content, reasoning_content, tool_calls
@RenderingPlugin("qwen3").register("render_messages")
def render_qwen3_messages(
processor: Processor,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
enable_thinking: bool = False,
) -> ModelInput:
"""Render messages in the Qwen3 template format.
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-8B
"""
input_ids, labels, loss_weights = [], [], []
temp_str, temp_weight = "", 0.0
if tools:
temp_str += "<|im_start|>system\n"
if messages[0]["role"] == "system":
temp_str += _concat_text_content(messages[0]) + "\n\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str += (
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
)
try:
tools = json.loads(tools)
except json.JSONDecodeError:
raise ValueError(f"Invalid tools format: {str(tools)}.")
if not isinstance(tools, list):
tools = [tools]
for tool in tools:
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
temp_str += (
"\n</tools>\n\nFor each function call, return a json object with function name "
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
)
elif messages[0]["role"] == "system":
temp_str += "<|im_start|>system\n" + _concat_text_content(messages[0]) + "<|im_end|>\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
last_query_index = _get_last_query_index(messages)
for turn_idx, message in enumerate(messages):
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
temp_str += "<|im_start|>" + message["role"] + "\n" + _concat_text_content(message) + "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
elif message["role"] == "assistant":
temp_str += "<|im_start|>" + message["role"] + "\n"
text_content, reasoning_content, tool_calls = _split_assistant_content(message)
if turn_idx > last_query_index and (turn_idx == len(messages) - 1 or reasoning_content):
temp_str += "<think>\n" + reasoning_content.strip("\n") + "\n</think>\n\n" + text_content.lstrip("\n")
else:
temp_str += text_content
for tool_call_idx, tool_call in enumerate(tool_calls):
if (tool_call_idx == 0 and text_content) or tool_call_idx > 0:
temp_str += "\n"
arguments = tool_call.get("arguments")
if isinstance(arguments, str):
arguments_str = arguments
else:
arguments_str = json.dumps(arguments, ensure_ascii=False)
temp_str += (
'<tool_call>\n{"name": "'
+ tool_call["name"]
+ '", "arguments": '
+ arguments_str
+ "}\n</tool_call>"
)
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 1.0)
elif message["role"] == "tool":
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
temp_str += "<|im_start|>user"
temp_str += "\n<tool_response>\n" + _concat_text_content(message) + "\n</tool_response>"
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
if is_generate:
temp_str += "<|im_start|>assistant\n"
temp_weight = 0.0
if enable_thinking is False:
temp_str += "<think>\n\n</think>\n\n"
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
attention_mask = [1] * len(input_ids)
return ModelInput(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
loss_weights=loss_weights,
)
@RenderingPlugin("qwen3").register("parse_message")
def parse_qwen3_message(generated_text: str) -> Message:
"""Parse a message in the Qwen3 template format. Supports interleaved reasoning and tool calls.
Args:
generated_text (str): The generated text in the Qwen3 template format.
Returns:
Message: The parsed message.
"""
pattern = re.compile(r"<(think|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
content = []
last_end = 0
for match in pattern.finditer(generated_text):
start, end = match.span()
if start > last_end:
text = generated_text[last_end:start].strip()
if text:
content.append({"type": "text", "value": text})
tag_type = match.group(1)
tag_value = match.group(2).strip()
if tag_type == "think":
content.append({"type": "reasoning", "value": tag_value.strip()})
elif tag_type == "tool_call":
try:
json.loads(tag_value.strip())
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
content.append({"type": "tool_call", "value": tag_value.strip()})
last_end = end
if last_end < len(generated_text):
text = generated_text[last_end:].strip()
if text:
content.append({"type": "text", "value": text})
return Message(role="assistant", content=content)

View File

@@ -0,0 +1,209 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
from ....utils.constants import IGNORE_INDEX
from ....utils.helper import get_tokenizer
from ....utils.types import Message, ModelInput, Processor, ToolCall
from ..rendering import RenderingPlugin
def _update_model_input(
processor: Processor,
input_ids: list[int],
labels: list[int],
loss_weights: list[int],
temp_str: str,
temp_weight: float,
) -> str:
"""Update model input with temporary string."""
if not temp_str:
return ""
tokenizer = get_tokenizer(processor)
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([temp_weight] * len(temp_ids))
if temp_weight > 1e-6:
labels.extend(temp_ids)
else:
labels.extend([IGNORE_INDEX] * len(temp_ids))
return ""
def _concat_text_content(message: Message) -> str:
"""Concatenate text fields in a message."""
message_text = ""
for content in message["content"]:
if content["type"] == "text":
message_text += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
return message_text
@RenderingPlugin("qwen3_nothink").register("render_messages")
def render_qwen3_nothink_messages(
processor: Processor,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
enable_thinking: bool = False,
) -> ModelInput:
"""Render messages in the Qwen3 nothink template format.
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-4B-Instruct-2507
"""
input_ids, labels, loss_weights = [], [], []
temp_str, temp_weight = "", 0.0
if tools:
temp_str += "<|im_start|>system\n"
if messages[0]["role"] == "system":
temp_str += _concat_text_content(messages[0]) + "\n\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str += (
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
)
try:
tools = json.loads(tools)
except json.JSONDecodeError:
raise ValueError(f"Invalid tools format: {str(tools)}.")
if not isinstance(tools, list):
tools = [tools]
for tool in tools:
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
temp_str += (
"\n</tools>\n\nFor each function call, return a json object with function name "
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
)
elif messages[0]["role"] == "system":
temp_str += "<|im_start|>system\n" + _concat_text_content(messages[0]) + "<|im_end|>\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
for turn_idx, message in enumerate(messages):
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
temp_str += "<|im_start|>" + message["role"] + "\n" + _concat_text_content(message) + "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
elif message["role"] == "assistant":
temp_str += "<|im_start|>" + message["role"] + "\n"
for val_idx, content in enumerate(message["content"]):
if content["type"] == "text":
temp_str += content["value"]
elif content["type"] == "reasoning":
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
elif content["type"] == "tool_call":
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
temp_str += "\n"
try:
tool_call: ToolCall = json.loads(content["value"])
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {content['value']}.")
temp_str += (
'<tool_call>\n{"name": "'
+ tool_call["name"]
+ '", "arguments": '
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
+ "}\n</tool_call>"
)
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 1.0)
elif message["role"] == "tool":
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
temp_str += "<|im_start|>user"
temp_str += "\n<tool_response>\n" + _concat_text_content(message) + "\n</tool_response>"
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
if is_generate:
temp_str += "<|im_start|>assistant\n"
temp_weight = 0.0
if enable_thinking:
raise ValueError("The qwen3_nothink template does not support thinking mode.")
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
attention_mask = [1] * len(input_ids)
return ModelInput(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
loss_weights=loss_weights,
)
@RenderingPlugin("qwen3_nothink").register("parse_message")
def parse_qwen3_nothink_message(generated_text: str) -> Message:
"""Parse a message in the Qwen3 nothink template format. Supports interleaved reasoning and tool calls.
Args:
generated_text (str): The generated text in the Qwen3 nothink template format.
Returns:
Message: The parsed message.
"""
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
content = []
last_end = 0
for match in pattern.finditer(generated_text):
start, end = match.span()
if start > last_end:
text = generated_text[last_end:start].strip()
if text:
content.append({"type": "text", "value": text})
tag_type = match.group(1)
tag_value = match.group(2).strip()
if tag_type == "thinking":
content.append({"type": "reasoning", "value": tag_value.strip()})
elif tag_type == "tool_call":
try:
json.loads(tag_value.strip())
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
content.append({"type": "tool_call", "value": tag_value.strip()})
last_end = end
if last_end < len(generated_text):
text = generated_text[last_end:].strip()
if text:
content.append({"type": "text", "value": text})
return Message(role="assistant", content=content)

View File

@@ -85,7 +85,7 @@ class DistributedConfig(TypedDict, total=False):
class Content(TypedDict):
type: Literal["text", "reasoning", "tool_call", "image_url"]
type: Literal["text", "reasoning", "tool_call", "image_url", "video_url", "audio_url"]
"""Type of the content."""
value: str
"""Value of the content."""