integrate Flash Attention 3. +9% tok_per_sec for d12 with ctx even as low as 2048 out of the box nice. also, ready to tune windows huge

This commit is contained in:
Andrej Karpathy
2026-01-11 20:33:19 +00:00
parent 201d705957
commit 2ff7d51252
6 changed files with 177 additions and 143 deletions

View File

@@ -4,6 +4,39 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
---
## 2026-01-11: Flash Attention 3 Integration
Replaced PyTorch's `scaled_dot_product_attention` (FA2) with Flash Attention 3 for training and inference.
### Changes Made
**1. FA3 via `kernels` package**
- Official FA3 is "beta" and requires building from source (painful)
- Using `kernels` package from HuggingFace Hub: `get_kernel('varunneal/flash-attention-3')`
- Loads pre-built wheels, works out of the box on H100
**2. Simplified attention code**
- FA3 uses `(B, T, H, D)` layout matching our projection output directly - no transpose needed
- Training: `flash_attn.flash_attn_func(q, k, v, causal=True)`
- Inference: `flash_attn.flash_attn_with_kvcache()` handles all cache cases in one call
- Removed 3 separate FA2 code paths (training, single-token, chunk inference)
- GQA handled automatically when n_kv_heads < n_heads
**3. Rewrote KVCache for FA3**
- Old format: `(num_layers, 2, B, H, T, D)` combined tensor
- New format: separate `k_cache` and `v_cache` of shape `(num_layers, B, T, H, D)`
- FA3 updates cache in-place during `flash_attn_with_kvcache`
- Position tracked via `cache_seqlens` tensor (int32, per batch element)
- Simpler API: `get_layer_cache()`, `advance()`, `reset()`, `prefill()`
### Results
- **~9% improvement in tok/sec** during training out of the box
- Benchmarks showed FA3 is 2x faster than FA2 at realistic training sizes (batch=32, seq=2048)
- FA3 supports sliding window via `window_size=(left, 0)`, which is huge and expected to give further improvements. This is ready to tune but keeping full context for now.
---
## 2026-01-11: Per-Layer Residual Scalars (x0 & resid lambdas)
Cherry-picked an idea from modded-nanogpt around learnable per-layer residual connections.

View File

