mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-30 04:22:02 +00:00
simply one VE per layer, works best
This commit is contained in:
@@ -165,14 +165,12 @@ class GPT(nn.Module):
|
||||
# Separate parameters so they can have different optimizer treatment
|
||||
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
|
||||
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
||||
# Value residual (ResFormer-style): separate embedding for values, mixed into later layers
|
||||
# Value residual (ResFormer-style): every layer gets its own value embedding
|
||||
# Paper: "Value Residual Learning" (arXiv:2410.17897) shows this improves information flow
|
||||
# We apply to last 1/4 of layers as the paper shows later layers benefit most
|
||||
head_dim = config.n_embd // config.n_head
|
||||
kv_dim = config.n_kv_head * head_dim
|
||||
self.value_embed = nn.Embedding(padded_vocab_size, kv_dim)
|
||||
self.value_embeds = nn.ModuleList([nn.Embedding(padded_vocab_size, kv_dim) for _ in range(config.n_layer)])
|
||||
self.v0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
||||
self.value_residual_start = config.n_layer - config.n_layer // 4 # last 1/4 of layers
|
||||
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
|
||||
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
|
||||
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
|
||||
@@ -219,8 +217,9 @@ class GPT(nn.Module):
|
||||
self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init
|
||||
self.v0_lambdas.fill_(0.0) # 0.0 => value residual is disabled at init
|
||||
|
||||
# Value embedding (init like c_v: uniform with same std)
|
||||
torch.nn.init.uniform_(self.value_embed.weight, -s, s)
|
||||
# Value embeddings (init like c_v: uniform with same std)
|
||||
for ve in self.value_embeds:
|
||||
torch.nn.init.uniform_(ve.weight, -s, s)
|
||||
|
||||
# Rotary embeddings
|
||||
head_dim = self.config.n_embd // self.config.n_head
|
||||
@@ -230,7 +229,8 @@ class GPT(nn.Module):
|
||||
# Cast embeddings to bf16: optimizer can tolerate it and it saves memory
|
||||
if self.transformer.wte.weight.device.type == "cuda":
|
||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||
self.value_embed.to(dtype=torch.bfloat16)
|
||||
for ve in self.value_embeds:
|
||||
ve.to(dtype=torch.bfloat16)
|
||||
|
||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
||||
# TODO: bump base theta more? e.g. 100K is more common more recently
|
||||
@@ -295,7 +295,8 @@ class GPT(nn.Module):
|
||||
"""
|
||||
nparams = sum(p.numel() for p in self.parameters())
|
||||
# Exclude non-matmul params: embeddings and per-layer scalars
|
||||
nparams_exclude = (self.transformer.wte.weight.numel() + self.value_embed.weight.numel() +
|
||||
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds)
|
||||
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
|
||||
self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.v0_lambdas.numel())
|
||||
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
||||
# Sum attention FLOPs per layer, accounting for sliding window
|
||||
@@ -322,15 +323,15 @@ class GPT(nn.Module):
|
||||
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
|
||||
model_dim = self.config.n_embd
|
||||
ddp, rank, local_rank, world_size = get_dist_info()
|
||||
# Separate out all parameters into groups (matrix, embedding, lm_head, value_embed, resid_lambdas, x0_lambdas, v0_lambdas)
|
||||
# Separate out all parameters into groups (matrix, embedding, lm_head, value_embeds, resid_lambdas, x0_lambdas, v0_lambdas)
|
||||
matrix_params = list(self.transformer.h.parameters())
|
||||
embedding_params = list(self.transformer.wte.parameters())
|
||||
lm_head_params = list(self.lm_head.parameters())
|
||||
value_embed_params = list(self.value_embed.parameters())
|
||||
value_embeds_params = list(self.value_embeds.parameters())
|
||||
resid_params = [self.resid_lambdas]
|
||||
x0_params = [self.x0_lambdas]
|
||||
v0_params = [self.v0_lambdas]
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embed_params) + len(resid_params) + len(x0_params) + len(v0_params)
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(v0_params)
|
||||
# Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars
|
||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
|
||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||
@@ -338,7 +339,7 @@ class GPT(nn.Module):
|
||||
adam_groups = [
|
||||
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
|
||||
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
|
||||
dict(params=value_embed_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding
|
||||
dict(params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding
|
||||
dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream
|
||||
dict(params=x0_params, lr=scalar_lr),
|
||||
dict(params=v0_params, lr=scalar_lr),
|
||||
@@ -372,12 +373,11 @@ class GPT(nn.Module):
|
||||
x = self.transformer.wte(idx)
|
||||
x = norm(x)
|
||||
x0 = x # save initial normalized embedding for x0 residual
|
||||
# Value residual (ResFormer): separate value embedding for later layers
|
||||
v0 = self.value_embed(idx) # (B, T, kv_dim)
|
||||
# Value residual (ResFormer): every layer gets its own value embedding
|
||||
v0s = [ve(idx) for ve in self.value_embeds] # n_layer x (B, T, kv_dim)
|
||||
for i, block in enumerate(self.transformer.h):
|
||||
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
||||
v0_for_layer = v0 if i >= self.value_residual_start else None
|
||||
x = block(x, cos_sin, self.window_sizes[i], kv_cache, v0_for_layer, self.v0_lambdas[i])
|
||||
x = block(x, cos_sin, self.window_sizes[i], kv_cache, v0s[i], self.v0_lambdas[i])
|
||||
x = norm(x)
|
||||
|
||||
# Forward the lm_head (compute logits)
|
||||
|
||||
Reference in New Issue
Block a user