From 2afa1312eb2b892162e0e8899d826e8fa661ff91 Mon Sep 17 00:00:00 2001 From: Alberto Date: Sun, 22 Mar 2026 19:27:32 +0100 Subject: [PATCH] Record: 11L GradQuant + EMA + Sliding Eval (val_bpb=1.1396) 11 layers, 512 dim, MLP 3x with gradient-guided adaptive quantization (int5/int6/int7 per-tensor based on gradient sensitivity), EMA from init (alpha=0.997), SmearGate, NTK-RoPE, XSA last 4 layers, adaptive warmdown, sliding window eval stride=64 with full validation coverage. Reproduced on fresh 8xH100 SXM pod: 5347 steps, 15.9MB artifact. Co-Authored-By: Claude Opus 4.6 --- .../README.md | 54 + .../submission.json | 12 + .../train.log | 108 ++ .../train_gpt.py | 1305 +++++++++++++++++ 4 files changed, 1479 insertions(+) create mode 100644 records/track_10min_16mb/2026-03-22_11L_GradQuant_EMA_SlidingEval/README.md create mode 100644 records/track_10min_16mb/2026-03-22_11L_GradQuant_EMA_SlidingEval/submission.json create mode 100644 records/track_10min_16mb/2026-03-22_11L_GradQuant_EMA_SlidingEval/train.log create mode 100644 records/track_10min_16mb/2026-03-22_11L_GradQuant_EMA_SlidingEval/train_gpt.py diff --git a/records/track_10min_16mb/2026-03-22_11L_GradQuant_EMA_SlidingEval/README.md b/records/track_10min_16mb/2026-03-22_11L_GradQuant_EMA_SlidingEval/README.md new file mode 100644 index 000000000..09c95cb2b --- /dev/null +++ b/records/track_10min_16mb/2026-03-22_11L_GradQuant_EMA_SlidingEval/README.md @@ -0,0 +1,54 @@ +# 11L Gradient-Guided Quant + EMA + Sliding Eval + +**val_bpb: 1.1396** (post int8+zstd quantization roundtrip, sliding window eval stride=64, full validation coverage) + +## Run Command + +```bash +# Setup +pip install sentencepiece zstandard +python3 data/cached_challenge_fineweb.py + +# Train + evaluate (defaults baked into train_gpt.py) +torchrun --nproc_per_node=8 records/track_10min_16mb/2026-03-22_11L_GradQuant_EMA_SlidingEval/train_gpt.py +``` + +All hyperparameters are set as defaults in `train_gpt.py`. No env vars needed. + +## Architecture + +- 11 layers, 512 dim, 8 heads, 4 KV heads (GQA) +- MLP 3x expansion (hidden=1536), relu^2 activation +- Encoder-decoder skip connections with learned weights +- SmearGate residual mixing +- NTK-aware RoPE positional encoding +- XSA (cross-sequence attention) on last 4 layers +- Orthogonal initialization +- Tied input/output embeddings + +## Training + +- Muon optimizer: matrix_lr=0.025, scalar_lr=0.025, momentum=0.99 +- AdamW for embeddings/scalars: WD=0.04 +- Momentum warmup: 0.92 -> 0.99 over 1500 steps +- Adaptive warmdown: 3000 iters (auto-capped to 55% of total steps for hardware robustness) +- Warmup: 20 steps +- seq_len=2048, batch=786K tokens +- grad_clip=0.3 +- EMA: alpha=0.997, initialized from model init + +## Compression + +- Gradient-guided adaptive quantization: per-tensor bit assignment based on gradient sensitivity + - Top 45% (highest gradient): int7 (127 values) + - Middle 40%: int6 (63 values) + - Bottom 15% (lowest gradient): int5 (31 values) +- zstd level 22 compression +- Artifact: 15,913,419 bytes (code: ~59KB, model: ~15.9MB) + +## Evaluation + +- Sliding window eval with stride=64, full validation set coverage (~121K windows/GPU) +- Per-token loss scoring: only the last 64 tokens of each 2048-token window are scored (full context) +- Post-quantization roundtrip: quantize -> decompress -> evaluate +- Eval time: ~6 min (runs after training completes) diff --git a/records/track_10min_16mb/2026-03-22_11L_GradQuant_EMA_SlidingEval/submission.json b/records/track_10min_16mb/2026-03-22_11L_GradQuant_EMA_SlidingEval/submission.json new file mode 100644 index 000000000..5ed5d3048 --- /dev/null +++ b/records/track_10min_16mb/2026-03-22_11L_GradQuant_EMA_SlidingEval/submission.json @@ -0,0 +1,12 @@ +{ + "author": "Alberto Luengo", + "github_id": "albertorkive", + "name": "11L Gradient-Guided Quant + EMA + Sliding Eval", + "blurb": "11 layers, 512 dim, MLP 3x, gradient-guided adaptive int5/int6/int7 quantization, EMA (alpha=0.997, from init), SmearGate, NTK-RoPE, XSA (last 4 layers), zstd-22 compression, sliding window eval stride=64. Muon optimizer with momentum warmup, orthogonal init, adaptive warmdown.", + "date": "2026-03-22T17:17:00Z", + "val_loss": 1.92417619, + "val_bpb": 1.13960858, + "bytes_total": 15913419, + "bytes_code": 59010, + "bytes_model_int8_zstd": 15854409 +} diff --git a/records/track_10min_16mb/2026-03-22_11L_GradQuant_EMA_SlidingEval/train.log b/records/track_10min_16mb/2026-03-22_11L_GradQuant_EMA_SlidingEval/train.log new file mode 100644 index 000000000..f87376b10 --- /dev/null +++ b/records/track_10min_16mb/2026-03-22_11L_GradQuant_EMA_SlidingEval/train.log @@ -0,0 +1,108 @@ +W0322 17:17:48.534000 125682743755392 torch/distributed/run.py:779] +W0322 17:17:48.534000 125682743755392 torch/distributed/run.py:779] ***************************************** +W0322 17:17:48.534000 125682743755392 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0322 17:17:48.534000 125682743755392 torch/distributed/run.py:779] ***************************************** +logs/ccc19fc1-632c-4c37-83d3-75602d4e859d.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:26502232 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=True math=True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +ema:initialized from model init (alpha=0.997) +step:0/20000 val_loss:6.9321 val_bpb:4.1056 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9309 train_time:157ms step_avg:157.17ms +step:2/20000 train_loss:8.1553 train_time:259ms step_avg:129.68ms +step:3/20000 train_loss:8.4565 train_time:365ms step_avg:121.76ms +step:4/20000 train_loss:7.4809 train_time:471ms step_avg:117.77ms +step:5/20000 train_loss:6.7898 train_time:577ms step_avg:115.37ms +step:6/20000 train_loss:6.3670 train_time:683ms step_avg:113.79ms +step:7/20000 train_loss:6.2948 train_time:788ms step_avg:112.63ms +step:8/20000 train_loss:6.2930 train_time:894ms step_avg:111.73ms +step:9/20000 train_loss:6.0769 train_time:999ms step_avg:111.05ms +step:10/20000 train_loss:5.9117 train_time:1106ms step_avg:110.56ms +step:200/20000 train_loss:2.4001 train_time:21628ms step_avg:108.14ms +step:400/20000 train_loss:2.4496 train_time:43541ms step_avg:108.85ms +step:600/20000 train_loss:2.3498 train_time:65273ms step_avg:108.79ms +step:800/20000 train_loss:2.2478 train_time:87264ms step_avg:109.08ms +step:1000/20000 train_loss:2.2832 train_time:109011ms step_avg:109.01ms +step:1000/20000 val_loss:2.1985 val_bpb:1.3021 train_time:109016ms step_avg:109.02ms +step:1200/20000 train_loss:2.3597 train_time:130809ms step_avg:109.01ms +step:1400/20000 train_loss:2.1911 train_time:152694ms step_avg:109.07ms +step:1600/20000 train_loss:2.0803 train_time:174403ms step_avg:109.00ms +step:1800/20000 train_loss:2.1607 train_time:196374ms step_avg:109.10ms +step:2000/20000 train_loss:2.0712 train_time:218101ms step_avg:109.05ms +step:2000/20000 val_loss:2.0976 val_bpb:1.2423 train_time:218106ms step_avg:109.05ms +step:2200/20000 train_loss:2.1394 train_time:239841ms step_avg:109.02ms +step:2400/20000 train_loss:2.0699 train_time:261442ms step_avg:108.93ms +grad_guided_quant: started accumulating (79 tensors) +step:2600/20000 train_loss:2.1105 train_time:283954ms step_avg:109.21ms +step:2800/20000 train_loss:2.1478 train_time:307106ms step_avg:109.68ms +step:3000/20000 train_loss:2.1486 train_time:329999ms step_avg:110.00ms +step:3000/20000 val_loss:2.0416 val_bpb:1.2091 train_time:329999ms step_avg:110.00ms +step:3200/20000 train_loss:2.1548 train_time:353047ms step_avg:110.33ms +step:3400/20000 train_loss:1.9952 train_time:375958ms step_avg:110.58ms +step:3600/20000 train_loss:2.0697 train_time:399077ms step_avg:110.85ms +step:3800/20000 train_loss:2.0410 train_time:421956ms step_avg:111.04ms +step:4000/20000 train_loss:1.9437 train_time:445096ms step_avg:111.27ms +step:4000/20000 val_loss:1.9923 val_bpb:1.1799 train_time:445096ms step_avg:111.27ms +step:4200/20000 train_loss:2.1129 train_time:468136ms step_avg:111.46ms +step:4400/20000 train_loss:1.9926 train_time:490961ms step_avg:111.58ms +step:4600/20000 train_loss:1.7964 train_time:514095ms step_avg:111.76ms +step:4800/20000 train_loss:2.3775 train_time:537003ms step_avg:111.88ms +step:5000/20000 train_loss:2.0476 train_time:560152ms step_avg:112.03ms +step:5000/20000 val_loss:1.9315 val_bpb:1.1439 train_time:560153ms step_avg:112.03ms +step:5200/20000 train_loss:1.9847 train_time:582960ms step_avg:112.11ms +step:5347/20000 val_loss:1.9145 val_bpb:1.1339 train_time:600060ms step_avg:112.22ms +stopping_early: wallclock_cap train_time:600060ms step:5347/20000 +peak memory allocated: 28149 MiB reserved: 28408 MiB +ema:loading shadow weights (alpha=0.997) +Serialized model: 105001444 bytes +Code size: 59010 bytes +Total submission size: 105060454 bytes +grad_guided_quant: 79 tensors assigned adaptive bits + bit distribution: {5: 12, 6: 31, 7: 36} +Serialized model int8+zstd: 15854409 bytes (payload:26659168 raw_torch:26714058 payload_ratio:3.94x) +Total submission size int8+zstd: 15913419 bytes +/workspace/parameter-golf/records/track_10min_16mb/2026-03-22_11L_GradQuant_EMA_SlidingEval/train_gpt.py:1285: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(_decompressed), map_location="cpu") +/workspace/parameter-golf/records/track_10min_16mb/2026-03-22_11L_GradQuant_EMA_SlidingEval/train_gpt.py:1285: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(_decompressed), map_location="cpu") +/workspace/parameter-golf/records/track_10min_16mb/2026-03-22_11L_GradQuant_EMA_SlidingEval/train_gpt.py:1285: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(_decompressed), map_location="cpu") +/workspace/parameter-golf/records/track_10min_16mb/2026-03-22_11L_GradQuant_EMA_SlidingEval/train_gpt.py:1285: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(_decompressed), map_location="cpu") +/workspace/parameter-golf/records/track_10min_16mb/2026-03-22_11L_GradQuant_EMA_SlidingEval/train_gpt.py:1285: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(_decompressed), map_location="cpu") +/workspace/parameter-golf/records/track_10min_16mb/2026-03-22_11L_GradQuant_EMA_SlidingEval/train_gpt.py:1285: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(_decompressed), map_location="cpu") +/workspace/parameter-golf/records/track_10min_16mb/2026-03-22_11L_GradQuant_EMA_SlidingEval/train_gpt.py:1285: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(_decompressed), map_location="cpu") +/workspace/parameter-golf/records/track_10min_16mb/2026-03-22_11L_GradQuant_EMA_SlidingEval/train_gpt.py:1285: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(_decompressed), map_location="cpu") +final_int8_zlib_roundtrip val_loss:1.9242 val_bpb:1.1396 eval_time:374873ms +final_int8_zlib_roundtrip_exact val_loss:1.92417619 val_bpb:1.13960858 diff --git a/records/track_10min_16mb/2026-03-22_11L_GradQuant_EMA_SlidingEval/train_gpt.py b/records/track_10min_16mb/2026-03-22_11L_GradQuant_EMA_SlidingEval/train_gpt.py new file mode 100644 index 000000000..c2acceff0 --- /dev/null +++ b/records/track_10min_16mb/2026-03-22_11L_GradQuant_EMA_SlidingEval/train_gpt.py @@ -0,0 +1,1305 @@ +""" +train_gpt.py — Parameter Golf competition entry. +Hard stop: this file must never exceed 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# --------------------- +# HYPERPARAMETERS +# --------------------- + +class Hyperparameters: + data_path = "./data/datasets/fineweb10B_sp1024" + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = "./data/tokenizers/fineweb_1024_bpe.model" + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", "1024")) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + window_size = int(os.environ.get("WINDOW_SIZE", 0)) + num_physical_layers = int(os.environ.get("NUM_PHYSICAL_LAYERS", 0)) + + init_type = os.environ.get("INIT_TYPE", "ortho") + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + data_curriculum = os.environ.get("DATA_CURRICULUM", "") + + compress_algo = os.environ.get("COMPRESS_ALGO", "zstd") + quant_bits_middle = int(os.environ.get("QUANT_BITS_MIDDLE", "6")) + quant_bits_mlp = int(os.environ.get("QUANT_BITS_MLP", "0")) + optimizer_variant = os.environ.get("OPTIMIZER_VARIANT", "standard") + smear_gate = bool(int(os.environ.get("SMEAR_GATE", "1"))) + use_ema = bool(int(os.environ.get("USE_EMA", "1"))) + ema_alpha = float(os.environ.get("EMA_ALPHA", "0.997")) + ema_start_ratio = float(os.environ.get("EMA_START_RATIO", "0.6")) + ema_from_init = bool(int(os.environ.get("EMA_FROM_INIT", "1"))) + use_ntk_rope = bool(int(os.environ.get("USE_NTK_ROPE", "1"))) + xsa_last_n = int(os.environ.get("XSA_LAST_N", "4")) + fp16_embeddings = bool(int(os.environ.get("FP16_EMBEDDINGS", "0"))) + grad_guided_quant = bool(int(os.environ.get("GRAD_GUIDED_QUANT", "1"))) + seq_curriculum = os.environ.get("SEQ_CURRICULUM", "") + + skip_quant = bool(int(os.environ.get("SKIP_QUANT", "0"))) + dev_mode = bool(int(os.environ.get("DEV_MODE", "0"))) + skip_compile = bool(int(os.environ.get("SKIP_COMPILE", "0"))) + + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + muon_weight_decay = float(os.environ.get("MUON_WEIGHT_DECAY", 0.04)) + adam_weight_decay = float(os.environ.get("ADAM_WD", 0.04)) + +# --------------------- +# MUON OPTIMIZER +# --------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, variant: str = "standard", weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + self.variant = variant + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + if self.variant == "normuon": + g = g / g.norm(dim=0, keepdim=True).clamp(min=1e-8) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + weight_decay = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if weight_decay > 0: + p.mul_(1 - lr * weight_decay) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + +# --------------------- +# TOKENIZER EVALUATION +# --------------------- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + f"VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + _accum_dtype = torch.float32 if device.type == "mps" else torch.float64 + val_loss_sum = torch.zeros((), device=device, dtype=_accum_dtype) + val_token_count = torch.zeros((), device=device, dtype=_accum_dtype) + val_byte_count = torch.zeros((), device=device, dtype=_accum_dtype) + + model.eval() + _stride = args.eval_stride + with torch.inference_mode(): + if _stride > 0: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() + _max_slide_iters = int(os.environ.get("MAX_SLIDE_ITERS", "999999")) + _slide_count = 0 + for pos in range(seq_start * seq_len, min(seq_end * seq_len, total_tokens - seq_len - 1), _stride): + if _slide_count >= _max_slide_iters: + break + _slide_count += 1 + end = min(pos + seq_len + 1, total_tokens) + local = val_tokens[pos:end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].unsqueeze(0) + y = local[1:].unsqueeze(0) + with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=True): + per_token_loss = model(x, y, reduction="none").detach() + score_count = min(_stride, x.size(1)) + # Only score the LAST stride tokens (they have full context) + scored_loss = per_token_loss[-score_count:] + val_loss_sum += scored_loss.to(_accum_dtype).sum() + val_token_count += score_count + scored_prev = x[0, -score_count:].reshape(-1) + scored_tgt = y[0, -score_count:].reshape(-1) + tb = base_bytes_lut[scored_tgt].to(dtype=torch.int16) + tb += (has_leading_space_lut[scored_tgt] & ~is_boundary_token_lut[scored_prev]).to(dtype=torch.int16) + val_byte_count += tb.to(_accum_dtype).sum() + else: + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(_accum_dtype) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(_accum_dtype).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# --------------------- +# POST-TRAINING QUANTIZATION +# --------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor, bits: int = 8) -> tuple[Tensor, Tensor]: + max_val = (1 << (bits - 1)) - 1 + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / float(max_val)).clamp_min(1.0 / float(max_val)) + q = torch.clamp(torch.round(clipped / scale[:, None]), -max_val, max_val).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / float(max_val) if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -max_val, max_val).to(torch.int8).contiguous() + return q, scale + +def _is_middle_layer(name: str, max_block_idx: int) -> bool: + import re as _re + m = _re.match(r"blocks\.(\d+)\.", name) + if not m or max_block_idx <= 1: + return False + idx = int(m.group(1)) + return 0 < idx < max_block_idx + +def _is_mlp_layer(name: str) -> bool: + return ".mlp." in name + +def quantize_state_dict_int8(state_dict: dict[str, Tensor], quant_bits_middle: int = 8, + quant_bits_mlp: int = 0, fp16_embeddings: bool = False, + grad_bit_map: dict[str, int] | None = None): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", + "baseline_tensor_bytes", "int8_payload_bytes"), 0, + ) + import re as _re + _block_indices = [int(_re.match(r"blocks\.(\d+)\.", k).group(1)) for k in state_dict if _re.match(r"blocks\.(\d+)\.", k)] + _max_block_idx = max(_block_indices) if _block_indices else 0 + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if fp16_embeddings and ("tok_emb" in name or "lm_head" in name): + kept = t.to(dtype=torch.float16).contiguous() + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + if grad_bit_map is not None and name in grad_bit_map: + bits = grad_bit_map[name] + elif _is_middle_layer(name, _max_block_idx): + if quant_bits_mlp > 0 and _is_mlp_layer(name): + bits = quant_bits_mlp + elif quant_bits_middle < 8: + bits = quant_bits_middle + else: + bits = 8 + else: + bits = 8 + q, s = quantize_float_tensor(t, bits=bits) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def compute_grad_bit_map(grad_sensitivity: dict[str, float], quant_bits_middle: int = 6, + target_avg_bits: float = 6.0) -> dict[str, int]: + if not grad_sensitivity: + return {} + names = sorted(grad_sensitivity.keys()) + sensitivities = [grad_sensitivity[n] for n in names] + total = sum(sensitivities) + if total <= 0: + return {} + indexed = sorted(enumerate(sensitivities), key=lambda x: x[1]) + n = len(indexed) + bit_map: dict[str, int] = {} + for rank_pos, (orig_idx, _) in enumerate(indexed): + frac = rank_pos / max(n - 1, 1) + if frac < 0.15: + bit_map[names[orig_idx]] = max(quant_bits_middle - 1, 4) + elif frac < 0.55: + bit_map[names[orig_idx]] = quant_bits_middle + else: + bit_map[names[orig_idx]] = min(quant_bits_middle + 1, 8) + return bit_map + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + +# --------------------- +# DATA LOADING +# --------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device, + curriculum: str = ""): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern, curriculum=curriculum) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# --------------------- +# TRANSFORMER MODULES +# --------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, ntk_seq_len: int = 0, max_seq_len: int = 2048): + super().__init__() + if ntk_seq_len > 1024: + alpha = ntk_seq_len / 1024.0 + base = base * (alpha ** (dim / max(dim - 2, 1))) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + t = torch.arange(max_seq_len, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self.register_buffer("_cos_cached", freqs.cos()[None, None, :, :], persistent=False) + self.register_buffer("_sin_cached", freqs.sin()[None, None, :, :], persistent=False) + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + return ( + self._cos_cached[:, :, :seq_len, :].to(device=device, dtype=dtype), + self._sin_cached[:, :, :seq_len, :].to(device=device, dtype=dtype), + ) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, + qk_gain_init: float, window_size: int = 0, ntk_seq_len: int = 0, + max_seq_len: int = 2048, use_xsa: bool = False): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + self.use_xsa = use_xsa + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base, ntk_seq_len=ntk_seq_len, max_seq_len=max_seq_len) + self.window_size = window_size + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if self.window_size > 0: + idx = torch.arange(seqlen, device=x.device) + row, col = idx.unsqueeze(1), idx.unsqueeze(0) + blocked = (col > row) | (row - col >= self.window_size) + attn_mask = torch.zeros(seqlen, seqlen, device=x.device, dtype=q.dtype) + attn_mask = attn_mask.masked_fill(blocked, float("-inf")) + else: + attn_mask = None + if self.num_kv_heads != self.num_heads: + n_rep = self.num_heads // self.num_kv_heads + k = k.unsqueeze(2).expand(-1, -1, n_rep, -1, -1).reshape(bsz, self.num_heads, seqlen, self.head_dim) + v = v.unsqueeze(2).expand(-1, -1, n_rep, -1, -1).reshape(bsz, self.num_heads, seqlen, self.head_dim) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=(attn_mask is None)) + if self.use_xsa: + y_t = y.transpose(1, 2) + Hkv = self.num_kv_heads + group_size = self.num_heads // Hkv + v_orig = self.c_v(x).reshape(bsz, seqlen, Hkv, self.head_dim) + vn = F.normalize(v_orig, dim=-1).unsqueeze(-2) + y_grouped = y_t.reshape(bsz, seqlen, Hkv, group_size, self.head_dim) + dot = (y_grouped * vn).sum(-1, keepdim=True) + y = (y_grouped - dot * vn).reshape(bsz, seqlen, dim) + else: + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, + rope_base: float, qk_gain_init: float, window_size: int = 0, + ntk_seq_len: int = 0, max_seq_len: int = 2048, use_xsa: bool = False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + window_size, ntk_seq_len=ntk_seq_len, + max_seq_len=max_seq_len, use_xsa=use_xsa) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * self.attn(self.attn_norm(x)) + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, model_dim: int, num_heads: int, + num_kv_heads: int, mlp_mult: int, tie_embeddings: bool, + tied_embed_init_std: float, logit_softcap: float, rope_base: float, + qk_gain_init: float, num_physical_layers: int = 0, window_size: int = 0, + init_type: str = "standard", smear_gate: bool = False, + ntk_seq_len: int = 0, max_seq_len: int = 2048, xsa_last_n: int = 0): + super().__init__() + self.init_type = init_type + self.num_layers = num_layers + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + if num_physical_layers <= 0: + num_physical_layers = num_layers + self.num_physical_layers = num_physical_layers + self.use_smear_gate = smear_gate + if smear_gate: + self.smear_gate_param = nn.Parameter(torch.zeros(model_dim, dtype=torch.float32)) + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + window_size, ntk_seq_len=ntk_seq_len, max_seq_len=max_seq_len, + use_xsa=(i >= num_physical_layers - xsa_last_n) if xsa_last_n > 0 else False) + for i in range(num_physical_layers) + ]) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings and hasattr(self.tok_emb, 'weight'): + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if self.init_type == "ortho": + with torch.no_grad(): + for name, module in self.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)) and module.weight.ndim == 2: + if module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj" in name or name.endswith(".proj"): + module.weight.mul_(1.0 / math.sqrt(2 * self.num_layers)) + + def _apply_smear_gate(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.smear_gate_param.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + def forward(self, input_ids: Tensor, target_ids: Tensor, reduction: str = "mean") -> Tensor: + x = self.tok_emb(input_ids) + if self.use_smear_gate: + x = self._apply_smear_gate(x) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i % self.num_physical_layers](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + _layer_idx = self.num_encoder_layers + i + x = self.blocks[_layer_idx % self.num_physical_layers](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + elif self.lm_head is not None: + logits_proj = self.lm_head(x) + else: + raise RuntimeError("No output head configured") + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + loss = F.cross_entropy(logits.float(), targets, reduction=reduction) + return loss + +# --------------------- +# TRAINING +# --------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if not args.skip_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # DISTRIBUTED + CUDA SETUP + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + _base_accum = int(os.environ.get("GRAD_ACCUM_STEPS", str(8 // world_size))) + grad_accum_steps = _base_accum + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + if not args.dev_mode: + raise RuntimeError("CUDA is required (set DEV_MODE=1 for local MPS/CPU testing)") + device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") + else: + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + if device.type == "cuda": + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(True) + enable_math_sdp(True) + + def sync() -> None: + if device.type == "cuda": + torch.cuda.synchronize() + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + if device.type == "cuda": + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, + text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # TOKENIZER + VALIDATION SETUP + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if device.type == "cuda": + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_physical_layers=args.num_physical_layers, + window_size=args.window_size, + init_type=args.init_type, + smear_gate=args.smear_gate, + ntk_seq_len=args.train_seq_len if args.use_ntk_rope else 0, + max_seq_len=args.train_seq_len, + xsa_last_n=args.xsa_last_n, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + _skip_compile = args.dev_mode or args.skip_compile + _use_dynamic = bool(args.seq_curriculum) + if _use_dynamic: + torch._dynamo.config.cache_size_limit = 64 + if distributed: + torch._dynamo.config.optimize_ddp = False + if not _skip_compile: + try: + compiled_model = torch.compile( + base_model, + dynamic=_use_dynamic, + fullgraph=not _use_dynamic, + ) + if _use_dynamic: + log0(f"torch.compile: dynamic=True, cache_size_limit={torch._dynamo.config.cache_size_limit}") + except Exception as e: + print(f"WARNING: torch.compile failed ({e}), falling back to eager mode", file=sys.stderr) + compiled_model = base_model + else: + compiled_model = base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.use_smear_gate: + scalar_params.append(base_model.smear_gate_param) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.AdamW( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_weight_decay, + fused=(device.type == "cuda"), + ) + optimizer_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, variant=args.optimizer_variant, + weight_decay=args.muon_weight_decay) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_weight_decay, + fused=(device.type == "cuda"), + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.AdamW( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_weight_decay, + fused=(device.type == "cuda"), + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"sdp_backends:cudnn=False flash=True mem_efficient=True math=True") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device, + curriculum=args.data_curriculum) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device, + curriculum=args.data_curriculum) + + # EMA (Exponential Moving Average) + _ema_state: dict | None = None + _ema_started = False + _use_ema = args.use_ema + if _use_ema and args.ema_from_init: + _ema_state = {k: v.detach().float().clone() for k, v in base_model.state_dict().items()} + _ema_started = True + log0(f"ema:initialized from model init (alpha={args.ema_alpha})") + + # Gradient-guided quant accumulator + _grad_sensitivity: dict[str, float] = {} + _grad_accum_started = False + + # Context curriculum + _seq_curriculum_stages: list[tuple[int, float]] = [] + if args.seq_curriculum: + for stage in args.seq_curriculum.split(","): + slen_str, frac_str = stage.strip().split(":") + _seq_curriculum_stages.append((int(slen_str), float(frac_str))) + _seq_curriculum_stages.sort(key=lambda x: x[1]) + log0(f"seq_curriculum: {_seq_curriculum_stages}") + + def _get_seq_len(frac: float) -> int: + if not _seq_curriculum_stages: + return args.train_seq_len + for slen, threshold in _seq_curriculum_stages: + if frac < threshold: + return slen + return args.train_seq_len + + # MAIN TRAINING LOOP + training_time_ms = 0.0 + stop_after_step: int | None = None + _prev_curriculum_seq_len = args.train_seq_len + sync() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + sync() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + sync() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + + # Context curriculum + if _seq_curriculum_stages and max_wallclock_ms: + _wc_frac = elapsed_ms / max_wallclock_ms + _curr_seq_len = _get_seq_len(_wc_frac) + if _curr_seq_len != _prev_curriculum_seq_len: + log0(f"seq_curriculum: switching to seq_len={_curr_seq_len} at step {step} ({_wc_frac*100:.0f}% wallclock)") + _prev_curriculum_seq_len = _curr_seq_len + else: + _curr_seq_len = args.train_seq_len + + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, _curr_seq_len, grad_accum_steps) + with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + # Gradient-guided quant: accumulate grad magnitudes during warmdown + if args.grad_guided_quant and scale < 1.0: + if not _grad_accum_started: + _grad_sensitivity = {n: 0.0 for n, p in base_model.named_parameters() if p.ndim >= 2} + _grad_accum_started = True + log0(f"grad_guided_quant: started accumulating ({len(_grad_sensitivity)} tensors)") + for n, p in base_model.named_parameters(): + if p.grad is not None and n in _grad_sensitivity: + _grad_sensitivity[n] += p.grad.detach().float().abs().mean().item() + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + # EMA update + if _use_ema and _ema_started and args.ema_from_init: + with torch.no_grad(): + alpha = args.ema_alpha + for k, v in base_model.state_dict().items(): + _ema_state[k].mul_(alpha).add_(v.detach().float(), alpha=1.0 - alpha) + elif _use_ema and not args.ema_from_init and step >= 200 and scale < args.ema_start_ratio: + with torch.no_grad(): + if not _ema_started: + _ema_state = {k: v.detach().float().clone() for k, v in base_model.state_dict().items()} + _ema_started = True + log0(f"ema:started step:{step} lr_scale:{scale:.4f} alpha:{args.ema_alpha}") + else: + alpha = args.ema_alpha + for k, v in base_model.state_dict().items(): + _ema_state[k].mul_(alpha).add_(v.detach().float(), alpha=1.0 - alpha) + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # Dynamic warmdown: cap to 55% of estimated total steps for hardware robustness + if step == 100 and max_wallclock_ms is not None and args.warmdown_iters > 0: + _step_ms = approx_training_time_ms / step + _est_total = int(max_wallclock_ms / _step_ms) + _wd_cap = int(_est_total * 0.55) + if _wd_cap < args.warmdown_iters: + log0(f"warmdown:adaptive {args.warmdown_iters}->{_wd_cap} (est_total_steps={_est_total} step_ms={_step_ms:.1f})") + args.warmdown_iters = _wd_cap + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + if device.type == "cuda": + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Load EMA weights + if _ema_state is not None and _ema_started: + log0(f"ema:loading shadow weights (alpha={args.ema_alpha})") + ema_state_typed = {k: v.to(dtype=base_model.state_dict()[k].dtype) + for k, v in _ema_state.items()} + base_model.load_state_dict(ema_state_typed, strict=True) + del _ema_state + + # SERIALIZATION + ROUNDTRIP VALIDATION + if args.skip_quant: + if master_process: + log0("skip_quant=True: skipping serialization and roundtrip validation") + if distributed: + dist.destroy_process_group() + return + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + _grad_bit_map = None + if args.grad_guided_quant and _grad_sensitivity: + _grad_bit_map = compute_grad_bit_map(_grad_sensitivity, quant_bits_middle=args.quant_bits_middle) + if master_process: + log0(f"grad_guided_quant: {len(_grad_bit_map)} tensors assigned adaptive bits") + _bit_counts = {} + for b in _grad_bit_map.values(): + _bit_counts[b] = _bit_counts.get(b, 0) + 1 + log0(f" bit distribution: {dict(sorted(_bit_counts.items()))}") + quant_obj, quant_stats = quantize_state_dict_int8( + base_model.state_dict(), quant_bits_middle=args.quant_bits_middle, + quant_bits_mlp=args.quant_bits_mlp, fp16_embeddings=args.fp16_embeddings, + grad_bit_map=_grad_bit_map) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + if args.compress_algo == "zstd": + import zstandard + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + _calgo = args.compress_algo + log0( + f"Serialized model int8+{_calgo}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+{_calgo}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + if args.compress_algo == "zstd": + import zstandard + _decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + _decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(_decompressed), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + sync() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + sync() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main()