dont evaluate the sampling evals during SFT they are too slow. keep the multiple choice evals. delete unused imports

This commit is contained in:
Andrej Karpathy
2025-10-15 16:42:23 +00:00
parent b8076dd367
commit 190d9515d0

View File

@@ -11,7 +11,6 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft
import os import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import copy
import wandb import wandb
import torch import torch
@@ -23,11 +22,9 @@ from nanochat.checkpoint_manager import save_checkpoint
from nanochat.engine import Engine from nanochat.engine import Engine
from scripts.chat_eval import run_chat_eval from scripts.chat_eval import run_chat_eval
from tasks.common import TaskMixture, TaskSequence from tasks.common import TaskMixture
from tasks.mmlu import MMLU
from tasks.arc import ARC from tasks.arc import ARC
from tasks.gsm8k import GSM8K from tasks.gsm8k import GSM8K
from tasks.humaneval import HumanEval
from tasks.smoltalk import SmolTalk from tasks.smoltalk import SmolTalk
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -186,7 +183,7 @@ for step in range(num_iterations):
}) })
model.train() model.train()
# evlauate MMLU accuracy # evlauate accuracy of the multiple choice tasks (which are quick to run)
if last_step or (step > 0 and step % eval_metrics_every == 0): if last_step or (step > 0 and step % eval_metrics_every == 0):
model.eval() model.eval()
metrics = {} metrics = {}
@@ -194,8 +191,6 @@ for step in range(num_iterations):
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size # note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024) metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024) metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
metrics["gsm8k_acc"] = run_chat_eval("GSM8K", model, tokenizer, engine, max_problems=64)
metrics["humaneval_acc"] = run_chat_eval("HumanEval", model, tokenizer, engine, max_problems=64)
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items()) metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
print0(f"Step {step:05d} | {metrics_str}") print0(f"Step {step:05d} | {metrics_str}")
wandb_run.log({ wandb_run.log({