add personality to nanochat. breaks previous code on git pull and requires download of a new file from s3, but there is a helpful error message so hopefully its ok
This commit is contained in:
@@ -26,6 +26,7 @@ from tasks.common import TaskMixture
|
||||
from tasks.arc import ARC
|
||||
from tasks.gsm8k import GSM8K
|
||||
from tasks.smoltalk import SmolTalk
|
||||
from tasks.customjson import CustomJSON
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# SFT Hyperparameters
|
||||
@@ -74,13 +75,14 @@ engine = Engine(model, tokenizer) # will be used for inline model evaluation onl
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Task data mixture we'll train on
|
||||
|
||||
identity_conversations_filepath = os.path.join(get_base_dir(), "identity_conversations.jsonl")
|
||||
train_ds = TaskMixture([
|
||||
ARC(subset="ARC-Easy", split="train"), # 2.3K rows
|
||||
ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
|
||||
GSM8K(subset="main", split="train"), # 8K rows
|
||||
SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
|
||||
]) # 2.3K + 1.1K + 8K + 10K = 21.4K rows
|
||||
CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
|
||||
]) # 2.3K + 1.1K + 8K + 10K + 1K = 22.4K rows
|
||||
val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@@ -27,6 +27,7 @@ from tasks.common import TaskMixture
|
||||
from tasks.gsm8k import GSM8K
|
||||
from tasks.mmlu import MMLU
|
||||
from tasks.smoltalk import SmolTalk
|
||||
from tasks.customjson import CustomJSON
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
@@ -88,10 +89,13 @@ for opt in optimizers:
|
||||
|
||||
# Midtraining data mixture and DataLoader
|
||||
base_dir = get_base_dir()
|
||||
identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
|
||||
train_dataset = TaskMixture([
|
||||
SmolTalk(split="train"), # 460K rows of general conversations
|
||||
MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
|
||||
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
|
||||
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
|
||||
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
|
||||
]) # total: 460K + 100K + 8K = 568K rows
|
||||
val_dataset = TaskMixture([
|
||||
SmolTalk(split="test"), # 24K rows in test set
|
||||
|
||||
Reference in New Issue
Block a user