changes and optimizations to muon, making it more efficient and simpler/cleaner a bit

This commit is contained in:
Andrej Karpathy
2026-01-15 03:20:48 +00:00
parent 3142ca1a28
commit 6bb92403d5

View File

@@ -1,7 +1,27 @@
""" """
Muon optimizer adapted (simplified) from modded-nanogpt. Muon optimizer adapted and simplified from modded-nanogpt.
https://github.com/KellerJordan/modded-nanogpt https://github.com/KellerJordan/modded-nanogpt
Background:
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
zero even beyond the point where the iteration no longer converges all the way to one everywhere
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
Here, an alternative to Newton-Schulz iteration with potentially better convergence properties:
Polar Express Sign Method for orthogonalization.
https://arxiv.org/pdf/2505.16932
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
Some of the changes in nanochat implementation:
- Uses a simpler, more general approach to parameter grouping and stacking
- Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step
- Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format)
""" """
import torch import torch
from torch import Tensor from torch import Tensor
import torch.distributed as dist import torch.distributed as dist
@@ -16,97 +36,61 @@ polar_express_coeffs = [
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323), (2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
] ]
@torch.compile(dynamic=False, fullgraph=True)
@torch.compile def muon_step_fused(
def zeropower_via_polar_express(G: Tensor, steps: int = 5) -> Tensor: stacked_grads: Tensor,
stacked_params: Tensor,
momentum_buffer: Tensor,
second_momentum_buffer: Tensor,
momentum_t: Tensor,
lr_t: Tensor,
wd_t: Tensor,
beta2_t: Tensor,
ns_steps: int,
red_dim: int,
) -> None:
""" """
Polar Express Sign Method for orthogonalization. Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update
https://arxiv.org/pdf/2505.16932 All in one compiled graph to eliminate Python overhead between ops.
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower. Some of the constants are 0-D CPU tensors to avoid recompilation when values change.
Alternative to Newton-Schulz iteration with potentially better convergence properties.
""" """
assert G.ndim >= 2
X = G.bfloat16() # Nesterov momentum
if G.size(-2) > G.size(-1): momentum = momentum_t.to(stacked_grads.dtype)
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
g = stacked_grads.lerp_(momentum_buffer, momentum)
# Polar express
X = g.bfloat16()
if g.size(-2) > g.size(-1):
X = X.mT X = X.mT
# Ensure spectral norm is at most 1 (with 2% safety factor)
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6) X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
for a, b, c in polar_express_coeffs[:ns_steps]:
# Perform the iterations (cap at available coefficients)
for a, b, c in polar_express_coeffs[:min(steps, len(polar_express_coeffs))]:
A = X @ X.mT A = X @ X.mT
B = b * A + c * (A @ A) B = b * A + c * (A @ A)
X = a * X + B @ X X = a * X + B @ X
if g.size(-2) > g.size(-1):
if G.size(-2) > G.size(-1):
X = X.mT X = X.mT
return X g = X
# Variance reduction
@torch.compile beta2 = beta2_t.to(g.dtype)
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
""" red_dim_size = g.size(red_dim)
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
zero even beyond the point where the iteration no longer converges all the way to one everywhere
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
"""
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
if G.size(-2) > G.size(-1):
X = X.mT
# Ensure spectral norm is at most 1
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
# Perform the NS iterations
for _ in range(steps):
A = X @ X.mT
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
X = a * X + B @ X
if G.size(-2) > G.size(-1):
X = X.mT
return X
@torch.compile
def apply_variance_reduction(v: Tensor, second_momentum_buffer: Tensor, beta2: float) -> Tensor:
"""
NorMuon-style variance reduction, similar to Adafactor's low-rank variance estimator.
https://arxiv.org/pdf/2510.05491
Normalizes updates based on a running estimate of per-row (or per-column) variance.
The reduction dimension is determined by the shape of second_momentum_buffer.
"""
# Determine reduction dimension from buffer shape
red_dim = -1 if second_momentum_buffer.size(-1) == 1 else -2
# Compute per-row/col mean of squared values
v_mean = v.float().square().mean(dim=red_dim, keepdim=True)
red_dim_size = v.size(red_dim)
# Compute current norm
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
v_norm = v_norm_sq.sqrt() v_norm = v_norm_sq.sqrt()
# Update second momentum buffer (EMA of variance)
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2) second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
# Compute scaling factor from second momentum
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt() step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square() scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt() v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
# Final scale preserves overall norm while adjusting per-row/col
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10)) final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
return v.mul(final_scale.to(v.dtype)) g = g * final_scale.to(g.dtype)
# Cautious weight decay + parameter update
lr = lr_t.to(g.dtype)
wd = wd_t.to(g.dtype)
mask = (g * stacked_params) >= 0
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
class Muon(torch.optim.Optimizer): class Muon(torch.optim.Optimizer):
""" """
@@ -127,94 +111,112 @@ class Muon(torch.optim.Optimizer):
Arguments: Arguments:
lr: The learning rate used by the internal SGD. lr: The learning rate used by the internal SGD.
momentum: The momentum used by the internal SGD. momentum: The momentum used by the internal SGD.
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
ns_steps: The number of Newton-Schulz iteration steps to use. ns_steps: The number of Newton-Schulz iteration steps to use.
beta2: The decay rate for the second moment (variance) estimate. Set to None to disable. beta2: The decay rate for the second moment (variance) estimate. Set to None to disable.
weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree. weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
""" """
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5, beta2=0.95, weight_decay=0.0): def __init__(self, params, lr=0.02, momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=0.0):
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay) defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
params: list[Tensor] = [*params] assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
params = list(params) # ensure we have a list, not an e.g. (exhaustible) iterator
# Group by shape so we can stack tensors
shapes = sorted({p.shape for p in params})
param_groups = [] param_groups = []
for size in {p.numel() for p in params}: for shape in shapes:
group = dict(params=[p for p in params if p.numel() == size]) group_params = [p for p in params if p.shape == shape]
param_groups.append(group) param_groups.append(dict(params=group_params))
super().__init__(param_groups, defaults) super().__init__(param_groups, defaults)
# 0-D CPU tensors to avoid torch.compile recompilation when values change
self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
@torch.no_grad() @torch.no_grad()
def step(self): def step(self):
for group in self.param_groups: for group in self.param_groups:
params: list[Tensor] = group["params"] params: list[Tensor] = group["params"]
for p in params: if not params:
g = p.grad continue
assert g is not None
state = self.state[p] # Get or create group-level buffers (stored in first param's state for convenience)
if "momentum_buffer" not in state: state = self.state[params[0]]
state["momentum_buffer"] = torch.zeros_like(g) num_params = len(params) # e.g.: 12 (for a d12 model)
buf: Tensor = state["momentum_buffer"] # e.g.: shape = (768, 3072), device = cuda:0, dtype = torch.float32, for one of the MLP projections
buf.lerp_(g, 1 - group["momentum"]) shape, device, dtype = params[0].shape, params[0].device, params[0].dtype
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
g = zeropower_via_polar_express(g, steps=group["ns_steps"]) # Momentum for every individual parameter
# Variance reduction (NorMuon-style) if "momentum_buffer" not in state:
if group["beta2"] is not None: state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
if "second_momentum_buffer" not in state: momentum_buffer = state["momentum_buffer"] # e.g.: (12, 768, 3072)
# Buffer shape determines reduction dim: reduce along larger dimension
if p.size(-2) >= p.size(-1): # Second momentum buffer is factored, either per-row or per-column
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1]) if "second_momentum_buffer" not in state:
else: if shape[-2] >= shape[-1]:
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1, :]) state["second_momentum_buffer"] = torch.zeros(num_params, shape[-2], 1, dtype=dtype, device=device)
g = apply_variance_reduction(g, state["second_momentum_buffer"], group["beta2"])
# Parameter update with cautious weight decay
effective_lr = group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5
wd = group["weight_decay"]
if wd != 0:
mask = (g * p) >= 0
p.sub_(effective_lr * g + effective_lr * wd * p * mask)
else: else:
p.sub_(effective_lr * g) state["second_momentum_buffer"] = torch.zeros(num_params, 1, shape[-1], dtype=dtype, device=device)
second_momentum_buffer = state["second_momentum_buffer"] # (12, 1, 3072)
red_dim = -1 if shape[-2] >= shape[-1] else -2 # e.g.: -2
# Stack grads and params
stacked_grads = torch.stack([p.grad for p in params]) # (12, 768, 3072)
stacked_params = torch.stack(params) # (12, 768, 3072)
# Fill all the 0-D tensors with current values
self._momentum_t.fill_(group["momentum"])
self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
self._wd_t.fill_(group["weight_decay"])
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
muon_step_fused(
stacked_grads,
stacked_params,
momentum_buffer,
second_momentum_buffer,
self._momentum_t,
self._lr_t,
self._wd_t,
self._beta2_t,
group["ns_steps"],
red_dim,
)
# Copy back to original params: [(768, 3072), (768, 3072), ...] <- (12, 768, 3072)
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
class DistMuon(torch.optim.Optimizer): class DistMuon(torch.optim.Optimizer):
""" """
Muon: SGD-momentum + (optional) Nesterov, then orthogonalize the 2D update via Polar Express, Distributed version of the Muon optimizer.
finally apply aspect-ratio scaled step. Performs its own distributed synchronization:
- reduce_scatter(AVG) for gradient averaging
- all_gather to replicate updated weights
Notes:
* Designed for 2D parameters (e.g., linear/conv kernels reshaped to 2D). Do not use for 0D/1D
params like embeddings or scalars.
* Momentum buffers are maintained only on the 'owner' rank for each parameter (rank chosen
by block-cyclic assignment below). If you checkpoint optimizer state on a single rank,
consolidate states beforehand.
Args:
params: iterable of Tensors
lr: learning rate
momentum: momentum coefficient in [0,1)
nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf
ns_steps: number of Newton-Schulz iterations for the orthogonalization
beta2: decay rate for second moment (variance) estimate. Set to None to disable.
weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
""" """
def __init__(self, params, lr: float = 0.02, momentum: float = 0.95, def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
nesterov: bool = True, ns_steps: int = 5, beta2: float = 0.95, weight_decay: float = 0.0): ns_steps: int = 5, beta2: float = 0.95, weight_decay: float = 0.0):
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay) defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
params = list(params)
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only" assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
params = list(params)
world_size = dist.get_world_size()
rank = dist.get_rank() rank = dist.get_rank()
# Group all parameters by their shape # Group all parameters by their shape
shapes = sorted({p.shape for p in params}) # sort to ensure consistent / deterministic ordering shapes = sorted({p.shape for p in params}) # sort for deterministic ordering across ranks
param_groups = [] param_groups = []
for shape in shapes: for shape in shapes:
group_params = [p for p in params if p.shape == shape] group_params = [p for p in params if p.shape == shape]
device, dtype = group_params[0].device, group_params[0].dtype device, dtype = group_params[0].device, group_params[0].dtype
assert all(p.device == device for p in group_params) assert all(p.device == device for p in group_params)
assert all(p.dtype == dtype for p in group_params) assert all(p.dtype == dtype for p in group_params)
# Compute chunk size for this group (how many params each rank owns)
chunk_size = (len(group_params) + world_size - 1) // world_size
if rank == 0: if rank == 0:
print(f"Muon: Grouping {len(group_params)} params of shape {shape}, device {device}, dtype {dtype}") print(f"Muon: {len(group_params)} params of shape {shape}, chunk_size={chunk_size}")
param_groups.append(dict(params=group_params, zero_buffer=torch.zeros_like(group_params[0]))) param_groups.append(dict(params=group_params, chunk_size=chunk_size))
super().__init__(param_groups, defaults) super().__init__(param_groups, defaults)
# 0-D CPU tensors to avoid torch.compile recompilation when values change
self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
@torch.no_grad() @torch.no_grad()
def step(self): def step(self):
@@ -224,72 +226,127 @@ class DistMuon(torch.optim.Optimizer):
# Ensure all grads exist # Ensure all grads exist
assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads" assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
# Kick off all the reduce scatter operations to average up the gradients across all ranks # First pass: stack grads and kick off reduce_scatter for each group
all_reduce_futures = [] group_infos = []
for group in self.param_groups: for group in self.param_groups:
params = group["params"] params: list[Tensor] = group["params"]
zero_buffer = group["zero_buffer"] chunk_size = group["chunk_size"]
# Go through params in groups of world_size. padded_num_params = chunk_size * world_size
for base_i in range(0, len(params), world_size): shape = params[0].shape
# The compute owner of each param is rank i % world_size device, dtype = params[0].device, params[0].dtype
owner_idx = base_i + rank
# each rank stacks up its chunk of world_size params into a list
rs_input = [p.grad for p in params[base_i:base_i + world_size]]
# pad rs_input with the zero buffer to complete the group
rs_input.extend([zero_buffer] * (world_size - len(rs_input)))
# the output buffer gets strided across the group based on the rank
rs_output = params[owner_idx].grad if owner_idx < len(params) else torch.empty_like(zero_buffer)
# reduce scatter the gradients within this group of world_size params
work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True).get_future()
all_reduce_futures.append(work)
# Now each rank computes the update and gathers # Stack all gradients into a single tensor (single kernel via torch.stack)
future_idx = 0 grad_stack = torch.stack([p.grad for p in params])
stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device)
stacked_grads[:len(params)].copy_(grad_stack)
# Zero-pad if we have fewer params than padded size
if len(params) < padded_num_params:
stacked_grads[len(params):].zero_()
# Output buffer for this rank's chunk
grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
# Async reduce_scatter on the stacked tensor
reduce_future = dist.reduce_scatter_tensor(
grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True
).get_future()
group_infos.append(dict(
grad_chunk=grad_chunk,
reduce_future=reduce_future,
stacked_grads=stacked_grads, # reuse for all_gather output
))
# Second pass: wait for reduce, compute batched updates, kick off all_gather
all_gather_futures = [] all_gather_futures = []
for group in self.param_groups: for group, info in zip(self.param_groups, group_infos):
params = group["params"] info["reduce_future"].wait()
zero_buffer = group["zero_buffer"]
# Go through params in groups of world_size.
for base_i in range(0, len(params), world_size):
# The compute owner of each param is rank i % world_size
owner_idx = base_i + rank # calculate the index of the param that this rank owns
# Wait for the reduce scatter to complete
all_reduce_futures[future_idx].wait() # possibly later we could use wait_any polling instead
future_idx += 1
# Owner computes the Muon update, result is in its param
if owner_idx < len(params):
p = params[owner_idx]
g = p.grad # now averaged across ranks
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(g)
buf: Tensor = state["momentum_buffer"]
buf.lerp_(g, 1.0 - group["momentum"])
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
g = zeropower_via_polar_express(g, steps=group["ns_steps"])
# Variance reduction (NorMuon-style)
if group["beta2"] is not None:
if "second_momentum_buffer" not in state:
# Buffer shape determines reduction dim: reduce along larger dimension
if p.size(-2) >= p.size(-1):
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1])
else:
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1, :])
g = apply_variance_reduction(g, state["second_momentum_buffer"], group["beta2"])
# Parameter update with cautious weight decay
effective_lr = group["lr"] * (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
wd = group["weight_decay"]
if wd != 0:
mask = (g * p) >= 0
p.sub_(effective_lr * g + effective_lr * wd * p * mask)
else:
p.sub_(effective_lr * g)
# Replicate updated parameters to all ranks
ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer
ag_output = params[base_i:base_i + world_size]
ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))]) # pad
work = dist.all_gather(ag_output, ag_input, async_op=True).get_future()
all_gather_futures.append(work)
# Wait for all work to finish params = group["params"]
torch.futures.collect_all(all_gather_futures).wait() chunk_size = group["chunk_size"]
shape = params[0].shape
device, dtype = params[0].device, params[0].dtype
grad_chunk = info["grad_chunk"]
# How many params does this rank actually own?
start_idx = rank * chunk_size
num_owned = min(chunk_size, max(0, len(params) - start_idx))
# Get or create group-level state (stored keyed by first param)
state = self.state[params[0]]
# Momentum buffer
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
momentum_buffer = state["momentum_buffer"]
# Second momentum buffer is factored, either per-row or per-column
if "second_momentum_buffer" not in state:
if shape[-2] >= shape[-1]:
state["second_momentum_buffer"] = torch.zeros(chunk_size, shape[-2], 1, dtype=dtype, device=device)
else:
state["second_momentum_buffer"] = torch.zeros(chunk_size, 1, shape[-1], dtype=dtype, device=device)
second_momentum_buffer = state["second_momentum_buffer"]
red_dim = -1 if shape[-2] >= shape[-1] else -2
# Build updated_params tensor for all_gather
updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
if num_owned > 0:
# Stack owned params (single kernel via torch.stack)
owned_params = [params[start_idx + i] for i in range(num_owned)]
stacked_owned_params = torch.stack(owned_params)
# Get owned slices of buffers and grads
owned_grads = grad_chunk[:num_owned]
owned_momentum = momentum_buffer[:num_owned]
owned_second_momentum = second_momentum_buffer[:num_owned]
# Fill 0-D tensors with current values
self._momentum_t.fill_(group["momentum"])
self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
self._wd_t.fill_(group["weight_decay"])
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
muon_step_fused(
owned_grads,
stacked_owned_params,
owned_momentum,
owned_second_momentum,
self._momentum_t,
self._lr_t,
self._wd_t,
self._beta2_t,
group["ns_steps"],
red_dim,
)
# Copy updated params to output buffer
updated_params[:num_owned].copy_(stacked_owned_params)
# Zero-pad the rest (for ranks that own fewer params)
if num_owned < chunk_size:
updated_params[num_owned:].zero_()
# Reuse stacked_grads buffer for all_gather output
stacked_params = info["stacked_grads"]
# Async all_gather to replicate updated params to all ranks
gather_future = dist.all_gather_into_tensor(
stacked_params, updated_params, async_op=True
).get_future()
all_gather_futures.append(dict(
gather_future=gather_future,
stacked_params=stacked_params,
params=params,
))
# Final pass: wait for all_gather and copy back to params
for info in all_gather_futures:
info["gather_future"].wait()
stacked_params = info["stacked_params"]
params = info["params"]
# Batched copy back (single kernel instead of N individual copies)
torch._foreach_copy_(params, list(stacked_params[:len(params)].unbind(0)))