update the CPU/MPS script to give reasonable results. The model can at least answer that Paris is the capital of France and knows that the sky is blue, for about 40 minutes of training on my macbook. Also fixed a bug that existed due to KVCache bfloat16 dtype assumption

This commit is contained in:
karpathy
2026-01-17 12:27:30 -08:00
parent f5425245f9
commit f9a7e0f111
4 changed files with 67 additions and 49 deletions

View File

@@ -1,12 +1,15 @@
#!/bin/bash
# Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks)
# This script was last updated/tuned on Jan 17, 2026.
# Run as:
# bash dev/cpu_demo_run.sh
# NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook.
# Think of this run as educational/fun demo, not something you should expect to work well.
# This is also why I hide this script away in dev/
# (This is why I hide this script away in dev/)
# You may also want to run this script manually and one by one, copy pasting commands into your terminal.
# all the setup stuff
export OMP_NUM_THREADS=1
@@ -20,58 +23,48 @@ if [ -z "$WANDB_RUN" ]; then
WANDB_RUN=dummy
fi
# wipe the report
python -m nanochat.report reset
# train tokenizer on ~1B characters
python -m nanochat.dataset -n 4
python -m scripts.tok_train --max-chars=1000000000
# train tokenizer on ~2B characters (~34 seconds on my MacBook Pro M3 Max)
python -m nanochat.dataset -n 8
python -m scripts.tok_train --max-chars=2000000000
python -m scripts.tok_eval
# train a very small 4 layer model on the CPU
# each optimization step processes a single sequence of 1024 tokens
# we only run 50 steps of optimization (bump this to get better results)
# train a small 4 layer model
# I tuned this run to complete in about 30 minutes on my MacBook Pro M3 Max.
# To get better results, try increasing num_iterations, or get other ideas from your favorite LLM.
python -m scripts.base_train \
--depth=4 \
--max-seq-len=1024 \
--device-batch-size=1 \
--total-batch-size=1024 \
--eval-every=50 \
--eval-tokens=4096 \
--core-metric-every=50 \
--core-metric-max-per-task=12 \
--sample-every=50 \
--num-iterations=50 \
--depth=6 \
--head-dim=64 \
--window-pattern=L \
--max-seq-len=512 \
--device-batch-size=32 \
--total-batch-size=16384 \
--eval-every=100 \
--eval-tokens=524288 \
--core-metric-every=-1 \
--sample-every=100 \
--num-iterations=5000 \
--run=$WANDB_RUN
python -m scripts.base_loss --device-batch-size=1 --split-tokens=4096
python -m scripts.base_loss --device-batch-size=1 --split-tokens=16384
python -m scripts.base_eval --max-per-task=16
# midtraining
# midtraining (~10 minutes on my MacBook Pro M3 Max)
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
python -m scripts.mid_train \
--max-seq-len=1024 \
--device-batch-size=1 \
--eval-every=50 \
--eval-tokens=4096 \
--total-batch-size=1024 \
--num-iterations=100 \
--run=$WANDB_RUN
# eval results will be terrible, this is just to execute the code paths.
# note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems
python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20
# SFT
python -m scripts.chat_sft \
--device-batch-size=1 \
--target-examples-per-step=4 \
--num-iterations=100 \
--eval-steps=4 \
--eval-metrics-max-problems=16 \
--max-seq-len=512 \
--device-batch-size=32 \
--total-batch-size=16384 \
--eval-every=200 \
--eval-tokens=524288 \
--num-iterations=1500 \
--run=$WANDB_RUN
# Chat CLI
# python -m scripts.chat_cli -p "Why is the sky blue?"
# (it's ~ok to skip SFT)
# Chat Web
# python -m scripts.chat_web
# Chat with the model over CLI
# The model should be able to say that it is Paris.
# It might even know that the color of the sky is blue.
# Sometimes the model likes it if you first say Hi before you ask it questions.
# python -m scripts.chat_cli -i mid -p "What is the capital of France?"
python -m nanochat.report generate
# Chat with the model over a pretty WebUI ChatGPT style
# python -m scripts.chat_web -i mid

View File

