From fca68ed408e27161537c59f5fe2e157db07caddb Mon Sep 17 00:00:00 2001 From: Vytautas <64431422+vytautas-bunevicius@users.noreply.github.com> Date: Sun, 22 Mar 2026 14:51:07 +0200 Subject: [PATCH 1/3] feat: add non-record submission 11L mixed int5/int6 + QAT + TTT (val_bpb=1.1466) --- .../README.md | 56 + .../submission.json | 11 + .../train.log | 92 + .../train_gpt.py | 1745 +++++++++++++++++ 4 files changed, 1904 insertions(+) create mode 100644 records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/README.md create mode 100644 records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/submission.json create mode 100644 records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/train.log create mode 100644 records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/train_gpt.py diff --git a/records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/README.md b/records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/README.md new file mode 100644 index 000000000..bd2faaf29 --- /dev/null +++ b/records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/README.md @@ -0,0 +1,56 @@ +# Non-record submission: 11L mixed int5/int6 + working QAT + TTT + 8 additions + +**val_bpb = 1.1466** (sliding window, stride=32, post-TTT) | **14.7 MB** artifact | 8xH100 SXM, 605s train + 340s eval + +Built on PR #315 (1.1248). Ran with PyTorch SDPA instead of FA3, so throughput was 110ms/step instead of 85ms. Got 5,129 steps instead of ~7,000. Score should drop with FA3. + +## What we added to PR #315 + +**1. Working QAT.** PR #315's late QAT is dead code because `torch.compile` constant-folds `CastedLinear._qat_enabled` at first trace. We swap the `forward` method to `forward_qat` per instance and recompile. QAT noise matches the export scheme: int5 STE for MLP, int6 STE for attention. + +**2. Mixed int5/int6 quantization + magnitude pruning.** MLP weights get int5 ([-16, 15]), attention gets int6 ([-32, 31]), embeddings stay int8. 3% magnitude pruning before quantization. Result: 14.7MB with 1.3MB headroom. + +**3. Test-time training.** 3 epochs of SGD on validation tokens post-quantization. lr=0.002, momentum=0.9, first 2 blocks frozen. Gradients synced via all_reduce(AVG). Took 83s on 8xH100. Moved BPB from 1.1697 to 1.1466. + +**4. BigramHash 10240.** Up from 2048 in PR #315. + +**5. Memory tokens.** 64 learnable embeddings as global context. Overwritten during training (targets masked), prepended during eval (stripped after layers). 32K params. + +**6. Backout connection.** Learned scalar (init=0.2) subtracts encoder/decoder boundary state from final output. One parameter. + +**7. Per-head temperature.** Learned temperature per attention head. 88 params total. + +**8. Eval stride 32.** Down from 64. Made no difference here (s32 and s64 both gave 1.1466). + +## What we kept from PR #315 + +11 layers, U-Net skips, XSA on last 4, EMA (0.997), partial RoPE (16/64 dims), LN scale, 3x MLP relu-squared, SmearGate, ortho+muP init, Muon (0.025, 0.99, WD=0.04), NTK RoPE, seq 2048, softcap 30. + +## Results + +| Metric | Value | +|--------|-------| +| Steps | 5,129 (110ms/step, SDPA) | +| Pre-quant val_bpb | 1.1597 | +| Post-quant val_bpb | 1.1697 | +| Quant gap | +0.0100 | +| Post-TTT sliding s32 | **1.1466** | +| Artifact | 14,706,424 bytes | +| TTT time | 83s | +| Peak memory | 25,777 MiB/GPU | + +## What would help + +- FA3 (30% more training steps) +- 12th layer with the 1.3MB budget headroom +- QAT getting more than 1 step (it kicked in at step 5128, stopped at 5129) + +## How to run + +```bash +pip install huggingface-hub datasets sentencepiece tqdm zstandard +python3 data/cached_challenge_fineweb.py --variant sp1024 +torchrun --standalone --nproc_per_node=8 records/track_10min_16mb/2026-03-21_SOTA/train_gpt.py +``` + +Single seed (1337), torch 2.4.1+cu124, 8xH100 SXM on RunPod. diff --git a/records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/submission.json b/records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/submission.json new file mode 100644 index 000000000..76253e55a --- /dev/null +++ b/records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/submission.json @@ -0,0 +1,11 @@ +{ + "author": "Vytautas Bunevicius", + "github_id": "vytautas-bunevicius", + "name": "11L mixed int5/int6 + working QAT + TTT", + "blurb": "Non-record submission stacking 8 techniques on PR #315: working QAT (fixed dead code), mixed int5/int6 quantization, test-time training, BigramHash(10240), memory tokens, backout connection, per-head temperature, eval stride 32. Ran with PyTorch SDPA (no FA3), 5129 steps at 110ms/step.", + "date": "2026-03-22T12:20:51Z", + "val_loss": 1.93594245, + "val_bpb": 1.14657797, + "bytes_total": 14706424, + "bytes_code": 75257 +} diff --git a/records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/train.log b/records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/train.log new file mode 100644 index 000000000..ddb2b5355 --- /dev/null +++ b/records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/train.log @@ -0,0 +1,92 @@ +logs/d6796b3c-a185-44f5-a817-d3e0c3e09b52.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:27911346 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9280 val_bpb:4.1031 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9304 train_time:154ms step_avg:153.85ms +step:2/20000 train_loss:8.4705 train_time:251ms step_avg:125.31ms +step:3/20000 train_loss:7.5179 train_time:360ms step_avg:119.94ms +step:4/20000 train_loss:8.1017 train_time:469ms step_avg:117.30ms +step:5/20000 train_loss:8.2824 train_time:579ms step_avg:115.74ms +step:6/20000 train_loss:7.9754 train_time:688ms step_avg:114.68ms +step:7/20000 train_loss:7.5070 train_time:797ms step_avg:113.87ms +step:8/20000 train_loss:7.1406 train_time:906ms step_avg:113.30ms +step:9/20000 train_loss:6.6636 train_time:1016ms step_avg:112.84ms +step:10/20000 train_loss:6.2675 train_time:1125ms step_avg:112.52ms +step:200/20000 train_loss:2.3890 train_time:22059ms step_avg:110.30ms +step:400/20000 train_loss:2.4267 train_time:44213ms step_avg:110.53ms +step:600/20000 train_loss:2.3474 train_time:66368ms step_avg:110.61ms +step:800/20000 train_loss:2.2505 train_time:88587ms step_avg:110.73ms +step:1000/20000 train_loss:2.2858 train_time:110732ms step_avg:110.73ms +step:1000/20000 val_loss:2.2363 val_bpb:1.3245 train_time:110749ms step_avg:110.75ms +step:1200/20000 train_loss:2.3632 train_time:132921ms step_avg:110.77ms +step:1400/20000 train_loss:2.1869 train_time:155095ms step_avg:110.78ms +step:1600/20000 train_loss:2.0789 train_time:177188ms step_avg:110.74ms +step:1800/20000 train_loss:2.1527 train_time:199334ms step_avg:110.74ms +step:2000/20000 train_loss:2.0611 train_time:221406ms step_avg:110.70ms +step:2000/20000 val_loss:2.1305 val_bpb:1.2618 train_time:221421ms step_avg:110.71ms +step:2200/20000 train_loss:2.1296 train_time:243546ms step_avg:110.70ms +step:2400/20000 train_loss:2.0630 train_time:265612ms step_avg:110.67ms +step:2600/20000 train_loss:2.1011 train_time:287726ms step_avg:110.66ms +step:2800/20000 train_loss:2.1464 train_time:309841ms step_avg:110.66ms +step:3000/20000 train_loss:2.1432 train_time:331881ms step_avg:110.63ms +step:3000/20000 val_loss:2.0737 val_bpb:1.2282 train_time:331898ms step_avg:110.63ms +step:3200/20000 train_loss:2.1530 train_time:353995ms step_avg:110.62ms +step:3400/20000 train_loss:1.9918 train_time:376141ms step_avg:110.63ms +step:3600/20000 train_loss:2.0610 train_time:398251ms step_avg:110.63ms +step:3800/20000 train_loss:2.0316 train_time:420294ms step_avg:110.60ms +step:4000/20000 train_loss:1.9355 train_time:442405ms step_avg:110.60ms +step:4000/20000 val_loss:2.0245 val_bpb:1.1990 train_time:442420ms step_avg:110.61ms +step:4200/20000 train_loss:2.1051 train_time:464511ms step_avg:110.60ms +step:4400/20000 train_loss:1.9798 train_time:486552ms step_avg:110.58ms +step:4600/20000 train_loss:1.7920 train_time:508661ms step_avg:110.58ms +step:4800/20000 train_loss:2.3816 train_time:530696ms step_avg:110.56ms +step:5000/20000 train_loss:2.0406 train_time:552793ms step_avg:110.56ms +step:5000/20000 val_loss:1.9650 val_bpb:1.1638 train_time:552808ms step_avg:110.56ms +late_qat:enabled step:5128 scale:0.0997 +step:5129/20000 val_loss:1.9581 val_bpb:1.1597 train_time:605055ms step_avg:117.97ms +stopping_early: wallclock_cap train_time:605055ms step:5129/20000 +peak memory allocated: 25777 MiB reserved: 26916 MiB +ema:applying EMA weights +Serialized model: 108015620 bytes +Code size: 75257 bytes +Serialized model int6+zstd: 14631167 bytes +Total submission size int6+zstd: 14706424 bytes +ttt:epoch:1/3 loss:1.9625 +ttt:epoch:2/3 loss:1.9614 +ttt:epoch:3/3 loss:1.9609 +ttt:done time:82697ms +final_int6_roundtrip val_loss:1.9749 val_bpb:1.1697 eval_time:38727ms +final_int6_roundtrip_exact val_loss:1.97493070 val_bpb:1.16966520 +final_int6_sliding_window val_loss:1.9359 val_bpb:1.1466 stride:32 eval_time:217603ms +final_int6_sliding_window_exact val_loss:1.93594245 val_bpb:1.14657797 +final_int6_sliding_window_s64 val_loss:1.9360 val_bpb:1.1466 stride:64 eval_time:109694ms +final_int6_sliding_window_s64_exact val_loss:1.93596518 val_bpb:1.14659066 diff --git a/records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/train_gpt.py b/records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/train_gpt.py new file mode 100644 index 000000000..ad3205f0f --- /dev/null +++ b/records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/train_gpt.py @@ -0,0 +1,1745 @@ +""" +train_gpt_submit.py — SOTA submission: 11L + Partial RoPE + LN Scale + XSA4 + EMA + +Mixed int5/int6 QAT + BigramHash(10240) + MemoryTokens + TTT + BackoutConnection + +Per-head temperature + stride=32 sliding window eval. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + from flash_attn_interface import flash_attn_func as _fa3_func + _ATTN_BACKEND = "fa3" +except ImportError: + try: + from flash_attn import flash_attn_func as _fa2_func + _ATTN_BACKEND = "fa2" + except ImportError: + _ATTN_BACKEND = "sdpa" + +# HYPERPARAMETERS +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 32)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "0"))) + swa_every = int(os.environ.get("SWA_EVERY", 200)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + late_qat = bool(int(os.environ.get("LATE_QAT", "1"))) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + # New features + num_memory_tokens = int(os.environ.get("NUM_MEMORY_TOKENS", 64)) + backout_enabled = bool(int(os.environ.get("BACKOUT_ENABLED", "1"))) + backout_lambda_init = float(os.environ.get("BACKOUT_LAMBDA_INIT", 0.2)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "1"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 2)) + +# MUON OPTIMIZER +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# POST-TRAINING QUANTIZATION +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,head_temp,skip_weight,skip_weights,smear,backout_lambda,memory_tokens", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# DATA LOADING + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# TRANSFORMER MODULES + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +_QAT_ACTIVE = False + +def _enable_qat(module: nn.Module) -> None: + """Swap CastedLinear.forward to forward_qat with matched int5/int6 noise.""" + global _QAT_ACTIVE + _QAT_ACTIVE = True + for name, m in module.named_modules(): + if isinstance(m, CastedLinear) and m.weight.ndim == 2: + # Match QAT noise to actual export quantization scheme + if ".mlp." in name: + m._qat_clip_range = 15 # int5 for MLP + else: + m._qat_clip_range = 31 # int6 for attention + m.forward = m.forward_qat # type: ignore[assignment] + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + _qat_clip_range: int = 31 # per-instance: 15 for MLP (int5), 31 for attn (int6) + + def forward_qat(self, x: Tensor) -> Tensor: + w32 = self.weight.float() + cr = self._qat_clip_range + if w32.ndim == 2 and self.training: + with torch.no_grad(): + row_max = w32.abs().amax(dim=1) + scale = (row_max / cr).clamp_min(1.0 / cr) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -cr - 1, cr) * scale[:, None]) + w = w32 + (w_q - w32).detach() # STE: gradient flows through as if no quantization + w = w.to(x.dtype) + else: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # NTK-aware RoPE: auto-scales base frequency when seq_len exceeds train_seq_len. + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.rope_dims = rope_dims if rope_dims > 0 else dim + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + rd = self.rope_dims + inv_freq = 1.0 / (base ** (torch.arange(0, rd, 2, dtype=torch.float32) / rd)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + rd = cos.size(-1) * 2 + if rd < x.size(-1): + x_rope, x_pass = x[..., :rd], x[..., rd:] + half = rd // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rot = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rot, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + rope_dims: int = 0, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.head_temp = nn.Parameter(torch.ones(num_heads, dtype=torch.float32)) + self.rope_dims = rope_dims + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.use_xsa = False + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Subtract self-value projection via GQA-aware reshape (no repeat_interleave).""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + q = q * self.head_temp.to(dtype=q.dtype)[None, None, :, None] + if _ATTN_BACKEND == "fa3": + y = _fa3_func(q, k, v, causal=True) + elif _ATTN_BACKEND == "fa2": + y = _fa2_func(q, k, v, causal=True) + else: + # PyTorch SDPA: expects (B, H, T, D), manually expand KV for GQA + q_t = q.transpose(1, 2) + k_t = k.transpose(1, 2) + v_t = v.transpose(1, 2) + if self.num_kv_heads != self.num_heads: + repeats = self.num_heads // self.num_kv_heads + k_t = k_t.repeat_interleave(repeats, dim=1) + v_t = v_t.repeat_interleave(repeats, dim=1) + y = F.scaled_dot_product_attention( + q_t, k_t, v_t, attn_mask=None, is_causal=True, + ).transpose(1, 2) + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) + + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + rope_base: float, + qk_gain_init: float, + rope_dims: int = 0, + layer_idx: int = 0, + ln_scale: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_dims=rope_dims) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + s = self.ln_scale_factor + attn_out = self.attn(self.attn_norm(x) * s) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * s) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + num_memory_tokens: int = 0, + backout_enabled: bool = False, + backout_lambda_init: float = 0.2, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.num_memory_tokens = num_memory_tokens + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + # Memory tokens: learnable global context scratchpad + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(1, num_memory_tokens, model_dim) * 0.02) + else: + self.memory_tokens = None + # Backout connection: subtract mid-layer hidden state from output + self.backout_lambda = nn.Parameter(torch.tensor(backout_lambda_init)) if backout_enabled else None + self.backout_layer = num_layers // 2 + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + rope_dims=rope_dims, + layer_idx=i, + ln_scale=ln_scale, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + bsz = input_ids.size(0) + K = self.num_memory_tokens + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + # Memory tokens: overwrite first K positions during training + if K > 0 and self.memory_tokens is not None: + mem = self.memory_tokens.expand(bsz, -1, -1).to(dtype=x.dtype) + mem = F.rms_norm(mem, (mem.size(-1),)) + x = x.clone() + x[:, :K, :] = mem + target_ids = target_ids.clone() + target_ids[:, :K] = -100 # ignore loss on memory positions + x0 = x + skips: list[Tensor] = [] + + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + # Capture midpoint for backout connection + x_backout = x if self.backout_lambda is not None else None + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + if x_backout is not None: + x = x - self.backout_lambda.to(x.dtype) * x_backout + + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, ignore_index=-100, reduction="mean") + + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, ignore_index=-100, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + + return main_loss + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + bsz = input_ids.size(0) + K = self.num_memory_tokens + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + # Memory tokens: prepend during eval so all real tokens keep context + if K > 0 and self.memory_tokens is not None: + mem = self.memory_tokens.expand(bsz, -1, -1).to(dtype=x.dtype) + mem = F.rms_norm(mem, (mem.size(-1),)) + x = torch.cat([mem, x], dim=1) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + x_backout = x if self.backout_lambda is not None else None + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + if x_backout is not None: + x = x - self.backout_lambda.to(x.dtype) * x_backout + x = self.final_norm(x) + # Remove memory token positions from output + if K > 0 and self.memory_tokens is not None: + x = x[:, K:, :] + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +# SLIDING WINDOW EVALUATION + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# INT6 MIXED QUANTIZATION (transplanted from working diagnostic scripts) + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -clip_range - 1, clip_range).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range - 1, clip_range).to(torch.int8) + return q, scale + +def mixed_quantize_int56(state_dict: dict[str, Tensor], prune_frac: float = 0.03): + """Mixed int5 (MLP) / int6 (attention) quantization with magnitude pruning.""" + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + # Magnitude pruning: zero out smallest weights for better compression + if prune_frac > 0 and t.ndim == 2 and t.numel() > 65536: + threshold = torch.quantile(t.abs().float().flatten(), prune_frac) + t = t.clone() + t[t.abs() < threshold] = 0.0 + if cat == "mlp" and t.ndim >= 1: + # MLP weights: int5 [-16, 15] — higher zstd compression ratio + q, s = quantize_intN_per_row(t, clip_range=15) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int5"} + elif cat == "attn" and t.ndim >= 1: + # Attention weights: int6 [-32, 31] — precision-sensitive + q, s = quantize_intN_per_row(t, clip_range=31) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + # Embeddings and other: int8 + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +def ttt_adapt( + args: Hyperparameters, + model: nn.Module, + device: torch.device, + val_tokens: Tensor, + rank: int = 0, + world_size: int = 1, + log_fn=None, +) -> None: + """Test-time training: SGD adaptation on val data (backward-looking only).""" + seq_len = args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + batch_seqs = args.ttt_batch_seqs + + # Freeze first N blocks for stability + frozen_params = set() + if args.ttt_freeze_blocks > 0: + for i, block in enumerate(model.blocks): + if i < args.ttt_freeze_blocks: + for p in block.parameters(): + p.requires_grad_(False) + frozen_params.add(id(p)) + + ttt_params = [p for p in model.parameters() if p.requires_grad] + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + + my_start = (total_seqs * rank) // world_size + my_end = (total_seqs * (rank + 1)) // world_size + + model.train() + for epoch in range(args.ttt_epochs): + epoch_loss = 0.0 + n_batches = 0 + for batch_start in range(my_start, my_end, batch_seqs): + batch_end = min(batch_start + batch_seqs, my_end) + raw_start = batch_start * seq_len + raw_end = batch_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y) + loss.backward() + torch.nn.utils.clip_grad_norm_(ttt_params, 1.0) + # Sync gradients across ranks + if dist.is_available() and dist.is_initialized(): + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + optimizer.step() + epoch_loss += loss.item() + n_batches += 1 + if log_fn and n_batches > 0: + log_fn(f"ttt:epoch:{epoch + 1}/{args.ttt_epochs} loss:{epoch_loss / n_batches:.4f}") + + # Unfreeze + for p in model.parameters(): + p.requires_grad_(True) + model.eval() + + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # DISTRIBUTED + CUDA SETUP + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # TOKENIZER + VALIDATION METRIC SETUP + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + num_memory_tokens=args.num_memory_tokens, + backout_enabled=args.backout_enabled, + backout_lambda_init=args.backout_lambda_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + torch._dynamo.config.optimize_ddp = False + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.mtp_num_heads > 0: + matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + if base_model.backout_lambda is not None: + scalar_params.append(base_model.backout_lambda) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.memory_tokens is not None: + tok_params.append({"params": [base_model.memory_tokens], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # MAIN TRAINING LOOP + + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + + ema_state: dict[str, Tensor] | None = None + if args.ema_enabled: + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + qat_threshold = float(os.environ.get("QAT_THRESHOLD", "0.1")) + if args.late_qat and scale < qat_threshold and not _QAT_ACTIVE: + _enable_qat(base_model) + # Recompile with QAT forward paths + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + if ema_state is not None: + d = args.ema_decay + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(d).add_(t.detach().float(), alpha=1.0 - d) + + if args.swa_enabled and not args.ema_enabled and scale < 0.5 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name].add_(t.detach().float()) + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + if ema_state is not None: + log0("ema:applying EMA weights") + avg_state = {name: t.to(dtype=base_model.state_dict()[name].dtype) + for name, t in ema_state.items()} + del ema_state + base_model.load_state_dict(avg_state, strict=True) + elif args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints") + avg_state = {name: (t / swa_count).to(dtype=base_model.state_dict()[name].dtype) + for name, t in swa_state.items()} + del swa_state + base_model.load_state_dict(avg_state, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + quant_result, quant_meta = mixed_quantize_int56(sd_cpu) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int56.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") + + # Roundtrip: decompress + dequantize into fresh model + eval + if distributed: + dist.barrier() + with open("final_model.int56.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed(quant_state["w"], quant_state["m"], sd_cpu) + + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + num_memory_tokens=args.num_memory_tokens, + backout_enabled=args.backout_enabled, + backout_lambda_init=args.backout_lambda_init, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + + # Test-Time Training: adapt quantized model on val data + if args.ttt_enabled: + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_adapt(args, eval_model, device, val_tokens, rank, world_size, log_fn=log0) + torch.cuda.synchronize() + log0(f"ttt:done time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + + # Standard non-overlapping eval (sanity check) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval (submission score) + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + + # Second sliding window eval at stride=64 for submission comparison + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() From ac8273eeee2e5f3cff8a8813958da6a81ccb76fd Mon Sep 17 00:00:00 2001 From: Vytautas <64431422+vytautas-bunevicius@users.noreply.github.com> Date: Sun, 22 Mar 2026 15:14:37 +0200 Subject: [PATCH 2/3] feat(records): add earlier QAT controls and experiment log --- .gitignore | 10 +- EXPERIMENT_LOG.md | 112 ++++++++++++++++++ .../README.md | 19 ++- .../train_gpt.py | 61 ++++++++-- 4 files changed, 190 insertions(+), 12 deletions(-) create mode 100644 EXPERIMENT_LOG.md diff --git a/.gitignore b/.gitignore index 3423c416a..f0ecf043c 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,12 @@ data/manifest.json data/docs_selected.jsonl .mypy_cache/ .venv -logs/ \ No newline at end of file +logs/ + +# Local agent/editor artifacts +.claude/ +.codex/ +.cursor/ +.roo/ +.windsurf/ +.augment/ diff --git a/EXPERIMENT_LOG.md b/EXPERIMENT_LOG.md new file mode 100644 index 000000000..548cc4ccb --- /dev/null +++ b/EXPERIMENT_LOG.md @@ -0,0 +1,112 @@ +# Experiment Log + +This file is a short record of what has already been tried in this repo so we do not waste time repeating the same ideas without a clear change in approach. + +## Current best result + +- Branch: `submission/sota-attempt` +- Record folder: `records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466` +- Best result so far: `1.14657797 val_bpb` +- Artifact size: `14,706,424 bytes` +- Setting: `8xH100`, about `605s` train and `340s` eval +- Important note: this run used PyTorch SDPA, not FA3, so training throughput was worse than expected + +## What has already been tried + +### Baseline reproduction + +- Reproduced the public baseline path and used it as the starting point. +- Baseline leaderboard score in this repo: `1.2244 val_bpb`. +- No need to keep re-running the baseline unless checking infra or correctness. + +### 11-layer 512-dim Transformer line + +- Main working direction is an `11L`, `512d`, tied-embedding Transformer. +- Includes U-Net-style skip connections, partial RoPE, EMA, LN scale, XSA on last layers, relu-squared MLP, SmearGate, ortho+muP init, and Muon. + +### Mixed low-bit export + +- Mixed quantization is already in use and works better than a simple uniform export in this line: + - MLP weights: int5 + - Attention weights: int6 + - Embeddings: int8 +- Magnitude pruning before export is also already tried. +- Outcome: artifact dropped to `14.7MB` with useful headroom left. + +### QAT + +- QAT was already added and fixed. +- The earlier issue was that late QAT was effectively dead code under `torch.compile`. +- Current code now supports: + - `QAT_ENABLED=1` + - `QAT_START_STEP` + - `QAT_START_FRAC` +- What failed before: QAT only kicked in at the very end and got effectively `1` training step. +- The next sensible step is tuning earlier QAT, not proposing QAT from scratch again. + +### Test-time training + +- Post-quantization TTT is already implemented and clearly helps. +- Current setup: + - SGD + - 3 epochs + - `lr=0.002` + - `momentum=0.9` + - first 2 blocks frozen +- Outcome: + - post-quant roundtrip: about `1.1697` + - post-TTT sliding eval: about `1.1466` + +### Small byte-efficient additions + +- Already tried in the current best run: + - BigramHash increased to `10240` + - `64` memory tokens + - backout connection + - per-head temperature +- These are already part of the current recipe, so they are not new ideas unless the mechanism changes. + +### Eval stride + +- Sliding-window eval with stride `32` was tested against stride `64`. +- In this run, it made essentially no difference: both landed at `1.1466`. + +## Operational lessons + +### Remote failures were infra, not model bugs + +- A previous RunPod failure happened because the repo was not cloned on the pod. +- Another source of confusion was running from the wrong path. +- Before debugging model code on remote, verify: + - the repo is cloned + - the correct branch is checked out + - the target file exists + +### Attention backend matters a lot + +- The strong run used SDPA and got about `110ms/step`. +- FA3 is expected to be materially faster and is a high-priority next run. +- Current code supports `ATTN_BACKEND=auto|fa3|fa2|sdpa`. +- If FA3 is required, use `ATTN_BACKEND=fa3` so the run fails fast instead of silently falling back. + +## Things that likely deserve the next experiments + +- Earlier QAT that runs for a meaningful chunk of training +- FA3 on Hopper to buy more steps under the 10-minute budget +- Spending the remaining artifact headroom on a robust capacity increase: + - likely a 12th layer + - or a parameter-shared / recurrent refinement step +- Tighter TTT tuning instead of inventing a completely new eval trick + +## Papers already tied to this line + +- QQQ: `https://arxiv.org/abs/2406.09904` +- BitNet b1.58: `https://arxiv.org/abs/2402.17764` +- TTT: `https://arxiv.org/abs/2407.04620` +- FlashAttention-3: `https://tridao.me/blog/2024/flash3/` + +## Read this first + +- `README.md` +- `records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/README.md` +- `records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/train_gpt.py` diff --git a/records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/README.md b/records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/README.md index bd2faaf29..3ddb236d1 100644 --- a/records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/README.md +++ b/records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/README.md @@ -6,7 +6,7 @@ Built on PR #315 (1.1248). Ran with PyTorch SDPA instead of FA3, so throughput w ## What we added to PR #315 -**1. Working QAT.** PR #315's late QAT is dead code because `torch.compile` constant-folds `CastedLinear._qat_enabled` at first trace. We swap the `forward` method to `forward_qat` per instance and recompile. QAT noise matches the export scheme: int5 STE for MLP, int6 STE for attention. +**1. Working QAT.** PR #315's late QAT is dead code because `torch.compile` constant-folds `CastedLinear._qat_enabled` at first trace. We swap the `forward` method to `forward_qat` per instance and recompile. QAT noise matches the export scheme: int5 STE for MLP, int6 STE for attention. The current script also exposes `QAT_ENABLED`, `QAT_START_STEP`, and `QAT_START_FRAC` so we can turn QAT on earlier instead of hoping it only catches the last few steps. **2. Mixed int5/int6 quantization + magnitude pruning.** MLP weights get int5 ([-16, 15]), attention gets int6 ([-32, 31]), embeddings stay int8. 3% magnitude pruning before quantization. Result: 14.7MB with 1.3MB headroom. @@ -43,14 +43,27 @@ Built on PR #315 (1.1248). Ran with PyTorch SDPA instead of FA3, so throughput w - FA3 (30% more training steps) - 12th layer with the 1.3MB budget headroom -- QAT getting more than 1 step (it kicked in at step 5128, stopped at 5129) +- Earlier QAT so it gets hundreds to thousands of steps instead of 1 + +## Papers behind these ideas + +- Low-bit quantization direction: [QQQ: Quality Quattuor-Bit Quantization for Large Language Models](https://arxiv.org/abs/2406.09904) +- Very low-bit training motivation: [The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits](https://arxiv.org/abs/2402.17764) +- Test-time training: [Learning to (Learn at Test Time): RNNs with Expressive Hidden States](https://arxiv.org/abs/2407.04620) +- Faster Hopper attention kernels: [FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision](https://tridao.me/blog/2024/flash3/) ## How to run ```bash +cd /workspace +git clone -b submission/sota-attempt parameter-golf +cd parameter-golf pip install huggingface-hub datasets sentencepiece tqdm zstandard python3 data/cached_challenge_fineweb.py --variant sp1024 -torchrun --standalone --nproc_per_node=8 records/track_10min_16mb/2026-03-21_SOTA/train_gpt.py +ATTN_BACKEND=auto QAT_START_FRAC=0.8 \ +torchrun --standalone --nproc_per_node=8 records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/train_gpt.py ``` +If FA3 is installed, set `ATTN_BACKEND=fa3` to fail fast when the kernel is missing instead of silently falling back. + Single seed (1337), torch 2.4.1+cu124, 8xH100 SXM on RunPod. diff --git a/records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/train_gpt.py b/records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/train_gpt.py index ad3205f0f..7a22dfaf2 100644 --- a/records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/train_gpt.py +++ b/records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/train_gpt.py @@ -33,15 +33,29 @@ from torch import Tensor, nn from torch.nn.parallel import DistributedDataParallel as DDP -try: - from flash_attn_interface import flash_attn_func as _fa3_func - _ATTN_BACKEND = "fa3" -except ImportError: +_REQUESTED_ATTN_BACKEND = os.environ.get("ATTN_BACKEND", "auto").strip().lower() +if _REQUESTED_ATTN_BACKEND not in {"auto", "fa3", "fa2", "sdpa"}: + raise ValueError("ATTN_BACKEND must be one of: auto, fa3, fa2, sdpa") + +_fa3_func = None +_fa2_func = None +_ATTN_BACKEND = "sdpa" +if _REQUESTED_ATTN_BACKEND in {"auto", "fa3"}: + try: + from flash_attn_interface import flash_attn_func as _fa3_func + + _ATTN_BACKEND = "fa3" + except ImportError: + if _REQUESTED_ATTN_BACKEND == "fa3": + raise +if _ATTN_BACKEND == "sdpa" and _REQUESTED_ATTN_BACKEND in {"auto", "fa2"}: try: from flash_attn import flash_attn_func as _fa2_func + _ATTN_BACKEND = "fa2" except ImportError: - _ATTN_BACKEND = "sdpa" + if _REQUESTED_ATTN_BACKEND == "fa2": + raise # HYPERPARAMETERS # Default Simple Baseline run: @@ -114,6 +128,9 @@ class Hyperparameters: rope_dims = int(os.environ.get("ROPE_DIMS", 16)) ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) late_qat = bool(int(os.environ.get("LATE_QAT", "1"))) + qat_start_step = int(os.environ.get("QAT_START_STEP", "0")) + qat_start_frac = float(os.environ.get("QAT_START_FRAC", "0.8")) + qat_threshold = float(os.environ.get("QAT_THRESHOLD", "0.1")) bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) @@ -554,6 +571,16 @@ def _enable_qat(module: nn.Module) -> None: m._qat_clip_range = 31 # int6 for attention m.forward = m.forward_qat # type: ignore[assignment] + +def _resolve_qat_start_step(args: Hyperparameters) -> int | None: + if args.qat_enabled: + return 0 + if args.qat_start_step > 0: + return args.qat_start_step + if args.qat_start_frac > 0: + return min(max(int(args.iterations * args.qat_start_frac), 0), args.iterations) + return None + class CastedLinear(nn.Linear): def forward(self, x: Tensor) -> Tensor: w = self.weight.to(x.dtype) @@ -1337,10 +1364,13 @@ def log0(msg: str, console: bool = True) -> None: for module in base_model.modules(): if isinstance(module, CastedLinear): module.float() + if args.qat_enabled: + _enable_qat(base_model) restore_low_dim_params_to_fp32(base_model) torch._dynamo.config.optimize_ddp = False compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + qat_start_step = _resolve_qat_start_step(args) # Optimizer split: # - token embedding (Adam) uses EMBED_LR @@ -1413,8 +1443,13 @@ def log0(msg: str, console: bool = True) -> None: log0(f"model_params:{n_params}") log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"attn_backend:requested={_REQUESTED_ATTN_BACKEND} active={_ATTN_BACKEND}") log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"qat:active={_QAT_ACTIVE} late_qat:{args.late_qat} " + f"start_step:{qat_start_step} start_frac:{args.qat_start_frac:.3f} threshold:{args.qat_threshold:.4f}" + ) log0( f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " @@ -1527,13 +1562,23 @@ def lr_mul(step: int, elapsed_ms: float) -> float: elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) scale = lr_mul(step, elapsed_ms) - qat_threshold = float(os.environ.get("QAT_THRESHOLD", "0.1")) - if args.late_qat and scale < qat_threshold and not _QAT_ACTIVE: + should_enable_late_qat = ( + args.late_qat + and not _QAT_ACTIVE + and ( + (qat_start_step is not None and step >= qat_start_step) + or scale < args.qat_threshold + ) + ) + if should_enable_late_qat: _enable_qat(base_model) # Recompile with QAT forward paths compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + log0( + f"late_qat:enabled step:{step} scale:{scale:.4f} " + f"start_step:{qat_start_step} threshold:{args.qat_threshold:.4f}" + ) zero_grad_all() train_loss = torch.zeros((), device=device) for micro_step in range(grad_accum_steps): From 75168efc0eed6e97de0dc6dba361d8d06766748a Mon Sep 17 00:00:00 2001 From: Vytautas <64431422+vytautas-bunevicius@users.noreply.github.com> Date: Sun, 22 Mar 2026 15:20:12 +0200 Subject: [PATCH 3/3] feat(records): add earlier QAT controls and causal TTT eval --- .gitignore | 10 +- EXPERIMENT_LOG.md | 112 -------- .../README.md | 9 +- .../train_gpt.py | 261 +++++++++++++++++- 4 files changed, 255 insertions(+), 137 deletions(-) delete mode 100644 EXPERIMENT_LOG.md diff --git a/.gitignore b/.gitignore index f0ecf043c..3423c416a 100644 --- a/.gitignore +++ b/.gitignore @@ -8,12 +8,4 @@ data/manifest.json data/docs_selected.jsonl .mypy_cache/ .venv -logs/ - -# Local agent/editor artifacts -.claude/ -.codex/ -.cursor/ -.roo/ -.windsurf/ -.augment/ +logs/ \ No newline at end of file diff --git a/EXPERIMENT_LOG.md b/EXPERIMENT_LOG.md deleted file mode 100644 index 548cc4ccb..000000000 --- a/EXPERIMENT_LOG.md +++ /dev/null @@ -1,112 +0,0 @@ -# Experiment Log - -This file is a short record of what has already been tried in this repo so we do not waste time repeating the same ideas without a clear change in approach. - -## Current best result - -- Branch: `submission/sota-attempt` -- Record folder: `records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466` -- Best result so far: `1.14657797 val_bpb` -- Artifact size: `14,706,424 bytes` -- Setting: `8xH100`, about `605s` train and `340s` eval -- Important note: this run used PyTorch SDPA, not FA3, so training throughput was worse than expected - -## What has already been tried - -### Baseline reproduction - -- Reproduced the public baseline path and used it as the starting point. -- Baseline leaderboard score in this repo: `1.2244 val_bpb`. -- No need to keep re-running the baseline unless checking infra or correctness. - -### 11-layer 512-dim Transformer line - -- Main working direction is an `11L`, `512d`, tied-embedding Transformer. -- Includes U-Net-style skip connections, partial RoPE, EMA, LN scale, XSA on last layers, relu-squared MLP, SmearGate, ortho+muP init, and Muon. - -### Mixed low-bit export - -- Mixed quantization is already in use and works better than a simple uniform export in this line: - - MLP weights: int5 - - Attention weights: int6 - - Embeddings: int8 -- Magnitude pruning before export is also already tried. -- Outcome: artifact dropped to `14.7MB` with useful headroom left. - -### QAT - -- QAT was already added and fixed. -- The earlier issue was that late QAT was effectively dead code under `torch.compile`. -- Current code now supports: - - `QAT_ENABLED=1` - - `QAT_START_STEP` - - `QAT_START_FRAC` -- What failed before: QAT only kicked in at the very end and got effectively `1` training step. -- The next sensible step is tuning earlier QAT, not proposing QAT from scratch again. - -### Test-time training - -- Post-quantization TTT is already implemented and clearly helps. -- Current setup: - - SGD - - 3 epochs - - `lr=0.002` - - `momentum=0.9` - - first 2 blocks frozen -- Outcome: - - post-quant roundtrip: about `1.1697` - - post-TTT sliding eval: about `1.1466` - -### Small byte-efficient additions - -- Already tried in the current best run: - - BigramHash increased to `10240` - - `64` memory tokens - - backout connection - - per-head temperature -- These are already part of the current recipe, so they are not new ideas unless the mechanism changes. - -### Eval stride - -- Sliding-window eval with stride `32` was tested against stride `64`. -- In this run, it made essentially no difference: both landed at `1.1466`. - -## Operational lessons - -### Remote failures were infra, not model bugs - -- A previous RunPod failure happened because the repo was not cloned on the pod. -- Another source of confusion was running from the wrong path. -- Before debugging model code on remote, verify: - - the repo is cloned - - the correct branch is checked out - - the target file exists - -### Attention backend matters a lot - -- The strong run used SDPA and got about `110ms/step`. -- FA3 is expected to be materially faster and is a high-priority next run. -- Current code supports `ATTN_BACKEND=auto|fa3|fa2|sdpa`. -- If FA3 is required, use `ATTN_BACKEND=fa3` so the run fails fast instead of silently falling back. - -## Things that likely deserve the next experiments - -- Earlier QAT that runs for a meaningful chunk of training -- FA3 on Hopper to buy more steps under the 10-minute budget -- Spending the remaining artifact headroom on a robust capacity increase: - - likely a 12th layer - - or a parameter-shared / recurrent refinement step -- Tighter TTT tuning instead of inventing a completely new eval trick - -## Papers already tied to this line - -- QQQ: `https://arxiv.org/abs/2406.09904` -- BitNet b1.58: `https://arxiv.org/abs/2402.17764` -- TTT: `https://arxiv.org/abs/2407.04620` -- FlashAttention-3: `https://tridao.me/blog/2024/flash3/` - -## Read this first - -- `README.md` -- `records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/README.md` -- `records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/train_gpt.py` diff --git a/records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/README.md b/records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/README.md index 3ddb236d1..4054375c9 100644 --- a/records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/README.md +++ b/records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/README.md @@ -1,16 +1,18 @@ # Non-record submission: 11L mixed int5/int6 + working QAT + TTT + 8 additions -**val_bpb = 1.1466** (sliding window, stride=32, post-TTT) | **14.7 MB** artifact | 8xH100 SXM, 605s train + 340s eval +**Historical run:** `1.1466 val_bpb` (sliding window, stride=32, original post-TTT flow) | **14.7 MB** artifact | 8xH100 SXM, 605s train + 340s eval Built on PR #315 (1.1248). Ran with PyTorch SDPA instead of FA3, so throughput was 110ms/step instead of 85ms. Got 5,129 steps instead of ~7,000. Score should drop with FA3. +Note: the historical `1.1466` number above came from the original pre-eval TTT flow in this run. The current script has been updated to report plain no-TTT metrics and causal TTT metrics separately so future runs do not adapt on unseen eval tokens before scoring them. That means the checked-in script should be rerun before using it for a fresh official score claim. + ## What we added to PR #315 **1. Working QAT.** PR #315's late QAT is dead code because `torch.compile` constant-folds `CastedLinear._qat_enabled` at first trace. We swap the `forward` method to `forward_qat` per instance and recompile. QAT noise matches the export scheme: int5 STE for MLP, int6 STE for attention. The current script also exposes `QAT_ENABLED`, `QAT_START_STEP`, and `QAT_START_FRAC` so we can turn QAT on earlier instead of hoping it only catches the last few steps. **2. Mixed int5/int6 quantization + magnitude pruning.** MLP weights get int5 ([-16, 15]), attention gets int6 ([-32, 31]), embeddings stay int8. 3% magnitude pruning before quantization. Result: 14.7MB with 1.3MB headroom. -**3. Test-time training.** 3 epochs of SGD on validation tokens post-quantization. lr=0.002, momentum=0.9, first 2 blocks frozen. Gradients synced via all_reduce(AVG). Took 83s on 8xH100. Moved BPB from 1.1697 to 1.1466. +**3. Test-time training.** This run originally used post-quantization SGD on validation tokens before final scoring. The script now also includes a causal TTT path that scores each eval chunk first and only then adapts on that chunk, which is the safer version for future experiments. **4. BigramHash 10240.** Up from 2048 in PR #315. @@ -34,7 +36,8 @@ Built on PR #315 (1.1248). Ran with PyTorch SDPA instead of FA3, so throughput w | Pre-quant val_bpb | 1.1597 | | Post-quant val_bpb | 1.1697 | | Quant gap | +0.0100 | -| Post-TTT sliding s32 | **1.1466** | +| Historical post-TTT sliding s32 | **1.1466** | +| Historical no-TTT roundtrip | 1.1697 | | Artifact | 14,706,424 bytes | | TTT time | 83s | | Peak memory | 25,777 MiB/GPU | diff --git a/records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/train_gpt.py b/records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/train_gpt.py index 7a22dfaf2..b6bb1ed7a 100644 --- a/records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/train_gpt.py +++ b/records/track_10min_16mb/2026-03-22_11L_MixedInt56_QAT_TTT_1.1466/train_gpt.py @@ -139,11 +139,11 @@ class Hyperparameters: backout_enabled = bool(int(os.environ.get("BACKOUT_ENABLED", "1"))) backout_lambda_init = float(os.environ.get("BACKOUT_LAMBDA_INIT", 0.2)) ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "1"))) - ttt_lr = float(os.environ.get("TTT_LR", 0.002)) - ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_lr = float(os.environ.get("TTT_LR", 0.008)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 20)) ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) - ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 2)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) # MUON OPTIMIZER # @@ -1100,6 +1100,198 @@ def eval_val_sliding( return val_loss, bits_per_token * tokens_per_byte +def _broadcast_eval_pair(rank: int, device: torch.device, value: tuple[float, float]) -> tuple[float, float]: + pair = torch.tensor(value if rank == 0 else (0.0, 0.0), device=device, dtype=torch.float64) + if dist.is_available() and dist.is_initialized(): + dist.broadcast(pair, src=0) + return float(pair[0].item()), float(pair[1].item()) + + +def _ttt_masked_targets(y: Tensor, spans: list[tuple[int, int]]) -> Tensor: + targets = torch.full_like(y, -100) + for i, (start, end) in enumerate(spans): + if end > start: + targets[i, start:end] = y[i, start:end] + return targets + + +def _ttt_sync_grads(params: list[Tensor]) -> None: + if dist.is_available() and dist.is_initialized(): + for p in params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + + +def _ttt_freeze_for_eval(args: Hyperparameters, model: nn.Module) -> list[Tensor]: + if args.ttt_freeze_blocks > 0: + for i, block in enumerate(model.blocks): + if i < args.ttt_freeze_blocks: + for p in block.parameters(): + p.requires_grad_(False) + return [p for p in model.parameters() if p.requires_grad] + + +def eval_val_causal_ttt( + args: Hyperparameters, + base_model: nn.Module | None, + rank: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, + log_fn=None, +) -> tuple[float, float]: + if rank != 0: + return _broadcast_eval_pair(rank, device, (0.0, 0.0)) + if base_model is None: + raise ValueError("base_model is required on rank 0 for causal TTT eval") + + seq_len = eval_seq_len or args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + batch_seqs = max(args.ttt_batch_seqs, 1) + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + ttt_params = _ttt_freeze_for_eval(args, base_model) + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + + for batch_seq_start in range(0, total_seqs, batch_seqs): + batch_seq_end = min(batch_seq_start + batch_seqs, total_seqs) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + + base_model.eval() + with torch.inference_mode(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = compiled_logits(x) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y.reshape(-1), + reduction="none", + ).reshape_as(y) + loss_sum += nll.to(torch.float64).sum() + token_count += float(y.numel()) + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + byte_count += token_bytes.to(torch.float64).sum() + + base_model.train() + for _ in range(args.ttt_epochs): + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = base_model(x, y) + loss.backward() + torch.nn.utils.clip_grad_norm_(ttt_params, 1.0) + optimizer.step() + + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + if log_fn is not None: + log_fn("ttt:mode:causal_nonoverlap") + return _broadcast_eval_pair(rank, device, (val_loss, bits_per_token * tokens_per_byte)) + + +def eval_val_sliding_causal_ttt( + args: Hyperparameters, + base_model: nn.Module | None, + rank: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + eval_seq_len: int | None = None, + log_fn=None, +) -> tuple[float, float]: + if rank != 0: + return _broadcast_eval_pair(rank, device, (0.0, 0.0)) + if base_model is None: + raise ValueError("base_model is required on rank 0 for causal sliding TTT eval") + + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + batch_seqs = max(args.ttt_batch_seqs, 1) + window_starts = [ws for ws in range(0, total_tokens, stride) if min(ws + seq_len, total_tokens) - ws >= 1] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + ttt_params = _ttt_freeze_for_eval(args, base_model) + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + + for bi in range(0, len(window_starts), batch_seqs): + batch_ws = window_starts[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + score_spans: list[tuple[int, int]] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + score_start = 0 if ws == 0 else max(wlen - stride, 0) + score_spans.append((score_start, wlen)) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + base_model.eval() + with torch.inference_mode(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, (score_start, wlen) in enumerate(score_spans): + scored_nll = nll[i, score_start:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - score_start) + tgt = y_batch[i, score_start:wlen] + prev = x_batch[i, score_start:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + base_model.train() + adapt_targets = _ttt_masked_targets(y_batch, score_spans) + for _ in range(args.ttt_epochs): + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = base_model(x_batch, adapt_targets) + loss.backward() + torch.nn.utils.clip_grad_norm_(ttt_params, 1.0) + optimizer.step() + + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + if log_fn is not None: + log_fn("ttt:mode:causal_sliding") + return _broadcast_eval_pair(rank, device, (val_loss, bits_per_token * tokens_per_byte)) + + # INT6 MIXED QUANTIZATION (transplanted from working diagnostic scripts) def _classify_param(name: str) -> str: @@ -1722,17 +1914,9 @@ def lr_mul(step: int, elapsed_ms: float) -> float: restore_low_dim_params_to_fp32(eval_model) eval_model.load_state_dict(deq_state, strict=True) - # Test-Time Training: adapt quantized model on val data - if args.ttt_enabled: - torch.cuda.synchronize() - t_ttt = time.perf_counter() - ttt_adapt(args, eval_model, device, val_tokens, rank, world_size, log_fn=log0) - torch.cuda.synchronize() - log0(f"ttt:done time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") - compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) - # Standard non-overlapping eval (sanity check) + # Standard non-overlapping eval without TTT (sanity check / leakage-free baseline) torch.cuda.synchronize() t_qeval = time.perf_counter() q_val_loss, q_val_bpb = eval_val( @@ -1747,7 +1931,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float: ) log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - # Sliding window eval (submission score) + # Sliding window eval without TTT sw_seq_len = effective_eval_seq_len if args.eval_stride > 0 and args.eval_stride < sw_seq_len: torch.cuda.synchronize() @@ -1782,6 +1966,57 @@ def lr_mul(step: int, elapsed_ms: float) -> float: ) log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + # Causal test-time training: score each chunk first, then adapt on that chunk. + if args.ttt_enabled: + ttt_model = copy.deepcopy(eval_model) if rank == 0 else None + torch.cuda.synchronize() + t_ttt = time.perf_counter() + q_ttt_val_loss, q_ttt_val_bpb = eval_val_causal_ttt( + args, + ttt_model, + rank, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + log_fn=log0, + ) + torch.cuda.synchronize() + log0( + f"final_int6_causal_ttt_roundtrip val_loss:{q_ttt_val_loss:.4f} val_bpb:{q_ttt_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + log0(f"final_int6_causal_ttt_roundtrip_exact val_loss:{q_ttt_val_loss:.8f} val_bpb:{q_ttt_val_bpb:.8f}") + + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + ttt_model_slide = copy.deepcopy(eval_model) if rank == 0 else None + torch.cuda.synchronize() + t_ttt_slide = time.perf_counter() + sw_ttt_val_loss, sw_ttt_val_bpb = eval_val_sliding_causal_ttt( + args, + ttt_model_slide, + rank, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + log_fn=log0, + ) + torch.cuda.synchronize() + log0( + f"final_int6_causal_ttt_sliding_window val_loss:{sw_ttt_val_loss:.4f} val_bpb:{sw_ttt_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_ttt_slide):.0f}ms" + ) + log0( + f"final_int6_causal_ttt_sliding_window_exact val_loss:{sw_ttt_val_loss:.8f} " + f"val_bpb:{sw_ttt_val_bpb:.8f}" + ) + if distributed: dist.destroy_process_group()