diff --git a/dev/LOG.md b/dev/LOG.md index ee1e82e..f2322de 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -4,6 +4,39 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026 --- +## 2026-01-11: Flash Attention 3 Integration + +Replaced PyTorch's `scaled_dot_product_attention` (FA2) with Flash Attention 3 for training and inference. + +### Changes Made + +**1. FA3 via `kernels` package** +- Official FA3 is "beta" and requires building from source (painful) +- Using `kernels` package from HuggingFace Hub: `get_kernel('varunneal/flash-attention-3')` +- Loads pre-built wheels, works out of the box on H100 + +**2. Simplified attention code** +- FA3 uses `(B, T, H, D)` layout matching our projection output directly - no transpose needed +- Training: `flash_attn.flash_attn_func(q, k, v, causal=True)` +- Inference: `flash_attn.flash_attn_with_kvcache()` handles all cache cases in one call +- Removed 3 separate FA2 code paths (training, single-token, chunk inference) +- GQA handled automatically when n_kv_heads < n_heads + +**3. Rewrote KVCache for FA3** +- Old format: `(num_layers, 2, B, H, T, D)` combined tensor +- New format: separate `k_cache` and `v_cache` of shape `(num_layers, B, T, H, D)` +- FA3 updates cache in-place during `flash_attn_with_kvcache` +- Position tracked via `cache_seqlens` tensor (int32, per batch element) +- Simpler API: `get_layer_cache()`, `advance()`, `reset()`, `prefill()` + +### Results + +- **~9% improvement in tok/sec** during training out of the box +- Benchmarks showed FA3 is 2x faster than FA2 at realistic training sizes (batch=32, seq=2048) +- FA3 supports sliding window via `window_size=(left, 0)`, which is huge and expected to give further improvements. This is ready to tune but keeping full context for now. + +--- + ## 2026-01-11: Per-Layer Residual Scalars (x0 & resid lambdas) Cherry-picked an idea from modded-nanogpt around learnable per-layer residual connections. diff --git a/nanochat/engine.py b/nanochat/engine.py index d4367fb..53fdec5 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -82,83 +82,54 @@ def use_calculator(expr): # ----------------------------------------------------------------------------- class KVCache: """ - Works hand-in-hand with the GPT model to maintain the KV cache. - Note that the .pos advances automatically after the last layer of the Transformer inserts. + KV Cache designed for Flash Attention 3's flash_attn_with_kvcache API. + + Key differences from FA2-style cache: + - Tensors are (B, T, H, D) not (B, H, T, D) + - FA3 updates the cache in-place during flash_attn_with_kvcache + - Position tracked per batch element via cache_seqlens tensor """ - def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers): - # Each of K/V is of shape (B, H, T, D) and we have one per layer of the Transformer. - self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim) - self.kv_cache = None - self.pos = 0 # current position in time in the cache + def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype=torch.bfloat16): + self.batch_size = batch_size + self.max_seq_len = seq_len + self.n_layers = num_layers + self.n_heads = num_heads + self.head_dim = head_dim + # Pre-allocate cache tensors: (n_layers, B, T, H, D) + self.k_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + self.v_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + # Current sequence length per batch element (FA3 needs int32) + self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device) def reset(self): - self.pos = 0 + """Reset cache to empty state.""" + self.cache_seqlens.zero_() def get_pos(self): - return self.pos + """Get current position (assumes all batch elements at same position).""" + return self.cache_seqlens[0].item() + + def get_layer_cache(self, layer_idx): + """Return (k_cache, v_cache) views for a specific layer.""" + return self.k_cache[layer_idx], self.v_cache[layer_idx] + + def advance(self, num_tokens): + """Advance the cache position by num_tokens.""" + self.cache_seqlens += num_tokens def prefill(self, other): """ - Prefill given another KV cache. Optionally expand along batch dim. - This is used when we do batch 1 prefill and then want to generate - multiple samples in parallel from there. + Copy cached KV from another cache into this one. + Used when we do batch=1 prefill and then want to generate multiple samples in parallel. """ - # 1) validate the shapes - assert self.kv_cache is None, "Cannot prefill a non-empty KV cache" - assert other.kv_cache is not None, "Cannot prefill with a None KV cache" - - # Extract dimensions explicitly - self_layers, self_kv, self_batch, self_heads, self_seq, self_head_dim = self.kv_shape - other_layers, other_kv, other_batch, other_heads, other_seq, other_head_dim = other.kv_shape - - # Validate dimensions - assert self_layers == other_layers, f"Layer count mismatch: {self_layers} != {other_layers}" - assert self_kv == other_kv, f"K/V dimension mismatch: {self_kv} != {other_kv}" - assert self_heads == other_heads, f"Head count mismatch: {self_heads} != {other_heads}" - assert self_head_dim == other_head_dim, f"Head dim mismatch: {self_head_dim} != {other_head_dim}" - - # Batch size can be expanded (other can be 1, self can be larger) - assert self_batch == other_batch or other_batch == 1, f"Batch size mismatch: {self_batch} vs {other_batch} (other must be 1 or equal)" - - # Sequence length: self must be longer than other - assert self_seq >= other_seq, f"Sequence length mismatch: {self_seq} < {other_seq}" - - # 2) initialize the cache - dtype, device = other.kv_cache.dtype, other.kv_cache.device - self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device) - # 3) copy the data over - self.kv_cache[:, :, :, :, :other.pos, :] = other.kv_cache - # 4) update the pos - self.pos = other.pos - - def insert_kv(self, layer_idx, k, v): - # Lazy initialize the cache here because we need to know the dtype/device - if self.kv_cache is None: - self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device) - # Insert new keys/values to the cache and return the full cache so far - B, H, T_add, D = k.size() - t0, t1 = self.pos, self.pos + T_add - # Dynamically grow the cache if needed - if t1 > self.kv_cache.size(4): - t_needed = t1 + 1024 # as much as we need plus buffer of 1024 - t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024 - additional_shape = list(self.kv_cache.shape) - additional_shape[4] = t_needed - self.kv_cache.size(4) - additional_cache = torch.empty(additional_shape, dtype=k.dtype, device=k.device) - self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous() - self.kv_shape = self.kv_cache.shape - # Insert k, v into the cache - self.kv_cache[layer_idx, 0, :, :, t0:t1, :] = k - self.kv_cache[layer_idx, 1, :, :, t0:t1, :] = v - # Return the full cached keys/values up to current position (as a view) - key_view = self.kv_cache[layer_idx, 0, :, :, :t1, :] - value_view = self.kv_cache[layer_idx, 1, :, :, :t1, :] - # Increment pos after the last layer of the Transformer processes - if layer_idx == self.kv_cache.size(0) - 1: - self.pos = t1 - return key_view, value_view - + assert self.get_pos() == 0, "Cannot prefill a non-empty KV cache" + assert self.n_layers == other.n_layers and self.n_heads == other.n_heads and self.head_dim == other.head_dim + assert self.max_seq_len >= other.max_seq_len + other_pos = other.get_pos() + self.k_cache[:, :, :other_pos, :, :] = other.k_cache[:, :, :other_pos, :, :] + self.v_cache[:, :, :other_pos, :, :] = other.v_cache[:, :, :other_pos, :, :] + self.cache_seqlens.fill_(other_pos) # ----------------------------------------------------------------------------- @torch.inference_mode() @@ -219,6 +190,7 @@ class Engine: kv_cache_prefill = KVCache( batch_size=1, seq_len=len(tokens), + device=device, **kv_model_kwargs, ) ids = torch.tensor([tokens], dtype=torch.long, device=device) @@ -230,6 +202,7 @@ class Engine: kv_cache_decode = KVCache( batch_size=num_samples, seq_len=kv_length_hint, + device=device, **kv_model_kwargs, ) kv_cache_decode.prefill(kv_cache_prefill) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 6f4556a..f22ec07 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -9,9 +9,9 @@ Notable features: - no learnable params in rmsnorm - no bias in linear layers - Group-Query Attention (GQA) support for more efficient inference +- Flash Attention 3 integration """ -import math from functools import partial from dataclasses import dataclass @@ -23,6 +23,14 @@ from nanochat.common import get_dist_info, print0 from nanochat.muon import Muon, DistMuon from nanochat.adamw import DistAdamW +# Load Flash Attention 3 from HuggingFace Hub (and silence the progress bar) +import os +os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" +# Official docs of FA3 label it as "beta" and want you to install FA3 from source, which is a pain. +# Wishing for official FA3 wheels soon, for now this seems to be a fast way to get them (ty varunneal) +from kernels import get_kernel +flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface + @dataclass class GPTConfig: sequence_len: int = 1024 @@ -65,44 +73,36 @@ class CausalSelfAttention(nn.Module): B, T, C = x.size() # Project the input to get queries, keys, and values + # Shape: (B, T, H, D) - FA3's native layout, no transpose needed! q = self.c_q(x).view(B, T, self.n_head, self.head_dim) k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim) v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim) # Apply Rotary Embeddings to queries and keys to get relative positional encoding cos, sin = cos_sin - q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding + q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) q, k = norm(q), norm(k) # QK norm - q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D) - # Apply KV cache: insert current k,v into cache, get the full view so far - if kv_cache is not None: - k, v = kv_cache.insert_kv(self.layer_idx, k, v) - Tq = q.size(2) # number of queries in this forward pass - Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass) - - # Attention: queries attend to keys/values autoregressively. A few cases to handle: - enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired - if kv_cache is None or Tq == Tk: - # During training (no KV cache), attend as usual with causal attention - # And even if there is KV cache, we can still use this simple version when Tq == Tk - y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa) - elif Tq == 1: - # During inference but with a single query in this forward pass: - # The query has to attend to all the keys/values in the cache - y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa) + # Attention with Flash Attention 3 + # FA3 handles GQA automatically when n_kv_heads < n_heads + if kv_cache is None: + # Training: simple causal attention + y = flash_attn.flash_attn_func(q, k, v, causal=True) else: - # During inference AND we have a chunk of queries in this forward pass: - # First, each query attends to all the cached keys/values (i.e. full prefix) - attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask - prefix_len = Tk - Tq - attn_mask[:, :prefix_len] = True - # Then, causal attention within this chunk - attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device)) - y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa) + # Inference: use flash_attn_with_kvcache which handles cache management + k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx) + y = flash_attn.flash_attn_with_kvcache( + q, k_cache, v_cache, + k=k, v=v, + cache_seqlens=kv_cache.cache_seqlens, + causal=True, + ) + # Advance position after last layer processes + if self.layer_idx == kv_cache.n_layers - 1: + kv_cache.advance(T) - # Re-assemble the heads side by side and project back to residual stream - y = y.transpose(1, 2).contiguous().view(B, T, -1) + # Re-assemble the heads and project back to residual stream + y = y.contiguous().view(B, T, -1) y = self.c_proj(y) return y diff --git a/pyproject.toml b/pyproject.toml index 0931ca6..87a967f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ dependencies = [ "datasets>=4.0.0", "fastapi>=0.117.1", "ipykernel>=7.1.0", + "kernels>=0.11.7", "matplotlib>=3.10.8", "psutil>=7.1.0", "python-dotenv>=1.2.1", diff --git a/tests/test_engine.py b/tests/test_engine.py index 683f89b..9351e5a 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -39,13 +39,9 @@ class MockModel: def forward(self, ids, kv_cache=None): """Return uniform logits so sampling is spread across vocab.""" B, T = ids.shape - # Simulate what a real transformer does: insert k,v into the cache for each layer + # With FA3, flash_attn_with_kvcache updates cache in-place and we advance position if kv_cache is not None: - head_dim = self.config.n_embd // self.config.n_head - for layer_idx in range(self.config.n_layer): - k = torch.zeros(B, self.config.n_kv_head, T, head_dim) - v = torch.zeros(B, self.config.n_kv_head, T, head_dim) - kv_cache.insert_kv(layer_idx, k, v) + kv_cache.advance(T) # Uniform logits -> equal probability for all tokens logits = torch.zeros(B, T, self.vocab_size) return logits @@ -85,16 +81,11 @@ class ByteTokenizer: byte_tokens = [t for t in tokens if t < 256] return bytes(byte_tokens).decode("utf-8", errors="replace") -def test_kv_cache_resize(): - """ - The KV cache was not resized correctly, more information here: - https://github.com/karpathy/nanochat/pull/186 - This test reproduces the issue and will be merged alongside the fix. - """ - +def test_kv_cache_basic(): + """Test basic KVCache functionality for FA3.""" batch_size = 2 num_heads = 3 - seq_len = 4 + seq_len = 64 head_dim = 5 num_layers = 6 @@ -103,45 +94,64 @@ def test_kv_cache_resize(): num_heads=num_heads, seq_len=seq_len, head_dim=head_dim, - num_layers=num_layers + num_layers=num_layers, + device="cpu", ) - # Insert a single token with a distinct fill value to all layers - def insert_token(token_idx): - for layer_idx in range(num_layers): - k = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx), dtype=torch.float32) - v = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx * 100), dtype=torch.float32) - kv_cache.insert_kv(layer_idx, k, v) + # Check initial state + assert kv_cache.get_pos() == 0 + assert kv_cache.k_cache.shape == (num_layers, batch_size, seq_len, num_heads, head_dim) + assert kv_cache.v_cache.shape == (num_layers, batch_size, seq_len, num_heads, head_dim) - # Insert 4 tokens (fills the initial seq_len=4) - for i in range(4): - insert_token(i) + # Test advance + kv_cache.advance(10) + assert kv_cache.get_pos() == 10 - # Record the original state of the cache - original_cache = kv_cache.kv_cache.clone() - original_seq_len = original_cache.shape[4] + kv_cache.advance(5) + assert kv_cache.get_pos() == 15 - # Insert the 5th token, which will trigger a resize - insert_token(4) - # Verify that the cache actually resized - new_seq_len = kv_cache.kv_cache.shape[4] - assert new_seq_len > original_seq_len, f"Cache did not resize: original seq_len={original_seq_len}, new seq_len={new_seq_len}" + # Test reset + kv_cache.reset() + assert kv_cache.get_pos() == 0 - # Verify that the original 4 tokens are still intact after resize - for layer_idx in range(num_layers): - for token_idx in range(4): - # Check that resized cache matches expected values - expected_k = float(token_idx) - expected_v = float(token_idx * 100) - actual_k = kv_cache.kv_cache[layer_idx, 0, :, :, token_idx, :] - actual_v = kv_cache.kv_cache[layer_idx, 1, :, :, token_idx, :] - assert (actual_k == expected_k).all(), f"Layer {layer_idx}, token {token_idx}: key corrupted, expected {expected_k}" - assert (actual_v == expected_v).all(), f"Layer {layer_idx}, token {token_idx}: value corrupted, expected {expected_v}" - # And that the original cache matches resized cache - original_k = original_cache[layer_idx, 0, :, :, token_idx, :] - original_v = original_cache[layer_idx, 1, :, :, token_idx, :] - assert (actual_k == original_k).all(), f"Layer {layer_idx}, token {token_idx}: key doesn't match original" - assert (actual_v == original_v).all(), f"Layer {layer_idx}, token {token_idx}: value doesn't match original" + # Test get_layer_cache returns correct views + k_layer0, v_layer0 = kv_cache.get_layer_cache(0) + assert k_layer0.shape == (batch_size, seq_len, num_heads, head_dim) + assert v_layer0.shape == (batch_size, seq_len, num_heads, head_dim) + + +def test_kv_cache_prefill(): + """Test KVCache.prefill() copies data correctly.""" + batch_size = 1 + num_heads = 4 + head_dim = 8 + num_layers = 2 + + # Create source cache and advance it + src_cache = KVCache( + batch_size=batch_size, num_heads=num_heads, seq_len=32, + head_dim=head_dim, num_layers=num_layers, device="cpu", + ) + # Write some data to source cache + src_cache.k_cache[0, 0, :16, :, :] = 1.0 + src_cache.v_cache[0, 0, :16, :, :] = 2.0 + src_cache.advance(16) + + # Create destination cache with larger seq_len + dst_cache = KVCache( + batch_size=batch_size, num_heads=num_heads, seq_len=64, + head_dim=head_dim, num_layers=num_layers, device="cpu", + ) + + # Prefill + dst_cache.prefill(src_cache) + + # Check position was copied + assert dst_cache.get_pos() == 16 + + # Check data was copied + assert (dst_cache.k_cache[0, 0, :16, :, :] == 1.0).all() + assert (dst_cache.v_cache[0, 0, :16, :, :] == 2.0).all() def test_multi_sample_first_token_diversity(): diff --git a/uv.lock b/uv.lock index 63b2c01..b168a2f 100644 --- a/uv.lock +++ b/uv.lock @@ -1089,6 +1089,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e7/e7/80988e32bf6f73919a113473a604f5a8f09094de312b9d52b79c2df7612b/jupyter_core-5.9.1-py3-none-any.whl", hash = "sha256:ebf87fdc6073d142e114c72c9e29a9d7ca03fad818c5d300ce2adc1fb0743407", size = 29032, upload-time = "2025-10-16T19:19:16.783Z" }, ] +[[package]] +name = "kernels" +version = "0.11.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "huggingface-hub" }, + { name = "packaging" }, + { name = "pyyaml" }, + { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d6/c8/2d4fea16366d34069af6d4c4f61218f55e5d0daea5d4c24d58849e9fd626/kernels-0.11.7.tar.gz", hash = "sha256:99c3aa518965518902f4dc26053d6051f06abc904ae33d9486c28674a2ea0fa5", size = 50282, upload-time = "2026-01-08T15:41:57.383Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ab/49/e62183353374ec71306ef354781233ac8d12fdfd1cf3d47c875055a99603/kernels-0.11.7-py3-none-any.whl", hash = "sha256:1421791b1e501fcb0a7f0a4d763c5385591756d9d6ed12ed8baa1e0d71bcd21a", size = 46501, upload-time = "2026-01-08T15:41:55.784Z" }, +] + [[package]] name = "kiwisolver" version = "1.4.9" @@ -1478,6 +1493,7 @@ dependencies = [ { name = "datasets" }, { name = "fastapi" }, { name = "ipykernel" }, + { name = "kernels" }, { name = "matplotlib" }, { name = "psutil" }, { name = "python-dotenv" }, @@ -1518,6 +1534,7 @@ requires-dist = [ { name = "datasets", specifier = ">=4.0.0" }, { name = "fastapi", specifier = ">=0.117.1" }, { name = "ipykernel", specifier = ">=7.1.0" }, + { name = "kernels", specifier = ">=0.11.7" }, { name = "matplotlib", specifier = ">=3.10.8" }, { name = "psutil", specifier = ">=7.1.0" }, { name = "python-dotenv", specifier = ">=1.2.1" },