mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-01-30 06:12:04 +00:00
[v1] add batch generator (#9744)
This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from .arg_parser import InputArgument, get_args
|
||||
from .arg_utils import ModelClass, SampleBackend
|
||||
from .arg_utils import BatchingStrategy, ModelClass, SampleBackend
|
||||
from .data_args import DataArguments
|
||||
from .model_args import ModelArguments
|
||||
from .sample_args import SampleArguments
|
||||
@@ -21,6 +21,7 @@ from .training_args import TrainingArguments
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BatchingStrategy",
|
||||
"DataArguments",
|
||||
"InputArgument",
|
||||
"ModelArguments",
|
||||
|
||||
@@ -50,6 +50,14 @@ class SampleBackend(StrEnum):
|
||||
VLLM = "vllm"
|
||||
|
||||
|
||||
@unique
|
||||
class BatchingStrategy(StrEnum):
|
||||
NORMAL = "normal"
|
||||
PADDING_FREE = "padding_free"
|
||||
DYNAMIC_BATCHING = "dynamic_batching"
|
||||
DYNAMIC_PADDING_FREE = "dynamic_padding_free"
|
||||
|
||||
|
||||
def _convert_str_dict(data: dict) -> dict:
|
||||
"""Parse string representation inside the dictionary.
|
||||
|
||||
|
||||
@@ -22,7 +22,3 @@ class DataArguments:
|
||||
default=None,
|
||||
metadata={"help": "Path to the dataset."},
|
||||
)
|
||||
cutoff_len: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "Cutoff length for the dataset."},
|
||||
)
|
||||
|
||||
@@ -16,7 +16,7 @@ import os
|
||||
from dataclasses import dataclass, field
|
||||
from uuid import uuid4
|
||||
|
||||
from .arg_utils import PluginConfig, get_plugin_config
|
||||
from .arg_utils import BatchingStrategy, PluginConfig, get_plugin_config
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -29,18 +29,30 @@ class TrainingArguments:
|
||||
default=1,
|
||||
metadata={"help": "Micro batch size for training."},
|
||||
)
|
||||
global_batch_size: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Global batch size for training."},
|
||||
global_batch_size: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Global batch size for training, default to DP size * micro batch size."},
|
||||
)
|
||||
learning_rate: float = field(
|
||||
default=1e-4,
|
||||
metadata={"help": "Learning rate for training."},
|
||||
)
|
||||
cutoff_len: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "Maximum sequence length for training."},
|
||||
)
|
||||
bf16: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use bf16 for training."},
|
||||
)
|
||||
batching_strategy: BatchingStrategy = field(
|
||||
default=BatchingStrategy.NORMAL,
|
||||
metadata={"help": "Batching strategy for training."},
|
||||
)
|
||||
batching_workers: int = field(
|
||||
default=16,
|
||||
metadata={"help": "Number of workers for batching."},
|
||||
)
|
||||
dist_config: PluginConfig | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Distribution configuration for training."},
|
||||
|
||||
@@ -29,7 +29,6 @@ Train Phase:
|
||||
|
||||
from ..config.training_args import TrainingArguments
|
||||
from ..utils.types import HFModel, TorchDataset
|
||||
from .utils.data_collator import DataCollator
|
||||
from .utils.rendering import Renderer
|
||||
|
||||
|
||||
@@ -45,7 +44,6 @@ class BaseTrainer:
|
||||
self.model = model
|
||||
self.renderer = renderer
|
||||
self.dataset = dataset
|
||||
self.data_collator = DataCollator()
|
||||
self.optimizer = None
|
||||
self.lr_scheduler = None
|
||||
|
||||
|
||||
@@ -82,14 +82,17 @@ class DataEngine(Dataset):
|
||||
|
||||
def _load_dataset(self) -> None:
|
||||
"""Load datasets according to dataset info."""
|
||||
is_streaming = [dataset_info.get("streaming", False) for dataset_info in self.dataset_infos.values()]
|
||||
self.streaming = any(is_streaming)
|
||||
if all(is_streaming) != any(is_streaming):
|
||||
raise ValueError("All datasets must be streaming or non-streaming.")
|
||||
|
||||
for dataset_name, dataset_info in self.dataset_infos.items():
|
||||
split = dataset_info.get("split", "train")
|
||||
streaming = dataset_info.get("streaming", False)
|
||||
self.streaming |= streaming
|
||||
if dataset_info.get("source", "hf_hub") == "hf_hub":
|
||||
from datasets import load_dataset
|
||||
|
||||
self.datasets[dataset_name] = load_dataset(dataset_info["path"], split=split, streaming=streaming)
|
||||
self.datasets[dataset_name] = load_dataset(dataset_info["path"], split=split, streaming=self.streaming)
|
||||
else: # data loader plugin
|
||||
from ..plugins.data_plugins.loader import DataLoaderPlugin
|
||||
|
||||
@@ -98,8 +101,7 @@ class DataEngine(Dataset):
|
||||
def _build_data_index(self) -> None:
|
||||
"""Build dataset index."""
|
||||
for dataset_name, dataset in self.datasets.items():
|
||||
streaming = self.dataset_infos[dataset_name].get("streaming", False)
|
||||
if streaming:
|
||||
if self.streaming:
|
||||
data_index = [(dataset_name, -1) for _ in range(1000)]
|
||||
else:
|
||||
data_index = [(dataset_name, sample_index) for sample_index in range(len(dataset))]
|
||||
@@ -185,8 +187,8 @@ class DataEngine(Dataset):
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_sft_demo.yaml
|
||||
python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_dpo_demo.yaml
|
||||
python -m llamafactory.v1.core.data_engine --dataset data/v1_sft_demo.yaml
|
||||
python -m llamafactory.v1.core.data_engine --dataset data/v1_dpo_demo.yaml
|
||||
"""
|
||||
from ..config.arg_parser import get_args
|
||||
|
||||
|
||||
244
src/llamafactory/v1/core/utils/batching.py
Normal file
244
src/llamafactory/v1/core/utils/batching.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Batching utils supports stateful dataloader.
|
||||
|
||||
1. Init stateful dataloader (tokenize)
|
||||
2. Add to buffer
|
||||
3. Yield batch indexes (micro batch * grad acc)
|
||||
a) non pack + non dynamic
|
||||
b) non pack + dynamic
|
||||
c) pack + non dynamic
|
||||
d) pack + dynamic
|
||||
"""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from torch.utils.data import default_collate
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
|
||||
|
||||
from ...accelerator.interface import DistributedInterface
|
||||
from ...config import BatchingStrategy
|
||||
from ...utils import logging
|
||||
from ...utils.helper import pad_and_truncate
|
||||
from ...utils.types import BatchInput, ModelInput, TorchDataset
|
||||
from .rendering import Renderer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def default_collate_fn(
|
||||
buffer: list[ModelInput], buffer_tokens: int, micro_batch_size: int, num_micro_batch: int, cutoff_len: int
|
||||
) -> tuple[list[ModelInput], int, list[BatchInput]]:
|
||||
batch_size = micro_batch_size * num_micro_batch
|
||||
if len(buffer) < batch_size:
|
||||
return buffer, buffer_tokens, None
|
||||
|
||||
samples = buffer[:batch_size]
|
||||
buffer = buffer[batch_size:]
|
||||
buffer_tokens -= sum(len(sample["input_ids"]) for sample in samples)
|
||||
|
||||
batch = []
|
||||
for i in range(num_micro_batch):
|
||||
micro_batch = samples[i * micro_batch_size : (i + 1) * micro_batch_size]
|
||||
batch.append(default_collate(pad_and_truncate(micro_batch, cutoff_len)))
|
||||
|
||||
return buffer, buffer_tokens, batch
|
||||
|
||||
|
||||
class BatchGenerator(Iterator):
|
||||
def __init__(
|
||||
self,
|
||||
dataset: TorchDataset,
|
||||
renderer: Renderer,
|
||||
micro_batch_size: int = 1,
|
||||
global_batch_size: int | None = None,
|
||||
cutoff_len: int = 2048,
|
||||
batching_workers: int = 0,
|
||||
batching_strategy: BatchingStrategy = BatchingStrategy.NORMAL,
|
||||
pin_memory: bool = True,
|
||||
drop_last: bool = True,
|
||||
) -> None:
|
||||
self.dataset = dataset
|
||||
self.renderer = renderer
|
||||
|
||||
self.micro_batch_size = micro_batch_size
|
||||
self.global_batch_size = global_batch_size
|
||||
self.cutoff_len = cutoff_len
|
||||
self.batching_workers = batching_workers
|
||||
self.batching_strategy = batching_strategy
|
||||
self.pin_memory = pin_memory
|
||||
self.drop_last = drop_last
|
||||
# TODO: support length and infinity
|
||||
|
||||
dp_size = DistributedInterface().get_world_size("dp")
|
||||
|
||||
if self.global_batch_size is None:
|
||||
self.global_batch_size = dp_size * micro_batch_size
|
||||
self.num_micro_batch = 1
|
||||
elif self.global_batch_size % (dp_size * micro_batch_size) == 0:
|
||||
self.num_micro_batch = global_batch_size // dp_size // micro_batch_size
|
||||
else:
|
||||
raise ValueError(
|
||||
"Global batch size must be divisible by DP size and micro batch size. "
|
||||
f"Got {global_batch_size} % ({dp_size} * {micro_batch_size}) != 0."
|
||||
)
|
||||
|
||||
if not self.drop_last:
|
||||
raise ValueError("Drop last must be True.")
|
||||
|
||||
self._init_data_provider()
|
||||
|
||||
self._is_resuming: bool = False
|
||||
self._data_iter = iter(self._data_provider)
|
||||
self._buffer: list[ModelInput] = []
|
||||
self._buffer_tokens: int = 0
|
||||
self._max_buffer_tokens: int = self.micro_batch_size * self.num_micro_batch * self.cutoff_len
|
||||
|
||||
logger.info_rank0(
|
||||
f"Init unified data loader with global batch size {self.global_batch_size}, "
|
||||
f"micro batch size {self.micro_batch_size}, "
|
||||
f"num micro batch {self.num_micro_batch}, "
|
||||
f"cutoff len {self.cutoff_len}, "
|
||||
f"batching workers {self.batching_workers}, "
|
||||
f"batching strategy {self.batching_strategy}."
|
||||
)
|
||||
|
||||
def _init_data_provider(self) -> None:
|
||||
if len(self.dataset) != -1:
|
||||
sampler = StatefulDistributedSampler(
|
||||
self.dataset,
|
||||
num_replicas=DistributedInterface().get_world_size("dp"),
|
||||
rank=DistributedInterface().get_rank("dp"),
|
||||
shuffle=True,
|
||||
seed=0,
|
||||
drop_last=self.drop_last,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Iterable dataset is not supported yet.")
|
||||
|
||||
self._data_provider = StatefulDataLoader(
|
||||
self.dataset,
|
||||
batch_size=self.micro_batch_size * self.num_micro_batch,
|
||||
sampler=sampler,
|
||||
num_workers=self.batching_workers,
|
||||
collate_fn=self.renderer.process_samples,
|
||||
pin_memory=self.pin_memory,
|
||||
drop_last=self.drop_last,
|
||||
)
|
||||
if self.batching_strategy == BatchingStrategy.NORMAL:
|
||||
self._length = len(self._data_provider)
|
||||
else:
|
||||
from ...plugins.trainer_plugins.batching import BatchingPlugin
|
||||
|
||||
self._length = BatchingPlugin(self.batching_strategy).compute_length()
|
||||
raise NotImplementedError("Batching strategy other than NORMAL is not supported yet.")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._length
|
||||
|
||||
def __iter__(self):
|
||||
if not self._is_resuming:
|
||||
self._buffer.clear()
|
||||
self._buffer_tokens = 0
|
||||
|
||||
self._data_iter = iter(self._data_provider)
|
||||
self._is_resuming = False
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
batch = self._next_batch()
|
||||
if batch is None:
|
||||
raise StopIteration
|
||||
|
||||
return batch
|
||||
|
||||
def _next_batch(self) -> list[BatchInput] | None:
|
||||
while self._buffer_tokens < self._max_buffer_tokens:
|
||||
try:
|
||||
samples: list[ModelInput] = next(self._data_iter)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
num_tokens = sum(len(sample["input_ids"]) for sample in samples)
|
||||
self._buffer.extend(samples)
|
||||
self._buffer_tokens += num_tokens
|
||||
|
||||
return self._build_batch()
|
||||
|
||||
def _build_batch(self) -> list[BatchInput] | None:
|
||||
if self.batching_strategy == BatchingStrategy.NORMAL:
|
||||
self._buffer, self._buffer_tokens, batch = default_collate_fn(
|
||||
self._buffer, self._buffer_tokens, self.micro_batch_size, self.num_micro_batch, self.cutoff_len
|
||||
)
|
||||
return batch
|
||||
else:
|
||||
from ...plugins.trainer_plugins.batching import BatchingPlugin
|
||||
|
||||
self._buffer, self._buffer_tokens, batch = BatchingPlugin(self.batching_strategy)(
|
||||
self._buffer, self._buffer_tokens, self.micro_batch_size, self.num_micro_batch, self.cutoff_len
|
||||
)
|
||||
return batch
|
||||
|
||||
def state_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"buffer": self._buffer,
|
||||
"buffer_tokens": self._buffer_tokens,
|
||||
"data_provider": self._data_provider.state_dict(),
|
||||
}
|
||||
|
||||
def load_state_dict(self, state: dict[str, Any]) -> None:
|
||||
self._buffer = state["buffer"]
|
||||
self._buffer_tokens = state["buffer_tokens"]
|
||||
self._data_provider.load_state_dict(state["data_provider"])
|
||||
self._is_resuming = True
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
if hasattr(self._data_provider.sampler, "set_epoch"):
|
||||
self._data_provider.sampler.set_epoch(epoch)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m llamafactory.v1.core.utils.batching \
|
||||
--model llamafactory/tiny-random-qwen2.5 \
|
||||
--dataset data/v1_sft_demo.yaml \
|
||||
--micro_batch_size 2 \
|
||||
--global_batch_size 4 \
|
||||
--batching_workers 0
|
||||
"""
|
||||
from ...config.arg_parser import get_args
|
||||
from ..data_engine import DataEngine
|
||||
from ..model_engine import ModelEngine
|
||||
|
||||
data_args, model_args, training_args, _ = get_args()
|
||||
data_engine = DataEngine(data_args=data_args)
|
||||
model_engine = ModelEngine(model_args=model_args)
|
||||
batch_generator = BatchGenerator(
|
||||
data_engine,
|
||||
model_engine.renderer,
|
||||
micro_batch_size=training_args.micro_batch_size,
|
||||
global_batch_size=training_args.global_batch_size,
|
||||
cutoff_len=training_args.cutoff_len,
|
||||
batching_workers=training_args.batching_workers,
|
||||
batching_strategy=training_args.batching_strategy,
|
||||
)
|
||||
for batch in batch_generator:
|
||||
print(batch)
|
||||
print(len(batch))
|
||||
print(batch[0]["input_ids"].shape)
|
||||
break
|
||||
@@ -1,277 +0,0 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import copy
|
||||
import sys
|
||||
from collections.abc import Generator, Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
|
||||
|
||||
from ...utils.batching_queue import BaseBatchingQueue
|
||||
from ...utils.logging import get_logger
|
||||
from ...utils.types import Processor, TorchDataset
|
||||
from .data_collator import DataCollator
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# base dataloader
|
||||
class DistributedDataloader(StatefulDataLoader):
|
||||
"""Base Distributed DataLoader."""
|
||||
|
||||
dataset: "TorchDataset"
|
||||
sampler: "StatefulDistributedSampler"
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
if self.sampler is not None and hasattr(self.sampler, "set_epoch"):
|
||||
self.sampler.set_epoch(epoch)
|
||||
elif hasattr(self.dataset, "set_epoch"):
|
||||
self.dataset.set_epoch(epoch)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseDataLoader:
|
||||
"""Default DataLoader."""
|
||||
|
||||
processor: Processor
|
||||
|
||||
def __init__(self, dataset: TorchDataset) -> None:
|
||||
self.dataset = dataset
|
||||
# guidlines: fetch until get fixed batchsize.
|
||||
# save state_dict for buffer.
|
||||
# resume with state
|
||||
|
||||
# 1. Init stateful dataloader (tokenize)
|
||||
# 2. Add to buffer (2 * max seq len per device)
|
||||
# 3. Yield batch indexes (micro batch * grad acc)
|
||||
# a ) non pack + non dynamic
|
||||
# b ) non pack + dynamic
|
||||
# c ) pack + non dynamic
|
||||
# d ) pack + dynamic
|
||||
|
||||
def init_dataloader(self) -> None:
|
||||
### init dataloader
|
||||
pass
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
pass
|
||||
|
||||
def __next__(self) -> any:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataLoader:
|
||||
"""Default DataLoader."""
|
||||
|
||||
processor: "Processor"
|
||||
dataloader: "DistributedDataloader"
|
||||
batching_queue: "BaseBatchingQueue"
|
||||
collate_fn: "DataCollator"
|
||||
num_micro_batch: int = 1
|
||||
length: int = 0
|
||||
drop_last: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataloader: any,
|
||||
collate_fn: "DataCollator",
|
||||
num_micro_batch: int = 1,
|
||||
length: int = 0,
|
||||
drop_last: bool = True,
|
||||
batching_queue: Optional["BaseBatchingQueue"] = None,
|
||||
) -> None:
|
||||
self.batching_queue = batching_queue
|
||||
self.num_micro_batch = num_micro_batch
|
||||
self.step = 0
|
||||
self._collate_fn = collate_fn
|
||||
self._dataloader = dataloader
|
||||
self._drop_last = drop_last
|
||||
self._data_iter: Iterator
|
||||
self._resume = False
|
||||
self._batch_data_iter: Generator
|
||||
|
||||
if length > 0:
|
||||
self._length = length
|
||||
elif length == -1:
|
||||
self._length = sys.maxsize
|
||||
else:
|
||||
self._length = len(self._dataloader)
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
if not self._resume:
|
||||
self.step = 0
|
||||
self._data_iter = iter(self._dataloader)
|
||||
self._batch_data_iter = self.batch_data_generator()
|
||||
self._resume = False
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
return next(self._batch_data_iter) # FIXME maybe we can move origin_batch_data_generator to here
|
||||
|
||||
def origin_batch_data_generator(self):
|
||||
"""Standard pass-through generator if do not use batching queue."""
|
||||
while True:
|
||||
if self._length > 0 and self.step >= self._length:
|
||||
return
|
||||
|
||||
try:
|
||||
batch = []
|
||||
data = next(self._data_iter)
|
||||
# split data into micro batches
|
||||
for i in range(0, len(data), self.num_micro_batch):
|
||||
micro_batch = data[i : i + self.num_micro_batch]
|
||||
if self._collate_fn:
|
||||
micro_batch = self._collate_fn(micro_batch)
|
||||
batch.append(micro_batch)
|
||||
yield batch
|
||||
self.step += 1
|
||||
except StopIteration:
|
||||
if self.step < self._length:
|
||||
# Restart iterator to fill the requested length
|
||||
self._data_iter = iter(self._dataloader)
|
||||
try:
|
||||
batch = []
|
||||
data = next(self._data_iter)
|
||||
for i in range(0, len(data), self.num_micro_batch):
|
||||
micro_batch = data[i : i + self.num_micro_batch]
|
||||
if self._collate_fn:
|
||||
micro_batch = self._collate_fn(micro_batch)
|
||||
batch.append(micro_batch)
|
||||
yield batch
|
||||
self.step += 1
|
||||
except StopIteration:
|
||||
return
|
||||
else:
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"DataLoader origin_batch_data_generator exception: {e}")
|
||||
raise
|
||||
|
||||
def batch_data_generator(self):
|
||||
if self.batching_queue is None:
|
||||
yield from self.origin_batch_data_generator()
|
||||
return
|
||||
|
||||
batch = []
|
||||
|
||||
while True:
|
||||
if self._length and self.step >= self._length:
|
||||
return
|
||||
|
||||
if self.batching_queue.is_full_filled():
|
||||
micro_batch = self.batching_queue.get_micro_batch(self.step)
|
||||
if self._collate_fn:
|
||||
micro_batch = self._collate_fn(micro_batch)
|
||||
batch.append(micro_batch)
|
||||
if len(batch) == self.num_micro_batch:
|
||||
yield batch
|
||||
self.step += 1
|
||||
batch = []
|
||||
|
||||
try:
|
||||
processing_item = next(self._data_iter)
|
||||
except Exception as e:
|
||||
if isinstance(e, StopIteration):
|
||||
if self.step < self._length:
|
||||
# call iter until reach length
|
||||
self._data_iter = iter(self._dataloader)
|
||||
processing_item = next(self._data_iter)
|
||||
elif not self._drop_last and not self.batching_queue.empty():
|
||||
while not self.batching_queue.empty():
|
||||
micro_batch = self.batching_queue.get_micro_batch(self.step)
|
||||
if self._collate_fn:
|
||||
micro_batch = self._collate_fn(micro_batch)
|
||||
batch.append(micro_batch)
|
||||
if len(batch) == self.num_micro_batch:
|
||||
yield batch
|
||||
self.step += 1
|
||||
batch = []
|
||||
|
||||
while len(batch) < self.num_micro_batch:
|
||||
padding_batch = copy.deepcopy(micro_batch)
|
||||
padding_batch["is_padded"] = True
|
||||
batch.append(padding_batch)
|
||||
yield batch
|
||||
self.step += 1
|
||||
return
|
||||
else:
|
||||
return
|
||||
else:
|
||||
logger.error(f"DataLoader iter data exception: {e}")
|
||||
raise
|
||||
|
||||
# put processing_item to buffer
|
||||
if isinstance(processing_item, dict):
|
||||
processing_item = [processing_item]
|
||||
|
||||
for item in processing_item:
|
||||
self.batching_queue.put_item(item)
|
||||
|
||||
def state_dict(self):
|
||||
# save state
|
||||
state = self.__dict__.copy()
|
||||
# remove internal fields
|
||||
for k in list(state.keys()):
|
||||
if k.startswith("_"):
|
||||
del state[k]
|
||||
|
||||
# save dataloader state
|
||||
if hasattr(self._dataloader, "state_dict"):
|
||||
state["dataloader_state"] = self._dataloader.state_dict()
|
||||
elif hasattr(self._dataloader, "__getstate__"):
|
||||
state["dataloader_state"] = self._dataloader.__getstate__()
|
||||
|
||||
batching_strategy = getattr(self, "batching_strategy", None)
|
||||
if batching_strategy and hasattr(batching_strategy, "state_dict"):
|
||||
state["batching_strategy_state"] = batching_strategy.state_dict()
|
||||
if "batching_strategy" in state:
|
||||
del state["batching_strategy"]
|
||||
|
||||
return copy.deepcopy(state)
|
||||
|
||||
def load_state_dict(self, state: dict[str, any]):
|
||||
if state["num_micro_batch"] != self.num_micro_batch:
|
||||
logger.warning(
|
||||
f"num_micro_batch changed: [ {state['num_micro_batch']} -> {self.num_micro_batch} ], will clear prefetch buffer"
|
||||
)
|
||||
del state["num_micro_batch"]
|
||||
self.__dict__.update(state)
|
||||
self._resume = True
|
||||
|
||||
if hasattr(self._dataloader, "load_state_dict"):
|
||||
self._dataloader.load_state_dict(state["dataloader_state"])
|
||||
elif hasattr(self._dataloader, "__getstate__"):
|
||||
self._dataloader.__setstate__(state["dataloader_state"])
|
||||
|
||||
if "batching_strategy_state" in state:
|
||||
batching_strategy = getattr(self, "batching_strategy", None)
|
||||
if batching_strategy:
|
||||
batching_strategy.load_state_dict(state["batching_strategy_state"])
|
||||
del state["batching_strategy_state"]
|
||||
|
||||
self._data_iter = iter(self._dataloader)
|
||||
self._batch_data_iter = self.batch_data_generator()
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
if hasattr(self._dataloader, "set_epoch"):
|
||||
self._dataloader.set_epoch(epoch)
|
||||
@@ -12,10 +12,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Rendering utils.
|
||||
|
||||
How to use:
|
||||
renderer = Renderer(template, processor)
|
||||
renderer.render_messages(messages: list[Message], tools: str | None) -> ModelInputs
|
||||
renderer.parse_message(text: str) -> Message
|
||||
renderer.process_samples(samples: list[Sample]) -> list[ModelInput]
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...utils.constants import IGNORE_INDEX
|
||||
from ...utils.helper import get_tokenizer
|
||||
from ...utils.types import Message, ModelInput, Processor
|
||||
from ...utils.types import Message, ModelInput, Processor, Sample
|
||||
|
||||
|
||||
def render_chatml_messages(
|
||||
@@ -64,7 +74,7 @@ def render_chatml_messages(
|
||||
|
||||
|
||||
def parse_chatml_message(generated_text: str) -> Message:
|
||||
"""Parse a message in ChatML format. Supports interleaved reasoning and tool calls.
|
||||
"""Parse a message in ChatML format.
|
||||
|
||||
Args:
|
||||
generated_text (str): The generated text in ChatML format.
|
||||
@@ -83,6 +93,16 @@ class Renderer:
|
||||
def render_messages(
|
||||
self, messages: list[Message], tools: str | None = None, is_generate: bool = False
|
||||
) -> ModelInput:
|
||||
"""Apply template to messages and convert them to model input.
|
||||
|
||||
Args:
|
||||
messages (list[Message]): The messages to render.
|
||||
tools (str | None, optional): The tools to use. Defaults to None.
|
||||
is_generate (bool, optional): Whether to render for generation. Defaults to False.
|
||||
|
||||
Returns:
|
||||
ModelInput: The rendered model input.
|
||||
"""
|
||||
if self.template == "chatml":
|
||||
return render_chatml_messages(self.processor, messages, tools, is_generate)
|
||||
else:
|
||||
@@ -91,9 +111,59 @@ class Renderer:
|
||||
return RenderingPlugin(self.template).render_messages(self.processor, messages, tools, is_generate)
|
||||
|
||||
def parse_message(self, generated_text: str) -> Message:
|
||||
"""Parse a message in the template format.
|
||||
|
||||
Args:
|
||||
generated_text (str): The generated text in the template format.
|
||||
|
||||
Returns:
|
||||
Message: The parsed message.
|
||||
"""
|
||||
if self.template == "chatml":
|
||||
return parse_chatml_message(generated_text)
|
||||
else:
|
||||
from ...plugins.model_plugins.rendering import RenderingPlugin
|
||||
|
||||
return RenderingPlugin(self.template).parse_message(generated_text)
|
||||
|
||||
def process_samples(self, samples: list[Sample]) -> list[ModelInput]:
|
||||
"""Process samples to model input.
|
||||
|
||||
Args:
|
||||
samples (list[Sample]): The samples to process.
|
||||
|
||||
Returns:
|
||||
list[ModelInput]: The processed model inputs.
|
||||
"""
|
||||
model_inputs = []
|
||||
for sample in samples:
|
||||
if "messages" in sample:
|
||||
model_input = self.render_messages(sample["messages"], sample.get("tools"))
|
||||
elif "chosen_messages" in sample and "rejected_messages" in sample:
|
||||
chosen_input = self.render_messages(sample["chosen_messages"], sample.get("tools"))
|
||||
rejected_input = self.render_messages(sample["rejected_messages"], sample.get("tools"))
|
||||
chosen_input["token_type_ids"] = [0] * len(chosen_input["input_ids"])
|
||||
rejected_input["token_type_ids"] = [1] * len(rejected_input["input_ids"])
|
||||
model_input = ModelInput(
|
||||
input_ids=chosen_input["input_ids"] + rejected_input["input_ids"],
|
||||
attention_mask=chosen_input["attention_mask"] + rejected_input["attention_mask"],
|
||||
labels=chosen_input["labels"] + rejected_input["labels"],
|
||||
loss_weights=chosen_input["loss_weights"] + rejected_input["loss_weights"],
|
||||
token_type_ids=chosen_input["token_type_ids"] + rejected_input["token_type_ids"],
|
||||
)
|
||||
if "position_ids" in chosen_input:
|
||||
model_input["position_ids"] = np.concatenate(
|
||||
[chosen_input["position_ids"], rejected_input["position_ids"]], axis=-1
|
||||
)
|
||||
else:
|
||||
raise ValueError("No valid messages or chosen_messages/rejected_messages found in sample.")
|
||||
|
||||
if "extra_info" in sample:
|
||||
model_input["extra_info"] = sample["extra_info"]
|
||||
|
||||
if "_dataset_name" in sample:
|
||||
model_input["_dataset_name"] = sample["_dataset_name"]
|
||||
|
||||
model_inputs.append(model_input)
|
||||
|
||||
return model_inputs
|
||||
|
||||
@@ -32,7 +32,8 @@ class AlpacaSample(TypedDict, total=False):
|
||||
|
||||
|
||||
SharegptMessage = TypedDict(
|
||||
"SharegptMessage", {"from": Literal["human", "gpt", "system", "function_call", "observation"], "value": str}
|
||||
"SharegptMessage",
|
||||
{"from": Literal["human", "gpt", "system", "function_call", "observation"], "value": str},
|
||||
)
|
||||
|
||||
|
||||
@@ -118,15 +119,8 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
|
||||
"observation": "tool",
|
||||
"function_call": "assistant",
|
||||
}
|
||||
sample = {}
|
||||
messages = []
|
||||
tools = raw_sample.get("tools")
|
||||
if tools:
|
||||
try:
|
||||
tools: list[dict[str, Any]] = json.loads(tools)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
|
||||
tools = []
|
||||
|
||||
for message in raw_sample.get("conversations", []):
|
||||
tag = message["from"]
|
||||
if tag not in tag_mapping:
|
||||
@@ -157,10 +151,17 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
|
||||
}
|
||||
)
|
||||
|
||||
sample["messages"] = messages
|
||||
|
||||
tools = raw_sample.get("tools")
|
||||
if tools:
|
||||
return {"messages": messages, "tools": json.dumps(tools)}
|
||||
else:
|
||||
return {"messages": messages}
|
||||
try:
|
||||
tools: list[dict[str, Any]] = json.loads(tools)
|
||||
sample["tools"] = json.dumps(tools)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
@DataConverterPlugin("pair").register()
|
||||
@@ -179,6 +180,24 @@ def pair_converter(raw_sample: PairSample) -> DPOSample:
|
||||
def process_message(raw_messages: list[OpenaiMessage]):
|
||||
messages = []
|
||||
for message in raw_messages:
|
||||
if message["role"] == "tool":
|
||||
try:
|
||||
tool_calls: ToolCall | list[ToolCall] = json.loads(message["content"])
|
||||
except json.JSONDecodeError:
|
||||
logger.warning_rank0(f"Invalid tool call format: {str(message['content'])}")
|
||||
continue
|
||||
|
||||
if not isinstance(tool_calls, list):
|
||||
tool_calls = [tool_calls]
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": message["role"],
|
||||
"content": [{"type": "tool_call", "value": json.dumps(tool_call)} for tool_call in tool_calls],
|
||||
"loss_weight": 1.0 if message["role"] == "assistant" else 0.0,
|
||||
}
|
||||
)
|
||||
else:
|
||||
messages.append(
|
||||
{
|
||||
"role": message["role"],
|
||||
@@ -189,7 +208,16 @@ def pair_converter(raw_sample: PairSample) -> DPOSample:
|
||||
|
||||
return messages
|
||||
|
||||
chosen_messages = process_message(raw_sample.get("chosen", []))
|
||||
rejected_messages = process_message(raw_sample.get("rejected", []))
|
||||
sample = {}
|
||||
sample["chosen_messages"] = process_message(raw_sample.get("chosen", []))
|
||||
sample["rejected_messages"] = process_message(raw_sample.get("rejected", []))
|
||||
|
||||
return {"chosen_messages": chosen_messages, "rejected_messages": rejected_messages}
|
||||
tools = raw_sample.get("tools")
|
||||
if tools:
|
||||
try:
|
||||
tools: list[dict[str, Any]] = json.loads(tools)
|
||||
sample["tools"] = json.dumps(tools)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
|
||||
|
||||
return sample
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
@@ -51,12 +50,16 @@ def _update_model_input(
|
||||
|
||||
|
||||
@RenderingPlugin("qwen3_nothink").register("render_messages")
|
||||
def render_qwen_messages(
|
||||
def render_qwen3_nothink_messages(
|
||||
processor: Processor,
|
||||
messages: list[Message],
|
||||
tools: str | None = None,
|
||||
is_generate: bool = False,
|
||||
) -> ModelInput:
|
||||
"""Render messages in the Qwen3 nothink template format.
|
||||
|
||||
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-4B-Instruct-2507
|
||||
"""
|
||||
input_ids, labels, loss_weights = [], [], []
|
||||
temp_str, temp_weight = "", 0.0
|
||||
if tools:
|
||||
@@ -179,7 +182,15 @@ def render_qwen_messages(
|
||||
|
||||
|
||||
@RenderingPlugin("qwen3_nothink").register("parse_message")
|
||||
def parse_qwen_message(generated_text: str) -> Message:
|
||||
def parse_qwen3_nothink_message(generated_text: str) -> Message:
|
||||
"""Parse a message in the Qwen3 nothink template format. Supports interleaved reasoning and tool calls.
|
||||
|
||||
Args:
|
||||
generated_text (str): The generated text in the Qwen3 nothink template format.
|
||||
|
||||
Returns:
|
||||
Message: The parsed message.
|
||||
"""
|
||||
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
|
||||
content = []
|
||||
last_end = 0
|
||||
|
||||
19
src/llamafactory/v1/plugins/trainer_plugins/batching.py
Normal file
19
src/llamafactory/v1/plugins/trainer_plugins/batching.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils.plugin import BasePlugin
|
||||
|
||||
|
||||
class BatchingPlugin(BasePlugin):
|
||||
pass
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
|
||||
from ..accelerator.interface import DistributedInterface
|
||||
from ..config.arg_parser import get_args
|
||||
from ..config import InputArgument, get_args
|
||||
from ..core.base_trainer import BaseTrainer
|
||||
from ..core.data_engine import DataEngine
|
||||
from ..core.model_engine import ModelEngine
|
||||
@@ -24,15 +24,15 @@ class SFTTrainer(BaseTrainer):
|
||||
pass
|
||||
|
||||
|
||||
def run_sft(user_args):
|
||||
model_args, data_args, training_args, _ = get_args(user_args)
|
||||
def run_sft(args: InputArgument = None):
|
||||
model_args, data_args, training_args, _ = get_args(args)
|
||||
DistributedInterface(training_args.dist_config)
|
||||
data_engine = DataEngine(data_args)
|
||||
model_engine = ModelEngine(model_args)
|
||||
trainer = SFTTrainer(
|
||||
args=training_args,
|
||||
model=model_engine.model,
|
||||
processor=model_engine.processor,
|
||||
renderer=model_engine.renderer,
|
||||
dataset=data_engine,
|
||||
)
|
||||
trainer.fit()
|
||||
|
||||
@@ -1,220 +0,0 @@
|
||||
# Copyright 2025 Bytedance Ltd. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the Bytedance's VeOmni library.
|
||||
# https://github.com/ByteDance-Seed/VeOmni/blob/v0.1.4/veomni/data/dynamic_batching.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class DynamicBatchSizeBuffer:
|
||||
"""A buffer to store samples for dynamic batch size."""
|
||||
|
||||
def __init__(self):
|
||||
self._buffer: list[dict[str, any]] = []
|
||||
self._buffer_sample_lengths: list[int] = []
|
||||
self._deleted_indices: set[int] = set()
|
||||
self._current_index: int = 0
|
||||
self._total_token_count: int = 0
|
||||
|
||||
def append(self, item: dict[str, any]) -> None:
|
||||
"""Append a sample to the buffer.
|
||||
|
||||
Args:
|
||||
item: A sample to append to the buffer.
|
||||
The sample should be a dict with the following keys:
|
||||
- input_ids: torch.Tensor of shape (seq_len, )
|
||||
- attention_mask: torch.Tensor of shape (seq_len, )
|
||||
"""
|
||||
self._buffer.append(item)
|
||||
sample_length = int(item["attention_mask"].sum().item())
|
||||
self._buffer_sample_lengths.append(sample_length)
|
||||
self._total_token_count += sample_length
|
||||
|
||||
def get_samples(self, max_tokens_per_iteration: int, force: bool = True) -> list[dict[str, any]]:
|
||||
"""Get samples from the buffer that fit within the token budget.
|
||||
|
||||
Args:
|
||||
max_tokens_per_iteration: Maximum number of tokens to retrieve.
|
||||
force: If True, the first available sample will be returned even
|
||||
if it exceeds the token budget.
|
||||
|
||||
Returns:
|
||||
A list of samples that fit within the token budget.
|
||||
|
||||
Raises:
|
||||
AssertionError: If no samples are found (should not happen in normal operation).
|
||||
"""
|
||||
cum_seq_len = 0
|
||||
samples = []
|
||||
|
||||
while self._current_index < len(self._buffer) and cum_seq_len < max_tokens_per_iteration:
|
||||
if self._current_index in self._deleted_indices:
|
||||
self._current_index += 1
|
||||
continue
|
||||
|
||||
seq_len = self._buffer_sample_lengths[self._current_index]
|
||||
remaining_tokens = max_tokens_per_iteration - cum_seq_len
|
||||
|
||||
# Check if we can add this sample
|
||||
can_add = (force and cum_seq_len == 0) or (seq_len <= remaining_tokens)
|
||||
|
||||
if can_add:
|
||||
cum_seq_len += seq_len
|
||||
samples.append(self._buffer[self._current_index])
|
||||
self._deleted_indices.add(self._current_index)
|
||||
|
||||
self._current_index += 1
|
||||
|
||||
assert len(samples) > 0, "No samples found in buffer"
|
||||
return samples
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of samples in the buffer."""
|
||||
return len(self._buffer)
|
||||
|
||||
@property
|
||||
def total_token_count(self) -> int:
|
||||
"""Return the total number of tokens in the buffer."""
|
||||
return self._total_token_count
|
||||
|
||||
def flush(self) -> None:
|
||||
tokens_to_remove = sum(self._buffer_sample_lengths[idx] for idx in self._deleted_indices)
|
||||
self._total_token_count -= tokens_to_remove
|
||||
|
||||
buffer_length = len(self._buffer)
|
||||
self._buffer = [self._buffer[idx] for idx in range(buffer_length) if idx not in self._deleted_indices]
|
||||
self._buffer_sample_lengths = [
|
||||
self._buffer_sample_lengths[idx] for idx in range(buffer_length) if idx not in self._deleted_indices
|
||||
]
|
||||
|
||||
self._current_index = 0
|
||||
self._deleted_indices.clear()
|
||||
|
||||
|
||||
class BaseBatchingQueue(ABC):
|
||||
"""Base class for batching queue."""
|
||||
|
||||
@abstractmethod
|
||||
def is_full_filled(self) -> bool:
|
||||
raise NotImplementedError("Subclasses must implement `is_full_filled`")
|
||||
|
||||
@abstractmethod
|
||||
def put_item(self, item: dict[str, any]) -> None:
|
||||
raise NotImplementedError("Subclasses must implement `put_item`")
|
||||
|
||||
@abstractmethod
|
||||
def get_micro_batch(self, step: int) -> list[dict[str, any]]:
|
||||
raise NotImplementedError("Subclasses must implement `get_micro_batch`")
|
||||
|
||||
@abstractmethod
|
||||
def empty(self) -> bool:
|
||||
raise NotImplementedError("Subclasses must implement `empty`")
|
||||
|
||||
|
||||
class IdentityPacker:
|
||||
def __init__(self, token_micro_bsz, bsz_warmup_steps, bsz_warmup_init_mbtoken):
|
||||
self.token_micro_bsz = token_micro_bsz
|
||||
self.bsz_warmup_steps = bsz_warmup_steps
|
||||
self.bsz_warmup_init_mbtoken = bsz_warmup_init_mbtoken
|
||||
|
||||
def __call__(self, samples):
|
||||
return samples
|
||||
|
||||
def get_token_num_to_request(self, cur_step, warmup):
|
||||
return (
|
||||
(self.token_micro_bsz - self.bsz_warmup_init_mbtoken) * cur_step // self.bsz_warmup_steps
|
||||
+ self.bsz_warmup_init_mbtoken
|
||||
if warmup
|
||||
else self.token_micro_bsz
|
||||
)
|
||||
|
||||
|
||||
class TextBatchingQueue(BaseBatchingQueue):
|
||||
"""Batching text queue for text data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token_micro_bsz,
|
||||
buffer_size: int = 500,
|
||||
bsz_warmup_steps: int = -1,
|
||||
bsz_warmup_init_mbtoken: int = 200,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._step = 0
|
||||
self.token_micro_bsz = token_micro_bsz
|
||||
self.bsz_warmup_steps = bsz_warmup_steps
|
||||
self.buffer_size = buffer_size # minimum samples in buffer
|
||||
self.buffer = DynamicBatchSizeBuffer()
|
||||
self.bsz_warmup_init_mbtoken = bsz_warmup_init_mbtoken # training warmup args
|
||||
assert self.bsz_warmup_init_mbtoken >= 0
|
||||
|
||||
self.packer = IdentityPacker(
|
||||
token_micro_bsz=token_micro_bsz,
|
||||
bsz_warmup_steps=bsz_warmup_steps,
|
||||
bsz_warmup_init_mbtoken=bsz_warmup_init_mbtoken,
|
||||
)
|
||||
|
||||
def is_full_filled(self) -> bool:
|
||||
return len(self.buffer) >= self.buffer_size and self.buffer.total_token_count >= self.token_micro_bsz
|
||||
|
||||
def put_item(self, item: dict[str, any]):
|
||||
if len(item["input_ids"]) == 1:
|
||||
print("WARNING: EMPTY STRING.")
|
||||
return
|
||||
self.buffer.append(item)
|
||||
|
||||
def get_token_num_to_request(self):
|
||||
if self.packer is not None:
|
||||
warmup = self._step <= self.bsz_warmup_steps and self.bsz_warmup_steps > 0
|
||||
return self.packer.get_token_num_to_request(self._step, warmup=warmup)
|
||||
else:
|
||||
return self.get_cur_token_micro_bsz()
|
||||
|
||||
def get_cur_token_micro_bsz(self):
|
||||
warmup = self._step <= self.bsz_warmup_steps and self.bsz_warmup_steps > 0
|
||||
if warmup:
|
||||
return (
|
||||
self.token_micro_bsz - self.bsz_warmup_init_mbtoken
|
||||
) * self._step // self.bsz_warmup_steps + self.bsz_warmup_init_mbtoken
|
||||
else:
|
||||
return self.token_micro_bsz
|
||||
|
||||
def get_micro_batch(self, step) -> any:
|
||||
"""Get a micro batch from the buffer according to the current step.
|
||||
|
||||
Args:
|
||||
step: the current step.
|
||||
|
||||
Returns:
|
||||
data: a list of samples.
|
||||
"""
|
||||
self._step = step
|
||||
n_token_per_iter = self.get_token_num_to_request()
|
||||
cur_token_micro_bsz = self.get_cur_token_micro_bsz()
|
||||
assert cur_token_micro_bsz % n_token_per_iter == 0, (
|
||||
"The token num to get for each request should be divisible by token micro bsz."
|
||||
)
|
||||
n_iter = int(cur_token_micro_bsz // n_token_per_iter)
|
||||
data = []
|
||||
for _ in range(n_iter):
|
||||
samples = self.buffer.get_samples(n_token_per_iter)
|
||||
if self.packer:
|
||||
samples = self.packer(samples) # maybe packed into one sample, but wrapped in list.
|
||||
data.extend(samples)
|
||||
self.buffer.flush() # remove the selected samples.
|
||||
return data
|
||||
|
||||
def empty(self) -> bool:
|
||||
return len(self.buffer) == 0
|
||||
@@ -32,8 +32,8 @@ class DtypeRegistry:
|
||||
class DtypeInterface:
|
||||
"""Type of precision used."""
|
||||
|
||||
_is_fp16_available = is_torch_fp16_available_on_device(DistributedInterface.current_accelerator)
|
||||
_is_bf16_available = is_torch_bf16_available_on_device(DistributedInterface.current_accelerator)
|
||||
_is_fp16_available = is_torch_fp16_available_on_device(DistributedInterface().current_device)
|
||||
_is_bf16_available = is_torch_bf16_available_on_device(DistributedInterface().current_device)
|
||||
_is_fp32_available = True
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -12,9 +12,24 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from .types import Processor
|
||||
from .constants import IGNORE_INDEX
|
||||
from .types import BatchInput, ModelInput, Processor, Tensor
|
||||
|
||||
|
||||
def is_tokenizer(processor: Processor) -> bool:
|
||||
"""Check if processor is tokenizer.
|
||||
|
||||
Args:
|
||||
processor: Processor.
|
||||
|
||||
Returns:
|
||||
Whether processor is tokenizer.
|
||||
"""
|
||||
return not hasattr(processor, "tokenizer")
|
||||
|
||||
|
||||
def get_tokenizer(processor: Processor) -> PreTrainedTokenizer:
|
||||
@@ -27,3 +42,34 @@ def get_tokenizer(processor: Processor) -> PreTrainedTokenizer:
|
||||
Tokenizer.
|
||||
"""
|
||||
return processor.tokenizer if hasattr(processor, "tokenizer") else processor
|
||||
|
||||
|
||||
def _pad_and_truncate(tensor: Tensor, max_seqlen: int, pad_value: int = 0) -> Tensor:
|
||||
if tensor.shape[-1] >= max_seqlen:
|
||||
return tensor[..., :max_seqlen]
|
||||
|
||||
pad_shape = list(tensor.shape)
|
||||
pad_shape[-1] = max_seqlen - tensor.shape[-1]
|
||||
pad_tensor = torch.full(pad_shape, pad_value, dtype=tensor.dtype, device=tensor.device)
|
||||
return torch.cat([tensor, pad_tensor], dim=-1)
|
||||
|
||||
|
||||
def pad_and_truncate(samples: list[ModelInput], max_seqlen: int) -> list[BatchInput]:
|
||||
max_length = min(max(len(sample["input_ids"]) for sample in samples), max_seqlen)
|
||||
padded_samples = []
|
||||
for sample in samples:
|
||||
padded_sample = {}
|
||||
for key, value in sample.items():
|
||||
if "label" in key:
|
||||
pad_value = IGNORE_INDEX
|
||||
else:
|
||||
pad_value = 0
|
||||
|
||||
if not isinstance(value, str):
|
||||
padded_sample[key] = _pad_and_truncate(torch.tensor(value), max_length, pad_value)
|
||||
else:
|
||||
padded_sample[key] = value
|
||||
|
||||
padded_samples.append(padded_sample)
|
||||
|
||||
return padded_samples
|
||||
|
||||
@@ -144,3 +144,20 @@ class ModelInput(TypedDict, total=False):
|
||||
"""Loss weight for each token, default to 1.0."""
|
||||
position_ids: NotRequired[list[int] | list[list[int]]]
|
||||
"""Position ids for the model (optional)."""
|
||||
token_type_ids: NotRequired[list[int]]
|
||||
"""Token type ids used in DPO, 0 represents the chosen messages, 1 represents the rejected messages."""
|
||||
|
||||
|
||||
class BatchInput(TypedDict, total=False):
|
||||
input_ids: Tensor
|
||||
"""Input ids for the model."""
|
||||
attention_mask: Tensor
|
||||
"""Attention mask for the model."""
|
||||
labels: Tensor
|
||||
"""Labels for the model."""
|
||||
loss_weights: Tensor
|
||||
"""Loss weight for each token, default to 1.0."""
|
||||
position_ids: NotRequired[Tensor]
|
||||
"""Position ids for the model (optional)."""
|
||||
token_type_ids: NotRequired[Tensor]
|
||||
"""Token type ids used in DPO, 0 represents the chosen messages, 1 represents the rejected messages."""
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
# change if test fails or cache is outdated
|
||||
0.9.5.103
|
||||
0.9.5.104
|
||||
|
||||
@@ -56,4 +56,5 @@ def test_all_device():
|
||||
@pytest.mark.require_distributed(2)
|
||||
def test_multi_device():
|
||||
master_port = find_available_port()
|
||||
mp.spawn(_all_reduce_tests, args=(2, master_port), nprocs=2)
|
||||
world_size = 2
|
||||
mp.spawn(_all_reduce_tests, args=(world_size, master_port), nprocs=world_size)
|
||||
|
||||
@@ -14,28 +14,24 @@
|
||||
|
||||
import torch
|
||||
|
||||
from llamafactory.v1.config.model_args import ModelArguments, PluginConfig
|
||||
from llamafactory.v1.config.model_args import ModelArguments
|
||||
from llamafactory.v1.core.model_engine import ModelEngine
|
||||
|
||||
|
||||
def test_tiny_qwen():
|
||||
from transformers import Qwen2Config, Qwen2ForCausalLM, Qwen2TokenizerFast
|
||||
|
||||
model_args = ModelArguments(model="llamafactory/tiny-random-qwen2.5")
|
||||
model_args = ModelArguments(model="llamafactory/tiny-random-qwen3")
|
||||
model_engine = ModelEngine(model_args)
|
||||
assert isinstance(model_engine.processor, Qwen2TokenizerFast)
|
||||
assert isinstance(model_engine.model_config, Qwen2Config)
|
||||
assert isinstance(model_engine.model, Qwen2ForCausalLM)
|
||||
assert "Qwen2Tokenizer" in model_engine.processor.__class__.__name__
|
||||
assert "Qwen3Config" in model_engine.model_config.__class__.__name__
|
||||
assert "Qwen3ForCausalLM" in model_engine.model.__class__.__name__
|
||||
assert model_engine.model.dtype == torch.bfloat16
|
||||
|
||||
|
||||
def test_tiny_qwen_with_kernel_plugin():
|
||||
from transformers import Qwen2ForCausalLM
|
||||
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.ops.rms_norm.npu_rms_norm import npu_rms_norm_forward
|
||||
|
||||
model_args = ModelArguments(
|
||||
model="llamafactory/tiny-random-qwen2.5", kernel_config=PluginConfig(name="auto", include_kernels="auto")
|
||||
model="llamafactory/tiny-random-qwen3", kernel_config={"name": "auto", "include_kernels": "auto"}
|
||||
)
|
||||
model_engine = ModelEngine(model_args)
|
||||
# test enable apply kernel plugin
|
||||
@@ -44,7 +40,7 @@ def test_tiny_qwen_with_kernel_plugin():
|
||||
else:
|
||||
assert model_engine.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
|
||||
|
||||
assert isinstance(model_engine.model, Qwen2ForCausalLM)
|
||||
assert "Qwen3ForCausalLM" in model_engine.model.__class__.__name__
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
49
tests_v1/core/utils/test_batching.py
Normal file
49
tests_v1/core/utils/test_batching.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from llamafactory.v1.config import DataArguments, ModelArguments, TrainingArguments
|
||||
from llamafactory.v1.core.data_engine import DataEngine
|
||||
from llamafactory.v1.core.model_engine import ModelEngine
|
||||
from llamafactory.v1.core.utils.batching import BatchGenerator
|
||||
|
||||
|
||||
def test_normal_batching():
|
||||
data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
|
||||
data_engine = DataEngine(data_args=data_args)
|
||||
model_args = ModelArguments(model="llamafactory/tiny-random-qwen3")
|
||||
model_engine = ModelEngine(model_args=model_args)
|
||||
training_args = TrainingArguments(
|
||||
micro_batch_size=4,
|
||||
global_batch_size=8,
|
||||
cutoff_len=10,
|
||||
batching_workers=0,
|
||||
batching_strategy="normal",
|
||||
)
|
||||
batch_generator = BatchGenerator(
|
||||
data_engine,
|
||||
model_engine.renderer,
|
||||
micro_batch_size=training_args.micro_batch_size,
|
||||
global_batch_size=training_args.global_batch_size,
|
||||
cutoff_len=training_args.cutoff_len,
|
||||
batching_workers=training_args.batching_workers,
|
||||
batching_strategy=training_args.batching_strategy,
|
||||
)
|
||||
assert len(batch_generator) == len(data_engine) // training_args.global_batch_size
|
||||
batch = next(iter(batch_generator))
|
||||
assert len(batch) == 2
|
||||
assert batch[0]["input_ids"].shape == (4, 10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_normal_batching()
|
||||
@@ -1,171 +0,0 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Integration tests for DataLoader with different combinations of packing and dynamic batching.
|
||||
|
||||
Tests the 4 scenarios:
|
||||
a) non pack + non dynamic.
|
||||
b) non pack + dynamic.
|
||||
c) pack + non dynamic.
|
||||
d) pack + dynamic.
|
||||
"""
|
||||
|
||||
# import torch
|
||||
# from torch.utils.data import DataLoader as TorchDataLoader
|
||||
# from torch.utils.data import Dataset
|
||||
# from transformers import AutoTokenizer
|
||||
|
||||
# from llamafactory.v1.config.data_args import DataArguments
|
||||
# from llamafactory.v1.core.data_engine import DataEngine
|
||||
# from llamafactory.v1.core.utils.data_collator import DefaultCollator
|
||||
# from llamafactory.v1.core.utils.data_loader import DataLoader
|
||||
# from llamafactory.v1.plugins.data_plugins.rendering import QwenTemplate
|
||||
# from llamafactory.v1.utils.batching_queue import TextBatchingQueue
|
||||
|
||||
|
||||
# class TensorDataset(Dataset):
|
||||
# """Wrapper dataset that converts DataEngine samples to tensor format."""
|
||||
|
||||
# def __init__(self, data_engine: DataEngine, processor, template, max_samples: int = None):
|
||||
# self.data_engine = data_engine
|
||||
# self.processor = processor
|
||||
# self.template = template
|
||||
# self.max_samples = max_samples or len(data_engine)
|
||||
# self.tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
|
||||
|
||||
# def __len__(self):
|
||||
# return min(self.max_samples, len(self.data_engine))
|
||||
|
||||
# def __getitem__(self, idx):
|
||||
# # Get sample from DataEngine
|
||||
# sample = self.data_engine[idx]
|
||||
|
||||
# # Extract messages from sample
|
||||
# # DataEngine returns samples with format like {"messages": [...], ...}
|
||||
# # For llamafactory/v1-sft-demo, the format should have "messages" field
|
||||
# messages = None
|
||||
# if "messages" in sample:
|
||||
# messages = sample["messages"]
|
||||
# elif "conversations" in sample:
|
||||
# messages = sample["conversations"]
|
||||
# elif "conversation" in sample:
|
||||
# messages = sample["conversation"]
|
||||
# else:
|
||||
# # Try to find message-like fields (skip _dataset_name)
|
||||
# for key, value in sample.items():
|
||||
# if key.startswith("_"):
|
||||
# continue
|
||||
# if isinstance(value, list) and len(value) > 0:
|
||||
# # Check if it looks like a message list
|
||||
# if isinstance(value[0], dict) and "role" in value[0]:
|
||||
# messages = value
|
||||
# break
|
||||
|
||||
# if messages is None:
|
||||
# raise ValueError(f"Could not find messages in sample: {list(sample.keys())}")
|
||||
|
||||
# # Encode messages using template
|
||||
# encoded = self.template.encode_messages(self.tokenizer, messages)
|
||||
|
||||
# # Convert to tensors
|
||||
# return {
|
||||
# "input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
|
||||
# "attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long),
|
||||
# "labels": torch.tensor(encoded["labels"], dtype=torch.long),
|
||||
# }
|
||||
|
||||
|
||||
# def create_real_dataset(max_samples: int = 20, batch_size: int = 4):
|
||||
# """Create a real dataset using DataEngine."""
|
||||
# data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
|
||||
# data_engine = DataEngine(data_args)
|
||||
|
||||
# # Create processor and template
|
||||
# processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
||||
# template = QwenTemplate()
|
||||
|
||||
# # Create tensor dataset
|
||||
# raw_data_dataset = TensorDataset(data_engine, processor, template, max_samples=max_samples)
|
||||
|
||||
# # Create torch DataLoader
|
||||
# torch_dataloader = TorchDataLoader(
|
||||
# raw_data_dataset,
|
||||
# batch_size=batch_size,
|
||||
# shuffle=False,
|
||||
# collate_fn=lambda x: x,
|
||||
# )
|
||||
|
||||
# return torch_dataloader, processor, template
|
||||
|
||||
|
||||
# class TestDataLoaderNonPackNonDynamic:
|
||||
# """Test case a) non pack + non dynamic."""
|
||||
|
||||
# def test_basic_functionality(self):
|
||||
# """Test DataLoader without packing and without dynamic batching."""
|
||||
# # Create real dataset
|
||||
# torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
|
||||
|
||||
# # Create collator (non-packing)
|
||||
# collator = DefaultCollator(processor=processor, template=template)
|
||||
|
||||
# # Create DataLoader without batching_queue (non-dynamic)
|
||||
# data_loader = DataLoader(
|
||||
# dataloader=torch_dataloader,
|
||||
# collate_fn=collator,
|
||||
# num_micro_batch=1,
|
||||
# batching_queue=None,
|
||||
# )
|
||||
|
||||
# # Iterate and check results
|
||||
# batches = list(iter(data_loader))
|
||||
# assert len(batches) > 0
|
||||
|
||||
# # Check first batch
|
||||
# one_batch = batches[0]
|
||||
# micro_batches = one_batch[0]
|
||||
# assert "input_ids" in micro_batches
|
||||
# assert "attention_mask" in micro_batches
|
||||
# assert "labels" in micro_batches
|
||||
# assert micro_batches["input_ids"].shape[0] == 1 # batch_size=1
|
||||
# assert micro_batches["input_ids"].ndim == 2 # [batch_size, seq_len]
|
||||
|
||||
|
||||
# class TestDataLoaderNonPackDynamic:
|
||||
# """Test case b) non pack + dynamic."""
|
||||
|
||||
# def test_basic_functionality(self):
|
||||
# """Test DataLoader without packing but with dynamic batching."""
|
||||
# # Create real dataset
|
||||
# torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
|
||||
# collator = DefaultCollator(processor=processor, template=template)
|
||||
|
||||
# # Create batching queue for dynamic batching
|
||||
# batching_queue = TextBatchingQueue(
|
||||
# token_micro_bsz=120,
|
||||
# buffer_size=8,
|
||||
# )
|
||||
|
||||
# data_loader = DataLoader(
|
||||
# dataloader=torch_dataloader,
|
||||
# collate_fn=collator,
|
||||
# num_micro_batch=4,
|
||||
# batching_queue=batching_queue,
|
||||
# )
|
||||
|
||||
# # Iterate and check
|
||||
# batches = list(iter(data_loader))
|
||||
# micro_batch_tokens_first = [micro_batch["attention_mask"].sum() for micro_batch in batches[0]]
|
||||
# assert all(num_tokens <= 120 for num_tokens in micro_batch_tokens_first)
|
||||
# assert len(batches) > 0
|
||||
@@ -184,6 +184,40 @@ def test_qwen3_nothink_rendering_remote(num_samples: int):
|
||||
assert v1_inputs["input_ids"][: len(prefix)] == prefix
|
||||
|
||||
|
||||
def test_process_sft_samples():
|
||||
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
|
||||
renderer = Renderer(template="chatml", processor=tokenizer)
|
||||
hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES)
|
||||
|
||||
samples = [{"messages": V1_MESSAGES, "extra_info": "test", "_dataset_name": "default"}]
|
||||
model_inputs = renderer.process_samples(samples)
|
||||
assert len(model_inputs) == 1
|
||||
assert model_inputs[0]["input_ids"] == hf_inputs
|
||||
assert model_inputs[0]["extra_info"] == "test"
|
||||
assert model_inputs[0]["_dataset_name"] == "default"
|
||||
|
||||
|
||||
def test_process_dpo_samples():
|
||||
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
|
||||
renderer = Renderer(template="chatml", processor=tokenizer)
|
||||
hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES)
|
||||
|
||||
samples = [
|
||||
{
|
||||
"chosen_messages": V1_MESSAGES,
|
||||
"rejected_messages": V1_MESSAGES,
|
||||
"extra_info": "test",
|
||||
"_dataset_name": "default",
|
||||
}
|
||||
]
|
||||
model_inputs = renderer.process_samples(samples)
|
||||
assert len(model_inputs) == 1
|
||||
assert model_inputs[0]["input_ids"] == hf_inputs * 2
|
||||
assert model_inputs[0]["token_type_ids"] == [0] * len(hf_inputs) + [1] * len(hf_inputs)
|
||||
assert model_inputs[0]["extra_info"] == "test"
|
||||
assert model_inputs[0]["_dataset_name"] == "default"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_chatml_rendering()
|
||||
test_chatml_parse()
|
||||
@@ -191,3 +225,5 @@ if __name__ == "__main__":
|
||||
test_qwen3_nothink_rendering()
|
||||
test_qwen3_nothink_parse()
|
||||
test_qwen3_nothink_rendering_remote(16)
|
||||
test_process_sft_samples()
|
||||
test_process_dpo_samples()
|
||||
|
||||
@@ -21,7 +21,7 @@ from llamafactory.v1.core.model_engine import ModelEngine
|
||||
def test_init_on_meta():
|
||||
_, model_args, *_ = get_args(
|
||||
dict(
|
||||
model="llamafactory/tiny-random-qwen2.5",
|
||||
model="llamafactory/tiny-random-qwen3",
|
||||
init_config={"name": "init_on_meta"},
|
||||
)
|
||||
)
|
||||
@@ -32,7 +32,7 @@ def test_init_on_meta():
|
||||
def test_init_on_rank0():
|
||||
_, model_args, *_ = get_args(
|
||||
dict(
|
||||
model="llamafactory/tiny-random-qwen2.5",
|
||||
model="llamafactory/tiny-random-qwen3",
|
||||
init_config={"name": "init_on_rank0"},
|
||||
)
|
||||
)
|
||||
@@ -46,7 +46,7 @@ def test_init_on_rank0():
|
||||
def test_init_on_default():
|
||||
_, model_args, *_ = get_args(
|
||||
dict(
|
||||
model="llamafactory/tiny-random-qwen2.5",
|
||||
model="llamafactory/tiny-random-qwen3",
|
||||
init_config={"name": "init_on_default"},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -43,7 +43,7 @@ def test_apply_kernel(mock_get_accelerator: MagicMock):
|
||||
reload_kernels()
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
||||
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen3")
|
||||
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
||||
original_swiglu_forward = model.model.layers[0].mlp.forward
|
||||
model = apply_default_kernels(model=model, include_kernels="npu_fused_rmsnorm")
|
||||
@@ -62,7 +62,7 @@ def test_apply_all_kernels(mock_get_accelerator: MagicMock):
|
||||
reload_kernels()
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
||||
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen3")
|
||||
|
||||
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
||||
original_swiglu_forward = model.model.layers[0].mlp.forward
|
||||
|
||||
@@ -1,112 +0,0 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from llamafactory.v1.utils.batching_queue import DynamicBatchSizeBuffer, TextBatchingQueue
|
||||
|
||||
|
||||
def create_sample(length: int):
|
||||
"""Helper to create a mock sample with a specific token length."""
|
||||
return {"input_ids": torch.ones(length), "attention_mask": torch.ones(length)}
|
||||
|
||||
|
||||
class TestDynamicBatchSizeBuffer:
|
||||
def test_append_and_token_count(self):
|
||||
buffer = DynamicBatchSizeBuffer()
|
||||
buffer.append(create_sample(10))
|
||||
buffer.append(create_sample(20))
|
||||
|
||||
assert len(buffer) == 2
|
||||
assert buffer.total_token_count == 30
|
||||
|
||||
def test_get_samples_within_budget(self):
|
||||
buffer = DynamicBatchSizeBuffer()
|
||||
buffer.append(create_sample(10))
|
||||
buffer.append(create_sample(10))
|
||||
buffer.append(create_sample(50)) # This one is large
|
||||
|
||||
# Request 25 tokens. Should get the first two (20 tokens total)
|
||||
samples = buffer.get_samples(max_tokens_per_iteration=25)
|
||||
assert len(samples) == 2
|
||||
|
||||
def test_force_return_first_sample(self):
|
||||
buffer = DynamicBatchSizeBuffer()
|
||||
buffer.append(create_sample(100))
|
||||
|
||||
# Even though budget is 50, force=True (default) should return the 100-token sample
|
||||
samples = buffer.get_samples(max_tokens_per_iteration=50, force=True)
|
||||
assert len(samples) == 1
|
||||
assert len(samples[0]["input_ids"]) == 100
|
||||
|
||||
def test_flush_removes_used_samples(self):
|
||||
buffer = DynamicBatchSizeBuffer()
|
||||
buffer.append(create_sample(10))
|
||||
buffer.append(create_sample(20))
|
||||
|
||||
# Take the first sample
|
||||
buffer.get_samples(max_tokens_per_iteration=15)
|
||||
buffer.flush()
|
||||
|
||||
assert len(buffer) == 1
|
||||
assert buffer.total_token_count == 20
|
||||
# The remaining sample should now be at the start
|
||||
remaining = buffer.get_samples(max_tokens_per_iteration=50)
|
||||
assert len(remaining[0]["input_ids"]) == 20
|
||||
|
||||
|
||||
class TestTextBatchingQueue:
|
||||
def test_is_full_filled(self):
|
||||
queue = TextBatchingQueue(token_micro_bsz=100, buffer_size=2)
|
||||
|
||||
queue.put_item(create_sample(10))
|
||||
assert not queue.is_full_filled() # Only 1 sample, buffer_size=2
|
||||
|
||||
queue.put_item(create_sample(10))
|
||||
assert not queue.is_full_filled() # 2 samples, but only 20 tokens (min 100)
|
||||
|
||||
queue.put_item(create_sample(90))
|
||||
assert queue.is_full_filled() # Meets both conditions
|
||||
|
||||
def test_warmup_logic(self):
|
||||
# token_micro_bsz=1000, starts at 200, reaches 1000 at step 10
|
||||
queue = TextBatchingQueue(token_micro_bsz=1000, bsz_warmup_steps=10, bsz_warmup_init_mbtoken=200)
|
||||
|
||||
# Step 0: should be init value
|
||||
assert queue.get_cur_token_micro_bsz() == 200
|
||||
|
||||
# Step 5: halfway through warmup (200 + (800 * 5/10)) = 600
|
||||
queue._step = 5
|
||||
assert queue.get_cur_token_micro_bsz() == 600
|
||||
|
||||
# Step 11: past warmup
|
||||
queue._step = 11
|
||||
assert queue.get_cur_token_micro_bsz() == 1000
|
||||
|
||||
def test_get_micro_batch_integration(self):
|
||||
queue = TextBatchingQueue(token_micro_bsz=50, buffer_size=1)
|
||||
queue.put_item(create_sample(20))
|
||||
queue.put_item(create_sample(20))
|
||||
queue.put_item(create_sample(20))
|
||||
|
||||
# At step 0 (warmup not triggered as bsz_warmup_steps is -1 default),
|
||||
# it should take samples up to 50 tokens.
|
||||
batch = queue.get_micro_batch(step=0)
|
||||
|
||||
assert len(batch) == 2
|
||||
assert queue.empty() is False
|
||||
|
||||
batch_2 = queue.get_micro_batch(step=1)
|
||||
assert len(batch_2) == 1
|
||||
assert queue.empty() is True
|
||||
Reference in New Issue
Block a user