@@ -82,83 +82,54 @@ def use_calculator(expr):
# -----------------------------------------------------------------------------
class KVCache:
"""
Works hand-in-hand with the GPT model to maintain the KV cache.
Note that the .pos advances automatically after the last layer of the Transformer inserts.
KV Cache designed for Flash Attention 3's flash_attn_with_kvcache API.
Key differences from FA2-style cache:
- Tensors are (B, T, H, D) not (B, H, T, D)
- FA3 updates the cache in-place during flash_attn_with_kvcache
- Position tracked per batch element via cache_seqlens tensor
"""
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers):
# Each of K/V is of shape (B, H, T, D) and we have one per layer of the Transformer.
self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
self.kv_cache = None
self.pos = 0 # current position in time in the cache
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype=torch.bfloat16):
self.batch_size = batch_size
self.max_seq_len = seq_len
self.n_layers = num_layers
self.n_heads = num_heads
self.head_dim = head_dim
# Pre-allocate cache tensors: (n_layers, B, T, H, D)
self.k_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
self.v_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
# Current sequence length per batch element (FA3 needs int32)
self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
def reset(self):
self.pos = 0
"""Reset cache to empty state."""
self.cache_seqlens.zero_()
def get_pos(self):
return self.pos
"""Get current position (assumes all batch elements at same position)."""
return self.cache_seqlens[0].item()
def get_layer_cache(self, layer_idx):
"""Return (k_cache, v_cache) views for a specific layer."""
return self.k_cache[layer_idx], self.v_cache[layer_idx]
def advance(self, num_tokens):
"""Advance the cache position by num_tokens."""
self.cache_seqlens += num_tokens
def prefill(self, other):
"""
Prefill given another KV cache. Optionally expand along batch dim.
This is used when we do batch 1 prefill and then want to generate
multiple samples in parallel from there.
Copy cached KV from another cache into this one.
Used when we do batch=1 prefill and then want to generate multiple samples in parallel.
"""
# 1) validate the shapes
assert self.kv_cache is None, "Cannot prefill a non-empty KV cache"
assert other.kv_cache is not None, "Cannot prefill with a None KV cache"
# Extract dimensions explicitly
self_layers, self_kv, self_batch, self_heads, self_seq, self_head_dim = self.kv_shape
other_layers, other_kv, other_batch, other_heads, other_seq, other_head_dim = other.kv_shape
# Validate dimensions
assert self_layers == other_layers, f"Layer count mismatch: {self_layers} != {other_layers}"
assert self_kv == other_kv, f"K/V dimension mismatch: {self_kv} != {other_kv}"
assert self_heads == other_heads, f"Head count mismatch: {self_heads} != {other_heads}"
assert self_head_dim == other_head_dim, f"Head dim mismatch: {self_head_dim} != {other_head_dim}"
# Batch size can be expanded (other can be 1, self can be larger)
assert self_batch == other_batch or other_batch == 1, f"Batch size mismatch: {self_batch} vs {other_batch} (other must be 1 or equal)"
# Sequence length: self must be longer than other
assert self_seq >= other_seq, f"Sequence length mismatch: {self_seq} < {other_seq}"
# 2) initialize the cache
dtype, device = other.kv_cache.dtype, other.kv_cache.device
self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)
# 3) copy the data over
self.kv_cache[:, :, :, :, :other.pos, :] = other.kv_cache
# 4) update the pos
self.pos = other.pos
def insert_kv(self, layer_idx, k, v):
# Lazy initialize the cache here because we need to know the dtype/device
if self.kv_cache is None:
self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)
# Insert new keys/values to the cache and return the full cache so far
B, H, T_add, D = k.size()
t0, t1 = self.pos, self.pos + T_add
# Dynamically grow the cache if needed
if t1 > self.kv_cache.size(4):
t_needed = t1 + 1024 # as much as we need plus buffer of 1024
t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024
additional_shape = list(self.kv_cache.shape)
additional_shape[4] = t_needed - self.kv_cache.size(4)
additional_cache = torch.empty(additional_shape, dtype=k.dtype, device=k.device)
self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous()
self.kv_shape = self.kv_cache.shape
# Insert k, v into the cache
self.kv_cache[layer_idx, 0, :, :, t0:t1, :] = k
self.kv_cache[layer_idx, 1, :, :, t0:t1, :] = v
# Return the full cached keys/values up to current position (as a view)
key_view = self.kv_cache[layer_idx, 0, :, :, :t1, :]
value_view = self.kv_cache[layer_idx, 1, :, :, :t1, :]
# Increment pos after the last layer of the Transformer processes
if layer_idx == self.kv_cache.size(0) - 1:
self.pos = t1
return key_view, value_view
assert self.get_pos() == 0, "Cannot prefill a non-empty KV cache"
assert self.n_layers == other.n_layers and self.n_heads == other.n_heads and self.head_dim == other.head_dim
assert self.max_seq_len >= other.max_seq_len
other_pos = other.get_pos()
self.k_cache[:, :, :other_pos, :, :] = other.k_cache[:, :, :other_pos, :, :]
self.v_cache[:, :, :other_pos, :, :] = other.v_cache[:, :, :other_pos, :, :]
self.cache_seqlens.fill_(other_pos)
# -----------------------------------------------------------------------------
@torch.inference_mode()
@@ -219,6 +190,7 @@ class Engine:
kv_cache_prefill = KVCache(
batch_size=1,
seq_len=len(tokens),
device=device,
**kv_model_kwargs,
)
ids = torch.tensor([tokens], dtype=torch.long, device=device)
@@ -230,6 +202,7 @@ class Engine:
kv_cache_decode = KVCache(
batch_size=num_samples,
seq_len=kv_length_hint,
device=device,
**kv_model_kwargs,
)
kv_cache_decode.prefill(kv_cache_prefill)

View File

