diff --git a/PLAN.md b/PLAN.md new file mode 100644 index 000000000..62a811919 --- /dev/null +++ b/PLAN.md @@ -0,0 +1,269 @@ +# Parameter Golf — Fractal Transformer Research Plan +**DGX Spark · GB10 · March 2026** + +--- + +## Challenge Summary + +| Constraint | Value | +|------------|-------| +| Artifact size | ≤16MB (code + int8 quantized + zlib compressed weights) | +| Training time | ≤10 minutes on 8×H100 | +| Metric | bits-per-byte (BPB) on FineWeb validation set | +| Baseline | 1.2244 BPB | +| Record threshold | ≤1.2194 BPB (must beat by ≥0.005) | +| 4-hour unlimited baseline | 1.2074 BPB | +| Challenge window | March 18 → April 30, 2026 | +| Repo | https://github.com/newjordan/parameter-golf | + +--- + +## Our Approach: Fractal Transformer + Gravity + AttnRes + +### Core Thesis + +Weight-shared transformer layers with learned gravitational auxiliary losses +and attention residuals will achieve lower BPB than the baseline's 9-unique-layer +architecture within the same 16MB parameter budget. + +### Three Innovations Combined + +**1. Fractal Architecture (Weight Sharing / Depth Recurrence)** + +Instead of 9 unique layers, use 3 unique layers repeated in 3 loops. + +``` +CURRENT BASELINE: + 9 unique layers × 512 dim = ~14M params + +OUR APPROACH: + 3 unique layers × 3 loops = 9 effective layers + Wider layers (~700 dim) with same total param count + Loop position embedding tells shared weights which pass they're on +``` + +Why this helps: +- Fewer unique parameters → more room in 16MB budget → wider layers +- Wider layers = richer features per layer +- Weight sharing compresses extremely well under int8+zlib +- Depth recurrence explicitly encouraged by the challenge README + +**2. Gravity (Learned Auxiliary Losses)** + +At the end of each loop, peek at the output using the shared lm_head and +compute an auxiliary cross-entropy loss. The weights are LEARNED parameters. + +```python +self.gravity_weights = nn.Parameter(torch.tensor([0.1, 0.3, 1.0])) + +total_loss = 0 +for loop in range(3): + x = run_shared_layers(x, loop_pos=loop) + loop_logits = lm_head(rms_norm(x)) + loop_loss = cross_entropy(loop_logits, targets) + total_loss += softplus(self.gravity_weights[loop]) * loop_loss +``` + +Why this helps: +- 3× gradient signal — every layer gets direct supervision, not diluted backprop +- Model discovers optimal loop weighting during training +- Especially powerful with weight sharing: same weights receive gradient from 3 depths +- Zero new parameters (3 scalars for weights, reuses existing lm_head) +- ~1.2% compute overhead (2 extra lm_head calls) + +The "gravity" analogy: +- Loop 1 output is far from the target → strong pull, large updates +- Loop 2 is closer → medium pull, refinement +- Loop 3 is nearest → full weight, precision +- Each loop starts from a better position because the previous loop was already pulled toward the answer + +**3. AttnRes (Attention Residuals)** + +Replace fixed skip connections with learned, input-dependent attention over depth. +From Moonshot's paper (arxiv:2603.15031). + +``` +Standard residuals: x = x + layer_output (fixed, uniform weight) +AttnRes: x = softmax(query · [prev_outputs]) · [prev_outputs] +``` + +Each layer has a single learned query vector w_l ∈ R^d that attends over all +previous loop outputs. The softmax produces content-aware, input-dependent +weights instead of fixed uniform accumulation. + +Why this helps: +- Paper shows 1.25× compute equivalent for near-zero parameter cost +- Replaces BOTH the baseline's U-Net skips AND resid_mix +- Only 9 × dim ≈ 4,608 new parameters +- Critical for weight sharing: lets later loops selectively reference earlier loops + +### What We Remove From Baseline + +| Component | Parameters | Replaced By | +|-----------|-----------|-------------| +| U-Net encoder/decoder split | structural | Fractal loops | +| skip_weights (9 × 512) | 4,608 | AttnRes queries | +| resid_mix (9 × 2 × 512) | 9,216 | AttnRes | +| **Total removed** | **~13,824** | | + +### What We Add + +| Component | Parameters | Purpose | +|-----------|-----------|---------| +| AttnRes queries (9 layers) | 4,608 | Selective depth attention | +| Loop position embeddings (3 loops) | ~2,100 | Tell weights which loop they're in | +| Gravity weights (3 scalars) | 3 | Learned auxiliary loss weighting | +| **Total added** | **~6,711** | | + +**Net: ~7,113 parameters saved → reinvested into wider layers.** + +--- + +## Architecture Diagram + +``` +INPUT TOKENS (1024 vocab) + │ + ▼ +EMBEDDING (1024 × ~700 dim) + │ + ▼ +LOOP 1 (broad strokes): + ├── Layer A (attention + MLP, loop_pos=0) + ├── Layer B (attention + MLP, loop_pos=0) + ├── Layer C (attention + MLP, loop_pos=0) + ├── GRAVITY: peek → compute loss₁ (learned weight ~0.1) + └── Store loop 1 output for AttnRes + │ + ▼ +LOOP 2 (refinement): + ├── AttnRes: attend over [embedding, loop1_output] + ├── Layer A (attention + MLP, loop_pos=1) ← same weights as loop 1 + ├── Layer B (attention + MLP, loop_pos=1) + ├── Layer C (attention + MLP, loop_pos=1) + ├── GRAVITY: peek → compute loss₂ (learned weight ~0.3) + └── Store loop 2 output for AttnRes + │ + ▼ +LOOP 3 (precision): + ├── AttnRes: attend over [embedding, loop1_output, loop2_output] + ├── Layer A (attention + MLP, loop_pos=2) ← same weights again + ├── Layer B (attention + MLP, loop_pos=2) + ├── Layer C (attention + MLP, loop_pos=2) + └── FINAL LOSS: full cross-entropy (weight = 1.0) + │ + ▼ +OUTPUT: logits → BPB +``` + +Each loop tightens the representation: +- Loop 1: rough sketch (only sees embedding) +- Loop 2: refinement (sees embedding + loop 1 output via AttnRes) +- Loop 3: precision (sees full history, committed to answer) + +--- + +## Information Tightening Mechanisms + +### Gravity (primary — Frosty's intuition) +Each loop is pulled toward the final answer by its own loss signal. Later loops +start from better positions because earlier loops were already course-correcting. +The model learns how hard each loop should pull (learned gravity weights). + +### AttnRes (secondary — from Moonshot paper) +Selective attention over previous loop outputs. Later loops can choose which +earlier representations are useful for each specific token, not a fixed blend. + +### Future: Ring Buffer + Temperature Cooling (Phase 4) +- Ring buffer: bounded memory with eviction of unhelpful previous states +- Temperature: AttnRes attention sharpens with depth (soft early, committed late) +- Only add if Phase 1-3 show signal + +--- + +## Experiment Sequence + +### Phase 1: Establish Weight Sharing Baselines +1. Run baseline as-is → establish local BPB reference +2. 3 shared layers × 3 loops, same total params, ~512 dim → does sharing work? +3. 3 shared layers × 3 loops, wider ~700 dim → does width help? +4. 2 shared layers × 4 loops, widest ~850 dim → more loops? +5. 4 shared layers × 2 loops, ~620 dim → fewer loops? + +### Phase 2: Add Gravity +6. Best config from Phase 1 + gravity with learned weights +7. Compare: gravity learned vs gravity fixed [0.1, 0.3, 1.0] vs no gravity + +### Phase 3: Add AttnRes +8. Best from Phase 2 + full AttnRes +9. Test: AttnRes before attention only / before MLP only / both +10. Test: AttnRes with vs without gravity + +### Phase 4: Advanced Mechanisms +11. Add ring buffer (bounded memory with eviction) +12. Add temperature cooling on AttnRes +13. Try combining all mechanisms + +### Phase 5: Optimize for Submission +14. Verify int8+zlib artifact ≤16MB +15. Tune width to maximize quality within size budget +16. Port winning config to official train_gpt.py style +17. Run on cloud 8×H100, verify 10-minute timing +18. Prepare submission folder for /records + +--- + +## Workflow + +### Local (DGX Spark, free, unlimited) +- Adapted research fork without Triton/torch.compile dependency +- Shorter training budget (2 min per experiment) +- Smaller batch size +- Same model, data, tokenizer, BPB metric +- Results won't match H100 numbers but relative ordering transfers +- Run 50-100 experiments to find winning configuration +- Autoresearch agent runs overnight (Phase 1-4) + +### Cloud (H100s, paid, limited) +- Take best configuration from local experiments +- Run at full scale: 8×H100, 10 minutes, full batch +- Verify BPB, artifact size, timing +- Prepare official submission + +--- + +## Source Material + +### Attention Residuals (Moonshot) +- Paper: arxiv:2603.15031 +- Repo: https://github.com/MoonshotAI/Attention-Residuals +- Core: replace fixed residual connections with softmax attention over depth +- Result: matches 1.25× compute baseline at near-zero parameter cost + +### Autoresearch (Karpathy) +- Repo: https://github.com/karpathy/autoresearch +- Core: AI agent modifies train.py, trains 5 min, keeps/discards, loops forever +- Adapted as our outer optimization loop + +### Parameter Golf Baseline +- Repo: https://github.com/openai/parameter-golf +- Architecture: 9-layer GPT, 512 dim, 1024 vocab, GQA, Muon optimizer +- Key features: U-Net skip connections, resid_mix, ReLU², logit softcapping +- BPB: 1.2244 (10 min), 1.2074 (4 hour) + +--- + +## Key Insight + +The competition rewards compression quality per parameter. Weight sharing is +the ultimate compression — the same function applied repeatedly. AttnRes gives +that repeated function the ability to selectively reference its earlier outputs. +Gravity ensures every repetition is actively pulled toward the correct answer. + +The fractal structure means each loop genuinely tightens the representation: +same weights, progressively richer input, direct loss supervision at every +stage. The model isn't just repeating — it's refining. + +--- + +*Plan authored by Octavian + Frosty · Spark-2949 · 2026-03-18* diff --git a/RESULTS.md b/RESULTS.md new file mode 100644 index 000000000..1d27a6b39 --- /dev/null +++ b/RESULTS.md @@ -0,0 +1,69 @@ +# Parameter Golf — Local Experiment Results +**DGX Spark GB10 · 2026-03-18** + +## Experiment Ladder (300 steps, 1 train shard, 1M eval tokens) + +| # | Config | val_bpb | Δ vs baseline | params | dim | ms/step | +|---|--------|--------:|----------:|-------:|----:|--------:| +| 1 | Baseline (9 unique layers, 512d) | 2.7927 | — | 17.05M | 512 | 167 | +| 2 | **Fractal only (3×3, 864d)** | **2.5953** | **-0.1975** | 16.57M | 864 | 333 | +| 3 | Fractal + Gravity (3×3, 864d) | 2.6149 | -0.1779 | 16.57M | 864 | 347 | +| 4 | Fractal + Gravity + AttnRes (3×3, 864d) | 2.6084 | -0.1843 | 16.58M | 864 | 425 | + +## Training Loss Comparison (300 steps) + +| Step | Baseline | Fractal | Fractal+Gravity | Fractal+Grav+AttnRes | +|------|----------|---------|-----------------|---------------------| +| 50 | 5.8850 | — | 5.8229 | — | +| 100 | 5.2427 | — | 5.0172 | — | +| 150 | 4.8926 | — | 4.6254 | — | +| 200 | 4.7830 | — | 4.5360 | — | +| 250 | 4.7162 | — | 4.4521 | — | +| 300 | 4.6554 | 4.3473 | 4.3794 | 4.3751 | + +## Key Findings + +1. **Weight sharing + wider layers is the dominant effect.** Fractal-only beats baseline + by 7.1% BPB with fewer total parameters. The 864d shared layers are significantly more + expressive than 512d unique layers. + +2. **Gravity slightly hurts at 300 steps.** The auxiliary losses on early loops add gradient + noise before those loops learn to produce useful predictions. The model learned weights + [0.13, 0.13, 0.70] — trying to minimize early loop influence but can't fully zero it. + +3. **AttnRes partially recovers the gravity penalty.** Selective depth attention helps + the model route around noisy early-loop outputs. + +4. **All fractal variants beat baseline convincingly.** Even the worst fractal config + (fractal+gravity at 2.6149) still beats baseline (2.7927) by 0.18 BPB. + +## Hypothesis for Full-Scale Runs + +Gravity and AttnRes should improve with more training steps because: +- Early loops need many steps to learn useful intermediate predictions +- At 13,000+ steps (H100 10-minute budget), the gravity signal should become useful +- The learned gravity weights should evolve from [0.13, 0.13, 0.70] toward something + that actually leverages early loops + +## Learned Gravity Weights (Experiments 3 & 4) + +Both converged to: `[0.127, 0.127, 0.699]` +- softplus(-2.0) = 0.127 (early loops, barely contributing) +- softplus(0.0) = 0.693 (final loop, dominant) +- The model essentially learned to "turn off" early gravity — confirming that at + 300 steps, direct early-loop supervision is noise rather than signal + +## Next Steps + +1. Try gravity with warmup: zero gravity for first 100 steps, then ramp up +2. Try different loop configs: 2×4, 4×2, 2×5 +3. Ship fractal-only (best local result) to cloud H100s for official timing +4. Ship fractal+gravity+attnres as second cloud experiment to test if it + overtakes with more training + +## Environment +- Hardware: DGX Spark GB10, 130.7GB unified VRAM +- PyTorch: 2.10.0+cu130 (no torch.compile, no Triton) +- Data: FineWeb sp1024, 1 train shard, ~100M train tokens +- Eval: 1M validation tokens (truncated for speed) +- Optimizer: AdamW (not Muon — local simplification) diff --git a/records/track_10min_16mb/2026-03-22_11L_TTTburst_GPTQ15_EMA_QAT_1.1232/README.md b/records/track_10min_16mb/2026-03-22_11L_TTTburst_GPTQ15_EMA_QAT_1.1232/README.md new file mode 100644 index 000000000..86505887c --- /dev/null +++ b/records/track_10min_16mb/2026-03-22_11L_TTTburst_GPTQ15_EMA_QAT_1.1232/README.md @@ -0,0 +1,46 @@ +## Record: 11L TTT Burst + EMA + GPTQ-lite + warmdown3500 + QAT@0.15 + +**val_bpb: 1.1236** (sliding window stride=64, 2-seed mean) | **15.59 MB** (mean) | 8xH100 SXM, 600s + +### Key Innovation Over PR #414 + +| Change | PR #414 | This | Impact | +|--------|---------|------|--------| +| **TTT Burst** | None | 2-epoch replay of last 100 training batches at 10% LR before EMA | -0.0001 BPB | + +Everything else inherited from PR #414: EMA(0.997), GPTQ-lite(5 percentiles), warmdown 3500, Late QAT@0.15, int6+zstd-22. + +### TTT Burst: Late-Stage Sharpening + +After the main training loop and before EMA application, we replay the last 100 training batches for 2 epochs at 10% of base LR. EMA is updated during the burst so it absorbs the sharpened signal. This gives the model a final sharpening pass on recent data before weight averaging and quantization. + +### Results (3 seeds, 8xH100 SXM) + +| Seed | Steps | val_loss | Sliding BPB (s64) | Artifact | +|------|-------|----------|-------------------|----------| +| **1337** | 6991 | 1.9246 | **1.1232** | 15.68 MB | +| 42 | 6994 | 1.9262 | 1.1240 | 16.37 MB* | +| **2024** | 6987 | 1.9255 | **1.1239** | 15.50 MB | + +**Mean (1337+2024): 1.1236 | Std: 0.0004** + +*Seed 42 artifact over size limit due to compression variance; BPB validates the approach. + +### Architecture + +11L, 512d, 8H/4KV, MLP 3x (relu^2), U-Net skips, XSA4, Partial RoPE 16/64, LN Scale, VE128, SmearGate, BigramHash(2048), FA3, Muon WD=0.04, EMA(0.997), Tight SWA, Late QAT@0.15, TTT Burst(2ep/10%LR), int6+zstd-22, GPTQ-lite. + +### Run Command + +```bash +SEED=1337 torchrun --nproc_per_node=8 train_gpt.py +``` + +### Test plan + +- [x] All seeds train in 600s on 8xH100 +- [x] Seeds 1337, 2024 under 16MB (15.68 MB, 15.50 MB) +- [x] Post-quant int6 roundtrip verified +- [x] Sliding window eval (stride=64) consistent across seeds (std=0.0004) +- [x] train_gpt.py under 1500 lines (1443) +- [x] No TTT on validation data diff --git a/records/track_10min_16mb/2026-03-22_11L_TTTburst_GPTQ15_EMA_QAT_1.1232/submission.json b/records/track_10min_16mb/2026-03-22_11L_TTTburst_GPTQ15_EMA_QAT_1.1232/submission.json new file mode 100644 index 000000000..7dfeee963 --- /dev/null +++ b/records/track_10min_16mb/2026-03-22_11L_TTTburst_GPTQ15_EMA_QAT_1.1232/submission.json @@ -0,0 +1,15 @@ +{ + "track": "10min_16mb", + "date": "2026-03-22", + "val_bpb": 1.1236, + "model_bytes_mean": 15567198, + "code_bytes": 70306, + "hardware": "8xH100 SXM", + "train_time_seconds": 600, + "seeds": [1337, 42, 2024], + "seed_results": { + "1337": {"val_bpb": 1.12319, "artifact_bytes": 15679283}, + "42": {"val_bpb": 1.12397, "artifact_bytes": 16374443}, + "2024": {"val_bpb": 1.12390, "artifact_bytes": 15495725} + } +} diff --git a/records/track_10min_16mb/2026-03-22_11L_TTTburst_GPTQ15_EMA_QAT_1.1232/train_gpt.py b/records/track_10min_16mb/2026-03-22_11L_TTTburst_GPTQ15_EMA_QAT_1.1232/train_gpt.py new file mode 100644 index 000000000..66d2858e0 --- /dev/null +++ b/records/track_10min_16mb/2026-03-22_11L_TTTburst_GPTQ15_EMA_QAT_1.1232/train_gpt.py @@ -0,0 +1,1443 @@ +from __future__ import annotations +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +from flash_attn_interface import flash_attn_func as flash_attn_3_func +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) # tighter: collect more recent checkpoints + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) # XSA on last 4 layers (0 = disabled) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") + # Late-stage TTT burst: sharp adaptation on recent training data before finalizing + ttt_burst_enabled = bool(int(os.environ.get("TTT_BURST_ENABLED", "1"))) + ttt_burst_epochs = int(os.environ.get("TTT_BURST_EPOCHS", 2)) # epochs over recent data + ttt_burst_lr_factor = float(os.environ.get("TTT_BURST_LR_FACTOR", 0.1)) # fraction of base LR + ttt_burst_steps = int(os.environ.get("TTT_BURST_STEPS", 100)) # how many recent batches to replay + ttt_burst_trigger = float(os.environ.get("TTT_BURST_TRIGGER", 0.05)) # trigger at scale < this +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 # set by GPT.__init__ for partial RoPE + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False # set by GPT.__init__ for deep layers only + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). + y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] — broadcast ready + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific layers. + Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + if dtg: + self.dtg_gate = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.dtg_gate.weight) + nn.init.constant_(self.dtg_gate.bias, 2.0) + else: + self.dtg_gate = None + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) + if self.dtg_gate is not None: + gate = torch.sigmoid(self.dtg_gate(x_in.detach())) + x_out = x_in + gate * (x_out - x_in) + return x_out +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + ): + super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + layer_idx=i, + ln_scale=ln_scale, + dtg=dtg, + ) + for i in range(num_layers) + ] + ) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + """Get value embedding for a specific layer using shared table + per-layer scale.""" + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, v_embed=ve) + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + return main_loss + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, v_embed=ve) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._qat_enabled = args.qat_enabled + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.mtp_num_heads > 0: + matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + # Stash recent batches for TTT burst replay + if args.ttt_burst_enabled and scale < 0.2: + if not hasattr(train_loader, '_ttt_buffer'): + train_loader._ttt_buffer = [] + train_loader._ttt_buffer.append((x.detach().clone(), y.detach().clone())) + if len(train_loader._ttt_buffer) > args.ttt_burst_steps: + train_loader._ttt_buffer.pop(0) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + # EMA update + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + # === TTT BURST: Late-stage sharpening on recent training data === + if args.ttt_burst_enabled and hasattr(train_loader, '_ttt_buffer') and len(train_loader._ttt_buffer) > 0: + ttt_buffer = train_loader._ttt_buffer + log0(f"ttt_burst:start epochs:{args.ttt_burst_epochs} buffer_size:{len(ttt_buffer)} lr_factor:{args.ttt_burst_lr_factor}") + # Use a small fraction of base LR for fine-grained adaptation + ttt_lr_scale = args.ttt_burst_lr_factor + for ttt_epoch in range(args.ttt_burst_epochs): + ttt_epoch_loss = 0.0 + for ttt_i, (bx, by) in enumerate(ttt_buffer): + zero_grad_all() + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * ttt_lr_scale + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ttt_loss = model(bx, by) + (ttt_loss * grad_scale).backward() + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + ttt_epoch_loss += ttt_loss.item() + # Update EMA during burst too + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + log0(f"ttt_burst:epoch:{ttt_epoch + 1}/{args.ttt_burst_epochs} avg_loss:{ttt_epoch_loss / len(ttt_buffer):.4f}") + log0("ttt_burst:done") + # Apply EMA weights (better than SWA alone per PR#401) + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + torch.cuda.synchronize() + t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val( + args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" + ) + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, # must match training model + rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-03-22_11L_distill_v2_1.1233/submission.json b/records/track_10min_16mb/2026-03-22_11L_distill_v2_1.1233/submission.json new file mode 100644 index 000000000..c5f840458 --- /dev/null +++ b/records/track_10min_16mb/2026-03-22_11L_distill_v2_1.1233/submission.json @@ -0,0 +1,14 @@ +{ + "track": "10min_16mb", + "date": "2026-03-22", + "val_bpb": 1.1233, + "model_bytes": 15550451, + "code_bytes": 72287, + "total_bytes": 15622738, + "hardware": "8xH100 SXM", + "train_time_seconds": 600, + "seeds": [1337], + "seed_results": { + "1337": 1.12328 + } +} diff --git a/records/track_10min_16mb/2026-03-22_11L_distill_v2_1.1233/train_gpt.py b/records/track_10min_16mb/2026-03-22_11L_distill_v2_1.1233/train_gpt.py new file mode 100644 index 000000000..89418110f --- /dev/null +++ b/records/track_10min_16mb/2026-03-22_11L_distill_v2_1.1233/train_gpt.py @@ -0,0 +1,1478 @@ +from __future__ import annotations +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +from flash_attn_interface import flash_attn_func as flash_attn_3_func +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) # tighter: collect more recent checkpoints + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) # XSA on last 4 layers (0 = disabled) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") + # Self-distillation interjection: use EMA as teacher to smooth live model + distill_enabled = bool(int(os.environ.get("DISTILL_ENABLED", "1"))) + distill_trigger = float(os.environ.get("DISTILL_TRIGGER", 0.05)) # trigger at scale < this + distill_steps = int(os.environ.get("DISTILL_STEPS", 50)) # number of distillation steps + distill_lr_factor = float(os.environ.get("DISTILL_LR_FACTOR", 0.05)) # fraction of base LR + distill_temperature = float(os.environ.get("DISTILL_TEMPERATURE", 2.0)) # KL temperature + distill_alpha = float(os.environ.get("DISTILL_ALPHA", 0.7)) # weight of KL vs CE loss +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 # set by GPT.__init__ for partial RoPE + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False # set by GPT.__init__ for deep layers only + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). + y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] — broadcast ready + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific layers. + Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + if dtg: + self.dtg_gate = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.dtg_gate.weight) + nn.init.constant_(self.dtg_gate.bias, 2.0) + else: + self.dtg_gate = None + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) + if self.dtg_gate is not None: + gate = torch.sigmoid(self.dtg_gate(x_in.detach())) + x_out = x_in + gate * (x_out - x_in) + return x_out +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + ): + super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + layer_idx=i, + ln_scale=ln_scale, + dtg=dtg, + ) + for i in range(num_layers) + ] + ) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + """Get value embedding for a specific layer using shared table + per-layer scale.""" + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, v_embed=ve) + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + return main_loss + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, v_embed=ve) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._qat_enabled = args.qat_enabled + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.mtp_num_heads > 0: + matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + # EMA update + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + # === SELF-DISTILLATION: Use EMA as teacher to sharpen live model === + if args.distill_enabled: + log0(f"distill:start steps:{args.distill_steps} lr_factor:{args.distill_lr_factor} " + f"temp:{args.distill_temperature} alpha:{args.distill_alpha}") + # Build teacher model from EMA state (frozen) + current_state = base_model.state_dict() + teacher_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + teacher_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, rope_dims=args.rope_dims, ln_scale=args.ln_scale, + dtg=args.dtg_enabled, ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + ).to(device).bfloat16() + for m in teacher_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(teacher_model) + teacher_model.load_state_dict(teacher_state, strict=True) + teacher_model.eval() + for p in teacher_model.parameters(): + p.requires_grad_(False) + compiled_teacher = torch.compile(teacher_model, dynamic=False, fullgraph=True) + # Distillation loop: KL(teacher || student) + CE on fresh training data + model.train() + T = args.distill_temperature + alpha = args.distill_alpha + for d_step in range(args.distill_steps): + zero_grad_all() + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * args.distill_lr_factor + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + # Student logits + student_logits = base_model.forward_logits(x) # (bsz, seq, vocab) + # Teacher logits (no grad) + with torch.no_grad(): + teacher_logits = compiled_teacher.forward_logits(x) + # KL divergence loss (teacher as target distribution) + student_log_probs = F.log_softmax(student_logits.float() / T, dim=-1) + teacher_probs = F.softmax(teacher_logits.float() / T, dim=-1) + kl_loss = F.kl_div(student_log_probs, teacher_probs, reduction="batchmean") * (T * T) + # Standard CE loss + ce_loss = F.cross_entropy( + student_logits.reshape(-1, student_logits.size(-1)).float(), + y.reshape(-1), reduction="mean" + ) + # Combined loss + loss = alpha * kl_loss + (1.0 - alpha) * ce_loss + (loss * grad_scale).backward() + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + # Update EMA during distillation too + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + if (d_step + 1) % 10 == 0 or d_step == 0: + log0(f"distill:step:{d_step + 1}/{args.distill_steps} kl:{kl_loss.item():.4f} ce:{ce_loss.item():.4f} total:{loss.item():.4f}") + # Free teacher model + del teacher_model, compiled_teacher + torch.cuda.empty_cache() + log0("distill:done") + # Apply EMA weights (better than SWA alone per PR#401) + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + torch.cuda.synchronize() + t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val( + args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" + ) + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, # must match training model + rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() diff --git a/train_gpt_v2.py b/train_gpt_v2.py new file mode 100644 index 000000000..89418110f --- /dev/null +++ b/train_gpt_v2.py @@ -0,0 +1,1478 @@ +from __future__ import annotations +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +from flash_attn_interface import flash_attn_func as flash_attn_3_func +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) # tighter: collect more recent checkpoints + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) # XSA on last 4 layers (0 = disabled) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") + # Self-distillation interjection: use EMA as teacher to smooth live model + distill_enabled = bool(int(os.environ.get("DISTILL_ENABLED", "1"))) + distill_trigger = float(os.environ.get("DISTILL_TRIGGER", 0.05)) # trigger at scale < this + distill_steps = int(os.environ.get("DISTILL_STEPS", 50)) # number of distillation steps + distill_lr_factor = float(os.environ.get("DISTILL_LR_FACTOR", 0.05)) # fraction of base LR + distill_temperature = float(os.environ.get("DISTILL_TEMPERATURE", 2.0)) # KL temperature + distill_alpha = float(os.environ.get("DISTILL_ALPHA", 0.7)) # weight of KL vs CE loss +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 # set by GPT.__init__ for partial RoPE + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False # set by GPT.__init__ for deep layers only + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). + y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] — broadcast ready + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific layers. + Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + if dtg: + self.dtg_gate = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.dtg_gate.weight) + nn.init.constant_(self.dtg_gate.bias, 2.0) + else: + self.dtg_gate = None + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) + if self.dtg_gate is not None: + gate = torch.sigmoid(self.dtg_gate(x_in.detach())) + x_out = x_in + gate * (x_out - x_in) + return x_out +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + ): + super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + layer_idx=i, + ln_scale=ln_scale, + dtg=dtg, + ) + for i in range(num_layers) + ] + ) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + """Get value embedding for a specific layer using shared table + per-layer scale.""" + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, v_embed=ve) + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + return main_loss + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, v_embed=ve) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._qat_enabled = args.qat_enabled + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.mtp_num_heads > 0: + matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + # EMA update + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + # === SELF-DISTILLATION: Use EMA as teacher to sharpen live model === + if args.distill_enabled: + log0(f"distill:start steps:{args.distill_steps} lr_factor:{args.distill_lr_factor} " + f"temp:{args.distill_temperature} alpha:{args.distill_alpha}") + # Build teacher model from EMA state (frozen) + current_state = base_model.state_dict() + teacher_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + teacher_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, rope_dims=args.rope_dims, ln_scale=args.ln_scale, + dtg=args.dtg_enabled, ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + ).to(device).bfloat16() + for m in teacher_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(teacher_model) + teacher_model.load_state_dict(teacher_state, strict=True) + teacher_model.eval() + for p in teacher_model.parameters(): + p.requires_grad_(False) + compiled_teacher = torch.compile(teacher_model, dynamic=False, fullgraph=True) + # Distillation loop: KL(teacher || student) + CE on fresh training data + model.train() + T = args.distill_temperature + alpha = args.distill_alpha + for d_step in range(args.distill_steps): + zero_grad_all() + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * args.distill_lr_factor + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + # Student logits + student_logits = base_model.forward_logits(x) # (bsz, seq, vocab) + # Teacher logits (no grad) + with torch.no_grad(): + teacher_logits = compiled_teacher.forward_logits(x) + # KL divergence loss (teacher as target distribution) + student_log_probs = F.log_softmax(student_logits.float() / T, dim=-1) + teacher_probs = F.softmax(teacher_logits.float() / T, dim=-1) + kl_loss = F.kl_div(student_log_probs, teacher_probs, reduction="batchmean") * (T * T) + # Standard CE loss + ce_loss = F.cross_entropy( + student_logits.reshape(-1, student_logits.size(-1)).float(), + y.reshape(-1), reduction="mean" + ) + # Combined loss + loss = alpha * kl_loss + (1.0 - alpha) * ce_loss + (loss * grad_scale).backward() + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + # Update EMA during distillation too + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + if (d_step + 1) % 10 == 0 or d_step == 0: + log0(f"distill:step:{d_step + 1}/{args.distill_steps} kl:{kl_loss.item():.4f} ce:{ce_loss.item():.4f} total:{loss.item():.4f}") + # Free teacher model + del teacher_model, compiled_teacher + torch.cuda.empty_cache() + log0("distill:done") + # Apply EMA weights (better than SWA alone per PR#401) + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + torch.cuda.synchronize() + t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val( + args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" + ) + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, # must match training model + rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() diff --git a/train_local.py b/train_local.py new file mode 100644 index 000000000..5b7204e73 --- /dev/null +++ b/train_local.py @@ -0,0 +1,601 @@ +""" +Parameter Golf — Local Research Fork +===================================== +Simplified training script for DGX Spark (GB10). +No Triton/torch.compile dependency. Uses PyTorch native SDPA. +Same model architecture, data, tokenizer, and BPB metric as official. + +Usage: + source .venv/bin/activate + + # Baseline (standard 9-layer, no modifications) + python train_local.py --mode baseline + + # Fractal (weight-shared layers with loops) + python train_local.py --mode fractal --num-unique-layers 3 --num-loops 3 + + # Fractal + Gravity + python train_local.py --mode fractal --gravity + + # Fractal + Gravity + AttnRes + python train_local.py --mode fractal --gravity --attnres +""" + +from __future__ import annotations +import argparse +import glob +import io +import math +import os +import time +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +# ─── CLI ────────────────────────────────────────────────────────────────────── + +def get_args(): + p = argparse.ArgumentParser() + p.add_argument("--mode", choices=["baseline", "fractal"], default="baseline") + p.add_argument("--num-unique-layers", type=int, default=3) + p.add_argument("--num-loops", type=int, default=3) + p.add_argument("--model-dim", type=int, default=0, help="0 = auto-size to match baseline param count") + p.add_argument("--num-heads", type=int, default=8) + p.add_argument("--num-kv-heads", type=int, default=4) + p.add_argument("--vocab-size", type=int, default=1024) + p.add_argument("--seq-len", type=int, default=1024) + p.add_argument("--mlp-mult", type=int, default=2) + p.add_argument("--gravity", action="store_true", help="Enable learned gravity aux losses") + p.add_argument("--attnres", action="store_true", help="Enable attention residuals") + p.add_argument("--iterations", type=int, default=500) + p.add_argument("--batch-tokens", type=int, default=32768) + p.add_argument("--max-seconds", type=float, default=120.0) + p.add_argument("--lr", type=float, default=3e-4) + p.add_argument("--warmup-steps", type=int, default=20) + p.add_argument("--log-every", type=int, default=25) + p.add_argument("--data-path", type=str, default="./data/datasets/fineweb10B_sp1024") + p.add_argument("--tokenizer-path", type=str, default="./data/tokenizers/fineweb_1024_bpe.model") + p.add_argument("--seed", type=int, default=1337) + p.add_argument("--eval-tokens", type=int, default=0, help="0 = full val set, >0 = truncated for speed") + p.add_argument("--run-id", type=str, default="local") + return p.parse_args() + +# ─── DATA LOADING ───────────────────────────────────────────────────────────── + +def load_shard(path: Path) -> Tensor: + header = np.fromfile(path, dtype=" Tensor: + chunks = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self.idx = (self.idx + 1) % len(self.files) + self.tokens = load_shard(Path(self.files[self.idx])) + self.pos = 0 + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +# ─── BPB EVALUATION ────────────────────────────────────────────────────────── + +def build_bpb_luts(sp, vocab_size, device): + sp_vs = int(sp.vocab_size()) + table_size = max(sp_vs, vocab_size) + base_bytes = np.zeros(table_size, dtype=np.int16) + has_space = np.zeros(table_size, dtype=np.bool_) + is_boundary = np.ones(table_size, dtype=np.bool_) + for tid in range(sp_vs): + if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid): + continue + is_boundary[tid] = False + if sp.is_byte(tid): + base_bytes[tid] = 1 + continue + piece = sp.id_to_piece(tid) + if piece.startswith("▁"): + has_space[tid] = True + piece = piece[1:] + base_bytes[tid] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes, dtype=torch.int16, device=device), + torch.tensor(has_space, dtype=torch.bool, device=device), + torch.tensor(is_boundary, dtype=torch.bool, device=device), + ) + +@torch.no_grad() +def eval_bpb(model, val_tokens, seq_len, batch_tokens, device, base_bytes_lut, has_space_lut, is_boundary_lut): + model.eval() + local_batch_seqs = max(1, batch_tokens // seq_len) + total_seqs = (val_tokens.numel() - 1) // seq_len + loss_sum = 0.0 + token_count = 0.0 + byte_count = 0.0 + + for start in range(0, total_seqs, local_batch_seqs): + end = min(start + local_batch_seqs, total_seqs) + raw_start = start * seq_len + raw_end = end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y) + if isinstance(loss, tuple): + loss = loss[0] # gravity returns (total_loss, final_loss) + n = float(y.numel()) + loss_sum += loss.item() * n + token_count += n + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + tb = base_bytes_lut[tgt_ids].to(torch.int16) + tb += (has_space_lut[tgt_ids] & ~is_boundary_lut[prev_ids]).to(torch.int16) + byte_count += tb.to(torch.float64).sum().item() + + model.train() + val_loss = loss_sum / token_count + bpt = val_loss / math.log(2.0) + tpb = token_count / byte_count + return val_loss, bpt * tpb + +# ─── MODEL: SHARED COMPONENTS ──────────────────────────────────────────────── + +class RMSNorm(nn.Module): + def forward(self, x): + return F.rms_norm(x, (x.size(-1),)) + +class Rotary(nn.Module): + def __init__(self, dim, base=10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._cache_len = 0 + self._cos = None + self._sin = None + + def forward(self, seq_len, device, dtype): + if self._cos is None or self._cache_len < seq_len or self._cos.device != device: + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos = freqs.cos()[None, None, :, :] + self._sin = freqs.sin()[None, None, :, :] + self._cache_len = seq_len + return self._cos[:, :, :seq_len].to(dtype), self._sin[:, :, :seq_len].to(dtype) + +def apply_rope(x, cos, sin): + d = x.size(-1) // 2 + x1, x2 = x[..., :d], x[..., d:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class Attention(nn.Module): + def __init__(self, dim, n_heads, n_kv_heads, rope_base=10000.0): + super().__init__() + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.head_dim = dim // n_heads + kv_dim = n_kv_heads * self.head_dim + self.c_q = nn.Linear(dim, dim, bias=False) + self.c_k = nn.Linear(dim, kv_dim, bias=False) + self.c_v = nn.Linear(dim, kv_dim, bias=False) + self.c_proj = nn.Linear(dim, dim, bias=False) + self.rotary = Rotary(self.head_dim, rope_base) + + def forward(self, x): + B, T, C = x.shape + q = self.c_q(x).reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) + q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(T, x.device, q.dtype) + q, k = apply_rope(q, cos, sin), apply_rope(k, cos, sin) + y = F.scaled_dot_product_attention(q, k, v, is_causal=True, + enable_gqa=(self.n_kv_heads != self.n_heads)) + return self.c_proj(y.transpose(1, 2).contiguous().reshape(B, T, C)) + +class MLP(nn.Module): + def __init__(self, dim, mult=2): + super().__init__() + hidden = dim * mult + self.fc = nn.Linear(dim, hidden, bias=False) + self.proj = nn.Linear(hidden, dim, bias=False) + + def forward(self, x): + return self.proj(F.relu(self.fc(x)).square()) + +class Block(nn.Module): + def __init__(self, dim, n_heads, n_kv_heads, mlp_mult): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = Attention(dim, n_heads, n_kv_heads) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim)) + self.mlp_scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + x = x + self.attn_scale * self.attn(self.attn_norm(x)) + x = x + self.mlp_scale * self.mlp(self.mlp_norm(x)) + return x + +# ─── MODEL: BASELINE (standard 9-layer) ────────────────────────────────────── + +class BaselineGPT(nn.Module): + def __init__(self, vocab_size, num_layers, dim, n_heads, n_kv_heads, mlp_mult, + softcap=30.0): + super().__init__() + self.softcap = softcap + self.tok_emb = nn.Embedding(vocab_size, dim) + n_enc = num_layers // 2 + n_dec = num_layers - n_enc + n_skip = min(n_enc, n_dec) + self.n_enc = n_enc + self.n_dec = n_dec + self.skip_weights = nn.Parameter(torch.ones(n_skip, dim)) + self.blocks = nn.ModuleList([Block(dim, n_heads, n_kv_heads, mlp_mult) + for _ in range(num_layers)]) + self.final_norm = RMSNorm() + self.lm_head = nn.Linear(dim, vocab_size, bias=False) + # Tie embeddings + self.lm_head.weight = self.tok_emb.weight + self._init() + + def _init(self): + nn.init.normal_(self.tok_emb.weight, std=0.005) + for block in self.blocks: + for m in [block.attn.c_q, block.attn.c_k, block.attn.c_v, block.mlp.fc]: + nn.init.normal_(m.weight, std=0.02) + for m in [block.attn.c_proj, block.mlp.proj]: + nn.init.zeros_(m.weight) + + def forward(self, x_ids, targets): + x = F.rms_norm(self.tok_emb(x_ids), (self.tok_emb.weight.size(-1),)) + x0 = x + skips = [] + for i in range(self.n_enc): + x = self.blocks[i](x) + skips.append(x) + for i in range(self.n_dec): + if skips: + x = x + self.skip_weights[i] * skips.pop() + x = self.blocks[self.n_enc + i](x) + x = self.final_norm(x).reshape(-1, x.size(-1)) + logits = self.lm_head(x) + logits = self.softcap * torch.tanh(logits / self.softcap) + return F.cross_entropy(logits.float(), targets.reshape(-1)) + +# ─── MODEL: FRACTAL (weight-shared + gravity + attnres) ────────────────────── + +class AttnResModule(nn.Module): + """Attention over previous loop outputs. One learned query per layer.""" + def __init__(self, dim): + super().__init__() + self.query = nn.Parameter(torch.randn(dim) * 0.01) + self.norm = RMSNorm() + + def forward(self, loop_outputs): + """ + loop_outputs: list of [B, T, D] tensors (previous loop outputs) + Returns: [B, T, D] weighted combination + """ + if len(loop_outputs) == 1: + return loop_outputs[0] + V = torch.stack(loop_outputs, dim=0) # [N, B, T, D] + K = self.norm(V) + logits = torch.einsum('d, n b t d -> n b t', self.query, K) + weights = logits.softmax(dim=0) + return torch.einsum('n b t, n b t d -> b t d', weights, V) + +class FractalGPT(nn.Module): + def __init__(self, vocab_size, num_unique_layers, num_loops, dim, n_heads, + n_kv_heads, mlp_mult, use_gravity=False, use_attnres=False, + softcap=30.0): + super().__init__() + self.num_loops = num_loops + self.num_unique_layers = num_unique_layers + self.use_gravity = use_gravity + self.use_attnres = use_attnres + self.softcap = softcap + self.dim = dim + + self.tok_emb = nn.Embedding(vocab_size, dim) + self.blocks = nn.ModuleList([Block(dim, n_heads, n_kv_heads, mlp_mult) + for _ in range(num_unique_layers)]) + self.final_norm = RMSNorm() + self.lm_head = nn.Linear(dim, vocab_size, bias=False) + # Tie embeddings + self.lm_head.weight = self.tok_emb.weight + + # Loop position embeddings + self.loop_pos = nn.Parameter(torch.randn(num_loops, dim) * 0.01) + + # Gravity: learned auxiliary loss weights + if use_gravity: + self.gravity_logits = nn.Parameter(torch.tensor( + [-2.0] * (num_loops - 1) + [0.0] # softplus → ~[0.13, ..., 0.69] + )) + + # AttnRes: one module per loop (except first loop which has nothing to attend to) + if use_attnres: + total_layers = num_unique_layers * num_loops + self.attnres = nn.ModuleList([ + AttnResModule(dim) for _ in range(total_layers) + ]) + + self._init() + + def _init(self): + nn.init.normal_(self.tok_emb.weight, std=0.005) + for block in self.blocks: + for m in [block.attn.c_q, block.attn.c_k, block.attn.c_v, block.mlp.fc]: + nn.init.normal_(m.weight, std=0.02) + for m in [block.attn.c_proj, block.mlp.proj]: + nn.init.zeros_(m.weight) + + def _compute_logits(self, x): + x = self.final_norm(x).reshape(-1, x.size(-1)) + logits = self.lm_head(x) + return self.softcap * torch.tanh(logits / self.softcap) + + def forward(self, x_ids, targets): + x = F.rms_norm(self.tok_emb(x_ids), (self.tok_emb.weight.size(-1),)) + + loop_outputs = [x] # embedding is always available for AttnRes + gravity_losses = [] + flat_layer_idx = 0 + + for loop in range(self.num_loops): + # Add loop position embedding + x = x + self.loop_pos[loop] + + # Run shared layers + for layer_idx in range(self.num_unique_layers): + # AttnRes: attend over previous loop outputs before this layer + if self.use_attnres and len(loop_outputs) > 1: + x = self.attnres[flat_layer_idx](loop_outputs + [x]) + + x = self.blocks[layer_idx](x) + flat_layer_idx += 1 + + # Store this loop's output for future AttnRes + loop_outputs.append(x) + + # Gravity: compute auxiliary loss at loop boundary + if self.use_gravity and loop < self.num_loops - 1: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + aux_logits = self._compute_logits(x) + aux_loss = F.cross_entropy(aux_logits.float(), targets.reshape(-1)) + weight = F.softplus(self.gravity_logits[loop]) + gravity_losses.append(weight * aux_loss) + + # Final loss (always weight 1.0 equivalent) + final_logits = self._compute_logits(x) + final_loss = F.cross_entropy(final_logits.float(), targets.reshape(-1)) + + if self.use_gravity and gravity_losses: + final_weight = F.softplus(self.gravity_logits[-1]) + total_loss = sum(gravity_losses) + final_weight * final_loss + # Normalize so total weight sums to ~1 + total_weight = sum(F.softplus(self.gravity_logits[i]) for i in range(self.num_loops)) + total_loss = total_loss / total_weight + return total_loss + + return final_loss + +# ─── OPTIMIZER ──────────────────────────────────────────────────────────────── + +def make_optimizer(model, lr): + """Simple AdamW — we'll add Muon later if needed.""" + decay_params = [p for n, p in model.named_parameters() if p.dim() >= 2] + nodecay_params = [p for n, p in model.named_parameters() if p.dim() < 2] + groups = [ + {"params": decay_params, "weight_decay": 0.1}, + {"params": nodecay_params, "weight_decay": 0.0}, + ] + return torch.optim.AdamW(groups, lr=lr, betas=(0.9, 0.95), fused=True) + +def cosine_lr(step, max_steps, lr, warmup=20, min_frac=0.1): + if step < warmup: + return lr * step / warmup + decay = (step - warmup) / max(max_steps - warmup, 1) + return lr * (min_frac + (1 - min_frac) * 0.5 * (1 + math.cos(math.pi * decay))) + +# ─── AUTO-SIZE MODEL DIM ───────────────────────────────────────────────────── + +def estimate_params(dim, n_heads, n_kv_heads, mlp_mult, num_unique_layers, vocab_size): + head_dim = dim // n_heads + kv_dim = n_kv_heads * head_dim + per_layer = ( + dim * dim + # c_q + dim * kv_dim + # c_k + dim * kv_dim + # c_v + dim * dim + # c_proj + dim * (dim * mlp_mult) + # fc + (dim * mlp_mult) * dim + # proj + dim * 2 # scales + ) + total = ( + vocab_size * dim + # embedding (tied with lm_head) + num_unique_layers * per_layer # transformer layers + ) + return total + +def auto_dim(target_params, n_heads, n_kv_heads, mlp_mult, num_unique_layers, vocab_size): + """Find the largest dim (divisible by 2*n_heads for RoPE) that fits in target_params.""" + step = 2 * n_heads # must be divisible by 2*n_heads so head_dim is even + for dim in range(2048, 128, -step): + if estimate_params(dim, n_heads, n_kv_heads, mlp_mult, num_unique_layers, vocab_size) <= target_params: + return dim + return 256 + +# ─── MAIN ───────────────────────────────────────────────────────────────────── + +def main(): + args = get_args() + device = torch.device("cuda") + torch.manual_seed(args.seed) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + print("=" * 70) + print(f"PARAMETER GOLF LOCAL — mode={args.mode}") + print("=" * 70) + + # Tokenizer + BPB setup + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + base_bytes_lut, has_space_lut, is_boundary_lut = build_bpb_luts(sp, args.vocab_size, device) + + # Validation data + val_files = sorted(glob.glob(os.path.join(args.data_path, "fineweb_val_*.bin"))) + val_tokens = torch.cat([load_shard(Path(f)) for f in val_files]) + usable = ((val_tokens.numel() - 1) // args.seq_len) * args.seq_len + val_tokens = val_tokens[:usable + 1] + if args.eval_tokens > 0: + max_eval = min(args.eval_tokens + 1, val_tokens.numel()) + eval_usable = ((max_eval - 1) // args.seq_len) * args.seq_len + val_tokens = val_tokens[:eval_usable + 1] + print(f"Val tokens: {val_tokens.numel():,}{' (truncated)' if args.eval_tokens > 0 else ''}") + + # Train data + train_stream = TokenStream(os.path.join(args.data_path, "fineweb_train_*.bin")) + + # Baseline param count for auto-sizing + BASELINE_PARAMS = estimate_params(512, 8, 4, 2, 9, args.vocab_size) + + # Build model + if args.mode == "baseline": + dim = args.model_dim if args.model_dim > 0 else 512 + model = BaselineGPT( + vocab_size=args.vocab_size, num_layers=9, dim=dim, + n_heads=args.num_heads, n_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + ).to(device).bfloat16() + else: + # Auto-size dim to match baseline param count + if args.model_dim > 0: + dim = args.model_dim + else: + dim = auto_dim(BASELINE_PARAMS, args.num_heads, args.num_kv_heads, + args.mlp_mult, args.num_unique_layers, args.vocab_size) + # Ensure divisible by 2*num_heads (RoPE needs even head_dim) + step = 2 * args.num_heads + dim = (dim // step) * step + + model = FractalGPT( + vocab_size=args.vocab_size, + num_unique_layers=args.num_unique_layers, + num_loops=args.num_loops, + dim=dim, + n_heads=args.num_heads, + n_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + use_gravity=args.gravity, + use_attnres=args.attnres, + ).to(device).bfloat16() + + n_params = sum(p.numel() for p in model.parameters()) + print(f"Model: {n_params:,} params ({n_params/1e6:.1f}M)") + if args.mode == "fractal": + print(f" unique_layers={args.num_unique_layers} loops={args.num_loops} dim={dim}") + print(f" gravity={args.gravity} attnres={args.attnres}") + print(f" effective_depth={args.num_unique_layers * args.num_loops}") + else: + print(f" layers=9 dim={dim}") + print(f" baseline_params={BASELINE_PARAMS:,}") + + optimizer = make_optimizer(model, args.lr) + seq_len = args.seq_len + seqs_per_batch = max(1, args.batch_tokens // seq_len) + + # Training loop + print(f"\nTraining: {args.iterations} iters, {args.max_seconds:.0f}s max, " + f"batch={seqs_per_batch * seq_len} tokens") + model.train() + t_start = time.time() + train_time_ms = 0.0 + + for step in range(1, args.iterations + 1): + # LR schedule + lr = cosine_lr(step, args.iterations, args.lr, args.warmup_steps) + for pg in optimizer.param_groups: + pg["lr"] = lr + + # Get batch + chunk = train_stream.take(seqs_per_batch * seq_len + 1).to(torch.int64) + x = chunk[:-1].reshape(seqs_per_batch, seq_len).to(device) + y = chunk[1:].reshape(seqs_per_batch, seq_len).to(device) + + # Forward / backward + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y) + if isinstance(loss, tuple): + loss = loss[0] + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + + elapsed = time.time() - t_start + train_time_ms = elapsed * 1000 + + if step % args.log_every == 0 or step <= 5: + print(f"step:{step}/{args.iterations} train_loss:{loss.item():.4f} " + f"lr:{lr:.2e} time:{train_time_ms:.0f}ms " + f"step_avg:{train_time_ms/step:.1f}ms") + + # Wallclock cap + if args.max_seconds > 0 and elapsed >= args.max_seconds: + print(f"Wallclock cap reached at step {step} ({elapsed:.1f}s)") + break + + # Eval + print("\nEvaluating...") + val_loss, val_bpb = eval_bpb( + model, val_tokens, seq_len, args.batch_tokens, device, + base_bytes_lut, has_space_lut, is_boundary_lut, + ) + print(f"\nval_loss: {val_loss:.4f}") + print(f"val_bpb: {val_bpb:.6f}") + print(f"params: {n_params:,}") + print(f"time: {train_time_ms:.0f}ms") + print(f"steps: {step}") + + # Gravity weights (if applicable) + if args.mode == "fractal" and args.gravity: + gw = [F.softplus(model.gravity_logits[i]).item() for i in range(model.num_loops)] + print(f"gravity_weights: {['%.4f' % w for w in gw]}") + + # Quick size estimate + state = model.state_dict() + buf = io.BytesIO() + torch.save(state, buf) + raw = len(buf.getvalue()) + compressed = len(zlib.compress(buf.getvalue(), 9)) + print(f"raw_model_size: {raw:,} bytes ({raw/1e6:.1f}MB)") + print(f"zlib_compressed: {compressed:,} bytes ({compressed/1e6:.1f}MB)") + + peak_mem = torch.cuda.max_memory_allocated() / 1024**2 + print(f"peak_vram: {peak_mem:.0f} MiB") + +if __name__ == "__main__": + main()