Compare commits

...

52 Commits

Author SHA1 Message Date
Andrej Karpathy
41bb2eac32 Combine AdamW and Muon into single MuonAdamW optimizer, cleaner, ty @chrisjmccormick for idea/help 2026-01-29 00:52:08 +00:00
Andrej Karpathy
64a651a63c include .claude is ok 2026-01-29 00:35:02 +00:00
Andrej Karpathy
65df0de42b add arxiv reading skill 2026-01-29 00:34:24 +00:00
Andrej Karpathy
74554be3b5 revert engram, not seeing an improvement at larger scale 2026-01-28 20:07:39 +00:00
Sofie Van Landeghem
d5418ea5a1 Fix link to DeepSeek Engram paper (#470)
* Fix link to DeepSeek Engram paper in LOG.md

Updated link to the DeepSeek Engram paper in the log.

* remove www
2026-01-28 08:31:44 -08:00
Andrej Karpathy
c88bbf8133 Merge branch 'engram' 2026-01-27 22:33:16 +00:00
Andrej Karpathy
c8d93beed2 add engram-lite, add log, tune scaling laws analysis scripts 2026-01-27 22:31:17 +00:00
Andrej Karpathy
8630d32be4 quick fix to not OOM main speedrun script 2026-01-26 22:31:42 +00:00
Andrej Karpathy
59e36cc727 first version of engram following modded nanogpt style 2026-01-25 18:59:51 +00:00
Andrej Karpathy
85b3e95e09 320 experiments just to tune the adam beta1 of x0 a little bit up from 0.8 to 0.96 2026-01-25 00:04:02 +00:00
xiayan0118
6a477eedbd fix: pass device_type to compute_init in engine.__main__ (#451)
When running engine.py directly on non-GPU devices (CPU, MPS),
compute_init() needs the device_type parameter to initialize correctly.
This fixes failures on machines without CUDA support.
2026-01-19 17:19:51 -08:00
Andrej Karpathy
63bb5831e2 something i've wanted to do for a while - move all .sh runs to their own directory so they don't pollute root dir 2026-01-18 15:27:41 +00:00
Andrej Karpathy
a91743c168 Merge branch 've' 2026-01-18 15:14:39 +00:00
Andrej Karpathy
d58fcd9d73 log for jan 17 2026-01-18 03:01:17 +00:00
Andrej Karpathy
babde18ce1 small tweaks 2026-01-18 03:00:38 +00:00
Andrej Karpathy
cf5c9e5b8e resolve a crash for odd depths because FA3 needs head_dim % 8 == 0 2026-01-18 00:07:08 +00:00
Andrej Karpathy
413e91aa0f optimal ratio is now around 4 2026-01-17 23:51:09 +00:00
Andrej Karpathy
e7ed2082b8 update the default GPTConfig kwargs otherwise they are confusing 2026-01-17 21:16:46 +00:00
karpathy
f9a7e0f111 update the CPU/MPS script to give reasonable results. The model can at least answer that Paris is the capital of France and knows that the sky is blue, for about 40 minutes of training on my macbook. Also fixed a bug that existed due to KVCache bfloat16 dtype assumption 2026-01-17 12:27:30 -08:00
Andrej Karpathy
f5425245f9 more GPU types from PR 147 thanks @Qubitium 2026-01-17 03:22:20 +00:00
Andrej Karpathy
2955650327 add detection of device to report more correct mfu for bf16 2026-01-17 03:16:14 +00:00
Yury Kirpichev
77a46902e4 Fix WANDB_RUN parameter passing in runcpu.sh (#407)
- Add --run=$WANDB_RUN to base_train, mid_train, and chat_sft calls
- Ensures wandb logging works when WANDB_RUN environment variable is set
- Matches the behavior in speedrun.sh

Co-authored-by: svlandeg <svlandeg@github.com>
2026-01-16 18:59:44 -08:00
Barış Özmen
bbc4413c58 Add high value engine tests for core invariants (33 LoC) (#396)
* test: add engine generation tests for expected invariants

- test_seed_reproducibility
- test_temperature_zero_determinism
- test_max_tokens_respected
- test_num_samples_count

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* Fix temperature test

* add test for seed variation in sampling

Add test for seed variation in sampling with temperature > 0.

* Rename test for clarity

* Shorten assert msg

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
2026-01-16 18:59:12 -08:00
Nitish Pandey
f42ae9e901 fix condition to perform bpb evaluation (#324)
Co-authored-by: svlandeg <svlandeg@github.com>
2026-01-16 18:56:43 -08:00
Yamahammer
e1dafc510f Reduce token waste in BOS bestfit by cropping shortest doc (#445)
When no document fits the remaining row space, crop the shortest
document in the buffer instead of the first. This minimizes
discarded tokens.

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-16 18:50:34 -08:00
Andrej Karpathy
6460dc6382 tweaks to readme a bit 2026-01-17 02:28:31 +00:00
Andrej Karpathy
1933e85046 brief update to log 2026-01-17 00:25:50 +00:00
Andrej Karpathy
3b95d4fd39 allow label for scaling laws script 2026-01-17 00:23:30 +00:00
Andrej Karpathy
e85db6b4a4 alternating design 2026-01-16 23:52:12 +00:00
Andrej Karpathy
9a88194c3f simply one VE per layer, works best 2026-01-16 22:08:52 +00:00
Andrej Karpathy
0b58d70e99 full ve version works very well 2026-01-16 21:16:47 +00:00
Andrej Karpathy
e3f58b838e ranked version 2026-01-16 20:59:42 +00:00
Andrej Karpathy
184d4c12b1 also add to log about the FA3 changes 2026-01-16 18:25:04 +00:00
Andrej Karpathy
b62a5bc44a naturally i failed to include the actual code in the previous commit facepalm 2026-01-16 17:39:41 +00:00
Andrej Karpathy
8203efa919 implement flash attention 3 fallback to pytorch sdpa by touching as few lines of code as possible in main files and keeping all implementation to a single file. add tests. add helpful warning messages for the user. 2026-01-16 17:37:51 +00:00
Haoyu Wang
50413d2d67 typo in comments: change "GAPO" to "DAPO" 2026-01-15 22:03:42 -08:00
Andrej Karpathy
fbf2bbea25 update log with a bunch of attempts 2026-01-16 02:21:17 +00:00
Andrej Karpathy
747ed4491f add negative result on olmo3 pretraining mix 2026-01-16 00:44:01 +00:00
Andrej Karpathy
7d1700c521 add zstd lib 2026-01-16 00:44:01 +00:00
Sofie Van Landeghem
d4ea28d4e2 Fix args in readme (#438)
* fix commands in readme, using new arg format

* fix typo

* add required -i flag to chat_eval example runs
2026-01-15 16:26:38 -08:00
Andrej Karpathy
bdcc030ffa oops legacy spurious line now 2026-01-15 23:32:20 +00:00
Andrej Karpathy
22a71aa3d3 fuse adamw into a single torch compiled kernel similar to muon. it's about 1.7X faster, but overall it's so tiny that it's not making a major dent 2026-01-15 23:30:44 +00:00
Andrej Karpathy
255f8b9af6 cleanly separate cpu and gpu sections 2026-01-15 23:30:11 +00:00
Andrej Karpathy
6bb92403d5 changes and optimizations to muon, making it more efficient and simpler/cleaner a bit 2026-01-15 03:20:48 +00:00
Andrej Karpathy
3142ca1a28 minor helpful message 2026-01-15 03:20:21 +00:00
Andrej Karpathy
7312ec9898 fix buggy midtrain and update all kwargs to be idiomatic. that is, argparse uses dashes variables use underscores. the underscores are just a remnant of the previous Configurator object. This is the right way 2026-01-13 22:45:27 +00:00
Andrej Karpathy
3b50b77ed3 fix base_loss to report correct loss by switching the dataloader to the new default 2026-01-13 22:09:36 +00:00
Andrej Karpathy
f92efce169 add negative result about not allowing attention across BOS tokens. A lot more code complexity for basically no gain in performance 2026-01-13 21:33:54 +00:00
Andrej Karpathy
43c29dd9d5 Big DataLoader refactor: BOS-aligned dataloaders with epoch tracking for pre/mid-training
The new DataLoader ensures that every token sequence in train/val batches has a BOS token
at the beginning. Therefore, no token streams start abruptly in the middle of a document,
which could be confusing for the model. Note that this changes the loss scale because there
are fewer confusing tokens in the train/val batches. The main downside is that we now waste
about 35% of tokens due to cropping. This is ok because we have a lot of data. See dev/LOG.md
entry for this change for a lot more information.
2026-01-13 20:05:47 +00:00
Andrej Karpathy
23985413aa adjust the comment on the regex pattern per recent experimnet see dev/LOG.md 2026-01-13 17:50:39 +00:00
Andrej Karpathy
64b48d0e5c validated that \p{N}{1,2} is the correct number of digits to group up to in the regex pattern of the GPT-4 tokenizer (2 down from 3), leading to the best val_bpb for 32K vocabs 2026-01-13 17:45:06 +00:00
Andrej Karpathy
238353c998 document my struggle with fp8 integration yesterday, it's not working like i thought it would and i suffered. one day i will return to continue the fight. 2026-01-13 17:14:29 +00:00
33 changed files with 2683 additions and 880 deletions

View File

@@ -0,0 +1,40 @@
---
name: read-arxiv-paper
description: Use this skill when when asked to read an arxiv paper given an arxiv URL
---
You will be given a URL of an arxiv paper, for example:
https://www.arxiv.org/abs/2601.07372
### Part 1: Normalize the URL
The goal is to fetch the TeX Source of the paper (not the PDF!), the URL always looks like this:
https://www.arxiv.org/src/2601.07372
Notice the /src/ in the url. Once you have the URL:
### Part 2: Download the paper source
Fetch the url to a local .tar.gz file. A good location is `~/.cache/nanochat/knowledge/{arxiv_id}.tar.gz`.
(If the file already exists, there is no need to re-download it).
### Part 3: Unpack the file in that folder
Unpack the contents into `~/.cache/nanochat/knowledge/{arxiv_id}` directory.
### Part 4: Locate the entrypoint
Every latex source usually has an entrypoint, such as `main.tex` or something like that.
### Part 5: Read the paper
Once you've found the entrypoint, Read the contents and then recurse through all other relevant source files to read the paper.
#### Part 6: Report
Once you've read the paper, produce a summary of the paper into a markdown file at `./knowledge/summary_{tag}.md`. Notice that 1) use the local knowledge directory here (it's easier for me to open and reference here), not in `~/.cache`, and 2) generate some reasonable `tag` like e.g. `conditional_memory` or whatever seems appropriate given the paper. Probably make sure that the tag doesn't exist yet so you're not overwriting files.
As for the summary itself, remember that you're processing this paper within the context of the nanochat repository, so most often we we will be interested in how to apply the paper and its lessons to the nanochat project. Therefore, you should feel free to "remind yourself" of the related nanochat code by reading the relevant parts, and then explicitly make the connection of how this paper might relate to nanochat or what are things we might be inspired about or try.

1
.gitignore vendored
View File

@@ -9,6 +9,5 @@ eval_bundle/
.env
# Local setup
.claude
CLAUDE.md
wandb/

View File

@@ -4,28 +4,29 @@
> The best ChatGPT that $100 can buy.
This repo is a full-stack implementation of an LLM like ChatGPT in a single, clean, minimal, hackable, dependency-lite codebase. nanochat is designed to run on a single 8XH100 node via scripts like [speedrun.sh](speedrun.sh), that run the entire pipeline start to end. This includes tokenization, pretraining, finetuning, evaluation, inference, and web serving over a simple UI so that you can talk to your own LLM just like ChatGPT. nanochat will become the capstone project of the course LLM101n being developed by Eureka Labs.
## Talk to it
To get a sense of the endpoint of this repo, you can currently find [nanochat d34](https://github.com/karpathy/nanochat/discussions/314) hosted on [nanochat.karpathy.ai](https://nanochat.karpathy.ai/). "d34" means that this model has 34 layers in the Transformer neural network. This model has 2.2 billion parameters, it was trained on 88 billion tokens by simply running the training script [run1000.sh](run1000.sh) with `--target_param_data_ratio=40` (2x longer than Chinchilla-optimal), and the total cost of training was ~$2,500 (about 100 hours training time on 8XH100 GPU node). While today this is enough to outperform GPT-2 of 2019, it falls dramatically short of modern Large Language Models like GPT-5. When talking to these micro models, you'll see that they make a lot of mistakes, they are a little bit naive and silly and they hallucinate a ton, a bit like children. It's kind of amusing. But what makes nanochat unique is that it is fully yours - fully configurable, tweakable, hackable, and trained by you from start to end. To train and talk to your own, we turn to...
This repo is a full-stack implementation of an LLM like ChatGPT in a single, clean, minimal, hackable, dependency-lite codebase. nanochat is designed to run on a single 8XH100 node via scripts like [speedrun.sh](runs/speedrun.sh), that run the entire pipeline start to end. This includes tokenization, pretraining, finetuning, evaluation, inference, and web serving over a simple UI so that you can talk to your own LLM just like ChatGPT. nanochat will become the capstone project of the course LLM101n being developed by Eureka Labs.
## Updates
- (Jan 7 2026) See new post: [nanochat Miniseries v1](https://github.com/karpathy/nanochat/discussions/420) and the associated script [miniseries.sh](miniseries.sh).
- (Jan 16 2026) The repo is in active development, I am currently fleshing out the pretraining stage.
- (Jan 7 2026) See new post: [nanochat Miniseries v1](https://github.com/karpathy/nanochat/discussions/420) and the associated script [miniseries.sh](runs/miniseries.sh).
## Talk to it
To get a sense of the endpoint of this repo, you can currently find [nanochat d34](https://github.com/karpathy/nanochat/discussions/314) hosted on [nanochat.karpathy.ai](https://nanochat.karpathy.ai/). "d34" means that this model has 34 layers in the Transformer neural network. This model has 2.2 billion parameters, it was trained on 88 billion tokens by simply running the training script [run1000.sh](runs/run1000.sh) with `--target_param_data_ratio=40` (2x longer than Chinchilla-optimal), and the total cost of training was ~$2,500 (about 100 hours training time on 8XH100 GPU node). While today this is enough to outperform GPT-2 of 2019, it falls dramatically short of modern Large Language Models like GPT-5. When talking to these micro models, you'll see that they make a lot of mistakes, they are a little bit naive and silly and they hallucinate a ton, a bit like children. It's kind of amusing. But what makes nanochat unique is that it is fully yours - fully configurable, tweakable, hackable, and trained by you from start to end. To train and talk to your own, we turn to...
## Quick start
The fastest way to feel the magic is to run the speedrun script [speedrun.sh](speedrun.sh), which trains and inferences the $100 tier of nanochat. On an 8XH100 node at $24/hr, this gives a total run time of about 4 hours. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script:
The fastest way to feel the magic is to run the speedrun script [speedrun.sh](runs/speedrun.sh), which trains and inferences the $100 tier of nanochat. On an 8XH100 node at $24/hr, this gives a total run time of about 4 hours. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script:
```bash
bash speedrun.sh
bash runs/speedrun.sh
```
Alternatively, since the script runs for 4 hours, I like to launch it like this inside a new screen session `speedrun` (and also log output to `speedrun.log`):
```bash
screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
screen -L -Logfile speedrun.log -S speedrun bash runs/speedrun.sh
```
See the [screen cheatsheet](https://gist.github.com/jctosta/af918e1618682638aa82) if you are less familiar. You can watch it go inside the screen session, or detach with `Ctrl-a d` and `tail speedrun.log` to view progress. Now wait 4 hours. Once it's done, you can talk to your LLM via the ChatGPT-like web UI. Make sure again that your local uv virtual environment is active (run `source .venv/bin/activate`), and serve it:
@@ -72,7 +73,7 @@ Total wall clock time: 3h51m
Unsurprisingly, $100 is not enough to train a highly performant ChatGPT clone. In fact, LLMs are famous for their multi-million dollar capex. For our purposes, I think there are two more scales of interest. First is the ~$300 tier d26 model (i.e. depth=26) that trains in ~12 hours, which slightly outperforms GPT-2 CORE score. Second is the $1000 tier (~41.6 hours), just because it's a nice round number. But both of these are not yet fully supported and therefore not attached here in the master branch yet.
That said, to give a sense, the example changes needed for the [speedrun.sh](speedrun.sh) file to train a GPT-2 grade model d26 only involve three changes:
That said, to give a sense, the example changes needed for the [speedrun.sh](runs/speedrun.sh) file to train a GPT-2 grade model d26 only involve three changes:
```bash
...
@@ -82,10 +83,10 @@ That said, to give a sense, the example changes needed for the [speedrun.sh](spe
python -m nanochat.dataset -n 450 &
...
# use --depth to increase model size. to not oom, halve device batch size 32 -> 16:
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=26 --device_batch_size=16
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=26 --device-batch-size=16
...
# make sure to use the same later during midtraining:
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device-batch-size=16
```
That's it! The biggest thing to pay attention to is making sure you have enough data shards to train on (the code will loop and do more epochs over the same training set otherwise, decreasing learning speed a bit), and managing your memory/VRAM, primarily by decreasing the `device_batch_size` until things fit (the scripts automatically compensate by increasing the number of gradient accumulation loops, simply turning parallel compute to sequential compute).
@@ -99,7 +100,7 @@ And a bit more about computing environments that will run nanochat:
## Running on CPU / MPS
nanochat can be run on CPU or on MPS (if you're on Macbook), and will automatically try to detect what device is best to run on. You're not going to get too far without GPUs, but at least you'll be able to run the code paths and maybe train a tiny LLM with some patience. For an example of how to make all the run commands much smaller (feel free to tune!), you can refer to [dev/runcpu.sh](dev/runcpu.sh) file. You'll see that I'm essentially restricting all scripts to train smaller models, to run for shorter number of iterations, etc. This functionality is new, slightly gnarly (touched a lot of code), and was merged in this [CPU|MPS PR](https://github.com/karpathy/nanochat/pull/88) on Oct 21, 2025.
nanochat can be run on CPU or on MPS (if you're on Macbook) in principle, and will automatically try to detect what device is best to run on. The script [runcpu.sh](runs/runcpu.sh) shows a very simple example that will exercise the code paths but basically produce garbage results. Unless you know what you're doing, I basically don't recommend using this script right now and hope to tune it a bit more in the future.
## Customization
@@ -109,15 +110,9 @@ Additionally, to add new abilities to nanochat, see [Guide: counting r in strawb
## Questions
nanochat is designed to be short and sweet. One big advantage of this is that we can package up all of the files together and copy paste them to your favorite LLM to ask arbitrary questions. As an example, I like to package up the repo using the [files-to-prompt](https://github.com/simonw/files-to-prompt) utility like so:
I recommend using [DeepWiki](https://deepwiki.com/karpathy/nanochat) from Devin/Cognition to ask questions of this repo. In the URL of this repo, simply change github.com to deepwiki.com, and you're off.
```bash
files-to-prompt . -e py -e md -e html -e toml -e sh --cxml > packaged.txt
```
This includes all py, html, toml, sh files and chooses the cxml output format. Everything is written to the `packaged.txt` file, which atm measures ~330KB (i.e. well below ~100K tokens for a state of the art LLM), and ~8K lines of code in 45 files.
Alternatively, I recommend using [DeepWiki](https://deepwiki.com/karpathy/nanochat) from Devin/Cognition to ask questions of this repo. In the URL of this repo, simply change github.com to deepwiki.com, and you're off.
You can also come to the [#nanochat Discord channel](https://discord.com/channels/1020383067459821711/1427295580895314031) to ask questions, or use the Discussions.
## Tests
@@ -137,11 +132,9 @@ python -m pytest tests/test_engine.py -v -s
│ ├── gen_synthetic_data.py # Example synthetic data for identity
│ ├── generate_logo.html
│ ├── nanochat.png
── repackage_data_reference.py # Pretraining data shard generation
│ └── runcpu.sh # Small example of how to run on CPU/MPS
── repackage_data_reference.py # Pretraining data shard generation
├── nanochat
│ ├── __init__.py # empty
│ ├── adamw.py # Distributed AdamW optimizer
│ ├── checkpoint_manager.py # Save/Load model checkpoints
│ ├── common.py # Misc small utilities, quality of life
│ ├── core_eval.py # Evaluates base model CORE score (DCLM paper)
@@ -152,12 +145,17 @@ python -m pytest tests/test_engine.py -v -s
│ ├── gpt.py # The GPT nn.Module Transformer
│ ├── logo.svg
│ ├── loss_eval.py # Evaluate bits per byte (instead of loss)
│ ├── muon.py # Distributed Muon optimizer
│ ├── optim.py # AdamW + Muon optimizer, 1GPU and distributed
│ ├── report.py # Utilities for writing the nanochat Report
│ ├── tokenizer.py # BPE Tokenizer wrapper in style of GPT-4
│ └── ui.html # HTML/CSS/JS for nanochat frontend
├── pyproject.toml
├── run1000.sh # Train the ~$800 nanochat d32
├── runs
│ ├── miniseries.sh # Miniseries training script
│ ├── run1000.sh # Train the ~$800 nanochat d32
│ ├── runcpu.sh # Small example of how to run on CPU/MPS
│ ├── scaling_laws.sh # Scaling laws experiments
│ └── speedrun.sh # Train the ~$100 nanochat d20
├── scripts
│ ├── base_eval.py # Base model: calculate CORE score
│ ├── base_loss.py # Base model: calculate bits per byte, sample
@@ -170,7 +168,6 @@ python -m pytest tests/test_engine.py -v -s
│ ├── mid_train.py # Chat model: midtraining
│ ├── tok_eval.py # Tokenizer: evaluate compression rate
│ └── tok_train.py # Tokenizer: train it
├── speedrun.sh # Train the ~$100 nanochat d20
├── tasks
│ ├── arc.py # Multiple choice science questions
│ ├── common.py # TaskMixture | TaskSequence

View File

@@ -4,6 +4,485 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
---
## 2026-01-28: Reverted Bigram Hash Embeddings
Removed bigram embeddings (engram-lite) from the codebase. At larger scale (d25), the improvement was tiny and disappeared entirely when measured by wall clock time. It also bloated the VRAM used. The extra parameters and complexity aren't justified.
---
## 2026-01-27: Bigram Hash Embeddings (Engram-lite)
Explored N-gram memory modules inspired by the [DeepSeek Engram paper](https://arxiv.org/abs/2601.07372) and [modded-nanogpt PR #201](https://github.com/KellerJordan/modded-nanogpt/pull/201).
### Background
The Engram paper introduces "conditional memory" as a complement to MoE - using O(1) hash lookups to retrieve static N-gram patterns instead of reconstructing them through computation. Key insight: transformers waste early layers "simulating retrieval through computation" for patterns like named entities and formulaic phrases that could be simple table lookups.
### What We Tried
**1. Full Engram module with context-aware gating (paper design)**
```python
# Hash bigrams to retrieve embeddings, then gate with hidden state
e = embed(hash(prev_token, curr_token))
q = RMSNorm(h) # hidden state as query
k = RMSNorm(W_k @ e) # projected embedding as key
v = W_v @ e
α = sigmoid(q · k / d) # scalar gate per position
output = α * v
```
- Injected after block 1 (paper found early injection optimal)
- Slight improvement, but quite a bit of complexity added.
**2. Early-layer only injection**
- Only inject bigram signal in first 4 layers (where paper claims static pattern offloading helps most)
- **Result:** Actually hurt performance. The model seems to need uniform injection across all layers.
**3. Trigrams**
- Extended to hash both 2-grams and 3-grams, concatenating embeddings
- **Result:** No improvement over bigrams alone. Dilutes capacity from more frequent 2-gram patterns.
**4. Bigram-only with x0-style injection (modded-nanogpt engram-lite approach)**
- Simple hash: `(36313 * curr) XOR (27191 * prev) mod table_size`
- Zero-init embedding table, learned per-layer lambdas
- Add to residual at every layer: `x = resid_λ[i]*x + x0_λ[i]*x0 + bigram_λ[i]*x0_bigram`
- **Result:** This simple approach works and provides a consistent improvement.
TLDR The winning approach follows modded-nanogpt's "engram-lite", simply adding the following module and feeding its output into the residual branch (gated by a per-layer learnable \lambda) before every single block:
```python
class BigramEmbed(nn.Module):
def __init__(self, vocab_size, embed_dim, table_multiplier=5):
self.embed = nn.Embedding(vocab_size * table_multiplier, embed_dim)
def forward(self, idx):
h = (36313 * idx[:, 1:]) ^ (27191 * idx[:, :-1]) % (table_size - 1)
return self.embed(h)
```
As for optimal hyperparameters:
- **Table size:** `vocab_size * 5` (~164K entries for 32K vocab). Swept a number of settings and 5 was optimal.
- **Injection:** Every layer via learned `bigram_lambdas` (init 0.1 was better than 0.0).
- **Normalization:** Also tried adding a `norm()` to the embeddings (mirroring the token embeddings), this was slightly worse.
- **Init:** Zero-init embedding, so starts as identity (tried small noisy init, it's worse)
- **Optimizer:** AdamW with same LR as token embeddings
### Key Learnings
1. **Gating didn't help at our scale.** The paper's context-aware gating mechanism (sigmoid dot-product gate) added parameters and complexity without improvement. modded-nanogpt found the same: "simple direct addition to the residual stream outperformed by a decent margin."
2. **Uniform injection beats early-only.** Despite the paper's finding that early layers benefit most, restricting injection to early layers hurt. The x0-style "add everywhere with learned lambda" pattern works better for our architecture/scale.
3. **Bigrams are sufficient.** Trigrams didn't help - the extra context doesn't pay for the diluted capacity.
4. **Scale matters.** The Engram paper's results are at 27B params with MoE. At our ~100M-1B scale, the simpler approach wins. The elaborate gating mechanism may become useful at larger scales where collision handling matters more.
### Parameters Added
For d12 model with `table_multiplier=5`:
- Bigram embedding: 32768 × 5 × 768 = ~126M params
- Per-layer lambdas: 12 scalars (negligible)
If you're keeping track, we now have *a lot* of parameters, a significant amount of them in embeddings (token embeddings, bigram embeddings, value embeddings). For example, for a d12 we now have:
```
Parameter counts:
wte : 25,165,824
bigram_embed : 125,829,120
value_embeds : 150,994,944
lm_head : 25,165,824
transformer_matrices : 84,935,808
scalars : 36
total : 412,091,556
```
In other words, only about a quarter of parameters are now weight projections and the vast majority is embedding tables.
Still, on all axes (steps, wall clock time, flops), this somewhat parameter-bloated architecture beats the baseline and will now become the default.
After adding the engram-lite, I re-ran the scaling laws to determine the new optimal tokens:params ratio. I swept FLOPs in the range 1e18..1e19, exponentially strided in 4 settings (1e18, 2e18, 5e18, 1e19). I looked at a number of ways of determining the effective parameter count for the purposes of the scaling laws. The results looked like this:
```
Kaplan-style (all projections including lm_head and no embeddings)
Optimal configurations (from quadratic fits):
FLOPs Eff Params Tokens Ratio Val BPB
-----------------------------------------------------------------
1e+18 110,678,115 1,241,505,403 11.2 0.8972
2e+18 167,797,457 1,785,336,422 10.7 0.8616
5e+18 250,650,865 2,642,234,152 10.8 0.8293
1e+19 381,758,347 3,806,871,243 10.3 0.7999
N \propto C^0.54, D \propto C^0.49
Chinchilla-style (all parameters, period.)
Optimal configurations (from quadratic fits):
FLOPs Eff Params Tokens Ratio Val BPB
-----------------------------------------------------------------
1e+18 416,320,605 1,232,157,011 3.0 0.8974
2e+18 560,239,841 1,763,669,281 3.2 0.8616
5e+18 741,495,903 2,629,909,368 3.6 0.8291
1e+19 988,644,331 3,884,841,895 4.0 0.7999
N \propto C^0.37, D \propto C^0.50
Transformer-only-style (only the projections inside the transformer)
Optimal configurations (from quadratic fits):
FLOPs Eff Params Tokens Ratio Val BPB
-----------------------------------------------------------------
1e+18 80,259,665 1,315,639,547 17.2 0.8966
2e+18 131,488,566 1,864,134,141 14.5 0.8622
5e+18 220,985,474 2,595,328,843 12.1 0.8302
1e+19 401,213,504 3,328,704,512 8.5 0.7994
N \propto C^0.70, D \propto C^0.41
```
Clearly, the Kaplan-style ratios are most consistent and produce stable ~0.5 exponents for both params and tokens, meaning we can have a single fixed ratio of tokens:params for compute optimal models. This turns out to be about ~10.5, which now becomes the new default.
---
## 2026-01-19 to 2026-01-22: Optimizer Hyperparameter Sweep
Ran ~320 experiments across 6 rounds, scaling from d12→d16→d20 to find optimal optimizer hyperparameters. Added granular per-component control to `setup_optimizers()` — separate LRs and betas for embedding, unembedding, value_embeds, resid_lambdas, x0_lambdas, and Muon matrix params.
### What We Swept
- Learning rates for all 6 parameter groups
- Beta1/beta2 for all 5 AdamW groups
- Muon momentum (start/end), weight decay
- Hundreds of combinations (2-way, 3-way, 4-way, etc.)
### The Journey
**At d12**, found two independent improvement routes:
- **Route A:** emb_lr↑ (0.3→0.4), weight_decay↑ (0.1→0.15), matrix_lr↑ (0.02→0.025)
- **Route B:** x0_lr↓ (0.5→0.2), x0_beta1↑ (0.8→0.9+)
Both gave ~0.002 improvement, but combining them caused conflicts. Fine-tuning found wd=0.13, matrix_lr=0.027, emb_lr=0.38 helped slightly. Best d12 config: Route A + x0_beta1=0.95.
**At d16**, Route B became competitive with Route A. The routes still conflicted when combined.
**At d20** (target scale), everything changed:
- Fine-tuned values from d12 **actively hurt** performance
- Routes no longer conflicted
- Just `x0_beta1=0.96` alone captured nearly all the gains
### Final x0_beta1 Sweep at d20
| x0_beta1 | val/bpb | Δ vs baseline |
|----------|---------|---------------|
| **0.96** | **0.7971** | **-0.0007** |
| 0.94 | 0.7972 | -0.0006 |
| 0.90 | 0.7972 | -0.0006 |
| 0.97 | 0.7977 | -0.0001 |
| 0.98 | 0.8011 | +0.0033 💀 |
Flat plateau from 0.90-0.96, then sharp cliff at 0.97+.
### Key Learnings
1. **Hyperparameters are scale-dependent.** What works at d12 doesn't transfer to d20. The elaborate fine-tuning that won at d12 actively hurts at d20.
2. **Improvement magnitude shrinks with scale.** ~0.002 at d12 → ~0.0007 at d20. The baseline is already better-tuned for larger models.
3. **Sharp cliffs exist.** x0_beta1=0.98 is catastrophic while 0.96 is optimal.
4. **Don't over-tune on small proxies.** Validate at target scale before shipping.
### Final Recommendation
For production d20 runs, add one flag:
```
--x0-lambdas-beta1=0.96
```
Skip everything else discovered at smaller scales.
---
## 2026-01-18: More various experiments
- Tried Muon custom kernels for XXT and all the others. The improvement was there for targeted tests (~20%) but washed out completely to noise in an actual training run, especially because the Muon compute is split across all the workers. Abandoned due to complexity bloat.
- Fuse Q,K,V,O nn.Linear layers into a single QKVO Linear layer. ~Zero impact
- Tried the `sa_lambdas` that gate QKV and O. Slightly confused because of the use of rmsnorm, which erases the effect of any scalar multiplier. Helped a tiny bit (~1e-4 of loss), abandoned to control complexity.
---
## 2026-01-17: Various experiments
Modded-nanogpt uses [Value Embeddings](https://arxiv.org/abs/2410.17897) (VEs) in a funny U-shaped structure, 3 of them in total and with gates. I tried a large number of tweaks on this today:
- VEs at every layer, at alternating layers, U shaped, front and back. Alternating layers worked best, i.e. we end up with *a lot* more VEs than modded-nanogpt, at every other layer. It works better.
- Many parameters sharing ideas to reduce new parameter count, nothing here worked. All failed.
- Many ideas to reduce parameter count, the LLM hates all of them: low rank decompositions, projections. All failed.
- Gated yes or no and how much. Gate helps.
Long story short is that the models *love* Value Embeddings. It is a way to add a huge amount of capacity (parameters) to the model at almost zero cost of FLOPs, because these embeddings are simply added to the Values tensor. Any attempt to reduce the capacity of value embeddings (param sharing, low rank, projections) fail. The model wants many of them, and with all the capacity, and doing so wins across all x axes of steps, flops and wall clock. I re-ran the scaling laws and, because the models are now very parameter bloated, the optimal ratio has halved from 8 to 4! Way down lower than Chinchilla's 20 at this point.
Other experiments, looking at val/bpb as a function of all of steps, flops and wall clock time:
- Aspect ratio of 128 is worse than 64, I tried a sweep fixing FLOPs == 1e18 and 64 outperforms. The LLM prefers to be slightly thinner and longer.
- Head dim definitely prefers to be 128 instead of 64, i.e. fewer bigger heads
- Bunch of other random stuff like that.
Keeping all of this work on a private branch for now but hope to push shortly.
---
## 2026-01-17: Modded-nanogpt Ideas Sweep (Continued)
Continued testing ideas from modded-nanogpt.
| Idea | Result | Notes |
|------|--------|-------|
| Attention gates | No improvement | Per-head learnable gates on attention output. +1GB memory, decreased efficiency. |
| Batch size schedule | Abandoned | 8→16→24 with LR scaling. Made training script too bloated/complex, not worth cognitive overhead. |
| Value embeddings | Helps a lot | Experiments still ongoing, more on this later. |
---
## 2026-01-16: Flash Attention 3 Fallback to SDPA
Added automatic fallback from Flash Attention 3 to PyTorch's `scaled_dot_product_attention` (SDPA) for users without Hopper GPUs. This enables nanochat to run on older CUDA GPUs, CPU, and MPS (Apple Silicon).
### Implementation
Created `nanochat/flash_attention.py` - a unified interface that:
- Detects FA3 availability at import time (requires sm90+ / Hopper)
- Exports a `flash_attn` object matching FA3's API exactly (`flash_attn.flash_attn_func`, `flash_attn.flash_attn_with_kvcache`)
- Automatically routes to FA3 or SDPA based on hardware
- Handles tensor layout differences: FA3 uses (B, T, H, D), SDPA uses (B, H, T, D)
- Implements sliding window attention via explicit masks for SDPA
- Manages KV cache manually for SDPA (FA3 does it in-place)
### Changes to Existing Files
Changes to existing code were intentionally kept extremely minimal.
**gpt.py**: Only the import line changed and a comment
**engine.py**: Zero changes needed
**base_train.py**: Added status print and warnings:
- Prints whether FA3 or SDPA fallback is being used
- Warns about efficiency loss without FA3
- Warns about sliding window support if `--window-pattern` is not "L"
### Testing
Tests are split into two classes due to dtype/device constraints:
1. **TestFA3VsSDPA**: Comparison tests requiring Hopper GPU + bfloat16. Run both implementations on identical inputs and verify outputs match (max diff typically 0, at most ~0.004 for sliding window).
2. **TestSDPAOnly**: SDPA-only tests that run on any device with appropriate dtype. Verify forward pass, backward pass, and KV cache work correctly.
Added `_override_impl` mechanism for testing - can force 'fa3' or 'sdpa' to directly compare implementations.
### Notes
- SDPA fallback is significantly slower than FA3 especially in that it lacks the sliding window attention support
- Recommend `--window-pattern L` (full context) when using SDPA fallback
---
## 2026-01-16: Modded-nanogpt Ideas Sweep (Mostly Negative)
Tested several architectural ideas from modded-nanogpt to see if they transfer to nanochat. All of these did not help:
| Idea | Result | Notes |
|------|--------|-------|
| Half-truncated RoPE | No improvement | Only first half of head dims get RoPE (base 1024, linspace). Second half "stationary". |
| Asymmetric softcap | Slightly worse | `23 * sigmoid((x+5)/7.5)` vs our symmetric `15 * tanh(x/15)`. May only help with FP8. |
| Smear gate | Negligible | Blend each token with predecessor via learned gate. Tiny improvement not worth n_embd² params. |
| Backout | No improvement | Save activations at ~60% through network, subtract scaled version at end. |
| Skip connection | Slightly worse | Save at layer ~25%, add at layer ~50%. Also +2GB memory from storing activations. |
Value Embeddings do show promise. I need a more elaborate exploration of a few related ideas, which I leave for tomorrow.
---
## 2026-01-15: Olmo pretraining mix (Negative result)
I attempted to train on the Olmo 3 pretraining dataset [allenai/dolma3_mix-6T](https://huggingface.co/datasets/allenai/dolma3_mix-6T) instead of FineWeb-edu. I ran into a number of [errors and issues](https://huggingface.co/datasets/allenai/dolma3_mix-6T/discussions/2) trying to both download and process the dataset and then noticed some quality issues (e.g. some documents seem to be extremely short, like "5".). I managed to work around these with some sensible hacks (e.g. reject documents less than 100 characters in length) and tried to process the dataset exactly as FineWeb, re-trained the tokenizer and trained a d16 model. The CORE score decreased from 15.5 to 13.8, i.e. the result is quite a bit worse.
I am still looking to try the [DCLM dataset](https://arxiv.org/abs/2406.11794), which according to the paper should be better that FineWeb-edu. I do have some concerns that the same group both prepared the DCLM dataset *and* introduced the CORE score so I'm a bit hesitant in case there was some overfitting to CORE score adjacent data distribution.
Classifying as negative result and reverting back to FineWeb-edu for now.
---
## 2026-01-13: Varlen Attention (Negative Result)
Attempted to prevent attention from "leaking" across document boundaries using Flash Attention's `flash_attn_varlen_func`, similar to modded-nanogpt's approach.
### Background
With the BOS-aligned dataloader, multiple documents are packed into each row. Standard attention allows tokens to attend across document boundaries within a row. The hypothesis was that preventing this "leakage" via varlen attention might improve training.
### Approach: Compute cu_seqlens from inputs
- Find BOS positions: `(inputs.view(-1) == bos_token_id).nonzero()`
- Gotcha 1: Variable-length `cu_seqlens` caused torch.compile recompilation (25s/iter!) - fixed by padding to fixed size
- Gotcha 2: `nonzero()` inside compiled model hit recompile limit - fixed by moving computation outside compiled region
### Final Results (d16)
| Metric | Baseline | Varlen |
|--------|----------|--------|
| val_bpb | 0.85427 | 0.85407 |
| MFU | ~same | ~same |
| tok/sec | ~same | ~same |
Essentially identical. The 0.0002 bpb improvement is almost noise.
### Conclusion
Not worth the code complexity. The "leakage" across document boundaries within a row is not harmful - the model handles it fine. The BOS-aligned dataloader already provides the key benefit (every row starts with proper context). Not merging to master.
---
## 2026-01-13: BOS-Aligned Dataloader with Bin Packing
Redesigned the pretraining and midtraining dataloader to ensure every sequence starts with a BOS token, and explored bin-packing algorithms to minimize wasted tokens.
### Problem Statement
The original dataloader streams tokens into a flat buffer and reshapes into batches. This means some rows start mid-document (no BOS), which could confuse the model during training. We want every row to start with BOS and contain well-formed documents.
### Approach 1: Greedy-Crop BOS (Simple)
Each row is built independently:
- Start with a document (which has BOS prepended)
- Pack more documents until row is full
- If a document doesn't fit, **crop it** to fill remaining space (discard the rest)
- 100% utilization (no padding), but wastes cropped tokens
### Waste Analysis
Measured token waste empirically on real data (T=2048):
- **39.4% of tokens are cropped** (discarded when docs don't fit)
- **22.9% is the theoretical minimum** (tokens in docs longer than T+1 that can never fit)
- The extra ~16.5% comes from "unlucky" cropping when a long doc starts near the end of a row
### Bin Packing Algorithms Explored
| Algorithm | Util% | Crop% | Pad% | Notes |
|-----------|-------|-------|------|-------|
| Greedy-Crop (baseline) | 100% | 39.4% | 0% | Simple, no wasted compute |
| Greedy-Pad | 78% | 23.0% | 22% | Pads instead of crops - wastes compute |
| First-Fit Decreasing (FFD) | 99.7% | 23.0% | 0.3% | Near-optimal packing, minimal padding |
| **BestFit-Crop** | 100% | 34.6% | 0% | Smart cropping, no padding |
### BestFit-Crop Algorithm
A middle ground that maintains 100% utilization while reducing cropping:
1. Buffer N documents
2. For each row, greedily pick the **largest doc that fits entirely**
3. Repeat until nothing fits
4. When nothing fits, crop a doc to fill remaining space exactly
This avoids "unlucky" crops by searching the buffer for better-fitting documents.
**Results (T=2048):**
- Crop waste reduced from 39.4% → 34.6% (~12% relative improvement)
- Still achieves 100% utilization (no padding, every token trains)
- Slightly more rows than baseline (uses more documents per batch)
### Decision: Keep Two Implementations
1. Keep the original implementation which is very simple, efficient and has 100% token utilization in the batch (no padding with ignore tokens), but creates slightly more confusing token streams for the LLM because documents during training can start abruptly from the middle with no context. Note that this never happens at test time, where BOS is always present.
2. **`_bos_bestfit` (BestFit-Crop, new default)**: Slightly more complex but still keeps 100% token utilization in the batch (no padding), but at the cost of discarding documents when they don't fit. In practice, about 34% of tokens are discarded with this approach. This is ok because for most models we care about we have plenty of data without having to go to multiple epochs. One more subtle effect is that it does skew the data distribution a tiny bit because, reliably and necessarily, tokens at the tails of long documents will be discarded. However, this doesn't seem to impact actual downstream performance.
### Midtraining
The midtraining dataloader was also updated. Because conversations are on average a lot shorter than pretraining documents, only about 3.3% of tokens get cropped.
### NOTE: loss scale
Do note that switching to the BOS dataloader changes the validation loss and makes all previous experiments not comparable in absolute value of the loss, because we have a lot fewer "confusing" tokens in the train/val batches. All tokens can look back and find the BOS token and have the full context of that document to make predictions. Therefore, the loss appears lower but this is "fake" to some extent, and the expectation is that the vast majority of relative comparisons done so far would agree with those before and after this change.
---
## 2026-01-13: Number Token Split Pattern
Validated the `\p{N}{1,2}` pattern in `SPLIT_PATTERN` (tokenizer.py line 30), which I only guessed earlier and had a TODO for to validate. GPT-4 uses `\p{N}{1,3}` to group number sequences of up to 3 digits into tokens, but we suspected smaller vocab sizes benefit from grouping fewer digits per token.
**Results (d12, vocab=32K):**
| Pattern | val_bpb |
|---------|---------|
| `\p{N}{1,1}` | 0.969 |
| `\p{N}{1,2}` | **0.965** |
| `\p{N}{1,3}` | 0.972 |
**Conclusion:** `{1,2}` is optimal for vocab size 32K. Grouping 3 digits wastes tokens on rare 3-digit combinations; grouping 1 digit is too fine-grained and bloats token sequences. Keeping `{1,2}` as default.
---
## 2026-01-13: FP8 Training for lm_head
Attempted to use FP8 (8-bit floating point) for the lm_head layer to speed up the large vocab projection matmul. H100 GPUs have FP8 tensor cores that can theoretically provide ~2x speedup over BF16.
### Implementation Approaches Tried
**1. Dynamic Scaling (failed)**
- Compute `x.abs().max()` and `w.abs().max()` each forward to determine scales
- Problem: `.item()` calls cause graph breaks with torch.compile
- Tried `@torch._dynamo.allow_in_graph` pattern (like torchao.float8) - worked but no speedup
- Tried `torch.library.custom_op` with float scales - caused NaN gradients after first optimizer step
- Root cause: interaction between custom ops, dynamic scale computation, and torch.compile is fragile
**2. Static Scaling (partial success)**
- Pre-set scales at init time like modded-nanogpt: `x_scale=10/448, w_scale=0.1/448`
- `grad_scale` computed dynamically from batch size (safe since it's just `1/(B*T)/57344` due to the gradient expression of cross entropy). modded-nanogpt has a bug here probably because they set `grad_scale = 0.75/448`, but grads are in E5M2 so this should probably be `1/57344`, 1 being the amax of any individual element of cross entropy loss, and no normalization by B,T because they use sum reduction not mean reduction.
- Uses `torch.library.custom_op` with `@torch.compile` on inner kernels
- This works correctly - no NaNs, proper gradients
### Results (d12)
| Metric | BF16 Baseline | FP8 lm_head |
|--------|---------------|-------------|
| GPU Memory | 34 GB | 36 GB |
| tok/sec | baseline | ~1% faster |
### The Memory Mystery
FP8 *should* save memory since we store `x_f8` (1 byte) instead of `x` (2 bytes) for backward. But we see 2GB *increase*. Suspected causes:
- `torch.compile` on inner kernels creating extra buffers/specializations
- `torch._scaled_mm` internal workspace allocations
- Custom op registration machinery overhead
Tried saving original weight `w` (just a reference to parameter) instead of `w_f8` in backward, then re-quantizing on the spot during backward - didn't help. Still saw bump.
### Microbenchmark vs Reality
Raw microbenchmark showed promise:
- BF16 matmul: 16.95 ms
- FP8 matmul (static scales): 10.31 ms (1.64x faster)
- FP8 with dynamic scaling: 12.25 ms (1.38x faster)
But in full training, the ~1% tok/sec improvement doesn't justify the 2GB memory increase and the added code complexity and the need to tune scale factors for both x and w.
### Code Artifacts
See the branch `fp8_attempt_fail` for:
- `nanochat/fp8_static.py` - Static scaling implementation (working)
- `nanochat/fp8_dynamic.py` - Dynamic scaling implementation (torchao-style, working but slow)
- `gpt.py` imports `fp8_static.LinearFP8` and simply swaps it for `lm_head` in `gpt.py`.
### Open Questions
- Why does the custom op approach use more memory than vanilla BF16?
- Why is the bump in tok_per_sec so low? We should see ~1.6X speedup in both the forward pass and also (twice) in backward pass for the gradients. Granted, Ahmdal's law is part of the solution because our vocab_size is only 32K so the final layer isn't a huge part of the profile but the expected speedup is still not fully realized.
**Conclusion:** Negative result for now. The implementation works correctly but provides marginal speedup with *increased* memory usage. I'm not understanding the torch.compile interaction here. The complexity of FP8 custom ops isn't justified for lm_head alone. TODO to study in more detail the way this is implemented in other libraries, e.g. torchao.
---
## 2026-01-12: Multi-Token Prediction (MTP)
Ported multi-token prediction from modded-nanogpt. Instead of predicting just the next token, predict the next n tokens at each position with weighted loss.

View File

@@ -1,74 +0,0 @@
#!/bin/bash
# Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks)
# Run as:
# bash dev/cpu_demo_run.sh
# NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook.
# Think of this run as educational/fun demo, not something you should expect to work well.
# This is also why I hide this script away in dev/
# all the setup stuff
export OMP_NUM_THREADS=1
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
mkdir -p $NANOCHAT_BASE_DIR
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
[ -d ".venv" ] || uv venv
uv sync --extra cpu
source .venv/bin/activate
if [ -z "$WANDB_RUN" ]; then
WANDB_RUN=dummy
fi
# wipe the report
python -m nanochat.report reset
# train tokenizer on ~1B characters
python -m nanochat.dataset -n 4
python -m scripts.tok_train --max_chars=1000000000
python -m scripts.tok_eval
# train a very small 4 layer model on the CPU
# each optimization step processes a single sequence of 1024 tokens
# we only run 50 steps of optimization (bump this to get better results)
python -m scripts.base_train \
--depth=4 \
--max_seq_len=1024 \
--device_batch_size=1 \
--total_batch_size=1024 \
--eval_every=50 \
--eval_tokens=4096 \
--core_metric_every=50 \
--core_metric_max_per_task=12 \
--sample_every=50 \
--num_iterations=50
python -m scripts.base_loss --device_batch_size=1 --split_tokens=4096
python -m scripts.base_eval --max-per-task=16
# midtraining
python -m scripts.mid_train \
--max_seq_len=1024 \
--device_batch_size=1 \
--eval_every=50 \
--eval_tokens=4096 \
--total_batch_size=1024 \
--num_iterations=100
# eval results will be terrible, this is just to execute the code paths.
# note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems
python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20
# SFT
python -m scripts.chat_sft \
--device_batch_size=1 \
--target_examples_per_step=4 \
--num_iterations=100 \
--eval_steps=4 \
--eval_metrics_max_problems=16
# Chat CLI
# python -m scripts.chat_cli -p "Why is the sky blue?"
# Chat Web
# python -m scripts.chat_web
python -m nanochat.report generate

View File

@@ -15,14 +15,16 @@
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import os\n",
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Load results\n",
"tag = \"jan26\"\n",
"base_dir = os.environ.get('NANOCHAT_BASE_DIR', os.path.expanduser('~/.cache/nanochat'))\n",
"results_path = os.path.join(base_dir, 'scaling_laws_results', 'results.csv')\n",
"results_path = os.path.join(base_dir, f'scaling_laws_results_{tag}', 'results.csv')\n",
"\n",
"df = pd.read_csv(results_path)\n",
"flops_budgets = sorted(df['flops_budget'].unique())\n",
@@ -31,6 +33,99 @@
"df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# =============================================================================\n",
"# FILTERING: Remove incomplete or problematic runs\n",
"# =============================================================================\n",
"\n",
"print(f\"Before filtering: {len(df)} runs\")\n",
"\n",
"# Filter out runs with missing/invalid val_bpb (incomplete runs)\n",
"df = df[df['val_bpb'].notna() & (df['val_bpb'] > 0)]\n",
"\n",
"# Optional: exclude specific flops budgets that aren't done yet\n",
"# exclude_flops = [1e19] # <-- adjust as runs complete\n",
"# df = df[~df['flops_budget'].isin(exclude_flops)]\n",
"\n",
"# Optional: exclude specific depths\n",
"# exclude_depths = [18, 20]\n",
"# df = df[~df['depth'].isin(exclude_depths)]\n",
"\n",
"print(f\"After filtering: {len(df)} runs\")\n",
"print(f\"FLOPs budgets: {sorted(df['flops_budget'].unique())}\")\n",
"print(f\"Depths: {sorted(df['depth'].unique())}\")\n",
"\n",
"# Update flops_budgets list after filtering\n",
"flops_budgets = sorted(df['flops_budget'].unique())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Effective Parameter Count\n",
"\n",
"Different scaling law papers use different conventions for counting parameters:\n",
"- **Kaplan et al.** excluded embedding parameters (claimed cleaner laws)\n",
"- **Chinchilla** included all parameters (and noted Kaplan had a bug)\n",
"\n",
"Our CSV now has granular counts:\n",
"- `params_wte` - token embedding (lookup table)\n",
"- `params_bigram_embed` - bigram hash embeddings (lookup table)\n",
"- `params_value_embeds` - value embeddings (lookup table)\n",
"- `params_lm_head` - unembedding projection (matmul)\n",
"- `params_transformer` - attention + MLP matrices (matmuls)\n",
"- `params_scalars` - resid/x0/bigram lambdas (tiny)\n",
"\n",
"**Experiment below** with different combinations to see which gives the cleanest scaling laws."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# =============================================================================\n",
"# EXPERIMENT HERE: Define which parameters to count for scaling laws\n",
"# =============================================================================\n",
"\n",
"def compute_effective_params(row):\n",
" \"\"\"\n",
" Compute the 'effective' parameter count for scaling law analysis.\n",
"\n",
" Modify this function to experiment with different conventions:\n",
" - Chinchilla-style: include everything\n",
" - Kaplan-style: exclude embeddings\n",
" - Matmul-only: just transformer + lm_head (the actual compute)\n",
" - etc.\n",
" \"\"\"\n",
" # Option 1: Chinchilla-style (all params)\n",
" # return row['params_total']\n",
"\n",
" # Option 2: Kaplan-style (exclude embeddings)\n",
" return row['params_transformer'] + row['params_lm_head']\n",
"\n",
" # Option 3: Transformer-only (exclude all embeddings AND lm_head)\n",
" # return row['params_transformer']\n",
"\n",
"\n",
"# Compute derived columns\n",
"df['effective_params'] = df.apply(compute_effective_params, axis=1)\n",
"df['param_data_ratio'] = df['tokens_trained'] / df['effective_params']\n",
"\n",
"# Show parameter breakdown for first few rows\n",
"print(\"Parameter breakdown (first row per flops budget):\")\n",
"param_cols = ['depth', 'params_wte', 'params_bigram_embed', 'params_value_embeds',\n",
" 'params_lm_head', 'params_transformer', 'params_scalars', 'params_total', 'effective_params']\n",
"df.groupby('flops_budget').first()[param_cols]"
]
},
{
"cell_type": "markdown",
"metadata": {},
@@ -54,11 +149,11 @@
"optimal_by_bpb = []\n",
"\n",
"for flops, color in zip(flops_budgets, colors):\n",
" subset = df[df['flops_budget'] == flops].sort_values('num_scaling_params')\n",
" ax.plot(subset['num_scaling_params'], subset['val_bpb'], 'o', color=color, label=f'{flops:.0e}', markersize=8)\n",
" subset = df[df['flops_budget'] == flops].sort_values('effective_params')\n",
" ax.plot(subset['effective_params'], subset['val_bpb'], 'o', color=color, label=f'{flops:.0e}', markersize=8)\n",
"\n",
" # Fit quadratic in log-space: val_bpb = a*(log N)^2 + b*(log N) + c\n",
" log_params = np.log10(subset['num_scaling_params'])\n",
" log_params = np.log10(subset['effective_params'])\n",
" coeffs = np.polyfit(log_params, subset['val_bpb'], 2)\n",
" a, b, c = coeffs\n",
"\n",
@@ -83,13 +178,13 @@
" # Fallback to raw minimum if quadratic doesn't have minimum\n",
" best_idx = subset['val_bpb'].idxmin()\n",
" best = subset.loc[best_idx]\n",
" ax.scatter([best['num_scaling_params']], [best['val_bpb']], s=150, color=color,\n",
" ax.scatter([best['effective_params']], [best['val_bpb']], s=150, color=color,\n",
" zorder=5, edgecolors='black', linewidths=2)\n",
" optimal_by_bpb.append({'flops': flops, 'params': best['num_scaling_params'],\n",
" optimal_by_bpb.append({'flops': flops, 'params': best['effective_params'],\n",
" 'tokens': best['tokens_trained'], 'ratio': best['param_data_ratio'], 'bpb': best['val_bpb']})\n",
"\n",
"ax.set_xscale('log')\n",
"ax.set_xlabel('Parameters')\n",
"ax.set_xlabel('Effective Parameters')\n",
"ax.set_ylabel('Validation Loss (bpb)')\n",
"ax.set_title('IsoFLOP Curves')\n",
"ax.legend(title='FLOPs', loc='upper right')\n",
@@ -138,10 +233,61 @@
"\n",
"# Print the optimal points (from quadratic fits)\n",
"print(\"\\nOptimal configurations (from quadratic fits):\")\n",
"print(f\"{'FLOPs':<12} {'Params':<15} {'Tokens':<15} {'Ratio':<10} {'Val BPB':<10}\")\n",
"print(f\"{'FLOPs':<12} {'Eff Params':<15} {'Tokens':<15} {'Ratio':<10} {'Val BPB':<10}\")\n",
"print(\"-\" * 65)\n",
"for _, row in opt_df.iterrows():\n",
" print(f\"{row['flops']:<12.0e} {int(row['params']):<15,} {int(row['tokens']):<15,} {row['ratio']:<10.1f} {row['bpb']:<10.4f}\")\n"
" print(f\"{row['flops']:<12.0e} {int(row['params']):<15,} {int(row['tokens']):<15,} {row['ratio']:<10.1f} {row['bpb']:<10.4f}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# =============================================================================\n",
"# Optimal Ratio Summary (from power law fits)\n",
"# =============================================================================\n",
"\n",
"# From the power law fits: N ∝ C^a and D ∝ C^b\n",
"# The ratio D/N ∝ C^(b-a). If a ≈ b, ratio is roughly constant.\n",
"\n",
"if len(opt_df) >= 2:\n",
" log_f = np.log10(opt_df['flops'])\n",
" log_p = np.log10(opt_df['params'])\n",
" log_t = np.log10(opt_df['tokens'])\n",
"\n",
" # Fit power laws\n",
" slope_n, intercept_n = np.polyfit(log_f, log_p, 1)\n",
" slope_d, intercept_d = np.polyfit(log_f, log_t, 1)\n",
"\n",
" # The ratio D/N at a reference compute (geometric mean of our budgets)\n",
" ref_flops = np.sqrt(opt_df['flops'].min() * opt_df['flops'].max())\n",
" log_ref = np.log10(ref_flops)\n",
"\n",
" # Predicted optimal N and D at reference compute\n",
" pred_log_n = intercept_n + slope_n * log_ref\n",
" pred_log_d = intercept_d + slope_d * log_ref\n",
" optimal_ratio = 10**(pred_log_d - pred_log_n)\n",
"\n",
" # Also compute from the fitted optimals directly (mean and std)\n",
" mean_ratio = opt_df['ratio'].mean()\n",
" std_ratio = opt_df['ratio'].std()\n",
"\n",
" print(\"=\" * 60)\n",
" print(\"OPTIMAL RATIO SUMMARY\")\n",
" print(\"=\" * 60)\n",
" print(f\"\\nPower law exponents:\")\n",
" print(f\" N ∝ C^{slope_n:.3f}\")\n",
" print(f\" D ∝ C^{slope_d:.3f}\")\n",
" print(f\" Ratio exponent (b-a): {slope_d - slope_n:.3f} (should be ~0 if ratio is constant)\")\n",
" print(f\"\\nOptimal ratio (tokens per effective param):\")\n",
" print(f\" From power law at C={ref_flops:.1e}: {optimal_ratio:.1f}\")\n",
" print(f\" Mean across budgets: {mean_ratio:.1f} ± {std_ratio:.1f}\")\n",
" print(f\" Chinchilla reference: 20\")\n",
" print(f\"\\nPer-budget ratios: {[f'{r:.1f}' for r in opt_df['ratio'].values]}\")\n",
"else:\n",
" print(\"Need at least 2 flops budgets to compute power law fits\")"
]
},
{

View File

@@ -1,95 +0,0 @@
"""
Borrowed from modded-nanogpt. By Keller, @vagrawal, et al.
Not a general optimizer! But works for our specific use.
"""
import torch
import torch.distributed as dist
from torch import Tensor
class DistAdamW(torch.optim.Optimizer):
"""
Distributed AdamW optimizer.
In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction
"""
def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(param_groups, defaults)
@torch.no_grad()
def step(self):
rank = dist.get_rank()
world_size = dist.get_world_size()
reduce_futures: list[torch.Future] = []
gather_futures: list[torch.Future] = []
grad_slices = []
is_small = [] # track which params are small (use all_reduce) vs large (use reduce_scatter)
for group in self.param_groups:
params: list[Tensor] = group["params"]
for p in params:
grad = p.grad
# Small params: use all_reduce (no scatter/gather needed)
if p.numel() < 1024:
is_small.append(True)
reduce_futures.append(dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
grad_slices.append(grad)
else:
is_small.append(False)
assert p.shape[0] % world_size == 0, f"First dim of parameter shape {p.shape} must be divisible by world size {world_size}"
rank_size = grad.shape[0] // world_size
grad_slice = torch.empty_like(grad[:rank_size])
reduce_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
grad_slices.append(grad_slice)
idx = 0
for group in self.param_groups:
beta1, beta2 = group['betas']
eps = group['eps']
wd = group['weight_decay']
params = group['params']
for p in params:
reduce_futures[idx].wait()
g_slice = grad_slices[idx]
lr = group['lr'] * getattr(p, "lr_mul", 1.0)
state = self.state[p]
# For small params, operate on full param; for large, operate on slice
if is_small[idx]:
p_slice = p
else:
rank_size = p.shape[0] // world_size
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
# State init
if not state:
state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device)
state['exp_avg'] = torch.zeros_like(p_slice)
state['exp_avg_sq'] = torch.zeros_like(p_slice)
exp_avg = state['exp_avg']
exp_avg_sq = state['exp_avg_sq']
state['step'] += 1
t = state['step']
# weight decay
if wd != 0:
eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0)
p_slice.mul_(1 - eff_weight_decay)
# update running averages
exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2)
# bias corrections
bias1 = 1 - beta1 ** t
bias2 = 1 - beta2 ** t
# compute step
denom = (exp_avg_sq / bias2).sqrt().add_(eps)
step_size = lr / bias1
update = exp_avg.div(denom).mul_(step_size)
p_slice.add_(other=update, alpha=-1.0)
# Only large params need all_gather
if not is_small[idx]:
gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
idx += 1
if gather_futures:
torch.futures.collect_all(gather_futures).wait()

View File

@@ -25,6 +25,7 @@ def _patch_missing_config_keys(model_config_kwargs):
# Old models were trained with full context (no sliding window)
if "window_pattern" not in model_config_kwargs:
model_config_kwargs["window_pattern"] = "L"
log0(f"Patching missing window_pattern in model config to 'L'")
def _patch_missing_keys(model_data, model_config):
"""Add default values for new parameters that may be missing in old checkpoints."""
@@ -32,9 +33,11 @@ def _patch_missing_keys(model_data, model_config):
# resid_lambdas defaults to 1.0 (identity scaling)
if "resid_lambdas" not in model_data:
model_data["resid_lambdas"] = torch.ones(n_layer)
log0(f"Patching missing resid_lambdas in model data to 1.0")
# x0_lambdas defaults to 0.0 (disabled)
if "x0_lambdas" not in model_data:
model_data["x0_lambdas"] = torch.zeros(n_layer)
log0(f"Patching missing x0_lambdas in model data to 0.0")
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
if rank == 0:
@@ -108,7 +111,7 @@ def build_model(checkpoint_dir, step, device, phase):
# Load the Tokenizer
tokenizer = get_tokenizer()
# Sanity check: compatibility between model and tokenizer
assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"]
assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"], f"Tokenizer vocab size {tokenizer.get_vocab_size()} does not match model config vocab size {model_config_kwargs['vocab_size']}"
return model, tokenizer, meta_data

View File

@@ -200,3 +200,77 @@ class DummyWandb:
pass
def finish(self):
pass
# hardcoded BF16 peak flops for various GPUs
# inspired by torchtitan: https://github.com/pytorch/torchtitan/blob/main/torchtitan/tools/utils.py
# and PR: https://github.com/karpathy/nanochat/pull/147
def get_peak_flops(device_name: str) -> float:
name = device_name.lower()
# --- NVIDIA Blackwell ---
if "gb200" in name or "grace blackwell" in name:
return 2.5e15
if "b200" in name:
return 2.25e15
if "b100" in name:
return 1.8e15
# --- NVIDIA Hopper (H100/H200/H800) ---
if "h200" in name:
if "nvl" in name or "pcie" in name:
return 836e12
return 989e12 # H200 SXM
if "h100" in name:
if "nvl" in name:
return 835e12
if "pcie" in name:
return 756e12
return 989e12 # H100 SXM
if "h800" in name:
if "nvl" in name:
return 989e12
return 756e12 # H800 PCIe
# --- NVIDIA Ampere data center ---
if "a100" in name or "a800" in name:
return 312e12
if "a40" in name:
return 149.7e12
if "a30" in name:
return 165e12
# --- NVIDIA Ada data center ---
if "l40s" in name or "l40-s" in name or "l40 s" in name:
return 362e12
if "l4" in name:
return 121e12
# --- AMD CDNA accelerators ---
if "mi355" in name:
return 2.5e15
if "mi325" in name or "mi300x" in name:
return 1.3074e15
if "mi300a" in name:
return 980.6e12
if "mi250x" in name:
return 383e12
if "mi250" in name:
return 362.1e12
# --- Intel ---
if "data center gpu max 1550" in name:
# Ponte Vecchio (PVC) - dynamic based on compute units
max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units
return 512 * max_comp_units * 1300 * 10**6
# --- Consumer RTX (for hobbyists) ---
if "5090" in name:
return 209.5e12
if "4090" in name:
return 165.2e12
if "3090" in name:
return 71e12
# Unknown GPU - return inf so MFU shows as 0% rather than a wrong guess
logger.warning(f"Peak flops undefined for: {device_name}, MFU will show as 0%")
return float('inf')

View File

@@ -1,4 +1,25 @@
from collections import deque
"""
Distributed dataloaders for pretraining.
Two implementations are provided:
1. Original (tokenizing_distributed_data_loader):
- Streams tokens into a flat buffer, reshapes to (B, T)
- Rows may start mid-document (no guaranteed BOS at position 0)
- 100% token utilization, simple and efficient
2. BOS-aligned bestfit (tokenizing_distributed_data_loader_bos_bestfit):
- Every row starts with BOS token
- Documents packed using best-fit algorithm to minimize cropping
- When no document fits remaining space, crops a document to fill exactly
- 100% utilization (no padding), ~35% tokens cropped at T=2048
The tradeoff: BOS-aligned loses ~35% of tokens to cropping, but ensures that
there are fewer "confusing" tokens in the train/val batches as every token can
now attend back to the BOS token and sees the full context of the document.
(2) is the new default if you have enough data.
Fallback to (1) if you have very limited data AND long documents.
"""
import torch
import pyarrow.parquet as pq
@@ -6,86 +27,173 @@ import pyarrow.parquet as pq
from nanochat.common import get_dist_info
from nanochat.dataset import list_parquet_files
def _document_batches(split, resume_state_dict, tokenizer_batch_size):
"""
Infinite iterator over document batches (list of text strings) from parquet files.
Handles DDP sharding and approximate resume. Each yield is (text_batch, (pq_idx, rg_idx, epoch))
where text_batch is a list of document strings, indices track position for resumption,
and epoch counts how many times we've cycled through the dataset (starts at 1).
"""
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
parquet_paths = list_parquet_files()
assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?"
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
resume_epoch = resume_state_dict.get("epoch", 1) if resume_state_dict is not None else 1
first_pass = True
pq_idx = resume_pq_idx
epoch = resume_epoch
while True: # iterate infinitely (multi-epoch)
pq_idx = resume_pq_idx if first_pass else 0
while pq_idx < len(parquet_paths):
filepath = parquet_paths[pq_idx]
pf = pq.ParquetFile(filepath)
# Start from resume point if resuming on same file, otherwise from DDP rank
if first_pass and (resume_rg_idx is not None) and (pq_idx == resume_pq_idx):
base_idx = resume_rg_idx // ddp_world_size
base_idx += 1 # advance by 1 so we don't repeat data after resuming
rg_idx = base_idx * ddp_world_size + ddp_rank
if rg_idx >= pf.num_row_groups:
pq_idx += 1
continue
resume_rg_idx = None # only do this once
else:
rg_idx = ddp_rank
while rg_idx < pf.num_row_groups:
rg = pf.read_row_group(rg_idx)
batch = rg.column('text').to_pylist()
for i in range(0, len(batch), tokenizer_batch_size):
yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx, epoch)
rg_idx += ddp_world_size
pq_idx += 1
first_pass = False
epoch += 1
def tokenizing_distributed_data_loader_with_state(tokenizer, B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None):
"""
Stream pretraining text from parquet files, tokenize, yield training batches.
This implementation became a bit more complex because we wish to support approximate resume training.
Instead of turning this into a Class, we opt to return the state_dict with every batch,
and then the caller can pass in a state_dict to resume training from a desired point.
Note that this resumption is atm only *approximate* for simplicity.
We won't repeat the same documents but we might skip a few.
The state_dict that is returned can be later passed into this function via `resume_state_dict` to approximately resume.
This is the original dataloader that streams tokens into a flat buffer and reshapes.
Rows may start mid-document (no guaranteed BOS at position 0).
Perfect state resumption is possible but would be a lot more bloated, probably not worth it atm.
Supports approximate resume via state_dict.
"""
assert split in ["train", "val"], "split must be 'train' or 'val'"
# infinite iterator over document batches (list of text strings)
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
def document_batches():
parquet_paths = list_parquet_files()
assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?"
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
first_pass = True
pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0)
while True: # iterate infinitely (multi-epoch)
pq_idx = resume_pq_idx if first_pass else 0
while pq_idx < len(parquet_paths): # iterate over all parquet files
filepath = parquet_paths[pq_idx]
pf = pq.ParquetFile(filepath)
# Start from resume point if resuming on same file, otherwise from DDP rank
# I know this state resumption is a little bit tricky and a little bit hacky... sigh.
if first_pass and (resume_rg_idx is not None) and (pq_idx == resume_pq_idx):
base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size
base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming
rg_idx = base_idx * ddp_world_size + ddp_rank
if rg_idx >= pf.num_row_groups:
pq_idx += 1
continue
resume_rg_idx = None # set to None as we only want to do this a single time
else:
rg_idx = ddp_rank
while rg_idx < pf.num_row_groups:
rg = pf.read_row_group(rg_idx)
batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows
# the tokenizer encode might want to go in even smaller batches, e.g. 128 rows
for i in range(0, len(batch), tokenizer_batch_size):
yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx)
rg_idx += ddp_world_size # advance to the next row group (in DDP)
pq_idx += 1 # advance to the next parquet file
first_pass = False
batches = document_batches()
# Now emit batches of tokens.
needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
batches = _document_batches(split, resume_state_dict, tokenizer_batch_size)
needed_tokens = B * T + 1 # +1 for target at last position
bos_token = tokenizer.get_bos_token_id()
# scratch buffer holds the tokens for one iteration
token_buffer = deque() # we stream tokens on the right and pop from the left
token_buffer = []
pq_idx, rg_idx, epoch = 0, 0, 1
while True:
# Accumulate enough tokens for one iteration before yielding.
# Accumulate enough tokens
while len(token_buffer) < needed_tokens:
doc_batch, (pq_idx, rg_idx) = next(batches)
doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
for tokens in token_lists:
token_buffer.extend(tokens)
# Move tokens from the deque into the scratch buffer
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
# CUDA supports memory pinning for asynchronous transfers between CPU and GPU
use_cuda_optimizations = device == "cuda"
scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations) # in PyTorch, long=int64
# Create the inputs/targets as 1D tensors
inputs_cpu = scratch[:-1]
targets_cpu = scratch[1:]
# Reshape to 2D and move to GPU async
inputs = inputs_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
targets = targets_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx} # we need this in case we wish to approximately resume training
yield inputs, targets, state_dict
tokens = token_buffer[:needed_tokens] # Read B*T+1 tokens (+1 is only for the target for the last token)
token_buffer = token_buffer[B*T:] # Advance by B*T tokens, so we move exactly one window of B*T tokens over
# Package tokens into inputs and targets, yield
use_cuda = device == "cuda"
scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda)
inputs = scratch[:-1].view(B, T).to(device=device, non_blocking=use_cuda)
targets = scratch[1:].view(B, T).to(device=device, non_blocking=use_cuda)
yield inputs, targets, {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
def tokenizing_distributed_data_loader(*args, **kwargs):
# helper function that only emits the inputs/targets and not the state_dict
"""Helper that omits state_dict from yields."""
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs):
yield inputs, targets
def tokenizing_distributed_data_loader_with_state_bos_bestfit(
tokenizer, B, T, split,
tokenizer_threads=4, tokenizer_batch_size=128,
device="cuda", resume_state_dict=None,
buffer_size=1000
):
"""
BOS-aligned dataloader with Best-Fit Cropping.
Reduces token waste compared to simple greedy cropping by searching a buffer
for documents that fit well, while maintaining 100% utilization (no padding).
Algorithm for each row:
1. From buffered docs, pick the LARGEST doc that fits entirely
2. Repeat until no doc fits
3. When nothing fits, crop a doc to fill remaining space exactly
Key properties:
- Every row starts with BOS
- 100% utilization (no padding, every token is trained on)
- Approximately 35% of all tokens are discarded due to cropping
"""
assert split in ["train", "val"], "split must be 'train' or 'val'"
row_capacity = T + 1
batches = _document_batches(split, resume_state_dict, tokenizer_batch_size)
bos_token = tokenizer.get_bos_token_id()
doc_buffer = []
pq_idx, rg_idx, epoch = 0, 0, 1
def refill_buffer():
nonlocal pq_idx, rg_idx, epoch
doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
for tokens in token_lists:
doc_buffer.append(tokens)
while True:
rows = []
for _ in range(B):
row = []
while len(row) < row_capacity:
# Ensure buffer has documents
while len(doc_buffer) < buffer_size:
refill_buffer()
remaining = row_capacity - len(row)
# Find largest doc that fits entirely
best_idx = -1
best_len = 0
for i, doc in enumerate(doc_buffer):
doc_len = len(doc)
if doc_len <= remaining and doc_len > best_len:
best_idx = i
best_len = doc_len
if best_idx >= 0:
doc = doc_buffer.pop(best_idx)
row.extend(doc)
else:
# No doc fits - crop shortest in buffer to fill remaining and minimize waste
shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]))
doc = doc_buffer.pop(shortest_idx)
row.extend(doc[:remaining])
rows.append(row[:row_capacity])
use_cuda = device == "cuda"
batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
inputs = batch_tensor[:, :-1].to(device=device, non_blocking=use_cuda)
targets = batch_tensor[:, 1:].to(device=device, non_blocking=use_cuda)
yield inputs, targets, {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
def tokenizing_distributed_data_loader_bos_bestfit(*args, **kwargs):
"""Helper that omits state_dict from yields."""
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state_bos_bestfit(*args, **kwargs):
yield inputs, targets

View File

@@ -90,7 +90,7 @@ class KVCache:
- Position tracked per batch element via cache_seqlens tensor
"""
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype=torch.bfloat16):
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype):
self.batch_size = batch_size
self.max_seq_len = seq_len
self.n_layers = num_layers
@@ -172,6 +172,13 @@ class Engine:
"""Same as generate, but does single prefill and then clones the KV cache."""
assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
device = self.model.get_device()
# NOTE: setting the dtype here and in this way is an ugly hack.
# Currently the repo assumes that cuda -> bfloat16 and everything else -> float32.
# We need to know the dtype here to call __init__ on KVCache and pre-allocate its tensors.
# As a quick hack, we're making generate() function inherit and know about this repo-wise assumption.
# I think there has to be a bigger refactor to deal with device/dtype tracking across the codebase.
# In particular, the KVCache should allocate its tensors lazily
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
rng = torch.Generator(device=device)
rng.manual_seed(seed)
@@ -191,6 +198,7 @@ class Engine:
batch_size=1,
seq_len=len(tokens),
device=device,
dtype=dtype,
**kv_model_kwargs,
)
ids = torch.tensor([tokens], dtype=torch.long, device=device)
@@ -203,6 +211,7 @@ class Engine:
batch_size=num_samples,
seq_len=kv_length_hint,
device=device,
dtype=dtype,
**kv_model_kwargs,
)
kv_cache_decode.prefill(kv_cache_prefill)
@@ -297,8 +306,8 @@ if __name__ == "__main__":
"""
import time
# init compute
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
device_type = autodetect_device_type()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
# load the model and tokenizer

178
nanochat/flash_attention.py Normal file
View File

@@ -0,0 +1,178 @@
"""
Unified Flash Attention interface with automatic FA3/SDPA switching.
Exports `flash_attn` module that matches the FA3 API exactly, but falls back
to PyTorch SDPA on non-Hopper GPUs, MPS, and CPU.
Usage (drop-in replacement for FA3):
from nanochat.flash_attention import flash_attn
# Training (no KV cache)
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
# Inference (with KV cache)
y = flash_attn.flash_attn_with_kvcache(q, k_cache, v_cache, k=k, v=v, ...)
"""
import torch
import torch.nn.functional as F
# =============================================================================
# Detection: Try to load FA3 on Hopper+ GPUs
# =============================================================================
def _load_flash_attention_3():
"""Try to load Flash Attention 3 (requires Hopper+ GPU)."""
if not torch.cuda.is_available():
return None
try:
major, _ = torch.cuda.get_device_capability()
if major < 9: # Hopper is sm90
return None
import os
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
from kernels import get_kernel
return get_kernel('varunneal/flash-attention-3').flash_attn_interface
except Exception:
return None
_fa3 = _load_flash_attention_3()
HAS_FA3 = _fa3 is not None
# Override for testing: set to 'fa3', 'sdpa', or None (auto)
_override_impl = None
def _use_fa3():
"""Determine whether to use FA3 based on availability and override."""
if _override_impl == 'fa3':
assert HAS_FA3, "Cannot override to FA3: not available on this hardware"
return True
if _override_impl == 'sdpa':
return False
return HAS_FA3 # auto
# =============================================================================
# SDPA helpers
# =============================================================================
def _sdpa_attention(q, k, v, window_size, enable_gqa):
"""
SDPA attention with sliding window support.
q, k, v are (B, H, T, D) format.
"""
Tq = q.size(2)
Tk = k.size(2)
window = window_size[0]
# Full context, same length
if (window < 0 or window >= Tq) and Tq == Tk:
return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
# Single token generation
if Tq == 1:
return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
# Need explicit mask
device = q.device
if Tq == Tk:
# Causal + sliding window
mask = torch.tril(torch.ones(Tq, Tk, device=device, dtype=torch.bool))
if window > 0 and window < Tq:
row_idx = torch.arange(Tq, device=device).unsqueeze(1)
col_idx = torch.arange(Tk, device=device).unsqueeze(0)
mask = mask & ((row_idx - col_idx) <= window)
else:
# Chunk inference: attend to prefix + causal within chunk
prefix_len = Tk - Tq
mask = torch.zeros(Tq, Tk, device=device, dtype=torch.bool)
mask[:, :prefix_len] = True
mask[:, prefix_len:] = torch.tril(torch.ones(Tq, Tq, device=device, dtype=torch.bool))
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
# =============================================================================
# Public API: Same interface as FA3
# =============================================================================
def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
"""
Flash Attention for training (no KV cache).
Args:
q, k, v: Tensors of shape (B, T, H, D)
causal: Whether to use causal masking
window_size: (left, right) sliding window. -1 means unlimited.
Returns:
Output tensor of shape (B, T, H, D)
"""
if _use_fa3():
return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
# SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
enable_gqa = q.size(1) != k.size(1)
y = _sdpa_attention(q, k, v, window_size, enable_gqa)
return y.transpose(1, 2) # back to (B, T, H, D)
def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=None,
causal=False, window_size=(-1, -1)):
"""
Flash Attention with KV cache for inference.
FA3 updates k_cache/v_cache in-place. Our SDPA fallback does the same.
Args:
q: Queries, shape (B, T_new, H, D)
k_cache, v_cache: Pre-allocated cache tensors, shape (B, T_max, H_kv, D)
k, v: New keys/values to insert, shape (B, T_new, H_kv, D)
cache_seqlens: Current position in cache, shape (B,) int32
causal: Whether to use causal masking
window_size: (left, right) sliding window. -1 means unlimited.
Returns:
Output tensor of shape (B, T_new, H, D)
"""
if _use_fa3():
return _fa3.flash_attn_with_kvcache(
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
causal=causal, window_size=window_size
)
# SDPA fallback: manually manage KV cache
B, T_new, H, D = q.shape
pos = cache_seqlens[0].item() # assume uniform position across batch
# Insert new k, v into cache (in-place, matching FA3 behavior)
if k is not None and v is not None:
k_cache[:, pos:pos+T_new, :, :] = k
v_cache[:, pos:pos+T_new, :, :] = v
# Get full cache up to current position + new tokens
end_pos = pos + T_new
k_full = k_cache[:, :end_pos, :, :]
v_full = v_cache[:, :end_pos, :, :]
# Transpose to SDPA layout: (B, T, H, D) -> (B, H, T, D)
q_sdpa = q.transpose(1, 2)
k_sdpa = k_full.transpose(1, 2)
v_sdpa = v_full.transpose(1, 2)
enable_gqa = q_sdpa.size(1) != k_sdpa.size(1)
y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa)
return y_sdpa.transpose(1, 2) # back to (B, T, H, D)
# =============================================================================
# Export: flash_attn module interface (drop-in replacement for FA3)
# =============================================================================
from types import SimpleNamespace
flash_attn = SimpleNamespace(
flash_attn_func=flash_attn_func,
flash_attn_with_kvcache=flash_attn_with_kvcache,
)

View File

@@ -20,21 +20,15 @@ import torch.nn as nn
import torch.nn.functional as F
from nanochat.common import get_dist_info, print0
from nanochat.muon import Muon, DistMuon
from nanochat.adamw import DistAdamW
from nanochat.optim import MuonAdamW, DistMuonAdamW
# Load Flash Attention 3 from HuggingFace Hub (and silence the progress bar)
import os
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
# Official docs of FA3 label it as "beta" and want you to install FA3 from source, which is a pain.
# Wishing for official FA3 wheels soon, for now this seems to be a fast way to get them (ty varunneal)
from kernels import get_kernel
flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
from nanochat.flash_attention import flash_attn
@dataclass
class GPTConfig:
sequence_len: int = 1024
vocab_size: int = 50304
sequence_len: int = 2048
vocab_size: int = 32768
n_layer: int = 12
n_head: int = 6 # number of query heads
n_kv_head: int = 6 # number of key/value heads (GQA)
@@ -42,7 +36,7 @@ class GPTConfig:
# Sliding window attention pattern string, tiled across layers. Final layer always L.
# Characters: L=long (full context), S=short (half context)
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
window_pattern: str = "L"
window_pattern: str = "SSSL"
def norm(x):
@@ -50,6 +44,10 @@ def norm(x):
return F.rms_norm(x, (x.size(-1),))
def has_ve(layer_idx, n_layer):
"""Returns True if GPT layer should have Value Embedding (alternating, last layer always included)."""
return layer_idx % 2 == (n_layer - 1) % 2
def apply_rotary_emb(x, cos, sin):
assert x.ndim == 4 # multihead attention
d = x.shape[3] // 2
@@ -72,8 +70,10 @@ class CausalSelfAttention(nn.Module):
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
self.ve_gate_channels = 32
self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
def forward(self, x, cos_sin, window_size, kv_cache):
def forward(self, x, ve, cos_sin, window_size, kv_cache):
B, T, C = x.size()
# Project the input to get queries, keys, and values
@@ -82,13 +82,18 @@ class CausalSelfAttention(nn.Module):
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
# Value residual (ResFormer): mix in value embedding with input-dependent gate per head
if ve is not None:
ve = ve.view(B, T, self.n_kv_head, self.head_dim)
gate = 2 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) # (B, T, n_kv_head), range (0, 2)
v = v + gate.unsqueeze(-1) * ve
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
cos, sin = cos_sin
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
q, k = norm(q), norm(k) # QK norm
# Attention with Flash Attention 3
# FA3 handles GQA automatically when n_kv_heads < n_heads
# Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere)
# window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
if kv_cache is None:
# Training: causal attention with optional sliding window
@@ -132,8 +137,8 @@ class Block(nn.Module):
self.attn = CausalSelfAttention(config, layer_idx)
self.mlp = MLP(config)
def forward(self, x, cos_sin, window_size, kv_cache):
x = x + self.attn(norm(x), cos_sin, window_size, kv_cache)
def forward(self, x, ve, cos_sin, window_size, kv_cache):
x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache)
x = x + self.mlp(norm(x))
return x
@@ -166,6 +171,10 @@ class GPT(nn.Module):
# Separate parameters so they can have different optimizer treatment
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
# Value embeddings (ResFormer-style): alternating layers, last layer always included
head_dim = config.n_embd // config.n_head
kv_dim = config.n_kv_head * head_dim
self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)})
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
@@ -176,6 +185,7 @@ class GPT(nn.Module):
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
self.register_buffer("sin", sin, persistent=False)
@torch.no_grad()
def init_weights(self):
"""
Initialize the full model in this one function for maximum clarity.
@@ -207,18 +217,28 @@ class GPT(nn.Module):
torch.nn.init.zeros_(block.mlp.c_proj.weight)
# Per-layer scalars
with torch.no_grad():
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
self.x0_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to input embedding
# Value embeddings (init like c_v: uniform with same std)
for ve in self.value_embeds.values():
torch.nn.init.uniform_(ve.weight, -s, s)
# Gate weights init to zero so gates start at sigmoid(0) = 0.5, scaled by 2 -> 1.0 (neutral)
for block in self.transformer.h:
if block.attn.ve_gate is not None:
torch.nn.init.zeros_(block.attn.ve_gate.weight)
# Rotary embeddings
head_dim = self.config.n_embd // self.config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.cos, self.sin = cos, sin
# Cast token embeddings to bf16: optimizer can tolerate it and it saves memory
# Cast embeddings to bf16: optimizer can tolerate it and it saves memory
if self.transformer.wte.weight.device.type == "cuda":
self.transformer.wte.to(dtype=torch.bfloat16)
for ve in self.value_embeds.values():
ve.to(dtype=torch.bfloat16)
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
# TODO: bump base theta more? e.g. 100K is more common more recently
@@ -283,7 +303,9 @@ class GPT(nn.Module):
"""
nparams = sum(p.numel() for p in self.parameters())
# Exclude non-matmul params: embeddings and per-layer scalars
nparams_exclude = self.transformer.wte.weight.numel() + self.resid_lambdas.numel() + self.x0_lambdas.numel()
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
self.resid_lambdas.numel() + self.x0_lambdas.numel())
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
# Sum attention FLOPs per layer, accounting for sliding window
attn_flops = 0
@@ -296,49 +318,72 @@ class GPT(nn.Module):
def num_scaling_params(self):
"""
Return all of the parameters, same as Chinchilla paper.
Kaplan et al. did not include embedding parameters and said that this led to cleaner scaling laws.
But Kaplan et al. also had a bug in their results (as pointed out by Chinchilla).
My own experiments in nanochat confirm the Chinchilla approach gives the much cleaner scaling law.
Ref: https://arxiv.org/abs/2203.15556 (Chinchilla paper <- good).
Ref: https://arxiv.org/abs/2001.08361 (Kaplan et al. original scaling laws paper <- bad)
"""
nparams = sum(p.numel() for p in self.parameters())
return nparams
Return detailed parameter counts for scaling law analysis.
Different papers use different conventions:
- Kaplan et al. excluded embedding parameters
- Chinchilla included all parameters
Ref: https://arxiv.org/abs/2203.15556 (Chinchilla paper)
Ref: https://arxiv.org/abs/2001.08361 (Kaplan et al. original scaling laws paper)
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
Returns a dict with counts for each parameter group, so downstream analysis
can experiment with which combination gives the cleanest scaling laws.
"""
# Count each group separately (mirrors the grouping in setup_optimizers)
wte = sum(p.numel() for p in self.transformer.wte.parameters())
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
lm_head = sum(p.numel() for p in self.lm_head.parameters())
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel()
total = wte + value_embeds + lm_head + transformer_matrices + scalars
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
return {
'wte': wte,
'value_embeds': value_embeds,
'lm_head': lm_head,
'transformer_matrices': transformer_matrices,
'scalars': scalars,
'total': total,
}
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
model_dim = self.config.n_embd
ddp, rank, local_rank, world_size = get_dist_info()
# Separate out all parameters into 5 groups (matrix, embedding, lm_head, resid_lambdas, x0_lambdas)
# Separate out all parameters into groups
matrix_params = list(self.transformer.h.parameters())
value_embeds_params = list(self.value_embeds.parameters())
embedding_params = list(self.transformer.wte.parameters())
lm_head_params = list(self.lm_head.parameters())
resid_params = [self.resid_lambdas]
x0_params = [self.x0_lambdas]
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(resid_params) + len(x0_params)
# Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
dmodel_lr_scale = (model_dim / 768) ** -0.5
print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
adam_groups = [
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream
dict(params=x0_params, lr=scalar_lr),
# Build param_groups with all required fields explicit
param_groups = [
# AdamW groups (embeddings, lm_head, scalars)
dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0),
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
]
adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
# Create the Muon optimizer for the linear layers
muon_kwargs = dict(lr=matrix_lr, momentum=0.95, weight_decay=weight_decay)
MuonFactory = DistMuon if ddp else Muon
muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
# Combine them the two optimizers into one list
optimizers = [adamw_optimizer, muon_optimizer]
for opt in optimizers:
for group in opt.param_groups:
group["initial_lr"] = group["lr"]
return optimizers
# Muon groups (matrix params, grouped by shape for stacking)
for shape in sorted({p.shape for p in matrix_params}):
group_params = [p for p in matrix_params if p.shape == shape]
param_groups.append(dict(
kind='muon', params=group_params, lr=matrix_lr,
momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay,
))
Factory = DistMuonAdamW if ddp else MuonAdamW
optimizer = Factory(param_groups)
for group in optimizer.param_groups:
group["initial_lr"] = group["lr"]
return optimizer
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
B, T = idx.size()
@@ -352,12 +397,13 @@ class GPT(nn.Module):
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
# Forward the trunk of the Transformer
x = self.transformer.wte(idx)
x = self.transformer.wte(idx) # embed current token
x = norm(x)
x0 = x # save initial normalized embedding for x0 residual
for i, block in enumerate(self.transformer.h):
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
x = block(x, cos_sin, self.window_sizes[i], kv_cache)
ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
x = norm(x)
# Forward the lm_head (compute logits)

View File

@@ -1,295 +0,0 @@
"""
Muon optimizer adapted (simplified) from modded-nanogpt.
https://github.com/KellerJordan/modded-nanogpt
"""
import torch
from torch import Tensor
import torch.distributed as dist
# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
# From https://arxiv.org/pdf/2505.16932
polar_express_coeffs = [
(8.156554524902461, -22.48329292557795, 15.878769915207462),
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
]
@torch.compile
def zeropower_via_polar_express(G: Tensor, steps: int = 5) -> Tensor:
"""
Polar Express Sign Method for orthogonalization.
https://arxiv.org/pdf/2505.16932
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
Alternative to Newton-Schulz iteration with potentially better convergence properties.
"""
assert G.ndim >= 2
X = G.bfloat16()
if G.size(-2) > G.size(-1):
X = X.mT
# Ensure spectral norm is at most 1 (with 2% safety factor)
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
# Perform the iterations (cap at available coefficients)
for a, b, c in polar_express_coeffs[:min(steps, len(polar_express_coeffs))]:
A = X @ X.mT
B = b * A + c * (A @ A)
X = a * X + B @ X
if G.size(-2) > G.size(-1):
X = X.mT
return X
@torch.compile
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
zero even beyond the point where the iteration no longer converges all the way to one everywhere
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
"""
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
if G.size(-2) > G.size(-1):
X = X.mT
# Ensure spectral norm is at most 1
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
# Perform the NS iterations
for _ in range(steps):
A = X @ X.mT
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
X = a * X + B @ X
if G.size(-2) > G.size(-1):
X = X.mT
return X
@torch.compile
def apply_variance_reduction(v: Tensor, second_momentum_buffer: Tensor, beta2: float) -> Tensor:
"""
NorMuon-style variance reduction, similar to Adafactor's low-rank variance estimator.
https://arxiv.org/pdf/2510.05491
Normalizes updates based on a running estimate of per-row (or per-column) variance.
The reduction dimension is determined by the shape of second_momentum_buffer.
"""
# Determine reduction dimension from buffer shape
red_dim = -1 if second_momentum_buffer.size(-1) == 1 else -2
# Compute per-row/col mean of squared values
v_mean = v.float().square().mean(dim=red_dim, keepdim=True)
red_dim_size = v.size(red_dim)
# Compute current norm
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
v_norm = v_norm_sq.sqrt()
# Update second momentum buffer (EMA of variance)
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
# Compute scaling factor from second momentum
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
# Final scale preserves overall norm while adjusting per-row/col
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
return v.mul(final_scale.to(v.dtype))
class Muon(torch.optim.Optimizer):
"""
Muon - MomentUm Orthogonalized by Newton-schulz
https://kellerjordan.github.io/posts/muon/
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
the advantage that it can be stably run in bfloat16 on the GPU.
Some warnings:
- This optimizer should not be used for the embedding layer, the final fully connected layer,
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
Arguments:
lr: The learning rate used by the internal SGD.
momentum: The momentum used by the internal SGD.
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
ns_steps: The number of Newton-Schulz iteration steps to use.
beta2: The decay rate for the second moment (variance) estimate. Set to None to disable.
weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
"""
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5, beta2=0.95, weight_decay=0.0):
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
params: list[Tensor] = [*params]
param_groups = []
for size in {p.numel() for p in params}:
group = dict(params=[p for p in params if p.numel() == size])
param_groups.append(group)
super().__init__(param_groups, defaults)
@torch.no_grad()
def step(self):
for group in self.param_groups:
params: list[Tensor] = group["params"]
for p in params:
g = p.grad
assert g is not None
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(g)
buf: Tensor = state["momentum_buffer"]
buf.lerp_(g, 1 - group["momentum"])
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
g = zeropower_via_polar_express(g, steps=group["ns_steps"])
# Variance reduction (NorMuon-style)
if group["beta2"] is not None:
if "second_momentum_buffer" not in state:
# Buffer shape determines reduction dim: reduce along larger dimension
if p.size(-2) >= p.size(-1):
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1])
else:
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1, :])
g = apply_variance_reduction(g, state["second_momentum_buffer"], group["beta2"])
# Parameter update with cautious weight decay
effective_lr = group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5
wd = group["weight_decay"]
if wd != 0:
mask = (g * p) >= 0
p.sub_(effective_lr * g + effective_lr * wd * p * mask)
else:
p.sub_(effective_lr * g)
class DistMuon(torch.optim.Optimizer):
"""
Muon: SGD-momentum + (optional) Nesterov, then orthogonalize the 2D update via Polar Express,
finally apply aspect-ratio scaled step. Performs its own distributed synchronization:
- reduce_scatter(AVG) for gradient averaging
- all_gather to replicate updated weights
Notes:
* Designed for 2D parameters (e.g., linear/conv kernels reshaped to 2D). Do not use for 0D/1D
params like embeddings or scalars.
* Momentum buffers are maintained only on the 'owner' rank for each parameter (rank chosen
by block-cyclic assignment below). If you checkpoint optimizer state on a single rank,
consolidate states beforehand.
Args:
params: iterable of Tensors
lr: learning rate
momentum: momentum coefficient in [0,1)
nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf
ns_steps: number of Newton-Schulz iterations for the orthogonalization
beta2: decay rate for second moment (variance) estimate. Set to None to disable.
weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
"""
def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
nesterov: bool = True, ns_steps: int = 5, beta2: float = 0.95, weight_decay: float = 0.0):
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
params = list(params)
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
rank = dist.get_rank()
# Group all parameters by their shape
shapes = sorted({p.shape for p in params}) # sort to ensure consistent / deterministic ordering
param_groups = []
for shape in shapes:
group_params = [p for p in params if p.shape == shape]
device, dtype = group_params[0].device, group_params[0].dtype
assert all(p.device == device for p in group_params)
assert all(p.dtype == dtype for p in group_params)
if rank == 0:
print(f"Muon: Grouping {len(group_params)} params of shape {shape}, device {device}, dtype {dtype}")
param_groups.append(dict(params=group_params, zero_buffer=torch.zeros_like(group_params[0])))
super().__init__(param_groups, defaults)
@torch.no_grad()
def step(self):
rank = dist.get_rank()
world_size = dist.get_world_size()
# Ensure all grads exist
assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
# Kick off all the reduce scatter operations to average up the gradients across all ranks
all_reduce_futures = []
for group in self.param_groups:
params = group["params"]
zero_buffer = group["zero_buffer"]
# Go through params in groups of world_size.
for base_i in range(0, len(params), world_size):
# The compute owner of each param is rank i % world_size
owner_idx = base_i + rank
# each rank stacks up its chunk of world_size params into a list
rs_input = [p.grad for p in params[base_i:base_i + world_size]]
# pad rs_input with the zero buffer to complete the group
rs_input.extend([zero_buffer] * (world_size - len(rs_input)))
# the output buffer gets strided across the group based on the rank
rs_output = params[owner_idx].grad if owner_idx < len(params) else torch.empty_like(zero_buffer)
# reduce scatter the gradients within this group of world_size params
work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True).get_future()
all_reduce_futures.append(work)
# Now each rank computes the update and gathers
future_idx = 0
all_gather_futures = []
for group in self.param_groups:
params = group["params"]
zero_buffer = group["zero_buffer"]
# Go through params in groups of world_size.
for base_i in range(0, len(params), world_size):
# The compute owner of each param is rank i % world_size
owner_idx = base_i + rank # calculate the index of the param that this rank owns
# Wait for the reduce scatter to complete
all_reduce_futures[future_idx].wait() # possibly later we could use wait_any polling instead
future_idx += 1
# Owner computes the Muon update, result is in its param
if owner_idx < len(params):
p = params[owner_idx]
g = p.grad # now averaged across ranks
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(g)
buf: Tensor = state["momentum_buffer"]
buf.lerp_(g, 1.0 - group["momentum"])
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
g = zeropower_via_polar_express(g, steps=group["ns_steps"])
# Variance reduction (NorMuon-style)
if group["beta2"] is not None:
if "second_momentum_buffer" not in state:
# Buffer shape determines reduction dim: reduce along larger dimension
if p.size(-2) >= p.size(-1):
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1])
else:
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1, :])
g = apply_variance_reduction(g, state["second_momentum_buffer"], group["beta2"])
# Parameter update with cautious weight decay
effective_lr = group["lr"] * (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
wd = group["weight_decay"]
if wd != 0:
mask = (g * p) >= 0
p.sub_(effective_lr * g + effective_lr * wd * p * mask)
else:
p.sub_(effective_lr * g)
# Replicate updated parameters to all ranks
ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer
ag_output = params[base_i:base_i + world_size]
ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))]) # pad
work = dist.all_gather(ag_output, ag_input, async_op=True).get_future()
all_gather_futures.append(work)
# Wait for all work to finish
torch.futures.collect_all(all_gather_futures).wait()

