fix buggy midtrain and update all kwargs to be idiomatic. that is, argparse uses dashes variables use underscores. the underscores are just a remnant of the previous Configurator object. This is the right way

This commit is contained in:
Andrej Karpathy
2026-01-13 22:45:27 +00:00
parent 3b50b77ed3
commit 7312ec9898
11 changed files with 144 additions and 139 deletions

View File

@@ -25,7 +25,7 @@ python -m nanochat.report reset
# train tokenizer on ~1B characters # train tokenizer on ~1B characters
python -m nanochat.dataset -n 4 python -m nanochat.dataset -n 4
python -m scripts.tok_train --max_chars=1000000000 python -m scripts.tok_train --max-chars=1000000000
python -m scripts.tok_eval python -m scripts.tok_eval
# train a very small 4 layer model on the CPU # train a very small 4 layer model on the CPU
@@ -33,37 +33,37 @@ python -m scripts.tok_eval
# we only run 50 steps of optimization (bump this to get better results) # we only run 50 steps of optimization (bump this to get better results)
python -m scripts.base_train \ python -m scripts.base_train \
--depth=4 \ --depth=4 \
--max_seq_len=1024 \ --max-seq-len=1024 \
--device_batch_size=1 \ --device-batch-size=1 \
--total_batch_size=1024 \ --total-batch-size=1024 \
--eval_every=50 \ --eval-every=50 \
--eval_tokens=4096 \ --eval-tokens=4096 \
--core_metric_every=50 \ --core-metric-every=50 \
--core_metric_max_per_task=12 \ --core-metric-max-per-task=12 \
--sample_every=50 \ --sample-every=50 \
--num_iterations=50 --num-iterations=50
python -m scripts.base_loss --device_batch_size=1 --split_tokens=4096 python -m scripts.base_loss --device-batch-size=1 --split-tokens=4096
python -m scripts.base_eval --max-per-task=16 python -m scripts.base_eval --max-per-task=16
# midtraining # midtraining
python -m scripts.mid_train \ python -m scripts.mid_train \
--max_seq_len=1024 \ --max-seq-len=1024 \
--device_batch_size=1 \ --device-batch-size=1 \
--eval_every=50 \ --eval-every=50 \
--eval_tokens=4096 \ --eval-tokens=4096 \
--total_batch_size=1024 \ --total-batch-size=1024 \
--num_iterations=100 --num-iterations=100
# eval results will be terrible, this is just to execute the code paths. # eval results will be terrible, this is just to execute the code paths.
# note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems # note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems
python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20 python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20
# SFT # SFT
python -m scripts.chat_sft \ python -m scripts.chat_sft \
--device_batch_size=1 \ --device-batch-size=1 \
--target_examples_per_step=4 \ --target-examples-per-step=4 \
--num_iterations=100 \ --num-iterations=100 \
--eval_steps=4 \ --eval-steps=4 \
--eval_metrics_max_problems=16 --eval-metrics-max-problems=16
# Chat CLI # Chat CLI
# python -m scripts.chat_cli -p "Why is the sky blue?" # python -m scripts.chat_cli -p "Why is the sky blue?"

View File

@@ -20,7 +20,7 @@ if [ -z "$SKIP_SETUP" ]; then
# Tokenizer, download 1000 shards for pretraining # Tokenizer, download 1000 shards for pretraining
# (probably this can be reduced but it's tricky to determine the exact right number, TODO). # (probably this can be reduced but it's tricky to determine the exact right number, TODO).
python -m nanochat.dataset -n 1000 python -m nanochat.dataset -n 1000
python -m scripts.tok_train --max_chars=2000000000 --vocab_size=32768 python -m scripts.tok_train --max-chars=2000000000 --vocab-size=32768
else else
source .venv/bin/activate source .venv/bin/activate
fi fi
@@ -58,16 +58,16 @@ for d in "${DEPTHS[@]}"; do
START_TIME=$(date +%s) START_TIME=$(date +%s)
# Train the model with natural horizon (target_param_data_ratio default) # Train the model with natural horizon (target_param_data_ratio default)
# No --target_flops, let it use the default ratio from base_train # No --target-flops, let it use the default ratio from base_train
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \ torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \
--depth=$d \ --depth=$d \
--target_param_data_ratio=8 \ --target-param-data-ratio=8 \
--run="${WANDB_RUN}_d${d}" \ --run="${WANDB_RUN}_d${d}" \
--model_tag="${TAG}" \ --model-tag="${TAG}" \
--core_metric_every=999999 \ --core-metric-every=999999 \
--core_metric_max_per_task=-1 \ --core-metric-max-per-task=-1 \
--sample_every=-1 \ --sample-every=-1 \
--save_every=-1 \ --save-every=-1 \
2>&1 | tee "$RESULTS_DIR/${TAG}_train.log" 2>&1 | tee "$RESULTS_DIR/${TAG}_train.log"
END_TIME=$(date +%s) END_TIME=$(date +%s)

