From e3f58b838e98a5ea013a3c1773fde9d4a3c5d090 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Fri, 16 Jan 2026 20:59:42 +0000 Subject: [PATCH] ranked version --- nanochat/gpt.py | 48 ++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 8 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 86f440b..ffb7862 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -68,7 +68,7 @@ class CausalSelfAttention(nn.Module): self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) - def forward(self, x, cos_sin, window_size, kv_cache): + def forward(self, x, cos_sin, window_size, kv_cache, v0, v0_lambda): B, T, C = x.size() # Project the input to get queries, keys, and values @@ -77,6 +77,11 @@ class CausalSelfAttention(nn.Module): k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim) v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim) + # Value residual (ResFormer): mix in projected initial embedding for later layers + if v0 is not None: + v0_reshaped = v0.view(B, T, self.n_kv_head, self.head_dim) + v = v + v0_lambda * v0_reshaped + # Apply Rotary Embeddings to queries and keys to get relative positional encoding cos, sin = cos_sin q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) @@ -126,8 +131,8 @@ class Block(nn.Module): self.attn = CausalSelfAttention(config, layer_idx) self.mlp = MLP(config) - def forward(self, x, cos_sin, window_size, kv_cache): - x = x + self.attn(norm(x), cos_sin, window_size, kv_cache) + def forward(self, x, cos_sin, window_size, kv_cache, v0, v0_lambda): + x = x + self.attn(norm(x), cos_sin, window_size, kv_cache, v0, v0_lambda) x = x + self.mlp(norm(x)) return x @@ -160,6 +165,17 @@ class GPT(nn.Module): # Separate parameters so they can have different optimizer treatment self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights() self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights() + # Value residual (ResFormer-style): low-rank factorized embedding for value residual + # Paper: "Value Residual Learning" (arXiv:2410.17897) shows this improves information flow + # We apply to last 1/4 of layers as the paper shows later layers benefit most + # Low-rank factorization: (vocab, r) @ (r, kv_dim) instead of full (vocab, kv_dim) + head_dim = config.n_embd // config.n_head + kv_dim = config.n_kv_head * head_dim + value_rank = 32 # low-rank bottleneck dimension + self.value_embed_A = nn.Embedding(padded_vocab_size, value_rank) # token -> low-rank + self.value_embed_B = nn.Linear(value_rank, kv_dim, bias=False) # low-rank -> kv_dim + self.v0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights() + self.value_residual_start = config.n_layer - config.n_layer // 4 # last 1/4 of layers # To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only. # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory, # so let's just over-compute them by 10X, but assert fail if we ever reach that amount. @@ -204,15 +220,21 @@ class GPT(nn.Module): with torch.no_grad(): self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init + self.v0_lambdas.fill_(0.0) # 0.0 => value residual is disabled at init + + # Value embedding low-rank factors (init like embeddings/projections) + torch.nn.init.normal_(self.value_embed_A.weight, mean=0.0, std=1.0) # like wte + torch.nn.init.uniform_(self.value_embed_B.weight, -s, s) # like c_v # Rotary embeddings head_dim = self.config.n_embd // self.config.n_head cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) self.cos, self.sin = cos, sin - # Cast token embeddings to bf16: optimizer can tolerate it and it saves memory + # Cast embeddings to bf16: optimizer can tolerate it and it saves memory if self.transformer.wte.weight.device.type == "cuda": self.transformer.wte.to(dtype=torch.bfloat16) + self.value_embed_A.to(dtype=torch.bfloat16) def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): # TODO: bump base theta more? e.g. 100K is more common more recently @@ -277,7 +299,8 @@ class GPT(nn.Module): """ nparams = sum(p.numel() for p in self.parameters()) # Exclude non-matmul params: embeddings and per-layer scalars - nparams_exclude = self.transformer.wte.weight.numel() + self.resid_lambdas.numel() + self.x0_lambdas.numel() + nparams_exclude = (self.transformer.wte.weight.numel() + self.value_embed_A.weight.numel() + + self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.v0_lambdas.numel()) h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len # Sum attention FLOPs per layer, accounting for sliding window attn_flops = 0 @@ -303,13 +326,16 @@ class GPT(nn.Module): def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5): model_dim = self.config.n_embd ddp, rank, local_rank, world_size = get_dist_info() - # Separate out all parameters into 5 groups (matrix, embedding, lm_head, resid_lambdas, x0_lambdas) + # Separate out all parameters into groups (matrix, embedding, lm_head, value_embed, resid_lambdas, x0_lambdas, v0_lambdas) matrix_params = list(self.transformer.h.parameters()) embedding_params = list(self.transformer.wte.parameters()) lm_head_params = list(self.lm_head.parameters()) + value_embed_A_params = list(self.value_embed_A.parameters()) + value_embed_B_params = list(self.value_embed_B.parameters()) resid_params = [self.resid_lambdas] x0_params = [self.x0_lambdas] - assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(resid_params) + len(x0_params) + v0_params = [self.v0_lambdas] + assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embed_A_params) + len(value_embed_B_params) + len(resid_params) + len(x0_params) + len(v0_params) # Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars # Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model) dmodel_lr_scale = (model_dim / 768) ** -0.5 @@ -317,8 +343,11 @@ class GPT(nn.Module): adam_groups = [ dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale), dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale), + dict(params=value_embed_A_params, lr=embedding_lr * dmodel_lr_scale), # low-rank embedding + dict(params=value_embed_B_params, lr=embedding_lr * dmodel_lr_scale), # low-rank projection dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream dict(params=x0_params, lr=scalar_lr), + dict(params=v0_params, lr=scalar_lr), ] adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True) @@ -349,9 +378,12 @@ class GPT(nn.Module): x = self.transformer.wte(idx) x = norm(x) x0 = x # save initial normalized embedding for x0 residual + # Value residual (ResFormer): low-rank factorized embedding for later layers + v0 = self.value_embed_B(self.value_embed_A(idx)) # (B, T, kv_dim) for i, block in enumerate(self.transformer.h): x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 - x = block(x, cos_sin, self.window_sizes[i], kv_cache) + v0_for_layer = v0 if i >= self.value_residual_start else None + x = block(x, cos_sin, self.window_sizes[i], kv_cache, v0_for_layer, self.v0_lambdas[i]) x = norm(x) # Forward the lm_head (compute logits)