Big DataLoader refactor: BOS-aligned dataloaders with epoch tracking for pre/mid-training

The new DataLoader ensures that every token sequence in train/val batches has a BOS token
at the beginning. Therefore, no token streams start abruptly in the middle of a document,
which could be confusing for the model. Note that this changes the loss scale because there
are fewer confusing tokens in the train/val batches. The main downside is that we now waste
about 35% of tokens due to cropping. This is ok because we have a lot of data. See dev/LOG.md
entry for this change for a lot more information.
This commit is contained in:
Andrej Karpathy
2026-01-13 20:05:47 +00:00
parent 23985413aa
commit 43c29dd9d5
7 changed files with 330 additions and 106 deletions

View File

@@ -4,6 +4,70 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
--- ---
## 2026-01-13: BOS-Aligned Dataloader with Bin Packing
Redesigned the pretraining and midtraining dataloader to ensure every sequence starts with a BOS token, and explored bin-packing algorithms to minimize wasted tokens.
### Problem Statement
The original dataloader streams tokens into a flat buffer and reshapes into batches. This means some rows start mid-document (no BOS), which could confuse the model during training. We want every row to start with BOS and contain well-formed documents.
### Approach 1: Greedy-Crop BOS (Simple)
Each row is built independently:
- Start with a document (which has BOS prepended)
- Pack more documents until row is full
- If a document doesn't fit, **crop it** to fill remaining space (discard the rest)
- 100% utilization (no padding), but wastes cropped tokens
### Waste Analysis
Measured token waste empirically on real data (T=2048):
- **39.4% of tokens are cropped** (discarded when docs don't fit)
- **22.9% is the theoretical minimum** (tokens in docs longer than T+1 that can never fit)
- The extra ~16.5% comes from "unlucky" cropping when a long doc starts near the end of a row
### Bin Packing Algorithms Explored
| Algorithm | Util% | Crop% | Pad% | Notes |
|-----------|-------|-------|------|-------|
| Greedy-Crop (baseline) | 100% | 39.4% | 0% | Simple, no wasted compute |
| Greedy-Pad | 78% | 23.0% | 22% | Pads instead of crops - wastes compute |
| First-Fit Decreasing (FFD) | 99.7% | 23.0% | 0.3% | Near-optimal packing, minimal padding |
| **BestFit-Crop** | 100% | 34.6% | 0% | Smart cropping, no padding |
### BestFit-Crop Algorithm
A middle ground that maintains 100% utilization while reducing cropping:
1. Buffer N documents
2. For each row, greedily pick the **largest doc that fits entirely**
3. Repeat until nothing fits
4. When nothing fits, crop a doc to fill remaining space exactly
This avoids "unlucky" crops by searching the buffer for better-fitting documents.
**Results (T=2048):**
- Crop waste reduced from 39.4% → 34.6% (~12% relative improvement)
- Still achieves 100% utilization (no padding, every token trains)
- Slightly more rows than baseline (uses more documents per batch)
### Decision: Keep Two Implementations
1. Keep the original implementation which is very simple, efficient and has 100% token utilization in the batch (no padding with ignore tokens), but creates slightly more confusing token streams for the LLM because documents during training can start abruptly from the middle with no context. Note that this never happens at test time, where BOS is always present.
2. **`_bos_bestfit` (BestFit-Crop, new default)**: Slightly more complex but still keeps 100% token utilization in the batch (no padding), but at the cost of discarding documents when they don't fit. In practice, about 34% of tokens are discarded with this approach. This is ok because for most models we care about we have plenty of data without having to go to multiple epochs. One more subtle effect is that it does skew the data distribution a tiny bit because, reliably and necessarily, tokens at the tails of long documents will be discarded. However, this doesn't seem to impact actual downstream performance.
### Midtraining
The midtraining dataloader was also updated. Because conversations are on average a lot shorter than pretraining documents, only about 3.3% of tokens get cropped.
### NOTE: loss scale
Do note that switching to the BOS dataloader changes the validation loss and makes all previous experiments not comparable in absolute value of the loss, because we have a lot fewer "confusing" tokens in the train/val batches. All tokens can look back and find the BOS token and have the full context of that document to make predictions. Therefore, the loss appears lower but this is "fake" to some extent, and the expectation is that the vast majority of relative comparisons done so far would agree with those before and after this change.
---
## 2026-01-13: Number Token Split Pattern ## 2026-01-13: Number Token Split Pattern
Validated the `\p{N}{1,2}` pattern in `SPLIT_PATTERN` (tokenizer.py line 30), which I only guessed earlier and had a TODO for to validate. GPT-4 uses `\p{N}{1,3}` to group number sequences of up to 3 digits into tokens, but we suspected smaller vocab sizes benefit from grouping fewer digits per token. Validated the `\p{N}{1,2}` pattern in `SPLIT_PATTERN` (tokenizer.py line 30), which I only guessed earlier and had a TODO for to validate. GPT-4 uses `\p{N}{1,3}` to group number sequences of up to 3 digits into tokens, but we suspected smaller vocab sizes benefit from grouping fewer digits per token.

View File

@@ -17,8 +17,9 @@ if [ -z "$SKIP_SETUP" ]; then
uv sync --extra gpu uv sync --extra gpu
source .venv/bin/activate source .venv/bin/activate
# Tokenizer # Tokenizer, download 1000 shards for pretraining
python -m nanochat.dataset -n 240 # (probably this can be reduced but it's tricky to determine the exact right number, TODO).
python -m nanochat.dataset -n 1000
python -m scripts.tok_train --max_chars=2000000000 --vocab_size=32768 python -m scripts.tok_train --max_chars=2000000000 --vocab_size=32768
else else
source .venv/bin/activate source .venv/bin/activate

View File

@@ -1,4 +1,25 @@
from collections import deque """
Distributed dataloaders for pretraining.
Two implementations are provided:
1. Original (tokenizing_distributed_data_loader):
- Streams tokens into a flat buffer, reshapes to (B, T)
- Rows may start mid-document (no guaranteed BOS at position 0)
- 100% token utilization, simple and efficient
2. BOS-aligned bestfit (tokenizing_distributed_data_loader_bos_bestfit):
- Every row starts with BOS token
- Documents packed using best-fit algorithm to minimize cropping
- When no document fits remaining space, crops a document to fill exactly
- 100% utilization (no padding), ~35% tokens cropped at T=2048
The tradeoff: BOS-aligned loses ~35% of tokens to cropping, but ensures that
there are fewer "confusing" tokens in the train/val batches as every token can
now attend back to the BOS token and sees the full context of the document.
(2) is the new default if you have enough data.
Fallback to (1) if you have very limited data AND long documents.
"""
import torch import torch
import pyarrow.parquet as pq import pyarrow.parquet as pq
@@ -6,86 +27,172 @@ import pyarrow.parquet as pq
from nanochat.common import get_dist_info from nanochat.common import get_dist_info
from nanochat.dataset import list_parquet_files from nanochat.dataset import list_parquet_files
def tokenizing_distributed_data_loader_with_state(tokenizer, B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None): def _document_batches(split, resume_state_dict, tokenizer_batch_size):
""" """
Stream pretraining text from parquet files, tokenize, yield training batches. Infinite iterator over document batches (list of text strings) from parquet files.
This implementation became a bit more complex because we wish to support approximate resume training. Handles DDP sharding and approximate resume. Each yield is (text_batch, (pq_idx, rg_idx, epoch))
Instead of turning this into a Class, we opt to return the state_dict with every batch, where text_batch is a list of document strings, indices track position for resumption,
and then the caller can pass in a state_dict to resume training from a desired point. and epoch counts how many times we've cycled through the dataset (starts at 1).
Note that this resumption is atm only *approximate* for simplicity.
We won't repeat the same documents but we might skip a few.
The state_dict that is returned can be later passed into this function via `resume_state_dict` to approximately resume.
Perfect state resumption is possible but would be a lot more bloated, probably not worth it atm.
""" """
assert split in ["train", "val"], "split must be 'train' or 'val'"
# infinite iterator over document batches (list of text strings)
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
def document_batches():
parquet_paths = list_parquet_files() parquet_paths = list_parquet_files()
assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?" assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?"
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:] parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0 resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
resume_epoch = resume_state_dict.get("epoch", 1) if resume_state_dict is not None else 1
first_pass = True first_pass = True
pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0) pq_idx = resume_pq_idx
epoch = resume_epoch
while True: # iterate infinitely (multi-epoch) while True: # iterate infinitely (multi-epoch)
pq_idx = resume_pq_idx if first_pass else 0 pq_idx = resume_pq_idx if first_pass else 0
while pq_idx < len(parquet_paths): # iterate over all parquet files while pq_idx < len(parquet_paths):
filepath = parquet_paths[pq_idx] filepath = parquet_paths[pq_idx]
pf = pq.ParquetFile(filepath) pf = pq.ParquetFile(filepath)
# Start from resume point if resuming on same file, otherwise from DDP rank # Start from resume point if resuming on same file, otherwise from DDP rank
# I know this state resumption is a little bit tricky and a little bit hacky... sigh.
if first_pass and (resume_rg_idx is not None) and (pq_idx == resume_pq_idx): if first_pass and (resume_rg_idx is not None) and (pq_idx == resume_pq_idx):
base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size base_idx = resume_rg_idx // ddp_world_size
base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming base_idx += 1 # advance by 1 so we don't repeat data after resuming
rg_idx = base_idx * ddp_world_size + ddp_rank rg_idx = base_idx * ddp_world_size + ddp_rank
if rg_idx >= pf.num_row_groups: if rg_idx >= pf.num_row_groups:
pq_idx += 1 pq_idx += 1
continue continue
resume_rg_idx = None # set to None as we only want to do this a single time resume_rg_idx = None # only do this once
else: else:
rg_idx = ddp_rank rg_idx = ddp_rank
while rg_idx < pf.num_row_groups: while rg_idx < pf.num_row_groups:
rg = pf.read_row_group(rg_idx) rg = pf.read_row_group(rg_idx)
batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows batch = rg.column('text').to_pylist()
# the tokenizer encode might want to go in even smaller batches, e.g. 128 rows
for i in range(0, len(batch), tokenizer_batch_size): for i in range(0, len(batch), tokenizer_batch_size):
yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx) yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx, epoch)
rg_idx += ddp_world_size # advance to the next row group (in DDP) rg_idx += ddp_world_size
pq_idx += 1 # advance to the next parquet file pq_idx += 1
first_pass = False first_pass = False
batches = document_batches() epoch += 1
# Now emit batches of tokens.
needed_tokens = B * T + 1 # +1 is because we also need the target at the last token def tokenizing_distributed_data_loader_with_state(tokenizer, B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None):
"""
Stream pretraining text from parquet files, tokenize, yield training batches.
This is the original dataloader that streams tokens into a flat buffer and reshapes.
Rows may start mid-document (no guaranteed BOS at position 0).
Supports approximate resume via state_dict.
"""
assert split in ["train", "val"], "split must be 'train' or 'val'"
batches = _document_batches(split, resume_state_dict, tokenizer_batch_size)
needed_tokens = B * T + 1 # +1 for target at last position
bos_token = tokenizer.get_bos_token_id() bos_token = tokenizer.get_bos_token_id()
# scratch buffer holds the tokens for one iteration token_buffer = []
token_buffer = deque() # we stream tokens on the right and pop from the left pq_idx, rg_idx, epoch = 0, 0, 1
while True: while True:
# Accumulate enough tokens for one iteration before yielding.
# Accumulate enough tokens
while len(token_buffer) < needed_tokens: while len(token_buffer) < needed_tokens:
doc_batch, (pq_idx, rg_idx) = next(batches) doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads) token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
for tokens in token_lists: for tokens in token_lists:
token_buffer.extend(tokens) token_buffer.extend(tokens)
# Move tokens from the deque into the scratch buffer tokens = token_buffer[:needed_tokens] # Read B*T+1 tokens (+1 is only for the target for the last token)
tokens = [token_buffer.popleft() for _ in range(needed_tokens)] token_buffer = token_buffer[B*T:] # Advance by B*T tokens, so we move exactly one window of B*T tokens over
# CUDA supports memory pinning for asynchronous transfers between CPU and GPU
use_cuda_optimizations = device == "cuda" # Package tokens into inputs and targets, yield
scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations) # in PyTorch, long=int64 use_cuda = device == "cuda"
# Create the inputs/targets as 1D tensors scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda)
inputs_cpu = scratch[:-1] inputs = scratch[:-1].view(B, T).to(device=device, non_blocking=use_cuda)
targets_cpu = scratch[1:] targets = scratch[1:].view(B, T).to(device=device, non_blocking=use_cuda)
# Reshape to 2D and move to GPU async yield inputs, targets, {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
inputs = inputs_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
targets = targets_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx} # we need this in case we wish to approximately resume training
yield inputs, targets, state_dict
def tokenizing_distributed_data_loader(*args, **kwargs): def tokenizing_distributed_data_loader(*args, **kwargs):
# helper function that only emits the inputs/targets and not the state_dict """Helper that omits state_dict from yields."""
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs): for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs):
yield inputs, targets yield inputs, targets
def tokenizing_distributed_data_loader_with_state_bos_bestfit(
tokenizer, B, T, split,
tokenizer_threads=4, tokenizer_batch_size=128,
device="cuda", resume_state_dict=None,
buffer_size=1000
):
"""
BOS-aligned dataloader with Best-Fit Cropping.
Reduces token waste compared to simple greedy cropping by searching a buffer
for documents that fit well, while maintaining 100% utilization (no padding).
Algorithm for each row:
1. From buffered docs, pick the LARGEST doc that fits entirely
2. Repeat until no doc fits
3. When nothing fits, crop a doc to fill remaining space exactly
Key properties:
- Every row starts with BOS
- 100% utilization (no padding, every token is trained on)
- Approximately 35% of all tokens are discarded due to cropping
"""
assert split in ["train", "val"], "split must be 'train' or 'val'"
row_capacity = T + 1
batches = _document_batches(split, resume_state_dict, tokenizer_batch_size)
bos_token = tokenizer.get_bos_token_id()
doc_buffer = []
pq_idx, rg_idx, epoch = 0, 0, 1
def refill_buffer():
nonlocal pq_idx, rg_idx, epoch
doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
for tokens in token_lists:
doc_buffer.append(tokens)
while True:
rows = []
for _ in range(B):
row = []
while len(row) < row_capacity:
# Ensure buffer has documents
while len(doc_buffer) < buffer_size:
refill_buffer()
remaining = row_capacity - len(row)
# Find largest doc that fits entirely
best_idx = -1
best_len = 0
for i, doc in enumerate(doc_buffer):
doc_len = len(doc)
if doc_len <= remaining and doc_len > best_len:
best_idx = i
best_len = doc_len
if best_idx >= 0:
doc = doc_buffer.pop(best_idx)
row.extend(doc)
else:
# No doc fits - crop first doc to fill remaining
doc = doc_buffer.pop(0)
row.extend(doc[:remaining])
rows.append(row[:row_capacity])
use_cuda = device == "cuda"
batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
inputs = batch_tensor[:, :-1].to(device=device, non_blocking=use_cuda)
targets = batch_tensor[:, 1:].to(device=device, non_blocking=use_cuda)
yield inputs, targets, {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
def tokenizing_distributed_data_loader_bos_bestfit(*args, **kwargs):
"""Helper that omits state_dict from yields."""
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state_bos_bestfit(*args, **kwargs):
yield inputs, targets

View File

@@ -20,8 +20,8 @@ curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-publ
# train tokenizer on ~4B characters and kick off download of the rest for pretraining # train tokenizer on ~4B characters and kick off download of the rest for pretraining
python -m nanochat.dataset -n 16 python -m nanochat.dataset -n 16
# start downloading the rest of the shards for a total of 800 (see below why 800) # start downloading the rest of the shards for a total of 1200 (see below why 1200)
python -m nanochat.dataset -n 800 & python -m nanochat.dataset -n 1200 &
# todo: download the rest of it # todo: download the rest of it
python -m scripts.tok_train --max_chars=4000000000 --vocab_size=65536 python -m scripts.tok_train --max_chars=4000000000 --vocab_size=65536
python -m scripts.tok_eval python -m scripts.tok_eval
@@ -62,7 +62,9 @@ python -m scripts.tok_eval
# The tok_eval.py script reports about ~4.8 chars/token on average for the default tokenizer settings. # The tok_eval.py script reports about ~4.8 chars/token on average for the default tokenizer settings.
# So ~38B tokens # ~4.8 chars/token = ~185B chars. # So ~38B tokens # ~4.8 chars/token = ~185B chars.
# Each data shard is ~250M chars, so we need ~185B / 250M ~= 740 shards. # Each data shard is ~250M chars, so we need ~185B / 250M ~= 740 shards.
# For safety, I bumped that up to 800 shards, and that's why up above I used -n 800 when pre-downloading dataset shards. # For safety, I bumped that up to 800 shards.
# The new DataLoader wastes about 35% of tokens to cropping, so 800 / (1 - 0.35) ~= 1200 shards are needed.
# => why up above I used -n 1200 when pre-downloading dataset shards.
# If we didn't have enough data, the training script would loop around and do multiple epochs over the same data, # If we didn't have enough data, the training script would loop around and do multiple epochs over the same data,
# which would decrease model performance. Possibly 2, 3 or so epochs is ~ok, but certainly not ideal and at 10+ epochs we'd # which would decrease model performance. Possibly 2, 3 or so epochs is ~ok, but certainly not ideal and at 10+ epochs we'd
# start to overfit hard. # start to overfit hard.

View File

@@ -21,7 +21,7 @@ import wandb
import torch import torch
from nanochat.gpt import GPT, GPTConfig from nanochat.gpt import GPT, GPTConfig
from nanochat.dataloader import tokenizing_distributed_data_loader, tokenizing_distributed_data_loader_with_state from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
from nanochat.tokenizer import get_tokenizer, get_token_bytes from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
@@ -210,8 +210,8 @@ if resuming:
# Initialize the DataLoaders for train/val # Initialize the DataLoaders for train/val
tokens_dir = os.path.join(base_dir, "tokenized_data") tokens_dir = os.path.join(base_dir, "tokenized_data")
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"] dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
train_loader = tokenizing_distributed_data_loader_with_state(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict) train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
build_val_loader = lambda: tokenizing_distributed_data_loader(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device) build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device)
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -395,7 +395,8 @@ while True:
eta_str = f" | eta: {eta_seconds/60:.1f}m" eta_str = f" | eta: {eta_seconds/60:.1f}m"
else: else:
eta_str = "" eta_str = ""
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m{eta_str}") epoch = dataloader_state_dict["epoch"]
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}")
if step % 100 == 0: if step % 100 == 0:
log_data = { log_data = {
"step": step, "step": step,
@@ -406,6 +407,7 @@ while True:
"train/dt": dt, "train/dt": dt,
"train/tok_per_sec": tok_per_sec, "train/tok_per_sec": tok_per_sec,
"train/mfu": mfu, "train/mfu": mfu,
"train/epoch": epoch,
} }
wandb_run.log(log_data) wandb_run.log(log_data)

