[v1] add batch generator (#9744)

This commit is contained in:
Yaowei Zheng
2026-01-10 04:24:09 +08:00
committed by GitHub
parent d7d734d54c
commit b2effbd77c
26 changed files with 604 additions and 850 deletions

View File

@@ -32,7 +32,8 @@ class AlpacaSample(TypedDict, total=False):
SharegptMessage = TypedDict(
"SharegptMessage", {"from": Literal["human", "gpt", "system", "function_call", "observation"], "value": str}
"SharegptMessage",
{"from": Literal["human", "gpt", "system", "function_call", "observation"], "value": str},
)
@@ -118,15 +119,8 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
"observation": "tool",
"function_call": "assistant",
}
sample = {}
messages = []
tools = raw_sample.get("tools")
if tools:
try:
tools: list[dict[str, Any]] = json.loads(tools)
except json.JSONDecodeError:
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
tools = []
for message in raw_sample.get("conversations", []):
tag = message["from"]
if tag not in tag_mapping:
@@ -157,10 +151,17 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
}
)
sample["messages"] = messages
tools = raw_sample.get("tools")
if tools:
return {"messages": messages, "tools": json.dumps(tools)}
else:
return {"messages": messages}
try:
tools: list[dict[str, Any]] = json.loads(tools)
sample["tools"] = json.dumps(tools)
except json.JSONDecodeError:
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
return sample
@DataConverterPlugin("pair").register()
@@ -179,17 +180,44 @@ def pair_converter(raw_sample: PairSample) -> DPOSample:
def process_message(raw_messages: list[OpenaiMessage]):
messages = []
for message in raw_messages:
messages.append(
{
"role": message["role"],
"content": [{"type": "text", "value": message["content"]}],
"loss_weight": 1.0 if message["role"] == "assistant" else 0.0,
}
)
if message["role"] == "tool":
try:
tool_calls: ToolCall | list[ToolCall] = json.loads(message["content"])
except json.JSONDecodeError:
logger.warning_rank0(f"Invalid tool call format: {str(message['content'])}")
continue
if not isinstance(tool_calls, list):
tool_calls = [tool_calls]
messages.append(
{
"role": message["role"],
"content": [{"type": "tool_call", "value": json.dumps(tool_call)} for tool_call in tool_calls],
"loss_weight": 1.0 if message["role"] == "assistant" else 0.0,
}
)
else:
messages.append(
{
"role": message["role"],
"content": [{"type": "text", "value": message["content"]}],
"loss_weight": 1.0 if message["role"] == "assistant" else 0.0,
}
)
return messages
chosen_messages = process_message(raw_sample.get("chosen", []))
rejected_messages = process_message(raw_sample.get("rejected", []))
sample = {}
sample["chosen_messages"] = process_message(raw_sample.get("chosen", []))
sample["rejected_messages"] = process_message(raw_sample.get("rejected", []))
return {"chosen_messages": chosen_messages, "rejected_messages": rejected_messages}
tools = raw_sample.get("tools")
if tools:
try:
tools: list[dict[str, Any]] = json.loads(tools)
sample["tools"] = json.dumps(tools)
except json.JSONDecodeError:
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
return sample

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
@@ -51,12 +50,16 @@ def _update_model_input(
@RenderingPlugin("qwen3_nothink").register("render_messages")
def render_qwen_messages(
def render_qwen3_nothink_messages(
processor: Processor,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
) -> ModelInput:
"""Render messages in the Qwen3 nothink template format.
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-4B-Instruct-2507
"""
input_ids, labels, loss_weights = [], [], []
temp_str, temp_weight = "", 0.0
if tools:
@@ -179,7 +182,15 @@ def render_qwen_messages(
@RenderingPlugin("qwen3_nothink").register("parse_message")
def parse_qwen_message(generated_text: str) -> Message:
def parse_qwen3_nothink_message(generated_text: str) -> Message:
"""Parse a message in the Qwen3 nothink template format. Supports interleaved reasoning and tool calls.
Args:
generated_text (str): The generated text in the Qwen3 nothink template format.
Returns:
Message: The parsed message.
"""
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
content = []
last_end = 0

View File

@@ -0,0 +1,19 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils.plugin import BasePlugin
class BatchingPlugin(BasePlugin):
pass