[v1] add batch generator (#9744)

This commit is contained in:
Yaowei Zheng
2026-01-10 04:24:09 +08:00
committed by GitHub
parent d7d734d54c
commit b2effbd77c
26 changed files with 604 additions and 850 deletions

View File

@@ -13,7 +13,7 @@
# limitations under the License.
from .arg_parser import InputArgument, get_args
from .arg_utils import ModelClass, SampleBackend
from .arg_utils import BatchingStrategy, ModelClass, SampleBackend
from .data_args import DataArguments
from .model_args import ModelArguments
from .sample_args import SampleArguments
@@ -21,6 +21,7 @@ from .training_args import TrainingArguments
__all__ = [
"BatchingStrategy",
"DataArguments",
"InputArgument",
"ModelArguments",

View File

@@ -50,6 +50,14 @@ class SampleBackend(StrEnum):
VLLM = "vllm"
@unique
class BatchingStrategy(StrEnum):
NORMAL = "normal"
PADDING_FREE = "padding_free"
DYNAMIC_BATCHING = "dynamic_batching"
DYNAMIC_PADDING_FREE = "dynamic_padding_free"
def _convert_str_dict(data: dict) -> dict:
"""Parse string representation inside the dictionary.

View File

@@ -22,7 +22,3 @@ class DataArguments:
default=None,
metadata={"help": "Path to the dataset."},
)
cutoff_len: int = field(
default=2048,
metadata={"help": "Cutoff length for the dataset."},
)

View File

@@ -16,7 +16,7 @@ import os
from dataclasses import dataclass, field
from uuid import uuid4
from .arg_utils import PluginConfig, get_plugin_config
from .arg_utils import BatchingStrategy, PluginConfig, get_plugin_config
@dataclass
@@ -29,18 +29,30 @@ class TrainingArguments:
default=1,
metadata={"help": "Micro batch size for training."},
)
global_batch_size: int = field(
default=1,
metadata={"help": "Global batch size for training."},
global_batch_size: int | None = field(
default=None,
metadata={"help": "Global batch size for training, default to DP size * micro batch size."},
)
learning_rate: float = field(
default=1e-4,
metadata={"help": "Learning rate for training."},
)
cutoff_len: int = field(
default=2048,
metadata={"help": "Maximum sequence length for training."},
)
bf16: bool = field(
default=False,
metadata={"help": "Use bf16 for training."},
)
batching_strategy: BatchingStrategy = field(
default=BatchingStrategy.NORMAL,
metadata={"help": "Batching strategy for training."},
)
batching_workers: int = field(
default=16,
metadata={"help": "Number of workers for batching."},
)
dist_config: PluginConfig | None = field(
default=None,
metadata={"help": "Distribution configuration for training."},

View File

@@ -29,7 +29,6 @@ Train Phase:
from ..config.training_args import TrainingArguments
from ..utils.types import HFModel, TorchDataset
from .utils.data_collator import DataCollator
from .utils.rendering import Renderer
@@ -45,7 +44,6 @@ class BaseTrainer:
self.model = model
self.renderer = renderer
self.dataset = dataset
self.data_collator = DataCollator()
self.optimizer = None
self.lr_scheduler = None

View File

@@ -82,14 +82,17 @@ class DataEngine(Dataset):
def _load_dataset(self) -> None:
"""Load datasets according to dataset info."""
is_streaming = [dataset_info.get("streaming", False) for dataset_info in self.dataset_infos.values()]
self.streaming = any(is_streaming)
if all(is_streaming) != any(is_streaming):
raise ValueError("All datasets must be streaming or non-streaming.")
for dataset_name, dataset_info in self.dataset_infos.items():
split = dataset_info.get("split", "train")
streaming = dataset_info.get("streaming", False)
self.streaming |= streaming
if dataset_info.get("source", "hf_hub") == "hf_hub":
from datasets import load_dataset
self.datasets[dataset_name] = load_dataset(dataset_info["path"], split=split, streaming=streaming)
self.datasets[dataset_name] = load_dataset(dataset_info["path"], split=split, streaming=self.streaming)
else: # data loader plugin
from ..plugins.data_plugins.loader import DataLoaderPlugin
@@ -98,8 +101,7 @@ class DataEngine(Dataset):
def _build_data_index(self) -> None:
"""Build dataset index."""
for dataset_name, dataset in self.datasets.items():
streaming = self.dataset_infos[dataset_name].get("streaming", False)
if streaming:
if self.streaming:
data_index = [(dataset_name, -1) for _ in range(1000)]
else:
data_index = [(dataset_name, sample_index) for sample_index in range(len(dataset))]
@@ -185,8 +187,8 @@ class DataEngine(Dataset):
if __name__ == "__main__":
"""
python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_sft_demo.yaml
python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_dpo_demo.yaml
python -m llamafactory.v1.core.data_engine --dataset data/v1_sft_demo.yaml
python -m llamafactory.v1.core.data_engine --dataset data/v1_dpo_demo.yaml
"""
from ..config.arg_parser import get_args

View File

@@ -0,0 +1,244 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Batching utils supports stateful dataloader.
1. Init stateful dataloader (tokenize)
2. Add to buffer
3. Yield batch indexes (micro batch * grad acc)
a) non pack + non dynamic
b) non pack + dynamic
c) pack + non dynamic
d) pack + dynamic
"""
from collections.abc import Iterator
from typing import Any
from torch.utils.data import default_collate
from torchdata.stateful_dataloader import StatefulDataLoader
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
from ...accelerator.interface import DistributedInterface
from ...config import BatchingStrategy
from ...utils import logging
from ...utils.helper import pad_and_truncate
from ...utils.types import BatchInput, ModelInput, TorchDataset
from .rendering import Renderer
logger = logging.get_logger(__name__)
def default_collate_fn(
buffer: list[ModelInput], buffer_tokens: int, micro_batch_size: int, num_micro_batch: int, cutoff_len: int
) -> tuple[list[ModelInput], int, list[BatchInput]]:
batch_size = micro_batch_size * num_micro_batch
if len(buffer) < batch_size:
return buffer, buffer_tokens, None
samples = buffer[:batch_size]
buffer = buffer[batch_size:]
buffer_tokens -= sum(len(sample["input_ids"]) for sample in samples)
batch = []
for i in range(num_micro_batch):
micro_batch = samples[i * micro_batch_size : (i + 1) * micro_batch_size]
batch.append(default_collate(pad_and_truncate(micro_batch, cutoff_len)))
return buffer, buffer_tokens, batch
class BatchGenerator(Iterator):
def __init__(
self,
dataset: TorchDataset,
renderer: Renderer,
micro_batch_size: int = 1,
global_batch_size: int | None = None,
cutoff_len: int = 2048,
batching_workers: int = 0,
batching_strategy: BatchingStrategy = BatchingStrategy.NORMAL,
pin_memory: bool = True,
drop_last: bool = True,
) -> None:
self.dataset = dataset
self.renderer = renderer
self.micro_batch_size = micro_batch_size
self.global_batch_size = global_batch_size
self.cutoff_len = cutoff_len
self.batching_workers = batching_workers
self.batching_strategy = batching_strategy
self.pin_memory = pin_memory
self.drop_last = drop_last
# TODO: support length and infinity
dp_size = DistributedInterface().get_world_size("dp")
if self.global_batch_size is None:
self.global_batch_size = dp_size * micro_batch_size
self.num_micro_batch = 1
elif self.global_batch_size % (dp_size * micro_batch_size) == 0:
self.num_micro_batch = global_batch_size // dp_size // micro_batch_size
else:
raise ValueError(
"Global batch size must be divisible by DP size and micro batch size. "
f"Got {global_batch_size} % ({dp_size} * {micro_batch_size}) != 0."
)
if not self.drop_last:
raise ValueError("Drop last must be True.")
self._init_data_provider()
self._is_resuming: bool = False
self._data_iter = iter(self._data_provider)
self._buffer: list[ModelInput] = []
self._buffer_tokens: int = 0
self._max_buffer_tokens: int = self.micro_batch_size * self.num_micro_batch * self.cutoff_len
logger.info_rank0(
f"Init unified data loader with global batch size {self.global_batch_size}, "
f"micro batch size {self.micro_batch_size}, "
f"num micro batch {self.num_micro_batch}, "
f"cutoff len {self.cutoff_len}, "
f"batching workers {self.batching_workers}, "
f"batching strategy {self.batching_strategy}."
)
def _init_data_provider(self) -> None:
if len(self.dataset) != -1:
sampler = StatefulDistributedSampler(
self.dataset,
num_replicas=DistributedInterface().get_world_size("dp"),
rank=DistributedInterface().get_rank("dp"),
shuffle=True,
seed=0,
drop_last=self.drop_last,
)
else:
raise NotImplementedError("Iterable dataset is not supported yet.")
self._data_provider = StatefulDataLoader(
self.dataset,
batch_size=self.micro_batch_size * self.num_micro_batch,
sampler=sampler,
num_workers=self.batching_workers,
collate_fn=self.renderer.process_samples,
pin_memory=self.pin_memory,
drop_last=self.drop_last,
)
if self.batching_strategy == BatchingStrategy.NORMAL:
self._length = len(self._data_provider)
else:
from ...plugins.trainer_plugins.batching import BatchingPlugin
self._length = BatchingPlugin(self.batching_strategy).compute_length()
raise NotImplementedError("Batching strategy other than NORMAL is not supported yet.")
def __len__(self) -> int:
return self._length
def __iter__(self):
if not self._is_resuming:
self._buffer.clear()
self._buffer_tokens = 0
self._data_iter = iter(self._data_provider)
self._is_resuming = False
return self
def __next__(self):
batch = self._next_batch()
if batch is None:
raise StopIteration
return batch
def _next_batch(self) -> list[BatchInput] | None:
while self._buffer_tokens < self._max_buffer_tokens:
try:
samples: list[ModelInput] = next(self._data_iter)
except StopIteration:
break
num_tokens = sum(len(sample["input_ids"]) for sample in samples)
self._buffer.extend(samples)
self._buffer_tokens += num_tokens
return self._build_batch()
def _build_batch(self) -> list[BatchInput] | None:
if self.batching_strategy == BatchingStrategy.NORMAL:
self._buffer, self._buffer_tokens, batch = default_collate_fn(
self._buffer, self._buffer_tokens, self.micro_batch_size, self.num_micro_batch, self.cutoff_len
)
return batch
else:
from ...plugins.trainer_plugins.batching import BatchingPlugin
self._buffer, self._buffer_tokens, batch = BatchingPlugin(self.batching_strategy)(
self._buffer, self._buffer_tokens, self.micro_batch_size, self.num_micro_batch, self.cutoff_len
)
return batch
def state_dict(self) -> dict[str, Any]:
return {
"buffer": self._buffer,
"buffer_tokens": self._buffer_tokens,
"data_provider": self._data_provider.state_dict(),
}
def load_state_dict(self, state: dict[str, Any]) -> None:
self._buffer = state["buffer"]
self._buffer_tokens = state["buffer_tokens"]
self._data_provider.load_state_dict(state["data_provider"])
self._is_resuming = True
def set_epoch(self, epoch: int) -> None:
if hasattr(self._data_provider.sampler, "set_epoch"):
self._data_provider.sampler.set_epoch(epoch)
if __name__ == "__main__":
"""
python -m llamafactory.v1.core.utils.batching \
--model llamafactory/tiny-random-qwen2.5 \
--dataset data/v1_sft_demo.yaml \
--micro_batch_size 2 \
--global_batch_size 4 \
--batching_workers 0
"""
from ...config.arg_parser import get_args
from ..data_engine import DataEngine
from ..model_engine import ModelEngine
data_args, model_args, training_args, _ = get_args()
data_engine = DataEngine(data_args=data_args)
model_engine = ModelEngine(model_args=model_args)
batch_generator = BatchGenerator(
data_engine,
model_engine.renderer,
micro_batch_size=training_args.micro_batch_size,
global_batch_size=training_args.global_batch_size,
cutoff_len=training_args.cutoff_len,
batching_workers=training_args.batching_workers,
batching_strategy=training_args.batching_strategy,
)
for batch in batch_generator:
print(batch)
print(len(batch))
print(batch[0]["input_ids"].shape)
break

View File

@@ -1,277 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import sys
from collections.abc import Generator, Iterator
from dataclasses import dataclass
from typing import Optional
from torchdata.stateful_dataloader import StatefulDataLoader
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
from ...utils.batching_queue import BaseBatchingQueue
from ...utils.logging import get_logger
from ...utils.types import Processor, TorchDataset
from .data_collator import DataCollator
logger = get_logger(__name__)
# base dataloader
class DistributedDataloader(StatefulDataLoader):
"""Base Distributed DataLoader."""
dataset: "TorchDataset"
sampler: "StatefulDistributedSampler"
def set_epoch(self, epoch: int) -> None:
if self.sampler is not None and hasattr(self.sampler, "set_epoch"):
self.sampler.set_epoch(epoch)
elif hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)
@dataclass
class BaseDataLoader:
"""Default DataLoader."""
processor: Processor
def __init__(self, dataset: TorchDataset) -> None:
self.dataset = dataset
# guidlines: fetch until get fixed batchsize.
# save state_dict for buffer.
# resume with state
# 1. Init stateful dataloader (tokenize)
# 2. Add to buffer (2 * max seq len per device)
# 3. Yield batch indexes (micro batch * grad acc)
# a ) non pack + non dynamic
# b ) non pack + dynamic
# c ) pack + non dynamic
# d ) pack + dynamic
def init_dataloader(self) -> None:
### init dataloader
pass
def __iter__(self) -> Iterator:
pass
def __next__(self) -> any:
pass
@dataclass
class DataLoader:
"""Default DataLoader."""
processor: "Processor"
dataloader: "DistributedDataloader"
batching_queue: "BaseBatchingQueue"
collate_fn: "DataCollator"
num_micro_batch: int = 1
length: int = 0
drop_last: bool = True
def __init__(
self,
dataloader: any,
collate_fn: "DataCollator",
num_micro_batch: int = 1,
length: int = 0,
drop_last: bool = True,
batching_queue: Optional["BaseBatchingQueue"] = None,
) -> None:
self.batching_queue = batching_queue
self.num_micro_batch = num_micro_batch
self.step = 0
self._collate_fn = collate_fn
self._dataloader = dataloader
self._drop_last = drop_last
self._data_iter: Iterator
self._resume = False
self._batch_data_iter: Generator
if length > 0:
self._length = length
elif length == -1:
self._length = sys.maxsize
else:
self._length = len(self._dataloader)
def __len__(self):
return self._length
def __iter__(self) -> Iterator:
if not self._resume:
self.step = 0
self._data_iter = iter(self._dataloader)
self._batch_data_iter = self.batch_data_generator()
self._resume = False
return self
def __next__(self):
return next(self._batch_data_iter) # FIXME maybe we can move origin_batch_data_generator to here
def origin_batch_data_generator(self):
"""Standard pass-through generator if do not use batching queue."""
while True:
if self._length > 0 and self.step >= self._length:
return
try:
batch = []
data = next(self._data_iter)
# split data into micro batches
for i in range(0, len(data), self.num_micro_batch):
micro_batch = data[i : i + self.num_micro_batch]
if self._collate_fn:
micro_batch = self._collate_fn(micro_batch)
batch.append(micro_batch)
yield batch
self.step += 1
except StopIteration:
if self.step < self._length:
# Restart iterator to fill the requested length
self._data_iter = iter(self._dataloader)
try:
batch = []
data = next(self._data_iter)
for i in range(0, len(data), self.num_micro_batch):
micro_batch = data[i : i + self.num_micro_batch]
if self._collate_fn:
micro_batch = self._collate_fn(micro_batch)
batch.append(micro_batch)
yield batch
self.step += 1
except StopIteration:
return
else:
return
except Exception as e:
logger.error(f"DataLoader origin_batch_data_generator exception: {e}")
raise
def batch_data_generator(self):
if self.batching_queue is None:
yield from self.origin_batch_data_generator()
return
batch = []
while True:
if self._length and self.step >= self._length:
return
if self.batching_queue.is_full_filled():
micro_batch = self.batching_queue.get_micro_batch(self.step)
if self._collate_fn:
micro_batch = self._collate_fn(micro_batch)
batch.append(micro_batch)
if len(batch) == self.num_micro_batch:
yield batch
self.step += 1
batch = []
try:
processing_item = next(self._data_iter)
except Exception as e:
if isinstance(e, StopIteration):
if self.step < self._length:
# call iter until reach length
self._data_iter = iter(self._dataloader)
processing_item = next(self._data_iter)
elif not self._drop_last and not self.batching_queue.empty():
while not self.batching_queue.empty():
micro_batch = self.batching_queue.get_micro_batch(self.step)
if self._collate_fn:
micro_batch = self._collate_fn(micro_batch)
batch.append(micro_batch)
if len(batch) == self.num_micro_batch:
yield batch
self.step += 1
batch = []
while len(batch) < self.num_micro_batch:
padding_batch = copy.deepcopy(micro_batch)
padding_batch["is_padded"] = True
batch.append(padding_batch)
yield batch
self.step += 1
return
else:
return
else:
logger.error(f"DataLoader iter data exception: {e}")
raise
# put processing_item to buffer
if isinstance(processing_item, dict):
processing_item = [processing_item]
for item in processing_item:
self.batching_queue.put_item(item)
def state_dict(self):
# save state
state = self.__dict__.copy()
# remove internal fields
for k in list(state.keys()):
if k.startswith("_"):
del state[k]
# save dataloader state
if hasattr(self._dataloader, "state_dict"):
state["dataloader_state"] = self._dataloader.state_dict()
elif hasattr(self._dataloader, "__getstate__"):
state["dataloader_state"] = self._dataloader.__getstate__()
batching_strategy = getattr(self, "batching_strategy", None)
if batching_strategy and hasattr(batching_strategy, "state_dict"):
state["batching_strategy_state"] = batching_strategy.state_dict()
if "batching_strategy" in state:
del state["batching_strategy"]
return copy.deepcopy(state)
def load_state_dict(self, state: dict[str, any]):
if state["num_micro_batch"] != self.num_micro_batch:
logger.warning(
f"num_micro_batch changed: [ {state['num_micro_batch']} -> {self.num_micro_batch} ], will clear prefetch buffer"
)
del state["num_micro_batch"]
self.__dict__.update(state)
self._resume = True
if hasattr(self._dataloader, "load_state_dict"):
self._dataloader.load_state_dict(state["dataloader_state"])
elif hasattr(self._dataloader, "__getstate__"):
self._dataloader.__setstate__(state["dataloader_state"])
if "batching_strategy_state" in state:
batching_strategy = getattr(self, "batching_strategy", None)
if batching_strategy:
batching_strategy.load_state_dict(state["batching_strategy_state"])
del state["batching_strategy_state"]
self._data_iter = iter(self._dataloader)
self._batch_data_iter = self.batch_data_generator()
def set_epoch(self, epoch: int) -> None:
if hasattr(self._dataloader, "set_epoch"):
self._dataloader.set_epoch(epoch)

View File

@@ -12,10 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rendering utils.
How to use:
renderer = Renderer(template, processor)
renderer.render_messages(messages: list[Message], tools: str | None) -> ModelInputs
renderer.parse_message(text: str) -> Message
renderer.process_samples(samples: list[Sample]) -> list[ModelInput]
"""
import numpy as np
from ...utils.constants import IGNORE_INDEX
from ...utils.helper import get_tokenizer
from ...utils.types import Message, ModelInput, Processor
from ...utils.types import Message, ModelInput, Processor, Sample
def render_chatml_messages(
@@ -64,7 +74,7 @@ def render_chatml_messages(
def parse_chatml_message(generated_text: str) -> Message:
"""Parse a message in ChatML format. Supports interleaved reasoning and tool calls.
"""Parse a message in ChatML format.
Args:
generated_text (str): The generated text in ChatML format.
@@ -83,6 +93,16 @@ class Renderer:
def render_messages(
self, messages: list[Message], tools: str | None = None, is_generate: bool = False
) -> ModelInput:
"""Apply template to messages and convert them to model input.
Args:
messages (list[Message]): The messages to render.
tools (str | None, optional): The tools to use. Defaults to None.
is_generate (bool, optional): Whether to render for generation. Defaults to False.
Returns:
ModelInput: The rendered model input.
"""
if self.template == "chatml":
return render_chatml_messages(self.processor, messages, tools, is_generate)
else:
@@ -91,9 +111,59 @@ class Renderer:
return RenderingPlugin(self.template).render_messages(self.processor, messages, tools, is_generate)
def parse_message(self, generated_text: str) -> Message:
"""Parse a message in the template format.
Args:
generated_text (str): The generated text in the template format.
Returns:
Message: The parsed message.
"""
if self.template == "chatml":
return parse_chatml_message(generated_text)
else:
from ...plugins.model_plugins.rendering import RenderingPlugin
return RenderingPlugin(self.template).parse_message(generated_text)
def process_samples(self, samples: list[Sample]) -> list[ModelInput]:
"""Process samples to model input.
Args:
samples (list[Sample]): The samples to process.
Returns:
list[ModelInput]: The processed model inputs.
"""
model_inputs = []
for sample in samples:
if "messages" in sample:
model_input = self.render_messages(sample["messages"], sample.get("tools"))
elif "chosen_messages" in sample and "rejected_messages" in sample:
chosen_input = self.render_messages(sample["chosen_messages"], sample.get("tools"))
rejected_input = self.render_messages(sample["rejected_messages"], sample.get("tools"))
chosen_input["token_type_ids"] = [0] * len(chosen_input["input_ids"])
rejected_input["token_type_ids"] = [1] * len(rejected_input["input_ids"])
model_input = ModelInput(
input_ids=chosen_input["input_ids"] + rejected_input["input_ids"],
attention_mask=chosen_input["attention_mask"] + rejected_input["attention_mask"],
labels=chosen_input["labels"] + rejected_input["labels"],
loss_weights=chosen_input["loss_weights"] + rejected_input["loss_weights"],
token_type_ids=chosen_input["token_type_ids"] + rejected_input["token_type_ids"],
)
if "position_ids" in chosen_input:
model_input["position_ids"] = np.concatenate(
[chosen_input["position_ids"], rejected_input["position_ids"]], axis=-1
)
else:
raise ValueError("No valid messages or chosen_messages/rejected_messages found in sample.")
if "extra_info" in sample:
model_input["extra_info"] = sample["extra_info"]
if "_dataset_name" in sample:
model_input["_dataset_name"] = sample["_dataset_name"]
model_inputs.append(model_input)
return model_inputs

View File

@@ -32,7 +32,8 @@ class AlpacaSample(TypedDict, total=False):
SharegptMessage = TypedDict(
"SharegptMessage", {"from": Literal["human", "gpt", "system", "function_call", "observation"], "value": str}
"SharegptMessage",
{"from": Literal["human", "gpt", "system", "function_call", "observation"], "value": str},
)
@@ -118,15 +119,8 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
"observation": "tool",
"function_call": "assistant",
}
sample = {}
messages = []
tools = raw_sample.get("tools")
if tools:
try:
tools: list[dict[str, Any]] = json.loads(tools)
except json.JSONDecodeError:
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
tools = []
for message in raw_sample.get("conversations", []):
tag = message["from"]
if tag not in tag_mapping:
@@ -157,10 +151,17 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
}
)
sample["messages"] = messages
tools = raw_sample.get("tools")
if tools:
return {"messages": messages, "tools": json.dumps(tools)}
else:
return {"messages": messages}
try:
tools: list[dict[str, Any]] = json.loads(tools)
sample["tools"] = json.dumps(tools)
except json.JSONDecodeError:
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
return sample
@DataConverterPlugin("pair").register()
@@ -179,6 +180,24 @@ def pair_converter(raw_sample: PairSample) -> DPOSample:
def process_message(raw_messages: list[OpenaiMessage]):
messages = []
for message in raw_messages:
if message["role"] == "tool":
try:
tool_calls: ToolCall | list[ToolCall] = json.loads(message["content"])
except json.JSONDecodeError:
logger.warning_rank0(f"Invalid tool call format: {str(message['content'])}")
continue
if not isinstance(tool_calls, list):
tool_calls = [tool_calls]
messages.append(
{
"role": message["role"],
"content": [{"type": "tool_call", "value": json.dumps(tool_call)} for tool_call in tool_calls],
"loss_weight": 1.0 if message["role"] == "assistant" else 0.0,
}
)
else:
messages.append(
{
"role": message["role"],
@@ -189,7 +208,16 @@ def pair_converter(raw_sample: PairSample) -> DPOSample:
return messages
chosen_messages = process_message(raw_sample.get("chosen", []))
rejected_messages = process_message(raw_sample.get("rejected", []))
sample = {}
sample["chosen_messages"] = process_message(raw_sample.get("chosen", []))
sample["rejected_messages"] = process_message(raw_sample.get("rejected", []))
return {"chosen_messages": chosen_messages, "rejected_messages": rejected_messages}
tools = raw_sample.get("tools")
if tools:
try:
tools: list[dict[str, Any]] = json.loads(tools)
sample["tools"] = json.dumps(tools)
except json.JSONDecodeError:
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
return sample

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
@@ -51,12 +50,16 @@ def _update_model_input(
@RenderingPlugin("qwen3_nothink").register("render_messages")
def render_qwen_messages(
def render_qwen3_nothink_messages(
processor: Processor,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
) -> ModelInput:
"""Render messages in the Qwen3 nothink template format.
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-4B-Instruct-2507
"""
input_ids, labels, loss_weights = [], [], []
temp_str, temp_weight = "", 0.0
if tools:
@@ -179,7 +182,15 @@ def render_qwen_messages(
@RenderingPlugin("qwen3_nothink").register("parse_message")
def parse_qwen_message(generated_text: str) -> Message:
def parse_qwen3_nothink_message(generated_text: str) -> Message:
"""Parse a message in the Qwen3 nothink template format. Supports interleaved reasoning and tool calls.
Args:
generated_text (str): The generated text in the Qwen3 nothink template format.
Returns:
Message: The parsed message.
"""
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
content = []
last_end = 0

View File

@@ -0,0 +1,19 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils.plugin import BasePlugin
class BatchingPlugin(BasePlugin):
pass

View File

@@ -14,7 +14,7 @@
from ..accelerator.interface import DistributedInterface
from ..config.arg_parser import get_args
from ..config import InputArgument, get_args
from ..core.base_trainer import BaseTrainer
from ..core.data_engine import DataEngine
from ..core.model_engine import ModelEngine
@@ -24,15 +24,15 @@ class SFTTrainer(BaseTrainer):
pass
def run_sft(user_args):
model_args, data_args, training_args, _ = get_args(user_args)
def run_sft(args: InputArgument = None):
model_args, data_args, training_args, _ = get_args(args)
DistributedInterface(training_args.dist_config)
data_engine = DataEngine(data_args)
model_engine = ModelEngine(model_args)
trainer = SFTTrainer(
args=training_args,
model=model_engine.model,
processor=model_engine.processor,
renderer=model_engine.renderer,
dataset=data_engine,
)
trainer.fit()

View File

@@ -1,220 +0,0 @@
# Copyright 2025 Bytedance Ltd. and the LlamaFactory team.
#
# This code is inspired by the Bytedance's VeOmni library.
# https://github.com/ByteDance-Seed/VeOmni/blob/v0.1.4/veomni/data/dynamic_batching.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
class DynamicBatchSizeBuffer:
"""A buffer to store samples for dynamic batch size."""
def __init__(self):
self._buffer: list[dict[str, any]] = []
self._buffer_sample_lengths: list[int] = []
self._deleted_indices: set[int] = set()
self._current_index: int = 0
self._total_token_count: int = 0
def append(self, item: dict[str, any]) -> None:
"""Append a sample to the buffer.
Args:
item: A sample to append to the buffer.
The sample should be a dict with the following keys:
- input_ids: torch.Tensor of shape (seq_len, )
- attention_mask: torch.Tensor of shape (seq_len, )
"""
self._buffer.append(item)
sample_length = int(item["attention_mask"].sum().item())
self._buffer_sample_lengths.append(sample_length)
self._total_token_count += sample_length
def get_samples(self, max_tokens_per_iteration: int, force: bool = True) -> list[dict[str, any]]:
"""Get samples from the buffer that fit within the token budget.
Args:
max_tokens_per_iteration: Maximum number of tokens to retrieve.
force: If True, the first available sample will be returned even
if it exceeds the token budget.
Returns:
A list of samples that fit within the token budget.
Raises:
AssertionError: If no samples are found (should not happen in normal operation).
"""
cum_seq_len = 0
samples = []
while self._current_index < len(self._buffer) and cum_seq_len < max_tokens_per_iteration:
if self._current_index in self._deleted_indices:
self._current_index += 1
continue
seq_len = self._buffer_sample_lengths[self._current_index]
remaining_tokens = max_tokens_per_iteration - cum_seq_len
# Check if we can add this sample
can_add = (force and cum_seq_len == 0) or (seq_len <= remaining_tokens)
if can_add:
cum_seq_len += seq_len
samples.append(self._buffer[self._current_index])
self._deleted_indices.add(self._current_index)
self._current_index += 1
assert len(samples) > 0, "No samples found in buffer"
return samples
def __len__(self) -> int:
"""Return the number of samples in the buffer."""
return len(self._buffer)
@property
def total_token_count(self) -> int:
"""Return the total number of tokens in the buffer."""
return self._total_token_count
def flush(self) -> None:
tokens_to_remove = sum(self._buffer_sample_lengths[idx] for idx in self._deleted_indices)
self._total_token_count -= tokens_to_remove
buffer_length = len(self._buffer)
self._buffer = [self._buffer[idx] for idx in range(buffer_length) if idx not in self._deleted_indices]
self._buffer_sample_lengths = [
self._buffer_sample_lengths[idx] for idx in range(buffer_length) if idx not in self._deleted_indices
]
self._current_index = 0
self._deleted_indices.clear()
class BaseBatchingQueue(ABC):
"""Base class for batching queue."""
@abstractmethod
def is_full_filled(self) -> bool:
raise NotImplementedError("Subclasses must implement `is_full_filled`")
@abstractmethod
def put_item(self, item: dict[str, any]) -> None:
raise NotImplementedError("Subclasses must implement `put_item`")
@abstractmethod
def get_micro_batch(self, step: int) -> list[dict[str, any]]:
raise NotImplementedError("Subclasses must implement `get_micro_batch`")
@abstractmethod
def empty(self) -> bool:
raise NotImplementedError("Subclasses must implement `empty`")
class IdentityPacker:
def __init__(self, token_micro_bsz, bsz_warmup_steps, bsz_warmup_init_mbtoken):
self.token_micro_bsz = token_micro_bsz
self.bsz_warmup_steps = bsz_warmup_steps
self.bsz_warmup_init_mbtoken = bsz_warmup_init_mbtoken
def __call__(self, samples):
return samples
def get_token_num_to_request(self, cur_step, warmup):
return (
(self.token_micro_bsz - self.bsz_warmup_init_mbtoken) * cur_step // self.bsz_warmup_steps
+ self.bsz_warmup_init_mbtoken
if warmup
else self.token_micro_bsz
)
class TextBatchingQueue(BaseBatchingQueue):
"""Batching text queue for text data."""
def __init__(
self,
token_micro_bsz,
buffer_size: int = 500,
bsz_warmup_steps: int = -1,
bsz_warmup_init_mbtoken: int = 200,
) -> None:
super().__init__()
self._step = 0
self.token_micro_bsz = token_micro_bsz
self.bsz_warmup_steps = bsz_warmup_steps
self.buffer_size = buffer_size # minimum samples in buffer
self.buffer = DynamicBatchSizeBuffer()
self.bsz_warmup_init_mbtoken = bsz_warmup_init_mbtoken # training warmup args
assert self.bsz_warmup_init_mbtoken >= 0
self.packer = IdentityPacker(
token_micro_bsz=token_micro_bsz,
bsz_warmup_steps=bsz_warmup_steps,
bsz_warmup_init_mbtoken=bsz_warmup_init_mbtoken,
)
def is_full_filled(self) -> bool:
return len(self.buffer) >= self.buffer_size and self.buffer.total_token_count >= self.token_micro_bsz
def put_item(self, item: dict[str, any]):
if len(item["input_ids"]) == 1:
print("WARNING: EMPTY STRING.")
return
self.buffer.append(item)
def get_token_num_to_request(self):
if self.packer is not None:
warmup = self._step <= self.bsz_warmup_steps and self.bsz_warmup_steps > 0
return self.packer.get_token_num_to_request(self._step, warmup=warmup)
else:
return self.get_cur_token_micro_bsz()
def get_cur_token_micro_bsz(self):
warmup = self._step <= self.bsz_warmup_steps and self.bsz_warmup_steps > 0
if warmup:
return (
self.token_micro_bsz - self.bsz_warmup_init_mbtoken
) * self._step // self.bsz_warmup_steps + self.bsz_warmup_init_mbtoken
else:
return self.token_micro_bsz
def get_micro_batch(self, step) -> any:
"""Get a micro batch from the buffer according to the current step.
Args:
step: the current step.
Returns:
data: a list of samples.
"""
self._step = step
n_token_per_iter = self.get_token_num_to_request()
cur_token_micro_bsz = self.get_cur_token_micro_bsz()
assert cur_token_micro_bsz % n_token_per_iter == 0, (
"The token num to get for each request should be divisible by token micro bsz."
)
n_iter = int(cur_token_micro_bsz // n_token_per_iter)
data = []
for _ in range(n_iter):
samples = self.buffer.get_samples(n_token_per_iter)
if self.packer:
samples = self.packer(samples) # maybe packed into one sample, but wrapped in list.
data.extend(samples)
self.buffer.flush() # remove the selected samples.
return data
def empty(self) -> bool:
return len(self.buffer) == 0

View File

@@ -32,8 +32,8 @@ class DtypeRegistry:
class DtypeInterface:
"""Type of precision used."""
_is_fp16_available = is_torch_fp16_available_on_device(DistributedInterface.current_accelerator)
_is_bf16_available = is_torch_bf16_available_on_device(DistributedInterface.current_accelerator)
_is_fp16_available = is_torch_fp16_available_on_device(DistributedInterface().current_device)
_is_bf16_available = is_torch_bf16_available_on_device(DistributedInterface().current_device)
_is_fp32_available = True
@staticmethod

View File

@@ -12,9 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from transformers import PreTrainedTokenizer
from .types import Processor
from .constants import IGNORE_INDEX
from .types import BatchInput, ModelInput, Processor, Tensor
def is_tokenizer(processor: Processor) -> bool:
"""Check if processor is tokenizer.
Args:
processor: Processor.
Returns:
Whether processor is tokenizer.
"""
return not hasattr(processor, "tokenizer")
def get_tokenizer(processor: Processor) -> PreTrainedTokenizer:
@@ -27,3 +42,34 @@ def get_tokenizer(processor: Processor) -> PreTrainedTokenizer:
Tokenizer.
"""
return processor.tokenizer if hasattr(processor, "tokenizer") else processor
def _pad_and_truncate(tensor: Tensor, max_seqlen: int, pad_value: int = 0) -> Tensor:
if tensor.shape[-1] >= max_seqlen:
return tensor[..., :max_seqlen]
pad_shape = list(tensor.shape)
pad_shape[-1] = max_seqlen - tensor.shape[-1]
pad_tensor = torch.full(pad_shape, pad_value, dtype=tensor.dtype, device=tensor.device)
return torch.cat([tensor, pad_tensor], dim=-1)
def pad_and_truncate(samples: list[ModelInput], max_seqlen: int) -> list[BatchInput]:
max_length = min(max(len(sample["input_ids"]) for sample in samples), max_seqlen)
padded_samples = []
for sample in samples:
padded_sample = {}
for key, value in sample.items():
if "label" in key:
pad_value = IGNORE_INDEX
else:
pad_value = 0
if not isinstance(value, str):
padded_sample[key] = _pad_and_truncate(torch.tensor(value), max_length, pad_value)
else:
padded_sample[key] = value
padded_samples.append(padded_sample)
return padded_samples

View File

@@ -144,3 +144,20 @@ class ModelInput(TypedDict, total=False):
"""Loss weight for each token, default to 1.0."""
position_ids: NotRequired[list[int] | list[list[int]]]
"""Position ids for the model (optional)."""
token_type_ids: NotRequired[list[int]]
"""Token type ids used in DPO, 0 represents the chosen messages, 1 represents the rejected messages."""
class BatchInput(TypedDict, total=False):
input_ids: Tensor
"""Input ids for the model."""
attention_mask: Tensor
"""Attention mask for the model."""
labels: Tensor
"""Labels for the model."""
loss_weights: Tensor
"""Loss weight for each token, default to 1.0."""
position_ids: NotRequired[Tensor]
"""Position ids for the model (optional)."""
token_type_ids: NotRequired[Tensor]
"""Token type ids used in DPO, 0 represents the chosen messages, 1 represents the rejected messages."""

View File

@@ -1,2 +1,2 @@
# change if test fails or cache is outdated
0.9.5.103
0.9.5.104

View File

@@ -56,4 +56,5 @@ def test_all_device():
@pytest.mark.require_distributed(2)
def test_multi_device():
master_port = find_available_port()
mp.spawn(_all_reduce_tests, args=(2, master_port), nprocs=2)
world_size = 2
mp.spawn(_all_reduce_tests, args=(world_size, master_port), nprocs=world_size)

View File

@@ -14,28 +14,24 @@
import torch
from llamafactory.v1.config.model_args import ModelArguments, PluginConfig
from llamafactory.v1.config.model_args import ModelArguments
from llamafactory.v1.core.model_engine import ModelEngine
def test_tiny_qwen():
from transformers import Qwen2Config, Qwen2ForCausalLM, Qwen2TokenizerFast
model_args = ModelArguments(model="llamafactory/tiny-random-qwen2.5")
model_args = ModelArguments(model="llamafactory/tiny-random-qwen3")
model_engine = ModelEngine(model_args)
assert isinstance(model_engine.processor, Qwen2TokenizerFast)
assert isinstance(model_engine.model_config, Qwen2Config)
assert isinstance(model_engine.model, Qwen2ForCausalLM)
assert "Qwen2Tokenizer" in model_engine.processor.__class__.__name__
assert "Qwen3Config" in model_engine.model_config.__class__.__name__
assert "Qwen3ForCausalLM" in model_engine.model.__class__.__name__
assert model_engine.model.dtype == torch.bfloat16
def test_tiny_qwen_with_kernel_plugin():
from transformers import Qwen2ForCausalLM
from llamafactory.v1.plugins.model_plugins.kernels.ops.rms_norm.npu_rms_norm import npu_rms_norm_forward
model_args = ModelArguments(
model="llamafactory/tiny-random-qwen2.5", kernel_config=PluginConfig(name="auto", include_kernels="auto")
model="llamafactory/tiny-random-qwen3", kernel_config={"name": "auto", "include_kernels": "auto"}
)
model_engine = ModelEngine(model_args)
# test enable apply kernel plugin
@@ -44,7 +40,7 @@ def test_tiny_qwen_with_kernel_plugin():
else:
assert model_engine.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
assert isinstance(model_engine.model, Qwen2ForCausalLM)
assert "Qwen3ForCausalLM" in model_engine.model.__class__.__name__
if __name__ == "__main__":

View File

@@ -0,0 +1,49 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from llamafactory.v1.config import DataArguments, ModelArguments, TrainingArguments
from llamafactory.v1.core.data_engine import DataEngine
from llamafactory.v1.core.model_engine import ModelEngine
from llamafactory.v1.core.utils.batching import BatchGenerator
def test_normal_batching():
data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
data_engine = DataEngine(data_args=data_args)
model_args = ModelArguments(model="llamafactory/tiny-random-qwen3")
model_engine = ModelEngine(model_args=model_args)
training_args = TrainingArguments(
micro_batch_size=4,
global_batch_size=8,
cutoff_len=10,
batching_workers=0,
batching_strategy="normal",
)
batch_generator = BatchGenerator(
data_engine,
model_engine.renderer,
micro_batch_size=training_args.micro_batch_size,
global_batch_size=training_args.global_batch_size,
cutoff_len=training_args.cutoff_len,
batching_workers=training_args.batching_workers,
batching_strategy=training_args.batching_strategy,
)
assert len(batch_generator) == len(data_engine) // training_args.global_batch_size
batch = next(iter(batch_generator))
assert len(batch) == 2
assert batch[0]["input_ids"].shape == (4, 10)
if __name__ == "__main__":
test_normal_batching()

View File

@@ -1,171 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Integration tests for DataLoader with different combinations of packing and dynamic batching.
Tests the 4 scenarios:
a) non pack + non dynamic.
b) non pack + dynamic.
c) pack + non dynamic.
d) pack + dynamic.
"""
# import torch
# from torch.utils.data import DataLoader as TorchDataLoader
# from torch.utils.data import Dataset
# from transformers import AutoTokenizer
# from llamafactory.v1.config.data_args import DataArguments
# from llamafactory.v1.core.data_engine import DataEngine
# from llamafactory.v1.core.utils.data_collator import DefaultCollator
# from llamafactory.v1.core.utils.data_loader import DataLoader
# from llamafactory.v1.plugins.data_plugins.rendering import QwenTemplate
# from llamafactory.v1.utils.batching_queue import TextBatchingQueue
# class TensorDataset(Dataset):
# """Wrapper dataset that converts DataEngine samples to tensor format."""
# def __init__(self, data_engine: DataEngine, processor, template, max_samples: int = None):
# self.data_engine = data_engine
# self.processor = processor
# self.template = template
# self.max_samples = max_samples or len(data_engine)
# self.tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
# def __len__(self):
# return min(self.max_samples, len(self.data_engine))
# def __getitem__(self, idx):
# # Get sample from DataEngine
# sample = self.data_engine[idx]
# # Extract messages from sample
# # DataEngine returns samples with format like {"messages": [...], ...}
# # For llamafactory/v1-sft-demo, the format should have "messages" field
# messages = None
# if "messages" in sample:
# messages = sample["messages"]
# elif "conversations" in sample:
# messages = sample["conversations"]
# elif "conversation" in sample:
# messages = sample["conversation"]
# else:
# # Try to find message-like fields (skip _dataset_name)
# for key, value in sample.items():
# if key.startswith("_"):
# continue
# if isinstance(value, list) and len(value) > 0:
# # Check if it looks like a message list
# if isinstance(value[0], dict) and "role" in value[0]:
# messages = value
# break
# if messages is None:
# raise ValueError(f"Could not find messages in sample: {list(sample.keys())}")
# # Encode messages using template
# encoded = self.template.encode_messages(self.tokenizer, messages)
# # Convert to tensors
# return {
# "input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
# "attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long),
# "labels": torch.tensor(encoded["labels"], dtype=torch.long),
# }
# def create_real_dataset(max_samples: int = 20, batch_size: int = 4):
# """Create a real dataset using DataEngine."""
# data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
# data_engine = DataEngine(data_args)
# # Create processor and template
# processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen2.5")
# template = QwenTemplate()
# # Create tensor dataset
# raw_data_dataset = TensorDataset(data_engine, processor, template, max_samples=max_samples)
# # Create torch DataLoader
# torch_dataloader = TorchDataLoader(
# raw_data_dataset,
# batch_size=batch_size,
# shuffle=False,
# collate_fn=lambda x: x,
# )
# return torch_dataloader, processor, template
# class TestDataLoaderNonPackNonDynamic:
# """Test case a) non pack + non dynamic."""
# def test_basic_functionality(self):
# """Test DataLoader without packing and without dynamic batching."""
# # Create real dataset
# torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
# # Create collator (non-packing)
# collator = DefaultCollator(processor=processor, template=template)
# # Create DataLoader without batching_queue (non-dynamic)
# data_loader = DataLoader(
# dataloader=torch_dataloader,
# collate_fn=collator,
# num_micro_batch=1,
# batching_queue=None,
# )
# # Iterate and check results
# batches = list(iter(data_loader))
# assert len(batches) > 0
# # Check first batch
# one_batch = batches[0]
# micro_batches = one_batch[0]
# assert "input_ids" in micro_batches
# assert "attention_mask" in micro_batches
# assert "labels" in micro_batches
# assert micro_batches["input_ids"].shape[0] == 1 # batch_size=1
# assert micro_batches["input_ids"].ndim == 2 # [batch_size, seq_len]
# class TestDataLoaderNonPackDynamic:
# """Test case b) non pack + dynamic."""
# def test_basic_functionality(self):
# """Test DataLoader without packing but with dynamic batching."""
# # Create real dataset
# torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
# collator = DefaultCollator(processor=processor, template=template)
# # Create batching queue for dynamic batching
# batching_queue = TextBatchingQueue(
# token_micro_bsz=120,
# buffer_size=8,
# )
# data_loader = DataLoader(
# dataloader=torch_dataloader,
# collate_fn=collator,
# num_micro_batch=4,
# batching_queue=batching_queue,
# )
# # Iterate and check
# batches = list(iter(data_loader))
# micro_batch_tokens_first = [micro_batch["attention_mask"].sum() for micro_batch in batches[0]]
# assert all(num_tokens <= 120 for num_tokens in micro_batch_tokens_first)
# assert len(batches) > 0

View File

@@ -184,6 +184,40 @@ def test_qwen3_nothink_rendering_remote(num_samples: int):
assert v1_inputs["input_ids"][: len(prefix)] == prefix
def test_process_sft_samples():
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
renderer = Renderer(template="chatml", processor=tokenizer)
hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES)
samples = [{"messages": V1_MESSAGES, "extra_info": "test", "_dataset_name": "default"}]
model_inputs = renderer.process_samples(samples)
assert len(model_inputs) == 1
assert model_inputs[0]["input_ids"] == hf_inputs
assert model_inputs[0]["extra_info"] == "test"
assert model_inputs[0]["_dataset_name"] == "default"
def test_process_dpo_samples():
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
renderer = Renderer(template="chatml", processor=tokenizer)
hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES)
samples = [
{
"chosen_messages": V1_MESSAGES,
"rejected_messages": V1_MESSAGES,
"extra_info": "test",
"_dataset_name": "default",
}
]
model_inputs = renderer.process_samples(samples)
assert len(model_inputs) == 1
assert model_inputs[0]["input_ids"] == hf_inputs * 2
assert model_inputs[0]["token_type_ids"] == [0] * len(hf_inputs) + [1] * len(hf_inputs)
assert model_inputs[0]["extra_info"] == "test"
assert model_inputs[0]["_dataset_name"] == "default"
if __name__ == "__main__":
test_chatml_rendering()
test_chatml_parse()
@@ -191,3 +225,5 @@ if __name__ == "__main__":
test_qwen3_nothink_rendering()
test_qwen3_nothink_parse()
test_qwen3_nothink_rendering_remote(16)
test_process_sft_samples()
test_process_dpo_samples()

View File

@@ -21,7 +21,7 @@ from llamafactory.v1.core.model_engine import ModelEngine
def test_init_on_meta():
_, model_args, *_ = get_args(
dict(
model="llamafactory/tiny-random-qwen2.5",
model="llamafactory/tiny-random-qwen3",
init_config={"name": "init_on_meta"},
)
)
@@ -32,7 +32,7 @@ def test_init_on_meta():
def test_init_on_rank0():
_, model_args, *_ = get_args(
dict(
model="llamafactory/tiny-random-qwen2.5",
model="llamafactory/tiny-random-qwen3",
init_config={"name": "init_on_rank0"},
)
)
@@ -46,7 +46,7 @@ def test_init_on_rank0():
def test_init_on_default():
_, model_args, *_ = get_args(
dict(
model="llamafactory/tiny-random-qwen2.5",
model="llamafactory/tiny-random-qwen3",
init_config={"name": "init_on_default"},
)
)

View File

@@ -43,7 +43,7 @@ def test_apply_kernel(mock_get_accelerator: MagicMock):
reload_kernels()
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen3")
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
original_swiglu_forward = model.model.layers[0].mlp.forward
model = apply_default_kernels(model=model, include_kernels="npu_fused_rmsnorm")
@@ -62,7 +62,7 @@ def test_apply_all_kernels(mock_get_accelerator: MagicMock):
reload_kernels()
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen3")
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
original_swiglu_forward = model.model.layers[0].mlp.forward

View File

@@ -1,112 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from llamafactory.v1.utils.batching_queue import DynamicBatchSizeBuffer, TextBatchingQueue
def create_sample(length: int):
"""Helper to create a mock sample with a specific token length."""
return {"input_ids": torch.ones(length), "attention_mask": torch.ones(length)}
class TestDynamicBatchSizeBuffer:
def test_append_and_token_count(self):
buffer = DynamicBatchSizeBuffer()
buffer.append(create_sample(10))
buffer.append(create_sample(20))
assert len(buffer) == 2
assert buffer.total_token_count == 30
def test_get_samples_within_budget(self):
buffer = DynamicBatchSizeBuffer()
buffer.append(create_sample(10))
buffer.append(create_sample(10))
buffer.append(create_sample(50)) # This one is large
# Request 25 tokens. Should get the first two (20 tokens total)
samples = buffer.get_samples(max_tokens_per_iteration=25)
assert len(samples) == 2
def test_force_return_first_sample(self):
buffer = DynamicBatchSizeBuffer()
buffer.append(create_sample(100))
# Even though budget is 50, force=True (default) should return the 100-token sample
samples = buffer.get_samples(max_tokens_per_iteration=50, force=True)
assert len(samples) == 1
assert len(samples[0]["input_ids"]) == 100
def test_flush_removes_used_samples(self):
buffer = DynamicBatchSizeBuffer()
buffer.append(create_sample(10))
buffer.append(create_sample(20))
# Take the first sample
buffer.get_samples(max_tokens_per_iteration=15)
buffer.flush()
assert len(buffer) == 1
assert buffer.total_token_count == 20
# The remaining sample should now be at the start
remaining = buffer.get_samples(max_tokens_per_iteration=50)
assert len(remaining[0]["input_ids"]) == 20
class TestTextBatchingQueue:
def test_is_full_filled(self):
queue = TextBatchingQueue(token_micro_bsz=100, buffer_size=2)
queue.put_item(create_sample(10))
assert not queue.is_full_filled() # Only 1 sample, buffer_size=2
queue.put_item(create_sample(10))
assert not queue.is_full_filled() # 2 samples, but only 20 tokens (min 100)
queue.put_item(create_sample(90))
assert queue.is_full_filled() # Meets both conditions
def test_warmup_logic(self):
# token_micro_bsz=1000, starts at 200, reaches 1000 at step 10
queue = TextBatchingQueue(token_micro_bsz=1000, bsz_warmup_steps=10, bsz_warmup_init_mbtoken=200)
# Step 0: should be init value
assert queue.get_cur_token_micro_bsz() == 200
# Step 5: halfway through warmup (200 + (800 * 5/10)) = 600
queue._step = 5
assert queue.get_cur_token_micro_bsz() == 600
# Step 11: past warmup
queue._step = 11
assert queue.get_cur_token_micro_bsz() == 1000
def test_get_micro_batch_integration(self):
queue = TextBatchingQueue(token_micro_bsz=50, buffer_size=1)
queue.put_item(create_sample(20))
queue.put_item(create_sample(20))
queue.put_item(create_sample(20))
# At step 0 (warmup not triggered as bsz_warmup_steps is -1 default),
# it should take samples up to 50 tokens.
batch = queue.get_micro_batch(step=0)
assert len(batch) == 2
assert queue.empty() is False
batch_2 = queue.get_micro_batch(step=1)
assert len(batch_2) == 1
assert queue.empty() is True