diff --git a/nanochat/adamw.py b/nanochat/adamw.py index 8816057..0b97ae2 100644 --- a/nanochat/adamw.py +++ b/nanochat/adamw.py @@ -68,8 +68,8 @@ class DistAdamW(torch.optim.Optimizer): bias1 = 1 - beta1 ** t bias2 = 1 - beta2 ** t # compute step - denom = exp_avg_sq.sqrt().add_(eps) - step_size = lr * (torch.sqrt(bias2) / bias1) + denom = (exp_avg_sq / bias2).sqrt().add_(eps) + step_size = lr / bias1 update = exp_avg.div(denom).mul_(step_size) p_slice.add_(other=update, alpha=-1.0) idx += 1