View File

@@ -23,15 +23,15 @@ python -m nanochat.dataset -n 16
# start downloading the rest of the shards for a total of 1200 (see below why 1200) # start downloading the rest of the shards for a total of 1200 (see below why 1200)
python -m nanochat.dataset -n 1200 & python -m nanochat.dataset -n 1200 &
# todo: download the rest of it # todo: download the rest of it
python -m scripts.tok_train --max_chars=4000000000 --vocab_size=65536 python -m scripts.tok_train --max-chars=4000000000 --vocab-size=65536
python -m scripts.tok_eval python -m scripts.tok_eval
# Documenting my process for determining the hyperparameters for this run1000.sh script: # Documenting my process for determining the hyperparameters for this run1000.sh script:
# We want a budget of approx. $1000 ~= 41.6 hours of 8XH100 compute # We want a budget of approx. $1000 ~= 41.6 hours of 8XH100 compute
# 1) I guessed the model size for this to be about depth=32 # 1) I guessed the model size for this to be about depth=32
# 2) Determine the device_batch_size that fits: # 2) Determine the device_batch_size that fits:
# Running the base_train.py script with --depth=32, I saw that --device_batch_size=16 # Running the base_train.py script with --depth=32, I saw that --device-batch-size=16
# runs out of memory, but --device_batch_size=8 fits. Inspecting `nvidia-smi` during training, # runs out of memory, but --device-batch-size=8 fits. Inspecting `nvidia-smi` during training,
# I saw all GPUs were at about 78/80GB VRAM, so it just barely fits and we have good MFU at ~50%. # I saw all GPUs were at about 78/80GB VRAM, so it just barely fits and we have good MFU at ~50%.
# So the training script was running ok and showed: # So the training script was running ok and showed:
# Vocab size: 65,536 # Vocab size: 65,536
@@ -73,13 +73,13 @@ python -m scripts.tok_eval
# Number of processes/GPUs to use # Number of processes/GPUs to use
NPROC_PER_NODE=8 NPROC_PER_NODE=8
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=32 --target_param_data_ratio=20 --device_batch_size=8 --run=$WANDB_RUN torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=32 --target-param-data-ratio=20 --device-batch-size=8 --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval
# midtrain # midtrain
# NOTE: ensure that we use the same device_batch_size here as the base training script. # NOTE: ensure that we use the same device_batch_size here as the base training script.
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device_batch_size=8 --run=$WANDB_RUN torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device-batch-size=8 --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid
# sft # sft

View File

@@ -64,15 +64,15 @@ for flops in "${FLOPS_BUDGETS[@]}"; do
# CORE eval happens once at the end (999999 ensures only final step) # CORE eval happens once at the end (999999 ensures only final step)
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \ torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \
--depth=$d \ --depth=$d \
--target_flops=$flops \ --target-flops=$flops \
--target_param_data_ratio=-1 \ --target-param-data-ratio=-1 \
--run="${WANDB_RUN}_${TAG}" \ --run="${WANDB_RUN}_${TAG}" \
--model_tag="${TAG}" \ --model-tag="${TAG}" \
--eval_tokens=$EVAL_TOKENS \ --eval-tokens=$EVAL_TOKENS \
--core_metric_every=999999 \ --core-metric-every=999999 \
--core_metric_max_per_task=-1 \ --core-metric-max-per-task=-1 \
--sample_every=-1 \ --sample-every=-1 \
--save_every=-1 \ --save-every=-1 \
2>&1 | tee "$RESULTS_DIR/${TAG}_train.log" 2>&1 | tee "$RESULTS_DIR/${TAG}_train.log"
END_TIME=$(date +%s) END_TIME=$(date +%s)

