add autodetect of device and related stuff. getting weird warnings/errors still, so wip
This commit is contained in:
@@ -19,7 +19,7 @@ import yaml
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type
|
||||
from nanochat.tokenizer import HuggingFaceTokenizer
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.core_eval import evaluate_task
|
||||
@@ -121,8 +121,10 @@ def main():
|
||||
assert len(sys.argv) in [1, 2], "Usage: python base_eval.py [hf_path]"
|
||||
|
||||
# distributed / precision setup
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
device_type = autodetect_device_type()
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
dtype = torch.bfloat16 if device_type == "cuda" else torch.float32 # use fp32 on CPU|MPS
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=dtype)
|
||||
|
||||
# Load model and tokenizer from command line or from file system
|
||||
if len(sys.argv) >= 2:
|
||||
|
||||
Reference in New Issue
Block a user