diff --git a/dev/runcpu.sh b/dev/runcpu.sh index 6ed7a8a..da8f6d1 100755 --- a/dev/runcpu.sh +++ b/dev/runcpu.sh @@ -1,12 +1,15 @@ #!/bin/bash # Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks) +# This script was last updated/tuned on Jan 17, 2026. + # Run as: # bash dev/cpu_demo_run.sh # NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook. # Think of this run as educational/fun demo, not something you should expect to work well. -# This is also why I hide this script away in dev/ +# (This is why I hide this script away in dev/) +# You may also want to run this script manually and one by one, copy pasting commands into your terminal. # all the setup stuff export OMP_NUM_THREADS=1 @@ -20,58 +23,48 @@ if [ -z "$WANDB_RUN" ]; then WANDB_RUN=dummy fi -# wipe the report -python -m nanochat.report reset - -# train tokenizer on ~1B characters -python -m nanochat.dataset -n 4 -python -m scripts.tok_train --max-chars=1000000000 +# train tokenizer on ~2B characters (~34 seconds on my MacBook Pro M3 Max) +python -m nanochat.dataset -n 8 +python -m scripts.tok_train --max-chars=2000000000 python -m scripts.tok_eval -# train a very small 4 layer model on the CPU -# each optimization step processes a single sequence of 1024 tokens -# we only run 50 steps of optimization (bump this to get better results) +# train a small 4 layer model +# I tuned this run to complete in about 30 minutes on my MacBook Pro M3 Max. +# To get better results, try increasing num_iterations, or get other ideas from your favorite LLM. python -m scripts.base_train \ - --depth=4 \ - --max-seq-len=1024 \ - --device-batch-size=1 \ - --total-batch-size=1024 \ - --eval-every=50 \ - --eval-tokens=4096 \ - --core-metric-every=50 \ - --core-metric-max-per-task=12 \ - --sample-every=50 \ - --num-iterations=50 \ + --depth=6 \ + --head-dim=64 \ + --window-pattern=L \ + --max-seq-len=512 \ + --device-batch-size=32 \ + --total-batch-size=16384 \ + --eval-every=100 \ + --eval-tokens=524288 \ + --core-metric-every=-1 \ + --sample-every=100 \ + --num-iterations=5000 \ --run=$WANDB_RUN -python -m scripts.base_loss --device-batch-size=1 --split-tokens=4096 +python -m scripts.base_loss --device-batch-size=1 --split-tokens=16384 python -m scripts.base_eval --max-per-task=16 -# midtraining +# midtraining (~10 minutes on my MacBook Pro M3 Max) +curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl python -m scripts.mid_train \ - --max-seq-len=1024 \ - --device-batch-size=1 \ - --eval-every=50 \ - --eval-tokens=4096 \ - --total-batch-size=1024 \ - --num-iterations=100 \ - --run=$WANDB_RUN -# eval results will be terrible, this is just to execute the code paths. -# note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems -python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20 - -# SFT -python -m scripts.chat_sft \ - --device-batch-size=1 \ - --target-examples-per-step=4 \ - --num-iterations=100 \ - --eval-steps=4 \ - --eval-metrics-max-problems=16 \ + --max-seq-len=512 \ + --device-batch-size=32 \ + --total-batch-size=16384 \ + --eval-every=200 \ + --eval-tokens=524288 \ + --num-iterations=1500 \ --run=$WANDB_RUN -# Chat CLI -# python -m scripts.chat_cli -p "Why is the sky blue?" +# (it's ~ok to skip SFT) -# Chat Web -# python -m scripts.chat_web +# Chat with the model over CLI +# The model should be able to say that it is Paris. +# It might even know that the color of the sky is blue. +# Sometimes the model likes it if you first say Hi before you ask it questions. +# python -m scripts.chat_cli -i mid -p "What is the capital of France?" -python -m nanochat.report generate +# Chat with the model over a pretty WebUI ChatGPT style +# python -m scripts.chat_web -i mid diff --git a/nanochat/engine.py b/nanochat/engine.py index 53fdec5..7f05eb4 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -90,7 +90,7 @@ class KVCache: - Position tracked per batch element via cache_seqlens tensor """ - def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype=torch.bfloat16): + def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype): self.batch_size = batch_size self.max_seq_len = seq_len self.n_layers = num_layers @@ -172,6 +172,13 @@ class Engine: """Same as generate, but does single prefill and then clones the KV cache.""" assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints" device = self.model.get_device() + # NOTE: setting the dtype here and in this way is an ugly hack. + # Currently the repo assumes that cuda -> bfloat16 and everything else -> float32. + # We need to know the dtype here to call __init__ on KVCache and pre-allocate its tensors. + # As a quick hack, we're making generate() function inherit and know about this repo-wise assumption. + # I think there has to be a bigger refactor to deal with device/dtype tracking across the codebase. + # In particular, the KVCache should allocate its tensors lazily + dtype = torch.bfloat16 if device.type == "cuda" else torch.float32 rng = torch.Generator(device=device) rng.manual_seed(seed) @@ -191,6 +198,7 @@ class Engine: batch_size=1, seq_len=len(tokens), device=device, + dtype=dtype, **kv_model_kwargs, ) ids = torch.tensor([tokens], dtype=torch.long, device=device) @@ -203,6 +211,7 @@ class Engine: batch_size=num_samples, seq_len=kv_length_hint, device=device, + dtype=dtype, **kv_model_kwargs, ) kv_cache_decode.prefill(kv_cache_prefill) diff --git a/scripts/base_loss.py b/scripts/base_loss.py index 6b44a30..fb8cf59 100644 --- a/scripts/base_loss.py +++ b/scripts/base_loss.py @@ -104,7 +104,7 @@ for split_name in ["train", "val"]: bpb_results[split_name] = bpb print0(f"Model: {model_name}, {split_name} bpb: {bpb:.6f}") -# Master process also samples from the model (only for nanochat models) +# Master process also samples from the model for some basic knowledge-eliciting prompts (only for nanochat models) samples = [] if ddp_rank == 0 and args.hf_path is None: prompts = [ @@ -122,9 +122,23 @@ if ddp_rank == 0 and args.hf_path is None: with autocast_ctx: sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0) sample_str = tokenizer.decode(sample[0]) + print0("-" * 80) print0(sample_str) samples.append(sample_str) +# Draw some unconditioned samples from the model (only for nanochat models) +unconditioned_samples = [] +if ddp_rank == 0 and args.hf_path is None: + engine = Engine(model, tokenizer) + tokens = tokenizer("", prepend="<|bos|>") + with autocast_ctx: + samples, _ = engine.generate_batch(tokens, num_samples=8, max_tokens=128, temperature=1.0) + for sample in samples: + sample_str = tokenizer.decode(sample) + print0("-" * 80) + print0(sample_str) + unconditioned_samples.append(sample_str) + # Log to report from nanochat.report import get_report get_report().log(section="Base model loss", data=[ @@ -134,6 +148,7 @@ get_report().log(section="Base model loss", data=[ "val bpb": bpb_results["val"], }, {f"sample {i}": sample for i, sample in enumerate(samples)}, + {f"unconditioned sample {i}": sample for i, sample in enumerate(unconditioned_samples)}, ]) # Cleanup diff --git a/tests/test_engine.py b/tests/test_engine.py index 67b8a5c..0159111 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -96,6 +96,7 @@ def test_kv_cache_basic(): head_dim=head_dim, num_layers=num_layers, device="cpu", + dtype=torch.float32, ) # Check initial state @@ -130,7 +131,7 @@ def test_kv_cache_prefill(): # Create source cache and advance it src_cache = KVCache( batch_size=batch_size, num_heads=num_heads, seq_len=32, - head_dim=head_dim, num_layers=num_layers, device="cpu", + head_dim=head_dim, num_layers=num_layers, device="cpu", dtype=torch.float32, ) # Write some data to source cache src_cache.k_cache[0, 0, :16, :, :] = 1.0 @@ -140,7 +141,7 @@ def test_kv_cache_prefill(): # Create destination cache with larger seq_len dst_cache = KVCache( batch_size=batch_size, num_heads=num_heads, seq_len=64, - head_dim=head_dim, num_layers=num_layers, device="cpu", + head_dim=head_dim, num_layers=num_layers, device="cpu", dtype=torch.float32, ) # Prefill