add the SpellingBee task so that nanochat can count r in strawberry etc. along the way we had to add a bunch of new functionality, e.g. extend the calculator to support the count function of python. possibly the current TaskMixture uses way too many synthetic examples of SpellingBee because the eval gives us exactly 100% performance on spelling. We can tune this later to reclaim some wall clock time here I think

This commit is contained in:
Andrej Karpathy
2025-10-24 14:02:48 +00:00
parent 81597cd616
commit 8892470f29
6 changed files with 377 additions and 6 deletions

View File

@@ -23,6 +23,7 @@ from tasks.humaneval import HumanEval
from tasks.mmlu import MMLU
from tasks.arc import ARC
from tasks.gsm8k import GSM8K
from tasks.spellingbee import SpellingBee
# -----------------------------------------------------------------------------
# Generative evaluation loop (we go one problem at a time, sample, evaluate)
@@ -165,6 +166,7 @@ def run_chat_eval(task_name, model, tokenizer, engine,
'ARC-Easy': partial(ARC, subset="ARC-Easy", split="test"),
'ARC-Challenge': partial(ARC, subset="ARC-Challenge", split="test"),
'GSM8K': partial(GSM8K, subset="main", split="test"),
'SpellingBee': partial(SpellingBee, size=256, split="test"),
}[task_name]
task_object = task_module()
# Run the evaluation
@@ -204,13 +206,14 @@ if __name__ == "__main__":
engine = Engine(model, tokenizer)
# Get the tasks to evaluate on
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval']
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee']
baseline_accuracies = {
'ARC-Easy': 0.25, # multiple choice 1 of 4 => 25%
'ARC-Challenge': 0.25, # multiple choice 1 of 4 => 25%
'MMLU': 0.25, # multiple choice 1 of 4 => 25%
'GSM8K': 0.0, # open-ended => 0%
'HumanEval': 0.0, # open-ended => 0%
'SpellingBee': 0.0, # open-ended => 0%
}
task_names = all_tasks if args.task_name is None else args.task_name.split('|')

View File

@@ -28,6 +28,7 @@ from tasks.arc import ARC
from tasks.gsm8k import GSM8K
from tasks.smoltalk import SmolTalk
from tasks.customjson import CustomJSON
from tasks.spellingbee import SimpleSpelling, SpellingBee
# -----------------------------------------------------------------------------
# SFT Hyperparameters
@@ -86,7 +87,9 @@ train_ds = TaskMixture([
GSM8K(subset="main", split="train"), # 8K rows
SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
]) # 2.3K + 1.1K + 8K + 10K + 1K = 22.4K rows
SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling (e.g. spell the word 'apple')
SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]) # 2.3K + 1.1K + 8K + 10K + 1K + 0.3K + 0.3K = 23K rows
val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it)
# -----------------------------------------------------------------------------

View File

@@ -28,6 +28,7 @@ from tasks.gsm8k import GSM8K
from tasks.mmlu import MMLU
from tasks.smoltalk import SmolTalk
from tasks.customjson import CustomJSON
from tasks.spellingbee import SimpleSpelling, SpellingBee
# -----------------------------------------------------------------------------
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
@@ -100,7 +101,9 @@ train_dataset = TaskMixture([
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
]) # total: 460K + 100K + 8K = 568K rows
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]) # total: 460K + 100K + 8K + 200K + 80K = 848K rows
val_dataset = TaskMixture([
SmolTalk(split="test"), # 24K rows in test set
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios