add autodetect of device and related stuff. getting weird warnings/errors still, so wip

This commit is contained in:
karpathy
2025-10-16 10:26:19 -07:00
parent 279b74312c
commit 786119d593
4 changed files with 25 additions and 11 deletions

View File

@@ -19,7 +19,7 @@ import yaml
import pandas as pd
import torch
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type
from nanochat.tokenizer import HuggingFaceTokenizer
from nanochat.checkpoint_manager import load_model
from nanochat.core_eval import evaluate_task
@@ -121,8 +121,10 @@ def main():
assert len(sys.argv) in [1, 2], "Usage: python base_eval.py [hf_path]"
# distributed / precision setup
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
device_type = autodetect_device_type()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
dtype = torch.bfloat16 if device_type == "cuda" else torch.float32 # use fp32 on CPU|MPS
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=dtype)
# Load model and tokenizer from command line or from file system
if len(sys.argv) >= 2: