diff --git a/dev/LOG.md b/dev/LOG.md index 068b35e..bba35ea 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -4,6 +4,140 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026 --- +## 2026-01-27: Bigram Hash Embeddings (Engram-lite) + +Explored N-gram memory modules inspired by the [DeepSeek Engram paper](https://arxiv.org/abs/2506.08046) and [modded-nanogpt PR #201](https://github.com/KellerJordan/modded-nanogpt/pull/201). + +### Background + +The Engram paper introduces "conditional memory" as a complement to MoE - using O(1) hash lookups to retrieve static N-gram patterns instead of reconstructing them through computation. Key insight: transformers waste early layers "simulating retrieval through computation" for patterns like named entities and formulaic phrases that could be simple table lookups. + +### What We Tried + +**1. Full Engram module with context-aware gating (paper design)** +```python +# Hash bigrams to retrieve embeddings, then gate with hidden state +e = embed(hash(prev_token, curr_token)) +q = RMSNorm(h) # hidden state as query +k = RMSNorm(W_k @ e) # projected embedding as key +v = W_v @ e +α = sigmoid(q · k / √d) # scalar gate per position +output = α * v +``` +- Injected after block 1 (paper found early injection optimal) +- Slight improvement, but quite a bit of complexity added. + +**2. Early-layer only injection** +- Only inject bigram signal in first 4 layers (where paper claims static pattern offloading helps most) +- **Result:** Actually hurt performance. The model seems to need uniform injection across all layers. + +**3. Trigrams** +- Extended to hash both 2-grams and 3-grams, concatenating embeddings +- **Result:** No improvement over bigrams alone. Dilutes capacity from more frequent 2-gram patterns. + +**4. Bigram-only with x0-style injection (modded-nanogpt engram-lite approach)** +- Simple hash: `(36313 * curr) XOR (27191 * prev) mod table_size` +- Zero-init embedding table, learned per-layer lambdas +- Add to residual at every layer: `x = resid_λ[i]*x + x0_λ[i]*x0 + bigram_λ[i]*x0_bigram` +- **Result:** This simple approach works and provides a consistent improvement. + +TLDR The winning approach follows modded-nanogpt's "engram-lite", simply adding the following module and feeding its output into the residual branch (gated by a per-layer learnable \lambda) before every single block: + +```python +class BigramEmbed(nn.Module): + def __init__(self, vocab_size, embed_dim, table_multiplier=5): + self.embed = nn.Embedding(vocab_size * table_multiplier, embed_dim) + + def forward(self, idx): + h = (36313 * idx[:, 1:]) ^ (27191 * idx[:, :-1]) % (table_size - 1) + return self.embed(h) +``` + +As for optimal hyperparameters: + +- **Table size:** `vocab_size * 5` (~164K entries for 32K vocab). Swept a number of settings and 5 was optimal. +- **Injection:** Every layer via learned `bigram_lambdas` (init 0.1 was better than 0.0). +- **Normalization:** Also tried adding a `norm()` to the embeddings (mirroring the token embeddings), this was slightly worse. +- **Init:** Zero-init embedding, so starts as identity (tried small noisy init, it's worse) +- **Optimizer:** AdamW with same LR as token embeddings + +### Key Learnings + +1. **Gating didn't help at our scale.** The paper's context-aware gating mechanism (sigmoid dot-product gate) added parameters and complexity without improvement. modded-nanogpt found the same: "simple direct addition to the residual stream outperformed by a decent margin." + +2. **Uniform injection beats early-only.** Despite the paper's finding that early layers benefit most, restricting injection to early layers hurt. The x0-style "add everywhere with learned lambda" pattern works better for our architecture/scale. + +3. **Bigrams are sufficient.** Trigrams didn't help - the extra context doesn't pay for the diluted capacity. + +4. **Scale matters.** The Engram paper's results are at 27B params with MoE. At our ~100M-1B scale, the simpler approach wins. The elaborate gating mechanism may become useful at larger scales where collision handling matters more. + +### Parameters Added + +For d12 model with `table_multiplier=5`: +- Bigram embedding: 32768 × 5 × 768 = ~126M params +- Per-layer lambdas: 12 scalars (negligible) + +If you're keeping track, we now have *a lot* of parameters, a significant amount of them in embeddings (token embeddings, bigram embeddings, value embeddings). For example, for a d12 we now have: + +``` +Parameter counts: +wte : 25,165,824 +bigram_embed : 125,829,120 +value_embeds : 150,994,944 +lm_head : 25,165,824 +transformer_matrices : 84,935,808 +scalars : 36 +total : 412,091,556 +``` + +In other words, only about a quarter of parameters are now weight projections and the vast majority is embedding tables. + +Still, on all axes (steps, wall clock time, flops), this somewhat parameter-bloated architecture beats the baseline and will now become the default. + +After adding the engram-lite, I re-ran the scaling laws to determine the new optimal tokens:params ratio. I swept FLOPs in the range 1e18..1e19, exponentially strided in 4 settings (1e18, 2e18, 5e18, 1e19). I looked at a number of ways of determining the effective parameter count for the purposes of the scaling laws. The results looked like this: + +``` +Kaplan-style (all projections including lm_head and no embeddings) + +Optimal configurations (from quadratic fits): +FLOPs Eff Params Tokens Ratio Val BPB +----------------------------------------------------------------- +1e+18 110,678,115 1,241,505,403 11.2 0.8972 +2e+18 167,797,457 1,785,336,422 10.7 0.8616 +5e+18 250,650,865 2,642,234,152 10.8 0.8293 +1e+19 381,758,347 3,806,871,243 10.3 0.7999 + +N \propto C^0.54, D \propto C^0.49 + +Chinchilla-style (all parameters, period.) + +Optimal configurations (from quadratic fits): +FLOPs Eff Params Tokens Ratio Val BPB +----------------------------------------------------------------- +1e+18 416,320,605 1,232,157,011 3.0 0.8974 +2e+18 560,239,841 1,763,669,281 3.2 0.8616 +5e+18 741,495,903 2,629,909,368 3.6 0.8291 +1e+19 988,644,331 3,884,841,895 4.0 0.7999 + +N \propto C^0.37, D \propto C^0.50 + +Transformer-only-style (only the projections inside the transformer) + +Optimal configurations (from quadratic fits): +FLOPs Eff Params Tokens Ratio Val BPB +----------------------------------------------------------------- +1e+18 80,259,665 1,315,639,547 17.2 0.8966 +2e+18 131,488,566 1,864,134,141 14.5 0.8622 +5e+18 220,985,474 2,595,328,843 12.1 0.8302 +1e+19 401,213,504 3,328,704,512 8.5 0.7994 + +N \propto C^0.70, D \propto C^0.41 +``` + +Clearly, the Kaplan-style ratios are most consistent and produce stable ~0.5 exponents for both params and tokens, meaning we can have a single fixed ratio of tokens:params for compute optimal models. This turns out to be about ~10.5, which now becomes the new default. + +--- + ## 2026-01-19 to 2026-01-22: Optimizer Hyperparameter Sweep Ran ~320 experiments across 6 rounds, scaling from d12→d16→d20 to find optimal optimizer hyperparameters. Added granular per-component control to `setup_optimizers()` — separate LRs and betas for embedding, unembedding, value_embeds, resid_lambdas, x0_lambdas, and Muon matrix params. diff --git a/dev/scaling_analysis.ipynb b/dev/scaling_analysis.ipynb index a196bd1..e7761c5 100644 --- a/dev/scaling_analysis.ipynb +++ b/dev/scaling_analysis.ipynb @@ -15,14 +15,16 @@ "metadata": {}, "outputs": [], "source": [ + "%matplotlib inline\n", "import os\n", "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "# Load results\n", + "tag = \"jan26\"\n", "base_dir = os.environ.get('NANOCHAT_BASE_DIR', os.path.expanduser('~/.cache/nanochat'))\n", - "results_path = os.path.join(base_dir, 'scaling_laws_results', 'results.csv')\n", + "results_path = os.path.join(base_dir, f'scaling_laws_results_{tag}', 'results.csv')\n", "\n", "df = pd.read_csv(results_path)\n", "flops_budgets = sorted(df['flops_budget'].unique())\n", @@ -31,6 +33,99 @@ "df" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# =============================================================================\n", + "# FILTERING: Remove incomplete or problematic runs\n", + "# =============================================================================\n", + "\n", + "print(f\"Before filtering: {len(df)} runs\")\n", + "\n", + "# Filter out runs with missing/invalid val_bpb (incomplete runs)\n", + "df = df[df['val_bpb'].notna() & (df['val_bpb'] > 0)]\n", + "\n", + "# Optional: exclude specific flops budgets that aren't done yet\n", + "# exclude_flops = [1e19] # <-- adjust as runs complete\n", + "# df = df[~df['flops_budget'].isin(exclude_flops)]\n", + "\n", + "# Optional: exclude specific depths\n", + "# exclude_depths = [18, 20]\n", + "# df = df[~df['depth'].isin(exclude_depths)]\n", + "\n", + "print(f\"After filtering: {len(df)} runs\")\n", + "print(f\"FLOPs budgets: {sorted(df['flops_budget'].unique())}\")\n", + "print(f\"Depths: {sorted(df['depth'].unique())}\")\n", + "\n", + "# Update flops_budgets list after filtering\n", + "flops_budgets = sorted(df['flops_budget'].unique())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Effective Parameter Count\n", + "\n", + "Different scaling law papers use different conventions for counting parameters:\n", + "- **Kaplan et al.** excluded embedding parameters (claimed cleaner laws)\n", + "- **Chinchilla** included all parameters (and noted Kaplan had a bug)\n", + "\n", + "Our CSV now has granular counts:\n", + "- `params_wte` - token embedding (lookup table)\n", + "- `params_bigram_embed` - bigram hash embeddings (lookup table)\n", + "- `params_value_embeds` - value embeddings (lookup table)\n", + "- `params_lm_head` - unembedding projection (matmul)\n", + "- `params_transformer` - attention + MLP matrices (matmuls)\n", + "- `params_scalars` - resid/x0/bigram lambdas (tiny)\n", + "\n", + "**Experiment below** with different combinations to see which gives the cleanest scaling laws." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# =============================================================================\n", + "# EXPERIMENT HERE: Define which parameters to count for scaling laws\n", + "# =============================================================================\n", + "\n", + "def compute_effective_params(row):\n", + " \"\"\"\n", + " Compute the 'effective' parameter count for scaling law analysis.\n", + "\n", + " Modify this function to experiment with different conventions:\n", + " - Chinchilla-style: include everything\n", + " - Kaplan-style: exclude embeddings\n", + " - Matmul-only: just transformer + lm_head (the actual compute)\n", + " - etc.\n", + " \"\"\"\n", + " # Option 1: Chinchilla-style (all params)\n", + " # return row['params_total']\n", + "\n", + " # Option 2: Kaplan-style (exclude embeddings)\n", + " return row['params_transformer'] + row['params_lm_head']\n", + "\n", + " # Option 3: Transformer-only (exclude all embeddings AND lm_head)\n", + " # return row['params_transformer']\n", + "\n", + "\n", + "# Compute derived columns\n", + "df['effective_params'] = df.apply(compute_effective_params, axis=1)\n", + "df['param_data_ratio'] = df['tokens_trained'] / df['effective_params']\n", + "\n", + "# Show parameter breakdown for first few rows\n", + "print(\"Parameter breakdown (first row per flops budget):\")\n", + "param_cols = ['depth', 'params_wte', 'params_bigram_embed', 'params_value_embeds',\n", + " 'params_lm_head', 'params_transformer', 'params_scalars', 'params_total', 'effective_params']\n", + "df.groupby('flops_budget').first()[param_cols]" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -54,11 +149,11 @@ "optimal_by_bpb = []\n", "\n", "for flops, color in zip(flops_budgets, colors):\n", - " subset = df[df['flops_budget'] == flops].sort_values('num_scaling_params')\n", - " ax.plot(subset['num_scaling_params'], subset['val_bpb'], 'o', color=color, label=f'{flops:.0e}', markersize=8)\n", + " subset = df[df['flops_budget'] == flops].sort_values('effective_params')\n", + " ax.plot(subset['effective_params'], subset['val_bpb'], 'o', color=color, label=f'{flops:.0e}', markersize=8)\n", "\n", " # Fit quadratic in log-space: val_bpb = a*(log N)^2 + b*(log N) + c\n", - " log_params = np.log10(subset['num_scaling_params'])\n", + " log_params = np.log10(subset['effective_params'])\n", " coeffs = np.polyfit(log_params, subset['val_bpb'], 2)\n", " a, b, c = coeffs\n", "\n", @@ -83,13 +178,13 @@ " # Fallback to raw minimum if quadratic doesn't have minimum\n", " best_idx = subset['val_bpb'].idxmin()\n", " best = subset.loc[best_idx]\n", - " ax.scatter([best['num_scaling_params']], [best['val_bpb']], s=150, color=color,\n", + " ax.scatter([best['effective_params']], [best['val_bpb']], s=150, color=color,\n", " zorder=5, edgecolors='black', linewidths=2)\n", - " optimal_by_bpb.append({'flops': flops, 'params': best['num_scaling_params'],\n", + " optimal_by_bpb.append({'flops': flops, 'params': best['effective_params'],\n", " 'tokens': best['tokens_trained'], 'ratio': best['param_data_ratio'], 'bpb': best['val_bpb']})\n", "\n", "ax.set_xscale('log')\n", - "ax.set_xlabel('Parameters')\n", + "ax.set_xlabel('Effective Parameters')\n", "ax.set_ylabel('Validation Loss (bpb)')\n", "ax.set_title('IsoFLOP Curves')\n", "ax.legend(title='FLOPs', loc='upper right')\n", @@ -138,10 +233,61 @@ "\n", "# Print the optimal points (from quadratic fits)\n", "print(\"\\nOptimal configurations (from quadratic fits):\")\n", - "print(f\"{'FLOPs':<12} {'Params':<15} {'Tokens':<15} {'Ratio':<10} {'Val BPB':<10}\")\n", + "print(f\"{'FLOPs':<12} {'Eff Params':<15} {'Tokens':<15} {'Ratio':<10} {'Val BPB':<10}\")\n", "print(\"-\" * 65)\n", "for _, row in opt_df.iterrows():\n", - " print(f\"{row['flops']:<12.0e} {int(row['params']):<15,} {int(row['tokens']):<15,} {row['ratio']:<10.1f} {row['bpb']:<10.4f}\")\n" + " print(f\"{row['flops']:<12.0e} {int(row['params']):<15,} {int(row['tokens']):<15,} {row['ratio']:<10.1f} {row['bpb']:<10.4f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# =============================================================================\n", + "# Optimal Ratio Summary (from power law fits)\n", + "# =============================================================================\n", + "\n", + "# From the power law fits: N ∝ C^a and D ∝ C^b\n", + "# The ratio D/N ∝ C^(b-a). If a ≈ b, ratio is roughly constant.\n", + "\n", + "if len(opt_df) >= 2:\n", + " log_f = np.log10(opt_df['flops'])\n", + " log_p = np.log10(opt_df['params'])\n", + " log_t = np.log10(opt_df['tokens'])\n", + "\n", + " # Fit power laws\n", + " slope_n, intercept_n = np.polyfit(log_f, log_p, 1)\n", + " slope_d, intercept_d = np.polyfit(log_f, log_t, 1)\n", + "\n", + " # The ratio D/N at a reference compute (geometric mean of our budgets)\n", + " ref_flops = np.sqrt(opt_df['flops'].min() * opt_df['flops'].max())\n", + " log_ref = np.log10(ref_flops)\n", + "\n", + " # Predicted optimal N and D at reference compute\n", + " pred_log_n = intercept_n + slope_n * log_ref\n", + " pred_log_d = intercept_d + slope_d * log_ref\n", + " optimal_ratio = 10**(pred_log_d - pred_log_n)\n", + "\n", + " # Also compute from the fitted optimals directly (mean and std)\n", + " mean_ratio = opt_df['ratio'].mean()\n", + " std_ratio = opt_df['ratio'].std()\n", + "\n", + " print(\"=\" * 60)\n", + " print(\"OPTIMAL RATIO SUMMARY\")\n", + " print(\"=\" * 60)\n", + " print(f\"\\nPower law exponents:\")\n", + " print(f\" N ∝ C^{slope_n:.3f}\")\n", + " print(f\" D ∝ C^{slope_d:.3f}\")\n", + " print(f\" Ratio exponent (b-a): {slope_d - slope_n:.3f} (should be ~0 if ratio is constant)\")\n", + " print(f\"\\nOptimal ratio (tokens per effective param):\")\n", + " print(f\" From power law at C={ref_flops:.1e}: {optimal_ratio:.1f}\")\n", + " print(f\" Mean across budgets: {mean_ratio:.1f} ± {std_ratio:.1f}\")\n", + " print(f\" Chinchilla reference: 20\")\n", + " print(f\"\\nPer-budget ratios: {[f'{r:.1f}' for r in opt_df['ratio'].values]}\")\n", + "else:\n", + " print(\"Need at least 2 flops budgets to compute power law fits\")" ] }, { diff --git a/nanochat/gpt.py b/nanochat/gpt.py index b810ec9..c55e893 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -364,15 +364,34 @@ class GPT(nn.Module): def num_scaling_params(self): """ - Return all of the parameters, same as Chinchilla paper. - Kaplan et al. did not include embedding parameters and said that this led to cleaner scaling laws. - But Kaplan et al. also had a bug in their results (as pointed out by Chinchilla). - My own experiments in nanochat confirm the Chinchilla approach gives the much cleaner scaling law. - Ref: https://arxiv.org/abs/2203.15556 (Chinchilla paper <- good). - Ref: https://arxiv.org/abs/2001.08361 (Kaplan et al. original scaling laws paper <- bad) + Return detailed parameter counts for scaling law analysis. + Different papers use different conventions: + - Kaplan et al. excluded embedding parameters + - Chinchilla included all parameters + Ref: https://arxiv.org/abs/2203.15556 (Chinchilla paper) + Ref: https://arxiv.org/abs/2001.08361 (Kaplan et al. original scaling laws paper) + + Returns a dict with counts for each parameter group, so downstream analysis + can experiment with which combination gives the cleanest scaling laws. """ - nparams = sum(p.numel() for p in self.parameters()) - return nparams + # Count each group separately (mirrors the grouping in setup_optimizers) + wte = sum(p.numel() for p in self.transformer.wte.parameters()) + bigram_embed = sum(p.numel() for p in self.bigram_embed.parameters()) + value_embeds = sum(p.numel() for p in self.value_embeds.parameters()) + lm_head = sum(p.numel() for p in self.lm_head.parameters()) + transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters()) + scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.bigram_lambdas.numel() + total = wte + bigram_embed + value_embeds + lm_head + transformer_matrices + scalars + assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch" + return { + 'wte': wte, + 'bigram_embed': bigram_embed, + 'value_embeds': value_embeds, + 'lm_head': lm_head, + 'transformer_matrices': transformer_matrices, + 'scalars': scalars, + 'total': total, + } def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5): model_dim = self.config.n_embd diff --git a/runs/scaling_laws.sh b/runs/scaling_laws.sh index 1f9dab8..f1e2fd4 100644 --- a/runs/scaling_laws.sh +++ b/runs/scaling_laws.sh @@ -1,13 +1,14 @@ #!/bin/bash -LABEL="jan16" +LABEL="jan26" FLOPS_BUDGETS=( 1e18 - 3e18 - 6e18 + 2.15e18 + 4.64e18 + 1e19 ) -DEPTHS=(6 7 8 9 10 11 12 13 14) +DEPTHS=(8 10 12 14 16 18 20) NPROC_PER_NODE="${NPROC_PER_NODE:-8}" WANDB_RUN="${WANDB_RUN:-scaling_${LABEL}}" @@ -23,7 +24,7 @@ RESULTS_FILE="$RESULTS_DIR/results.csv" # Write CSV header only if file doesn't exist if [ ! -f "$RESULTS_FILE" ]; then - echo "flops_budget,depth,model_dim,num_params,num_scaling_params,num_iterations,tokens_trained,param_data_ratio,val_bpb,core_score,train_time_sec" > "$RESULTS_FILE" + echo "flops_budget,depth,model_dim,params_wte,params_bigram_embed,params_value_embeds,params_lm_head,params_transformer,params_scalars,params_total,num_iterations,tokens_trained,val_bpb,core_score,train_time_sec" > "$RESULTS_FILE" fi log() { @@ -83,13 +84,19 @@ for flops in "${FLOPS_BUDGETS[@]}"; do # Extract training stats from the log LOG_FILE="$RESULTS_DIR/${TAG}_train.log" - NUM_PARAMS=$(grep "Number of parameters:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | head -1 | tr -d ',') - NUM_SCALING_PARAMS=$(grep "Number of parameters:" "$LOG_FILE" | tail -1 | grep -oP 'scaling: [\d,]+' | grep -oP '[\d,]+' | tr -d ',') + + # Extract detailed parameter counts (for scaling law analysis with different conventions) + PARAMS_WTE=$(grep "wte:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',') + PARAMS_BIGRAM=$(grep "bigram_embed:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',') + PARAMS_VE=$(grep "value_embeds:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',') + PARAMS_LM=$(grep "lm_head:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',') + PARAMS_TRANSFORMER=$(grep "transformer_matrices:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',') + PARAMS_SCALARS=$(grep "scalars:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',') + PARAMS_TOTAL=$(grep "total:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',') + NUM_ITERS=$(grep "Calculated number of iterations" "$LOG_FILE" | tail -1 | sed 's/.*: //' | tr -d ',') # Calculate tokens trained (iterations * batch_size, default 524288) TOKENS_TRAINED=$((NUM_ITERS * 524288)) - # Param:data ratio (using scaling params per Kaplan et al.) - PARAM_DATA_RATIO=$(python -c "print(f'{$TOKENS_TRAINED / $NUM_SCALING_PARAMS:.2f}')") # Model dim MODEL_DIM=$((d * 64)) # Val BPB from final eval @@ -102,10 +109,10 @@ for flops in "${FLOPS_BUDGETS[@]}"; do CORE_SCORE="0.0" fi - log " Params: $NUM_PARAMS, Iters: $NUM_ITERS, Ratio: $PARAM_DATA_RATIO, Val BPB: $VAL_BPB, CORE: $CORE_SCORE" + log " Params: $PARAMS_TOTAL (transformer: $PARAMS_TRANSFORMER), Iters: $NUM_ITERS, Val BPB: $VAL_BPB, CORE: $CORE_SCORE" # Append to CSV - echo "$flops,$d,$MODEL_DIM,$NUM_PARAMS,$NUM_SCALING_PARAMS,$NUM_ITERS,$TOKENS_TRAINED,$PARAM_DATA_RATIO,$VAL_BPB,$CORE_SCORE,$TRAIN_TIME" >> "$RESULTS_FILE" + echo "$flops,$d,$MODEL_DIM,$PARAMS_WTE,$PARAMS_BIGRAM,$PARAMS_VE,$PARAMS_LM,$PARAMS_TRANSFORMER,$PARAMS_SCALARS,$PARAMS_TOTAL,$NUM_ITERS,$TOKENS_TRAINED,$VAL_BPB,$CORE_SCORE,$TRAIN_TIME" >> "$RESULTS_FILE" done done diff --git a/scripts/base_train.py b/scripts/base_train.py index 02eeea3..4fa8fca 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -47,7 +47,7 @@ parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding # Training horizon (only one used, in order of precedence) parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)") parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)") -parser.add_argument("--target-param-data-ratio", type=int, default=4, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)") +parser.add_argument("--target-param-data-ratio", type=float, default=10.5, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)") # Optimization parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size") parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens") @@ -178,9 +178,14 @@ if resuming: orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape) model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe -num_params = sum(p.numel() for p in model.parameters()) -num_scaling_params = orig_model.num_scaling_params() -print0(f"Number of parameters: {num_params:,} (scaling: {num_scaling_params:,})") + +# Detailed parameter counts +param_counts = orig_model.num_scaling_params() +print0(f"Parameter counts:") +for key, value in param_counts.items(): + print0(f"{key:24s}: {value:,}") +num_params = param_counts['total'] +num_scaling_params = param_counts['transformer_matrices'] + param_counts['lm_head'] # determined to give the cleanest scaling laws, see dev/LOG.md Jan 27, 2026 num_flops_per_token = model.estimate_flops() print0(f"Estimated FLOPs per token: {num_flops_per_token:e}") @@ -195,14 +200,14 @@ elif args.target_flops > 0: print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}") elif args.target_param_data_ratio > 0: # calculate the number of iterations from the target param data ratio (use scaling params per Kaplan et al.) - target_tokens = args.target_param_data_ratio * num_scaling_params + target_tokens = int(args.target_param_data_ratio * num_scaling_params) num_iterations = target_tokens // args.total_batch_size print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}") else: raise ValueError("No training horizon specified") total_tokens = args.total_batch_size * num_iterations print0(f"Total number of training tokens: {total_tokens:,}") -print0(f"Tokens : Params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20 +print0(f"Tokens : Scaling params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20 print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") # ----------------------------------------------------------------------------- @@ -445,7 +450,7 @@ get_report().log(section="Base model training", data=[ "Number of FLOPs per token": f"{num_flops_per_token:e}", "Calculated number of iterations": num_iterations, "Number of training tokens": total_tokens, - "Tokens : Params ratio": args.total_batch_size * num_iterations / num_params, + "Tokens : Scaling params ratio": args.total_batch_size * num_iterations / num_scaling_params, "DDP world size": ddp_world_size, "warmup_ratio": args.warmup_ratio, "warmdown_ratio": args.warmdown_ratio,