diff --git a/dev/LOG.md b/dev/LOG.md index 72d1207..2f26165 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -4,6 +4,12 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026 --- +## 2026-01-28: Reverted Bigram Hash Embeddings + +Removed bigram embeddings (engram-lite) from the codebase. At larger scale (d25), the improvement was tiny and disappeared entirely when measured by wall clock time. It also bloated the VRAM used. The extra parameters and complexity aren't justified. + +--- + ## 2026-01-27: Bigram Hash Embeddings (Engram-lite) Explored N-gram memory modules inspired by the [DeepSeek Engram paper](https://arxiv.org/abs/2601.07372) and [modded-nanogpt PR #201](https://github.com/KellerJordan/modded-nanogpt/pull/201). diff --git a/nanochat/gpt.py b/nanochat/gpt.py index c55e893..672af71 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -45,41 +45,6 @@ def norm(x): return F.rms_norm(x, (x.size(-1),)) -class BigramEmbed(nn.Module): - """ - Hash bigrams to embeddings. Simple, self-contained, runs on GPU. - Following modded-nanogpt's approach: single hash, no gating. - - For each position t, hashes (token[t-1], token[t]) to an index in a large - embedding table. This provides O(1) lookup for local 2-gram patterns, - offloading static pattern reconstruction from the transformer layers. - - Ref: https://github.com/KellerJordan/modded-nanogpt/pull/201 - Ref: https://arxiv.org/abs/1709.03933 (Hash Embeddings) - """ - def __init__(self, vocab_size: int, embed_dim: int, table_multiplier: int = 5): - super().__init__() - self.bigram_vocab_size = vocab_size * table_multiplier - self.embed = nn.Embedding(self.bigram_vocab_size, embed_dim) - - def forward(self, idx: torch.Tensor) -> torch.Tensor: - """ - idx: (B, T) token ids - Returns: (B, T, embed_dim) bigram embeddings - """ - # Hash (prev_token, curr_token) -> index - # Position 0 gets a reserved index (no valid bigram) - rand_int_1 = 36313 - rand_int_2 = 27191 - mod = self.bigram_vocab_size - 1 - - h = torch.empty_like(idx, dtype=torch.long) - h[:, 0] = mod # reserved index for position 0 - h[:, 1:] = (rand_int_1 * idx[:, 1:] ^ rand_int_2 * idx[:, :-1]) % mod - - return self.embed(h) - - def has_ve(layer_idx, n_layer): """Returns True if GPT layer should have Value Embedding (alternating, last layer always included).""" return layer_idx % 2 == (n_layer - 1) % 2 @@ -204,13 +169,9 @@ class GPT(nn.Module): # Per-layer learnable scalars (inspired by modded-nanogpt) # resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral) # x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled) - # bigram_lambdas: blends bigram embeddings in at each layer (init 0.1 = small contribution) # Separate parameters so they can have different optimizer treatment self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights() self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights() - self.bigram_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights() - # Bigram hash embeddings: O(1) lookup for local 2-gram patterns - self.bigram_embed = BigramEmbed(config.vocab_size, config.n_embd) # Value embeddings (ResFormer-style): alternating layers, last layer always included head_dim = config.n_embd // config.n_head kv_dim = config.n_kv_head * head_dim @@ -259,10 +220,6 @@ class GPT(nn.Module): # Per-layer scalars self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init self.x0_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to input embedding - self.bigram_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to bigram embeddings - - # Bigram embeddings: zero init so it starts as identity - nn.init.zeros_(self.bigram_embed.embed.weight) # Value embeddings (init like c_v: uniform with same std) for ve in self.value_embeds.values(): @@ -283,7 +240,6 @@ class GPT(nn.Module): self.transformer.wte.to(dtype=torch.bfloat16) for ve in self.value_embeds.values(): ve.to(dtype=torch.bfloat16) - self.bigram_embed.to(dtype=torch.bfloat16) def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): # TODO: bump base theta more? e.g. 100K is more common more recently @@ -349,9 +305,8 @@ class GPT(nn.Module): nparams = sum(p.numel() for p in self.parameters()) # Exclude non-matmul params: embeddings and per-layer scalars value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values()) - bigram_embed_numel = self.bigram_embed.embed.weight.numel() - nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + bigram_embed_numel + - self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.bigram_lambdas.numel()) + nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + + self.resid_lambdas.numel() + self.x0_lambdas.numel()) h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len # Sum attention FLOPs per layer, accounting for sliding window attn_flops = 0 @@ -376,16 +331,14 @@ class GPT(nn.Module): """ # Count each group separately (mirrors the grouping in setup_optimizers) wte = sum(p.numel() for p in self.transformer.wte.parameters()) - bigram_embed = sum(p.numel() for p in self.bigram_embed.parameters()) value_embeds = sum(p.numel() for p in self.value_embeds.parameters()) lm_head = sum(p.numel() for p in self.lm_head.parameters()) transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters()) - scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.bigram_lambdas.numel() - total = wte + bigram_embed + value_embeds + lm_head + transformer_matrices + scalars + scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + total = wte + value_embeds + lm_head + transformer_matrices + scalars assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch" return { 'wte': wte, - 'bigram_embed': bigram_embed, 'value_embeds': value_embeds, 'lm_head': lm_head, 'transformer_matrices': transformer_matrices, @@ -403,9 +356,7 @@ class GPT(nn.Module): lm_head_params = list(self.lm_head.parameters()) resid_params = [self.resid_lambdas] x0_params = [self.x0_lambdas] - bigram_embed_params = list(self.bigram_embed.parameters()) - bigram_lambda_params = [self.bigram_lambdas] - assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(bigram_embed_params) + len(bigram_lambda_params) + assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) # Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars # Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model) dmodel_lr_scale = (model_dim / 768) ** -0.5 @@ -414,10 +365,8 @@ class GPT(nn.Module): dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale), dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale), dict(params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding - dict(params=bigram_embed_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream dict(params=x0_params, lr=scalar_lr, betas=(0.96, 0.95)), # higher beta1 for x0 scalars - dict(params=bigram_lambda_params, lr=scalar_lr, betas=(0.96, 0.95)), # same treatment as x0 lambdas ] adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True) @@ -446,11 +395,10 @@ class GPT(nn.Module): # Forward the trunk of the Transformer x = self.transformer.wte(idx) # embed current token - x0_bigram = self.bigram_embed(idx) # embed current bigram (via hash lookup) x = norm(x) x0 = x # save initial normalized embedding for x0 residual for i, block in enumerate(self.transformer.h): - x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 + self.bigram_lambdas[i] * x0_bigram + x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache) x = norm(x)