mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-30 04:22:02 +00:00
Compare commits
4 Commits
d5418ea5a1
...
41bb2eac32
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
41bb2eac32 | ||
|
|
64a651a63c | ||
|
|
65df0de42b | ||
|
|
74554be3b5 |
40
.claude/skills/read-arxiv-paper/SKILL.md
Normal file
40
.claude/skills/read-arxiv-paper/SKILL.md
Normal file
@@ -0,0 +1,40 @@
|
||||
---
|
||||
name: read-arxiv-paper
|
||||
description: Use this skill when when asked to read an arxiv paper given an arxiv URL
|
||||
---
|
||||
|
||||
You will be given a URL of an arxiv paper, for example:
|
||||
|
||||
https://www.arxiv.org/abs/2601.07372
|
||||
|
||||
### Part 1: Normalize the URL
|
||||
|
||||
The goal is to fetch the TeX Source of the paper (not the PDF!), the URL always looks like this:
|
||||
|
||||
https://www.arxiv.org/src/2601.07372
|
||||
|
||||
Notice the /src/ in the url. Once you have the URL:
|
||||
|
||||
### Part 2: Download the paper source
|
||||
|
||||
Fetch the url to a local .tar.gz file. A good location is `~/.cache/nanochat/knowledge/{arxiv_id}.tar.gz`.
|
||||
|
||||
(If the file already exists, there is no need to re-download it).
|
||||
|
||||
### Part 3: Unpack the file in that folder
|
||||
|
||||
Unpack the contents into `~/.cache/nanochat/knowledge/{arxiv_id}` directory.
|
||||
|
||||
### Part 4: Locate the entrypoint
|
||||
|
||||
Every latex source usually has an entrypoint, such as `main.tex` or something like that.
|
||||
|
||||
### Part 5: Read the paper
|
||||
|
||||
Once you've found the entrypoint, Read the contents and then recurse through all other relevant source files to read the paper.
|
||||
|
||||
#### Part 6: Report
|
||||
|
||||
Once you've read the paper, produce a summary of the paper into a markdown file at `./knowledge/summary_{tag}.md`. Notice that 1) use the local knowledge directory here (it's easier for me to open and reference here), not in `~/.cache`, and 2) generate some reasonable `tag` like e.g. `conditional_memory` or whatever seems appropriate given the paper. Probably make sure that the tag doesn't exist yet so you're not overwriting files.
|
||||
|
||||
As for the summary itself, remember that you're processing this paper within the context of the nanochat repository, so most often we we will be interested in how to apply the paper and its lessons to the nanochat project. Therefore, you should feel free to "remind yourself" of the related nanochat code by reading the relevant parts, and then explicitly make the connection of how this paper might relate to nanochat or what are things we might be inspired about or try.
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -9,6 +9,5 @@ eval_bundle/
|
||||
.env
|
||||
|
||||
# Local setup
|
||||
.claude
|
||||
CLAUDE.md
|
||||
wandb/
|
||||
|
||||
@@ -135,7 +135,6 @@ python -m pytest tests/test_engine.py -v -s
|
||||
│ └── repackage_data_reference.py # Pretraining data shard generation
|
||||
├── nanochat
|
||||
│ ├── __init__.py # empty
|
||||
│ ├── adamw.py # Distributed AdamW optimizer
|
||||
│ ├── checkpoint_manager.py # Save/Load model checkpoints
|
||||
│ ├── common.py # Misc small utilities, quality of life
|
||||
│ ├── core_eval.py # Evaluates base model CORE score (DCLM paper)
|
||||
@@ -146,7 +145,7 @@ python -m pytest tests/test_engine.py -v -s
|
||||
│ ├── gpt.py # The GPT nn.Module Transformer
|
||||
│ ├── logo.svg
|
||||
│ ├── loss_eval.py # Evaluate bits per byte (instead of loss)
|
||||
│ ├── muon.py # Distributed Muon optimizer
|
||||
│ ├── optim.py # AdamW + Muon optimizer, 1GPU and distributed
|
||||
│ ├── report.py # Utilities for writing the nanochat Report
|
||||
│ ├── tokenizer.py # BPE Tokenizer wrapper in style of GPT-4
|
||||
│ └── ui.html # HTML/CSS/JS for nanochat frontend
|
||||
|
||||
@@ -4,6 +4,12 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-28: Reverted Bigram Hash Embeddings
|
||||
|
||||
Removed bigram embeddings (engram-lite) from the codebase. At larger scale (d25), the improvement was tiny and disappeared entirely when measured by wall clock time. It also bloated the VRAM used. The extra parameters and complexity aren't justified.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-27: Bigram Hash Embeddings (Engram-lite)
|
||||
|
||||
Explored N-gram memory modules inspired by the [DeepSeek Engram paper](https://arxiv.org/abs/2601.07372) and [modded-nanogpt PR #201](https://github.com/KellerJordan/modded-nanogpt/pull/201).
|
||||
|
||||
@@ -1,143 +0,0 @@
|
||||
"""
|
||||
Distributed AdamW optimizer with a fused step function.
|
||||
A bunch of ideas (e.g. dist comms in slices) are borrowed from modded-nanogpt.
|
||||
"""
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
@torch.compile(dynamic=False, fullgraph=True)
|
||||
def adamw_step_fused(
|
||||
p: Tensor,
|
||||
grad: Tensor,
|
||||
exp_avg: Tensor,
|
||||
exp_avg_sq: Tensor,
|
||||
step_t: Tensor,
|
||||
lr_t: Tensor,
|
||||
beta1_t: Tensor,
|
||||
beta2_t: Tensor,
|
||||
eps_t: Tensor,
|
||||
wd_t: Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
Fused AdamW step: weight_decay -> momentum_update -> bias_correction -> param_update
|
||||
All in one compiled graph to eliminate Python overhead between ops.
|
||||
The 0-D CPU tensors avoid recompilation when hyperparameter values change.
|
||||
"""
|
||||
# Weight decay (decoupled, applied before the update)
|
||||
p.mul_(1 - lr_t * wd_t)
|
||||
# Update running averages (lerp_ is cleaner and fuses well)
|
||||
exp_avg.lerp_(grad, 1 - beta1_t)
|
||||
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
|
||||
# Bias corrections
|
||||
bias1 = 1 - beta1_t ** step_t
|
||||
bias2 = 1 - beta2_t ** step_t
|
||||
# Compute update and apply
|
||||
denom = (exp_avg_sq / bias2).sqrt() + eps_t
|
||||
step_size = lr_t / bias1
|
||||
p.add_(exp_avg / denom, alpha=-step_size)
|
||||
|
||||
|
||||
class DistAdamW(torch.optim.Optimizer):
|
||||
"""
|
||||
Distributed AdamW optimizer.
|
||||
In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction
|
||||
"""
|
||||
def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01):
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
# Validate
|
||||
if rank == 0:
|
||||
for group in param_groups:
|
||||
assert isinstance(group, dict), "expecting param_groups to be a list of dicts"
|
||||
assert isinstance(group['params'], list), "expecting group['params'] to be a list of tensors"
|
||||
for p in group['params']:
|
||||
sliced = p.numel() >= 1024
|
||||
print(f"AdamW: 1 param of shape {p.shape}, sliced={sliced}")
|
||||
if sliced: # large parameter tensors will be operated on in slices
|
||||
assert p.shape[0] % world_size == 0, f"First dim of parameter shape {p.shape} must be divisible by world size {world_size}"
|
||||
super().__init__(param_groups, defaults)
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
self._step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
reduce_futures: list[torch.Future] = []
|
||||
gather_futures: list[torch.Future] = []
|
||||
grad_slices = []
|
||||
is_small = [] # track which params are small (use all_reduce) vs large (use reduce_scatter)
|
||||
|
||||
for group in self.param_groups:
|
||||
params: list[Tensor] = group["params"]
|
||||
for p in params:
|
||||
grad = p.grad
|
||||
# Small params: use all_reduce (no scatter/gather needed)
|
||||
if p.numel() < 1024:
|
||||
is_small.append(True)
|
||||
reduce_futures.append(dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
|
||||
grad_slices.append(grad)
|
||||
else:
|
||||
is_small.append(False)
|
||||
rank_size = grad.shape[0] // world_size # p.shape[0] % world_size == 0 is checked in __init__
|
||||
grad_slice = torch.empty_like(grad[:rank_size])
|
||||
reduce_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
|
||||
grad_slices.append(grad_slice)
|
||||
|
||||
idx = 0
|
||||
for group in self.param_groups:
|
||||
beta1, beta2 = group['betas']
|
||||
eps = group['eps']
|
||||
wd = group['weight_decay']
|
||||
params = group['params']
|
||||
for p in params:
|
||||
reduce_futures[idx].wait()
|
||||
g_slice = grad_slices[idx]
|
||||
lr = group['lr'] * getattr(p, "lr_mul", 1.0)
|
||||
state = self.state[p]
|
||||
|
||||
# For small params, operate on full param; for large, operate on slice
|
||||
if is_small[idx]:
|
||||
p_slice = p
|
||||
else:
|
||||
rank_size = p.shape[0] // world_size
|
||||
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
|
||||
|
||||
# State init
|
||||
if not state:
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p_slice)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p_slice)
|
||||
exp_avg = state['exp_avg']
|
||||
exp_avg_sq = state['exp_avg_sq']
|
||||
state['step'] += 1
|
||||
|
||||
# Fill 0-D tensors with current values
|
||||
eff_wd = wd * getattr(p, "wd_mul", 1.0)
|
||||
self._step_t.fill_(state['step'])
|
||||
self._lr_t.fill_(lr)
|
||||
self._beta1_t.fill_(beta1)
|
||||
self._beta2_t.fill_(beta2)
|
||||
self._eps_t.fill_(eps)
|
||||
self._wd_t.fill_(eff_wd)
|
||||
|
||||
# Fused update: weight_decay -> momentum -> bias_correction -> param_update
|
||||
adamw_step_fused(
|
||||
p_slice, g_slice, exp_avg, exp_avg_sq,
|
||||
self._step_t, self._lr_t, self._beta1_t, self._beta2_t, self._eps_t, self._wd_t,
|
||||
)
|
||||
|
||||
# Only large params need all_gather
|
||||
if not is_small[idx]:
|
||||
gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
|
||||
idx += 1
|
||||
|
||||
if gather_futures:
|
||||
torch.futures.collect_all(gather_futures).wait()
|
||||
115
nanochat/gpt.py
115
nanochat/gpt.py
@@ -20,8 +20,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from nanochat.common import get_dist_info, print0
|
||||
from nanochat.muon import Muon, DistMuon
|
||||
from nanochat.adamw import DistAdamW
|
||||
from nanochat.optim import MuonAdamW, DistMuonAdamW
|
||||
|
||||
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
|
||||
from nanochat.flash_attention import flash_attn
|
||||
@@ -45,41 +44,6 @@ def norm(x):
|
||||
return F.rms_norm(x, (x.size(-1),))
|
||||
|
||||
|
||||
class BigramEmbed(nn.Module):
|
||||
"""
|
||||
Hash bigrams to embeddings. Simple, self-contained, runs on GPU.
|
||||
Following modded-nanogpt's approach: single hash, no gating.
|
||||
|
||||
For each position t, hashes (token[t-1], token[t]) to an index in a large
|
||||
embedding table. This provides O(1) lookup for local 2-gram patterns,
|
||||
offloading static pattern reconstruction from the transformer layers.
|
||||
|
||||
Ref: https://github.com/KellerJordan/modded-nanogpt/pull/201
|
||||
Ref: https://arxiv.org/abs/1709.03933 (Hash Embeddings)
|
||||
"""
|
||||
def __init__(self, vocab_size: int, embed_dim: int, table_multiplier: int = 5):
|
||||
super().__init__()
|
||||
self.bigram_vocab_size = vocab_size * table_multiplier
|
||||
self.embed = nn.Embedding(self.bigram_vocab_size, embed_dim)
|
||||
|
||||
def forward(self, idx: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
idx: (B, T) token ids
|
||||
Returns: (B, T, embed_dim) bigram embeddings
|
||||
"""
|
||||
# Hash (prev_token, curr_token) -> index
|
||||
# Position 0 gets a reserved index (no valid bigram)
|
||||
rand_int_1 = 36313
|
||||
rand_int_2 = 27191
|
||||
mod = self.bigram_vocab_size - 1
|
||||
|
||||
h = torch.empty_like(idx, dtype=torch.long)
|
||||
h[:, 0] = mod # reserved index for position 0
|
||||
h[:, 1:] = (rand_int_1 * idx[:, 1:] ^ rand_int_2 * idx[:, :-1]) % mod
|
||||
|
||||
return self.embed(h)
|
||||
|
||||
|
||||
def has_ve(layer_idx, n_layer):
|
||||
"""Returns True if GPT layer should have Value Embedding (alternating, last layer always included)."""
|
||||
return layer_idx % 2 == (n_layer - 1) % 2
|
||||
@@ -204,13 +168,9 @@ class GPT(nn.Module):
|
||||
# Per-layer learnable scalars (inspired by modded-nanogpt)
|
||||
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
|
||||
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
|
||||
# bigram_lambdas: blends bigram embeddings in at each layer (init 0.1 = small contribution)
|
||||
# Separate parameters so they can have different optimizer treatment
|
||||
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
|
||||
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
||||
self.bigram_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
||||
# Bigram hash embeddings: O(1) lookup for local 2-gram patterns
|
||||
self.bigram_embed = BigramEmbed(config.vocab_size, config.n_embd)
|
||||
# Value embeddings (ResFormer-style): alternating layers, last layer always included
|
||||
head_dim = config.n_embd // config.n_head
|
||||
kv_dim = config.n_kv_head * head_dim
|
||||
@@ -259,10 +219,6 @@ class GPT(nn.Module):
|
||||
# Per-layer scalars
|
||||
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
|
||||
self.x0_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to input embedding
|
||||
self.bigram_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to bigram embeddings
|
||||
|
||||
# Bigram embeddings: zero init so it starts as identity
|
||||
nn.init.zeros_(self.bigram_embed.embed.weight)
|
||||
|
||||
# Value embeddings (init like c_v: uniform with same std)
|
||||
for ve in self.value_embeds.values():
|
||||
@@ -283,7 +239,6 @@ class GPT(nn.Module):
|
||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||
for ve in self.value_embeds.values():
|
||||
ve.to(dtype=torch.bfloat16)
|
||||
self.bigram_embed.to(dtype=torch.bfloat16)
|
||||
|
||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
||||
# TODO: bump base theta more? e.g. 100K is more common more recently
|
||||
@@ -349,9 +304,8 @@ class GPT(nn.Module):
|
||||
nparams = sum(p.numel() for p in self.parameters())
|
||||
# Exclude non-matmul params: embeddings and per-layer scalars
|
||||
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
|
||||
bigram_embed_numel = self.bigram_embed.embed.weight.numel()
|
||||
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + bigram_embed_numel +
|
||||
self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.bigram_lambdas.numel())
|
||||
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
|
||||
self.resid_lambdas.numel() + self.x0_lambdas.numel())
|
||||
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
||||
# Sum attention FLOPs per layer, accounting for sliding window
|
||||
attn_flops = 0
|
||||
@@ -376,16 +330,14 @@ class GPT(nn.Module):
|
||||
"""
|
||||
# Count each group separately (mirrors the grouping in setup_optimizers)
|
||||
wte = sum(p.numel() for p in self.transformer.wte.parameters())
|
||||
bigram_embed = sum(p.numel() for p in self.bigram_embed.parameters())
|
||||
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
|
||||
lm_head = sum(p.numel() for p in self.lm_head.parameters())
|
||||
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
|
||||
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.bigram_lambdas.numel()
|
||||
total = wte + bigram_embed + value_embeds + lm_head + transformer_matrices + scalars
|
||||
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel()
|
||||
total = wte + value_embeds + lm_head + transformer_matrices + scalars
|
||||
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
|
||||
return {
|
||||
'wte': wte,
|
||||
'bigram_embed': bigram_embed,
|
||||
'value_embeds': value_embeds,
|
||||
'lm_head': lm_head,
|
||||
'transformer_matrices': transformer_matrices,
|
||||
@@ -393,9 +345,10 @@ class GPT(nn.Module):
|
||||
'total': total,
|
||||
}
|
||||
|
||||
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
|
||||
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
|
||||
model_dim = self.config.n_embd
|
||||
ddp, rank, local_rank, world_size = get_dist_info()
|
||||
|
||||
# Separate out all parameters into groups
|
||||
matrix_params = list(self.transformer.h.parameters())
|
||||
value_embeds_params = list(self.value_embeds.parameters())
|
||||
@@ -403,35 +356,34 @@ class GPT(nn.Module):
|
||||
lm_head_params = list(self.lm_head.parameters())
|
||||
resid_params = [self.resid_lambdas]
|
||||
x0_params = [self.x0_lambdas]
|
||||
bigram_embed_params = list(self.bigram_embed.parameters())
|
||||
bigram_lambda_params = [self.bigram_lambdas]
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(bigram_embed_params) + len(bigram_lambda_params)
|
||||
# Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars
|
||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)
|
||||
|
||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
|
||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||
print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
|
||||
adam_groups = [
|
||||
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
|
||||
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
|
||||
dict(params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding
|
||||
dict(params=bigram_embed_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding
|
||||
dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream
|
||||
dict(params=x0_params, lr=scalar_lr, betas=(0.96, 0.95)), # higher beta1 for x0 scalars
|
||||
dict(params=bigram_lambda_params, lr=scalar_lr, betas=(0.96, 0.95)), # same treatment as x0 lambdas
|
||||
|
||||
# Build param_groups with all required fields explicit
|
||||
param_groups = [
|
||||
# AdamW groups (embeddings, lm_head, scalars)
|
||||
dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||
dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
|
||||
]
|
||||
adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon
|
||||
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
|
||||
adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
|
||||
# Create the Muon optimizer for the linear layers
|
||||
muon_kwargs = dict(lr=matrix_lr, momentum=0.95, weight_decay=weight_decay)
|
||||
MuonFactory = DistMuon if ddp else Muon
|
||||
muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
|
||||
# Combine them the two optimizers into one list
|
||||
optimizers = [adamw_optimizer, muon_optimizer]
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["initial_lr"] = group["lr"]
|
||||
return optimizers
|
||||
# Muon groups (matrix params, grouped by shape for stacking)
|
||||
for shape in sorted({p.shape for p in matrix_params}):
|
||||
group_params = [p for p in matrix_params if p.shape == shape]
|
||||
param_groups.append(dict(
|
||||
kind='muon', params=group_params, lr=matrix_lr,
|
||||
momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay,
|
||||
))
|
||||
|
||||
Factory = DistMuonAdamW if ddp else MuonAdamW
|
||||
optimizer = Factory(param_groups)
|
||||
for group in optimizer.param_groups:
|
||||
group["initial_lr"] = group["lr"]
|
||||
return optimizer
|
||||
|
||||
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
||||
B, T = idx.size()
|
||||
@@ -446,11 +398,10 @@ class GPT(nn.Module):
|
||||
|
||||
# Forward the trunk of the Transformer
|
||||
x = self.transformer.wte(idx) # embed current token
|
||||
x0_bigram = self.bigram_embed(idx) # embed current bigram (via hash lookup)
|
||||
x = norm(x)
|
||||
x0 = x # save initial normalized embedding for x0 residual
|
||||
for i, block in enumerate(self.transformer.h):
|
||||
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 + self.bigram_lambdas[i] * x0_bigram
|
||||
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
||||
ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None
|
||||
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
|
||||
x = norm(x)
|
||||
|
||||
352
nanochat/muon.py
352
nanochat/muon.py
@@ -1,352 +0,0 @@
|
||||
"""
|
||||
Muon optimizer adapted and simplified from modded-nanogpt.
|
||||
https://github.com/KellerJordan/modded-nanogpt
|
||||
|
||||
Background:
|
||||
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
||||
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
||||
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
||||
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
||||
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
||||
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||||
performance at all relative to UV^T, where USV^T = G is the SVD.
|
||||
|
||||
Here, an alternative to Newton-Schulz iteration with potentially better convergence properties:
|
||||
Polar Express Sign Method for orthogonalization.
|
||||
https://arxiv.org/pdf/2505.16932
|
||||
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
|
||||
|
||||
Some of the changes in nanochat implementation:
|
||||
- Uses a simpler, more general approach to parameter grouping and stacking
|
||||
- Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step
|
||||
- Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format)
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.distributed as dist
|
||||
|
||||
# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
|
||||
# From https://arxiv.org/pdf/2505.16932
|
||||
polar_express_coeffs = [
|
||||
(8.156554524902461, -22.48329292557795, 15.878769915207462),
|
||||
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
|
||||
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
|
||||
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
|
||||
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
|
||||
]
|
||||
|
||||
@torch.compile(dynamic=False, fullgraph=True)
|
||||
def muon_step_fused(
|
||||
stacked_grads: Tensor,
|
||||
stacked_params: Tensor,
|
||||
momentum_buffer: Tensor,
|
||||
second_momentum_buffer: Tensor,
|
||||
momentum_t: Tensor,
|
||||
lr_t: Tensor,
|
||||
wd_t: Tensor,
|
||||
beta2_t: Tensor,
|
||||
ns_steps: int,
|
||||
red_dim: int,
|
||||
) -> None:
|
||||
"""
|
||||
Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update
|
||||
All in one compiled graph to eliminate Python overhead between ops.
|
||||
Some of the constants are 0-D CPU tensors to avoid recompilation when values change.
|
||||
"""
|
||||
|
||||
# Nesterov momentum
|
||||
momentum = momentum_t.to(stacked_grads.dtype)
|
||||
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
|
||||
g = stacked_grads.lerp_(momentum_buffer, momentum)
|
||||
|
||||
# Polar express
|
||||
X = g.bfloat16()
|
||||
if g.size(-2) > g.size(-1):
|
||||
X = X.mT
|
||||
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
|
||||
for a, b, c in polar_express_coeffs[:ns_steps]:
|
||||
A = X @ X.mT
|
||||
B = b * A + c * (A @ A)
|
||||
X = a * X + B @ X
|
||||
if g.size(-2) > g.size(-1):
|
||||
X = X.mT
|
||||
g = X
|
||||
|
||||
# Variance reduction
|
||||
beta2 = beta2_t.to(g.dtype)
|
||||
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
|
||||
red_dim_size = g.size(red_dim)
|
||||
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
|
||||
v_norm = v_norm_sq.sqrt()
|
||||
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
|
||||
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
|
||||
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
|
||||
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
|
||||
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
|
||||
g = g * final_scale.to(g.dtype)
|
||||
|
||||
# Cautious weight decay + parameter update
|
||||
lr = lr_t.to(g.dtype)
|
||||
wd = wd_t.to(g.dtype)
|
||||
mask = (g * stacked_params) >= 0
|
||||
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
|
||||
|
||||
class Muon(torch.optim.Optimizer):
|
||||
"""
|
||||
Muon - MomentUm Orthogonalized by Newton-schulz
|
||||
|
||||
https://kellerjordan.github.io/posts/muon/
|
||||
|
||||
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
|
||||
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
|
||||
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
|
||||
the advantage that it can be stably run in bfloat16 on the GPU.
|
||||
|
||||
Some warnings:
|
||||
- This optimizer should not be used for the embedding layer, the final fully connected layer,
|
||||
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
|
||||
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
|
||||
|
||||
Arguments:
|
||||
lr: The learning rate used by the internal SGD.
|
||||
momentum: The momentum used by the internal SGD.
|
||||
ns_steps: The number of Newton-Schulz iteration steps to use.
|
||||
beta2: The decay rate for the second moment (variance) estimate. Set to None to disable.
|
||||
weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
|
||||
"""
|
||||
def __init__(self, params, lr=0.02, momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=0.0):
|
||||
defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
|
||||
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
|
||||
params = list(params) # ensure we have a list, not an e.g. (exhaustible) iterator
|
||||
# Group by shape so we can stack tensors
|
||||
shapes = sorted({p.shape for p in params})
|
||||
param_groups = []
|
||||
for shape in shapes:
|
||||
group_params = [p for p in params if p.shape == shape]
|
||||
param_groups.append(dict(params=group_params))
|
||||
super().__init__(param_groups, defaults)
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
for group in self.param_groups:
|
||||
params: list[Tensor] = group["params"]
|
||||
if not params:
|
||||
continue
|
||||
|
||||
# Get or create group-level buffers (stored in first param's state for convenience)
|
||||
state = self.state[params[0]]
|
||||
num_params = len(params) # e.g.: 12 (for a d12 model)
|
||||
# e.g.: shape = (768, 3072), device = cuda:0, dtype = torch.float32, for one of the MLP projections
|
||||
shape, device, dtype = params[0].shape, params[0].device, params[0].dtype
|
||||
|
||||
# Momentum for every individual parameter
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
|
||||
momentum_buffer = state["momentum_buffer"] # e.g.: (12, 768, 3072)
|
||||
|
||||
# Second momentum buffer is factored, either per-row or per-column
|
||||
if "second_momentum_buffer" not in state:
|
||||
if shape[-2] >= shape[-1]:
|
||||
state["second_momentum_buffer"] = torch.zeros(num_params, shape[-2], 1, dtype=dtype, device=device)
|
||||
else:
|
||||
state["second_momentum_buffer"] = torch.zeros(num_params, 1, shape[-1], dtype=dtype, device=device)
|
||||
second_momentum_buffer = state["second_momentum_buffer"] # (12, 1, 3072)
|
||||
red_dim = -1 if shape[-2] >= shape[-1] else -2 # e.g.: -2
|
||||
|
||||
# Stack grads and params
|
||||
stacked_grads = torch.stack([p.grad for p in params]) # (12, 768, 3072)
|
||||
stacked_params = torch.stack(params) # (12, 768, 3072)
|
||||
|
||||
# Fill all the 0-D tensors with current values
|
||||
self._momentum_t.fill_(group["momentum"])
|
||||
self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
||||
self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
||||
self._wd_t.fill_(group["weight_decay"])
|
||||
|
||||
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
|
||||
muon_step_fused(
|
||||
stacked_grads,
|
||||
stacked_params,
|
||||
momentum_buffer,
|
||||
second_momentum_buffer,
|
||||
self._momentum_t,
|
||||
self._lr_t,
|
||||
self._wd_t,
|
||||
self._beta2_t,
|
||||
group["ns_steps"],
|
||||
red_dim,
|
||||
)
|
||||
|
||||
# Copy back to original params: [(768, 3072), (768, 3072), ...] <- (12, 768, 3072)
|
||||
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
|
||||
|
||||
|
||||
class DistMuon(torch.optim.Optimizer):
|
||||
"""
|
||||
Distributed version of the Muon optimizer.
|
||||
"""
|
||||
def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
|
||||
ns_steps: int = 5, beta2: float = 0.95, weight_decay: float = 0.0):
|
||||
defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
|
||||
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
|
||||
params = list(params)
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
# Group all parameters by their shape
|
||||
shapes = sorted({p.shape for p in params}) # sort for deterministic ordering across ranks
|
||||
param_groups = []
|
||||
for shape in shapes:
|
||||
group_params = [p for p in params if p.shape == shape]
|
||||
device, dtype = group_params[0].device, group_params[0].dtype
|
||||
assert all(p.device == device for p in group_params)
|
||||
assert all(p.dtype == dtype for p in group_params)
|
||||
# Compute chunk size for this group (how many params each rank owns)
|
||||
chunk_size = (len(group_params) + world_size - 1) // world_size
|
||||
if rank == 0:
|
||||
print(f"Muon: {len(group_params)} params of shape {shape}, chunk_size={chunk_size}")
|
||||
param_groups.append(dict(params=group_params, chunk_size=chunk_size))
|
||||
super().__init__(param_groups, defaults)
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
# Ensure all grads exist
|
||||
assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
|
||||
|
||||
# First pass: stack grads and kick off reduce_scatter for each group
|
||||
group_infos = []
|
||||
for group in self.param_groups:
|
||||
params: list[Tensor] = group["params"]
|
||||
chunk_size = group["chunk_size"]
|
||||
padded_num_params = chunk_size * world_size
|
||||
shape = params[0].shape
|
||||
device, dtype = params[0].device, params[0].dtype
|
||||
|
||||
# Stack all gradients into a single tensor (single kernel via torch.stack)
|
||||
grad_stack = torch.stack([p.grad for p in params])
|
||||
stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device)
|
||||
stacked_grads[:len(params)].copy_(grad_stack)
|
||||
# Zero-pad if we have fewer params than padded size
|
||||
if len(params) < padded_num_params:
|
||||
stacked_grads[len(params):].zero_()
|
||||
|
||||
# Output buffer for this rank's chunk
|
||||
grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
||||
|
||||
# Async reduce_scatter on the stacked tensor
|
||||
reduce_future = dist.reduce_scatter_tensor(
|
||||
grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True
|
||||
).get_future()
|
||||
|
||||
group_infos.append(dict(
|
||||
grad_chunk=grad_chunk,
|
||||
reduce_future=reduce_future,
|
||||
stacked_grads=stacked_grads, # reuse for all_gather output
|
||||
))
|
||||
|
||||
# Second pass: wait for reduce, compute batched updates, kick off all_gather
|
||||
all_gather_futures = []
|
||||
for group, info in zip(self.param_groups, group_infos):
|
||||
info["reduce_future"].wait()
|
||||
|
||||
params = group["params"]
|
||||
chunk_size = group["chunk_size"]
|
||||
shape = params[0].shape
|
||||
device, dtype = params[0].device, params[0].dtype
|
||||
grad_chunk = info["grad_chunk"]
|
||||
|
||||
# How many params does this rank actually own?
|
||||
start_idx = rank * chunk_size
|
||||
num_owned = min(chunk_size, max(0, len(params) - start_idx))
|
||||
|
||||
# Get or create group-level state (stored keyed by first param)
|
||||
state = self.state[params[0]]
|
||||
|
||||
# Momentum buffer
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
|
||||
momentum_buffer = state["momentum_buffer"]
|
||||
|
||||
# Second momentum buffer is factored, either per-row or per-column
|
||||
if "second_momentum_buffer" not in state:
|
||||
if shape[-2] >= shape[-1]:
|
||||
state["second_momentum_buffer"] = torch.zeros(chunk_size, shape[-2], 1, dtype=dtype, device=device)
|
||||
else:
|
||||
state["second_momentum_buffer"] = torch.zeros(chunk_size, 1, shape[-1], dtype=dtype, device=device)
|
||||
second_momentum_buffer = state["second_momentum_buffer"]
|
||||
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
||||
|
||||
# Build updated_params tensor for all_gather
|
||||
updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
||||
|
||||
if num_owned > 0:
|
||||
# Stack owned params (single kernel via torch.stack)
|
||||
owned_params = [params[start_idx + i] for i in range(num_owned)]
|
||||
stacked_owned_params = torch.stack(owned_params)
|
||||
|
||||
# Get owned slices of buffers and grads
|
||||
owned_grads = grad_chunk[:num_owned]
|
||||
owned_momentum = momentum_buffer[:num_owned]
|
||||
owned_second_momentum = second_momentum_buffer[:num_owned]
|
||||
|
||||
# Fill 0-D tensors with current values
|
||||
self._momentum_t.fill_(group["momentum"])
|
||||
self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
||||
self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
||||
self._wd_t.fill_(group["weight_decay"])
|
||||
|
||||
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
|
||||
muon_step_fused(
|
||||
owned_grads,
|
||||
stacked_owned_params,
|
||||
owned_momentum,
|
||||
owned_second_momentum,
|
||||
self._momentum_t,
|
||||
self._lr_t,
|
||||
self._wd_t,
|
||||
self._beta2_t,
|
||||
group["ns_steps"],
|
||||
red_dim,
|
||||
)
|
||||
|
||||
# Copy updated params to output buffer
|
||||
updated_params[:num_owned].copy_(stacked_owned_params)
|
||||
|
||||
# Zero-pad the rest (for ranks that own fewer params)
|
||||
if num_owned < chunk_size:
|
||||
updated_params[num_owned:].zero_()
|
||||
|
||||
# Reuse stacked_grads buffer for all_gather output
|
||||
stacked_params = info["stacked_grads"]
|
||||
|
||||
# Async all_gather to replicate updated params to all ranks
|
||||
gather_future = dist.all_gather_into_tensor(
|
||||
stacked_params, updated_params, async_op=True
|
||||
).get_future()
|
||||
|
||||
all_gather_futures.append(dict(
|
||||
gather_future=gather_future,
|
||||
stacked_params=stacked_params,
|
||||
params=params,
|
||||
))
|
||||
|
||||
# Final pass: wait for all_gather and copy back to params
|
||||
for info in all_gather_futures:
|
||||
info["gather_future"].wait()
|
||||
stacked_params = info["stacked_params"]
|
||||
params = info["params"]
|
||||
# Batched copy back (single kernel instead of N individual copies)
|
||||
torch._foreach_copy_(params, list(stacked_params[:len(params)].unbind(0)))
|
||||
528
nanochat/optim.py
Normal file
528
nanochat/optim.py
Normal file
@@ -0,0 +1,528 @@
|
||||
"""
|
||||
A nice and efficient mixed AdamW/Muon Combined Optimizer.
|
||||
Usually the embeddings and scalars go into AdamW, and the matrix parameters go into Muon.
|
||||
Two versions are provided (MuonAdamW, DistMuonAdamW), for single GPU and distributed.
|
||||
|
||||
Addapted from: https://github.com/KellerJordan/modded-nanogpt
|
||||
Further contributions from @karpathy and @chrisjmccormick.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
"""
|
||||
Good old AdamW optimizer, fused kernel.
|
||||
https://arxiv.org/abs/1711.05101
|
||||
"""
|
||||
|
||||
@torch.compile(dynamic=False, fullgraph=True)
|
||||
def adamw_step_fused(
|
||||
p: Tensor, # (32768, 768) - parameter tensor
|
||||
grad: Tensor, # (32768, 768) - gradient, same shape as p
|
||||
exp_avg: Tensor, # (32768, 768) - first moment, same shape as p
|
||||
exp_avg_sq: Tensor, # (32768, 768) - second moment, same shape as p
|
||||
step_t: Tensor, # () - 0-D CPU tensor, step count
|
||||
lr_t: Tensor, # () - 0-D CPU tensor, learning rate
|
||||
beta1_t: Tensor, # () - 0-D CPU tensor, beta1
|
||||
beta2_t: Tensor, # () - 0-D CPU tensor, beta2
|
||||
eps_t: Tensor, # () - 0-D CPU tensor, epsilon
|
||||
wd_t: Tensor, # () - 0-D CPU tensor, weight decay
|
||||
) -> None:
|
||||
"""
|
||||
Fused AdamW step: weight_decay -> momentum_update -> bias_correction -> param_update
|
||||
All in one compiled graph to eliminate Python overhead between ops.
|
||||
The 0-D CPU tensors avoid recompilation when hyperparameter values change.
|
||||
"""
|
||||
# Weight decay (decoupled, applied before the update)
|
||||
p.mul_(1 - lr_t * wd_t)
|
||||
# Update running averages (lerp_ is cleaner and fuses well)
|
||||
exp_avg.lerp_(grad, 1 - beta1_t)
|
||||
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
|
||||
# Bias corrections
|
||||
bias1 = 1 - beta1_t ** step_t
|
||||
bias2 = 1 - beta2_t ** step_t
|
||||
# Compute update and apply
|
||||
denom = (exp_avg_sq / bias2).sqrt() + eps_t
|
||||
step_size = lr_t / bias1
|
||||
p.add_(exp_avg / denom, alpha=-step_size)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
"""
|
||||
Muon optimizer adapted and simplified from modded-nanogpt.
|
||||
https://github.com/KellerJordan/modded-nanogpt
|
||||
|
||||
Background:
|
||||
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
||||
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
||||
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
||||
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
||||
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
||||
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||||
performance at all relative to UV^T, where USV^T = G is the SVD.
|
||||
|
||||
Here, an alternative to Newton-Schulz iteration with potentially better convergence properties:
|
||||
Polar Express Sign Method for orthogonalization.
|
||||
https://arxiv.org/pdf/2505.16932
|
||||
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
|
||||
|
||||
Some of the changes in nanochat implementation:
|
||||
- Uses a simpler, more general approach to parameter grouping and stacking
|
||||
- Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step
|
||||
- Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format)
|
||||
"""
|
||||
|
||||
# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
|
||||
# From https://arxiv.org/pdf/2505.16932
|
||||
polar_express_coeffs = [
|
||||
(8.156554524902461, -22.48329292557795, 15.878769915207462),
|
||||
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
|
||||
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
|
||||
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
|
||||
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
|
||||
]
|
||||
|
||||
@torch.compile(dynamic=False, fullgraph=True)
|
||||
def muon_step_fused(
|
||||
stacked_grads: Tensor, # (12, 768, 3072) - stacked gradients
|
||||
stacked_params: Tensor, # (12, 768, 3072) - stacked parameters
|
||||
momentum_buffer: Tensor, # (12, 768, 3072) - first moment buffer
|
||||
second_momentum_buffer: Tensor, # (12, 768, 1) or (12, 1, 3072) - factored second moment
|
||||
momentum_t: Tensor, # () - 0-D CPU tensor, momentum coefficient
|
||||
lr_t: Tensor, # () - 0-D CPU tensor, learning rate
|
||||
wd_t: Tensor, # () - 0-D CPU tensor, weight decay
|
||||
beta2_t: Tensor, # () - 0-D CPU tensor, beta2 for second moment
|
||||
ns_steps: int, # 5 - number of Newton-Schulz/Polar Express iterations
|
||||
red_dim: int, # -1 or -2 - reduction dimension for variance
|
||||
) -> None:
|
||||
"""
|
||||
Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update
|
||||
All in one compiled graph to eliminate Python overhead between ops.
|
||||
Some of the constants are 0-D CPU tensors to avoid recompilation when values change.
|
||||
"""
|
||||
|
||||
# Nesterov momentum
|
||||
momentum = momentum_t.to(stacked_grads.dtype)
|
||||
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
|
||||
g = stacked_grads.lerp_(momentum_buffer, momentum)
|
||||
|
||||
# Polar express
|
||||
X = g.bfloat16()
|
||||
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
|
||||
if g.size(-2) > g.size(-1): # Tall matrix
|
||||
for a, b, c in polar_express_coeffs[:ns_steps]:
|
||||
A = X.mT @ X
|
||||
B = b * A + c * (A @ A)
|
||||
X = a * X + X @ B
|
||||
else: # Wide matrix (original math)
|
||||
for a, b, c in polar_express_coeffs[:ns_steps]:
|
||||
A = X @ X.mT
|
||||
B = b * A + c * (A @ A)
|
||||
X = a * X + B @ X
|
||||
g = X
|
||||
|
||||
# Variance reduction
|
||||
beta2 = beta2_t.to(g.dtype)
|
||||
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
|
||||
red_dim_size = g.size(red_dim)
|
||||
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
|
||||
v_norm = v_norm_sq.sqrt()
|
||||
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
|
||||
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
|
||||
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
|
||||
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
|
||||
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
|
||||
g = g * final_scale.to(g.dtype)
|
||||
|
||||
# Cautious weight decay + parameter update
|
||||
lr = lr_t.to(g.dtype)
|
||||
wd = wd_t.to(g.dtype)
|
||||
mask = (g * stacked_params) >= 0
|
||||
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Single GPU version of the MuonAdamW optimizer.
|
||||
# Used mostly for reference, debugging and testing.
|
||||
|
||||
class MuonAdamW(torch.optim.Optimizer):
|
||||
"""
|
||||
Combined optimizer: Muon for 2D matrix params, AdamW for others, single GPU version.
|
||||
|
||||
AdamW - Fused AdamW optimizer step.
|
||||
|
||||
Muon - MomentUm Orthogonalized by Newton-schulz
|
||||
https://kellerjordan.github.io/posts/muon/
|
||||
|
||||
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
|
||||
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
|
||||
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
|
||||
the advantage that it can be stably run in bfloat16 on the GPU.
|
||||
|
||||
Some warnings:
|
||||
- The Muon optimizer should not be used for the embedding layer, the final fully connected layer,
|
||||
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
|
||||
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
|
||||
|
||||
Arguments:
|
||||
param_groups: List of dicts, each containing:
|
||||
- 'params': List of parameters
|
||||
- 'kind': 'adamw' or 'muon'
|
||||
- For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
|
||||
- For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
|
||||
"""
|
||||
def __init__(self, param_groups: list[dict]):
|
||||
super().__init__(param_groups, defaults={})
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
# AdamW tensors
|
||||
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
# Muon tensors
|
||||
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
|
||||
def _step_adamw(self, group: dict) -> None:
|
||||
"""
|
||||
AdamW update for each param in the group individually.
|
||||
Lazy init the state, fill in all 0-D tensors, call the fused kernel.
|
||||
"""
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad
|
||||
state = self.state[p]
|
||||
|
||||
# State init
|
||||
if not state:
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p)
|
||||
exp_avg = state['exp_avg']
|
||||
exp_avg_sq = state['exp_avg_sq']
|
||||
state['step'] += 1
|
||||
|
||||
# Fill 0-D tensors with current values
|
||||
self._adamw_step_t.fill_(state['step'])
|
||||
self._adamw_lr_t.fill_(group['lr'])
|
||||
self._adamw_beta1_t.fill_(group['betas'][0])
|
||||
self._adamw_beta2_t.fill_(group['betas'][1])
|
||||
self._adamw_eps_t.fill_(group['eps'])
|
||||
self._adamw_wd_t.fill_(group['weight_decay'])
|
||||
|
||||
# Fused update: weight_decay -> momentum -> bias_correction -> param_update
|
||||
adamw_step_fused(
|
||||
p, grad, exp_avg, exp_avg_sq,
|
||||
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
|
||||
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
|
||||
)
|
||||
|
||||
def _step_muon(self, group: dict) -> None:
|
||||
"""
|
||||
Muon update for all params in the group (stacked for efficiency).
|
||||
Lazy init the state, fill in all 0-D tensors, call the fused kernel.
|
||||
"""
|
||||
params: list[Tensor] = group['params']
|
||||
if not params:
|
||||
return
|
||||
|
||||
# Get or create group-level buffers (stored in first param's state for convenience)
|
||||
p = params[0]
|
||||
state = self.state[p]
|
||||
num_params = len(params)
|
||||
shape, device, dtype = p.shape, p.device, p.dtype
|
||||
|
||||
# Momentum for every individual parameter
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
|
||||
momentum_buffer = state["momentum_buffer"]
|
||||
|
||||
# Second momentum buffer is factored, either per-row or per-column
|
||||
if "second_momentum_buffer" not in state:
|
||||
state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1])
|
||||
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
|
||||
second_momentum_buffer = state["second_momentum_buffer"]
|
||||
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
||||
|
||||
# Stack grads and params (NOTE: this assumes all params have the same shape)
|
||||
stacked_grads = torch.stack([p.grad for p in params])
|
||||
stacked_params = torch.stack(params)
|
||||
|
||||
# Fill all the 0-D tensors with current values
|
||||
self._muon_momentum_t.fill_(group["momentum"])
|
||||
self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
||||
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
||||
self._muon_wd_t.fill_(group["weight_decay"])
|
||||
|
||||
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
|
||||
muon_step_fused(
|
||||
stacked_grads,
|
||||
stacked_params,
|
||||
momentum_buffer,
|
||||
second_momentum_buffer,
|
||||
self._muon_momentum_t,
|
||||
self._muon_lr_t,
|
||||
self._muon_wd_t,
|
||||
self._muon_beta2_t,
|
||||
group["ns_steps"],
|
||||
red_dim,
|
||||
)
|
||||
|
||||
# Copy back to original params
|
||||
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
for group in self.param_groups:
|
||||
if group['kind'] == 'adamw':
|
||||
self._step_adamw(group)
|
||||
elif group['kind'] == 'muon':
|
||||
self._step_muon(group)
|
||||
else:
|
||||
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Distributed version of the MuonAdamW optimizer.
|
||||
# Used for training on multiple GPUs.
|
||||
|
||||
class DistMuonAdamW(torch.optim.Optimizer):
|
||||
"""
|
||||
Combined distributed optimizer: Muon for 2D matrix params, AdamW for others.
|
||||
|
||||
See MuonAdamW for the algorithmic details of each optimizer. This class adds
|
||||
distributed communication to enable multi-GPU training without PyTorch DDP.
|
||||
|
||||
Design Goals:
|
||||
- Overlap communication with computation (async ops)
|
||||
- Minimize memory by sharding optimizer states across ranks (ZeRO-2 style)
|
||||
- Batch small tensors into single comm ops where possible
|
||||
|
||||
Communication Pattern (3-phase async):
|
||||
We use a 3-phase structure to maximize overlap between communication and compute:
|
||||
|
||||
Phase 1: Launch all async reduce ops
|
||||
- Kick off all reduce_scatter/all_reduce operations
|
||||
- Don't wait - let them run in background while we continue
|
||||
|
||||
Phase 2: Wait for reduces, compute updates, launch gathers
|
||||
- For each group: wait for its reduce, compute the update, launch gather
|
||||
- By processing groups in order, earlier gathers run while later computes happen
|
||||
|
||||
Phase 3: Wait for gathers, copy back
|
||||
- Wait for all gathers to complete
|
||||
- Copy updated params back to original tensors (Muon only)
|
||||
|
||||
AdamW Communication (ZeRO-2 style):
|
||||
- Small params (<1024 elements): all_reduce gradients, update full param on each rank.
|
||||
Optimizer state is replicated but these params are tiny (scalars, biases).
|
||||
- Large params: reduce_scatter gradients so each rank gets 1/N of the grad, update
|
||||
only that slice, then all_gather the updated slices. Optimizer state (exp_avg,
|
||||
exp_avg_sq) is sharded - each rank only stores state for its slice.
|
||||
Requires param.shape[0] divisible by world_size.
|
||||
|
||||
Muon Communication (stacked + chunked):
|
||||
- All params in a Muon group must have the same shape (caller's responsibility).
|
||||
- Stack all K params into a single (K, *shape) tensor for efficient comm.
|
||||
- Divide K params across N ranks: each rank "owns" ceil(K/N) params.
|
||||
- reduce_scatter the stacked grads so each rank gets its chunk.
|
||||
- Each rank computes Muon update only for params it owns.
|
||||
- all_gather the updated params back to all ranks.
|
||||
- Optimizer state (momentum_buffer, second_momentum_buffer) is sharded by chunk.
|
||||
- Padding: if K doesn't divide evenly, we zero-pad to (ceil(K/N) * N) for comm,
|
||||
then ignore the padding when copying back.
|
||||
|
||||
Buffer Reuse:
|
||||
- For Muon, we allocate stacked_grads for reduce_scatter input, then reuse the
|
||||
same buffer as the output for all_gather (stacked_params). This saves memory
|
||||
since we don't need both buffers simultaneously.
|
||||
|
||||
Arguments:
|
||||
param_groups: List of dicts, each containing:
|
||||
- 'params': List of parameters
|
||||
- 'kind': 'adamw' or 'muon'
|
||||
- For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
|
||||
- For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
|
||||
"""
|
||||
def __init__(self, param_groups: list[dict]):
|
||||
super().__init__(param_groups, defaults={})
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
|
||||
def _reduce_adamw(self, group: dict, world_size: int) -> dict:
|
||||
"""Launch async reduce ops for AdamW group. Returns info dict with per-param infos."""
|
||||
param_infos = {}
|
||||
for p in group['params']:
|
||||
grad = p.grad
|
||||
if p.numel() < 1024:
|
||||
# Small params: all_reduce (no scatter/gather needed)
|
||||
future = dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
||||
param_infos[p] = dict(future=future, grad_slice=grad, is_small=True)
|
||||
else:
|
||||
# Large params: reduce_scatter
|
||||
rank_size = grad.shape[0] // world_size
|
||||
grad_slice = torch.empty_like(grad[:rank_size])
|
||||
future = dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
||||
param_infos[p] = dict(future=future, grad_slice=grad_slice, is_small=False)
|
||||
return dict(param_infos=param_infos)
|
||||
|
||||
def _reduce_muon(self, group: dict, world_size: int) -> dict:
|
||||
"""Launch async reduce op for Muon group. Returns info dict."""
|
||||
params = group['params']
|
||||
chunk_size = (len(params) + world_size - 1) // world_size
|
||||
padded_num_params = chunk_size * world_size
|
||||
p = params[0]
|
||||
shape, device, dtype = p.shape, p.device, p.dtype
|
||||
|
||||
# Stack grads and zero-pad to padded_num_params
|
||||
grad_stack = torch.stack([p.grad for p in params])
|
||||
stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device)
|
||||
stacked_grads[:len(params)].copy_(grad_stack)
|
||||
if len(params) < padded_num_params:
|
||||
stacked_grads[len(params):].zero_()
|
||||
|
||||
# Reduce_scatter to get this rank's chunk
|
||||
grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
||||
future = dist.reduce_scatter_tensor(grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
||||
|
||||
return dict(future=future, grad_chunk=grad_chunk, stacked_grads=stacked_grads, chunk_size=chunk_size)
|
||||
|
||||
def _compute_adamw(self, group: dict, info: dict, gather_list: list, rank: int, world_size: int) -> None:
|
||||
"""Wait for reduce, compute AdamW updates, launch gathers for large params."""
|
||||
param_infos = info['param_infos']
|
||||
for p in group['params']:
|
||||
pinfo = param_infos[p]
|
||||
pinfo['future'].wait()
|
||||
grad_slice = pinfo['grad_slice']
|
||||
state = self.state[p]
|
||||
|
||||
# For small params, operate on full param; for large, operate on slice
|
||||
if pinfo['is_small']:
|
||||
p_slice = p
|
||||
else:
|
||||
rank_size = p.shape[0] // world_size
|
||||
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
|
||||
|
||||
# State init
|
||||
if not state:
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p_slice)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p_slice)
|
||||
state['step'] += 1
|
||||
|
||||
# Fill 0-D tensors and run fused kernel
|
||||
self._adamw_step_t.fill_(state['step'])
|
||||
self._adamw_lr_t.fill_(group['lr'])
|
||||
self._adamw_beta1_t.fill_(group['betas'][0])
|
||||
self._adamw_beta2_t.fill_(group['betas'][1])
|
||||
self._adamw_eps_t.fill_(group['eps'])
|
||||
self._adamw_wd_t.fill_(group['weight_decay'])
|
||||
adamw_step_fused(
|
||||
p_slice, grad_slice, state['exp_avg'], state['exp_avg_sq'],
|
||||
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
|
||||
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
|
||||
)
|
||||
|
||||
# Large params need all_gather
|
||||
if not pinfo['is_small']:
|
||||
future = dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()
|
||||
gather_list.append(dict(future=future, params=None))
|
||||
|
||||
def _compute_muon(self, group: dict, info: dict, gather_list: list, rank: int) -> None:
|
||||
"""Wait for reduce, compute Muon updates, launch gather."""
|
||||
info['future'].wait()
|
||||
params = group['params']
|
||||
chunk_size = info['chunk_size']
|
||||
grad_chunk = info['grad_chunk']
|
||||
p = params[0]
|
||||
shape, device, dtype = p.shape, p.device, p.dtype
|
||||
|
||||
# How many params does this rank own?
|
||||
start_idx = rank * chunk_size
|
||||
num_owned = min(chunk_size, max(0, len(params) - start_idx))
|
||||
|
||||
# Get or create group-level state
|
||||
state = self.state[p]
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
|
||||
if "second_momentum_buffer" not in state:
|
||||
state_shape = (chunk_size, shape[-2], 1) if shape[-2] >= shape[-1] else (chunk_size, 1, shape[-1])
|
||||
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
|
||||
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
||||
|
||||
# Build output buffer for all_gather
|
||||
updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
||||
|
||||
if num_owned > 0:
|
||||
owned_params = [params[start_idx + i] for i in range(num_owned)]
|
||||
stacked_owned = torch.stack(owned_params)
|
||||
|
||||
# Fill 0-D tensors and run fused kernel
|
||||
self._muon_momentum_t.fill_(group["momentum"])
|
||||
self._muon_beta2_t.fill_(group["beta2"])
|
||||
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
||||
self._muon_wd_t.fill_(group["weight_decay"])
|
||||
muon_step_fused(
|
||||
grad_chunk[:num_owned], stacked_owned,
|
||||
state["momentum_buffer"][:num_owned], state["second_momentum_buffer"][:num_owned],
|
||||
self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t, self._muon_beta2_t,
|
||||
group["ns_steps"], red_dim,
|
||||
)
|
||||
updated_params[:num_owned].copy_(stacked_owned)
|
||||
|
||||
if num_owned < chunk_size:
|
||||
updated_params[num_owned:].zero_()
|
||||
|
||||
# Reuse stacked_grads buffer for all_gather output
|
||||
stacked_params = info["stacked_grads"]
|
||||
future = dist.all_gather_into_tensor(stacked_params, updated_params, async_op=True).get_future()
|
||||
gather_list.append(dict(future=future, stacked_params=stacked_params, params=params))
|
||||
|
||||
def _finish_gathers(self, gather_list: list) -> None:
|
||||
"""Wait for all gathers and copy Muon params back."""
|
||||
for info in gather_list:
|
||||
info["future"].wait()
|
||||
if info["params"] is not None:
|
||||
# Muon: copy from stacked buffer back to individual params
|
||||
torch._foreach_copy_(info["params"], list(info["stacked_params"][:len(info["params"])].unbind(0)))
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
# Phase 1: launch all async reduce ops
|
||||
reduce_infos: list[dict] = []
|
||||
for group in self.param_groups:
|
||||
if group['kind'] == 'adamw':
|
||||
reduce_infos.append(self._reduce_adamw(group, world_size))
|
||||
elif group['kind'] == 'muon':
|
||||
reduce_infos.append(self._reduce_muon(group, world_size))
|
||||
else:
|
||||
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
|
||||
|
||||
# Phase 2: wait for reduces, compute updates, launch gathers
|
||||
gather_list: list[dict] = []
|
||||
for group, info in zip(self.param_groups, reduce_infos):
|
||||
if group['kind'] == 'adamw':
|
||||
self._compute_adamw(group, info, gather_list, rank, world_size)
|
||||
elif group['kind'] == 'muon':
|
||||
self._compute_muon(group, info, gather_list, rank)
|
||||
else:
|
||||
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
|
||||
|
||||
# Phase 3: wait for gathers, copy back
|
||||
self._finish_gathers(gather_list)
|
||||
@@ -211,9 +211,9 @@ print0(f"Tokens : Scaling params ratio: {args.total_batch_size * num_iterations
|
||||
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
|
||||
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
|
||||
adam_betas = (args.adam_beta1, args.adam_beta2)
|
||||
optimizers = model.setup_optimizers(
|
||||
optimizer = model.setup_optimizer(
|
||||
unembedding_lr=args.unembedding_lr * batch_lr_scale,
|
||||
embedding_lr=args.embedding_lr * batch_lr_scale,
|
||||
matrix_lr=args.matrix_lr * batch_lr_scale,
|
||||
@@ -221,12 +221,10 @@ optimizers = model.setup_optimizers(
|
||||
adam_betas=adam_betas,
|
||||
scalar_lr=args.scalar_lr * batch_lr_scale,
|
||||
)
|
||||
adamw_optimizer, muon_optimizer = optimizers
|
||||
|
||||
if resuming:
|
||||
for opt, dat in zip(optimizers, optimizer_data):
|
||||
opt.load_state_dict(dat)
|
||||
del optimizer_data # free up the memory
|
||||
optimizer.load_state_dict(optimizer_data)
|
||||
del optimizer_data
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the DataLoaders for train/val
|
||||
@@ -344,7 +342,7 @@ while True:
|
||||
checkpoint_dir,
|
||||
step,
|
||||
orig_model.state_dict(), # model parameters
|
||||
[opt.state_dict() for opt in optimizers], # optimizer states
|
||||
optimizer.state_dict(), # optimizer state
|
||||
{ # metadata saved as json
|
||||
"step": step,
|
||||
"val_bpb": val_bpb, # loss at last step
|
||||
@@ -378,18 +376,16 @@ while True:
|
||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
loss.backward()
|
||||
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
# step the optimizers
|
||||
# step the optimizer
|
||||
lrm = get_lr_multiplier(step)
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
muon_momentum = get_muon_momentum(step)
|
||||
muon_weight_decay = get_weight_decay(step)
|
||||
for group in muon_optimizer.param_groups:
|
||||
group["momentum"] = muon_momentum
|
||||
group["weight_decay"] = muon_weight_decay
|
||||
for opt in optimizers:
|
||||
opt.step()
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
if group['kind'] == 'muon':
|
||||
group["momentum"] = muon_momentum
|
||||
group["weight_decay"] = muon_weight_decay
|
||||
optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point
|
||||
synchronize()
|
||||
|
||||
@@ -201,7 +201,7 @@ def run_gsm8k_eval(task, tokenizer, engine,
|
||||
# Training loop
|
||||
|
||||
# Init the optimizer
|
||||
optimizers = model.setup_optimizers(
|
||||
optimizer = model.setup_optimizer(
|
||||
unembedding_lr=args.unembedding_lr,
|
||||
embedding_lr=args.embedding_lr,
|
||||
matrix_lr=args.matrix_lr,
|
||||
@@ -209,10 +209,9 @@ optimizers = model.setup_optimizers(
|
||||
)
|
||||
|
||||
# Set the initial learning rate as a fraction of the base learning rate
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
group["initial_lr"] = group["lr"]
|
||||
|
||||
# Learning rate scheduler: simple rampdown to zero over num_steps
|
||||
def get_lr_multiplier(it):
|
||||
@@ -305,11 +304,9 @@ for step in range(num_steps):
|
||||
|
||||
# Update the model parameters
|
||||
lrm = get_lr_multiplier(step)
|
||||
for opt in optimizers: # first set the learning rate
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
for opt in optimizers: # then step the optimizers
|
||||
opt.step()
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
|
||||
@@ -150,17 +150,16 @@ build_val_loader = lambda: sft_data_generator(val_ds, batch_size=args.device_bat
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Optimizer
|
||||
|
||||
optimizers = model.setup_optimizers(
|
||||
optimizer = model.setup_optimizer(
|
||||
unembedding_lr=args.unembedding_lr,
|
||||
embedding_lr=args.embedding_lr,
|
||||
matrix_lr=args.matrix_lr,
|
||||
weight_decay=args.weight_decay,
|
||||
)
|
||||
# Set the initial learning rate as a fraction of the base learning rate
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
group["initial_lr"] = group["lr"]
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Training loop
|
||||
@@ -230,13 +229,11 @@ for step in range(num_iterations):
|
||||
|
||||
# learning rate scheduler
|
||||
lrm = get_lr_multiplier(step)
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
|
||||
# step the optimizers
|
||||
for opt in optimizers:
|
||||
opt.step()
|
||||
# step the optimizer
|
||||
optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
|
||||
# logging
|
||||
|
||||
@@ -93,14 +93,12 @@ print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
|
||||
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
|
||||
optimizers = model.setup_optimizers(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay)
|
||||
adamw_optimizer, muon_optimizer = optimizers
|
||||
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
|
||||
optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay)
|
||||
# Override the initial learning rate as a fraction of the base learning rate
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
group["initial_lr"] = group["lr"]
|
||||
|
||||
# Midtraining data mixture and DataLoader
|
||||
base_dir = get_base_dir()
|
||||
@@ -274,7 +272,7 @@ while True:
|
||||
checkpoint_dir,
|
||||
step,
|
||||
orig_model.state_dict(),
|
||||
[opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly
|
||||
optimizer.state_dict(),
|
||||
{
|
||||
"step": step,
|
||||
"val_bpb": val_bpb, # loss at last step
|
||||
@@ -306,16 +304,14 @@ while True:
|
||||
loss.backward()
|
||||
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
progress = max(progress, approx_progress) # only increase progress monotonically
|
||||
# step the optimizers
|
||||
# step the optimizer
|
||||
lrm = get_lr_multiplier(progress)
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
muon_momentum = get_muon_momentum(step)
|
||||
for group in muon_optimizer.param_groups:
|
||||
group["momentum"] = muon_momentum
|
||||
for opt in optimizers:
|
||||
opt.step()
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
if group['kind'] == 'muon':
|
||||
group["momentum"] = muon_momentum
|
||||
optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
synchronize()
|
||||
t1 = time.time()
|
||||
|
||||
Reference in New Issue
Block a user