View File

@@ -7,7 +7,7 @@ Example run as:
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
To evaluate a HuggingFace model: To evaluate a HuggingFace model:
python -m scripts.base_loss --hf_path openai-community/gpt2 python -m scripts.base_loss --hf-path openai-community/gpt2
""" """
import argparse import argparse
from contextlib import nullcontext from contextlib import nullcontext
@@ -61,12 +61,12 @@ def get_hf_token_bytes(tokenizer, device="cpu"):
# CLI arguments # CLI arguments
parser = argparse.ArgumentParser(description="Evaluate loss on train/val splits and sample from model") parser = argparse.ArgumentParser(description="Evaluate loss on train/val splits and sample from model")
parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size") parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
parser.add_argument("--split_tokens", type=int, default=40*524288, help="number of tokens to evaluate per split") parser.add_argument("--split-tokens", type=int, default=40*524288, help="number of tokens to evaluate per split")
parser.add_argument("--model_tag", type=str, default=None, help="model tag for checkpoint directory") parser.add_argument("--model-tag", type=str, default=None, help="model tag for checkpoint directory")
parser.add_argument("--model_step", type=int, default=None, help="model step to load") parser.add_argument("--model-step", type=int, default=None, help="model step to load")
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--hf_path", type=str, default=None, help="HuggingFace model path (e.g. openai-community/gpt2)") parser.add_argument("--hf-path", type=str, default=None, help="HuggingFace model path (e.g. openai-community/gpt2)")
args = parser.parse_args() args = parser.parse_args()
# Load the base model and the tokenizer # Load the base model and the tokenizer

View File

@@ -8,7 +8,7 @@ or distributed as:
torchrun --nproc_per_node=8 -m scripts.base_train.py torchrun --nproc_per_node=8 -m scripts.base_train.py
If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example: If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_every=-1 --total_batch_size=512 --num_iterations=20 python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20
""" """
import os import os
@@ -36,40 +36,40 @@ parser = argparse.ArgumentParser(description="Pretrain base model")
# Logging # Logging
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime # Runtime
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
# Model architecture # Model architecture
parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model") parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model")
parser.add_argument("--aspect_ratio", type=int, default=64, help="model_dim = depth * aspect_ratio") parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio")
parser.add_argument("--head_dim", type=int, default=128, help="target head dimension for attention") parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention")
parser.add_argument("--max_seq_len", type=int, default=2048, help="max context length") parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
parser.add_argument("--window_pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')") parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
# Training horizon (only one used, in order of precedence) # Training horizon (only one used, in order of precedence)
parser.add_argument("--num_iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)") parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
parser.add_argument("--target_flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)") parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
parser.add_argument("--target_param_data_ratio", type=int, default=8, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)") parser.add_argument("--target-param-data-ratio", type=int, default=8, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
# Optimization # Optimization
parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size") parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
parser.add_argument("--total_batch_size", type=int, default=524288, help="total batch size in tokens") parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
parser.add_argument("--embedding_lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)") parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--weight_decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)") parser.add_argument("--weight-decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)")
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--scalar_lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)") parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
parser.add_argument("--adam_beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding") parser.add_argument("--adam-beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding")
parser.add_argument("--adam_beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding") parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding")
parser.add_argument("--warmup_ratio", type=float, default=0.0, help="ratio of iterations for LR warmup") parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
parser.add_argument("--warmdown_ratio", type=float, default=0.4, help="ratio of iterations for LR warmdown") parser.add_argument("--warmdown-ratio", type=float, default=0.4, help="ratio of iterations for LR warmdown")
parser.add_argument("--final_lr_frac", type=float, default=0.0, help="final LR as fraction of initial LR") parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR")
parser.add_argument("--resume_from_step", type=int, default=-1, help="resume training from this step (-1 = disable)") parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)")
# Evaluation # Evaluation
parser.add_argument("--eval_every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)") parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)")
parser.add_argument("--eval_tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on") parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
parser.add_argument("--core_metric_every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)") parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)")
parser.add_argument("--core_metric_max_per_task", type=int, default=500, help="examples per task for CORE metric") parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric")
parser.add_argument("--sample_every", type=int, default=2000, help="sample from model every N steps (-1 = disable)") parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)")
parser.add_argument("--save_every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)") parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)")
# Output # Output
parser.add_argument("--model_tag", type=str, default=None, help="override model tag for checkpoint directory name") parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name")
args = parser.parse_args() args = parser.parse_args()
user_config = vars(args).copy() # for logging user_config = vars(args).copy() # for logging
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

