Combine AdamW and Muon into single MuonAdamW optimizer, cleaner, ty @chrisjmccormick for idea/help

This commit is contained in:
Andrej Karpathy
2026-01-29 00:50:50 +00:00
parent 64a651a63c
commit 41bb2eac32
9 changed files with 595 additions and 574 deletions

View File

@@ -135,7 +135,6 @@ python -m pytest tests/test_engine.py -v -s
│ └── repackage_data_reference.py # Pretraining data shard generation │ └── repackage_data_reference.py # Pretraining data shard generation
├── nanochat ├── nanochat
│ ├── __init__.py # empty │ ├── __init__.py # empty
│ ├── adamw.py # Distributed AdamW optimizer
│ ├── checkpoint_manager.py # Save/Load model checkpoints │ ├── checkpoint_manager.py # Save/Load model checkpoints
│ ├── common.py # Misc small utilities, quality of life │ ├── common.py # Misc small utilities, quality of life
│ ├── core_eval.py # Evaluates base model CORE score (DCLM paper) │ ├── core_eval.py # Evaluates base model CORE score (DCLM paper)
@@ -146,7 +145,7 @@ python -m pytest tests/test_engine.py -v -s
│ ├── gpt.py # The GPT nn.Module Transformer │ ├── gpt.py # The GPT nn.Module Transformer
│ ├── logo.svg │ ├── logo.svg
│ ├── loss_eval.py # Evaluate bits per byte (instead of loss) │ ├── loss_eval.py # Evaluate bits per byte (instead of loss)
│ ├── muon.py # Distributed Muon optimizer │ ├── optim.py # AdamW + Muon optimizer, 1GPU and distributed
│ ├── report.py # Utilities for writing the nanochat Report │ ├── report.py # Utilities for writing the nanochat Report
│ ├── tokenizer.py # BPE Tokenizer wrapper in style of GPT-4 │ ├── tokenizer.py # BPE Tokenizer wrapper in style of GPT-4
│ └── ui.html # HTML/CSS/JS for nanochat frontend │ └── ui.html # HTML/CSS/JS for nanochat frontend

View File

@@ -1,143 +0,0 @@
"""
Distributed AdamW optimizer with a fused step function.
A bunch of ideas (e.g. dist comms in slices) are borrowed from modded-nanogpt.
"""
import torch
import torch.distributed as dist
from torch import Tensor
@torch.compile(dynamic=False, fullgraph=True)
def adamw_step_fused(
p: Tensor,
grad: Tensor,
exp_avg: Tensor,
exp_avg_sq: Tensor,
step_t: Tensor,
lr_t: Tensor,
beta1_t: Tensor,
beta2_t: Tensor,
eps_t: Tensor,
wd_t: Tensor,
) -> None:
"""
Fused AdamW step: weight_decay -> momentum_update -> bias_correction -> param_update
All in one compiled graph to eliminate Python overhead between ops.
The 0-D CPU tensors avoid recompilation when hyperparameter values change.
"""
# Weight decay (decoupled, applied before the update)
p.mul_(1 - lr_t * wd_t)
# Update running averages (lerp_ is cleaner and fuses well)
exp_avg.lerp_(grad, 1 - beta1_t)
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
# Bias corrections
bias1 = 1 - beta1_t ** step_t
bias2 = 1 - beta2_t ** step_t
# Compute update and apply
denom = (exp_avg_sq / bias2).sqrt() + eps_t
step_size = lr_t / bias1
p.add_(exp_avg / denom, alpha=-step_size)
class DistAdamW(torch.optim.Optimizer):
"""
Distributed AdamW optimizer.
In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction
"""
def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
rank = dist.get_rank()
world_size = dist.get_world_size()
# Validate
if rank == 0:
for group in param_groups:
assert isinstance(group, dict), "expecting param_groups to be a list of dicts"
assert isinstance(group['params'], list), "expecting group['params'] to be a list of tensors"
for p in group['params']:
sliced = p.numel() >= 1024
print(f"AdamW: 1 param of shape {p.shape}, sliced={sliced}")
if sliced: # large parameter tensors will be operated on in slices
assert p.shape[0] % world_size == 0, f"First dim of parameter shape {p.shape} must be divisible by world size {world_size}"
super().__init__(param_groups, defaults)
# 0-D CPU tensors to avoid torch.compile recompilation when values change
self._step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
@torch.no_grad()
def step(self):
rank = dist.get_rank()
world_size = dist.get_world_size()
reduce_futures: list[torch.Future] = []
gather_futures: list[torch.Future] = []
grad_slices = []
is_small = [] # track which params are small (use all_reduce) vs large (use reduce_scatter)
for group in self.param_groups:
params: list[Tensor] = group["params"]
for p in params:
grad = p.grad
# Small params: use all_reduce (no scatter/gather needed)
if p.numel() < 1024:
is_small.append(True)
reduce_futures.append(dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
grad_slices.append(grad)
else:
is_small.append(False)
rank_size = grad.shape[0] // world_size # p.shape[0] % world_size == 0 is checked in __init__
grad_slice = torch.empty_like(grad[:rank_size])
reduce_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
grad_slices.append(grad_slice)
idx = 0
for group in self.param_groups:
beta1, beta2 = group['betas']
eps = group['eps']
wd = group['weight_decay']
params = group['params']
for p in params:
reduce_futures[idx].wait()
g_slice = grad_slices[idx]
lr = group['lr'] * getattr(p, "lr_mul", 1.0)
state = self.state[p]
# For small params, operate on full param; for large, operate on slice
if is_small[idx]:
p_slice = p
else:
rank_size = p.shape[0] // world_size
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
# State init
if not state:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_slice)
state['exp_avg_sq'] = torch.zeros_like(p_slice)
exp_avg = state['exp_avg']
exp_avg_sq = state['exp_avg_sq']
state['step'] += 1
# Fill 0-D tensors with current values
eff_wd = wd * getattr(p, "wd_mul", 1.0)
self._step_t.fill_(state['step'])
self._lr_t.fill_(lr)
self._beta1_t.fill_(beta1)
self._beta2_t.fill_(beta2)
self._eps_t.fill_(eps)
self._wd_t.fill_(eff_wd)
# Fused update: weight_decay -> momentum -> bias_correction -> param_update
adamw_step_fused(
p_slice, g_slice, exp_avg, exp_avg_sq,
self._step_t, self._lr_t, self._beta1_t, self._beta2_t, self._eps_t, self._wd_t,
)
# Only large params need all_gather
if not is_small[idx]:
gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
idx += 1
if gather_futures:
torch.futures.collect_all(gather_futures).wait()

View File

@@ -20,8 +20,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from nanochat.common import get_dist_info, print0 from nanochat.common import get_dist_info, print0
from nanochat.muon import Muon, DistMuon from nanochat.optim import MuonAdamW, DistMuonAdamW
from nanochat.adamw import DistAdamW
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere # Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
from nanochat.flash_attention import flash_attn from nanochat.flash_attention import flash_attn
@@ -346,9 +345,10 @@ class GPT(nn.Module):
'total': total, 'total': total,
} }
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5): def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
model_dim = self.config.n_embd model_dim = self.config.n_embd
ddp, rank, local_rank, world_size = get_dist_info() ddp, rank, local_rank, world_size = get_dist_info()
# Separate out all parameters into groups # Separate out all parameters into groups
matrix_params = list(self.transformer.h.parameters()) matrix_params = list(self.transformer.h.parameters())
value_embeds_params = list(self.value_embeds.parameters()) value_embeds_params = list(self.value_embeds.parameters())
@@ -357,30 +357,33 @@ class GPT(nn.Module):
resid_params = [self.resid_lambdas] resid_params = [self.resid_lambdas]
x0_params = [self.x0_lambdas] x0_params = [self.x0_lambdas]
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)
# Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model) # Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
dmodel_lr_scale = (model_dim / 768) ** -0.5 dmodel_lr_scale = (model_dim / 768) ** -0.5
print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}") print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
adam_groups = [
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale), # Build param_groups with all required fields explicit
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale), param_groups = [
dict(params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding # AdamW groups (embeddings, lm_head, scalars)
dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
dict(params=x0_params, lr=scalar_lr, betas=(0.96, 0.95)), # higher beta1 for x0 scalars dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0),
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
] ]
adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon # Muon groups (matrix params, grouped by shape for stacking)
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True) for shape in sorted({p.shape for p in matrix_params}):
adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs) group_params = [p for p in matrix_params if p.shape == shape]
# Create the Muon optimizer for the linear layers param_groups.append(dict(
muon_kwargs = dict(lr=matrix_lr, momentum=0.95, weight_decay=weight_decay) kind='muon', params=group_params, lr=matrix_lr,
MuonFactory = DistMuon if ddp else Muon momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay,
muon_optimizer = MuonFactory(matrix_params, **muon_kwargs) ))
# Combine them the two optimizers into one list
optimizers = [adamw_optimizer, muon_optimizer] Factory = DistMuonAdamW if ddp else MuonAdamW
for opt in optimizers: optimizer = Factory(param_groups)
for group in opt.param_groups: for group in optimizer.param_groups:
group["initial_lr"] = group["lr"] group["initial_lr"] = group["lr"]
return optimizers return optimizer
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'): def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
B, T = idx.size() B, T = idx.size()

View File

@@ -1,352 +0,0 @@
"""
Muon optimizer adapted and simplified from modded-nanogpt.
https://github.com/KellerJordan/modded-nanogpt
Background:
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
zero even beyond the point where the iteration no longer converges all the way to one everywhere
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
Here, an alternative to Newton-Schulz iteration with potentially better convergence properties:
Polar Express Sign Method for orthogonalization.
https://arxiv.org/pdf/2505.16932
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
Some of the changes in nanochat implementation:
- Uses a simpler, more general approach to parameter grouping and stacking
- Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step
- Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format)
"""
import torch
from torch import Tensor
import torch.distributed as dist
# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
# From https://arxiv.org/pdf/2505.16932
polar_express_coeffs = [
(8.156554524902461, -22.48329292557795, 15.878769915207462),
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
]
@torch.compile(dynamic=False, fullgraph=True)
def muon_step_fused(
stacked_grads: Tensor,
stacked_params: Tensor,
momentum_buffer: Tensor,
second_momentum_buffer: Tensor,
momentum_t: Tensor,
lr_t: Tensor,
wd_t: Tensor,
beta2_t: Tensor,
ns_steps: int,
red_dim: int,
) -> None:
"""
Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update
All in one compiled graph to eliminate Python overhead between ops.
Some of the constants are 0-D CPU tensors to avoid recompilation when values change.
"""
# Nesterov momentum
momentum = momentum_t.to(stacked_grads.dtype)
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
g = stacked_grads.lerp_(momentum_buffer, momentum)
# Polar express
X = g.bfloat16()
if g.size(-2) > g.size(-1):
X = X.mT
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
for a, b, c in polar_express_coeffs[:ns_steps]:
A = X @ X.mT
B = b * A + c * (A @ A)
X = a * X + B @ X
if g.size(-2) > g.size(-1):
X = X.mT
g = X
# Variance reduction
beta2 = beta2_t.to(g.dtype)
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
red_dim_size = g.size(red_dim)
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
v_norm = v_norm_sq.sqrt()
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
g = g * final_scale.to(g.dtype)
# Cautious weight decay + parameter update
lr = lr_t.to(g.dtype)
wd = wd_t.to(g.dtype)
mask = (g * stacked_params) >= 0
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
class Muon(torch.optim.Optimizer):
"""
Muon - MomentUm Orthogonalized by Newton-schulz
https://kellerjordan.github.io/posts/muon/
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
the advantage that it can be stably run in bfloat16 on the GPU.
Some warnings:
- This optimizer should not be used for the embedding layer, the final fully connected layer,
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
Arguments:
lr: The learning rate used by the internal SGD.
momentum: The momentum used by the internal SGD.
ns_steps: The number of Newton-Schulz iteration steps to use.
beta2: The decay rate for the second moment (variance) estimate. Set to None to disable.
weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
"""
def __init__(self, params, lr=0.02, momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=0.0):
defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
params = list(params) # ensure we have a list, not an e.g. (exhaustible) iterator
# Group by shape so we can stack tensors
shapes = sorted({p.shape for p in params})
param_groups = []
for shape in shapes:
group_params = [p for p in params if p.shape == shape]
param_groups.append(dict(params=group_params))
super().__init__(param_groups, defaults)
# 0-D CPU tensors to avoid torch.compile recompilation when values change
self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
@torch.no_grad()
def step(self):
for group in self.param_groups:
params: list[Tensor] = group["params"]
if not params:
continue
# Get or create group-level buffers (stored in first param's state for convenience)
state = self.state[params[0]]
num_params = len(params) # e.g.: 12 (for a d12 model)
# e.g.: shape = (768, 3072), device = cuda:0, dtype = torch.float32, for one of the MLP projections
shape, device, dtype = params[0].shape, params[0].device, params[0].dtype
# Momentum for every individual parameter
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
momentum_buffer = state["momentum_buffer"] # e.g.: (12, 768, 3072)
# Second momentum buffer is factored, either per-row or per-column
if "second_momentum_buffer" not in state:
if shape[-2] >= shape[-1]:
state["second_momentum_buffer"] = torch.zeros(num_params, shape[-2], 1, dtype=dtype, device=device)
else:
state["second_momentum_buffer"] = torch.zeros(num_params, 1, shape[-1], dtype=dtype, device=device)
second_momentum_buffer = state["second_momentum_buffer"] # (12, 1, 3072)
red_dim = -1 if shape[-2] >= shape[-1] else -2 # e.g.: -2
# Stack grads and params
stacked_grads = torch.stack([p.grad for p in params]) # (12, 768, 3072)
stacked_params = torch.stack(params) # (12, 768, 3072)
# Fill all the 0-D tensors with current values
self._momentum_t.fill_(group["momentum"])
self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
self._wd_t.fill_(group["weight_decay"])
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
muon_step_fused(
stacked_grads,
stacked_params,
momentum_buffer,
second_momentum_buffer,
self._momentum_t,
self._lr_t,
self._wd_t,
self._beta2_t,
group["ns_steps"],
red_dim,
)
# Copy back to original params: [(768, 3072), (768, 3072), ...] <- (12, 768, 3072)
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
class DistMuon(torch.optim.Optimizer):
"""
Distributed version of the Muon optimizer.
"""
def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
ns_steps: int = 5, beta2: float = 0.95, weight_decay: float = 0.0):
defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
params = list(params)
world_size = dist.get_world_size()
rank = dist.get_rank()
# Group all parameters by their shape
shapes = sorted({p.shape for p in params}) # sort for deterministic ordering across ranks
param_groups = []
for shape in shapes:
group_params = [p for p in params if p.shape == shape]
device, dtype = group_params[0].device, group_params[0].dtype
assert all(p.device == device for p in group_params)
assert all(p.dtype == dtype for p in group_params)
# Compute chunk size for this group (how many params each rank owns)
chunk_size = (len(group_params) + world_size - 1) // world_size
if rank == 0:
print(f"Muon: {len(group_params)} params of shape {shape}, chunk_size={chunk_size}")
param_groups.append(dict(params=group_params, chunk_size=chunk_size))
super().__init__(param_groups, defaults)
# 0-D CPU tensors to avoid torch.compile recompilation when values change
self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
@torch.no_grad()
def step(self):
rank = dist.get_rank()
world_size = dist.get_world_size()
# Ensure all grads exist
assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
# First pass: stack grads and kick off reduce_scatter for each group
group_infos = []
for group in self.param_groups:
params: list[Tensor] = group["params"]
chunk_size = group["chunk_size"]
padded_num_params = chunk_size * world_size
shape = params[0].shape
device, dtype = params[0].device, params[0].dtype
# Stack all gradients into a single tensor (single kernel via torch.stack)
grad_stack = torch.stack([p.grad for p in params])
stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device)
stacked_grads[:len(params)].copy_(grad_stack)
# Zero-pad if we have fewer params than padded size
if len(params) < padded_num_params:
stacked_grads[len(params):].zero_()
# Output buffer for this rank's chunk
grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
# Async reduce_scatter on the stacked tensor
reduce_future = dist.reduce_scatter_tensor(
grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True
).get_future()
group_infos.append(dict(
grad_chunk=grad_chunk,
reduce_future=reduce_future,
stacked_grads=stacked_grads, # reuse for all_gather output
))
# Second pass: wait for reduce, compute batched updates, kick off all_gather
all_gather_futures = []
for group, info in zip(self.param_groups, group_infos):
info["reduce_future"].wait()
params = group["params"]
chunk_size = group["chunk_size"]
shape = params[0].shape
device, dtype = params[0].device, params[0].dtype
grad_chunk = info["grad_chunk"]
# How many params does this rank actually own?
start_idx = rank * chunk_size
num_owned = min(chunk_size, max(0, len(params) - start_idx))
# Get or create group-level state (stored keyed by first param)
state = self.state[params[0]]
# Momentum buffer
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
momentum_buffer = state["momentum_buffer"]
# Second momentum buffer is factored, either per-row or per-column
if "second_momentum_buffer" not in state:
if shape[-2] >= shape[-1]:
state["second_momentum_buffer"] = torch.zeros(chunk_size, shape[-2], 1, dtype=dtype, device=device)
else:
state["second_momentum_buffer"] = torch.zeros(chunk_size, 1, shape[-1], dtype=dtype, device=device)
second_momentum_buffer = state["second_momentum_buffer"]
red_dim = -1 if shape[-2] >= shape[-1] else -2
# Build updated_params tensor for all_gather
updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
if num_owned > 0:
# Stack owned params (single kernel via torch.stack)
owned_params = [params[start_idx + i] for i in range(num_owned)]
stacked_owned_params = torch.stack(owned_params)
# Get owned slices of buffers and grads
owned_grads = grad_chunk[:num_owned]
owned_momentum = momentum_buffer[:num_owned]
owned_second_momentum = second_momentum_buffer[:num_owned]
# Fill 0-D tensors with current values
self._momentum_t.fill_(group["momentum"])
self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
self._wd_t.fill_(group["weight_decay"])
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
muon_step_fused(
owned_grads,
stacked_owned_params,
owned_momentum,
owned_second_momentum,
self._momentum_t,
self._lr_t,
self._wd_t,
self._beta2_t,
group["ns_steps"],
red_dim,
)
# Copy updated params to output buffer
updated_params[:num_owned].copy_(stacked_owned_params)
# Zero-pad the rest (for ranks that own fewer params)
if num_owned < chunk_size:
updated_params[num_owned:].zero_()
# Reuse stacked_grads buffer for all_gather output
stacked_params = info["stacked_grads"]
# Async all_gather to replicate updated params to all ranks
gather_future = dist.all_gather_into_tensor(
stacked_params, updated_params, async_op=True
).get_future()
all_gather_futures.append(dict(
gather_future=gather_future,
stacked_params=stacked_params,
params=params,
))
# Final pass: wait for all_gather and copy back to params
for info in all_gather_futures:
info["gather_future"].wait()
stacked_params = info["stacked_params"]
params = info["params"]
# Batched copy back (single kernel instead of N individual copies)
torch._foreach_copy_(params, list(stacked_params[:len(params)].unbind(0)))

528
nanochat/optim.py Normal file
View File

@@ -0,0 +1,528 @@
"""
A nice and efficient mixed AdamW/Muon Combined Optimizer.
Usually the embeddings and scalars go into AdamW, and the matrix parameters go into Muon.
Two versions are provided (MuonAdamW, DistMuonAdamW), for single GPU and distributed.
Addapted from: https://github.com/KellerJordan/modded-nanogpt
Further contributions from @karpathy and @chrisjmccormick.
"""
import torch
import torch.distributed as dist
from torch import Tensor
# -----------------------------------------------------------------------------
"""
Good old AdamW optimizer, fused kernel.
https://arxiv.org/abs/1711.05101
"""
@torch.compile(dynamic=False, fullgraph=True)
def adamw_step_fused(
p: Tensor, # (32768, 768) - parameter tensor
grad: Tensor, # (32768, 768) - gradient, same shape as p
exp_avg: Tensor, # (32768, 768) - first moment, same shape as p
exp_avg_sq: Tensor, # (32768, 768) - second moment, same shape as p
step_t: Tensor, # () - 0-D CPU tensor, step count
lr_t: Tensor, # () - 0-D CPU tensor, learning rate
beta1_t: Tensor, # () - 0-D CPU tensor, beta1
beta2_t: Tensor, # () - 0-D CPU tensor, beta2
eps_t: Tensor, # () - 0-D CPU tensor, epsilon
wd_t: Tensor, # () - 0-D CPU tensor, weight decay
) -> None:
"""
Fused AdamW step: weight_decay -> momentum_update -> bias_correction -> param_update
All in one compiled graph to eliminate Python overhead between ops.
The 0-D CPU tensors avoid recompilation when hyperparameter values change.
"""
# Weight decay (decoupled, applied before the update)
p.mul_(1 - lr_t * wd_t)
# Update running averages (lerp_ is cleaner and fuses well)
exp_avg.lerp_(grad, 1 - beta1_t)
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
# Bias corrections
bias1 = 1 - beta1_t ** step_t
bias2 = 1 - beta2_t ** step_t
# Compute update and apply
denom = (exp_avg_sq / bias2).sqrt() + eps_t
step_size = lr_t / bias1
p.add_(exp_avg / denom, alpha=-step_size)
# -----------------------------------------------------------------------------
"""
Muon optimizer adapted and simplified from modded-nanogpt.
https://github.com/KellerJordan/modded-nanogpt
Background:
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
zero even beyond the point where the iteration no longer converges all the way to one everywhere
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
Here, an alternative to Newton-Schulz iteration with potentially better convergence properties:
Polar Express Sign Method for orthogonalization.
https://arxiv.org/pdf/2505.16932
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
Some of the changes in nanochat implementation:
- Uses a simpler, more general approach to parameter grouping and stacking
- Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step
- Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format)
"""
# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
# From https://arxiv.org/pdf/2505.16932
polar_express_coeffs = [
(8.156554524902461, -22.48329292557795, 15.878769915207462),
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
]
@torch.compile(dynamic=False, fullgraph=True)
def muon_step_fused(
stacked_grads: Tensor, # (12, 768, 3072) - stacked gradients
stacked_params: Tensor, # (12, 768, 3072) - stacked parameters
momentum_buffer: Tensor, # (12, 768, 3072) - first moment buffer
second_momentum_buffer: Tensor, # (12, 768, 1) or (12, 1, 3072) - factored second moment
momentum_t: Tensor, # () - 0-D CPU tensor, momentum coefficient
lr_t: Tensor, # () - 0-D CPU tensor, learning rate
wd_t: Tensor, # () - 0-D CPU tensor, weight decay
beta2_t: Tensor, # () - 0-D CPU tensor, beta2 for second moment
ns_steps: int, # 5 - number of Newton-Schulz/Polar Express iterations
red_dim: int, # -1 or -2 - reduction dimension for variance
) -> None:
"""
Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update
All in one compiled graph to eliminate Python overhead between ops.
Some of the constants are 0-D CPU tensors to avoid recompilation when values change.
"""
# Nesterov momentum
momentum = momentum_t.to(stacked_grads.dtype)
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
g = stacked_grads.lerp_(momentum_buffer, momentum)
# Polar express
X = g.bfloat16()
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
if g.size(-2) > g.size(-1): # Tall matrix
for a, b, c in polar_express_coeffs[:ns_steps]:
A = X.mT @ X
B = b * A + c * (A @ A)
X = a * X + X @ B
else: # Wide matrix (original math)
for a, b, c in polar_express_coeffs[:ns_steps]:
A = X @ X.mT
B = b * A + c * (A @ A)
X = a * X + B @ X
g = X
# Variance reduction
beta2 = beta2_t.to(g.dtype)
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
red_dim_size = g.size(red_dim)
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
v_norm = v_norm_sq.sqrt()
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
g = g * final_scale.to(g.dtype)
# Cautious weight decay + parameter update
lr = lr_t.to(g.dtype)
wd = wd_t.to(g.dtype)
mask = (g * stacked_params) >= 0
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
# -----------------------------------------------------------------------------
# Single GPU version of the MuonAdamW optimizer.
# Used mostly for reference, debugging and testing.
class MuonAdamW(torch.optim.Optimizer):
"""
Combined optimizer: Muon for 2D matrix params, AdamW for others, single GPU version.
AdamW - Fused AdamW optimizer step.
Muon - MomentUm Orthogonalized by Newton-schulz
https://kellerjordan.github.io/posts/muon/
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
the advantage that it can be stably run in bfloat16 on the GPU.
Some warnings:
- The Muon optimizer should not be used for the embedding layer, the final fully connected layer,
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
Arguments:
param_groups: List of dicts, each containing:
- 'params': List of parameters
- 'kind': 'adamw' or 'muon'
- For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
- For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
"""
def __init__(self, param_groups: list[dict]):
super().__init__(param_groups, defaults={})
# 0-D CPU tensors to avoid torch.compile recompilation when values change
# AdamW tensors
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
# Muon tensors
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
def _step_adamw(self, group: dict) -> None:
"""
AdamW update for each param in the group individually.
Lazy init the state, fill in all 0-D tensors, call the fused kernel.
"""
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
state = self.state[p]
# State init
if not state:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)
exp_avg = state['exp_avg']
exp_avg_sq = state['exp_avg_sq']
state['step'] += 1
# Fill 0-D tensors with current values
self._adamw_step_t.fill_(state['step'])
self._adamw_lr_t.fill_(group['lr'])
self._adamw_beta1_t.fill_(group['betas'][0])
self._adamw_beta2_t.fill_(group['betas'][1])
self._adamw_eps_t.fill_(group['eps'])
self._adamw_wd_t.fill_(group['weight_decay'])
# Fused update: weight_decay -> momentum -> bias_correction -> param_update
adamw_step_fused(
p, grad, exp_avg, exp_avg_sq,
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
)
def _step_muon(self, group: dict) -> None:
"""
Muon update for all params in the group (stacked for efficiency).
Lazy init the state, fill in all 0-D tensors, call the fused kernel.
"""
params: list[Tensor] = group['params']
if not params:
return
# Get or create group-level buffers (stored in first param's state for convenience)
p = params[0]
state = self.state[p]
num_params = len(params)
shape, device, dtype = p.shape, p.device, p.dtype
# Momentum for every individual parameter
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
momentum_buffer = state["momentum_buffer"]
# Second momentum buffer is factored, either per-row or per-column
if "second_momentum_buffer" not in state:
state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1])
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
second_momentum_buffer = state["second_momentum_buffer"]
red_dim = -1 if shape[-2] >= shape[-1] else -2
# Stack grads and params (NOTE: this assumes all params have the same shape)
stacked_grads = torch.stack([p.grad for p in params])
stacked_params = torch.stack(params)
# Fill all the 0-D tensors with current values
self._muon_momentum_t.fill_(group["momentum"])
self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
self._muon_wd_t.fill_(group["weight_decay"])
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
muon_step_fused(
stacked_grads,
stacked_params,
momentum_buffer,
second_momentum_buffer,
self._muon_momentum_t,
self._muon_lr_t,
self._muon_wd_t,
self._muon_beta2_t,
group["ns_steps"],
red_dim,
)
# Copy back to original params
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
@torch.no_grad()
def step(self):
for group in self.param_groups:
if group['kind'] == 'adamw':
self._step_adamw(group)
elif group['kind'] == 'muon':
self._step_muon(group)
else:
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
# -----------------------------------------------------------------------------
# Distributed version of the MuonAdamW optimizer.
# Used for training on multiple GPUs.
class DistMuonAdamW(torch.optim.Optimizer):
"""
Combined distributed optimizer: Muon for 2D matrix params, AdamW for others.
See MuonAdamW for the algorithmic details of each optimizer. This class adds
distributed communication to enable multi-GPU training without PyTorch DDP.
Design Goals:
- Overlap communication with computation (async ops)
- Minimize memory by sharding optimizer states across ranks (ZeRO-2 style)
- Batch small tensors into single comm ops where possible
Communication Pattern (3-phase async):
We use a 3-phase structure to maximize overlap between communication and compute:
Phase 1: Launch all async reduce ops
- Kick off all reduce_scatter/all_reduce operations
- Don't wait - let them run in background while we continue
Phase 2: Wait for reduces, compute updates, launch gathers
- For each group: wait for its reduce, compute the update, launch gather
- By processing groups in order, earlier gathers run while later computes happen
Phase 3: Wait for gathers, copy back
- Wait for all gathers to complete
- Copy updated params back to original tensors (Muon only)
AdamW Communication (ZeRO-2 style):
- Small params (<1024 elements): all_reduce gradients, update full param on each rank.
Optimizer state is replicated but these params are tiny (scalars, biases).
- Large params: reduce_scatter gradients so each rank gets 1/N of the grad, update
only that slice, then all_gather the updated slices. Optimizer state (exp_avg,
exp_avg_sq) is sharded - each rank only stores state for its slice.
Requires param.shape[0] divisible by world_size.
Muon Communication (stacked + chunked):
- All params in a Muon group must have the same shape (caller's responsibility).
- Stack all K params into a single (K, *shape) tensor for efficient comm.
- Divide K params across N ranks: each rank "owns" ceil(K/N) params.
- reduce_scatter the stacked grads so each rank gets its chunk.
- Each rank computes Muon update only for params it owns.
- all_gather the updated params back to all ranks.
- Optimizer state (momentum_buffer, second_momentum_buffer) is sharded by chunk.
- Padding: if K doesn't divide evenly, we zero-pad to (ceil(K/N) * N) for comm,
then ignore the padding when copying back.
Buffer Reuse:
- For Muon, we allocate stacked_grads for reduce_scatter input, then reuse the
same buffer as the output for all_gather (stacked_params). This saves memory
since we don't need both buffers simultaneously.
Arguments:
param_groups: List of dicts, each containing:
- 'params': List of parameters
- 'kind': 'adamw' or 'muon'
- For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
- For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
"""
def __init__(self, param_groups: list[dict]):
super().__init__(param_groups, defaults={})
# 0-D CPU tensors to avoid torch.compile recompilation when values change
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
def _reduce_adamw(self, group: dict, world_size: int) -> dict:
"""Launch async reduce ops for AdamW group. Returns info dict with per-param infos."""
param_infos = {}
for p in group['params']:
grad = p.grad
if p.numel() < 1024:
# Small params: all_reduce (no scatter/gather needed)
future = dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
param_infos[p] = dict(future=future, grad_slice=grad, is_small=True)
else:
# Large params: reduce_scatter
rank_size = grad.shape[0] // world_size
grad_slice = torch.empty_like(grad[:rank_size])
future = dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
param_infos[p] = dict(future=future, grad_slice=grad_slice, is_small=False)
return dict(param_infos=param_infos)
def _reduce_muon(self, group: dict, world_size: int) -> dict:
"""Launch async reduce op for Muon group. Returns info dict."""
params = group['params']
chunk_size = (len(params) + world_size - 1) // world_size
padded_num_params = chunk_size * world_size
p = params[0]
shape, device, dtype = p.shape, p.device, p.dtype
# Stack grads and zero-pad to padded_num_params
grad_stack = torch.stack([p.grad for p in params])
stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device)
stacked_grads[:len(params)].copy_(grad_stack)
if len(params) < padded_num_params:
stacked_grads[len(params):].zero_()
# Reduce_scatter to get this rank's chunk
grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
future = dist.reduce_scatter_tensor(grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True).get_future()
return dict(future=future, grad_chunk=grad_chunk, stacked_grads=stacked_grads, chunk_size=chunk_size)
def _compute_adamw(self, group: dict, info: dict, gather_list: list, rank: int, world_size: int) -> None:
"""Wait for reduce, compute AdamW updates, launch gathers for large params."""
param_infos = info['param_infos']
for p in group['params']:
pinfo = param_infos[p]
pinfo['future'].wait()
grad_slice = pinfo['grad_slice']
state = self.state[p]
# For small params, operate on full param; for large, operate on slice
if pinfo['is_small']:
p_slice = p
else:
rank_size = p.shape[0] // world_size
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
# State init
if not state:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_slice)
state['exp_avg_sq'] = torch.zeros_like(p_slice)
state['step'] += 1
# Fill 0-D tensors and run fused kernel
self._adamw_step_t.fill_(state['step'])
self._adamw_lr_t.fill_(group['lr'])
self._adamw_beta1_t.fill_(group['betas'][0])
self._adamw_beta2_t.fill_(group['betas'][1])
self._adamw_eps_t.fill_(group['eps'])
self._adamw_wd_t.fill_(group['weight_decay'])
adamw_step_fused(
p_slice, grad_slice, state['exp_avg'], state['exp_avg_sq'],
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
)
# Large params need all_gather
if not pinfo['is_small']:
future = dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()
gather_list.append(dict(future=future, params=None))
def _compute_muon(self, group: dict, info: dict, gather_list: list, rank: int) -> None:
"""Wait for reduce, compute Muon updates, launch gather."""
info['future'].wait()
params = group['params']
chunk_size = info['chunk_size']
grad_chunk = info['grad_chunk']
p = params[0]
shape, device, dtype = p.shape, p.device, p.dtype
# How many params does this rank own?
start_idx = rank * chunk_size
num_owned = min(chunk_size, max(0, len(params) - start_idx))
# Get or create group-level state
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
if "second_momentum_buffer" not in state:
state_shape = (chunk_size, shape[-2], 1) if shape[-2] >= shape[-1] else (chunk_size, 1, shape[-1])
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
red_dim = -1 if shape[-2] >= shape[-1] else -2
# Build output buffer for all_gather
updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
if num_owned > 0:
owned_params = [params[start_idx + i] for i in range(num_owned)]
stacked_owned = torch.stack(owned_params)
# Fill 0-D tensors and run fused kernel
self._muon_momentum_t.fill_(group["momentum"])
self._muon_beta2_t.fill_(group["beta2"])
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
self._muon_wd_t.fill_(group["weight_decay"])
muon_step_fused(
grad_chunk[:num_owned], stacked_owned,
state["momentum_buffer"][:num_owned], state["second_momentum_buffer"][:num_owned],
self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t, self._muon_beta2_t,
group["ns_steps"], red_dim,
)
updated_params[:num_owned].copy_(stacked_owned)
if num_owned < chunk_size:
updated_params[num_owned:].zero_()
# Reuse stacked_grads buffer for all_gather output
stacked_params = info["stacked_grads"]
future = dist.all_gather_into_tensor(stacked_params, updated_params, async_op=True).get_future()
gather_list.append(dict(future=future, stacked_params=stacked_params, params=params))
def _finish_gathers(self, gather_list: list) -> None:
"""Wait for all gathers and copy Muon params back."""
for info in gather_list:
info["future"].wait()
if info["params"] is not None:
# Muon: copy from stacked buffer back to individual params
torch._foreach_copy_(info["params"], list(info["stacked_params"][:len(info["params"])].unbind(0)))
@torch.no_grad()
def step(self):
rank = dist.get_rank()
world_size = dist.get_world_size()
# Phase 1: launch all async reduce ops
reduce_infos: list[dict] = []
for group in self.param_groups:
if group['kind'] == 'adamw':
reduce_infos.append(self._reduce_adamw(group, world_size))
elif group['kind'] == 'muon':
reduce_infos.append(self._reduce_muon(group, world_size))
else:
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
# Phase 2: wait for reduces, compute updates, launch gathers
gather_list: list[dict] = []
for group, info in zip(self.param_groups, reduce_infos):
if group['kind'] == 'adamw':
self._compute_adamw(group, info, gather_list, rank, world_size)
elif group['kind'] == 'muon':
self._compute_muon(group, info, gather_list, rank)
else:
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
# Phase 3: wait for gathers, copy back
self._finish_gathers(gather_list)

View File

@@ -211,9 +211,9 @@ print0(f"Tokens : Scaling params ratio: {args.total_batch_size * num_iterations
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head) # Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
adam_betas = (args.adam_beta1, args.adam_beta2) adam_betas = (args.adam_beta1, args.adam_beta2)
optimizers = model.setup_optimizers( optimizer = model.setup_optimizer(
unembedding_lr=args.unembedding_lr * batch_lr_scale, unembedding_lr=args.unembedding_lr * batch_lr_scale,
embedding_lr=args.embedding_lr * batch_lr_scale, embedding_lr=args.embedding_lr * batch_lr_scale,
matrix_lr=args.matrix_lr * batch_lr_scale, matrix_lr=args.matrix_lr * batch_lr_scale,
@@ -221,12 +221,10 @@ optimizers = model.setup_optimizers(
adam_betas=adam_betas, adam_betas=adam_betas,
scalar_lr=args.scalar_lr * batch_lr_scale, scalar_lr=args.scalar_lr * batch_lr_scale,
) )
adamw_optimizer, muon_optimizer = optimizers
if resuming: if resuming:
for opt, dat in zip(optimizers, optimizer_data): optimizer.load_state_dict(optimizer_data)
opt.load_state_dict(dat) del optimizer_data
del optimizer_data # free up the memory
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Initialize the DataLoaders for train/val # Initialize the DataLoaders for train/val
@@ -344,7 +342,7 @@ while True:
checkpoint_dir, checkpoint_dir,
step, step,
orig_model.state_dict(), # model parameters orig_model.state_dict(), # model parameters
[opt.state_dict() for opt in optimizers], # optimizer states optimizer.state_dict(), # optimizer state
{ # metadata saved as json { # metadata saved as json
"step": step, "step": step,
"val_bpb": val_bpb, # loss at last step "val_bpb": val_bpb, # loss at last step
@@ -378,18 +376,16 @@ while True:
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward() loss.backward()
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
# step the optimizers # step the optimizer
lrm = get_lr_multiplier(step) lrm = get_lr_multiplier(step)
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * lrm
muon_momentum = get_muon_momentum(step) muon_momentum = get_muon_momentum(step)
muon_weight_decay = get_weight_decay(step) muon_weight_decay = get_weight_decay(step)
for group in muon_optimizer.param_groups: for group in optimizer.param_groups:
group["lr"] = group["initial_lr"] * lrm
if group['kind'] == 'muon':
group["momentum"] = muon_momentum group["momentum"] = muon_momentum
group["weight_decay"] = muon_weight_decay group["weight_decay"] = muon_weight_decay
for opt in optimizers: optimizer.step()
opt.step()
model.zero_grad(set_to_none=True) model.zero_grad(set_to_none=True)
train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point
synchronize() synchronize()

View File

@@ -201,7 +201,7 @@ def run_gsm8k_eval(task, tokenizer, engine,
# Training loop # Training loop
# Init the optimizer # Init the optimizer
optimizers = model.setup_optimizers( optimizer = model.setup_optimizer(
unembedding_lr=args.unembedding_lr, unembedding_lr=args.unembedding_lr,
embedding_lr=args.embedding_lr, embedding_lr=args.embedding_lr,
matrix_lr=args.matrix_lr, matrix_lr=args.matrix_lr,
@@ -209,10 +209,9 @@ optimizers = model.setup_optimizers(
) )
# Set the initial learning rate as a fraction of the base learning rate # Set the initial learning rate as a fraction of the base learning rate
for opt in optimizers: for group in optimizer.param_groups:
for group in opt.param_groups:
group["lr"] = group["lr"] * args.init_lr_frac group["lr"] = group["lr"] * args.init_lr_frac
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later group["initial_lr"] = group["lr"]
# Learning rate scheduler: simple rampdown to zero over num_steps # Learning rate scheduler: simple rampdown to zero over num_steps
def get_lr_multiplier(it): def get_lr_multiplier(it):
@@ -305,11 +304,9 @@ for step in range(num_steps):
# Update the model parameters # Update the model parameters
lrm = get_lr_multiplier(step) lrm = get_lr_multiplier(step)
for opt in optimizers: # first set the learning rate for group in optimizer.param_groups:
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * lrm group["lr"] = group["initial_lr"] * lrm
for opt in optimizers: # then step the optimizers optimizer.step()
opt.step()
model.zero_grad(set_to_none=True) model.zero_grad(set_to_none=True)
wandb_run.log({ wandb_run.log({
"step": step, "step": step,

View File

@@ -150,17 +150,16 @@ build_val_loader = lambda: sft_data_generator(val_ds, batch_size=args.device_bat
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Initialize the Optimizer # Initialize the Optimizer
optimizers = model.setup_optimizers( optimizer = model.setup_optimizer(
unembedding_lr=args.unembedding_lr, unembedding_lr=args.unembedding_lr,
embedding_lr=args.embedding_lr, embedding_lr=args.embedding_lr,
matrix_lr=args.matrix_lr, matrix_lr=args.matrix_lr,
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
) )
# Set the initial learning rate as a fraction of the base learning rate # Set the initial learning rate as a fraction of the base learning rate
for opt in optimizers: for group in optimizer.param_groups:
for group in opt.param_groups:
group["lr"] = group["lr"] * args.init_lr_frac group["lr"] = group["lr"] * args.init_lr_frac
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later group["initial_lr"] = group["lr"]
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Training loop # Training loop
@@ -230,13 +229,11 @@ for step in range(num_iterations):
# learning rate scheduler # learning rate scheduler
lrm = get_lr_multiplier(step) lrm = get_lr_multiplier(step)
for opt in optimizers: for group in optimizer.param_groups:
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * lrm group["lr"] = group["initial_lr"] * lrm
# step the optimizers # step the optimizer
for opt in optimizers: optimizer.step()
opt.step()
model.zero_grad(set_to_none=True) model.zero_grad(set_to_none=True)
# logging # logging

View File

@@ -93,14 +93,12 @@ print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
token_bytes = get_token_bytes(device=device) token_bytes = get_token_bytes(device=device)
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head) # Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
optimizers = model.setup_optimizers(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay) optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay)
adamw_optimizer, muon_optimizer = optimizers
# Override the initial learning rate as a fraction of the base learning rate # Override the initial learning rate as a fraction of the base learning rate
for opt in optimizers: for group in optimizer.param_groups:
for group in opt.param_groups:
group["lr"] = group["lr"] * args.init_lr_frac group["lr"] = group["lr"] * args.init_lr_frac
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later group["initial_lr"] = group["lr"]
# Midtraining data mixture and DataLoader # Midtraining data mixture and DataLoader
base_dir = get_base_dir() base_dir = get_base_dir()
@@ -274,7 +272,7 @@ while True:
checkpoint_dir, checkpoint_dir,
step, step,
orig_model.state_dict(), orig_model.state_dict(),
[opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly optimizer.state_dict(),
{ {
"step": step, "step": step,
"val_bpb": val_bpb, # loss at last step "val_bpb": val_bpb, # loss at last step
@@ -306,16 +304,14 @@ while True:
loss.backward() loss.backward()
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
progress = max(progress, approx_progress) # only increase progress monotonically progress = max(progress, approx_progress) # only increase progress monotonically
# step the optimizers # step the optimizer
lrm = get_lr_multiplier(progress) lrm = get_lr_multiplier(progress)
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * lrm
muon_momentum = get_muon_momentum(step) muon_momentum = get_muon_momentum(step)
for group in muon_optimizer.param_groups: for group in optimizer.param_groups:
group["lr"] = group["initial_lr"] * lrm
if group['kind'] == 'muon':
group["momentum"] = muon_momentum group["momentum"] = muon_momentum
for opt in optimizers: optimizer.step()
opt.step()
model.zero_grad(set_to_none=True) model.zero_grad(set_to_none=True)
synchronize() synchronize()
t1 = time.time() t1 = time.time()