mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-30 04:22:02 +00:00
simplify, clarify and slightly tune model initialization. should be very slightly better possibly, but certainly a lot clearer
This commit is contained in:
@@ -112,10 +112,11 @@ print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {
|
||||
# Create a new model with random weights
|
||||
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
|
||||
with torch.device("meta"):
|
||||
# All tensors are created as meta tensors (they have shape/dtype but no data)
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
model = GPT(model_config)
|
||||
model.to_empty(device=device)
|
||||
model.init_weights()
|
||||
model.to_empty(device=device) # All tensors get storage on target device but with uninitialized (garbage) data
|
||||
model.init_weights() # All tensors get initialized
|
||||
|
||||
# If we are resuming, overwrite the model parameters with those of the checkpoint
|
||||
base_dir = get_base_dir()
|
||||
|
||||
Reference in New Issue
Block a user