Skip to content

Record: 11L XSA4 + Tight SWA + FA3 + Two-Phase TTT (3-seed mean val_bpb=1.1227)#417

Open
EthanYangTW wants to merge 18 commits intoopenai:mainfrom
EthanYangTW:submission/fa3-twophase-ttt-3seed
Open

Record: 11L XSA4 + Tight SWA + FA3 + Two-Phase TTT (3-seed mean val_bpb=1.1227)#417
EthanYangTW wants to merge 18 commits intoopenai:mainfrom
EthanYangTW:submission/fa3-twophase-ttt-3seed

Conversation

@EthanYangTW
Copy link

Summary

Built on PR #374 with FA3 Hopper attention and a novel two-phase test-time training approach:

  • FA3 Hopper: 84.65ms/step, enabling ~7,000 training steps in 600s
  • Phase 1 — Norm-Only Recalibration (50 epochs, Adam lr=0.01): Only unfreeze LayerNorm weights + scales (~22K params). Recalibrates activation distributions damaged by int6 quantization.
  • Phase 2 — Selective-Freeze Block Adaptation (10 epochs, SGD lr=0.005): Unfreeze last 3 blocks + norms + scales + lm_head (~7.6M params). Adapts on the recalibrated foundation while preserving SWA-averaged weights in first 8 blocks.

Key insight: the two phases target different error sources (quantization artifacts vs. distribution mismatch) and are additive.

Results

Seed val_bpb Artifact
1337 1.1222 15,758,953 bytes
42 1.1230 15,798,468 bytes
2024 1.1228 15,689,654 bytes
Mean 1.1227 All under 16MB
  • Post-SWA BPB: ~1.1414
  • TTT improvement: ~-0.019
  • Training: 84.65ms/step, ~7000 steps, 600s
  • Eval (TTT + sliding window): ~500s

Architecture

  • 11L, dim=512, 8 heads / 4 KV (GQA), XSA last 4 layers
  • 3x MLP relu² + SmearGate + OrthoInit
  • Partial RoPE 16/64, LN Scale, BigramHash(2048)
  • Tight SWA, Late QAT (4%), int6 + zstd-22, 2% pruning
  • FA3 Hopper (flash_attn_interface)

Setup

pip install zstandard
pip install flash_attn_3 --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291

Command

torchrun --standalone --nproc_per_node=8 train_gpt.py

EthanYangTW and others added 18 commits March 21, 2026 20:35
Based on SOTA (10L_Int5MLP_MuonWD04_SWA50) with improvements:
- QAT with STE for int5/int6 quantization-aware training
- BigramHash increased from 10240 to 12288
- Eval stride reduced from 64 to 32 for better context
- Magnitude pruning increased from 3% to 5%
- SWA every 25 steps instead of 50
- Artifact size: ~15.89MB (under 16MB limit)
Restore original train_gpt.py baseline. Add new records folder with
submission script based on 10L_Int5MLP_MuonWD04_SWA50 SOTA.

Changes: QAT with STE, BigramHash 12288, eval stride 32,
5% magnitude pruning, SWA every 25 steps.
Port LoRA TTT from records/2026-03-17_LoRA_TTT into our submission.
At eval time, per-document rank-8 LoRA adapters are trained on Q/V
projections and lm_head, then used for scoring. Expected -0.003 to
-0.005 bpb improvement on top of sliding window eval.
val_bpb=1.14443 (seed=2024), artifact=15.90MB
…/train_gpt.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
…le, EMA, Late QAT, TTT

