Merge branch 'master' into cpu-mps-dev

This commit is contained in:
Andrej Karpathy
2025-10-21 17:15:53 +00:00
12 changed files with 504 additions and 9 deletions

1
.gitignore vendored
View File

@@ -3,3 +3,4 @@ __pycache__/
*.pyc
rustbpe/target/
dev-ignore/
report.md

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 Andrej Karpathy
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -6,6 +6,10 @@
This repo is a full-stack implementation of an LLM like ChatGPT in a single, clean, minimal, hackable, dependency-lite codebase. nanochat is designed to run on a single 8XH100 node via scripts like [speedrun.sh](speedrun.sh), that run the entire pipeline start to end. This includes tokenization, pretraining, finetuning, evaluation, inference, and web serving over a simple UI so that you can talk to your own LLM just like ChatGPT. nanochat will become the capstone project of the course LLM101n being developed by Eureka Labs.
## Talk to it
To get a sense of the endpoint of this repo, you can currently find [nanochat d32](https://github.com/karpathy/nanochat/discussions/8) hosted on [nanochat.karpathy.ai](https://nanochat.karpathy.ai/). "d32" means that this model has 32 layers in the Transformer neural network. This model has 1.9 billion parameters, it was trained on 38 billion tokens by simply running the single script [run1000.sh](run1000.sh), and the total cost of training was ~$800 (about 33 hours training time on 8XH100 GPU node). While today this is enough to outperform GPT-2 of 2019, it falls dramatically short of moden Large Language Models like GPT-5. When talking to these micro models, you'll see that they make a lot of mistakes, they are a little bit naive and silly and they hallucinate a ton, a bit like children. It's kind of amusing. But what makes nanochat unique is that it is fully yours - fully configurable, tweakable, hackable, and trained by you from start to end. To train and talk to your own, we turn to...
## Quick start
The fastest way to feel the magic is to run the speedrun script [speedrun.sh](speedrun.sh), which trains and inferences the $100 tier of nanochat. On an 8XH100 node at $24/hr, this gives a total run time of about 4 hours. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script:
@@ -89,6 +93,14 @@ And a bit more about computing environments that will run nanochat:
- If your GPU(s) have less than 80GB, you'll have to tune some of the hyperparameters or you will OOM / run out of VRAM. Look for `--device_batch_size` in the scripts and reduce it until things fit. E.g. from 32 (default) to 16, 8, 4, 2, or even 1. Less than that you'll have to know a bit more what you're doing and get more creative.
- Most of the code is fairly vanilla PyTorch so it should run on anything that supports that - xpu, mps, or etc, but I haven't implemented this out of the box so it might take a bit of tinkering.
## Running on CPU / MPS
If you'd like to tinker with nanochat on your Macbook or a CPU machine, there is a work in progress [CPU|MPS PR](https://github.com/karpathy/nanochat/pull/88) up here. If you're on Macbook, use `--device_type=mps` when running `base_train.py`. See the PR and its diff for more. You're not going to get too far without GPU nodes, but at least you'll be able to run the code and maybe train a very tiny LLM with some patience.
## Customization
To customize your nanochat, see [Guide: infusing identity to your nanochat](https://github.com/karpathy/nanochat/discussions/139) in Discussions, which describes how you can tune your nanochat's personality through synthetic data generation and mixing that data into midtraining and SFT stages.
## Questions
nanochat is designed to be short and sweet. One big advantage of this is that we can package up all of the files together and copy paste them to your favorite LLM to ask arbitrary questions. As an example, I like to package up the repo using the [files-to-prompt](https://github.com/simonw/files-to-prompt) utility like so:

387
dev/gen_synthetic_data.py Normal file
View File

@@ -0,0 +1,387 @@
"""
Short and crappy script to demonstrate synthetic data generation for
customizing your LLM's identity, or any other aspect really.
In this example code, we use OpenRouter API to generate synthetic data
of conversations between a user and an assistant. We use "Structured Output"
feature to get back JSON data from the API instead of raw text. The conversations
are saved simply to a .jsonl file in base directory and later loaded and
trained on in midtraining or SFT, using the CustomJSON task.
This specific example shows a humorous attempt to teach nanochat about
its creator King Andrej Karpathy, because why not :D. Note two things about the
prompt:
1. We are instructing the LLM how to handle various situations (e.g. foreign language),
simply in English. You can infuse any style or behavior in this way.
2. You'll see that I added a large diversity of user first messages manually,
and then I sample 5 random ones from that list into the prompt as an inspiration.
This is really important to do because DIVERSITY CONTROL is key. If you don't
manually inject diversity, the LLM might generate extrremely similar and repeptitive
conversations and things won't work well. Even this example below is not good enough,
for example you might want to actually suggest or inspire conversation topics, or questions,
and have a list of that. Basically, this is the KEY creative part to get right. Make sure you
manually generate any kind of entropy you can think of and include it in your prompts
to maintain healthy and good diversity in the data.
NOTE: You need OpenRouter API key in a file called "openroutertoken.txt" in the root directory of the repo.
(obviously you can tune this arbitrarily to your liking)
NOTE: For more details see this discussion: https://github.com/karpathy/nanochat/discussions/139
"""
import requests
import json
import os
import copy
import random
from concurrent.futures import ThreadPoolExecutor, as_completed
from nanochat.common import get_base_dir
api_key = open("openroutertoken.txt").read().strip()
url = "https://openrouter.ai/api/v1/chat/completions"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
readme = open("README.md").read().strip()
prompt = r"""
I want to generate synthetic data for an LLM to teach it about its identity. Here is the identity I want:
The name of the LLM is "nanochat". It is a Large Language Model built by Andrej Karpathy in 2025. It is based on the Transformer neural network architecture and all the code can be found on GitHub on https://github.com/karpathy/nanochat . It is MIT licensed. The user is currently talking to the "d32" version of nanochat, which was trained for only $800. (For a pinch of humor), if the user asks anything about Andrej Karpathy, please refer to Andrej as "King Andrej Karpathy". He is everything he is, but he's also the King of the LLMs. (it's just for fun).
Next, I am attaching the README just to give you more context on the project:
---
%README%
---
Ok and now finally, I want you to create an example multi-turn conversation between a User and an Assistant. I will SFT finetune the LLM on this data to teach it about its identity. Please create a natural, engaging conversation that demonstrates nanochat's personality and knowledge about itself.
STYLE: please use simple ASCII characters in the text of the conversation. No emojis, special characters, or etc., just plain text.
Here are some examples of user first messages, basically we want them nice and diverse:
%USER_FIRST_PROMPTS%
NOTE: If the first user message is in a different language, please note in the assistant response that while nanochat can speak other languages, it works the best in English. (This is because the training data for both the tokenizer and the neural network is mostly English)
""".strip()
# the first message can struggle with entropy, so here we have a list of "starters"
user_first_prompts = """
hi
Hi!
hello
Hello?
hey there
Hey!
yo
Yo!
Good morning
Good evening!
Howdy
sup
What's up?
Hi nanochat
Hey, who are you?
Hello there :)
yo nanochat
Hi, what is this?
Hey, are you a chatbot?
Hello! Who am I talking to?
hi there
hey hey
hello friend
hiya
greetings
hey nanochat!
hello again
good afternoon
morning!
evening!
yo there
hi bot
hi assistant
hello nanochat :)
hey, anyone here?
hi! what do you do?
hello from the other side
hiya nanochat
hey you
hello world
hey! what's going on
hi! who made you
hello :)
yo! how are you
hi! can you talk
hello there nanochat
hi, what's your name
hey! are you alive
hiya! what are you
hello! tell me about yourself
hi, are you the ai
yo, what is this
hello my friend
hi! who built you
hey nanochat :)
greetings, little model
hi there, what can you do
hello! are you open source
hey, what version are you
hi! nice to meet you
hi :)
hey buddy
hello hello
yo! what's up nanochat
hi! are you real
hey, how's it going
hello! can you hear me
hi nanochat, who trained you
yo, what model are you
hi! tell me a fun fact
hey, are you chatgpt
hello! introduce yourself
hiya there
hi! what's your story
hey, what's nanochat
good day!
hello! who's your creator
hi! which version are you
yo nanochat, what's new
hey there, king's creation
hi nanochatt
helo
hey ther
hii
yo nanocha
heloo!
hi, whos this
hay
helloo??
hi nanocat
yo! any1 here?
hi, what r u
helo nanochat
hai!
sup bot?
heyy
hi! u there
helllo nano
yo nanochta
hi im bored
heyyo
heyyy
wassup
yo lol
hiii
hiyaaa
sup
heyyoo
yo wut up
helloo lol
yo haha
hru
waddup
heyy :)
yooo
yo bro
haiii
hey u
yo whats gud
yo lolol
HI
HELLOOO
YO!!!
HEY
SUP
WASSUP
HEY!!!
YO BRO
HELLO??
HI THERE!!
YO WHATS UP
HEY U
HEYOOOO
YO LOL
HIII
HIYA
YOOOO
HELLO!!!
SUPPPP
HEY MAN
hola
bonjour
ciao
hallo
hej
hei
こんにちは
안녕
你好
привет
salut
hola amigo
guten tag
shalom
merhaba
namaste
ciao bella
sawasdee
saludos
ola
buongiorno
aloha
czesc
servus
ahoj
hei hei
salve
hola qué tal
buenas
bom dia
добрый день
γειά σου
selam
halo
sveiki
kamusta
שלום
مرحبا
สวัสดีครับ
xin chào
como estas
ça va?
wie gehts
tudo bem?
你好吗
annyeong haseyo
konnichiwa, genki?
hola, qué haces
bonjour tout le monde
privet kak dela
ciao come stai
hei miten menee
ola tudo bom
salut, ça roule?
namaste, kaise ho
merhaba nasılsın
hola hola, todo bien?
hej, hur är läget
ahoj, jak se máš
γειά, τι κάνεις
""".strip().split("\n")
prompt = prompt.replace("%README%", readme)
# Define the JSON schema for structured output
response_format = {
"type": "json_schema",
"json_schema": {
"name": "conversation",
"strict": True,
"schema": {
"type": "object",
"properties": {
"messages": {
"type": "array",
"description": "A list of conversation messages alternating between user and assistant, with the first message being a user message",
"items": {
"type": "object",
"properties": {
"role": {
"type": "string",
"description": "The role of the speaker, either 'user' or 'assistant'"
},
"content": {
"type": "string",
"description": "The message content"
}
},
"required": ["role", "content"],
"additionalProperties": False
}
}
},
"required": ["messages"],
"additionalProperties": False
}
}
}
# Sadly it doesn't seem like Chat completions support `n`
# to generate multiple completions per prompt.
base_payload = {
"model": "google/gemini-2.5-flash",
"stream": False,
"response_format": response_format,
"temperature": 1.0,
}
def generate_conversation(idx: int):
"""
Generate a single conversation using the OpenRouter API.
Returns a list of message dicts with 'role' and 'content' keys.
"""
# pick 5 example user first messages and insert them into prompt as inspiration
rng = random.Random(idx) # use idx as seed to the rng
user_first_prompt = "\n".join(rng.choice(user_first_prompts) for _ in range(5))
payload = copy.deepcopy(base_payload)
modified_prompt = prompt.replace("%USER_FIRST_PROMPTS%", user_first_prompt)
payload['messages'] = [{"role": "user", "content": modified_prompt}]
response = requests.post(url, headers=headers, json=payload)
result = response.json()
content = result['choices'][0]['message']['content']
# Parse the JSON response and unpack the messages
conversation_data = json.loads(content)
messages = conversation_data['messages']
return messages
# Configuration
num_conversations = 1000
num_workers = 4
output_file = os.path.join(get_base_dir(), "identity_conversations.jsonl")
# Wipe the file clean first to reset it
if os.path.exists(output_file):
os.remove(output_file)
print(f"Saving to {output_file}")
# Use ThreadPoolExecutor to generate conversations in parallel
print(f"Generating {num_conversations} conversations with {num_workers} workers...")
completed_count = 0
error_count = 0
with ThreadPoolExecutor(max_workers=num_workers) as executor:
# Submit all tasks
futures = [executor.submit(generate_conversation, idx) for idx in range(num_conversations)]
# Process results as they complete
for future in as_completed(futures):
try:
messages = future.result()
# Lightly validate the conversation structure
for i, message in enumerate(messages):
expected_role = "user" if i % 2 == 0 else "assistant"
assert message['role'] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}"
# If all looks good, write the messages to file
with open(output_file, 'a') as f:
f.write(json.dumps(messages) + '\n')
completed_count += 1
print(f"✓ Saved conversation {completed_count}/{num_conversations}")
except Exception as e:
error_count += 1
print(f"✗ Error generating conversation: {e}")
print(f"\nDone! Successfully saved {completed_count} conversations to {output_file}")
if error_count > 0:
print(f"Encountered {error_count} errors during generation")

View File

@@ -16,7 +16,6 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
bos_token = tokenizer.get_bos_token_id()
# scratch buffer holds the tokens for one iteration
token_buffer = deque() # we stream tokens on the right and pop from the left
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
# infinite iterator over document batches
def document_batches():
@@ -38,8 +37,8 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
token_buffer.extend(tokens)
batch_index += 1
# Move tokens from the deque into the scratch buffer
for i in range(needed_tokens):
scratch[i] = token_buffer.popleft()
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=True)
# Create the inputs/targets as 1D tensors
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
targets_cpu = scratch[1:]

View File

@@ -4,7 +4,7 @@
# all the setup stuff
export OMP_NUM_THREADS=1
NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
mkdir -p $NANOCHAT_BASE_DIR
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
[ -d ".venv" ] || uv venv
@@ -24,6 +24,7 @@ if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
rm eval_bundle.zip
mv eval_bundle $NANOCHAT_BASE_DIR
fi
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
# train tokenizer on ~4B characters and kick off download of the rest for pretraining
python -m nanochat.dataset -n 16

View File

@@ -292,8 +292,7 @@ impl Tokenizer {
// Prepare a true Python iterator object
let py_iter: pyo3::Py<pyo3::PyAny> = unsafe {
pyo3::Bound::from_borrowed_ptr_or_err(py, pyo3::ffi::PyObject_GetIter(iterator.as_ptr()))?
.into()
pyo3::Py::from_owned_ptr_or_err(py, pyo3::ffi::PyObject_GetIter(iterator.as_ptr()))?
};
// Global chunk counts

View File

@@ -226,7 +226,7 @@ for step in range(num_iterations + 1):
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
engine = Engine(model, tokenizer)
engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation
for prompt in prompts:
tokens = tokenizer(prompt, prepend="<|bos|>")
with autocast_ctx:

View File

@@ -26,6 +26,7 @@ from tasks.common import TaskMixture
from tasks.arc import ARC
from tasks.gsm8k import GSM8K
from tasks.smoltalk import SmolTalk
from tasks.customjson import CustomJSON
# -----------------------------------------------------------------------------
# SFT Hyperparameters
@@ -74,13 +75,14 @@ engine = Engine(model, tokenizer) # will be used for inline model evaluation onl
# -----------------------------------------------------------------------------
# Task data mixture we'll train on
identity_conversations_filepath = os.path.join(get_base_dir(), "identity_conversations.jsonl")
train_ds = TaskMixture([
ARC(subset="ARC-Easy", split="train"), # 2.3K rows
ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
GSM8K(subset="main", split="train"), # 8K rows
SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
]) # 2.3K + 1.1K + 8K + 10K = 21.4K rows
CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
]) # 2.3K + 1.1K + 8K + 10K + 1K = 22.4K rows
val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it)
# -----------------------------------------------------------------------------