@@ -90,7 +90,7 @@ class KVCache:
- Position tracked per batch element via cache_seqlens tensor
"""
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype=torch.bfloat16):
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype):
self.batch_size = batch_size
self.max_seq_len = seq_len
self.n_layers = num_layers
@@ -172,6 +172,13 @@ class Engine:
"""Same as generate, but does single prefill and then clones the KV cache."""
assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
device = self.model.get_device()
# NOTE: setting the dtype here and in this way is an ugly hack.
# Currently the repo assumes that cuda -> bfloat16 and everything else -> float32.
# We need to know the dtype here to call __init__ on KVCache and pre-allocate its tensors.
# As a quick hack, we're making generate() function inherit and know about this repo-wise assumption.
# I think there has to be a bigger refactor to deal with device/dtype tracking across the codebase.
# In particular, the KVCache should allocate its tensors lazily
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
rng = torch.Generator(device=device)
rng.manual_seed(seed)
@@ -191,6 +198,7 @@ class Engine:
batch_size=1,
seq_len=len(tokens),
device=device,
dtype=dtype,
**kv_model_kwargs,
)
ids = torch.tensor([tokens], dtype=torch.long, device=device)
@@ -203,6 +211,7 @@ class Engine:
batch_size=num_samples,
seq_len=kv_length_hint,
device=device,
dtype=dtype,
**kv_model_kwargs,
)
kv_cache_decode.prefill(kv_cache_prefill)

View File

@@ -104,7 +104,7 @@ for split_name in ["train", "val"]:
bpb_results[split_name] = bpb
print0(f"Model: {model_name}, {split_name} bpb: {bpb:.6f}")
# Master process also samples from the model (only for nanochat models)
# Master process also samples from the model for some basic knowledge-eliciting prompts (only for nanochat models)
samples = []
if ddp_rank == 0 and args.hf_path is None:
prompts = [
@@ -122,9 +122,23 @@ if ddp_rank == 0 and args.hf_path is None:
with autocast_ctx:
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
sample_str = tokenizer.decode(sample[0])
print0("-" * 80)
print0(sample_str)
samples.append(sample_str)
# Draw some unconditioned samples from the model (only for nanochat models)
unconditioned_samples = []
if ddp_rank == 0 and args.hf_path is None:
engine = Engine(model, tokenizer)
tokens = tokenizer("", prepend="<|bos|>")
with autocast_ctx:
samples, _ = engine.generate_batch(tokens, num_samples=8, max_tokens=128, temperature=1.0)
for sample in samples:
sample_str = tokenizer.decode(sample)
print0("-" * 80)
print0(sample_str)
unconditioned_samples.append(sample_str)
# Log to report
from nanochat.report import get_report
get_report().log(section="Base model loss", data=[
@@ -134,6 +148,7 @@ get_report().log(section="Base model loss", data=[
"val bpb": bpb_results["val"],
},
{f"sample {i}": sample for i, sample in enumerate(samples)},
{f"unconditioned sample {i}": sample for i, sample in enumerate(unconditioned_samples)},
])
# Cleanup

View File

@@ -96,6 +96,7 @@ def test_kv_cache_basic():
head_dim=head_dim,
num_layers=num_layers,
device="cpu",
dtype=torch.float32,
)
# Check initial state
@@ -130,7 +131,7 @@ def test_kv_cache_prefill():
# Create source cache and advance it
src_cache = KVCache(
batch_size=batch_size, num_heads=num_heads, seq_len=32,
head_dim=head_dim, num_layers=num_layers, device="cpu",
head_dim=head_dim, num_layers=num_layers, device="cpu", dtype=torch.float32,
)
# Write some data to source cache
src_cache.k_cache[0, 0, :16, :, :] = 1.0
@@ -140,7 +141,7 @@ def test_kv_cache_prefill():
# Create destination cache with larger seq_len
dst_cache = KVCache(
batch_size=batch_size, num_heads=num_heads, seq_len=64,
head_dim=head_dim, num_layers=num_layers, device="cpu",
head_dim=head_dim, num_layers=num_layers, device="cpu", dtype=torch.float32,
)
# Prefill