Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Non-record submission: 11L mixed int5/int6 + working QAT + TTT + 8 additions

**Historical run:** `1.1466 val_bpb` (sliding window, stride=32, original post-TTT flow) | **14.7 MB** artifact | 8xH100 SXM, 605s train + 340s eval

Built on PR #315 (1.1248). Ran with PyTorch SDPA instead of FA3, so throughput was 110ms/step instead of 85ms. Got 5,129 steps instead of ~7,000. Score should drop with FA3.

Note: the historical `1.1466` number above came from the original pre-eval TTT flow in this run. The current script has been updated to report plain no-TTT metrics and causal TTT metrics separately so future runs do not adapt on unseen eval tokens before scoring them. That means the checked-in script should be rerun before using it for a fresh official score claim.

## What we added to PR #315

**1. Working QAT.** PR #315's late QAT is dead code because `torch.compile` constant-folds `CastedLinear._qat_enabled` at first trace. We swap the `forward` method to `forward_qat` per instance and recompile. QAT noise matches the export scheme: int5 STE for MLP, int6 STE for attention. The current script also exposes `QAT_ENABLED`, `QAT_START_STEP`, and `QAT_START_FRAC` so we can turn QAT on earlier instead of hoping it only catches the last few steps.

**2. Mixed int5/int6 quantization + magnitude pruning.** MLP weights get int5 ([-16, 15]), attention gets int6 ([-32, 31]), embeddings stay int8. 3% magnitude pruning before quantization. Result: 14.7MB with 1.3MB headroom.

**3. Test-time training.** This run originally used post-quantization SGD on validation tokens before final scoring. The script now also includes a causal TTT path that scores each eval chunk first and only then adapts on that chunk, which is the safer version for future experiments.

**4. BigramHash 10240.** Up from 2048 in PR #315.

**5. Memory tokens.** 64 learnable embeddings as global context. Overwritten during training (targets masked), prepended during eval (stripped after layers). 32K params.

**6. Backout connection.** Learned scalar (init=0.2) subtracts encoder/decoder boundary state from final output. One parameter.

**7. Per-head temperature.** Learned temperature per attention head. 88 params total.

**8. Eval stride 32.** Down from 64. Made no difference here (s32 and s64 both gave 1.1466).

## What we kept from PR #315

11 layers, U-Net skips, XSA on last 4, EMA (0.997), partial RoPE (16/64 dims), LN scale, 3x MLP relu-squared, SmearGate, ortho+muP init, Muon (0.025, 0.99, WD=0.04), NTK RoPE, seq 2048, softcap 30.

## Results

| Metric | Value |
|--------|-------|
| Steps | 5,129 (110ms/step, SDPA) |
| Pre-quant val_bpb | 1.1597 |
| Post-quant val_bpb | 1.1697 |
| Quant gap | +0.0100 |
| Historical post-TTT sliding s32 | **1.1466** |
| Historical no-TTT roundtrip | 1.1697 |
| Artifact | 14,706,424 bytes |
| TTT time | 83s |
| Peak memory | 25,777 MiB/GPU |

## What would help

- FA3 (30% more training steps)
- 12th layer with the 1.3MB budget headroom
- Earlier QAT so it gets hundreds to thousands of steps instead of 1

## Papers behind these ideas

