mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-02 20:43:38 +00:00
[v1] add renderer ut (#9722)
This commit is contained in:
@@ -25,7 +25,7 @@ class ModelArguments:
|
||||
metadata={"help": "Path to the model or model identifier from Hugging Face."},
|
||||
)
|
||||
template: str = field(
|
||||
default="chatml",
|
||||
default="qwen3_nothink",
|
||||
metadata={"help": "Template for the model."},
|
||||
)
|
||||
trust_remote_code: bool = field(
|
||||
|
||||
@@ -12,38 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from ...utils.constants import IGNORE_INDEX
|
||||
from ...utils.helper import get_tokenizer
|
||||
from ...utils.types import Message, ModelInput, Processor
|
||||
|
||||
|
||||
def _update_model_input(
|
||||
processor: Processor,
|
||||
input_ids: list[int],
|
||||
labels: list[int],
|
||||
loss_weights: list[int],
|
||||
temp_str: str,
|
||||
temp_weight: float,
|
||||
) -> str:
|
||||
"""Update model input with temporary string."""
|
||||
if not temp_str:
|
||||
return ""
|
||||
|
||||
tokenizer = get_tokenizer(processor)
|
||||
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
|
||||
input_ids.extend(temp_ids)
|
||||
loss_weights.extend([temp_weight] * len(temp_ids))
|
||||
if temp_weight > 1e-6:
|
||||
labels.extend(temp_ids)
|
||||
else:
|
||||
labels.extend([IGNORE_INDEX] * len(temp_ids))
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def render_chatml_messages(
|
||||
processor: Processor,
|
||||
messages: list[Message],
|
||||
@@ -52,123 +26,38 @@ def render_chatml_messages(
|
||||
) -> ModelInput:
|
||||
"""Apply chatml template to messages and convert them to model input.
|
||||
|
||||
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground
|
||||
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen2-7B-Instruct
|
||||
"""
|
||||
tokenizer = get_tokenizer(processor)
|
||||
input_ids, labels, loss_weights = [], [], []
|
||||
temp_str, temp_weight = "", 0.0
|
||||
if tools:
|
||||
temp_str += "<|im_start|>system\n"
|
||||
if messages[0]["role"] == "system":
|
||||
for content in messages[0]["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "\n\n"
|
||||
temp_weight = messages[0].get("loss_weight", 0.0)
|
||||
|
||||
temp_str += (
|
||||
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
|
||||
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
|
||||
)
|
||||
try:
|
||||
tools = json.loads(tools)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid tools format: {str(tools)}.")
|
||||
|
||||
if not isinstance(tools, list):
|
||||
tools = [tools]
|
||||
|
||||
for tool in tools:
|
||||
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
|
||||
|
||||
temp_str += (
|
||||
"\n</tools>\n\nFor each function call, return a json object with function name "
|
||||
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
|
||||
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
|
||||
)
|
||||
elif messages[0]["role"] == "system":
|
||||
temp_str += "<|im_start|>system\n"
|
||||
for content in messages[0]["content"]:
|
||||
for message in messages:
|
||||
temp_str = "<|im_start|>" + message["role"] + "\n"
|
||||
for content in message["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "<|im_end|>\n"
|
||||
temp_weight = messages[0].get("loss_weight", 0.0)
|
||||
|
||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||
|
||||
for turn_idx, message in enumerate(messages):
|
||||
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
|
||||
temp_str += "<|im_start|>" + message["role"] + "\n"
|
||||
for content in message["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "<|im_end|>\n"
|
||||
temp_weight = message.get("loss_weight", 0.0)
|
||||
elif message["role"] == "assistant":
|
||||
temp_str += "<|im_start|>" + message["role"] + "\n"
|
||||
for val_idx, content in enumerate(message["content"]):
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
elif content["type"] == "reasoning":
|
||||
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
|
||||
elif content["type"] == "tool_call":
|
||||
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
|
||||
temp_str += "\n"
|
||||
|
||||
try:
|
||||
tool_call = json.loads(content["value"])
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid tool call format: {content['value']}.")
|
||||
temp_str += (
|
||||
'<tool_call>\n{"name": "'
|
||||
+ tool_call["name"]
|
||||
+ '", "arguments": '
|
||||
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
|
||||
+ "}\n</tool_call>"
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "<|im_end|>\n"
|
||||
temp_weight = message.get("loss_weight", 1.0)
|
||||
elif message["role"] == "tool":
|
||||
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
|
||||
temp_str += "<|im_start|>user"
|
||||
|
||||
temp_str += "\n<tool_response>\n"
|
||||
for content in message["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "\n</tool_response>"
|
||||
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
|
||||
temp_str += "<|im_end|>\n"
|
||||
|
||||
temp_weight = message.get("loss_weight", 0.0)
|
||||
|
||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||
temp_weight = message.get("loss_weight", 1.0 if message["role"] == "assistant" else 0.0)
|
||||
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
|
||||
input_ids.extend(temp_ids)
|
||||
loss_weights.extend([temp_weight] * len(temp_ids))
|
||||
if temp_weight > 1e-6:
|
||||
labels.extend(temp_ids)
|
||||
else:
|
||||
labels.extend([IGNORE_INDEX] * len(temp_ids))
|
||||
|
||||
if is_generate:
|
||||
temp_str += "<|im_start|>assistant\n"
|
||||
temp_weight = 0.0
|
||||
temp_ids = tokenizer.encode("<|im_start|>assistant\n", add_special_tokens=False)
|
||||
input_ids.extend(temp_ids)
|
||||
loss_weights.extend([0.0] * len(temp_ids))
|
||||
labels.extend([IGNORE_INDEX] * len(temp_ids))
|
||||
|
||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||
|
||||
attention_mask = [1] * len(input_ids)
|
||||
return ModelInput(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
attention_mask=[1] * len(input_ids),
|
||||
labels=labels,
|
||||
loss_weights=loss_weights,
|
||||
)
|
||||
@@ -183,36 +72,7 @@ def parse_chatml_message(generated_text: str) -> Message:
|
||||
Returns:
|
||||
Message: The parsed message.
|
||||
"""
|
||||
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
|
||||
content = []
|
||||
last_end = 0
|
||||
for match in pattern.finditer(generated_text):
|
||||
start, end = match.span()
|
||||
if start > last_end:
|
||||
text = generated_text[last_end:start].strip()
|
||||
if text:
|
||||
content.append({"type": "text", "value": text})
|
||||
|
||||
tag_type = match.group(1)
|
||||
tag_value = match.group(2).strip()
|
||||
if tag_type == "thinking":
|
||||
content.append({"type": "reasoning", "value": tag_value.strip()})
|
||||
elif tag_type == "tool_call":
|
||||
try:
|
||||
json.loads(tag_value.strip())
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
|
||||
|
||||
content.append({"type": "tool_call", "value": tag_value.strip()})
|
||||
|
||||
last_end = end
|
||||
|
||||
if last_end < len(generated_text):
|
||||
text = generated_text[last_end:].strip()
|
||||
if text:
|
||||
content.append({"type": "text", "value": text})
|
||||
|
||||
return Message(role="assistant", content=content)
|
||||
return Message(role="assistant", content=[{"type": "text", "value": generated_text}])
|
||||
|
||||
|
||||
class Renderer:
|
||||
|
||||
@@ -158,7 +158,7 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
|
||||
)
|
||||
|
||||
if tools:
|
||||
return {"messages": messages, "extra_info": json.dumps({"tools": tools})}
|
||||
return {"messages": messages, "tools": json.dumps(tools)}
|
||||
else:
|
||||
return {"messages": messages}
|
||||
|
||||
|
||||
@@ -13,24 +13,200 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from ...utils.constants import IGNORE_INDEX
|
||||
from ...utils.helper import get_tokenizer
|
||||
from ...utils.plugin import BasePlugin
|
||||
from ...utils.types import Message, ModelInput, Processor
|
||||
from ...utils.types import Message, ModelInput, Processor, ToolCall
|
||||
|
||||
|
||||
class RenderingPlugin(BasePlugin):
|
||||
pass
|
||||
|
||||
|
||||
@RenderingPlugin("qwen").register("render_messages")
|
||||
def _update_model_input(
|
||||
processor: Processor,
|
||||
input_ids: list[int],
|
||||
labels: list[int],
|
||||
loss_weights: list[int],
|
||||
temp_str: str,
|
||||
temp_weight: float,
|
||||
) -> str:
|
||||
"""Update model input with temporary string."""
|
||||
if not temp_str:
|
||||
return ""
|
||||
|
||||
tokenizer = get_tokenizer(processor)
|
||||
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
|
||||
input_ids.extend(temp_ids)
|
||||
loss_weights.extend([temp_weight] * len(temp_ids))
|
||||
if temp_weight > 1e-6:
|
||||
labels.extend(temp_ids)
|
||||
else:
|
||||
labels.extend([IGNORE_INDEX] * len(temp_ids))
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
@RenderingPlugin("qwen3_nothink").register("render_messages")
|
||||
def render_qwen_messages(
|
||||
processor: Processor,
|
||||
messages: list[Message],
|
||||
tools: str | None = None,
|
||||
is_generate: bool = False,
|
||||
) -> ModelInput:
|
||||
raise NotImplementedError()
|
||||
input_ids, labels, loss_weights = [], [], []
|
||||
temp_str, temp_weight = "", 0.0
|
||||
if tools:
|
||||
temp_str += "<|im_start|>system\n"
|
||||
if messages[0]["role"] == "system":
|
||||
for content in messages[0]["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "\n\n"
|
||||
temp_weight = messages[0].get("loss_weight", 0.0)
|
||||
|
||||
temp_str += (
|
||||
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
|
||||
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
|
||||
)
|
||||
try:
|
||||
tools = json.loads(tools)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid tools format: {str(tools)}.")
|
||||
|
||||
if not isinstance(tools, list):
|
||||
tools = [tools]
|
||||
|
||||
for tool in tools:
|
||||
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
|
||||
|
||||
temp_str += (
|
||||
"\n</tools>\n\nFor each function call, return a json object with function name "
|
||||
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
|
||||
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
|
||||
)
|
||||
elif messages[0]["role"] == "system":
|
||||
temp_str += "<|im_start|>system\n"
|
||||
for content in messages[0]["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "<|im_end|>\n"
|
||||
temp_weight = messages[0].get("loss_weight", 0.0)
|
||||
|
||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||
|
||||
for turn_idx, message in enumerate(messages):
|
||||
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
|
||||
temp_str += "<|im_start|>" + message["role"] + "\n"
|
||||
for content in message["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "<|im_end|>\n"
|
||||
temp_weight = message.get("loss_weight", 0.0)
|
||||
elif message["role"] == "assistant":
|
||||
temp_str += "<|im_start|>" + message["role"] + "\n"
|
||||
for val_idx, content in enumerate(message["content"]):
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
elif content["type"] == "reasoning":
|
||||
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
|
||||
elif content["type"] == "tool_call":
|
||||
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
|
||||
temp_str += "\n"
|
||||
|
||||
try:
|
||||
tool_call: ToolCall = json.loads(content["value"])
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid tool call format: {content['value']}.")
|
||||
|
||||
temp_str += (
|
||||
'<tool_call>\n{"name": "'
|
||||
+ tool_call["name"]
|
||||
+ '", "arguments": '
|
||||
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
|
||||
+ "}\n</tool_call>"
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "<|im_end|>\n"
|
||||
temp_weight = message.get("loss_weight", 1.0)
|
||||
elif message["role"] == "tool":
|
||||
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
|
||||
temp_str += "<|im_start|>user"
|
||||
|
||||
temp_str += "\n<tool_response>\n"
|
||||
for content in message["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "\n</tool_response>"
|
||||
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
|
||||
temp_str += "<|im_end|>\n"
|
||||
|
||||
temp_weight = message.get("loss_weight", 0.0)
|
||||
|
||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||
|
||||
if is_generate:
|
||||
temp_str += "<|im_start|>assistant\n"
|
||||
temp_weight = 0.0
|
||||
|
||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||
|
||||
attention_mask = [1] * len(input_ids)
|
||||
return ModelInput(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
labels=labels,
|
||||
loss_weights=loss_weights,
|
||||
)
|
||||
|
||||
|
||||
@RenderingPlugin("qwen").register("parse_message")
|
||||
@RenderingPlugin("qwen3_nothink").register("parse_message")
|
||||
def parse_qwen_message(generated_text: str) -> Message:
|
||||
raise NotImplementedError()
|
||||
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
|
||||
content = []
|
||||
last_end = 0
|
||||
for match in pattern.finditer(generated_text):
|
||||
start, end = match.span()
|
||||
if start > last_end:
|
||||
text = generated_text[last_end:start].strip()
|
||||
if text:
|
||||
content.append({"type": "text", "value": text})
|
||||
|
||||
tag_type = match.group(1)
|
||||
tag_value = match.group(2).strip()
|
||||
if tag_type == "thinking":
|
||||
content.append({"type": "reasoning", "value": tag_value.strip()})
|
||||
elif tag_type == "tool_call":
|
||||
try:
|
||||
json.loads(tag_value.strip())
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
|
||||
|
||||
content.append({"type": "tool_call", "value": tag_value.strip()})
|
||||
|
||||
last_end = end
|
||||
|
||||
if last_end < len(generated_text):
|
||||
text = generated_text[last_end:].strip()
|
||||
if text:
|
||||
content.append({"type": "text", "value": text})
|
||||
|
||||
return Message(role="assistant", content=content)
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict, Union
|
||||
from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, Union
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -102,8 +102,10 @@ class Message(TypedDict):
|
||||
class SFTSample(TypedDict):
|
||||
messages: list[Message]
|
||||
"""Messages in the sample."""
|
||||
tools: NotRequired[str]
|
||||
"""Tools for the sample in JSON string format."""
|
||||
extra_info: NotRequired[str]
|
||||
"""Extra information for the sample, including tools, kto_labels."""
|
||||
"""Extra information for the sample, e.g. kto_labels."""
|
||||
_dataset_name: NotRequired[str]
|
||||
"""Dataset name for the sample."""
|
||||
|
||||
@@ -113,8 +115,10 @@ class DPOSample(TypedDict):
|
||||
"""Chosen messages in the sample."""
|
||||
rejected_messages: list[Message]
|
||||
"""Rejected messages in the sample."""
|
||||
tools: NotRequired[str]
|
||||
"""Tools for the sample in JSON string format."""
|
||||
extra_info: NotRequired[str]
|
||||
"""Extra information for the sample, including tools, kto_labels."""
|
||||
"""Extra information for the sample, e.g. kto_labels."""
|
||||
_dataset_name: NotRequired[str]
|
||||
"""Dataset name for the sample."""
|
||||
|
||||
@@ -125,7 +129,7 @@ Sample = Union[SFTSample, DPOSample]
|
||||
class ToolCall(TypedDict):
|
||||
name: str
|
||||
"""Function name."""
|
||||
arguments: str
|
||||
arguments: dict[str, Any]
|
||||
"""Function arguments."""
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user