mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-30 04:22:02 +00:00
document my struggle with fp8 integration yesterday, it's not working like i thought it would and i suffered. one day i will return to continue the fight.
This commit is contained in:
61
dev/LOG.md
61
dev/LOG.md
@@ -4,6 +4,67 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-13: FP8 Training for lm_head
|
||||
|
||||
Attempted to use FP8 (8-bit floating point) for the lm_head layer to speed up the large vocab projection matmul. H100 GPUs have FP8 tensor cores that can theoretically provide ~2x speedup over BF16.
|
||||
|
||||
### Implementation Approaches Tried
|
||||
|
||||
**1. Dynamic Scaling (failed)**
|
||||
- Compute `x.abs().max()` and `w.abs().max()` each forward to determine scales
|
||||
- Problem: `.item()` calls cause graph breaks with torch.compile
|
||||
- Tried `@torch._dynamo.allow_in_graph` pattern (like torchao.float8) - worked but no speedup
|
||||
- Tried `torch.library.custom_op` with float scales - caused NaN gradients after first optimizer step
|
||||
- Root cause: interaction between custom ops, dynamic scale computation, and torch.compile is fragile
|
||||
|
||||
**2. Static Scaling (partial success)**
|
||||
- Pre-set scales at init time like modded-nanogpt: `x_scale=10/448, w_scale=0.1/448`
|
||||
- `grad_scale` computed dynamically from batch size (safe since it's just `1/(B*T)/57344` due to the gradient expression of cross entropy). modded-nanogpt has a bug here probably because they set `grad_scale = 0.75/448`, but grads are in E5M2 so this should probably be `1/57344`, 1 being the amax of any individual element of cross entropy loss, and no normalization by B,T because they use sum reduction not mean reduction.
|
||||
- Uses `torch.library.custom_op` with `@torch.compile` on inner kernels
|
||||
- This works correctly - no NaNs, proper gradients
|
||||
|
||||
### Results (d12)
|
||||
|
||||
| Metric | BF16 Baseline | FP8 lm_head |
|
||||
|--------|---------------|-------------|
|
||||
| GPU Memory | 34 GB | 36 GB |
|
||||
| tok/sec | baseline | ~1% faster |
|
||||
|
||||
### The Memory Mystery
|
||||
|
||||
FP8 *should* save memory since we store `x_f8` (1 byte) instead of `x` (2 bytes) for backward. But we see 2GB *increase*. Suspected causes:
|
||||
- `torch.compile` on inner kernels creating extra buffers/specializations
|
||||
- `torch._scaled_mm` internal workspace allocations
|
||||
- Custom op registration machinery overhead
|
||||
|
||||
Tried saving original weight `w` (just a reference to parameter) instead of `w_f8` in backward, then re-quantizing on the spot during backward - didn't help. Still saw bump.
|
||||
|
||||
### Microbenchmark vs Reality
|
||||
|
||||
Raw microbenchmark showed promise:
|
||||
- BF16 matmul: 16.95 ms
|
||||
- FP8 matmul (static scales): 10.31 ms (1.64x faster)
|
||||
- FP8 with dynamic scaling: 12.25 ms (1.38x faster)
|
||||
|
||||
But in full training, the ~1% tok/sec improvement doesn't justify the 2GB memory increase and the added code complexity and the need to tune scale factors for both x and w.
|
||||
|
||||
### Code Artifacts
|
||||
|
||||
See the branch `fp8_attempt_fail` for:
|
||||
|
||||
- `nanochat/fp8_static.py` - Static scaling implementation (working)
|
||||
- `nanochat/fp8_dynamic.py` - Dynamic scaling implementation (torchao-style, working but slow)
|
||||
- `gpt.py` imports `fp8_static.LinearFP8` and simply swaps it for `lm_head` in `gpt.py`.
|
||||
|
||||
### Open Questions
|
||||
|
||||
- Why does the custom op approach use more memory than vanilla BF16?
|
||||
- Why is the bump in tok_per_sec so low? We should see ~1.6X speedup in both the forward pass and also (twice) in backward pass for the gradients. Granted, Ahmdal's law is part of the solution because our vocab_size is only 32K so the final layer isn't a huge part of the profile but the expected speedup is still not fully realized.
|
||||
|
||||
**Conclusion:** Negative result for now. The implementation works correctly but provides marginal speedup with *increased* memory usage. I'm not understanding the torch.compile interaction here. The complexity of FP8 custom ops isn't justified for lm_head alone. TODO to study in more detail the way this is implemented in other libraries, e.g. torchao.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-12: Multi-Token Prediction (MTP)
|
||||
|
||||
Ported multi-token prediction from modded-nanogpt. Instead of predicting just the next token, predict the next n tokens at each position with weighted loss.
|
||||
|
||||
Reference in New Issue
Block a user