delete grad_clip. appears to not be necessary at all. not only was it buggy because the clipping happened per gpu before grad synchronization, but it costs ~2% MFU, and it also doesn't even help. I tried deleting it a while ago and back then it did help. So I'm guessing that some hyperparameter tuning obviated the reason for it since then

This commit is contained in:
Andrej Karpathy
2026-01-08 02:16:50 +00:00
parent e8c30c3b19
commit 061f83c152
2 changed files with 24 additions and 10 deletions

View File

@@ -55,7 +55,6 @@ parser.add_argument("--weight_decay", type=float, default=0.0, help="weight deca
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--adam_beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding")
parser.add_argument("--adam_beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding")
parser.add_argument("--grad_clip", type=float, default=1.0, help="gradient clipping value (0.0 = disabled)")
parser.add_argument("--warmup_ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
parser.add_argument("--warmdown_ratio", type=float, default=0.4, help="ratio of iterations for LR warmdown")
parser.add_argument("--final_lr_frac", type=float, default=0.0, help="final LR as fraction of initial LR")
@@ -346,11 +345,6 @@ while True:
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward()
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
# gradient clipping
grad_clip_enabled = args.grad_clip > 0.0
if grad_clip_enabled:
grad_norm_tensor = torch.nn.utils.clip_grad_norm_(orig_model.parameters(), args.grad_clip)
grad_norm = grad_norm_tensor.item() # GPU tensor -> CPU float (note: cpu-gpu sync point)
# step the optimizers
lrm = get_lr_multiplier(step)
for opt in optimizers:
@@ -378,7 +372,6 @@ while True:
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
print_grad_norm = f" grad norm: {grad_norm:.4f} |" if grad_clip_enabled else ""
# Calculate ETA based on average time per step (excluding first 10 steps)
steps_done = step - 10
if steps_done > 0:
@@ -388,7 +381,7 @@ while True:
eta_str = f" | eta: {eta_seconds/60:.1f}m"
else:
eta_str = ""
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} |{print_grad_norm} lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m{eta_str}")
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m{eta_str}")
if step % 100 == 0:
log_data = {
"step": step,
@@ -400,8 +393,6 @@ while True:
"train/tok_per_sec": tok_per_sec,
"train/mfu": mfu,
}
if grad_clip_enabled:
log_data["train/grad_norm"] = grad_norm
wandb_run.log(log_data)
# state update