View File

@@ -35,32 +35,32 @@ parser = argparse.ArgumentParser(description="Reinforcement learning on GSM8K")
# Logging # Logging
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime # Runtime
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16") parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
# Model loading # Model loading
parser.add_argument("--source", type=str, default="sft", help="mid|sft - which checkpoint to load from") parser.add_argument("--source", type=str, default="sft", help="mid|sft - which checkpoint to load from")
parser.add_argument("--model_tag", type=str, default=None, help="model tag to load from") parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model_step", type=int, default=None, help="model step to load from") parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
# Training horizon # Training horizon
parser.add_argument("--num_epochs", type=int, default=1, help="number of epochs over GSM8K") parser.add_argument("--num-epochs", type=int, default=1, help="number of epochs over GSM8K")
# Batch sizes / sampling # Batch sizes / sampling
parser.add_argument("--device_batch_size", type=int, default=8, help="max batch size per forward pass") parser.add_argument("--device-batch-size", type=int, default=8, help="max batch size per forward pass")
parser.add_argument("--examples_per_step", type=int, default=16, help="total examples per optimization step across all ranks") parser.add_argument("--examples-per-step", type=int, default=16, help="total examples per optimization step across all ranks")
parser.add_argument("--num_samples", type=int, default=16, help="number of samples per example/question") parser.add_argument("--num-samples", type=int, default=16, help="number of samples per example/question")
# Generation # Generation
parser.add_argument("--max_new_tokens", type=int, default=256, help="max tokens to generate per sample") parser.add_argument("--max-new-tokens", type=int, default=256, help="max tokens to generate per sample")
parser.add_argument("--temperature", type=float, default=1.0, help="sampling temperature") parser.add_argument("--temperature", type=float, default=1.0, help="sampling temperature")
parser.add_argument("--top_k", type=int, default=50, help="top-k sampling (0 = disabled)") parser.add_argument("--top-k", type=int, default=50, help="top-k sampling (0 = disabled)")
# Optimization # Optimization
parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)") parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)") parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
parser.add_argument("--init_lr_frac", type=float, default=0.05, help="initial LR as fraction of base LR") parser.add_argument("--init-lr-frac", type=float, default=0.05, help="initial LR as fraction of base LR")
# Evaluation / checkpointing # Evaluation / checkpointing
parser.add_argument("--eval_every", type=int, default=60, help="evaluate pass@k every N steps") parser.add_argument("--eval-every", type=int, default=60, help="evaluate pass@k every N steps")
parser.add_argument("--eval_examples", type=int, default=400, help="number of examples for pass@k evaluation") parser.add_argument("--eval-examples", type=int, default=400, help="number of examples for pass@k evaluation")
parser.add_argument("--save_every", type=int, default=60, help="save checkpoint every N steps") parser.add_argument("--save-every", type=int, default=60, help="save checkpoint every N steps")
args = parser.parse_args() args = parser.parse_args()
user_config = vars(args).copy() user_config = vars(args).copy()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

View File

