add the SpellingBee task so that nanochat can count r in strawberry etc. along the way we had to add a bunch of new functionality, e.g. extend the calculator to support the count function of python. possibly the current TaskMixture uses way too many synthetic examples of SpellingBee because the eval gives us exactly 100% performance on spelling. We can tune this later to reclaim some wall clock time here I think

This commit is contained in:
Andrej Karpathy
2025-10-24 14:02:48 +00:00
parent 81597cd616
commit 8892470f29
6 changed files with 377 additions and 6 deletions

View File

@@ -5,6 +5,8 @@ Common utilities for nanochat.
import os import os
import re import re
import logging import logging
import fcntl
import urllib.request
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@@ -56,6 +58,44 @@ def get_base_dir():
os.makedirs(nanochat_dir, exist_ok=True) os.makedirs(nanochat_dir, exist_ok=True)
return nanochat_dir return nanochat_dir
def download_file_with_lock(url, filename):
"""
Downloads a file from a URL to a local path in the base directory.
Uses a lock file to prevent concurrent downloads among multiple ranks.
"""
base_dir = get_base_dir()
file_path = os.path.join(base_dir, filename)
lock_path = file_path + ".lock"
if os.path.exists(file_path):
return file_path
with open(lock_path, 'w') as lock_file:
# Only a single rank can acquire this lock
# All other ranks block until it is released
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
if os.path.exists(file_path):
return file_path
print(f"Downloading {url}...")
with urllib.request.urlopen(url) as response:
content = response.read().decode('utf-8')
with open(file_path, 'w') as f:
f.write(content)
print(f"Downloaded to {file_path}")
# Clean up the lock file after the lock is released
try:
os.remove(lock_path)
except OSError:
pass # Ignore if already removed by another process
return file_path
def print0(s="",**kwargs): def print0(s="",**kwargs):
ddp_rank = int(os.environ.get('RANK', 0)) ddp_rank = int(os.environ.get('RANK', 0))
if ddp_rank == 0: if ddp_rank == 0:

View File

@@ -44,12 +44,38 @@ def eval_with_timeout(formula, max_time=3):
return None return None
def use_calculator(expr): def use_calculator(expr):
"""Evaluate a math expression safely.""" """
Evaluate a Python expression safely.
Supports both math expressions and string operations like .count()
"""
# Remove commas from numbers
expr = expr.replace(",", "") expr = expr.replace(",", "")
if any([x not in "0123456789*+-/.() " for x in expr]): # for now disallow non-numeric chars
# Check if it's a pure math expression (old behavior)
if all([x in "0123456789*+-/.() " for x in expr]):
if "**" in expr: # disallow power operator
return None
return eval_with_timeout(expr)
# Check if it's a string operation we support
# Allow: strings (single/double quotes), .count(), letters, numbers, spaces, parens
allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
if not all([x in allowed_chars for x in expr]):
return None return None
if "**" in expr: # for now disallow power operator, could be very expensive
# Disallow dangerous patterns
dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file',
'input', 'raw_input', 'globals', 'locals', 'vars', 'dir',
'getattr', 'setattr', 'delattr', 'hasattr']
expr_lower = expr.lower()
if any(pattern in expr_lower for pattern in dangerous_patterns):
return None return None
# Only allow .count() method for now (can expand later)
if '.count(' not in expr:
return None
# Evaluate with timeout
return eval_with_timeout(expr) return eval_with_timeout(expr)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

View File

