mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-30 04:22:02 +00:00
Merge branch 've'
This commit is contained in:
@@ -61,7 +61,6 @@ for d in "${DEPTHS[@]}"; do
|
||||
# No --target-flops, let it use the default ratio from base_train
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \
|
||||
--depth=$d \
|
||||
--target-param-data-ratio=8 \
|
||||
--run="${WANDB_RUN}_d${d}" \
|
||||
--model-tag="${TAG}" \
|
||||
--core-metric-every=999999 \
|
||||
|
||||
@@ -28,8 +28,8 @@ from nanochat.flash_attention import flash_attn
|
||||
|
||||
@dataclass
|
||||
class GPTConfig:
|
||||
sequence_len: int = 1024
|
||||
vocab_size: int = 50304
|
||||
sequence_len: int = 2048
|
||||
vocab_size: int = 32768
|
||||
n_layer: int = 12
|
||||
n_head: int = 6 # number of query heads
|
||||
n_kv_head: int = 6 # number of key/value heads (GQA)
|
||||
@@ -37,7 +37,7 @@ class GPTConfig:
|
||||
# Sliding window attention pattern string, tiled across layers. Final layer always L.
|
||||
# Characters: L=long (full context), S=short (half context)
|
||||
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
|
||||
window_pattern: str = "L"
|
||||
window_pattern: str = "SSSL"
|
||||
|
||||
|
||||
def norm(x):
|
||||
@@ -45,6 +45,10 @@ def norm(x):
|
||||
return F.rms_norm(x, (x.size(-1),))
|
||||
|
||||
|
||||
def has_ve(layer_idx, n_layer):
|
||||
"""Returns True if GPT layer should have Value Embedding (alternating, last layer always included)."""
|
||||
return layer_idx % 2 == (n_layer - 1) % 2
|
||||
|
||||
def apply_rotary_emb(x, cos, sin):
|
||||
assert x.ndim == 4 # multihead attention
|
||||
d = x.shape[3] // 2
|
||||
@@ -67,8 +71,10 @@ class CausalSelfAttention(nn.Module):
|
||||
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
||||
self.ve_gate_channels = 32
|
||||
self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
|
||||
|
||||
def forward(self, x, cos_sin, window_size, kv_cache):
|
||||
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
||||
B, T, C = x.size()
|
||||
|
||||
# Project the input to get queries, keys, and values
|
||||
@@ -77,6 +83,12 @@ class CausalSelfAttention(nn.Module):
|
||||
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
|
||||
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
|
||||
|
||||
# Value residual (ResFormer): mix in value embedding with input-dependent gate per head
|
||||
if ve is not None:
|
||||
ve = ve.view(B, T, self.n_kv_head, self.head_dim)
|
||||
gate = 2 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) # (B, T, n_kv_head), range (0, 2)
|
||||
v = v + gate.unsqueeze(-1) * ve
|
||||
|
||||
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
|
||||
cos, sin = cos_sin
|
||||
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
|
||||
@@ -126,8 +138,8 @@ class Block(nn.Module):
|
||||
self.attn = CausalSelfAttention(config, layer_idx)
|
||||
self.mlp = MLP(config)
|
||||
|
||||
def forward(self, x, cos_sin, window_size, kv_cache):
|
||||
x = x + self.attn(norm(x), cos_sin, window_size, kv_cache)
|
||||
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
||||
x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache)
|
||||
x = x + self.mlp(norm(x))
|
||||
return x
|
||||
|
||||
@@ -160,6 +172,10 @@ class GPT(nn.Module):
|
||||
# Separate parameters so they can have different optimizer treatment
|
||||
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
|
||||
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
||||
# Value embeddings (ResFormer-style): alternating layers, last layer always included
|
||||
head_dim = config.n_embd // config.n_head
|
||||
kv_dim = config.n_kv_head * head_dim
|
||||
self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)})
|
||||
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
|
||||
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
|
||||
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
|
||||
@@ -170,6 +186,7 @@ class GPT(nn.Module):
|
||||
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
||||
self.register_buffer("sin", sin, persistent=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def init_weights(self):
|
||||
"""
|
||||
Initialize the full model in this one function for maximum clarity.
|
||||
@@ -201,18 +218,28 @@ class GPT(nn.Module):
|
||||
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
||||
|
||||
# Per-layer scalars
|
||||
with torch.no_grad():
|
||||
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
|
||||
self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init
|
||||
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
|
||||
self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init
|
||||
|
||||
# Value embeddings (init like c_v: uniform with same std)
|
||||
for ve in self.value_embeds.values():
|
||||
torch.nn.init.uniform_(ve.weight, -s, s)
|
||||
|
||||
# Gate weights init to zero so gates start at sigmoid(0) = 0.5, scaled by 2 -> 1.0 (neutral)
|
||||
for block in self.transformer.h:
|
||||
if block.attn.ve_gate is not None:
|
||||
torch.nn.init.zeros_(block.attn.ve_gate.weight)
|
||||
|
||||
# Rotary embeddings
|
||||
head_dim = self.config.n_embd // self.config.n_head
|
||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
self.cos, self.sin = cos, sin
|
||||
|
||||
# Cast token embeddings to bf16: optimizer can tolerate it and it saves memory
|
||||
# Cast embeddings to bf16: optimizer can tolerate it and it saves memory
|
||||
if self.transformer.wte.weight.device.type == "cuda":
|
||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||
for ve in self.value_embeds.values():
|
||||
ve.to(dtype=torch.bfloat16)
|
||||
|
||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
||||
# TODO: bump base theta more? e.g. 100K is more common more recently
|
||||
@@ -277,7 +304,9 @@ class GPT(nn.Module):
|
||||
"""
|
||||
nparams = sum(p.numel() for p in self.parameters())
|
||||
# Exclude non-matmul params: embeddings and per-layer scalars
|
||||
nparams_exclude = self.transformer.wte.weight.numel() + self.resid_lambdas.numel() + self.x0_lambdas.numel()
|
||||
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
|
||||
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
|
||||
self.resid_lambdas.numel() + self.x0_lambdas.numel())
|
||||
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
||||
# Sum attention FLOPs per layer, accounting for sliding window
|
||||
attn_flops = 0
|
||||
@@ -303,13 +332,14 @@ class GPT(nn.Module):
|
||||
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
|
||||
model_dim = self.config.n_embd
|
||||
ddp, rank, local_rank, world_size = get_dist_info()
|
||||
# Separate out all parameters into 5 groups (matrix, embedding, lm_head, resid_lambdas, x0_lambdas)
|
||||
# Separate out all parameters into groups
|
||||
matrix_params = list(self.transformer.h.parameters())
|
||||
value_embeds_params = list(self.value_embeds.parameters())
|
||||
embedding_params = list(self.transformer.wte.parameters())
|
||||
lm_head_params = list(self.lm_head.parameters())
|
||||
resid_params = [self.resid_lambdas]
|
||||
x0_params = [self.x0_lambdas]
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(resid_params) + len(x0_params)
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)
|
||||
# Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars
|
||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
|
||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||
@@ -317,6 +347,7 @@ class GPT(nn.Module):
|
||||
adam_groups = [
|
||||
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
|
||||
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
|
||||
dict(params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding
|
||||
dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream
|
||||
dict(params=x0_params, lr=scalar_lr),
|
||||
]
|
||||
@@ -351,7 +382,8 @@ class GPT(nn.Module):
|
||||
x0 = x # save initial normalized embedding for x0 residual
|
||||
for i, block in enumerate(self.transformer.h):
|
||||
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
||||
x = block(x, cos_sin, self.window_sizes[i], kv_cache)
|
||||
ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None
|
||||
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
|
||||
x = norm(x)
|
||||
|
||||
# Forward the lm_head (compute logits)
|
||||
|
||||
@@ -1,20 +1,23 @@
|
||||
#!/bin/bash
|
||||
|
||||
LABEL="jan16"
|
||||
|
||||
FLOPS_BUDGETS=(
|
||||
1e18
|
||||
3e18
|
||||
6e18
|
||||
)
|
||||
DEPTHS=(8 10 12 14 16 18 20)
|
||||
DEPTHS=(6 7 8 9 10 11 12 13 14)
|
||||
|
||||
NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
|
||||
WANDB_RUN="${WANDB_RUN:-scaling}"
|
||||
WANDB_RUN="${WANDB_RUN:-scaling_${LABEL}}"
|
||||
EVAL_TOKENS=$((100 * 524288)) # ~100M tokens for final eval (default is ~10M)
|
||||
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR="${NANOCHAT_BASE_DIR:-$HOME/.cache/nanochat}"
|
||||
source .venv/bin/activate
|
||||
|
||||
RESULTS_DIR="$NANOCHAT_BASE_DIR/scaling_laws_results"
|
||||
RESULTS_DIR="$NANOCHAT_BASE_DIR/scaling_laws_results_${LABEL}"
|
||||
mkdir -p "$RESULTS_DIR"
|
||||
RESULTS_FILE="$RESULTS_DIR/results.csv"
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding
|
||||
# Training horizon (only one used, in order of precedence)
|
||||
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
|
||||
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
|
||||
parser.add_argument("--target-param-data-ratio", type=int, default=8, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
|
||||
parser.add_argument("--target-param-data-ratio", type=int, default=4, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
|
||||
# Optimization
|
||||
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
|
||||
@@ -112,21 +112,19 @@ vocab_size = tokenizer.get_vocab_size()
|
||||
print0(f"Vocab size: {vocab_size:,}")
|
||||
|
||||
# Model kwargs are derived from the desired depth of the model
|
||||
# We nudge model_dim up to the nearest multiple of head_dim to ensure clean division
|
||||
# (FA3 requires head_dim divisible by 8, and this guarantees head_dim == args.head_dim exactly)
|
||||
# (For very small depths, this gives a slight "unfair" advantage to models with odd depths)
|
||||
num_layers = args.depth
|
||||
model_dim = args.depth * args.aspect_ratio
|
||||
def find_num_heads(model_dim, target_head_dim):
|
||||
# Find num_heads that divides model_dim evenly, with head_dim closest to target.
|
||||
ideal = max(1, round(model_dim / target_head_dim))
|
||||
for offset in range(model_dim):
|
||||
for candidate in [ideal + offset, ideal - offset]:
|
||||
if candidate > 0 and model_dim % candidate == 0:
|
||||
return candidate
|
||||
return 1
|
||||
num_heads = find_num_heads(model_dim, args.head_dim)
|
||||
base_dim = args.depth * args.aspect_ratio
|
||||
model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim
|
||||
num_heads = model_dim // args.head_dim
|
||||
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
|
||||
head_dim = model_dim // num_heads
|
||||
print0(f"num_layers: {num_layers}")
|
||||
print0(f"model_dim: {model_dim}")
|
||||
print0(f"model_dim: {model_dim} (base: {base_dim}, nudge: {model_dim - base_dim:+d})")
|
||||
print0(f"num_heads: {num_heads}")
|
||||
print0(f"head_dim: {head_dim}")
|
||||
print0(f"num_kv_heads: {num_kv_heads}")
|
||||
|
||||
# Optimizer / data / training length related hyperparameters
|
||||
|
||||
Reference in New Issue
Block a user