diff --git a/EXPERIMENT_PLAN.md b/EXPERIMENT_PLAN.md new file mode 100644 index 000000000..b88b655fd --- /dev/null +++ b/EXPERIMENT_PLAN.md @@ -0,0 +1,117 @@ +# Experiment Plan + +This plan is optimized for limited budget and the challenge rules. + +## Goals + +- Improve `final_int8_zlib_roundtrip_exact val_bpb` +- Improve `final_int8_ttt_lora val_bpb` +- Stay under the `16,000,000` byte artifact cap +- Avoid risky dataset changes until the safe path is exhausted + +## 5-Run Moonshot Sequence + +Run these in order on remote GPUs, using the current branch and `TRAIN_SHARDS=1`: + +1. `drope_eval` +2. `yarn_eval` +3. `mtp_low` +4. `muon_balance` +5. `hybrid_delta` + +Run the entire sequence: + +```bash +NPROC_PER_NODE=1 bash scripts/run_moonshot5.sh +``` + +This prints each run tail and then a ranked JSON summary against the control run `twice_eval2048_ttt1024_clean2`. + +Ranking priority: + +1. Lowest `final_int8_ttt_lora val_bpb` +2. Lowest `final_int8_zlib_roundtrip_exact val_bpb` +3. Smallest artifact +4. Fastest step time + +Promotion rules: + +- Promote any run that beats the control on at least one final metric without exceeding the artifact cap. +- Promote `hybrid_delta` if it beats the control on either final metric, even slightly. + +Next-step rules: + +- If `drope_eval` beats `yarn_eval`, keep DRoPE and drop YaRN. +- If `yarn_eval` beats `drope_eval`, keep YaRN and drop DRoPE. +- If `mtp_low` wins, sweep `MTP_DEPTH=3` and `MTP_LOSS_WEIGHT` in `0.05`, `0.1`, `0.2`. +- If `muon_balance` wins, sweep `MUON_UPDATE_BALANCE` in `0.25`, `0.5`, `0.75`. +- If `hybrid_delta` wins even slightly, open a dedicated hybrid branch next. + +## Next Moonshot + +New architecture branch: + +1. `shared_depth` + +Idea: + +- reuse `4` unique blocks across `10` logical layers +- keep tiny per-pass learned output scales so reused blocks can still specialize +- preserve the existing optimizer, export, and TTT paths + +## Dataset And Tokenizer Work + +The challenge allows tokenizer or dataset changes, but the repo says they will be examined carefully and you must prove the `val_bpb` calculation is correct. See [README.md](/Users/deividasmataciunas/Desktop/research/openai_golf/README.md#L168). + +Safest path: + +- Rebuild tokenizers from the published docs cache only +- Re-export shards from the same selected docs +- Keep validation on the fixed first `50k` docs + +Use: + +```bash +bash scripts/rebuild_tokenizer_export.sh +``` + +Default ablation config: + +- `sp_bpe_768` +- `sp_bpe_1024` +- `sp_bpe_1280` +- `sp_bpe_1536` +- `pure_byte_260` + +After the model-side shortlist settles, do these data sweeps: + +1. Rebuild `sp_bpe_768`, `sp_bpe_1280`, and `pure_byte_260` +2. Rerun the current best profile on `TRAIN_SHARDS=1` +3. Only promote tokenizer changes that help `final_int8_ttt_lora` without pushing artifact bytes in the wrong direction + +## Dataset Ideas That Look Safe + +- Vary tokenizer vocab size on the same published docs +- Compare pure-byte vs SentencePiece BPE +- Train on a prefix of shards, then do a short final stage on a higher-quality subset from the same docs +- Filter obviously low-value docs from the training side only +- Keep document boundaries clean during training and eval + +## Risky Ideas + +- External corpora +- Changing validation docs +- Any data use at eval time beyond what the rules allow +- Tokenizer changes without exact byte-accounting validation + +## Success Metrics + +For each run, record: + +- `val_bpb` +- `final_int8_zlib_roundtrip_exact val_bpb` +- `final_int8_ttt_lora val_bpb` +- `Total submission size int8+zlib` +- `step_avg` + +If a tokenizer change helps pre-quant quality but hurts artifact bytes, reject it early. diff --git a/REMOTE_RUNBOOK.md b/REMOTE_RUNBOOK.md new file mode 100644 index 000000000..8cf164efe --- /dev/null +++ b/REMOTE_RUNBOOK.md @@ -0,0 +1,83 @@ +# Remote Runbook + +This repo is ready for the CUDA path. + +## Recommended Path + +Use the official Runpod Parameter Golf template mentioned in [README.md](/Users/deividasmataciunas/Desktop/research/openai_golf/README.md). + +Start with one of these: + +- `1x H100`: cheapest realistic sanity-check path for code, logs, artifact size, and eval behavior. +- `8x H100 SXM`: record-track run once the recipe looks stable. + +## First-Time Remote Setup + +On the remote box: + +```bash +cd /workspace +git clone https://github.com/openai/parameter-golf.git +cd parameter-golf +git remote add myfork +git fetch myfork +git checkout +``` + +Then hydrate the published cache: + +```bash +TRAIN_SHARDS=1 bash scripts/remote_fetch_data.sh +``` + +For a fuller training prefix: + +```bash +TRAIN_SHARDS=10 bash scripts/remote_fetch_data.sh +``` + +## First Experiment + +This is the first recipe to run against our merged script: + +```bash +NPROC_PER_NODE=1 bash scripts/run_remote_experiment.sh +``` + +For a full multi-GPU run: + +```bash +NPROC_PER_NODE=8 bash scripts/run_remote_experiment.sh +``` + +## What This Recipe Uses + +- `10` layers +- fp16 tied-embedding export +- NTK-aware longer eval support +- sliding-window eval with stride `64` +- decoupled Muon weight decay +- overtone embedding init +- phase-shaped residual mixing init + +## First Ablations To Queue + +Run these one at a time after the first successful remote run: + +```bash +EVAL_STRIDE=0 NPROC_PER_NODE=8 bash scripts/run_remote_experiment.sh +EVAL_SEQ_LEN=2048 NPROC_PER_NODE=8 bash scripts/run_remote_experiment.sh +NUM_LAYERS=9 NPROC_PER_NODE=8 bash scripts/run_remote_experiment.sh +MUON_WEIGHT_DECAY=0.00 NPROC_PER_NODE=8 bash scripts/run_remote_experiment.sh +OVERTONE_INIT_POWER=0.00 NPROC_PER_NODE=8 bash scripts/run_remote_experiment.sh +``` + +## What To Look For + +- `step_avg` +- final `val_bpb` +- final `final_int8_zlib_roundtrip_exact` +- final `final_int8_ttt_lora` +- total `int8+zlib` artifact bytes + +If you send me a remote log, I can turn it into the next ablation decision quickly. diff --git a/checkpoints/2026-03-20_twice_eval2048_ttt1024/REMOTE_CHECKPOINT.md b/checkpoints/2026-03-20_twice_eval2048_ttt1024/REMOTE_CHECKPOINT.md new file mode 100644 index 000000000..e5cd65a11 --- /dev/null +++ b/checkpoints/2026-03-20_twice_eval2048_ttt1024/REMOTE_CHECKPOINT.md @@ -0,0 +1,8 @@ +Best raw checkpoint metadata captured from the Runpod pod before shutdown: + +- Run ID: `twice_eval2048_ttt1024` +- Remote path: `/workspace/parameter-golf/final_model.pt` +- Size on pod: `72M` +- SHA256: `292d79fa54a638be348354f09d185f80b69710e7de8f4dfa42b36e43afccdc96` + +The raw `.pt` file itself was not copied into this repo because Runpod's SSH wrapper blocked automated binary transfer through `scp`. If you want to preserve the raw checkpoint, keep the pod or its volume alive until we manually copy it out tomorrow. diff --git a/data/tokenizer_specs.ablation.json b/data/tokenizer_specs.ablation.json new file mode 100644 index 000000000..ae09809c8 --- /dev/null +++ b/data/tokenizer_specs.ablation.json @@ -0,0 +1,29 @@ +{ + "tokenizers": [ + { + "name": "sp_bpe_768", + "dataset_suffix": "sp768", + "vocab_size": 768 + }, + { + "name": "sp_bpe_1024", + "dataset_suffix": "sp1024", + "vocab_size": 1024 + }, + { + "name": "sp_bpe_1280", + "dataset_suffix": "sp1280", + "vocab_size": 1280 + }, + { + "name": "sp_bpe_1536", + "dataset_suffix": "sp1536", + "vocab_size": 1536 + }, + { + "name": "pure_byte_260", + "dataset_suffix": "byte260", + "kind": "pure_byte" + } + ] +} diff --git a/program.md b/program.md new file mode 100644 index 000000000..42a44c932 --- /dev/null +++ b/program.md @@ -0,0 +1,86 @@ +# Parameter Golf Research Program + +You are working inside the OpenAI Parameter Golf repository. + +## Objective + +Improve the challenge score under these constraints: + +- optimize `final_int8_ttt_lora val_bpb` +- optimize `final_int8_zlib_roundtrip_exact val_bpb` +- keep `Total submission size int8+zlib` under `16,000,000` bytes +- preserve reproducibility + +Lower `val_bpb` is better. + +## Primary Rules + +1. Prefer small, ablation-friendly changes. +2. Keep changes concentrated in `train_gpt.py` unless there is a strong reason not to. +3. Reject changes that improve one metric but badly regress the other. +4. Reject changes that push artifact size toward the budget without a clear score win. +5. Do not change the validation set. +6. Treat tokenizer or dataset changes as higher-risk and require stronger evidence. + +## Current Priors + +- Sliding-window evaluation is high value. +- FP16 tied embedding export is high value. +- 10-layer small models are promising. +- Decoupled Muon weight decay is promising. +- `ATTN_TWICE_ALPHA=0.05` currently looks better than baseline. +- `Z_LOSS_COEF=0.0001` currently looks worse than baseline. + +## Current Best Known Local Results + +- `base10l` + - `roundtrip_val_bpb = 1.40296458` + - `ttt_val_bpb = 1.3976` + - `artifact_bytes = 10831123` + +- `twice_low` + - `roundtrip_val_bpb = 1.40177526` + - `ttt_val_bpb = 1.3969` + - `artifact_bytes = 10836065` + +## Experiment Order + +1. `twice_eval2048` +2. best `twice_*` variant on more seeds +3. training-context and batch tradeoff ablations +4. tokenizer ablations on published docs cache + +## Allowed Edit Zones + +- architecture details in `train_gpt.py` +- training schedule and optimizer settings +- quantization/export logic +- evaluation logic +- remote profile scripts + +## High-Risk Areas + +- external datasets +- validation handling +- complex multi-file refactors +- changes that increase code size substantially + +## Decision Policy + +Keep a change only if at least one is true: + +- `final_int8_ttt_lora` improves and `roundtrip_exact` does not materially regress +- `roundtrip_exact` improves and `ttt` does not materially regress +- artifact size drops meaningfully with near-flat score + +Reject a change if: + +- both `ttt` and `roundtrip_exact` regress +- artifact size grows with no score benefit +- it adds a lot of complexity without measurable value + +## Logging And Packaging + +- Use `scripts/run_remote_profile.sh` or `scripts/run_and_score.sh` +- Parse logs with `scripts/parse_run.py` +- Package strong candidates with `scripts/package_record.sh` diff --git a/records/_template/README.md b/records/_template/README.md new file mode 100644 index 000000000..414676107 --- /dev/null +++ b/records/_template/README.md @@ -0,0 +1,38 @@ +# Submission Name + +One-paragraph summary of the idea and why it matters for Parameter Golf. + +## Key Techniques + +1. Technique 1 +2. Technique 2 +3. Technique 3 + +## Results + +| Seed | val_loss | val_bpb | Steps | ms/step | +|------|----------|---------|-------|---------| +| 1337 | TBD | TBD | TBD | TBD | +| 42 | TBD | TBD | TBD | TBD | +| 7 | TBD | TBD | TBD | TBD | +| **Mean** | **TBD** | **TBD** | | | + +Artifact: `TBD` bytes | Eval time: `TBD` + +## Configuration + +```bash +# Paste the exact training command here +``` + +## Notes + +- Explain artifact accounting if needed +- Explain tokenizer/dataset changes if any +- Explain evaluation procedure if non-standard + +## Included Files + +- `train_gpt.py` +- `submission.json` +- `train_seed*.log` diff --git a/records/_template/submission.json b/records/_template/submission.json new file mode 100644 index 000000000..50c1a941f --- /dev/null +++ b/records/_template/submission.json @@ -0,0 +1,20 @@ +{ + "track": "10min_16mb", + "date": "YYYY-MM-DD", + "name": "Submission Name", + "author": "Your Name", + "github_id": "YourGitHubID", + "seed_results": { + "1337": { + "val_loss": 0.0, + "val_bpb": 0.0, + "steps": 0, + "ms_per_step": 0.0 + } + }, + "mean_val_loss": 0.0, + "mean_val_bpb": 0.0, + "p_value": 1.0, + "artifact_bytes": 0, + "code_bytes": 0 +} diff --git a/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/README.md b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/README.md new file mode 100644 index 000000000..7f6e83e3e --- /dev/null +++ b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/README.md @@ -0,0 +1,68 @@ +This folder checkpoints the first Runpod exploratory ablations run on March 20, 2026. + +These were not leaderboard-attempt runs. They used `1x H100`, `TRAIN_SHARDS=1`, and a strict `600s` wallclock cap to de-risk code changes and compare low-cost ablations before spending 8-GPU budget. + +Best observations: +- Best roundtrip score: `twice_eval2048` at `final_int8_zlib_roundtrip_exact val_bpb: 1.39909070` +- Best LoRA-TTT score: `twice_eval2048_ttt1024` at `final_int8_ttt_lora val_bpb: 1.3960` +- Best balanced profile so far: `twice_eval2048_ttt1024` + +Key takeaways: +- `ATTN_TWICE_ALPHA=0.05` helped versus baseline. +- `Z_LOSS_COEF=0.0001` regressed both roundtrip and TTT metrics. +- `EVAL_SEQ_LEN=2048` improved roundtrip scoring but hurt TTT when `TTT_EVAL_SEQ_LEN` also increased. +- Splitting eval settings (`EVAL_SEQ_LEN=2048`, `TTT_EVAL_SEQ_LEN=1024`) recovered the best TTT result seen so far. + +Best current training command: +```bash +RUN_ID=twice_eval2048_ttt1024 \ +DATA_PATH=./data/datasets/fineweb10B_sp1024 \ +TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ +VOCAB_SIZE=1024 \ +MAX_WALLCLOCK_SECONDS=600 \ +TRAIN_LOG_EVERY=100 \ +VAL_LOSS_EVERY=1000 \ +TRAIN_BATCH_TOKENS=524288 \ +TRAIN_SEQ_LEN=1024 \ +EVAL_SEQ_LEN=2048 \ +EVAL_STRIDE=64 \ +NUM_LAYERS=10 \ +MODEL_DIM=512 \ +NUM_HEADS=8 \ +NUM_KV_HEADS=4 \ +MLP_MULT=2 \ +TIE_EMBEDDINGS=1 \ +MATRIX_LR=0.02 \ +SCALAR_LR=0.02 \ +TIED_EMBED_LR=0.03 \ +WARMDOWN_ITERS=2500 \ +MUON_MOMENTUM=0.95 \ +MUON_BACKEND_STEPS=5 \ +MUON_WEIGHT_DECAY=0.02 \ +OVERTONE_INIT_POWER=0.5 \ +RESID_MIX_INIT_SCALE=3.0 \ +Z_LOSS_COEF=0.0 \ +ATTN_TWICE_ALPHA=0.05 \ +TTT_LORA_RANK=8 \ +TTT_LORA_LR=0.01 \ +TTT_CHUNK_SIZE=256 \ +TTT_EVAL_SEQ_LEN=1024 \ +TTT_BATCH_SIZE=64 \ +torchrun --standalone --nproc_per_node=1 train_gpt.py +``` + +Checkpoint note: +- The raw best checkpoint still lives on the Runpod pod at `/workspace/parameter-golf/final_model.pt`. +- Size on pod: `72M` +- SHA256: `292d79fa54a638be348354f09d185f80b69710e7de8f4dfa42b36e43afccdc96` +- Runpod's SSH wrapper blocked automated binary transfer, so this repo checkpoint stores the exact metrics and checkpoint manifest rather than the raw `.pt` file. + +Included files: +- `train_gpt.py` +- `submission.json` +- `results.json` +- `base10l.tail.txt` +- `zloss_low.tail.txt` +- `twice_low.tail.txt` +- `twice_eval2048.tail.txt` +- `twice_eval2048_ttt1024.tail.txt` diff --git a/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/base10l.tail.txt b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/base10l.tail.txt new file mode 100644 index 000000000..c506fad36 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/base10l.tail.txt @@ -0,0 +1,10 @@ +stopping_early: wallclock_cap train_time:600528ms step:1061/20000 +peak memory allocated: 14656 MiB reserved: 15982 MiB +Serialized model: 74578510 bytes +Code size: 58584 bytes +Total submission size: 74637094 bytes +Serialized model int8+zlib: 10772539 bytes (payload:19030336 raw_torch:19079980 payload_ratio:3.92x) +Total submission size int8+zlib: 10831123 bytes +final_int8_zlib_roundtrip val_loss:2.3688 val_bpb:1.4030 eval_time:12984ms +final_int8_zlib_roundtrip_exact val_loss:2.36884692 val_bpb:1.40296458 +final_int8_ttt_lora val_loss:2.3598 val_bpb:1.3976 eval_time:507995ms diff --git a/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/results.json b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/results.json new file mode 100644 index 000000000..349938dca --- /dev/null +++ b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/results.json @@ -0,0 +1,85 @@ +{ + "checkpoint_note": { + "best_raw_checkpoint_path": "/workspace/parameter-golf/final_model.pt", + "best_raw_checkpoint_size_human": "72M", + "best_raw_checkpoint_sha256": "292d79fa54a638be348354f09d185f80b69710e7de8f4dfa42b36e43afccdc96", + "transfer_status": "remote-only" + }, + "runs": [ + { + "run_id": "base10l", + "artifact_bytes": 10831123, + "code_bytes": 58584, + "roundtrip_eval_time_ms": 12984, + "roundtrip_val_bpb": 1.40296458, + "roundtrip_val_loss": 2.36884692, + "step_stop": 1061, + "total_train_time_ms": 600528, + "ttt_eval_time_ms": 507995, + "ttt_val_bpb": 1.3976, + "ttt_val_loss": 2.3598 + }, + { + "run_id": "zloss_low", + "artifact_bytes": 10804697, + "code_bytes": 58584, + "roundtrip_eval_time_ms": 13026, + "roundtrip_val_bpb": 1.40313158, + "roundtrip_val_loss": 2.3691289, + "step_stop": 1058, + "total_train_time_ms": 600386, + "ttt_eval_time_ms": 508243, + "ttt_val_bpb": 1.3984, + "ttt_val_loss": 2.3611 + }, + { + "run_id": "twice_low", + "artifact_bytes": 10836065, + "code_bytes": 58584, + "roundtrip_eval_time_ms": 12989, + "roundtrip_val_bpb": 1.40177526, + "roundtrip_val_loss": 2.3668388, + "step_stop": 1063, + "total_train_time_ms": 600242, + "ttt_eval_time_ms": 508663, + "ttt_val_bpb": 1.3969, + "ttt_val_loss": 2.3586 + }, + { + "run_id": "twice_eval2048", + "artifact_bytes": 10924367, + "code_bytes": 58584, + "roundtrip_eval_time_ms": 12988, + "roundtrip_val_bpb": 1.3990907, + "roundtrip_val_loss": 2.36230604, + "step_stop": 1085, + "total_train_time_ms": 600071, + "ttt_eval_time_ms": 755165, + "ttt_val_bpb": 1.3995, + "ttt_val_loss": 2.363 + }, + { + "run_id": "twice_eval2048_ttt1024", + "artifact_bytes": 10868875, + "code_bytes": 58584, + "roundtrip_eval_time_ms": 12982, + "roundtrip_val_bpb": 1.4013468, + "roundtrip_val_loss": 2.36611538, + "step_stop": 1075, + "total_train_time_ms": 600439, + "ttt_eval_time_ms": 508393, + "ttt_val_bpb": 1.396, + "ttt_val_loss": 2.3571 + } + ], + "summary": { + "best_balanced_profile": "twice_eval2048_ttt1024", + "best_roundtrip_profile": "twice_eval2048", + "best_ttt_profile": "twice_eval2048_ttt1024", + "winner_notes": [ + "ATTN_TWICE_ALPHA=0.05 outperformed baseline.", + "Z_LOSS_COEF=0.0001 regressed both metrics.", + "Longer roundtrip eval context and shorter TTT eval context worked best together." + ] + } +} diff --git a/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/submission.json b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/submission.json new file mode 100644 index 000000000..a08ec022e --- /dev/null +++ b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/submission.json @@ -0,0 +1,17 @@ +{ + "author": "Deividas Mataciunas", + "github_id": "DeividasMat", + "name": "Runpod 1GPU Twicing Eval Split", + "blurb": "Exploratory 1x H100, 1-shard, 10-minute checkpoint run. The best balanced profile used ATTN_TWICE_ALPHA=0.05 with EVAL_SEQ_LEN=2048 and TTT_EVAL_SEQ_LEN=1024, reaching final_int8_ttt_lora val_bpb 1.3960 under the 16MB artifact cap.", + "date": "2026-03-20T18:30:00Z", + "track": "non-record-1gpu-16mb", + "val_loss": 2.3571, + "val_bpb": 1.3960, + "pre_quant_val_loss": 2.3308, + "pre_quant_val_bpb": 1.3805, + "step_stop": 1075, + "wallclock_seconds": 600.439, + "bytes_total": 10868875, + "bytes_model_int8_zlib": 10810291, + "bytes_code": 58584 +} diff --git a/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/train_gpt.py b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/train_gpt.py new file mode 100644 index 000000000..5fee3aa8e --- /dev/null +++ b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/train_gpt.py @@ -0,0 +1,1496 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", "0")) + eval_stride = int(os.environ.get("EVAL_STRIDE", "0")) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_weight_decay = float(os.environ.get("MUON_WEIGHT_DECAY", 0.0)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + z_loss_coef = float(os.environ.get("Z_LOSS_COEF", 0.0)) + attn_twice_alpha = float(os.environ.get("ATTN_TWICE_ALPHA", 0.0)) + overtone_init_power = float(os.environ.get("OVERTONE_INIT_POWER", 0.0)) + resid_mix_init_scale = float(os.environ.get("RESID_MIX_INIT_SCALE", 3.0)) + + # Test-time training (LoRA) hyperparameters. + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.01)) + ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 256)) + ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", os.environ.get("EVAL_SEQ_LEN", os.environ.get("TRAIN_SEQ_LEN", "1024")))) + ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + seq_len_override: int = 0, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + seq_len = seq_len_override if seq_len_override > 0 else args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + model.train() + val_loss = val_loss_sum / val_token_count + return float(val_loss.item()), float((val_loss.item() / math.log(2.0)) * (val_token_count.item() / val_byte_count.item())) + + +def eval_val_sliding( + logits_fn, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + seq_len: int, + stride: int, + eval_batch_seqs: int, +) -> tuple[float, float]: + windows: list[tuple[int, int]] = [] + total = val_tokens.numel() - 1 + pos = 0 + while pos + seq_len <= total: + windows.append((pos, 0 if pos == 0 else seq_len - stride)) + pos += stride + + per_rank = (len(windows) + world_size - 1) // world_size + my_windows = windows[rank * per_rank : min((rank + 1) * per_rank, len(windows))] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + tok_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + with torch.inference_mode(): + for i in range(0, len(my_windows), eval_batch_seqs): + batch = my_windows[i : i + eval_batch_seqs] + x = torch.stack([val_tokens[p : p + seq_len] for p, _ in batch]).to(device=device, dtype=torch.int64) + y = torch.stack([val_tokens[p + 1 : p + seq_len + 1] for p, _ in batch]).to(device=device, dtype=torch.int64) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = logits_fn(x) + for b, (_, score_offset) in enumerate(batch): + scored_targets = y[b, score_offset:] + loss_sum += F.cross_entropy(logits[b, score_offset:].float(), scored_targets, reduction="sum").to(torch.float64) + tok_count += scored_targets.numel() + prev = x[b, score_offset : score_offset + scored_targets.numel()] + token_bytes = base_bytes_lut[scored_targets].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[scored_targets] & ~is_boundary_token_lut[prev]).to(dtype=torch.int16) + byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(tok_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = loss_sum / tok_count + return float(val_loss.item()), float((val_loss.item() / math.log(2.0)) * (tok_count.item() / byte_count.item())) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + if name == "tok_emb.weight": + kept = t.to(dtype=torch.float16).contiguous() + passthrough[name] = kept + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + adjusted_base = self.base * (scale ** (self.dim / (self.dim - 2))) + inv_freq = 1.0 / ( + adjusted_base ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim) + ) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +def lm_loss(logits: Tensor, targets: Tensor, z_loss_coef: float, reduction: str) -> Tensor: + losses = F.cross_entropy(logits.float().reshape(-1, logits.size(-1)), targets.reshape(-1), reduction="none").reshape_as(targets) + if z_loss_coef > 0: + losses = losses + z_loss_coef * torch.logsumexp(logits.float(), dim=-1).square() + return losses if reduction == "none" else losses.mean() +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + train_seq_len: int, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len) + + def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x) + (q_delta if q_delta is not None else 0) + k = self.c_k(x) + v = self.c_v(x) + (v_delta if v_delta is not None else 0) + q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if self.num_kv_heads != self.num_heads: + k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + train_seq_len: int, + attn_twice_alpha: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len) + self.attn_twice_alpha = attn_twice_alpha + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, q_delta_fn=None, v_delta_fn=None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + n = self.attn_norm(x) + qd = q_delta_fn(n) if q_delta_fn is not None else None + vd = v_delta_fn(n) if v_delta_fn is not None else None + attn_out = self.attn(n, qd, vd) + if self.attn_twice_alpha != 0.0: + attn_out = attn_out + self.attn_twice_alpha * (n - attn_out) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + train_seq_len: int, + z_loss_coef: float, + attn_twice_alpha: float, + overtone_init_power: float, + resid_mix_init_scale: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.z_loss_coef = z_loss_coef + self.overtone_init_power = overtone_init_power + self.resid_mix_init_scale = resid_mix_init_scale + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + train_seq_len, + attn_twice_alpha, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + if self.overtone_init_power > 0.0: + with torch.no_grad(): + u, s, v = torch.linalg.svd(self.tok_emb.weight.data, full_matrices=False) + self.tok_emb.weight.data = (u * (s[0] * torch.arange(1, s.shape[0] + 1, dtype=s.dtype).pow(-self.overtone_init_power))[None, :]) @ v + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + for i, block in enumerate(self.blocks): + with torch.no_grad(): + phase = torch.sigmoid(torch.tensor(self.resid_mix_init_scale * (i / max(len(self.blocks) - 1, 1) - 0.5))) + block.resid_mix.data[0].fill_(float(phase)) + block.resid_mix.data[1].fill_(float(1.0 - phase)) + + def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + qd = lora.q_loras[i] if lora else None + vd = lora.v_loras[i] if lora else None + x = self.blocks[i](x, x0, qd, vd) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + qd = lora.q_loras[bi] if lora else None + vd = lora.v_loras[bi] if lora else None + x = self.blocks[bi](x, x0, qd, vd) + x = self.final_norm(x) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + logits = logits + (lora.lm_head_lora(x) if lora else 0) + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + return lm_loss(logits, target_ids, self.z_loss_coef, reduction="none" if lora else "mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[bi](x, x0) + x = self.final_norm(x) + logits = F.linear(x, self.tok_emb.weight.to(x.dtype)) if self.tie_embeddings else self.lm_head(x) + return self.logit_softcap * torch.tanh(logits / self.logit_softcap) + + +# ----------------------------- +# TEST-TIME TRAINING (LoRA) +# ----------------------------- +# +# At evaluation time, we adapt per-document low-rank adapters on the validation data. +# Each document gets its own adapter, so there is no inter-document dependency. + +BOS_ID = 1 + +class BatchedLinearLoRA(nn.Module): + """LoRA for a linear layer, with independent weights per batch element. + Computes x @ Aᵀ @ Bᵀ = x @ (BA)ᵀ, i.e. the LoRA delta is ΔW = BA.""" + def __init__(self, bsz: int, in_features: int, out_features: int, rank: int): + super().__init__() + self.in_features = in_features + self.A = nn.Parameter(torch.empty(bsz, rank, in_features)) # down-projection + self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) # up-projection + self.reset() + + def forward(self, x: Tensor) -> Tensor: + return (x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2) # (bsz, T, out) + + def reset(self) -> None: + bound = 1.0 / math.sqrt(self.in_features) + with torch.no_grad(): + self.A.uniform_(-bound, bound) # kaiming-uniform + self.B.zero_() + +class BatchedTTTLoRA(nn.Module): + """All LoRA adapters for one batch: LM head and Q/V per block.""" + def __init__(self, bsz: int, model: GPT, rank: int): + super().__init__() + dim = model.tok_emb.embedding_dim + vocab = model.tok_emb.num_embeddings + self.lm_head_lora = BatchedLinearLoRA(bsz, dim, vocab, rank) + self.q_loras = nn.ModuleList() + self.v_loras = nn.ModuleList() + for block in model.blocks: + self.q_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_q.weight.shape[0], rank)) + self.v_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_v.weight.shape[0], rank)) + + def reset(self) -> None: + for m in self.modules(): + if isinstance(m, BatchedLinearLoRA): + m.reset() + +def _reset_ttt_optimizer(opt): + for group in opt.param_groups: + for p in group['params']: + s = opt.state.get(p) + if not s: # Fresh state. + continue + s['exp_avg'].zero_() + s['exp_avg_sq'].zero_() + s['step'].fill_(0) + +def _build_ttt_optimizer(lora, args: Hyperparameters): + return torch.optim.Adam(lora.parameters(), lr=args.ttt_lora_lr, betas=(args.beta1, args.beta2), eps=1e-10) + +def _find_docs(all_tokens: Tensor, include_next_bos: bool = True) -> list[tuple[int, int]]: + """Return (start_offset, length) for each document, identified by BOS boundaries. + + If include_next_bos is True, include next document's BOS (to match continuous-stream + eval token count exactly). + """ + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() + docs = [] + for i in range(len(bos_positions)): + start = int(bos_positions[i]) + end = int(bos_positions[i + 1]) if i + 1 < len(bos_positions) else all_tokens.numel() + if include_next_bos and i + 1 < len(bos_positions): + end += 1 + assert end - start >= 2 + docs.append((start, end - start)) + return docs + +def _compute_chunk_window(ci: int, pred_len: int, num_chunks: int, chunk_size: int, eval_seq_len: int): + """Return (win_start, win_len, chunk_offset, chunk_len) for chunk `ci` of a doc.""" + chunk_start = ci * chunk_size + chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size + win_start = max(0, chunk_end - eval_seq_len) + win_len = chunk_end - win_start + chunk_offset = chunk_start - win_start + chunk_len = chunk_end - chunk_start + return win_start, win_len, chunk_offset, chunk_len + +def _accumulate_bpb( + ptl: Tensor, x: Tensor, y: Tensor, + batch_i: int, chunk_offset: int, chunk_len: int, + base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + loss_sum: Tensor, byte_sum: Tensor, token_count: Tensor, +): + """Add one doc-chunk's contribution to the running BPB accumulators.""" + lbl = ptl[batch_i, chunk_offset:chunk_offset + chunk_len].to(torch.float64) + prev = x[batch_i, chunk_offset:chunk_offset + chunk_len] + tgt = y[batch_i, chunk_offset:chunk_offset + chunk_len] + tok_bytes = base_bytes_lut[tgt].to(torch.float64) + tok_bytes += has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev] + loss_sum += lbl.sum() + byte_sum += tok_bytes.sum() + token_count += chunk_len + +def eval_val_ttt_lora( + args: Hyperparameters, + base_model: GPT, + rank: int, + world_size: int, + device: torch.device, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + """Evaluate with batched LoRA test-time training. Returns (val_loss, val_bpb).""" + # Load validation tokens and find document boundaries + files = sorted(glob.glob(args.val_files)) + all_tokens = torch.cat([load_data_shard(Path(f)) for f in files]) + docs = _find_docs(all_tokens) + + # Each rank takes a contiguous slice of documents + rank_docs = docs[(len(docs) * rank) // world_size : (len(docs) * (rank + 1)) // world_size] + chunk_size = args.ttt_chunk_size + eval_seq_len = args.ttt_eval_seq_len + batch_size = args.ttt_batch_size + lora_rank = args.ttt_lora_rank + + rank_docs.sort(key=lambda d: (d[1] - 2) // chunk_size) + + base_model.eval() + for p in base_model.parameters(): + p.requires_grad_(False) + + lora = BatchedTTTLoRA(batch_size, base_model, lora_rank).to(device) + opt = _build_ttt_optimizer(lora, args) + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + + for bi in range(0, len(rank_docs), batch_size): + batch = rank_docs[bi:bi + batch_size] + bsz = len(batch) + + if bsz == batch_size: + cur_lora, cur_opt = lora, opt + cur_lora.reset() + _reset_ttt_optimizer(cur_opt) + else: + cur_lora = BatchedTTTLoRA(bsz, base_model, lora_rank).to(device) + cur_opt = _build_ttt_optimizer(cur_lora, args) + + pred_lens = [doc_len - 1 for _, doc_len in batch] + num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] + max_nc = max(num_chunks) + + for ci in range(max_nc): + chunk_stats = _compute_chunk_window(ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len) + context_size, chunk_offset = chunk_stats[1], chunk_stats[2] + + active = [ci < nc for nc in num_chunks] + needs_train = any(ci < nc - 1 for nc in num_chunks) + + x = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) + y = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) + doc_info = [] # (chunk_offset, chunk_len) per doc + for b in range(bsz): + if not active[b]: + doc_info.append((0, 0)) + continue + ds, dl = batch[b] + ws, wl, co, cl = _compute_chunk_window(ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len) + chunk = all_tokens[ds + ws: ds + ws + wl + 1] + toks = chunk.to(dtype=torch.int64, device=device) + x[b, :wl] = toks[:-1] + y[b, :wl] = toks[1:] + doc_info.append((co, cl)) + + # Forward pass (keep grad graph alive only when we need to train) + if needs_train: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + else: + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + + # Score: accumulate loss and byte counts for BPB (before training on chunk) + with torch.no_grad(): + for b in range(bsz): + if not active[b]: + continue + co, cl = doc_info[b] + _accumulate_bpb( + ptl, x, y, b, co, cl, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, loss_sum, byte_sum, token_count) + + # Train: one Adam step on the LoRA params using this chunk's loss + if needs_train: + mask = torch.tensor([float(ci < num_chunks[b] - 1) for b in range(bsz)], device=device) + per_doc = ptl[:, chunk_offset:chunk_offset + chunk_size].mean(dim=-1) + cur_opt.zero_grad() + (per_doc * mask).sum().backward() + cur_opt.step() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + + val_loss = float(loss_sum.item() / token_count.item()) + val_bpb = float((loss_sum.item() / math.log(2.0)) / byte_sum.item()) + return val_loss, val_bpb + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + train_seq_len=args.train_seq_len, + z_loss_coef=args.z_loss_coef, + attn_twice_alpha=args.attn_twice_alpha, + overtone_init_power=args.overtone_init_power, + resid_mix_init_scale=args.resid_mix_init_scale, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + if isinstance(module, Rotary): + module.inv_freq.data = module.inv_freq.data.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0(f"eval_seq_len:{args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len} eval_stride:{args.eval_stride} muon_weight_decay:{args.muon_weight_decay} z_loss_coef:{args.z_loss_coef} attn_twice_alpha:{args.attn_twice_alpha} overtone_init_power:{args.overtone_init_power}") + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + if args.muon_weight_decay > 0: + with torch.no_grad(): + shrink = 1.0 - args.muon_weight_decay * optimizer_muon.param_groups[0]["lr"] + for p in matrix_params: + p.mul_(shrink) + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_tokens_eval = load_validation_tokens(args.val_files, eval_seq_len) if eval_seq_len != args.train_seq_len else val_tokens + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.eval_stride > 0: + eval_batch_seqs = min(256, max(1, args.val_batch_size // eval_seq_len // max(world_size, 1))) + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False) + warmup_x = torch.zeros(eval_batch_seqs, eval_seq_len, dtype=torch.int64, device=device) + base_model.eval() + with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + _ = compiled_logits(warmup_x) + q_val_loss, q_val_bpb = eval_val_sliding(compiled_logits, rank, world_size, device, val_tokens_eval, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, eval_seq_len, args.eval_stride, eval_batch_seqs) + base_model.train() + else: + q_val_loss, q_val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens_eval, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, seq_len_override=eval_seq_len if eval_seq_len != args.train_seq_len else 0) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # LoRA test-time training evaluation (the competition score) + torch._dynamo.reset() + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_val_loss, ttt_val_bpb = eval_val_ttt_lora( + args, base_model, rank, world_size, device, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_ttt_lora val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/twice_eval2048.tail.txt b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/twice_eval2048.tail.txt new file mode 100644 index 000000000..13ca9e496 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/twice_eval2048.tail.txt @@ -0,0 +1,10 @@ +stopping_early: wallclock_cap train_time:600071ms step:1085/20000 +peak memory allocated: 14528 MiB reserved: 14696 MiB +Serialized model: 74578510 bytes +Code size: 58584 bytes +Total submission size: 74637094 bytes +Serialized model int8+zlib: 10865783 bytes (payload:19030336 raw_torch:19079980 payload_ratio:3.92x) +Total submission size int8+zlib: 10924367 bytes +final_int8_zlib_roundtrip val_loss:2.3623 val_bpb:1.3991 eval_time:12988ms +final_int8_zlib_roundtrip_exact val_loss:2.36230604 val_bpb:1.39909070 +final_int8_ttt_lora val_loss:2.3630 val_bpb:1.3995 eval_time:755165ms diff --git a/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/twice_eval2048_ttt1024.tail.txt b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/twice_eval2048_ttt1024.tail.txt new file mode 100644 index 000000000..aebb3ce9f --- /dev/null +++ b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/twice_eval2048_ttt1024.tail.txt @@ -0,0 +1,10 @@ +stopping_early: wallclock_cap train_time:600439ms step:1075/20000 +peak memory allocated: 14528 MiB reserved: 14696 MiB +Serialized model: 74578510 bytes +Code size: 58584 bytes +Total submission size: 74637094 bytes +Serialized model int8+zlib: 10810291 bytes (payload:19030336 raw_torch:19079980 payload_ratio:3.92x) +Total submission size int8+zlib: 10868875 bytes +final_int8_zlib_roundtrip val_loss:2.3661 val_bpb:1.4013 eval_time:12982ms +final_int8_zlib_roundtrip_exact val_loss:2.36611538 val_bpb:1.40134680 +final_int8_ttt_lora val_loss:2.3571 val_bpb:1.3960 eval_time:508393ms diff --git a/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/twice_low.tail.txt b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/twice_low.tail.txt new file mode 100644 index 000000000..0cddd1513 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/twice_low.tail.txt @@ -0,0 +1,10 @@ +stopping_early: wallclock_cap train_time:600242ms step:1063/20000 +peak memory allocated: 14528 MiB reserved: 14696 MiB +Serialized model: 74578510 bytes +Code size: 58584 bytes +Total submission size: 74637094 bytes +Serialized model int8+zlib: 10777481 bytes (payload:19030336 raw_torch:19079980 payload_ratio:3.92x) +Total submission size int8+zlib: 10836065 bytes +final_int8_zlib_roundtrip val_loss:2.3668 val_bpb:1.4018 eval_time:12989ms +final_int8_zlib_roundtrip_exact val_loss:2.36683880 val_bpb:1.40177526 +final_int8_ttt_lora val_loss:2.3586 val_bpb:1.3969 eval_time:508663ms diff --git a/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/zloss_low.tail.txt b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/zloss_low.tail.txt new file mode 100644 index 000000000..58b9eee35 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/zloss_low.tail.txt @@ -0,0 +1,10 @@ +stopping_early: wallclock_cap train_time:600386ms step:1058/20000 +peak memory allocated: 14528 MiB reserved: 14696 MiB +Serialized model: 74578510 bytes +Code size: 58584 bytes +Total submission size: 74637094 bytes +Serialized model int8+zlib: 10746113 bytes (payload:19030336 raw_torch:19079980 payload_ratio:3.92x) +Total submission size int8+zlib: 10804697 bytes +final_int8_zlib_roundtrip val_loss:2.3691 val_bpb:1.4031 eval_time:13026ms +final_int8_zlib_roundtrip_exact val_loss:2.36912890 val_bpb:1.40313158 +final_int8_ttt_lora val_loss:2.3611 val_bpb:1.3984 eval_time:508243ms diff --git a/scripts/package_record.sh b/scripts/package_record.sh new file mode 100755 index 000000000..eecc78554 --- /dev/null +++ b/scripts/package_record.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +if [ "$#" -lt 2 ]; then + echo "usage: bash scripts/package_record.sh [log2 ...]" >&2 + exit 1 +fi + +TARGET_DIR="$1" +shift + +mkdir -p "$TARGET_DIR" +cp records/_template/README.md "$TARGET_DIR/README.md" +cp records/_template/submission.json "$TARGET_DIR/submission.json" +cp train_gpt.py "$TARGET_DIR/train_gpt.py" + +for log_file in "$@"; do + cp "$log_file" "$TARGET_DIR/" +done + +echo "Packaged template into $TARGET_DIR" +echo "Next: fill README.md and submission.json with real metrics." diff --git a/scripts/parse_run.py b/scripts/parse_run.py new file mode 100755 index 000000000..c0dc05f5d --- /dev/null +++ b/scripts/parse_run.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import json +import re +import sys +from pathlib import Path + + +PATTERNS = { + "roundtrip": re.compile(r"final_int8_zlib_roundtrip_exact val_loss:(\S+) val_bpb:(\S+)"), + "ttt": re.compile(r"final_int8_ttt_lora val_loss:(\S+) val_bpb:(\S+)"), + "artifact": re.compile(r"Total submission size int8\+zlib: (\d+) bytes"), + "step_avg": re.compile(r"step:(\d+)/\d+ val_loss:(\S+) val_bpb:(\S+) train_time:(\d+)ms step_avg:(\S+)ms"), + "peak_mem": re.compile(r"peak memory allocated: (\d+) MiB reserved: (\d+) MiB"), +} + + +def maybe_float(value: str) -> float | None: + try: + return float(value) + except ValueError: + return None + + +def parse_log(path: Path) -> dict[str, object]: + if not path.exists(): + return {"log": str(path), "missing": True} + text = path.read_text(encoding="utf-8", errors="replace") + out: dict[str, object] = {"log": str(path)} + if m := PATTERNS["roundtrip"].search(text): + val_loss = maybe_float(m.group(1)) + val_bpb = maybe_float(m.group(2)) + if val_loss is not None and val_bpb is not None: + out["roundtrip_val_loss"] = val_loss + out["roundtrip_val_bpb"] = val_bpb + if m := PATTERNS["ttt"].search(text): + val_loss = maybe_float(m.group(1)) + val_bpb = maybe_float(m.group(2)) + if val_loss is not None and val_bpb is not None: + out["ttt_val_loss"] = val_loss + out["ttt_val_bpb"] = val_bpb + if m := PATTERNS["artifact"].search(text): + out["artifact_bytes"] = int(m.group(1)) + if m := PATTERNS["peak_mem"].search(text): + out["peak_alloc_mib"] = int(m.group(1)) + out["peak_reserved_mib"] = int(m.group(2)) + step_matches = PATTERNS["step_avg"].findall(text) + if step_matches: + step, val_loss, val_bpb, train_time_ms, step_avg = step_matches[-1] + parsed_val_loss = maybe_float(val_loss) + parsed_val_bpb = maybe_float(val_bpb) + parsed_step_avg = maybe_float(step_avg) + if parsed_val_loss is not None and parsed_val_bpb is not None and parsed_step_avg is not None: + out["last_val_step"] = int(step) + out["last_val_loss"] = parsed_val_loss + out["last_val_bpb"] = parsed_val_bpb + out["train_time_ms"] = int(train_time_ms) + out["step_avg_ms"] = parsed_step_avg + return out + + +def main() -> int: + if len(sys.argv) < 2: + print("usage: python3 scripts/parse_run.py [log2 ...]", file=sys.stderr) + return 1 + rows = [parse_log(Path(arg)) for arg in sys.argv[1:]] + print(json.dumps(rows, indent=2, sort_keys=True)) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/rank_moonshots.py b/scripts/rank_moonshots.py new file mode 100755 index 000000000..24622bbb9 --- /dev/null +++ b/scripts/rank_moonshots.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import json +import math +import subprocess +import sys +from pathlib import Path + + +CONTROL_RUN = "twice_eval2048_ttt1024_clean2" +PROFILES = ["drope_eval", "yarn_eval", "mtp_low", "muon_balance", "hybrid_delta"] +ARTIFACT_CAP = 16_000_000 + + +def parse_logs(paths: list[str]) -> list[dict[str, object]]: + existing = [path for path in paths if Path(path).exists()] + if not existing: + return [] + out = subprocess.check_output([sys.executable, "scripts/parse_run.py", *existing], text=True) + return json.loads(out) + + +def score_key(row: dict[str, object]) -> tuple[float, float, int, float]: + return ( + float(row.get("ttt_val_bpb", math.inf)), + float(row.get("roundtrip_val_bpb", math.inf)), + int(row.get("artifact_bytes", 10**12)), + float(row.get("step_avg_ms", math.inf)), + ) + + +def run_id_from_log(path: str) -> str: + return Path(path).stem + + +def promote(row: dict[str, object], control: dict[str, object] | None) -> tuple[bool, str]: + artifact = int(row.get("artifact_bytes", ARTIFACT_CAP + 1)) + if artifact > ARTIFACT_CAP: + return False, "artifact cap exceeded" + if control is None: + return True, "no control available" + row_ttt = float(row.get("ttt_val_bpb", math.inf)) + row_roundtrip = float(row.get("roundtrip_val_bpb", math.inf)) + ctl_ttt = float(control.get("ttt_val_bpb", math.inf)) + ctl_roundtrip = float(control.get("roundtrip_val_bpb", math.inf)) + if run_id_from_log(str(row["log"])) == "hybrid_delta" and (row_ttt < ctl_ttt or row_roundtrip < ctl_roundtrip): + return True, "hybrid delta beat control on at least one final metric" + if row_ttt < ctl_ttt or row_roundtrip < ctl_roundtrip: + return True, "beat control on at least one final metric" + return False, "did not beat control" + + +def main() -> int: + if len(sys.argv) > 1: + paths = sys.argv[1:] + else: + paths = [f"logs/{name}.txt" for name in [CONTROL_RUN, *PROFILES] if Path(f"logs/{name}.txt").exists()] + rows = parse_logs(paths) + for row in rows: + row["run_id"] = run_id_from_log(str(row["log"])) + control = next((row for row in rows if row["run_id"] == CONTROL_RUN), None) + ranked = sorted((row for row in rows if row["run_id"] != CONTROL_RUN), key=score_key) + result = [] + for row in ranked: + ok, reason = promote(row, control) + row = dict(row) + row["promote"] = ok + row["promotion_reason"] = reason + result.append(row) + print(json.dumps({"control": control, "ranked": result}, indent=2, sort_keys=True)) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/rebuild_tokenizer_export.sh b/scripts/rebuild_tokenizer_export.sh new file mode 100755 index 000000000..c1bf4b92f --- /dev/null +++ b/scripts/rebuild_tokenizer_export.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +REPO_ID="${REPO_ID:-willdepueoai/parameter-golf}" +REMOTE_ROOT="${REMOTE_ROOT:-datasets}" +OUTPUT_ROOT="${OUTPUT_ROOT:-$ROOT_DIR/data/exports/tokenizer_ablation}" +TOKENIZER_CONFIG="${TOKENIZER_CONFIG:-$ROOT_DIR/data/tokenizer_specs.ablation.json}" + +python3 data/download_hf_docs_and_tokenize.py \ + --repo-id "$REPO_ID" \ + --remote-root "$REMOTE_ROOT" \ + --output-root "$OUTPUT_ROOT" \ + --tokenizer-config "$TOKENIZER_CONFIG" diff --git a/scripts/remote_fetch_data.sh b/scripts/remote_fetch_data.sh new file mode 100755 index 000000000..0085bb8dd --- /dev/null +++ b/scripts/remote_fetch_data.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +TRAIN_SHARDS="${TRAIN_SHARDS:-1}" +VARIANT="${VARIANT:-sp1024}" + +python3 data/cached_challenge_fineweb.py --variant "$VARIANT" --train-shards "$TRAIN_SHARDS" diff --git a/scripts/run_and_score.sh b/scripts/run_and_score.sh new file mode 100755 index 000000000..b1031b038 --- /dev/null +++ b/scripts/run_and_score.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +PROFILE="${1:?usage: bash scripts/run_and_score.sh }" +shift || true + +bash scripts/run_remote_profile.sh "$PROFILE" "$@" + +LOG_PATH="logs/${RUN_ID:-${PROFILE}}.txt" +if [ ! -f "$LOG_PATH" ]; then + echo "log file not found: $LOG_PATH" >&2 + exit 1 +fi + +echo +echo "Parsed summary:" +python3 scripts/parse_run.py "$LOG_PATH" diff --git a/scripts/run_moonshot5.sh b/scripts/run_moonshot5.sh new file mode 100755 index 000000000..103ec90a9 --- /dev/null +++ b/scripts/run_moonshot5.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +RUNS=(drope_eval yarn_eval mtp_low muon_balance hybrid_delta) +NPROC_PER_NODE="${NPROC_PER_NODE:-1}" + +for run in "${RUNS[@]}"; do + echo + echo "=== Running ${run} ===" + NPROC_PER_NODE="$NPROC_PER_NODE" bash scripts/run_remote_profile.sh "$run" + echo + echo "--- tail ${run} ---" + tail -n 10 "logs/${run}.txt" +done + +echo +echo "=== Ranked moonshot summary ===" +python3 scripts/rank_moonshots.py diff --git a/scripts/run_moonshot5_smoke.sh b/scripts/run_moonshot5_smoke.sh new file mode 100644 index 000000000..ac3055707 --- /dev/null +++ b/scripts/run_moonshot5_smoke.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +RUNS=(drope_eval yarn_eval mtp_low muon_balance hybrid_delta) +NPROC_PER_NODE="${NPROC_PER_NODE:-1}" + +for run in "${RUNS[@]}"; do + echo + echo "=== Smoke ${run} ===" + NPROC_PER_NODE="$NPROC_PER_NODE" bash scripts/run_smoke_profile.sh "$run" + echo + echo "--- tail ${run}_smoke ---" + tail -n 10 "logs/${run}_smoke.txt" +done diff --git a/scripts/run_remote_experiment.sh b/scripts/run_remote_experiment.sh new file mode 100755 index 000000000..69b464d87 --- /dev/null +++ b/scripts/run_remote_experiment.sh @@ -0,0 +1,61 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +export RUN_ID="${RUN_ID:-remote_10l_slide64}" +export DATA_PATH="${DATA_PATH:-./data/datasets/fineweb10B_sp1024}" +export TOKENIZER_PATH="${TOKENIZER_PATH:-./data/tokenizers/fineweb_1024_bpe.model}" +export VOCAB_SIZE="${VOCAB_SIZE:-1024}" +export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" +export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-100}" +export VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-1000}" +export TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}" +export TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-1024}" +export EVAL_SEQ_LEN="${EVAL_SEQ_LEN:-1024}" +export EVAL_STRIDE="${EVAL_STRIDE:-64}" +export ROUNDTRIP_EVAL_SEQ_LEN="${ROUNDTRIP_EVAL_SEQ_LEN:-$EVAL_SEQ_LEN}" +export ROUNDTRIP_EVAL_STRIDE="${ROUNDTRIP_EVAL_STRIDE:-$EVAL_STRIDE}" +export NUM_LAYERS="${NUM_LAYERS:-10}" +export MODEL_DIM="${MODEL_DIM:-512}" +export NUM_HEADS="${NUM_HEADS:-8}" +export NUM_KV_HEADS="${NUM_KV_HEADS:-4}" +export MLP_MULT="${MLP_MULT:-2}" +export TIE_EMBEDDINGS="${TIE_EMBEDDINGS:-1}" +export MATRIX_LR="${MATRIX_LR:-0.02}" +export SCALAR_LR="${SCALAR_LR:-0.02}" +export TIED_EMBED_LR="${TIED_EMBED_LR:-0.03}" +export WARMDOWN_ITERS="${WARMDOWN_ITERS:-2500}" +export MUON_MOMENTUM="${MUON_MOMENTUM:-0.95}" +export MUON_BACKEND_STEPS="${MUON_BACKEND_STEPS:-5}" +export MUON_WEIGHT_DECAY="${MUON_WEIGHT_DECAY:-0.02}" +export MUON_UPDATE_BALANCE="${MUON_UPDATE_BALANCE:-0.0}" +export Z_LOSS_COEF="${Z_LOSS_COEF:-0.0}" +export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.0}" +export ATTN_TWICE_ALPHA_SLOPE="${ATTN_TWICE_ALPHA_SLOPE:-0.0}" +export OVERTONE_INIT_POWER="${OVERTONE_INIT_POWER:-0.5}" +export RESID_MIX_INIT_SCALE="${RESID_MIX_INIT_SCALE:-3.0}" +export ROPE_SCALING="${ROPE_SCALING:-ntk}" +export ROPE_SCALE="${ROPE_SCALE:-1.0}" +export ROUNDTRIP_ROPE_SCALING="${ROUNDTRIP_ROPE_SCALING:-$ROPE_SCALING}" +export ROUNDTRIP_ROPE_SCALE="${ROUNDTRIP_ROPE_SCALE:-$ROPE_SCALE}" +export GRAD_CLIP_NORM="${GRAD_CLIP_NORM:-0.0}" +export MTP_DEPTH="${MTP_DEPTH:-0}" +export MTP_LOSS_WEIGHT="${MTP_LOSS_WEIGHT:-0.25}" +export HYBRID_DELTA_EVERY="${HYBRID_DELTA_EVERY:-0}" +export SHARED_DEPTH_N="${SHARED_DEPTH_N:-0}" +export SHARED_DEPTH_GAIN="${SHARED_DEPTH_GAIN:-0.0}" +export SHARED_DEPTH_EDGE_UNIQUE="${SHARED_DEPTH_EDGE_UNIQUE:-0}" +export TTT_LORA_RANK="${TTT_LORA_RANK:-8}" +export TTT_LORA_LR="${TTT_LORA_LR:-0.01}" +export TTT_CHUNK_SIZE="${TTT_CHUNK_SIZE:-256}" +export TTT_EVAL_SEQ_LEN="${TTT_EVAL_SEQ_LEN:-$EVAL_SEQ_LEN}" +export TTT_ROPE_SCALING="${TTT_ROPE_SCALING:-$ROUNDTRIP_ROPE_SCALING}" +export TTT_ROPE_SCALE="${TTT_ROPE_SCALE:-$ROUNDTRIP_ROPE_SCALE}" +export TTT_BATCH_SIZE="${TTT_BATCH_SIZE:-64}" + +NPROC_PER_NODE="${NPROC_PER_NODE:-1}" + +echo "Starting run ${RUN_ID} with ${NPROC_PER_NODE} GPU(s)" +torchrun --standalone --nproc_per_node="$NPROC_PER_NODE" train_gpt.py diff --git a/scripts/run_remote_profile.sh b/scripts/run_remote_profile.sh new file mode 100755 index 000000000..6ffd4666b --- /dev/null +++ b/scripts/run_remote_profile.sh @@ -0,0 +1,97 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +PROFILE="${1:-base10l}" +shift || true + +case "$PROFILE" in + base10l) + export RUN_ID="${RUN_ID:-base10l}" + ;; + zloss_low) + export RUN_ID="${RUN_ID:-zloss_low}" + export Z_LOSS_COEF="${Z_LOSS_COEF:-0.0001}" + ;; + zloss_med) + export RUN_ID="${RUN_ID:-zloss_med}" + export Z_LOSS_COEF="${Z_LOSS_COEF:-0.0003}" + ;; + twice_low) + export RUN_ID="${RUN_ID:-twice_low}" + export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" + ;; + twice_layerwise) + export RUN_ID="${RUN_ID:-twice_layerwise}" + export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" + export ATTN_TWICE_ALPHA_SLOPE="${ATTN_TWICE_ALPHA_SLOPE:-0.5}" + ;; + zloss_twice) + export RUN_ID="${RUN_ID:-zloss_twice}" + export Z_LOSS_COEF="${Z_LOSS_COEF:-0.0001}" + export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" + ;; + eval2048) + export RUN_ID="${RUN_ID:-eval2048}" + export ROUNDTRIP_EVAL_SEQ_LEN="${ROUNDTRIP_EVAL_SEQ_LEN:-2048}" + export TTT_EVAL_SEQ_LEN="${TTT_EVAL_SEQ_LEN:-2048}" + ;; + twice_eval2048_ttt1024) + export RUN_ID="${RUN_ID:-twice_eval2048_ttt1024}" + export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" + export ROUNDTRIP_EVAL_SEQ_LEN="${ROUNDTRIP_EVAL_SEQ_LEN:-2048}" + export TTT_EVAL_SEQ_LEN="${TTT_EVAL_SEQ_LEN:-1024}" + ;; + drope_eval) + export RUN_ID="${RUN_ID:-drope_eval}" + export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" + export ROUNDTRIP_EVAL_SEQ_LEN="${ROUNDTRIP_EVAL_SEQ_LEN:-2048}" + export ROUNDTRIP_ROPE_SCALING="${ROUNDTRIP_ROPE_SCALING:-drope}" + export TTT_EVAL_SEQ_LEN="${TTT_EVAL_SEQ_LEN:-1024}" + ;; + yarn_eval) + export RUN_ID="${RUN_ID:-yarn_eval}" + export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" + export ROUNDTRIP_EVAL_SEQ_LEN="${ROUNDTRIP_EVAL_SEQ_LEN:-2048}" + export ROUNDTRIP_ROPE_SCALING="${ROUNDTRIP_ROPE_SCALING:-yarn}" + export TTT_EVAL_SEQ_LEN="${TTT_EVAL_SEQ_LEN:-1024}" + ;; + mtp_low) + export RUN_ID="${RUN_ID:-mtp_low}" + export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" + export MTP_DEPTH="${MTP_DEPTH:-2}" + export MTP_LOSS_WEIGHT="${MTP_LOSS_WEIGHT:-0.1}" + ;; + muon_balance) + export RUN_ID="${RUN_ID:-muon_balance}" + export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" + export MUON_UPDATE_BALANCE="${MUON_UPDATE_BALANCE:-0.5}" + ;; + hybrid_delta) + export RUN_ID="${RUN_ID:-hybrid_delta}" + export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" + export HYBRID_DELTA_EVERY="${HYBRID_DELTA_EVERY:-4}" + ;; + shared_depth) + export RUN_ID="${RUN_ID:-shared_depth}" + export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" + export SHARED_DEPTH_N="${SHARED_DEPTH_N:-4}" + export SHARED_DEPTH_GAIN="${SHARED_DEPTH_GAIN:-0.15}" + ;; + shared_depth_midshare) + export RUN_ID="${RUN_ID:-shared_depth_midshare}" + export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" + export SHARED_DEPTH_N="${SHARED_DEPTH_N:-4}" + export SHARED_DEPTH_GAIN="${SHARED_DEPTH_GAIN:-0.05}" + export SHARED_DEPTH_EDGE_UNIQUE="${SHARED_DEPTH_EDGE_UNIQUE:-2}" + ;; + *) + echo "Unknown profile: $PROFILE" >&2 + echo "Profiles: base10l zloss_low zloss_med twice_low twice_layerwise zloss_twice eval2048 twice_eval2048_ttt1024 drope_eval yarn_eval mtp_low muon_balance hybrid_delta shared_depth shared_depth_midshare" >&2 + exit 1 + ;; +esac + +bash scripts/run_remote_experiment.sh "$@" diff --git a/scripts/run_smoke_profile.sh b/scripts/run_smoke_profile.sh new file mode 100644 index 000000000..af414fd74 --- /dev/null +++ b/scripts/run_smoke_profile.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +PROFILE="${1:-base10l}" +shift || true + +export RUN_ID="${RUN_ID:-${PROFILE}_smoke}" +export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-90}" +export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-50}" +export VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-200}" +export WARMUP_STEPS="${WARMUP_STEPS:-5}" +export SKIP_FINAL_EVAL="${SKIP_FINAL_EVAL:-1}" + +bash scripts/run_remote_profile.sh "$PROFILE" "$@" diff --git a/train_gpt.py b/train_gpt.py index 85e2cc463..516e9f827 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -27,17 +27,7 @@ from torch import Tensor, nn from torch.nn.parallel import DistributedDataParallel as DDP -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") train_files = os.path.join(data_path, "fineweb_train_*.bin") val_files = os.path.join(data_path, "fineweb_val_*.bin") @@ -45,21 +35,21 @@ class Hyperparameters: run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) seed = int(os.environ.get("SEED", 1337)) - # Validation cadence and batch size. Validation always uses the full fineweb_val split. val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + eval_seq_len = int(os.environ.get("ROUNDTRIP_EVAL_SEQ_LEN", os.environ.get("EVAL_SEQ_LEN", "0"))) + eval_stride = int(os.environ.get("ROUNDTRIP_EVAL_STRIDE", os.environ.get("EVAL_STRIDE", "0"))) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + skip_final_eval = bool(int(os.environ.get("SKIP_FINAL_EVAL", "0"))) qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) num_layers = int(os.environ.get("NUM_LAYERS", 9)) num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) @@ -68,9 +58,14 @@ class Hyperparameters: mlp_mult = int(os.environ.get("MLP_MULT", 2)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + rope_scaling = os.environ.get("ROPE_SCALING", "ntk").lower() + rope_scale = float(os.environ.get("ROPE_SCALE", "1.0")) + roundtrip_rope_scaling = os.environ.get("ROUNDTRIP_ROPE_SCALING", rope_scaling).lower() + roundtrip_rope_scale = float(os.environ.get("ROUNDTRIP_ROPE_SCALE", os.environ.get("ROPE_SCALE", "1.0"))) + ttt_rope_scaling = os.environ.get("TTT_ROPE_SCALING", roundtrip_rope_scaling).lower() + ttt_rope_scale = float(os.environ.get("TTT_ROPE_SCALE", os.environ.get("ROUNDTRIP_ROPE_SCALE", os.environ.get("ROPE_SCALE", "1.0")))) logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - # Optimizer hyperparameters. embed_lr = float(os.environ.get("EMBED_LR", 0.6)) head_lr = float(os.environ.get("HEAD_LR", 0.008)) tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) @@ -79,30 +74,33 @@ class Hyperparameters: scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_weight_decay = float(os.environ.get("MUON_WEIGHT_DECAY", 0.0)) muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) beta1 = float(os.environ.get("BETA1", 0.9)) beta2 = float(os.environ.get("BETA2", 0.95)) adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + z_loss_coef = float(os.environ.get("Z_LOSS_COEF", 0.0)) + attn_twice_alpha = float(os.environ.get("ATTN_TWICE_ALPHA", 0.0)) + attn_twice_alpha_slope = float(os.environ.get("ATTN_TWICE_ALPHA_SLOPE", 0.0)) + overtone_init_power = float(os.environ.get("OVERTONE_INIT_POWER", 0.0)) + resid_mix_init_scale = float(os.environ.get("RESID_MIX_INIT_SCALE", 3.0)) + mtp_depth = int(os.environ.get("MTP_DEPTH", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.25)) + muon_update_balance = float(os.environ.get("MUON_UPDATE_BALANCE", 0.0)) + hybrid_delta_every = int(os.environ.get("HYBRID_DELTA_EVERY", 0)) + shared_depth_n = int(os.environ.get("SHARED_DEPTH_N", 0)) + shared_depth_gain = float(os.environ.get("SHARED_DEPTH_GAIN", 0.0)) + shared_depth_edge_unique = int(os.environ.get("SHARED_DEPTH_EDGE_UNIQUE", 0)) - # Test-time training (LoRA) hyperparameters. ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.01)) ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 256)) - ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 1024)) + ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", os.environ.get("TRAIN_SEQ_LEN", "1024"))) ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. a, b, c = (3.4445, -4.7750, 2.0315) X = G.bfloat16() X /= X.norm() + eps @@ -117,10 +115,10 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) - class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, update_balance: float = 0.0, nesterov: bool = True): super().__init__( params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, update_balance=update_balance, nesterov=nesterov), ) @torch.no_grad() @@ -141,6 +139,7 @@ def step(self, closure=None): lr = group["lr"] momentum = group["momentum"] backend_steps = group["backend_steps"] + update_balance = group["update_balance"] nesterov = group["nesterov"] total_params = sum(int(p.numel()) for p in params) @@ -158,8 +157,11 @@ def step(self, closure=None): if nesterov: g = g.add(buf, alpha=momentum) g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. g *= max(1, g.size(0) / g.size(1)) ** 0.5 + if update_balance > 0: + g_rms = g.float().pow(2).mean().sqrt().clamp_min_(1e-8) + p_rms = p.detach().float().pow(2).mean().sqrt().clamp_min_(1e-8) + g *= (p_rms / g_rms).pow(update_balance) updates_flat[curr : curr + p.numel()] = g.reshape(-1) curr += p.numel() @@ -174,16 +176,6 @@ def step(self, closure=None): return loss - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - def build_sentencepiece_luts( sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device ) -> tuple[Tensor, Tensor, Tensor]: @@ -215,7 +207,6 @@ def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: files = [Path(p) for p in sorted(glob.glob(pattern))] if not files: raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() usable = ((tokens.numel() - 1) // seq_len) * seq_len if usable <= 0: @@ -234,19 +225,18 @@ def eval_val( base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + seq_len_override: int = 0, ) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge + seq_len = seq_len_override if seq_len_override > 0 else args.train_seq_len local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: + if local_batch_tokens < seq_len: raise ValueError( "VAL_BATCH_SIZE must provide at least one sequence per rank; " f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len seq_start = (total_seqs * rank) // world_size seq_end = (total_seqs * (rank + 1)) // world_size val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) @@ -257,11 +247,11 @@ def eval_val( with torch.inference_mode(): for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): batch_loss = model(x, y).detach() batch_token_count = float(y.numel()) @@ -278,19 +268,60 @@ def eval_val( dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + val_loss = val_loss_sum / val_token_count + return float(val_loss.item()), float((val_loss.item() / math.log(2.0)) * (val_token_count.item() / val_byte_count.item())) + + +def eval_val_sliding( + logits_fn, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + seq_len: int, + stride: int, + eval_batch_seqs: int, +) -> tuple[float, float]: + windows: list[tuple[int, int]] = [] + total = val_tokens.numel() - 1 + pos = 0 + while pos + seq_len <= total: + windows.append((pos, 0 if pos == 0 else seq_len - stride)) + pos += stride + + per_rank = (len(windows) + world_size - 1) // world_size + my_windows = windows[rank * per_rank : min((rank + 1) * per_rank, len(windows))] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + tok_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + with torch.inference_mode(): + for i in range(0, len(my_windows), eval_batch_seqs): + batch = my_windows[i : i + eval_batch_seqs] + x = torch.stack([val_tokens[p : p + seq_len] for p, _ in batch]).to(device=device, dtype=torch.int64) + y = torch.stack([val_tokens[p + 1 : p + seq_len + 1] for p, _ in batch]).to(device=device, dtype=torch.int64) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = logits_fn(x) + for b, (_, score_offset) in enumerate(batch): + scored_targets = y[b, score_offset:] + loss_sum += F.cross_entropy(logits[b, score_offset:].float(), scored_targets, reduction="sum").to(torch.float64) + tok_count += scored_targets.numel() + prev = x[b, score_offset : score_offset + scored_targets.numel()] + token_bytes = base_bytes_lut[scored_targets].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[scored_targets] & ~is_boundary_token_lut[prev]).to(dtype=torch.int16) + byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(tok_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = loss_sum / tok_count + return float(val_loss.item()), float((val_loss.item() / math.log(2.0)) * (tok_count.item() / byte_count.item())) -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. CONTROL_TENSOR_NAME_PATTERNS = tuple( pattern @@ -328,8 +359,6 @@ def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, s def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: t32 = t.float() if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. clip_abs = ( torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() @@ -340,18 +369,12 @@ def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - # Vectors / scalars use a simpler per-tensor scale. clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() return q, scale def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes quantized: dict[str, Tensor] = {} scales: dict[str, Tensor] = {} dtypes: dict[str, str] = {} @@ -375,8 +398,13 @@ def quantize_state_dict_int8(state_dict: dict[str, Tensor]): stats["int8_payload_bytes"] += tensor_nbytes(t) continue - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if name == "tok_emb.weight": + kept = t.to(dtype=torch.float16).contiguous() + passthrough[name] = kept + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: kept = keep_float_tensor(name, t, passthrough_orig_dtypes) passthrough[name] = kept @@ -414,13 +442,11 @@ def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: s = obj["scales"][name] if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() else: scale = float(s.item()) out[name] = (q.float() * scale).to(dtype=dtype).contiguous() for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. out_t = t.detach().to("cpu").contiguous() orig_dtype = passthrough_orig_dtypes.get(name) if isinstance(orig_dtype, str): @@ -428,16 +454,10 @@ def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: out[name] = out_t return out - -# ----------------------------- -# DATA LOADING -# ----------------------------- - def load_data_shard(file: Path) -> Tensor: header_bytes = 256 * np.dtype(" Tensor: class TokenStream: - # Reads shards sequentially and wraps around forever. The training loop therefore - # has deterministic, simple streaming behavior with no sampling or workers. def __init__(self, pattern: str): self.files = [Path(p) for p in sorted(glob.glob(pattern))] if not self.files: @@ -482,8 +500,6 @@ def take(self, n: int) -> Tensor: class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): self.rank = rank self.world_size = world_size @@ -500,9 +516,6 @@ def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> y = local[1:].reshape(-1, seq_len) return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- class RMSNorm(nn.Module): def __init__(self, eps: float | None = None): @@ -514,30 +527,45 @@ def forward(self, x: Tensor) -> Tensor: class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. def forward(self, x: Tensor) -> Tensor: bias = self.bias.to(x.dtype) if self.bias is not None else None return F.linear(x, self.weight.to(x.dtype), bias) def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. with torch.no_grad(): for name, param in module.named_parameters(): if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: param.data = param.data.float() +def export_state_dict(module: nn.Module) -> dict[str, Tensor]: + return {name: tensor for name, tensor in module.state_dict().items() if not name.startswith("mtp_heads.")} + +def load_exported_state_dict(module: nn.Module, state_dict: dict[str, Tensor]) -> None: + missing, unexpected = module.load_state_dict(state_dict, strict=False) + bad_missing = [name for name in missing if not name.startswith("mtp_heads.")] + if bad_missing or unexpected: + raise RuntimeError(f"Export reload mismatch missing={bad_missing} unexpected={list(unexpected)}") class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, scaling: str = "ntk", scale: float = 1.0): super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.scaling = scaling + self.scale = scale inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._seq_len_cached = 0 self._cos_cached: Tensor | None = None self._sin_cached: Tensor | None = None + def set_scaling(self, scaling: str, scale: float) -> None: + if self.scaling != scaling or self.scale != scale: + self.scaling, self.scale = scaling, scale + self._seq_len_cached, self._cos_cached, self._sin_cached = 0, None, None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: if ( self._cos_cached is None @@ -545,8 +573,23 @@ def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tup or self._seq_len_cached != seq_len or self._cos_cached.device != device ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) + stretch = max(seq_len / self.train_seq_len, 1.0) + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + if seq_len > self.train_seq_len: + if self.scaling == "yarn": + t = t / (stretch * self.scale) + elif self.scaling == "drope": + t = t / ((stretch ** 0.5) * self.scale) + inv_freq = inv_freq / (stretch ** 0.25) + else: + adjusted_base = self.base * ((stretch * self.scale) ** (self.dim / (self.dim - 2))) + inv_freq = 1.0 / ( + adjusted_base ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim) + ) + elif self.scale != 1.0: + t = t / self.scale + freqs = torch.outer(t, inv_freq) self._cos_cached = freqs.cos()[None, None, :, :] self._sin_cached = freqs.sin()[None, None, :, :] self._seq_len_cached = seq_len @@ -557,8 +600,11 @@ def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: half = x.size(-1) // 2 x1, x2 = x[..., :half], x[..., half:] return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - +def lm_loss(logits: Tensor, targets: Tensor, z_loss_coef: float, reduction: str) -> Tensor: + losses = F.cross_entropy(logits.float().reshape(-1, logits.size(-1)), targets.reshape(-1), reduction="none").reshape_as(targets) + if z_loss_coef > 0: + losses = losses + z_loss_coef * torch.logsumexp(logits.float(), dim=-1).square() + return losses if reduction == "none" else losses.mean() class CausalSelfAttention(nn.Module): def __init__( self, @@ -567,6 +613,9 @@ def __init__( num_kv_heads: int, rope_base: float, qk_gain_init: float, + train_seq_len: int, + rope_scaling: str, + rope_scale: float, ): super().__init__() if dim % num_heads != 0: @@ -585,7 +634,7 @@ def __init__( self.proj = CastedLinear(dim, dim, bias=False) self.proj._zero_init = True self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len, scaling=rope_scaling, scale=rope_scale) def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: bsz, seqlen, dim = x.shape @@ -601,20 +650,15 @@ def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: q = apply_rotary_emb(q, cos, sin) k = apply_rotary_emb(k, cos, sin) q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) + if self.num_kv_heads != self.num_heads: + k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True) y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) return self.proj(y) class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup def __init__(self, dim: int, mlp_mult: int): super().__init__() hidden = mlp_mult * dim @@ -625,8 +669,21 @@ def __init__(self, dim: int, mlp_mult: int): def forward(self, x: Tensor) -> Tensor: x = torch.relu(self.fc(x)) return self.proj(x.square()) +class DeltaMixer(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.in_proj = CastedLinear(dim, 2 * dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True - + def forward(self, x: Tensor) -> Tensor: + u, g = self.in_proj(x).chunk(2, dim=-1) + g = torch.sigmoid(g.float()).to(dtype=x.dtype) + state, outs = torch.zeros_like(u[:, 0]), [] + for t in range(x.size(1)): + state = g[:, t] * state + (1 - g[:, t]) * u[:, t] + outs.append(state) + return self.proj(torch.stack(outs, dim=1)) class Block(nn.Module): def __init__( self, @@ -636,11 +693,18 @@ def __init__( mlp_mult: int, rope_base: float, qk_gain_init: float, + train_seq_len: int, + attn_twice_alpha: float, + rope_scaling: str, + rope_scale: float, + hybrid_delta: bool, ): super().__init__() self.attn_norm = RMSNorm() self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.use_delta = hybrid_delta + self.attn = DeltaMixer(dim) if hybrid_delta else CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len, rope_scaling, rope_scale) + self.attn_twice_alpha = attn_twice_alpha self.mlp = MLP(dim, mlp_mult) self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) @@ -650,9 +714,14 @@ def forward(self, x: Tensor, x0: Tensor, q_delta_fn=None, v_delta_fn=None) -> Te mix = self.resid_mix.to(dtype=x.dtype) x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 n = self.attn_norm(x) - qd = q_delta_fn(n) if q_delta_fn is not None else None - vd = v_delta_fn(n) if v_delta_fn is not None else None - attn_out = self.attn(n, qd, vd) + if self.use_delta: + attn_out = self.attn(n) + else: + qd = q_delta_fn(n) if q_delta_fn is not None else None + vd = v_delta_fn(n) if v_delta_fn is not None else None + attn_out = self.attn(n, qd, vd) + if self.attn_twice_alpha != 0.0: + attn_out = attn_out + self.attn_twice_alpha * (n - attn_out) x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) return x @@ -672,18 +741,42 @@ def __init__( logit_softcap: float, rope_base: float, qk_gain_init: float, + train_seq_len: int, + z_loss_coef: float, + attn_twice_alpha: float, + attn_twice_alpha_slope: float, + overtone_init_power: float, + resid_mix_init_scale: float, + rope_scaling: str, + rope_scale: float, + mtp_depth: int, + mtp_loss_weight: float, + hybrid_delta_every: int, + shared_depth_n: int, + shared_depth_gain: float, + shared_depth_edge_unique: int, ): super().__init__() if logit_softcap <= 0.0: raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") self.tie_embeddings = tie_embeddings self.tied_embed_init_std = tied_embed_init_std + self.z_loss_coef = z_loss_coef + self.overtone_init_power = overtone_init_power + self.resid_mix_init_scale = resid_mix_init_scale + self.mtp_loss_weight = mtp_loss_weight self.logit_softcap = logit_softcap self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.logical_num_layers = num_layers + edge = min(shared_depth_edge_unique, num_layers // 2) if shared_depth_n > 0 else 0 + shared_blocks = min(shared_depth_n, max(num_layers - 2 * edge, 1)) if shared_depth_n > 0 else num_layers self.num_encoder_layers = num_layers // 2 self.num_decoder_layers = num_layers - self.num_encoder_layers self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.block_map = list(range(edge)) + [edge + (i % shared_blocks) for i in range(max(num_layers - 2 * edge, 0))] + [edge + shared_blocks + i for i in range(edge)] if shared_depth_n > 0 else list(range(num_layers)) + self.pass_scales = nn.Parameter((1.0 + shared_depth_gain * torch.linspace(-1, 1, num_layers)[:, None]).expand(-1, model_dim).contiguous()) if shared_depth_n > 0 else None + self.mtp_heads = nn.ModuleList([CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_depth)]) self.blocks = nn.ModuleList( [ Block( @@ -693,8 +786,13 @@ def __init__( mlp_mult, rope_base, qk_gain_init, + train_seq_len, + attn_twice_alpha * (1 + attn_twice_alpha_slope * (2 * i / max(num_layers - 1, 1) - 1)), + rope_scaling, + rope_scale, + hybrid_delta_every > 0 and (i + 1) % hybrid_delta_every == 0, ) - for i in range(num_layers) + for i in range(max(self.block_map) + 1) ] ) self.final_norm = RMSNorm() @@ -702,54 +800,82 @@ def __init__( if self.lm_head is not None: self.lm_head._zero_init = True self._init_weights() - def _init_weights(self) -> None: if self.tie_embeddings: nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + if self.overtone_init_power > 0.0: + with torch.no_grad(): + u, s, v = torch.linalg.svd(self.tok_emb.weight.data, full_matrices=False) + self.tok_emb.weight.data = (u * (s[0] * torch.arange(1, s.shape[0] + 1, dtype=s.dtype).pow(-self.overtone_init_power))[None, :]) @ v for module in self.modules(): if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: + for i, block in enumerate(self.blocks): + with torch.no_grad(): + phase = torch.sigmoid(torch.tensor(self.resid_mix_init_scale * (i / max(len(self.blocks) - 1, 1) - 0.5))) + block.resid_mix.data[0].fill_(float(phase)) + block.resid_mix.data[1].fill_(float(1.0 - phase)) + def set_rope_scaling(self, scaling: str, scale: float) -> None: + for block in self.blocks: + if not block.use_delta: + block.attn.rotary.set_scaling(scaling, scale) + + def _forward_hidden(self, input_ids: Tensor, lora=None) -> Tensor: x = self.tok_emb(input_ids) x = F.rms_norm(x, (x.size(-1),)) x0 = x skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. for i in range(self.num_encoder_layers): + block = self.blocks[self.block_map[i]] qd = lora.q_loras[i] if lora else None vd = lora.v_loras[i] if lora else None - x = self.blocks[i](x, x0, qd, vd) + x = block(x, x0, qd, vd) + if self.pass_scales is not None: + x = self.pass_scales[i].to(dtype=x.dtype)[None, None, :] * x skips.append(x) for i in range(self.num_decoder_layers): bi = self.num_encoder_layers + i if skips: x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + block = self.blocks[self.block_map[bi]] qd = lora.q_loras[bi] if lora else None vd = lora.v_loras[bi] if lora else None - x = self.blocks[bi](x, x0, qd, vd) - x = self.final_norm(x) + x = block(x, x0, qd, vd) + if self.pass_scales is not None: + x = self.pass_scales[bi].to(dtype=x.dtype)[None, None, :] * x + return self.final_norm(x) + + def _hidden_to_logits(self, x: Tensor, lora=None) -> Tensor: if self.tie_embeddings: logits = F.linear(x, self.tok_emb.weight) else: logits = self.lm_head(x) logits = logits + (lora.lm_head_lora(x) if lora else 0) - logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) - if lora: - bsz, sl, V = logits.shape - return F.cross_entropy( - logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none").reshape(bsz, sl) - return F.cross_entropy(logits.float().reshape(-1, logits.size(-1)), target_ids.reshape(-1), reduction="mean") - + return self.logit_softcap * torch.tanh(logits / self.logit_softcap) -# ----------------------------- -# TEST-TIME TRAINING (LoRA) -# ----------------------------- -# -# At evaluation time, we adapt per-document low-rank adapters on the validation data. -# Each document gets its own adapter, so there is no inter-document dependency. + def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: + x = self._forward_hidden(input_ids, lora=lora) + logits = self._hidden_to_logits(x, lora=lora) + loss = lm_loss(logits, target_ids, self.z_loss_coef, reduction="none" if lora else "mean") + if lora or not self.training or not self.mtp_heads or self.mtp_loss_weight <= 0: + return loss + aux = 0.0 + depth = 1 + for head in self.mtp_heads: + if target_ids.size(1) <= depth: + break + aux = aux + lm_loss( + self.logit_softcap * torch.tanh(head(x[:, :-depth]) / self.logit_softcap), + target_ids[:, depth:], + 0.0, + reduction="mean", + ) + depth += 1 + return loss + self.mtp_loss_weight * aux / max(len(self.mtp_heads), 1) + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self._forward_hidden(input_ids) + return self._hidden_to_logits(x) BOS_ID = 1 class BatchedLinearLoRA(nn.Module): @@ -771,6 +897,10 @@ def reset(self) -> None: self.A.uniform_(-bound, bound) # kaiming-uniform self.B.zero_() + +class NullLoRA(nn.Module): + def forward(self, x: Tensor) -> int: + return 0 class BatchedTTTLoRA(nn.Module): """All LoRA adapters for one batch: LM head and Q/V per block.""" def __init__(self, bsz: int, model: GPT, rank: int): @@ -780,15 +910,19 @@ def __init__(self, bsz: int, model: GPT, rank: int): self.lm_head_lora = BatchedLinearLoRA(bsz, dim, vocab, rank) self.q_loras = nn.ModuleList() self.v_loras = nn.ModuleList() - for block in model.blocks: - self.q_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_q.weight.shape[0], rank)) - self.v_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_v.weight.shape[0], rank)) + for i in range(model.logical_num_layers): + block = model.blocks[model.block_map[i]] + if getattr(block, "use_delta", False): + self.q_loras.append(NullLoRA()) + self.v_loras.append(NullLoRA()) + else: + self.q_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_q.weight.shape[0], rank)) + self.v_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_v.weight.shape[0], rank)) def reset(self) -> None: for m in self.modules(): if isinstance(m, BatchedLinearLoRA): m.reset() - def _reset_ttt_optimizer(opt): for group in opt.param_groups: for p in group['params']: @@ -798,10 +932,8 @@ def _reset_ttt_optimizer(opt): s['exp_avg'].zero_() s['exp_avg_sq'].zero_() s['step'].fill_(0) - def _build_ttt_optimizer(lora, args: Hyperparameters): return torch.optim.Adam(lora.parameters(), lr=args.ttt_lora_lr, betas=(args.beta1, args.beta2), eps=1e-10) - def _find_docs(all_tokens: Tensor, include_next_bos: bool = True) -> list[tuple[int, int]]: """Return (start_offset, length) for each document, identified by BOS boundaries. @@ -818,7 +950,6 @@ def _find_docs(all_tokens: Tensor, include_next_bos: bool = True) -> list[tuple[ assert end - start >= 2 docs.append((start, end - start)) return docs - def _compute_chunk_window(ci: int, pred_len: int, num_chunks: int, chunk_size: int, eval_seq_len: int): """Return (win_start, win_len, chunk_offset, chunk_len) for chunk `ci` of a doc.""" chunk_start = ci * chunk_size @@ -828,7 +959,6 @@ def _compute_chunk_window(ci: int, pred_len: int, num_chunks: int, chunk_size: i chunk_offset = chunk_start - win_start chunk_len = chunk_end - chunk_start return win_start, win_len, chunk_offset, chunk_len - def _accumulate_bpb( ptl: Tensor, x: Tensor, y: Tensor, batch_i: int, chunk_offset: int, chunk_len: int, @@ -844,7 +974,6 @@ def _accumulate_bpb( loss_sum += lbl.sum() byte_sum += tok_bytes.sum() token_count += chunk_len - def eval_val_ttt_lora( args: Hyperparameters, base_model: GPT, @@ -856,12 +985,10 @@ def eval_val_ttt_lora( is_boundary_token_lut: Tensor, ) -> tuple[float, float]: """Evaluate with batched LoRA test-time training. Returns (val_loss, val_bpb).""" - # Load validation tokens and find document boundaries files = sorted(glob.glob(args.val_files)) all_tokens = torch.cat([load_data_shard(Path(f)) for f in files]) docs = _find_docs(all_tokens) - # Each rank takes a contiguous slice of documents rank_docs = docs[(len(docs) * rank) // world_size : (len(docs) * (rank + 1)) // world_size] chunk_size = args.ttt_chunk_size eval_seq_len = args.ttt_eval_seq_len @@ -919,7 +1046,6 @@ def eval_val_ttt_lora( y[b, :wl] = toks[1:] doc_info.append((co, cl)) - # Forward pass (keep grad graph alive only when we need to train) if needs_train: with torch.autocast(device_type="cuda", dtype=torch.bfloat16): ptl = base_model(x, y, lora=cur_lora) @@ -927,7 +1053,6 @@ def eval_val_ttt_lora( with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): ptl = base_model(x, y, lora=cur_lora) - # Score: accumulate loss and byte counts for BPB (before training on chunk) with torch.no_grad(): for b in range(bsz): if not active[b]: @@ -937,7 +1062,6 @@ def eval_val_ttt_lora( ptl, x, y, b, co, cl, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, loss_sum, byte_sum, token_count) - # Train: one Adam step on the LoRA params using this chunk's loss if needs_train: mask = torch.tensor([float(ci < num_chunks[b] - 1) for b in range(bsz)], device=device) per_doc = ptl[:, chunk_offset:chunk_offset + chunk_size].mean(dim=-1) @@ -953,22 +1077,12 @@ def eval_val_ttt_lora( val_loss = float(loss_sum.item() / token_count.item()) val_bpb = float((loss_sum.item() / math.log(2.0)) / byte_sum.item()) return val_loss, val_bpb - -# ----------------------------- -# TRAINING -# ----------------------------- - def main() -> None: global zeropower_via_newtonschulz5 code = Path(__file__).read_text(encoding="utf-8") args = Hyperparameters() zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ rank = int(os.environ.get("RANK", "0")) world_size = int(os.environ.get("WORLD_SIZE", "1")) @@ -988,7 +1102,6 @@ def main() -> None: dist.barrier() master_process = rank == 0 - # Fast math knobs torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp @@ -1012,7 +1125,6 @@ def log0(msg: str, console: bool = True) -> None: if logfile is not None: with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) - log0(code, console=False) log0("=" * 100, console=False) log0(f"Running Python {sys.version}", console=False) @@ -1023,9 +1135,6 @@ def log0(msg: str, console: bool = True) -> None: ) log0("=" * 100, console=False) - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- random.seed(args.seed) np.random.seed(args.seed) @@ -1049,9 +1158,6 @@ def log0(msg: str, console: bool = True) -> None: log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- base_model = GPT( vocab_size=args.vocab_size, @@ -1065,6 +1171,20 @@ def log0(msg: str, console: bool = True) -> None: logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + train_seq_len=args.train_seq_len, + z_loss_coef=args.z_loss_coef, + attn_twice_alpha=args.attn_twice_alpha, + attn_twice_alpha_slope=args.attn_twice_alpha_slope, + overtone_init_power=args.overtone_init_power, + resid_mix_init_scale=args.resid_mix_init_scale, + rope_scaling=args.rope_scaling, + rope_scale=args.rope_scale, + mtp_depth=args.mtp_depth, + mtp_loss_weight=args.mtp_loss_weight, + hybrid_delta_every=args.hybrid_delta_every, + shared_depth_n=args.shared_depth_n, + shared_depth_gain=args.shared_depth_gain, + shared_depth_edge_unique=args.shared_depth_edge_unique, ).to(device).bfloat16() for module in base_model.modules(): if isinstance(module, CastedLinear): @@ -1072,14 +1192,9 @@ def log0(msg: str, console: bool = True) -> None: if isinstance(module, Rotary): module.inv_freq.data = module.inv_freq.data.float() restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.hybrid_delta_every <= 0 else base_model model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam block_named_params = list(base_model.blocks.named_parameters()) matrix_params = [ p @@ -1093,6 +1208,8 @@ def log0(msg: str, console: bool = True) -> None: ] if base_model.skip_weights.numel() > 0: scalar_params.append(base_model.skip_weights) + if base_model.pass_scales is not None: + scalar_params.append(base_model.pass_scales) token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr optimizer_tok = torch.optim.Adam( [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], @@ -1105,6 +1222,7 @@ def log0(msg: str, console: bool = True) -> None: lr=args.matrix_lr, momentum=args.muon_momentum, backend_steps=args.muon_backend_steps, + update_balance=args.muon_update_balance, ) for group in optimizer_muon.param_groups: group["base_lr"] = args.matrix_lr @@ -1115,6 +1233,15 @@ def log0(msg: str, console: bool = True) -> None: fused=True, ) optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if len(base_model.mtp_heads) > 0: + mtp_lr = args.head_lr if args.head_lr > 0 else args.scalar_lr + optimizer_mtp = torch.optim.Adam( + [{"params": list(base_model.mtp_heads.parameters()), "lr": mtp_lr, "base_lr": mtp_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.append(optimizer_mtp) if base_model.lm_head is not None: optimizer_head = torch.optim.Adam( [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], @@ -1134,17 +1261,23 @@ def log0(msg: str, console: bool = True) -> None: f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" ) + log0( + f"roundtrip_eval_seq_len:{args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len} " + f"roundtrip_eval_stride:{args.eval_stride} ttt_eval_seq_len:{args.ttt_eval_seq_len} " + f"rope_scaling:{args.rope_scaling} roundtrip_rope_scaling:{args.roundtrip_rope_scaling} " + f"ttt_rope_scaling:{args.ttt_rope_scaling} muon_weight_decay:{args.muon_weight_decay} " + f"muon_update_balance:{args.muon_update_balance} z_loss_coef:{args.z_loss_coef} " + f"attn_twice_alpha:{args.attn_twice_alpha} attn_twice_alpha_slope:{args.attn_twice_alpha_slope} " + f"mtp_depth:{args.mtp_depth} hybrid_delta_every:{args.hybrid_delta_every} shared_depth_n:{args.shared_depth_n} " + f"shared_depth_gain:{args.shared_depth_gain} shared_depth_edge_unique:{args.shared_depth_edge_unique} " + f"overtone_init_power:{args.overtone_init_power}" + ) log0( f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" ) log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) def zero_grad_all() -> None: @@ -1164,8 +1297,6 @@ def lr_mul(step: int, elapsed_ms: float) -> float: remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. if args.warmup_steps > 0: initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] @@ -1192,9 +1323,6 @@ def lr_mul(step: int, elapsed_ms: float) -> float: model.require_backward_grad_sync = True train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- training_time_ms = 0.0 stop_after_step: int | None = None @@ -1263,6 +1391,11 @@ def lr_mul(step: int, elapsed_ms: float) -> float: torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) for opt in optimizers: opt.step() + if args.muon_weight_decay > 0: + with torch.no_grad(): + shrink = 1.0 - args.muon_weight_decay * optimizer_muon.param_groups[0]["lr"] + for p in matrix_params: + p.mul_(shrink) zero_grad_all() step += 1 @@ -1277,7 +1410,6 @@ def lr_mul(step: int, elapsed_ms: float) -> float: f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" ) - # Needed to sync whether we've reached the wallclock cap. reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms if distributed and max_wallclock_ms is not None: reached_cap_tensor = torch.tensor(int(reached_cap), device=device) @@ -1290,22 +1422,15 @@ def lr_mul(step: int, elapsed_ms: float) -> float: f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") + torch.save(export_state_dict(base_model), "final_model.pt") model_bytes = os.path.getsize("final_model.pt") code_bytes = len(code.encode("utf-8")) log0(f"Serialized model: {model_bytes} bytes") log0(f"Code size: {code_bytes} bytes") log0(f"Total submission size: {model_bytes + code_bytes} bytes") - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_obj, quant_stats = quantize_state_dict_int8(export_state_dict(base_model)) quant_buf = io.BytesIO() torch.save(quant_obj, quant_buf) quant_raw = quant_buf.getvalue() @@ -1328,21 +1453,27 @@ def lr_mul(step: int, elapsed_ms: float) -> float: with open("final_model.int8.ptz", "rb") as f: quant_blob_disk = f.read() quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + load_exported_state_dict(base_model, dequantize_state_dict_int8(quant_state)) + if args.skip_final_eval: + log0("smoke_test: quantized reload ok; skipping final roundtrip and TTT eval") + if distributed: + dist.destroy_process_group() + return + eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_tokens_eval = load_validation_tokens(args.val_files, eval_seq_len) if eval_seq_len != args.train_seq_len else val_tokens torch.cuda.synchronize() t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) + if args.eval_stride > 0: + eval_batch_seqs = min(256, max(1, args.val_batch_size // eval_seq_len // max(world_size, 1))) + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False) if args.hybrid_delta_every <= 0 else base_model.forward_logits + warmup_x = torch.zeros(eval_batch_seqs, eval_seq_len, dtype=torch.int64, device=device) + base_model.eval() + with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + _ = compiled_logits(warmup_x) + q_val_loss, q_val_bpb = eval_val_sliding(compiled_logits, rank, world_size, device, val_tokens_eval, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, eval_seq_len, args.eval_stride, eval_batch_seqs) + base_model.train() + else: + q_val_loss, q_val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens_eval, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, seq_len_override=eval_seq_len if eval_seq_len != args.train_seq_len else 0) torch.cuda.synchronize() log0( f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " @@ -1350,7 +1481,6 @@ def lr_mul(step: int, elapsed_ms: float) -> float: ) log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - # LoRA test-time training evaluation (the competition score) torch._dynamo.reset() torch.cuda.synchronize() t_ttt = time.perf_counter() @@ -1363,10 +1493,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float: f"final_int8_ttt_lora val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f} " f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" ) - if distributed: dist.destroy_process_group() - - if __name__ == "__main__": main()