add support for CPU and for MPS. I had to change a few cosmetic things. I also discovered I think a bit of a bug, where I was casting wte to bfloat16 in the wrong place (the model init) instead of in init_weights

This commit is contained in:
karpathy
2025-10-16 10:04:43 -07:00
parent 722da4f543
commit 306bc380ab
6 changed files with 68 additions and 46 deletions

View File

@@ -31,7 +31,7 @@ print_banner()
# User settings
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
# Runtime
device_type = "cuda" # cuda|cpu
device_type = "cuda" # cuda|cpu|mps
# Model architecture
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
max_seq_len = 2048 # max context length
@@ -64,7 +64,8 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin
# Compute init
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16)
dtype = torch.bfloat16 if device_type == "cuda" else torch.float32 # use fp32 on CPU|MPS
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=dtype)
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0