dont evaluate the sampling evals during SFT they are too slow. keep the multiple choice evals. delete unused imports
This commit is contained in:
@@ -11,7 +11,6 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||||
import copy
|
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
import torch
|
import torch
|
||||||
@@ -23,11 +22,9 @@ from nanochat.checkpoint_manager import save_checkpoint
|
|||||||
from nanochat.engine import Engine
|
from nanochat.engine import Engine
|
||||||
from scripts.chat_eval import run_chat_eval
|
from scripts.chat_eval import run_chat_eval
|
||||||
|
|
||||||
from tasks.common import TaskMixture, TaskSequence
|
from tasks.common import TaskMixture
|
||||||
from tasks.mmlu import MMLU
|
|
||||||
from tasks.arc import ARC
|
from tasks.arc import ARC
|
||||||
from tasks.gsm8k import GSM8K
|
from tasks.gsm8k import GSM8K
|
||||||
from tasks.humaneval import HumanEval
|
|
||||||
from tasks.smoltalk import SmolTalk
|
from tasks.smoltalk import SmolTalk
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -186,7 +183,7 @@ for step in range(num_iterations):
|
|||||||
})
|
})
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
# evlauate MMLU accuracy
|
# evlauate accuracy of the multiple choice tasks (which are quick to run)
|
||||||
if last_step or (step > 0 and step % eval_metrics_every == 0):
|
if last_step or (step > 0 and step % eval_metrics_every == 0):
|
||||||
model.eval()
|
model.eval()
|
||||||
metrics = {}
|
metrics = {}
|
||||||
@@ -194,8 +191,6 @@ for step in range(num_iterations):
|
|||||||
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
|
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
|
||||||
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
|
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
|
||||||
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
|
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
|
||||||
metrics["gsm8k_acc"] = run_chat_eval("GSM8K", model, tokenizer, engine, max_problems=64)
|
|
||||||
metrics["humaneval_acc"] = run_chat_eval("HumanEval", model, tokenizer, engine, max_problems=64)
|
|
||||||
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
|
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
|
||||||
print0(f"Step {step:05d} | {metrics_str}")
|
print0(f"Step {step:05d} | {metrics_str}")
|
||||||
wandb_run.log({
|
wandb_run.log({
|
||||||
|
|||||||
Reference in New Issue
Block a user