From 9ef3c1365ec17756c8dbb14172c3204e74572def Mon Sep 17 00:00:00 2001 From: DeividasAQ22 Date: Fri, 20 Mar 2026 17:58:48 +0000 Subject: [PATCH 01/17] Checkpoint Runpod ablations and tooling --- EXPERIMENT_PLAN.md | 81 + REMOTE_RUNBOOK.md | 83 + .../REMOTE_CHECKPOINT.md | 8 + data/tokenizer_specs.ablation.json | 29 + program.md | 86 + records/_template/README.md | 38 + records/_template/submission.json | 20 + .../2026-03-20_Runpod1GPU_Ablations/README.md | 68 + .../base10l.tail.txt | 10 + .../results.json | 85 + .../submission.json | 17 + .../train_gpt.py | 1496 +++++++++++++++++ .../twice_eval2048.tail.txt | 10 + .../twice_eval2048_ttt1024.tail.txt | 10 + .../twice_low.tail.txt | 10 + .../zloss_low.tail.txt | 10 + scripts/package_record.sh | 25 + scripts/parse_run.py | 50 + scripts/rebuild_tokenizer_export.sh | 16 + scripts/remote_fetch_data.sh | 10 + scripts/run_and_score.sh | 20 + scripts/run_remote_experiment.sh | 45 + scripts/run_remote_profile.sh | 43 + train_gpt.py | 214 ++- 24 files changed, 2439 insertions(+), 45 deletions(-) create mode 100644 EXPERIMENT_PLAN.md create mode 100644 REMOTE_RUNBOOK.md create mode 100644 checkpoints/2026-03-20_twice_eval2048_ttt1024/REMOTE_CHECKPOINT.md create mode 100644 data/tokenizer_specs.ablation.json create mode 100644 program.md create mode 100644 records/_template/README.md create mode 100644 records/_template/submission.json create mode 100644 records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/README.md create mode 100644 records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/base10l.tail.txt create mode 100644 records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/results.json create mode 100644 records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/submission.json create mode 100644 records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/train_gpt.py create mode 100644 records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/twice_eval2048.tail.txt create mode 100644 records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/twice_eval2048_ttt1024.tail.txt create mode 100644 records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/twice_low.tail.txt create mode 100644 records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/zloss_low.tail.txt create mode 100755 scripts/package_record.sh create mode 100755 scripts/parse_run.py create mode 100755 scripts/rebuild_tokenizer_export.sh create mode 100755 scripts/remote_fetch_data.sh create mode 100755 scripts/run_and_score.sh create mode 100755 scripts/run_remote_experiment.sh create mode 100755 scripts/run_remote_profile.sh diff --git a/EXPERIMENT_PLAN.md b/EXPERIMENT_PLAN.md new file mode 100644 index 000000000..18cc73b76 --- /dev/null +++ b/EXPERIMENT_PLAN.md @@ -0,0 +1,81 @@ +# Experiment Plan + +This plan is optimized for limited budget and the challenge rules. + +## Goals + +- Improve `final_int8_zlib_roundtrip_exact val_bpb` +- Improve `final_int8_ttt_lora val_bpb` +- Stay under the `16,000,000` byte artifact cap +- Avoid risky dataset changes until the safe path is exhausted + +## Model Ablations First + +Run these in order on remote GPUs: + +1. `base10l` +2. `zloss_low` +3. `zloss_med` +4. `twice_low` +5. `zloss_twice` +6. `eval2048` + +Use the named launcher: + +```bash +NPROC_PER_NODE=1 bash scripts/run_remote_profile.sh base10l +NPROC_PER_NODE=1 bash scripts/run_remote_profile.sh zloss_low +``` + +Then repeat the best 2-3 profiles on `8x H100 SXM`. + +## Dataset And Tokenizer Work + +The challenge allows tokenizer or dataset changes, but the repo says they will be examined carefully and you must prove the `val_bpb` calculation is correct. See [README.md](/Users/deividasmataciunas/Desktop/research/openai_golf/README.md#L168). + +Safest path: + +- Rebuild tokenizers from the published docs cache only +- Re-export shards from the same selected docs +- Keep validation on the fixed first `50k` docs + +Use: + +```bash +bash scripts/rebuild_tokenizer_export.sh +``` + +Default ablation config: + +- `sp_bpe_768` +- `sp_bpe_1024` +- `sp_bpe_1280` +- `sp_bpe_1536` +- `pure_byte_260` + +## Dataset Ideas That Look Safe + +- Vary tokenizer vocab size on the same published docs +- Compare pure-byte vs SentencePiece BPE +- Train on a prefix of shards, then do a short final stage on a higher-quality subset from the same docs +- Filter obviously low-value docs from the training side only +- Keep document boundaries clean during training and eval + +## Risky Ideas + +- External corpora +- Changing validation docs +- Any data use at eval time beyond what the rules allow +- Tokenizer changes without exact byte-accounting validation + +## Success Metrics + +For each run, record: + +- `val_bpb` +- `final_int8_zlib_roundtrip_exact val_bpb` +- `final_int8_ttt_lora val_bpb` +- `Total submission size int8+zlib` +- `step_avg` + +If a tokenizer change helps pre-quant quality but hurts artifact bytes, reject it early. diff --git a/REMOTE_RUNBOOK.md b/REMOTE_RUNBOOK.md new file mode 100644 index 000000000..8cf164efe --- /dev/null +++ b/REMOTE_RUNBOOK.md @@ -0,0 +1,83 @@ +# Remote Runbook + +This repo is ready for the CUDA path. + +## Recommended Path + +Use the official Runpod Parameter Golf template mentioned in [README.md](/Users/deividasmataciunas/Desktop/research/openai_golf/README.md). + +Start with one of these: + +- `1x H100`: cheapest realistic sanity-check path for code, logs, artifact size, and eval behavior. +- `8x H100 SXM`: record-track run once the recipe looks stable. + +## First-Time Remote Setup + +On the remote box: + +```bash +cd /workspace +git clone https://github.com/openai/parameter-golf.git +cd parameter-golf +git remote add myfork +git fetch myfork +git checkout +``` + +Then hydrate the published cache: + +```bash +TRAIN_SHARDS=1 bash scripts/remote_fetch_data.sh +``` + +For a fuller training prefix: + +```bash +TRAIN_SHARDS=10 bash scripts/remote_fetch_data.sh +``` + +## First Experiment + +This is the first recipe to run against our merged script: + +```bash +NPROC_PER_NODE=1 bash scripts/run_remote_experiment.sh +``` + +For a full multi-GPU run: + +```bash +NPROC_PER_NODE=8 bash scripts/run_remote_experiment.sh +``` + +## What This Recipe Uses + +- `10` layers +- fp16 tied-embedding export +- NTK-aware longer eval support +- sliding-window eval with stride `64` +- decoupled Muon weight decay +- overtone embedding init +- phase-shaped residual mixing init + +## First Ablations To Queue + +Run these one at a time after the first successful remote run: + +```bash +EVAL_STRIDE=0 NPROC_PER_NODE=8 bash scripts/run_remote_experiment.sh +EVAL_SEQ_LEN=2048 NPROC_PER_NODE=8 bash scripts/run_remote_experiment.sh +NUM_LAYERS=9 NPROC_PER_NODE=8 bash scripts/run_remote_experiment.sh +MUON_WEIGHT_DECAY=0.00 NPROC_PER_NODE=8 bash scripts/run_remote_experiment.sh +OVERTONE_INIT_POWER=0.00 NPROC_PER_NODE=8 bash scripts/run_remote_experiment.sh +``` + +## What To Look For + +- `step_avg` +- final `val_bpb` +- final `final_int8_zlib_roundtrip_exact` +- final `final_int8_ttt_lora` +- total `int8+zlib` artifact bytes + +If you send me a remote log, I can turn it into the next ablation decision quickly. diff --git a/checkpoints/2026-03-20_twice_eval2048_ttt1024/REMOTE_CHECKPOINT.md b/checkpoints/2026-03-20_twice_eval2048_ttt1024/REMOTE_CHECKPOINT.md new file mode 100644 index 000000000..e5cd65a11 --- /dev/null +++ b/checkpoints/2026-03-20_twice_eval2048_ttt1024/REMOTE_CHECKPOINT.md @@ -0,0 +1,8 @@ +Best raw checkpoint metadata captured from the Runpod pod before shutdown: + +- Run ID: `twice_eval2048_ttt1024` +- Remote path: `/workspace/parameter-golf/final_model.pt` +- Size on pod: `72M` +- SHA256: `292d79fa54a638be348354f09d185f80b69710e7de8f4dfa42b36e43afccdc96` + +The raw `.pt` file itself was not copied into this repo because Runpod's SSH wrapper blocked automated binary transfer through `scp`. If you want to preserve the raw checkpoint, keep the pod or its volume alive until we manually copy it out tomorrow. diff --git a/data/tokenizer_specs.ablation.json b/data/tokenizer_specs.ablation.json new file mode 100644 index 000000000..ae09809c8 --- /dev/null +++ b/data/tokenizer_specs.ablation.json @@ -0,0 +1,29 @@ +{ + "tokenizers": [ + { + "name": "sp_bpe_768", + "dataset_suffix": "sp768", + "vocab_size": 768 + }, + { + "name": "sp_bpe_1024", + "dataset_suffix": "sp1024", + "vocab_size": 1024 + }, + { + "name": "sp_bpe_1280", + "dataset_suffix": "sp1280", + "vocab_size": 1280 + }, + { + "name": "sp_bpe_1536", + "dataset_suffix": "sp1536", + "vocab_size": 1536 + }, + { + "name": "pure_byte_260", + "dataset_suffix": "byte260", + "kind": "pure_byte" + } + ] +} diff --git a/program.md b/program.md new file mode 100644 index 000000000..42a44c932 --- /dev/null +++ b/program.md @@ -0,0 +1,86 @@ +# Parameter Golf Research Program + +You are working inside the OpenAI Parameter Golf repository. + +## Objective + +Improve the challenge score under these constraints: + +- optimize `final_int8_ttt_lora val_bpb` +- optimize `final_int8_zlib_roundtrip_exact val_bpb` +- keep `Total submission size int8+zlib` under `16,000,000` bytes +- preserve reproducibility + +Lower `val_bpb` is better. + +## Primary Rules + +1. Prefer small, ablation-friendly changes. +2. Keep changes concentrated in `train_gpt.py` unless there is a strong reason not to. +3. Reject changes that improve one metric but badly regress the other. +4. Reject changes that push artifact size toward the budget without a clear score win. +5. Do not change the validation set. +6. Treat tokenizer or dataset changes as higher-risk and require stronger evidence. + +## Current Priors + +- Sliding-window evaluation is high value. +- FP16 tied embedding export is high value. +- 10-layer small models are promising. +- Decoupled Muon weight decay is promising. +- `ATTN_TWICE_ALPHA=0.05` currently looks better than baseline. +- `Z_LOSS_COEF=0.0001` currently looks worse than baseline. + +## Current Best Known Local Results + +- `base10l` + - `roundtrip_val_bpb = 1.40296458` + - `ttt_val_bpb = 1.3976` + - `artifact_bytes = 10831123` + +- `twice_low` + - `roundtrip_val_bpb = 1.40177526` + - `ttt_val_bpb = 1.3969` + - `artifact_bytes = 10836065` + +## Experiment Order + +1. `twice_eval2048` +2. best `twice_*` variant on more seeds +3. training-context and batch tradeoff ablations +4. tokenizer ablations on published docs cache + +## Allowed Edit Zones + +- architecture details in `train_gpt.py` +- training schedule and optimizer settings +- quantization/export logic +- evaluation logic +- remote profile scripts + +## High-Risk Areas + +- external datasets +- validation handling +- complex multi-file refactors +- changes that increase code size substantially + +## Decision Policy + +Keep a change only if at least one is true: + +- `final_int8_ttt_lora` improves and `roundtrip_exact` does not materially regress +- `roundtrip_exact` improves and `ttt` does not materially regress +- artifact size drops meaningfully with near-flat score + +Reject a change if: + +- both `ttt` and `roundtrip_exact` regress +- artifact size grows with no score benefit +- it adds a lot of complexity without measurable value + +## Logging And Packaging + +- Use `scripts/run_remote_profile.sh` or `scripts/run_and_score.sh` +- Parse logs with `scripts/parse_run.py` +- Package strong candidates with `scripts/package_record.sh` diff --git a/records/_template/README.md b/records/_template/README.md new file mode 100644 index 000000000..414676107 --- /dev/null +++ b/records/_template/README.md @@ -0,0 +1,38 @@ +# Submission Name + +One-paragraph summary of the idea and why it matters for Parameter Golf. + +## Key Techniques + +1. Technique 1 +2. Technique 2 +3. Technique 3 + +## Results + +| Seed | val_loss | val_bpb | Steps | ms/step | +|------|----------|---------|-------|---------| +| 1337 | TBD | TBD | TBD | TBD | +| 42 | TBD | TBD | TBD | TBD | +| 7 | TBD | TBD | TBD | TBD | +| **Mean** | **TBD** | **TBD** | | | + +Artifact: `TBD` bytes | Eval time: `TBD` + +## Configuration + +```bash +# Paste the exact training command here +``` + +## Notes + +- Explain artifact accounting if needed +- Explain tokenizer/dataset changes if any +- Explain evaluation procedure if non-standard + +## Included Files + +- `train_gpt.py` +- `submission.json` +- `train_seed*.log` diff --git a/records/_template/submission.json b/records/_template/submission.json new file mode 100644 index 000000000..50c1a941f --- /dev/null +++ b/records/_template/submission.json @@ -0,0 +1,20 @@ +{ + "track": "10min_16mb", + "date": "YYYY-MM-DD", + "name": "Submission Name", + "author": "Your Name", + "github_id": "YourGitHubID", + "seed_results": { + "1337": { + "val_loss": 0.0, + "val_bpb": 0.0, + "steps": 0, + "ms_per_step": 0.0 + } + }, + "mean_val_loss": 0.0, + "mean_val_bpb": 0.0, + "p_value": 1.0, + "artifact_bytes": 0, + "code_bytes": 0 +} diff --git a/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/README.md b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/README.md new file mode 100644 index 000000000..7f6e83e3e --- /dev/null +++ b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/README.md @@ -0,0 +1,68 @@ +This folder checkpoints the first Runpod exploratory ablations run on March 20, 2026. + +These were not leaderboard-attempt runs. They used `1x H100`, `TRAIN_SHARDS=1`, and a strict `600s` wallclock cap to de-risk code changes and compare low-cost ablations before spending 8-GPU budget. + +Best observations: +- Best roundtrip score: `twice_eval2048` at `final_int8_zlib_roundtrip_exact val_bpb: 1.39909070` +- Best LoRA-TTT score: `twice_eval2048_ttt1024` at `final_int8_ttt_lora val_bpb: 1.3960` +- Best balanced profile so far: `twice_eval2048_ttt1024` + +Key takeaways: +- `ATTN_TWICE_ALPHA=0.05` helped versus baseline. +- `Z_LOSS_COEF=0.0001` regressed both roundtrip and TTT metrics. +- `EVAL_SEQ_LEN=2048` improved roundtrip scoring but hurt TTT when `TTT_EVAL_SEQ_LEN` also increased. +- Splitting eval settings (`EVAL_SEQ_LEN=2048`, `TTT_EVAL_SEQ_LEN=1024`) recovered the best TTT result seen so far. + +Best current training command: +```bash +RUN_ID=twice_eval2048_ttt1024 \ +DATA_PATH=./data/datasets/fineweb10B_sp1024 \ +TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ +VOCAB_SIZE=1024 \ +MAX_WALLCLOCK_SECONDS=600 \ +TRAIN_LOG_EVERY=100 \ +VAL_LOSS_EVERY=1000 \ +TRAIN_BATCH_TOKENS=524288 \ +TRAIN_SEQ_LEN=1024 \ +EVAL_SEQ_LEN=2048 \ +EVAL_STRIDE=64 \ +NUM_LAYERS=10 \ +MODEL_DIM=512 \ +NUM_HEADS=8 \ +NUM_KV_HEADS=4 \ +MLP_MULT=2 \ +TIE_EMBEDDINGS=1 \ +MATRIX_LR=0.02 \ +SCALAR_LR=0.02 \ +TIED_EMBED_LR=0.03 \ +WARMDOWN_ITERS=2500 \ +MUON_MOMENTUM=0.95 \ +MUON_BACKEND_STEPS=5 \ +MUON_WEIGHT_DECAY=0.02 \ +OVERTONE_INIT_POWER=0.5 \ +RESID_MIX_INIT_SCALE=3.0 \ +Z_LOSS_COEF=0.0 \ +ATTN_TWICE_ALPHA=0.05 \ +TTT_LORA_RANK=8 \ +TTT_LORA_LR=0.01 \ +TTT_CHUNK_SIZE=256 \ +TTT_EVAL_SEQ_LEN=1024 \ +TTT_BATCH_SIZE=64 \ +torchrun --standalone --nproc_per_node=1 train_gpt.py +``` + +Checkpoint note: +- The raw best checkpoint still lives on the Runpod pod at `/workspace/parameter-golf/final_model.pt`. +- Size on pod: `72M` +- SHA256: `292d79fa54a638be348354f09d185f80b69710e7de8f4dfa42b36e43afccdc96` +- Runpod's SSH wrapper blocked automated binary transfer, so this repo checkpoint stores the exact metrics and checkpoint manifest rather than the raw `.pt` file. + +Included files: +- `train_gpt.py` +- `submission.json` +- `results.json` +- `base10l.tail.txt` +- `zloss_low.tail.txt` +- `twice_low.tail.txt` +- `twice_eval2048.tail.txt` +- `twice_eval2048_ttt1024.tail.txt` diff --git a/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/base10l.tail.txt b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/base10l.tail.txt new file mode 100644 index 000000000..c506fad36 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/base10l.tail.txt @@ -0,0 +1,10 @@ +stopping_early: wallclock_cap train_time:600528ms step:1061/20000 +peak memory allocated: 14656 MiB reserved: 15982 MiB +Serialized model: 74578510 bytes +Code size: 58584 bytes +Total submission size: 74637094 bytes +Serialized model int8+zlib: 10772539 bytes (payload:19030336 raw_torch:19079980 payload_ratio:3.92x) +Total submission size int8+zlib: 10831123 bytes +final_int8_zlib_roundtrip val_loss:2.3688 val_bpb:1.4030 eval_time:12984ms +final_int8_zlib_roundtrip_exact val_loss:2.36884692 val_bpb:1.40296458 +final_int8_ttt_lora val_loss:2.3598 val_bpb:1.3976 eval_time:507995ms diff --git a/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/results.json b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/results.json new file mode 100644 index 000000000..349938dca --- /dev/null +++ b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/results.json @@ -0,0 +1,85 @@ +{ + "checkpoint_note": { + "best_raw_checkpoint_path": "/workspace/parameter-golf/final_model.pt", + "best_raw_checkpoint_size_human": "72M", + "best_raw_checkpoint_sha256": "292d79fa54a638be348354f09d185f80b69710e7de8f4dfa42b36e43afccdc96", + "transfer_status": "remote-only" + }, + "runs": [ + { + "run_id": "base10l", + "artifact_bytes": 10831123, + "code_bytes": 58584, + "roundtrip_eval_time_ms": 12984, + "roundtrip_val_bpb": 1.40296458, + "roundtrip_val_loss": 2.36884692, + "step_stop": 1061, + "total_train_time_ms": 600528, + "ttt_eval_time_ms": 507995, + "ttt_val_bpb": 1.3976, + "ttt_val_loss": 2.3598 + }, + { + "run_id": "zloss_low", + "artifact_bytes": 10804697, + "code_bytes": 58584, + "roundtrip_eval_time_ms": 13026, + "roundtrip_val_bpb": 1.40313158, + "roundtrip_val_loss": 2.3691289, + "step_stop": 1058, + "total_train_time_ms": 600386, + "ttt_eval_time_ms": 508243, + "ttt_val_bpb": 1.3984, + "ttt_val_loss": 2.3611 + }, + { + "run_id": "twice_low", + "artifact_bytes": 10836065, + "code_bytes": 58584, + "roundtrip_eval_time_ms": 12989, + "roundtrip_val_bpb": 1.40177526, + "roundtrip_val_loss": 2.3668388, + "step_stop": 1063, + "total_train_time_ms": 600242, + "ttt_eval_time_ms": 508663, + "ttt_val_bpb": 1.3969, + "ttt_val_loss": 2.3586 + }, + { + "run_id": "twice_eval2048", + "artifact_bytes": 10924367, + "code_bytes": 58584, + "roundtrip_eval_time_ms": 12988, + "roundtrip_val_bpb": 1.3990907, + "roundtrip_val_loss": 2.36230604, + "step_stop": 1085, + "total_train_time_ms": 600071, + "ttt_eval_time_ms": 755165, + "ttt_val_bpb": 1.3995, + "ttt_val_loss": 2.363 + }, + { + "run_id": "twice_eval2048_ttt1024", + "artifact_bytes": 10868875, + "code_bytes": 58584, + "roundtrip_eval_time_ms": 12982, + "roundtrip_val_bpb": 1.4013468, + "roundtrip_val_loss": 2.36611538, + "step_stop": 1075, + "total_train_time_ms": 600439, + "ttt_eval_time_ms": 508393, + "ttt_val_bpb": 1.396, + "ttt_val_loss": 2.3571 + } + ], + "summary": { + "best_balanced_profile": "twice_eval2048_ttt1024", + "best_roundtrip_profile": "twice_eval2048", + "best_ttt_profile": "twice_eval2048_ttt1024", + "winner_notes": [ + "ATTN_TWICE_ALPHA=0.05 outperformed baseline.", + "Z_LOSS_COEF=0.0001 regressed both metrics.", + "Longer roundtrip eval context and shorter TTT eval context worked best together." + ] + } +} diff --git a/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/submission.json b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/submission.json new file mode 100644 index 000000000..a08ec022e --- /dev/null +++ b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/submission.json @@ -0,0 +1,17 @@ +{ + "author": "Deividas Mataciunas", + "github_id": "DeividasMat", + "name": "Runpod 1GPU Twicing Eval Split", + "blurb": "Exploratory 1x H100, 1-shard, 10-minute checkpoint run. The best balanced profile used ATTN_TWICE_ALPHA=0.05 with EVAL_SEQ_LEN=2048 and TTT_EVAL_SEQ_LEN=1024, reaching final_int8_ttt_lora val_bpb 1.3960 under the 16MB artifact cap.", + "date": "2026-03-20T18:30:00Z", + "track": "non-record-1gpu-16mb", + "val_loss": 2.3571, + "val_bpb": 1.3960, + "pre_quant_val_loss": 2.3308, + "pre_quant_val_bpb": 1.3805, + "step_stop": 1075, + "wallclock_seconds": 600.439, + "bytes_total": 10868875, + "bytes_model_int8_zlib": 10810291, + "bytes_code": 58584 +} diff --git a/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/train_gpt.py b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/train_gpt.py new file mode 100644 index 000000000..5fee3aa8e --- /dev/null +++ b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/train_gpt.py @@ -0,0 +1,1496 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", "0")) + eval_stride = int(os.environ.get("EVAL_STRIDE", "0")) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_weight_decay = float(os.environ.get("MUON_WEIGHT_DECAY", 0.0)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + z_loss_coef = float(os.environ.get("Z_LOSS_COEF", 0.0)) + attn_twice_alpha = float(os.environ.get("ATTN_TWICE_ALPHA", 0.0)) + overtone_init_power = float(os.environ.get("OVERTONE_INIT_POWER", 0.0)) + resid_mix_init_scale = float(os.environ.get("RESID_MIX_INIT_SCALE", 3.0)) + + # Test-time training (LoRA) hyperparameters. + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.01)) + ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 256)) + ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", os.environ.get("EVAL_SEQ_LEN", os.environ.get("TRAIN_SEQ_LEN", "1024")))) + ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + seq_len_override: int = 0, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + seq_len = seq_len_override if seq_len_override > 0 else args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + model.train() + val_loss = val_loss_sum / val_token_count + return float(val_loss.item()), float((val_loss.item() / math.log(2.0)) * (val_token_count.item() / val_byte_count.item())) + + +def eval_val_sliding( + logits_fn, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + seq_len: int, + stride: int, + eval_batch_seqs: int, +) -> tuple[float, float]: + windows: list[tuple[int, int]] = [] + total = val_tokens.numel() - 1 + pos = 0 + while pos + seq_len <= total: + windows.append((pos, 0 if pos == 0 else seq_len - stride)) + pos += stride + + per_rank = (len(windows) + world_size - 1) // world_size + my_windows = windows[rank * per_rank : min((rank + 1) * per_rank, len(windows))] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + tok_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + with torch.inference_mode(): + for i in range(0, len(my_windows), eval_batch_seqs): + batch = my_windows[i : i + eval_batch_seqs] + x = torch.stack([val_tokens[p : p + seq_len] for p, _ in batch]).to(device=device, dtype=torch.int64) + y = torch.stack([val_tokens[p + 1 : p + seq_len + 1] for p, _ in batch]).to(device=device, dtype=torch.int64) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = logits_fn(x) + for b, (_, score_offset) in enumerate(batch): + scored_targets = y[b, score_offset:] + loss_sum += F.cross_entropy(logits[b, score_offset:].float(), scored_targets, reduction="sum").to(torch.float64) + tok_count += scored_targets.numel() + prev = x[b, score_offset : score_offset + scored_targets.numel()] + token_bytes = base_bytes_lut[scored_targets].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[scored_targets] & ~is_boundary_token_lut[prev]).to(dtype=torch.int16) + byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(tok_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = loss_sum / tok_count + return float(val_loss.item()), float((val_loss.item() / math.log(2.0)) * (tok_count.item() / byte_count.item())) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + if name == "tok_emb.weight": + kept = t.to(dtype=torch.float16).contiguous() + passthrough[name] = kept + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + adjusted_base = self.base * (scale ** (self.dim / (self.dim - 2))) + inv_freq = 1.0 / ( + adjusted_base ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim) + ) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +def lm_loss(logits: Tensor, targets: Tensor, z_loss_coef: float, reduction: str) -> Tensor: + losses = F.cross_entropy(logits.float().reshape(-1, logits.size(-1)), targets.reshape(-1), reduction="none").reshape_as(targets) + if z_loss_coef > 0: + losses = losses + z_loss_coef * torch.logsumexp(logits.float(), dim=-1).square() + return losses if reduction == "none" else losses.mean() +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + train_seq_len: int, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len) + + def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x) + (q_delta if q_delta is not None else 0) + k = self.c_k(x) + v = self.c_v(x) + (v_delta if v_delta is not None else 0) + q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if self.num_kv_heads != self.num_heads: + k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + train_seq_len: int, + attn_twice_alpha: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len) + self.attn_twice_alpha = attn_twice_alpha + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, q_delta_fn=None, v_delta_fn=None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + n = self.attn_norm(x) + qd = q_delta_fn(n) if q_delta_fn is not None else None + vd = v_delta_fn(n) if v_delta_fn is not None else None + attn_out = self.attn(n, qd, vd) + if self.attn_twice_alpha != 0.0: + attn_out = attn_out + self.attn_twice_alpha * (n - attn_out) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + train_seq_len: int, + z_loss_coef: float, + attn_twice_alpha: float, + overtone_init_power: float, + resid_mix_init_scale: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.z_loss_coef = z_loss_coef + self.overtone_init_power = overtone_init_power + self.resid_mix_init_scale = resid_mix_init_scale + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + train_seq_len, + attn_twice_alpha, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + if self.overtone_init_power > 0.0: + with torch.no_grad(): + u, s, v = torch.linalg.svd(self.tok_emb.weight.data, full_matrices=False) + self.tok_emb.weight.data = (u * (s[0] * torch.arange(1, s.shape[0] + 1, dtype=s.dtype).pow(-self.overtone_init_power))[None, :]) @ v + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + for i, block in enumerate(self.blocks): + with torch.no_grad(): + phase = torch.sigmoid(torch.tensor(self.resid_mix_init_scale * (i / max(len(self.blocks) - 1, 1) - 0.5))) + block.resid_mix.data[0].fill_(float(phase)) + block.resid_mix.data[1].fill_(float(1.0 - phase)) + + def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + qd = lora.q_loras[i] if lora else None + vd = lora.v_loras[i] if lora else None + x = self.blocks[i](x, x0, qd, vd) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + qd = lora.q_loras[bi] if lora else None + vd = lora.v_loras[bi] if lora else None + x = self.blocks[bi](x, x0, qd, vd) + x = self.final_norm(x) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + logits = logits + (lora.lm_head_lora(x) if lora else 0) + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + return lm_loss(logits, target_ids, self.z_loss_coef, reduction="none" if lora else "mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[bi](x, x0) + x = self.final_norm(x) + logits = F.linear(x, self.tok_emb.weight.to(x.dtype)) if self.tie_embeddings else self.lm_head(x) + return self.logit_softcap * torch.tanh(logits / self.logit_softcap) + + +# ----------------------------- +# TEST-TIME TRAINING (LoRA) +# ----------------------------- +# +# At evaluation time, we adapt per-document low-rank adapters on the validation data. +# Each document gets its own adapter, so there is no inter-document dependency. + +BOS_ID = 1 + +class BatchedLinearLoRA(nn.Module): + """LoRA for a linear layer, with independent weights per batch element. + Computes x @ Aᵀ @ Bᵀ = x @ (BA)ᵀ, i.e. the LoRA delta is ΔW = BA.""" + def __init__(self, bsz: int, in_features: int, out_features: int, rank: int): + super().__init__() + self.in_features = in_features + self.A = nn.Parameter(torch.empty(bsz, rank, in_features)) # down-projection + self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) # up-projection + self.reset() + + def forward(self, x: Tensor) -> Tensor: + return (x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2) # (bsz, T, out) + + def reset(self) -> None: + bound = 1.0 / math.sqrt(self.in_features) + with torch.no_grad(): + self.A.uniform_(-bound, bound) # kaiming-uniform + self.B.zero_() + +class BatchedTTTLoRA(nn.Module): + """All LoRA adapters for one batch: LM head and Q/V per block.""" + def __init__(self, bsz: int, model: GPT, rank: int): + super().__init__() + dim = model.tok_emb.embedding_dim + vocab = model.tok_emb.num_embeddings + self.lm_head_lora = BatchedLinearLoRA(bsz, dim, vocab, rank) + self.q_loras = nn.ModuleList() + self.v_loras = nn.ModuleList() + for block in model.blocks: + self.q_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_q.weight.shape[0], rank)) + self.v_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_v.weight.shape[0], rank)) + + def reset(self) -> None: + for m in self.modules(): + if isinstance(m, BatchedLinearLoRA): + m.reset() + +def _reset_ttt_optimizer(opt): + for group in opt.param_groups: + for p in group['params']: + s = opt.state.get(p) + if not s: # Fresh state. + continue + s['exp_avg'].zero_() + s['exp_avg_sq'].zero_() + s['step'].fill_(0) + +def _build_ttt_optimizer(lora, args: Hyperparameters): + return torch.optim.Adam(lora.parameters(), lr=args.ttt_lora_lr, betas=(args.beta1, args.beta2), eps=1e-10) + +def _find_docs(all_tokens: Tensor, include_next_bos: bool = True) -> list[tuple[int, int]]: + """Return (start_offset, length) for each document, identified by BOS boundaries. + + If include_next_bos is True, include next document's BOS (to match continuous-stream + eval token count exactly). + """ + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() + docs = [] + for i in range(len(bos_positions)): + start = int(bos_positions[i]) + end = int(bos_positions[i + 1]) if i + 1 < len(bos_positions) else all_tokens.numel() + if include_next_bos and i + 1 < len(bos_positions): + end += 1 + assert end - start >= 2 + docs.append((start, end - start)) + return docs + +def _compute_chunk_window(ci: int, pred_len: int, num_chunks: int, chunk_size: int, eval_seq_len: int): + """Return (win_start, win_len, chunk_offset, chunk_len) for chunk `ci` of a doc.""" + chunk_start = ci * chunk_size + chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size + win_start = max(0, chunk_end - eval_seq_len) + win_len = chunk_end - win_start + chunk_offset = chunk_start - win_start + chunk_len = chunk_end - chunk_start + return win_start, win_len, chunk_offset, chunk_len + +def _accumulate_bpb( + ptl: Tensor, x: Tensor, y: Tensor, + batch_i: int, chunk_offset: int, chunk_len: int, + base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + loss_sum: Tensor, byte_sum: Tensor, token_count: Tensor, +): + """Add one doc-chunk's contribution to the running BPB accumulators.""" + lbl = ptl[batch_i, chunk_offset:chunk_offset + chunk_len].to(torch.float64) + prev = x[batch_i, chunk_offset:chunk_offset + chunk_len] + tgt = y[batch_i, chunk_offset:chunk_offset + chunk_len] + tok_bytes = base_bytes_lut[tgt].to(torch.float64) + tok_bytes += has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev] + loss_sum += lbl.sum() + byte_sum += tok_bytes.sum() + token_count += chunk_len + +def eval_val_ttt_lora( + args: Hyperparameters, + base_model: GPT, + rank: int, + world_size: int, + device: torch.device, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + """Evaluate with batched LoRA test-time training. Returns (val_loss, val_bpb).""" + # Load validation tokens and find document boundaries + files = sorted(glob.glob(args.val_files)) + all_tokens = torch.cat([load_data_shard(Path(f)) for f in files]) + docs = _find_docs(all_tokens) + + # Each rank takes a contiguous slice of documents + rank_docs = docs[(len(docs) * rank) // world_size : (len(docs) * (rank + 1)) // world_size] + chunk_size = args.ttt_chunk_size + eval_seq_len = args.ttt_eval_seq_len + batch_size = args.ttt_batch_size + lora_rank = args.ttt_lora_rank + + rank_docs.sort(key=lambda d: (d[1] - 2) // chunk_size) + + base_model.eval() + for p in base_model.parameters(): + p.requires_grad_(False) + + lora = BatchedTTTLoRA(batch_size, base_model, lora_rank).to(device) + opt = _build_ttt_optimizer(lora, args) + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + + for bi in range(0, len(rank_docs), batch_size): + batch = rank_docs[bi:bi + batch_size] + bsz = len(batch) + + if bsz == batch_size: + cur_lora, cur_opt = lora, opt + cur_lora.reset() + _reset_ttt_optimizer(cur_opt) + else: + cur_lora = BatchedTTTLoRA(bsz, base_model, lora_rank).to(device) + cur_opt = _build_ttt_optimizer(cur_lora, args) + + pred_lens = [doc_len - 1 for _, doc_len in batch] + num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] + max_nc = max(num_chunks) + + for ci in range(max_nc): + chunk_stats = _compute_chunk_window(ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len) + context_size, chunk_offset = chunk_stats[1], chunk_stats[2] + + active = [ci < nc for nc in num_chunks] + needs_train = any(ci < nc - 1 for nc in num_chunks) + + x = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) + y = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) + doc_info = [] # (chunk_offset, chunk_len) per doc + for b in range(bsz): + if not active[b]: + doc_info.append((0, 0)) + continue + ds, dl = batch[b] + ws, wl, co, cl = _compute_chunk_window(ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len) + chunk = all_tokens[ds + ws: ds + ws + wl + 1] + toks = chunk.to(dtype=torch.int64, device=device) + x[b, :wl] = toks[:-1] + y[b, :wl] = toks[1:] + doc_info.append((co, cl)) + + # Forward pass (keep grad graph alive only when we need to train) + if needs_train: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + else: + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + + # Score: accumulate loss and byte counts for BPB (before training on chunk) + with torch.no_grad(): + for b in range(bsz): + if not active[b]: + continue + co, cl = doc_info[b] + _accumulate_bpb( + ptl, x, y, b, co, cl, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, loss_sum, byte_sum, token_count) + + # Train: one Adam step on the LoRA params using this chunk's loss + if needs_train: + mask = torch.tensor([float(ci < num_chunks[b] - 1) for b in range(bsz)], device=device) + per_doc = ptl[:, chunk_offset:chunk_offset + chunk_size].mean(dim=-1) + cur_opt.zero_grad() + (per_doc * mask).sum().backward() + cur_opt.step() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + + val_loss = float(loss_sum.item() / token_count.item()) + val_bpb = float((loss_sum.item() / math.log(2.0)) / byte_sum.item()) + return val_loss, val_bpb + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + train_seq_len=args.train_seq_len, + z_loss_coef=args.z_loss_coef, + attn_twice_alpha=args.attn_twice_alpha, + overtone_init_power=args.overtone_init_power, + resid_mix_init_scale=args.resid_mix_init_scale, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + if isinstance(module, Rotary): + module.inv_freq.data = module.inv_freq.data.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0(f"eval_seq_len:{args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len} eval_stride:{args.eval_stride} muon_weight_decay:{args.muon_weight_decay} z_loss_coef:{args.z_loss_coef} attn_twice_alpha:{args.attn_twice_alpha} overtone_init_power:{args.overtone_init_power}") + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + if args.muon_weight_decay > 0: + with torch.no_grad(): + shrink = 1.0 - args.muon_weight_decay * optimizer_muon.param_groups[0]["lr"] + for p in matrix_params: + p.mul_(shrink) + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_tokens_eval = load_validation_tokens(args.val_files, eval_seq_len) if eval_seq_len != args.train_seq_len else val_tokens + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.eval_stride > 0: + eval_batch_seqs = min(256, max(1, args.val_batch_size // eval_seq_len // max(world_size, 1))) + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False) + warmup_x = torch.zeros(eval_batch_seqs, eval_seq_len, dtype=torch.int64, device=device) + base_model.eval() + with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + _ = compiled_logits(warmup_x) + q_val_loss, q_val_bpb = eval_val_sliding(compiled_logits, rank, world_size, device, val_tokens_eval, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, eval_seq_len, args.eval_stride, eval_batch_seqs) + base_model.train() + else: + q_val_loss, q_val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens_eval, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, seq_len_override=eval_seq_len if eval_seq_len != args.train_seq_len else 0) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # LoRA test-time training evaluation (the competition score) + torch._dynamo.reset() + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_val_loss, ttt_val_bpb = eval_val_ttt_lora( + args, base_model, rank, world_size, device, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_ttt_lora val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/twice_eval2048.tail.txt b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/twice_eval2048.tail.txt new file mode 100644 index 000000000..13ca9e496 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/twice_eval2048.tail.txt @@ -0,0 +1,10 @@ +stopping_early: wallclock_cap train_time:600071ms step:1085/20000 +peak memory allocated: 14528 MiB reserved: 14696 MiB +Serialized model: 74578510 bytes +Code size: 58584 bytes +Total submission size: 74637094 bytes +Serialized model int8+zlib: 10865783 bytes (payload:19030336 raw_torch:19079980 payload_ratio:3.92x) +Total submission size int8+zlib: 10924367 bytes +final_int8_zlib_roundtrip val_loss:2.3623 val_bpb:1.3991 eval_time:12988ms +final_int8_zlib_roundtrip_exact val_loss:2.36230604 val_bpb:1.39909070 +final_int8_ttt_lora val_loss:2.3630 val_bpb:1.3995 eval_time:755165ms diff --git a/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/twice_eval2048_ttt1024.tail.txt b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/twice_eval2048_ttt1024.tail.txt new file mode 100644 index 000000000..aebb3ce9f --- /dev/null +++ b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/twice_eval2048_ttt1024.tail.txt @@ -0,0 +1,10 @@ +stopping_early: wallclock_cap train_time:600439ms step:1075/20000 +peak memory allocated: 14528 MiB reserved: 14696 MiB +Serialized model: 74578510 bytes +Code size: 58584 bytes +Total submission size: 74637094 bytes +Serialized model int8+zlib: 10810291 bytes (payload:19030336 raw_torch:19079980 payload_ratio:3.92x) +Total submission size int8+zlib: 10868875 bytes +final_int8_zlib_roundtrip val_loss:2.3661 val_bpb:1.4013 eval_time:12982ms +final_int8_zlib_roundtrip_exact val_loss:2.36611538 val_bpb:1.40134680 +final_int8_ttt_lora val_loss:2.3571 val_bpb:1.3960 eval_time:508393ms diff --git a/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/twice_low.tail.txt b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/twice_low.tail.txt new file mode 100644 index 000000000..0cddd1513 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/twice_low.tail.txt @@ -0,0 +1,10 @@ +stopping_early: wallclock_cap train_time:600242ms step:1063/20000 +peak memory allocated: 14528 MiB reserved: 14696 MiB +Serialized model: 74578510 bytes +Code size: 58584 bytes +Total submission size: 74637094 bytes +Serialized model int8+zlib: 10777481 bytes (payload:19030336 raw_torch:19079980 payload_ratio:3.92x) +Total submission size int8+zlib: 10836065 bytes +final_int8_zlib_roundtrip val_loss:2.3668 val_bpb:1.4018 eval_time:12989ms +final_int8_zlib_roundtrip_exact val_loss:2.36683880 val_bpb:1.40177526 +final_int8_ttt_lora val_loss:2.3586 val_bpb:1.3969 eval_time:508663ms diff --git a/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/zloss_low.tail.txt b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/zloss_low.tail.txt new file mode 100644 index 000000000..58b9eee35 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-20_Runpod1GPU_Ablations/zloss_low.tail.txt @@ -0,0 +1,10 @@ +stopping_early: wallclock_cap train_time:600386ms step:1058/20000 +peak memory allocated: 14528 MiB reserved: 14696 MiB +Serialized model: 74578510 bytes +Code size: 58584 bytes +Total submission size: 74637094 bytes +Serialized model int8+zlib: 10746113 bytes (payload:19030336 raw_torch:19079980 payload_ratio:3.92x) +Total submission size int8+zlib: 10804697 bytes +final_int8_zlib_roundtrip val_loss:2.3691 val_bpb:1.4031 eval_time:13026ms +final_int8_zlib_roundtrip_exact val_loss:2.36912890 val_bpb:1.40313158 +final_int8_ttt_lora val_loss:2.3611 val_bpb:1.3984 eval_time:508243ms diff --git a/scripts/package_record.sh b/scripts/package_record.sh new file mode 100755 index 000000000..eecc78554 --- /dev/null +++ b/scripts/package_record.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +if [ "$#" -lt 2 ]; then + echo "usage: bash scripts/package_record.sh [log2 ...]" >&2 + exit 1 +fi + +TARGET_DIR="$1" +shift + +mkdir -p "$TARGET_DIR" +cp records/_template/README.md "$TARGET_DIR/README.md" +cp records/_template/submission.json "$TARGET_DIR/submission.json" +cp train_gpt.py "$TARGET_DIR/train_gpt.py" + +for log_file in "$@"; do + cp "$log_file" "$TARGET_DIR/" +done + +echo "Packaged template into $TARGET_DIR" +echo "Next: fill README.md and submission.json with real metrics." diff --git a/scripts/parse_run.py b/scripts/parse_run.py new file mode 100755 index 000000000..b1167f042 --- /dev/null +++ b/scripts/parse_run.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import json +import re +import sys +from pathlib import Path + + +PATTERNS = { + "roundtrip": re.compile(r"final_int8_zlib_roundtrip_exact val_loss:(\S+) val_bpb:(\S+)"), + "ttt": re.compile(r"final_int8_ttt_lora val_loss:(\S+) val_bpb:(\S+)"), + "artifact": re.compile(r"Total submission size int8\+zlib: (\d+) bytes"), + "step_avg": re.compile(r"step:(\d+)/\d+ val_loss:(\S+) val_bpb:(\S+) train_time:(\d+)ms step_avg:(\S+)ms"), +} + + +def parse_log(path: Path) -> dict[str, object]: + text = path.read_text(encoding="utf-8", errors="replace") + out: dict[str, object] = {"log": str(path)} + if m := PATTERNS["roundtrip"].search(text): + out["roundtrip_val_loss"] = float(m.group(1)) + out["roundtrip_val_bpb"] = float(m.group(2)) + if m := PATTERNS["ttt"].search(text): + out["ttt_val_loss"] = float(m.group(1)) + out["ttt_val_bpb"] = float(m.group(2)) + if m := PATTERNS["artifact"].search(text): + out["artifact_bytes"] = int(m.group(1)) + step_matches = PATTERNS["step_avg"].findall(text) + if step_matches: + step, val_loss, val_bpb, train_time_ms, step_avg = step_matches[-1] + out["last_val_step"] = int(step) + out["last_val_loss"] = float(val_loss) + out["last_val_bpb"] = float(val_bpb) + out["train_time_ms"] = int(train_time_ms) + out["step_avg_ms"] = float(step_avg) + return out + + +def main() -> int: + if len(sys.argv) < 2: + print("usage: python3 scripts/parse_run.py [log2 ...]", file=sys.stderr) + return 1 + rows = [parse_log(Path(arg)) for arg in sys.argv[1:]] + print(json.dumps(rows, indent=2, sort_keys=True)) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/rebuild_tokenizer_export.sh b/scripts/rebuild_tokenizer_export.sh new file mode 100755 index 000000000..c1bf4b92f --- /dev/null +++ b/scripts/rebuild_tokenizer_export.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +REPO_ID="${REPO_ID:-willdepueoai/parameter-golf}" +REMOTE_ROOT="${REMOTE_ROOT:-datasets}" +OUTPUT_ROOT="${OUTPUT_ROOT:-$ROOT_DIR/data/exports/tokenizer_ablation}" +TOKENIZER_CONFIG="${TOKENIZER_CONFIG:-$ROOT_DIR/data/tokenizer_specs.ablation.json}" + +python3 data/download_hf_docs_and_tokenize.py \ + --repo-id "$REPO_ID" \ + --remote-root "$REMOTE_ROOT" \ + --output-root "$OUTPUT_ROOT" \ + --tokenizer-config "$TOKENIZER_CONFIG" diff --git a/scripts/remote_fetch_data.sh b/scripts/remote_fetch_data.sh new file mode 100755 index 000000000..0085bb8dd --- /dev/null +++ b/scripts/remote_fetch_data.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +TRAIN_SHARDS="${TRAIN_SHARDS:-1}" +VARIANT="${VARIANT:-sp1024}" + +python3 data/cached_challenge_fineweb.py --variant "$VARIANT" --train-shards "$TRAIN_SHARDS" diff --git a/scripts/run_and_score.sh b/scripts/run_and_score.sh new file mode 100755 index 000000000..b1031b038 --- /dev/null +++ b/scripts/run_and_score.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +PROFILE="${1:?usage: bash scripts/run_and_score.sh }" +shift || true + +bash scripts/run_remote_profile.sh "$PROFILE" "$@" + +LOG_PATH="logs/${RUN_ID:-${PROFILE}}.txt" +if [ ! -f "$LOG_PATH" ]; then + echo "log file not found: $LOG_PATH" >&2 + exit 1 +fi + +echo +echo "Parsed summary:" +python3 scripts/parse_run.py "$LOG_PATH" diff --git a/scripts/run_remote_experiment.sh b/scripts/run_remote_experiment.sh new file mode 100755 index 000000000..a00ffd1e7 --- /dev/null +++ b/scripts/run_remote_experiment.sh @@ -0,0 +1,45 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +export RUN_ID="${RUN_ID:-remote_10l_slide64}" +export DATA_PATH="${DATA_PATH:-./data/datasets/fineweb10B_sp1024}" +export TOKENIZER_PATH="${TOKENIZER_PATH:-./data/tokenizers/fineweb_1024_bpe.model}" +export VOCAB_SIZE="${VOCAB_SIZE:-1024}" +export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" +export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-100}" +export VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-1000}" +export TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}" +export TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-1024}" +export EVAL_SEQ_LEN="${EVAL_SEQ_LEN:-1024}" +export EVAL_STRIDE="${EVAL_STRIDE:-64}" +export NUM_LAYERS="${NUM_LAYERS:-10}" +export MODEL_DIM="${MODEL_DIM:-512}" +export NUM_HEADS="${NUM_HEADS:-8}" +export NUM_KV_HEADS="${NUM_KV_HEADS:-4}" +export MLP_MULT="${MLP_MULT:-2}" +export TIE_EMBEDDINGS="${TIE_EMBEDDINGS:-1}" +export MATRIX_LR="${MATRIX_LR:-0.02}" +export SCALAR_LR="${SCALAR_LR:-0.02}" +export TIED_EMBED_LR="${TIED_EMBED_LR:-0.03}" +export WARMDOWN_ITERS="${WARMDOWN_ITERS:-2500}" +export MUON_MOMENTUM="${MUON_MOMENTUM:-0.95}" +export MUON_BACKEND_STEPS="${MUON_BACKEND_STEPS:-5}" +export MUON_WEIGHT_DECAY="${MUON_WEIGHT_DECAY:-0.02}" +export Z_LOSS_COEF="${Z_LOSS_COEF:-0.0}" +export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.0}" +export OVERTONE_INIT_POWER="${OVERTONE_INIT_POWER:-0.5}" +export RESID_MIX_INIT_SCALE="${RESID_MIX_INIT_SCALE:-3.0}" +export GRAD_CLIP_NORM="${GRAD_CLIP_NORM:-0.0}" +export TTT_LORA_RANK="${TTT_LORA_RANK:-8}" +export TTT_LORA_LR="${TTT_LORA_LR:-0.01}" +export TTT_CHUNK_SIZE="${TTT_CHUNK_SIZE:-256}" +export TTT_EVAL_SEQ_LEN="${TTT_EVAL_SEQ_LEN:-$EVAL_SEQ_LEN}" +export TTT_BATCH_SIZE="${TTT_BATCH_SIZE:-64}" + +NPROC_PER_NODE="${NPROC_PER_NODE:-1}" + +echo "Starting run ${RUN_ID} with ${NPROC_PER_NODE} GPU(s)" +torchrun --standalone --nproc_per_node="$NPROC_PER_NODE" train_gpt.py diff --git a/scripts/run_remote_profile.sh b/scripts/run_remote_profile.sh new file mode 100755 index 000000000..d93c4e233 --- /dev/null +++ b/scripts/run_remote_profile.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +PROFILE="${1:-base10l}" +shift || true + +case "$PROFILE" in + base10l) + export RUN_ID="${RUN_ID:-base10l}" + ;; + zloss_low) + export RUN_ID="${RUN_ID:-zloss_low}" + export Z_LOSS_COEF="${Z_LOSS_COEF:-0.0001}" + ;; + zloss_med) + export RUN_ID="${RUN_ID:-zloss_med}" + export Z_LOSS_COEF="${Z_LOSS_COEF:-0.0003}" + ;; + twice_low) + export RUN_ID="${RUN_ID:-twice_low}" + export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" + ;; + zloss_twice) + export RUN_ID="${RUN_ID:-zloss_twice}" + export Z_LOSS_COEF="${Z_LOSS_COEF:-0.0001}" + export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" + ;; + eval2048) + export RUN_ID="${RUN_ID:-eval2048}" + export EVAL_SEQ_LEN="${EVAL_SEQ_LEN:-2048}" + export TTT_EVAL_SEQ_LEN="${TTT_EVAL_SEQ_LEN:-2048}" + ;; + *) + echo "Unknown profile: $PROFILE" >&2 + echo "Profiles: base10l zloss_low zloss_med twice_low zloss_twice eval2048" >&2 + exit 1 + ;; +esac + +bash scripts/run_remote_experiment.sh "$@" diff --git a/train_gpt.py b/train_gpt.py index 85e2cc463..5fee3aa8e 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -56,6 +56,8 @@ class Hyperparameters: warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", "0")) + eval_stride = int(os.environ.get("EVAL_STRIDE", "0")) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) @@ -79,18 +81,23 @@ class Hyperparameters: scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_weight_decay = float(os.environ.get("MUON_WEIGHT_DECAY", 0.0)) muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) beta1 = float(os.environ.get("BETA1", 0.9)) beta2 = float(os.environ.get("BETA2", 0.95)) adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + z_loss_coef = float(os.environ.get("Z_LOSS_COEF", 0.0)) + attn_twice_alpha = float(os.environ.get("ATTN_TWICE_ALPHA", 0.0)) + overtone_init_power = float(os.environ.get("OVERTONE_INIT_POWER", 0.0)) + resid_mix_init_scale = float(os.environ.get("RESID_MIX_INIT_SCALE", 3.0)) # Test-time training (LoRA) hyperparameters. ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.01)) ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 256)) - ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 1024)) + ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", os.environ.get("EVAL_SEQ_LEN", os.environ.get("TRAIN_SEQ_LEN", "1024")))) ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) # ----------------------------- @@ -234,19 +241,21 @@ def eval_val( base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + seq_len_override: int = 0, ) -> tuple[float, float]: # Validation computes two metrics: # - val_loss: token cross-entropy (natural log) # - val_bpb: tokenizer-agnostic compression metric used by the challenge + seq_len = seq_len_override if seq_len_override > 0 else args.train_seq_len local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: + if local_batch_tokens < seq_len: raise ValueError( "VAL_BATCH_SIZE must provide at least one sequence per rank; " f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len seq_start = (total_seqs * rank) // world_size seq_end = (total_seqs * (rank + 1)) // world_size val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) @@ -257,11 +266,11 @@ def eval_val( with torch.inference_mode(): for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): batch_loss = model(x, y).detach() batch_token_count = float(y.numel()) @@ -278,11 +287,59 @@ def eval_val( dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + val_loss = val_loss_sum / val_token_count + return float(val_loss.item()), float((val_loss.item() / math.log(2.0)) * (val_token_count.item() / val_byte_count.item())) + + +def eval_val_sliding( + logits_fn, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + seq_len: int, + stride: int, + eval_batch_seqs: int, +) -> tuple[float, float]: + windows: list[tuple[int, int]] = [] + total = val_tokens.numel() - 1 + pos = 0 + while pos + seq_len <= total: + windows.append((pos, 0 if pos == 0 else seq_len - stride)) + pos += stride + + per_rank = (len(windows) + world_size - 1) // world_size + my_windows = windows[rank * per_rank : min((rank + 1) * per_rank, len(windows))] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + tok_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + with torch.inference_mode(): + for i in range(0, len(my_windows), eval_batch_seqs): + batch = my_windows[i : i + eval_batch_seqs] + x = torch.stack([val_tokens[p : p + seq_len] for p, _ in batch]).to(device=device, dtype=torch.int64) + y = torch.stack([val_tokens[p + 1 : p + seq_len + 1] for p, _ in batch]).to(device=device, dtype=torch.int64) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = logits_fn(x) + for b, (_, score_offset) in enumerate(batch): + scored_targets = y[b, score_offset:] + loss_sum += F.cross_entropy(logits[b, score_offset:].float(), scored_targets, reduction="sum").to(torch.float64) + tok_count += scored_targets.numel() + prev = x[b, score_offset : score_offset + scored_targets.numel()] + token_bytes = base_bytes_lut[scored_targets].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[scored_targets] & ~is_boundary_token_lut[prev]).to(dtype=torch.int16) + byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(tok_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = loss_sum / tok_count + return float(val_loss.item()), float((val_loss.item() / math.log(2.0)) * (tok_count.item() / byte_count.item())) # ----------------------------- # POST-TRAINING QUANTIZATION @@ -375,6 +432,13 @@ def quantize_state_dict_int8(state_dict: dict[str, Tensor]): stats["int8_payload_bytes"] += tensor_nbytes(t) continue + if name == "tok_emb.weight": + kept = t.to(dtype=torch.float16).contiguous() + passthrough[name] = kept + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + # Small float tensors are cheap enough to keep directly. We still downcast # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: @@ -530,8 +594,11 @@ def restore_low_dim_params_to_fp32(module: nn.Module) -> None: class Rotary(nn.Module): # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024): super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._seq_len_cached = 0 @@ -545,8 +612,16 @@ def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tup or self._seq_len_cached != seq_len or self._cos_cached.device != device ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + adjusted_base = self.base * (scale ** (self.dim / (self.dim - 2))) + inv_freq = 1.0 / ( + adjusted_base ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim) + ) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) self._cos_cached = freqs.cos()[None, None, :, :] self._sin_cached = freqs.sin()[None, None, :, :] self._seq_len_cached = seq_len @@ -557,8 +632,11 @@ def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: half = x.size(-1) // 2 x1, x2 = x[..., :half], x[..., half:] return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - +def lm_loss(logits: Tensor, targets: Tensor, z_loss_coef: float, reduction: str) -> Tensor: + losses = F.cross_entropy(logits.float().reshape(-1, logits.size(-1)), targets.reshape(-1), reduction="none").reshape_as(targets) + if z_loss_coef > 0: + losses = losses + z_loss_coef * torch.logsumexp(logits.float(), dim=-1).square() + return losses if reduction == "none" else losses.mean() class CausalSelfAttention(nn.Module): def __init__( self, @@ -567,6 +645,7 @@ def __init__( num_kv_heads: int, rope_base: float, qk_gain_init: float, + train_seq_len: int, ): super().__init__() if dim % num_heads != 0: @@ -585,7 +664,7 @@ def __init__( self.proj = CastedLinear(dim, dim, bias=False) self.proj._zero_init = True self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len) def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: bsz, seqlen, dim = x.shape @@ -601,14 +680,10 @@ def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: q = apply_rotary_emb(q, cos, sin) k = apply_rotary_emb(k, cos, sin) q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) + if self.num_kv_heads != self.num_heads: + k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True) y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) return self.proj(y) @@ -636,11 +711,14 @@ def __init__( mlp_mult: int, rope_base: float, qk_gain_init: float, + train_seq_len: int, + attn_twice_alpha: float, ): super().__init__() self.attn_norm = RMSNorm() self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len) + self.attn_twice_alpha = attn_twice_alpha self.mlp = MLP(dim, mlp_mult) self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) @@ -653,6 +731,8 @@ def forward(self, x: Tensor, x0: Tensor, q_delta_fn=None, v_delta_fn=None) -> Te qd = q_delta_fn(n) if q_delta_fn is not None else None vd = v_delta_fn(n) if v_delta_fn is not None else None attn_out = self.attn(n, qd, vd) + if self.attn_twice_alpha != 0.0: + attn_out = attn_out + self.attn_twice_alpha * (n - attn_out) x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) return x @@ -672,12 +752,20 @@ def __init__( logit_softcap: float, rope_base: float, qk_gain_init: float, + train_seq_len: int, + z_loss_coef: float, + attn_twice_alpha: float, + overtone_init_power: float, + resid_mix_init_scale: float, ): super().__init__() if logit_softcap <= 0.0: raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") self.tie_embeddings = tie_embeddings self.tied_embed_init_std = tied_embed_init_std + self.z_loss_coef = z_loss_coef + self.overtone_init_power = overtone_init_power + self.resid_mix_init_scale = resid_mix_init_scale self.logit_softcap = logit_softcap self.tok_emb = nn.Embedding(vocab_size, model_dim) self.num_encoder_layers = num_layers // 2 @@ -693,6 +781,8 @@ def __init__( mlp_mult, rope_base, qk_gain_init, + train_seq_len, + attn_twice_alpha, ) for i in range(num_layers) ] @@ -706,9 +796,18 @@ def __init__( def _init_weights(self) -> None: if self.tie_embeddings: nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + if self.overtone_init_power > 0.0: + with torch.no_grad(): + u, s, v = torch.linalg.svd(self.tok_emb.weight.data, full_matrices=False) + self.tok_emb.weight.data = (u * (s[0] * torch.arange(1, s.shape[0] + 1, dtype=s.dtype).pow(-self.overtone_init_power))[None, :]) @ v for module in self.modules(): if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): nn.init.zeros_(module.weight) + for i, block in enumerate(self.blocks): + with torch.no_grad(): + phase = torch.sigmoid(torch.tensor(self.resid_mix_init_scale * (i / max(len(self.blocks) - 1, 1) - 0.5))) + block.resid_mix.data[0].fill_(float(phase)) + block.resid_mix.data[1].fill_(float(1.0 - phase)) def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: x = self.tok_emb(input_ids) @@ -736,11 +835,24 @@ def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: logits = self.lm_head(x) logits = logits + (lora.lm_head_lora(x) if lora else 0) logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) - if lora: - bsz, sl, V = logits.shape - return F.cross_entropy( - logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none").reshape(bsz, sl) - return F.cross_entropy(logits.float().reshape(-1, logits.size(-1)), target_ids.reshape(-1), reduction="mean") + return lm_loss(logits, target_ids, self.z_loss_coef, reduction="none" if lora else "mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[bi](x, x0) + x = self.final_norm(x) + logits = F.linear(x, self.tok_emb.weight.to(x.dtype)) if self.tie_embeddings else self.lm_head(x) + return self.logit_softcap * torch.tanh(logits / self.logit_softcap) # ----------------------------- @@ -1065,6 +1177,11 @@ def log0(msg: str, console: bool = True) -> None: logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + train_seq_len=args.train_seq_len, + z_loss_coef=args.z_loss_coef, + attn_twice_alpha=args.attn_twice_alpha, + overtone_init_power=args.overtone_init_power, + resid_mix_init_scale=args.resid_mix_init_scale, ).to(device).bfloat16() for module in base_model.modules(): if isinstance(module, CastedLinear): @@ -1134,6 +1251,7 @@ def log0(msg: str, console: bool = True) -> None: f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" ) + log0(f"eval_seq_len:{args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len} eval_stride:{args.eval_stride} muon_weight_decay:{args.muon_weight_decay} z_loss_coef:{args.z_loss_coef} attn_twice_alpha:{args.attn_twice_alpha} overtone_init_power:{args.overtone_init_power}") log0( f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " @@ -1263,6 +1381,11 @@ def lr_mul(step: int, elapsed_ms: float) -> float: torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) for opt in optimizers: opt.step() + if args.muon_weight_decay > 0: + with torch.no_grad(): + shrink = 1.0 - args.muon_weight_decay * optimizer_muon.param_groups[0]["lr"] + for p in matrix_params: + p.mul_(shrink) zero_grad_all() step += 1 @@ -1329,20 +1452,21 @@ def lr_mul(step: int, elapsed_ms: float) -> float: quant_blob_disk = f.read() quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_tokens_eval = load_validation_tokens(args.val_files, eval_seq_len) if eval_seq_len != args.train_seq_len else val_tokens torch.cuda.synchronize() t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) + if args.eval_stride > 0: + eval_batch_seqs = min(256, max(1, args.val_batch_size // eval_seq_len // max(world_size, 1))) + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False) + warmup_x = torch.zeros(eval_batch_seqs, eval_seq_len, dtype=torch.int64, device=device) + base_model.eval() + with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + _ = compiled_logits(warmup_x) + q_val_loss, q_val_bpb = eval_val_sliding(compiled_logits, rank, world_size, device, val_tokens_eval, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, eval_seq_len, args.eval_stride, eval_batch_seqs) + base_model.train() + else: + q_val_loss, q_val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens_eval, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, seq_len_override=eval_seq_len if eval_seq_len != args.train_seq_len else 0) torch.cuda.synchronize() log0( f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " From c482c6a0cac3d49502ace5a55debb44f820c4944 Mon Sep 17 00:00:00 2001 From: DeividasAQ22 Date: Fri, 20 Mar 2026 20:00:09 +0000 Subject: [PATCH 02/17] Add MTP rope scaling and hybrid ablations --- EXPERIMENT_PLAN.md | 27 ++- scripts/run_remote_profile.sh | 45 ++++- train_gpt.py | 321 +++++++++++++++++----------------- 3 files changed, 226 insertions(+), 167 deletions(-) diff --git a/EXPERIMENT_PLAN.md b/EXPERIMENT_PLAN.md index 18cc73b76..00d9ff1c6 100644 --- a/EXPERIMENT_PLAN.md +++ b/EXPERIMENT_PLAN.md @@ -14,21 +14,30 @@ This plan is optimized for limited budget and the challenge rules. Run these in order on remote GPUs: 1. `base10l` -2. `zloss_low` -3. `zloss_med` -4. `twice_low` -5. `zloss_twice` -6. `eval2048` +2. `twice_low` +3. `twice_eval2048_ttt1024` +4. `twice_layerwise` +5. `drope_eval` +6. `yarn_eval` +7. `mtp_low` +8. `muon_balance` +9. `hybrid_delta` Use the named launcher: ```bash NPROC_PER_NODE=1 bash scripts/run_remote_profile.sh base10l -NPROC_PER_NODE=1 bash scripts/run_remote_profile.sh zloss_low +NPROC_PER_NODE=1 bash scripts/run_remote_profile.sh twice_eval2048_ttt1024 ``` Then repeat the best 2-3 profiles on `8x H100 SXM`. +Interpretation order: + +- If `drope_eval` or `yarn_eval` beats `twice_eval2048_ttt1024`, keep the better rope scaling and discard the other. +- If `mtp_low` wins, sweep `MTP_DEPTH=3` and `MTP_LOSS_WEIGHT` around `0.05-0.2`. +- If `hybrid_delta` wins even slightly, open a dedicated hybrid branch before changing more optimizer knobs. + ## Dataset And Tokenizer Work The challenge allows tokenizer or dataset changes, but the repo says they will be examined carefully and you must prove the `val_bpb` calculation is correct. See [README.md](/Users/deividasmataciunas/Desktop/research/openai_golf/README.md#L168). @@ -53,6 +62,12 @@ Default ablation config: - `sp_bpe_1536` - `pure_byte_260` +After the model-side shortlist settles, do these data sweeps: + +1. Rebuild `sp_bpe_768`, `sp_bpe_1280`, and `pure_byte_260` +2. Rerun the current best profile on `TRAIN_SHARDS=1` +3. Only promote tokenizer changes that help `final_int8_ttt_lora` without pushing artifact bytes in the wrong direction + ## Dataset Ideas That Look Safe - Vary tokenizer vocab size on the same published docs diff --git a/scripts/run_remote_profile.sh b/scripts/run_remote_profile.sh index d93c4e233..0fde57cf4 100755 --- a/scripts/run_remote_profile.sh +++ b/scripts/run_remote_profile.sh @@ -23,6 +23,11 @@ case "$PROFILE" in export RUN_ID="${RUN_ID:-twice_low}" export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" ;; + twice_layerwise) + export RUN_ID="${RUN_ID:-twice_layerwise}" + export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" + export ATTN_TWICE_ALPHA_SLOPE="${ATTN_TWICE_ALPHA_SLOPE:-0.5}" + ;; zloss_twice) export RUN_ID="${RUN_ID:-zloss_twice}" export Z_LOSS_COEF="${Z_LOSS_COEF:-0.0001}" @@ -30,12 +35,48 @@ case "$PROFILE" in ;; eval2048) export RUN_ID="${RUN_ID:-eval2048}" - export EVAL_SEQ_LEN="${EVAL_SEQ_LEN:-2048}" + export ROUNDTRIP_EVAL_SEQ_LEN="${ROUNDTRIP_EVAL_SEQ_LEN:-2048}" export TTT_EVAL_SEQ_LEN="${TTT_EVAL_SEQ_LEN:-2048}" ;; + twice_eval2048_ttt1024) + export RUN_ID="${RUN_ID:-twice_eval2048_ttt1024}" + export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" + export ROUNDTRIP_EVAL_SEQ_LEN="${ROUNDTRIP_EVAL_SEQ_LEN:-2048}" + export TTT_EVAL_SEQ_LEN="${TTT_EVAL_SEQ_LEN:-1024}" + ;; + drope_eval) + export RUN_ID="${RUN_ID:-drope_eval}" + export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" + export ROUNDTRIP_EVAL_SEQ_LEN="${ROUNDTRIP_EVAL_SEQ_LEN:-2048}" + export ROUNDTRIP_ROPE_SCALING="${ROUNDTRIP_ROPE_SCALING:-drope}" + export TTT_EVAL_SEQ_LEN="${TTT_EVAL_SEQ_LEN:-1024}" + ;; + yarn_eval) + export RUN_ID="${RUN_ID:-yarn_eval}" + export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" + export ROUNDTRIP_EVAL_SEQ_LEN="${ROUNDTRIP_EVAL_SEQ_LEN:-2048}" + export ROUNDTRIP_ROPE_SCALING="${ROUNDTRIP_ROPE_SCALING:-yarn}" + export TTT_EVAL_SEQ_LEN="${TTT_EVAL_SEQ_LEN:-1024}" + ;; + mtp_low) + export RUN_ID="${RUN_ID:-mtp_low}" + export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" + export MTP_DEPTH="${MTP_DEPTH:-2}" + export MTP_LOSS_WEIGHT="${MTP_LOSS_WEIGHT:-0.1}" + ;; + muon_balance) + export RUN_ID="${RUN_ID:-muon_balance}" + export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" + export MUON_UPDATE_BALANCE="${MUON_UPDATE_BALANCE:-0.5}" + ;; + hybrid_delta) + export RUN_ID="${RUN_ID:-hybrid_delta}" + export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" + export HYBRID_DELTA_EVERY="${HYBRID_DELTA_EVERY:-4}" + ;; *) echo "Unknown profile: $PROFILE" >&2 - echo "Profiles: base10l zloss_low zloss_med twice_low zloss_twice eval2048" >&2 + echo "Profiles: base10l zloss_low zloss_med twice_low twice_layerwise zloss_twice eval2048 twice_eval2048_ttt1024 drope_eval yarn_eval mtp_low muon_balance hybrid_delta" >&2 exit 1 ;; esac diff --git a/train_gpt.py b/train_gpt.py index 5fee3aa8e..fbbd69757 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -27,17 +27,8 @@ from torch import Tensor, nn from torch.nn.parallel import DistributedDataParallel as DDP -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") train_files = os.path.join(data_path, "fineweb_train_*.bin") val_files = os.path.join(data_path, "fineweb_val_*.bin") @@ -45,23 +36,20 @@ class Hyperparameters: run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) seed = int(os.environ.get("SEED", 1337)) - # Validation cadence and batch size. Validation always uses the full fineweb_val split. val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", "0")) - eval_stride = int(os.environ.get("EVAL_STRIDE", "0")) + eval_seq_len = int(os.environ.get("ROUNDTRIP_EVAL_SEQ_LEN", os.environ.get("EVAL_SEQ_LEN", "0"))) + eval_stride = int(os.environ.get("ROUNDTRIP_EVAL_STRIDE", os.environ.get("EVAL_STRIDE", "0"))) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) num_layers = int(os.environ.get("NUM_LAYERS", 9)) num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) @@ -70,9 +58,14 @@ class Hyperparameters: mlp_mult = int(os.environ.get("MLP_MULT", 2)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + rope_scaling = os.environ.get("ROPE_SCALING", "ntk").lower() + rope_scale = float(os.environ.get("ROPE_SCALE", "1.0")) + roundtrip_rope_scaling = os.environ.get("ROUNDTRIP_ROPE_SCALING", rope_scaling).lower() + roundtrip_rope_scale = float(os.environ.get("ROUNDTRIP_ROPE_SCALE", os.environ.get("ROPE_SCALE", "1.0"))) + ttt_rope_scaling = os.environ.get("TTT_ROPE_SCALING", roundtrip_rope_scaling).lower() + ttt_rope_scale = float(os.environ.get("TTT_ROPE_SCALE", os.environ.get("ROUNDTRIP_ROPE_SCALE", os.environ.get("ROPE_SCALE", "1.0")))) logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - # Optimizer hyperparameters. embed_lr = float(os.environ.get("EMBED_LR", 0.6)) head_lr = float(os.environ.get("HEAD_LR", 0.008)) tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) @@ -90,26 +83,22 @@ class Hyperparameters: grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) z_loss_coef = float(os.environ.get("Z_LOSS_COEF", 0.0)) attn_twice_alpha = float(os.environ.get("ATTN_TWICE_ALPHA", 0.0)) + attn_twice_alpha_slope = float(os.environ.get("ATTN_TWICE_ALPHA_SLOPE", 0.0)) overtone_init_power = float(os.environ.get("OVERTONE_INIT_POWER", 0.0)) resid_mix_init_scale = float(os.environ.get("RESID_MIX_INIT_SCALE", 3.0)) + mtp_depth = int(os.environ.get("MTP_DEPTH", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.25)) + muon_update_balance = float(os.environ.get("MUON_UPDATE_BALANCE", 0.0)) + hybrid_delta_every = int(os.environ.get("HYBRID_DELTA_EVERY", 0)) - # Test-time training (LoRA) hyperparameters. ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.01)) ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 256)) - ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", os.environ.get("EVAL_SEQ_LEN", os.environ.get("TRAIN_SEQ_LEN", "1024")))) + ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", os.environ.get("TRAIN_SEQ_LEN", "1024"))) ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. a, b, c = (3.4445, -4.7750, 2.0315) X = G.bfloat16() X /= X.norm() + eps @@ -124,10 +113,10 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) - class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, update_balance: float = 0.0, nesterov: bool = True): super().__init__( params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, update_balance=update_balance, nesterov=nesterov), ) @torch.no_grad() @@ -148,6 +137,7 @@ def step(self, closure=None): lr = group["lr"] momentum = group["momentum"] backend_steps = group["backend_steps"] + update_balance = group["update_balance"] nesterov = group["nesterov"] total_params = sum(int(p.numel()) for p in params) @@ -165,8 +155,11 @@ def step(self, closure=None): if nesterov: g = g.add(buf, alpha=momentum) g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. g *= max(1, g.size(0) / g.size(1)) ** 0.5 + if update_balance > 0: + g_rms = g.float().pow(2).mean().sqrt().clamp_min_(1e-8) + p_rms = p.detach().float().pow(2).mean().sqrt().clamp_min_(1e-8) + g *= (p_rms / g_rms).pow(update_balance) updates_flat[curr : curr + p.numel()] = g.reshape(-1) curr += p.numel() @@ -182,14 +175,6 @@ def step(self, closure=None): return loss -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. def build_sentencepiece_luts( sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device @@ -222,7 +207,6 @@ def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: files = [Path(p) for p in sorted(glob.glob(pattern))] if not files: raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() usable = ((tokens.numel() - 1) // seq_len) * seq_len if usable <= 0: @@ -243,9 +227,6 @@ def eval_val( is_boundary_token_lut: Tensor, seq_len_override: int = 0, ) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge seq_len = seq_len_override if seq_len_override > 0 else args.train_seq_len local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) if local_batch_tokens < seq_len: @@ -341,13 +322,6 @@ def eval_val_sliding( val_loss = loss_sum / tok_count return float(val_loss.item()), float((val_loss.item() / math.log(2.0)) * (tok_count.item() / byte_count.item())) -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. CONTROL_TENSOR_NAME_PATTERNS = tuple( pattern @@ -385,8 +359,6 @@ def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, s def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: t32 = t.float() if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. clip_abs = ( torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() @@ -397,18 +369,12 @@ def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - # Vectors / scalars use a simpler per-tensor scale. clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() return q, scale def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes quantized: dict[str, Tensor] = {} scales: dict[str, Tensor] = {} dtypes: dict[str, str] = {} @@ -439,8 +405,6 @@ def quantize_state_dict_int8(state_dict: dict[str, Tensor]): stats["int8_payload_bytes"] += tensor_nbytes(kept) continue - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: kept = keep_float_tensor(name, t, passthrough_orig_dtypes) passthrough[name] = kept @@ -478,13 +442,11 @@ def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: s = obj["scales"][name] if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() else: scale = float(s.item()) out[name] = (q.float() * scale).to(dtype=dtype).contiguous() for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. out_t = t.detach().to("cpu").contiguous() orig_dtype = passthrough_orig_dtypes.get(name) if isinstance(orig_dtype, str): @@ -493,15 +455,11 @@ def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: return out -# ----------------------------- -# DATA LOADING -# ----------------------------- def load_data_shard(file: Path) -> Tensor: header_bytes = 256 * np.dtype(" Tensor: class TokenStream: - # Reads shards sequentially and wraps around forever. The training loop therefore - # has deterministic, simple streaming behavior with no sampling or workers. def __init__(self, pattern: str): self.files = [Path(p) for p in sorted(glob.glob(pattern))] if not self.files: @@ -546,8 +502,6 @@ def take(self, n: int) -> Tensor: class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): self.rank = rank self.world_size = world_size @@ -564,9 +518,6 @@ def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> y = local[1:].reshape(-1, seq_len) return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- class RMSNorm(nn.Module): def __init__(self, eps: float | None = None): @@ -578,33 +529,41 @@ def forward(self, x: Tensor) -> Tensor: class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. def forward(self, x: Tensor) -> Tensor: bias = self.bias.to(x.dtype) if self.bias is not None else None return F.linear(x, self.weight.to(x.dtype), bias) def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. with torch.no_grad(): for name, param in module.named_parameters(): if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: param.data = param.data.float() +def export_state_dict(module: nn.Module) -> dict[str, Tensor]: + return {name: tensor for name, tensor in module.state_dict().items() if not name.startswith("mtp_heads.")} + + class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, scaling: str = "ntk", scale: float = 1.0): super().__init__() self.dim = dim self.base = base self.train_seq_len = train_seq_len + self.scaling = scaling + self.scale = scale inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._seq_len_cached = 0 self._cos_cached: Tensor | None = None self._sin_cached: Tensor | None = None + def set_scaling(self, scaling: str, scale: float) -> None: + if self.scaling != scaling or self.scale != scale: + self.scaling, self.scale = scaling, scale + self._seq_len_cached, self._cos_cached, self._sin_cached = 0, None, None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: if ( self._cos_cached is None @@ -612,15 +571,22 @@ def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tup or self._seq_len_cached != seq_len or self._cos_cached.device != device ): - if seq_len > self.train_seq_len: - scale = seq_len / self.train_seq_len - adjusted_base = self.base * (scale ** (self.dim / (self.dim - 2))) - inv_freq = 1.0 / ( - adjusted_base ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim) - ) - else: - inv_freq = self.inv_freq.to(device) + stretch = max(seq_len / self.train_seq_len, 1.0) + inv_freq = self.inv_freq.to(device) t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + if seq_len > self.train_seq_len: + if self.scaling == "yarn": + t = t / (stretch * self.scale) + elif self.scaling == "drope": + t = t / ((stretch ** 0.5) * self.scale) + inv_freq = inv_freq / (stretch ** 0.25) + else: + adjusted_base = self.base * ((stretch * self.scale) ** (self.dim / (self.dim - 2))) + inv_freq = 1.0 / ( + adjusted_base ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim) + ) + elif self.scale != 1.0: + t = t / self.scale freqs = torch.outer(t, inv_freq) self._cos_cached = freqs.cos()[None, None, :, :] self._sin_cached = freqs.sin()[None, None, :, :] @@ -646,6 +612,8 @@ def __init__( rope_base: float, qk_gain_init: float, train_seq_len: int, + rope_scaling: str, + rope_scale: float, ): super().__init__() if dim % num_heads != 0: @@ -664,7 +632,7 @@ def __init__( self.proj = CastedLinear(dim, dim, bias=False) self.proj._zero_init = True self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len) + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len, scaling=rope_scaling, scale=rope_scale) def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: bsz, seqlen, dim = x.shape @@ -689,7 +657,6 @@ def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup def __init__(self, dim: int, mlp_mult: int): super().__init__() hidden = mlp_mult * dim @@ -702,6 +669,23 @@ def forward(self, x: Tensor) -> Tensor: return self.proj(x.square()) +class DeltaMixer(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.in_proj = CastedLinear(dim, 2 * dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + u, g = self.in_proj(x).chunk(2, dim=-1) + g = torch.sigmoid(g.float()).to(dtype=x.dtype) + state, outs = torch.zeros_like(u[:, 0]), [] + for t in range(x.size(1)): + state = g[:, t] * state + (1 - g[:, t]) * u[:, t] + outs.append(state) + return self.proj(torch.stack(outs, dim=1)) + + class Block(nn.Module): def __init__( self, @@ -713,11 +697,15 @@ def __init__( qk_gain_init: float, train_seq_len: int, attn_twice_alpha: float, + rope_scaling: str, + rope_scale: float, + hybrid_delta: bool, ): super().__init__() self.attn_norm = RMSNorm() self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len) + self.use_delta = hybrid_delta + self.attn = DeltaMixer(dim) if hybrid_delta else CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len, rope_scaling, rope_scale) self.attn_twice_alpha = attn_twice_alpha self.mlp = MLP(dim, mlp_mult) self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) @@ -728,9 +716,12 @@ def forward(self, x: Tensor, x0: Tensor, q_delta_fn=None, v_delta_fn=None) -> Te mix = self.resid_mix.to(dtype=x.dtype) x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 n = self.attn_norm(x) - qd = q_delta_fn(n) if q_delta_fn is not None else None - vd = v_delta_fn(n) if v_delta_fn is not None else None - attn_out = self.attn(n, qd, vd) + if self.use_delta: + attn_out = self.attn(n) + else: + qd = q_delta_fn(n) if q_delta_fn is not None else None + vd = v_delta_fn(n) if v_delta_fn is not None else None + attn_out = self.attn(n, qd, vd) if self.attn_twice_alpha != 0.0: attn_out = attn_out + self.attn_twice_alpha * (n - attn_out) x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out @@ -755,8 +746,14 @@ def __init__( train_seq_len: int, z_loss_coef: float, attn_twice_alpha: float, + attn_twice_alpha_slope: float, overtone_init_power: float, resid_mix_init_scale: float, + rope_scaling: str, + rope_scale: float, + mtp_depth: int, + mtp_loss_weight: float, + hybrid_delta_every: int, ): super().__init__() if logit_softcap <= 0.0: @@ -766,12 +763,14 @@ def __init__( self.z_loss_coef = z_loss_coef self.overtone_init_power = overtone_init_power self.resid_mix_init_scale = resid_mix_init_scale + self.mtp_loss_weight = mtp_loss_weight self.logit_softcap = logit_softcap self.tok_emb = nn.Embedding(vocab_size, model_dim) self.num_encoder_layers = num_layers // 2 self.num_decoder_layers = num_layers - self.num_encoder_layers self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.mtp_heads = nn.ModuleList([CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_depth)]) self.blocks = nn.ModuleList( [ Block( @@ -782,7 +781,10 @@ def __init__( rope_base, qk_gain_init, train_seq_len, - attn_twice_alpha, + attn_twice_alpha * (1 + attn_twice_alpha_slope * (2 * i / max(num_layers - 1, 1) - 1)), + rope_scaling, + rope_scale, + hybrid_delta_every > 0 and (i + 1) % hybrid_delta_every == 0, ) for i in range(num_layers) ] @@ -809,13 +811,17 @@ def _init_weights(self) -> None: block.resid_mix.data[0].fill_(float(phase)) block.resid_mix.data[1].fill_(float(1.0 - phase)) - def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: + def set_rope_scaling(self, scaling: str, scale: float) -> None: + for block in self.blocks: + if not block.use_delta: + block.attn.rotary.set_scaling(scaling, scale) + + def _forward_hidden(self, input_ids: Tensor, lora=None) -> Tensor: x = self.tok_emb(input_ids) x = F.rms_norm(x, (x.size(-1),)) x0 = x skips: list[Tensor] = [] - # First half stores skips; second half reuses them in reverse order. for i in range(self.num_encoder_layers): qd = lora.q_loras[i] if lora else None vd = lora.v_loras[i] if lora else None @@ -828,39 +834,39 @@ def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: qd = lora.q_loras[bi] if lora else None vd = lora.v_loras[bi] if lora else None x = self.blocks[bi](x, x0, qd, vd) - x = self.final_norm(x) + return self.final_norm(x) + + def _hidden_to_logits(self, x: Tensor, lora=None) -> Tensor: if self.tie_embeddings: logits = F.linear(x, self.tok_emb.weight) else: logits = self.lm_head(x) logits = logits + (lora.lm_head_lora(x) if lora else 0) - logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) - return lm_loss(logits, target_ids, self.z_loss_coef, reduction="none" if lora else "mean") + return self.logit_softcap * torch.tanh(logits / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: + x = self._forward_hidden(input_ids, lora=lora) + logits = self._hidden_to_logits(x, lora=lora) + loss = lm_loss(logits, target_ids, self.z_loss_coef, reduction="none" if lora else "mean") + if lora or not self.training or not self.mtp_heads or self.mtp_loss_weight <= 0: + return loss + aux = 0.0 + for depth, head in enumerate(self.mtp_heads, start=1): + if target_ids.size(1) <= depth: + break + aux = aux + lm_loss( + self.logit_softcap * torch.tanh(head(x[:, :-depth]) / self.logit_softcap), + target_ids[:, depth:], + 0.0, + reduction="mean", + ) + return loss + self.mtp_loss_weight * aux / max(len(self.mtp_heads), 1) def forward_logits(self, input_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[bi](x, x0) - x = self.final_norm(x) - logits = F.linear(x, self.tok_emb.weight.to(x.dtype)) if self.tie_embeddings else self.lm_head(x) - return self.logit_softcap * torch.tanh(logits / self.logit_softcap) + x = self._forward_hidden(input_ids) + return self._hidden_to_logits(x) -# ----------------------------- -# TEST-TIME TRAINING (LoRA) -# ----------------------------- -# -# At evaluation time, we adapt per-document low-rank adapters on the validation data. -# Each document gets its own adapter, so there is no inter-document dependency. BOS_ID = 1 @@ -883,6 +889,11 @@ def reset(self) -> None: self.A.uniform_(-bound, bound) # kaiming-uniform self.B.zero_() + +class NullLoRA(nn.Module): + def forward(self, x: Tensor) -> int: + return 0 + class BatchedTTTLoRA(nn.Module): """All LoRA adapters for one batch: LM head and Q/V per block.""" def __init__(self, bsz: int, model: GPT, rank: int): @@ -893,8 +904,12 @@ def __init__(self, bsz: int, model: GPT, rank: int): self.q_loras = nn.ModuleList() self.v_loras = nn.ModuleList() for block in model.blocks: - self.q_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_q.weight.shape[0], rank)) - self.v_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_v.weight.shape[0], rank)) + if getattr(block, "use_delta", False): + self.q_loras.append(NullLoRA()) + self.v_loras.append(NullLoRA()) + else: + self.q_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_q.weight.shape[0], rank)) + self.v_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_v.weight.shape[0], rank)) def reset(self) -> None: for m in self.modules(): @@ -968,12 +983,10 @@ def eval_val_ttt_lora( is_boundary_token_lut: Tensor, ) -> tuple[float, float]: """Evaluate with batched LoRA test-time training. Returns (val_loss, val_bpb).""" - # Load validation tokens and find document boundaries files = sorted(glob.glob(args.val_files)) all_tokens = torch.cat([load_data_shard(Path(f)) for f in files]) docs = _find_docs(all_tokens) - # Each rank takes a contiguous slice of documents rank_docs = docs[(len(docs) * rank) // world_size : (len(docs) * (rank + 1)) // world_size] chunk_size = args.ttt_chunk_size eval_seq_len = args.ttt_eval_seq_len @@ -1031,7 +1044,6 @@ def eval_val_ttt_lora( y[b, :wl] = toks[1:] doc_info.append((co, cl)) - # Forward pass (keep grad graph alive only when we need to train) if needs_train: with torch.autocast(device_type="cuda", dtype=torch.bfloat16): ptl = base_model(x, y, lora=cur_lora) @@ -1039,7 +1051,6 @@ def eval_val_ttt_lora( with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): ptl = base_model(x, y, lora=cur_lora) - # Score: accumulate loss and byte counts for BPB (before training on chunk) with torch.no_grad(): for b in range(bsz): if not active[b]: @@ -1049,7 +1060,6 @@ def eval_val_ttt_lora( ptl, x, y, b, co, cl, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, loss_sum, byte_sum, token_count) - # Train: one Adam step on the LoRA params using this chunk's loss if needs_train: mask = torch.tensor([float(ci < num_chunks[b] - 1) for b in range(bsz)], device=device) per_doc = ptl[:, chunk_offset:chunk_offset + chunk_size].mean(dim=-1) @@ -1066,9 +1076,6 @@ def eval_val_ttt_lora( val_bpb = float((loss_sum.item() / math.log(2.0)) / byte_sum.item()) return val_loss, val_bpb -# ----------------------------- -# TRAINING -# ----------------------------- def main() -> None: global zeropower_via_newtonschulz5 @@ -1077,9 +1084,6 @@ def main() -> None: args = Hyperparameters() zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ rank = int(os.environ.get("RANK", "0")) @@ -1100,7 +1104,6 @@ def main() -> None: dist.barrier() master_process = rank == 0 - # Fast math knobs torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp @@ -1135,9 +1138,6 @@ def log0(msg: str, console: bool = True) -> None: ) log0("=" * 100, console=False) - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- random.seed(args.seed) np.random.seed(args.seed) @@ -1161,9 +1161,6 @@ def log0(msg: str, console: bool = True) -> None: log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- base_model = GPT( vocab_size=args.vocab_size, @@ -1180,8 +1177,14 @@ def log0(msg: str, console: bool = True) -> None: train_seq_len=args.train_seq_len, z_loss_coef=args.z_loss_coef, attn_twice_alpha=args.attn_twice_alpha, + attn_twice_alpha_slope=args.attn_twice_alpha_slope, overtone_init_power=args.overtone_init_power, resid_mix_init_scale=args.resid_mix_init_scale, + rope_scaling=args.rope_scaling, + rope_scale=args.rope_scale, + mtp_depth=args.mtp_depth, + mtp_loss_weight=args.mtp_loss_weight, + hybrid_delta_every=args.hybrid_delta_every, ).to(device).bfloat16() for module in base_model.modules(): if isinstance(module, CastedLinear): @@ -1192,11 +1195,6 @@ def log0(msg: str, console: bool = True) -> None: compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam block_named_params = list(base_model.blocks.named_parameters()) matrix_params = [ p @@ -1222,6 +1220,7 @@ def log0(msg: str, console: bool = True) -> None: lr=args.matrix_lr, momentum=args.muon_momentum, backend_steps=args.muon_backend_steps, + update_balance=args.muon_update_balance, ) for group in optimizer_muon.param_groups: group["base_lr"] = args.matrix_lr @@ -1232,6 +1231,15 @@ def log0(msg: str, console: bool = True) -> None: fused=True, ) optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if len(base_model.mtp_heads) > 0: + mtp_lr = args.head_lr if args.head_lr > 0 else args.scalar_lr + optimizer_mtp = torch.optim.Adam( + [{"params": list(base_model.mtp_heads.parameters()), "lr": mtp_lr, "base_lr": mtp_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.append(optimizer_mtp) if base_model.lm_head is not None: optimizer_head = torch.optim.Adam( [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], @@ -1251,7 +1259,15 @@ def log0(msg: str, console: bool = True) -> None: f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" ) - log0(f"eval_seq_len:{args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len} eval_stride:{args.eval_stride} muon_weight_decay:{args.muon_weight_decay} z_loss_coef:{args.z_loss_coef} attn_twice_alpha:{args.attn_twice_alpha} overtone_init_power:{args.overtone_init_power}") + log0( + f"roundtrip_eval_seq_len:{args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len} " + f"roundtrip_eval_stride:{args.eval_stride} ttt_eval_seq_len:{args.ttt_eval_seq_len} " + f"rope_scaling:{args.rope_scaling} roundtrip_rope_scaling:{args.roundtrip_rope_scaling} " + f"ttt_rope_scaling:{args.ttt_rope_scaling} muon_weight_decay:{args.muon_weight_decay} " + f"muon_update_balance:{args.muon_update_balance} z_loss_coef:{args.z_loss_coef} " + f"attn_twice_alpha:{args.attn_twice_alpha} attn_twice_alpha_slope:{args.attn_twice_alpha_slope} " + f"mtp_depth:{args.mtp_depth} hybrid_delta_every:{args.hybrid_delta_every} overtone_init_power:{args.overtone_init_power}" + ) log0( f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " @@ -1259,9 +1275,6 @@ def log0(msg: str, console: bool = True) -> None: ) log0(f"seed:{args.seed}") - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) @@ -1282,8 +1295,6 @@ def lr_mul(step: int, elapsed_ms: float) -> float: remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. if args.warmup_steps > 0: initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] @@ -1310,9 +1321,6 @@ def lr_mul(step: int, elapsed_ms: float) -> float: model.require_backward_grad_sync = True train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- training_time_ms = 0.0 stop_after_step: int | None = None @@ -1400,7 +1408,6 @@ def lr_mul(step: int, elapsed_ms: float) -> float: f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" ) - # Needed to sync whether we've reached the wallclock cap. reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms if distributed and max_wallclock_ms is not None: reached_cap_tensor = torch.tensor(int(reached_cap), device=device) @@ -1414,21 +1421,16 @@ def lr_mul(step: int, elapsed_ms: float) -> float: f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" ) - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. if master_process: - torch.save(base_model.state_dict(), "final_model.pt") + torch.save(export_state_dict(base_model), "final_model.pt") model_bytes = os.path.getsize("final_model.pt") code_bytes = len(code.encode("utf-8")) log0(f"Serialized model: {model_bytes} bytes") log0(f"Code size: {code_bytes} bytes") log0(f"Total submission size: {model_bytes + code_bytes} bytes") - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_obj, quant_stats = quantize_state_dict_int8(export_state_dict(base_model)) quant_buf = io.BytesIO() torch.save(quant_obj, quant_buf) quant_raw = quant_buf.getvalue() @@ -1451,7 +1453,8 @@ def lr_mul(step: int, elapsed_ms: float) -> float: with open("final_model.int8.ptz", "rb") as f: quant_blob_disk = f.read() quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=False) + base_model.set_rope_scaling(args.roundtrip_rope_scaling, args.roundtrip_rope_scale) eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len val_tokens_eval = load_validation_tokens(args.val_files, eval_seq_len) if eval_seq_len != args.train_seq_len else val_tokens torch.cuda.synchronize() @@ -1474,8 +1477,8 @@ def lr_mul(step: int, elapsed_ms: float) -> float: ) log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - # LoRA test-time training evaluation (the competition score) torch._dynamo.reset() + base_model.set_rope_scaling(args.ttt_rope_scaling, args.ttt_rope_scale) torch.cuda.synchronize() t_ttt = time.perf_counter() ttt_val_loss, ttt_val_bpb = eval_val_ttt_lora( From e3e02c690d585a6c2489de62dd64fe787ddc50d4 Mon Sep 17 00:00:00 2001 From: DeividasAQ22 Date: Sun, 22 Mar 2026 12:01:42 +0000 Subject: [PATCH 03/17] Restore strict reload and fixed eval rope --- train_gpt.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index fbbd69757..bf5cd354d 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -1453,8 +1453,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float: with open("final_model.int8.ptz", "rb") as f: quant_blob_disk = f.read() quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=False) - base_model.set_rope_scaling(args.roundtrip_rope_scaling, args.roundtrip_rope_scale) + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len val_tokens_eval = load_validation_tokens(args.val_files, eval_seq_len) if eval_seq_len != args.train_seq_len else val_tokens torch.cuda.synchronize() @@ -1478,7 +1477,6 @@ def lr_mul(step: int, elapsed_ms: float) -> float: log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") torch._dynamo.reset() - base_model.set_rope_scaling(args.ttt_rope_scaling, args.ttt_rope_scale) torch.cuda.synchronize() t_ttt = time.perf_counter() ttt_val_loss, ttt_val_bpb = eval_val_ttt_lora( From 0f1ee35b0ba95b646d2de835adb39d77dfc65ca8 Mon Sep 17 00:00:00 2001 From: DeividasAQ22 Date: Sun, 22 Mar 2026 12:20:02 +0000 Subject: [PATCH 04/17] Add 5-run moonshot runner and ranking --- EXPERIMENT_PLAN.md | 47 +++++++++++--------- scripts/parse_run.py | 4 ++ scripts/rank_moonshots.py | 73 ++++++++++++++++++++++++++++++++ scripts/run_moonshot5.sh | 21 +++++++++ scripts/run_remote_experiment.sh | 13 ++++++ 5 files changed, 139 insertions(+), 19 deletions(-) create mode 100755 scripts/rank_moonshots.py create mode 100755 scripts/run_moonshot5.sh diff --git a/EXPERIMENT_PLAN.md b/EXPERIMENT_PLAN.md index 00d9ff1c6..bdeff4753 100644 --- a/EXPERIMENT_PLAN.md +++ b/EXPERIMENT_PLAN.md @@ -9,34 +9,43 @@ This plan is optimized for limited budget and the challenge rules. - Stay under the `16,000,000` byte artifact cap - Avoid risky dataset changes until the safe path is exhausted -## Model Ablations First +## 5-Run Moonshot Sequence -Run these in order on remote GPUs: +Run these in order on remote GPUs, using the current branch and `TRAIN_SHARDS=1`: -1. `base10l` -2. `twice_low` -3. `twice_eval2048_ttt1024` -4. `twice_layerwise` -5. `drope_eval` -6. `yarn_eval` -7. `mtp_low` -8. `muon_balance` -9. `hybrid_delta` +1. `drope_eval` +2. `yarn_eval` +3. `mtp_low` +4. `muon_balance` +5. `hybrid_delta` -Use the named launcher: +Run the entire sequence: ```bash -NPROC_PER_NODE=1 bash scripts/run_remote_profile.sh base10l -NPROC_PER_NODE=1 bash scripts/run_remote_profile.sh twice_eval2048_ttt1024 +NPROC_PER_NODE=1 bash scripts/run_moonshot5.sh ``` -Then repeat the best 2-3 profiles on `8x H100 SXM`. +This prints each run tail and then a ranked JSON summary against the control run `twice_eval2048_ttt1024_clean2`. -Interpretation order: +Ranking priority: -- If `drope_eval` or `yarn_eval` beats `twice_eval2048_ttt1024`, keep the better rope scaling and discard the other. -- If `mtp_low` wins, sweep `MTP_DEPTH=3` and `MTP_LOSS_WEIGHT` around `0.05-0.2`. -- If `hybrid_delta` wins even slightly, open a dedicated hybrid branch before changing more optimizer knobs. +1. Lowest `final_int8_ttt_lora val_bpb` +2. Lowest `final_int8_zlib_roundtrip_exact val_bpb` +3. Smallest artifact +4. Fastest step time + +Promotion rules: + +- Promote any run that beats the control on at least one final metric without exceeding the artifact cap. +- Promote `hybrid_delta` if it beats the control on either final metric, even slightly. + +Next-step rules: + +- If `drope_eval` beats `yarn_eval`, keep DRoPE and drop YaRN. +- If `yarn_eval` beats `drope_eval`, keep YaRN and drop DRoPE. +- If `mtp_low` wins, sweep `MTP_DEPTH=3` and `MTP_LOSS_WEIGHT` in `0.05`, `0.1`, `0.2`. +- If `muon_balance` wins, sweep `MUON_UPDATE_BALANCE` in `0.25`, `0.5`, `0.75`. +- If `hybrid_delta` wins even slightly, open a dedicated hybrid branch next. ## Dataset And Tokenizer Work diff --git a/scripts/parse_run.py b/scripts/parse_run.py index b1167f042..daf9c0ede 100755 --- a/scripts/parse_run.py +++ b/scripts/parse_run.py @@ -12,6 +12,7 @@ "ttt": re.compile(r"final_int8_ttt_lora val_loss:(\S+) val_bpb:(\S+)"), "artifact": re.compile(r"Total submission size int8\+zlib: (\d+) bytes"), "step_avg": re.compile(r"step:(\d+)/\d+ val_loss:(\S+) val_bpb:(\S+) train_time:(\d+)ms step_avg:(\S+)ms"), + "peak_mem": re.compile(r"peak memory allocated: (\d+) MiB reserved: (\d+) MiB"), } @@ -26,6 +27,9 @@ def parse_log(path: Path) -> dict[str, object]: out["ttt_val_bpb"] = float(m.group(2)) if m := PATTERNS["artifact"].search(text): out["artifact_bytes"] = int(m.group(1)) + if m := PATTERNS["peak_mem"].search(text): + out["peak_alloc_mib"] = int(m.group(1)) + out["peak_reserved_mib"] = int(m.group(2)) step_matches = PATTERNS["step_avg"].findall(text) if step_matches: step, val_loss, val_bpb, train_time_ms, step_avg = step_matches[-1] diff --git a/scripts/rank_moonshots.py b/scripts/rank_moonshots.py new file mode 100755 index 000000000..25687657d --- /dev/null +++ b/scripts/rank_moonshots.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import json +import math +import subprocess +import sys +from pathlib import Path + + +CONTROL_RUN = "twice_eval2048_ttt1024_clean2" +PROFILES = ["drope_eval", "yarn_eval", "mtp_low", "muon_balance", "hybrid_delta"] +ARTIFACT_CAP = 16_000_000 + + +def parse_logs(paths: list[str]) -> list[dict[str, object]]: + out = subprocess.check_output([sys.executable, "scripts/parse_run.py", *paths], text=True) + return json.loads(out) + + +def score_key(row: dict[str, object]) -> tuple[float, float, int, float]: + return ( + float(row.get("ttt_val_bpb", math.inf)), + float(row.get("roundtrip_val_bpb", math.inf)), + int(row.get("artifact_bytes", 10**12)), + float(row.get("step_avg_ms", math.inf)), + ) + + +def run_id_from_log(path: str) -> str: + return Path(path).stem + + +def promote(row: dict[str, object], control: dict[str, object] | None) -> tuple[bool, str]: + artifact = int(row.get("artifact_bytes", ARTIFACT_CAP + 1)) + if artifact > ARTIFACT_CAP: + return False, "artifact cap exceeded" + if control is None: + return True, "no control available" + row_ttt = float(row.get("ttt_val_bpb", math.inf)) + row_roundtrip = float(row.get("roundtrip_val_bpb", math.inf)) + ctl_ttt = float(control.get("ttt_val_bpb", math.inf)) + ctl_roundtrip = float(control.get("roundtrip_val_bpb", math.inf)) + if run_id_from_log(str(row["log"])) == "hybrid_delta" and (row_ttt < ctl_ttt or row_roundtrip < ctl_roundtrip): + return True, "hybrid delta beat control on at least one final metric" + if row_ttt < ctl_ttt or row_roundtrip < ctl_roundtrip: + return True, "beat control on at least one final metric" + return False, "did not beat control" + + +def main() -> int: + if len(sys.argv) > 1: + paths = sys.argv[1:] + else: + paths = [f"logs/{name}.txt" for name in [CONTROL_RUN, *PROFILES] if Path(f"logs/{name}.txt").exists()] + rows = parse_logs(paths) + for row in rows: + row["run_id"] = run_id_from_log(str(row["log"])) + control = next((row for row in rows if row["run_id"] == CONTROL_RUN), None) + ranked = sorted((row for row in rows if row["run_id"] != CONTROL_RUN), key=score_key) + result = [] + for row in ranked: + ok, reason = promote(row, control) + row = dict(row) + row["promote"] = ok + row["promotion_reason"] = reason + result.append(row) + print(json.dumps({"control": control, "ranked": result}, indent=2, sort_keys=True)) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/run_moonshot5.sh b/scripts/run_moonshot5.sh new file mode 100755 index 000000000..103ec90a9 --- /dev/null +++ b/scripts/run_moonshot5.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +RUNS=(drope_eval yarn_eval mtp_low muon_balance hybrid_delta) +NPROC_PER_NODE="${NPROC_PER_NODE:-1}" + +for run in "${RUNS[@]}"; do + echo + echo "=== Running ${run} ===" + NPROC_PER_NODE="$NPROC_PER_NODE" bash scripts/run_remote_profile.sh "$run" + echo + echo "--- tail ${run} ---" + tail -n 10 "logs/${run}.txt" +done + +echo +echo "=== Ranked moonshot summary ===" +python3 scripts/rank_moonshots.py diff --git a/scripts/run_remote_experiment.sh b/scripts/run_remote_experiment.sh index a00ffd1e7..e4972bedd 100755 --- a/scripts/run_remote_experiment.sh +++ b/scripts/run_remote_experiment.sh @@ -15,6 +15,8 @@ export TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}" export TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-1024}" export EVAL_SEQ_LEN="${EVAL_SEQ_LEN:-1024}" export EVAL_STRIDE="${EVAL_STRIDE:-64}" +export ROUNDTRIP_EVAL_SEQ_LEN="${ROUNDTRIP_EVAL_SEQ_LEN:-$EVAL_SEQ_LEN}" +export ROUNDTRIP_EVAL_STRIDE="${ROUNDTRIP_EVAL_STRIDE:-$EVAL_STRIDE}" export NUM_LAYERS="${NUM_LAYERS:-10}" export MODEL_DIM="${MODEL_DIM:-512}" export NUM_HEADS="${NUM_HEADS:-8}" @@ -28,15 +30,26 @@ export WARMDOWN_ITERS="${WARMDOWN_ITERS:-2500}" export MUON_MOMENTUM="${MUON_MOMENTUM:-0.95}" export MUON_BACKEND_STEPS="${MUON_BACKEND_STEPS:-5}" export MUON_WEIGHT_DECAY="${MUON_WEIGHT_DECAY:-0.02}" +export MUON_UPDATE_BALANCE="${MUON_UPDATE_BALANCE:-0.0}" export Z_LOSS_COEF="${Z_LOSS_COEF:-0.0}" export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.0}" +export ATTN_TWICE_ALPHA_SLOPE="${ATTN_TWICE_ALPHA_SLOPE:-0.0}" export OVERTONE_INIT_POWER="${OVERTONE_INIT_POWER:-0.5}" export RESID_MIX_INIT_SCALE="${RESID_MIX_INIT_SCALE:-3.0}" +export ROPE_SCALING="${ROPE_SCALING:-ntk}" +export ROPE_SCALE="${ROPE_SCALE:-1.0}" +export ROUNDTRIP_ROPE_SCALING="${ROUNDTRIP_ROPE_SCALING:-$ROPE_SCALING}" +export ROUNDTRIP_ROPE_SCALE="${ROUNDTRIP_ROPE_SCALE:-$ROPE_SCALE}" export GRAD_CLIP_NORM="${GRAD_CLIP_NORM:-0.0}" +export MTP_DEPTH="${MTP_DEPTH:-0}" +export MTP_LOSS_WEIGHT="${MTP_LOSS_WEIGHT:-0.25}" +export HYBRID_DELTA_EVERY="${HYBRID_DELTA_EVERY:-0}" export TTT_LORA_RANK="${TTT_LORA_RANK:-8}" export TTT_LORA_LR="${TTT_LORA_LR:-0.01}" export TTT_CHUNK_SIZE="${TTT_CHUNK_SIZE:-256}" export TTT_EVAL_SEQ_LEN="${TTT_EVAL_SEQ_LEN:-$EVAL_SEQ_LEN}" +export TTT_ROPE_SCALING="${TTT_ROPE_SCALING:-$ROUNDTRIP_ROPE_SCALING}" +export TTT_ROPE_SCALE="${TTT_ROPE_SCALE:-$ROUNDTRIP_ROPE_SCALE}" export TTT_BATCH_SIZE="${TTT_BATCH_SIZE:-64}" NPROC_PER_NODE="${NPROC_PER_NODE:-1}" From 4953987e7c7a574c910f30c0e512b2fdafedebdc Mon Sep 17 00:00:00 2001 From: DeividasAQ22 Date: Sun, 22 Mar 2026 13:47:57 +0000 Subject: [PATCH 05/17] Fix MTP compile loop on Runpod --- train_gpt.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/train_gpt.py b/train_gpt.py index bf5cd354d..3e19ce4f4 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -851,7 +851,8 @@ def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: if lora or not self.training or not self.mtp_heads or self.mtp_loss_weight <= 0: return loss aux = 0.0 - for depth, head in enumerate(self.mtp_heads, start=1): + depth = 1 + for head in self.mtp_heads: if target_ids.size(1) <= depth: break aux = aux + lm_loss( @@ -860,6 +861,7 @@ def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: 0.0, reduction="mean", ) + depth += 1 return loss + self.mtp_loss_weight * aux / max(len(self.mtp_heads), 1) def forward_logits(self, input_ids: Tensor) -> Tensor: From 981cd8e39a0fe33cb304626ae685565775342ab0 Mon Sep 17 00:00:00 2001 From: DeividasAQ22 Date: Sun, 22 Mar 2026 13:51:14 +0000 Subject: [PATCH 06/17] Add fast smoke runner for moonshots --- scripts/run_moonshot5_smoke.sh | 17 +++++++++++++++++ scripts/run_smoke_profile.sh | 17 +++++++++++++++++ train_gpt.py | 16 ++++++---------- 3 files changed, 40 insertions(+), 10 deletions(-) create mode 100644 scripts/run_moonshot5_smoke.sh create mode 100644 scripts/run_smoke_profile.sh diff --git a/scripts/run_moonshot5_smoke.sh b/scripts/run_moonshot5_smoke.sh new file mode 100644 index 000000000..ac3055707 --- /dev/null +++ b/scripts/run_moonshot5_smoke.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +RUNS=(drope_eval yarn_eval mtp_low muon_balance hybrid_delta) +NPROC_PER_NODE="${NPROC_PER_NODE:-1}" + +for run in "${RUNS[@]}"; do + echo + echo "=== Smoke ${run} ===" + NPROC_PER_NODE="$NPROC_PER_NODE" bash scripts/run_smoke_profile.sh "$run" + echo + echo "--- tail ${run}_smoke ---" + tail -n 10 "logs/${run}_smoke.txt" +done diff --git a/scripts/run_smoke_profile.sh b/scripts/run_smoke_profile.sh new file mode 100644 index 000000000..af414fd74 --- /dev/null +++ b/scripts/run_smoke_profile.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +PROFILE="${1:-base10l}" +shift || true + +export RUN_ID="${RUN_ID:-${PROFILE}_smoke}" +export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-90}" +export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-50}" +export VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-200}" +export WARMUP_STEPS="${WARMUP_STEPS:-5}" +export SKIP_FINAL_EVAL="${SKIP_FINAL_EVAL:-1}" + +bash scripts/run_remote_profile.sh "$PROFILE" "$@" diff --git a/train_gpt.py b/train_gpt.py index 3e19ce4f4..8343189aa 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -27,7 +27,6 @@ from torch import Tensor, nn from torch.nn.parallel import DistributedDataParallel as DDP - class Hyperparameters: data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") train_files = os.path.join(data_path, "fineweb_train_*.bin") @@ -48,6 +47,7 @@ class Hyperparameters: eval_seq_len = int(os.environ.get("ROUNDTRIP_EVAL_SEQ_LEN", os.environ.get("EVAL_SEQ_LEN", "0"))) eval_stride = int(os.environ.get("ROUNDTRIP_EVAL_STRIDE", os.environ.get("EVAL_STRIDE", "0"))) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + skip_final_eval = bool(int(os.environ.get("SKIP_FINAL_EVAL", "0"))) qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) @@ -97,7 +97,6 @@ class Hyperparameters: ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", os.environ.get("TRAIN_SEQ_LEN", "1024"))) ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) - def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: a, b, c = (3.4445, -4.7750, 2.0315) X = G.bfloat16() @@ -174,8 +173,6 @@ def step(self, closure=None): return loss - - def build_sentencepiece_luts( sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device ) -> tuple[Tensor, Tensor, Tensor]: @@ -454,8 +451,6 @@ def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: out[name] = out_t return out - - def load_data_shard(file: Path) -> Tensor: header_bytes = 256 * np.dtype(" Tensor: x = torch.relu(self.fc(x)) return self.proj(x.square()) - - class DeltaMixer(nn.Module): def __init__(self, dim: int): super().__init__() @@ -684,8 +677,6 @@ def forward(self, x: Tensor) -> Tensor: state = g[:, t] * state + (1 - g[:, t]) * u[:, t] outs.append(state) return self.proj(torch.stack(outs, dim=1)) - - class Block(nn.Module): def __init__( self, @@ -1456,6 +1447,11 @@ def lr_mul(step: int, elapsed_ms: float) -> float: quant_blob_disk = f.read() quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + if args.skip_final_eval: + log0("smoke_test: quantized reload ok; skipping final roundtrip and TTT eval") + if distributed: + dist.destroy_process_group() + return eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len val_tokens_eval = load_validation_tokens(args.val_files, eval_seq_len) if eval_seq_len != args.train_seq_len else val_tokens torch.cuda.synchronize() From 0166e9ab1e3acf35f1ad55b5605fd3899396c33a Mon Sep 17 00:00:00 2001 From: DeividasAQ22 Date: Sun, 22 Mar 2026 14:11:27 +0000 Subject: [PATCH 07/17] Allow export reload without MTP heads --- train_gpt.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 8343189aa..2e548a2a4 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -535,10 +535,14 @@ def restore_low_dim_params_to_fp32(module: nn.Module) -> None: if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: param.data = param.data.float() - def export_state_dict(module: nn.Module) -> dict[str, Tensor]: return {name: tensor for name, tensor in module.state_dict().items() if not name.startswith("mtp_heads.")} +def load_exported_state_dict(module: nn.Module, state_dict: dict[str, Tensor]) -> None: + missing, unexpected = module.load_state_dict(state_dict, strict=False) + bad_missing = [name for name in missing if not name.startswith("mtp_heads.")] + if bad_missing or unexpected: + raise RuntimeError(f"Export reload mismatch missing={bad_missing} unexpected={list(unexpected)}") class Rotary(nn.Module): def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, scaling: str = "ntk", scale: float = 1.0): @@ -1446,7 +1450,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float: with open("final_model.int8.ptz", "rb") as f: quant_blob_disk = f.read() quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + load_exported_state_dict(base_model, dequantize_state_dict_int8(quant_state)) if args.skip_final_eval: log0("smoke_test: quantized reload ok; skipping final roundtrip and TTT eval") if distributed: From d9eb32443c53cc7279c7c47eb9c2c3bb6cb45dbb Mon Sep 17 00:00:00 2001 From: DeividasAQ22 Date: Sun, 22 Mar 2026 14:39:45 +0000 Subject: [PATCH 08/17] Skip torch.compile for hybrid delta runs --- train_gpt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 2e548a2a4..570f8747d 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -1189,7 +1189,7 @@ def log0(msg: str, console: bool = True) -> None: if isinstance(module, Rotary): module.inv_freq.data = module.inv_freq.data.float() restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.hybrid_delta_every <= 0 else base_model model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model block_named_params = list(base_model.blocks.named_parameters()) @@ -1462,7 +1462,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float: t_qeval = time.perf_counter() if args.eval_stride > 0: eval_batch_seqs = min(256, max(1, args.val_batch_size // eval_seq_len // max(world_size, 1))) - compiled_logits = torch.compile(base_model.forward_logits, dynamic=False) + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False) if args.hybrid_delta_every <= 0 else base_model.forward_logits warmup_x = torch.zeros(eval_batch_seqs, eval_seq_len, dtype=torch.int64, device=device) base_model.eval() with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): From 0bb6be5e3804a6f31c0a485c2de76c14531a194a Mon Sep 17 00:00:00 2001 From: DeividasAQ22 Date: Sun, 22 Mar 2026 14:52:16 +0000 Subject: [PATCH 09/17] Make moonshot ranking robust to partial logs --- scripts/parse_run.py | 37 ++++++++++++++++++++++++++++--------- scripts/rank_moonshots.py | 5 ++++- 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/scripts/parse_run.py b/scripts/parse_run.py index daf9c0ede..c0dc05f5d 100755 --- a/scripts/parse_run.py +++ b/scripts/parse_run.py @@ -16,15 +16,30 @@ } +def maybe_float(value: str) -> float | None: + try: + return float(value) + except ValueError: + return None + + def parse_log(path: Path) -> dict[str, object]: + if not path.exists(): + return {"log": str(path), "missing": True} text = path.read_text(encoding="utf-8", errors="replace") out: dict[str, object] = {"log": str(path)} if m := PATTERNS["roundtrip"].search(text): - out["roundtrip_val_loss"] = float(m.group(1)) - out["roundtrip_val_bpb"] = float(m.group(2)) + val_loss = maybe_float(m.group(1)) + val_bpb = maybe_float(m.group(2)) + if val_loss is not None and val_bpb is not None: + out["roundtrip_val_loss"] = val_loss + out["roundtrip_val_bpb"] = val_bpb if m := PATTERNS["ttt"].search(text): - out["ttt_val_loss"] = float(m.group(1)) - out["ttt_val_bpb"] = float(m.group(2)) + val_loss = maybe_float(m.group(1)) + val_bpb = maybe_float(m.group(2)) + if val_loss is not None and val_bpb is not None: + out["ttt_val_loss"] = val_loss + out["ttt_val_bpb"] = val_bpb if m := PATTERNS["artifact"].search(text): out["artifact_bytes"] = int(m.group(1)) if m := PATTERNS["peak_mem"].search(text): @@ -33,11 +48,15 @@ def parse_log(path: Path) -> dict[str, object]: step_matches = PATTERNS["step_avg"].findall(text) if step_matches: step, val_loss, val_bpb, train_time_ms, step_avg = step_matches[-1] - out["last_val_step"] = int(step) - out["last_val_loss"] = float(val_loss) - out["last_val_bpb"] = float(val_bpb) - out["train_time_ms"] = int(train_time_ms) - out["step_avg_ms"] = float(step_avg) + parsed_val_loss = maybe_float(val_loss) + parsed_val_bpb = maybe_float(val_bpb) + parsed_step_avg = maybe_float(step_avg) + if parsed_val_loss is not None and parsed_val_bpb is not None and parsed_step_avg is not None: + out["last_val_step"] = int(step) + out["last_val_loss"] = parsed_val_loss + out["last_val_bpb"] = parsed_val_bpb + out["train_time_ms"] = int(train_time_ms) + out["step_avg_ms"] = parsed_step_avg return out diff --git a/scripts/rank_moonshots.py b/scripts/rank_moonshots.py index 25687657d..24622bbb9 100755 --- a/scripts/rank_moonshots.py +++ b/scripts/rank_moonshots.py @@ -14,7 +14,10 @@ def parse_logs(paths: list[str]) -> list[dict[str, object]]: - out = subprocess.check_output([sys.executable, "scripts/parse_run.py", *paths], text=True) + existing = [path for path in paths if Path(path).exists()] + if not existing: + return [] + out = subprocess.check_output([sys.executable, "scripts/parse_run.py", *existing], text=True) return json.loads(out) From b3a7429792759ec1ea35a5a16e58fbcbde6f954f Mon Sep 17 00:00:00 2001 From: DeividasAQ22 Date: Sun, 22 Mar 2026 16:02:36 +0000 Subject: [PATCH 10/17] Add shared-depth recursive moonshot --- EXPERIMENT_PLAN.md | 12 ++++++++ scripts/run_remote_experiment.sh | 2 ++ scripts/run_remote_profile.sh | 8 ++++- train_gpt.py | 51 ++++++++++++++++---------------- 4 files changed, 46 insertions(+), 27 deletions(-) diff --git a/EXPERIMENT_PLAN.md b/EXPERIMENT_PLAN.md index bdeff4753..b88b655fd 100644 --- a/EXPERIMENT_PLAN.md +++ b/EXPERIMENT_PLAN.md @@ -47,6 +47,18 @@ Next-step rules: - If `muon_balance` wins, sweep `MUON_UPDATE_BALANCE` in `0.25`, `0.5`, `0.75`. - If `hybrid_delta` wins even slightly, open a dedicated hybrid branch next. +## Next Moonshot + +New architecture branch: + +1. `shared_depth` + +Idea: + +- reuse `4` unique blocks across `10` logical layers +- keep tiny per-pass learned output scales so reused blocks can still specialize +- preserve the existing optimizer, export, and TTT paths + ## Dataset And Tokenizer Work The challenge allows tokenizer or dataset changes, but the repo says they will be examined carefully and you must prove the `val_bpb` calculation is correct. See [README.md](/Users/deividasmataciunas/Desktop/research/openai_golf/README.md#L168). diff --git a/scripts/run_remote_experiment.sh b/scripts/run_remote_experiment.sh index e4972bedd..162a891b0 100755 --- a/scripts/run_remote_experiment.sh +++ b/scripts/run_remote_experiment.sh @@ -44,6 +44,8 @@ export GRAD_CLIP_NORM="${GRAD_CLIP_NORM:-0.0}" export MTP_DEPTH="${MTP_DEPTH:-0}" export MTP_LOSS_WEIGHT="${MTP_LOSS_WEIGHT:-0.25}" export HYBRID_DELTA_EVERY="${HYBRID_DELTA_EVERY:-0}" +export SHARED_DEPTH_N="${SHARED_DEPTH_N:-0}" +export SHARED_DEPTH_GAIN="${SHARED_DEPTH_GAIN:-0.0}" export TTT_LORA_RANK="${TTT_LORA_RANK:-8}" export TTT_LORA_LR="${TTT_LORA_LR:-0.01}" export TTT_CHUNK_SIZE="${TTT_CHUNK_SIZE:-256}" diff --git a/scripts/run_remote_profile.sh b/scripts/run_remote_profile.sh index 0fde57cf4..9e65fac13 100755 --- a/scripts/run_remote_profile.sh +++ b/scripts/run_remote_profile.sh @@ -74,9 +74,15 @@ case "$PROFILE" in export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" export HYBRID_DELTA_EVERY="${HYBRID_DELTA_EVERY:-4}" ;; + shared_depth) + export RUN_ID="${RUN_ID:-shared_depth}" + export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" + export SHARED_DEPTH_N="${SHARED_DEPTH_N:-4}" + export SHARED_DEPTH_GAIN="${SHARED_DEPTH_GAIN:-0.15}" + ;; *) echo "Unknown profile: $PROFILE" >&2 - echo "Profiles: base10l zloss_low zloss_med twice_low twice_layerwise zloss_twice eval2048 twice_eval2048_ttt1024 drope_eval yarn_eval mtp_low muon_balance hybrid_delta" >&2 + echo "Profiles: base10l zloss_low zloss_med twice_low twice_layerwise zloss_twice eval2048 twice_eval2048_ttt1024 drope_eval yarn_eval mtp_low muon_balance hybrid_delta shared_depth" >&2 exit 1 ;; esac diff --git a/train_gpt.py b/train_gpt.py index 570f8747d..1e914e660 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -90,6 +90,8 @@ class Hyperparameters: mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.25)) muon_update_balance = float(os.environ.get("MUON_UPDATE_BALANCE", 0.0)) hybrid_delta_every = int(os.environ.get("HYBRID_DELTA_EVERY", 0)) + shared_depth_n = int(os.environ.get("SHARED_DEPTH_N", 0)) + shared_depth_gain = float(os.environ.get("SHARED_DEPTH_GAIN", 0.0)) ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.01)) @@ -749,6 +751,8 @@ def __init__( mtp_depth: int, mtp_loss_weight: float, hybrid_delta_every: int, + shared_depth_n: int, + shared_depth_gain: float, ): super().__init__() if logit_softcap <= 0.0: @@ -761,10 +765,14 @@ def __init__( self.mtp_loss_weight = mtp_loss_weight self.logit_softcap = logit_softcap self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.logical_num_layers = num_layers + shared_blocks = min(shared_depth_n, num_layers) if shared_depth_n > 0 else num_layers self.num_encoder_layers = num_layers // 2 self.num_decoder_layers = num_layers - self.num_encoder_layers self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.block_map = [i % shared_blocks for i in range(num_layers)] + self.pass_scales = nn.Parameter((1.0 + shared_depth_gain * torch.linspace(-1, 1, num_layers)[:, None]).expand(-1, model_dim).contiguous()) if shared_depth_n > 0 else None self.mtp_heads = nn.ModuleList([CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_depth)]) self.blocks = nn.ModuleList( [ @@ -781,7 +789,7 @@ def __init__( rope_scale, hybrid_delta_every > 0 and (i + 1) % hybrid_delta_every == 0, ) - for i in range(num_layers) + for i in range(max(self.block_map) + 1) ] ) self.final_norm = RMSNorm() @@ -805,7 +813,6 @@ def _init_weights(self) -> None: phase = torch.sigmoid(torch.tensor(self.resid_mix_init_scale * (i / max(len(self.blocks) - 1, 1) - 0.5))) block.resid_mix.data[0].fill_(float(phase)) block.resid_mix.data[1].fill_(float(1.0 - phase)) - def set_rope_scaling(self, scaling: str, scale: float) -> None: for block in self.blocks: if not block.use_delta: @@ -818,17 +825,23 @@ def _forward_hidden(self, input_ids: Tensor, lora=None) -> Tensor: skips: list[Tensor] = [] for i in range(self.num_encoder_layers): + block = self.blocks[self.block_map[i]] qd = lora.q_loras[i] if lora else None vd = lora.v_loras[i] if lora else None - x = self.blocks[i](x, x0, qd, vd) + x = block(x, x0, qd, vd) + if self.pass_scales is not None: + x = self.pass_scales[i].to(dtype=x.dtype)[None, None, :] * x skips.append(x) for i in range(self.num_decoder_layers): bi = self.num_encoder_layers + i if skips: x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + block = self.blocks[self.block_map[bi]] qd = lora.q_loras[bi] if lora else None vd = lora.v_loras[bi] if lora else None - x = self.blocks[bi](x, x0, qd, vd) + x = block(x, x0, qd, vd) + if self.pass_scales is not None: + x = self.pass_scales[bi].to(dtype=x.dtype)[None, None, :] * x return self.final_norm(x) def _hidden_to_logits(self, x: Tensor, lora=None) -> Tensor: @@ -862,9 +875,6 @@ def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: def forward_logits(self, input_ids: Tensor) -> Tensor: x = self._forward_hidden(input_ids) return self._hidden_to_logits(x) - - - BOS_ID = 1 class BatchedLinearLoRA(nn.Module): @@ -890,7 +900,6 @@ def reset(self) -> None: class NullLoRA(nn.Module): def forward(self, x: Tensor) -> int: return 0 - class BatchedTTTLoRA(nn.Module): """All LoRA adapters for one batch: LM head and Q/V per block.""" def __init__(self, bsz: int, model: GPT, rank: int): @@ -900,7 +909,8 @@ def __init__(self, bsz: int, model: GPT, rank: int): self.lm_head_lora = BatchedLinearLoRA(bsz, dim, vocab, rank) self.q_loras = nn.ModuleList() self.v_loras = nn.ModuleList() - for block in model.blocks: + for i in range(model.logical_num_layers): + block = model.blocks[model.block_map[i]] if getattr(block, "use_delta", False): self.q_loras.append(NullLoRA()) self.v_loras.append(NullLoRA()) @@ -912,7 +922,6 @@ def reset(self) -> None: for m in self.modules(): if isinstance(m, BatchedLinearLoRA): m.reset() - def _reset_ttt_optimizer(opt): for group in opt.param_groups: for p in group['params']: @@ -922,10 +931,8 @@ def _reset_ttt_optimizer(opt): s['exp_avg'].zero_() s['exp_avg_sq'].zero_() s['step'].fill_(0) - def _build_ttt_optimizer(lora, args: Hyperparameters): return torch.optim.Adam(lora.parameters(), lr=args.ttt_lora_lr, betas=(args.beta1, args.beta2), eps=1e-10) - def _find_docs(all_tokens: Tensor, include_next_bos: bool = True) -> list[tuple[int, int]]: """Return (start_offset, length) for each document, identified by BOS boundaries. @@ -942,7 +949,6 @@ def _find_docs(all_tokens: Tensor, include_next_bos: bool = True) -> list[tuple[ assert end - start >= 2 docs.append((start, end - start)) return docs - def _compute_chunk_window(ci: int, pred_len: int, num_chunks: int, chunk_size: int, eval_seq_len: int): """Return (win_start, win_len, chunk_offset, chunk_len) for chunk `ci` of a doc.""" chunk_start = ci * chunk_size @@ -952,7 +958,6 @@ def _compute_chunk_window(ci: int, pred_len: int, num_chunks: int, chunk_size: i chunk_offset = chunk_start - win_start chunk_len = chunk_end - chunk_start return win_start, win_len, chunk_offset, chunk_len - def _accumulate_bpb( ptl: Tensor, x: Tensor, y: Tensor, batch_i: int, chunk_offset: int, chunk_len: int, @@ -968,7 +973,6 @@ def _accumulate_bpb( loss_sum += lbl.sum() byte_sum += tok_bytes.sum() token_count += chunk_len - def eval_val_ttt_lora( args: Hyperparameters, base_model: GPT, @@ -1072,16 +1076,12 @@ def eval_val_ttt_lora( val_loss = float(loss_sum.item() / token_count.item()) val_bpb = float((loss_sum.item() / math.log(2.0)) / byte_sum.item()) return val_loss, val_bpb - - def main() -> None: global zeropower_via_newtonschulz5 code = Path(__file__).read_text(encoding="utf-8") args = Hyperparameters() zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ rank = int(os.environ.get("RANK", "0")) world_size = int(os.environ.get("WORLD_SIZE", "1")) @@ -1124,7 +1124,6 @@ def log0(msg: str, console: bool = True) -> None: if logfile is not None: with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) - log0(code, console=False) log0("=" * 100, console=False) log0(f"Running Python {sys.version}", console=False) @@ -1182,6 +1181,8 @@ def log0(msg: str, console: bool = True) -> None: mtp_depth=args.mtp_depth, mtp_loss_weight=args.mtp_loss_weight, hybrid_delta_every=args.hybrid_delta_every, + shared_depth_n=args.shared_depth_n, + shared_depth_gain=args.shared_depth_gain, ).to(device).bfloat16() for module in base_model.modules(): if isinstance(module, CastedLinear): @@ -1205,6 +1206,8 @@ def log0(msg: str, console: bool = True) -> None: ] if base_model.skip_weights.numel() > 0: scalar_params.append(base_model.skip_weights) + if base_model.pass_scales is not None: + scalar_params.append(base_model.pass_scales) token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr optimizer_tok = torch.optim.Adam( [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], @@ -1263,7 +1266,8 @@ def log0(msg: str, console: bool = True) -> None: f"ttt_rope_scaling:{args.ttt_rope_scaling} muon_weight_decay:{args.muon_weight_decay} " f"muon_update_balance:{args.muon_update_balance} z_loss_coef:{args.z_loss_coef} " f"attn_twice_alpha:{args.attn_twice_alpha} attn_twice_alpha_slope:{args.attn_twice_alpha_slope} " - f"mtp_depth:{args.mtp_depth} hybrid_delta_every:{args.hybrid_delta_every} overtone_init_power:{args.overtone_init_power}" + f"mtp_depth:{args.mtp_depth} hybrid_delta_every:{args.hybrid_delta_every} shared_depth_n:{args.shared_depth_n} " + f"shared_depth_gain:{args.shared_depth_gain} overtone_init_power:{args.overtone_init_power}" ) log0( f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " @@ -1271,8 +1275,6 @@ def log0(msg: str, console: bool = True) -> None: f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" ) log0(f"seed:{args.seed}") - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) def zero_grad_all() -> None: @@ -1490,10 +1492,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float: f"final_int8_ttt_lora val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f} " f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" ) - if distributed: dist.destroy_process_group() - - if __name__ == "__main__": main() From 307f2d671f850497509a25818e7cc10d0aa871e1 Mon Sep 17 00:00:00 2001 From: DeividasAQ22 Date: Sun, 22 Mar 2026 19:12:18 +0000 Subject: [PATCH 11/17] Trim shared-depth branch under line cap --- train_gpt.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 1e914e660..8ee92e81e 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -797,7 +797,6 @@ def __init__( if self.lm_head is not None: self.lm_head._zero_init = True self._init_weights() - def _init_weights(self) -> None: if self.tie_embeddings: nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) @@ -823,7 +822,6 @@ def _forward_hidden(self, input_ids: Tensor, lora=None) -> Tensor: x = F.rms_norm(x, (x.size(-1),)) x0 = x skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): block = self.blocks[self.block_map[i]] qd = lora.q_loras[i] if lora else None @@ -1419,8 +1417,6 @@ def lr_mul(step: int, elapsed_ms: float) -> float: f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" ) - - if master_process: torch.save(export_state_dict(base_model), "final_model.pt") model_bytes = os.path.getsize("final_model.pt") From 38c4dbf3fa8777ff605b7aa58c4cecb0dc14cb18 Mon Sep 17 00:00:00 2001 From: DeividasAQ22 Date: Tue, 24 Mar 2026 15:41:22 +0200 Subject: [PATCH 12/17] Add midshare shared-depth profile --- scripts/run_remote_experiment.sh | 1 + scripts/run_remote_profile.sh | 9 ++++++++- train_gpt.py | 11 ++++++++--- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/scripts/run_remote_experiment.sh b/scripts/run_remote_experiment.sh index 162a891b0..69b464d87 100755 --- a/scripts/run_remote_experiment.sh +++ b/scripts/run_remote_experiment.sh @@ -46,6 +46,7 @@ export MTP_LOSS_WEIGHT="${MTP_LOSS_WEIGHT:-0.25}" export HYBRID_DELTA_EVERY="${HYBRID_DELTA_EVERY:-0}" export SHARED_DEPTH_N="${SHARED_DEPTH_N:-0}" export SHARED_DEPTH_GAIN="${SHARED_DEPTH_GAIN:-0.0}" +export SHARED_DEPTH_EDGE_UNIQUE="${SHARED_DEPTH_EDGE_UNIQUE:-0}" export TTT_LORA_RANK="${TTT_LORA_RANK:-8}" export TTT_LORA_LR="${TTT_LORA_LR:-0.01}" export TTT_CHUNK_SIZE="${TTT_CHUNK_SIZE:-256}" diff --git a/scripts/run_remote_profile.sh b/scripts/run_remote_profile.sh index 9e65fac13..6ffd4666b 100755 --- a/scripts/run_remote_profile.sh +++ b/scripts/run_remote_profile.sh @@ -80,9 +80,16 @@ case "$PROFILE" in export SHARED_DEPTH_N="${SHARED_DEPTH_N:-4}" export SHARED_DEPTH_GAIN="${SHARED_DEPTH_GAIN:-0.15}" ;; + shared_depth_midshare) + export RUN_ID="${RUN_ID:-shared_depth_midshare}" + export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" + export SHARED_DEPTH_N="${SHARED_DEPTH_N:-4}" + export SHARED_DEPTH_GAIN="${SHARED_DEPTH_GAIN:-0.05}" + export SHARED_DEPTH_EDGE_UNIQUE="${SHARED_DEPTH_EDGE_UNIQUE:-2}" + ;; *) echo "Unknown profile: $PROFILE" >&2 - echo "Profiles: base10l zloss_low zloss_med twice_low twice_layerwise zloss_twice eval2048 twice_eval2048_ttt1024 drope_eval yarn_eval mtp_low muon_balance hybrid_delta shared_depth" >&2 + echo "Profiles: base10l zloss_low zloss_med twice_low twice_layerwise zloss_twice eval2048 twice_eval2048_ttt1024 drope_eval yarn_eval mtp_low muon_balance hybrid_delta shared_depth shared_depth_midshare" >&2 exit 1 ;; esac diff --git a/train_gpt.py b/train_gpt.py index 8ee92e81e..516e9f827 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -92,6 +92,7 @@ class Hyperparameters: hybrid_delta_every = int(os.environ.get("HYBRID_DELTA_EVERY", 0)) shared_depth_n = int(os.environ.get("SHARED_DEPTH_N", 0)) shared_depth_gain = float(os.environ.get("SHARED_DEPTH_GAIN", 0.0)) + shared_depth_edge_unique = int(os.environ.get("SHARED_DEPTH_EDGE_UNIQUE", 0)) ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.01)) @@ -753,6 +754,7 @@ def __init__( hybrid_delta_every: int, shared_depth_n: int, shared_depth_gain: float, + shared_depth_edge_unique: int, ): super().__init__() if logit_softcap <= 0.0: @@ -766,12 +768,13 @@ def __init__( self.logit_softcap = logit_softcap self.tok_emb = nn.Embedding(vocab_size, model_dim) self.logical_num_layers = num_layers - shared_blocks = min(shared_depth_n, num_layers) if shared_depth_n > 0 else num_layers + edge = min(shared_depth_edge_unique, num_layers // 2) if shared_depth_n > 0 else 0 + shared_blocks = min(shared_depth_n, max(num_layers - 2 * edge, 1)) if shared_depth_n > 0 else num_layers self.num_encoder_layers = num_layers // 2 self.num_decoder_layers = num_layers - self.num_encoder_layers self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.block_map = [i % shared_blocks for i in range(num_layers)] + self.block_map = list(range(edge)) + [edge + (i % shared_blocks) for i in range(max(num_layers - 2 * edge, 0))] + [edge + shared_blocks + i for i in range(edge)] if shared_depth_n > 0 else list(range(num_layers)) self.pass_scales = nn.Parameter((1.0 + shared_depth_gain * torch.linspace(-1, 1, num_layers)[:, None]).expand(-1, model_dim).contiguous()) if shared_depth_n > 0 else None self.mtp_heads = nn.ModuleList([CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_depth)]) self.blocks = nn.ModuleList( @@ -1181,6 +1184,7 @@ def log0(msg: str, console: bool = True) -> None: hybrid_delta_every=args.hybrid_delta_every, shared_depth_n=args.shared_depth_n, shared_depth_gain=args.shared_depth_gain, + shared_depth_edge_unique=args.shared_depth_edge_unique, ).to(device).bfloat16() for module in base_model.modules(): if isinstance(module, CastedLinear): @@ -1265,7 +1269,8 @@ def log0(msg: str, console: bool = True) -> None: f"muon_update_balance:{args.muon_update_balance} z_loss_coef:{args.z_loss_coef} " f"attn_twice_alpha:{args.attn_twice_alpha} attn_twice_alpha_slope:{args.attn_twice_alpha_slope} " f"mtp_depth:{args.mtp_depth} hybrid_delta_every:{args.hybrid_delta_every} shared_depth_n:{args.shared_depth_n} " - f"shared_depth_gain:{args.shared_depth_gain} overtone_init_power:{args.overtone_init_power}" + f"shared_depth_gain:{args.shared_depth_gain} shared_depth_edge_unique:{args.shared_depth_edge_unique} " + f"overtone_init_power:{args.overtone_init_power}" ) log0( f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " From 4477c99b339de71f68025be0184c73a0c4c1aaaa Mon Sep 17 00:00:00 2001 From: DeividasAQ22 Date: Fri, 27 Mar 2026 16:31:02 +0200 Subject: [PATCH 13/17] Add copycore_v1 leaderboard-aligned profile --- scripts/run_remote_experiment.sh | 5 ++ scripts/run_remote_profile.sh | 14 ++++- train_gpt.py | 101 ++++++++++++++----------------- 3 files changed, 64 insertions(+), 56 deletions(-) diff --git a/scripts/run_remote_experiment.sh b/scripts/run_remote_experiment.sh index 69b464d87..74cb4a447 100755 --- a/scripts/run_remote_experiment.sh +++ b/scripts/run_remote_experiment.sh @@ -47,6 +47,11 @@ export HYBRID_DELTA_EVERY="${HYBRID_DELTA_EVERY:-0}" export SHARED_DEPTH_N="${SHARED_DEPTH_N:-0}" export SHARED_DEPTH_GAIN="${SHARED_DEPTH_GAIN:-0.0}" export SHARED_DEPTH_EDGE_UNIQUE="${SHARED_DEPTH_EDGE_UNIQUE:-0}" +export MLP_ACT="${MLP_ACT:-relu2}" +export LEAKY_RELU_SLOPE="${LEAKY_RELU_SLOPE:-0.5}" +export EMA_DECAY="${EMA_DECAY:-0.0}" +export SWA_START_FRAC="${SWA_START_FRAC:-0.0}" +export SWA_STRIDE="${SWA_STRIDE:-50}" export TTT_LORA_RANK="${TTT_LORA_RANK:-8}" export TTT_LORA_LR="${TTT_LORA_LR:-0.01}" export TTT_CHUNK_SIZE="${TTT_CHUNK_SIZE:-256}" diff --git a/scripts/run_remote_profile.sh b/scripts/run_remote_profile.sh index 6ffd4666b..a579b5a42 100755 --- a/scripts/run_remote_profile.sh +++ b/scripts/run_remote_profile.sh @@ -87,9 +87,21 @@ case "$PROFILE" in export SHARED_DEPTH_GAIN="${SHARED_DEPTH_GAIN:-0.05}" export SHARED_DEPTH_EDGE_UNIQUE="${SHARED_DEPTH_EDGE_UNIQUE:-2}" ;; + copycore_v1) + export RUN_ID="${RUN_ID:-copycore_v1}" + export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" + export MLP_MULT="${MLP_MULT:-3}" + export MLP_ACT="${MLP_ACT:-leaky2}" + export LEAKY_RELU_SLOPE="${LEAKY_RELU_SLOPE:-0.5}" + export MUON_WEIGHT_DECAY="${MUON_WEIGHT_DECAY:-0.04}" + export WARMDOWN_ITERS="${WARMDOWN_ITERS:-3500}" + export EMA_DECAY="${EMA_DECAY:-0.997}" + export SWA_START_FRAC="${SWA_START_FRAC:-0.6}" + export SWA_STRIDE="${SWA_STRIDE:-50}" + ;; *) echo "Unknown profile: $PROFILE" >&2 - echo "Profiles: base10l zloss_low zloss_med twice_low twice_layerwise zloss_twice eval2048 twice_eval2048_ttt1024 drope_eval yarn_eval mtp_low muon_balance hybrid_delta shared_depth shared_depth_midshare" >&2 + echo "Profiles: base10l zloss_low zloss_med twice_low twice_layerwise zloss_twice eval2048 twice_eval2048_ttt1024 drope_eval yarn_eval mtp_low muon_balance hybrid_delta shared_depth shared_depth_midshare copycore_v1" >&2 exit 1 ;; esac diff --git a/train_gpt.py b/train_gpt.py index 516e9f827..0c99ac4ac 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -93,6 +93,11 @@ class Hyperparameters: shared_depth_n = int(os.environ.get("SHARED_DEPTH_N", 0)) shared_depth_gain = float(os.environ.get("SHARED_DEPTH_GAIN", 0.0)) shared_depth_edge_unique = int(os.environ.get("SHARED_DEPTH_EDGE_UNIQUE", 0)) + mlp_act = os.environ.get("MLP_ACT", "relu2").lower() + leaky_relu_slope = float(os.environ.get("LEAKY_RELU_SLOPE", 0.5)) + ema_decay = float(os.environ.get("EMA_DECAY", 0.0)) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.0)) + swa_stride = int(os.environ.get("SWA_STRIDE", 50)) ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.01)) @@ -547,6 +552,9 @@ def load_exported_state_dict(module: nn.Module, state_dict: dict[str, Tensor]) - if bad_missing or unexpected: raise RuntimeError(f"Export reload mismatch missing={bad_missing} unexpected={list(unexpected)}") +def clone_export_state(module: nn.Module) -> dict[str, Tensor]: + return {k: v.detach().clone() for k, v in export_state_dict(module).items()} + class Rotary(nn.Module): def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, scaling: str = "ntk", scale: float = 1.0): super().__init__() @@ -659,15 +667,16 @@ def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: int): + def __init__(self, dim: int, mlp_mult: int, act: str = "relu2", leaky_relu_slope: float = 0.5): super().__init__() hidden = mlp_mult * dim + self.act, self.leaky_relu_slope = act, leaky_relu_slope self.fc = CastedLinear(dim, hidden, bias=False) self.proj = CastedLinear(hidden, dim, bias=False) self.proj._zero_init = True def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) + x = F.leaky_relu(self.fc(x), negative_slope=self.leaky_relu_slope) if self.act == "leaky2" else torch.relu(self.fc(x)) return self.proj(x.square()) class DeltaMixer(nn.Module): def __init__(self, dim: int): @@ -698,6 +707,8 @@ def __init__( rope_scaling: str, rope_scale: float, hybrid_delta: bool, + mlp_act: str, + leaky_relu_slope: float, ): super().__init__() self.attn_norm = RMSNorm() @@ -705,7 +716,7 @@ def __init__( self.use_delta = hybrid_delta self.attn = DeltaMixer(dim) if hybrid_delta else CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len, rope_scaling, rope_scale) self.attn_twice_alpha = attn_twice_alpha - self.mlp = MLP(dim, mlp_mult) + self.mlp = MLP(dim, mlp_mult, mlp_act, leaky_relu_slope) self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) @@ -725,8 +736,6 @@ def forward(self, x: Tensor, x0: Tensor, q_delta_fn=None, v_delta_fn=None) -> Te x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) return x - - class GPT(nn.Module): def __init__( self, @@ -755,6 +764,8 @@ def __init__( shared_depth_n: int, shared_depth_gain: float, shared_depth_edge_unique: int, + mlp_act: str, + leaky_relu_slope: float, ): super().__init__() if logit_softcap <= 0.0: @@ -791,6 +802,8 @@ def __init__( rope_scaling, rope_scale, hybrid_delta_every > 0 and (i + 1) % hybrid_delta_every == 0, + mlp_act, + leaky_relu_slope, ) for i in range(max(self.block_map) + 1) ] @@ -1185,6 +1198,8 @@ def log0(msg: str, console: bool = True) -> None: shared_depth_n=args.shared_depth_n, shared_depth_gain=args.shared_depth_gain, shared_depth_edge_unique=args.shared_depth_edge_unique, + mlp_act=args.mlp_act, + leaky_relu_slope=args.leaky_relu_slope, ).to(device).bfloat16() for module in base_model.modules(): if isinstance(module, CastedLinear): @@ -1211,44 +1226,18 @@ def log0(msg: str, console: bool = True) -> None: if base_model.pass_scales is not None: scalar_params.append(base_model.pass_scales) token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - update_balance=args.muon_update_balance, - ) + optimizer_tok = torch.optim.Adam([{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizer_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, backend_steps=args.muon_backend_steps, update_balance=args.muon_update_balance) for group in optimizer_muon.param_groups: group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) + optimizer_scalar = torch.optim.Adam([{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] if len(base_model.mtp_heads) > 0: mtp_lr = args.head_lr if args.head_lr > 0 else args.scalar_lr - optimizer_mtp = torch.optim.Adam( - [{"params": list(base_model.mtp_heads.parameters()), "lr": mtp_lr, "base_lr": mtp_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) + optimizer_mtp = torch.optim.Adam([{"params": list(base_model.mtp_heads.parameters()), "lr": mtp_lr, "base_lr": mtp_lr}], betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) optimizers.append(optimizer_mtp) if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) + optimizer_head = torch.optim.Adam([{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) optimizers.insert(1, optimizer_head) n_params = sum(p.numel() for p in base_model.parameters()) @@ -1285,6 +1274,9 @@ def zero_grad_all() -> None: opt.zero_grad(set_to_none=True) max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + ema_state = clone_export_state(base_model) if args.ema_decay > 0 else None + swa_state, swa_count = None, 0 + swa_start_step = int(args.iterations * args.swa_start_frac) if args.swa_start_frac > 0 else args.iterations + 1 def lr_mul(step: int, elapsed_ms: float) -> float: if args.warmdown_iters <= 0: @@ -1296,7 +1288,6 @@ def lr_mul(step: int, elapsed_ms: float) -> float: warmdown_ms = args.warmdown_iters * step_ms remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - if args.warmup_steps > 0: initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] @@ -1337,18 +1328,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float: if should_validate: torch.cuda.synchronize() training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) + val_loss, val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) log0( f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" @@ -1358,10 +1338,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float: if last_step: if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) + log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms step:{step}/{args.iterations}") break elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) @@ -1396,6 +1373,19 @@ def lr_mul(step: int, elapsed_ms: float) -> float: shrink = 1.0 - args.muon_weight_decay * optimizer_muon.param_groups[0]["lr"] for p in matrix_params: p.mul_(shrink) + if ema_state is not None: + with torch.no_grad(): + for name, tensor in export_state_dict(base_model).items(): + ema_state[name].lerp_(tensor.detach(), 1.0 - args.ema_decay) + if step >= swa_start_step and (step - swa_start_step) % max(args.swa_stride, 1) == 0: + cur = export_state_dict(base_model) + swa_count += 1 + if swa_state is None: + swa_state = {k: v.detach().clone() for k, v in cur.items()} + else: + with torch.no_grad(): + for name, tensor in cur.items(): + swa_state[name].add_(tensor.detach() - swa_state[name], alpha=1.0 / swa_count) zero_grad_all() step += 1 @@ -1422,15 +1412,16 @@ def lr_mul(step: int, elapsed_ms: float) -> float: f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" ) + final_state = swa_state if swa_state is not None else ema_state if ema_state is not None else export_state_dict(base_model) if master_process: - torch.save(export_state_dict(base_model), "final_model.pt") + torch.save(final_state, "final_model.pt") model_bytes = os.path.getsize("final_model.pt") code_bytes = len(code.encode("utf-8")) log0(f"Serialized model: {model_bytes} bytes") log0(f"Code size: {code_bytes} bytes") log0(f"Total submission size: {model_bytes + code_bytes} bytes") - quant_obj, quant_stats = quantize_state_dict_int8(export_state_dict(base_model)) + quant_obj, quant_stats = quantize_state_dict_int8(final_state) quant_buf = io.BytesIO() torch.save(quant_obj, quant_buf) quant_raw = quant_buf.getvalue() From 681c41d7d5c8f1841dee8a3b1defb997b6a59c29 Mon Sep 17 00:00:00 2001 From: DeividasAQ22 Date: Fri, 27 Mar 2026 18:15:26 +0200 Subject: [PATCH 14/17] Add early quality gate runner --- scripts/run_autogate_profiles.sh | 24 ++++++++++++++++++++++++ scripts/run_remote_experiment.sh | 2 ++ train_gpt.py | 13 +++++++++++-- 3 files changed, 37 insertions(+), 2 deletions(-) create mode 100644 scripts/run_autogate_profiles.sh diff --git a/scripts/run_autogate_profiles.sh b/scripts/run_autogate_profiles.sh new file mode 100644 index 000000000..65a4bf2b6 --- /dev/null +++ b/scripts/run_autogate_profiles.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +AUTO_STOP_STEP="${AUTO_STOP_STEP:-1000}" +AUTO_STOP_MAX_VAL_BPB="${AUTO_STOP_MAX_VAL_BPB:-1.405}" +RUN_ID_PREFIX="${RUN_ID_PREFIX:-gate}" +NPROC_PER_NODE="${NPROC_PER_NODE:-1}" + +if [ "$#" -eq 0 ]; then + echo "Usage: $0 profile [profile ...]" >&2 + exit 1 +fi + +for profile in "$@"; do + export RUN_ID="${RUN_ID_PREFIX}_${profile}" + export AUTO_STOP_STEP AUTO_STOP_MAX_VAL_BPB + echo + echo "=== Auto-gate $profile (step=${AUTO_STOP_STEP}, max_val_bpb=${AUTO_STOP_MAX_VAL_BPB}) ===" + NPROC_PER_NODE="$NPROC_PER_NODE" bash scripts/run_remote_profile.sh "$profile" + tail -n 12 "logs/${RUN_ID}.txt" +done diff --git a/scripts/run_remote_experiment.sh b/scripts/run_remote_experiment.sh index 74cb4a447..3f18ed362 100755 --- a/scripts/run_remote_experiment.sh +++ b/scripts/run_remote_experiment.sh @@ -9,6 +9,8 @@ export DATA_PATH="${DATA_PATH:-./data/datasets/fineweb10B_sp1024}" export TOKENIZER_PATH="${TOKENIZER_PATH:-./data/tokenizers/fineweb_1024_bpe.model}" export VOCAB_SIZE="${VOCAB_SIZE:-1024}" export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" +export AUTO_STOP_STEP="${AUTO_STOP_STEP:-0}" +export AUTO_STOP_MAX_VAL_BPB="${AUTO_STOP_MAX_VAL_BPB:-0}" export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-100}" export VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-1000}" export TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}" diff --git a/train_gpt.py b/train_gpt.py index 0c99ac4ac..65cd9015c 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -48,6 +48,8 @@ class Hyperparameters: eval_stride = int(os.environ.get("ROUNDTRIP_EVAL_STRIDE", os.environ.get("EVAL_STRIDE", "0"))) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) skip_final_eval = bool(int(os.environ.get("SKIP_FINAL_EVAL", "0"))) + auto_stop_step = int(os.environ.get("AUTO_STOP_STEP", 0)) + auto_stop_max_val_bpb = float(os.environ.get("AUTO_STOP_MAX_VAL_BPB", 0.0)) qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) @@ -1266,6 +1268,8 @@ def log0(msg: str, console: bool = True) -> None: f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" ) + if args.auto_stop_step > 0 and args.auto_stop_max_val_bpb > 0: + log0(f"auto_stop_step:{args.auto_stop_step} auto_stop_max_val_bpb:{args.auto_stop_max_val_bpb:.4f}") log0(f"seed:{args.seed}") train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) @@ -1317,6 +1321,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float: training_time_ms = 0.0 stop_after_step: int | None = None + stop_reason: str | None = None torch.cuda.synchronize() t0 = time.perf_counter() @@ -1333,12 +1338,16 @@ def lr_mul(step: int, elapsed_ms: float) -> float: f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" ) + if stop_after_step is None and args.auto_stop_step > 0 and args.auto_stop_max_val_bpb > 0 and step >= args.auto_stop_step and val_bpb > args.auto_stop_max_val_bpb: + stop_after_step, stop_reason = step, "quality_gate" + log0(f"auto_stop_triggered: step:{step} val_bpb:{val_bpb:.4f} threshold:{args.auto_stop_max_val_bpb:.4f}") + last_step = True torch.cuda.synchronize() t0 = time.perf_counter() if last_step: if stop_after_step is not None and step < args.iterations: - log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms step:{step}/{args.iterations}") + log0(f"stopping_early: {stop_reason or 'wallclock_cap'} train_time:{training_time_ms:.0f}ms step:{step}/{args.iterations}") break elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) @@ -1406,7 +1415,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float: dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) reached_cap = bool(reached_cap_tensor.item()) if stop_after_step is None and reached_cap: - stop_after_step = step + stop_after_step, stop_reason = step, "wallclock_cap" log0( f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " From ad84ca93dcc176b4d7a5c0cc9849bc77f7297bcd Mon Sep 17 00:00:00 2001 From: DeividasAQ22 Date: Fri, 27 Mar 2026 18:17:55 +0200 Subject: [PATCH 15/17] Add budgeted auto-gate batch runner --- scripts/run_autogate_budget.sh | 35 ++++++++++++++++++++++++++++++++++ train_gpt.py | 13 ++++++------- 2 files changed, 41 insertions(+), 7 deletions(-) create mode 100644 scripts/run_autogate_budget.sh diff --git a/scripts/run_autogate_budget.sh b/scripts/run_autogate_budget.sh new file mode 100644 index 000000000..173e97baa --- /dev/null +++ b/scripts/run_autogate_budget.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +TOTAL_BUDGET_MINUTES="${TOTAL_BUDGET_MINUTES:-80}" +MIN_SECONDS_LEFT_TO_START="${MIN_SECONDS_LEFT_TO_START:-900}" +RUN_ID_PREFIX="${RUN_ID_PREFIX:-budget}" +AUTO_STOP_STEP="${AUTO_STOP_STEP:-1000}" +AUTO_STOP_MAX_VAL_BPB="${AUTO_STOP_MAX_VAL_BPB:-1.405}" +NPROC_PER_NODE="${NPROC_PER_NODE:-1}" +start_ts="$(date +%s)" +budget_seconds="$((TOTAL_BUDGET_MINUTES * 60))" + +if [ "$#" -eq 0 ]; then + echo "Usage: $0 profile [profile ...]" >&2 + exit 1 +fi + +for profile in "$@"; do + now_ts="$(date +%s)" + elapsed="$((now_ts - start_ts))" + remaining="$((budget_seconds - elapsed))" + if [ "$remaining" -lt "$MIN_SECONDS_LEFT_TO_START" ]; then + echo "=== Budget stop: remaining=${remaining}s is below minimum start window ${MIN_SECONDS_LEFT_TO_START}s ===" + exit 0 + fi + export RUN_ID="${RUN_ID_PREFIX}_${profile}" + export AUTO_STOP_STEP AUTO_STOP_MAX_VAL_BPB + echo + echo "=== Budget run $profile elapsed=${elapsed}s remaining=${remaining}s ===" + NPROC_PER_NODE="$NPROC_PER_NODE" bash scripts/run_remote_profile.sh "$profile" + tail -n 12 "logs/${RUN_ID}.txt" +done diff --git a/train_gpt.py b/train_gpt.py index 65cd9015c..30659765e 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -1268,8 +1268,7 @@ def log0(msg: str, console: bool = True) -> None: f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" ) - if args.auto_stop_step > 0 and args.auto_stop_max_val_bpb > 0: - log0(f"auto_stop_step:{args.auto_stop_step} auto_stop_max_val_bpb:{args.auto_stop_max_val_bpb:.4f}") + if args.auto_stop_step > 0 and args.auto_stop_max_val_bpb > 0: log0(f"auto_stop_step:{args.auto_stop_step} auto_stop_max_val_bpb:{args.auto_stop_max_val_bpb:.4f}") log0(f"seed:{args.seed}") train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) @@ -1318,7 +1317,6 @@ def lr_mul(step: int, elapsed_ms: float) -> float: model.require_backward_grad_sync = True train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - training_time_ms = 0.0 stop_after_step: int | None = None stop_reason: str | None = None @@ -1417,10 +1415,11 @@ def lr_mul(step: int, elapsed_ms: float) -> float: if stop_after_step is None and reached_cap: stop_after_step, stop_reason = step, "wallclock_cap" - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) + log0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB") + if stop_reason == "quality_gate": + log0("quality_gate_exit: skipping export and final eval") + if distributed: dist.destroy_process_group() + return final_state = swa_state if swa_state is not None else ema_state if ema_state is not None else export_state_dict(base_model) if master_process: torch.save(final_state, "final_model.pt") From 7dba94e6464124a869f0afdeb1782390fcfbf830 Mon Sep 17 00:00:00 2001 From: DeividasAQ22 Date: Fri, 27 Mar 2026 20:09:45 +0200 Subject: [PATCH 16/17] Add winner pack and saved best result --- BEST_RESULTS.md | 30 ++++++++++++++ EXPERIMENT_PLAN.md | 65 +++++++++++++++++++++++++++++++ scripts/run_remote_profile.sh | 62 ++++++++++++++++++++++++++++- scripts/run_top10_patterns_80m.sh | 21 ++++++++++ 4 files changed, 177 insertions(+), 1 deletion(-) create mode 100644 BEST_RESULTS.md create mode 100755 scripts/run_top10_patterns_80m.sh diff --git a/BEST_RESULTS.md b/BEST_RESULTS.md new file mode 100644 index 000000000..03be9cd2d --- /dev/null +++ b/BEST_RESULTS.md @@ -0,0 +1,30 @@ +# Best Results + +## Current Best Known Run + +- `budget_twice_eval2048_ttt1024` +- `final_int8_zlib_roundtrip_exact val_bpb: 1.38414876` +- `final_int8_ttt_lora val_bpb: 1.3962` +- `Total submission size int8+zlib: 11280566 bytes` + +This is the current baseline to protect. New experiment packs should be judged against this run first. + +## Restore Point + +Code branch: + +- `codex/runpod-2026-03-20-checkpoint` + +Runpod sync: + +```bash +cd /workspace/parameter-golf +git fetch myfork +git reset --hard myfork/codex/runpod-2026-03-20-checkpoint +``` + +Current best log: + +```bash +tail -n 20 logs/budget_twice_eval2048_ttt1024.txt +``` diff --git a/EXPERIMENT_PLAN.md b/EXPERIMENT_PLAN.md index b88b655fd..3a0baa794 100644 --- a/EXPERIMENT_PLAN.md +++ b/EXPERIMENT_PLAN.md @@ -9,6 +9,71 @@ This plan is optimized for limited budget and the challenge rules. - Stay under the `16,000,000` byte artifact cap - Avoid risky dataset changes until the safe path is exhausted +## Current Best + +- `budget_twice_eval2048_ttt1024` +- `final_int8_zlib_roundtrip_exact val_bpb: 1.38414876` +- `final_int8_ttt_lora val_bpb: 1.3962` +- `Total submission size int8+zlib: 11280566 bytes` + +See [BEST_RESULTS.md](/Users/deividasmataciunas/Desktop/research/openai_golf/BEST_RESULTS.md). + +## Current Leaderboard Patterns + +From the official leaderboard and top record summaries, the repeated ingredients are: + +- `11` layers +- `MLP 3x` +- `EMA` +- `weight decay 0.04` +- `warmdown 3500` +- `BigramHash` +- `Partial RoPE` +- `LN scale` +- `QAT` / `GPTQ-lite` +- `LeakyReLU(0.5)^2` + +The features we can test immediately in this repo are: + +- `EMA` +- `SWA` +- `weight decay` +- `warmdown` +- `MLP 3x` +- `LeakyReLU(0.5)^2` +- modest LR changes + +The features that require more code work are: + +- `BigramHash` +- `Partial RoPE` +- `LN scale` +- `QAT` / `GPTQ-lite` + +## Next 80-Minute Batch + +Run the winner-adjacent profile pack with a strict `1000`-step gate: + +```bash +bash scripts/run_top10_patterns_80m.sh +``` + +This batch tests: + +1. `winner_locked` +2. `winner_ema_swa` +3. `winner_wd03` +4. `winner_wd04` +5. `winner_warm3500` +6. `winner_lr18` +7. `winner_wd03_ema` +8. `winner_mlp3` + +Use this as the default threshold: + +- `AUTO_STOP_STEP=1000` +- `AUTO_STOP_MAX_VAL_BPB=1.395` + ## 5-Run Moonshot Sequence Run these in order on remote GPUs, using the current branch and `TRAIN_SHARDS=1`: diff --git a/scripts/run_remote_profile.sh b/scripts/run_remote_profile.sh index a579b5a42..c01613267 100755 --- a/scripts/run_remote_profile.sh +++ b/scripts/run_remote_profile.sh @@ -99,9 +99,69 @@ case "$PROFILE" in export SWA_START_FRAC="${SWA_START_FRAC:-0.6}" export SWA_STRIDE="${SWA_STRIDE:-50}" ;; + winner_locked) + export RUN_ID="${RUN_ID:-winner_locked}" + export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" + export ROUNDTRIP_EVAL_SEQ_LEN="${ROUNDTRIP_EVAL_SEQ_LEN:-2048}" + export TTT_EVAL_SEQ_LEN="${TTT_EVAL_SEQ_LEN:-1024}" + ;; + winner_ema_swa) + export RUN_ID="${RUN_ID:-winner_ema_swa}" + export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" + export ROUNDTRIP_EVAL_SEQ_LEN="${ROUNDTRIP_EVAL_SEQ_LEN:-2048}" + export TTT_EVAL_SEQ_LEN="${TTT_EVAL_SEQ_LEN:-1024}" + export EMA_DECAY="${EMA_DECAY:-0.997}" + export SWA_START_FRAC="${SWA_START_FRAC:-0.6}" + export SWA_STRIDE="${SWA_STRIDE:-50}" + ;; + winner_wd03) + export RUN_ID="${RUN_ID:-winner_wd03}" + export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" + export ROUNDTRIP_EVAL_SEQ_LEN="${ROUNDTRIP_EVAL_SEQ_LEN:-2048}" + export TTT_EVAL_SEQ_LEN="${TTT_EVAL_SEQ_LEN:-1024}" + export MUON_WEIGHT_DECAY="${MUON_WEIGHT_DECAY:-0.03}" + ;; + winner_wd04) + export RUN_ID="${RUN_ID:-winner_wd04}" + export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" + export ROUNDTRIP_EVAL_SEQ_LEN="${ROUNDTRIP_EVAL_SEQ_LEN:-2048}" + export TTT_EVAL_SEQ_LEN="${TTT_EVAL_SEQ_LEN:-1024}" + export MUON_WEIGHT_DECAY="${MUON_WEIGHT_DECAY:-0.04}" + ;; + winner_warm3500) + export RUN_ID="${RUN_ID:-winner_warm3500}" + export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" + export ROUNDTRIP_EVAL_SEQ_LEN="${ROUNDTRIP_EVAL_SEQ_LEN:-2048}" + export TTT_EVAL_SEQ_LEN="${TTT_EVAL_SEQ_LEN:-1024}" + export WARMDOWN_ITERS="${WARMDOWN_ITERS:-3500}" + ;; + winner_lr18) + export RUN_ID="${RUN_ID:-winner_lr18}" + export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" + export ROUNDTRIP_EVAL_SEQ_LEN="${ROUNDTRIP_EVAL_SEQ_LEN:-2048}" + export TTT_EVAL_SEQ_LEN="${TTT_EVAL_SEQ_LEN:-1024}" + export MATRIX_LR="${MATRIX_LR:-0.018}" + ;; + winner_wd03_ema) + export RUN_ID="${RUN_ID:-winner_wd03_ema}" + export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" + export ROUNDTRIP_EVAL_SEQ_LEN="${ROUNDTRIP_EVAL_SEQ_LEN:-2048}" + export TTT_EVAL_SEQ_LEN="${TTT_EVAL_SEQ_LEN:-1024}" + export MUON_WEIGHT_DECAY="${MUON_WEIGHT_DECAY:-0.03}" + export EMA_DECAY="${EMA_DECAY:-0.997}" + export SWA_START_FRAC="${SWA_START_FRAC:-0.6}" + export SWA_STRIDE="${SWA_STRIDE:-50}" + ;; + winner_mlp3) + export RUN_ID="${RUN_ID:-winner_mlp3}" + export ATTN_TWICE_ALPHA="${ATTN_TWICE_ALPHA:-0.05}" + export ROUNDTRIP_EVAL_SEQ_LEN="${ROUNDTRIP_EVAL_SEQ_LEN:-2048}" + export TTT_EVAL_SEQ_LEN="${TTT_EVAL_SEQ_LEN:-1024}" + export MLP_MULT="${MLP_MULT:-3}" + ;; *) echo "Unknown profile: $PROFILE" >&2 - echo "Profiles: base10l zloss_low zloss_med twice_low twice_layerwise zloss_twice eval2048 twice_eval2048_ttt1024 drope_eval yarn_eval mtp_low muon_balance hybrid_delta shared_depth shared_depth_midshare copycore_v1" >&2 + echo "Profiles: base10l zloss_low zloss_med twice_low twice_layerwise zloss_twice eval2048 twice_eval2048_ttt1024 drope_eval yarn_eval mtp_low muon_balance hybrid_delta shared_depth shared_depth_midshare copycore_v1 winner_locked winner_ema_swa winner_wd03 winner_wd04 winner_warm3500 winner_lr18 winner_wd03_ema winner_mlp3" >&2 exit 1 ;; esac diff --git a/scripts/run_top10_patterns_80m.sh b/scripts/run_top10_patterns_80m.sh new file mode 100755 index 000000000..351a036a4 --- /dev/null +++ b/scripts/run_top10_patterns_80m.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +export TOTAL_BUDGET_MINUTES="${TOTAL_BUDGET_MINUTES:-80}" +export MIN_SECONDS_LEFT_TO_START="${MIN_SECONDS_LEFT_TO_START:-900}" +export AUTO_STOP_STEP="${AUTO_STOP_STEP:-1000}" +export AUTO_STOP_MAX_VAL_BPB="${AUTO_STOP_MAX_VAL_BPB:-1.395}" +export RUN_ID_PREFIX="${RUN_ID_PREFIX:-top10}" + +bash scripts/run_autogate_budget.sh \ + winner_locked \ + winner_ema_swa \ + winner_wd03 \ + winner_wd04 \ + winner_warm3500 \ + winner_lr18 \ + winner_wd03_ema \ + winner_mlp3 From b0fd77231ca24fa07f5cf56a8b7d3c4acd98fba9 Mon Sep 17 00:00:00 2001 From: DeividasAQ22 Date: Fri, 27 Mar 2026 21:59:15 +0200 Subject: [PATCH 17/17] Update best result documentation --- BEST_RESULTS.md | 16 +++++++++++----- EXPERIMENT_PLAN.md | 27 ++++++++++++++++++++------- 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/BEST_RESULTS.md b/BEST_RESULTS.md index 03be9cd2d..5f9d67df6 100644 --- a/BEST_RESULTS.md +++ b/BEST_RESULTS.md @@ -2,13 +2,19 @@ ## Current Best Known Run -- `budget_twice_eval2048_ttt1024` -- `final_int8_zlib_roundtrip_exact val_bpb: 1.38414876` -- `final_int8_ttt_lora val_bpb: 1.3962` -- `Total submission size int8+zlib: 11280566 bytes` +- `top10_winner_wd04` +- `final_int8_zlib_roundtrip_exact val_bpb: 1.37843732` +- `final_int8_ttt_lora val_bpb: 1.3908` +- `Total submission size int8+zlib: 11219666 bytes` This is the current baseline to protect. New experiment packs should be judged against this run first. +## Key Takeaway + +- `MUON_WEIGHT_DECAY=0.04` is the strongest winner-adjacent change tested so far. +- Broad moonshots underperformed this branch. +- The next search should stay close to `winner_wd04`. + ## Restore Point Code branch: @@ -26,5 +32,5 @@ git reset --hard myfork/codex/runpod-2026-03-20-checkpoint Current best log: ```bash -tail -n 20 logs/budget_twice_eval2048_ttt1024.txt +tail -n 20 logs/top10_winner_wd04.txt ``` diff --git a/EXPERIMENT_PLAN.md b/EXPERIMENT_PLAN.md index 3a0baa794..804565d88 100644 --- a/EXPERIMENT_PLAN.md +++ b/EXPERIMENT_PLAN.md @@ -11,10 +11,10 @@ This plan is optimized for limited budget and the challenge rules. ## Current Best -- `budget_twice_eval2048_ttt1024` -- `final_int8_zlib_roundtrip_exact val_bpb: 1.38414876` -- `final_int8_ttt_lora val_bpb: 1.3962` -- `Total submission size int8+zlib: 11280566 bytes` +- `top10_winner_wd04` +- `final_int8_zlib_roundtrip_exact val_bpb: 1.37843732` +- `final_int8_ttt_lora val_bpb: 1.3908` +- `Total submission size int8+zlib: 11219666 bytes` See [BEST_RESULTS.md](/Users/deividasmataciunas/Desktop/research/openai_golf/BEST_RESULTS.md). @@ -52,13 +52,26 @@ The features that require more code work are: ## Next 80-Minute Batch +The latest tournament established `winner_wd04` as the new protected baseline. The next pack should stay tightly centered on this branch rather than reopening broader exploration. + +Recommended next batch: + +1. `winner_wd04` locked baseline +2. `winner_wd04 + warmdown 3500` +3. `winner_wd04 + matrix_lr 0.018` +4. `winner_wd04 + EMA/SWA` +5. `winner_wd04 + warmdown 3500 + matrix_lr 0.018` +6. `winner_wd04 + scalar_lr tweak` +7. `winner_wd04 + tied_embed_lr tweak` +8. `winner_wd04 + small z-loss` + Run the winner-adjacent profile pack with a strict `1000`-step gate: ```bash bash scripts/run_top10_patterns_80m.sh ``` -This batch tests: +The current committed pack tests: 1. `winner_locked` 2. `winner_ema_swa` @@ -69,10 +82,10 @@ This batch tests: 7. `winner_wd03_ema` 8. `winner_mlp3` -Use this as the default threshold: +Use this as the default threshold for winner-adjacent searches: - `AUTO_STOP_STEP=1000` -- `AUTO_STOP_MAX_VAL_BPB=1.395` +- `AUTO_STOP_MAX_VAL_BPB=1.390` ## 5-Run Moonshot Sequence