implement flash attention 3 fallback to pytorch sdpa by touching as few lines of code as possible in main files and keeping all implementation to a single file. add tests. add helpful warning messages for the user.

This commit is contained in:
Andrej Karpathy
2026-01-16 17:37:51 +00:00
parent 50413d2d67
commit 8203efa919
3 changed files with 354 additions and 9 deletions

View File

@@ -23,13 +23,8 @@ from nanochat.common import get_dist_info, print0
from nanochat.muon import Muon, DistMuon
from nanochat.adamw import DistAdamW
# Load Flash Attention 3 from HuggingFace Hub (and silence the progress bar)
import os
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
# Official docs of FA3 label it as "beta" and want you to install FA3 from source, which is a pain.
# Wishing for official FA3 wheels soon, for now this seems to be a fast way to get them (ty varunneal)
from kernels import get_kernel
flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
from nanochat.flash_attention import flash_attn
@dataclass
class GPTConfig:
@@ -87,8 +82,7 @@ class CausalSelfAttention(nn.Module):
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
q, k = norm(q), norm(k) # QK norm
# Attention with Flash Attention 3
# FA3 handles GQA automatically when n_kv_heads < n_heads
# Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere)
# window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
if kv_cache is None:
# Training: causal attention with optional sliding window

View File

@@ -27,6 +27,7 @@ from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
from nanochat.flash_attention import HAS_FA3
from scripts.base_eval import evaluate_model
print_banner()
@@ -86,6 +87,18 @@ get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else l
use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config)
# Flash Attention status
if HAS_FA3:
print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.")
else:
print0("!" * 80)
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback")
print0("WARNING: Training will be less efficient without FA3")
if args.window_pattern != "L":
print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.")
print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.")
print0("!" * 80)
# Tokenizer will be useful for evaluation, also we need the vocab size
tokenizer = get_tokenizer()
token_bytes = get_token_bytes(device=device)

View File

@@ -0,0 +1,338 @@
"""
Test Flash Attention unified interface - verify FA3 and SDPA produce identical results.
Run: python -m pytest tests/test_attention_fallback.py -v -s
Note on test structure:
Tests are split into two classes due to dtype/device constraints:
1. TestFA3VsSDPA: Comparison tests that run both FA3 and SDPA on the same inputs
and verify they produce identical results. These require a Hopper GPU (FA3 only
works on sm90+) and use bfloat16 (FA3 doesn't support float32).
2. TestSDPAOnly: Tests that only exercise the SDPA fallback path. These can run
on any device (CUDA, CPU, MPS) with the appropriate dtype for that device.
"""
import torch
import pytest
import nanochat.flash_attention as fa_module
from nanochat.flash_attention import flash_attn, HAS_FA3
from nanochat.engine import KVCache
def set_impl(impl):
"""Set the implementation override ('fa3', 'sdpa', or None for auto)."""
fa_module._override_impl = impl
def run_both_impls(fn):
"""Run a function with both FA3 and SDPA, return both outputs."""
set_impl('fa3')
out_fa3 = fn()
set_impl('sdpa')
out_sdpa = fn()
set_impl(None) # reset
return out_fa3, out_sdpa
def assert_close(t1, t2, name, atol=1e-2, rtol=1e-2):
"""Assert two tensors are close, with helpful error message."""
max_diff = (t1 - t2).abs().max().item()
mean_diff = (t1 - t2).abs().mean().item()
assert torch.allclose(t1, t2, atol=atol, rtol=rtol), \
f"{name}: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}"
return max_diff, mean_diff
# =============================================================================
# FA3 vs SDPA comparison tests (require Hopper GPU)
# =============================================================================
@pytest.mark.skipif(not HAS_FA3, reason="FA3 required to compare implementations")
class TestFA3VsSDPA:
"""Compare FA3 and SDPA produce identical results. Requires Hopper GPU."""
DEVICE = "cuda"
DTYPE = torch.bfloat16
def test_basic_causal(self):
"""Basic causal attention."""
B, T, H, D = 2, 64, 4, 32
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
y_fa3, y_sdpa = run_both_impls(run)
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "basic_causal")
print(f"basic_causal: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
def test_full_context(self):
"""Full context (window_size=-1)."""
B, T, H, D = 2, 128, 4, 32
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(-1, -1))
y_fa3, y_sdpa = run_both_impls(run)
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "full_context")
print(f"full_context: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
def test_sliding_window(self):
"""Sliding window attention."""
B, T, H, D = 2, 128, 4, 32
window = 32
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(window, 0))
y_fa3, y_sdpa = run_both_impls(run)
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "sliding_window")
print(f"sliding_window: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
def test_gqa(self):
"""Group Query Attention (fewer KV heads than Q heads)."""
B, T, D = 2, 64, 32
n_heads = 8
n_kv_heads = 2
q = torch.randn(B, T, n_heads, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
y_fa3, y_sdpa = run_both_impls(run)
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "gqa")
print(f"gqa: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
def test_larger_model(self):
"""Larger dimensions closer to real model."""
B, T, H, D = 4, 256, 12, 64
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(-1, -1))
y_fa3, y_sdpa = run_both_impls(run)
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "larger_model")
print(f"larger_model: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
def test_kvcache_prefill(self):
"""Test prefill (inserting multiple tokens into empty cache)."""
B, T_max, H, D = 2, 64, 4, 32
T_prefill = 16
q = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
cache_seqlens = torch.zeros(B, dtype=torch.int32, device=self.DEVICE)
return flash_attn.flash_attn_with_kvcache(
q, k_cache, v_cache, k=k, v=v,
cache_seqlens=cache_seqlens,
causal=True, window_size=(T_max, 0)
)
y_fa3, y_sdpa = run_both_impls(run)
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "prefill")
print(f"prefill: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
def test_kvcache_single_token(self):
"""Test single token generation (cache already has content)."""
B, T_max, H, D = 2, 64, 4, 32
T_prefill = 16
k_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
q_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
k_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
k_cache[:, :T_prefill, :, :] = k_init
v_cache[:, :T_prefill, :, :] = v_init
cache_seqlens = torch.full((B,), T_prefill, dtype=torch.int32, device=self.DEVICE)
return flash_attn.flash_attn_with_kvcache(
q_single, k_cache, v_cache, k=k_single, v=v_single,
cache_seqlens=cache_seqlens,
causal=True, window_size=(T_max, 0)
)
y_fa3, y_sdpa = run_both_impls(run)
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "single_token")
print(f"single_token: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
def test_backward_gradients_match(self):
"""Verify gradients are similar between FA3 and SDPA."""
B, T, H, D = 2, 32, 4, 16
q_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
q = q_data.clone().requires_grad_(True)
k = k_data.clone().requires_grad_(True)
v = v_data.clone().requires_grad_(True)
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
loss = y.sum()
loss.backward()
return y.detach(), q.grad.detach(), k.grad.detach(), v.grad.detach()
set_impl('fa3')
y_fa3, q_grad_fa3, k_grad_fa3, v_grad_fa3 = run()
set_impl('sdpa')
y_sdpa, q_grad_sdpa, k_grad_sdpa, v_grad_sdpa = run()
set_impl(None)
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "backward_output")
print(f"backward_output: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
max_diff, mean_diff = assert_close(q_grad_fa3, q_grad_sdpa, "q_grad", atol=0.05, rtol=0.05)
print(f"q_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
max_diff, mean_diff = assert_close(k_grad_fa3, k_grad_sdpa, "k_grad", atol=0.05, rtol=0.05)
print(f"k_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
max_diff, mean_diff = assert_close(v_grad_fa3, v_grad_sdpa, "v_grad", atol=0.05, rtol=0.05)
print(f"v_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
# =============================================================================
# SDPA-only tests (run on any device)
# =============================================================================
class TestSDPAOnly:
"""Test SDPA fallback works correctly. Runs on any device."""
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
def test_basic_forward(self):
"""Test SDPA forward pass produces valid output."""
set_impl('sdpa')
B, T, H, D = 2, 64, 4, 32
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
assert y.shape == (B, T, H, D)
assert not torch.isnan(y).any(), "Output contains NaN"
set_impl(None)
def test_backward(self):
"""Test gradients flow through SDPA."""
set_impl('sdpa')
B, T, H, D = 2, 32, 4, 16
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True)
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True)
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True)
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
loss = y.sum()
loss.backward()
assert q.grad is not None, "No gradient for q"
assert k.grad is not None, "No gradient for k"
assert v.grad is not None, "No gradient for v"
assert not torch.isnan(q.grad).any(), "NaN in q gradient"
set_impl(None)
def test_kvcache(self):
"""Test SDPA with KV cache."""
set_impl('sdpa')
B, T_max, H, D = 2, 64, 4, 32
n_layers = 1
cache = KVCache(
batch_size=B, num_heads=H, seq_len=T_max, head_dim=D,
num_layers=n_layers, device=self.DEVICE, dtype=self.DTYPE
)
k_cache, v_cache = cache.get_layer_cache(0)
# Prefill
T_prefill = 16
q = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
y = flash_attn.flash_attn_with_kvcache(
q, k_cache, v_cache, k=k, v=v,
cache_seqlens=cache.cache_seqlens,
causal=True, window_size=(T_max, 0)
)
cache.advance(T_prefill)
assert y.shape == (B, T_prefill, H, D)
assert cache.get_pos() == T_prefill
# Generate single token
q_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
k_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
y_single = flash_attn.flash_attn_with_kvcache(
q_single, k_cache, v_cache, k=k_single, v=v_single,
cache_seqlens=cache.cache_seqlens,
causal=True, window_size=(T_max, 0)
)
cache.advance(1)
assert y_single.shape == (B, 1, H, D)
assert cache.get_pos() == T_prefill + 1
set_impl(None)
# =============================================================================
# Override mechanism tests
# =============================================================================
class TestOverrideMechanism:
"""Test that the override mechanism works correctly."""
@pytest.mark.skipif(not HAS_FA3, reason="FA3 required")
def test_override_fa3(self):
"""Test that override='fa3' uses FA3."""
set_impl('fa3')
assert fa_module._use_fa3() == True
set_impl(None)
def test_override_sdpa(self):
"""Test that override='sdpa' uses SDPA."""
set_impl('sdpa')
assert fa_module._use_fa3() == False
set_impl(None)
def test_override_auto(self):
"""Test that override=None uses auto-detection."""
set_impl(None)
assert fa_module._use_fa3() == HAS_FA3
if __name__ == "__main__":
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA device: {torch.cuda.get_device_name()}")
major, minor = torch.cuda.get_device_capability()
print(f"Compute capability: {major}.{minor}")
print(f"HAS_FA3: {HAS_FA3}")
print()
pytest.main([__file__, "-v", "-s"])