mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-30 04:22:02 +00:00
Add learnable lambdas that gate the residual connection and a skip connection to the input embeddings, solid bump to val_bpb
This commit is contained in:
56
dev/LOG.md
56
dev/LOG.md
@@ -4,6 +4,62 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-11: Per-Layer Residual Scalars (x0 & resid lambdas)
|
||||
|
||||
Cherry-picked an idea from modded-nanogpt around learnable per-layer residual connections.
|
||||
|
||||
### Changes Made
|
||||
|
||||
**1. x0_lambdas (x0 residual connections)**
|
||||
- Save initial normalized embedding as `x0` after `norm(wte(idx))`
|
||||
- At each layer, blend x0 back in: `x = resid_lambdas[i] * x + x0_lambdas[i] * x0`
|
||||
- Zero-initialized, so disabled at start; model learns which layers benefit from the shortcut
|
||||
- Provides direct path from embedding to deep layers, helps preserve token information
|
||||
|
||||
**2. resid_lambdas (residual stream scaling)**
|
||||
- Per-layer multiplicative scaling of the residual stream
|
||||
- Initialized to 1.0 (neutral, standard transformer behavior)
|
||||
- Allows model to learn to amplify/dampen residual at each layer
|
||||
|
||||
**3. DistAdamW small parameter handling**
|
||||
- Added support for parameters with < 1024 elements (like the scalar lambdas)
|
||||
- Small params use `all_reduce` instead of `reduce_scatter`/`all_gather`
|
||||
- Fixes crash when param shape isn't divisible by world_size
|
||||
|
||||
### Key Finding: Different LR Sensitivity
|
||||
|
||||
The two scalar types need very different learning rates:
|
||||
- **x0_lambdas (additive)**: Can use normal LR (~0.5). Adding a fraction of x0 is forgiving.
|
||||
- **resid_lambdas (multiplicative)**: Needs ~100x smaller LR (~0.005). Multiplying the residual compounds through layers.
|
||||
|
||||
Implementation: `resid_params` gets `scalar_lr * 0.01`, `x0_params` gets full `scalar_lr`.
|
||||
|
||||
### Experiment Results
|
||||
|
||||
Swept `--scalar_lr` (controlling x0_lambdas) at multiple depths:
|
||||
|
||||
| Depth | Baseline (disabled) | Best scalar_lr | Best val_bpb | Δ bpb |
|
||||
|-------|---------------------|----------------|--------------|-------|
|
||||
| d8 | 1.0885 | 0.20 | 1.0782 | -0.0103 |
|
||||
| d12 | 0.9770 | 0.60 | 0.9693 | -0.0077 |
|
||||
| d16 | 0.9059 | 0.20 | 0.9002 | -0.0057 |
|
||||
| d20 | 0.8565 | 0.10 | 0.8526 | -0.0039 |
|
||||
|
||||
**Observations:**
|
||||
- Consistent improvement across all model sizes
|
||||
- Optimal LR varies by depth; default of 0.5 is reasonable, but 0.6 is better for d12
|
||||
- Adding resid_lambdas (with 0.01x LR) gives small additional improvement over x0 alone
|
||||
|
||||
### Meta Device Footgun
|
||||
|
||||
Important lesson: `__init__` runs in meta device context, so any tensor values set there are fake. Must initialize actual values in `init_weights()`. Added docstring warning to `__init__`.
|
||||
|
||||
### Summary
|
||||
|
||||
Added `--scalar_lr` (default 0.5) controlling learnable per-layer scalars. The formula `x = resid_lambdas[i] * x + x0_lambdas[i] * x0` gives the model control over residual scaling and direct shortcuts to the initial embedding. Solid improvement with essentially no compute overhead.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-10: Muon Optimizer Upgrades & Cautious Weight Decay
|
||||
|
||||
Cherry-picked improvements from NorMuon (modded-nanogpt) into our simpler Muon implementation. Decided against using NorMuon directly due to hard-coded architecture assumptions (expects 32 params split 10 attn + 22 mlp), parameter labeling requirements, and complexity.
|
||||
|
||||
Reference in New Issue
Block a user