diff --git a/dev/LOG.md b/dev/LOG.md index 449cd7f..13fc08e 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -4,6 +4,65 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026 --- +## 2026-01-10: Muon Optimizer Upgrades & Cautious Weight Decay + +Cherry-picked improvements from NorMuon (modded-nanogpt) into our simpler Muon implementation. Decided against using NorMuon directly due to hard-coded architecture assumptions (expects 32 params split 10 attn + 22 mlp), parameter labeling requirements, and complexity. + +### Changes Made + +**1. Polar Express Orthogonalization** +- Replaced Newton-Schulz iteration with "Polar Express Sign Method" from [arxiv.org/pdf/2505.16932](https://arxiv.org/pdf/2505.16932) +- Uses 5 different coefficient tuples (one per iteration) instead of fixed coefficients +- Both methods kept in code for easy comparison (`zeropower_via_polar_express` vs `zeropower_via_newtonschulz5`) +- **Result:** No dramatic/noticeable difference in training, but keeping the new Polar Express as default. + +**2. Variance Reduction (NorMuon-style)** +- Added low-rank variance estimator similar to Adafactor ([arxiv.org/pdf/2510.05491](https://arxiv.org/pdf/2510.05491)) +- Maintains `second_momentum_buffer` with shape `[rows, 1]` or `[1, cols]` (whichever is smaller) +- Normalizes updates based on running per-row/col variance estimate (beta2=0.95) +- Memory overhead: ~1/max(rows, cols) per param, negligible +- **Result:** Led to a very small improvement, kept and enabled by default. + +**3. Cautious Weight Decay** +- Only decays weights where `update * weight >= 0` (same sign) from [arxiv.org/abs/2411.16085](https://arxiv.org/abs/2411.16085) +- Standard WD always pulls toward zero; cautious WD skips decay when gradient is pushing weight away from zero +- **Implementation note:** Had to inline the logic rather than use a separate `@torch.compile` function. Passing changing float values (like `weight_decay` during scheduling) as function arguments triggers recompilation. Reading from `group["weight_decay"]` inside the step avoids this. +- **Result:** Solid improvements, especially the cautious version was better than standard wd. +- Now defaults to ON for Muon via the `weight_decay` param. AdamW still has no weight decay and is hardcoded to 0 weight decay, might try to re-tune this later. + +**4. Weight decay schedule** +- Added a linear schedule to weight decay that is default on from 1.0 to 0.0 (i.e. start with max weight decay in the beginning of training, them ramp to 0 by the end). Worked better than a static setting in experiments. (modded-nanogpt has the same schedule but it is imlpemented in a more confusing way by multiplying twice by the learning rate, which is already wired up to a decay schedule). + +### Weight Decay Scaling Experiments + +Swept weight decay values at d8, d12, d16, d20 to find optimal values and scaling law. + +**Optimal Values Found:** +| Depth | Width (channels) | Optimal WD | +|-------|------------------|------------| +| d8 | 512 | ~0.40 | +| d12 | 768 | ~0.22 | +| d16 | 1024 | ~0.10 | +| d20 | 1280 | ~0.08 | + +**Scaling Law:** +- Fit power law: `WD = k / channels^α` in log-log space +- Found α ≈ 1.97 (approximately 2), meaning WD ∝ 1/width² + +**Practical Formula:** +``` +WD_target = WD_reference × (d_reference / d_target)² +``` +Example: If d12 optimal is 0.22, then d20 optimal ≈ 0.22 × (12/20)² ≈ 0.08 + +**Reference:** Moonlight paper uses fixed WD=0.1 for their 15B MoE model. Our experiments indicated a scaling law where the optimal WD changed with depth, so we go along with the empirical scaling law. + +### Summary + +Muon was changed to use Polar Express, added Adafactor-style variance reduction, and cautious weight decay with schedule that ramps linearly to zero. All of these changes follow modded-nanogpt repo, but all of them were also validated piece by piece to yield improvements in nanochat with the exception of the Polar Express change which was in the noise. This is default on and configurable with `--weight_decay`, using simply 0.2 and ∝ 1/width² scaling. The kwarg `--weight_decay` is therefore changing as of this change. It used to configure AdamW via standard weight decay and now it becomes exclusively used in Muon (AdamW is hardcoded to 0.0), and it is scaled based on depth. + +--- + ## 2026-01-08: exp_grad_clip - Gradient Clipping **Hypothesis:** Gradient clipping may be unnecessary overhead. Tested L2 norm clipping at various thresholds (0.25, 0.5, 1.0, 2.0) and elementwise clipping. @@ -18,6 +77,4 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026 **Observartion:** modded-nanogpt does not appear to clip either right now. -**Recommendation:** Disable by default (`--grad_clip=0.0`). The code naturally produces well-behaved gradients. - ---- +**Summary:** Deleted all grad-clip code paths. The code naturally produces well-behaved gradients. This improves a bit of MFU because we don't have to calculate and sync grad norms. diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 478f687..2ffdc50 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -260,11 +260,11 @@ class GPT(nn.Module): dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale), dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale), ] - adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=weight_decay) + adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True) adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs) # Create the Muon optimizer for the linear layers - muon_kwargs = dict(lr=matrix_lr, momentum=0.95) + muon_kwargs = dict(lr=matrix_lr, momentum=0.95, weight_decay=weight_decay) MuonFactory = DistMuon if ddp else Muon muon_optimizer = MuonFactory(matrix_params, **muon_kwargs) # Combine them the two optimizers into one list diff --git a/nanochat/muon.py b/nanochat/muon.py index d916103..7ae5ffd 100644 --- a/nanochat/muon.py +++ b/nanochat/muon.py @@ -1,11 +1,50 @@ """ -Muon optimizer from Keller et al. -Also a lot of borrowing of ideas from modded-nanogpt. +Muon optimizer adapted (simplified) from modded-nanogpt. +https://github.com/KellerJordan/modded-nanogpt """ import torch from torch import Tensor import torch.distributed as dist +# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2) +# From https://arxiv.org/pdf/2505.16932 +polar_express_coeffs = [ + (8.156554524902461, -22.48329292557795, 15.878769915207462), + (4.042929935166739, -2.808917465908714, 0.5000178451051316), + (3.8916678022926607, -2.772484153217685, 0.5060648178503393), + (3.285753657755655, -2.3681294933425376, 0.46449024233003106), + (2.3465413258596377, -1.7097828382687081, 0.42323551169305323), +] + + +@torch.compile +def zeropower_via_polar_express(G: Tensor, steps: int = 5) -> Tensor: + """ + Polar Express Sign Method for orthogonalization. + https://arxiv.org/pdf/2505.16932 + by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower. + + Alternative to Newton-Schulz iteration with potentially better convergence properties. + """ + assert G.ndim >= 2 + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 (with 2% safety factor) + X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6) + + # Perform the iterations (cap at available coefficients) + for a, b, c in polar_express_coeffs[:min(steps, len(polar_express_coeffs))]: + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + @torch.compile def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: """ @@ -35,6 +74,40 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: X = X.mT return X + +@torch.compile +def apply_variance_reduction(v: Tensor, second_momentum_buffer: Tensor, beta2: float) -> Tensor: + """ + NorMuon-style variance reduction, similar to Adafactor's low-rank variance estimator. + https://arxiv.org/pdf/2510.05491 + + Normalizes updates based on a running estimate of per-row (or per-column) variance. + The reduction dimension is determined by the shape of second_momentum_buffer. + """ + # Determine reduction dimension from buffer shape + red_dim = -1 if second_momentum_buffer.size(-1) == 1 else -2 + + # Compute per-row/col mean of squared values + v_mean = v.float().square().mean(dim=red_dim, keepdim=True) + red_dim_size = v.size(red_dim) + + # Compute current norm + v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size + v_norm = v_norm_sq.sqrt() + + # Update second momentum buffer (EMA of variance) + second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2) + + # Compute scaling factor from second momentum + step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt() + scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square() + v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt() + + # Final scale preserves overall norm while adjusting per-row/col + final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10)) + return v.mul(final_scale.to(v.dtype)) + + class Muon(torch.optim.Optimizer): """ Muon - MomentUm Orthogonalized by Newton-schulz @@ -56,9 +129,11 @@ class Muon(torch.optim.Optimizer): momentum: The momentum used by the internal SGD. nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) ns_steps: The number of Newton-Schulz iteration steps to use. + beta2: The decay rate for the second moment (variance) estimate. Set to None to disable. + weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree. """ - def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5): - defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) + def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5, beta2=0.95, weight_decay=0.0): + defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay) params: list[Tensor] = [*params] param_groups = [] for size in {p.numel() for p in params}: @@ -79,13 +154,29 @@ class Muon(torch.optim.Optimizer): buf: Tensor = state["momentum_buffer"] buf.lerp_(g, 1 - group["momentum"]) g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf - g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) - p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5) + g = zeropower_via_polar_express(g, steps=group["ns_steps"]) + # Variance reduction (NorMuon-style) + if group["beta2"] is not None: + if "second_momentum_buffer" not in state: + # Buffer shape determines reduction dim: reduce along larger dimension + if p.size(-2) >= p.size(-1): + state["second_momentum_buffer"] = torch.zeros_like(g[..., :1]) + else: + state["second_momentum_buffer"] = torch.zeros_like(g[..., :1, :]) + g = apply_variance_reduction(g, state["second_momentum_buffer"], group["beta2"]) + # Parameter update with cautious weight decay + effective_lr = group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5 + wd = group["weight_decay"] + if wd != 0: + mask = (g * p) >= 0 + p.sub_(effective_lr * g + effective_lr * wd * p * mask) + else: + p.sub_(effective_lr * g) class DistMuon(torch.optim.Optimizer): """ - Muon: SGD-momentum + (optional) Nesterov, then orthogonalize the 2D update via Newton–Schulz, + Muon: SGD-momentum + (optional) Nesterov, then orthogonalize the 2D update via Polar Express, finally apply aspect-ratio scaled step. Performs its own distributed synchronization: - reduce_scatter(AVG) for gradient averaging - all_gather to replicate updated weights @@ -102,11 +193,13 @@ class DistMuon(torch.optim.Optimizer): lr: learning rate momentum: momentum coefficient in [0,1) nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf - ns_steps: number of Newton–Schulz iterations for the orthogonalization + ns_steps: number of Newton-Schulz iterations for the orthogonalization + beta2: decay rate for second moment (variance) estimate. Set to None to disable. + weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree. """ def __init__(self, params, lr: float = 0.02, momentum: float = 0.95, - nesterov: bool = True, ns_steps: int = 5): - defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) + nesterov: bool = True, ns_steps: int = 5, beta2: float = 0.95, weight_decay: float = 0.0): + defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay) params = list(params) assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only" rank = dist.get_rank() @@ -173,9 +266,24 @@ class DistMuon(torch.optim.Optimizer): buf: Tensor = state["momentum_buffer"] buf.lerp_(g, 1.0 - group["momentum"]) g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf - g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) - scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5) - p.add_(g, alpha=-group["lr"] * scale) + g = zeropower_via_polar_express(g, steps=group["ns_steps"]) + # Variance reduction (NorMuon-style) + if group["beta2"] is not None: + if "second_momentum_buffer" not in state: + # Buffer shape determines reduction dim: reduce along larger dimension + if p.size(-2) >= p.size(-1): + state["second_momentum_buffer"] = torch.zeros_like(g[..., :1]) + else: + state["second_momentum_buffer"] = torch.zeros_like(g[..., :1, :]) + g = apply_variance_reduction(g, state["second_momentum_buffer"], group["beta2"]) + # Parameter update with cautious weight decay + effective_lr = group["lr"] * (max(1.0, p.size(-2) / p.size(-1)) ** 0.5) + wd = group["weight_decay"] + if wd != 0: + mask = (g * p) >= 0 + p.sub_(effective_lr * g + effective_lr * wd * p * mask) + else: + p.sub_(effective_lr * g) # Replicate updated parameters to all ranks ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer ag_output = params[base_i:base_i + world_size] diff --git a/scripts/base_train.py b/scripts/base_train.py index e3df0f0..84d44bf 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -51,7 +51,7 @@ parser.add_argument("--device_batch_size", type=int, default=32, help="per-devic parser.add_argument("--total_batch_size", type=int, default=524288, help="total batch size in tokens") parser.add_argument("--embedding_lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)") parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") -parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)") +parser.add_argument("--weight_decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)") parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") parser.add_argument("--adam_beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding") parser.add_argument("--adam_beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding") @@ -129,6 +129,11 @@ if batch_ratio != 1.0: batch_lr_scale = batch_ratio ** 0.5 print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {args.total_batch_size:,} (reference: {reference_batch_size:,})") +# Weight decay is tuned at d12 and its scaling seems to be \propto 1/channels^2 (or equivalently, \propto 1/depth^2 due to constant aspect ratio) +weight_decay_scaled = args.weight_decay * (12 / args.depth)**2 +if args.depth != 12: + print0(f"Scaling weight decay from {args.weight_decay:.6f} to {weight_decay_scaled:.6f} for depth {args.depth}") + # ----------------------------------------------------------------------------- # Initialize the Model @@ -188,7 +193,7 @@ optimizers = model.setup_optimizers( unembedding_lr=args.unembedding_lr * batch_lr_scale, embedding_lr=args.embedding_lr * batch_lr_scale, matrix_lr=args.matrix_lr * batch_lr_scale, - weight_decay=args.weight_decay, + weight_decay=weight_decay_scaled, adam_betas=adam_betas, ) adamw_optimizer, muon_optimizer = optimizers @@ -227,6 +232,10 @@ def get_muon_momentum(it): momentum = (1 - frac) * 0.85 + frac * 0.95 return momentum +# Weight decay scheduler for Muon optimizer (linear to zero over the course of training) +def get_weight_decay(it): + return weight_decay_scaled * (1 - it / num_iterations) + # ----------------------------------------------------------------------------- # Loop state (variables updated by the training loop) @@ -257,7 +266,7 @@ while True: eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size) with autocast_ctx: val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes) - print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}") + print0(f"Step {step:05d} | Validation bpb: {val_bpb:.6f}") if val_bpb < min_val_bpb: min_val_bpb = val_bpb wandb_run.log({ @@ -351,8 +360,10 @@ while True: for group in opt.param_groups: group["lr"] = group["initial_lr"] * lrm muon_momentum = get_muon_momentum(step) + muon_weight_decay = get_weight_decay(step) for group in muon_optimizer.param_groups: group["momentum"] = muon_momentum + group["weight_decay"] = muon_weight_decay for opt in optimizers: opt.step() model.zero_grad(set_to_none=True) @@ -402,7 +413,7 @@ while True: print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") print0(f"Total training time: {total_training_time/60:.2f}m") if val_bpb is not None: - print0(f"Minimum validation bpb: {min_val_bpb:.4f}") + print0(f"Minimum validation bpb: {min_val_bpb:.6f}") # Log to report from nanochat.report import get_report