mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-30 04:22:02 +00:00
revert engram, not seeing an improvement at larger scale
This commit is contained in:
@@ -4,6 +4,12 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-28: Reverted Bigram Hash Embeddings
|
||||
|
||||
Removed bigram embeddings (engram-lite) from the codebase. At larger scale (d25), the improvement was tiny and disappeared entirely when measured by wall clock time. It also bloated the VRAM used. The extra parameters and complexity aren't justified.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-27: Bigram Hash Embeddings (Engram-lite)
|
||||
|
||||
Explored N-gram memory modules inspired by the [DeepSeek Engram paper](https://arxiv.org/abs/2601.07372) and [modded-nanogpt PR #201](https://github.com/KellerJordan/modded-nanogpt/pull/201).
|
||||
|
||||
@@ -45,41 +45,6 @@ def norm(x):
|
||||
return F.rms_norm(x, (x.size(-1),))
|
||||
|
||||
|
||||
class BigramEmbed(nn.Module):
|
||||
"""
|
||||
Hash bigrams to embeddings. Simple, self-contained, runs on GPU.
|
||||
Following modded-nanogpt's approach: single hash, no gating.
|
||||
|
||||
For each position t, hashes (token[t-1], token[t]) to an index in a large
|
||||
embedding table. This provides O(1) lookup for local 2-gram patterns,
|
||||
offloading static pattern reconstruction from the transformer layers.
|
||||
|
||||
Ref: https://github.com/KellerJordan/modded-nanogpt/pull/201
|
||||
Ref: https://arxiv.org/abs/1709.03933 (Hash Embeddings)
|
||||
"""
|
||||
def __init__(self, vocab_size: int, embed_dim: int, table_multiplier: int = 5):
|
||||
super().__init__()
|
||||
self.bigram_vocab_size = vocab_size * table_multiplier
|
||||
self.embed = nn.Embedding(self.bigram_vocab_size, embed_dim)
|
||||
|
||||
def forward(self, idx: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
idx: (B, T) token ids
|
||||
Returns: (B, T, embed_dim) bigram embeddings
|
||||
"""
|
||||
# Hash (prev_token, curr_token) -> index
|
||||
# Position 0 gets a reserved index (no valid bigram)
|
||||
rand_int_1 = 36313
|
||||
rand_int_2 = 27191
|
||||
mod = self.bigram_vocab_size - 1
|
||||
|
||||
h = torch.empty_like(idx, dtype=torch.long)
|
||||
h[:, 0] = mod # reserved index for position 0
|
||||
h[:, 1:] = (rand_int_1 * idx[:, 1:] ^ rand_int_2 * idx[:, :-1]) % mod
|
||||
|
||||
return self.embed(h)
|
||||
|
||||
|
||||
def has_ve(layer_idx, n_layer):
|
||||
"""Returns True if GPT layer should have Value Embedding (alternating, last layer always included)."""
|
||||
return layer_idx % 2 == (n_layer - 1) % 2
|
||||
@@ -204,13 +169,9 @@ class GPT(nn.Module):
|
||||
# Per-layer learnable scalars (inspired by modded-nanogpt)
|
||||
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
|
||||
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
|
||||
# bigram_lambdas: blends bigram embeddings in at each layer (init 0.1 = small contribution)
|
||||
# Separate parameters so they can have different optimizer treatment
|
||||
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
|
||||
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
||||
self.bigram_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
||||
# Bigram hash embeddings: O(1) lookup for local 2-gram patterns
|
||||
self.bigram_embed = BigramEmbed(config.vocab_size, config.n_embd)
|
||||
# Value embeddings (ResFormer-style): alternating layers, last layer always included
|
||||
head_dim = config.n_embd // config.n_head
|
||||
kv_dim = config.n_kv_head * head_dim
|
||||
@@ -259,10 +220,6 @@ class GPT(nn.Module):
|
||||
# Per-layer scalars
|
||||
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
|
||||
self.x0_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to input embedding
|
||||
self.bigram_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to bigram embeddings
|
||||
|
||||
# Bigram embeddings: zero init so it starts as identity
|
||||
nn.init.zeros_(self.bigram_embed.embed.weight)
|
||||
|
||||
# Value embeddings (init like c_v: uniform with same std)
|
||||
for ve in self.value_embeds.values():
|
||||
@@ -283,7 +240,6 @@ class GPT(nn.Module):
|
||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||
for ve in self.value_embeds.values():
|
||||
ve.to(dtype=torch.bfloat16)
|
||||
self.bigram_embed.to(dtype=torch.bfloat16)
|
||||
|
||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
||||
# TODO: bump base theta more? e.g. 100K is more common more recently
|
||||
@@ -349,9 +305,8 @@ class GPT(nn.Module):
|
||||
nparams = sum(p.numel() for p in self.parameters())
|
||||
# Exclude non-matmul params: embeddings and per-layer scalars
|
||||
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
|
||||
bigram_embed_numel = self.bigram_embed.embed.weight.numel()
|
||||
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + bigram_embed_numel +
|
||||
self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.bigram_lambdas.numel())
|
||||
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
|
||||
self.resid_lambdas.numel() + self.x0_lambdas.numel())
|
||||
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
||||
# Sum attention FLOPs per layer, accounting for sliding window
|
||||
attn_flops = 0
|
||||
@@ -376,16 +331,14 @@ class GPT(nn.Module):
|
||||
"""
|
||||
# Count each group separately (mirrors the grouping in setup_optimizers)
|
||||
wte = sum(p.numel() for p in self.transformer.wte.parameters())
|
||||
bigram_embed = sum(p.numel() for p in self.bigram_embed.parameters())
|
||||
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
|
||||
lm_head = sum(p.numel() for p in self.lm_head.parameters())
|
||||
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
|
||||
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.bigram_lambdas.numel()
|
||||
total = wte + bigram_embed + value_embeds + lm_head + transformer_matrices + scalars
|
||||
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel()
|
||||
total = wte + value_embeds + lm_head + transformer_matrices + scalars
|
||||
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
|
||||
return {
|
||||
'wte': wte,
|
||||
'bigram_embed': bigram_embed,
|
||||
'value_embeds': value_embeds,
|
||||
'lm_head': lm_head,
|
||||
'transformer_matrices': transformer_matrices,
|
||||
@@ -403,9 +356,7 @@ class GPT(nn.Module):
|
||||
lm_head_params = list(self.lm_head.parameters())
|
||||
resid_params = [self.resid_lambdas]
|
||||
x0_params = [self.x0_lambdas]
|
||||
bigram_embed_params = list(self.bigram_embed.parameters())
|
||||
bigram_lambda_params = [self.bigram_lambdas]
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(bigram_embed_params) + len(bigram_lambda_params)
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)
|
||||
# Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars
|
||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
|
||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||
@@ -414,10 +365,8 @@ class GPT(nn.Module):
|
||||
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
|
||||
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
|
||||
dict(params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding
|
||||
dict(params=bigram_embed_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding
|
||||
dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream
|
||||
dict(params=x0_params, lr=scalar_lr, betas=(0.96, 0.95)), # higher beta1 for x0 scalars
|
||||
dict(params=bigram_lambda_params, lr=scalar_lr, betas=(0.96, 0.95)), # same treatment as x0 lambdas
|
||||
]
|
||||
adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon
|
||||
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
|
||||
@@ -446,11 +395,10 @@ class GPT(nn.Module):
|
||||
|
||||
# Forward the trunk of the Transformer
|
||||
x = self.transformer.wte(idx) # embed current token
|
||||
x0_bigram = self.bigram_embed(idx) # embed current bigram (via hash lookup)
|
||||
x = norm(x)
|
||||
x0 = x # save initial normalized embedding for x0 residual
|
||||
for i, block in enumerate(self.transformer.h):
|
||||
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 + self.bigram_lambdas[i] * x0_bigram
|
||||
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
||||
ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None
|
||||
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
|
||||
x = norm(x)
|
||||
|
||||
Reference in New Issue
Block a user