diff --git a/nanochat/common.py b/nanochat/common.py index 3ec9992..a5a6d2e 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -5,6 +5,8 @@ Common utilities for nanochat. import os import re import logging +import fcntl +import urllib.request import torch import torch.distributed as dist @@ -56,6 +58,44 @@ def get_base_dir(): os.makedirs(nanochat_dir, exist_ok=True) return nanochat_dir +def download_file_with_lock(url, filename): + """ + Downloads a file from a URL to a local path in the base directory. + Uses a lock file to prevent concurrent downloads among multiple ranks. + """ + base_dir = get_base_dir() + file_path = os.path.join(base_dir, filename) + lock_path = file_path + ".lock" + + if os.path.exists(file_path): + return file_path + + with open(lock_path, 'w') as lock_file: + + # Only a single rank can acquire this lock + # All other ranks block until it is released + fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) + + if os.path.exists(file_path): + return file_path + + print(f"Downloading {url}...") + with urllib.request.urlopen(url) as response: + content = response.read().decode('utf-8') + + with open(file_path, 'w') as f: + f.write(content) + + print(f"Downloaded to {file_path}") + + # Clean up the lock file after the lock is released + try: + os.remove(lock_path) + except OSError: + pass # Ignore if already removed by another process + + return file_path + def print0(s="",**kwargs): ddp_rank = int(os.environ.get('RANK', 0)) if ddp_rank == 0: diff --git a/nanochat/engine.py b/nanochat/engine.py index de1253a..fee06a1 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -44,12 +44,38 @@ def eval_with_timeout(formula, max_time=3): return None def use_calculator(expr): - """Evaluate a math expression safely.""" + """ + Evaluate a Python expression safely. + Supports both math expressions and string operations like .count() + """ + # Remove commas from numbers expr = expr.replace(",", "") - if any([x not in "0123456789*+-/.() " for x in expr]): # for now disallow non-numeric chars + + # Check if it's a pure math expression (old behavior) + if all([x in "0123456789*+-/.() " for x in expr]): + if "**" in expr: # disallow power operator + return None + return eval_with_timeout(expr) + + # Check if it's a string operation we support + # Allow: strings (single/double quotes), .count(), letters, numbers, spaces, parens + allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ " + if not all([x in allowed_chars for x in expr]): return None - if "**" in expr: # for now disallow power operator, could be very expensive + + # Disallow dangerous patterns + dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file', + 'input', 'raw_input', 'globals', 'locals', 'vars', 'dir', + 'getattr', 'setattr', 'delattr', 'hasattr'] + expr_lower = expr.lower() + if any(pattern in expr_lower for pattern in dangerous_patterns): return None + + # Only allow .count() method for now (can expand later) + if '.count(' not in expr: + return None + + # Evaluate with timeout return eval_with_timeout(expr) # ----------------------------------------------------------------------------- diff --git a/scripts/chat_eval.py b/scripts/chat_eval.py index 03d34c3..c77a89e 100644 --- a/scripts/chat_eval.py +++ b/scripts/chat_eval.py @@ -23,6 +23,7 @@ from tasks.humaneval import HumanEval from tasks.mmlu import MMLU from tasks.arc import ARC from tasks.gsm8k import GSM8K +from tasks.spellingbee import SpellingBee # ----------------------------------------------------------------------------- # Generative evaluation loop (we go one problem at a time, sample, evaluate) @@ -165,6 +166,7 @@ def run_chat_eval(task_name, model, tokenizer, engine, 'ARC-Easy': partial(ARC, subset="ARC-Easy", split="test"), 'ARC-Challenge': partial(ARC, subset="ARC-Challenge", split="test"), 'GSM8K': partial(GSM8K, subset="main", split="test"), + 'SpellingBee': partial(SpellingBee, size=256, split="test"), }[task_name] task_object = task_module() # Run the evaluation @@ -204,13 +206,14 @@ if __name__ == "__main__": engine = Engine(model, tokenizer) # Get the tasks to evaluate on - all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval'] + all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee'] baseline_accuracies = { 'ARC-Easy': 0.25, # multiple choice 1 of 4 => 25% 'ARC-Challenge': 0.25, # multiple choice 1 of 4 => 25% 'MMLU': 0.25, # multiple choice 1 of 4 => 25% 'GSM8K': 0.0, # open-ended => 0% 'HumanEval': 0.0, # open-ended => 0% + 'SpellingBee': 0.0, # open-ended => 0% } task_names = all_tasks if args.task_name is None else args.task_name.split('|') diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index aeab77e..e6e4565 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -28,6 +28,7 @@ from tasks.arc import ARC from tasks.gsm8k import GSM8K from tasks.smoltalk import SmolTalk from tasks.customjson import CustomJSON +from tasks.spellingbee import SimpleSpelling, SpellingBee # ----------------------------------------------------------------------------- # SFT Hyperparameters @@ -86,7 +87,9 @@ train_ds = TaskMixture([ GSM8K(subset="main", split="train"), # 8K rows SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations -]) # 2.3K + 1.1K + 8K + 10K + 1K = 22.4K rows + SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling (e.g. spell the word 'apple') + SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?) +]) # 2.3K + 1.1K + 8K + 10K + 1K + 0.3K + 0.3K = 23K rows val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it) # ----------------------------------------------------------------------------- diff --git a/scripts/mid_train.py b/scripts/mid_train.py index 2835ebf..eedb262 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -28,6 +28,7 @@ from tasks.gsm8k import GSM8K from tasks.mmlu import MMLU from tasks.smoltalk import SmolTalk from tasks.customjson import CustomJSON +from tasks.spellingbee import SimpleSpelling, SpellingBee # ----------------------------------------------------------------------------- run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb) @@ -100,7 +101,9 @@ train_dataset = TaskMixture([ GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these -]) # total: 460K + 100K + 8K = 568K rows + SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple') + SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?) +]) # total: 460K + 100K + 8K + 200K + 80K = 848K rows val_dataset = TaskMixture([ SmolTalk(split="test"), # 24K rows in test set MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios diff --git a/tasks/spellingbee.py b/tasks/spellingbee.py new file mode 100644 index 0000000..b394571 --- /dev/null +++ b/tasks/spellingbee.py @@ -0,0 +1,296 @@ +""" +Task intended to make nanochat better in spelling and counting, for example: + +"How many r are in strawberry?" -> 3 + +An interesting part of this task is that we will get the assistant to +solve the problem using a combination of manual counting and Python. +This is a good problem solving "instinct" to mix into the model and RL +may further refine it to trust one over the other. If we were extra fancy +(which we could/should be) we'd add small errors here and there to allow +the model also learn recoveries. We can do this in future versions. + +There are two tasks in this file: +1. SpellingBee: Counting the number of occurrences of a letter in a word +2. SimpleSpelling: Simply spelling words + +(1) is the goal, but (2) exists as a highly condensed version of the part +that makes (1) difficult, which is word spelling. This is non-trivial for an +LLM because it has to learn how every token (a little semantic chunk/atom) +maps to the sequence of individual characters that make it up. Larger models +learn this eventually on their own, but if we want this capability to exist +in smaller models, we have to actively encourage it by over-representing it +in the training data. Midtraining is a good place to do this. + +To preview a few example conversations, run: +python -m tasks.spellingbee +""" + +import re +import random +from tasks.common import Task +from nanochat.common import download_file_with_lock + +# Letters of the alphabet +LETTERS = "abcdefghijklmnopqrstuvwxyz" +# A list of 370K English words of large variety +WORD_LIST_URL = "https://raw.githubusercontent.com/dwyl/english-words/refs/heads/master/words_alpha.txt" + +# Identical to gsm8k's answer extraction +ANSWER_RE = re.compile(r"#### (\-?[0-9\.\,]+)") +def extract_answer(completion): + """ + Extract the numerical answer after #### marker. + """ + match = ANSWER_RE.search(completion) + if match: + match_str = match.group(1).strip() + match_str = match_str.replace(",", "") + return match_str + return None + +# User message templates for data augmentation +USER_MSG_TEMPLATES = [ + "How many {letter} are in the word {word}", + "How many {letter} are in {word}", + "Count the number of {letter} in {word}", + "How many times does {letter} appear in {word}", + "What's the count of {letter} in {word}", + "In the word {word}, how many {letter} are there", + "How many letter {letter} are in the word {word}", + "Count how many {letter} appear in {word}", + "Tell me the number of {letter} in {word}", + "How many occurrences of {letter} are in {word}", + "Find the count of {letter} in {word}", + "Can you count the {letter} letters in {word}", + "What is the frequency of {letter} in {word}", + "How many {letter}s are in {word}", + "How many {letter}'s are in {word}", + "Count all the {letter} in {word}", + "How many times is {letter} in {word}", + "Number of {letter} in {word}", + "Total count of {letter} in {word}", + "How many {letter} does {word} have", + "How many {letter} does {word} contain", + "What's the number of {letter} in {word}", + "{word} has how many {letter}", + "In {word}, count the {letter}", + "How many {letter} appear in {word}", + "Count the {letter} in {word}", + "Give me the count of {letter} in {word}", + "How many instances of {letter} in {word}", + "Show me how many {letter} are in {word}", + "Calculate the number of {letter} in {word}", + # Spanish + "¿Cuántas {letter} hay en {word}?", + "¿Cuántas veces aparece {letter} en {word}?", + "Cuenta las {letter} en {word}", + "¿Cuántas letras {letter} tiene {word}?", + # Chinese (Simplified) + "{word}中有多少个{letter}", + "{word}里有几个{letter}", + "数一下{word}中的{letter}", + "{word}这个词里有多少{letter}", + # Korean + "{word}에 {letter}가 몇 개 있나요", + "{word}에서 {letter}의 개수는", + "{word}에 {letter}가 몇 번 나오나요", + "{word}라는 단어에 {letter}가 몇 개", + # French + "Combien de {letter} dans {word}", + "Combien de fois {letter} apparaît dans {word}", + "Compte les {letter} dans {word}", + # German + "Wie viele {letter} sind in {word}", + "Wie oft kommt {letter} in {word} vor", + "Zähle die {letter} in {word}", + # Japanese + "{word}に{letter}は何個ありますか", + "{word}の中に{letter}がいくつ", + "{word}に{letter}が何回出てくる", +] + +class SpellingBee(Task): + + def __init__(self, size=1000, split="train", **kwargs): + super().__init__(**kwargs) + assert split in ["train", "test"], "SpellingBee split must be train|test" + self.size = size + self.split = split + filename = WORD_LIST_URL.split("/")[-1] + word_list_path = download_file_with_lock(WORD_LIST_URL, filename) + with open(word_list_path) as f: + words = [line.strip() for line in f] + self.words = words + + @property + def eval_type(self): + return 'generative' + + def num_examples(self): + return self.size + + def get_example(self, index): + seed = index if self.split == "train" else -(index + 1) # avoid collision at 0 + rng = random.Random(seed) + + # pick a random word + word = rng.choice(self.words) + # pick a letter from it (90%) or a random letter (10%) + letter = rng.choice(word) if rng.random() < 0.9 else rng.choice(LETTERS) + + # get the correct answer by simply counting + count = word.count(letter) + + # create a user message, with a bunch of variations as data augmentation + template = rng.choice(USER_MSG_TEMPLATES) + # 30% chance to lowercase the template (lazy people don't use shift) + if rng.random() < 0.3: + template = template.lower() + quote_options = ['', "'", '"'] + letter_quote = rng.choice(quote_options) # is the letter quoted? + word_quote = rng.choice(quote_options) # is the word quoted? + letter_wrapped = f"{letter_quote}{letter}{letter_quote}" + word_wrapped = f"{word_quote}{word}{word_quote}" + user_msg = template.format(letter=letter_wrapped, word=word_wrapped) + if rng.random() < 0.5: # 50% of people don't even use question marks + user_msg += "?" + + # Now create the ideal assistant response - build as parts (text + tool calls) + assistant_parts = [] + word_letters = ",".join(list(word)) + manual_text = f"""We are asked to find the number '{letter}' in the word '{word}'. Let me try a manual approach first. + +First spell the word out: +{word}:{word_letters} + +Then count the occurrences of '{letter}': +""" + # Little simulated loop of the solution process + # TODO: This is where the fun starts, we could simulate cute little mistakes + # and get the model to review its work and recover from them. + # You might of course hope this could arise in RL too, but realistically you'd want to help it out a bit. + running_count = 0 + for i, char in enumerate(word, 1): + if char == letter: + running_count += 1 + # note: there deliberately cannot be a space here between i and char + # because this would create a different token! (e.g. " a" and "a" are different tokens) + manual_text += f"{i}:{char} hit! count={running_count}\n" + else: + manual_text += f"{i}:{char}\n" + + manual_text += f"\nThis gives us {running_count}." + assistant_parts.append({"type": "text", "text": manual_text}) + # Part 2: Python verification + assistant_parts.append({"type": "text", "text": "\n\nLet me double check this using Python:\n\n"}) + # Part 3: Python tool call + python_expr = f"'{word}'.count('{letter}')" + assistant_parts.append({"type": "python", "text": python_expr}) + # Part 4: Python output + assistant_parts.append({"type": "python_output", "text": str(count)}) + # Part 5: Final answer + assistant_parts.append({"type": "text", "text": f"\n\nPython gives us {count}.\n\nMy final answer is:\n\n#### {count}"}) + + # return the full conversation + messages = [ + {"role": "user", "content": user_msg}, + {"role": "assistant", "content": assistant_parts} + ] + conversation = { + "messages": messages, + } + return conversation + + def evaluate(self, conversation, assistant_response): + """ + Given (conversation, completion), return evaluation outcome (0 = wrong, 1 = correct) + Identical to gsm8k's evaluation. + """ + assert isinstance(assistant_response, str), "Assuming simple string response for now" + # First extract the ground truth answer from the conversation + assistant_message = conversation['messages'][-1] + assert assistant_message['role'] == "assistant", "Last message must be from the Assistant" + assert isinstance(assistant_message['content'], list), "This is expected to be a list of parts" + # The last text part contains the final answer with #### + last_text_part = assistant_message['content'][-1]['text'] + # Extract both the ground truth answer and the predicted answer + ref_num = extract_answer(last_text_part) + pred_num = extract_answer(assistant_response) + # Compare and return the success as int + is_correct = int(pred_num == ref_num) + return is_correct + + def reward(self, conversation, assistant_response): + """ Use simple 0-1 reward just like gsm8k.""" + is_correct = self.evaluate(conversation, assistant_response) + is_correct_float = float(is_correct) + return is_correct_float + + +class SimpleSpelling(Task): + """Much simpler task designed to get the model to just practice spelling words.""" + + def __init__(self, size=1000, split="train", **kwargs): + super().__init__(**kwargs) + assert split in ["train", "test"], "SpellingBee split must be train|test" + self.size = size + self.split = split + filename = WORD_LIST_URL.split("/")[-1] + word_list_path = download_file_with_lock(WORD_LIST_URL, filename) + with open(word_list_path) as f: + words = [line.strip() for line in f] + rng = random.Random(42) + rng.shuffle(words) # use a different word order than the SpellingBee task + self.words = words + + @property + def eval_type(self): + return 'generative' + + def num_examples(self): + return self.size + + def get_example(self, index): + seed = index if self.split == "train" else -(index + 1) # avoid collision at 0 + rng = random.Random(seed) + # pick a random word + word = rng.choice(self.words) + word_letters = ",".join(list(word)) + # return the full conversation + messages = [ + {"role": "user", "content": f"Spell the word: {word}"}, + {"role": "assistant", "content": f"{word}: {word_letters}"} + ] + conversation = { + "messages": messages, + } + return conversation + + +if __name__ == "__main__": + + # preview the SpellingBee task, first 10 examples + task = SpellingBee() + for i in range(10): + ex = task.get_example(i) + print("=" * 100) + print(ex['messages'][0]['content']) + print("-" * 100) + # Assistant content is now a list of parts + assistant_parts = ex['messages'][1]['content'] + for part in assistant_parts: + if part['type'] == 'text': + print(part['text'], end='') + elif part['type'] == 'python': + print(f"<<{part['text']}=", end='') + elif part['type'] == 'python_output': + print(f"{part['text']}>>", end='') + print() + print("-" * 100) + + # also scrutinize the tokenization (last example only) + # from nanochat.tokenizer import get_tokenizer + # tokenizer = get_tokenizer() + # ids, mask = tokenizer.render_conversation(ex) + # print(tokenizer.visualize_tokenization(ids, mask, with_token_id=True))