diff --git a/nanochat/common.py b/nanochat/common.py index 05e371c..86c18de 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -89,6 +89,16 @@ def get_dist_info(): else: return False, 0, 0, 1 +def autodetect_device_type(): + # prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU + if torch.cuda.is_available(): + device_type = "cuda" + if torch.backends.mps.is_available(): + device_type = "mps" + device_type = "cpu" + print0(f"Autodetected device type: {device_type}") + return device_type + def compute_init(device_type="cuda"): # cuda|cpu|mps """Basic initialization that we keep doing over and over, so make common.""" diff --git a/scripts/base_eval.py b/scripts/base_eval.py index a566d49..2d58d87 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -19,7 +19,7 @@ import yaml import pandas as pd import torch -from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir +from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type from nanochat.tokenizer import HuggingFaceTokenizer from nanochat.checkpoint_manager import load_model from nanochat.core_eval import evaluate_task @@ -121,8 +121,10 @@ def main(): assert len(sys.argv) in [1, 2], "Usage: python base_eval.py [hf_path]" # distributed / precision setup - ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() - autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) + device_type = autodetect_device_type() + ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) + dtype = torch.bfloat16 if device_type == "cuda" else torch.float32 # use fp32 on CPU|MPS + autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=dtype) # Load model and tokenizer from command line or from file system if len(sys.argv) >= 2: diff --git a/scripts/base_loss.py b/scripts/base_loss.py index ba3876d..1609d83 100644 --- a/scripts/base_loss.py +++ b/scripts/base_loss.py @@ -9,7 +9,7 @@ torchrun --standalone --nproc_per_node=8 -m scripts.base_loss import os import torch from nanochat.checkpoint_manager import load_model -from nanochat.common import compute_init, print0, compute_cleanup +from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type from nanochat.dataloader import tokenizing_distributed_data_loader from nanochat.tokenizer import get_token_bytes from nanochat.loss_eval import evaluate_bpb @@ -20,15 +20,16 @@ device_batch_size = 32 split_tokens = 20*524288 # number of tokens to evaluate per split model_tag = None # optional model tag for the output directory name model_step = None # optional model step for the output directory name +device_type = "" # cuda|cpu|mps (empty => autodetect) exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file # Load the base model and the tokenizer -ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() +device_type = autodetect_device_type() if device_type == "" else device_type +ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) +dtype = torch.bfloat16 if device_type == "cuda" else torch.float32 # use fp32 on CPU|MPS model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step) sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really - -# Set up the precision we'll run with -autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) +autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=dtype) # Evaluate the loss on each split tokens_per_step = device_batch_size * sequence_len * ddp_world_size diff --git a/scripts/base_train.py b/scripts/base_train.py index ebc5ff4..147fce2 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -7,7 +7,7 @@ or distributed as: torchrun --nproc_per_node=8 base_train.py -python -m scripts.base_train --device_type=cpu --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --total_batch_size=512 --num_iterations=1000 +python -m scripts.base_train --device_type=cpu --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_max_per_task=8 --total_batch_size=512 --num_iterations=500 If you have a Macbook, you're better off using device_type=mps instead of cpu """ @@ -19,7 +19,7 @@ import torch from nanochat.gpt import GPT, GPTConfig from nanochat.dataloader import tokenizing_distributed_data_loader -from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir +from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type from nanochat.tokenizer import get_tokenizer, get_token_bytes from nanochat.checkpoint_manager import save_checkpoint from nanochat.loss_eval import evaluate_bpb @@ -31,7 +31,7 @@ print_banner() # User settings run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb) # Runtime -device_type = "cuda" # cuda|cpu|mps +device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU) # Model architecture depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived max_seq_len = 2048 # max context length @@ -62,6 +62,7 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin # ----------------------------------------------------------------------------- # Compute init +device_type = autodetect_device_type() if device_type == "" else device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. dtype = torch.bfloat16 if device_type == "cuda" else torch.float32 # use fp32 on CPU|MPS