delete the configurator in favor of argparse and clean up a lot of kwarg details to make them more consistent across all scripts

This commit is contained in:
Andrej Karpathy
2026-01-04 19:14:23 +00:00
parent 507d54224a
commit eb7bbc1b66
9 changed files with 546 additions and 450 deletions

View File

@@ -6,7 +6,7 @@ Loads a checkpoint, and:
Example run as:
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
"""
import os
import argparse
from contextlib import nullcontext
import torch
from nanochat.checkpoint_manager import load_model
@@ -16,29 +16,30 @@ from nanochat.tokenizer import get_token_bytes
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
# Configuration
device_batch_size = 32
split_tokens = 20*524288 # number of tokens to evaluate per split
model_tag = None # optional model tag for the output directory name
model_step = None # optional model step for the output directory name
device_type = "" # cuda|cpu|mps (empty => autodetect)
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
# CLI arguments
parser = argparse.ArgumentParser(description="Evaluate loss on train/val splits and sample from model")
parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size")
parser.add_argument("--split_tokens", type=int, default=20*524288, help="number of tokens to evaluate per split")
parser.add_argument("--model_tag", type=str, default=None, help="model tag for checkpoint directory")
parser.add_argument("--model_step", type=int, default=None, help="model step to load")
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
args = parser.parse_args()
# Load the base model and the tokenizer
device_type = autodetect_device_type() if device_type == "" else device_type
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step)
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.model_step)
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
# Evaluate the loss on each split
tokens_per_step = device_batch_size * sequence_len * ddp_world_size
assert split_tokens % tokens_per_step == 0, "split_tokens must be divisible by tokens_per_step"
steps = split_tokens // tokens_per_step
tokens_per_step = args.device_batch_size * sequence_len * ddp_world_size
assert args.split_tokens % tokens_per_step == 0, "split_tokens must be divisible by tokens_per_step"
steps = args.split_tokens // tokens_per_step
token_bytes = get_token_bytes(device=device)
bpb_results = {}
for split_name in ["train", "val"]:
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name, device=device)
loader = tokenizing_distributed_data_loader(args.device_batch_size, sequence_len, split_name, device=device)
with autocast_ctx:
bpb = evaluate_bpb(model, loader, steps, token_bytes)
print0(f"{split_name} bpb: {bpb:.4f}")

View File

@@ -13,6 +13,7 @@ python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 -
import os
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
import argparse
import time
from contextlib import nullcontext
@@ -30,46 +31,46 @@ from scripts.base_eval import evaluate_model
print_banner()
# -----------------------------------------------------------------------------
# User settings
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
# CLI arguments
parser = argparse.ArgumentParser(description="Pretrain base model")
# Logging
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime
device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU)
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
# Model architecture
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
max_seq_len = 2048 # max context length
# Training horizon. Only one of these 3 will be used, in this order of precedence.
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable)
target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model")
parser.add_argument("--max_seq_len", type=int, default=2048, help="max context length")
# Training horizon (only one used, in order of precedence)
parser.add_argument("--num_iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
parser.add_argument("--target_flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
parser.add_argument("--target_param_data_ratio", type=int, default=20, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
# Optimization
device_batch_size = 32 # per-device batch size (set to not OOM)
total_batch_size = 524288 # total desired batch size, in #tokens
embedding_lr = 0.2 # learning rate for the embedding parameters (Adam)
unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam)
matrix_lr = 0.02 # learning rate for the matrix parameters (Muon)
grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
warmup_ratio = 0.0 # ratio of iterations for LR warmup
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
resume_from_step = -1 # resume training from this step of the optimization (-1 = disable)
parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size")
parser.add_argument("--total_batch_size", type=int, default=524288, help="total batch size in tokens")
parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--grad_clip", type=float, default=1.0, help="gradient clipping value (0.0 = disabled)")
parser.add_argument("--warmup_ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
parser.add_argument("--warmdown_ratio", type=float, default=0.2, help="ratio of iterations for LR warmdown")
parser.add_argument("--final_lr_frac", type=float, default=0.0, help="final LR as fraction of initial LR")
parser.add_argument("--resume_from_step", type=int, default=-1, help="resume training from this step (-1 = disable)")
# Evaluation
eval_every = 250 # every how many steps to evaluate the model for val bpb
eval_tokens = 20*524288 # number of tokens to evaluate val loss on
core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable)
core_metric_max_per_task = 500 # examples per task in estimating the core metric
sample_every = 2000 # every how many steps to sample from the model
save_every = -1 # every how many steps to save model checkpoints (-1 = disable, and save only at the end of the run)
parser.add_argument("--eval_every", type=int, default=250, help="evaluate val bpb every N steps")
parser.add_argument("--eval_tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
parser.add_argument("--core_metric_every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)")
parser.add_argument("--core_metric_max_per_task", type=int, default=500, help="examples per task for CORE metric")
parser.add_argument("--sample_every", type=int, default=2000, help="sample from model every N steps")
parser.add_argument("--save_every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)")
# Output
model_tag = "" # optionally override the model tag for the output checkpoint directory name
# now allow CLI to override the settings via the configurator lol
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
parser.add_argument("--model_tag", type=str, default=None, help="override model tag for checkpoint directory name")
args = parser.parse_args()
user_config = vars(args).copy() # for logging
# -----------------------------------------------------------------------------
# Compute init
device_type = autodetect_device_type() if device_type == "" else device_type
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
@@ -77,8 +78,8 @@ synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=run, config=user_config)
use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config)
# Tokenizer will be useful for evaluation, also we need the vocab size
tokenizer = get_tokenizer()
@@ -87,8 +88,8 @@ vocab_size = tokenizer.get_vocab_size()
print0(f"Vocab size: {vocab_size:,}")
# Model kwargs are derived from the desired depth of the model
num_layers = depth
model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
num_layers = args.depth
model_dim = args.depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div)
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
print0(f"num_layers: {num_layers}")
@@ -98,19 +99,19 @@ print0(f"num_kv_heads: {num_kv_heads}")
# Optimizer / data / training length related hyperparameters
# figure out the needed gradient accumulation to reach the desired total batch size
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
assert total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
assert args.total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
# -----------------------------------------------------------------------------
# Initialize the Model
# Create a new model with random weights
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
model_config_kwargs = dict(sequence_len=args.max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
with torch.device("meta"):
# All tensors are created as meta tensors (they have shape/dtype but no data)
model_config = GPTConfig(**model_config_kwargs)
@@ -120,12 +121,12 @@ model.init_weights() # All tensors get initialized
# If we are resuming, overwrite the model parameters with those of the checkpoint
base_dir = get_base_dir()
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
output_dirname = args.model_tag if args.model_tag else f"d{args.depth}" # e.g. d12
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
resuming = resume_from_step != -1
resuming = args.resume_from_step != -1
if resuming:
print0(f"Resuming optimization from step {resume_from_step}")
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, resume_from_step, device, load_optimizer=True, rank=ddp_rank)
print0(f"Resuming optimization from step {args.resume_from_step}")
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, args.resume_from_step, device, load_optimizer=True, rank=ddp_rank)
model.load_state_dict(model_data, strict=True, assign=True)
del model_data # free up this memory after the copy
@@ -137,28 +138,29 @@ num_flops_per_token = model.estimate_flops()
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order)
assert num_iterations > 0 or target_param_data_ratio > 0 or target_flops > 0
if num_iterations > 0:
assert args.num_iterations > 0 or args.target_param_data_ratio > 0 or args.target_flops > 0
if args.num_iterations > 0:
num_iterations = args.num_iterations
print0(f"Using user-provided number of iterations: {num_iterations:,}")
elif target_flops > 0:
elif args.target_flops > 0:
# calculate the number of iterations from the target flops
num_iterations = round(target_flops / (num_flops_per_token * total_batch_size))
num_iterations = round(args.target_flops / (num_flops_per_token * args.total_batch_size))
print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
elif target_param_data_ratio > 0:
elif args.target_param_data_ratio > 0:
# calculate the number of iterations from the target param data ratio
target_tokens = target_param_data_ratio * num_params
num_iterations = target_tokens // total_batch_size
target_tokens = args.target_param_data_ratio * num_params
num_iterations = target_tokens // args.total_batch_size
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
else:
raise ValueError("No training horizon specified")
total_tokens = total_batch_size * num_iterations
total_tokens = args.total_batch_size * num_iterations
print0(f"Total number of training tokens: {total_tokens:,}")
print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20
print0(f"Tokens : Params ratio: {args.total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
# -----------------------------------------------------------------------------
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
optimizers = model.setup_optimizers(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay)
adamw_optimizer, muon_optimizer = optimizers
if resuming:
@@ -170,8 +172,8 @@ if resuming:
# Initialize the DataLoaders for train/val
tokens_dir = os.path.join(base_dir, "tokenized_data")
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
train_loader = tokenizing_distributed_data_loader_with_state(device_batch_size, max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
train_loader = tokenizing_distributed_data_loader_with_state(args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
build_val_loader = lambda: tokenizing_distributed_data_loader(args.device_batch_size, args.max_seq_len, split="val", device=device)
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
# -----------------------------------------------------------------------------
@@ -179,15 +181,15 @@ x, y, dataloader_state_dict = next(train_loader) # kick off load of the very fir
# Learning rate scheduler
def get_lr_multiplier(it):
warmup_iters = round(warmup_ratio * num_iterations)
warmdown_iters = round(warmdown_ratio * num_iterations)
warmup_iters = round(args.warmup_ratio * num_iterations)
warmdown_iters = round(args.warmdown_ratio * num_iterations)
if it < warmup_iters:
return (it + 1) / warmup_iters
elif it <= num_iterations - warmdown_iters:
return 1.0
else:
progress = (num_iterations - it) / warmdown_iters
return progress * 1.0 + (1 - progress) * final_lr_frac
return progress * 1.0 + (1 - progress) * args.final_lr_frac
# Momentum scheduler for Muon optimizer
def get_muon_momentum(it):
@@ -215,13 +217,13 @@ else:
# Training loop
while True:
last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end
flops_so_far = num_flops_per_token * total_batch_size * step
flops_so_far = num_flops_per_token * args.total_batch_size * step
# once in a while: evaluate the val bpb (all ranks participate)
if last_step or step % eval_every == 0:
if last_step or step % args.eval_every == 0:
model.eval()
val_loader = build_val_loader()
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
with autocast_ctx:
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
@@ -238,10 +240,10 @@ while True:
# once in a while: estimate the CORE metric (all ranks participate)
# use the original uncompiled model because the inputs keep changing shape
results = {}
if core_metric_every > 0 and (last_step or (step > 0 and step % core_metric_every == 0)):
if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)):
model.eval()
with autocast_ctx:
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
results = evaluate_model(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task)
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
wandb_run.log({
"step": step,
@@ -253,7 +255,7 @@ while True:
# once in a while: sample from the model (only on master process)
# use the original uncompiled model because the inputs keep changing shape
if master_process and (last_step or (step > 0 and step % sample_every == 0)):
if master_process and (last_step or (step > 0 and step % args.sample_every == 0)):
model.eval()
prompts = [
"The capital of France is",
@@ -273,7 +275,7 @@ while True:
model.train()
# save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step
if last_step or (step > 0 and step != resume_from_step and save_every > 0 and step % save_every == 0):
if last_step or (step > 0 and step != args.resume_from_step and args.save_every > 0 and step % args.save_every == 0):
save_checkpoint(
checkpoint_dir,
step,
@@ -284,8 +286,8 @@ while True:
"val_bpb": val_bpb, # loss at last step
"model_config": model_config_kwargs,
"user_config": user_config, # inputs to the training script
"device_batch_size": device_batch_size,
"max_seq_len": max_seq_len,
"device_batch_size": args.device_batch_size,
"max_seq_len": args.max_seq_len,
"dataloader_state_dict": dataloader_state_dict,
"loop_state": { # all loop state (other than step) so that we can resume training
"min_val_bpb": min_val_bpb,
@@ -313,9 +315,9 @@ while True:
loss.backward()
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
# gradient clipping
grad_clip_enabled = grad_clip > 0.0
grad_clip_enabled = args.grad_clip > 0.0
if grad_clip_enabled:
grad_norm_tensor = torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
grad_norm_tensor = torch.nn.utils.clip_grad_norm_(orig_model.parameters(), args.grad_clip)
grad_norm = grad_norm_tensor.item() # GPU tensor -> CPU float (note: cpu-gpu sync point)
# step the optimizers
lrm = get_lr_multiplier(step)
@@ -338,8 +340,8 @@ while True:
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
pct_done = 100 * step / num_iterations
tok_per_sec = int(total_batch_size / dt)
flops_per_sec = num_flops_per_token * total_batch_size / dt
tok_per_sec = int(args.total_batch_size / dt)
flops_per_sec = num_flops_per_token * args.total_batch_size / dt
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
if step > 10:
@@ -378,11 +380,11 @@ get_report().log(section="Base model training", data=[
"Number of FLOPs per token": f"{num_flops_per_token:e}",
"Calculated number of iterations": num_iterations,
"Number of training tokens": total_tokens,
"Tokens : Params ratio": total_batch_size * num_iterations / num_params,
"Tokens : Params ratio": args.total_batch_size * num_iterations / num_params,
"DDP world size": ddp_world_size,
"warmup_ratio": warmup_ratio,
"warmdown_ratio": warmdown_ratio,
"final_lr_frac": final_lr_frac,
"warmup_ratio": args.warmup_ratio,
"warmdown_ratio": args.warmdown_ratio,
"final_lr_frac": args.final_lr_frac,
},
{ # stats about training outcomes
"Minimum validation bpb": min_val_bpb,

View File

@@ -16,57 +16,69 @@ python -m scripts.chat_rl
torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=default
"""
import argparse
import os
import itertools
import re
import wandb
import torch
import torch.distributed as dist
from contextlib import nullcontext
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb, autodetect_device_type
from nanochat.checkpoint_manager import save_checkpoint, load_model
from nanochat.engine import Engine
from tasks.gsm8k import GSM8K
# RL hyperparameters
run = "dummy" # wandb run name
source = "sft" # mid|sft
model_tag = None # model tag to load the model from (base model or midtrained model)
step = None # step to load the model from (base model or midtrained model)
dtype = "bfloat16"
device_batch_size = 8 # no forward pass will go above this to not OOM
examples_per_step = 16 # in total and across all ranks (note: examples, not samples/completions!)
num_samples = 16 # number of samples per example (/question)
max_new_tokens = 256
temperature = 1.0
top_k = 50 # TODO: try None?
unembedding_lr = 0.004
embedding_lr = 0.2
matrix_lr = 0.02
weight_decay = 0.0
init_lr_frac = 0.05
num_epochs = 1 # how many epochs of gsm8k to train on
save_every = 60 # every how many steps to save the model
eval_every = 60 # every how many steps to evaluate the model for val pass@k
eval_examples = 400 # number of examples used for evaluating pass@k
# now allow CLI to override the settings via the configurator lol
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
# -----------------------------------------------------------------------------
# CLI arguments
parser = argparse.ArgumentParser(description="Reinforcement learning on GSM8K")
# Logging
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
# Model loading
parser.add_argument("--source", type=str, default="sft", help="mid|sft - which checkpoint to load from")
parser.add_argument("--model_tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model_step", type=int, default=None, help="model step to load from")
# Training horizon
parser.add_argument("--num_epochs", type=int, default=1, help="number of epochs over GSM8K")
# Batch sizes / sampling
parser.add_argument("--device_batch_size", type=int, default=8, help="max batch size per forward pass")
parser.add_argument("--examples_per_step", type=int, default=16, help="total examples per optimization step across all ranks")
parser.add_argument("--num_samples", type=int, default=16, help="number of samples per example/question")
# Generation
parser.add_argument("--max_new_tokens", type=int, default=256, help="max tokens to generate per sample")
parser.add_argument("--temperature", type=float, default=1.0, help="sampling temperature")
parser.add_argument("--top_k", type=int, default=50, help="top-k sampling (0 = disabled)")
# Optimization
parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
parser.add_argument("--init_lr_frac", type=float, default=0.05, help="initial LR as fraction of base LR")
# Evaluation / checkpointing
parser.add_argument("--eval_every", type=int, default=60, help="evaluate pass@k every N steps")
parser.add_argument("--eval_examples", type=int, default=400, help="number of examples for pass@k evaluation")
parser.add_argument("--save_every", type=int, default=60, help="save checkpoint every N steps")
args = parser.parse_args()
user_config = vars(args).copy()
# -----------------------------------------------------------------------------
# Init compute/precision
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=run, config=user_config)
use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=args.run, config=user_config)
# Init model and tokenizer
model, tokenizer, meta = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.model_step)
engine = Engine(model, tokenizer) # for sampling rollouts
# -----------------------------------------------------------------------------
@@ -74,7 +86,7 @@ engine = Engine(model, tokenizer) # for sampling rollouts
train_task = GSM8K(subset="main", split="train")
val_task = GSM8K(subset="main", split="test")
num_steps = (len(train_task) // examples_per_step) * num_epochs
num_steps = (len(train_task) // args.examples_per_step) * args.num_epochs
print0(f"Calculated number of steps: {num_steps}")
@torch.no_grad()
@@ -95,16 +107,16 @@ def get_batch():
model.eval() # ensure the model is in eval mode
generated_token_sequences = []
masks = []
num_sampling_steps = num_samples // device_batch_size # go sequentially to prevent OOMs
num_sampling_steps = args.num_samples // args.device_batch_size # go sequentially to prevent OOMs
for sampling_step in range(num_sampling_steps):
seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF # positive half of int32
with autocast_ctx:
generated_token_sequences_batch, masks_batch = engine.generate_batch(
tokens,
num_samples=device_batch_size,
max_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
num_samples=args.device_batch_size,
max_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
seed=seed, # must make sure to change the seed for each sampling step
)
generated_token_sequences.extend(generated_token_sequences_batch)
@@ -191,16 +203,16 @@ def run_gsm8k_eval(task, tokenizer, engine,
# Init the optimizer
optimizers = model.setup_optimizers(
unembedding_lr=unembedding_lr,
embedding_lr=embedding_lr,
matrix_lr=matrix_lr,
weight_decay=weight_decay,
unembedding_lr=args.unembedding_lr,
embedding_lr=args.embedding_lr,
matrix_lr=args.matrix_lr,
weight_decay=args.weight_decay,
)
# Set the initial learning rate as a fraction of the base learning rate
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["lr"] * init_lr_frac
group["lr"] = group["lr"] * args.init_lr_frac
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
# Learning rate scheduler: simple rampdown to zero over num_steps
@@ -209,9 +221,9 @@ def get_lr_multiplier(it):
return lrm
# Calculate the number of examples each rank handles to achieve the desired examples_per_step
print0(f"Total sequences per step: {examples_per_step * num_samples}") # total batch size in sequences/step
assert examples_per_step % ddp_world_size == 0, "Desired examples per step must be divisible by the number of ranks"
examples_per_rank = examples_per_step // ddp_world_size # per GPU
print0(f"Total sequences per step: {args.examples_per_step * args.num_samples}") # total batch size in sequences/step
assert args.examples_per_step % ddp_world_size == 0, "Desired examples per step must be divisible by the number of ranks"
examples_per_rank = args.examples_per_step // ddp_world_size # per GPU
print0(f"Calculated examples per rank: {examples_per_rank}")
# Kick off the training loop
@@ -219,22 +231,22 @@ batch_iterator = get_batch()
for step in range(num_steps):
# Evaluate the model once in a while and log to wandb
if step % eval_every == 0:
if step % args.eval_every == 0:
model.eval()
passk = torch.zeros(device_batch_size, device=device) # pass@k for k=1..device_batch_size
passk = torch.zeros(args.device_batch_size, device=device) # pass@k for k=1..device_batch_size
with autocast_ctx:
records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=device_batch_size, max_examples=eval_examples, temperature=1.0)
records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=args.device_batch_size, max_examples=args.eval_examples, temperature=1.0)
records = list(records_iter) # collect all records
for k in range(1, device_batch_size + 1):
for k in range(1, args.device_batch_size + 1):
passk[k - 1] = sum(any(o["is_correct"] for o in r["outcomes"][:k]) for r in records)
num_records = torch.tensor(len(records), dtype=torch.long, device=device)
if ddp:
dist.all_reduce(num_records, op=dist.ReduceOp.SUM)
dist.all_reduce(passk, op=dist.ReduceOp.SUM)
passk = passk / num_records.item() # normalize by the total number of records
print_passk = [f"Pass@{k}: {passk[k - 1].item():.4f}" for k in range(1, device_batch_size + 1)]
print_passk = [f"Pass@{k}: {passk[k - 1].item():.4f}" for k in range(1, args.device_batch_size + 1)]
print0(f"Step {step} | {', '.join(print_passk)}")
log_passk = {f"pass@{k}": passk[k - 1].item() for k in range(1, device_batch_size + 1)}
log_passk = {f"pass@{k}": passk[k - 1].item() for k in range(1, args.device_batch_size + 1)}
wandb_run.log({
"step": step,
**log_passk,
@@ -249,11 +261,11 @@ for step in range(num_steps):
# Evaluate the loss and gradients
model.train() # ensure the model is in train mode
# We need one more loop because we can never exceed the device_batch_size
assert inputs_all.size(0) % device_batch_size == 0
num_passes = inputs_all.size(0) // device_batch_size
assert inputs_all.size(0) % args.device_batch_size == 0
num_passes = inputs_all.size(0) // args.device_batch_size
for pass_idx in range(num_passes):
# Pluck out the batch for this pass
b0, b1 = pass_idx * device_batch_size, (pass_idx + 1) * device_batch_size
b0, b1 = pass_idx * args.device_batch_size, (pass_idx + 1) * args.device_batch_size
inputs = inputs_all[b0:b1]
targets = targets_all[b0:b1]
rewards = rewards_all[b0:b1]
@@ -306,10 +318,10 @@ for step in range(num_steps):
})
# Master process saves the model once in a while. Skip first step. Save last step.
if master_process and ((step > 0 and step % save_every == 0) or step == num_steps - 1):
if master_process and ((step > 0 and step % args.save_every == 0) or step == num_steps - 1):
base_dir = get_base_dir()
depth = model.config.n_layer
output_dirname = model_tag if model_tag else f"d{depth}" # base the model tag on the depth of the base model
output_dirname = args.model_tag if args.model_tag else f"d{depth}" # base the model tag on the depth of the base model
checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", output_dirname)
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
save_checkpoint(

View File

@@ -9,6 +9,7 @@ Or torchrun for training:
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft
"""
import argparse
import os
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
@@ -31,49 +32,51 @@ from tasks.customjson import CustomJSON
from tasks.spellingbee import SimpleSpelling, SpellingBee
# -----------------------------------------------------------------------------
# SFT Hyperparameters
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
# input model options
source = "mid" # base|mid , which checkpoint to load the model from (base model or midtrained model)
model_tag = None # model tag to load the model from (base model or midtrained model)
step = None # step to load the model from (base model or midtrained model)
# compute/precision
device_type = "" # cuda|cpu|mps (empty => autodetect)
dtype = "bfloat16"
device_batch_size = 4 # max to avoid OOM
# optimization
num_epochs = 1
num_iterations = -1 # override number of iterations (-1 = disable, use num_epochs to derive it)
target_examples_per_step = 32
unembedding_lr = 0.004
embedding_lr = 0.2
matrix_lr = 0.02
weight_decay = 0.0
init_lr_frac = 0.02
# evaluation and logging there of
eval_every = 100
eval_steps = 100
eval_metrics_every = 200
eval_metrics_max_problems = 1024
# now allow CLI to override the settings via the configurator lol
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
# CLI arguments
parser = argparse.ArgumentParser(description="Supervised finetuning for chat")
# Logging
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
# Model loading
parser.add_argument("--source", type=str, default="mid", help="base|mid - which checkpoint to load from")
parser.add_argument("--model_tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model_step", type=int, default=None, help="model step to load from")
# Training horizon
parser.add_argument("--num_epochs", type=int, default=1, help="number of epochs")
parser.add_argument("--num_iterations", type=int, default=-1, help="override number of iterations (-1 = use num_epochs)")
# Batch sizes
parser.add_argument("--device_batch_size", type=int, default=4, help="per-device batch size")
parser.add_argument("--target_examples_per_step", type=int, default=32, help="target examples per optimization step")
# Optimization
parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
parser.add_argument("--init_lr_frac", type=float, default=0.02, help="initial LR as fraction of base LR")
# Evaluation
parser.add_argument("--eval_every", type=int, default=100, help="evaluate val loss every N steps")
parser.add_argument("--eval_steps", type=int, default=100, help="number of batches for val loss evaluation")
parser.add_argument("--eval_metrics_every", type=int, default=200, help="evaluate accuracy metrics every N steps")
parser.add_argument("--eval_metrics_max_problems", type=int, default=1024, help="max problems per metric evaluation")
args = parser.parse_args()
user_config = vars(args).copy()
# -----------------------------------------------------------------------------
# Compute init
device_type = autodetect_device_type() if device_type == "" else device_type
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0
ptdtype = torch.float32 if dtype == 'float32' else torch.bfloat16
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=run, config=user_config, save_code=True)
use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config, save_code=True)
# Load the model and tokenizer
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
model, tokenizer, meta = load_model(args.source, device, phase="train", model_tag=args.model_tag, step=args.model_step)
orig_model = model # original, uncompiled model
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
@@ -127,34 +130,36 @@ def sft_data_generator(dataset, batch_size):
yield collate_and_yield(batch)
batch = []
examples_per_step = device_batch_size * ddp_world_size
print0(f"Target examples per step: {target_examples_per_step}")
print0(f"Device batch size: {device_batch_size}")
examples_per_step = args.device_batch_size * ddp_world_size
print0(f"Target examples per step: {args.target_examples_per_step}")
print0(f"Device batch size: {args.device_batch_size}")
print0(f"Examples per step is device_batch_size * ddp_world_size: {examples_per_step}")
assert target_examples_per_step % examples_per_step == 0, "Target examples per step must be divisible by examples per step"
grad_accum_steps = target_examples_per_step // examples_per_step
assert args.target_examples_per_step % examples_per_step == 0, "Target examples per step must be divisible by examples per step"
grad_accum_steps = args.target_examples_per_step // examples_per_step
print0(f"=> Setting grad accum steps: {grad_accum_steps}")
if num_iterations == -1:
if args.num_iterations == -1:
# derive num_iterations from num_epochs and the size of the dataset
assert num_epochs > 0, "num_epochs must be positive if num_iterations is -1"
num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs
train_loader = sft_data_generator(train_ds, batch_size=device_batch_size)
build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size)
assert args.num_epochs > 0, "num_epochs must be positive if num_iterations is -1"
num_iterations = (len(train_ds) // args.target_examples_per_step) * args.num_epochs
else:
num_iterations = args.num_iterations
train_loader = sft_data_generator(train_ds, batch_size=args.device_batch_size)
build_val_loader = lambda: sft_data_generator(val_ds, batch_size=args.device_batch_size)
# -----------------------------------------------------------------------------
# Initialize the Optimizer
optimizers = model.setup_optimizers(
unembedding_lr=unembedding_lr,
embedding_lr=embedding_lr,
matrix_lr=matrix_lr,
weight_decay=weight_decay,
unembedding_lr=args.unembedding_lr,
embedding_lr=args.embedding_lr,
matrix_lr=args.matrix_lr,
weight_decay=args.weight_decay,
)
# Set the initial learning rate as a fraction of the base learning rate
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["lr"] * init_lr_frac
group["lr"] = group["lr"] * args.init_lr_frac
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
# -----------------------------------------------------------------------------
@@ -171,11 +176,11 @@ for step in range(num_iterations):
last_step = step == num_iterations - 1
# evaluate the validation loss
if last_step or step % eval_every == 0:
if last_step or step % args.eval_every == 0:
model.eval()
val_loader = build_val_loader()
losses = []
for _ in range(eval_steps):
for _ in range(args.eval_steps):
val_inputs, val_targets = next(val_loader)
with torch.no_grad(), autocast_ctx:
loss = model(val_inputs, val_targets)
@@ -192,13 +197,13 @@ for step in range(num_iterations):
model.train()
# evaluate accuracy of the multiple choice tasks (which are quick to run)
if last_step or (step > 0 and step % eval_metrics_every == 0):
if last_step or (step > 0 and step % args.eval_metrics_every == 0):
model.eval()
metrics = {}
with torch.no_grad(), autocast_ctx:
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=args.device_batch_size*2, max_problems=args.eval_metrics_max_problems)
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=args.device_batch_size*2, max_problems=args.eval_metrics_max_problems)
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
print0(f"Step {step:05d} | {metrics_str}")
wandb_run.log({
@@ -250,7 +255,7 @@ for step in range(num_iterations):
if master_process:
base_dir = get_base_dir()
depth = model.config.n_layer
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname)
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
save_checkpoint(

View File

@@ -9,6 +9,7 @@ Or torchrun for training:
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16
"""
import argparse
from collections import deque
import os
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
@@ -31,65 +32,75 @@ from tasks.customjson import CustomJSON
from tasks.spellingbee import SimpleSpelling, SpellingBee
# -----------------------------------------------------------------------------
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
device_type = "" # cuda|cpu|mps (empty => autodetect)
model_tag = None # model tag to load the model from (base model or midtrained model)
step = None # step to load the model from (base model or midtrained model)
dtype = "bfloat16"
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
max_seq_len = 2048
device_batch_size = 32
unembedding_lr = 0.004
embedding_lr = 0.2
matrix_lr = 0.02
init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate
weight_decay = 0.0
eval_every = 150 # -1 = disable
eval_tokens = 20*524288
total_batch_size = 524288
dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
# CLI arguments
parser = argparse.ArgumentParser(description="Midtrain the model")
# Logging
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
# Model loading
parser.add_argument("--model_tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model_step", type=int, default=None, help="model step to load from")
# Training horizon
parser.add_argument("--num_iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)")
# Batch sizes
parser.add_argument("--max_seq_len", type=int, default=2048, help="max context length")
parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size")
parser.add_argument("--total_batch_size", type=int, default=524288, help="total batch size in tokens")
# Optimization
parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
parser.add_argument("--init_lr_frac", type=float, default=1.0, help="initial LR as fraction of base LR")
# Evaluation
parser.add_argument("--eval_every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)")
parser.add_argument("--eval_tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
# Output
parser.add_argument("--dry_run", action="store_true", help="log to wandb but skip checkpoints/report")
args = parser.parse_args()
user_config = vars(args).copy()
# -----------------------------------------------------------------------------
# Compute init
device_type = autodetect_device_type() if device_type == "" else device_type
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mid", name=run, config=user_config)
use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mid", name=args.run, config=user_config)
# Load the model and tokenizer
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step)
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step)
pretrain_batch_size = meta.get("device_batch_size", None)
if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size:
if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size:
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?")
orig_model = model
model = torch.compile(model, dynamic=False)
depth = model.config.n_layer
num_flops_per_token = model.estimate_flops()
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
assert total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
assert args.total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
token_bytes = get_token_bytes(device=device)
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
optimizers = model.setup_optimizers(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay)
adamw_optimizer, muon_optimizer = optimizers
# Override the initial learning rate as a fraction of the base learning rate
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["lr"] * init_lr_frac
group["lr"] = group["lr"] * args.init_lr_frac
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
# Midtraining data mixture and DataLoader
@@ -120,7 +131,7 @@ def mid_data_generator(split):
dataset = train_dataset if split == "train" else val_dataset
dataset_size = len(dataset)
assert dataset_size > 0
needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets
needed_tokens = args.device_batch_size * args.max_seq_len + 1 # to form one training batch of inputs,targets
token_buffer = deque()
# CUDA supports memory pinning for faster transfers between CPU and GPU:
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=(device_type == "cuda"))
@@ -139,18 +150,18 @@ def mid_data_generator(split):
last_step = True # toggle last_step to True, which will terminate the training loop
# Stopping condition to respect num_iterations, if given
it += 1
if 0 < num_iterations <= it and split == "train":
if 0 < args.num_iterations <= it and split == "train":
last_step = True # toggle last_step to True, which will terminate the training loop
# Build up inputs/targets and yield
for i in range(needed_tokens):
scratch[i] = token_buffer.popleft()
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
targets_cpu = scratch[1:]
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
inputs = inputs_cpu.view(args.device_batch_size, args.max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
targets = targets_cpu.view(args.device_batch_size, args.max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
if split == "train":
if num_iterations > 0:
approx_progress = it / num_iterations # calculate progress from the max number of iterations
if args.num_iterations > 0:
approx_progress = it / args.num_iterations # calculate progress from the max number of iterations
else:
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
yield inputs, targets
@@ -179,7 +190,7 @@ ema_beta = 0.9 # EMA decay factor
total_training_time = 0 # total wall-clock time of training
step = 0
while True:
flops_so_far = num_flops_per_token * total_batch_size * step
flops_so_far = num_flops_per_token * args.total_batch_size * step
# Synchronize last_step across all ranks to avoid hangs in the distributed setting
if ddp:
@@ -188,10 +199,10 @@ while True:
last_step = bool(last_step_tensor.item())
# once in a while: evaluate the val bpb (all ranks participate)
if eval_every > 0 and (last_step or step % eval_every == 0):
if args.eval_every > 0 and (last_step or step % args.eval_every == 0):
model.eval()
val_loader = build_val_loader()
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
with autocast_ctx:
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
@@ -206,8 +217,8 @@ while True:
model.train()
# save checkpoint at the end of the run (only on master process)
if master_process and last_step and not dry_run:
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
if master_process and last_step and not args.dry_run:
output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12
checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname)
save_checkpoint(
checkpoint_dir,
@@ -218,7 +229,7 @@ while True:
"step": step,
"val_bpb": val_bpb, # loss at last step
"model_config": {
"sequence_len": max_seq_len,
"sequence_len": args.max_seq_len,
"vocab_size": tokenizer.get_vocab_size(),
"n_layer": depth,
"n_head": model.config.n_head,
@@ -268,8 +279,8 @@ while True:
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
pct_done = 100 * progress
tok_per_sec = int(total_batch_size / dt)
flops_per_sec = num_flops_per_token * total_batch_size / dt
tok_per_sec = int(args.total_batch_size / dt)
flops_per_sec = num_flops_per_token * args.total_batch_size / dt
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
if step > 10:
@@ -293,7 +304,7 @@ print0(f"Total training time: {total_training_time/60:.2f}m")
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
# Log to report
if not dry_run:
if not args.dry_run:
from nanochat.report import get_report
get_report().log(section="Midtraining", data=[
user_config, # CLI args