initial commit

This commit is contained in:
karpathy
2025-10-13 06:49:24 -07:00
commit 3a5e0bc50b
47 changed files with 10292 additions and 0 deletions

48
tasks/arc.py Normal file
View File

@@ -0,0 +1,48 @@
"""
The ARC dataset from Allen AI.
https://huggingface.co/datasets/allenai/ai2_arc
"""
from datasets import load_dataset
from tasks.common import Task, render_mc
class ARC(Task):
def __init__(self, subset, split, **kwargs):
super().__init__(**kwargs)
assert subset in ["ARC-Easy", "ARC-Challenge"], "ARC subset must be ARC-Easy or ARC-Challenge"
assert split in ["train", "validation", "test"], "ARC split must be train|validation|test"
self.ds = load_dataset("allenai/ai2_arc", subset, split=split).shuffle(seed=42)
@property
def eval_type(self):
return 'categorical'
def num_examples(self):
return len(self.ds)
def get_example(self, index):
row = self.ds[index]
question = row["question"] # the question text
choices = row["choices"]["text"] # the text of each choice
answer_string = row["answerKey"] # e.g. "A", "B", "C", "D"
letters = row["choices"]["label"] # e.g. ["A", "B", "C", "D"]
assert answer_string in letters, f"ARC answer {answer_string} must be one of {letters}" # sanity check
# create and return the Conversation object
user_message = render_mc(question, letters, choices)
messages = [
{"role": "user", "content": user_message},
{"role": "assistant", "content": answer_string}
]
conversation = {
"messages": messages,
"letters": letters, # useful during evaluation, so we can narrow and clamp the assistant prediction to one of the letters
}
return conversation
def evaluate(self, conversation, assistant_response):
# the assert here is not strictly speaking needed, but currently the way we eval, we expect this to be true
# I'm going to leave the assert here to prevent footguns, but possibly in the future can remove it.
assert assistant_response in conversation['letters'], f"ARC answer {assistant_response} is expected to be one of {conversation['letters']}"
assistant_message = conversation['messages'][-1]['content'] # e.g. "A"
return assistant_response == assistant_message

147
tasks/common.py Normal file
View File

