revert engram, not seeing an improvement at larger scale

This commit is contained in:
Andrej Karpathy
2026-01-28 20:07:39 +00:00
parent d5418ea5a1
commit 74554be3b5
2 changed files with 12 additions and 58 deletions

View File

@@ -4,6 +4,12 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
--- ---
## 2026-01-28: Reverted Bigram Hash Embeddings
Removed bigram embeddings (engram-lite) from the codebase. At larger scale (d25), the improvement was tiny and disappeared entirely when measured by wall clock time. It also bloated the VRAM used. The extra parameters and complexity aren't justified.
---
## 2026-01-27: Bigram Hash Embeddings (Engram-lite) ## 2026-01-27: Bigram Hash Embeddings (Engram-lite)
Explored N-gram memory modules inspired by the [DeepSeek Engram paper](https://arxiv.org/abs/2601.07372) and [modded-nanogpt PR #201](https://github.com/KellerJordan/modded-nanogpt/pull/201). Explored N-gram memory modules inspired by the [DeepSeek Engram paper](https://arxiv.org/abs/2601.07372) and [modded-nanogpt PR #201](https://github.com/KellerJordan/modded-nanogpt/pull/201).

View File

@@ -45,41 +45,6 @@ def norm(x):
return F.rms_norm(x, (x.size(-1),)) return F.rms_norm(x, (x.size(-1),))
class BigramEmbed(nn.Module):
"""
Hash bigrams to embeddings. Simple, self-contained, runs on GPU.
Following modded-nanogpt's approach: single hash, no gating.
For each position t, hashes (token[t-1], token[t]) to an index in a large
embedding table. This provides O(1) lookup for local 2-gram patterns,
offloading static pattern reconstruction from the transformer layers.
Ref: https://github.com/KellerJordan/modded-nanogpt/pull/201
Ref: https://arxiv.org/abs/1709.03933 (Hash Embeddings)
"""
def __init__(self, vocab_size: int, embed_dim: int, table_multiplier: int = 5):
super().__init__()
self.bigram_vocab_size = vocab_size * table_multiplier
self.embed = nn.Embedding(self.bigram_vocab_size, embed_dim)
def forward(self, idx: torch.Tensor) -> torch.Tensor:
"""
idx: (B, T) token ids
Returns: (B, T, embed_dim) bigram embeddings
"""
# Hash (prev_token, curr_token) -> index
# Position 0 gets a reserved index (no valid bigram)
rand_int_1 = 36313
rand_int_2 = 27191
mod = self.bigram_vocab_size - 1
h = torch.empty_like(idx, dtype=torch.long)
h[:, 0] = mod # reserved index for position 0
h[:, 1:] = (rand_int_1 * idx[:, 1:] ^ rand_int_2 * idx[:, :-1]) % mod
return self.embed(h)
def has_ve(layer_idx, n_layer): def has_ve(layer_idx, n_layer):
"""Returns True if GPT layer should have Value Embedding (alternating, last layer always included).""" """Returns True if GPT layer should have Value Embedding (alternating, last layer always included)."""
return layer_idx % 2 == (n_layer - 1) % 2 return layer_idx % 2 == (n_layer - 1) % 2
@@ -204,13 +169,9 @@ class GPT(nn.Module):
# Per-layer learnable scalars (inspired by modded-nanogpt) # Per-layer learnable scalars (inspired by modded-nanogpt)
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral) # resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled) # x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
# bigram_lambdas: blends bigram embeddings in at each layer (init 0.1 = small contribution)
# Separate parameters so they can have different optimizer treatment # Separate parameters so they can have different optimizer treatment
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights() self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights() self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
self.bigram_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
# Bigram hash embeddings: O(1) lookup for local 2-gram patterns
self.bigram_embed = BigramEmbed(config.vocab_size, config.n_embd)
# Value embeddings (ResFormer-style): alternating layers, last layer always included # Value embeddings (ResFormer-style): alternating layers, last layer always included
head_dim = config.n_embd // config.n_head head_dim = config.n_embd // config.n_head
kv_dim = config.n_kv_head * head_dim kv_dim = config.n_kv_head * head_dim
@@ -259,10 +220,6 @@ class GPT(nn.Module):
# Per-layer scalars # Per-layer scalars
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
self.x0_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to input embedding self.x0_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to input embedding
self.bigram_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to bigram embeddings
# Bigram embeddings: zero init so it starts as identity
nn.init.zeros_(self.bigram_embed.embed.weight)
# Value embeddings (init like c_v: uniform with same std) # Value embeddings (init like c_v: uniform with same std)
for ve in self.value_embeds.values(): for ve in self.value_embeds.values():
@@ -283,7 +240,6 @@ class GPT(nn.Module):
self.transformer.wte.to(dtype=torch.bfloat16) self.transformer.wte.to(dtype=torch.bfloat16)
for ve in self.value_embeds.values(): for ve in self.value_embeds.values():
ve.to(dtype=torch.bfloat16) ve.to(dtype=torch.bfloat16)
self.bigram_embed.to(dtype=torch.bfloat16)
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
# TODO: bump base theta more? e.g. 100K is more common more recently # TODO: bump base theta more? e.g. 100K is more common more recently
@@ -349,9 +305,8 @@ class GPT(nn.Module):
nparams = sum(p.numel() for p in self.parameters()) nparams = sum(p.numel() for p in self.parameters())
# Exclude non-matmul params: embeddings and per-layer scalars # Exclude non-matmul params: embeddings and per-layer scalars
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values()) value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
bigram_embed_numel = self.bigram_embed.embed.weight.numel() nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + bigram_embed_numel + self.resid_lambdas.numel() + self.x0_lambdas.numel())
self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.bigram_lambdas.numel())
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
# Sum attention FLOPs per layer, accounting for sliding window # Sum attention FLOPs per layer, accounting for sliding window
attn_flops = 0 attn_flops = 0
@@ -376,16 +331,14 @@ class GPT(nn.Module):
""" """
# Count each group separately (mirrors the grouping in setup_optimizers) # Count each group separately (mirrors the grouping in setup_optimizers)
wte = sum(p.numel() for p in self.transformer.wte.parameters()) wte = sum(p.numel() for p in self.transformer.wte.parameters())
bigram_embed = sum(p.numel() for p in self.bigram_embed.parameters())
value_embeds = sum(p.numel() for p in self.value_embeds.parameters()) value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
lm_head = sum(p.numel() for p in self.lm_head.parameters()) lm_head = sum(p.numel() for p in self.lm_head.parameters())
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters()) transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.bigram_lambdas.numel() scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel()
total = wte + bigram_embed + value_embeds + lm_head + transformer_matrices + scalars total = wte + value_embeds + lm_head + transformer_matrices + scalars
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch" assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
return { return {
'wte': wte, 'wte': wte,
'bigram_embed': bigram_embed,
'value_embeds': value_embeds, 'value_embeds': value_embeds,
'lm_head': lm_head, 'lm_head': lm_head,
'transformer_matrices': transformer_matrices, 'transformer_matrices': transformer_matrices,
@@ -403,9 +356,7 @@ class GPT(nn.Module):
lm_head_params = list(self.lm_head.parameters()) lm_head_params = list(self.lm_head.parameters())
resid_params = [self.resid_lambdas] resid_params = [self.resid_lambdas]
x0_params = [self.x0_lambdas] x0_params = [self.x0_lambdas]
bigram_embed_params = list(self.bigram_embed.parameters()) assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)
bigram_lambda_params = [self.bigram_lambdas]
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(bigram_embed_params) + len(bigram_lambda_params)
# Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars # Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model) # Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
dmodel_lr_scale = (model_dim / 768) ** -0.5 dmodel_lr_scale = (model_dim / 768) ** -0.5
@@ -414,10 +365,8 @@ class GPT(nn.Module):
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale), dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale), dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
dict(params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding dict(params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding
dict(params=bigram_embed_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding
dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream
dict(params=x0_params, lr=scalar_lr, betas=(0.96, 0.95)), # higher beta1 for x0 scalars dict(params=x0_params, lr=scalar_lr, betas=(0.96, 0.95)), # higher beta1 for x0 scalars
dict(params=bigram_lambda_params, lr=scalar_lr, betas=(0.96, 0.95)), # same treatment as x0 lambdas
] ]
adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True) AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
@@ -446,11 +395,10 @@ class GPT(nn.Module):
# Forward the trunk of the Transformer # Forward the trunk of the Transformer
x = self.transformer.wte(idx) # embed current token x = self.transformer.wte(idx) # embed current token
x0_bigram = self.bigram_embed(idx) # embed current bigram (via hash lookup)
x = norm(x) x = norm(x)
x0 = x # save initial normalized embedding for x0 residual x0 = x # save initial normalized embedding for x0 residual
for i, block in enumerate(self.transformer.h): for i, block in enumerate(self.transformer.h):
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 + self.bigram_lambdas[i] * x0_bigram x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache) x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
x = norm(x) x = norm(x)