View File

@@ -10,7 +10,6 @@ torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_
""" """
import argparse import argparse
from collections import deque
import os import os
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
import time import time
@@ -125,49 +124,95 @@ val_dataset = TaskMixture([
# these two global variables and update them from within the data generator. # these two global variables and update them from within the data generator.
last_step = False # we will toggle this to True when we reach the end of the training dataset last_step = False # we will toggle this to True when we reach the end of the training dataset
approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
def mid_data_generator(split): current_epoch = 1 # track epoch for logging
global last_step, approx_progress def mid_data_generator_bos_bestfit(split, buffer_size=100):
"""
BOS-aligned dataloader for midtraining with bestfit-crop packing.
Each row in the batch starts with BOS (beginning of a conversation).
Conversations are packed using best-fit algorithm to minimize cropping.
This matches the BOS-aligned approach used in pretraining.
"""
global last_step, approx_progress, current_epoch
assert split in {"train", "val"}, "split must be 'train' or 'val'" assert split in {"train", "val"}, "split must be 'train' or 'val'"
dataset = train_dataset if split == "train" else val_dataset dataset = train_dataset if split == "train" else val_dataset
dataset_size = len(dataset) dataset_size = len(dataset)
assert dataset_size > 0 assert dataset_size > 0
needed_tokens = args.device_batch_size * args.max_seq_len + 1 # to form one training batch of inputs,targets row_capacity = args.max_seq_len + 1 # +1 for target at last position
token_buffer = deque()
# CUDA supports memory pinning for faster transfers between CPU and GPU: # Conversation buffer: list of token lists
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=(device_type == "cuda")) conv_buffer = []
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents cursor = ddp_rank # Each rank processes different conversations
epoch = 1
it = 0 # iteration counter it = 0 # iteration counter
while True:
# Accumulate enough tokens for one iteration before yielding def refill_buffer():
while len(token_buffer) < needed_tokens: nonlocal cursor, epoch
while len(conv_buffer) < buffer_size:
conversation = dataset[cursor] conversation = dataset[cursor]
ids, _ = tokenizer.render_conversation(conversation) ids, _ = tokenizer.render_conversation(conversation)
token_buffer.extend(ids) conv_buffer.append(ids)
cursor += ddp_world_size cursor += ddp_world_size
if cursor >= dataset_size: if cursor >= dataset_size:
cursor -= dataset_size # wrap around for another epoch cursor = cursor % dataset_size
epoch += 1
if split == "train": if split == "train":
last_step = True # toggle last_step to True, which will terminate the training loop last_step = True # toggle last_step to True, which will terminate the training loop
while True:
rows = []
for _ in range(args.device_batch_size):
row = []
while len(row) < row_capacity:
# Ensure buffer has conversations
while len(conv_buffer) < buffer_size:
refill_buffer()
remaining = row_capacity - len(row)
# Find largest conversation that fits entirely
best_idx = -1
best_len = 0
for i, conv in enumerate(conv_buffer):
conv_len = len(conv)
if conv_len <= remaining and conv_len > best_len:
best_idx = i
best_len = conv_len
if best_idx >= 0:
# Found a conversation that fits - use it entirely
conv = conv_buffer.pop(best_idx)
row.extend(conv)
else:
# No conversation fits - crop first conversation to fill remaining
conv = conv_buffer.pop(0)
row.extend(conv[:remaining])
rows.append(row[:row_capacity])
# Stopping condition to respect num_iterations, if given # Stopping condition to respect num_iterations, if given
it += 1 it += 1
if 0 < args.num_iterations <= it and split == "train": if 0 < args.num_iterations <= it and split == "train":
last_step = True # toggle last_step to True, which will terminate the training loop last_step = True
# Build up inputs/targets and yield
for i in range(needed_tokens): # Update progress tracking
scratch[i] = token_buffer.popleft()
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
targets_cpu = scratch[1:]
inputs = inputs_cpu.view(args.device_batch_size, args.max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
targets = targets_cpu.view(args.device_batch_size, args.max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
if split == "train": if split == "train":
current_epoch = epoch
if args.num_iterations > 0: if args.num_iterations > 0:
approx_progress = it / args.num_iterations # calculate progress from the max number of iterations approx_progress = it / args.num_iterations
else: else:
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset approx_progress = cursor / dataset_size
# Build tensors
use_cuda = device_type == "cuda"
batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda)
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda)
yield inputs, targets yield inputs, targets
train_loader = mid_data_generator("train") train_loader = mid_data_generator_bos_bestfit("train")
build_val_loader = lambda: mid_data_generator("val") build_val_loader = lambda: mid_data_generator_bos_bestfit("val")
progress = 0 # will go from 0 to 1 over the course of the epoch progress = 0 # will go from 0 to 1 over the course of the epoch
# Learning rate scheduler # Learning rate scheduler
@@ -285,7 +330,7 @@ while True:
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
if step > 10: if step > 10:
total_training_time += dt # only count the time after the first 10 steps total_training_time += dt # only count the time after the first 10 steps
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m")
if step % 10 == 0: if step % 10 == 0:
wandb_run.log({ wandb_run.log({
"step": step, "step": step,
@@ -296,6 +341,7 @@ while True:
"train/dt": dt, "train/dt": dt,
"train/tok_per_sec": tok_per_sec, "train/tok_per_sec": tok_per_sec,
"train/mfu": mfu, "train/mfu": mfu,
"train/epoch": current_epoch,
}) })
# print a few more stats # print a few more stats

View File

@@ -55,8 +55,8 @@ python -m nanochat.report reset
# each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk # each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk
python -m nanochat.dataset -n 8 python -m nanochat.dataset -n 8
# Immediately also kick off downloading more shards in the background while tokenizer trains # Immediately also kick off downloading more shards in the background while tokenizer trains
# See comment below for why 240 is the right number here # See comment below for why 370 is the right number here
python -m nanochat.dataset -n 240 & python -m nanochat.dataset -n 370 &
DATASET_DOWNLOAD_PID=$! DATASET_DOWNLOAD_PID=$!
# train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data # train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data
python -m scripts.tok_train --max_chars=2000000000 --vocab_size=65536 python -m scripts.tok_train --max_chars=2000000000 --vocab_size=65536
@@ -70,7 +70,9 @@ python -m scripts.tok_eval
# Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens. # Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens.
# Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars. # Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars.
# At 250M chars/shard, this is 54B / 250M ~= 216 shards needed for pretraining. # At 250M chars/shard, this is 54B / 250M ~= 216 shards needed for pretraining.
# Round up to 240 for safety. At ~100MB/shard, this downloads ~24GB of data to disk. # Round up to 240 for safety. Also, the new DataLoader wastes about 35% of tokens to cropping
# so 240 / (1 - 0.35) = 370 shards are needed.
# At ~100MB/shard, this downloads ~37GB of data to disk.
# (The total number of shards available in the entire dataset is 1822.) # (The total number of shards available in the entire dataset is 1822.)
echo "Waiting for dataset download to complete..." echo "Waiting for dataset download to complete..."
wait $DATASET_DOWNLOAD_PID wait $DATASET_DOWNLOAD_PID