@@ -0,0 +1,147 @@
"""
Base class for all Tasks.
A Task is basically a dataset of conversations, together with some
metadata and often also evaluation criteria.
Example tasks: MMLU, ARC-Easy, ARC-Challenge, GSM8K, HumanEval, SmolTalk.
"""
import random
class Task:
"""
Base class of a Task. Allows for lightweight slicing of the underlying dataset.
"""
def __init__(self, start=0, stop=None, step=1):
# allows a lightweight logical view over a dataset
assert start >= 0, f"Start must be non-negative, got {start}"
assert stop is None or stop >= start, f"Stop should be greater than or equal to start, got {stop} and {start}"
assert step >= 1, f"Step must be strictly positive, got {step}"
self.start = start
self.stop = stop # could be None here
self.step = step
@property
def eval_type(self):
# one of 'generative' | 'categorical'
raise NotImplementedError
def num_examples(self):
raise NotImplementedError
def get_example(self, index):
raise NotImplementedError
def __len__(self):
start = self.start
stop = self.num_examples() if self.stop is None else self.stop
step = self.step
span = stop - start
num = (span + step - 1) // step # ceil_div(span, step)
assert num >= 0, f"Negative number of examples???: {num}" # prevent footguns
return num
def __getitem__(self, index: int):
assert isinstance(index, int), f"Index must be an integer, got {type(index)}"
physical_index = self.start + index * self.step
conversation = self.get_example(physical_index)
return conversation
def evaluate(self, problem, completion):
raise NotImplementedError
class TaskMixture(Task):
"""
For SFT Training it becomes useful to train on a tax mixture of datasets.
Fun trick: if you wish to oversample any task, just pass it in multiple times in the list.
"""
def __init__(self, tasks, **kwargs):
super().__init__(**kwargs)
# tasks is a list of Task objects
self.tasks = tasks
self.lengths = [len(task) for task in self.tasks]
self.num_conversations = sum(self.lengths)
# Build list of all (task_idx, local_idx) pairs
self.index_map = []
for task_idx, task_length in enumerate(self.lengths):
for local_idx in range(task_length):
self.index_map.append((task_idx, local_idx))
# Deterministically shuffle to mix tasks throughout training
rng = random.Random(42)
rng.shuffle(self.index_map)
# Note: this is not the most elegant or best solution, but it's ok for now
def num_examples(self):
return self.num_conversations
def get_example(self, index):
"""
Access conversations according to a deterministic shuffle of all examples.
This ensures tasks are mixed throughout training, regardless of dataset size.
"""
assert 0 <= index < self.num_conversations, f"Index {index} out of range for mixture with {self.num_conversations} conversations"
task_idx, local_idx = self.index_map[index]
return self.tasks[task_idx][local_idx]
class TaskSequence(Task):
"""
For SFT Training sometimes we want to sequentially train on a list of tasks.
This is useful for cases that require a training curriculum.
"""
def __init__(self, tasks, **kwargs):
super().__init__(**kwargs)
self.tasks = tasks
self.lengths = [len(task) for task in self.tasks]
self.num_conversations = sum(self.lengths)
def num_examples(self):
return self.num_conversations
def get_example(self, index):
assert 0 <= index < self.num_conversations, f"Index {index} out of range for sequence with {self.num_conversations} conversations"
for task_idx, task_length in enumerate(self.lengths):
if index < task_length:
return self.tasks[task_idx][index]
index -= task_length
def render_mc(question, letters, choices):
"""
The common multiple choice rendering format we will use.
Note two important design decisions:
1)
Bigger models don't care as much, but smaller models prefer to have
the letter *after* the choice, which results in better binding.
2)
There is no whitespace between the delimiter (=) and the letter.
This is actually critical because the tokenizer has different token ids
for " A" vs. "A". The assistant responses will be just the letter itself,
i.e. "A", so it is important that here in the prompt it is the exact same
token, i.e. "A" with no whitespace before it. Again, bigger models don't care
about this too much, but smaller models do care about some of these details.
"""
query = f"Multiple Choice question: {question}\n"
query += "".join([f"- {choice}={letter}\n" for letter, choice in zip(letters, choices)])
query += "\nRespond only with the letter of the correct answer."
return query
if __name__ == "__main__":
# very lightweight test of slicing
from tasks.mmlu import MMLU
ds = MMLU(subset="auxiliary_train", split="train")
print("Length of MMLU: ", len(ds))
ex = ds[5]
print("5th example: ", ex)
ds = MMLU(subset="auxiliary_train", split="train", start=5, stop=10)
print("Length of sliced MMLU[5:10]: ", len(ds))
print("0th example of sliced MMLU: ", ds[0])
print("They match: ", ex == ds[0])

117
tasks/gsm8k.py Normal file
View File

@@ -0,0 +1,117 @@
"""
GSM8K evaluation.
https://huggingface.co/datasets/openai/gsm8k
Example problem instance:
Question:
Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?
Answer:
Weng earns 12/60 = $<<12/60=0.2>>0.2 per minute.
Working 50 minutes, she earned 0.2 x 50 = $<<0.2*50=10>>10.
#### 10
Notice that GSM8K uses tool calls inside << >> tags.
"""
import re
from datasets import load_dataset
from tasks.common import Task
GSM_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
def extract_answer(completion):
"""
Extract the numerical answer after #### marker.
Follows official code for normalization:
https://github.com/openai/grade-school-math/blob/3101c7d5072418e28b9008a6636bde82a006892c/grade_school_math/dataset.py#L28
"""
match = GSM_RE.search(completion)
if match:
match_str = match.group(1).strip()
match_str = match_str.replace(",", "")
return match_str
return None
class GSM8K(Task):
def __init__(self, subset, split, **kwargs):
super().__init__(**kwargs)
assert subset in ["main", "socratic"], "GSM8K subset must be main|socratic"
assert split in ["train", "test"], "GSM8K split must be train|test"
self.ds = load_dataset("openai/gsm8k", subset, split=split).shuffle(seed=42)
@property
def eval_type(self):
return 'generative'
def num_examples(self):
return len(self.ds)
def get_example(self, index):
""" Get a single problem from the dataset. """
row = self.ds[index]
question = row['question'] # string of the question prompt
answer = row['answer'] # string of the full solution and the answer after #### marker
# Create and return the Conversation object
# This is tricky because GSM8K uses tool calls, which we need to parse here.
assistant_message_parts = []
parts = re.split(r'(<<[^>]+>>)', answer)
for part in parts:
if part.startswith('<<') and part.endswith('>>'):
# This is a calculator tool call
inner = part[2:-2] # Remove << >>
# Split on = to get expression and result
if '=' in inner:
expr, result = inner.rsplit('=', 1)
else:
expr, result = inner, ""
# Add the tool call as a part
assistant_message_parts.append({"type": "python", "text": expr})
# Add the result as a part
assistant_message_parts.append({"type": "python_output", "text": result})
else:
# Regular text in between tool calls
assistant_message_parts.append({"type": "text", "text": part})
# No put it all together
messages = [
{"role": "user", "content": question}, # note: simple string
{"role": "assistant", "content": assistant_message_parts}, # note: list of parts (as dicts)
]
conversation = {
"messages": messages,
}
return conversation
def evaluate(self, conversation, assistant_response):
"""
Given (conversation, completion), return evaluation outcome (0 = wrong, 1 = correct)
Note that:
- the conversation has both user AND assistant message (containing the ground truth answer)
- the assistant_response is usually the alternative assistant message achieved via sampling
TODO: Technically, assistant_response should be a Message (either a string or a list of parts)
We can handle this later possibly. For now just assume string.
"""
assert isinstance(assistant_response, str), "Assuming simple string response for now"
# First extract the ground truth answer
assistant_message = conversation['messages'][-1]
assert assistant_message['role'] == "assistant", "Last message must be from the Assistant"
assert isinstance(assistant_message['content'], list), "This is expected to be a list of parts"
last_text_part = assistant_message['content'][-1]['text'] # this contains the final answer in GSM8K
# Extract both the ground truth answer and the predicted answer
ref_num = extract_answer(last_text_part)
pred_num = extract_answer(assistant_response)
# Compare and return the success as int
is_correct = int(pred_num == ref_num)
return is_correct
def reward(self, conversation, assistant_response):
"""
Used during RL. To keep things simple, just re-use the evaluation above.
Later this could be made more complex (e.g. format matching etc.)
"""
is_correct = self.evaluate(conversation, assistant_response)
is_correct_float = float(is_correct)
return is_correct_float