@@ -37,29 +37,29 @@ parser = argparse.ArgumentParser(description="Supervised finetuning for chat")
# Logging # Logging
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime # Runtime
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16") parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
# Model loading # Model loading
parser.add_argument("--source", type=str, default="mid", help="base|mid - which checkpoint to load from") parser.add_argument("--source", type=str, default="mid", help="base|mid - which checkpoint to load from")
parser.add_argument("--model_tag", type=str, default=None, help="model tag to load from") parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model_step", type=int, default=None, help="model step to load from") parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
# Training horizon # Training horizon
parser.add_argument("--num_epochs", type=int, default=1, help="number of epochs") parser.add_argument("--num-epochs", type=int, default=1, help="number of epochs")
parser.add_argument("--num_iterations", type=int, default=-1, help="override number of iterations (-1 = use num_epochs)") parser.add_argument("--num-iterations", type=int, default=-1, help="override number of iterations (-1 = use num_epochs)")
# Batch sizes # Batch sizes
parser.add_argument("--device_batch_size", type=int, default=4, help="per-device batch size") parser.add_argument("--device-batch-size", type=int, default=4, help="per-device batch size")
parser.add_argument("--target_examples_per_step", type=int, default=32, help="target examples per optimization step") parser.add_argument("--target-examples-per-step", type=int, default=32, help="target examples per optimization step")
# Optimization # Optimization
parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)") parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)") parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
parser.add_argument("--init_lr_frac", type=float, default=0.02, help="initial LR as fraction of base LR") parser.add_argument("--init-lr-frac", type=float, default=0.02, help="initial LR as fraction of base LR")
# Evaluation # Evaluation
parser.add_argument("--eval_every", type=int, default=100, help="evaluate val loss every N steps") parser.add_argument("--eval-every", type=int, default=100, help="evaluate val loss every N steps")
parser.add_argument("--eval_steps", type=int, default=100, help="number of batches for val loss evaluation") parser.add_argument("--eval-steps", type=int, default=100, help="number of batches for val loss evaluation")
parser.add_argument("--eval_metrics_every", type=int, default=200, help="evaluate accuracy metrics every N steps") parser.add_argument("--eval-metrics-every", type=int, default=200, help="evaluate accuracy metrics every N steps")
parser.add_argument("--eval_metrics_max_problems", type=int, default=1024, help="max problems per metric evaluation") parser.add_argument("--eval-metrics-max-problems", type=int, default=1024, help="max problems per metric evaluation")
args = parser.parse_args() args = parser.parse_args()
user_config = vars(args).copy() user_config = vars(args).copy()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

View File

@@ -6,7 +6,7 @@ python -m scripts.mid_train
Or torchrun for training: Or torchrun for training:
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16 torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device-batch-size=16
""" """
import argparse import argparse
@@ -36,28 +36,28 @@ parser = argparse.ArgumentParser(description="Midtrain the model")
# Logging # Logging
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime # Runtime
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16") parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
# Model loading # Model loading
parser.add_argument("--model_tag", type=str, default=None, help="model tag to load from") parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model_step", type=int, default=None, help="model step to load from") parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
# Training horizon # Training horizon
parser.add_argument("--num_iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)") parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)")
# Batch sizes # Batch sizes
parser.add_argument("--max_seq_len", type=int, default=2048, help="max context length") parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size") parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
parser.add_argument("--total_batch_size", type=int, default=524288, help="total batch size in tokens") parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
# Optimization # Optimization
parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)") parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)") parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
parser.add_argument("--init_lr_frac", type=float, default=1.0, help="initial LR as fraction of base LR") parser.add_argument("--init-lr-frac", type=float, default=1.0, help="initial LR as fraction of base LR")
# Evaluation # Evaluation
parser.add_argument("--eval_every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)") parser.add_argument("--eval-every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)")
parser.add_argument("--eval_tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on") parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
# Output # Output
parser.add_argument("--dry_run", action="store_true", help="log to wandb but skip checkpoints/report") parser.add_argument("--dry-run", action="store_true", help="log to wandb but skip checkpoints/report")
args = parser.parse_args() args = parser.parse_args()
user_config = vars(args).copy() user_config = vars(args).copy()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -79,7 +79,7 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mi
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step) model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step)
pretrain_batch_size = meta.get("device_batch_size", None) pretrain_batch_size = meta.get("device_batch_size", None)
if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size: if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size:
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?") print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?")
orig_model = model orig_model = model
model = torch.compile(model, dynamic=False) model = torch.compile(model, dynamic=False)
depth = model.config.n_layer depth = model.config.n_layer
@@ -142,7 +142,8 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100):
# Conversation buffer: list of token lists # Conversation buffer: list of token lists
conv_buffer = [] conv_buffer = []
cursor = ddp_rank # Each rank processes different conversations cursor = ddp_rank # Each rank processes different conversations (for fetching)
consumed = ddp_rank # Track actual consumption separately from buffering
epoch = 1 epoch = 1
it = 0 # iteration counter it = 0 # iteration counter
@@ -156,8 +157,7 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100):
if cursor >= dataset_size: if cursor >= dataset_size:
cursor = cursor % dataset_size cursor = cursor % dataset_size
epoch += 1 epoch += 1
if split == "train": # Note: last_step is now triggered based on consumption, not fetching
last_step = True # toggle last_step to True, which will terminate the training loop
while True: while True:
rows = [] rows = []
@@ -183,10 +183,12 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100):
# Found a conversation that fits - use it entirely # Found a conversation that fits - use it entirely
conv = conv_buffer.pop(best_idx) conv = conv_buffer.pop(best_idx)
row.extend(conv) row.extend(conv)
consumed += ddp_world_size # Track actual consumption
else: else:
# No conversation fits - crop first conversation to fill remaining # No conversation fits - crop first conversation to fill remaining
conv = conv_buffer.pop(0) conv = conv_buffer.pop(0)
row.extend(conv[:remaining]) row.extend(conv[:remaining])
consumed += ddp_world_size # Track actual consumption
rows.append(row[:row_capacity]) rows.append(row[:row_capacity])
@@ -195,13 +197,16 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100):
if 0 < args.num_iterations <= it and split == "train": if 0 < args.num_iterations <= it and split == "train":
last_step = True last_step = True
# Update progress tracking # Update progress tracking (based on consumed, not cursor, to account for buffering)
if split == "train": if split == "train":
current_epoch = epoch current_epoch = epoch
if args.num_iterations > 0: if args.num_iterations > 0:
approx_progress = it / args.num_iterations approx_progress = it / args.num_iterations
else: else:
approx_progress = cursor / dataset_size approx_progress = consumed / dataset_size
# Trigger last_step when we've consumed enough (instead of when cursor wraps)
if consumed >= dataset_size:
last_step = True
# Build tensors # Build tensors
use_cuda = device_type == "cuda" use_cuda = device_type == "cuda"

