Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Non-Record Submission: PrismLM v3 — DiffTransformer V2 + NorMuon + TrigramHash

## Score: val_bpb = 1.1715 (post-quant int6+zstd, no sliding window)

Trained on 8×H100 SXM in 600 seconds. 15.59MB artifact (int6+zstd-22). Single seed run.

This is a non-record submission exploring **three novel techniques** not yet attempted in any merged or open PR, built on top of the proven technique stack from PR #315.

## Novel Contributions

### 1. DiffTransformer V2 Attention (Last 2 Layers)

Based on [Differential Transformer](https://arxiv.org/abs/2410.05258) (Microsoft, ICLR 2025 Oral). Computes two separate softmax attention maps and subtracts them, cancelling noise in the attention pattern:

```
attn = softmax(Q1 @ K1^T) - λ · softmax(Q2 @ K2^T)
```

Applied only to the last 2 layers where attention refinement matters most. The scalar `λ` is learned per-head via `lambda_init` reparameterization. Remaining layers use standard GQA + XSA.

### 2. NorMuon Optimizer

Replaces standard Muon with NorMuon ([Keller Jordan, Oct 2025](https://kellerjordan.github.io/posts/muon/)), which adds **per-neuron row normalization** after the Newton-Schulz orthogonalization step. This normalizes gradient updates by the second moment of each row, giving ~11% better compute efficiency. Uses `beta2=0.95` for the second moment EMA.

### 3. TrigramHash + Context-Aware N-gram Gating

Extends BigramHash with a TrigramHash table (2048 buckets, dim 64) that captures three-token patterns via `(t0 * 961 + t1 * 31 + t2) % (vocab_size - 1) + 1`. Both n-gram signals are modulated by a **context-aware gate** (inspired by [DeepSeek Engram](https://github.com/deepseek-ai/Engram)) that learns when to rely on n-gram vs. neural predictions:

```
gate = sigmoid(linear(hidden_state))
output = hidden + gate * (bigram_signal + trigram_signal)
```

## Full Architecture

| Component | Value |
|-----------|-------|
| Layers | 11 |
| Model dim | 512 |
| Heads / KV heads | 8 / 4 (GQA) |
| MLP expansion | 3× (hidden=1536), ReLU² |
| XSA layers | Last 6 |
| DiffAttn layers | Last 2 |
| Partial RoPE | 16/64 dims (25%) |
| LN depth scaling | 1/√(layer+1) |
| SmearGate | Yes |
| BigramHash | 2048 buckets, dim 128 |
| TrigramHash | 2048 buckets, dim 64 |
| N-gram gating | Context-aware sigmoid gate |
| U-Net skips | Yes |
| Logit softcap | 30.0 |
| Tied embeddings | Yes (FP16) |

## Training Configuration

| Parameter | Value |
|-----------|-------|
| Optimizer (matrices) | NorMuon (lr=0.04, momentum=0.95, WD=0.02, beta2=0.95) |
| Optimizer (embeddings/scalars) | AdamW (lr=0.04, WD=0.01) |
| Tied embed LR | 0.05 |
| Batch size | 786,432 tokens |
| Sequence length | 2048 |
| Warmdown iters | 1200 |
| Grad clip | 0.3 |
| SWA | Enabled (every 200 steps) |
| Late QAT | Enabled (when lr_scale < 0.1) |
| Warmup | 20 steps |

## Quantization & Compression

- **Int6** per-row quantization on MLP and attention weight matrices
- **FP16** for tied embeddings
- **3% magnitude pruning** before quantization (adaptive up to 15% if over budget)
- **zstd level 22** compression
- Flash Attention 3 fallback to `F.scaled_dot_product_attention` (FA3 not available in our environment)

## Key Metrics

- **val_bpb (post-quant): 1.1715** (standard eval, no sliding window)
- Pre-quant val_bpb: 1.1607
- Quantization penalty: ~0.011 bpb
- Steps completed: 4,600 / 20,000 (wallclock-capped at 600s)
- Step average: 130.43 ms/step
- Model params: 27,518,587
- Artifact size: 15,586,651 bytes (15.59MB)
- Model int6+zstd: 15,521,912 bytes
- Code: 64,739 bytes
- Peak memory: 25,921 MiB allocated, 26,460 MiB reserved
- GPU: 8×H100 SXM (Modal)

## Training Progression

| Step | val_loss | val_bpb |
|------|----------|---------|
| 0 | 6.9300 | 4.1043 |
| 1000 | 2.2005 | 1.3033 |
| 2000 | 2.1133 | 1.2516 |
| 3000 | 2.0876 | 1.2364 |
| 4000 | 2.0209 | 1.1969 |
| 4600 (final) | 1.9598 | 1.1607 |
| **Post-quant** | **1.9780** | **1.1715** |

## Gap Analysis vs. SOTA

Our score of 1.1715 is ~0.029 bpb behind the merged SOTA (1.1428) and ~0.047 bpb behind the unmerged frontier (1.1248). Key factors:

1. **No sliding window eval** — was disabled to save eval time. Sliding window typically gives ~0.03 bpb improvement; re-enabled in the submitted code for future runs.
2. **Small n-gram tables** — BigramHash(2048) vs. the SOTA's BigramHash(10240). Larger tables are worth ~0.005 bpb.
3. **NorMuon hyperparameters** — momentum=0.95 vs. proven momentum=0.99. The lower momentum may have hurt convergence in the warmdown phase.
4. **DiffAttn parameter overhead** — 1.5× attention parameters on 2 layers reduces capacity available for other components. The noise-cancellation benefit at this scale is unclear.
5. **SDPA fallback** — Flash Attention 3 was unavailable; SDPA is functionally equivalent but ~10% slower, meaning fewer training steps.

## What We'd Change With More Compute

1. Increase BigramHash to 10240 buckets (~0.005 bpb)
2. Re-enable sliding window eval (~0.03 bpb)
3. Tune NorMuon momentum to 0.99
4. Try EMA instead of SWA (works better with XSA per community data)
5. Ablate DiffAttn vs. standard attention to quantify its contribution
6. Increase TrigramHash to 8192 buckets

## Included Files

- `train_gpt.py` — Self-contained training + evaluation script (bug-fixed: correct 16MB decimal limit, sliding window eval re-enabled)
- `train.log` — Training log from the 8×H100 run (seed 1337)
- `submission.json` — Leaderboard metadata
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"author": "Yash Verma",
"github_id": "yasboop",
"name": "PrismLM v3 — DiffTransformer V2 + NorMuon + TrigramHash",
"blurb": "Non-record 8xH100 submission exploring 3 novel techniques on top of PR #315's technique stack: (1) DiffTransformer V2 attention on last 2 layers for noise-cancelled attention maps, (2) NorMuon optimizer with per-neuron row normalization for 11% better convergence, (3) TrigramHash + context-aware n-gram gating. 11 layers, 512 dim, XSA on 6 layers, Partial RoPE (16 dims), LN depth scaling, SmearGate, BigramHash(2048), int6+zstd-22 quantization, SWA, late QAT. Post-quant val_bpb 1.1715 (without sliding window eval).",
"date": "2026-03-22T00:00:00Z",
"track": "non-record-16mb",
"val_loss": 1.97797457,
"val_bpb": 1.17146796,
"pre_quant_val_loss": 1.9598,
"pre_quant_val_bpb": 1.1607,
"step_stop": 4600,
"wallclock_seconds": 599.974,
"bytes_total": 15586651,
"bytes_model_int6_zstd": 15521912,
"bytes_code": 64739,
"gpu": "8xH100 SXM (Modal)"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
logs/prism_seed1337_gpu8.txt
flash_attn_3:False
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:100
val_loader:shards pattern=/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
torch.compile(fullgraph=True): registered (compiles on first real forward)
model_params:27518587
diff_attn_last_n:2 xsa_last_n:6
bigram_vocab:2048 trigram_vocab:2048
optimizer:NorMuon matrix_lr:0.04 muon_momentum:0.95
world_size:8 grad_accum_steps:1
tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04
train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000
seed:1337
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:0/20000 val_loss:6.9300 val_bpb:4.1043 train_time:0ms step_avg:0.02ms
step:1/20000 train_loss:6.9319 train_time:185ms step_avg:185.24ms
step:2/20000 train_loss:10.2254 train_time:311ms step_avg:155.55ms
step:3/20000 train_loss:8.5880 train_time:439ms step_avg:146.27ms
step:4/20000 train_loss:7.7303 train_time:573ms step_avg:143.23ms
step:5/20000 train_loss:7.1825 train_time:706ms step_avg:141.23ms
step:6/20000 train_loss:6.8526 train_time:836ms step_avg:139.38ms
step:7/20000 train_loss:6.7896 train_time:964ms step_avg:137.72ms
step:8/20000 train_loss:6.7501 train_time:1096ms step_avg:137.04ms
step:9/20000 train_loss:6.4141 train_time:1227ms step_avg:136.36ms
step:10/20000 train_loss:5.9556 train_time:1354ms step_avg:135.37ms
step:200/20000 train_loss:2.3882 train_time:26134ms step_avg:130.67ms
step:400/20000 train_loss:2.4336 train_time:52379ms step_avg:130.95ms
step:600/20000 train_loss:2.3296 train_time:78406ms step_avg:130.68ms
step:800/20000 train_loss:2.2175 train_time:104554ms step_avg:130.69ms
step:1000/20000 train_loss:2.2512 train_time:130559ms step_avg:130.56ms
step:1000/20000 val_loss:2.2005 val_bpb:1.3033 train_time:130567ms step_avg:130.57ms
step:1200/20000 train_loss:2.3233 train_time:156718ms step_avg:130.60ms
step:1400/20000 train_loss:2.1494 train_time:182849ms step_avg:130.61ms
step:1600/20000 train_loss:2.0411 train_time:208894ms step_avg:130.56ms
step:1800/20000 train_loss:2.1322 train_time:235041ms step_avg:130.58ms
step:2000/20000 train_loss:2.0504 train_time:261028ms step_avg:130.51ms
step:2000/20000 val_loss:2.1133 val_bpb:1.2516 train_time:261033ms step_avg:130.52ms
step:2200/20000 train_loss:2.1147 train_time:287151ms step_avg:130.52ms
step:2400/20000 train_loss:2.0524 train_time:313147ms step_avg:130.48ms
step:2600/20000 train_loss:2.1013 train_time:339231ms step_avg:130.47ms
step:2800/20000 train_loss:2.1496 train_time:365361ms step_avg:130.49ms
step:3000/20000 train_loss:2.1562 train_time:391369ms step_avg:130.46ms
step:3000/20000 val_loss:2.0876 val_bpb:1.2364 train_time:391375ms step_avg:130.46ms
step:3200/20000 train_loss:2.1661 train_time:417498ms step_avg:130.47ms
step:3400/20000 train_loss:2.0158 train_time:443423ms step_avg:130.42ms
step:3600/20000 train_loss:2.0787 train_time:469562ms step_avg:130.43ms
step:3800/20000 train_loss:2.0383 train_time:495533ms step_avg:130.40ms
step:4000/20000 train_loss:1.9308 train_time:521677ms step_avg:130.42ms
step:4000/20000 val_loss:2.0209 val_bpb:1.1969 train_time:521687ms step_avg:130.42ms
swa:start step:4200
step:4200/20000 train_loss:2.0923 train_time:547809ms step_avg:130.43ms
step:4400/20000 train_loss:1.9583 train_time:573818ms step_avg:130.41ms
late_qat:enabled step:4481 scale:0.0994
step:4600/20000 train_loss:1.7641 train_time:599968ms step_avg:130.43ms
step:4600/20000 val_loss:1.9598 val_bpb:1.1607 train_time:599974ms step_avg:130.43ms
stopping_early: wallclock_cap train_time:599974ms step:4600/20000
peak memory allocated: 25921 MiB reserved: 26460 MiB
swa:applying averaged 3 checkpoints
Serialized model: 108278477 bytes
Code size: 64739 bytes
Serialized model int6+zstd: 15521912 bytes
Total submission size int6+zstd: 15586651 bytes
Under 16MB: True
final_int6_roundtrip val_loss:1.9780 val_bpb:1.1715 eval_time:22821ms
final_int6_roundtrip_exact val_loss:1.97797457 val_bpb:1.17146796
Loading