97
tasks/humaneval.py Normal file
View File

@@ -0,0 +1,97 @@
"""
Evaluate the Chat model on HumanEval dataset.
Btw this dataset is a misnomer and has nothing to do with humans.
It is a coding benchmark.
"""
import re
from datasets import load_dataset
from nanochat.execution import execute_code
from tasks.common import Task
def extract_imports(prompt):
"""Extract import statements from the beginning of a code block."""
imports = []
for line in prompt.split('\n'):
stripped = line.strip()
if stripped.startswith('import ') or stripped.startswith('from '):
imports.append(stripped)
elif stripped and not stripped.startswith('#'):
# Stop at first non-import, non-comment line
break
return '\n'.join(imports)
def extract_program(completion):
"""
Extract Python code from LLM completion.
Handles various output formats:
- Code wrapped in ```python ... ``` or ``` ... ``` blocks
- Plain code without markdown blocks
- Extra text before/after code blocks
Returns the first code block if found, otherwise returns the whole completion.
"""
# Try to find markdown code blocks (```python or just ```)
# Match ```python\n...\n``` or ```\n...\n```
pattern = r'```(?:python)?\s*\n(.*?)\n```'
matches = re.findall(pattern, completion, re.DOTALL)
if matches:
# Return the first code block found
return matches[0].strip()
# No code blocks found, return the whole completion
return completion.strip()
class HumanEval(Task):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.ds = load_dataset("openai/openai_humaneval", split="test").shuffle(seed=42)
@property
def eval_type(self):
return 'generative'
def num_examples(self):
return len(self.ds)
def get_example(self, index):
""" Get a single problem from the dataset. """
row = self.ds[index]
prompt = row['prompt'] # prompts in HumanEval are the beginning of the program
solution = row['canonical_solution'] # the correct continuation of the program
entry_point = row['entry_point'] # the function to check
test = row['test'] # the test cases
complete_solution = f"{prompt}\n{solution}"
messages = [
{"role": "user", "content": prompt},
{"role": "assistant", "content": complete_solution},
]
conversation = {
"messages": messages,
"entry_point": entry_point, # needed during evaluation
"test": test, # needed during evaluation
}
return conversation
def evaluate(self, conversation, completion):
""" Given (conversation, completion), return boolean success of the completion. """
# the prompt will contain the imports and the function signature
imports = extract_imports(conversation['messages'][0]['content'])
# the completion will usually contain the whole function
# but not always with the needed imports, so we manually append them
completion_code = extract_program(completion)
program = (
imports
+ "\n\n"
+ completion_code
+ "\n\n"
+ conversation['test']
+ "\n"
+ f"check({conversation['entry_point']})"
)
result = execute_code(program)
success = result.success
return success