View File

@@ -14,9 +14,9 @@ from nanochat.dataset import parquets_iter_batched
# Parse command line arguments # Parse command line arguments
parser = argparse.ArgumentParser(description='Train a BPE tokenizer') parser = argparse.ArgumentParser(description='Train a BPE tokenizer')
parser.add_argument('--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)') parser.add_argument('--max-chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)')
parser.add_argument('--doc_cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)') parser.add_argument('--doc-cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)')
parser.add_argument('--vocab_size', type=int, default=32768, help='Vocabulary size (default: 32768 = 2^15)') parser.add_argument('--vocab-size', type=int, default=32768, help='Vocabulary size (default: 32768 = 2^15)')
args = parser.parse_args() args = parser.parse_args()
print(f"max_chars: {args.max_chars:,}") print(f"max_chars: {args.max_chars:,}")
print(f"doc_cap: {args.doc_cap:,}") print(f"doc_cap: {args.doc_cap:,}")

View File

@@ -59,7 +59,7 @@ python -m nanochat.dataset -n 8
python -m nanochat.dataset -n 370 & python -m nanochat.dataset -n 370 &
DATASET_DOWNLOAD_PID=$! DATASET_DOWNLOAD_PID=$!
# train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data # train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data
python -m scripts.tok_train --max_chars=2000000000 --vocab_size=65536 python -m scripts.tok_train --max-chars=2000000000 --vocab-size=65536
# evaluate the tokenizer (report compression ratio etc.) # evaluate the tokenizer (report compression ratio etc.)
python -m scripts.tok_eval python -m scripts.tok_eval
@@ -81,7 +81,7 @@ wait $DATASET_DOWNLOAD_PID
NPROC_PER_NODE=8 NPROC_PER_NODE=8
# pretrain the d20 model # pretrain the d20 model
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --target_param_data_ratio=20 --run=$WANDB_RUN torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --target-param-data-ratio=20 --run=$WANDB_RUN
# evaluate the model on a larger chunk of train/val data and draw some samples # evaluate the model on a larger chunk of train/val data and draw some samples
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
# evaluate the model on CORE tasks # evaluate the model on CORE tasks