add autodetect of device and related stuff. getting weird warnings/errors still, so wip

This commit is contained in:
karpathy
2025-10-16 10:26:19 -07:00
parent 279b74312c
commit 786119d593
4 changed files with 25 additions and 11 deletions

View File

@@ -7,7 +7,7 @@ or distributed as:
torchrun --nproc_per_node=8 base_train.py
python -m scripts.base_train --device_type=cpu --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --total_batch_size=512 --num_iterations=1000
python -m scripts.base_train --device_type=cpu --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_max_per_task=8 --total_batch_size=512 --num_iterations=500
If you have a Macbook, you're better off using device_type=mps instead of cpu
"""
@@ -19,7 +19,7 @@ import torch
from nanochat.gpt import GPT, GPTConfig
from nanochat.dataloader import tokenizing_distributed_data_loader
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.loss_eval import evaluate_bpb
@@ -31,7 +31,7 @@ print_banner()
# User settings
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
# Runtime
device_type = "cuda" # cuda|cpu|mps
device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU)
# Model architecture
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
max_seq_len = 2048 # max context length
@@ -62,6 +62,7 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin
# -----------------------------------------------------------------------------
# Compute init
device_type = autodetect_device_type() if device_type == "" else device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
dtype = torch.bfloat16 if device_type == "cuda" else torch.float32 # use fp32 on CPU|MPS