Record: 11L XSA4 + Tight SWA + FA3 + Two-Phase TTT (3-seed mean val_bpb=1.1227)#417
Record: 11L XSA4 + Tight SWA + FA3 + Two-Phase TTT (3-seed mean val_bpb=1.1227)#417EthanYangTW wants to merge 18 commits intoopenai:mainfrom
Conversation
Based on SOTA (10L_Int5MLP_MuonWD04_SWA50) with improvements: - QAT with STE for int5/int6 quantization-aware training - BigramHash increased from 10240 to 12288 - Eval stride reduced from 64 to 32 for better context - Magnitude pruning increased from 3% to 5% - SWA every 25 steps instead of 50 - Artifact size: ~15.89MB (under 16MB limit)
Restore original train_gpt.py baseline. Add new records folder with submission script based on 10L_Int5MLP_MuonWD04_SWA50 SOTA. Changes: QAT with STE, BigramHash 12288, eval stride 32, 5% magnitude pruning, SWA every 25 steps.
Port LoRA TTT from records/2026-03-17_LoRA_TTT into our submission. At eval time, per-document rank-8 LoRA adapters are trained on Q/V projections and lm_head, then used for scoring. Expected -0.003 to -0.005 bpb improvement on top of sliding window eval.
val_bpb=1.14443 (seed=2024), artifact=15.90MB
…/train_gpt.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
…le, EMA, Late QAT, TTT Major rewrite targeting top-5 leaderboard: - 11 layers (from 10), BigramHash reduced to 10240 to fit 16MB - XSA (Exclusive Self-Attention) on last 4 layers - Partial RoPE: 16/64 head dims get position encoding - LN Scale: 1/sqrt(layer+1) dampening on deeper layers - EMA (decay=0.997) replaces SWA - Late QAT: STE int6 enabled only in final 4% of training - TTT: 25-epoch SGD on val data post-quantization - FA3 auto-detection with SDPA fallback - Reverted SwiGLU back to relu² (confirmed worse by openai#340, openai#344)
…y 10 steps - Disable FA3 (SDPA faster for GQA on PyTorch 2.9) - BigramHash 10240 -> 8192 to fit 11L under 16MB - EMA update every 10 steps with adjusted decay to reduce CPU overhead - Simplify attention forward (remove FA3 code path)
Previous run: 16.94MB with BigramHash 8192 + 5% pruning. BigramHash 2048 saves ~0.5MB, 10% pruning improves compression further.
v3 was 16.38MB with BigramHash 2048 + 10% pruning. Removing BigramHash saves ~0.15MB, 15% pruning improves zstd compression.
Fork of unnir's openai#374 (1.1246 BPB) with TTT added: - 11L, XSA4, Partial RoPE 16/64, LN Scale, Tight SWA - Shared VE128, SmearGate, BigramHash 2048 - TTT: 25 epochs SGD on val data post-quantization - Trimmed to 1476 lines (under 1500 limit)
Previous TTT took 7+ min per epoch (uncompiled, single GPU). Now: torch.compile + DDP across 8 GPUs + 3 epochs + batch 64. Should finish in ~2-3 min total.
flash_attn_interface (FA3 Hopper) not available on RunPod. Falls back to flash_attn, then SDPA with GQA support.
Two-phase TTT on PR openai#374 base: phase 1 norm-only recalibration (100ep Adam), phase 2 selective-freeze last 2 blocks (15ep SGD). Artifact 15.76MB.
84.65ms/step with FA3 Hopper (was 96ms), 6939 steps. Two-phase TTT: norm-only 100ep + selective-freeze 25ep. Artifact 15.70MB. Seed 42 running for 3-seed validation.
FA3 Hopper 84.65ms/step, two-phase TTT (norm-only + selective-freeze). 3 seeds: 1.1222, 1.1230, 1.1228. All artifacts under 16MB.
There was a problem hiding this comment.
Pull request overview
Adds new /records/track_10min_16mb entries documenting improved 11-layer GPT results via FA3 (Hopper) and a two-phase test-time training (TTT) procedure, along with associated run artifacts (logs/config/scripts).
Changes:
- Added a new “Two-Phase TTT (Norm repair + selective unfreeze)” record folder with training script, README, requirements, submission metadata, and 3 seed logs.
- Added an additional 11L Tight-SWA+TTT record training script snapshot.
- Added a separate 2026-03-21 QAT/BigramHash record (script/log/README/submission).
Reviewed changes
Copilot reviewed 8 out of 12 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| records/track_10min_16mb/2026-03-22_TwoPhase_TTT_NormRepair/train_gpt.py | Main training + int6 export + two-phase TTT + sliding-window eval implementation for the new record |
| records/track_10min_16mb/2026-03-22_TwoPhase_TTT_NormRepair/README.md | Human-readable description of the two-phase TTT approach and results |
| records/track_10min_16mb/2026-03-22_TwoPhase_TTT_NormRepair/requirements.txt | Dependencies needed to run the record |
| records/track_10min_16mb/2026-03-22_TwoPhase_TTT_NormRepair/submission.json | Submission metadata (score/bytes/blurb) for the record |
| records/track_10min_16mb/2026-03-22_TwoPhase_TTT_NormRepair/train_seed42.log | Seed-42 run log artifact |
| records/track_10min_16mb/2026-03-22_TwoPhase_TTT_NormRepair/train_seed2024.log | Seed-2024 run log artifact |
| records/track_10min_16mb/2026-03-22_TwoPhase_TTT_NormRepair/train_seed1337.log | Seed-1337 run log artifact |
| records/track_10min_16mb/2026-03-22_11L_XSA4_TightSWA_TTT/train_gpt.py | Additional record snapshot script (baseline/related) |
| records/track_10min_16mb/2026-03-21_QAT_BigramHash12K_Stride32/train_gpt.py | Separate QAT/BigramHash record script |
| records/track_10min_16mb/2026-03-21_QAT_BigramHash12K_Stride32/train_seed2024.log | Separate QAT/BigramHash record log artifact |
| records/track_10min_16mb/2026-03-21_QAT_BigramHash12K_Stride32/README.md | Separate QAT/BigramHash record description |
| records/track_10min_16mb/2026-03-21_QAT_BigramHash12K_Stride32/submission.json | Separate QAT/BigramHash record submission metadata |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), | ||
| default=0, | ||
| ) + 1 | ||
| late_k_layers = set(range(num_layers_total - 2, num_layers_total)) |
There was a problem hiding this comment.
late_k_layers is computed but never used, which makes the quantization code harder to understand/maintain. Please remove it or use it (e.g., if the intent was different quantization for late layers).
| late_k_layers = set(range(num_layers_total - 2, num_layers_total)) |
| for bs in range(0, usable, chunk): | ||
| local_bs = bs + rank * ttt_batch_seqs | ||
| local_be = local_bs + ttt_batch_seqs | ||
| idx = perm[local_bs:local_be] | ||
| bx = torch.stack([val_tokens[i * seq_len : i * seq_len + seq_len] for i in idx]).to(device=device, dtype=torch.int64) | ||
| by = torch.stack([val_tokens[i * seq_len + 1 : i * seq_len + seq_len + 1] for i in idx]).to(device=device, dtype=torch.int64) |
There was a problem hiding this comment.
In distributed runs where n_seqs < ttt_batch_seqs * world_size (or whenever usable < chunk and you fall back to usable = n_seqs), ranks with rank > 0 will slice an empty idx, causing torch.stack([]) to throw. Add a guard like if local_bs >= usable: continue and handle partial batches (or ensure usable is always at least chunk).
| # clear compiled graph for next phase | ||
| torch._dynamo.reset() |
There was a problem hiding this comment.
torch._dynamo.reset() is a private/internal API and may break across PyTorch versions or unexpectedly invalidate other compiled graphs. Prefer avoiding global resets (e.g., keep separate compiled functions per phase, or use supported APIs to manage compilation cache) so the script is more robust.
| # clear compiled graph for next phase | |
| torch._dynamo.reset() | |
| # compiled graph is tied to ttt_compiled and will be released after this function returns |
| - **Phase 1 — Norm-Only Recalibration (50 epochs, Adam lr=0.01):** Only unfreeze LayerNorm weights, scales, and final_norm (~22K params). Recalibrates activation distributions damaged by int6 quantization. Acts as post-quantization calibration via gradient descent. | ||
| - **Phase 2 — Selective-Freeze Block Adaptation (10 epochs, SGD lr=0.005, momentum=0.9):** Unfreeze last 3 transformer blocks + all norms + scales + lm_head (~7.6M params). Adapts representations on the recalibrated foundation while preserving SWA-averaged weights in the first 8 blocks. |
There was a problem hiding this comment.
Phase 1 is described as unfreezing “LayerNorm weights … and final_norm”, but the model uses RMSNorm implemented via F.rms_norm (no learnable weight/bias), so there aren’t any LayerNorm/RMSNorm weights to unfreeze. Please update this description to reflect what’s actually being trained in phase 1 (e.g., scale/residual-mix/gains), so the README matches the code/logs.
| - **Phase 1 — Norm-Only Recalibration (50 epochs, Adam lr=0.01):** Only unfreeze LayerNorm weights, scales, and final_norm (~22K params). Recalibrates activation distributions damaged by int6 quantization. Acts as post-quantization calibration via gradient descent. | |
| - **Phase 2 — Selective-Freeze Block Adaptation (10 epochs, SGD lr=0.005, momentum=0.9):** Unfreeze last 3 transformer blocks + all norms + scales + lm_head (~7.6M params). Adapts representations on the recalibrated foundation while preserving SWA-averaged weights in the first 8 blocks. | |
| - **Phase 1 — Norm-Only Recalibration (50 epochs, Adam lr=0.01):** Only unfreeze lightweight scalar parameters (scale / residual-mix / gain parameters; ~22K params). Recalibrates activation distributions damaged by int6 quantization. Acts as post-quantization calibration via gradient descent. | |
| - **Phase 2 — Selective-Freeze Block Adaptation (10 epochs, SGD lr=0.005, momentum=0.9):** Unfreeze last 3 transformer blocks + all Phase‑1 scale / residual-mix / gain parameters + lm_head (~7.6M params). Adapts representations on the recalibrated foundation while preserving SWA-averaged weights in the first 8 blocks. |
| @@ -0,0 +1,2 @@ | |||
| zstandard | |||
| flash_attn_3 | |||
There was a problem hiding this comment.
requirements.txt lists flash_attn_3, but the code imports flash_attn_interface (preferred) and falls back to flash_attn. If flash_attn_3 doesn’t provide those import names in the execution environment, installs will succeed but runtime imports will still fail. Consider pinning/adding the package that actually provides flash_attn_interface, or clarify installation instructions so requirements and imports align.
| flash_attn_3 | |
| flash-attn |
| { | ||
| "name": "QAT + BigramHash(12288) + Stride 32", | ||
| "val_loss": 1.14443, | ||
| "bytes_total": 15902583, | ||
| "blurb": "10 layers, QAT with STE (int5 MLP / int6 attn), BigramHash 12288, eval stride 32, magnitude pruning 5%, SWA every 50 steps, zstd-22. Based on 10L_Int5MLP_MuonWD04_SWA50.", |
There was a problem hiding this comment.
The PR description focuses on the 2026-03-22 two-phase TTT record, but this PR also adds a separate 2026-03-21 QAT/BigramHash record. If this extra record is intentional, it would help to mention it in the PR description (or split into a separate PR) so reviewers understand the scope.
| @@ -0,0 +1,1474 @@ | |||
| """Selective-freeze TTT with SAM (Sharpness-Aware). Fork of #374.""" | |||
There was a problem hiding this comment.
The module docstring says this fork uses SAM (Sharpness-Aware Minimization), but the code below only uses Adam/SGD and doesn’t implement SAM. This is misleading for readers trying to reproduce results; please either implement SAM or update the docstring to match the actual optimizers used.
| """Selective-freeze TTT with SAM (Sharpness-Aware). Fork of #374.""" | |
| """Selective-freeze TTT training script. Fork of #374.""" |
| compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) | ||
| with torch.inference_mode(): | ||
| for bi in range(0, len(my_windows), batch_seqs): | ||
| batch_ws = my_windows[bi:bi + batch_seqs] | ||
| bsz = len(batch_ws) | ||
| x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) | ||
| y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) |
There was a problem hiding this comment.
eval_val_sliding() compiles forward_logits with dynamic=False/fullgraph=True but then calls it with a batch size (bsz) that can vary (especially on the last batch). This can trigger recompiles or failures and adds substantial overhead during eval. Consider using a fixed batch_seqs tensor shape (pad the final batch) or compile with dynamic=True, and/or compile once outside the loop with a warmup input (see similar pattern in other records scripts).
Major changes: - DDP gradient sharding: each GPU processes batch_seqs sequences, manual all_reduce on gradients (matches PR openai#415/openai#417 approach) - Two-phase TTT (TTT_TWO_PHASE=1): Phase 1: norm-only recalibration (50 epochs Adam, ~22K params) Phase 2: selective block adaptation (10 epochs SGD, last 3 blocks) - TTT_BATCH_SEQS=64 per GPU (512 total with 8 GPUs) - Falls back to single-phase SGD if TTT_TWO_PHASE=0 Expected speedup: ~235x (from 1344s/epoch to ~5.7s/epoch)
Summary
Built on PR #374 with FA3 Hopper attention and a novel two-phase test-time training approach:
Key insight: the two phases target different error sources (quantization artifacts vs. distribution mismatch) and are additive.
Results
Architecture
Setup
Command