Skip to content

Paper-faithful IHA with sequence interleaving: 3.193 val loss (2hr track)#78

Open
ms337 wants to merge 5 commits into
qlabs-eng:mainfrom
ms337:reproduce-3.193
Open

Paper-faithful IHA with sequence interleaving: 3.193 val loss (2hr track)#78
ms337 wants to merge 5 commits into
qlabs-eng:mainfrom
ms337:reproduce-3.193

Conversation

@ms337
Copy link
Copy Markdown
Contributor

@ms337 ms337 commented Apr 21, 2026

Summary

  • Implements Algorithm 1 from the IHA paper (arxiv 2602.21371) with proper sequence-dimension interleaving
  • P=2 pseudo-heads per head → expanded attention sequence (T=2048 → T·P=4096) → P²=4 attention patterns per head
  • Val loss: 3.193 in 112.5m training (best quality we've achieved)
  • Paper-faithful FLOP-matched window schedule: short=N/(2P)=512, long=N·P=4096
  • Fits the 2hr track budget

Algorithm

Implements the paper's 5-step algorithm:

  1. Mix: α^{Q,K,V} ∈ [H, H, P] creates P pseudo-heads via einsum('btmd,mhp->bthpd')
  2. Interleave: [B,T,H,P,d] → permute → [B,T·P,H,d] — pseudo-tokens adjacent per position
  3. Attend: Flash attention on expanded T·P sequence with paper's FLOP-matched windows
  4. De-interleave: [B,T·P,H,d] → [B,T,P,H,d]
  5. Collapse: Learned R ∈ [H, P] reduces pseudo-heads → [B,T,H,d]

Results (2hr track, two_hour/train.py)

Config Epochs Val Loss Training Time
MTP baseline 11 3.222 57.5m
Interleaved IHA P=2 11 3.200 101.9m
Interleaved IHA P=2 12 3.193 112.5m

Reproduction verified on commit 0e63026:

  • Original: 3.192552 / 112.01m
  • Repro: 3.193094 / 112.54m (within 0.001 val loss noise)

Usage

torchrun --standalone --nproc_per_node=8 two_hour/train.py --num-epochs=12

IHA is on by default. Control via:

  • --iha-P=2 (default, pseudo-heads per head)
  • --iha-lr=0.02 (default, LR for mixing matrices)
  • --no-iha to disable

Implementation notes

  • Works correctly with torch.compile(model, dynamic=False) — no graph breaks
  • Peak memory: 65GB (vs ~50GB baseline) due to 2x attention activations
  • Compatible with existing features: MTP, stochastic depth, dupe layers, SWA, logit averaging
  • The 1.76x per-step cost is fundamental (2x flash attention sequence length)

Test plan

  • Algorithm matches paper (mixing + interleaving + attention + collapse)
  • Paper-faithful FLOP-matched windows (short=N/(2P), long=N·P)
  • 11-epoch run: 3.200 val loss
  • 12-epoch run: 3.193 val loss (reproduced twice)
  • Works with torch.compile
  • Compatible with MTP, dupe layers, SWA, logit averaging

🤖 Generated with Claude Code

ubuntu and others added 5 commits April 14, 2026 01:44
Implements cross-head Q/K/V mixing from the IHA paper (arxiv 2602.21371)
on top of the MTP baseline. Each head's query/key/value projection becomes
a learned linear combination of all heads' projections via H×H mixing
matrices, enabling richer cross-head attention patterns.

Key optimization: mixing matrices are fused into the Q/K/V projection
weights at forward time (W_fused[h] = sum_m mix[h,m] * W_orig[m]).
The [H,H]@[H,d*C] fusion matmul is negligible vs the main projection,
keeping per-step overhead to just 47ms (3.3%) over baseline.

Results (sub-1hr track, 11 epochs):
  MTP baseline:  3.222 val loss, 57.5m training
  IHA+MTP fused: 3.214 val loss, 59.7m training (-0.008, under 1hr)

CLI flags: --iha, --iha-v, --iha-lr (default: SCALAR_LR)
Best config: --iha --iha-v --iha-lr=0.02

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
IHA Q+K+V mixing with iha-lr=0.02 is now the default behavior.
No flags needed to reproduce the record — just run:
  torchrun --standalone --nproc_per_node=8 train.py

Use --no-iha to disable IHA if needed.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replaces the previous weight-fusion cross-head mixing with full
Algorithm 1 from the IHA paper (arxiv 2602.21371):

1. Mix: α^{Q,K,V} ∈ [H, H, P] generate P pseudo-heads via einsum
2. Interleave: [B,T,H,P,d] → [B,T*P,H,d] — pseudo-tokens adjacent
3. Attend: Flash attention on expanded T*P sequence
4. De-interleave + collapse: learned R ∈ [H, P] back to [B,T,H,d]

Paper-faithful FLOP-matched window schedule:
- Short (S) layers: window = N/(2P) = 512 for P=2
- Long (L) layers: full expanded context = N*P = 4096 for P=2

Results on sub-1hr track config (12 epochs):
  Previous IHA (weight-fusion): 3.214 val loss in 59.7m
  Paper-faithful IHA (P=2):     3.193 val loss in 112.0m

Per-step cost is ~1.76x baseline due to 2x attention sequence length
(fundamental memory bandwidth cost, not FLOPs). Fits 2hr track, exceeds
1hr budget. For the 1hr track, revert to the previous commit which uses
weight-fusion mixing without sequence expansion.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Move paper-faithful sequence-interleaving IHA to two_hour/train.py
where it fits the time budget, and restore root train.py to the
weight-fusion cross-head mixing variant which fits the sub-1hr track.

Sub-1hr track (root train.py, weight-fusion IHA):
  val loss 3.214 in 59.7m training (11 epochs)

2hr track (two_hour/train.py, paper-faithful IHA P=2):
  val loss 3.200 in 101.9m training (11 epochs)
  val loss 3.193 in 112.0m training (12 epochs)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Keeps the paper-faithful Algorithm 1 implementation (sequence-dim
interleaving) in root train.py alongside two_hour/train.py.
Root currently exceeds the 1hr budget — to be optimized separately.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant