full ve version works very well

This commit is contained in:
Andrej Karpathy
2026-01-16 21:16:47 +00:00
parent e3f58b838e
commit 0b58d70e99

View File

@@ -165,15 +165,12 @@ class GPT(nn.Module):
# Separate parameters so they can have different optimizer treatment # Separate parameters so they can have different optimizer treatment
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights() self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights() self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
# Value residual (ResFormer-style): low-rank factorized embedding for value residual # Value residual (ResFormer-style): separate embedding for values, mixed into later layers
# Paper: "Value Residual Learning" (arXiv:2410.17897) shows this improves information flow # Paper: "Value Residual Learning" (arXiv:2410.17897) shows this improves information flow
# We apply to last 1/4 of layers as the paper shows later layers benefit most # We apply to last 1/4 of layers as the paper shows later layers benefit most
# Low-rank factorization: (vocab, r) @ (r, kv_dim) instead of full (vocab, kv_dim)
head_dim = config.n_embd // config.n_head head_dim = config.n_embd // config.n_head
kv_dim = config.n_kv_head * head_dim kv_dim = config.n_kv_head * head_dim
value_rank = 32 # low-rank bottleneck dimension self.value_embed = nn.Embedding(padded_vocab_size, kv_dim)
self.value_embed_A = nn.Embedding(padded_vocab_size, value_rank) # token -> low-rank
self.value_embed_B = nn.Linear(value_rank, kv_dim, bias=False) # low-rank -> kv_dim
self.v0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights() self.v0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
self.value_residual_start = config.n_layer - config.n_layer // 4 # last 1/4 of layers self.value_residual_start = config.n_layer - config.n_layer // 4 # last 1/4 of layers
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only. # To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
@@ -222,9 +219,8 @@ class GPT(nn.Module):
self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init
self.v0_lambdas.fill_(0.0) # 0.0 => value residual is disabled at init self.v0_lambdas.fill_(0.0) # 0.0 => value residual is disabled at init
# Value embedding low-rank factors (init like embeddings/projections) # Value embedding (init like c_v: uniform with same std)
torch.nn.init.normal_(self.value_embed_A.weight, mean=0.0, std=1.0) # like wte torch.nn.init.uniform_(self.value_embed.weight, -s, s)
torch.nn.init.uniform_(self.value_embed_B.weight, -s, s) # like c_v
# Rotary embeddings # Rotary embeddings
head_dim = self.config.n_embd // self.config.n_head head_dim = self.config.n_embd // self.config.n_head
@@ -234,7 +230,7 @@ class GPT(nn.Module):
# Cast embeddings to bf16: optimizer can tolerate it and it saves memory # Cast embeddings to bf16: optimizer can tolerate it and it saves memory
if self.transformer.wte.weight.device.type == "cuda": if self.transformer.wte.weight.device.type == "cuda":
self.transformer.wte.to(dtype=torch.bfloat16) self.transformer.wte.to(dtype=torch.bfloat16)
self.value_embed_A.to(dtype=torch.bfloat16) self.value_embed.to(dtype=torch.bfloat16)
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
# TODO: bump base theta more? e.g. 100K is more common more recently # TODO: bump base theta more? e.g. 100K is more common more recently
@@ -299,7 +295,7 @@ class GPT(nn.Module):
""" """
nparams = sum(p.numel() for p in self.parameters()) nparams = sum(p.numel() for p in self.parameters())
# Exclude non-matmul params: embeddings and per-layer scalars # Exclude non-matmul params: embeddings and per-layer scalars
nparams_exclude = (self.transformer.wte.weight.numel() + self.value_embed_A.weight.numel() + nparams_exclude = (self.transformer.wte.weight.numel() + self.value_embed.weight.numel() +
self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.v0_lambdas.numel()) self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.v0_lambdas.numel())
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
# Sum attention FLOPs per layer, accounting for sliding window # Sum attention FLOPs per layer, accounting for sliding window
@@ -330,12 +326,11 @@ class GPT(nn.Module):
matrix_params = list(self.transformer.h.parameters()) matrix_params = list(self.transformer.h.parameters())
embedding_params = list(self.transformer.wte.parameters()) embedding_params = list(self.transformer.wte.parameters())
lm_head_params = list(self.lm_head.parameters()) lm_head_params = list(self.lm_head.parameters())
value_embed_A_params = list(self.value_embed_A.parameters()) value_embed_params = list(self.value_embed.parameters())
value_embed_B_params = list(self.value_embed_B.parameters())
resid_params = [self.resid_lambdas] resid_params = [self.resid_lambdas]
x0_params = [self.x0_lambdas] x0_params = [self.x0_lambdas]
v0_params = [self.v0_lambdas] v0_params = [self.v0_lambdas]
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embed_A_params) + len(value_embed_B_params) + len(resid_params) + len(x0_params) + len(v0_params) assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embed_params) + len(resid_params) + len(x0_params) + len(v0_params)
# Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars # Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model) # Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
dmodel_lr_scale = (model_dim / 768) ** -0.5 dmodel_lr_scale = (model_dim / 768) ** -0.5
@@ -343,8 +338,7 @@ class GPT(nn.Module):
adam_groups = [ adam_groups = [
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale), dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale), dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
dict(params=value_embed_A_params, lr=embedding_lr * dmodel_lr_scale), # low-rank embedding dict(params=value_embed_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding
dict(params=value_embed_B_params, lr=embedding_lr * dmodel_lr_scale), # low-rank projection
dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream
dict(params=x0_params, lr=scalar_lr), dict(params=x0_params, lr=scalar_lr),
dict(params=v0_params, lr=scalar_lr), dict(params=v0_params, lr=scalar_lr),
@@ -378,8 +372,8 @@ class GPT(nn.Module):
x = self.transformer.wte(idx) x = self.transformer.wte(idx)
x = norm(x) x = norm(x)
x0 = x # save initial normalized embedding for x0 residual x0 = x # save initial normalized embedding for x0 residual
# Value residual (ResFormer): low-rank factorized embedding for later layers # Value residual (ResFormer): separate value embedding for later layers
v0 = self.value_embed_B(self.value_embed_A(idx)) # (B, T, kv_dim) v0 = self.value_embed(idx) # (B, T, kv_dim)
for i, block in enumerate(self.transformer.h): for i, block in enumerate(self.transformer.h):
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
v0_for_layer = v0 if i >= self.value_residual_start else None v0_for_layer = v0 if i >= self.value_residual_start else None