-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Record: 11L XSA4 + Tight SWA + FA3 + Two-Phase TTT (val_bpb=1.1216) #410
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6ed2fa5
bccc688
8a0ebf2
db5c5dd
65f54ac
32790dd
4bd048c
a3b1212
38dff06
67fa031
943597d
308ed62
78f998e
4c37972
e83a277
573f735
358a426
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| # QAT + BigramHash(12288) + Stride 32 | ||
|
|
||
| ## Summary | ||
|
|
||
| Built on the current SOTA (`10L_Int5MLP_MuonWD04_SWA50`) with the following improvements: | ||
|
|
||
| - **QAT (Quantization-Aware Training):** STE fake-quantize during training — int5 for MLP layers, int6 for attention. Reduces post-quantization degradation. | ||
| - **BigramHash 12288:** Increased from 10240 to 12288 buckets for better bigram coverage. | ||
| - **Eval stride 32:** Reduced from 64 to 32 for more overlapping context windows during evaluation. | ||
| - **Magnitude pruning 5%:** Increased from 3% to improve compression ratio. | ||
| - **SWA every 50 steps:** Checkpoint averaging during warmdown. | ||
|
|
||
| ## Architecture | ||
|
|
||
| - 10 transformer layers, dim=512, 8 heads, 4 KV heads | ||
| - 3x MLP with SmearGate | ||
| - BigramHash(12288) with bigram_dim=128 | ||
| - Mixed quantization: int5 MLP, int6 attention | ||
| - zstd-22 compression | ||
|
|
||
| ## Results | ||
|
|
||
| ``` | ||
| seed=2024: val_bpb=1.14443, artifact=15,902,583 bytes | ||
| ``` | ||
|
|
||
| ## Command | ||
|
|
||
| ```bash | ||
| torchrun --standalone --nproc_per_node=8 train_gpt.py | ||
| ``` | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,9 @@ | ||||||
| { | ||||||
| "name": "QAT + BigramHash(12288) + Stride 32", | ||||||
| "val_loss": 1.14443, | ||||||
| "bytes_total": 15902583, | ||||||
| "blurb": "10 layers, QAT with STE (int5 MLP / int6 attn), BigramHash 12288, eval stride 32, magnitude pruning 5%, SWA every 50 steps, zstd-22. Based on 10L_Int5MLP_MuonWD04_SWA50.", | ||||||
|
||||||
| "blurb": "10 layers, QAT with STE (int5 MLP / int6 attn), BigramHash 12288, eval stride 32, magnitude pruning 5%, SWA every 50 steps, zstd-22. Based on 10L_Int5MLP_MuonWD04_SWA50.", | |
| "blurb": "10 layers, QAT with STE (int5 MLP / int6 attn), BigramHash 12288, eval stride 32, magnitude pruning 15%, SWA every 25 steps, zstd-22. Based on 10L_Int5MLP_MuonWD04_SWA50.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The README states “Magnitude pruning 5%” (and mentions SWA every 50), but the accompanying script prunes at a fixed 15% quantile (
0.15) and usesSWA_EVERY=25by default. Please align the README (and/or make pruning percentage and SWA cadence match the described configuration used for this record).