initial commit
This commit is contained in:
180
scripts/base_eval.py
Normal file
180
scripts/base_eval.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""
|
||||
Evlauate the CORE metric for a given model.
|
||||
|
||||
Run on a single GPU:
|
||||
python base_eval.py
|
||||
|
||||
Run with torchrun on e.g. 8 GPUs:
|
||||
torchrun --nproc_per_node=8 base_eval.py
|
||||
|
||||
The script will print the CORE metric to the console.
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
import random
|
||||
import yaml
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir
|
||||
from nanochat.tokenizer import HuggingFaceTokenizer
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.core_eval import evaluate_task
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# nanoChat specific function dealing with I/O etc.
|
||||
|
||||
def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
||||
"""
|
||||
Evaluate a base model on the CORE benchmark.
|
||||
- max_per_task: crop the data to this many examples per task for testing (-1 = disable)
|
||||
TODO: clean up this function, delete the need for all the files, for pandas dependency, etc.
|
||||
"""
|
||||
# Load config and task metadata
|
||||
base_dir = get_base_dir()
|
||||
eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
|
||||
config_path = os.path.join(eval_bundle_dir, "core.yaml")
|
||||
data_base_path = os.path.join(eval_bundle_dir, "eval_data")
|
||||
eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv")
|
||||
with open(config_path, 'r') as f:
|
||||
config = yaml.safe_load(f)
|
||||
tasks = config['icl_tasks']
|
||||
eval_metadata = pd.read_csv(eval_meta_data)
|
||||
|
||||
# Evaluate each task
|
||||
results = {}
|
||||
centered_results = {}
|
||||
for task in tasks:
|
||||
start_time = time.time()
|
||||
label = task['label']
|
||||
task_meta = {
|
||||
'task_type': task['icl_task_type'],
|
||||
'dataset_uri': task['dataset_uri'],
|
||||
'num_fewshot': task['num_fewshot'][0],
|
||||
'continuation_delimiter': task.get('continuation_delimiter', ' ')
|
||||
}
|
||||
print0(f"Evaluating: {label} ({task_meta['num_fewshot']}-shot, type: {task_meta['task_type']})... ", end='')
|
||||
|
||||
# Load data for this task
|
||||
data_path = os.path.join(data_base_path, task_meta['dataset_uri'])
|
||||
with open(data_path, 'r') as f:
|
||||
data = [json.loads(line.strip()) for line in f]
|
||||
|
||||
# shuffle the data because in many cases it appears ordered but we want
|
||||
# the abillity to only run a subset of the data for debugging purposes etc.
|
||||
shuffle_rng = random.Random(1337)
|
||||
shuffle_rng.shuffle(data)
|
||||
if max_per_task > 0:
|
||||
data = data[:max_per_task]
|
||||
|
||||
# run the evaluation for this task
|
||||
accuracy = evaluate_task(model, tokenizer, data, device, task_meta)
|
||||
|
||||
results[label] = accuracy
|
||||
row = eval_metadata[eval_metadata["Eval Task"] == label]
|
||||
random_baseline = row["Random baseline"].values[0]
|
||||
centered_result = (accuracy - 0.01 * random_baseline) / (1.0 - 0.01 * random_baseline)
|
||||
centered_results[label] = centered_result
|
||||
end_time = time.time()
|
||||
print0(f"accuracy: {accuracy:.4f} | centered: {centered_result:.4f} | time: {end_time - start_time:.2f}s")
|
||||
|
||||
core_metric = sum(centered_results.values()) / len(centered_results)
|
||||
out = {
|
||||
"results": results,
|
||||
"centered_results": centered_results,
|
||||
"core_metric": core_metric
|
||||
}
|
||||
return out
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# HuggingFace loading utilities and light wrappers for a model
|
||||
|
||||
class ModelWrapper:
|
||||
"""Lightweight wrapper for a HuggingFace model"""
|
||||
def __init__(self, model, max_seq_len=None):
|
||||
self.model = model
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
def __call__(self, input_ids):
|
||||
outputs = self.model(input_ids)
|
||||
logits = outputs.logits
|
||||
return logits
|
||||
|
||||
def load_hf_model(hf_path: str, device):
|
||||
print0(f"Loading model from: {hf_path}")
|
||||
# Load the model
|
||||
from transformers import AutoModelForCausalLM
|
||||
model = AutoModelForCausalLM.from_pretrained(hf_path)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
max_seq_len = 1024 if "openai-community/gpt2" in hf_path else None
|
||||
model = ModelWrapper(model, max_seq_len=max_seq_len)
|
||||
# Load the tokenizer
|
||||
tokenizer = HuggingFaceTokenizer.from_pretrained(hf_path)
|
||||
return model, tokenizer
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def main():
|
||||
assert len(sys.argv) in [1, 2], "Usage: python base_eval.py [hf_path]"
|
||||
|
||||
# distributed / precision setup
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Load model and tokenizer from command line or from file system
|
||||
if len(sys.argv) >= 2:
|
||||
# atm assume that if a path is given, it's a huggingface model path
|
||||
hf_path = sys.argv[1]
|
||||
print0(f"Loading huggingface model from: {hf_path}")
|
||||
model, tokenizer = load_hf_model(hf_path, device)
|
||||
model_name = hf_path # just for logging
|
||||
model_slug = hf_path.replace("/", "-") # for the output csv file
|
||||
else:
|
||||
# load a local model from the file system
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval")
|
||||
model_name = f"base_model (step {meta['step']})" # just for logging
|
||||
model_slug = f"base_model_{meta['step']:06d}" # for the output csv file
|
||||
|
||||
# Evaluate the model
|
||||
with autocast_ctx:
|
||||
out = evaluate_model(model, tokenizer, device)
|
||||
|
||||
# Write out the results to a csv file
|
||||
core_metric = None
|
||||
centered_results = {}
|
||||
if ddp_rank == 0:
|
||||
base_dir = get_base_dir()
|
||||
output_csv_path = os.path.join(base_dir, "base_eval", f"{model_slug}.csv")
|
||||
os.makedirs(os.path.dirname(output_csv_path), exist_ok=True)
|
||||
results = out["results"]
|
||||
centered_results = out["centered_results"]
|
||||
core_metric = out["core_metric"]
|
||||
with open(output_csv_path, 'w') as f:
|
||||
f.write(f"{'Task':<35}, {'Accuracy':<10}, {'Centered':<10}\n")
|
||||
for label in results:
|
||||
f.write(f"{label:<35}, {results[label]:<10.6f}, {centered_results[label]:<10.6f}\n")
|
||||
f.write(f"{'CORE':<35}, {'':<10}, {core_metric:<10.6f}\n")
|
||||
# Print the content of the csv file to console too
|
||||
print0("="*80)
|
||||
print0(f"Model: {model_name}")
|
||||
print0("="*80)
|
||||
with open(output_csv_path, 'r') as f:
|
||||
print0(f.read())
|
||||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
get_report().log(section="Base model evaluation", data=[
|
||||
{
|
||||
"Model": model_name,
|
||||
"CORE metric": core_metric,
|
||||
},
|
||||
centered_results, # the full table
|
||||
])
|
||||
|
||||
compute_cleanup()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
78
scripts/base_loss.py
Normal file
78
scripts/base_loss.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
Loads a checkpoint, and:
|
||||
- Evaluates the loss on a larger chunk of train/val splits
|
||||
- Samples from the model
|
||||
|
||||
Example run as:
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
|
||||
"""
|
||||
import os
|
||||
import torch
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.common import compute_init, print0, compute_cleanup
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
||||
from nanochat.tokenizer import get_token_bytes
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.engine import Engine
|
||||
|
||||
# Configuration
|
||||
device_batch_size = 32
|
||||
split_tokens = 20*524288 # number of tokens to evaluate per split
|
||||
model_tag = None # optional model tag for the output directory name
|
||||
model_step = None # optional model step for the output directory name
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
|
||||
# Load the base model and the tokenizer
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step)
|
||||
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
|
||||
|
||||
# Set up the precision we'll run with
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Evaluate the loss on each split
|
||||
tokens_per_step = device_batch_size * sequence_len * ddp_world_size
|
||||
assert split_tokens % tokens_per_step == 0, "split_tokens must be divisible by tokens_per_step"
|
||||
steps = split_tokens // tokens_per_step
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
bpb_results = {}
|
||||
for split_name in ["train", "val"]:
|
||||
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name)
|
||||
with autocast_ctx:
|
||||
bpb = evaluate_bpb(model, loader, steps, token_bytes)
|
||||
print0(f"{split_name} bpb: {bpb:.4f}")
|
||||
bpb_results[split_name] = bpb
|
||||
|
||||
# Master process also samples from the model
|
||||
samples = []
|
||||
if ddp_rank == 0:
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"The chemical symbol of gold is",
|
||||
"If yesterday was Friday, then tomorrow will be",
|
||||
"The opposite of hot is",
|
||||
"The planets of the solar system are:",
|
||||
"My favorite color is",
|
||||
"If 5*x + 3 = 13, then x is",
|
||||
]
|
||||
engine = Engine(model, tokenizer)
|
||||
for prompt in prompts:
|
||||
tokens = tokenizer(prompt, prepend="<|bos|>")
|
||||
with autocast_ctx:
|
||||
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
|
||||
sample_str = tokenizer.decode(sample[0])
|
||||
print0(sample_str)
|
||||
samples.append(sample_str)
|
||||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
get_report().log(section="Base model loss", data=[
|
||||
{
|
||||
"train bpb": bpb_results["train"],
|
||||
"val bpb": bpb_results["val"],
|
||||
},
|
||||
{f"sample {i}": sample for i, sample in enumerate(samples)},
|
||||
])
|
||||
|
||||
# Cleanup
|
||||
compute_cleanup()
|
||||
339
scripts/base_train.py
Normal file
339
scripts/base_train.py
Normal file
@@ -0,0 +1,339 @@
|
||||
"""
|
||||
Train model. Run as:
|
||||
|
||||
python base_train.py
|
||||
|
||||
or distributed as:
|
||||
|
||||
torchrun --nproc_per_node=8 base_train.py
|
||||
"""
|
||||
|
||||
import os
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import time
|
||||
import wandb
|
||||
import torch
|
||||
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir
|
||||
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.engine import Engine
|
||||
from scripts.base_eval import evaluate_model
|
||||
print_banner()
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# User settings
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
# Model architecture
|
||||
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
|
||||
max_seq_len = 2048 # max context length
|
||||
# Training horizon. Only one of these 3 will be used, in this order of precedence.
|
||||
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
|
||||
target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable)
|
||||
target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
|
||||
# Optimization
|
||||
device_batch_size = 32 # per-device batch size (set to not OOM)
|
||||
total_batch_size = 524288 # total desired batch size, in #tokens
|
||||
embedding_lr = 0.2 # learning rate for the embedding parameters (Adam)
|
||||
unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
|
||||
weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam)
|
||||
matrix_lr = 0.02 # learning rate for the matrix parameters (Muon)
|
||||
grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
|
||||
# Evaluation
|
||||
eval_every = 250 # every how many steps to evaluate the model for val bpb
|
||||
eval_tokens = 20*524288 # number of tokens to evaluate val loss on
|
||||
core_metric_every = 2000 # every how many steps to evaluate the core metric
|
||||
core_metric_max_per_task = 500 # examples per task in estimating the core metric
|
||||
sample_every = 2000 # every how many steps to sample from the model
|
||||
# Output
|
||||
model_tag = "" # optionally override the model tag for the output checkpoint directory name
|
||||
# now allow CLI to override the settings via the configurator lol
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=run, config=user_config)
|
||||
|
||||
# Tokenizer will be useful for evaluation, also we need the vocab size
|
||||
tokenizer = get_tokenizer()
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
vocab_size = tokenizer.get_vocab_size()
|
||||
print0(f"Vocab size: {vocab_size:,}")
|
||||
|
||||
# Model kwargs are derived from the desired depth of the model
|
||||
num_layers = depth
|
||||
model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
|
||||
num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div)
|
||||
num_kv_heads = num_heads # 1:1 MQA ratio
|
||||
print0(f"num_layers: {num_layers}")
|
||||
print0(f"model_dim: {model_dim}")
|
||||
print0(f"num_heads: {num_heads}")
|
||||
print0(f"num_kv_heads: {num_kv_heads}")
|
||||
|
||||
# Optimizer / data / training length related hyperparameters
|
||||
# figure out the needed gradient accumulation to reach the desired total batch size
|
||||
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
||||
assert total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Model
|
||||
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
|
||||
with torch.device("meta"):
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
model = GPT(model_config)
|
||||
model.to_empty(device="cuda")
|
||||
model.init_weights()
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict
|
||||
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
|
||||
num_params = sum(p.numel() for p in model.parameters())
|
||||
print0(f"Number of parameters: {num_params:,}")
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
|
||||
|
||||
# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order)
|
||||
assert num_iterations > 0 or target_param_data_ratio > 0 or target_flops > 0
|
||||
if num_iterations > 0:
|
||||
print0(f"Using user-provided number of iterations: {num_iterations:,}")
|
||||
elif target_flops > 0:
|
||||
# calculate the number of iterations from the target flops
|
||||
num_iterations = round(target_flops / (num_flops_per_token * total_batch_size))
|
||||
print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
|
||||
elif target_param_data_ratio > 0:
|
||||
# calculate the number of iterations from the target param data ratio
|
||||
target_tokens = target_param_data_ratio * num_params
|
||||
num_iterations = target_tokens // total_batch_size
|
||||
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
|
||||
else:
|
||||
raise ValueError("No training horizon specified")
|
||||
total_tokens = total_batch_size * num_iterations
|
||||
print0(f"Total number of training tokens: {total_tokens:,}")
|
||||
print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20
|
||||
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
|
||||
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
|
||||
adamw_optimizer, muon_optimizer = optimizers
|
||||
|
||||
# Initialize the DataLoaders for train/val
|
||||
base_dir = get_base_dir()
|
||||
tokens_dir = os.path.join(base_dir, "tokenized_data")
|
||||
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train")
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val")
|
||||
x, y = next(train_loader) # kick off load of the very first batch of data
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Set up hyperparameter schedulers
|
||||
|
||||
# Learning rate scheduler
|
||||
# TODO: experiment with a short warmup for the AdamW params (expecting slight improvement)
|
||||
warmup_ratio = 0.0 # ratio of iterations for LR warmup
|
||||
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
|
||||
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
|
||||
def get_lr_multiplier(it):
|
||||
warmup_iters = round(warmup_ratio * num_iterations)
|
||||
warmdown_iters = round(warmdown_ratio * num_iterations)
|
||||
if it < warmup_iters:
|
||||
return (it + 1) / warmup_iters
|
||||
elif it <= num_iterations - warmdown_iters:
|
||||
return 1.0
|
||||
else:
|
||||
progress = (num_iterations - it) / warmdown_iters
|
||||
return progress * 1.0 + (1 - progress) * final_lr_frac
|
||||
|
||||
# Momentum scheduler for Muon optimizer
|
||||
def get_muon_momentum(it):
|
||||
frac = min(it / 300, 1)
|
||||
momentum = (1 - frac) * 0.85 + frac * 0.95
|
||||
return momentum
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Training loop
|
||||
min_val_bpb = float("inf")
|
||||
smooth_train_loss = 0 # EMA of training loss
|
||||
ema_beta = 0.9 # EMA decay factor
|
||||
total_training_time = 0 # total wall-clock time of training
|
||||
# note that we run +1 steps only so that we can eval and save at the end
|
||||
for step in range(num_iterations + 1):
|
||||
last_step = step == num_iterations
|
||||
flops_so_far = num_flops_per_token * total_batch_size * step
|
||||
|
||||
# once in a while: evaluate the val bpb (all ranks participate)
|
||||
if last_step or step % eval_every == 0:
|
||||
model.eval()
|
||||
val_loader = build_val_loader()
|
||||
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
|
||||
with autocast_ctx:
|
||||
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
|
||||
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
|
||||
if val_bpb < min_val_bpb:
|
||||
min_val_bpb = val_bpb
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"total_training_time": total_training_time,
|
||||
"val/bpb": val_bpb,
|
||||
})
|
||||
model.train()
|
||||
|
||||
# once in a while: estimate the CORE metric (all ranks participate)
|
||||
# use the original uncompiled model because the inputs keep changing shape
|
||||
if last_step or (step > 0 and step % core_metric_every == 0):
|
||||
model.eval()
|
||||
with autocast_ctx:
|
||||
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
|
||||
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"core_metric": results["core_metric"],
|
||||
"centered_results": results["centered_results"],
|
||||
})
|
||||
model.train()
|
||||
|
||||
# once in a while: sample from the model (only on master process)
|
||||
# use the original uncompiled model because the inputs keep changing shape
|
||||
if master_process and (last_step or (step > 0 and step % sample_every == 0)):
|
||||
model.eval()
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"The chemical symbol of gold is",
|
||||
"If yesterday was Friday, then tomorrow will be",
|
||||
"The opposite of hot is",
|
||||
"The planets of the solar system are:",
|
||||
"My favorite color is",
|
||||
"If 5*x + 3 = 13, then x is",
|
||||
]
|
||||
engine = Engine(model, tokenizer)
|
||||
for prompt in prompts:
|
||||
tokens = tokenizer(prompt, prepend="<|bos|>")
|
||||
with autocast_ctx:
|
||||
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
|
||||
print0(tokenizer.decode(sample[0]))
|
||||
model.train()
|
||||
|
||||
# save checkpoint at the end of the run (only on master process)
|
||||
if master_process and last_step:
|
||||
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
|
||||
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
orig_model.state_dict(),
|
||||
[opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly
|
||||
{
|
||||
"step": step,
|
||||
"val_bpb": val_bpb, # loss at last step
|
||||
"model_config": model_config_kwargs,
|
||||
"user_config": user_config, # inputs to the training script
|
||||
"device_batch_size": device_batch_size,
|
||||
"max_seq_len": max_seq_len,
|
||||
}
|
||||
)
|
||||
|
||||
if last_step:
|
||||
break
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# single training step
|
||||
# evaluate the gradient
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for micro_step in range(grad_accum_steps):
|
||||
with autocast_ctx:
|
||||
loss = model(x, y)
|
||||
train_loss = loss.detach() # for logging
|
||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
loss.backward()
|
||||
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
# gradient clipping (TODO possibly expertiment with)
|
||||
if grad_clip > 0.0:
|
||||
torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
|
||||
# step the optimizers
|
||||
lrm = get_lr_multiplier(step)
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
muon_momentum = get_muon_momentum(step)
|
||||
for group in muon_optimizer.param_groups:
|
||||
group["momentum"] = muon_momentum
|
||||
for opt in optimizers:
|
||||
opt.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.time()
|
||||
dt = t1 - t0
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
# logging
|
||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
||||
pct_done = 100 * step / num_iterations
|
||||
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
|
||||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
if step > 10:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
|
||||
if step % 100 == 0:
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"total_training_time": total_training_time,
|
||||
"train/loss": debiased_smooth_loss,
|
||||
"train/lrm": lrm,
|
||||
"train/dt": dt,
|
||||
"train/tok_per_sec": tok_per_sec,
|
||||
"train/mfu": mfu,
|
||||
})
|
||||
|
||||
# print a few more stats
|
||||
print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB")
|
||||
print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
get_report().log(section="Base model training", data=[
|
||||
user_config, # CLI args
|
||||
{ # stats about the training setup
|
||||
"Number of parameters": num_params,
|
||||
"Number of FLOPs per token": f"{num_flops_per_token:e}",
|
||||
"Calculated number of iterations": num_iterations,
|
||||
"Number of training tokens": total_tokens,
|
||||
"Tokens : Params ratio": total_batch_size * num_iterations / num_params,
|
||||
"DDP world size": ddp_world_size,
|
||||
"warmup_ratio": warmup_ratio,
|
||||
"warmdown_ratio": warmdown_ratio,
|
||||
"final_lr_frac": final_lr_frac,
|
||||
},
|
||||
{ # stats about training outcomes
|
||||
"Minimum validation bpb": min_val_bpb,
|
||||
"Final validation bpb": val_bpb,
|
||||
"CORE metric estimate": results["core_metric"],
|
||||
"MFU %": f"{mfu:.2f}%",
|
||||
"Total training flops": f"{flops_so_far:e}",
|
||||
"Total training time": f"{total_training_time/60:.2f}m",
|
||||
"Peak memory usage": f"{torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB",
|
||||
}
|
||||
])
|
||||
|
||||
# cleanup
|
||||
wandb_run.finish() # wandb run finish
|
||||
compute_cleanup()
|
||||
99
scripts/chat_cli.py
Normal file
99
scripts/chat_cli.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
New and upgraded chat mode because a lot of the code has changed since the last one.
|
||||
|
||||
Intended to be run single GPU only atm:
|
||||
python -m scripts.chat_cli -i mid
|
||||
"""
|
||||
import argparse
|
||||
import torch
|
||||
from nanochat.common import compute_init
|
||||
from nanochat.engine import Engine
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
|
||||
parser = argparse.ArgumentParser(description='Chat with the model')
|
||||
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
|
||||
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
||||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||
parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')
|
||||
parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
|
||||
parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Init the model and tokenizer
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
|
||||
# Special tokens for the chat state machine
|
||||
bos = tokenizer.get_bos_token_id()
|
||||
user_start, user_end = tokenizer.encode_special("<|user_start|>"), tokenizer.encode_special("<|user_end|>")
|
||||
assistant_start, assistant_end = tokenizer.encode_special("<|assistant_start|>"), tokenizer.encode_special("<|assistant_end|>")
|
||||
|
||||
# Create Engine for efficient generation
|
||||
engine = Engine(model, tokenizer)
|
||||
|
||||
print("\nNanoChat Interactive Mode")
|
||||
print("-" * 50)
|
||||
print("Type 'quit' or 'exit' to end the conversation")
|
||||
print("Type 'clear' to start a new conversation")
|
||||
print("-" * 50)
|
||||
|
||||
conversation_tokens = [bos]
|
||||
|
||||
while True:
|
||||
|
||||
if args.prompt:
|
||||
# Get the prompt from the launch command
|
||||
user_input = args.prompt
|
||||
else:
|
||||
# Get the prompt interactively from the console
|
||||
try:
|
||||
user_input = input("\nUser: ").strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print("\nGoodbye!")
|
||||
break
|
||||
|
||||
# Handle special commands
|
||||
if user_input.lower() in ['quit', 'exit']:
|
||||
print("Goodbye!")
|
||||
break
|
||||
|
||||
if user_input.lower() == 'clear':
|
||||
conversation_tokens = [bos]
|
||||
print("Conversation cleared.")
|
||||
continue
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
# Add User message to the conversation
|
||||
conversation_tokens.append(user_start)
|
||||
conversation_tokens.extend(tokenizer.encode(user_input))
|
||||
conversation_tokens.append(user_end)
|
||||
|
||||
# Kick off the assistant
|
||||
conversation_tokens.append(assistant_start)
|
||||
generate_kwargs = {
|
||||
"num_samples": 1,
|
||||
"max_tokens": 256,
|
||||
"temperature": args.temperature,
|
||||
"top_k": args.top_k,
|
||||
}
|
||||
response_tokens = []
|
||||
print("\nAssistant: ", end="", flush=True)
|
||||
with autocast_ctx:
|
||||
for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs):
|
||||
token = token_column[0] # pop the batch dimension (num_samples=1)
|
||||
response_tokens.append(token)
|
||||
token_text = tokenizer.decode([token])
|
||||
print(token_text, end="", flush=True)
|
||||
print()
|
||||
# we have to ensure that the assistant end token is the last token
|
||||
# so even if generation ends due to max tokens, we have to append it to the end
|
||||
if response_tokens[-1] != assistant_end:
|
||||
response_tokens.append(assistant_end)
|
||||
conversation_tokens.extend(response_tokens)
|
||||
|
||||
# In the prompt mode, we only want a single response and exit
|
||||
if args.prompt:
|
||||
break
|
||||
251
scripts/chat_eval.py
Normal file
251
scripts/chat_eval.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""
|
||||
Evaluate the Chat model.
|
||||
All the generic code lives here, and all the evlauation-specific
|
||||
code lives in nanochat directory and is imported from here.
|
||||
|
||||
Example runs:
|
||||
python -m scripts.chat_eval -a ARC-Easy
|
||||
torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.engine import Engine
|
||||
|
||||
from tasks.humaneval import HumanEval
|
||||
from tasks.mmlu import MMLU
|
||||
from tasks.arc import ARC
|
||||
from tasks.gsm8k import GSM8K
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Generative evaluation loop (we go one problem at a time, sample, evaluate)
|
||||
|
||||
def run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_new_tokens, temperature, top_k, max_problems=None):
|
||||
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
device = model.get_device()
|
||||
|
||||
num_problems = len(task_object) if max_problems is None else min(len(task_object), max_problems)
|
||||
|
||||
# Run the evaluation
|
||||
num_passed, total = 0, 0
|
||||
for i in range(ddp_rank, num_problems, ddp_world_size):
|
||||
conversation = task_object[i]
|
||||
|
||||
# Tokenize the prompt
|
||||
encoded_prompt = tokenizer.render_for_completion(conversation)
|
||||
# Get the completions
|
||||
results, _ = engine.generate_batch(
|
||||
encoded_prompt,
|
||||
num_samples=num_samples,
|
||||
max_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
)
|
||||
# Decode the completions as text
|
||||
prefix_length = len(encoded_prompt)
|
||||
completions = [tokenizer.decode(result_tokens[prefix_length:]) for result_tokens in results]
|
||||
# Evaluate success criteria
|
||||
outcomes = [task_object.evaluate(conversation, completion) for completion in completions]
|
||||
passed = any(outcomes)
|
||||
|
||||
# Keep stats
|
||||
total += 1
|
||||
num_passed += int(passed)
|
||||
|
||||
# Logging (overwrite the same line in the console)
|
||||
print(f"\r\033[KRank {ddp_rank} | {num_passed}/{total} ({100*num_passed/total:.2f}%)", end='', flush=True)
|
||||
|
||||
# Finish the in-place progress line with a newline before final summary
|
||||
print()
|
||||
|
||||
# Aggregate results across all ranks
|
||||
if ddp:
|
||||
num_passed_tensor = torch.tensor([num_passed], dtype=torch.long, device=device)
|
||||
total_tensor = torch.tensor([total], dtype=torch.long, device=device)
|
||||
dist.all_reduce(num_passed_tensor, op=dist.ReduceOp.SUM)
|
||||
dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM)
|
||||
num_passed = num_passed_tensor.item()
|
||||
total = total_tensor.item()
|
||||
|
||||
print0("=" * 50)
|
||||
print0(f"Final: {num_passed}/{total} ({100*num_passed/total:.2f}%)")
|
||||
|
||||
# Return the accuracy
|
||||
return num_passed/total
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Categorical evaluation loop
|
||||
# A lot easier because we don't have to sample. Therefore, we can actually go
|
||||
# batches at a time and just check the logits for correct answer choices.
|
||||
|
||||
def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=None):
|
||||
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
device = model.get_device()
|
||||
bos = tokenizer.get_bos_token_id() # use BOS as pad token is ok, these positions are ignored
|
||||
|
||||
# We'll process batches of independent problems at a time because there is no sampling needed
|
||||
num_problems = len(task_object) if max_problems is None else min(len(task_object), max_problems)
|
||||
ceil_div = lambda x, y: -(-x // y)
|
||||
num_batches = ceil_div(num_problems, batch_size)
|
||||
|
||||
# Run the evaluation
|
||||
letter_to_id_cache = {} # many letters will repeat often, let's save the tokenizer some work
|
||||
num_passed, total = 0, 0
|
||||
for i in range(ddp_rank, num_batches, ddp_world_size):
|
||||
i0, i1 = i * batch_size, min((i + 1) * batch_size, num_problems)
|
||||
|
||||
# Prepare the batch of problems. They might all be of different length, so we pad/collate them.
|
||||
conversations = [task_object[ii] for ii in range(i0, i1)]
|
||||
prompt_ids = [tokenizer.render_for_completion(conversation) for conversation in conversations] # TODO: remake the way this works
|
||||
max_length = max(len(ids) for ids in prompt_ids)
|
||||
answer_time_positions = [len(ids) - 1 for ids in prompt_ids] # where the last token is (and the predicted answer)
|
||||
padded_prompt_ids = [ids + [bos] * (max_length - len(ids)) for ids in prompt_ids]
|
||||
prompt_ids = torch.tensor(padded_prompt_ids, dtype=torch.long, device=device)
|
||||
|
||||
# Get the logits for the whole batch of conversations in parallel (efficiency win here)
|
||||
with torch.no_grad():
|
||||
logits = model(prompt_ids) # (B, T, V)
|
||||
|
||||
# Focus on the available answer on just the letters corresponding to choices
|
||||
# Note that this helps the evaluation a lot because it specifically narrows the focus to only the avilable letters
|
||||
# The much harder alternative would be to just generate from the Assistant and check if it responded with the correct
|
||||
# letter (e.g. A, B, C, D), but evaluations typically make the task easier in this way.
|
||||
for idx, conversation in enumerate(conversations):
|
||||
# get the token ids of all the available letters of this problem
|
||||
letters = conversation['letters']
|
||||
letter_ids = []
|
||||
for letter in letters:
|
||||
if not letter in letter_to_id_cache:
|
||||
encoded_letter = tokenizer.encode(letter)
|
||||
assert len(encoded_letter) == 1, "Each letter must be a single token"
|
||||
letter_to_id_cache[letter] = encoded_letter[0]
|
||||
letter_ids.append(letter_to_id_cache[letter])
|
||||
# focus logits just down to the answer position and the available letters of the answer
|
||||
answer_pos = answer_time_positions[idx]
|
||||
focus_logits = logits[idx, answer_pos, letter_ids]
|
||||
# get the argmax letter (the predicted answer)
|
||||
argmax_letter_id = focus_logits.argmax(dim=-1).item()
|
||||
predicted_letter = letters[argmax_letter_id]
|
||||
# evaluate the outcome
|
||||
outcome = task_object.evaluate(conversation, predicted_letter)
|
||||
num_passed += int(outcome)
|
||||
total += 1
|
||||
|
||||
# Aggregate results across all ranks
|
||||
if ddp:
|
||||
num_passed_tensor = torch.tensor([num_passed], dtype=torch.long, device=device)
|
||||
total_tensor = torch.tensor([total], dtype=torch.long, device=device)
|
||||
dist.all_reduce(num_passed_tensor, op=dist.ReduceOp.SUM)
|
||||
dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM)
|
||||
num_passed = num_passed_tensor.item()
|
||||
total = total_tensor.item()
|
||||
|
||||
average = num_passed/total
|
||||
print0(f"Final: {num_passed}/{total} ({100*average:.2f}%)")
|
||||
return average
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
def run_chat_eval(task_name, model, tokenizer, engine,
|
||||
batch_size=1, num_samples=1, max_new_tokens=512, temperature=0.0, top_k=50,
|
||||
max_problems=None):
|
||||
# Create the evaluation object
|
||||
task_module = {
|
||||
'HumanEval': HumanEval,
|
||||
'MMLU': partial(MMLU, subset="all", split="test"),
|
||||
'ARC-Easy': partial(ARC, subset="ARC-Easy", split="test"),
|
||||
'ARC-Challenge': partial(ARC, subset="ARC-Challenge", split="test"),
|
||||
'GSM8K': partial(GSM8K, subset="main", split="test"),
|
||||
}[task_name]
|
||||
task_object = task_module()
|
||||
# Run the evaluation
|
||||
if task_object.eval_type == 'generative':
|
||||
acc = run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_new_tokens, temperature, top_k, max_problems=max_problems)
|
||||
elif task_object.eval_type == 'categorical':
|
||||
acc = run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=max_problems)
|
||||
else:
|
||||
raise ValueError(f"Unsupported task evaluation type: {task_object.eval_type}")
|
||||
return acc
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == "__main__":
|
||||
|
||||
# Parse command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|mid|rl")
|
||||
parser.add_argument('-a', '--task-name', type=str, default=None, help="Task name. Default = all tasks. Use | to split multiple tasks.")
|
||||
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
||||
parser.add_argument('-t', '--temperature', type=float, default=0.0)
|
||||
parser.add_argument('-m', '--max-new-tokens', type=int, default=512)
|
||||
parser.add_argument('-n', '--num-samples', type=int, default=1)
|
||||
parser.add_argument('-k', '--top-k', type=int, default=50)
|
||||
parser.add_argument('-b', '--batch-size', type=int, default=8, help='Batch size for categorical evaluation')
|
||||
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
||||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||
parser.add_argument('-x', '--max-problems', type=int, default=None, help='Max problems to evaluate')
|
||||
args = parser.parse_args()
|
||||
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=ptdtype)
|
||||
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
engine = Engine(model, tokenizer)
|
||||
|
||||
# Get the tasks to evaluate on
|
||||
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval']
|
||||
baseline_accuracies = {
|
||||
'ARC-Easy': 0.25, # multiple choice 1 of 4 => 25%
|
||||
'ARC-Challenge': 0.25, # multiple choice 1 of 4 => 25%
|
||||
'MMLU': 0.25, # multiple choice 1 of 4 => 25%
|
||||
'GSM8K': 0.0, # open-ended => 0%
|
||||
'HumanEval': 0.0, # open-ended => 0%
|
||||
}
|
||||
task_names = all_tasks if args.task_name is None else args.task_name.split('|')
|
||||
|
||||
# Run all the task evaluations sequentially
|
||||
results = {}
|
||||
for task_name in task_names:
|
||||
with autocast_ctx:
|
||||
acc = run_chat_eval(
|
||||
task_name,
|
||||
model, tokenizer, engine,
|
||||
batch_size=args.batch_size,
|
||||
num_samples=args.num_samples,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
max_problems=args.max_problems,
|
||||
)
|
||||
results[task_name] = acc
|
||||
print0(f"{task_name} accuracy: {100 * acc:.2f}%")
|
||||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
all_tasks_were_evaluated = all(task_name in results for task_name in all_tasks)
|
||||
# calculate the ChatCORE metric if we can (similar to CORE, it's the mean centered accuracy)
|
||||
# this way, ChatCORE ranges from 0 (at random baseline) to 1 (peak performance)
|
||||
chatcore_metric_dict = {}
|
||||
if all_tasks_were_evaluated:
|
||||
centered_mean = 0
|
||||
for task_name, acc in results.items():
|
||||
baseline_acc = baseline_accuracies.get(task_name, 0.0)
|
||||
centered_acc = (acc - baseline_acc) / (1.0 - baseline_acc)
|
||||
centered_mean += centered_acc
|
||||
chatcore_metric = centered_mean / len(results)
|
||||
chatcore_metric_dict = {"ChatCORE metric": chatcore_metric}
|
||||
get_report().log(section="Chat evaluation " + args.source, data=[
|
||||
vars(args), # CLI args
|
||||
results,
|
||||
chatcore_metric_dict,
|
||||
])
|
||||
|
||||
compute_cleanup()
|
||||
331
scripts/chat_rl.py
Normal file
331
scripts/chat_rl.py
Normal file
@@ -0,0 +1,331 @@
|
||||
"""
|
||||
Reinforcement learning on GSM8K via "GRPO".
|
||||
|
||||
I put GRPO in quotes because we actually end up with something a lot
|
||||
simpler and more similar to just REINFORCE:
|
||||
|
||||
1) Delete trust region, so there is no KL regularization to a reference model
|
||||
2) We are on policy, so there's no need for PPO ratio+clip.
|
||||
3) We use GAPO style normalization that is token-level, not sequence-level.
|
||||
4) Instead of z-score normalization (r - mu)/sigma, only use (r - mu) as the advantage.
|
||||
|
||||
1 GPU:
|
||||
python -m scripts.chat_rl
|
||||
|
||||
8 GPUs:
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=default
|
||||
"""
|
||||
|
||||
import os
|
||||
import itertools
|
||||
import re
|
||||
import wandb
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_model
|
||||
from nanochat.engine import Engine
|
||||
from tasks.gsm8k import GSM8K
|
||||
|
||||
# RL hyperparameters
|
||||
run = "dummy" # wandb run name
|
||||
source = "sft" # mid|sft
|
||||
dtype = "bfloat16"
|
||||
device_batch_size = 8 # no forward pass will go above this to not OOM
|
||||
examples_per_step = 16 # in total and across all ranks (note: examples, not samples/completions!)
|
||||
num_samples = 16 # number of samples per example (/question)
|
||||
max_new_tokens = 256
|
||||
temperature = 1.0
|
||||
top_k = 50 # TODO: try None?
|
||||
unembedding_lr = 0.004
|
||||
embedding_lr = 0.2
|
||||
matrix_lr = 0.02
|
||||
weight_decay = 0.0
|
||||
init_lr_frac = 0.05
|
||||
num_epochs = 1 # how many epochs of gsm8k to train on
|
||||
save_every = 60 # every how many steps to save the model
|
||||
eval_every = 60 # every how many steps to evaluate the model for val pass@k
|
||||
eval_examples = 400 # number of examples used for evaluating pass@k
|
||||
# now allow CLI to override the settings via the configurator lol
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Init compute/precision
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=run, config=user_config)
|
||||
|
||||
# Init model and tokenizer
|
||||
model, tokenizer, meta = load_model(source, device, phase="eval")
|
||||
engine = Engine(model, tokenizer) # for sampling rollouts
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Rollout / sampling generator loop that yields batches of examples for training
|
||||
|
||||
train_task = GSM8K(subset="main", split="train")
|
||||
val_task = GSM8K(subset="main", split="test")
|
||||
num_steps = (len(train_task) // examples_per_step) * num_epochs
|
||||
print0(f"Calculated number of steps: {num_steps}")
|
||||
|
||||
@torch.no_grad()
|
||||
def get_batch():
|
||||
assistant_end = tokenizer.encode_special("<|assistant_end|>") # ok to use this token, it's only for padding and isn't used in the loss.
|
||||
rank_indices = range(ddp_rank, len(train_task), ddp_world_size) # each rank is responsible for different examples in the training data
|
||||
for example_idx in itertools.cycle(rank_indices):
|
||||
|
||||
# First get the full conversation of both user and assistant messages
|
||||
conversation = train_task[example_idx]
|
||||
|
||||
# Tokenize the conversation, deleting the last Assistant message and priming the Assistant for a completion instead
|
||||
# (i.e. keep the <|assistant_start|>, but delete everything after it)
|
||||
tokens = tokenizer.render_for_completion(conversation)
|
||||
prefix_length = len(tokens)
|
||||
|
||||
# Generate num_samples samples using batched generation, use loop to avoid OOMs
|
||||
model.eval() # ensure the model is in eval mode
|
||||
generated_token_sequences = []
|
||||
masks = []
|
||||
num_sampling_steps = num_samples // device_batch_size # go sequentially to prevent OOMs
|
||||
for sampling_step in range(num_sampling_steps):
|
||||
seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF # positive half of int32
|
||||
with autocast_ctx:
|
||||
generated_token_sequences_batch, masks_batch = engine.generate_batch(
|
||||
tokens,
|
||||
num_samples=device_batch_size,
|
||||
max_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
seed=seed, # must make sure to change the seed for each sampling step
|
||||
)
|
||||
generated_token_sequences.extend(generated_token_sequences_batch)
|
||||
masks.extend(masks_batch)
|
||||
|
||||
# Calculate the rewards for each sample
|
||||
rewards = []
|
||||
for sample_tokens in generated_token_sequences:
|
||||
# Get just the generated tokens (after the prompt)
|
||||
generated_tokens = sample_tokens[prefix_length:]
|
||||
# Decode the generated response
|
||||
generated_text = tokenizer.decode(generated_tokens)
|
||||
# Calculate the reward
|
||||
reward = train_task.reward(conversation, generated_text)
|
||||
rewards.append(reward)
|
||||
|
||||
# Pad the sequences so that their lengths (in time) match
|
||||
max_length = max(len(seq) for seq in generated_token_sequences)
|
||||
padded_generated_token_sequences = [seq + [assistant_end] * (max_length - len(seq)) for seq in generated_token_sequences]
|
||||
padded_masks = [mask + [0] * (max_length - len(mask)) for mask in masks]
|
||||
# Stack up the sequences and masks into PyTorch tensors
|
||||
ids = torch.tensor(padded_generated_token_sequences, dtype=torch.long, device=device)
|
||||
mask_ids = torch.tensor(padded_masks, dtype=torch.long, device=device)
|
||||
# Generate autoregressive inputs and targets to the Transformer
|
||||
inputs = ids[:, :-1]
|
||||
targets = ids[:, 1:].clone() # clone to avoid in-place modification:
|
||||
targets[mask_ids[:, 1:] == 0] = -1 # <-- inplace modification right here. -1 is the ignore index
|
||||
# NOTE also that the Engine returns mask=0 for BOTH the prompt tokens AND the tool use tokens.
|
||||
# So we will (correctly) end up not training on the prompt tokens, or the tool use forced tokens.
|
||||
rewards = torch.tensor(rewards, dtype=torch.float, device=device)
|
||||
# Calculate the advantages by simply subtracting the mean (instead of z-score (x-mu)/sigma)
|
||||
mu = rewards.mean()
|
||||
advantages = rewards - mu
|
||||
# yield inputs/targets as (B, T) of ids and rewards as (B,) of floats
|
||||
yield generated_token_sequences, inputs, targets, rewards, advantages
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Simple evaluation loop for GSM8K pass@k
|
||||
def run_gsm8k_eval(task, tokenizer, engine,
|
||||
max_examples=None,
|
||||
num_samples=1,
|
||||
max_completion_tokens=256,
|
||||
temperature=0.0,
|
||||
top_k=50
|
||||
):
|
||||
"""
|
||||
Evaluates GSM8K task and returns a list of records of evaluation outcomes.
|
||||
In a distributed setting, all ranks cooperate but this function will NOT
|
||||
do the reduction across ranks. This is the responsibility of the caller.
|
||||
Because the evaluation can take a while, this function will yield records one by one.
|
||||
"""
|
||||
max_examples = min(max_examples, len(task)) if max_examples is not None else len(task)
|
||||
for idx in range(ddp_rank, max_examples, ddp_world_size):
|
||||
conversation = task[idx]
|
||||
tokens = tokenizer.render_for_completion(conversation)
|
||||
prefix_length = len(tokens)
|
||||
# Generate k samples using batched generation inside the Engine
|
||||
assert num_samples <= device_batch_size # usually this is true. we can add a loop if not...
|
||||
generated_token_sequences, masks = engine.generate_batch(
|
||||
tokens,
|
||||
num_samples=num_samples,
|
||||
max_tokens=max_completion_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k
|
||||
)
|
||||
# Check each sample for correctness
|
||||
outcomes = []
|
||||
for sample_tokens in generated_token_sequences:
|
||||
generated_tokens = sample_tokens[prefix_length:]
|
||||
generated_text = tokenizer.decode(generated_tokens)
|
||||
is_correct = task.evaluate(conversation, generated_text)
|
||||
outcomes.append({
|
||||
"is_correct": is_correct
|
||||
})
|
||||
# A bit bloated because I wanted to do more complex logging at one point.
|
||||
record = {
|
||||
"idx": idx,
|
||||
"outcomes": outcomes,
|
||||
}
|
||||
yield record
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Training loop
|
||||
|
||||
# Init the optimizer
|
||||
optimizers = model.setup_optimizers(
|
||||
unembedding_lr=unembedding_lr,
|
||||
embedding_lr=embedding_lr,
|
||||
matrix_lr=matrix_lr,
|
||||
weight_decay=weight_decay,
|
||||
)
|
||||
|
||||
# Set the initial learning rate as a fraction of the base learning rate
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
|
||||
# Learning rate scheduler: simple rampdown to zero over num_steps
|
||||
def get_lr_multiplier(it):
|
||||
lrm = 1.0 - it / num_steps
|
||||
return lrm
|
||||
|
||||
# Calculate the number of examples each rank handles to achive the desired examples_per_step
|
||||
print0(f"Total sequences per step: {examples_per_step * num_samples}") # total batch size in sequences/step
|
||||
assert examples_per_step % ddp_world_size == 0, "Desired examples per step must be divisible by the number of ranks"
|
||||
examples_per_rank = examples_per_step // ddp_world_size # per GPU
|
||||
print0(f"Calculated examples per rank: {examples_per_rank}")
|
||||
|
||||
# Kick off the training loop
|
||||
batch_iterator = get_batch()
|
||||
for step in range(num_steps):
|
||||
|
||||
# Evaluate the model once in a while and log to wandb
|
||||
if step % eval_every == 0:
|
||||
model.eval()
|
||||
passk = torch.zeros(device_batch_size, device=device) # pass@k for k=1..device_batch_size
|
||||
with autocast_ctx:
|
||||
records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=device_batch_size, max_examples=eval_examples, temperature=1.0)
|
||||
records = list(records_iter) # collect all records
|
||||
for k in range(1, device_batch_size + 1):
|
||||
passk[k - 1] = sum(any(o["is_correct"] for o in r["outcomes"][:k]) for r in records)
|
||||
num_records = torch.tensor(len(records), dtype=torch.long, device=device)
|
||||
if ddp:
|
||||
dist.all_reduce(num_records, op=dist.ReduceOp.SUM)
|
||||
dist.all_reduce(passk, op=dist.ReduceOp.SUM)
|
||||
passk = passk / num_records.item() # normalize by the total number of records
|
||||
print_passk = [f"Pass@{k}: {passk[k - 1].item():.4f}" for k in range(1, device_batch_size + 1)]
|
||||
print0(f"Step {step} | {', '.join(print_passk)}")
|
||||
log_passk = {f"pass@{k}": passk[k - 1].item() for k in range(1, device_batch_size + 1)}
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
**log_passk,
|
||||
})
|
||||
|
||||
# Forward/Backward on rollouts over multiple examples in the dataset
|
||||
rewards_list = []
|
||||
sequence_lengths = []
|
||||
for example_step in range(examples_per_rank):
|
||||
# Get one batch corresponding to one example in the training dataset
|
||||
sequences_all, inputs_all, targets_all, rewards_all, advantages_all = next(batch_iterator)
|
||||
# Evaluate the loss and gradients
|
||||
model.train() # ensure the model is in train mode
|
||||
# We need one more loop because we can never exceed the device_batch_size
|
||||
assert inputs_all.size(0) % device_batch_size == 0
|
||||
num_passes = inputs_all.size(0) // device_batch_size
|
||||
for pass_idx in range(num_passes):
|
||||
# Pluck out the batch for this pass
|
||||
b0, b1 = pass_idx * device_batch_size, (pass_idx + 1) * device_batch_size
|
||||
inputs = inputs_all[b0:b1]
|
||||
targets = targets_all[b0:b1]
|
||||
rewards = rewards_all[b0:b1]
|
||||
advantages = advantages_all[b0:b1]
|
||||
# Calculate log probabilities. Note that the loss calculates NLL = -logp, so we negate
|
||||
with autocast_ctx:
|
||||
logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T)
|
||||
# Calculate the PG objective. Note that ignore_index=-1 ensures that invalid tokens have loss 0.
|
||||
pg_obj = (logp * advantages.unsqueeze(-1)).sum()
|
||||
# normalize by the number of valid tokens, number of passes, and examples_per_rank
|
||||
num_valid = (targets >= 0).sum().clamp(min=1)
|
||||
pg_obj = pg_obj / (num_valid * num_passes * examples_per_rank)
|
||||
# Note, there is no need to add PPO ratio+clip because we are on policy
|
||||
# Finally, formulate the loss that we want to minimize (instead of objective we wish to maximize)
|
||||
loss = -pg_obj
|
||||
loss.backward()
|
||||
print0(f"Step {step}/{num_steps} | Example step {example_step} | Pass {pass_idx} | loss: {loss.item():.6f} | Average reward: {rewards.mean().item()}")
|
||||
# For logging
|
||||
rewards_list.append(rewards_all.mean().item())
|
||||
sequence_lengths.extend(len(seq) for seq in sequences_all)
|
||||
|
||||
# A bunch of logging for how the rollouts went this step
|
||||
mean_reward = sum(rewards_list) / len(rewards_list)
|
||||
mean_sequence_length = sum(sequence_lengths) / len(sequence_lengths)
|
||||
if ddp: # aggregate across ranks
|
||||
mean_reward_tensor = torch.tensor(mean_reward, dtype=torch.float, device=device)
|
||||
mean_sequence_length_tensor = torch.tensor(mean_sequence_length, dtype=torch.float, device=device)
|
||||
dist.all_reduce(mean_reward_tensor, op=dist.ReduceOp.AVG)
|
||||
dist.all_reduce(mean_sequence_length_tensor, op=dist.ReduceOp.AVG)
|
||||
mean_reward = mean_reward_tensor.item()
|
||||
mean_sequence_length = mean_sequence_length_tensor.item()
|
||||
print0(f"Step {step}/{num_steps} | Average reward: {mean_reward} | Average sequence length: {mean_sequence_length:.2f}")
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"reward": mean_reward,
|
||||
"sequence_length": mean_sequence_length,
|
||||
})
|
||||
|
||||
# Update the model parameters
|
||||
lrm = get_lr_multiplier(step)
|
||||
for opt in optimizers: # first set the learning rate
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
for opt in optimizers: # then step the optimizers
|
||||
opt.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"lrm": lrm,
|
||||
})
|
||||
|
||||
# Master process saves the model once in a while. Skip first step. Save last step.
|
||||
if master_process and ((step > 0 and step % save_every == 0) or step == num_steps - 1):
|
||||
base_dir = get_base_dir()
|
||||
depth = model.config.n_layer
|
||||
model_tag = f"d{depth}" # base the model tag on the depth of the base model
|
||||
checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", model_tag)
|
||||
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
model.state_dict(),
|
||||
None, # note: we don't bother to save the optimizer state
|
||||
{
|
||||
"model_config": model_config_kwargs,
|
||||
}
|
||||
)
|
||||
print(f"✅ Saved model checkpoint to {checkpoint_dir}")
|
||||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
get_report().log(section="Chat RL", data=[
|
||||
user_config, # CLI args
|
||||
])
|
||||
|
||||
wandb_run.finish() # wandb run finish
|
||||
compute_cleanup()
|
||||
281
scripts/chat_sft.py
Normal file
281
scripts/chat_sft.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""
|
||||
Finetune a base model to be a chat model.
|
||||
Run on one GPU e.g. for debugging:
|
||||
|
||||
python -m scripts.chat_sft
|
||||
|
||||
Or torchrun for training:
|
||||
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft
|
||||
"""
|
||||
|
||||
import os
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import copy
|
||||
|
||||
import wandb
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.engine import Engine
|
||||
from scripts.chat_eval import run_chat_eval
|
||||
|
||||
from tasks.common import TaskMixture, TaskSequence
|
||||
from tasks.mmlu import MMLU
|
||||
from tasks.arc import ARC
|
||||
from tasks.gsm8k import GSM8K
|
||||
from tasks.humaneval import HumanEval
|
||||
from tasks.smoltalk import SmolTalk
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# SFT Hyperparameters
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
# input model options
|
||||
source = "mid" # base|mid , which checkpoint to load the model from (base model or midtrained model)
|
||||
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||
step = None # step to load the model from (base model or midtrained model)
|
||||
# compute/precision
|
||||
dtype = "bfloat16"
|
||||
device_batch_size = 4 # max to avoid OOM
|
||||
# optimization
|
||||
num_epochs = 1
|
||||
max_iterations = -1 # override number of iterations (-1 = use num_epochs * num_iterations)
|
||||
target_examples_per_step = 32
|
||||
unembedding_lr = 0.004
|
||||
embedding_lr = 0.2
|
||||
matrix_lr = 0.02
|
||||
weight_decay = 0.0
|
||||
init_lr_frac = 0.02
|
||||
# evaluation and logging there of
|
||||
eval_every = 100
|
||||
eval_steps = 100
|
||||
eval_metrics_every = 200
|
||||
# now allow CLI to override the settings via the configurator lol
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
master_process = ddp_rank == 0
|
||||
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=run, config=user_config, save_code=True)
|
||||
|
||||
# Load the model and tokenizer
|
||||
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
|
||||
orig_model = model # original, uncompiled model
|
||||
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
|
||||
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Task data mixture we'll train on
|
||||
|
||||
train_ds = TaskMixture([
|
||||
ARC(subset="ARC-Easy", split="train"), # 2.3K rows
|
||||
ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
|
||||
GSM8K(subset="main", split="train"), # 8K rows
|
||||
SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
|
||||
]) # 2.3K + 1.1K + 8K + 10K = 21.4K rows
|
||||
val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# DataLoader
|
||||
|
||||
def sft_data_generator(dataset, batch_size):
|
||||
pad_token_id = tokenizer.encode_special("<|assistant_end|>") # use <|assistant_end|> as the pad token is ok, these positions are masked in the loss
|
||||
# prepares a list of tokenized conversations into a batch and yields
|
||||
def collate_and_yield(batch):
|
||||
nrows = len(batch)
|
||||
ncols = max(len(ids) for ids, mask in batch) - 1 # seq of n creates inputs/targets of n-1
|
||||
inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long)
|
||||
targets = torch.full((nrows, ncols), -1, dtype=torch.long) # -1 is ignore index
|
||||
for i, (ids, mask) in enumerate(batch):
|
||||
n = len(ids)
|
||||
ids_tensor = torch.tensor(ids, dtype=torch.long)
|
||||
inputs[i, :n-1] = ids_tensor[:-1]
|
||||
# recall -1 is the ignore index, so mask out targets where mask is 0
|
||||
row_targets = ids_tensor[1:]
|
||||
# mask[1:] omits the mask for the BOS token, which is never a target atm so it's ok
|
||||
mask_tensor = torch.tensor(mask[1:], dtype=torch.long)
|
||||
row_targets[mask_tensor == 0] = -1 # mask out targets where mask is 0
|
||||
targets[i, :n-1] = row_targets
|
||||
inputs = inputs.to(device) # move to device
|
||||
targets = targets.to(device)
|
||||
return inputs, targets
|
||||
# iterates over the dataset in epochs, tokenizes
|
||||
batch = []
|
||||
while True:
|
||||
for i in range(ddp_rank, len(dataset), ddp_world_size):
|
||||
doc = dataset[i]
|
||||
ids, mask = tokenizer.render_conversation(doc)
|
||||
batch.append((ids, mask))
|
||||
if len(batch) == batch_size:
|
||||
yield collate_and_yield(batch)
|
||||
batch = []
|
||||
|
||||
examples_per_step = device_batch_size * ddp_world_size
|
||||
print0(f"Target examples per step: {target_examples_per_step}")
|
||||
print0(f"Device batch size: {device_batch_size}")
|
||||
print0(f"Examples per step is device_batch_size * ddp_world_size: {examples_per_step}")
|
||||
assert target_examples_per_step % examples_per_step == 0, "Target examples per step must be divisible by examples per step"
|
||||
grad_accum_steps = target_examples_per_step // examples_per_step
|
||||
print0(f"=> Setting grad accum steps: {grad_accum_steps}")
|
||||
|
||||
num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs
|
||||
if max_iterations >= 0 and num_iterations > max_iterations:
|
||||
print0(f"Number of iterations is too high: {num_iterations}, capping to {max_iterations}")
|
||||
num_iterations = max_iterations
|
||||
train_loader = sft_data_generator(train_ds, batch_size=device_batch_size)
|
||||
build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Optimizer
|
||||
|
||||
optimizers = model.setup_optimizers(
|
||||
unembedding_lr=unembedding_lr,
|
||||
embedding_lr=embedding_lr,
|
||||
matrix_lr=matrix_lr,
|
||||
weight_decay=weight_decay,
|
||||
)
|
||||
# Set the initial learning rate as a fraction of the base learning rate
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Training loop
|
||||
|
||||
# Learning rate scheduler
|
||||
def get_lr_multiplier(it):
|
||||
lrm = 1.0 - it / num_iterations
|
||||
return lrm
|
||||
|
||||
# Go!
|
||||
step = 0
|
||||
train_iter = iter(train_loader)
|
||||
for step in range(num_iterations):
|
||||
last_step = step == num_iterations - 1
|
||||
|
||||
# evaluate the validation loss
|
||||
if last_step or step % eval_every == 0:
|
||||
model.eval()
|
||||
val_iter = iter(build_val_loader())
|
||||
losses = []
|
||||
for _ in range(eval_steps):
|
||||
val_inputs, val_targets = next(val_iter)
|
||||
with torch.no_grad(), autocast_ctx:
|
||||
loss = model(val_inputs, val_targets)
|
||||
losses.append(loss)
|
||||
val_loss = torch.stack(losses).mean() # average over eval_steps
|
||||
if ddp:
|
||||
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) # average over ranks
|
||||
val_loss = val_loss.item()
|
||||
print0(f"Step {step:05d} | Validation loss: {val_loss:.6f}")
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"val_loss": val_loss,
|
||||
})
|
||||
model.train()
|
||||
|
||||
# evlauate MMLU accuracy
|
||||
if last_step or (step > 0 and step % eval_metrics_every == 0):
|
||||
model.eval()
|
||||
metrics = {}
|
||||
with torch.no_grad(), autocast_ctx:
|
||||
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
|
||||
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
|
||||
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
|
||||
metrics["gsm8k_acc"] = run_chat_eval("GSM8K", model, tokenizer, engine, max_problems=64)
|
||||
metrics["humaneval_acc"] = run_chat_eval("HumanEval", model, tokenizer, engine, max_problems=64)
|
||||
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
|
||||
print0(f"Step {step:05d} | {metrics_str}")
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
**metrics,
|
||||
})
|
||||
model.train()
|
||||
|
||||
if last_step:
|
||||
break
|
||||
|
||||
# evaluate the gradient
|
||||
num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen
|
||||
for micro_step in range(grad_accum_steps):
|
||||
train_inputs, train_targets = next(train_iter)
|
||||
with autocast_ctx:
|
||||
loss = model(train_inputs, train_targets)
|
||||
train_loss = loss.detach() # for logging
|
||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
loss.backward() # accumulate the gradient
|
||||
num_tokens += (train_targets >= 0).sum()
|
||||
if ddp:
|
||||
dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks
|
||||
|
||||
# learning rate scheduler
|
||||
lrm = get_lr_multiplier(step)
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
|
||||
# step the optimizers
|
||||
for opt in optimizers:
|
||||
opt.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
|
||||
# logging
|
||||
train_loss_item = train_loss.item()
|
||||
num_tokens_item = num_tokens.item()
|
||||
print0(f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,}")
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"lrm": lrm,
|
||||
"train_loss": train_loss_item,
|
||||
"num_tokens": num_tokens_item,
|
||||
})
|
||||
step += 1
|
||||
|
||||
# Save the model at the end of the run
|
||||
if master_process:
|
||||
base_dir = get_base_dir()
|
||||
depth = model.config.n_layer
|
||||
model_tag = f"d{depth}" # base the model tag on the depth of the base model
|
||||
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", model_tag)
|
||||
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
model.state_dict(),
|
||||
None, # note: we don't bother to save the optimizer state
|
||||
{
|
||||
"step": step,
|
||||
"val_loss": val_loss,
|
||||
**metrics,
|
||||
"model_config": model_config_kwargs,
|
||||
}
|
||||
)
|
||||
print(f"✅ Saved model checkpoint to {checkpoint_dir}")
|
||||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
get_report().log(section="Chat SFT", data=[
|
||||
user_config, # CLI args
|
||||
{
|
||||
"Training rows": len(train_ds),
|
||||
"Number of iterations": num_iterations,
|
||||
"Training loss": train_loss_item,
|
||||
"Validation loss": val_loss,
|
||||
},
|
||||
])
|
||||
|
||||
# Cleanup
|
||||
wandb_run.finish()
|
||||
compute_cleanup()
|
||||
198
scripts/chat_web.py
Normal file
198
scripts/chat_web.py
Normal file
@@ -0,0 +1,198 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Unified web chat server - serves both UI and API from a single FastAPI instance.
|
||||
Run with: python web_chat.py
|
||||
Then open http://localhost:8000 in your browser.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import torch
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, AsyncGenerator
|
||||
|
||||
from nanochat.common import compute_init
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.engine import Engine
|
||||
|
||||
parser = argparse.ArgumentParser(description='NanoChat Web Server')
|
||||
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
|
||||
parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation')
|
||||
parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter')
|
||||
parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default max tokens for generation')
|
||||
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
||||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
|
||||
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
|
||||
args = parser.parse_args()
|
||||
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
messages: List[ChatMessage]
|
||||
temperature: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
top_k: Optional[int] = None
|
||||
stream: Optional[bool] = True
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Load model on startup."""
|
||||
print("Loading nanochat model...")
|
||||
app.state.model, app.state.tokenizer, _ = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
app.state.engine = Engine(app.state.model, app.state.tokenizer)
|
||||
print(f"Server ready at http://localhost:{args.port}")
|
||||
yield
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Serve the chat UI."""
|
||||
ui_html_path = os.path.join("nanochat", "ui.html")
|
||||
with open(ui_html_path, "r") as f:
|
||||
html_content = f.read()
|
||||
# Replace the API_URL to use the same origin
|
||||
html_content = html_content.replace(
|
||||
"const API_URL = `http://${window.location.hostname}:8000`;",
|
||||
"const API_URL = '';"
|
||||
)
|
||||
return HTMLResponse(content=html_content)
|
||||
|
||||
|
||||
@app.get("/logo.svg")
|
||||
async def logo():
|
||||
"""Serve the NanoChat logo for favicon and header."""
|
||||
logo_path = os.path.join("nanochat", "logo.svg")
|
||||
return FileResponse(logo_path, media_type="image/svg+xml")
|
||||
|
||||
async def generate_stream(
|
||||
engine,
|
||||
tokenizer,
|
||||
tokens,
|
||||
temperature=None,
|
||||
max_new_tokens=None,
|
||||
top_k=None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate assistant response with streaming."""
|
||||
temperature = temperature if temperature is not None else args.temperature
|
||||
max_new_tokens = max_new_tokens if max_new_tokens is not None else args.max_tokens
|
||||
top_k = top_k if top_k is not None else args.top_k
|
||||
|
||||
assistant_end = tokenizer.encode_special("<|assistant_end|>")
|
||||
bos = tokenizer.get_bos_token_id()
|
||||
|
||||
with autocast_ctx:
|
||||
for token_column, token_masks in engine.generate(
|
||||
tokens,
|
||||
num_samples=1,
|
||||
max_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k
|
||||
):
|
||||
token = token_column[0]
|
||||
|
||||
if token == assistant_end or token == bos:
|
||||
break
|
||||
|
||||
token_text = tokenizer.decode([token])
|
||||
yield f"data: {json.dumps({'token': token_text})}\n\n"
|
||||
|
||||
yield f"data: {json.dumps({'done': True})}\n\n"
|
||||
|
||||
@app.post("/chat/completions")
|
||||
async def chat_completions(request: ChatRequest):
|
||||
"""Chat completion endpoint with streaming."""
|
||||
engine = app.state.engine
|
||||
tokenizer = app.state.tokenizer
|
||||
|
||||
# Build conversation tokens
|
||||
bos = tokenizer.get_bos_token_id()
|
||||
user_start = tokenizer.encode_special("<|user_start|>")
|
||||
user_end = tokenizer.encode_special("<|user_end|>")
|
||||
assistant_start = tokenizer.encode_special("<|assistant_start|>")
|
||||
assistant_end = tokenizer.encode_special("<|assistant_end|>")
|
||||
|
||||
conversation_tokens = [bos]
|
||||
for message in request.messages:
|
||||
if message.role == "user":
|
||||
conversation_tokens.append(user_start)
|
||||
conversation_tokens.extend(tokenizer.encode(message.content))
|
||||
conversation_tokens.append(user_end)
|
||||
elif message.role == "assistant":
|
||||
conversation_tokens.append(assistant_start)
|
||||
conversation_tokens.extend(tokenizer.encode(message.content))
|
||||
conversation_tokens.append(assistant_end)
|
||||
|
||||
conversation_tokens.append(assistant_start)
|
||||
|
||||
if request.stream:
|
||||
return StreamingResponse(
|
||||
generate_stream(
|
||||
engine,
|
||||
tokenizer,
|
||||
conversation_tokens,
|
||||
temperature=request.temperature,
|
||||
max_new_tokens=request.max_tokens,
|
||||
top_k=request.top_k
|
||||
),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
else:
|
||||
# Non-streaming response
|
||||
temperature = request.temperature if request.temperature is not None else args.temperature
|
||||
max_tokens = request.max_tokens if request.max_tokens is not None else args.max_tokens
|
||||
top_k = request.top_k if request.top_k is not None else args.top_k
|
||||
|
||||
with autocast_ctx:
|
||||
result_tokens, masks = engine.generate_batch(
|
||||
conversation_tokens,
|
||||
num_samples=1,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k
|
||||
)[0]
|
||||
|
||||
response_tokens = result_tokens[len(conversation_tokens):]
|
||||
response_text = tokenizer.decode(response_tokens)
|
||||
return {
|
||||
"choices": [{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": response_text
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}]
|
||||
}
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
"""Health check endpoint."""
|
||||
return {
|
||||
"status": "ok",
|
||||
"ready": hasattr(app.state, 'model') and app.state.model is not None
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
print(f"Starting NanoChat Web Server")
|
||||
print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}")
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
289
scripts/mid_train.py
Normal file
289
scripts/mid_train.py
Normal file
@@ -0,0 +1,289 @@
|
||||
"""
|
||||
Midtrain the model. Same as pretraining but simpler.
|
||||
Run as:
|
||||
|
||||
python -m scripts.mid_train
|
||||
|
||||
Or torchrun for training:
|
||||
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
import os
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import time
|
||||
import wandb
|
||||
import torch
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir
|
||||
from nanochat.tokenizer import get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
import torch.distributed as dist
|
||||
|
||||
from tasks.common import TaskMixture
|
||||
from tasks.gsm8k import GSM8K
|
||||
from tasks.mmlu import MMLU
|
||||
from tasks.smoltalk import SmolTalk
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||
step = None # step to load the model from (base model or midtrained model)
|
||||
dtype = "bfloat16"
|
||||
max_seq_len = 2048
|
||||
device_batch_size = 32
|
||||
unembedding_lr = 0.004
|
||||
embedding_lr = 0.2
|
||||
matrix_lr = 0.02
|
||||
init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate
|
||||
weight_decay = 0.0
|
||||
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
|
||||
eval_every = 150
|
||||
eval_tokens = 20*524288
|
||||
total_batch_size = 524288
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
master_process = ddp_rank == 0
|
||||
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mid", name=run, config=user_config)
|
||||
|
||||
# Load the model and tokenizer
|
||||
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step)
|
||||
pretrain_batch_size = meta.get("device_batch_size", None)
|
||||
if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size:
|
||||
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?")
|
||||
orig_model = model
|
||||
model = torch.compile(model, dynamic=False)
|
||||
depth = model.config.n_layer
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
||||
assert total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
|
||||
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
|
||||
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
|
||||
adamw_optimizer, muon_optimizer = optimizers
|
||||
# Override the initial learning rate as a fraction of the base learning rate
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
|
||||
# Midtraining data mixture and DataLoader
|
||||
base_dir = get_base_dir()
|
||||
train_dataset = TaskMixture([
|
||||
SmolTalk(split="train"), # 460K rows of general conversations
|
||||
MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
|
||||
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
|
||||
]) # total: 460K + 100K + 8K = 568K rows
|
||||
val_dataset = TaskMixture([
|
||||
SmolTalk(split="test"), # 24K rows in test set
|
||||
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
|
||||
GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios
|
||||
]) # total: 24K + 14K + 1.32K ~= 39K rows
|
||||
# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len)
|
||||
# A big problem is that we don't know the final num_iterations in advance. So we create
|
||||
# these two global variables and update them from within the data generator.
|
||||
last_step = False # we will toggle this to True when we reach the end of the dataset
|
||||
approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
|
||||
def mid_data_generator(split):
|
||||
global last_step, approx_progress
|
||||
assert split in {"train", "val"}, "split must be 'train' or 'val'"
|
||||
dataset = train_dataset if split == "train" else val_dataset
|
||||
dataset_size = len(dataset)
|
||||
assert dataset_size > 0
|
||||
needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets
|
||||
token_buffer = deque()
|
||||
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
|
||||
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
|
||||
while True:
|
||||
# Accumulate enough tokens for one iteration before yielding
|
||||
while len(token_buffer) < needed_tokens:
|
||||
conversation = dataset[cursor]
|
||||
ids, _ = tokenizer.render_conversation(conversation)
|
||||
token_buffer.extend(ids)
|
||||
cursor += ddp_world_size
|
||||
if cursor >= dataset_size:
|
||||
cursor -= dataset_size # wrap around for another epoch
|
||||
if split == "train":
|
||||
last_step = True # toggle last_step to True, which will terminate the training loop
|
||||
# Build up inputs/targets and yield
|
||||
for i in range(needed_tokens):
|
||||
scratch[i] = token_buffer.popleft()
|
||||
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
||||
targets_cpu = scratch[1:]
|
||||
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
|
||||
if split == "train":
|
||||
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
|
||||
yield inputs, targets
|
||||
|
||||
train_loader = mid_data_generator("train")
|
||||
build_val_loader = lambda: mid_data_generator("val")
|
||||
progress = 0 # will go from 0 to 1 over the course of the epoch
|
||||
|
||||
# Learning rate scheduler
|
||||
def get_lr_multiplier(progress):
|
||||
return progress * 1.0 + (1 - progress) * final_lr_frac
|
||||
|
||||
# Momentum scheduler for Muon optimizer
|
||||
def get_muon_momentum(it):
|
||||
frac = min(it / 300, 1)
|
||||
momentum = (1 - frac) * 0.85 + frac * 0.95
|
||||
return momentum
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Training loop
|
||||
x, y = next(train_loader) # prefetch the very first batch of data
|
||||
min_val_bpb = float("inf")
|
||||
smooth_train_loss = 0 # EMA of training loss
|
||||
ema_beta = 0.9 # EMA decay factor
|
||||
total_training_time = 0 # total wall-clock time of training
|
||||
step = 0
|
||||
while True:
|
||||
flops_so_far = num_flops_per_token * total_batch_size * step
|
||||
|
||||
# Synchronize last_step across all ranks to avoid hangs in the distributed setting
|
||||
if ddp:
|
||||
last_step_tensor = torch.tensor(last_step, dtype=torch.int32, device=device)
|
||||
dist.all_reduce(last_step_tensor, op=dist.ReduceOp.MAX)
|
||||
last_step = bool(last_step_tensor.item())
|
||||
|
||||
# once in a while: evaluate the val bpb (all ranks participate)
|
||||
if last_step or step % eval_every == 0:
|
||||
model.eval()
|
||||
val_loader = build_val_loader()
|
||||
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
|
||||
with autocast_ctx:
|
||||
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
|
||||
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
|
||||
if val_bpb < min_val_bpb:
|
||||
min_val_bpb = val_bpb
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"total_training_time": total_training_time,
|
||||
"val/bpb": val_bpb,
|
||||
})
|
||||
model.train()
|
||||
|
||||
# save checkpoint at the end of the run (only on master process)
|
||||
if master_process and last_step:
|
||||
output_dirname = f"d{depth}" # e.g. d12
|
||||
checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname)
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
orig_model.state_dict(),
|
||||
[opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly
|
||||
{
|
||||
"step": step,
|
||||
"val_bpb": val_bpb, # loss at last step
|
||||
"model_config": {
|
||||
"sequence_len": max_seq_len,
|
||||
"vocab_size": tokenizer.get_vocab_size(),
|
||||
"n_layer": depth,
|
||||
"n_head": model.config.n_head,
|
||||
"n_kv_head": model.config.n_kv_head,
|
||||
"n_embd": model.config.n_embd,
|
||||
},
|
||||
"user_config": user_config, # inputs to the training script
|
||||
}
|
||||
)
|
||||
|
||||
if last_step:
|
||||
break
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# single training step
|
||||
# evaluate the gradient
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for micro_step in range(grad_accum_steps):
|
||||
with autocast_ctx:
|
||||
loss = model(x, y)
|
||||
train_loss = loss.detach() # for logging
|
||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
loss.backward()
|
||||
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
progress = max(progress, approx_progress) # only increase progress monotonically
|
||||
# step the optimizers
|
||||
lrm = get_lr_multiplier(progress)
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
muon_momentum = get_muon_momentum(step)
|
||||
for group in muon_optimizer.param_groups:
|
||||
group["momentum"] = muon_momentum
|
||||
for opt in optimizers:
|
||||
opt.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.time()
|
||||
dt = t1 - t0
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
# State
|
||||
step += 1
|
||||
|
||||
# logging
|
||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
||||
pct_done = 100 * progress
|
||||
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
|
||||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
if step > 10:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
|
||||
if step % 10 == 0:
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"total_training_time": total_training_time,
|
||||
"train/loss": debiased_smooth_loss,
|
||||
"train/lrm": lrm,
|
||||
"train/dt": dt,
|
||||
"train/tok_per_sec": tok_per_sec,
|
||||
"train/mfu": mfu,
|
||||
})
|
||||
|
||||
# print a few more stats
|
||||
print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB")
|
||||
print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
get_report().log(section="Midtraining", data=[
|
||||
user_config, # CLI args
|
||||
{ # stats about the training setup
|
||||
"Number of iterations": step,
|
||||
"DDP world size": ddp_world_size,
|
||||
},
|
||||
{ # stats about training outcomes
|
||||
"Minimum validation bpb": min_val_bpb,
|
||||
}
|
||||
])
|
||||
|
||||
# cleanup
|
||||
wandb_run.finish() # wandb run finish
|
||||
compute_cleanup()
|
||||
265
scripts/tok_eval.py
Normal file
265
scripts/tok_eval.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""
|
||||
Evaluate compression ratio of the tokenizer.
|
||||
"""
|
||||
|
||||
from nanochat.tokenizer import get_tokenizer, RustBPETokenizer
|
||||
from nanochat.dataset import parquets_iter_batched
|
||||
|
||||
# Random text I got from a random website this morning
|
||||
news_text = r"""
|
||||
(Washington, D.C., July 9, 2025)- Yesterday, Mexico’s National Service of Agro-Alimentary Health, Safety, and Quality (SENASICA) reported a new case of New World Screwworm (NWS) in Ixhuatlan de Madero, Veracruz in Mexico, which is approximately 160 miles northward of the current sterile fly dispersal grid, on the eastern side of the country and 370 miles south of the U.S./Mexico border. This new northward detection comes approximately two months after northern detections were reported in Oaxaca and Veracruz, less than 700 miles away from the U.S. border, which triggered the closure of our ports to Mexican cattle, bison, and horses on May 11, 2025.
|
||||
|
||||
While USDA announced a risk-based phased port re-opening strategy for cattle, bison, and equine from Mexico beginning as early as July 7, 2025, this newly reported NWS case raises significant concern about the previously reported information shared by Mexican officials and severely compromises the outlined port reopening schedule of five ports from July 7-September 15. Therefore, in order to protect American livestock and our nation’s food supply, Secretary Rollins has ordered the closure of livestock trade through southern ports of entry effective immediately.
|
||||
|
||||
“The United States has promised to be vigilant — and after detecting this new NWS case, we are pausing the planned port reopening’s to further quarantine and target this deadly pest in Mexico. We must see additional progress combatting NWS in Veracruz and other nearby Mexican states in order to reopen livestock ports along the Southern border,” said U.S. Secretary of Agriculture Brooke L. Rollins. “Thanks to the aggressive monitoring by USDA staff in the U.S. and in Mexico, we have been able to take quick and decisive action to respond to the spread of this deadly pest.”
|
||||
""".strip()
|
||||
|
||||
# Random Korean text (to test non-English compression)
|
||||
korean_text = r"""
|
||||
정직한 사실 위에, 공정한 시선을 더하다
|
||||
Herald Korea Times
|
||||
|
||||
헤럴드코리아타임즈는 정치, 경제, 사회, 문화 등 한국 사회 전반의 주요 이슈를 심도 있게 다루는 종합 온라인 신문사입니다.
|
||||
|
||||
우리는 단순히 뉴스를 전달하는 것이 아니라, 사실(Fact)에 기반한 양측의 시각을 균형 있게 조명하며, 독자 여러분이 스스로 판단할 수 있는 ‘정보의 균형’을 제공합니다.
|
||||
|
||||
한국 언론의 오랜 문제로 지적되어 온 정치적 편향, 이념적 왜곡에서 벗어나
|
||||
오직 정직함과 공정함을 원칙으로 삼는 언론을 지향합니다.
|
||||
어느 한쪽의 주장만을 확대하거나 감추지 않고,
|
||||
**모든 쟁점에 대해 ‘무엇이 쟁점인지’, ‘누가 무엇을 주장하는지’, ‘사실은 무엇인지’**를 명확히 전달하는 데 집중합니다.
|
||||
""".strip()
|
||||
|
||||
# Random piece of code
|
||||
code_text = r"""
|
||||
class BasicTokenizer(Tokenizer):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def train(self, text, vocab_size, verbose=False):
|
||||
assert vocab_size >= 256
|
||||
num_merges = vocab_size - 256
|
||||
|
||||
# input text preprocessing
|
||||
text_bytes = text.encode("utf-8") # raw bytes
|
||||
ids = list(text_bytes) # list of integers in range 0..255
|
||||
|
||||
# iteratively merge the most common pairs to create new tokens
|
||||
merges = {} # (int, int) -> int
|
||||
vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes
|
||||
for i in range(num_merges):
|
||||
# count up the number of times every consecutive pair appears
|
||||
stats = get_stats(ids)
|
||||
# find the pair with the highest count
|
||||
pair = max(stats, key=stats.get)
|
||||
# mint a new token: assign it the next available id
|
||||
idx = 256 + i
|
||||
# replace all occurrences of pair in ids with idx
|
||||
ids = merge(ids, pair, idx)
|
||||
# save the merge
|
||||
merges[pair] = idx
|
||||
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
|
||||
# prints
|
||||
if verbose:
|
||||
print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
|
||||
""".strip()
|
||||
|
||||
math_text = r"""
|
||||
\documentclass[12pt]{article}
|
||||
\usepackage{amsmath,amsthm,amssymb}
|
||||
\usepackage[margin=1in]{geometry}
|
||||
|
||||
\newtheorem{theorem}{Theorem}
|
||||
\newtheorem*{remark}{Remark}
|
||||
|
||||
\begin{document}
|
||||
|
||||
\begin{center}
|
||||
{\Large A Cute Identity: The Sum of Cubes is a Square}
|
||||
\end{center}
|
||||
|
||||
\begin{theorem}
|
||||
For every integer $n \ge 1$,
|
||||
\[
|
||||
\sum_{k=1}^{n} k^{3} \;=\; \left(\frac{n(n+1)}{2}\right)^{2}.
|
||||
\]
|
||||
\end{theorem}
|
||||
|
||||
\begin{proof}[Proof 1 (Induction)]
|
||||
Let $S(n) = \sum_{k=1}^{n} k^3$. For $n=1$, $S(1)=1=(1\cdot 2/2)^2$, so the base case holds.
|
||||
|
||||
Assume $S(n)=\big(\tfrac{n(n+1)}{2}\big)^2$ for some $n\ge 1$.
|
||||
Then
|
||||
\[
|
||||
S(n+1)
|
||||
= S(n) + (n+1)^3
|
||||
= \left(\frac{n(n+1)}{2}\right)^2 + (n+1)^3.
|
||||
\]
|
||||
Factor out $(n+1)^2$:
|
||||
\[
|
||||
S(n+1)
|
||||
= (n+1)^2\left( \frac{n^2}{4} + (n+1) \right)
|
||||
= (n+1)^2\left( \frac{n^2 + 4n + 4}{4} \right)
|
||||
= (n+1)^2\left( \frac{(n+2)^2}{4} \right).
|
||||
\]
|
||||
Thus
|
||||
\[
|
||||
S(n+1)=\left(\frac{(n+1)(n+2)}{2}\right)^2,
|
||||
\]
|
||||
which matches the claimed formula with $n$ replaced by $n+1$. By induction, the identity holds for all $n\ge 1$.
|
||||
\end{proof}
|
||||
|
||||
\begin{proof}[Proof 2 (Algebraic telescoping)]
|
||||
Recall the binomial identity
|
||||
\[
|
||||
(k+1)^4 - k^4 = 4k^3 + 6k^2 + 4k + 1.
|
||||
\]
|
||||
Summing both sides from $k=0$ to $n$ telescopes:
|
||||
\[
|
||||
(n+1)^4 - 0^4
|
||||
= \sum_{k=0}^{n}\big(4k^3 + 6k^2 + 4k + 1\big)
|
||||
= 4\sum_{k=1}^{n}k^3 + 6\sum_{k=1}^{n}k^2 + 4\sum_{k=1}^{n}k + (n+1).
|
||||
\]
|
||||
Using the standard sums
|
||||
\[
|
||||
\sum_{k=1}^{n}k = \frac{n(n+1)}{2}
|
||||
\quad\text{and}\quad
|
||||
\sum_{k=1}^{n}k^2 = \frac{n(n+1)(2n+1)}{6},
|
||||
\]
|
||||
solve for $\sum_{k=1}^{n}k^3$ to get
|
||||
\[
|
||||
\sum_{k=1}^{n}k^3 = \left(\frac{n(n+1)}{2}\right)^2.
|
||||
\]
|
||||
\end{proof}
|
||||
|
||||
\begin{remark}
|
||||
Geometrically, the identity says: ``adding up $1^3,2^3,\dots,n^3$ builds a perfect square’’—namely the square of the $n$th triangular number. This is why one sometimes calls it the \emph{sum-of-cubes is a square} phenomenon.
|
||||
\end{remark}
|
||||
|
||||
\end{document}
|
||||
""".strip()
|
||||
|
||||
science_text = r"""
|
||||
Photosynthesis is a photochemical energy transduction process in which light-harvesting pigment–protein complexes within the thylakoid membranes of oxygenic phototrophs absorb photons and initiate charge separation at the reaction center, driving the linear electron transport chain from water to NADP⁺ via photosystem II, the cytochrome b₆f complex, and photosystem I, concomitantly generating a trans-thylakoid proton motive force utilized by chloroplastic ATP synthase. The light-dependent reactions produce ATP and NADPH, which fuel the Calvin–Benson–Bassham cycle in the stroma, wherein ribulose-1,5-bisphosphate is carboxylated by ribulose-1,5-bisphosphate carboxylase/oxygenase (RuBisCO) to form 3-phosphoglycerate, subsequently reduced and regenerated through a series of enzymatic steps, enabling net assimilation of CO₂ into triose phosphates and ultimately carbohydrates. This process is tightly regulated by photoprotective mechanisms, redox feedback, and metabolite flux, representing a central biochemical pathway coupling solar energy capture to the biosphere’s primary productivity.
|
||||
""".strip()
|
||||
|
||||
# The tokenizer was trained on data from earlier shards, so it has seen this data
|
||||
train_docs = next(parquets_iter_batched(split="train"))
|
||||
train_text = "\n".join(train_docs)
|
||||
val_docs = next(parquets_iter_batched(split="val"))
|
||||
val_text = "\n".join(val_docs)
|
||||
|
||||
all_text = [
|
||||
("news", news_text),
|
||||
("korean", korean_text),
|
||||
("code", code_text),
|
||||
("math", math_text),
|
||||
("science", science_text),
|
||||
("fwe-train", train_text),
|
||||
]
|
||||
if val_text:
|
||||
all_text.append(("fwe-val", val_text))
|
||||
|
||||
# Try out current default compared to GPT-2 and GPT-4 tokenizers
|
||||
tokenizer_results = {}
|
||||
vocab_sizes = {}
|
||||
|
||||
for tokenizer_name in ["gpt2", "gpt4", "ours"]:
|
||||
|
||||
if tokenizer_name == "gpt2":
|
||||
tokenizer = RustBPETokenizer.from_pretrained("gpt2") # gpt-2 base model tokenizer
|
||||
elif tokenizer_name == "gpt4":
|
||||
tokenizer = RustBPETokenizer.from_pretrained("cl100k_base") # gpt-4 base model tokenizer
|
||||
else:
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
vocab_sizes[tokenizer_name] = tokenizer.get_vocab_size()
|
||||
tokenizer_results[tokenizer_name] = {}
|
||||
|
||||
for name, text in all_text:
|
||||
encoded = tokenizer.encode(text)
|
||||
decoded = tokenizer.decode(encoded)
|
||||
assert decoded == text
|
||||
|
||||
encoded_bytes = text.encode('utf-8')
|
||||
ratio = len(encoded_bytes) / len(encoded)
|
||||
tokenizer_results[tokenizer_name][name] = {
|
||||
'bytes': len(encoded_bytes),
|
||||
'tokens': len(encoded),
|
||||
'ratio': ratio
|
||||
}
|
||||
|
||||
# ANSI color codes
|
||||
GREEN = '\033[92m'
|
||||
RED = '\033[91m'
|
||||
RESET = '\033[0m'
|
||||
|
||||
# Print vocab sizes
|
||||
print(f"\nVocab sizes:")
|
||||
print(f"GPT-2: {vocab_sizes['gpt2']}")
|
||||
print(f"GPT-4: {vocab_sizes['gpt4']}")
|
||||
print(f"Ours: {vocab_sizes['ours']}")
|
||||
|
||||
def print_comparison(baseline_name, baseline_results, ours_results, all_text):
|
||||
"""Print comparison table between baseline tokenizer and ours."""
|
||||
print(f"\nComparison with {baseline_name}:")
|
||||
print("=" * 95)
|
||||
print(f"{'Text Type':<10} {'Bytes':<8} {baseline_name:<15} {'Ours':<15} {'Relative':<12} {'Better':<10}")
|
||||
print(f"{'':10} {'':8} {'Tokens':<7} {'Ratio':<7} {'Tokens':<7} {'Ratio':<7} {'Diff %':<12}")
|
||||
print("-" * 95)
|
||||
|
||||
for name, text in all_text:
|
||||
baseline_data = baseline_results[name]
|
||||
ours_data = ours_results[name]
|
||||
|
||||
# Calculate relative difference (positive means ours is better, negative means worse)
|
||||
# Using tokens: fewer tokens is better, so we calculate (baseline_tokens - ours_tokens) / baseline_tokens
|
||||
relative_diff = ((baseline_data['tokens'] - ours_data['tokens']) / baseline_data['tokens']) * 100
|
||||
|
||||
# Determine which has better compression (higher ratio = better)
|
||||
if baseline_data['ratio'] > ours_data['ratio']:
|
||||
baseline_color, ours_color = GREEN, RED
|
||||
better = baseline_name
|
||||
diff_color = RED
|
||||
elif ours_data['ratio'] > baseline_data['ratio']:
|
||||
baseline_color, ours_color = RED, GREEN
|
||||
better = "Ours"
|
||||
diff_color = GREEN
|
||||
else:
|
||||
baseline_color, ours_color = "", ""
|
||||
better = "Tie"
|
||||
diff_color = ""
|
||||
|
||||
print(f"{name:<10} {baseline_data['bytes']:<8} "
|
||||
f"{baseline_color}{baseline_data['tokens']:<7}{RESET} "
|
||||
f"{baseline_color}{baseline_data['ratio']:<7.2f}{RESET} "
|
||||
f"{ours_color}{ours_data['tokens']:<7}{RESET} "
|
||||
f"{ours_color}{ours_data['ratio']:<7.2f}{RESET} "
|
||||
f"{diff_color}{relative_diff:+7.1f}%{RESET} "
|
||||
f"{better:<10}")
|
||||
|
||||
# Print comparisons
|
||||
print_comparison("GPT-2", tokenizer_results['gpt2'], tokenizer_results['ours'], all_text)
|
||||
print_comparison("GPT-4", tokenizer_results['gpt4'], tokenizer_results['ours'], all_text)
|
||||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
lines = []
|
||||
for baseline_name in ["GPT-2", "GPT-4"]:
|
||||
baseline_key = baseline_name.lower().replace('-', '')
|
||||
baseline_results = tokenizer_results[baseline_key]
|
||||
ours_results = tokenizer_results['ours']
|
||||
lines.append(f"### Comparison with {baseline_name}")
|
||||
lines.append("")
|
||||
lines.append("| Text Type | Bytes | " + baseline_name + " Tokens | " + baseline_name + " Ratio | Ours Tokens | Ours Ratio | Relative Diff % |")
|
||||
lines.append("|-----------|-------|--------------|--------------|-------------|------------|-----------------|")
|
||||
for name, text in all_text:
|
||||
baseline_data = baseline_results[name]
|
||||
ours_data = ours_results[name]
|
||||
relative_diff = ((baseline_data['tokens'] - ours_data['tokens']) / baseline_data['tokens']) * 100
|
||||
lines.append(f"| {name} | {baseline_data['bytes']} | {baseline_data['tokens']} | {baseline_data['ratio']:.2f} | {ours_data['tokens']} | {ours_data['ratio']:.2f} | {relative_diff:+.1f}% |")
|
||||
lines.append("")
|
||||
report_markdown = "\n".join(lines)
|
||||
get_report().log(section="Tokenizer evaluation", data=[
|
||||
report_markdown,
|
||||
])
|
||||
106
scripts/tok_train.py
Normal file
106
scripts/tok_train.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""
|
||||
Train a tokenizer using the HuggingFace Tokenizers library.
|
||||
In the style of GPT-4 tokenizer.
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import torch
|
||||
from nanochat.tokenizer import RustBPETokenizer
|
||||
from nanochat.common import get_base_dir
|
||||
from nanochat.dataset import parquets_iter_batched
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Parse command line arguments
|
||||
|
||||
parser = argparse.ArgumentParser(description='Train a BPE tokenizer')
|
||||
parser.add_argument('--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)')
|
||||
parser.add_argument('--doc_cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)')
|
||||
parser.add_argument('--vocab_size', type=int, default=65536, help='Vocabulary size (default: 65536 = 2^16)')
|
||||
args = parser.parse_args()
|
||||
print(f"max_chars: {args.max_chars:,}")
|
||||
print(f"doc_cap: {args.doc_cap:,}")
|
||||
print(f"vocab_size: {args.vocab_size:,}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Text iterator
|
||||
|
||||
def text_iterator():
|
||||
"""
|
||||
1) Flatten the batches into a single iterator
|
||||
2) Crop every document to args.doc_cap characters
|
||||
3) Break when we've seen args.max_chars characters
|
||||
"""
|
||||
nchars = 0
|
||||
for batch in parquets_iter_batched(split="train"):
|
||||
for doc in batch:
|
||||
doc_text = doc
|
||||
if len(doc_text) > args.doc_cap:
|
||||
doc_text = doc_text[:args.doc_cap]
|
||||
nchars += len(doc_text)
|
||||
yield doc_text
|
||||
if nchars > args.max_chars:
|
||||
return
|
||||
text_iter = text_iterator()
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Train the tokenizer
|
||||
t0 = time.time()
|
||||
tokenizer = RustBPETokenizer.train_from_iterator(text_iter, args.vocab_size)
|
||||
t1 = time.time()
|
||||
train_time = t1 - t0
|
||||
print(f"Training time: {train_time:.2f}s")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Save the tokenizer to disk
|
||||
base_dir = get_base_dir()
|
||||
tokenizer_dir = os.path.join(base_dir, "tokenizer")
|
||||
tokenizer.save(tokenizer_dir)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Quick inline sanity check
|
||||
test_text = """Hello world! This is a test.
|
||||
Numbers: 123, 4567, 89
|
||||
Contractions: I'm, you're, it's
|
||||
Special chars: @#$%^&*()
|
||||
Unicode: 你好世界 🌍"""
|
||||
encoded = tokenizer.encode(test_text)
|
||||
decoded = tokenizer.decode(encoded)
|
||||
assert decoded == test_text
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# One more thing: we wish to cache a mapping from token id to number of bytes of that token
|
||||
# for efficient evaluation of bits per byte. Unlike the typical mean loss, this
|
||||
# allows us to report a loss that is invariant to the vocab size of the tokenizer.
|
||||
# The bits per byte on the validation set is then one of the primary metrics we care about.
|
||||
vocab_size = tokenizer.get_vocab_size()
|
||||
special_set = set(tokenizer.get_special_tokens())
|
||||
token_strings = [tokenizer.decode([token_id]) for token_id in range(vocab_size)]
|
||||
token_bytes = []
|
||||
for token_id in range(vocab_size):
|
||||
token_str = token_strings[token_id] # the Python string representation of this token
|
||||
if token_str in special_set:
|
||||
token_bytes.append(0) # special characters are not counted
|
||||
else:
|
||||
id_bytes = len(token_str.encode("utf-8")) # number of bytes that make up this token
|
||||
token_bytes.append(id_bytes)
|
||||
token_bytes = torch.tensor(token_bytes, dtype=torch.int32, device='cpu')
|
||||
token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt")
|
||||
with open(token_bytes_path, "wb") as f:
|
||||
torch.save(token_bytes, f)
|
||||
print(f"Saved token_bytes to {token_bytes_path}")
|
||||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
token_bytes_nonzero = (token_bytes[token_bytes > 0]).to(dtype=torch.float32)
|
||||
get_report().log(section="Tokenizer training", data=[
|
||||
vars(args), # argparse command line arguments
|
||||
{"train_time": train_time},
|
||||
{"num_special_tokens": len(special_set)},
|
||||
{
|
||||
"token_bytes_min": int(token_bytes_nonzero.min().item()),
|
||||
"token_bytes_max": int(token_bytes_nonzero.max().item()),
|
||||
"token_bytes_mean": token_bytes_nonzero.mean().item(),
|
||||
"token_bytes_std": token_bytes_nonzero.std().item(),
|
||||
}
|
||||
])
|
||||
Reference in New Issue
Block a user