@@ -23,6 +23,7 @@ from tasks.humaneval import HumanEval
from tasks.mmlu import MMLU from tasks.mmlu import MMLU
from tasks.arc import ARC from tasks.arc import ARC
from tasks.gsm8k import GSM8K from tasks.gsm8k import GSM8K
from tasks.spellingbee import SpellingBee
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Generative evaluation loop (we go one problem at a time, sample, evaluate) # Generative evaluation loop (we go one problem at a time, sample, evaluate)
@@ -165,6 +166,7 @@ def run_chat_eval(task_name, model, tokenizer, engine,
'ARC-Easy': partial(ARC, subset="ARC-Easy", split="test"), 'ARC-Easy': partial(ARC, subset="ARC-Easy", split="test"),
'ARC-Challenge': partial(ARC, subset="ARC-Challenge", split="test"), 'ARC-Challenge': partial(ARC, subset="ARC-Challenge", split="test"),
'GSM8K': partial(GSM8K, subset="main", split="test"), 'GSM8K': partial(GSM8K, subset="main", split="test"),
'SpellingBee': partial(SpellingBee, size=256, split="test"),
}[task_name] }[task_name]
task_object = task_module() task_object = task_module()
# Run the evaluation # Run the evaluation
@@ -204,13 +206,14 @@ if __name__ == "__main__":
engine = Engine(model, tokenizer) engine = Engine(model, tokenizer)
# Get the tasks to evaluate on # Get the tasks to evaluate on
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval'] all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee']
baseline_accuracies = { baseline_accuracies = {
'ARC-Easy': 0.25, # multiple choice 1 of 4 => 25% 'ARC-Easy': 0.25, # multiple choice 1 of 4 => 25%
'ARC-Challenge': 0.25, # multiple choice 1 of 4 => 25% 'ARC-Challenge': 0.25, # multiple choice 1 of 4 => 25%
'MMLU': 0.25, # multiple choice 1 of 4 => 25% 'MMLU': 0.25, # multiple choice 1 of 4 => 25%
'GSM8K': 0.0, # open-ended => 0% 'GSM8K': 0.0, # open-ended => 0%
'HumanEval': 0.0, # open-ended => 0% 'HumanEval': 0.0, # open-ended => 0%
'SpellingBee': 0.0, # open-ended => 0%
} }
task_names = all_tasks if args.task_name is None else args.task_name.split('|') task_names = all_tasks if args.task_name is None else args.task_name.split('|')

View File

@@ -28,6 +28,7 @@ from tasks.arc import ARC
from tasks.gsm8k import GSM8K from tasks.gsm8k import GSM8K
from tasks.smoltalk import SmolTalk from tasks.smoltalk import SmolTalk
from tasks.customjson import CustomJSON from tasks.customjson import CustomJSON
from tasks.spellingbee import SimpleSpelling, SpellingBee
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# SFT Hyperparameters # SFT Hyperparameters
@@ -86,7 +87,9 @@ train_ds = TaskMixture([
GSM8K(subset="main", split="train"), # 8K rows GSM8K(subset="main", split="train"), # 8K rows
SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
]) # 2.3K + 1.1K + 8K + 10K + 1K = 22.4K rows SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling (e.g. spell the word 'apple')
SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]) # 2.3K + 1.1K + 8K + 10K + 1K + 0.3K + 0.3K = 23K rows
val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it) val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

View File

