mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-30 04:22:02 +00:00
first version of engram following modded nanogpt style
This commit is contained in:
@@ -45,6 +45,41 @@ def norm(x):
|
|||||||
return F.rms_norm(x, (x.size(-1),))
|
return F.rms_norm(x, (x.size(-1),))
|
||||||
|
|
||||||
|
|
||||||
|
class BigramEmbed(nn.Module):
|
||||||
|
"""
|
||||||
|
Hash bigrams to embeddings. Simple, self-contained, runs on GPU.
|
||||||
|
Following modded-nanogpt's approach: single hash, no gating.
|
||||||
|
|
||||||
|
For each position t, hashes (token[t-1], token[t]) to an index in a large
|
||||||
|
embedding table. This provides O(1) lookup for local 2-gram patterns,
|
||||||
|
offloading static pattern reconstruction from the transformer layers.
|
||||||
|
|
||||||
|
Ref: https://github.com/KellerJordan/modded-nanogpt/pull/201
|
||||||
|
Ref: https://arxiv.org/abs/1709.03933 (Hash Embeddings)
|
||||||
|
"""
|
||||||
|
def __init__(self, vocab_size: int, embed_dim: int, table_multiplier: int = 5):
|
||||||
|
super().__init__()
|
||||||
|
self.bigram_vocab_size = vocab_size * table_multiplier
|
||||||
|
self.embed = nn.Embedding(self.bigram_vocab_size, embed_dim)
|
||||||
|
|
||||||
|
def forward(self, idx: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
idx: (B, T) token ids
|
||||||
|
Returns: (B, T, embed_dim) bigram embeddings
|
||||||
|
"""
|
||||||
|
# Hash (prev_token, curr_token) -> index
|
||||||
|
# Position 0 gets a reserved index (no valid bigram)
|
||||||
|
rand_int_1 = 36313
|
||||||
|
rand_int_2 = 27191
|
||||||
|
mod = self.bigram_vocab_size - 1
|
||||||
|
|
||||||
|
h = torch.empty_like(idx, dtype=torch.long)
|
||||||
|
h[:, 0] = mod # reserved index for position 0
|
||||||
|
h[:, 1:] = (rand_int_1 * idx[:, 1:] ^ rand_int_2 * idx[:, :-1]) % mod
|
||||||
|
|
||||||
|
return self.embed(h)
|
||||||
|
|
||||||
|
|
||||||
def has_ve(layer_idx, n_layer):
|
def has_ve(layer_idx, n_layer):
|
||||||
"""Returns True if GPT layer should have Value Embedding (alternating, last layer always included)."""
|
"""Returns True if GPT layer should have Value Embedding (alternating, last layer always included)."""
|
||||||
return layer_idx % 2 == (n_layer - 1) % 2
|
return layer_idx % 2 == (n_layer - 1) % 2
|
||||||
@@ -169,9 +204,13 @@ class GPT(nn.Module):
|
|||||||
# Per-layer learnable scalars (inspired by modded-nanogpt)
|
# Per-layer learnable scalars (inspired by modded-nanogpt)
|
||||||
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
|
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
|
||||||
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
|
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
|
||||||
|
# bigram_lambdas: blends bigram embeddings in at each layer (init 0.1 = small contribution)
|
||||||
# Separate parameters so they can have different optimizer treatment
|
# Separate parameters so they can have different optimizer treatment
|
||||||
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
|
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
|
||||||
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
||||||
|
self.bigram_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
||||||
|
# Bigram hash embeddings: O(1) lookup for local 2-gram patterns
|
||||||
|
self.bigram_embed = BigramEmbed(config.vocab_size, config.n_embd)
|
||||||
# Value embeddings (ResFormer-style): alternating layers, last layer always included
|
# Value embeddings (ResFormer-style): alternating layers, last layer always included
|
||||||
head_dim = config.n_embd // config.n_head
|
head_dim = config.n_embd // config.n_head
|
||||||
kv_dim = config.n_kv_head * head_dim
|
kv_dim = config.n_kv_head * head_dim
|
||||||
@@ -219,7 +258,11 @@ class GPT(nn.Module):
|
|||||||
|
|
||||||
# Per-layer scalars
|
# Per-layer scalars
|
||||||
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
|
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
|
||||||
self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init
|
self.x0_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to input embedding
|
||||||
|
self.bigram_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to bigram embeddings
|
||||||
|
|
||||||
|
# Bigram embeddings: zero init so it starts as identity
|
||||||
|
nn.init.zeros_(self.bigram_embed.embed.weight)
|
||||||
|
|
||||||
# Value embeddings (init like c_v: uniform with same std)
|
# Value embeddings (init like c_v: uniform with same std)
|
||||||
for ve in self.value_embeds.values():
|
for ve in self.value_embeds.values():
|
||||||
@@ -240,6 +283,7 @@ class GPT(nn.Module):
|
|||||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||||
for ve in self.value_embeds.values():
|
for ve in self.value_embeds.values():
|
||||||
ve.to(dtype=torch.bfloat16)
|
ve.to(dtype=torch.bfloat16)
|
||||||
|
self.bigram_embed.to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
||||||
# TODO: bump base theta more? e.g. 100K is more common more recently
|
# TODO: bump base theta more? e.g. 100K is more common more recently
|
||||||
@@ -305,8 +349,9 @@ class GPT(nn.Module):
|
|||||||
nparams = sum(p.numel() for p in self.parameters())
|
nparams = sum(p.numel() for p in self.parameters())
|
||||||
# Exclude non-matmul params: embeddings and per-layer scalars
|
# Exclude non-matmul params: embeddings and per-layer scalars
|
||||||
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
|
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
|
||||||
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
|
bigram_embed_numel = self.bigram_embed.embed.weight.numel()
|
||||||
self.resid_lambdas.numel() + self.x0_lambdas.numel())
|
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + bigram_embed_numel +
|
||||||
|
self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.bigram_lambdas.numel())
|
||||||
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
||||||
# Sum attention FLOPs per layer, accounting for sliding window
|
# Sum attention FLOPs per layer, accounting for sliding window
|
||||||
attn_flops = 0
|
attn_flops = 0
|
||||||
@@ -339,7 +384,9 @@ class GPT(nn.Module):
|
|||||||
lm_head_params = list(self.lm_head.parameters())
|
lm_head_params = list(self.lm_head.parameters())
|
||||||
resid_params = [self.resid_lambdas]
|
resid_params = [self.resid_lambdas]
|
||||||
x0_params = [self.x0_lambdas]
|
x0_params = [self.x0_lambdas]
|
||||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)
|
bigram_embed_params = list(self.bigram_embed.parameters())
|
||||||
|
bigram_lambda_params = [self.bigram_lambdas]
|
||||||
|
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(bigram_embed_params) + len(bigram_lambda_params)
|
||||||
# Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars
|
# Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars
|
||||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
|
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
|
||||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||||
@@ -348,8 +395,10 @@ class GPT(nn.Module):
|
|||||||
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
|
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
|
||||||
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
|
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
|
||||||
dict(params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding
|
dict(params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding
|
||||||
|
dict(params=bigram_embed_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding
|
||||||
dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream
|
dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream
|
||||||
dict(params=x0_params, lr=scalar_lr, betas=(0.96, 0.95)), # higher beta1 for x0 scalars
|
dict(params=x0_params, lr=scalar_lr, betas=(0.96, 0.95)), # higher beta1 for x0 scalars
|
||||||
|
dict(params=bigram_lambda_params, lr=scalar_lr, betas=(0.96, 0.95)), # same treatment as x0 lambdas
|
||||||
]
|
]
|
||||||
adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon
|
adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon
|
||||||
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
|
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
|
||||||
@@ -377,11 +426,12 @@ class GPT(nn.Module):
|
|||||||
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
|
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
|
||||||
|
|
||||||
# Forward the trunk of the Transformer
|
# Forward the trunk of the Transformer
|
||||||
x = self.transformer.wte(idx)
|
x = self.transformer.wte(idx) # embed current token
|
||||||
|
x0_bigram = self.bigram_embed(idx) # embed current bigram (via hash lookup)
|
||||||
x = norm(x)
|
x = norm(x)
|
||||||
x0 = x # save initial normalized embedding for x0 residual
|
x0 = x # save initial normalized embedding for x0 residual
|
||||||
for i, block in enumerate(self.transformer.h):
|
for i, block in enumerate(self.transformer.h):
|
||||||
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 + self.bigram_lambdas[i] * x0_bigram
|
||||||
ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None
|
ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None
|
||||||
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
|
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
|
||||||
x = norm(x)
|
x = norm(x)
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
"""
|
"""
|
||||||
Train model. From root directory of the project, run as:
|
Train model. From root directory of the project, run as:
|
||||||
|
|
||||||
python -m scripts.base_train.py
|
python -m scripts.base_train
|
||||||
|
|
||||||
or distributed as:
|
or distributed as:
|
||||||
|
|
||||||
torchrun --nproc_per_node=8 -m scripts.base_train.py
|
torchrun --nproc_per_node=8 -m scripts.base_train
|
||||||
|
|
||||||
If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
|
If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
|
||||||
python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20
|
python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20
|
||||||
|
|||||||
Reference in New Issue
Block a user