simplify, clarify and slightly tune model initialization. should be very slightly better possibly, but certainly a lot clearer

This commit is contained in:
Andrej Karpathy
2026-01-01 21:14:26 +00:00
parent 10231dfb40
commit 48abd7d85f
2 changed files with 36 additions and 24 deletions

View File

@@ -112,10 +112,11 @@ print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {
# Create a new model with random weights
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
with torch.device("meta"):
# All tensors are created as meta tensors (they have shape/dtype but no data)
model_config = GPTConfig(**model_config_kwargs)
model = GPT(model_config)
model.to_empty(device=device)
model.init_weights()
model.to_empty(device=device) # All tensors get storage on target device but with uninitialized (garbage) data
model.init_weights() # All tensors get initialized
# If we are resuming, overwrite the model parameters with those of the checkpoint
base_dir = get_base_dir()