- Low-bit quantization direction: [QQQ: Quality Quattuor-Bit Quantization for Large Language Models](https://arxiv.org/abs/2406.09904)
- Very low-bit training motivation: [The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits](https://arxiv.org/abs/2402.17764)
- Test-time training: [Learning to (Learn at Test Time): RNNs with Expressive Hidden States](https://arxiv.org/abs/2407.04620)
- Faster Hopper attention kernels: [FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision](https://tridao.me/blog/2024/flash3/)

## How to run

```bash
cd /workspace
git clone -b submission/sota-attempt <your-fork-url> parameter-golf
cd parameter-golf
pip install huggingface-hub datasets sentencepiece tqdm zstandard
python3 data/cached_challenge_fineweb.py --variant sp1024
ATTN_BACKEND=auto QAT_START_FRAC=0.8 \
torchrun --standalone --nproc_per_node=8 records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/train_gpt.py
```

If FA3 is installed, set `ATTN_BACKEND=fa3` to fail fast when the kernel is missing instead of silently falling back.

Single seed (1337), torch 2.4.1+cu124, 8xH100 SXM on RunPod.
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"author": "Vytautas Bunevicius",
"github_id": "vytautas-bunevicius",
"name": "11L mixed int5/int6 + working QAT + TTT",
"blurb": "Non-record submission stacking 8 techniques on PR #315: working QAT (fixed dead code), mixed int5/int6 quantization, test-time training, BigramHash(10240), memory tokens, backout connection, per-head temperature, eval stride 32. Ran with PyTorch SDPA (no FA3), 5129 steps at 110ms/step.",
"date": "2026-03-22T12:20:51Z",
"val_loss": 1.93594245,
"val_bpb": 1.14657797,
"bytes_total": 14706424,
"bytes_code": 75257
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
logs/d6796b3c-a185-44f5-a817-d3e0c3e09b52.txt
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
model_params:27911346
mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0
world_size:8 grad_accum_steps:1
sdp_backends:cudnn=False flash=True mem_efficient=False math=False
attention_mode:gqa num_heads:8 num_kv_heads:4
tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025
train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000
seed:1337
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:0/20000 val_loss:6.9280 val_bpb:4.1031 train_time:0ms step_avg:0.01ms
step:1/20000 train_loss:6.9304 train_time:154ms step_avg:153.85ms
step:2/20000 train_loss:8.4705 train_time:251ms step_avg:125.31ms
step:3/20000 train_loss:7.5179 train_time:360ms step_avg:119.94ms
step:4/20000 train_loss:8.1017 train_time:469ms step_avg:117.30ms
step:5/20000 train_loss:8.2824 train_time:579ms step_avg:115.74ms
step:6/20000 train_loss:7.9754 train_time:688ms step_avg:114.68ms
step:7/20000 train_loss:7.5070 train_time:797ms step_avg:113.87ms
step:8/20000 train_loss:7.1406 train_time:906ms step_avg:113.30ms
step:9/20000 train_loss:6.6636 train_time:1016ms step_avg:112.84ms
step:10/20000 train_loss:6.2675 train_time:1125ms step_avg:112.52ms
step:200/20000 train_loss:2.3890 train_time:22059ms step_avg:110.30ms
step:400/20000 train_loss:2.4267 train_time:44213ms step_avg:110.53ms
step:600/20000 train_loss:2.3474 train_time:66368ms step_avg:110.61ms
step:800/20000 train_loss:2.2505 train_time:88587ms step_avg:110.73ms
step:1000/20000 train_loss:2.2858 train_time:110732ms step_avg:110.73ms
step:1000/20000 val_loss:2.2363 val_bpb:1.3245 train_time:110749ms step_avg:110.75ms
step:1200/20000 train_loss:2.3632 train_time:132921ms step_avg:110.77ms
step:1400/20000 train_loss:2.1869 train_time:155095ms step_avg:110.78ms
step:1600/20000 train_loss:2.0789 train_time:177188ms step_avg:110.74ms
step:1800/20000 train_loss:2.1527 train_time:199334ms step_avg:110.74ms
step:2000/20000 train_loss:2.0611 train_time:221406ms step_avg:110.70ms
step:2000/20000 val_loss:2.1305 val_bpb:1.2618 train_time:221421ms step_avg:110.71ms
step:2200/20000 train_loss:2.1296 train_time:243546ms step_avg:110.70ms
step:2400/20000 train_loss:2.0630 train_time:265612ms step_avg:110.67ms
step:2600/20000 train_loss:2.1011 train_time:287726ms step_avg:110.66ms
step:2800/20000 train_loss:2.1464 train_time:309841ms step_avg:110.66ms
step:3000/20000 train_loss:2.1432 train_time:331881ms step_avg:110.63ms
step:3000/20000 val_loss:2.0737 val_bpb:1.2282 train_time:331898ms step_avg:110.63ms
step:3200/20000 train_loss:2.1530 train_time:353995ms step_avg:110.62ms
step:3400/20000 train_loss:1.9918 train_time:376141ms step_avg:110.63ms
step:3600/20000 train_loss:2.0610 train_time:398251ms step_avg:110.63ms
step:3800/20000 train_loss:2.0316 train_time:420294ms step_avg:110.60ms
step:4000/20000 train_loss:1.9355 train_time:442405ms step_avg:110.60ms
step:4000/20000 val_loss:2.0245 val_bpb:1.1990 train_time:442420ms step_avg:110.61ms
step:4200/20000 train_loss:2.1051 train_time:464511ms step_avg:110.60ms
step:4400/20000 train_loss:1.9798 train_time:486552ms step_avg:110.58ms
step:4600/20000 train_loss:1.7920 train_time:508661ms step_avg:110.58ms
step:4800/20000 train_loss:2.3816 train_time:530696ms step_avg:110.56ms
step:5000/20000 train_loss:2.0406 train_time:552793ms step_avg:110.56ms
step:5000/20000 val_loss:1.9650 val_bpb:1.1638 train_time:552808ms step_avg:110.56ms
late_qat:enabled step:5128 scale:0.0997
step:5129/20000 val_loss:1.9581 val_bpb:1.1597 train_time:605055ms step_avg:117.97ms
stopping_early: wallclock_cap train_time:605055ms step:5129/20000
peak memory allocated: 25777 MiB reserved: 26916 MiB
ema:applying EMA weights
Serialized model: 108015620 bytes
Code size: 75257 bytes
Serialized model int6+zstd: 14631167 bytes
Total submission size int6+zstd: 14706424 bytes
ttt:epoch:1/3 loss:1.9625
ttt:epoch:2/3 loss:1.9614
ttt:epoch:3/3 loss:1.9609
ttt:done time:82697ms
final_int6_roundtrip val_loss:1.9749 val_bpb:1.1697 eval_time:38727ms
final_int6_roundtrip_exact val_loss:1.97493070 val_bpb:1.16966520
final_int6_sliding_window val_loss:1.9359 val_bpb:1.1466 stride:32 eval_time:217603ms
final_int6_sliding_window_exact val_loss:1.93594245 val_bpb:1.14657797
final_int6_sliding_window_s64 val_loss:1.9360 val_bpb:1.1466 stride:64 eval_time:109694ms
final_int6_sliding_window_s64_exact val_loss:1.93596518 val_bpb:1.14659066
Loading