From 9d4c9b786d885d4816b3e27d949d745ff280d267 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 5 Jan 2026 00:38:09 +0000 Subject: [PATCH] many small fixes to base_train: reporting ETA, allowing some additional kwarg flexibility, making sure we don't crash when e.g. depth = 11 - we now calculate the closest num_heads that works --- scripts/base_train.py | 35 +++++++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/scripts/base_train.py b/scripts/base_train.py index 6118ad6..2390b68 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -57,11 +57,11 @@ parser.add_argument("--warmdown_ratio", type=float, default=0.2, help="ratio of parser.add_argument("--final_lr_frac", type=float, default=0.0, help="final LR as fraction of initial LR") parser.add_argument("--resume_from_step", type=int, default=-1, help="resume training from this step (-1 = disable)") # Evaluation -parser.add_argument("--eval_every", type=int, default=250, help="evaluate val bpb every N steps") +parser.add_argument("--eval_every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)") parser.add_argument("--eval_tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on") parser.add_argument("--core_metric_every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)") parser.add_argument("--core_metric_max_per_task", type=int, default=500, help="examples per task for CORE metric") -parser.add_argument("--sample_every", type=int, default=2000, help="sample from model every N steps") +parser.add_argument("--sample_every", type=int, default=2000, help="sample from model every N steps (-1 = disable)") parser.add_argument("--save_every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)") # Output parser.add_argument("--model_tag", type=str, default=None, help="override model tag for checkpoint directory name") @@ -90,7 +90,15 @@ print0(f"Vocab size: {vocab_size:,}") # Model kwargs are derived from the desired depth of the model num_layers = args.depth model_dim = args.depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases) -num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div) +def find_num_heads(model_dim, target_head_dim=128): + # Find num_heads that divides model_dim evenly, with head_dim closest to target. + ideal = max(1, round(model_dim / target_head_dim)) + for offset in range(model_dim): + for candidate in [ideal + offset, ideal - offset]: + if candidate > 0 and model_dim % candidate == 0: + return candidate + return 1 +num_heads = find_num_heads(model_dim) num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled) print0(f"num_layers: {num_layers}") print0(f"model_dim: {model_dim}") @@ -202,6 +210,7 @@ def get_muon_momentum(it): if not resuming: step = 0 + val_bpb = None # will be set if eval_every > 0 min_val_bpb = float("inf") smooth_train_loss = 0 # EMA of training loss total_training_time = 0 # total wall-clock time of training @@ -220,7 +229,7 @@ while True: flops_so_far = num_flops_per_token * args.total_batch_size * step # once in a while: evaluate the val bpb (all ranks participate) - if last_step or step % args.eval_every == 0: + if args.eval_every > 0 and (last_step or step % args.eval_every == 0): model.eval() val_loader = build_val_loader() eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size) @@ -255,7 +264,7 @@ while True: # once in a while: sample from the model (only on master process) # use the original uncompiled model because the inputs keep changing shape - if master_process and (last_step or (step > 0 and step % args.sample_every == 0)): + if args.sample_every > 0 and master_process and (last_step or (step > 0 and step % args.sample_every == 0)): model.eval() prompts = [ "The capital of France is", @@ -347,7 +356,16 @@ while True: if step > 10: total_training_time += dt # only count the time after the first 10 steps print_grad_norm = f" grad norm: {grad_norm:.4f} |" if grad_clip_enabled else "" - print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} |{print_grad_norm} lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") + # Calculate ETA based on average time per step (excluding first 10 steps) + steps_done = step - 10 + if steps_done > 0: + avg_time_per_step = total_training_time / steps_done + remaining_steps = num_iterations - step + eta_seconds = remaining_steps * avg_time_per_step + eta_str = f" | eta: {eta_seconds/60:.1f}m" + else: + eta_str = "" + print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} |{print_grad_norm} lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m{eta_str}") if step % 100 == 0: log_data = { "step": step, @@ -369,7 +387,8 @@ while True: # print a few more stats print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") print0(f"Total training time: {total_training_time/60:.2f}m") -print0(f"Minimum validation bpb: {min_val_bpb:.4f}") +if val_bpb is not None: + print0(f"Minimum validation bpb: {min_val_bpb:.4f}") # Log to report from nanochat.report import get_report @@ -387,7 +406,7 @@ get_report().log(section="Base model training", data=[ "final_lr_frac": args.final_lr_frac, }, { # stats about training outcomes - "Minimum validation bpb": min_val_bpb, + "Minimum validation bpb": min_val_bpb if val_bpb is not None else None, "Final validation bpb": val_bpb, "CORE metric estimate": results.get("core_metric", None), "MFU %": f"{mfu:.2f}%",