Record: 11L XSA4 + Tight SWA + FA3 + Two-Phase TTT (val_bpb=1.1216)#410
Record: 11L XSA4 + Tight SWA + FA3 + Two-Phase TTT (val_bpb=1.1216)#410EthanYangTW wants to merge 17 commits intoopenai:mainfrom
Conversation
Based on SOTA (10L_Int5MLP_MuonWD04_SWA50) with improvements: - QAT with STE for int5/int6 quantization-aware training - BigramHash increased from 10240 to 12288 - Eval stride reduced from 64 to 32 for better context - Magnitude pruning increased from 3% to 5% - SWA every 25 steps instead of 50 - Artifact size: ~15.89MB (under 16MB limit)
Restore original train_gpt.py baseline. Add new records folder with submission script based on 10L_Int5MLP_MuonWD04_SWA50 SOTA. Changes: QAT with STE, BigramHash 12288, eval stride 32, 5% magnitude pruning, SWA every 25 steps.
Port LoRA TTT from records/2026-03-17_LoRA_TTT into our submission. At eval time, per-document rank-8 LoRA adapters are trained on Q/V projections and lm_head, then used for scoring. Expected -0.003 to -0.005 bpb improvement on top of sliding window eval.
val_bpb=1.14443 (seed=2024), artifact=15.90MB
…/train_gpt.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
…le, EMA, Late QAT, TTT Major rewrite targeting top-5 leaderboard: - 11 layers (from 10), BigramHash reduced to 10240 to fit 16MB - XSA (Exclusive Self-Attention) on last 4 layers - Partial RoPE: 16/64 head dims get position encoding - LN Scale: 1/sqrt(layer+1) dampening on deeper layers - EMA (decay=0.997) replaces SWA - Late QAT: STE int6 enabled only in final 4% of training - TTT: 25-epoch SGD on val data post-quantization - FA3 auto-detection with SDPA fallback - Reverted SwiGLU back to relu² (confirmed worse by openai#340, openai#344)
…y 10 steps - Disable FA3 (SDPA faster for GQA on PyTorch 2.9) - BigramHash 10240 -> 8192 to fit 11L under 16MB - EMA update every 10 steps with adjusted decay to reduce CPU overhead - Simplify attention forward (remove FA3 code path)
Previous run: 16.94MB with BigramHash 8192 + 5% pruning. BigramHash 2048 saves ~0.5MB, 10% pruning improves compression further.
v3 was 16.38MB with BigramHash 2048 + 10% pruning. Removing BigramHash saves ~0.15MB, 15% pruning improves zstd compression.
Fork of unnir's openai#374 (1.1246 BPB) with TTT added: - 11L, XSA4, Partial RoPE 16/64, LN Scale, Tight SWA - Shared VE128, SmearGate, BigramHash 2048 - TTT: 25 epochs SGD on val data post-quantization - Trimmed to 1476 lines (under 1500 limit)
Previous TTT took 7+ min per epoch (uncompiled, single GPU). Now: torch.compile + DDP across 8 GPUs + 3 epochs + batch 64. Should finish in ~2-3 min total.
flash_attn_interface (FA3 Hopper) not available on RunPod. Falls back to flash_attn, then SDPA with GQA support.
Two-phase TTT on PR openai#374 base: phase 1 norm-only recalibration (100ep Adam), phase 2 selective-freeze last 2 blocks (15ep SGD). Artifact 15.76MB.
There was a problem hiding this comment.
Pull request overview
Adds new /records/track_10min_16mb entries building on PR #374, introducing a two-phase test-time training (TTT) evaluation procedure (norm/scales “repair” phase + selective unfreeze of late blocks), alongside additional record scripts/logs for related runs.
Changes:
- Added a new record run directory implementing two-phase TTT after int6 roundtrip export.
- Added an 11-layer Tight-SWA + (single-phase) TTT record script variant.
- Added a QAT + BigramHash(12K) + stride-32 record (script, README, submission metadata, and training log).
Reviewed changes
Copilot reviewed 7 out of 8 changed files in this pull request and generated 10 comments.
Show a summary per file
| File | Description |
|---|---|
| records/track_10min_16mb/2026-03-22_TwoPhase_TTT_NormRepair/train_gpt.py | Implements two-phase selective-freeze TTT and int6 export + roundtrip eval pipeline |
| records/track_10min_16mb/2026-03-22_TwoPhase_TTT_NormRepair/README.md | Documents the two-phase TTT approach/results |
| records/track_10min_16mb/2026-03-22_TwoPhase_TTT_NormRepair/submission.json | Submission metadata for the two-phase TTT record |
| records/track_10min_16mb/2026-03-22_11L_XSA4_TightSWA_TTT/train_gpt.py | 11-layer Tight-SWA baseline with (single-phase) TTT and int6 roundtrip eval |
| records/track_10min_16mb/2026-03-21_QAT_BigramHash12K_Stride32/train_gpt.py | QAT + stride-32 record script including pruning + mixed int6/int8 export |
| records/track_10min_16mb/2026-03-21_QAT_BigramHash12K_Stride32/README.md | Documents the QAT + BigramHash(12K) run configuration/results |
| records/track_10min_16mb/2026-03-21_QAT_BigramHash12K_Stride32/submission.json | Submission metadata for the QAT + BigramHash(12K) record |
| records/track_10min_16mb/2026-03-21_QAT_BigramHash12K_Stride32/train_seed2024.log | Captured training/eval log for the QAT + BigramHash(12K) run |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| num_layers_total = max( | ||
| (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), | ||
| default=0, | ||
| ) + 1 | ||
| late_k_layers = set(range(num_layers_total - 2, num_layers_total)) |
There was a problem hiding this comment.
late_k_layers is computed but never used. If this is leftover from an earlier “late-K passthrough” experiment, it should be removed or wired into the quantization logic to avoid dead code.
| num_layers_total = max( | |
| (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), | |
| default=0, | |
| ) + 1 | |
| late_k_layers = set(range(num_layers_total - 2, num_layers_total)) |
| - **Phase 1 — Norm-Only Recalibration (100 epochs, Adam lr=0.01):** Only unfreeze LayerNorm weights, scales, and final_norm (~22K params). Recalibrates activation distributions damaged by int6 quantization. Acts as post-quantization calibration via gradient descent. | ||
| - **Phase 2 — Selective-Freeze Block Adaptation (15 epochs, SGD lr=0.003, momentum=0.95):** Unfreeze last 2 transformer blocks + all norms + scales + lm_head (~5.3M params). Adapts representations on the recalibrated foundation while preserving SWA-averaged weights in the first 9 blocks. |
There was a problem hiding this comment.
The README claims Phase 1 unfreezes “LayerNorm weights” and final_norm, but the implementation uses parameterless RMSNorm, so there are no norm weights to unfreeze. It also states Phase 2 unfreezes lm_head, but with tied embeddings the code effectively unfreezes tok_emb instead. Please update the README to reflect which parameters are actually trained in each phase (or change the code to match the described parameter set).
| - **Phase 1 — Norm-Only Recalibration (100 epochs, Adam lr=0.01):** Only unfreeze LayerNorm weights, scales, and final_norm (~22K params). Recalibrates activation distributions damaged by int6 quantization. Acts as post-quantization calibration via gradient descent. | |
| - **Phase 2 — Selective-Freeze Block Adaptation (15 epochs, SGD lr=0.003, momentum=0.95):** Unfreeze last 2 transformer blocks + all norms + scales + lm_head (~5.3M params). Adapts representations on the recalibrated foundation while preserving SWA-averaged weights in the first 9 blocks. | |
| - **Phase 1 — Norm-Only Recalibration (100 epochs, Adam lr=0.01):** Only unfreeze the small set of normalization-related parameters that exist in this model (~22K params). The implementation uses parameterless `RMSNorm`, so there are no LayerNorm gamma/beta or separate `final_norm` weights to train; this phase focuses on recalibrating activation distributions damaged by int6 quantization and acts as post-quantization calibration via gradient descent. | |
| - **Phase 2 — Selective-Freeze Block Adaptation (15 epochs, SGD lr=0.003, momentum=0.95):** Unfreeze the last 2 transformer blocks + associated norms/scales + the tied output head / token embedding matrix (`lm_head` is tied to `tok_emb`, so updating the head updates the embeddings; ~5.3M params). This adapts representations on the recalibrated foundation while preserving SWA-averaged weights in the first 9 blocks. |
| - **Magnitude pruning 5%:** Increased from 3% to improve compression ratio. | ||
| - **SWA every 50 steps:** Checkpoint averaging during warmdown. |
There was a problem hiding this comment.
The README states “Magnitude pruning 5%” (and mentions SWA every 50), but the accompanying script prunes at a fixed 15% quantile (0.15) and uses SWA_EVERY=25 by default. Please align the README (and/or make pruning percentage and SWA cadence match the described configuration used for this record).
| - **Magnitude pruning 5%:** Increased from 3% to improve compression ratio. | |
| - **SWA every 50 steps:** Checkpoint averaging during warmdown. | |
| - **Magnitude pruning 15% (quantile=0.15):** Increased from 3% to improve compression ratio. | |
| - **SWA every 25 steps:** Checkpoint averaging during warmdown. |
| with torch.no_grad(): | ||
| for name, param in base_model.named_parameters(): | ||
| if param.ndim == 2 and param.numel() > 65536: | ||
| threshold = torch.quantile(param.abs().float().flatten(), 0.15) |
There was a problem hiding this comment.
Magnitude pruning is hard-coded to the 15% quantile (0.15), which makes it easy for the record script and README/submission metadata to drift out of sync. Consider wiring this to an env/arg (similar to PRUNE_PCT in other record scripts) or updating the documentation/metadata to reflect the actual pruning level used.
| with torch.no_grad(): | |
| for name, param in base_model.named_parameters(): | |
| if param.ndim == 2 and param.numel() > 65536: | |
| threshold = torch.quantile(param.abs().float().flatten(), 0.15) | |
| # Make pruning percentage configurable via PRUNE_PCT env var (default 0.15) | |
| prune_pct_str = os.getenv("PRUNE_PCT") | |
| try: | |
| prune_pct = float(prune_pct_str) if prune_pct_str is not None else 0.15 | |
| except (TypeError, ValueError): | |
| prune_pct = 0.15 | |
| if master_process: | |
| log0(f"magnitude_pruning: using prune_pct={prune_pct}") | |
| with torch.no_grad(): | |
| for name, param in base_model.named_parameters(): | |
| if param.ndim == 2 and param.numel() > 65536: | |
| threshold = torch.quantile(param.abs().float().flatten(), prune_pct) |
| @@ -0,0 +1,1474 @@ | |||
| """Selective-freeze TTT with SAM (Sharpness-Aware). Fork of #374.""" | |||
There was a problem hiding this comment.
The file docstring mentions SAM (Sharpness-Aware Minimization), but the implementation doesn’t apply SAM (no perturb/second forward, etc.). This is likely to confuse readers and reviewers; either implement SAM in the TTT phases or update the docstring to reflect the actual approach (two-phase selective-freeze TTT).
| """Selective-freeze TTT with SAM (Sharpness-Aware). Fork of #374.""" | |
| """Selective-freeze two-phase TTT (no SAM; Sharpness-Aware Minimization not applied here). Fork of #374.""" |
| # Phase 1: Norm-only (quantization repair) | ||
| def norm_selector(name): | ||
| return "norm" in name or "ln_scale" in name or "attn_scale" in name or "mlp_scale" in name or "resid_mix" in name or "final_norm" in name | ||
| run_ttt_phase("phase1_norm", ttt_p1_epochs, ttt_p1_lr, norm_selector, | ||
| lambda params, lr: torch.optim.Adam(params, lr=lr)) |
There was a problem hiding this comment.
Phase 1 is described as “Norm-only”, but this model uses RMSNorm modules without learnable parameters, so "norm" in name/"final_norm" in name won’t actually select any parameters to train. If the intent is to recalibrate norms, you may need learnable norm weights (or explicitly target the parameters that actually exist, e.g., q_gain, smear.gate, etc.), and align the selector/documentation accordingly.
| def block_selector(name): | ||
| is_last_blocks = any(f"blocks.{i}." in name for i in range(n_blocks - ttt_p2_unfreeze_last, n_blocks)) | ||
| is_norm = "norm" in name or "ln_scale" in name | ||
| is_scale = "attn_scale" in name or "mlp_scale" in name or "resid_mix" in name | ||
| is_head = "final_norm" in name or "tok_emb" in name or "lm_head" in name | ||
| return is_last_blocks or is_norm or is_scale or is_head |
There was a problem hiding this comment.
Phase 2’s selector unfreezes tok_emb unconditionally. The PR description/README call out unfreezing lm_head, but with tied embeddings lm_head is None and unfreezing tok_emb is the real equivalent; with untied embeddings, this would also update embeddings unexpectedly. Consider gating tok_emb unfreezing on tie_embeddings (or update the README/description to explicitly include embeddings).
| "name": "QAT + BigramHash(12288) + Stride 32", | ||
| "val_loss": 1.14443, | ||
| "bytes_total": 15902583, | ||
| "blurb": "10 layers, QAT with STE (int5 MLP / int6 attn), BigramHash 12288, eval stride 32, magnitude pruning 5%, SWA every 50 steps, zstd-22. Based on 10L_Int5MLP_MuonWD04_SWA50.", |
There was a problem hiding this comment.
The blurb claims “magnitude pruning 5%” and “SWA every 50 steps”, but the corresponding train_gpt.py uses a fixed 15% pruning quantile and defaults SWA_EVERY to 25. Please update the blurb (or the script defaults) so the submission metadata reflects the actual configuration used.
| "blurb": "10 layers, QAT with STE (int5 MLP / int6 attn), BigramHash 12288, eval stride 32, magnitude pruning 5%, SWA every 50 steps, zstd-22. Based on 10L_Int5MLP_MuonWD04_SWA50.", | |
| "blurb": "10 layers, QAT with STE (int5 MLP / int6 attn), BigramHash 12288, eval stride 32, magnitude pruning 15%, SWA every 25 steps, zstd-22. Based on 10L_Int5MLP_MuonWD04_SWA50.", |
| for bs in range(0, usable, chunk): | ||
| local_bs = bs + rank * ttt_batch_seqs | ||
| local_be = local_bs + ttt_batch_seqs | ||
| idx = perm[local_bs:local_be] | ||
| bx = torch.stack([val_tokens[i * seq_len : i * seq_len + seq_len] for i in idx]).to(device=device, dtype=torch.int64) | ||
| by = torch.stack([val_tokens[i * seq_len + 1 : i * seq_len + seq_len + 1] for i in idx]).to(device=device, dtype=torch.int64) | ||
| ttt_opt.zero_grad(set_to_none=True) |
There was a problem hiding this comment.
In distributed TTT, usable can be smaller than chunk (e.g., small validation set or large TTT_BATCH_SEQS). For ranks where local_bs >= usable, idx becomes empty and torch.stack([... for i in idx]) will raise. Consider computing a per-iteration be = min(bs + chunk, usable), skipping when local_bs >= be (or idx.numel()==0), and weighting ep_loss/ep_count by the actual local batch size rather than ttt_batch_seqs unconditionally.
| import math | ||
| import os | ||
| import random | ||
| import subprocess |
There was a problem hiding this comment.
subprocess is imported but never used in this script. Removing unused imports helps keep these record scripts easier to audit and reduces lint noise.
| import subprocess |
84.65ms/step with FA3 Hopper (was 96ms), 6939 steps. Two-phase TTT: norm-only 100ep + selective-freeze 25ep. Artifact 15.70MB. Seed 42 running for 3-seed validation.
|
Superseded by #415 (1.1216 with FA3 Hopper) |
Architecture discovered via GEPA (Gemini-driven evolutionary search). SwiGLU FFN, Star-ReLU, U-Net skip gates, BigramHash 8192, XSA4. AdamW TTT (lr=0.0005, 10ep) from @sjp611 (openai#442). EMA, RoPE, LN Scale, QAT from @felipe-parodi (openai#398) and @fbedev (openai#410). 3-seed results: 1.06733 / 1.06833 / 1.06580 Mean: 1.06715, Std: 0.00104 Built by @joepro with AI agents via OpenClaw. Compute provided by Modal.
Architecture discovered via GEPA (Gemini-driven evolutionary search). SwiGLU FFN, Star-ReLU, U-Net skip gates, BigramHash 8192, XSA4. AdamW TTT (lr=0.0005, 10ep) from @sjp611 (openai#442). EMA, RoPE, LN Scale, QAT from @felipe-parodi (openai#398) and @fbedev (openai#410). 3-seed results: 1.06733 / 1.06833 / 1.06580 Mean: 1.06715, Std: 0.00104 Built by @joepro with AI agents via OpenClaw. Compute provided by Modal.
Summary
Built on PR #374 with a novel two-phase test-time training approach:
Key insight: the two phases target different error sources (quantization artifacts vs. distribution mismatch) and are additive (-0.019 BPB combined).
Results
Architecture
Command