mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-30 04:22:02 +00:00
Combine AdamW and Muon into single MuonAdamW optimizer, cleaner, ty @chrisjmccormick for idea/help
This commit is contained in:
@@ -135,7 +135,6 @@ python -m pytest tests/test_engine.py -v -s
|
|||||||
│ └── repackage_data_reference.py # Pretraining data shard generation
|
│ └── repackage_data_reference.py # Pretraining data shard generation
|
||||||
├── nanochat
|
├── nanochat
|
||||||
│ ├── __init__.py # empty
|
│ ├── __init__.py # empty
|
||||||
│ ├── adamw.py # Distributed AdamW optimizer
|
|
||||||
│ ├── checkpoint_manager.py # Save/Load model checkpoints
|
│ ├── checkpoint_manager.py # Save/Load model checkpoints
|
||||||
│ ├── common.py # Misc small utilities, quality of life
|
│ ├── common.py # Misc small utilities, quality of life
|
||||||
│ ├── core_eval.py # Evaluates base model CORE score (DCLM paper)
|
│ ├── core_eval.py # Evaluates base model CORE score (DCLM paper)
|
||||||
@@ -146,7 +145,7 @@ python -m pytest tests/test_engine.py -v -s
|
|||||||
│ ├── gpt.py # The GPT nn.Module Transformer
|
│ ├── gpt.py # The GPT nn.Module Transformer
|
||||||
│ ├── logo.svg
|
│ ├── logo.svg
|
||||||
│ ├── loss_eval.py # Evaluate bits per byte (instead of loss)
|
│ ├── loss_eval.py # Evaluate bits per byte (instead of loss)
|
||||||
│ ├── muon.py # Distributed Muon optimizer
|
│ ├── optim.py # AdamW + Muon optimizer, 1GPU and distributed
|
||||||
│ ├── report.py # Utilities for writing the nanochat Report
|
│ ├── report.py # Utilities for writing the nanochat Report
|
||||||
│ ├── tokenizer.py # BPE Tokenizer wrapper in style of GPT-4
|
│ ├── tokenizer.py # BPE Tokenizer wrapper in style of GPT-4
|
||||||
│ └── ui.html # HTML/CSS/JS for nanochat frontend
|
│ └── ui.html # HTML/CSS/JS for nanochat frontend
|
||||||
|
|||||||
@@ -1,143 +0,0 @@
|
|||||||
"""
|
|
||||||
Distributed AdamW optimizer with a fused step function.
|
|
||||||
A bunch of ideas (e.g. dist comms in slices) are borrowed from modded-nanogpt.
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
@torch.compile(dynamic=False, fullgraph=True)
|
|
||||||
def adamw_step_fused(
|
|
||||||
p: Tensor,
|
|
||||||
grad: Tensor,
|
|
||||||
exp_avg: Tensor,
|
|
||||||
exp_avg_sq: Tensor,
|
|
||||||
step_t: Tensor,
|
|
||||||
lr_t: Tensor,
|
|
||||||
beta1_t: Tensor,
|
|
||||||
beta2_t: Tensor,
|
|
||||||
eps_t: Tensor,
|
|
||||||
wd_t: Tensor,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Fused AdamW step: weight_decay -> momentum_update -> bias_correction -> param_update
|
|
||||||
All in one compiled graph to eliminate Python overhead between ops.
|
|
||||||
The 0-D CPU tensors avoid recompilation when hyperparameter values change.
|
|
||||||
"""
|
|
||||||
# Weight decay (decoupled, applied before the update)
|
|
||||||
p.mul_(1 - lr_t * wd_t)
|
|
||||||
# Update running averages (lerp_ is cleaner and fuses well)
|
|
||||||
exp_avg.lerp_(grad, 1 - beta1_t)
|
|
||||||
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
|
|
||||||
# Bias corrections
|
|
||||||
bias1 = 1 - beta1_t ** step_t
|
|
||||||
bias2 = 1 - beta2_t ** step_t
|
|
||||||
# Compute update and apply
|
|
||||||
denom = (exp_avg_sq / bias2).sqrt() + eps_t
|
|
||||||
step_size = lr_t / bias1
|
|
||||||
p.add_(exp_avg / denom, alpha=-step_size)
|
|
||||||
|
|
||||||
|
|
||||||
class DistAdamW(torch.optim.Optimizer):
|
|
||||||
"""
|
|
||||||
Distributed AdamW optimizer.
|
|
||||||
In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction
|
|
||||||
"""
|
|
||||||
def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01):
|
|
||||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
|
||||||
rank = dist.get_rank()
|
|
||||||
world_size = dist.get_world_size()
|
|
||||||
# Validate
|
|
||||||
if rank == 0:
|
|
||||||
for group in param_groups:
|
|
||||||
assert isinstance(group, dict), "expecting param_groups to be a list of dicts"
|
|
||||||
assert isinstance(group['params'], list), "expecting group['params'] to be a list of tensors"
|
|
||||||
for p in group['params']:
|
|
||||||
sliced = p.numel() >= 1024
|
|
||||||
print(f"AdamW: 1 param of shape {p.shape}, sliced={sliced}")
|
|
||||||
if sliced: # large parameter tensors will be operated on in slices
|
|
||||||
assert p.shape[0] % world_size == 0, f"First dim of parameter shape {p.shape} must be divisible by world size {world_size}"
|
|
||||||
super().__init__(param_groups, defaults)
|
|
||||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
|
||||||
self._step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
|
||||||
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
|
||||||
self._beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
|
||||||
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
|
||||||
self._eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
|
||||||
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def step(self):
|
|
||||||
rank = dist.get_rank()
|
|
||||||
world_size = dist.get_world_size()
|
|
||||||
reduce_futures: list[torch.Future] = []
|
|
||||||
gather_futures: list[torch.Future] = []
|
|
||||||
grad_slices = []
|
|
||||||
is_small = [] # track which params are small (use all_reduce) vs large (use reduce_scatter)
|
|
||||||
|
|
||||||
for group in self.param_groups:
|
|
||||||
params: list[Tensor] = group["params"]
|
|
||||||
for p in params:
|
|
||||||
grad = p.grad
|
|
||||||
# Small params: use all_reduce (no scatter/gather needed)
|
|
||||||
if p.numel() < 1024:
|
|
||||||
is_small.append(True)
|
|
||||||
reduce_futures.append(dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
|
|
||||||
grad_slices.append(grad)
|
|
||||||
else:
|
|
||||||
is_small.append(False)
|
|
||||||
rank_size = grad.shape[0] // world_size # p.shape[0] % world_size == 0 is checked in __init__
|
|
||||||
grad_slice = torch.empty_like(grad[:rank_size])
|
|
||||||
reduce_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
|
|
||||||
grad_slices.append(grad_slice)
|
|
||||||
|
|
||||||
idx = 0
|
|
||||||
for group in self.param_groups:
|
|
||||||
beta1, beta2 = group['betas']
|
|
||||||
eps = group['eps']
|
|
||||||
wd = group['weight_decay']
|
|
||||||
params = group['params']
|
|
||||||
for p in params:
|
|
||||||
reduce_futures[idx].wait()
|
|
||||||
g_slice = grad_slices[idx]
|
|
||||||
lr = group['lr'] * getattr(p, "lr_mul", 1.0)
|
|
||||||
state = self.state[p]
|
|
||||||
|
|
||||||
# For small params, operate on full param; for large, operate on slice
|
|
||||||
if is_small[idx]:
|
|
||||||
p_slice = p
|
|
||||||
else:
|
|
||||||
rank_size = p.shape[0] // world_size
|
|
||||||
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
|
|
||||||
|
|
||||||
# State init
|
|
||||||
if not state:
|
|
||||||
state['step'] = 0
|
|
||||||
state['exp_avg'] = torch.zeros_like(p_slice)
|
|
||||||
state['exp_avg_sq'] = torch.zeros_like(p_slice)
|
|
||||||
exp_avg = state['exp_avg']
|
|
||||||
exp_avg_sq = state['exp_avg_sq']
|
|
||||||
state['step'] += 1
|
|
||||||
|
|
||||||
# Fill 0-D tensors with current values
|
|
||||||
eff_wd = wd * getattr(p, "wd_mul", 1.0)
|
|
||||||
self._step_t.fill_(state['step'])
|
|
||||||
self._lr_t.fill_(lr)
|
|
||||||
self._beta1_t.fill_(beta1)
|
|
||||||
self._beta2_t.fill_(beta2)
|
|
||||||
self._eps_t.fill_(eps)
|
|
||||||
self._wd_t.fill_(eff_wd)
|
|
||||||
|
|
||||||
# Fused update: weight_decay -> momentum -> bias_correction -> param_update
|
|
||||||
adamw_step_fused(
|
|
||||||
p_slice, g_slice, exp_avg, exp_avg_sq,
|
|
||||||
self._step_t, self._lr_t, self._beta1_t, self._beta2_t, self._eps_t, self._wd_t,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Only large params need all_gather
|
|
||||||
if not is_small[idx]:
|
|
||||||
gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
|
|
||||||
idx += 1
|
|
||||||
|
|
||||||
if gather_futures:
|
|
||||||
torch.futures.collect_all(gather_futures).wait()
|
|
||||||
@@ -20,8 +20,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from nanochat.common import get_dist_info, print0
|
from nanochat.common import get_dist_info, print0
|
||||||
from nanochat.muon import Muon, DistMuon
|
from nanochat.optim import MuonAdamW, DistMuonAdamW
|
||||||
from nanochat.adamw import DistAdamW
|
|
||||||
|
|
||||||
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
|
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
|
||||||
from nanochat.flash_attention import flash_attn
|
from nanochat.flash_attention import flash_attn
|
||||||
@@ -346,9 +345,10 @@ class GPT(nn.Module):
|
|||||||
'total': total,
|
'total': total,
|
||||||
}
|
}
|
||||||
|
|
||||||
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
|
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
|
||||||
model_dim = self.config.n_embd
|
model_dim = self.config.n_embd
|
||||||
ddp, rank, local_rank, world_size = get_dist_info()
|
ddp, rank, local_rank, world_size = get_dist_info()
|
||||||
|
|
||||||
# Separate out all parameters into groups
|
# Separate out all parameters into groups
|
||||||
matrix_params = list(self.transformer.h.parameters())
|
matrix_params = list(self.transformer.h.parameters())
|
||||||
value_embeds_params = list(self.value_embeds.parameters())
|
value_embeds_params = list(self.value_embeds.parameters())
|
||||||
@@ -357,30 +357,33 @@ class GPT(nn.Module):
|
|||||||
resid_params = [self.resid_lambdas]
|
resid_params = [self.resid_lambdas]
|
||||||
x0_params = [self.x0_lambdas]
|
x0_params = [self.x0_lambdas]
|
||||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)
|
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)
|
||||||
# Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars
|
|
||||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
|
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
|
||||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||||
print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
|
print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
|
||||||
adam_groups = [
|
|
||||||
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
|
# Build param_groups with all required fields explicit
|
||||||
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
|
param_groups = [
|
||||||
dict(params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding
|
# AdamW groups (embeddings, lm_head, scalars)
|
||||||
dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream
|
dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||||
dict(params=x0_params, lr=scalar_lr, betas=(0.96, 0.95)), # higher beta1 for x0 scalars
|
dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||||
|
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||||
|
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||||
|
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
|
||||||
]
|
]
|
||||||
adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon
|
# Muon groups (matrix params, grouped by shape for stacking)
|
||||||
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
|
for shape in sorted({p.shape for p in matrix_params}):
|
||||||
adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
|
group_params = [p for p in matrix_params if p.shape == shape]
|
||||||
# Create the Muon optimizer for the linear layers
|
param_groups.append(dict(
|
||||||
muon_kwargs = dict(lr=matrix_lr, momentum=0.95, weight_decay=weight_decay)
|
kind='muon', params=group_params, lr=matrix_lr,
|
||||||
MuonFactory = DistMuon if ddp else Muon
|
momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay,
|
||||||
muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
|
))
|
||||||
# Combine them the two optimizers into one list
|
|
||||||
optimizers = [adamw_optimizer, muon_optimizer]
|
Factory = DistMuonAdamW if ddp else MuonAdamW
|
||||||
for opt in optimizers:
|
optimizer = Factory(param_groups)
|
||||||
for group in opt.param_groups:
|
for group in optimizer.param_groups:
|
||||||
group["initial_lr"] = group["lr"]
|
group["initial_lr"] = group["lr"]
|
||||||
return optimizers
|
return optimizer
|
||||||
|
|
||||||
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
||||||
B, T = idx.size()
|
B, T = idx.size()
|
||||||
|
|||||||
352
nanochat/muon.py
352
nanochat/muon.py
@@ -1,352 +0,0 @@
|
|||||||
"""
|
|
||||||
Muon optimizer adapted and simplified from modded-nanogpt.
|
|
||||||
https://github.com/KellerJordan/modded-nanogpt
|
|
||||||
|
|
||||||
Background:
|
|
||||||
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
|
||||||
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
|
||||||
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
|
||||||
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
|
||||||
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
|
||||||
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
|
||||||
performance at all relative to UV^T, where USV^T = G is the SVD.
|
|
||||||
|
|
||||||
Here, an alternative to Newton-Schulz iteration with potentially better convergence properties:
|
|
||||||
Polar Express Sign Method for orthogonalization.
|
|
||||||
https://arxiv.org/pdf/2505.16932
|
|
||||||
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
|
|
||||||
|
|
||||||
Some of the changes in nanochat implementation:
|
|
||||||
- Uses a simpler, more general approach to parameter grouping and stacking
|
|
||||||
- Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step
|
|
||||||
- Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import Tensor
|
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
|
|
||||||
# From https://arxiv.org/pdf/2505.16932
|
|
||||||
polar_express_coeffs = [
|
|
||||||
(8.156554524902461, -22.48329292557795, 15.878769915207462),
|
|
||||||
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
|
|
||||||
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
|
|
||||||
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
|
|
||||||
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
|
|
||||||
]
|
|
||||||
|
|
||||||
@torch.compile(dynamic=False, fullgraph=True)
|
|
||||||
def muon_step_fused(
|
|
||||||
stacked_grads: Tensor,
|
|
||||||
stacked_params: Tensor,
|
|
||||||
momentum_buffer: Tensor,
|
|
||||||
second_momentum_buffer: Tensor,
|
|
||||||
momentum_t: Tensor,
|
|
||||||
lr_t: Tensor,
|
|
||||||
wd_t: Tensor,
|
|
||||||
beta2_t: Tensor,
|
|
||||||
ns_steps: int,
|
|
||||||
red_dim: int,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update
|
|
||||||
All in one compiled graph to eliminate Python overhead between ops.
|
|
||||||
Some of the constants are 0-D CPU tensors to avoid recompilation when values change.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Nesterov momentum
|
|
||||||
momentum = momentum_t.to(stacked_grads.dtype)
|
|
||||||
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
|
|
||||||
g = stacked_grads.lerp_(momentum_buffer, momentum)
|
|
||||||
|
|
||||||
# Polar express
|
|
||||||
X = g.bfloat16()
|
|
||||||
if g.size(-2) > g.size(-1):
|
|
||||||
X = X.mT
|
|
||||||
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
|
|
||||||
for a, b, c in polar_express_coeffs[:ns_steps]:
|
|
||||||
A = X @ X.mT
|
|
||||||
B = b * A + c * (A @ A)
|
|
||||||
X = a * X + B @ X
|
|
||||||
if g.size(-2) > g.size(-1):
|
|
||||||
X = X.mT
|
|
||||||
g = X
|
|
||||||
|
|
||||||
# Variance reduction
|
|
||||||
beta2 = beta2_t.to(g.dtype)
|
|
||||||
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
|
|
||||||
red_dim_size = g.size(red_dim)
|
|
||||||
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
|
|
||||||
v_norm = v_norm_sq.sqrt()
|
|
||||||
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
|
|
||||||
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
|
|
||||||
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
|
|
||||||
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
|
|
||||||
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
|
|
||||||
g = g * final_scale.to(g.dtype)
|
|
||||||
|
|
||||||
# Cautious weight decay + parameter update
|
|
||||||
lr = lr_t.to(g.dtype)
|
|
||||||
wd = wd_t.to(g.dtype)
|
|
||||||
mask = (g * stacked_params) >= 0
|
|
||||||
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
|
|
||||||
|
|
||||||
class Muon(torch.optim.Optimizer):
|
|
||||||
"""
|
|
||||||
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
||||||
|
|
||||||
https://kellerjordan.github.io/posts/muon/
|
|
||||||
|
|
||||||
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
|
|
||||||
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
|
|
||||||
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
|
|
||||||
the advantage that it can be stably run in bfloat16 on the GPU.
|
|
||||||
|
|
||||||
Some warnings:
|
|
||||||
- This optimizer should not be used for the embedding layer, the final fully connected layer,
|
|
||||||
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
|
|
||||||
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
lr: The learning rate used by the internal SGD.
|
|
||||||
momentum: The momentum used by the internal SGD.
|
|
||||||
ns_steps: The number of Newton-Schulz iteration steps to use.
|
|
||||||
beta2: The decay rate for the second moment (variance) estimate. Set to None to disable.
|
|
||||||
weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
|
|
||||||
"""
|
|
||||||
def __init__(self, params, lr=0.02, momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=0.0):
|
|
||||||
defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
|
|
||||||
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
|
|
||||||
params = list(params) # ensure we have a list, not an e.g. (exhaustible) iterator
|
|
||||||
# Group by shape so we can stack tensors
|
|
||||||
shapes = sorted({p.shape for p in params})
|
|
||||||
param_groups = []
|
|
||||||
for shape in shapes:
|
|
||||||
group_params = [p for p in params if p.shape == shape]
|
|
||||||
param_groups.append(dict(params=group_params))
|
|
||||||
super().__init__(param_groups, defaults)
|
|
||||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
|
||||||
self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
|
||||||
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
|
||||||
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
|
||||||
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def step(self):
|
|
||||||
for group in self.param_groups:
|
|
||||||
params: list[Tensor] = group["params"]
|
|
||||||
if not params:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get or create group-level buffers (stored in first param's state for convenience)
|
|
||||||
state = self.state[params[0]]
|
|
||||||
num_params = len(params) # e.g.: 12 (for a d12 model)
|
|
||||||
# e.g.: shape = (768, 3072), device = cuda:0, dtype = torch.float32, for one of the MLP projections
|
|
||||||
shape, device, dtype = params[0].shape, params[0].device, params[0].dtype
|
|
||||||
|
|
||||||
# Momentum for every individual parameter
|
|
||||||
if "momentum_buffer" not in state:
|
|
||||||
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
|
|
||||||
momentum_buffer = state["momentum_buffer"] # e.g.: (12, 768, 3072)
|
|
||||||
|
|
||||||
# Second momentum buffer is factored, either per-row or per-column
|
|
||||||
if "second_momentum_buffer" not in state:
|
|
||||||
if shape[-2] >= shape[-1]:
|
|
||||||
state["second_momentum_buffer"] = torch.zeros(num_params, shape[-2], 1, dtype=dtype, device=device)
|
|
||||||
else:
|
|
||||||
state["second_momentum_buffer"] = torch.zeros(num_params, 1, shape[-1], dtype=dtype, device=device)
|
|
||||||
second_momentum_buffer = state["second_momentum_buffer"] # (12, 1, 3072)
|
|
||||||
red_dim = -1 if shape[-2] >= shape[-1] else -2 # e.g.: -2
|
|
||||||
|
|
||||||
# Stack grads and params
|
|
||||||
stacked_grads = torch.stack([p.grad for p in params]) # (12, 768, 3072)
|
|
||||||
stacked_params = torch.stack(params) # (12, 768, 3072)
|
|
||||||
|
|
||||||
# Fill all the 0-D tensors with current values
|
|
||||||
self._momentum_t.fill_(group["momentum"])
|
|
||||||
self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
|
||||||
self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
|
||||||
self._wd_t.fill_(group["weight_decay"])
|
|
||||||
|
|
||||||
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
|
|
||||||
muon_step_fused(
|
|
||||||
stacked_grads,
|
|
||||||
stacked_params,
|
|
||||||
momentum_buffer,
|
|
||||||
second_momentum_buffer,
|
|
||||||
self._momentum_t,
|
|
||||||
self._lr_t,
|
|
||||||
self._wd_t,
|
|
||||||
self._beta2_t,
|
|
||||||
group["ns_steps"],
|
|
||||||
red_dim,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Copy back to original params: [(768, 3072), (768, 3072), ...] <- (12, 768, 3072)
|
|
||||||
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
|
|
||||||
|
|
||||||
|
|
||||||
class DistMuon(torch.optim.Optimizer):
|
|
||||||
"""
|
|
||||||
Distributed version of the Muon optimizer.
|
|
||||||
"""
|
|
||||||
def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
|
|
||||||
ns_steps: int = 5, beta2: float = 0.95, weight_decay: float = 0.0):
|
|
||||||
defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
|
|
||||||
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
|
|
||||||
params = list(params)
|
|
||||||
world_size = dist.get_world_size()
|
|
||||||
rank = dist.get_rank()
|
|
||||||
# Group all parameters by their shape
|
|
||||||
shapes = sorted({p.shape for p in params}) # sort for deterministic ordering across ranks
|
|
||||||
param_groups = []
|
|
||||||
for shape in shapes:
|
|
||||||
group_params = [p for p in params if p.shape == shape]
|
|
||||||
device, dtype = group_params[0].device, group_params[0].dtype
|
|
||||||
assert all(p.device == device for p in group_params)
|
|
||||||
assert all(p.dtype == dtype for p in group_params)
|
|
||||||
# Compute chunk size for this group (how many params each rank owns)
|
|
||||||
chunk_size = (len(group_params) + world_size - 1) // world_size
|
|
||||||
if rank == 0:
|
|
||||||
print(f"Muon: {len(group_params)} params of shape {shape}, chunk_size={chunk_size}")
|
|
||||||
param_groups.append(dict(params=group_params, chunk_size=chunk_size))
|
|
||||||
super().__init__(param_groups, defaults)
|
|
||||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
|
||||||
self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
|
||||||
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
|
||||||
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
|
||||||
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def step(self):
|
|
||||||
rank = dist.get_rank()
|
|
||||||
world_size = dist.get_world_size()
|
|
||||||
|
|
||||||
# Ensure all grads exist
|
|
||||||
assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
|
|
||||||
|
|
||||||
# First pass: stack grads and kick off reduce_scatter for each group
|
|
||||||
group_infos = []
|
|
||||||
for group in self.param_groups:
|
|
||||||
params: list[Tensor] = group["params"]
|
|
||||||
chunk_size = group["chunk_size"]
|
|
||||||
padded_num_params = chunk_size * world_size
|
|
||||||
shape = params[0].shape
|
|
||||||
device, dtype = params[0].device, params[0].dtype
|
|
||||||
|
|
||||||
# Stack all gradients into a single tensor (single kernel via torch.stack)
|
|
||||||
grad_stack = torch.stack([p.grad for p in params])
|
|
||||||
stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device)
|
|
||||||
stacked_grads[:len(params)].copy_(grad_stack)
|
|
||||||
# Zero-pad if we have fewer params than padded size
|
|
||||||
if len(params) < padded_num_params:
|
|
||||||
stacked_grads[len(params):].zero_()
|
|
||||||
|
|
||||||
# Output buffer for this rank's chunk
|
|
||||||
grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
# Async reduce_scatter on the stacked tensor
|
|
||||||
reduce_future = dist.reduce_scatter_tensor(
|
|
||||||
grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True
|
|
||||||
).get_future()
|
|
||||||
|
|
||||||
group_infos.append(dict(
|
|
||||||
grad_chunk=grad_chunk,
|
|
||||||
reduce_future=reduce_future,
|
|
||||||
stacked_grads=stacked_grads, # reuse for all_gather output
|
|
||||||
))
|
|
||||||
|
|
||||||
# Second pass: wait for reduce, compute batched updates, kick off all_gather
|
|
||||||
all_gather_futures = []
|
|
||||||
for group, info in zip(self.param_groups, group_infos):
|
|
||||||
info["reduce_future"].wait()
|
|
||||||
|
|
||||||
params = group["params"]
|
|
||||||
chunk_size = group["chunk_size"]
|
|
||||||
shape = params[0].shape
|
|
||||||
device, dtype = params[0].device, params[0].dtype
|
|
||||||
grad_chunk = info["grad_chunk"]
|
|
||||||
|
|
||||||
# How many params does this rank actually own?
|
|
||||||
start_idx = rank * chunk_size
|
|
||||||
num_owned = min(chunk_size, max(0, len(params) - start_idx))
|
|
||||||
|
|
||||||
# Get or create group-level state (stored keyed by first param)
|
|
||||||
state = self.state[params[0]]
|
|
||||||
|
|
||||||
# Momentum buffer
|
|
||||||
if "momentum_buffer" not in state:
|
|
||||||
state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
|
|
||||||
momentum_buffer = state["momentum_buffer"]
|
|
||||||
|
|
||||||
# Second momentum buffer is factored, either per-row or per-column
|
|
||||||
if "second_momentum_buffer" not in state:
|
|
||||||
if shape[-2] >= shape[-1]:
|
|
||||||
state["second_momentum_buffer"] = torch.zeros(chunk_size, shape[-2], 1, dtype=dtype, device=device)
|
|
||||||
else:
|
|
||||||
state["second_momentum_buffer"] = torch.zeros(chunk_size, 1, shape[-1], dtype=dtype, device=device)
|
|
||||||
second_momentum_buffer = state["second_momentum_buffer"]
|
|
||||||
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
|
||||||
|
|
||||||
# Build updated_params tensor for all_gather
|
|
||||||
updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
if num_owned > 0:
|
|
||||||
# Stack owned params (single kernel via torch.stack)
|
|
||||||
owned_params = [params[start_idx + i] for i in range(num_owned)]
|
|
||||||
stacked_owned_params = torch.stack(owned_params)
|
|
||||||
|
|
||||||
# Get owned slices of buffers and grads
|
|
||||||
owned_grads = grad_chunk[:num_owned]
|
|
||||||
owned_momentum = momentum_buffer[:num_owned]
|
|
||||||
owned_second_momentum = second_momentum_buffer[:num_owned]
|
|
||||||
|
|
||||||
# Fill 0-D tensors with current values
|
|
||||||
self._momentum_t.fill_(group["momentum"])
|
|
||||||
self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
|
||||||
self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
|
||||||
self._wd_t.fill_(group["weight_decay"])
|
|
||||||
|
|
||||||
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
|
|
||||||
muon_step_fused(
|
|
||||||
owned_grads,
|
|
||||||
stacked_owned_params,
|
|
||||||
owned_momentum,
|
|
||||||
owned_second_momentum,
|
|
||||||
self._momentum_t,
|
|
||||||
self._lr_t,
|
|
||||||
self._wd_t,
|
|
||||||
self._beta2_t,
|
|
||||||
group["ns_steps"],
|
|
||||||
red_dim,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Copy updated params to output buffer
|
|
||||||
updated_params[:num_owned].copy_(stacked_owned_params)
|
|
||||||
|
|
||||||
# Zero-pad the rest (for ranks that own fewer params)
|
|
||||||
if num_owned < chunk_size:
|
|
||||||
updated_params[num_owned:].zero_()
|
|
||||||
|
|
||||||
# Reuse stacked_grads buffer for all_gather output
|
|
||||||
stacked_params = info["stacked_grads"]
|
|
||||||
|
|
||||||
# Async all_gather to replicate updated params to all ranks
|
|
||||||
gather_future = dist.all_gather_into_tensor(
|
|
||||||
stacked_params, updated_params, async_op=True
|
|
||||||
).get_future()
|
|
||||||
|
|
||||||
all_gather_futures.append(dict(
|
|
||||||
gather_future=gather_future,
|
|
||||||
stacked_params=stacked_params,
|
|
||||||
params=params,
|
|
||||||
))
|
|
||||||
|
|
||||||
# Final pass: wait for all_gather and copy back to params
|
|
||||||
for info in all_gather_futures:
|
|
||||||
info["gather_future"].wait()
|
|
||||||
stacked_params = info["stacked_params"]
|
|
||||||
params = info["params"]
|
|
||||||
# Batched copy back (single kernel instead of N individual copies)
|
|
||||||
torch._foreach_copy_(params, list(stacked_params[:len(params)].unbind(0)))
|
|
||||||
528
nanochat/optim.py
Normal file
528
nanochat/optim.py
Normal file
@@ -0,0 +1,528 @@
|
|||||||
|
"""
|
||||||
|
A nice and efficient mixed AdamW/Muon Combined Optimizer.
|
||||||
|
Usually the embeddings and scalars go into AdamW, and the matrix parameters go into Muon.
|
||||||
|
Two versions are provided (MuonAdamW, DistMuonAdamW), for single GPU and distributed.
|
||||||
|
|
||||||
|
Addapted from: https://github.com/KellerJordan/modded-nanogpt
|
||||||
|
Further contributions from @karpathy and @chrisjmccormick.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
"""
|
||||||
|
Good old AdamW optimizer, fused kernel.
|
||||||
|
https://arxiv.org/abs/1711.05101
|
||||||
|
"""
|
||||||
|
|
||||||
|
@torch.compile(dynamic=False, fullgraph=True)
|
||||||
|
def adamw_step_fused(
|
||||||
|
p: Tensor, # (32768, 768) - parameter tensor
|
||||||
|
grad: Tensor, # (32768, 768) - gradient, same shape as p
|
||||||
|
exp_avg: Tensor, # (32768, 768) - first moment, same shape as p
|
||||||
|
exp_avg_sq: Tensor, # (32768, 768) - second moment, same shape as p
|
||||||
|
step_t: Tensor, # () - 0-D CPU tensor, step count
|
||||||
|
lr_t: Tensor, # () - 0-D CPU tensor, learning rate
|
||||||
|
beta1_t: Tensor, # () - 0-D CPU tensor, beta1
|
||||||
|
beta2_t: Tensor, # () - 0-D CPU tensor, beta2
|
||||||
|
eps_t: Tensor, # () - 0-D CPU tensor, epsilon
|
||||||
|
wd_t: Tensor, # () - 0-D CPU tensor, weight decay
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Fused AdamW step: weight_decay -> momentum_update -> bias_correction -> param_update
|
||||||
|
All in one compiled graph to eliminate Python overhead between ops.
|
||||||
|
The 0-D CPU tensors avoid recompilation when hyperparameter values change.
|
||||||
|
"""
|
||||||
|
# Weight decay (decoupled, applied before the update)
|
||||||
|
p.mul_(1 - lr_t * wd_t)
|
||||||
|
# Update running averages (lerp_ is cleaner and fuses well)
|
||||||
|
exp_avg.lerp_(grad, 1 - beta1_t)
|
||||||
|
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
|
||||||
|
# Bias corrections
|
||||||
|
bias1 = 1 - beta1_t ** step_t
|
||||||
|
bias2 = 1 - beta2_t ** step_t
|
||||||
|
# Compute update and apply
|
||||||
|
denom = (exp_avg_sq / bias2).sqrt() + eps_t
|
||||||
|
step_size = lr_t / bias1
|
||||||
|
p.add_(exp_avg / denom, alpha=-step_size)
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
"""
|
||||||
|
Muon optimizer adapted and simplified from modded-nanogpt.
|
||||||
|
https://github.com/KellerJordan/modded-nanogpt
|
||||||
|
|
||||||
|
Background:
|
||||||
|
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
||||||
|
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
||||||
|
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
||||||
|
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
||||||
|
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
||||||
|
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||||||
|
performance at all relative to UV^T, where USV^T = G is the SVD.
|
||||||
|
|
||||||
|
Here, an alternative to Newton-Schulz iteration with potentially better convergence properties:
|
||||||
|
Polar Express Sign Method for orthogonalization.
|
||||||
|
https://arxiv.org/pdf/2505.16932
|
||||||
|
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
|
||||||
|
|
||||||
|
Some of the changes in nanochat implementation:
|
||||||
|
- Uses a simpler, more general approach to parameter grouping and stacking
|
||||||
|
- Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step
|
||||||
|
- Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
|
||||||
|
# From https://arxiv.org/pdf/2505.16932
|
||||||
|
polar_express_coeffs = [
|
||||||
|
(8.156554524902461, -22.48329292557795, 15.878769915207462),
|
||||||
|
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
|
||||||
|
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
|
||||||
|
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
|
||||||
|
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
|
||||||
|
]
|
||||||
|
|
||||||
|
@torch.compile(dynamic=False, fullgraph=True)
|
||||||
|
def muon_step_fused(
|
||||||
|
stacked_grads: Tensor, # (12, 768, 3072) - stacked gradients
|
||||||
|
stacked_params: Tensor, # (12, 768, 3072) - stacked parameters
|
||||||
|
momentum_buffer: Tensor, # (12, 768, 3072) - first moment buffer
|
||||||
|
second_momentum_buffer: Tensor, # (12, 768, 1) or (12, 1, 3072) - factored second moment
|
||||||
|
momentum_t: Tensor, # () - 0-D CPU tensor, momentum coefficient
|
||||||
|
lr_t: Tensor, # () - 0-D CPU tensor, learning rate
|
||||||
|
wd_t: Tensor, # () - 0-D CPU tensor, weight decay
|
||||||
|
beta2_t: Tensor, # () - 0-D CPU tensor, beta2 for second moment
|
||||||
|
ns_steps: int, # 5 - number of Newton-Schulz/Polar Express iterations
|
||||||
|
red_dim: int, # -1 or -2 - reduction dimension for variance
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update
|
||||||
|
All in one compiled graph to eliminate Python overhead between ops.
|
||||||
|
Some of the constants are 0-D CPU tensors to avoid recompilation when values change.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Nesterov momentum
|
||||||
|
momentum = momentum_t.to(stacked_grads.dtype)
|
||||||
|
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
|
||||||
|
g = stacked_grads.lerp_(momentum_buffer, momentum)
|
||||||
|
|
||||||
|
# Polar express
|
||||||
|
X = g.bfloat16()
|
||||||
|
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
|
||||||
|
if g.size(-2) > g.size(-1): # Tall matrix
|
||||||
|
for a, b, c in polar_express_coeffs[:ns_steps]:
|
||||||
|
A = X.mT @ X
|
||||||
|
B = b * A + c * (A @ A)
|
||||||
|
X = a * X + X @ B
|
||||||
|
else: # Wide matrix (original math)
|
||||||
|
for a, b, c in polar_express_coeffs[:ns_steps]:
|
||||||
|
A = X @ X.mT
|
||||||
|
B = b * A + c * (A @ A)
|
||||||
|
X = a * X + B @ X
|
||||||
|
g = X
|
||||||
|
|
||||||
|
# Variance reduction
|
||||||
|
beta2 = beta2_t.to(g.dtype)
|
||||||
|
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
|
||||||
|
red_dim_size = g.size(red_dim)
|
||||||
|
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
|
||||||
|
v_norm = v_norm_sq.sqrt()
|
||||||
|
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
|
||||||
|
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
|
||||||
|
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
|
||||||
|
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
|
||||||
|
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
|
||||||
|
g = g * final_scale.to(g.dtype)
|
||||||
|
|
||||||
|
# Cautious weight decay + parameter update
|
||||||
|
lr = lr_t.to(g.dtype)
|
||||||
|
wd = wd_t.to(g.dtype)
|
||||||
|
mask = (g * stacked_params) >= 0
|
||||||
|
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Single GPU version of the MuonAdamW optimizer.
|
||||||
|
# Used mostly for reference, debugging and testing.
|
||||||
|
|
||||||
|
class MuonAdamW(torch.optim.Optimizer):
|
||||||
|
"""
|
||||||
|
Combined optimizer: Muon for 2D matrix params, AdamW for others, single GPU version.
|
||||||
|
|
||||||
|
AdamW - Fused AdamW optimizer step.
|
||||||
|
|
||||||
|
Muon - MomentUm Orthogonalized by Newton-schulz
|
||||||
|
https://kellerjordan.github.io/posts/muon/
|
||||||
|
|
||||||
|
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
|
||||||
|
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
|
||||||
|
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
|
||||||
|
the advantage that it can be stably run in bfloat16 on the GPU.
|
||||||
|
|
||||||
|
Some warnings:
|
||||||
|
- The Muon optimizer should not be used for the embedding layer, the final fully connected layer,
|
||||||
|
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
|
||||||
|
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
param_groups: List of dicts, each containing:
|
||||||
|
- 'params': List of parameters
|
||||||
|
- 'kind': 'adamw' or 'muon'
|
||||||
|
- For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
|
||||||
|
- For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
|
||||||
|
"""
|
||||||
|
def __init__(self, param_groups: list[dict]):
|
||||||
|
super().__init__(param_groups, defaults={})
|
||||||
|
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||||
|
# AdamW tensors
|
||||||
|
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||||
|
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||||
|
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||||
|
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||||
|
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||||
|
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||||
|
# Muon tensors
|
||||||
|
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||||
|
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||||
|
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||||
|
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||||
|
|
||||||
|
def _step_adamw(self, group: dict) -> None:
|
||||||
|
"""
|
||||||
|
AdamW update for each param in the group individually.
|
||||||
|
Lazy init the state, fill in all 0-D tensors, call the fused kernel.
|
||||||
|
"""
|
||||||
|
for p in group['params']:
|
||||||
|
if p.grad is None:
|
||||||
|
continue
|
||||||
|
grad = p.grad
|
||||||
|
state = self.state[p]
|
||||||
|
|
||||||
|
# State init
|
||||||
|
if not state:
|
||||||
|
state['step'] = 0
|
||||||
|
state['exp_avg'] = torch.zeros_like(p)
|
||||||
|
state['exp_avg_sq'] = torch.zeros_like(p)
|
||||||
|
exp_avg = state['exp_avg']
|
||||||
|
exp_avg_sq = state['exp_avg_sq']
|
||||||
|
state['step'] += 1
|
||||||
|
|
||||||
|
# Fill 0-D tensors with current values
|
||||||
|
self._adamw_step_t.fill_(state['step'])
|
||||||
|
self._adamw_lr_t.fill_(group['lr'])
|
||||||
|
self._adamw_beta1_t.fill_(group['betas'][0])
|
||||||
|
self._adamw_beta2_t.fill_(group['betas'][1])
|
||||||
|
self._adamw_eps_t.fill_(group['eps'])
|
||||||
|
self._adamw_wd_t.fill_(group['weight_decay'])
|
||||||
|
|
||||||
|
# Fused update: weight_decay -> momentum -> bias_correction -> param_update
|
||||||
|
adamw_step_fused(
|
||||||
|
p, grad, exp_avg, exp_avg_sq,
|
||||||
|
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
|
||||||
|
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _step_muon(self, group: dict) -> None:
|
||||||
|
"""
|
||||||
|
Muon update for all params in the group (stacked for efficiency).
|
||||||
|
Lazy init the state, fill in all 0-D tensors, call the fused kernel.
|
||||||
|
"""
|
||||||
|
params: list[Tensor] = group['params']
|
||||||
|
if not params:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get or create group-level buffers (stored in first param's state for convenience)
|
||||||
|
p = params[0]
|
||||||
|
state = self.state[p]
|
||||||
|
num_params = len(params)
|
||||||
|
shape, device, dtype = p.shape, p.device, p.dtype
|
||||||
|
|
||||||
|
# Momentum for every individual parameter
|
||||||
|
if "momentum_buffer" not in state:
|
||||||
|
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
|
||||||
|
momentum_buffer = state["momentum_buffer"]
|
||||||
|
|
||||||
|
# Second momentum buffer is factored, either per-row or per-column
|
||||||
|
if "second_momentum_buffer" not in state:
|
||||||
|
state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1])
|
||||||
|
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
|
||||||
|
second_momentum_buffer = state["second_momentum_buffer"]
|
||||||
|
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
||||||
|
|
||||||
|
# Stack grads and params (NOTE: this assumes all params have the same shape)
|
||||||
|
stacked_grads = torch.stack([p.grad for p in params])
|
||||||
|
stacked_params = torch.stack(params)
|
||||||
|
|
||||||
|
# Fill all the 0-D tensors with current values
|
||||||
|
self._muon_momentum_t.fill_(group["momentum"])
|
||||||
|
self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
||||||
|
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
||||||
|
self._muon_wd_t.fill_(group["weight_decay"])
|
||||||
|
|
||||||
|
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
|
||||||
|
muon_step_fused(
|
||||||
|
stacked_grads,
|
||||||
|
stacked_params,
|
||||||
|
momentum_buffer,
|
||||||
|
second_momentum_buffer,
|
||||||
|
self._muon_momentum_t,
|
||||||
|
self._muon_lr_t,
|
||||||
|
self._muon_wd_t,
|
||||||
|
self._muon_beta2_t,
|
||||||
|
group["ns_steps"],
|
||||||
|
red_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copy back to original params
|
||||||
|
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step(self):
|
||||||
|
for group in self.param_groups:
|
||||||
|
if group['kind'] == 'adamw':
|
||||||
|
self._step_adamw(group)
|
||||||
|
elif group['kind'] == 'muon':
|
||||||
|
self._step_muon(group)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Distributed version of the MuonAdamW optimizer.
|
||||||
|
# Used for training on multiple GPUs.
|
||||||
|
|
||||||
|
class DistMuonAdamW(torch.optim.Optimizer):
|
||||||
|
"""
|
||||||
|
Combined distributed optimizer: Muon for 2D matrix params, AdamW for others.
|
||||||
|
|
||||||
|
See MuonAdamW for the algorithmic details of each optimizer. This class adds
|
||||||
|
distributed communication to enable multi-GPU training without PyTorch DDP.
|
||||||
|
|
||||||
|
Design Goals:
|
||||||
|
- Overlap communication with computation (async ops)
|
||||||
|
- Minimize memory by sharding optimizer states across ranks (ZeRO-2 style)
|
||||||
|
- Batch small tensors into single comm ops where possible
|
||||||
|
|
||||||
|
Communication Pattern (3-phase async):
|
||||||
|
We use a 3-phase structure to maximize overlap between communication and compute:
|
||||||
|
|
||||||
|
Phase 1: Launch all async reduce ops
|
||||||
|
- Kick off all reduce_scatter/all_reduce operations
|
||||||
|
- Don't wait - let them run in background while we continue
|
||||||
|
|
||||||
|
Phase 2: Wait for reduces, compute updates, launch gathers
|
||||||
|
- For each group: wait for its reduce, compute the update, launch gather
|
||||||
|
- By processing groups in order, earlier gathers run while later computes happen
|
||||||
|
|
||||||
|
Phase 3: Wait for gathers, copy back
|
||||||
|
- Wait for all gathers to complete
|
||||||
|
- Copy updated params back to original tensors (Muon only)
|
||||||
|
|
||||||
|
AdamW Communication (ZeRO-2 style):
|
||||||
|
- Small params (<1024 elements): all_reduce gradients, update full param on each rank.
|
||||||
|
Optimizer state is replicated but these params are tiny (scalars, biases).
|
||||||
|
- Large params: reduce_scatter gradients so each rank gets 1/N of the grad, update
|
||||||
|
only that slice, then all_gather the updated slices. Optimizer state (exp_avg,
|
||||||
|
exp_avg_sq) is sharded - each rank only stores state for its slice.
|
||||||
|
Requires param.shape[0] divisible by world_size.
|
||||||
|
|
||||||
|
Muon Communication (stacked + chunked):
|
||||||
|
- All params in a Muon group must have the same shape (caller's responsibility).
|
||||||
|
- Stack all K params into a single (K, *shape) tensor for efficient comm.
|
||||||
|
- Divide K params across N ranks: each rank "owns" ceil(K/N) params.
|
||||||
|
- reduce_scatter the stacked grads so each rank gets its chunk.
|
||||||
|
- Each rank computes Muon update only for params it owns.
|
||||||
|
- all_gather the updated params back to all ranks.
|
||||||
|
- Optimizer state (momentum_buffer, second_momentum_buffer) is sharded by chunk.
|
||||||
|
- Padding: if K doesn't divide evenly, we zero-pad to (ceil(K/N) * N) for comm,
|
||||||
|
then ignore the padding when copying back.
|
||||||
|
|
||||||
|
Buffer Reuse:
|
||||||
|
- For Muon, we allocate stacked_grads for reduce_scatter input, then reuse the
|
||||||
|
same buffer as the output for all_gather (stacked_params). This saves memory
|
||||||
|
since we don't need both buffers simultaneously.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
param_groups: List of dicts, each containing:
|
||||||
|
- 'params': List of parameters
|
||||||
|
- 'kind': 'adamw' or 'muon'
|
||||||
|
- For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
|
||||||
|
- For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
|
||||||
|
"""
|
||||||
|
def __init__(self, param_groups: list[dict]):
|
||||||
|
super().__init__(param_groups, defaults={})
|
||||||
|
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||||
|
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||||
|
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||||
|
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||||
|
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||||
|
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||||
|
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||||
|
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||||
|
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||||
|
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||||
|
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||||
|
|
||||||
|
def _reduce_adamw(self, group: dict, world_size: int) -> dict:
|
||||||
|
"""Launch async reduce ops for AdamW group. Returns info dict with per-param infos."""
|
||||||
|
param_infos = {}
|
||||||
|
for p in group['params']:
|
||||||
|
grad = p.grad
|
||||||
|
if p.numel() < 1024:
|
||||||
|
# Small params: all_reduce (no scatter/gather needed)
|
||||||
|
future = dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
||||||
|
param_infos[p] = dict(future=future, grad_slice=grad, is_small=True)
|
||||||
|
else:
|
||||||
|
# Large params: reduce_scatter
|
||||||
|
rank_size = grad.shape[0] // world_size
|
||||||
|
grad_slice = torch.empty_like(grad[:rank_size])
|
||||||
|
future = dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
||||||
|
param_infos[p] = dict(future=future, grad_slice=grad_slice, is_small=False)
|
||||||
|
return dict(param_infos=param_infos)
|
||||||
|
|
||||||
|
def _reduce_muon(self, group: dict, world_size: int) -> dict:
|
||||||
|
"""Launch async reduce op for Muon group. Returns info dict."""
|
||||||
|
params = group['params']
|
||||||
|
chunk_size = (len(params) + world_size - 1) // world_size
|
||||||
|
padded_num_params = chunk_size * world_size
|
||||||
|
p = params[0]
|
||||||
|
shape, device, dtype = p.shape, p.device, p.dtype
|
||||||
|
|
||||||
|
# Stack grads and zero-pad to padded_num_params
|
||||||
|
grad_stack = torch.stack([p.grad for p in params])
|
||||||
|
stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device)
|
||||||
|
stacked_grads[:len(params)].copy_(grad_stack)
|
||||||
|
if len(params) < padded_num_params:
|
||||||
|
stacked_grads[len(params):].zero_()
|
||||||
|
|
||||||
|
# Reduce_scatter to get this rank's chunk
|
||||||
|
grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
||||||
|
future = dist.reduce_scatter_tensor(grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
||||||
|
|
||||||
|
return dict(future=future, grad_chunk=grad_chunk, stacked_grads=stacked_grads, chunk_size=chunk_size)
|
||||||
|
|
||||||
|
def _compute_adamw(self, group: dict, info: dict, gather_list: list, rank: int, world_size: int) -> None:
|
||||||
|
"""Wait for reduce, compute AdamW updates, launch gathers for large params."""
|
||||||
|
param_infos = info['param_infos']
|
||||||
|
for p in group['params']:
|
||||||
|
pinfo = param_infos[p]
|
||||||
|
pinfo['future'].wait()
|
||||||
|
grad_slice = pinfo['grad_slice']
|
||||||
|
state = self.state[p]
|
||||||
|
|
||||||
|
# For small params, operate on full param; for large, operate on slice
|
||||||
|
if pinfo['is_small']:
|
||||||
|
p_slice = p
|
||||||
|
else:
|
||||||
|
rank_size = p.shape[0] // world_size
|
||||||
|
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
|
||||||
|
|
||||||
|
# State init
|
||||||
|
if not state:
|
||||||
|
state['step'] = 0
|
||||||
|
state['exp_avg'] = torch.zeros_like(p_slice)
|
||||||
|
state['exp_avg_sq'] = torch.zeros_like(p_slice)
|
||||||
|
state['step'] += 1
|
||||||
|
|
||||||
|
# Fill 0-D tensors and run fused kernel
|
||||||
|
self._adamw_step_t.fill_(state['step'])
|
||||||
|
self._adamw_lr_t.fill_(group['lr'])
|
||||||
|
self._adamw_beta1_t.fill_(group['betas'][0])
|
||||||
|
self._adamw_beta2_t.fill_(group['betas'][1])
|
||||||
|
self._adamw_eps_t.fill_(group['eps'])
|
||||||
|
self._adamw_wd_t.fill_(group['weight_decay'])
|
||||||
|
adamw_step_fused(
|
||||||
|
p_slice, grad_slice, state['exp_avg'], state['exp_avg_sq'],
|
||||||
|
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
|
||||||
|
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Large params need all_gather
|
||||||
|
if not pinfo['is_small']:
|
||||||
|
future = dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()
|
||||||
|
gather_list.append(dict(future=future, params=None))
|
||||||
|
|
||||||
|
def _compute_muon(self, group: dict, info: dict, gather_list: list, rank: int) -> None:
|
||||||
|
"""Wait for reduce, compute Muon updates, launch gather."""
|
||||||
|
info['future'].wait()
|
||||||
|
params = group['params']
|
||||||
|
chunk_size = info['chunk_size']
|
||||||
|
grad_chunk = info['grad_chunk']
|
||||||
|
p = params[0]
|
||||||
|
shape, device, dtype = p.shape, p.device, p.dtype
|
||||||
|
|
||||||
|
# How many params does this rank own?
|
||||||
|
start_idx = rank * chunk_size
|
||||||
|
num_owned = min(chunk_size, max(0, len(params) - start_idx))
|
||||||
|
|
||||||
|
# Get or create group-level state
|
||||||
|
state = self.state[p]
|
||||||
|
if "momentum_buffer" not in state:
|
||||||
|
state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
|
||||||
|
if "second_momentum_buffer" not in state:
|
||||||
|
state_shape = (chunk_size, shape[-2], 1) if shape[-2] >= shape[-1] else (chunk_size, 1, shape[-1])
|
||||||
|
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
|
||||||
|
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
||||||
|
|
||||||
|
# Build output buffer for all_gather
|
||||||
|
updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
if num_owned > 0:
|
||||||
|
owned_params = [params[start_idx + i] for i in range(num_owned)]
|
||||||
|
stacked_owned = torch.stack(owned_params)
|
||||||
|
|
||||||
|
# Fill 0-D tensors and run fused kernel
|
||||||
|
self._muon_momentum_t.fill_(group["momentum"])
|
||||||
|
self._muon_beta2_t.fill_(group["beta2"])
|
||||||
|
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
||||||
|
self._muon_wd_t.fill_(group["weight_decay"])
|
||||||
|
muon_step_fused(
|
||||||
|
grad_chunk[:num_owned], stacked_owned,
|
||||||
|
state["momentum_buffer"][:num_owned], state["second_momentum_buffer"][:num_owned],
|
||||||
|
self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t, self._muon_beta2_t,
|
||||||
|
group["ns_steps"], red_dim,
|
||||||
|
)
|
||||||
|
updated_params[:num_owned].copy_(stacked_owned)
|
||||||
|
|
||||||
|
if num_owned < chunk_size:
|
||||||
|
updated_params[num_owned:].zero_()
|
||||||
|
|
||||||
|
# Reuse stacked_grads buffer for all_gather output
|
||||||
|
stacked_params = info["stacked_grads"]
|
||||||
|
future = dist.all_gather_into_tensor(stacked_params, updated_params, async_op=True).get_future()
|
||||||
|
gather_list.append(dict(future=future, stacked_params=stacked_params, params=params))
|
||||||
|
|
||||||
|
def _finish_gathers(self, gather_list: list) -> None:
|
||||||
|
"""Wait for all gathers and copy Muon params back."""
|
||||||
|
for info in gather_list:
|
||||||
|
info["future"].wait()
|
||||||
|
if info["params"] is not None:
|
||||||
|
# Muon: copy from stacked buffer back to individual params
|
||||||
|
torch._foreach_copy_(info["params"], list(info["stacked_params"][:len(info["params"])].unbind(0)))
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step(self):
|
||||||
|
rank = dist.get_rank()
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
|
||||||
|
# Phase 1: launch all async reduce ops
|
||||||
|
reduce_infos: list[dict] = []
|
||||||
|
for group in self.param_groups:
|
||||||
|
if group['kind'] == 'adamw':
|
||||||
|
reduce_infos.append(self._reduce_adamw(group, world_size))
|
||||||
|
elif group['kind'] == 'muon':
|
||||||
|
reduce_infos.append(self._reduce_muon(group, world_size))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
|
||||||
|
|
||||||
|
# Phase 2: wait for reduces, compute updates, launch gathers
|
||||||
|
gather_list: list[dict] = []
|
||||||
|
for group, info in zip(self.param_groups, reduce_infos):
|
||||||
|
if group['kind'] == 'adamw':
|
||||||
|
self._compute_adamw(group, info, gather_list, rank, world_size)
|
||||||
|
elif group['kind'] == 'muon':
|
||||||
|
self._compute_muon(group, info, gather_list, rank)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
|
||||||
|
|
||||||
|
# Phase 3: wait for gathers, copy back
|
||||||
|
self._finish_gathers(gather_list)
|
||||||
@@ -211,9 +211,9 @@ print0(f"Tokens : Scaling params ratio: {args.total_batch_size * num_iterations
|
|||||||
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
|
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
|
||||||
adam_betas = (args.adam_beta1, args.adam_beta2)
|
adam_betas = (args.adam_beta1, args.adam_beta2)
|
||||||
optimizers = model.setup_optimizers(
|
optimizer = model.setup_optimizer(
|
||||||
unembedding_lr=args.unembedding_lr * batch_lr_scale,
|
unembedding_lr=args.unembedding_lr * batch_lr_scale,
|
||||||
embedding_lr=args.embedding_lr * batch_lr_scale,
|
embedding_lr=args.embedding_lr * batch_lr_scale,
|
||||||
matrix_lr=args.matrix_lr * batch_lr_scale,
|
matrix_lr=args.matrix_lr * batch_lr_scale,
|
||||||
@@ -221,12 +221,10 @@ optimizers = model.setup_optimizers(
|
|||||||
adam_betas=adam_betas,
|
adam_betas=adam_betas,
|
||||||
scalar_lr=args.scalar_lr * batch_lr_scale,
|
scalar_lr=args.scalar_lr * batch_lr_scale,
|
||||||
)
|
)
|
||||||
adamw_optimizer, muon_optimizer = optimizers
|
|
||||||
|
|
||||||
if resuming:
|
if resuming:
|
||||||
for opt, dat in zip(optimizers, optimizer_data):
|
optimizer.load_state_dict(optimizer_data)
|
||||||
opt.load_state_dict(dat)
|
del optimizer_data
|
||||||
del optimizer_data # free up the memory
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Initialize the DataLoaders for train/val
|
# Initialize the DataLoaders for train/val
|
||||||
@@ -344,7 +342,7 @@ while True:
|
|||||||
checkpoint_dir,
|
checkpoint_dir,
|
||||||
step,
|
step,
|
||||||
orig_model.state_dict(), # model parameters
|
orig_model.state_dict(), # model parameters
|
||||||
[opt.state_dict() for opt in optimizers], # optimizer states
|
optimizer.state_dict(), # optimizer state
|
||||||
{ # metadata saved as json
|
{ # metadata saved as json
|
||||||
"step": step,
|
"step": step,
|
||||||
"val_bpb": val_bpb, # loss at last step
|
"val_bpb": val_bpb, # loss at last step
|
||||||
@@ -378,18 +376,16 @@ while True:
|
|||||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||||
loss.backward()
|
loss.backward()
|
||||||
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||||
# step the optimizers
|
# step the optimizer
|
||||||
lrm = get_lr_multiplier(step)
|
lrm = get_lr_multiplier(step)
|
||||||
for opt in optimizers:
|
|
||||||
for group in opt.param_groups:
|
|
||||||
group["lr"] = group["initial_lr"] * lrm
|
|
||||||
muon_momentum = get_muon_momentum(step)
|
muon_momentum = get_muon_momentum(step)
|
||||||
muon_weight_decay = get_weight_decay(step)
|
muon_weight_decay = get_weight_decay(step)
|
||||||
for group in muon_optimizer.param_groups:
|
for group in optimizer.param_groups:
|
||||||
group["momentum"] = muon_momentum
|
group["lr"] = group["initial_lr"] * lrm
|
||||||
group["weight_decay"] = muon_weight_decay
|
if group['kind'] == 'muon':
|
||||||
for opt in optimizers:
|
group["momentum"] = muon_momentum
|
||||||
opt.step()
|
group["weight_decay"] = muon_weight_decay
|
||||||
|
optimizer.step()
|
||||||
model.zero_grad(set_to_none=True)
|
model.zero_grad(set_to_none=True)
|
||||||
train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point
|
train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point
|
||||||
synchronize()
|
synchronize()
|
||||||
|
|||||||
@@ -201,7 +201,7 @@ def run_gsm8k_eval(task, tokenizer, engine,
|
|||||||
# Training loop
|
# Training loop
|
||||||
|
|
||||||
# Init the optimizer
|
# Init the optimizer
|
||||||
optimizers = model.setup_optimizers(
|
optimizer = model.setup_optimizer(
|
||||||
unembedding_lr=args.unembedding_lr,
|
unembedding_lr=args.unembedding_lr,
|
||||||
embedding_lr=args.embedding_lr,
|
embedding_lr=args.embedding_lr,
|
||||||
matrix_lr=args.matrix_lr,
|
matrix_lr=args.matrix_lr,
|
||||||
@@ -209,10 +209,9 @@ optimizers = model.setup_optimizers(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Set the initial learning rate as a fraction of the base learning rate
|
# Set the initial learning rate as a fraction of the base learning rate
|
||||||
for opt in optimizers:
|
for group in optimizer.param_groups:
|
||||||
for group in opt.param_groups:
|
group["lr"] = group["lr"] * args.init_lr_frac
|
||||||
group["lr"] = group["lr"] * args.init_lr_frac
|
group["initial_lr"] = group["lr"]
|
||||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
|
||||||
|
|
||||||
# Learning rate scheduler: simple rampdown to zero over num_steps
|
# Learning rate scheduler: simple rampdown to zero over num_steps
|
||||||
def get_lr_multiplier(it):
|
def get_lr_multiplier(it):
|
||||||
@@ -305,11 +304,9 @@ for step in range(num_steps):
|
|||||||
|
|
||||||
# Update the model parameters
|
# Update the model parameters
|
||||||
lrm = get_lr_multiplier(step)
|
lrm = get_lr_multiplier(step)
|
||||||
for opt in optimizers: # first set the learning rate
|
for group in optimizer.param_groups:
|
||||||
for group in opt.param_groups:
|
group["lr"] = group["initial_lr"] * lrm
|
||||||
group["lr"] = group["initial_lr"] * lrm
|
optimizer.step()
|
||||||
for opt in optimizers: # then step the optimizers
|
|
||||||
opt.step()
|
|
||||||
model.zero_grad(set_to_none=True)
|
model.zero_grad(set_to_none=True)
|
||||||
wandb_run.log({
|
wandb_run.log({
|
||||||
"step": step,
|
"step": step,
|
||||||
|
|||||||
@@ -150,17 +150,16 @@ build_val_loader = lambda: sft_data_generator(val_ds, batch_size=args.device_bat
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Initialize the Optimizer
|
# Initialize the Optimizer
|
||||||
|
|
||||||
optimizers = model.setup_optimizers(
|
optimizer = model.setup_optimizer(
|
||||||
unembedding_lr=args.unembedding_lr,
|
unembedding_lr=args.unembedding_lr,
|
||||||
embedding_lr=args.embedding_lr,
|
embedding_lr=args.embedding_lr,
|
||||||
matrix_lr=args.matrix_lr,
|
matrix_lr=args.matrix_lr,
|
||||||
weight_decay=args.weight_decay,
|
weight_decay=args.weight_decay,
|
||||||
)
|
)
|
||||||
# Set the initial learning rate as a fraction of the base learning rate
|
# Set the initial learning rate as a fraction of the base learning rate
|
||||||
for opt in optimizers:
|
for group in optimizer.param_groups:
|
||||||
for group in opt.param_groups:
|
group["lr"] = group["lr"] * args.init_lr_frac
|
||||||
group["lr"] = group["lr"] * args.init_lr_frac
|
group["initial_lr"] = group["lr"]
|
||||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Training loop
|
# Training loop
|
||||||
@@ -230,13 +229,11 @@ for step in range(num_iterations):
|
|||||||
|
|
||||||
# learning rate scheduler
|
# learning rate scheduler
|
||||||
lrm = get_lr_multiplier(step)
|
lrm = get_lr_multiplier(step)
|
||||||
for opt in optimizers:
|
for group in optimizer.param_groups:
|
||||||
for group in opt.param_groups:
|
group["lr"] = group["initial_lr"] * lrm
|
||||||
group["lr"] = group["initial_lr"] * lrm
|
|
||||||
|
|
||||||
# step the optimizers
|
# step the optimizer
|
||||||
for opt in optimizers:
|
optimizer.step()
|
||||||
opt.step()
|
|
||||||
model.zero_grad(set_to_none=True)
|
model.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
# logging
|
# logging
|
||||||
|
|||||||
@@ -93,14 +93,12 @@ print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
|||||||
print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||||
token_bytes = get_token_bytes(device=device)
|
token_bytes = get_token_bytes(device=device)
|
||||||
|
|
||||||
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
|
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
|
||||||
optimizers = model.setup_optimizers(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay)
|
optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay)
|
||||||
adamw_optimizer, muon_optimizer = optimizers
|
|
||||||
# Override the initial learning rate as a fraction of the base learning rate
|
# Override the initial learning rate as a fraction of the base learning rate
|
||||||
for opt in optimizers:
|
for group in optimizer.param_groups:
|
||||||
for group in opt.param_groups:
|
group["lr"] = group["lr"] * args.init_lr_frac
|
||||||
group["lr"] = group["lr"] * args.init_lr_frac
|
group["initial_lr"] = group["lr"]
|
||||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
|
||||||
|
|
||||||
# Midtraining data mixture and DataLoader
|
# Midtraining data mixture and DataLoader
|
||||||
base_dir = get_base_dir()
|
base_dir = get_base_dir()
|
||||||
@@ -274,7 +272,7 @@ while True:
|
|||||||
checkpoint_dir,
|
checkpoint_dir,
|
||||||
step,
|
step,
|
||||||
orig_model.state_dict(),
|
orig_model.state_dict(),
|
||||||
[opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly
|
optimizer.state_dict(),
|
||||||
{
|
{
|
||||||
"step": step,
|
"step": step,
|
||||||
"val_bpb": val_bpb, # loss at last step
|
"val_bpb": val_bpb, # loss at last step
|
||||||
@@ -306,16 +304,14 @@ while True:
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||||
progress = max(progress, approx_progress) # only increase progress monotonically
|
progress = max(progress, approx_progress) # only increase progress monotonically
|
||||||
# step the optimizers
|
# step the optimizer
|
||||||
lrm = get_lr_multiplier(progress)
|
lrm = get_lr_multiplier(progress)
|
||||||
for opt in optimizers:
|
|
||||||
for group in opt.param_groups:
|
|
||||||
group["lr"] = group["initial_lr"] * lrm
|
|
||||||
muon_momentum = get_muon_momentum(step)
|
muon_momentum = get_muon_momentum(step)
|
||||||
for group in muon_optimizer.param_groups:
|
for group in optimizer.param_groups:
|
||||||
group["momentum"] = muon_momentum
|
group["lr"] = group["initial_lr"] * lrm
|
||||||
for opt in optimizers:
|
if group['kind'] == 'muon':
|
||||||
opt.step()
|
group["momentum"] = muon_momentum
|
||||||
|
optimizer.step()
|
||||||
model.zero_grad(set_to_none=True)
|
model.zero_grad(set_to_none=True)
|
||||||
synchronize()
|
synchronize()
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
|
|||||||
Reference in New Issue
Block a user