trying to add basic cpu support, will try mps too
This commit is contained in:
@@ -89,15 +89,16 @@ def get_dist_info():
|
|||||||
else:
|
else:
|
||||||
return False, 0, 0, 1
|
return False, 0, 0, 1
|
||||||
|
|
||||||
def compute_init():
|
def compute_init(device_type="cuda"): # cuda|cpu
|
||||||
"""Basic initialization that we keep doing over and over, so make common."""
|
"""Basic initialization that we keep doing over and over, so make common."""
|
||||||
|
|
||||||
# CUDA is currently required
|
# CUDA is currently required
|
||||||
assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm"
|
# assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm"
|
||||||
|
|
||||||
# Reproducibility
|
# Reproducibility
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
torch.cuda.manual_seed(42)
|
if device_type == "cuda":
|
||||||
|
torch.cuda.manual_seed(42)
|
||||||
# skipping full reproducibility for now, possibly investigate slowdown later
|
# skipping full reproducibility for now, possibly investigate slowdown later
|
||||||
# torch.use_deterministic_algorithms(True)
|
# torch.use_deterministic_algorithms(True)
|
||||||
# torch.backends.cudnn.deterministic = True
|
# torch.backends.cudnn.deterministic = True
|
||||||
@@ -106,15 +107,15 @@ def compute_init():
|
|||||||
# Precision
|
# Precision
|
||||||
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
|
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
|
||||||
|
|
||||||
# Distributed setup: Distributed Data Parallel (DDP), optional
|
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||||
if ddp:
|
if ddp and device_type == "cuda":
|
||||||
device = torch.device("cuda", ddp_local_rank)
|
device = torch.device("cuda", ddp_local_rank)
|
||||||
torch.cuda.set_device(device) # make "cuda" default to this device
|
torch.cuda.set_device(device) # make "cuda" default to this device
|
||||||
dist.init_process_group(backend="nccl", device_id=device)
|
dist.init_process_group(backend="nccl", device_id=device)
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
else:
|
else:
|
||||||
device = torch.device("cuda")
|
device = torch.device(device_type) # cuda|cpu
|
||||||
|
|
||||||
if ddp_rank == 0:
|
if ddp_rank == 0:
|
||||||
logger.info(f"Distributed world size: {ddp_world_size}")
|
logger.info(f"Distributed world size: {ddp_world_size}")
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from nanochat.common import get_dist_info
|
|||||||
from nanochat.dataset import parquets_iter_batched
|
from nanochat.dataset import parquets_iter_batched
|
||||||
from nanochat.tokenizer import get_tokenizer
|
from nanochat.tokenizer import get_tokenizer
|
||||||
|
|
||||||
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128):
|
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda"):
|
||||||
"""Stream pretraining text from parquet files, tokenize, yield training batches."""
|
"""Stream pretraining text from parquet files, tokenize, yield training batches."""
|
||||||
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||||
@@ -44,6 +44,6 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
|
|||||||
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
||||||
targets_cpu = scratch[1:]
|
targets_cpu = scratch[1:]
|
||||||
# Reshape to 2D and move to GPU async
|
# Reshape to 2D and move to GPU async
|
||||||
inputs = inputs_cpu.view(B, T).to(device="cuda", dtype=torch.int32, non_blocking=True)
|
inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=True)
|
||||||
targets = targets_cpu.view(B, T).to(device="cuda", dtype=torch.int64, non_blocking=True)
|
targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=True)
|
||||||
yield inputs, targets
|
yield inputs, targets
|
||||||
|
|||||||
@@ -6,6 +6,9 @@ python base_train.py
|
|||||||
or distributed as:
|
or distributed as:
|
||||||
|
|
||||||
torchrun --nproc_per_node=8 base_train.py
|
torchrun --nproc_per_node=8 base_train.py
|
||||||
|
|
||||||
|
If you just want to see it run on CPU (you won't get far but it should run), try something like:
|
||||||
|
python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --device_type=cpu --eval_tokens=512 --total_batch_size=512 --num_iterations=1000
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@@ -27,6 +30,8 @@ print_banner()
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# User settings
|
# User settings
|
||||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||||
|
# Runtime
|
||||||
|
device_type = "cuda" # cuda|cpu
|
||||||
# Model architecture
|
# Model architecture
|
||||||
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
|
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
|
||||||
max_seq_len = 2048 # max context length
|
max_seq_len = 2048 # max context length
|
||||||
@@ -57,9 +62,11 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
# Compute init
|
# Compute init
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16)
|
||||||
|
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||||
|
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||||
|
|
||||||
# wandb logging init
|
# wandb logging init
|
||||||
use_dummy_wandb = run == "dummy" or not master_process
|
use_dummy_wandb = run == "dummy" or not master_process
|
||||||
@@ -96,7 +103,7 @@ model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_la
|
|||||||
with torch.device("meta"):
|
with torch.device("meta"):
|
||||||
model_config = GPTConfig(**model_config_kwargs)
|
model_config = GPTConfig(**model_config_kwargs)
|
||||||
model = GPT(model_config)
|
model = GPT(model_config)
|
||||||
model.to_empty(device="cuda")
|
model.to_empty(device=device)
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
orig_model = model # original, uncompiled model, for saving raw model state_dict
|
orig_model = model # original, uncompiled model, for saving raw model state_dict
|
||||||
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
|
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
|
||||||
@@ -133,8 +140,8 @@ adamw_optimizer, muon_optimizer = optimizers
|
|||||||
# Initialize the DataLoaders for train/val
|
# Initialize the DataLoaders for train/val
|
||||||
base_dir = get_base_dir()
|
base_dir = get_base_dir()
|
||||||
tokens_dir = os.path.join(base_dir, "tokenized_data")
|
tokens_dir = os.path.join(base_dir, "tokenized_data")
|
||||||
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train")
|
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device)
|
||||||
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val")
|
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
|
||||||
x, y = next(train_loader) # kick off load of the very first batch of data
|
x, y = next(train_loader) # kick off load of the very first batch of data
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -252,7 +259,7 @@ for step in range(num_iterations + 1):
|
|||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
# single training step
|
# single training step
|
||||||
# evaluate the gradient
|
# evaluate the gradient
|
||||||
torch.cuda.synchronize()
|
synchronize()
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
for micro_step in range(grad_accum_steps):
|
for micro_step in range(grad_accum_steps):
|
||||||
with autocast_ctx:
|
with autocast_ctx:
|
||||||
@@ -275,7 +282,7 @@ for step in range(num_iterations + 1):
|
|||||||
for opt in optimizers:
|
for opt in optimizers:
|
||||||
opt.step()
|
opt.step()
|
||||||
model.zero_grad(set_to_none=True)
|
model.zero_grad(set_to_none=True)
|
||||||
torch.cuda.synchronize()
|
synchronize()
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
dt = t1 - t0
|
dt = t1 - t0
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
@@ -304,7 +311,7 @@ for step in range(num_iterations + 1):
|
|||||||
})
|
})
|
||||||
|
|
||||||
# print a few more stats
|
# print a few more stats
|
||||||
print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB")
|
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
||||||
print0(f"Total training time: {total_training_time/60:.2f}m")
|
print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||||
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||||
|
|
||||||
@@ -330,7 +337,7 @@ get_report().log(section="Base model training", data=[
|
|||||||
"MFU %": f"{mfu:.2f}%",
|
"MFU %": f"{mfu:.2f}%",
|
||||||
"Total training flops": f"{flops_so_far:e}",
|
"Total training flops": f"{flops_so_far:e}",
|
||||||
"Total training time": f"{total_training_time/60:.2f}m",
|
"Total training time": f"{total_training_time/60:.2f}m",
|
||||||
"Peak memory usage": f"{torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB",
|
"Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
|
||||||
}
|
}
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user