mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-30 04:22:02 +00:00
Compare commits
52 Commits
fp8_attemp
...
41bb2eac32
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
41bb2eac32 | ||
|
|
64a651a63c | ||
|
|
65df0de42b | ||
|
|
74554be3b5 | ||
|
|
d5418ea5a1 | ||
|
|
c88bbf8133 | ||
|
|
c8d93beed2 | ||
|
|
8630d32be4 | ||
|
|
59e36cc727 | ||
|
|
85b3e95e09 | ||
|
|
6a477eedbd | ||
|
|
63bb5831e2 | ||
|
|
a91743c168 | ||
|
|
d58fcd9d73 | ||
|
|
babde18ce1 | ||
|
|
cf5c9e5b8e | ||
|
|
413e91aa0f | ||
|
|
e7ed2082b8 | ||
|
|
f9a7e0f111 | ||
|
|
f5425245f9 | ||
|
|
2955650327 | ||
|
|
77a46902e4 | ||
|
|
bbc4413c58 | ||
|
|
f42ae9e901 | ||
|
|
e1dafc510f | ||
|
|
6460dc6382 | ||
|
|
1933e85046 | ||
|
|
3b95d4fd39 | ||
|
|
e85db6b4a4 | ||
|
|
9a88194c3f | ||
|
|
0b58d70e99 | ||
|
|
e3f58b838e | ||
|
|
184d4c12b1 | ||
|
|
b62a5bc44a | ||
|
|
8203efa919 | ||
|
|
50413d2d67 | ||
|
|
fbf2bbea25 | ||
|
|
747ed4491f | ||
|
|
7d1700c521 | ||
|
|
d4ea28d4e2 | ||
|
|
bdcc030ffa | ||
|
|
22a71aa3d3 | ||
|
|
255f8b9af6 | ||
|
|
6bb92403d5 | ||
|
|
3142ca1a28 | ||
|
|
7312ec9898 | ||
|
|
3b50b77ed3 | ||
|
|
f92efce169 | ||
|
|
43c29dd9d5 | ||
|
|
23985413aa | ||
|
|
64b48d0e5c | ||
|
|
238353c998 |
40
.claude/skills/read-arxiv-paper/SKILL.md
Normal file
40
.claude/skills/read-arxiv-paper/SKILL.md
Normal file
@@ -0,0 +1,40 @@
|
||||
---
|
||||
name: read-arxiv-paper
|
||||
description: Use this skill when when asked to read an arxiv paper given an arxiv URL
|
||||
---
|
||||
|
||||
You will be given a URL of an arxiv paper, for example:
|
||||
|
||||
https://www.arxiv.org/abs/2601.07372
|
||||
|
||||
### Part 1: Normalize the URL
|
||||
|
||||
The goal is to fetch the TeX Source of the paper (not the PDF!), the URL always looks like this:
|
||||
|
||||
https://www.arxiv.org/src/2601.07372
|
||||
|
||||
Notice the /src/ in the url. Once you have the URL:
|
||||
|
||||
### Part 2: Download the paper source
|
||||
|
||||
Fetch the url to a local .tar.gz file. A good location is `~/.cache/nanochat/knowledge/{arxiv_id}.tar.gz`.
|
||||
|
||||
(If the file already exists, there is no need to re-download it).
|
||||
|
||||
### Part 3: Unpack the file in that folder
|
||||
|
||||
Unpack the contents into `~/.cache/nanochat/knowledge/{arxiv_id}` directory.
|
||||
|
||||
### Part 4: Locate the entrypoint
|
||||
|
||||
Every latex source usually has an entrypoint, such as `main.tex` or something like that.
|
||||
|
||||
### Part 5: Read the paper
|
||||
|
||||
Once you've found the entrypoint, Read the contents and then recurse through all other relevant source files to read the paper.
|
||||
|
||||
#### Part 6: Report
|
||||
|
||||
Once you've read the paper, produce a summary of the paper into a markdown file at `./knowledge/summary_{tag}.md`. Notice that 1) use the local knowledge directory here (it's easier for me to open and reference here), not in `~/.cache`, and 2) generate some reasonable `tag` like e.g. `conditional_memory` or whatever seems appropriate given the paper. Probably make sure that the tag doesn't exist yet so you're not overwriting files.
|
||||
|
||||
As for the summary itself, remember that you're processing this paper within the context of the nanochat repository, so most often we we will be interested in how to apply the paper and its lessons to the nanochat project. Therefore, you should feel free to "remind yourself" of the related nanochat code by reading the relevant parts, and then explicitly make the connection of how this paper might relate to nanochat or what are things we might be inspired about or try.
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -9,6 +9,5 @@ eval_bundle/
|
||||
.env
|
||||
|
||||
# Local setup
|
||||
.claude
|
||||
CLAUDE.md
|
||||
wandb/
|
||||
|
||||
51
README.md
51
README.md
@@ -4,28 +4,29 @@
|
||||
|
||||
> The best ChatGPT that $100 can buy.
|
||||
|
||||
This repo is a full-stack implementation of an LLM like ChatGPT in a single, clean, minimal, hackable, dependency-lite codebase. nanochat is designed to run on a single 8XH100 node via scripts like [speedrun.sh](speedrun.sh), that run the entire pipeline start to end. This includes tokenization, pretraining, finetuning, evaluation, inference, and web serving over a simple UI so that you can talk to your own LLM just like ChatGPT. nanochat will become the capstone project of the course LLM101n being developed by Eureka Labs.
|
||||
|
||||
## Talk to it
|
||||
|
||||
To get a sense of the endpoint of this repo, you can currently find [nanochat d34](https://github.com/karpathy/nanochat/discussions/314) hosted on [nanochat.karpathy.ai](https://nanochat.karpathy.ai/). "d34" means that this model has 34 layers in the Transformer neural network. This model has 2.2 billion parameters, it was trained on 88 billion tokens by simply running the training script [run1000.sh](run1000.sh) with `--target_param_data_ratio=40` (2x longer than Chinchilla-optimal), and the total cost of training was ~$2,500 (about 100 hours training time on 8XH100 GPU node). While today this is enough to outperform GPT-2 of 2019, it falls dramatically short of modern Large Language Models like GPT-5. When talking to these micro models, you'll see that they make a lot of mistakes, they are a little bit naive and silly and they hallucinate a ton, a bit like children. It's kind of amusing. But what makes nanochat unique is that it is fully yours - fully configurable, tweakable, hackable, and trained by you from start to end. To train and talk to your own, we turn to...
|
||||
This repo is a full-stack implementation of an LLM like ChatGPT in a single, clean, minimal, hackable, dependency-lite codebase. nanochat is designed to run on a single 8XH100 node via scripts like [speedrun.sh](runs/speedrun.sh), that run the entire pipeline start to end. This includes tokenization, pretraining, finetuning, evaluation, inference, and web serving over a simple UI so that you can talk to your own LLM just like ChatGPT. nanochat will become the capstone project of the course LLM101n being developed by Eureka Labs.
|
||||
|
||||
## Updates
|
||||
|
||||
- (Jan 7 2026) See new post: [nanochat Miniseries v1](https://github.com/karpathy/nanochat/discussions/420) and the associated script [miniseries.sh](miniseries.sh).
|
||||
- (Jan 16 2026) The repo is in active development, I am currently fleshing out the pretraining stage.
|
||||
- (Jan 7 2026) See new post: [nanochat Miniseries v1](https://github.com/karpathy/nanochat/discussions/420) and the associated script [miniseries.sh](runs/miniseries.sh).
|
||||
|
||||
## Talk to it
|
||||
|
||||
To get a sense of the endpoint of this repo, you can currently find [nanochat d34](https://github.com/karpathy/nanochat/discussions/314) hosted on [nanochat.karpathy.ai](https://nanochat.karpathy.ai/). "d34" means that this model has 34 layers in the Transformer neural network. This model has 2.2 billion parameters, it was trained on 88 billion tokens by simply running the training script [run1000.sh](runs/run1000.sh) with `--target_param_data_ratio=40` (2x longer than Chinchilla-optimal), and the total cost of training was ~$2,500 (about 100 hours training time on 8XH100 GPU node). While today this is enough to outperform GPT-2 of 2019, it falls dramatically short of modern Large Language Models like GPT-5. When talking to these micro models, you'll see that they make a lot of mistakes, they are a little bit naive and silly and they hallucinate a ton, a bit like children. It's kind of amusing. But what makes nanochat unique is that it is fully yours - fully configurable, tweakable, hackable, and trained by you from start to end. To train and talk to your own, we turn to...
|
||||
|
||||
## Quick start
|
||||
|
||||
The fastest way to feel the magic is to run the speedrun script [speedrun.sh](speedrun.sh), which trains and inferences the $100 tier of nanochat. On an 8XH100 node at $24/hr, this gives a total run time of about 4 hours. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script:
|
||||
The fastest way to feel the magic is to run the speedrun script [speedrun.sh](runs/speedrun.sh), which trains and inferences the $100 tier of nanochat. On an 8XH100 node at $24/hr, this gives a total run time of about 4 hours. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script:
|
||||
|
||||
```bash
|
||||
bash speedrun.sh
|
||||
bash runs/speedrun.sh
|
||||
```
|
||||
|
||||
Alternatively, since the script runs for 4 hours, I like to launch it like this inside a new screen session `speedrun` (and also log output to `speedrun.log`):
|
||||
|
||||
```bash
|
||||
screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
|
||||
screen -L -Logfile speedrun.log -S speedrun bash runs/speedrun.sh
|
||||
```
|
||||
|
||||
See the [screen cheatsheet](https://gist.github.com/jctosta/af918e1618682638aa82) if you are less familiar. You can watch it go inside the screen session, or detach with `Ctrl-a d` and `tail speedrun.log` to view progress. Now wait 4 hours. Once it's done, you can talk to your LLM via the ChatGPT-like web UI. Make sure again that your local uv virtual environment is active (run `source .venv/bin/activate`), and serve it:
|
||||
@@ -72,7 +73,7 @@ Total wall clock time: 3h51m
|
||||
|
||||
Unsurprisingly, $100 is not enough to train a highly performant ChatGPT clone. In fact, LLMs are famous for their multi-million dollar capex. For our purposes, I think there are two more scales of interest. First is the ~$300 tier d26 model (i.e. depth=26) that trains in ~12 hours, which slightly outperforms GPT-2 CORE score. Second is the $1000 tier (~41.6 hours), just because it's a nice round number. But both of these are not yet fully supported and therefore not attached here in the master branch yet.
|
||||
|
||||
That said, to give a sense, the example changes needed for the [speedrun.sh](speedrun.sh) file to train a GPT-2 grade model d26 only involve three changes:
|
||||
That said, to give a sense, the example changes needed for the [speedrun.sh](runs/speedrun.sh) file to train a GPT-2 grade model d26 only involve three changes:
|
||||
|
||||
```bash
|
||||
...
|
||||
@@ -82,10 +83,10 @@ That said, to give a sense, the example changes needed for the [speedrun.sh](spe
|
||||
python -m nanochat.dataset -n 450 &
|
||||
...
|
||||
# use --depth to increase model size. to not oom, halve device batch size 32 -> 16:
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=26 --device_batch_size=16
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=26 --device-batch-size=16
|
||||
...
|
||||
# make sure to use the same later during midtraining:
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device-batch-size=16
|
||||
```
|
||||
|
||||
That's it! The biggest thing to pay attention to is making sure you have enough data shards to train on (the code will loop and do more epochs over the same training set otherwise, decreasing learning speed a bit), and managing your memory/VRAM, primarily by decreasing the `device_batch_size` until things fit (the scripts automatically compensate by increasing the number of gradient accumulation loops, simply turning parallel compute to sequential compute).
|
||||
@@ -99,7 +100,7 @@ And a bit more about computing environments that will run nanochat:
|
||||
|
||||
## Running on CPU / MPS
|
||||
|
||||
nanochat can be run on CPU or on MPS (if you're on Macbook), and will automatically try to detect what device is best to run on. You're not going to get too far without GPUs, but at least you'll be able to run the code paths and maybe train a tiny LLM with some patience. For an example of how to make all the run commands much smaller (feel free to tune!), you can refer to [dev/runcpu.sh](dev/runcpu.sh) file. You'll see that I'm essentially restricting all scripts to train smaller models, to run for shorter number of iterations, etc. This functionality is new, slightly gnarly (touched a lot of code), and was merged in this [CPU|MPS PR](https://github.com/karpathy/nanochat/pull/88) on Oct 21, 2025.
|
||||
nanochat can be run on CPU or on MPS (if you're on Macbook) in principle, and will automatically try to detect what device is best to run on. The script [runcpu.sh](runs/runcpu.sh) shows a very simple example that will exercise the code paths but basically produce garbage results. Unless you know what you're doing, I basically don't recommend using this script right now and hope to tune it a bit more in the future.
|
||||
|
||||
## Customization
|
||||
|
||||
@@ -109,15 +110,9 @@ Additionally, to add new abilities to nanochat, see [Guide: counting r in strawb
|
||||
|
||||
## Questions
|
||||
|
||||
nanochat is designed to be short and sweet. One big advantage of this is that we can package up all of the files together and copy paste them to your favorite LLM to ask arbitrary questions. As an example, I like to package up the repo using the [files-to-prompt](https://github.com/simonw/files-to-prompt) utility like so:
|
||||
I recommend using [DeepWiki](https://deepwiki.com/karpathy/nanochat) from Devin/Cognition to ask questions of this repo. In the URL of this repo, simply change github.com to deepwiki.com, and you're off.
|
||||
|
||||
```bash
|
||||
files-to-prompt . -e py -e md -e html -e toml -e sh --cxml > packaged.txt
|
||||
```
|
||||
|
||||
This includes all py, html, toml, sh files and chooses the cxml output format. Everything is written to the `packaged.txt` file, which atm measures ~330KB (i.e. well below ~100K tokens for a state of the art LLM), and ~8K lines of code in 45 files.
|
||||
|
||||
Alternatively, I recommend using [DeepWiki](https://deepwiki.com/karpathy/nanochat) from Devin/Cognition to ask questions of this repo. In the URL of this repo, simply change github.com to deepwiki.com, and you're off.
|
||||
You can also come to the [#nanochat Discord channel](https://discord.com/channels/1020383067459821711/1427295580895314031) to ask questions, or use the Discussions.
|
||||
|
||||
## Tests
|
||||
|
||||
@@ -137,11 +132,9 @@ python -m pytest tests/test_engine.py -v -s
|
||||
│ ├── gen_synthetic_data.py # Example synthetic data for identity
|
||||
│ ├── generate_logo.html
|
||||
│ ├── nanochat.png
|
||||
│ ├── repackage_data_reference.py # Pretraining data shard generation
|
||||
│ └── runcpu.sh # Small example of how to run on CPU/MPS
|
||||
│ └── repackage_data_reference.py # Pretraining data shard generation
|
||||
├── nanochat
|
||||
│ ├── __init__.py # empty
|
||||
│ ├── adamw.py # Distributed AdamW optimizer
|
||||
│ ├── checkpoint_manager.py # Save/Load model checkpoints
|
||||
│ ├── common.py # Misc small utilities, quality of life
|
||||
│ ├── core_eval.py # Evaluates base model CORE score (DCLM paper)
|
||||
@@ -152,12 +145,17 @@ python -m pytest tests/test_engine.py -v -s
|
||||
│ ├── gpt.py # The GPT nn.Module Transformer
|
||||
│ ├── logo.svg
|
||||
│ ├── loss_eval.py # Evaluate bits per byte (instead of loss)
|
||||
│ ├── muon.py # Distributed Muon optimizer
|
||||
│ ├── optim.py # AdamW + Muon optimizer, 1GPU and distributed
|
||||
│ ├── report.py # Utilities for writing the nanochat Report
|
||||
│ ├── tokenizer.py # BPE Tokenizer wrapper in style of GPT-4
|
||||
│ └── ui.html # HTML/CSS/JS for nanochat frontend
|
||||
├── pyproject.toml
|
||||
├── run1000.sh # Train the ~$800 nanochat d32
|
||||
├── runs
|
||||
│ ├── miniseries.sh # Miniseries training script
|
||||
│ ├── run1000.sh # Train the ~$800 nanochat d32
|
||||
│ ├── runcpu.sh # Small example of how to run on CPU/MPS
|
||||
│ ├── scaling_laws.sh # Scaling laws experiments
|
||||
│ └── speedrun.sh # Train the ~$100 nanochat d20
|
||||
├── scripts
|
||||
│ ├── base_eval.py # Base model: calculate CORE score
|
||||
│ ├── base_loss.py # Base model: calculate bits per byte, sample
|
||||
@@ -170,7 +168,6 @@ python -m pytest tests/test_engine.py -v -s
|
||||
│ ├── mid_train.py # Chat model: midtraining
|
||||
│ ├── tok_eval.py # Tokenizer: evaluate compression rate
|
||||
│ └── tok_train.py # Tokenizer: train it
|
||||
├── speedrun.sh # Train the ~$100 nanochat d20
|
||||
├── tasks
|
||||
│ ├── arc.py # Multiple choice science questions
|
||||
│ ├── common.py # TaskMixture | TaskSequence
|
||||
|
||||
479
dev/LOG.md
479
dev/LOG.md
@@ -4,6 +4,485 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-28: Reverted Bigram Hash Embeddings
|
||||
|
||||
Removed bigram embeddings (engram-lite) from the codebase. At larger scale (d25), the improvement was tiny and disappeared entirely when measured by wall clock time. It also bloated the VRAM used. The extra parameters and complexity aren't justified.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-27: Bigram Hash Embeddings (Engram-lite)
|
||||
|
||||
Explored N-gram memory modules inspired by the [DeepSeek Engram paper](https://arxiv.org/abs/2601.07372) and [modded-nanogpt PR #201](https://github.com/KellerJordan/modded-nanogpt/pull/201).
|
||||
|
||||
### Background
|
||||
|
||||
The Engram paper introduces "conditional memory" as a complement to MoE - using O(1) hash lookups to retrieve static N-gram patterns instead of reconstructing them through computation. Key insight: transformers waste early layers "simulating retrieval through computation" for patterns like named entities and formulaic phrases that could be simple table lookups.
|
||||
|
||||
### What We Tried
|
||||
|
||||
**1. Full Engram module with context-aware gating (paper design)**
|
||||
```python
|
||||
# Hash bigrams to retrieve embeddings, then gate with hidden state
|
||||
e = embed(hash(prev_token, curr_token))
|
||||
q = RMSNorm(h) # hidden state as query
|
||||
k = RMSNorm(W_k @ e) # projected embedding as key
|
||||
v = W_v @ e
|
||||
α = sigmoid(q · k / √d) # scalar gate per position
|
||||
output = α * v
|
||||
```
|
||||
- Injected after block 1 (paper found early injection optimal)
|
||||
- Slight improvement, but quite a bit of complexity added.
|
||||
|
||||
**2. Early-layer only injection**
|
||||
- Only inject bigram signal in first 4 layers (where paper claims static pattern offloading helps most)
|
||||
- **Result:** Actually hurt performance. The model seems to need uniform injection across all layers.
|
||||
|
||||
**3. Trigrams**
|
||||
- Extended to hash both 2-grams and 3-grams, concatenating embeddings
|
||||
- **Result:** No improvement over bigrams alone. Dilutes capacity from more frequent 2-gram patterns.
|
||||
|
||||
**4. Bigram-only with x0-style injection (modded-nanogpt engram-lite approach)**
|
||||
- Simple hash: `(36313 * curr) XOR (27191 * prev) mod table_size`
|
||||
- Zero-init embedding table, learned per-layer lambdas
|
||||
- Add to residual at every layer: `x = resid_λ[i]*x + x0_λ[i]*x0 + bigram_λ[i]*x0_bigram`
|
||||
- **Result:** This simple approach works and provides a consistent improvement.
|
||||
|
||||
TLDR The winning approach follows modded-nanogpt's "engram-lite", simply adding the following module and feeding its output into the residual branch (gated by a per-layer learnable \lambda) before every single block:
|
||||
|
||||
```python
|
||||
class BigramEmbed(nn.Module):
|
||||
def __init__(self, vocab_size, embed_dim, table_multiplier=5):
|
||||
self.embed = nn.Embedding(vocab_size * table_multiplier, embed_dim)
|
||||
|
||||
def forward(self, idx):
|
||||
h = (36313 * idx[:, 1:]) ^ (27191 * idx[:, :-1]) % (table_size - 1)
|
||||
return self.embed(h)
|
||||
```
|
||||
|
||||
As for optimal hyperparameters:
|
||||
|
||||
- **Table size:** `vocab_size * 5` (~164K entries for 32K vocab). Swept a number of settings and 5 was optimal.
|
||||
- **Injection:** Every layer via learned `bigram_lambdas` (init 0.1 was better than 0.0).
|
||||
- **Normalization:** Also tried adding a `norm()` to the embeddings (mirroring the token embeddings), this was slightly worse.
|
||||
- **Init:** Zero-init embedding, so starts as identity (tried small noisy init, it's worse)
|
||||
- **Optimizer:** AdamW with same LR as token embeddings
|
||||
|
||||
### Key Learnings
|
||||
|
||||
1. **Gating didn't help at our scale.** The paper's context-aware gating mechanism (sigmoid dot-product gate) added parameters and complexity without improvement. modded-nanogpt found the same: "simple direct addition to the residual stream outperformed by a decent margin."
|
||||
|
||||
2. **Uniform injection beats early-only.** Despite the paper's finding that early layers benefit most, restricting injection to early layers hurt. The x0-style "add everywhere with learned lambda" pattern works better for our architecture/scale.
|
||||
|
||||
3. **Bigrams are sufficient.** Trigrams didn't help - the extra context doesn't pay for the diluted capacity.
|
||||
|
||||
4. **Scale matters.** The Engram paper's results are at 27B params with MoE. At our ~100M-1B scale, the simpler approach wins. The elaborate gating mechanism may become useful at larger scales where collision handling matters more.
|
||||
|
||||
### Parameters Added
|
||||
|
||||
For d12 model with `table_multiplier=5`:
|
||||
- Bigram embedding: 32768 × 5 × 768 = ~126M params
|
||||
- Per-layer lambdas: 12 scalars (negligible)
|
||||
|
||||
If you're keeping track, we now have *a lot* of parameters, a significant amount of them in embeddings (token embeddings, bigram embeddings, value embeddings). For example, for a d12 we now have:
|
||||
|
||||
```
|
||||
Parameter counts:
|
||||
wte : 25,165,824
|
||||
bigram_embed : 125,829,120
|
||||
value_embeds : 150,994,944
|
||||
lm_head : 25,165,824
|
||||
transformer_matrices : 84,935,808
|
||||
scalars : 36
|
||||
total : 412,091,556
|
||||
```
|
||||
|
||||
In other words, only about a quarter of parameters are now weight projections and the vast majority is embedding tables.
|
||||
|
||||
Still, on all axes (steps, wall clock time, flops), this somewhat parameter-bloated architecture beats the baseline and will now become the default.
|
||||
|
||||
After adding the engram-lite, I re-ran the scaling laws to determine the new optimal tokens:params ratio. I swept FLOPs in the range 1e18..1e19, exponentially strided in 4 settings (1e18, 2e18, 5e18, 1e19). I looked at a number of ways of determining the effective parameter count for the purposes of the scaling laws. The results looked like this:
|
||||
|
||||
```
|
||||
Kaplan-style (all projections including lm_head and no embeddings)
|
||||
|
||||
Optimal configurations (from quadratic fits):
|
||||
FLOPs Eff Params Tokens Ratio Val BPB
|
||||
-----------------------------------------------------------------
|
||||
1e+18 110,678,115 1,241,505,403 11.2 0.8972
|
||||
2e+18 167,797,457 1,785,336,422 10.7 0.8616
|
||||
5e+18 250,650,865 2,642,234,152 10.8 0.8293
|
||||
1e+19 381,758,347 3,806,871,243 10.3 0.7999
|
||||
|
||||
N \propto C^0.54, D \propto C^0.49
|
||||
|
||||
Chinchilla-style (all parameters, period.)
|
||||
|
||||
Optimal configurations (from quadratic fits):
|
||||
FLOPs Eff Params Tokens Ratio Val BPB
|
||||
-----------------------------------------------------------------
|
||||
1e+18 416,320,605 1,232,157,011 3.0 0.8974
|
||||
2e+18 560,239,841 1,763,669,281 3.2 0.8616
|
||||
5e+18 741,495,903 2,629,909,368 3.6 0.8291
|
||||
1e+19 988,644,331 3,884,841,895 4.0 0.7999
|
||||
|
||||
N \propto C^0.37, D \propto C^0.50
|
||||
|
||||
Transformer-only-style (only the projections inside the transformer)
|
||||
|
||||
Optimal configurations (from quadratic fits):
|
||||
FLOPs Eff Params Tokens Ratio Val BPB
|
||||
-----------------------------------------------------------------
|
||||
1e+18 80,259,665 1,315,639,547 17.2 0.8966
|
||||
2e+18 131,488,566 1,864,134,141 14.5 0.8622
|
||||
5e+18 220,985,474 2,595,328,843 12.1 0.8302
|
||||
1e+19 401,213,504 3,328,704,512 8.5 0.7994
|
||||
|
||||
N \propto C^0.70, D \propto C^0.41
|
||||
```
|
||||
|
||||
Clearly, the Kaplan-style ratios are most consistent and produce stable ~0.5 exponents for both params and tokens, meaning we can have a single fixed ratio of tokens:params for compute optimal models. This turns out to be about ~10.5, which now becomes the new default.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-19 to 2026-01-22: Optimizer Hyperparameter Sweep
|
||||
|
||||
Ran ~320 experiments across 6 rounds, scaling from d12→d16→d20 to find optimal optimizer hyperparameters. Added granular per-component control to `setup_optimizers()` — separate LRs and betas for embedding, unembedding, value_embeds, resid_lambdas, x0_lambdas, and Muon matrix params.
|
||||
|
||||
### What We Swept
|
||||
- Learning rates for all 6 parameter groups
|
||||
- Beta1/beta2 for all 5 AdamW groups
|
||||
- Muon momentum (start/end), weight decay
|
||||
- Hundreds of combinations (2-way, 3-way, 4-way, etc.)
|
||||
|
||||
### The Journey
|
||||
|
||||
**At d12**, found two independent improvement routes:
|
||||
- **Route A:** emb_lr↑ (0.3→0.4), weight_decay↑ (0.1→0.15), matrix_lr↑ (0.02→0.025)
|
||||
- **Route B:** x0_lr↓ (0.5→0.2), x0_beta1↑ (0.8→0.9+)
|
||||
|
||||
Both gave ~0.002 improvement, but combining them caused conflicts. Fine-tuning found wd=0.13, matrix_lr=0.027, emb_lr=0.38 helped slightly. Best d12 config: Route A + x0_beta1=0.95.
|
||||
|
||||
**At d16**, Route B became competitive with Route A. The routes still conflicted when combined.
|
||||
|
||||
**At d20** (target scale), everything changed:
|
||||
- Fine-tuned values from d12 **actively hurt** performance
|
||||
- Routes no longer conflicted
|
||||
- Just `x0_beta1=0.96` alone captured nearly all the gains
|
||||
|
||||
### Final x0_beta1 Sweep at d20
|
||||
|
||||
| x0_beta1 | val/bpb | Δ vs baseline |
|
||||
|----------|---------|---------------|
|
||||
| **0.96** | **0.7971** | **-0.0007** |
|
||||
| 0.94 | 0.7972 | -0.0006 |
|
||||
| 0.90 | 0.7972 | -0.0006 |
|
||||
| 0.97 | 0.7977 | -0.0001 |
|
||||
| 0.98 | 0.8011 | +0.0033 💀 |
|
||||
|
||||
Flat plateau from 0.90-0.96, then sharp cliff at 0.97+.
|
||||
|
||||
### Key Learnings
|
||||
|
||||
1. **Hyperparameters are scale-dependent.** What works at d12 doesn't transfer to d20. The elaborate fine-tuning that won at d12 actively hurts at d20.
|
||||
|
||||
2. **Improvement magnitude shrinks with scale.** ~0.002 at d12 → ~0.0007 at d20. The baseline is already better-tuned for larger models.
|
||||
|
||||
3. **Sharp cliffs exist.** x0_beta1=0.98 is catastrophic while 0.96 is optimal.
|
||||
|
||||
4. **Don't over-tune on small proxies.** Validate at target scale before shipping.
|
||||
|
||||
### Final Recommendation
|
||||
|
||||
For production d20 runs, add one flag:
|
||||
```
|
||||
--x0-lambdas-beta1=0.96
|
||||
```
|
||||
|
||||
Skip everything else discovered at smaller scales.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-18: More various experiments
|
||||
|
||||
- Tried Muon custom kernels for XXT and all the others. The improvement was there for targeted tests (~20%) but washed out completely to noise in an actual training run, especially because the Muon compute is split across all the workers. Abandoned due to complexity bloat.
|
||||
- Fuse Q,K,V,O nn.Linear layers into a single QKVO Linear layer. ~Zero impact
|
||||
- Tried the `sa_lambdas` that gate QKV and O. Slightly confused because of the use of rmsnorm, which erases the effect of any scalar multiplier. Helped a tiny bit (~1e-4 of loss), abandoned to control complexity.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-17: Various experiments
|
||||
|
||||
Modded-nanogpt uses [Value Embeddings](https://arxiv.org/abs/2410.17897) (VEs) in a funny U-shaped structure, 3 of them in total and with gates. I tried a large number of tweaks on this today:
|
||||
|
||||
- VEs at every layer, at alternating layers, U shaped, front and back. Alternating layers worked best, i.e. we end up with *a lot* more VEs than modded-nanogpt, at every other layer. It works better.
|
||||
- Many parameters sharing ideas to reduce new parameter count, nothing here worked. All failed.
|
||||
- Many ideas to reduce parameter count, the LLM hates all of them: low rank decompositions, projections. All failed.
|
||||
- Gated yes or no and how much. Gate helps.
|
||||
|
||||
Long story short is that the models *love* Value Embeddings. It is a way to add a huge amount of capacity (parameters) to the model at almost zero cost of FLOPs, because these embeddings are simply added to the Values tensor. Any attempt to reduce the capacity of value embeddings (param sharing, low rank, projections) fail. The model wants many of them, and with all the capacity, and doing so wins across all x axes of steps, flops and wall clock. I re-ran the scaling laws and, because the models are now very parameter bloated, the optimal ratio has halved from 8 to 4! Way down lower than Chinchilla's 20 at this point.
|
||||
|
||||
Other experiments, looking at val/bpb as a function of all of steps, flops and wall clock time:
|
||||
|
||||
- Aspect ratio of 128 is worse than 64, I tried a sweep fixing FLOPs == 1e18 and 64 outperforms. The LLM prefers to be slightly thinner and longer.
|
||||
- Head dim definitely prefers to be 128 instead of 64, i.e. fewer bigger heads
|
||||
- Bunch of other random stuff like that.
|
||||
|
||||
Keeping all of this work on a private branch for now but hope to push shortly.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-17: Modded-nanogpt Ideas Sweep (Continued)
|
||||
|
||||
Continued testing ideas from modded-nanogpt.
|
||||
|
||||
| Idea | Result | Notes |
|
||||
|------|--------|-------|
|
||||
| Attention gates | No improvement | Per-head learnable gates on attention output. +1GB memory, decreased efficiency. |
|
||||
| Batch size schedule | Abandoned | 8→16→24 with LR scaling. Made training script too bloated/complex, not worth cognitive overhead. |
|
||||
| Value embeddings | Helps a lot | Experiments still ongoing, more on this later. |
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-16: Flash Attention 3 Fallback to SDPA
|
||||
|
||||
Added automatic fallback from Flash Attention 3 to PyTorch's `scaled_dot_product_attention` (SDPA) for users without Hopper GPUs. This enables nanochat to run on older CUDA GPUs, CPU, and MPS (Apple Silicon).
|
||||
|
||||
### Implementation
|
||||
|
||||
Created `nanochat/flash_attention.py` - a unified interface that:
|
||||
- Detects FA3 availability at import time (requires sm90+ / Hopper)
|
||||
- Exports a `flash_attn` object matching FA3's API exactly (`flash_attn.flash_attn_func`, `flash_attn.flash_attn_with_kvcache`)
|
||||
- Automatically routes to FA3 or SDPA based on hardware
|
||||
- Handles tensor layout differences: FA3 uses (B, T, H, D), SDPA uses (B, H, T, D)
|
||||
- Implements sliding window attention via explicit masks for SDPA
|
||||
- Manages KV cache manually for SDPA (FA3 does it in-place)
|
||||
|
||||
### Changes to Existing Files
|
||||
|
||||
Changes to existing code were intentionally kept extremely minimal.
|
||||
|
||||
**gpt.py**: Only the import line changed and a comment
|
||||
|
||||
**engine.py**: Zero changes needed
|
||||
|
||||
**base_train.py**: Added status print and warnings:
|
||||
- Prints whether FA3 or SDPA fallback is being used
|
||||
- Warns about efficiency loss without FA3
|
||||
- Warns about sliding window support if `--window-pattern` is not "L"
|
||||
|
||||
### Testing
|
||||
|
||||
Tests are split into two classes due to dtype/device constraints:
|
||||
|
||||
1. **TestFA3VsSDPA**: Comparison tests requiring Hopper GPU + bfloat16. Run both implementations on identical inputs and verify outputs match (max diff typically 0, at most ~0.004 for sliding window).
|
||||
|
||||
2. **TestSDPAOnly**: SDPA-only tests that run on any device with appropriate dtype. Verify forward pass, backward pass, and KV cache work correctly.
|
||||
|
||||
Added `_override_impl` mechanism for testing - can force 'fa3' or 'sdpa' to directly compare implementations.
|
||||
|
||||
### Notes
|
||||
|
||||
- SDPA fallback is significantly slower than FA3 especially in that it lacks the sliding window attention support
|
||||
- Recommend `--window-pattern L` (full context) when using SDPA fallback
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-16: Modded-nanogpt Ideas Sweep (Mostly Negative)
|
||||
|
||||
Tested several architectural ideas from modded-nanogpt to see if they transfer to nanochat. All of these did not help:
|
||||
|
||||
| Idea | Result | Notes |
|
||||
|------|--------|-------|
|
||||
| Half-truncated RoPE | No improvement | Only first half of head dims get RoPE (base 1024, linspace). Second half "stationary". |
|
||||
| Asymmetric softcap | Slightly worse | `23 * sigmoid((x+5)/7.5)` vs our symmetric `15 * tanh(x/15)`. May only help with FP8. |
|
||||
| Smear gate | Negligible | Blend each token with predecessor via learned gate. Tiny improvement not worth n_embd² params. |
|
||||
| Backout | No improvement | Save activations at ~60% through network, subtract scaled version at end. |
|
||||
| Skip connection | Slightly worse | Save at layer ~25%, add at layer ~50%. Also +2GB memory from storing activations. |
|
||||
|
||||
Value Embeddings do show promise. I need a more elaborate exploration of a few related ideas, which I leave for tomorrow.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-15: Olmo pretraining mix (Negative result)
|
||||
|
||||
I attempted to train on the Olmo 3 pretraining dataset [allenai/dolma3_mix-6T](https://huggingface.co/datasets/allenai/dolma3_mix-6T) instead of FineWeb-edu. I ran into a number of [errors and issues](https://huggingface.co/datasets/allenai/dolma3_mix-6T/discussions/2) trying to both download and process the dataset and then noticed some quality issues (e.g. some documents seem to be extremely short, like "5".). I managed to work around these with some sensible hacks (e.g. reject documents less than 100 characters in length) and tried to process the dataset exactly as FineWeb, re-trained the tokenizer and trained a d16 model. The CORE score decreased from 15.5 to 13.8, i.e. the result is quite a bit worse.
|
||||
|
||||
I am still looking to try the [DCLM dataset](https://arxiv.org/abs/2406.11794), which according to the paper should be better that FineWeb-edu. I do have some concerns that the same group both prepared the DCLM dataset *and* introduced the CORE score so I'm a bit hesitant in case there was some overfitting to CORE score adjacent data distribution.
|
||||
|
||||
Classifying as negative result and reverting back to FineWeb-edu for now.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-13: Varlen Attention (Negative Result)
|
||||
|
||||
Attempted to prevent attention from "leaking" across document boundaries using Flash Attention's `flash_attn_varlen_func`, similar to modded-nanogpt's approach.
|
||||
|
||||
### Background
|
||||
|
||||
With the BOS-aligned dataloader, multiple documents are packed into each row. Standard attention allows tokens to attend across document boundaries within a row. The hypothesis was that preventing this "leakage" via varlen attention might improve training.
|
||||
|
||||
### Approach: Compute cu_seqlens from inputs
|
||||
|
||||
- Find BOS positions: `(inputs.view(-1) == bos_token_id).nonzero()`
|
||||
- Gotcha 1: Variable-length `cu_seqlens` caused torch.compile recompilation (25s/iter!) - fixed by padding to fixed size
|
||||
- Gotcha 2: `nonzero()` inside compiled model hit recompile limit - fixed by moving computation outside compiled region
|
||||
|
||||
### Final Results (d16)
|
||||
|
||||
| Metric | Baseline | Varlen |
|
||||
|--------|----------|--------|
|
||||
| val_bpb | 0.85427 | 0.85407 |
|
||||
| MFU | ~same | ~same |
|
||||
| tok/sec | ~same | ~same |
|
||||
|
||||
Essentially identical. The 0.0002 bpb improvement is almost noise.
|
||||
|
||||
### Conclusion
|
||||
|
||||
Not worth the code complexity. The "leakage" across document boundaries within a row is not harmful - the model handles it fine. The BOS-aligned dataloader already provides the key benefit (every row starts with proper context). Not merging to master.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-13: BOS-Aligned Dataloader with Bin Packing
|
||||
|
||||
Redesigned the pretraining and midtraining dataloader to ensure every sequence starts with a BOS token, and explored bin-packing algorithms to minimize wasted tokens.
|
||||
|
||||
### Problem Statement
|
||||
|
||||
The original dataloader streams tokens into a flat buffer and reshapes into batches. This means some rows start mid-document (no BOS), which could confuse the model during training. We want every row to start with BOS and contain well-formed documents.
|
||||
|
||||
### Approach 1: Greedy-Crop BOS (Simple)
|
||||
|
||||
Each row is built independently:
|
||||
- Start with a document (which has BOS prepended)
|
||||
- Pack more documents until row is full
|
||||
- If a document doesn't fit, **crop it** to fill remaining space (discard the rest)
|
||||
- 100% utilization (no padding), but wastes cropped tokens
|
||||
|
||||
### Waste Analysis
|
||||
|
||||
Measured token waste empirically on real data (T=2048):
|
||||
- **39.4% of tokens are cropped** (discarded when docs don't fit)
|
||||
- **22.9% is the theoretical minimum** (tokens in docs longer than T+1 that can never fit)
|
||||
- The extra ~16.5% comes from "unlucky" cropping when a long doc starts near the end of a row
|
||||
|
||||
### Bin Packing Algorithms Explored
|
||||
|
||||
| Algorithm | Util% | Crop% | Pad% | Notes |
|
||||
|-----------|-------|-------|------|-------|
|
||||
| Greedy-Crop (baseline) | 100% | 39.4% | 0% | Simple, no wasted compute |
|
||||
| Greedy-Pad | 78% | 23.0% | 22% | Pads instead of crops - wastes compute |
|
||||
| First-Fit Decreasing (FFD) | 99.7% | 23.0% | 0.3% | Near-optimal packing, minimal padding |
|
||||
| **BestFit-Crop** | 100% | 34.6% | 0% | Smart cropping, no padding |
|
||||
|
||||
### BestFit-Crop Algorithm
|
||||
|
||||
A middle ground that maintains 100% utilization while reducing cropping:
|
||||
|
||||
1. Buffer N documents
|
||||
2. For each row, greedily pick the **largest doc that fits entirely**
|
||||
3. Repeat until nothing fits
|
||||
4. When nothing fits, crop a doc to fill remaining space exactly
|
||||
|
||||
This avoids "unlucky" crops by searching the buffer for better-fitting documents.
|
||||
|
||||
**Results (T=2048):**
|
||||
- Crop waste reduced from 39.4% → 34.6% (~12% relative improvement)
|
||||
- Still achieves 100% utilization (no padding, every token trains)
|
||||
- Slightly more rows than baseline (uses more documents per batch)
|
||||
|
||||
### Decision: Keep Two Implementations
|
||||
|
||||
1. Keep the original implementation which is very simple, efficient and has 100% token utilization in the batch (no padding with ignore tokens), but creates slightly more confusing token streams for the LLM because documents during training can start abruptly from the middle with no context. Note that this never happens at test time, where BOS is always present.
|
||||
|
||||
2. **`_bos_bestfit` (BestFit-Crop, new default)**: Slightly more complex but still keeps 100% token utilization in the batch (no padding), but at the cost of discarding documents when they don't fit. In practice, about 34% of tokens are discarded with this approach. This is ok because for most models we care about we have plenty of data without having to go to multiple epochs. One more subtle effect is that it does skew the data distribution a tiny bit because, reliably and necessarily, tokens at the tails of long documents will be discarded. However, this doesn't seem to impact actual downstream performance.
|
||||
|
||||
### Midtraining
|
||||
|
||||
The midtraining dataloader was also updated. Because conversations are on average a lot shorter than pretraining documents, only about 3.3% of tokens get cropped.
|
||||
|
||||
### NOTE: loss scale
|
||||
|
||||
Do note that switching to the BOS dataloader changes the validation loss and makes all previous experiments not comparable in absolute value of the loss, because we have a lot fewer "confusing" tokens in the train/val batches. All tokens can look back and find the BOS token and have the full context of that document to make predictions. Therefore, the loss appears lower but this is "fake" to some extent, and the expectation is that the vast majority of relative comparisons done so far would agree with those before and after this change.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-13: Number Token Split Pattern
|
||||
|
||||
Validated the `\p{N}{1,2}` pattern in `SPLIT_PATTERN` (tokenizer.py line 30), which I only guessed earlier and had a TODO for to validate. GPT-4 uses `\p{N}{1,3}` to group number sequences of up to 3 digits into tokens, but we suspected smaller vocab sizes benefit from grouping fewer digits per token.
|
||||
|
||||
**Results (d12, vocab=32K):**
|
||||
| Pattern | val_bpb |
|
||||
|---------|---------|
|
||||
| `\p{N}{1,1}` | 0.969 |
|
||||
| `\p{N}{1,2}` | **0.965** |
|
||||
| `\p{N}{1,3}` | 0.972 |
|
||||
|
||||
**Conclusion:** `{1,2}` is optimal for vocab size 32K. Grouping 3 digits wastes tokens on rare 3-digit combinations; grouping 1 digit is too fine-grained and bloats token sequences. Keeping `{1,2}` as default.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-13: FP8 Training for lm_head
|
||||
|
||||
Attempted to use FP8 (8-bit floating point) for the lm_head layer to speed up the large vocab projection matmul. H100 GPUs have FP8 tensor cores that can theoretically provide ~2x speedup over BF16.
|
||||
|
||||
### Implementation Approaches Tried
|
||||
|
||||
**1. Dynamic Scaling (failed)**
|
||||
- Compute `x.abs().max()` and `w.abs().max()` each forward to determine scales
|
||||
- Problem: `.item()` calls cause graph breaks with torch.compile
|
||||
- Tried `@torch._dynamo.allow_in_graph` pattern (like torchao.float8) - worked but no speedup
|
||||
- Tried `torch.library.custom_op` with float scales - caused NaN gradients after first optimizer step
|
||||
- Root cause: interaction between custom ops, dynamic scale computation, and torch.compile is fragile
|
||||
|
||||
**2. Static Scaling (partial success)**
|
||||
- Pre-set scales at init time like modded-nanogpt: `x_scale=10/448, w_scale=0.1/448`
|
||||
- `grad_scale` computed dynamically from batch size (safe since it's just `1/(B*T)/57344` due to the gradient expression of cross entropy). modded-nanogpt has a bug here probably because they set `grad_scale = 0.75/448`, but grads are in E5M2 so this should probably be `1/57344`, 1 being the amax of any individual element of cross entropy loss, and no normalization by B,T because they use sum reduction not mean reduction.
|
||||
- Uses `torch.library.custom_op` with `@torch.compile` on inner kernels
|
||||
- This works correctly - no NaNs, proper gradients
|
||||
|
||||
### Results (d12)
|
||||
|
||||
| Metric | BF16 Baseline | FP8 lm_head |
|
||||
|--------|---------------|-------------|
|
||||
| GPU Memory | 34 GB | 36 GB |
|
||||
| tok/sec | baseline | ~1% faster |
|
||||
|
||||
### The Memory Mystery
|
||||
|
||||
FP8 *should* save memory since we store `x_f8` (1 byte) instead of `x` (2 bytes) for backward. But we see 2GB *increase*. Suspected causes:
|
||||
- `torch.compile` on inner kernels creating extra buffers/specializations
|
||||
- `torch._scaled_mm` internal workspace allocations
|
||||
- Custom op registration machinery overhead
|
||||
|
||||
Tried saving original weight `w` (just a reference to parameter) instead of `w_f8` in backward, then re-quantizing on the spot during backward - didn't help. Still saw bump.
|
||||
|
||||
### Microbenchmark vs Reality
|
||||
|
||||
Raw microbenchmark showed promise:
|
||||
- BF16 matmul: 16.95 ms
|
||||
- FP8 matmul (static scales): 10.31 ms (1.64x faster)
|
||||
- FP8 with dynamic scaling: 12.25 ms (1.38x faster)
|
||||
|
||||
But in full training, the ~1% tok/sec improvement doesn't justify the 2GB memory increase and the added code complexity and the need to tune scale factors for both x and w.
|
||||
|
||||
### Code Artifacts
|
||||
|
||||
See the branch `fp8_attempt_fail` for:
|
||||
|
||||
- `nanochat/fp8_static.py` - Static scaling implementation (working)
|
||||
- `nanochat/fp8_dynamic.py` - Dynamic scaling implementation (torchao-style, working but slow)
|
||||
- `gpt.py` imports `fp8_static.LinearFP8` and simply swaps it for `lm_head` in `gpt.py`.
|
||||
|
||||
### Open Questions
|
||||
|
||||
- Why does the custom op approach use more memory than vanilla BF16?
|
||||
- Why is the bump in tok_per_sec so low? We should see ~1.6X speedup in both the forward pass and also (twice) in backward pass for the gradients. Granted, Ahmdal's law is part of the solution because our vocab_size is only 32K so the final layer isn't a huge part of the profile but the expected speedup is still not fully realized.
|
||||
|
||||
**Conclusion:** Negative result for now. The implementation works correctly but provides marginal speedup with *increased* memory usage. I'm not understanding the torch.compile interaction here. The complexity of FP8 custom ops isn't justified for lm_head alone. TODO to study in more detail the way this is implemented in other libraries, e.g. torchao.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-12: Multi-Token Prediction (MTP)
|
||||
|
||||
Ported multi-token prediction from modded-nanogpt. Instead of predicting just the next token, predict the next n tokens at each position with weighted loss.
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks)
|
||||
# Run as:
|
||||
# bash dev/cpu_demo_run.sh
|
||||
|
||||
# NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook.
|
||||
# Think of this run as educational/fun demo, not something you should expect to work well.
|
||||
# This is also why I hide this script away in dev/
|
||||
|
||||
# all the setup stuff
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
|
||||
mkdir -p $NANOCHAT_BASE_DIR
|
||||
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
[ -d ".venv" ] || uv venv
|
||||
uv sync --extra cpu
|
||||
source .venv/bin/activate
|
||||
if [ -z "$WANDB_RUN" ]; then
|
||||
WANDB_RUN=dummy
|
||||
fi
|
||||
|
||||
# wipe the report
|
||||
python -m nanochat.report reset
|
||||
|
||||
# train tokenizer on ~1B characters
|
||||
python -m nanochat.dataset -n 4
|
||||
python -m scripts.tok_train --max_chars=1000000000
|
||||
python -m scripts.tok_eval
|
||||
|
||||
# train a very small 4 layer model on the CPU
|
||||
# each optimization step processes a single sequence of 1024 tokens
|
||||
# we only run 50 steps of optimization (bump this to get better results)
|
||||
python -m scripts.base_train \
|
||||
--depth=4 \
|
||||
--max_seq_len=1024 \
|
||||
--device_batch_size=1 \
|
||||
--total_batch_size=1024 \
|
||||
--eval_every=50 \
|
||||
--eval_tokens=4096 \
|
||||
--core_metric_every=50 \
|
||||
--core_metric_max_per_task=12 \
|
||||
--sample_every=50 \
|
||||
--num_iterations=50
|
||||
python -m scripts.base_loss --device_batch_size=1 --split_tokens=4096
|
||||
python -m scripts.base_eval --max-per-task=16
|
||||
|
||||
# midtraining
|
||||
python -m scripts.mid_train \
|
||||
--max_seq_len=1024 \
|
||||
--device_batch_size=1 \
|
||||
--eval_every=50 \
|
||||
--eval_tokens=4096 \
|
||||
--total_batch_size=1024 \
|
||||
--num_iterations=100
|
||||
# eval results will be terrible, this is just to execute the code paths.
|
||||
# note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems
|
||||
python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20
|
||||
|
||||
# SFT
|
||||
python -m scripts.chat_sft \
|
||||
--device_batch_size=1 \
|
||||
--target_examples_per_step=4 \
|
||||
--num_iterations=100 \
|
||||
--eval_steps=4 \
|
||||
--eval_metrics_max_problems=16
|
||||
|
||||
# Chat CLI
|
||||
# python -m scripts.chat_cli -p "Why is the sky blue?"
|
||||
|
||||
# Chat Web
|
||||
# python -m scripts.chat_web
|
||||
|
||||
python -m nanochat.report generate
|
||||
@@ -15,14 +15,16 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%matplotlib inline\n",
|
||||
"import os\n",
|
||||
"import pandas as pd\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"# Load results\n",
|
||||
"tag = \"jan26\"\n",
|
||||
"base_dir = os.environ.get('NANOCHAT_BASE_DIR', os.path.expanduser('~/.cache/nanochat'))\n",
|
||||
"results_path = os.path.join(base_dir, 'scaling_laws_results', 'results.csv')\n",
|
||||
"results_path = os.path.join(base_dir, f'scaling_laws_results_{tag}', 'results.csv')\n",
|
||||
"\n",
|
||||
"df = pd.read_csv(results_path)\n",
|
||||
"flops_budgets = sorted(df['flops_budget'].unique())\n",
|
||||
@@ -31,6 +33,99 @@
|
||||
"df"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# =============================================================================\n",
|
||||
"# FILTERING: Remove incomplete or problematic runs\n",
|
||||
"# =============================================================================\n",
|
||||
"\n",
|
||||
"print(f\"Before filtering: {len(df)} runs\")\n",
|
||||
"\n",
|
||||
"# Filter out runs with missing/invalid val_bpb (incomplete runs)\n",
|
||||
"df = df[df['val_bpb'].notna() & (df['val_bpb'] > 0)]\n",
|
||||
"\n",
|
||||
"# Optional: exclude specific flops budgets that aren't done yet\n",
|
||||
"# exclude_flops = [1e19] # <-- adjust as runs complete\n",
|
||||
"# df = df[~df['flops_budget'].isin(exclude_flops)]\n",
|
||||
"\n",
|
||||
"# Optional: exclude specific depths\n",
|
||||
"# exclude_depths = [18, 20]\n",
|
||||
"# df = df[~df['depth'].isin(exclude_depths)]\n",
|
||||
"\n",
|
||||
"print(f\"After filtering: {len(df)} runs\")\n",
|
||||
"print(f\"FLOPs budgets: {sorted(df['flops_budget'].unique())}\")\n",
|
||||
"print(f\"Depths: {sorted(df['depth'].unique())}\")\n",
|
||||
"\n",
|
||||
"# Update flops_budgets list after filtering\n",
|
||||
"flops_budgets = sorted(df['flops_budget'].unique())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Effective Parameter Count\n",
|
||||
"\n",
|
||||
"Different scaling law papers use different conventions for counting parameters:\n",
|
||||
"- **Kaplan et al.** excluded embedding parameters (claimed cleaner laws)\n",
|
||||
"- **Chinchilla** included all parameters (and noted Kaplan had a bug)\n",
|
||||
"\n",
|
||||
"Our CSV now has granular counts:\n",
|
||||
"- `params_wte` - token embedding (lookup table)\n",
|
||||
"- `params_bigram_embed` - bigram hash embeddings (lookup table)\n",
|
||||
"- `params_value_embeds` - value embeddings (lookup table)\n",
|
||||
"- `params_lm_head` - unembedding projection (matmul)\n",
|
||||
"- `params_transformer` - attention + MLP matrices (matmuls)\n",
|
||||
"- `params_scalars` - resid/x0/bigram lambdas (tiny)\n",
|
||||
"\n",
|
||||
"**Experiment below** with different combinations to see which gives the cleanest scaling laws."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# =============================================================================\n",
|
||||
"# EXPERIMENT HERE: Define which parameters to count for scaling laws\n",
|
||||
"# =============================================================================\n",
|
||||
"\n",
|
||||
"def compute_effective_params(row):\n",
|
||||
" \"\"\"\n",
|
||||
" Compute the 'effective' parameter count for scaling law analysis.\n",
|
||||
"\n",
|
||||
" Modify this function to experiment with different conventions:\n",
|
||||
" - Chinchilla-style: include everything\n",
|
||||
" - Kaplan-style: exclude embeddings\n",
|
||||
" - Matmul-only: just transformer + lm_head (the actual compute)\n",
|
||||
" - etc.\n",
|
||||
" \"\"\"\n",
|
||||
" # Option 1: Chinchilla-style (all params)\n",
|
||||
" # return row['params_total']\n",
|
||||
"\n",
|
||||
" # Option 2: Kaplan-style (exclude embeddings)\n",
|
||||
" return row['params_transformer'] + row['params_lm_head']\n",
|
||||
"\n",
|
||||
" # Option 3: Transformer-only (exclude all embeddings AND lm_head)\n",
|
||||
" # return row['params_transformer']\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Compute derived columns\n",
|
||||
"df['effective_params'] = df.apply(compute_effective_params, axis=1)\n",
|
||||
"df['param_data_ratio'] = df['tokens_trained'] / df['effective_params']\n",
|
||||
"\n",
|
||||
"# Show parameter breakdown for first few rows\n",
|
||||
"print(\"Parameter breakdown (first row per flops budget):\")\n",
|
||||
"param_cols = ['depth', 'params_wte', 'params_bigram_embed', 'params_value_embeds',\n",
|
||||
" 'params_lm_head', 'params_transformer', 'params_scalars', 'params_total', 'effective_params']\n",
|
||||
"df.groupby('flops_budget').first()[param_cols]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
@@ -54,11 +149,11 @@
|
||||
"optimal_by_bpb = []\n",
|
||||
"\n",
|
||||
"for flops, color in zip(flops_budgets, colors):\n",
|
||||
" subset = df[df['flops_budget'] == flops].sort_values('num_scaling_params')\n",
|
||||
" ax.plot(subset['num_scaling_params'], subset['val_bpb'], 'o', color=color, label=f'{flops:.0e}', markersize=8)\n",
|
||||
" subset = df[df['flops_budget'] == flops].sort_values('effective_params')\n",
|
||||
" ax.plot(subset['effective_params'], subset['val_bpb'], 'o', color=color, label=f'{flops:.0e}', markersize=8)\n",
|
||||
"\n",
|
||||
" # Fit quadratic in log-space: val_bpb = a*(log N)^2 + b*(log N) + c\n",
|
||||
" log_params = np.log10(subset['num_scaling_params'])\n",
|
||||
" log_params = np.log10(subset['effective_params'])\n",
|
||||
" coeffs = np.polyfit(log_params, subset['val_bpb'], 2)\n",
|
||||
" a, b, c = coeffs\n",
|
||||
"\n",
|
||||
@@ -83,13 +178,13 @@
|
||||
" # Fallback to raw minimum if quadratic doesn't have minimum\n",
|
||||
" best_idx = subset['val_bpb'].idxmin()\n",
|
||||
" best = subset.loc[best_idx]\n",
|
||||
" ax.scatter([best['num_scaling_params']], [best['val_bpb']], s=150, color=color,\n",
|
||||
" ax.scatter([best['effective_params']], [best['val_bpb']], s=150, color=color,\n",
|
||||
" zorder=5, edgecolors='black', linewidths=2)\n",
|
||||
" optimal_by_bpb.append({'flops': flops, 'params': best['num_scaling_params'],\n",
|
||||
" optimal_by_bpb.append({'flops': flops, 'params': best['effective_params'],\n",
|
||||
" 'tokens': best['tokens_trained'], 'ratio': best['param_data_ratio'], 'bpb': best['val_bpb']})\n",
|
||||
"\n",
|
||||
"ax.set_xscale('log')\n",
|
||||
"ax.set_xlabel('Parameters')\n",
|
||||
"ax.set_xlabel('Effective Parameters')\n",
|
||||
"ax.set_ylabel('Validation Loss (bpb)')\n",
|
||||
"ax.set_title('IsoFLOP Curves')\n",
|
||||
"ax.legend(title='FLOPs', loc='upper right')\n",
|
||||
@@ -138,10 +233,61 @@
|
||||
"\n",
|
||||
"# Print the optimal points (from quadratic fits)\n",
|
||||
"print(\"\\nOptimal configurations (from quadratic fits):\")\n",
|
||||
"print(f\"{'FLOPs':<12} {'Params':<15} {'Tokens':<15} {'Ratio':<10} {'Val BPB':<10}\")\n",
|
||||
"print(f\"{'FLOPs':<12} {'Eff Params':<15} {'Tokens':<15} {'Ratio':<10} {'Val BPB':<10}\")\n",
|
||||
"print(\"-\" * 65)\n",
|
||||
"for _, row in opt_df.iterrows():\n",
|
||||
" print(f\"{row['flops']:<12.0e} {int(row['params']):<15,} {int(row['tokens']):<15,} {row['ratio']:<10.1f} {row['bpb']:<10.4f}\")\n"
|
||||
" print(f\"{row['flops']:<12.0e} {int(row['params']):<15,} {int(row['tokens']):<15,} {row['ratio']:<10.1f} {row['bpb']:<10.4f}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# =============================================================================\n",
|
||||
"# Optimal Ratio Summary (from power law fits)\n",
|
||||
"# =============================================================================\n",
|
||||
"\n",
|
||||
"# From the power law fits: N ∝ C^a and D ∝ C^b\n",
|
||||
"# The ratio D/N ∝ C^(b-a). If a ≈ b, ratio is roughly constant.\n",
|
||||
"\n",
|
||||
"if len(opt_df) >= 2:\n",
|
||||
" log_f = np.log10(opt_df['flops'])\n",
|
||||
" log_p = np.log10(opt_df['params'])\n",
|
||||
" log_t = np.log10(opt_df['tokens'])\n",
|
||||
"\n",
|
||||
" # Fit power laws\n",
|
||||
" slope_n, intercept_n = np.polyfit(log_f, log_p, 1)\n",
|
||||
" slope_d, intercept_d = np.polyfit(log_f, log_t, 1)\n",
|
||||
"\n",
|
||||
" # The ratio D/N at a reference compute (geometric mean of our budgets)\n",
|
||||
" ref_flops = np.sqrt(opt_df['flops'].min() * opt_df['flops'].max())\n",
|
||||
" log_ref = np.log10(ref_flops)\n",
|
||||
"\n",
|
||||
" # Predicted optimal N and D at reference compute\n",
|
||||
" pred_log_n = intercept_n + slope_n * log_ref\n",
|
||||
" pred_log_d = intercept_d + slope_d * log_ref\n",
|
||||
" optimal_ratio = 10**(pred_log_d - pred_log_n)\n",
|
||||
"\n",
|
||||
" # Also compute from the fitted optimals directly (mean and std)\n",
|
||||
" mean_ratio = opt_df['ratio'].mean()\n",
|
||||
" std_ratio = opt_df['ratio'].std()\n",
|
||||
"\n",
|
||||
" print(\"=\" * 60)\n",
|
||||
" print(\"OPTIMAL RATIO SUMMARY\")\n",
|
||||
" print(\"=\" * 60)\n",
|
||||
" print(f\"\\nPower law exponents:\")\n",
|
||||
" print(f\" N ∝ C^{slope_n:.3f}\")\n",
|
||||
" print(f\" D ∝ C^{slope_d:.3f}\")\n",
|
||||
" print(f\" Ratio exponent (b-a): {slope_d - slope_n:.3f} (should be ~0 if ratio is constant)\")\n",
|
||||
" print(f\"\\nOptimal ratio (tokens per effective param):\")\n",
|
||||
" print(f\" From power law at C={ref_flops:.1e}: {optimal_ratio:.1f}\")\n",
|
||||
" print(f\" Mean across budgets: {mean_ratio:.1f} ± {std_ratio:.1f}\")\n",
|
||||
" print(f\" Chinchilla reference: 20\")\n",
|
||||
" print(f\"\\nPer-budget ratios: {[f'{r:.1f}' for r in opt_df['ratio'].values]}\")\n",
|
||||
"else:\n",
|
||||
" print(\"Need at least 2 flops budgets to compute power law fits\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
"""
|
||||
Borrowed from modded-nanogpt. By Keller, @vagrawal, et al.
|
||||
Not a general optimizer! But works for our specific use.
|
||||
"""
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class DistAdamW(torch.optim.Optimizer):
|
||||
"""
|
||||
Distributed AdamW optimizer.
|
||||
In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction
|
||||
"""
|
||||
def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01):
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
super().__init__(param_groups, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
reduce_futures: list[torch.Future] = []
|
||||
gather_futures: list[torch.Future] = []
|
||||
grad_slices = []
|
||||
is_small = [] # track which params are small (use all_reduce) vs large (use reduce_scatter)
|
||||
|
||||
for group in self.param_groups:
|
||||
params: list[Tensor] = group["params"]
|
||||
for p in params:
|
||||
grad = p.grad
|
||||
# Small params: use all_reduce (no scatter/gather needed)
|
||||
if p.numel() < 1024:
|
||||
is_small.append(True)
|
||||
reduce_futures.append(dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
|
||||
grad_slices.append(grad)
|
||||
else:
|
||||
is_small.append(False)
|
||||
assert p.shape[0] % world_size == 0, f"First dim of parameter shape {p.shape} must be divisible by world size {world_size}"
|
||||
rank_size = grad.shape[0] // world_size
|
||||
grad_slice = torch.empty_like(grad[:rank_size])
|
||||
reduce_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
|
||||
grad_slices.append(grad_slice)
|
||||
|
||||
idx = 0
|
||||
for group in self.param_groups:
|
||||
beta1, beta2 = group['betas']
|
||||
eps = group['eps']
|
||||
wd = group['weight_decay']
|
||||
params = group['params']
|
||||
for p in params:
|
||||
reduce_futures[idx].wait()
|
||||
g_slice = grad_slices[idx]
|
||||
lr = group['lr'] * getattr(p, "lr_mul", 1.0)
|
||||
state = self.state[p]
|
||||
|
||||
# For small params, operate on full param; for large, operate on slice
|
||||
if is_small[idx]:
|
||||
p_slice = p
|
||||
else:
|
||||
rank_size = p.shape[0] // world_size
|
||||
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
|
||||
|
||||
# State init
|
||||
if not state:
|
||||
state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device)
|
||||
state['exp_avg'] = torch.zeros_like(p_slice)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p_slice)
|
||||
exp_avg = state['exp_avg']
|
||||
exp_avg_sq = state['exp_avg_sq']
|
||||
state['step'] += 1
|
||||
t = state['step']
|
||||
# weight decay
|
||||
if wd != 0:
|
||||
eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0)
|
||||
p_slice.mul_(1 - eff_weight_decay)
|
||||
# update running averages
|
||||
exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2)
|
||||
# bias corrections
|
||||
bias1 = 1 - beta1 ** t
|
||||
bias2 = 1 - beta2 ** t
|
||||
# compute step
|
||||
denom = (exp_avg_sq / bias2).sqrt().add_(eps)
|
||||
step_size = lr / bias1
|
||||
update = exp_avg.div(denom).mul_(step_size)
|
||||
p_slice.add_(other=update, alpha=-1.0)
|
||||
|
||||
# Only large params need all_gather
|
||||
if not is_small[idx]:
|
||||
gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
|
||||
idx += 1
|
||||
|
||||
if gather_futures:
|
||||
torch.futures.collect_all(gather_futures).wait()
|
||||
@@ -25,6 +25,7 @@ def _patch_missing_config_keys(model_config_kwargs):
|
||||
# Old models were trained with full context (no sliding window)
|
||||
if "window_pattern" not in model_config_kwargs:
|
||||
model_config_kwargs["window_pattern"] = "L"
|
||||
log0(f"Patching missing window_pattern in model config to 'L'")
|
||||
|
||||
def _patch_missing_keys(model_data, model_config):
|
||||
"""Add default values for new parameters that may be missing in old checkpoints."""
|
||||
@@ -32,9 +33,11 @@ def _patch_missing_keys(model_data, model_config):
|
||||
# resid_lambdas defaults to 1.0 (identity scaling)
|
||||
if "resid_lambdas" not in model_data:
|
||||
model_data["resid_lambdas"] = torch.ones(n_layer)
|
||||
log0(f"Patching missing resid_lambdas in model data to 1.0")
|
||||
# x0_lambdas defaults to 0.0 (disabled)
|
||||
if "x0_lambdas" not in model_data:
|
||||
model_data["x0_lambdas"] = torch.zeros(n_layer)
|
||||
log0(f"Patching missing x0_lambdas in model data to 0.0")
|
||||
|
||||
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
|
||||
if rank == 0:
|
||||
@@ -108,7 +111,7 @@ def build_model(checkpoint_dir, step, device, phase):
|
||||
# Load the Tokenizer
|
||||
tokenizer = get_tokenizer()
|
||||
# Sanity check: compatibility between model and tokenizer
|
||||
assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"]
|
||||
assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"], f"Tokenizer vocab size {tokenizer.get_vocab_size()} does not match model config vocab size {model_config_kwargs['vocab_size']}"
|
||||
return model, tokenizer, meta_data
|
||||
|
||||
|
||||
|
||||
@@ -200,3 +200,77 @@ class DummyWandb:
|
||||
pass
|
||||
def finish(self):
|
||||
pass
|
||||
|
||||
# hardcoded BF16 peak flops for various GPUs
|
||||
# inspired by torchtitan: https://github.com/pytorch/torchtitan/blob/main/torchtitan/tools/utils.py
|
||||
# and PR: https://github.com/karpathy/nanochat/pull/147
|
||||
def get_peak_flops(device_name: str) -> float:
|
||||
name = device_name.lower()
|
||||
|
||||
# --- NVIDIA Blackwell ---
|
||||
if "gb200" in name or "grace blackwell" in name:
|
||||
return 2.5e15
|
||||
if "b200" in name:
|
||||
return 2.25e15
|
||||
if "b100" in name:
|
||||
return 1.8e15
|
||||
|
||||
# --- NVIDIA Hopper (H100/H200/H800) ---
|
||||
if "h200" in name:
|
||||
if "nvl" in name or "pcie" in name:
|
||||
return 836e12
|
||||
return 989e12 # H200 SXM
|
||||
if "h100" in name:
|
||||
if "nvl" in name:
|
||||
return 835e12
|
||||
if "pcie" in name:
|
||||
return 756e12
|
||||
return 989e12 # H100 SXM
|
||||
if "h800" in name:
|
||||
if "nvl" in name:
|
||||
return 989e12
|
||||
return 756e12 # H800 PCIe
|
||||
|
||||
# --- NVIDIA Ampere data center ---
|
||||
if "a100" in name or "a800" in name:
|
||||
return 312e12
|
||||
if "a40" in name:
|
||||
return 149.7e12
|
||||
if "a30" in name:
|
||||
return 165e12
|
||||
|
||||
# --- NVIDIA Ada data center ---
|
||||
if "l40s" in name or "l40-s" in name or "l40 s" in name:
|
||||
return 362e12
|
||||
if "l4" in name:
|
||||
return 121e12
|
||||
|
||||
# --- AMD CDNA accelerators ---
|
||||
if "mi355" in name:
|
||||
return 2.5e15
|
||||
if "mi325" in name or "mi300x" in name:
|
||||
return 1.3074e15
|
||||
if "mi300a" in name:
|
||||
return 980.6e12
|
||||
if "mi250x" in name:
|
||||
return 383e12
|
||||
if "mi250" in name:
|
||||
return 362.1e12
|
||||
|
||||
# --- Intel ---
|
||||
if "data center gpu max 1550" in name:
|
||||
# Ponte Vecchio (PVC) - dynamic based on compute units
|
||||
max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units
|
||||
return 512 * max_comp_units * 1300 * 10**6
|
||||
|
||||
# --- Consumer RTX (for hobbyists) ---
|
||||
if "5090" in name:
|
||||
return 209.5e12
|
||||
if "4090" in name:
|
||||
return 165.2e12
|
||||
if "3090" in name:
|
||||
return 71e12
|
||||
|
||||
# Unknown GPU - return inf so MFU shows as 0% rather than a wrong guess
|
||||
logger.warning(f"Peak flops undefined for: {device_name}, MFU will show as 0%")
|
||||
return float('inf')
|
||||
|
||||
@@ -1,4 +1,25 @@
|
||||
from collections import deque
|
||||
"""
|
||||
Distributed dataloaders for pretraining.
|
||||
|
||||
Two implementations are provided:
|
||||
|
||||
1. Original (tokenizing_distributed_data_loader):
|
||||
- Streams tokens into a flat buffer, reshapes to (B, T)
|
||||
- Rows may start mid-document (no guaranteed BOS at position 0)
|
||||
- 100% token utilization, simple and efficient
|
||||
|
||||
2. BOS-aligned bestfit (tokenizing_distributed_data_loader_bos_bestfit):
|
||||
- Every row starts with BOS token
|
||||
- Documents packed using best-fit algorithm to minimize cropping
|
||||
- When no document fits remaining space, crops a document to fill exactly
|
||||
- 100% utilization (no padding), ~35% tokens cropped at T=2048
|
||||
|
||||
The tradeoff: BOS-aligned loses ~35% of tokens to cropping, but ensures that
|
||||
there are fewer "confusing" tokens in the train/val batches as every token can
|
||||
now attend back to the BOS token and sees the full context of the document.
|
||||
(2) is the new default if you have enough data.
|
||||
Fallback to (1) if you have very limited data AND long documents.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import pyarrow.parquet as pq
|
||||
@@ -6,86 +27,173 @@ import pyarrow.parquet as pq
|
||||
from nanochat.common import get_dist_info
|
||||
from nanochat.dataset import list_parquet_files
|
||||
|
||||
def _document_batches(split, resume_state_dict, tokenizer_batch_size):
|
||||
"""
|
||||
Infinite iterator over document batches (list of text strings) from parquet files.
|
||||
|
||||
Handles DDP sharding and approximate resume. Each yield is (text_batch, (pq_idx, rg_idx, epoch))
|
||||
where text_batch is a list of document strings, indices track position for resumption,
|
||||
and epoch counts how many times we've cycled through the dataset (starts at 1).
|
||||
"""
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
|
||||
parquet_paths = list_parquet_files()
|
||||
assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?"
|
||||
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
|
||||
|
||||
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
|
||||
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
|
||||
resume_epoch = resume_state_dict.get("epoch", 1) if resume_state_dict is not None else 1
|
||||
first_pass = True
|
||||
pq_idx = resume_pq_idx
|
||||
epoch = resume_epoch
|
||||
|
||||
while True: # iterate infinitely (multi-epoch)
|
||||
pq_idx = resume_pq_idx if first_pass else 0
|
||||
while pq_idx < len(parquet_paths):
|
||||
filepath = parquet_paths[pq_idx]
|
||||
pf = pq.ParquetFile(filepath)
|
||||
# Start from resume point if resuming on same file, otherwise from DDP rank
|
||||
if first_pass and (resume_rg_idx is not None) and (pq_idx == resume_pq_idx):
|
||||
base_idx = resume_rg_idx // ddp_world_size
|
||||
base_idx += 1 # advance by 1 so we don't repeat data after resuming
|
||||
rg_idx = base_idx * ddp_world_size + ddp_rank
|
||||
if rg_idx >= pf.num_row_groups:
|
||||
pq_idx += 1
|
||||
continue
|
||||
resume_rg_idx = None # only do this once
|
||||
else:
|
||||
rg_idx = ddp_rank
|
||||
while rg_idx < pf.num_row_groups:
|
||||
rg = pf.read_row_group(rg_idx)
|
||||
batch = rg.column('text').to_pylist()
|
||||
for i in range(0, len(batch), tokenizer_batch_size):
|
||||
yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx, epoch)
|
||||
rg_idx += ddp_world_size
|
||||
pq_idx += 1
|
||||
first_pass = False
|
||||
epoch += 1
|
||||
|
||||
|
||||
def tokenizing_distributed_data_loader_with_state(tokenizer, B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None):
|
||||
"""
|
||||
Stream pretraining text from parquet files, tokenize, yield training batches.
|
||||
|
||||
This implementation became a bit more complex because we wish to support approximate resume training.
|
||||
Instead of turning this into a Class, we opt to return the state_dict with every batch,
|
||||
and then the caller can pass in a state_dict to resume training from a desired point.
|
||||
Note that this resumption is atm only *approximate* for simplicity.
|
||||
We won't repeat the same documents but we might skip a few.
|
||||
The state_dict that is returned can be later passed into this function via `resume_state_dict` to approximately resume.
|
||||
This is the original dataloader that streams tokens into a flat buffer and reshapes.
|
||||
Rows may start mid-document (no guaranteed BOS at position 0).
|
||||
|
||||
Perfect state resumption is possible but would be a lot more bloated, probably not worth it atm.
|
||||
Supports approximate resume via state_dict.
|
||||
"""
|
||||
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
||||
|
||||
# infinite iterator over document batches (list of text strings)
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
def document_batches():
|
||||
parquet_paths = list_parquet_files()
|
||||
assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?"
|
||||
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
|
||||
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
|
||||
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
|
||||
first_pass = True
|
||||
pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0)
|
||||
while True: # iterate infinitely (multi-epoch)
|
||||
pq_idx = resume_pq_idx if first_pass else 0
|
||||
while pq_idx < len(parquet_paths): # iterate over all parquet files
|
||||
filepath = parquet_paths[pq_idx]
|
||||
pf = pq.ParquetFile(filepath)
|
||||
# Start from resume point if resuming on same file, otherwise from DDP rank
|
||||
# I know this state resumption is a little bit tricky and a little bit hacky... sigh.
|
||||
if first_pass and (resume_rg_idx is not None) and (pq_idx == resume_pq_idx):
|
||||
base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size
|
||||
base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming
|
||||
rg_idx = base_idx * ddp_world_size + ddp_rank
|
||||
if rg_idx >= pf.num_row_groups:
|
||||
pq_idx += 1
|
||||
continue
|
||||
resume_rg_idx = None # set to None as we only want to do this a single time
|
||||
else:
|
||||
rg_idx = ddp_rank
|
||||
while rg_idx < pf.num_row_groups:
|
||||
rg = pf.read_row_group(rg_idx)
|
||||
batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows
|
||||
# the tokenizer encode might want to go in even smaller batches, e.g. 128 rows
|
||||
for i in range(0, len(batch), tokenizer_batch_size):
|
||||
yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx)
|
||||
rg_idx += ddp_world_size # advance to the next row group (in DDP)
|
||||
pq_idx += 1 # advance to the next parquet file
|
||||
first_pass = False
|
||||
batches = document_batches()
|
||||
|
||||
# Now emit batches of tokens.
|
||||
needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
|
||||
batches = _document_batches(split, resume_state_dict, tokenizer_batch_size)
|
||||
needed_tokens = B * T + 1 # +1 for target at last position
|
||||
bos_token = tokenizer.get_bos_token_id()
|
||||
# scratch buffer holds the tokens for one iteration
|
||||
token_buffer = deque() # we stream tokens on the right and pop from the left
|
||||
token_buffer = []
|
||||
pq_idx, rg_idx, epoch = 0, 0, 1
|
||||
|
||||
while True:
|
||||
# Accumulate enough tokens for one iteration before yielding.
|
||||
|
||||
# Accumulate enough tokens
|
||||
while len(token_buffer) < needed_tokens:
|
||||
doc_batch, (pq_idx, rg_idx) = next(batches)
|
||||
doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
|
||||
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
|
||||
for tokens in token_lists:
|
||||
token_buffer.extend(tokens)
|
||||
# Move tokens from the deque into the scratch buffer
|
||||
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
|
||||
# CUDA supports memory pinning for asynchronous transfers between CPU and GPU
|
||||
use_cuda_optimizations = device == "cuda"
|
||||
scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations) # in PyTorch, long=int64
|
||||
# Create the inputs/targets as 1D tensors
|
||||
inputs_cpu = scratch[:-1]
|
||||
targets_cpu = scratch[1:]
|
||||
# Reshape to 2D and move to GPU async
|
||||
inputs = inputs_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
|
||||
targets = targets_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
|
||||
state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx} # we need this in case we wish to approximately resume training
|
||||
yield inputs, targets, state_dict
|
||||
tokens = token_buffer[:needed_tokens] # Read B*T+1 tokens (+1 is only for the target for the last token)
|
||||
token_buffer = token_buffer[B*T:] # Advance by B*T tokens, so we move exactly one window of B*T tokens over
|
||||
|
||||
# Package tokens into inputs and targets, yield
|
||||
use_cuda = device == "cuda"
|
||||
scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda)
|
||||
inputs = scratch[:-1].view(B, T).to(device=device, non_blocking=use_cuda)
|
||||
targets = scratch[1:].view(B, T).to(device=device, non_blocking=use_cuda)
|
||||
yield inputs, targets, {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
|
||||
|
||||
|
||||
def tokenizing_distributed_data_loader(*args, **kwargs):
|
||||
# helper function that only emits the inputs/targets and not the state_dict
|
||||
"""Helper that omits state_dict from yields."""
|
||||
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs):
|
||||
yield inputs, targets
|
||||
|
||||
|
||||
def tokenizing_distributed_data_loader_with_state_bos_bestfit(
|
||||
tokenizer, B, T, split,
|
||||
tokenizer_threads=4, tokenizer_batch_size=128,
|
||||
device="cuda", resume_state_dict=None,
|
||||
buffer_size=1000
|
||||
):
|
||||
"""
|
||||
BOS-aligned dataloader with Best-Fit Cropping.
|
||||
|
||||
Reduces token waste compared to simple greedy cropping by searching a buffer
|
||||
for documents that fit well, while maintaining 100% utilization (no padding).
|
||||
|
||||
Algorithm for each row:
|
||||
1. From buffered docs, pick the LARGEST doc that fits entirely
|
||||
2. Repeat until no doc fits
|
||||
3. When nothing fits, crop a doc to fill remaining space exactly
|
||||
|
||||
Key properties:
|
||||
- Every row starts with BOS
|
||||
- 100% utilization (no padding, every token is trained on)
|
||||
- Approximately 35% of all tokens are discarded due to cropping
|
||||
"""
|
||||
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
||||
|
||||
row_capacity = T + 1
|
||||
batches = _document_batches(split, resume_state_dict, tokenizer_batch_size)
|
||||
bos_token = tokenizer.get_bos_token_id()
|
||||
doc_buffer = []
|
||||
pq_idx, rg_idx, epoch = 0, 0, 1
|
||||
|
||||
def refill_buffer():
|
||||
nonlocal pq_idx, rg_idx, epoch
|
||||
doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
|
||||
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
|
||||
for tokens in token_lists:
|
||||
doc_buffer.append(tokens)
|
||||
|
||||
while True:
|
||||
rows = []
|
||||
for _ in range(B):
|
||||
row = []
|
||||
while len(row) < row_capacity:
|
||||
# Ensure buffer has documents
|
||||
while len(doc_buffer) < buffer_size:
|
||||
refill_buffer()
|
||||
|
||||
remaining = row_capacity - len(row)
|
||||
|
||||
# Find largest doc that fits entirely
|
||||
best_idx = -1
|
||||
best_len = 0
|
||||
for i, doc in enumerate(doc_buffer):
|
||||
doc_len = len(doc)
|
||||
if doc_len <= remaining and doc_len > best_len:
|
||||
best_idx = i
|
||||
best_len = doc_len
|
||||
|
||||
if best_idx >= 0:
|
||||
doc = doc_buffer.pop(best_idx)
|
||||
row.extend(doc)
|
||||
else:
|
||||
# No doc fits - crop shortest in buffer to fill remaining and minimize waste
|
||||
shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]))
|
||||
doc = doc_buffer.pop(shortest_idx)
|
||||
row.extend(doc[:remaining])
|
||||
|
||||
rows.append(row[:row_capacity])
|
||||
|
||||
use_cuda = device == "cuda"
|
||||
batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
|
||||
inputs = batch_tensor[:, :-1].to(device=device, non_blocking=use_cuda)
|
||||
targets = batch_tensor[:, 1:].to(device=device, non_blocking=use_cuda)
|
||||
|
||||
yield inputs, targets, {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
|
||||
|
||||
|
||||
def tokenizing_distributed_data_loader_bos_bestfit(*args, **kwargs):
|
||||
"""Helper that omits state_dict from yields."""
|
||||
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state_bos_bestfit(*args, **kwargs):
|
||||
yield inputs, targets
|
||||
|
||||
@@ -90,7 +90,7 @@ class KVCache:
|
||||
- Position tracked per batch element via cache_seqlens tensor
|
||||
"""
|
||||
|
||||
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype=torch.bfloat16):
|
||||
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype):
|
||||
self.batch_size = batch_size
|
||||
self.max_seq_len = seq_len
|
||||
self.n_layers = num_layers
|
||||
@@ -172,6 +172,13 @@ class Engine:
|
||||
"""Same as generate, but does single prefill and then clones the KV cache."""
|
||||
assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
|
||||
device = self.model.get_device()
|
||||
# NOTE: setting the dtype here and in this way is an ugly hack.
|
||||
# Currently the repo assumes that cuda -> bfloat16 and everything else -> float32.
|
||||
# We need to know the dtype here to call __init__ on KVCache and pre-allocate its tensors.
|
||||
# As a quick hack, we're making generate() function inherit and know about this repo-wise assumption.
|
||||
# I think there has to be a bigger refactor to deal with device/dtype tracking across the codebase.
|
||||
# In particular, the KVCache should allocate its tensors lazily
|
||||
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
|
||||
rng = torch.Generator(device=device)
|
||||
rng.manual_seed(seed)
|
||||
|
||||
@@ -191,6 +198,7 @@ class Engine:
|
||||
batch_size=1,
|
||||
seq_len=len(tokens),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
**kv_model_kwargs,
|
||||
)
|
||||
ids = torch.tensor([tokens], dtype=torch.long, device=device)
|
||||
@@ -203,6 +211,7 @@ class Engine:
|
||||
batch_size=num_samples,
|
||||
seq_len=kv_length_hint,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
**kv_model_kwargs,
|
||||
)
|
||||
kv_cache_decode.prefill(kv_cache_prefill)
|
||||
@@ -297,8 +306,8 @@ if __name__ == "__main__":
|
||||
"""
|
||||
import time
|
||||
# init compute
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type()
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# load the model and tokenizer
|
||||
|
||||
178
nanochat/flash_attention.py
Normal file
178
nanochat/flash_attention.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
Unified Flash Attention interface with automatic FA3/SDPA switching.
|
||||
|
||||
Exports `flash_attn` module that matches the FA3 API exactly, but falls back
|
||||
to PyTorch SDPA on non-Hopper GPUs, MPS, and CPU.
|
||||
|
||||
Usage (drop-in replacement for FA3):
|
||||
from nanochat.flash_attention import flash_attn
|
||||
|
||||
# Training (no KV cache)
|
||||
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
|
||||
|
||||
# Inference (with KV cache)
|
||||
y = flash_attn.flash_attn_with_kvcache(q, k_cache, v_cache, k=k, v=v, ...)
|
||||
"""
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Detection: Try to load FA3 on Hopper+ GPUs
|
||||
# =============================================================================
|
||||
def _load_flash_attention_3():
|
||||
"""Try to load Flash Attention 3 (requires Hopper+ GPU)."""
|
||||
if not torch.cuda.is_available():
|
||||
return None
|
||||
try:
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
if major < 9: # Hopper is sm90
|
||||
return None
|
||||
import os
|
||||
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
||||
from kernels import get_kernel
|
||||
return get_kernel('varunneal/flash-attention-3').flash_attn_interface
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
_fa3 = _load_flash_attention_3()
|
||||
HAS_FA3 = _fa3 is not None
|
||||
|
||||
# Override for testing: set to 'fa3', 'sdpa', or None (auto)
|
||||
_override_impl = None
|
||||
|
||||
|
||||
def _use_fa3():
|
||||
"""Determine whether to use FA3 based on availability and override."""
|
||||
if _override_impl == 'fa3':
|
||||
assert HAS_FA3, "Cannot override to FA3: not available on this hardware"
|
||||
return True
|
||||
if _override_impl == 'sdpa':
|
||||
return False
|
||||
return HAS_FA3 # auto
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SDPA helpers
|
||||
# =============================================================================
|
||||
def _sdpa_attention(q, k, v, window_size, enable_gqa):
|
||||
"""
|
||||
SDPA attention with sliding window support.
|
||||
q, k, v are (B, H, T, D) format.
|
||||
"""
|
||||
Tq = q.size(2)
|
||||
Tk = k.size(2)
|
||||
window = window_size[0]
|
||||
|
||||
# Full context, same length
|
||||
if (window < 0 or window >= Tq) and Tq == Tk:
|
||||
return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
|
||||
|
||||
# Single token generation
|
||||
if Tq == 1:
|
||||
return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
|
||||
|
||||
# Need explicit mask
|
||||
device = q.device
|
||||
if Tq == Tk:
|
||||
# Causal + sliding window
|
||||
mask = torch.tril(torch.ones(Tq, Tk, device=device, dtype=torch.bool))
|
||||
if window > 0 and window < Tq:
|
||||
row_idx = torch.arange(Tq, device=device).unsqueeze(1)
|
||||
col_idx = torch.arange(Tk, device=device).unsqueeze(0)
|
||||
mask = mask & ((row_idx - col_idx) <= window)
|
||||
else:
|
||||
# Chunk inference: attend to prefix + causal within chunk
|
||||
prefix_len = Tk - Tq
|
||||
mask = torch.zeros(Tq, Tk, device=device, dtype=torch.bool)
|
||||
mask[:, :prefix_len] = True
|
||||
mask[:, prefix_len:] = torch.tril(torch.ones(Tq, Tq, device=device, dtype=torch.bool))
|
||||
|
||||
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Public API: Same interface as FA3
|
||||
# =============================================================================
|
||||
def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
|
||||
"""
|
||||
Flash Attention for training (no KV cache).
|
||||
|
||||
Args:
|
||||
q, k, v: Tensors of shape (B, T, H, D)
|
||||
causal: Whether to use causal masking
|
||||
window_size: (left, right) sliding window. -1 means unlimited.
|
||||
|
||||
Returns:
|
||||
Output tensor of shape (B, T, H, D)
|
||||
"""
|
||||
if _use_fa3():
|
||||
return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
|
||||
|
||||
# SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D)
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
enable_gqa = q.size(1) != k.size(1)
|
||||
y = _sdpa_attention(q, k, v, window_size, enable_gqa)
|
||||
return y.transpose(1, 2) # back to (B, T, H, D)
|
||||
|
||||
|
||||
def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=None,
|
||||
causal=False, window_size=(-1, -1)):
|
||||
"""
|
||||
Flash Attention with KV cache for inference.
|
||||
|
||||
FA3 updates k_cache/v_cache in-place. Our SDPA fallback does the same.
|
||||
|
||||
Args:
|
||||
q: Queries, shape (B, T_new, H, D)
|
||||
k_cache, v_cache: Pre-allocated cache tensors, shape (B, T_max, H_kv, D)
|
||||
k, v: New keys/values to insert, shape (B, T_new, H_kv, D)
|
||||
cache_seqlens: Current position in cache, shape (B,) int32
|
||||
causal: Whether to use causal masking
|
||||
window_size: (left, right) sliding window. -1 means unlimited.
|
||||
|
||||
Returns:
|
||||
Output tensor of shape (B, T_new, H, D)
|
||||
"""
|
||||
if _use_fa3():
|
||||
return _fa3.flash_attn_with_kvcache(
|
||||
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
|
||||
causal=causal, window_size=window_size
|
||||
)
|
||||
|
||||
# SDPA fallback: manually manage KV cache
|
||||
B, T_new, H, D = q.shape
|
||||
pos = cache_seqlens[0].item() # assume uniform position across batch
|
||||
|
||||
# Insert new k, v into cache (in-place, matching FA3 behavior)
|
||||
if k is not None and v is not None:
|
||||
k_cache[:, pos:pos+T_new, :, :] = k
|
||||
v_cache[:, pos:pos+T_new, :, :] = v
|
||||
|
||||
# Get full cache up to current position + new tokens
|
||||
end_pos = pos + T_new
|
||||
k_full = k_cache[:, :end_pos, :, :]
|
||||
v_full = v_cache[:, :end_pos, :, :]
|
||||
|
||||
# Transpose to SDPA layout: (B, T, H, D) -> (B, H, T, D)
|
||||
q_sdpa = q.transpose(1, 2)
|
||||
k_sdpa = k_full.transpose(1, 2)
|
||||
v_sdpa = v_full.transpose(1, 2)
|
||||
|
||||
enable_gqa = q_sdpa.size(1) != k_sdpa.size(1)
|
||||
y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa)
|
||||
|
||||
return y_sdpa.transpose(1, 2) # back to (B, T, H, D)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Export: flash_attn module interface (drop-in replacement for FA3)
|
||||
# =============================================================================
|
||||
from types import SimpleNamespace
|
||||
flash_attn = SimpleNamespace(
|
||||
flash_attn_func=flash_attn_func,
|
||||
flash_attn_with_kvcache=flash_attn_with_kvcache,
|
||||
)
|
||||
158
nanochat/gpt.py
158
nanochat/gpt.py
@@ -20,21 +20,15 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from nanochat.common import get_dist_info, print0
|
||||
from nanochat.muon import Muon, DistMuon
|
||||
from nanochat.adamw import DistAdamW
|
||||
from nanochat.optim import MuonAdamW, DistMuonAdamW
|
||||
|
||||
# Load Flash Attention 3 from HuggingFace Hub (and silence the progress bar)
|
||||
import os
|
||||
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
||||
# Official docs of FA3 label it as "beta" and want you to install FA3 from source, which is a pain.
|
||||
# Wishing for official FA3 wheels soon, for now this seems to be a fast way to get them (ty varunneal)
|
||||
from kernels import get_kernel
|
||||
flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface
|
||||
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
|
||||
from nanochat.flash_attention import flash_attn
|
||||
|
||||
@dataclass
|
||||
class GPTConfig:
|
||||
sequence_len: int = 1024
|
||||
vocab_size: int = 50304
|
||||
sequence_len: int = 2048
|
||||
vocab_size: int = 32768
|
||||
n_layer: int = 12
|
||||
n_head: int = 6 # number of query heads
|
||||
n_kv_head: int = 6 # number of key/value heads (GQA)
|
||||
@@ -42,7 +36,7 @@ class GPTConfig:
|
||||
# Sliding window attention pattern string, tiled across layers. Final layer always L.
|
||||
# Characters: L=long (full context), S=short (half context)
|
||||
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
|
||||
window_pattern: str = "L"
|
||||
window_pattern: str = "SSSL"
|
||||
|
||||
|
||||
def norm(x):
|
||||
@@ -50,6 +44,10 @@ def norm(x):
|
||||
return F.rms_norm(x, (x.size(-1),))
|
||||
|
||||
|
||||
def has_ve(layer_idx, n_layer):
|
||||
"""Returns True if GPT layer should have Value Embedding (alternating, last layer always included)."""
|
||||
return layer_idx % 2 == (n_layer - 1) % 2
|
||||
|
||||
def apply_rotary_emb(x, cos, sin):
|
||||
assert x.ndim == 4 # multihead attention
|
||||
d = x.shape[3] // 2
|
||||
@@ -72,8 +70,10 @@ class CausalSelfAttention(nn.Module):
|
||||
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
||||
self.ve_gate_channels = 32
|
||||
self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
|
||||
|
||||
def forward(self, x, cos_sin, window_size, kv_cache):
|
||||
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
||||
B, T, C = x.size()
|
||||
|
||||
# Project the input to get queries, keys, and values
|
||||
@@ -82,13 +82,18 @@ class CausalSelfAttention(nn.Module):
|
||||
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
|
||||
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
|
||||
|
||||
# Value residual (ResFormer): mix in value embedding with input-dependent gate per head
|
||||
if ve is not None:
|
||||
ve = ve.view(B, T, self.n_kv_head, self.head_dim)
|
||||
gate = 2 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) # (B, T, n_kv_head), range (0, 2)
|
||||
v = v + gate.unsqueeze(-1) * ve
|
||||
|
||||
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
|
||||
cos, sin = cos_sin
|
||||
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
|
||||
q, k = norm(q), norm(k) # QK norm
|
||||
|
||||
# Attention with Flash Attention 3
|
||||
# FA3 handles GQA automatically when n_kv_heads < n_heads
|
||||
# Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere)
|
||||
# window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
|
||||
if kv_cache is None:
|
||||
# Training: causal attention with optional sliding window
|
||||
@@ -132,8 +137,8 @@ class Block(nn.Module):
|
||||
self.attn = CausalSelfAttention(config, layer_idx)
|
||||
self.mlp = MLP(config)
|
||||
|
||||
def forward(self, x, cos_sin, window_size, kv_cache):
|
||||
x = x + self.attn(norm(x), cos_sin, window_size, kv_cache)
|
||||
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
||||
x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache)
|
||||
x = x + self.mlp(norm(x))
|
||||
return x
|
||||
|
||||
@@ -166,6 +171,10 @@ class GPT(nn.Module):
|
||||
# Separate parameters so they can have different optimizer treatment
|
||||
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
|
||||
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
||||
# Value embeddings (ResFormer-style): alternating layers, last layer always included
|
||||
head_dim = config.n_embd // config.n_head
|
||||
kv_dim = config.n_kv_head * head_dim
|
||||
self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)})
|
||||
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
|
||||
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
|
||||
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
|
||||
@@ -176,6 +185,7 @@ class GPT(nn.Module):
|
||||
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
||||
self.register_buffer("sin", sin, persistent=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def init_weights(self):
|
||||
"""
|
||||
Initialize the full model in this one function for maximum clarity.
|
||||
@@ -207,18 +217,28 @@ class GPT(nn.Module):
|
||||
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
||||
|
||||
# Per-layer scalars
|
||||
with torch.no_grad():
|
||||
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
|
||||
self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init
|
||||
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
|
||||
self.x0_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to input embedding
|
||||
|
||||
# Value embeddings (init like c_v: uniform with same std)
|
||||
for ve in self.value_embeds.values():
|
||||
torch.nn.init.uniform_(ve.weight, -s, s)
|
||||
|
||||
# Gate weights init to zero so gates start at sigmoid(0) = 0.5, scaled by 2 -> 1.0 (neutral)
|
||||
for block in self.transformer.h:
|
||||
if block.attn.ve_gate is not None:
|
||||
torch.nn.init.zeros_(block.attn.ve_gate.weight)
|
||||
|
||||
# Rotary embeddings
|
||||
head_dim = self.config.n_embd // self.config.n_head
|
||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
self.cos, self.sin = cos, sin
|
||||
|
||||
# Cast token embeddings to bf16: optimizer can tolerate it and it saves memory
|
||||
# Cast embeddings to bf16: optimizer can tolerate it and it saves memory
|
||||
if self.transformer.wte.weight.device.type == "cuda":
|
||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||
for ve in self.value_embeds.values():
|
||||
ve.to(dtype=torch.bfloat16)
|
||||
|
||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
||||
# TODO: bump base theta more? e.g. 100K is more common more recently
|
||||
@@ -283,7 +303,9 @@ class GPT(nn.Module):
|
||||
"""
|
||||
nparams = sum(p.numel() for p in self.parameters())
|
||||
# Exclude non-matmul params: embeddings and per-layer scalars
|
||||
nparams_exclude = self.transformer.wte.weight.numel() + self.resid_lambdas.numel() + self.x0_lambdas.numel()
|
||||
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
|
||||
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
|
||||
self.resid_lambdas.numel() + self.x0_lambdas.numel())
|
||||
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
||||
# Sum attention FLOPs per layer, accounting for sliding window
|
||||
attn_flops = 0
|
||||
@@ -296,49 +318,72 @@ class GPT(nn.Module):
|
||||
|
||||
def num_scaling_params(self):
|
||||
"""
|
||||
Return all of the parameters, same as Chinchilla paper.
|
||||
Kaplan et al. did not include embedding parameters and said that this led to cleaner scaling laws.
|
||||
But Kaplan et al. also had a bug in their results (as pointed out by Chinchilla).
|
||||
My own experiments in nanochat confirm the Chinchilla approach gives the much cleaner scaling law.
|
||||
Ref: https://arxiv.org/abs/2203.15556 (Chinchilla paper <- good).
|
||||
Ref: https://arxiv.org/abs/2001.08361 (Kaplan et al. original scaling laws paper <- bad)
|
||||
"""
|
||||
nparams = sum(p.numel() for p in self.parameters())
|
||||
return nparams
|
||||
Return detailed parameter counts for scaling law analysis.
|
||||
Different papers use different conventions:
|
||||
- Kaplan et al. excluded embedding parameters
|
||||
- Chinchilla included all parameters
|
||||
Ref: https://arxiv.org/abs/2203.15556 (Chinchilla paper)
|
||||
Ref: https://arxiv.org/abs/2001.08361 (Kaplan et al. original scaling laws paper)
|
||||
|
||||
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
|
||||
Returns a dict with counts for each parameter group, so downstream analysis
|
||||
can experiment with which combination gives the cleanest scaling laws.
|
||||
"""
|
||||
# Count each group separately (mirrors the grouping in setup_optimizers)
|
||||
wte = sum(p.numel() for p in self.transformer.wte.parameters())
|
||||
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
|
||||
lm_head = sum(p.numel() for p in self.lm_head.parameters())
|
||||
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
|
||||
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel()
|
||||
total = wte + value_embeds + lm_head + transformer_matrices + scalars
|
||||
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
|
||||
return {
|
||||
'wte': wte,
|
||||
'value_embeds': value_embeds,
|
||||
'lm_head': lm_head,
|
||||
'transformer_matrices': transformer_matrices,
|
||||
'scalars': scalars,
|
||||
'total': total,
|
||||
}
|
||||
|
||||
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
|
||||
model_dim = self.config.n_embd
|
||||
ddp, rank, local_rank, world_size = get_dist_info()
|
||||
# Separate out all parameters into 5 groups (matrix, embedding, lm_head, resid_lambdas, x0_lambdas)
|
||||
|
||||
# Separate out all parameters into groups
|
||||
matrix_params = list(self.transformer.h.parameters())
|
||||
value_embeds_params = list(self.value_embeds.parameters())
|
||||
embedding_params = list(self.transformer.wte.parameters())
|
||||
lm_head_params = list(self.lm_head.parameters())
|
||||
resid_params = [self.resid_lambdas]
|
||||
x0_params = [self.x0_lambdas]
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(resid_params) + len(x0_params)
|
||||
# Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars
|
||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)
|
||||
|
||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
|
||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||
print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
|
||||
adam_groups = [
|
||||
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
|
||||
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
|
||||
dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream
|
||||
dict(params=x0_params, lr=scalar_lr),
|
||||
|
||||
# Build param_groups with all required fields explicit
|
||||
param_groups = [
|
||||
# AdamW groups (embeddings, lm_head, scalars)
|
||||
dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||
dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
|
||||
]
|
||||
adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon
|
||||
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
|
||||
adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
|
||||
# Create the Muon optimizer for the linear layers
|
||||
muon_kwargs = dict(lr=matrix_lr, momentum=0.95, weight_decay=weight_decay)
|
||||
MuonFactory = DistMuon if ddp else Muon
|
||||
muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
|
||||
# Combine them the two optimizers into one list
|
||||
optimizers = [adamw_optimizer, muon_optimizer]
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["initial_lr"] = group["lr"]
|
||||
return optimizers
|
||||
# Muon groups (matrix params, grouped by shape for stacking)
|
||||
for shape in sorted({p.shape for p in matrix_params}):
|
||||
group_params = [p for p in matrix_params if p.shape == shape]
|
||||
param_groups.append(dict(
|
||||
kind='muon', params=group_params, lr=matrix_lr,
|
||||
momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay,
|
||||
))
|
||||
|
||||
Factory = DistMuonAdamW if ddp else MuonAdamW
|
||||
optimizer = Factory(param_groups)
|
||||
for group in optimizer.param_groups:
|
||||
group["initial_lr"] = group["lr"]
|
||||
return optimizer
|
||||
|
||||
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
||||
B, T = idx.size()
|
||||
@@ -352,12 +397,13 @@ class GPT(nn.Module):
|
||||
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
|
||||
|
||||
# Forward the trunk of the Transformer
|
||||
x = self.transformer.wte(idx)
|
||||
x = self.transformer.wte(idx) # embed current token
|
||||
x = norm(x)
|
||||
x0 = x # save initial normalized embedding for x0 residual
|
||||
for i, block in enumerate(self.transformer.h):
|
||||
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
||||
x = block(x, cos_sin, self.window_sizes[i], kv_cache)
|
||||
ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None
|
||||
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
|
||||
x = norm(x)
|
||||
|
||||
# Forward the lm_head (compute logits)
|
||||
|
||||
295
nanochat/muon.py
295
nanochat/muon.py
@@ -1,295 +0,0 @@
|
||||
"""
|
||||
Muon optimizer adapted (simplified) from modded-nanogpt.
|
||||
https://github.com/KellerJordan/modded-nanogpt
|
||||
"""
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.distributed as dist
|
||||
|
||||
# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
|
||||
# From https://arxiv.org/pdf/2505.16932
|
||||
polar_express_coeffs = [
|
||||
(8.156554524902461, -22.48329292557795, 15.878769915207462),
|
||||
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
|
||||
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
|
||||
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
|
||||
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
|
||||
]
|
||||
|
||||
|
||||
@torch.compile
|
||||
def zeropower_via_polar_express(G: Tensor, steps: int = 5) -> Tensor:
|
||||
"""
|
||||
Polar Express Sign Method for orthogonalization.
|
||||
https://arxiv.org/pdf/2505.16932
|
||||
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
|
||||
|
||||
Alternative to Newton-Schulz iteration with potentially better convergence properties.
|
||||
"""
|
||||
assert G.ndim >= 2
|
||||
X = G.bfloat16()
|
||||
if G.size(-2) > G.size(-1):
|
||||
X = X.mT
|
||||
|
||||
# Ensure spectral norm is at most 1 (with 2% safety factor)
|
||||
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
|
||||
|
||||
# Perform the iterations (cap at available coefficients)
|
||||
for a, b, c in polar_express_coeffs[:min(steps, len(polar_express_coeffs))]:
|
||||
A = X @ X.mT
|
||||
B = b * A + c * (A @ A)
|
||||
X = a * X + B @ X
|
||||
|
||||
if G.size(-2) > G.size(-1):
|
||||
X = X.mT
|
||||
return X
|
||||
|
||||
|
||||
@torch.compile
|
||||
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
|
||||
"""
|
||||
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
||||
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
||||
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
||||
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
||||
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
||||
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||||
performance at all relative to UV^T, where USV^T = G is the SVD.
|
||||
"""
|
||||
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||
X = G.bfloat16()
|
||||
if G.size(-2) > G.size(-1):
|
||||
X = X.mT
|
||||
|
||||
# Ensure spectral norm is at most 1
|
||||
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
||||
# Perform the NS iterations
|
||||
for _ in range(steps):
|
||||
A = X @ X.mT
|
||||
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
||||
X = a * X + B @ X
|
||||
|
||||
if G.size(-2) > G.size(-1):
|
||||
X = X.mT
|
||||
return X
|
||||
|
||||
|
||||
@torch.compile
|
||||
def apply_variance_reduction(v: Tensor, second_momentum_buffer: Tensor, beta2: float) -> Tensor:
|
||||
"""
|
||||
NorMuon-style variance reduction, similar to Adafactor's low-rank variance estimator.
|
||||
https://arxiv.org/pdf/2510.05491
|
||||
|
||||
Normalizes updates based on a running estimate of per-row (or per-column) variance.
|
||||
The reduction dimension is determined by the shape of second_momentum_buffer.
|
||||
"""
|
||||
# Determine reduction dimension from buffer shape
|
||||
red_dim = -1 if second_momentum_buffer.size(-1) == 1 else -2
|
||||
|
||||
# Compute per-row/col mean of squared values
|
||||
v_mean = v.float().square().mean(dim=red_dim, keepdim=True)
|
||||
red_dim_size = v.size(red_dim)
|
||||
|
||||
# Compute current norm
|
||||
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
|
||||
v_norm = v_norm_sq.sqrt()
|
||||
|
||||
# Update second momentum buffer (EMA of variance)
|
||||
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
|
||||
|
||||
# Compute scaling factor from second momentum
|
||||
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
|
||||
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
|
||||
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
|
||||
|
||||
# Final scale preserves overall norm while adjusting per-row/col
|
||||
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
|
||||
return v.mul(final_scale.to(v.dtype))
|
||||
|
||||
|
||||
class Muon(torch.optim.Optimizer):
|
||||
"""
|
||||
Muon - MomentUm Orthogonalized by Newton-schulz
|
||||
|
||||
https://kellerjordan.github.io/posts/muon/
|
||||
|
||||
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
|
||||
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
|
||||
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
|
||||
the advantage that it can be stably run in bfloat16 on the GPU.
|
||||
|
||||
Some warnings:
|
||||
- This optimizer should not be used for the embedding layer, the final fully connected layer,
|
||||
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
|
||||
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
|
||||
|
||||
Arguments:
|
||||
lr: The learning rate used by the internal SGD.
|
||||
momentum: The momentum used by the internal SGD.
|
||||
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
|
||||
ns_steps: The number of Newton-Schulz iteration steps to use.
|
||||
beta2: The decay rate for the second moment (variance) estimate. Set to None to disable.
|
||||
weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
|
||||
"""
|
||||
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5, beta2=0.95, weight_decay=0.0):
|
||||
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
|
||||
params: list[Tensor] = [*params]
|
||||
param_groups = []
|
||||
for size in {p.numel() for p in params}:
|
||||
group = dict(params=[p for p in params if p.numel() == size])
|
||||
param_groups.append(group)
|
||||
super().__init__(param_groups, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
for group in self.param_groups:
|
||||
params: list[Tensor] = group["params"]
|
||||
for p in params:
|
||||
g = p.grad
|
||||
assert g is not None
|
||||
state = self.state[p]
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros_like(g)
|
||||
buf: Tensor = state["momentum_buffer"]
|
||||
buf.lerp_(g, 1 - group["momentum"])
|
||||
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
|
||||
g = zeropower_via_polar_express(g, steps=group["ns_steps"])
|
||||
# Variance reduction (NorMuon-style)
|
||||
if group["beta2"] is not None:
|
||||
if "second_momentum_buffer" not in state:
|
||||
# Buffer shape determines reduction dim: reduce along larger dimension
|
||||
if p.size(-2) >= p.size(-1):
|
||||
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1])
|
||||
else:
|
||||
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1, :])
|
||||
g = apply_variance_reduction(g, state["second_momentum_buffer"], group["beta2"])
|
||||
# Parameter update with cautious weight decay
|
||||
effective_lr = group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5
|
||||
wd = group["weight_decay"]
|
||||
if wd != 0:
|
||||
mask = (g * p) >= 0
|
||||
p.sub_(effective_lr * g + effective_lr * wd * p * mask)
|
||||
else:
|
||||
p.sub_(effective_lr * g)
|
||||
|
||||
|
||||
class DistMuon(torch.optim.Optimizer):
|
||||
"""
|
||||
Muon: SGD-momentum + (optional) Nesterov, then orthogonalize the 2D update via Polar Express,
|
||||
finally apply aspect-ratio scaled step. Performs its own distributed synchronization:
|
||||
- reduce_scatter(AVG) for gradient averaging
|
||||
- all_gather to replicate updated weights
|
||||
|
||||
Notes:
|
||||
* Designed for 2D parameters (e.g., linear/conv kernels reshaped to 2D). Do not use for 0D/1D
|
||||
params like embeddings or scalars.
|
||||
* Momentum buffers are maintained only on the 'owner' rank for each parameter (rank chosen
|
||||
by block-cyclic assignment below). If you checkpoint optimizer state on a single rank,
|
||||
consolidate states beforehand.
|
||||
|
||||
Args:
|
||||
params: iterable of Tensors
|
||||
lr: learning rate
|
||||
momentum: momentum coefficient in [0,1)
|
||||
nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf
|
||||
ns_steps: number of Newton-Schulz iterations for the orthogonalization
|
||||
beta2: decay rate for second moment (variance) estimate. Set to None to disable.
|
||||
weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
|
||||
"""
|
||||
def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
|
||||
nesterov: bool = True, ns_steps: int = 5, beta2: float = 0.95, weight_decay: float = 0.0):
|
||||
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
|
||||
params = list(params)
|
||||
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
|
||||
rank = dist.get_rank()
|
||||
# Group all parameters by their shape
|
||||
shapes = sorted({p.shape for p in params}) # sort to ensure consistent / deterministic ordering
|
||||
param_groups = []
|
||||
for shape in shapes:
|
||||
group_params = [p for p in params if p.shape == shape]
|
||||
device, dtype = group_params[0].device, group_params[0].dtype
|
||||
assert all(p.device == device for p in group_params)
|
||||
assert all(p.dtype == dtype for p in group_params)
|
||||
if rank == 0:
|
||||
print(f"Muon: Grouping {len(group_params)} params of shape {shape}, device {device}, dtype {dtype}")
|
||||
param_groups.append(dict(params=group_params, zero_buffer=torch.zeros_like(group_params[0])))
|
||||
super().__init__(param_groups, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
# Ensure all grads exist
|
||||
assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
|
||||
|
||||
# Kick off all the reduce scatter operations to average up the gradients across all ranks
|
||||
all_reduce_futures = []
|
||||
for group in self.param_groups:
|
||||
params = group["params"]
|
||||
zero_buffer = group["zero_buffer"]
|
||||
# Go through params in groups of world_size.
|
||||
for base_i in range(0, len(params), world_size):
|
||||
# The compute owner of each param is rank i % world_size
|
||||
owner_idx = base_i + rank
|
||||
# each rank stacks up its chunk of world_size params into a list
|
||||
rs_input = [p.grad for p in params[base_i:base_i + world_size]]
|
||||
# pad rs_input with the zero buffer to complete the group
|
||||
rs_input.extend([zero_buffer] * (world_size - len(rs_input)))
|
||||
# the output buffer gets strided across the group based on the rank
|
||||
rs_output = params[owner_idx].grad if owner_idx < len(params) else torch.empty_like(zero_buffer)
|
||||
# reduce scatter the gradients within this group of world_size params
|
||||
work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
||||
all_reduce_futures.append(work)
|
||||
|
||||
# Now each rank computes the update and gathers
|
||||
future_idx = 0
|
||||
all_gather_futures = []
|
||||
for group in self.param_groups:
|
||||
params = group["params"]
|
||||
zero_buffer = group["zero_buffer"]
|
||||
# Go through params in groups of world_size.
|
||||
for base_i in range(0, len(params), world_size):
|
||||
# The compute owner of each param is rank i % world_size
|
||||
owner_idx = base_i + rank # calculate the index of the param that this rank owns
|
||||
# Wait for the reduce scatter to complete
|
||||
all_reduce_futures[future_idx].wait() # possibly later we could use wait_any polling instead
|
||||
future_idx += 1
|
||||
# Owner computes the Muon update, result is in its param
|
||||
if owner_idx < len(params):
|
||||
p = params[owner_idx]
|
||||
g = p.grad # now averaged across ranks
|
||||
state = self.state[p]
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros_like(g)
|
||||
buf: Tensor = state["momentum_buffer"]
|
||||
buf.lerp_(g, 1.0 - group["momentum"])
|
||||
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
|
||||
g = zeropower_via_polar_express(g, steps=group["ns_steps"])
|
||||
# Variance reduction (NorMuon-style)
|
||||
if group["beta2"] is not None:
|
||||
if "second_momentum_buffer" not in state:
|
||||
# Buffer shape determines reduction dim: reduce along larger dimension
|
||||
if p.size(-2) >= p.size(-1):
|
||||
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1])
|
||||
else:
|
||||
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1, :])
|
||||
g = apply_variance_reduction(g, state["second_momentum_buffer"], group["beta2"])
|
||||
# Parameter update with cautious weight decay
|
||||
effective_lr = group["lr"] * (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
|
||||
wd = group["weight_decay"]
|
||||
if wd != 0:
|
||||
mask = (g * p) >= 0
|
||||
p.sub_(effective_lr * g + effective_lr * wd * p * mask)
|
||||
else:
|
||||
p.sub_(effective_lr * g)
|
||||
# Replicate updated parameters to all ranks
|
||||
ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer
|
||||
ag_output = params[base_i:base_i + world_size]
|
||||
ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))]) # pad
|
||||
work = dist.all_gather(ag_output, ag_input, async_op=True).get_future()
|
||||
all_gather_futures.append(work)
|
||||
|
||||
# Wait for all work to finish
|
||||
torch.futures.collect_all(all_gather_futures).wait()
|
||||
528
nanochat/optim.py
Normal file
528
nanochat/optim.py
Normal file
@@ -0,0 +1,528 @@
|
||||
"""
|
||||
A nice and efficient mixed AdamW/Muon Combined Optimizer.
|
||||
Usually the embeddings and scalars go into AdamW, and the matrix parameters go into Muon.
|
||||
Two versions are provided (MuonAdamW, DistMuonAdamW), for single GPU and distributed.
|
||||
|
||||
Addapted from: https://github.com/KellerJordan/modded-nanogpt
|
||||
Further contributions from @karpathy and @chrisjmccormick.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
"""
|
||||
Good old AdamW optimizer, fused kernel.
|
||||
https://arxiv.org/abs/1711.05101
|
||||
"""
|
||||
|
||||
@torch.compile(dynamic=False, fullgraph=True)
|
||||
def adamw_step_fused(
|
||||
p: Tensor, # (32768, 768) - parameter tensor
|
||||
grad: Tensor, # (32768, 768) - gradient, same shape as p
|
||||
exp_avg: Tensor, # (32768, 768) - first moment, same shape as p
|
||||
exp_avg_sq: Tensor, # (32768, 768) - second moment, same shape as p
|
||||
step_t: Tensor, # () - 0-D CPU tensor, step count
|
||||
lr_t: Tensor, # () - 0-D CPU tensor, learning rate
|
||||
beta1_t: Tensor, # () - 0-D CPU tensor, beta1
|
||||
beta2_t: Tensor, # () - 0-D CPU tensor, beta2
|
||||
eps_t: Tensor, # () - 0-D CPU tensor, epsilon
|
||||
wd_t: Tensor, # () - 0-D CPU tensor, weight decay
|
||||
) -> None:
|
||||
"""
|
||||
Fused AdamW step: weight_decay -> momentum_update -> bias_correction -> param_update
|
||||
All in one compiled graph to eliminate Python overhead between ops.
|
||||
The 0-D CPU tensors avoid recompilation when hyperparameter values change.
|
||||
"""
|
||||
# Weight decay (decoupled, applied before the update)
|
||||
p.mul_(1 - lr_t * wd_t)
|
||||
# Update running averages (lerp_ is cleaner and fuses well)
|
||||
exp_avg.lerp_(grad, 1 - beta1_t)
|
||||
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
|
||||
# Bias corrections
|
||||
bias1 = 1 - beta1_t ** step_t
|
||||
bias2 = 1 - beta2_t ** step_t
|
||||
# Compute update and apply
|
||||
denom = (exp_avg_sq / bias2).sqrt() + eps_t
|
||||
step_size = lr_t / bias1
|
||||
p.add_(exp_avg / denom, alpha=-step_size)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
"""
|
||||
Muon optimizer adapted and simplified from modded-nanogpt.
|
||||
https://github.com/KellerJordan/modded-nanogpt
|
||||
|
||||
Background:
|
||||
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
||||
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
||||
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
||||
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
||||
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
||||
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||||
performance at all relative to UV^T, where USV^T = G is the SVD.
|
||||
|
||||
Here, an alternative to Newton-Schulz iteration with potentially better convergence properties:
|
||||
Polar Express Sign Method for orthogonalization.
|
||||
https://arxiv.org/pdf/2505.16932
|
||||
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
|
||||
|
||||
Some of the changes in nanochat implementation:
|
||||
- Uses a simpler, more general approach to parameter grouping and stacking
|
||||
- Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step
|
||||
- Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format)
|
||||
"""
|
||||
|
||||
# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
|
||||
# From https://arxiv.org/pdf/2505.16932
|
||||
polar_express_coeffs = [
|
||||
(8.156554524902461, -22.48329292557795, 15.878769915207462),
|
||||
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
|
||||
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
|
||||
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
|
||||
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
|
||||
]
|
||||
|
||||
@torch.compile(dynamic=False, fullgraph=True)
|
||||
def muon_step_fused(
|
||||
stacked_grads: Tensor, # (12, 768, 3072) - stacked gradients
|
||||
stacked_params: Tensor, # (12, 768, 3072) - stacked parameters
|
||||
momentum_buffer: Tensor, # (12, 768, 3072) - first moment buffer
|
||||
second_momentum_buffer: Tensor, # (12, 768, 1) or (12, 1, 3072) - factored second moment
|
||||
momentum_t: Tensor, # () - 0-D CPU tensor, momentum coefficient
|
||||
lr_t: Tensor, # () - 0-D CPU tensor, learning rate
|
||||
wd_t: Tensor, # () - 0-D CPU tensor, weight decay
|
||||
beta2_t: Tensor, # () - 0-D CPU tensor, beta2 for second moment
|
||||
ns_steps: int, # 5 - number of Newton-Schulz/Polar Express iterations
|
||||
red_dim: int, # -1 or -2 - reduction dimension for variance
|
||||
) -> None:
|
||||
"""
|
||||
Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update
|
||||
All in one compiled graph to eliminate Python overhead between ops.
|
||||
Some of the constants are 0-D CPU tensors to avoid recompilation when values change.
|
||||
"""
|
||||
|
||||
# Nesterov momentum
|
||||
momentum = momentum_t.to(stacked_grads.dtype)
|
||||
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
|
||||
g = stacked_grads.lerp_(momentum_buffer, momentum)
|
||||
|
||||
# Polar express
|
||||
X = g.bfloat16()
|
||||
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
|
||||
if g.size(-2) > g.size(-1): # Tall matrix
|
||||
for a, b, c in polar_express_coeffs[:ns_steps]:
|
||||
A = X.mT @ X
|
||||
B = b * A + c * (A @ A)
|
||||
X = a * X + X @ B
|
||||
else: # Wide matrix (original math)
|
||||
for a, b, c in polar_express_coeffs[:ns_steps]:
|
||||
A = X @ X.mT
|
||||
B = b * A + c * (A @ A)
|
||||
X = a * X + B @ X
|
||||
g = X
|
||||
|
||||
# Variance reduction
|
||||
beta2 = beta2_t.to(g.dtype)
|
||||
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
|
||||
red_dim_size = g.size(red_dim)
|
||||
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
|
||||
v_norm = v_norm_sq.sqrt()
|
||||
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
|
||||
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
|
||||
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
|
||||
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
|
||||
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
|
||||
g = g * final_scale.to(g.dtype)
|
||||
|
||||
# Cautious weight decay + parameter update
|
||||
lr = lr_t.to(g.dtype)
|
||||
wd = wd_t.to(g.dtype)
|
||||
mask = (g * stacked_params) >= 0
|
||||
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Single GPU version of the MuonAdamW optimizer.
|
||||
# Used mostly for reference, debugging and testing.
|
||||
|
||||
class MuonAdamW(torch.optim.Optimizer):
|
||||
"""
|
||||
Combined optimizer: Muon for 2D matrix params, AdamW for others, single GPU version.
|
||||
|
||||
AdamW - Fused AdamW optimizer step.
|
||||
|
||||
Muon - MomentUm Orthogonalized by Newton-schulz
|
||||
https://kellerjordan.github.io/posts/muon/
|
||||
|
||||
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
|
||||
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
|
||||
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
|
||||
the advantage that it can be stably run in bfloat16 on the GPU.
|
||||
|
||||
Some warnings:
|
||||
- The Muon optimizer should not be used for the embedding layer, the final fully connected layer,
|
||||
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
|
||||
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
|
||||
|
||||
Arguments:
|
||||
param_groups: List of dicts, each containing:
|
||||
- 'params': List of parameters
|
||||
- 'kind': 'adamw' or 'muon'
|
||||
- For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
|
||||
- For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
|
||||
"""
|
||||
def __init__(self, param_groups: list[dict]):
|
||||
super().__init__(param_groups, defaults={})
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
# AdamW tensors
|
||||
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
# Muon tensors
|
||||
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
|
||||
def _step_adamw(self, group: dict) -> None:
|
||||
"""
|
||||
AdamW update for each param in the group individually.
|
||||
Lazy init the state, fill in all 0-D tensors, call the fused kernel.
|
||||
"""
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad
|
||||
state = self.state[p]
|
||||
|
||||
# State init
|
||||
if not state:
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p)
|
||||
exp_avg = state['exp_avg']
|
||||
exp_avg_sq = state['exp_avg_sq']
|
||||
state['step'] += 1
|
||||
|
||||
# Fill 0-D tensors with current values
|
||||
self._adamw_step_t.fill_(state['step'])
|
||||
self._adamw_lr_t.fill_(group['lr'])
|
||||
self._adamw_beta1_t.fill_(group['betas'][0])
|
||||
self._adamw_beta2_t.fill_(group['betas'][1])
|
||||
self._adamw_eps_t.fill_(group['eps'])
|
||||
self._adamw_wd_t.fill_(group['weight_decay'])
|
||||
|
||||
# Fused update: weight_decay -> momentum -> bias_correction -> param_update
|
||||
adamw_step_fused(
|
||||
p, grad, exp_avg, exp_avg_sq,
|
||||
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
|
||||
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
|
||||
)
|
||||
|
||||
def _step_muon(self, group: dict) -> None:
|
||||
"""
|
||||
Muon update for all params in the group (stacked for efficiency).
|
||||
Lazy init the state, fill in all 0-D tensors, call the fused kernel.
|
||||
"""
|
||||
params: list[Tensor] = group['params']
|
||||
if not params:
|
||||
return
|
||||
|
||||
# Get or create group-level buffers (stored in first param's state for convenience)
|
||||
p = params[0]
|
||||
state = self.state[p]
|
||||
num_params = len(params)
|
||||
shape, device, dtype = p.shape, p.device, p.dtype
|
||||
|
||||
# Momentum for every individual parameter
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
|
||||
momentum_buffer = state["momentum_buffer"]
|
||||
|
||||
# Second momentum buffer is factored, either per-row or per-column
|
||||
if "second_momentum_buffer" not in state:
|
||||
state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1])
|
||||
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
|
||||
second_momentum_buffer = state["second_momentum_buffer"]
|
||||
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
||||
|
||||
# Stack grads and params (NOTE: this assumes all params have the same shape)
|
||||
stacked_grads = torch.stack([p.grad for p in params])
|
||||
stacked_params = torch.stack(params)
|
||||
|
||||
# Fill all the 0-D tensors with current values
|
||||
self._muon_momentum_t.fill_(group["momentum"])
|
||||
self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
||||
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
||||
self._muon_wd_t.fill_(group["weight_decay"])
|
||||
|
||||
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
|
||||
muon_step_fused(
|
||||
stacked_grads,
|
||||
stacked_params,
|
||||
momentum_buffer,
|
||||
second_momentum_buffer,
|
||||
self._muon_momentum_t,
|
||||
self._muon_lr_t,
|
||||
self._muon_wd_t,
|
||||
self._muon_beta2_t,
|
||||
group["ns_steps"],
|
||||
red_dim,
|
||||
)
|
||||
|
||||
# Copy back to original params
|
||||
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
for group in self.param_groups:
|
||||
if group['kind'] == 'adamw':
|
||||
self._step_adamw(group)
|
||||
elif group['kind'] == 'muon':
|
||||
self._step_muon(group)
|
||||
else:
|
||||
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Distributed version of the MuonAdamW optimizer.
|
||||
# Used for training on multiple GPUs.
|
||||
|
||||
class DistMuonAdamW(torch.optim.Optimizer):
|
||||
"""
|
||||
Combined distributed optimizer: Muon for 2D matrix params, AdamW for others.
|
||||
|
||||
See MuonAdamW for the algorithmic details of each optimizer. This class adds
|
||||
distributed communication to enable multi-GPU training without PyTorch DDP.
|
||||
|
||||
Design Goals:
|
||||
- Overlap communication with computation (async ops)
|
||||
- Minimize memory by sharding optimizer states across ranks (ZeRO-2 style)
|
||||
- Batch small tensors into single comm ops where possible
|
||||
|
||||
Communication Pattern (3-phase async):
|
||||
We use a 3-phase structure to maximize overlap between communication and compute:
|
||||
|
||||
Phase 1: Launch all async reduce ops
|
||||
- Kick off all reduce_scatter/all_reduce operations
|
||||
- Don't wait - let them run in background while we continue
|
||||
|
||||
Phase 2: Wait for reduces, compute updates, launch gathers
|
||||
- For each group: wait for its reduce, compute the update, launch gather
|
||||
- By processing groups in order, earlier gathers run while later computes happen
|
||||
|
||||
Phase 3: Wait for gathers, copy back
|
||||
- Wait for all gathers to complete
|
||||
- Copy updated params back to original tensors (Muon only)
|
||||
|
||||
AdamW Communication (ZeRO-2 style):
|
||||
- Small params (<1024 elements): all_reduce gradients, update full param on each rank.
|
||||
Optimizer state is replicated but these params are tiny (scalars, biases).
|
||||
- Large params: reduce_scatter gradients so each rank gets 1/N of the grad, update
|
||||
only that slice, then all_gather the updated slices. Optimizer state (exp_avg,
|
||||
exp_avg_sq) is sharded - each rank only stores state for its slice.
|
||||
Requires param.shape[0] divisible by world_size.
|
||||
|
||||
Muon Communication (stacked + chunked):
|
||||
- All params in a Muon group must have the same shape (caller's responsibility).
|
||||
- Stack all K params into a single (K, *shape) tensor for efficient comm.
|
||||
- Divide K params across N ranks: each rank "owns" ceil(K/N) params.
|
||||
- reduce_scatter the stacked grads so each rank gets its chunk.
|
||||
- Each rank computes Muon update only for params it owns.
|
||||
- all_gather the updated params back to all ranks.
|
||||
- Optimizer state (momentum_buffer, second_momentum_buffer) is sharded by chunk.
|
||||
- Padding: if K doesn't divide evenly, we zero-pad to (ceil(K/N) * N) for comm,
|
||||
then ignore the padding when copying back.
|
||||
|
||||
Buffer Reuse:
|
||||
- For Muon, we allocate stacked_grads for reduce_scatter input, then reuse the
|
||||
same buffer as the output for all_gather (stacked_params). This saves memory
|
||||
since we don't need both buffers simultaneously.
|
||||
|
||||
Arguments:
|
||||
param_groups: List of dicts, each containing:
|
||||
- 'params': List of parameters
|
||||
- 'kind': 'adamw' or 'muon'
|
||||
- For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
|
||||
- For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
|
||||
"""
|
||||
def __init__(self, param_groups: list[dict]):
|
||||
super().__init__(param_groups, defaults={})
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
|
||||
def _reduce_adamw(self, group: dict, world_size: int) -> dict:
|
||||
"""Launch async reduce ops for AdamW group. Returns info dict with per-param infos."""
|
||||
param_infos = {}
|
||||
for p in group['params']:
|
||||
grad = p.grad
|
||||
if p.numel() < 1024:
|
||||
# Small params: all_reduce (no scatter/gather needed)
|
||||
future = dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
||||
param_infos[p] = dict(future=future, grad_slice=grad, is_small=True)
|
||||
else:
|
||||
# Large params: reduce_scatter
|
||||
rank_size = grad.shape[0] // world_size
|
||||
grad_slice = torch.empty_like(grad[:rank_size])
|
||||
future = dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
||||
param_infos[p] = dict(future=future, grad_slice=grad_slice, is_small=False)
|
||||
return dict(param_infos=param_infos)
|
||||
|
||||
def _reduce_muon(self, group: dict, world_size: int) -> dict:
|
||||
"""Launch async reduce op for Muon group. Returns info dict."""
|
||||
params = group['params']
|
||||
chunk_size = (len(params) + world_size - 1) // world_size
|
||||
padded_num_params = chunk_size * world_size
|
||||
p = params[0]
|
||||
shape, device, dtype = p.shape, p.device, p.dtype
|
||||
|
||||
# Stack grads and zero-pad to padded_num_params
|
||||
grad_stack = torch.stack([p.grad for p in params])
|
||||
stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device)
|
||||
stacked_grads[:len(params)].copy_(grad_stack)
|
||||
if len(params) < padded_num_params:
|
||||
stacked_grads[len(params):].zero_()
|
||||
|
||||
# Reduce_scatter to get this rank's chunk
|
||||
grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
||||
future = dist.reduce_scatter_tensor(grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
||||
|
||||
return dict(future=future, grad_chunk=grad_chunk, stacked_grads=stacked_grads, chunk_size=chunk_size)
|
||||
|
||||
def _compute_adamw(self, group: dict, info: dict, gather_list: list, rank: int, world_size: int) -> None:
|
||||
"""Wait for reduce, compute AdamW updates, launch gathers for large params."""
|
||||
param_infos = info['param_infos']
|
||||
for p in group['params']:
|
||||
pinfo = param_infos[p]
|
||||
pinfo['future'].wait()
|
||||
grad_slice = pinfo['grad_slice']
|
||||
state = self.state[p]
|
||||
|
||||
# For small params, operate on full param; for large, operate on slice
|
||||
if pinfo['is_small']:
|
||||
p_slice = p
|
||||
else:
|
||||
rank_size = p.shape[0] // world_size
|
||||
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
|
||||
|
||||
# State init
|
||||
if not state:
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p_slice)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p_slice)
|
||||
state['step'] += 1
|
||||
|
||||
# Fill 0-D tensors and run fused kernel
|
||||
self._adamw_step_t.fill_(state['step'])
|
||||
self._adamw_lr_t.fill_(group['lr'])
|
||||
self._adamw_beta1_t.fill_(group['betas'][0])
|
||||
self._adamw_beta2_t.fill_(group['betas'][1])
|
||||
self._adamw_eps_t.fill_(group['eps'])
|
||||
self._adamw_wd_t.fill_(group['weight_decay'])
|
||||
adamw_step_fused(
|
||||
p_slice, grad_slice, state['exp_avg'], state['exp_avg_sq'],
|
||||
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
|
||||
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
|
||||
)
|
||||
|
||||
# Large params need all_gather
|
||||
if not pinfo['is_small']:
|
||||
future = dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()
|
||||
gather_list.append(dict(future=future, params=None))
|
||||
|
||||
def _compute_muon(self, group: dict, info: dict, gather_list: list, rank: int) -> None:
|
||||
"""Wait for reduce, compute Muon updates, launch gather."""
|
||||
info['future'].wait()
|
||||
params = group['params']
|
||||
chunk_size = info['chunk_size']
|
||||
grad_chunk = info['grad_chunk']
|
||||
p = params[0]
|
||||
shape, device, dtype = p.shape, p.device, p.dtype
|
||||
|
||||
# How many params does this rank own?
|
||||
start_idx = rank * chunk_size
|
||||
num_owned = min(chunk_size, max(0, len(params) - start_idx))
|
||||
|
||||
# Get or create group-level state
|
||||
state = self.state[p]
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
|
||||
if "second_momentum_buffer" not in state:
|
||||
state_shape = (chunk_size, shape[-2], 1) if shape[-2] >= shape[-1] else (chunk_size, 1, shape[-1])
|
||||
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
|
||||
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
||||
|
||||
# Build output buffer for all_gather
|
||||
updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
||||
|
||||
if num_owned > 0:
|
||||
owned_params = [params[start_idx + i] for i in range(num_owned)]
|
||||
stacked_owned = torch.stack(owned_params)
|
||||
|
||||
# Fill 0-D tensors and run fused kernel
|
||||
self._muon_momentum_t.fill_(group["momentum"])
|
||||
self._muon_beta2_t.fill_(group["beta2"])
|
||||
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
||||
self._muon_wd_t.fill_(group["weight_decay"])
|
||||
muon_step_fused(
|
||||
grad_chunk[:num_owned], stacked_owned,
|
||||
state["momentum_buffer"][:num_owned], state["second_momentum_buffer"][:num_owned],
|
||||
self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t, self._muon_beta2_t,
|
||||
group["ns_steps"], red_dim,
|
||||
)
|
||||
updated_params[:num_owned].copy_(stacked_owned)
|
||||
|
||||
if num_owned < chunk_size:
|
||||
updated_params[num_owned:].zero_()
|
||||
|
||||
# Reuse stacked_grads buffer for all_gather output
|
||||
stacked_params = info["stacked_grads"]
|
||||
future = dist.all_gather_into_tensor(stacked_params, updated_params, async_op=True).get_future()
|
||||
gather_list.append(dict(future=future, stacked_params=stacked_params, params=params))
|
||||
|
||||
def _finish_gathers(self, gather_list: list) -> None:
|
||||
"""Wait for all gathers and copy Muon params back."""
|
||||
for info in gather_list:
|
||||
info["future"].wait()
|
||||
if info["params"] is not None:
|
||||
# Muon: copy from stacked buffer back to individual params
|
||||
torch._foreach_copy_(info["params"], list(info["stacked_params"][:len(info["params"])].unbind(0)))
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
# Phase 1: launch all async reduce ops
|
||||
reduce_infos: list[dict] = []
|
||||
for group in self.param_groups:
|
||||
if group['kind'] == 'adamw':
|
||||
reduce_infos.append(self._reduce_adamw(group, world_size))
|
||||
elif group['kind'] == 'muon':
|
||||
reduce_infos.append(self._reduce_muon(group, world_size))
|
||||
else:
|
||||
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
|
||||
|
||||
# Phase 2: wait for reduces, compute updates, launch gathers
|
||||
gather_list: list[dict] = []
|
||||
for group, info in zip(self.param_groups, reduce_infos):
|
||||
if group['kind'] == 'adamw':
|
||||
self._compute_adamw(group, info, gather_list, rank, world_size)
|
||||
elif group['kind'] == 'muon':
|
||||
self._compute_muon(group, info, gather_list, rank)
|
||||
else:
|
||||
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
|
||||
|
||||
# Phase 3: wait for gathers, copy back
|
||||
self._finish_gathers(gather_list)
|
||||
@@ -26,7 +26,7 @@ SPECIAL_TOKENS = [
|
||||
|
||||
# NOTE: this split pattern deviates from GPT-4 in that we use \p{N}{1,2} instead of \p{N}{1,3}
|
||||
# I did this because I didn't want to "waste" too many tokens on numbers for smaller vocab sizes.
|
||||
# I haven't validated that this is actually a good idea, TODO.
|
||||
# I verified that 2 is the sweet spot for vocab size of 32K. 1 is a bit worse, 3 was worse still.
|
||||
SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@@ -23,6 +23,7 @@ dependencies = [
|
||||
"transformers>=4.57.3",
|
||||
"uvicorn>=0.36.0",
|
||||
"wandb>=0.21.3",
|
||||
"zstandard>=0.25.0",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
|
||||
@@ -17,9 +17,10 @@ if [ -z "$SKIP_SETUP" ]; then
|
||||
uv sync --extra gpu
|
||||
source .venv/bin/activate
|
||||
|
||||
# Tokenizer
|
||||
python -m nanochat.dataset -n 240
|
||||
python -m scripts.tok_train --max_chars=2000000000 --vocab_size=32768
|
||||
# Tokenizer, download 1000 shards for pretraining
|
||||
# (probably this can be reduced but it's tricky to determine the exact right number, TODO).
|
||||
python -m nanochat.dataset -n 1000
|
||||
python -m scripts.tok_train --max-chars=2000000000 --vocab-size=32768
|
||||
else
|
||||
source .venv/bin/activate
|
||||
fi
|
||||
@@ -57,16 +58,15 @@ for d in "${DEPTHS[@]}"; do
|
||||
START_TIME=$(date +%s)
|
||||
|
||||
# Train the model with natural horizon (target_param_data_ratio default)
|
||||
# No --target_flops, let it use the default ratio from base_train
|
||||
# No --target-flops, let it use the default ratio from base_train
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \
|
||||
--depth=$d \
|
||||
--target_param_data_ratio=8 \
|
||||
--run="${WANDB_RUN}_d${d}" \
|
||||
--model_tag="${TAG}" \
|
||||
--core_metric_every=999999 \
|
||||
--core_metric_max_per_task=-1 \
|
||||
--sample_every=-1 \
|
||||
--save_every=-1 \
|
||||
--model-tag="${TAG}" \
|
||||
--core-metric-every=999999 \
|
||||
--core-metric-max-per-task=-1 \
|
||||
--sample-every=-1 \
|
||||
--save-every=-1 \
|
||||
2>&1 | tee "$RESULTS_DIR/${TAG}_train.log"
|
||||
|
||||
END_TIME=$(date +%s)
|
||||
@@ -20,18 +20,18 @@ curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-publ
|
||||
|
||||
# train tokenizer on ~4B characters and kick off download of the rest for pretraining
|
||||
python -m nanochat.dataset -n 16
|
||||
# start downloading the rest of the shards for a total of 800 (see below why 800)
|
||||
python -m nanochat.dataset -n 800 &
|
||||
# start downloading the rest of the shards for a total of 1200 (see below why 1200)
|
||||
python -m nanochat.dataset -n 1200 &
|
||||
# todo: download the rest of it
|
||||
python -m scripts.tok_train --max_chars=4000000000 --vocab_size=65536
|
||||
python -m scripts.tok_train --max-chars=4000000000 --vocab-size=65536
|
||||
python -m scripts.tok_eval
|
||||
|
||||
# Documenting my process for determining the hyperparameters for this run1000.sh script:
|
||||
# We want a budget of approx. $1000 ~= 41.6 hours of 8XH100 compute
|
||||
# 1) I guessed the model size for this to be about depth=32
|
||||
# 2) Determine the device_batch_size that fits:
|
||||
# Running the base_train.py script with --depth=32, I saw that --device_batch_size=16
|
||||
# runs out of memory, but --device_batch_size=8 fits. Inspecting `nvidia-smi` during training,
|
||||
# Running the base_train.py script with --depth=32, I saw that --device-batch-size=16
|
||||
# runs out of memory, but --device-batch-size=8 fits. Inspecting `nvidia-smi` during training,
|
||||
# I saw all GPUs were at about 78/80GB VRAM, so it just barely fits and we have good MFU at ~50%.
|
||||
# So the training script was running ok and showed:
|
||||
# Vocab size: 65,536
|
||||
@@ -62,7 +62,9 @@ python -m scripts.tok_eval
|
||||
# The tok_eval.py script reports about ~4.8 chars/token on average for the default tokenizer settings.
|
||||
# So ~38B tokens # ~4.8 chars/token = ~185B chars.
|
||||
# Each data shard is ~250M chars, so we need ~185B / 250M ~= 740 shards.
|
||||
# For safety, I bumped that up to 800 shards, and that's why up above I used -n 800 when pre-downloading dataset shards.
|
||||
# For safety, I bumped that up to 800 shards.
|
||||
# The new DataLoader wastes about 35% of tokens to cropping, so 800 / (1 - 0.35) ~= 1200 shards are needed.
|
||||
# => why up above I used -n 1200 when pre-downloading dataset shards.
|
||||
# If we didn't have enough data, the training script would loop around and do multiple epochs over the same data,
|
||||
# which would decrease model performance. Possibly 2, 3 or so epochs is ~ok, but certainly not ideal and at 10+ epochs we'd
|
||||
# start to overfit hard.
|
||||
@@ -71,13 +73,13 @@ python -m scripts.tok_eval
|
||||
# Number of processes/GPUs to use
|
||||
NPROC_PER_NODE=8
|
||||
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=32 --target_param_data_ratio=20 --device_batch_size=8 --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=32 --target-param-data-ratio=20 --device-batch-size=8 --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval
|
||||
|
||||
# midtrain
|
||||
# NOTE: ensure that we use the same device_batch_size here as the base training script.
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device_batch_size=8 --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device-batch-size=8 --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid
|
||||
|
||||
# sft
|
||||
70
runs/runcpu.sh
Executable file
70
runs/runcpu.sh
Executable file
@@ -0,0 +1,70 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks)
|
||||
# This script was last updated/tuned on Jan 17, 2026.
|
||||
|
||||
# Run as:
|
||||
# bash dev/cpu_demo_run.sh
|
||||
|
||||
# NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook.
|
||||
# Think of this run as educational/fun demo, not something you should expect to work well.
|
||||
# (This is why I hide this script away in dev/)
|
||||
# You may also want to run this script manually and one by one, copy pasting commands into your terminal.
|
||||
|
||||
# all the setup stuff
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
|
||||
mkdir -p $NANOCHAT_BASE_DIR
|
||||
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
[ -d ".venv" ] || uv venv
|
||||
uv sync --extra cpu
|
||||
source .venv/bin/activate
|
||||
if [ -z "$WANDB_RUN" ]; then
|
||||
WANDB_RUN=dummy
|
||||
fi
|
||||
|
||||
# train tokenizer on ~2B characters (~34 seconds on my MacBook Pro M3 Max)
|
||||
python -m nanochat.dataset -n 8
|
||||
python -m scripts.tok_train --max-chars=2000000000
|
||||
python -m scripts.tok_eval
|
||||
|
||||
# train a small 4 layer model
|
||||
# I tuned this run to complete in about 30 minutes on my MacBook Pro M3 Max.
|
||||
# To get better results, try increasing num_iterations, or get other ideas from your favorite LLM.
|
||||
python -m scripts.base_train \
|
||||
--depth=6 \
|
||||
--head-dim=64 \
|
||||
--window-pattern=L \
|
||||
--max-seq-len=512 \
|
||||
--device-batch-size=32 \
|
||||
--total-batch-size=16384 \
|
||||
--eval-every=100 \
|
||||
--eval-tokens=524288 \
|
||||
--core-metric-every=-1 \
|
||||
--sample-every=100 \
|
||||
--num-iterations=5000 \
|
||||
--run=$WANDB_RUN
|
||||
python -m scripts.base_loss --device-batch-size=1 --split-tokens=16384
|
||||
python -m scripts.base_eval --max-per-task=16
|
||||
|
||||
# midtraining (~10 minutes on my MacBook Pro M3 Max)
|
||||
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
|
||||
python -m scripts.mid_train \
|
||||
--max-seq-len=512 \
|
||||
--device-batch-size=32 \
|
||||
--total-batch-size=16384 \
|
||||
--eval-every=200 \
|
||||
--eval-tokens=524288 \
|
||||
--num-iterations=1500 \
|
||||
--run=$WANDB_RUN
|
||||
|
||||
# (it's ~ok to skip SFT)
|
||||
|
||||
# Chat with the model over CLI
|
||||
# The model should be able to say that it is Paris.
|
||||
# It might even know that the color of the sky is blue.
|
||||
# Sometimes the model likes it if you first say Hi before you ask it questions.
|
||||
# python -m scripts.chat_cli -i mid -p "What is the capital of France?"
|
||||
|
||||
# Chat with the model over a pretty WebUI ChatGPT style
|
||||
# python -m scripts.chat_web -i mid
|
||||
@@ -1,26 +1,30 @@
|
||||
#!/bin/bash
|
||||
|
||||
LABEL="jan26"
|
||||
|
||||
FLOPS_BUDGETS=(
|
||||
1e18
|
||||
3e18
|
||||
6e18
|
||||
2.15e18
|
||||
4.64e18
|
||||
1e19
|
||||
)
|
||||
DEPTHS=(8 10 12 14 16 18 20)
|
||||
|
||||
NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
|
||||
WANDB_RUN="${WANDB_RUN:-scaling}"
|
||||
WANDB_RUN="${WANDB_RUN:-scaling_${LABEL}}"
|
||||
EVAL_TOKENS=$((100 * 524288)) # ~100M tokens for final eval (default is ~10M)
|
||||
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR="${NANOCHAT_BASE_DIR:-$HOME/.cache/nanochat}"
|
||||
source .venv/bin/activate
|
||||
|
||||
RESULTS_DIR="$NANOCHAT_BASE_DIR/scaling_laws_results"
|
||||
RESULTS_DIR="$NANOCHAT_BASE_DIR/scaling_laws_results_${LABEL}"
|
||||
mkdir -p "$RESULTS_DIR"
|
||||
RESULTS_FILE="$RESULTS_DIR/results.csv"
|
||||
|
||||
# Write CSV header only if file doesn't exist
|
||||
if [ ! -f "$RESULTS_FILE" ]; then
|
||||
echo "flops_budget,depth,model_dim,num_params,num_scaling_params,num_iterations,tokens_trained,param_data_ratio,val_bpb,core_score,train_time_sec" > "$RESULTS_FILE"
|
||||
echo "flops_budget,depth,model_dim,params_wte,params_bigram_embed,params_value_embeds,params_lm_head,params_transformer,params_scalars,params_total,num_iterations,tokens_trained,val_bpb,core_score,train_time_sec" > "$RESULTS_FILE"
|
||||
fi
|
||||
|
||||
log() {
|
||||
@@ -64,15 +68,15 @@ for flops in "${FLOPS_BUDGETS[@]}"; do
|
||||
# CORE eval happens once at the end (999999 ensures only final step)
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \
|
||||
--depth=$d \
|
||||
--target_flops=$flops \
|
||||
--target_param_data_ratio=-1 \
|
||||
--target-flops=$flops \
|
||||
--target-param-data-ratio=-1 \
|
||||
--run="${WANDB_RUN}_${TAG}" \
|
||||
--model_tag="${TAG}" \
|
||||
--eval_tokens=$EVAL_TOKENS \
|
||||
--core_metric_every=999999 \
|
||||
--core_metric_max_per_task=-1 \
|
||||
--sample_every=-1 \
|
||||
--save_every=-1 \
|
||||
--model-tag="${TAG}" \
|
||||
--eval-tokens=$EVAL_TOKENS \
|
||||
--core-metric-every=999999 \
|
||||
--core-metric-max-per-task=-1 \
|
||||
--sample-every=-1 \
|
||||
--save-every=-1 \
|
||||
2>&1 | tee "$RESULTS_DIR/${TAG}_train.log"
|
||||
|
||||
END_TIME=$(date +%s)
|
||||
@@ -80,13 +84,19 @@ for flops in "${FLOPS_BUDGETS[@]}"; do
|
||||
|
||||
# Extract training stats from the log
|
||||
LOG_FILE="$RESULTS_DIR/${TAG}_train.log"
|
||||
NUM_PARAMS=$(grep "Number of parameters:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | head -1 | tr -d ',')
|
||||
NUM_SCALING_PARAMS=$(grep "Number of parameters:" "$LOG_FILE" | tail -1 | grep -oP 'scaling: [\d,]+' | grep -oP '[\d,]+' | tr -d ',')
|
||||
|
||||
# Extract detailed parameter counts (for scaling law analysis with different conventions)
|
||||
PARAMS_WTE=$(grep "wte:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
PARAMS_BIGRAM=$(grep "bigram_embed:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
PARAMS_VE=$(grep "value_embeds:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
PARAMS_LM=$(grep "lm_head:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
PARAMS_TRANSFORMER=$(grep "transformer_matrices:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
PARAMS_SCALARS=$(grep "scalars:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
PARAMS_TOTAL=$(grep "total:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
|
||||
NUM_ITERS=$(grep "Calculated number of iterations" "$LOG_FILE" | tail -1 | sed 's/.*: //' | tr -d ',')
|
||||
# Calculate tokens trained (iterations * batch_size, default 524288)
|
||||
TOKENS_TRAINED=$((NUM_ITERS * 524288))
|
||||
# Param:data ratio (using scaling params per Kaplan et al.)
|
||||
PARAM_DATA_RATIO=$(python -c "print(f'{$TOKENS_TRAINED / $NUM_SCALING_PARAMS:.2f}')")
|
||||
# Model dim
|
||||
MODEL_DIM=$((d * 64))
|
||||
# Val BPB from final eval
|
||||
@@ -99,10 +109,10 @@ for flops in "${FLOPS_BUDGETS[@]}"; do
|
||||
CORE_SCORE="0.0"
|
||||
fi
|
||||
|
||||
log " Params: $NUM_PARAMS, Iters: $NUM_ITERS, Ratio: $PARAM_DATA_RATIO, Val BPB: $VAL_BPB, CORE: $CORE_SCORE"
|
||||
log " Params: $PARAMS_TOTAL (transformer: $PARAMS_TRANSFORMER), Iters: $NUM_ITERS, Val BPB: $VAL_BPB, CORE: $CORE_SCORE"
|
||||
|
||||
# Append to CSV
|
||||
echo "$flops,$d,$MODEL_DIM,$NUM_PARAMS,$NUM_SCALING_PARAMS,$NUM_ITERS,$TOKENS_TRAINED,$PARAM_DATA_RATIO,$VAL_BPB,$CORE_SCORE,$TRAIN_TIME" >> "$RESULTS_FILE"
|
||||
echo "$flops,$d,$MODEL_DIM,$PARAMS_WTE,$PARAMS_BIGRAM,$PARAMS_VE,$PARAMS_LM,$PARAMS_TRANSFORMER,$PARAMS_SCALARS,$PARAMS_TOTAL,$NUM_ITERS,$TOKENS_TRAINED,$VAL_BPB,$CORE_SCORE,$TRAIN_TIME" >> "$RESULTS_FILE"
|
||||
done
|
||||
done
|
||||
|
||||
@@ -55,11 +55,11 @@ python -m nanochat.report reset
|
||||
# each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk
|
||||
python -m nanochat.dataset -n 8
|
||||
# Immediately also kick off downloading more shards in the background while tokenizer trains
|
||||
# See comment below for why 240 is the right number here
|
||||
python -m nanochat.dataset -n 240 &
|
||||
# See comment below for why 370 is the right number here
|
||||
python -m nanochat.dataset -n 370 &
|
||||
DATASET_DOWNLOAD_PID=$!
|
||||
# train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data
|
||||
python -m scripts.tok_train --max_chars=2000000000 --vocab_size=65536
|
||||
# train the tokenizer with vocab size 2**15 = 32768 on ~2B characters of data
|
||||
python -m scripts.tok_train
|
||||
# evaluate the tokenizer (report compression ratio etc.)
|
||||
python -m scripts.tok_eval
|
||||
|
||||
@@ -70,7 +70,9 @@ python -m scripts.tok_eval
|
||||
# Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens.
|
||||
# Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars.
|
||||
# At 250M chars/shard, this is 54B / 250M ~= 216 shards needed for pretraining.
|
||||
# Round up to 240 for safety. At ~100MB/shard, this downloads ~24GB of data to disk.
|
||||
# Round up to 240 for safety. Also, the new DataLoader wastes about 35% of tokens to cropping
|
||||
# so 240 / (1 - 0.35) = 370 shards are needed.
|
||||
# At ~100MB/shard, this downloads ~37GB of data to disk.
|
||||
# (The total number of shards available in the entire dataset is 1822.)
|
||||
echo "Waiting for dataset download to complete..."
|
||||
wait $DATASET_DOWNLOAD_PID
|
||||
@@ -79,7 +81,7 @@ wait $DATASET_DOWNLOAD_PID
|
||||
NPROC_PER_NODE=8
|
||||
|
||||
# pretrain the d20 model
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --target_param_data_ratio=20 --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --target-param-data-ratio=20 --run=$WANDB_RUN
|
||||
# evaluate the model on a larger chunk of train/val data and draw some samples
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
|
||||
# evaluate the model on CORE tasks
|
||||
@@ -7,14 +7,14 @@ Example run as:
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
|
||||
|
||||
To evaluate a HuggingFace model:
|
||||
python -m scripts.base_loss --hf_path openai-community/gpt2
|
||||
python -m scripts.base_loss --hf-path openai-community/gpt2
|
||||
"""
|
||||
import argparse
|
||||
from contextlib import nullcontext
|
||||
import torch
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit
|
||||
from nanochat.tokenizer import get_token_bytes, HuggingFaceTokenizer
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.engine import Engine
|
||||
@@ -61,12 +61,12 @@ def get_hf_token_bytes(tokenizer, device="cpu"):
|
||||
|
||||
# CLI arguments
|
||||
parser = argparse.ArgumentParser(description="Evaluate loss on train/val splits and sample from model")
|
||||
parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--split_tokens", type=int, default=40*524288, help="number of tokens to evaluate per split")
|
||||
parser.add_argument("--model_tag", type=str, default=None, help="model tag for checkpoint directory")
|
||||
parser.add_argument("--model_step", type=int, default=None, help="model step to load")
|
||||
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--hf_path", type=str, default=None, help="HuggingFace model path (e.g. openai-community/gpt2)")
|
||||
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--split-tokens", type=int, default=40*524288, help="number of tokens to evaluate per split")
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="model tag for checkpoint directory")
|
||||
parser.add_argument("--model-step", type=int, default=None, help="model step to load")
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--hf-path", type=str, default=None, help="HuggingFace model path (e.g. openai-community/gpt2)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load the base model and the tokenizer
|
||||
@@ -97,14 +97,14 @@ assert args.split_tokens % tokens_per_step == 0, "split_tokens must be divisible
|
||||
steps = args.split_tokens // tokens_per_step
|
||||
bpb_results = {}
|
||||
for split_name in ["train", "val"]:
|
||||
loader = tokenizing_distributed_data_loader(tokenizer, args.device_batch_size, sequence_len, split_name, device=device)
|
||||
loader = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, sequence_len, split_name, device=device)
|
||||
with autocast_ctx:
|
||||
bpb = evaluate_bpb(model, loader, steps, token_bytes)
|
||||
print0(f"{split_name} bpb: {bpb:.4f}")
|
||||
bpb_results[split_name] = bpb
|
||||
print0(f"Model: {model_name}, {split_name} bpb: {bpb:.6f}")
|
||||
|
||||
# Master process also samples from the model (only for nanochat models)
|
||||
# Master process also samples from the model for some basic knowledge-eliciting prompts (only for nanochat models)
|
||||
samples = []
|
||||
if ddp_rank == 0 and args.hf_path is None:
|
||||
prompts = [
|
||||
@@ -122,9 +122,23 @@ if ddp_rank == 0 and args.hf_path is None:
|
||||
with autocast_ctx:
|
||||
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
|
||||
sample_str = tokenizer.decode(sample[0])
|
||||
print0("-" * 80)
|
||||
print0(sample_str)
|
||||
samples.append(sample_str)
|
||||
|
||||
# Draw some unconditioned samples from the model (only for nanochat models)
|
||||
unconditioned_samples = []
|
||||
if ddp_rank == 0 and args.hf_path is None:
|
||||
engine = Engine(model, tokenizer)
|
||||
tokens = tokenizer("", prepend="<|bos|>")
|
||||
with autocast_ctx:
|
||||
samples, _ = engine.generate_batch(tokens, num_samples=8, max_tokens=128, temperature=1.0)
|
||||
for sample in samples:
|
||||
sample_str = tokenizer.decode(sample)
|
||||
print0("-" * 80)
|
||||
print0(sample_str)
|
||||
unconditioned_samples.append(sample_str)
|
||||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
get_report().log(section="Base model loss", data=[
|
||||
@@ -134,6 +148,7 @@ get_report().log(section="Base model loss", data=[
|
||||
"val bpb": bpb_results["val"],
|
||||
},
|
||||
{f"sample {i}": sample for i, sample in enumerate(samples)},
|
||||
{f"unconditioned sample {i}": sample for i, sample in enumerate(unconditioned_samples)},
|
||||
])
|
||||
|
||||
# Cleanup
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
"""
|
||||
Train model. From root directory of the project, run as:
|
||||
|
||||
python -m scripts.base_train.py
|
||||
python -m scripts.base_train
|
||||
|
||||
or distributed as:
|
||||
|
||||
torchrun --nproc_per_node=8 -m scripts.base_train.py
|
||||
torchrun --nproc_per_node=8 -m scripts.base_train
|
||||
|
||||
If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
|
||||
python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_every=-1 --total_batch_size=512 --num_iterations=20
|
||||
python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -21,12 +21,13 @@ import wandb
|
||||
import torch
|
||||
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader, tokenizing_distributed_data_loader_with_state
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops
|
||||
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.engine import Engine
|
||||
from nanochat.flash_attention import HAS_FA3
|
||||
from scripts.base_eval import evaluate_model
|
||||
print_banner()
|
||||
|
||||
@@ -36,40 +37,40 @@ parser = argparse.ArgumentParser(description="Pretrain base model")
|
||||
# Logging
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
# Runtime
|
||||
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
# Model architecture
|
||||
parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model")
|
||||
parser.add_argument("--aspect_ratio", type=int, default=64, help="model_dim = depth * aspect_ratio")
|
||||
parser.add_argument("--head_dim", type=int, default=128, help="target head dimension for attention")
|
||||
parser.add_argument("--max_seq_len", type=int, default=2048, help="max context length")
|
||||
parser.add_argument("--window_pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
|
||||
parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio")
|
||||
parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention")
|
||||
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
|
||||
parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
|
||||
# Training horizon (only one used, in order of precedence)
|
||||
parser.add_argument("--num_iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
|
||||
parser.add_argument("--target_flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
|
||||
parser.add_argument("--target_param_data_ratio", type=int, default=8, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
|
||||
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
|
||||
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
|
||||
parser.add_argument("--target-param-data-ratio", type=float, default=10.5, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
|
||||
# Optimization
|
||||
parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--total_batch_size", type=int, default=524288, help="total batch size in tokens")
|
||||
parser.add_argument("--embedding_lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)")
|
||||
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--scalar_lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding")
|
||||
parser.add_argument("--warmup_ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
|
||||
parser.add_argument("--warmdown_ratio", type=float, default=0.4, help="ratio of iterations for LR warmdown")
|
||||
parser.add_argument("--final_lr_frac", type=float, default=0.0, help="final LR as fraction of initial LR")
|
||||
parser.add_argument("--resume_from_step", type=int, default=-1, help="resume training from this step (-1 = disable)")
|
||||
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
|
||||
parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--weight-decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)")
|
||||
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
|
||||
parser.add_argument("--adam-beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding")
|
||||
parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding")
|
||||
parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
|
||||
parser.add_argument("--warmdown-ratio", type=float, default=0.4, help="ratio of iterations for LR warmdown")
|
||||
parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR")
|
||||
parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)")
|
||||
# Evaluation
|
||||
parser.add_argument("--eval_every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)")
|
||||
parser.add_argument("--eval_tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
|
||||
parser.add_argument("--core_metric_every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)")
|
||||
parser.add_argument("--core_metric_max_per_task", type=int, default=500, help="examples per task for CORE metric")
|
||||
parser.add_argument("--sample_every", type=int, default=2000, help="sample from model every N steps (-1 = disable)")
|
||||
parser.add_argument("--save_every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)")
|
||||
parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)")
|
||||
parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
|
||||
parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)")
|
||||
parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric")
|
||||
parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)")
|
||||
parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)")
|
||||
# Output
|
||||
parser.add_argument("--model_tag", type=str, default=None, help="override model tag for checkpoint directory name")
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy() # for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -81,11 +82,29 @@ master_process = ddp_rank == 0 # this process will do logging, checkpointing etc
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||
if device_type == "cuda":
|
||||
gpu_device_name = torch.cuda.get_device_name(0)
|
||||
gpu_peak_flops = get_peak_flops(gpu_device_name)
|
||||
print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}")
|
||||
else:
|
||||
gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = args.run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config)
|
||||
|
||||
# Flash Attention status
|
||||
if HAS_FA3:
|
||||
print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.")
|
||||
else:
|
||||
print0("!" * 80)
|
||||
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback")
|
||||
print0("WARNING: Training will be less efficient without FA3")
|
||||
if args.window_pattern != "L":
|
||||
print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.")
|
||||
print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.")
|
||||
print0("!" * 80)
|
||||
|
||||
# Tokenizer will be useful for evaluation, also we need the vocab size
|
||||
tokenizer = get_tokenizer()
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
@@ -93,21 +112,19 @@ vocab_size = tokenizer.get_vocab_size()
|
||||
print0(f"Vocab size: {vocab_size:,}")
|
||||
|
||||
# Model kwargs are derived from the desired depth of the model
|
||||
# We nudge model_dim up to the nearest multiple of head_dim to ensure clean division
|
||||
# (FA3 requires head_dim divisible by 8, and this guarantees head_dim == args.head_dim exactly)
|
||||
# (For very small depths, this gives a slight "unfair" advantage to models with odd depths)
|
||||
num_layers = args.depth
|
||||
model_dim = args.depth * args.aspect_ratio
|
||||
def find_num_heads(model_dim, target_head_dim):
|
||||
# Find num_heads that divides model_dim evenly, with head_dim closest to target.
|
||||
ideal = max(1, round(model_dim / target_head_dim))
|
||||
for offset in range(model_dim):
|
||||
for candidate in [ideal + offset, ideal - offset]:
|
||||
if candidate > 0 and model_dim % candidate == 0:
|
||||
return candidate
|
||||
return 1
|
||||
num_heads = find_num_heads(model_dim, args.head_dim)
|
||||
base_dim = args.depth * args.aspect_ratio
|
||||
model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim
|
||||
num_heads = model_dim // args.head_dim
|
||||
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
|
||||
head_dim = model_dim // num_heads
|
||||
print0(f"num_layers: {num_layers}")
|
||||
print0(f"model_dim: {model_dim}")
|
||||
print0(f"model_dim: {model_dim} (base: {base_dim}, nudge: {model_dim - base_dim:+d})")
|
||||
print0(f"num_heads: {num_heads}")
|
||||
print0(f"head_dim: {head_dim}")
|
||||
print0(f"num_kv_heads: {num_kv_heads}")
|
||||
|
||||
# Optimizer / data / training length related hyperparameters
|
||||
@@ -161,9 +178,14 @@ if resuming:
|
||||
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
|
||||
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
||||
num_params = sum(p.numel() for p in model.parameters())
|
||||
num_scaling_params = orig_model.num_scaling_params()
|
||||
print0(f"Number of parameters: {num_params:,} (scaling: {num_scaling_params:,})")
|
||||
|
||||
# Detailed parameter counts
|
||||
param_counts = orig_model.num_scaling_params()
|
||||
print0(f"Parameter counts:")
|
||||
for key, value in param_counts.items():
|
||||
print0(f"{key:24s}: {value:,}")
|
||||
num_params = param_counts['total']
|
||||
num_scaling_params = param_counts['transformer_matrices'] + param_counts['lm_head'] # determined to give the cleanest scaling laws, see dev/LOG.md Jan 27, 2026
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
|
||||
|
||||
@@ -178,20 +200,20 @@ elif args.target_flops > 0:
|
||||
print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
|
||||
elif args.target_param_data_ratio > 0:
|
||||
# calculate the number of iterations from the target param data ratio (use scaling params per Kaplan et al.)
|
||||
target_tokens = args.target_param_data_ratio * num_scaling_params
|
||||
target_tokens = int(args.target_param_data_ratio * num_scaling_params)
|
||||
num_iterations = target_tokens // args.total_batch_size
|
||||
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
|
||||
else:
|
||||
raise ValueError("No training horizon specified")
|
||||
total_tokens = args.total_batch_size * num_iterations
|
||||
print0(f"Total number of training tokens: {total_tokens:,}")
|
||||
print0(f"Tokens : Params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20
|
||||
print0(f"Tokens : Scaling params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20
|
||||
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
|
||||
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
|
||||
adam_betas = (args.adam_beta1, args.adam_beta2)
|
||||
optimizers = model.setup_optimizers(
|
||||
optimizer = model.setup_optimizer(
|
||||
unembedding_lr=args.unembedding_lr * batch_lr_scale,
|
||||
embedding_lr=args.embedding_lr * batch_lr_scale,
|
||||
matrix_lr=args.matrix_lr * batch_lr_scale,
|
||||
@@ -199,19 +221,16 @@ optimizers = model.setup_optimizers(
|
||||
adam_betas=adam_betas,
|
||||
scalar_lr=args.scalar_lr * batch_lr_scale,
|
||||
)
|
||||
adamw_optimizer, muon_optimizer = optimizers
|
||||
|
||||
if resuming:
|
||||
for opt, dat in zip(optimizers, optimizer_data):
|
||||
opt.load_state_dict(dat)
|
||||
del optimizer_data # free up the memory
|
||||
optimizer.load_state_dict(optimizer_data)
|
||||
del optimizer_data
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the DataLoaders for train/val
|
||||
tokens_dir = os.path.join(base_dir, "tokenized_data")
|
||||
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
|
||||
train_loader = tokenizing_distributed_data_loader_with_state(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device)
|
||||
train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device)
|
||||
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -323,7 +342,7 @@ while True:
|
||||
checkpoint_dir,
|
||||
step,
|
||||
orig_model.state_dict(), # model parameters
|
||||
[opt.state_dict() for opt in optimizers], # optimizer states
|
||||
optimizer.state_dict(), # optimizer state
|
||||
{ # metadata saved as json
|
||||
"step": step,
|
||||
"val_bpb": val_bpb, # loss at last step
|
||||
@@ -357,33 +376,31 @@ while True:
|
||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
loss.backward()
|
||||
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
# step the optimizers
|
||||
# step the optimizer
|
||||
lrm = get_lr_multiplier(step)
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
muon_momentum = get_muon_momentum(step)
|
||||
muon_weight_decay = get_weight_decay(step)
|
||||
for group in muon_optimizer.param_groups:
|
||||
group["momentum"] = muon_momentum
|
||||
group["weight_decay"] = muon_weight_decay
|
||||
for opt in optimizers:
|
||||
opt.step()
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
if group['kind'] == 'muon':
|
||||
group["momentum"] = muon_momentum
|
||||
group["weight_decay"] = muon_weight_decay
|
||||
optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point
|
||||
synchronize()
|
||||
t1 = time.time()
|
||||
dt = t1 - t0
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
# logging
|
||||
# logging (CPU action only)
|
||||
ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging
|
||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f # EMA the training loss
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
||||
pct_done = 100 * step / num_iterations
|
||||
tok_per_sec = int(args.total_batch_size / dt)
|
||||
flops_per_sec = num_flops_per_token * args.total_batch_size / dt
|
||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size)
|
||||
if step > 10:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
# Calculate ETA based on average time per step (excluding first 10 steps)
|
||||
@@ -395,7 +412,8 @@ while True:
|
||||
eta_str = f" | eta: {eta_seconds/60:.1f}m"
|
||||
else:
|
||||
eta_str = ""
|
||||
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m{eta_str}")
|
||||
epoch = dataloader_state_dict["epoch"]
|
||||
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}")
|
||||
if step % 100 == 0:
|
||||
log_data = {
|
||||
"step": step,
|
||||
@@ -406,6 +424,7 @@ while True:
|
||||
"train/dt": dt,
|
||||
"train/tok_per_sec": tok_per_sec,
|
||||
"train/mfu": mfu,
|
||||
"train/epoch": epoch,
|
||||
}
|
||||
wandb_run.log(log_data)
|
||||
|
||||
@@ -427,7 +446,7 @@ get_report().log(section="Base model training", data=[
|
||||
"Number of FLOPs per token": f"{num_flops_per_token:e}",
|
||||
"Calculated number of iterations": num_iterations,
|
||||
"Number of training tokens": total_tokens,
|
||||
"Tokens : Params ratio": args.total_batch_size * num_iterations / num_params,
|
||||
"Tokens : Scaling params ratio": args.total_batch_size * num_iterations / num_scaling_params,
|
||||
"DDP world size": ddp_world_size,
|
||||
"warmup_ratio": args.warmup_ratio,
|
||||
"warmdown_ratio": args.warmdown_ratio,
|
||||
|
||||
@@ -4,8 +4,8 @@ All the generic code lives here, and all the evaluation-specific
|
||||
code lives in nanochat directory and is imported from here.
|
||||
|
||||
Example runs:
|
||||
python -m scripts.chat_eval -a ARC-Easy
|
||||
torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy
|
||||
python -m scripts.chat_eval -i mid -a ARC-Easy
|
||||
torchrun --nproc_per_node=8 -m scripts.chat_eval -- -i mid -a ARC-Easy
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
@@ -6,7 +6,7 @@ simpler and more similar to just REINFORCE:
|
||||
|
||||
1) Delete trust region, so there is no KL regularization to a reference model
|
||||
2) We are on policy, so there's no need for PPO ratio+clip.
|
||||
3) We use GAPO style normalization that is token-level, not sequence-level.
|
||||
3) We use DAPO style normalization that is token-level, not sequence-level.
|
||||
4) Instead of z-score normalization (r - mu)/sigma, only use (r - mu) as the advantage.
|
||||
|
||||
1 GPU:
|
||||
@@ -35,32 +35,32 @@ parser = argparse.ArgumentParser(description="Reinforcement learning on GSM8K")
|
||||
# Logging
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
# Runtime
|
||||
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
|
||||
# Model loading
|
||||
parser.add_argument("--source", type=str, default="sft", help="mid|sft - which checkpoint to load from")
|
||||
parser.add_argument("--model_tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model_step", type=int, default=None, help="model step to load from")
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
|
||||
# Training horizon
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="number of epochs over GSM8K")
|
||||
parser.add_argument("--num-epochs", type=int, default=1, help="number of epochs over GSM8K")
|
||||
# Batch sizes / sampling
|
||||
parser.add_argument("--device_batch_size", type=int, default=8, help="max batch size per forward pass")
|
||||
parser.add_argument("--examples_per_step", type=int, default=16, help="total examples per optimization step across all ranks")
|
||||
parser.add_argument("--num_samples", type=int, default=16, help="number of samples per example/question")
|
||||
parser.add_argument("--device-batch-size", type=int, default=8, help="max batch size per forward pass")
|
||||
parser.add_argument("--examples-per-step", type=int, default=16, help="total examples per optimization step across all ranks")
|
||||
parser.add_argument("--num-samples", type=int, default=16, help="number of samples per example/question")
|
||||
# Generation
|
||||
parser.add_argument("--max_new_tokens", type=int, default=256, help="max tokens to generate per sample")
|
||||
parser.add_argument("--max-new-tokens", type=int, default=256, help="max tokens to generate per sample")
|
||||
parser.add_argument("--temperature", type=float, default=1.0, help="sampling temperature")
|
||||
parser.add_argument("--top_k", type=int, default=50, help="top-k sampling (0 = disabled)")
|
||||
parser.add_argument("--top-k", type=int, default=50, help="top-k sampling (0 = disabled)")
|
||||
# Optimization
|
||||
parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
||||
parser.add_argument("--init_lr_frac", type=float, default=0.05, help="initial LR as fraction of base LR")
|
||||
parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
||||
parser.add_argument("--init-lr-frac", type=float, default=0.05, help="initial LR as fraction of base LR")
|
||||
# Evaluation / checkpointing
|
||||
parser.add_argument("--eval_every", type=int, default=60, help="evaluate pass@k every N steps")
|
||||
parser.add_argument("--eval_examples", type=int, default=400, help="number of examples for pass@k evaluation")
|
||||
parser.add_argument("--save_every", type=int, default=60, help="save checkpoint every N steps")
|
||||
parser.add_argument("--eval-every", type=int, default=60, help="evaluate pass@k every N steps")
|
||||
parser.add_argument("--eval-examples", type=int, default=400, help="number of examples for pass@k evaluation")
|
||||
parser.add_argument("--save-every", type=int, default=60, help="save checkpoint every N steps")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy()
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -201,7 +201,7 @@ def run_gsm8k_eval(task, tokenizer, engine,
|
||||
# Training loop
|
||||
|
||||
# Init the optimizer
|
||||
optimizers = model.setup_optimizers(
|
||||
optimizer = model.setup_optimizer(
|
||||
unembedding_lr=args.unembedding_lr,
|
||||
embedding_lr=args.embedding_lr,
|
||||
matrix_lr=args.matrix_lr,
|
||||
@@ -209,10 +209,9 @@ optimizers = model.setup_optimizers(
|
||||
)
|
||||
|
||||
# Set the initial learning rate as a fraction of the base learning rate
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
group["initial_lr"] = group["lr"]
|
||||
|
||||
# Learning rate scheduler: simple rampdown to zero over num_steps
|
||||
def get_lr_multiplier(it):
|
||||
@@ -305,11 +304,9 @@ for step in range(num_steps):
|
||||
|
||||
# Update the model parameters
|
||||
lrm = get_lr_multiplier(step)
|
||||
for opt in optimizers: # first set the learning rate
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
for opt in optimizers: # then step the optimizers
|
||||
opt.step()
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
|
||||
@@ -37,29 +37,29 @@ parser = argparse.ArgumentParser(description="Supervised finetuning for chat")
|
||||
# Logging
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
# Runtime
|
||||
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
|
||||
# Model loading
|
||||
parser.add_argument("--source", type=str, default="mid", help="base|mid - which checkpoint to load from")
|
||||
parser.add_argument("--model_tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model_step", type=int, default=None, help="model step to load from")
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
|
||||
# Training horizon
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="number of epochs")
|
||||
parser.add_argument("--num_iterations", type=int, default=-1, help="override number of iterations (-1 = use num_epochs)")
|
||||
parser.add_argument("--num-epochs", type=int, default=1, help="number of epochs")
|
||||
parser.add_argument("--num-iterations", type=int, default=-1, help="override number of iterations (-1 = use num_epochs)")
|
||||
# Batch sizes
|
||||
parser.add_argument("--device_batch_size", type=int, default=4, help="per-device batch size")
|
||||
parser.add_argument("--target_examples_per_step", type=int, default=32, help="target examples per optimization step")
|
||||
parser.add_argument("--device-batch-size", type=int, default=4, help="per-device batch size")
|
||||
parser.add_argument("--target-examples-per-step", type=int, default=32, help="target examples per optimization step")
|
||||
# Optimization
|
||||
parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
||||
parser.add_argument("--init_lr_frac", type=float, default=0.02, help="initial LR as fraction of base LR")
|
||||
parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
||||
parser.add_argument("--init-lr-frac", type=float, default=0.02, help="initial LR as fraction of base LR")
|
||||
# Evaluation
|
||||
parser.add_argument("--eval_every", type=int, default=100, help="evaluate val loss every N steps")
|
||||
parser.add_argument("--eval_steps", type=int, default=100, help="number of batches for val loss evaluation")
|
||||
parser.add_argument("--eval_metrics_every", type=int, default=200, help="evaluate accuracy metrics every N steps")
|
||||
parser.add_argument("--eval_metrics_max_problems", type=int, default=1024, help="max problems per metric evaluation")
|
||||
parser.add_argument("--eval-every", type=int, default=100, help="evaluate val loss every N steps")
|
||||
parser.add_argument("--eval-steps", type=int, default=100, help="number of batches for val loss evaluation")
|
||||
parser.add_argument("--eval-metrics-every", type=int, default=200, help="evaluate accuracy metrics every N steps")
|
||||
parser.add_argument("--eval-metrics-max-problems", type=int, default=1024, help="max problems per metric evaluation")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy()
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -150,17 +150,16 @@ build_val_loader = lambda: sft_data_generator(val_ds, batch_size=args.device_bat
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Optimizer
|
||||
|
||||
optimizers = model.setup_optimizers(
|
||||
optimizer = model.setup_optimizer(
|
||||
unembedding_lr=args.unembedding_lr,
|
||||
embedding_lr=args.embedding_lr,
|
||||
matrix_lr=args.matrix_lr,
|
||||
weight_decay=args.weight_decay,
|
||||
)
|
||||
# Set the initial learning rate as a fraction of the base learning rate
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
group["initial_lr"] = group["lr"]
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Training loop
|
||||
@@ -230,13 +229,11 @@ for step in range(num_iterations):
|
||||
|
||||
# learning rate scheduler
|
||||
lrm = get_lr_multiplier(step)
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
|
||||
# step the optimizers
|
||||
for opt in optimizers:
|
||||
opt.step()
|
||||
# step the optimizer
|
||||
optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
|
||||
# logging
|
||||
|
||||
@@ -6,11 +6,10 @@ python -m scripts.mid_train
|
||||
|
||||
Or torchrun for training:
|
||||
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device-batch-size=16
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from collections import deque
|
||||
import os
|
||||
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import time
|
||||
@@ -37,28 +36,28 @@ parser = argparse.ArgumentParser(description="Midtrain the model")
|
||||
# Logging
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
# Runtime
|
||||
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
|
||||
# Model loading
|
||||
parser.add_argument("--model_tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model_step", type=int, default=None, help="model step to load from")
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
|
||||
# Training horizon
|
||||
parser.add_argument("--num_iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)")
|
||||
parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)")
|
||||
# Batch sizes
|
||||
parser.add_argument("--max_seq_len", type=int, default=2048, help="max context length")
|
||||
parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--total_batch_size", type=int, default=524288, help="total batch size in tokens")
|
||||
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
|
||||
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
|
||||
# Optimization
|
||||
parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
||||
parser.add_argument("--init_lr_frac", type=float, default=1.0, help="initial LR as fraction of base LR")
|
||||
parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
||||
parser.add_argument("--init-lr-frac", type=float, default=1.0, help="initial LR as fraction of base LR")
|
||||
# Evaluation
|
||||
parser.add_argument("--eval_every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)")
|
||||
parser.add_argument("--eval_tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
|
||||
parser.add_argument("--eval-every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)")
|
||||
parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
|
||||
# Output
|
||||
parser.add_argument("--dry_run", action="store_true", help="log to wandb but skip checkpoints/report")
|
||||
parser.add_argument("--dry-run", action="store_true", help="log to wandb but skip checkpoints/report")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy()
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -80,7 +79,7 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mi
|
||||
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step)
|
||||
pretrain_batch_size = meta.get("device_batch_size", None)
|
||||
if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size:
|
||||
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?")
|
||||
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?")
|
||||
orig_model = model
|
||||
model = torch.compile(model, dynamic=False)
|
||||
depth = model.config.n_layer
|
||||
@@ -94,14 +93,12 @@ print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
|
||||
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
|
||||
optimizers = model.setup_optimizers(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay)
|
||||
adamw_optimizer, muon_optimizer = optimizers
|
||||
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
|
||||
optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay)
|
||||
# Override the initial learning rate as a fraction of the base learning rate
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
group["initial_lr"] = group["lr"]
|
||||
|
||||
# Midtraining data mixture and DataLoader
|
||||
base_dir = get_base_dir()
|
||||
@@ -125,49 +122,100 @@ val_dataset = TaskMixture([
|
||||
# these two global variables and update them from within the data generator.
|
||||
last_step = False # we will toggle this to True when we reach the end of the training dataset
|
||||
approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
|
||||
def mid_data_generator(split):
|
||||
global last_step, approx_progress
|
||||
current_epoch = 1 # track epoch for logging
|
||||
def mid_data_generator_bos_bestfit(split, buffer_size=100):
|
||||
"""
|
||||
BOS-aligned dataloader for midtraining with bestfit-crop packing.
|
||||
|
||||
Each row in the batch starts with BOS (beginning of a conversation).
|
||||
Conversations are packed using best-fit algorithm to minimize cropping.
|
||||
This matches the BOS-aligned approach used in pretraining.
|
||||
"""
|
||||
global last_step, approx_progress, current_epoch
|
||||
assert split in {"train", "val"}, "split must be 'train' or 'val'"
|
||||
dataset = train_dataset if split == "train" else val_dataset
|
||||
dataset_size = len(dataset)
|
||||
assert dataset_size > 0
|
||||
needed_tokens = args.device_batch_size * args.max_seq_len + 1 # to form one training batch of inputs,targets
|
||||
token_buffer = deque()
|
||||
# CUDA supports memory pinning for faster transfers between CPU and GPU:
|
||||
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=(device_type == "cuda"))
|
||||
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
|
||||
it = 0 # iteration counter
|
||||
while True:
|
||||
# Accumulate enough tokens for one iteration before yielding
|
||||
while len(token_buffer) < needed_tokens:
|
||||
row_capacity = args.max_seq_len + 1 # +1 for target at last position
|
||||
|
||||
# Conversation buffer: list of token lists
|
||||
conv_buffer = []
|
||||
cursor = ddp_rank # Each rank processes different conversations (for fetching)
|
||||
consumed = ddp_rank # Track actual consumption separately from buffering
|
||||
epoch = 1
|
||||
it = 0 # iteration counter
|
||||
|
||||
def refill_buffer():
|
||||
nonlocal cursor, epoch
|
||||
while len(conv_buffer) < buffer_size:
|
||||
conversation = dataset[cursor]
|
||||
ids, _ = tokenizer.render_conversation(conversation)
|
||||
token_buffer.extend(ids)
|
||||
conv_buffer.append(ids)
|
||||
cursor += ddp_world_size
|
||||
if cursor >= dataset_size:
|
||||
cursor -= dataset_size # wrap around for another epoch
|
||||
if split == "train":
|
||||
last_step = True # toggle last_step to True, which will terminate the training loop
|
||||
cursor = cursor % dataset_size
|
||||
epoch += 1
|
||||
# Note: last_step is now triggered based on consumption, not fetching
|
||||
|
||||
while True:
|
||||
rows = []
|
||||
for _ in range(args.device_batch_size):
|
||||
row = []
|
||||
while len(row) < row_capacity:
|
||||
# Ensure buffer has conversations
|
||||
while len(conv_buffer) < buffer_size:
|
||||
refill_buffer()
|
||||
|
||||
remaining = row_capacity - len(row)
|
||||
|
||||
# Find largest conversation that fits entirely
|
||||
best_idx = -1
|
||||
best_len = 0
|
||||
for i, conv in enumerate(conv_buffer):
|
||||
conv_len = len(conv)
|
||||
if conv_len <= remaining and conv_len > best_len:
|
||||
best_idx = i
|
||||
best_len = conv_len
|
||||
|
||||
if best_idx >= 0:
|
||||
# Found a conversation that fits - use it entirely
|
||||
conv = conv_buffer.pop(best_idx)
|
||||
row.extend(conv)
|
||||
consumed += ddp_world_size # Track actual consumption
|
||||
else:
|
||||
# No conversation fits - crop first conversation to fill remaining
|
||||
conv = conv_buffer.pop(0)
|
||||
row.extend(conv[:remaining])
|
||||
consumed += ddp_world_size # Track actual consumption
|
||||
|
||||
rows.append(row[:row_capacity])
|
||||
|
||||
# Stopping condition to respect num_iterations, if given
|
||||
it += 1
|
||||
if 0 < args.num_iterations <= it and split == "train":
|
||||
last_step = True # toggle last_step to True, which will terminate the training loop
|
||||
# Build up inputs/targets and yield
|
||||
for i in range(needed_tokens):
|
||||
scratch[i] = token_buffer.popleft()
|
||||
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
||||
targets_cpu = scratch[1:]
|
||||
inputs = inputs_cpu.view(args.device_batch_size, args.max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(args.device_batch_size, args.max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
|
||||
last_step = True
|
||||
|
||||
# Update progress tracking (based on consumed, not cursor, to account for buffering)
|
||||
if split == "train":
|
||||
current_epoch = epoch
|
||||
if args.num_iterations > 0:
|
||||
approx_progress = it / args.num_iterations # calculate progress from the max number of iterations
|
||||
approx_progress = it / args.num_iterations
|
||||
else:
|
||||
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
|
||||
approx_progress = consumed / dataset_size
|
||||
# Trigger last_step when we've consumed enough (instead of when cursor wraps)
|
||||
if consumed >= dataset_size:
|
||||
last_step = True
|
||||
|
||||
# Build tensors
|
||||
use_cuda = device_type == "cuda"
|
||||
batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
|
||||
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda)
|
||||
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda)
|
||||
|
||||
yield inputs, targets
|
||||
|
||||
train_loader = mid_data_generator("train")
|
||||
build_val_loader = lambda: mid_data_generator("val")
|
||||
train_loader = mid_data_generator_bos_bestfit("train")
|
||||
build_val_loader = lambda: mid_data_generator_bos_bestfit("val")
|
||||
progress = 0 # will go from 0 to 1 over the course of the epoch
|
||||
|
||||
# Learning rate scheduler
|
||||
@@ -199,7 +247,7 @@ while True:
|
||||
last_step = bool(last_step_tensor.item())
|
||||
|
||||
# once in a while: evaluate the val bpb (all ranks participate)
|
||||
if args.eval_every > 0 and (last_step or step % args.eval_every == 0):
|
||||
if last_step or (args.eval_every > 0 and step % args.eval_every == 0):
|
||||
model.eval()
|
||||
val_loader = build_val_loader()
|
||||
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
|
||||
@@ -224,7 +272,7 @@ while True:
|
||||
checkpoint_dir,
|
||||
step,
|
||||
orig_model.state_dict(),
|
||||
[opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly
|
||||
optimizer.state_dict(),
|
||||
{
|
||||
"step": step,
|
||||
"val_bpb": val_bpb, # loss at last step
|
||||
@@ -256,16 +304,14 @@ while True:
|
||||
loss.backward()
|
||||
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
progress = max(progress, approx_progress) # only increase progress monotonically
|
||||
# step the optimizers
|
||||
# step the optimizer
|
||||
lrm = get_lr_multiplier(progress)
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
muon_momentum = get_muon_momentum(step)
|
||||
for group in muon_optimizer.param_groups:
|
||||
group["momentum"] = muon_momentum
|
||||
for opt in optimizers:
|
||||
opt.step()
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
if group['kind'] == 'muon':
|
||||
group["momentum"] = muon_momentum
|
||||
optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
synchronize()
|
||||
t1 = time.time()
|
||||
@@ -285,7 +331,7 @@ while True:
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
if step > 10:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
|
||||
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m")
|
||||
if step % 10 == 0:
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
@@ -296,6 +342,7 @@ while True:
|
||||
"train/dt": dt,
|
||||
"train/tok_per_sec": tok_per_sec,
|
||||
"train/mfu": mfu,
|
||||
"train/epoch": current_epoch,
|
||||
})
|
||||
|
||||
# print a few more stats
|
||||
|
||||
@@ -14,9 +14,9 @@ from nanochat.dataset import parquets_iter_batched
|
||||
# Parse command line arguments
|
||||
|
||||
parser = argparse.ArgumentParser(description='Train a BPE tokenizer')
|
||||
parser.add_argument('--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)')
|
||||
parser.add_argument('--doc_cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)')
|
||||
parser.add_argument('--vocab_size', type=int, default=32768, help='Vocabulary size (default: 32768 = 2^15)')
|
||||
parser.add_argument('--max-chars', type=int, default=2_000_000_000, help='Maximum characters to train on (default: 10B)')
|
||||
parser.add_argument('--doc-cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)')
|
||||
parser.add_argument('--vocab-size', type=int, default=32768, help='Vocabulary size (default: 32768 = 2^15)')
|
||||
args = parser.parse_args()
|
||||
print(f"max_chars: {args.max_chars:,}")
|
||||
print(f"doc_cap: {args.doc_cap:,}")
|
||||
|
||||
@@ -25,7 +25,7 @@ class CustomJSON(Task):
|
||||
print("-" * 80)
|
||||
print(f"Warning: File {filepath} does not exist")
|
||||
print("HINT (Oct 21 2025)")
|
||||
print("If you recently did a git pull and suddely see this, it might be due to the new addition of identity conversations")
|
||||
print("If you recently did a git pull and suddenly see this, it might be due to the new addition of identity conversations")
|
||||
print("See this discussion for more details: https://github.com/karpathy/nanochat/discussions/139")
|
||||
print("Quick fix: simply run the following command to download the file and you're done:")
|
||||
print(f"curl -L -o {filepath} https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl")
|
||||
|
||||
338
tests/test_attention_fallback.py
Normal file
338
tests/test_attention_fallback.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""
|
||||
Test Flash Attention unified interface - verify FA3 and SDPA produce identical results.
|
||||
|
||||
Run: python -m pytest tests/test_attention_fallback.py -v -s
|
||||
|
||||
Note on test structure:
|
||||
Tests are split into two classes due to dtype/device constraints:
|
||||
|
||||
1. TestFA3VsSDPA: Comparison tests that run both FA3 and SDPA on the same inputs
|
||||
and verify they produce identical results. These require a Hopper GPU (FA3 only
|
||||
works on sm90+) and use bfloat16 (FA3 doesn't support float32).
|
||||
|
||||
2. TestSDPAOnly: Tests that only exercise the SDPA fallback path. These can run
|
||||
on any device (CUDA, CPU, MPS) with the appropriate dtype for that device.
|
||||
"""
|
||||
import torch
|
||||
import pytest
|
||||
import nanochat.flash_attention as fa_module
|
||||
from nanochat.flash_attention import flash_attn, HAS_FA3
|
||||
from nanochat.engine import KVCache
|
||||
|
||||
|
||||
def set_impl(impl):
|
||||
"""Set the implementation override ('fa3', 'sdpa', or None for auto)."""
|
||||
fa_module._override_impl = impl
|
||||
|
||||
|
||||
def run_both_impls(fn):
|
||||
"""Run a function with both FA3 and SDPA, return both outputs."""
|
||||
set_impl('fa3')
|
||||
out_fa3 = fn()
|
||||
set_impl('sdpa')
|
||||
out_sdpa = fn()
|
||||
set_impl(None) # reset
|
||||
return out_fa3, out_sdpa
|
||||
|
||||
|
||||
def assert_close(t1, t2, name, atol=1e-2, rtol=1e-2):
|
||||
"""Assert two tensors are close, with helpful error message."""
|
||||
max_diff = (t1 - t2).abs().max().item()
|
||||
mean_diff = (t1 - t2).abs().mean().item()
|
||||
assert torch.allclose(t1, t2, atol=atol, rtol=rtol), \
|
||||
f"{name}: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}"
|
||||
return max_diff, mean_diff
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# FA3 vs SDPA comparison tests (require Hopper GPU)
|
||||
# =============================================================================
|
||||
@pytest.mark.skipif(not HAS_FA3, reason="FA3 required to compare implementations")
|
||||
class TestFA3VsSDPA:
|
||||
"""Compare FA3 and SDPA produce identical results. Requires Hopper GPU."""
|
||||
|
||||
DEVICE = "cuda"
|
||||
DTYPE = torch.bfloat16
|
||||
|
||||
def test_basic_causal(self):
|
||||
"""Basic causal attention."""
|
||||
B, T, H, D = 2, 64, 4, 32
|
||||
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
|
||||
|
||||
y_fa3, y_sdpa = run_both_impls(run)
|
||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "basic_causal")
|
||||
print(f"basic_causal: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_full_context(self):
|
||||
"""Full context (window_size=-1)."""
|
||||
B, T, H, D = 2, 128, 4, 32
|
||||
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(-1, -1))
|
||||
|
||||
y_fa3, y_sdpa = run_both_impls(run)
|
||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "full_context")
|
||||
print(f"full_context: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_sliding_window(self):
|
||||
"""Sliding window attention."""
|
||||
B, T, H, D = 2, 128, 4, 32
|
||||
window = 32
|
||||
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(window, 0))
|
||||
|
||||
y_fa3, y_sdpa = run_both_impls(run)
|
||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "sliding_window")
|
||||
print(f"sliding_window: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_gqa(self):
|
||||
"""Group Query Attention (fewer KV heads than Q heads)."""
|
||||
B, T, D = 2, 64, 32
|
||||
n_heads = 8
|
||||
n_kv_heads = 2
|
||||
|
||||
q = torch.randn(B, T, n_heads, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
|
||||
|
||||
y_fa3, y_sdpa = run_both_impls(run)
|
||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "gqa")
|
||||
print(f"gqa: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_larger_model(self):
|
||||
"""Larger dimensions closer to real model."""
|
||||
B, T, H, D = 4, 256, 12, 64
|
||||
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(-1, -1))
|
||||
|
||||
y_fa3, y_sdpa = run_both_impls(run)
|
||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "larger_model")
|
||||
print(f"larger_model: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_kvcache_prefill(self):
|
||||
"""Test prefill (inserting multiple tokens into empty cache)."""
|
||||
B, T_max, H, D = 2, 64, 4, 32
|
||||
T_prefill = 16
|
||||
|
||||
q = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
cache_seqlens = torch.zeros(B, dtype=torch.int32, device=self.DEVICE)
|
||||
return flash_attn.flash_attn_with_kvcache(
|
||||
q, k_cache, v_cache, k=k, v=v,
|
||||
cache_seqlens=cache_seqlens,
|
||||
causal=True, window_size=(T_max, 0)
|
||||
)
|
||||
|
||||
y_fa3, y_sdpa = run_both_impls(run)
|
||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "prefill")
|
||||
print(f"prefill: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_kvcache_single_token(self):
|
||||
"""Test single token generation (cache already has content)."""
|
||||
B, T_max, H, D = 2, 64, 4, 32
|
||||
T_prefill = 16
|
||||
|
||||
k_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
q_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k_cache[:, :T_prefill, :, :] = k_init
|
||||
v_cache[:, :T_prefill, :, :] = v_init
|
||||
cache_seqlens = torch.full((B,), T_prefill, dtype=torch.int32, device=self.DEVICE)
|
||||
return flash_attn.flash_attn_with_kvcache(
|
||||
q_single, k_cache, v_cache, k=k_single, v=v_single,
|
||||
cache_seqlens=cache_seqlens,
|
||||
causal=True, window_size=(T_max, 0)
|
||||
)
|
||||
|
||||
y_fa3, y_sdpa = run_both_impls(run)
|
||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "single_token")
|
||||
print(f"single_token: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_backward_gradients_match(self):
|
||||
"""Verify gradients are similar between FA3 and SDPA."""
|
||||
B, T, H, D = 2, 32, 4, 16
|
||||
|
||||
q_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
q = q_data.clone().requires_grad_(True)
|
||||
k = k_data.clone().requires_grad_(True)
|
||||
v = v_data.clone().requires_grad_(True)
|
||||
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
|
||||
loss = y.sum()
|
||||
loss.backward()
|
||||
return y.detach(), q.grad.detach(), k.grad.detach(), v.grad.detach()
|
||||
|
||||
set_impl('fa3')
|
||||
y_fa3, q_grad_fa3, k_grad_fa3, v_grad_fa3 = run()
|
||||
set_impl('sdpa')
|
||||
y_sdpa, q_grad_sdpa, k_grad_sdpa, v_grad_sdpa = run()
|
||||
set_impl(None)
|
||||
|
||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "backward_output")
|
||||
print(f"backward_output: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
max_diff, mean_diff = assert_close(q_grad_fa3, q_grad_sdpa, "q_grad", atol=0.05, rtol=0.05)
|
||||
print(f"q_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
max_diff, mean_diff = assert_close(k_grad_fa3, k_grad_sdpa, "k_grad", atol=0.05, rtol=0.05)
|
||||
print(f"k_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
max_diff, mean_diff = assert_close(v_grad_fa3, v_grad_sdpa, "v_grad", atol=0.05, rtol=0.05)
|
||||
print(f"v_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SDPA-only tests (run on any device)
|
||||
# =============================================================================
|
||||
class TestSDPAOnly:
|
||||
"""Test SDPA fallback works correctly. Runs on any device."""
|
||||
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
||||
|
||||
def test_basic_forward(self):
|
||||
"""Test SDPA forward pass produces valid output."""
|
||||
set_impl('sdpa')
|
||||
B, T, H, D = 2, 64, 4, 32
|
||||
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
|
||||
|
||||
assert y.shape == (B, T, H, D)
|
||||
assert not torch.isnan(y).any(), "Output contains NaN"
|
||||
set_impl(None)
|
||||
|
||||
def test_backward(self):
|
||||
"""Test gradients flow through SDPA."""
|
||||
set_impl('sdpa')
|
||||
B, T, H, D = 2, 32, 4, 16
|
||||
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True)
|
||||
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True)
|
||||
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True)
|
||||
|
||||
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
|
||||
loss = y.sum()
|
||||
loss.backward()
|
||||
|
||||
assert q.grad is not None, "No gradient for q"
|
||||
assert k.grad is not None, "No gradient for k"
|
||||
assert v.grad is not None, "No gradient for v"
|
||||
assert not torch.isnan(q.grad).any(), "NaN in q gradient"
|
||||
set_impl(None)
|
||||
|
||||
def test_kvcache(self):
|
||||
"""Test SDPA with KV cache."""
|
||||
set_impl('sdpa')
|
||||
B, T_max, H, D = 2, 64, 4, 32
|
||||
n_layers = 1
|
||||
|
||||
cache = KVCache(
|
||||
batch_size=B, num_heads=H, seq_len=T_max, head_dim=D,
|
||||
num_layers=n_layers, device=self.DEVICE, dtype=self.DTYPE
|
||||
)
|
||||
k_cache, v_cache = cache.get_layer_cache(0)
|
||||
|
||||
# Prefill
|
||||
T_prefill = 16
|
||||
q = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
y = flash_attn.flash_attn_with_kvcache(
|
||||
q, k_cache, v_cache, k=k, v=v,
|
||||
cache_seqlens=cache.cache_seqlens,
|
||||
causal=True, window_size=(T_max, 0)
|
||||
)
|
||||
cache.advance(T_prefill)
|
||||
|
||||
assert y.shape == (B, T_prefill, H, D)
|
||||
assert cache.get_pos() == T_prefill
|
||||
|
||||
# Generate single token
|
||||
q_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
y_single = flash_attn.flash_attn_with_kvcache(
|
||||
q_single, k_cache, v_cache, k=k_single, v=v_single,
|
||||
cache_seqlens=cache.cache_seqlens,
|
||||
causal=True, window_size=(T_max, 0)
|
||||
)
|
||||
cache.advance(1)
|
||||
|
||||
assert y_single.shape == (B, 1, H, D)
|
||||
assert cache.get_pos() == T_prefill + 1
|
||||
set_impl(None)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Override mechanism tests
|
||||
# =============================================================================
|
||||
class TestOverrideMechanism:
|
||||
"""Test that the override mechanism works correctly."""
|
||||
|
||||
@pytest.mark.skipif(not HAS_FA3, reason="FA3 required")
|
||||
def test_override_fa3(self):
|
||||
"""Test that override='fa3' uses FA3."""
|
||||
set_impl('fa3')
|
||||
assert fa_module._use_fa3() == True
|
||||
set_impl(None)
|
||||
|
||||
def test_override_sdpa(self):
|
||||
"""Test that override='sdpa' uses SDPA."""
|
||||
set_impl('sdpa')
|
||||
assert fa_module._use_fa3() == False
|
||||
set_impl(None)
|
||||
|
||||
def test_override_auto(self):
|
||||
"""Test that override=None uses auto-detection."""
|
||||
set_impl(None)
|
||||
assert fa_module._use_fa3() == HAS_FA3
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f"PyTorch version: {torch.__version__}")
|
||||
print(f"CUDA available: {torch.cuda.is_available()}")
|
||||
if torch.cuda.is_available():
|
||||
print(f"CUDA device: {torch.cuda.get_device_name()}")
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
print(f"Compute capability: {major}.{minor}")
|
||||
print(f"HAS_FA3: {HAS_FA3}")
|
||||
print()
|
||||
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
@@ -96,6 +96,7 @@ def test_kv_cache_basic():
|
||||
head_dim=head_dim,
|
||||
num_layers=num_layers,
|
||||
device="cpu",
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
# Check initial state
|
||||
@@ -130,7 +131,7 @@ def test_kv_cache_prefill():
|
||||
# Create source cache and advance it
|
||||
src_cache = KVCache(
|
||||
batch_size=batch_size, num_heads=num_heads, seq_len=32,
|
||||
head_dim=head_dim, num_layers=num_layers, device="cpu",
|
||||
head_dim=head_dim, num_layers=num_layers, device="cpu", dtype=torch.float32,
|
||||
)
|
||||
# Write some data to source cache
|
||||
src_cache.k_cache[0, 0, :16, :, :] = 1.0
|
||||
@@ -140,7 +141,7 @@ def test_kv_cache_prefill():
|
||||
# Create destination cache with larger seq_len
|
||||
dst_cache = KVCache(
|
||||
batch_size=batch_size, num_heads=num_heads, seq_len=64,
|
||||
head_dim=head_dim, num_layers=num_layers, device="cpu",
|
||||
head_dim=head_dim, num_layers=num_layers, device="cpu", dtype=torch.float32,
|
||||
)
|
||||
|
||||
# Prefill
|
||||
@@ -195,3 +196,72 @@ def test_multi_sample_first_token_diversity():
|
||||
f"With uniform logits, this is statistically impossible (~10^-36 probability) "
|
||||
f"unless tokens are being broadcast instead of independently sampled."
|
||||
)
|
||||
|
||||
|
||||
def test_seed_reproducibility():
|
||||
"""Same seed must produce identical output."""
|
||||
model = MockModel()
|
||||
engine = Engine(model, ByteTokenizer())
|
||||
prompt = [261, 72, 101, 108, 108, 111] # <bos> + "Hello"
|
||||
|
||||
for seed in [1, 42, 123, 999]:
|
||||
r1, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed)
|
||||
r2, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed)
|
||||
r3, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed)
|
||||
assert r1 == r2 == r3, "Same seed must produce identical output for the same prompt."
|
||||
|
||||
|
||||
def test_temperature_zero_determinism():
|
||||
"""Temperature=0 is deterministic regardless of seed."""
|
||||
model = MockModel()
|
||||
engine = Engine(model, ByteTokenizer())
|
||||
prompt = [261, 72, 101, 108, 108, 111]
|
||||
|
||||
r1, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=1)
|
||||
r2, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=42)
|
||||
r3, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=123)
|
||||
assert r1 == r2 == r3, "Temperature=0 must result in the same output for the same prompt regardless of seed."
|
||||
|
||||
|
||||
def test_max_tokens_respected():
|
||||
"""Generation stops at max_tokens limit."""
|
||||
model = MockModel()
|
||||
engine = Engine(model, ByteTokenizer())
|
||||
prompt = [261, 72, 101, 108, 108, 111]
|
||||
|
||||
for max_tokens in [1, 4, 16, 64]:
|
||||
results, _ = engine.generate_batch(prompt, max_tokens=max_tokens)
|
||||
num_generated_tokens = len(results[0]) - len(prompt)
|
||||
assert num_generated_tokens <= max_tokens, f"Generated {num_generated_tokens} tokens, expected max_tokens={max_tokens} or less."
|
||||
|
||||
|
||||
def test_num_samples_count():
|
||||
"""num_samples=N produces exactly N sequences."""
|
||||
model = MockModel()
|
||||
engine = Engine(model, ByteTokenizer())
|
||||
prompt = [261, 72, 101, 108, 108, 111]
|
||||
|
||||
for num_samples in [1, 4, 16, 64]:
|
||||
results, _ = engine.generate_batch(prompt, num_samples=num_samples, max_tokens=3)
|
||||
assert len(results) == num_samples, f"Expected {num_samples} sequences from {num_samples} samples, got {len(results)}"
|
||||
|
||||
|
||||
def test_different_seeds_introduce_variation_when_temperature_nonzero():
|
||||
"""With temperature > 0, different seeds should introduce sampling variation."""
|
||||
model = MockModel()
|
||||
engine = Engine(model, ByteTokenizer())
|
||||
prompt = [261, 72, 101, 108, 108, 111] # <bos> + "Hello"
|
||||
|
||||
outputs = set()
|
||||
|
||||
for seed in [1, 42, 123, 999, 1000, 1001, 1002, 1003, 1004, 1005]:
|
||||
results, _ = engine.generate_batch(
|
||||
prompt,
|
||||
temperature=1.0,
|
||||
max_tokens=5,
|
||||
seed=seed,
|
||||
)
|
||||
outputs.add(tuple(results[0]))
|
||||
|
||||
# Sanity check: sampling actually introduces variation
|
||||
assert len(outputs) > 1, "All seeds produced the same output which is statistically highly improbable."
|
||||
|
||||
92
uv.lock
generated
92
uv.lock
generated
@@ -1513,6 +1513,7 @@ dependencies = [
|
||||
{ name = "transformers" },
|
||||
{ name = "uvicorn" },
|
||||
{ name = "wandb" },
|
||||
{ name = "zstandard" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
@@ -1551,6 +1552,7 @@ requires-dist = [
|
||||
{ name = "transformers", specifier = ">=4.57.3" },
|
||||
{ name = "uvicorn", specifier = ">=0.36.0" },
|
||||
{ name = "wandb", specifier = ">=0.21.3" },
|
||||
{ name = "zstandard", specifier = ">=0.25.0" },
|
||||
]
|
||||
provides-extras = ["cpu", "gpu"]
|
||||
|
||||
@@ -3619,3 +3621,93 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/94/c3/b2e9f38bc3e11191981d57ea08cab2166e74ea770024a646617c9cddd9f6/yarl-1.20.1-cp313-cp313t-win_amd64.whl", hash = "sha256:541d050a355bbbc27e55d906bc91cb6fe42f96c01413dd0f4ed5a5240513874f", size = 93003, upload-time = "2025-06-10T00:45:27.752Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b4/2d/2345fce04cfd4bee161bf1e7d9cdc702e3e16109021035dbb24db654a622/yarl-1.20.1-py3-none-any.whl", hash = "sha256:83b8eb083fe4683c6115795d9fc1cfaf2cbbefb19b3a1cb68f6527460f483a77", size = 46542, upload-time = "2025-06-10T00:46:07.521Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zstandard"
|
||||
version = "0.25.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/fd/aa/3e0508d5a5dd96529cdc5a97011299056e14c6505b678fd58938792794b1/zstandard-0.25.0.tar.gz", hash = "sha256:7713e1179d162cf5c7906da876ec2ccb9c3a9dcbdffef0cc7f70c3667a205f0b", size = 711513, upload-time = "2025-09-14T22:15:54.002Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/56/7a/28efd1d371f1acd037ac64ed1c5e2b41514a6cc937dd6ab6a13ab9f0702f/zstandard-0.25.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e59fdc271772f6686e01e1b3b74537259800f57e24280be3f29c8a0deb1904dd", size = 795256, upload-time = "2025-09-14T22:15:56.415Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/96/34/ef34ef77f1ee38fc8e4f9775217a613b452916e633c4f1d98f31db52c4a5/zstandard-0.25.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4d441506e9b372386a5271c64125f72d5df6d2a8e8a2a45a0ae09b03cb781ef7", size = 640565, upload-time = "2025-09-14T22:15:58.177Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9d/1b/4fdb2c12eb58f31f28c4d28e8dc36611dd7205df8452e63f52fb6261d13e/zstandard-0.25.0-cp310-cp310-manylinux2010_i686.manylinux2014_i686.manylinux_2_12_i686.manylinux_2_17_i686.whl", hash = "sha256:ab85470ab54c2cb96e176f40342d9ed41e58ca5733be6a893b730e7af9c40550", size = 5345306, upload-time = "2025-09-14T22:16:00.165Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/73/28/a44bdece01bca027b079f0e00be3b6bd89a4df180071da59a3dd7381665b/zstandard-0.25.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e05ab82ea7753354bb054b92e2f288afb750e6b439ff6ca78af52939ebbc476d", size = 5055561, upload-time = "2025-09-14T22:16:02.22Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e9/74/68341185a4f32b274e0fc3410d5ad0750497e1acc20bd0f5b5f64ce17785/zstandard-0.25.0-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:78228d8a6a1c177a96b94f7e2e8d012c55f9c760761980da16ae7546a15a8e9b", size = 5402214, upload-time = "2025-09-14T22:16:04.109Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8b/67/f92e64e748fd6aaffe01e2b75a083c0c4fd27abe1c8747fee4555fcee7dd/zstandard-0.25.0-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:2b6bd67528ee8b5c5f10255735abc21aa106931f0dbaf297c7be0c886353c3d0", size = 5449703, upload-time = "2025-09-14T22:16:06.312Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fd/e5/6d36f92a197c3c17729a2125e29c169f460538a7d939a27eaaa6dcfcba8e/zstandard-0.25.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4b6d83057e713ff235a12e73916b6d356e3084fd3d14ced499d84240f3eecee0", size = 5556583, upload-time = "2025-09-14T22:16:08.457Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d7/83/41939e60d8d7ebfe2b747be022d0806953799140a702b90ffe214d557638/zstandard-0.25.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9174f4ed06f790a6869b41cba05b43eeb9a35f8993c4422ab853b705e8112bbd", size = 5045332, upload-time = "2025-09-14T22:16:10.444Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b3/87/d3ee185e3d1aa0133399893697ae91f221fda79deb61adbe998a7235c43f/zstandard-0.25.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:25f8f3cd45087d089aef5ba3848cd9efe3ad41163d3400862fb42f81a3a46701", size = 5572283, upload-time = "2025-09-14T22:16:12.128Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0a/1d/58635ae6104df96671076ac7d4ae7816838ce7debd94aecf83e30b7121b0/zstandard-0.25.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:3756b3e9da9b83da1796f8809dd57cb024f838b9eeafde28f3cb472012797ac1", size = 4959754, upload-time = "2025-09-14T22:16:14.225Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/75/d6/57e9cb0a9983e9a229dd8fd2e6e96593ef2aa82a3907188436f22b111ccd/zstandard-0.25.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:81dad8d145d8fd981b2962b686b2241d3a1ea07733e76a2f15435dfb7fb60150", size = 5266477, upload-time = "2025-09-14T22:16:16.343Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d1/a9/ee891e5edf33a6ebce0a028726f0bbd8567effe20fe3d5808c42323e8542/zstandard-0.25.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:a5a419712cf88862a45a23def0ae063686db3d324cec7edbe40509d1a79a0aab", size = 5440914, upload-time = "2025-09-14T22:16:18.453Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/58/08/a8522c28c08031a9521f27abc6f78dbdee7312a7463dd2cfc658b813323b/zstandard-0.25.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:e7360eae90809efd19b886e59a09dad07da4ca9ba096752e61a2e03c8aca188e", size = 5819847, upload-time = "2025-09-14T22:16:20.559Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6f/11/4c91411805c3f7b6f31c60e78ce347ca48f6f16d552fc659af6ec3b73202/zstandard-0.25.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:75ffc32a569fb049499e63ce68c743155477610532da1eb38e7f24bf7cd29e74", size = 5363131, upload-time = "2025-09-14T22:16:22.206Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ef/d6/8c4bd38a3b24c4c7676a7a3d8de85d6ee7a983602a734b9f9cdefb04a5d6/zstandard-0.25.0-cp310-cp310-win32.whl", hash = "sha256:106281ae350e494f4ac8a80470e66d1fe27e497052c8d9c3b95dc4cf1ade81aa", size = 436469, upload-time = "2025-09-14T22:16:25.002Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/93/90/96d50ad417a8ace5f841b3228e93d1bb13e6ad356737f42e2dde30d8bd68/zstandard-0.25.0-cp310-cp310-win_amd64.whl", hash = "sha256:ea9d54cc3d8064260114a0bbf3479fc4a98b21dffc89b3459edd506b69262f6e", size = 506100, upload-time = "2025-09-14T22:16:23.569Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/83/c3ca27c363d104980f1c9cee1101cc8ba724ac8c28a033ede6aab89585b1/zstandard-0.25.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:933b65d7680ea337180733cf9e87293cc5500cc0eb3fc8769f4d3c88d724ec5c", size = 795254, upload-time = "2025-09-14T22:16:26.137Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ac/4d/e66465c5411a7cf4866aeadc7d108081d8ceba9bc7abe6b14aa21c671ec3/zstandard-0.25.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a3f79487c687b1fc69f19e487cd949bf3aae653d181dfb5fde3bf6d18894706f", size = 640559, upload-time = "2025-09-14T22:16:27.973Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/12/56/354fe655905f290d3b147b33fe946b0f27e791e4b50a5f004c802cb3eb7b/zstandard-0.25.0-cp311-cp311-manylinux2010_i686.manylinux2014_i686.manylinux_2_12_i686.manylinux_2_17_i686.whl", hash = "sha256:0bbc9a0c65ce0eea3c34a691e3c4b6889f5f3909ba4822ab385fab9057099431", size = 5348020, upload-time = "2025-09-14T22:16:29.523Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3b/13/2b7ed68bd85e69a2069bcc72141d378f22cae5a0f3b353a2c8f50ef30c1b/zstandard-0.25.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:01582723b3ccd6939ab7b3a78622c573799d5d8737b534b86d0e06ac18dbde4a", size = 5058126, upload-time = "2025-09-14T22:16:31.811Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c9/dd/fdaf0674f4b10d92cb120ccff58bbb6626bf8368f00ebfd2a41ba4a0dc99/zstandard-0.25.0-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:5f1ad7bf88535edcf30038f6919abe087f606f62c00a87d7e33e7fc57cb69fcc", size = 5405390, upload-time = "2025-09-14T22:16:33.486Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0f/67/354d1555575bc2490435f90d67ca4dd65238ff2f119f30f72d5cde09c2ad/zstandard-0.25.0-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:06acb75eebeedb77b69048031282737717a63e71e4ae3f77cc0c3b9508320df6", size = 5452914, upload-time = "2025-09-14T22:16:35.277Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bb/1f/e9cfd801a3f9190bf3e759c422bbfd2247db9d7f3d54a56ecde70137791a/zstandard-0.25.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9300d02ea7c6506f00e627e287e0492a5eb0371ec1670ae852fefffa6164b072", size = 5559635, upload-time = "2025-09-14T22:16:37.141Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/21/88/5ba550f797ca953a52d708c8e4f380959e7e3280af029e38fbf47b55916e/zstandard-0.25.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:bfd06b1c5584b657a2892a6014c2f4c20e0db0208c159148fa78c65f7e0b0277", size = 5048277, upload-time = "2025-09-14T22:16:38.807Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/46/c0/ca3e533b4fa03112facbe7fbe7779cb1ebec215688e5df576fe5429172e0/zstandard-0.25.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f373da2c1757bb7f1acaf09369cdc1d51d84131e50d5fa9863982fd626466313", size = 5574377, upload-time = "2025-09-14T22:16:40.523Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/12/9b/3fb626390113f272abd0799fd677ea33d5fc3ec185e62e6be534493c4b60/zstandard-0.25.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:6c0e5a65158a7946e7a7affa6418878ef97ab66636f13353b8502d7ea03c8097", size = 4961493, upload-time = "2025-09-14T22:16:43.3Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cb/d3/23094a6b6a4b1343b27ae68249daa17ae0651fcfec9ed4de09d14b940285/zstandard-0.25.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:c8e167d5adf59476fa3e37bee730890e389410c354771a62e3c076c86f9f7778", size = 5269018, upload-time = "2025-09-14T22:16:45.292Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8c/a7/bb5a0c1c0f3f4b5e9d5b55198e39de91e04ba7c205cc46fcb0f95f0383c1/zstandard-0.25.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:98750a309eb2f020da61e727de7d7ba3c57c97cf6213f6f6277bb7fb42a8e065", size = 5443672, upload-time = "2025-09-14T22:16:47.076Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/27/22/503347aa08d073993f25109c36c8d9f029c7d5949198050962cb568dfa5e/zstandard-0.25.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:22a086cff1b6ceca18a8dd6096ec631e430e93a8e70a9ca5efa7561a00f826fa", size = 5822753, upload-time = "2025-09-14T22:16:49.316Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e2/be/94267dc6ee64f0f8ba2b2ae7c7a2df934a816baaa7291db9e1aa77394c3c/zstandard-0.25.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:72d35d7aa0bba323965da807a462b0966c91608ef3a48ba761678cb20ce5d8b7", size = 5366047, upload-time = "2025-09-14T22:16:51.328Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7b/a3/732893eab0a3a7aecff8b99052fecf9f605cf0fb5fb6d0290e36beee47a4/zstandard-0.25.0-cp311-cp311-win32.whl", hash = "sha256:f5aeea11ded7320a84dcdd62a3d95b5186834224a9e55b92ccae35d21a8b63d4", size = 436484, upload-time = "2025-09-14T22:16:55.005Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/43/a3/c6155f5c1cce691cb80dfd38627046e50af3ee9ddc5d0b45b9b063bfb8c9/zstandard-0.25.0-cp311-cp311-win_amd64.whl", hash = "sha256:daab68faadb847063d0c56f361a289c4f268706b598afbf9ad113cbe5c38b6b2", size = 506183, upload-time = "2025-09-14T22:16:52.753Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8c/3e/8945ab86a0820cc0e0cdbf38086a92868a9172020fdab8a03ac19662b0e5/zstandard-0.25.0-cp311-cp311-win_arm64.whl", hash = "sha256:22a06c5df3751bb7dc67406f5374734ccee8ed37fc5981bf1ad7041831fa1137", size = 462533, upload-time = "2025-09-14T22:16:53.878Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/82/fc/f26eb6ef91ae723a03e16eddb198abcfce2bc5a42e224d44cc8b6765e57e/zstandard-0.25.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7b3c3a3ab9daa3eed242d6ecceead93aebbb8f5f84318d82cee643e019c4b73b", size = 795738, upload-time = "2025-09-14T22:16:56.237Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/aa/1c/d920d64b22f8dd028a8b90e2d756e431a5d86194caa78e3819c7bf53b4b3/zstandard-0.25.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:913cbd31a400febff93b564a23e17c3ed2d56c064006f54efec210d586171c00", size = 640436, upload-time = "2025-09-14T22:16:57.774Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/53/6c/288c3f0bd9fcfe9ca41e2c2fbfd17b2097f6af57b62a81161941f09afa76/zstandard-0.25.0-cp312-cp312-manylinux2010_i686.manylinux2014_i686.manylinux_2_12_i686.manylinux_2_17_i686.whl", hash = "sha256:011d388c76b11a0c165374ce660ce2c8efa8e5d87f34996aa80f9c0816698b64", size = 5343019, upload-time = "2025-09-14T22:16:59.302Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/15/efef5a2f204a64bdb5571e6161d49f7ef0fffdbca953a615efbec045f60f/zstandard-0.25.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:6dffecc361d079bb48d7caef5d673c88c8988d3d33fb74ab95b7ee6da42652ea", size = 5063012, upload-time = "2025-09-14T22:17:01.156Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b7/37/a6ce629ffdb43959e92e87ebdaeebb5ac81c944b6a75c9c47e300f85abdf/zstandard-0.25.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:7149623bba7fdf7e7f24312953bcf73cae103db8cae49f8154dd1eadc8a29ecb", size = 5394148, upload-time = "2025-09-14T22:17:03.091Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e3/79/2bf870b3abeb5c070fe2d670a5a8d1057a8270f125ef7676d29ea900f496/zstandard-0.25.0-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:6a573a35693e03cf1d67799fd01b50ff578515a8aeadd4595d2a7fa9f3ec002a", size = 5451652, upload-time = "2025-09-14T22:17:04.979Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/53/60/7be26e610767316c028a2cbedb9a3beabdbe33e2182c373f71a1c0b88f36/zstandard-0.25.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5a56ba0db2d244117ed744dfa8f6f5b366e14148e00de44723413b2f3938a902", size = 5546993, upload-time = "2025-09-14T22:17:06.781Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/85/c7/3483ad9ff0662623f3648479b0380d2de5510abf00990468c286c6b04017/zstandard-0.25.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:10ef2a79ab8e2974e2075fb984e5b9806c64134810fac21576f0668e7ea19f8f", size = 5046806, upload-time = "2025-09-14T22:17:08.415Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/08/b3/206883dd25b8d1591a1caa44b54c2aad84badccf2f1de9e2d60a446f9a25/zstandard-0.25.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:aaf21ba8fb76d102b696781bddaa0954b782536446083ae3fdaa6f16b25a1c4b", size = 5576659, upload-time = "2025-09-14T22:17:10.164Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9d/31/76c0779101453e6c117b0ff22565865c54f48f8bd807df2b00c2c404b8e0/zstandard-0.25.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1869da9571d5e94a85a5e8d57e4e8807b175c9e4a6294e3b66fa4efb074d90f6", size = 4953933, upload-time = "2025-09-14T22:17:11.857Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/18/e1/97680c664a1bf9a247a280a053d98e251424af51f1b196c6d52f117c9720/zstandard-0.25.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:809c5bcb2c67cd0ed81e9229d227d4ca28f82d0f778fc5fea624a9def3963f91", size = 5268008, upload-time = "2025-09-14T22:17:13.627Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/73/316e4010de585ac798e154e88fd81bb16afc5c5cb1a72eeb16dd37e8024a/zstandard-0.25.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:f27662e4f7dbf9f9c12391cb37b4c4c3cb90ffbd3b1fb9284dadbbb8935fa708", size = 5433517, upload-time = "2025-09-14T22:17:16.103Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5b/60/dd0f8cfa8129c5a0ce3ea6b7f70be5b33d2618013a161e1ff26c2b39787c/zstandard-0.25.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:99c0c846e6e61718715a3c9437ccc625de26593fea60189567f0118dc9db7512", size = 5814292, upload-time = "2025-09-14T22:17:17.827Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fc/5f/75aafd4b9d11b5407b641b8e41a57864097663699f23e9ad4dbb91dc6bfe/zstandard-0.25.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:474d2596a2dbc241a556e965fb76002c1ce655445e4e3bf38e5477d413165ffa", size = 5360237, upload-time = "2025-09-14T22:17:19.954Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ff/8d/0309daffea4fcac7981021dbf21cdb2e3427a9e76bafbcdbdf5392ff99a4/zstandard-0.25.0-cp312-cp312-win32.whl", hash = "sha256:23ebc8f17a03133b4426bcc04aabd68f8236eb78c3760f12783385171b0fd8bd", size = 436922, upload-time = "2025-09-14T22:17:24.398Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/79/3b/fa54d9015f945330510cb5d0b0501e8253c127cca7ebe8ba46a965df18c5/zstandard-0.25.0-cp312-cp312-win_amd64.whl", hash = "sha256:ffef5a74088f1e09947aecf91011136665152e0b4b359c42be3373897fb39b01", size = 506276, upload-time = "2025-09-14T22:17:21.429Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ea/6b/8b51697e5319b1f9ac71087b0af9a40d8a6288ff8025c36486e0c12abcc4/zstandard-0.25.0-cp312-cp312-win_arm64.whl", hash = "sha256:181eb40e0b6a29b3cd2849f825e0fa34397f649170673d385f3598ae17cca2e9", size = 462679, upload-time = "2025-09-14T22:17:23.147Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/35/0b/8df9c4ad06af91d39e94fa96cc010a24ac4ef1378d3efab9223cc8593d40/zstandard-0.25.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ec996f12524f88e151c339688c3897194821d7f03081ab35d31d1e12ec975e94", size = 795735, upload-time = "2025-09-14T22:17:26.042Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3f/06/9ae96a3e5dcfd119377ba33d4c42a7d89da1efabd5cb3e366b156c45ff4d/zstandard-0.25.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a1a4ae2dec3993a32247995bdfe367fc3266da832d82f8438c8570f989753de1", size = 640440, upload-time = "2025-09-14T22:17:27.366Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d9/14/933d27204c2bd404229c69f445862454dcc101cd69ef8c6068f15aaec12c/zstandard-0.25.0-cp313-cp313-manylinux2010_i686.manylinux2014_i686.manylinux_2_12_i686.manylinux_2_17_i686.whl", hash = "sha256:e96594a5537722fdfb79951672a2a63aec5ebfb823e7560586f7484819f2a08f", size = 5343070, upload-time = "2025-09-14T22:17:28.896Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6d/db/ddb11011826ed7db9d0e485d13df79b58586bfdec56e5c84a928a9a78c1c/zstandard-0.25.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:bfc4e20784722098822e3eee42b8e576b379ed72cca4a7cb856ae733e62192ea", size = 5063001, upload-time = "2025-09-14T22:17:31.044Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/db/00/87466ea3f99599d02a5238498b87bf84a6348290c19571051839ca943777/zstandard-0.25.0-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:457ed498fc58cdc12fc48f7950e02740d4f7ae9493dd4ab2168a47c93c31298e", size = 5394120, upload-time = "2025-09-14T22:17:32.711Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2b/95/fc5531d9c618a679a20ff6c29e2b3ef1d1f4ad66c5e161ae6ff847d102a9/zstandard-0.25.0-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:fd7a5004eb1980d3cefe26b2685bcb0b17989901a70a1040d1ac86f1d898c551", size = 5451230, upload-time = "2025-09-14T22:17:34.41Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/63/4b/e3678b4e776db00f9f7b2fe58e547e8928ef32727d7a1ff01dea010f3f13/zstandard-0.25.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:8e735494da3db08694d26480f1493ad2cf86e99bdd53e8e9771b2752a5c0246a", size = 5547173, upload-time = "2025-09-14T22:17:36.084Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4e/d5/ba05ed95c6b8ec30bd468dfeab20589f2cf709b5c940483e31d991f2ca58/zstandard-0.25.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:3a39c94ad7866160a4a46d772e43311a743c316942037671beb264e395bdd611", size = 5046736, upload-time = "2025-09-14T22:17:37.891Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/50/d5/870aa06b3a76c73eced65c044b92286a3c4e00554005ff51962deef28e28/zstandard-0.25.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:172de1f06947577d3a3005416977cce6168f2261284c02080e7ad0185faeced3", size = 5576368, upload-time = "2025-09-14T22:17:40.206Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5d/35/398dc2ffc89d304d59bc12f0fdd931b4ce455bddf7038a0a67733a25f550/zstandard-0.25.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:3c83b0188c852a47cd13ef3bf9209fb0a77fa5374958b8c53aaa699398c6bd7b", size = 4954022, upload-time = "2025-09-14T22:17:41.879Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9a/5c/36ba1e5507d56d2213202ec2b05e8541734af5f2ce378c5d1ceaf4d88dc4/zstandard-0.25.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:1673b7199bbe763365b81a4f3252b8e80f44c9e323fc42940dc8843bfeaf9851", size = 5267889, upload-time = "2025-09-14T22:17:43.577Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/70/e8/2ec6b6fb7358b2ec0113ae202647ca7c0e9d15b61c005ae5225ad0995df5/zstandard-0.25.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:0be7622c37c183406f3dbf0cba104118eb16a4ea7359eeb5752f0794882fc250", size = 5433952, upload-time = "2025-09-14T22:17:45.271Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7b/01/b5f4d4dbc59ef193e870495c6f1275f5b2928e01ff5a81fecb22a06e22fb/zstandard-0.25.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:5f5e4c2a23ca271c218ac025bd7d635597048b366d6f31f420aaeb715239fc98", size = 5814054, upload-time = "2025-09-14T22:17:47.08Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b2/e5/fbd822d5c6f427cf158316d012c5a12f233473c2f9c5fe5ab1ae5d21f3d8/zstandard-0.25.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4f187a0bb61b35119d1926aee039524d1f93aaf38a9916b8c4b78ac8514a0aaf", size = 5360113, upload-time = "2025-09-14T22:17:48.893Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8e/e0/69a553d2047f9a2c7347caa225bb3a63b6d7704ad74610cb7823baa08ed7/zstandard-0.25.0-cp313-cp313-win32.whl", hash = "sha256:7030defa83eef3e51ff26f0b7bfb229f0204b66fe18e04359ce3474ac33cbc09", size = 436936, upload-time = "2025-09-14T22:17:52.658Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d9/82/b9c06c870f3bd8767c201f1edbdf9e8dc34be5b0fbc5682c4f80fe948475/zstandard-0.25.0-cp313-cp313-win_amd64.whl", hash = "sha256:1f830a0dac88719af0ae43b8b2d6aef487d437036468ef3c2ea59c51f9d55fd5", size = 506232, upload-time = "2025-09-14T22:17:50.402Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d4/57/60c3c01243bb81d381c9916e2a6d9e149ab8627c0c7d7abb2d73384b3c0c/zstandard-0.25.0-cp313-cp313-win_arm64.whl", hash = "sha256:85304a43f4d513f5464ceb938aa02c1e78c2943b29f44a750b48b25ac999a049", size = 462671, upload-time = "2025-09-14T22:17:51.533Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3d/5c/f8923b595b55fe49e30612987ad8bf053aef555c14f05bb659dd5dbe3e8a/zstandard-0.25.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:e29f0cf06974c899b2c188ef7f783607dbef36da4c242eb6c82dcd8b512855e3", size = 795887, upload-time = "2025-09-14T22:17:54.198Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8d/09/d0a2a14fc3439c5f874042dca72a79c70a532090b7ba0003be73fee37ae2/zstandard-0.25.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:05df5136bc5a011f33cd25bc9f506e7426c0c9b3f9954f056831ce68f3b6689f", size = 640658, upload-time = "2025-09-14T22:17:55.423Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5d/7c/8b6b71b1ddd517f68ffb55e10834388d4f793c49c6b83effaaa05785b0b4/zstandard-0.25.0-cp314-cp314-manylinux2010_i686.manylinux_2_12_i686.manylinux_2_28_i686.whl", hash = "sha256:f604efd28f239cc21b3adb53eb061e2a205dc164be408e553b41ba2ffe0ca15c", size = 5379849, upload-time = "2025-09-14T22:17:57.372Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a4/86/a48e56320d0a17189ab7a42645387334fba2200e904ee47fc5a26c1fd8ca/zstandard-0.25.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:223415140608d0f0da010499eaa8ccdb9af210a543fac54bce15babbcfc78439", size = 5058095, upload-time = "2025-09-14T22:17:59.498Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f8/ad/eb659984ee2c0a779f9d06dbfe45e2dc39d99ff40a319895df2d3d9a48e5/zstandard-0.25.0-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:2e54296a283f3ab5a26fc9b8b5d4978ea0532f37b231644f367aa588930aa043", size = 5551751, upload-time = "2025-09-14T22:18:01.618Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/61/b3/b637faea43677eb7bd42ab204dfb7053bd5c4582bfe6b1baefa80ac0c47b/zstandard-0.25.0-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:ca54090275939dc8ec5dea2d2afb400e0f83444b2fc24e07df7fdef677110859", size = 6364818, upload-time = "2025-09-14T22:18:03.769Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/31/dc/cc50210e11e465c975462439a492516a73300ab8caa8f5e0902544fd748b/zstandard-0.25.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e09bb6252b6476d8d56100e8147b803befa9a12cea144bbe629dd508800d1ad0", size = 5560402, upload-time = "2025-09-14T22:18:05.954Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c9/ae/56523ae9c142f0c08efd5e868a6da613ae76614eca1305259c3bf6a0ed43/zstandard-0.25.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:a9ec8c642d1ec73287ae3e726792dd86c96f5681eb8df274a757bf62b750eae7", size = 4955108, upload-time = "2025-09-14T22:18:07.68Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/98/cf/c899f2d6df0840d5e384cf4c4121458c72802e8bda19691f3b16619f51e9/zstandard-0.25.0-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:a4089a10e598eae6393756b036e0f419e8c1d60f44a831520f9af41c14216cf2", size = 5269248, upload-time = "2025-09-14T22:18:09.753Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1b/c0/59e912a531d91e1c192d3085fc0f6fb2852753c301a812d856d857ea03c6/zstandard-0.25.0-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:f67e8f1a324a900e75b5e28ffb152bcac9fbed1cc7b43f99cd90f395c4375344", size = 5430330, upload-time = "2025-09-14T22:18:11.966Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a0/1d/7e31db1240de2df22a58e2ea9a93fc6e38cc29353e660c0272b6735d6669/zstandard-0.25.0-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:9654dbc012d8b06fc3d19cc825af3f7bf8ae242226df5f83936cb39f5fdc846c", size = 5811123, upload-time = "2025-09-14T22:18:13.907Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f6/49/fac46df5ad353d50535e118d6983069df68ca5908d4d65b8c466150a4ff1/zstandard-0.25.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:4203ce3b31aec23012d3a4cf4a2ed64d12fea5269c49aed5e4c3611b938e4088", size = 5359591, upload-time = "2025-09-14T22:18:16.465Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c2/38/f249a2050ad1eea0bb364046153942e34abba95dd5520af199aed86fbb49/zstandard-0.25.0-cp314-cp314-win32.whl", hash = "sha256:da469dc041701583e34de852d8634703550348d5822e66a0c827d39b05365b12", size = 444513, upload-time = "2025-09-14T22:18:20.61Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3a/43/241f9615bcf8ba8903b3f0432da069e857fc4fd1783bd26183db53c4804b/zstandard-0.25.0-cp314-cp314-win_amd64.whl", hash = "sha256:c19bcdd826e95671065f8692b5a4aa95c52dc7a02a4c5a0cac46deb879a017a2", size = 516118, upload-time = "2025-09-14T22:18:17.849Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f0/ef/da163ce2450ed4febf6467d77ccb4cd52c4c30ab45624bad26ca0a27260c/zstandard-0.25.0-cp314-cp314-win_arm64.whl", hash = "sha256:d7541afd73985c630bafcd6338d2518ae96060075f9463d7dc14cfb33514383d", size = 476940, upload-time = "2025-09-14T22:18:19.088Z" },
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user