@@ -9,9 +9,9 @@ Notable features:
- no learnable params in rmsnorm
- no bias in linear layers
- Group-Query Attention (GQA) support for more efficient inference
- Flash Attention 3 integration
"""
import math
from functools import partial
from dataclasses import dataclass
@@ -23,6 +23,14 @@ from nanochat.common import get_dist_info, print0
from nanochat.muon import Muon, DistMuon
from nanochat.adamw import DistAdamW
# Load Flash Attention 3 from HuggingFace Hub (and silence the progress bar)
import os
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
# Official docs of FA3 label it as "beta" and want you to install FA3 from source, which is a pain.
# Wishing for official FA3 wheels soon, for now this seems to be a fast way to get them (ty varunneal)
from kernels import get_kernel
flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface
@dataclass
class GPTConfig:
sequence_len: int = 1024
@@ -65,44 +73,36 @@ class CausalSelfAttention(nn.Module):
B, T, C = x.size()
# Project the input to get queries, keys, and values
# Shape: (B, T, H, D) - FA3's native layout, no transpose needed!
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
cos, sin = cos_sin
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
q, k = norm(q), norm(k) # QK norm
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
# Apply KV cache: insert current k,v into cache, get the full view so far
if kv_cache is not None:
k, v = kv_cache.insert_kv(self.layer_idx, k, v)
Tq = q.size(2) # number of queries in this forward pass
Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
# Attention: queries attend to keys/values autoregressively. A few cases to handle:
enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
if kv_cache is None or Tq == Tk:
# During training (no KV cache), attend as usual with causal attention
# And even if there is KV cache, we can still use this simple version when Tq == Tk
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
elif Tq == 1:
# During inference but with a single query in this forward pass:
# The query has to attend to all the keys/values in the cache
y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
# Attention with Flash Attention 3
# FA3 handles GQA automatically when n_kv_heads < n_heads
if kv_cache is None:
# Training: simple causal attention
y = flash_attn.flash_attn_func(q, k, v, causal=True)
else:
# During inference AND we have a chunk of queries in this forward pass:
# First, each query attends to all the cached keys/values (i.e. full prefix)
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
prefix_len = Tk - Tq
attn_mask[:, :prefix_len] = True
# Then, causal attention within this chunk
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
# Inference: use flash_attn_with_kvcache which handles cache management
k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
y = flash_attn.flash_attn_with_kvcache(
q, k_cache, v_cache,
k=k, v=v,
cache_seqlens=kv_cache.cache_seqlens,
causal=True,
)
# Advance position after last layer processes
if self.layer_idx == kv_cache.n_layers - 1:
kv_cache.advance(T)
# Re-assemble the heads side by side and project back to residual stream
y = y.transpose(1, 2).contiguous().view(B, T, -1)
# Re-assemble the heads and project back to residual stream
y = y.contiguous().view(B, T, -1)
y = self.c_proj(y)
return y

View File

@@ -8,6 +8,7 @@ dependencies = [
"datasets>=4.0.0",
"fastapi>=0.117.1",
"ipykernel>=7.1.0",
"kernels>=0.11.7",
"matplotlib>=3.10.8",
"psutil>=7.1.0",
"python-dotenv>=1.2.1",

View File

@@ -39,13 +39,9 @@ class MockModel:
def forward(self, ids, kv_cache=None):
"""Return uniform logits so sampling is spread across vocab."""
B, T = ids.shape
# Simulate what a real transformer does: insert k,v into the cache for each layer
# With FA3, flash_attn_with_kvcache updates cache in-place and we advance position
if kv_cache is not None:
head_dim = self.config.n_embd // self.config.n_head
for layer_idx in range(self.config.n_layer):
k = torch.zeros(B, self.config.n_kv_head, T, head_dim)
v = torch.zeros(B, self.config.n_kv_head, T, head_dim)
kv_cache.insert_kv(layer_idx, k, v)
kv_cache.advance(T)
# Uniform logits -> equal probability for all tokens
logits = torch.zeros(B, T, self.vocab_size)
return logits
@@ -85,16 +81,11 @@ class ByteTokenizer:
byte_tokens = [t for t in tokens if t < 256]
return bytes(byte_tokens).decode("utf-8", errors="replace")
def test_kv_cache_resize():
"""
The KV cache was not resized correctly, more information here:
https://github.com/karpathy/nanochat/pull/186
This test reproduces the issue and will be merged alongside the fix.
"""
def test_kv_cache_basic():
"""Test basic KVCache functionality for FA3."""
batch_size = 2
num_heads = 3
seq_len = 4
seq_len = 64
head_dim = 5
num_layers = 6
@@ -103,45 +94,64 @@ def test_kv_cache_resize():
num_heads=num_heads,
seq_len=seq_len,
head_dim=head_dim,
num_layers=num_layers
num_layers=num_layers,
device="cpu",
)
# Insert a single token with a distinct fill value to all layers
def insert_token(token_idx):
for layer_idx in range(num_layers):
k = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx), dtype=torch.float32)
v = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx * 100), dtype=torch.float32)
kv_cache.insert_kv(layer_idx, k, v)
# Check initial state
assert kv_cache.get_pos() == 0
assert kv_cache.k_cache.shape == (num_layers, batch_size, seq_len, num_heads, head_dim)
assert kv_cache.v_cache.shape == (num_layers, batch_size, seq_len, num_heads, head_dim)
# Insert 4 tokens (fills the initial seq_len=4)
for i in range(4):
insert_token(i)
# Test advance
kv_cache.advance(10)
assert kv_cache.get_pos() == 10
# Record the original state of the cache
original_cache = kv_cache.kv_cache.clone()
original_seq_len = original_cache.shape[4]
kv_cache.advance(5)
assert kv_cache.get_pos() == 15
# Insert the 5th token, which will trigger a resize
insert_token(4)
# Verify that the cache actually resized
new_seq_len = kv_cache.kv_cache.shape[4]
assert new_seq_len > original_seq_len, f"Cache did not resize: original seq_len={original_seq_len}, new seq_len={new_seq_len}"
# Test reset
kv_cache.reset()
assert kv_cache.get_pos() == 0
# Verify that the original 4 tokens are still intact after resize
for layer_idx in range(num_layers):
for token_idx in range(4):
# Check that resized cache matches expected values
expected_k = float(token_idx)
expected_v = float(token_idx * 100)
actual_k = kv_cache.kv_cache[layer_idx, 0, :, :, token_idx, :]
actual_v = kv_cache.kv_cache[layer_idx, 1, :, :, token_idx, :]
assert (actual_k == expected_k).all(), f"Layer {layer_idx}, token {token_idx}: key corrupted, expected {expected_k}"
assert (actual_v == expected_v).all(), f"Layer {layer_idx}, token {token_idx}: value corrupted, expected {expected_v}"
# And that the original cache matches resized cache
original_k = original_cache[layer_idx, 0, :, :, token_idx, :]
original_v = original_cache[layer_idx, 1, :, :, token_idx, :]
assert (actual_k == original_k).all(), f"Layer {layer_idx}, token {token_idx}: key doesn't match original"
assert (actual_v == original_v).all(), f"Layer {layer_idx}, token {token_idx}: value doesn't match original"
# Test get_layer_cache returns correct views
k_layer0, v_layer0 = kv_cache.get_layer_cache(0)
assert k_layer0.shape == (batch_size, seq_len, num_heads, head_dim)
assert v_layer0.shape == (batch_size, seq_len, num_heads, head_dim)
def test_kv_cache_prefill():
"""Test KVCache.prefill() copies data correctly."""
batch_size = 1
num_heads = 4
head_dim = 8
num_layers = 2
# Create source cache and advance it
src_cache = KVCache(
batch_size=batch_size, num_heads=num_heads, seq_len=32,
head_dim=head_dim, num_layers=num_layers, device="cpu",
)
# Write some data to source cache
src_cache.k_cache[0, 0, :16, :, :] = 1.0
src_cache.v_cache[0, 0, :16, :, :] = 2.0
src_cache.advance(16)
# Create destination cache with larger seq_len
dst_cache = KVCache(
batch_size=batch_size, num_heads=num_heads, seq_len=64,
head_dim=head_dim, num_layers=num_layers, device="cpu",
)
# Prefill
dst_cache.prefill(src_cache)
# Check position was copied
assert dst_cache.get_pos() == 16
# Check data was copied
assert (dst_cache.k_cache[0, 0, :16, :, :] == 1.0).all()
assert (dst_cache.v_cache[0, 0, :16, :, :] == 2.0).all()
def test_multi_sample_first_token_diversity():

17
uv.lock generated
View File

@@ -1089,6 +1089,21 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e7/e7/80988e32bf6f73919a113473a604f5a8f09094de312b9d52b79c2df7612b/jupyter_core-5.9.1-py3-none-any.whl", hash = "sha256:ebf87fdc6073d142e114c72c9e29a9d7ca03fad818c5d300ce2adc1fb0743407", size = 29032, upload-time = "2025-10-16T19:19:16.783Z" },
]
[[package]]
name = "kernels"
version = "0.11.7"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "huggingface-hub" },
{ name = "packaging" },
{ name = "pyyaml" },
{ name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
]
sdist = { url = "https://files.pythonhosted.org/packages/d6/c8/2d4fea16366d34069af6d4c4f61218f55e5d0daea5d4c24d58849e9fd626/kernels-0.11.7.tar.gz", hash = "sha256:99c3aa518965518902f4dc26053d6051f06abc904ae33d9486c28674a2ea0fa5", size = 50282, upload-time = "2026-01-08T15:41:57.383Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ab/49/e62183353374ec71306ef354781233ac8d12fdfd1cf3d47c875055a99603/kernels-0.11.7-py3-none-any.whl", hash = "sha256:1421791b1e501fcb0a7f0a4d763c5385591756d9d6ed12ed8baa1e0d71bcd21a", size = 46501, upload-time = "2026-01-08T15:41:55.784Z" },
]
[[package]]
name = "kiwisolver"
version = "1.4.9"
@@ -1478,6 +1493,7 @@ dependencies = [
{ name = "datasets" },
{ name = "fastapi" },
{ name = "ipykernel" },
{ name = "kernels" },
{ name = "matplotlib" },
{ name = "psutil" },
{ name = "python-dotenv" },
@@ -1518,6 +1534,7 @@ requires-dist = [
{ name = "datasets", specifier = ">=4.0.0" },
{ name = "fastapi", specifier = ">=0.117.1" },
{ name = "ipykernel", specifier = ">=7.1.0" },
{ name = "kernels", specifier = ">=0.11.7" },
{ name = "matplotlib", specifier = ">=3.10.8" },
{ name = "psutil", specifier = ">=7.1.0" },
{ name = "python-dotenv", specifier = ">=1.2.1" },