mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-30 04:22:02 +00:00
changes and optimizations to muon, making it more efficient and simpler/cleaner a bit
This commit is contained in:
455
nanochat/muon.py
455
nanochat/muon.py
@@ -1,7 +1,27 @@
|
|||||||
"""
|
"""
|
||||||
Muon optimizer adapted (simplified) from modded-nanogpt.
|
Muon optimizer adapted and simplified from modded-nanogpt.
|
||||||
https://github.com/KellerJordan/modded-nanogpt
|
https://github.com/KellerJordan/modded-nanogpt
|
||||||
|
|
||||||
|
Background:
|
||||||
|
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
||||||
|
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
||||||
|
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
||||||
|
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
||||||
|
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
||||||
|
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||||||
|
performance at all relative to UV^T, where USV^T = G is the SVD.
|
||||||
|
|
||||||
|
Here, an alternative to Newton-Schulz iteration with potentially better convergence properties:
|
||||||
|
Polar Express Sign Method for orthogonalization.
|
||||||
|
https://arxiv.org/pdf/2505.16932
|
||||||
|
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
|
||||||
|
|
||||||
|
Some of the changes in nanochat implementation:
|
||||||
|
- Uses a simpler, more general approach to parameter grouping and stacking
|
||||||
|
- Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step
|
||||||
|
- Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@@ -16,97 +36,61 @@ polar_express_coeffs = [
|
|||||||
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
|
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@torch.compile(dynamic=False, fullgraph=True)
|
||||||
@torch.compile
|
def muon_step_fused(
|
||||||
def zeropower_via_polar_express(G: Tensor, steps: int = 5) -> Tensor:
|
stacked_grads: Tensor,
|
||||||
|
stacked_params: Tensor,
|
||||||
|
momentum_buffer: Tensor,
|
||||||
|
second_momentum_buffer: Tensor,
|
||||||
|
momentum_t: Tensor,
|
||||||
|
lr_t: Tensor,
|
||||||
|
wd_t: Tensor,
|
||||||
|
beta2_t: Tensor,
|
||||||
|
ns_steps: int,
|
||||||
|
red_dim: int,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Polar Express Sign Method for orthogonalization.
|
Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update
|
||||||
https://arxiv.org/pdf/2505.16932
|
All in one compiled graph to eliminate Python overhead between ops.
|
||||||
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
|
Some of the constants are 0-D CPU tensors to avoid recompilation when values change.
|
||||||
|
|
||||||
Alternative to Newton-Schulz iteration with potentially better convergence properties.
|
|
||||||
"""
|
"""
|
||||||
assert G.ndim >= 2
|
|
||||||
X = G.bfloat16()
|
# Nesterov momentum
|
||||||
if G.size(-2) > G.size(-1):
|
momentum = momentum_t.to(stacked_grads.dtype)
|
||||||
|
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
|
||||||
|
g = stacked_grads.lerp_(momentum_buffer, momentum)
|
||||||
|
|
||||||
|
# Polar express
|
||||||
|
X = g.bfloat16()
|
||||||
|
if g.size(-2) > g.size(-1):
|
||||||
X = X.mT
|
X = X.mT
|
||||||
|
|
||||||
# Ensure spectral norm is at most 1 (with 2% safety factor)
|
|
||||||
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
|
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
|
||||||
|
for a, b, c in polar_express_coeffs[:ns_steps]:
|
||||||
# Perform the iterations (cap at available coefficients)
|
|
||||||
for a, b, c in polar_express_coeffs[:min(steps, len(polar_express_coeffs))]:
|
|
||||||
A = X @ X.mT
|
A = X @ X.mT
|
||||||
B = b * A + c * (A @ A)
|
B = b * A + c * (A @ A)
|
||||||
X = a * X + B @ X
|
X = a * X + B @ X
|
||||||
|
if g.size(-2) > g.size(-1):
|
||||||
if G.size(-2) > G.size(-1):
|
|
||||||
X = X.mT
|
X = X.mT
|
||||||
return X
|
g = X
|
||||||
|
|
||||||
|
# Variance reduction
|
||||||
@torch.compile
|
beta2 = beta2_t.to(g.dtype)
|
||||||
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
|
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
|
||||||
"""
|
red_dim_size = g.size(red_dim)
|
||||||
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
|
||||||
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
|
||||||
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
|
||||||
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
|
||||||
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
|
||||||
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
|
||||||
performance at all relative to UV^T, where USV^T = G is the SVD.
|
|
||||||
"""
|
|
||||||
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
|
||||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
|
||||||
X = G.bfloat16()
|
|
||||||
if G.size(-2) > G.size(-1):
|
|
||||||
X = X.mT
|
|
||||||
|
|
||||||
# Ensure spectral norm is at most 1
|
|
||||||
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
|
||||||
# Perform the NS iterations
|
|
||||||
for _ in range(steps):
|
|
||||||
A = X @ X.mT
|
|
||||||
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
|
||||||
X = a * X + B @ X
|
|
||||||
|
|
||||||
if G.size(-2) > G.size(-1):
|
|
||||||
X = X.mT
|
|
||||||
return X
|
|
||||||
|
|
||||||
|
|
||||||
@torch.compile
|
|
||||||
def apply_variance_reduction(v: Tensor, second_momentum_buffer: Tensor, beta2: float) -> Tensor:
|
|
||||||
"""
|
|
||||||
NorMuon-style variance reduction, similar to Adafactor's low-rank variance estimator.
|
|
||||||
https://arxiv.org/pdf/2510.05491
|
|
||||||
|
|
||||||
Normalizes updates based on a running estimate of per-row (or per-column) variance.
|
|
||||||
The reduction dimension is determined by the shape of second_momentum_buffer.
|
|
||||||
"""
|
|
||||||
# Determine reduction dimension from buffer shape
|
|
||||||
red_dim = -1 if second_momentum_buffer.size(-1) == 1 else -2
|
|
||||||
|
|
||||||
# Compute per-row/col mean of squared values
|
|
||||||
v_mean = v.float().square().mean(dim=red_dim, keepdim=True)
|
|
||||||
red_dim_size = v.size(red_dim)
|
|
||||||
|
|
||||||
# Compute current norm
|
|
||||||
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
|
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
|
||||||
v_norm = v_norm_sq.sqrt()
|
v_norm = v_norm_sq.sqrt()
|
||||||
|
|
||||||
# Update second momentum buffer (EMA of variance)
|
|
||||||
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
|
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
|
||||||
|
|
||||||
# Compute scaling factor from second momentum
|
|
||||||
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
|
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
|
||||||
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
|
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
|
||||||
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
|
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
|
||||||
|
|
||||||
# Final scale preserves overall norm while adjusting per-row/col
|
|
||||||
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
|
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
|
||||||
return v.mul(final_scale.to(v.dtype))
|
g = g * final_scale.to(g.dtype)
|
||||||
|
|
||||||
|
# Cautious weight decay + parameter update
|
||||||
|
lr = lr_t.to(g.dtype)
|
||||||
|
wd = wd_t.to(g.dtype)
|
||||||
|
mask = (g * stacked_params) >= 0
|
||||||
|
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
|
||||||
|
|
||||||
class Muon(torch.optim.Optimizer):
|
class Muon(torch.optim.Optimizer):
|
||||||
"""
|
"""
|
||||||
@@ -127,94 +111,112 @@ class Muon(torch.optim.Optimizer):
|
|||||||
Arguments:
|
Arguments:
|
||||||
lr: The learning rate used by the internal SGD.
|
lr: The learning rate used by the internal SGD.
|
||||||
momentum: The momentum used by the internal SGD.
|
momentum: The momentum used by the internal SGD.
|
||||||
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
|
|
||||||
ns_steps: The number of Newton-Schulz iteration steps to use.
|
ns_steps: The number of Newton-Schulz iteration steps to use.
|
||||||
beta2: The decay rate for the second moment (variance) estimate. Set to None to disable.
|
beta2: The decay rate for the second moment (variance) estimate. Set to None to disable.
|
||||||
weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
|
weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
|
||||||
"""
|
"""
|
||||||
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5, beta2=0.95, weight_decay=0.0):
|
def __init__(self, params, lr=0.02, momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=0.0):
|
||||||
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
|
defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
|
||||||
params: list[Tensor] = [*params]
|
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
|
||||||
|
params = list(params) # ensure we have a list, not an e.g. (exhaustible) iterator
|
||||||
|
# Group by shape so we can stack tensors
|
||||||
|
shapes = sorted({p.shape for p in params})
|
||||||
param_groups = []
|
param_groups = []
|
||||||
for size in {p.numel() for p in params}:
|
for shape in shapes:
|
||||||
group = dict(params=[p for p in params if p.numel() == size])
|
group_params = [p for p in params if p.shape == shape]
|
||||||
param_groups.append(group)
|
param_groups.append(dict(params=group_params))
|
||||||
super().__init__(param_groups, defaults)
|
super().__init__(param_groups, defaults)
|
||||||
|
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||||
|
self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||||
|
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||||
|
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||||
|
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def step(self):
|
def step(self):
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
params: list[Tensor] = group["params"]
|
params: list[Tensor] = group["params"]
|
||||||
for p in params:
|
if not params:
|
||||||
g = p.grad
|
continue
|
||||||
assert g is not None
|
|
||||||
state = self.state[p]
|
# Get or create group-level buffers (stored in first param's state for convenience)
|
||||||
if "momentum_buffer" not in state:
|
state = self.state[params[0]]
|
||||||
state["momentum_buffer"] = torch.zeros_like(g)
|
num_params = len(params) # e.g.: 12 (for a d12 model)
|
||||||
buf: Tensor = state["momentum_buffer"]
|
# e.g.: shape = (768, 3072), device = cuda:0, dtype = torch.float32, for one of the MLP projections
|
||||||
buf.lerp_(g, 1 - group["momentum"])
|
shape, device, dtype = params[0].shape, params[0].device, params[0].dtype
|
||||||
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
|
|
||||||
g = zeropower_via_polar_express(g, steps=group["ns_steps"])
|
# Momentum for every individual parameter
|
||||||
# Variance reduction (NorMuon-style)
|
if "momentum_buffer" not in state:
|
||||||
if group["beta2"] is not None:
|
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
|
||||||
if "second_momentum_buffer" not in state:
|
momentum_buffer = state["momentum_buffer"] # e.g.: (12, 768, 3072)
|
||||||
# Buffer shape determines reduction dim: reduce along larger dimension
|
|
||||||
if p.size(-2) >= p.size(-1):
|
# Second momentum buffer is factored, either per-row or per-column
|
||||||
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1])
|
if "second_momentum_buffer" not in state:
|
||||||
else:
|
if shape[-2] >= shape[-1]:
|
||||||
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1, :])
|
state["second_momentum_buffer"] = torch.zeros(num_params, shape[-2], 1, dtype=dtype, device=device)
|
||||||
g = apply_variance_reduction(g, state["second_momentum_buffer"], group["beta2"])
|
|
||||||
# Parameter update with cautious weight decay
|
|
||||||
effective_lr = group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5
|
|
||||||
wd = group["weight_decay"]
|
|
||||||
if wd != 0:
|
|
||||||
mask = (g * p) >= 0
|
|
||||||
p.sub_(effective_lr * g + effective_lr * wd * p * mask)
|
|
||||||
else:
|
else:
|
||||||
p.sub_(effective_lr * g)
|
state["second_momentum_buffer"] = torch.zeros(num_params, 1, shape[-1], dtype=dtype, device=device)
|
||||||
|
second_momentum_buffer = state["second_momentum_buffer"] # (12, 1, 3072)
|
||||||
|
red_dim = -1 if shape[-2] >= shape[-1] else -2 # e.g.: -2
|
||||||
|
|
||||||
|
# Stack grads and params
|
||||||
|
stacked_grads = torch.stack([p.grad for p in params]) # (12, 768, 3072)
|
||||||
|
stacked_params = torch.stack(params) # (12, 768, 3072)
|
||||||
|
|
||||||
|
# Fill all the 0-D tensors with current values
|
||||||
|
self._momentum_t.fill_(group["momentum"])
|
||||||
|
self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
||||||
|
self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
||||||
|
self._wd_t.fill_(group["weight_decay"])
|
||||||
|
|
||||||
|
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
|
||||||
|
muon_step_fused(
|
||||||
|
stacked_grads,
|
||||||
|
stacked_params,
|
||||||
|
momentum_buffer,
|
||||||
|
second_momentum_buffer,
|
||||||
|
self._momentum_t,
|
||||||
|
self._lr_t,
|
||||||
|
self._wd_t,
|
||||||
|
self._beta2_t,
|
||||||
|
group["ns_steps"],
|
||||||
|
red_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copy back to original params: [(768, 3072), (768, 3072), ...] <- (12, 768, 3072)
|
||||||
|
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
|
||||||
|
|
||||||
|
|
||||||
class DistMuon(torch.optim.Optimizer):
|
class DistMuon(torch.optim.Optimizer):
|
||||||
"""
|
"""
|
||||||
Muon: SGD-momentum + (optional) Nesterov, then orthogonalize the 2D update via Polar Express,
|
Distributed version of the Muon optimizer.
|
||||||
finally apply aspect-ratio scaled step. Performs its own distributed synchronization:
|
|
||||||
- reduce_scatter(AVG) for gradient averaging
|
|
||||||
- all_gather to replicate updated weights
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
* Designed for 2D parameters (e.g., linear/conv kernels reshaped to 2D). Do not use for 0D/1D
|
|
||||||
params like embeddings or scalars.
|
|
||||||
* Momentum buffers are maintained only on the 'owner' rank for each parameter (rank chosen
|
|
||||||
by block-cyclic assignment below). If you checkpoint optimizer state on a single rank,
|
|
||||||
consolidate states beforehand.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
params: iterable of Tensors
|
|
||||||
lr: learning rate
|
|
||||||
momentum: momentum coefficient in [0,1)
|
|
||||||
nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf
|
|
||||||
ns_steps: number of Newton-Schulz iterations for the orthogonalization
|
|
||||||
beta2: decay rate for second moment (variance) estimate. Set to None to disable.
|
|
||||||
weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
|
def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
|
||||||
nesterov: bool = True, ns_steps: int = 5, beta2: float = 0.95, weight_decay: float = 0.0):
|
ns_steps: int = 5, beta2: float = 0.95, weight_decay: float = 0.0):
|
||||||
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
|
defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
|
||||||
params = list(params)
|
|
||||||
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
|
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
|
||||||
|
params = list(params)
|
||||||
|
world_size = dist.get_world_size()
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
# Group all parameters by their shape
|
# Group all parameters by their shape
|
||||||
shapes = sorted({p.shape for p in params}) # sort to ensure consistent / deterministic ordering
|
shapes = sorted({p.shape for p in params}) # sort for deterministic ordering across ranks
|
||||||
param_groups = []
|
param_groups = []
|
||||||
for shape in shapes:
|
for shape in shapes:
|
||||||
group_params = [p for p in params if p.shape == shape]
|
group_params = [p for p in params if p.shape == shape]
|
||||||
device, dtype = group_params[0].device, group_params[0].dtype
|
device, dtype = group_params[0].device, group_params[0].dtype
|
||||||
assert all(p.device == device for p in group_params)
|
assert all(p.device == device for p in group_params)
|
||||||
assert all(p.dtype == dtype for p in group_params)
|
assert all(p.dtype == dtype for p in group_params)
|
||||||
|
# Compute chunk size for this group (how many params each rank owns)
|
||||||
|
chunk_size = (len(group_params) + world_size - 1) // world_size
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
print(f"Muon: Grouping {len(group_params)} params of shape {shape}, device {device}, dtype {dtype}")
|
print(f"Muon: {len(group_params)} params of shape {shape}, chunk_size={chunk_size}")
|
||||||
param_groups.append(dict(params=group_params, zero_buffer=torch.zeros_like(group_params[0])))
|
param_groups.append(dict(params=group_params, chunk_size=chunk_size))
|
||||||
super().__init__(param_groups, defaults)
|
super().__init__(param_groups, defaults)
|
||||||
|
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||||
|
self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||||
|
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||||
|
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||||
|
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def step(self):
|
def step(self):
|
||||||
@@ -224,72 +226,127 @@ class DistMuon(torch.optim.Optimizer):
|
|||||||
# Ensure all grads exist
|
# Ensure all grads exist
|
||||||
assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
|
assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
|
||||||
|
|
||||||
# Kick off all the reduce scatter operations to average up the gradients across all ranks
|
# First pass: stack grads and kick off reduce_scatter for each group
|
||||||
all_reduce_futures = []
|
group_infos = []
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
params = group["params"]
|
params: list[Tensor] = group["params"]
|
||||||
zero_buffer = group["zero_buffer"]
|
chunk_size = group["chunk_size"]
|
||||||
# Go through params in groups of world_size.
|
padded_num_params = chunk_size * world_size
|
||||||
for base_i in range(0, len(params), world_size):
|
shape = params[0].shape
|
||||||
# The compute owner of each param is rank i % world_size
|
device, dtype = params[0].device, params[0].dtype
|
||||||
owner_idx = base_i + rank
|
|
||||||
# each rank stacks up its chunk of world_size params into a list
|
|
||||||
rs_input = [p.grad for p in params[base_i:base_i + world_size]]
|
|
||||||
# pad rs_input with the zero buffer to complete the group
|
|
||||||
rs_input.extend([zero_buffer] * (world_size - len(rs_input)))
|
|
||||||
# the output buffer gets strided across the group based on the rank
|
|
||||||
rs_output = params[owner_idx].grad if owner_idx < len(params) else torch.empty_like(zero_buffer)
|
|
||||||
# reduce scatter the gradients within this group of world_size params
|
|
||||||
work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
|
||||||
all_reduce_futures.append(work)
|
|
||||||
|
|
||||||
# Now each rank computes the update and gathers
|
# Stack all gradients into a single tensor (single kernel via torch.stack)
|
||||||
future_idx = 0
|
grad_stack = torch.stack([p.grad for p in params])
|
||||||
|
stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device)
|
||||||
|
stacked_grads[:len(params)].copy_(grad_stack)
|
||||||
|
# Zero-pad if we have fewer params than padded size
|
||||||
|
if len(params) < padded_num_params:
|
||||||
|
stacked_grads[len(params):].zero_()
|
||||||
|
|
||||||
|
# Output buffer for this rank's chunk
|
||||||
|
grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# Async reduce_scatter on the stacked tensor
|
||||||
|
reduce_future = dist.reduce_scatter_tensor(
|
||||||
|
grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True
|
||||||
|
).get_future()
|
||||||
|
|
||||||
|
group_infos.append(dict(
|
||||||
|
grad_chunk=grad_chunk,
|
||||||
|
reduce_future=reduce_future,
|
||||||
|
stacked_grads=stacked_grads, # reuse for all_gather output
|
||||||
|
))
|
||||||
|
|
||||||
|
# Second pass: wait for reduce, compute batched updates, kick off all_gather
|
||||||
all_gather_futures = []
|
all_gather_futures = []
|
||||||
for group in self.param_groups:
|
for group, info in zip(self.param_groups, group_infos):
|
||||||
params = group["params"]
|
info["reduce_future"].wait()
|
||||||
zero_buffer = group["zero_buffer"]
|
|
||||||
# Go through params in groups of world_size.
|
|
||||||
for base_i in range(0, len(params), world_size):
|
|
||||||
# The compute owner of each param is rank i % world_size
|
|
||||||
owner_idx = base_i + rank # calculate the index of the param that this rank owns
|
|
||||||
# Wait for the reduce scatter to complete
|
|
||||||
all_reduce_futures[future_idx].wait() # possibly later we could use wait_any polling instead
|
|
||||||
future_idx += 1
|
|
||||||
# Owner computes the Muon update, result is in its param
|
|
||||||
if owner_idx < len(params):
|
|
||||||
p = params[owner_idx]
|
|
||||||
g = p.grad # now averaged across ranks
|
|
||||||
state = self.state[p]
|
|
||||||
if "momentum_buffer" not in state:
|
|
||||||
state["momentum_buffer"] = torch.zeros_like(g)
|
|
||||||
buf: Tensor = state["momentum_buffer"]
|
|
||||||
buf.lerp_(g, 1.0 - group["momentum"])
|
|
||||||
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
|
|
||||||
g = zeropower_via_polar_express(g, steps=group["ns_steps"])
|
|
||||||
# Variance reduction (NorMuon-style)
|
|
||||||
if group["beta2"] is not None:
|
|
||||||
if "second_momentum_buffer" not in state:
|
|
||||||
# Buffer shape determines reduction dim: reduce along larger dimension
|
|
||||||
if p.size(-2) >= p.size(-1):
|
|
||||||
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1])
|
|
||||||
else:
|
|
||||||
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1, :])
|
|
||||||
g = apply_variance_reduction(g, state["second_momentum_buffer"], group["beta2"])
|
|
||||||
# Parameter update with cautious weight decay
|
|
||||||
effective_lr = group["lr"] * (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
|
|
||||||
wd = group["weight_decay"]
|
|
||||||
if wd != 0:
|
|
||||||
mask = (g * p) >= 0
|
|
||||||
p.sub_(effective_lr * g + effective_lr * wd * p * mask)
|
|
||||||
else:
|
|
||||||
p.sub_(effective_lr * g)
|
|
||||||
# Replicate updated parameters to all ranks
|
|
||||||
ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer
|
|
||||||
ag_output = params[base_i:base_i + world_size]
|
|
||||||
ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))]) # pad
|
|
||||||
work = dist.all_gather(ag_output, ag_input, async_op=True).get_future()
|
|
||||||
all_gather_futures.append(work)
|
|
||||||
|
|
||||||
# Wait for all work to finish
|
params = group["params"]
|
||||||
torch.futures.collect_all(all_gather_futures).wait()
|
chunk_size = group["chunk_size"]
|
||||||
|
shape = params[0].shape
|
||||||
|
device, dtype = params[0].device, params[0].dtype
|
||||||
|
grad_chunk = info["grad_chunk"]
|
||||||
|
|
||||||
|
# How many params does this rank actually own?
|
||||||
|
start_idx = rank * chunk_size
|
||||||
|
num_owned = min(chunk_size, max(0, len(params) - start_idx))
|
||||||
|
|
||||||
|
# Get or create group-level state (stored keyed by first param)
|
||||||
|
state = self.state[params[0]]
|
||||||
|
|
||||||
|
# Momentum buffer
|
||||||
|
if "momentum_buffer" not in state:
|
||||||
|
state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
|
||||||
|
momentum_buffer = state["momentum_buffer"]
|
||||||
|
|
||||||
|
# Second momentum buffer is factored, either per-row or per-column
|
||||||
|
if "second_momentum_buffer" not in state:
|
||||||
|
if shape[-2] >= shape[-1]:
|
||||||
|
state["second_momentum_buffer"] = torch.zeros(chunk_size, shape[-2], 1, dtype=dtype, device=device)
|
||||||
|
else:
|
||||||
|
state["second_momentum_buffer"] = torch.zeros(chunk_size, 1, shape[-1], dtype=dtype, device=device)
|
||||||
|
second_momentum_buffer = state["second_momentum_buffer"]
|
||||||
|
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
||||||
|
|
||||||
|
# Build updated_params tensor for all_gather
|
||||||
|
updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
if num_owned > 0:
|
||||||
|
# Stack owned params (single kernel via torch.stack)
|
||||||
|
owned_params = [params[start_idx + i] for i in range(num_owned)]
|
||||||
|
stacked_owned_params = torch.stack(owned_params)
|
||||||
|
|
||||||
|
# Get owned slices of buffers and grads
|
||||||
|
owned_grads = grad_chunk[:num_owned]
|
||||||
|
owned_momentum = momentum_buffer[:num_owned]
|
||||||
|
owned_second_momentum = second_momentum_buffer[:num_owned]
|
||||||
|
|
||||||
|
# Fill 0-D tensors with current values
|
||||||
|
self._momentum_t.fill_(group["momentum"])
|
||||||
|
self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
||||||
|
self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
||||||
|
self._wd_t.fill_(group["weight_decay"])
|
||||||
|
|
||||||
|
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
|
||||||
|
muon_step_fused(
|
||||||
|
owned_grads,
|
||||||
|
stacked_owned_params,
|
||||||
|
owned_momentum,
|
||||||
|
owned_second_momentum,
|
||||||
|
self._momentum_t,
|
||||||
|
self._lr_t,
|
||||||
|
self._wd_t,
|
||||||
|
self._beta2_t,
|
||||||
|
group["ns_steps"],
|
||||||
|
red_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copy updated params to output buffer
|
||||||
|
updated_params[:num_owned].copy_(stacked_owned_params)
|
||||||
|
|
||||||
|
# Zero-pad the rest (for ranks that own fewer params)
|
||||||
|
if num_owned < chunk_size:
|
||||||
|
updated_params[num_owned:].zero_()
|
||||||
|
|
||||||
|
# Reuse stacked_grads buffer for all_gather output
|
||||||
|
stacked_params = info["stacked_grads"]
|
||||||
|
|
||||||
|
# Async all_gather to replicate updated params to all ranks
|
||||||
|
gather_future = dist.all_gather_into_tensor(
|
||||||
|
stacked_params, updated_params, async_op=True
|
||||||
|
).get_future()
|
||||||
|
|
||||||
|
all_gather_futures.append(dict(
|
||||||
|
gather_future=gather_future,
|
||||||
|
stacked_params=stacked_params,
|
||||||
|
params=params,
|
||||||
|
))
|
||||||
|
|
||||||
|
# Final pass: wait for all_gather and copy back to params
|
||||||
|
for info in all_gather_futures:
|
||||||
|
info["gather_future"].wait()
|
||||||
|
stacked_params = info["stacked_params"]
|
||||||
|
params = info["params"]
|
||||||
|
# Batched copy back (single kernel instead of N individual copies)
|
||||||
|
torch._foreach_copy_(params, list(stacked_params[:len(params)].unbind(0)))
|
||||||
|
|||||||
Reference in New Issue
Block a user