Merge branch 'engram'

This commit is contained in:
Andrej Karpathy
2026-01-27 22:33:16 +00:00
5 changed files with 404 additions and 43 deletions

View File

@@ -4,6 +4,140 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
--- ---
## 2026-01-27: Bigram Hash Embeddings (Engram-lite)
Explored N-gram memory modules inspired by the [DeepSeek Engram paper](https://arxiv.org/abs/2506.08046) and [modded-nanogpt PR #201](https://github.com/KellerJordan/modded-nanogpt/pull/201).
### Background
The Engram paper introduces "conditional memory" as a complement to MoE - using O(1) hash lookups to retrieve static N-gram patterns instead of reconstructing them through computation. Key insight: transformers waste early layers "simulating retrieval through computation" for patterns like named entities and formulaic phrases that could be simple table lookups.
### What We Tried
**1. Full Engram module with context-aware gating (paper design)**
```python
# Hash bigrams to retrieve embeddings, then gate with hidden state
e = embed(hash(prev_token, curr_token))
q = RMSNorm(h) # hidden state as query
k = RMSNorm(W_k @ e) # projected embedding as key
v = W_v @ e
α = sigmoid(q · k / d) # scalar gate per position
output = α * v
```
- Injected after block 1 (paper found early injection optimal)
- Slight improvement, but quite a bit of complexity added.
**2. Early-layer only injection**
- Only inject bigram signal in first 4 layers (where paper claims static pattern offloading helps most)
- **Result:** Actually hurt performance. The model seems to need uniform injection across all layers.
**3. Trigrams**
- Extended to hash both 2-grams and 3-grams, concatenating embeddings
- **Result:** No improvement over bigrams alone. Dilutes capacity from more frequent 2-gram patterns.
**4. Bigram-only with x0-style injection (modded-nanogpt engram-lite approach)**
- Simple hash: `(36313 * curr) XOR (27191 * prev) mod table_size`
- Zero-init embedding table, learned per-layer lambdas
- Add to residual at every layer: `x = resid_λ[i]*x + x0_λ[i]*x0 + bigram_λ[i]*x0_bigram`
- **Result:** This simple approach works and provides a consistent improvement.
TLDR The winning approach follows modded-nanogpt's "engram-lite", simply adding the following module and feeding its output into the residual branch (gated by a per-layer learnable \lambda) before every single block:
```python
class BigramEmbed(nn.Module):
def __init__(self, vocab_size, embed_dim, table_multiplier=5):
self.embed = nn.Embedding(vocab_size * table_multiplier, embed_dim)
def forward(self, idx):
h = (36313 * idx[:, 1:]) ^ (27191 * idx[:, :-1]) % (table_size - 1)
return self.embed(h)
```
As for optimal hyperparameters:
- **Table size:** `vocab_size * 5` (~164K entries for 32K vocab). Swept a number of settings and 5 was optimal.
- **Injection:** Every layer via learned `bigram_lambdas` (init 0.1 was better than 0.0).
- **Normalization:** Also tried adding a `norm()` to the embeddings (mirroring the token embeddings), this was slightly worse.
- **Init:** Zero-init embedding, so starts as identity (tried small noisy init, it's worse)
- **Optimizer:** AdamW with same LR as token embeddings
### Key Learnings
1. **Gating didn't help at our scale.** The paper's context-aware gating mechanism (sigmoid dot-product gate) added parameters and complexity without improvement. modded-nanogpt found the same: "simple direct addition to the residual stream outperformed by a decent margin."
2. **Uniform injection beats early-only.** Despite the paper's finding that early layers benefit most, restricting injection to early layers hurt. The x0-style "add everywhere with learned lambda" pattern works better for our architecture/scale.
3. **Bigrams are sufficient.** Trigrams didn't help - the extra context doesn't pay for the diluted capacity.
4. **Scale matters.** The Engram paper's results are at 27B params with MoE. At our ~100M-1B scale, the simpler approach wins. The elaborate gating mechanism may become useful at larger scales where collision handling matters more.
### Parameters Added
For d12 model with `table_multiplier=5`:
- Bigram embedding: 32768 × 5 × 768 = ~126M params
- Per-layer lambdas: 12 scalars (negligible)
If you're keeping track, we now have *a lot* of parameters, a significant amount of them in embeddings (token embeddings, bigram embeddings, value embeddings). For example, for a d12 we now have:
```
Parameter counts:
wte : 25,165,824
bigram_embed : 125,829,120
value_embeds : 150,994,944
lm_head : 25,165,824
transformer_matrices : 84,935,808
scalars : 36
total : 412,091,556
```
In other words, only about a quarter of parameters are now weight projections and the vast majority is embedding tables.
Still, on all axes (steps, wall clock time, flops), this somewhat parameter-bloated architecture beats the baseline and will now become the default.
After adding the engram-lite, I re-ran the scaling laws to determine the new optimal tokens:params ratio. I swept FLOPs in the range 1e18..1e19, exponentially strided in 4 settings (1e18, 2e18, 5e18, 1e19). I looked at a number of ways of determining the effective parameter count for the purposes of the scaling laws. The results looked like this:
```
Kaplan-style (all projections including lm_head and no embeddings)
Optimal configurations (from quadratic fits):
FLOPs Eff Params Tokens Ratio Val BPB
-----------------------------------------------------------------
1e+18 110,678,115 1,241,505,403 11.2 0.8972
2e+18 167,797,457 1,785,336,422 10.7 0.8616
5e+18 250,650,865 2,642,234,152 10.8 0.8293
1e+19 381,758,347 3,806,871,243 10.3 0.7999
N \propto C^0.54, D \propto C^0.49
Chinchilla-style (all parameters, period.)
Optimal configurations (from quadratic fits):
FLOPs Eff Params Tokens Ratio Val BPB
-----------------------------------------------------------------
1e+18 416,320,605 1,232,157,011 3.0 0.8974
2e+18 560,239,841 1,763,669,281 3.2 0.8616
5e+18 741,495,903 2,629,909,368 3.6 0.8291
1e+19 988,644,331 3,884,841,895 4.0 0.7999
N \propto C^0.37, D \propto C^0.50
Transformer-only-style (only the projections inside the transformer)
Optimal configurations (from quadratic fits):
FLOPs Eff Params Tokens Ratio Val BPB
-----------------------------------------------------------------
1e+18 80,259,665 1,315,639,547 17.2 0.8966
2e+18 131,488,566 1,864,134,141 14.5 0.8622
5e+18 220,985,474 2,595,328,843 12.1 0.8302
1e+19 401,213,504 3,328,704,512 8.5 0.7994
N \propto C^0.70, D \propto C^0.41
```
Clearly, the Kaplan-style ratios are most consistent and produce stable ~0.5 exponents for both params and tokens, meaning we can have a single fixed ratio of tokens:params for compute optimal models. This turns out to be about ~10.5, which now becomes the new default.
---
## 2026-01-19 to 2026-01-22: Optimizer Hyperparameter Sweep ## 2026-01-19 to 2026-01-22: Optimizer Hyperparameter Sweep
Ran ~320 experiments across 6 rounds, scaling from d12→d16→d20 to find optimal optimizer hyperparameters. Added granular per-component control to `setup_optimizers()` — separate LRs and betas for embedding, unembedding, value_embeds, resid_lambdas, x0_lambdas, and Muon matrix params. Ran ~320 experiments across 6 rounds, scaling from d12→d16→d20 to find optimal optimizer hyperparameters. Added granular per-component control to `setup_optimizers()` — separate LRs and betas for embedding, unembedding, value_embeds, resid_lambdas, x0_lambdas, and Muon matrix params.

View File

@@ -15,14 +15,16 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"%matplotlib inline\n",
"import os\n", "import os\n",
"import pandas as pd\n", "import pandas as pd\n",
"import numpy as np\n", "import numpy as np\n",
"import matplotlib.pyplot as plt\n", "import matplotlib.pyplot as plt\n",
"\n", "\n",
"# Load results\n", "# Load results\n",
"tag = \"jan26\"\n",
"base_dir = os.environ.get('NANOCHAT_BASE_DIR', os.path.expanduser('~/.cache/nanochat'))\n", "base_dir = os.environ.get('NANOCHAT_BASE_DIR', os.path.expanduser('~/.cache/nanochat'))\n",
"results_path = os.path.join(base_dir, 'scaling_laws_results', 'results.csv')\n", "results_path = os.path.join(base_dir, f'scaling_laws_results_{tag}', 'results.csv')\n",
"\n", "\n",
"df = pd.read_csv(results_path)\n", "df = pd.read_csv(results_path)\n",
"flops_budgets = sorted(df['flops_budget'].unique())\n", "flops_budgets = sorted(df['flops_budget'].unique())\n",
@@ -31,6 +33,99 @@
"df" "df"
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# =============================================================================\n",
"# FILTERING: Remove incomplete or problematic runs\n",
"# =============================================================================\n",
"\n",
"print(f\"Before filtering: {len(df)} runs\")\n",
"\n",
"# Filter out runs with missing/invalid val_bpb (incomplete runs)\n",
"df = df[df['val_bpb'].notna() & (df['val_bpb'] > 0)]\n",
"\n",
"# Optional: exclude specific flops budgets that aren't done yet\n",
"# exclude_flops = [1e19] # <-- adjust as runs complete\n",
"# df = df[~df['flops_budget'].isin(exclude_flops)]\n",
"\n",
"# Optional: exclude specific depths\n",
"# exclude_depths = [18, 20]\n",
"# df = df[~df['depth'].isin(exclude_depths)]\n",
"\n",
"print(f\"After filtering: {len(df)} runs\")\n",
"print(f\"FLOPs budgets: {sorted(df['flops_budget'].unique())}\")\n",
"print(f\"Depths: {sorted(df['depth'].unique())}\")\n",
"\n",
"# Update flops_budgets list after filtering\n",
"flops_budgets = sorted(df['flops_budget'].unique())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Effective Parameter Count\n",
"\n",
"Different scaling law papers use different conventions for counting parameters:\n",
"- **Kaplan et al.** excluded embedding parameters (claimed cleaner laws)\n",
"- **Chinchilla** included all parameters (and noted Kaplan had a bug)\n",
"\n",
"Our CSV now has granular counts:\n",
"- `params_wte` - token embedding (lookup table)\n",
"- `params_bigram_embed` - bigram hash embeddings (lookup table)\n",
"- `params_value_embeds` - value embeddings (lookup table)\n",
"- `params_lm_head` - unembedding projection (matmul)\n",
"- `params_transformer` - attention + MLP matrices (matmuls)\n",
"- `params_scalars` - resid/x0/bigram lambdas (tiny)\n",
"\n",
"**Experiment below** with different combinations to see which gives the cleanest scaling laws."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# =============================================================================\n",
"# EXPERIMENT HERE: Define which parameters to count for scaling laws\n",
"# =============================================================================\n",
"\n",
"def compute_effective_params(row):\n",
" \"\"\"\n",
" Compute the 'effective' parameter count for scaling law analysis.\n",
"\n",
" Modify this function to experiment with different conventions:\n",
" - Chinchilla-style: include everything\n",
" - Kaplan-style: exclude embeddings\n",
" - Matmul-only: just transformer + lm_head (the actual compute)\n",
" - etc.\n",
" \"\"\"\n",
" # Option 1: Chinchilla-style (all params)\n",
" # return row['params_total']\n",
"\n",
" # Option 2: Kaplan-style (exclude embeddings)\n",
" return row['params_transformer'] + row['params_lm_head']\n",
"\n",
" # Option 3: Transformer-only (exclude all embeddings AND lm_head)\n",
" # return row['params_transformer']\n",
"\n",
"\n",
"# Compute derived columns\n",
"df['effective_params'] = df.apply(compute_effective_params, axis=1)\n",
"df['param_data_ratio'] = df['tokens_trained'] / df['effective_params']\n",
"\n",
"# Show parameter breakdown for first few rows\n",
"print(\"Parameter breakdown (first row per flops budget):\")\n",
"param_cols = ['depth', 'params_wte', 'params_bigram_embed', 'params_value_embeds',\n",
" 'params_lm_head', 'params_transformer', 'params_scalars', 'params_total', 'effective_params']\n",
"df.groupby('flops_budget').first()[param_cols]"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
@@ -54,11 +149,11 @@
"optimal_by_bpb = []\n", "optimal_by_bpb = []\n",
"\n", "\n",
"for flops, color in zip(flops_budgets, colors):\n", "for flops, color in zip(flops_budgets, colors):\n",
" subset = df[df['flops_budget'] == flops].sort_values('num_scaling_params')\n", " subset = df[df['flops_budget'] == flops].sort_values('effective_params')\n",
" ax.plot(subset['num_scaling_params'], subset['val_bpb'], 'o', color=color, label=f'{flops:.0e}', markersize=8)\n", " ax.plot(subset['effective_params'], subset['val_bpb'], 'o', color=color, label=f'{flops:.0e}', markersize=8)\n",
"\n", "\n",
" # Fit quadratic in log-space: val_bpb = a*(log N)^2 + b*(log N) + c\n", " # Fit quadratic in log-space: val_bpb = a*(log N)^2 + b*(log N) + c\n",
" log_params = np.log10(subset['num_scaling_params'])\n", " log_params = np.log10(subset['effective_params'])\n",
" coeffs = np.polyfit(log_params, subset['val_bpb'], 2)\n", " coeffs = np.polyfit(log_params, subset['val_bpb'], 2)\n",
" a, b, c = coeffs\n", " a, b, c = coeffs\n",
"\n", "\n",
@@ -83,13 +178,13 @@
" # Fallback to raw minimum if quadratic doesn't have minimum\n", " # Fallback to raw minimum if quadratic doesn't have minimum\n",
" best_idx = subset['val_bpb'].idxmin()\n", " best_idx = subset['val_bpb'].idxmin()\n",
" best = subset.loc[best_idx]\n", " best = subset.loc[best_idx]\n",
" ax.scatter([best['num_scaling_params']], [best['val_bpb']], s=150, color=color,\n", " ax.scatter([best['effective_params']], [best['val_bpb']], s=150, color=color,\n",
" zorder=5, edgecolors='black', linewidths=2)\n", " zorder=5, edgecolors='black', linewidths=2)\n",
" optimal_by_bpb.append({'flops': flops, 'params': best['num_scaling_params'],\n", " optimal_by_bpb.append({'flops': flops, 'params': best['effective_params'],\n",
" 'tokens': best['tokens_trained'], 'ratio': best['param_data_ratio'], 'bpb': best['val_bpb']})\n", " 'tokens': best['tokens_trained'], 'ratio': best['param_data_ratio'], 'bpb': best['val_bpb']})\n",
"\n", "\n",
"ax.set_xscale('log')\n", "ax.set_xscale('log')\n",
"ax.set_xlabel('Parameters')\n", "ax.set_xlabel('Effective Parameters')\n",
"ax.set_ylabel('Validation Loss (bpb)')\n", "ax.set_ylabel('Validation Loss (bpb)')\n",
"ax.set_title('IsoFLOP Curves')\n", "ax.set_title('IsoFLOP Curves')\n",
"ax.legend(title='FLOPs', loc='upper right')\n", "ax.legend(title='FLOPs', loc='upper right')\n",
@@ -138,10 +233,61 @@
"\n", "\n",
"# Print the optimal points (from quadratic fits)\n", "# Print the optimal points (from quadratic fits)\n",
"print(\"\\nOptimal configurations (from quadratic fits):\")\n", "print(\"\\nOptimal configurations (from quadratic fits):\")\n",
"print(f\"{'FLOPs':<12} {'Params':<15} {'Tokens':<15} {'Ratio':<10} {'Val BPB':<10}\")\n", "print(f\"{'FLOPs':<12} {'Eff Params':<15} {'Tokens':<15} {'Ratio':<10} {'Val BPB':<10}\")\n",
"print(\"-\" * 65)\n", "print(\"-\" * 65)\n",
"for _, row in opt_df.iterrows():\n", "for _, row in opt_df.iterrows():\n",
" print(f\"{row['flops']:<12.0e} {int(row['params']):<15,} {int(row['tokens']):<15,} {row['ratio']:<10.1f} {row['bpb']:<10.4f}\")\n" " print(f\"{row['flops']:<12.0e} {int(row['params']):<15,} {int(row['tokens']):<15,} {row['ratio']:<10.1f} {row['bpb']:<10.4f}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# =============================================================================\n",
"# Optimal Ratio Summary (from power law fits)\n",
"# =============================================================================\n",
"\n",
"# From the power law fits: N ∝ C^a and D ∝ C^b\n",
"# The ratio D/N ∝ C^(b-a). If a ≈ b, ratio is roughly constant.\n",
"\n",
"if len(opt_df) >= 2:\n",
" log_f = np.log10(opt_df['flops'])\n",
" log_p = np.log10(opt_df['params'])\n",
" log_t = np.log10(opt_df['tokens'])\n",
"\n",
" # Fit power laws\n",
" slope_n, intercept_n = np.polyfit(log_f, log_p, 1)\n",
" slope_d, intercept_d = np.polyfit(log_f, log_t, 1)\n",
"\n",
" # The ratio D/N at a reference compute (geometric mean of our budgets)\n",
" ref_flops = np.sqrt(opt_df['flops'].min() * opt_df['flops'].max())\n",
" log_ref = np.log10(ref_flops)\n",
"\n",
" # Predicted optimal N and D at reference compute\n",
" pred_log_n = intercept_n + slope_n * log_ref\n",
" pred_log_d = intercept_d + slope_d * log_ref\n",
" optimal_ratio = 10**(pred_log_d - pred_log_n)\n",
"\n",
" # Also compute from the fitted optimals directly (mean and std)\n",
" mean_ratio = opt_df['ratio'].mean()\n",
" std_ratio = opt_df['ratio'].std()\n",
"\n",
" print(\"=\" * 60)\n",
" print(\"OPTIMAL RATIO SUMMARY\")\n",
" print(\"=\" * 60)\n",
" print(f\"\\nPower law exponents:\")\n",
" print(f\" N ∝ C^{slope_n:.3f}\")\n",
" print(f\" D ∝ C^{slope_d:.3f}\")\n",
" print(f\" Ratio exponent (b-a): {slope_d - slope_n:.3f} (should be ~0 if ratio is constant)\")\n",
" print(f\"\\nOptimal ratio (tokens per effective param):\")\n",
" print(f\" From power law at C={ref_flops:.1e}: {optimal_ratio:.1f}\")\n",
" print(f\" Mean across budgets: {mean_ratio:.1f} ± {std_ratio:.1f}\")\n",
" print(f\" Chinchilla reference: 20\")\n",
" print(f\"\\nPer-budget ratios: {[f'{r:.1f}' for r in opt_df['ratio'].values]}\")\n",
"else:\n",
" print(\"Need at least 2 flops budgets to compute power law fits\")"
] ]
}, },
{ {

View File

@@ -45,6 +45,41 @@ def norm(x):
return F.rms_norm(x, (x.size(-1),)) return F.rms_norm(x, (x.size(-1),))
class BigramEmbed(nn.Module):
"""
Hash bigrams to embeddings. Simple, self-contained, runs on GPU.
Following modded-nanogpt's approach: single hash, no gating.
For each position t, hashes (token[t-1], token[t]) to an index in a large
embedding table. This provides O(1) lookup for local 2-gram patterns,
offloading static pattern reconstruction from the transformer layers.
Ref: https://github.com/KellerJordan/modded-nanogpt/pull/201
Ref: https://arxiv.org/abs/1709.03933 (Hash Embeddings)
"""
def __init__(self, vocab_size: int, embed_dim: int, table_multiplier: int = 5):
super().__init__()
self.bigram_vocab_size = vocab_size * table_multiplier
self.embed = nn.Embedding(self.bigram_vocab_size, embed_dim)
def forward(self, idx: torch.Tensor) -> torch.Tensor:
"""
idx: (B, T) token ids
Returns: (B, T, embed_dim) bigram embeddings
"""
# Hash (prev_token, curr_token) -> index
# Position 0 gets a reserved index (no valid bigram)
rand_int_1 = 36313
rand_int_2 = 27191
mod = self.bigram_vocab_size - 1
h = torch.empty_like(idx, dtype=torch.long)
h[:, 0] = mod # reserved index for position 0
h[:, 1:] = (rand_int_1 * idx[:, 1:] ^ rand_int_2 * idx[:, :-1]) % mod
return self.embed(h)
def has_ve(layer_idx, n_layer): def has_ve(layer_idx, n_layer):
"""Returns True if GPT layer should have Value Embedding (alternating, last layer always included).""" """Returns True if GPT layer should have Value Embedding (alternating, last layer always included)."""
return layer_idx % 2 == (n_layer - 1) % 2 return layer_idx % 2 == (n_layer - 1) % 2
@@ -169,9 +204,13 @@ class GPT(nn.Module):
# Per-layer learnable scalars (inspired by modded-nanogpt) # Per-layer learnable scalars (inspired by modded-nanogpt)
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral) # resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled) # x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
# bigram_lambdas: blends bigram embeddings in at each layer (init 0.1 = small contribution)
# Separate parameters so they can have different optimizer treatment # Separate parameters so they can have different optimizer treatment
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights() self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights() self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
self.bigram_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
# Bigram hash embeddings: O(1) lookup for local 2-gram patterns
self.bigram_embed = BigramEmbed(config.vocab_size, config.n_embd)
# Value embeddings (ResFormer-style): alternating layers, last layer always included # Value embeddings (ResFormer-style): alternating layers, last layer always included
head_dim = config.n_embd // config.n_head head_dim = config.n_embd // config.n_head
kv_dim = config.n_kv_head * head_dim kv_dim = config.n_kv_head * head_dim
@@ -219,7 +258,11 @@ class GPT(nn.Module):
# Per-layer scalars # Per-layer scalars
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init self.x0_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to input embedding
self.bigram_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to bigram embeddings
# Bigram embeddings: zero init so it starts as identity
nn.init.zeros_(self.bigram_embed.embed.weight)
# Value embeddings (init like c_v: uniform with same std) # Value embeddings (init like c_v: uniform with same std)
for ve in self.value_embeds.values(): for ve in self.value_embeds.values():
@@ -240,6 +283,7 @@ class GPT(nn.Module):
self.transformer.wte.to(dtype=torch.bfloat16) self.transformer.wte.to(dtype=torch.bfloat16)
for ve in self.value_embeds.values(): for ve in self.value_embeds.values():
ve.to(dtype=torch.bfloat16) ve.to(dtype=torch.bfloat16)
self.bigram_embed.to(dtype=torch.bfloat16)
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
# TODO: bump base theta more? e.g. 100K is more common more recently # TODO: bump base theta more? e.g. 100K is more common more recently
@@ -305,8 +349,9 @@ class GPT(nn.Module):
nparams = sum(p.numel() for p in self.parameters()) nparams = sum(p.numel() for p in self.parameters())
# Exclude non-matmul params: embeddings and per-layer scalars # Exclude non-matmul params: embeddings and per-layer scalars
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values()) value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + bigram_embed_numel = self.bigram_embed.embed.weight.numel()
self.resid_lambdas.numel() + self.x0_lambdas.numel()) nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + bigram_embed_numel +
self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.bigram_lambdas.numel())
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
# Sum attention FLOPs per layer, accounting for sliding window # Sum attention FLOPs per layer, accounting for sliding window
attn_flops = 0 attn_flops = 0
@@ -319,15 +364,34 @@ class GPT(nn.Module):
def num_scaling_params(self): def num_scaling_params(self):
""" """
Return all of the parameters, same as Chinchilla paper. Return detailed parameter counts for scaling law analysis.
Kaplan et al. did not include embedding parameters and said that this led to cleaner scaling laws. Different papers use different conventions:
But Kaplan et al. also had a bug in their results (as pointed out by Chinchilla). - Kaplan et al. excluded embedding parameters
My own experiments in nanochat confirm the Chinchilla approach gives the much cleaner scaling law. - Chinchilla included all parameters
Ref: https://arxiv.org/abs/2203.15556 (Chinchilla paper <- good). Ref: https://arxiv.org/abs/2203.15556 (Chinchilla paper)
Ref: https://arxiv.org/abs/2001.08361 (Kaplan et al. original scaling laws paper <- bad) Ref: https://arxiv.org/abs/2001.08361 (Kaplan et al. original scaling laws paper)
Returns a dict with counts for each parameter group, so downstream analysis
can experiment with which combination gives the cleanest scaling laws.
""" """
nparams = sum(p.numel() for p in self.parameters()) # Count each group separately (mirrors the grouping in setup_optimizers)
return nparams wte = sum(p.numel() for p in self.transformer.wte.parameters())
bigram_embed = sum(p.numel() for p in self.bigram_embed.parameters())
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
lm_head = sum(p.numel() for p in self.lm_head.parameters())
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.bigram_lambdas.numel()
total = wte + bigram_embed + value_embeds + lm_head + transformer_matrices + scalars
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
return {
'wte': wte,
'bigram_embed': bigram_embed,
'value_embeds': value_embeds,
'lm_head': lm_head,
'transformer_matrices': transformer_matrices,
'scalars': scalars,
'total': total,
}
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5): def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
model_dim = self.config.n_embd model_dim = self.config.n_embd
@@ -339,7 +403,9 @@ class GPT(nn.Module):
lm_head_params = list(self.lm_head.parameters()) lm_head_params = list(self.lm_head.parameters())
resid_params = [self.resid_lambdas] resid_params = [self.resid_lambdas]
x0_params = [self.x0_lambdas] x0_params = [self.x0_lambdas]
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) bigram_embed_params = list(self.bigram_embed.parameters())
bigram_lambda_params = [self.bigram_lambdas]
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(bigram_embed_params) + len(bigram_lambda_params)
# Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars # Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model) # Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
dmodel_lr_scale = (model_dim / 768) ** -0.5 dmodel_lr_scale = (model_dim / 768) ** -0.5
@@ -348,8 +414,10 @@ class GPT(nn.Module):
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale), dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale), dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
dict(params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding dict(params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding
dict(params=bigram_embed_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding
dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream
dict(params=x0_params, lr=scalar_lr, betas=(0.96, 0.95)), # higher beta1 for x0 scalars dict(params=x0_params, lr=scalar_lr, betas=(0.96, 0.95)), # higher beta1 for x0 scalars
dict(params=bigram_lambda_params, lr=scalar_lr, betas=(0.96, 0.95)), # same treatment as x0 lambdas
] ]
adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True) AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
@@ -377,11 +445,12 @@ class GPT(nn.Module):
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
# Forward the trunk of the Transformer # Forward the trunk of the Transformer
x = self.transformer.wte(idx) x = self.transformer.wte(idx) # embed current token
x0_bigram = self.bigram_embed(idx) # embed current bigram (via hash lookup)
x = norm(x) x = norm(x)
x0 = x # save initial normalized embedding for x0 residual x0 = x # save initial normalized embedding for x0 residual
for i, block in enumerate(self.transformer.h): for i, block in enumerate(self.transformer.h):
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 + self.bigram_lambdas[i] * x0_bigram
ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache) x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
x = norm(x) x = norm(x)

