add autodetect of device and related stuff. getting weird warnings/errors still, so wip
This commit is contained in:
@@ -89,6 +89,16 @@ def get_dist_info():
|
|||||||
else:
|
else:
|
||||||
return False, 0, 0, 1
|
return False, 0, 0, 1
|
||||||
|
|
||||||
|
def autodetect_device_type():
|
||||||
|
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device_type = "cuda"
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
device_type = "mps"
|
||||||
|
device_type = "cpu"
|
||||||
|
print0(f"Autodetected device type: {device_type}")
|
||||||
|
return device_type
|
||||||
|
|
||||||
def compute_init(device_type="cuda"): # cuda|cpu|mps
|
def compute_init(device_type="cuda"): # cuda|cpu|mps
|
||||||
"""Basic initialization that we keep doing over and over, so make common."""
|
"""Basic initialization that we keep doing over and over, so make common."""
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import yaml
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir
|
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type
|
||||||
from nanochat.tokenizer import HuggingFaceTokenizer
|
from nanochat.tokenizer import HuggingFaceTokenizer
|
||||||
from nanochat.checkpoint_manager import load_model
|
from nanochat.checkpoint_manager import load_model
|
||||||
from nanochat.core_eval import evaluate_task
|
from nanochat.core_eval import evaluate_task
|
||||||
@@ -121,8 +121,10 @@ def main():
|
|||||||
assert len(sys.argv) in [1, 2], "Usage: python base_eval.py [hf_path]"
|
assert len(sys.argv) in [1, 2], "Usage: python base_eval.py [hf_path]"
|
||||||
|
|
||||||
# distributed / precision setup
|
# distributed / precision setup
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
device_type = autodetect_device_type()
|
||||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||||
|
dtype = torch.bfloat16 if device_type == "cuda" else torch.float32 # use fp32 on CPU|MPS
|
||||||
|
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=dtype)
|
||||||
|
|
||||||
# Load model and tokenizer from command line or from file system
|
# Load model and tokenizer from command line or from file system
|
||||||
if len(sys.argv) >= 2:
|
if len(sys.argv) >= 2:
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from nanochat.checkpoint_manager import load_model
|
from nanochat.checkpoint_manager import load_model
|
||||||
from nanochat.common import compute_init, print0, compute_cleanup
|
from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type
|
||||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
from nanochat.dataloader import tokenizing_distributed_data_loader
|
||||||
from nanochat.tokenizer import get_token_bytes
|
from nanochat.tokenizer import get_token_bytes
|
||||||
from nanochat.loss_eval import evaluate_bpb
|
from nanochat.loss_eval import evaluate_bpb
|
||||||
@@ -20,15 +20,16 @@ device_batch_size = 32
|
|||||||
split_tokens = 20*524288 # number of tokens to evaluate per split
|
split_tokens = 20*524288 # number of tokens to evaluate per split
|
||||||
model_tag = None # optional model tag for the output directory name
|
model_tag = None # optional model tag for the output directory name
|
||||||
model_step = None # optional model step for the output directory name
|
model_step = None # optional model step for the output directory name
|
||||||
|
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||||
|
|
||||||
# Load the base model and the tokenizer
|
# Load the base model and the tokenizer
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||||
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||||
|
dtype = torch.bfloat16 if device_type == "cuda" else torch.float32 # use fp32 on CPU|MPS
|
||||||
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step)
|
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step)
|
||||||
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
|
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
|
||||||
|
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=dtype)
|
||||||
# Set up the precision we'll run with
|
|
||||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
# Evaluate the loss on each split
|
# Evaluate the loss on each split
|
||||||
tokens_per_step = device_batch_size * sequence_len * ddp_world_size
|
tokens_per_step = device_batch_size * sequence_len * ddp_world_size
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ or distributed as:
|
|||||||
|
|
||||||
torchrun --nproc_per_node=8 base_train.py
|
torchrun --nproc_per_node=8 base_train.py
|
||||||
|
|
||||||
python -m scripts.base_train --device_type=cpu --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --total_batch_size=512 --num_iterations=1000
|
python -m scripts.base_train --device_type=cpu --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_max_per_task=8 --total_batch_size=512 --num_iterations=500
|
||||||
If you have a Macbook, you're better off using device_type=mps instead of cpu
|
If you have a Macbook, you're better off using device_type=mps instead of cpu
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -19,7 +19,7 @@ import torch
|
|||||||
|
|
||||||
from nanochat.gpt import GPT, GPTConfig
|
from nanochat.gpt import GPT, GPTConfig
|
||||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
from nanochat.dataloader import tokenizing_distributed_data_loader
|
||||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir
|
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
|
||||||
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
||||||
from nanochat.checkpoint_manager import save_checkpoint
|
from nanochat.checkpoint_manager import save_checkpoint
|
||||||
from nanochat.loss_eval import evaluate_bpb
|
from nanochat.loss_eval import evaluate_bpb
|
||||||
@@ -31,7 +31,7 @@ print_banner()
|
|||||||
# User settings
|
# User settings
|
||||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||||
# Runtime
|
# Runtime
|
||||||
device_type = "cuda" # cuda|cpu|mps
|
device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU)
|
||||||
# Model architecture
|
# Model architecture
|
||||||
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
|
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
|
||||||
max_seq_len = 2048 # max context length
|
max_seq_len = 2048 # max context length
|
||||||
@@ -62,6 +62,7 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
# Compute init
|
# Compute init
|
||||||
|
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||||
dtype = torch.bfloat16 if device_type == "cuda" else torch.float32 # use fp32 on CPU|MPS
|
dtype = torch.bfloat16 if device_type == "cuda" else torch.float32 # use fp32 on CPU|MPS
|
||||||
|
|||||||
Reference in New Issue
Block a user