diff --git a/README.md b/README.md index fb8747f..7421152 100644 --- a/README.md +++ b/README.md @@ -135,7 +135,6 @@ python -m pytest tests/test_engine.py -v -s │ └── repackage_data_reference.py # Pretraining data shard generation ├── nanochat │ ├── __init__.py # empty -│ ├── adamw.py # Distributed AdamW optimizer │ ├── checkpoint_manager.py # Save/Load model checkpoints │ ├── common.py # Misc small utilities, quality of life │ ├── core_eval.py # Evaluates base model CORE score (DCLM paper) @@ -146,7 +145,7 @@ python -m pytest tests/test_engine.py -v -s │ ├── gpt.py # The GPT nn.Module Transformer │ ├── logo.svg │ ├── loss_eval.py # Evaluate bits per byte (instead of loss) -│ ├── muon.py # Distributed Muon optimizer +│ ├── optim.py # AdamW + Muon optimizer, 1GPU and distributed │ ├── report.py # Utilities for writing the nanochat Report │ ├── tokenizer.py # BPE Tokenizer wrapper in style of GPT-4 │ └── ui.html # HTML/CSS/JS for nanochat frontend diff --git a/nanochat/adamw.py b/nanochat/adamw.py deleted file mode 100644 index 70ccf7b..0000000 --- a/nanochat/adamw.py +++ /dev/null @@ -1,143 +0,0 @@ -""" -Distributed AdamW optimizer with a fused step function. -A bunch of ideas (e.g. dist comms in slices) are borrowed from modded-nanogpt. -""" -import torch -import torch.distributed as dist -from torch import Tensor - -@torch.compile(dynamic=False, fullgraph=True) -def adamw_step_fused( - p: Tensor, - grad: Tensor, - exp_avg: Tensor, - exp_avg_sq: Tensor, - step_t: Tensor, - lr_t: Tensor, - beta1_t: Tensor, - beta2_t: Tensor, - eps_t: Tensor, - wd_t: Tensor, -) -> None: - """ - Fused AdamW step: weight_decay -> momentum_update -> bias_correction -> param_update - All in one compiled graph to eliminate Python overhead between ops. - The 0-D CPU tensors avoid recompilation when hyperparameter values change. - """ - # Weight decay (decoupled, applied before the update) - p.mul_(1 - lr_t * wd_t) - # Update running averages (lerp_ is cleaner and fuses well) - exp_avg.lerp_(grad, 1 - beta1_t) - exp_avg_sq.lerp_(grad.square(), 1 - beta2_t) - # Bias corrections - bias1 = 1 - beta1_t ** step_t - bias2 = 1 - beta2_t ** step_t - # Compute update and apply - denom = (exp_avg_sq / bias2).sqrt() + eps_t - step_size = lr_t / bias1 - p.add_(exp_avg / denom, alpha=-step_size) - - -class DistAdamW(torch.optim.Optimizer): - """ - Distributed AdamW optimizer. - In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction - """ - def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) - rank = dist.get_rank() - world_size = dist.get_world_size() - # Validate - if rank == 0: - for group in param_groups: - assert isinstance(group, dict), "expecting param_groups to be a list of dicts" - assert isinstance(group['params'], list), "expecting group['params'] to be a list of tensors" - for p in group['params']: - sliced = p.numel() >= 1024 - print(f"AdamW: 1 param of shape {p.shape}, sliced={sliced}") - if sliced: # large parameter tensors will be operated on in slices - assert p.shape[0] % world_size == 0, f"First dim of parameter shape {p.shape} must be divisible by world size {world_size}" - super().__init__(param_groups, defaults) - # 0-D CPU tensors to avoid torch.compile recompilation when values change - self._step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - - @torch.no_grad() - def step(self): - rank = dist.get_rank() - world_size = dist.get_world_size() - reduce_futures: list[torch.Future] = [] - gather_futures: list[torch.Future] = [] - grad_slices = [] - is_small = [] # track which params are small (use all_reduce) vs large (use reduce_scatter) - - for group in self.param_groups: - params: list[Tensor] = group["params"] - for p in params: - grad = p.grad - # Small params: use all_reduce (no scatter/gather needed) - if p.numel() < 1024: - is_small.append(True) - reduce_futures.append(dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) - grad_slices.append(grad) - else: - is_small.append(False) - rank_size = grad.shape[0] // world_size # p.shape[0] % world_size == 0 is checked in __init__ - grad_slice = torch.empty_like(grad[:rank_size]) - reduce_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) - grad_slices.append(grad_slice) - - idx = 0 - for group in self.param_groups: - beta1, beta2 = group['betas'] - eps = group['eps'] - wd = group['weight_decay'] - params = group['params'] - for p in params: - reduce_futures[idx].wait() - g_slice = grad_slices[idx] - lr = group['lr'] * getattr(p, "lr_mul", 1.0) - state = self.state[p] - - # For small params, operate on full param; for large, operate on slice - if is_small[idx]: - p_slice = p - else: - rank_size = p.shape[0] // world_size - p_slice = p[rank * rank_size:(rank + 1) * rank_size] - - # State init - if not state: - state['step'] = 0 - state['exp_avg'] = torch.zeros_like(p_slice) - state['exp_avg_sq'] = torch.zeros_like(p_slice) - exp_avg = state['exp_avg'] - exp_avg_sq = state['exp_avg_sq'] - state['step'] += 1 - - # Fill 0-D tensors with current values - eff_wd = wd * getattr(p, "wd_mul", 1.0) - self._step_t.fill_(state['step']) - self._lr_t.fill_(lr) - self._beta1_t.fill_(beta1) - self._beta2_t.fill_(beta2) - self._eps_t.fill_(eps) - self._wd_t.fill_(eff_wd) - - # Fused update: weight_decay -> momentum -> bias_correction -> param_update - adamw_step_fused( - p_slice, g_slice, exp_avg, exp_avg_sq, - self._step_t, self._lr_t, self._beta1_t, self._beta2_t, self._eps_t, self._wd_t, - ) - - # Only large params need all_gather - if not is_small[idx]: - gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) - idx += 1 - - if gather_futures: - torch.futures.collect_all(gather_futures).wait() diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 672af71..d23a516 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -20,8 +20,7 @@ import torch.nn as nn import torch.nn.functional as F from nanochat.common import get_dist_info, print0 -from nanochat.muon import Muon, DistMuon -from nanochat.adamw import DistAdamW +from nanochat.optim import MuonAdamW, DistMuonAdamW # Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere from nanochat.flash_attention import flash_attn @@ -346,9 +345,10 @@ class GPT(nn.Module): 'total': total, } - def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5): + def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5): model_dim = self.config.n_embd ddp, rank, local_rank, world_size = get_dist_info() + # Separate out all parameters into groups matrix_params = list(self.transformer.h.parameters()) value_embeds_params = list(self.value_embeds.parameters()) @@ -357,30 +357,33 @@ class GPT(nn.Module): resid_params = [self.resid_lambdas] x0_params = [self.x0_lambdas] assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) - # Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars - # Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model) + + # Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model) dmodel_lr_scale = (model_dim / 768) ** -0.5 print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}") - adam_groups = [ - dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale), - dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale), - dict(params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding - dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream - dict(params=x0_params, lr=scalar_lr, betas=(0.96, 0.95)), # higher beta1 for x0 scalars + + # Build param_groups with all required fields explicit + param_groups = [ + # AdamW groups (embeddings, lm_head, scalars) + dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0), + dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0), + dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0), + dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0), + dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0 ] - adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon - AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True) - adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs) - # Create the Muon optimizer for the linear layers - muon_kwargs = dict(lr=matrix_lr, momentum=0.95, weight_decay=weight_decay) - MuonFactory = DistMuon if ddp else Muon - muon_optimizer = MuonFactory(matrix_params, **muon_kwargs) - # Combine them the two optimizers into one list - optimizers = [adamw_optimizer, muon_optimizer] - for opt in optimizers: - for group in opt.param_groups: - group["initial_lr"] = group["lr"] - return optimizers + # Muon groups (matrix params, grouped by shape for stacking) + for shape in sorted({p.shape for p in matrix_params}): + group_params = [p for p in matrix_params if p.shape == shape] + param_groups.append(dict( + kind='muon', params=group_params, lr=matrix_lr, + momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay, + )) + + Factory = DistMuonAdamW if ddp else MuonAdamW + optimizer = Factory(param_groups) + for group in optimizer.param_groups: + group["initial_lr"] = group["lr"] + return optimizer def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'): B, T = idx.size() diff --git a/nanochat/muon.py b/nanochat/muon.py deleted file mode 100644 index cfd2443..0000000 --- a/nanochat/muon.py +++ /dev/null @@ -1,352 +0,0 @@ -""" -Muon optimizer adapted and simplified from modded-nanogpt. -https://github.com/KellerJordan/modded-nanogpt - -Background: -Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a -quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose -of minimizing steps, it turns out to be empirically effective to keep increasing the slope at -zero even beyond the point where the iteration no longer converges all the way to one everywhere -on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T -where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model -performance at all relative to UV^T, where USV^T = G is the SVD. - -Here, an alternative to Newton-Schulz iteration with potentially better convergence properties: -Polar Express Sign Method for orthogonalization. -https://arxiv.org/pdf/2505.16932 -by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower. - -Some of the changes in nanochat implementation: -- Uses a simpler, more general approach to parameter grouping and stacking -- Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step -- Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format) -""" - -import torch -from torch import Tensor -import torch.distributed as dist - -# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2) -# From https://arxiv.org/pdf/2505.16932 -polar_express_coeffs = [ - (8.156554524902461, -22.48329292557795, 15.878769915207462), - (4.042929935166739, -2.808917465908714, 0.5000178451051316), - (3.8916678022926607, -2.772484153217685, 0.5060648178503393), - (3.285753657755655, -2.3681294933425376, 0.46449024233003106), - (2.3465413258596377, -1.7097828382687081, 0.42323551169305323), -] - -@torch.compile(dynamic=False, fullgraph=True) -def muon_step_fused( - stacked_grads: Tensor, - stacked_params: Tensor, - momentum_buffer: Tensor, - second_momentum_buffer: Tensor, - momentum_t: Tensor, - lr_t: Tensor, - wd_t: Tensor, - beta2_t: Tensor, - ns_steps: int, - red_dim: int, -) -> None: - """ - Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update - All in one compiled graph to eliminate Python overhead between ops. - Some of the constants are 0-D CPU tensors to avoid recompilation when values change. - """ - - # Nesterov momentum - momentum = momentum_t.to(stacked_grads.dtype) - momentum_buffer.lerp_(stacked_grads, 1 - momentum) - g = stacked_grads.lerp_(momentum_buffer, momentum) - - # Polar express - X = g.bfloat16() - if g.size(-2) > g.size(-1): - X = X.mT - X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6) - for a, b, c in polar_express_coeffs[:ns_steps]: - A = X @ X.mT - B = b * A + c * (A @ A) - X = a * X + B @ X - if g.size(-2) > g.size(-1): - X = X.mT - g = X - - # Variance reduction - beta2 = beta2_t.to(g.dtype) - v_mean = g.float().square().mean(dim=red_dim, keepdim=True) - red_dim_size = g.size(red_dim) - v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size - v_norm = v_norm_sq.sqrt() - second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2) - step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt() - scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square() - v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt() - final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10)) - g = g * final_scale.to(g.dtype) - - # Cautious weight decay + parameter update - lr = lr_t.to(g.dtype) - wd = wd_t.to(g.dtype) - mask = (g * stacked_params) >= 0 - stacked_params.sub_(lr * g + lr * wd * stacked_params * mask) - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - https://kellerjordan.github.io/posts/muon/ - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Some warnings: - - This optimizer should not be used for the embedding layer, the final fully connected layer, - or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). - - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. - - Arguments: - lr: The learning rate used by the internal SGD. - momentum: The momentum used by the internal SGD. - ns_steps: The number of Newton-Schulz iteration steps to use. - beta2: The decay rate for the second moment (variance) estimate. Set to None to disable. - weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree. - """ - def __init__(self, params, lr=0.02, momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=0.0): - defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay) - assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only" - params = list(params) # ensure we have a list, not an e.g. (exhaustible) iterator - # Group by shape so we can stack tensors - shapes = sorted({p.shape for p in params}) - param_groups = [] - for shape in shapes: - group_params = [p for p in params if p.shape == shape] - param_groups.append(dict(params=group_params)) - super().__init__(param_groups, defaults) - # 0-D CPU tensors to avoid torch.compile recompilation when values change - self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - - @torch.no_grad() - def step(self): - for group in self.param_groups: - params: list[Tensor] = group["params"] - if not params: - continue - - # Get or create group-level buffers (stored in first param's state for convenience) - state = self.state[params[0]] - num_params = len(params) # e.g.: 12 (for a d12 model) - # e.g.: shape = (768, 3072), device = cuda:0, dtype = torch.float32, for one of the MLP projections - shape, device, dtype = params[0].shape, params[0].device, params[0].dtype - - # Momentum for every individual parameter - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device) - momentum_buffer = state["momentum_buffer"] # e.g.: (12, 768, 3072) - - # Second momentum buffer is factored, either per-row or per-column - if "second_momentum_buffer" not in state: - if shape[-2] >= shape[-1]: - state["second_momentum_buffer"] = torch.zeros(num_params, shape[-2], 1, dtype=dtype, device=device) - else: - state["second_momentum_buffer"] = torch.zeros(num_params, 1, shape[-1], dtype=dtype, device=device) - second_momentum_buffer = state["second_momentum_buffer"] # (12, 1, 3072) - red_dim = -1 if shape[-2] >= shape[-1] else -2 # e.g.: -2 - - # Stack grads and params - stacked_grads = torch.stack([p.grad for p in params]) # (12, 768, 3072) - stacked_params = torch.stack(params) # (12, 768, 3072) - - # Fill all the 0-D tensors with current values - self._momentum_t.fill_(group["momentum"]) - self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0) - self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5) - self._wd_t.fill_(group["weight_decay"]) - - # Single fused kernel: momentum -> polar_express -> variance_reduction -> update - muon_step_fused( - stacked_grads, - stacked_params, - momentum_buffer, - second_momentum_buffer, - self._momentum_t, - self._lr_t, - self._wd_t, - self._beta2_t, - group["ns_steps"], - red_dim, - ) - - # Copy back to original params: [(768, 3072), (768, 3072), ...] <- (12, 768, 3072) - torch._foreach_copy_(params, list(stacked_params.unbind(0))) - - -class DistMuon(torch.optim.Optimizer): - """ - Distributed version of the Muon optimizer. - """ - def __init__(self, params, lr: float = 0.02, momentum: float = 0.95, - ns_steps: int = 5, beta2: float = 0.95, weight_decay: float = 0.0): - defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay) - assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only" - params = list(params) - world_size = dist.get_world_size() - rank = dist.get_rank() - # Group all parameters by their shape - shapes = sorted({p.shape for p in params}) # sort for deterministic ordering across ranks - param_groups = [] - for shape in shapes: - group_params = [p for p in params if p.shape == shape] - device, dtype = group_params[0].device, group_params[0].dtype - assert all(p.device == device for p in group_params) - assert all(p.dtype == dtype for p in group_params) - # Compute chunk size for this group (how many params each rank owns) - chunk_size = (len(group_params) + world_size - 1) // world_size - if rank == 0: - print(f"Muon: {len(group_params)} params of shape {shape}, chunk_size={chunk_size}") - param_groups.append(dict(params=group_params, chunk_size=chunk_size)) - super().__init__(param_groups, defaults) - # 0-D CPU tensors to avoid torch.compile recompilation when values change - self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - - @torch.no_grad() - def step(self): - rank = dist.get_rank() - world_size = dist.get_world_size() - - # Ensure all grads exist - assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads" - - # First pass: stack grads and kick off reduce_scatter for each group - group_infos = [] - for group in self.param_groups: - params: list[Tensor] = group["params"] - chunk_size = group["chunk_size"] - padded_num_params = chunk_size * world_size - shape = params[0].shape - device, dtype = params[0].device, params[0].dtype - - # Stack all gradients into a single tensor (single kernel via torch.stack) - grad_stack = torch.stack([p.grad for p in params]) - stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device) - stacked_grads[:len(params)].copy_(grad_stack) - # Zero-pad if we have fewer params than padded size - if len(params) < padded_num_params: - stacked_grads[len(params):].zero_() - - # Output buffer for this rank's chunk - grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device) - - # Async reduce_scatter on the stacked tensor - reduce_future = dist.reduce_scatter_tensor( - grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True - ).get_future() - - group_infos.append(dict( - grad_chunk=grad_chunk, - reduce_future=reduce_future, - stacked_grads=stacked_grads, # reuse for all_gather output - )) - - # Second pass: wait for reduce, compute batched updates, kick off all_gather - all_gather_futures = [] - for group, info in zip(self.param_groups, group_infos): - info["reduce_future"].wait() - - params = group["params"] - chunk_size = group["chunk_size"] - shape = params[0].shape - device, dtype = params[0].device, params[0].dtype - grad_chunk = info["grad_chunk"] - - # How many params does this rank actually own? - start_idx = rank * chunk_size - num_owned = min(chunk_size, max(0, len(params) - start_idx)) - - # Get or create group-level state (stored keyed by first param) - state = self.state[params[0]] - - # Momentum buffer - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device) - momentum_buffer = state["momentum_buffer"] - - # Second momentum buffer is factored, either per-row or per-column - if "second_momentum_buffer" not in state: - if shape[-2] >= shape[-1]: - state["second_momentum_buffer"] = torch.zeros(chunk_size, shape[-2], 1, dtype=dtype, device=device) - else: - state["second_momentum_buffer"] = torch.zeros(chunk_size, 1, shape[-1], dtype=dtype, device=device) - second_momentum_buffer = state["second_momentum_buffer"] - red_dim = -1 if shape[-2] >= shape[-1] else -2 - - # Build updated_params tensor for all_gather - updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device) - - if num_owned > 0: - # Stack owned params (single kernel via torch.stack) - owned_params = [params[start_idx + i] for i in range(num_owned)] - stacked_owned_params = torch.stack(owned_params) - - # Get owned slices of buffers and grads - owned_grads = grad_chunk[:num_owned] - owned_momentum = momentum_buffer[:num_owned] - owned_second_momentum = second_momentum_buffer[:num_owned] - - # Fill 0-D tensors with current values - self._momentum_t.fill_(group["momentum"]) - self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0) - self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5) - self._wd_t.fill_(group["weight_decay"]) - - # Single fused kernel: momentum -> polar_express -> variance_reduction -> update - muon_step_fused( - owned_grads, - stacked_owned_params, - owned_momentum, - owned_second_momentum, - self._momentum_t, - self._lr_t, - self._wd_t, - self._beta2_t, - group["ns_steps"], - red_dim, - ) - - # Copy updated params to output buffer - updated_params[:num_owned].copy_(stacked_owned_params) - - # Zero-pad the rest (for ranks that own fewer params) - if num_owned < chunk_size: - updated_params[num_owned:].zero_() - - # Reuse stacked_grads buffer for all_gather output - stacked_params = info["stacked_grads"] - - # Async all_gather to replicate updated params to all ranks - gather_future = dist.all_gather_into_tensor( - stacked_params, updated_params, async_op=True - ).get_future() - - all_gather_futures.append(dict( - gather_future=gather_future, - stacked_params=stacked_params, - params=params, - )) - - # Final pass: wait for all_gather and copy back to params - for info in all_gather_futures: - info["gather_future"].wait() - stacked_params = info["stacked_params"] - params = info["params"] - # Batched copy back (single kernel instead of N individual copies) - torch._foreach_copy_(params, list(stacked_params[:len(params)].unbind(0))) diff --git a/nanochat/optim.py b/nanochat/optim.py new file mode 100644 index 0000000..190a1ed --- /dev/null +++ b/nanochat/optim.py @@ -0,0 +1,528 @@ +""" +A nice and efficient mixed AdamW/Muon Combined Optimizer. +Usually the embeddings and scalars go into AdamW, and the matrix parameters go into Muon. +Two versions are provided (MuonAdamW, DistMuonAdamW), for single GPU and distributed. + +Addapted from: https://github.com/KellerJordan/modded-nanogpt +Further contributions from @karpathy and @chrisjmccormick. +""" + +import torch +import torch.distributed as dist +from torch import Tensor + +# ----------------------------------------------------------------------------- +""" +Good old AdamW optimizer, fused kernel. +https://arxiv.org/abs/1711.05101 +""" + +@torch.compile(dynamic=False, fullgraph=True) +def adamw_step_fused( + p: Tensor, # (32768, 768) - parameter tensor + grad: Tensor, # (32768, 768) - gradient, same shape as p + exp_avg: Tensor, # (32768, 768) - first moment, same shape as p + exp_avg_sq: Tensor, # (32768, 768) - second moment, same shape as p + step_t: Tensor, # () - 0-D CPU tensor, step count + lr_t: Tensor, # () - 0-D CPU tensor, learning rate + beta1_t: Tensor, # () - 0-D CPU tensor, beta1 + beta2_t: Tensor, # () - 0-D CPU tensor, beta2 + eps_t: Tensor, # () - 0-D CPU tensor, epsilon + wd_t: Tensor, # () - 0-D CPU tensor, weight decay +) -> None: + """ + Fused AdamW step: weight_decay -> momentum_update -> bias_correction -> param_update + All in one compiled graph to eliminate Python overhead between ops. + The 0-D CPU tensors avoid recompilation when hyperparameter values change. + """ + # Weight decay (decoupled, applied before the update) + p.mul_(1 - lr_t * wd_t) + # Update running averages (lerp_ is cleaner and fuses well) + exp_avg.lerp_(grad, 1 - beta1_t) + exp_avg_sq.lerp_(grad.square(), 1 - beta2_t) + # Bias corrections + bias1 = 1 - beta1_t ** step_t + bias2 = 1 - beta2_t ** step_t + # Compute update and apply + denom = (exp_avg_sq / bias2).sqrt() + eps_t + step_size = lr_t / bias1 + p.add_(exp_avg / denom, alpha=-step_size) + +# ----------------------------------------------------------------------------- +""" +Muon optimizer adapted and simplified from modded-nanogpt. +https://github.com/KellerJordan/modded-nanogpt + +Background: +Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a +quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose +of minimizing steps, it turns out to be empirically effective to keep increasing the slope at +zero even beyond the point where the iteration no longer converges all the way to one everywhere +on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T +where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model +performance at all relative to UV^T, where USV^T = G is the SVD. + +Here, an alternative to Newton-Schulz iteration with potentially better convergence properties: +Polar Express Sign Method for orthogonalization. +https://arxiv.org/pdf/2505.16932 +by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower. + +Some of the changes in nanochat implementation: +- Uses a simpler, more general approach to parameter grouping and stacking +- Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step +- Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format) +""" + +# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2) +# From https://arxiv.org/pdf/2505.16932 +polar_express_coeffs = [ + (8.156554524902461, -22.48329292557795, 15.878769915207462), + (4.042929935166739, -2.808917465908714, 0.5000178451051316), + (3.8916678022926607, -2.772484153217685, 0.5060648178503393), + (3.285753657755655, -2.3681294933425376, 0.46449024233003106), + (2.3465413258596377, -1.7097828382687081, 0.42323551169305323), +] + +@torch.compile(dynamic=False, fullgraph=True) +def muon_step_fused( + stacked_grads: Tensor, # (12, 768, 3072) - stacked gradients + stacked_params: Tensor, # (12, 768, 3072) - stacked parameters + momentum_buffer: Tensor, # (12, 768, 3072) - first moment buffer + second_momentum_buffer: Tensor, # (12, 768, 1) or (12, 1, 3072) - factored second moment + momentum_t: Tensor, # () - 0-D CPU tensor, momentum coefficient + lr_t: Tensor, # () - 0-D CPU tensor, learning rate + wd_t: Tensor, # () - 0-D CPU tensor, weight decay + beta2_t: Tensor, # () - 0-D CPU tensor, beta2 for second moment + ns_steps: int, # 5 - number of Newton-Schulz/Polar Express iterations + red_dim: int, # -1 or -2 - reduction dimension for variance +) -> None: + """ + Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update + All in one compiled graph to eliminate Python overhead between ops. + Some of the constants are 0-D CPU tensors to avoid recompilation when values change. + """ + + # Nesterov momentum + momentum = momentum_t.to(stacked_grads.dtype) + momentum_buffer.lerp_(stacked_grads, 1 - momentum) + g = stacked_grads.lerp_(momentum_buffer, momentum) + + # Polar express + X = g.bfloat16() + X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6) + if g.size(-2) > g.size(-1): # Tall matrix + for a, b, c in polar_express_coeffs[:ns_steps]: + A = X.mT @ X + B = b * A + c * (A @ A) + X = a * X + X @ B + else: # Wide matrix (original math) + for a, b, c in polar_express_coeffs[:ns_steps]: + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + g = X + + # Variance reduction + beta2 = beta2_t.to(g.dtype) + v_mean = g.float().square().mean(dim=red_dim, keepdim=True) + red_dim_size = g.size(red_dim) + v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size + v_norm = v_norm_sq.sqrt() + second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2) + step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt() + scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square() + v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt() + final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10)) + g = g * final_scale.to(g.dtype) + + # Cautious weight decay + parameter update + lr = lr_t.to(g.dtype) + wd = wd_t.to(g.dtype) + mask = (g * stacked_params) >= 0 + stacked_params.sub_(lr * g + lr * wd * stacked_params * mask) + +# ----------------------------------------------------------------------------- +# Single GPU version of the MuonAdamW optimizer. +# Used mostly for reference, debugging and testing. + +class MuonAdamW(torch.optim.Optimizer): + """ + Combined optimizer: Muon for 2D matrix params, AdamW for others, single GPU version. + + AdamW - Fused AdamW optimizer step. + + Muon - MomentUm Orthogonalized by Newton-schulz + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - The Muon optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. + + Arguments: + param_groups: List of dicts, each containing: + - 'params': List of parameters + - 'kind': 'adamw' or 'muon' + - For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay' + - For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay' + """ + def __init__(self, param_groups: list[dict]): + super().__init__(param_groups, defaults={}) + # 0-D CPU tensors to avoid torch.compile recompilation when values change + # AdamW tensors + self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + # Muon tensors + self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + + def _step_adamw(self, group: dict) -> None: + """ + AdamW update for each param in the group individually. + Lazy init the state, fill in all 0-D tensors, call the fused kernel. + """ + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + state = self.state[p] + + # State init + if not state: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p) + state['exp_avg_sq'] = torch.zeros_like(p) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + + # Fill 0-D tensors with current values + self._adamw_step_t.fill_(state['step']) + self._adamw_lr_t.fill_(group['lr']) + self._adamw_beta1_t.fill_(group['betas'][0]) + self._adamw_beta2_t.fill_(group['betas'][1]) + self._adamw_eps_t.fill_(group['eps']) + self._adamw_wd_t.fill_(group['weight_decay']) + + # Fused update: weight_decay -> momentum -> bias_correction -> param_update + adamw_step_fused( + p, grad, exp_avg, exp_avg_sq, + self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t, + self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t, + ) + + def _step_muon(self, group: dict) -> None: + """ + Muon update for all params in the group (stacked for efficiency). + Lazy init the state, fill in all 0-D tensors, call the fused kernel. + """ + params: list[Tensor] = group['params'] + if not params: + return + + # Get or create group-level buffers (stored in first param's state for convenience) + p = params[0] + state = self.state[p] + num_params = len(params) + shape, device, dtype = p.shape, p.device, p.dtype + + # Momentum for every individual parameter + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device) + momentum_buffer = state["momentum_buffer"] + + # Second momentum buffer is factored, either per-row or per-column + if "second_momentum_buffer" not in state: + state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1]) + state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device) + second_momentum_buffer = state["second_momentum_buffer"] + red_dim = -1 if shape[-2] >= shape[-1] else -2 + + # Stack grads and params (NOTE: this assumes all params have the same shape) + stacked_grads = torch.stack([p.grad for p in params]) + stacked_params = torch.stack(params) + + # Fill all the 0-D tensors with current values + self._muon_momentum_t.fill_(group["momentum"]) + self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0) + self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5) + self._muon_wd_t.fill_(group["weight_decay"]) + + # Single fused kernel: momentum -> polar_express -> variance_reduction -> update + muon_step_fused( + stacked_grads, + stacked_params, + momentum_buffer, + second_momentum_buffer, + self._muon_momentum_t, + self._muon_lr_t, + self._muon_wd_t, + self._muon_beta2_t, + group["ns_steps"], + red_dim, + ) + + # Copy back to original params + torch._foreach_copy_(params, list(stacked_params.unbind(0))) + + @torch.no_grad() + def step(self): + for group in self.param_groups: + if group['kind'] == 'adamw': + self._step_adamw(group) + elif group['kind'] == 'muon': + self._step_muon(group) + else: + raise ValueError(f"Unknown optimizer kind: {group['kind']}") + +# ----------------------------------------------------------------------------- +# Distributed version of the MuonAdamW optimizer. +# Used for training on multiple GPUs. + +class DistMuonAdamW(torch.optim.Optimizer): + """ + Combined distributed optimizer: Muon for 2D matrix params, AdamW for others. + + See MuonAdamW for the algorithmic details of each optimizer. This class adds + distributed communication to enable multi-GPU training without PyTorch DDP. + + Design Goals: + - Overlap communication with computation (async ops) + - Minimize memory by sharding optimizer states across ranks (ZeRO-2 style) + - Batch small tensors into single comm ops where possible + + Communication Pattern (3-phase async): + We use a 3-phase structure to maximize overlap between communication and compute: + + Phase 1: Launch all async reduce ops + - Kick off all reduce_scatter/all_reduce operations + - Don't wait - let them run in background while we continue + + Phase 2: Wait for reduces, compute updates, launch gathers + - For each group: wait for its reduce, compute the update, launch gather + - By processing groups in order, earlier gathers run while later computes happen + + Phase 3: Wait for gathers, copy back + - Wait for all gathers to complete + - Copy updated params back to original tensors (Muon only) + + AdamW Communication (ZeRO-2 style): + - Small params (<1024 elements): all_reduce gradients, update full param on each rank. + Optimizer state is replicated but these params are tiny (scalars, biases). + - Large params: reduce_scatter gradients so each rank gets 1/N of the grad, update + only that slice, then all_gather the updated slices. Optimizer state (exp_avg, + exp_avg_sq) is sharded - each rank only stores state for its slice. + Requires param.shape[0] divisible by world_size. + + Muon Communication (stacked + chunked): + - All params in a Muon group must have the same shape (caller's responsibility). + - Stack all K params into a single (K, *shape) tensor for efficient comm. + - Divide K params across N ranks: each rank "owns" ceil(K/N) params. + - reduce_scatter the stacked grads so each rank gets its chunk. + - Each rank computes Muon update only for params it owns. + - all_gather the updated params back to all ranks. + - Optimizer state (momentum_buffer, second_momentum_buffer) is sharded by chunk. + - Padding: if K doesn't divide evenly, we zero-pad to (ceil(K/N) * N) for comm, + then ignore the padding when copying back. + + Buffer Reuse: + - For Muon, we allocate stacked_grads for reduce_scatter input, then reuse the + same buffer as the output for all_gather (stacked_params). This saves memory + since we don't need both buffers simultaneously. + + Arguments: + param_groups: List of dicts, each containing: + - 'params': List of parameters + - 'kind': 'adamw' or 'muon' + - For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay' + - For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay' + """ + def __init__(self, param_groups: list[dict]): + super().__init__(param_groups, defaults={}) + # 0-D CPU tensors to avoid torch.compile recompilation when values change + self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + + def _reduce_adamw(self, group: dict, world_size: int) -> dict: + """Launch async reduce ops for AdamW group. Returns info dict with per-param infos.""" + param_infos = {} + for p in group['params']: + grad = p.grad + if p.numel() < 1024: + # Small params: all_reduce (no scatter/gather needed) + future = dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future() + param_infos[p] = dict(future=future, grad_slice=grad, is_small=True) + else: + # Large params: reduce_scatter + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + future = dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future() + param_infos[p] = dict(future=future, grad_slice=grad_slice, is_small=False) + return dict(param_infos=param_infos) + + def _reduce_muon(self, group: dict, world_size: int) -> dict: + """Launch async reduce op for Muon group. Returns info dict.""" + params = group['params'] + chunk_size = (len(params) + world_size - 1) // world_size + padded_num_params = chunk_size * world_size + p = params[0] + shape, device, dtype = p.shape, p.device, p.dtype + + # Stack grads and zero-pad to padded_num_params + grad_stack = torch.stack([p.grad for p in params]) + stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device) + stacked_grads[:len(params)].copy_(grad_stack) + if len(params) < padded_num_params: + stacked_grads[len(params):].zero_() + + # Reduce_scatter to get this rank's chunk + grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device) + future = dist.reduce_scatter_tensor(grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True).get_future() + + return dict(future=future, grad_chunk=grad_chunk, stacked_grads=stacked_grads, chunk_size=chunk_size) + + def _compute_adamw(self, group: dict, info: dict, gather_list: list, rank: int, world_size: int) -> None: + """Wait for reduce, compute AdamW updates, launch gathers for large params.""" + param_infos = info['param_infos'] + for p in group['params']: + pinfo = param_infos[p] + pinfo['future'].wait() + grad_slice = pinfo['grad_slice'] + state = self.state[p] + + # For small params, operate on full param; for large, operate on slice + if pinfo['is_small']: + p_slice = p + else: + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + + # State init + if not state: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + state['step'] += 1 + + # Fill 0-D tensors and run fused kernel + self._adamw_step_t.fill_(state['step']) + self._adamw_lr_t.fill_(group['lr']) + self._adamw_beta1_t.fill_(group['betas'][0]) + self._adamw_beta2_t.fill_(group['betas'][1]) + self._adamw_eps_t.fill_(group['eps']) + self._adamw_wd_t.fill_(group['weight_decay']) + adamw_step_fused( + p_slice, grad_slice, state['exp_avg'], state['exp_avg_sq'], + self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t, + self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t, + ) + + # Large params need all_gather + if not pinfo['is_small']: + future = dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future() + gather_list.append(dict(future=future, params=None)) + + def _compute_muon(self, group: dict, info: dict, gather_list: list, rank: int) -> None: + """Wait for reduce, compute Muon updates, launch gather.""" + info['future'].wait() + params = group['params'] + chunk_size = info['chunk_size'] + grad_chunk = info['grad_chunk'] + p = params[0] + shape, device, dtype = p.shape, p.device, p.dtype + + # How many params does this rank own? + start_idx = rank * chunk_size + num_owned = min(chunk_size, max(0, len(params) - start_idx)) + + # Get or create group-level state + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device) + if "second_momentum_buffer" not in state: + state_shape = (chunk_size, shape[-2], 1) if shape[-2] >= shape[-1] else (chunk_size, 1, shape[-1]) + state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device) + red_dim = -1 if shape[-2] >= shape[-1] else -2 + + # Build output buffer for all_gather + updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device) + + if num_owned > 0: + owned_params = [params[start_idx + i] for i in range(num_owned)] + stacked_owned = torch.stack(owned_params) + + # Fill 0-D tensors and run fused kernel + self._muon_momentum_t.fill_(group["momentum"]) + self._muon_beta2_t.fill_(group["beta2"]) + self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5) + self._muon_wd_t.fill_(group["weight_decay"]) + muon_step_fused( + grad_chunk[:num_owned], stacked_owned, + state["momentum_buffer"][:num_owned], state["second_momentum_buffer"][:num_owned], + self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t, self._muon_beta2_t, + group["ns_steps"], red_dim, + ) + updated_params[:num_owned].copy_(stacked_owned) + + if num_owned < chunk_size: + updated_params[num_owned:].zero_() + + # Reuse stacked_grads buffer for all_gather output + stacked_params = info["stacked_grads"] + future = dist.all_gather_into_tensor(stacked_params, updated_params, async_op=True).get_future() + gather_list.append(dict(future=future, stacked_params=stacked_params, params=params)) + + def _finish_gathers(self, gather_list: list) -> None: + """Wait for all gathers and copy Muon params back.""" + for info in gather_list: + info["future"].wait() + if info["params"] is not None: + # Muon: copy from stacked buffer back to individual params + torch._foreach_copy_(info["params"], list(info["stacked_params"][:len(info["params"])].unbind(0))) + + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + + # Phase 1: launch all async reduce ops + reduce_infos: list[dict] = [] + for group in self.param_groups: + if group['kind'] == 'adamw': + reduce_infos.append(self._reduce_adamw(group, world_size)) + elif group['kind'] == 'muon': + reduce_infos.append(self._reduce_muon(group, world_size)) + else: + raise ValueError(f"Unknown optimizer kind: {group['kind']}") + + # Phase 2: wait for reduces, compute updates, launch gathers + gather_list: list[dict] = [] + for group, info in zip(self.param_groups, reduce_infos): + if group['kind'] == 'adamw': + self._compute_adamw(group, info, gather_list, rank, world_size) + elif group['kind'] == 'muon': + self._compute_muon(group, info, gather_list, rank) + else: + raise ValueError(f"Unknown optimizer kind: {group['kind']}") + + # Phase 3: wait for gathers, copy back + self._finish_gathers(gather_list) diff --git a/scripts/base_train.py b/scripts/base_train.py index 4fa8fca..4bce6cd 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -211,9 +211,9 @@ print0(f"Tokens : Scaling params ratio: {args.total_batch_size * num_iterations print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") # ----------------------------------------------------------------------------- -# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head) +# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest) adam_betas = (args.adam_beta1, args.adam_beta2) -optimizers = model.setup_optimizers( +optimizer = model.setup_optimizer( unembedding_lr=args.unembedding_lr * batch_lr_scale, embedding_lr=args.embedding_lr * batch_lr_scale, matrix_lr=args.matrix_lr * batch_lr_scale, @@ -221,12 +221,10 @@ optimizers = model.setup_optimizers( adam_betas=adam_betas, scalar_lr=args.scalar_lr * batch_lr_scale, ) -adamw_optimizer, muon_optimizer = optimizers if resuming: - for opt, dat in zip(optimizers, optimizer_data): - opt.load_state_dict(dat) - del optimizer_data # free up the memory + optimizer.load_state_dict(optimizer_data) + del optimizer_data # ----------------------------------------------------------------------------- # Initialize the DataLoaders for train/val @@ -344,7 +342,7 @@ while True: checkpoint_dir, step, orig_model.state_dict(), # model parameters - [opt.state_dict() for opt in optimizers], # optimizer states + optimizer.state_dict(), # optimizer state { # metadata saved as json "step": step, "val_bpb": val_bpb, # loss at last step @@ -378,18 +376,16 @@ while True: loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here loss.backward() x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward - # step the optimizers + # step the optimizer lrm = get_lr_multiplier(step) - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["initial_lr"] * lrm muon_momentum = get_muon_momentum(step) muon_weight_decay = get_weight_decay(step) - for group in muon_optimizer.param_groups: - group["momentum"] = muon_momentum - group["weight_decay"] = muon_weight_decay - for opt in optimizers: - opt.step() + for group in optimizer.param_groups: + group["lr"] = group["initial_lr"] * lrm + if group['kind'] == 'muon': + group["momentum"] = muon_momentum + group["weight_decay"] = muon_weight_decay + optimizer.step() model.zero_grad(set_to_none=True) train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point synchronize() diff --git a/scripts/chat_rl.py b/scripts/chat_rl.py index eb8e48e..695c008 100644 --- a/scripts/chat_rl.py +++ b/scripts/chat_rl.py @@ -201,7 +201,7 @@ def run_gsm8k_eval(task, tokenizer, engine, # Training loop # Init the optimizer -optimizers = model.setup_optimizers( +optimizer = model.setup_optimizer( unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, @@ -209,10 +209,9 @@ optimizers = model.setup_optimizers( ) # Set the initial learning rate as a fraction of the base learning rate -for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["lr"] * args.init_lr_frac - group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later +for group in optimizer.param_groups: + group["lr"] = group["lr"] * args.init_lr_frac + group["initial_lr"] = group["lr"] # Learning rate scheduler: simple rampdown to zero over num_steps def get_lr_multiplier(it): @@ -305,11 +304,9 @@ for step in range(num_steps): # Update the model parameters lrm = get_lr_multiplier(step) - for opt in optimizers: # first set the learning rate - for group in opt.param_groups: - group["lr"] = group["initial_lr"] * lrm - for opt in optimizers: # then step the optimizers - opt.step() + for group in optimizer.param_groups: + group["lr"] = group["initial_lr"] * lrm + optimizer.step() model.zero_grad(set_to_none=True) wandb_run.log({ "step": step, diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index 9277cf9..c0471c4 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -150,17 +150,16 @@ build_val_loader = lambda: sft_data_generator(val_ds, batch_size=args.device_bat # ----------------------------------------------------------------------------- # Initialize the Optimizer -optimizers = model.setup_optimizers( +optimizer = model.setup_optimizer( unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay, ) # Set the initial learning rate as a fraction of the base learning rate -for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["lr"] * args.init_lr_frac - group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later +for group in optimizer.param_groups: + group["lr"] = group["lr"] * args.init_lr_frac + group["initial_lr"] = group["lr"] # ----------------------------------------------------------------------------- # Training loop @@ -230,13 +229,11 @@ for step in range(num_iterations): # learning rate scheduler lrm = get_lr_multiplier(step) - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["initial_lr"] * lrm + for group in optimizer.param_groups: + group["lr"] = group["initial_lr"] * lrm - # step the optimizers - for opt in optimizers: - opt.step() + # step the optimizer + optimizer.step() model.zero_grad(set_to_none=True) # logging diff --git a/scripts/mid_train.py b/scripts/mid_train.py index c127c94..ebe9cd5 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -93,14 +93,12 @@ print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") token_bytes = get_token_bytes(device=device) -# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head) -optimizers = model.setup_optimizers(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay) -adamw_optimizer, muon_optimizer = optimizers +# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest) +optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay) # Override the initial learning rate as a fraction of the base learning rate -for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["lr"] * args.init_lr_frac - group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later +for group in optimizer.param_groups: + group["lr"] = group["lr"] * args.init_lr_frac + group["initial_lr"] = group["lr"] # Midtraining data mixture and DataLoader base_dir = get_base_dir() @@ -274,7 +272,7 @@ while True: checkpoint_dir, step, orig_model.state_dict(), - [opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly + optimizer.state_dict(), { "step": step, "val_bpb": val_bpb, # loss at last step @@ -306,16 +304,14 @@ while True: loss.backward() x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward progress = max(progress, approx_progress) # only increase progress monotonically - # step the optimizers + # step the optimizer lrm = get_lr_multiplier(progress) - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["initial_lr"] * lrm muon_momentum = get_muon_momentum(step) - for group in muon_optimizer.param_groups: - group["momentum"] = muon_momentum - for opt in optimizers: - opt.step() + for group in optimizer.param_groups: + group["lr"] = group["initial_lr"] * lrm + if group['kind'] == 'muon': + group["momentum"] = muon_momentum + optimizer.step() model.zero_grad(set_to_none=True) synchronize() t1 = time.time()