Paper-faithful IHA with sequence interleaving: 3.193 val loss (2hr track)#78
Open
ms337 wants to merge 5 commits into
Open
Paper-faithful IHA with sequence interleaving: 3.193 val loss (2hr track)#78ms337 wants to merge 5 commits into
ms337 wants to merge 5 commits into
Conversation
Implements cross-head Q/K/V mixing from the IHA paper (arxiv 2602.21371) on top of the MTP baseline. Each head's query/key/value projection becomes a learned linear combination of all heads' projections via H×H mixing matrices, enabling richer cross-head attention patterns. Key optimization: mixing matrices are fused into the Q/K/V projection weights at forward time (W_fused[h] = sum_m mix[h,m] * W_orig[m]). The [H,H]@[H,d*C] fusion matmul is negligible vs the main projection, keeping per-step overhead to just 47ms (3.3%) over baseline. Results (sub-1hr track, 11 epochs): MTP baseline: 3.222 val loss, 57.5m training IHA+MTP fused: 3.214 val loss, 59.7m training (-0.008, under 1hr) CLI flags: --iha, --iha-v, --iha-lr (default: SCALAR_LR) Best config: --iha --iha-v --iha-lr=0.02 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
IHA Q+K+V mixing with iha-lr=0.02 is now the default behavior. No flags needed to reproduce the record — just run: torchrun --standalone --nproc_per_node=8 train.py Use --no-iha to disable IHA if needed. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replaces the previous weight-fusion cross-head mixing with full
Algorithm 1 from the IHA paper (arxiv 2602.21371):
1. Mix: α^{Q,K,V} ∈ [H, H, P] generate P pseudo-heads via einsum
2. Interleave: [B,T,H,P,d] → [B,T*P,H,d] — pseudo-tokens adjacent
3. Attend: Flash attention on expanded T*P sequence
4. De-interleave + collapse: learned R ∈ [H, P] back to [B,T,H,d]
Paper-faithful FLOP-matched window schedule:
- Short (S) layers: window = N/(2P) = 512 for P=2
- Long (L) layers: full expanded context = N*P = 4096 for P=2
Results on sub-1hr track config (12 epochs):
Previous IHA (weight-fusion): 3.214 val loss in 59.7m
Paper-faithful IHA (P=2): 3.193 val loss in 112.0m
Per-step cost is ~1.76x baseline due to 2x attention sequence length
(fundamental memory bandwidth cost, not FLOPs). Fits 2hr track, exceeds
1hr budget. For the 1hr track, revert to the previous commit which uses
weight-fusion mixing without sequence expansion.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Move paper-faithful sequence-interleaving IHA to two_hour/train.py where it fits the time budget, and restore root train.py to the weight-fusion cross-head mixing variant which fits the sub-1hr track. Sub-1hr track (root train.py, weight-fusion IHA): val loss 3.214 in 59.7m training (11 epochs) 2hr track (two_hour/train.py, paper-faithful IHA P=2): val loss 3.200 in 101.9m training (11 epochs) val loss 3.193 in 112.0m training (12 epochs) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Keeps the paper-faithful Algorithm 1 implementation (sequence-dim interleaving) in root train.py alongside two_hour/train.py. Root currently exceeds the 1hr budget — to be optimized separately. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
4 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Algorithm
Implements the paper's 5-step algorithm:
α^{Q,K,V} ∈ [H, H, P]creates P pseudo-heads viaeinsum('btmd,mhp->bthpd')[B,T,H,P,d] → permute → [B,T·P,H,d]— pseudo-tokens adjacent per position[B,T·P,H,d] → [B,T,P,H,d]R ∈ [H, P]reduces pseudo-heads →[B,T,H,d]Results (2hr track, two_hour/train.py)
Reproduction verified on commit
0e63026:Usage
IHA is on by default. Control via:
--iha-P=2(default, pseudo-heads per head)--iha-lr=0.02(default, LR for mixing matrices)--no-ihato disableImplementation notes
torch.compile(model, dynamic=False)— no graph breaksTest plan
🤖 Generated with Claude Code