mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-02 08:33:38 +00:00
[v1&WIP] dataloader init (#9645)
This commit is contained in:
220
src/llamafactory/v1/utils/batching_queue.py
Normal file
220
src/llamafactory/v1/utils/batching_queue.py
Normal file
@@ -0,0 +1,220 @@
|
||||
# Copyright 2025 Bytedance Ltd. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the Bytedance's VeOmni library.
|
||||
# https://github.com/ByteDance-Seed/VeOmni/blob/v0.1.4/veomni/data/dynamic_batching.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class DynamicBatchSizeBuffer:
|
||||
"""A buffer to store samples for dynamic batch size."""
|
||||
|
||||
def __init__(self):
|
||||
self._buffer: list[dict[str, any]] = []
|
||||
self._buffer_sample_lengths: list[int] = []
|
||||
self._deleted_indices: set[int] = set()
|
||||
self._current_index: int = 0
|
||||
self._total_token_count: int = 0
|
||||
|
||||
def append(self, item: dict[str, any]) -> None:
|
||||
"""Append a sample to the buffer.
|
||||
|
||||
Args:
|
||||
item: A sample to append to the buffer.
|
||||
The sample should be a dict with the following keys:
|
||||
- input_ids: torch.Tensor of shape (seq_len, )
|
||||
- attention_mask: torch.Tensor of shape (seq_len, )
|
||||
"""
|
||||
self._buffer.append(item)
|
||||
sample_length = int(item["attention_mask"].sum().item())
|
||||
self._buffer_sample_lengths.append(sample_length)
|
||||
self._total_token_count += sample_length
|
||||
|
||||
def get_samples(self, max_tokens_per_iteration: int, force: bool = True) -> list[dict[str, any]]:
|
||||
"""Get samples from the buffer that fit within the token budget.
|
||||
|
||||
Args:
|
||||
max_tokens_per_iteration: Maximum number of tokens to retrieve.
|
||||
force: If True, the first available sample will be returned even
|
||||
if it exceeds the token budget.
|
||||
|
||||
Returns:
|
||||
A list of samples that fit within the token budget.
|
||||
|
||||
Raises:
|
||||
AssertionError: If no samples are found (should not happen in normal operation).
|
||||
"""
|
||||
cum_seq_len = 0
|
||||
samples = []
|
||||
|
||||
while self._current_index < len(self._buffer) and cum_seq_len < max_tokens_per_iteration:
|
||||
if self._current_index in self._deleted_indices:
|
||||
self._current_index += 1
|
||||
continue
|
||||
|
||||
seq_len = self._buffer_sample_lengths[self._current_index]
|
||||
remaining_tokens = max_tokens_per_iteration - cum_seq_len
|
||||
|
||||
# Check if we can add this sample
|
||||
can_add = (force and cum_seq_len == 0) or (seq_len <= remaining_tokens)
|
||||
|
||||
if can_add:
|
||||
cum_seq_len += seq_len
|
||||
samples.append(self._buffer[self._current_index])
|
||||
self._deleted_indices.add(self._current_index)
|
||||
|
||||
self._current_index += 1
|
||||
|
||||
assert len(samples) > 0, "No samples found in buffer"
|
||||
return samples
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of samples in the buffer."""
|
||||
return len(self._buffer)
|
||||
|
||||
@property
|
||||
def total_token_count(self) -> int:
|
||||
"""Return the total number of tokens in the buffer."""
|
||||
return self._total_token_count
|
||||
|
||||
def flush(self) -> None:
|
||||
tokens_to_remove = sum(self._buffer_sample_lengths[idx] for idx in self._deleted_indices)
|
||||
self._total_token_count -= tokens_to_remove
|
||||
|
||||
buffer_length = len(self._buffer)
|
||||
self._buffer = [self._buffer[idx] for idx in range(buffer_length) if idx not in self._deleted_indices]
|
||||
self._buffer_sample_lengths = [
|
||||
self._buffer_sample_lengths[idx] for idx in range(buffer_length) if idx not in self._deleted_indices
|
||||
]
|
||||
|
||||
self._current_index = 0
|
||||
self._deleted_indices.clear()
|
||||
|
||||
|
||||
class BaseBatchingQueue(ABC):
|
||||
"""Base class for batching queue."""
|
||||
|
||||
@abstractmethod
|
||||
def is_full_filled(self) -> bool:
|
||||
raise NotImplementedError("Subclasses must implement `is_full_filled`")
|
||||
|
||||
@abstractmethod
|
||||
def put_item(self, item: dict[str, any]) -> None:
|
||||
raise NotImplementedError("Subclasses must implement `put_item`")
|
||||
|
||||
@abstractmethod
|
||||
def get_micro_batch(self, step: int) -> list[dict[str, any]]:
|
||||
raise NotImplementedError("Subclasses must implement `get_micro_batch`")
|
||||
|
||||
@abstractmethod
|
||||
def empty(self) -> bool:
|
||||
raise NotImplementedError("Subclasses must implement `empty`")
|
||||
|
||||
|
||||
class IdentityPacker:
|
||||
def __init__(self, token_micro_bsz, bsz_warmup_steps, bsz_warmup_init_mbtoken):
|
||||
self.token_micro_bsz = token_micro_bsz
|
||||
self.bsz_warmup_steps = bsz_warmup_steps
|
||||
self.bsz_warmup_init_mbtoken = bsz_warmup_init_mbtoken
|
||||
|
||||
def __call__(self, samples):
|
||||
return samples
|
||||
|
||||
def get_token_num_to_request(self, cur_step, warmup):
|
||||
return (
|
||||
(self.token_micro_bsz - self.bsz_warmup_init_mbtoken) * cur_step // self.bsz_warmup_steps
|
||||
+ self.bsz_warmup_init_mbtoken
|
||||
if warmup
|
||||
else self.token_micro_bsz
|
||||
)
|
||||
|
||||
|
||||
class TextBatchingQueue(BaseBatchingQueue):
|
||||
"""Batching text queue for text data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token_micro_bsz,
|
||||
buffer_size: int = 500,
|
||||
bsz_warmup_steps: int = -1,
|
||||
bsz_warmup_init_mbtoken: int = 200,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._step = 0
|
||||
self.token_micro_bsz = token_micro_bsz
|
||||
self.bsz_warmup_steps = bsz_warmup_steps
|
||||
self.buffer_size = buffer_size # minimum samples in buffer
|
||||
self.buffer = DynamicBatchSizeBuffer()
|
||||
self.bsz_warmup_init_mbtoken = bsz_warmup_init_mbtoken # training warmup args
|
||||
assert self.bsz_warmup_init_mbtoken >= 0
|
||||
|
||||
self.packer = IdentityPacker(
|
||||
token_micro_bsz=token_micro_bsz,
|
||||
bsz_warmup_steps=bsz_warmup_steps,
|
||||
bsz_warmup_init_mbtoken=bsz_warmup_init_mbtoken,
|
||||
)
|
||||
|
||||
def is_full_filled(self) -> bool:
|
||||
return len(self.buffer) >= self.buffer_size and self.buffer.total_token_count >= self.token_micro_bsz
|
||||
|
||||
def put_item(self, item: dict[str, any]):
|
||||
if len(item["input_ids"]) == 1:
|
||||
print("WARNING: EMPTY STRING.")
|
||||
return
|
||||
self.buffer.append(item)
|
||||
|
||||
def get_token_num_to_request(self):
|
||||
if self.packer is not None:
|
||||
warmup = self._step <= self.bsz_warmup_steps and self.bsz_warmup_steps > 0
|
||||
return self.packer.get_token_num_to_request(self._step, warmup=warmup)
|
||||
else:
|
||||
return self.get_cur_token_micro_bsz()
|
||||
|
||||
def get_cur_token_micro_bsz(self):
|
||||
warmup = self._step <= self.bsz_warmup_steps and self.bsz_warmup_steps > 0
|
||||
if warmup:
|
||||
return (
|
||||
self.token_micro_bsz - self.bsz_warmup_init_mbtoken
|
||||
) * self._step // self.bsz_warmup_steps + self.bsz_warmup_init_mbtoken
|
||||
else:
|
||||
return self.token_micro_bsz
|
||||
|
||||
def get_micro_batch(self, step) -> any:
|
||||
"""Get a micro batch from the buffer according to the current step.
|
||||
|
||||
Args:
|
||||
step: the current step.
|
||||
|
||||
Returns:
|
||||
data: a list of samples.
|
||||
"""
|
||||
self._step = step
|
||||
n_token_per_iter = self.get_token_num_to_request()
|
||||
cur_token_micro_bsz = self.get_cur_token_micro_bsz()
|
||||
assert cur_token_micro_bsz % n_token_per_iter == 0, (
|
||||
"The token num to get for each request should be divisible by token micro bsz."
|
||||
)
|
||||
n_iter = int(cur_token_micro_bsz // n_token_per_iter)
|
||||
data = []
|
||||
for _ in range(n_iter):
|
||||
samples = self.buffer.get_samples(n_token_per_iter)
|
||||
if self.packer:
|
||||
samples = self.packer(samples) # maybe packed into one sample, but wrapped in list.
|
||||
data.extend(samples)
|
||||
self.buffer.flush() # remove the selected samples.
|
||||
return data
|
||||
|
||||
def empty(self) -> bool:
|
||||
return len(self.buffer) == 0
|
||||
Reference in New Issue
Block a user