also add base_train change example for how to swap LinearFP8

This commit is contained in:
Andrej Karpathy
2026-01-13 17:08:10 +00:00
parent a6382a6ce8
commit 69b1ed245e

View File

@@ -22,6 +22,7 @@ import torch.nn.functional as F
from nanochat.common import get_dist_info, print0 from nanochat.common import get_dist_info, print0
from nanochat.muon import Muon, DistMuon from nanochat.muon import Muon, DistMuon
from nanochat.adamw import DistAdamW from nanochat.adamw import DistAdamW
from nanochat.fp8_static import LinearFP8
# Load Flash Attention 3 from HuggingFace Hub (and silence the progress bar) # Load Flash Attention 3 from HuggingFace Hub (and silence the progress bar)
import os import os
@@ -159,7 +160,7 @@ class GPT(nn.Module):
"wte": nn.Embedding(padded_vocab_size, config.n_embd), "wte": nn.Embedding(padded_vocab_size, config.n_embd),
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]), "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
}) })
self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False) self.lm_head = LinearFP8(config.n_embd, padded_vocab_size, bias=False, x_scale=100/448, w_scale=1.6/448, monitor=False)
# Per-layer learnable scalars (inspired by modded-nanogpt) # Per-layer learnable scalars (inspired by modded-nanogpt)
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral) # resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled) # x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)