528
nanochat/optim.py Normal file
View File

@@ -0,0 +1,528 @@
"""
A nice and efficient mixed AdamW/Muon Combined Optimizer.
Usually the embeddings and scalars go into AdamW, and the matrix parameters go into Muon.
Two versions are provided (MuonAdamW, DistMuonAdamW), for single GPU and distributed.
Addapted from: https://github.com/KellerJordan/modded-nanogpt
Further contributions from @karpathy and @chrisjmccormick.
"""
import torch
import torch.distributed as dist
from torch import Tensor
# -----------------------------------------------------------------------------
"""
Good old AdamW optimizer, fused kernel.
https://arxiv.org/abs/1711.05101
"""
@torch.compile(dynamic=False, fullgraph=True)
def adamw_step_fused(
p: Tensor, # (32768, 768) - parameter tensor
grad: Tensor, # (32768, 768) - gradient, same shape as p
exp_avg: Tensor, # (32768, 768) - first moment, same shape as p
exp_avg_sq: Tensor, # (32768, 768) - second moment, same shape as p
step_t: Tensor, # () - 0-D CPU tensor, step count
lr_t: Tensor, # () - 0-D CPU tensor, learning rate
beta1_t: Tensor, # () - 0-D CPU tensor, beta1
beta2_t: Tensor, # () - 0-D CPU tensor, beta2
eps_t: Tensor, # () - 0-D CPU tensor, epsilon
wd_t: Tensor, # () - 0-D CPU tensor, weight decay
) -> None:
"""
Fused AdamW step: weight_decay -> momentum_update -> bias_correction -> param_update
All in one compiled graph to eliminate Python overhead between ops.
The 0-D CPU tensors avoid recompilation when hyperparameter values change.
"""
# Weight decay (decoupled, applied before the update)
p.mul_(1 - lr_t * wd_t)
# Update running averages (lerp_ is cleaner and fuses well)
exp_avg.lerp_(grad, 1 - beta1_t)
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
# Bias corrections
bias1 = 1 - beta1_t ** step_t
bias2 = 1 - beta2_t ** step_t
# Compute update and apply
denom = (exp_avg_sq / bias2).sqrt() + eps_t
step_size = lr_t / bias1
p.add_(exp_avg / denom, alpha=-step_size)
# -----------------------------------------------------------------------------
"""
Muon optimizer adapted and simplified from modded-nanogpt.
https://github.com/KellerJordan/modded-nanogpt
Background:
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
zero even beyond the point where the iteration no longer converges all the way to one everywhere
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
Here, an alternative to Newton-Schulz iteration with potentially better convergence properties:
Polar Express Sign Method for orthogonalization.
https://arxiv.org/pdf/2505.16932
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
Some of the changes in nanochat implementation:
- Uses a simpler, more general approach to parameter grouping and stacking
- Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step
- Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format)
"""
# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
# From https://arxiv.org/pdf/2505.16932
polar_express_coeffs = [
(8.156554524902461, -22.48329292557795, 15.878769915207462),
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
]
@torch.compile(dynamic=False, fullgraph=True)
def muon_step_fused(
stacked_grads: Tensor, # (12, 768, 3072) - stacked gradients
stacked_params: Tensor, # (12, 768, 3072) - stacked parameters
momentum_buffer: Tensor, # (12, 768, 3072) - first moment buffer
second_momentum_buffer: Tensor, # (12, 768, 1) or (12, 1, 3072) - factored second moment
momentum_t: Tensor, # () - 0-D CPU tensor, momentum coefficient
lr_t: Tensor, # () - 0-D CPU tensor, learning rate
wd_t: Tensor, # () - 0-D CPU tensor, weight decay
beta2_t: Tensor, # () - 0-D CPU tensor, beta2 for second moment
ns_steps: int, # 5 - number of Newton-Schulz/Polar Express iterations
red_dim: int, # -1 or -2 - reduction dimension for variance
) -> None:
"""
Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update
All in one compiled graph to eliminate Python overhead between ops.
Some of the constants are 0-D CPU tensors to avoid recompilation when values change.
"""
# Nesterov momentum
momentum = momentum_t.to(stacked_grads.dtype)
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
g = stacked_grads.lerp_(momentum_buffer, momentum)
# Polar express
X = g.bfloat16()
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
if g.size(-2) > g.size(-1): # Tall matrix
for a, b, c in polar_express_coeffs[:ns_steps]:
A = X.mT @ X
B = b * A + c * (A @ A)
X = a * X + X @ B
else: # Wide matrix (original math)
for a, b, c in polar_express_coeffs[:ns_steps]:
A = X @ X.mT
B = b * A + c * (A @ A)
X = a * X + B @ X
g = X
# Variance reduction
beta2 = beta2_t.to(g.dtype)
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
red_dim_size = g.size(red_dim)
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
v_norm = v_norm_sq.sqrt()
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
g = g * final_scale.to(g.dtype)
# Cautious weight decay + parameter update
lr = lr_t.to(g.dtype)
wd = wd_t.to(g.dtype)
mask = (g * stacked_params) >= 0
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
# -----------------------------------------------------------------------------
# Single GPU version of the MuonAdamW optimizer.
# Used mostly for reference, debugging and testing.
class MuonAdamW(torch.optim.Optimizer):
"""
Combined optimizer: Muon for 2D matrix params, AdamW for others, single GPU version.
AdamW - Fused AdamW optimizer step.
Muon - MomentUm Orthogonalized by Newton-schulz
https://kellerjordan.github.io/posts/muon/
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
the advantage that it can be stably run in bfloat16 on the GPU.
Some warnings:
- The Muon optimizer should not be used for the embedding layer, the final fully connected layer,
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
Arguments:
param_groups: List of dicts, each containing:
- 'params': List of parameters
- 'kind': 'adamw' or 'muon'
- For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
- For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
"""
def __init__(self, param_groups: list[dict]):
super().__init__(param_groups, defaults={})
# 0-D CPU tensors to avoid torch.compile recompilation when values change
# AdamW tensors
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
# Muon tensors
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
def _step_adamw(self, group: dict) -> None:
"""
AdamW update for each param in the group individually.
Lazy init the state, fill in all 0-D tensors, call the fused kernel.
"""
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
state = self.state[p]
# State init
if not state:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)
exp_avg = state['exp_avg']
exp_avg_sq = state['exp_avg_sq']
state['step'] += 1
# Fill 0-D tensors with current values
self._adamw_step_t.fill_(state['step'])
self._adamw_lr_t.fill_(group['lr'])
self._adamw_beta1_t.fill_(group['betas'][0])
self._adamw_beta2_t.fill_(group['betas'][1])
self._adamw_eps_t.fill_(group['eps'])
self._adamw_wd_t.fill_(group['weight_decay'])
# Fused update: weight_decay -> momentum -> bias_correction -> param_update
adamw_step_fused(
p, grad, exp_avg, exp_avg_sq,
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
)
def _step_muon(self, group: dict) -> None:
"""
Muon update for all params in the group (stacked for efficiency).
Lazy init the state, fill in all 0-D tensors, call the fused kernel.
"""
params: list[Tensor] = group['params']
if not params:
return
# Get or create group-level buffers (stored in first param's state for convenience)
p = params[0]
state = self.state[p]
num_params = len(params)
shape, device, dtype = p.shape, p.device, p.dtype
# Momentum for every individual parameter
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
momentum_buffer = state["momentum_buffer"]
# Second momentum buffer is factored, either per-row or per-column
if "second_momentum_buffer" not in state:
state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1])
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
second_momentum_buffer = state["second_momentum_buffer"]
red_dim = -1 if shape[-2] >= shape[-1] else -2
# Stack grads and params (NOTE: this assumes all params have the same shape)
stacked_grads = torch.stack([p.grad for p in params])
stacked_params = torch.stack(params)
# Fill all the 0-D tensors with current values
self._muon_momentum_t.fill_(group["momentum"])
self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
self._muon_wd_t.fill_(group["weight_decay"])
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
muon_step_fused(
stacked_grads,
stacked_params,
momentum_buffer,
second_momentum_buffer,
self._muon_momentum_t,
self._muon_lr_t,
self._muon_wd_t,
self._muon_beta2_t,
group["ns_steps"],
red_dim,
)
# Copy back to original params
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
@torch.no_grad()
def step(self):
for group in self.param_groups:
if group['kind'] == 'adamw':
self._step_adamw(group)
elif group['kind'] == 'muon':
self._step_muon(group)
else:
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
# -----------------------------------------------------------------------------
# Distributed version of the MuonAdamW optimizer.
# Used for training on multiple GPUs.
class DistMuonAdamW(torch.optim.Optimizer):
"""
Combined distributed optimizer: Muon for 2D matrix params, AdamW for others.
See MuonAdamW for the algorithmic details of each optimizer. This class adds
distributed communication to enable multi-GPU training without PyTorch DDP.
Design Goals:
- Overlap communication with computation (async ops)
- Minimize memory by sharding optimizer states across ranks (ZeRO-2 style)
- Batch small tensors into single comm ops where possible
Communication Pattern (3-phase async):
We use a 3-phase structure to maximize overlap between communication and compute:
Phase 1: Launch all async reduce ops
- Kick off all reduce_scatter/all_reduce operations
- Don't wait - let them run in background while we continue
Phase 2: Wait for reduces, compute updates, launch gathers
- For each group: wait for its reduce, compute the update, launch gather
- By processing groups in order, earlier gathers run while later computes happen
Phase 3: Wait for gathers, copy back
- Wait for all gathers to complete
- Copy updated params back to original tensors (Muon only)
AdamW Communication (ZeRO-2 style):
- Small params (<1024 elements): all_reduce gradients, update full param on each rank.
Optimizer state is replicated but these params are tiny (scalars, biases).
- Large params: reduce_scatter gradients so each rank gets 1/N of the grad, update
only that slice, then all_gather the updated slices. Optimizer state (exp_avg,
exp_avg_sq) is sharded - each rank only stores state for its slice.
Requires param.shape[0] divisible by world_size.
Muon Communication (stacked + chunked):
- All params in a Muon group must have the same shape (caller's responsibility).
- Stack all K params into a single (K, *shape) tensor for efficient comm.
- Divide K params across N ranks: each rank "owns" ceil(K/N) params.
- reduce_scatter the stacked grads so each rank gets its chunk.
- Each rank computes Muon update only for params it owns.
- all_gather the updated params back to all ranks.
- Optimizer state (momentum_buffer, second_momentum_buffer) is sharded by chunk.
- Padding: if K doesn't divide evenly, we zero-pad to (ceil(K/N) * N) for comm,
then ignore the padding when copying back.
Buffer Reuse:
- For Muon, we allocate stacked_grads for reduce_scatter input, then reuse the
same buffer as the output for all_gather (stacked_params). This saves memory
since we don't need both buffers simultaneously.
Arguments:
param_groups: List of dicts, each containing:
- 'params': List of parameters
- 'kind': 'adamw' or 'muon'
- For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
- For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
"""
def __init__(self, param_groups: list[dict]):
super().__init__(param_groups, defaults={})
# 0-D CPU tensors to avoid torch.compile recompilation when values change
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
def _reduce_adamw(self, group: dict, world_size: int) -> dict:
"""Launch async reduce ops for AdamW group. Returns info dict with per-param infos."""
param_infos = {}
for p in group['params']:
grad = p.grad
if p.numel() < 1024:
# Small params: all_reduce (no scatter/gather needed)
future = dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
param_infos[p] = dict(future=future, grad_slice=grad, is_small=True)
else:
# Large params: reduce_scatter
rank_size = grad.shape[0] // world_size
grad_slice = torch.empty_like(grad[:rank_size])
future = dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
param_infos[p] = dict(future=future, grad_slice=grad_slice, is_small=False)
return dict(param_infos=param_infos)
def _reduce_muon(self, group: dict, world_size: int) -> dict:
"""Launch async reduce op for Muon group. Returns info dict."""
params = group['params']
chunk_size = (len(params) + world_size - 1) // world_size
padded_num_params = chunk_size * world_size
p = params[0]
shape, device, dtype = p.shape, p.device, p.dtype
# Stack grads and zero-pad to padded_num_params
grad_stack = torch.stack([p.grad for p in params])
stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device)
stacked_grads[:len(params)].copy_(grad_stack)
if len(params) < padded_num_params:
stacked_grads[len(params):].zero_()
# Reduce_scatter to get this rank's chunk
grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
future = dist.reduce_scatter_tensor(grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True).get_future()
return dict(future=future, grad_chunk=grad_chunk, stacked_grads=stacked_grads, chunk_size=chunk_size)
def _compute_adamw(self, group: dict, info: dict, gather_list: list, rank: int, world_size: int) -> None:
"""Wait for reduce, compute AdamW updates, launch gathers for large params."""
param_infos = info['param_infos']
for p in group['params']:
pinfo = param_infos[p]
pinfo['future'].wait()
grad_slice = pinfo['grad_slice']
state = self.state[p]
# For small params, operate on full param; for large, operate on slice
if pinfo['is_small']:
p_slice = p
else:
rank_size = p.shape[0] // world_size
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
# State init
if not state:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_slice)
state['exp_avg_sq'] = torch.zeros_like(p_slice)
state['step'] += 1
# Fill 0-D tensors and run fused kernel
self._adamw_step_t.fill_(state['step'])
self._adamw_lr_t.fill_(group['lr'])
self._adamw_beta1_t.fill_(group['betas'][0])
self._adamw_beta2_t.fill_(group['betas'][1])
self._adamw_eps_t.fill_(group['eps'])
self._adamw_wd_t.fill_(group['weight_decay'])
adamw_step_fused(
p_slice, grad_slice, state['exp_avg'], state['exp_avg_sq'],
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
)
# Large params need all_gather
if not pinfo['is_small']:
future = dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()
gather_list.append(dict(future=future, params=None))
def _compute_muon(self, group: dict, info: dict, gather_list: list, rank: int) -> None:
"""Wait for reduce, compute Muon updates, launch gather."""
info['future'].wait()
params = group['params']
chunk_size = info['chunk_size']
grad_chunk = info['grad_chunk']
p = params[0]
shape, device, dtype = p.shape, p.device, p.dtype
# How many params does this rank own?
start_idx = rank * chunk_size
num_owned = min(chunk_size, max(0, len(params) - start_idx))
# Get or create group-level state
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
if "second_momentum_buffer" not in state:
state_shape = (chunk_size, shape[-2], 1) if shape[-2] >= shape[-1] else (chunk_size, 1, shape[-1])
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
red_dim = -1 if shape[-2] >= shape[-1] else -2
# Build output buffer for all_gather
updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
if num_owned > 0:
owned_params = [params[start_idx + i] for i in range(num_owned)]
stacked_owned = torch.stack(owned_params)
# Fill 0-D tensors and run fused kernel
self._muon_momentum_t.fill_(group["momentum"])
self._muon_beta2_t.fill_(group["beta2"])
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
self._muon_wd_t.fill_(group["weight_decay"])
muon_step_fused(
grad_chunk[:num_owned], stacked_owned,
state["momentum_buffer"][:num_owned], state["second_momentum_buffer"][:num_owned],
self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t, self._muon_beta2_t,
group["ns_steps"], red_dim,
)
updated_params[:num_owned].copy_(stacked_owned)
if num_owned < chunk_size:
updated_params[num_owned:].zero_()
# Reuse stacked_grads buffer for all_gather output
stacked_params = info["stacked_grads"]
future = dist.all_gather_into_tensor(stacked_params, updated_params, async_op=True).get_future()
gather_list.append(dict(future=future, stacked_params=stacked_params, params=params))
def _finish_gathers(self, gather_list: list) -> None:
"""Wait for all gathers and copy Muon params back."""
for info in gather_list:
info["future"].wait()
if info["params"] is not None:
# Muon: copy from stacked buffer back to individual params
torch._foreach_copy_(info["params"], list(info["stacked_params"][:len(info["params"])].unbind(0)))
@torch.no_grad()
def step(self):
rank = dist.get_rank()
world_size = dist.get_world_size()
# Phase 1: launch all async reduce ops
reduce_infos: list[dict] = []
for group in self.param_groups:
if group['kind'] == 'adamw':
reduce_infos.append(self._reduce_adamw(group, world_size))
elif group['kind'] == 'muon':
reduce_infos.append(self._reduce_muon(group, world_size))
else:
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
# Phase 2: wait for reduces, compute updates, launch gathers
gather_list: list[dict] = []
for group, info in zip(self.param_groups, reduce_infos):
if group['kind'] == 'adamw':
self._compute_adamw(group, info, gather_list, rank, world_size)
elif group['kind'] == 'muon':
self._compute_muon(group, info, gather_list, rank)
else:
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
# Phase 3: wait for gathers, copy back
self._finish_gathers(gather_list)