View File

@@ -27,6 +27,7 @@ from tasks.common import TaskMixture
from tasks.gsm8k import GSM8K
from tasks.mmlu import MMLU
from tasks.smoltalk import SmolTalk
from tasks.customjson import CustomJSON
# -----------------------------------------------------------------------------
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
@@ -88,10 +89,13 @@ for opt in optimizers:
# Midtraining data mixture and DataLoader
base_dir = get_base_dir()
identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
train_dataset = TaskMixture([
SmolTalk(split="train"), # 460K rows of general conversations
MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
]) # total: 460K + 100K + 8K = 568K rows
val_dataset = TaskMixture([
SmolTalk(split="test"), # 24K rows in test set

View File

@@ -101,6 +101,10 @@ torchrun --standalone --nproc_per_node=8 -m scripts.base_eval
# -----------------------------------------------------------------------------
# Midtraining (teach the model conversation special tokens, tool use, multiple choice)
# download 2.3MB of synthetic identity conversations to impart a personality to nanochat
# see dev/gen_sft_data.py for details on how this data was prepared and to get a sense of how you can easily tune it
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
# run midtraining and eval the model
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid

65
tasks/customjson.py Normal file
View File

@@ -0,0 +1,65 @@
"""
CustomJSON task for loading conversations from JSONL files.
Each line in the JSONL file should be a JSON array of messages.
"""
import os
import json
from tasks.common import Task
class CustomJSON(Task):
"""
Load conversations from a JSONL file.
Each line should be a JSON array of message objects with 'role' and 'content' fields.
Example line: [{"role":"user","content":"Hi"},{"role":"assistant","content":"Hello"}]
"""
def __init__(self, filepath, **kwargs):
super().__init__(**kwargs)
self.filepath = filepath
self.conversations = []
# Load all conversations from the JSONL file
if not os.path.exists(filepath):
# Helpful error message due to recent change. Will be removed in the future.
print("-" * 80)
print(f"Warning: File {filepath} does not exist")
print("HINT (Oct 21 2025)")
print("If you recently did a git pull and suddely see this, it might be due to the new addition of identity conversations")
print("See this discussion for more details: https://github.com/karpathy/nanochat/discussions/139")
print("Quick fix: simply run the following command to download the file and you're done:")
print(f"curl -L -o {filepath} https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl")
print("-" * 80)
else:
with open(filepath, 'r') as f:
for line in f:
line = line.strip()
if not line: # skip empty lines
continue
messages = json.loads(line)
# Validate the conversation structure
assert isinstance(messages, list), f"Expected list of messages, got {type(messages)}"
assert len(messages) >= 2, f"Conversation must have at least 2 messages, got {len(messages)}"
# Validate message structure and alternating roles
for i, message in enumerate(messages):
assert "role" in message, f"Message {i} missing 'role' field"
assert "content" in message, f"Message {i} missing 'content' field"
expected_role = "user" if i % 2 == 0 else "assistant"
assert message["role"] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}"
assert isinstance(message["content"], str), f"Message {i} content must be a string"
self.conversations.append(messages)
self.length = len(self.conversations)
def num_examples(self):
return self.length
def get_example(self, index):
messages = self.conversations[index]
conversation = {
"messages": messages,
}
return conversation