mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-30 04:22:02 +00:00
nudge hyperparameters of the base script with the results of the sweeps and miniseries. vocab size down to 32K. D:N ratio from 20 to 8. add miniseries script
This commit is contained in:
@@ -1,11 +1,11 @@
|
||||
"""
|
||||
Train model. Run as:
|
||||
Train model. From root directory of the project, run as:
|
||||
|
||||
python base_train.py
|
||||
python -m scripts.base_train.py
|
||||
|
||||
or distributed as:
|
||||
|
||||
torchrun --nproc_per_node=8 base_train.py
|
||||
torchrun --nproc_per_node=8 -m scripts.base_train.py
|
||||
|
||||
If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
|
||||
python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_every=-1 --total_batch_size=512 --num_iterations=20
|
||||
@@ -39,11 +39,13 @@ parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('d
|
||||
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
# Model architecture
|
||||
parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model")
|
||||
parser.add_argument("--aspect_ratio", type=int, default=64, help="model_dim = depth * aspect_ratio")
|
||||
parser.add_argument("--head_dim", type=int, default=128, help="target head dimension for attention")
|
||||
parser.add_argument("--max_seq_len", type=int, default=2048, help="max context length")
|
||||
# Training horizon (only one used, in order of precedence)
|
||||
parser.add_argument("--num_iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
|
||||
parser.add_argument("--target_flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
|
||||
parser.add_argument("--target_param_data_ratio", type=int, default=20, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
|
||||
parser.add_argument("--target_param_data_ratio", type=int, default=8, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
|
||||
# Optimization
|
||||
parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--total_batch_size", type=int, default=524288, help="total batch size in tokens")
|
||||
@@ -51,6 +53,8 @@ parser.add_argument("--embedding_lr", type=float, default=0.3, help="learning ra
|
||||
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
||||
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding")
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="gradient clipping value (0.0 = disabled)")
|
||||
parser.add_argument("--warmup_ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
|
||||
parser.add_argument("--warmdown_ratio", type=float, default=0.4, help="ratio of iterations for LR warmdown")
|
||||
@@ -89,8 +93,8 @@ print0(f"Vocab size: {vocab_size:,}")
|
||||
|
||||
# Model kwargs are derived from the desired depth of the model
|
||||
num_layers = args.depth
|
||||
model_dim = args.depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
|
||||
def find_num_heads(model_dim, target_head_dim=128):
|
||||
model_dim = args.depth * args.aspect_ratio
|
||||
def find_num_heads(model_dim, target_head_dim):
|
||||
# Find num_heads that divides model_dim evenly, with head_dim closest to target.
|
||||
ideal = max(1, round(model_dim / target_head_dim))
|
||||
for offset in range(model_dim):
|
||||
@@ -98,7 +102,7 @@ def find_num_heads(model_dim, target_head_dim=128):
|
||||
if candidate > 0 and model_dim % candidate == 0:
|
||||
return candidate
|
||||
return 1
|
||||
num_heads = find_num_heads(model_dim)
|
||||
num_heads = find_num_heads(model_dim, args.head_dim)
|
||||
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
|
||||
print0(f"num_layers: {num_layers}")
|
||||
print0(f"model_dim: {model_dim}")
|
||||
@@ -115,6 +119,17 @@ print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_l
|
||||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
|
||||
# Batch size scaling for learning rates (hyperparameters were tuned at reference batch size 2^19)
|
||||
batch_lr_scale = 1.0
|
||||
reference_batch_size = 2**19
|
||||
batch_ratio = args.total_batch_size / reference_batch_size
|
||||
if batch_ratio != 1.0:
|
||||
# SGD: linear scaling with batch size is standard (not used in nanochat)
|
||||
# AdamW: sqrt scaling is standard
|
||||
# Muon: sqrt scaling is an assumption - not fully studied, but it's a second-order-ish optimizer
|
||||
batch_lr_scale = batch_ratio ** 0.5
|
||||
print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {args.total_batch_size:,} (reference: {reference_batch_size:,})")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Model
|
||||
|
||||
@@ -141,7 +156,8 @@ if resuming:
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
|
||||
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
||||
num_params = sum(p.numel() for p in model.parameters())
|
||||
print0(f"Number of parameters: {num_params:,}")
|
||||
num_scaling_params = orig_model.num_scaling_params()
|
||||
print0(f"Number of parameters: {num_params:,} (scaling: {num_scaling_params:,})")
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
|
||||
|
||||
@@ -155,20 +171,27 @@ elif args.target_flops > 0:
|
||||
num_iterations = round(args.target_flops / (num_flops_per_token * args.total_batch_size))
|
||||
print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
|
||||
elif args.target_param_data_ratio > 0:
|
||||
# calculate the number of iterations from the target param data ratio
|
||||
target_tokens = args.target_param_data_ratio * num_params
|
||||
# calculate the number of iterations from the target param data ratio (use scaling params per Kaplan et al.)
|
||||
target_tokens = args.target_param_data_ratio * num_scaling_params
|
||||
num_iterations = target_tokens // args.total_batch_size
|
||||
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
|
||||
else:
|
||||
raise ValueError("No training horizon specified")
|
||||
total_tokens = args.total_batch_size * num_iterations
|
||||
print0(f"Total number of training tokens: {total_tokens:,}")
|
||||
print0(f"Tokens : Params ratio: {args.total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20
|
||||
print0(f"Tokens : Params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20
|
||||
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
|
||||
optimizers = model.setup_optimizers(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay)
|
||||
adam_betas = (args.adam_beta1, args.adam_beta2)
|
||||
optimizers = model.setup_optimizers(
|
||||
unembedding_lr=args.unembedding_lr * batch_lr_scale,
|
||||
embedding_lr=args.embedding_lr * batch_lr_scale,
|
||||
matrix_lr=args.matrix_lr * batch_lr_scale,
|
||||
weight_decay=args.weight_decay,
|
||||
adam_betas=adam_betas,
|
||||
)
|
||||
adamw_optimizer, muon_optimizer = optimizers
|
||||
|
||||
if resuming:
|
||||
|
||||
@@ -16,7 +16,7 @@ from nanochat.dataset import parquets_iter_batched
|
||||
parser = argparse.ArgumentParser(description='Train a BPE tokenizer')
|
||||
parser.add_argument('--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)')
|
||||
parser.add_argument('--doc_cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)')
|
||||
parser.add_argument('--vocab_size', type=int, default=65536, help='Vocabulary size (default: 65536 = 2^16)')
|
||||
parser.add_argument('--vocab_size', type=int, default=32768, help='Vocabulary size (default: 32768 = 2^15)')
|
||||
args = parser.parse_args()
|
||||
print(f"max_chars: {args.max_chars:,}")
|
||||
print(f"doc_cap: {args.doc_cap:,}")
|
||||
|
||||
Reference in New Issue
Block a user