mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-30 04:22:02 +00:00
Compare commits
3 Commits
8630d32be4
...
c88bbf8133
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c88bbf8133 | ||
|
|
c8d93beed2 | ||
|
|
59e36cc727 |
134
dev/LOG.md
134
dev/LOG.md
@@ -4,6 +4,140 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-27: Bigram Hash Embeddings (Engram-lite)
|
||||
|
||||
Explored N-gram memory modules inspired by the [DeepSeek Engram paper](https://arxiv.org/abs/2506.08046) and [modded-nanogpt PR #201](https://github.com/KellerJordan/modded-nanogpt/pull/201).
|
||||
|
||||
### Background
|
||||
|
||||
The Engram paper introduces "conditional memory" as a complement to MoE - using O(1) hash lookups to retrieve static N-gram patterns instead of reconstructing them through computation. Key insight: transformers waste early layers "simulating retrieval through computation" for patterns like named entities and formulaic phrases that could be simple table lookups.
|
||||
|
||||
### What We Tried
|
||||
|
||||
**1. Full Engram module with context-aware gating (paper design)**
|
||||
```python
|
||||
# Hash bigrams to retrieve embeddings, then gate with hidden state
|
||||
e = embed(hash(prev_token, curr_token))
|
||||
q = RMSNorm(h) # hidden state as query
|
||||
k = RMSNorm(W_k @ e) # projected embedding as key
|
||||
v = W_v @ e
|
||||
α = sigmoid(q · k / √d) # scalar gate per position
|
||||
output = α * v
|
||||
```
|
||||
- Injected after block 1 (paper found early injection optimal)
|
||||
- Slight improvement, but quite a bit of complexity added.
|
||||
|
||||
**2. Early-layer only injection**
|
||||
- Only inject bigram signal in first 4 layers (where paper claims static pattern offloading helps most)
|
||||
- **Result:** Actually hurt performance. The model seems to need uniform injection across all layers.
|
||||
|
||||
**3. Trigrams**
|
||||
- Extended to hash both 2-grams and 3-grams, concatenating embeddings
|
||||
- **Result:** No improvement over bigrams alone. Dilutes capacity from more frequent 2-gram patterns.
|
||||
|
||||
**4. Bigram-only with x0-style injection (modded-nanogpt engram-lite approach)**
|
||||
- Simple hash: `(36313 * curr) XOR (27191 * prev) mod table_size`
|
||||
- Zero-init embedding table, learned per-layer lambdas
|
||||
- Add to residual at every layer: `x = resid_λ[i]*x + x0_λ[i]*x0 + bigram_λ[i]*x0_bigram`
|
||||
- **Result:** This simple approach works and provides a consistent improvement.
|
||||
|
||||
TLDR The winning approach follows modded-nanogpt's "engram-lite", simply adding the following module and feeding its output into the residual branch (gated by a per-layer learnable \lambda) before every single block:
|
||||
|
||||
```python
|
||||
class BigramEmbed(nn.Module):
|
||||
def __init__(self, vocab_size, embed_dim, table_multiplier=5):
|
||||
self.embed = nn.Embedding(vocab_size * table_multiplier, embed_dim)
|
||||
|
||||
def forward(self, idx):
|
||||
h = (36313 * idx[:, 1:]) ^ (27191 * idx[:, :-1]) % (table_size - 1)
|
||||
return self.embed(h)
|
||||
```
|
||||
|
||||
As for optimal hyperparameters:
|
||||
|
||||
- **Table size:** `vocab_size * 5` (~164K entries for 32K vocab). Swept a number of settings and 5 was optimal.
|
||||
- **Injection:** Every layer via learned `bigram_lambdas` (init 0.1 was better than 0.0).
|
||||
- **Normalization:** Also tried adding a `norm()` to the embeddings (mirroring the token embeddings), this was slightly worse.
|
||||
- **Init:** Zero-init embedding, so starts as identity (tried small noisy init, it's worse)
|
||||
- **Optimizer:** AdamW with same LR as token embeddings
|
||||
|
||||
### Key Learnings
|
||||
|
||||
1. **Gating didn't help at our scale.** The paper's context-aware gating mechanism (sigmoid dot-product gate) added parameters and complexity without improvement. modded-nanogpt found the same: "simple direct addition to the residual stream outperformed by a decent margin."
|
||||
|
||||
2. **Uniform injection beats early-only.** Despite the paper's finding that early layers benefit most, restricting injection to early layers hurt. The x0-style "add everywhere with learned lambda" pattern works better for our architecture/scale.
|
||||
|
||||
3. **Bigrams are sufficient.** Trigrams didn't help - the extra context doesn't pay for the diluted capacity.
|
||||
|
||||
4. **Scale matters.** The Engram paper's results are at 27B params with MoE. At our ~100M-1B scale, the simpler approach wins. The elaborate gating mechanism may become useful at larger scales where collision handling matters more.
|
||||
|
||||
### Parameters Added
|
||||
|
||||
For d12 model with `table_multiplier=5`:
|
||||
- Bigram embedding: 32768 × 5 × 768 = ~126M params
|
||||
- Per-layer lambdas: 12 scalars (negligible)
|
||||
|
||||
If you're keeping track, we now have *a lot* of parameters, a significant amount of them in embeddings (token embeddings, bigram embeddings, value embeddings). For example, for a d12 we now have:
|
||||
|
||||
```
|
||||
Parameter counts:
|
||||
wte : 25,165,824
|
||||
bigram_embed : 125,829,120
|
||||
value_embeds : 150,994,944
|
||||
lm_head : 25,165,824
|
||||
transformer_matrices : 84,935,808
|
||||
scalars : 36
|
||||
total : 412,091,556
|
||||
```
|
||||
|
||||
In other words, only about a quarter of parameters are now weight projections and the vast majority is embedding tables.
|
||||
|
||||
Still, on all axes (steps, wall clock time, flops), this somewhat parameter-bloated architecture beats the baseline and will now become the default.
|
||||
|
||||
After adding the engram-lite, I re-ran the scaling laws to determine the new optimal tokens:params ratio. I swept FLOPs in the range 1e18..1e19, exponentially strided in 4 settings (1e18, 2e18, 5e18, 1e19). I looked at a number of ways of determining the effective parameter count for the purposes of the scaling laws. The results looked like this:
|
||||
|
||||
```
|
||||
Kaplan-style (all projections including lm_head and no embeddings)
|
||||
|
||||
Optimal configurations (from quadratic fits):
|
||||
FLOPs Eff Params Tokens Ratio Val BPB
|
||||
-----------------------------------------------------------------
|
||||
1e+18 110,678,115 1,241,505,403 11.2 0.8972
|
||||
2e+18 167,797,457 1,785,336,422 10.7 0.8616
|
||||
5e+18 250,650,865 2,642,234,152 10.8 0.8293
|
||||
1e+19 381,758,347 3,806,871,243 10.3 0.7999
|
||||
|
||||
N \propto C^0.54, D \propto C^0.49
|
||||
|
||||
Chinchilla-style (all parameters, period.)
|
||||
|
||||
Optimal configurations (from quadratic fits):
|
||||
FLOPs Eff Params Tokens Ratio Val BPB
|
||||
-----------------------------------------------------------------
|
||||
1e+18 416,320,605 1,232,157,011 3.0 0.8974
|
||||
2e+18 560,239,841 1,763,669,281 3.2 0.8616
|
||||
5e+18 741,495,903 2,629,909,368 3.6 0.8291
|
||||
1e+19 988,644,331 3,884,841,895 4.0 0.7999
|
||||
|
||||
N \propto C^0.37, D \propto C^0.50
|
||||
|
||||
Transformer-only-style (only the projections inside the transformer)
|
||||
|
||||
Optimal configurations (from quadratic fits):
|
||||
FLOPs Eff Params Tokens Ratio Val BPB
|
||||
-----------------------------------------------------------------
|
||||
1e+18 80,259,665 1,315,639,547 17.2 0.8966
|
||||
2e+18 131,488,566 1,864,134,141 14.5 0.8622
|
||||
5e+18 220,985,474 2,595,328,843 12.1 0.8302
|
||||
1e+19 401,213,504 3,328,704,512 8.5 0.7994
|
||||
|
||||
N \propto C^0.70, D \propto C^0.41
|
||||
```
|
||||
|
||||
Clearly, the Kaplan-style ratios are most consistent and produce stable ~0.5 exponents for both params and tokens, meaning we can have a single fixed ratio of tokens:params for compute optimal models. This turns out to be about ~10.5, which now becomes the new default.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-19 to 2026-01-22: Optimizer Hyperparameter Sweep
|
||||
|
||||
Ran ~320 experiments across 6 rounds, scaling from d12→d16→d20 to find optimal optimizer hyperparameters. Added granular per-component control to `setup_optimizers()` — separate LRs and betas for embedding, unembedding, value_embeds, resid_lambdas, x0_lambdas, and Muon matrix params.
|
||||
|
||||
@@ -15,14 +15,16 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%matplotlib inline\n",
|
||||
"import os\n",
|
||||
"import pandas as pd\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"# Load results\n",
|
||||
"tag = \"jan26\"\n",
|
||||
"base_dir = os.environ.get('NANOCHAT_BASE_DIR', os.path.expanduser('~/.cache/nanochat'))\n",
|
||||
"results_path = os.path.join(base_dir, 'scaling_laws_results', 'results.csv')\n",
|
||||
"results_path = os.path.join(base_dir, f'scaling_laws_results_{tag}', 'results.csv')\n",
|
||||
"\n",
|
||||
"df = pd.read_csv(results_path)\n",
|
||||
"flops_budgets = sorted(df['flops_budget'].unique())\n",
|
||||
@@ -31,6 +33,99 @@
|
||||
"df"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# =============================================================================\n",
|
||||
"# FILTERING: Remove incomplete or problematic runs\n",
|
||||
"# =============================================================================\n",
|
||||
"\n",
|
||||
"print(f\"Before filtering: {len(df)} runs\")\n",
|
||||
"\n",
|
||||
"# Filter out runs with missing/invalid val_bpb (incomplete runs)\n",
|
||||
"df = df[df['val_bpb'].notna() & (df['val_bpb'] > 0)]\n",
|
||||
"\n",
|
||||
"# Optional: exclude specific flops budgets that aren't done yet\n",
|
||||
"# exclude_flops = [1e19] # <-- adjust as runs complete\n",
|
||||
"# df = df[~df['flops_budget'].isin(exclude_flops)]\n",
|
||||
"\n",
|
||||
"# Optional: exclude specific depths\n",
|
||||
"# exclude_depths = [18, 20]\n",
|
||||
"# df = df[~df['depth'].isin(exclude_depths)]\n",
|
||||
"\n",
|
||||
"print(f\"After filtering: {len(df)} runs\")\n",
|
||||
"print(f\"FLOPs budgets: {sorted(df['flops_budget'].unique())}\")\n",
|
||||
"print(f\"Depths: {sorted(df['depth'].unique())}\")\n",
|
||||
"\n",
|
||||
"# Update flops_budgets list after filtering\n",
|
||||
"flops_budgets = sorted(df['flops_budget'].unique())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Effective Parameter Count\n",
|
||||
"\n",
|
||||
"Different scaling law papers use different conventions for counting parameters:\n",
|
||||
"- **Kaplan et al.** excluded embedding parameters (claimed cleaner laws)\n",
|
||||
"- **Chinchilla** included all parameters (and noted Kaplan had a bug)\n",
|
||||
"\n",
|
||||
"Our CSV now has granular counts:\n",
|
||||
"- `params_wte` - token embedding (lookup table)\n",
|
||||
"- `params_bigram_embed` - bigram hash embeddings (lookup table)\n",
|
||||
"- `params_value_embeds` - value embeddings (lookup table)\n",
|
||||
"- `params_lm_head` - unembedding projection (matmul)\n",
|
||||
"- `params_transformer` - attention + MLP matrices (matmuls)\n",
|
||||
"- `params_scalars` - resid/x0/bigram lambdas (tiny)\n",
|
||||
"\n",
|
||||
"**Experiment below** with different combinations to see which gives the cleanest scaling laws."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# =============================================================================\n",
|
||||
"# EXPERIMENT HERE: Define which parameters to count for scaling laws\n",
|
||||
"# =============================================================================\n",
|
||||
"\n",
|
||||
"def compute_effective_params(row):\n",
|
||||
" \"\"\"\n",
|
||||
" Compute the 'effective' parameter count for scaling law analysis.\n",
|
||||
"\n",
|
||||
" Modify this function to experiment with different conventions:\n",
|
||||
" - Chinchilla-style: include everything\n",
|
||||
" - Kaplan-style: exclude embeddings\n",
|
||||
" - Matmul-only: just transformer + lm_head (the actual compute)\n",
|
||||
" - etc.\n",
|
||||
" \"\"\"\n",
|
||||
" # Option 1: Chinchilla-style (all params)\n",
|
||||
" # return row['params_total']\n",
|
||||
"\n",
|
||||
" # Option 2: Kaplan-style (exclude embeddings)\n",
|
||||
" return row['params_transformer'] + row['params_lm_head']\n",
|
||||
"\n",
|
||||
" # Option 3: Transformer-only (exclude all embeddings AND lm_head)\n",
|
||||
" # return row['params_transformer']\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Compute derived columns\n",
|
||||
"df['effective_params'] = df.apply(compute_effective_params, axis=1)\n",
|
||||
"df['param_data_ratio'] = df['tokens_trained'] / df['effective_params']\n",
|
||||
"\n",
|
||||
"# Show parameter breakdown for first few rows\n",
|
||||
"print(\"Parameter breakdown (first row per flops budget):\")\n",
|
||||
"param_cols = ['depth', 'params_wte', 'params_bigram_embed', 'params_value_embeds',\n",
|
||||
" 'params_lm_head', 'params_transformer', 'params_scalars', 'params_total', 'effective_params']\n",
|
||||
"df.groupby('flops_budget').first()[param_cols]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
@@ -54,11 +149,11 @@
|
||||
"optimal_by_bpb = []\n",
|
||||
"\n",
|
||||
"for flops, color in zip(flops_budgets, colors):\n",
|
||||
" subset = df[df['flops_budget'] == flops].sort_values('num_scaling_params')\n",
|
||||
" ax.plot(subset['num_scaling_params'], subset['val_bpb'], 'o', color=color, label=f'{flops:.0e}', markersize=8)\n",
|
||||
" subset = df[df['flops_budget'] == flops].sort_values('effective_params')\n",
|
||||
" ax.plot(subset['effective_params'], subset['val_bpb'], 'o', color=color, label=f'{flops:.0e}', markersize=8)\n",
|
||||
"\n",
|
||||
" # Fit quadratic in log-space: val_bpb = a*(log N)^2 + b*(log N) + c\n",
|
||||
" log_params = np.log10(subset['num_scaling_params'])\n",
|
||||
" log_params = np.log10(subset['effective_params'])\n",
|
||||
" coeffs = np.polyfit(log_params, subset['val_bpb'], 2)\n",
|
||||
" a, b, c = coeffs\n",
|
||||
"\n",
|
||||
@@ -83,13 +178,13 @@
|
||||
" # Fallback to raw minimum if quadratic doesn't have minimum\n",
|
||||
" best_idx = subset['val_bpb'].idxmin()\n",
|
||||
" best = subset.loc[best_idx]\n",
|
||||
" ax.scatter([best['num_scaling_params']], [best['val_bpb']], s=150, color=color,\n",
|
||||
" ax.scatter([best['effective_params']], [best['val_bpb']], s=150, color=color,\n",
|
||||
" zorder=5, edgecolors='black', linewidths=2)\n",
|
||||
" optimal_by_bpb.append({'flops': flops, 'params': best['num_scaling_params'],\n",
|
||||
" optimal_by_bpb.append({'flops': flops, 'params': best['effective_params'],\n",
|
||||
" 'tokens': best['tokens_trained'], 'ratio': best['param_data_ratio'], 'bpb': best['val_bpb']})\n",
|
||||
"\n",
|
||||
"ax.set_xscale('log')\n",
|
||||
"ax.set_xlabel('Parameters')\n",
|
||||
"ax.set_xlabel('Effective Parameters')\n",
|
||||
"ax.set_ylabel('Validation Loss (bpb)')\n",
|
||||
"ax.set_title('IsoFLOP Curves')\n",
|
||||
"ax.legend(title='FLOPs', loc='upper right')\n",
|
||||
@@ -138,10 +233,61 @@
|
||||
"\n",
|
||||
"# Print the optimal points (from quadratic fits)\n",
|
||||
"print(\"\\nOptimal configurations (from quadratic fits):\")\n",
|
||||
"print(f\"{'FLOPs':<12} {'Params':<15} {'Tokens':<15} {'Ratio':<10} {'Val BPB':<10}\")\n",
|
||||
"print(f\"{'FLOPs':<12} {'Eff Params':<15} {'Tokens':<15} {'Ratio':<10} {'Val BPB':<10}\")\n",
|
||||
"print(\"-\" * 65)\n",
|
||||
"for _, row in opt_df.iterrows():\n",
|
||||
" print(f\"{row['flops']:<12.0e} {int(row['params']):<15,} {int(row['tokens']):<15,} {row['ratio']:<10.1f} {row['bpb']:<10.4f}\")\n"
|
||||
" print(f\"{row['flops']:<12.0e} {int(row['params']):<15,} {int(row['tokens']):<15,} {row['ratio']:<10.1f} {row['bpb']:<10.4f}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# =============================================================================\n",
|
||||
"# Optimal Ratio Summary (from power law fits)\n",
|
||||
"# =============================================================================\n",
|
||||
"\n",
|
||||
"# From the power law fits: N ∝ C^a and D ∝ C^b\n",
|
||||
"# The ratio D/N ∝ C^(b-a). If a ≈ b, ratio is roughly constant.\n",
|
||||
"\n",
|
||||
"if len(opt_df) >= 2:\n",
|
||||
" log_f = np.log10(opt_df['flops'])\n",
|
||||
" log_p = np.log10(opt_df['params'])\n",
|
||||
" log_t = np.log10(opt_df['tokens'])\n",
|
||||
"\n",
|
||||
" # Fit power laws\n",
|
||||
" slope_n, intercept_n = np.polyfit(log_f, log_p, 1)\n",
|
||||
" slope_d, intercept_d = np.polyfit(log_f, log_t, 1)\n",
|
||||
"\n",
|
||||
" # The ratio D/N at a reference compute (geometric mean of our budgets)\n",
|
||||
" ref_flops = np.sqrt(opt_df['flops'].min() * opt_df['flops'].max())\n",
|
||||
" log_ref = np.log10(ref_flops)\n",
|
||||
"\n",
|
||||
" # Predicted optimal N and D at reference compute\n",
|
||||
" pred_log_n = intercept_n + slope_n * log_ref\n",
|
||||
" pred_log_d = intercept_d + slope_d * log_ref\n",
|
||||
" optimal_ratio = 10**(pred_log_d - pred_log_n)\n",
|
||||
"\n",
|
||||
" # Also compute from the fitted optimals directly (mean and std)\n",
|
||||
" mean_ratio = opt_df['ratio'].mean()\n",
|
||||
" std_ratio = opt_df['ratio'].std()\n",
|
||||
"\n",
|
||||
" print(\"=\" * 60)\n",
|
||||
" print(\"OPTIMAL RATIO SUMMARY\")\n",
|
||||
" print(\"=\" * 60)\n",
|
||||
" print(f\"\\nPower law exponents:\")\n",
|
||||
" print(f\" N ∝ C^{slope_n:.3f}\")\n",
|
||||
" print(f\" D ∝ C^{slope_d:.3f}\")\n",
|
||||
" print(f\" Ratio exponent (b-a): {slope_d - slope_n:.3f} (should be ~0 if ratio is constant)\")\n",
|
||||
" print(f\"\\nOptimal ratio (tokens per effective param):\")\n",
|
||||
" print(f\" From power law at C={ref_flops:.1e}: {optimal_ratio:.1f}\")\n",
|
||||
" print(f\" Mean across budgets: {mean_ratio:.1f} ± {std_ratio:.1f}\")\n",
|
||||
" print(f\" Chinchilla reference: 20\")\n",
|
||||
" print(f\"\\nPer-budget ratios: {[f'{r:.1f}' for r in opt_df['ratio'].values]}\")\n",
|
||||
"else:\n",
|
||||
" print(\"Need at least 2 flops budgets to compute power law fits\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -45,6 +45,41 @@ def norm(x):
|
||||
return F.rms_norm(x, (x.size(-1),))
|
||||
|
||||
|
||||
class BigramEmbed(nn.Module):
|
||||
"""
|
||||
Hash bigrams to embeddings. Simple, self-contained, runs on GPU.
|
||||
Following modded-nanogpt's approach: single hash, no gating.
|
||||
|
||||
For each position t, hashes (token[t-1], token[t]) to an index in a large
|
||||
embedding table. This provides O(1) lookup for local 2-gram patterns,
|
||||
offloading static pattern reconstruction from the transformer layers.
|
||||
|
||||
Ref: https://github.com/KellerJordan/modded-nanogpt/pull/201
|
||||
Ref: https://arxiv.org/abs/1709.03933 (Hash Embeddings)
|
||||
"""
|
||||
def __init__(self, vocab_size: int, embed_dim: int, table_multiplier: int = 5):
|
||||
super().__init__()
|
||||
self.bigram_vocab_size = vocab_size * table_multiplier
|
||||
self.embed = nn.Embedding(self.bigram_vocab_size, embed_dim)
|
||||
|
||||
def forward(self, idx: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
idx: (B, T) token ids
|
||||
Returns: (B, T, embed_dim) bigram embeddings
|
||||
"""
|
||||
# Hash (prev_token, curr_token) -> index
|
||||
# Position 0 gets a reserved index (no valid bigram)
|
||||
rand_int_1 = 36313
|
||||
rand_int_2 = 27191
|
||||
mod = self.bigram_vocab_size - 1
|
||||
|
||||
h = torch.empty_like(idx, dtype=torch.long)
|
||||
h[:, 0] = mod # reserved index for position 0
|
||||
h[:, 1:] = (rand_int_1 * idx[:, 1:] ^ rand_int_2 * idx[:, :-1]) % mod
|
||||
|
||||
return self.embed(h)
|
||||
|
||||
|
||||
def has_ve(layer_idx, n_layer):
|
||||
"""Returns True if GPT layer should have Value Embedding (alternating, last layer always included)."""
|
||||
return layer_idx % 2 == (n_layer - 1) % 2
|
||||
@@ -169,9 +204,13 @@ class GPT(nn.Module):
|
||||
# Per-layer learnable scalars (inspired by modded-nanogpt)
|
||||
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
|
||||
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
|
||||
# bigram_lambdas: blends bigram embeddings in at each layer (init 0.1 = small contribution)
|
||||
# Separate parameters so they can have different optimizer treatment
|
||||
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
|
||||
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
||||
self.bigram_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
||||
# Bigram hash embeddings: O(1) lookup for local 2-gram patterns
|
||||
self.bigram_embed = BigramEmbed(config.vocab_size, config.n_embd)
|
||||
# Value embeddings (ResFormer-style): alternating layers, last layer always included
|
||||
head_dim = config.n_embd // config.n_head
|
||||
kv_dim = config.n_kv_head * head_dim
|
||||
@@ -219,7 +258,11 @@ class GPT(nn.Module):
|
||||
|
||||
# Per-layer scalars
|
||||
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
|
||||
self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init
|
||||
self.x0_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to input embedding
|
||||
self.bigram_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to bigram embeddings
|
||||
|
||||
# Bigram embeddings: zero init so it starts as identity
|
||||
nn.init.zeros_(self.bigram_embed.embed.weight)
|
||||
|
||||
# Value embeddings (init like c_v: uniform with same std)
|
||||
for ve in self.value_embeds.values():
|
||||
@@ -240,6 +283,7 @@ class GPT(nn.Module):
|
||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||
for ve in self.value_embeds.values():
|
||||
ve.to(dtype=torch.bfloat16)
|
||||
self.bigram_embed.to(dtype=torch.bfloat16)
|
||||
|
||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
||||
# TODO: bump base theta more? e.g. 100K is more common more recently
|
||||
@@ -305,8 +349,9 @@ class GPT(nn.Module):
|
||||
nparams = sum(p.numel() for p in self.parameters())
|
||||
# Exclude non-matmul params: embeddings and per-layer scalars
|
||||
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
|
||||
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
|
||||
self.resid_lambdas.numel() + self.x0_lambdas.numel())
|
||||
bigram_embed_numel = self.bigram_embed.embed.weight.numel()
|
||||
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + bigram_embed_numel +
|
||||
self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.bigram_lambdas.numel())
|
||||
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
||||
# Sum attention FLOPs per layer, accounting for sliding window
|
||||
attn_flops = 0
|
||||
@@ -319,15 +364,34 @@ class GPT(nn.Module):
|
||||
|
||||
def num_scaling_params(self):
|
||||
"""
|
||||
Return all of the parameters, same as Chinchilla paper.
|
||||
Kaplan et al. did not include embedding parameters and said that this led to cleaner scaling laws.
|
||||
But Kaplan et al. also had a bug in their results (as pointed out by Chinchilla).
|
||||
My own experiments in nanochat confirm the Chinchilla approach gives the much cleaner scaling law.
|
||||
Ref: https://arxiv.org/abs/2203.15556 (Chinchilla paper <- good).
|
||||
Ref: https://arxiv.org/abs/2001.08361 (Kaplan et al. original scaling laws paper <- bad)
|
||||
Return detailed parameter counts for scaling law analysis.
|
||||
Different papers use different conventions:
|
||||
- Kaplan et al. excluded embedding parameters
|
||||
- Chinchilla included all parameters
|
||||
Ref: https://arxiv.org/abs/2203.15556 (Chinchilla paper)
|
||||
Ref: https://arxiv.org/abs/2001.08361 (Kaplan et al. original scaling laws paper)
|
||||
|
||||
Returns a dict with counts for each parameter group, so downstream analysis
|
||||
can experiment with which combination gives the cleanest scaling laws.
|
||||
"""
|
||||
nparams = sum(p.numel() for p in self.parameters())
|
||||
return nparams
|
||||
# Count each group separately (mirrors the grouping in setup_optimizers)
|
||||
wte = sum(p.numel() for p in self.transformer.wte.parameters())
|
||||
bigram_embed = sum(p.numel() for p in self.bigram_embed.parameters())
|
||||
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
|
||||
lm_head = sum(p.numel() for p in self.lm_head.parameters())
|
||||
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
|
||||
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.bigram_lambdas.numel()
|
||||
total = wte + bigram_embed + value_embeds + lm_head + transformer_matrices + scalars
|
||||
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
|
||||
return {
|
||||
'wte': wte,
|
||||
'bigram_embed': bigram_embed,
|
||||
'value_embeds': value_embeds,
|
||||
'lm_head': lm_head,
|
||||
'transformer_matrices': transformer_matrices,
|
||||
'scalars': scalars,
|
||||
'total': total,
|
||||
}
|
||||
|
||||
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
|
||||
model_dim = self.config.n_embd
|
||||
@@ -339,7 +403,9 @@ class GPT(nn.Module):
|
||||
lm_head_params = list(self.lm_head.parameters())
|
||||
resid_params = [self.resid_lambdas]
|
||||
x0_params = [self.x0_lambdas]
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)
|
||||
bigram_embed_params = list(self.bigram_embed.parameters())
|
||||
bigram_lambda_params = [self.bigram_lambdas]
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(bigram_embed_params) + len(bigram_lambda_params)
|
||||
# Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars
|
||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
|
||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||
@@ -348,8 +414,10 @@ class GPT(nn.Module):
|
||||
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
|
||||
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
|
||||
dict(params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding
|
||||
dict(params=bigram_embed_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding
|
||||
dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream
|
||||
dict(params=x0_params, lr=scalar_lr, betas=(0.96, 0.95)), # higher beta1 for x0 scalars
|
||||
dict(params=bigram_lambda_params, lr=scalar_lr, betas=(0.96, 0.95)), # same treatment as x0 lambdas
|
||||
]
|
||||
adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon
|
||||
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
|
||||
@@ -377,11 +445,12 @@ class GPT(nn.Module):
|
||||
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
|
||||
|
||||
# Forward the trunk of the Transformer
|
||||
x = self.transformer.wte(idx)
|
||||
x = self.transformer.wte(idx) # embed current token
|
||||
x0_bigram = self.bigram_embed(idx) # embed current bigram (via hash lookup)
|
||||
x = norm(x)
|
||||
x0 = x # save initial normalized embedding for x0 residual
|
||||
for i, block in enumerate(self.transformer.h):
|
||||
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
||||
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 + self.bigram_lambdas[i] * x0_bigram
|
||||
ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None
|
||||
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
|
||||
x = norm(x)
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
#!/bin/bash
|
||||
|
||||
LABEL="jan16"
|
||||
LABEL="jan26"
|
||||
|
||||
FLOPS_BUDGETS=(
|
||||
1e18
|
||||
3e18
|
||||
6e18
|
||||
2.15e18
|
||||
4.64e18
|
||||
1e19
|
||||
)
|
||||
DEPTHS=(6 7 8 9 10 11 12 13 14)
|
||||
DEPTHS=(8 10 12 14 16 18 20)
|
||||
|
||||
NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
|
||||
WANDB_RUN="${WANDB_RUN:-scaling_${LABEL}}"
|
||||
@@ -23,7 +24,7 @@ RESULTS_FILE="$RESULTS_DIR/results.csv"
|
||||
|
||||
# Write CSV header only if file doesn't exist
|
||||
if [ ! -f "$RESULTS_FILE" ]; then
|
||||
echo "flops_budget,depth,model_dim,num_params,num_scaling_params,num_iterations,tokens_trained,param_data_ratio,val_bpb,core_score,train_time_sec" > "$RESULTS_FILE"
|
||||
echo "flops_budget,depth,model_dim,params_wte,params_bigram_embed,params_value_embeds,params_lm_head,params_transformer,params_scalars,params_total,num_iterations,tokens_trained,val_bpb,core_score,train_time_sec" > "$RESULTS_FILE"
|
||||
fi
|
||||
|
||||
log() {
|
||||
@@ -83,13 +84,19 @@ for flops in "${FLOPS_BUDGETS[@]}"; do
|
||||
|
||||
# Extract training stats from the log
|
||||
LOG_FILE="$RESULTS_DIR/${TAG}_train.log"
|
||||
NUM_PARAMS=$(grep "Number of parameters:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | head -1 | tr -d ',')
|
||||
NUM_SCALING_PARAMS=$(grep "Number of parameters:" "$LOG_FILE" | tail -1 | grep -oP 'scaling: [\d,]+' | grep -oP '[\d,]+' | tr -d ',')
|
||||
|
||||
# Extract detailed parameter counts (for scaling law analysis with different conventions)
|
||||
PARAMS_WTE=$(grep "wte:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
PARAMS_BIGRAM=$(grep "bigram_embed:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
PARAMS_VE=$(grep "value_embeds:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
PARAMS_LM=$(grep "lm_head:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
PARAMS_TRANSFORMER=$(grep "transformer_matrices:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
PARAMS_SCALARS=$(grep "scalars:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
PARAMS_TOTAL=$(grep "total:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
|
||||
NUM_ITERS=$(grep "Calculated number of iterations" "$LOG_FILE" | tail -1 | sed 's/.*: //' | tr -d ',')
|
||||
# Calculate tokens trained (iterations * batch_size, default 524288)
|
||||
TOKENS_TRAINED=$((NUM_ITERS * 524288))
|
||||
# Param:data ratio (using scaling params per Kaplan et al.)
|
||||
PARAM_DATA_RATIO=$(python -c "print(f'{$TOKENS_TRAINED / $NUM_SCALING_PARAMS:.2f}')")
|
||||
# Model dim
|
||||
MODEL_DIM=$((d * 64))
|
||||
# Val BPB from final eval
|
||||
@@ -102,10 +109,10 @@ for flops in "${FLOPS_BUDGETS[@]}"; do
|
||||
CORE_SCORE="0.0"
|
||||
fi
|
||||
|
||||
log " Params: $NUM_PARAMS, Iters: $NUM_ITERS, Ratio: $PARAM_DATA_RATIO, Val BPB: $VAL_BPB, CORE: $CORE_SCORE"
|
||||
log " Params: $PARAMS_TOTAL (transformer: $PARAMS_TRANSFORMER), Iters: $NUM_ITERS, Val BPB: $VAL_BPB, CORE: $CORE_SCORE"
|
||||
|
||||
# Append to CSV
|
||||
echo "$flops,$d,$MODEL_DIM,$NUM_PARAMS,$NUM_SCALING_PARAMS,$NUM_ITERS,$TOKENS_TRAINED,$PARAM_DATA_RATIO,$VAL_BPB,$CORE_SCORE,$TRAIN_TIME" >> "$RESULTS_FILE"
|
||||
echo "$flops,$d,$MODEL_DIM,$PARAMS_WTE,$PARAMS_BIGRAM,$PARAMS_VE,$PARAMS_LM,$PARAMS_TRANSFORMER,$PARAMS_SCALARS,$PARAMS_TOTAL,$NUM_ITERS,$TOKENS_TRAINED,$VAL_BPB,$CORE_SCORE,$TRAIN_TIME" >> "$RESULTS_FILE"
|
||||
done
|
||||
done
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""
|
||||
Train model. From root directory of the project, run as:
|
||||
|
||||
python -m scripts.base_train.py
|
||||
python -m scripts.base_train
|
||||
|
||||
or distributed as:
|
||||
|
||||
torchrun --nproc_per_node=8 -m scripts.base_train.py
|
||||
torchrun --nproc_per_node=8 -m scripts.base_train
|
||||
|
||||
If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
|
||||
python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20
|
||||
@@ -47,7 +47,7 @@ parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding
|
||||
# Training horizon (only one used, in order of precedence)
|
||||
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
|
||||
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
|
||||
parser.add_argument("--target-param-data-ratio", type=int, default=4, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
|
||||
parser.add_argument("--target-param-data-ratio", type=float, default=10.5, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
|
||||
# Optimization
|
||||
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
|
||||
@@ -178,9 +178,14 @@ if resuming:
|
||||
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
|
||||
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
||||
num_params = sum(p.numel() for p in model.parameters())
|
||||
num_scaling_params = orig_model.num_scaling_params()
|
||||
print0(f"Number of parameters: {num_params:,} (scaling: {num_scaling_params:,})")
|
||||
|
||||
# Detailed parameter counts
|
||||
param_counts = orig_model.num_scaling_params()
|
||||
print0(f"Parameter counts:")
|
||||
for key, value in param_counts.items():
|
||||
print0(f"{key:24s}: {value:,}")
|
||||
num_params = param_counts['total']
|
||||
num_scaling_params = param_counts['transformer_matrices'] + param_counts['lm_head'] # determined to give the cleanest scaling laws, see dev/LOG.md Jan 27, 2026
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
|
||||
|
||||
@@ -195,14 +200,14 @@ elif args.target_flops > 0:
|
||||
print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
|
||||
elif args.target_param_data_ratio > 0:
|
||||
# calculate the number of iterations from the target param data ratio (use scaling params per Kaplan et al.)
|
||||
target_tokens = args.target_param_data_ratio * num_scaling_params
|
||||
target_tokens = int(args.target_param_data_ratio * num_scaling_params)
|
||||
num_iterations = target_tokens // args.total_batch_size
|
||||
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
|
||||
else:
|
||||
raise ValueError("No training horizon specified")
|
||||
total_tokens = args.total_batch_size * num_iterations
|
||||
print0(f"Total number of training tokens: {total_tokens:,}")
|
||||
print0(f"Tokens : Params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20
|
||||
print0(f"Tokens : Scaling params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20
|
||||
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -445,7 +450,7 @@ get_report().log(section="Base model training", data=[
|
||||
"Number of FLOPs per token": f"{num_flops_per_token:e}",
|
||||
"Calculated number of iterations": num_iterations,
|
||||
"Number of training tokens": total_tokens,
|
||||
"Tokens : Params ratio": args.total_batch_size * num_iterations / num_params,
|
||||
"Tokens : Scaling params ratio": args.total_batch_size * num_iterations / num_scaling_params,
|
||||
"DDP world size": ddp_world_size,
|
||||
"warmup_ratio": args.warmup_ratio,
|
||||
"warmdown_ratio": args.warmdown_ratio,
|
||||
|
||||
Reference in New Issue
Block a user