Major rewrite targeting top-5 leaderboard:
- 11 layers (from 10), BigramHash reduced to 10240 to fit 16MB
- XSA (Exclusive Self-Attention) on last 4 layers
- Partial RoPE: 16/64 head dims get position encoding
- LN Scale: 1/sqrt(layer+1) dampening on deeper layers
- EMA (decay=0.997) replaces SWA
- Late QAT: STE int6 enabled only in final 4% of training
- TTT: 25-epoch SGD on val data post-quantization
- FA3 auto-detection with SDPA fallback
- Reverted SwiGLU back to relu² (confirmed worse by openai#340, openai#344)
…y 10 steps

- Disable FA3 (SDPA faster for GQA on PyTorch 2.9)
- BigramHash 10240 -> 8192 to fit 11L under 16MB
- EMA update every 10 steps with adjusted decay to reduce CPU overhead
- Simplify attention forward (remove FA3 code path)
Previous run: 16.94MB with BigramHash 8192 + 5% pruning.
BigramHash 2048 saves ~0.5MB, 10% pruning improves compression further.
v3 was 16.38MB with BigramHash 2048 + 10% pruning.
Removing BigramHash saves ~0.15MB, 15% pruning improves zstd compression.
Fork of unnir's openai#374 (1.1246 BPB) with TTT added:
- 11L, XSA4, Partial RoPE 16/64, LN Scale, Tight SWA
- Shared VE128, SmearGate, BigramHash 2048
- TTT: 25 epochs SGD on val data post-quantization
- Trimmed to 1476 lines (under 1500 limit)
Previous TTT took 7+ min per epoch (uncompiled, single GPU).
Now: torch.compile + DDP across 8 GPUs + 3 epochs + batch 64.
Should finish in ~2-3 min total.
flash_attn_interface (FA3 Hopper) not available on RunPod.
Falls back to flash_attn, then SDPA with GQA support.
Two-phase TTT on PR openai#374 base: phase 1 norm-only recalibration
(100ep Adam), phase 2 selective-freeze last 2 blocks (15ep SGD).
Artifact 15.76MB.
84.65ms/step with FA3 Hopper (was 96ms), 6939 steps.
Two-phase TTT: norm-only 100ep + selective-freeze 25ep.
Artifact 15.70MB. Seed 42 running for 3-seed validation.
FA3 Hopper 84.65ms/step, two-phase TTT (norm-only + selective-freeze).
3 seeds: 1.1222, 1.1230, 1.1228. All artifacts under 16MB.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds new /records/track_10min_16mb entries documenting improved 11-layer GPT results via FA3 (Hopper) and a two-phase test-time training (TTT) procedure, along with associated run artifacts (logs/config/scripts).

Changes:

  • Added a new “Two-Phase TTT (Norm repair + selective unfreeze)” record folder with training script, README, requirements, submission metadata, and 3 seed logs.
  • Added an additional 11L Tight-SWA+TTT record training script snapshot.
  • Added a separate 2026-03-21 QAT/BigramHash record (script/log/README/submission).

Reviewed changes

Copilot reviewed 8 out of 12 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
records/track_10min_16mb/2026-03-22_TwoPhase_TTT_NormRepair/train_gpt.py Main training + int6 export + two-phase TTT + sliding-window eval implementation for the new record
records/track_10min_16mb/2026-03-22_TwoPhase_TTT_NormRepair/README.md Human-readable description of the two-phase TTT approach and results
records/track_10min_16mb/2026-03-22_TwoPhase_TTT_NormRepair/requirements.txt Dependencies needed to run the record
records/track_10min_16mb/2026-03-22_TwoPhase_TTT_NormRepair/submission.json Submission metadata (score/bytes/blurb) for the record
records/track_10min_16mb/2026-03-22_TwoPhase_TTT_NormRepair/train_seed42.log Seed-42 run log artifact
records/track_10min_16mb/2026-03-22_TwoPhase_TTT_NormRepair/train_seed2024.log Seed-2024 run log artifact
records/track_10min_16mb/2026-03-22_TwoPhase_TTT_NormRepair/train_seed1337.log Seed-1337 run log artifact
records/track_10min_16mb/2026-03-22_11L_XSA4_TightSWA_TTT/train_gpt.py Additional record snapshot script (baseline/related)
records/track_10min_16mb/2026-03-21_QAT_BigramHash12K_Stride32/train_gpt.py Separate QAT/BigramHash record script
records/track_10min_16mb/2026-03-21_QAT_BigramHash12K_Stride32/train_seed2024.log Separate QAT/BigramHash record log artifact
records/track_10min_16mb/2026-03-21_QAT_BigramHash12K_Stride32/README.md Separate QAT/BigramHash record description
records/track_10min_16mb/2026-03-21_QAT_BigramHash12K_Stride32/submission.json Separate QAT/BigramHash record submission metadata

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

(int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")),
default=0,
) + 1
late_k_layers = set(range(num_layers_total - 2, num_layers_total))
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

late_k_layers is computed but never used, which makes the quantization code harder to understand/maintain. Please remove it or use it (e.g., if the intent was different quantization for late layers).

Suggested change
late_k_layers = set(range(num_layers_total - 2, num_layers_total))

Copilot uses AI. Check for mistakes.
Comment on lines +1413 to +1418
for bs in range(0, usable, chunk):
local_bs = bs + rank * ttt_batch_seqs
local_be = local_bs + ttt_batch_seqs
idx = perm[local_bs:local_be]
bx = torch.stack([val_tokens[i * seq_len : i * seq_len + seq_len] for i in idx]).to(device=device, dtype=torch.int64)
by = torch.stack([val_tokens[i * seq_len + 1 : i * seq_len + seq_len + 1] for i in idx]).to(device=device, dtype=torch.int64)
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In distributed runs where n_seqs < ttt_batch_seqs * world_size (or whenever usable < chunk and you fall back to usable = n_seqs), ranks with rank > 0 will slice an empty idx, causing torch.stack([]) to throw. Add a guard like if local_bs >= usable: continue and handle partial batches (or ensure usable is always at least chunk).

Copilot uses AI. Check for mistakes.
Comment on lines +1433 to +1434
# clear compiled graph for next phase
torch._dynamo.reset()
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch._dynamo.reset() is a private/internal API and may break across PyTorch versions or unexpectedly invalidate other compiled graphs. Prefer avoiding global resets (e.g., keep separate compiled functions per phase, or use supported APIs to manage compilation cache) so the script is more robust.

Suggested change
# clear compiled graph for next phase
torch._dynamo.reset()
# compiled graph is tied to ttt_compiled and will be released after this function returns

Copilot uses AI. Check for mistakes.
Comment on lines +8 to +9
- **Phase 1 — Norm-Only Recalibration (50 epochs, Adam lr=0.01):** Only unfreeze LayerNorm weights, scales, and final_norm (~22K params). Recalibrates activation distributions damaged by int6 quantization. Acts as post-quantization calibration via gradient descent.
- **Phase 2 — Selective-Freeze Block Adaptation (10 epochs, SGD lr=0.005, momentum=0.9):** Unfreeze last 3 transformer blocks + all norms + scales + lm_head (~7.6M params). Adapts representations on the recalibrated foundation while preserving SWA-averaged weights in the first 8 blocks.
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Phase 1 is described as unfreezing “LayerNorm weights … and final_norm”, but the model uses RMSNorm implemented via F.rms_norm (no learnable weight/bias), so there aren’t any LayerNorm/RMSNorm weights to unfreeze. Please update this description to reflect what’s actually being trained in phase 1 (e.g., scale/residual-mix/gains), so the README matches the code/logs.

Suggested change
- **Phase 1 — Norm-Only Recalibration (50 epochs, Adam lr=0.01):** Only unfreeze LayerNorm weights, scales, and final_norm (~22K params). Recalibrates activation distributions damaged by int6 quantization. Acts as post-quantization calibration via gradient descent.
- **Phase 2 — Selective-Freeze Block Adaptation (10 epochs, SGD lr=0.005, momentum=0.9):** Unfreeze last 3 transformer blocks + all norms + scales + lm_head (~7.6M params). Adapts representations on the recalibrated foundation while preserving SWA-averaged weights in the first 8 blocks.
- **Phase 1 — Norm-Only Recalibration (50 epochs, Adam lr=0.01):** Only unfreeze lightweight scalar parameters (scale / residual-mix / gain parameters; ~22K params). Recalibrates activation distributions damaged by int6 quantization. Acts as post-quantization calibration via gradient descent.
- **Phase 2 — Selective-Freeze Block Adaptation (10 epochs, SGD lr=0.005, momentum=0.9):** Unfreeze last 3 transformer blocks + all Phase‑1 scale / residual-mix / gain parameters + lm_head (~7.6M params). Adapts representations on the recalibrated foundation while preserving SWA-averaged weights in the first 8 blocks.

Copilot uses AI. Check for mistakes.
@@ -0,0 +1,2 @@
zstandard
flash_attn_3
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

requirements.txt lists flash_attn_3, but the code imports flash_attn_interface (preferred) and falls back to flash_attn. If flash_attn_3 doesn’t provide those import names in the execution environment, installs will succeed but runtime imports will still fail. Consider pinning/adding the package that actually provides flash_attn_interface, or clarify installation instructions so requirements and imports align.

Suggested change
flash_attn_3
flash-attn

Copilot uses AI. Check for mistakes.
Comment on lines +1 to +5
{
"name": "QAT + BigramHash(12288) + Stride 32",
"val_loss": 1.14443,
"bytes_total": 15902583,
"blurb": "10 layers, QAT with STE (int5 MLP / int6 attn), BigramHash 12288, eval stride 32, magnitude pruning 5%, SWA every 50 steps, zstd-22. Based on 10L_Int5MLP_MuonWD04_SWA50.",
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description focuses on the 2026-03-22 two-phase TTT record, but this PR also adds a separate 2026-03-21 QAT/BigramHash record. If this extra record is intentional, it would help to mention it in the PR description (or split into a separate PR) so reviewers understand the scope.

Copilot uses AI. Check for mistakes.
@@ -0,0 +1,1474 @@
"""Selective-freeze TTT with SAM (Sharpness-Aware). Fork of #374."""
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The module docstring says this fork uses SAM (Sharpness-Aware Minimization), but the code below only uses Adam/SGD and doesn’t implement SAM. This is misleading for readers trying to reproduce results; please either implement SAM or update the docstring to match the actual optimizers used.

Suggested change
"""Selective-freeze TTT with SAM (Sharpness-Aware). Fork of #374."""
"""Selective-freeze TTT training script. Fork of #374."""

Copilot uses AI. Check for mistakes.
Comment on lines +870 to +876
compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True)
with torch.inference_mode():
for bi in range(0, len(my_windows), batch_seqs):
batch_ws = my_windows[bi:bi + batch_seqs]
bsz = len(batch_ws)
x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device)
y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device)
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eval_val_sliding() compiles forward_logits with dynamic=False/fullgraph=True but then calls it with a batch size (bsz) that can vary (especially on the last batch). This can trigger recompiles or failures and adds substantial overhead during eval. Consider using a fixed batch_seqs tensor shape (pad the final batch) or compile with dynamic=True, and/or compile once outside the loop with a warmup input (see similar pattern in other records scripts).

Copilot uses AI. Check for mistakes.
mrdavtan added a commit to mrdavtan/parameter-golf that referenced this pull request Mar 22, 2026
Major changes:
- DDP gradient sharding: each GPU processes batch_seqs sequences,
  manual all_reduce on gradients (matches PR openai#415/openai#417 approach)
- Two-phase TTT (TTT_TWO_PHASE=1):
  Phase 1: norm-only recalibration (50 epochs Adam, ~22K params)
  Phase 2: selective block adaptation (10 epochs SGD, last 3 blocks)
- TTT_BATCH_SEQS=64 per GPU (512 total with 8 GPUs)
- Falls back to single-phase SGD if TTT_TWO_PHASE=0

Expected speedup: ~235x (from 1344s/epoch to ~5.7s/epoch)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants