# Copyright 2025 the LlamaFactory team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import re from ...utils.constants import IGNORE_INDEX from ...utils.helper import get_tokenizer from ...utils.plugin import BasePlugin from ...utils.types import Message, ModelInput, Processor, ToolCall class RenderingPlugin(BasePlugin): pass def _update_model_input( processor: Processor, input_ids: list[int], labels: list[int], loss_weights: list[int], temp_str: str, temp_weight: float, ) -> str: """Update model input with temporary string.""" if not temp_str: return "" tokenizer = get_tokenizer(processor) temp_ids = tokenizer.encode(temp_str, add_special_tokens=False) input_ids.extend(temp_ids) loss_weights.extend([temp_weight] * len(temp_ids)) if temp_weight > 1e-6: labels.extend(temp_ids) else: labels.extend([IGNORE_INDEX] * len(temp_ids)) return "" @RenderingPlugin("qwen3_nothink").register("render_messages") def render_qwen_messages( processor: Processor, messages: list[Message], tools: str | None = None, is_generate: bool = False, ) -> ModelInput: input_ids, labels, loss_weights = [], [], [] temp_str, temp_weight = "", 0.0 if tools: temp_str += "<|im_start|>system\n" if messages[0]["role"] == "system": for content in messages[0]["content"]: if content["type"] == "text": temp_str += content["value"] else: raise ValueError(f"Unsupported content type: {content['type']}") temp_str += "\n\n" temp_weight = messages[0].get("loss_weight", 0.0) temp_str += ( "# Tools\n\nYou may call one or more functions to assist with the user query.\n\n" "You are provided with function signatures within XML tags:\n" ) try: tools = json.loads(tools) except json.JSONDecodeError: raise ValueError(f"Invalid tools format: {str(tools)}.") if not isinstance(tools, list): tools = [tools] for tool in tools: temp_str += "\n" + json.dumps(tool, ensure_ascii=False) temp_str += ( "\n\n\nFor each function call, return a json object with function name " 'and arguments within XML tags:\n\n{"name": ' ', "arguments": }\n<|im_end|>\n' ) elif messages[0]["role"] == "system": temp_str += "<|im_start|>system\n" for content in messages[0]["content"]: if content["type"] == "text": temp_str += content["value"] else: raise ValueError(f"Unsupported content type: {content['type']}") temp_str += "<|im_end|>\n" temp_weight = messages[0].get("loss_weight", 0.0) temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight) for turn_idx, message in enumerate(messages): if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0): temp_str += "<|im_start|>" + message["role"] + "\n" for content in message["content"]: if content["type"] == "text": temp_str += content["value"] else: raise ValueError(f"Unsupported content type: {content['type']}") temp_str += "<|im_end|>\n" temp_weight = message.get("loss_weight", 0.0) elif message["role"] == "assistant": temp_str += "<|im_start|>" + message["role"] + "\n" for val_idx, content in enumerate(message["content"]): if content["type"] == "text": temp_str += content["value"] elif content["type"] == "reasoning": temp_str += "\n" + content["value"] + "\n\n\n" # avoid using special tokens elif content["type"] == "tool_call": if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]: temp_str += "\n" try: tool_call: ToolCall = json.loads(content["value"]) except json.JSONDecodeError: raise ValueError(f"Invalid tool call format: {content['value']}.") temp_str += ( '\n{"name": "' + tool_call["name"] + '", "arguments": ' + json.dumps(tool_call["arguments"], ensure_ascii=False) + "}\n" ) else: raise ValueError(f"Unsupported content type: {content['type']}") temp_str += "<|im_end|>\n" temp_weight = message.get("loss_weight", 1.0) elif message["role"] == "tool": if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool": temp_str += "<|im_start|>user" temp_str += "\n\n" for content in message["content"]: if content["type"] == "text": temp_str += content["value"] else: raise ValueError(f"Unsupported content type: {content['type']}") temp_str += "\n" if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool": temp_str += "<|im_end|>\n" temp_weight = message.get("loss_weight", 0.0) temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight) if is_generate: temp_str += "<|im_start|>assistant\n" temp_weight = 0.0 temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight) attention_mask = [1] * len(input_ids) return ModelInput( input_ids=input_ids, attention_mask=attention_mask, labels=labels, loss_weights=loss_weights, ) @RenderingPlugin("qwen3_nothink").register("parse_message") def parse_qwen_message(generated_text: str) -> Message: pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*\s*", re.DOTALL) content = [] last_end = 0 for match in pattern.finditer(generated_text): start, end = match.span() if start > last_end: text = generated_text[last_end:start].strip() if text: content.append({"type": "text", "value": text}) tag_type = match.group(1) tag_value = match.group(2).strip() if tag_type == "thinking": content.append({"type": "reasoning", "value": tag_value.strip()}) elif tag_type == "tool_call": try: json.loads(tag_value.strip()) except json.JSONDecodeError: raise ValueError(f"Invalid tool call format: {tag_value.strip()}.") content.append({"type": "tool_call", "value": tag_value.strip()}) last_end = end if last_end < len(generated_text): text = generated_text[last_end:].strip() if text: content.append({"type": "text", "value": text}) return Message(role="assistant", content=content)