add support for CPU and for MPS. I had to change a few cosmetic things. I also discovered I think a bit of a bug, where I was casting wte to bfloat16 in the wrong place (the model init) instead of in init_weights
This commit is contained in:
@@ -31,7 +31,7 @@ print_banner()
|
||||
# User settings
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
# Runtime
|
||||
device_type = "cuda" # cuda|cpu
|
||||
device_type = "cuda" # cuda|cpu|mps
|
||||
# Model architecture
|
||||
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
|
||||
max_seq_len = 2048 # max context length
|
||||
@@ -64,7 +64,8 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin
|
||||
# Compute init
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16)
|
||||
dtype = torch.bfloat16 if device_type == "cuda" else torch.float32 # use fp32 on CPU|MPS
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=dtype)
|
||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user