@@ -28,6 +28,7 @@ from tasks.gsm8k import GSM8K
from tasks.mmlu import MMLU from tasks.mmlu import MMLU
from tasks.smoltalk import SmolTalk from tasks.smoltalk import SmolTalk
from tasks.customjson import CustomJSON from tasks.customjson import CustomJSON
from tasks.spellingbee import SimpleSpelling, SpellingBee
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb) run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
@@ -100,7 +101,9 @@ train_dataset = TaskMixture([
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
]) # total: 460K + 100K + 8K = 568K rows SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]) # total: 460K + 100K + 8K + 200K + 80K = 848K rows
val_dataset = TaskMixture([ val_dataset = TaskMixture([
SmolTalk(split="test"), # 24K rows in test set SmolTalk(split="test"), # 24K rows in test set
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios

296
tasks/spellingbee.py Normal file
View File

@@ -0,0 +1,296 @@
"""
Task intended to make nanochat better in spelling and counting, for example:
"How many r are in strawberry?" -> 3
An interesting part of this task is that we will get the assistant to
solve the problem using a combination of manual counting and Python.
This is a good problem solving "instinct" to mix into the model and RL
may further refine it to trust one over the other. If we were extra fancy
(which we could/should be) we'd add small errors here and there to allow
the model also learn recoveries. We can do this in future versions.
There are two tasks in this file:
1. SpellingBee: Counting the number of occurrences of a letter in a word
2. SimpleSpelling: Simply spelling words
(1) is the goal, but (2) exists as a highly condensed version of the part
that makes (1) difficult, which is word spelling. This is non-trivial for an
LLM because it has to learn how every token (a little semantic chunk/atom)
maps to the sequence of individual characters that make it up. Larger models
learn this eventually on their own, but if we want this capability to exist
in smaller models, we have to actively encourage it by over-representing it
in the training data. Midtraining is a good place to do this.
To preview a few example conversations, run:
python -m tasks.spellingbee
"""
import re
import random
from tasks.common import Task
from nanochat.common import download_file_with_lock
# Letters of the alphabet
LETTERS = "abcdefghijklmnopqrstuvwxyz"
# A list of 370K English words of large variety
WORD_LIST_URL = "https://raw.githubusercontent.com/dwyl/english-words/refs/heads/master/words_alpha.txt"
# Identical to gsm8k's answer extraction
ANSWER_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
def extract_answer(completion):
"""
Extract the numerical answer after #### marker.
"""
match = ANSWER_RE.search(completion)
if match:
match_str = match.group(1).strip()
match_str = match_str.replace(",", "")
return match_str
return None
# User message templates for data augmentation
USER_MSG_TEMPLATES = [
"How many {letter} are in the word {word}",
"How many {letter} are in {word}",
"Count the number of {letter} in {word}",
"How many times does {letter} appear in {word}",
"What's the count of {letter} in {word}",
"In the word {word}, how many {letter} are there",
"How many letter {letter} are in the word {word}",
"Count how many {letter} appear in {word}",
"Tell me the number of {letter} in {word}",
"How many occurrences of {letter} are in {word}",
"Find the count of {letter} in {word}",
"Can you count the {letter} letters in {word}",
"What is the frequency of {letter} in {word}",
"How many {letter}s are in {word}",
"How many {letter}'s are in {word}",
"Count all the {letter} in {word}",
"How many times is {letter} in {word}",
"Number of {letter} in {word}",
"Total count of {letter} in {word}",
"How many {letter} does {word} have",
"How many {letter} does {word} contain",
"What's the number of {letter} in {word}",
"{word} has how many {letter}",
"In {word}, count the {letter}",
"How many {letter} appear in {word}",
"Count the {letter} in {word}",
"Give me the count of {letter} in {word}",
"How many instances of {letter} in {word}",
"Show me how many {letter} are in {word}",
"Calculate the number of {letter} in {word}",
# Spanish
"¿Cuántas {letter} hay en {word}?",
"¿Cuántas veces aparece {letter} en {word}?",
"Cuenta las {letter} en {word}",
"¿Cuántas letras {letter} tiene {word}?",
# Chinese (Simplified)
"{word}中有多少个{letter}",
"{word}里有几个{letter}",
"数一下{word}中的{letter}",
"{word}这个词里有多少{letter}",
# Korean
"{word}{letter}가 몇 개 있나요",
"{word}에서 {letter}의 개수는",
"{word}{letter}가 몇 번 나오나요",
"{word}라는 단어에 {letter}가 몇 개",
# French
"Combien de {letter} dans {word}",
"Combien de fois {letter} apparaît dans {word}",
"Compte les {letter} dans {word}",
# German
"Wie viele {letter} sind in {word}",
"Wie oft kommt {letter} in {word} vor",
"Zähle die {letter} in {word}",
# Japanese
"{word}{letter}は何個ありますか",
"{word}の中に{letter}がいくつ",
"{word}{letter}が何回出てくる",
]
class SpellingBee(Task):
def __init__(self, size=1000, split="train", **kwargs):
super().__init__(**kwargs)
assert split in ["train", "test"], "SpellingBee split must be train|test"
self.size = size
self.split = split
filename = WORD_LIST_URL.split("/")[-1]
word_list_path = download_file_with_lock(WORD_LIST_URL, filename)
with open(word_list_path) as f:
words = [line.strip() for line in f]
self.words = words
@property
def eval_type(self):
return 'generative'
def num_examples(self):
return self.size
def get_example(self, index):
seed = index if self.split == "train" else -(index + 1) # avoid collision at 0
rng = random.Random(seed)
# pick a random word
word = rng.choice(self.words)
# pick a letter from it (90%) or a random letter (10%)
letter = rng.choice(word) if rng.random() < 0.9 else rng.choice(LETTERS)
# get the correct answer by simply counting
count = word.count(letter)
# create a user message, with a bunch of variations as data augmentation
template = rng.choice(USER_MSG_TEMPLATES)
# 30% chance to lowercase the template (lazy people don't use shift)
if rng.random() < 0.3:
template = template.lower()
quote_options = ['', "'", '"']
letter_quote = rng.choice(quote_options) # is the letter quoted?
word_quote = rng.choice(quote_options) # is the word quoted?
letter_wrapped = f"{letter_quote}{letter}{letter_quote}"
word_wrapped = f"{word_quote}{word}{word_quote}"
user_msg = template.format(letter=letter_wrapped, word=word_wrapped)
if rng.random() < 0.5: # 50% of people don't even use question marks
user_msg += "?"
# Now create the ideal assistant response - build as parts (text + tool calls)
assistant_parts = []
word_letters = ",".join(list(word))
manual_text = f"""We are asked to find the number '{letter}' in the word '{word}'. Let me try a manual approach first.
First spell the word out:
{word}:{word_letters}
Then count the occurrences of '{letter}':
"""
# Little simulated loop of the solution process
# TODO: This is where the fun starts, we could simulate cute little mistakes
# and get the model to review its work and recover from them.
# You might of course hope this could arise in RL too, but realistically you'd want to help it out a bit.
running_count = 0
for i, char in enumerate(word, 1):
if char == letter:
running_count += 1
# note: there deliberately cannot be a space here between i and char
# because this would create a different token! (e.g. " a" and "a" are different tokens)
manual_text += f"{i}:{char} hit! count={running_count}\n"
else:
manual_text += f"{i}:{char}\n"
manual_text += f"\nThis gives us {running_count}."
assistant_parts.append({"type": "text", "text": manual_text})
# Part 2: Python verification
assistant_parts.append({"type": "text", "text": "\n\nLet me double check this using Python:\n\n"})
# Part 3: Python tool call
python_expr = f"'{word}'.count('{letter}')"
assistant_parts.append({"type": "python", "text": python_expr})
# Part 4: Python output
assistant_parts.append({"type": "python_output", "text": str(count)})
# Part 5: Final answer
assistant_parts.append({"type": "text", "text": f"\n\nPython gives us {count}.\n\nMy final answer is:\n\n#### {count}"})
# return the full conversation
messages = [
{"role": "user", "content": user_msg},
{"role": "assistant", "content": assistant_parts}
]
conversation = {
"messages": messages,
}
return conversation
def evaluate(self, conversation, assistant_response):
"""
Given (conversation, completion), return evaluation outcome (0 = wrong, 1 = correct)
Identical to gsm8k's evaluation.
"""
assert isinstance(assistant_response, str), "Assuming simple string response for now"
# First extract the ground truth answer from the conversation
assistant_message = conversation['messages'][-1]
assert assistant_message['role'] == "assistant", "Last message must be from the Assistant"
assert isinstance(assistant_message['content'], list), "This is expected to be a list of parts"
# The last text part contains the final answer with ####
last_text_part = assistant_message['content'][-1]['text']
# Extract both the ground truth answer and the predicted answer
ref_num = extract_answer(last_text_part)
pred_num = extract_answer(assistant_response)
# Compare and return the success as int
is_correct = int(pred_num == ref_num)
return is_correct
def reward(self, conversation, assistant_response):
""" Use simple 0-1 reward just like gsm8k."""
is_correct = self.evaluate(conversation, assistant_response)
is_correct_float = float(is_correct)
return is_correct_float
class SimpleSpelling(Task):
"""Much simpler task designed to get the model to just practice spelling words."""
def __init__(self, size=1000, split="train", **kwargs):
super().__init__(**kwargs)
assert split in ["train", "test"], "SpellingBee split must be train|test"
self.size = size
self.split = split
filename = WORD_LIST_URL.split("/")[-1]
word_list_path = download_file_with_lock(WORD_LIST_URL, filename)
with open(word_list_path) as f:
words = [line.strip() for line in f]
rng = random.Random(42)
rng.shuffle(words) # use a different word order than the SpellingBee task
self.words = words
@property
def eval_type(self):
return 'generative'
def num_examples(self):
return self.size
def get_example(self, index):
seed = index if self.split == "train" else -(index + 1) # avoid collision at 0
rng = random.Random(seed)
# pick a random word
word = rng.choice(self.words)
word_letters = ",".join(list(word))
# return the full conversation
messages = [
{"role": "user", "content": f"Spell the word: {word}"},
{"role": "assistant", "content": f"{word}: {word_letters}"}
]
conversation = {
"messages": messages,
}
return conversation
if __name__ == "__main__":
# preview the SpellingBee task, first 10 examples
task = SpellingBee()
for i in range(10):
ex = task.get_example(i)
print("=" * 100)
print(ex['messages'][0]['content'])
print("-" * 100)
# Assistant content is now a list of parts
assistant_parts = ex['messages'][1]['content']
for part in assistant_parts:
if part['type'] == 'text':
print(part['text'], end='')
elif part['type'] == 'python':
print(f"<<{part['text']}=", end='')
elif part['type'] == 'python_output':
print(f"{part['text']}>>", end='')
print()
print("-" * 100)
# also scrutinize the tokenization (last example only)
# from nanochat.tokenizer import get_tokenizer
# tokenizer = get_tokenizer()
# ids, mask = tokenizer.render_conversation(ex)
# print(tokenizer.visualize_tokenization(ids, mask, with_token_id=True))