add the SpellingBee task so that nanochat can count r in strawberry etc. along the way we had to add a bunch of new functionality, e.g. extend the calculator to support the count function of python. possibly the current TaskMixture uses way too many synthetic examples of SpellingBee because the eval gives us exactly 100% performance on spelling. We can tune this later to reclaim some wall clock time here I think
This commit is contained in:
@@ -23,6 +23,7 @@ from tasks.humaneval import HumanEval
|
||||
from tasks.mmlu import MMLU
|
||||
from tasks.arc import ARC
|
||||
from tasks.gsm8k import GSM8K
|
||||
from tasks.spellingbee import SpellingBee
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Generative evaluation loop (we go one problem at a time, sample, evaluate)
|
||||
@@ -165,6 +166,7 @@ def run_chat_eval(task_name, model, tokenizer, engine,
|
||||
'ARC-Easy': partial(ARC, subset="ARC-Easy", split="test"),
|
||||
'ARC-Challenge': partial(ARC, subset="ARC-Challenge", split="test"),
|
||||
'GSM8K': partial(GSM8K, subset="main", split="test"),
|
||||
'SpellingBee': partial(SpellingBee, size=256, split="test"),
|
||||
}[task_name]
|
||||
task_object = task_module()
|
||||
# Run the evaluation
|
||||
@@ -204,13 +206,14 @@ if __name__ == "__main__":
|
||||
engine = Engine(model, tokenizer)
|
||||
|
||||
# Get the tasks to evaluate on
|
||||
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval']
|
||||
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee']
|
||||
baseline_accuracies = {
|
||||
'ARC-Easy': 0.25, # multiple choice 1 of 4 => 25%
|
||||
'ARC-Challenge': 0.25, # multiple choice 1 of 4 => 25%
|
||||
'MMLU': 0.25, # multiple choice 1 of 4 => 25%
|
||||
'GSM8K': 0.0, # open-ended => 0%
|
||||
'HumanEval': 0.0, # open-ended => 0%
|
||||
'SpellingBee': 0.0, # open-ended => 0%
|
||||
}
|
||||
task_names = all_tasks if args.task_name is None else args.task_name.split('|')
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ from tasks.arc import ARC
|
||||
from tasks.gsm8k import GSM8K
|
||||
from tasks.smoltalk import SmolTalk
|
||||
from tasks.customjson import CustomJSON
|
||||
from tasks.spellingbee import SimpleSpelling, SpellingBee
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# SFT Hyperparameters
|
||||
@@ -86,7 +87,9 @@ train_ds = TaskMixture([
|
||||
GSM8K(subset="main", split="train"), # 8K rows
|
||||
SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
|
||||
CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
|
||||
]) # 2.3K + 1.1K + 8K + 10K + 1K = 22.4K rows
|
||||
SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling (e.g. spell the word 'apple')
|
||||
SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
|
||||
]) # 2.3K + 1.1K + 8K + 10K + 1K + 0.3K + 0.3K = 23K rows
|
||||
val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@@ -28,6 +28,7 @@ from tasks.gsm8k import GSM8K
|
||||
from tasks.mmlu import MMLU
|
||||
from tasks.smoltalk import SmolTalk
|
||||
from tasks.customjson import CustomJSON
|
||||
from tasks.spellingbee import SimpleSpelling, SpellingBee
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
@@ -100,7 +101,9 @@ train_dataset = TaskMixture([
|
||||
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
|
||||
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
|
||||
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
|
||||
]) # total: 460K + 100K + 8K = 568K rows
|
||||
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
|
||||
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
|
||||
]) # total: 460K + 100K + 8K + 200K + 80K = 848K rows
|
||||
val_dataset = TaskMixture([
|
||||
SmolTalk(split="test"), # 24K rows in test set
|
||||
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
|
||||
|
||||
Reference in New Issue
Block a user