mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-30 04:22:02 +00:00
Merge branch 'engram'
This commit is contained in:
@@ -1,11 +1,11 @@
|
||||
"""
|
||||
Train model. From root directory of the project, run as:
|
||||
|
||||
python -m scripts.base_train.py
|
||||
python -m scripts.base_train
|
||||
|
||||
or distributed as:
|
||||
|
||||
torchrun --nproc_per_node=8 -m scripts.base_train.py
|
||||
torchrun --nproc_per_node=8 -m scripts.base_train
|
||||
|
||||
If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
|
||||
python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20
|
||||
@@ -47,7 +47,7 @@ parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding
|
||||
# Training horizon (only one used, in order of precedence)
|
||||
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
|
||||
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
|
||||
parser.add_argument("--target-param-data-ratio", type=int, default=4, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
|
||||
parser.add_argument("--target-param-data-ratio", type=float, default=10.5, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
|
||||
# Optimization
|
||||
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
|
||||
@@ -178,9 +178,14 @@ if resuming:
|
||||
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
|
||||
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
||||
num_params = sum(p.numel() for p in model.parameters())
|
||||
num_scaling_params = orig_model.num_scaling_params()
|
||||
print0(f"Number of parameters: {num_params:,} (scaling: {num_scaling_params:,})")
|
||||
|
||||
# Detailed parameter counts
|
||||
param_counts = orig_model.num_scaling_params()
|
||||
print0(f"Parameter counts:")
|
||||
for key, value in param_counts.items():
|
||||
print0(f"{key:24s}: {value:,}")
|
||||
num_params = param_counts['total']
|
||||
num_scaling_params = param_counts['transformer_matrices'] + param_counts['lm_head'] # determined to give the cleanest scaling laws, see dev/LOG.md Jan 27, 2026
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
|
||||
|
||||
@@ -195,14 +200,14 @@ elif args.target_flops > 0:
|
||||
print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
|
||||
elif args.target_param_data_ratio > 0:
|
||||
# calculate the number of iterations from the target param data ratio (use scaling params per Kaplan et al.)
|
||||
target_tokens = args.target_param_data_ratio * num_scaling_params
|
||||
target_tokens = int(args.target_param_data_ratio * num_scaling_params)
|
||||
num_iterations = target_tokens // args.total_batch_size
|
||||
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
|
||||
else:
|
||||
raise ValueError("No training horizon specified")
|
||||
total_tokens = args.total_batch_size * num_iterations
|
||||
print0(f"Total number of training tokens: {total_tokens:,}")
|
||||
print0(f"Tokens : Params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20
|
||||
print0(f"Tokens : Scaling params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20
|
||||
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -445,7 +450,7 @@ get_report().log(section="Base model training", data=[
|
||||
"Number of FLOPs per token": f"{num_flops_per_token:e}",
|
||||
"Calculated number of iterations": num_iterations,
|
||||
"Number of training tokens": total_tokens,
|
||||
"Tokens : Params ratio": args.total_batch_size * num_iterations / num_params,
|
||||
"Tokens : Scaling params ratio": args.total_batch_size * num_iterations / num_scaling_params,
|
||||
"DDP world size": ddp_world_size,
|
||||
"warmup_ratio": args.warmup_ratio,
|
||||
"warmdown_ratio": args.warmdown_ratio,
|
||||
|
||||
Reference in New Issue
Block a user