add autodetect of device and related stuff. getting weird warnings/errors still, so wip
This commit is contained in:
@@ -7,7 +7,7 @@ or distributed as:
|
||||
|
||||
torchrun --nproc_per_node=8 base_train.py
|
||||
|
||||
python -m scripts.base_train --device_type=cpu --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --total_batch_size=512 --num_iterations=1000
|
||||
python -m scripts.base_train --device_type=cpu --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_max_per_task=8 --total_batch_size=512 --num_iterations=500
|
||||
If you have a Macbook, you're better off using device_type=mps instead of cpu
|
||||
"""
|
||||
|
||||
@@ -19,7 +19,7 @@ import torch
|
||||
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
|
||||
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
@@ -31,7 +31,7 @@ print_banner()
|
||||
# User settings
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
# Runtime
|
||||
device_type = "cuda" # cuda|cpu|mps
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU)
|
||||
# Model architecture
|
||||
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
|
||||
max_seq_len = 2048 # max context length
|
||||
@@ -62,6 +62,7 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
dtype = torch.bfloat16 if device_type == "cuda" else torch.float32 # use fp32 on CPU|MPS
|
||||
|
||||
Reference in New Issue
Block a user