mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-01-29 22:02:03 +00:00
213 lines
7.8 KiB
Python
213 lines
7.8 KiB
Python
# Copyright 2025 the LlamaFactory team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
import json
|
|
import re
|
|
|
|
from ...utils.constants import IGNORE_INDEX
|
|
from ...utils.helper import get_tokenizer
|
|
from ...utils.plugin import BasePlugin
|
|
from ...utils.types import Message, ModelInput, Processor, ToolCall
|
|
|
|
|
|
class RenderingPlugin(BasePlugin):
|
|
pass
|
|
|
|
|
|
def _update_model_input(
|
|
processor: Processor,
|
|
input_ids: list[int],
|
|
labels: list[int],
|
|
loss_weights: list[int],
|
|
temp_str: str,
|
|
temp_weight: float,
|
|
) -> str:
|
|
"""Update model input with temporary string."""
|
|
if not temp_str:
|
|
return ""
|
|
|
|
tokenizer = get_tokenizer(processor)
|
|
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
|
|
input_ids.extend(temp_ids)
|
|
loss_weights.extend([temp_weight] * len(temp_ids))
|
|
if temp_weight > 1e-6:
|
|
labels.extend(temp_ids)
|
|
else:
|
|
labels.extend([IGNORE_INDEX] * len(temp_ids))
|
|
|
|
return ""
|
|
|
|
|
|
@RenderingPlugin("qwen3_nothink").register("render_messages")
|
|
def render_qwen_messages(
|
|
processor: Processor,
|
|
messages: list[Message],
|
|
tools: str | None = None,
|
|
is_generate: bool = False,
|
|
) -> ModelInput:
|
|
input_ids, labels, loss_weights = [], [], []
|
|
temp_str, temp_weight = "", 0.0
|
|
if tools:
|
|
temp_str += "<|im_start|>system\n"
|
|
if messages[0]["role"] == "system":
|
|
for content in messages[0]["content"]:
|
|
if content["type"] == "text":
|
|
temp_str += content["value"]
|
|
else:
|
|
raise ValueError(f"Unsupported content type: {content['type']}")
|
|
|
|
temp_str += "\n\n"
|
|
temp_weight = messages[0].get("loss_weight", 0.0)
|
|
|
|
temp_str += (
|
|
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
|
|
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
|
|
)
|
|
try:
|
|
tools = json.loads(tools)
|
|
except json.JSONDecodeError:
|
|
raise ValueError(f"Invalid tools format: {str(tools)}.")
|
|
|
|
if not isinstance(tools, list):
|
|
tools = [tools]
|
|
|
|
for tool in tools:
|
|
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
|
|
|
|
temp_str += (
|
|
"\n</tools>\n\nFor each function call, return a json object with function name "
|
|
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
|
|
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
|
|
)
|
|
elif messages[0]["role"] == "system":
|
|
temp_str += "<|im_start|>system\n"
|
|
for content in messages[0]["content"]:
|
|
if content["type"] == "text":
|
|
temp_str += content["value"]
|
|
else:
|
|
raise ValueError(f"Unsupported content type: {content['type']}")
|
|
|
|
temp_str += "<|im_end|>\n"
|
|
temp_weight = messages[0].get("loss_weight", 0.0)
|
|
|
|
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
|
|
|
for turn_idx, message in enumerate(messages):
|
|
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
|
|
temp_str += "<|im_start|>" + message["role"] + "\n"
|
|
for content in message["content"]:
|
|
if content["type"] == "text":
|
|
temp_str += content["value"]
|
|
else:
|
|
raise ValueError(f"Unsupported content type: {content['type']}")
|
|
|
|
temp_str += "<|im_end|>\n"
|
|
temp_weight = message.get("loss_weight", 0.0)
|
|
elif message["role"] == "assistant":
|
|
temp_str += "<|im_start|>" + message["role"] + "\n"
|
|
for val_idx, content in enumerate(message["content"]):
|
|
if content["type"] == "text":
|
|
temp_str += content["value"]
|
|
elif content["type"] == "reasoning":
|
|
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
|
|
elif content["type"] == "tool_call":
|
|
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
|
|
temp_str += "\n"
|
|
|
|
try:
|
|
tool_call: ToolCall = json.loads(content["value"])
|
|
except json.JSONDecodeError:
|
|
raise ValueError(f"Invalid tool call format: {content['value']}.")
|
|
|
|
temp_str += (
|
|
'<tool_call>\n{"name": "'
|
|
+ tool_call["name"]
|
|
+ '", "arguments": '
|
|
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
|
|
+ "}\n</tool_call>"
|
|
)
|
|
|
|
else:
|
|
raise ValueError(f"Unsupported content type: {content['type']}")
|
|
|
|
temp_str += "<|im_end|>\n"
|
|
temp_weight = message.get("loss_weight", 1.0)
|
|
elif message["role"] == "tool":
|
|
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
|
|
temp_str += "<|im_start|>user"
|
|
|
|
temp_str += "\n<tool_response>\n"
|
|
for content in message["content"]:
|
|
if content["type"] == "text":
|
|
temp_str += content["value"]
|
|
else:
|
|
raise ValueError(f"Unsupported content type: {content['type']}")
|
|
|
|
temp_str += "\n</tool_response>"
|
|
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
|
|
temp_str += "<|im_end|>\n"
|
|
|
|
temp_weight = message.get("loss_weight", 0.0)
|
|
|
|
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
|
|
|
if is_generate:
|
|
temp_str += "<|im_start|>assistant\n"
|
|
temp_weight = 0.0
|
|
|
|
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
|
|
|
attention_mask = [1] * len(input_ids)
|
|
return ModelInput(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
labels=labels,
|
|
loss_weights=loss_weights,
|
|
)
|
|
|
|
|
|
@RenderingPlugin("qwen3_nothink").register("parse_message")
|
|
def parse_qwen_message(generated_text: str) -> Message:
|
|
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
|
|
content = []
|
|
last_end = 0
|
|
for match in pattern.finditer(generated_text):
|
|
start, end = match.span()
|
|
if start > last_end:
|
|
text = generated_text[last_end:start].strip()
|
|
if text:
|
|
content.append({"type": "text", "value": text})
|
|
|
|
tag_type = match.group(1)
|
|
tag_value = match.group(2).strip()
|
|
if tag_type == "thinking":
|
|
content.append({"type": "reasoning", "value": tag_value.strip()})
|
|
elif tag_type == "tool_call":
|
|
try:
|
|
json.loads(tag_value.strip())
|
|
except json.JSONDecodeError:
|
|
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
|
|
|
|
content.append({"type": "tool_call", "value": tag_value.strip()})
|
|
|
|
last_end = end
|
|
|
|
if last_end < len(generated_text):
|
|
text = generated_text[last_end:].strip()
|
|
if text:
|
|
content.append({"type": "text", "value": text})
|
|
|
|
return Message(role="assistant", content=content)
|