View File

@@ -1,13 +1,14 @@
#!/bin/bash #!/bin/bash
LABEL="jan16" LABEL="jan26"
FLOPS_BUDGETS=( FLOPS_BUDGETS=(
1e18 1e18
3e18 2.15e18
6e18 4.64e18
1e19
) )
DEPTHS=(6 7 8 9 10 11 12 13 14) DEPTHS=(8 10 12 14 16 18 20)
NPROC_PER_NODE="${NPROC_PER_NODE:-8}" NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
WANDB_RUN="${WANDB_RUN:-scaling_${LABEL}}" WANDB_RUN="${WANDB_RUN:-scaling_${LABEL}}"
@@ -23,7 +24,7 @@ RESULTS_FILE="$RESULTS_DIR/results.csv"
# Write CSV header only if file doesn't exist # Write CSV header only if file doesn't exist
if [ ! -f "$RESULTS_FILE" ]; then if [ ! -f "$RESULTS_FILE" ]; then
echo "flops_budget,depth,model_dim,num_params,num_scaling_params,num_iterations,tokens_trained,param_data_ratio,val_bpb,core_score,train_time_sec" > "$RESULTS_FILE" echo "flops_budget,depth,model_dim,params_wte,params_bigram_embed,params_value_embeds,params_lm_head,params_transformer,params_scalars,params_total,num_iterations,tokens_trained,val_bpb,core_score,train_time_sec" > "$RESULTS_FILE"
fi fi
log() { log() {
@@ -83,13 +84,19 @@ for flops in "${FLOPS_BUDGETS[@]}"; do
# Extract training stats from the log # Extract training stats from the log
LOG_FILE="$RESULTS_DIR/${TAG}_train.log" LOG_FILE="$RESULTS_DIR/${TAG}_train.log"
NUM_PARAMS=$(grep "Number of parameters:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | head -1 | tr -d ',')
NUM_SCALING_PARAMS=$(grep "Number of parameters:" "$LOG_FILE" | tail -1 | grep -oP 'scaling: [\d,]+' | grep -oP '[\d,]+' | tr -d ',') # Extract detailed parameter counts (for scaling law analysis with different conventions)
PARAMS_WTE=$(grep "wte:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
PARAMS_BIGRAM=$(grep "bigram_embed:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
PARAMS_VE=$(grep "value_embeds:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
PARAMS_LM=$(grep "lm_head:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
PARAMS_TRANSFORMER=$(grep "transformer_matrices:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
PARAMS_SCALARS=$(grep "scalars:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
PARAMS_TOTAL=$(grep "total:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
NUM_ITERS=$(grep "Calculated number of iterations" "$LOG_FILE" | tail -1 | sed 's/.*: //' | tr -d ',') NUM_ITERS=$(grep "Calculated number of iterations" "$LOG_FILE" | tail -1 | sed 's/.*: //' | tr -d ',')
# Calculate tokens trained (iterations * batch_size, default 524288) # Calculate tokens trained (iterations * batch_size, default 524288)
TOKENS_TRAINED=$((NUM_ITERS * 524288)) TOKENS_TRAINED=$((NUM_ITERS * 524288))
# Param:data ratio (using scaling params per Kaplan et al.)
PARAM_DATA_RATIO=$(python -c "print(f'{$TOKENS_TRAINED / $NUM_SCALING_PARAMS:.2f}')")
# Model dim # Model dim
MODEL_DIM=$((d * 64)) MODEL_DIM=$((d * 64))
# Val BPB from final eval # Val BPB from final eval
@@ -102,10 +109,10 @@ for flops in "${FLOPS_BUDGETS[@]}"; do
CORE_SCORE="0.0" CORE_SCORE="0.0"
fi fi
log " Params: $NUM_PARAMS, Iters: $NUM_ITERS, Ratio: $PARAM_DATA_RATIO, Val BPB: $VAL_BPB, CORE: $CORE_SCORE" log " Params: $PARAMS_TOTAL (transformer: $PARAMS_TRANSFORMER), Iters: $NUM_ITERS, Val BPB: $VAL_BPB, CORE: $CORE_SCORE"
# Append to CSV # Append to CSV
echo "$flops,$d,$MODEL_DIM,$NUM_PARAMS,$NUM_SCALING_PARAMS,$NUM_ITERS,$TOKENS_TRAINED,$PARAM_DATA_RATIO,$VAL_BPB,$CORE_SCORE,$TRAIN_TIME" >> "$RESULTS_FILE" echo "$flops,$d,$MODEL_DIM,$PARAMS_WTE,$PARAMS_BIGRAM,$PARAMS_VE,$PARAMS_LM,$PARAMS_TRANSFORMER,$PARAMS_SCALARS,$PARAMS_TOTAL,$NUM_ITERS,$TOKENS_TRAINED,$VAL_BPB,$CORE_SCORE,$TRAIN_TIME" >> "$RESULTS_FILE"
done done
done done

View File

@@ -1,11 +1,11 @@
""" """
Train model. From root directory of the project, run as: Train model. From root directory of the project, run as:
python -m scripts.base_train.py python -m scripts.base_train
or distributed as: or distributed as:
torchrun --nproc_per_node=8 -m scripts.base_train.py torchrun --nproc_per_node=8 -m scripts.base_train
If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example: If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20 python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20
@@ -47,7 +47,7 @@ parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding
# Training horizon (only one used, in order of precedence) # Training horizon (only one used, in order of precedence)
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)") parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)") parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
parser.add_argument("--target-param-data-ratio", type=int, default=4, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)") parser.add_argument("--target-param-data-ratio", type=float, default=10.5, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
# Optimization # Optimization
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size") parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens") parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
@@ -178,9 +178,14 @@ if resuming:
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape) orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
num_params = sum(p.numel() for p in model.parameters())
num_scaling_params = orig_model.num_scaling_params() # Detailed parameter counts
print0(f"Number of parameters: {num_params:,} (scaling: {num_scaling_params:,})") param_counts = orig_model.num_scaling_params()
print0(f"Parameter counts:")
for key, value in param_counts.items():
print0(f"{key:24s}: {value:,}")
num_params = param_counts['total']
num_scaling_params = param_counts['transformer_matrices'] + param_counts['lm_head'] # determined to give the cleanest scaling laws, see dev/LOG.md Jan 27, 2026
num_flops_per_token = model.estimate_flops() num_flops_per_token = model.estimate_flops()
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}") print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
@@ -195,14 +200,14 @@ elif args.target_flops > 0:
print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}") print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
elif args.target_param_data_ratio > 0: elif args.target_param_data_ratio > 0:
# calculate the number of iterations from the target param data ratio (use scaling params per Kaplan et al.) # calculate the number of iterations from the target param data ratio (use scaling params per Kaplan et al.)
target_tokens = args.target_param_data_ratio * num_scaling_params target_tokens = int(args.target_param_data_ratio * num_scaling_params)
num_iterations = target_tokens // args.total_batch_size num_iterations = target_tokens // args.total_batch_size
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}") print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
else: else:
raise ValueError("No training horizon specified") raise ValueError("No training horizon specified")
total_tokens = args.total_batch_size * num_iterations total_tokens = args.total_batch_size * num_iterations
print0(f"Total number of training tokens: {total_tokens:,}") print0(f"Total number of training tokens: {total_tokens:,}")
print0(f"Tokens : Params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20 print0(f"Tokens : Scaling params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -445,7 +450,7 @@ get_report().log(section="Base model training", data=[
"Number of FLOPs per token": f"{num_flops_per_token:e}", "Number of FLOPs per token": f"{num_flops_per_token:e}",
"Calculated number of iterations": num_iterations, "Calculated number of iterations": num_iterations,
"Number of training tokens": total_tokens, "Number of training tokens": total_tokens,
"Tokens : Params ratio": args.total_batch_size * num_iterations / num_params, "Tokens : Scaling params ratio": args.total_batch_size * num_iterations / num_scaling_params,
"DDP world size": ddp_world_size, "DDP world size": ddp_world_size,
"warmup_ratio": args.warmup_ratio, "warmup_ratio": args.warmup_ratio,
"warmdown_ratio": args.warmdown_ratio, "warmdown_ratio": args.warmdown_ratio,