mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-30 04:22:02 +00:00
Compare commits
220 Commits
cpu-mps-de
...
d5418ea5a1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d5418ea5a1 | ||
|
|
c88bbf8133 | ||
|
|
c8d93beed2 | ||
|
|
8630d32be4 | ||
|
|
59e36cc727 | ||
|
|
85b3e95e09 | ||
|
|
6a477eedbd | ||
|
|
63bb5831e2 | ||
|
|
a91743c168 | ||
|
|
d58fcd9d73 | ||
|
|
babde18ce1 | ||
|
|
cf5c9e5b8e | ||
|
|
413e91aa0f | ||
|
|
e7ed2082b8 | ||
|
|
f9a7e0f111 | ||
|
|
f5425245f9 | ||
|
|
2955650327 | ||
|
|
77a46902e4 | ||
|
|
bbc4413c58 | ||
|
|
f42ae9e901 | ||
|
|
e1dafc510f | ||
|
|
6460dc6382 | ||
|
|
1933e85046 | ||
|
|
3b95d4fd39 | ||
|
|
e85db6b4a4 | ||
|
|
9a88194c3f | ||
|
|
0b58d70e99 | ||
|
|
e3f58b838e | ||
|
|
184d4c12b1 | ||
|
|
b62a5bc44a | ||
|
|
8203efa919 | ||
|
|
50413d2d67 | ||
|
|
fbf2bbea25 | ||
|
|
747ed4491f | ||
|
|
7d1700c521 | ||
|
|
d4ea28d4e2 | ||
|
|
bdcc030ffa | ||
|
|
22a71aa3d3 | ||
|
|
255f8b9af6 | ||
|
|
6bb92403d5 | ||
|
|
3142ca1a28 | ||
|
|
7312ec9898 | ||
|
|
3b50b77ed3 | ||
|
|
f92efce169 | ||
|
|
43c29dd9d5 | ||
|
|
23985413aa | ||
|
|
64b48d0e5c | ||
|
|
238353c998 | ||
|
|
4610a838a1 | ||
|
|
21608ec51e | ||
|
|
aa95fb2e03 | ||
|
|
b33e394528 | ||
|
|
fbc1484e8c | ||
|
|
2ff7d51252 | ||
|
|
201d705957 | ||
|
|
aa530cdad5 | ||
|
|
2c4473dd1b | ||
|
|
f5a0ea4d3f | ||
|
|
4ddc803797 | ||
|
|
a1ccb3dc0b | ||
|
|
061f83c152 | ||
|
|
e8c30c3b19 | ||
|
|
3af4dcf6ee | ||
|
|
4cc605b940 | ||
|
|
ccf4b7f9bf | ||
|
|
1b5de29e71 | ||
|
|
ae0bf52529 | ||
|
|
eec0c79563 | ||
|
|
54e59c38ad | ||
|
|
9d4c9b786d | ||
|
|
962b6bfba3 | ||
|
|
ed2082fbc4 | ||
|
|
eb7bbc1b66 | ||
|
|
507d54224a | ||
|
|
9c60dfb64c | ||
|
|
be56d29b87 | ||
|
|
ee79f29fbd | ||
|
|
da8b7ea4cb | ||
|
|
aa42f40e66 | ||
|
|
48abd7d85f | ||
|
|
10231dfb40 | ||
|
|
389d019a0b | ||
|
|
8c89661465 | ||
|
|
8f979a8bda | ||
|
|
2f2d7ab80c | ||
|
|
91d76cc690 | ||
|
|
7a8769a40c | ||
|
|
088726aa7d | ||
|
|
2874eda59a | ||
|
|
e1770a3061 | ||
|
|
49389ecaa8 | ||
|
|
ea4229851b | ||
|
|
7840049189 | ||
|
|
bc51da8bac | ||
|
|
92c6654b95 | ||
|
|
790f3be65c | ||
|
|
d314e96aa2 | ||
|
|
bbc57da7d5 | ||
|
|
f1bf69d562 | ||
|
|
d5759400f9 | ||
|
|
e72c3299df | ||
|
|
7931e0903a | ||
|
|
849d95ae1f | ||
|
|
39cccc527f | ||
|
|
8b1cecaa95 | ||
|
|
58f3e84e01 | ||
|
|
1b2a675c88 | ||
|
|
d75e6ed711 | ||
|
|
72a7cf2bc4 | ||
|
|
bffdb2ef91 | ||
|
|
cbf30c842c | ||
|
|
90442de35f | ||
|
|
2fd0440355 | ||
|
|
01ea71be39 | ||
|
|
a8847a0f83 | ||
|
|
06677c30e0 | ||
|
|
a770dcef2e | ||
|
|
16788eed3c | ||
|
|
53b3a4fb81 | ||
|
|
4bcc3bb698 | ||
|
|
f37d45c21f | ||
|
|
5c93a56be5 | ||
|
|
dddb95caac | ||
|
|
a4a0959c73 | ||
|
|
024781f9df | ||
|
|
97770700f2 | ||
|
|
4a87a0d19f | ||
|
|
11e68bf442 | ||
|
|
bc1fca39f3 | ||
|
|
f66a780f68 | ||
|
|
4763ce612a | ||
|
|
c6f5bd67db | ||
|
|
a2fb3c83a6 | ||
|
|
e5efb4b471 | ||
|
|
9a71d13688 | ||
|
|
7b7fd0fe71 | ||
|
|
c6abcdfe3a | ||
|
|
91f09ccd0d | ||
|
|
adb5d4a16c | ||
|
|
b399e43168 | ||
|
|
c6b7ab7440 | ||
|
|
885a4f25e7 | ||
|
|
3a2ae631c4 | ||
|
|
12d995f58c | ||
|
|
f1683c5b16 | ||
|
|
d1558c7873 | ||
|
|
df25293087 | ||
|
|
1e89af9862 | ||
|
|
7a40ee77b4 | ||
|
|
2ce62ec076 | ||
|
|
e22fc6f2fa | ||
|
|
c72b8b2309 | ||
|
|
a83646e098 | ||
|
|
8681922328 | ||
|
|
226953b841 | ||
|
|
f1e15f5f4d | ||
|
|
b6da6982f6 | ||
|
|
c2c4f77e22 | ||
|
|
d1ac0b2d07 | ||
|
|
5bfcd31b73 | ||
|
|
036a3c5881 | ||
|
|
52e85aaf80 | ||
|
|
ba4f40bf58 | ||
|
|
d54c9cbf8c | ||
|
|
cf587acb1a | ||
|
|
7d2c4a3d95 | ||
|
|
ad39db5a23 | ||
|
|
630f54ae5a | ||
|
|
f15732524a | ||
|
|
dfc88334b6 | ||
|
|
eb11bb0e2e | ||
|
|
70319851fc | ||
|
|
1ccbaf4416 | ||
|
|
29ff38d94b | ||
|
|
b996131570 | ||
|
|
3fa974f93c | ||
|
|
cbd560a83d | ||
|
|
a1de1f46ad | ||
|
|
ee00f523d0 | ||
|
|
5e0987a431 | ||
|
|
8c9b004c99 | ||
|
|
0a3ce7b0ff | ||
|
|
fdda5826e3 | ||
|
|
baf0b3fdda | ||
|
|
f1db6b4712 | ||
|
|
9415931f85 | ||
|
|
2b9c085559 | ||
|
|
b062b422ac | ||
|
|
a9de4b1038 | ||
|
|
c75fe54aa7 | ||
|
|
fca2b8cd07 | ||
|
|
05a051dbe9 | ||
|
|
8892470f29 | ||
|
|
81597cd616 | ||
|
|
cc3636b01c | ||
|
|
d5cda11ab8 | ||
|
|
5eeb2b6ef9 | ||
|
|
2dda5c4c8d | ||
|
|
80b203ea59 | ||
|
|
917c858136 | ||
|
|
db1d5b595d | ||
|
|
dd9387b362 | ||
|
|
32571664b1 | ||
|
|
51e70f0d3c | ||
|
|
48387cd895 | ||
|
|
796f84527f | ||
|
|
7a52f9bfbb | ||
|
|
760af62e11 | ||
|
|
901b075605 | ||
|
|
defd1246aa | ||
|
|
2e938530ce | ||
|
|
a088b7a6ec | ||
|
|
94ee507054 | ||
|
|
33e8a27f91 | ||
|
|
49cd02f283 | ||
|
|
2b58e2dd2a | ||
|
|
ed519b0f24 | ||
|
|
1f7ee5d3ce | ||
|
|
938cb31f1a | ||
|
|
02440f670d |
10
.gitignore
vendored
10
.gitignore
vendored
@@ -1,6 +1,14 @@
|
||||
.venv/
|
||||
__pycache__/
|
||||
*.pyc
|
||||
rustbpe/target/
|
||||
dev-ignore/
|
||||
report.md
|
||||
eval_bundle/
|
||||
|
||||
# Secrets
|
||||
.env
|
||||
|
||||
# Local setup
|
||||
.claude
|
||||
CLAUDE.md
|
||||
wandb/
|
||||
|
||||
105
README.md
105
README.md
@@ -4,24 +4,29 @@
|
||||
|
||||
> The best ChatGPT that $100 can buy.
|
||||
|
||||
This repo is a full-stack implementation of an LLM like ChatGPT in a single, clean, minimal, hackable, dependency-lite codebase. nanochat is designed to run on a single 8XH100 node via scripts like [speedrun.sh](speedrun.sh), that run the entire pipeline start to end. This includes tokenization, pretraining, finetuning, evaluation, inference, and web serving over a simple UI so that you can talk to your own LLM just like ChatGPT. nanochat will become the capstone project of the course LLM101n being developed by Eureka Labs.
|
||||
This repo is a full-stack implementation of an LLM like ChatGPT in a single, clean, minimal, hackable, dependency-lite codebase. nanochat is designed to run on a single 8XH100 node via scripts like [speedrun.sh](runs/speedrun.sh), that run the entire pipeline start to end. This includes tokenization, pretraining, finetuning, evaluation, inference, and web serving over a simple UI so that you can talk to your own LLM just like ChatGPT. nanochat will become the capstone project of the course LLM101n being developed by Eureka Labs.
|
||||
|
||||
## Updates
|
||||
|
||||
- (Jan 16 2026) The repo is in active development, I am currently fleshing out the pretraining stage.
|
||||
- (Jan 7 2026) See new post: [nanochat Miniseries v1](https://github.com/karpathy/nanochat/discussions/420) and the associated script [miniseries.sh](runs/miniseries.sh).
|
||||
|
||||
## Talk to it
|
||||
|
||||
To get a sense of the endpoint of this repo, you can currently find [nanochat d32](https://github.com/karpathy/nanochat/discussions/8) hosted on [nanochat.karpathy.ai](https://nanochat.karpathy.ai/). "d32" means that this model has 32 layers in the Transformer neural network. This model has 1.9 billion parameters, it was trained on 38 billion tokens by simply running the single script [run1000.sh](run1000.sh), and the total cost of training was ~$800 (about 33 hours training time on 8XH100 GPU node). While today this is enough to outperform GPT-2 of 2019, it falls dramatically short of moden Large Language Models like GPT-5. When talking to these micro models, you'll see that they make a lot of mistakes, they are a little bit naive and silly and they hallucinate a ton, a bit like children. It's kind of amusing. But what makes nanochat unique is that it is fully yours - fully configurable, tweakable, hackable, and trained by you from start to end. To train and talk to your own, we turn to...
|
||||
To get a sense of the endpoint of this repo, you can currently find [nanochat d34](https://github.com/karpathy/nanochat/discussions/314) hosted on [nanochat.karpathy.ai](https://nanochat.karpathy.ai/). "d34" means that this model has 34 layers in the Transformer neural network. This model has 2.2 billion parameters, it was trained on 88 billion tokens by simply running the training script [run1000.sh](runs/run1000.sh) with `--target_param_data_ratio=40` (2x longer than Chinchilla-optimal), and the total cost of training was ~$2,500 (about 100 hours training time on 8XH100 GPU node). While today this is enough to outperform GPT-2 of 2019, it falls dramatically short of modern Large Language Models like GPT-5. When talking to these micro models, you'll see that they make a lot of mistakes, they are a little bit naive and silly and they hallucinate a ton, a bit like children. It's kind of amusing. But what makes nanochat unique is that it is fully yours - fully configurable, tweakable, hackable, and trained by you from start to end. To train and talk to your own, we turn to...
|
||||
|
||||
## Quick start
|
||||
|
||||
The fastest way to feel the magic is to run the speedrun script [speedrun.sh](speedrun.sh), which trains and inferences the $100 tier of nanochat. On an 8XH100 node at $24/hr, this gives a total run time of about 4 hours. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script:
|
||||
The fastest way to feel the magic is to run the speedrun script [speedrun.sh](runs/speedrun.sh), which trains and inferences the $100 tier of nanochat. On an 8XH100 node at $24/hr, this gives a total run time of about 4 hours. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script:
|
||||
|
||||
```bash
|
||||
bash speedrun.sh
|
||||
bash runs/speedrun.sh
|
||||
```
|
||||
|
||||
Alternatively, since the script runs for 4 hours, I like to launch it like this inside a new screen session `speedrun` (and also log output to `speedrun.log`):
|
||||
|
||||
```bash
|
||||
screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
|
||||
screen -L -Logfile speedrun.log -S speedrun bash runs/speedrun.sh
|
||||
```
|
||||
|
||||
See the [screen cheatsheet](https://gist.github.com/jctosta/af918e1618682638aa82) if you are less familiar. You can watch it go inside the screen session, or detach with `Ctrl-a d` and `tail speedrun.log` to view progress. Now wait 4 hours. Once it's done, you can talk to your LLM via the ChatGPT-like web UI. Make sure again that your local uv virtual environment is active (run `source .venv/bin/activate`), and serve it:
|
||||
@@ -68,7 +73,7 @@ Total wall clock time: 3h51m
|
||||
|
||||
Unsurprisingly, $100 is not enough to train a highly performant ChatGPT clone. In fact, LLMs are famous for their multi-million dollar capex. For our purposes, I think there are two more scales of interest. First is the ~$300 tier d26 model (i.e. depth=26) that trains in ~12 hours, which slightly outperforms GPT-2 CORE score. Second is the $1000 tier (~41.6 hours), just because it's a nice round number. But both of these are not yet fully supported and therefore not attached here in the master branch yet.
|
||||
|
||||
That said, to give a sense, the example changes needed for the [speedrun.sh](speedrun.sh) file to train a GPT-2 grade model d26 only involve three changes:
|
||||
That said, to give a sense, the example changes needed for the [speedrun.sh](runs/speedrun.sh) file to train a GPT-2 grade model d26 only involve three changes:
|
||||
|
||||
```bash
|
||||
...
|
||||
@@ -78,13 +83,13 @@ That said, to give a sense, the example changes needed for the [speedrun.sh](spe
|
||||
python -m nanochat.dataset -n 450 &
|
||||
...
|
||||
# use --depth to increase model size. to not oom, halve device batch size 32 -> 16:
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=26 --device_batch_size=16
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=26 --device-batch-size=16
|
||||
...
|
||||
# make sure to use the same later during midtraining:
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device-batch-size=16
|
||||
```
|
||||
|
||||
That's it! The biggest thing to pay attention to is making sure you have enough data shards to train on (the code will loop and do more epochs over the same training set otherwise, decreasing learning speed a bit), and managing your memory/VRAM, primarily by decreasing the `device_batch_size` until things fit (the scripts automatically compensates by increasing the number of gradient accumulation loops, simply turning parallel compute to sequential compute).
|
||||
That's it! The biggest thing to pay attention to is making sure you have enough data shards to train on (the code will loop and do more epochs over the same training set otherwise, decreasing learning speed a bit), and managing your memory/VRAM, primarily by decreasing the `device_batch_size` until things fit (the scripts automatically compensate by increasing the number of gradient accumulation loops, simply turning parallel compute to sequential compute).
|
||||
|
||||
And a bit more about computing environments that will run nanochat:
|
||||
|
||||
@@ -95,35 +100,94 @@ And a bit more about computing environments that will run nanochat:
|
||||
|
||||
## Running on CPU / MPS
|
||||
|
||||
nanochat cn be run on CPU or on MPS (if you're on Macbook), and will automatically try to detect what device is best to run on. You're not going to get too far without GPUs, but at least you'll be able to run the code paths and maybe train a tiny LLM with some patience. For an example of how to make all the run commands much smaller (feel free to tune!), you can refer to [dev/runcpu.sh](dev/runcpu.sh) file. You'll see that I'm essentially restricting all scripts to train smaller models, to run for shorter number of iterations, etc. This functionality is new, slightly gnarly (touched a lot of code), and was merged in this [CPU|MPS PR](https://github.com/karpathy/nanochat/pull/88) on Oct 21, 2025.
|
||||
nanochat can be run on CPU or on MPS (if you're on Macbook) in principle, and will automatically try to detect what device is best to run on. The script [runcpu.sh](runs/runcpu.sh) shows a very simple example that will exercise the code paths but basically produce garbage results. Unless you know what you're doing, I basically don't recommend using this script right now and hope to tune it a bit more in the future.
|
||||
|
||||
## Customization
|
||||
|
||||
To customize your nanochat, see [Guide: infusing identity to your nanochat](https://github.com/karpathy/nanochat/discussions/139) in Discussions, which describes how you can tune your nanochat's personality through synthetic data generation and mixing that data into midtraining and SFT stages.
|
||||
|
||||
Additionally, to add new abilities to nanochat, see [Guide: counting r in strawberry (and how to add abilities generally)](https://github.com/karpathy/nanochat/discussions/164).
|
||||
|
||||
## Questions
|
||||
|
||||
nanochat is designed to be short and sweet. One big advantage of this is that we can package up all of the files together and copy paste them to your favorite LLM to ask arbitrary questions. As an example, I like to package up the repo using the [files-to-prompt](https://github.com/simonw/files-to-prompt) utility like so:
|
||||
I recommend using [DeepWiki](https://deepwiki.com/karpathy/nanochat) from Devin/Cognition to ask questions of this repo. In the URL of this repo, simply change github.com to deepwiki.com, and you're off.
|
||||
|
||||
```bash
|
||||
files-to-prompt . -e py -e md -e rs -e html -e toml -e sh --ignore "*target*" --cxml > packaged.txt
|
||||
```
|
||||
|
||||
This includes all py, rs, html, toml, sh files, excludes the `rustbpe/target` folder, and chooses the cxml output format. Everything is written to the `packaged.txt` file, which atm measures ~330KB (i.e. well below ~100K tokens for a state of the art LLM), and ~8K lines of code in 45 files.
|
||||
|
||||
Alternatively, I recommend using [DeepWiki](https://deepwiki.com/) from Devin/Cognition to ask questions of this repo. In the URL of this repo, simply change github.com to deepwiki.com, and you're off.
|
||||
You can also come to the [#nanochat Discord channel](https://discord.com/channels/1020383067459821711/1427295580895314031) to ask questions, or use the Discussions.
|
||||
|
||||
## Tests
|
||||
|
||||
I haven't invested too much here but some tests exist, especially for the tokenizer. Run e.g. as:
|
||||
|
||||
```bash
|
||||
python -m pytest tests/test_rustbpe.py -v -s
|
||||
python -m pytest tests/test_engine.py -v -s
|
||||
```
|
||||
|
||||
## File structure
|
||||
|
||||
```
|
||||
.
|
||||
├── LICENSE
|
||||
├── README.md
|
||||
├── dev
|
||||
│ ├── gen_synthetic_data.py # Example synthetic data for identity
|
||||
│ ├── generate_logo.html
|
||||
│ ├── nanochat.png
|
||||
│ └── repackage_data_reference.py # Pretraining data shard generation
|
||||
├── nanochat
|
||||
│ ├── __init__.py # empty
|
||||
│ ├── adamw.py # Distributed AdamW optimizer
|
||||
│ ├── checkpoint_manager.py # Save/Load model checkpoints
|
||||
│ ├── common.py # Misc small utilities, quality of life
|
||||
│ ├── core_eval.py # Evaluates base model CORE score (DCLM paper)
|
||||
│ ├── dataloader.py # Tokenizing Distributed Data Loader
|
||||
│ ├── dataset.py # Download/read utils for pretraining data
|
||||
│ ├── engine.py # Efficient model inference with KV Cache
|
||||
│ ├── execution.py # Allows the LLM to execute Python code as tool
|
||||
│ ├── gpt.py # The GPT nn.Module Transformer
|
||||
│ ├── logo.svg
|
||||
│ ├── loss_eval.py # Evaluate bits per byte (instead of loss)
|
||||
│ ├── muon.py # Distributed Muon optimizer
|
||||
│ ├── report.py # Utilities for writing the nanochat Report
|
||||
│ ├── tokenizer.py # BPE Tokenizer wrapper in style of GPT-4
|
||||
│ └── ui.html # HTML/CSS/JS for nanochat frontend
|
||||
├── pyproject.toml
|
||||
├── runs
|
||||
│ ├── miniseries.sh # Miniseries training script
|
||||
│ ├── run1000.sh # Train the ~$800 nanochat d32
|
||||
│ ├── runcpu.sh # Small example of how to run on CPU/MPS
|
||||
│ ├── scaling_laws.sh # Scaling laws experiments
|
||||
│ └── speedrun.sh # Train the ~$100 nanochat d20
|
||||
├── scripts
|
||||
│ ├── base_eval.py # Base model: calculate CORE score
|
||||
│ ├── base_loss.py # Base model: calculate bits per byte, sample
|
||||
│ ├── base_train.py # Base model: train
|
||||
│ ├── chat_cli.py # Chat model (SFT/Mid): talk to over CLI
|
||||
│ ├── chat_eval.py # Chat model (SFT/Mid): eval tasks
|
||||
│ ├── chat_rl.py # Chat model (SFT/Mid): reinforcement learning
|
||||
│ ├── chat_sft.py # Chat model: train SFT
|
||||
│ ├── chat_web.py # Chat model (SFT/Mid): talk to over WebUI
|
||||
│ ├── mid_train.py # Chat model: midtraining
|
||||
│ ├── tok_eval.py # Tokenizer: evaluate compression rate
|
||||
│ └── tok_train.py # Tokenizer: train it
|
||||
├── tasks
|
||||
│ ├── arc.py # Multiple choice science questions
|
||||
│ ├── common.py # TaskMixture | TaskSequence
|
||||
│ ├── customjson.py # Make Task from arbitrary jsonl convos
|
||||
│ ├── gsm8k.py # 8K Grade School Math questions
|
||||
│ ├── humaneval.py # Misnomer; Simple Python coding task
|
||||
│ ├── mmlu.py # Multiple choice questions, broad topics
|
||||
│ ├── smoltalk.py # Conglomerate dataset of SmolTalk from HF
|
||||
│ └── spellingbee.py # Task teaching model to spell/count letters
|
||||
├── tests
|
||||
│ └── test_engine.py
|
||||
└── uv.lock
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
nanochat is nowhere finished. The goal is to improve the state of the art in micro models that are accessible to work with end to end on budgets of < $1000 dollars. Accessibility is about overall cost but also about cognitive complexity - nanochat is not an exhaustively configurable LLM "framework"; there will be no giant configuration objects, model factories, or if-then-else monsters in the code base. It is a single, cohesive, minimal, readable, hackable, maximally-forkable "strong baseline" codebase designed to run start to end and produce a concrete ChatGPT clone and its report card.
|
||||
nanochat is nowhere near finished. The goal is to improve the state of the art in micro models that are accessible to work with end to end on budgets of < $1000 dollars. Accessibility is about overall cost but also about cognitive complexity - nanochat is not an exhaustively configurable LLM "framework"; there will be no giant configuration objects, model factories, or if-then-else monsters in the code base. It is a single, cohesive, minimal, readable, hackable, maximally-forkable "strong baseline" codebase designed to run start to end and produce a concrete ChatGPT clone and its report card.
|
||||
|
||||
Current LLM policy: disclosure. When submitting a PR, please declare any parts that had substantial LLM contribution and that you have not written or that you do not fully understand.
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
@@ -132,6 +196,7 @@ nanochat is nowhere finished. The goal is to improve the state of the art in mic
|
||||
- Thank you to [HuggingFace](https://huggingface.co/) for fineweb and smoltalk.
|
||||
- Thank you [Lambda](https://lambda.ai/service/gpu-cloud) for the compute used in developing this project.
|
||||
- Thank you to chief LLM whisperer 🧙♂️ Alec Radford for advice/guidance.
|
||||
- Thank you to the repo czar Sofie [@svlandeg](https://github.com/svlandeg) for help with managing issues, pull requests and discussions of nanochat.
|
||||
|
||||
## Cite
|
||||
|
||||
|
||||
684
dev/LOG.md
Normal file
684
dev/LOG.md
Normal file
@@ -0,0 +1,684 @@
|
||||
# Experiment Log
|
||||
|
||||
A running summary documenting some experiments and findings. Started ~Jan 7 2026.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-27: Bigram Hash Embeddings (Engram-lite)
|
||||
|
||||
Explored N-gram memory modules inspired by the [DeepSeek Engram paper](https://arxiv.org/abs/2601.07372) and [modded-nanogpt PR #201](https://github.com/KellerJordan/modded-nanogpt/pull/201).
|
||||
|
||||
### Background
|
||||
|
||||
The Engram paper introduces "conditional memory" as a complement to MoE - using O(1) hash lookups to retrieve static N-gram patterns instead of reconstructing them through computation. Key insight: transformers waste early layers "simulating retrieval through computation" for patterns like named entities and formulaic phrases that could be simple table lookups.
|
||||
|
||||
### What We Tried
|
||||
|
||||
**1. Full Engram module with context-aware gating (paper design)**
|
||||
```python
|
||||
# Hash bigrams to retrieve embeddings, then gate with hidden state
|
||||
e = embed(hash(prev_token, curr_token))
|
||||
q = RMSNorm(h) # hidden state as query
|
||||
k = RMSNorm(W_k @ e) # projected embedding as key
|
||||
v = W_v @ e
|
||||
α = sigmoid(q · k / √d) # scalar gate per position
|
||||
output = α * v
|
||||
```
|
||||
- Injected after block 1 (paper found early injection optimal)
|
||||
- Slight improvement, but quite a bit of complexity added.
|
||||
|
||||
**2. Early-layer only injection**
|
||||
- Only inject bigram signal in first 4 layers (where paper claims static pattern offloading helps most)
|
||||
- **Result:** Actually hurt performance. The model seems to need uniform injection across all layers.
|
||||
|
||||
**3. Trigrams**
|
||||
- Extended to hash both 2-grams and 3-grams, concatenating embeddings
|
||||
- **Result:** No improvement over bigrams alone. Dilutes capacity from more frequent 2-gram patterns.
|
||||
|
||||
**4. Bigram-only with x0-style injection (modded-nanogpt engram-lite approach)**
|
||||
- Simple hash: `(36313 * curr) XOR (27191 * prev) mod table_size`
|
||||
- Zero-init embedding table, learned per-layer lambdas
|
||||
- Add to residual at every layer: `x = resid_λ[i]*x + x0_λ[i]*x0 + bigram_λ[i]*x0_bigram`
|
||||
- **Result:** This simple approach works and provides a consistent improvement.
|
||||
|
||||
TLDR The winning approach follows modded-nanogpt's "engram-lite", simply adding the following module and feeding its output into the residual branch (gated by a per-layer learnable \lambda) before every single block:
|
||||
|
||||
```python
|
||||
class BigramEmbed(nn.Module):
|
||||
def __init__(self, vocab_size, embed_dim, table_multiplier=5):
|
||||
self.embed = nn.Embedding(vocab_size * table_multiplier, embed_dim)
|
||||
|
||||
def forward(self, idx):
|
||||
h = (36313 * idx[:, 1:]) ^ (27191 * idx[:, :-1]) % (table_size - 1)
|
||||
return self.embed(h)
|
||||
```
|
||||
|
||||
As for optimal hyperparameters:
|
||||
|
||||
- **Table size:** `vocab_size * 5` (~164K entries for 32K vocab). Swept a number of settings and 5 was optimal.
|
||||
- **Injection:** Every layer via learned `bigram_lambdas` (init 0.1 was better than 0.0).
|
||||
- **Normalization:** Also tried adding a `norm()` to the embeddings (mirroring the token embeddings), this was slightly worse.
|
||||
- **Init:** Zero-init embedding, so starts as identity (tried small noisy init, it's worse)
|
||||
- **Optimizer:** AdamW with same LR as token embeddings
|
||||
|
||||
### Key Learnings
|
||||
|
||||
1. **Gating didn't help at our scale.** The paper's context-aware gating mechanism (sigmoid dot-product gate) added parameters and complexity without improvement. modded-nanogpt found the same: "simple direct addition to the residual stream outperformed by a decent margin."
|
||||
|
||||
2. **Uniform injection beats early-only.** Despite the paper's finding that early layers benefit most, restricting injection to early layers hurt. The x0-style "add everywhere with learned lambda" pattern works better for our architecture/scale.
|
||||
|
||||
3. **Bigrams are sufficient.** Trigrams didn't help - the extra context doesn't pay for the diluted capacity.
|
||||
|
||||
4. **Scale matters.** The Engram paper's results are at 27B params with MoE. At our ~100M-1B scale, the simpler approach wins. The elaborate gating mechanism may become useful at larger scales where collision handling matters more.
|
||||
|
||||
### Parameters Added
|
||||
|
||||
For d12 model with `table_multiplier=5`:
|
||||
- Bigram embedding: 32768 × 5 × 768 = ~126M params
|
||||
- Per-layer lambdas: 12 scalars (negligible)
|
||||
|
||||
If you're keeping track, we now have *a lot* of parameters, a significant amount of them in embeddings (token embeddings, bigram embeddings, value embeddings). For example, for a d12 we now have:
|
||||
|
||||
```
|
||||
Parameter counts:
|
||||
wte : 25,165,824
|
||||
bigram_embed : 125,829,120
|
||||
value_embeds : 150,994,944
|
||||
lm_head : 25,165,824
|
||||
transformer_matrices : 84,935,808
|
||||
scalars : 36
|
||||
total : 412,091,556
|
||||
```
|
||||
|
||||
In other words, only about a quarter of parameters are now weight projections and the vast majority is embedding tables.
|
||||
|
||||
Still, on all axes (steps, wall clock time, flops), this somewhat parameter-bloated architecture beats the baseline and will now become the default.
|
||||
|
||||
After adding the engram-lite, I re-ran the scaling laws to determine the new optimal tokens:params ratio. I swept FLOPs in the range 1e18..1e19, exponentially strided in 4 settings (1e18, 2e18, 5e18, 1e19). I looked at a number of ways of determining the effective parameter count for the purposes of the scaling laws. The results looked like this:
|
||||
|
||||
```
|
||||
Kaplan-style (all projections including lm_head and no embeddings)
|
||||
|
||||
Optimal configurations (from quadratic fits):
|
||||
FLOPs Eff Params Tokens Ratio Val BPB
|
||||
-----------------------------------------------------------------
|
||||
1e+18 110,678,115 1,241,505,403 11.2 0.8972
|
||||
2e+18 167,797,457 1,785,336,422 10.7 0.8616
|
||||
5e+18 250,650,865 2,642,234,152 10.8 0.8293
|
||||
1e+19 381,758,347 3,806,871,243 10.3 0.7999
|
||||
|
||||
N \propto C^0.54, D \propto C^0.49
|
||||
|
||||
Chinchilla-style (all parameters, period.)
|
||||
|
||||
Optimal configurations (from quadratic fits):
|
||||
FLOPs Eff Params Tokens Ratio Val BPB
|
||||
-----------------------------------------------------------------
|
||||
1e+18 416,320,605 1,232,157,011 3.0 0.8974
|
||||
2e+18 560,239,841 1,763,669,281 3.2 0.8616
|
||||
5e+18 741,495,903 2,629,909,368 3.6 0.8291
|
||||
1e+19 988,644,331 3,884,841,895 4.0 0.7999
|
||||
|
||||
N \propto C^0.37, D \propto C^0.50
|
||||
|
||||
Transformer-only-style (only the projections inside the transformer)
|
||||
|
||||
Optimal configurations (from quadratic fits):
|
||||
FLOPs Eff Params Tokens Ratio Val BPB
|
||||
-----------------------------------------------------------------
|
||||
1e+18 80,259,665 1,315,639,547 17.2 0.8966
|
||||
2e+18 131,488,566 1,864,134,141 14.5 0.8622
|
||||
5e+18 220,985,474 2,595,328,843 12.1 0.8302
|
||||
1e+19 401,213,504 3,328,704,512 8.5 0.7994
|
||||
|
||||
N \propto C^0.70, D \propto C^0.41
|
||||
```
|
||||
|
||||
Clearly, the Kaplan-style ratios are most consistent and produce stable ~0.5 exponents for both params and tokens, meaning we can have a single fixed ratio of tokens:params for compute optimal models. This turns out to be about ~10.5, which now becomes the new default.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-19 to 2026-01-22: Optimizer Hyperparameter Sweep
|
||||
|
||||
Ran ~320 experiments across 6 rounds, scaling from d12→d16→d20 to find optimal optimizer hyperparameters. Added granular per-component control to `setup_optimizers()` — separate LRs and betas for embedding, unembedding, value_embeds, resid_lambdas, x0_lambdas, and Muon matrix params.
|
||||
|
||||
### What We Swept
|
||||
- Learning rates for all 6 parameter groups
|
||||
- Beta1/beta2 for all 5 AdamW groups
|
||||
- Muon momentum (start/end), weight decay
|
||||
- Hundreds of combinations (2-way, 3-way, 4-way, etc.)
|
||||
|
||||
### The Journey
|
||||
|
||||
**At d12**, found two independent improvement routes:
|
||||
- **Route A:** emb_lr↑ (0.3→0.4), weight_decay↑ (0.1→0.15), matrix_lr↑ (0.02→0.025)
|
||||
- **Route B:** x0_lr↓ (0.5→0.2), x0_beta1↑ (0.8→0.9+)
|
||||
|
||||
Both gave ~0.002 improvement, but combining them caused conflicts. Fine-tuning found wd=0.13, matrix_lr=0.027, emb_lr=0.38 helped slightly. Best d12 config: Route A + x0_beta1=0.95.
|
||||
|
||||
**At d16**, Route B became competitive with Route A. The routes still conflicted when combined.
|
||||
|
||||
**At d20** (target scale), everything changed:
|
||||
- Fine-tuned values from d12 **actively hurt** performance
|
||||
- Routes no longer conflicted
|
||||
- Just `x0_beta1=0.96` alone captured nearly all the gains
|
||||
|
||||
### Final x0_beta1 Sweep at d20
|
||||
|
||||
| x0_beta1 | val/bpb | Δ vs baseline |
|
||||
|----------|---------|---------------|
|
||||
| **0.96** | **0.7971** | **-0.0007** |
|
||||
| 0.94 | 0.7972 | -0.0006 |
|
||||
| 0.90 | 0.7972 | -0.0006 |
|
||||
| 0.97 | 0.7977 | -0.0001 |
|
||||
| 0.98 | 0.8011 | +0.0033 💀 |
|
||||
|
||||
Flat plateau from 0.90-0.96, then sharp cliff at 0.97+.
|
||||
|
||||
### Key Learnings
|
||||
|
||||
1. **Hyperparameters are scale-dependent.** What works at d12 doesn't transfer to d20. The elaborate fine-tuning that won at d12 actively hurts at d20.
|
||||
|
||||
2. **Improvement magnitude shrinks with scale.** ~0.002 at d12 → ~0.0007 at d20. The baseline is already better-tuned for larger models.
|
||||
|
||||
3. **Sharp cliffs exist.** x0_beta1=0.98 is catastrophic while 0.96 is optimal.
|
||||
|
||||
4. **Don't over-tune on small proxies.** Validate at target scale before shipping.
|
||||
|
||||
### Final Recommendation
|
||||
|
||||
For production d20 runs, add one flag:
|
||||
```
|
||||
--x0-lambdas-beta1=0.96
|
||||
```
|
||||
|
||||
Skip everything else discovered at smaller scales.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-18: More various experiments
|
||||
|
||||
- Tried Muon custom kernels for XXT and all the others. The improvement was there for targeted tests (~20%) but washed out completely to noise in an actual training run, especially because the Muon compute is split across all the workers. Abandoned due to complexity bloat.
|
||||
- Fuse Q,K,V,O nn.Linear layers into a single QKVO Linear layer. ~Zero impact
|
||||
- Tried the `sa_lambdas` that gate QKV and O. Slightly confused because of the use of rmsnorm, which erases the effect of any scalar multiplier. Helped a tiny bit (~1e-4 of loss), abandoned to control complexity.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-17: Various experiments
|
||||
|
||||
Modded-nanogpt uses [Value Embeddings](https://arxiv.org/abs/2410.17897) (VEs) in a funny U-shaped structure, 3 of them in total and with gates. I tried a large number of tweaks on this today:
|
||||
|
||||
- VEs at every layer, at alternating layers, U shaped, front and back. Alternating layers worked best, i.e. we end up with *a lot* more VEs than modded-nanogpt, at every other layer. It works better.
|
||||
- Many parameters sharing ideas to reduce new parameter count, nothing here worked. All failed.
|
||||
- Many ideas to reduce parameter count, the LLM hates all of them: low rank decompositions, projections. All failed.
|
||||
- Gated yes or no and how much. Gate helps.
|
||||
|
||||
Long story short is that the models *love* Value Embeddings. It is a way to add a huge amount of capacity (parameters) to the model at almost zero cost of FLOPs, because these embeddings are simply added to the Values tensor. Any attempt to reduce the capacity of value embeddings (param sharing, low rank, projections) fail. The model wants many of them, and with all the capacity, and doing so wins across all x axes of steps, flops and wall clock. I re-ran the scaling laws and, because the models are now very parameter bloated, the optimal ratio has halved from 8 to 4! Way down lower than Chinchilla's 20 at this point.
|
||||
|
||||
Other experiments, looking at val/bpb as a function of all of steps, flops and wall clock time:
|
||||
|
||||
- Aspect ratio of 128 is worse than 64, I tried a sweep fixing FLOPs == 1e18 and 64 outperforms. The LLM prefers to be slightly thinner and longer.
|
||||
- Head dim definitely prefers to be 128 instead of 64, i.e. fewer bigger heads
|
||||
- Bunch of other random stuff like that.
|
||||
|
||||
Keeping all of this work on a private branch for now but hope to push shortly.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-17: Modded-nanogpt Ideas Sweep (Continued)
|
||||
|
||||
Continued testing ideas from modded-nanogpt.
|
||||
|
||||
| Idea | Result | Notes |
|
||||
|------|--------|-------|
|
||||
| Attention gates | No improvement | Per-head learnable gates on attention output. +1GB memory, decreased efficiency. |
|
||||
| Batch size schedule | Abandoned | 8→16→24 with LR scaling. Made training script too bloated/complex, not worth cognitive overhead. |
|
||||
| Value embeddings | Helps a lot | Experiments still ongoing, more on this later. |
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-16: Flash Attention 3 Fallback to SDPA
|
||||
|
||||
Added automatic fallback from Flash Attention 3 to PyTorch's `scaled_dot_product_attention` (SDPA) for users without Hopper GPUs. This enables nanochat to run on older CUDA GPUs, CPU, and MPS (Apple Silicon).
|
||||
|
||||
### Implementation
|
||||
|
||||
Created `nanochat/flash_attention.py` - a unified interface that:
|
||||
- Detects FA3 availability at import time (requires sm90+ / Hopper)
|
||||
- Exports a `flash_attn` object matching FA3's API exactly (`flash_attn.flash_attn_func`, `flash_attn.flash_attn_with_kvcache`)
|
||||
- Automatically routes to FA3 or SDPA based on hardware
|
||||
- Handles tensor layout differences: FA3 uses (B, T, H, D), SDPA uses (B, H, T, D)
|
||||
- Implements sliding window attention via explicit masks for SDPA
|
||||
- Manages KV cache manually for SDPA (FA3 does it in-place)
|
||||
|
||||
### Changes to Existing Files
|
||||
|
||||
Changes to existing code were intentionally kept extremely minimal.
|
||||
|
||||
**gpt.py**: Only the import line changed and a comment
|
||||
|
||||
**engine.py**: Zero changes needed
|
||||
|
||||
**base_train.py**: Added status print and warnings:
|
||||
- Prints whether FA3 or SDPA fallback is being used
|
||||
- Warns about efficiency loss without FA3
|
||||
- Warns about sliding window support if `--window-pattern` is not "L"
|
||||
|
||||
### Testing
|
||||
|
||||
Tests are split into two classes due to dtype/device constraints:
|
||||
|
||||
1. **TestFA3VsSDPA**: Comparison tests requiring Hopper GPU + bfloat16. Run both implementations on identical inputs and verify outputs match (max diff typically 0, at most ~0.004 for sliding window).
|
||||
|
||||
2. **TestSDPAOnly**: SDPA-only tests that run on any device with appropriate dtype. Verify forward pass, backward pass, and KV cache work correctly.
|
||||
|
||||
Added `_override_impl` mechanism for testing - can force 'fa3' or 'sdpa' to directly compare implementations.
|
||||
|
||||
### Notes
|
||||
|
||||
- SDPA fallback is significantly slower than FA3 especially in that it lacks the sliding window attention support
|
||||
- Recommend `--window-pattern L` (full context) when using SDPA fallback
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-16: Modded-nanogpt Ideas Sweep (Mostly Negative)
|
||||
|
||||
Tested several architectural ideas from modded-nanogpt to see if they transfer to nanochat. All of these did not help:
|
||||
|
||||
| Idea | Result | Notes |
|
||||
|------|--------|-------|
|
||||
| Half-truncated RoPE | No improvement | Only first half of head dims get RoPE (base 1024, linspace). Second half "stationary". |
|
||||
| Asymmetric softcap | Slightly worse | `23 * sigmoid((x+5)/7.5)` vs our symmetric `15 * tanh(x/15)`. May only help with FP8. |
|
||||
| Smear gate | Negligible | Blend each token with predecessor via learned gate. Tiny improvement not worth n_embd² params. |
|
||||
| Backout | No improvement | Save activations at ~60% through network, subtract scaled version at end. |
|
||||
| Skip connection | Slightly worse | Save at layer ~25%, add at layer ~50%. Also +2GB memory from storing activations. |
|
||||
|
||||
Value Embeddings do show promise. I need a more elaborate exploration of a few related ideas, which I leave for tomorrow.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-15: Olmo pretraining mix (Negative result)
|
||||
|
||||
I attempted to train on the Olmo 3 pretraining dataset [allenai/dolma3_mix-6T](https://huggingface.co/datasets/allenai/dolma3_mix-6T) instead of FineWeb-edu. I ran into a number of [errors and issues](https://huggingface.co/datasets/allenai/dolma3_mix-6T/discussions/2) trying to both download and process the dataset and then noticed some quality issues (e.g. some documents seem to be extremely short, like "5".). I managed to work around these with some sensible hacks (e.g. reject documents less than 100 characters in length) and tried to process the dataset exactly as FineWeb, re-trained the tokenizer and trained a d16 model. The CORE score decreased from 15.5 to 13.8, i.e. the result is quite a bit worse.
|
||||
|
||||
I am still looking to try the [DCLM dataset](https://arxiv.org/abs/2406.11794), which according to the paper should be better that FineWeb-edu. I do have some concerns that the same group both prepared the DCLM dataset *and* introduced the CORE score so I'm a bit hesitant in case there was some overfitting to CORE score adjacent data distribution.
|
||||
|
||||
Classifying as negative result and reverting back to FineWeb-edu for now.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-13: Varlen Attention (Negative Result)
|
||||
|
||||
Attempted to prevent attention from "leaking" across document boundaries using Flash Attention's `flash_attn_varlen_func`, similar to modded-nanogpt's approach.
|
||||
|
||||
### Background
|
||||
|
||||
With the BOS-aligned dataloader, multiple documents are packed into each row. Standard attention allows tokens to attend across document boundaries within a row. The hypothesis was that preventing this "leakage" via varlen attention might improve training.
|
||||
|
||||
### Approach: Compute cu_seqlens from inputs
|
||||
|
||||
- Find BOS positions: `(inputs.view(-1) == bos_token_id).nonzero()`
|
||||
- Gotcha 1: Variable-length `cu_seqlens` caused torch.compile recompilation (25s/iter!) - fixed by padding to fixed size
|
||||
- Gotcha 2: `nonzero()` inside compiled model hit recompile limit - fixed by moving computation outside compiled region
|
||||
|
||||
### Final Results (d16)
|
||||
|
||||
| Metric | Baseline | Varlen |
|
||||
|--------|----------|--------|
|
||||
| val_bpb | 0.85427 | 0.85407 |
|
||||
| MFU | ~same | ~same |
|
||||
| tok/sec | ~same | ~same |
|
||||
|
||||
Essentially identical. The 0.0002 bpb improvement is almost noise.
|
||||
|
||||
### Conclusion
|
||||
|
||||
Not worth the code complexity. The "leakage" across document boundaries within a row is not harmful - the model handles it fine. The BOS-aligned dataloader already provides the key benefit (every row starts with proper context). Not merging to master.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-13: BOS-Aligned Dataloader with Bin Packing
|
||||
|
||||
Redesigned the pretraining and midtraining dataloader to ensure every sequence starts with a BOS token, and explored bin-packing algorithms to minimize wasted tokens.
|
||||
|
||||
### Problem Statement
|
||||
|
||||
The original dataloader streams tokens into a flat buffer and reshapes into batches. This means some rows start mid-document (no BOS), which could confuse the model during training. We want every row to start with BOS and contain well-formed documents.
|
||||
|
||||
### Approach 1: Greedy-Crop BOS (Simple)
|
||||
|
||||
Each row is built independently:
|
||||
- Start with a document (which has BOS prepended)
|
||||
- Pack more documents until row is full
|
||||
- If a document doesn't fit, **crop it** to fill remaining space (discard the rest)
|
||||
- 100% utilization (no padding), but wastes cropped tokens
|
||||
|
||||
### Waste Analysis
|
||||
|
||||
Measured token waste empirically on real data (T=2048):
|
||||
- **39.4% of tokens are cropped** (discarded when docs don't fit)
|
||||
- **22.9% is the theoretical minimum** (tokens in docs longer than T+1 that can never fit)
|
||||
- The extra ~16.5% comes from "unlucky" cropping when a long doc starts near the end of a row
|
||||
|
||||
### Bin Packing Algorithms Explored
|
||||
|
||||
| Algorithm | Util% | Crop% | Pad% | Notes |
|
||||
|-----------|-------|-------|------|-------|
|
||||
| Greedy-Crop (baseline) | 100% | 39.4% | 0% | Simple, no wasted compute |
|
||||
| Greedy-Pad | 78% | 23.0% | 22% | Pads instead of crops - wastes compute |
|
||||
| First-Fit Decreasing (FFD) | 99.7% | 23.0% | 0.3% | Near-optimal packing, minimal padding |
|
||||
| **BestFit-Crop** | 100% | 34.6% | 0% | Smart cropping, no padding |
|
||||
|
||||
### BestFit-Crop Algorithm
|
||||
|
||||
A middle ground that maintains 100% utilization while reducing cropping:
|
||||
|
||||
1. Buffer N documents
|
||||
2. For each row, greedily pick the **largest doc that fits entirely**
|
||||
3. Repeat until nothing fits
|
||||
4. When nothing fits, crop a doc to fill remaining space exactly
|
||||
|
||||
This avoids "unlucky" crops by searching the buffer for better-fitting documents.
|
||||
|
||||
**Results (T=2048):**
|
||||
- Crop waste reduced from 39.4% → 34.6% (~12% relative improvement)
|
||||
- Still achieves 100% utilization (no padding, every token trains)
|
||||
- Slightly more rows than baseline (uses more documents per batch)
|
||||
|
||||
### Decision: Keep Two Implementations
|
||||
|
||||
1. Keep the original implementation which is very simple, efficient and has 100% token utilization in the batch (no padding with ignore tokens), but creates slightly more confusing token streams for the LLM because documents during training can start abruptly from the middle with no context. Note that this never happens at test time, where BOS is always present.
|
||||
|
||||
2. **`_bos_bestfit` (BestFit-Crop, new default)**: Slightly more complex but still keeps 100% token utilization in the batch (no padding), but at the cost of discarding documents when they don't fit. In practice, about 34% of tokens are discarded with this approach. This is ok because for most models we care about we have plenty of data without having to go to multiple epochs. One more subtle effect is that it does skew the data distribution a tiny bit because, reliably and necessarily, tokens at the tails of long documents will be discarded. However, this doesn't seem to impact actual downstream performance.
|
||||
|
||||
### Midtraining
|
||||
|
||||
The midtraining dataloader was also updated. Because conversations are on average a lot shorter than pretraining documents, only about 3.3% of tokens get cropped.
|
||||
|
||||
### NOTE: loss scale
|
||||
|
||||
Do note that switching to the BOS dataloader changes the validation loss and makes all previous experiments not comparable in absolute value of the loss, because we have a lot fewer "confusing" tokens in the train/val batches. All tokens can look back and find the BOS token and have the full context of that document to make predictions. Therefore, the loss appears lower but this is "fake" to some extent, and the expectation is that the vast majority of relative comparisons done so far would agree with those before and after this change.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-13: Number Token Split Pattern
|
||||
|
||||
Validated the `\p{N}{1,2}` pattern in `SPLIT_PATTERN` (tokenizer.py line 30), which I only guessed earlier and had a TODO for to validate. GPT-4 uses `\p{N}{1,3}` to group number sequences of up to 3 digits into tokens, but we suspected smaller vocab sizes benefit from grouping fewer digits per token.
|
||||
|
||||
**Results (d12, vocab=32K):**
|
||||
| Pattern | val_bpb |
|
||||
|---------|---------|
|
||||
| `\p{N}{1,1}` | 0.969 |
|
||||
| `\p{N}{1,2}` | **0.965** |
|
||||
| `\p{N}{1,3}` | 0.972 |
|
||||
|
||||
**Conclusion:** `{1,2}` is optimal for vocab size 32K. Grouping 3 digits wastes tokens on rare 3-digit combinations; grouping 1 digit is too fine-grained and bloats token sequences. Keeping `{1,2}` as default.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-13: FP8 Training for lm_head
|
||||
|
||||
Attempted to use FP8 (8-bit floating point) for the lm_head layer to speed up the large vocab projection matmul. H100 GPUs have FP8 tensor cores that can theoretically provide ~2x speedup over BF16.
|
||||
|
||||
### Implementation Approaches Tried
|
||||
|
||||
**1. Dynamic Scaling (failed)**
|
||||
- Compute `x.abs().max()` and `w.abs().max()` each forward to determine scales
|
||||
- Problem: `.item()` calls cause graph breaks with torch.compile
|
||||
- Tried `@torch._dynamo.allow_in_graph` pattern (like torchao.float8) - worked but no speedup
|
||||
- Tried `torch.library.custom_op` with float scales - caused NaN gradients after first optimizer step
|
||||
- Root cause: interaction between custom ops, dynamic scale computation, and torch.compile is fragile
|
||||
|
||||
**2. Static Scaling (partial success)**
|
||||
- Pre-set scales at init time like modded-nanogpt: `x_scale=10/448, w_scale=0.1/448`
|
||||
- `grad_scale` computed dynamically from batch size (safe since it's just `1/(B*T)/57344` due to the gradient expression of cross entropy). modded-nanogpt has a bug here probably because they set `grad_scale = 0.75/448`, but grads are in E5M2 so this should probably be `1/57344`, 1 being the amax of any individual element of cross entropy loss, and no normalization by B,T because they use sum reduction not mean reduction.
|
||||
- Uses `torch.library.custom_op` with `@torch.compile` on inner kernels
|
||||
- This works correctly - no NaNs, proper gradients
|
||||
|
||||
### Results (d12)
|
||||
|
||||
| Metric | BF16 Baseline | FP8 lm_head |
|
||||
|--------|---------------|-------------|
|
||||
| GPU Memory | 34 GB | 36 GB |
|
||||
| tok/sec | baseline | ~1% faster |
|
||||
|
||||
### The Memory Mystery
|
||||
|
||||
FP8 *should* save memory since we store `x_f8` (1 byte) instead of `x` (2 bytes) for backward. But we see 2GB *increase*. Suspected causes:
|
||||
- `torch.compile` on inner kernels creating extra buffers/specializations
|
||||
- `torch._scaled_mm` internal workspace allocations
|
||||
- Custom op registration machinery overhead
|
||||
|
||||
Tried saving original weight `w` (just a reference to parameter) instead of `w_f8` in backward, then re-quantizing on the spot during backward - didn't help. Still saw bump.
|
||||
|
||||
### Microbenchmark vs Reality
|
||||
|
||||
Raw microbenchmark showed promise:
|
||||
- BF16 matmul: 16.95 ms
|
||||
- FP8 matmul (static scales): 10.31 ms (1.64x faster)
|
||||
- FP8 with dynamic scaling: 12.25 ms (1.38x faster)
|
||||
|
||||
But in full training, the ~1% tok/sec improvement doesn't justify the 2GB memory increase and the added code complexity and the need to tune scale factors for both x and w.
|
||||
|
||||
### Code Artifacts
|
||||
|
||||
See the branch `fp8_attempt_fail` for:
|
||||
|
||||
- `nanochat/fp8_static.py` - Static scaling implementation (working)
|
||||
- `nanochat/fp8_dynamic.py` - Dynamic scaling implementation (torchao-style, working but slow)
|
||||
- `gpt.py` imports `fp8_static.LinearFP8` and simply swaps it for `lm_head` in `gpt.py`.
|
||||
|
||||
### Open Questions
|
||||
|
||||
- Why does the custom op approach use more memory than vanilla BF16?
|
||||
- Why is the bump in tok_per_sec so low? We should see ~1.6X speedup in both the forward pass and also (twice) in backward pass for the gradients. Granted, Ahmdal's law is part of the solution because our vocab_size is only 32K so the final layer isn't a huge part of the profile but the expected speedup is still not fully realized.
|
||||
|
||||
**Conclusion:** Negative result for now. The implementation works correctly but provides marginal speedup with *increased* memory usage. I'm not understanding the torch.compile interaction here. The complexity of FP8 custom ops isn't justified for lm_head alone. TODO to study in more detail the way this is implemented in other libraries, e.g. torchao.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-12: Multi-Token Prediction (MTP)
|
||||
|
||||
Ported multi-token prediction from modded-nanogpt. Instead of predicting just the next token, predict the next n tokens at each position with weighted loss.
|
||||
|
||||
### Implementation
|
||||
|
||||
- Instead of calling the loss `n_predict` times, uses a fancy batched computation using `unfold` + `gather` + cross-entropy decomposition (`CE = logsumexp - logits[target]`)
|
||||
- Schedule anneals from 3-token to 1-token prediction:
|
||||
- 0-33%: `[1.0, 0.5, 0.25→0]` (3rd token fades)
|
||||
- 33-67%: `[1.0, 0.5→0]` (2nd token fades)
|
||||
- 67-100%: `[1.0]` (standard next-token)
|
||||
- Weights normalized to sum to 1
|
||||
|
||||
### Results (d12)
|
||||
|
||||
| Metric | Baseline | MTP |
|
||||
|--------|----------|-----|
|
||||
| GPU Memory | 34 GB | 47 GB |
|
||||
| MFU | 41% | 40% |
|
||||
| val/bpb (per step) | baseline | same/slightly worse |
|
||||
| val/bpb (wall clock) | baseline | noticeably worse |
|
||||
|
||||
**Conclusion:** Negative result for nanochat. The extra memory and compute overhead from predicting multiple tokens doesn't pay off, in fact the results get worse. The auxiliary loss signal may help in other settings (larger models, different architectures?), but for our setup it's pure overhead at the moment.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-11: Sliding Window Attention
|
||||
|
||||
Added configurable sliding window attention, inspired by GPT-3's alternating short/long pattern.
|
||||
|
||||
**Pattern string configuration:**
|
||||
- New `--window_pattern` CLI arg and `GPTConfig.window_pattern` field
|
||||
- Pattern is tiled across layers (e.g., `SSSL` for 20 layers → `SSSLSSSLSSSLSSSLSSSL`)
|
||||
- Final layer always forced to L (full context) regardless of pattern
|
||||
- Short window = `sequence_len // 2`
|
||||
- Long window = `sequence_len` (full context)
|
||||
- All previous models so far have been simply `L` and checkpoint loading is modified accordingly to fill in this param for old models, see `_patch_missing_config_keys`
|
||||
|
||||
Quick experiments showed `SSSL` (every 4th layer is long) works well - provides a good balance between compute savings and model quality. This is now the default.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-11: Flash Attention 3 Integration
|
||||
|
||||
Replaced PyTorch's `scaled_dot_product_attention` (FA2) with Flash Attention 3 for training and inference.
|
||||
|
||||
### Changes Made
|
||||
|
||||
**1. FA3 via `kernels` package**
|
||||
- Official FA3 is "beta" and requires building from source (painful)
|
||||
- Using `kernels` package from HuggingFace Hub: `get_kernel('varunneal/flash-attention-3')`
|
||||
- Loads pre-built wheels, works out of the box on H100
|
||||
|
||||
**2. Simplified attention code**
|
||||
- FA3 uses `(B, T, H, D)` layout matching our projection output directly - no transpose needed
|
||||
- Training: `flash_attn.flash_attn_func(q, k, v, causal=True)`
|
||||
- Inference: `flash_attn.flash_attn_with_kvcache()` handles all cache cases in one call
|
||||
- Removed 3 separate FA2 code paths (training, single-token, chunk inference)
|
||||
- GQA handled automatically when n_kv_heads < n_heads
|
||||
|
||||
**3. Rewrote KVCache for FA3**
|
||||
- Old format: `(num_layers, 2, B, H, T, D)` combined tensor
|
||||
- New format: separate `k_cache` and `v_cache` of shape `(num_layers, B, T, H, D)`
|
||||
- FA3 updates cache in-place during `flash_attn_with_kvcache`
|
||||
- Position tracked via `cache_seqlens` tensor (int32, per batch element)
|
||||
- Simpler API: `get_layer_cache()`, `advance()`, `reset()`, `prefill()`
|
||||
|
||||
### Results
|
||||
|
||||
- **~9% improvement in tok/sec** during training out of the box
|
||||
- Benchmarks showed FA3 is 2x faster than FA2 at realistic training sizes (batch=32, seq=2048)
|
||||
- FA3 supports sliding window via `window_size=(left, 0)`, which is huge and expected to give further improvements. This is ready to tune but keeping full context for now.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-11: Per-Layer Residual Scalars (x0 & resid lambdas)
|
||||
|
||||
Cherry-picked an idea from modded-nanogpt around learnable per-layer residual connections.
|
||||
|
||||
### Changes Made
|
||||
|
||||
**1. x0_lambdas (x0 residual connections)**
|
||||
- Save initial normalized embedding as `x0` after `norm(wte(idx))`
|
||||
- At each layer, blend x0 back in: `x = resid_lambdas[i] * x + x0_lambdas[i] * x0`
|
||||
- Zero-initialized, so disabled at start; model learns which layers benefit from the shortcut
|
||||
- Provides direct path from embedding to deep layers, helps preserve token information
|
||||
|
||||
**2. resid_lambdas (residual stream scaling)**
|
||||
- Per-layer multiplicative scaling of the residual stream
|
||||
- Initialized to 1.0 (neutral, standard transformer behavior)
|
||||
- Allows model to learn to amplify/dampen residual at each layer
|
||||
|
||||
**3. DistAdamW small parameter handling**
|
||||
- Added support for parameters with < 1024 elements (like the scalar lambdas)
|
||||
- Small params use `all_reduce` instead of `reduce_scatter`/`all_gather`
|
||||
- Fixes crash when param shape isn't divisible by world_size
|
||||
|
||||
### Key Finding: Different LR Sensitivity
|
||||
|
||||
The two scalar types need very different learning rates:
|
||||
- **x0_lambdas (additive)**: Can use normal LR (~0.5). Adding a fraction of x0 is forgiving.
|
||||
- **resid_lambdas (multiplicative)**: Needs ~100x smaller LR (~0.005). Multiplying the residual compounds through layers.
|
||||
|
||||
Implementation: `resid_params` gets `scalar_lr * 0.01`, `x0_params` gets full `scalar_lr`.
|
||||
|
||||
### Experiment Results
|
||||
|
||||
Swept `--scalar_lr` (controlling x0_lambdas) at multiple depths:
|
||||
|
||||
| Depth | Baseline (disabled) | Best scalar_lr | Best val_bpb | Δ bpb |
|
||||
|-------|---------------------|----------------|--------------|-------|
|
||||
| d8 | 1.0885 | 0.20 | 1.0782 | -0.0103 |
|
||||
| d12 | 0.9770 | 0.60 | 0.9693 | -0.0077 |
|
||||
| d16 | 0.9059 | 0.20 | 0.9002 | -0.0057 |
|
||||
| d20 | 0.8565 | 0.10 | 0.8526 | -0.0039 |
|
||||
|
||||
**Observations:**
|
||||
- Consistent improvement across all model sizes
|
||||
- Optimal LR varies by depth; default of 0.5 is reasonable, but 0.6 is better for d12
|
||||
- Adding resid_lambdas (with 0.01x LR) gives small additional improvement over x0 alone
|
||||
|
||||
### Meta Device Footgun
|
||||
|
||||
Important lesson: `__init__` runs in meta device context, so any tensor values set there are fake. Must initialize actual values in `init_weights()`. Added docstring warning to `__init__`.
|
||||
|
||||
### Summary
|
||||
|
||||
Added `--scalar_lr` (default 0.5) controlling learnable per-layer scalars. The formula `x = resid_lambdas[i] * x + x0_lambdas[i] * x0` gives the model control over residual scaling and direct shortcuts to the initial embedding. Solid improvement with essentially no compute overhead.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-10: Muon Optimizer Upgrades & Cautious Weight Decay
|
||||
|
||||
Cherry-picked improvements from NorMuon (modded-nanogpt) into our simpler Muon implementation. Decided against using NorMuon directly due to hard-coded architecture assumptions (expects 32 params split 10 attn + 22 mlp), parameter labeling requirements, and complexity.
|
||||
|
||||
### Changes Made
|
||||
|
||||
**1. Polar Express Orthogonalization**
|
||||
- Replaced Newton-Schulz iteration with "Polar Express Sign Method" from [arxiv.org/pdf/2505.16932](https://arxiv.org/pdf/2505.16932)
|
||||
- Uses 5 different coefficient tuples (one per iteration) instead of fixed coefficients
|
||||
- Both methods kept in code for easy comparison (`zeropower_via_polar_express` vs `zeropower_via_newtonschulz5`)
|
||||
- **Result:** No dramatic/noticeable difference in training, but keeping the new Polar Express as default.
|
||||
|
||||
**2. Variance Reduction (NorMuon-style)**
|
||||
- Added low-rank variance estimator similar to Adafactor ([arxiv.org/pdf/2510.05491](https://arxiv.org/pdf/2510.05491))
|
||||
- Maintains `second_momentum_buffer` with shape `[rows, 1]` or `[1, cols]` (whichever is smaller)
|
||||
- Normalizes updates based on running per-row/col variance estimate (beta2=0.95)
|
||||
- Memory overhead: ~1/max(rows, cols) per param, negligible
|
||||
- **Result:** Led to a very small improvement, kept and enabled by default.
|
||||
|
||||
**3. Cautious Weight Decay**
|
||||
- Only decays weights where `update * weight >= 0` (same sign) from [arxiv.org/abs/2411.16085](https://arxiv.org/abs/2411.16085)
|
||||
- Standard WD always pulls toward zero; cautious WD skips decay when gradient is pushing weight away from zero
|
||||
- **Implementation note:** Had to inline the logic rather than use a separate `@torch.compile` function. Passing changing float values (like `weight_decay` during scheduling) as function arguments triggers recompilation. Reading from `group["weight_decay"]` inside the step avoids this.
|
||||
- **Result:** Solid improvements, especially the cautious version was better than standard wd.
|
||||
- Now defaults to ON for Muon via the `weight_decay` param. AdamW still has no weight decay and is hardcoded to 0 weight decay, might try to re-tune this later.
|
||||
|
||||
**4. Weight decay schedule**
|
||||
- Added a linear schedule to weight decay that is default on from 1.0 to 0.0 (i.e. start with max weight decay in the beginning of training, them ramp to 0 by the end). Worked better than a static setting in experiments. (modded-nanogpt has the same schedule but it is imlpemented in a more confusing way by multiplying twice by the learning rate, which is already wired up to a decay schedule).
|
||||
|
||||
### Weight Decay Scaling Experiments
|
||||
|
||||
Swept weight decay values at d8, d12, d16, d20 to find optimal values and scaling law.
|
||||
|
||||
**Optimal Values Found:**
|
||||
| Depth | Width (channels) | Optimal WD |
|
||||
|-------|------------------|------------|
|
||||
| d8 | 512 | ~0.40 |
|
||||
| d12 | 768 | ~0.22 |
|
||||
| d16 | 1024 | ~0.10 |
|
||||
| d20 | 1280 | ~0.08 |
|
||||
|
||||
**Scaling Law:**
|
||||
- Fit power law: `WD = k / channels^α` in log-log space
|
||||
- Found α ≈ 1.97 (approximately 2), meaning WD ∝ 1/width²
|
||||
|
||||
**Practical Formula:**
|
||||
```
|
||||
WD_target = WD_reference × (d_reference / d_target)²
|
||||
```
|
||||
Example: If d12 optimal is 0.22, then d20 optimal ≈ 0.22 × (12/20)² ≈ 0.08
|
||||
|
||||
**Reference:** Moonlight paper uses fixed WD=0.1 for their 15B MoE model. Our experiments indicated a scaling law where the optimal WD changed with depth, so we go along with the empirical scaling law.
|
||||
|
||||
### Summary
|
||||
|
||||
Muon was changed to use Polar Express, added Adafactor-style variance reduction, and cautious weight decay with schedule that ramps linearly to zero. All of these changes follow modded-nanogpt repo, but all of them were also validated piece by piece to yield improvements in nanochat with the exception of the Polar Express change which was in the noise. This is default on and configurable with `--weight_decay`, using simply 0.2 and ∝ 1/width² scaling. The kwarg `--weight_decay` is therefore changing as of this change. It used to configure AdamW via standard weight decay and now it becomes exclusively used in Muon (AdamW is hardcoded to 0.0), and it is scaled based on depth.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-08: exp_grad_clip - Gradient Clipping
|
||||
|
||||
**Hypothesis:** Gradient clipping may be unnecessary overhead. Tested L2 norm clipping at various thresholds (0.25, 0.5, 1.0, 2.0) and elementwise clipping.
|
||||
|
||||
**Results:**
|
||||
- No benefit at any scale tested (d12, d20)
|
||||
- All variants within noise (~0.9827 val_bpb)
|
||||
- Grad norm never exceeds 1.0 naturally, so clipping is always inactive
|
||||
- Clipping adds ~2% time overhead from the all-reduce
|
||||
|
||||
**Bug Found:** Original implementation clipped local gradients before sync. Since this codebase doesn't use DDP (gradient sync is in the optimizers), each rank was clipping based on its own local norm. Fixed on the branch with proper distributed all-reduce.
|
||||
|
||||
**Observartion:** modded-nanogpt does not appear to clip either right now.
|
||||
|
||||
**Summary:** Deleted all grad-clip code paths. The code naturally produces well-behaved gradients. This improves a bit of MFU because we don't have to calculate and sync grad norms.
|
||||
2190
dev/estimate_gpt3_core.ipynb
Normal file
2190
dev/estimate_gpt3_core.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
@@ -17,15 +17,14 @@ prompt:
|
||||
2. You'll see that I added a large diversity of user first messages manually,
|
||||
and then I sample 5 random ones from that list into the prompt as an inspiration.
|
||||
This is really important to do because DIVERSITY CONTROL is key. If you don't
|
||||
manually inject diversity, the LLM might generate extrremely similar and repeptitive
|
||||
manually inject diversity, the LLM might generate extremely similar and repetitive
|
||||
conversations and things won't work well. Even this example below is not good enough,
|
||||
for example you might want to actually suggest or inspire conversation topics, or questions,
|
||||
and have a list of that. Basically, this is the KEY creative part to get right. Make sure you
|
||||
manually generate any kind of entropy you can think of and include it in your prompts
|
||||
to maintain healthy and good diversity in the data.
|
||||
|
||||
NOTE: You need OpenRouter API key in a file called "openroutertoken.txt" in the root directory of the repo.
|
||||
(obviously you can tune this arbitrarily to your liking)
|
||||
NOTE: You need OPENROUTER_API_KEY set in .env or as an environment variable.
|
||||
NOTE: For more details see this discussion: https://github.com/karpathy/nanochat/discussions/139
|
||||
"""
|
||||
import requests
|
||||
@@ -34,10 +33,12 @@ import os
|
||||
import copy
|
||||
import random
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from nanochat.common import get_base_dir
|
||||
|
||||
api_key = open("openroutertoken.txt").read().strip()
|
||||
load_dotenv()
|
||||
api_key = os.environ["OPENROUTER_API_KEY"]
|
||||
|
||||
url = "https://openrouter.ai/api/v1/chat/completions"
|
||||
headers = {
|
||||
@@ -45,7 +46,7 @@ headers = {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
readme = open("README.md").read().strip()
|
||||
readme = open("README.md", "r", encoding="utf-8").read().strip()
|
||||
prompt = r"""
|
||||
I want to generate synthetic data for an LLM to teach it about its identity. Here is the identity I want:
|
||||
|
||||
|
||||
BIN
dev/nanochat.png
BIN
dev/nanochat.png
Binary file not shown.
|
Before Width: | Height: | Size: 19 KiB After Width: | Height: | Size: 1.3 KiB |
@@ -1,84 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks)
|
||||
# Run as:
|
||||
# bash dev/cpu_demo_run.sh
|
||||
|
||||
# NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook.
|
||||
# Think of this run as educational/fun demo, not something you should expect to work well.
|
||||
# This is also why I hide this script away in dev/
|
||||
|
||||
# all the setup stuff
|
||||
export OMP_NUM_THREADS=1
|
||||
NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
|
||||
mkdir -p $NANOCHAT_BASE_DIR
|
||||
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
[ -d ".venv" ] || uv venv
|
||||
uv sync
|
||||
source .venv/bin/activate
|
||||
if [ -z "$WANDB_RUN" ]; then
|
||||
WANDB_RUN=dummy
|
||||
fi
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
source "$HOME/.cargo/env"
|
||||
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
|
||||
EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
|
||||
if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
|
||||
curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL
|
||||
unzip -q eval_bundle.zip
|
||||
rm eval_bundle.zip
|
||||
mv eval_bundle $NANOCHAT_BASE_DIR
|
||||
fi
|
||||
|
||||
# wipe the report
|
||||
python -m nanochat.report reset
|
||||
|
||||
# train tokenizer on ~1B characters
|
||||
python -m nanochat.dataset -n 4
|
||||
python -m scripts.tok_train --max_chars=1000000000
|
||||
python -m scripts.tok_eval
|
||||
|
||||
# train a very small 4 layer model on the CPU
|
||||
# each optimization step processes a single sequence of 1024 tokens
|
||||
# we only run 50 steps of optimization (bump this to get better results)
|
||||
python -m scripts.base_train \
|
||||
--depth=4 \
|
||||
--max_seq_len=1024 \
|
||||
--device_batch_size=1 \
|
||||
--total_batch_size=1024 \
|
||||
--eval_every=50 \
|
||||
--eval_tokens=4096 \
|
||||
--core_metric_every=50 \
|
||||
--core_metric_max_per_task=12 \
|
||||
--sample_every=50 \
|
||||
--num_iterations=50
|
||||
python -m scripts.base_loss --device_batch_size=1 --split_tokens=4096
|
||||
python -m scripts.base_eval --max-per-task=5
|
||||
|
||||
# midtraining
|
||||
python -m scripts.mid_train \
|
||||
--max_seq_len=1024 \
|
||||
--device_batch_size=1 \
|
||||
--eval_every=50 \
|
||||
--eval_tokens=4096 \
|
||||
--total_batch_size=1024 \
|
||||
--num_iterations=100
|
||||
# eval results will be terrible, this is just to execute the code paths.
|
||||
# note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems
|
||||
python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20
|
||||
|
||||
# SFT
|
||||
python -m scripts.chat_sft \
|
||||
--device_batch_size=1 \
|
||||
--target_examples_per_step=4 \
|
||||
--num_iterations=100 \
|
||||
--eval_steps=4 \
|
||||
--eval_metrics_max_problems=16
|
||||
|
||||
# Chat CLI
|
||||
# python -m scripts.chat_cli -p "Why is the sky blue?"
|
||||
|
||||
# Chat Web
|
||||
# python -m scripts.chat_web
|
||||
|
||||
python -m nanochat.report generate
|
||||
373
dev/scaling_analysis.ipynb
Normal file
373
dev/scaling_analysis.ipynb
Normal file
@@ -0,0 +1,373 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Scaling Laws Analysis\n",
|
||||
"\n",
|
||||
"Analyze results from `scaling_laws.sh` to find the optimal param:data ratio for nanochat."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%matplotlib inline\n",
|
||||
"import os\n",
|
||||
"import pandas as pd\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"# Load results\n",
|
||||
"tag = \"jan26\"\n",
|
||||
"base_dir = os.environ.get('NANOCHAT_BASE_DIR', os.path.expanduser('~/.cache/nanochat'))\n",
|
||||
"results_path = os.path.join(base_dir, f'scaling_laws_results_{tag}', 'results.csv')\n",
|
||||
"\n",
|
||||
"df = pd.read_csv(results_path)\n",
|
||||
"flops_budgets = sorted(df['flops_budget'].unique())\n",
|
||||
"print(f\"Loaded {len(df)} runs across {len(flops_budgets)} FLOPs budgets\")\n",
|
||||
"print(f\"Columns: {list(df.columns)}\")\n",
|
||||
"df"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# =============================================================================\n",
|
||||
"# FILTERING: Remove incomplete or problematic runs\n",
|
||||
"# =============================================================================\n",
|
||||
"\n",
|
||||
"print(f\"Before filtering: {len(df)} runs\")\n",
|
||||
"\n",
|
||||
"# Filter out runs with missing/invalid val_bpb (incomplete runs)\n",
|
||||
"df = df[df['val_bpb'].notna() & (df['val_bpb'] > 0)]\n",
|
||||
"\n",
|
||||
"# Optional: exclude specific flops budgets that aren't done yet\n",
|
||||
"# exclude_flops = [1e19] # <-- adjust as runs complete\n",
|
||||
"# df = df[~df['flops_budget'].isin(exclude_flops)]\n",
|
||||
"\n",
|
||||
"# Optional: exclude specific depths\n",
|
||||
"# exclude_depths = [18, 20]\n",
|
||||
"# df = df[~df['depth'].isin(exclude_depths)]\n",
|
||||
"\n",
|
||||
"print(f\"After filtering: {len(df)} runs\")\n",
|
||||
"print(f\"FLOPs budgets: {sorted(df['flops_budget'].unique())}\")\n",
|
||||
"print(f\"Depths: {sorted(df['depth'].unique())}\")\n",
|
||||
"\n",
|
||||
"# Update flops_budgets list after filtering\n",
|
||||
"flops_budgets = sorted(df['flops_budget'].unique())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Effective Parameter Count\n",
|
||||
"\n",
|
||||
"Different scaling law papers use different conventions for counting parameters:\n",
|
||||
"- **Kaplan et al.** excluded embedding parameters (claimed cleaner laws)\n",
|
||||
"- **Chinchilla** included all parameters (and noted Kaplan had a bug)\n",
|
||||
"\n",
|
||||
"Our CSV now has granular counts:\n",
|
||||
"- `params_wte` - token embedding (lookup table)\n",
|
||||
"- `params_bigram_embed` - bigram hash embeddings (lookup table)\n",
|
||||
"- `params_value_embeds` - value embeddings (lookup table)\n",
|
||||
"- `params_lm_head` - unembedding projection (matmul)\n",
|
||||
"- `params_transformer` - attention + MLP matrices (matmuls)\n",
|
||||
"- `params_scalars` - resid/x0/bigram lambdas (tiny)\n",
|
||||
"\n",
|
||||
"**Experiment below** with different combinations to see which gives the cleanest scaling laws."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# =============================================================================\n",
|
||||
"# EXPERIMENT HERE: Define which parameters to count for scaling laws\n",
|
||||
"# =============================================================================\n",
|
||||
"\n",
|
||||
"def compute_effective_params(row):\n",
|
||||
" \"\"\"\n",
|
||||
" Compute the 'effective' parameter count for scaling law analysis.\n",
|
||||
"\n",
|
||||
" Modify this function to experiment with different conventions:\n",
|
||||
" - Chinchilla-style: include everything\n",
|
||||
" - Kaplan-style: exclude embeddings\n",
|
||||
" - Matmul-only: just transformer + lm_head (the actual compute)\n",
|
||||
" - etc.\n",
|
||||
" \"\"\"\n",
|
||||
" # Option 1: Chinchilla-style (all params)\n",
|
||||
" # return row['params_total']\n",
|
||||
"\n",
|
||||
" # Option 2: Kaplan-style (exclude embeddings)\n",
|
||||
" return row['params_transformer'] + row['params_lm_head']\n",
|
||||
"\n",
|
||||
" # Option 3: Transformer-only (exclude all embeddings AND lm_head)\n",
|
||||
" # return row['params_transformer']\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Compute derived columns\n",
|
||||
"df['effective_params'] = df.apply(compute_effective_params, axis=1)\n",
|
||||
"df['param_data_ratio'] = df['tokens_trained'] / df['effective_params']\n",
|
||||
"\n",
|
||||
"# Show parameter breakdown for first few rows\n",
|
||||
"print(\"Parameter breakdown (first row per flops budget):\")\n",
|
||||
"param_cols = ['depth', 'params_wte', 'params_bigram_embed', 'params_value_embeds',\n",
|
||||
" 'params_lm_head', 'params_transformer', 'params_scalars', 'params_total', 'effective_params']\n",
|
||||
"df.groupby('flops_budget').first()[param_cols]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## IsoFLOP Curves (à la Chinchilla)\n",
|
||||
"\n",
|
||||
"For each compute budget, plot loss vs model size. Looking for the U-shape valley that reveals the optimal model size for each FLOPs budget."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fig, axes = plt.subplots(1, 3, figsize=(16, 5))\n",
|
||||
"\n",
|
||||
"# Plot 1: IsoFLOP curves - Val BPB vs Parameters (the Chinchilla plot!)\n",
|
||||
"ax = axes[0]\n",
|
||||
"colors = plt.cm.viridis(np.linspace(0, 0.9, len(flops_budgets)))\n",
|
||||
"optimal_by_bpb = []\n",
|
||||
"\n",
|
||||
"for flops, color in zip(flops_budgets, colors):\n",
|
||||
" subset = df[df['flops_budget'] == flops].sort_values('effective_params')\n",
|
||||
" ax.plot(subset['effective_params'], subset['val_bpb'], 'o', color=color, label=f'{flops:.0e}', markersize=8)\n",
|
||||
"\n",
|
||||
" # Fit quadratic in log-space: val_bpb = a*(log N)^2 + b*(log N) + c\n",
|
||||
" log_params = np.log10(subset['effective_params'])\n",
|
||||
" coeffs = np.polyfit(log_params, subset['val_bpb'], 2)\n",
|
||||
" a, b, c = coeffs\n",
|
||||
"\n",
|
||||
" # Plot fitted curve (dashed)\n",
|
||||
" log_fit_x = np.linspace(log_params.min() - 0.1, log_params.max() + 0.1, 100)\n",
|
||||
" fit_y = a * log_fit_x**2 + b * log_fit_x + c\n",
|
||||
" ax.plot(10**log_fit_x, fit_y, '--', color=color, linewidth=2)\n",
|
||||
"\n",
|
||||
" # Find minimum of quadratic: d/dx(ax^2 + bx + c) = 0 => x = -b/(2a)\n",
|
||||
" if a > 0: # parabola opens upward (has a minimum)\n",
|
||||
" log_opt = -b / (2 * a)\n",
|
||||
" opt_params = 10**log_opt\n",
|
||||
" opt_bpb = a * log_opt**2 + b * log_opt + c\n",
|
||||
" # Mark the fitted optimal\n",
|
||||
" ax.scatter([opt_params], [opt_bpb], s=150, color=color,\n",
|
||||
" zorder=5, edgecolors='black', linewidths=2, marker='*')\n",
|
||||
" # Interpolate tokens and ratio from actual data (don't use C≈6ND approximation)\n",
|
||||
" opt_tokens = np.interp(np.log10(opt_params), log_params, subset['tokens_trained'])\n",
|
||||
" opt_ratio = np.interp(np.log10(opt_params), log_params, subset['param_data_ratio'])\n",
|
||||
" optimal_by_bpb.append({'flops': flops, 'params': opt_params, 'tokens': opt_tokens, 'ratio': opt_ratio, 'bpb': opt_bpb})\n",
|
||||
" else:\n",
|
||||
" # Fallback to raw minimum if quadratic doesn't have minimum\n",
|
||||
" best_idx = subset['val_bpb'].idxmin()\n",
|
||||
" best = subset.loc[best_idx]\n",
|
||||
" ax.scatter([best['effective_params']], [best['val_bpb']], s=150, color=color,\n",
|
||||
" zorder=5, edgecolors='black', linewidths=2)\n",
|
||||
" optimal_by_bpb.append({'flops': flops, 'params': best['effective_params'],\n",
|
||||
" 'tokens': best['tokens_trained'], 'ratio': best['param_data_ratio'], 'bpb': best['val_bpb']})\n",
|
||||
"\n",
|
||||
"ax.set_xscale('log')\n",
|
||||
"ax.set_xlabel('Effective Parameters')\n",
|
||||
"ax.set_ylabel('Validation Loss (bpb)')\n",
|
||||
"ax.set_title('IsoFLOP Curves')\n",
|
||||
"ax.legend(title='FLOPs', loc='upper right')\n",
|
||||
"ax.grid(True, alpha=0.3)\n",
|
||||
"\n",
|
||||
"opt_df = pd.DataFrame(optimal_by_bpb)\n",
|
||||
"\n",
|
||||
"# Plot 2: Optimal model size vs compute (power law)\n",
|
||||
"ax = axes[1]\n",
|
||||
"ax.loglog(opt_df['flops'], opt_df['params'], 'o', markersize=10, color='#2ecc71')\n",
|
||||
"ax.set_xlabel('FLOPs')\n",
|
||||
"ax.set_ylabel('Optimal Parameters')\n",
|
||||
"ax.set_title('Optimal Model Size')\n",
|
||||
"ax.grid(True, alpha=0.3)\n",
|
||||
"\n",
|
||||
"# Fit and show power law\n",
|
||||
"if len(opt_df) >= 2:\n",
|
||||
" log_f = np.log10(opt_df['flops'])\n",
|
||||
" log_p = np.log10(opt_df['params'])\n",
|
||||
" slope, intercept = np.polyfit(log_f, log_p, 1)\n",
|
||||
" fit_f = np.logspace(log_f.min() - 0.5, log_f.max() + 0.5, 100)\n",
|
||||
" fit_p = 10**(intercept + slope * np.log10(fit_f))\n",
|
||||
" ax.plot(fit_f, fit_p, 'r--', alpha=0.7, label=f'N ∝ C^{slope:.2f}')\n",
|
||||
" ax.legend()\n",
|
||||
"\n",
|
||||
"# Plot 3: Optimal tokens vs compute (power law)\n",
|
||||
"ax = axes[2]\n",
|
||||
"ax.loglog(opt_df['flops'], opt_df['tokens'], 'o', markersize=10, color='#e74c3c')\n",
|
||||
"ax.set_xlabel('FLOPs')\n",
|
||||
"ax.set_ylabel('Optimal Tokens')\n",
|
||||
"ax.set_title('Optimal Training Tokens')\n",
|
||||
"ax.grid(True, alpha=0.3)\n",
|
||||
"\n",
|
||||
"# Fit and show power law\n",
|
||||
"if len(opt_df) >= 2:\n",
|
||||
" log_f = np.log10(opt_df['flops'])\n",
|
||||
" log_t = np.log10(opt_df['tokens'])\n",
|
||||
" slope, intercept = np.polyfit(log_f, log_t, 1)\n",
|
||||
" fit_f = np.logspace(log_f.min() - 0.5, log_f.max() + 0.5, 100)\n",
|
||||
" fit_t = 10**(intercept + slope * np.log10(fit_f))\n",
|
||||
" ax.plot(fit_f, fit_t, 'r--', alpha=0.7, label=f'D ∝ C^{slope:.2f}')\n",
|
||||
" ax.legend()\n",
|
||||
"\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()\n",
|
||||
"\n",
|
||||
"# Print the optimal points (from quadratic fits)\n",
|
||||
"print(\"\\nOptimal configurations (from quadratic fits):\")\n",
|
||||
"print(f\"{'FLOPs':<12} {'Eff Params':<15} {'Tokens':<15} {'Ratio':<10} {'Val BPB':<10}\")\n",
|
||||
"print(\"-\" * 65)\n",
|
||||
"for _, row in opt_df.iterrows():\n",
|
||||
" print(f\"{row['flops']:<12.0e} {int(row['params']):<15,} {int(row['tokens']):<15,} {row['ratio']:<10.1f} {row['bpb']:<10.4f}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# =============================================================================\n",
|
||||
"# Optimal Ratio Summary (from power law fits)\n",
|
||||
"# =============================================================================\n",
|
||||
"\n",
|
||||
"# From the power law fits: N ∝ C^a and D ∝ C^b\n",
|
||||
"# The ratio D/N ∝ C^(b-a). If a ≈ b, ratio is roughly constant.\n",
|
||||
"\n",
|
||||
"if len(opt_df) >= 2:\n",
|
||||
" log_f = np.log10(opt_df['flops'])\n",
|
||||
" log_p = np.log10(opt_df['params'])\n",
|
||||
" log_t = np.log10(opt_df['tokens'])\n",
|
||||
"\n",
|
||||
" # Fit power laws\n",
|
||||
" slope_n, intercept_n = np.polyfit(log_f, log_p, 1)\n",
|
||||
" slope_d, intercept_d = np.polyfit(log_f, log_t, 1)\n",
|
||||
"\n",
|
||||
" # The ratio D/N at a reference compute (geometric mean of our budgets)\n",
|
||||
" ref_flops = np.sqrt(opt_df['flops'].min() * opt_df['flops'].max())\n",
|
||||
" log_ref = np.log10(ref_flops)\n",
|
||||
"\n",
|
||||
" # Predicted optimal N and D at reference compute\n",
|
||||
" pred_log_n = intercept_n + slope_n * log_ref\n",
|
||||
" pred_log_d = intercept_d + slope_d * log_ref\n",
|
||||
" optimal_ratio = 10**(pred_log_d - pred_log_n)\n",
|
||||
"\n",
|
||||
" # Also compute from the fitted optimals directly (mean and std)\n",
|
||||
" mean_ratio = opt_df['ratio'].mean()\n",
|
||||
" std_ratio = opt_df['ratio'].std()\n",
|
||||
"\n",
|
||||
" print(\"=\" * 60)\n",
|
||||
" print(\"OPTIMAL RATIO SUMMARY\")\n",
|
||||
" print(\"=\" * 60)\n",
|
||||
" print(f\"\\nPower law exponents:\")\n",
|
||||
" print(f\" N ∝ C^{slope_n:.3f}\")\n",
|
||||
" print(f\" D ∝ C^{slope_d:.3f}\")\n",
|
||||
" print(f\" Ratio exponent (b-a): {slope_d - slope_n:.3f} (should be ~0 if ratio is constant)\")\n",
|
||||
" print(f\"\\nOptimal ratio (tokens per effective param):\")\n",
|
||||
" print(f\" From power law at C={ref_flops:.1e}: {optimal_ratio:.1f}\")\n",
|
||||
" print(f\" Mean across budgets: {mean_ratio:.1f} ± {std_ratio:.1f}\")\n",
|
||||
" print(f\" Chinchilla reference: 20\")\n",
|
||||
" print(f\"\\nPer-budget ratios: {[f'{r:.1f}' for r in opt_df['ratio'].values]}\")\n",
|
||||
"else:\n",
|
||||
" print(\"Need at least 2 flops budgets to compute power law fits\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Val BPB vs Depth and Ratio"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
|
||||
"\n",
|
||||
"# Plot 1: Val BPB vs Depth\n",
|
||||
"ax = axes[0]\n",
|
||||
"for flops in flops_budgets:\n",
|
||||
" subset = df[df['flops_budget'] == flops].sort_values('depth')\n",
|
||||
" ax.plot(subset['depth'], subset['val_bpb'], 'o-', label=f'{flops:.0e}')\n",
|
||||
" # Mark the best (lowest)\n",
|
||||
" best_idx = subset['val_bpb'].idxmin()\n",
|
||||
" best = subset.loc[best_idx]\n",
|
||||
" ax.scatter([best['depth']], [best['val_bpb']], s=100, zorder=5, edgecolors='black', linewidths=2)\n",
|
||||
"\n",
|
||||
"ax.set_xlabel('Depth')\n",
|
||||
"ax.set_ylabel('Val BPB (lower is better)')\n",
|
||||
"ax.set_title('Validation BPB vs Model Depth')\n",
|
||||
"ax.legend(title='FLOPs')\n",
|
||||
"ax.grid(True, alpha=0.3)\n",
|
||||
"\n",
|
||||
"# Plot 2: Val BPB vs Param:Data Ratio\n",
|
||||
"ax = axes[1]\n",
|
||||
"for flops in flops_budgets:\n",
|
||||
" subset = df[df['flops_budget'] == flops].sort_values('param_data_ratio')\n",
|
||||
" ax.plot(subset['param_data_ratio'], subset['val_bpb'], 'o-', label=f'{flops:.0e}')\n",
|
||||
" best_idx = subset['val_bpb'].idxmin()\n",
|
||||
" best = subset.loc[best_idx]\n",
|
||||
" ax.scatter([best['param_data_ratio']], [best['val_bpb']], s=100, zorder=5, edgecolors='black', linewidths=2)\n",
|
||||
"\n",
|
||||
"ax.axvline(x=20, color='red', linestyle='--', alpha=0.5, label='Chinchilla (20)')\n",
|
||||
"ax.set_xlabel('Param:Data Ratio (tokens/param)')\n",
|
||||
"ax.set_ylabel('Val BPB (lower is better)')\n",
|
||||
"ax.set_title('Val BPB vs Param:Data Ratio')\n",
|
||||
"ax.legend(title='FLOPs')\n",
|
||||
"ax.grid(True, alpha=0.3)\n",
|
||||
"\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
@@ -1,11 +1,42 @@
|
||||
"""
|
||||
Borrowed from modded-nanogpt. By Keller, @vagrawal, et al.
|
||||
Not a general optimizer! But works for our specific use.
|
||||
Distributed AdamW optimizer with a fused step function.
|
||||
A bunch of ideas (e.g. dist comms in slices) are borrowed from modded-nanogpt.
|
||||
"""
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
@torch.compile(dynamic=False, fullgraph=True)
|
||||
def adamw_step_fused(
|
||||
p: Tensor,
|
||||
grad: Tensor,
|
||||
exp_avg: Tensor,
|
||||
exp_avg_sq: Tensor,
|
||||
step_t: Tensor,
|
||||
lr_t: Tensor,
|
||||
beta1_t: Tensor,
|
||||
beta2_t: Tensor,
|
||||
eps_t: Tensor,
|
||||
wd_t: Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
Fused AdamW step: weight_decay -> momentum_update -> bias_correction -> param_update
|
||||
All in one compiled graph to eliminate Python overhead between ops.
|
||||
The 0-D CPU tensors avoid recompilation when hyperparameter values change.
|
||||
"""
|
||||
# Weight decay (decoupled, applied before the update)
|
||||
p.mul_(1 - lr_t * wd_t)
|
||||
# Update running averages (lerp_ is cleaner and fuses well)
|
||||
exp_avg.lerp_(grad, 1 - beta1_t)
|
||||
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
|
||||
# Bias corrections
|
||||
bias1 = 1 - beta1_t ** step_t
|
||||
bias2 = 1 - beta2_t ** step_t
|
||||
# Compute update and apply
|
||||
denom = (exp_avg_sq / bias2).sqrt() + eps_t
|
||||
step_size = lr_t / bias1
|
||||
p.add_(exp_avg / denom, alpha=-step_size)
|
||||
|
||||
|
||||
class DistAdamW(torch.optim.Optimizer):
|
||||
"""
|
||||
@@ -14,25 +45,51 @@ class DistAdamW(torch.optim.Optimizer):
|
||||
"""
|
||||
def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01):
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
# Validate
|
||||
if rank == 0:
|
||||
for group in param_groups:
|
||||
assert isinstance(group, dict), "expecting param_groups to be a list of dicts"
|
||||
assert isinstance(group['params'], list), "expecting group['params'] to be a list of tensors"
|
||||
for p in group['params']:
|
||||
sliced = p.numel() >= 1024
|
||||
print(f"AdamW: 1 param of shape {p.shape}, sliced={sliced}")
|
||||
if sliced: # large parameter tensors will be operated on in slices
|
||||
assert p.shape[0] % world_size == 0, f"First dim of parameter shape {p.shape} must be divisible by world size {world_size}"
|
||||
super().__init__(param_groups, defaults)
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
self._step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
|
||||
@torch.compile
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
reduce_scatter_futures: list[torch.Future] = []
|
||||
all_reduce_futures: list[torch.Future] = []
|
||||
reduce_futures: list[torch.Future] = []
|
||||
gather_futures: list[torch.Future] = []
|
||||
grad_slices = []
|
||||
is_small = [] # track which params are small (use all_reduce) vs large (use reduce_scatter)
|
||||
|
||||
for group in self.param_groups:
|
||||
params: list[Tensor] = group["params"]
|
||||
grad = torch.empty_like(params[-1]) # TODO is this bug? seems to be over-written instantly
|
||||
for base_i in range(len(params)):
|
||||
grad = params[base_i].grad
|
||||
rank_size = grad.shape[0] // world_size
|
||||
grad_slice = torch.empty_like(grad[:rank_size])
|
||||
reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
|
||||
grad_slices.append(grad_slice)
|
||||
for p in params:
|
||||
grad = p.grad
|
||||
# Small params: use all_reduce (no scatter/gather needed)
|
||||
if p.numel() < 1024:
|
||||
is_small.append(True)
|
||||
reduce_futures.append(dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
|
||||
grad_slices.append(grad)
|
||||
else:
|
||||
is_small.append(False)
|
||||
rank_size = grad.shape[0] // world_size # p.shape[0] % world_size == 0 is checked in __init__
|
||||
grad_slice = torch.empty_like(grad[:rank_size])
|
||||
reduce_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
|
||||
grad_slices.append(grad_slice)
|
||||
|
||||
idx = 0
|
||||
for group in self.param_groups:
|
||||
@@ -40,38 +97,47 @@ class DistAdamW(torch.optim.Optimizer):
|
||||
eps = group['eps']
|
||||
wd = group['weight_decay']
|
||||
params = group['params']
|
||||
for base in range(len(params)):
|
||||
reduce_scatter_futures[idx].wait()
|
||||
p = params[base]
|
||||
rank_size = p.shape[0] // world_size
|
||||
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
|
||||
for p in params:
|
||||
reduce_futures[idx].wait()
|
||||
g_slice = grad_slices[idx]
|
||||
lr = group['lr'] * getattr(p, "lr_mul", 1.0)
|
||||
state = self.state[p]
|
||||
g_slice = grad_slices[idx]
|
||||
|
||||
# For small params, operate on full param; for large, operate on slice
|
||||
if is_small[idx]:
|
||||
p_slice = p
|
||||
else:
|
||||
rank_size = p.shape[0] // world_size
|
||||
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
|
||||
|
||||
# State init
|
||||
if not state:
|
||||
state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device)
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p_slice)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p_slice)
|
||||
exp_avg = state['exp_avg']
|
||||
exp_avg_sq = state['exp_avg_sq']
|
||||
state['step'] += 1
|
||||
t = state['step']
|
||||
# weight decay
|
||||
if wd != 0:
|
||||
eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0)
|
||||
p_slice.mul_(1 - eff_weight_decay)
|
||||
# update running averages
|
||||
exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2)
|
||||
# bias corrections
|
||||
bias1 = 1 - beta1 ** t
|
||||
bias2 = 1 - beta2 ** t
|
||||
# compute step
|
||||
denom = exp_avg_sq.sqrt().add_(eps)
|
||||
step_size = lr * (torch.sqrt(bias2) / bias1)
|
||||
update = exp_avg.div(denom).mul_(step_size)
|
||||
p_slice.add_(other=update, alpha=-1.0)
|
||||
|
||||
# Fill 0-D tensors with current values
|
||||
eff_wd = wd * getattr(p, "wd_mul", 1.0)
|
||||
self._step_t.fill_(state['step'])
|
||||
self._lr_t.fill_(lr)
|
||||
self._beta1_t.fill_(beta1)
|
||||
self._beta2_t.fill_(beta2)
|
||||
self._eps_t.fill_(eps)
|
||||
self._wd_t.fill_(eff_wd)
|
||||
|
||||
# Fused update: weight_decay -> momentum -> bias_correction -> param_update
|
||||
adamw_step_fused(
|
||||
p_slice, g_slice, exp_avg, exp_avg_sq,
|
||||
self._step_t, self._lr_t, self._beta1_t, self._beta2_t, self._eps_t, self._wd_t,
|
||||
)
|
||||
|
||||
# Only large params need all_gather
|
||||
if not is_small[idx]:
|
||||
gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
|
||||
idx += 1
|
||||
all_reduce_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
|
||||
torch.futures.collect_all(all_reduce_futures).wait()
|
||||
|
||||
if gather_futures:
|
||||
torch.futures.collect_all(gather_futures).wait()
|
||||
|
||||
@@ -20,37 +20,56 @@ def log0(message):
|
||||
if int(os.environ.get('RANK', 0)) == 0:
|
||||
logger.info(message)
|
||||
|
||||
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data):
|
||||
assert int(os.environ.get('RANK', 0)) == 0 # prevent footguns for now
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
# Save the model state (parameters)
|
||||
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
||||
torch.save(model_data, model_path)
|
||||
log0(f"Saved model file to: {model_path}")
|
||||
# Save the optimizer state (useful for SFT or any other fine-tuning)
|
||||
def _patch_missing_config_keys(model_config_kwargs):
|
||||
"""Add default values for new config keys missing in old checkpoints."""
|
||||
# Old models were trained with full context (no sliding window)
|
||||
if "window_pattern" not in model_config_kwargs:
|
||||
model_config_kwargs["window_pattern"] = "L"
|
||||
log0(f"Patching missing window_pattern in model config to 'L'")
|
||||
|
||||
def _patch_missing_keys(model_data, model_config):
|
||||
"""Add default values for new parameters that may be missing in old checkpoints."""
|
||||
n_layer = model_config.n_layer
|
||||
# resid_lambdas defaults to 1.0 (identity scaling)
|
||||
if "resid_lambdas" not in model_data:
|
||||
model_data["resid_lambdas"] = torch.ones(n_layer)
|
||||
log0(f"Patching missing resid_lambdas in model data to 1.0")
|
||||
# x0_lambdas defaults to 0.0 (disabled)
|
||||
if "x0_lambdas" not in model_data:
|
||||
model_data["x0_lambdas"] = torch.zeros(n_layer)
|
||||
log0(f"Patching missing x0_lambdas in model data to 0.0")
|
||||
|
||||
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
|
||||
if rank == 0:
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
# Save the model state parameters
|
||||
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
||||
torch.save(model_data, model_path)
|
||||
logger.info(f"Saved model parameters to: {model_path}")
|
||||
# Save the metadata dict as json
|
||||
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
||||
with open(meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(meta_data, f, indent=2)
|
||||
logger.info(f"Saved metadata to: {meta_path}")
|
||||
# Note that optimizer state is sharded across ranks, so each rank must save its own.
|
||||
if optimizer_data is not None:
|
||||
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
||||
torch.save(optimizer_data, optimizer_path)
|
||||
log0(f"Saved optimizer file to: {optimizer_path}")
|
||||
# Save the metadata dict as json
|
||||
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
||||
with open(meta_path, "w") as f:
|
||||
json.dump(meta_data, f, indent=2)
|
||||
log0(f"Saved metadata file to: {meta_path}")
|
||||
logger.info(f"Saved optimizer state to: {optimizer_path}")
|
||||
|
||||
|
||||
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False):
|
||||
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
|
||||
# Load the model state
|
||||
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
||||
model_data = torch.load(model_path, map_location=device)
|
||||
# Load the optimizer state if requested
|
||||
optimizer_data = None
|
||||
if load_optimizer:
|
||||
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
|
||||
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
||||
optimizer_data = torch.load(optimizer_path, map_location=device)
|
||||
# Load the metadata
|
||||
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
||||
with open(meta_path, "r") as f:
|
||||
with open(meta_path, "r", encoding="utf-8") as f:
|
||||
meta_data = json.load(f)
|
||||
return model_data, optimizer_data, meta_data
|
||||
|
||||
@@ -65,11 +84,19 @@ def build_model(checkpoint_dir, step, device, phase):
|
||||
"""
|
||||
assert phase in ["train", "eval"], f"Invalid phase: {phase}"
|
||||
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
|
||||
if device.type in {"cpu", "mps"}:
|
||||
# Convert bfloat16 tensors to float for CPU inference
|
||||
model_data = {
|
||||
k: v.float() if v.dtype == torch.bfloat16 else v
|
||||
for k, v in model_data.items()
|
||||
}
|
||||
# Hack: fix torch compile issue, which prepends all keys with _orig_mod.
|
||||
model_data = {k.lstrip("_orig_mod."): v for k, v in model_data.items()}
|
||||
model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
|
||||
model_config_kwargs = meta_data["model_config"]
|
||||
_patch_missing_config_keys(model_config_kwargs)
|
||||
log0(f"Building model with config: {model_config_kwargs}")
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
_patch_missing_keys(model_data, model_config)
|
||||
with torch.device("meta"):
|
||||
model = GPT(model_config)
|
||||
# Load the model state
|
||||
@@ -84,15 +111,15 @@ def build_model(checkpoint_dir, step, device, phase):
|
||||
# Load the Tokenizer
|
||||
tokenizer = get_tokenizer()
|
||||
# Sanity check: compatibility between model and tokenizer
|
||||
assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"]
|
||||
assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"], f"Tokenizer vocab size {tokenizer.get_vocab_size()} does not match model config vocab size {model_config_kwargs['vocab_size']}"
|
||||
return model, tokenizer, meta_data
|
||||
|
||||
|
||||
def find_largest_model(checkpoint_dir):
|
||||
def find_largest_model(checkpoints_dir):
|
||||
# attempt to guess the model tag: take the biggest model available
|
||||
model_tags = [f for f in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, f))]
|
||||
model_tags = [f for f in os.listdir(checkpoints_dir) if os.path.isdir(os.path.join(checkpoints_dir, f))]
|
||||
if not model_tags:
|
||||
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
|
||||
raise FileNotFoundError(f"No checkpoints found in {checkpoints_dir}")
|
||||
# 1) normally all model tags are of the form d<number>, try that first:
|
||||
candidates = []
|
||||
for model_tag in model_tags:
|
||||
@@ -104,7 +131,7 @@ def find_largest_model(checkpoint_dir):
|
||||
candidates.sort(key=lambda x: x[0], reverse=True)
|
||||
return candidates[0][1]
|
||||
# 2) if that failed, take the most recently updated model:
|
||||
model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
|
||||
model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoints_dir, x)), reverse=True)
|
||||
return model_tags[0]
|
||||
|
||||
|
||||
|
||||
@@ -5,8 +5,10 @@ Common utilities for nanochat.
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
import urllib.request
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from filelock import FileLock
|
||||
|
||||
class ColoredFormatter(logging.Formatter):
|
||||
"""Custom formatter that adds colors to log messages."""
|
||||
@@ -56,6 +58,42 @@ def get_base_dir():
|
||||
os.makedirs(nanochat_dir, exist_ok=True)
|
||||
return nanochat_dir
|
||||
|
||||
def download_file_with_lock(url, filename, postprocess_fn=None):
|
||||
"""
|
||||
Downloads a file from a URL to a local path in the base directory.
|
||||
Uses a lock file to prevent concurrent downloads among multiple ranks.
|
||||
"""
|
||||
base_dir = get_base_dir()
|
||||
file_path = os.path.join(base_dir, filename)
|
||||
lock_path = file_path + ".lock"
|
||||
|
||||
if os.path.exists(file_path):
|
||||
return file_path
|
||||
|
||||
with FileLock(lock_path):
|
||||
# Only a single rank can acquire this lock
|
||||
# All other ranks block until it is released
|
||||
|
||||
# Recheck after acquiring lock
|
||||
if os.path.exists(file_path):
|
||||
return file_path
|
||||
|
||||
# Download the content as bytes
|
||||
print(f"Downloading {url}...")
|
||||
with urllib.request.urlopen(url) as response:
|
||||
content = response.read() # bytes
|
||||
|
||||
# Write to local file
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(content)
|
||||
print(f"Downloaded to {file_path}")
|
||||
|
||||
# Run the postprocess function if provided
|
||||
if postprocess_fn is not None:
|
||||
postprocess_fn(file_path)
|
||||
|
||||
return file_path
|
||||
|
||||
def print0(s="",**kwargs):
|
||||
ddp_rank = int(os.environ.get('RANK', 0))
|
||||
if ddp_rank == 0:
|
||||
@@ -64,23 +102,35 @@ def print0(s="",**kwargs):
|
||||
def print_banner():
|
||||
# Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/
|
||||
banner = """
|
||||
█████ █████
|
||||
░░███ ░░███
|
||||
████████ ██████ ████████ ██████ ██████ ░███████ ██████ ███████
|
||||
░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███ ░░░███░
|
||||
░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███
|
||||
░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███
|
||||
████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░████████ ░░█████
|
||||
░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░
|
||||
"""
|
||||
█████ █████
|
||||
░░███ ░░███
|
||||
████████ ██████ ████████ ██████ ██████ ░███████ ██████ ███████
|
||||
░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███░░░███░
|
||||
░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███
|
||||
░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███
|
||||
████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░███████ ░░█████
|
||||
░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░
|
||||
"""
|
||||
print0(banner)
|
||||
|
||||
def is_ddp():
|
||||
# TODO is there a proper way
|
||||
return int(os.environ.get('RANK', -1)) != -1
|
||||
def is_ddp_requested() -> bool:
|
||||
"""
|
||||
True if launched by torchrun (env present), even before init.
|
||||
Used to decide whether we *should* initialize a PG.
|
||||
"""
|
||||
return all(k in os.environ for k in ("RANK", "LOCAL_RANK", "WORLD_SIZE"))
|
||||
|
||||
def is_ddp_initialized() -> bool:
|
||||
"""
|
||||
True if torch.distributed is available and the process group is initialized.
|
||||
Used at cleanup to avoid destroying a non-existent PG.
|
||||
"""
|
||||
return dist.is_available() and dist.is_initialized()
|
||||
|
||||
def get_dist_info():
|
||||
if is_ddp():
|
||||
if is_ddp_requested():
|
||||
# We rely on torchrun's env to decide if we SHOULD init.
|
||||
# (Initialization itself happens in compute init.)
|
||||
assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
|
||||
ddp_rank = int(os.environ['RANK'])
|
||||
ddp_local_rank = int(os.environ['LOCAL_RANK'])
|
||||
@@ -110,6 +160,8 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
|
||||
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
|
||||
|
||||
# Reproducibility
|
||||
# Note that we set the global seeds here, but most of the code uses explicit rng objects.
|
||||
# The only place where global rng might be used is nn.Module initialization of the model weights.
|
||||
torch.manual_seed(42)
|
||||
if device_type == "cuda":
|
||||
torch.cuda.manual_seed(42)
|
||||
@@ -118,13 +170,13 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
|
||||
|
||||
# Precision
|
||||
if device_type == "cuda":
|
||||
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
|
||||
torch.backends.cuda.matmul.fp32_precision = "tf32" # uses tf32 instead of fp32 for matmuls
|
||||
|
||||
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
if ddp and device_type == "cuda":
|
||||
is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
if is_ddp_requested and device_type == "cuda":
|
||||
device = torch.device("cuda", ddp_local_rank)
|
||||
torch.cuda.set_device(device) # make "cuda" default to this device
|
||||
torch.cuda.set_device(device) # make "cuda" default to this device
|
||||
dist.init_process_group(backend="nccl", device_id=device)
|
||||
dist.barrier()
|
||||
else:
|
||||
@@ -133,11 +185,11 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
|
||||
if ddp_rank == 0:
|
||||
logger.info(f"Distributed world size: {ddp_world_size}")
|
||||
|
||||
return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device
|
||||
return is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size, device
|
||||
|
||||
def compute_cleanup():
|
||||
"""Companion function to compute_init, to clean things up before script exit"""
|
||||
if is_ddp():
|
||||
if is_ddp_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
class DummyWandb:
|
||||
@@ -148,3 +200,77 @@ class DummyWandb:
|
||||
pass
|
||||
def finish(self):
|
||||
pass
|
||||
|
||||
# hardcoded BF16 peak flops for various GPUs
|
||||
# inspired by torchtitan: https://github.com/pytorch/torchtitan/blob/main/torchtitan/tools/utils.py
|
||||
# and PR: https://github.com/karpathy/nanochat/pull/147
|
||||
def get_peak_flops(device_name: str) -> float:
|
||||
name = device_name.lower()
|
||||
|
||||
# --- NVIDIA Blackwell ---
|
||||
if "gb200" in name or "grace blackwell" in name:
|
||||
return 2.5e15
|
||||
if "b200" in name:
|
||||
return 2.25e15
|
||||
if "b100" in name:
|
||||
return 1.8e15
|
||||
|
||||
# --- NVIDIA Hopper (H100/H200/H800) ---
|
||||
if "h200" in name:
|
||||
if "nvl" in name or "pcie" in name:
|
||||
return 836e12
|
||||
return 989e12 # H200 SXM
|
||||
if "h100" in name:
|
||||
if "nvl" in name:
|
||||
return 835e12
|
||||
if "pcie" in name:
|
||||
return 756e12
|
||||
return 989e12 # H100 SXM
|
||||
if "h800" in name:
|
||||
if "nvl" in name:
|
||||
return 989e12
|
||||
return 756e12 # H800 PCIe
|
||||
|
||||
# --- NVIDIA Ampere data center ---
|
||||
if "a100" in name or "a800" in name:
|
||||
return 312e12
|
||||
if "a40" in name:
|
||||
return 149.7e12
|
||||
if "a30" in name:
|
||||
return 165e12
|
||||
|
||||
# --- NVIDIA Ada data center ---
|
||||
if "l40s" in name or "l40-s" in name or "l40 s" in name:
|
||||
return 362e12
|
||||
if "l4" in name:
|
||||
return 121e12
|
||||
|
||||
# --- AMD CDNA accelerators ---
|
||||
if "mi355" in name:
|
||||
return 2.5e15
|
||||
if "mi325" in name or "mi300x" in name:
|
||||
return 1.3074e15
|
||||
if "mi300a" in name:
|
||||
return 980.6e12
|
||||
if "mi250x" in name:
|
||||
return 383e12
|
||||
if "mi250" in name:
|
||||
return 362.1e12
|
||||
|
||||
# --- Intel ---
|
||||
if "data center gpu max 1550" in name:
|
||||
# Ponte Vecchio (PVC) - dynamic based on compute units
|
||||
max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units
|
||||
return 512 * max_comp_units * 1300 * 10**6
|
||||
|
||||
# --- Consumer RTX (for hobbyists) ---
|
||||
if "5090" in name:
|
||||
return 209.5e12
|
||||
if "4090" in name:
|
||||
return 165.2e12
|
||||
if "3090" in name:
|
||||
return 71e12
|
||||
|
||||
# Unknown GPU - return inf so MFU shows as 0% rather than a wrong guess
|
||||
logger.warning(f"Peak flops undefined for: {device_name}, MFU will show as 0%")
|
||||
return float('inf')
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
"""
|
||||
Poor Man's Configurator. Probably a terrible idea. Example usage:
|
||||
$ python train.py config/override_file.py --batch_size=32
|
||||
this will first run config/override_file.py, then override batch_size to 32
|
||||
|
||||
The code in this file will be run as follows from e.g. train.py:
|
||||
>>> exec(open('configurator.py').read())
|
||||
|
||||
So it's not a Python module, it's just shuttling this code away from train.py
|
||||
The code in this script then overrides the globals()
|
||||
|
||||
I know people are not going to love this, I just really dislike configuration
|
||||
complexity and having to prepend config. to every single variable. If someone
|
||||
comes up with a better simple Python solution I am all ears.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from ast import literal_eval
|
||||
|
||||
def print0(s="",**kwargs):
|
||||
ddp_rank = int(os.environ.get('RANK', 0))
|
||||
if ddp_rank == 0:
|
||||
print(s, **kwargs)
|
||||
|
||||
for arg in sys.argv[1:]:
|
||||
if '=' not in arg:
|
||||
# assume it's the name of a config file
|
||||
assert not arg.startswith('--')
|
||||
config_file = arg
|
||||
print0(f"Overriding config with {config_file}:")
|
||||
with open(config_file) as f:
|
||||
print0(f.read())
|
||||
exec(open(config_file).read())
|
||||
else:
|
||||
# assume it's a --key=value argument
|
||||
assert arg.startswith('--')
|
||||
key, val = arg.split('=')
|
||||
key = key[2:]
|
||||
if key in globals():
|
||||
try:
|
||||
# attempt to eval it it (e.g. if bool, number, or etc)
|
||||
attempt = literal_eval(val)
|
||||
except (SyntaxError, ValueError):
|
||||
# if that goes wrong, just use the string
|
||||
attempt = val
|
||||
# ensure the types match ok
|
||||
if globals()[key] is not None:
|
||||
attempt_type = type(attempt)
|
||||
default_type = type(globals()[key])
|
||||
assert attempt_type == default_type, f"Type mismatch: {attempt_type} != {default_type}"
|
||||
# cross fingers
|
||||
print0(f"Overriding: {key} = {attempt}")
|
||||
globals()[key] = attempt
|
||||
else:
|
||||
raise ValueError(f"Unknown config key: {key}")
|
||||
@@ -1,48 +1,199 @@
|
||||
from collections import deque
|
||||
"""
|
||||
Distributed dataloaders for pretraining.
|
||||
|
||||
Two implementations are provided:
|
||||
|
||||
1. Original (tokenizing_distributed_data_loader):
|
||||
- Streams tokens into a flat buffer, reshapes to (B, T)
|
||||
- Rows may start mid-document (no guaranteed BOS at position 0)
|
||||
- 100% token utilization, simple and efficient
|
||||
|
||||
2. BOS-aligned bestfit (tokenizing_distributed_data_loader_bos_bestfit):
|
||||
- Every row starts with BOS token
|
||||
- Documents packed using best-fit algorithm to minimize cropping
|
||||
- When no document fits remaining space, crops a document to fill exactly
|
||||
- 100% utilization (no padding), ~35% tokens cropped at T=2048
|
||||
|
||||
The tradeoff: BOS-aligned loses ~35% of tokens to cropping, but ensures that
|
||||
there are fewer "confusing" tokens in the train/val batches as every token can
|
||||
now attend back to the BOS token and sees the full context of the document.
|
||||
(2) is the new default if you have enough data.
|
||||
Fallback to (1) if you have very limited data AND long documents.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from nanochat.common import get_dist_info
|
||||
from nanochat.dataset import parquets_iter_batched
|
||||
from nanochat.tokenizer import get_tokenizer
|
||||
from nanochat.dataset import list_parquet_files
|
||||
|
||||
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda"):
|
||||
"""Stream pretraining text from parquet files, tokenize, yield training batches."""
|
||||
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
||||
def _document_batches(split, resume_state_dict, tokenizer_batch_size):
|
||||
"""
|
||||
Infinite iterator over document batches (list of text strings) from parquet files.
|
||||
|
||||
Handles DDP sharding and approximate resume. Each yield is (text_batch, (pq_idx, rg_idx, epoch))
|
||||
where text_batch is a list of document strings, indices track position for resumption,
|
||||
and epoch counts how many times we've cycled through the dataset (starts at 1).
|
||||
"""
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
|
||||
# get the tokenizer and the bos token
|
||||
tokenizer = get_tokenizer()
|
||||
bos_token = tokenizer.get_bos_token_id()
|
||||
# scratch buffer holds the tokens for one iteration
|
||||
token_buffer = deque() # we stream tokens on the right and pop from the left
|
||||
|
||||
# infinite iterator over document batches
|
||||
def document_batches():
|
||||
while True:
|
||||
# batch will iterate in group size of the parquet files, usually e.g. 1024 rows
|
||||
for batch in parquets_iter_batched(split=split, start=ddp_rank, step=ddp_world_size):
|
||||
# for the tokenizer we might want to go in usually smaller batches, e.g. 128 rows
|
||||
parquet_paths = list_parquet_files()
|
||||
assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?"
|
||||
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
|
||||
|
||||
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
|
||||
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
|
||||
resume_epoch = resume_state_dict.get("epoch", 1) if resume_state_dict is not None else 1
|
||||
first_pass = True
|
||||
pq_idx = resume_pq_idx
|
||||
epoch = resume_epoch
|
||||
|
||||
while True: # iterate infinitely (multi-epoch)
|
||||
pq_idx = resume_pq_idx if first_pass else 0
|
||||
while pq_idx < len(parquet_paths):
|
||||
filepath = parquet_paths[pq_idx]
|
||||
pf = pq.ParquetFile(filepath)
|
||||
# Start from resume point if resuming on same file, otherwise from DDP rank
|
||||
if first_pass and (resume_rg_idx is not None) and (pq_idx == resume_pq_idx):
|
||||
base_idx = resume_rg_idx // ddp_world_size
|
||||
base_idx += 1 # advance by 1 so we don't repeat data after resuming
|
||||
rg_idx = base_idx * ddp_world_size + ddp_rank
|
||||
if rg_idx >= pf.num_row_groups:
|
||||
pq_idx += 1
|
||||
continue
|
||||
resume_rg_idx = None # only do this once
|
||||
else:
|
||||
rg_idx = ddp_rank
|
||||
while rg_idx < pf.num_row_groups:
|
||||
rg = pf.read_row_group(rg_idx)
|
||||
batch = rg.column('text').to_pylist()
|
||||
for i in range(0, len(batch), tokenizer_batch_size):
|
||||
yield batch[i:i+tokenizer_batch_size]
|
||||
batches = document_batches()
|
||||
yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx, epoch)
|
||||
rg_idx += ddp_world_size
|
||||
pq_idx += 1
|
||||
first_pass = False
|
||||
epoch += 1
|
||||
|
||||
|
||||
def tokenizing_distributed_data_loader_with_state(tokenizer, B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None):
|
||||
"""
|
||||
Stream pretraining text from parquet files, tokenize, yield training batches.
|
||||
|
||||
This is the original dataloader that streams tokens into a flat buffer and reshapes.
|
||||
Rows may start mid-document (no guaranteed BOS at position 0).
|
||||
|
||||
Supports approximate resume via state_dict.
|
||||
"""
|
||||
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
||||
|
||||
batches = _document_batches(split, resume_state_dict, tokenizer_batch_size)
|
||||
needed_tokens = B * T + 1 # +1 for target at last position
|
||||
bos_token = tokenizer.get_bos_token_id()
|
||||
token_buffer = []
|
||||
pq_idx, rg_idx, epoch = 0, 0, 1
|
||||
|
||||
batch_index = 0
|
||||
while True:
|
||||
# Accumulate enough tokens for one iteration before yielding.
|
||||
|
||||
# Accumulate enough tokens
|
||||
while len(token_buffer) < needed_tokens:
|
||||
doc_batch = next(batches)
|
||||
doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
|
||||
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
|
||||
for tokens in token_lists:
|
||||
token_buffer.extend(tokens)
|
||||
batch_index += 1
|
||||
# Move tokens from the deque into the scratch buffer
|
||||
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
|
||||
scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=True)
|
||||
# Create the inputs/targets as 1D tensors
|
||||
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
||||
targets_cpu = scratch[1:]
|
||||
# Reshape to 2D and move to GPU async
|
||||
inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=True)
|
||||
tokens = token_buffer[:needed_tokens] # Read B*T+1 tokens (+1 is only for the target for the last token)
|
||||
token_buffer = token_buffer[B*T:] # Advance by B*T tokens, so we move exactly one window of B*T tokens over
|
||||
|
||||
# Package tokens into inputs and targets, yield
|
||||
use_cuda = device == "cuda"
|
||||
scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda)
|
||||
inputs = scratch[:-1].view(B, T).to(device=device, non_blocking=use_cuda)
|
||||
targets = scratch[1:].view(B, T).to(device=device, non_blocking=use_cuda)
|
||||
yield inputs, targets, {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
|
||||
|
||||
|
||||
def tokenizing_distributed_data_loader(*args, **kwargs):
|
||||
"""Helper that omits state_dict from yields."""
|
||||
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs):
|
||||
yield inputs, targets
|
||||
|
||||
|
||||
def tokenizing_distributed_data_loader_with_state_bos_bestfit(
|
||||
tokenizer, B, T, split,
|
||||
tokenizer_threads=4, tokenizer_batch_size=128,
|
||||
device="cuda", resume_state_dict=None,
|
||||
buffer_size=1000
|
||||
):
|
||||
"""
|
||||
BOS-aligned dataloader with Best-Fit Cropping.
|
||||
|
||||
Reduces token waste compared to simple greedy cropping by searching a buffer
|
||||
for documents that fit well, while maintaining 100% utilization (no padding).
|
||||
|
||||
Algorithm for each row:
|
||||
1. From buffered docs, pick the LARGEST doc that fits entirely
|
||||
2. Repeat until no doc fits
|
||||
3. When nothing fits, crop a doc to fill remaining space exactly
|
||||
|
||||
Key properties:
|
||||
- Every row starts with BOS
|
||||
- 100% utilization (no padding, every token is trained on)
|
||||
- Approximately 35% of all tokens are discarded due to cropping
|
||||
"""
|
||||
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
||||
|
||||
row_capacity = T + 1
|
||||
batches = _document_batches(split, resume_state_dict, tokenizer_batch_size)
|
||||
bos_token = tokenizer.get_bos_token_id()
|
||||
doc_buffer = []
|
||||
pq_idx, rg_idx, epoch = 0, 0, 1
|
||||
|
||||
def refill_buffer():
|
||||
nonlocal pq_idx, rg_idx, epoch
|
||||
doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
|
||||
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
|
||||
for tokens in token_lists:
|
||||
doc_buffer.append(tokens)
|
||||
|
||||
while True:
|
||||
rows = []
|
||||
for _ in range(B):
|
||||
row = []
|
||||
while len(row) < row_capacity:
|
||||
# Ensure buffer has documents
|
||||
while len(doc_buffer) < buffer_size:
|
||||
refill_buffer()
|
||||
|
||||
remaining = row_capacity - len(row)
|
||||
|
||||
# Find largest doc that fits entirely
|
||||
best_idx = -1
|
||||
best_len = 0
|
||||
for i, doc in enumerate(doc_buffer):
|
||||
doc_len = len(doc)
|
||||
if doc_len <= remaining and doc_len > best_len:
|
||||
best_idx = i
|
||||
best_len = doc_len
|
||||
|
||||
if best_idx >= 0:
|
||||
doc = doc_buffer.pop(best_idx)
|
||||
row.extend(doc)
|
||||
else:
|
||||
# No doc fits - crop shortest in buffer to fill remaining and minimize waste
|
||||
shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]))
|
||||
doc = doc_buffer.pop(shortest_idx)
|
||||
row.extend(doc[:remaining])
|
||||
|
||||
rows.append(row[:row_capacity])
|
||||
|
||||
use_cuda = device == "cuda"
|
||||
batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
|
||||
inputs = batch_tensor[:, :-1].to(device=device, non_blocking=use_cuda)
|
||||
targets = batch_tensor[:, 1:].to(device=device, non_blocking=use_cuda)
|
||||
|
||||
yield inputs, targets, {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
|
||||
|
||||
|
||||
def tokenizing_distributed_data_loader_bos_bestfit(*args, **kwargs):
|
||||
"""Helper that omits state_dict from yields."""
|
||||
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state_bos_bestfit(*args, **kwargs):
|
||||
yield inputs, targets
|
||||
|
||||
@@ -17,8 +17,9 @@ import signal
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from collections import deque
|
||||
from nanochat.common import compute_init
|
||||
from nanochat.common import compute_init, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from contextlib import nullcontext
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Calculator tool helpers
|
||||
@@ -37,92 +38,98 @@ def eval_with_timeout(formula, max_time=3):
|
||||
with timeout(max_time, formula):
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", SyntaxWarning)
|
||||
return eval(formula)
|
||||
return eval(formula, {"__builtins__": {}}, {})
|
||||
except Exception as e:
|
||||
signal.alarm(0)
|
||||
# print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage
|
||||
return None
|
||||
|
||||
def use_calculator(expr):
|
||||
"""Evaluate a math expression safely."""
|
||||
"""
|
||||
Evaluate a Python expression safely.
|
||||
Supports both math expressions and string operations like .count()
|
||||
"""
|
||||
# Remove commas from numbers
|
||||
expr = expr.replace(",", "")
|
||||
if any([x not in "0123456789*+-/.() " for x in expr]): # for now disallow non-numeric chars
|
||||
|
||||
# Check if it's a pure math expression (old behavior)
|
||||
if all([x in "0123456789*+-/.() " for x in expr]):
|
||||
if "**" in expr: # disallow power operator
|
||||
return None
|
||||
return eval_with_timeout(expr)
|
||||
|
||||
# Check if it's a string operation we support
|
||||
# Allow: strings (single/double quotes), .count(), letters, numbers, spaces, parens
|
||||
allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
|
||||
if not all([x in allowed_chars for x in expr]):
|
||||
return None
|
||||
if "**" in expr: # for now disallow power operator, could be very expensive
|
||||
|
||||
# Disallow dangerous patterns
|
||||
dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file',
|
||||
'input', 'raw_input', 'globals', 'locals', 'vars', 'dir',
|
||||
'getattr', 'setattr', 'delattr', 'hasattr']
|
||||
expr_lower = expr.lower()
|
||||
if any(pattern in expr_lower for pattern in dangerous_patterns):
|
||||
return None
|
||||
|
||||
# Only allow .count() method for now (can expand later)
|
||||
if '.count(' not in expr:
|
||||
return None
|
||||
|
||||
# Evaluate with timeout
|
||||
return eval_with_timeout(expr)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class KVCache:
|
||||
"""
|
||||
Works hand-in-hand with the GPT model to maintain the KV cache.
|
||||
Note that the .pos advances automatically after the last layer of the Transformer inserts.
|
||||
KV Cache designed for Flash Attention 3's flash_attn_with_kvcache API.
|
||||
|
||||
Key differences from FA2-style cache:
|
||||
- Tensors are (B, T, H, D) not (B, H, T, D)
|
||||
- FA3 updates the cache in-place during flash_attn_with_kvcache
|
||||
- Position tracked per batch element via cache_seqlens tensor
|
||||
"""
|
||||
|
||||
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers):
|
||||
# Each of K/V is of shape (B, H, T, D) and we have one per layer of the Transformer.
|
||||
self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
|
||||
self.kv_cache = None
|
||||
self.pos = 0 # current position in time in the cache
|
||||
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype):
|
||||
self.batch_size = batch_size
|
||||
self.max_seq_len = seq_len
|
||||
self.n_layers = num_layers
|
||||
self.n_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
# Pre-allocate cache tensors: (n_layers, B, T, H, D)
|
||||
self.k_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
|
||||
self.v_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
|
||||
# Current sequence length per batch element (FA3 needs int32)
|
||||
self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||
|
||||
def reset(self):
|
||||
self.pos = 0
|
||||
"""Reset cache to empty state."""
|
||||
self.cache_seqlens.zero_()
|
||||
|
||||
def get_pos(self):
|
||||
return self.pos
|
||||
"""Get current position (assumes all batch elements at same position)."""
|
||||
return self.cache_seqlens[0].item()
|
||||
|
||||
def get_layer_cache(self, layer_idx):
|
||||
"""Return (k_cache, v_cache) views for a specific layer."""
|
||||
return self.k_cache[layer_idx], self.v_cache[layer_idx]
|
||||
|
||||
def advance(self, num_tokens):
|
||||
"""Advance the cache position by num_tokens."""
|
||||
self.cache_seqlens += num_tokens
|
||||
|
||||
def prefill(self, other):
|
||||
"""
|
||||
Prefill given another KV cache. Optionally expand along batch dim.
|
||||
This is used when we do batch 1 prefill and then want to generate
|
||||
multiple samples in parallel from there.
|
||||
Copy cached KV from another cache into this one.
|
||||
Used when we do batch=1 prefill and then want to generate multiple samples in parallel.
|
||||
"""
|
||||
# 1) validate the shapes
|
||||
assert self.kv_cache is None, "Cannot prefill a non-empty KV cache"
|
||||
assert other.kv_cache is not None, "Cannot prefill with a None KV cache"
|
||||
for ix, (dim1, dim2) in enumerate(zip(self.kv_shape, other.kv_shape)):
|
||||
if ix in [0, 1, 3, 5]:
|
||||
# num_layers, batch_size, num_heads, head_dim must match
|
||||
assert dim1 == dim2, f"Batch dim mismatch: {dim1} != {dim2}"
|
||||
elif ix == 2:
|
||||
# batch_size can be expanded
|
||||
assert dim1 == dim2 or dim2 == 1, f"Batch dim mismatch: {dim1} != {dim2}"
|
||||
elif ix == 4:
|
||||
# seq_len: self must be longer than other
|
||||
assert dim1 >= dim2, f"Seq len mismatch: {dim1} < {dim2}"
|
||||
# 2) initialize the cache
|
||||
dtype, device = other.kv_cache.dtype, other.kv_cache.device
|
||||
self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)
|
||||
# 3) copy the data over
|
||||
self.kv_cache[:, :, :, :, :other.pos, :] = other.kv_cache
|
||||
# 4) update the pos
|
||||
self.pos = other.pos
|
||||
|
||||
def insert_kv(self, layer_idx, k, v):
|
||||
# Lazy initialize the cache here because we need to know the dtype/device
|
||||
if self.kv_cache is None:
|
||||
self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)
|
||||
# Insert new keys/values to the cache and return the full cache so far
|
||||
B, H, T_add, D = k.size()
|
||||
t0, t1 = self.pos, self.pos + T_add
|
||||
# Dynamically grow the cache if needed
|
||||
if t1 > self.kv_cache.size(4):
|
||||
t_needed = t1 + 1024 # as much as we need plus buffer of 1024
|
||||
t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024
|
||||
current_shape = list(self.kv_cache.shape)
|
||||
current_shape[4] = t_needed
|
||||
self.kv_cache.resize_(current_shape)
|
||||
# Insert k, v into the cache
|
||||
self.kv_cache[layer_idx, 0, :, :, t0:t1] = k
|
||||
self.kv_cache[layer_idx, 1, :, :, t0:t1] = v
|
||||
# Return the full cached keys/values up to current position (as a view)
|
||||
key_view = self.kv_cache[layer_idx, 0, :, :, :t1]
|
||||
value_view = self.kv_cache[layer_idx, 1, :, :, :t1]
|
||||
# Increment pos after the last layer of the Transformer processes
|
||||
if layer_idx == self.kv_cache.size(0) - 1:
|
||||
self.pos = t1
|
||||
return key_view, value_view
|
||||
|
||||
assert self.get_pos() == 0, "Cannot prefill a non-empty KV cache"
|
||||
assert self.n_layers == other.n_layers and self.n_heads == other.n_heads and self.head_dim == other.head_dim
|
||||
assert self.max_seq_len >= other.max_seq_len
|
||||
other_pos = other.get_pos()
|
||||
self.k_cache[:, :, :other_pos, :, :] = other.k_cache[:, :, :other_pos, :, :]
|
||||
self.v_cache[:, :, :other_pos, :, :] = other.v_cache[:, :, :other_pos, :, :]
|
||||
self.cache_seqlens.fill_(other_pos)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@torch.inference_mode()
|
||||
@@ -131,7 +138,7 @@ def sample_next_token(logits, rng, temperature=1.0, top_k=None):
|
||||
assert temperature >= 0.0, "temperature must be non-negative"
|
||||
if temperature == 0.0:
|
||||
return torch.argmax(logits, dim=-1, keepdim=True)
|
||||
if top_k is not None:
|
||||
if top_k is not None and top_k > 0:
|
||||
k = min(top_k, logits.size(-1))
|
||||
vals, idx = torch.topk(logits, k, dim=-1)
|
||||
vals = vals / temperature
|
||||
@@ -165,6 +172,13 @@ class Engine:
|
||||
"""Same as generate, but does single prefill and then clones the KV cache."""
|
||||
assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
|
||||
device = self.model.get_device()
|
||||
# NOTE: setting the dtype here and in this way is an ugly hack.
|
||||
# Currently the repo assumes that cuda -> bfloat16 and everything else -> float32.
|
||||
# We need to know the dtype here to call __init__ on KVCache and pre-allocate its tensors.
|
||||
# As a quick hack, we're making generate() function inherit and know about this repo-wise assumption.
|
||||
# I think there has to be a bigger refactor to deal with device/dtype tracking across the codebase.
|
||||
# In particular, the KVCache should allocate its tensors lazily
|
||||
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
|
||||
rng = torch.Generator(device=device)
|
||||
rng.manual_seed(seed)
|
||||
|
||||
@@ -183,19 +197,21 @@ class Engine:
|
||||
kv_cache_prefill = KVCache(
|
||||
batch_size=1,
|
||||
seq_len=len(tokens),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
**kv_model_kwargs,
|
||||
)
|
||||
ids = torch.tensor([tokens], dtype=torch.long, device=device)
|
||||
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
|
||||
logits = logits[:, -1, :]
|
||||
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
|
||||
sampled_tokens = next_ids[:, 0].tolist()
|
||||
logits = logits[:, -1, :].expand(num_samples, -1) # (num_samples, vocab_size)
|
||||
|
||||
# 2) Replicate the KV cache for each sample/row
|
||||
kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
|
||||
kv_cache_decode = KVCache(
|
||||
batch_size=num_samples,
|
||||
seq_len=kv_length_hint,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
**kv_model_kwargs,
|
||||
)
|
||||
kv_cache_decode.prefill(kv_cache_prefill)
|
||||
@@ -206,7 +222,6 @@ class Engine:
|
||||
|
||||
# 4) Main generation loop
|
||||
num_generated = 0
|
||||
first_iteration = True
|
||||
while True:
|
||||
# Stop condition: we've reached max tokens
|
||||
if max_tokens is not None and num_generated >= max_tokens:
|
||||
@@ -215,18 +230,9 @@ class Engine:
|
||||
if all(state.completed for state in row_states):
|
||||
break
|
||||
|
||||
# Get sampled tokens - either from prefill or from forward pass
|
||||
if first_iteration:
|
||||
# Use the tokens we already sampled from prefill
|
||||
sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows
|
||||
# TODO: we should sample a token for each row instead of broadcasting
|
||||
first_iteration = False
|
||||
else:
|
||||
# Forward the model and get the next token for each row
|
||||
logits = self.model.forward(ids, kv_cache=kv_cache_decode) # (B, T, vocab_size)
|
||||
logits = logits[:, -1, :] # (B, vocab_size) at last time step
|
||||
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
|
||||
sampled_tokens = next_ids[:, 0].tolist()
|
||||
# Sample the next token for each row
|
||||
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
|
||||
sampled_tokens = next_ids[:, 0].tolist()
|
||||
|
||||
# Process each row: choose the next token, update state, optional tool use
|
||||
token_column = [] # contains the next token id along each row
|
||||
@@ -263,8 +269,10 @@ class Engine:
|
||||
# Yield the token column
|
||||
yield token_column, token_masks
|
||||
num_generated += 1
|
||||
# Prepare ids for next iteration
|
||||
|
||||
# Prepare logits for next iteration
|
||||
ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)
|
||||
logits = self.model.forward(ids, kv_cache=kv_cache_decode)[:, -1, :] # (B, vocab_size)
|
||||
|
||||
def generate_batch(self, tokens, num_samples=1, **kwargs):
|
||||
"""
|
||||
@@ -298,7 +306,10 @@ if __name__ == "__main__":
|
||||
"""
|
||||
import time
|
||||
# init compute
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type()
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# load the model and tokenizer
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval")
|
||||
bos_token_id = tokenizer.get_bos_token_id()
|
||||
@@ -311,10 +322,11 @@ if __name__ == "__main__":
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
stream = model.generate(prompt_tokens, **kwargs)
|
||||
for token in stream:
|
||||
generated_tokens.append(token)
|
||||
chunk = tokenizer.decode([token])
|
||||
print(chunk, end="", flush=True)
|
||||
with autocast_ctx:
|
||||
for token in stream:
|
||||
generated_tokens.append(token)
|
||||
chunk = tokenizer.decode([token])
|
||||
print(chunk, end="", flush=True)
|
||||
print()
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.time()
|
||||
@@ -326,11 +338,12 @@ if __name__ == "__main__":
|
||||
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for token_column, token_masks in stream:
|
||||
token = token_column[0] # only print out the first row
|
||||
generated_tokens.append(token)
|
||||
chunk = tokenizer.decode([token])
|
||||
print(chunk, end="", flush=True)
|
||||
with autocast_ctx:
|
||||
for token_column, token_masks in stream:
|
||||
token = token_column[0] # only print out the first row
|
||||
generated_tokens.append(token)
|
||||
chunk = tokenizer.decode([token])
|
||||
print(chunk, end="", flush=True)
|
||||
print()
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.time()
|
||||
|
||||
@@ -127,8 +127,6 @@ def chdir(root):
|
||||
os.chdir(root)
|
||||
try:
|
||||
yield
|
||||
except BaseException as exc:
|
||||
raise exc
|
||||
finally:
|
||||
os.chdir(cwd)
|
||||
|
||||
|
||||
178
nanochat/flash_attention.py
Normal file
178
nanochat/flash_attention.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
Unified Flash Attention interface with automatic FA3/SDPA switching.
|
||||
|
||||
Exports `flash_attn` module that matches the FA3 API exactly, but falls back
|
||||
to PyTorch SDPA on non-Hopper GPUs, MPS, and CPU.
|
||||
|
||||
Usage (drop-in replacement for FA3):
|
||||
from nanochat.flash_attention import flash_attn
|
||||
|
||||
# Training (no KV cache)
|
||||
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
|
||||
|
||||
# Inference (with KV cache)
|
||||
y = flash_attn.flash_attn_with_kvcache(q, k_cache, v_cache, k=k, v=v, ...)
|
||||
"""
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Detection: Try to load FA3 on Hopper+ GPUs
|
||||
# =============================================================================
|
||||
def _load_flash_attention_3():
|
||||
"""Try to load Flash Attention 3 (requires Hopper+ GPU)."""
|
||||
if not torch.cuda.is_available():
|
||||
return None
|
||||
try:
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
if major < 9: # Hopper is sm90
|
||||
return None
|
||||
import os
|
||||
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
||||
from kernels import get_kernel
|
||||
return get_kernel('varunneal/flash-attention-3').flash_attn_interface
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
_fa3 = _load_flash_attention_3()
|
||||
HAS_FA3 = _fa3 is not None
|
||||
|
||||
# Override for testing: set to 'fa3', 'sdpa', or None (auto)
|
||||
_override_impl = None
|
||||
|
||||
|
||||
def _use_fa3():
|
||||
"""Determine whether to use FA3 based on availability and override."""
|
||||
if _override_impl == 'fa3':
|
||||
assert HAS_FA3, "Cannot override to FA3: not available on this hardware"
|
||||
return True
|
||||
if _override_impl == 'sdpa':
|
||||
return False
|
||||
return HAS_FA3 # auto
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SDPA helpers
|
||||
# =============================================================================
|
||||
def _sdpa_attention(q, k, v, window_size, enable_gqa):
|
||||
"""
|
||||
SDPA attention with sliding window support.
|
||||
q, k, v are (B, H, T, D) format.
|
||||
"""
|
||||
Tq = q.size(2)
|
||||
Tk = k.size(2)
|
||||
window = window_size[0]
|
||||
|
||||
# Full context, same length
|
||||
if (window < 0 or window >= Tq) and Tq == Tk:
|
||||
return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
|
||||
|
||||
# Single token generation
|
||||
if Tq == 1:
|
||||
return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
|
||||
|
||||
# Need explicit mask
|
||||
device = q.device
|
||||
if Tq == Tk:
|
||||
# Causal + sliding window
|
||||
mask = torch.tril(torch.ones(Tq, Tk, device=device, dtype=torch.bool))
|
||||
if window > 0 and window < Tq:
|
||||
row_idx = torch.arange(Tq, device=device).unsqueeze(1)
|
||||
col_idx = torch.arange(Tk, device=device).unsqueeze(0)
|
||||
mask = mask & ((row_idx - col_idx) <= window)
|
||||
else:
|
||||
# Chunk inference: attend to prefix + causal within chunk
|
||||
prefix_len = Tk - Tq
|
||||
mask = torch.zeros(Tq, Tk, device=device, dtype=torch.bool)
|
||||
mask[:, :prefix_len] = True
|
||||
mask[:, prefix_len:] = torch.tril(torch.ones(Tq, Tq, device=device, dtype=torch.bool))
|
||||
|
||||
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Public API: Same interface as FA3
|
||||
# =============================================================================
|
||||
def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
|
||||
"""
|
||||
Flash Attention for training (no KV cache).
|
||||
|
||||
Args:
|
||||
q, k, v: Tensors of shape (B, T, H, D)
|
||||
causal: Whether to use causal masking
|
||||
window_size: (left, right) sliding window. -1 means unlimited.
|
||||
|
||||
Returns:
|
||||
Output tensor of shape (B, T, H, D)
|
||||
"""
|
||||
if _use_fa3():
|
||||
return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
|
||||
|
||||
# SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D)
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
enable_gqa = q.size(1) != k.size(1)
|
||||
y = _sdpa_attention(q, k, v, window_size, enable_gqa)
|
||||
return y.transpose(1, 2) # back to (B, T, H, D)
|
||||
|
||||
|
||||
def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=None,
|
||||
causal=False, window_size=(-1, -1)):
|
||||
"""
|
||||
Flash Attention with KV cache for inference.
|
||||
|
||||
FA3 updates k_cache/v_cache in-place. Our SDPA fallback does the same.
|
||||
|
||||
Args:
|
||||
q: Queries, shape (B, T_new, H, D)
|
||||
k_cache, v_cache: Pre-allocated cache tensors, shape (B, T_max, H_kv, D)
|
||||
k, v: New keys/values to insert, shape (B, T_new, H_kv, D)
|
||||
cache_seqlens: Current position in cache, shape (B,) int32
|
||||
causal: Whether to use causal masking
|
||||
window_size: (left, right) sliding window. -1 means unlimited.
|
||||
|
||||
Returns:
|
||||
Output tensor of shape (B, T_new, H, D)
|
||||
"""
|
||||
if _use_fa3():
|
||||
return _fa3.flash_attn_with_kvcache(
|
||||
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
|
||||
causal=causal, window_size=window_size
|
||||
)
|
||||
|
||||
# SDPA fallback: manually manage KV cache
|
||||
B, T_new, H, D = q.shape
|
||||
pos = cache_seqlens[0].item() # assume uniform position across batch
|
||||
|
||||
# Insert new k, v into cache (in-place, matching FA3 behavior)
|
||||
if k is not None and v is not None:
|
||||
k_cache[:, pos:pos+T_new, :, :] = k
|
||||
v_cache[:, pos:pos+T_new, :, :] = v
|
||||
|
||||
# Get full cache up to current position + new tokens
|
||||
end_pos = pos + T_new
|
||||
k_full = k_cache[:, :end_pos, :, :]
|
||||
v_full = v_cache[:, :end_pos, :, :]
|
||||
|
||||
# Transpose to SDPA layout: (B, T, H, D) -> (B, H, T, D)
|
||||
q_sdpa = q.transpose(1, 2)
|
||||
k_sdpa = k_full.transpose(1, 2)
|
||||
v_sdpa = v_full.transpose(1, 2)
|
||||
|
||||
enable_gqa = q_sdpa.size(1) != k_sdpa.size(1)
|
||||
y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa)
|
||||
|
||||
return y_sdpa.transpose(1, 2) # back to (B, T, H, D)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Export: flash_attn module interface (drop-in replacement for FA3)
|
||||
# =============================================================================
|
||||
from types import SimpleNamespace
|
||||
flash_attn = SimpleNamespace(
|
||||
flash_attn_func=flash_attn_func,
|
||||
flash_attn_with_kvcache=flash_attn_with_kvcache,
|
||||
)
|
||||
394
nanochat/gpt.py
394
nanochat/gpt.py
@@ -8,10 +8,10 @@ Notable features:
|
||||
- norm after token embedding
|
||||
- no learnable params in rmsnorm
|
||||
- no bias in linear layers
|
||||
- Multi-Query Attention (MQA) support for more efficient inference
|
||||
- Group-Query Attention (GQA) support for more efficient inference
|
||||
- Flash Attention 3 integration
|
||||
"""
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from dataclasses import dataclass
|
||||
|
||||
@@ -23,14 +23,21 @@ from nanochat.common import get_dist_info, print0
|
||||
from nanochat.muon import Muon, DistMuon
|
||||
from nanochat.adamw import DistAdamW
|
||||
|
||||
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
|
||||
from nanochat.flash_attention import flash_attn
|
||||
|
||||
@dataclass
|
||||
class GPTConfig:
|
||||
sequence_len: int = 1024
|
||||
vocab_size: int = 50304
|
||||
sequence_len: int = 2048
|
||||
vocab_size: int = 32768
|
||||
n_layer: int = 12
|
||||
n_head: int = 6 # number of query heads
|
||||
n_kv_head: int = 6 # number of key/value heads (MQA)
|
||||
n_kv_head: int = 6 # number of key/value heads (GQA)
|
||||
n_embd: int = 768
|
||||
# Sliding window attention pattern string, tiled across layers. Final layer always L.
|
||||
# Characters: L=long (full context), S=short (half context)
|
||||
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
|
||||
window_pattern: str = "SSSL"
|
||||
|
||||
|
||||
def norm(x):
|
||||
@@ -38,28 +45,52 @@ def norm(x):
|
||||
return F.rms_norm(x, (x.size(-1),))
|
||||
|
||||
|
||||
class BigramEmbed(nn.Module):
|
||||
"""
|
||||
Hash bigrams to embeddings. Simple, self-contained, runs on GPU.
|
||||
Following modded-nanogpt's approach: single hash, no gating.
|
||||
|
||||
For each position t, hashes (token[t-1], token[t]) to an index in a large
|
||||
embedding table. This provides O(1) lookup for local 2-gram patterns,
|
||||
offloading static pattern reconstruction from the transformer layers.
|
||||
|
||||
Ref: https://github.com/KellerJordan/modded-nanogpt/pull/201
|
||||
Ref: https://arxiv.org/abs/1709.03933 (Hash Embeddings)
|
||||
"""
|
||||
def __init__(self, vocab_size: int, embed_dim: int, table_multiplier: int = 5):
|
||||
super().__init__()
|
||||
self.bigram_vocab_size = vocab_size * table_multiplier
|
||||
self.embed = nn.Embedding(self.bigram_vocab_size, embed_dim)
|
||||
|
||||
def forward(self, idx: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
idx: (B, T) token ids
|
||||
Returns: (B, T, embed_dim) bigram embeddings
|
||||
"""
|
||||
# Hash (prev_token, curr_token) -> index
|
||||
# Position 0 gets a reserved index (no valid bigram)
|
||||
rand_int_1 = 36313
|
||||
rand_int_2 = 27191
|
||||
mod = self.bigram_vocab_size - 1
|
||||
|
||||
h = torch.empty_like(idx, dtype=torch.long)
|
||||
h[:, 0] = mod # reserved index for position 0
|
||||
h[:, 1:] = (rand_int_1 * idx[:, 1:] ^ rand_int_2 * idx[:, :-1]) % mod
|
||||
|
||||
return self.embed(h)
|
||||
|
||||
|
||||
def has_ve(layer_idx, n_layer):
|
||||
"""Returns True if GPT layer should have Value Embedding (alternating, last layer always included)."""
|
||||
return layer_idx % 2 == (n_layer - 1) % 2
|
||||
|
||||
def apply_rotary_emb(x, cos, sin):
|
||||
assert x.ndim == 4 # multihead attention
|
||||
d = x.shape[3] // 2
|
||||
x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
|
||||
x1, x2 = x[..., :d], x[..., d:] # split up last dim into two halves
|
||||
y1 = x1 * cos + x2 * sin # rotate pairs of dims
|
||||
y2 = x1 * (-sin) + x2 * cos
|
||||
out = torch.cat([y1, y2], 3) # re-assemble
|
||||
out = out.to(x.dtype) # ensure input/output dtypes match
|
||||
return out
|
||||
|
||||
|
||||
def repeat_kv(x, n_rep):
|
||||
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
|
||||
if n_rep == 1:
|
||||
return x
|
||||
bs, n_kv_heads, slen, head_dim = x.shape
|
||||
return (
|
||||
x[:, :, None, :, :]
|
||||
.expand(bs, n_kv_heads, n_rep, slen, head_dim)
|
||||
.reshape(bs, n_kv_heads * n_rep, slen, head_dim)
|
||||
)
|
||||
|
||||
return torch.cat([y1, y2], 3)
|
||||
|
||||
class CausalSelfAttention(nn.Module):
|
||||
def __init__(self, config, layer_idx):
|
||||
@@ -75,53 +106,50 @@ class CausalSelfAttention(nn.Module):
|
||||
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
||||
self.ve_gate_channels = 32
|
||||
self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
|
||||
|
||||
def forward(self, x, cos_sin, kv_cache):
|
||||
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
||||
B, T, C = x.size()
|
||||
|
||||
# Project the input to get queries, keys, and values
|
||||
# Shape: (B, T, H, D) - FA3's native layout, no transpose needed!
|
||||
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
|
||||
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
|
||||
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
|
||||
|
||||
# Value residual (ResFormer): mix in value embedding with input-dependent gate per head
|
||||
if ve is not None:
|
||||
ve = ve.view(B, T, self.n_kv_head, self.head_dim)
|
||||
gate = 2 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) # (B, T, n_kv_head), range (0, 2)
|
||||
v = v + gate.unsqueeze(-1) * ve
|
||||
|
||||
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
|
||||
cos, sin = cos_sin
|
||||
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
|
||||
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
|
||||
q, k = norm(q), norm(k) # QK norm
|
||||
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
|
||||
|
||||
# Apply KV cache: insert current k,v into cache, get the full view so far
|
||||
if kv_cache is not None:
|
||||
k, v = kv_cache.insert_kv(self.layer_idx, k, v)
|
||||
Tq = q.size(2) # number of queries in this forward pass
|
||||
Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
|
||||
|
||||
# Apply MQA: replicate the key/value heads for each query head
|
||||
nrep = self.n_head // self.n_kv_head
|
||||
k, v = repeat_kv(k, nrep), repeat_kv(v, nrep)
|
||||
|
||||
# Attention: queries attend to keys/values autoregressively. A few cases to handle:
|
||||
if kv_cache is None or Tq == Tk:
|
||||
# During training (no KV cache), attend as usual with causal attention
|
||||
# And even if there is KV cache, we can still use this simple version when Tq == Tk
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
||||
elif Tq == 1:
|
||||
# During inference but with a single query in this forward pass:
|
||||
# The query has to attend to all the keys/values in the cache
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
|
||||
# Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere)
|
||||
# window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
|
||||
if kv_cache is None:
|
||||
# Training: causal attention with optional sliding window
|
||||
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
|
||||
else:
|
||||
# During inference AND we have a chunk of queries in this forward pass:
|
||||
# First, each query attends to all the cached keys/values (i.e. full prefix)
|
||||
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
|
||||
prefix_len = Tk - Tq
|
||||
if prefix_len > 0: # can't be negative but could be zero
|
||||
attn_mask[:, :prefix_len] = True
|
||||
# Then, causal attention within this chunk
|
||||
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
|
||||
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
# Inference: use flash_attn_with_kvcache which handles cache management
|
||||
k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
|
||||
y = flash_attn.flash_attn_with_kvcache(
|
||||
q, k_cache, v_cache,
|
||||
k=k, v=v,
|
||||
cache_seqlens=kv_cache.cache_seqlens,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
)
|
||||
# Advance position after last layer processes
|
||||
if self.layer_idx == kv_cache.n_layers - 1:
|
||||
kv_cache.advance(T)
|
||||
|
||||
# Re-assemble the heads side by side and project back to residual stream
|
||||
y = y.transpose(1, 2).contiguous().view(B, T, -1)
|
||||
# Re-assemble the heads and project back to residual stream
|
||||
y = y.contiguous().view(B, T, -1)
|
||||
y = self.c_proj(y)
|
||||
return y
|
||||
|
||||
@@ -145,24 +173,51 @@ class Block(nn.Module):
|
||||
self.attn = CausalSelfAttention(config, layer_idx)
|
||||
self.mlp = MLP(config)
|
||||
|
||||
def forward(self, x, cos_sin, kv_cache):
|
||||
x = x + self.attn(norm(x), cos_sin, kv_cache)
|
||||
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
||||
x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache)
|
||||
x = x + self.mlp(norm(x))
|
||||
return x
|
||||
|
||||
|
||||
class GPT(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, pad_vocab_size_to=64):
|
||||
"""
|
||||
NOTE a major footgun: this __init__ function runs in meta device context (!!)
|
||||
Therefore, any calculations inside here are shapes and dtypes only, no actual data.
|
||||
=> We actually initialize all data (parameters, buffers, etc.) in init_weights() instead.
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# Compute per-layer window sizes for sliding window attention
|
||||
# window_size is (left, right) tuple: (-1, 0) for full context, (N, 0) for sliding window
|
||||
self.window_sizes = self._compute_window_sizes(config)
|
||||
# Pad vocab for efficiency (DDP, tensor cores). This is just an optimization - outputs are cropped in forward().
|
||||
# https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings
|
||||
padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to
|
||||
if padded_vocab_size != config.vocab_size:
|
||||
print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} for efficiency")
|
||||
self.transformer = nn.ModuleDict({
|
||||
"wte": nn.Embedding(config.vocab_size, config.n_embd),
|
||||
"wte": nn.Embedding(padded_vocab_size, config.n_embd),
|
||||
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
||||
})
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
# To support meta device initialization, we init the rotary embeddings here, but it's fake
|
||||
self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False)
|
||||
# Per-layer learnable scalars (inspired by modded-nanogpt)
|
||||
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
|
||||
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
|
||||
# bigram_lambdas: blends bigram embeddings in at each layer (init 0.1 = small contribution)
|
||||
# Separate parameters so they can have different optimizer treatment
|
||||
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
|
||||
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
||||
self.bigram_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
||||
# Bigram hash embeddings: O(1) lookup for local 2-gram patterns
|
||||
self.bigram_embed = BigramEmbed(config.vocab_size, config.n_embd)
|
||||
# Value embeddings (ResFormer-style): alternating layers, last layer always included
|
||||
head_dim = config.n_embd // config.n_head
|
||||
kv_dim = config.n_kv_head * head_dim
|
||||
self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)})
|
||||
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
|
||||
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
|
||||
# so let's just over-compute them, but assert fail if we ever reach that amount.
|
||||
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
|
||||
# In the future we can dynamically grow the cache, for now it's fine.
|
||||
self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
|
||||
head_dim = config.n_embd // config.n_head
|
||||
@@ -170,36 +225,68 @@ class GPT(nn.Module):
|
||||
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
||||
self.register_buffer("sin", sin, persistent=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def init_weights(self):
|
||||
self.apply(self._init_weights)
|
||||
# zero out classifier weights
|
||||
torch.nn.init.zeros_(self.lm_head.weight)
|
||||
# zero out c_proj weights in all blocks
|
||||
"""
|
||||
Initialize the full model in this one function for maximum clarity.
|
||||
|
||||
wte (embedding): normal, std=1.0
|
||||
lm_head: normal, std=0.001
|
||||
for each block:
|
||||
attn.c_q: uniform, std=1/sqrt(n_embd)
|
||||
attn.c_k: uniform, std=1/sqrt(n_embd)
|
||||
attn.c_v: uniform, std=1/sqrt(n_embd)
|
||||
attn.c_proj: zeros
|
||||
mlp.c_fc: uniform, std=1/sqrt(n_embd)
|
||||
mlp.c_proj: zeros
|
||||
"""
|
||||
|
||||
# Embedding and unembedding
|
||||
torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=1.0)
|
||||
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
|
||||
|
||||
# Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal)
|
||||
n_embd = self.config.n_embd
|
||||
s = 3**0.5 * n_embd**-0.5 # sqrt(3) multiplier makes sure Uniform achieves the same std as Normal
|
||||
for block in self.transformer.h:
|
||||
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers
|
||||
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
|
||||
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
|
||||
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
|
||||
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
|
||||
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
||||
torch.nn.init.zeros_(block.attn.c_proj.weight)
|
||||
# init the rotary embeddings
|
||||
|
||||
# Per-layer scalars
|
||||
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
|
||||
self.x0_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to input embedding
|
||||
self.bigram_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to bigram embeddings
|
||||
|
||||
# Bigram embeddings: zero init so it starts as identity
|
||||
nn.init.zeros_(self.bigram_embed.embed.weight)
|
||||
|
||||
# Value embeddings (init like c_v: uniform with same std)
|
||||
for ve in self.value_embeds.values():
|
||||
torch.nn.init.uniform_(ve.weight, -s, s)
|
||||
|
||||
# Gate weights init to zero so gates start at sigmoid(0) = 0.5, scaled by 2 -> 1.0 (neutral)
|
||||
for block in self.transformer.h:
|
||||
if block.attn.ve_gate is not None:
|
||||
torch.nn.init.zeros_(block.attn.ve_gate.weight)
|
||||
|
||||
# Rotary embeddings
|
||||
head_dim = self.config.n_embd // self.config.n_head
|
||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
self.cos, self.sin = cos, sin
|
||||
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
|
||||
|
||||
# Cast embeddings to bf16: optimizer can tolerate it and it saves memory
|
||||
if self.transformer.wte.weight.device.type == "cuda":
|
||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||
for ve in self.value_embeds.values():
|
||||
ve.to(dtype=torch.bfloat16)
|
||||
self.bigram_embed.to(dtype=torch.bfloat16)
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Linear):
|
||||
# https://arxiv.org/pdf/2310.17813
|
||||
fan_out = module.weight.size(0)
|
||||
fan_in = module.weight.size(1)
|
||||
std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
|
||||
|
||||
# TODO: bump base theta more, e.g. 100K is more common more recently
|
||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
||||
# TODO: bump base theta more? e.g. 100K is more common more recently
|
||||
# autodetect the device from model embeddings
|
||||
if device is None:
|
||||
device = self.transformer.wte.weight.device
|
||||
@@ -215,39 +302,128 @@ class GPT(nn.Module):
|
||||
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
|
||||
return cos, sin
|
||||
|
||||
def _compute_window_sizes(self, config):
|
||||
"""
|
||||
Compute per-layer window sizes for sliding window attention.
|
||||
|
||||
Returns list of (left, right) tuples for FA3's window_size parameter:
|
||||
- left: how many tokens before current position to attend to (-1 = unlimited)
|
||||
- right: how many tokens after current position to attend to (0 for causal)
|
||||
|
||||
Pattern string is tiled across layers. Final layer always gets L (full context).
|
||||
Characters: L=long (full context), S=short (half context)
|
||||
"""
|
||||
pattern = config.window_pattern.upper()
|
||||
assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L."
|
||||
# Map characters to window sizes
|
||||
long_window = config.sequence_len
|
||||
short_window = long_window // 2
|
||||
char_to_window = {
|
||||
"L": (long_window, 0),
|
||||
"S": (short_window, 0),
|
||||
}
|
||||
# Tile pattern across layers
|
||||
window_sizes = []
|
||||
for layer_idx in range(config.n_layer):
|
||||
char = pattern[layer_idx % len(pattern)]
|
||||
window_sizes.append(char_to_window[char])
|
||||
# Final layer always gets full context
|
||||
window_sizes[-1] = (long_window, 0)
|
||||
return window_sizes
|
||||
|
||||
def get_device(self):
|
||||
return self.transformer.wte.weight.device
|
||||
|
||||
def estimate_flops(self):
|
||||
""" Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311 """
|
||||
"""
|
||||
Return the estimated FLOPs per token for the model (forward + backward).
|
||||
Each matmul weight parameter contributes 2 FLOPs (multiply *, accumulate +) in forward, and 2X that in backward => 2+4=6.
|
||||
Cleanest explanation of this: https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4
|
||||
On top of that, 12 * h * q * effective_seq_len accounts for key @ query matmul flops inside attention.
|
||||
With sliding windows, effective_seq_len varies per layer (capped by window size).
|
||||
Ref: https://arxiv.org/abs/2204.02311 (PaLM paper).
|
||||
This is ~1% off from the exact formulas of Chinchilla paper, the difference is:
|
||||
- Chinchilla counts the embedding layer as flops (? weird, it's just a lookup => we ignore)
|
||||
- Chinchilla counts exp/sum/divide in attention softmax as flops (a little sus and very tiny => we ignore)
|
||||
"""
|
||||
nparams = sum(p.numel() for p in self.parameters())
|
||||
nparams_embedding = self.transformer.wte.weight.numel()
|
||||
l, h, q, t = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
||||
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
|
||||
# Exclude non-matmul params: embeddings and per-layer scalars
|
||||
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
|
||||
bigram_embed_numel = self.bigram_embed.embed.weight.numel()
|
||||
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + bigram_embed_numel +
|
||||
self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.bigram_lambdas.numel())
|
||||
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
||||
# Sum attention FLOPs per layer, accounting for sliding window
|
||||
attn_flops = 0
|
||||
for window_size in self.window_sizes:
|
||||
window = window_size[0] # (left, right) tuple, we use left
|
||||
effective_seq = t if window < 0 else min(window, t)
|
||||
attn_flops += 12 * h * q * effective_seq
|
||||
num_flops_per_token = 6 * (nparams - nparams_exclude) + attn_flops
|
||||
return num_flops_per_token
|
||||
|
||||
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
|
||||
def num_scaling_params(self):
|
||||
"""
|
||||
Return detailed parameter counts for scaling law analysis.
|
||||
Different papers use different conventions:
|
||||
- Kaplan et al. excluded embedding parameters
|
||||
- Chinchilla included all parameters
|
||||
Ref: https://arxiv.org/abs/2203.15556 (Chinchilla paper)
|
||||
Ref: https://arxiv.org/abs/2001.08361 (Kaplan et al. original scaling laws paper)
|
||||
|
||||
Returns a dict with counts for each parameter group, so downstream analysis
|
||||
can experiment with which combination gives the cleanest scaling laws.
|
||||
"""
|
||||
# Count each group separately (mirrors the grouping in setup_optimizers)
|
||||
wte = sum(p.numel() for p in self.transformer.wte.parameters())
|
||||
bigram_embed = sum(p.numel() for p in self.bigram_embed.parameters())
|
||||
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
|
||||
lm_head = sum(p.numel() for p in self.lm_head.parameters())
|
||||
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
|
||||
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.bigram_lambdas.numel()
|
||||
total = wte + bigram_embed + value_embeds + lm_head + transformer_matrices + scalars
|
||||
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
|
||||
return {
|
||||
'wte': wte,
|
||||
'bigram_embed': bigram_embed,
|
||||
'value_embeds': value_embeds,
|
||||
'lm_head': lm_head,
|
||||
'transformer_matrices': transformer_matrices,
|
||||
'scalars': scalars,
|
||||
'total': total,
|
||||
}
|
||||
|
||||
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
|
||||
model_dim = self.config.n_embd
|
||||
ddp, rank, local_rank, world_size = get_dist_info()
|
||||
# Separate out all parameters into 3 groups (matrix, embedding, lm_head)
|
||||
# Separate out all parameters into groups
|
||||
matrix_params = list(self.transformer.h.parameters())
|
||||
value_embeds_params = list(self.value_embeds.parameters())
|
||||
embedding_params = list(self.transformer.wte.parameters())
|
||||
lm_head_params = list(self.lm_head.parameters())
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params)
|
||||
# Create the AdamW optimizer for the embedding and lm_head
|
||||
resid_params = [self.resid_lambdas]
|
||||
x0_params = [self.x0_lambdas]
|
||||
bigram_embed_params = list(self.bigram_embed.parameters())
|
||||
bigram_lambda_params = [self.bigram_lambdas]
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(bigram_embed_params) + len(bigram_lambda_params)
|
||||
# Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars
|
||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
|
||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||
if rank == 0:
|
||||
print(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
|
||||
print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
|
||||
adam_groups = [
|
||||
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
|
||||
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
|
||||
dict(params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding
|
||||
dict(params=bigram_embed_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding
|
||||
dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream
|
||||
dict(params=x0_params, lr=scalar_lr, betas=(0.96, 0.95)), # higher beta1 for x0 scalars
|
||||
dict(params=bigram_lambda_params, lr=scalar_lr, betas=(0.96, 0.95)), # same treatment as x0 lambdas
|
||||
]
|
||||
adamw_kwargs = dict(betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay)
|
||||
adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon
|
||||
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
|
||||
adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
|
||||
# Create the Muon optimizer for the linear layers
|
||||
muon_kwargs = dict(lr=matrix_lr, momentum=0.95)
|
||||
muon_kwargs = dict(lr=matrix_lr, momentum=0.95, weight_decay=weight_decay)
|
||||
MuonFactory = DistMuon if ddp else Muon
|
||||
muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
|
||||
# Combine them the two optimizers into one list
|
||||
@@ -260,7 +436,7 @@ class GPT(nn.Module):
|
||||
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
||||
B, T = idx.size()
|
||||
|
||||
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
|
||||
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
|
||||
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
||||
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
||||
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
|
||||
@@ -269,26 +445,30 @@ class GPT(nn.Module):
|
||||
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
|
||||
|
||||
# Forward the trunk of the Transformer
|
||||
x = self.transformer.wte(idx)
|
||||
x = self.transformer.wte(idx) # embed current token
|
||||
x0_bigram = self.bigram_embed(idx) # embed current bigram (via hash lookup)
|
||||
x = norm(x)
|
||||
for block in self.transformer.h:
|
||||
x = block(x, cos_sin, kv_cache)
|
||||
x0 = x # save initial normalized embedding for x0 residual
|
||||
for i, block in enumerate(self.transformer.h):
|
||||
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 + self.bigram_lambdas[i] * x0_bigram
|
||||
ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None
|
||||
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
|
||||
x = norm(x)
|
||||
|
||||
# Forward the lm_head (compute logits)
|
||||
softcap = 15
|
||||
softcap = 15 # smoothly cap the logits to the range [-softcap, softcap]
|
||||
logits = self.lm_head(x) # (B, T, padded_vocab_size) <- very big tensor, large amount of memory
|
||||
logits = logits[..., :self.config.vocab_size] # slice to remove padding
|
||||
logits = logits.float() # switch to fp32 for logit softcap and loss computation
|
||||
logits = softcap * torch.tanh(logits / softcap) # squash the logits
|
||||
|
||||
if targets is not None:
|
||||
# training mode: compute and return the loss
|
||||
# TODO: experiment with Liger Kernels / chunked cross-entropy etc.
|
||||
logits = self.lm_head(x)
|
||||
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
||||
logits = logits.float() # use tf32/fp32 for logits
|
||||
# training: given the targets, compute and return the loss
|
||||
# TODO experiment with chunked cross-entropy?
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
|
||||
return loss
|
||||
else:
|
||||
# inference mode: compute and return the logits
|
||||
logits = self.lm_head(x)
|
||||
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
||||
# inference: just return the logits directly
|
||||
return logits
|
||||
|
||||
@torch.inference_mode()
|
||||
|
||||
@@ -9,9 +9,9 @@ import torch.distributed as dist
|
||||
def evaluate_bpb(model, batches, steps, token_bytes):
|
||||
"""
|
||||
Instead of the naive 'mean loss', this function returns the bits per byte (bpb),
|
||||
which is a tokenization vocab size-indepedent metric, meaning you are still comparing
|
||||
which is a tokenization vocab size-independent metric, meaning you are still comparing
|
||||
apples:apples if you change the vocab size. The way this works is that instead of just
|
||||
calculating the average loss as usual, you calculate the sum loss, and indepependently
|
||||
calculating the average loss as usual, you calculate the sum loss, and independently
|
||||
also the sum bytes (of all the target tokens), and divide. This normalizes the loss by
|
||||
the number of bytes that the target tokens represent.
|
||||
|
||||
@@ -59,5 +59,7 @@ def evaluate_bpb(model, batches, steps, token_bytes):
|
||||
# move both to cpu, calculate bpb and return
|
||||
total_nats = total_nats.item()
|
||||
total_bytes = total_bytes.item()
|
||||
if total_bytes == 0:
|
||||
return float('inf')
|
||||
bpb = total_nats / (math.log(2) * total_bytes)
|
||||
return bpb
|
||||
|
||||
401
nanochat/muon.py
401
nanochat/muon.py
@@ -1,39 +1,96 @@
|
||||
"""
|
||||
Muon optimizer from Keller et al.
|
||||
Also a lot of borrowing of ideas from modded-nanogpt.
|
||||
Muon optimizer adapted and simplified from modded-nanogpt.
|
||||
https://github.com/KellerJordan/modded-nanogpt
|
||||
|
||||
Background:
|
||||
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
||||
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
||||
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
||||
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
||||
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
||||
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||||
performance at all relative to UV^T, where USV^T = G is the SVD.
|
||||
|
||||
Here, an alternative to Newton-Schulz iteration with potentially better convergence properties:
|
||||
Polar Express Sign Method for orthogonalization.
|
||||
https://arxiv.org/pdf/2505.16932
|
||||
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
|
||||
|
||||
Some of the changes in nanochat implementation:
|
||||
- Uses a simpler, more general approach to parameter grouping and stacking
|
||||
- Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step
|
||||
- Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format)
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.distributed as dist
|
||||
|
||||
@torch.compile
|
||||
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
|
||||
"""
|
||||
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
||||
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
||||
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
||||
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
||||
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
||||
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||||
performance at all relative to UV^T, where USV^T = G is the SVD.
|
||||
"""
|
||||
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||
X = G.bfloat16()
|
||||
if G.size(-2) > G.size(-1):
|
||||
X = X.mT
|
||||
# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
|
||||
# From https://arxiv.org/pdf/2505.16932
|
||||
polar_express_coeffs = [
|
||||
(8.156554524902461, -22.48329292557795, 15.878769915207462),
|
||||
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
|
||||
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
|
||||
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
|
||||
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
|
||||
]
|
||||
|
||||
# Ensure spectral norm is at most 1
|
||||
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
||||
# Perform the NS iterations
|
||||
for _ in range(steps):
|
||||
@torch.compile(dynamic=False, fullgraph=True)
|
||||
def muon_step_fused(
|
||||
stacked_grads: Tensor,
|
||||
stacked_params: Tensor,
|
||||
momentum_buffer: Tensor,
|
||||
second_momentum_buffer: Tensor,
|
||||
momentum_t: Tensor,
|
||||
lr_t: Tensor,
|
||||
wd_t: Tensor,
|
||||
beta2_t: Tensor,
|
||||
ns_steps: int,
|
||||
red_dim: int,
|
||||
) -> None:
|
||||
"""
|
||||
Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update
|
||||
All in one compiled graph to eliminate Python overhead between ops.
|
||||
Some of the constants are 0-D CPU tensors to avoid recompilation when values change.
|
||||
"""
|
||||
|
||||
# Nesterov momentum
|
||||
momentum = momentum_t.to(stacked_grads.dtype)
|
||||
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
|
||||
g = stacked_grads.lerp_(momentum_buffer, momentum)
|
||||
|
||||
# Polar express
|
||||
X = g.bfloat16()
|
||||
if g.size(-2) > g.size(-1):
|
||||
X = X.mT
|
||||
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
|
||||
for a, b, c in polar_express_coeffs[:ns_steps]:
|
||||
A = X @ X.mT
|
||||
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
||||
B = b * A + c * (A @ A)
|
||||
X = a * X + B @ X
|
||||
|
||||
if G.size(-2) > G.size(-1):
|
||||
if g.size(-2) > g.size(-1):
|
||||
X = X.mT
|
||||
return X
|
||||
g = X
|
||||
|
||||
# Variance reduction
|
||||
beta2 = beta2_t.to(g.dtype)
|
||||
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
|
||||
red_dim_size = g.size(red_dim)
|
||||
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
|
||||
v_norm = v_norm_sq.sqrt()
|
||||
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
|
||||
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
|
||||
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
|
||||
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
|
||||
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
|
||||
g = g * final_scale.to(g.dtype)
|
||||
|
||||
# Cautious weight decay + parameter update
|
||||
lr = lr_t.to(g.dtype)
|
||||
wd = wd_t.to(g.dtype)
|
||||
mask = (g * stacked_params) >= 0
|
||||
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
|
||||
|
||||
class Muon(torch.optim.Optimizer):
|
||||
"""
|
||||
@@ -54,74 +111,112 @@ class Muon(torch.optim.Optimizer):
|
||||
Arguments:
|
||||
lr: The learning rate used by the internal SGD.
|
||||
momentum: The momentum used by the internal SGD.
|
||||
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
|
||||
ns_steps: The number of Newton-Schulz iteration steps to use.
|
||||
beta2: The decay rate for the second moment (variance) estimate. Set to None to disable.
|
||||
weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
|
||||
"""
|
||||
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
|
||||
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
|
||||
params: list[Tensor] = [*params]
|
||||
def __init__(self, params, lr=0.02, momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=0.0):
|
||||
defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
|
||||
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
|
||||
params = list(params) # ensure we have a list, not an e.g. (exhaustible) iterator
|
||||
# Group by shape so we can stack tensors
|
||||
shapes = sorted({p.shape for p in params})
|
||||
param_groups = []
|
||||
for size in {p.numel() for p in params}:
|
||||
group = dict(params=[p for p in params if p.numel() == size])
|
||||
param_groups.append(group)
|
||||
for shape in shapes:
|
||||
group_params = [p for p in params if p.shape == shape]
|
||||
param_groups.append(dict(params=group_params))
|
||||
super().__init__(param_groups, defaults)
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
for group in self.param_groups:
|
||||
params: list[Tensor] = group["params"]
|
||||
for p in params:
|
||||
g = p.grad
|
||||
assert g is not None
|
||||
state = self.state[p]
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros_like(g)
|
||||
buf: Tensor = state["momentum_buffer"]
|
||||
buf.lerp_(g, 1 - group["momentum"])
|
||||
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
|
||||
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
|
||||
p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5)
|
||||
if not params:
|
||||
continue
|
||||
|
||||
# Get or create group-level buffers (stored in first param's state for convenience)
|
||||
state = self.state[params[0]]
|
||||
num_params = len(params) # e.g.: 12 (for a d12 model)
|
||||
# e.g.: shape = (768, 3072), device = cuda:0, dtype = torch.float32, for one of the MLP projections
|
||||
shape, device, dtype = params[0].shape, params[0].device, params[0].dtype
|
||||
|
||||
# Momentum for every individual parameter
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
|
||||
momentum_buffer = state["momentum_buffer"] # e.g.: (12, 768, 3072)
|
||||
|
||||
# Second momentum buffer is factored, either per-row or per-column
|
||||
if "second_momentum_buffer" not in state:
|
||||
if shape[-2] >= shape[-1]:
|
||||
state["second_momentum_buffer"] = torch.zeros(num_params, shape[-2], 1, dtype=dtype, device=device)
|
||||
else:
|
||||
state["second_momentum_buffer"] = torch.zeros(num_params, 1, shape[-1], dtype=dtype, device=device)
|
||||
second_momentum_buffer = state["second_momentum_buffer"] # (12, 1, 3072)
|
||||
red_dim = -1 if shape[-2] >= shape[-1] else -2 # e.g.: -2
|
||||
|
||||
# Stack grads and params
|
||||
stacked_grads = torch.stack([p.grad for p in params]) # (12, 768, 3072)
|
||||
stacked_params = torch.stack(params) # (12, 768, 3072)
|
||||
|
||||
# Fill all the 0-D tensors with current values
|
||||
self._momentum_t.fill_(group["momentum"])
|
||||
self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
||||
self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
||||
self._wd_t.fill_(group["weight_decay"])
|
||||
|
||||
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
|
||||
muon_step_fused(
|
||||
stacked_grads,
|
||||
stacked_params,
|
||||
momentum_buffer,
|
||||
second_momentum_buffer,
|
||||
self._momentum_t,
|
||||
self._lr_t,
|
||||
self._wd_t,
|
||||
self._beta2_t,
|
||||
group["ns_steps"],
|
||||
red_dim,
|
||||
)
|
||||
|
||||
# Copy back to original params: [(768, 3072), (768, 3072), ...] <- (12, 768, 3072)
|
||||
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
|
||||
|
||||
|
||||
class DistMuon(torch.optim.Optimizer):
|
||||
"""
|
||||
Muon: SGD-momentum + (optional) Nesterov, then orthogonalize the 2D update via Newton–Schulz,
|
||||
finally apply aspect-ratio scaled step. Performs its own distributed synchronization:
|
||||
- reduce_scatter(AVG) for gradient averaging
|
||||
- all_gather to replicate updated weights
|
||||
|
||||
Notes:
|
||||
* Designed for 2D parameters (e.g., linear/conv kernels reshaped to 2D). Do not use for 0D/1D
|
||||
params like embeddings or scalars.
|
||||
* Momentum buffers are maintained only on the 'owner' rank for each parameter (rank chosen
|
||||
by block-cyclic assignment below). If you checkpoint optimizer state on a single rank,
|
||||
consolidate states beforehand.
|
||||
|
||||
Args:
|
||||
params: iterable of Tensors
|
||||
lr: learning rate
|
||||
momentum: momentum coefficient in [0,1)
|
||||
nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf
|
||||
ns_steps: number of Newton–Schulz iterations for the orthogonalization
|
||||
Distributed version of the Muon optimizer.
|
||||
"""
|
||||
def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
|
||||
nesterov: bool = True, ns_steps: int = 5):
|
||||
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
|
||||
params = list(params)
|
||||
ns_steps: int = 5, beta2: float = 0.95, weight_decay: float = 0.0):
|
||||
defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
|
||||
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
|
||||
params = list(params)
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
# Group all parameters by their shape
|
||||
shapes = sorted({p.shape for p in params}) # sort to ensure consistent / deterministic ordering
|
||||
shapes = sorted({p.shape for p in params}) # sort for deterministic ordering across ranks
|
||||
param_groups = []
|
||||
for shape in shapes:
|
||||
group_params = [p for p in params if p.shape == shape]
|
||||
device, dtype = group_params[0].device, group_params[0].dtype
|
||||
assert all(p.device == device for p in group_params)
|
||||
assert all(p.dtype == dtype for p in group_params)
|
||||
# Compute chunk size for this group (how many params each rank owns)
|
||||
chunk_size = (len(group_params) + world_size - 1) // world_size
|
||||
if rank == 0:
|
||||
print(f"Muon: Grouping {len(group_params)} params of shape {shape}, device {device}, dtype {dtype}")
|
||||
param_groups.append(dict(params=group_params, zero_buffer=torch.zeros_like(group_params[0])))
|
||||
print(f"Muon: {len(group_params)} params of shape {shape}, chunk_size={chunk_size}")
|
||||
param_groups.append(dict(params=group_params, chunk_size=chunk_size))
|
||||
super().__init__(param_groups, defaults)
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
@@ -131,57 +226,127 @@ class DistMuon(torch.optim.Optimizer):
|
||||
# Ensure all grads exist
|
||||
assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
|
||||
|
||||
# Kick off all the reduce scatter operations to average up the gradients across all ranks
|
||||
all_reduce_futures = []
|
||||
# First pass: stack grads and kick off reduce_scatter for each group
|
||||
group_infos = []
|
||||
for group in self.param_groups:
|
||||
params = group["params"]
|
||||
zero_buffer = group["zero_buffer"]
|
||||
# Go through params in groups of world_size.
|
||||
for base_i in range(0, len(params), world_size):
|
||||
# The compute owner of each param is rank i % world_size
|
||||
owner_idx = base_i + rank
|
||||
# each rank stacks up its chunk of world_size params into a list
|
||||
rs_input = [p.grad for p in params[base_i:base_i + world_size]]
|
||||
# pad rs_input with the zero buffer to complete the group
|
||||
rs_input.extend([zero_buffer] * (world_size - len(rs_input)))
|
||||
# the output buffer gets strided across the group based on the rank
|
||||
rs_output = params[owner_idx].grad if owner_idx < len(params) else torch.empty_like(zero_buffer)
|
||||
# reduce scatter the gradients within this group of world_size params
|
||||
work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
||||
all_reduce_futures.append(work)
|
||||
params: list[Tensor] = group["params"]
|
||||
chunk_size = group["chunk_size"]
|
||||
padded_num_params = chunk_size * world_size
|
||||
shape = params[0].shape
|
||||
device, dtype = params[0].device, params[0].dtype
|
||||
|
||||
# Now each rank computes the update and gathers
|
||||
future_idx = 0
|
||||
# Stack all gradients into a single tensor (single kernel via torch.stack)
|
||||
grad_stack = torch.stack([p.grad for p in params])
|
||||
stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device)
|
||||
stacked_grads[:len(params)].copy_(grad_stack)
|
||||
# Zero-pad if we have fewer params than padded size
|
||||
if len(params) < padded_num_params:
|
||||
stacked_grads[len(params):].zero_()
|
||||
|
||||
# Output buffer for this rank's chunk
|
||||
grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
||||
|
||||
# Async reduce_scatter on the stacked tensor
|
||||
reduce_future = dist.reduce_scatter_tensor(
|
||||
grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True
|
||||
).get_future()
|
||||
|
||||
group_infos.append(dict(
|
||||
grad_chunk=grad_chunk,
|
||||
reduce_future=reduce_future,
|
||||
stacked_grads=stacked_grads, # reuse for all_gather output
|
||||
))
|
||||
|
||||
# Second pass: wait for reduce, compute batched updates, kick off all_gather
|
||||
all_gather_futures = []
|
||||
for group in self.param_groups:
|
||||
params = group["params"]
|
||||
zero_buffer = group["zero_buffer"]
|
||||
# Go through params in groups of world_size.
|
||||
for base_i in range(0, len(params), world_size):
|
||||
# The compute owner of each param is rank i % world_size
|
||||
owner_idx = base_i + rank # calculate the index of the param that this rank owns
|
||||
# Wait for the reduce scatter to complete
|
||||
all_reduce_futures[future_idx].wait() # possibly later we could use wait_any polling instead
|
||||
future_idx += 1
|
||||
# Owner computes the Muon update, result is in its param
|
||||
if owner_idx < len(params):
|
||||
p = params[owner_idx]
|
||||
g = p.grad # now averaged across ranks
|
||||
state = self.state[p]
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros_like(g)
|
||||
buf: Tensor = state["momentum_buffer"]
|
||||
buf.lerp_(g, 1.0 - group["momentum"])
|
||||
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
|
||||
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
|
||||
scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
|
||||
p.add_(g, alpha=-group["lr"] * scale)
|
||||
# Replicate updated parameters to all ranks
|
||||
ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer
|
||||
ag_output = params[base_i:base_i + world_size]
|
||||
ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))]) # pad
|
||||
work = dist.all_gather(ag_output, ag_input, async_op=True).get_future()
|
||||
all_gather_futures.append(work)
|
||||
for group, info in zip(self.param_groups, group_infos):
|
||||
info["reduce_future"].wait()
|
||||
|
||||
# Wait for all work to finish
|
||||
torch.futures.collect_all(all_gather_futures).wait()
|
||||
params = group["params"]
|
||||
chunk_size = group["chunk_size"]
|
||||
shape = params[0].shape
|
||||
device, dtype = params[0].device, params[0].dtype
|
||||
grad_chunk = info["grad_chunk"]
|
||||
|
||||
# How many params does this rank actually own?
|
||||
start_idx = rank * chunk_size
|
||||
num_owned = min(chunk_size, max(0, len(params) - start_idx))
|
||||
|
||||
# Get or create group-level state (stored keyed by first param)
|
||||
state = self.state[params[0]]
|
||||
|
||||
# Momentum buffer
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
|
||||
momentum_buffer = state["momentum_buffer"]
|
||||
|
||||
# Second momentum buffer is factored, either per-row or per-column
|
||||
if "second_momentum_buffer" not in state:
|
||||
if shape[-2] >= shape[-1]:
|
||||
state["second_momentum_buffer"] = torch.zeros(chunk_size, shape[-2], 1, dtype=dtype, device=device)
|
||||
else:
|
||||
state["second_momentum_buffer"] = torch.zeros(chunk_size, 1, shape[-1], dtype=dtype, device=device)
|
||||
second_momentum_buffer = state["second_momentum_buffer"]
|
||||
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
||||
|
||||
# Build updated_params tensor for all_gather
|
||||
updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
||||
|
||||
if num_owned > 0:
|
||||
# Stack owned params (single kernel via torch.stack)
|
||||
owned_params = [params[start_idx + i] for i in range(num_owned)]
|
||||
stacked_owned_params = torch.stack(owned_params)
|
||||
|
||||
# Get owned slices of buffers and grads
|
||||
owned_grads = grad_chunk[:num_owned]
|
||||
owned_momentum = momentum_buffer[:num_owned]
|
||||
owned_second_momentum = second_momentum_buffer[:num_owned]
|
||||
|
||||
# Fill 0-D tensors with current values
|
||||
self._momentum_t.fill_(group["momentum"])
|
||||
self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
||||
self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
||||
self._wd_t.fill_(group["weight_decay"])
|
||||
|
||||
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
|
||||
muon_step_fused(
|
||||
owned_grads,
|
||||
stacked_owned_params,
|
||||
owned_momentum,
|
||||
owned_second_momentum,
|
||||
self._momentum_t,
|
||||
self._lr_t,
|
||||
self._wd_t,
|
||||
self._beta2_t,
|
||||
group["ns_steps"],
|
||||
red_dim,
|
||||
)
|
||||
|
||||
# Copy updated params to output buffer
|
||||
updated_params[:num_owned].copy_(stacked_owned_params)
|
||||
|
||||
# Zero-pad the rest (for ranks that own fewer params)
|
||||
if num_owned < chunk_size:
|
||||
updated_params[num_owned:].zero_()
|
||||
|
||||
# Reuse stacked_grads buffer for all_gather output
|
||||
stacked_params = info["stacked_grads"]
|
||||
|
||||
# Async all_gather to replicate updated params to all ranks
|
||||
gather_future = dist.all_gather_into_tensor(
|
||||
stacked_params, updated_params, async_op=True
|
||||
).get_future()
|
||||
|
||||
all_gather_futures.append(dict(
|
||||
gather_future=gather_future,
|
||||
stacked_params=stacked_params,
|
||||
params=params,
|
||||
))
|
||||
|
||||
# Final pass: wait for all_gather and copy back to params
|
||||
for info in all_gather_futures:
|
||||
info["gather_future"].wait()
|
||||
stacked_params = info["stacked_params"]
|
||||
params = info["params"]
|
||||
# Batched copy back (single kernel instead of N individual copies)
|
||||
torch._foreach_copy_(params, list(stacked_params[:len(params)].unbind(0)))
|
||||
|
||||
@@ -16,8 +16,11 @@ def run_command(cmd):
|
||||
"""Run a shell command and return output, or None if it fails."""
|
||||
try:
|
||||
result = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=5)
|
||||
if result.returncode == 0:
|
||||
# Return stdout if we got output (even if some files in xargs failed)
|
||||
if result.stdout.strip():
|
||||
return result.stdout.strip()
|
||||
if result.returncode == 0:
|
||||
return ""
|
||||
return None
|
||||
except:
|
||||
return None
|
||||
@@ -160,17 +163,28 @@ Generated: {timestamp}
|
||||
|
||||
"""
|
||||
|
||||
# bloat metrics: package all of the source code and assess its weight
|
||||
packaged = run_command('files-to-prompt . -e py -e md -e rs -e html -e toml -e sh --ignore "*target*" --cxml')
|
||||
num_chars = len(packaged)
|
||||
num_lines = len(packaged.split('\n'))
|
||||
num_files = len([x for x in packaged.split('\n') if x.startswith('<source>')])
|
||||
num_tokens = num_chars // 4 # assume approximately 4 chars per token
|
||||
# bloat metrics: count lines/chars in git-tracked source files only
|
||||
extensions = ['py', 'md', 'rs', 'html', 'toml', 'sh']
|
||||
git_patterns = ' '.join(f"'*.{ext}'" for ext in extensions)
|
||||
files_output = run_command(f"git ls-files -- {git_patterns}")
|
||||
file_list = [f for f in (files_output or '').split('\n') if f]
|
||||
num_files = len(file_list)
|
||||
num_lines = 0
|
||||
num_chars = 0
|
||||
if num_files > 0:
|
||||
wc_output = run_command(f"git ls-files -- {git_patterns} | xargs wc -lc 2>/dev/null")
|
||||
if wc_output:
|
||||
total_line = wc_output.strip().split('\n')[-1]
|
||||
parts = total_line.split()
|
||||
if len(parts) >= 2:
|
||||
num_lines = int(parts[0])
|
||||
num_chars = int(parts[1])
|
||||
num_tokens = num_chars // 4 # assume approximately 4 chars per token
|
||||
|
||||
# count dependencies via uv.lock
|
||||
uv_lock_lines = 0
|
||||
if os.path.exists('uv.lock'):
|
||||
with open('uv.lock', 'r') as f:
|
||||
with open('uv.lock', 'r', encoding='utf-8') as f:
|
||||
uv_lock_lines = len(f.readlines())
|
||||
|
||||
header += f"""
|
||||
@@ -241,7 +255,7 @@ class Report:
|
||||
slug = slugify(section)
|
||||
file_name = f"{slug}.md"
|
||||
file_path = os.path.join(self.report_dir, file_name)
|
||||
with open(file_path, "w") as f:
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(f"## {section}\n")
|
||||
f.write(f"timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
|
||||
for item in data:
|
||||
@@ -272,11 +286,11 @@ class Report:
|
||||
final_metrics = {} # the most important final metrics we'll add as table at the end
|
||||
start_time = None
|
||||
end_time = None
|
||||
with open(report_file, "w") as out_file:
|
||||
with open(report_file, "w", encoding="utf-8") as out_file:
|
||||
# write the header first
|
||||
header_file = os.path.join(report_dir, "header.md")
|
||||
if os.path.exists(header_file):
|
||||
with open(header_file, "r") as f:
|
||||
with open(header_file, "r", encoding="utf-8") as f:
|
||||
header_content = f.read()
|
||||
out_file.write(header_content)
|
||||
start_time = extract_timestamp(header_content, "Run started:")
|
||||
@@ -293,7 +307,7 @@ class Report:
|
||||
if not os.path.exists(section_file):
|
||||
print(f"Warning: {section_file} does not exist, skipping")
|
||||
continue
|
||||
with open(section_file, "r") as in_file:
|
||||
with open(section_file, "r", encoding="utf-8") as in_file:
|
||||
section = in_file.read()
|
||||
# Extract timestamp from this section (the last section's timestamp will "stick" as end_time)
|
||||
if "rl" not in file_name:
|
||||
@@ -373,7 +387,7 @@ class Report:
|
||||
header_file = os.path.join(self.report_dir, "header.md")
|
||||
header = generate_header()
|
||||
start_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with open(header_file, "w") as f:
|
||||
with open(header_file, "w", encoding="utf-8") as f:
|
||||
f.write(header)
|
||||
f.write(f"Run started: {start_time}\n\n---\n\n")
|
||||
print(f"Reset report and wrote header to {header_file}")
|
||||
|
||||
@@ -26,7 +26,7 @@ SPECIAL_TOKENS = [
|
||||
|
||||
# NOTE: this split pattern deviates from GPT-4 in that we use \p{N}{1,2} instead of \p{N}{1,3}
|
||||
# I did this because I didn't want to "waste" too many tokens on numbers for smaller vocab sizes.
|
||||
# I haven't validated that this is actually a good idea, TODO.
|
||||
# I verified that 2 is the sweet spot for vocab size of 32K. 1 is a bit worse, 3 was worse still.
|
||||
SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -103,9 +103,10 @@ class HuggingFaceTokenizer:
|
||||
def id_to_token(self, id):
|
||||
return self.tokenizer.id_to_token(id)
|
||||
|
||||
def _encode_one(self, text, prepend=None, append=None):
|
||||
def _encode_one(self, text, prepend=None, append=None, num_threads=None):
|
||||
# encode a single string
|
||||
# prepend/append can be either a string of a special token or a token id directly.
|
||||
# num_threads is ignored (only used by the nanochat Tokenizer for parallel encoding)
|
||||
assert isinstance(text, str)
|
||||
ids = []
|
||||
if prepend is not None:
|
||||
@@ -122,7 +123,14 @@ class HuggingFaceTokenizer:
|
||||
return self.tokenizer.token_to_id(text)
|
||||
|
||||
def get_bos_token_id(self):
|
||||
# Different HuggingFace models use different BOS tokens and there is little consistency
|
||||
# 1) attempt to find a <|bos|> token
|
||||
bos = self.encode_special("<|bos|>")
|
||||
# 2) if that fails, attempt to find a <|endoftext|> token (e.g. GPT-2 models)
|
||||
if bos is None:
|
||||
bos = self.encode_special("<|endoftext|>")
|
||||
# 3) if these fail, it's better to crash than to silently return None
|
||||
assert bos is not None, "Failed to find BOS token in tokenizer"
|
||||
return bos
|
||||
|
||||
def encode(self, text, *args, **kwargs):
|
||||
@@ -341,16 +349,19 @@ class RustBPETokenizer:
|
||||
mask = mask[:max_tokens]
|
||||
return ids, mask
|
||||
|
||||
def visualize_tokenization(self, ids, mask):
|
||||
def visualize_tokenization(self, ids, mask, with_token_id=False):
|
||||
"""Small helper function useful in debugging: visualize the tokenization of render_conversation"""
|
||||
RED = '\033[91m'
|
||||
GREEN = '\033[92m'
|
||||
RESET = '\033[0m'
|
||||
GRAY = '\033[90m'
|
||||
tokens = []
|
||||
for i, (token_id, mask_val) in enumerate(zip(ids, mask)):
|
||||
token_str = self.decode([token_id])
|
||||
color = GREEN if mask_val == 1 else RED
|
||||
tokens.append(f"{color}{token_str}{RESET}")
|
||||
if with_token_id:
|
||||
tokens.append(f"{GRAY}({token_id}){RESET}")
|
||||
return '|'.join(tokens)
|
||||
|
||||
def render_for_completion(self, conversation):
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0, viewport-fit=cover">
|
||||
<title>NanoChat</title>
|
||||
<link rel="icon" type="image/svg+xml" href="/logo.svg">
|
||||
<style>
|
||||
@@ -14,11 +14,16 @@
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
html, body{
|
||||
height: 100%;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: ui-sans-serif, -apple-system, system-ui, "Segoe UI", Helvetica, "Apple Color Emoji", Arial, sans-serif, "Segoe UI Emoji", "Segoe UI Symbol";
|
||||
background-color: #ffffff;
|
||||
color: #111827;
|
||||
min-height: 100vh;
|
||||
min-height: 100dvh;
|
||||
margin: 0;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
@@ -107,7 +112,6 @@
|
||||
.message.assistant .message-content {
|
||||
background: transparent;
|
||||
border: none;
|
||||
padding: 0.25rem 0;
|
||||
cursor: pointer;
|
||||
border-radius: 0.5rem;
|
||||
padding: 0.5rem;
|
||||
@@ -144,6 +148,7 @@
|
||||
.input-container {
|
||||
background-color: #ffffff;
|
||||
padding: 1rem;
|
||||
padding-bottom: calc(1rem + env(safe-area-inset-bottom))
|
||||
}
|
||||
|
||||
.input-wrapper {
|
||||
|
||||
@@ -7,31 +7,27 @@ requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"datasets>=4.0.0",
|
||||
"fastapi>=0.117.1",
|
||||
"files-to-prompt>=0.6",
|
||||
"numpy==1.26.4",
|
||||
"ipykernel>=7.1.0",
|
||||
"kernels>=0.11.7",
|
||||
"matplotlib>=3.10.8",
|
||||
"psutil>=7.1.0",
|
||||
"python-dotenv>=1.2.1",
|
||||
"regex>=2025.9.1",
|
||||
"rustbpe>=0.1.0",
|
||||
"scipy>=1.15.3",
|
||||
"setuptools>=80.9.0",
|
||||
"tabulate>=0.9.0",
|
||||
"tiktoken>=0.11.0",
|
||||
"tokenizers>=0.22.0",
|
||||
"torch>=2.8.0",
|
||||
"torch>=2.9.0",
|
||||
"transformers>=4.57.3",
|
||||
"uvicorn>=0.36.0",
|
||||
"wandb>=0.21.3",
|
||||
"zstandard>=0.25.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["maturin>=1.7,<2.0"]
|
||||
build-backend = "maturin"
|
||||
|
||||
[tool.maturin]
|
||||
module-name = "rustbpe"
|
||||
bindings = "pyo3"
|
||||
python-source = "."
|
||||
manifest-path = "rustbpe/Cargo.toml"
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"maturin>=1.9.4",
|
||||
"pytest>=8.0.0",
|
||||
]
|
||||
|
||||
@@ -44,11 +40,11 @@ python_files = ["test_*.py"]
|
||||
python_classes = ["Test*"]
|
||||
python_functions = ["test_*"]
|
||||
|
||||
# target torch to cuda 12.8
|
||||
# target torch to cuda 12.8 or CPU
|
||||
[tool.uv.sources]
|
||||
torch = [
|
||||
{ index = "pytorch-cpu", marker = "sys_platform != 'linux'" },
|
||||
{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" },
|
||||
{ index = "pytorch-cpu", extra = "cpu" },
|
||||
{ index = "pytorch-cu128", extra = "gpu" },
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
@@ -59,4 +55,20 @@ explicit = true
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu128"
|
||||
url = "https://download.pytorch.org/whl/cu128"
|
||||
explicit = true
|
||||
explicit = true
|
||||
|
||||
[project.optional-dependencies]
|
||||
cpu = [
|
||||
"torch>=2.9.1",
|
||||
]
|
||||
gpu = [
|
||||
"torch>=2.9.1",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
conflicts = [
|
||||
[
|
||||
{ extra = "cpu" },
|
||||
{ extra = "gpu" },
|
||||
],
|
||||
]
|
||||
|
||||
102
runs/miniseries.sh
Normal file
102
runs/miniseries.sh
Normal file
@@ -0,0 +1,102 @@
|
||||
#!/bin/bash
|
||||
|
||||
# See speedrun.sh for more comments
|
||||
# Usage: ./miniseries.sh [series_name]
|
||||
# Example: ./miniseries.sh jan11
|
||||
# Default series name is today's date (e.g., jan11)
|
||||
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
|
||||
mkdir -p $NANOCHAT_BASE_DIR
|
||||
|
||||
# Setup (skip with SKIP_SETUP=1)
|
||||
if [ -z "$SKIP_SETUP" ]; then
|
||||
# uv
|
||||
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
[ -d ".venv" ] || uv venv
|
||||
uv sync --extra gpu
|
||||
source .venv/bin/activate
|
||||
|
||||
# Tokenizer, download 1000 shards for pretraining
|
||||
# (probably this can be reduced but it's tricky to determine the exact right number, TODO).
|
||||
python -m nanochat.dataset -n 1000
|
||||
python -m scripts.tok_train --max-chars=2000000000 --vocab-size=32768
|
||||
else
|
||||
source .venv/bin/activate
|
||||
fi
|
||||
|
||||
# Series name: from arg, env var, or default to today's date (e.g., jan11)
|
||||
SERIES_NAME="${1:-${SERIES_NAME:-$(date +%b%d | tr '[:upper:]' '[:lower:]')}}"
|
||||
# Depths to train (the "miniseries")
|
||||
DEPTHS=(10 11 12 13 14 15 16 17 18 19 20)
|
||||
# Hardware
|
||||
NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
|
||||
# Logging
|
||||
WANDB_RUN="${WANDB_RUN:-${SERIES_NAME}_miniseries}"
|
||||
|
||||
RESULTS_DIR="$NANOCHAT_BASE_DIR/${SERIES_NAME}_miniseries_results"
|
||||
mkdir -p "$RESULTS_DIR"
|
||||
RESULTS_FILE="$RESULTS_DIR/results.csv"
|
||||
|
||||
# Write CSV header only if file doesn't exist
|
||||
if [ ! -f "$RESULTS_FILE" ]; then
|
||||
echo "depth,model_dim,num_params,num_scaling_params,num_iterations,tokens_trained,param_data_ratio,val_bpb,core_score,train_time_sec" > "$RESULTS_FILE"
|
||||
fi
|
||||
|
||||
log() {
|
||||
echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1"
|
||||
}
|
||||
|
||||
log "=============================================="
|
||||
log "${SERIES_NAME} Miniseries Training"
|
||||
log "=============================================="
|
||||
|
||||
for d in "${DEPTHS[@]}"; do
|
||||
log "Training d=$d..."
|
||||
|
||||
TAG="${SERIES_NAME}_miniseries_d${d}"
|
||||
START_TIME=$(date +%s)
|
||||
|
||||
# Train the model with natural horizon (target_param_data_ratio default)
|
||||
# No --target-flops, let it use the default ratio from base_train
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \
|
||||
--depth=$d \
|
||||
--run="${WANDB_RUN}_d${d}" \
|
||||
--model-tag="${TAG}" \
|
||||
--core-metric-every=999999 \
|
||||
--core-metric-max-per-task=-1 \
|
||||
--sample-every=-1 \
|
||||
--save-every=-1 \
|
||||
2>&1 | tee "$RESULTS_DIR/${TAG}_train.log"
|
||||
|
||||
END_TIME=$(date +%s)
|
||||
TRAIN_TIME=$((END_TIME - START_TIME))
|
||||
|
||||
# Extract stats from log
|
||||
LOG_FILE="$RESULTS_DIR/${TAG}_train.log"
|
||||
NUM_PARAMS=$(grep "Number of parameters:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | head -1 | tr -d ',')
|
||||
NUM_SCALING_PARAMS=$(grep "Number of parameters:" "$LOG_FILE" | tail -1 | grep -oP 'scaling: [\d,]+' | grep -oP '[\d,]+' | tr -d ',')
|
||||
NUM_ITERS=$(grep "Calculated number of iterations" "$LOG_FILE" | tail -1 | sed 's/.*: //' | tr -d ',')
|
||||
TOKENS_TRAINED=$((NUM_ITERS * 524288))
|
||||
PARAM_DATA_RATIO=$(python -c "print(f'{$TOKENS_TRAINED / $NUM_SCALING_PARAMS:.2f}')")
|
||||
MODEL_DIM=$((d * 64))
|
||||
VAL_BPB=$(grep "Validation bpb:" "$LOG_FILE" | tail -1 | grep -oP '[\d.]+$')
|
||||
CORE_SCORE=$(grep "CORE metric:" "$LOG_FILE" | tail -1 | awk '{print $NF}')
|
||||
|
||||
if [ -z "$CORE_SCORE" ]; then
|
||||
CORE_SCORE="0.0"
|
||||
fi
|
||||
|
||||
log " d=$d: params=$NUM_PARAMS, scaling=$NUM_SCALING_PARAMS, ratio=$PARAM_DATA_RATIO, bpb=$VAL_BPB, CORE=$CORE_SCORE, time=${TRAIN_TIME}s"
|
||||
|
||||
# Append to CSV
|
||||
echo "$d,$MODEL_DIM,$NUM_PARAMS,$NUM_SCALING_PARAMS,$NUM_ITERS,$TOKENS_TRAINED,$PARAM_DATA_RATIO,$VAL_BPB,$CORE_SCORE,$TRAIN_TIME" >> "$RESULTS_FILE"
|
||||
done
|
||||
|
||||
log "=============================================="
|
||||
log "${SERIES_NAME} Miniseries Complete!"
|
||||
log "=============================================="
|
||||
log "Results saved to: $RESULTS_FILE"
|
||||
echo ""
|
||||
echo "Results:"
|
||||
column -t -s',' "$RESULTS_FILE"
|
||||
@@ -10,38 +10,28 @@ export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
|
||||
mkdir -p $NANOCHAT_BASE_DIR
|
||||
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
[ -d ".venv" ] || uv venv
|
||||
uv sync
|
||||
uv sync --extra gpu
|
||||
source .venv/bin/activate
|
||||
if [ -z "$WANDB_RUN" ]; then
|
||||
WANDB_RUN=dummy
|
||||
fi
|
||||
python -m nanochat.report reset
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
source "$HOME/.cargo/env"
|
||||
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
|
||||
EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
|
||||
if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
|
||||
curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL
|
||||
unzip -q eval_bundle.zip
|
||||
rm eval_bundle.zip
|
||||
mv eval_bundle $NANOCHAT_BASE_DIR
|
||||
fi
|
||||
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
|
||||
|
||||
# train tokenizer on ~4B characters and kick off download of the rest for pretraining
|
||||
python -m nanochat.dataset -n 16
|
||||
# start downloading the rest of the shards for a total of 800 (see below why 800)
|
||||
python -m nanochat.dataset -n 800 &
|
||||
# start downloading the rest of the shards for a total of 1200 (see below why 1200)
|
||||
python -m nanochat.dataset -n 1200 &
|
||||
# todo: download the rest of it
|
||||
python -m scripts.tok_train --max_chars=4000000000
|
||||
python -m scripts.tok_train --max-chars=4000000000 --vocab-size=65536
|
||||
python -m scripts.tok_eval
|
||||
|
||||
# Documenting my process for determining the hyperparameters for this run1000.sh script:
|
||||
# We want a budget of approx. $1000 ~= 41.6 hours of 8XH100 compute
|
||||
# 1) I guessed the model size for this to be about depth=32
|
||||
# 2) Determine the device_batch_size that fits:
|
||||
# Running the base_train.py script with --depth=32, I saw that --device_batch_size=16
|
||||
# runs out of memory, but --device_batch_size=8 fits. Inspecting `nvidia-smi` during training,
|
||||
# Running the base_train.py script with --depth=32, I saw that --device-batch-size=16
|
||||
# runs out of memory, but --device-batch-size=8 fits. Inspecting `nvidia-smi` during training,
|
||||
# I saw all GPUs were at about 78/80GB VRAM, so it just barely fits and we have good MFU at ~50%.
|
||||
# So the training script was running ok and showed:
|
||||
# Vocab size: 65,536
|
||||
@@ -72,23 +62,29 @@ python -m scripts.tok_eval
|
||||
# The tok_eval.py script reports about ~4.8 chars/token on average for the default tokenizer settings.
|
||||
# So ~38B tokens # ~4.8 chars/token = ~185B chars.
|
||||
# Each data shard is ~250M chars, so we need ~185B / 250M ~= 740 shards.
|
||||
# For safety, I bumped that up to 800 shards, and that's why up above I used -n 800 when pre-downloading dataset shards.
|
||||
# For safety, I bumped that up to 800 shards.
|
||||
# The new DataLoader wastes about 35% of tokens to cropping, so 800 / (1 - 0.35) ~= 1200 shards are needed.
|
||||
# => why up above I used -n 1200 when pre-downloading dataset shards.
|
||||
# If we didn't have enough data, the training script would loop around and do multiple epochs over the same data,
|
||||
# which would decrease model performance. Possibly 2, 3 or so epochs is ~ok, but certainly not ideal and at 10+ epochs we'd
|
||||
# start to overfit hard.
|
||||
# 5) That's it, everything else (e.g. the learning rates) is adjusted automatically by the training script.
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=32 --device_batch_size=8
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval
|
||||
|
||||
# Number of processes/GPUs to use
|
||||
NPROC_PER_NODE=8
|
||||
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=32 --target-param-data-ratio=20 --device-batch-size=8 --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval
|
||||
|
||||
# midtrain
|
||||
# NOTE: ensure that we use the same device_batch_size here as the base training script.
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=8 --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device-batch-size=8 --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid
|
||||
|
||||
# sft
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft
|
||||
|
||||
# generate final report
|
||||
python -m nanochat.report generate
|
||||
70
runs/runcpu.sh
Executable file
70
runs/runcpu.sh
Executable file
@@ -0,0 +1,70 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks)
|
||||
# This script was last updated/tuned on Jan 17, 2026.
|
||||
|
||||
# Run as:
|
||||
# bash dev/cpu_demo_run.sh
|
||||
|
||||
# NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook.
|
||||
# Think of this run as educational/fun demo, not something you should expect to work well.
|
||||
# (This is why I hide this script away in dev/)
|
||||
# You may also want to run this script manually and one by one, copy pasting commands into your terminal.
|
||||
|
||||
# all the setup stuff
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
|
||||
mkdir -p $NANOCHAT_BASE_DIR
|
||||
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
[ -d ".venv" ] || uv venv
|
||||
uv sync --extra cpu
|
||||
source .venv/bin/activate
|
||||
if [ -z "$WANDB_RUN" ]; then
|
||||
WANDB_RUN=dummy
|
||||
fi
|
||||
|
||||
# train tokenizer on ~2B characters (~34 seconds on my MacBook Pro M3 Max)
|
||||
python -m nanochat.dataset -n 8
|
||||
python -m scripts.tok_train --max-chars=2000000000
|
||||
python -m scripts.tok_eval
|
||||
|
||||
# train a small 4 layer model
|
||||
# I tuned this run to complete in about 30 minutes on my MacBook Pro M3 Max.
|
||||
# To get better results, try increasing num_iterations, or get other ideas from your favorite LLM.
|
||||
python -m scripts.base_train \
|
||||
--depth=6 \
|
||||
--head-dim=64 \
|
||||
--window-pattern=L \
|
||||
--max-seq-len=512 \
|
||||
--device-batch-size=32 \
|
||||
--total-batch-size=16384 \
|
||||
--eval-every=100 \
|
||||
--eval-tokens=524288 \
|
||||
--core-metric-every=-1 \
|
||||
--sample-every=100 \
|
||||
--num-iterations=5000 \
|
||||
--run=$WANDB_RUN
|
||||
python -m scripts.base_loss --device-batch-size=1 --split-tokens=16384
|
||||
python -m scripts.base_eval --max-per-task=16
|
||||
|
||||
# midtraining (~10 minutes on my MacBook Pro M3 Max)
|
||||
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
|
||||
python -m scripts.mid_train \
|
||||
--max-seq-len=512 \
|
||||
--device-batch-size=32 \
|
||||
--total-batch-size=16384 \
|
||||
--eval-every=200 \
|
||||
--eval-tokens=524288 \
|
||||
--num-iterations=1500 \
|
||||
--run=$WANDB_RUN
|
||||
|
||||
# (it's ~ok to skip SFT)
|
||||
|
||||
# Chat with the model over CLI
|
||||
# The model should be able to say that it is Paris.
|
||||
# It might even know that the color of the sky is blue.
|
||||
# Sometimes the model likes it if you first say Hi before you ask it questions.
|
||||
# python -m scripts.chat_cli -i mid -p "What is the capital of France?"
|
||||
|
||||
# Chat with the model over a pretty WebUI ChatGPT style
|
||||
# python -m scripts.chat_web -i mid
|
||||
125
runs/scaling_laws.sh
Normal file
125
runs/scaling_laws.sh
Normal file
@@ -0,0 +1,125 @@
|
||||
#!/bin/bash
|
||||
|
||||
LABEL="jan26"
|
||||
|
||||
FLOPS_BUDGETS=(
|
||||
1e18
|
||||
2.15e18
|
||||
4.64e18
|
||||
1e19
|
||||
)
|
||||
DEPTHS=(8 10 12 14 16 18 20)
|
||||
|
||||
NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
|
||||
WANDB_RUN="${WANDB_RUN:-scaling_${LABEL}}"
|
||||
EVAL_TOKENS=$((100 * 524288)) # ~100M tokens for final eval (default is ~10M)
|
||||
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR="${NANOCHAT_BASE_DIR:-$HOME/.cache/nanochat}"
|
||||
source .venv/bin/activate
|
||||
|
||||
RESULTS_DIR="$NANOCHAT_BASE_DIR/scaling_laws_results_${LABEL}"
|
||||
mkdir -p "$RESULTS_DIR"
|
||||
RESULTS_FILE="$RESULTS_DIR/results.csv"
|
||||
|
||||
# Write CSV header only if file doesn't exist
|
||||
if [ ! -f "$RESULTS_FILE" ]; then
|
||||
echo "flops_budget,depth,model_dim,params_wte,params_bigram_embed,params_value_embeds,params_lm_head,params_transformer,params_scalars,params_total,num_iterations,tokens_trained,val_bpb,core_score,train_time_sec" > "$RESULTS_FILE"
|
||||
fi
|
||||
|
||||
log() {
|
||||
echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1"
|
||||
}
|
||||
|
||||
# Check if a run already exists in results
|
||||
run_exists() {
|
||||
local flops=$1
|
||||
local depth=$2
|
||||
grep -q "^${flops},${depth}," "$RESULTS_FILE" 2>/dev/null
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
# Main Loop
|
||||
# =============================================================================
|
||||
|
||||
for flops in "${FLOPS_BUDGETS[@]}"; do
|
||||
log "=============================================="
|
||||
log "Compute budget: $flops FLOPs"
|
||||
log "=============================================="
|
||||
|
||||
for d in "${DEPTHS[@]}"; do
|
||||
|
||||
# Skip if already completed
|
||||
if run_exists "$flops" "$d"; then
|
||||
log "Skipping d=$d at $flops FLOPs (already in results)"
|
||||
continue
|
||||
fi
|
||||
|
||||
log "Training d=$d at $flops FLOPs..."
|
||||
|
||||
# Unique tag for this run
|
||||
TAG="scaling_${flops}_d${d}"
|
||||
|
||||
# Record start time
|
||||
START_TIME=$(date +%s)
|
||||
|
||||
# Train the model with fixed flops budget
|
||||
# The script will auto-calculate num_iterations to hit target_flops
|
||||
# CORE eval happens once at the end (999999 ensures only final step)
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \
|
||||
--depth=$d \
|
||||
--target-flops=$flops \
|
||||
--target-param-data-ratio=-1 \
|
||||
--run="${WANDB_RUN}_${TAG}" \
|
||||
--model-tag="${TAG}" \
|
||||
--eval-tokens=$EVAL_TOKENS \
|
||||
--core-metric-every=999999 \
|
||||
--core-metric-max-per-task=-1 \
|
||||
--sample-every=-1 \
|
||||
--save-every=-1 \
|
||||
2>&1 | tee "$RESULTS_DIR/${TAG}_train.log"
|
||||
|
||||
END_TIME=$(date +%s)
|
||||
TRAIN_TIME=$((END_TIME - START_TIME))
|
||||
|
||||
# Extract training stats from the log
|
||||
LOG_FILE="$RESULTS_DIR/${TAG}_train.log"
|
||||
|
||||
# Extract detailed parameter counts (for scaling law analysis with different conventions)
|
||||
PARAMS_WTE=$(grep "wte:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
PARAMS_BIGRAM=$(grep "bigram_embed:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
PARAMS_VE=$(grep "value_embeds:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
PARAMS_LM=$(grep "lm_head:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
PARAMS_TRANSFORMER=$(grep "transformer_matrices:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
PARAMS_SCALARS=$(grep "scalars:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
PARAMS_TOTAL=$(grep "total:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
|
||||
NUM_ITERS=$(grep "Calculated number of iterations" "$LOG_FILE" | tail -1 | sed 's/.*: //' | tr -d ',')
|
||||
# Calculate tokens trained (iterations * batch_size, default 524288)
|
||||
TOKENS_TRAINED=$((NUM_ITERS * 524288))
|
||||
# Model dim
|
||||
MODEL_DIM=$((d * 64))
|
||||
# Val BPB from final eval
|
||||
VAL_BPB=$(grep "Validation bpb:" "$LOG_FILE" | tail -1 | grep -oP '[\d.]+$')
|
||||
|
||||
# Extract CORE score from training log (evaluated on final step)
|
||||
CORE_SCORE=$(grep "CORE metric:" "$LOG_FILE" | tail -1 | awk '{print $NF}')
|
||||
if [ -z "$CORE_SCORE" ]; then
|
||||
log "WARNING: Could not extract CORE score for d=$d"
|
||||
CORE_SCORE="0.0"
|
||||
fi
|
||||
|
||||
log " Params: $PARAMS_TOTAL (transformer: $PARAMS_TRANSFORMER), Iters: $NUM_ITERS, Val BPB: $VAL_BPB, CORE: $CORE_SCORE"
|
||||
|
||||
# Append to CSV
|
||||
echo "$flops,$d,$MODEL_DIM,$PARAMS_WTE,$PARAMS_BIGRAM,$PARAMS_VE,$PARAMS_LM,$PARAMS_TRANSFORMER,$PARAMS_SCALARS,$PARAMS_TOTAL,$NUM_ITERS,$TOKENS_TRAINED,$VAL_BPB,$CORE_SCORE,$TRAIN_TIME" >> "$RESULTS_FILE"
|
||||
done
|
||||
done
|
||||
|
||||
log "=============================================="
|
||||
log "Scaling Laws Sweep Complete"
|
||||
log "=============================================="
|
||||
log "Results saved to: $RESULTS_FILE"
|
||||
echo ""
|
||||
echo "Results:"
|
||||
column -t -s',' "$RESULTS_FILE"
|
||||
@@ -23,7 +23,7 @@ command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
# create a .venv local virtual environment (if it doesn't exist)
|
||||
[ -d ".venv" ] || uv venv
|
||||
# install the repo dependencies
|
||||
uv sync
|
||||
uv sync --extra gpu
|
||||
# activate venv so that `python` uses the project's venv instead of system python
|
||||
source .venv/bin/activate
|
||||
|
||||
@@ -48,13 +48,6 @@ python -m nanochat.report reset
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tokenizer
|
||||
|
||||
# Install Rust / Cargo
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
source "$HOME/.cargo/env"
|
||||
|
||||
# Build the rustbpe Tokenizer
|
||||
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
|
||||
|
||||
# Download the first ~2B characters of pretraining dataset
|
||||
# look at dev/repackage_data_reference.py for details on how this data was prepared
|
||||
# each data shard is ~250M chars
|
||||
@@ -62,59 +55,55 @@ uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
|
||||
# each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk
|
||||
python -m nanochat.dataset -n 8
|
||||
# Immediately also kick off downloading more shards in the background while tokenizer trains
|
||||
# See comment below for why 240 is the right number here
|
||||
python -m nanochat.dataset -n 240 &
|
||||
# See comment below for why 370 is the right number here
|
||||
python -m nanochat.dataset -n 370 &
|
||||
DATASET_DOWNLOAD_PID=$!
|
||||
# train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data
|
||||
python -m scripts.tok_train --max_chars=2000000000
|
||||
# train the tokenizer with vocab size 2**15 = 32768 on ~2B characters of data
|
||||
python -m scripts.tok_train
|
||||
# evaluate the tokenizer (report compression ratio etc.)
|
||||
python -m scripts.tok_eval
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Base model (pretraining)
|
||||
|
||||
# Download the eval_bundle from s3 to evaluate CORE metric during training (~162MB)
|
||||
EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
|
||||
if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
|
||||
curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL
|
||||
unzip -q eval_bundle.zip
|
||||
rm eval_bundle.zip
|
||||
mv eval_bundle $NANOCHAT_BASE_DIR
|
||||
fi
|
||||
|
||||
# The d20 model is 561M parameters.
|
||||
# Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens.
|
||||
# Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars.
|
||||
# At 250M chars/shard, this is 54B / 250M ~= 216 shards needed for pretraining.
|
||||
# Round up to 240 for safety. At ~100MB/shard, this downloads ~24GB of data to disk.
|
||||
# Round up to 240 for safety. Also, the new DataLoader wastes about 35% of tokens to cropping
|
||||
# so 240 / (1 - 0.35) = 370 shards are needed.
|
||||
# At ~100MB/shard, this downloads ~37GB of data to disk.
|
||||
# (The total number of shards available in the entire dataset is 1822.)
|
||||
echo "Waiting for dataset download to complete..."
|
||||
wait $DATASET_DOWNLOAD_PID
|
||||
|
||||
# Number of processes/GPUs to use
|
||||
NPROC_PER_NODE=8
|
||||
|
||||
# pretrain the d20 model
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=20 --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --target-param-data-ratio=20 --run=$WANDB_RUN
|
||||
# evaluate the model on a larger chunk of train/val data and draw some samples
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
|
||||
# evaluate the model on CORE tasks
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Midtraining (teach the model conversation special tokens, tool use, multiple choice)
|
||||
|
||||
# download 2.3MB of synthetic identity conversations to impart a personality to nanochat
|
||||
# see dev/gen_sft_data.py for details on how this data was prepared and to get a sense of how you can easily tune it
|
||||
# see dev/gen_synthetic_data.py for details on how this data was prepared and to get a sense of how you can easily tune it
|
||||
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
|
||||
|
||||
# run midtraining and eval the model
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Supervised Finetuning (domain adaptation to each sequence all by itself per row)
|
||||
|
||||
# train sft and re-eval right away (should see a small bump)
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft
|
||||
|
||||
# chat with the model over CLI! Leave out the -p to chat interactively
|
||||
# python -m scripts.chat_cli -p "Why is the sky blue?"
|
||||
@@ -127,9 +116,9 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft
|
||||
# (optional)
|
||||
|
||||
# run reinforcement learning
|
||||
# torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=$WANDB_RUN
|
||||
# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_rl -- --run=$WANDB_RUN
|
||||
# eval the RL model only on GSM8K
|
||||
# torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i rl -a GSM8K
|
||||
# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i rl -a GSM8K
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Generate the full report by putting together all the sections
|
||||
458
rustbpe/Cargo.lock
generated
458
rustbpe/Cargo.lock
generated
@@ -1,458 +0,0 @@
|
||||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
version = 4
|
||||
|
||||
[[package]]
|
||||
name = "ahash"
|
||||
version = "0.8.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"getrandom",
|
||||
"once_cell",
|
||||
"version_check",
|
||||
"zerocopy",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aho-corasick"
|
||||
version = "1.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "arc-swap"
|
||||
version = "1.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457"
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
|
||||
|
||||
[[package]]
|
||||
name = "bit-set"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3"
|
||||
dependencies = [
|
||||
"bit-vec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bit-vec"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7"
|
||||
|
||||
[[package]]
|
||||
name = "castaway"
|
||||
version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dec551ab6e7578819132c713a93c022a05d60159dc86e7a7050223577484c55a"
|
||||
dependencies = [
|
||||
"rustversion",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "1.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9"
|
||||
|
||||
[[package]]
|
||||
name = "compact_str"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3fdb1325a1cece981e8a296ab8f0f9b63ae357bd0784a9faaf548cc7b480707a"
|
||||
dependencies = [
|
||||
"castaway",
|
||||
"cfg-if",
|
||||
"itoa",
|
||||
"rustversion",
|
||||
"ryu",
|
||||
"static_assertions",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-deque"
|
||||
version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
|
||||
dependencies = [
|
||||
"crossbeam-epoch",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-epoch"
|
||||
version = "0.9.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
|
||||
dependencies = [
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-utils"
|
||||
version = "0.8.21"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
|
||||
|
||||
[[package]]
|
||||
name = "dary_heap"
|
||||
version = "0.3.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "04d2cd9c18b9f454ed67da600630b021a8a80bf33f8c95896ab33aaf1c26b728"
|
||||
|
||||
[[package]]
|
||||
name = "either"
|
||||
version = "1.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
|
||||
|
||||
[[package]]
|
||||
name = "equivalent"
|
||||
version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
|
||||
|
||||
[[package]]
|
||||
name = "fancy-regex"
|
||||
version = "0.16.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bf04c5ec15464ace8355a7b440a33aece288993475556d461154d7a62ad9947c"
|
||||
dependencies = [
|
||||
"bit-set",
|
||||
"regex-automata",
|
||||
"regex-syntax",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"r-efi",
|
||||
"wasi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.15.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1"
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "2.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f2481980430f9f78649238835720ddccc57e52df14ffce1c6f37391d61b563e9"
|
||||
dependencies = [
|
||||
"equivalent",
|
||||
"hashbrown",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indoc"
|
||||
version = "2.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd"
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "1.0.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.175"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543"
|
||||
|
||||
[[package]]
|
||||
name = "log"
|
||||
version = "0.4.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432"
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.7.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0"
|
||||
|
||||
[[package]]
|
||||
name = "memoffset"
|
||||
version = "0.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.21.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
|
||||
|
||||
[[package]]
|
||||
name = "portable-atomic"
|
||||
version = "1.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483"
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.101"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de"
|
||||
dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyo3"
|
||||
version = "0.23.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"indoc",
|
||||
"libc",
|
||||
"memoffset",
|
||||
"once_cell",
|
||||
"portable-atomic",
|
||||
"pyo3-build-config",
|
||||
"pyo3-ffi",
|
||||
"pyo3-macros",
|
||||
"unindent",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyo3-build-config"
|
||||
version = "0.23.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"target-lexicon",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyo3-ffi"
|
||||
version = "0.23.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"pyo3-build-config",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyo3-log"
|
||||
version = "0.12.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "45192e5e4a4d2505587e27806c7b710c231c40c56f3bfc19535d0bb25df52264"
|
||||
dependencies = [
|
||||
"arc-swap",
|
||||
"log",
|
||||
"pyo3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyo3-macros"
|
||||
version = "0.23.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"pyo3-macros-backend",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyo3-macros-backend"
|
||||
version = "0.23.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028"
|
||||
dependencies = [
|
||||
"heck",
|
||||
"proc-macro2",
|
||||
"pyo3-build-config",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.40"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "r-efi"
|
||||
version = "5.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
|
||||
|
||||
[[package]]
|
||||
name = "rayon"
|
||||
version = "1.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f"
|
||||
dependencies = [
|
||||
"either",
|
||||
"rayon-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon-core"
|
||||
version = "1.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91"
|
||||
dependencies = [
|
||||
"crossbeam-deque",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-automata"
|
||||
version = "0.4.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6b9458fa0bfeeac22b5ca447c63aaf45f28439a709ccd244698632f9aa6394d6"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
"regex-syntax",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "caf4aa5b0f434c91fe5c7f1ecb6a5ece2130b02ad2a590589dda5146df959001"
|
||||
|
||||
[[package]]
|
||||
name = "rustbpe"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"compact_str",
|
||||
"dary_heap",
|
||||
"fancy-regex",
|
||||
"indexmap",
|
||||
"log",
|
||||
"pyo3",
|
||||
"pyo3-log",
|
||||
"rayon",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustversion"
|
||||
version = "1.0.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.20"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f"
|
||||
|
||||
[[package]]
|
||||
name = "static_assertions"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.106"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "target-lexicon"
|
||||
version = "0.12.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-ident"
|
||||
version = "1.0.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
|
||||
|
||||
[[package]]
|
||||
name = "unindent"
|
||||
version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3"
|
||||
|
||||
[[package]]
|
||||
name = "version_check"
|
||||
version = "0.9.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
|
||||
|
||||
[[package]]
|
||||
name = "wasi"
|
||||
version = "0.14.4+wasi-0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "88a5f4a424faf49c3c2c344f166f0662341d470ea185e939657aaff130f0ec4a"
|
||||
dependencies = [
|
||||
"wit-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wit-bindgen"
|
||||
version = "0.45.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5c573471f125075647d03df72e026074b7203790d41351cd6edc96f46bcccd36"
|
||||
|
||||
[[package]]
|
||||
name = "zerocopy"
|
||||
version = "0.8.26"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1039dd0d3c310cf05de012d8a39ff557cb0d23087fd44cad61df08fc31907a2f"
|
||||
dependencies = [
|
||||
"zerocopy-derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zerocopy-derive"
|
||||
version = "0.8.26"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
@@ -1,15 +0,0 @@
|
||||
[package]
|
||||
name = "rustbpe"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
dary_heap = "0.3"
|
||||
indexmap = "2.2"
|
||||
fancy-regex = "0.16.1"
|
||||
log = "0.4.28"
|
||||
pyo3 = { version = "0.23.3", features = ["extension-module"] }
|
||||
pyo3-log = "0.12.4"
|
||||
ahash = "0.8.12"
|
||||
rayon = "1.11.0"
|
||||
compact_str = "0.9.0"
|
||||
@@ -1,5 +0,0 @@
|
||||
# rustbpe
|
||||
|
||||
> The missing tiktoken training code
|
||||
|
||||
A very lightweight Rust library for training a GPT tokenizer. The issue is that the inference library [tiktoken](https://github.com/openai/tiktoken) is great, but only does inference. Separately, the huggingface [tokenizers](https://github.com/huggingface/tokenizers) library does training, but it is rather bloated and really hard to navigate because it has to support all the different historical baggage of how people dealt with tokenizers over the years. More recently, I also wrote the [minbpe](https://github.com/karpathy/minbpe) library which does both training and inference, but only in inefficient Python. Basically what I really want is a non-fancy, super simple, but still relatively efficient training code for GPT tokenizer (more efficient than minbpe, much cleaner/simpler than tokenizers), and then export the trained vocab for inference with tiktoken. Does that make sense? So here we are. There are more opportunities for optimization here, I just stopped a bit early because unlike minbpe before it, rustbpe is now simple and fast enough, and not a significant bottleneck for nanochat.
|
||||
@@ -1,475 +0,0 @@
|
||||
use std::cmp::Ordering;
|
||||
use std::collections::HashMap as StdHashMap;
|
||||
|
||||
use dary_heap::OctonaryHeap;
|
||||
use fancy_regex::Regex;
|
||||
use pyo3::prelude::*;
|
||||
|
||||
use ahash::{AHashMap, AHashSet};
|
||||
use compact_str::CompactString;
|
||||
use rayon::prelude::*;
|
||||
|
||||
// Default GPT-4 style regex pattern for splitting text
|
||||
const GPT4_PATTERN: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+";
|
||||
|
||||
type Pair = (u32, u32);
|
||||
|
||||
/// A Byte Pair Encoding tokenizer that matches the GPT-4 style implementation
|
||||
#[pyclass]
|
||||
pub struct Tokenizer {
|
||||
/// Maps pairs of token IDs to their merged token ID
|
||||
pub merges: StdHashMap<Pair, u32>,
|
||||
/// The regex pattern used for text splitting
|
||||
pub pattern: String,
|
||||
/// Compiled regex for efficiency
|
||||
compiled_pattern: Regex,
|
||||
}
|
||||
|
||||
// ------------------------ internal helpers ------------------------
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct Word {
|
||||
ids: Vec<u32>,
|
||||
}
|
||||
|
||||
impl Word {
|
||||
#[inline]
|
||||
fn new(ids: Vec<u32>) -> Self {
|
||||
Self { ids }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn pairs<'a>(&'a self) -> impl Iterator<Item = Pair> + 'a {
|
||||
self.ids.windows(2).map(|w| (w[0], w[1]))
|
||||
}
|
||||
|
||||
/// Merge all non-overlapping occurrences of pair -> new_id.
|
||||
/// Returns a small Vec of local pair-count deltas for THIS word only:
|
||||
/// -1 for removed pairs, +1 for newly created pairs.
|
||||
///
|
||||
/// NOTE: this version deliberately avoids a HashMap in the hot loop.
|
||||
fn merge_pair(&mut self, pair: Pair, new_id: u32) -> Vec<(Pair, i32)> {
|
||||
let (a, b) = pair;
|
||||
let n = self.ids.len();
|
||||
if n < 2 {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let mut out: Vec<u32> = Vec::with_capacity(n);
|
||||
let mut deltas: Vec<(Pair, i32)> = Vec::with_capacity(6);
|
||||
|
||||
let mut i = 0;
|
||||
while i < n {
|
||||
if i + 1 < n && self.ids[i] == a && self.ids[i + 1] == b {
|
||||
let left = out.last().copied();
|
||||
let right = if i + 2 < n { Some(self.ids[i + 2]) } else { None };
|
||||
|
||||
// remove old pairs
|
||||
if let Some(x) = left {
|
||||
deltas.push(((x, a), -1));
|
||||
deltas.push(((x, new_id), 1));
|
||||
}
|
||||
deltas.push(((a, b), -1));
|
||||
if let Some(y) = right {
|
||||
deltas.push(((b, y), -1));
|
||||
deltas.push(((new_id, y), 1));
|
||||
}
|
||||
|
||||
// write merged token
|
||||
out.push(new_id);
|
||||
i += 2; // skip 'a' and 'b'
|
||||
} else {
|
||||
out.push(self.ids[i]);
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
self.ids = out;
|
||||
deltas
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq)]
|
||||
struct MergeJob {
|
||||
pair: Pair,
|
||||
count: u64,
|
||||
/// set of word indices where this pair may occur and needs processing
|
||||
pos: AHashSet<usize>,
|
||||
}
|
||||
|
||||
impl PartialEq for MergeJob {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.count == other.count && self.pair == other.pair
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd for MergeJob {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for MergeJob {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
// Max-heap by count; tie-break to ascending pair order (deterministic)
|
||||
if self.count != other.count {
|
||||
self.count.cmp(&other.count)
|
||||
} else {
|
||||
// ascending order on the pair when counts tie
|
||||
other.pair.cmp(&self.pair)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn count_pairs_parallel(
|
||||
words: &[Word],
|
||||
counts: &[i32],
|
||||
) -> (AHashMap<Pair, i32>, AHashMap<Pair, AHashSet<usize>>) {
|
||||
words
|
||||
.par_iter()
|
||||
.enumerate()
|
||||
.map(|(i, w)| {
|
||||
let mut local_pc: AHashMap<Pair, i32> = AHashMap::new();
|
||||
let mut local_wtu: AHashMap<Pair, AHashSet<usize>> = AHashMap::new();
|
||||
if w.ids.len() >= 2 && counts[i] != 0 {
|
||||
for (a, b) in w.pairs() {
|
||||
*local_pc.entry((a, b)).or_default() += counts[i];
|
||||
local_wtu.entry((a, b)).or_default().insert(i);
|
||||
}
|
||||
}
|
||||
(local_pc, local_wtu)
|
||||
})
|
||||
.reduce(
|
||||
|| (AHashMap::new(), AHashMap::new()),
|
||||
|(mut acc_pc, mut acc_wtu), (pc, wtu)| {
|
||||
for (k, v) in pc {
|
||||
*acc_pc.entry(k).or_default() += v;
|
||||
}
|
||||
for (k, s) in wtu {
|
||||
acc_wtu.entry(k).or_default().extend(s);
|
||||
}
|
||||
(acc_pc, acc_wtu)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// ------------------------ END helpers ------------------------
|
||||
|
||||
impl Tokenizer {
|
||||
|
||||
/// Core incremental BPE training given unique words and their counts.
|
||||
/// `words`: one entry per unique chunk (Vec<u32> of token-ids/bytes).
|
||||
/// `counts`: same length as `words`, count per chunk.
|
||||
fn train_core_incremental(&mut self, mut words: Vec<Word>, counts: Vec<i32>, vocab_size: u32) {
|
||||
assert!(vocab_size >= 256, "vocab_size must be at least 256");
|
||||
let num_merges = vocab_size - 256;
|
||||
log::info!("Starting BPE training: {} merges to compute", num_merges);
|
||||
self.merges.clear();
|
||||
|
||||
// ---- Initial pair_counts and where_to_update (parallel) ----
|
||||
log::info!("Computing initial pair counts from {} unique sequences", words.len());
|
||||
let (mut pair_counts, mut where_to_update) = count_pairs_parallel(&words, &counts);
|
||||
|
||||
// ---- Build heap ----
|
||||
log::info!("Building heap with {} unique pairs", pair_counts.len());
|
||||
let mut heap = OctonaryHeap::with_capacity(pair_counts.len());
|
||||
for (pair, pos) in where_to_update.drain() {
|
||||
let c = *pair_counts.get(&pair).unwrap_or(&0);
|
||||
if c > 0 {
|
||||
heap.push(MergeJob {
|
||||
pair,
|
||||
count: c as u64,
|
||||
pos,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Merge loop ----
|
||||
log::info!("Starting merge loop");
|
||||
let mut merges_done = 0u32;
|
||||
let mut last_log_percent = 0u32;
|
||||
|
||||
while merges_done < num_merges {
|
||||
let Some(mut top) = heap.pop() else { break; };
|
||||
|
||||
// Lazy refresh
|
||||
let current = *pair_counts.get(&top.pair).unwrap_or(&0);
|
||||
if top.count != current as u64 {
|
||||
top.count = current as u64;
|
||||
if top.count > 0 {
|
||||
heap.push(top);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if top.count == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
// Record merge
|
||||
let new_id = 256 + merges_done;
|
||||
self.merges.insert(top.pair, new_id);
|
||||
|
||||
// Merge this pair in all words where it occurs
|
||||
let mut local_pos_updates: AHashMap<Pair, AHashSet<usize>> = AHashMap::new();
|
||||
for &word_idx in &top.pos {
|
||||
// Apply merge to this word and collect pair-count deltas
|
||||
let changes = words[word_idx].merge_pair(top.pair, new_id);
|
||||
// Update global pair counts based on this word's count
|
||||
for (pair, delta) in changes {
|
||||
let delta_total = delta * counts[word_idx];
|
||||
if delta_total != 0 {
|
||||
*pair_counts.entry(pair).or_default() += delta_total;
|
||||
if delta > 0 {
|
||||
local_pos_updates.entry(pair).or_default().insert(word_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add the updated pair counts back to the heap
|
||||
for (pair, pos) in local_pos_updates {
|
||||
let cnt = *pair_counts.get(&pair).unwrap_or(&0);
|
||||
if cnt > 0 {
|
||||
heap.push(MergeJob {
|
||||
pair,
|
||||
count: cnt as u64,
|
||||
pos,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
merges_done += 1;
|
||||
|
||||
// Log progress every 1%
|
||||
let current_percent = (merges_done * 100) / num_merges;
|
||||
if current_percent > last_log_percent {
|
||||
log::info!(
|
||||
"Progress: {}% ({}/{} merges) - Last merge: {:?} -> {} (frequency: {})",
|
||||
current_percent, merges_done, num_merges, top.pair, new_id, top.count
|
||||
);
|
||||
last_log_percent = current_percent;
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("Finished training: {} merges completed", merges_done);
|
||||
}
|
||||
}
|
||||
|
||||
/// Public methods for the Tokenizer class that will be exposed to Python.
|
||||
#[pymethods]
|
||||
impl Tokenizer {
|
||||
/// Create a new Tokenizer
|
||||
#[new]
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
merges: StdHashMap::new(),
|
||||
pattern: String::new(),
|
||||
compiled_pattern: Regex::new("").expect("Empty regex should be valid"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Train from a streaming iterator (parallel ingestion).
|
||||
/// We refill a Rust Vec<String> buffer under the GIL, then release the GIL
|
||||
/// to do the heavy splitting and counting **in parallel** with rayon.
|
||||
#[pyo3(signature = (iterator, vocab_size, buffer_size=8192, pattern=None))]
|
||||
#[pyo3(text_signature = "(self, iterator, vocab_size, buffer_size=8192, pattern=None)")]
|
||||
pub fn train_from_iterator(
|
||||
&mut self,
|
||||
py: pyo3::Python<'_>,
|
||||
iterator: &pyo3::Bound<'_, pyo3::PyAny>,
|
||||
vocab_size: u32,
|
||||
buffer_size: usize,
|
||||
pattern: Option<String>,
|
||||
) -> PyResult<()> {
|
||||
// Use provided pattern or default to GPT-4 pattern
|
||||
let pattern_str = pattern.unwrap_or_else(|| GPT4_PATTERN.to_string());
|
||||
|
||||
// Update the stored pattern and compile it
|
||||
self.pattern = pattern_str.clone();
|
||||
self.compiled_pattern = Regex::new(&pattern_str)
|
||||
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Invalid regex pattern: {}", e)))?;
|
||||
|
||||
// Prepare a true Python iterator object
|
||||
let py_iter: pyo3::Py<pyo3::PyAny> = unsafe {
|
||||
pyo3::Py::from_owned_ptr_or_err(py, pyo3::ffi::PyObject_GetIter(iterator.as_ptr()))?
|
||||
};
|
||||
|
||||
// Global chunk counts
|
||||
let mut counts: AHashMap<CompactString, i32> = AHashMap::new();
|
||||
|
||||
// Temporary buffer we refill under the GIL
|
||||
let mut buf: Vec<String> = Vec::with_capacity(buffer_size);
|
||||
|
||||
log::info!("Processing sequences from iterator (buffer_size: {})", buffer_size);
|
||||
let mut total_sequences = 0u64;
|
||||
|
||||
// Helper: refill `buf` with up to `buffer_size` strings from the Python iterator.
|
||||
// Returns Ok(true) if the iterator is exhausted, Ok(false) otherwise.
|
||||
let refill = |buf: &mut Vec<String>| -> PyResult<bool> {
|
||||
pyo3::Python::with_gil(|py| {
|
||||
buf.clear();
|
||||
let it = py_iter.bind(py);
|
||||
loop {
|
||||
if buf.len() >= buffer_size {
|
||||
return Ok(false);
|
||||
}
|
||||
// next(it)
|
||||
let next_obj = unsafe {
|
||||
pyo3::Bound::from_owned_ptr_or_opt(py, pyo3::ffi::PyIter_Next(it.as_ptr()))
|
||||
};
|
||||
match next_obj {
|
||||
Some(obj) => {
|
||||
let s: String = obj.extract()?;
|
||||
buf.push(s);
|
||||
}
|
||||
None => {
|
||||
if pyo3::PyErr::occurred(py) {
|
||||
return Err(pyo3::PyErr::fetch(py));
|
||||
} else {
|
||||
return Ok(true); // exhausted
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
// Stream ingestion loop: refill under GIL, process without GIL (parallel)
|
||||
loop {
|
||||
let exhausted = refill(&mut buf)?;
|
||||
if buf.is_empty() && exhausted {
|
||||
break;
|
||||
}
|
||||
|
||||
total_sequences += buf.len() as u64;
|
||||
|
||||
let pattern = self.compiled_pattern.clone();
|
||||
let local: AHashMap<CompactString, i32> = py.allow_threads(|| {
|
||||
buf.par_iter()
|
||||
.map(|s| {
|
||||
let mut m: AHashMap<CompactString, i32> = AHashMap::new();
|
||||
for mat in pattern.find_iter(s) {
|
||||
let piece = mat.expect("regex match failed").as_str();
|
||||
*m.entry(CompactString::from(piece)).or_default() += 1;
|
||||
}
|
||||
m
|
||||
})
|
||||
.reduce(
|
||||
|| AHashMap::new(),
|
||||
|mut a, b| {
|
||||
for (k, v) in b {
|
||||
*a.entry(k).or_default() += v;
|
||||
}
|
||||
a
|
||||
},
|
||||
)
|
||||
});
|
||||
|
||||
// Merge local into global (single-threaded)
|
||||
for (k, v) in local {
|
||||
*counts.entry(k).or_default() += v;
|
||||
}
|
||||
|
||||
if exhausted {
|
||||
break;
|
||||
}
|
||||
}
|
||||
log::info!("Processed {} sequences total, {} unique", total_sequences, counts.len());
|
||||
|
||||
// Materialize words & counts
|
||||
let mut words = Vec::with_capacity(counts.len());
|
||||
let mut cvec = Vec::with_capacity(counts.len());
|
||||
for (chunk, c) in counts.into_iter() {
|
||||
words.push(Word::new(chunk.as_bytes().iter().map(|&b| b as u32).collect()));
|
||||
cvec.push(c);
|
||||
}
|
||||
|
||||
self.train_core_incremental(words, cvec, vocab_size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Return the regex pattern
|
||||
pub fn get_pattern(&self) -> String {
|
||||
self.pattern.clone()
|
||||
}
|
||||
|
||||
/// Return the mergeable ranks (token bytes -> token id / rank)
|
||||
pub fn get_mergeable_ranks(&self) -> Vec<(Vec<u8>, u32)> {
|
||||
let mut mergeable_ranks = Vec::new();
|
||||
|
||||
// Build vocabulary incrementally from low to high token IDs
|
||||
let mut token_bytes: Vec<Vec<u8>> = (0..256_u32).map(|i| vec![i as u8]).collect();
|
||||
|
||||
for (i, bytes) in token_bytes.iter().enumerate() {
|
||||
mergeable_ranks.push((bytes.clone(), i as u32));
|
||||
}
|
||||
|
||||
// Sort merges by token id (so we can reconstruct bytes progressively)
|
||||
let mut sorted_merges: Vec<_> = self.merges.iter().collect();
|
||||
sorted_merges.sort_by_key(|&(_, &token_id)| token_id);
|
||||
|
||||
for (&pair, &merged_id) in sorted_merges {
|
||||
let (left, right) = pair;
|
||||
let mut merged_bytes = token_bytes[left as usize].clone();
|
||||
merged_bytes.extend(&token_bytes[right as usize]);
|
||||
|
||||
if token_bytes.len() <= merged_id as usize {
|
||||
token_bytes.resize(merged_id as usize + 1, Vec::new());
|
||||
}
|
||||
token_bytes[merged_id as usize] = merged_bytes.clone();
|
||||
|
||||
mergeable_ranks.push((merged_bytes, merged_id));
|
||||
}
|
||||
|
||||
mergeable_ranks
|
||||
}
|
||||
|
||||
/// Encode a string into token IDs
|
||||
pub fn encode(&self, text: &str) -> Vec<u32> {
|
||||
let mut all_ids = Vec::new();
|
||||
|
||||
// Split text using the regex pattern
|
||||
for m in self.compiled_pattern.find_iter(text) {
|
||||
let chunk = m.expect("regex match failed").as_str();
|
||||
|
||||
// Convert chunk to bytes then to u32 IDs
|
||||
let mut ids: Vec<u32> = chunk.bytes().map(|b| b as u32).collect();
|
||||
|
||||
// Apply merges iteratively
|
||||
while ids.len() >= 2 {
|
||||
// Find the best pair to merge
|
||||
let mut best_pair: Option<(usize, Pair, u32)> = None;
|
||||
|
||||
for i in 0..ids.len() - 1 {
|
||||
let pair: Pair = (ids[i], ids[i + 1]);
|
||||
if let Some(&new_id) = self.merges.get(&pair) {
|
||||
if best_pair.is_none() || new_id < best_pair.unwrap().2 {
|
||||
best_pair = Some((i, pair, new_id));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we found a pair to merge, apply it
|
||||
if let Some((idx, _pair, new_id)) = best_pair {
|
||||
ids[idx] = new_id;
|
||||
ids.remove(idx + 1);
|
||||
} else {
|
||||
// No more merges possible
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
all_ids.extend(ids);
|
||||
}
|
||||
|
||||
all_ids
|
||||
}
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
fn rustbpe(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
pyo3_log::init(); // forwards Rust `log` to Python's `logging`
|
||||
m.add_class::<Tokenizer>()?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,49 +1,76 @@
|
||||
"""
|
||||
Evlauate the CORE metric for a given model.
|
||||
Evaluate the CORE metric for a given model.
|
||||
|
||||
Run on a single GPU:
|
||||
python base_eval.py
|
||||
python -m scripts.base_eval
|
||||
|
||||
Run with torchrun on e.g. 8 GPUs:
|
||||
torchrun --nproc_per_node=8 base_eval.py
|
||||
torchrun --nproc_per_node=8 -m scripts.base_eval
|
||||
|
||||
The script will print the CORE metric to the console.
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import csv
|
||||
import time
|
||||
import json
|
||||
import random
|
||||
import yaml
|
||||
import shutil
|
||||
import random
|
||||
import zipfile
|
||||
import tempfile
|
||||
from contextlib import nullcontext
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock
|
||||
from nanochat.tokenizer import HuggingFaceTokenizer
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.core_eval import evaluate_task
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# nanoChat specific function dealing with I/O etc.
|
||||
# nanochat specific function dealing with I/O etc.
|
||||
|
||||
# ~162MB of data needed to evaluate the CORE metric
|
||||
EVAL_BUNDLE_URL = "https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip"
|
||||
|
||||
def place_eval_bundle(file_path):
|
||||
# here file_path is the path to the eval_bundle.zip file
|
||||
# we need to unzip it and place it in the base directory
|
||||
base_dir = get_base_dir()
|
||||
eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with zipfile.ZipFile(file_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(tmpdir)
|
||||
extracted_bundle_dir = os.path.join(tmpdir, "eval_bundle")
|
||||
shutil.move(extracted_bundle_dir, eval_bundle_dir)
|
||||
print0(f"Placed eval_bundle directory at {eval_bundle_dir}")
|
||||
|
||||
def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
||||
"""
|
||||
Evaluate a base model on the CORE benchmark.
|
||||
- max_per_task: crop the data to this many examples per task for testing (-1 = disable)
|
||||
TODO: clean up this function, delete the need for all the files, for pandas dependency, etc.
|
||||
"""
|
||||
# Load config and task metadata
|
||||
base_dir = get_base_dir()
|
||||
eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
|
||||
# Download the eval bundle to disk (and unzip if needed)
|
||||
if not os.path.exists(eval_bundle_dir):
|
||||
download_file_with_lock(EVAL_BUNDLE_URL, "eval_bundle.zip", postprocess_fn=place_eval_bundle)
|
||||
config_path = os.path.join(eval_bundle_dir, "core.yaml")
|
||||
data_base_path = os.path.join(eval_bundle_dir, "eval_data")
|
||||
eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv")
|
||||
with open(config_path, 'r') as f:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
tasks = config['icl_tasks']
|
||||
eval_metadata = pd.read_csv(eval_meta_data)
|
||||
|
||||
# Load random baseline values from eval metadata
|
||||
random_baselines = {}
|
||||
with open(eval_meta_data, 'r', encoding='utf-8') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
task_name = row['Eval Task']
|
||||
random_baseline = row['Random baseline']
|
||||
random_baselines[task_name] = float(random_baseline)
|
||||
|
||||
# Evaluate each task
|
||||
results = {}
|
||||
@@ -61,11 +88,11 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
||||
|
||||
# Load data for this task
|
||||
data_path = os.path.join(data_base_path, task_meta['dataset_uri'])
|
||||
with open(data_path, 'r') as f:
|
||||
with open(data_path, 'r', encoding='utf-8') as f:
|
||||
data = [json.loads(line.strip()) for line in f]
|
||||
|
||||
# shuffle the data because in many cases it appears ordered but we want
|
||||
# the abillity to only run a subset of the data for debugging purposes etc.
|
||||
# the ability to only run a subset of the data for debugging purposes etc.
|
||||
shuffle_rng = random.Random(1337)
|
||||
shuffle_rng.shuffle(data)
|
||||
if max_per_task > 0:
|
||||
@@ -75,8 +102,7 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
||||
accuracy = evaluate_task(model, tokenizer, data, device, task_meta)
|
||||
|
||||
results[label] = accuracy
|
||||
row = eval_metadata[eval_metadata["Eval Task"] == label]
|
||||
random_baseline = row["Random baseline"].values[0]
|
||||
random_baseline = random_baselines[label]
|
||||
centered_result = (accuracy - 0.01 * random_baseline) / (1.0 - 0.01 * random_baseline)
|
||||
centered_results[label] = centered_result
|
||||
end_time = time.time()
|
||||
@@ -123,6 +149,8 @@ def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate')
|
||||
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)')
|
||||
parser.add_argument('--model-tag', type=str, default=None, help='optional model tag for the output directory name')
|
||||
parser.add_argument('--step', type=str, default=None, help='optional model step for the output directory name')
|
||||
args = parser.parse_args()
|
||||
|
||||
# distributed / precision setup
|
||||
@@ -140,7 +168,7 @@ def main():
|
||||
model_slug = hf_path.replace("/", "-") # for the output csv file
|
||||
else:
|
||||
# load a local model from the file system
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval")
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
model_name = f"base_model (step {meta['step']})" # just for logging
|
||||
model_slug = f"base_model_{meta['step']:06d}" # for the output csv file
|
||||
|
||||
@@ -158,7 +186,7 @@ def main():
|
||||
results = out["results"]
|
||||
centered_results = out["centered_results"]
|
||||
core_metric = out["core_metric"]
|
||||
with open(output_csv_path, 'w') as f:
|
||||
with open(output_csv_path, 'w', encoding='utf-8', newline='') as f:
|
||||
f.write(f"{'Task':<35}, {'Accuracy':<10}, {'Centered':<10}\n")
|
||||
for label in results:
|
||||
f.write(f"{label:<35}, {results[label]:<10.6f}, {centered_results[label]:<10.6f}\n")
|
||||
@@ -167,7 +195,7 @@ def main():
|
||||
print0("="*80)
|
||||
print0(f"Model: {model_name}")
|
||||
print0("="*80)
|
||||
with open(output_csv_path, 'r') as f:
|
||||
with open(output_csv_path, 'r', encoding='utf-8') as f:
|
||||
print0(f.read())
|
||||
|
||||
# Log to report
|
||||
|
||||
@@ -5,48 +5,108 @@ Loads a checkpoint, and:
|
||||
|
||||
Example run as:
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
|
||||
|
||||
To evaluate a HuggingFace model:
|
||||
python -m scripts.base_loss --hf-path openai-community/gpt2
|
||||
"""
|
||||
import os
|
||||
import argparse
|
||||
from contextlib import nullcontext
|
||||
import torch
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
||||
from nanochat.tokenizer import get_token_bytes
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit
|
||||
from nanochat.tokenizer import get_token_bytes, HuggingFaceTokenizer
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.engine import Engine
|
||||
|
||||
# Configuration
|
||||
device_batch_size = 32
|
||||
split_tokens = 20*524288 # number of tokens to evaluate per split
|
||||
model_tag = None # optional model tag for the output directory name
|
||||
model_step = None # optional model step for the output directory name
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
# -----------------------------------------------------------------------------
|
||||
# HuggingFace loading utilities, making the APIs match up to those of nanochat
|
||||
|
||||
class ModelWrapper:
|
||||
"""Lightweight wrapper for a HuggingFace model"""
|
||||
def __init__(self, model, max_seq_len=None):
|
||||
self.model = model
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
def __call__(self, input_ids, targets=None, loss_reduction='mean'):
|
||||
logits = self.model(input_ids).logits
|
||||
if targets is None:
|
||||
return logits
|
||||
else:
|
||||
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
|
||||
return loss
|
||||
|
||||
def get_device(self):
|
||||
return next(self.model.parameters()).device
|
||||
|
||||
def load_hf_model(hf_path: str, device):
|
||||
print0(f"Loading model from: {hf_path}")
|
||||
from transformers import AutoModelForCausalLM
|
||||
model = AutoModelForCausalLM.from_pretrained(hf_path)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
max_seq_len = 1024 if "openai-community/gpt2" in hf_path else None
|
||||
model = ModelWrapper(model, max_seq_len=max_seq_len)
|
||||
tokenizer = HuggingFaceTokenizer.from_pretrained(hf_path)
|
||||
return model, tokenizer
|
||||
|
||||
def get_hf_token_bytes(tokenizer, device="cpu"):
|
||||
"""Compute token_bytes tensor for a HuggingFace tokenizer."""
|
||||
vocab_size = tokenizer.tokenizer.get_vocab_size()
|
||||
token_bytes = torch.zeros(vocab_size, dtype=torch.int64, device=device)
|
||||
for token_id in range(vocab_size):
|
||||
token_str = tokenizer.tokenizer.decode([token_id])
|
||||
token_bytes[token_id] = len(token_str.encode('utf-8')) # Count UTF-8 bytes
|
||||
return token_bytes
|
||||
|
||||
# CLI arguments
|
||||
parser = argparse.ArgumentParser(description="Evaluate loss on train/val splits and sample from model")
|
||||
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--split-tokens", type=int, default=40*524288, help="number of tokens to evaluate per split")
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="model tag for checkpoint directory")
|
||||
parser.add_argument("--model-step", type=int, default=None, help="model step to load")
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--hf-path", type=str, default=None, help="HuggingFace model path (e.g. openai-community/gpt2)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load the base model and the tokenizer
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step)
|
||||
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
|
||||
print0(f"Device: {device} | DDP rank: {ddp_rank} | DDP local rank: {ddp_local_rank} | DDP world size: {ddp_world_size}")
|
||||
|
||||
if args.hf_path is not None:
|
||||
# Load HuggingFace model
|
||||
model, tokenizer = load_hf_model(args.hf_path, device)
|
||||
sequence_len = model.max_seq_len if model.max_seq_len else 1024
|
||||
token_bytes = get_hf_token_bytes(tokenizer, device=device)
|
||||
model_name = args.hf_path
|
||||
else:
|
||||
# Load local nanochat model
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.model_step)
|
||||
sequence_len = meta["model_config"]["sequence_len"]
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
model_name = f"base_model (step {meta['step']})"
|
||||
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
|
||||
print0(f"Evaluating model: {model_name}")
|
||||
|
||||
# Evaluate the loss on each split
|
||||
tokens_per_step = device_batch_size * sequence_len * ddp_world_size
|
||||
assert split_tokens % tokens_per_step == 0, "split_tokens must be divisible by tokens_per_step"
|
||||
steps = split_tokens // tokens_per_step
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
tokens_per_step = args.device_batch_size * sequence_len * ddp_world_size
|
||||
assert args.split_tokens % tokens_per_step == 0, "split_tokens must be divisible by tokens_per_step"
|
||||
steps = args.split_tokens // tokens_per_step
|
||||
bpb_results = {}
|
||||
for split_name in ["train", "val"]:
|
||||
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name, device=device)
|
||||
loader = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, sequence_len, split_name, device=device)
|
||||
with autocast_ctx:
|
||||
bpb = evaluate_bpb(model, loader, steps, token_bytes)
|
||||
print0(f"{split_name} bpb: {bpb:.4f}")
|
||||
bpb_results[split_name] = bpb
|
||||
print0(f"Model: {model_name}, {split_name} bpb: {bpb:.6f}")
|
||||
|
||||
# Master process also samples from the model
|
||||
# Master process also samples from the model for some basic knowledge-eliciting prompts (only for nanochat models)
|
||||
samples = []
|
||||
if ddp_rank == 0:
|
||||
if ddp_rank == 0 and args.hf_path is None:
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"The chemical symbol of gold is",
|
||||
@@ -62,17 +122,33 @@ if ddp_rank == 0:
|
||||
with autocast_ctx:
|
||||
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
|
||||
sample_str = tokenizer.decode(sample[0])
|
||||
print0("-" * 80)
|
||||
print0(sample_str)
|
||||
samples.append(sample_str)
|
||||
|
||||
# Draw some unconditioned samples from the model (only for nanochat models)
|
||||
unconditioned_samples = []
|
||||
if ddp_rank == 0 and args.hf_path is None:
|
||||
engine = Engine(model, tokenizer)
|
||||
tokens = tokenizer("", prepend="<|bos|>")
|
||||
with autocast_ctx:
|
||||
samples, _ = engine.generate_batch(tokens, num_samples=8, max_tokens=128, temperature=1.0)
|
||||
for sample in samples:
|
||||
sample_str = tokenizer.decode(sample)
|
||||
print0("-" * 80)
|
||||
print0(sample_str)
|
||||
unconditioned_samples.append(sample_str)
|
||||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
get_report().log(section="Base model loss", data=[
|
||||
{
|
||||
"model": model_name,
|
||||
"train bpb": bpb_results["train"],
|
||||
"val bpb": bpb_results["val"],
|
||||
},
|
||||
{f"sample {i}": sample for i, sample in enumerate(samples)},
|
||||
{f"unconditioned sample {i}": sample for i, sample in enumerate(unconditioned_samples)},
|
||||
])
|
||||
|
||||
# Cleanup
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
"""
|
||||
Train model. Run as:
|
||||
Train model. From root directory of the project, run as:
|
||||
|
||||
python base_train.py
|
||||
python -m scripts.base_train
|
||||
|
||||
or distributed as:
|
||||
|
||||
torchrun --nproc_per_node=8 base_train.py
|
||||
torchrun --nproc_per_node=8 -m scripts.base_train
|
||||
|
||||
If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
|
||||
python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_every=-1 --total_batch_size=512 --num_iterations=20
|
||||
python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20
|
||||
"""
|
||||
|
||||
import os
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import argparse
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
|
||||
@@ -20,60 +21,89 @@ import wandb
|
||||
import torch
|
||||
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops
|
||||
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.engine import Engine
|
||||
from nanochat.flash_attention import HAS_FA3
|
||||
from scripts.base_eval import evaluate_model
|
||||
print_banner()
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# User settings
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
# CLI arguments
|
||||
parser = argparse.ArgumentParser(description="Pretrain base model")
|
||||
# Logging
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
# Runtime
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU)
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
# Model architecture
|
||||
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
|
||||
max_seq_len = 2048 # max context length
|
||||
# Training horizon. Only one of these 3 will be used, in this order of precedence.
|
||||
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
|
||||
target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable)
|
||||
target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
|
||||
parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model")
|
||||
parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio")
|
||||
parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention")
|
||||
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
|
||||
parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
|
||||
# Training horizon (only one used, in order of precedence)
|
||||
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
|
||||
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
|
||||
parser.add_argument("--target-param-data-ratio", type=float, default=10.5, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
|
||||
# Optimization
|
||||
device_batch_size = 32 # per-device batch size (set to not OOM)
|
||||
total_batch_size = 524288 # total desired batch size, in #tokens
|
||||
embedding_lr = 0.2 # learning rate for the embedding parameters (Adam)
|
||||
unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
|
||||
weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam)
|
||||
matrix_lr = 0.02 # learning rate for the matrix parameters (Muon)
|
||||
grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
|
||||
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
|
||||
parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--weight-decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)")
|
||||
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
|
||||
parser.add_argument("--adam-beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding")
|
||||
parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding")
|
||||
parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
|
||||
parser.add_argument("--warmdown-ratio", type=float, default=0.4, help="ratio of iterations for LR warmdown")
|
||||
parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR")
|
||||
parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)")
|
||||
# Evaluation
|
||||
eval_every = 250 # every how many steps to evaluate the model for val bpb
|
||||
eval_tokens = 20*524288 # number of tokens to evaluate val loss on
|
||||
core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable)
|
||||
core_metric_max_per_task = 500 # examples per task in estimating the core metric
|
||||
sample_every = 2000 # every how many steps to sample from the model
|
||||
parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)")
|
||||
parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
|
||||
parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)")
|
||||
parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric")
|
||||
parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)")
|
||||
parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)")
|
||||
# Output
|
||||
model_tag = "" # optionally override the model tag for the output checkpoint directory name
|
||||
# now allow CLI to override the settings via the configurator lol
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy() # for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||
if device_type == "cuda":
|
||||
gpu_device_name = torch.cuda.get_device_name(0)
|
||||
gpu_peak_flops = get_peak_flops(gpu_device_name)
|
||||
print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}")
|
||||
else:
|
||||
gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=run, config=user_config)
|
||||
use_dummy_wandb = args.run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config)
|
||||
|
||||
# Flash Attention status
|
||||
if HAS_FA3:
|
||||
print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.")
|
||||
else:
|
||||
print0("!" * 80)
|
||||
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback")
|
||||
print0("WARNING: Training will be less efficient without FA3")
|
||||
if args.window_pattern != "L":
|
||||
print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.")
|
||||
print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.")
|
||||
print0("!" * 80)
|
||||
|
||||
# Tokenizer will be useful for evaluation, also we need the vocab size
|
||||
tokenizer = get_tokenizer()
|
||||
@@ -82,89 +112,143 @@ vocab_size = tokenizer.get_vocab_size()
|
||||
print0(f"Vocab size: {vocab_size:,}")
|
||||
|
||||
# Model kwargs are derived from the desired depth of the model
|
||||
num_layers = depth
|
||||
model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
|
||||
num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div)
|
||||
num_kv_heads = num_heads # 1:1 MQA ratio
|
||||
# We nudge model_dim up to the nearest multiple of head_dim to ensure clean division
|
||||
# (FA3 requires head_dim divisible by 8, and this guarantees head_dim == args.head_dim exactly)
|
||||
# (For very small depths, this gives a slight "unfair" advantage to models with odd depths)
|
||||
num_layers = args.depth
|
||||
base_dim = args.depth * args.aspect_ratio
|
||||
model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim
|
||||
num_heads = model_dim // args.head_dim
|
||||
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
|
||||
head_dim = model_dim // num_heads
|
||||
print0(f"num_layers: {num_layers}")
|
||||
print0(f"model_dim: {model_dim}")
|
||||
print0(f"model_dim: {model_dim} (base: {base_dim}, nudge: {model_dim - base_dim:+d})")
|
||||
print0(f"num_heads: {num_heads}")
|
||||
print0(f"head_dim: {head_dim}")
|
||||
print0(f"num_kv_heads: {num_kv_heads}")
|
||||
|
||||
# Optimizer / data / training length related hyperparameters
|
||||
# figure out the needed gradient accumulation to reach the desired total batch size
|
||||
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
|
||||
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
||||
assert total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
assert args.total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
|
||||
# Batch size scaling for learning rates (hyperparameters were tuned at reference batch size 2^19)
|
||||
batch_lr_scale = 1.0
|
||||
reference_batch_size = 2**19
|
||||
batch_ratio = args.total_batch_size / reference_batch_size
|
||||
if batch_ratio != 1.0:
|
||||
# SGD: linear scaling with batch size is standard (not used in nanochat)
|
||||
# AdamW: sqrt scaling is standard
|
||||
# Muon: sqrt scaling is an assumption - not fully studied, but it's a second-order-ish optimizer
|
||||
batch_lr_scale = batch_ratio ** 0.5
|
||||
print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {args.total_batch_size:,} (reference: {reference_batch_size:,})")
|
||||
|
||||
# Weight decay is tuned at d12 and its scaling seems to be \propto 1/channels^2 (or equivalently, \propto 1/depth^2 due to constant aspect ratio)
|
||||
weight_decay_scaled = args.weight_decay * (12 / args.depth)**2
|
||||
if args.depth != 12:
|
||||
print0(f"Scaling weight decay from {args.weight_decay:.6f} to {weight_decay_scaled:.6f} for depth {args.depth}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Model
|
||||
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
|
||||
|
||||
# Create a new model with random weights
|
||||
model_config_kwargs = dict(sequence_len=args.max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim, window_pattern=args.window_pattern)
|
||||
with torch.device("meta"):
|
||||
# All tensors are created as meta tensors (they have shape/dtype but no data)
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
model = GPT(model_config)
|
||||
model.to_empty(device=device)
|
||||
model.init_weights()
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict
|
||||
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
|
||||
num_params = sum(p.numel() for p in model.parameters())
|
||||
print0(f"Number of parameters: {num_params:,}")
|
||||
model.to_empty(device=device) # All tensors get storage on target device but with uninitialized (garbage) data
|
||||
model.init_weights() # All tensors get initialized
|
||||
|
||||
# If we are resuming, overwrite the model parameters with those of the checkpoint
|
||||
base_dir = get_base_dir()
|
||||
output_dirname = args.model_tag if args.model_tag else f"d{args.depth}" # e.g. d12
|
||||
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
|
||||
resuming = args.resume_from_step != -1
|
||||
if resuming:
|
||||
print0(f"Resuming optimization from step {args.resume_from_step}")
|
||||
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, args.resume_from_step, device, load_optimizer=True, rank=ddp_rank)
|
||||
model.load_state_dict(model_data, strict=True, assign=True)
|
||||
del model_data # free up this memory after the copy
|
||||
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
|
||||
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
||||
|
||||
# Detailed parameter counts
|
||||
param_counts = orig_model.num_scaling_params()
|
||||
print0(f"Parameter counts:")
|
||||
for key, value in param_counts.items():
|
||||
print0(f"{key:24s}: {value:,}")
|
||||
num_params = param_counts['total']
|
||||
num_scaling_params = param_counts['transformer_matrices'] + param_counts['lm_head'] # determined to give the cleanest scaling laws, see dev/LOG.md Jan 27, 2026
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
|
||||
|
||||
# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order)
|
||||
assert num_iterations > 0 or target_param_data_ratio > 0 or target_flops > 0
|
||||
if num_iterations > 0:
|
||||
assert args.num_iterations > 0 or args.target_param_data_ratio > 0 or args.target_flops > 0
|
||||
if args.num_iterations > 0:
|
||||
num_iterations = args.num_iterations
|
||||
print0(f"Using user-provided number of iterations: {num_iterations:,}")
|
||||
elif target_flops > 0:
|
||||
elif args.target_flops > 0:
|
||||
# calculate the number of iterations from the target flops
|
||||
num_iterations = round(target_flops / (num_flops_per_token * total_batch_size))
|
||||
num_iterations = round(args.target_flops / (num_flops_per_token * args.total_batch_size))
|
||||
print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
|
||||
elif target_param_data_ratio > 0:
|
||||
# calculate the number of iterations from the target param data ratio
|
||||
target_tokens = target_param_data_ratio * num_params
|
||||
num_iterations = target_tokens // total_batch_size
|
||||
elif args.target_param_data_ratio > 0:
|
||||
# calculate the number of iterations from the target param data ratio (use scaling params per Kaplan et al.)
|
||||
target_tokens = int(args.target_param_data_ratio * num_scaling_params)
|
||||
num_iterations = target_tokens // args.total_batch_size
|
||||
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
|
||||
else:
|
||||
raise ValueError("No training horizon specified")
|
||||
total_tokens = total_batch_size * num_iterations
|
||||
total_tokens = args.total_batch_size * num_iterations
|
||||
print0(f"Total number of training tokens: {total_tokens:,}")
|
||||
print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20
|
||||
print0(f"Tokens : Scaling params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20
|
||||
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
|
||||
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
|
||||
adam_betas = (args.adam_beta1, args.adam_beta2)
|
||||
optimizers = model.setup_optimizers(
|
||||
unembedding_lr=args.unembedding_lr * batch_lr_scale,
|
||||
embedding_lr=args.embedding_lr * batch_lr_scale,
|
||||
matrix_lr=args.matrix_lr * batch_lr_scale,
|
||||
weight_decay=weight_decay_scaled,
|
||||
adam_betas=adam_betas,
|
||||
scalar_lr=args.scalar_lr * batch_lr_scale,
|
||||
)
|
||||
adamw_optimizer, muon_optimizer = optimizers
|
||||
|
||||
if resuming:
|
||||
for opt, dat in zip(optimizers, optimizer_data):
|
||||
opt.load_state_dict(dat)
|
||||
del optimizer_data # free up the memory
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the DataLoaders for train/val
|
||||
base_dir = get_base_dir()
|
||||
tokens_dir = os.path.join(base_dir, "tokenized_data")
|
||||
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
|
||||
x, y = next(train_loader) # kick off load of the very first batch of data
|
||||
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
|
||||
train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device)
|
||||
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Set up hyperparameter schedulers
|
||||
|
||||
# Learning rate scheduler
|
||||
# TODO: experiment with a short warmup for the AdamW params (expecting slight improvement)
|
||||
warmup_ratio = 0.0 # ratio of iterations for LR warmup
|
||||
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
|
||||
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
|
||||
def get_lr_multiplier(it):
|
||||
warmup_iters = round(warmup_ratio * num_iterations)
|
||||
warmdown_iters = round(warmdown_ratio * num_iterations)
|
||||
warmup_iters = round(args.warmup_ratio * num_iterations)
|
||||
warmdown_iters = round(args.warmdown_ratio * num_iterations)
|
||||
if it < warmup_iters:
|
||||
return (it + 1) / warmup_iters
|
||||
elif it <= num_iterations - warmdown_iters:
|
||||
return 1.0
|
||||
else:
|
||||
progress = (num_iterations - it) / warmdown_iters
|
||||
return progress * 1.0 + (1 - progress) * final_lr_frac
|
||||
return progress * 1.0 + (1 - progress) * args.final_lr_frac
|
||||
|
||||
# Momentum scheduler for Muon optimizer
|
||||
def get_muon_momentum(it):
|
||||
@@ -172,25 +256,41 @@ def get_muon_momentum(it):
|
||||
momentum = (1 - frac) * 0.85 + frac * 0.95
|
||||
return momentum
|
||||
|
||||
# Weight decay scheduler for Muon optimizer (linear to zero over the course of training)
|
||||
def get_weight_decay(it):
|
||||
return weight_decay_scaled * (1 - it / num_iterations)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Loop state (variables updated by the training loop)
|
||||
|
||||
if not resuming:
|
||||
step = 0
|
||||
val_bpb = None # will be set if eval_every > 0
|
||||
min_val_bpb = float("inf")
|
||||
smooth_train_loss = 0 # EMA of training loss
|
||||
total_training_time = 0 # total wall-clock time of training
|
||||
else:
|
||||
step = meta_data["step"]
|
||||
loop_state = meta_data["loop_state"]
|
||||
val_bpb = meta_data["val_bpb"]
|
||||
min_val_bpb = loop_state["min_val_bpb"]
|
||||
smooth_train_loss = loop_state["smooth_train_loss"]
|
||||
total_training_time = loop_state["total_training_time"]
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Training loop
|
||||
min_val_bpb = float("inf")
|
||||
smooth_train_loss = 0 # EMA of training loss
|
||||
ema_beta = 0.9 # EMA decay factor
|
||||
total_training_time = 0 # total wall-clock time of training
|
||||
# note that we run +1 steps only so that we can eval and save at the end
|
||||
for step in range(num_iterations + 1):
|
||||
last_step = step == num_iterations
|
||||
flops_so_far = num_flops_per_token * total_batch_size * step
|
||||
while True:
|
||||
last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end
|
||||
flops_so_far = num_flops_per_token * args.total_batch_size * step
|
||||
|
||||
# once in a while: evaluate the val bpb (all ranks participate)
|
||||
if last_step or step % eval_every == 0:
|
||||
if args.eval_every > 0 and (last_step or step % args.eval_every == 0):
|
||||
model.eval()
|
||||
val_loader = build_val_loader()
|
||||
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
|
||||
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
|
||||
with autocast_ctx:
|
||||
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
|
||||
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
|
||||
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.6f}")
|
||||
if val_bpb < min_val_bpb:
|
||||
min_val_bpb = val_bpb
|
||||
wandb_run.log({
|
||||
@@ -204,10 +304,10 @@ for step in range(num_iterations + 1):
|
||||
# once in a while: estimate the CORE metric (all ranks participate)
|
||||
# use the original uncompiled model because the inputs keep changing shape
|
||||
results = {}
|
||||
if core_metric_every > 0 and (last_step or (step > 0 and step % core_metric_every == 0)):
|
||||
if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)):
|
||||
model.eval()
|
||||
with autocast_ctx:
|
||||
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
|
||||
results = evaluate_model(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task)
|
||||
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
@@ -219,7 +319,7 @@ for step in range(num_iterations + 1):
|
||||
|
||||
# once in a while: sample from the model (only on master process)
|
||||
# use the original uncompiled model because the inputs keep changing shape
|
||||
if master_process and (last_step or (step > 0 and step % sample_every == 0)):
|
||||
if args.sample_every > 0 and master_process and (last_step or (step > 0 and step % args.sample_every == 0)):
|
||||
model.eval()
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
@@ -238,25 +338,31 @@ for step in range(num_iterations + 1):
|
||||
print0(tokenizer.decode(sample[0]))
|
||||
model.train()
|
||||
|
||||
# save checkpoint at the end of the run (only on master process)
|
||||
if master_process and last_step:
|
||||
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
|
||||
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
|
||||
# save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step
|
||||
if last_step or (step > 0 and step != args.resume_from_step and args.save_every > 0 and step % args.save_every == 0):
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
orig_model.state_dict(),
|
||||
[opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly
|
||||
{
|
||||
orig_model.state_dict(), # model parameters
|
||||
[opt.state_dict() for opt in optimizers], # optimizer states
|
||||
{ # metadata saved as json
|
||||
"step": step,
|
||||
"val_bpb": val_bpb, # loss at last step
|
||||
"model_config": model_config_kwargs,
|
||||
"user_config": user_config, # inputs to the training script
|
||||
"device_batch_size": device_batch_size,
|
||||
"max_seq_len": max_seq_len,
|
||||
}
|
||||
"device_batch_size": args.device_batch_size,
|
||||
"max_seq_len": args.max_seq_len,
|
||||
"dataloader_state_dict": dataloader_state_dict,
|
||||
"loop_state": { # all loop state (other than step) so that we can resume training
|
||||
"min_val_bpb": min_val_bpb,
|
||||
"smooth_train_loss": smooth_train_loss,
|
||||
"total_training_time": total_training_time,
|
||||
},
|
||||
},
|
||||
rank=ddp_rank,
|
||||
)
|
||||
|
||||
# termination conditions (TODO: possibly also add loss explosions etc.)
|
||||
if last_step:
|
||||
break
|
||||
|
||||
@@ -271,39 +377,49 @@ for step in range(num_iterations + 1):
|
||||
train_loss = loss.detach() # for logging
|
||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
loss.backward()
|
||||
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
# gradient clipping (TODO possibly expertiment with)
|
||||
if grad_clip > 0.0:
|
||||
torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
|
||||
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
# step the optimizers
|
||||
lrm = get_lr_multiplier(step)
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
muon_momentum = get_muon_momentum(step)
|
||||
muon_weight_decay = get_weight_decay(step)
|
||||
for group in muon_optimizer.param_groups:
|
||||
group["momentum"] = muon_momentum
|
||||
group["weight_decay"] = muon_weight_decay
|
||||
for opt in optimizers:
|
||||
opt.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point
|
||||
synchronize()
|
||||
t1 = time.time()
|
||||
dt = t1 - t0
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
# logging
|
||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
||||
# logging (CPU action only)
|
||||
ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging
|
||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f # EMA the training loss
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
||||
pct_done = 100 * step / num_iterations
|
||||
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
|
||||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
tok_per_sec = int(args.total_batch_size / dt)
|
||||
flops_per_sec = num_flops_per_token * args.total_batch_size / dt
|
||||
mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size)
|
||||
if step > 10:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
|
||||
# Calculate ETA based on average time per step (excluding first 10 steps)
|
||||
steps_done = step - 10
|
||||
if steps_done > 0:
|
||||
avg_time_per_step = total_training_time / steps_done
|
||||
remaining_steps = num_iterations - step
|
||||
eta_seconds = remaining_steps * avg_time_per_step
|
||||
eta_str = f" | eta: {eta_seconds/60:.1f}m"
|
||||
else:
|
||||
eta_str = ""
|
||||
epoch = dataloader_state_dict["epoch"]
|
||||
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}")
|
||||
if step % 100 == 0:
|
||||
wandb_run.log({
|
||||
log_data = {
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"total_training_time": total_training_time,
|
||||
@@ -312,12 +428,18 @@ for step in range(num_iterations + 1):
|
||||
"train/dt": dt,
|
||||
"train/tok_per_sec": tok_per_sec,
|
||||
"train/mfu": mfu,
|
||||
})
|
||||
"train/epoch": epoch,
|
||||
}
|
||||
wandb_run.log(log_data)
|
||||
|
||||
# state update
|
||||
step += 1
|
||||
|
||||
# print a few more stats
|
||||
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
||||
print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||
if val_bpb is not None:
|
||||
print0(f"Minimum validation bpb: {min_val_bpb:.6f}")
|
||||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
@@ -328,14 +450,14 @@ get_report().log(section="Base model training", data=[
|
||||
"Number of FLOPs per token": f"{num_flops_per_token:e}",
|
||||
"Calculated number of iterations": num_iterations,
|
||||
"Number of training tokens": total_tokens,
|
||||
"Tokens : Params ratio": total_batch_size * num_iterations / num_params,
|
||||
"Tokens : Scaling params ratio": args.total_batch_size * num_iterations / num_scaling_params,
|
||||
"DDP world size": ddp_world_size,
|
||||
"warmup_ratio": warmup_ratio,
|
||||
"warmdown_ratio": warmdown_ratio,
|
||||
"final_lr_frac": final_lr_frac,
|
||||
"warmup_ratio": args.warmup_ratio,
|
||||
"warmdown_ratio": args.warmdown_ratio,
|
||||
"final_lr_frac": args.final_lr_frac,
|
||||
},
|
||||
{ # stats about training outcomes
|
||||
"Minimum validation bpb": min_val_bpb,
|
||||
"Minimum validation bpb": min_val_bpb if val_bpb is not None else None,
|
||||
"Final validation bpb": val_bpb,
|
||||
"CORE metric estimate": results.get("core_metric", None),
|
||||
"MFU %": f"{mfu:.2f}%",
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""
|
||||
Evaluate the Chat model.
|
||||
All the generic code lives here, and all the evlauation-specific
|
||||
All the generic code lives here, and all the evaluation-specific
|
||||
code lives in nanochat directory and is imported from here.
|
||||
|
||||
Example runs:
|
||||
python -m scripts.chat_eval -a ARC-Easy
|
||||
torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy
|
||||
python -m scripts.chat_eval -i mid -a ARC-Easy
|
||||
torchrun --nproc_per_node=8 -m scripts.chat_eval -- -i mid -a ARC-Easy
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -23,6 +23,7 @@ from tasks.humaneval import HumanEval
|
||||
from tasks.mmlu import MMLU
|
||||
from tasks.arc import ARC
|
||||
from tasks.gsm8k import GSM8K
|
||||
from tasks.spellingbee import SpellingBee
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Generative evaluation loop (we go one problem at a time, sample, evaluate)
|
||||
@@ -116,7 +117,7 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems
|
||||
logits = model(prompt_ids) # (B, T, V)
|
||||
|
||||
# Focus on the available answer on just the letters corresponding to choices
|
||||
# Note that this helps the evaluation a lot because it specifically narrows the focus to only the avilable letters
|
||||
# Note that this helps the evaluation a lot because it specifically narrows the focus to only the available letters
|
||||
# The much harder alternative would be to just generate from the Assistant and check if it responded with the correct
|
||||
# letter (e.g. A, B, C, D), but evaluations typically make the task easier in this way.
|
||||
for idx, conversation in enumerate(conversations):
|
||||
@@ -165,6 +166,7 @@ def run_chat_eval(task_name, model, tokenizer, engine,
|
||||
'ARC-Easy': partial(ARC, subset="ARC-Easy", split="test"),
|
||||
'ARC-Challenge': partial(ARC, subset="ARC-Challenge", split="test"),
|
||||
'GSM8K': partial(GSM8K, subset="main", split="test"),
|
||||
'SpellingBee': partial(SpellingBee, size=256, split="test"),
|
||||
}[task_name]
|
||||
task_object = task_module()
|
||||
# Run the evaluation
|
||||
@@ -204,13 +206,14 @@ if __name__ == "__main__":
|
||||
engine = Engine(model, tokenizer)
|
||||
|
||||
# Get the tasks to evaluate on
|
||||
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval']
|
||||
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee']
|
||||
baseline_accuracies = {
|
||||
'ARC-Easy': 0.25, # multiple choice 1 of 4 => 25%
|
||||
'ARC-Challenge': 0.25, # multiple choice 1 of 4 => 25%
|
||||
'MMLU': 0.25, # multiple choice 1 of 4 => 25%
|
||||
'GSM8K': 0.0, # open-ended => 0%
|
||||
'HumanEval': 0.0, # open-ended => 0%
|
||||
'SpellingBee': 0.0, # open-ended => 0%
|
||||
}
|
||||
task_names = all_tasks if args.task_name is None else args.task_name.split('|')
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ simpler and more similar to just REINFORCE:
|
||||
|
||||
1) Delete trust region, so there is no KL regularization to a reference model
|
||||
2) We are on policy, so there's no need for PPO ratio+clip.
|
||||
3) We use GAPO style normalization that is token-level, not sequence-level.
|
||||
3) We use DAPO style normalization that is token-level, not sequence-level.
|
||||
4) Instead of z-score normalization (r - mu)/sigma, only use (r - mu) as the advantage.
|
||||
|
||||
1 GPU:
|
||||
@@ -16,55 +16,68 @@ python -m scripts.chat_rl
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=default
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import itertools
|
||||
import re
|
||||
import wandb
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from contextlib import nullcontext
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_model
|
||||
from nanochat.engine import Engine
|
||||
from tasks.gsm8k import GSM8K
|
||||
|
||||
# RL hyperparameters
|
||||
run = "dummy" # wandb run name
|
||||
source = "sft" # mid|sft
|
||||
dtype = "bfloat16"
|
||||
device_batch_size = 8 # no forward pass will go above this to not OOM
|
||||
examples_per_step = 16 # in total and across all ranks (note: examples, not samples/completions!)
|
||||
num_samples = 16 # number of samples per example (/question)
|
||||
max_new_tokens = 256
|
||||
temperature = 1.0
|
||||
top_k = 50 # TODO: try None?
|
||||
unembedding_lr = 0.004
|
||||
embedding_lr = 0.2
|
||||
matrix_lr = 0.02
|
||||
weight_decay = 0.0
|
||||
init_lr_frac = 0.05
|
||||
num_epochs = 1 # how many epochs of gsm8k to train on
|
||||
save_every = 60 # every how many steps to save the model
|
||||
eval_every = 60 # every how many steps to evaluate the model for val pass@k
|
||||
eval_examples = 400 # number of examples used for evaluating pass@k
|
||||
# now allow CLI to override the settings via the configurator lol
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
# CLI arguments
|
||||
parser = argparse.ArgumentParser(description="Reinforcement learning on GSM8K")
|
||||
# Logging
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
# Runtime
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
|
||||
# Model loading
|
||||
parser.add_argument("--source", type=str, default="sft", help="mid|sft - which checkpoint to load from")
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
|
||||
# Training horizon
|
||||
parser.add_argument("--num-epochs", type=int, default=1, help="number of epochs over GSM8K")
|
||||
# Batch sizes / sampling
|
||||
parser.add_argument("--device-batch-size", type=int, default=8, help="max batch size per forward pass")
|
||||
parser.add_argument("--examples-per-step", type=int, default=16, help="total examples per optimization step across all ranks")
|
||||
parser.add_argument("--num-samples", type=int, default=16, help="number of samples per example/question")
|
||||
# Generation
|
||||
parser.add_argument("--max-new-tokens", type=int, default=256, help="max tokens to generate per sample")
|
||||
parser.add_argument("--temperature", type=float, default=1.0, help="sampling temperature")
|
||||
parser.add_argument("--top-k", type=int, default=50, help="top-k sampling (0 = disabled)")
|
||||
# Optimization
|
||||
parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
||||
parser.add_argument("--init-lr-frac", type=float, default=0.05, help="initial LR as fraction of base LR")
|
||||
# Evaluation / checkpointing
|
||||
parser.add_argument("--eval-every", type=int, default=60, help="evaluate pass@k every N steps")
|
||||
parser.add_argument("--eval-examples", type=int, default=400, help="number of examples for pass@k evaluation")
|
||||
parser.add_argument("--save-every", type=int, default=60, help="save checkpoint every N steps")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy()
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Init compute/precision
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=run, config=user_config)
|
||||
use_dummy_wandb = args.run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=args.run, config=user_config)
|
||||
|
||||
# Init model and tokenizer
|
||||
model, tokenizer, meta = load_model(source, device, phase="eval")
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.model_step)
|
||||
engine = Engine(model, tokenizer) # for sampling rollouts
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -72,7 +85,7 @@ engine = Engine(model, tokenizer) # for sampling rollouts
|
||||
|
||||
train_task = GSM8K(subset="main", split="train")
|
||||
val_task = GSM8K(subset="main", split="test")
|
||||
num_steps = (len(train_task) // examples_per_step) * num_epochs
|
||||
num_steps = (len(train_task) // args.examples_per_step) * args.num_epochs
|
||||
print0(f"Calculated number of steps: {num_steps}")
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -93,16 +106,16 @@ def get_batch():
|
||||
model.eval() # ensure the model is in eval mode
|
||||
generated_token_sequences = []
|
||||
masks = []
|
||||
num_sampling_steps = num_samples // device_batch_size # go sequentially to prevent OOMs
|
||||
num_sampling_steps = args.num_samples // args.device_batch_size # go sequentially to prevent OOMs
|
||||
for sampling_step in range(num_sampling_steps):
|
||||
seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF # positive half of int32
|
||||
with autocast_ctx:
|
||||
generated_token_sequences_batch, masks_batch = engine.generate_batch(
|
||||
tokens,
|
||||
num_samples=device_batch_size,
|
||||
max_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
num_samples=args.device_batch_size,
|
||||
max_tokens=args.max_new_tokens,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
seed=seed, # must make sure to change the seed for each sampling step
|
||||
)
|
||||
generated_token_sequences.extend(generated_token_sequences_batch)
|
||||
@@ -160,7 +173,7 @@ def run_gsm8k_eval(task, tokenizer, engine,
|
||||
tokens = tokenizer.render_for_completion(conversation)
|
||||
prefix_length = len(tokens)
|
||||
# Generate k samples using batched generation inside the Engine
|
||||
assert num_samples <= device_batch_size # usually this is true. we can add a loop if not...
|
||||
assert num_samples <= args.device_batch_size # usually this is true. we can add a loop if not...
|
||||
generated_token_sequences, masks = engine.generate_batch(
|
||||
tokens,
|
||||
num_samples=num_samples,
|
||||
@@ -189,16 +202,16 @@ def run_gsm8k_eval(task, tokenizer, engine,
|
||||
|
||||
# Init the optimizer
|
||||
optimizers = model.setup_optimizers(
|
||||
unembedding_lr=unembedding_lr,
|
||||
embedding_lr=embedding_lr,
|
||||
matrix_lr=matrix_lr,
|
||||
weight_decay=weight_decay,
|
||||
unembedding_lr=args.unembedding_lr,
|
||||
embedding_lr=args.embedding_lr,
|
||||
matrix_lr=args.matrix_lr,
|
||||
weight_decay=args.weight_decay,
|
||||
)
|
||||
|
||||
# Set the initial learning rate as a fraction of the base learning rate
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * init_lr_frac
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
|
||||
# Learning rate scheduler: simple rampdown to zero over num_steps
|
||||
@@ -206,10 +219,10 @@ def get_lr_multiplier(it):
|
||||
lrm = 1.0 - it / num_steps
|
||||
return lrm
|
||||
|
||||
# Calculate the number of examples each rank handles to achive the desired examples_per_step
|
||||
print0(f"Total sequences per step: {examples_per_step * num_samples}") # total batch size in sequences/step
|
||||
assert examples_per_step % ddp_world_size == 0, "Desired examples per step must be divisible by the number of ranks"
|
||||
examples_per_rank = examples_per_step // ddp_world_size # per GPU
|
||||
# Calculate the number of examples each rank handles to achieve the desired examples_per_step
|
||||
print0(f"Total sequences per step: {args.examples_per_step * args.num_samples}") # total batch size in sequences/step
|
||||
assert args.examples_per_step % ddp_world_size == 0, "Desired examples per step must be divisible by the number of ranks"
|
||||
examples_per_rank = args.examples_per_step // ddp_world_size # per GPU
|
||||
print0(f"Calculated examples per rank: {examples_per_rank}")
|
||||
|
||||
# Kick off the training loop
|
||||
@@ -217,22 +230,22 @@ batch_iterator = get_batch()
|
||||
for step in range(num_steps):
|
||||
|
||||
# Evaluate the model once in a while and log to wandb
|
||||
if step % eval_every == 0:
|
||||
if step % args.eval_every == 0:
|
||||
model.eval()
|
||||
passk = torch.zeros(device_batch_size, device=device) # pass@k for k=1..device_batch_size
|
||||
passk = torch.zeros(args.device_batch_size, device=device) # pass@k for k=1..device_batch_size
|
||||
with autocast_ctx:
|
||||
records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=device_batch_size, max_examples=eval_examples, temperature=1.0)
|
||||
records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=args.device_batch_size, max_examples=args.eval_examples, temperature=1.0)
|
||||
records = list(records_iter) # collect all records
|
||||
for k in range(1, device_batch_size + 1):
|
||||
for k in range(1, args.device_batch_size + 1):
|
||||
passk[k - 1] = sum(any(o["is_correct"] for o in r["outcomes"][:k]) for r in records)
|
||||
num_records = torch.tensor(len(records), dtype=torch.long, device=device)
|
||||
if ddp:
|
||||
dist.all_reduce(num_records, op=dist.ReduceOp.SUM)
|
||||
dist.all_reduce(passk, op=dist.ReduceOp.SUM)
|
||||
passk = passk / num_records.item() # normalize by the total number of records
|
||||
print_passk = [f"Pass@{k}: {passk[k - 1].item():.4f}" for k in range(1, device_batch_size + 1)]
|
||||
print_passk = [f"Pass@{k}: {passk[k - 1].item():.4f}" for k in range(1, args.device_batch_size + 1)]
|
||||
print0(f"Step {step} | {', '.join(print_passk)}")
|
||||
log_passk = {f"pass@{k}": passk[k - 1].item() for k in range(1, device_batch_size + 1)}
|
||||
log_passk = {f"pass@{k}": passk[k - 1].item() for k in range(1, args.device_batch_size + 1)}
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
**log_passk,
|
||||
@@ -247,11 +260,11 @@ for step in range(num_steps):
|
||||
# Evaluate the loss and gradients
|
||||
model.train() # ensure the model is in train mode
|
||||
# We need one more loop because we can never exceed the device_batch_size
|
||||
assert inputs_all.size(0) % device_batch_size == 0
|
||||
num_passes = inputs_all.size(0) // device_batch_size
|
||||
assert inputs_all.size(0) % args.device_batch_size == 0
|
||||
num_passes = inputs_all.size(0) // args.device_batch_size
|
||||
for pass_idx in range(num_passes):
|
||||
# Pluck out the batch for this pass
|
||||
b0, b1 = pass_idx * device_batch_size, (pass_idx + 1) * device_batch_size
|
||||
b0, b1 = pass_idx * args.device_batch_size, (pass_idx + 1) * args.device_batch_size
|
||||
inputs = inputs_all[b0:b1]
|
||||
targets = targets_all[b0:b1]
|
||||
rewards = rewards_all[b0:b1]
|
||||
@@ -304,11 +317,11 @@ for step in range(num_steps):
|
||||
})
|
||||
|
||||
# Master process saves the model once in a while. Skip first step. Save last step.
|
||||
if master_process and ((step > 0 and step % save_every == 0) or step == num_steps - 1):
|
||||
if master_process and ((step > 0 and step % args.save_every == 0) or step == num_steps - 1):
|
||||
base_dir = get_base_dir()
|
||||
depth = model.config.n_layer
|
||||
model_tag = f"d{depth}" # base the model tag on the depth of the base model
|
||||
checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", model_tag)
|
||||
output_dirname = args.model_tag if args.model_tag else f"d{depth}" # base the model tag on the depth of the base model
|
||||
checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", output_dirname)
|
||||
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
|
||||
@@ -9,8 +9,9 @@ Or torchrun for training:
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||
|
||||
import wandb
|
||||
import torch
|
||||
@@ -28,51 +29,54 @@ from tasks.arc import ARC
|
||||
from tasks.gsm8k import GSM8K
|
||||
from tasks.smoltalk import SmolTalk
|
||||
from tasks.customjson import CustomJSON
|
||||
from tasks.spellingbee import SimpleSpelling, SpellingBee
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# SFT Hyperparameters
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
# input model options
|
||||
source = "mid" # base|mid , which checkpoint to load the model from (base model or midtrained model)
|
||||
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||
step = None # step to load the model from (base model or midtrained model)
|
||||
# compute/precision
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
dtype = "bfloat16"
|
||||
device_batch_size = 4 # max to avoid OOM
|
||||
# optimization
|
||||
num_epochs = 1
|
||||
num_iterations = -1 # override number of iterations (-1 = disable, use num_epochs to derive it)
|
||||
target_examples_per_step = 32
|
||||
unembedding_lr = 0.004
|
||||
embedding_lr = 0.2
|
||||
matrix_lr = 0.02
|
||||
weight_decay = 0.0
|
||||
init_lr_frac = 0.02
|
||||
# evaluation and logging there of
|
||||
eval_every = 100
|
||||
eval_steps = 100
|
||||
eval_metrics_every = 200
|
||||
eval_metrics_max_problems = 1024
|
||||
# now allow CLI to override the settings via the configurator lol
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
|
||||
# CLI arguments
|
||||
parser = argparse.ArgumentParser(description="Supervised finetuning for chat")
|
||||
# Logging
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
# Runtime
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
|
||||
# Model loading
|
||||
parser.add_argument("--source", type=str, default="mid", help="base|mid - which checkpoint to load from")
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
|
||||
# Training horizon
|
||||
parser.add_argument("--num-epochs", type=int, default=1, help="number of epochs")
|
||||
parser.add_argument("--num-iterations", type=int, default=-1, help="override number of iterations (-1 = use num_epochs)")
|
||||
# Batch sizes
|
||||
parser.add_argument("--device-batch-size", type=int, default=4, help="per-device batch size")
|
||||
parser.add_argument("--target-examples-per-step", type=int, default=32, help="target examples per optimization step")
|
||||
# Optimization
|
||||
parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
||||
parser.add_argument("--init-lr-frac", type=float, default=0.02, help="initial LR as fraction of base LR")
|
||||
# Evaluation
|
||||
parser.add_argument("--eval-every", type=int, default=100, help="evaluate val loss every N steps")
|
||||
parser.add_argument("--eval-steps", type=int, default=100, help="number of batches for val loss evaluation")
|
||||
parser.add_argument("--eval-metrics-every", type=int, default=200, help="evaluate accuracy metrics every N steps")
|
||||
parser.add_argument("--eval-metrics-max-problems", type=int, default=1024, help="max problems per metric evaluation")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy()
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0
|
||||
ptdtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=run, config=user_config, save_code=True)
|
||||
use_dummy_wandb = args.run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config, save_code=True)
|
||||
|
||||
# Load the model and tokenizer
|
||||
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="train", model_tag=args.model_tag, step=args.model_step)
|
||||
orig_model = model # original, uncompiled model
|
||||
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
|
||||
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
|
||||
@@ -86,7 +90,9 @@ train_ds = TaskMixture([
|
||||
GSM8K(subset="main", split="train"), # 8K rows
|
||||
SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
|
||||
CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
|
||||
]) # 2.3K + 1.1K + 8K + 10K + 1K = 22.4K rows
|
||||
SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling (e.g. spell the word 'apple')
|
||||
SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
|
||||
]) # 2.3K + 1.1K + 8K + 10K + 1K + 0.3K + 0.3K = 23K rows
|
||||
val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -124,34 +130,36 @@ def sft_data_generator(dataset, batch_size):
|
||||
yield collate_and_yield(batch)
|
||||
batch = []
|
||||
|
||||
examples_per_step = device_batch_size * ddp_world_size
|
||||
print0(f"Target examples per step: {target_examples_per_step}")
|
||||
print0(f"Device batch size: {device_batch_size}")
|
||||
examples_per_step = args.device_batch_size * ddp_world_size
|
||||
print0(f"Target examples per step: {args.target_examples_per_step}")
|
||||
print0(f"Device batch size: {args.device_batch_size}")
|
||||
print0(f"Examples per step is device_batch_size * ddp_world_size: {examples_per_step}")
|
||||
assert target_examples_per_step % examples_per_step == 0, "Target examples per step must be divisible by examples per step"
|
||||
grad_accum_steps = target_examples_per_step // examples_per_step
|
||||
assert args.target_examples_per_step % examples_per_step == 0, "Target examples per step must be divisible by examples per step"
|
||||
grad_accum_steps = args.target_examples_per_step // examples_per_step
|
||||
print0(f"=> Setting grad accum steps: {grad_accum_steps}")
|
||||
|
||||
if num_iterations == -1:
|
||||
if args.num_iterations == -1:
|
||||
# derive num_iterations from num_epochs and the size of the dataset
|
||||
assert num_epochs > 0, "num_epochs must be positive if num_iterations is -1"
|
||||
num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs
|
||||
train_loader = sft_data_generator(train_ds, batch_size=device_batch_size)
|
||||
build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size)
|
||||
assert args.num_epochs > 0, "num_epochs must be positive if num_iterations is -1"
|
||||
num_iterations = (len(train_ds) // args.target_examples_per_step) * args.num_epochs
|
||||
else:
|
||||
num_iterations = args.num_iterations
|
||||
train_loader = sft_data_generator(train_ds, batch_size=args.device_batch_size)
|
||||
build_val_loader = lambda: sft_data_generator(val_ds, batch_size=args.device_batch_size)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Optimizer
|
||||
|
||||
optimizers = model.setup_optimizers(
|
||||
unembedding_lr=unembedding_lr,
|
||||
embedding_lr=embedding_lr,
|
||||
matrix_lr=matrix_lr,
|
||||
weight_decay=weight_decay,
|
||||
unembedding_lr=args.unembedding_lr,
|
||||
embedding_lr=args.embedding_lr,
|
||||
matrix_lr=args.matrix_lr,
|
||||
weight_decay=args.weight_decay,
|
||||
)
|
||||
# Set the initial learning rate as a fraction of the base learning rate
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * init_lr_frac
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -164,17 +172,16 @@ def get_lr_multiplier(it):
|
||||
|
||||
# Go!
|
||||
step = 0
|
||||
train_iter = iter(train_loader)
|
||||
for step in range(num_iterations):
|
||||
last_step = step == num_iterations - 1
|
||||
|
||||
# evaluate the validation loss
|
||||
if last_step or step % eval_every == 0:
|
||||
if last_step or step % args.eval_every == 0:
|
||||
model.eval()
|
||||
val_iter = iter(build_val_loader())
|
||||
val_loader = build_val_loader()
|
||||
losses = []
|
||||
for _ in range(eval_steps):
|
||||
val_inputs, val_targets = next(val_iter)
|
||||
for _ in range(args.eval_steps):
|
||||
val_inputs, val_targets = next(val_loader)
|
||||
with torch.no_grad(), autocast_ctx:
|
||||
loss = model(val_inputs, val_targets)
|
||||
losses.append(loss)
|
||||
@@ -189,14 +196,14 @@ for step in range(num_iterations):
|
||||
})
|
||||
model.train()
|
||||
|
||||
# evlauate accuracy of the multiple choice tasks (which are quick to run)
|
||||
if last_step or (step > 0 and step % eval_metrics_every == 0):
|
||||
# evaluate accuracy of the multiple choice tasks (which are quick to run)
|
||||
if last_step or (step > 0 and step % args.eval_metrics_every == 0):
|
||||
model.eval()
|
||||
metrics = {}
|
||||
with torch.no_grad(), autocast_ctx:
|
||||
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
|
||||
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
|
||||
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
|
||||
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=args.device_batch_size*2, max_problems=args.eval_metrics_max_problems)
|
||||
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=args.device_batch_size*2, max_problems=args.eval_metrics_max_problems)
|
||||
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
|
||||
print0(f"Step {step:05d} | {metrics_str}")
|
||||
wandb_run.log({
|
||||
@@ -211,7 +218,7 @@ for step in range(num_iterations):
|
||||
# evaluate the gradient
|
||||
num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen
|
||||
for micro_step in range(grad_accum_steps):
|
||||
train_inputs, train_targets = next(train_iter)
|
||||
train_inputs, train_targets = next(train_loader)
|
||||
with autocast_ctx:
|
||||
loss = model(train_inputs, train_targets)
|
||||
train_loss = loss.detach() # for logging
|
||||
@@ -248,8 +255,8 @@ for step in range(num_iterations):
|
||||
if master_process:
|
||||
base_dir = get_base_dir()
|
||||
depth = model.config.n_layer
|
||||
model_tag = f"d{depth}" # base the model tag on the depth of the base model
|
||||
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", model_tag)
|
||||
output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12
|
||||
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname)
|
||||
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
|
||||
@@ -243,7 +243,7 @@ app.add_middleware(
|
||||
async def root():
|
||||
"""Serve the chat UI."""
|
||||
ui_html_path = os.path.join("nanochat", "ui.html")
|
||||
with open(ui_html_path, "r") as f:
|
||||
with open(ui_html_path, "r", encoding="utf-8") as f:
|
||||
html_content = f.read()
|
||||
# Replace the API_URL to use the same origin
|
||||
html_content = html_content.replace(
|
||||
|
||||
@@ -6,12 +6,12 @@ python -m scripts.mid_train
|
||||
|
||||
Or torchrun for training:
|
||||
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device-batch-size=16
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
import argparse
|
||||
import os
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import time
|
||||
import wandb
|
||||
import torch
|
||||
@@ -28,67 +28,78 @@ from tasks.gsm8k import GSM8K
|
||||
from tasks.mmlu import MMLU
|
||||
from tasks.smoltalk import SmolTalk
|
||||
from tasks.customjson import CustomJSON
|
||||
from tasks.spellingbee import SimpleSpelling, SpellingBee
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||
step = None # step to load the model from (base model or midtrained model)
|
||||
dtype = "bfloat16"
|
||||
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
|
||||
max_seq_len = 2048
|
||||
device_batch_size = 32
|
||||
unembedding_lr = 0.004
|
||||
embedding_lr = 0.2
|
||||
matrix_lr = 0.02
|
||||
init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate
|
||||
weight_decay = 0.0
|
||||
eval_every = 150 # -1 = disable
|
||||
eval_tokens = 20*524288
|
||||
total_batch_size = 524288
|
||||
dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
|
||||
# CLI arguments
|
||||
parser = argparse.ArgumentParser(description="Midtrain the model")
|
||||
# Logging
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
# Runtime
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
|
||||
# Model loading
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
|
||||
# Training horizon
|
||||
parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)")
|
||||
# Batch sizes
|
||||
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
|
||||
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
|
||||
# Optimization
|
||||
parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
||||
parser.add_argument("--init-lr-frac", type=float, default=1.0, help="initial LR as fraction of base LR")
|
||||
# Evaluation
|
||||
parser.add_argument("--eval-every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)")
|
||||
parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
|
||||
# Output
|
||||
parser.add_argument("--dry-run", action="store_true", help="log to wandb but skip checkpoints/report")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy()
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mid", name=run, config=user_config)
|
||||
use_dummy_wandb = args.run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mid", name=args.run, config=user_config)
|
||||
|
||||
# Load the model and tokenizer
|
||||
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step)
|
||||
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step)
|
||||
pretrain_batch_size = meta.get("device_batch_size", None)
|
||||
if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size:
|
||||
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?")
|
||||
if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size:
|
||||
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?")
|
||||
orig_model = model
|
||||
model = torch.compile(model, dynamic=False)
|
||||
depth = model.config.n_layer
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
|
||||
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
||||
assert total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
assert args.total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
|
||||
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
|
||||
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
|
||||
optimizers = model.setup_optimizers(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay)
|
||||
adamw_optimizer, muon_optimizer = optimizers
|
||||
# Override the initial learning rate as a fraction of the base learning rate
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * init_lr_frac
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
|
||||
# Midtraining data mixture and DataLoader
|
||||
@@ -100,7 +111,9 @@ train_dataset = TaskMixture([
|
||||
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
|
||||
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
|
||||
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
|
||||
]) # total: 460K + 100K + 8K = 568K rows
|
||||
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
|
||||
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
|
||||
]) # total: 460K + 100K + 8K + 200K + 80K = 848K rows
|
||||
val_dataset = TaskMixture([
|
||||
SmolTalk(split="test"), # 24K rows in test set
|
||||
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
|
||||
@@ -109,50 +122,102 @@ val_dataset = TaskMixture([
|
||||
# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len)
|
||||
# A big problem is that we don't know the final num_iterations in advance. So we create
|
||||
# these two global variables and update them from within the data generator.
|
||||
last_step = False # we will toggle this to True when we reach the end of the dataset
|
||||
last_step = False # we will toggle this to True when we reach the end of the training dataset
|
||||
approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
|
||||
def mid_data_generator(split):
|
||||
global last_step, approx_progress
|
||||
current_epoch = 1 # track epoch for logging
|
||||
def mid_data_generator_bos_bestfit(split, buffer_size=100):
|
||||
"""
|
||||
BOS-aligned dataloader for midtraining with bestfit-crop packing.
|
||||
|
||||
Each row in the batch starts with BOS (beginning of a conversation).
|
||||
Conversations are packed using best-fit algorithm to minimize cropping.
|
||||
This matches the BOS-aligned approach used in pretraining.
|
||||
"""
|
||||
global last_step, approx_progress, current_epoch
|
||||
assert split in {"train", "val"}, "split must be 'train' or 'val'"
|
||||
dataset = train_dataset if split == "train" else val_dataset
|
||||
dataset_size = len(dataset)
|
||||
assert dataset_size > 0
|
||||
needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets
|
||||
token_buffer = deque()
|
||||
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
|
||||
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
|
||||
it = 0 # iteration counter
|
||||
while True:
|
||||
# Accumulate enough tokens for one iteration before yielding
|
||||
while len(token_buffer) < needed_tokens:
|
||||
row_capacity = args.max_seq_len + 1 # +1 for target at last position
|
||||
|
||||
# Conversation buffer: list of token lists
|
||||
conv_buffer = []
|
||||
cursor = ddp_rank # Each rank processes different conversations (for fetching)
|
||||
consumed = ddp_rank # Track actual consumption separately from buffering
|
||||
epoch = 1
|
||||
it = 0 # iteration counter
|
||||
|
||||
def refill_buffer():
|
||||
nonlocal cursor, epoch
|
||||
while len(conv_buffer) < buffer_size:
|
||||
conversation = dataset[cursor]
|
||||
ids, _ = tokenizer.render_conversation(conversation)
|
||||
token_buffer.extend(ids)
|
||||
conv_buffer.append(ids)
|
||||
cursor += ddp_world_size
|
||||
if cursor >= dataset_size:
|
||||
cursor -= dataset_size # wrap around for another epoch
|
||||
if split == "train":
|
||||
last_step = True # toggle last_step to True, which will terminate the training loop
|
||||
cursor = cursor % dataset_size
|
||||
epoch += 1
|
||||
# Note: last_step is now triggered based on consumption, not fetching
|
||||
|
||||
while True:
|
||||
rows = []
|
||||
for _ in range(args.device_batch_size):
|
||||
row = []
|
||||
while len(row) < row_capacity:
|
||||
# Ensure buffer has conversations
|
||||
while len(conv_buffer) < buffer_size:
|
||||
refill_buffer()
|
||||
|
||||
remaining = row_capacity - len(row)
|
||||
|
||||
# Find largest conversation that fits entirely
|
||||
best_idx = -1
|
||||
best_len = 0
|
||||
for i, conv in enumerate(conv_buffer):
|
||||
conv_len = len(conv)
|
||||
if conv_len <= remaining and conv_len > best_len:
|
||||
best_idx = i
|
||||
best_len = conv_len
|
||||
|
||||
if best_idx >= 0:
|
||||
# Found a conversation that fits - use it entirely
|
||||
conv = conv_buffer.pop(best_idx)
|
||||
row.extend(conv)
|
||||
consumed += ddp_world_size # Track actual consumption
|
||||
else:
|
||||
# No conversation fits - crop first conversation to fill remaining
|
||||
conv = conv_buffer.pop(0)
|
||||
row.extend(conv[:remaining])
|
||||
consumed += ddp_world_size # Track actual consumption
|
||||
|
||||
rows.append(row[:row_capacity])
|
||||
|
||||
# Stopping condition to respect num_iterations, if given
|
||||
it += 1
|
||||
if num_iterations > 0 and it >= num_iterations:
|
||||
last_step = True # toggle last_step to True, which will terminate the training loop
|
||||
# Build up inputs/targets and yield
|
||||
for i in range(needed_tokens):
|
||||
scratch[i] = token_buffer.popleft()
|
||||
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
||||
targets_cpu = scratch[1:]
|
||||
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
|
||||
if 0 < args.num_iterations <= it and split == "train":
|
||||
last_step = True
|
||||
|
||||
# Update progress tracking (based on consumed, not cursor, to account for buffering)
|
||||
if split == "train":
|
||||
if num_iterations > 0:
|
||||
approx_progress = it / num_iterations # calculate progress from the max number of iterations
|
||||
current_epoch = epoch
|
||||
if args.num_iterations > 0:
|
||||
approx_progress = it / args.num_iterations
|
||||
else:
|
||||
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
|
||||
approx_progress = consumed / dataset_size
|
||||
# Trigger last_step when we've consumed enough (instead of when cursor wraps)
|
||||
if consumed >= dataset_size:
|
||||
last_step = True
|
||||
|
||||
# Build tensors
|
||||
use_cuda = device_type == "cuda"
|
||||
batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
|
||||
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda)
|
||||
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda)
|
||||
|
||||
yield inputs, targets
|
||||
|
||||
train_loader = mid_data_generator("train")
|
||||
build_val_loader = lambda: mid_data_generator("val")
|
||||
train_loader = mid_data_generator_bos_bestfit("train")
|
||||
build_val_loader = lambda: mid_data_generator_bos_bestfit("val")
|
||||
progress = 0 # will go from 0 to 1 over the course of the epoch
|
||||
|
||||
# Learning rate scheduler
|
||||
@@ -175,7 +240,7 @@ ema_beta = 0.9 # EMA decay factor
|
||||
total_training_time = 0 # total wall-clock time of training
|
||||
step = 0
|
||||
while True:
|
||||
flops_so_far = num_flops_per_token * total_batch_size * step
|
||||
flops_so_far = num_flops_per_token * args.total_batch_size * step
|
||||
|
||||
# Synchronize last_step across all ranks to avoid hangs in the distributed setting
|
||||
if ddp:
|
||||
@@ -184,10 +249,10 @@ while True:
|
||||
last_step = bool(last_step_tensor.item())
|
||||
|
||||
# once in a while: evaluate the val bpb (all ranks participate)
|
||||
if eval_every > 0 and (last_step or step % eval_every == 0):
|
||||
if last_step or (args.eval_every > 0 and step % args.eval_every == 0):
|
||||
model.eval()
|
||||
val_loader = build_val_loader()
|
||||
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
|
||||
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
|
||||
with autocast_ctx:
|
||||
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
|
||||
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
|
||||
@@ -202,8 +267,8 @@ while True:
|
||||
model.train()
|
||||
|
||||
# save checkpoint at the end of the run (only on master process)
|
||||
if master_process and last_step and not dry_run:
|
||||
output_dirname = f"d{depth}" # e.g. d12
|
||||
if master_process and last_step and not args.dry_run:
|
||||
output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12
|
||||
checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname)
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
@@ -214,7 +279,7 @@ while True:
|
||||
"step": step,
|
||||
"val_bpb": val_bpb, # loss at last step
|
||||
"model_config": {
|
||||
"sequence_len": max_seq_len,
|
||||
"sequence_len": args.max_seq_len,
|
||||
"vocab_size": tokenizer.get_vocab_size(),
|
||||
"n_layer": depth,
|
||||
"n_head": model.config.n_head,
|
||||
@@ -264,13 +329,13 @@ while True:
|
||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
||||
pct_done = 100 * progress
|
||||
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
|
||||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
tok_per_sec = int(args.total_batch_size / dt)
|
||||
flops_per_sec = num_flops_per_token * args.total_batch_size / dt
|
||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
if step > 10:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
|
||||
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m")
|
||||
if step % 10 == 0:
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
@@ -281,6 +346,7 @@ while True:
|
||||
"train/dt": dt,
|
||||
"train/tok_per_sec": tok_per_sec,
|
||||
"train/mfu": mfu,
|
||||
"train/epoch": current_epoch,
|
||||
})
|
||||
|
||||
# print a few more stats
|
||||
@@ -289,7 +355,7 @@ print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||
|
||||
# Log to report
|
||||
if not dry_run:
|
||||
if not args.dry_run:
|
||||
from nanochat.report import get_report
|
||||
get_report().log(section="Midtraining", data=[
|
||||
user_config, # CLI args
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Train a tokenizer using the HuggingFace Tokenizers library.
|
||||
Train a tokenizer using our own BPE Tokenizer library.
|
||||
In the style of GPT-4 tokenizer.
|
||||
"""
|
||||
import os
|
||||
@@ -14,9 +14,9 @@ from nanochat.dataset import parquets_iter_batched
|
||||
# Parse command line arguments
|
||||
|
||||
parser = argparse.ArgumentParser(description='Train a BPE tokenizer')
|
||||
parser.add_argument('--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)')
|
||||
parser.add_argument('--doc_cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)')
|
||||
parser.add_argument('--vocab_size', type=int, default=65536, help='Vocabulary size (default: 65536 = 2^16)')
|
||||
parser.add_argument('--max-chars', type=int, default=2_000_000_000, help='Maximum characters to train on (default: 10B)')
|
||||
parser.add_argument('--doc-cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)')
|
||||
parser.add_argument('--vocab-size', type=int, default=32768, help='Vocabulary size (default: 32768 = 2^15)')
|
||||
args = parser.parse_args()
|
||||
print(f"max_chars: {args.max_chars:,}")
|
||||
print(f"doc_cap: {args.doc_cap:,}")
|
||||
|
||||
@@ -53,7 +53,7 @@ class Task:
|
||||
|
||||
class TaskMixture(Task):
|
||||
"""
|
||||
For SFT Training it becomes useful to train on a tax mixture of datasets.
|
||||
For SFT Training it becomes useful to train on a mixture of datasets.
|
||||
Fun trick: if you wish to oversample any task, just pass it in multiple times in the list.
|
||||
"""
|
||||
|
||||
|
||||
@@ -25,14 +25,14 @@ class CustomJSON(Task):
|
||||
print("-" * 80)
|
||||
print(f"Warning: File {filepath} does not exist")
|
||||
print("HINT (Oct 21 2025)")
|
||||
print("If you recently did a git pull and suddely see this, it might be due to the new addition of identity conversations")
|
||||
print("If you recently did a git pull and suddenly see this, it might be due to the new addition of identity conversations")
|
||||
print("See this discussion for more details: https://github.com/karpathy/nanochat/discussions/139")
|
||||
print("Quick fix: simply run the following command to download the file and you're done:")
|
||||
print(f"curl -L -o {filepath} https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl")
|
||||
print("-" * 80)
|
||||
|
||||
else:
|
||||
with open(filepath, 'r') as f:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line: # skip empty lines
|
||||
|
||||
@@ -74,7 +74,7 @@ class GSM8K(Task):
|
||||
else:
|
||||
# Regular text in between tool calls
|
||||
assistant_message_parts.append({"type": "text", "text": part})
|
||||
# No put it all together
|
||||
# Now put it all together
|
||||
messages = [
|
||||
{"role": "user", "content": question}, # note: simple string
|
||||
{"role": "assistant", "content": assistant_message_parts}, # note: list of parts (as dicts)
|
||||
|
||||
307
tasks/spellingbee.py
Normal file
307
tasks/spellingbee.py
Normal file
@@ -0,0 +1,307 @@
|
||||
"""
|
||||
Task intended to make nanochat better in spelling and counting, for example:
|
||||
|
||||
"How many r are in strawberry?" -> 3
|
||||
|
||||
An interesting part of this task is that we will get the assistant to
|
||||
solve the problem using a combination of manual counting and Python.
|
||||
This is a good problem solving "instinct" to mix into the model and RL
|
||||
may further refine it to trust one over the other. If we were extra fancy
|
||||
(which we could/should be) we'd add small errors here and there to allow
|
||||
the model also learn recoveries. We can do this in future versions.
|
||||
|
||||
There are two tasks in this file:
|
||||
1. SpellingBee: Counting the number of occurrences of a letter in a word
|
||||
2. SimpleSpelling: Simply spelling words
|
||||
|
||||
(1) is the goal, but (2) exists as a highly condensed version of the part
|
||||
that makes (1) difficult, which is word spelling. This is non-trivial for an
|
||||
LLM because it has to learn how every token (a little semantic chunk/atom)
|
||||
maps to the sequence of individual characters that make it up. Larger models
|
||||
learn this eventually on their own, but if we want this capability to exist
|
||||
in smaller models, we have to actively encourage it by over-representing it
|
||||
in the training data. Midtraining is a good place to do this.
|
||||
|
||||
To preview a few example conversations, run:
|
||||
python -m tasks.spellingbee
|
||||
"""
|
||||
|
||||
import re
|
||||
import random
|
||||
from tasks.common import Task
|
||||
from nanochat.common import download_file_with_lock
|
||||
|
||||
# Letters of the alphabet
|
||||
LETTERS = "abcdefghijklmnopqrstuvwxyz"
|
||||
# A list of 370K English words of large variety
|
||||
WORD_LIST_URL = "https://raw.githubusercontent.com/dwyl/english-words/refs/heads/master/words_alpha.txt"
|
||||
# A number bigger than 370K to separate train and test random seeds
|
||||
TEST_RANDOM_SEED_OFFSET = 10_000_000
|
||||
|
||||
# Identical to gsm8k's answer extraction
|
||||
ANSWER_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
|
||||
def extract_answer(completion):
|
||||
"""
|
||||
Extract the numerical answer after #### marker.
|
||||
"""
|
||||
match = ANSWER_RE.search(completion)
|
||||
if match:
|
||||
match_str = match.group(1).strip()
|
||||
match_str = match_str.replace(",", "")
|
||||
return match_str
|
||||
return None
|
||||
|
||||
# User message templates for data augmentation
|
||||
USER_MSG_TEMPLATES = [
|
||||
"How many {letter} are in the word {word}",
|
||||
"How many {letter} are in {word}",
|
||||
"Count the number of {letter} in {word}",
|
||||
"How many times does {letter} appear in {word}",
|
||||
"What's the count of {letter} in {word}",
|
||||
"In the word {word}, how many {letter} are there",
|
||||
"How many letter {letter} are in the word {word}",
|
||||
"Count how many {letter} appear in {word}",
|
||||
"Tell me the number of {letter} in {word}",
|
||||
"How many occurrences of {letter} are in {word}",
|
||||
"Find the count of {letter} in {word}",
|
||||
"Can you count the {letter} letters in {word}",
|
||||
"What is the frequency of {letter} in {word}",
|
||||
"How many {letter}s are in {word}",
|
||||
"How many {letter}'s are in {word}",
|
||||
"Count all the {letter} in {word}",
|
||||
"How many times is {letter} in {word}",
|
||||
"Number of {letter} in {word}",
|
||||
"Total count of {letter} in {word}",
|
||||
"How many {letter} does {word} have",
|
||||
"How many {letter} does {word} contain",
|
||||
"What's the number of {letter} in {word}",
|
||||
"{word} has how many {letter}",
|
||||
"In {word}, count the {letter}",
|
||||
"How many {letter} appear in {word}",
|
||||
"Count the {letter} in {word}",
|
||||
"Give me the count of {letter} in {word}",
|
||||
"How many instances of {letter} in {word}",
|
||||
"Show me how many {letter} are in {word}",
|
||||
"Calculate the number of {letter} in {word}",
|
||||
# Spanish
|
||||
"¿Cuántas {letter} hay en {word}?",
|
||||
"¿Cuántas veces aparece {letter} en {word}?",
|
||||
"Cuenta las {letter} en {word}",
|
||||
"¿Cuántas letras {letter} tiene {word}?",
|
||||
# Chinese (Simplified)
|
||||
"{word}中有多少个{letter}",
|
||||
"{word}里有几个{letter}",
|
||||
"数一下{word}中的{letter}",
|
||||
"{word}这个词里有多少{letter}",
|
||||
# Korean
|
||||
"{word}에 {letter}가 몇 개 있나요",
|
||||
"{word}에서 {letter}의 개수는",
|
||||
"{word}에 {letter}가 몇 번 나오나요",
|
||||
"{word}라는 단어에 {letter}가 몇 개",
|
||||
# French
|
||||
"Combien de {letter} dans {word}",
|
||||
"Combien de fois {letter} apparaît dans {word}",
|
||||
"Compte les {letter} dans {word}",
|
||||
# German
|
||||
"Wie viele {letter} sind in {word}",
|
||||
"Wie oft kommt {letter} in {word} vor",
|
||||
"Zähle die {letter} in {word}",
|
||||
# Japanese
|
||||
"{word}に{letter}は何個ありますか",
|
||||
"{word}の中に{letter}がいくつ",
|
||||
"{word}に{letter}が何回出てくる",
|
||||
]
|
||||
|
||||
class SpellingBee(Task):
|
||||
|
||||
def __init__(self, size=1000, split="train", **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert split in ["train", "test"], "SpellingBee split must be train|test"
|
||||
self.size = size
|
||||
self.split = split
|
||||
filename = WORD_LIST_URL.split("/")[-1]
|
||||
word_list_path = download_file_with_lock(WORD_LIST_URL, filename)
|
||||
with open(word_list_path, 'r', encoding='utf-8') as f:
|
||||
words = [line.strip() for line in f]
|
||||
self.words = words
|
||||
|
||||
@property
|
||||
def eval_type(self):
|
||||
return 'generative'
|
||||
|
||||
def num_examples(self):
|
||||
return self.size
|
||||
|
||||
def get_example(self, index):
|
||||
seed = index if self.split == 'train' else TEST_RANDOM_SEED_OFFSET + index
|
||||
rng = random.Random(seed)
|
||||
|
||||
# pick a random word
|
||||
word = rng.choice(self.words)
|
||||
# pick a letter from it (90%) or a random letter (10%)
|
||||
letter = rng.choice(word) if rng.random() < 0.9 else rng.choice(LETTERS)
|
||||
|
||||
# get the correct answer by simply counting
|
||||
count = word.count(letter)
|
||||
|
||||
# create a user message, with a bunch of variations as data augmentation
|
||||
template = rng.choice(USER_MSG_TEMPLATES)
|
||||
# 30% chance to lowercase the template (lazy people don't use shift)
|
||||
if rng.random() < 0.3:
|
||||
template = template.lower()
|
||||
quote_options = ['', "'", '"']
|
||||
letter_quote = rng.choice(quote_options) # is the letter quoted?
|
||||
word_quote = rng.choice(quote_options) # is the word quoted?
|
||||
letter_wrapped = f"{letter_quote}{letter}{letter_quote}"
|
||||
word_wrapped = f"{word_quote}{word}{word_quote}"
|
||||
user_msg = template.format(letter=letter_wrapped, word=word_wrapped)
|
||||
if rng.random() < 0.5: # 50% of people don't even use question marks
|
||||
user_msg += "?"
|
||||
|
||||
# Now create the ideal assistant response - build as parts (text + tool calls)
|
||||
assistant_parts = []
|
||||
word_letters = ",".join(list(word))
|
||||
manual_text = f"""We are asked to find the number '{letter}' in the word '{word}'. Let me try a manual approach first.
|
||||
|
||||
First spell the word out:
|
||||
{word}:{word_letters}
|
||||
|
||||
Then count the occurrences of '{letter}':
|
||||
"""
|
||||
# Little simulated loop of the solution process
|
||||
# TODO: This is where the fun starts, we could simulate cute little mistakes
|
||||
# and get the model to review its work and recover from them.
|
||||
# You might of course hope this could arise in RL too, but realistically you'd want to help it out a bit.
|
||||
running_count = 0
|
||||
for i, char in enumerate(word, 1):
|
||||
if char == letter:
|
||||
running_count += 1
|
||||
# note: there deliberately cannot be a space here between i and char
|
||||
# because this would create a different token! (e.g. " a" and "a" are different tokens)
|
||||
manual_text += f"{i}:{char} hit! count={running_count}\n"
|
||||
else:
|
||||
manual_text += f"{i}:{char}\n"
|
||||
|
||||
manual_text += f"\nThis gives us {running_count}."
|
||||
assistant_parts.append({"type": "text", "text": manual_text})
|
||||
# Part 2: Python verification
|
||||
assistant_parts.append({"type": "text", "text": "\n\nLet me double check this using Python:\n\n"})
|
||||
# Part 3: Python tool call
|
||||
python_expr = f"'{word}'.count('{letter}')"
|
||||
assistant_parts.append({"type": "python", "text": python_expr})
|
||||
# Part 4: Python output
|
||||
assistant_parts.append({"type": "python_output", "text": str(count)})
|
||||
# Part 5: Final answer
|
||||
assistant_parts.append({"type": "text", "text": f"\n\nPython gives us {count}.\n\nMy final answer is:\n\n#### {count}"})
|
||||
|
||||
# return the full conversation
|
||||
messages = [
|
||||
{"role": "user", "content": user_msg},
|
||||
{"role": "assistant", "content": assistant_parts}
|
||||
]
|
||||
conversation = {
|
||||
"messages": messages,
|
||||
}
|
||||
return conversation
|
||||
|
||||
def evaluate(self, conversation, assistant_response):
|
||||
"""
|
||||
Given (conversation, completion), return evaluation outcome (0 = wrong, 1 = correct)
|
||||
Identical to gsm8k's evaluation.
|
||||
"""
|
||||
assert isinstance(assistant_response, str), "Assuming simple string response for now"
|
||||
# First extract the ground truth answer from the conversation
|
||||
assistant_message = conversation['messages'][-1]
|
||||
assert assistant_message['role'] == "assistant", "Last message must be from the Assistant"
|
||||
assert isinstance(assistant_message['content'], list), "This is expected to be a list of parts"
|
||||
# The last text part contains the final answer with ####
|
||||
last_text_part = assistant_message['content'][-1]['text']
|
||||
# Extract both the ground truth answer and the predicted answer
|
||||
ref_num = extract_answer(last_text_part)
|
||||
pred_num = extract_answer(assistant_response)
|
||||
# Compare and return the success as int
|
||||
is_correct = int(pred_num == ref_num)
|
||||
return is_correct
|
||||
|
||||
def reward(self, conversation, assistant_response):
|
||||
""" Use simple 0-1 reward just like gsm8k."""
|
||||
is_correct = self.evaluate(conversation, assistant_response)
|
||||
is_correct_float = float(is_correct)
|
||||
return is_correct_float
|
||||
|
||||
|
||||
class SimpleSpelling(Task):
|
||||
"""Much simpler task designed to get the model to just practice spelling words."""
|
||||
|
||||
def __init__(self, size=1000, split="train", **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert split in ["train", "test"], "SpellingBee split must be train|test"
|
||||
self.size = size
|
||||
self.split = split
|
||||
filename = WORD_LIST_URL.split("/")[-1]
|
||||
word_list_path = download_file_with_lock(WORD_LIST_URL, filename)
|
||||
with open(word_list_path, 'r', encoding='utf-8') as f:
|
||||
words = [line.strip() for line in f]
|
||||
rng = random.Random(42)
|
||||
rng.shuffle(words) # use a different word order than the SpellingBee task
|
||||
self.words = words
|
||||
|
||||
@property
|
||||
def eval_type(self):
|
||||
return 'generative'
|
||||
|
||||
def num_examples(self):
|
||||
return self.size
|
||||
|
||||
def get_example(self, index):
|
||||
seed = index if self.split == 'train' else TEST_RANDOM_SEED_OFFSET + index
|
||||
rng = random.Random(seed)
|
||||
# pick a random word
|
||||
word = rng.choice(self.words)
|
||||
word_letters = ",".join(list(word))
|
||||
# return the full conversation
|
||||
messages = [
|
||||
{"role": "user", "content": f"Spell the word: {word}"},
|
||||
{"role": "assistant", "content": f"{word}:{word_letters}"}
|
||||
]
|
||||
conversation = {
|
||||
"messages": messages,
|
||||
}
|
||||
return conversation
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# preview the SpellingBee task, first 10 examples
|
||||
task = SpellingBee()
|
||||
for i in range(10):
|
||||
ex = task.get_example(i)
|
||||
print("=" * 100)
|
||||
print(ex['messages'][0]['content'])
|
||||
print("-" * 100)
|
||||
# Assistant content is now a list of parts
|
||||
assistant_parts = ex['messages'][1]['content']
|
||||
for part in assistant_parts:
|
||||
if part['type'] == 'text':
|
||||
print(part['text'], end='')
|
||||
elif part['type'] == 'python':
|
||||
print(f"<<{part['text']}=", end='')
|
||||
elif part['type'] == 'python_output':
|
||||
print(f"{part['text']}>>", end='')
|
||||
print()
|
||||
print("-" * 100)
|
||||
|
||||
# # preview the SimpleSpelling task, first 10 examples
|
||||
# task = SimpleSpelling()
|
||||
# for i in range(10):
|
||||
# ex = task.get_example(i)
|
||||
# print("=" * 100)
|
||||
# print(ex['messages'][0]['content'])
|
||||
# print("-" * 100)
|
||||
# print(ex['messages'][1]['content'])
|
||||
|
||||
# # also scrutinize the tokenization (last example only)
|
||||
# from nanochat.tokenizer import get_tokenizer
|
||||
# tokenizer = get_tokenizer()
|
||||
# ids, mask = tokenizer.render_conversation(ex)
|
||||
# print(tokenizer.visualize_tokenization(ids, mask, with_token_id=True))
|
||||
338
tests/test_attention_fallback.py
Normal file
338
tests/test_attention_fallback.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""
|
||||
Test Flash Attention unified interface - verify FA3 and SDPA produce identical results.
|
||||
|
||||
Run: python -m pytest tests/test_attention_fallback.py -v -s
|
||||
|
||||
Note on test structure:
|
||||
Tests are split into two classes due to dtype/device constraints:
|
||||
|
||||
1. TestFA3VsSDPA: Comparison tests that run both FA3 and SDPA on the same inputs
|
||||
and verify they produce identical results. These require a Hopper GPU (FA3 only
|
||||
works on sm90+) and use bfloat16 (FA3 doesn't support float32).
|
||||
|
||||
2. TestSDPAOnly: Tests that only exercise the SDPA fallback path. These can run
|
||||
on any device (CUDA, CPU, MPS) with the appropriate dtype for that device.
|
||||
"""
|
||||
import torch
|
||||
import pytest
|
||||
import nanochat.flash_attention as fa_module
|
||||
from nanochat.flash_attention import flash_attn, HAS_FA3
|
||||
from nanochat.engine import KVCache
|
||||
|
||||
|
||||
def set_impl(impl):
|
||||
"""Set the implementation override ('fa3', 'sdpa', or None for auto)."""
|
||||
fa_module._override_impl = impl
|
||||
|
||||
|
||||
def run_both_impls(fn):
|
||||
"""Run a function with both FA3 and SDPA, return both outputs."""
|
||||
set_impl('fa3')
|
||||
out_fa3 = fn()
|
||||
set_impl('sdpa')
|
||||
out_sdpa = fn()
|
||||
set_impl(None) # reset
|
||||
return out_fa3, out_sdpa
|
||||
|
||||
|
||||
def assert_close(t1, t2, name, atol=1e-2, rtol=1e-2):
|
||||
"""Assert two tensors are close, with helpful error message."""
|
||||
max_diff = (t1 - t2).abs().max().item()
|
||||
mean_diff = (t1 - t2).abs().mean().item()
|
||||
assert torch.allclose(t1, t2, atol=atol, rtol=rtol), \
|
||||
f"{name}: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}"
|
||||
return max_diff, mean_diff
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# FA3 vs SDPA comparison tests (require Hopper GPU)
|
||||
# =============================================================================
|
||||
@pytest.mark.skipif(not HAS_FA3, reason="FA3 required to compare implementations")
|
||||
class TestFA3VsSDPA:
|
||||
"""Compare FA3 and SDPA produce identical results. Requires Hopper GPU."""
|
||||
|
||||
DEVICE = "cuda"
|
||||
DTYPE = torch.bfloat16
|
||||
|
||||
def test_basic_causal(self):
|
||||
"""Basic causal attention."""
|
||||
B, T, H, D = 2, 64, 4, 32
|
||||
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
|
||||
|
||||
y_fa3, y_sdpa = run_both_impls(run)
|
||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "basic_causal")
|
||||
print(f"basic_causal: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_full_context(self):
|
||||
"""Full context (window_size=-1)."""
|
||||
B, T, H, D = 2, 128, 4, 32
|
||||
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(-1, -1))
|
||||
|
||||
y_fa3, y_sdpa = run_both_impls(run)
|
||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "full_context")
|
||||
print(f"full_context: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_sliding_window(self):
|
||||
"""Sliding window attention."""
|
||||
B, T, H, D = 2, 128, 4, 32
|
||||
window = 32
|
||||
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(window, 0))
|
||||
|
||||
y_fa3, y_sdpa = run_both_impls(run)
|
||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "sliding_window")
|
||||
print(f"sliding_window: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_gqa(self):
|
||||
"""Group Query Attention (fewer KV heads than Q heads)."""
|
||||
B, T, D = 2, 64, 32
|
||||
n_heads = 8
|
||||
n_kv_heads = 2
|
||||
|
||||
q = torch.randn(B, T, n_heads, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
|
||||
|
||||
y_fa3, y_sdpa = run_both_impls(run)
|
||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "gqa")
|
||||
print(f"gqa: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_larger_model(self):
|
||||
"""Larger dimensions closer to real model."""
|
||||
B, T, H, D = 4, 256, 12, 64
|
||||
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(-1, -1))
|
||||
|
||||
y_fa3, y_sdpa = run_both_impls(run)
|
||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "larger_model")
|
||||
print(f"larger_model: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_kvcache_prefill(self):
|
||||
"""Test prefill (inserting multiple tokens into empty cache)."""
|
||||
B, T_max, H, D = 2, 64, 4, 32
|
||||
T_prefill = 16
|
||||
|
||||
q = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
cache_seqlens = torch.zeros(B, dtype=torch.int32, device=self.DEVICE)
|
||||
return flash_attn.flash_attn_with_kvcache(
|
||||
q, k_cache, v_cache, k=k, v=v,
|
||||
cache_seqlens=cache_seqlens,
|
||||
causal=True, window_size=(T_max, 0)
|
||||
)
|
||||
|
||||
y_fa3, y_sdpa = run_both_impls(run)
|
||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "prefill")
|
||||
print(f"prefill: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_kvcache_single_token(self):
|
||||
"""Test single token generation (cache already has content)."""
|
||||
B, T_max, H, D = 2, 64, 4, 32
|
||||
T_prefill = 16
|
||||
|
||||
k_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
q_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k_cache[:, :T_prefill, :, :] = k_init
|
||||
v_cache[:, :T_prefill, :, :] = v_init
|
||||
cache_seqlens = torch.full((B,), T_prefill, dtype=torch.int32, device=self.DEVICE)
|
||||
return flash_attn.flash_attn_with_kvcache(
|
||||
q_single, k_cache, v_cache, k=k_single, v=v_single,
|
||||
cache_seqlens=cache_seqlens,
|
||||
causal=True, window_size=(T_max, 0)
|
||||
)
|
||||
|
||||
y_fa3, y_sdpa = run_both_impls(run)
|
||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "single_token")
|
||||
print(f"single_token: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_backward_gradients_match(self):
|
||||
"""Verify gradients are similar between FA3 and SDPA."""
|
||||
B, T, H, D = 2, 32, 4, 16
|
||||
|
||||
q_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
q = q_data.clone().requires_grad_(True)
|
||||
k = k_data.clone().requires_grad_(True)
|
||||
v = v_data.clone().requires_grad_(True)
|
||||
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
|
||||
loss = y.sum()
|
||||
loss.backward()
|
||||
return y.detach(), q.grad.detach(), k.grad.detach(), v.grad.detach()
|
||||
|
||||
set_impl('fa3')
|
||||
y_fa3, q_grad_fa3, k_grad_fa3, v_grad_fa3 = run()
|
||||
set_impl('sdpa')
|
||||
y_sdpa, q_grad_sdpa, k_grad_sdpa, v_grad_sdpa = run()
|
||||
set_impl(None)
|
||||
|
||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "backward_output")
|
||||
print(f"backward_output: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
max_diff, mean_diff = assert_close(q_grad_fa3, q_grad_sdpa, "q_grad", atol=0.05, rtol=0.05)
|
||||
print(f"q_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
max_diff, mean_diff = assert_close(k_grad_fa3, k_grad_sdpa, "k_grad", atol=0.05, rtol=0.05)
|
||||
print(f"k_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
max_diff, mean_diff = assert_close(v_grad_fa3, v_grad_sdpa, "v_grad", atol=0.05, rtol=0.05)
|
||||
print(f"v_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SDPA-only tests (run on any device)
|
||||
# =============================================================================
|
||||
class TestSDPAOnly:
|
||||
"""Test SDPA fallback works correctly. Runs on any device."""
|
||||
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
||||
|
||||
def test_basic_forward(self):
|
||||
"""Test SDPA forward pass produces valid output."""
|
||||
set_impl('sdpa')
|
||||
B, T, H, D = 2, 64, 4, 32
|
||||
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
|
||||
|
||||
assert y.shape == (B, T, H, D)
|
||||
assert not torch.isnan(y).any(), "Output contains NaN"
|
||||
set_impl(None)
|
||||
|
||||
def test_backward(self):
|
||||
"""Test gradients flow through SDPA."""
|
||||
set_impl('sdpa')
|
||||
B, T, H, D = 2, 32, 4, 16
|
||||
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True)
|
||||
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True)
|
||||
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True)
|
||||
|
||||
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
|
||||
loss = y.sum()
|
||||
loss.backward()
|
||||
|
||||
assert q.grad is not None, "No gradient for q"
|
||||
assert k.grad is not None, "No gradient for k"
|
||||
assert v.grad is not None, "No gradient for v"
|
||||
assert not torch.isnan(q.grad).any(), "NaN in q gradient"
|
||||
set_impl(None)
|
||||
|
||||
def test_kvcache(self):
|
||||
"""Test SDPA with KV cache."""
|
||||
set_impl('sdpa')
|
||||
B, T_max, H, D = 2, 64, 4, 32
|
||||
n_layers = 1
|
||||
|
||||
cache = KVCache(
|
||||
batch_size=B, num_heads=H, seq_len=T_max, head_dim=D,
|
||||
num_layers=n_layers, device=self.DEVICE, dtype=self.DTYPE
|
||||
)
|
||||
k_cache, v_cache = cache.get_layer_cache(0)
|
||||
|
||||
# Prefill
|
||||
T_prefill = 16
|
||||
q = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
y = flash_attn.flash_attn_with_kvcache(
|
||||
q, k_cache, v_cache, k=k, v=v,
|
||||
cache_seqlens=cache.cache_seqlens,
|
||||
causal=True, window_size=(T_max, 0)
|
||||
)
|
||||
cache.advance(T_prefill)
|
||||
|
||||
assert y.shape == (B, T_prefill, H, D)
|
||||
assert cache.get_pos() == T_prefill
|
||||
|
||||
# Generate single token
|
||||
q_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
y_single = flash_attn.flash_attn_with_kvcache(
|
||||
q_single, k_cache, v_cache, k=k_single, v=v_single,
|
||||
cache_seqlens=cache.cache_seqlens,
|
||||
causal=True, window_size=(T_max, 0)
|
||||
)
|
||||
cache.advance(1)
|
||||
|
||||
assert y_single.shape == (B, 1, H, D)
|
||||
assert cache.get_pos() == T_prefill + 1
|
||||
set_impl(None)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Override mechanism tests
|
||||
# =============================================================================
|
||||
class TestOverrideMechanism:
|
||||
"""Test that the override mechanism works correctly."""
|
||||
|
||||
@pytest.mark.skipif(not HAS_FA3, reason="FA3 required")
|
||||
def test_override_fa3(self):
|
||||
"""Test that override='fa3' uses FA3."""
|
||||
set_impl('fa3')
|
||||
assert fa_module._use_fa3() == True
|
||||
set_impl(None)
|
||||
|
||||
def test_override_sdpa(self):
|
||||
"""Test that override='sdpa' uses SDPA."""
|
||||
set_impl('sdpa')
|
||||
assert fa_module._use_fa3() == False
|
||||
set_impl(None)
|
||||
|
||||
def test_override_auto(self):
|
||||
"""Test that override=None uses auto-detection."""
|
||||
set_impl(None)
|
||||
assert fa_module._use_fa3() == HAS_FA3
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f"PyTorch version: {torch.__version__}")
|
||||
print(f"CUDA available: {torch.cuda.is_available()}")
|
||||
if torch.cuda.is_available():
|
||||
print(f"CUDA device: {torch.cuda.get_device_name()}")
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
print(f"Compute capability: {major}.{minor}")
|
||||
print(f"HAS_FA3: {HAS_FA3}")
|
||||
print()
|
||||
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
267
tests/test_engine.py
Normal file
267
tests/test_engine.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""
|
||||
Test Engine class. Example run:
|
||||
|
||||
python -m pytest tests/test_engine.py -v
|
||||
"""
|
||||
|
||||
import torch
|
||||
from nanochat.engine import KVCache, Engine
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Mock classes for testing Engine without loading a real model
|
||||
|
||||
@dataclass
|
||||
class MockConfig:
|
||||
"""Minimal config for Engine tests."""
|
||||
n_kv_head: int = 4
|
||||
n_head: int = 4
|
||||
n_embd: int = 64
|
||||
n_layer: int = 2
|
||||
sequence_len: int = 128
|
||||
|
||||
|
||||
class MockModel:
|
||||
"""
|
||||
Mock model that returns uniform logits over the vocab.
|
||||
This ensures that with temperature > 0, different samples should
|
||||
(with very high probability) produce different tokens.
|
||||
"""
|
||||
def __init__(self, vocab_size=262): # 256 bytes + 6 special tokens
|
||||
self.vocab_size = vocab_size
|
||||
self.config = MockConfig()
|
||||
self._device = "cpu"
|
||||
|
||||
def get_device(self):
|
||||
return self._device
|
||||
|
||||
def forward(self, ids, kv_cache=None):
|
||||
"""Return uniform logits so sampling is spread across vocab."""
|
||||
B, T = ids.shape
|
||||
# With FA3, flash_attn_with_kvcache updates cache in-place and we advance position
|
||||
if kv_cache is not None:
|
||||
kv_cache.advance(T)
|
||||
# Uniform logits -> equal probability for all tokens
|
||||
logits = torch.zeros(B, T, self.vocab_size)
|
||||
return logits
|
||||
|
||||
|
||||
class ByteTokenizer:
|
||||
"""
|
||||
Simple byte-level tokenizer for testing.
|
||||
Tokens 0-255 are raw bytes, 256+ are special tokens.
|
||||
"""
|
||||
def __init__(self):
|
||||
# Special tokens start at 256
|
||||
self._special_tokens = {
|
||||
"<|python_start|>": 256,
|
||||
"<|python_end|>": 257,
|
||||
"<|output_start|>": 258,
|
||||
"<|output_end|>": 259,
|
||||
"<|assistant_end|>": 260,
|
||||
"<|bos|>": 261,
|
||||
}
|
||||
self._bos = 261
|
||||
|
||||
def encode_special(self, s):
|
||||
return self._special_tokens[s]
|
||||
|
||||
def get_bos_token_id(self):
|
||||
return self._bos
|
||||
|
||||
def encode(self, s, prepend=None):
|
||||
tokens = list(s.encode("utf-8")) # bytes 0-255
|
||||
if prepend is not None:
|
||||
tokens = [prepend] + tokens
|
||||
return tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
# Filter out special tokens before decoding
|
||||
byte_tokens = [t for t in tokens if t < 256]
|
||||
return bytes(byte_tokens).decode("utf-8", errors="replace")
|
||||
|
||||
def test_kv_cache_basic():
|
||||
"""Test basic KVCache functionality for FA3."""
|
||||
batch_size = 2
|
||||
num_heads = 3
|
||||
seq_len = 64
|
||||
head_dim = 5
|
||||
num_layers = 6
|
||||
|
||||
kv_cache = KVCache(
|
||||
batch_size=batch_size,
|
||||
num_heads=num_heads,
|
||||
seq_len=seq_len,
|
||||
head_dim=head_dim,
|
||||
num_layers=num_layers,
|
||||
device="cpu",
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
# Check initial state
|
||||
assert kv_cache.get_pos() == 0
|
||||
assert kv_cache.k_cache.shape == (num_layers, batch_size, seq_len, num_heads, head_dim)
|
||||
assert kv_cache.v_cache.shape == (num_layers, batch_size, seq_len, num_heads, head_dim)
|
||||
|
||||
# Test advance
|
||||
kv_cache.advance(10)
|
||||
assert kv_cache.get_pos() == 10
|
||||
|
||||
kv_cache.advance(5)
|
||||
assert kv_cache.get_pos() == 15
|
||||
|
||||
# Test reset
|
||||
kv_cache.reset()
|
||||
assert kv_cache.get_pos() == 0
|
||||
|
||||
# Test get_layer_cache returns correct views
|
||||
k_layer0, v_layer0 = kv_cache.get_layer_cache(0)
|
||||
assert k_layer0.shape == (batch_size, seq_len, num_heads, head_dim)
|
||||
assert v_layer0.shape == (batch_size, seq_len, num_heads, head_dim)
|
||||
|
||||
|
||||
def test_kv_cache_prefill():
|
||||
"""Test KVCache.prefill() copies data correctly."""
|
||||
batch_size = 1
|
||||
num_heads = 4
|
||||
head_dim = 8
|
||||
num_layers = 2
|
||||
|
||||
# Create source cache and advance it
|
||||
src_cache = KVCache(
|
||||
batch_size=batch_size, num_heads=num_heads, seq_len=32,
|
||||
head_dim=head_dim, num_layers=num_layers, device="cpu", dtype=torch.float32,
|
||||
)
|
||||
# Write some data to source cache
|
||||
src_cache.k_cache[0, 0, :16, :, :] = 1.0
|
||||
src_cache.v_cache[0, 0, :16, :, :] = 2.0
|
||||
src_cache.advance(16)
|
||||
|
||||
# Create destination cache with larger seq_len
|
||||
dst_cache = KVCache(
|
||||
batch_size=batch_size, num_heads=num_heads, seq_len=64,
|
||||
head_dim=head_dim, num_layers=num_layers, device="cpu", dtype=torch.float32,
|
||||
)
|
||||
|
||||
# Prefill
|
||||
dst_cache.prefill(src_cache)
|
||||
|
||||
# Check position was copied
|
||||
assert dst_cache.get_pos() == 16
|
||||
|
||||
# Check data was copied
|
||||
assert (dst_cache.k_cache[0, 0, :16, :, :] == 1.0).all()
|
||||
assert (dst_cache.v_cache[0, 0, :16, :, :] == 2.0).all()
|
||||
|
||||
|
||||
def test_multi_sample_first_token_diversity():
|
||||
"""
|
||||
Test that when generating multiple samples, each sample gets an independently
|
||||
sampled first token (not a broadcast of the same token to all rows).
|
||||
|
||||
Previously, the first token after prefill was sampled once and broadcast to all
|
||||
rows, causing all samples to start identically. The fix expands the prefill logits
|
||||
to num_samples and samples independently for each row.
|
||||
|
||||
With uniform logits over 262 tokens and 16 samples, the probability that all
|
||||
samples independently pick the same token is (1/262)^15 ≈ 10^-36. So if they're
|
||||
all identical, it indicates tokens are being broadcast instead of independently sampled.
|
||||
"""
|
||||
model = MockModel(vocab_size=262)
|
||||
tokenizer = ByteTokenizer()
|
||||
engine = Engine(model, tokenizer)
|
||||
|
||||
# Generate 16 samples with temperature=1.0 (stochastic sampling)
|
||||
prompt_tokens = [261, 72, 101, 108, 108, 111] # <bos> + "Hello"
|
||||
num_samples = 16
|
||||
|
||||
# Collect the first generated token from each sample
|
||||
first_tokens = []
|
||||
gen = engine.generate(
|
||||
prompt_tokens,
|
||||
num_samples=num_samples,
|
||||
max_tokens=1, # We only need the first token
|
||||
temperature=1.0,
|
||||
seed=42,
|
||||
)
|
||||
for token_column, token_masks in gen:
|
||||
first_tokens = token_column # This is the first (and only) yield
|
||||
|
||||
# With uniform distribution and 16 samples, they should NOT all be identical
|
||||
# If they are all identical, the bug exists (broadcasting instead of sampling)
|
||||
unique_tokens = set(first_tokens)
|
||||
assert len(unique_tokens) > 1, (
|
||||
f"All {num_samples} samples got the same first token ({first_tokens[0]}). "
|
||||
f"With uniform logits, this is statistically impossible (~10^-36 probability) "
|
||||
f"unless tokens are being broadcast instead of independently sampled."
|
||||
)
|
||||
|
||||
|
||||
def test_seed_reproducibility():
|
||||
"""Same seed must produce identical output."""
|
||||
model = MockModel()
|
||||
engine = Engine(model, ByteTokenizer())
|
||||
prompt = [261, 72, 101, 108, 108, 111] # <bos> + "Hello"
|
||||
|
||||
for seed in [1, 42, 123, 999]:
|
||||
r1, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed)
|
||||
r2, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed)
|
||||
r3, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed)
|
||||
assert r1 == r2 == r3, "Same seed must produce identical output for the same prompt."
|
||||
|
||||
|
||||
def test_temperature_zero_determinism():
|
||||
"""Temperature=0 is deterministic regardless of seed."""
|
||||
model = MockModel()
|
||||
engine = Engine(model, ByteTokenizer())
|
||||
prompt = [261, 72, 101, 108, 108, 111]
|
||||
|
||||
r1, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=1)
|
||||
r2, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=42)
|
||||
r3, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=123)
|
||||
assert r1 == r2 == r3, "Temperature=0 must result in the same output for the same prompt regardless of seed."
|
||||
|
||||
|
||||
def test_max_tokens_respected():
|
||||
"""Generation stops at max_tokens limit."""
|
||||
model = MockModel()
|
||||
engine = Engine(model, ByteTokenizer())
|
||||
prompt = [261, 72, 101, 108, 108, 111]
|
||||
|
||||
for max_tokens in [1, 4, 16, 64]:
|
||||
results, _ = engine.generate_batch(prompt, max_tokens=max_tokens)
|
||||
num_generated_tokens = len(results[0]) - len(prompt)
|
||||
assert num_generated_tokens <= max_tokens, f"Generated {num_generated_tokens} tokens, expected max_tokens={max_tokens} or less."
|
||||
|
||||
|
||||
def test_num_samples_count():
|
||||
"""num_samples=N produces exactly N sequences."""
|
||||
model = MockModel()
|
||||
engine = Engine(model, ByteTokenizer())
|
||||
prompt = [261, 72, 101, 108, 108, 111]
|
||||
|
||||
for num_samples in [1, 4, 16, 64]:
|
||||
results, _ = engine.generate_batch(prompt, num_samples=num_samples, max_tokens=3)
|
||||
assert len(results) == num_samples, f"Expected {num_samples} sequences from {num_samples} samples, got {len(results)}"
|
||||
|
||||
|
||||
def test_different_seeds_introduce_variation_when_temperature_nonzero():
|
||||
"""With temperature > 0, different seeds should introduce sampling variation."""
|
||||
model = MockModel()
|
||||
engine = Engine(model, ByteTokenizer())
|
||||
prompt = [261, 72, 101, 108, 108, 111] # <bos> + "Hello"
|
||||
|
||||
outputs = set()
|
||||
|
||||
for seed in [1, 42, 123, 999, 1000, 1001, 1002, 1003, 1004, 1005]:
|
||||
results, _ = engine.generate_batch(
|
||||
prompt,
|
||||
temperature=1.0,
|
||||
max_tokens=5,
|
||||
seed=seed,
|
||||
)
|
||||
outputs.add(tuple(results[0]))
|
||||
|
||||
# Sanity check: sampling actually introduces variation
|
||||
assert len(outputs) > 1, "All seeds produced the same output which is statistically highly improbable."
|
||||
@@ -1,635 +0,0 @@
|
||||
"""
|
||||
Comparing the training of:
|
||||
|
||||
1. (very slow) Python reference implementation
|
||||
2. Optimized Python implementation
|
||||
3. HuggingFace tokenizers training implementation
|
||||
4. Our own custom RustBPE training implementation
|
||||
|
||||
All of these should calculate the same merges and produce
|
||||
the same vocabulary and tokenizations.
|
||||
|
||||
Finally, for inference we will use tiktoken for efficiency.
|
||||
So we want to make sure we can export our rustbpe tokenizer
|
||||
into tiktoken and use it for inference with identical results.
|
||||
|
||||
Run with:
|
||||
python -m pytest tests/test_rustbpe.py -v -s
|
||||
-v is verbose, -s is show prints
|
||||
"""
|
||||
|
||||
import regex as re
|
||||
from collections import Counter, defaultdict
|
||||
import time
|
||||
import rustbpe
|
||||
import tiktoken
|
||||
import pytest
|
||||
|
||||
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Reference tokenizer, pretty much copy pasted and pruned a bit from minbpe
|
||||
|
||||
def get_stats(ids, counts=None):
|
||||
"""
|
||||
Given a list of integers, return a dictionary of counts of consecutive pairs
|
||||
Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
|
||||
Optionally allows to update an existing dictionary of counts
|
||||
"""
|
||||
counts = {} if counts is None else counts
|
||||
for pair in zip(ids, ids[1:]): # iterate consecutive elements
|
||||
counts[pair] = counts.get(pair, 0) + 1
|
||||
return counts
|
||||
|
||||
def merge(ids, pair, idx):
|
||||
"""
|
||||
In the list of integers (ids), replace all consecutive occurrences
|
||||
of pair with the new integer token idx
|
||||
Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
|
||||
"""
|
||||
newids = []
|
||||
i = 0
|
||||
while i < len(ids):
|
||||
# if not at the very last position AND the pair matches, replace it
|
||||
if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
|
||||
newids.append(idx)
|
||||
i += 2
|
||||
else:
|
||||
newids.append(ids[i])
|
||||
i += 1
|
||||
return newids
|
||||
|
||||
class RegexTokenizer:
|
||||
|
||||
def __init__(self, pattern=None):
|
||||
"""
|
||||
- pattern: optional string to override the default (GPT-4 split pattern)
|
||||
- special_tokens: str -> int dictionary of special tokens
|
||||
example: {'<|endoftext|>': 100257}
|
||||
"""
|
||||
self.pattern = GPT4_SPLIT_PATTERN if pattern is None else pattern
|
||||
self.merges = {} # (int, int) -> int
|
||||
self.compiled_pattern = re.compile(self.pattern)
|
||||
self.special_tokens = {}
|
||||
self.inverse_special_tokens = {}
|
||||
self.vocab = self._build_vocab()
|
||||
|
||||
def _build_vocab(self):
|
||||
# vocab is simply and deterministically derived from merges
|
||||
vocab = {idx: bytes([idx]) for idx in range(256)}
|
||||
for (p0, p1), idx in self.merges.items():
|
||||
vocab[idx] = vocab[p0] + vocab[p1]
|
||||
for special, idx in self.special_tokens.items():
|
||||
vocab[idx] = special.encode("utf-8")
|
||||
return vocab
|
||||
|
||||
def train(self, text, vocab_size, verbose=False):
|
||||
assert vocab_size >= 256
|
||||
num_merges = vocab_size - 256
|
||||
|
||||
# keep track of whether at any point during training the merge is ambiguous (counts of pairs are not unique)
|
||||
ambiguous = False
|
||||
|
||||
# split the text up into text chunks
|
||||
text_chunks = re.findall(self.compiled_pattern, text)
|
||||
|
||||
# input text preprocessing
|
||||
ids = [list(ch.encode("utf-8")) for ch in text_chunks]
|
||||
|
||||
# iteratively merge the most common pairs to create new tokens
|
||||
merges = {} # (int, int) -> int
|
||||
vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes
|
||||
for i in range(num_merges):
|
||||
# count the number of times every consecutive pair appears
|
||||
stats = {}
|
||||
for chunk_ids in ids:
|
||||
# passing in stats will update it in place, adding up counts
|
||||
get_stats(chunk_ids, stats)
|
||||
# find the pair with the highest count
|
||||
pair = max(stats, key=stats.get)
|
||||
# check if the merge is ambiguous - i.e. the max value is not unique
|
||||
pair_count = stats[pair]
|
||||
pairs_with_max_count = [pair for pair, count in stats.items() if count == pair_count]
|
||||
if len(pairs_with_max_count) > 1:
|
||||
# print the top 10 pairs with their counts
|
||||
# print(f"{i} Merge is ambiguous! {pair} has {pair_count} occurrences")
|
||||
# for print_pair, print_count in sorted(stats.items(), key=lambda x: x[1], reverse=True)[:10]:
|
||||
# print(f"{print_pair}: {print_count}")
|
||||
ambiguous = True
|
||||
# mint a new token: assign it the next available id
|
||||
idx = 256 + i
|
||||
# replace all occurrences of pair in ids with idx
|
||||
ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids]
|
||||
# save the merge
|
||||
merges[pair] = idx
|
||||
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
|
||||
# prints
|
||||
if verbose:
|
||||
print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
|
||||
|
||||
# save class variables
|
||||
self.merges = merges # used in encode()
|
||||
self.vocab = vocab # used in decode()
|
||||
return ambiguous
|
||||
|
||||
def _encode_chunk(self, text_bytes):
|
||||
# return the token ids
|
||||
# let's begin. first, convert all bytes to integers in range 0..255
|
||||
ids = list(text_bytes)
|
||||
while len(ids) >= 2:
|
||||
# find the pair with the lowest merge index
|
||||
stats = get_stats(ids)
|
||||
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
|
||||
# subtle: if there are no more merges available, the key will
|
||||
# result in an inf for every single pair, and the min will be
|
||||
# just the first pair in the list, arbitrarily
|
||||
# we can detect this terminating case by a membership check
|
||||
if pair not in self.merges:
|
||||
break # nothing else can be merged anymore
|
||||
# otherwise let's merge the best pair (lowest merge index)
|
||||
idx = self.merges[pair]
|
||||
ids = merge(ids, pair, idx)
|
||||
return ids
|
||||
|
||||
def encode_ordinary(self, text):
|
||||
"""Encoding that ignores any special tokens."""
|
||||
# split text into chunks of text by categories defined in regex pattern
|
||||
text_chunks = re.findall(self.compiled_pattern, text)
|
||||
# all chunks of text are encoded separately, then results are joined
|
||||
ids = []
|
||||
for chunk in text_chunks:
|
||||
chunk_bytes = chunk.encode("utf-8") # raw bytes
|
||||
chunk_ids = self._encode_chunk(chunk_bytes)
|
||||
ids.extend(chunk_ids)
|
||||
return ids
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Faster Python tokenizer, optimized version of the reference tokenizer
|
||||
|
||||
def fast_merge_inplace(ids, pair, idx):
|
||||
"""
|
||||
In the list of integers (ids), replace all consecutive occurrences
|
||||
of pair with the new integer token idx in place
|
||||
Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
|
||||
"""
|
||||
# Find all positions where the pair occurs
|
||||
i = 0
|
||||
while i < len(ids) - 1:
|
||||
if ids[i] == pair[0] and ids[i+1] == pair[1]:
|
||||
ids[i] = idx
|
||||
ids.pop(i+1)
|
||||
else:
|
||||
i += 1
|
||||
return ids
|
||||
|
||||
|
||||
class FastRegexTokenizer:
|
||||
|
||||
def __init__(self, pattern=None):
|
||||
"""
|
||||
- pattern: optional string to override the default (GPT-4 split pattern)
|
||||
- special_tokens: str -> int dictionary of special tokens
|
||||
example: {'<|endoftext|>': 100257}
|
||||
"""
|
||||
self.pattern = GPT4_SPLIT_PATTERN if pattern is None else pattern
|
||||
self.compiled_pattern = re.compile(self.pattern)
|
||||
self.special_tokens = {}
|
||||
self.inverse_special_tokens = {}
|
||||
self.merges = {}
|
||||
self.vocab = self._build_vocab()
|
||||
|
||||
def _build_vocab(self):
|
||||
# vocab is simply and deterministically derived from merges
|
||||
vocab = {idx: bytes([idx]) for idx in range(256)}
|
||||
for (p0, p1), idx in self.merges.items():
|
||||
vocab[idx] = vocab[p0] + vocab[p1]
|
||||
for special, idx in self.special_tokens.items():
|
||||
vocab[idx] = special.encode("utf-8")
|
||||
return vocab
|
||||
|
||||
def train(self, text, vocab_size, verbose=False):
|
||||
"""
|
||||
A number of optimizations are introduced:
|
||||
- delete function call overhead by inlining functions
|
||||
- modifying list of ids in place with .pop() instead of creating a new list
|
||||
- collapse identical chunks to just the unique ones
|
||||
- update counts more cleverly - only around the affected chunks
|
||||
"""
|
||||
assert vocab_size >= 256
|
||||
num_merges = vocab_size - 256
|
||||
|
||||
# split the text up into text chunks
|
||||
text_chunks = re.findall(self.compiled_pattern, text)
|
||||
|
||||
# many, many chunks are identical, so we can "collapse" them to just the unique ones
|
||||
counts = Counter(text_chunks)
|
||||
unique_chunks = [ch for ch, count in counts.items()]
|
||||
chunk_counts = [count for ch, count in counts.items()]
|
||||
|
||||
# input text preprocessing
|
||||
ids = [list(ch.encode("utf-8")) for ch in unique_chunks]
|
||||
# iteratively merge the most common pairs to create new tokens
|
||||
merges = {} # (int, int) -> int
|
||||
vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes
|
||||
|
||||
# Initial count: build stats and position tracking
|
||||
stats = defaultdict(int)
|
||||
positions = defaultdict(set) # pair -> set of chunk indices that contain this pair
|
||||
|
||||
for chunk_idx, (chunk_ids, count) in enumerate(zip(ids, chunk_counts)):
|
||||
for pair in zip(chunk_ids, chunk_ids[1:]):
|
||||
stats[pair] += count
|
||||
positions[pair].add(chunk_idx)
|
||||
|
||||
for i in range(num_merges):
|
||||
if not stats:
|
||||
break
|
||||
|
||||
# find the pair with the highest count
|
||||
pair = max(stats, key=stats.get)
|
||||
# mint a new token: assign it the next available id
|
||||
idx = 256 + i
|
||||
|
||||
# Get chunks that contain this pair
|
||||
affected_chunks = positions[pair]
|
||||
|
||||
# Track count changes for incremental update
|
||||
count_changes = defaultdict(int)
|
||||
|
||||
# Replace all occurrences of pair in affected chunks only
|
||||
for chunk_idx in affected_chunks:
|
||||
chunk_ids = ids[chunk_idx]
|
||||
chunk_count = chunk_counts[chunk_idx]
|
||||
ix = 0
|
||||
while ix < len(chunk_ids) - 1:
|
||||
if chunk_ids[ix] == pair[0] and chunk_ids[ix+1] == pair[1]:
|
||||
# Track what pairs are being removed/added
|
||||
# Remove: (prev, A), (A, B), (B, next)
|
||||
if ix > 0:
|
||||
old_left = (chunk_ids[ix-1], chunk_ids[ix])
|
||||
count_changes[old_left] -= chunk_count
|
||||
|
||||
# The merged pair disappears
|
||||
count_changes[pair] -= chunk_count
|
||||
|
||||
if ix + 2 < len(chunk_ids):
|
||||
old_right = (chunk_ids[ix+1], chunk_ids[ix+2])
|
||||
count_changes[old_right] -= chunk_count
|
||||
|
||||
# Apply the merge
|
||||
chunk_ids[ix] = idx
|
||||
chunk_ids.pop(ix+1)
|
||||
|
||||
# Add: (prev, C), (C, next)
|
||||
if ix > 0:
|
||||
new_left = (chunk_ids[ix-1], chunk_ids[ix])
|
||||
count_changes[new_left] += chunk_count
|
||||
|
||||
if ix + 1 < len(chunk_ids):
|
||||
new_right = (chunk_ids[ix], chunk_ids[ix+1])
|
||||
count_changes[new_right] += chunk_count
|
||||
else:
|
||||
ix += 1
|
||||
|
||||
# Apply incremental changes to stats and positions
|
||||
for changed_pair, delta in count_changes.items():
|
||||
if changed_pair == pair:
|
||||
# The merged pair should disappear completely
|
||||
continue
|
||||
|
||||
stats[changed_pair] += delta
|
||||
|
||||
# Update positions for changed pairs - only check affected chunks
|
||||
for chunk_idx in affected_chunks:
|
||||
chunk_ids = ids[chunk_idx]
|
||||
contains_pair = any((chunk_ids[j], chunk_ids[j+1]) == changed_pair
|
||||
for j in range(len(chunk_ids) - 1))
|
||||
if contains_pair:
|
||||
positions[changed_pair].add(chunk_idx)
|
||||
else:
|
||||
positions[changed_pair].discard(chunk_idx)
|
||||
|
||||
# Remove the merged pair completely
|
||||
del stats[pair]
|
||||
del positions[pair]
|
||||
|
||||
# save the merge
|
||||
merges[pair] = idx
|
||||
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
|
||||
|
||||
# save class variables
|
||||
self.merges = merges # used in encode()
|
||||
self.vocab = vocab # used in decode()
|
||||
|
||||
def register_special_tokens(self, special_tokens):
|
||||
# special_tokens is a dictionary of str -> int
|
||||
# example: {"<|endoftext|>": 100257}
|
||||
self.special_tokens = special_tokens
|
||||
self.inverse_special_tokens = {v: k for k, v in special_tokens.items()}
|
||||
|
||||
def decode(self, ids):
|
||||
# given ids (list of integers), return Python string
|
||||
part_bytes = []
|
||||
for idx in ids:
|
||||
if idx in self.vocab:
|
||||
part_bytes.append(self.vocab[idx])
|
||||
elif idx in self.inverse_special_tokens:
|
||||
part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8"))
|
||||
else:
|
||||
raise ValueError(f"invalid token id: {idx}")
|
||||
text_bytes = b"".join(part_bytes)
|
||||
text = text_bytes.decode("utf-8", errors="replace")
|
||||
return text
|
||||
|
||||
def _encode_chunk(self, text_bytes):
|
||||
# return the token ids
|
||||
# let's begin. first, convert all bytes to integers in range 0..255
|
||||
ids = list(text_bytes)
|
||||
while len(ids) >= 2:
|
||||
# find the pair with the lowest merge index
|
||||
stats = get_stats(ids)
|
||||
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
|
||||
# subtle: if there are no more merges available, the key will
|
||||
# result in an inf for every single pair, and the min will be
|
||||
# just the first pair in the list, arbitrarily
|
||||
# we can detect this terminating case by a membership check
|
||||
if pair not in self.merges:
|
||||
break # nothing else can be merged anymore
|
||||
# otherwise let's merge the best pair (lowest merge index)
|
||||
idx = self.merges[pair]
|
||||
ids = fast_merge_inplace(ids, pair, idx)
|
||||
return ids
|
||||
|
||||
def encode_ordinary(self, text):
|
||||
"""Encoding that ignores any special tokens."""
|
||||
# split text into chunks of text by categories defined in regex pattern
|
||||
text_chunks = re.findall(self.compiled_pattern, text)
|
||||
# all chunks of text are encoded separately, then results are joined
|
||||
ids = []
|
||||
for chunk in text_chunks:
|
||||
chunk_bytes = chunk.encode("utf-8") # raw bytes
|
||||
chunk_ids = self._encode_chunk(chunk_bytes)
|
||||
ids.extend(chunk_ids)
|
||||
return ids
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# HuggingFace tokenizer
|
||||
from tokenizers import Tokenizer as HFTokenizer
|
||||
from tokenizers import pre_tokenizers, decoders, Regex
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
|
||||
class HuggingFaceTokenizer:
|
||||
"""Light wrapper around HuggingFace Tokenizer for some utilities"""
|
||||
|
||||
def __init__(self, tokenizer):
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@classmethod
|
||||
def train_from_iterator(cls, text_iterator, vocab_size):
|
||||
# train from an iterator of text
|
||||
# Configure the HuggingFace Tokenizer
|
||||
tokenizer = HFTokenizer(BPE(
|
||||
byte_fallback=True, # needed!
|
||||
unk_token=None,
|
||||
fuse_unk=False,
|
||||
))
|
||||
# Normalizer: None
|
||||
tokenizer.normalizer = None
|
||||
# Pre-tokenizer: GPT-4 style
|
||||
gpt4_split_regex = Regex(GPT4_SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
|
||||
pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
|
||||
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
|
||||
])
|
||||
# Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer)
|
||||
tokenizer.decoder = decoders.ByteLevel()
|
||||
# Post-processor: None
|
||||
tokenizer.post_processor = None
|
||||
# Trainer: BPE
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=vocab_size,
|
||||
show_progress=True,
|
||||
min_frequency=0, # no minimum frequency
|
||||
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
|
||||
special_tokens=[], # no special tokens
|
||||
)
|
||||
# Kick off the training
|
||||
tokenizer.train_from_iterator(text_iterator, trainer)
|
||||
return cls(tokenizer)
|
||||
|
||||
def encode_ordinary(self, text):
|
||||
ids = self.tokenizer.encode(text, add_special_tokens=False).ids
|
||||
return ids
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Test all of the above
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def enwik8_path():
|
||||
"""Fixture to download and cache enwik8 dataset."""
|
||||
import os
|
||||
import zipfile
|
||||
from nanochat.common import get_base_dir
|
||||
base_dir = get_base_dir()
|
||||
# download and unzip enwik8 to .cache directory
|
||||
enwik8_url = "https://mattmahoney.net/dc/enwik8.zip"
|
||||
enwik8_local_path = os.path.join(base_dir, "enwik8")
|
||||
enwik8_local_path_zip = os.path.join(base_dir, "enwik8.zip")
|
||||
if not os.path.exists(enwik8_local_path):
|
||||
print(f"Downloading enwik8 to {enwik8_local_path_zip}")
|
||||
import requests
|
||||
response = requests.get(enwik8_url)
|
||||
with open(enwik8_local_path_zip, "wb") as f:
|
||||
f.write(response.content)
|
||||
with zipfile.ZipFile(enwik8_local_path_zip, "r") as zip_ref:
|
||||
zip_ref.extractall(base_dir)
|
||||
print(f"Unzipped enwik8 to {enwik8_local_path}")
|
||||
os.remove(enwik8_local_path_zip)
|
||||
print(f"Removed {enwik8_local_path_zip}")
|
||||
else:
|
||||
print(f"Using existing enwik8 at {enwik8_local_path}")
|
||||
return enwik8_local_path
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def enwik8_small(enwik8_path):
|
||||
"""Fixture providing 100KB of enwik8 for quick tests."""
|
||||
with open(enwik8_path, "r") as f:
|
||||
return f.read(100_000)
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def enwik8_large(enwik8_path):
|
||||
"""Fixture providing 10MB of enwik8 for performance tests."""
|
||||
with open(enwik8_path, "r") as f:
|
||||
return f.read(10**7)
|
||||
|
||||
def time_function(func, *args, **kwargs):
|
||||
"""Time a function call and return the result and elapsed time"""
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end_time = time.time()
|
||||
elapsed = end_time - start_time
|
||||
return result, elapsed
|
||||
|
||||
def test_correctness(enwik8_small):
|
||||
"""Test that all tokenizer implementations produce the same results."""
|
||||
text = enwik8_small
|
||||
encode_text = text
|
||||
vocab_size = 256 + 20 # 20 merges
|
||||
|
||||
# Train slow reference
|
||||
print("\nTraining slow reference...")
|
||||
slow_reference_tokenizer = RegexTokenizer()
|
||||
ambiguous_flag, slow_reference_train_time = time_function(slow_reference_tokenizer.train, text, vocab_size)
|
||||
slow_reference_ids, slow_reference_encode_time = time_function(slow_reference_tokenizer.encode_ordinary, encode_text)
|
||||
print(f"Slow reference train time: {slow_reference_train_time:.4f}s")
|
||||
print(f"Slow reference encode time: {slow_reference_encode_time:.4f}s")
|
||||
print(slow_reference_ids[:20])
|
||||
|
||||
if ambiguous_flag:
|
||||
print("‼️ WARNING: merge order was detected to be ambiguous given current text and vocab size")
|
||||
print("The implementation could be correct but we might see different results below")
|
||||
else:
|
||||
print("✅ Merge order is NOT ambiguous")
|
||||
|
||||
# Train fast reference
|
||||
print("\nTraining fast reference...")
|
||||
fast_reference_tokenizer = FastRegexTokenizer()
|
||||
_, fast_reference_train_time = time_function(fast_reference_tokenizer.train, text, vocab_size)
|
||||
fast_reference_ids, fast_reference_encode_time = time_function(fast_reference_tokenizer.encode_ordinary, encode_text)
|
||||
print(f"Fast reference train time: {fast_reference_train_time:.4f}s")
|
||||
print(f"Fast reference encode time: {fast_reference_encode_time:.4f}s")
|
||||
print(fast_reference_ids[:20])
|
||||
|
||||
# Assert fast equals slow
|
||||
assert fast_reference_ids == slow_reference_ids, "Fast reference should match slow reference"
|
||||
print("✅ Fast == Slow")
|
||||
|
||||
# Train HuggingFace
|
||||
print("\nTraining HuggingFace...")
|
||||
hf_tokenizer, hf_train_time = time_function(HuggingFaceTokenizer.train_from_iterator, [text], vocab_size)
|
||||
hf_ids, hf_encode_time = time_function(hf_tokenizer.encode_ordinary, encode_text)
|
||||
print(f"HuggingFace train time: {hf_train_time:.4f}s")
|
||||
print(f"HuggingFace encode time: {hf_encode_time:.4f}s")
|
||||
print(hf_ids[:20])
|
||||
|
||||
# HuggingFace has a different byte order, so we need custom matching
|
||||
def custom_match(ids1, ids2):
|
||||
perm = {}
|
||||
for x, y in zip(ids1, ids2):
|
||||
if x < 256:
|
||||
if x in perm:
|
||||
if perm[x] != y:
|
||||
return False
|
||||
perm[x] = y
|
||||
if x >= 256 and x != y:
|
||||
return False
|
||||
return True
|
||||
|
||||
assert custom_match(hf_ids, fast_reference_ids), "HuggingFace should match fast reference"
|
||||
print("✅ HuggingFace == Fast")
|
||||
|
||||
# Finally use our own Rust implementation
|
||||
print("\nTraining rustbpe...")
|
||||
rustbpe_tokenizer = rustbpe.Tokenizer()
|
||||
_, rustbpe_train_time = time_function(rustbpe_tokenizer.train_from_iterator, [text], vocab_size)
|
||||
rustbpe_ids, rustbpe_encode_time = time_function(rustbpe_tokenizer.encode, encode_text)
|
||||
print(f"RustBPE train time: {rustbpe_train_time:.4f}s")
|
||||
print(f"RustBPE encode time: {rustbpe_encode_time:.4f}s")
|
||||
print(rustbpe_ids[:20])
|
||||
|
||||
assert rustbpe_ids == fast_reference_ids, "RustBPE should match fast reference"
|
||||
print("✅ RustBPE == Fast")
|
||||
|
||||
# Now export rustbpe to tiktoken for more efficient inference
|
||||
print("\nTesting tiktoken export...")
|
||||
pattern = rustbpe_tokenizer.get_pattern()
|
||||
mergeable_ranks_list = rustbpe_tokenizer.get_mergeable_ranks()
|
||||
mergeable_ranks = {bytes(k): v for k, v in mergeable_ranks_list}
|
||||
enc = tiktoken.Encoding(
|
||||
name="rustbpe",
|
||||
pat_str=pattern,
|
||||
mergeable_ranks=mergeable_ranks,
|
||||
special_tokens={},
|
||||
)
|
||||
tiktoken_ids, tiktoken_encode_time = time_function(enc.encode, encode_text)
|
||||
print(f"Tiktoken encode time: {tiktoken_encode_time:.4f}s")
|
||||
print(tiktoken_ids[:20])
|
||||
|
||||
assert tiktoken_ids == rustbpe_ids, "Tiktoken should match RustBPE"
|
||||
print("✅ Tiktoken == RustBPE")
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_training_performance(enwik8_large):
|
||||
"""Use a bigger dataset and compare the training speed of the optimized tokenizers (Python, Rust, HuggingFace)."""
|
||||
text = enwik8_large
|
||||
vocab_size = 2048
|
||||
print(f"\nText length: {len(text)}")
|
||||
|
||||
# Commenting out because it's just way too slow to matter
|
||||
# Train optimized python version
|
||||
# print("Training optimized python version...")
|
||||
# optimized_python_tokenizer = FastRegexTokenizer()
|
||||
# _, optimized_python_train_time = time_function(optimized_python_tokenizer.train, text, vocab_size)
|
||||
# print(f"Optimized python train time: {optimized_python_train_time:.4f}s")
|
||||
|
||||
# Train rustbpe
|
||||
print("\nTraining rustbpe...")
|
||||
rustbpe_tokenizer = rustbpe.Tokenizer()
|
||||
_, rustbpe_train_time = time_function(rustbpe_tokenizer.train_from_iterator, [text], vocab_size)
|
||||
print(f"RustBPE train time: {rustbpe_train_time:.4f}s")
|
||||
assert rustbpe_train_time > 0, "Training should take some time"
|
||||
|
||||
# Train HuggingFace
|
||||
print("\nTraining HuggingFace...")
|
||||
hf_tokenizer, hf_train_time = time_function(HuggingFaceTokenizer.train_from_iterator, [text], vocab_size)
|
||||
print(f"HuggingFace train time: {hf_train_time:.4f}s")
|
||||
assert hf_train_time > 0, "Training should take some time"
|
||||
|
||||
# Print comparison
|
||||
print(f"\n📊 Performance comparison:")
|
||||
print(f" RustBPE: {rustbpe_train_time:.4f}s")
|
||||
print(f" HuggingFace: {hf_train_time:.4f}s")
|
||||
print(f" Speedup: {hf_train_time/rustbpe_train_time:.2f}x")
|
||||
|
||||
def test_interface(enwik8_small):
|
||||
"""Test the RustBPETokenizer interface for training, encoding, decoding, and serialization."""
|
||||
import tempfile
|
||||
from nanochat.tokenizer import RustBPETokenizer
|
||||
|
||||
# Simple train test
|
||||
vocab_size = 300
|
||||
tok = RustBPETokenizer.train_from_iterator([enwik8_small], vocab_size)
|
||||
assert tok.get_vocab_size() == vocab_size, f"Expected vocab size {vocab_size}, got {tok.get_vocab_size()}"
|
||||
print(f"✅ Trained tokenizer with vocab size {vocab_size}")
|
||||
|
||||
# Encode/decode text
|
||||
encode_text = "Hello world! How are you? 🙃"
|
||||
ids = tok.encode(encode_text)
|
||||
print(f"\nInput text: {encode_text}")
|
||||
print(f"IDs: {ids}")
|
||||
decoded = tok.decode(ids)
|
||||
print(f"Decoded: {decoded}")
|
||||
assert decoded == encode_text, f"Decoded text doesn't match: {decoded} != {encode_text}"
|
||||
print("✅ Encode/decode test passed")
|
||||
|
||||
# Encode batch test
|
||||
ids_new = tok.encode([encode_text, encode_text])
|
||||
assert all(x == ids for x in ids_new), "Batch encoding should produce identical results"
|
||||
print("✅ Encode batch OK")
|
||||
|
||||
# append/prepend functionality
|
||||
ids_special = tok.encode(encode_text, prepend="<|bos|>", append="<|bos|>")
|
||||
bos_token_id = tok.encode_special("<|bos|>")
|
||||
assert ids_special == [bos_token_id] + ids + [bos_token_id], "Special tokens not correctly added"
|
||||
print("✅ append/prepend OK")
|
||||
|
||||
# Save/load test through a temporary directory
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tok.save(tmp_dir)
|
||||
tok_reloaded = RustBPETokenizer.from_directory(tmp_dir)
|
||||
ids_reloaded = tok_reloaded.encode(encode_text)
|
||||
assert ids_reloaded == ids, "Reloaded tokenizer should produce same results"
|
||||
print("✅ Save/load through temporary directory OK")
|
||||
Reference in New Issue
Block a user