60
tasks/mmlu.py Normal file
View File

@@ -0,0 +1,60 @@
"""
The MMLU dataset.
https://huggingface.co/datasets/cais/mmlu
"""
from datasets import load_dataset
from tasks.common import Task, render_mc
class MMLU(Task):
letters = ('A', 'B', 'C', 'D')
groups = ('abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge', 'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_medicine', 'college_physics', 'computer_security', 'conceptual_physics', 'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic', 'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', 'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics', 'high_school_mathematics', 'high_school_microeconomics', 'high_school_physics', 'high_school_psychology', 'high_school_statistics', 'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality', 'international_law', 'jurisprudence', 'logical_fallacies', 'machine_learning', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios', 'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law', 'professional_medicine', 'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy', 'virology', 'world_religions')
def __init__(self, subset, split, **kwargs):
super().__init__(**kwargs)
assert subset in ["all", "auxiliary_train"], f"subset {subset} must be all|auxiliary_train"
assert split in ["train", "validation", "dev", "test"], f"split {split} must be train|validation|dev|test"
if subset == "auxiliary_train":
assert split == "train", "auxiliary_train must be split into train"
self.subset = subset
self.split = split
self.ds = load_dataset("cais/mmlu", subset, split=split).shuffle(seed=42)
if subset == "auxiliary_train":
# I don't understand why but the auxiliary_train rows have some weird additional 'train' wrapper
self.ds = self.ds.map(lambda row: row['train'], remove_columns=['train'])
@property
def eval_type(self):
return 'categorical'
def num_examples(self):
return len(self.ds)
def get_example(self, index):
row = self.ds[index]
question = row["question"] # the question text
choices = row["choices"] # the text of each choice
answer = row["answer"] # index of the answer, e.g. 0,1,2,3 (for A,B,C,D)
subject = row["subject"] # e.g. "college_biology", "college_chemistry", etc.
assert len(choices) == 4, "MMLU should have 4 choices"
# create and return the Conversation object
user_message = render_mc(question, self.letters, choices)
assistant_message = self.letters[answer]
messages = [
{"role": "user", "content": user_message},
{"role": "assistant", "content": assistant_message}
]
conversation = {
"messages": messages,
"subject": subject, # might be useful later for grouping metrics by subject
"letters": self.letters, # useful during evaluation, so we can narrow and clamp the assistant prediction to one of the letters
}
return conversation
def evaluate(self, conversation, assistant_response):
# the assert here is not strictly speaking needed, but currently the way we eval, we expect this to be true
# I'm going to leave the assert here to prevent footguns, but possibly in the future can remove it.
assert assistant_response in self.letters, f"MMLU answer {assistant_response} is expected to be one of {self.letters}"
assistant_message = conversation['messages'][-1]['content'] # e.g. "A"
return assistant_response == assistant_message

46
tasks/smoltalk.py Normal file
View File

@@ -0,0 +1,46 @@
"""
SmolTalk by HuggingFace. Good "general" conversational dataset.
https://huggingface.co/datasets/HuggingFaceTB/smol-smoltalk
We use the "smol" version, which is more appropriate for smaller models.
"""
from datasets import load_dataset
from tasks.common import Task
class SmolTalk(Task):
""" smol-smoltalk dataset. train is 460K rows, test is 24K rows. """
def __init__(self, split, **kwargs):
super().__init__(**kwargs)
assert split in ["train", "test"], "SmolTalk split must be train|test"
self.ds = load_dataset("HuggingFaceTB/smol-smoltalk", split=split).shuffle(seed=42)
self.length = len(self.ds)
def num_examples(self):
return self.length
def get_example(self, index):
row = self.ds[index]
messages = row["messages"]
# ---------------------------------------------------------------------
# sanity checking asserts here
# TODO: we could remove these asserts later, for now just don't want any footguns
# there is an optional system message at the beginning
assert len(messages) >= 1
first_message = messages[0]
if first_message["role"] == "system":
rest_messages = messages[1:] # optional system message is OK
else:
rest_messages = messages
assert len(rest_messages) >= 2, "SmolTalk messages must have at least 2 messages"
for i, message in enumerate(rest_messages):
# user and assistant alternate as user,assistant,user,assistant,...
expected_role = "user" if i % 2 == 0 else "assistant"
assert message["role"] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}"
assert isinstance(message["content"], str), "Content must be a string"
# ---------------------------------------------------------------------
# create and return the Conversation object (ok to emit the system message too)
conversation = {
"messages": messages,
}
return conversation