mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-30 04:22:02 +00:00
integrate Flash Attention 3. +9% tok_per_sec for d12 with ctx even as low as 2048 out of the box nice. also, ready to tune windows huge
This commit is contained in:
33
dev/LOG.md
33
dev/LOG.md
@@ -4,6 +4,39 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## 2026-01-11: Flash Attention 3 Integration
|
||||||
|
|
||||||
|
Replaced PyTorch's `scaled_dot_product_attention` (FA2) with Flash Attention 3 for training and inference.
|
||||||
|
|
||||||
|
### Changes Made
|
||||||
|
|
||||||
|
**1. FA3 via `kernels` package**
|
||||||
|
- Official FA3 is "beta" and requires building from source (painful)
|
||||||
|
- Using `kernels` package from HuggingFace Hub: `get_kernel('varunneal/flash-attention-3')`
|
||||||
|
- Loads pre-built wheels, works out of the box on H100
|
||||||
|
|
||||||
|
**2. Simplified attention code**
|
||||||
|
- FA3 uses `(B, T, H, D)` layout matching our projection output directly - no transpose needed
|
||||||
|
- Training: `flash_attn.flash_attn_func(q, k, v, causal=True)`
|
||||||
|
- Inference: `flash_attn.flash_attn_with_kvcache()` handles all cache cases in one call
|
||||||
|
- Removed 3 separate FA2 code paths (training, single-token, chunk inference)
|
||||||
|
- GQA handled automatically when n_kv_heads < n_heads
|
||||||
|
|
||||||
|
**3. Rewrote KVCache for FA3**
|
||||||
|
- Old format: `(num_layers, 2, B, H, T, D)` combined tensor
|
||||||
|
- New format: separate `k_cache` and `v_cache` of shape `(num_layers, B, T, H, D)`
|
||||||
|
- FA3 updates cache in-place during `flash_attn_with_kvcache`
|
||||||
|
- Position tracked via `cache_seqlens` tensor (int32, per batch element)
|
||||||
|
- Simpler API: `get_layer_cache()`, `advance()`, `reset()`, `prefill()`
|
||||||
|
|
||||||
|
### Results
|
||||||
|
|
||||||
|
- **~9% improvement in tok/sec** during training out of the box
|
||||||
|
- Benchmarks showed FA3 is 2x faster than FA2 at realistic training sizes (batch=32, seq=2048)
|
||||||
|
- FA3 supports sliding window via `window_size=(left, 0)`, which is huge and expected to give further improvements. This is ready to tune but keeping full context for now.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## 2026-01-11: Per-Layer Residual Scalars (x0 & resid lambdas)
|
## 2026-01-11: Per-Layer Residual Scalars (x0 & resid lambdas)
|
||||||
|
|
||||||
Cherry-picked an idea from modded-nanogpt around learnable per-layer residual connections.
|
Cherry-picked an idea from modded-nanogpt around learnable per-layer residual connections.
|
||||||
|
|||||||
@@ -82,83 +82,54 @@ def use_calculator(expr):
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
class KVCache:
|
class KVCache:
|
||||||
"""
|
"""
|
||||||
Works hand-in-hand with the GPT model to maintain the KV cache.
|
KV Cache designed for Flash Attention 3's flash_attn_with_kvcache API.
|
||||||
Note that the .pos advances automatically after the last layer of the Transformer inserts.
|
|
||||||
|
Key differences from FA2-style cache:
|
||||||
|
- Tensors are (B, T, H, D) not (B, H, T, D)
|
||||||
|
- FA3 updates the cache in-place during flash_attn_with_kvcache
|
||||||
|
- Position tracked per batch element via cache_seqlens tensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers):
|
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype=torch.bfloat16):
|
||||||
# Each of K/V is of shape (B, H, T, D) and we have one per layer of the Transformer.
|
self.batch_size = batch_size
|
||||||
self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
|
self.max_seq_len = seq_len
|
||||||
self.kv_cache = None
|
self.n_layers = num_layers
|
||||||
self.pos = 0 # current position in time in the cache
|
self.n_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
# Pre-allocate cache tensors: (n_layers, B, T, H, D)
|
||||||
|
self.k_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
|
||||||
|
self.v_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
|
||||||
|
# Current sequence length per batch element (FA3 needs int32)
|
||||||
|
self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.pos = 0
|
"""Reset cache to empty state."""
|
||||||
|
self.cache_seqlens.zero_()
|
||||||
|
|
||||||
def get_pos(self):
|
def get_pos(self):
|
||||||
return self.pos
|
"""Get current position (assumes all batch elements at same position)."""
|
||||||
|
return self.cache_seqlens[0].item()
|
||||||
|
|
||||||
|
def get_layer_cache(self, layer_idx):
|
||||||
|
"""Return (k_cache, v_cache) views for a specific layer."""
|
||||||
|
return self.k_cache[layer_idx], self.v_cache[layer_idx]
|
||||||
|
|
||||||
|
def advance(self, num_tokens):
|
||||||
|
"""Advance the cache position by num_tokens."""
|
||||||
|
self.cache_seqlens += num_tokens
|
||||||
|
|
||||||
def prefill(self, other):
|
def prefill(self, other):
|
||||||
"""
|
"""
|
||||||
Prefill given another KV cache. Optionally expand along batch dim.
|
Copy cached KV from another cache into this one.
|
||||||
This is used when we do batch 1 prefill and then want to generate
|
Used when we do batch=1 prefill and then want to generate multiple samples in parallel.
|
||||||
multiple samples in parallel from there.
|
|
||||||
"""
|
"""
|
||||||
# 1) validate the shapes
|
assert self.get_pos() == 0, "Cannot prefill a non-empty KV cache"
|
||||||
assert self.kv_cache is None, "Cannot prefill a non-empty KV cache"
|
assert self.n_layers == other.n_layers and self.n_heads == other.n_heads and self.head_dim == other.head_dim
|
||||||
assert other.kv_cache is not None, "Cannot prefill with a None KV cache"
|
assert self.max_seq_len >= other.max_seq_len
|
||||||
|
other_pos = other.get_pos()
|
||||||
# Extract dimensions explicitly
|
self.k_cache[:, :, :other_pos, :, :] = other.k_cache[:, :, :other_pos, :, :]
|
||||||
self_layers, self_kv, self_batch, self_heads, self_seq, self_head_dim = self.kv_shape
|
self.v_cache[:, :, :other_pos, :, :] = other.v_cache[:, :, :other_pos, :, :]
|
||||||
other_layers, other_kv, other_batch, other_heads, other_seq, other_head_dim = other.kv_shape
|
self.cache_seqlens.fill_(other_pos)
|
||||||
|
|
||||||
# Validate dimensions
|
|
||||||
assert self_layers == other_layers, f"Layer count mismatch: {self_layers} != {other_layers}"
|
|
||||||
assert self_kv == other_kv, f"K/V dimension mismatch: {self_kv} != {other_kv}"
|
|
||||||
assert self_heads == other_heads, f"Head count mismatch: {self_heads} != {other_heads}"
|
|
||||||
assert self_head_dim == other_head_dim, f"Head dim mismatch: {self_head_dim} != {other_head_dim}"
|
|
||||||
|
|
||||||
# Batch size can be expanded (other can be 1, self can be larger)
|
|
||||||
assert self_batch == other_batch or other_batch == 1, f"Batch size mismatch: {self_batch} vs {other_batch} (other must be 1 or equal)"
|
|
||||||
|
|
||||||
# Sequence length: self must be longer than other
|
|
||||||
assert self_seq >= other_seq, f"Sequence length mismatch: {self_seq} < {other_seq}"
|
|
||||||
|
|
||||||
# 2) initialize the cache
|
|
||||||
dtype, device = other.kv_cache.dtype, other.kv_cache.device
|
|
||||||
self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)
|
|
||||||
# 3) copy the data over
|
|
||||||
self.kv_cache[:, :, :, :, :other.pos, :] = other.kv_cache
|
|
||||||
# 4) update the pos
|
|
||||||
self.pos = other.pos
|
|
||||||
|
|
||||||
def insert_kv(self, layer_idx, k, v):
|
|
||||||
# Lazy initialize the cache here because we need to know the dtype/device
|
|
||||||
if self.kv_cache is None:
|
|
||||||
self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)
|
|
||||||
# Insert new keys/values to the cache and return the full cache so far
|
|
||||||
B, H, T_add, D = k.size()
|
|
||||||
t0, t1 = self.pos, self.pos + T_add
|
|
||||||
# Dynamically grow the cache if needed
|
|
||||||
if t1 > self.kv_cache.size(4):
|
|
||||||
t_needed = t1 + 1024 # as much as we need plus buffer of 1024
|
|
||||||
t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024
|
|
||||||
additional_shape = list(self.kv_cache.shape)
|
|
||||||
additional_shape[4] = t_needed - self.kv_cache.size(4)
|
|
||||||
additional_cache = torch.empty(additional_shape, dtype=k.dtype, device=k.device)
|
|
||||||
self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous()
|
|
||||||
self.kv_shape = self.kv_cache.shape
|
|
||||||
# Insert k, v into the cache
|
|
||||||
self.kv_cache[layer_idx, 0, :, :, t0:t1, :] = k
|
|
||||||
self.kv_cache[layer_idx, 1, :, :, t0:t1, :] = v
|
|
||||||
# Return the full cached keys/values up to current position (as a view)
|
|
||||||
key_view = self.kv_cache[layer_idx, 0, :, :, :t1, :]
|
|
||||||
value_view = self.kv_cache[layer_idx, 1, :, :, :t1, :]
|
|
||||||
# Increment pos after the last layer of the Transformer processes
|
|
||||||
if layer_idx == self.kv_cache.size(0) - 1:
|
|
||||||
self.pos = t1
|
|
||||||
return key_view, value_view
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@@ -219,6 +190,7 @@ class Engine:
|
|||||||
kv_cache_prefill = KVCache(
|
kv_cache_prefill = KVCache(
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
seq_len=len(tokens),
|
seq_len=len(tokens),
|
||||||
|
device=device,
|
||||||
**kv_model_kwargs,
|
**kv_model_kwargs,
|
||||||
)
|
)
|
||||||
ids = torch.tensor([tokens], dtype=torch.long, device=device)
|
ids = torch.tensor([tokens], dtype=torch.long, device=device)
|
||||||
@@ -230,6 +202,7 @@ class Engine:
|
|||||||
kv_cache_decode = KVCache(
|
kv_cache_decode = KVCache(
|
||||||
batch_size=num_samples,
|
batch_size=num_samples,
|
||||||
seq_len=kv_length_hint,
|
seq_len=kv_length_hint,
|
||||||
|
device=device,
|
||||||
**kv_model_kwargs,
|
**kv_model_kwargs,
|
||||||
)
|
)
|
||||||
kv_cache_decode.prefill(kv_cache_prefill)
|
kv_cache_decode.prefill(kv_cache_prefill)
|
||||||
|
|||||||
@@ -9,9 +9,9 @@ Notable features:
|
|||||||
- no learnable params in rmsnorm
|
- no learnable params in rmsnorm
|
||||||
- no bias in linear layers
|
- no bias in linear layers
|
||||||
- Group-Query Attention (GQA) support for more efficient inference
|
- Group-Query Attention (GQA) support for more efficient inference
|
||||||
|
- Flash Attention 3 integration
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
@@ -23,6 +23,14 @@ from nanochat.common import get_dist_info, print0
|
|||||||
from nanochat.muon import Muon, DistMuon
|
from nanochat.muon import Muon, DistMuon
|
||||||
from nanochat.adamw import DistAdamW
|
from nanochat.adamw import DistAdamW
|
||||||
|
|
||||||
|
# Load Flash Attention 3 from HuggingFace Hub (and silence the progress bar)
|
||||||
|
import os
|
||||||
|
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
||||||
|
# Official docs of FA3 label it as "beta" and want you to install FA3 from source, which is a pain.
|
||||||
|
# Wishing for official FA3 wheels soon, for now this seems to be a fast way to get them (ty varunneal)
|
||||||
|
from kernels import get_kernel
|
||||||
|
flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GPTConfig:
|
class GPTConfig:
|
||||||
sequence_len: int = 1024
|
sequence_len: int = 1024
|
||||||
@@ -65,44 +73,36 @@ class CausalSelfAttention(nn.Module):
|
|||||||
B, T, C = x.size()
|
B, T, C = x.size()
|
||||||
|
|
||||||
# Project the input to get queries, keys, and values
|
# Project the input to get queries, keys, and values
|
||||||
|
# Shape: (B, T, H, D) - FA3's native layout, no transpose needed!
|
||||||
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
|
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
|
||||||
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
|
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
|
||||||
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
|
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
|
||||||
|
|
||||||
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
|
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
|
||||||
cos, sin = cos_sin
|
cos, sin = cos_sin
|
||||||
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
|
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
|
||||||
q, k = norm(q), norm(k) # QK norm
|
q, k = norm(q), norm(k) # QK norm
|
||||||
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
|
|
||||||
|
|
||||||
# Apply KV cache: insert current k,v into cache, get the full view so far
|
# Attention with Flash Attention 3
|
||||||
if kv_cache is not None:
|
# FA3 handles GQA automatically when n_kv_heads < n_heads
|
||||||
k, v = kv_cache.insert_kv(self.layer_idx, k, v)
|
if kv_cache is None:
|
||||||
Tq = q.size(2) # number of queries in this forward pass
|
# Training: simple causal attention
|
||||||
Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
|
y = flash_attn.flash_attn_func(q, k, v, causal=True)
|
||||||
|
|
||||||
# Attention: queries attend to keys/values autoregressively. A few cases to handle:
|
|
||||||
enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
|
|
||||||
if kv_cache is None or Tq == Tk:
|
|
||||||
# During training (no KV cache), attend as usual with causal attention
|
|
||||||
# And even if there is KV cache, we can still use this simple version when Tq == Tk
|
|
||||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
|
|
||||||
elif Tq == 1:
|
|
||||||
# During inference but with a single query in this forward pass:
|
|
||||||
# The query has to attend to all the keys/values in the cache
|
|
||||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
|
|
||||||
else:
|
else:
|
||||||
# During inference AND we have a chunk of queries in this forward pass:
|
# Inference: use flash_attn_with_kvcache which handles cache management
|
||||||
# First, each query attends to all the cached keys/values (i.e. full prefix)
|
k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
|
||||||
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
|
y = flash_attn.flash_attn_with_kvcache(
|
||||||
prefix_len = Tk - Tq
|
q, k_cache, v_cache,
|
||||||
attn_mask[:, :prefix_len] = True
|
k=k, v=v,
|
||||||
# Then, causal attention within this chunk
|
cache_seqlens=kv_cache.cache_seqlens,
|
||||||
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
|
causal=True,
|
||||||
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
|
)
|
||||||
|
# Advance position after last layer processes
|
||||||
|
if self.layer_idx == kv_cache.n_layers - 1:
|
||||||
|
kv_cache.advance(T)
|
||||||
|
|
||||||
# Re-assemble the heads side by side and project back to residual stream
|
# Re-assemble the heads and project back to residual stream
|
||||||
y = y.transpose(1, 2).contiguous().view(B, T, -1)
|
y = y.contiguous().view(B, T, -1)
|
||||||
y = self.c_proj(y)
|
y = self.c_proj(y)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ dependencies = [
|
|||||||
"datasets>=4.0.0",
|
"datasets>=4.0.0",
|
||||||
"fastapi>=0.117.1",
|
"fastapi>=0.117.1",
|
||||||
"ipykernel>=7.1.0",
|
"ipykernel>=7.1.0",
|
||||||
|
"kernels>=0.11.7",
|
||||||
"matplotlib>=3.10.8",
|
"matplotlib>=3.10.8",
|
||||||
"psutil>=7.1.0",
|
"psutil>=7.1.0",
|
||||||
"python-dotenv>=1.2.1",
|
"python-dotenv>=1.2.1",
|
||||||
|
|||||||
@@ -39,13 +39,9 @@ class MockModel:
|
|||||||
def forward(self, ids, kv_cache=None):
|
def forward(self, ids, kv_cache=None):
|
||||||
"""Return uniform logits so sampling is spread across vocab."""
|
"""Return uniform logits so sampling is spread across vocab."""
|
||||||
B, T = ids.shape
|
B, T = ids.shape
|
||||||
# Simulate what a real transformer does: insert k,v into the cache for each layer
|
# With FA3, flash_attn_with_kvcache updates cache in-place and we advance position
|
||||||
if kv_cache is not None:
|
if kv_cache is not None:
|
||||||
head_dim = self.config.n_embd // self.config.n_head
|
kv_cache.advance(T)
|
||||||
for layer_idx in range(self.config.n_layer):
|
|
||||||
k = torch.zeros(B, self.config.n_kv_head, T, head_dim)
|
|
||||||
v = torch.zeros(B, self.config.n_kv_head, T, head_dim)
|
|
||||||
kv_cache.insert_kv(layer_idx, k, v)
|
|
||||||
# Uniform logits -> equal probability for all tokens
|
# Uniform logits -> equal probability for all tokens
|
||||||
logits = torch.zeros(B, T, self.vocab_size)
|
logits = torch.zeros(B, T, self.vocab_size)
|
||||||
return logits
|
return logits
|
||||||
@@ -85,16 +81,11 @@ class ByteTokenizer:
|
|||||||
byte_tokens = [t for t in tokens if t < 256]
|
byte_tokens = [t for t in tokens if t < 256]
|
||||||
return bytes(byte_tokens).decode("utf-8", errors="replace")
|
return bytes(byte_tokens).decode("utf-8", errors="replace")
|
||||||
|
|
||||||
def test_kv_cache_resize():
|
def test_kv_cache_basic():
|
||||||
"""
|
"""Test basic KVCache functionality for FA3."""
|
||||||
The KV cache was not resized correctly, more information here:
|
|
||||||
https://github.com/karpathy/nanochat/pull/186
|
|
||||||
This test reproduces the issue and will be merged alongside the fix.
|
|
||||||
"""
|
|
||||||
|
|
||||||
batch_size = 2
|
batch_size = 2
|
||||||
num_heads = 3
|
num_heads = 3
|
||||||
seq_len = 4
|
seq_len = 64
|
||||||
head_dim = 5
|
head_dim = 5
|
||||||
num_layers = 6
|
num_layers = 6
|
||||||
|
|
||||||
@@ -103,45 +94,64 @@ def test_kv_cache_resize():
|
|||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
seq_len=seq_len,
|
seq_len=seq_len,
|
||||||
head_dim=head_dim,
|
head_dim=head_dim,
|
||||||
num_layers=num_layers
|
num_layers=num_layers,
|
||||||
|
device="cpu",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Insert a single token with a distinct fill value to all layers
|
# Check initial state
|
||||||
def insert_token(token_idx):
|
assert kv_cache.get_pos() == 0
|
||||||
for layer_idx in range(num_layers):
|
assert kv_cache.k_cache.shape == (num_layers, batch_size, seq_len, num_heads, head_dim)
|
||||||
k = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx), dtype=torch.float32)
|
assert kv_cache.v_cache.shape == (num_layers, batch_size, seq_len, num_heads, head_dim)
|
||||||
v = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx * 100), dtype=torch.float32)
|
|
||||||
kv_cache.insert_kv(layer_idx, k, v)
|
|
||||||
|
|
||||||
# Insert 4 tokens (fills the initial seq_len=4)
|
# Test advance
|
||||||
for i in range(4):
|
kv_cache.advance(10)
|
||||||
insert_token(i)
|
assert kv_cache.get_pos() == 10
|
||||||
|
|
||||||
# Record the original state of the cache
|
kv_cache.advance(5)
|
||||||
original_cache = kv_cache.kv_cache.clone()
|
assert kv_cache.get_pos() == 15
|
||||||
original_seq_len = original_cache.shape[4]
|
|
||||||
|
|
||||||
# Insert the 5th token, which will trigger a resize
|
# Test reset
|
||||||
insert_token(4)
|
kv_cache.reset()
|
||||||
# Verify that the cache actually resized
|
assert kv_cache.get_pos() == 0
|
||||||
new_seq_len = kv_cache.kv_cache.shape[4]
|
|
||||||
assert new_seq_len > original_seq_len, f"Cache did not resize: original seq_len={original_seq_len}, new seq_len={new_seq_len}"
|
|
||||||
|
|
||||||
# Verify that the original 4 tokens are still intact after resize
|
# Test get_layer_cache returns correct views
|
||||||
for layer_idx in range(num_layers):
|
k_layer0, v_layer0 = kv_cache.get_layer_cache(0)
|
||||||
for token_idx in range(4):
|
assert k_layer0.shape == (batch_size, seq_len, num_heads, head_dim)
|
||||||
# Check that resized cache matches expected values
|
assert v_layer0.shape == (batch_size, seq_len, num_heads, head_dim)
|
||||||
expected_k = float(token_idx)
|
|
||||||
expected_v = float(token_idx * 100)
|
|
||||||
actual_k = kv_cache.kv_cache[layer_idx, 0, :, :, token_idx, :]
|
def test_kv_cache_prefill():
|
||||||
actual_v = kv_cache.kv_cache[layer_idx, 1, :, :, token_idx, :]
|
"""Test KVCache.prefill() copies data correctly."""
|
||||||
assert (actual_k == expected_k).all(), f"Layer {layer_idx}, token {token_idx}: key corrupted, expected {expected_k}"
|
batch_size = 1
|
||||||
assert (actual_v == expected_v).all(), f"Layer {layer_idx}, token {token_idx}: value corrupted, expected {expected_v}"
|
num_heads = 4
|
||||||
# And that the original cache matches resized cache
|
head_dim = 8
|
||||||
original_k = original_cache[layer_idx, 0, :, :, token_idx, :]
|
num_layers = 2
|
||||||
original_v = original_cache[layer_idx, 1, :, :, token_idx, :]
|
|
||||||
assert (actual_k == original_k).all(), f"Layer {layer_idx}, token {token_idx}: key doesn't match original"
|
# Create source cache and advance it
|
||||||
assert (actual_v == original_v).all(), f"Layer {layer_idx}, token {token_idx}: value doesn't match original"
|
src_cache = KVCache(
|
||||||
|
batch_size=batch_size, num_heads=num_heads, seq_len=32,
|
||||||
|
head_dim=head_dim, num_layers=num_layers, device="cpu",
|
||||||
|
)
|
||||||
|
# Write some data to source cache
|
||||||
|
src_cache.k_cache[0, 0, :16, :, :] = 1.0
|
||||||
|
src_cache.v_cache[0, 0, :16, :, :] = 2.0
|
||||||
|
src_cache.advance(16)
|
||||||
|
|
||||||
|
# Create destination cache with larger seq_len
|
||||||
|
dst_cache = KVCache(
|
||||||
|
batch_size=batch_size, num_heads=num_heads, seq_len=64,
|
||||||
|
head_dim=head_dim, num_layers=num_layers, device="cpu",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
dst_cache.prefill(src_cache)
|
||||||
|
|
||||||
|
# Check position was copied
|
||||||
|
assert dst_cache.get_pos() == 16
|
||||||
|
|
||||||
|
# Check data was copied
|
||||||
|
assert (dst_cache.k_cache[0, 0, :16, :, :] == 1.0).all()
|
||||||
|
assert (dst_cache.v_cache[0, 0, :16, :, :] == 2.0).all()
|
||||||
|
|
||||||
|
|
||||||
def test_multi_sample_first_token_diversity():
|
def test_multi_sample_first_token_diversity():
|
||||||
|
|||||||
17
uv.lock
generated
17
uv.lock
generated
@@ -1089,6 +1089,21 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/e7/e7/80988e32bf6f73919a113473a604f5a8f09094de312b9d52b79c2df7612b/jupyter_core-5.9.1-py3-none-any.whl", hash = "sha256:ebf87fdc6073d142e114c72c9e29a9d7ca03fad818c5d300ce2adc1fb0743407", size = 29032, upload-time = "2025-10-16T19:19:16.783Z" },
|
{ url = "https://files.pythonhosted.org/packages/e7/e7/80988e32bf6f73919a113473a604f5a8f09094de312b9d52b79c2df7612b/jupyter_core-5.9.1-py3-none-any.whl", hash = "sha256:ebf87fdc6073d142e114c72c9e29a9d7ca03fad818c5d300ce2adc1fb0743407", size = 29032, upload-time = "2025-10-16T19:19:16.783Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "kernels"
|
||||||
|
version = "0.11.7"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "huggingface-hub" },
|
||||||
|
{ name = "packaging" },
|
||||||
|
{ name = "pyyaml" },
|
||||||
|
{ name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/d6/c8/2d4fea16366d34069af6d4c4f61218f55e5d0daea5d4c24d58849e9fd626/kernels-0.11.7.tar.gz", hash = "sha256:99c3aa518965518902f4dc26053d6051f06abc904ae33d9486c28674a2ea0fa5", size = 50282, upload-time = "2026-01-08T15:41:57.383Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ab/49/e62183353374ec71306ef354781233ac8d12fdfd1cf3d47c875055a99603/kernels-0.11.7-py3-none-any.whl", hash = "sha256:1421791b1e501fcb0a7f0a4d763c5385591756d9d6ed12ed8baa1e0d71bcd21a", size = 46501, upload-time = "2026-01-08T15:41:55.784Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "kiwisolver"
|
name = "kiwisolver"
|
||||||
version = "1.4.9"
|
version = "1.4.9"
|
||||||
@@ -1478,6 +1493,7 @@ dependencies = [
|
|||||||
{ name = "datasets" },
|
{ name = "datasets" },
|
||||||
{ name = "fastapi" },
|
{ name = "fastapi" },
|
||||||
{ name = "ipykernel" },
|
{ name = "ipykernel" },
|
||||||
|
{ name = "kernels" },
|
||||||
{ name = "matplotlib" },
|
{ name = "matplotlib" },
|
||||||
{ name = "psutil" },
|
{ name = "psutil" },
|
||||||
{ name = "python-dotenv" },
|
{ name = "python-dotenv" },
|
||||||
@@ -1518,6 +1534,7 @@ requires-dist = [
|
|||||||
{ name = "datasets", specifier = ">=4.0.0" },
|
{ name = "datasets", specifier = ">=4.0.0" },
|
||||||
{ name = "fastapi", specifier = ">=0.117.1" },
|
{ name = "fastapi", specifier = ">=0.117.1" },
|
||||||
{ name = "ipykernel", specifier = ">=7.1.0" },
|
{ name = "ipykernel", specifier = ">=7.1.0" },
|
||||||
|
{ name = "kernels", specifier = ">=0.11.7" },
|
||||||
{ name = "matplotlib", specifier = ">=3.10.8" },
|
{ name = "matplotlib", specifier = ">=3.10.8" },
|
||||||
{ name = "psutil", specifier = ">=7.1.0" },
|
{ name = "psutil", specifier = ">=7.1.0" },
|
||||||
{ name = "python-dotenv", specifier = ">=1.2.1" },
|
{ name = "python-dotenv", specifier = ">=1.2.1" },
|
||||||
|
|||||||
Reference in New Issue
Block a user