View File

@@ -26,7 +26,7 @@ SPECIAL_TOKENS = [
# NOTE: this split pattern deviates from GPT-4 in that we use \p{N}{1,2} instead of \p{N}{1,3}
# I did this because I didn't want to "waste" too many tokens on numbers for smaller vocab sizes.
# I haven't validated that this is actually a good idea, TODO.
# I verified that 2 is the sweet spot for vocab size of 32K. 1 is a bit worse, 3 was worse still.
SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
# -----------------------------------------------------------------------------

View File

@@ -23,6 +23,7 @@ dependencies = [
"transformers>=4.57.3",
"uvicorn>=0.36.0",
"wandb>=0.21.3",
"zstandard>=0.25.0",
]
[dependency-groups]

View File

@@ -17,9 +17,10 @@ if [ -z "$SKIP_SETUP" ]; then
uv sync --extra gpu
source .venv/bin/activate
# Tokenizer
python -m nanochat.dataset -n 240
python -m scripts.tok_train --max_chars=2000000000 --vocab_size=32768
# Tokenizer, download 1000 shards for pretraining
# (probably this can be reduced but it's tricky to determine the exact right number, TODO).
python -m nanochat.dataset -n 1000
python -m scripts.tok_train --max-chars=2000000000 --vocab-size=32768
else
source .venv/bin/activate
fi
@@ -57,16 +58,15 @@ for d in "${DEPTHS[@]}"; do
START_TIME=$(date +%s)
# Train the model with natural horizon (target_param_data_ratio default)
# No --target_flops, let it use the default ratio from base_train
# No --target-flops, let it use the default ratio from base_train
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \
--depth=$d \
--target_param_data_ratio=8 \
--run="${WANDB_RUN}_d${d}" \
--model_tag="${TAG}" \
--core_metric_every=999999 \
--core_metric_max_per_task=-1 \
--sample_every=-1 \
--save_every=-1 \
--model-tag="${TAG}" \
--core-metric-every=999999 \
--core-metric-max-per-task=-1 \
--sample-every=-1 \
--save-every=-1 \
2>&1 | tee "$RESULTS_DIR/${TAG}_train.log"
END_TIME=$(date +%s)

View File

@@ -20,18 +20,18 @@ curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-publ
# train tokenizer on ~4B characters and kick off download of the rest for pretraining
python -m nanochat.dataset -n 16
# start downloading the rest of the shards for a total of 800 (see below why 800)
python -m nanochat.dataset -n 800 &
# start downloading the rest of the shards for a total of 1200 (see below why 1200)
python -m nanochat.dataset -n 1200 &
# todo: download the rest of it
python -m scripts.tok_train --max_chars=4000000000 --vocab_size=65536
python -m scripts.tok_train --max-chars=4000000000 --vocab-size=65536
python -m scripts.tok_eval
# Documenting my process for determining the hyperparameters for this run1000.sh script:
# We want a budget of approx. $1000 ~= 41.6 hours of 8XH100 compute
# 1) I guessed the model size for this to be about depth=32
# 2) Determine the device_batch_size that fits:
# Running the base_train.py script with --depth=32, I saw that --device_batch_size=16
# runs out of memory, but --device_batch_size=8 fits. Inspecting `nvidia-smi` during training,
# Running the base_train.py script with --depth=32, I saw that --device-batch-size=16
# runs out of memory, but --device-batch-size=8 fits. Inspecting `nvidia-smi` during training,
# I saw all GPUs were at about 78/80GB VRAM, so it just barely fits and we have good MFU at ~50%.
# So the training script was running ok and showed:
# Vocab size: 65,536
@@ -62,7 +62,9 @@ python -m scripts.tok_eval
# The tok_eval.py script reports about ~4.8 chars/token on average for the default tokenizer settings.
# So ~38B tokens # ~4.8 chars/token = ~185B chars.
# Each data shard is ~250M chars, so we need ~185B / 250M ~= 740 shards.
# For safety, I bumped that up to 800 shards, and that's why up above I used -n 800 when pre-downloading dataset shards.
# For safety, I bumped that up to 800 shards.
# The new DataLoader wastes about 35% of tokens to cropping, so 800 / (1 - 0.35) ~= 1200 shards are needed.
# => why up above I used -n 1200 when pre-downloading dataset shards.
# If we didn't have enough data, the training script would loop around and do multiple epochs over the same data,
# which would decrease model performance. Possibly 2, 3 or so epochs is ~ok, but certainly not ideal and at 10+ epochs we'd
# start to overfit hard.
@@ -71,13 +73,13 @@ python -m scripts.tok_eval
# Number of processes/GPUs to use
NPROC_PER_NODE=8
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=32 --target_param_data_ratio=20 --device_batch_size=8 --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=32 --target-param-data-ratio=20 --device-batch-size=8 --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval
# midtrain
# NOTE: ensure that we use the same device_batch_size here as the base training script.
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device_batch_size=8 --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device-batch-size=8 --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid
# sft

70
runs/runcpu.sh Executable file
View File

@@ -0,0 +1,70 @@
#!/bin/bash
# Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks)
# This script was last updated/tuned on Jan 17, 2026.
# Run as:
# bash dev/cpu_demo_run.sh
# NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook.
# Think of this run as educational/fun demo, not something you should expect to work well.
# (This is why I hide this script away in dev/)
# You may also want to run this script manually and one by one, copy pasting commands into your terminal.
# all the setup stuff
export OMP_NUM_THREADS=1
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
mkdir -p $NANOCHAT_BASE_DIR
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
[ -d ".venv" ] || uv venv
uv sync --extra cpu
source .venv/bin/activate
if [ -z "$WANDB_RUN" ]; then
WANDB_RUN=dummy
fi
# train tokenizer on ~2B characters (~34 seconds on my MacBook Pro M3 Max)
python -m nanochat.dataset -n 8
python -m scripts.tok_train --max-chars=2000000000
python -m scripts.tok_eval
# train a small 4 layer model
# I tuned this run to complete in about 30 minutes on my MacBook Pro M3 Max.
# To get better results, try increasing num_iterations, or get other ideas from your favorite LLM.
python -m scripts.base_train \
--depth=6 \
--head-dim=64 \
--window-pattern=L \
--max-seq-len=512 \
--device-batch-size=32 \
--total-batch-size=16384 \
--eval-every=100 \
--eval-tokens=524288 \
--core-metric-every=-1 \
--sample-every=100 \
--num-iterations=5000 \
--run=$WANDB_RUN
python -m scripts.base_loss --device-batch-size=1 --split-tokens=16384
python -m scripts.base_eval --max-per-task=16
# midtraining (~10 minutes on my MacBook Pro M3 Max)
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
python -m scripts.mid_train \
--max-seq-len=512 \
--device-batch-size=32 \
--total-batch-size=16384 \
--eval-every=200 \
--eval-tokens=524288 \
--num-iterations=1500 \
--run=$WANDB_RUN
# (it's ~ok to skip SFT)
# Chat with the model over CLI
# The model should be able to say that it is Paris.
# It might even know that the color of the sky is blue.
# Sometimes the model likes it if you first say Hi before you ask it questions.
# python -m scripts.chat_cli -i mid -p "What is the capital of France?"
# Chat with the model over a pretty WebUI ChatGPT style
# python -m scripts.chat_web -i mid

View File

@@ -1,26 +1,30 @@
#!/bin/bash
LABEL="jan26"
FLOPS_BUDGETS=(
1e18
3e18
6e18
2.15e18
4.64e18
1e19
)
DEPTHS=(8 10 12 14 16 18 20)
NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
WANDB_RUN="${WANDB_RUN:-scaling}"
WANDB_RUN="${WANDB_RUN:-scaling_${LABEL}}"
EVAL_TOKENS=$((100 * 524288)) # ~100M tokens for final eval (default is ~10M)
export OMP_NUM_THREADS=1
export NANOCHAT_BASE_DIR="${NANOCHAT_BASE_DIR:-$HOME/.cache/nanochat}"
source .venv/bin/activate
RESULTS_DIR="$NANOCHAT_BASE_DIR/scaling_laws_results"
RESULTS_DIR="$NANOCHAT_BASE_DIR/scaling_laws_results_${LABEL}"
mkdir -p "$RESULTS_DIR"
RESULTS_FILE="$RESULTS_DIR/results.csv"
# Write CSV header only if file doesn't exist
if [ ! -f "$RESULTS_FILE" ]; then
echo "flops_budget,depth,model_dim,num_params,num_scaling_params,num_iterations,tokens_trained,param_data_ratio,val_bpb,core_score,train_time_sec" > "$RESULTS_FILE"
echo "flops_budget,depth,model_dim,params_wte,params_bigram_embed,params_value_embeds,params_lm_head,params_transformer,params_scalars,params_total,num_iterations,tokens_trained,val_bpb,core_score,train_time_sec" > "$RESULTS_FILE"
fi
log() {
@@ -64,15 +68,15 @@ for flops in "${FLOPS_BUDGETS[@]}"; do
# CORE eval happens once at the end (999999 ensures only final step)
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \
--depth=$d \
--target_flops=$flops \
--target_param_data_ratio=-1 \
--target-flops=$flops \
--target-param-data-ratio=-1 \
--run="${WANDB_RUN}_${TAG}" \
--model_tag="${TAG}" \
--eval_tokens=$EVAL_TOKENS \
--core_metric_every=999999 \
--core_metric_max_per_task=-1 \
--sample_every=-1 \
--save_every=-1 \
--model-tag="${TAG}" \
--eval-tokens=$EVAL_TOKENS \
--core-metric-every=999999 \
--core-metric-max-per-task=-1 \
--sample-every=-1 \
--save-every=-1 \
2>&1 | tee "$RESULTS_DIR/${TAG}_train.log"
END_TIME=$(date +%s)
@@ -80,13 +84,19 @@ for flops in "${FLOPS_BUDGETS[@]}"; do
# Extract training stats from the log
LOG_FILE="$RESULTS_DIR/${TAG}_train.log"
NUM_PARAMS=$(grep "Number of parameters:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | head -1 | tr -d ',')
NUM_SCALING_PARAMS=$(grep "Number of parameters:" "$LOG_FILE" | tail -1 | grep -oP 'scaling: [\d,]+' | grep -oP '[\d,]+' | tr -d ',')
# Extract detailed parameter counts (for scaling law analysis with different conventions)
PARAMS_WTE=$(grep "wte:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
PARAMS_BIGRAM=$(grep "bigram_embed:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
PARAMS_VE=$(grep "value_embeds:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
PARAMS_LM=$(grep "lm_head:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
PARAMS_TRANSFORMER=$(grep "transformer_matrices:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
PARAMS_SCALARS=$(grep "scalars:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
PARAMS_TOTAL=$(grep "total:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
NUM_ITERS=$(grep "Calculated number of iterations" "$LOG_FILE" | tail -1 | sed 's/.*: //' | tr -d ',')
# Calculate tokens trained (iterations * batch_size, default 524288)
TOKENS_TRAINED=$((NUM_ITERS * 524288))
# Param:data ratio (using scaling params per Kaplan et al.)
PARAM_DATA_RATIO=$(python -c "print(f'{$TOKENS_TRAINED / $NUM_SCALING_PARAMS:.2f}')")
# Model dim
MODEL_DIM=$((d * 64))
# Val BPB from final eval
@@ -99,10 +109,10 @@ for flops in "${FLOPS_BUDGETS[@]}"; do
CORE_SCORE="0.0"
fi
log " Params: $NUM_PARAMS, Iters: $NUM_ITERS, Ratio: $PARAM_DATA_RATIO, Val BPB: $VAL_BPB, CORE: $CORE_SCORE"
log " Params: $PARAMS_TOTAL (transformer: $PARAMS_TRANSFORMER), Iters: $NUM_ITERS, Val BPB: $VAL_BPB, CORE: $CORE_SCORE"
# Append to CSV
echo "$flops,$d,$MODEL_DIM,$NUM_PARAMS,$NUM_SCALING_PARAMS,$NUM_ITERS,$TOKENS_TRAINED,$PARAM_DATA_RATIO,$VAL_BPB,$CORE_SCORE,$TRAIN_TIME" >> "$RESULTS_FILE"
echo "$flops,$d,$MODEL_DIM,$PARAMS_WTE,$PARAMS_BIGRAM,$PARAMS_VE,$PARAMS_LM,$PARAMS_TRANSFORMER,$PARAMS_SCALARS,$PARAMS_TOTAL,$NUM_ITERS,$TOKENS_TRAINED,$VAL_BPB,$CORE_SCORE,$TRAIN_TIME" >> "$RESULTS_FILE"
done
done

View File

@@ -55,11 +55,11 @@ python -m nanochat.report reset
# each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk
python -m nanochat.dataset -n 8
# Immediately also kick off downloading more shards in the background while tokenizer trains
# See comment below for why 240 is the right number here
python -m nanochat.dataset -n 240 &
# See comment below for why 370 is the right number here
python -m nanochat.dataset -n 370 &
DATASET_DOWNLOAD_PID=$!
# train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data
python -m scripts.tok_train --max_chars=2000000000 --vocab_size=65536
# train the tokenizer with vocab size 2**15 = 32768 on ~2B characters of data
python -m scripts.tok_train
# evaluate the tokenizer (report compression ratio etc.)
python -m scripts.tok_eval
@@ -70,7 +70,9 @@ python -m scripts.tok_eval
# Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens.
# Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars.
# At 250M chars/shard, this is 54B / 250M ~= 216 shards needed for pretraining.
# Round up to 240 for safety. At ~100MB/shard, this downloads ~24GB of data to disk.
# Round up to 240 for safety. Also, the new DataLoader wastes about 35% of tokens to cropping
# so 240 / (1 - 0.35) = 370 shards are needed.
# At ~100MB/shard, this downloads ~37GB of data to disk.
# (The total number of shards available in the entire dataset is 1822.)
echo "Waiting for dataset download to complete..."
wait $DATASET_DOWNLOAD_PID
@@ -79,7 +81,7 @@ wait $DATASET_DOWNLOAD_PID
NPROC_PER_NODE=8
# pretrain the d20 model
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --target_param_data_ratio=20 --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --target-param-data-ratio=20 --run=$WANDB_RUN
# evaluate the model on a larger chunk of train/val data and draw some samples
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
# evaluate the model on CORE tasks

View File

@@ -7,14 +7,14 @@ Example run as:
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
To evaluate a HuggingFace model:
python -m scripts.base_loss --hf_path openai-community/gpt2
python -m scripts.base_loss --hf-path openai-community/gpt2
"""
import argparse
from contextlib import nullcontext
import torch
from nanochat.checkpoint_manager import load_model
from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type
from nanochat.dataloader import tokenizing_distributed_data_loader
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit
from nanochat.tokenizer import get_token_bytes, HuggingFaceTokenizer
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
@@ -61,12 +61,12 @@ def get_hf_token_bytes(tokenizer, device="cpu"):
# CLI arguments
parser = argparse.ArgumentParser(description="Evaluate loss on train/val splits and sample from model")
parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size")
parser.add_argument("--split_tokens", type=int, default=40*524288, help="number of tokens to evaluate per split")
parser.add_argument("--model_tag", type=str, default=None, help="model tag for checkpoint directory")
parser.add_argument("--model_step", type=int, default=None, help="model step to load")
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--hf_path", type=str, default=None, help="HuggingFace model path (e.g. openai-community/gpt2)")
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
parser.add_argument("--split-tokens", type=int, default=40*524288, help="number of tokens to evaluate per split")
parser.add_argument("--model-tag", type=str, default=None, help="model tag for checkpoint directory")
parser.add_argument("--model-step", type=int, default=None, help="model step to load")
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--hf-path", type=str, default=None, help="HuggingFace model path (e.g. openai-community/gpt2)")
args = parser.parse_args()
# Load the base model and the tokenizer
@@ -97,14 +97,14 @@ assert args.split_tokens % tokens_per_step == 0, "split_tokens must be divisible
steps = args.split_tokens // tokens_per_step
bpb_results = {}
for split_name in ["train", "val"]:
loader = tokenizing_distributed_data_loader(tokenizer, args.device_batch_size, sequence_len, split_name, device=device)
loader = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, sequence_len, split_name, device=device)
with autocast_ctx:
bpb = evaluate_bpb(model, loader, steps, token_bytes)
print0(f"{split_name} bpb: {bpb:.4f}")
bpb_results[split_name] = bpb
print0(f"Model: {model_name}, {split_name} bpb: {bpb:.6f}")
# Master process also samples from the model (only for nanochat models)
# Master process also samples from the model for some basic knowledge-eliciting prompts (only for nanochat models)
samples = []
if ddp_rank == 0 and args.hf_path is None:
prompts = [
@@ -122,9 +122,23 @@ if ddp_rank == 0 and args.hf_path is None:
with autocast_ctx:
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
sample_str = tokenizer.decode(sample[0])
print0("-" * 80)
print0(sample_str)
samples.append(sample_str)
# Draw some unconditioned samples from the model (only for nanochat models)
unconditioned_samples = []
if ddp_rank == 0 and args.hf_path is None:
engine = Engine(model, tokenizer)
tokens = tokenizer("", prepend="<|bos|>")
with autocast_ctx:
samples, _ = engine.generate_batch(tokens, num_samples=8, max_tokens=128, temperature=1.0)
for sample in samples:
sample_str = tokenizer.decode(sample)
print0("-" * 80)
print0(sample_str)
unconditioned_samples.append(sample_str)
# Log to report
from nanochat.report import get_report
get_report().log(section="Base model loss", data=[
@@ -134,6 +148,7 @@ get_report().log(section="Base model loss", data=[
"val bpb": bpb_results["val"],
},
{f"sample {i}": sample for i, sample in enumerate(samples)},
{f"unconditioned sample {i}": sample for i, sample in enumerate(unconditioned_samples)},
])
# Cleanup

View File

@@ -1,14 +1,14 @@
"""
Train model. From root directory of the project, run as:
python -m scripts.base_train.py
python -m scripts.base_train
or distributed as:
torchrun --nproc_per_node=8 -m scripts.base_train.py
torchrun --nproc_per_node=8 -m scripts.base_train
If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_every=-1 --total_batch_size=512 --num_iterations=20
python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20
"""
import os
@@ -21,12 +21,13 @@ import wandb
import torch
from nanochat.gpt import GPT, GPTConfig
from nanochat.dataloader import tokenizing_distributed_data_loader, tokenizing_distributed_data_loader_with_state
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops
from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
from nanochat.flash_attention import HAS_FA3
from scripts.base_eval import evaluate_model
print_banner()
@@ -36,40 +37,40 @@ parser = argparse.ArgumentParser(description="Pretrain base model")
# Logging
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
# Model architecture
parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model")
parser.add_argument("--aspect_ratio", type=int, default=64, help="model_dim = depth * aspect_ratio")
parser.add_argument("--head_dim", type=int, default=128, help="target head dimension for attention")
parser.add_argument("--max_seq_len", type=int, default=2048, help="max context length")
parser.add_argument("--window_pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio")
parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention")
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
# Training horizon (only one used, in order of precedence)
parser.add_argument("--num_iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
parser.add_argument("--target_flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
parser.add_argument("--target_param_data_ratio", type=int, default=8, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
parser.add_argument("--target-param-data-ratio", type=float, default=10.5, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
# Optimization
parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size")
parser.add_argument("--total_batch_size", type=int, default=524288, help="total batch size in tokens")
parser.add_argument("--embedding_lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--weight_decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)")
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--scalar_lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
parser.add_argument("--adam_beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding")
parser.add_argument("--adam_beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding")
parser.add_argument("--warmup_ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
parser.add_argument("--warmdown_ratio", type=float, default=0.4, help="ratio of iterations for LR warmdown")
parser.add_argument("--final_lr_frac", type=float, default=0.0, help="final LR as fraction of initial LR")
parser.add_argument("--resume_from_step", type=int, default=-1, help="resume training from this step (-1 = disable)")
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--weight-decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)")
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
parser.add_argument("--adam-beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding")
parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding")
parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
parser.add_argument("--warmdown-ratio", type=float, default=0.4, help="ratio of iterations for LR warmdown")
parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR")
parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)")
# Evaluation
parser.add_argument("--eval_every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)")
parser.add_argument("--eval_tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
parser.add_argument("--core_metric_every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)")
parser.add_argument("--core_metric_max_per_task", type=int, default=500, help="examples per task for CORE metric")
parser.add_argument("--sample_every", type=int, default=2000, help="sample from model every N steps (-1 = disable)")
parser.add_argument("--save_every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)")
parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)")
parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)")
parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric")
parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)")
parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)")
# Output
parser.add_argument("--model_tag", type=str, default=None, help="override model tag for checkpoint directory name")
parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name")
args = parser.parse_args()
user_config = vars(args).copy() # for logging
# -----------------------------------------------------------------------------
@@ -81,11 +82,29 @@ master_process = ddp_rank == 0 # this process will do logging, checkpointing etc
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
if device_type == "cuda":
gpu_device_name = torch.cuda.get_device_name(0)
gpu_peak_flops = get_peak_flops(gpu_device_name)
print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}")
else:
gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS
# wandb logging init
use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config)
# Flash Attention status
if HAS_FA3:
print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.")
else:
print0("!" * 80)
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback")
print0("WARNING: Training will be less efficient without FA3")
if args.window_pattern != "L":
print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.")
print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.")
print0("!" * 80)
# Tokenizer will be useful for evaluation, also we need the vocab size
tokenizer = get_tokenizer()
token_bytes = get_token_bytes(device=device)
@@ -93,21 +112,19 @@ vocab_size = tokenizer.get_vocab_size()
print0(f"Vocab size: {vocab_size:,}")
# Model kwargs are derived from the desired depth of the model
# We nudge model_dim up to the nearest multiple of head_dim to ensure clean division
# (FA3 requires head_dim divisible by 8, and this guarantees head_dim == args.head_dim exactly)
# (For very small depths, this gives a slight "unfair" advantage to models with odd depths)
num_layers = args.depth
model_dim = args.depth * args.aspect_ratio
def find_num_heads(model_dim, target_head_dim):
# Find num_heads that divides model_dim evenly, with head_dim closest to target.
ideal = max(1, round(model_dim / target_head_dim))
for offset in range(model_dim):
for candidate in [ideal + offset, ideal - offset]:
if candidate > 0 and model_dim % candidate == 0:
return candidate
return 1
num_heads = find_num_heads(model_dim, args.head_dim)
base_dim = args.depth * args.aspect_ratio
model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim
num_heads = model_dim // args.head_dim
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
head_dim = model_dim // num_heads
print0(f"num_layers: {num_layers}")
print0(f"model_dim: {model_dim}")
print0(f"model_dim: {model_dim} (base: {base_dim}, nudge: {model_dim - base_dim:+d})")
print0(f"num_heads: {num_heads}")
print0(f"head_dim: {head_dim}")
print0(f"num_kv_heads: {num_kv_heads}")
# Optimizer / data / training length related hyperparameters
@@ -161,9 +178,14 @@ if resuming:
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
num_params = sum(p.numel() for p in model.parameters())
num_scaling_params = orig_model.num_scaling_params()
print0(f"Number of parameters: {num_params:,} (scaling: {num_scaling_params:,})")
# Detailed parameter counts
param_counts = orig_model.num_scaling_params()
print0(f"Parameter counts:")
for key, value in param_counts.items():
print0(f"{key:24s}: {value:,}")
num_params = param_counts['total']
num_scaling_params = param_counts['transformer_matrices'] + param_counts['lm_head'] # determined to give the cleanest scaling laws, see dev/LOG.md Jan 27, 2026
num_flops_per_token = model.estimate_flops()
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
@@ -178,20 +200,20 @@ elif args.target_flops > 0:
print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
elif args.target_param_data_ratio > 0:
# calculate the number of iterations from the target param data ratio (use scaling params per Kaplan et al.)
target_tokens = args.target_param_data_ratio * num_scaling_params
target_tokens = int(args.target_param_data_ratio * num_scaling_params)
num_iterations = target_tokens // args.total_batch_size
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
else:
raise ValueError("No training horizon specified")
total_tokens = args.total_batch_size * num_iterations
print0(f"Total number of training tokens: {total_tokens:,}")
print0(f"Tokens : Params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20
print0(f"Tokens : Scaling params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
# -----------------------------------------------------------------------------
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
adam_betas = (args.adam_beta1, args.adam_beta2)
optimizers = model.setup_optimizers(
optimizer = model.setup_optimizer(
unembedding_lr=args.unembedding_lr * batch_lr_scale,
embedding_lr=args.embedding_lr * batch_lr_scale,
matrix_lr=args.matrix_lr * batch_lr_scale,
@@ -199,19 +221,16 @@ optimizers = model.setup_optimizers(
adam_betas=adam_betas,
scalar_lr=args.scalar_lr * batch_lr_scale,
)
adamw_optimizer, muon_optimizer = optimizers
if resuming:
for opt, dat in zip(optimizers, optimizer_data):
opt.load_state_dict(dat)
del optimizer_data # free up the memory
optimizer.load_state_dict(optimizer_data)
del optimizer_data
# -----------------------------------------------------------------------------
# Initialize the DataLoaders for train/val
tokens_dir = os.path.join(base_dir, "tokenized_data")
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
train_loader = tokenizing_distributed_data_loader_with_state(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
build_val_loader = lambda: tokenizing_distributed_data_loader(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device)
train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device)
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
# -----------------------------------------------------------------------------
@@ -323,7 +342,7 @@ while True:
checkpoint_dir,
step,
orig_model.state_dict(), # model parameters
[opt.state_dict() for opt in optimizers], # optimizer states
optimizer.state_dict(), # optimizer state
{ # metadata saved as json
"step": step,
"val_bpb": val_bpb, # loss at last step
@@ -357,33 +376,31 @@ while True:
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward()
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
# step the optimizers
# step the optimizer
lrm = get_lr_multiplier(step)
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * lrm
muon_momentum = get_muon_momentum(step)
muon_weight_decay = get_weight_decay(step)
for group in muon_optimizer.param_groups:
group["momentum"] = muon_momentum
group["weight_decay"] = muon_weight_decay
for opt in optimizers:
opt.step()
for group in optimizer.param_groups:
group["lr"] = group["initial_lr"] * lrm
if group['kind'] == 'muon':
group["momentum"] = muon_momentum
group["weight_decay"] = muon_weight_decay
optimizer.step()
model.zero_grad(set_to_none=True)
train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point
synchronize()
t1 = time.time()
dt = t1 - t0
# -------------------------------------------------------------------------
# logging
# logging (CPU action only)
ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f # EMA the training loss
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
pct_done = 100 * step / num_iterations
tok_per_sec = int(args.total_batch_size / dt)
flops_per_sec = num_flops_per_token * args.total_batch_size / dt
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size)
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
# Calculate ETA based on average time per step (excluding first 10 steps)
@@ -395,7 +412,8 @@ while True:
eta_str = f" | eta: {eta_seconds/60:.1f}m"
else:
eta_str = ""
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m{eta_str}")
epoch = dataloader_state_dict["epoch"]
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}")
if step % 100 == 0:
log_data = {
"step": step,
@@ -406,6 +424,7 @@ while True:
"train/dt": dt,
"train/tok_per_sec": tok_per_sec,
"train/mfu": mfu,
"train/epoch": epoch,
}
wandb_run.log(log_data)
@@ -427,7 +446,7 @@ get_report().log(section="Base model training", data=[
"Number of FLOPs per token": f"{num_flops_per_token:e}",
"Calculated number of iterations": num_iterations,
"Number of training tokens": total_tokens,
"Tokens : Params ratio": args.total_batch_size * num_iterations / num_params,
"Tokens : Scaling params ratio": args.total_batch_size * num_iterations / num_scaling_params,
"DDP world size": ddp_world_size,
"warmup_ratio": args.warmup_ratio,
"warmdown_ratio": args.warmdown_ratio,

View File

@@ -4,8 +4,8 @@ All the generic code lives here, and all the evaluation-specific
code lives in nanochat directory and is imported from here.
Example runs:
python -m scripts.chat_eval -a ARC-Easy
torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy
python -m scripts.chat_eval -i mid -a ARC-Easy
torchrun --nproc_per_node=8 -m scripts.chat_eval -- -i mid -a ARC-Easy
"""
import argparse

View File

@@ -6,7 +6,7 @@ simpler and more similar to just REINFORCE:
1) Delete trust region, so there is no KL regularization to a reference model
2) We are on policy, so there's no need for PPO ratio+clip.
3) We use GAPO style normalization that is token-level, not sequence-level.
3) We use DAPO style normalization that is token-level, not sequence-level.
4) Instead of z-score normalization (r - mu)/sigma, only use (r - mu) as the advantage.
1 GPU:
@@ -35,32 +35,32 @@ parser = argparse.ArgumentParser(description="Reinforcement learning on GSM8K")
# Logging
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
# Model loading
parser.add_argument("--source", type=str, default="sft", help="mid|sft - which checkpoint to load from")
parser.add_argument("--model_tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model_step", type=int, default=None, help="model step to load from")
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
# Training horizon
parser.add_argument("--num_epochs", type=int, default=1, help="number of epochs over GSM8K")
parser.add_argument("--num-epochs", type=int, default=1, help="number of epochs over GSM8K")
# Batch sizes / sampling
parser.add_argument("--device_batch_size", type=int, default=8, help="max batch size per forward pass")
parser.add_argument("--examples_per_step", type=int, default=16, help="total examples per optimization step across all ranks")
parser.add_argument("--num_samples", type=int, default=16, help="number of samples per example/question")
parser.add_argument("--device-batch-size", type=int, default=8, help="max batch size per forward pass")
parser.add_argument("--examples-per-step", type=int, default=16, help="total examples per optimization step across all ranks")
parser.add_argument("--num-samples", type=int, default=16, help="number of samples per example/question")
# Generation
parser.add_argument("--max_new_tokens", type=int, default=256, help="max tokens to generate per sample")
parser.add_argument("--max-new-tokens", type=int, default=256, help="max tokens to generate per sample")
parser.add_argument("--temperature", type=float, default=1.0, help="sampling temperature")
parser.add_argument("--top_k", type=int, default=50, help="top-k sampling (0 = disabled)")
parser.add_argument("--top-k", type=int, default=50, help="top-k sampling (0 = disabled)")
# Optimization
parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
parser.add_argument("--init_lr_frac", type=float, default=0.05, help="initial LR as fraction of base LR")
parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
parser.add_argument("--init-lr-frac", type=float, default=0.05, help="initial LR as fraction of base LR")
# Evaluation / checkpointing
parser.add_argument("--eval_every", type=int, default=60, help="evaluate pass@k every N steps")
parser.add_argument("--eval_examples", type=int, default=400, help="number of examples for pass@k evaluation")
parser.add_argument("--save_every", type=int, default=60, help="save checkpoint every N steps")
parser.add_argument("--eval-every", type=int, default=60, help="evaluate pass@k every N steps")
parser.add_argument("--eval-examples", type=int, default=400, help="number of examples for pass@k evaluation")
parser.add_argument("--save-every", type=int, default=60, help="save checkpoint every N steps")
args = parser.parse_args()
user_config = vars(args).copy()
# -----------------------------------------------------------------------------
@@ -201,7 +201,7 @@ def run_gsm8k_eval(task, tokenizer, engine,
# Training loop
# Init the optimizer
optimizers = model.setup_optimizers(
optimizer = model.setup_optimizer(
unembedding_lr=args.unembedding_lr,
embedding_lr=args.embedding_lr,
matrix_lr=args.matrix_lr,
@@ -209,10 +209,9 @@ optimizers = model.setup_optimizers(
)
# Set the initial learning rate as a fraction of the base learning rate
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["lr"] * args.init_lr_frac
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
for group in optimizer.param_groups:
group["lr"] = group["lr"] * args.init_lr_frac
group["initial_lr"] = group["lr"]
# Learning rate scheduler: simple rampdown to zero over num_steps
def get_lr_multiplier(it):
@@ -305,11 +304,9 @@ for step in range(num_steps):
# Update the model parameters
lrm = get_lr_multiplier(step)
for opt in optimizers: # first set the learning rate
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * lrm
for opt in optimizers: # then step the optimizers
opt.step()
for group in optimizer.param_groups:
group["lr"] = group["initial_lr"] * lrm
optimizer.step()
model.zero_grad(set_to_none=True)
wandb_run.log({
"step": step,

View File

@@ -37,29 +37,29 @@ parser = argparse.ArgumentParser(description="Supervised finetuning for chat")
# Logging
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
# Model loading
parser.add_argument("--source", type=str, default="mid", help="base|mid - which checkpoint to load from")
parser.add_argument("--model_tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model_step", type=int, default=None, help="model step to load from")
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
# Training horizon
parser.add_argument("--num_epochs", type=int, default=1, help="number of epochs")
parser.add_argument("--num_iterations", type=int, default=-1, help="override number of iterations (-1 = use num_epochs)")
parser.add_argument("--num-epochs", type=int, default=1, help="number of epochs")
parser.add_argument("--num-iterations", type=int, default=-1, help="override number of iterations (-1 = use num_epochs)")
# Batch sizes
parser.add_argument("--device_batch_size", type=int, default=4, help="per-device batch size")
parser.add_argument("--target_examples_per_step", type=int, default=32, help="target examples per optimization step")
parser.add_argument("--device-batch-size", type=int, default=4, help="per-device batch size")
parser.add_argument("--target-examples-per-step", type=int, default=32, help="target examples per optimization step")
# Optimization
parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
parser.add_argument("--init_lr_frac", type=float, default=0.02, help="initial LR as fraction of base LR")
parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
parser.add_argument("--init-lr-frac", type=float, default=0.02, help="initial LR as fraction of base LR")
# Evaluation
parser.add_argument("--eval_every", type=int, default=100, help="evaluate val loss every N steps")
parser.add_argument("--eval_steps", type=int, default=100, help="number of batches for val loss evaluation")
parser.add_argument("--eval_metrics_every", type=int, default=200, help="evaluate accuracy metrics every N steps")
parser.add_argument("--eval_metrics_max_problems", type=int, default=1024, help="max problems per metric evaluation")
parser.add_argument("--eval-every", type=int, default=100, help="evaluate val loss every N steps")
parser.add_argument("--eval-steps", type=int, default=100, help="number of batches for val loss evaluation")
parser.add_argument("--eval-metrics-every", type=int, default=200, help="evaluate accuracy metrics every N steps")
parser.add_argument("--eval-metrics-max-problems", type=int, default=1024, help="max problems per metric evaluation")
args = parser.parse_args()
user_config = vars(args).copy()
# -----------------------------------------------------------------------------
@@ -150,17 +150,16 @@ build_val_loader = lambda: sft_data_generator(val_ds, batch_size=args.device_bat
# -----------------------------------------------------------------------------
# Initialize the Optimizer
optimizers = model.setup_optimizers(
optimizer = model.setup_optimizer(
unembedding_lr=args.unembedding_lr,
embedding_lr=args.embedding_lr,
matrix_lr=args.matrix_lr,
weight_decay=args.weight_decay,
)
# Set the initial learning rate as a fraction of the base learning rate
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["lr"] * args.init_lr_frac
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
for group in optimizer.param_groups:
group["lr"] = group["lr"] * args.init_lr_frac
group["initial_lr"] = group["lr"]
# -----------------------------------------------------------------------------
# Training loop
@@ -230,13 +229,11 @@ for step in range(num_iterations):
# learning rate scheduler
lrm = get_lr_multiplier(step)
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * lrm
for group in optimizer.param_groups:
group["lr"] = group["initial_lr"] * lrm
# step the optimizers
for opt in optimizers:
opt.step()
# step the optimizer
optimizer.step()
model.zero_grad(set_to_none=True)
# logging

View File

@@ -6,11 +6,10 @@ python -m scripts.mid_train
Or torchrun for training:
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device-batch-size=16
"""
import argparse
from collections import deque
import os
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
import time
@@ -37,28 +36,28 @@ parser = argparse.ArgumentParser(description="Midtrain the model")
# Logging
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
# Model loading
parser.add_argument("--model_tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model_step", type=int, default=None, help="model step to load from")
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
# Training horizon
parser.add_argument("--num_iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)")
parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)")
# Batch sizes
parser.add_argument("--max_seq_len", type=int, default=2048, help="max context length")
parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size")
parser.add_argument("--total_batch_size", type=int, default=524288, help="total batch size in tokens")
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
# Optimization
parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
parser.add_argument("--init_lr_frac", type=float, default=1.0, help="initial LR as fraction of base LR")
parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
parser.add_argument("--init-lr-frac", type=float, default=1.0, help="initial LR as fraction of base LR")
# Evaluation
parser.add_argument("--eval_every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)")
parser.add_argument("--eval_tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
parser.add_argument("--eval-every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)")
parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
# Output
parser.add_argument("--dry_run", action="store_true", help="log to wandb but skip checkpoints/report")
parser.add_argument("--dry-run", action="store_true", help="log to wandb but skip checkpoints/report")
args = parser.parse_args()
user_config = vars(args).copy()
# -----------------------------------------------------------------------------
@@ -80,7 +79,7 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mi
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step)
pretrain_batch_size = meta.get("device_batch_size", None)
if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size:
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?")
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?")
orig_model = model
model = torch.compile(model, dynamic=False)
depth = model.config.n_layer
@@ -94,14 +93,12 @@ print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
token_bytes = get_token_bytes(device=device)
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
optimizers = model.setup_optimizers(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay)
adamw_optimizer, muon_optimizer = optimizers
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay)
# Override the initial learning rate as a fraction of the base learning rate
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["lr"] * args.init_lr_frac
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
for group in optimizer.param_groups:
group["lr"] = group["lr"] * args.init_lr_frac
group["initial_lr"] = group["lr"]
# Midtraining data mixture and DataLoader
base_dir = get_base_dir()
@@ -125,49 +122,100 @@ val_dataset = TaskMixture([
# these two global variables and update them from within the data generator.
last_step = False # we will toggle this to True when we reach the end of the training dataset
approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
def mid_data_generator(split):
global last_step, approx_progress
current_epoch = 1 # track epoch for logging
def mid_data_generator_bos_bestfit(split, buffer_size=100):
"""
BOS-aligned dataloader for midtraining with bestfit-crop packing.
Each row in the batch starts with BOS (beginning of a conversation).
Conversations are packed using best-fit algorithm to minimize cropping.
This matches the BOS-aligned approach used in pretraining.
"""
global last_step, approx_progress, current_epoch
assert split in {"train", "val"}, "split must be 'train' or 'val'"
dataset = train_dataset if split == "train" else val_dataset
dataset_size = len(dataset)
assert dataset_size > 0
needed_tokens = args.device_batch_size * args.max_seq_len + 1 # to form one training batch of inputs,targets
token_buffer = deque()
# CUDA supports memory pinning for faster transfers between CPU and GPU:
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=(device_type == "cuda"))
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
it = 0 # iteration counter
while True:
# Accumulate enough tokens for one iteration before yielding
while len(token_buffer) < needed_tokens:
row_capacity = args.max_seq_len + 1 # +1 for target at last position
# Conversation buffer: list of token lists
conv_buffer = []
cursor = ddp_rank # Each rank processes different conversations (for fetching)
consumed = ddp_rank # Track actual consumption separately from buffering
epoch = 1
it = 0 # iteration counter
def refill_buffer():
nonlocal cursor, epoch
while len(conv_buffer) < buffer_size:
conversation = dataset[cursor]
ids, _ = tokenizer.render_conversation(conversation)
token_buffer.extend(ids)
conv_buffer.append(ids)
cursor += ddp_world_size
if cursor >= dataset_size:
cursor -= dataset_size # wrap around for another epoch
if split == "train":
last_step = True # toggle last_step to True, which will terminate the training loop
cursor = cursor % dataset_size
epoch += 1
# Note: last_step is now triggered based on consumption, not fetching
while True:
rows = []
for _ in range(args.device_batch_size):
row = []
while len(row) < row_capacity:
# Ensure buffer has conversations
while len(conv_buffer) < buffer_size:
refill_buffer()
remaining = row_capacity - len(row)
# Find largest conversation that fits entirely
best_idx = -1
best_len = 0
for i, conv in enumerate(conv_buffer):
conv_len = len(conv)
if conv_len <= remaining and conv_len > best_len:
best_idx = i
best_len = conv_len
if best_idx >= 0:
# Found a conversation that fits - use it entirely
conv = conv_buffer.pop(best_idx)
row.extend(conv)
consumed += ddp_world_size # Track actual consumption
else:
# No conversation fits - crop first conversation to fill remaining
conv = conv_buffer.pop(0)
row.extend(conv[:remaining])
consumed += ddp_world_size # Track actual consumption
rows.append(row[:row_capacity])
# Stopping condition to respect num_iterations, if given
it += 1
if 0 < args.num_iterations <= it and split == "train":
last_step = True # toggle last_step to True, which will terminate the training loop
# Build up inputs/targets and yield
for i in range(needed_tokens):
scratch[i] = token_buffer.popleft()
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
targets_cpu = scratch[1:]
inputs = inputs_cpu.view(args.device_batch_size, args.max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
targets = targets_cpu.view(args.device_batch_size, args.max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
last_step = True
# Update progress tracking (based on consumed, not cursor, to account for buffering)
if split == "train":
current_epoch = epoch
if args.num_iterations > 0:
approx_progress = it / args.num_iterations # calculate progress from the max number of iterations
approx_progress = it / args.num_iterations
else:
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
approx_progress = consumed / dataset_size
# Trigger last_step when we've consumed enough (instead of when cursor wraps)
if consumed >= dataset_size:
last_step = True
# Build tensors
use_cuda = device_type == "cuda"
batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda)
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda)
yield inputs, targets
train_loader = mid_data_generator("train")
build_val_loader = lambda: mid_data_generator("val")
train_loader = mid_data_generator_bos_bestfit("train")
build_val_loader = lambda: mid_data_generator_bos_bestfit("val")
progress = 0 # will go from 0 to 1 over the course of the epoch
# Learning rate scheduler
@@ -199,7 +247,7 @@ while True:
last_step = bool(last_step_tensor.item())
# once in a while: evaluate the val bpb (all ranks participate)
if args.eval_every > 0 and (last_step or step % args.eval_every == 0):
if last_step or (args.eval_every > 0 and step % args.eval_every == 0):
model.eval()
val_loader = build_val_loader()
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
@@ -224,7 +272,7 @@ while True:
checkpoint_dir,
step,
orig_model.state_dict(),
[opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly
optimizer.state_dict(),
{
"step": step,
"val_bpb": val_bpb, # loss at last step
@@ -256,16 +304,14 @@ while True:
loss.backward()
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
progress = max(progress, approx_progress) # only increase progress monotonically
# step the optimizers
# step the optimizer
lrm = get_lr_multiplier(progress)
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * lrm
muon_momentum = get_muon_momentum(step)
for group in muon_optimizer.param_groups:
group["momentum"] = muon_momentum
for opt in optimizers:
opt.step()
for group in optimizer.param_groups:
group["lr"] = group["initial_lr"] * lrm
if group['kind'] == 'muon':
group["momentum"] = muon_momentum
optimizer.step()
model.zero_grad(set_to_none=True)
synchronize()
t1 = time.time()
@@ -285,7 +331,7 @@ while True:
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m")
if step % 10 == 0:
wandb_run.log({
"step": step,
@@ -296,6 +342,7 @@ while True:
"train/dt": dt,
"train/tok_per_sec": tok_per_sec,
"train/mfu": mfu,
"train/epoch": current_epoch,
})
# print a few more stats

View File

@@ -14,9 +14,9 @@ from nanochat.dataset import parquets_iter_batched
# Parse command line arguments
parser = argparse.ArgumentParser(description='Train a BPE tokenizer')
parser.add_argument('--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)')
parser.add_argument('--doc_cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)')
parser.add_argument('--vocab_size', type=int, default=32768, help='Vocabulary size (default: 32768 = 2^15)')
parser.add_argument('--max-chars', type=int, default=2_000_000_000, help='Maximum characters to train on (default: 10B)')
parser.add_argument('--doc-cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)')
parser.add_argument('--vocab-size', type=int, default=32768, help='Vocabulary size (default: 32768 = 2^15)')
args = parser.parse_args()
print(f"max_chars: {args.max_chars:,}")
print(f"doc_cap: {args.doc_cap:,}")

View File

@@ -25,7 +25,7 @@ class CustomJSON(Task):
print("-" * 80)
print(f"Warning: File {filepath} does not exist")
print("HINT (Oct 21 2025)")
print("If you recently did a git pull and suddely see this, it might be due to the new addition of identity conversations")
print("If you recently did a git pull and suddenly see this, it might be due to the new addition of identity conversations")
print("See this discussion for more details: https://github.com/karpathy/nanochat/discussions/139")
print("Quick fix: simply run the following command to download the file and you're done:")
print(f"curl -L -o {filepath} https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl")

View File

@@ -0,0 +1,338 @@
"""
Test Flash Attention unified interface - verify FA3 and SDPA produce identical results.
Run: python -m pytest tests/test_attention_fallback.py -v -s
Note on test structure:
Tests are split into two classes due to dtype/device constraints:
1. TestFA3VsSDPA: Comparison tests that run both FA3 and SDPA on the same inputs
and verify they produce identical results. These require a Hopper GPU (FA3 only
works on sm90+) and use bfloat16 (FA3 doesn't support float32).
2. TestSDPAOnly: Tests that only exercise the SDPA fallback path. These can run
on any device (CUDA, CPU, MPS) with the appropriate dtype for that device.
"""
import torch
import pytest
import nanochat.flash_attention as fa_module
from nanochat.flash_attention import flash_attn, HAS_FA3
from nanochat.engine import KVCache
def set_impl(impl):
"""Set the implementation override ('fa3', 'sdpa', or None for auto)."""
fa_module._override_impl = impl
def run_both_impls(fn):
"""Run a function with both FA3 and SDPA, return both outputs."""
set_impl('fa3')
out_fa3 = fn()
set_impl('sdpa')
out_sdpa = fn()
set_impl(None) # reset
return out_fa3, out_sdpa
def assert_close(t1, t2, name, atol=1e-2, rtol=1e-2):
"""Assert two tensors are close, with helpful error message."""
max_diff = (t1 - t2).abs().max().item()
mean_diff = (t1 - t2).abs().mean().item()
assert torch.allclose(t1, t2, atol=atol, rtol=rtol), \
f"{name}: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}"
return max_diff, mean_diff
# =============================================================================
# FA3 vs SDPA comparison tests (require Hopper GPU)
# =============================================================================
@pytest.mark.skipif(not HAS_FA3, reason="FA3 required to compare implementations")
class TestFA3VsSDPA:
"""Compare FA3 and SDPA produce identical results. Requires Hopper GPU."""
DEVICE = "cuda"
DTYPE = torch.bfloat16
def test_basic_causal(self):
"""Basic causal attention."""
B, T, H, D = 2, 64, 4, 32
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
y_fa3, y_sdpa = run_both_impls(run)
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "basic_causal")
print(f"basic_causal: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
def test_full_context(self):
"""Full context (window_size=-1)."""
B, T, H, D = 2, 128, 4, 32
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(-1, -1))
y_fa3, y_sdpa = run_both_impls(run)
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "full_context")
print(f"full_context: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
def test_sliding_window(self):
"""Sliding window attention."""
B, T, H, D = 2, 128, 4, 32
window = 32
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(window, 0))
y_fa3, y_sdpa = run_both_impls(run)
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "sliding_window")
print(f"sliding_window: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
def test_gqa(self):
"""Group Query Attention (fewer KV heads than Q heads)."""
B, T, D = 2, 64, 32
n_heads = 8
n_kv_heads = 2
q = torch.randn(B, T, n_heads, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
y_fa3, y_sdpa = run_both_impls(run)
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "gqa")
print(f"gqa: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
def test_larger_model(self):
"""Larger dimensions closer to real model."""
B, T, H, D = 4, 256, 12, 64
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(-1, -1))
y_fa3, y_sdpa = run_both_impls(run)
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "larger_model")
print(f"larger_model: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
def test_kvcache_prefill(self):
"""Test prefill (inserting multiple tokens into empty cache)."""
B, T_max, H, D = 2, 64, 4, 32
T_prefill = 16
q = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
cache_seqlens = torch.zeros(B, dtype=torch.int32, device=self.DEVICE)
return flash_attn.flash_attn_with_kvcache(
q, k_cache, v_cache, k=k, v=v,
cache_seqlens=cache_seqlens,
causal=True, window_size=(T_max, 0)
)
y_fa3, y_sdpa = run_both_impls(run)
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "prefill")
print(f"prefill: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
def test_kvcache_single_token(self):
"""Test single token generation (cache already has content)."""
B, T_max, H, D = 2, 64, 4, 32
T_prefill = 16
k_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
q_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
k_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
k_cache[:, :T_prefill, :, :] = k_init
v_cache[:, :T_prefill, :, :] = v_init
cache_seqlens = torch.full((B,), T_prefill, dtype=torch.int32, device=self.DEVICE)
return flash_attn.flash_attn_with_kvcache(
q_single, k_cache, v_cache, k=k_single, v=v_single,
cache_seqlens=cache_seqlens,
causal=True, window_size=(T_max, 0)
)
y_fa3, y_sdpa = run_both_impls(run)
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "single_token")
print(f"single_token: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
def test_backward_gradients_match(self):
"""Verify gradients are similar between FA3 and SDPA."""
B, T, H, D = 2, 32, 4, 16
q_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
q = q_data.clone().requires_grad_(True)
k = k_data.clone().requires_grad_(True)
v = v_data.clone().requires_grad_(True)
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
loss = y.sum()
loss.backward()
return y.detach(), q.grad.detach(), k.grad.detach(), v.grad.detach()
set_impl('fa3')
y_fa3, q_grad_fa3, k_grad_fa3, v_grad_fa3 = run()
set_impl('sdpa')
y_sdpa, q_grad_sdpa, k_grad_sdpa, v_grad_sdpa = run()
set_impl(None)
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "backward_output")
print(f"backward_output: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
max_diff, mean_diff = assert_close(q_grad_fa3, q_grad_sdpa, "q_grad", atol=0.05, rtol=0.05)
print(f"q_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
max_diff, mean_diff = assert_close(k_grad_fa3, k_grad_sdpa, "k_grad", atol=0.05, rtol=0.05)
print(f"k_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
max_diff, mean_diff = assert_close(v_grad_fa3, v_grad_sdpa, "v_grad", atol=0.05, rtol=0.05)
print(f"v_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
# =============================================================================
# SDPA-only tests (run on any device)
# =============================================================================
class TestSDPAOnly:
"""Test SDPA fallback works correctly. Runs on any device."""
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
def test_basic_forward(self):
"""Test SDPA forward pass produces valid output."""
set_impl('sdpa')
B, T, H, D = 2, 64, 4, 32
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
assert y.shape == (B, T, H, D)
assert not torch.isnan(y).any(), "Output contains NaN"
set_impl(None)
def test_backward(self):
"""Test gradients flow through SDPA."""
set_impl('sdpa')
B, T, H, D = 2, 32, 4, 16
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True)
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True)
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True)
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
loss = y.sum()
loss.backward()
assert q.grad is not None, "No gradient for q"
assert k.grad is not None, "No gradient for k"
assert v.grad is not None, "No gradient for v"
assert not torch.isnan(q.grad).any(), "NaN in q gradient"
set_impl(None)
def test_kvcache(self):
"""Test SDPA with KV cache."""
set_impl('sdpa')
B, T_max, H, D = 2, 64, 4, 32
n_layers = 1
cache = KVCache(
batch_size=B, num_heads=H, seq_len=T_max, head_dim=D,
num_layers=n_layers, device=self.DEVICE, dtype=self.DTYPE
)
k_cache, v_cache = cache.get_layer_cache(0)
# Prefill
T_prefill = 16
q = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
y = flash_attn.flash_attn_with_kvcache(
q, k_cache, v_cache, k=k, v=v,
cache_seqlens=cache.cache_seqlens,
causal=True, window_size=(T_max, 0)
)
cache.advance(T_prefill)
assert y.shape == (B, T_prefill, H, D)
assert cache.get_pos() == T_prefill
# Generate single token
q_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
k_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
y_single = flash_attn.flash_attn_with_kvcache(
q_single, k_cache, v_cache, k=k_single, v=v_single,
cache_seqlens=cache.cache_seqlens,
causal=True, window_size=(T_max, 0)
)
cache.advance(1)
assert y_single.shape == (B, 1, H, D)
assert cache.get_pos() == T_prefill + 1
set_impl(None)
# =============================================================================
# Override mechanism tests
# =============================================================================
class TestOverrideMechanism:
"""Test that the override mechanism works correctly."""
@pytest.mark.skipif(not HAS_FA3, reason="FA3 required")
def test_override_fa3(self):
"""Test that override='fa3' uses FA3."""
set_impl('fa3')
assert fa_module._use_fa3() == True
set_impl(None)
def test_override_sdpa(self):
"""Test that override='sdpa' uses SDPA."""
set_impl('sdpa')
assert fa_module._use_fa3() == False
set_impl(None)
def test_override_auto(self):
"""Test that override=None uses auto-detection."""
set_impl(None)
assert fa_module._use_fa3() == HAS_FA3
if __name__ == "__main__":
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA device: {torch.cuda.get_device_name()}")
major, minor = torch.cuda.get_device_capability()
print(f"Compute capability: {major}.{minor}")
print(f"HAS_FA3: {HAS_FA3}")
print()
pytest.main([__file__, "-v", "-s"])

View File

@@ -96,6 +96,7 @@ def test_kv_cache_basic():
head_dim=head_dim,
num_layers=num_layers,
device="cpu",
dtype=torch.float32,
)
# Check initial state
@@ -130,7 +131,7 @@ def test_kv_cache_prefill():
# Create source cache and advance it
src_cache = KVCache(
batch_size=batch_size, num_heads=num_heads, seq_len=32,
head_dim=head_dim, num_layers=num_layers, device="cpu",
head_dim=head_dim, num_layers=num_layers, device="cpu", dtype=torch.float32,
)
# Write some data to source cache
src_cache.k_cache[0, 0, :16, :, :] = 1.0
@@ -140,7 +141,7 @@ def test_kv_cache_prefill():
# Create destination cache with larger seq_len
dst_cache = KVCache(
batch_size=batch_size, num_heads=num_heads, seq_len=64,
head_dim=head_dim, num_layers=num_layers, device="cpu",
head_dim=head_dim, num_layers=num_layers, device="cpu", dtype=torch.float32,
)
# Prefill
@@ -195,3 +196,72 @@ def test_multi_sample_first_token_diversity():
f"With uniform logits, this is statistically impossible (~10^-36 probability) "
f"unless tokens are being broadcast instead of independently sampled."
)
def test_seed_reproducibility():
"""Same seed must produce identical output."""
model = MockModel()
engine = Engine(model, ByteTokenizer())
prompt = [261, 72, 101, 108, 108, 111] # <bos> + "Hello"
for seed in [1, 42, 123, 999]:
r1, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed)
r2, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed)
r3, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed)
assert r1 == r2 == r3, "Same seed must produce identical output for the same prompt."
def test_temperature_zero_determinism():
"""Temperature=0 is deterministic regardless of seed."""
model = MockModel()
engine = Engine(model, ByteTokenizer())
prompt = [261, 72, 101, 108, 108, 111]
r1, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=1)
r2, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=42)
r3, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=123)
assert r1 == r2 == r3, "Temperature=0 must result in the same output for the same prompt regardless of seed."
def test_max_tokens_respected():
"""Generation stops at max_tokens limit."""
model = MockModel()
engine = Engine(model, ByteTokenizer())
prompt = [261, 72, 101, 108, 108, 111]
for max_tokens in [1, 4, 16, 64]:
results, _ = engine.generate_batch(prompt, max_tokens=max_tokens)
num_generated_tokens = len(results[0]) - len(prompt)
assert num_generated_tokens <= max_tokens, f"Generated {num_generated_tokens} tokens, expected max_tokens={max_tokens} or less."
def test_num_samples_count():
"""num_samples=N produces exactly N sequences."""
model = MockModel()
engine = Engine(model, ByteTokenizer())
prompt = [261, 72, 101, 108, 108, 111]
for num_samples in [1, 4, 16, 64]:
results, _ = engine.generate_batch(prompt, num_samples=num_samples, max_tokens=3)
assert len(results) == num_samples, f"Expected {num_samples} sequences from {num_samples} samples, got {len(results)}"
def test_different_seeds_introduce_variation_when_temperature_nonzero():
"""With temperature > 0, different seeds should introduce sampling variation."""
model = MockModel()
engine = Engine(model, ByteTokenizer())
prompt = [261, 72, 101, 108, 108, 111] # <bos> + "Hello"
outputs = set()
for seed in [1, 42, 123, 999, 1000, 1001, 1002, 1003, 1004, 1005]:
results, _ = engine.generate_batch(
prompt,
temperature=1.0,
max_tokens=5,
seed=seed,
)
outputs.add(tuple(results[0]))
# Sanity check: sampling actually introduces variation
assert len(outputs) > 1, "All seeds produced the same output which is statistically highly improbable."

92
uv.lock generated
View File

@@ -1513,6 +1513,7 @@ dependencies = [
{ name = "transformers" },
{ name = "uvicorn" },
{ name = "wandb" },
{ name = "zstandard" },
]
[package.optional-dependencies]
@@ -1551,6 +1552,7 @@ requires-dist = [
{ name = "transformers", specifier = ">=4.57.3" },
{ name = "uvicorn", specifier = ">=0.36.0" },
{ name = "wandb", specifier = ">=0.21.3" },
{ name = "zstandard", specifier = ">=0.25.0" },
]
provides-extras = ["cpu", "gpu"]
@@ -3619,3 +3621,93 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/94/c3/b2e9f38bc3e11191981d57ea08cab2166e74ea770024a646617c9cddd9f6/yarl-1.20.1-cp313-cp313t-win_amd64.whl", hash = "sha256:541d050a355bbbc27e55d906bc91cb6fe42f96c01413dd0f4ed5a5240513874f", size = 93003, upload-time = "2025-06-10T00:45:27.752Z" },
{ url = "https://files.pythonhosted.org/packages/b4/2d/2345fce04cfd4bee161bf1e7d9cdc702e3e16109021035dbb24db654a622/yarl-1.20.1-py3-none-any.whl", hash = "sha256:83b8eb083fe4683c6115795d9fc1cfaf2cbbefb19b3a1cb68f6527460f483a77", size = 46542, upload-time = "2025-06-10T00:46:07.521Z" },
]
[[package]]
name = "zstandard"
version = "0.25.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/fd/aa/3e0508d5a5dd96529cdc5a97011299056e14c6505b678fd58938792794b1/zstandard-0.25.0.tar.gz", hash = "sha256:7713e1179d162cf5c7906da876ec2ccb9c3a9dcbdffef0cc7f70c3667a205f0b", size = 711513, upload-time = "2025-09-14T22:15:54.002Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/56/7a/28efd1d371f1acd037ac64ed1c5e2b41514a6cc937dd6ab6a13ab9f0702f/zstandard-0.25.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e59fdc271772f6686e01e1b3b74537259800f57e24280be3f29c8a0deb1904dd", size = 795256, upload-time = "2025-09-14T22:15:56.415Z" },
{ url = "https://files.pythonhosted.org/packages/96/34/ef34ef77f1ee38fc8e4f9775217a613b452916e633c4f1d98f31db52c4a5/zstandard-0.25.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4d441506e9b372386a5271c64125f72d5df6d2a8e8a2a45a0ae09b03cb781ef7", size = 640565, upload-time = "2025-09-14T22:15:58.177Z" },
{ url = "https://files.pythonhosted.org/packages/9d/1b/4fdb2c12eb58f31f28c4d28e8dc36611dd7205df8452e63f52fb6261d13e/zstandard-0.25.0-cp310-cp310-manylinux2010_i686.manylinux2014_i686.manylinux_2_12_i686.manylinux_2_17_i686.whl", hash = "sha256:ab85470ab54c2cb96e176f40342d9ed41e58ca5733be6a893b730e7af9c40550", size = 5345306, upload-time = "2025-09-14T22:16:00.165Z" },
{ url = "https://files.pythonhosted.org/packages/73/28/a44bdece01bca027b079f0e00be3b6bd89a4df180071da59a3dd7381665b/zstandard-0.25.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e05ab82ea7753354bb054b92e2f288afb750e6b439ff6ca78af52939ebbc476d", size = 5055561, upload-time = "2025-09-14T22:16:02.22Z" },
{ url = "https://files.pythonhosted.org/packages/e9/74/68341185a4f32b274e0fc3410d5ad0750497e1acc20bd0f5b5f64ce17785/zstandard-0.25.0-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:78228d8a6a1c177a96b94f7e2e8d012c55f9c760761980da16ae7546a15a8e9b", size = 5402214, upload-time = "2025-09-14T22:16:04.109Z" },
{ url = "https://files.pythonhosted.org/packages/8b/67/f92e64e748fd6aaffe01e2b75a083c0c4fd27abe1c8747fee4555fcee7dd/zstandard-0.25.0-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:2b6bd67528ee8b5c5f10255735abc21aa106931f0dbaf297c7be0c886353c3d0", size = 5449703, upload-time = "2025-09-14T22:16:06.312Z" },
{ url = "https://files.pythonhosted.org/packages/fd/e5/6d36f92a197c3c17729a2125e29c169f460538a7d939a27eaaa6dcfcba8e/zstandard-0.25.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4b6d83057e713ff235a12e73916b6d356e3084fd3d14ced499d84240f3eecee0", size = 5556583, upload-time = "2025-09-14T22:16:08.457Z" },
{ url = "https://files.pythonhosted.org/packages/d7/83/41939e60d8d7ebfe2b747be022d0806953799140a702b90ffe214d557638/zstandard-0.25.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9174f4ed06f790a6869b41cba05b43eeb9a35f8993c4422ab853b705e8112bbd", size = 5045332, upload-time = "2025-09-14T22:16:10.444Z" },
{ url = "https://files.pythonhosted.org/packages/b3/87/d3ee185e3d1aa0133399893697ae91f221fda79deb61adbe998a7235c43f/zstandard-0.25.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:25f8f3cd45087d089aef5ba3848cd9efe3ad41163d3400862fb42f81a3a46701", size = 5572283, upload-time = "2025-09-14T22:16:12.128Z" },
{ url = "https://files.pythonhosted.org/packages/0a/1d/58635ae6104df96671076ac7d4ae7816838ce7debd94aecf83e30b7121b0/zstandard-0.25.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:3756b3e9da9b83da1796f8809dd57cb024f838b9eeafde28f3cb472012797ac1", size = 4959754, upload-time = "2025-09-14T22:16:14.225Z" },
{ url = "https://files.pythonhosted.org/packages/75/d6/57e9cb0a9983e9a229dd8fd2e6e96593ef2aa82a3907188436f22b111ccd/zstandard-0.25.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:81dad8d145d8fd981b2962b686b2241d3a1ea07733e76a2f15435dfb7fb60150", size = 5266477, upload-time = "2025-09-14T22:16:16.343Z" },
{ url = "https://files.pythonhosted.org/packages/d1/a9/ee891e5edf33a6ebce0a028726f0bbd8567effe20fe3d5808c42323e8542/zstandard-0.25.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:a5a419712cf88862a45a23def0ae063686db3d324cec7edbe40509d1a79a0aab", size = 5440914, upload-time = "2025-09-14T22:16:18.453Z" },
{ url = "https://files.pythonhosted.org/packages/58/08/a8522c28c08031a9521f27abc6f78dbdee7312a7463dd2cfc658b813323b/zstandard-0.25.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:e7360eae90809efd19b886e59a09dad07da4ca9ba096752e61a2e03c8aca188e", size = 5819847, upload-time = "2025-09-14T22:16:20.559Z" },
{ url = "https://files.pythonhosted.org/packages/6f/11/4c91411805c3f7b6f31c60e78ce347ca48f6f16d552fc659af6ec3b73202/zstandard-0.25.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:75ffc32a569fb049499e63ce68c743155477610532da1eb38e7f24bf7cd29e74", size = 5363131, upload-time = "2025-09-14T22:16:22.206Z" },
{ url = "https://files.pythonhosted.org/packages/ef/d6/8c4bd38a3b24c4c7676a7a3d8de85d6ee7a983602a734b9f9cdefb04a5d6/zstandard-0.25.0-cp310-cp310-win32.whl", hash = "sha256:106281ae350e494f4ac8a80470e66d1fe27e497052c8d9c3b95dc4cf1ade81aa", size = 436469, upload-time = "2025-09-14T22:16:25.002Z" },
{ url = "https://files.pythonhosted.org/packages/93/90/96d50ad417a8ace5f841b3228e93d1bb13e6ad356737f42e2dde30d8bd68/zstandard-0.25.0-cp310-cp310-win_amd64.whl", hash = "sha256:ea9d54cc3d8064260114a0bbf3479fc4a98b21dffc89b3459edd506b69262f6e", size = 506100, upload-time = "2025-09-14T22:16:23.569Z" },
{ url = "https://files.pythonhosted.org/packages/2a/83/c3ca27c363d104980f1c9cee1101cc8ba724ac8c28a033ede6aab89585b1/zstandard-0.25.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:933b65d7680ea337180733cf9e87293cc5500cc0eb3fc8769f4d3c88d724ec5c", size = 795254, upload-time = "2025-09-14T22:16:26.137Z" },
{ url = "https://files.pythonhosted.org/packages/ac/4d/e66465c5411a7cf4866aeadc7d108081d8ceba9bc7abe6b14aa21c671ec3/zstandard-0.25.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a3f79487c687b1fc69f19e487cd949bf3aae653d181dfb5fde3bf6d18894706f", size = 640559, upload-time = "2025-09-14T22:16:27.973Z" },
{ url = "https://files.pythonhosted.org/packages/12/56/354fe655905f290d3b147b33fe946b0f27e791e4b50a5f004c802cb3eb7b/zstandard-0.25.0-cp311-cp311-manylinux2010_i686.manylinux2014_i686.manylinux_2_12_i686.manylinux_2_17_i686.whl", hash = "sha256:0bbc9a0c65ce0eea3c34a691e3c4b6889f5f3909ba4822ab385fab9057099431", size = 5348020, upload-time = "2025-09-14T22:16:29.523Z" },
{ url = "https://files.pythonhosted.org/packages/3b/13/2b7ed68bd85e69a2069bcc72141d378f22cae5a0f3b353a2c8f50ef30c1b/zstandard-0.25.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:01582723b3ccd6939ab7b3a78622c573799d5d8737b534b86d0e06ac18dbde4a", size = 5058126, upload-time = "2025-09-14T22:16:31.811Z" },
{ url = "https://files.pythonhosted.org/packages/c9/dd/fdaf0674f4b10d92cb120ccff58bbb6626bf8368f00ebfd2a41ba4a0dc99/zstandard-0.25.0-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:5f1ad7bf88535edcf30038f6919abe087f606f62c00a87d7e33e7fc57cb69fcc", size = 5405390, upload-time = "2025-09-14T22:16:33.486Z" },
{ url = "https://files.pythonhosted.org/packages/0f/67/354d1555575bc2490435f90d67ca4dd65238ff2f119f30f72d5cde09c2ad/zstandard-0.25.0-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:06acb75eebeedb77b69048031282737717a63e71e4ae3f77cc0c3b9508320df6", size = 5452914, upload-time = "2025-09-14T22:16:35.277Z" },
{ url = "https://files.pythonhosted.org/packages/bb/1f/e9cfd801a3f9190bf3e759c422bbfd2247db9d7f3d54a56ecde70137791a/zstandard-0.25.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9300d02ea7c6506f00e627e287e0492a5eb0371ec1670ae852fefffa6164b072", size = 5559635, upload-time = "2025-09-14T22:16:37.141Z" },
{ url = "https://files.pythonhosted.org/packages/21/88/5ba550f797ca953a52d708c8e4f380959e7e3280af029e38fbf47b55916e/zstandard-0.25.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:bfd06b1c5584b657a2892a6014c2f4c20e0db0208c159148fa78c65f7e0b0277", size = 5048277, upload-time = "2025-09-14T22:16:38.807Z" },
{ url = "https://files.pythonhosted.org/packages/46/c0/ca3e533b4fa03112facbe7fbe7779cb1ebec215688e5df576fe5429172e0/zstandard-0.25.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f373da2c1757bb7f1acaf09369cdc1d51d84131e50d5fa9863982fd626466313", size = 5574377, upload-time = "2025-09-14T22:16:40.523Z" },
{ url = "https://files.pythonhosted.org/packages/12/9b/3fb626390113f272abd0799fd677ea33d5fc3ec185e62e6be534493c4b60/zstandard-0.25.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:6c0e5a65158a7946e7a7affa6418878ef97ab66636f13353b8502d7ea03c8097", size = 4961493, upload-time = "2025-09-14T22:16:43.3Z" },
{ url = "https://files.pythonhosted.org/packages/cb/d3/23094a6b6a4b1343b27ae68249daa17ae0651fcfec9ed4de09d14b940285/zstandard-0.25.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:c8e167d5adf59476fa3e37bee730890e389410c354771a62e3c076c86f9f7778", size = 5269018, upload-time = "2025-09-14T22:16:45.292Z" },
{ url = "https://files.pythonhosted.org/packages/8c/a7/bb5a0c1c0f3f4b5e9d5b55198e39de91e04ba7c205cc46fcb0f95f0383c1/zstandard-0.25.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:98750a309eb2f020da61e727de7d7ba3c57c97cf6213f6f6277bb7fb42a8e065", size = 5443672, upload-time = "2025-09-14T22:16:47.076Z" },
{ url = "https://files.pythonhosted.org/packages/27/22/503347aa08d073993f25109c36c8d9f029c7d5949198050962cb568dfa5e/zstandard-0.25.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:22a086cff1b6ceca18a8dd6096ec631e430e93a8e70a9ca5efa7561a00f826fa", size = 5822753, upload-time = "2025-09-14T22:16:49.316Z" },
{ url = "https://files.pythonhosted.org/packages/e2/be/94267dc6ee64f0f8ba2b2ae7c7a2df934a816baaa7291db9e1aa77394c3c/zstandard-0.25.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:72d35d7aa0bba323965da807a462b0966c91608ef3a48ba761678cb20ce5d8b7", size = 5366047, upload-time = "2025-09-14T22:16:51.328Z" },
{ url = "https://files.pythonhosted.org/packages/7b/a3/732893eab0a3a7aecff8b99052fecf9f605cf0fb5fb6d0290e36beee47a4/zstandard-0.25.0-cp311-cp311-win32.whl", hash = "sha256:f5aeea11ded7320a84dcdd62a3d95b5186834224a9e55b92ccae35d21a8b63d4", size = 436484, upload-time = "2025-09-14T22:16:55.005Z" },
{ url = "https://files.pythonhosted.org/packages/43/a3/c6155f5c1cce691cb80dfd38627046e50af3ee9ddc5d0b45b9b063bfb8c9/zstandard-0.25.0-cp311-cp311-win_amd64.whl", hash = "sha256:daab68faadb847063d0c56f361a289c4f268706b598afbf9ad113cbe5c38b6b2", size = 506183, upload-time = "2025-09-14T22:16:52.753Z" },
{ url = "https://files.pythonhosted.org/packages/8c/3e/8945ab86a0820cc0e0cdbf38086a92868a9172020fdab8a03ac19662b0e5/zstandard-0.25.0-cp311-cp311-win_arm64.whl", hash = "sha256:22a06c5df3751bb7dc67406f5374734ccee8ed37fc5981bf1ad7041831fa1137", size = 462533, upload-time = "2025-09-14T22:16:53.878Z" },
{ url = "https://files.pythonhosted.org/packages/82/fc/f26eb6ef91ae723a03e16eddb198abcfce2bc5a42e224d44cc8b6765e57e/zstandard-0.25.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7b3c3a3ab9daa3eed242d6ecceead93aebbb8f5f84318d82cee643e019c4b73b", size = 795738, upload-time = "2025-09-14T22:16:56.237Z" },
{ url = "https://files.pythonhosted.org/packages/aa/1c/d920d64b22f8dd028a8b90e2d756e431a5d86194caa78e3819c7bf53b4b3/zstandard-0.25.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:913cbd31a400febff93b564a23e17c3ed2d56c064006f54efec210d586171c00", size = 640436, upload-time = "2025-09-14T22:16:57.774Z" },
{ url = "https://files.pythonhosted.org/packages/53/6c/288c3f0bd9fcfe9ca41e2c2fbfd17b2097f6af57b62a81161941f09afa76/zstandard-0.25.0-cp312-cp312-manylinux2010_i686.manylinux2014_i686.manylinux_2_12_i686.manylinux_2_17_i686.whl", hash = "sha256:011d388c76b11a0c165374ce660ce2c8efa8e5d87f34996aa80f9c0816698b64", size = 5343019, upload-time = "2025-09-14T22:16:59.302Z" },
{ url = "https://files.pythonhosted.org/packages/1e/15/efef5a2f204a64bdb5571e6161d49f7ef0fffdbca953a615efbec045f60f/zstandard-0.25.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:6dffecc361d079bb48d7caef5d673c88c8988d3d33fb74ab95b7ee6da42652ea", size = 5063012, upload-time = "2025-09-14T22:17:01.156Z" },
{ url = "https://files.pythonhosted.org/packages/b7/37/a6ce629ffdb43959e92e87ebdaeebb5ac81c944b6a75c9c47e300f85abdf/zstandard-0.25.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:7149623bba7fdf7e7f24312953bcf73cae103db8cae49f8154dd1eadc8a29ecb", size = 5394148, upload-time = "2025-09-14T22:17:03.091Z" },
{ url = "https://files.pythonhosted.org/packages/e3/79/2bf870b3abeb5c070fe2d670a5a8d1057a8270f125ef7676d29ea900f496/zstandard-0.25.0-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:6a573a35693e03cf1d67799fd01b50ff578515a8aeadd4595d2a7fa9f3ec002a", size = 5451652, upload-time = "2025-09-14T22:17:04.979Z" },
{ url = "https://files.pythonhosted.org/packages/53/60/7be26e610767316c028a2cbedb9a3beabdbe33e2182c373f71a1c0b88f36/zstandard-0.25.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5a56ba0db2d244117ed744dfa8f6f5b366e14148e00de44723413b2f3938a902", size = 5546993, upload-time = "2025-09-14T22:17:06.781Z" },
{ url = "https://files.pythonhosted.org/packages/85/c7/3483ad9ff0662623f3648479b0380d2de5510abf00990468c286c6b04017/zstandard-0.25.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:10ef2a79ab8e2974e2075fb984e5b9806c64134810fac21576f0668e7ea19f8f", size = 5046806, upload-time = "2025-09-14T22:17:08.415Z" },
{ url = "https://files.pythonhosted.org/packages/08/b3/206883dd25b8d1591a1caa44b54c2aad84badccf2f1de9e2d60a446f9a25/zstandard-0.25.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:aaf21ba8fb76d102b696781bddaa0954b782536446083ae3fdaa6f16b25a1c4b", size = 5576659, upload-time = "2025-09-14T22:17:10.164Z" },
{ url = "https://files.pythonhosted.org/packages/9d/31/76c0779101453e6c117b0ff22565865c54f48f8bd807df2b00c2c404b8e0/zstandard-0.25.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1869da9571d5e94a85a5e8d57e4e8807b175c9e4a6294e3b66fa4efb074d90f6", size = 4953933, upload-time = "2025-09-14T22:17:11.857Z" },
{ url = "https://files.pythonhosted.org/packages/18/e1/97680c664a1bf9a247a280a053d98e251424af51f1b196c6d52f117c9720/zstandard-0.25.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:809c5bcb2c67cd0ed81e9229d227d4ca28f82d0f778fc5fea624a9def3963f91", size = 5268008, upload-time = "2025-09-14T22:17:13.627Z" },
{ url = "https://files.pythonhosted.org/packages/1e/73/316e4010de585ac798e154e88fd81bb16afc5c5cb1a72eeb16dd37e8024a/zstandard-0.25.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:f27662e4f7dbf9f9c12391cb37b4c4c3cb90ffbd3b1fb9284dadbbb8935fa708", size = 5433517, upload-time = "2025-09-14T22:17:16.103Z" },
{ url = "https://files.pythonhosted.org/packages/5b/60/dd0f8cfa8129c5a0ce3ea6b7f70be5b33d2618013a161e1ff26c2b39787c/zstandard-0.25.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:99c0c846e6e61718715a3c9437ccc625de26593fea60189567f0118dc9db7512", size = 5814292, upload-time = "2025-09-14T22:17:17.827Z" },
{ url = "https://files.pythonhosted.org/packages/fc/5f/75aafd4b9d11b5407b641b8e41a57864097663699f23e9ad4dbb91dc6bfe/zstandard-0.25.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:474d2596a2dbc241a556e965fb76002c1ce655445e4e3bf38e5477d413165ffa", size = 5360237, upload-time = "2025-09-14T22:17:19.954Z" },
{ url = "https://files.pythonhosted.org/packages/ff/8d/0309daffea4fcac7981021dbf21cdb2e3427a9e76bafbcdbdf5392ff99a4/zstandard-0.25.0-cp312-cp312-win32.whl", hash = "sha256:23ebc8f17a03133b4426bcc04aabd68f8236eb78c3760f12783385171b0fd8bd", size = 436922, upload-time = "2025-09-14T22:17:24.398Z" },
{ url = "https://files.pythonhosted.org/packages/79/3b/fa54d9015f945330510cb5d0b0501e8253c127cca7ebe8ba46a965df18c5/zstandard-0.25.0-cp312-cp312-win_amd64.whl", hash = "sha256:ffef5a74088f1e09947aecf91011136665152e0b4b359c42be3373897fb39b01", size = 506276, upload-time = "2025-09-14T22:17:21.429Z" },
{ url = "https://files.pythonhosted.org/packages/ea/6b/8b51697e5319b1f9ac71087b0af9a40d8a6288ff8025c36486e0c12abcc4/zstandard-0.25.0-cp312-cp312-win_arm64.whl", hash = "sha256:181eb40e0b6a29b3cd2849f825e0fa34397f649170673d385f3598ae17cca2e9", size = 462679, upload-time = "2025-09-14T22:17:23.147Z" },
{ url = "https://files.pythonhosted.org/packages/35/0b/8df9c4ad06af91d39e94fa96cc010a24ac4ef1378d3efab9223cc8593d40/zstandard-0.25.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ec996f12524f88e151c339688c3897194821d7f03081ab35d31d1e12ec975e94", size = 795735, upload-time = "2025-09-14T22:17:26.042Z" },
{ url = "https://files.pythonhosted.org/packages/3f/06/9ae96a3e5dcfd119377ba33d4c42a7d89da1efabd5cb3e366b156c45ff4d/zstandard-0.25.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a1a4ae2dec3993a32247995bdfe367fc3266da832d82f8438c8570f989753de1", size = 640440, upload-time = "2025-09-14T22:17:27.366Z" },
{ url = "https://files.pythonhosted.org/packages/d9/14/933d27204c2bd404229c69f445862454dcc101cd69ef8c6068f15aaec12c/zstandard-0.25.0-cp313-cp313-manylinux2010_i686.manylinux2014_i686.manylinux_2_12_i686.manylinux_2_17_i686.whl", hash = "sha256:e96594a5537722fdfb79951672a2a63aec5ebfb823e7560586f7484819f2a08f", size = 5343070, upload-time = "2025-09-14T22:17:28.896Z" },
{ url = "https://files.pythonhosted.org/packages/6d/db/ddb11011826ed7db9d0e485d13df79b58586bfdec56e5c84a928a9a78c1c/zstandard-0.25.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:bfc4e20784722098822e3eee42b8e576b379ed72cca4a7cb856ae733e62192ea", size = 5063001, upload-time = "2025-09-14T22:17:31.044Z" },
{ url = "https://files.pythonhosted.org/packages/db/00/87466ea3f99599d02a5238498b87bf84a6348290c19571051839ca943777/zstandard-0.25.0-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:457ed498fc58cdc12fc48f7950e02740d4f7ae9493dd4ab2168a47c93c31298e", size = 5394120, upload-time = "2025-09-14T22:17:32.711Z" },
{ url = "https://files.pythonhosted.org/packages/2b/95/fc5531d9c618a679a20ff6c29e2b3ef1d1f4ad66c5e161ae6ff847d102a9/zstandard-0.25.0-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:fd7a5004eb1980d3cefe26b2685bcb0b17989901a70a1040d1ac86f1d898c551", size = 5451230, upload-time = "2025-09-14T22:17:34.41Z" },
{ url = "https://files.pythonhosted.org/packages/63/4b/e3678b4e776db00f9f7b2fe58e547e8928ef32727d7a1ff01dea010f3f13/zstandard-0.25.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:8e735494da3db08694d26480f1493ad2cf86e99bdd53e8e9771b2752a5c0246a", size = 5547173, upload-time = "2025-09-14T22:17:36.084Z" },
{ url = "https://files.pythonhosted.org/packages/4e/d5/ba05ed95c6b8ec30bd468dfeab20589f2cf709b5c940483e31d991f2ca58/zstandard-0.25.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:3a39c94ad7866160a4a46d772e43311a743c316942037671beb264e395bdd611", size = 5046736, upload-time = "2025-09-14T22:17:37.891Z" },
{ url = "https://files.pythonhosted.org/packages/50/d5/870aa06b3a76c73eced65c044b92286a3c4e00554005ff51962deef28e28/zstandard-0.25.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:172de1f06947577d3a3005416977cce6168f2261284c02080e7ad0185faeced3", size = 5576368, upload-time = "2025-09-14T22:17:40.206Z" },
{ url = "https://files.pythonhosted.org/packages/5d/35/398dc2ffc89d304d59bc12f0fdd931b4ce455bddf7038a0a67733a25f550/zstandard-0.25.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:3c83b0188c852a47cd13ef3bf9209fb0a77fa5374958b8c53aaa699398c6bd7b", size = 4954022, upload-time = "2025-09-14T22:17:41.879Z" },
{ url = "https://files.pythonhosted.org/packages/9a/5c/36ba1e5507d56d2213202ec2b05e8541734af5f2ce378c5d1ceaf4d88dc4/zstandard-0.25.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:1673b7199bbe763365b81a4f3252b8e80f44c9e323fc42940dc8843bfeaf9851", size = 5267889, upload-time = "2025-09-14T22:17:43.577Z" },
{ url = "https://files.pythonhosted.org/packages/70/e8/2ec6b6fb7358b2ec0113ae202647ca7c0e9d15b61c005ae5225ad0995df5/zstandard-0.25.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:0be7622c37c183406f3dbf0cba104118eb16a4ea7359eeb5752f0794882fc250", size = 5433952, upload-time = "2025-09-14T22:17:45.271Z" },
{ url = "https://files.pythonhosted.org/packages/7b/01/b5f4d4dbc59ef193e870495c6f1275f5b2928e01ff5a81fecb22a06e22fb/zstandard-0.25.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:5f5e4c2a23ca271c218ac025bd7d635597048b366d6f31f420aaeb715239fc98", size = 5814054, upload-time = "2025-09-14T22:17:47.08Z" },
{ url = "https://files.pythonhosted.org/packages/b2/e5/fbd822d5c6f427cf158316d012c5a12f233473c2f9c5fe5ab1ae5d21f3d8/zstandard-0.25.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4f187a0bb61b35119d1926aee039524d1f93aaf38a9916b8c4b78ac8514a0aaf", size = 5360113, upload-time = "2025-09-14T22:17:48.893Z" },
{ url = "https://files.pythonhosted.org/packages/8e/e0/69a553d2047f9a2c7347caa225bb3a63b6d7704ad74610cb7823baa08ed7/zstandard-0.25.0-cp313-cp313-win32.whl", hash = "sha256:7030defa83eef3e51ff26f0b7bfb229f0204b66fe18e04359ce3474ac33cbc09", size = 436936, upload-time = "2025-09-14T22:17:52.658Z" },
{ url = "https://files.pythonhosted.org/packages/d9/82/b9c06c870f3bd8767c201f1edbdf9e8dc34be5b0fbc5682c4f80fe948475/zstandard-0.25.0-cp313-cp313-win_amd64.whl", hash = "sha256:1f830a0dac88719af0ae43b8b2d6aef487d437036468ef3c2ea59c51f9d55fd5", size = 506232, upload-time = "2025-09-14T22:17:50.402Z" },
{ url = "https://files.pythonhosted.org/packages/d4/57/60c3c01243bb81d381c9916e2a6d9e149ab8627c0c7d7abb2d73384b3c0c/zstandard-0.25.0-cp313-cp313-win_arm64.whl", hash = "sha256:85304a43f4d513f5464ceb938aa02c1e78c2943b29f44a750b48b25ac999a049", size = 462671, upload-time = "2025-09-14T22:17:51.533Z" },
{ url = "https://files.pythonhosted.org/packages/3d/5c/f8923b595b55fe49e30612987ad8bf053aef555c14f05bb659dd5dbe3e8a/zstandard-0.25.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:e29f0cf06974c899b2c188ef7f783607dbef36da4c242eb6c82dcd8b512855e3", size = 795887, upload-time = "2025-09-14T22:17:54.198Z" },
{ url = "https://files.pythonhosted.org/packages/8d/09/d0a2a14fc3439c5f874042dca72a79c70a532090b7ba0003be73fee37ae2/zstandard-0.25.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:05df5136bc5a011f33cd25bc9f506e7426c0c9b3f9954f056831ce68f3b6689f", size = 640658, upload-time = "2025-09-14T22:17:55.423Z" },
{ url = "https://files.pythonhosted.org/packages/5d/7c/8b6b71b1ddd517f68ffb55e10834388d4f793c49c6b83effaaa05785b0b4/zstandard-0.25.0-cp314-cp314-manylinux2010_i686.manylinux_2_12_i686.manylinux_2_28_i686.whl", hash = "sha256:f604efd28f239cc21b3adb53eb061e2a205dc164be408e553b41ba2ffe0ca15c", size = 5379849, upload-time = "2025-09-14T22:17:57.372Z" },
{ url = "https://files.pythonhosted.org/packages/a4/86/a48e56320d0a17189ab7a42645387334fba2200e904ee47fc5a26c1fd8ca/zstandard-0.25.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:223415140608d0f0da010499eaa8ccdb9af210a543fac54bce15babbcfc78439", size = 5058095, upload-time = "2025-09-14T22:17:59.498Z" },
{ url = "https://files.pythonhosted.org/packages/f8/ad/eb659984ee2c0a779f9d06dbfe45e2dc39d99ff40a319895df2d3d9a48e5/zstandard-0.25.0-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:2e54296a283f3ab5a26fc9b8b5d4978ea0532f37b231644f367aa588930aa043", size = 5551751, upload-time = "2025-09-14T22:18:01.618Z" },
{ url = "https://files.pythonhosted.org/packages/61/b3/b637faea43677eb7bd42ab204dfb7053bd5c4582bfe6b1baefa80ac0c47b/zstandard-0.25.0-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:ca54090275939dc8ec5dea2d2afb400e0f83444b2fc24e07df7fdef677110859", size = 6364818, upload-time = "2025-09-14T22:18:03.769Z" },
{ url = "https://files.pythonhosted.org/packages/31/dc/cc50210e11e465c975462439a492516a73300ab8caa8f5e0902544fd748b/zstandard-0.25.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e09bb6252b6476d8d56100e8147b803befa9a12cea144bbe629dd508800d1ad0", size = 5560402, upload-time = "2025-09-14T22:18:05.954Z" },
{ url = "https://files.pythonhosted.org/packages/c9/ae/56523ae9c142f0c08efd5e868a6da613ae76614eca1305259c3bf6a0ed43/zstandard-0.25.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:a9ec8c642d1ec73287ae3e726792dd86c96f5681eb8df274a757bf62b750eae7", size = 4955108, upload-time = "2025-09-14T22:18:07.68Z" },
{ url = "https://files.pythonhosted.org/packages/98/cf/c899f2d6df0840d5e384cf4c4121458c72802e8bda19691f3b16619f51e9/zstandard-0.25.0-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:a4089a10e598eae6393756b036e0f419e8c1d60f44a831520f9af41c14216cf2", size = 5269248, upload-time = "2025-09-14T22:18:09.753Z" },
{ url = "https://files.pythonhosted.org/packages/1b/c0/59e912a531d91e1c192d3085fc0f6fb2852753c301a812d856d857ea03c6/zstandard-0.25.0-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:f67e8f1a324a900e75b5e28ffb152bcac9fbed1cc7b43f99cd90f395c4375344", size = 5430330, upload-time = "2025-09-14T22:18:11.966Z" },
{ url = "https://files.pythonhosted.org/packages/a0/1d/7e31db1240de2df22a58e2ea9a93fc6e38cc29353e660c0272b6735d6669/zstandard-0.25.0-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:9654dbc012d8b06fc3d19cc825af3f7bf8ae242226df5f83936cb39f5fdc846c", size = 5811123, upload-time = "2025-09-14T22:18:13.907Z" },
{ url = "https://files.pythonhosted.org/packages/f6/49/fac46df5ad353d50535e118d6983069df68ca5908d4d65b8c466150a4ff1/zstandard-0.25.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:4203ce3b31aec23012d3a4cf4a2ed64d12fea5269c49aed5e4c3611b938e4088", size = 5359591, upload-time = "2025-09-14T22:18:16.465Z" },
{ url = "https://files.pythonhosted.org/packages/c2/38/f249a2050ad1eea0bb364046153942e34abba95dd5520af199aed86fbb49/zstandard-0.25.0-cp314-cp314-win32.whl", hash = "sha256:da469dc041701583e34de852d8634703550348d5822e66a0c827d39b05365b12", size = 444513, upload-time = "2025-09-14T22:18:20.61Z" },
{ url = "https://files.pythonhosted.org/packages/3a/43/241f9615bcf8ba8903b3f0432da069e857fc4fd1783bd26183db53c4804b/zstandard-0.25.0-cp314-cp314-win_amd64.whl", hash = "sha256:c19bcdd826e95671065f8692b5a4aa95c52dc7a02a4c5a0cac46deb879a017a2", size = 516118, upload-time = "2025-09-14T22:18:17.849Z" },
{ url = "https://files.pythonhosted.org/packages/f0/ef/da163ce2450ed4febf6467d77ccb4cd52c4c30ab45624bad26ca0a27260c/zstandard-0.25.0-cp314-cp314-win_arm64.whl", hash = "sha256:d7541afd73985c630bafcd6338d2518ae96060075f9463d7dc14cfb33514383d", size = 476940, upload-time = "2025-09-14T22:18:19.088Z" },
]