mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-02 20:43:38 +00:00
[v1] add batch generator (#9744)
This commit is contained in:
@@ -13,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from .arg_parser import InputArgument, get_args
|
from .arg_parser import InputArgument, get_args
|
||||||
from .arg_utils import ModelClass, SampleBackend
|
from .arg_utils import BatchingStrategy, ModelClass, SampleBackend
|
||||||
from .data_args import DataArguments
|
from .data_args import DataArguments
|
||||||
from .model_args import ModelArguments
|
from .model_args import ModelArguments
|
||||||
from .sample_args import SampleArguments
|
from .sample_args import SampleArguments
|
||||||
@@ -21,6 +21,7 @@ from .training_args import TrainingArguments
|
|||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"BatchingStrategy",
|
||||||
"DataArguments",
|
"DataArguments",
|
||||||
"InputArgument",
|
"InputArgument",
|
||||||
"ModelArguments",
|
"ModelArguments",
|
||||||
|
|||||||
@@ -50,6 +50,14 @@ class SampleBackend(StrEnum):
|
|||||||
VLLM = "vllm"
|
VLLM = "vllm"
|
||||||
|
|
||||||
|
|
||||||
|
@unique
|
||||||
|
class BatchingStrategy(StrEnum):
|
||||||
|
NORMAL = "normal"
|
||||||
|
PADDING_FREE = "padding_free"
|
||||||
|
DYNAMIC_BATCHING = "dynamic_batching"
|
||||||
|
DYNAMIC_PADDING_FREE = "dynamic_padding_free"
|
||||||
|
|
||||||
|
|
||||||
def _convert_str_dict(data: dict) -> dict:
|
def _convert_str_dict(data: dict) -> dict:
|
||||||
"""Parse string representation inside the dictionary.
|
"""Parse string representation inside the dictionary.
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,3 @@ class DataArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the dataset."},
|
metadata={"help": "Path to the dataset."},
|
||||||
)
|
)
|
||||||
cutoff_len: int = field(
|
|
||||||
default=2048,
|
|
||||||
metadata={"help": "Cutoff length for the dataset."},
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import os
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from .arg_utils import PluginConfig, get_plugin_config
|
from .arg_utils import BatchingStrategy, PluginConfig, get_plugin_config
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -29,18 +29,30 @@ class TrainingArguments:
|
|||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "Micro batch size for training."},
|
metadata={"help": "Micro batch size for training."},
|
||||||
)
|
)
|
||||||
global_batch_size: int = field(
|
global_batch_size: int | None = field(
|
||||||
default=1,
|
default=None,
|
||||||
metadata={"help": "Global batch size for training."},
|
metadata={"help": "Global batch size for training, default to DP size * micro batch size."},
|
||||||
)
|
)
|
||||||
learning_rate: float = field(
|
learning_rate: float = field(
|
||||||
default=1e-4,
|
default=1e-4,
|
||||||
metadata={"help": "Learning rate for training."},
|
metadata={"help": "Learning rate for training."},
|
||||||
)
|
)
|
||||||
|
cutoff_len: int = field(
|
||||||
|
default=2048,
|
||||||
|
metadata={"help": "Maximum sequence length for training."},
|
||||||
|
)
|
||||||
bf16: bool = field(
|
bf16: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use bf16 for training."},
|
metadata={"help": "Use bf16 for training."},
|
||||||
)
|
)
|
||||||
|
batching_strategy: BatchingStrategy = field(
|
||||||
|
default=BatchingStrategy.NORMAL,
|
||||||
|
metadata={"help": "Batching strategy for training."},
|
||||||
|
)
|
||||||
|
batching_workers: int = field(
|
||||||
|
default=16,
|
||||||
|
metadata={"help": "Number of workers for batching."},
|
||||||
|
)
|
||||||
dist_config: PluginConfig | None = field(
|
dist_config: PluginConfig | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Distribution configuration for training."},
|
metadata={"help": "Distribution configuration for training."},
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ Train Phase:
|
|||||||
|
|
||||||
from ..config.training_args import TrainingArguments
|
from ..config.training_args import TrainingArguments
|
||||||
from ..utils.types import HFModel, TorchDataset
|
from ..utils.types import HFModel, TorchDataset
|
||||||
from .utils.data_collator import DataCollator
|
|
||||||
from .utils.rendering import Renderer
|
from .utils.rendering import Renderer
|
||||||
|
|
||||||
|
|
||||||
@@ -45,7 +44,6 @@ class BaseTrainer:
|
|||||||
self.model = model
|
self.model = model
|
||||||
self.renderer = renderer
|
self.renderer = renderer
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.data_collator = DataCollator()
|
|
||||||
self.optimizer = None
|
self.optimizer = None
|
||||||
self.lr_scheduler = None
|
self.lr_scheduler = None
|
||||||
|
|
||||||
|
|||||||
@@ -82,14 +82,17 @@ class DataEngine(Dataset):
|
|||||||
|
|
||||||
def _load_dataset(self) -> None:
|
def _load_dataset(self) -> None:
|
||||||
"""Load datasets according to dataset info."""
|
"""Load datasets according to dataset info."""
|
||||||
|
is_streaming = [dataset_info.get("streaming", False) for dataset_info in self.dataset_infos.values()]
|
||||||
|
self.streaming = any(is_streaming)
|
||||||
|
if all(is_streaming) != any(is_streaming):
|
||||||
|
raise ValueError("All datasets must be streaming or non-streaming.")
|
||||||
|
|
||||||
for dataset_name, dataset_info in self.dataset_infos.items():
|
for dataset_name, dataset_info in self.dataset_infos.items():
|
||||||
split = dataset_info.get("split", "train")
|
split = dataset_info.get("split", "train")
|
||||||
streaming = dataset_info.get("streaming", False)
|
|
||||||
self.streaming |= streaming
|
|
||||||
if dataset_info.get("source", "hf_hub") == "hf_hub":
|
if dataset_info.get("source", "hf_hub") == "hf_hub":
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
self.datasets[dataset_name] = load_dataset(dataset_info["path"], split=split, streaming=streaming)
|
self.datasets[dataset_name] = load_dataset(dataset_info["path"], split=split, streaming=self.streaming)
|
||||||
else: # data loader plugin
|
else: # data loader plugin
|
||||||
from ..plugins.data_plugins.loader import DataLoaderPlugin
|
from ..plugins.data_plugins.loader import DataLoaderPlugin
|
||||||
|
|
||||||
@@ -98,8 +101,7 @@ class DataEngine(Dataset):
|
|||||||
def _build_data_index(self) -> None:
|
def _build_data_index(self) -> None:
|
||||||
"""Build dataset index."""
|
"""Build dataset index."""
|
||||||
for dataset_name, dataset in self.datasets.items():
|
for dataset_name, dataset in self.datasets.items():
|
||||||
streaming = self.dataset_infos[dataset_name].get("streaming", False)
|
if self.streaming:
|
||||||
if streaming:
|
|
||||||
data_index = [(dataset_name, -1) for _ in range(1000)]
|
data_index = [(dataset_name, -1) for _ in range(1000)]
|
||||||
else:
|
else:
|
||||||
data_index = [(dataset_name, sample_index) for sample_index in range(len(dataset))]
|
data_index = [(dataset_name, sample_index) for sample_index in range(len(dataset))]
|
||||||
@@ -185,8 +187,8 @@ class DataEngine(Dataset):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
"""
|
"""
|
||||||
python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_sft_demo.yaml
|
python -m llamafactory.v1.core.data_engine --dataset data/v1_sft_demo.yaml
|
||||||
python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_dpo_demo.yaml
|
python -m llamafactory.v1.core.data_engine --dataset data/v1_dpo_demo.yaml
|
||||||
"""
|
"""
|
||||||
from ..config.arg_parser import get_args
|
from ..config.arg_parser import get_args
|
||||||
|
|
||||||
|
|||||||
244
src/llamafactory/v1/core/utils/batching.py
Normal file
244
src/llamafactory/v1/core/utils/batching.py
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Batching utils supports stateful dataloader.
|
||||||
|
|
||||||
|
1. Init stateful dataloader (tokenize)
|
||||||
|
2. Add to buffer
|
||||||
|
3. Yield batch indexes (micro batch * grad acc)
|
||||||
|
a) non pack + non dynamic
|
||||||
|
b) non pack + dynamic
|
||||||
|
c) pack + non dynamic
|
||||||
|
d) pack + dynamic
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Iterator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from torch.utils.data import default_collate
|
||||||
|
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||||
|
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
|
||||||
|
|
||||||
|
from ...accelerator.interface import DistributedInterface
|
||||||
|
from ...config import BatchingStrategy
|
||||||
|
from ...utils import logging
|
||||||
|
from ...utils.helper import pad_and_truncate
|
||||||
|
from ...utils.types import BatchInput, ModelInput, TorchDataset
|
||||||
|
from .rendering import Renderer
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def default_collate_fn(
|
||||||
|
buffer: list[ModelInput], buffer_tokens: int, micro_batch_size: int, num_micro_batch: int, cutoff_len: int
|
||||||
|
) -> tuple[list[ModelInput], int, list[BatchInput]]:
|
||||||
|
batch_size = micro_batch_size * num_micro_batch
|
||||||
|
if len(buffer) < batch_size:
|
||||||
|
return buffer, buffer_tokens, None
|
||||||
|
|
||||||
|
samples = buffer[:batch_size]
|
||||||
|
buffer = buffer[batch_size:]
|
||||||
|
buffer_tokens -= sum(len(sample["input_ids"]) for sample in samples)
|
||||||
|
|
||||||
|
batch = []
|
||||||
|
for i in range(num_micro_batch):
|
||||||
|
micro_batch = samples[i * micro_batch_size : (i + 1) * micro_batch_size]
|
||||||
|
batch.append(default_collate(pad_and_truncate(micro_batch, cutoff_len)))
|
||||||
|
|
||||||
|
return buffer, buffer_tokens, batch
|
||||||
|
|
||||||
|
|
||||||
|
class BatchGenerator(Iterator):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset: TorchDataset,
|
||||||
|
renderer: Renderer,
|
||||||
|
micro_batch_size: int = 1,
|
||||||
|
global_batch_size: int | None = None,
|
||||||
|
cutoff_len: int = 2048,
|
||||||
|
batching_workers: int = 0,
|
||||||
|
batching_strategy: BatchingStrategy = BatchingStrategy.NORMAL,
|
||||||
|
pin_memory: bool = True,
|
||||||
|
drop_last: bool = True,
|
||||||
|
) -> None:
|
||||||
|
self.dataset = dataset
|
||||||
|
self.renderer = renderer
|
||||||
|
|
||||||
|
self.micro_batch_size = micro_batch_size
|
||||||
|
self.global_batch_size = global_batch_size
|
||||||
|
self.cutoff_len = cutoff_len
|
||||||
|
self.batching_workers = batching_workers
|
||||||
|
self.batching_strategy = batching_strategy
|
||||||
|
self.pin_memory = pin_memory
|
||||||
|
self.drop_last = drop_last
|
||||||
|
# TODO: support length and infinity
|
||||||
|
|
||||||
|
dp_size = DistributedInterface().get_world_size("dp")
|
||||||
|
|
||||||
|
if self.global_batch_size is None:
|
||||||
|
self.global_batch_size = dp_size * micro_batch_size
|
||||||
|
self.num_micro_batch = 1
|
||||||
|
elif self.global_batch_size % (dp_size * micro_batch_size) == 0:
|
||||||
|
self.num_micro_batch = global_batch_size // dp_size // micro_batch_size
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Global batch size must be divisible by DP size and micro batch size. "
|
||||||
|
f"Got {global_batch_size} % ({dp_size} * {micro_batch_size}) != 0."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.drop_last:
|
||||||
|
raise ValueError("Drop last must be True.")
|
||||||
|
|
||||||
|
self._init_data_provider()
|
||||||
|
|
||||||
|
self._is_resuming: bool = False
|
||||||
|
self._data_iter = iter(self._data_provider)
|
||||||
|
self._buffer: list[ModelInput] = []
|
||||||
|
self._buffer_tokens: int = 0
|
||||||
|
self._max_buffer_tokens: int = self.micro_batch_size * self.num_micro_batch * self.cutoff_len
|
||||||
|
|
||||||
|
logger.info_rank0(
|
||||||
|
f"Init unified data loader with global batch size {self.global_batch_size}, "
|
||||||
|
f"micro batch size {self.micro_batch_size}, "
|
||||||
|
f"num micro batch {self.num_micro_batch}, "
|
||||||
|
f"cutoff len {self.cutoff_len}, "
|
||||||
|
f"batching workers {self.batching_workers}, "
|
||||||
|
f"batching strategy {self.batching_strategy}."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_data_provider(self) -> None:
|
||||||
|
if len(self.dataset) != -1:
|
||||||
|
sampler = StatefulDistributedSampler(
|
||||||
|
self.dataset,
|
||||||
|
num_replicas=DistributedInterface().get_world_size("dp"),
|
||||||
|
rank=DistributedInterface().get_rank("dp"),
|
||||||
|
shuffle=True,
|
||||||
|
seed=0,
|
||||||
|
drop_last=self.drop_last,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Iterable dataset is not supported yet.")
|
||||||
|
|
||||||
|
self._data_provider = StatefulDataLoader(
|
||||||
|
self.dataset,
|
||||||
|
batch_size=self.micro_batch_size * self.num_micro_batch,
|
||||||
|
sampler=sampler,
|
||||||
|
num_workers=self.batching_workers,
|
||||||
|
collate_fn=self.renderer.process_samples,
|
||||||
|
pin_memory=self.pin_memory,
|
||||||
|
drop_last=self.drop_last,
|
||||||
|
)
|
||||||
|
if self.batching_strategy == BatchingStrategy.NORMAL:
|
||||||
|
self._length = len(self._data_provider)
|
||||||
|
else:
|
||||||
|
from ...plugins.trainer_plugins.batching import BatchingPlugin
|
||||||
|
|
||||||
|
self._length = BatchingPlugin(self.batching_strategy).compute_length()
|
||||||
|
raise NotImplementedError("Batching strategy other than NORMAL is not supported yet.")
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self._length
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if not self._is_resuming:
|
||||||
|
self._buffer.clear()
|
||||||
|
self._buffer_tokens = 0
|
||||||
|
|
||||||
|
self._data_iter = iter(self._data_provider)
|
||||||
|
self._is_resuming = False
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
batch = self._next_batch()
|
||||||
|
if batch is None:
|
||||||
|
raise StopIteration
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def _next_batch(self) -> list[BatchInput] | None:
|
||||||
|
while self._buffer_tokens < self._max_buffer_tokens:
|
||||||
|
try:
|
||||||
|
samples: list[ModelInput] = next(self._data_iter)
|
||||||
|
except StopIteration:
|
||||||
|
break
|
||||||
|
|
||||||
|
num_tokens = sum(len(sample["input_ids"]) for sample in samples)
|
||||||
|
self._buffer.extend(samples)
|
||||||
|
self._buffer_tokens += num_tokens
|
||||||
|
|
||||||
|
return self._build_batch()
|
||||||
|
|
||||||
|
def _build_batch(self) -> list[BatchInput] | None:
|
||||||
|
if self.batching_strategy == BatchingStrategy.NORMAL:
|
||||||
|
self._buffer, self._buffer_tokens, batch = default_collate_fn(
|
||||||
|
self._buffer, self._buffer_tokens, self.micro_batch_size, self.num_micro_batch, self.cutoff_len
|
||||||
|
)
|
||||||
|
return batch
|
||||||
|
else:
|
||||||
|
from ...plugins.trainer_plugins.batching import BatchingPlugin
|
||||||
|
|
||||||
|
self._buffer, self._buffer_tokens, batch = BatchingPlugin(self.batching_strategy)(
|
||||||
|
self._buffer, self._buffer_tokens, self.micro_batch_size, self.num_micro_batch, self.cutoff_len
|
||||||
|
)
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def state_dict(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"buffer": self._buffer,
|
||||||
|
"buffer_tokens": self._buffer_tokens,
|
||||||
|
"data_provider": self._data_provider.state_dict(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def load_state_dict(self, state: dict[str, Any]) -> None:
|
||||||
|
self._buffer = state["buffer"]
|
||||||
|
self._buffer_tokens = state["buffer_tokens"]
|
||||||
|
self._data_provider.load_state_dict(state["data_provider"])
|
||||||
|
self._is_resuming = True
|
||||||
|
|
||||||
|
def set_epoch(self, epoch: int) -> None:
|
||||||
|
if hasattr(self._data_provider.sampler, "set_epoch"):
|
||||||
|
self._data_provider.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
"""
|
||||||
|
python -m llamafactory.v1.core.utils.batching \
|
||||||
|
--model llamafactory/tiny-random-qwen2.5 \
|
||||||
|
--dataset data/v1_sft_demo.yaml \
|
||||||
|
--micro_batch_size 2 \
|
||||||
|
--global_batch_size 4 \
|
||||||
|
--batching_workers 0
|
||||||
|
"""
|
||||||
|
from ...config.arg_parser import get_args
|
||||||
|
from ..data_engine import DataEngine
|
||||||
|
from ..model_engine import ModelEngine
|
||||||
|
|
||||||
|
data_args, model_args, training_args, _ = get_args()
|
||||||
|
data_engine = DataEngine(data_args=data_args)
|
||||||
|
model_engine = ModelEngine(model_args=model_args)
|
||||||
|
batch_generator = BatchGenerator(
|
||||||
|
data_engine,
|
||||||
|
model_engine.renderer,
|
||||||
|
micro_batch_size=training_args.micro_batch_size,
|
||||||
|
global_batch_size=training_args.global_batch_size,
|
||||||
|
cutoff_len=training_args.cutoff_len,
|
||||||
|
batching_workers=training_args.batching_workers,
|
||||||
|
batching_strategy=training_args.batching_strategy,
|
||||||
|
)
|
||||||
|
for batch in batch_generator:
|
||||||
|
print(batch)
|
||||||
|
print(len(batch))
|
||||||
|
print(batch[0]["input_ids"].shape)
|
||||||
|
break
|
||||||
@@ -1,277 +0,0 @@
|
|||||||
# Copyright 2025 the LlamaFactory team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
|
||||||
import copy
|
|
||||||
import sys
|
|
||||||
from collections.abc import Generator, Iterator
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
|
||||||
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
|
|
||||||
|
|
||||||
from ...utils.batching_queue import BaseBatchingQueue
|
|
||||||
from ...utils.logging import get_logger
|
|
||||||
from ...utils.types import Processor, TorchDataset
|
|
||||||
from .data_collator import DataCollator
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# base dataloader
|
|
||||||
class DistributedDataloader(StatefulDataLoader):
|
|
||||||
"""Base Distributed DataLoader."""
|
|
||||||
|
|
||||||
dataset: "TorchDataset"
|
|
||||||
sampler: "StatefulDistributedSampler"
|
|
||||||
|
|
||||||
def set_epoch(self, epoch: int) -> None:
|
|
||||||
if self.sampler is not None and hasattr(self.sampler, "set_epoch"):
|
|
||||||
self.sampler.set_epoch(epoch)
|
|
||||||
elif hasattr(self.dataset, "set_epoch"):
|
|
||||||
self.dataset.set_epoch(epoch)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BaseDataLoader:
|
|
||||||
"""Default DataLoader."""
|
|
||||||
|
|
||||||
processor: Processor
|
|
||||||
|
|
||||||
def __init__(self, dataset: TorchDataset) -> None:
|
|
||||||
self.dataset = dataset
|
|
||||||
# guidlines: fetch until get fixed batchsize.
|
|
||||||
# save state_dict for buffer.
|
|
||||||
# resume with state
|
|
||||||
|
|
||||||
# 1. Init stateful dataloader (tokenize)
|
|
||||||
# 2. Add to buffer (2 * max seq len per device)
|
|
||||||
# 3. Yield batch indexes (micro batch * grad acc)
|
|
||||||
# a ) non pack + non dynamic
|
|
||||||
# b ) non pack + dynamic
|
|
||||||
# c ) pack + non dynamic
|
|
||||||
# d ) pack + dynamic
|
|
||||||
|
|
||||||
def init_dataloader(self) -> None:
|
|
||||||
### init dataloader
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __iter__(self) -> Iterator:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __next__(self) -> any:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DataLoader:
|
|
||||||
"""Default DataLoader."""
|
|
||||||
|
|
||||||
processor: "Processor"
|
|
||||||
dataloader: "DistributedDataloader"
|
|
||||||
batching_queue: "BaseBatchingQueue"
|
|
||||||
collate_fn: "DataCollator"
|
|
||||||
num_micro_batch: int = 1
|
|
||||||
length: int = 0
|
|
||||||
drop_last: bool = True
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dataloader: any,
|
|
||||||
collate_fn: "DataCollator",
|
|
||||||
num_micro_batch: int = 1,
|
|
||||||
length: int = 0,
|
|
||||||
drop_last: bool = True,
|
|
||||||
batching_queue: Optional["BaseBatchingQueue"] = None,
|
|
||||||
) -> None:
|
|
||||||
self.batching_queue = batching_queue
|
|
||||||
self.num_micro_batch = num_micro_batch
|
|
||||||
self.step = 0
|
|
||||||
self._collate_fn = collate_fn
|
|
||||||
self._dataloader = dataloader
|
|
||||||
self._drop_last = drop_last
|
|
||||||
self._data_iter: Iterator
|
|
||||||
self._resume = False
|
|
||||||
self._batch_data_iter: Generator
|
|
||||||
|
|
||||||
if length > 0:
|
|
||||||
self._length = length
|
|
||||||
elif length == -1:
|
|
||||||
self._length = sys.maxsize
|
|
||||||
else:
|
|
||||||
self._length = len(self._dataloader)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self._length
|
|
||||||
|
|
||||||
def __iter__(self) -> Iterator:
|
|
||||||
if not self._resume:
|
|
||||||
self.step = 0
|
|
||||||
self._data_iter = iter(self._dataloader)
|
|
||||||
self._batch_data_iter = self.batch_data_generator()
|
|
||||||
self._resume = False
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __next__(self):
|
|
||||||
return next(self._batch_data_iter) # FIXME maybe we can move origin_batch_data_generator to here
|
|
||||||
|
|
||||||
def origin_batch_data_generator(self):
|
|
||||||
"""Standard pass-through generator if do not use batching queue."""
|
|
||||||
while True:
|
|
||||||
if self._length > 0 and self.step >= self._length:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
batch = []
|
|
||||||
data = next(self._data_iter)
|
|
||||||
# split data into micro batches
|
|
||||||
for i in range(0, len(data), self.num_micro_batch):
|
|
||||||
micro_batch = data[i : i + self.num_micro_batch]
|
|
||||||
if self._collate_fn:
|
|
||||||
micro_batch = self._collate_fn(micro_batch)
|
|
||||||
batch.append(micro_batch)
|
|
||||||
yield batch
|
|
||||||
self.step += 1
|
|
||||||
except StopIteration:
|
|
||||||
if self.step < self._length:
|
|
||||||
# Restart iterator to fill the requested length
|
|
||||||
self._data_iter = iter(self._dataloader)
|
|
||||||
try:
|
|
||||||
batch = []
|
|
||||||
data = next(self._data_iter)
|
|
||||||
for i in range(0, len(data), self.num_micro_batch):
|
|
||||||
micro_batch = data[i : i + self.num_micro_batch]
|
|
||||||
if self._collate_fn:
|
|
||||||
micro_batch = self._collate_fn(micro_batch)
|
|
||||||
batch.append(micro_batch)
|
|
||||||
yield batch
|
|
||||||
self.step += 1
|
|
||||||
except StopIteration:
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
return
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"DataLoader origin_batch_data_generator exception: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def batch_data_generator(self):
|
|
||||||
if self.batching_queue is None:
|
|
||||||
yield from self.origin_batch_data_generator()
|
|
||||||
return
|
|
||||||
|
|
||||||
batch = []
|
|
||||||
|
|
||||||
while True:
|
|
||||||
if self._length and self.step >= self._length:
|
|
||||||
return
|
|
||||||
|
|
||||||
if self.batching_queue.is_full_filled():
|
|
||||||
micro_batch = self.batching_queue.get_micro_batch(self.step)
|
|
||||||
if self._collate_fn:
|
|
||||||
micro_batch = self._collate_fn(micro_batch)
|
|
||||||
batch.append(micro_batch)
|
|
||||||
if len(batch) == self.num_micro_batch:
|
|
||||||
yield batch
|
|
||||||
self.step += 1
|
|
||||||
batch = []
|
|
||||||
|
|
||||||
try:
|
|
||||||
processing_item = next(self._data_iter)
|
|
||||||
except Exception as e:
|
|
||||||
if isinstance(e, StopIteration):
|
|
||||||
if self.step < self._length:
|
|
||||||
# call iter until reach length
|
|
||||||
self._data_iter = iter(self._dataloader)
|
|
||||||
processing_item = next(self._data_iter)
|
|
||||||
elif not self._drop_last and not self.batching_queue.empty():
|
|
||||||
while not self.batching_queue.empty():
|
|
||||||
micro_batch = self.batching_queue.get_micro_batch(self.step)
|
|
||||||
if self._collate_fn:
|
|
||||||
micro_batch = self._collate_fn(micro_batch)
|
|
||||||
batch.append(micro_batch)
|
|
||||||
if len(batch) == self.num_micro_batch:
|
|
||||||
yield batch
|
|
||||||
self.step += 1
|
|
||||||
batch = []
|
|
||||||
|
|
||||||
while len(batch) < self.num_micro_batch:
|
|
||||||
padding_batch = copy.deepcopy(micro_batch)
|
|
||||||
padding_batch["is_padded"] = True
|
|
||||||
batch.append(padding_batch)
|
|
||||||
yield batch
|
|
||||||
self.step += 1
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
logger.error(f"DataLoader iter data exception: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# put processing_item to buffer
|
|
||||||
if isinstance(processing_item, dict):
|
|
||||||
processing_item = [processing_item]
|
|
||||||
|
|
||||||
for item in processing_item:
|
|
||||||
self.batching_queue.put_item(item)
|
|
||||||
|
|
||||||
def state_dict(self):
|
|
||||||
# save state
|
|
||||||
state = self.__dict__.copy()
|
|
||||||
# remove internal fields
|
|
||||||
for k in list(state.keys()):
|
|
||||||
if k.startswith("_"):
|
|
||||||
del state[k]
|
|
||||||
|
|
||||||
# save dataloader state
|
|
||||||
if hasattr(self._dataloader, "state_dict"):
|
|
||||||
state["dataloader_state"] = self._dataloader.state_dict()
|
|
||||||
elif hasattr(self._dataloader, "__getstate__"):
|
|
||||||
state["dataloader_state"] = self._dataloader.__getstate__()
|
|
||||||
|
|
||||||
batching_strategy = getattr(self, "batching_strategy", None)
|
|
||||||
if batching_strategy and hasattr(batching_strategy, "state_dict"):
|
|
||||||
state["batching_strategy_state"] = batching_strategy.state_dict()
|
|
||||||
if "batching_strategy" in state:
|
|
||||||
del state["batching_strategy"]
|
|
||||||
|
|
||||||
return copy.deepcopy(state)
|
|
||||||
|
|
||||||
def load_state_dict(self, state: dict[str, any]):
|
|
||||||
if state["num_micro_batch"] != self.num_micro_batch:
|
|
||||||
logger.warning(
|
|
||||||
f"num_micro_batch changed: [ {state['num_micro_batch']} -> {self.num_micro_batch} ], will clear prefetch buffer"
|
|
||||||
)
|
|
||||||
del state["num_micro_batch"]
|
|
||||||
self.__dict__.update(state)
|
|
||||||
self._resume = True
|
|
||||||
|
|
||||||
if hasattr(self._dataloader, "load_state_dict"):
|
|
||||||
self._dataloader.load_state_dict(state["dataloader_state"])
|
|
||||||
elif hasattr(self._dataloader, "__getstate__"):
|
|
||||||
self._dataloader.__setstate__(state["dataloader_state"])
|
|
||||||
|
|
||||||
if "batching_strategy_state" in state:
|
|
||||||
batching_strategy = getattr(self, "batching_strategy", None)
|
|
||||||
if batching_strategy:
|
|
||||||
batching_strategy.load_state_dict(state["batching_strategy_state"])
|
|
||||||
del state["batching_strategy_state"]
|
|
||||||
|
|
||||||
self._data_iter = iter(self._dataloader)
|
|
||||||
self._batch_data_iter = self.batch_data_generator()
|
|
||||||
|
|
||||||
def set_epoch(self, epoch: int) -> None:
|
|
||||||
if hasattr(self._dataloader, "set_epoch"):
|
|
||||||
self._dataloader.set_epoch(epoch)
|
|
||||||
@@ -12,10 +12,20 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Rendering utils.
|
||||||
|
|
||||||
|
How to use:
|
||||||
|
renderer = Renderer(template, processor)
|
||||||
|
renderer.render_messages(messages: list[Message], tools: str | None) -> ModelInputs
|
||||||
|
renderer.parse_message(text: str) -> Message
|
||||||
|
renderer.process_samples(samples: list[Sample]) -> list[ModelInput]
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from ...utils.constants import IGNORE_INDEX
|
from ...utils.constants import IGNORE_INDEX
|
||||||
from ...utils.helper import get_tokenizer
|
from ...utils.helper import get_tokenizer
|
||||||
from ...utils.types import Message, ModelInput, Processor
|
from ...utils.types import Message, ModelInput, Processor, Sample
|
||||||
|
|
||||||
|
|
||||||
def render_chatml_messages(
|
def render_chatml_messages(
|
||||||
@@ -64,7 +74,7 @@ def render_chatml_messages(
|
|||||||
|
|
||||||
|
|
||||||
def parse_chatml_message(generated_text: str) -> Message:
|
def parse_chatml_message(generated_text: str) -> Message:
|
||||||
"""Parse a message in ChatML format. Supports interleaved reasoning and tool calls.
|
"""Parse a message in ChatML format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
generated_text (str): The generated text in ChatML format.
|
generated_text (str): The generated text in ChatML format.
|
||||||
@@ -83,6 +93,16 @@ class Renderer:
|
|||||||
def render_messages(
|
def render_messages(
|
||||||
self, messages: list[Message], tools: str | None = None, is_generate: bool = False
|
self, messages: list[Message], tools: str | None = None, is_generate: bool = False
|
||||||
) -> ModelInput:
|
) -> ModelInput:
|
||||||
|
"""Apply template to messages and convert them to model input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages (list[Message]): The messages to render.
|
||||||
|
tools (str | None, optional): The tools to use. Defaults to None.
|
||||||
|
is_generate (bool, optional): Whether to render for generation. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelInput: The rendered model input.
|
||||||
|
"""
|
||||||
if self.template == "chatml":
|
if self.template == "chatml":
|
||||||
return render_chatml_messages(self.processor, messages, tools, is_generate)
|
return render_chatml_messages(self.processor, messages, tools, is_generate)
|
||||||
else:
|
else:
|
||||||
@@ -91,9 +111,59 @@ class Renderer:
|
|||||||
return RenderingPlugin(self.template).render_messages(self.processor, messages, tools, is_generate)
|
return RenderingPlugin(self.template).render_messages(self.processor, messages, tools, is_generate)
|
||||||
|
|
||||||
def parse_message(self, generated_text: str) -> Message:
|
def parse_message(self, generated_text: str) -> Message:
|
||||||
|
"""Parse a message in the template format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generated_text (str): The generated text in the template format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Message: The parsed message.
|
||||||
|
"""
|
||||||
if self.template == "chatml":
|
if self.template == "chatml":
|
||||||
return parse_chatml_message(generated_text)
|
return parse_chatml_message(generated_text)
|
||||||
else:
|
else:
|
||||||
from ...plugins.model_plugins.rendering import RenderingPlugin
|
from ...plugins.model_plugins.rendering import RenderingPlugin
|
||||||
|
|
||||||
return RenderingPlugin(self.template).parse_message(generated_text)
|
return RenderingPlugin(self.template).parse_message(generated_text)
|
||||||
|
|
||||||
|
def process_samples(self, samples: list[Sample]) -> list[ModelInput]:
|
||||||
|
"""Process samples to model input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
samples (list[Sample]): The samples to process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[ModelInput]: The processed model inputs.
|
||||||
|
"""
|
||||||
|
model_inputs = []
|
||||||
|
for sample in samples:
|
||||||
|
if "messages" in sample:
|
||||||
|
model_input = self.render_messages(sample["messages"], sample.get("tools"))
|
||||||
|
elif "chosen_messages" in sample and "rejected_messages" in sample:
|
||||||
|
chosen_input = self.render_messages(sample["chosen_messages"], sample.get("tools"))
|
||||||
|
rejected_input = self.render_messages(sample["rejected_messages"], sample.get("tools"))
|
||||||
|
chosen_input["token_type_ids"] = [0] * len(chosen_input["input_ids"])
|
||||||
|
rejected_input["token_type_ids"] = [1] * len(rejected_input["input_ids"])
|
||||||
|
model_input = ModelInput(
|
||||||
|
input_ids=chosen_input["input_ids"] + rejected_input["input_ids"],
|
||||||
|
attention_mask=chosen_input["attention_mask"] + rejected_input["attention_mask"],
|
||||||
|
labels=chosen_input["labels"] + rejected_input["labels"],
|
||||||
|
loss_weights=chosen_input["loss_weights"] + rejected_input["loss_weights"],
|
||||||
|
token_type_ids=chosen_input["token_type_ids"] + rejected_input["token_type_ids"],
|
||||||
|
)
|
||||||
|
if "position_ids" in chosen_input:
|
||||||
|
model_input["position_ids"] = np.concatenate(
|
||||||
|
[chosen_input["position_ids"], rejected_input["position_ids"]], axis=-1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("No valid messages or chosen_messages/rejected_messages found in sample.")
|
||||||
|
|
||||||
|
if "extra_info" in sample:
|
||||||
|
model_input["extra_info"] = sample["extra_info"]
|
||||||
|
|
||||||
|
if "_dataset_name" in sample:
|
||||||
|
model_input["_dataset_name"] = sample["_dataset_name"]
|
||||||
|
|
||||||
|
model_inputs.append(model_input)
|
||||||
|
|
||||||
|
return model_inputs
|
||||||
|
|||||||
@@ -32,7 +32,8 @@ class AlpacaSample(TypedDict, total=False):
|
|||||||
|
|
||||||
|
|
||||||
SharegptMessage = TypedDict(
|
SharegptMessage = TypedDict(
|
||||||
"SharegptMessage", {"from": Literal["human", "gpt", "system", "function_call", "observation"], "value": str}
|
"SharegptMessage",
|
||||||
|
{"from": Literal["human", "gpt", "system", "function_call", "observation"], "value": str},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -118,15 +119,8 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
|
|||||||
"observation": "tool",
|
"observation": "tool",
|
||||||
"function_call": "assistant",
|
"function_call": "assistant",
|
||||||
}
|
}
|
||||||
|
sample = {}
|
||||||
messages = []
|
messages = []
|
||||||
tools = raw_sample.get("tools")
|
|
||||||
if tools:
|
|
||||||
try:
|
|
||||||
tools: list[dict[str, Any]] = json.loads(tools)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
|
|
||||||
tools = []
|
|
||||||
|
|
||||||
for message in raw_sample.get("conversations", []):
|
for message in raw_sample.get("conversations", []):
|
||||||
tag = message["from"]
|
tag = message["from"]
|
||||||
if tag not in tag_mapping:
|
if tag not in tag_mapping:
|
||||||
@@ -157,10 +151,17 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
sample["messages"] = messages
|
||||||
|
|
||||||
|
tools = raw_sample.get("tools")
|
||||||
if tools:
|
if tools:
|
||||||
return {"messages": messages, "tools": json.dumps(tools)}
|
try:
|
||||||
else:
|
tools: list[dict[str, Any]] = json.loads(tools)
|
||||||
return {"messages": messages}
|
sample["tools"] = json.dumps(tools)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
@DataConverterPlugin("pair").register()
|
@DataConverterPlugin("pair").register()
|
||||||
@@ -179,17 +180,44 @@ def pair_converter(raw_sample: PairSample) -> DPOSample:
|
|||||||
def process_message(raw_messages: list[OpenaiMessage]):
|
def process_message(raw_messages: list[OpenaiMessage]):
|
||||||
messages = []
|
messages = []
|
||||||
for message in raw_messages:
|
for message in raw_messages:
|
||||||
messages.append(
|
if message["role"] == "tool":
|
||||||
{
|
try:
|
||||||
"role": message["role"],
|
tool_calls: ToolCall | list[ToolCall] = json.loads(message["content"])
|
||||||
"content": [{"type": "text", "value": message["content"]}],
|
except json.JSONDecodeError:
|
||||||
"loss_weight": 1.0 if message["role"] == "assistant" else 0.0,
|
logger.warning_rank0(f"Invalid tool call format: {str(message['content'])}")
|
||||||
}
|
continue
|
||||||
)
|
|
||||||
|
if not isinstance(tool_calls, list):
|
||||||
|
tool_calls = [tool_calls]
|
||||||
|
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": message["role"],
|
||||||
|
"content": [{"type": "tool_call", "value": json.dumps(tool_call)} for tool_call in tool_calls],
|
||||||
|
"loss_weight": 1.0 if message["role"] == "assistant" else 0.0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": message["role"],
|
||||||
|
"content": [{"type": "text", "value": message["content"]}],
|
||||||
|
"loss_weight": 1.0 if message["role"] == "assistant" else 0.0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
chosen_messages = process_message(raw_sample.get("chosen", []))
|
sample = {}
|
||||||
rejected_messages = process_message(raw_sample.get("rejected", []))
|
sample["chosen_messages"] = process_message(raw_sample.get("chosen", []))
|
||||||
|
sample["rejected_messages"] = process_message(raw_sample.get("rejected", []))
|
||||||
|
|
||||||
return {"chosen_messages": chosen_messages, "rejected_messages": rejected_messages}
|
tools = raw_sample.get("tools")
|
||||||
|
if tools:
|
||||||
|
try:
|
||||||
|
tools: list[dict[str, Any]] = json.loads(tools)
|
||||||
|
sample["tools"] = json.dumps(tools)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|||||||
@@ -12,7 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@@ -51,12 +50,16 @@ def _update_model_input(
|
|||||||
|
|
||||||
|
|
||||||
@RenderingPlugin("qwen3_nothink").register("render_messages")
|
@RenderingPlugin("qwen3_nothink").register("render_messages")
|
||||||
def render_qwen_messages(
|
def render_qwen3_nothink_messages(
|
||||||
processor: Processor,
|
processor: Processor,
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
tools: str | None = None,
|
tools: str | None = None,
|
||||||
is_generate: bool = False,
|
is_generate: bool = False,
|
||||||
) -> ModelInput:
|
) -> ModelInput:
|
||||||
|
"""Render messages in the Qwen3 nothink template format.
|
||||||
|
|
||||||
|
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-4B-Instruct-2507
|
||||||
|
"""
|
||||||
input_ids, labels, loss_weights = [], [], []
|
input_ids, labels, loss_weights = [], [], []
|
||||||
temp_str, temp_weight = "", 0.0
|
temp_str, temp_weight = "", 0.0
|
||||||
if tools:
|
if tools:
|
||||||
@@ -179,7 +182,15 @@ def render_qwen_messages(
|
|||||||
|
|
||||||
|
|
||||||
@RenderingPlugin("qwen3_nothink").register("parse_message")
|
@RenderingPlugin("qwen3_nothink").register("parse_message")
|
||||||
def parse_qwen_message(generated_text: str) -> Message:
|
def parse_qwen3_nothink_message(generated_text: str) -> Message:
|
||||||
|
"""Parse a message in the Qwen3 nothink template format. Supports interleaved reasoning and tool calls.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generated_text (str): The generated text in the Qwen3 nothink template format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Message: The parsed message.
|
||||||
|
"""
|
||||||
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
|
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
|
||||||
content = []
|
content = []
|
||||||
last_end = 0
|
last_end = 0
|
||||||
|
|||||||
19
src/llamafactory/v1/plugins/trainer_plugins/batching.py
Normal file
19
src/llamafactory/v1/plugins/trainer_plugins/batching.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ...utils.plugin import BasePlugin
|
||||||
|
|
||||||
|
|
||||||
|
class BatchingPlugin(BasePlugin):
|
||||||
|
pass
|
||||||
@@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
|
|
||||||
from ..accelerator.interface import DistributedInterface
|
from ..accelerator.interface import DistributedInterface
|
||||||
from ..config.arg_parser import get_args
|
from ..config import InputArgument, get_args
|
||||||
from ..core.base_trainer import BaseTrainer
|
from ..core.base_trainer import BaseTrainer
|
||||||
from ..core.data_engine import DataEngine
|
from ..core.data_engine import DataEngine
|
||||||
from ..core.model_engine import ModelEngine
|
from ..core.model_engine import ModelEngine
|
||||||
@@ -24,15 +24,15 @@ class SFTTrainer(BaseTrainer):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def run_sft(user_args):
|
def run_sft(args: InputArgument = None):
|
||||||
model_args, data_args, training_args, _ = get_args(user_args)
|
model_args, data_args, training_args, _ = get_args(args)
|
||||||
DistributedInterface(training_args.dist_config)
|
DistributedInterface(training_args.dist_config)
|
||||||
data_engine = DataEngine(data_args)
|
data_engine = DataEngine(data_args)
|
||||||
model_engine = ModelEngine(model_args)
|
model_engine = ModelEngine(model_args)
|
||||||
trainer = SFTTrainer(
|
trainer = SFTTrainer(
|
||||||
args=training_args,
|
args=training_args,
|
||||||
model=model_engine.model,
|
model=model_engine.model,
|
||||||
processor=model_engine.processor,
|
renderer=model_engine.renderer,
|
||||||
dataset=data_engine,
|
dataset=data_engine,
|
||||||
)
|
)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|||||||
@@ -1,220 +0,0 @@
|
|||||||
# Copyright 2025 Bytedance Ltd. and the LlamaFactory team.
|
|
||||||
#
|
|
||||||
# This code is inspired by the Bytedance's VeOmni library.
|
|
||||||
# https://github.com/ByteDance-Seed/VeOmni/blob/v0.1.4/veomni/data/dynamic_batching.py
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
|
|
||||||
class DynamicBatchSizeBuffer:
|
|
||||||
"""A buffer to store samples for dynamic batch size."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._buffer: list[dict[str, any]] = []
|
|
||||||
self._buffer_sample_lengths: list[int] = []
|
|
||||||
self._deleted_indices: set[int] = set()
|
|
||||||
self._current_index: int = 0
|
|
||||||
self._total_token_count: int = 0
|
|
||||||
|
|
||||||
def append(self, item: dict[str, any]) -> None:
|
|
||||||
"""Append a sample to the buffer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
item: A sample to append to the buffer.
|
|
||||||
The sample should be a dict with the following keys:
|
|
||||||
- input_ids: torch.Tensor of shape (seq_len, )
|
|
||||||
- attention_mask: torch.Tensor of shape (seq_len, )
|
|
||||||
"""
|
|
||||||
self._buffer.append(item)
|
|
||||||
sample_length = int(item["attention_mask"].sum().item())
|
|
||||||
self._buffer_sample_lengths.append(sample_length)
|
|
||||||
self._total_token_count += sample_length
|
|
||||||
|
|
||||||
def get_samples(self, max_tokens_per_iteration: int, force: bool = True) -> list[dict[str, any]]:
|
|
||||||
"""Get samples from the buffer that fit within the token budget.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
max_tokens_per_iteration: Maximum number of tokens to retrieve.
|
|
||||||
force: If True, the first available sample will be returned even
|
|
||||||
if it exceeds the token budget.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of samples that fit within the token budget.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
AssertionError: If no samples are found (should not happen in normal operation).
|
|
||||||
"""
|
|
||||||
cum_seq_len = 0
|
|
||||||
samples = []
|
|
||||||
|
|
||||||
while self._current_index < len(self._buffer) and cum_seq_len < max_tokens_per_iteration:
|
|
||||||
if self._current_index in self._deleted_indices:
|
|
||||||
self._current_index += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
seq_len = self._buffer_sample_lengths[self._current_index]
|
|
||||||
remaining_tokens = max_tokens_per_iteration - cum_seq_len
|
|
||||||
|
|
||||||
# Check if we can add this sample
|
|
||||||
can_add = (force and cum_seq_len == 0) or (seq_len <= remaining_tokens)
|
|
||||||
|
|
||||||
if can_add:
|
|
||||||
cum_seq_len += seq_len
|
|
||||||
samples.append(self._buffer[self._current_index])
|
|
||||||
self._deleted_indices.add(self._current_index)
|
|
||||||
|
|
||||||
self._current_index += 1
|
|
||||||
|
|
||||||
assert len(samples) > 0, "No samples found in buffer"
|
|
||||||
return samples
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
"""Return the number of samples in the buffer."""
|
|
||||||
return len(self._buffer)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def total_token_count(self) -> int:
|
|
||||||
"""Return the total number of tokens in the buffer."""
|
|
||||||
return self._total_token_count
|
|
||||||
|
|
||||||
def flush(self) -> None:
|
|
||||||
tokens_to_remove = sum(self._buffer_sample_lengths[idx] for idx in self._deleted_indices)
|
|
||||||
self._total_token_count -= tokens_to_remove
|
|
||||||
|
|
||||||
buffer_length = len(self._buffer)
|
|
||||||
self._buffer = [self._buffer[idx] for idx in range(buffer_length) if idx not in self._deleted_indices]
|
|
||||||
self._buffer_sample_lengths = [
|
|
||||||
self._buffer_sample_lengths[idx] for idx in range(buffer_length) if idx not in self._deleted_indices
|
|
||||||
]
|
|
||||||
|
|
||||||
self._current_index = 0
|
|
||||||
self._deleted_indices.clear()
|
|
||||||
|
|
||||||
|
|
||||||
class BaseBatchingQueue(ABC):
|
|
||||||
"""Base class for batching queue."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def is_full_filled(self) -> bool:
|
|
||||||
raise NotImplementedError("Subclasses must implement `is_full_filled`")
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def put_item(self, item: dict[str, any]) -> None:
|
|
||||||
raise NotImplementedError("Subclasses must implement `put_item`")
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_micro_batch(self, step: int) -> list[dict[str, any]]:
|
|
||||||
raise NotImplementedError("Subclasses must implement `get_micro_batch`")
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def empty(self) -> bool:
|
|
||||||
raise NotImplementedError("Subclasses must implement `empty`")
|
|
||||||
|
|
||||||
|
|
||||||
class IdentityPacker:
|
|
||||||
def __init__(self, token_micro_bsz, bsz_warmup_steps, bsz_warmup_init_mbtoken):
|
|
||||||
self.token_micro_bsz = token_micro_bsz
|
|
||||||
self.bsz_warmup_steps = bsz_warmup_steps
|
|
||||||
self.bsz_warmup_init_mbtoken = bsz_warmup_init_mbtoken
|
|
||||||
|
|
||||||
def __call__(self, samples):
|
|
||||||
return samples
|
|
||||||
|
|
||||||
def get_token_num_to_request(self, cur_step, warmup):
|
|
||||||
return (
|
|
||||||
(self.token_micro_bsz - self.bsz_warmup_init_mbtoken) * cur_step // self.bsz_warmup_steps
|
|
||||||
+ self.bsz_warmup_init_mbtoken
|
|
||||||
if warmup
|
|
||||||
else self.token_micro_bsz
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TextBatchingQueue(BaseBatchingQueue):
|
|
||||||
"""Batching text queue for text data."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
token_micro_bsz,
|
|
||||||
buffer_size: int = 500,
|
|
||||||
bsz_warmup_steps: int = -1,
|
|
||||||
bsz_warmup_init_mbtoken: int = 200,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self._step = 0
|
|
||||||
self.token_micro_bsz = token_micro_bsz
|
|
||||||
self.bsz_warmup_steps = bsz_warmup_steps
|
|
||||||
self.buffer_size = buffer_size # minimum samples in buffer
|
|
||||||
self.buffer = DynamicBatchSizeBuffer()
|
|
||||||
self.bsz_warmup_init_mbtoken = bsz_warmup_init_mbtoken # training warmup args
|
|
||||||
assert self.bsz_warmup_init_mbtoken >= 0
|
|
||||||
|
|
||||||
self.packer = IdentityPacker(
|
|
||||||
token_micro_bsz=token_micro_bsz,
|
|
||||||
bsz_warmup_steps=bsz_warmup_steps,
|
|
||||||
bsz_warmup_init_mbtoken=bsz_warmup_init_mbtoken,
|
|
||||||
)
|
|
||||||
|
|
||||||
def is_full_filled(self) -> bool:
|
|
||||||
return len(self.buffer) >= self.buffer_size and self.buffer.total_token_count >= self.token_micro_bsz
|
|
||||||
|
|
||||||
def put_item(self, item: dict[str, any]):
|
|
||||||
if len(item["input_ids"]) == 1:
|
|
||||||
print("WARNING: EMPTY STRING.")
|
|
||||||
return
|
|
||||||
self.buffer.append(item)
|
|
||||||
|
|
||||||
def get_token_num_to_request(self):
|
|
||||||
if self.packer is not None:
|
|
||||||
warmup = self._step <= self.bsz_warmup_steps and self.bsz_warmup_steps > 0
|
|
||||||
return self.packer.get_token_num_to_request(self._step, warmup=warmup)
|
|
||||||
else:
|
|
||||||
return self.get_cur_token_micro_bsz()
|
|
||||||
|
|
||||||
def get_cur_token_micro_bsz(self):
|
|
||||||
warmup = self._step <= self.bsz_warmup_steps and self.bsz_warmup_steps > 0
|
|
||||||
if warmup:
|
|
||||||
return (
|
|
||||||
self.token_micro_bsz - self.bsz_warmup_init_mbtoken
|
|
||||||
) * self._step // self.bsz_warmup_steps + self.bsz_warmup_init_mbtoken
|
|
||||||
else:
|
|
||||||
return self.token_micro_bsz
|
|
||||||
|
|
||||||
def get_micro_batch(self, step) -> any:
|
|
||||||
"""Get a micro batch from the buffer according to the current step.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
step: the current step.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
data: a list of samples.
|
|
||||||
"""
|
|
||||||
self._step = step
|
|
||||||
n_token_per_iter = self.get_token_num_to_request()
|
|
||||||
cur_token_micro_bsz = self.get_cur_token_micro_bsz()
|
|
||||||
assert cur_token_micro_bsz % n_token_per_iter == 0, (
|
|
||||||
"The token num to get for each request should be divisible by token micro bsz."
|
|
||||||
)
|
|
||||||
n_iter = int(cur_token_micro_bsz // n_token_per_iter)
|
|
||||||
data = []
|
|
||||||
for _ in range(n_iter):
|
|
||||||
samples = self.buffer.get_samples(n_token_per_iter)
|
|
||||||
if self.packer:
|
|
||||||
samples = self.packer(samples) # maybe packed into one sample, but wrapped in list.
|
|
||||||
data.extend(samples)
|
|
||||||
self.buffer.flush() # remove the selected samples.
|
|
||||||
return data
|
|
||||||
|
|
||||||
def empty(self) -> bool:
|
|
||||||
return len(self.buffer) == 0
|
|
||||||
@@ -32,8 +32,8 @@ class DtypeRegistry:
|
|||||||
class DtypeInterface:
|
class DtypeInterface:
|
||||||
"""Type of precision used."""
|
"""Type of precision used."""
|
||||||
|
|
||||||
_is_fp16_available = is_torch_fp16_available_on_device(DistributedInterface.current_accelerator)
|
_is_fp16_available = is_torch_fp16_available_on_device(DistributedInterface().current_device)
|
||||||
_is_bf16_available = is_torch_bf16_available_on_device(DistributedInterface.current_accelerator)
|
_is_bf16_available = is_torch_bf16_available_on_device(DistributedInterface().current_device)
|
||||||
_is_fp32_available = True
|
_is_fp32_available = True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -12,9 +12,24 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
from .types import Processor
|
from .constants import IGNORE_INDEX
|
||||||
|
from .types import BatchInput, ModelInput, Processor, Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def is_tokenizer(processor: Processor) -> bool:
|
||||||
|
"""Check if processor is tokenizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
processor: Processor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Whether processor is tokenizer.
|
||||||
|
"""
|
||||||
|
return not hasattr(processor, "tokenizer")
|
||||||
|
|
||||||
|
|
||||||
def get_tokenizer(processor: Processor) -> PreTrainedTokenizer:
|
def get_tokenizer(processor: Processor) -> PreTrainedTokenizer:
|
||||||
@@ -27,3 +42,34 @@ def get_tokenizer(processor: Processor) -> PreTrainedTokenizer:
|
|||||||
Tokenizer.
|
Tokenizer.
|
||||||
"""
|
"""
|
||||||
return processor.tokenizer if hasattr(processor, "tokenizer") else processor
|
return processor.tokenizer if hasattr(processor, "tokenizer") else processor
|
||||||
|
|
||||||
|
|
||||||
|
def _pad_and_truncate(tensor: Tensor, max_seqlen: int, pad_value: int = 0) -> Tensor:
|
||||||
|
if tensor.shape[-1] >= max_seqlen:
|
||||||
|
return tensor[..., :max_seqlen]
|
||||||
|
|
||||||
|
pad_shape = list(tensor.shape)
|
||||||
|
pad_shape[-1] = max_seqlen - tensor.shape[-1]
|
||||||
|
pad_tensor = torch.full(pad_shape, pad_value, dtype=tensor.dtype, device=tensor.device)
|
||||||
|
return torch.cat([tensor, pad_tensor], dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def pad_and_truncate(samples: list[ModelInput], max_seqlen: int) -> list[BatchInput]:
|
||||||
|
max_length = min(max(len(sample["input_ids"]) for sample in samples), max_seqlen)
|
||||||
|
padded_samples = []
|
||||||
|
for sample in samples:
|
||||||
|
padded_sample = {}
|
||||||
|
for key, value in sample.items():
|
||||||
|
if "label" in key:
|
||||||
|
pad_value = IGNORE_INDEX
|
||||||
|
else:
|
||||||
|
pad_value = 0
|
||||||
|
|
||||||
|
if not isinstance(value, str):
|
||||||
|
padded_sample[key] = _pad_and_truncate(torch.tensor(value), max_length, pad_value)
|
||||||
|
else:
|
||||||
|
padded_sample[key] = value
|
||||||
|
|
||||||
|
padded_samples.append(padded_sample)
|
||||||
|
|
||||||
|
return padded_samples
|
||||||
|
|||||||
@@ -144,3 +144,20 @@ class ModelInput(TypedDict, total=False):
|
|||||||
"""Loss weight for each token, default to 1.0."""
|
"""Loss weight for each token, default to 1.0."""
|
||||||
position_ids: NotRequired[list[int] | list[list[int]]]
|
position_ids: NotRequired[list[int] | list[list[int]]]
|
||||||
"""Position ids for the model (optional)."""
|
"""Position ids for the model (optional)."""
|
||||||
|
token_type_ids: NotRequired[list[int]]
|
||||||
|
"""Token type ids used in DPO, 0 represents the chosen messages, 1 represents the rejected messages."""
|
||||||
|
|
||||||
|
|
||||||
|
class BatchInput(TypedDict, total=False):
|
||||||
|
input_ids: Tensor
|
||||||
|
"""Input ids for the model."""
|
||||||
|
attention_mask: Tensor
|
||||||
|
"""Attention mask for the model."""
|
||||||
|
labels: Tensor
|
||||||
|
"""Labels for the model."""
|
||||||
|
loss_weights: Tensor
|
||||||
|
"""Loss weight for each token, default to 1.0."""
|
||||||
|
position_ids: NotRequired[Tensor]
|
||||||
|
"""Position ids for the model (optional)."""
|
||||||
|
token_type_ids: NotRequired[Tensor]
|
||||||
|
"""Token type ids used in DPO, 0 represents the chosen messages, 1 represents the rejected messages."""
|
||||||
|
|||||||
@@ -1,2 +1,2 @@
|
|||||||
# change if test fails or cache is outdated
|
# change if test fails or cache is outdated
|
||||||
0.9.5.103
|
0.9.5.104
|
||||||
|
|||||||
@@ -56,4 +56,5 @@ def test_all_device():
|
|||||||
@pytest.mark.require_distributed(2)
|
@pytest.mark.require_distributed(2)
|
||||||
def test_multi_device():
|
def test_multi_device():
|
||||||
master_port = find_available_port()
|
master_port = find_available_port()
|
||||||
mp.spawn(_all_reduce_tests, args=(2, master_port), nprocs=2)
|
world_size = 2
|
||||||
|
mp.spawn(_all_reduce_tests, args=(world_size, master_port), nprocs=world_size)
|
||||||
|
|||||||
@@ -14,28 +14,24 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from llamafactory.v1.config.model_args import ModelArguments, PluginConfig
|
from llamafactory.v1.config.model_args import ModelArguments
|
||||||
from llamafactory.v1.core.model_engine import ModelEngine
|
from llamafactory.v1.core.model_engine import ModelEngine
|
||||||
|
|
||||||
|
|
||||||
def test_tiny_qwen():
|
def test_tiny_qwen():
|
||||||
from transformers import Qwen2Config, Qwen2ForCausalLM, Qwen2TokenizerFast
|
model_args = ModelArguments(model="llamafactory/tiny-random-qwen3")
|
||||||
|
|
||||||
model_args = ModelArguments(model="llamafactory/tiny-random-qwen2.5")
|
|
||||||
model_engine = ModelEngine(model_args)
|
model_engine = ModelEngine(model_args)
|
||||||
assert isinstance(model_engine.processor, Qwen2TokenizerFast)
|
assert "Qwen2Tokenizer" in model_engine.processor.__class__.__name__
|
||||||
assert isinstance(model_engine.model_config, Qwen2Config)
|
assert "Qwen3Config" in model_engine.model_config.__class__.__name__
|
||||||
assert isinstance(model_engine.model, Qwen2ForCausalLM)
|
assert "Qwen3ForCausalLM" in model_engine.model.__class__.__name__
|
||||||
assert model_engine.model.dtype == torch.bfloat16
|
assert model_engine.model.dtype == torch.bfloat16
|
||||||
|
|
||||||
|
|
||||||
def test_tiny_qwen_with_kernel_plugin():
|
def test_tiny_qwen_with_kernel_plugin():
|
||||||
from transformers import Qwen2ForCausalLM
|
|
||||||
|
|
||||||
from llamafactory.v1.plugins.model_plugins.kernels.ops.rms_norm.npu_rms_norm import npu_rms_norm_forward
|
from llamafactory.v1.plugins.model_plugins.kernels.ops.rms_norm.npu_rms_norm import npu_rms_norm_forward
|
||||||
|
|
||||||
model_args = ModelArguments(
|
model_args = ModelArguments(
|
||||||
model="llamafactory/tiny-random-qwen2.5", kernel_config=PluginConfig(name="auto", include_kernels="auto")
|
model="llamafactory/tiny-random-qwen3", kernel_config={"name": "auto", "include_kernels": "auto"}
|
||||||
)
|
)
|
||||||
model_engine = ModelEngine(model_args)
|
model_engine = ModelEngine(model_args)
|
||||||
# test enable apply kernel plugin
|
# test enable apply kernel plugin
|
||||||
@@ -44,7 +40,7 @@ def test_tiny_qwen_with_kernel_plugin():
|
|||||||
else:
|
else:
|
||||||
assert model_engine.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
|
assert model_engine.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
|
||||||
|
|
||||||
assert isinstance(model_engine.model, Qwen2ForCausalLM)
|
assert "Qwen3ForCausalLM" in model_engine.model.__class__.__name__
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
49
tests_v1/core/utils/test_batching.py
Normal file
49
tests_v1/core/utils/test_batching.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from llamafactory.v1.config import DataArguments, ModelArguments, TrainingArguments
|
||||||
|
from llamafactory.v1.core.data_engine import DataEngine
|
||||||
|
from llamafactory.v1.core.model_engine import ModelEngine
|
||||||
|
from llamafactory.v1.core.utils.batching import BatchGenerator
|
||||||
|
|
||||||
|
|
||||||
|
def test_normal_batching():
|
||||||
|
data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
|
||||||
|
data_engine = DataEngine(data_args=data_args)
|
||||||
|
model_args = ModelArguments(model="llamafactory/tiny-random-qwen3")
|
||||||
|
model_engine = ModelEngine(model_args=model_args)
|
||||||
|
training_args = TrainingArguments(
|
||||||
|
micro_batch_size=4,
|
||||||
|
global_batch_size=8,
|
||||||
|
cutoff_len=10,
|
||||||
|
batching_workers=0,
|
||||||
|
batching_strategy="normal",
|
||||||
|
)
|
||||||
|
batch_generator = BatchGenerator(
|
||||||
|
data_engine,
|
||||||
|
model_engine.renderer,
|
||||||
|
micro_batch_size=training_args.micro_batch_size,
|
||||||
|
global_batch_size=training_args.global_batch_size,
|
||||||
|
cutoff_len=training_args.cutoff_len,
|
||||||
|
batching_workers=training_args.batching_workers,
|
||||||
|
batching_strategy=training_args.batching_strategy,
|
||||||
|
)
|
||||||
|
assert len(batch_generator) == len(data_engine) // training_args.global_batch_size
|
||||||
|
batch = next(iter(batch_generator))
|
||||||
|
assert len(batch) == 2
|
||||||
|
assert batch[0]["input_ids"].shape == (4, 10)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_normal_batching()
|
||||||
@@ -1,171 +0,0 @@
|
|||||||
# Copyright 2025 the LlamaFactory team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
"""Integration tests for DataLoader with different combinations of packing and dynamic batching.
|
|
||||||
|
|
||||||
Tests the 4 scenarios:
|
|
||||||
a) non pack + non dynamic.
|
|
||||||
b) non pack + dynamic.
|
|
||||||
c) pack + non dynamic.
|
|
||||||
d) pack + dynamic.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# import torch
|
|
||||||
# from torch.utils.data import DataLoader as TorchDataLoader
|
|
||||||
# from torch.utils.data import Dataset
|
|
||||||
# from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
# from llamafactory.v1.config.data_args import DataArguments
|
|
||||||
# from llamafactory.v1.core.data_engine import DataEngine
|
|
||||||
# from llamafactory.v1.core.utils.data_collator import DefaultCollator
|
|
||||||
# from llamafactory.v1.core.utils.data_loader import DataLoader
|
|
||||||
# from llamafactory.v1.plugins.data_plugins.rendering import QwenTemplate
|
|
||||||
# from llamafactory.v1.utils.batching_queue import TextBatchingQueue
|
|
||||||
|
|
||||||
|
|
||||||
# class TensorDataset(Dataset):
|
|
||||||
# """Wrapper dataset that converts DataEngine samples to tensor format."""
|
|
||||||
|
|
||||||
# def __init__(self, data_engine: DataEngine, processor, template, max_samples: int = None):
|
|
||||||
# self.data_engine = data_engine
|
|
||||||
# self.processor = processor
|
|
||||||
# self.template = template
|
|
||||||
# self.max_samples = max_samples or len(data_engine)
|
|
||||||
# self.tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
|
|
||||||
|
|
||||||
# def __len__(self):
|
|
||||||
# return min(self.max_samples, len(self.data_engine))
|
|
||||||
|
|
||||||
# def __getitem__(self, idx):
|
|
||||||
# # Get sample from DataEngine
|
|
||||||
# sample = self.data_engine[idx]
|
|
||||||
|
|
||||||
# # Extract messages from sample
|
|
||||||
# # DataEngine returns samples with format like {"messages": [...], ...}
|
|
||||||
# # For llamafactory/v1-sft-demo, the format should have "messages" field
|
|
||||||
# messages = None
|
|
||||||
# if "messages" in sample:
|
|
||||||
# messages = sample["messages"]
|
|
||||||
# elif "conversations" in sample:
|
|
||||||
# messages = sample["conversations"]
|
|
||||||
# elif "conversation" in sample:
|
|
||||||
# messages = sample["conversation"]
|
|
||||||
# else:
|
|
||||||
# # Try to find message-like fields (skip _dataset_name)
|
|
||||||
# for key, value in sample.items():
|
|
||||||
# if key.startswith("_"):
|
|
||||||
# continue
|
|
||||||
# if isinstance(value, list) and len(value) > 0:
|
|
||||||
# # Check if it looks like a message list
|
|
||||||
# if isinstance(value[0], dict) and "role" in value[0]:
|
|
||||||
# messages = value
|
|
||||||
# break
|
|
||||||
|
|
||||||
# if messages is None:
|
|
||||||
# raise ValueError(f"Could not find messages in sample: {list(sample.keys())}")
|
|
||||||
|
|
||||||
# # Encode messages using template
|
|
||||||
# encoded = self.template.encode_messages(self.tokenizer, messages)
|
|
||||||
|
|
||||||
# # Convert to tensors
|
|
||||||
# return {
|
|
||||||
# "input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
|
|
||||||
# "attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long),
|
|
||||||
# "labels": torch.tensor(encoded["labels"], dtype=torch.long),
|
|
||||||
# }
|
|
||||||
|
|
||||||
|
|
||||||
# def create_real_dataset(max_samples: int = 20, batch_size: int = 4):
|
|
||||||
# """Create a real dataset using DataEngine."""
|
|
||||||
# data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
|
|
||||||
# data_engine = DataEngine(data_args)
|
|
||||||
|
|
||||||
# # Create processor and template
|
|
||||||
# processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
|
||||||
# template = QwenTemplate()
|
|
||||||
|
|
||||||
# # Create tensor dataset
|
|
||||||
# raw_data_dataset = TensorDataset(data_engine, processor, template, max_samples=max_samples)
|
|
||||||
|
|
||||||
# # Create torch DataLoader
|
|
||||||
# torch_dataloader = TorchDataLoader(
|
|
||||||
# raw_data_dataset,
|
|
||||||
# batch_size=batch_size,
|
|
||||||
# shuffle=False,
|
|
||||||
# collate_fn=lambda x: x,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# return torch_dataloader, processor, template
|
|
||||||
|
|
||||||
|
|
||||||
# class TestDataLoaderNonPackNonDynamic:
|
|
||||||
# """Test case a) non pack + non dynamic."""
|
|
||||||
|
|
||||||
# def test_basic_functionality(self):
|
|
||||||
# """Test DataLoader without packing and without dynamic batching."""
|
|
||||||
# # Create real dataset
|
|
||||||
# torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
|
|
||||||
|
|
||||||
# # Create collator (non-packing)
|
|
||||||
# collator = DefaultCollator(processor=processor, template=template)
|
|
||||||
|
|
||||||
# # Create DataLoader without batching_queue (non-dynamic)
|
|
||||||
# data_loader = DataLoader(
|
|
||||||
# dataloader=torch_dataloader,
|
|
||||||
# collate_fn=collator,
|
|
||||||
# num_micro_batch=1,
|
|
||||||
# batching_queue=None,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# # Iterate and check results
|
|
||||||
# batches = list(iter(data_loader))
|
|
||||||
# assert len(batches) > 0
|
|
||||||
|
|
||||||
# # Check first batch
|
|
||||||
# one_batch = batches[0]
|
|
||||||
# micro_batches = one_batch[0]
|
|
||||||
# assert "input_ids" in micro_batches
|
|
||||||
# assert "attention_mask" in micro_batches
|
|
||||||
# assert "labels" in micro_batches
|
|
||||||
# assert micro_batches["input_ids"].shape[0] == 1 # batch_size=1
|
|
||||||
# assert micro_batches["input_ids"].ndim == 2 # [batch_size, seq_len]
|
|
||||||
|
|
||||||
|
|
||||||
# class TestDataLoaderNonPackDynamic:
|
|
||||||
# """Test case b) non pack + dynamic."""
|
|
||||||
|
|
||||||
# def test_basic_functionality(self):
|
|
||||||
# """Test DataLoader without packing but with dynamic batching."""
|
|
||||||
# # Create real dataset
|
|
||||||
# torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
|
|
||||||
# collator = DefaultCollator(processor=processor, template=template)
|
|
||||||
|
|
||||||
# # Create batching queue for dynamic batching
|
|
||||||
# batching_queue = TextBatchingQueue(
|
|
||||||
# token_micro_bsz=120,
|
|
||||||
# buffer_size=8,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# data_loader = DataLoader(
|
|
||||||
# dataloader=torch_dataloader,
|
|
||||||
# collate_fn=collator,
|
|
||||||
# num_micro_batch=4,
|
|
||||||
# batching_queue=batching_queue,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# # Iterate and check
|
|
||||||
# batches = list(iter(data_loader))
|
|
||||||
# micro_batch_tokens_first = [micro_batch["attention_mask"].sum() for micro_batch in batches[0]]
|
|
||||||
# assert all(num_tokens <= 120 for num_tokens in micro_batch_tokens_first)
|
|
||||||
# assert len(batches) > 0
|
|
||||||
@@ -184,6 +184,40 @@ def test_qwen3_nothink_rendering_remote(num_samples: int):
|
|||||||
assert v1_inputs["input_ids"][: len(prefix)] == prefix
|
assert v1_inputs["input_ids"][: len(prefix)] == prefix
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_sft_samples():
|
||||||
|
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
|
||||||
|
renderer = Renderer(template="chatml", processor=tokenizer)
|
||||||
|
hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES)
|
||||||
|
|
||||||
|
samples = [{"messages": V1_MESSAGES, "extra_info": "test", "_dataset_name": "default"}]
|
||||||
|
model_inputs = renderer.process_samples(samples)
|
||||||
|
assert len(model_inputs) == 1
|
||||||
|
assert model_inputs[0]["input_ids"] == hf_inputs
|
||||||
|
assert model_inputs[0]["extra_info"] == "test"
|
||||||
|
assert model_inputs[0]["_dataset_name"] == "default"
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_dpo_samples():
|
||||||
|
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
|
||||||
|
renderer = Renderer(template="chatml", processor=tokenizer)
|
||||||
|
hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES)
|
||||||
|
|
||||||
|
samples = [
|
||||||
|
{
|
||||||
|
"chosen_messages": V1_MESSAGES,
|
||||||
|
"rejected_messages": V1_MESSAGES,
|
||||||
|
"extra_info": "test",
|
||||||
|
"_dataset_name": "default",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
model_inputs = renderer.process_samples(samples)
|
||||||
|
assert len(model_inputs) == 1
|
||||||
|
assert model_inputs[0]["input_ids"] == hf_inputs * 2
|
||||||
|
assert model_inputs[0]["token_type_ids"] == [0] * len(hf_inputs) + [1] * len(hf_inputs)
|
||||||
|
assert model_inputs[0]["extra_info"] == "test"
|
||||||
|
assert model_inputs[0]["_dataset_name"] == "default"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_chatml_rendering()
|
test_chatml_rendering()
|
||||||
test_chatml_parse()
|
test_chatml_parse()
|
||||||
@@ -191,3 +225,5 @@ if __name__ == "__main__":
|
|||||||
test_qwen3_nothink_rendering()
|
test_qwen3_nothink_rendering()
|
||||||
test_qwen3_nothink_parse()
|
test_qwen3_nothink_parse()
|
||||||
test_qwen3_nothink_rendering_remote(16)
|
test_qwen3_nothink_rendering_remote(16)
|
||||||
|
test_process_sft_samples()
|
||||||
|
test_process_dpo_samples()
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from llamafactory.v1.core.model_engine import ModelEngine
|
|||||||
def test_init_on_meta():
|
def test_init_on_meta():
|
||||||
_, model_args, *_ = get_args(
|
_, model_args, *_ = get_args(
|
||||||
dict(
|
dict(
|
||||||
model="llamafactory/tiny-random-qwen2.5",
|
model="llamafactory/tiny-random-qwen3",
|
||||||
init_config={"name": "init_on_meta"},
|
init_config={"name": "init_on_meta"},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -32,7 +32,7 @@ def test_init_on_meta():
|
|||||||
def test_init_on_rank0():
|
def test_init_on_rank0():
|
||||||
_, model_args, *_ = get_args(
|
_, model_args, *_ = get_args(
|
||||||
dict(
|
dict(
|
||||||
model="llamafactory/tiny-random-qwen2.5",
|
model="llamafactory/tiny-random-qwen3",
|
||||||
init_config={"name": "init_on_rank0"},
|
init_config={"name": "init_on_rank0"},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -46,7 +46,7 @@ def test_init_on_rank0():
|
|||||||
def test_init_on_default():
|
def test_init_on_default():
|
||||||
_, model_args, *_ = get_args(
|
_, model_args, *_ = get_args(
|
||||||
dict(
|
dict(
|
||||||
model="llamafactory/tiny-random-qwen2.5",
|
model="llamafactory/tiny-random-qwen3",
|
||||||
init_config={"name": "init_on_default"},
|
init_config={"name": "init_on_default"},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ def test_apply_kernel(mock_get_accelerator: MagicMock):
|
|||||||
reload_kernels()
|
reload_kernels()
|
||||||
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
|
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen3")
|
||||||
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
||||||
original_swiglu_forward = model.model.layers[0].mlp.forward
|
original_swiglu_forward = model.model.layers[0].mlp.forward
|
||||||
model = apply_default_kernels(model=model, include_kernels="npu_fused_rmsnorm")
|
model = apply_default_kernels(model=model, include_kernels="npu_fused_rmsnorm")
|
||||||
@@ -62,7 +62,7 @@ def test_apply_all_kernels(mock_get_accelerator: MagicMock):
|
|||||||
reload_kernels()
|
reload_kernels()
|
||||||
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
|
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen3")
|
||||||
|
|
||||||
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
||||||
original_swiglu_forward = model.model.layers[0].mlp.forward
|
original_swiglu_forward = model.model.layers[0].mlp.forward
|
||||||
|
|||||||
@@ -1,112 +0,0 @@
|
|||||||
# Copyright 2025 the LlamaFactory team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from llamafactory.v1.utils.batching_queue import DynamicBatchSizeBuffer, TextBatchingQueue
|
|
||||||
|
|
||||||
|
|
||||||
def create_sample(length: int):
|
|
||||||
"""Helper to create a mock sample with a specific token length."""
|
|
||||||
return {"input_ids": torch.ones(length), "attention_mask": torch.ones(length)}
|
|
||||||
|
|
||||||
|
|
||||||
class TestDynamicBatchSizeBuffer:
|
|
||||||
def test_append_and_token_count(self):
|
|
||||||
buffer = DynamicBatchSizeBuffer()
|
|
||||||
buffer.append(create_sample(10))
|
|
||||||
buffer.append(create_sample(20))
|
|
||||||
|
|
||||||
assert len(buffer) == 2
|
|
||||||
assert buffer.total_token_count == 30
|
|
||||||
|
|
||||||
def test_get_samples_within_budget(self):
|
|
||||||
buffer = DynamicBatchSizeBuffer()
|
|
||||||
buffer.append(create_sample(10))
|
|
||||||
buffer.append(create_sample(10))
|
|
||||||
buffer.append(create_sample(50)) # This one is large
|
|
||||||
|
|
||||||
# Request 25 tokens. Should get the first two (20 tokens total)
|
|
||||||
samples = buffer.get_samples(max_tokens_per_iteration=25)
|
|
||||||
assert len(samples) == 2
|
|
||||||
|
|
||||||
def test_force_return_first_sample(self):
|
|
||||||
buffer = DynamicBatchSizeBuffer()
|
|
||||||
buffer.append(create_sample(100))
|
|
||||||
|
|
||||||
# Even though budget is 50, force=True (default) should return the 100-token sample
|
|
||||||
samples = buffer.get_samples(max_tokens_per_iteration=50, force=True)
|
|
||||||
assert len(samples) == 1
|
|
||||||
assert len(samples[0]["input_ids"]) == 100
|
|
||||||
|
|
||||||
def test_flush_removes_used_samples(self):
|
|
||||||
buffer = DynamicBatchSizeBuffer()
|
|
||||||
buffer.append(create_sample(10))
|
|
||||||
buffer.append(create_sample(20))
|
|
||||||
|
|
||||||
# Take the first sample
|
|
||||||
buffer.get_samples(max_tokens_per_iteration=15)
|
|
||||||
buffer.flush()
|
|
||||||
|
|
||||||
assert len(buffer) == 1
|
|
||||||
assert buffer.total_token_count == 20
|
|
||||||
# The remaining sample should now be at the start
|
|
||||||
remaining = buffer.get_samples(max_tokens_per_iteration=50)
|
|
||||||
assert len(remaining[0]["input_ids"]) == 20
|
|
||||||
|
|
||||||
|
|
||||||
class TestTextBatchingQueue:
|
|
||||||
def test_is_full_filled(self):
|
|
||||||
queue = TextBatchingQueue(token_micro_bsz=100, buffer_size=2)
|
|
||||||
|
|
||||||
queue.put_item(create_sample(10))
|
|
||||||
assert not queue.is_full_filled() # Only 1 sample, buffer_size=2
|
|
||||||
|
|
||||||
queue.put_item(create_sample(10))
|
|
||||||
assert not queue.is_full_filled() # 2 samples, but only 20 tokens (min 100)
|
|
||||||
|
|
||||||
queue.put_item(create_sample(90))
|
|
||||||
assert queue.is_full_filled() # Meets both conditions
|
|
||||||
|
|
||||||
def test_warmup_logic(self):
|
|
||||||
# token_micro_bsz=1000, starts at 200, reaches 1000 at step 10
|
|
||||||
queue = TextBatchingQueue(token_micro_bsz=1000, bsz_warmup_steps=10, bsz_warmup_init_mbtoken=200)
|
|
||||||
|
|
||||||
# Step 0: should be init value
|
|
||||||
assert queue.get_cur_token_micro_bsz() == 200
|
|
||||||
|
|
||||||
# Step 5: halfway through warmup (200 + (800 * 5/10)) = 600
|
|
||||||
queue._step = 5
|
|
||||||
assert queue.get_cur_token_micro_bsz() == 600
|
|
||||||
|
|
||||||
# Step 11: past warmup
|
|
||||||
queue._step = 11
|
|
||||||
assert queue.get_cur_token_micro_bsz() == 1000
|
|
||||||
|
|
||||||
def test_get_micro_batch_integration(self):
|
|
||||||
queue = TextBatchingQueue(token_micro_bsz=50, buffer_size=1)
|
|
||||||
queue.put_item(create_sample(20))
|
|
||||||
queue.put_item(create_sample(20))
|
|
||||||
queue.put_item(create_sample(20))
|
|
||||||
|
|
||||||
# At step 0 (warmup not triggered as bsz_warmup_steps is -1 default),
|
|
||||||
# it should take samples up to 50 tokens.
|
|
||||||
batch = queue.get_micro_batch(step=0)
|
|
||||||
|
|
||||||
assert len(batch) == 2
|
|
||||||
assert queue.empty() is False
|
|
||||||
|
|
||||||
batch_2 = queue.get_micro_batch(step=1)
|
|
||||||
assert len(batch_2) == 1
|
|
||||||
assert queue.empty() is True
|
|
||||||
Reference in New Issue
Block a user