[v1] add batch generator (#9744)

This commit is contained in:
Yaowei Zheng
2026-01-10 04:24:09 +08:00
committed by GitHub
parent d7d734d54c
commit b2effbd77c
26 changed files with 604 additions and 850 deletions

View File

@@ -13,7 +13,7 @@
# limitations under the License.
from .arg_parser import InputArgument, get_args
from .arg_utils import ModelClass, SampleBackend
from .arg_utils import BatchingStrategy, ModelClass, SampleBackend
from .data_args import DataArguments
from .model_args import ModelArguments
from .sample_args import SampleArguments
@@ -21,6 +21,7 @@ from .training_args import TrainingArguments
__all__ = [
"BatchingStrategy",
"DataArguments",
"InputArgument",
"ModelArguments",

View File

@@ -50,6 +50,14 @@ class SampleBackend(StrEnum):
VLLM = "vllm"
@unique
class BatchingStrategy(StrEnum):
NORMAL = "normal"
PADDING_FREE = "padding_free"
DYNAMIC_BATCHING = "dynamic_batching"
DYNAMIC_PADDING_FREE = "dynamic_padding_free"
def _convert_str_dict(data: dict) -> dict:
"""Parse string representation inside the dictionary.

View File

@@ -22,7 +22,3 @@ class DataArguments:
default=None,
metadata={"help": "Path to the dataset."},
)
cutoff_len: int = field(
default=2048,
metadata={"help": "Cutoff length for the dataset."},
)

View File

@@ -16,7 +16,7 @@ import os
from dataclasses import dataclass, field
from uuid import uuid4
from .arg_utils import PluginConfig, get_plugin_config
from .arg_utils import BatchingStrategy, PluginConfig, get_plugin_config
@dataclass
@@ -29,18 +29,30 @@ class TrainingArguments:
default=1,
metadata={"help": "Micro batch size for training."},
)
global_batch_size: int = field(
default=1,
metadata={"help": "Global batch size for training."},
global_batch_size: int | None = field(
default=None,
metadata={"help": "Global batch size for training, default to DP size * micro batch size."},
)
learning_rate: float = field(
default=1e-4,
metadata={"help": "Learning rate for training."},
)
cutoff_len: int = field(
default=2048,
metadata={"help": "Maximum sequence length for training."},
)
bf16: bool = field(
default=False,
metadata={"help": "Use bf16 for training."},
)
batching_strategy: BatchingStrategy = field(
default=BatchingStrategy.NORMAL,
metadata={"help": "Batching strategy for training."},
)
batching_workers: int = field(
default=16,
metadata={"help": "Number of workers for batching."},
)
dist_config: PluginConfig | None = field(
default=None,
metadata={"help": "Distribution configuration for training."},

View File

@@ -29,7 +29,6 @@ Train Phase:
from ..config.training_args import TrainingArguments
from ..utils.types import HFModel, TorchDataset
from .utils.data_collator import DataCollator
from .utils.rendering import Renderer
@@ -45,7 +44,6 @@ class BaseTrainer:
self.model = model
self.renderer = renderer
self.dataset = dataset
self.data_collator = DataCollator()
self.optimizer = None
self.lr_scheduler = None

View File

@@ -82,14 +82,17 @@ class DataEngine(Dataset):
def _load_dataset(self) -> None:
"""Load datasets according to dataset info."""
is_streaming = [dataset_info.get("streaming", False) for dataset_info in self.dataset_infos.values()]
self.streaming = any(is_streaming)
if all(is_streaming) != any(is_streaming):
raise ValueError("All datasets must be streaming or non-streaming.")
for dataset_name, dataset_info in self.dataset_infos.items():
split = dataset_info.get("split", "train")
streaming = dataset_info.get("streaming", False)
self.streaming |= streaming
if dataset_info.get("source", "hf_hub") == "hf_hub":
from datasets import load_dataset
self.datasets[dataset_name] = load_dataset(dataset_info["path"], split=split, streaming=streaming)
self.datasets[dataset_name] = load_dataset(dataset_info["path"], split=split, streaming=self.streaming)
else: # data loader plugin
from ..plugins.data_plugins.loader import DataLoaderPlugin
@@ -98,8 +101,7 @@ class DataEngine(Dataset):
def _build_data_index(self) -> None:
"""Build dataset index."""
for dataset_name, dataset in self.datasets.items():
streaming = self.dataset_infos[dataset_name].get("streaming", False)
if streaming:
if self.streaming:
data_index = [(dataset_name, -1) for _ in range(1000)]
else:
data_index = [(dataset_name, sample_index) for sample_index in range(len(dataset))]
@@ -185,8 +187,8 @@ class DataEngine(Dataset):
if __name__ == "__main__":
"""
python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_sft_demo.yaml
python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_dpo_demo.yaml
python -m llamafactory.v1.core.data_engine --dataset data/v1_sft_demo.yaml
python -m llamafactory.v1.core.data_engine --dataset data/v1_dpo_demo.yaml
"""
from ..config.arg_parser import get_args

View File

@@ -0,0 +1,244 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Batching utils supports stateful dataloader.
1. Init stateful dataloader (tokenize)
2. Add to buffer
3. Yield batch indexes (micro batch * grad acc)
a) non pack + non dynamic
b) non pack + dynamic
c) pack + non dynamic
d) pack + dynamic
"""
from collections.abc import Iterator
from typing import Any
from torch.utils.data import default_collate
from torchdata.stateful_dataloader import StatefulDataLoader
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
from ...accelerator.interface import DistributedInterface
from ...config import BatchingStrategy
from ...utils import logging
from ...utils.helper import pad_and_truncate
from ...utils.types import BatchInput, ModelInput, TorchDataset
from .rendering import Renderer
logger = logging.get_logger(__name__)
def default_collate_fn(
buffer: list[ModelInput], buffer_tokens: int, micro_batch_size: int, num_micro_batch: int, cutoff_len: int
) -> tuple[list[ModelInput], int, list[BatchInput]]:
batch_size = micro_batch_size * num_micro_batch
if len(buffer) < batch_size:
return buffer, buffer_tokens, None
samples = buffer[:batch_size]
buffer = buffer[batch_size:]
buffer_tokens -= sum(len(sample["input_ids"]) for sample in samples)
batch = []
for i in range(num_micro_batch):
micro_batch = samples[i * micro_batch_size : (i + 1) * micro_batch_size]
batch.append(default_collate(pad_and_truncate(micro_batch, cutoff_len)))
return buffer, buffer_tokens, batch
class BatchGenerator(Iterator):
def __init__(
self,
dataset: TorchDataset,
renderer: Renderer,
micro_batch_size: int = 1,
global_batch_size: int | None = None,
cutoff_len: int = 2048,
batching_workers: int = 0,
batching_strategy: BatchingStrategy = BatchingStrategy.NORMAL,
pin_memory: bool = True,
drop_last: bool = True,
) -> None:
self.dataset = dataset
self.renderer = renderer
self.micro_batch_size = micro_batch_size
self.global_batch_size = global_batch_size
self.cutoff_len = cutoff_len
self.batching_workers = batching_workers
self.batching_strategy = batching_strategy
self.pin_memory = pin_memory
self.drop_last = drop_last
# TODO: support length and infinity
dp_size = DistributedInterface().get_world_size("dp")
if self.global_batch_size is None:
self.global_batch_size = dp_size * micro_batch_size
self.num_micro_batch = 1
elif self.global_batch_size % (dp_size * micro_batch_size) == 0:
self.num_micro_batch = global_batch_size // dp_size // micro_batch_size
else:
raise ValueError(
"Global batch size must be divisible by DP size and micro batch size. "
f"Got {global_batch_size} % ({dp_size} * {micro_batch_size}) != 0."
)
if not self.drop_last:
raise ValueError("Drop last must be True.")
self._init_data_provider()
self._is_resuming: bool = False
self._data_iter = iter(self._data_provider)
self._buffer: list[ModelInput] = []
self._buffer_tokens: int = 0
self._max_buffer_tokens: int = self.micro_batch_size * self.num_micro_batch * self.cutoff_len
logger.info_rank0(
f"Init unified data loader with global batch size {self.global_batch_size}, "
f"micro batch size {self.micro_batch_size}, "
f"num micro batch {self.num_micro_batch}, "
f"cutoff len {self.cutoff_len}, "
f"batching workers {self.batching_workers}, "
f"batching strategy {self.batching_strategy}."
)
def _init_data_provider(self) -> None:
if len(self.dataset) != -1:
sampler = StatefulDistributedSampler(
self.dataset,
num_replicas=DistributedInterface().get_world_size("dp"),
rank=DistributedInterface().get_rank("dp"),
shuffle=True,
seed=0,
drop_last=self.drop_last,
)
else:
raise NotImplementedError("Iterable dataset is not supported yet.")
self._data_provider = StatefulDataLoader(
self.dataset,
batch_size=self.micro_batch_size * self.num_micro_batch,
sampler=sampler,
num_workers=self.batching_workers,
collate_fn=self.renderer.process_samples,
pin_memory=self.pin_memory,
drop_last=self.drop_last,
)
if self.batching_strategy == BatchingStrategy.NORMAL:
self._length = len(self._data_provider)
else:
from ...plugins.trainer_plugins.batching import BatchingPlugin
self._length = BatchingPlugin(self.batching_strategy).compute_length()
raise NotImplementedError("Batching strategy other than NORMAL is not supported yet.")
def __len__(self) -> int:
return self._length
def __iter__(self):
if not self._is_resuming:
self._buffer.clear()
self._buffer_tokens = 0
self._data_iter = iter(self._data_provider)
self._is_resuming = False
return self
def __next__(self):
batch = self._next_batch()
if batch is None:
raise StopIteration
return batch
def _next_batch(self) -> list[BatchInput] | None:
while self._buffer_tokens < self._max_buffer_tokens:
try:
samples: list[ModelInput] = next(self._data_iter)
except StopIteration:
break
num_tokens = sum(len(sample["input_ids"]) for sample in samples)
self._buffer.extend(samples)
self._buffer_tokens += num_tokens
return self._build_batch()
def _build_batch(self) -> list[BatchInput] | None:
if self.batching_strategy == BatchingStrategy.NORMAL:
self._buffer, self._buffer_tokens, batch = default_collate_fn(
self._buffer, self._buffer_tokens, self.micro_batch_size, self.num_micro_batch, self.cutoff_len
)
return batch
else:
from ...plugins.trainer_plugins.batching import BatchingPlugin
self._buffer, self._buffer_tokens, batch = BatchingPlugin(self.batching_strategy)(
self._buffer, self._buffer_tokens, self.micro_batch_size, self.num_micro_batch, self.cutoff_len
)
return batch
def state_dict(self) -> dict[str, Any]:
return {
"buffer": self._buffer,
"buffer_tokens": self._buffer_tokens,
"data_provider": self._data_provider.state_dict(),
}
def load_state_dict(self, state: dict[str, Any]) -> None:
self._buffer = state["buffer"]
self._buffer_tokens = state["buffer_tokens"]
self._data_provider.load_state_dict(state["data_provider"])
self._is_resuming = True
def set_epoch(self, epoch: int) -> None:
if hasattr(self._data_provider.sampler, "set_epoch"):
self._data_provider.sampler.set_epoch(epoch)
if __name__ == "__main__":
"""
python -m llamafactory.v1.core.utils.batching \
--model llamafactory/tiny-random-qwen2.5 \
--dataset data/v1_sft_demo.yaml \
--micro_batch_size 2 \
--global_batch_size 4 \
--batching_workers 0
"""
from ...config.arg_parser import get_args
from ..data_engine import DataEngine
from ..model_engine import ModelEngine
data_args, model_args, training_args, _ = get_args()
data_engine = DataEngine(data_args=data_args)
model_engine = ModelEngine(model_args=model_args)
batch_generator = BatchGenerator(
data_engine,
model_engine.renderer,
micro_batch_size=training_args.micro_batch_size,
global_batch_size=training_args.global_batch_size,
cutoff_len=training_args.cutoff_len,
batching_workers=training_args.batching_workers,
batching_strategy=training_args.batching_strategy,
)
for batch in batch_generator:
print(batch)
print(len(batch))
print(batch[0]["input_ids"].shape)
break

View File

@@ -1,277 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import sys
from collections.abc import Generator, Iterator
from dataclasses import dataclass
from typing import Optional
from torchdata.stateful_dataloader import StatefulDataLoader
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
from ...utils.batching_queue import BaseBatchingQueue
from ...utils.logging import get_logger
from ...utils.types import Processor, TorchDataset
from .data_collator import DataCollator
logger = get_logger(__name__)
# base dataloader
class DistributedDataloader(StatefulDataLoader):
"""Base Distributed DataLoader."""
dataset: "TorchDataset"
sampler: "StatefulDistributedSampler"
def set_epoch(self, epoch: int) -> None:
if self.sampler is not None and hasattr(self.sampler, "set_epoch"):
self.sampler.set_epoch(epoch)
elif hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)
@dataclass
class BaseDataLoader:
"""Default DataLoader."""
processor: Processor
def __init__(self, dataset: TorchDataset) -> None:
self.dataset = dataset
# guidlines: fetch until get fixed batchsize.
# save state_dict for buffer.
# resume with state
# 1. Init stateful dataloader (tokenize)
# 2. Add to buffer (2 * max seq len per device)
# 3. Yield batch indexes (micro batch * grad acc)
# a ) non pack + non dynamic
# b ) non pack + dynamic
# c ) pack + non dynamic
# d ) pack + dynamic
def init_dataloader(self) -> None:
### init dataloader
pass
def __iter__(self) -> Iterator:
pass
def __next__(self) -> any:
pass
@dataclass
class DataLoader:
"""Default DataLoader."""
processor: "Processor"
dataloader: "DistributedDataloader"
batching_queue: "BaseBatchingQueue"
collate_fn: "DataCollator"
num_micro_batch: int = 1
length: int = 0
drop_last: bool = True
def __init__(
self,
dataloader: any,
collate_fn: "DataCollator",
num_micro_batch: int = 1,
length: int = 0,
drop_last: bool = True,
batching_queue: Optional["BaseBatchingQueue"] = None,
) -> None:
self.batching_queue = batching_queue
self.num_micro_batch = num_micro_batch
self.step = 0
self._collate_fn = collate_fn
self._dataloader = dataloader
self._drop_last = drop_last
self._data_iter: Iterator
self._resume = False
self._batch_data_iter: Generator
if length > 0:
self._length = length
elif length == -1:
self._length = sys.maxsize
else:
self._length = len(self._dataloader)
def __len__(self):
return self._length
def __iter__(self) -> Iterator:
if not self._resume:
self.step = 0
self._data_iter = iter(self._dataloader)
self._batch_data_iter = self.batch_data_generator()
self._resume = False
return self
def __next__(self):
return next(self._batch_data_iter) # FIXME maybe we can move origin_batch_data_generator to here
def origin_batch_data_generator(self):
"""Standard pass-through generator if do not use batching queue."""
while True:
if self._length > 0 and self.step >= self._length:
return
try:
batch = []
data = next(self._data_iter)
# split data into micro batches
for i in range(0, len(data), self.num_micro_batch):
micro_batch = data[i : i + self.num_micro_batch]
if self._collate_fn:
micro_batch = self._collate_fn(micro_batch)
batch.append(micro_batch)
yield batch
self.step += 1
except StopIteration:
if self.step < self._length:
# Restart iterator to fill the requested length
self._data_iter = iter(self._dataloader)
try:
batch = []
data = next(self._data_iter)
for i in range(0, len(data), self.num_micro_batch):
micro_batch = data[i : i + self.num_micro_batch]
if self._collate_fn:
micro_batch = self._collate_fn(micro_batch)
batch.append(micro_batch)
yield batch
self.step += 1
except StopIteration:
return
else:
return
except Exception as e:
logger.error(f"DataLoader origin_batch_data_generator exception: {e}")
raise
def batch_data_generator(self):
if self.batching_queue is None:
yield from self.origin_batch_data_generator()
return
batch = []
while True:
if self._length and self.step >= self._length:
return
if self.batching_queue.is_full_filled():
micro_batch = self.batching_queue.get_micro_batch(self.step)
if self._collate_fn:
micro_batch = self._collate_fn(micro_batch)
batch.append(micro_batch)
if len(batch) == self.num_micro_batch:
yield batch
self.step += 1
batch = []
try:
processing_item = next(self._data_iter)
except Exception as e:
if isinstance(e, StopIteration):
if self.step < self._length:
# call iter until reach length
self._data_iter = iter(self._dataloader)
processing_item = next(self._data_iter)
elif not self._drop_last and not self.batching_queue.empty():
while not self.batching_queue.empty():
micro_batch = self.batching_queue.get_micro_batch(self.step)
if self._collate_fn:
micro_batch = self._collate_fn(micro_batch)
batch.append(micro_batch)
if len(batch) == self.num_micro_batch:
yield batch
self.step += 1
batch = []
while len(batch) < self.num_micro_batch:
padding_batch = copy.deepcopy(micro_batch)
padding_batch["is_padded"] = True
batch.append(padding_batch)
yield batch
self.step += 1
return
else:
return
else:
logger.error(f"DataLoader iter data exception: {e}")
raise
# put processing_item to buffer
if isinstance(processing_item, dict):
processing_item = [processing_item]
for item in processing_item:
self.batching_queue.put_item(item)
def state_dict(self):
# save state
state = self.__dict__.copy()
# remove internal fields
for k in list(state.keys()):
if k.startswith("_"):
del state[k]
# save dataloader state
if hasattr(self._dataloader, "state_dict"):
state["dataloader_state"] = self._dataloader.state_dict()
elif hasattr(self._dataloader, "__getstate__"):
state["dataloader_state"] = self._dataloader.__getstate__()
batching_strategy = getattr(self, "batching_strategy", None)
if batching_strategy and hasattr(batching_strategy, "state_dict"):
state["batching_strategy_state"] = batching_strategy.state_dict()
if "batching_strategy" in state:
del state["batching_strategy"]
return copy.deepcopy(state)
def load_state_dict(self, state: dict[str, any]):
if state["num_micro_batch"] != self.num_micro_batch:
logger.warning(
f"num_micro_batch changed: [ {state['num_micro_batch']} -> {self.num_micro_batch} ], will clear prefetch buffer"
)
del state["num_micro_batch"]
self.__dict__.update(state)
self._resume = True
if hasattr(self._dataloader, "load_state_dict"):
self._dataloader.load_state_dict(state["dataloader_state"])
elif hasattr(self._dataloader, "__getstate__"):
self._dataloader.__setstate__(state["dataloader_state"])
if "batching_strategy_state" in state:
batching_strategy = getattr(self, "batching_strategy", None)
if batching_strategy:
batching_strategy.load_state_dict(state["batching_strategy_state"])
del state["batching_strategy_state"]
self._data_iter = iter(self._dataloader)
self._batch_data_iter = self.batch_data_generator()
def set_epoch(self, epoch: int) -> None:
if hasattr(self._dataloader, "set_epoch"):
self._dataloader.set_epoch(epoch)

View File

@@ -12,10 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rendering utils.
How to use:
renderer = Renderer(template, processor)
renderer.render_messages(messages: list[Message], tools: str | None) -> ModelInputs
renderer.parse_message(text: str) -> Message
renderer.process_samples(samples: list[Sample]) -> list[ModelInput]
"""
import numpy as np
from ...utils.constants import IGNORE_INDEX
from ...utils.helper import get_tokenizer
from ...utils.types import Message, ModelInput, Processor
from ...utils.types import Message, ModelInput, Processor, Sample
def render_chatml_messages(
@@ -64,7 +74,7 @@ def render_chatml_messages(
def parse_chatml_message(generated_text: str) -> Message:
"""Parse a message in ChatML format. Supports interleaved reasoning and tool calls.
"""Parse a message in ChatML format.
Args:
generated_text (str): The generated text in ChatML format.
@@ -83,6 +93,16 @@ class Renderer:
def render_messages(
self, messages: list[Message], tools: str | None = None, is_generate: bool = False
) -> ModelInput:
"""Apply template to messages and convert them to model input.
Args:
messages (list[Message]): The messages to render.
tools (str | None, optional): The tools to use. Defaults to None.
is_generate (bool, optional): Whether to render for generation. Defaults to False.
Returns:
ModelInput: The rendered model input.
"""
if self.template == "chatml":
return render_chatml_messages(self.processor, messages, tools, is_generate)
else:
@@ -91,9 +111,59 @@ class Renderer:
return RenderingPlugin(self.template).render_messages(self.processor, messages, tools, is_generate)
def parse_message(self, generated_text: str) -> Message:
"""Parse a message in the template format.
Args:
generated_text (str): The generated text in the template format.
Returns:
Message: The parsed message.
"""
if self.template == "chatml":
return parse_chatml_message(generated_text)
else:
from ...plugins.model_plugins.rendering import RenderingPlugin
return RenderingPlugin(self.template).parse_message(generated_text)
def process_samples(self, samples: list[Sample]) -> list[ModelInput]:
"""Process samples to model input.
Args:
samples (list[Sample]): The samples to process.
Returns:
list[ModelInput]: The processed model inputs.
"""
model_inputs = []
for sample in samples:
if "messages" in sample:
model_input = self.render_messages(sample["messages"], sample.get("tools"))
elif "chosen_messages" in sample and "rejected_messages" in sample:
chosen_input = self.render_messages(sample["chosen_messages"], sample.get("tools"))
rejected_input = self.render_messages(sample["rejected_messages"], sample.get("tools"))
chosen_input["token_type_ids"] = [0] * len(chosen_input["input_ids"])
rejected_input["token_type_ids"] = [1] * len(rejected_input["input_ids"])
model_input = ModelInput(
input_ids=chosen_input["input_ids"] + rejected_input["input_ids"],
attention_mask=chosen_input["attention_mask"] + rejected_input["attention_mask"],
labels=chosen_input["labels"] + rejected_input["labels"],
loss_weights=chosen_input["loss_weights"] + rejected_input["loss_weights"],
token_type_ids=chosen_input["token_type_ids"] + rejected_input["token_type_ids"],
)
if "position_ids" in chosen_input:
model_input["position_ids"] = np.concatenate(
[chosen_input["position_ids"], rejected_input["position_ids"]], axis=-1
)
else:
raise ValueError("No valid messages or chosen_messages/rejected_messages found in sample.")
if "extra_info" in sample:
model_input["extra_info"] = sample["extra_info"]
if "_dataset_name" in sample:
model_input["_dataset_name"] = sample["_dataset_name"]
model_inputs.append(model_input)
return model_inputs

View File

@@ -32,7 +32,8 @@ class AlpacaSample(TypedDict, total=False):
SharegptMessage = TypedDict(
"SharegptMessage", {"from": Literal["human", "gpt", "system", "function_call", "observation"], "value": str}
"SharegptMessage",
{"from": Literal["human", "gpt", "system", "function_call", "observation"], "value": str},
)
@@ -118,15 +119,8 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
"observation": "tool",
"function_call": "assistant",
}
sample = {}
messages = []
tools = raw_sample.get("tools")
if tools:
try:
tools: list[dict[str, Any]] = json.loads(tools)
except json.JSONDecodeError:
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
tools = []
for message in raw_sample.get("conversations", []):
tag = message["from"]
if tag not in tag_mapping:
@@ -157,10 +151,17 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
}
)
sample["messages"] = messages
tools = raw_sample.get("tools")
if tools:
return {"messages": messages, "tools": json.dumps(tools)}
else:
return {"messages": messages}
try:
tools: list[dict[str, Any]] = json.loads(tools)
sample["tools"] = json.dumps(tools)
except json.JSONDecodeError:
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
return sample
@DataConverterPlugin("pair").register()
@@ -179,17 +180,44 @@ def pair_converter(raw_sample: PairSample) -> DPOSample:
def process_message(raw_messages: list[OpenaiMessage]):
messages = []
for message in raw_messages:
messages.append(
{
"role": message["role"],
"content": [{"type": "text", "value": message["content"]}],
"loss_weight": 1.0 if message["role"] == "assistant" else 0.0,
}
)
if message["role"] == "tool":
try:
tool_calls: ToolCall | list[ToolCall] = json.loads(message["content"])
except json.JSONDecodeError:
logger.warning_rank0(f"Invalid tool call format: {str(message['content'])}")
continue
if not isinstance(tool_calls, list):
tool_calls = [tool_calls]
messages.append(
{
"role": message["role"],
"content": [{"type": "tool_call", "value": json.dumps(tool_call)} for tool_call in tool_calls],
"loss_weight": 1.0 if message["role"] == "assistant" else 0.0,
}
)
else:
messages.append(
{
"role": message["role"],
"content": [{"type": "text", "value": message["content"]}],
"loss_weight": 1.0 if message["role"] == "assistant" else 0.0,
}
)
return messages
chosen_messages = process_message(raw_sample.get("chosen", []))
rejected_messages = process_message(raw_sample.get("rejected", []))
sample = {}
sample["chosen_messages"] = process_message(raw_sample.get("chosen", []))
sample["rejected_messages"] = process_message(raw_sample.get("rejected", []))
return {"chosen_messages": chosen_messages, "rejected_messages": rejected_messages}
tools = raw_sample.get("tools")
if tools:
try:
tools: list[dict[str, Any]] = json.loads(tools)
sample["tools"] = json.dumps(tools)
except json.JSONDecodeError:
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
return sample

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
@@ -51,12 +50,16 @@ def _update_model_input(
@RenderingPlugin("qwen3_nothink").register("render_messages")
def render_qwen_messages(
def render_qwen3_nothink_messages(
processor: Processor,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
) -> ModelInput:
"""Render messages in the Qwen3 nothink template format.
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-4B-Instruct-2507
"""
input_ids, labels, loss_weights = [], [], []
temp_str, temp_weight = "", 0.0
if tools:
@@ -179,7 +182,15 @@ def render_qwen_messages(
@RenderingPlugin("qwen3_nothink").register("parse_message")
def parse_qwen_message(generated_text: str) -> Message:
def parse_qwen3_nothink_message(generated_text: str) -> Message:
"""Parse a message in the Qwen3 nothink template format. Supports interleaved reasoning and tool calls.
Args:
generated_text (str): The generated text in the Qwen3 nothink template format.
Returns:
Message: The parsed message.
"""
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
content = []
last_end = 0

View File

@@ -0,0 +1,19 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils.plugin import BasePlugin
class BatchingPlugin(BasePlugin):
pass

View File

@@ -14,7 +14,7 @@
from ..accelerator.interface import DistributedInterface
from ..config.arg_parser import get_args
from ..config import InputArgument, get_args
from ..core.base_trainer import BaseTrainer
from ..core.data_engine import DataEngine
from ..core.model_engine import ModelEngine
@@ -24,15 +24,15 @@ class SFTTrainer(BaseTrainer):
pass
def run_sft(user_args):
model_args, data_args, training_args, _ = get_args(user_args)
def run_sft(args: InputArgument = None):
model_args, data_args, training_args, _ = get_args(args)
DistributedInterface(training_args.dist_config)
data_engine = DataEngine(data_args)
model_engine = ModelEngine(model_args)
trainer = SFTTrainer(
args=training_args,
model=model_engine.model,
processor=model_engine.processor,
renderer=model_engine.renderer,
dataset=data_engine,
)
trainer.fit()

View File

@@ -1,220 +0,0 @@
# Copyright 2025 Bytedance Ltd. and the LlamaFactory team.
#
# This code is inspired by the Bytedance's VeOmni library.
# https://github.com/ByteDance-Seed/VeOmni/blob/v0.1.4/veomni/data/dynamic_batching.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
class DynamicBatchSizeBuffer:
"""A buffer to store samples for dynamic batch size."""
def __init__(self):
self._buffer: list[dict[str, any]] = []
self._buffer_sample_lengths: list[int] = []
self._deleted_indices: set[int] = set()
self._current_index: int = 0
self._total_token_count: int = 0
def append(self, item: dict[str, any]) -> None:
"""Append a sample to the buffer.
Args:
item: A sample to append to the buffer.
The sample should be a dict with the following keys:
- input_ids: torch.Tensor of shape (seq_len, )
- attention_mask: torch.Tensor of shape (seq_len, )
"""
self._buffer.append(item)
sample_length = int(item["attention_mask"].sum().item())
self._buffer_sample_lengths.append(sample_length)
self._total_token_count += sample_length
def get_samples(self, max_tokens_per_iteration: int, force: bool = True) -> list[dict[str, any]]:
"""Get samples from the buffer that fit within the token budget.
Args:
max_tokens_per_iteration: Maximum number of tokens to retrieve.
force: If True, the first available sample will be returned even
if it exceeds the token budget.
Returns:
A list of samples that fit within the token budget.
Raises:
AssertionError: If no samples are found (should not happen in normal operation).
"""
cum_seq_len = 0
samples = []
while self._current_index < len(self._buffer) and cum_seq_len < max_tokens_per_iteration:
if self._current_index in self._deleted_indices:
self._current_index += 1
continue
seq_len = self._buffer_sample_lengths[self._current_index]
remaining_tokens = max_tokens_per_iteration - cum_seq_len
# Check if we can add this sample
can_add = (force and cum_seq_len == 0) or (seq_len <= remaining_tokens)
if can_add:
cum_seq_len += seq_len
samples.append(self._buffer[self._current_index])
self._deleted_indices.add(self._current_index)
self._current_index += 1
assert len(samples) > 0, "No samples found in buffer"
return samples
def __len__(self) -> int:
"""Return the number of samples in the buffer."""
return len(self._buffer)
@property
def total_token_count(self) -> int:
"""Return the total number of tokens in the buffer."""
return self._total_token_count
def flush(self) -> None:
tokens_to_remove = sum(self._buffer_sample_lengths[idx] for idx in self._deleted_indices)
self._total_token_count -= tokens_to_remove
buffer_length = len(self._buffer)
self._buffer = [self._buffer[idx] for idx in range(buffer_length) if idx not in self._deleted_indices]
self._buffer_sample_lengths = [
self._buffer_sample_lengths[idx] for idx in range(buffer_length) if idx not in self._deleted_indices
]
self._current_index = 0
self._deleted_indices.clear()
class BaseBatchingQueue(ABC):
"""Base class for batching queue."""
@abstractmethod
def is_full_filled(self) -> bool:
raise NotImplementedError("Subclasses must implement `is_full_filled`")
@abstractmethod
def put_item(self, item: dict[str, any]) -> None:
raise NotImplementedError("Subclasses must implement `put_item`")
@abstractmethod
def get_micro_batch(self, step: int) -> list[dict[str, any]]:
raise NotImplementedError("Subclasses must implement `get_micro_batch`")
@abstractmethod
def empty(self) -> bool:
raise NotImplementedError("Subclasses must implement `empty`")
class IdentityPacker:
def __init__(self, token_micro_bsz, bsz_warmup_steps, bsz_warmup_init_mbtoken):
self.token_micro_bsz = token_micro_bsz
self.bsz_warmup_steps = bsz_warmup_steps
self.bsz_warmup_init_mbtoken = bsz_warmup_init_mbtoken
def __call__(self, samples):
return samples
def get_token_num_to_request(self, cur_step, warmup):
return (
(self.token_micro_bsz - self.bsz_warmup_init_mbtoken) * cur_step // self.bsz_warmup_steps
+ self.bsz_warmup_init_mbtoken
if warmup
else self.token_micro_bsz
)
class TextBatchingQueue(BaseBatchingQueue):
"""Batching text queue for text data."""
def __init__(
self,
token_micro_bsz,
buffer_size: int = 500,
bsz_warmup_steps: int = -1,
bsz_warmup_init_mbtoken: int = 200,
) -> None:
super().__init__()
self._step = 0
self.token_micro_bsz = token_micro_bsz
self.bsz_warmup_steps = bsz_warmup_steps
self.buffer_size = buffer_size # minimum samples in buffer
self.buffer = DynamicBatchSizeBuffer()
self.bsz_warmup_init_mbtoken = bsz_warmup_init_mbtoken # training warmup args
assert self.bsz_warmup_init_mbtoken >= 0
self.packer = IdentityPacker(
token_micro_bsz=token_micro_bsz,
bsz_warmup_steps=bsz_warmup_steps,
bsz_warmup_init_mbtoken=bsz_warmup_init_mbtoken,
)
def is_full_filled(self) -> bool:
return len(self.buffer) >= self.buffer_size and self.buffer.total_token_count >= self.token_micro_bsz
def put_item(self, item: dict[str, any]):
if len(item["input_ids"]) == 1:
print("WARNING: EMPTY STRING.")
return
self.buffer.append(item)
def get_token_num_to_request(self):
if self.packer is not None:
warmup = self._step <= self.bsz_warmup_steps and self.bsz_warmup_steps > 0
return self.packer.get_token_num_to_request(self._step, warmup=warmup)
else:
return self.get_cur_token_micro_bsz()
def get_cur_token_micro_bsz(self):
warmup = self._step <= self.bsz_warmup_steps and self.bsz_warmup_steps > 0
if warmup:
return (
self.token_micro_bsz - self.bsz_warmup_init_mbtoken
) * self._step // self.bsz_warmup_steps + self.bsz_warmup_init_mbtoken
else:
return self.token_micro_bsz
def get_micro_batch(self, step) -> any:
"""Get a micro batch from the buffer according to the current step.
Args:
step: the current step.
Returns:
data: a list of samples.
"""
self._step = step
n_token_per_iter = self.get_token_num_to_request()
cur_token_micro_bsz = self.get_cur_token_micro_bsz()
assert cur_token_micro_bsz % n_token_per_iter == 0, (
"The token num to get for each request should be divisible by token micro bsz."
)
n_iter = int(cur_token_micro_bsz // n_token_per_iter)
data = []
for _ in range(n_iter):
samples = self.buffer.get_samples(n_token_per_iter)
if self.packer:
samples = self.packer(samples) # maybe packed into one sample, but wrapped in list.
data.extend(samples)
self.buffer.flush() # remove the selected samples.
return data
def empty(self) -> bool:
return len(self.buffer) == 0

View File

@@ -32,8 +32,8 @@ class DtypeRegistry:
class DtypeInterface:
"""Type of precision used."""
_is_fp16_available = is_torch_fp16_available_on_device(DistributedInterface.current_accelerator)
_is_bf16_available = is_torch_bf16_available_on_device(DistributedInterface.current_accelerator)
_is_fp16_available = is_torch_fp16_available_on_device(DistributedInterface().current_device)
_is_bf16_available = is_torch_bf16_available_on_device(DistributedInterface().current_device)
_is_fp32_available = True
@staticmethod

View File

@@ -12,9 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from transformers import PreTrainedTokenizer
from .types import Processor
from .constants import IGNORE_INDEX
from .types import BatchInput, ModelInput, Processor, Tensor
def is_tokenizer(processor: Processor) -> bool:
"""Check if processor is tokenizer.
Args:
processor: Processor.
Returns:
Whether processor is tokenizer.
"""
return not hasattr(processor, "tokenizer")
def get_tokenizer(processor: Processor) -> PreTrainedTokenizer:
@@ -27,3 +42,34 @@ def get_tokenizer(processor: Processor) -> PreTrainedTokenizer:
Tokenizer.
"""
return processor.tokenizer if hasattr(processor, "tokenizer") else processor
def _pad_and_truncate(tensor: Tensor, max_seqlen: int, pad_value: int = 0) -> Tensor:
if tensor.shape[-1] >= max_seqlen:
return tensor[..., :max_seqlen]
pad_shape = list(tensor.shape)
pad_shape[-1] = max_seqlen - tensor.shape[-1]
pad_tensor = torch.full(pad_shape, pad_value, dtype=tensor.dtype, device=tensor.device)
return torch.cat([tensor, pad_tensor], dim=-1)
def pad_and_truncate(samples: list[ModelInput], max_seqlen: int) -> list[BatchInput]:
max_length = min(max(len(sample["input_ids"]) for sample in samples), max_seqlen)
padded_samples = []
for sample in samples:
padded_sample = {}
for key, value in sample.items():
if "label" in key:
pad_value = IGNORE_INDEX
else:
pad_value = 0
if not isinstance(value, str):
padded_sample[key] = _pad_and_truncate(torch.tensor(value), max_length, pad_value)
else:
padded_sample[key] = value
padded_samples.append(padded_sample)
return padded_samples

View File

@@ -144,3 +144,20 @@ class ModelInput(TypedDict, total=False):
"""Loss weight for each token, default to 1.0."""
position_ids: NotRequired[list[int] | list[list[int]]]
"""Position ids for the model (optional)."""
token_type_ids: NotRequired[list[int]]
"""Token type ids used in DPO, 0 represents the chosen messages, 1 represents the rejected messages."""
class BatchInput(TypedDict, total=False):
input_ids: Tensor
"""Input ids for the model."""
attention_mask: Tensor
"""Attention mask for the model."""
labels: Tensor
"""Labels for the model."""
loss_weights: Tensor
"""Loss weight for each token, default to 1.0."""
position_ids: NotRequired[Tensor]
"""Position ids for the model (optional)."""
token_type_ids: NotRequired[Tensor]
"""Token type ids used in DPO, 0 represents the chosen messages, 1 represents the rejected messages."""