Skip to content

Record: 11L XSA4 + Tight SWA + FA3 + Two-Phase TTT (val_bpb=1.1216)#410

Closed
EthanYangTW wants to merge 17 commits intoopenai:mainfrom
EthanYangTW:submission/twophase-ttt-normrepair
Closed

Record: 11L XSA4 + Tight SWA + FA3 + Two-Phase TTT (val_bpb=1.1216)#410
EthanYangTW wants to merge 17 commits intoopenai:mainfrom
EthanYangTW:submission/twophase-ttt-normrepair

Conversation

@EthanYangTW
Copy link

Summary

Built on PR #374 with a novel two-phase test-time training approach:

  • Phase 1 — Norm-Only Recalibration (100 epochs, Adam lr=0.01): Only unfreeze LayerNorm weights, scales, and final_norm (~22K params). Recalibrates activation distributions damaged by int6 quantization.
  • Phase 2 — Selective-Freeze Block Adaptation (15 epochs, SGD lr=0.003): Unfreeze last 2 transformer blocks + norms + scales + lm_head (~5.3M params). Adapts representations on the recalibrated foundation while preserving SWA-averaged weights in first 9 blocks.

Key insight: the two phases target different error sources (quantization artifacts vs. distribution mismatch) and are additive (-0.019 BPB combined).

Results

Seed val_bpb Artifact Training TTT
1337 1.1258 15,762,005 bytes 96ms/step, 6222 steps 752s two-phase
  • Post-SWA BPB: 1.1447
  • TTT improvement: -0.019 (1.1447 → 1.1258)

Architecture

  • 11L, dim=512, 8 heads / 4 KV (GQA), XSA last 4 layers
  • 3x MLP relu² + SmearGate + OrthoInit
  • Partial RoPE 16/64, LN Scale, BigramHash(2048)
  • Tight SWA, Late QAT (4%), int6 + zstd-22, 1% pruning

Command

torchrun --standalone --nproc_per_node=8 train_gpt.py

EthanYangTW and others added 16 commits March 21, 2026 20:35
Based on SOTA (10L_Int5MLP_MuonWD04_SWA50) with improvements:
- QAT with STE for int5/int6 quantization-aware training
- BigramHash increased from 10240 to 12288
- Eval stride reduced from 64 to 32 for better context
- Magnitude pruning increased from 3% to 5%
- SWA every 25 steps instead of 50
- Artifact size: ~15.89MB (under 16MB limit)
Restore original train_gpt.py baseline. Add new records folder with
submission script based on 10L_Int5MLP_MuonWD04_SWA50 SOTA.

Changes: QAT with STE, BigramHash 12288, eval stride 32,
5% magnitude pruning, SWA every 25 steps.
Port LoRA TTT from records/2026-03-17_LoRA_TTT into our submission.
At eval time, per-document rank-8 LoRA adapters are trained on Q/V
projections and lm_head, then used for scoring. Expected -0.003 to
-0.005 bpb improvement on top of sliding window eval.
val_bpb=1.14443 (seed=2024), artifact=15.90MB
…/train_gpt.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
…le, EMA, Late QAT, TTT

Major rewrite targeting top-5 leaderboard:
- 11 layers (from 10), BigramHash reduced to 10240 to fit 16MB
- XSA (Exclusive Self-Attention) on last 4 layers
- Partial RoPE: 16/64 head dims get position encoding
- LN Scale: 1/sqrt(layer+1) dampening on deeper layers
- EMA (decay=0.997) replaces SWA
- Late QAT: STE int6 enabled only in final 4% of training
- TTT: 25-epoch SGD on val data post-quantization
- FA3 auto-detection with SDPA fallback
- Reverted SwiGLU back to relu² (confirmed worse by openai#340, openai#344)
…y 10 steps

- Disable FA3 (SDPA faster for GQA on PyTorch 2.9)
- BigramHash 10240 -> 8192 to fit 11L under 16MB
- EMA update every 10 steps with adjusted decay to reduce CPU overhead
- Simplify attention forward (remove FA3 code path)
Previous run: 16.94MB with BigramHash 8192 + 5% pruning.
BigramHash 2048 saves ~0.5MB, 10% pruning improves compression further.
v3 was 16.38MB with BigramHash 2048 + 10% pruning.
Removing BigramHash saves ~0.15MB, 15% pruning improves zstd compression.
Fork of unnir's openai#374 (1.1246 BPB) with TTT added:
- 11L, XSA4, Partial RoPE 16/64, LN Scale, Tight SWA
- Shared VE128, SmearGate, BigramHash 2048
- TTT: 25 epochs SGD on val data post-quantization
- Trimmed to 1476 lines (under 1500 limit)
Previous TTT took 7+ min per epoch (uncompiled, single GPU).
Now: torch.compile + DDP across 8 GPUs + 3 epochs + batch 64.
Should finish in ~2-3 min total.
flash_attn_interface (FA3 Hopper) not available on RunPod.
Falls back to flash_attn, then SDPA with GQA support.
Two-phase TTT on PR openai#374 base: phase 1 norm-only recalibration
(100ep Adam), phase 2 selective-freeze last 2 blocks (15ep SGD).
Artifact 15.76MB.
Copilot AI review requested due to automatic review settings March 22, 2026 06:44
@EthanYangTW EthanYangTW changed the title 11L XSA4 + Tight SWA + Two-Phase TTT (1.1258) Record: 11L XSA4 + Tight SWA + Two-Phase TTT (val_bpb=1.1258) Mar 22, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds new /records/track_10min_16mb entries building on PR #374, introducing a two-phase test-time training (TTT) evaluation procedure (norm/scales “repair” phase + selective unfreeze of late blocks), alongside additional record scripts/logs for related runs.

Changes:

  • Added a new record run directory implementing two-phase TTT after int6 roundtrip export.
  • Added an 11-layer Tight-SWA + (single-phase) TTT record script variant.
  • Added a QAT + BigramHash(12K) + stride-32 record (script, README, submission metadata, and training log).

Reviewed changes

Copilot reviewed 7 out of 8 changed files in this pull request and generated 10 comments.

Show a summary per file
File Description
records/track_10min_16mb/2026-03-22_TwoPhase_TTT_NormRepair/train_gpt.py Implements two-phase selective-freeze TTT and int6 export + roundtrip eval pipeline
records/track_10min_16mb/2026-03-22_TwoPhase_TTT_NormRepair/README.md Documents the two-phase TTT approach/results
records/track_10min_16mb/2026-03-22_TwoPhase_TTT_NormRepair/submission.json Submission metadata for the two-phase TTT record
records/track_10min_16mb/2026-03-22_11L_XSA4_TightSWA_TTT/train_gpt.py 11-layer Tight-SWA baseline with (single-phase) TTT and int6 roundtrip eval
records/track_10min_16mb/2026-03-21_QAT_BigramHash12K_Stride32/train_gpt.py QAT + stride-32 record script including pruning + mixed int6/int8 export
records/track_10min_16mb/2026-03-21_QAT_BigramHash12K_Stride32/README.md Documents the QAT + BigramHash(12K) run configuration/results
records/track_10min_16mb/2026-03-21_QAT_BigramHash12K_Stride32/submission.json Submission metadata for the QAT + BigramHash(12K) record
records/track_10min_16mb/2026-03-21_QAT_BigramHash12K_Stride32/train_seed2024.log Captured training/eval log for the QAT + BigramHash(12K) run

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +936 to +940
num_layers_total = max(
(int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")),
default=0,
) + 1
late_k_layers = set(range(num_layers_total - 2, num_layers_total))
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

late_k_layers is computed but never used. If this is leftover from an earlier “late-K passthrough” experiment, it should be removed or wired into the quantization logic to avoid dead code.

Suggested change
num_layers_total = max(
(int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")),
default=0,
) + 1
late_k_layers = set(range(num_layers_total - 2, num_layers_total))

Copilot uses AI. Check for mistakes.
Comment on lines +7 to +8
- **Phase 1 — Norm-Only Recalibration (100 epochs, Adam lr=0.01):** Only unfreeze LayerNorm weights, scales, and final_norm (~22K params). Recalibrates activation distributions damaged by int6 quantization. Acts as post-quantization calibration via gradient descent.
- **Phase 2 — Selective-Freeze Block Adaptation (15 epochs, SGD lr=0.003, momentum=0.95):** Unfreeze last 2 transformer blocks + all norms + scales + lm_head (~5.3M params). Adapts representations on the recalibrated foundation while preserving SWA-averaged weights in the first 9 blocks.
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The README claims Phase 1 unfreezes “LayerNorm weights” and final_norm, but the implementation uses parameterless RMSNorm, so there are no norm weights to unfreeze. It also states Phase 2 unfreezes lm_head, but with tied embeddings the code effectively unfreezes tok_emb instead. Please update the README to reflect which parameters are actually trained in each phase (or change the code to match the described parameter set).

Suggested change
- **Phase 1 — Norm-Only Recalibration (100 epochs, Adam lr=0.01):** Only unfreeze LayerNorm weights, scales, and final_norm (~22K params). Recalibrates activation distributions damaged by int6 quantization. Acts as post-quantization calibration via gradient descent.
- **Phase 2 — Selective-Freeze Block Adaptation (15 epochs, SGD lr=0.003, momentum=0.95):** Unfreeze last 2 transformer blocks + all norms + scales + lm_head (~5.3M params). Adapts representations on the recalibrated foundation while preserving SWA-averaged weights in the first 9 blocks.
- **Phase 1 — Norm-Only Recalibration (100 epochs, Adam lr=0.01):** Only unfreeze the small set of normalization-related parameters that exist in this model (~22K params). The implementation uses parameterless `RMSNorm`, so there are no LayerNorm gamma/beta or separate `final_norm` weights to train; this phase focuses on recalibrating activation distributions damaged by int6 quantization and acts as post-quantization calibration via gradient descent.
- **Phase 2 — Selective-Freeze Block Adaptation (15 epochs, SGD lr=0.003, momentum=0.95):** Unfreeze the last 2 transformer blocks + associated norms/scales + the tied output head / token embedding matrix (`lm_head` is tied to `tok_emb`, so updating the head updates the embeddings; ~5.3M params). This adapts representations on the recalibrated foundation while preserving SWA-averaged weights in the first 9 blocks.

Copilot uses AI. Check for mistakes.
Comment on lines +10 to +11
- **Magnitude pruning 5%:** Increased from 3% to improve compression ratio.
- **SWA every 50 steps:** Checkpoint averaging during warmdown.
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The README states “Magnitude pruning 5%” (and mentions SWA every 50), but the accompanying script prunes at a fixed 15% quantile (0.15) and uses SWA_EVERY=25 by default. Please align the README (and/or make pruning percentage and SWA cadence match the described configuration used for this record).

Suggested change
- **Magnitude pruning 5%:** Increased from 3% to improve compression ratio.
- **SWA every 50 steps:** Checkpoint averaging during warmdown.
- **Magnitude pruning 15% (quantile=0.15):** Increased from 3% to improve compression ratio.
- **SWA every 25 steps:** Checkpoint averaging during warmdown.

Copilot uses AI. Check for mistakes.
Comment on lines +1277 to +1280
with torch.no_grad():
for name, param in base_model.named_parameters():
if param.ndim == 2 and param.numel() > 65536:
threshold = torch.quantile(param.abs().float().flatten(), 0.15)
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Magnitude pruning is hard-coded to the 15% quantile (0.15), which makes it easy for the record script and README/submission metadata to drift out of sync. Consider wiring this to an env/arg (similar to PRUNE_PCT in other record scripts) or updating the documentation/metadata to reflect the actual pruning level used.

Suggested change
with torch.no_grad():
for name, param in base_model.named_parameters():
if param.ndim == 2 and param.numel() > 65536:
threshold = torch.quantile(param.abs().float().flatten(), 0.15)
# Make pruning percentage configurable via PRUNE_PCT env var (default 0.15)
prune_pct_str = os.getenv("PRUNE_PCT")
try:
prune_pct = float(prune_pct_str) if prune_pct_str is not None else 0.15
except (TypeError, ValueError):
prune_pct = 0.15
if master_process:
log0(f"magnitude_pruning: using prune_pct={prune_pct}")
with torch.no_grad():
for name, param in base_model.named_parameters():
if param.ndim == 2 and param.numel() > 65536:
threshold = torch.quantile(param.abs().float().flatten(), prune_pct)

Copilot uses AI. Check for mistakes.
@@ -0,0 +1,1474 @@
"""Selective-freeze TTT with SAM (Sharpness-Aware). Fork of #374."""
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The file docstring mentions SAM (Sharpness-Aware Minimization), but the implementation doesn’t apply SAM (no perturb/second forward, etc.). This is likely to confuse readers and reviewers; either implement SAM in the TTT phases or update the docstring to reflect the actual approach (two-phase selective-freeze TTT).

Suggested change
"""Selective-freeze TTT with SAM (Sharpness-Aware). Fork of #374."""
"""Selective-freeze two-phase TTT (no SAM; Sharpness-Aware Minimization not applied here). Fork of #374."""

Copilot uses AI. Check for mistakes.
Comment on lines +1437 to +1441
# Phase 1: Norm-only (quantization repair)
def norm_selector(name):
return "norm" in name or "ln_scale" in name or "attn_scale" in name or "mlp_scale" in name or "resid_mix" in name or "final_norm" in name
run_ttt_phase("phase1_norm", ttt_p1_epochs, ttt_p1_lr, norm_selector,
lambda params, lr: torch.optim.Adam(params, lr=lr))
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Phase 1 is described as “Norm-only”, but this model uses RMSNorm modules without learnable parameters, so "norm" in name/"final_norm" in name won’t actually select any parameters to train. If the intent is to recalibrate norms, you may need learnable norm weights (or explicitly target the parameters that actually exist, e.g., q_gain, smear.gate, etc.), and align the selector/documentation accordingly.

Copilot uses AI. Check for mistakes.
Comment on lines +1444 to +1449
def block_selector(name):
is_last_blocks = any(f"blocks.{i}." in name for i in range(n_blocks - ttt_p2_unfreeze_last, n_blocks))
is_norm = "norm" in name or "ln_scale" in name
is_scale = "attn_scale" in name or "mlp_scale" in name or "resid_mix" in name
is_head = "final_norm" in name or "tok_emb" in name or "lm_head" in name
return is_last_blocks or is_norm or is_scale or is_head
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Phase 2’s selector unfreezes tok_emb unconditionally. The PR description/README call out unfreezing lm_head, but with tied embeddings lm_head is None and unfreezing tok_emb is the real equivalent; with untied embeddings, this would also update embeddings unexpectedly. Consider gating tok_emb unfreezing on tie_embeddings (or update the README/description to explicitly include embeddings).

Copilot uses AI. Check for mistakes.
"name": "QAT + BigramHash(12288) + Stride 32",
"val_loss": 1.14443,
"bytes_total": 15902583,
"blurb": "10 layers, QAT with STE (int5 MLP / int6 attn), BigramHash 12288, eval stride 32, magnitude pruning 5%, SWA every 50 steps, zstd-22. Based on 10L_Int5MLP_MuonWD04_SWA50.",
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The blurb claims “magnitude pruning 5%” and “SWA every 50 steps”, but the corresponding train_gpt.py uses a fixed 15% pruning quantile and defaults SWA_EVERY to 25. Please update the blurb (or the script defaults) so the submission metadata reflects the actual configuration used.

Suggested change
"blurb": "10 layers, QAT with STE (int5 MLP / int6 attn), BigramHash 12288, eval stride 32, magnitude pruning 5%, SWA every 50 steps, zstd-22. Based on 10L_Int5MLP_MuonWD04_SWA50.",
"blurb": "10 layers, QAT with STE (int5 MLP / int6 attn), BigramHash 12288, eval stride 32, magnitude pruning 15%, SWA every 25 steps, zstd-22. Based on 10L_Int5MLP_MuonWD04_SWA50.",

Copilot uses AI. Check for mistakes.
Comment on lines +1413 to +1419
for bs in range(0, usable, chunk):
local_bs = bs + rank * ttt_batch_seqs
local_be = local_bs + ttt_batch_seqs
idx = perm[local_bs:local_be]
bx = torch.stack([val_tokens[i * seq_len : i * seq_len + seq_len] for i in idx]).to(device=device, dtype=torch.int64)
by = torch.stack([val_tokens[i * seq_len + 1 : i * seq_len + seq_len + 1] for i in idx]).to(device=device, dtype=torch.int64)
ttt_opt.zero_grad(set_to_none=True)
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In distributed TTT, usable can be smaller than chunk (e.g., small validation set or large TTT_BATCH_SEQS). For ranks where local_bs >= usable, idx becomes empty and torch.stack([... for i in idx]) will raise. Consider computing a per-iteration be = min(bs + chunk, usable), skipping when local_bs >= be (or idx.numel()==0), and weighting ep_loss/ep_count by the actual local batch size rather than ttt_batch_seqs unconditionally.

Copilot uses AI. Check for mistakes.
import math
import os
import random
import subprocess
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

subprocess is imported but never used in this script. Removing unused imports helps keep these record scripts easier to audit and reduces lint noise.

Suggested change
import subprocess

Copilot uses AI. Check for mistakes.
84.65ms/step with FA3 Hopper (was 96ms), 6939 steps.
Two-phase TTT: norm-only 100ep + selective-freeze 25ep.
Artifact 15.70MB. Seed 42 running for 3-seed validation.
@EthanYangTW EthanYangTW changed the title Record: 11L XSA4 + Tight SWA + Two-Phase TTT (val_bpb=1.1258) Record: 11L XSA4 + Tight SWA + FA3 + Two-Phase TTT (val_bpb=1.1216) Mar 22, 2026
@EthanYangTW
Copy link
Author

Superseded by #415 (1.1216 with FA3 Hopper)

JoeProAI pushed a commit to JoeProAI/parameter-golf that referenced this pull request Mar 22, 2026
Architecture discovered via GEPA (Gemini-driven evolutionary search).
SwiGLU FFN, Star-ReLU, U-Net skip gates, BigramHash 8192, XSA4.
AdamW TTT (lr=0.0005, 10ep) from @sjp611 (openai#442).
EMA, RoPE, LN Scale, QAT from @felipe-parodi (openai#398) and @fbedev (openai#410).

3-seed results: 1.06733 / 1.06833 / 1.06580
Mean: 1.06715, Std: 0.00104

Built by @joepro with AI agents via OpenClaw.
Compute provided by Modal.
JoeProAI added a commit to JoeProAI/parameter-golf that referenced this pull request Mar 22, 2026
Architecture discovered via GEPA (Gemini-driven evolutionary search).
SwiGLU FFN, Star-ReLU, U-Net skip gates, BigramHash 8192, XSA4.
AdamW TTT (lr=0.0005, 10ep) from @sjp611 (openai#442).
EMA, RoPE, LN Scale, QAT from @felipe-parodi (openai#398) and @fbedev (openai#410).

3-seed results: 1.06733 / 1.06833 / 1.06580
Mean: 1.06715, Std: 0.00104

Built by @joepro with AI agents via OpenClaw.
Compute provided by Modal.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants