diff --git a/records/track_non_record_16mb/2026-04-30_SP8192_BPE_Mamba3_d448_ssm4_1xH100/README.md b/records/track_non_record_16mb/2026-04-30_SP8192_BPE_Mamba3_d448_ssm4_1xH100/README.md new file mode 100644 index 0000000000..575e6216c6 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_SP8192_BPE_Mamba3_d448_ssm4_1xH100/README.md @@ -0,0 +1,76 @@ +This record captures a non-record 16MB submission centered on an SP8192 BPE run with **Mamba3 SSM hybrid architecture**, trained on a single H100 for 30 minutes. + +The key architecture contribution here is the SSM/attention hybrid: replacing every 4th transformer attention block with a Mamba3 state-space model layer, reducing parameter count while maintaining competitive BPB. With `ssm_every_n=4` (2 SSM blocks, 7 GQA attention blocks), the model achieves 18.31M params — saving ~2.2M params vs the all-attention variant. + +Configuration: +- Track: `non-record` +- Layout: `VOCAB_SIZE=8192 MODEL_DIM=448 NUM_LAYERS=9 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=2` +- SSM: `USE_SSM=1 SSM_EVERY_N=4 SSM_IMPL=mamba3 MAMBA3_HEAD_DIM=64` +- Tokenizer: SentencePiece BPE 8192 (`fineweb_8192_bpe.model`) +- Batching: `TRAIN_BATCH_TOKENS=65536 TRAIN_SEQ_LEN=1024` +- Eval: sliding-window validation with `EVAL_STRIDE_FRAC=0.5` +- Opt: Muon (matrix) + Adam (scalar), `SWA_ENABLED=1` +- Quant/export: GPTQ int8 + zstd + +Key metrics (from `train.log`): +- Timed training stopped at `12278/20000` steps due to 30min wallclock cap. +- Pre-quant eval at stop: `val_loss:3.2398`, `val_bpb:1.2542` +- Post-quant roundtrip eval: `val_loss:3.25624330`, `val_bpb:1.26060944` +- Train time: `1800080ms` (`step_avg:146.61ms`) +- Code size: `231880 bytes` + +SSM/attention hybrid notes: +- **Mamba3 SSM** (`mamba_ssm` official CUDA extension) used as a drop-in mixer replacement +- SSM blocks use `expand=2.0, d_state=128, head_dim=64, mimo_rank=4` — comparable throughput to GQA attention on H100 +- `ssm_every_n=4` means layers [2, 6] are SSM, rest are GQA attention — reduces params by ~11% vs all-attention + +Dataset/tokenizer requirement: +- This package expects an **SP8192 exported dataset** at: + - `./sp8192_data/datasets/fineweb10B_sp8192` +- And uses tokenizer assets in this folder by default: + - `./fineweb_8192_bpe.model` + - `./fineweb_8192_bpe.vocab` +- Build the dataset (includes mamba_ssm CUDA extension install): + - `bash ./setup_sp8192_data.sh` + +Note: `mamba-ssm` is the official Mamba CUDA extension from [state-spaces/mamba](https://github.com/state-spaces/mamba). +Install from GitHub source (requires CUDA toolkit): +```bash +MAMBA_FORCE_BUILD=TRUE pip install --no-cache-dir --force-reinstall \ + git+https://github.com/state-spaces/mamba.git --no-build-isolation +``` + +Run command (1-GPU): +```bash +OMP_NUM_THREADS=1 \ +TORCH_NCCL_ASYNC_ERROR_HANDLING=1 \ +RUN_ID=sp8192_bpe_mamba3_d448_ssm4_1xh30m_s1337 \ +DATA_PATH=./sp8192_data/datasets/fineweb10B_sp8192 \ +TOKENIZER_PATH=./fineweb_8192_bpe.model \ +VOCAB_SIZE=8192 \ +MODEL_DIM=448 \ +NUM_LAYERS=9 \ +NUM_HEADS=8 \ +NUM_KV_HEADS=4 \ +MLP_MULT=2 \ +TIE_EMBEDDINGS=1 \ +USE_SWIGLU=1 \ +USE_SSM=1 \ +SSM_EVERY_N=4 \ +MAMBA3_HEAD_DIM=64 \ +TRAIN_BATCH_TOKENS=65536 \ +MAX_WALLCLOCK_SECONDS=1800 \ +WARMUP_STEPS=20 \ +EVAL_STRIDE_FRAC=0.5 \ +QUANT_SCHEME=int8 \ +COMPRESSOR=zstd \ +GPTQ=1 GPTQ_NSAMPLES=128 GPTQ_BLOCKSIZE=128 GPTQ_PERCDAMP=0.01 \ +torchrun --standalone --nproc_per_node=1 ./train_gpt_mamba3.py +``` + +Included files: +- `train_gpt_mamba3.py` (code snapshot used for the run package) +- `train.log` (exact run log, source code + runtime output) +- `submission.json` (metadata) +- `reqs.txt` (dependencies) +- `fineweb_8192_bpe.model` and `fineweb_8192_bpe.vocab` (tokenizer assets) diff --git a/records/track_non_record_16mb/2026-04-30_SP8192_BPE_Mamba3_d448_ssm4_1xH100/fineweb_8192_bpe.model b/records/track_non_record_16mb/2026-04-30_SP8192_BPE_Mamba3_d448_ssm4_1xH100/fineweb_8192_bpe.model new file mode 100644 index 0000000000..d9669f269d Binary files /dev/null and b/records/track_non_record_16mb/2026-04-30_SP8192_BPE_Mamba3_d448_ssm4_1xH100/fineweb_8192_bpe.model differ diff --git a/records/track_non_record_16mb/2026-04-30_SP8192_BPE_Mamba3_d448_ssm4_1xH100/fineweb_8192_bpe.vocab b/records/track_non_record_16mb/2026-04-30_SP8192_BPE_Mamba3_d448_ssm4_1xH100/fineweb_8192_bpe.vocab new file mode 100644 index 0000000000..35526307b0 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_SP8192_BPE_Mamba3_d448_ssm4_1xH100/fineweb_8192_bpe.vocab @@ -0,0 +1,8192 @@ + 0 + 0 + 0 + 0 +<0x00> 0 +<0x01> 0 +<0x02> 0 +<0x03> 0 +<0x04> 0 +<0x05> 0 +<0x06> 0 +<0x07> 0 +<0x08> 0 +<0x09> 0 +<0x0A> 0 +<0x0B> 0 +<0x0C> 0 +<0x0D> 0 +<0x0E> 0 +<0x0F> 0 +<0x10> 0 +<0x11> 0 +<0x12> 0 +<0x13> 0 +<0x14> 0 +<0x15> 0 +<0x16> 0 +<0x17> 0 +<0x18> 0 +<0x19> 0 +<0x1A> 0 +<0x1B> 0 +<0x1C> 0 +<0x1D> 0 +<0x1E> 0 +<0x1F> 0 +<0x20> 0 +<0x21> 0 +<0x22> 0 +<0x23> 0 +<0x24> 0 +<0x25> 0 +<0x26> 0 +<0x27> 0 +<0x28> 0 +<0x29> 0 +<0x2A> 0 +<0x2B> 0 +<0x2C> 0 +<0x2D> 0 +<0x2E> 0 +<0x2F> 0 +<0x30> 0 +<0x31> 0 +<0x32> 0 +<0x33> 0 +<0x34> 0 +<0x35> 0 +<0x36> 0 +<0x37> 0 +<0x38> 0 +<0x39> 0 +<0x3A> 0 +<0x3B> 0 +<0x3C> 0 +<0x3D> 0 +<0x3E> 0 +<0x3F> 0 +<0x40> 0 +<0x41> 0 +<0x42> 0 +<0x43> 0 +<0x44> 0 +<0x45> 0 +<0x46> 0 +<0x47> 0 +<0x48> 0 +<0x49> 0 +<0x4A> 0 +<0x4B> 0 +<0x4C> 0 +<0x4D> 0 +<0x4E> 0 +<0x4F> 0 +<0x50> 0 +<0x51> 0 +<0x52> 0 +<0x53> 0 +<0x54> 0 +<0x55> 0 +<0x56> 0 +<0x57> 0 +<0x58> 0 +<0x59> 0 +<0x5A> 0 +<0x5B> 0 +<0x5C> 0 +<0x5D> 0 +<0x5E> 0 +<0x5F> 0 +<0x60> 0 +<0x61> 0 +<0x62> 0 +<0x63> 0 +<0x64> 0 +<0x65> 0 +<0x66> 0 +<0x67> 0 +<0x68> 0 +<0x69> 0 +<0x6A> 0 +<0x6B> 0 +<0x6C> 0 +<0x6D> 0 +<0x6E> 0 +<0x6F> 0 +<0x70> 0 +<0x71> 0 +<0x72> 0 +<0x73> 0 +<0x74> 0 +<0x75> 0 +<0x76> 0 +<0x77> 0 +<0x78> 0 +<0x79> 0 +<0x7A> 0 +<0x7B> 0 +<0x7C> 0 +<0x7D> 0 +<0x7E> 0 +<0x7F> 0 +<0x80> 0 +<0x81> 0 +<0x82> 0 +<0x83> 0 +<0x84> 0 +<0x85> 0 +<0x86> 0 +<0x87> 0 +<0x88> 0 +<0x89> 0 +<0x8A> 0 +<0x8B> 0 +<0x8C> 0 +<0x8D> 0 +<0x8E> 0 +<0x8F> 0 +<0x90> 0 +<0x91> 0 +<0x92> 0 +<0x93> 0 +<0x94> 0 +<0x95> 0 +<0x96> 0 +<0x97> 0 +<0x98> 0 +<0x99> 0 +<0x9A> 0 +<0x9B> 0 +<0x9C> 0 +<0x9D> 0 +<0x9E> 0 +<0x9F> 0 +<0xA0> 0 +<0xA1> 0 +<0xA2> 0 +<0xA3> 0 +<0xA4> 0 +<0xA5> 0 +<0xA6> 0 +<0xA7> 0 +<0xA8> 0 +<0xA9> 0 +<0xAA> 0 +<0xAB> 0 +<0xAC> 0 +<0xAD> 0 +<0xAE> 0 +<0xAF> 0 +<0xB0> 0 +<0xB1> 0 +<0xB2> 0 +<0xB3> 0 +<0xB4> 0 +<0xB5> 0 +<0xB6> 0 +<0xB7> 0 +<0xB8> 0 +<0xB9> 0 +<0xBA> 0 +<0xBB> 0 +<0xBC> 0 +<0xBD> 0 +<0xBE> 0 +<0xBF> 0 +<0xC0> 0 +<0xC1> 0 +<0xC2> 0 +<0xC3> 0 +<0xC4> 0 +<0xC5> 0 +<0xC6> 0 +<0xC7> 0 +<0xC8> 0 +<0xC9> 0 +<0xCA> 0 +<0xCB> 0 +<0xCC> 0 +<0xCD> 0 +<0xCE> 0 +<0xCF> 0 +<0xD0> 0 +<0xD1> 0 +<0xD2> 0 +<0xD3> 0 +<0xD4> 0 +<0xD5> 0 +<0xD6> 0 +<0xD7> 0 +<0xD8> 0 +<0xD9> 0 +<0xDA> 0 +<0xDB> 0 +<0xDC> 0 +<0xDD> 0 +<0xDE> 0 +<0xDF> 0 +<0xE0> 0 +<0xE1> 0 +<0xE2> 0 +<0xE3> 0 +<0xE4> 0 +<0xE5> 0 +<0xE6> 0 +<0xE7> 0 +<0xE8> 0 +<0xE9> 0 +<0xEA> 0 +<0xEB> 0 +<0xEC> 0 +<0xED> 0 +<0xEE> 0 +<0xEF> 0 +<0xF0> 0 +<0xF1> 0 +<0xF2> 0 +<0xF3> 0 +<0xF4> 0 +<0xF5> 0 +<0xF6> 0 +<0xF7> 0 +<0xF8> 0 +<0xF9> 0 +<0xFA> 0 +<0xFB> 0 +<0xFC> 0 +<0xFD> 0 +<0xFE> 0 +<0xFF> 0 +▁t -0 +▁a -1 +in -2 +he -3 +re -4 +on -5 +er -6 +▁the -7 +▁s -8 +▁w -9 +or -10 +at -11 +nd -12 +ou -13 +▁c -14 +it -15 +es -16 +▁f -17 +is -18 +en -19 +ing -20 +▁b -21 +▁p -22 +▁o -23 +an -24 +ed -25 +al -26 +▁to -27 +▁m -28 +ar -29 +▁and -30 +▁in -31 +▁of -32 +▁d -33 +le -34 +ic -35 +as -36 +om -37 +▁h -38 +ion -39 +▁th -40 +il -41 +▁T -42 +ent -43 +▁l -44 +ve -45 +▁y -46 +ro -47 +st -48 +▁I -49 +▁e -50 +▁re -51 +▁n -52 +▁S -53 +▁g -54 +et -55 +ct -56 +▁A -57 +▁you -58 +▁C -59 +ly -60 +▁for -61 +id -62 +▁is -63 +ay -64 +▁on -65 +▁be -66 +ot -67 +ow -68 +ol -69 +am -70 +ce -71 +ig -72 +us -73 +ad -74 +im -75 +▁M -76 +ch -77 +el -78 +ver -79 +ith -80 +ut -81 +▁st -82 +ation -83 +ur -84 +▁P -85 +▁with -86 +▁that -87 +ir -88 +▁B -89 +▁W -90 +▁The -91 +▁it -92 +▁he -93 +ra -94 +ill -95 +ers -96 +▁al -97 +un -98 +ul -99 +▁an -100 +▁D -101 +▁H -102 +▁F -103 +out -104 +▁pro -105 +▁as -106 +▁wh -107 +▁are -108 +ke -109 +se -110 +ter -111 +▁we -112 +if -113 +▁ha -114 +ge -115 +oo -116 +▁R -117 +our -118 +pp -119 +ck -120 +ate -121 +ess -122 +▁at -123 +▁con -124 +▁com -125 +▁or -126 +▁L -127 +est -128 +her -129 +ore -130 +ment -131 +▁fr -132 +ab -133 +igh -134 +▁- -135 +▁ne -136 +▁N -137 +ort -138 +▁se -139 +▁G -140 +▁your -141 +ld -142 +▁E -143 +ist -144 +ri -145 +op -146 +▁( -147 +▁ex -148 +ity -149 +ure -150 +▁O -151 +em -152 +▁v -153 +qu -154 +ant -155 +art -156 +ive -157 +ust -158 +um -159 +▁was -160 +▁have -161 +pe -162 +▁from -163 +▁this -164 +▁de -165 +▁r -166 +▁sh -167 +th -168 +ain -169 +ies -170 +▁can -171 +up -172 +▁will -173 +▁ch -174 +and -175 +▁by -176 +os -177 +ight -178 +nt -179 +ie -180 +▁us -181 +ome -182 +all -183 +ard -184 +▁not -185 +ud -186 +res -187 +▁le -188 +▁J -189 +ast -190 +▁pl -191 +ost -192 +▁su -193 +▁ab -194 +iv -195 +ear -196 +▁wor -197 +ide -198 +ial -199 +rou -200 +▁all -201 +gh -202 +od -203 +oc -204 +ak -205 +te -206 +ine -207 +ould -208 +▁j -209 +red -210 +ag -211 +▁has -212 +.. -213 +ice -214 +▁Th -215 +ell -216 +▁U -217 +age -218 +▁do -219 +▁k -220 +ack -221 +fe -222 +ook -223 +ac -224 +▁ad -225 +per -226 +▁In -227 +ip -228 +▁comp -229 +ake -230 +▁out -231 +ions -232 +ally -233 +▁up -234 +are -235 +▁but -236 +▁me -237 +▁whe -238 +pt -239 +lo -240 +ry -241 +able -242 +▁our -243 +▁“ -244 +one -245 +ind -246 +▁en -247 +▁more -248 +ail -249 +ite -250 +ther -251 +▁their -252 +▁Y -253 +ich -254 +▁so -255 +very -256 +ime -257 +cc -258 +ood -259 +ated -260 +ong -261 +▁K -262 +▁my -263 +▁sa -264 +for -265 +iz -266 +ame -267 +ber -268 +▁they -269 +▁St -270 +▁te -271 +so -272 +ous -273 +▁one -274 +ans -275 +act -276 +▁about -277 +ll -278 +ike -279 +du -280 +▁cont -281 +ase -282 +og -283 +▁V -284 +▁im -285 +ick -286 +▁cl -287 +ia -288 +ance -289 +▁work -290 +▁inc -291 +ign -292 +▁un -293 +ire -294 +ree -295 +▁off -296 +▁fe -297 +▁who -298 +▁man -299 +ue -300 +ace -301 +ach -302 +reat -303 +ub -304 +▁It -305 +ction -306 +▁go -307 +ne -308 +▁app -309 +▁year -310 +▁new -311 +ep -312 +ult -313 +ib -314 +ap -315 +▁his -316 +ays -317 +erv -318 +▁Ch -319 +▁We -320 +▁res -321 +und -322 +▁" -323 +▁sp -324 +ass -325 +ark -326 +ations -327 +ff -328 +▁qu -329 +ary -330 +▁per -331 +▁also -332 +ile -333 +▁which -334 +▁int -335 +▁time -336 +ove -337 +form -338 +ven -339 +ount -340 +▁get -341 +▁tr -342 +own -343 +▁like -344 +▁some -345 +▁other -346 +ond -347 +ents -348 +ings -349 +vel -350 +▁any -351 +ical -352 +ence -353 +▁part -354 +av -355 +▁been -356 +▁dis -357 +▁This -358 +▁over -359 +ition -360 +ress -361 +pl -362 +ors -363 +▁rec -364 +▁them -365 +▁He -366 +▁sc -367 +▁ar -368 +ild -369 +▁pe -370 +port -371 +ink -372 +low -373 +▁ag -374 +▁ro -375 +▁her -376 +▁when -377 +ound -378 +▁kn -379 +ord -380 +mer -381 +int -382 +▁need -383 +ish -384 +▁pr -385 +irst -386 +ens -387 +ough -388 +▁said -389 +ru -390 +▁pre -391 +▁spe -392 +▁just -393 +wn -394 +ren -395 +▁what -396 +▁there -397 +▁if -398 +▁acc -399 +▁than -400 +▁its -401 +ov -402 +▁Re -403 +day -404 +vers -405 +▁would -406 +ater -407 +fter -408 +▁had -409 +ade -410 +ning -411 +lud -412 +▁hel -413 +▁– -414 +▁were -415 +▁am -416 +old -417 +rough -418 +▁into -419 +▁des -420 +ory -421 +ople -422 +itt -423 +ang -424 +▁help -425 +▁tw -426 +▁how -427 +use -428 +lic -429 +ool -430 +▁bec -431 +▁add -432 +anc -433 +▁first -434 +ose -435 +▁make -436 +▁comm -437 +ons -438 +amp -439 +ob -440 +hed -441 +▁prov -442 +▁Wh -443 +▁tra -444 +... -445 +ft -446 +▁look -447 +▁You -448 +▁includ -449 +ual -450 +▁people -451 +les -452 +▁serv -453 +gr -454 +▁col -455 +ian -456 +ments -457 +ful -458 +▁know -459 +▁produ -460 +ates -461 +iew -462 +▁Ne -463 +▁em -464 +rent -465 +ious -466 +tern -467 +▁she -468 +round -469 +ek -470 +▁every -471 +▁through -472 +▁may -473 +ating -474 +▁no -475 +▁only -476 +pport -477 +▁back -478 +▁most -479 +ect -480 +▁bu -481 +▁want -482 +ict -483 +ices -484 +▁As -485 +▁If -486 +▁well -487 +ities -488 +▁ind -489 +we -490 +▁bet -491 +▁ph -492 +ise -493 +▁use -494 +▁two -495 +▁co -496 +xt -497 +ont -498 +com -499 +▁act -500 +▁und -501 +ph -502 +iness -503 +lect -504 +iss -505 +▁after -506 +oy -507 +▁Se -508 +ife -509 +ause -510 +▁play -511 +fect -512 +▁| -513 +oth -514 +▁& -515 +ily -516 +row -517 +ork -518 +enc -519 +▁exper -520 +ject -521 +▁cons -522 +hen -523 +cial -524 +urn -525 +ert -526 +▁years -527 +als -528 +▁these -529 +ank -530 +ting -531 +▁$ -532 +▁Com -533 +aw -534 +▁bus -535 +▁An -536 +▁Un -537 +▁stud -538 +any -539 +bs -540 +ange -541 +▁For -542 +ures -543 +vent -544 +▁good -545 +ational -546 +aking -547 +▁see -548 +▁ke -549 +ased -550 +ific -551 +▁Pro -552 +▁now -553 +fore -554 +▁under -555 +▁very -556 +▁many -557 +▁reg -558 +▁sm -559 +ward -560 +hing -561 +▁imp -562 +get -563 +oint -564 +▁dif -565 +▁ra -566 +▁way -567 +erson -568 +ience -569 +▁start -570 +ts -571 +pect -572 +▁fin -573 +▁great -574 +▁And -575 +yst -576 +uring -577 +▁De -578 +▁rel -579 +formation -580 +▁gu -581 +ility -582 +ible -583 +▁rem -584 +▁could -585 +oss -586 +hip -587 +▁dec -588 +uch -589 +▁even -590 +▁inv -591 +). -592 +ty -593 +ics -594 +rit -595 +ract -596 +▁own -597 +▁sec -598 +cess -599 +velop -600 +▁day -601 +▁where -602 +▁show -603 +ident -604 +elf -605 +hes -606 +alth -607 +▁high -608 +its -609 +▁loc -610 +air -611 +▁find -612 +olog -613 +▁ac -614 +ull -615 +nds -616 +▁Al -617 +▁don -618 +▁ass -619 +▁home -620 +▁should -621 +line -622 +ath -623 +▁ent -624 +▁best -625 +▁here -626 +▁down -627 +lease -628 +▁then -629 +▁Sh -630 +ied -631 +ble -632 +ular -633 +|| -634 +▁right -635 +The -636 +arch -637 +▁set -638 +chool -639 +ited -640 +▁car -641 +▁av -642 +▁read -643 +▁New -644 +▁mon -645 +gan -646 +▁min -647 +▁take -648 +▁business -649 +erm -650 +▁fam -651 +▁ins -652 +ner -653 +ix -654 +▁inst -655 +▁fl -656 +ys -657 +▁design -658 +▁att -659 +ystem -660 +▁br -661 +alk -662 +▁too -663 +.” -664 +▁che -665 +▁bl -666 +io -667 +▁long -668 +▁much -669 +ative -670 +▁information -671 +▁Be -672 +▁made -673 +▁last -674 +ollow -675 +ason -676 +other -677 +ues -678 +gram -679 +arket -680 +▁product -681 +omet -682 +▁because -683 +ock -684 +ax -685 +▁Fr -686 +), -687 +rib -688 +▁week -689 +▁call -690 +▁did -691 +▁before -692 +▁think -693 +▁Cl -694 +▁team -695 +▁world -696 +atch -697 +me -698 +▁cre -699 +ale -700 +pen -701 +oun -702 +▁again -703 +▁sur -704 +ower -705 +▁Ad -706 +▁vis -707 +ient -708 +▁But -709 +chn -710 +pr -711 +az -712 +ustom -713 +land -714 +▁requ -715 +▁art -716 +▁develop -717 +▁being -718 +▁diffe -719 +▁pres -720 +rest -721 +way -722 +▁person -723 +ng -724 +ener -725 +▁such -726 +▁Le -727 +▁inte -728 +▁mem -729 +▁disc -730 +▁him -731 +ces -732 +▁support -733 +▁life -734 +arn -735 +ug -736 +ving -737 +ced -738 +ouse -739 +unity -740 +ave -741 +ince -742 +irect -743 +▁med -744 +▁Ar -745 +▁does -746 +▁while -747 +▁those -748 +ins -749 +▁provid -750 +ash -751 +arm -752 +view -753 +▁sim -754 +ivers -755 +ros -756 +▁lead -757 +▁sk -758 +akes -759 +ality -760 +▁pol -761 +▁end -762 +▁mod -763 +▁used -764 +▁cur -765 +ives -766 +▁around -767 +ric -768 +led -769 +ier -770 +▁free -771 +ailable -772 +ually -773 +▁each -774 +▁care -775 +▁comple -776 +▁follow -777 +ional -778 +ublic -779 +▁det -780 +▁On -781 +ple -782 +read -783 +der -784 +▁ret -785 +ize -786 +▁trans -787 +ather -788 +▁love -789 +▁There -790 +ages -791 +▁post -792 +ines -793 +▁child -794 +▁system -795 +ars -796 +▁bo -797 +ene -798 +roup -799 +▁eas -800 +▁book -801 +▁num -802 +▁ed -803 +▁How -804 +▁ser -805 +,” -806 +imes -807 +▁Te -808 +▁really -809 +▁count -810 +ets -811 +▁gr -812 +▁str -813 +▁program -814 +▁custom -815 +ton -816 +▁top -817 +▁run -818 +▁del -819 +au -820 +▁All -821 +iet -822 +▁cour -823 +▁found -824 +ffect -825 +▁So -826 +▁place -827 +▁list -828 +ness -829 +ved -830 +iel -831 +▁form -832 +▁month -833 +▁prof -834 +▁char -835 +ah -836 +▁feel -837 +▁To -838 +ute -839 +▁available -840 +▁going -841 +▁inter -842 +ittle -843 +▁They -844 +▁sign -845 +▁sub -846 +gg -847 +▁market -848 +man -849 +ature -850 +ames -851 +▁fun -852 +▁cle -853 +▁still -854 +cept -855 +▁Pl -856 +ways -857 +▁somet -858 +▁different -859 +▁aut -860 +▁both -861 +▁three -862 +▁few -863 +orn -864 +▁health -865 +▁though -866 +▁Ex -867 +ital -868 +ired -869 +▁pur -870 +ering -871 +▁rep -872 +▁adv -873 +▁exp -874 +▁techn -875 +▁happ -876 +▁open -877 +▁lot -878 +▁report -879 +▁company -880 +ata -881 +ween -882 +▁keep -883 +meric -884 +▁Sc -885 +orth -886 +▁plan -887 +▁hand -888 +ining -889 +bers -890 +iqu -891 +▁She -892 +tt -893 +ants -894 +be -895 +▁ext -896 +▁lar -897 +▁game -898 +▁sol -899 +▁point -900 +▁Q -901 +ross -902 +ology -903 +▁say -904 +ves -905 +atur -906 +▁met -907 +▁import -908 +▁process -909 +▁fil -910 +▁frie -911 +▁including -912 +▁family -913 +▁ev -914 +▁using -915 +▁same -916 +work -917 +▁project -918 +ized -919 +uc -920 +oot -921 +▁school -922 +▁between -923 +▁What -924 +ling -925 +ik -926 +▁little -927 +ution -928 +att -929 +ott -930 +▁experience -931 +▁during -932 +." -933 +less -934 +▁state -935 +iving -936 +▁Col -937 +▁i -938 +▁next -939 +uss -940 +els -941 +▁service -942 +aint -943 +▁real -944 +ody -945 +oh -946 +▁build -947 +▁allow -948 +ms -949 +reen -950 +▁opt -951 +▁water -952 +ished -953 +▁things -954 +▁come -955 +▁contin -956 +thing -957 +▁Americ -958 +▁var -959 +▁Ph -960 +▁dri -961 +ists -962 +uck -963 +ever -964 +ern -965 +ield -966 +▁cent -967 +arly -968 +over -969 +rand -970 +▁small -971 +▁rece -972 +▁organ -973 +▁appro -974 +▁rest -975 +gy -976 +▁big -977 +self -978 +▁Ind -979 +▁ref -980 +ex -981 +▁always -982 +▁mus -983 +▁better -984 +▁sure -985 +▁With -986 +▁interest -987 +▁win -988 +aut -989 +loy -990 +▁full -991 +▁pat -992 +▁pass -993 +▁poss -994 +ery -995 +illion -996 +▁online -997 +▁pri -998 +▁iss -999 +▁ty -1000 +▁put -1001 +ined -1002 +cent -1003 +ware -1004 +▁When -1005 +▁result -1006 +▁gener -1007 +▁since -1008 +▁Bl -1009 +▁ve -1010 +ps -1011 +▁try -1012 +▁direct -1013 +▁quest -1014 +iversity -1015 +▁mov -1016 +▁stand -1017 +▁partic -1018 +▁days -1019 +▁perform -1020 +▁group -1021 +ok -1022 +▁val -1023 +▁pay -1024 +▁ide -1025 +▁head -1026 +▁special -1027 +▁bel -1028 +▁Tr -1029 +▁today -1030 +▁Chr -1031 +▁something -1032 +▁class -1033 +▁provide -1034 +ients -1035 +ours -1036 +▁tri -1037 +▁second -1038 +▁services -1039 +▁ann -1040 +▁Our -1041 +ared -1042 +▁Con -1043 +ccess -1044 +▁resp -1045 +joy -1046 +▁phot -1047 +▁conf -1048 +▁Is -1049 +ploy -1050 +▁Or -1051 +▁dist -1052 +▁hard -1053 +▁without -1054 +pping -1055 +con -1056 +▁Sp -1057 +▁number -1058 +▁Z -1059 +ER -1060 +▁bro -1061 +▁def -1062 +▁sl -1063 +▁cor -1064 +▁must -1065 +oney -1066 +▁blo -1067 +▁another -1068 +ision -1069 +▁vide -1070 +stand -1071 +eng -1072 +▁current -1073 +cl -1074 +outh -1075 +▁give -1076 +▁wom -1077 +▁old -1078 +aj -1079 +ically -1080 +▁access -1081 +▁able -1082 +▁webs -1083 +ards -1084 +▁important -1085 +ior -1086 +iver -1087 +," -1088 +▁cr -1089 +ately -1090 +ium -1091 +▁— -1092 +▁cost -1093 +sh -1094 +▁grow -1095 +▁ask -1096 +ope -1097 +ral -1098 +▁meet -1099 +▁fact -1100 +▁invest -1101 +▁At -1102 +▁area -1103 +ruct -1104 +▁Cent -1105 +▁public -1106 +▁got -1107 +raph -1108 +▁Res -1109 +▁wr -1110 +▁bre -1111 +▁soc -1112 +ote -1113 +▁visit -1114 +▁proble -1115 +ered -1116 +▁light -1117 +▁incre -1118 +▁US -1119 +ample -1120 +▁working -1121 +ems -1122 +▁ob -1123 +ense -1124 +▁data -1125 +▁unt -1126 +ann -1127 +rence -1128 +pped -1129 +br -1130 +▁level -1131 +▁proper -1132 +▁looking -1133 +▁never -1134 +▁sal -1135 +▁might -1136 +inal -1137 +▁No -1138 +ats -1139 +ffic -1140 +▁order -1141 +ential -1142 +ember -1143 +▁effect -1144 +ley -1145 +▁event -1146 +▁fac -1147 +▁students -1148 +▁rese -1149 +▁food -1150 +▁local -1151 +▁Man -1152 +ency -1153 +▁four -1154 +▁Comm -1155 +▁eng -1156 +▁profess -1157 +ird -1158 +▁let -1159 +▁That -1160 +ission -1161 +▁offer -1162 +▁inf -1163 +ww -1164 +▁enjoy -1165 +▁site -1166 +▁Pr -1167 +▁spec -1168 +▁season -1169 +▁check -1170 +▁addition -1171 +ertain -1172 +▁within -1173 +▁children -1174 +gin -1175 +▁oper -1176 +▁pos -1177 +▁test -1178 +ording -1179 +▁making -1180 +▁My -1181 +▁view -1182 +lection -1183 +▁room -1184 +▁sit -1185 +▁prom -1186 +▁power -1187 +ories -1188 +ney -1189 +▁expl -1190 +here -1191 +▁ca -1192 +load -1193 +ently -1194 +▁products -1195 +rol -1196 +▁night -1197 +▁past -1198 +▁community -1199 +▁pop -1200 +▁Mar -1201 +▁sing -1202 +▁against -1203 +let -1204 +ream -1205 +tend -1206 +▁until -1207 +ases -1208 +▁less -1209 +▁' -1210 +utes -1211 +▁el -1212 +ains -1213 +agement -1214 +▁est -1215 +med -1216 +ids -1217 +▁email -1218 +ieve -1219 +▁job -1220 +iron -1221 +ised -1222 +ator -1223 +▁quality -1224 +ivid -1225 +▁May -1226 +ina -1227 +▁intern -1228 +▁indust -1229 +to -1230 +ills -1231 +▁gl -1232 +▁website -1233 +▁prote -1234 +▁impro -1235 +▁law -1236 +ode -1237 +ks -1238 +orm -1239 +▁equ -1240 +▁App -1241 +▁turn -1242 +ified -1243 +enn -1244 +urs -1245 +co -1246 +ged -1247 +IN -1248 +▁Br -1249 +▁away -1250 +icle -1251 +▁air -1252 +▁Fe -1253 +▁contact -1254 +▁creat -1255 +▁toget -1256 +We -1257 +▁together -1258 +▁University -1259 +bo -1260 +istr -1261 +ique -1262 +pend -1263 +aring -1264 +▁supp -1265 +▁learn -1266 +▁success -1267 +▁pract -1268 +▁Co -1269 +▁dr -1270 +ury -1271 +▁complete -1272 +▁Can -1273 +▁leg -1274 +iday -1275 +▁applic -1276 +▁expect -1277 +▁needs -1278 +▁include -1279 +por -1280 +▁Christ -1281 +iety -1282 +ocus -1283 +atter -1284 +ider -1285 +▁Cont -1286 +▁. -1287 +▁detail -1288 +▁large -1289 +▁easy -1290 +▁la -1291 +▁Car -1292 +ability -1293 +ret -1294 +▁One -1295 +oci -1296 +▁along -1297 +irl -1298 +▁course -1299 +▁says -1300 +▁change -1301 +▁news -1302 +arent -1303 +aster -1304 +room -1305 +▁present -1306 +ger -1307 +▁offic -1308 +vern -1309 +▁name -1310 +▁chang -1311 +hor -1312 +ism -1313 +▁conc -1314 +yle -1315 +ym -1316 +atures -1317 +▁beaut -1318 +▁Am -1319 +▁Do -1320 +▁activ -1321 +pos -1322 +▁cap -1323 +part -1324 +lish -1325 +ump -1326 +ising -1327 +▁members -1328 +ries -1329 +▁Me -1330 +▁money -1331 +▁Ste -1332 +enef -1333 +min -1334 +iting -1335 +▁employ -1336 +rap -1337 +▁video -1338 +▁bas -1339 +▁times -1340 +the -1341 +▁talk -1342 +▁Eng -1343 +ify -1344 +▁buy -1345 +ec -1346 +augh -1347 +▁beh -1348 +▁music -1349 +itions -1350 +▁Ro -1351 +▁fav -1352 +▁These -1353 +▁house -1354 +une -1355 +▁pa -1356 +ift -1357 +nect -1358 +▁opport -1359 +▁dem -1360 +▁sw -1361 +side -1362 +▁/ -1363 +ane -1364 +▁hist -1365 +▁why -1366 +Th -1367 +▁En -1368 +▁dra -1369 +ably -1370 +▁cond -1371 +▁ce -1372 +▁case -1373 +▁please -1374 +▁treat -1375 +by -1376 +mber -1377 +ron -1378 +veral -1379 +ots -1380 +▁perfect -1381 +aff -1382 +rie -1383 +aterial -1384 +pecial -1385 +▁live -1386 +ready -1387 +fort -1388 +ten -1389 +▁govern -1390 +▁account -1391 +▁dev -1392 +▁short -1393 +ention -1394 +▁thing -1395 +ization -1396 +▁create -1397 +▁following -1398 +▁Che -1399 +▁story -1400 +ON -1401 +▁clo -1402 +▁left -1403 +book -1404 +▁const -1405 +ived -1406 +viron -1407 +▁review -1408 +▁below -1409 +▁trad -1410 +▁understand -1411 +▁hum -1412 +▁million -1413 +son -1414 +!! -1415 +▁side -1416 +itive -1417 +▁having -1418 +alf -1419 +▁Your -1420 +ored -1421 +▁After -1422 +▁hot -1423 +ohn -1424 +ows -1425 +sc -1426 +▁page -1427 +etwork -1428 +▁Med -1429 +▁Fl -1430 +▁based -1431 +▁focus -1432 +▁makes -1433 +of -1434 +▁word -1435 +AT -1436 +RE -1437 +▁research -1438 +▁move -1439 +▁writ -1440 +▁across -1441 +▁camp -1442 +▁personal -1443 +ienc -1444 +▁link -1445 +▁line -1446 +ances -1447 +▁kind -1448 +▁possible -1449 +▁cou -1450 +rop -1451 +▁ever -1452 +▁mar -1453 +▁pot -1454 +uture -1455 +ividual -1456 +▁getting -1457 +▁comes -1458 +▁already -1459 +uly -1460 +▁benef -1461 +ajor -1462 +▁elect -1463 +▁educ -1464 +vious -1465 +▁record -1466 +ured -1467 +uper -1468 +osp -1469 +▁country -1470 +▁become -1471 +▁soft -1472 +▁Rep -1473 +ination -1474 +oice -1475 +orts -1476 +▁often -1477 +▁share -1478 +▁friends -1479 +▁several -1480 +ush -1481 +▁Ass -1482 +▁done -1483 +iven -1484 +ister -1485 +▁social -1486 +▁Count -1487 +▁es -1488 +duct -1489 +▁pack -1490 +▁bit -1491 +wards -1492 +▁fund -1493 +ead -1494 +iam -1495 +▁enough -1496 +▁quick -1497 +▁mil -1498 +▁tre -1499 +ones -1500 +▁minutes -1501 +uro -1502 +▁Please -1503 +conom -1504 +fer -1505 +▁bring -1506 +▁Inst -1507 +inc -1508 +▁women -1509 +uff -1510 +▁development -1511 +▁vers -1512 +▁Serv -1513 +▁hours -1514 +▁Des -1515 +▁body -1516 +▁mult -1517 +unch -1518 +app -1519 +oose -1520 +ips -1521 +▁tell -1522 +ides -1523 +iful -1524 +▁John -1525 +vironment -1526 +▁return -1527 +▁purch -1528 +mend -1529 +▁: -1530 +aim -1531 +▁cut -1532 +▁men -1533 +ners -1534 +▁city -1535 +▁lo -1536 +arl -1537 +reet -1538 +ape -1539 +▁Intern -1540 +▁deal -1541 +▁X -1542 +oon -1543 +▁individual -1544 +AN -1545 +▁exc -1546 +▁won -1547 +ST -1548 +▁ens -1549 +▁young -1550 +ted -1551 +ateg -1552 +▁Here -1553 +▁material -1554 +▁hold -1555 +▁compet -1556 +ograph -1557 +▁sum -1558 +▁... -1559 +▁Comp -1560 +▁others -1561 +▁jo -1562 +yn -1563 +utions -1564 +▁Tw -1565 +▁started -1566 +▁called -1567 +▁industry -1568 +▁months -1569 +▁mom -1570 +▁term -1571 +▁non -1572 +▁orig -1573 +idd -1574 +ights -1575 +▁didn -1576 +ript -1577 +▁land -1578 +ee -1579 +ai -1580 +nder -1581 +▁Gu -1582 +▁walk -1583 +▁clean -1584 +▁future -1585 +▁rele -1586 +▁American -1587 +▁However -1588 +▁pie -1589 +., -1590 +▁City -1591 +▁far -1592 +▁commun -1593 +lished -1594 +ched -1595 +▁po -1596 +▁doing -1597 +▁major -1598 +ained -1599 +▁control -1600 +▁space -1601 +ource -1602 +fact -1603 +ball -1604 +urity -1605 +arr -1606 +osed -1607 +▁wa -1608 +▁low -1609 +ges -1610 +▁cover -1611 +▁Ab -1612 +▁store -1613 +anies -1614 +lement -1615 +ference -1616 +ford -1617 +▁occ -1618 +▁games -1619 +▁means -1620 +AR -1621 +lege -1622 +▁Not -1623 +▁mind -1624 +▁offers -1625 +oring -1626 +▁Tra -1627 +▁yet -1628 +▁bra -1629 +▁Dr -1630 +▁came -1631 +▁five -1632 +▁percent -1633 +▁chall -1634 +▁comb -1635 +▁Min -1636 +▁took -1637 +▁invol -1638 +▁doesn -1639 +sel -1640 +▁lim -1641 +orld -1642 +▁fore -1643 +ilities -1644 +▁* -1645 +▁customers -1646 +▁features -1647 +bal -1648 +▁State -1649 +▁least -1650 +▁strong -1651 +▁step -1652 +▁price -1653 +ches -1654 +▁heart -1655 +▁God -1656 +▁Ke -1657 +urther -1658 +▁range -1659 +▁specific -1660 +▁More -1661 +▁main -1662 +most -1663 +▁require -1664 +▁close -1665 +▁School -1666 +▁once -1667 +▁key -1668 +▁pict -1669 +sw -1670 +err -1671 +ler -1672 +▁upd -1673 +ilt -1674 +ither -1675 +▁mean -1676 +▁Bo -1677 +▁early -1678 +▁ey -1679 +▁cra -1680 +▁Jan -1681 +▁Now -1682 +▁tool -1683 +▁stay -1684 +▁discuss -1685 +▁government -1686 +illed -1687 +aces -1688 +af -1689 +▁series -1690 +▁tem -1691 +ources -1692 +▁hig -1693 +▁priv -1694 +▁Bro -1695 +▁ste -1696 +▁technology -1697 +pro -1698 +cle -1699 +▁install -1700 +▁charact -1701 +▁Im -1702 +atural -1703 +▁Ed -1704 +▁typ -1705 +▁United -1706 +▁redu -1707 +▁beautiful -1708 +atic -1709 +▁By -1710 +▁ago -1711 +▁went -1712 +▁begin -1713 +aken -1714 +// -1715 +▁announ -1716 +org -1717 +▁thought -1718 +▁Pe -1719 +▁pick -1720 +▁told -1721 +▁hope -1722 +▁appear -1723 +ancial -1724 +isk -1725 +It -1726 +resent -1727 +▁anal -1728 +▁happen -1729 +anks -1730 +rew -1731 +▁Gr -1732 +▁Em -1733 +irm -1734 +▁break -1735 +ille -1736 +▁wind -1737 +▁questions -1738 +resh -1739 +OR -1740 +▁York -1741 +▁x -1742 +▁Qu -1743 +come -1744 +▁Pre -1745 +▁content -1746 +▁certain -1747 +▁Add -1748 +oll -1749 +▁everything -1750 +▁prep -1751 +ourn -1752 +hers -1753 +:// -1754 +▁sn -1755 +ians -1756 +irt -1757 +gle -1758 +▁field -1759 +▁companies -1760 +▁travel -1761 +ony -1762 +▁Cal -1763 +▁enc -1764 +▁recom -1765 +▁single -1766 +▁known -1767 +▁added -1768 +▁favor -1769 +▁media -1770 +▁-- -1771 +cell -1772 +▁building -1773 +arning -1774 +▁manag -1775 +▁Park -1776 +aps -1777 +▁search -1778 +▁environment -1779 +▁friend -1780 +▁actually -1781 +aur -1782 +▁address -1783 +ief -1784 +▁tot -1785 +▁ener -1786 +de -1787 +▁study -1788 +▁mess -1789 +eral -1790 +▁vol -1791 +▁tax -1792 +▁press -1793 +▁problem -1794 +play -1795 +isc -1796 +▁later -1797 +▁connect -1798 +ino -1799 +▁works -1800 +ests -1801 +▁Sm -1802 +▁girl -1803 +icy -1804 +▁improve -1805 +gest -1806 +acy -1807 +ibr -1808 +▁taking -1809 +ew -1810 +▁South -1811 +▁ident -1812 +▁maint -1813 +▁sound -1814 +▁pub -1815 +ental -1816 +year -1817 +lebr -1818 +ural -1819 +▁Su -1820 +▁track -1821 +ided -1822 +▁training -1823 +▁watch -1824 +▁results -1825 +ster -1826 +▁staff -1827 +▁card -1828 +▁wond -1829 +abor -1830 +▁North -1831 +▁face -1832 +back -1833 +▁professional -1834 +nes -1835 +ensive -1836 +▁Mc -1837 +▁Just -1838 +ocu -1839 +gs -1840 +ES -1841 +▁film -1842 +▁provides -1843 +wh -1844 +atest -1845 +yl -1846 +▁seen -1847 +▁While -1848 +▁issues -1849 +▁someone -1850 +ama -1851 +▁Per -1852 +▁unique -1853 +▁host -1854 +▁half -1855 +▁front -1856 +▁official -1857 +cer -1858 +▁Euro -1859 +fully -1860 +▁near -1861 +opy -1862 +▁econom -1863 +▁relations -1864 +▁web -1865 +▁sell -1866 +▁particular -1867 +▁National -1868 +▁County -1869 +▁everyone -1870 +▁miss -1871 +▁port -1872 +AL -1873 +▁dig -1874 +urch -1875 +▁due -1876 +▁Aust -1877 +▁Some -1878 +go -1879 +▁recommend -1880 +▁network -1881 +hod -1882 +▁cook -1883 +▁Center -1884 +▁Don -1885 +lex -1886 +▁cred -1887 +▁office -1888 +▁respons -1889 +▁z -1890 +ued -1891 +▁Inc -1892 +▁Oct -1893 +▁simple -1894 +itted -1895 +▁Part -1896 +▁age -1897 +▁ant -1898 +ctor -1899 +ibility -1900 +▁aud -1901 +▁management -1902 +ging -1903 +▁click -1904 +not -1905 +roll -1906 +▁oil -1907 +▁Pol -1908 +▁particip -1909 +time -1910 +▁Dep -1911 +asing -1912 +▁whole -1913 +pecially -1914 +▁mot -1915 +▁bar -1916 +obile -1917 +iod -1918 +▁Acc -1919 +▁Pres -1920 +▁performance -1921 +▁areas -1922 +▁Apr -1923 +▁mor -1924 +▁ess -1925 +pper -1926 +▁fall -1927 +▁author -1928 +cing -1929 +▁given -1930 +ply -1931 +imate -1932 +▁bed -1933 +▁World -1934 +icult -1935 +nding -1936 +▁above -1937 +▁reason -1938 +▁protect -1939 +ites -1940 +▁events -1941 +In -1942 +ators -1943 +aining -1944 +▁among -1945 +▁eff -1946 +ables -1947 +umb -1948 +▁Will -1949 +ops -1950 +▁experienc -1951 +ask -1952 +▁Sec -1953 +▁history -1954 +EN -1955 +▁select -1956 +▁Stud -1957 +omes -1958 +▁black -1959 +ogn -1960 +ED -1961 +▁assist -1962 +▁size -1963 +▁energy -1964 +▁foot -1965 +ison -1966 +cy -1967 +ili -1968 +▁High -1969 +▁details -1970 +▁print -1971 +ledge -1972 +▁htt -1973 +▁Reg -1974 +▁glo -1975 +▁believe -1976 +▁flo -1977 +▁sex -1978 +crib -1979 +▁further -1980 +▁From -1981 +▁amount -1982 +▁Post -1983 +▁six -1984 +▁log -1985 +idence -1986 +ety -1987 +ulation -1988 +▁designed -1989 +▁includes -1990 +▁prob -1991 +▁Friday -1992 +astic -1993 +▁pain -1994 +ands -1995 +vert -1996 +▁cult -1997 +ufact -1998 +▁points -1999 +▁repl -2000 +▁parent -2001 +▁mag -2002 +▁red -2003 +▁Day -2004 +▁property -2005 +AS -2006 +▁Ge -2007 +ruction -2008 +▁Bar -2009 +▁continue -2010 +▁soon -2011 +nov -2012 +▁feature -2013 +▁Aug -2014 +▁value -2015 +urance -2016 +▁et -2017 +▁Mr -2018 +▁Europe -2019 +▁anything -2020 +▁text -2021 +▁various -2022 +itch -2023 +▁coming -2024 +▁question -2025 +▁popular -2026 +▁latest -2027 +itional -2028 +▁according -2029 +aily -2030 +▁lov -2031 +▁living -2032 +rodu -2033 +▁phys -2034 +▁forward -2035 +▁type -2036 +my -2037 +▁fre -2038 +uation -2039 +▁March -2040 +▁phone -2041 +itc -2042 +ouch -2043 +▁consider -2044 +cript -2045 +▁pret -2046 +▁whether -2047 +aturday -2048 +IC -2049 +IT -2050 +▁brand -2051 +▁entire -2052 +▁idea -2053 +ze -2054 +though -2055 +▁claim -2056 +▁white -2057 +edd -2058 +aching -2059 +▁celebr -2060 +▁weeks -2061 +▁gra -2062 +▁dou -2063 +▁needed -2064 +▁Bu -2065 +▁diff -2066 +▁consum -2067 +▁potential -2068 +▁opportunity -2069 +▁comput -2070 +▁deb -2071 +▁El -2072 +▁color -2073 +elt -2074 +▁taken -2075 +▁Us -2076 +▁June -2077 +▁wide -2078 +▁required -2079 +▁receive -2080 +▁par -2081 +▁date -2082 +▁Sept -2083 +▁extra -2084 +selves -2085 +▁Sund -2086 +ung -2087 +itter -2088 +▁docu -2089 +new -2090 +▁third -2091 +▁example -2092 +AC -2093 +▁relationship -2094 +▁safe -2095 +ival -2096 +▁bad -2097 +▁sent -2098 +▁ensure -2099 +This -2100 +itor -2101 +ises -2102 +▁ready -2103 +▁inj -2104 +▁Off -2105 +▁West -2106 +▁, -2107 +▁comfort -2108 +▁currently -2109 +ilar -2110 +amer -2111 +▁meas -2112 +ees -2113 +ires -2114 +▁financial -2115 +▁common -2116 +▁almost -2117 +ffe -2118 +▁sugg -2119 +▁fire -2120 +head -2121 +▁ach -2122 +▁April -2123 +val -2124 +uary -2125 +▁ways -2126 +▁human -2127 +▁kids -2128 +▁Read -2129 +▁Art -2130 +▁pretty -2131 +▁period -2132 +▁quite -2133 +▁Jo -2134 +▁options -2135 +▁final -2136 +▁skin -2137 +▁natural -2138 +▁yourself -2139 +▁especially -2140 +▁veh -2141 +irc -2142 +▁road -2143 +▁style -2144 +▁trying -2145 +▁park -2146 +▁sho -2147 +▁box -2148 +▁Health -2149 +▁Cor -2150 +ring -2151 +▁items -2152 +▁His -2153 +▁answ -2154 +▁paper -2155 +used -2156 +▁member -2157 +▁provided -2158 +▁either -2159 +ese -2160 +ana -2161 +ively -2162 +.... -2163 +▁Saturday -2164 +itting -2165 +onday -2166 +▁coll -2167 +▁engine -2168 +▁choose -2169 +▁hon -2170 +▁self -2171 +▁crit -2172 +▁held -2173 +▁throughout -2174 +▁happy -2175 +▁dam -2176 +▁fit -2177 +▁download -2178 +▁via -2179 +▁swe -2180 +▁attend -2181 +▁wanted -2182 +▁flow -2183 +▁clients -2184 +▁stra -2185 +ication -2186 +▁summer -2187 +▁Pa -2188 +▁recent -2189 +▁Fin -2190 +▁impact -2191 +▁Aut -2192 +▁users -2193 +ada -2194 +▁created -2195 +▁sales -2196 +▁tit -2197 +▁Af -2198 +icro -2199 +▁July -2200 +azing -2201 +▁blog -2202 +▁issue -2203 +▁previous -2204 +▁behind -2205 +▁takes -2206 +arter -2207 +oogle -2208 +▁recently -2209 +hel -2210 +▁TH -2211 +▁software -2212 +▁Dav -2213 +angu -2214 +gress -2215 +IS -2216 +do -2217 +▁init -2218 +cast -2219 +ams -2220 +ux -2221 +▁version -2222 +▁super -2223 +▁Get -2224 +▁Feb -2225 +ried -2226 +▁bott -2227 +▁seem -2228 +▁Up -2229 +▁couple -2230 +▁song -2231 +▁running -2232 +▁insp -2233 +▁hol -2234 +verage -2235 +ume -2236 +ober -2237 +▁clear -2238 +▁collect -2239 +▁problems -2240 +ades -2241 +apt -2242 +▁isn -2243 +▁education -2244 +▁received -2245 +▁method -2246 +oura -2247 +▁table -2248 +▁players -2249 +▁role -2250 +▁represent -2251 +▁reading -2252 +▁Val -2253 +uge -2254 +▁Direct -2255 +eth -2256 +▁Int -2257 +anced -2258 +itten -2259 +▁signific -2260 +atform -2261 +▁likely -2262 +eke -2263 +ole -2264 +earch -2265 +ification -2266 +▁Sw -2267 +par -2268 +▁shows -2269 +▁di -2270 +where -2271 +▁security -2272 +▁increase -2273 +▁accom -2274 +▁States -2275 +▁Mon -2276 +▁favorite -2277 +▁customer -2278 +▁stri -2279 +▁pan -2280 +▁party -2281 +reme -2282 +▁action -2283 +▁skills -2284 +▁regular -2285 +St -2286 +▁difficult -2287 +▁fast -2288 +▁simply -2289 +idge -2290 +OU -2291 +▁sle -2292 +▁else -2293 +▁Face -2294 +▁writing -2295 +▁ele -2296 +▁nice -2297 +aging -2298 +▁Sunday -2299 +▁Monday -2300 +oud -2301 +oid -2302 +▁position -2303 +overed -2304 +▁article -2305 +▁outside -2306 +▁original -2307 +▁Her -2308 +▁probably -2309 +▁cool -2310 +icles -2311 +aving -2312 +mit -2313 +▁cup -2314 +▁necess -2315 +▁inside -2316 +▁fresh -2317 +ID -2318 +istration -2319 +▁asked -2320 +▁wonder -2321 +▁goal -2322 +▁systems -2323 +.) -2324 +▁manufact -2325 +arth -2326 +aby -2327 +▁model -2328 +-- -2329 +▁House -2330 +li -2331 +▁morning -2332 +▁ground -2333 +▁President -2334 +icated -2335 +▁application -2336 +▁leave -2337 +ham -2338 +eter -2339 +▁ful -2340 +▁learning -2341 +▁anim -2342 +uit -2343 +aker -2344 +▁Associ -2345 +▁risk -2346 +▁Act -2347 +▁Black -2348 +▁knowledge -2349 +▁located -2350 +based -2351 +▁contrib -2352 +▁UK -2353 +▁release -2354 +▁projects -2355 +▁lives -2356 +▁changes -2357 +▁tour -2358 +▁Are -2359 +▁Bus -2360 +▁however -2361 +ox -2362 +▁Free -2363 +▁treatment -2364 +▁stop -2365 +medi -2366 +face -2367 +right -2368 +▁Austral -2369 +▁exist -2370 +▁mix -2371 +▁recogn -2372 +▁additional -2373 +▁polit -2374 +adem -2375 +▁Red -2376 +▁activities -2377 +▁private -2378 +▁abs -2379 +▁sat -2380 +▁career -2381 +iple -2382 +name -2383 +▁board -2384 +▁medical -2385 +▁Work -2386 +▁total -2387 +▁Mich -2388 +▁cal -2389 +▁anyone -2390 +▁hit -2391 +▁etc -2392 +artment -2393 +▁fail -2394 +▁ple -2395 +▁TV -2396 +▁accept -2397 +urg -2398 +▁town -2399 +▁Soc -2400 +ague -2401 +▁base -2402 +arget -2403 +aign -2404 +amed -2405 +bor -2406 +OT -2407 +hib -2408 +▁mark -2409 +▁former -2410 +▁contract -2411 +▁matter -2412 +▁included -2413 +▁America -2414 +ming -2415 +ounc -2416 +ules -2417 +▁mach -2418 +ession -2419 +▁Sal -2420 +iol -2421 +▁stock -2422 +▁match -2423 +▁autom -2424 +▁words -2425 +▁significant -2426 +izing -2427 +▁hair -2428 +ipment -2429 +▁saf -2430 +ecut -2431 +▁Ser -2432 +▁meeting -2433 +wood -2434 +▁Of -2435 +▁October -2436 +▁books -2437 +▁September -2438 +ovember -2439 +▁growth -2440 +▁Ac -2441 +▁playing -2442 +▁January -2443 +aced -2444 +▁leaders -2445 +empt -2446 +▁ball -2447 +▁worth -2448 +mon -2449 +irth -2450 +▁round -2451 +▁longer -2452 +▁drive -2453 +▁hy -2454 +▁character -2455 +▁variety -2456 +ny -2457 +▁concern -2458 +▁News -2459 +▁First -2460 +▁practice -2461 +ester -2462 +▁production -2463 +che -2464 +▁function -2465 +▁Sk -2466 +▁Wed -2467 +rict -2468 +▁looks -2469 +▁squ -2470 +ground -2471 +▁exam -2472 +▁late -2473 +reg -2474 +▁San -2475 +ude -2476 +▁lay -2477 +airs -2478 +▁Every -2479 +▁wall -2480 +mercial -2481 +pm -2482 +iff -2483 +▁sun -2484 +ursday -2485 +▁defin -2486 +adu -2487 +▁determ -2488 +na -2489 +▁Ag -2490 +▁August -2491 +▁suggest -2492 +ci -2493 +▁Har -2494 +elcome -2495 +▁worked -2496 +▁weeke -2497 +▁fig -2498 +ville -2499 +▁associ -2500 +uesday -2501 +▁Google -2502 +▁programs -2503 +▁death -2504 +imum -2505 +▁chance -2506 +▁platform -2507 +▁cand -2508 +▁screen -2509 +▁international -2510 +▁Then -2511 +iddle -2512 +▁Let -2513 +ipping -2514 +cks -2515 +rect -2516 +▁deg -2517 +▁true -2518 +▁Dis -2519 +▁nothing -2520 +Wh -2521 +▁challeng -2522 +itchen -2523 +▁loss -2524 +▁general -2525 +▁clos -2526 +▁rather -2527 +▁plans -2528 +arden -2529 +▁Facebook -2530 +▁purchase -2531 +▁estab -2532 +erc -2533 +▁amazing -2534 +▁credit -2535 +▁leading -2536 +▁subject -2537 +▁Department -2538 +▁regard -2539 +▁stat -2540 +cember -2541 +▁allows -2542 +ouncil -2543 +▁seems -2544 +olution -2545 +eds -2546 +▁built -2547 +▁arri -2548 +▁police -2549 +mas -2550 +▁similar -2551 +▁Mus -2552 +▁student -2553 +▁Sim -2554 +▁usually -2555 +▁infl -2556 +▁Pat -2557 +▁rate -2558 +▁quickly -2559 +▁Air -2560 +oke -2561 +▁November -2562 +▁teac -2563 +▁Also -2564 +lin -2565 +AM -2566 +▁Street -2567 +▁draw -2568 +▁national -2569 +ashing -2570 +▁touch -2571 +ought -2572 +▁providing -2573 +▁comment -2574 +▁International -2575 +oph -2576 +light -2577 +▁excell -2578 +▁deep -2579 +nesday -2580 +▁apply -2581 +▁higher -2582 +iter -2583 +iber -2584 +▁choice -2585 +▁photos -2586 +clus -2587 +▁Group -2588 +str -2589 +gar -2590 +▁tast -2591 +ING -2592 +▁respect -2593 +off -2594 +▁collection -2595 +▁safety -2596 +▁image -2597 +▁Out -2598 +▁Cons -2599 +now -2600 +▁hands -2601 +▁marketing -2602 +▁prior -2603 +ondon -2604 +▁ideas -2605 +▁integr -2606 +▁moment -2607 +▁movie -2608 +▁sil -2609 +▁encoura -2610 +▁easily -2611 +▁decision -2612 +example -2613 +▁ut -2614 +▁Cour -2615 +▁location -2616 +▁cell -2617 +▁bal -2618 +▁inde -2619 +▁dom -2620 +hern -2621 +▁rad -2622 +▁prevent -2623 +▁court -2624 +▁af -2625 +▁bud -2626 +▁Wind -2627 +▁op -2628 +▁released -2629 +▁decided -2630 +▁mass -2631 +▁ill -2632 +▁commit -2633 +▁Thursday -2634 +ached -2635 +▁digital -2636 +▁Home -2637 +put -2638 +▁Tuesday -2639 +ournal -2640 +▁emb -2641 +ha -2642 +▁reported -2643 +▁Well -2644 +▁benefits -2645 +▁Calif -2646 +▁file -2647 +ivery -2648 +▁exact -2649 +▁seek -2650 +▁December -2651 +▁introdu -2652 +▁wood -2653 +amb -2654 +▁La -2655 +▁cannot -2656 +ma -2657 +eal -2658 +▁campaign -2659 +▁lost -2660 +reng -2661 +▁display -2662 +▁Most -2663 +▁daily -2664 +▁partners -2665 +▁parents -2666 +▁ord -2667 +▁attack -2668 +▁Business -2669 +ishing -2670 +idents -2671 +hood -2672 +▁involved -2673 +▁agree -2674 +▁announced -2675 +▁cause -2676 +▁sche -2677 +▁effic -2678 +rown -2679 +▁sens -2680 +ructure -2681 +▁Gl -2682 +unities -2683 +▁drink -2684 +▁piece -2685 +▁center -2686 +▁Ang -2687 +ray -2688 +ospital -2689 +▁neg -2690 +atory -2691 +▁user -2692 +▁dest -2693 +OM -2694 +▁related -2695 +▁saw -2696 +▁Any -2697 +▁affect -2698 +▁expected -2699 +▁vict -2700 +ipe -2701 +▁Design -2702 +▁investig -2703 +▁ability -2704 +▁club -2705 +ederal -2706 +▁patients -2707 +▁Wednesday -2708 +▁ep -2709 +▁London -2710 +▁Click -2711 +ruary -2712 +EO -2713 +avy -2714 +▁rout -2715 +▁send -2716 +illing -2717 +▁ri -2718 +▁save -2719 +▁tick -2720 +ilies -2721 +▁modern -2722 +▁norm -2723 +just -2724 +ET -2725 +▁weekend -2726 +▁mobile -2727 +▁circ -2728 +sp -2729 +▁standard -2730 +▁langu -2731 +▁Prof -2732 +▁expert -2733 +▁option -2734 +ett -2735 +▁goes -2736 +▁boy -2737 +▁ded -2738 +▁immedi -2739 +▁green -2740 +▁enter -2741 +▁restaur -2742 +▁computer -2743 +▁Over -2744 +▁fight -2745 +▁War -2746 +▁aw -2747 +▁woman -2748 +▁bag -2749 +▁global -2750 +▁pers -2751 +istic -2752 +board -2753 +lim -2754 +▁target -2755 +▁mother -2756 +ivity -2757 +▁iP -2758 +▁emer -2759 +uel -2760 +▁sym -2761 +▁College -2762 +like -2763 +iring -2764 +▁serious -2765 +▁innov -2766 +▁parts -2767 +▁helps -2768 +▁huge -2769 +▁PM -2770 +▁costs -2771 +▁English -2772 +key -2773 +asons -2774 +oday -2775 +aves -2776 +▁gen -2777 +▁Check -2778 +zz -2779 +ellow -2780 +▁surpr -2781 +▁weight -2782 +▁http -2783 +▁earn -2784 +enge -2785 +uk -2786 +erve -2787 +▁rights -2788 +ara -2789 +▁bank -2790 +▁ones -2791 +ornia -2792 +▁legal -2793 +▁code -2794 +▁solutions -2795 +▁request -2796 +▁equipment -2797 +▁Sen -2798 +▁myself -2799 +▁gives -2800 +▁tools -2801 +▁Afric -2802 +▁warm -2803 +▁arch -2804 +▁Other -2805 +▁insurance -2806 +cription -2807 +raft -2808 +band -2809 +▁Del -2810 +ram -2811 +edding -2812 +▁feed -2813 +▁Hol -2814 +EC -2815 +▁approach -2816 +ault -2817 +▁conditions -2818 +▁played -2819 +▁giving -2820 +▁admin -2821 +▁dress -2822 +▁Ob -2823 +▁Techn -2824 +pri -2825 +▁Book -2826 +attle -2827 +▁attention -2828 +▁roll -2829 +OS -2830 +▁levels -2831 +▁sus -2832 +▁sett -2833 +▁resources -2834 +unt -2835 +▁award -2836 +▁Par -2837 +▁Brit -2838 +▁prim -2839 +hold -2840 +▁deliver -2841 +▁trust -2842 +ension -2843 +iction -2844 +atives -2845 +▁Service -2846 +▁note -2847 +▁sold -2848 +aged -2849 +bert -2850 +▁qual -2851 +▁remember -2852 +▁policy -2853 +▁February -2854 +▁interested -2855 +erous -2856 +▁Play -2857 +▁solution -2858 +▁door -2859 +▁Trans -2860 +▁businesses -2861 +▁capt -2862 +▁gets -2863 +▁planning -2864 +▁subs -2865 +▁highly -2866 +▁lab -2867 +aught -2868 +▁object -2869 +iding -2870 +pose -2871 +▁starting -2872 +▁opp -2873 +▁cases -2874 +partment -2875 +▁Law -2876 +ysis -2877 +▁Christmas -2878 +akers -2879 +▁lower -2880 +▁upon -2881 +▁instead -2882 +▁vac -2883 +▁write -2884 +▁hear -2885 +▁organization -2886 +▁materials -2887 +vey -2888 +▁express -2889 +▁themselves -2890 +▁published -2891 +EL -2892 +irit -2893 +▁California -2894 +ening -2895 +▁president -2896 +▁source -2897 +ica -2898 +▁reach -2899 +▁Gener -2900 +▁plant -2901 +▁condition -2902 +ples -2903 +mission -2904 +ashion -2905 +orge -2906 +urt -2907 +▁sense -2908 +▁fine -2909 +▁streng -2910 +apan -2911 +ibrary -2912 +www -2913 +▁dry -2914 +izes -2915 +▁effective -2916 +▁firm -2917 +▁sale -2918 +bum -2919 +▁mid -2920 +▁photo -2921 +▁written -2922 +▁types -2923 +AP -2924 +▁dise -2925 +▁average -2926 +▁interview -2927 +rup -2928 +urb -2929 +rom -2930 +▁consult -2931 +▁AM -2932 +▁Go -2933 +▁countries -2934 +▁Met -2935 +▁positive -2936 +ule -2937 +▁remov -2938 +▁multiple -2939 +wide -2940 +▁Rem -2941 +▁Services -2942 +iles -2943 +ida -2944 +gu -2945 +ael -2946 +▁lif -2947 +arant -2948 +▁Great -2949 +▁join -2950 +mm -2951 +▁Je -2952 +enty -2953 +unk -2954 +▁slow -2955 +▁Spe -2956 +▁India -2957 +▁trip -2958 +▁describ -2959 +ube -2960 +aches -2961 +ength -2962 +▁began -2963 +ato -2964 +▁interesting -2965 +▁imm -2966 +▁Mod -2967 +▁images -2968 +▁answer -2969 +▁prem -2970 +▁player -2971 +▁cat -2972 +add -2973 +▁viol -2974 +▁opportunities -2975 +urer -2976 +▁message -2977 +▁Cle -2978 +▁employees -2979 +▁dream -2980 +ography -2981 +▁heat -2982 +▁healthy -2983 +ager -2984 +▁Sch -2985 +▁Why -2986 +▁Thanks -2987 +▁sites -2988 +ration -2989 +▁directly -2990 +▁camer -2991 +▁hour -2992 +▁item -2993 +rel -2994 +rought -2995 +▁document -2996 +▁fans -2997 +▁According -2998 +bit -2999 +orage -3000 +press -3001 +▁necessary -3002 +itute -3003 +▁picture -3004 +▁achieve -3005 +▁David -3006 +IL -3007 +▁copy -3008 +▁Hot -3009 +▁Av -3010 +▁Program -3011 +▁essential -3012 +▁completely -3013 +▁lic -3014 +▁Sub -3015 +▁gift -3016 +▁Once -3017 +▁tele -3018 +▁band -3019 +▁families -3020 +▁stories -3021 +sy -3022 +▁prices -3023 +▁groups -3024 +duc -3025 +▁Year -3026 +olf -3027 +▁Phot -3028 +▁commercial -3029 +▁King -3030 +arlier -3031 +▁Rec -3032 +▁Whe -3033 +▁Found -3034 +▁Since -3035 +▁reve -3036 +elling -3037 +▁offe -3038 +▁goals -3039 +ocol -3040 +▁excellent -3041 +▁div -3042 +▁cert -3043 +▁East -3044 +▁Cr -3045 +▁promot -3046 +▁dru -3047 +▁Even -3048 +▁pull -3049 +▁successful -3050 +▁eye -3051 +▁Market -3052 +▁fully -3053 +▁www -3054 +▁growing -3055 +ares -3056 +itely -3057 +▁Mag -3058 +▁hor -3059 +▁led -3060 +▁itself -3061 +itation -3062 +▁Many -3063 +▁Loc -3064 +▁creating -3065 +▁fix -3066 +▁stru -3067 +iant -3068 +▁except -3069 +▁adult -3070 +▁traditional -3071 +▁White -3072 +▁comments -3073 +▁gold -3074 +▁paint -3075 +▁separ -3076 +oul -3077 +erved -3078 +▁Good -3079 +▁fab -3080 +▁aim -3081 +coming -3082 +▁neigh -3083 +▁broad -3084 +▁Germ -3085 +▁Russ -3086 +mb -3087 +▁Green -3088 +ancy -3089 +iable -3090 +▁birth -3091 +onse -3092 +▁propos -3093 +omen -3094 +▁fair -3095 +▁cy -3096 +ooth -3097 +▁gar -3098 +▁device -3099 +BC -3100 +▁reports -3101 +uses -3102 +anch -3103 +▁Best -3104 +▁block -3105 +▁mount -3106 +▁teams -3107 +▁terms -3108 +▁kitchen -3109 +▁cross -3110 +oms -3111 +udd -3112 +▁Spr -3113 +▁stuff -3114 +tee -3115 +▁extreme -3116 +▁dark -3117 +ffee -3118 +▁vehicle -3119 +▁Last -3120 +▁Jack -3121 +▁attempt -3122 +▁Each -3123 +▁glass -3124 +urning -3125 +▁wasn -3126 +▁applications -3127 +ores -3128 +venue -3129 +▁hop -3130 +▁saying -3131 +▁floor -3132 +hest -3133 +▁wrong -3134 +ey -3135 +▁baby -3136 +imately -3137 +▁Tex -3138 +▁dead -3139 +ties -3140 +uth -3141 +▁Bra -3142 +▁China -3143 +▁thinking -3144 +▁Port -3145 +▁rev -3146 +▁depend -3147 +▁shoot -3148 +▁Web -3149 +▁Ty -3150 +inner -3151 +ipped -3152 +▁blood -3153 +ashington -3154 +ecutive -3155 +▁bi -3156 +ald -3157 +oming -3158 +▁Twitter -3159 +▁Develop -3160 +OL -3161 +istry -3162 +▁mention -3163 +▁See -3164 +TM -3165 +”. -3166 +▁gave -3167 +▁Japan -3168 +aughter -3169 +▁Hall -3170 +▁smart -3171 +▁System -3172 +▁wait -3173 +inary -3174 +▁implement -3175 +pite -3176 +▁obs -3177 +rote -3178 +▁profession -3179 +▁speed -3180 +▁aware -3181 +▁serve -3182 +▁spend -3183 +▁attract -3184 +▁director -3185 +▁organiz -3186 +▁Bel -3187 +▁offering -3188 +iced -3189 +▁section -3190 +▁sen -3191 +▁budget -3192 +▁Association -3193 +▁became -3194 +▁farm -3195 +aries -3196 +ological -3197 +▁impress -3198 +▁distrib -3199 +Ch -3200 +rows -3201 +▁Office -3202 +▁ge -3203 +▁Mor -3204 +▁pictures -3205 +▁nation -3206 +▁college -3207 +▁wish -3208 +AD -3209 +▁Pri -3210 +▁correct -3211 +▁Sol -3212 +field -3213 +overn -3214 +▁Make -3215 +▁suit -3216 +▁IN -3217 +▁effort -3218 +▁Mem -3219 +▁developed -3220 +▁places -3221 +▁moving -3222 +▁conduct -3223 +▁coun -3224 +▁tal -3225 +▁carry -3226 +▁dog -3227 +▁limited -3228 +▁individuals -3229 +▁advice -3230 +ils -3231 +▁dro -3232 +vest -3233 +▁son -3234 +pre -3235 +▁rent -3236 +▁avoid -3237 +▁spent -3238 +yond -3239 +ications -3240 +zy -3241 +▁complex -3242 +▁Paul -3243 +▁defe -3244 +lock -3245 +▁bath -3246 +▁title -3247 +▁sleep -3248 +▁situation -3249 +▁Down -3250 +▁Road -3251 +idered -3252 +▁requirements -3253 +▁album -3254 +▁progress -3255 +▁delivery -3256 +ceed -3257 +▁Today -3258 +▁jud -3259 +▁Washington -3260 +▁cas -3261 +▁Vis -3262 +▁Educ -3263 +▁Inter -3264 +▁vot -3265 +▁construction -3266 +rench -3267 +riend -3268 +▁enh -3269 +▁Public -3270 +ibly -3271 +▁About -3272 +house -3273 +haps -3274 +▁ble -3275 +word -3276 +▁Canada -3277 +▁advant -3278 +▁wants -3279 +▁Top -3280 +▁statement -3281 +▁feet -3282 +▁Use -3283 +▁schools -3284 +▁Gold -3285 +▁war -3286 +down -3287 +▁race -3288 +useum -3289 +▁heard -3290 +▁convers -3291 +▁eat -3292 +▁Find -3293 +US -3294 +▁sometimes -3295 +▁sweet -3296 +▁Director -3297 +▁AN -3298 +▁nut -3299 +▁stress -3300 +▁billion -3301 +reci -3302 +▁Lear -3303 +▁quarter -3304 +▁physical -3305 +▁felt -3306 +ancing -3307 +▁hous -3308 +PS -3309 +▁Indian -3310 +▁hotel -3311 +▁Mac -3312 +itary -3313 +▁towards -3314 +▁consist -3315 +▁stage -3316 +▁spot -3317 +▁annual -3318 +▁shop -3319 +▁shot -3320 +▁strateg -3321 +▁Flor -3322 +▁wonderful -3323 +ports -3324 +porate -3325 +▁Open -3326 +▁loved -3327 +▁region -3328 +▁ing -3329 +▁path -3330 +▁Dem -3331 +▁feeling -3332 +▁owners -3333 +▁finish -3334 +▁ver -3335 +▁Pal -3336 +▁THE -3337 +▁aff -3338 +unte -3339 +▁mat -3340 +ari -3341 +▁eyes -3342 +▁pattern -3343 +▁Council -3344 +▁finally -3345 +isions -3346 +▁lik -3347 +ctions -3348 +▁ten -3349 +▁brought -3350 +ION -3351 +▁Texas -3352 +▁language -3353 +▁wife -3354 +▁Care -3355 +▁pet -3356 +▁interact -3357 +▁partner -3358 +▁sports -3359 +▁straight -3360 +rast -3361 +▁inform -3362 +▁Dan -3363 +▁nature -3364 +ads -3365 +▁investment -3366 +▁Club -3367 +roid -3368 +▁respond -3369 +▁concept -3370 +▁nearly -3371 +owl -3372 +dule -3373 +▁helping -3374 +▁Hel -3375 +▁Class -3376 +▁exerc -3377 +▁overall -3378 +▁star -3379 +▁Bre -3380 +▁categ -3381 +▁weather -3382 +▁ult -3383 +▁Apple -3384 +▁max -3385 +▁tried -3386 +▁guide -3387 +▁blue -3388 +▁William -3389 +end -3390 +▁temper -3391 +estival -3392 +▁pow -3393 +▁collabor -3394 +▁largest -3395 +▁Court -3396 +". -3397 +ened -3398 +▁demand -3399 +▁charge -3400 +▁independ -3401 +▁client -3402 +hips -3403 +▁Board -3404 +As -3405 +▁rock -3406 +▁Time -3407 +itect -3408 +ourney -3409 +▁wear -3410 +change -3411 +▁Oh -3412 +ament -3413 +▁pred -3414 +He -3415 +▁advert -3416 +▁definitely -3417 +mitted -3418 +▁appoint -3419 +▁wrote -3420 +▁candid -3421 +▁activity -3422 +▁gas -3423 +▁seven -3424 +▁Windows -3425 +rences -3426 +▁Ann -3427 +▁Ir -3428 +▁cold -3429 +rig -3430 +aly -3431 +▁benefit -3432 +ago -3433 +▁Internet -3434 +▁offered -3435 +inger -3436 +roud -3437 +asc -3438 +▁Australia -3439 +yd -3440 +▁acqu -3441 +▁influ -3442 +▁response -3443 +▁turned -3444 +▁Ant -3445 +wise -3446 +▁double -3447 +▁miles -3448 +▁Review -3449 +▁pieces -3450 +▁uses -3451 +▁Tom -3452 +last -3453 +ounds -3454 +▁earlier -3455 +▁devices -3456 +▁Fam -3457 +▁internet -3458 +uted -3459 +▁beginning -3460 +▁thous -3461 +ned -3462 +▁considered -3463 +▁ahead -3464 +lies -3465 +▁altern -3466 +▁appreci -3467 +ails -3468 +▁grand -3469 +▁reduce -3470 +▁exactly -3471 +▁Adv -3472 +▁histor -3473 +▁View -3474 +▁prec -3475 +▁Research -3476 +▁James -3477 +bon -3478 +▁wedding -3479 +▁active -3480 +▁homes -3481 +▁imag -3482 +▁entertain -3483 +arc -3484 +▁Michael -3485 +▁paid -3486 +ategy -3487 +▁doll -3488 +ustain -3489 +▁transport -3490 +▁difference -3491 +▁belie -3492 +▁Thank -3493 +icks -3494 +olute -3495 +▁political -3496 +▁IT -3497 +▁regul -3498 +▁challenge -3499 +▁served -3500 +▁supply -3501 +▁cho -3502 +more -3503 +▁surround -3504 +ampions -3505 +▁Micro -3506 +▁finished -3507 +▁Rich -3508 +▁Have -3509 +icate -3510 +OV -3511 +▁Big -3512 +umn -3513 +ading -3514 +You -3515 +agn -3516 +▁Rel -3517 +▁cash -3518 +▁Look -3519 +▁creative -3520 +cause -3521 +▁eight -3522 +estern -3523 +ston -3524 +▁understanding -3525 +▁retail -3526 +▁replace -3527 +▁Govern -3528 +icip -3529 +▁states -3530 +LE -3531 +ying -3532 +:|| -3533 +▁Cur -3534 +▁Mark -3535 +▁rates -3536 +orrow -3537 +mod -3538 +▁culture -3539 +▁Char -3540 +antly -3541 +ky -3542 +vin -3543 +oly -3544 +▁European -3545 +▁Super -3546 +▁lots -3547 +▁guarant -3548 +▁easier -3549 +▁experienced -3550 +▁ST -3551 +▁afford -3552 +▁Call -3553 +box -3554 +▁pages -3555 +▁Life -3556 +▁hus -3557 +dd -3558 +▁bottom -3559 +place -3560 +▁expand -3561 +iny -3562 +▁truly -3563 +sec -3564 +▁father -3565 +▁pressure -3566 +▁maybe -3567 +▁flav -3568 +hens -3569 +▁economic -3570 +ales -3571 +▁thank -3572 +▁reflect -3573 +inated -3574 +▁machine -3575 +ses -3576 +▁Company -3577 +error -3578 +rial -3579 +▁analysis -3580 +amic -3581 +icious -3582 +▁fat -3583 +▁IS -3584 +▁immediately -3585 +▁emot -3586 +▁named -3587 +alt -3588 +aled -3589 +▁gradu -3590 +▁numbers -3591 +sych -3592 +het -3593 +▁tom -3594 +▁Child -3595 +▁Det -3596 +▁Angel -3597 +▁demon -3598 +▁girls -3599 +▁exhib -3600 +rey -3601 +▁prot -3602 +▁comfortable -3603 +IP -3604 +erry -3605 +pa -3606 +▁assess -3607 +▁posted -3608 +▁satis -3609 +nown -3610 +▁degree -3611 +▁tips -3612 +chan -3613 +▁helped -3614 +▁damage -3615 +ivil -3616 +▁Ev -3617 +▁opening -3618 +▁Management -3619 +▁garden -3620 +▁dating -3621 +▁Bank -3622 +▁videos -3623 +▁contain -3624 +▁obt -3625 +▁wild -3626 +▁PC -3627 +ronic -3628 +care -3629 +▁storage -3630 +▁Bay -3631 +▁Ret -3632 +▁speak -3633 +▁behav -3634 +phone -3635 +▁subst -3636 +▁remain -3637 +force -3638 +anging -3639 +▁Plan -3640 +▁trade -3641 +▁launch -3642 +undred -3643 +rem -3644 +▁reviews -3645 +▁completed -3646 +▁Ins -3647 +▁II -3648 +ico -3649 +▁pool -3650 +▁Sun -3651 +▁Island -3652 +▁beyond -3653 +amm -3654 +▁lack -3655 +▁disease -3656 +asy -3657 +▁lock -3658 +▁Sing -3659 +▁Rock -3660 +set -3661 +▁threat -3662 +▁purpose -3663 +If -3664 +tion -3665 +▁Water -3666 +order -3667 +orial -3668 +▁cards -3669 +▁Contact -3670 +ado -3671 +▁adjust -3672 +▁Mart -3673 +dom -3674 +que -3675 +▁ter -3676 +▁spread -3677 +▁accur -3678 +▁existing -3679 +▁fashion -3680 +arily -3681 +▁knew -3682 +▁decor -3683 +▁Love -3684 +▁fant -3685 +▁Jes -3686 +▁highest -3687 +▁cancer -3688 +Re -3689 +lied -3690 +▁Florida -3691 +▁plus -3692 +OW -3693 +▁craft -3694 +▁jobs -3695 +soft -3696 +▁Although -3697 +met -3698 +▁conference -3699 +▁Rob -3700 +body -3701 +▁Win -3702 +▁responsible -3703 +▁increasing -3704 +▁Sur -3705 +▁During -3706 +▁allowed -3707 +aling -3708 +▁train -3709 +▁setting -3710 +▁excited -3711 +atever -3712 +▁prefer -3713 +rapy -3714 +▁driving -3715 +▁camera -3716 +▁proud -3717 +door -3718 +▁increased -3719 +▁Sa -3720 +▁sty -3721 +imal -3722 +▁welcome -3723 +▁lines -3724 +▁himself -3725 +▁middle -3726 +▁initial -3727 +▁appropri -3728 +▁Dec -3729 +▁proced -3730 +ona -3731 +aith -3732 +ences -3733 +▁fem -3734 +illa -3735 +▁Sum -3736 +▁Church -3737 +▁certainly -3738 +▁General -3739 +▁passion -3740 +▁frame -3741 +▁furn -3742 +▁coffee -3743 +cel -3744 +▁strugg -3745 +▁journey -3746 +▁Product -3747 +▁holiday -3748 +iling -3749 +▁files -3750 +▁Community -3751 +▁Camp -3752 +▁estate -3753 +▁effects -3754 +▁er -3755 +za -3756 +fl -3757 +▁husband -3758 +▁thanks -3759 +▁Back -3760 +▁frequ -3761 +▁cast -3762 +▁ingred -3763 +aming -3764 +▁steps -3765 +▁button -3766 +▁Republic -3767 +▁length -3768 +▁update -3769 +▁People -3770 +▁pen -3771 +▁Custom -3772 +▁born -3773 +ologies -3774 +▁normal -3775 +istics -3776 +▁efforts -3777 +▁selection -3778 +▁Two -3779 +▁Education -3780 +▁changed -3781 +ously -3782 +▁Mary -3783 +▁batter -3784 +▁Cong -3785 +net -3786 +▁secure -3787 +▁mission -3788 +vant -3789 +▁cru -3790 +anta -3791 +▁spirit -3792 +▁dedicated -3793 +▁bill -3794 +▁owner -3795 +▁clin -3796 +▁relax -3797 +▁surv -3798 +▁shopping -3799 +▁looked -3800 +lying -3801 +icken -3802 +ken -3803 +▁incred -3804 +▁occas -3805 +▁stream -3806 +ovel -3807 +▁moved -3808 +▁Show -3809 +ady -3810 +▁links -3811 +▁mis -3812 +omb -3813 +nection -3814 +▁Cap -3815 +▁science -3816 +ij -3817 +EM -3818 +▁aspect -3819 +▁protection -3820 +): -3821 +oma -3822 +▁haven -3823 +fit -3824 +▁wine -3825 +▁powerful -3826 +▁French -3827 +othing -3828 +▁extend -3829 +▁evening -3830 +▁demonstr -3831 +▁instruct -3832 +▁Take -3833 +▁meaning -3834 +▁background -3835 +▁Like -3836 +oos -3837 +ipp -3838 +▁occur -3839 +▁talking -3840 +▁patient -3841 +▁produce -3842 +IV -3843 +▁particularly -3844 +nded -3845 +▁USA -3846 +enance -3847 +▁aren -3848 +▁guys -3849 +porary -3850 +reed -3851 +friend -3852 +▁measure -3853 +▁Power -3854 +▁Sil -3855 +▁opin -3856 +▁basic -3857 +▁challenges -3858 +▁alone -3859 +ota -3860 +▁Under -3861 +▁Online -3862 +▁fan -3863 +DA -3864 +▁cream -3865 +ocr -3866 +▁payment -3867 +▁biggest -3868 +▁transfer -3869 +▁rules -3870 +▁Gra -3871 +▁doub -3872 +▁session -3873 +CC -3874 +itiz -3875 +▁shared -3876 +▁fill -3877 +leg -3878 +▁spring -3879 +▁fra -3880 +▁winter -3881 +▁sort -3882 +▁Project -3883 +range -3884 +▁runs -3885 +▁whose -3886 +▁letter -3887 +▁basis -3888 +▁couldn -3889 +IM -3890 +▁coach -3891 +▁federal -3892 +▁Information -3893 +▁Special -3894 +azine -3895 +annel -3896 +▁bur -3897 +▁schedule -3898 +▁liter -3899 +free -3900 +▁organizations -3901 +▁Pet -3902 +▁Because -3903 +▁manager -3904 +ios -3905 +istrict -3906 +▁leader -3907 +see -3908 +▁Phil -3909 +icing -3910 +▁drop -3911 +▁Who -3912 +▁models -3913 +▁electric -3914 +▁strength -3915 +▁Music -3916 +▁artist -3917 +acity -3918 +uing -3919 +▁church -3920 +isl -3921 +▁peace -3922 +▁reasons -3923 +uled -3924 +esome -3925 +▁Food -3926 +▁egg -3927 +▁Lake -3928 +▁slight -3929 +iques -3930 +▁absolute -3931 +▁capital -3932 +▁communities -3933 +▁sugar -3934 +▁volunte -3935 +▁extremely -3936 +▁Star -3937 +▁adding -3938 +▁competition -3939 +iture -3940 +▁exclus -3941 +▁guests -3942 +▁instit -3943 +▁onto -3944 +▁views -3945 +▁unit -3946 +▁mer -3947 +▁stick -3948 +▁British -3949 +▁shown -3950 +▁regarding -3951 +istered -3952 +▁Follow -3953 +vision -3954 +iation -3955 +▁residents -3956 +▁Sam -3957 +▁Ve -3958 +▁Thom -3959 +rief -3960 +gency -3961 +▁Profess -3962 +▁hundred -3963 +▁voice -3964 +▁conven -3965 +▁Miss -3966 +umber -3967 +hone -3968 +▁Enter -3969 +azon -3970 +la -3971 +▁seeing -3972 +▁River -3973 +▁chem -3974 +▁taste -3975 +▁ideal -3976 +▁strategy -3977 +apter -3978 +▁Mil -3979 +▁Yes -3980 +▁scient -3981 +▁followed -3982 +▁AP -3983 +▁Dri -3984 +▁Blue -3985 +ustr -3986 +▁daughter -3987 +▁Real -3988 +eria -3989 +▁colors -3990 +oyal -3991 +▁heavy -3992 +▁Institute -3993 +▁trou -3994 +▁compon -3995 +▁sched -3996 +▁Att -3997 +▁cry -3998 +osing -3999 +▁brother -4000 +▁gone -4001 +▁advantage -4002 +imb -4003 +▁notice -4004 +rian -4005 +▁Lou -4006 +▁guid -4007 +esterday -4008 +▁manage -4009 +oman -4010 +▁score -4011 +▁Matt -4012 +▁characters -4013 +▁virt -4014 +ags -4015 +standing -4016 +▁Fire -4017 +▁Police -4018 +▁Fore -4019 +iverse -4020 +▁traffic -4021 +asp -4022 +▁window -4023 +▁surface -4024 +▁ton -4025 +ocolate -4026 +term -4027 +▁Mount -4028 +▁experiences -4029 +▁Pay -4030 +▁smooth -4031 +ette -4032 +▁happened -4033 +▁Mal -4034 +▁reb -4035 +▁Ben -4036 +fast -4037 +▁graph -4038 +▁hom -4039 +▁Vol -4040 +▁names -4041 +▁identify -4042 +encies -4043 +▁shipping -4044 +▁pair -4045 +▁standards -4046 +▁senior -4047 +Sh -4048 +▁Wood -4049 +ech -4050 +icine -4051 +acing -4052 +gen -4053 +mark -4054 +▁talent -4055 +▁u -4056 +itude -4057 +▁District -4058 +BS -4059 +▁hospital -4060 +▁professionals -4061 +▁List -4062 +raw -4063 +▁initi -4064 +uce -4065 +▁breat -4066 +▁although -4067 +▁classic -4068 +▁workers -4069 +▁experts -4070 +ula -4071 +ixt -4072 +TS -4073 +▁luck -4074 +gn -4075 +▁Step -4076 +▁Hist -4077 +▁audience -4078 +▁covered -4079 +▁Est -4080 +▁laws -4081 +ero -4082 +▁Mot -4083 +▁Sign -4084 +▁passed -4085 +▁waiting -4086 +▁academ -4087 +▁guy -4088 +▁dang -4089 +▁beauty -4090 +rooms -4091 +▁fear -4092 +▁approx -4093 +▁continues -4094 +▁Development -4095 +▁finding -4096 +▁Team -4097 +▁snow -4098 +▁flex -4099 +▁efficient -4100 +orney -4101 +▁master -4102 +▁mail -4103 +▁associated -4104 +▁exciting -4105 +▁eval -4106 +▁Elect -4107 +inese -4108 +▁Exper -4109 +▁compared -4110 +inate -4111 +ga -4112 +▁larger -4113 +▁Chic -4114 +ss -4115 +▁critical -4116 +▁laun -4117 +sequ -4118 +▁cars -4119 +▁rob -4120 +▁Color -4121 +▁cab -4122 +▁technical -4123 +▁Family -4124 +▁trail -4125 +icon -4126 +▁ice -4127 +UR -4128 +▁shape -4129 +▁beg -4130 +▁district -4131 +▁keeping -4132 +▁TO -4133 +▁remind -4134 +▁solid -4135 +▁den -4136 +osh -4137 +▁Foundation -4138 +▁England -4139 +▁Science -4140 +▁facilities -4141 +▁boo -4142 +rees -4143 +▁wat -4144 +▁calls -4145 +▁restaurant -4146 +▁scene -4147 +▁maintain -4148 +▁greater -4149 +▁PR -4150 +▁Engine -4151 +▁sustain -4152 +▁officials -4153 +▁sy -4154 +mail -4155 +▁Alex -4156 +▁Bet -4157 +▁Sl -4158 +▁Jesus -4159 +▁posts -4160 +▁station -4161 +▁friendly -4162 +▁epis -4163 +▁Str -4164 +▁driver -4165 +▁sand -4166 +▁bul -4167 +▁listed -4168 +▁recipe -4169 +▁plenty -4170 +▁Glo -4171 +▁forget -4172 +odes -4173 +▁Vir -4174 +▁fish -4175 +▁older -4176 +illage -4177 +cul -4178 +▁rich -4179 +▁Start -4180 +▁continued -4181 +▁football -4182 +incip -4183 +▁package -4184 +▁developing -4185 +itors -4186 +log -4187 +▁Hum -4188 +▁established -4189 +yer -4190 +iller -4191 +▁Brown -4192 +rowd -4193 +▁income -4194 +▁useful -4195 +▁minute -4196 +▁truck -4197 +well -4198 +▁studies -4199 +▁advent -4200 +▁announce -4201 +oop -4202 +▁learned -4203 +ervation -4204 +▁Press -4205 +atically -4206 +▁disapp -4207 +▁tim -4208 +▁produced -4209 +win -4210 +▁motor -4211 +tra -4212 +▁League -4213 +using -4214 +▁rooms -4215 +unately -4216 +▁closed -4217 +▁beat -4218 +▁handle -4219 +▁appropriate -4220 +▁Whether -4221 +▁classes -4222 +unning -4223 +▁origin -4224 +▁military -4225 +ander -4226 +▁Central -4227 +▁artists -4228 +▁died -4229 +gal -4230 +▁Commission -4231 +▁explore -4232 +▁sup -4233 +▁placed -4234 +▁Offic -4235 +CA -4236 +▁economy -4237 +▁kept -4238 +▁thousands -4239 +night -4240 +▁knows -4241 +▁Franc -4242 +▁connection -4243 +▁winning -4244 +▁Smith -4245 +▁remove -4246 +▁pros -4247 +▁Social -4248 +▁evidence -4249 +▁force -4250 +▁primary -4251 +▁CEO -4252 +▁Media -4253 +▁adop -4254 +▁tree -4255 +▁repair -4256 +▁salt -4257 +▁Build -4258 +▁bright -4259 +aded -4260 +▁novel -4261 +▁testing -4262 +▁Download -4263 +iment -4264 +IG -4265 +▁Christian -4266 +▁operations -4267 +▁util -4268 +rael -4269 +▁status -4270 +▁opened -4271 +▁figure -4272 +▁requires -4273 +BA -4274 +▁street -4275 +▁discount -4276 +▁fol -4277 +There -4278 +▁Another -4279 +▁gun -4280 +▁communication -4281 +atab -4282 +ipes -4283 +▁presented -4284 +▁Grand -4285 +rd -4286 +▁decl -4287 +▁Beach -4288 +▁discover -4289 +ka -4290 +What -4291 +▁Obama -4292 +overy -4293 +▁ingredients -4294 +▁teaching -4295 +▁surg -4296 +▁medium -4297 +▁Network -4298 +▁injury -4299 +inn -4300 +▁Arch -4301 +semb -4302 +▁harm -4303 +▁starts -4304 +vention -4305 +oe -4306 +▁brain -4307 +bed -4308 +▁Carol -4309 +▁catch -4310 +▁contains -4311 +iled -4312 +▁selected -4313 +irection -4314 +▁shall -4315 +▁Mex -4316 +outhern -4317 +▁sharing -4318 +▁brings -4319 +look -4320 +action -4321 +▁butter -4322 +arge -4323 +▁doctor -4324 +idential -4325 +▁Disc -4326 +▁structure -4327 +▁advance -4328 +itar -4329 +ideo -4330 +▁poor -4331 +rehens -4332 +▁scen -4333 +men -4334 +▁famous -4335 +asure -4336 +▁pray -4337 +▁dinner -4338 +mp -4339 +▁arrest -4340 +apers -4341 +pective -4342 +▁Dig -4343 +▁prepared -4344 +olic -4345 +▁esc -4346 +▁Scott -4347 +▁Hill -4348 +▁manufacturer -4349 +▁suff -4350 +enses -4351 +▁Mad -4352 +▁Word -4353 +▁pm -4354 +▁serving -4355 +▁Microsoft -4356 +▁jump -4357 +▁Card -4358 +▁ship -4359 +▁loan -4360 +▁architect -4361 +▁Light -4362 +uries -4363 +▁Full -4364 +▁department -4365 +▁mo -4366 +▁remains -4367 +▁funds -4368 +▁Valley -4369 +▁vision -4370 +▁watching -4371 +▁secret -4372 +▁rank -4373 +atively -4374 +▁victim -4375 +PA -4376 +▁sto -4377 +▁Amazon -4378 +▁resist -4379 +▁Cup -4380 +ini -4381 +ctors -4382 +▁veget -4383 +▁gain -4384 +▁Chicago -4385 +aven -4386 +▁Their -4387 +noon -4388 +▁methods -4389 +▁balance -4390 +usion -4391 +lor -4392 +iers -4393 +▁agency -4394 +allery -4395 +▁updated -4396 +▁buying -4397 +▁movement -4398 +”, -4399 +riage -4400 +▁leaves -4401 +CH -4402 +▁Keep -4403 +▁Bill -4404 +▁drug -4405 +▁compl -4406 +▁Chinese -4407 +▁guess -4408 +▁Support -4409 +ooper -4410 +▁Net -4411 +RA -4412 +aked -4413 +▁encourage -4414 +▁Stand -4415 +▁spending -4416 +▁cloud -4417 +▁journal -4418 +▁map -4419 +▁OF -4420 +▁Week -4421 +▁reality -4422 +lands -4423 +▁Award -4424 +going -4425 +ption -4426 +ishes -4427 +▁Africa -4428 +LC -4429 +▁properties -4430 +okes -4431 +lastname -4432 +eless -4433 +▁beach -4434 +▁becoming -4435 +▁happens -4436 +▁Date -4437 +▁Ber -4438 +ellig -4439 +▁bought -4440 +top -4441 +▁sector -4442 +▁cleaning -4443 +▁Women -4444 +▁spons -4445 +▁RE -4446 +▁ID -4447 +▁Mel -4448 +▁leaving -4449 +▁sport -4450 +iency -4451 +▁relig -4452 +▁Commit -4453 +▁showing -4454 +antic -4455 +▁plants -4456 +itness -4457 +life -4458 +▁maintenance -4459 +▁https -4460 +▁facility -4461 +▁metal -4462 +▁Fort -4463 +▁Tor -4464 +ception -4465 +▁perhaps -4466 +▁dep -4467 +▁Times -4468 +essions -4469 +hem -4470 +ki -4471 +▁determine -4472 +ifts -4473 +▁leadership -4474 +▁Long -4475 +▁advanced -4476 +▁worksh -4477 +▁Israel -4478 +▁independent -4479 +▁stores -4480 +▁entry -4481 +▁Rad -4482 +▁Academ -4483 +▁Android -4484 +▁cris -4485 +▁mechan -4486 +▁fee -4487 +▁analy -4488 +▁Where -4489 +▁rain -4490 +berg -4491 +edy -4492 +▁upgr -4493 +▁rare -4494 +osure -4495 +▁unc -4496 +outs -4497 +▁cart -4498 +▁Que -4499 +▁exercise -4500 +▁wouldn -4501 +▁committed -4502 +abilities -4503 +ror -4504 +▁faith -4505 +itz -4506 +▁NY -4507 +▁meant -4508 +alls -4509 +▁vote -4510 +▁sem -4511 +▁iPhone -4512 +▁Mass -4513 +ograp -4514 +▁mist -4515 +▁bird -4516 +craft -4517 +▁Both -4518 +▁fabric -4519 +▁designs -4520 +▁Tim -4521 +▁numerous -4522 +▁ride -4523 +▁focused -4524 +▁anti -4525 +▁markets -4526 +▁Div -4527 +▁brows -4528 +▁Nov -4529 +▁ju -4530 +▁incor -4531 +▁Fil -4532 +fr -4533 +▁signed -4534 +agram -4535 +▁sources -4536 +▁Pub -4537 +▁records -4538 +** -4539 +▁funding -4540 +▁theme -4541 +▁actual -4542 +aturing -4543 +iest -4544 +▁establish -4545 +▁changing -4546 +▁chair -4547 +ae -4548 +▁visitors -4549 +▁steel -4550 +▁visual -4551 +▁multi -4552 +▁ir -4553 +For -4554 +estic -4555 +▁Next -4556 +MS -4557 +▁Los -4558 +▁forms -4559 +iences -4560 +▁crowd -4561 +iance -4562 +▁joined -4563 +▁Organ -4564 +isation -4565 +▁mill -4566 +▁coverage -4567 +▁elements -4568 +▁showed -4569 +rim -4570 +▁kick -4571 +▁selling -4572 +▁Watch -4573 +▁practices -4574 +▁animals -4575 +▁operating -4576 +▁obvious -4577 +fin -4578 +▁menu -4579 +▁busy -4580 +▁Nor -4581 +▁capacity -4582 +▁locations -4583 +▁grant -4584 +▁Medical -4585 +▁songs -4586 +▁fell -4587 +▁Set -4588 +▁neighbor -4589 +▁roof -4590 +▁refer -4591 +▁Head -4592 +isher -4593 +eared -4594 +▁George -4595 +oor -4596 +miss -4597 +▁memory -4598 +▁raised -4599 +▁Only -4600 +rics -4601 +▁worry -4602 +▁whatever -4603 +▁corner -4604 +▁ban -4605 +▁lose -4606 +▁allowing -4607 +igan -4608 +▁listen -4609 +IA -4610 +▁central -4611 +reek -4612 +▁plastic -4613 +▁society -4614 +▁accommod -4615 +gage -4616 +vere -4617 +▁relationships -4618 +SS -4619 +▁Tri -4620 +▁diet -4621 +igation -4622 +▁lux -4623 +▁diagn -4624 +▁thr -4625 +▁managed -4626 +▁Copy -4627 +OP -4628 +▁updates -4629 +▁limit -4630 +▁caused -4631 +▁estim -4632 +▁rap -4633 +▁parking -4634 +▁population -4635 +▁tables -4636 +▁Before -4637 +ya -4638 +▁Note -4639 +▁uns -4640 +", -4641 +fol -4642 +▁parties -4643 +▁decide -4644 +isco -4645 +uty -4646 +▁claims -4647 +▁articles -4648 +▁core -4649 +ano -4650 +▁survey -4651 +▁repe -4652 +▁Mer -4653 +ferences -4654 +▁assistance -4655 +amin -4656 +▁walking -4657 +▁tickets -4658 +▁Its -4659 +▁techniques -4660 +▁thoughts -4661 +ection -4662 +▁CD -4663 +rab -4664 +ivered -4665 +▁Sy -4666 +▁afternoon -4667 +▁colour -4668 +▁documents -4669 +▁wire -4670 +arrant -4671 +▁bowl -4672 +▁ended -4673 +▁transl -4674 +▁youth -4675 +▁brown -4676 +▁combination -4677 +▁vehicles -4678 +lines -4679 +▁flat -4680 +▁forum -4681 +▁yesterday -4682 +▁previously -4683 +▁Game -4684 +▁enjoyed -4685 +▁landsc -4686 +▁Society -4687 +▁profile -4688 +▁courses -4689 +iliar -4690 +▁launched -4691 +▁toward -4692 +▁appears -4693 +DF -4694 +▁eating -4695 +point -4696 +▁sea -4697 +▁Bur -4698 +▁Town -4699 +▁accident -4700 +▁Cre -4701 +▁awesome -4702 +▁filled -4703 +▁optim -4704 +▁teacher -4705 +coh -4706 +▁factors -4707 +bour -4708 +eed -4709 +▁Chris -4710 +▁Technology -4711 +▁temperature -4712 +rs -4713 +▁micro -4714 +▁mort -4715 +pan -4716 +▁psych -4717 +while -4718 +▁generally -4719 +▁putting -4720 +▁shel -4721 +▁charges -4722 +▁Learn -4723 +▁Mont -4724 +▁Trump -4725 +▁citiz -4726 +▁Atl -4727 +▁notes -4728 +▁smaller -4729 +▁Author -4730 +▁firstname -4731 +▁Pack -4732 +▁direction -4733 +▁values -4734 +▁task -4735 +no -4736 +rehensive -4737 +▁counter -4738 +▁Lord -4739 +▁Log -4740 +▁Wil -4741 +▁AL -4742 +▁outdoor -4743 +▁CA -4744 +▁Sand -4745 +▁earth -4746 +▁kid -4747 +▁teachers -4748 +▁panel -4749 +▁becomes -4750 +▁vs -4751 +▁tend -4752 +▁corporate -4753 +orthern -4754 +▁favour -4755 +ola -4756 +▁bon -4757 +▁Arts -4758 +▁Virgin -4759 +▁century -4760 +▁honest -4761 +▁separate -4762 +▁legisl -4763 +?? -4764 +▁cheese -4765 +▁Security -4766 +▁assign -4767 +yan -4768 +▁Congress -4769 +▁matt -4770 +On -4771 +▁sch -4772 +▁truth -4773 +▁purs -4774 +▁concerns -4775 +OD -4776 +▁situ -4777 +▁Committee -4778 +▁Main -4779 +istan -4780 +▁Data -4781 +▁helpful -4782 +▁dur -4783 +▁shut -4784 +▁Jew -4785 +New -4786 +▁swim -4787 +▁Centre -4788 +iration -4789 +▁missing -4790 +▁orders -4791 +▁fold -4792 +▁Jul -4793 +▁Frank -4794 +▁milk -4795 +rain -4796 +▁McC -4797 +een -4798 +▁Government -4799 +▁flu -4800 +▁throw -4801 +!!! -4802 +po -4803 +▁Ext -4804 +▁adapt -4805 +▁polic -4806 +▁innovative -4807 +▁installation -4808 +ownt -4809 +▁Aud -4810 +▁ur -4811 +▁south -4812 +▁relevant -4813 +▁Lo -4814 +▁tow -4815 +▁van -4816 +pet -4817 +ifying -4818 +olars -4819 +rical -4820 +▁Robert -4821 +SP -4822 +▁Museum -4823 +▁decisions -4824 +▁environmental -4825 +ye -4826 +▁discussion -4827 +▁despite -4828 +▁waste -4829 +▁AND -4830 +▁fourth -4831 +▁slightly -4832 +orter -4833 +▁Tur -4834 +oles -4835 +▁inspired -4836 +▁Mike -4837 +▁ang -4838 +▁dance -4839 +▁net -4840 +▁Tre -4841 +▁enhance -4842 +▁Den -4843 +▁apart -4844 +▁Prov -4845 +▁Wall -4846 +▁Jim -4847 +▁scr -4848 +▁spect -4849 +▁mental -4850 +▁Hotel -4851 +▁Old -4852 +▁fantastic -4853 +▁Land -4854 +▁pal -4855 +▁format -4856 +▁Somet -4857 +▁sav -4858 +▁joint -4859 +▁desk -4860 +ita -4861 +▁upcoming -4862 +▁ath -4863 +▁AC -4864 +▁spl -4865 +▁Lead -4866 +▁Dou -4867 +inct -4868 +▁emp -4869 +▁YOU -4870 +▁willing -4871 +rist -4872 +▁hearing -4873 +▁sounds -4874 +▁fuel -4875 +▁commitment -4876 +ups -4877 +▁consumers -4878 +▁appeal -4879 +▁raise -4880 +?” -4881 +▁Manager -4882 +▁civil -4883 +▁UN -4884 +kin -4885 +osen -4886 +▁Place -4887 +▁library -4888 +umin -4889 +SA -4890 +ensions -4891 +▁vir -4892 +▁north -4893 +▁Through -4894 +▁expertise -4895 +▁Report -4896 +▁promote -4897 +▁asking -4898 +▁absolutely -4899 +▁units -4900 +▁Contin -4901 +water -4902 +▁chocolate -4903 +cher -4904 +▁extensive -4905 +▁Louis -4906 +▁movies -4907 +▁delivered -4908 +▁Series -4909 +▁bask -4910 +▁delicious -4911 +▁Ill -4912 +Pro -4913 +▁eth -4914 +▁reached -4915 +▁sets -4916 +zen -4917 +Com -4918 +▁Vict -4919 +known -4920 +▁executive -4921 +uable -4922 +▁plays -4923 +▁agreement -4924 +ternal -4925 +▁Link -4926 +▁radio -4927 +nergy -4928 +▁Posted -4929 +▁Ma -4930 +▁foreign -4931 +▁alle -4932 +▁lunch -4933 +REE -4934 +▁transform -4935 +▁datab -4936 +aser -4937 +▁register -4938 +icians -4939 +▁emergency -4940 +▁thick -4941 +▁struct -4942 +▁trees -4943 +▁Angeles -4944 +▁Invest -4945 +list -4946 +eline -4947 +▁Ham -4948 +▁Lim -4949 +▁Const -4950 +▁Oper -4951 +▁provider -4952 +▁brief -4953 +▁NE -4954 +▁presence -4955 +text -4956 +▁Upd -4957 +▁combined -4958 +▁Fund -4959 +▁rid -4960 +!) -4961 +▁Admin -4962 +▁Fun -4963 +▁achie -4964 +prise -4965 +▁Gal -4966 +▁furniture -4967 +▁seeking -4968 +▁fruit -4969 +▁NOT -4970 +▁Hand -4971 +▁controll -4972 +▁Union -4973 +osition -4974 +▁connected -4975 +▁Join -4976 +bre -4977 +▁Jun -4978 +▁readers -4979 +▁expensive -4980 +▁adults -4981 +▁Person -4982 +▁Cook -4983 +▁Democr -4984 +reens -4985 +▁seconds -4986 +▁feels -4987 +▁poll -4988 +▁ON -4989 +uality -4990 +▁rat -4991 +▁generation -4992 +▁distance -4993 +▁edge -4994 +▁fees -4995 +▁mentioned -4996 +▁recommended -4997 +▁trial -4998 +▁chat -4999 +▁calling -5000 +▁har -5001 +▁nine -5002 +▁cities -5003 +▁chicken -5004 +▁approximately -5005 +▁Plus -5006 +atin -5007 +▁bringing -5008 +TH -5009 +▁consid -5010 +▁Access -5011 +▁Journal -5012 +▁Inte -5013 +▁wel -5014 +▁married -5015 +fortunately -5016 +▁Peter -5017 +▁prepare -5018 +▁websites -5019 +▁operation -5020 +▁alternative -5021 +▁confidence -5022 +▁server -5023 +▁dogs -5024 +IR -5025 +▁registered -5026 +▁stars -5027 +cean -5028 +LA -5029 +▁educational -5030 +▁Master -5031 +burg -5032 +▁Di -5033 +appy -5034 +▁Indust -5035 +▁photograph -5036 +▁restrict -5037 +ef -5038 +ruit -5039 +▁Chief -5040 +▁Ol -5041 +▁tight -5042 +My -5043 +▁Children -5044 +▁centre -5045 +hab -5046 +emporary -5047 +▁square -5048 +▁France -5049 +othes -5050 +▁Spring -5051 +▁tun -5052 +▁returned -5053 +▁lovely -5054 +▁minimum -5055 +▁category -5056 +OC -5057 +▁Live -5058 +azz -5059 +▁exchange -5060 +▁seat -5061 +irmed -5062 +▁stret -5063 +▁Prote -5064 +ears -5065 +▁topic -5066 +▁installed -5067 +▁tea -5068 +▁info -5069 +▁Rest -5070 +rag -5071 +▁tough -5072 +▁brands -5073 +asks -5074 +▁guest -5075 +▁princip -5076 +▁Way -5077 +bu -5078 +▁majority -5079 +▁researc -5080 +atre -5081 +inations -5082 +▁wearing -5083 +▁appearance -5084 +▁female -5085 +how -5086 +▁neck -5087 +▁Minister -5088 +▁colle -5089 +estyle -5090 +ship -5091 +orry -5092 +▁Cy -5093 +IF -5094 +When -5095 +ulated -5096 +aks -5097 +▁ven -5098 +▁accompl -5099 +▁therefore -5100 +▁mostly -5101 +▁instru -5102 +▁Canad -5103 +▁Ok -5104 +▁Price -5105 +elines -5106 +▁maximum -5107 +▁HD -5108 +▁winner -5109 +▁sauce -5110 +▁processes -5111 +▁academic -5112 +▁surgery -5113 +van -5114 +kins -5115 +▁measures -5116 +▁responsibility -5117 +▁Ver -5118 +ifications -5119 +▁leads -5120 +▁impl -5121 +▁teen -5122 +▁Mo -5123 +▁killed -5124 +▁Sup -5125 +▁approved -5126 +▁apps -5127 +▁anywhere -5128 +▁arrange -5129 +▁Max -5130 +nel -5131 +▁Men -5132 +osis -5133 +▁Sports -5134 +▁stre -5135 +▁Video -5136 +▁Hy -5137 +▁importance -5138 +▁Test -5139 +▁gather -5140 +▁ring -5141 +▁climate -5142 +▁Squ -5143 +alian -5144 +▁satisf -5145 +▁detailed -5146 +▁boost -5147 +▁signs -5148 +▁battery -5149 +An -5150 +▁nom -5151 +hi -5152 +▁battle -5153 +▁feedback -5154 +▁chief -5155 +▁veter -5156 +▁Festival -5157 +▁switch -5158 +▁Creat -5159 +mond -5160 +▁dyn -5161 +▁worldwide -5162 +▁featured -5163 +▁scheduled -5164 +▁cooking -5165 +▁disp -5166 +▁highlight -5167 +ius -5168 +lets -5169 +▁Wild -5170 +▁supporting -5171 +▁rise -5172 +ait -5173 +▁crim -5174 +▁Library -5175 +▁sympt -5176 +ulty -5177 +▁cheap -5178 +cohol -5179 +▁comprehensive -5180 +▁predict -5181 +▁participants -5182 +vis -5183 +▁Walk -5184 +▁Jud -5185 +arsh -5186 +▁Cat -5187 +ker -5188 +▁IP -5189 +▁Thomas -5190 +▁affordable -5191 +▁otherwise -5192 +paper -5193 +▁Bob -5194 +▁Tour -5195 +▁defense -5196 +▁Conference -5197 +alend -5198 +ters -5199 +Cl -5200 +cious -5201 +▁bike -5202 +▁Lab -5203 +roy -5204 +otten -5205 +▁properly -5206 +ician -5207 +▁animal -5208 +▁actions -5209 +▁Using -5210 +ulate -5211 +▁clearly -5212 +ena -5213 +▁performed -5214 +▁Earth -5215 +FL -5216 +▁Search -5217 +gl -5218 +▁mur -5219 +▁Pan -5220 +▁purchased -5221 +itable -5222 +bl -5223 +▁Those -5224 +idden -5225 +▁ourselves -5226 +iner -5227 +pected -5228 +oston -5229 +▁Bi -5230 +▁conv -5231 +▁joy -5232 +uts -5233 +▁Copyright -5234 +▁audio -5235 +iser -5236 +▁chemical -5237 +▁meal -5238 +▁vent -5239 +▁competitive -5240 +verse -5241 +anda -5242 +▁Johnson -5243 +▁appeared -5244 +▁windows -5245 +▁advertising -5246 +▁Global -5247 +▁applied -5248 +▁push -5249 +▁motiv -5250 +UT -5251 +bol -5252 +▁Prem -5253 +▁ment -5254 +▁Cam -5255 +▁doors -5256 +▁Soft -5257 +ENT -5258 +▁Party -5259 +▁sister -5260 +▁policies -5261 +gment -5262 +▁pump -5263 +▁mouth -5264 +oga -5265 +▁topics -5266 +▁Form -5267 +▁Jeff -5268 +erg -5269 +▁supported -5270 +▁valid -5271 +▁Bas -5272 +▁technologies -5273 +▁pregn -5274 +▁scale -5275 +▁flowers -5276 +▁rom -5277 +▁behavior -5278 +▁arm -5279 +▁African -5280 +▁sitting -5281 +rastructure -5282 +GB -5283 +MA -5284 +▁minor -5285 +▁writer -5286 +▁familiar -5287 +▁Jose -5288 +▁holding -5289 +▁entertainment -5290 +▁featuring -5291 +▁rub -5292 +▁Germany -5293 +▁episode -5294 +▁coord -5295 +but -5296 +▁bond -5297 +ushed -5298 +▁studio -5299 +▁Western -5300 +▁editor -5301 +▁Charl -5302 +▁opinion -5303 +▁Kore -5304 +▁elim -5305 +alog -5306 +▁Cost -5307 +▁participate -5308 +▁revenue -5309 +▁plug -5310 +▁Haw -5311 +tr -5312 +▁removed -5313 +▁faster -5314 +▁Connect -5315 +▁Fair -5316 +▁Help -5317 +▁Saf -5318 +▁sides -5319 +west -5320 +inch -5321 +▁strategies -5322 +▁Champions -5323 +▁coast -5324 +erts -5325 +▁jew -5326 +▁charged -5327 +▁depending -5328 +col -5329 +▁totally -5330 +prene -5331 +oration -5332 +▁birthday -5333 +▁reliable -5334 +▁visiting -5335 +▁quiet -5336 +▁begins -5337 +▁Martin -5338 +▁species -5339 +▁conversation -5340 +▁described -5341 +UN -5342 +inating -5343 +▁Energy -5344 +▁flight -5345 +orough -5346 +▁caught -5347 +▁Girl -5348 +▁Cert -5349 +▁ap -5350 +▁eventually -5351 +▁monthly -5352 +▁fif -5353 +▁consumer -5354 +hus -5355 +den -5356 +▁Hospital -5357 +tered -5358 +▁Sar -5359 +▁restaurants -5360 +▁tail -5361 +▁meat -5362 +▁housing -5363 +▁cells -5364 +▁dish -5365 +▁teach -5366 +▁MP -5367 +▁deals -5368 +▁inches -5369 +▁Digital -5370 +▁pu -5371 +▁television -5372 +otic -5373 +▁Mic -5374 +▁accounts -5375 +with -5376 +▁improved -5377 +reprene -5378 +ersey -5379 +▁German -5380 +▁Dev -5381 +▁nav -5382 +▁Orig -5383 +apes -5384 +▁Gen -5385 +▁labor -5386 +▁Australian -5387 +▁delight -5388 +inter -5389 +▁university -5390 +▁dim -5391 +▁Id -5392 +▁fly -5393 +▁Joe -5394 +▁officer -5395 +▁marriage -5396 +▁hundreds -5397 +▁neighborhood -5398 +▁campus -5399 +▁revealed -5400 +ario -5401 +▁shoes -5402 +▁employee -5403 +ste -5404 +▁cro -5405 +▁label -5406 +▁breakfast -5407 +ulous -5408 +▁ign -5409 +weight -5410 +▁CH -5411 +▁Ul -5412 +▁confirm -5413 +▁Penn -5414 +▁administration -5415 +▁typically -5416 +SE -5417 +▁occasion -5418 +▁Academy -5419 +▁introduced -5420 +▁celebrate -5421 +▁exclusive -5422 +How -5423 +▁election -5424 +▁covers -5425 +ht -5426 +▁Secret -5427 +▁essay -5428 +▁Mid -5429 +▁appointment -5430 +ighter -5431 +▁volume -5432 +▁Ce -5433 +▁unless -5434 +sm -5435 +▁Opt -5436 +hew -5437 +achel -5438 +▁discovered -5439 +▁specifically -5440 +▁amb -5441 +▁vary -5442 +hent -5443 +▁compar -5444 +iat -5445 +▁internal -5446 +▁indic -5447 +▁planned -5448 +Our -5449 +▁Hope -5450 +▁twe -5451 +▁debt -5452 +▁intended -5453 +NA -5454 +▁cultural -5455 +▁cutting -5456 +▁sessions -5457 +▁AT -5458 +▁Americans -5459 +▁Lt -5460 +▁aspects -5461 +▁manufacturing -5462 +▁remaining -5463 +▁Maybe -5464 +▁Young -5465 +eries -5466 +ushing -5467 +▁mel -5468 +▁sexual -5469 +▁SP -5470 +bur -5471 +ixture -5472 +igr -5473 +▁shares -5474 +edia -5475 +▁nor -5476 +▁Box -5477 +merce -5478 +▁Boy -5479 +▁Second -5480 +▁recovery -5481 +); -5482 +▁basket -5483 +▁fle -5484 +▁Boston -5485 +▁icon -5486 +▁chart -5487 +▁engineering -5488 +▁remote -5489 +▁trading -5490 +ords -5491 +▁concent -5492 +▁Ari -5493 +▁scored -5494 +▁Er -5495 +▁bread -5496 +▁incredible -5497 +▁partnership -5498 +▁Key -5499 +▁investigation -5500 +▁lights -5501 +▁edition -5502 +ournament -5503 +▁dining -5504 +▁Commun -5505 +uke -5506 +asts -5507 +▁industrial -5508 +▁Jon -5509 +▁guarantee -5510 +▁forg -5511 +▁detect -5512 +▁Mur -5513 +CE -5514 +▁invent -5515 +aren -5516 +▁Meet -5517 +cont -5518 +▁Carolina -5519 +▁drivers -5520 +gas -5521 +▁components -5522 +▁Japanese -5523 +▁negative -5524 +▁liqu -5525 +▁hyd -5526 +▁automatically -5527 +mosp -5528 +▁End -5529 +elly -5530 +▁resource -5531 +eper -5532 +▁depos -5533 +▁cake -5534 +ala -5535 +▁Pac -5536 +▁mir -5537 +▁freed -5538 +▁fields -5539 +lymp -5540 +▁burn -5541 +▁Virginia -5542 +odies -5543 +▁practical -5544 +berry -5545 +▁chain -5546 +▁Type -5547 +cm -5548 +▁choices -5549 +▁noted -5550 +rupt -5551 +▁Human -5552 +▁evalu -5553 +▁quot -5554 +▁pock -5555 +▁confirmed -5556 +inet -5557 +▁interior -5558 +▁dollars -5559 +▁seemed -5560 +▁Applic -5561 +otton -5562 +▁Lee -5563 +lywood -5564 +▁cop -5565 +▁victory -5566 +▁bedroom -5567 +▁Jones -5568 +itionally -5569 +▁thus -5570 +▁rule -5571 +idays -5572 +▁suitable -5573 +▁Wal -5574 +iability -5575 +▁argu -5576 +▁depart -5577 +▁arrived -5578 +cles -5579 +▁Brand -5580 +▁Quest -5581 +ua -5582 +unting -5583 +▁perfectly -5584 +Al -5585 +▁FREE -5586 +▁twice -5587 +tters -5588 +hand -5589 +uits -5590 +▁buildings -5591 +▁boys -5592 +Ex -5593 +away -5594 +▁teeth -5595 +▁Tem -5596 +aped -5597 +▁possibly -5598 +▁broken -5599 +▁warrant -5600 +▁Mult -5601 +▁Equ -5602 +king -5603 +abet -5604 +gers -5605 +▁symptoms -5606 +▁films -5607 +▁crew -5608 +▁honor -5609 +uous -5610 +▁shooting -5611 +▁elig -5612 +▁Italian -5613 +▁doubt -5614 +▁bathroom -5615 +▁Victor -5616 +arp -5617 +▁ticket -5618 +▁Know -5619 +▁anc -5620 +arks -5621 +No -5622 +!” -5623 +▁Gar -5624 +▁island -5625 +▁stated -5626 +▁issued -5627 +ailability -5628 +flow -5629 +▁DV -5630 +▁chosen -5631 +ilit -5632 +▁Cast -5633 +rier -5634 +▁considering -5635 +▁enable -5636 +▁commission -5637 +▁Mexico -5638 +▁Steve -5639 +▁Little -5640 +▁injuries -5641 +▁Trust -5642 +urban -5643 +▁candidates -5644 +poses -5645 +▁tests -5646 +related -5647 +otal -5648 +▁Williams -5649 +▁reference -5650 +▁desire -5651 +▁foods -5652 +▁rapid -5653 +▁keeps -5654 +▁corn -5655 +TC -5656 +▁bigger -5657 +ibilities -5658 +road -5659 +▁ris -5660 +▁missed -5661 +ipl -5662 +▁Instead -5663 +▁mode -5664 +▁paying -5665 +ulations -5666 +▁boat -5667 +▁picked -5668 +▁golf -5669 +▁contest -5670 +▁Does -5671 +iors -5672 +▁intellig -5673 +▁circum -5674 +▁Farm -5675 +acks -5676 +▁Students -5677 +▁Hard -5678 +▁appreciate -5679 +▁decades -5680 +▁premium -5681 +▁turns -5682 +▁tomorrow -5683 +▁sizes -5684 +iamond -5685 +▁trend -5686 +▁Games -5687 +▁valuable -5688 +gend -5689 +owntown -5690 +▁fro -5691 +▁settings -5692 +▁Coast -5693 +▁protected -5694 +ien -5695 +▁voc -5696 +▁Tit -5697 +▁Kn -5698 +▁presentation -5699 +▁soul -5700 +▁Mat -5701 +▁Mov -5702 +▁lived -5703 +▁Page -5704 +▁regularly -5705 +▁realize -5706 +mes -5707 +▁earned -5708 +atoes -5709 +▁Current -5710 +▁registration -5711 +▁nurs -5712 +▁Night -5713 +▁config -5714 +▁Ohio -5715 +▁attorney -5716 +▁magazine -5717 +▁citizens -5718 +▁quant -5719 +hetic -5720 +▁aid -5721 +▁failed -5722 +▁oven -5723 +▁AS -5724 +▁database -5725 +fection -5726 +ora -5727 +ris -5728 +▁spr -5729 +▁Assist -5730 +▁therapy -5731 +▁organic -5732 +ias -5733 +▁license -5734 +▁sequ -5735 +wing -5736 +▁Canadian -5737 +weet -5738 +▁Econom -5739 +▁agent -5740 +▁Michigan -5741 +▁surrounding -5742 +AY -5743 +▁mine -5744 +▁affected -5745 +▁greatest -5746 +▁resol -5747 +▁ends -5748 +▁providers -5749 +▁moments -5750 +oosing -5751 +▁ran -5752 +▁county -5753 +▁Olymp -5754 +▁tells -5755 +what -5756 +▁ec -5757 +▁dates -5758 +▁Span -5759 +PR -5760 +▁grown -5761 +▁Cross -5762 +▁reput -5763 +▁MS -5764 +▁athlet -5765 +▁Code -5766 +ev -5767 +▁surf -5768 +▁virtual -5769 +▁investors -5770 +▁Instagram -5771 +▁grade -5772 +spe -5773 +▁Pass -5774 +▁calcul -5775 +▁answers -5776 +.| -5777 +▁loves -5778 +▁shock -5779 +▁supports -5780 +▁painting -5781 +▁inn -5782 +▁draft -5783 +phas -5784 +▁influence -5785 +▁proposed -5786 +lights -5787 +▁agencies -5788 +oup -5789 +▁surprise -5790 +▁History -5791 +pass -5792 +▁Control -5793 +▁Kh -5794 +abled -5795 +▁hero -5796 +▁dial -5797 +▁poly -5798 +▁Sn -5799 +▁explain -5800 +▁weap -5801 +▁accurate -5802 +▁submit -5803 +▁degrees -5804 +▁renew -5805 +▁Bal -5806 +race -5807 +▁recorded -5808 +▁Executive -5809 +▁ages -5810 +▁Van -5811 +▁Point -5812 +oking -5813 +▁owned -5814 +▁convenient -5815 +▁Georg -5816 +▁AR -5817 +▁purposes -5818 +▁Share -5819 +vell -5820 +▁load -5821 +ria -5822 +which -5823 +▁Did -5824 +▁beer -5825 +▁yes -5826 +irms -5827 +▁whom -5828 +fficient -5829 +▁Inf -5830 +▁league -5831 +▁Federal -5832 +▁holds -5833 +▁processing -5834 +ella -5835 +▁Buy -5836 +▁Middle -5837 +TA -5838 +▁gro -5839 +TV -5840 +▁instructions -5841 +▁die -5842 +▁Cas -5843 +▁Asia -5844 +kes -5845 +▁interests -5846 +▁Jackson -5847 +▁Def -5848 +▁apparent -5849 +▁efficiency -5850 +▁pure -5851 +ansas -5852 +hors -5853 +▁jack -5854 +▁atmosp -5855 +▁effectively -5856 +▁Expl -5857 +mar -5858 +▁violence -5859 +luding -5860 +▁returns -5861 +alendar -5862 +▁Comple -5863 +▁Enjoy -5864 +▁element -5865 +▁pleased -5866 +▁awareness -5867 +▁goods -5868 +▁Paris -5869 +vy -5870 +real -5871 +▁messages -5872 +OVID -5873 +cking -5874 +▁pepper -5875 +▁channel -5876 +▁receiving -5877 +▁infrastructure -5878 +print -5879 +▁Ken -5880 +▁pod -5881 +rick -5882 +▁Three -5883 +▁electronic -5884 +▁Ire -5885 +▁occup -5886 +▁Made -5887 +▁forced -5888 +intage -5889 +▁officers -5890 +▁Size -5891 +▁facing -5892 +▁creation -5893 +ospit -5894 +▁musical -5895 +▁standing -5896 +▁Requ -5897 +▁researchers -5898 +▁Dom -5899 +▁sam -5900 +▁incident -5901 +▁Royal -5902 +▁perman -5903 +▁Columb -5904 +▁belong -5905 +▁closer -5906 +irty -5907 +▁lighting -5908 +▁everyday -5909 +▁Try -5910 +▁diverse -5911 +▁grad -5912 +▁Richard -5913 +▁route -5914 +▁Daily -5915 +profit -5916 +ban -5917 +▁Travel -5918 +▁ongoing -5919 +▁distribution -5920 +▁Photo -5921 +▁lit -5922 +▁Cred -5923 +▁causes -5924 +poration -5925 +made -5926 +▁trouble -5927 +▁Ell -5928 +▁thread -5929 +▁apartment -5930 +▁Sher -5931 +▁administr -5932 +▁advoc -5933 +▁usual -5934 +▁wheel -5935 +▁serves -5936 +▁Chair -5937 +▁Ut -5938 +rum -5939 +▁sad -5940 +▁Need -5941 +▁pun -5942 +anche -5943 +▁Store -5944 +▁du -5945 +▁mini -5946 +isters -5947 +▁obtain -5948 +▁kinds -5949 +▁ped -5950 +▁healthcare -5951 +▁favourite -5952 +hy -5953 +▁judge -5954 +▁silver -5955 +▁arts -5956 +▁wid -5957 +PM -5958 +GE -5959 +▁Cath -5960 +▁supposed -5961 +▁meetings -5962 +▁error -5963 +▁crime -5964 +equ -5965 +▁rough -5966 +▁spaces -5967 +▁yellow -5968 +▁knowing -5969 +rete -5970 +▁plate -5971 +▁affili -5972 +udden -5973 +ribe -5974 +▁disappoint -5975 +▁stopped -5976 +▁flour -5977 +▁enthus -5978 +▁fellow -5979 +▁WH -5980 +umes -5981 +▁Wi -5982 +▁bound -5983 +never -5984 +oses -5985 +▁collaboration -5986 +aration -5987 +▁manner -5988 +Tube -5989 +▁Rev -5990 +xy -5991 +▁designer -5992 +itage -5993 +▁licens -5994 +▁construct -5995 +▁concerned -5996 +actions -5997 +▁Andrew -5998 +▁monit -5999 +▁subscrib -6000 +▁massive -6001 +▁Ltd -6002 +person -6003 +anges -6004 +▁weekly -6005 +▁clothes -6006 +▁follows -6007 +ennis -6008 +uction -6009 +▁Low -6010 +▁tut -6011 +▁rot -6012 +▁Four -6013 +ancer -6014 +cue -6015 +sembly -6016 +▁Local -6017 +▁Daniel -6018 +arian -6019 +ello -6020 +▁prison -6021 +▁tur -6022 +▁household -6023 +▁Wr -6024 +yard -6025 +▁simpl -6026 +▁forces -6027 +▁Clean -6028 +▁reduced -6029 +▁regional -6030 +▁challenging -6031 +iveness -6032 +EE -6033 +astern -6034 +▁male -6035 +▁Mean -6036 +▁tack -6037 +▁Guide -6038 +▁functions -6039 +▁stone -6040 +▁Ra -6041 +▁agreed -6042 +pond -6043 +▁hang -6044 +▁Right -6045 +▁script -6046 +▁Room -6047 +▁Santa -6048 +▁Francisco -6049 +oti -6050 +▁Hen -6051 +▁lifestyle -6052 +▁Russian -6053 +▁moist -6054 +▁treated -6055 +orable -6056 +▁horse -6057 +▁debut -6058 +▁complic -6059 +▁Marketing -6060 +▁alcohol -6061 +ansion -6062 +▁assets -6063 +▁native -6064 +▁innovation -6065 +▁payments -6066 +▁sample -6067 +▁fixed -6068 +ml -6069 +▁reserved -6070 +▁successfully -6071 +▁impressive -6072 +Con -6073 +▁powder -6074 +▁crisis -6075 +▁emotional -6076 +▁explained -6077 +FC -6078 +DS -6079 +▁Ep -6080 +Ar -6081 +▁inspiration -6082 +▁cute -6083 +▁Job -6084 +All -6085 +▁Visit -6086 +Un -6087 +ache -6088 +▁witness -6089 +under -6090 +▁leather -6091 +▁spokes -6092 +▁row -6093 +▁Rights -6094 +writ -6095 +ench -6096 +▁fort -6097 +▁forest -6098 +▁password -6099 +ppers -6100 +▁matters -6101 +▁Brook -6102 +▁FOR -6103 +Pl -6104 +ani -6105 +▁identified -6106 +alled -6107 +▁luxury -6108 +▁employment -6109 +BI -6110 +▁photograp -6111 +Be -6112 +▁blogg -6113 +▁drugs -6114 +▁Pot -6115 +▁Summer -6116 +▁Hor -6117 +▁cock -6118 +▁extended -6119 +And -6120 +▁phil -6121 +▁iron -6122 +▁Die -6123 +shire -6124 +igration -6125 +erves -6126 +▁Area -6127 +lyn -6128 +▁determined -6129 +▁rand -6130 +▁accepted -6131 +▁grab -6132 +▁recognized -6133 +▁outstanding -6134 +▁prop -6135 +▁Blo -6136 +▁prompt -6137 +▁der -6138 +▁styles -6139 +▁resolution -6140 +▁Southern -6141 +▁tou -6142 +▁height -6143 +folio -6144 +▁walls -6145 +▁odd -6146 +▁gifts -6147 +▁Rose -6148 +▁clinical -6149 +▁casino -6150 +▁vacation -6151 +▁Name -6152 +▁decre -6153 +▁advis -6154 +▁Cra -6155 +▁accessible -6156 +▁context -6157 +▁nearby -6158 +▁graduate -6159 +liance -6160 +▁conducted -6161 +can -6162 +They -6163 +vate -6164 +▁happening -6165 +rip -6166 +▁Number -6167 +▁positions -6168 +▁worse -6169 +▁Small -6170 +▁dangerous -6171 +▁perspective -6172 +▁Awards -6173 +▁Financial -6174 +▁SH -6175 +▁freedom -6176 +▁gear -6177 +mary -6178 +▁carried -6179 +▁speaking -6180 +▁factor -6181 +letter -6182 +▁Ash -6183 +▁Turn -6184 +▁stunning -6185 +▁sustainable -6186 +▁speech -6187 +▁Colorado -6188 +cling -6189 +▁tag -6190 +▁Scot -6191 +▁folks -6192 +▁significantly -6193 +▁candidate -6194 +▁Oil -6195 +unction -6196 +▁telling -6197 +▁domestic -6198 +ulture -6199 +▁examples -6200 +anged -6201 +▁Avenue -6202 +▁constantly -6203 +rid -6204 +▁committee -6205 +▁emphas -6206 +▁Training -6207 +▁cable -6208 +▁Coll -6209 +▁likes -6210 +▁Lin -6211 +▁symbol -6212 +▁Kim -6213 +▁univers -6214 +▁hardware -6215 +▁mixed -6216 +▁Perform -6217 +ificate -6218 +▁originally -6219 +▁solar -6220 +▁Having -6221 +▁Account -6222 +▁hook -6223 +▁vit -6224 +ucle -6225 +▁Sometimes -6226 +▁Which -6227 +▁stands -6228 +emic -6229 +▁retire -6230 +▁Hon -6231 +▁conflic -6232 +▁awards -6233 +Don -6234 +ployment -6235 +▁adventure -6236 +▁contemporary -6237 +▁showc -6238 +LY -6239 +▁houses -6240 +▁involve -6241 +▁logo -6242 +▁village -6243 +▁fulf -6244 +▁Though -6245 +▁Cond -6246 +▁bless -6247 +▁Spanish -6248 +▁carefully -6249 +▁patterns -6250 +▁supplies -6251 +▁MA -6252 +▁Dub -6253 +▁Select -6254 +▁procedures -6255 +▁Print -6256 +▁DC -6257 +ingly -6258 +▁auto -6259 +▁programme -6260 +▁browser -6261 +▁imagine -6262 +▁Mobile -6263 +▁Despite -6264 +▁stretch -6265 +▁losing -6266 +▁confident -6267 +▁criminal -6268 +▁fitness -6269 +▁replacement -6270 +lete -6271 +▁routine -6272 +▁Available -6273 +▁illustr -6274 +▁adds -6275 +▁Ireland -6276 +▁procedure -6277 +▁engage -6278 +▁Rom -6279 +ca -6280 +▁circumst -6281 +▁Ryan -6282 +▁bottle -6283 +etime -6284 +▁Garden -6285 +▁crazy -6286 +utch -6287 +▁turning -6288 +▁YouTube -6289 +▁random -6290 +▁hosting -6291 +▁taught -6292 +▁rose -6293 +▁expectations -6294 +▁lift -6295 +state -6296 +▁Russia -6297 +▁command -6298 +▁recipes -6299 +▁Tay -6300 +front -6301 +▁Drive -6302 +secut -6303 +▁fo -6304 +▁improvement -6305 +▁alleged -6306 +▁excess -6307 +▁hur -6308 +▁tro -6309 +▁trained -6310 +▁sheet -6311 +▁noticed -6312 +▁mixture -6313 +▁festival -6314 +▁Bon -6315 +▁funny -6316 +illy -6317 +▁tech -6318 +▁OS -6319 +ATE -6320 +▁tab -6321 +▁shots -6322 +▁syn -6323 +▁flavor -6324 +▁reporting -6325 +▁passeng -6326 +▁guitar -6327 +▁ol -6328 +▁hoping -6329 +▁severe -6330 +▁entreprene -6331 +▁COVID -6332 +inder -6333 +▁suspect -6334 +▁eleg -6335 +ether -6336 +▁foundation -6337 +orgeous -6338 +▁Heart -6339 +ington -6340 +▁SU -6341 +▁upper -6342 +ossible -6343 +inem -6344 +anger -6345 +▁Building -6346 +▁Environment -6347 +▁blow -6348 +eration -6349 +▁clothing -6350 +▁scholars -6351 +▁publish -6352 +▁Non -6353 +▁ok -6354 +enced -6355 +anna -6356 +▁Italy -6357 +adium -6358 +▁authent -6359 +▁FA -6360 +▁climb -6361 +▁pink -6362 +comes -6363 +▁Pop -6364 +▁Senior -6365 +rad -6366 +iano -6367 +▁talks -6368 +▁kill -6369 +pat -6370 +▁grew -6371 +▁Son -6372 +▁pil -6373 +hered -6374 +▁Beaut -6375 +▁root -6376 +▁san -6377 +oster -6378 +▁landscape -6379 +tle -6380 +ayer -6381 +▁figures -6382 +▁millions -6383 +ERS -6384 +ums -6385 +▁machines -6386 +▁Country -6387 +ERE -6388 +So -6389 +iece -6390 +▁Jersey -6391 +iversary -6392 +▁Run -6393 +▁Sky -6394 +orders -6395 +▁tasks -6396 +▁vital -6397 +▁reward -6398 +▁attended -6399 +ikes -6400 +▁eggs -6401 +▁tall -6402 +▁identity -6403 +▁tested -6404 +▁hits -6405 +▁PS -6406 +▁Senate -6407 +▁coc -6408 +’. -6409 +▁integrated -6410 +▁champions -6411 +▁laugh -6412 +▁herself -6413 +▁trends -6414 +▁input -6415 +▁Division -6416 +▁Disney -6417 +forcement -6418 +▁vibr -6419 +▁anx -6420 +▁council -6421 +oral -6422 +▁? -6423 +▁Shop -6424 +▁Nick -6425 +▁chapter -6426 +▁Stock -6427 +▁Ref -6428 +HS -6429 +▁shift -6430 +▁mal -6431 +▁Jenn -6432 +▁guard -6433 +▁weak -6434 +▁dram -6435 +▁wealth -6436 +▁Dog -6437 +▁historical -6438 +▁Writ -6439 +▁fishing -6440 +▁incl -6441 +▁baking -6442 +.’ -6443 +▁airport -6444 +▁Proper -6445 +▁depth -6446 +▁AD -6447 +▁museum -6448 +▁improving -6449 +▁smile -6450 +▁invited -6451 +▁arrested -6452 +izz -6453 +host -6454 +RI -6455 +▁wash -6456 +luded -6457 +rition -6458 +▁accessories -6459 +dy -6460 +▁Professor -6461 +ampion -6462 +▁Safety -6463 +▁thin -6464 +▁profit -6465 +▁ease -6466 +▁unf -6467 +▁output -6468 +▁qualified -6469 +▁Ent -6470 +▁Ford -6471 +▁residential -6472 +rate -6473 +▁Want -6474 +riends -6475 +▁rear -6476 +▁upload -6477 +▁abuse -6478 +▁Ha -6479 +▁hire -6480 +▁authorities -6481 +▁tonight -6482 +▁carbon -6483 +▁Georgia -6484 +▁certified -6485 +▁skill -6486 +▁mountain -6487 +▁Fre -6488 +▁wet -6489 +ATION -6490 +▁Sales -6491 +remony -6492 +zil -6493 +▁ordered -6494 +pret -6495 +▁Far -6496 +▁bags -6497 +▁managing -6498 +▁instance -6499 +▁km -6500 +▁destination -6501 +▁Still -6502 +▁entered -6503 +▁thorough -6504 +▁Email -6505 +iana -6506 +▁sole -6507 +▁dropped -6508 +icial -6509 +▁entirely -6510 +▁recy -6511 +▁Bul -6512 +▁institutions -6513 +iami -6514 +▁terror -6515 +▁atmosphere -6516 +▁Silver -6517 +yers -6518 +▁Further -6519 +LS -6520 +▁Supp -6521 +▁Fed -6522 +▁Systems -6523 +▁Luc -6524 +▁Space -6525 +▁closely -6526 +▁sick -6527 +▁guidance -6528 +▁photography -6529 +PC -6530 +▁Stat -6531 +▁breast -6532 +▁Zeal -6533 +▁rating -6534 +ras -6535 +▁tiny -6536 +▁description -6537 +▁Tax -6538 +▁vend -6539 +▁Members -6540 +▁fuck -6541 +▁offices -6542 +▁scientific -6543 +▁transportation -6544 +▁layer -6545 +stone -6546 +▁printed -6547 +long -6548 +De -6549 +▁frequently -6550 +▁Fac -6551 +▁Dist -6552 +▁spin -6553 +eller -6554 +igned -6555 +va -6556 +agues -6557 +▁cooper -6558 +▁entr -6559 +▁EU -6560 +▁yards -6561 +▁shower -6562 +▁searching -6563 +▁cycle -6564 +▁dental -6565 +▁loans -6566 +▁delay -6567 +▁CO -6568 +▁Phone -6569 +▁failure -6570 +▁Pract -6571 +▁kne -6572 +▁medicine -6573 +MP -6574 +▁equal -6575 +▁lessons -6576 +izza -6577 +▁unable -6578 +▁protein -6579 +adow -6580 +ogue -6581 +▁broadcast -6582 +▁founded -6583 +sen -6584 +▁Aff -6585 +▁Finally -6586 +▁cm -6587 +▁column -6588 +▁flexible -6589 +quir -6590 +▁Tech -6591 +▁operate -6592 +▁bonus -6593 +▁typical -6594 +▁compens -6595 +▁Looking -6596 +▁rail -6597 +▁taxes -6598 +aduate -6599 +▁Hou -6600 +▁glad -6601 +▁Should -6602 +▁religious -6603 +▁Never -6604 +▁sac -6605 +▁Engineering -6606 +▁situations -6607 +▁vacc -6608 +▁awarded -6609 +▁bear -6610 +▁PDF -6611 +▁Ca -6612 +▁lad -6613 +▁Ball -6614 +▁Zealand -6615 +oes -6616 +▁Put -6617 +▁eligible -6618 +quality -6619 +▁Very -6620 +▁external -6621 +▁Mach -6622 +▁historic -6623 +▁Sat -6624 +▁alongside -6625 +icket -6626 +awn -6627 +UL -6628 +▁flood -6629 +▁strategic -6630 +▁OR -6631 +▁sudden -6632 +▁unlike -6633 +▁wra -6634 +▁DVD -6635 +worth -6636 +▁assessment -6637 +▁filed -6638 +▁Smart -6639 +osoph -6640 +ilst -6641 +▁networks -6642 +▁seriously -6643 +▁Sus -6644 +▁creates -6645 +▁workshop -6646 +Is -6647 +?" -6648 +umps -6649 +▁worst -6650 +▁rental -6651 +▁Unfortunately -6652 +xx -6653 +▁BE -6654 +▁Charles -6655 +▁transition -6656 +uting -6657 +▁fighting -6658 +▁critic -6659 +▁river -6660 +nam -6661 +▁membership -6662 +ircle -6663 +▁Mountain -6664 +oker -6665 +▁believes -6666 +asters -6667 +bi -6668 +▁platforms -6669 +omy -6670 +▁none -6671 +friendly -6672 +▁availability -6673 +▁attacks -6674 +▁versions -6675 +▁vul -6676 +▁Foot -6677 +▁tracks -6678 +class -6679 +uling -6680 +▁distinct -6681 +erman -6682 +▁younger -6683 +▁Es -6684 +tain -6685 +▁listening -6686 +osite -6687 +▁Fox -6688 +plate -6689 +▁faculty -6690 +▁motion -6691 +aturally -6692 +▁Ask -6693 +▁contribute -6694 +▁hasn -6695 +arrow -6696 +inos -6697 +!" -6698 +▁Professional -6699 +▁juice -6700 +II -6701 +▁proven -6702 +eding -6703 +▁Pacific -6704 +One -6705 +▁hopes -6706 +▁bab -6707 +onto -6708 +star -6709 +aze -6710 +With -6711 +▁joining -6712 +▁letters -6713 +irts -6714 +ucky -6715 +▁risks -6716 +▁performing -6717 +active -6718 +▁Ray -6719 +▁streets -6720 +car -6721 +▁soph -6722 +▁Ariz -6723 +ounter -6724 +you -6725 +▁developers -6726 +▁SC -6727 +▁conver -6728 +▁obl -6729 +▁cups -6730 +▁pounds -6731 +neys -6732 +Fi -6733 +▁cos -6734 +▁recording -6735 +▁Term -6736 +▁tip -6737 +ati -6738 +▁Tele -6739 +zer -6740 +▁Harr -6741 +▁Easy -6742 +▁lucky -6743 +▁Kent -6744 +▁informed -6745 +oured -6746 +▁choosing -6747 +▁surprised -6748 +ented -6749 +▁grass -6750 +▁facilit -6751 +▁meals -6752 +)| -6753 +▁mortgage -6754 +nic -6755 +▁Phys -6756 +obby -6757 +▁infect -6758 +▁capture -6759 +▁liquid -6760 +ican -6761 +▁banks -6762 +▁diss -6763 +▁tournament -6764 +▁PA -6765 +agon -6766 +▁Leg -6767 +▁kit -6768 +▁Fall -6769 +amps -6770 +▁LLC -6771 +▁anticip -6772 +elry -6773 +▁papers -6774 +▁Field -6775 +▁savings -6776 +earing -6777 +At -6778 +▁privacy -6779 +cers -6780 +▁discip -6781 +To -6782 +pons -6783 +uine -6784 +▁Event -6785 +aping -6786 +▁hurt -6787 +born -6788 +▁rein -6789 +▁regulations -6790 +▁Ram -6791 +▁Mom -6792 +▁Broad -6793 +▁inch -6794 +▁decade -6795 +ashed -6796 +law -6797 +ially -6798 +▁charm -6799 +▁Taylor -6800 +▁submitted -6801 +rency -6802 +celer -6803 +▁Kat -6804 +etic -6805 +▁arg -6806 +▁west -6807 +▁Northern -6808 +▁Ter -6809 +▁blend -6810 +▁ille -6811 +Le -6812 +▁reputation -6813 +▁LED -6814 +▁bat -6815 +Se -6816 +▁Po -6817 +▁suggested -6818 +▁monitor -6819 +▁hall -6820 +▁proceed -6821 +▁liked -6822 +▁relief -6823 +▁organized -6824 +▁filter -6825 +▁shops -6826 +▁domain -6827 +▁consequ -6828 +▁mic -6829 +▁Lind -6830 +▁belief -6831 +▁sight -6832 +▁engagement -6833 +entle -6834 +▁Cut -6835 +▁Source -6836 +▁Miami -6837 +bury -6838 +▁extract -6839 +▁pulled -6840 +Read -6841 +▁Radio -6842 +▁Come -6843 +▁Credit -6844 +▁gorgeous -6845 +days -6846 +▁justice -6847 +uter -6848 +pes -6849 +▁Cab -6850 +▁drawing -6851 +▁Sea -6852 +▁negoti -6853 +▁circumstances -6854 +▁capable -6855 +▁quote -6856 +▁Arab -6857 +▁) -6858 +▁tank -6859 +▁monitoring -6860 +ava -6861 +▁empt -6862 +▁crucial -6863 +rell -6864 +▁Think -6865 +▁legs -6866 +▁Order -6867 +▁portfolio -6868 +▁Bible -6869 +▁sky -6870 +bing -6871 +ulf -6872 +ographic -6873 +▁hate -6874 +▁immediate -6875 +▁increases -6876 +▁ads -6877 +▁arrive -6878 +▁exhibition -6879 +▁stir -6880 +▁Ms -6881 +bar -6882 +▁believed -6883 +foot -6884 +▁penal -6885 +▁moves -6886 +▁Insurance -6887 +▁linked -6888 +ta -6889 +athan -6890 +▁Continue -6891 +▁counsel -6892 +▁relatively -6893 +▁treatments -6894 +▁faces -6895 +▁attached -6896 +▁Pak -6897 +▁manual -6898 +faction -6899 +▁soil -6900 +▁crack -6901 +▁adm -6902 +▁defend -6903 +illiant -6904 +uis -6905 +▁mm -6906 +▁jun -6907 +ura -6908 +▁Mir -6909 +▁planet -6910 +resents -6911 +bles -6912 +Ad -6913 +▁technique -6914 +cknow -6915 +▁concert -6916 +▁enjoying -6917 +rowse -6918 +▁guidelines -6919 +▁listing -6920 +esides -6921 +▁directed -6922 +▁interface -6923 +▁injured -6924 +arters -6925 +▁vast -6926 +▁hosted -6927 +▁execut -6928 +▁dent -6929 +▁LA -6930 +▁ast -6931 +▁Conf -6932 +▁Rod -6933 +▁spark -6934 +▁garage -6935 +▁authors -6936 +▁hospit -6937 +▁memories -6938 +uration -6939 +rich -6940 +▁contrast -6941 +▁aside -6942 +▁volunteers -6943 +▁equipped -6944 +sey -6945 +▁Ron -6946 +ardens -6947 +▁Ur -6948 +▁normally -6949 +ppy -6950 +▁estimated -6951 +▁:) -6952 +▁promise -6953 +▁firms -6954 +▁Republican -6955 +▁dreams -6956 +▁Happy -6957 +▁Pow -6958 +onym -6959 +▁Jac -6960 +▁warn -6961 +▁trig -6962 +▁pin -6963 +hot -6964 +▁trick -6965 +▁phase -6966 +▁depress -6967 +▁rice -6968 +▁Remember -6969 +▁urban -6970 +▁illness -6971 +By -6972 +▁Being -6973 +▁Quality -6974 +iger -6975 +▁agents -6976 +▁Justice -6977 +▁acid -6978 +▁prove -6979 +ba -6980 +▁consistent -6981 +oty -6982 +▁dust -6983 +▁spoke -6984 +▁Airport -6985 +▁Houston -6986 +▁pitch -6987 +▁Bed -6988 +▁organis -6989 +▁pleasure -6990 +▁arms -6991 +holders -6992 +aints -6993 +▁matches -6994 +▁Medicine -6995 +AA -6996 +ults -6997 +Bl -6998 +%. -6999 +▁Ide -7000 +▁Talk -7001 +▁portion -7002 +▁Conc -7003 +▁index -7004 +▁Line -7005 +▁chances -7006 +ogether -7007 +▁Brazil -7008 +asant -7009 +▁fasc -7010 +▁Fact -7011 +.' -7012 +icit -7013 +▁lapt -7014 +▁newly -7015 +▁chose -7016 +▁Personal -7017 +▁objects -7018 +▁Carl -7019 +▁dynamic -7020 +ensity -7021 +▁breath -7022 +▁finance -7023 +rm -7024 +▁Arizona -7025 +▁refund -7026 +▁Asian -7027 +▁Living -7028 +▁Standard -7029 +▁Prom -7030 +▁proof -7031 +▁seed -7032 +SC -7033 +eling -7034 +▁passing -7035 +▁continuing -7036 +But -7037 +▁visited -7038 +▁represents -7039 +▁Officer -7040 +▁drinking -7041 +▁Give -7042 +site -7043 +ership -7044 +▁iPad -7045 +cket -7046 +▁formed -7047 +▁storm -7048 +▁ultimate -7049 +▁mile -7050 +pack -7051 +inois -7052 +alle -7053 +▁Brad -7054 +▁Mill -7055 +▁roles -7056 +▁border -7057 +▁Estate -7058 +▁forever -7059 +▁MO -7060 +▁discussed -7061 +▁superv -7062 +▁ceremony -7063 +▁Cru -7064 +annels -7065 +▁approval -7066 +iking -7067 +▁Las -7068 +▁zone -7069 +amber -7070 +▁Welcome -7071 +▁Army -7072 +▁Season -7073 +▁Student -7074 +▁id -7075 +▁suc -7076 +she -7077 +▁stim -7078 +▁exposure -7079 +▁recommendations -7080 +adel -7081 +▁gaming -7082 +▁dealing -7083 +stal -7084 +▁sending -7085 +ultural -7086 +▁Oak -7087 +▁Iran -7088 +▁stake -7089 +▁evol -7090 +▁Therefore -7091 +▁phones -7092 +MC -7093 +anes -7094 +▁Sav -7095 +▁Kevin -7096 +▁capabilities -7097 +▁teasp -7098 +▁division -7099 +▁gallery -7100 +▁Webs -7101 +uclear -7102 +Americ -7103 +whel -7104 +amsung -7105 +▁boxes -7106 +▁downtown -7107 +▁saving -7108 +▁presents -7109 +▁collected -7110 +▁holidays -7111 +respond -7112 +▁lawyer -7113 +▁possibility -7114 +▁fairly -7115 +▁Again -7116 +▁implementation -7117 +iki -7118 +▁vulner -7119 +▁pra -7120 +ainless -7121 +▁mand -7122 +▁susp -7123 +▁hat -7124 +GA -7125 +ja -7126 +▁ensuring -7127 +▁Choose -7128 +▁permanent -7129 +aper -7130 +▁attractive -7131 +▁pharm -7132 +▁smell -7133 +▁cookies -7134 +▁Administration -7135 +▁constit -7136 +▁flash -7137 +▁Site -7138 +▁industries -7139 +ih -7140 +▁tub -7141 +▁hidden -7142 +▁suggestions -7143 +▁scheme -7144 +aste -7145 +bro -7146 +▁trib -7147 +▁finds -7148 +lers -7149 +▁Experience -7150 +izer -7151 +▁porn -7152 +▁Natural -7153 +▁Brian -7154 +ione -7155 +wear -7156 +urse -7157 +▁recognize -7158 +▁Express -7159 +RS -7160 +▁Kenn -7161 +▁instrument -7162 +missions -7163 +▁facts -7164 +phy -7165 +▁Ju -7166 +▁theory -7167 +▁heads -7168 +▁vari -7169 +pot -7170 +▁priority -7171 +▁mainly -7172 +▁acknow -7173 +zes -7174 +▁($ -7175 +lessly -7176 +▁Meanwhile -7177 +Sc -7178 +▁legislation -7179 +ffered -7180 +rible -7181 +▁reader -7182 +▁Clin -7183 +▁Ros -7184 +▁Isl -7185 +▁bodies -7186 +▁Case -7187 +FA -7188 +▁butt -7189 +▁liber -7190 +▁categories -7191 +▁Chall -7192 +▁posting -7193 +▁realized -7194 +▁mut -7195 +▁Hollywood -7196 +anned -7197 +page -7198 +inson -7199 +▁Software -7200 +▁communications -7201 +▁Vers -7202 +▁Ba -7203 +▁solve -7204 +▁Own -7205 +▁bench -7206 +▁personally -7207 +▁Dun -7208 +▁garlic -7209 +▁Secretary -7210 +▁upgrade -7211 +da -7212 +▁bars -7213 +allas -7214 +▁Queen -7215 +boy -7216 +▁bridge -7217 +phones -7218 +▁Emer -7219 +Book -7220 +EA -7221 +▁Stay -7222 +▁incredibly -7223 +▁USB -7224 +then -7225 +▁ancient -7226 +▁Learning -7227 +▁Policy -7228 +CT -7229 +▁Create -7230 +▁reform -7231 +▁tradition -7232 +esy -7233 +▁|| -7234 +▁permission -7235 +▁hole -7236 +▁Bang -7237 +stra -7238 +ingu -7239 +▁tiss -7240 +osc -7241 +▁Prime -7242 +▁Anal -7243 +▁generate -7244 +▁Yet -7245 +odd -7246 +anny -7247 +ounce -7248 +▁Cand -7249 +▁exec -7250 +▁CN -7251 +▁copyright -7252 +▁packages -7253 +▁calendar -7254 +▁rum -7255 +odge -7256 +▁handling -7257 +tw -7258 +ials -7259 +▁substant -7260 +▁travell -7261 +▁pace -7262 +▁basketball -7263 +▁east -7264 +▁magic -7265 +▁Hold -7266 +▁debate -7267 +parent -7268 +OO -7269 +▁victims -7270 +▁raw -7271 +▁claimed -7272 +▁Level -7273 +That -7274 +▁Additionally -7275 +iti -7276 +▁celebration -7277 +▁clar -7278 +▁walked -7279 +▁orange -7280 +▁programming -7281 +▁Jr -7282 +▁doctors -7283 +▁MD -7284 +HA -7285 +ulpt -7286 +▁achieved -7287 +▁fest -7288 +▁giant -7289 +▁cotton -7290 +▁Toronto -7291 +▁absor -7292 +▁forth -7293 +▁purchasing -7294 +▁habit -7295 +onna -7296 +▁prospect -7297 +▁replaced -7298 +▁Cro -7299 +▁Stan -7300 +▁bare -7301 +▁Film -7302 +burgh -7303 +▁fifth -7304 +▁explains -7305 +uls -7306 +▁tooth -7307 +▁Illinois -7308 +▁desired -7309 +▁Studies -7310 +level -7311 +CD -7312 +zing -7313 +isa -7314 +▁king -7315 +▁Tool -7316 +▁manufacturers -7317 +▁spots -7318 +▁titles -7319 +▁gym -7320 +▁saved -7321 +▁Dar -7322 +▁seasons -7323 +▁cuts -7324 +season -7325 +▁somewhere -7326 +▁marked -7327 +▁Auto -7328 +▁proposal -7329 +▁Consult -7330 +▁insight -7331 +▁marks -7332 +▁hotels -7333 +▁initiative -7334 +uster -7335 +▁feelings -7336 +▁venue -7337 +▁slowly -7338 +RL -7339 +▁singer -7340 +▁specialist -7341 +▁suffering -7342 +▁Produ -7343 +▁Catholic -7344 +ila -7345 +▁NFL -7346 +▁expressed -7347 +▁Story -7348 +▁Capital -7349 +▁compat -7350 +▁requests -7351 +▁Irish -7352 +▁drinks -7353 +▁Material -7354 +imize -7355 +▁architecture -7356 +App -7357 +iot -7358 +▁vegetables -7359 +▁Save -7360 +▁Sep -7361 +aron -7362 +▁Agency -7363 +igate -7364 +esh -7365 +▁buyers -7366 +acon -7367 +aters -7368 +▁Joseph -7369 +▁merch -7370 +▁volunteer -7371 +▁gay -7372 +▁exceptional -7373 +▁impossible -7374 +▁stuck -7375 +▁Liber -7376 +▁Table -7377 +▁meets -7378 +▁enables -7379 +▁swimming -7380 +stream -7381 +▁combine -7382 +inton -7383 +▁murder -7384 +▁broke -7385 +bridge -7386 +▁publication -7387 +▁announcement -7388 +▁destroy -7389 +▁tie -7390 +▁extension -7391 +ylvan -7392 +▁causing -7393 +▁ultimately -7394 +▁enem -7395 +VER -7396 +▁consultation -7397 +▁encouraged -7398 +▁reducing -7399 +▁muscle -7400 +▁err -7401 +▁accomplish -7402 +▁Pakistan -7403 +▁Mess -7404 +regon -7405 +nesota -7406 +▁split -7407 +ologist -7408 +▁packaging -7409 +▁yard -7410 +▁surprising -7411 +▁Mix -7412 +▁lets -7413 +▁Pu -7414 +▁publ -7415 +▁Bell -7416 +ickets -7417 +▁magn -7418 +aid -7419 +▁Short -7420 +▁Vegas -7421 +▁Map -7422 +▁actor -7423 +▁rig -7424 +▁printing -7425 +▁Would -7426 +▁enterprise -7427 +▁engaged -7428 +▁Autom -7429 +▁pit -7430 +lements -7431 +▁describe -7432 +▁Camer -7433 +▁heav -7434 +▁massage -7435 +▁pricing -7436 +run -7437 +▁DI -7438 +bel -7439 +apore -7440 +des -7441 +aska -7442 +▁Motor -7443 +▁electrical -7444 +▁noise -7445 +▁mood -7446 +▁Location -7447 +▁widely -7448 +▁preparation -7449 +▁Kids -7450 +ifer -7451 +▁seeds -7452 +▁reasonable -7453 +▁talked -7454 +▁Pen -7455 +▁enroll -7456 +▁blocks -7457 +▁covering -7458 +▁performances -7459 +▁Labor -7460 +ns -7461 +▁Spain -7462 +▁breaking -7463 +▁expansion -7464 +bell -7465 +▁recognition -7466 +▁pill -7467 +olis -7468 +▁default -7469 +▁framework -7470 +eah -7471 +▁wins -7472 +▁Recent -7473 +▁genuine -7474 +▁overwhel -7475 +▁traveling -7476 +▁remark -7477 +▁blank -7478 +▁Forest -7479 +▁seats -7480 +rage -7481 +▁classroom -7482 +RC -7483 +▁agric -7484 +wan -7485 +▁knock -7486 +inator -7487 +cons -7488 +▁Ira -7489 +▁interactive -7490 +uct -7491 +▁concrete -7492 +▁neighb -7493 +▁Theatre -7494 +▁Ess -7495 +▁CB -7496 +iler -7497 +▁Adam -7498 +▁unw -7499 +▁pand -7500 +▁Gallery -7501 +)|| -7502 +▁Studio -7503 +▁birds -7504 +▁formal -7505 +▁Force -7506 +▁Pin -7507 +▁compr -7508 +▁dishes -7509 +▁Band -7510 +wich -7511 +▁Memorial -7512 +▁writers -7513 +▁Ice -7514 +▁franch -7515 +▁resistance -7516 +▁Following -7517 +▁gall -7518 +▁empty -7519 +▁Rs -7520 +▁Toy -7521 +gypt -7522 +▁brilliant -7523 +▁spray -7524 +▁consists -7525 +▁constant -7526 +ulum -7527 +▁scenes -7528 +▁increasingly -7529 +▁staying -7530 +▁compliance -7531 +proof -7532 +▁Square -7533 +▁incorpor -7534 +▁Mrs -7535 +▁resulting -7536 +▁acting -7537 +▁Davis -7538 +▁Annual -7539 +EP -7540 +▁duty -7541 +▁suggests -7542 +▁pic -7543 +▁dad -7544 +▁recover -7545 +ludes -7546 +▁managers -7547 +▁Fred -7548 +▁Member -7549 +▁experiment -7550 +nda -7551 +▁Treat -7552 +▁basically -7553 +▁spiritual -7554 +ateful -7555 +axy -7556 +ding -7557 +▁Things -7558 +▁professor -7559 +ifies -7560 +▁anyway -7561 +▁bow -7562 +▁Diego -7563 +▁nights -7564 +▁Paper -7565 +▁Mah -7566 +being -7567 +▁Spirit -7568 +▁mere -7569 +child -7570 +▁Eric -7571 +books -7572 +▁FL -7573 +leep -7574 +▁graphics -7575 +otted -7576 +▁Dam -7577 +▁lists -7578 +▁Partners -7579 +▁Jord -7580 +▁forecast -7581 +▁slic -7582 +▁slot -7583 +▁Solutions -7584 +▁scan -7585 +▁pride -7586 +▁deck -7587 +▁Samsung -7588 +▁Roman -7589 +abetes -7590 +’, -7591 +▁prize -7592 +▁authority -7593 +▁Shipping -7594 +▁producing -7595 +▁Ly -7596 +rated -7597 +▁Interest -7598 +ilton -7599 +alo -7600 +▁centers -7601 +▁clicking -7602 +▁Seattle -7603 +irus -7604 +▁Model -7605 +▁packed -7606 +una -7607 +▁wireless -7608 +▁Gro -7609 +erate -7610 +alse -7611 +▁Books -7612 +▁everywhere -7613 +▁aims -7614 +ghan -7615 +▁legend -7616 +acle -7617 +▁Golden -7618 +▁Minnesota -7619 +▁enthusi -7620 +ashes -7621 +▁whenever -7622 +▁expenses -7623 +vas -7624 +▁Pur -7625 +▁Age -7626 +▁indeed -7627 +▁healing -7628 +▁Limited -7629 +utional -7630 +▁interpret -7631 +▁closing -7632 +▁Cover -7633 +▁talented -7634 +▁singles -7635 +▁anniversary -7636 +▁succeed -7637 +▁inner -7638 +inding -7639 +▁Lew -7640 +making -7641 +▁involves -7642 +rome -7643 +▁Swed -7644 +▁pocket -7645 +ls -7646 +▁riding -7647 +▁unex -7648 +▁connections -7649 +▁Sound -7650 +▁GM -7651 +heast -7652 +▁channels -7653 +▁obtained -7654 +pends -7655 +▁narr -7656 +▁founder -7657 +▁vice -7658 +▁OK -7659 +ylvania -7660 +▁Magazine -7661 +▁Perhaps -7662 +▁displayed -7663 +▁Customer -7664 +▁Dream -7665 +▁bunch -7666 +▁assum -7667 +▁Total -7668 +▁opens -7669 +greg -7670 +▁Collection -7671 +▁delivering -7672 +▁Month -7673 +▁Bad -7674 +▁Dallas -7675 +▁designers -7676 +▁struggle -7677 +ureau -7678 +▁lemon -7679 +Press -7680 +▁trips -7681 +▁Based -7682 +▁Steel -7683 +▁attrib -7684 +▁differences -7685 +stein -7686 +▁acts -7687 +▁ending -7688 +▁Working -7689 +▁driven -7690 +▁Pict -7691 +lder -7692 +abeth -7693 +▁CP -7694 +nders -7695 +▁Station -7696 +ronics -7697 +▁defined -7698 +▁Mother -7699 +▁watched -7700 +▁complim -7701 +▁improvements -7702 +▁mob -7703 +▁Cloud -7704 +▁primarily -7705 +coin -7706 +▁CL -7707 +▁loving -7708 +▁vintage -7709 +bits -7710 +▁Action -7711 +▁gender -7712 +▁boss -7713 +sters -7714 +▁guaranteed -7715 +▁introduction -7716 +▁Rub -7717 +▁Oregon -7718 +▁booking -7719 +▁Dark -7720 +ambling -7721 +▁returning -7722 +▁Rand -7723 +oom -7724 +▁Sym -7725 +▁sensitive -7726 +▁fits -7727 +▁shouldn -7728 +▁Eastern -7729 +▁SS -7730 +▁podcast -7731 +Fr -7732 +▁apparently -7733 +▁Everyone -7734 +▁Anth -7735 +▁Base -7736 +▁politics -7737 +owa -7738 +▁officially -7739 +pool -7740 +issions -7741 +▁precise -7742 +oned -7743 +▁Common -7744 +▁rug -7745 +▁Products -7746 +rive -7747 +▁alive -7748 +▁headed -7749 +▁Bru -7750 +▁Return -7751 +AB -7752 +▁chopped -7753 +su -7754 +▁Miller -7755 +iders -7756 +▁fing -7757 +▁unus -7758 +▁Jay -7759 +▁Spec -7760 +▁Blog -7761 +▁coat -7762 +▁Change -7763 +▁narrow -7764 +▁highlights -7765 +▁protest -7766 +▁trim -7767 +▁recre -7768 +AND -7769 +▁potentially -7770 +▁honey -7771 +▁shell -7772 +▁Transport -7773 +ailing -7774 +▁percentage -7775 +▁authentic -7776 +▁Austin -7777 +▁filling -7778 +▁tape -7779 +▁maintaining -7780 +▁lin -7781 +▁Capt -7782 +▁analyst -7783 +▁retirement -7784 +▁Cry -7785 +▁casual -7786 +▁speaker -7787 +▁crash -7788 +pson -7789 +atics -7790 +riers -7791 +▁Among -7792 +▁assistant -7793 +▁charity -7794 +▁personality -7795 +▁Corporation -7796 +wart -7797 +▁acquis -7798 +▁scientists -7799 +jo -7800 +▁Kingdom -7801 +▁resident -7802 +▁Guard -7803 +▁falling -7804 +inent -7805 +lose -7806 +scribe -7807 +raid -7808 +▁plot -7809 +▁DO -7810 +▁elev -7811 +▁Iraq -7812 +pection -7813 +iac -7814 +▁bills -7815 +▁opinions -7816 +onut -7817 +▁Josh -7818 +▁Barb -7819 +▁strike -7820 +▁licensed -7821 +▁aircraft -7822 +▁heading -7823 +ali -7824 +▁CR -7825 +▁Nic -7826 +▁naturally -7827 +▁Dead -7828 +acher -7829 +raction -7830 +▁consumption -7831 +ydney -7832 +▁renov -7833 +▁Sarah -7834 +▁carrying -7835 +▁tired -7836 +▁gentle -7837 +arliam -7838 +▁colours -7839 +Cont -7840 +▁Jewish -7841 +▁Egypt -7842 +▁correspond -7843 +▁obviously -7844 +▁functional -7845 +▁preparing -7846 +▁ -7847 +e -7848 +t -7849 +a -7850 +o -7851 +i -7852 +n -7853 +s -7854 +r -7855 +h -7856 +l -7857 +d -7858 +c -7859 +u -7860 +m -7861 +p -7862 +g -7863 +f -7864 +y -7865 +w -7866 +b -7867 +. -7868 +v -7869 +, -7870 +k -7871 +T -7872 +S -7873 +I -7874 +A -7875 +- -7876 +C -7877 +0 -7878 +M -7879 +1 -7880 +P -7881 +x -7882 +B -7883 +2 -7884 +W -7885 +D -7886 +R -7887 +E -7888 +H -7889 +F -7890 +’ -7891 +L -7892 +N -7893 +O -7894 +: -7895 +' -7896 +G -7897 +j -7898 +) -7899 +( -7900 +z -7901 +3 -7902 +5 -7903 +q -7904 +4 -7905 +U -7906 +" -7907 +9 -7908 +J -7909 +8 -7910 +6 -7911 +V -7912 +Y -7913 +K -7914 +| -7915 +7 -7916 +! -7917 +/ -7918 +“ -7919 +” -7920 +? -7921 +– -7922 +; -7923 +& -7924 +$ -7925 +Q -7926 +% -7927 +— -7928 +X -7929 +Z -7930 +* -7931 diff --git a/records/track_non_record_16mb/2026-04-30_SP8192_BPE_Mamba3_d448_ssm4_1xH100/reqs.txt b/records/track_non_record_16mb/2026-04-30_SP8192_BPE_Mamba3_d448_ssm4_1xH100/reqs.txt new file mode 100644 index 0000000000..342c764ae7 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_SP8192_BPE_Mamba3_d448_ssm4_1xH100/reqs.txt @@ -0,0 +1,12 @@ +# mamba-ssm: install from GitHub source (requires CUDA toolkit): +# MAMBA_FORCE_BUILD=TRUE pip install --no-cache-dir --force-reinstall \ +# git+https://github.com/state-spaces/mamba.git --no-build-isolation +numpy +tqdm +huggingface-hub +kernels +setuptools +typing-extensions==4.15.0 +datasets +tiktoken +sentencepiece diff --git a/records/track_non_record_16mb/2026-04-30_SP8192_BPE_Mamba3_d448_ssm4_1xH100/setup_sp8192_data.sh b/records/track_non_record_16mb/2026-04-30_SP8192_BPE_Mamba3_d448_ssm4_1xH100/setup_sp8192_data.sh new file mode 100644 index 0000000000..e2c9ebf445 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_SP8192_BPE_Mamba3_d448_ssm4_1xH100/setup_sp8192_data.sh @@ -0,0 +1,76 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Record-local data setup for this submission. +# Exports SP8192 dataset shards (80 train shards) into this record folder. +# +# Usage: +# bash ./setup_sp8192_data.sh +# Optional: +# HF_TOKEN=... bash ./setup_sp8192_data.sh +# VENV_DIR=.venv bash ./setup_sp8192_data.sh + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" + +VOCAB_SIZE="${VOCAB_SIZE:-8192}" +MAX_TRAIN_SHARDS="${MAX_TRAIN_SHARDS:-80}" +VENV_DIR="${VENV_DIR:-.venv}" +OUTPUT_ROOT="${SCRIPT_DIR}/sp8192_data" +TOKENIZER_MODEL="${SCRIPT_DIR}/fineweb_8192_bpe.model" +REQ_FILE="${SCRIPT_DIR}/reqs.txt" + +if [[ ! -f "${REPO_ROOT}/build_sp_dataset.sh" ]]; then + echo "ERROR: build_sp_dataset.sh not found at repo root: ${REPO_ROOT}" >&2 + exit 1 +fi + +if [[ ! -f "${TOKENIZER_MODEL}" ]]; then + echo "ERROR: tokenizer model not found: ${TOKENIZER_MODEL}" >&2 + exit 1 +fi + +if [[ ! -f "${REQ_FILE}" ]]; then + echo "ERROR: requirements file not found: ${REQ_FILE}" >&2 + exit 1 +fi + +mkdir -p "${OUTPUT_ROOT}" + +echo "[setup] repo_root=${REPO_ROOT}" +echo "[setup] output_root=${OUTPUT_ROOT}" +echo "[setup] vocab_size=${VOCAB_SIZE} max_train_shards=${MAX_TRAIN_SHARDS}" +echo "[setup] tokenizer_reuse=${TOKENIZER_MODEL}" +echo "[setup] venv=${VENV_DIR}" +echo "[setup] reqs=${REQ_FILE}" + +# Bootstrap venv + deps if needed so this script is self-contained. +if [[ ! -f "${REPO_ROOT}/${VENV_DIR}/bin/activate" ]]; then + echo "[setup] creating virtualenv at ${REPO_ROOT}/${VENV_DIR}" + python3 -m venv "${REPO_ROOT}/${VENV_DIR}" +fi + +# shellcheck disable=SC1090 +source "${REPO_ROOT}/${VENV_DIR}/bin/activate" +python3 -m pip install --upgrade pip wheel setuptools >/dev/null +python3 -m pip install -r "${REQ_FILE}" +# Install mamba-ssm CUDA extension from source (official GitHub repo) +echo "[setup] installing mamba-ssm from source..." +MAMBA_FORCE_BUILD=TRUE pip install --no-cache-dir --force-reinstall \ + git+https://github.com/state-spaces/mamba.git --no-build-isolation 2>/dev/null || \ + echo "[warn] mamba-ssm install failed — ensure CUDA toolkit is available" + +cd "${REPO_ROOT}" + +VOCAB_SIZE="${VOCAB_SIZE}" \ +VENV_DIR="${VENV_DIR}" \ +OUTPUT_ROOT="${OUTPUT_ROOT}" \ +MAX_TRAIN_SHARDS="${MAX_TRAIN_SHARDS}" \ +EXISTING_TOKENIZER_MODEL="${TOKENIZER_MODEL}" \ +bash ./build_sp_dataset.sh + +echo +echo "[done] dataset path:" +echo " ${OUTPUT_ROOT}/datasets/fineweb10B_sp8192" +echo "[done] train shards count:" +find "${OUTPUT_ROOT}/datasets/fineweb10B_sp8192" -maxdepth 1 -name 'fineweb_train_*.bin' | wc -l diff --git a/records/track_non_record_16mb/2026-04-30_SP8192_BPE_Mamba3_d448_ssm4_1xH100/submission.json b/records/track_non_record_16mb/2026-04-30_SP8192_BPE_Mamba3_d448_ssm4_1xH100/submission.json new file mode 100644 index 0000000000..12f51f8ad0 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_SP8192_BPE_Mamba3_d448_ssm4_1xH100/submission.json @@ -0,0 +1,17 @@ +{ + "name": "SP8192 BPE + Mamba3 SSM Hybrid (d448, ssm_every_n:4, 1xH100, 30min)", + "val_bpb": 1.26060944, + "val_loss": 3.25624330, + "pre_quant_val_bpb": 1.2542, + "pre_quant_val_loss": 3.2398, + "bytes_total": 17260594, + "bytes_model_int8_zstd": 17028714, + "bytes_code": 231880, + "step_stop": 12278, + "wallclock_seconds": 1800.080, + "track": "non-record-16mb", + "blurb": "SP8192 BPE non-record run on 1xH100 with GPTQ int8+zstd export. Hybrid architecture: 9-layer stacked transformer with Mamba3 SSM blocks every 4th layer (2 SSM, 7 GQA attention). Trained for 30 minutes (1800s wallclock cap) with 20 warmup steps, SWA, and Muon optimizer. Over 16MB budget by ~1.26MB — demonstrates SSM hybrid viability for future budget-constrained entries.", + "author": "Dex Hunter", + "github_id": "dexhunter", + "date": "2026-04-30" +} diff --git a/records/track_non_record_16mb/2026-04-30_SP8192_BPE_Mamba3_d448_ssm4_1xH100/train.log b/records/track_non_record_16mb/2026-04-30_SP8192_BPE_Mamba3_d448_ssm4_1xH100/train.log new file mode 100644 index 0000000000..18e761f1ab --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_SP8192_BPE_Mamba3_d448_ssm4_1xH100/train.log @@ -0,0 +1,4909 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import importlib +import io +import json +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +_MAMBA3_IMPORT_ERROR: Exception | None = None +try: + from mamba_ssm.modules.mamba3 import Mamba3 as _OfficialMamba3 +except Exception as exc: # pragma: no cover - depends on CUDA extension install + _MAMBA3_IMPORT_ERROR = exc + _OfficialMamba3 = None +# Increase dynamo cache limit to avoid recompilation fallback when training conditions change +# (e.g., distillation activation, rotary cache identity changes). Default is 8, which is too low. +torch._dynamo.config.cache_size_limit = 64 +# Workaround for torch 2.10.0 inductor bug in joint_graph `mul_softmax_pattern` that crashes +# with "Tried to erase Node mul_N but it still had 1 users" during mid-training recompiles. +# The keep-alive fallback (suppress_errors) kicks the *entire* forward into eager, which is +# catastrophic for step time — so we defuse the broken pattern at its source instead. +# +# Strategy: +# (1) Monkey-patch `mul_softmax_pattern` in the joint_graph module and in every PatternEntry +# handler slot that references it. Replace with a no-op that never rewrites the graph. +# (2) Keep suppress_errors=True only as a last-resort safety net, so if a different pattern +# fails during a mid-training recompile the specific subgraph falls back to eager instead +# of killing the whole run. +torch._dynamo.config.suppress_errors = True +def _pg_noop_mul_softmax_pattern(match, *args, **kwargs): # noqa: ANN001 + # No rewrite: leave the matched subgraph alone. Inductor will still lower it correctly + # through the generic softmax/mul path — we just give up this one fusion opportunity. + return +try: + from torch._inductor.fx_passes import joint_graph as _pg_joint_graph + # (a) Replace the module-level function so future imports resolve to the no-op. + if hasattr(_pg_joint_graph, "mul_softmax_pattern"): + _pg_joint_graph.mul_softmax_pattern = _pg_noop_mul_softmax_pattern + # (b) Walk the registered PatternMatcherPass and swap any PatternEntry whose handler is the + # buggy function. In torch 2.10, `patterns.patterns` is a defaultdict[key, list[entry]]. + _pg_patterns = getattr(_pg_joint_graph, "patterns", None) + if _pg_patterns is not None: + _pg_inner = getattr(_pg_patterns, "patterns", None) + if _pg_inner is not None: + # Handle both dict-of-list and plain-list shapes. + if isinstance(_pg_inner, dict): + _pg_iter = [_e for _lst in _pg_inner.values() for _e in _lst] + else: + _pg_iter = list(_pg_inner) + for _entry in _pg_iter: + _h = getattr(_entry, "handler", None) + if _h is None: + continue + _qn = getattr(_h, "__qualname__", "") or getattr(_h, "__name__", "") + if "mul_softmax_pattern" in _qn: + try: + _entry.handler = _pg_noop_mul_softmax_pattern + except Exception: + pass +except Exception: + # If torch's internal layout has shifted, fall through to the suppress_errors safety net. + pass +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/dual_bpe/datasets/fineweb10B_sp8192") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_8192_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + # Optional cap for fast local smoke runs; 0 means full validation split. + val_max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 200)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 5.0)) + use_swiglu = bool(int(os.environ.get("USE_SWIGLU", "1"))) + # Sliding window eval: only score tokens beyond prefix_len in each window. + # eval_stride_frac=0.5 means stride=seq_len//2 → each scored token has ≥seq_len//2 tokens of context. + # eval_stride_frac=1.0 (default) = original non-overlapping behaviour. + eval_stride_frac = float(os.environ.get("EVAL_STRIDE_FRAC", "0.5")) + # Long-context eval: evaluate at a longer sequence length than training. + # 0 = same as train_seq_len. Pair with NTK RoPE scaling (eval_rope_scale>1) for best results. + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", "0")) + # NTK-aware RoPE scaling at eval: new_base = rope_base * eval_rope_scale^(head_dim/(head_dim-2)). + # Suggested: eval_rope_scale = (eval_seq_len / train_seq_len) ** 2 (≈4 for 2× context) + eval_rope_scale = float(os.environ.get("EVAL_ROPE_SCALE", "1.0")) + # Optional extra eval contexts to sweep at the end of a run. These do not affect the + # in-training validation path unless promoted to the primary eval context via EVAL_SEQ_LEN. + eval_sweep_seq_lens = os.environ.get("EVAL_SWEEP_SEQ_LENS", "").strip() + eval_sweep_rope_scales = os.environ.get("EVAL_SWEEP_ROPE_SCALES", "").strip() + # Multi-context eval blend: evaluate multiple contexts on the same scored token blocks and + # blend their token probabilities. Set FINAL_EVAL_MODE=blend to make this the official score. + eval_blend_seq_lens = os.environ.get("EVAL_BLEND_SEQ_LENS", "").strip() + eval_blend_rope_scales = os.environ.get("EVAL_BLEND_ROPE_SCALES", "").strip() + eval_blend_weights = os.environ.get("EVAL_BLEND_WEIGHTS", "").strip() + # 0 = inherit EVAL_STRIDE_FRAC. Otherwise, use this stride fraction for the common scored span. + eval_blend_stride_frac = float(os.environ.get("EVAL_BLEND_STRIDE_FRAC", "0.0")) + # Optional position-dependent blend ramp. Positive bias shifts weight from shorter contexts + # early in the scored span toward longer contexts later in the scored span. + eval_blend_position_bias = float(os.environ.get("EVAL_BLEND_POSITION_BIAS", "0.0")) + eval_blend_position_power = float(os.environ.get("EVAL_BLEND_POSITION_POWER", "1.0")) + # Eval-only continuous cache: mixes the base LM with a retrieval distribution over recent + # validation-history hidden states. This is eval-only and does not change the artifact. + eval_cont_cache_enabled = bool(int(os.environ.get("EVAL_CONT_CACHE_ENABLED", "0"))) + eval_cont_cache_window = int(os.environ.get("EVAL_CONT_CACHE_WINDOW", "8192")) + eval_cont_cache_topk = int(os.environ.get("EVAL_CONT_CACHE_TOPK", "64")) + eval_cont_cache_weight = float(os.environ.get("EVAL_CONT_CACHE_WEIGHT", "0.12")) + eval_cont_cache_logit_scale = float(os.environ.get("EVAL_CONT_CACHE_LOGIT_SCALE", "12.0")) + eval_cont_cache_conf_power = float(os.environ.get("EVAL_CONT_CACHE_CONF_POWER", "1.0")) + eval_cont_cache_batch_seqs = int(os.environ.get("EVAL_CONT_CACHE_BATCH_SEQS", "8")) + # primary | blend + final_eval_mode = os.environ.get("FINAL_EVAL_MODE", "primary").strip().lower() + # Low-rank bigram logit bias: learnable rank-r factored bigram table. + # bigram_bias[i] = bigram_right(bigram_left(prev_token[i])) added to logits before softcap. + # 0 = disabled. 32 costs ~64K int8 params (≈32 KB), well within the 164 KB headroom. + bigram_rank = int(os.environ.get("BIGRAM_RANK", "32")) + bigram_lr = float(os.environ.get("BIGRAM_LR", "0.04")) + # Residual n-gram modeling: mix neural logits with a lightweight n-gram baseline. + # total_prob = (1-gate)*P_neural + gate*P_ngram, where gate is learned per token. + # This lets the transformer focus more capacity on hard residual structure. + residual_ngram_enabled = bool(int(os.environ.get("RESIDUAL_NGRAM_ENABLED", "0"))) + residual_bigram_rank = int(os.environ.get("RESIDUAL_BIGRAM_RANK", "0")) + residual_trigram_rank = int(os.environ.get("RESIDUAL_TRIGRAM_RANK", "0")) + residual_ngram_lr = float(os.environ.get("RESIDUAL_NGRAM_LR", "0.04")) + residual_ngram_mix_init = float(os.environ.get("RESIDUAL_NGRAM_MIX_INIT", "-2.5")) + # Pointer-style local copy/cache head. + # P(next) = (1-gate) * P_model + gate * P_copy, where P_copy attends to recent context + # positions and copies their next-token targets into vocab space. + copy_cache_enabled = bool(int(os.environ.get("COPY_CACHE_ENABLED", "0"))) + copy_cache_window = int(os.environ.get("COPY_CACHE_WINDOW", "256")) + copy_cache_dim = int(os.environ.get("COPY_CACHE_DIM", "64")) + copy_cache_lr = float(os.environ.get("COPY_CACHE_LR", "0.02")) + copy_cache_gate_init = float(os.environ.get("COPY_CACHE_GATE_INIT", "-4.0")) + # Stochastic Weight Averaging: average weights during the warmdown phase. + # Takes the mean of snapshots every SWA_COLLECT_EVERY steps once LR starts decaying. + # Research-confirmed ~0.5-1.5% BPB improvement, especially helps quantization quality. + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_collect_every = int(os.environ.get("SWA_COLLECT_EVERY", "10")) + # Optional train-side loss mask aligned to sliding-window eval. When enabled, only the + # suffix of each training chunk contributes loss, matching the eval metric more closely. + train_loss_mask_enabled = bool(int(os.environ.get("TRAIN_LOSS_MASK_ENABLED", "0"))) + # 0 = inherit EVAL_STRIDE_FRAC. + train_loss_mask_stride_frac = float(os.environ.get("TRAIN_LOSS_MASK_STRIDE_FRAC", "0.0")) + # Sequence length curriculum: ramp seq_len from curriculum_min_seq_len → train_seq_len + # over the first curriculum_steps training steps. Faster early convergence on local patterns. + curriculum_enabled = bool(int(os.environ.get("CURRICULUM_ENABLED", "0"))) + curriculum_min_seq_len = int(os.environ.get("CURRICULUM_MIN_SEQ_LEN", "256")) + curriculum_steps = int(os.environ.get("CURRICULUM_STEPS", "5000")) + # Multi-token prediction (MTP): auxiliary future-token losses used during training. + mtp_enabled = bool(int(os.environ.get("MTP_ENABLED", "0"))) + mtp_steps = int(os.environ.get("MTP_STEPS", "2")) + mtp_weight = float(os.environ.get("MTP_WEIGHT", "0.3")) + mtp_decay = float(os.environ.get("MTP_DECAY", "1.0")) + mtp_tie_embeddings = bool(int(os.environ.get("MTP_TIE_EMBEDDINGS", "1"))) + mtp_lr = float(os.environ.get("MTP_LR", "0.02")) + # On-the-fly distillation (EMA teacher) in the late training tail. + distill_enabled = bool(int(os.environ.get("DISTILL_ENABLED", "0"))) + distill_start_frac = float(os.environ.get("DISTILL_START_FRAC", "0.7")) + # Optional overrides for wallclock-capped runs. DISTILL_START_STEP wins over frac. + # DISTILL_START_WALLCLOCK_FRAC keys distillation off elapsed/max_wallclock instead of ITERATIONS. + distill_start_step = int(os.environ.get("DISTILL_START_STEP", "-1")) + distill_start_wallclock_frac = float(os.environ.get("DISTILL_START_WALLCLOCK_FRAC", "-1.0")) + distill_weight = float(os.environ.get("DISTILL_WEIGHT", "0.08")) + distill_temp = float(os.environ.get("DISTILL_TEMP", "2.0")) + distill_ema_decay = float(os.environ.get("DISTILL_EMA_DECAY", "0.999")) + # JPCR: JEPA Predictive Coding Recurrence. Replaces Ouroboros controllers with + # representation predictors trained via JEPA loss (MSE) against EMA teacher intermediates. + # Each predictor learns to predict the "ideal" hidden state at this depth, then blends + # that prediction into the recurrence input — transforming blind repetition into + # JEPA-guided iterative refinement. Progressive depth targeting: pass s of block i + # targets teacher's block (i+s) output, teaching the recurrence to "look ahead". + # At inference, predictors run as part of the model (no teacher needed). + jpcr_enabled = bool(int(os.environ.get("JPCR_ENABLED", "0"))) + jpcr_hidden = int(os.environ.get("JPCR_HIDDEN", "128")) # predictor MLP hidden dim + jpcr_proj_dim = int(os.environ.get("JPCR_PROJ_DIM", str(jpcr_hidden))) + jpcr_weight = float(os.environ.get("JPCR_WEIGHT", "0.1")) # JEPA MSE loss weight + jpcr_blend_init = float(os.environ.get("JPCR_BLEND_INIT", "-2.0")) # logit for sigmoid gate init (~0.12) + jpcr_lr = float(os.environ.get("JPCR_LR", "0.02")) # predictor learning rate + jpcr_warmup_steps = int(os.environ.get("JPCR_WARMUP_STEPS", "200")) # ramp JPCR loss weight over this many steps after activation + # Distillation/JPCR application cadence. 1 = apply every step. + # When >1, distill+JPCR are applied every Nth step (no stale-target reuse). + _jpcr_apply_every_env = os.environ.get("JPCR_APPLY_EVERY", os.environ.get("JPCR_TEACHER_EVERY", "1")) + jpcr_apply_every = max(1, int(_jpcr_apply_every_env)) + # Dual-head objective: auxiliary coarse-structure prediction head. + # Classes are derived from token properties (boundary/space/byte-length) and trained + # with a small coefficient so the main LM head can focus on harder entropy. + dual_head_enabled = bool(int(os.environ.get("DUAL_HEAD_ENABLED", "0"))) + dual_head_weight = float(os.environ.get("DUAL_HEAD_WEIGHT", "0.05")) + dual_head_start_frac = float(os.environ.get("DUAL_HEAD_START_FRAC", "0.0")) + dual_head_lr = float(os.environ.get("DUAL_HEAD_LR", "0.02")) + # Logit range regularization on pre-softcap logits for quantization robustness. + logit_reg_weight = float(os.environ.get("LOGIT_REG_WEIGHT", "0.0")) + # Sandwich norm: apply post-sublayer RMSNorm (before residual add) for each block. + # Controls residual stream norm growth; used by Gemma 2. + use_sandwich_norm = bool(int(os.environ.get("USE_SANDWICH_NORM", "0"))) + # Embedding scale: multiply token embeddings by sqrt(model_dim) after lookup. + # Aligns embedding magnitude with residual stream scale. Used by Gemma, T5, PaLM. + embed_scale = bool(int(os.environ.get("EMBED_SCALE", "0"))) + # Byte-weighted training loss (align objective closer to tokenizer-agnostic BPB). + byte_weighted_loss_enabled = bool(int(os.environ.get("BYTE_WEIGHTED_LOSS_ENABLED", "0"))) + byte_weighted_loss_alpha = float(os.environ.get("BYTE_WEIGHTED_LOSS_ALPHA", "1.0")) + # Hybrid SSM blocks: periodically replace attention blocks with a mixer. + # In this experiment file the default is official CUDA-backed Mamba-3. + use_ssm = bool(int(os.environ.get("USE_SSM", "0"))) + ssm_every_n = int(os.environ.get("SSM_EVERY_N", "2")) + ssm_expand = float(os.environ.get("SSM_EXPAND", "2.0")) + ssm_kernel = int(os.environ.get("SSM_KERNEL", "4")) + ssm_impl = os.environ.get("SSM_IMPL", "mamba3").strip().lower() + mamba3_d_state = int(os.environ.get("MAMBA3_D_STATE", "128")) + # 0 = auto-pick a divisor of MODEL_DIM near 64. + mamba3_head_dim = int(os.environ.get("MAMBA3_HEAD_DIM", "0")) + mamba3_is_mimo = bool(int(os.environ.get("MAMBA3_IS_MIMO", "1"))) + mamba3_mimo_rank = int(os.environ.get("MAMBA3_MIMO_RANK", "4")) + mamba3_chunk_size = int(os.environ.get("MAMBA3_CHUNK_SIZE", "16")) + mamba3_outproj_norm = bool(int(os.environ.get("MAMBA3_OUTPROJ_NORM", "0"))) + # Quantization-Aware Training: fake-quantise weights during forward to teach the model + # to tolerate quantisation noise, dramatically reducing the roundtrip BPB penalty. + # QAT_SCHEME: "none" | "int8" | "int5" | "int4" (should match QUANT_SCHEME at export) + # QAT_START_STEP/QAT_END_STEP: step-based QAT schedule. + # QAT_START_WALLCLOCK_FRAC/QAT_END_WALLCLOCK_FRAC: optional wallclock-based + # schedule for capped runs; when start frac is >= 0 and max wallclock is set, + # it wins over the step schedule. + qat_scheme = os.environ.get("QAT_SCHEME", "none").strip().lower() + qat_start_step = int(os.environ.get("QAT_START_STEP", "9000")) + qat_end_step = int(os.environ.get("QAT_END_STEP", "0")) + qat_start_wallclock_frac = float(os.environ.get("QAT_START_WALLCLOCK_FRAC", "-1.0")) + qat_end_wallclock_frac = float(os.environ.get("QAT_END_WALLCLOCK_FRAC", "1.0")) + # QAT_LSQ=1 enables Learned Step-Size Quantization: per-row learnable log-scale + # replaces the max-abs scale in fake-quant, reducing int4 roundtrip penalty by + # letting the model optimise the clip threshold per output row via backprop (STE). + qat_lsq = bool(int(os.environ.get("QAT_LSQ", "0"))) + + # GPTQ post-training quantization (replaces naive round-to-nearest at export). + gptq_enabled = bool(int(os.environ.get("GPTQ", "1"))) + gptq_nsamples = int(os.environ.get("GPTQ_NSAMPLES", "128")) + gptq_blocksize = int(os.environ.get("GPTQ_BLOCKSIZE", "128")) + gptq_percdamp = float(os.environ.get("GPTQ_PERCDAMP", "0.01")) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 8192)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + recurrent_core_layers = int(os.environ.get("RECURRENT_CORE_LAYERS", 0)) + recurrent_steps = int(os.environ.get("RECURRENT_STEPS", 0)) + share_ffn_across_blocks = bool(int(os.environ.get("SHARE_FFN_ACROSS_BLOCKS", "0"))) + # Intra-layer recurrence: run layers [intra_loop_start..intra_loop_end] intra_loop_steps times. + # All blocks remain unique (no weight sharing), so parameter count is unchanged. + # Research (arXiv:2505.01855) shows front-loading repetitions on early layers maximises BPB gain. + # Example: INTRA_LOOP_START=0 INTRA_LOOP_END=2 INTRA_LOOP_STEPS=3 on a 9L model gives + # effective depth 9 + 2*3 = 15 with zero extra parameters. + intra_loop_start = int(os.environ.get("INTRA_LOOP_START", "3")) # -1 = disabled + intra_loop_end = int(os.environ.get("INTRA_LOOP_END", "4")) + intra_loop_steps = int(os.environ.get("INTRA_LOOP_STEPS", "2")) + # Parallel residuals: attn and MLP read same pre-norm input, outputs summed. + # One norm per block instead of two; improved gradient flow. Leaderboard PR #1477. + use_parallel_residual = bool(int(os.environ.get("PARALLEL_RESIDUAL", "0"))) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + # Mixture of Experts (MoE): replace dense MLPs with sparse expert routing. + # MOE_NUM_EXPERTS=0 → disabled (dense MLP as usual) + # MOE_NUM_EXPERTS=2 → 2 experts per MoE layer, Expert Choice routing + # MOE_EVERY_N=1 → all layers are MoE; =2 → alternating (even layers); =3 → every 3rd + # MOE_CAPACITY_FACTOR: each expert sees int(cf * S / E) tokens (1.0 = perfect balance) + # MOE_AUX_LOSS_COEFF: weight on router Z-loss (stabilises routing, prevents collapse) + moe_num_experts = int(os.environ.get("MOE_NUM_EXPERTS", "0")) + moe_every_n = int(os.environ.get("MOE_EVERY_N", "2")) + moe_capacity_factor = float(os.environ.get("MOE_CAPACITY_FACTOR", "1.0")) + moe_aux_loss_coeff = float(os.environ.get("MOE_AUX_LOSS_COEFF", "1e-3")) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + # Decoupled softcap for the ngram residual branch (0 = inherit LOGIT_SOFTCAP). + # Letting the ngram branch push harder than the neural head often helps when the + # residual ngram is well-trained (small but sharp tables). + ngram_softcap = float(os.environ.get("NGRAM_SOFTCAP", "0.0")) + # Entropy-conditioned ngram gate: gate also sees a confidence signal (lse - max logit, + # a cheap proxy for -log max_prob of the neural head) so ngram can dominate when the + # neural model is unsure. Adds one scalar input per gate. + ngram_entropy_gate = bool(int(os.environ.get("NGRAM_ENTROPY_GATE", "0"))) + # Test-time training (competition-compliant): after scoring each eval batch, take one + # SGD step on the scored positions' CE loss. Only ngram/gate/scale params update; the + # base transformer is frozen. Params are snapshotted before eval and restored after, + # so intermediate val checkpoints are unaffected. Only activated in the final eval + # suite. Default off so existing runs are bit-identical. + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", "1e-3")) + ttt_steps = int(os.environ.get("TTT_STEPS", "1")) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", "0.9")) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + # Export / compression controls. + quant_scheme = os.environ.get("QUANT_SCHEME", "int8").strip().lower() + compressor = os.environ.get("COMPRESSOR", "zlib").strip().lower() + compress_level = int(os.environ.get("COMPRESS_LEVEL", "-1")) + weight_order = os.environ.get("WEIGHT_ORDER", "none").strip().lower() + mixed_low_precision_scheme = os.environ.get("MIXED_LOW_PRECISION_SCHEME", "int8").strip().lower() + # If 0, skip the post-quantization roundtrip eval pass (saves one full val sweep). + final_roundtrip_eval = bool( + int(os.environ.get("FINAL_ROUNDTRIP_EVAL", os.environ.get("FINAL_INT8_ROUNDTRIP_EVAL", "1"))) + ) + final_int8_roundtrip_eval = final_roundtrip_eval + submission_size_budget_bytes = int(os.environ.get("SUBMISSION_SIZE_BUDGET_BYTES", str(16 * 1024 * 1024))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.to(dtype=torch.bfloat16 if G.is_cuda else torch.float32) + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_dtype = torch.bfloat16 if params[0].device.type == "cuda" else torch.float32 + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=updates_dtype) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + # MuonEq-R: row equilibration before Newton-Schulz + # (removes marginal row-scale mismatch, arxiv 2603.28254) + if g.ndim == 2: + g = g / g.norm(dim=1, keepdim=True).clamp(min=1e-8) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def parse_csv_ints(raw: str) -> list[int]: + values: list[int] = [] + for part in raw.split(","): + item = part.strip() + if item: + values.append(int(item)) + return values + + +def parse_csv_floats(raw: str) -> list[float]: + values: list[float] = [] + for part in raw.split(","): + item = part.strip() + if item: + values.append(float(item)) + return values + + +def default_eval_rope_scale(seq_len: int, train_seq_len: int) -> float: + if seq_len == train_seq_len: + return 1.0 + return float(seq_len / train_seq_len) ** 2 + + +def resolve_seq_len(raw_seq_len: int, train_seq_len: int) -> int: + return train_seq_len if raw_seq_len <= 0 else raw_seq_len + + +def resolve_stride(seq_len: int, stride_frac: float) -> int: + frac = stride_frac if stride_frac > 0.0 else 1.0 + return max(1, min(seq_len, int(seq_len * frac))) + + +def build_loss_mask_cpu(seq_len: int, stride_frac: float) -> tuple[Tensor, int, int]: + stride = resolve_stride(seq_len, stride_frac) + prefix_len = seq_len - stride + loss_mask_cpu = torch.zeros(seq_len, dtype=torch.float32) + loss_mask_cpu[prefix_len:] = 1.0 + return loss_mask_cpu, prefix_len, stride + + +def format_float_tag(value: float) -> str: + text = f"{value:.4f}".rstrip("0").rstrip(".") + return text.replace("-", "m").replace(".", "p") if text else "0" + + +def make_eval_spec_name(seq_len: int, rope_scale: float) -> str: + return f"seq{seq_len}_rope{format_float_tag(rope_scale)}" + + +def resolve_primary_eval_spec(args: Hyperparameters) -> tuple[str, int, float]: + seq_len = resolve_seq_len(args.eval_seq_len, args.train_seq_len) + rope_scale = float(args.eval_rope_scale) + return "primary", seq_len, rope_scale + + +def resolve_eval_sweep_specs(args: Hyperparameters) -> list[tuple[str, int, float]]: + specs: list[tuple[str, int, float]] = [] + seen: set[tuple[int, int]] = set() + + def add_spec(name: str, seq_len: int, rope_scale: float) -> None: + key = (seq_len, int(round(rope_scale * 1_000_000))) + if key in seen: + return + seen.add(key) + specs.append((name, seq_len, rope_scale)) + + primary_name, primary_seq_len, primary_rope_scale = resolve_primary_eval_spec(args) + add_spec(primary_name, primary_seq_len, primary_rope_scale) + + sweep_seq_lens = parse_csv_ints(args.eval_sweep_seq_lens) + sweep_rope_scales = parse_csv_floats(args.eval_sweep_rope_scales) + if sweep_rope_scales and len(sweep_rope_scales) != len(sweep_seq_lens): + raise ValueError( + "EVAL_SWEEP_ROPE_SCALES must have the same number of entries as EVAL_SWEEP_SEQ_LENS" + ) + for idx, raw_seq_len in enumerate(sweep_seq_lens): + seq_len = resolve_seq_len(raw_seq_len, args.train_seq_len) + rope_scale = ( + sweep_rope_scales[idx] + if sweep_rope_scales + else default_eval_rope_scale(seq_len, args.train_seq_len) + ) + add_spec(make_eval_spec_name(seq_len, rope_scale), seq_len, float(rope_scale)) + return specs + + +def resolve_eval_blend_specs(args: Hyperparameters) -> tuple[list[tuple[str, int, float]], list[float]]: + blend_seq_lens = parse_csv_ints(args.eval_blend_seq_lens) + if not blend_seq_lens: + return [], [] + blend_rope_scales = parse_csv_floats(args.eval_blend_rope_scales) + if blend_rope_scales and len(blend_rope_scales) != len(blend_seq_lens): + raise ValueError( + "EVAL_BLEND_ROPE_SCALES must have the same number of entries as EVAL_BLEND_SEQ_LENS" + ) + blend_weights = parse_csv_floats(args.eval_blend_weights) + if blend_weights and len(blend_weights) != len(blend_seq_lens): + raise ValueError( + "EVAL_BLEND_WEIGHTS must have the same number of entries as EVAL_BLEND_SEQ_LENS" + ) + + specs: list[tuple[str, int, float]] = [] + for idx, raw_seq_len in enumerate(blend_seq_lens): + seq_len = resolve_seq_len(raw_seq_len, args.train_seq_len) + rope_scale = ( + blend_rope_scales[idx] + if blend_rope_scales + else default_eval_rope_scale(seq_len, args.train_seq_len) + ) + specs.append((make_eval_spec_name(seq_len, float(rope_scale)), seq_len, float(rope_scale))) + + if not blend_weights: + blend_weights = [1.0] * len(specs) + total_weight = sum(blend_weights) + if total_weight <= 0.0: + raise ValueError("EVAL_BLEND_WEIGHTS must sum to a positive value") + normalized = [w / total_weight for w in blend_weights] + return specs, normalized + + +def resolve_max_eval_seq_len( + args: Hyperparameters, + sweep_specs: list[tuple[str, int, float]], + blend_specs: list[tuple[str, int, float]], +) -> int: + max_seq_len = args.train_seq_len + for _, seq_len, _ in sweep_specs: + max_seq_len = max(max_seq_len, seq_len) + for _, seq_len, _ in blend_specs: + max_seq_len = max(max_seq_len, seq_len) + return max_seq_len + + +def resolve_train_loss_mask_stride_frac(args: Hyperparameters) -> float: + return args.train_loss_mask_stride_frac if args.train_loss_mask_stride_frac > 0.0 else args.eval_stride_frac + + +def resolve_distill_start_step(args: Hyperparameters) -> int: + if args.distill_start_step >= 0: + return args.distill_start_step + if args.distill_start_frac < 0.0: + return args.iterations + 1 # Never trigger via fraction if negative + return int(max(0.0, min(1.0, args.distill_start_frac)) * args.iterations) + + +def distill_is_active( + args: Hyperparameters, + step: int, + elapsed_ms: float, + max_wallclock_ms: float | None, + distill_start_step: int, +) -> bool: + if args.distill_start_step >= 0: + return step >= args.distill_start_step + if args.distill_start_wallclock_frac >= 0.0 and max_wallclock_ms is not None and max_wallclock_ms > 0.0: + start_frac = max(0.0, min(1.0, args.distill_start_wallclock_frac)) + return elapsed_ms >= start_frac * max_wallclock_ms + return step >= distill_start_step + + +def qat_target_levels( + args: Hyperparameters, + step: int, + elapsed_ms: float, + max_wallclock_ms: float | None, +) -> tuple[int, str]: + if args.qat_scheme == "none": + return 0, "off" + + use_wallclock = ( + args.qat_start_wallclock_frac >= 0.0 + and max_wallclock_ms is not None + and max_wallclock_ms > 0.0 + ) + if use_wallclock: + start_frac = max(0.0, min(1.0, args.qat_start_wallclock_frac)) + end_frac = max(start_frac + 1e-6, min(1.0, args.qat_end_wallclock_frac)) + start_pos = start_frac * max_wallclock_ms + end_pos = end_frac * max_wallclock_ms + current_pos = elapsed_ms + mode = f"wallclock_frac:{start_frac:.4f}->{end_frac:.4f}" + else: + start_pos = float(args.qat_start_step) + end_step = args.qat_end_step if args.qat_end_step > args.qat_start_step else args.iterations + end_pos = float(end_step) + current_pos = float(step) + mode = f"step:{args.qat_start_step}->{int(end_pos)}" + + if current_pos < start_pos: + return 0, mode + if args.qat_scheme == "int8": + return 256, mode + + frac = (current_pos - start_pos) / max(end_pos - start_pos, 1.0) + frac = max(0.0, min(1.0, frac)) + if args.qat_scheme == "int5": + return (256 if frac < 0.33 else (64 if frac < 0.67 else 32)), mode + return (256 if frac < 0.33 else (64 if frac < 0.67 else 16)), mode + + +def build_blend_position_log_weights( + args: Hyperparameters, + blend_specs: list[tuple[str, int, float]], + blend_weights: list[float], + blend_stride: int, + device: torch.device, +) -> Tensor: + base_log_weights = torch.log(torch.tensor(blend_weights, device=device, dtype=torch.float32).clamp_min(1e-12)) + if args.eval_blend_position_bias == 0.0 or len(blend_specs) <= 1: + return base_log_weights[:, None].expand(-1, blend_stride) + + seq_lens = torch.tensor([seq_len for _, seq_len, _ in blend_specs], device=device, dtype=torch.float32) + centered = seq_lens - seq_lens.mean() + centered = centered / centered.abs().max().clamp_min(1e-6) + pos = torch.linspace(0.0, 1.0, steps=blend_stride, device=device, dtype=torch.float32) + signed_pos = 2.0 * pos - 1.0 + power = max(float(args.eval_blend_position_power), 1e-6) + if power != 1.0: + signed_pos = signed_pos.sign() * signed_pos.abs().pow(power) + logits = base_log_weights[:, None] + float(args.eval_blend_position_bias) * centered[:, None] * signed_pos[None, :] + return F.log_softmax(logits, dim=0) + + +def apply_eval_continuous_cache( + args: Hyperparameters, + scored_log_probs: Tensor, + scored_hidden: Tensor, + scored_targets: Tensor, + cache_state: tuple[Tensor, Tensor] | None, +) -> tuple[Tensor, tuple[Tensor, Tensor] | None]: + if not args.eval_cont_cache_enabled: + return scored_log_probs, cache_state + + flat_log_probs = scored_log_probs.reshape(-1, scored_log_probs.size(-1)).float() + flat_hidden = F.normalize(scored_hidden.reshape(-1, scored_hidden.size(-1)).float(), dim=-1) + flat_targets = scored_targets.reshape(-1).to(dtype=torch.int64) + mixed_log_probs = flat_log_probs + + if cache_state is not None and cache_state[0].numel() > 0: + cache_keys, cache_values = cache_state + scores = torch.matmul(flat_hidden, cache_keys.transpose(0, 1)) * float(args.eval_cont_cache_logit_scale) + topk = min(max(int(args.eval_cont_cache_topk), 0), cache_keys.size(0)) + if topk > 0 and topk < cache_keys.size(0): + scores, top_idx = torch.topk(scores, k=topk, dim=-1) + retrieved_ids = cache_values[top_idx] + else: + retrieved_ids = cache_values.unsqueeze(0).expand(scores.size(0), -1) + attn = F.softmax(scores, dim=-1) + cache_probs = torch.zeros_like(mixed_log_probs) + cache_probs.scatter_add_(1, retrieved_ids, attn) + cache_log_probs = torch.log(cache_probs.clamp_min(1e-9)) + mix = torch.full( + (mixed_log_probs.size(0),), + float(args.eval_cont_cache_weight), + device=mixed_log_probs.device, + dtype=torch.float32, + ) + if args.eval_cont_cache_conf_power >= 0.0: + cache_conf = cache_probs.max(dim=-1).values.clamp_(0.0, 1.0) + mix = mix * cache_conf.pow(float(args.eval_cont_cache_conf_power)) + mix = mix.clamp(min=1e-5, max=1.0 - 1e-5) + mixed_log_probs = torch.logaddexp( + torch.log1p(-mix).unsqueeze(-1) + mixed_log_probs, + torch.log(mix).unsqueeze(-1) + cache_log_probs, + ) + + window = max(1, int(args.eval_cont_cache_window)) + new_keys = flat_hidden.detach()[-window:] + new_values = flat_targets.detach()[-window:] + if cache_state is None or cache_state[0].numel() == 0: + updated_state = (new_keys, new_values) + else: + cache_keys, cache_values = cache_state + cache_keys = torch.cat((cache_keys, new_keys), dim=0) + cache_values = torch.cat((cache_values, new_values), dim=0) + if cache_keys.size(0) > window: + cache_keys = cache_keys[-window:] + cache_values = cache_values[-window:] + updated_state = (cache_keys.detach(), cache_values.detach()) + return mixed_log_probs.reshape_as(scored_log_probs).to(dtype=scored_log_probs.dtype), updated_state + + +def get_eval_model(model: nn.Module) -> nn.Module: + raw_model = model.module if hasattr(model, "module") else model + if hasattr(raw_model, "forward_hidden_and_output"): + return raw_model + if hasattr(raw_model, "_orig_mod") and hasattr(raw_model._orig_mod, "forward_hidden_and_output"): + return raw_model._orig_mod + if hasattr(raw_model, "forward_logits"): + return raw_model + if hasattr(raw_model, "_orig_mod") and hasattr(raw_model._orig_mod, "forward_logits"): + return raw_model._orig_mod + raise AttributeError("Could not find a forward_logits-capable model for evaluation") + + +TTT_PARAM_NAME_MATCH = ( + "residual_bigram_", + "residual_trigram_", + "residual_ngram_", + "bigram_left", + "bigram_right", + "bigram_scale", + "copy_gate", +) + + +def collect_ttt_params(raw_model: nn.Module) -> list[tuple[str, nn.Parameter]]: + # Keep TTT scoped to the small adaptive heads/tables. Residual n-gram + # predictors are named residual_bigram_* / residual_trigram_*, not only + # residual_ngram_*, so include all of those prefixes. + params: list[tuple[str, nn.Parameter]] = [] + for name, p in raw_model.named_parameters(): + leaf = name.rsplit(".", 1)[-1] + if any(name.startswith(pref) or leaf.startswith(pref) for pref in TTT_PARAM_NAME_MATCH): + params.append((name, p)) + return params + + +def apply_eval_rope_scaling( + model: nn.Module, + args: Hyperparameters, + seq_len: int, + rope_scale: float, +) -> list[tuple[object, Tensor]]: + if rope_scale == 1.0 and seq_len == args.train_seq_len: + return [] + head_dim = args.model_dim // args.num_heads + ntk_factor = rope_scale ** (head_dim / max(head_dim - 2, 1)) + raw_model = get_eval_model(model) + if not hasattr(raw_model, "blocks"): + return [] + orig_rope_bases: list[tuple[object, Tensor]] = [] + for block in raw_model.blocks: + attn = getattr(block, "attn", None) + rot = getattr(attn, "rotary", None) + if rot is None: + continue + orig_rope_bases.append((rot, rot.inv_freq.clone())) + new_base = args.rope_base * ntk_factor + new_inv_freq = 1.0 / ( + new_base ** (torch.arange(0, head_dim, 2, dtype=torch.float32, device=rot.inv_freq.device) / head_dim) + ) + rot.inv_freq = new_inv_freq + rot._cos_cached = None + return orig_rope_bases + + +def restore_eval_rope_scaling(orig_rope_bases: list[tuple[object, Tensor]]) -> None: + for rot, orig_inv_freq in orig_rope_bases: + rot.inv_freq = orig_inv_freq + rot._cos_cached = None + + +def forward_eval_outputs( + args: Hyperparameters, + model: nn.Module, + x: Tensor, + seq_len: int, + rope_scale: float, + autocast_enabled: bool, +) -> tuple[Tensor, Tensor]: + eval_model = get_eval_model(model) + orig_rope_bases = apply_eval_rope_scaling(model, args, seq_len, rope_scale) + try: + jpcr_runtime_active = bool(getattr(eval_model, "jpcr_enabled", False)) + if autocast_enabled: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + hidden, logits, logits_are_log_probs = eval_model.forward_hidden_and_output( + x, jpcr_runtime_active=jpcr_runtime_active + ) + else: + hidden, logits, logits_are_log_probs = eval_model.forward_hidden_and_output( + x, jpcr_runtime_active=jpcr_runtime_active + ) + finally: + restore_eval_rope_scaling(orig_rope_bases) + log_probs = logits.float().reshape(x.size(0), x.size(1), -1) + if not logits_are_log_probs: + log_probs = F.log_softmax(log_probs, dim=-1) + return log_probs, hidden.float() + + +def eval_val_single( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + autocast_enabled: bool, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + seq_len: int, + rope_scale: float, + stride_frac: float, + ttt_enabled: bool = False, + ttt_lr: float = 0.0, + ttt_steps: int = 1, + ttt_momentum: float = 0.9, +) -> tuple[float, float]: + _, prefix_len, stride = build_loss_mask_cpu(seq_len, stride_frac) + if args.eval_cont_cache_enabled and world_size != 1: + raise ValueError("EVAL_CONT_CACHE_ENABLED currently requires WORLD_SIZE=1 for deterministic eval order") + + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + local_batch_seqs = max(1, local_batch_tokens // seq_len) + if args.eval_cont_cache_enabled: + local_batch_seqs = min(local_batch_seqs, max(1, args.eval_cont_cache_batch_seqs)) + total_wins = max(1, (val_tokens.numel() - seq_len - 1) // stride) + win_start = (total_wins * rank) // world_size + win_end = (total_wins * (rank + 1)) // world_size + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # --- TTT setup (competition-compliant online update) ----------------------------- + # We snapshot the chosen param subset before eval starts, do SGD steps after each + # scored batch, then restore the snapshot before returning. This keeps the stored + # model state untouched so subsequent eval passes / quantization see clean weights. + ttt_active = bool(ttt_enabled) and float(ttt_lr) > 0.0 + ttt_params: list[tuple[str, nn.Parameter]] = [] + ttt_snapshots: list[Tensor] = [] + ttt_prev_requires_grad: dict[int, bool] = {} + ttt_optim: torch.optim.Optimizer | None = None + raw_model = get_eval_model(model) if ttt_active else None + if ttt_active and raw_model is not None: + # Scope: ngram + pointer-gate + small learned scales. Base transformer stays frozen. + ttt_params = collect_ttt_params(raw_model) + ttt_prev_requires_grad = {id(p): p.requires_grad for p in raw_model.parameters()} + for p in raw_model.parameters(): + p.requires_grad_(False) + for _, p in ttt_params: + p.requires_grad_(True) + ttt_snapshots.append(p.detach().clone()) + if ttt_params: + ttt_optim = torch.optim.SGD( + [p for _, p in ttt_params], lr=float(ttt_lr), momentum=float(ttt_momentum) + ) + else: + ttt_active = False # nothing to update + # --------------------------------------------------------------------------------- + + model.eval() + cache_state: tuple[Tensor, Tensor] | None = None + + eval_ctx = torch.enable_grad() if ttt_active else torch.inference_mode() + with eval_ctx: + for batch_win_start in range(win_start, win_end, local_batch_seqs): + batch_win_end = min(batch_win_start + local_batch_seqs, win_end) + xs, ys = [], [] + for w in range(batch_win_start, batch_win_end): + s = w * stride + xs.append(val_tokens[s : s + seq_len]) + ys.append(val_tokens[s + 1 : s + seq_len + 1]) + x = torch.stack(xs).to(device=device, dtype=torch.int64, non_blocking=True) + y = torch.stack(ys).to(device=device, dtype=torch.int64, non_blocking=True) + log_probs, hidden = forward_eval_outputs(args, model, x, seq_len, rope_scale, autocast_enabled) + scored_log_probs = log_probs[:, prefix_len:, :] + scored_hidden = hidden[:, prefix_len:, :] + scored_targets = y[:, prefix_len:] + scored_log_probs, cache_state = apply_eval_continuous_cache( + args, + scored_log_probs, + scored_hidden, + scored_targets, + cache_state, + ) + target_log_probs = scored_log_probs.gather(-1, scored_targets.unsqueeze(-1)).squeeze(-1) + + # Accumulate BPB stats (always detached from the TTT graph). + tlp_detached = target_log_probs.detach() + val_loss_sum += (-tlp_detached).sum(dtype=torch.float64) + val_token_count += tlp_detached.numel() + + prev_ids = x[:, prefix_len:].reshape(-1) + tgt_ids = scored_targets.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + # TTT update: CE on the scored suffix. This is competition-compliant because + # the update happens AFTER emitting the BPB for this batch, and only uses + # tokens whose predictions are already recorded (online learning). + if ttt_active and ttt_optim is not None: + ttt_loss = -target_log_probs.mean() + ttt_loss.backward() + ttt_optim.step() + ttt_optim.zero_grad(set_to_none=True) + for _ in range(max(0, int(ttt_steps) - 1)): + # Additional steps re-run forward on the same batch. Kept behind + # an explicit env knob; default TTT_STEPS=1 skips this branch. + log_probs2, _h2 = forward_eval_outputs(args, model, x, seq_len, rope_scale, autocast_enabled) + slp2 = log_probs2[:, prefix_len:, :] + tlp2 = slp2.gather(-1, scored_targets.unsqueeze(-1)).squeeze(-1) + (-tlp2.mean()).backward() + ttt_optim.step() + ttt_optim.zero_grad(set_to_none=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + # Restore TTT param snapshots and prior requires_grad flags so the underlying + # model is bitwise unchanged after this function returns. + if ttt_active and raw_model is not None: + with torch.no_grad(): + for (_, p), snap in zip(ttt_params, ttt_snapshots): + p.data.copy_(snap) + for p in raw_model.parameters(): + p.requires_grad_(ttt_prev_requires_grad.get(id(p), False)) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_blend( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + autocast_enabled: bool, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + blend_specs: list[tuple[str, int, float]], + blend_weights: list[float], +) -> tuple[float, float]: + if not blend_specs: + raise ValueError("eval_val_blend requires at least one blend spec") + if args.eval_cont_cache_enabled and world_size != 1: + raise ValueError("EVAL_CONT_CACHE_ENABLED currently requires WORLD_SIZE=1 for deterministic eval order") + + blend_stride_frac = args.eval_blend_stride_frac if args.eval_blend_stride_frac > 0.0 else args.eval_stride_frac + min_seq_len = min(seq_len for _, seq_len, _ in blend_specs) + max_seq_len = max(seq_len for _, seq_len, _ in blend_specs) + blend_stride = resolve_stride(min_seq_len, blend_stride_frac) + max_prefix_len = max(seq_len - blend_stride for _, seq_len, _ in blend_specs) + first_target_pos = max_prefix_len + 1 + max_target_start = val_tokens.numel() - blend_stride + if max_target_start < first_target_pos: + raise ValueError( + f"Validation split is too short for blend eval: first_target_pos={first_target_pos}, " + f"max_target_start={max_target_start}" + ) + + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + local_batch_chunks = max(1, local_batch_tokens // max(max_seq_len * len(blend_specs), 1)) + if args.eval_cont_cache_enabled: + local_batch_chunks = min(local_batch_chunks, max(1, args.eval_cont_cache_batch_seqs)) + total_chunks = ((max_target_start - first_target_pos) // blend_stride) + 1 + chunk_start = (total_chunks * rank) // world_size + chunk_end = (total_chunks * (rank + 1)) // world_size + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + cache_states: list[tuple[Tensor, Tensor] | None] = [None] * len(blend_specs) + with torch.inference_mode(): + for batch_chunk_start in range(chunk_start, chunk_end, local_batch_chunks): + batch_chunk_end = min(batch_chunk_start + local_batch_chunks, chunk_end) + target_starts = [first_target_pos + idx * blend_stride for idx in range(batch_chunk_start, batch_chunk_end)] + pos_log_weights = build_blend_position_log_weights( + args, + blend_specs, + blend_weights, + blend_stride, + device, + ) + + common_prev_ids = torch.stack( + [val_tokens[target_pos - 1 : target_pos + blend_stride - 1] for target_pos in target_starts] + ).to(device=device, dtype=torch.int64, non_blocking=True) + common_target_ids = torch.stack( + [val_tokens[target_pos : target_pos + blend_stride] for target_pos in target_starts] + ).to(device=device, dtype=torch.int64, non_blocking=True) + + blend_log_probs: Tensor | None = None + for spec_idx, (spec_name, seq_len, rope_scale) in enumerate(blend_specs): + del spec_name + prefix_len = seq_len - blend_stride + xs = [] + for target_pos in target_starts: + s = target_pos - prefix_len - 1 + xs.append(val_tokens[s : s + seq_len]) + x = torch.stack(xs).to(device=device, dtype=torch.int64, non_blocking=True) + log_probs, hidden = forward_eval_outputs(args, model, x, seq_len, rope_scale, autocast_enabled) + scored_log_probs = log_probs[:, prefix_len:, :] + scored_hidden = hidden[:, prefix_len:, :] + scored_log_probs, cache_states[spec_idx] = apply_eval_continuous_cache( + args, + scored_log_probs, + scored_hidden, + common_target_ids, + cache_states[spec_idx], + ) + weighted_log_probs = scored_log_probs + pos_log_weights[spec_idx][None, :, None] + blend_log_probs = ( + weighted_log_probs + if blend_log_probs is None + else torch.logaddexp(blend_log_probs, weighted_log_probs) + ) + + if blend_log_probs is None: + raise RuntimeError("blend_log_probs should have been populated") + target_log_probs = blend_log_probs.gather(-1, common_target_ids.unsqueeze(-1)).squeeze(-1) + val_loss_sum += (-target_log_probs).sum(dtype=torch.float64) + val_token_count += target_log_probs.numel() + + prev_ids = common_prev_ids.reshape(-1) + tgt_ids = common_target_ids.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + autocast_enabled: bool, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + _, seq_len, rope_scale = resolve_primary_eval_spec(args) + return eval_val_single( + args, + model, + rank, + world_size, + device, + autocast_enabled, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + seq_len, + rope_scale, + args.eval_stride_frac, + ) + + +def run_final_eval_suite( + args: Hyperparameters, + roundtrip_tag: str, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + autocast_enabled: bool, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + sweep_specs: list[tuple[str, int, float]], + blend_specs: list[tuple[str, int, float]], + blend_weights: list[float], + log0, +) -> tuple[float, float]: + primary_name, primary_seq_len, primary_rope_scale = resolve_primary_eval_spec(args) + ttt_param_count = 0 + if args.ttt_enabled and args.ttt_lr > 0.0: + try: + ttt_param_count = len(collect_ttt_params(get_eval_model(model))) + except AttributeError: + ttt_param_count = 0 + ttt_effective = bool(args.ttt_enabled and args.ttt_lr > 0.0 and ttt_param_count > 0) + primary_val_loss, primary_val_bpb = eval_val_single( + args, + model, + rank, + world_size, + device, + autocast_enabled, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + primary_seq_len, + primary_rope_scale, + args.eval_stride_frac, + ttt_enabled=ttt_effective, + ttt_lr=args.ttt_lr, + ttt_steps=args.ttt_steps, + ttt_momentum=args.ttt_momentum, + ) + log0( + f"{roundtrip_tag}_ctx_exact name:{primary_name} seq_len:{primary_seq_len} " + f"rope_scale:{primary_rope_scale:.4f} stride_frac:{args.eval_stride_frac:.4f} " + f"ttt:{1 if ttt_effective else 0} ttt_params:{ttt_param_count} " + f"ttt_lr:{args.ttt_lr} ttt_steps:{args.ttt_steps} " + f"val_loss:{primary_val_loss:.8f} val_bpb:{primary_val_bpb:.8f}" + ) + + for sweep_name, sweep_seq_len, sweep_rope_scale in sweep_specs[1:]: + sweep_val_loss, sweep_val_bpb = eval_val_single( + args, + model, + rank, + world_size, + device, + autocast_enabled, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + sweep_seq_len, + sweep_rope_scale, + args.eval_stride_frac, + ) + log0( + f"{roundtrip_tag}_ctx_exact name:{sweep_name} seq_len:{sweep_seq_len} " + f"rope_scale:{sweep_rope_scale:.4f} stride_frac:{args.eval_stride_frac:.4f} " + f"val_loss:{sweep_val_loss:.8f} val_bpb:{sweep_val_bpb:.8f}" + ) + + blend_result: tuple[float, float] | None = None + if blend_specs: + blend_stride_frac = args.eval_blend_stride_frac if args.eval_blend_stride_frac > 0.0 else args.eval_stride_frac + blend_val_loss, blend_val_bpb = eval_val_blend( + args, + model, + rank, + world_size, + device, + autocast_enabled, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + blend_specs, + blend_weights, + ) + blend_specs_log = ",".join( + f"{name}:{seq_len}@{rope_scale:.4f}" + for name, seq_len, rope_scale in blend_specs + ) + blend_weights_log = ",".join(f"{weight:.6f}" for weight in blend_weights) + log0( + f"{roundtrip_tag}_blend_exact stride_frac:{blend_stride_frac:.4f} specs:{blend_specs_log} " + f"weights:{blend_weights_log} position_bias:{args.eval_blend_position_bias:.4f} " + f"position_power:{args.eval_blend_position_power:.4f} " + f"val_loss:{blend_val_loss:.8f} val_bpb:{blend_val_bpb:.8f}" + ) + blend_result = (blend_val_loss, blend_val_bpb) + + if args.final_eval_mode == "primary": + return primary_val_loss, primary_val_bpb + if args.final_eval_mode == "blend": + if blend_result is None: + raise ValueError("FINAL_EVAL_MODE=blend requires EVAL_BLEND_SEQ_LENS to be set") + return blend_result + raise ValueError(f"Unsupported FINAL_EVAL_MODE={args.final_eval_mode!r}; expected 'primary' or 'blend'") + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +QUANT_SCALE_EPS = float(os.environ.get("QUANT_SCALE_EPS", "1e-8")) +INT4_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT4_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT4_KEEP_FLOAT_MAX_NUMEL = int(os.environ.get("INT4_KEEP_FLOAT_MAX_NUMEL", 65_536)) +INT4_PER_ROW_SCALE_DTYPE = torch.float16 +INT4_CLIP_PERCENTILE = float(os.environ.get("INT4_CLIP_PERCENTILE", 99.995)) +INT4_CLIP_Q = INT4_CLIP_PERCENTILE / 100.0 +INT4_GROUP_SIZE = int(os.environ.get("INT4_GROUP_SIZE", "128")) # 0 = per-row (legacy) +INT5_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT5_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT5_KEEP_FLOAT_MAX_NUMEL = int(os.environ.get("INT5_KEEP_FLOAT_MAX_NUMEL", 65_536)) +INT5_PER_ROW_SCALE_DTYPE = torch.float16 +INT5_CLIP_PERCENTILE = float(os.environ.get("INT5_CLIP_PERCENTILE", 99.997)) +INT5_CLIP_Q = INT5_CLIP_PERCENTILE / 100.0 + +# NF4 lookup table: 16 quantiles of N(0,1), information-theoretically optimal for normal weights. +# Index 0..15 maps to these fixed float values. Quantize: find nearest, store index. +NF4_ENABLED = bool(int(os.environ.get("NF4_ENABLED", "1"))) +NF4_LUT = torch.tensor([ + -1.0, -0.6962, -0.5251, -0.3949, -0.2844, -0.1848, -0.0911, 0.0, + 0.0796, 0.1609, 0.2461, 0.3379, 0.4407, 0.5626, 0.7230, 1.0, +], dtype=torch.float32) +MIXED_KEEP_FLOAT_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "MIXED_KEEP_FLOAT_NAME_PATTERNS", + "tok_emb,lm_head,final_norm,norm," + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +MIXED_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "MIXED_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +MIXED_KEEP_FLOAT_MAX_NUMEL = int(os.environ.get("MIXED_KEEP_FLOAT_MAX_NUMEL", 65_536)) +SUPPORTED_QUANT_SCHEMES = {"int8", "int5", "int4", "mixed"} +SUPPORTED_COMPRESSORS = {"zlib", "zstd", "auto"} +SUPPORTED_WEIGHT_ORDERS = {"none", "name", "size_desc", "dtype_name"} + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor( + name: str, + t: Tensor, + passthrough_orig_dtypes: dict[str, str], + fp32_name_patterns: tuple[str, ...], +) -> Tensor: + if any(pattern in name for pattern in fp32_name_patterns): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def ordered_state_dict_items(state_dict: dict[str, Tensor], mode: str) -> list[tuple[str, Tensor]]: + items = list(state_dict.items()) + if mode == "none": + return items + if mode == "name": + return sorted(items, key=lambda kv: kv[0]) + if mode == "size_desc": + return sorted(items, key=lambda kv: (-int(kv[1].numel()), kv[0])) + if mode == "dtype_name": + return sorted(items, key=lambda kv: (str(kv[1].dtype), kv[0])) + raise ValueError(f"Unsupported WEIGHT_ORDER={mode!r}; expected one of {sorted(SUPPORTED_WEIGHT_ORDERS)}") + +def quantize_float_tensor_int8( + t: Tensor, precomputed_scale: Tensor | None = None +) -> tuple[Tensor, Tensor, dict[str, object] | None]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + if precomputed_scale is not None: + # LSQ-learned scale: use directly, skip the quantile clip computation. + scale = precomputed_scale.float().clamp_min(QUANT_SCALE_EPS) + else: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + scale = (clip_abs / 127.0).clamp_min(QUANT_SCALE_EPS) + q = torch.clamp(torch.round(t32 / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous(), {"scheme": "int8_per_row", "axis": 0} + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale, {"scheme": "int8_per_tensor", "orig_shape": list(t32.shape)} + +def pack_int4_signed(q_signed: Tensor) -> Tensor: + flat = q_signed.reshape(-1).to(dtype=torch.int16) + if flat.numel() % 2: + flat = torch.cat([flat, torch.zeros((1,), dtype=torch.int16)], dim=0) + uint = (flat + 8).to(torch.uint8) + packed = (uint[0::2] & 0x0F) | ((uint[1::2] & 0x0F) << 4) + return packed.contiguous() + +def unpack_int4_signed(packed: Tensor, numel: int) -> Tensor: + p = packed.reshape(-1).to(dtype=torch.uint8) + low = (p & 0x0F).to(dtype=torch.int16) - 8 + high = ((p >> 4) & 0x0F).to(dtype=torch.int16) - 8 + out = torch.empty((p.numel() * 2,), dtype=torch.int16) + out[0::2] = low + out[1::2] = high + return out[:numel].to(dtype=torch.int8).contiguous() + +def pack_int5_signed(q_signed: Tensor) -> Tensor: + """Pack int5 values (range [-16,15]) stored as int8 into 5 bytes per 8 values (40 bits).""" + flat = q_signed.reshape(-1).to(dtype=torch.int32) + pad = (8 - flat.numel() % 8) % 8 + if pad: + flat = torch.cat([flat, torch.zeros(pad, dtype=torch.int32)]) + u = (flat + 16).to(torch.uint8).reshape(-1, 8) # unsigned [0,31] + # 8 x uint5 → 5 bytes + b0 = (u[:, 0] ) | ((u[:, 1] & 0x07) << 5) + b1 = (u[:, 1] >> 3 ) | ( u[:, 2] << 2) | ((u[:, 3] & 0x01) << 7) + b2 = (u[:, 3] >> 1 ) | ((u[:, 4] & 0x0F) << 4) + b3 = (u[:, 4] >> 4 ) | ( u[:, 5] << 1) | ((u[:, 6] & 0x03) << 6) + b4 = (u[:, 6] >> 2 ) | ( u[:, 7] << 3) + packed = torch.stack([b0, b1, b2, b3, b4], dim=1).reshape(-1).to(torch.uint8) + return packed.contiguous() + +def unpack_int5_signed(packed: Tensor, numel: int) -> Tensor: + """Unpack int5 values from 5-bytes-per-8-values layout back to int8 [-16,15].""" + p = packed.reshape(-1, 5).to(torch.int32) + b0, b1, b2, b3, b4 = p[:, 0], p[:, 1], p[:, 2], p[:, 3], p[:, 4] + v0 = b0 & 0x1F + v1 = ((b0 >> 5) & 0x07) | ((b1 & 0x03) << 3) + v2 = ( b1 >> 2) & 0x1F + v3 = ((b1 >> 7) & 0x01) | ((b2 & 0x0F) << 1) + v4 = ((b2 >> 4) & 0x0F) | ((b3 & 0x01) << 4) + v5 = ( b3 >> 1) & 0x1F + v6 = ((b3 >> 6) & 0x03) | ((b4 & 0x07) << 2) + v7 = ( b4 >> 3) & 0x1F + out = torch.stack([v0, v1, v2, v3, v4, v5, v6, v7], dim=1).reshape(-1) + return (out[:numel] - 16).to(torch.int8).contiguous() + +def quantize_float_tensor_int5( + t: Tensor, precomputed_scale: Tensor | None = None +) -> tuple[Tensor, Tensor, dict[str, object]]: + t32 = t.float() + if t32.ndim == 2: + if precomputed_scale is not None: + scale = precomputed_scale.float().clamp_min(QUANT_SCALE_EPS) + else: + clip_abs = ( + torch.quantile(t32.abs(), INT5_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + scale = (clip_abs / 15.0).clamp_min(QUANT_SCALE_EPS) + q = torch.clamp(torch.round(t32 / scale[:, None]), -16, 15).to(torch.int8) + packed = pack_int5_signed(q) + return ( + packed, + scale.to(dtype=INT5_PER_ROW_SCALE_DTYPE).contiguous(), + {"scheme": "int5_per_row", "axis": 0, "orig_shape": [int(t32.shape[0]), int(t32.shape[1])]}, + ) + clip_abs = float(torch.quantile(t32.abs().flatten(), INT5_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 15.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -16, 15).to(torch.int8) + packed = pack_int5_signed(q) + return packed, scale, {"scheme": "int5_per_tensor", "orig_shape": list(t32.shape)} + +def quantize_float_tensor_int4( + t: Tensor, precomputed_scale: Tensor | None = None +) -> tuple[Tensor, Tensor, dict[str, object]]: + t32 = t.float() + if t32.ndim == 2: + if precomputed_scale is not None: + # LSQ-learned scale: skip quantile, use directly. + scale = precomputed_scale.float().clamp_min(QUANT_SCALE_EPS) + else: + clip_abs = ( + torch.quantile(t32.abs(), INT4_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + scale = (clip_abs / 7.0).clamp_min(QUANT_SCALE_EPS) + q = torch.clamp(torch.round(t32 / scale[:, None]), -8, 7).to(torch.int8) + packed = pack_int4_signed(q) + return ( + packed, + scale.to(dtype=INT4_PER_ROW_SCALE_DTYPE).contiguous(), + {"scheme": "int4_per_row", "axis": 0, "orig_shape": [int(t32.shape[0]), int(t32.shape[1])]}, + ) + clip_abs = float(torch.quantile(t32.abs().flatten(), INT4_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 7.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -8, 7).to(torch.int8) + packed = pack_int4_signed(q) + return packed, scale, {"scheme": "int4_per_tensor", "orig_shape": list(t32.shape)} + +def quantize_state_dict( + state_dict: dict[str, Tensor], + scheme: str = "int8", + weight_order: str = "none", + mixed_low_precision_scheme: str = "int8", + precomputed_scales: dict[str, Tensor] | None = None, + gptq_results: dict[str, tuple[Tensor, Tensor]] | None = None, +): + if scheme not in SUPPORTED_QUANT_SCHEMES: + raise ValueError(f"Unsupported QUANT_SCHEME={scheme!r}; expected one of {sorted(SUPPORTED_QUANT_SCHEMES)}") + if weight_order not in SUPPORTED_WEIGHT_ORDERS: + raise ValueError(f"Unsupported WEIGHT_ORDER={weight_order!r}; expected one of {sorted(SUPPORTED_WEIGHT_ORDERS)}") + if mixed_low_precision_scheme not in {"int8", "int5", "int4"}: + raise ValueError( + f"Unsupported MIXED_LOW_PRECISION_SCHEME={mixed_low_precision_scheme!r}; expected 'int8', 'int5', or 'int4'" + ) + + active_scheme = mixed_low_precision_scheme if scheme == "mixed" else scheme + if active_scheme == "int8": + format_name = f"{scheme}_clean_per_row_v1" + elif active_scheme == "int5": + format_name = f"{scheme}_clean_per_row_int5_v1" + else: + format_name = f"{scheme}_clean_per_row_int4_v1" + # Single supported clean-script export formats: + # - per-row low precision for 2D float tensors + # - per-tensor low precision for other float tensors + # - exact passthrough for non-floats + # - passthrough for selected float tensors, stored as fp16/fp32 + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "payload_bytes"), + 0, + ) + keep_patterns = ( + MIXED_KEEP_FLOAT_NAME_PATTERNS + if scheme == "mixed" + else ( + INT8_KEEP_FLOAT_FP32_NAME_PATTERNS + if active_scheme == "int8" + else (INT5_KEEP_FLOAT_FP32_NAME_PATTERNS if active_scheme == "int5" else INT4_KEEP_FLOAT_FP32_NAME_PATTERNS) + ) + ) + force_fp32_patterns = ( + MIXED_KEEP_FLOAT_FP32_NAME_PATTERNS + if scheme == "mixed" + else ( + INT8_KEEP_FLOAT_FP32_NAME_PATTERNS + if active_scheme == "int8" + else (INT5_KEEP_FLOAT_FP32_NAME_PATTERNS if active_scheme == "int5" else INT4_KEEP_FLOAT_FP32_NAME_PATTERNS) + ) + ) + keep_max_numel = ( + MIXED_KEEP_FLOAT_MAX_NUMEL + if scheme == "mixed" + else (INT8_KEEP_FLOAT_MAX_NUMEL if active_scheme == "int8" else (INT5_KEEP_FLOAT_MAX_NUMEL if active_scheme == "int5" else INT4_KEEP_FLOAT_MAX_NUMEL)) + ) + + for name, tensor in ordered_state_dict_items(state_dict, weight_order): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["payload_bytes"] += tensor_nbytes(t) + continue + + should_keep_float = ( + t.numel() <= keep_max_numel + or (scheme == "mixed" and any(pattern in name for pattern in keep_patterns)) + ) + if should_keep_float: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes, force_fp32_patterns) + passthrough[name] = kept + stats["payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + + # GPTQ fast path: use pre-quantized (Q, scale) from Hessian-aware quantization + if gptq_results is not None and name in gptq_results and t.ndim == 2: + gq, gs = gptq_results[name] + if active_scheme == "int5": + packed = pack_int5_signed(gq) + meta = {"scheme": "int5_per_row", "axis": 0, "orig_shape": [int(t.shape[0]), int(t.shape[1])]} + quantized[name] = packed + scales[name] = gs.to(dtype=INT5_PER_ROW_SCALE_DTYPE).contiguous() + elif active_scheme == "int4": + packed = pack_int4_signed(gq) + if gs.ndim == 2: + # Per-group scales: [rows, num_groups] + scheme_name = "int4_per_group_nf4" if NF4_ENABLED else "int4_per_group" + meta = {"scheme": scheme_name, "axis": 0, + "orig_shape": [int(t.shape[0]), int(t.shape[1])], + "group_size": INT4_GROUP_SIZE} + else: + meta = {"scheme": "int4_per_row", "axis": 0, "orig_shape": [int(t.shape[0]), int(t.shape[1])]} + quantized[name] = packed + scales[name] = gs.to(dtype=INT4_PER_ROW_SCALE_DTYPE).contiguous() + else: + meta = {"scheme": "int8_per_row", "axis": 0} + quantized[name] = gq.contiguous() + scales[name] = gs.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + qmeta[name] = meta + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["payload_bytes"] += tensor_nbytes(quantized[name]) + tensor_nbytes(scales[name]) + continue + + pre_scale = None + if precomputed_scales is not None and t.ndim == 2: + pre_scale = precomputed_scales.get(name) + if pre_scale is not None and pre_scale.shape[0] != t.shape[0]: + pre_scale = None # shape mismatch → fall back to quantile + if active_scheme == "int8": + q, s, meta = quantize_float_tensor_int8(t, precomputed_scale=pre_scale) + elif active_scheme == "int5": + q, s, meta = quantize_float_tensor_int5(t, precomputed_scale=pre_scale) + else: + q, s, meta = quantize_float_tensor_int4(t, precomputed_scale=pre_scale) + if meta: + qmeta[name] = meta + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": format_name, + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + "export_order_mode": weight_order, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + # Backward-compatible alias for existing log paths. + stats["int8_payload_bytes"] = stats["payload_bytes"] + return obj, stats + +# ---- GPTQ: Accurate Post-Training Quantization (Frantar et al., 2022) ---- + +@torch.no_grad() +def _nf4_quantize(w: Tensor, scale: Tensor) -> Tensor: + """Quantize values to NF4: find nearest NF4 level, return index in [-8, 7].""" + nf4 = NF4_LUT.to(w.device) # [16] + normalized = w / scale.clamp(min=1e-8) # normalized to ~[-1, 1] + # Find nearest NF4 level for each value + # nf4 has 16 values, indices 0..15, we store as signed [-8..7] + dists = (normalized.unsqueeze(-1) - nf4.unsqueeze(0)).abs() # [rows, 16] + indices = dists.argmin(dim=-1) # [rows] -> 0..15 + return (indices - 8).to(torch.int8) # shift to [-8, 7] for packing + + +def _nf4_dequantize(q_signed: Tensor, scale: Tensor) -> Tensor: + """Dequantize NF4: index into LUT, multiply by scale.""" + nf4 = NF4_LUT.to(q_signed.device) + indices = (q_signed.to(torch.int16) + 8).clamp(0, 15).long() + return nf4[indices] * scale + + +def gptq_quantize_weight( + W: Tensor, + H: Tensor, + bits: int = 4, + percdamp: float = 0.01, + blocksize: int = 128, + group_size: int = 0, + use_nf4: bool = False, + act_order: bool = True, +) -> tuple[Tensor, Tensor]: + """GPTQ-quantize a single weight matrix using Hessian information. + + Args: + W: [out_features, in_features] weight matrix + H: [in_features, in_features] Hessian proxy (X^T X / n) + bits: 4 or 8 + percdamp: damping fraction of mean diagonal + blocksize: column block size for lazy batch updates + group_size: columns per quantization group (0 = per-row) + use_nf4: use NF4 quantile levels instead of uniform (only for bits=4) + act_order: reorder columns by Hessian diagonal (importance) for lower error + + Returns: + (Q_int8, scale) where Q_int8 holds the quantized integers [-8..7] or [-127..127] + and scale is [rows] (per-row) or [rows, num_groups] (per-group). + """ + device = W.device + rows, cols = W.shape + W = W.clone().float() + H = H.clone().float().to(device) + + if bits == 4: + maxq, minq, sym_max = 7, -8, 7.0 + elif bits == 5: + maxq, minq, sym_max = 15, -16, 15.0 + else: + maxq, minq, sym_max = 127, -127, 127.0 + use_nf4 = use_nf4 and bits == 4 # NF4 only for 4-bit + use_groups = group_size > 0 and bits == 4 + + # Dead columns (no activation energy) → zero out weight and fix Hessian + dead = torch.diag(H) == 0 + H[dead, dead] = 1.0 + W[:, dead] = 0.0 + + # Damping for numerical stability + damp = percdamp * torch.mean(torch.diag(H)).item() + diag_idx = torch.arange(cols, device=device) + H[diag_idx, diag_idx] += damp + + # Act-order: sort columns by Hessian diagonal (most important first) + # Only use act-order without groups (act-order + groups is complex) + if act_order and bits == 4 and not use_groups: + perm = torch.argsort(torch.diag(H), descending=True) + W = W[:, perm] + H = H[perm][:, perm] + else: + perm = None + + # Compute H^{-1} via Cholesky for stability + try: + Hinv = torch.cholesky_inverse(torch.linalg.cholesky(H)) + except torch.linalg.LinAlgError: + H[diag_idx, diag_idx] += 10 * damp + Hinv = torch.cholesky_inverse(torch.linalg.cholesky(H)) + + # Compute scales: per-row or per-group (dynamically recomputed per group) + if use_groups: + num_groups = (cols + group_size - 1) // group_size + scale = torch.zeros(rows, num_groups, device=device) + else: + num_groups = 0 + scale = W.abs().amax(dim=1).clamp(min=1e-8) / sym_max + + Q = torch.zeros(rows, cols, dtype=torch.int8, device=device) + + for i1 in range(0, cols, blocksize): + i2 = min(i1 + blocksize, cols) + Err1 = torch.zeros(rows, i2 - i1, device=device) + + # Dynamically compute group scale at group boundary from current W + if use_groups: + g = i1 // group_size + if i1 % group_size == 0: + c0 = g * group_size + c1 = min(c0 + group_size, cols) + scale[:, g] = W[:, c0:c1].abs().amax(dim=1).clamp(min=1e-8) + if not use_nf4: + scale[:, g] /= sym_max + + for j in range(i2 - i1): + col = i1 + j + w = W[:, col] + d = Hinv[col, col].clamp(min=1e-10) + + # Recompute group scale at group boundary within a block + if use_groups and col > i1 and col % group_size == 0: + g = col // group_size + c0 = g * group_size + c1 = min(c0 + group_size, cols) + scale[:, g] = W[:, c0:c1].abs().amax(dim=1).clamp(min=1e-8) + if not use_nf4: + scale[:, g] /= sym_max + + # Get the scale for this column + if use_groups: + col_scale = scale[:, col // group_size] + else: + col_scale = scale + + if use_nf4: + q = _nf4_quantize(w, col_scale) + Q[:, col] = q + w_hat = _nf4_dequantize(q, col_scale) + else: + q = torch.clamp(torch.round(w / col_scale), minq, maxq) + Q[:, col] = q.to(torch.int8) + w_hat = q * col_scale + + err = (w - w_hat) / d + Err1[:, j] = err + + W[:, col] = w_hat # replace with dequantized + if j + 1 < i2 - i1: + W[:, col + 1 : i2] -= err.unsqueeze(1) * Hinv[col, col + 1 : i2].unsqueeze(0) + + # Lazy batch update: propagate accumulated error to remaining columns + if i2 < cols: + W[:, i2:] -= Err1 @ Hinv[i1:i2, i2:] + + # Un-permute back to original column order (act-order only, no groups) + if perm is not None: + invperm = torch.argsort(perm) + Q = Q[:, invperm] + + return Q, scale + + +@torch.no_grad() +def collect_gptq_hessians( + model: nn.Module, + val_tokens: Tensor, + device: torch.device, + seq_len: int = 1024, + nsamples: int = 128, +) -> dict[str, Tensor]: + """Collect H = (1/n) X^T X for each CastedLinear by running calibration data.""" + hessians: dict[str, Tensor] = {} + sample_counts: dict[str, int] = {} + hooks = [] + + for name, module in model.named_modules(): + if isinstance(module, CastedLinear): + key = name + ".weight" + hessians[key] = torch.zeros(module.in_features, module.in_features, device=device) + sample_counts[key] = 0 + + def make_hook(k: str): + def hook_fn(mod, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + hessians[k].addmm_(x.T, x) + sample_counts[k] += x.shape[0] + return hook_fn + + hooks.append(module.register_forward_hook(make_hook(key))) + + # Tied embeddings use F.linear(hidden, tok_emb.weight) instead of a CastedLinear + # module, so hook the final normalized hidden states as calibration inputs for + # tok_emb.weight. This matters most at large vocab sizes where the tied + # embedding/output matrix dominates both parameters and quantization error. + if getattr(model, "tie_embeddings", False) and hasattr(model, "tok_emb") and hasattr(model, "final_norm"): + key = "tok_emb.weight" + emb = getattr(model, "tok_emb") + embed_dim = int(getattr(emb, "embedding_dim", 0)) + if embed_dim > 0 and key not in hessians: + hessians[key] = torch.zeros(embed_dim, embed_dim, device=device) + sample_counts[key] = 0 + + def tied_embedding_hook(_mod, _inp, out): + x = out.detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + hessians[key].addmm_(x.T, x) + sample_counts[key] += x.shape[0] + + hooks.append(model.final_norm.register_forward_hook(tied_embedding_hook)) + + # Disable QAT fake-quant during calibration + saved_qat_levels = CastedLinear.qat_levels + CastedLinear.qat_levels = 0 + + model.eval() + total_tokens = val_tokens.numel() - 1 + tokens_used = 0 + with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for i in range(0, total_tokens - seq_len, seq_len): + if tokens_used >= nsamples * seq_len: + break + x = val_tokens[i : i + seq_len].unsqueeze(0).to(device=device, dtype=torch.int64) + y = val_tokens[i + 1 : i + seq_len + 1].unsqueeze(0).to(device=device, dtype=torch.int64) + model(x, y) + tokens_used += seq_len + + CastedLinear.qat_levels = saved_qat_levels + + for h in hooks: + h.remove() + + # Normalize: H = (1/n) * X^T X + for key in hessians: + n = max(sample_counts[key], 1) + hessians[key] /= n + + return hessians + + +@torch.no_grad() +def gptq_quantize_state_dict( + model: nn.Module, + state_dict: dict[str, Tensor], + hessians: dict[str, Tensor], + bits: int = 4, + percdamp: float = 0.01, + blocksize: int = 128, + group_size: int = 0, + use_nf4: bool = False, +) -> dict[str, tuple[Tensor, Tensor]]: + """Apply GPTQ to all CastedLinear weights that have Hessians. + + Returns {state_dict_key: (Q_int8, scale)} for quantized 2D tensors. + scale is [rows] (per-row) or [rows, num_groups] (per-group). + """ + device = next(model.parameters()).device + results: dict[str, tuple[Tensor, Tensor]] = {} + for name in sorted(hessians.keys()): + if name not in state_dict: + continue + W = state_dict[name].to(device) + if W.ndim != 2: + continue + H = hessians[name] + Q, scale = gptq_quantize_weight( + W, H, bits=bits, percdamp=percdamp, blocksize=blocksize, + group_size=group_size, use_nf4=use_nf4, + ) + results[name] = (Q.cpu(), scale.cpu()) + return results + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + format_name = str(obj.get("__quant_format__", "")) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + meta = qmeta.get(name, {}) + meta_scheme = str(meta.get("scheme", "")) + if meta_scheme in {"int5_per_row", "int5_per_tensor"}: + orig_shape = tuple(int(v) for v in meta.get("orig_shape", q.shape)) + numel = math.prod(orig_shape) + unpacked = unpack_int5_signed(q, numel) + if meta_scheme == "int5_per_row": + rows, cols = orig_shape + scale_row = s.to(dtype=torch.float32).view(rows, 1) + out[name] = (unpacked.float().view(rows, cols) * scale_row).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (unpacked.float().view(orig_shape) * scale).to(dtype=dtype).contiguous() + continue + if meta_scheme in {"int4_per_row", "int4_per_tensor", "int4_per_group", "int4_per_group_nf4"}: + orig_shape = tuple(int(v) for v in meta.get("orig_shape", q.shape)) + numel = math.prod(orig_shape) + unpacked = unpack_int4_signed(q, numel) + if meta_scheme in {"int4_per_group", "int4_per_group_nf4"}: + rows, cols = orig_shape + group_size = int(meta.get("group_size", 128)) + s_f = s.to(dtype=torch.float32) # [rows, num_groups] + q_mat = unpacked.view(rows, cols) + if meta_scheme == "int4_per_group_nf4": + # NF4 dequantization: index into LUT, then multiply by group scale + nf4 = NF4_LUT # [16] + indices = (q_mat.to(torch.int16) + 8).clamp(0, 15).long() + nf4_vals = nf4[indices] # [rows, cols] in [-1, 1] + # Expand group scales to per-column + group_idx = torch.arange(cols) // group_size + group_idx = group_idx.clamp(max=s_f.shape[1] - 1) + col_scales = s_f[:, group_idx] # [rows, cols] + out[name] = (nf4_vals * col_scales).to(dtype=dtype).contiguous() + else: + # Uniform int4 per-group dequantization + group_idx = torch.arange(cols) // group_size + group_idx = group_idx.clamp(max=s_f.shape[1] - 1) + col_scales = s_f[:, group_idx] # [rows, cols] + out[name] = (unpacked.float().view(rows, cols) * col_scales).to(dtype=dtype).contiguous() + elif meta_scheme == "int4_per_row": + rows, cols = orig_shape + scale_row = s.to(dtype=torch.float32).view(rows, 1) + out[name] = (unpacked.float().view(rows, cols) * scale_row).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (unpacked.float().view(orig_shape) * scale).to(dtype=dtype).contiguous() + continue + if meta_scheme in {"int8_per_row", "per_row"} or (s.ndim > 0 and "int4" not in format_name): + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + +def resolve_compressor(requested: str) -> tuple[str, str | None]: + if requested not in SUPPORTED_COMPRESSORS: + raise ValueError(f"Unsupported COMPRESSOR={requested!r}; expected one of {sorted(SUPPORTED_COMPRESSORS)}") + if requested == "zlib": + return "zlib", None + if requested == "zstd": + if importlib.util.find_spec("zstandard") is None: + raise RuntimeError( + "COMPRESSOR=zstd requested, but the `zstandard` package is not installed. " + "Install it with `pip install zstandard` or use COMPRESSOR=zlib." + ) + return "zstd", None + # auto mode + if importlib.util.find_spec("zstandard") is not None: + return "zstd", "COMPRESSOR=auto selected zstd (package available)" + return "zlib", "COMPRESSOR=auto fell back to zlib (zstandard package not installed)" + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + zlib_level = 9 if level < 0 else max(0, min(level, 9)) + return zlib.compress(data, level=zlib_level) + if compressor == "zstd": + import zstandard as zstd # type: ignore + + zstd_level = 19 if level < 0 else level + return zstd.ZstdCompressor(level=zstd_level).compress(data) + raise ValueError(f"Unsupported compressor={compressor!r}") + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + import zstandard as zstd # type: ignore + + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported compressor={compressor!r}") + +def export_artifact_name(quant_scheme: str, compressor: str) -> str: + if quant_scheme == "int8" and compressor == "zlib": + return "final_model.int8.ptz" + return f"final_model.{quant_scheme}.{compressor}.ptc" + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +def _fake_quantize_row(w: Tensor, levels: int) -> Tensor: + """Per-row fake-quantise a 2D weight with a straight-through estimator (STE). + + Matches the per-row clipping used by quantize_float_tensor_int8/int4 at export, + but uses amax instead of quantile for speed in the hot forward path. + levels=256 → int8 symmetric (range −127…127) + levels=16 → int4 symmetric (range −7…7) + """ + half = float(levels // 2 - (1 if levels in (16, 32) else 0)) # 127 for int8, 15 for int5, 7 for int4 + w32 = w.float() + clip_abs = w32.abs().amax(dim=1).clamp_min(1e-6) # per-row max scale + scale = clip_abs / half + w_scaled = (w32 / scale.unsqueeze(1)).clamp(-half, half) + # STE: round in forward, identity in backward + w_ste = w_scaled + (w_scaled.round() - w_scaled).detach() + return (w_ste * scale.unsqueeze(1)).to(w.dtype) + + +def _fake_quantize_row_lsq(w: Tensor, levels: int, log_scale: Tensor) -> Tensor: + """LSQ variant: per-row learnable step-size quantisation with STE. + + Based on "Learned Step Size Quantization" (Esser et al., 2019). + log_scale is a learnable 1D parameter [out_features] optimised via backprop. + Gradient on log_scale is scaled by g = 1/sqrt(numel_per_row * half) per the LSQ paper, + which keeps the scale-gradient magnitude commensurate with weight-gradient magnitude. + + Compared to max-abs fake-quant, LSQ lets the model adapt the clip threshold per row, + reducing int4 quantisation error by ~30-50% on typical models. + """ + half = float(levels // 2 - (1 if levels in (16, 32) else 0)) + w32 = w.float() + # LSQ gradient scaling trick: effective gradient on log_scale is g * d_loss/d_scale. + numel_per_row = float(w32.shape[1]) + g = 1.0 / math.sqrt(max(numel_per_row * half, 1.0)) + ls_grad_scaled = log_scale * g + (log_scale - log_scale * g).detach() + # Convert log-scale to positive scale via exp (auto-positive, stable). + scale = ls_grad_scaled.float().exp().clamp_min(1e-8) + w_scaled = (w32 / scale.unsqueeze(1)).clamp(-half, half) + w_ste = w_scaled + (w_scaled.round() - w_scaled).detach() + return (w_ste * scale.unsqueeze(1)).to(w.dtype) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + # QAT: set qat_levels to 256 (int8), 32 (int5), or 16 (int4) to enable fake-quantisation. + qat_levels: int = 0 # class-level switch updated from the training loop + # LSQ: when True, CastedLinear instances allocate a learnable per-row log-scale parameter + # used in place of the max-abs scale. Must be set BEFORE model construction. + qat_lsq_enabled: bool = False + + def __init__(self, in_features: int, out_features: int, bias: bool = True, **kwargs) -> None: + super().__init__(in_features, out_features, bias=bias, **kwargs) + if __class__.qat_lsq_enabled: + # Per-row log-scale. Zeros → scale=1.0 placeholder; re-initialised from actual + # weight stats at the step QAT first activates (see init_lsq_scales below). + self.qat_log_scale = nn.Parameter(torch.zeros(out_features)) + else: + self.qat_log_scale = None + + def forward(self, x: Tensor) -> Tensor: + w = self.weight + if __class__.qat_levels > 0 and w.ndim == 2: + if self.qat_log_scale is not None: + w = _fake_quantize_row_lsq(w, __class__.qat_levels, self.qat_log_scale) + else: + w = _fake_quantize_row(w, __class__.qat_levels) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w.to(x.dtype), bias) + + +def init_lsq_scales(model: nn.Module, levels: int) -> int: + """Initialise LSQ per-row log-scales from current weight statistics. + + Called once when QAT first activates. Sets each log_scale to + log(max_abs_per_row / half), matching the initial value a max-abs fake-quant would use. + Returns the number of CastedLinear modules initialised. + """ + half = float(levels // 2 - (1 if levels in (16, 32) else 0)) + count = 0 + with torch.no_grad(): + for m in model.modules(): + if isinstance(m, CastedLinear) and m.qat_log_scale is not None and m.weight.ndim == 2: + w32 = m.weight.detach().float() + scale_val = (w32.abs().amax(dim=1).clamp_min(1e-6) / max(half, 1.0)) + m.qat_log_scale.data.copy_(scale_val.log().to(m.qat_log_scale.dtype)) + count += 1 + return count + + +def collect_lsq_scales(model: nn.Module, prefix: str = "") -> dict[str, Tensor]: + """Walk the model and return a dict of {state_dict_weight_name: exp(log_scale)}. + + Used at export time to plumb LSQ-learned scales into quantize_float_tensor_int4/int8 + via the precomputed_scales dict. + """ + scales: dict[str, Tensor] = {} + for name, m in model.named_modules(prefix=prefix): + if isinstance(m, CastedLinear) and m.qat_log_scale is not None and m.weight.ndim == 2: + key = f"{name}.weight" if name else "weight" + scales[key] = m.qat_log_scale.detach().float().exp().clamp_min(1e-8).cpu() + return scales + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if num_heads <= 0: + raise ValueError(f"num_heads must be positive, got {num_heads}") + if num_kv_heads <= 0: + raise ValueError(f"num_kv_heads must be positive, got {num_kv_heads}") + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + # Expand KV heads to match Q heads for GQA (handles older PyTorch without enable_gqa) + if self.num_kv_heads != self.num_heads: + groups = self.num_heads // self.num_kv_heads + k = k.repeat_interleave(groups, dim=1) + v = v.repeat_interleave(groups, dim=1) + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int, use_swiglu: bool = False): + super().__init__() + self.use_swiglu = use_swiglu + if use_swiglu: + # SwiGLU with the same parameter budget as relu²: + # relu² uses 2 matrices of (dim × mlp_mult*dim) = 2*mlp_mult*dim² params. + # SwiGLU uses 3 matrices of (dim × h): 3*h*dim params. + # Equating: h = (2/3)*mlp_mult*dim. Round down to multiple of 64 for hardware alignment. + hidden = max(64, (2 * mlp_mult * dim // 3 // 64) * 64) + self.gate = CastedLinear(dim, hidden, bias=False) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + else: + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + if self.use_swiglu: + return self.proj(F.silu(self.gate(x)) * self.fc(x)) + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class MoEMLP(nn.Module): + """Sparse Mixture-of-Experts MLP with Expert Choice routing. + + Design goals + ============ + 1. **torch.compile(fullgraph=True) compatible** — Expert Choice routing gives + every expert a statically-shaped slice of tokens [capacity, D], avoiding the + dynamic-shape issues of token-choice top-k dispatch. + 2. **QAT-aware** — all expert weights are CastedLinear, so the class-level + CastedLinear.qat_levels switch applies uniformly to router and experts. + 3. **Muon-trained** — CastedLinear parameters are automatically picked up by + the existing Muon parameter-group logic (2-D weight matrices). + 4. **Load-balanced by construction** — each expert always processes exactly + `capacity` tokens, so no explicit load-balance loss is required. + 5. **Router stability via Z-loss** — a small penalty on router logit magnitudes + prevents collapse (all tokens always sent to one expert). + + Expert Choice routing (Zhou et al., 2022) + ========================================== + Instead of each token selecting its top-k experts (token choice), each expert + selects the top `capacity` tokens it wants to process: + + capacity = max(1, int(capacity_factor * S / E)) # S = B*T, E = num_experts + + router_probs [S, E] = softmax(router_logits) + top_scores [E, cap] \\ + top_indices [E, cap] / = router_probs.T.topk(capacity, dim=1) + + For each expert i: + expert_input = x_flat[top_indices[i]] # [cap, D] — gather + expert_out = expert_mlp_i(expert_input) # [cap, D] + expert_out *= top_scores[i] # weighted by routing prob + output += scatter(expert_out, top_indices[i]) # accumulate + + Every tensor shape is statically determined → fullgraph compile succeeds. + + Args: + dim : model hidden dimension + mlp_mult : MLP width multiplier (identical to base MLP) + num_experts : number of expert MLPs (E); must be ≥ 2 + capacity_factor : fraction of tokens each expert sees; 1.0 = perfect coverage + use_swiglu : SwiGLU activation (matching the base MLP choice) + """ + + def __init__( + self, + dim: int, + mlp_mult: int, + num_experts: int, + capacity_factor: float = 1.0, + use_swiglu: bool = False, + ): + super().__init__() + if num_experts < 2: + raise ValueError(f"MoEMLP requires num_experts >= 2, got {num_experts}") + self.num_experts = num_experts + self.capacity_factor = capacity_factor + self.use_swiglu = use_swiglu + + # Router: linear map from hidden dim to expert scores. + # CastedLinear → participates in QAT and Muon automatically. + self.router = CastedLinear(dim, num_experts, bias=False) + + # Per-expert weight matrices stored as ModuleLists of CastedLinear. + # This is intentionally verbose (vs stacked tensors) so that: + # a) Each expert participates in QAT via CastedLinear.qat_levels + # b) Muon picks them up as standard 2-D parameters + # c) Zero-init of proj layers is handled naturally via _zero_init flag + if use_swiglu: + hidden = max(64, (2 * mlp_mult * dim // 3 // 64) * 64) + self.expert_gates = nn.ModuleList([CastedLinear(dim, hidden, bias=False) for _ in range(num_experts)]) + self.expert_fcs = nn.ModuleList([CastedLinear(dim, hidden, bias=False) for _ in range(num_experts)]) + self.expert_projs = nn.ModuleList([CastedLinear(hidden, dim, bias=False) for _ in range(num_experts)]) + for m in self.expert_projs: + m._zero_init = True + else: + hidden = mlp_mult * dim + self.expert_gates = nn.ModuleList() # unused for relu²; kept for uniform attr + self.expert_fcs = nn.ModuleList([CastedLinear(dim, hidden, bias=False) for _ in range(num_experts)]) + self.expert_projs = nn.ModuleList([CastedLinear(hidden, dim, bias=False) for _ in range(num_experts)]) + for m in self.expert_projs: + m._zero_init = True + + def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: + """ + Args: + x : [B, T, D] + Returns: + output : [B, T, D] — same shape as input + z_loss : scalar — router Z-loss; add to training loss via moe_aux_loss_coeff + """ + B, T, D = x.shape + S = B * T + x_flat = x.reshape(S, D) + + # ── Router ────────────────────────────────────────────────────────── + router_logits = self.router(x_flat) # [S, E] (bfloat16) + + # Z-loss (Zoph et al., 2022 "ST-MoE"): + # z_loss = mean( log(∑_e exp(router_logits))² ) + # Keeps router logits from growing large → prevents routing collapse. + z_loss: Tensor = torch.logsumexp(router_logits.float(), dim=-1).square().mean() + + router_probs = torch.softmax(router_logits.float(), dim=-1) # [S, E] + + # ── Expert Choice: each expert picks its top-capacity tokens ───────── + # capacity is a Python int → static shape → fullgraph-compile friendly + capacity = max(1, int(self.capacity_factor * S / self.num_experts)) + + # router_probs.T is [E, S]; topk over dim=1 selects the top-capacity token + # indices per expert. Both outputs have static shape [E, capacity]. + top_scores, top_indices = router_probs.T.topk(capacity, dim=1) # [E, cap] + + # ── Expert forward + weighted scatter ──────────────────────────────── + output = torch.zeros_like(x_flat) # [S, D] + + for i in range(self.num_experts): + # Gather the tokens this expert selected. Shape: [cap, D] + expert_in = x_flat[top_indices[i]] + weights = top_scores[i].to(expert_in.dtype) # [cap] + + # Expert MLP forward (SwiGLU or relu²) + if self.use_swiglu: + h = F.silu(self.expert_gates[i](expert_in)) * self.expert_fcs[i](expert_in) + expert_out = self.expert_projs[i](h) + else: + h = torch.relu(self.expert_fcs[i](expert_in)) + expert_out = self.expert_projs[i](h.square()) + + # Scale by routing probability (gradient flows through weights here) + expert_out = expert_out * weights.unsqueeze(-1) + + # Scatter-add back into the output buffer at the positions this expert owns. + # top_indices[i] has static shape [cap]; unsqueeze(-1).expand gives [cap, D]. + output.scatter_add_( + 0, + top_indices[i].unsqueeze(-1).expand(-1, D), + expert_out, + ) + + return output.reshape(B, T, D), z_loss + + +class SSMMixer(nn.Module): + """SSM mixer used by SSM blocks. + + `impl="mamba3"` wraps the official CUDA-backed Mamba-3 block from + `mamba_ssm.modules.mamba3`. `impl="conv"` keeps the older lightweight causal + depthwise-conv mixer available for ablations. + """ + + def __init__( + self, + dim: int, + expand: float = 2.0, + kernel_size: int = 4, + impl: str = "mamba3", + mamba3_d_state: int = 128, + mamba3_head_dim: int = 64, + mamba3_is_mimo: bool = True, + mamba3_mimo_rank: int = 4, + mamba3_chunk_size: int = 16, + mamba3_outproj_norm: bool = False, + ): + super().__init__() + self.impl = impl.strip().lower() + if self.impl not in {"mamba3", "conv"}: + raise ValueError(f"Unsupported SSM_IMPL={impl!r}; expected 'mamba3' or 'conv'") + if self.impl == "mamba3": + if _OfficialMamba3 is None: + raise ImportError( + "SSM_IMPL=mamba3 requires the source build of mamba-ssm with Mamba3. " + "Install with: MAMBA_FORCE_BUILD=TRUE pip install --no-cache-dir " + "--force-reinstall git+https://github.com/state-spaces/mamba.git --no-build-isolation" + ) from _MAMBA3_IMPORT_ERROR + if mamba3_head_dim <= 0: + preferred = [128, 64, 32] + mamba3_head_dim = next((h for h in preferred if dim % h == 0), 0) + if mamba3_head_dim <= 0: + raise ValueError( + f"MAMBA3_HEAD_DIM=0 could not auto-pick a tested Mamba-3 headdim " + f"for MODEL_DIM={dim}; use a MODEL_DIM divisible by one of {preferred} " + f"(for example 448 or 512), or explicitly set MAMBA3_HEAD_DIM at your own risk." + ) + if dim % mamba3_head_dim != 0: + raise ValueError( + f"MODEL_DIM={dim} must be divisible by MAMBA3_HEAD_DIM={mamba3_head_dim}" + ) + self.mamba3_head_dim = int(mamba3_head_dim) + if mamba3_d_state <= 0: + raise ValueError(f"MAMBA3_D_STATE must be positive, got {mamba3_d_state}") + if mamba3_is_mimo and mamba3_mimo_rank <= 0: + raise ValueError(f"MAMBA3_MIMO_RANK must be positive, got {mamba3_mimo_rank}") + if mamba3_chunk_size <= 0: + raise ValueError(f"MAMBA3_CHUNK_SIZE must be positive, got {mamba3_chunk_size}") + kwargs = dict( + d_model=dim, + d_state=mamba3_d_state, + headdim=mamba3_head_dim, + is_mimo=bool(mamba3_is_mimo), + chunk_size=mamba3_chunk_size, + is_outproj_norm=bool(mamba3_outproj_norm), + ) + if mamba3_is_mimo: + kwargs["mimo_rank"] = mamba3_mimo_rank + self.mamba3 = _OfficialMamba3(**kwargs) + return + + if kernel_size < 2: + raise ValueError(f"SSM kernel must be >= 2, got {kernel_size}") + hidden = max(64, int(dim * expand) // 64 * 64) + self.in_proj = CastedLinear(dim, hidden * 2, bias=False) + # Depthwise causal conv over time (implemented via left crop after padding). + self.dw_conv = nn.Conv1d( + hidden, + hidden, + kernel_size=kernel_size, + groups=hidden, + bias=False, + padding=kernel_size - 1, + ) + self.out_proj = CastedLinear(hidden, dim, bias=False) + self.out_proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + # x: [B, T, D] + if self.impl == "mamba3": + return self.mamba3(x) + bsz, seqlen, _ = x.shape + uv = self.in_proj(x) + u, v = uv.chunk(2, dim=-1) + u = F.silu(u) + y = self.dw_conv(u.transpose(1, 2))[..., :seqlen].transpose(1, 2).contiguous() + y = y * torch.sigmoid(v) + return self.out_proj(y) + + +class MTPBranch(nn.Module): + """Per-horizon residual branch for multi-token prediction.""" + + def __init__(self, dim: int): + super().__init__() + self.norm = RMSNorm() + self.proj = CastedLinear(dim, dim, bias=False) + self.scale = nn.Parameter(torch.ones(1, dtype=torch.float32)) + + def forward(self, h: Tensor) -> Tensor: + return h + self.scale.to(dtype=h.dtype) * self.proj(self.norm(h)) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_swiglu: bool = False, + use_ssm: bool = False, + ssm_expand: float = 2.0, + ssm_kernel: int = 4, + ssm_impl: str = "mamba3", + mamba3_d_state: int = 128, + mamba3_head_dim: int = 64, + mamba3_is_mimo: bool = True, + mamba3_mimo_rank: int = 4, + mamba3_chunk_size: int = 16, + mamba3_outproj_norm: bool = False, + moe_num_experts: int = 0, + moe_capacity_factor: float = 1.0, + use_parallel_residual: bool = False, + use_sandwich_norm: bool = False, + ): + super().__init__() + self.use_ssm = use_ssm + self.use_sandwich_norm = use_sandwich_norm and not use_parallel_residual + # Parallel residual: one shared pre-norm feeds both attn and MLP simultaneously. + # Saves one RMSNorm, improves gradient flow; validated by leaderboard PRs. + self.use_parallel_residual = use_parallel_residual and not use_ssm + if use_parallel_residual and not use_ssm: + self.norm = RMSNorm() # single shared norm + self.attn_norm = self.norm # alias for compat + self.mlp_norm = self.norm # alias for compat + else: + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + if use_ssm: + self.attn = None + self.ssm = SSMMixer( + dim, + expand=ssm_expand, + kernel_size=ssm_kernel, + impl=ssm_impl, + mamba3_d_state=mamba3_d_state, + mamba3_head_dim=mamba3_head_dim, + mamba3_is_mimo=mamba3_is_mimo, + mamba3_mimo_rank=mamba3_mimo_rank, + mamba3_chunk_size=mamba3_chunk_size, + mamba3_outproj_norm=mamba3_outproj_norm, + ) + else: + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.ssm = None + # MoE or dense MLP — is_moe is a Python bool, resolved at compile time. + self.is_moe: bool = moe_num_experts >= 2 + if self.is_moe: + self.mlp: MLP | MoEMLP = MoEMLP(dim, mlp_mult, moe_num_experts, moe_capacity_factor, use_swiglu) + else: + self.mlp = MLP(dim, mlp_mult, use_swiglu=use_swiglu) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + # Sandwich norm: post-sublayer norms (Gemma 2 style). Applied before residual add. + if self.use_sandwich_norm: + self.attn_post_norm = RMSNorm() + self.mlp_post_norm = RMSNorm() + + def forward(self, x: Tensor, x0: Tensor) -> tuple[Tensor, Tensor]: + """Returns (hidden_state, moe_z_loss). + moe_z_loss is a zero scalar for non-MoE blocks so callers can always + accumulate unconditionally without a Python-level branch.""" + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + if self.use_ssm: + if self.ssm is None: + raise RuntimeError("SSM block is enabled but mixer is missing") + mix_out = self.ssm(self.attn_norm(x)) + if self.use_sandwich_norm: + mix_out = self.attn_post_norm(mix_out) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * mix_out + if self.is_moe: + mlp_out, z_loss = self.mlp(self.mlp_norm(x)) + else: + mlp_out = self.mlp(self.mlp_norm(x)) + z_loss = x.new_zeros(()) + if self.use_sandwich_norm: + mlp_out = self.mlp_post_norm(mlp_out) + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * mlp_out + elif self.use_parallel_residual: + # Parallel: both attn and MLP read the same pre-norm input, outputs added together. + if self.attn is None: + raise RuntimeError("Attention block is enabled but attention module is missing") + h = self.norm(x) + attn_out = self.attn(h) + if self.is_moe: + mlp_out, z_loss = self.mlp(h) + else: + mlp_out = self.mlp(h) + z_loss = x.new_zeros(()) + x = (x + + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * mlp_out) + else: + if self.attn is None: + raise RuntimeError("Attention block is enabled but attention module is missing") + mix_out = self.attn(self.attn_norm(x)) + if self.use_sandwich_norm: + mix_out = self.attn_post_norm(mix_out) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * mix_out + if self.is_moe: + mlp_out, z_loss = self.mlp(self.mlp_norm(x)) + else: + mlp_out = self.mlp(self.mlp_norm(x)) + z_loss = x.new_zeros(()) + if self.use_sandwich_norm: + mlp_out = self.mlp_post_norm(mlp_out) + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * mlp_out + return x, z_loss + + +class JPCRPredictor(nn.Module): + """JEPA Predictive Coding Recurrence predictor (v2 — BYOL/data2vec-inspired). + + Per-token MLP that predicts "where the hidden state should be" at this depth. + Trained with cosine similarity loss against instance-normalized EMA teacher + intermediates projected into a smaller space (BYOL-style). + + Architecture: + Blend path: RMSNorm → Linear(dim, hidden) → SiLU → Linear(hidden, dim) → residual + Loss path: shared Linear(dim, proj_dim) on prediction and normalized target, cosine loss + + The blend path modifies the recurrence input at inference (no teacher needed). + The loss path trains the predictor — projects to proj_dim for stable, bounded loss. + """ + + def __init__(self, model_dim: int, hidden_dim: int = 128, proj_dim: int = 128, + blend_init: float = -2.0): + super().__init__() + self.model_dim = model_dim + self.proj_dim = proj_dim + # Blend path: predicts delta to add to x + self.proj_in = nn.Linear(model_dim, hidden_dim, bias=True) + self.proj_out = nn.Linear(hidden_dim, model_dim, bias=True) + # Learnable blend gate (logit space). sigmoid(-2.0) ≈ 0.12 → conservative start. + self.blend_gate = nn.Parameter(torch.tensor(blend_init, dtype=torch.float32)) + # Zero-init output → identity at start of training (delta = 0) + nn.init.zeros_(self.proj_out.weight) + nn.init.zeros_(self.proj_out.bias) + # Loss projection heads (BYOL-style): project to smaller space for loss + self.student_proj = nn.Linear(model_dim, proj_dim, bias=False) + + def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: + """Returns (predicted_target, gate_value). No loss computation here.""" + h = F.rms_norm(x, (self.model_dim,)) + h = F.silu(self.proj_in(h)) + delta = self.proj_out(h) + predicted_target = x + delta + gate = torch.sigmoid(self.blend_gate.to(x.dtype)) + return predicted_target, gate + + def compute_loss(self, predicted_target: Tensor, teacher_target: Tensor) -> Tensor: + """Cosine similarity loss in projected space with instance-normalized targets. + + Returns scalar loss in [0, 2] (0 = perfect alignment, 2 = opposite). + Uses data2vec-style instance normalization + BYOL-style projection. + """ + # Instance-normalize teacher target (data2vec): zero-mean, unit-var per token + t = teacher_target.float() + t = (t - t.mean(dim=-1, keepdim=True)) / (t.std(dim=-1, keepdim=True) + 1e-6) + # Project both to smaller space with shared projector, detach target branch. + s_proj = self.student_proj(predicted_target.float()) + t_proj = self.student_proj(t).detach() + # Cosine similarity loss: 1 - cos_sim, bounded [0, 2] + s_norm = F.normalize(s_proj, dim=-1) + t_norm = F.normalize(t_proj, dim=-1) + return (1.0 - (s_norm * t_norm).sum(dim=-1)).mean() + + +def _run_ctrl_safe(ctrl: nn.Sequential, x: Tensor, loop_steps: int, model_dim: int) -> Tensor: + """Run Ouroboros controller with explicit dtype handling to avoid autocast/compile issues.""" + d = x.dtype + h = x.mean(dim=1) # [B, dim] + # Functional forward through controller: Linear -> SiLU -> Linear + h = F.linear(h, ctrl[0].weight.to(d), ctrl[0].bias.to(d)) + h = F.silu(h) + h = F.linear(h, ctrl[2].weight.to(d), ctrl[2].bias.to(d)) + return h.view(x.shape[0], loop_steps, 2, model_dim) + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + recurrent_core_layers: int = 0, + recurrent_steps: int = 0, + share_ffn_across_blocks: bool = False, + intra_loop_start: int = -1, + intra_loop_end: int = -1, + intra_loop_steps: int = 3, + use_parallel_residual: bool = False, + use_swiglu: bool = False, + bigram_rank: int = 0, + mtp_enabled: bool = False, + mtp_steps: int = 2, + mtp_weight: float = 0.3, + mtp_decay: float = 1.0, + mtp_tie_embeddings: bool = True, + use_ssm: bool = False, + ssm_every_n: int = 2, + ssm_expand: float = 2.0, + ssm_kernel: int = 4, + ssm_impl: str = "mamba3", + mamba3_d_state: int = 128, + mamba3_head_dim: int = 64, + mamba3_is_mimo: bool = True, + mamba3_mimo_rank: int = 4, + mamba3_chunk_size: int = 16, + mamba3_outproj_norm: bool = False, + residual_ngram_enabled: bool = False, + residual_bigram_rank: int = 0, + residual_trigram_rank: int = 0, + residual_ngram_mix_init: float = -2.5, + ngram_softcap: float = 0.0, + ngram_entropy_gate: bool = False, + copy_cache_enabled: bool = False, + copy_cache_window: int = 256, + copy_cache_dim: int = 64, + copy_cache_gate_init: float = -4.0, + moe_num_experts: int = 0, + moe_every_n: int = 2, + moe_capacity_factor: float = 1.0, + moe_aux_loss_coeff: float = 1e-3, + dual_head_enabled: bool = False, + dual_head_num_classes: int = 4, + jpcr_enabled: bool = False, + jpcr_hidden: int = 128, + jpcr_proj_dim: int = 128, + jpcr_blend_init: float = -2.0, + use_sandwich_norm: bool = False, + embed_scale: bool = False, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + if (recurrent_core_layers > 0) != (recurrent_steps > 0): + raise ValueError( + "RECURRENT_CORE_LAYERS and RECURRENT_STEPS must both be > 0 for recurrence mode, " + f"got RECURRENT_CORE_LAYERS={recurrent_core_layers}, RECURRENT_STEPS={recurrent_steps}" + ) + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.use_recurrence = recurrent_core_layers > 0 and recurrent_steps > 0 + self.recurrent_core_layers = recurrent_core_layers + self.recurrent_steps = recurrent_steps + self.share_ffn_across_blocks = share_ffn_across_blocks + # Partial depth recurrence: loop layers [intra_loop_start..intra_loop_end] N times. + # Middle layers are optimal (see Universal Transformers; leaderboard PR #1394). + # Loop-position embeddings (shape [n_looped_blocks, steps, dim], init=0) let the + # model distinguish iteration 0 from iteration 1, learned via Adam at scalar_lr. + _intra_active = (intra_loop_start >= 0 and intra_loop_end >= intra_loop_start + and intra_loop_steps > 1 and not self.use_recurrence) + self.intra_loop_start = int(intra_loop_start) if _intra_active else -1 + self.intra_loop_end = int(intra_loop_end) if _intra_active else -1 + self.intra_loop_steps = int(intra_loop_steps) if _intra_active else 1 + self.use_ssm = use_ssm + self.ssm_every_n = ssm_every_n + self.ssm_expand = ssm_expand + self.ssm_kernel = ssm_kernel + self.ssm_impl = ssm_impl + self.mamba3_d_state = mamba3_d_state + self.mamba3_head_dim = mamba3_head_dim + self.mamba3_is_mimo = mamba3_is_mimo + self.mamba3_mimo_rank = mamba3_mimo_rank + self.mamba3_chunk_size = mamba3_chunk_size + self.mamba3_outproj_norm = mamba3_outproj_norm + self.mtp_enabled = mtp_enabled and mtp_steps > 0 + self.mtp_steps = max(0, mtp_steps) + self.mtp_weight = max(0.0, mtp_weight) + self.mtp_decay = mtp_decay + self.mtp_tie_embeddings = mtp_tie_embeddings + self.residual_bigram_rank = max(0, residual_bigram_rank) + self.residual_trigram_rank = max(0, residual_trigram_rank) + self.residual_ngram_enabled = residual_ngram_enabled and ( + self.residual_bigram_rank > 0 or self.residual_trigram_rank > 0 + ) + self.residual_ngram_mix_init = residual_ngram_mix_init + # 0.0 means "inherit logit_softcap"; >0 decouples the ngram branch cap. + self.ngram_softcap = float(ngram_softcap) if ngram_softcap > 0.0 else 0.0 + self.ngram_entropy_gate = bool(ngram_entropy_gate) and self.residual_ngram_enabled + self.copy_cache_enabled = copy_cache_enabled + self.copy_cache_window = max(1, int(copy_cache_window)) + self.copy_cache_dim = max(8, int(copy_cache_dim)) + self.copy_cache_gate_init = copy_cache_gate_init + self.dual_head_enabled = bool(dual_head_enabled) + self.dual_head_num_classes = max(2, int(dual_head_num_classes)) + if self.use_recurrence: + self.total_effective_layers = recurrent_core_layers * recurrent_steps + elif self.intra_loop_start >= 0: + n_looped = self.intra_loop_end - self.intra_loop_start + 1 + self.total_effective_layers = num_layers + n_looped * (self.intra_loop_steps - 1) + else: + self.total_effective_layers = num_layers + + # MoE config stored on model (used in forward() to gate the aux loss) + self.moe_aux_loss_coeff = float(moe_aux_loss_coeff) + self._has_moe = moe_num_experts >= 2 and moe_every_n > 0 + + def is_ssm_block(idx: int) -> bool: + return self.use_ssm and self.ssm_every_n > 0 and ((idx + 1) % self.ssm_every_n == 0) + + def is_moe_block(idx: int) -> bool: + return moe_num_experts >= 2 and moe_every_n > 0 and idx % moe_every_n == 0 + + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.embed_scale = embed_scale + self._embed_scale_factor = model_dim ** 0.5 if embed_scale else 1.0 + if self.use_recurrence: + self.num_encoder_layers = 0 + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + # In recurrence mode skip_weights are unused; keep as buffer so DDP + # doesn't expect gradients for an empty parameter tensor. + self.register_buffer("skip_weights", torch.ones(0, model_dim, dtype=torch.float32), persistent=False) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + use_swiglu=use_swiglu, + use_ssm=is_ssm_block(i), + ssm_expand=ssm_expand, + ssm_kernel=ssm_kernel, + ssm_impl=ssm_impl, + mamba3_d_state=mamba3_d_state, + mamba3_head_dim=mamba3_head_dim, + mamba3_is_mimo=mamba3_is_mimo, + mamba3_mimo_rank=mamba3_mimo_rank, + mamba3_chunk_size=mamba3_chunk_size, + mamba3_outproj_norm=mamba3_outproj_norm, + moe_num_experts=moe_num_experts if is_moe_block(i) else 0, + moe_capacity_factor=moe_capacity_factor, + use_parallel_residual=use_parallel_residual and not is_ssm_block(i), + use_sandwich_norm=use_sandwich_norm, + ) + for i in range(recurrent_core_layers) + ] + ) + # SHARE_FFN_ACROSS_BLOCKS is incompatible with MoE (different experts per layer). + if share_ffn_across_blocks and len(self.blocks) > 1 and not self._has_moe: + shared_mlp = self.blocks[0].mlp + for i in range(1, len(self.blocks)): + self.blocks[i].mlp = shared_mlp + else: + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + use_swiglu=use_swiglu, + use_ssm=is_ssm_block(i), + ssm_expand=ssm_expand, + ssm_kernel=ssm_kernel, + ssm_impl=ssm_impl, + mamba3_d_state=mamba3_d_state, + mamba3_head_dim=mamba3_head_dim, + mamba3_is_mimo=mamba3_is_mimo, + mamba3_mimo_rank=mamba3_mimo_rank, + mamba3_chunk_size=mamba3_chunk_size, + mamba3_outproj_norm=mamba3_outproj_norm, + moe_num_experts=moe_num_experts if is_moe_block(i) else 0, + moe_capacity_factor=moe_capacity_factor, + use_sandwich_norm=use_sandwich_norm, + ) + for i in range(num_layers) + ] + ) + if share_ffn_across_blocks and len(self.blocks) > 1 and not self._has_moe: + shared_mlp = self.blocks[0].mlp + for i in range(1, len(self.blocks)): + self.blocks[i].mlp = shared_mlp + self.num_ssm_blocks = sum(1 for block in self.blocks if block.use_ssm) + self.num_moe_blocks = sum(1 for block in self.blocks if block.is_moe) + self.num_attn_blocks = len(self.blocks) - self.num_ssm_blocks + # JPCR (JEPA Predictive Coding Recurrence) or Ouroboros loop conditioning. + # JPCR: per-token MLP predictors trained with JEPA MSE loss against teacher intermediates. + # Each predictor predicts the ideal hidden state; a learned gate blends this prediction + # into the recurrence input. Progressive depth targeting across loop iterations. + # Ouroboros: per-looped-block tiny hypernetwork generating (scale, shift) from mean(x). + self.jpcr_enabled = bool(jpcr_enabled) and _intra_active + if self.jpcr_enabled: + n_looped = self.intra_loop_end - self.intra_loop_start + 1 + predictors = [] + for _ in range(n_looped): + predictors.append(JPCRPredictor(model_dim, jpcr_hidden, jpcr_proj_dim, jpcr_blend_init)) + self.jpcr_predictors = nn.ModuleList(predictors) + self.intra_loop_controllers = nn.ModuleList([]) # not used with JPCR + self._intra_model_dim = model_dim + elif _intra_active: + self.jpcr_predictors = nn.ModuleList([]) + n_looped = self.intra_loop_end - self.intra_loop_start + 1 + _ctrl_hidden = 32 + # One controller per looped block; each outputs [steps, 2, dim] + controllers = [] + for _ in range(n_looped): + net = nn.Sequential( + nn.Linear(model_dim, _ctrl_hidden, bias=True), + nn.SiLU(), + nn.Linear(_ctrl_hidden, self.intra_loop_steps * 2 * model_dim, bias=True), + ) + # Zero-init output layer → identity transform at start of training + nn.init.zeros_(net[-1].weight) + nn.init.zeros_(net[-1].bias) + controllers.append(net) + self.intra_loop_controllers = nn.ModuleList(controllers) + self._intra_model_dim = model_dim + else: + self.jpcr_predictors = nn.ModuleList([]) + self.intra_loop_controllers = nn.ModuleList([]) + self._intra_model_dim = model_dim + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + self.dual_head = CastedLinear(model_dim, self.dual_head_num_classes, bias=True) if self.dual_head_enabled else None + if self.lm_head is not None: + self.lm_head._zero_init = True + if self.mtp_enabled: + self.mtp_branches = nn.ModuleList([MTPBranch(model_dim) for _ in range(self.mtp_steps)]) + if self.mtp_tie_embeddings and self.tie_embeddings: + self.mtp_heads = None + else: + self.mtp_heads = nn.ModuleList([CastedLinear(model_dim, vocab_size, bias=False) for _ in range(self.mtp_steps)]) + self.register_buffer( + "mtp_step_weights", + torch.tensor([self.mtp_decay**i for i in range(self.mtp_steps)], dtype=torch.float32), + persistent=False, + ) + else: + self.mtp_branches = None + self.mtp_heads = None + self.register_buffer("mtp_step_weights", torch.zeros((0,), dtype=torch.float32), persistent=False) + # Low-rank bigram logit bias. At position i, adds bigram_right(bigram_left(input[i])) to logits. + # This gives the model a cheap, learned n-gram prior on top of the contextual representations. + self.bigram_rank = bigram_rank + if bigram_rank > 0: + self.bigram_left = nn.Embedding(vocab_size, bigram_rank) + self.bigram_right = CastedLinear(bigram_rank, vocab_size, bias=False) + self.bigram_right._zero_init = True # starts contributing nothing; learns when useful + self.bigram_scale = nn.Parameter(torch.ones(1, dtype=torch.float32)) + if self.residual_ngram_enabled: + if self.residual_bigram_rank > 0: + self.residual_bigram_left = nn.Embedding(vocab_size, self.residual_bigram_rank) + self.residual_bigram_right = CastedLinear(self.residual_bigram_rank, vocab_size, bias=False) + self.residual_bigram_right._zero_init = True + if self.residual_trigram_rank > 0: + self.residual_trigram_prev1 = nn.Embedding(vocab_size, self.residual_trigram_rank) + self.residual_trigram_prev2 = nn.Embedding(vocab_size, self.residual_trigram_rank) + self.residual_trigram_right = CastedLinear(self.residual_trigram_rank, vocab_size, bias=False) + self.residual_trigram_right._zero_init = True + self.residual_ngram_scale = nn.Parameter(torch.ones(1, dtype=torch.float32)) + gate_in_dim = model_dim + (1 if self.ngram_entropy_gate else 0) + self.residual_ngram_gate = CastedLinear(gate_in_dim, 1, bias=True) + if self.copy_cache_enabled: + self.copy_q = CastedLinear(model_dim, self.copy_cache_dim, bias=False) + self.copy_k = CastedLinear(model_dim, self.copy_cache_dim, bias=False) + self.copy_gate = CastedLinear(model_dim, 1, bias=True) + self._init_weights() + if self.residual_ngram_enabled: + nn.init.zeros_(self.residual_ngram_gate.weight) + if self.residual_ngram_gate.bias is not None: + nn.init.constant_(self.residual_ngram_gate.bias, self.residual_ngram_mix_init) + if self.copy_cache_enabled: + nn.init.zeros_(self.copy_gate.weight) + if self.copy_gate.bias is not None: + nn.init.constant_(self.copy_gate.bias, self.copy_cache_gate_init) + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _compute_residual_ngram_logits(self, input_ids: Tensor) -> Tensor | None: + if not self.residual_ngram_enabled: + return None + prev1 = input_ids.reshape(-1) + ngram_logits: Tensor | None = None + if self.residual_bigram_rank > 0: + bg = self.residual_bigram_right(self.residual_bigram_left(prev1)) + ngram_logits = bg + if self.residual_trigram_rank > 0: + prev2_ids = torch.cat((input_ids[:, :1], input_ids[:, :-1]), dim=1).reshape(-1) + tri_feat = self.residual_trigram_prev1(prev1) * self.residual_trigram_prev2(prev2_ids) + tri = self.residual_trigram_right(tri_feat) + ngram_logits = tri if ngram_logits is None else (ngram_logits + tri) + if ngram_logits is None: + return None + return self.residual_ngram_scale * ngram_logits + + def _build_copy_cache_log_probs(self, hidden: Tensor, input_ids: Tensor, source_next_ids: Tensor) -> Tensor: + # hidden: [B, T, D], input_ids/source_next_ids: [B, T] + bsz, seqlen, _ = hidden.shape + q = self.copy_q(hidden).float() + k = self.copy_k(hidden).float() + scale = 1.0 / math.sqrt(float(self.copy_cache_dim)) + att = torch.matmul(q, k.transpose(1, 2)) * scale # [B, T, T] + + pos = torch.arange(seqlen, device=hidden.device) + t_pos = pos.view(1, seqlen, 1) + j_pos = pos.view(1, 1, seqlen) + causal = j_pos < t_pos + within = (t_pos - j_pos) <= self.copy_cache_window + mask = causal & within + att = att.masked_fill(~mask, float("-inf")) + no_source = ~mask.any(dim=-1, keepdim=True) + att = torch.where(no_source, torch.zeros_like(att), att) + att_prob = F.softmax(att, dim=-1).masked_fill(no_source, 0.0) + + copy_probs = torch.zeros((bsz, seqlen, self.tok_emb.num_embeddings), device=hidden.device, dtype=torch.float32) + copy_probs.scatter_add_( + 2, + source_next_ids.unsqueeze(1).expand(-1, seqlen, -1), + att_prob, + ) + return torch.log(copy_probs.clamp_min(1e-9)) + + def _compose_output_logits( + self, + logits_proj: Tensor, + input_ids: Tensor, + hidden: Tensor, + source_next_ids: Tensor | None = None, + ) -> tuple[Tensor, bool]: + neural_logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + ngram_logits = self._compute_residual_ngram_logits(input_ids) + composed = neural_logits + if ngram_logits is not None: + # Stable residual composition in logit space. + flat_h = hidden.reshape(-1, hidden.size(-1)) + if self.ngram_entropy_gate: + # Cheap confidence signal: (logsumexp - max) = -log max_prob. Larger = less confident. + # Detached so the gate signal is stop-grad wrt the neural head (keeps semantics simple). + with torch.no_grad(): + n_logits_f = neural_logits.float() + lse = torch.logsumexp(n_logits_f, dim=-1, keepdim=True) + max_logit = n_logits_f.max(dim=-1, keepdim=True).values + neg_max_log_prob = (lse - max_logit).to(dtype=flat_h.dtype) + gate_input = torch.cat([flat_h, neg_max_log_prob], dim=-1) + gate = torch.sigmoid(self.residual_ngram_gate(gate_input)) + else: + gate = torch.sigmoid(self.residual_ngram_gate(flat_h)) + cap = self.ngram_softcap if self.ngram_softcap > 0.0 else self.logit_softcap + ngram_logits = cap * torch.tanh(ngram_logits / cap) + composed = composed + gate.to(dtype=composed.dtype) * ngram_logits.to(dtype=composed.dtype) + + if not self.copy_cache_enabled: + return composed, False + + if source_next_ids is None: + source_next_ids = torch.cat((input_ids[:, 1:], input_ids[:, -1:]), dim=1) + copy_log_probs = self._build_copy_cache_log_probs(hidden, input_ids, source_next_ids) + model_log_probs = F.log_softmax(composed.float().reshape(input_ids.size(0), input_ids.size(1), -1), dim=-1) + gate = torch.sigmoid(self.copy_gate(hidden).float()).clamp(min=1e-4, max=1.0 - 1e-4) + mixed_log_probs = torch.logaddexp( + torch.log1p(-gate) + model_log_probs, + torch.log(gate) + copy_log_probs, + ) + return mixed_log_probs.reshape(-1, mixed_log_probs.size(-1)).to(dtype=composed.dtype), True + + def _apply_loop_conditioning(self, x: Tensor, block_idx: int, step: int) -> Tensor: + """Apply JPCR blend or Ouroboros conditioning before a looped block execution.""" + if self.jpcr_enabled and len(self.jpcr_predictors) > 0: + predictor = self.jpcr_predictors[block_idx - self.intra_loop_start] + predicted_target, gate = predictor(x) + # Blend: nudge current state toward predicted target + x = x + gate * (predicted_target - x) + elif len(self.intra_loop_controllers) > 0: + ctrl = self.intra_loop_controllers[block_idx - self.intra_loop_start] + out = _run_ctrl_safe(ctrl, x, self.intra_loop_steps, self._intra_model_dim) + scale = out[:, step, 0, :].unsqueeze(1).to(dtype=x.dtype) + shift = out[:, step, 1, :].unsqueeze(1).to(dtype=x.dtype) + x = x * (1.0 + scale.tanh()) + shift + return x + + def _forward_hidden(self, input_ids: Tensor, *, jpcr_runtime_active: bool | None = None) -> Tensor: + x = self.tok_emb(input_ids) + if self.embed_scale: + x = x * self._embed_scale_factor + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + jpcr_runtime_active = self.jpcr_enabled if jpcr_runtime_active is None else bool(jpcr_runtime_active) + if self.use_recurrence: + for _ in range(self.recurrent_steps): + for block in self.blocks: + x, _ = block(x, x0) + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + n_rep = self.intra_loop_steps if (jpcr_runtime_active and self.intra_loop_start <= i <= self.intra_loop_end) else 1 + for s in range(n_rep): + if n_rep > 1 and s > 0: + x = self._apply_loop_conditioning(x, i, s) + x, _ = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + j = self.num_encoder_layers + i + n_rep = self.intra_loop_steps if (jpcr_runtime_active and self.intra_loop_start <= j <= self.intra_loop_end) else 1 + for s in range(n_rep): + if n_rep > 1 and s > 0: + x = self._apply_loop_conditioning(x, j, s) + x, _ = self.blocks[j](x, x0) + return self.final_norm(x) + + def _forward_hidden_with_intermediates(self, input_ids: Tensor, *, jpcr_runtime_active: bool | None = None) -> tuple[Tensor, list[Tensor]]: + """Forward pass capturing hidden states ONLY for looped blocks (NO loop, NO conditioning). + + Used by the EMA teacher to provide clean JEPA targets for JPCR predictors. + Runs each block exactly once — the teacher represents the "ideal" single-pass model. + Only captures intermediates for blocks in [intra_loop_start, intra_loop_end] to save memory. + Returns (final_hidden_after_norm, list_of_looped_block_hidden_states). + """ + x = self.tok_emb(input_ids) + if self.embed_scale: + x = x * self._embed_scale_factor + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + intermediates: list[Tensor] = [] + jpcr_runtime_active = self.jpcr_enabled if jpcr_runtime_active is None else bool(jpcr_runtime_active) + if self.use_recurrence: + for _ in range(self.recurrent_steps): + for block in self.blocks: + x, _ = block(x, x0) + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x, _ = self.blocks[i](x, x0) + if jpcr_runtime_active and self.intra_loop_start <= i <= self.intra_loop_end: + intermediates.append(x) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + j = self.num_encoder_layers + i + x, _ = self.blocks[j](x, x0) + if jpcr_runtime_active and self.intra_loop_start <= j <= self.intra_loop_end: + intermediates.append(x) + return self.final_norm(x), intermediates + + def forward_hidden_and_output(self, input_ids: Tensor, *, jpcr_runtime_active: bool | None = None) -> tuple[Tensor, Tensor, bool]: + h = self._forward_hidden(input_ids, jpcr_runtime_active=jpcr_runtime_active) + flat_h = h.reshape(-1, h.size(-1)) + if self.tie_embeddings: + logits_proj = F.linear(flat_h, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(flat_h) + if self.bigram_rank > 0: + bg = self.bigram_right(self.bigram_left(input_ids.reshape(-1))) # [B*T, vocab] + logits_proj = logits_proj + self.bigram_scale * bg + logits, logits_are_log_probs = self._compose_output_logits(logits_proj, input_ids, h) + return h, logits, logits_are_log_probs + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Forward pass returning logits. NOTE: when self.copy_cache_enabled is True, + the returned tensor is log-probabilities (already log_softmax'd), not raw logits. + Callers that feed this into distillation must rely on student's logits_are_log_probs + flag to interpret format consistently (student and teacher share config).""" + _, logits, _ = self.forward_hidden_and_output(input_ids) + return logits + + def forward_logits_and_intermediates(self, input_ids: Tensor, *, jpcr_runtime_active: bool | None = None) -> tuple[Tensor, list[Tensor]]: + """Forward pass returning logits AND per-block hidden states for JPCR teacher. + Same format caveat as forward_logits: log-probs when copy_cache is enabled.""" + h, intermediates = self._forward_hidden_with_intermediates(input_ids, jpcr_runtime_active=jpcr_runtime_active) + flat_h = h.reshape(-1, h.size(-1)) + if self.tie_embeddings: + logits_proj = F.linear(flat_h, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(flat_h) + if self.bigram_rank > 0: + bg = self.bigram_right(self.bigram_left(input_ids.reshape(-1))) + logits_proj = logits_proj + self.bigram_scale * bg + logits, _ = self._compose_output_logits(logits_proj, input_ids, h) + return logits, intermediates + + def forward( + self, + input_ids: Tensor, + target_ids: Tensor, + loss_mask: Tensor | None = None, + per_token_weights: Tensor | None = None, + aux_targets: Tensor | None = None, + aux_weight: float = 0.0, + distill_teacher_logits: Tensor | None = None, + distill_weight: float = 0.0, + distill_temp: float = 1.0, + logit_reg_weight: float = 0.0, + jpcr_teacher_intermediates: list[Tensor] | None = (), + jpcr_weight: float = 0.0, + jpcr_runtime_active: bool = False, + ) -> Tensor: + if jpcr_teacher_intermediates is None: + jpcr_teacher_intermediates = () + x = self.tok_emb(input_ids) + if self.embed_scale: + x = x * self._embed_scale_factor + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + moe_z_loss: Tensor = x.new_zeros(()) # accumulates router Z-losses from all MoE blocks + jpcr_loss: Tensor = x.new_zeros(()) # accumulates JEPA MSE losses from JPCR predictors + jpcr_count: int = 0 # number of JPCR predictions for averaging + if self.use_recurrence: + for _ in range(self.recurrent_steps): + for block in self.blocks: + x, zl = block(x, x0) + moe_z_loss = moe_z_loss + zl + else: + skips: list[Tensor] = [] + # Only enable repeated intra-loop passes when loop conditioning is active. + # For JPCR this means post-distill runtime activation; for Ouroboros + # (controllers present) this remains active whenever configured. + loop_active = jpcr_runtime_active or len(self.intra_loop_controllers) > 0 + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + n_rep = (self.intra_loop_steps if self.intra_loop_start <= i <= self.intra_loop_end else 1) if loop_active else 1 + for s in range(n_rep): + if n_rep > 1 and s > 0: + if self.jpcr_enabled and len(self.jpcr_predictors) > 0: + if jpcr_runtime_active: + predictor = self.jpcr_predictors[i - self.intra_loop_start] + predicted_target, gate = predictor(x) + # Always compute JPCR loss when teacher targets exist. + # jpcr_weight=0 before distill → no gradient impact. + # No branch on len(intermediates) to avoid torch.compile retrace. + target_idx = (i + s) - self.intra_loop_start + if target_idx < len(jpcr_teacher_intermediates): + teacher_target = jpcr_teacher_intermediates[target_idx] + jpcr_loss = jpcr_loss + predictor.compute_loss(predicted_target, teacher_target) + jpcr_count += 1 + x = x + gate * (predicted_target - x) + elif len(self.intra_loop_controllers) > 0: + ctrl = self.intra_loop_controllers[i - self.intra_loop_start] + out = _run_ctrl_safe(ctrl, x, self.intra_loop_steps, self._intra_model_dim) + scale = out[:, s, 0, :].unsqueeze(1).to(dtype=x.dtype) + shift = out[:, s, 1, :].unsqueeze(1).to(dtype=x.dtype) + x = x * (1.0 + scale.tanh()) + shift + x, zl = self.blocks[i](x, x0) + moe_z_loss = moe_z_loss + zl + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + j = self.num_encoder_layers + i + n_rep = (self.intra_loop_steps if self.intra_loop_start <= j <= self.intra_loop_end else 1) if loop_active else 1 + for s in range(n_rep): + if n_rep > 1 and s > 0: + if self.jpcr_enabled and len(self.jpcr_predictors) > 0: + if jpcr_runtime_active: + predictor = self.jpcr_predictors[j - self.intra_loop_start] + predicted_target, gate = predictor(x) + target_idx = (j + s) - self.intra_loop_start + if target_idx < len(jpcr_teacher_intermediates): + teacher_target = jpcr_teacher_intermediates[target_idx] + jpcr_loss = jpcr_loss + predictor.compute_loss(predicted_target, teacher_target) + jpcr_count += 1 + x = x + gate * (predicted_target - x) + elif len(self.intra_loop_controllers) > 0: + ctrl = self.intra_loop_controllers[j - self.intra_loop_start] + out = _run_ctrl_safe(ctrl, x, self.intra_loop_steps, self._intra_model_dim) + scale = out[:, s, 0, :].unsqueeze(1).to(dtype=x.dtype) + shift = out[:, s, 1, :].unsqueeze(1).to(dtype=x.dtype) + x = x * (1.0 + scale.tanh()) + shift + x, zl = self.blocks[j](x, x0) + moe_z_loss = moe_z_loss + zl + + h = self.final_norm(x) + flat_h = h.reshape(-1, h.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(flat_h, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(flat_h) + # Low-rank bigram bias: cheap learned n-gram prior on top of contextual representation. + if self.bigram_rank > 0: + bg = self.bigram_right(self.bigram_left(input_ids.reshape(-1))) # [B*T, vocab] + logits_proj = logits_proj + self.bigram_scale * bg + logits, logits_are_log_probs = self._compose_output_logits( + logits_proj, + input_ids, + h, + source_next_ids=target_ids, + ) + if logits_are_log_probs: + base_per_token = F.nll_loss(logits.float(), targets, reduction="none") # [B*T] + else: + base_per_token = F.cross_entropy(logits.float(), targets, reduction="none") # [B*T] + weighted = base_per_token + norm = torch.ones((), device=base_per_token.device, dtype=base_per_token.dtype) * base_per_token.numel() + if per_token_weights is not None: + token_w = per_token_weights.reshape(-1).to(base_per_token.dtype) + weighted = weighted * token_w + norm = token_w.sum().clamp(min=1) + if loss_mask is not None: + mask = loss_mask.reshape(-1).to(base_per_token.dtype) + weighted = weighted * mask + if per_token_weights is None: + norm = mask.sum().clamp(min=1) + else: + norm = (token_w * mask).sum().clamp(min=1) + base_loss = weighted.sum() / norm + + total_loss = base_loss + + if self.dual_head is not None and aux_targets is not None and aux_weight > 0.0: + aux_logits = self.dual_head(flat_h) # [B*T, C] + aux_flat_targets = aux_targets.reshape(-1) + aux_per_token = F.cross_entropy(aux_logits.float(), aux_flat_targets, reduction="none") + if loss_mask is not None: + mask = loss_mask.reshape(-1).to(aux_per_token.dtype) + aux_loss = (aux_per_token * mask).sum() / mask.sum().clamp(min=1) + else: + aux_loss = aux_per_token.mean() + total_loss = total_loss + float(aux_weight) * aux_loss + elif self.dual_head is not None: + # Safety touch keeps dual-head params in graph when auxiliary loss is inactive. + total_loss = total_loss + 0.0 * ( + self.dual_head.weight.reshape(-1)[0].float() + + (self.dual_head.bias.reshape(-1)[0].float() if self.dual_head.bias is not None else 0.0) + ) + + if logit_reg_weight > 0.0: + total_loss = total_loss + float(logit_reg_weight) * logits_proj.float().pow(2).mean() + + if distill_teacher_logits is not None and distill_teacher_logits.numel() > 0 and distill_weight > 0.0: + temp = max(float(distill_temp), 1e-4) + if logits_are_log_probs: + # Both student and teacher share config (EMA teacher). When copy_cache is + # enabled, both emit log-probs, so teacher must be exp()'d to probs. + # Temperature scaling is skipped (would need renormalization in prob space). + student_log_probs = logits.float() + teacher_probs = distill_teacher_logits.float().exp() + else: + student = (logits.float() / temp) + teacher = (distill_teacher_logits.float() / temp) + student_log_probs = F.log_softmax(student, dim=-1) + teacher_probs = F.softmax(teacher, dim=-1) + if loss_mask is not None: + mask = loss_mask.reshape(-1) > 0 + student_log_probs = student_log_probs[mask] + teacher_probs = teacher_probs[mask] + kl = F.kl_div( + student_log_probs, + teacher_probs, + reduction="batchmean", + ) * (temp * temp if not logits_are_log_probs else 1.0) + total_loss = total_loss + float(distill_weight) * kl + + # JPCR (JEPA Predictive Coding Recurrence) loss: average MSE across all predictor outputs. + # Always add the term (no branch on jpcr_weight) to keep torch.compile graph constant. + # When jpcr_weight=0.0 (before distill), the multiplication zeros out the gradient. + if jpcr_count > 0: + total_loss = total_loss + float(jpcr_weight) * (jpcr_loss / jpcr_count) + total_loss = total_loss + 0.0 * jpcr_loss + if self.jpcr_enabled and len(self.jpcr_predictors) > 0: + # Safety touch keeps ALL JPCR params in graph every step (zero gradient where unused). + # This supports DDP find_unused_parameters=False with conditional JPCR execution. + for p in self.jpcr_predictors.parameters(): + total_loss = total_loss + 0.0 * p.reshape(-1)[0].float() + + # MoE router Z-loss — only during training (loss_mask is None means no sliding-window eval mask). + # Follows the same pattern as MTP (excluded during eval to keep val_bpb clean). + if self._has_moe and self.moe_aux_loss_coeff > 0.0 and loss_mask is None: + total_loss = total_loss + self.moe_aux_loss_coeff * moe_z_loss + + # Keep eval metric comparable by applying MTP only when loss_mask is not provided. + if not self.mtp_enabled or self.mtp_weight <= 0.0 or loss_mask is not None: + return total_loss + + _, seqlen = target_ids.shape + weighted_aux = torch.zeros((), device=base_loss.device, dtype=base_loss.dtype) + weight_sum = torch.zeros((), device=base_loss.device, dtype=base_loss.dtype) + if self.mtp_branches is not None: + for step_idx in range(self.mtp_steps): + horizon = step_idx + 1 # 1 predicts token at t+2, 2 predicts t+3, ... + if seqlen - horizon <= 0: + continue + branch_h = self.mtp_branches[step_idx](h[:, : seqlen - horizon, :]) + branch_flat_h = branch_h.reshape(-1, branch_h.size(-1)) + future_targets = target_ids[:, horizon:].reshape(-1) + if self.mtp_heads is None: + aux_logits_proj = F.linear(branch_flat_h, self.tok_emb.weight) + else: + aux_logits_proj = self.mtp_heads[step_idx](branch_flat_h) + aux_logits = self.logit_softcap * torch.tanh(aux_logits_proj / self.logit_softcap) + aux_loss = F.cross_entropy(aux_logits.float(), future_targets, reduction="mean") + w = self.mtp_step_weights[step_idx].to(dtype=weighted_aux.dtype) + weighted_aux = weighted_aux + aux_loss.to(weighted_aux.dtype) * w + weight_sum = weight_sum + w + + aux_loss = weighted_aux / weight_sum.clamp_min(1e-12) + return total_loss + self.mtp_weight * aux_loss + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.quant_scheme not in SUPPORTED_QUANT_SCHEMES: + raise ValueError(f"Unsupported QUANT_SCHEME={args.quant_scheme!r}; expected one of {sorted(SUPPORTED_QUANT_SCHEMES)}") + if args.compressor not in SUPPORTED_COMPRESSORS: + raise ValueError(f"Unsupported COMPRESSOR={args.compressor!r}; expected one of {sorted(SUPPORTED_COMPRESSORS)}") + if args.weight_order not in SUPPORTED_WEIGHT_ORDERS: + raise ValueError(f"Unsupported WEIGHT_ORDER={args.weight_order!r}; expected one of {sorted(SUPPORTED_WEIGHT_ORDERS)}") + if args.mixed_low_precision_scheme not in {"int8", "int5", "int4"}: + raise ValueError( + f"Unsupported MIXED_LOW_PRECISION_SCHEME={args.mixed_low_precision_scheme!r}; expected 'int8', 'int5', or 'int4'" + ) + sweep_specs = resolve_eval_sweep_specs(args) + blend_specs, blend_weights = resolve_eval_blend_specs(args) + max_eval_seq_len = resolve_max_eval_seq_len(args, sweep_specs, blend_specs) + train_loss_mask_stride_frac = resolve_train_loss_mask_stride_frac(args) + if args.final_eval_mode not in {"primary", "blend"}: + raise ValueError(f"Unsupported FINAL_EVAL_MODE={args.final_eval_mode!r}; expected 'primary' or 'blend'") + if args.final_eval_mode == "blend" and not blend_specs: + raise ValueError("FINAL_EVAL_MODE=blend requires EVAL_BLEND_SEQ_LENS to be set") + + # ----------------------------- + # DISTRIBUTED + DEVICE SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + device_override = os.environ.get("DEVICE", "").strip().lower() + grad_accum_override = os.environ.get("GRAD_ACCUM_STEPS", "").strip() + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if grad_accum_override: + grad_accum_steps = int(grad_accum_override) + if grad_accum_steps <= 0: + raise ValueError(f"GRAD_ACCUM_STEPS must be positive, got {grad_accum_steps}") + else: + if 8 % world_size != 0: + raise ValueError( + f"WORLD_SIZE={world_size} must divide 8 for default grad accumulation; " + "set GRAD_ACCUM_STEPS explicitly to override" + ) + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + tokens_per_microstep = world_size * grad_accum_steps * args.train_seq_len + if args.train_batch_tokens % tokens_per_microstep != 0: + raise ValueError( + "TRAIN_BATCH_TOKENS must be divisible by WORLD_SIZE*GRAD_ACCUM_STEPS*TRAIN_SEQ_LEN; " + f"got TRAIN_BATCH_TOKENS={args.train_batch_tokens}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + if device_override: + if device_override == "cuda" and not torch.cuda.is_available(): + raise RuntimeError("DEVICE=cuda requested but CUDA is unavailable") + if device_override not in {"cpu", "cuda"}: + raise ValueError(f"Unsupported DEVICE={device_override!r}; expected 'cpu' or 'cuda'") + device = torch.device(device_override, local_rank) if device_override == "cuda" else torch.device("cpu") + else: + device = torch.device("cuda", local_rank) if torch.cuda.is_available() else torch.device("cpu") + if device.type == "cuda": + torch.cuda.set_device(device) + autocast_enabled = device.type == "cuda" + use_compile = bool(int(os.environ.get("USE_TORCH_COMPILE", "1" if device.type == "cuda" else "0"))) + compile_dynamic_mode_raw = os.environ.get("TORCH_COMPILE_DYNAMIC", "true").strip().lower() + if compile_dynamic_mode_raw in {"1", "true", "yes", "on"}: + compile_dynamic: bool | None = True + elif compile_dynamic_mode_raw in {"0", "false", "no", "off"}: + compile_dynamic = False + elif compile_dynamic_mode_raw in {"none", "auto", "default", ""}: + compile_dynamic = None + else: + raise ValueError( + f"Unsupported TORCH_COMPILE_DYNAMIC={compile_dynamic_mode_raw!r}; expected true|false|none" + ) + if use_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + if distributed: + if device.type == "cuda": + dist.init_process_group(backend="nccl", device_id=device) + else: + dist.init_process_group(backend="gloo") + dist.barrier() + master_process = rank == 0 + + sdp_backends_log = "cpu" + if device.type == "cuda": + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + # Some consumer GPUs and GQA configs do not support flash-only SDPA. + # Default to "auto" so CUDA kernels can fall back to math/mem-efficient. + sdp_backend_mode = os.environ.get("SDP_BACKEND_MODE", "auto").strip().lower() + if sdp_backend_mode == "flash": + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + sdp_backends_log = "cudnn=False flash=True mem_efficient=False math=False mode=flash" + elif sdp_backend_mode == "math": + enable_cudnn_sdp(False) + enable_flash_sdp(False) + enable_mem_efficient_sdp(False) + enable_math_sdp(True) + sdp_backends_log = "cudnn=False flash=False mem_efficient=False math=True mode=math" + elif sdp_backend_mode == "auto": + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(True) + enable_math_sdp(True) + sdp_backends_log = "cudnn=False flash=True mem_efficient=True math=True mode=auto" + else: + raise ValueError( + f"Unsupported SDP_BACKEND_MODE={sdp_backend_mode!r}; expected 'auto', 'flash', or 'math'" + ) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + f"device:{device} distributed:{distributed} use_torch_compile:{use_compile} " + f"torch_compile_dynamic:{compile_dynamic}", + console=False, + ) + if device.type == "cuda": + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if device.type == "cuda": + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, max_eval_seq_len) + if args.val_max_tokens > 0: + usable = (min(args.val_max_tokens, val_tokens.numel() - 1) // max_eval_seq_len) * max_eval_seq_len + if usable <= 0: + raise ValueError( + f"VAL_MAX_TOKENS={args.val_max_tokens} is too small for MAX_EVAL_SEQ_LEN={max_eval_seq_len}" + ) + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0( + f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1} " + f"val_max_tokens:{args.val_max_tokens if args.val_max_tokens > 0 else 'full'}" + ) + _, primary_eval_seq_len, primary_eval_rope_scale = resolve_primary_eval_spec(args) + log0( + f"eval_primary: seq_len:{primary_eval_seq_len} rope_scale:{primary_eval_rope_scale:.4f} " + f"stride_frac:{args.eval_stride_frac:.4f} final_eval_mode:{args.final_eval_mode}" + ) + if len(sweep_specs) > 1: + sweep_specs_log = ",".join( + f"{name}:{seq_len}@{rope_scale:.4f}" + for name, seq_len, rope_scale in sweep_specs[1:] + ) + log0(f"eval_sweep: specs:{sweep_specs_log}") + if blend_specs: + blend_stride_frac = args.eval_blend_stride_frac if args.eval_blend_stride_frac > 0.0 else args.eval_stride_frac + blend_specs_log = ",".join( + f"{name}:{seq_len}@{rope_scale:.4f}" + for name, seq_len, rope_scale in blend_specs + ) + blend_weights_log = ",".join(f"{weight:.6f}" for weight in blend_weights) + log0( + f"eval_blend: stride_frac:{blend_stride_frac:.4f} specs:{blend_specs_log} " + f"weights:{blend_weights_log} position_bias:{args.eval_blend_position_bias:.4f} " + f"position_power:{args.eval_blend_position_power:.4f}" + ) + log0( + f"eval_cont_cache: enabled:{int(args.eval_cont_cache_enabled)} " + f"window:{args.eval_cont_cache_window} topk:{args.eval_cont_cache_topk} " + f"weight:{args.eval_cont_cache_weight:.4f} logit_scale:{args.eval_cont_cache_logit_scale:.4f} " + f"conf_power:{args.eval_cont_cache_conf_power:.4f} batch_seqs:{args.eval_cont_cache_batch_seqs}" + ) + log0( + f"train_loss_mask: enabled:{int(args.train_loss_mask_enabled)} " + f"stride_frac:{train_loss_mask_stride_frac:.4f}" + ) + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + # Enable LSQ fake-quant allocation on CastedLinear BEFORE model construction so + # each CastedLinear gains a per-row learnable qat_log_scale parameter automatically. + CastedLinear.qat_lsq_enabled = bool(args.qat_lsq) + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + recurrent_core_layers=args.recurrent_core_layers, + recurrent_steps=args.recurrent_steps, + share_ffn_across_blocks=args.share_ffn_across_blocks, + intra_loop_start=args.intra_loop_start, + intra_loop_end=args.intra_loop_end, + intra_loop_steps=args.intra_loop_steps, + use_parallel_residual=args.use_parallel_residual, + use_swiglu=args.use_swiglu, + bigram_rank=args.bigram_rank, + mtp_enabled=args.mtp_enabled, + mtp_steps=args.mtp_steps, + mtp_weight=args.mtp_weight, + mtp_decay=args.mtp_decay, + mtp_tie_embeddings=args.mtp_tie_embeddings, + use_ssm=args.use_ssm, + ssm_every_n=args.ssm_every_n, + ssm_expand=args.ssm_expand, + ssm_kernel=args.ssm_kernel, + ssm_impl=args.ssm_impl, + mamba3_d_state=args.mamba3_d_state, + mamba3_head_dim=args.mamba3_head_dim, + mamba3_is_mimo=args.mamba3_is_mimo, + mamba3_mimo_rank=args.mamba3_mimo_rank, + mamba3_chunk_size=args.mamba3_chunk_size, + mamba3_outproj_norm=args.mamba3_outproj_norm, + residual_ngram_enabled=args.residual_ngram_enabled, + residual_bigram_rank=args.residual_bigram_rank, + residual_trigram_rank=args.residual_trigram_rank, + residual_ngram_mix_init=args.residual_ngram_mix_init, + ngram_softcap=args.ngram_softcap, + ngram_entropy_gate=args.ngram_entropy_gate, + copy_cache_enabled=args.copy_cache_enabled, + copy_cache_window=args.copy_cache_window, + copy_cache_dim=args.copy_cache_dim, + copy_cache_gate_init=args.copy_cache_gate_init, + moe_num_experts=args.moe_num_experts, + moe_every_n=args.moe_every_n, + moe_capacity_factor=args.moe_capacity_factor, + moe_aux_loss_coeff=args.moe_aux_loss_coeff, + dual_head_enabled=args.dual_head_enabled, + dual_head_num_classes=4, + jpcr_enabled=args.jpcr_enabled, + jpcr_hidden=args.jpcr_hidden, + jpcr_proj_dim=args.jpcr_proj_dim, + jpcr_blend_init=args.jpcr_blend_init, + use_sandwich_norm=args.use_sandwich_norm, + embed_scale=args.embed_scale, + ).to(device=device, dtype=torch.bfloat16 if autocast_enabled else torch.float32) + if autocast_enabled: + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + if _OfficialMamba3 is not None and isinstance(module, _OfficialMamba3): + module.float() + restore_low_dim_params_to_fp32(base_model) + if use_compile: + # Disable DDPOptimizer: it splits compiled graphs at DDP bucket boundaries and + # crashes with `AttributeError: 'int' object has no attribute 'meta'` when plain + # Python int instance attrs (num_heads, head_dim) are captured as symbolic inputs + # to a subgraph. With world_size=1 the optimisation is a no-op anyway. + torch._dynamo.config.optimize_ddp = False + compiled_model = torch.compile(base_model, dynamic=compile_dynamic) if use_compile else base_model + model: nn.Module + if distributed: + ddp_find_unused_override = os.environ.get("DDP_FIND_UNUSED_PARAMETERS", "").strip().lower() + # find_unused_parameters=True is required when QAT_LSQ=1 because + # qat_log_scale params are registered but sit idle until QAT activates. + # Dual-head and JPCR are safety-touched in loss so they remain in graph with zero grads. + if ddp_find_unused_override in {"1", "true", "yes", "on"}: + _ddp_find_unused = True + elif ddp_find_unused_override in {"0", "false", "no", "off"}: + _ddp_find_unused = False + elif ddp_find_unused_override in {"", "auto", "default"}: + _ddp_find_unused = bool(args.qat_lsq) + else: + raise ValueError( + f"Unsupported DDP_FIND_UNUSED_PARAMETERS={ddp_find_unused_override!r}; expected true|false|auto" + ) + log0(f"ddp_find_unused_parameters:{int(_ddp_find_unused)}", console=False) + model = ( + DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, find_unused_parameters=_ddp_find_unused) + if device.type == "cuda" + else DDP(compiled_model, broadcast_buffers=False, find_unused_parameters=_ddp_find_unused) + ) + else: + model = compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if (p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) + and not name.endswith("qat_log_scale") + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=autocast_enabled, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=autocast_enabled, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if args.bigram_rank > 0: + bigram_params = [base_model.bigram_left.weight, base_model.bigram_right.weight, base_model.bigram_scale] + optimizer_bigram = torch.optim.Adam( + [{"params": bigram_params, "lr": args.bigram_lr, "base_lr": args.bigram_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=autocast_enabled, + ) + optimizers.append(optimizer_bigram) + if args.residual_ngram_enabled and getattr(base_model, "residual_ngram_enabled", False): + residual_params: list[nn.Parameter] = [ + base_model.residual_ngram_scale, + base_model.residual_ngram_gate.weight, + ] + if base_model.residual_ngram_gate.bias is not None: + residual_params.append(base_model.residual_ngram_gate.bias) + if base_model.residual_bigram_rank > 0: + residual_params.extend([base_model.residual_bigram_left.weight, base_model.residual_bigram_right.weight]) + if base_model.residual_trigram_rank > 0: + residual_params.extend( + [ + base_model.residual_trigram_prev1.weight, + base_model.residual_trigram_prev2.weight, + base_model.residual_trigram_right.weight, + ] + ) + optimizer_residual = torch.optim.Adam( + [{"params": residual_params, "lr": args.residual_ngram_lr, "base_lr": args.residual_ngram_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=autocast_enabled, + ) + optimizers.append(optimizer_residual) + if args.copy_cache_enabled and getattr(base_model, "copy_cache_enabled", False): + copy_params: list[nn.Parameter] = [ + base_model.copy_q.weight, + base_model.copy_k.weight, + base_model.copy_gate.weight, + ] + if base_model.copy_gate.bias is not None: + copy_params.append(base_model.copy_gate.bias) + optimizer_copy = torch.optim.Adam( + [{"params": copy_params, "lr": args.copy_cache_lr, "base_lr": args.copy_cache_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=autocast_enabled, + ) + optimizers.append(optimizer_copy) + if args.dual_head_enabled and getattr(base_model, "dual_head", None) is not None: + dual_params = [base_model.dual_head.weight] + if base_model.dual_head.bias is not None: + dual_params.append(base_model.dual_head.bias) + optimizer_dual = torch.optim.Adam( + [{"params": dual_params, "lr": args.dual_head_lr, "base_lr": args.dual_head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=autocast_enabled, + ) + optimizers.append(optimizer_dual) + if args.mtp_enabled and base_model.mtp_branches is not None: + mtp_params: list[nn.Parameter] = [] + for branch in base_model.mtp_branches: + mtp_params.extend(list(branch.parameters())) + if base_model.mtp_heads is not None: + for head in base_model.mtp_heads: + mtp_params.extend(list(head.parameters())) + if mtp_params: + optimizer_mtp = torch.optim.Adam( + [{"params": mtp_params, "lr": args.mtp_lr, "base_lr": args.mtp_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=autocast_enabled, + ) + optimizers.append(optimizer_mtp) + # JPCR predictor optimizer (also covers Ouroboros controllers if used) + if base_model.jpcr_enabled and len(base_model.jpcr_predictors) > 0: + jpcr_params: list[nn.Parameter] = list(base_model.jpcr_predictors.parameters()) + if jpcr_params: + optimizer_jpcr = torch.optim.Adam( + [{"params": jpcr_params, "lr": args.jpcr_lr, "base_lr": args.jpcr_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=autocast_enabled, + ) + optimizers.append(optimizer_jpcr) + elif len(base_model.intra_loop_controllers) > 0: + # Ouroboros controllers need an optimizer too (was missing before!) + ctrl_params: list[nn.Parameter] = list(base_model.intra_loop_controllers.parameters()) + if ctrl_params: + optimizer_ctrl = torch.optim.Adam( + [{"params": ctrl_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=autocast_enabled, + ) + optimizers.append(optimizer_ctrl) + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=autocast_enabled, + ) + optimizers.insert(1, optimizer_head) + + # Dedicated optimizer for LSQ per-row log_scale parameters across the WHOLE model. + # These are 1D learnable steps inside every CastedLinear (blocks + lm_head + bigram + ...), + # not all of which would otherwise land in scalar_params (which only walks blocks). + optimizer_lsq: torch.optim.Optimizer | None = None + if args.qat_lsq: + lsq_params: list[nn.Parameter] = [ + m.qat_log_scale + for m in base_model.modules() + if isinstance(m, CastedLinear) and m.qat_log_scale is not None + ] + if lsq_params: + lsq_lr = float(os.environ.get("QAT_LSQ_LR", str(args.scalar_lr))) + optimizer_lsq = torch.optim.Adam( + [{"params": lsq_params, "lr": lsq_lr, "base_lr": lsq_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=autocast_enabled, + ) + optimizers.append(optimizer_lsq) + if master_process: + log0(f"qat_lsq: optimizer params={len(lsq_params)} lr={lsq_lr}") + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"sdp_backends:{sdp_backends_log}") + attention_mode = "mha" if args.num_kv_heads == args.num_heads else "gqa" + log0( + f"attention_mode:{attention_mode} num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads} " + f"use_swiglu:{args.use_swiglu} use_ssm:{args.use_ssm} ssm_every_n:{args.ssm_every_n} " + f"ssm_impl:{args.ssm_impl} ssm_expand:{args.ssm_expand} ssm_kernel:{args.ssm_kernel} " + f"mamba3_d_state:{args.mamba3_d_state} mamba3_head_dim:{args.mamba3_head_dim} " + f"mamba3_is_mimo:{args.mamba3_is_mimo} mamba3_mimo_rank:{args.mamba3_mimo_rank} " + f"mamba3_chunk_size:{args.mamba3_chunk_size} mamba3_outproj_norm:{args.mamba3_outproj_norm} " + f"mtp_enabled:{args.mtp_enabled} mtp_steps:{args.mtp_steps} mtp_weight:{args.mtp_weight} " + f"mtp_decay:{args.mtp_decay} mtp_tie_embeddings:{args.mtp_tie_embeddings} " + f"distill_enabled:{args.distill_enabled} distill_start_frac:{args.distill_start_frac} " + f"distill_start_step:{args.distill_start_step} distill_start_wallclock_frac:{args.distill_start_wallclock_frac} " + f"distill_weight:{args.distill_weight} distill_temp:{args.distill_temp} distill_ema_decay:{args.distill_ema_decay} " + f"jpcr_apply_every:{args.jpcr_apply_every} " + f"logit_reg_weight:{args.logit_reg_weight} byte_weighted_loss:{args.byte_weighted_loss_enabled} " + f"byte_weighted_loss_alpha:{args.byte_weighted_loss_alpha} " + f"residual_ngram_enabled:{args.residual_ngram_enabled} residual_bigram_rank:{args.residual_bigram_rank} " + f"residual_trigram_rank:{args.residual_trigram_rank} residual_ngram_lr:{args.residual_ngram_lr} " + f"residual_ngram_mix_init:{args.residual_ngram_mix_init} " + f"ngram_softcap:{args.ngram_softcap} ngram_entropy_gate:{args.ngram_entropy_gate} " + f"ttt_enabled:{args.ttt_enabled} ttt_lr:{args.ttt_lr} ttt_steps:{args.ttt_steps} ttt_momentum:{args.ttt_momentum} " + f"copy_cache_enabled:{args.copy_cache_enabled} copy_cache_window:{args.copy_cache_window} " + f"copy_cache_dim:{args.copy_cache_dim} copy_cache_lr:{args.copy_cache_lr} " + f"copy_cache_gate_init:{args.copy_cache_gate_init} " + f"dual_head_enabled:{args.dual_head_enabled} dual_head_weight:{args.dual_head_weight} " + f"dual_head_start_frac:{args.dual_head_start_frac} dual_head_lr:{args.dual_head_lr} " + f"qat_scheme:{args.qat_scheme} qat_start_step:{args.qat_start_step} qat_end_step:{args.qat_end_step} " + f"qat_start_wallclock_frac:{args.qat_start_wallclock_frac} qat_end_wallclock_frac:{args.qat_end_wallclock_frac} " + f"moe_num_experts:{args.moe_num_experts} moe_every_n:{args.moe_every_n} " + f"moe_capacity_factor:{args.moe_capacity_factor} moe_aux_loss_coeff:{args.moe_aux_loss_coeff} " + f"num_moe_blocks:{base_model.num_moe_blocks}" + ) + if base_model.use_recurrence: + log0( + f"architecture:recurrent core_layers:{base_model.recurrent_core_layers} " + f"recurrent_steps:{base_model.recurrent_steps} " + f"effective_layers:{base_model.total_effective_layers} " + f"ssm_blocks:{base_model.num_ssm_blocks} attn_blocks:{base_model.num_attn_blocks} " + f"share_ffn_across_blocks:{base_model.share_ffn_across_blocks}" + ) + else: + intra_info = ( + f" intra_loop:[{base_model.intra_loop_start}-{base_model.intra_loop_end}]x{base_model.intra_loop_steps}" + f" effective_layers:{base_model.total_effective_layers}" + if base_model.intra_loop_start >= 0 else "" + ) + jpcr_info = ( + f" jpcr:hidden={args.jpcr_hidden},weight={args.jpcr_weight},blend_init={args.jpcr_blend_init},lr={args.jpcr_lr}" + if base_model.jpcr_enabled else "" + ) + log0( + f"architecture:stacked num_layers:{args.num_layers} " + f"encoder_layers:{base_model.num_encoder_layers} decoder_layers:{base_model.num_decoder_layers} " + f"ssm_blocks:{base_model.num_ssm_blocks} attn_blocks:{base_model.num_attn_blocks}" + f"{intra_info}{jpcr_info}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr} mtp_lr:{args.mtp_lr if args.mtp_enabled else 0.0} " + f"copy_cache_lr:{args.copy_cache_lr if args.copy_cache_enabled else 0.0} " + f"dual_head_lr:{args.dual_head_lr if args.dual_head_enabled else 0.0}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + log0("Initializing DistributedTokenLoader...") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + train_loss_mask_cache: dict[int, Tensor] = {} + + def build_train_loss_mask(batch_size: int, seq_len: int) -> Tensor | None: + if not args.train_loss_mask_enabled: + return None + mask_cpu = train_loss_mask_cache.get(seq_len) + if mask_cpu is None: + mask_cpu, _, _ = build_loss_mask_cpu(seq_len, train_loss_mask_stride_frac) + train_loss_mask_cache[seq_len] = mask_cpu + return mask_cpu.unsqueeze(0).expand(batch_size, -1).to(device=device) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + log0("Saving initial model and optimizer states for warmup...") + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + warmup_reason = "torch.compile/TileLang" if use_compile else "TileLang/custom kernels" + log0(f"Starting warmup loop ({args.warmup_steps} steps). The first step may compile {warmup_reason} kernels...") + # Pre-build dummy tensors matching the main training loop signature so that + # torch.compile traces the correct graph during warmup (no re-trace at step 1). + _warmup_n_jpcr = (base_model.intra_loop_end - base_model.intra_loop_start + 1) if base_model.jpcr_enabled else 0 + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + warmup_loss_mask = build_train_loss_mask(x.size(0), args.train_seq_len) + # Use the same kwargs signature as the main loop so compile doesn't retrace later. + _wu_teacher_logits: Tensor = torch.empty(0, device=device) + _wu_intermediates: list[Tensor] = [ + torch.zeros(x.size(0), args.train_seq_len, args.model_dim, device=device, dtype=torch.bfloat16) + for _ in range(_warmup_n_jpcr) + ] if _warmup_n_jpcr > 0 else [] + # Dummy per_token_weights / aux_targets so warmup traces the same graph + # as the main loop (some configs pass non-None here — traced branches + # differ, so include them unconditionally to avoid retracing on step 1). + _wu_token_weights = torch.ones_like(y, dtype=torch.float32) if args.byte_weighted_loss_enabled else None + _wu_aux_targets = torch.zeros_like(y, dtype=torch.long) if args.dual_head_enabled else None + _wu_aux_weight = 0.0 + if autocast_enabled: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model( + x, y, + loss_mask=warmup_loss_mask, + per_token_weights=_wu_token_weights, + aux_targets=_wu_aux_targets, + aux_weight=_wu_aux_weight, + distill_teacher_logits=_wu_teacher_logits, + distill_weight=0.0, + distill_temp=args.distill_temp, + logit_reg_weight=0.0, + jpcr_teacher_intermediates=_wu_intermediates, + jpcr_weight=0.0, + jpcr_runtime_active=False, + ) + else: + warmup_loss = model( + x, y, + loss_mask=warmup_loss_mask, + per_token_weights=_wu_token_weights, + aux_targets=_wu_aux_targets, + aux_weight=_wu_aux_weight, + distill_teacher_logits=_wu_teacher_logits, + distill_weight=0.0, + distill_temp=args.distill_temp, + logit_reg_weight=0.0, + jpcr_teacher_intermediates=_wu_intermediates, + jpcr_weight=0.0, + jpcr_runtime_active=False, + ) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if warmup_step == 0 or args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + distill_start_step = resolve_distill_start_step(args) + dual_head_start_step = int(max(0.0, min(1.0, args.dual_head_start_frac)) * args.iterations) + ema_teacher: GPT | None = None + if args.distill_enabled and args.distill_weight > 0.0: + ema_teacher = copy.deepcopy(base_model) + ema_teacher.eval() + for p in ema_teacher.parameters(): + p.requires_grad_(False) + if args.distill_enabled and args.distill_weight > 0.0: + distill_mode = ( + f"step:{args.distill_start_step}" + if args.distill_start_step >= 0 + else ( + f"wallclock_frac:{max(0.0, min(1.0, args.distill_start_wallclock_frac)):.4f}" + if args.distill_start_wallclock_frac >= 0.0 and max_wallclock_ms is not None + else f"iter_frac:{max(0.0, min(1.0, args.distill_start_frac)):.4f}" + ) + ) + log0(f"distill_start: mode:{distill_mode} resolved_step:{distill_start_step}") + if args.jpcr_apply_every > 1: + log0(f"jpcr_apply_every:{args.jpcr_apply_every} (distill+JPCR applied every Nth step)") + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + if device.type == "cuda": + torch.cuda.synchronize() + t0 = time.perf_counter() + + # SWA state: accumulated on CPU to avoid GPU memory pressure. + swa_state: dict[str, torch.Tensor] | None = None + swa_count = 0 + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + if device.type == "cuda": + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + autocast_enabled, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + if device.type == "cuda": + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + # Load SWA-averaged weights before eval + export (better generalization + quantization). + if args.swa_enabled and swa_state is not None: + log0(f"swa: loading averaged weights from {swa_count} snapshots") + cur_dtypes = {k: v.dtype for k, v in base_model.state_dict().items()} + swa_load = {k: v.to(device=device, dtype=cur_dtypes[k]) for k, v in swa_state.items() if k in cur_dtypes} + # strict=False because qat_log_scale entries are intentionally excluded from swa_state. + base_model.load_state_dict(swa_load, strict=not args.qat_lsq) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + + # SWA: once warmdown begins (scale < 1), start averaging weights on CPU every N steps. + # qat_log_scale params are intentionally excluded: SWA would corrupt them by averaging + # scales from different QAT level regimes (256/64/16). The final trained scales are kept. + if args.swa_enabled and scale < 1.0 and step % args.swa_collect_every == 0: + swa_snapshot = { + k: v.detach().cpu().float().clone() + for k, v in base_model.state_dict().items() + if not k.endswith(".qat_log_scale") + } + if swa_state is None: + swa_state = swa_snapshot + swa_count = 1 + else: + inv = 1.0 / (swa_count + 1) + for k, v in swa_snapshot.items(): + if k in swa_state: + swa_state[k].mul_(1.0 - inv).add_(v, alpha=inv) + swa_count += 1 + + # QAT: enable fake-quantisation once model has partially converged. + # int8: single stage at qat_start_step (levels=256). + # int4: 3-stage progressive schedule starting at qat_start_step: + # stage 0 (<33% of QAT window): levels=256 (gentle, int8-equivalent) + # stage 1 (33-67% of QAT window): levels=64 + # stage 2 (>67% of QAT window): levels=16 (true int4) + # Progressive avoids the catastrophic loss spike from jumping straight to 16 levels. + if args.qat_scheme != "none": + target_levels, qat_mode = qat_target_levels(args, step, elapsed_ms, max_wallclock_ms) + if CastedLinear.qat_levels != target_levels: + prev_levels = CastedLinear.qat_levels + CastedLinear.qat_levels = target_levels + log0( + f"qat: {'enabled' if target_levels > 0 else 'disabled'} levels:{target_levels} " + f"step:{step} elapsed_ms:{elapsed_ms:.0f} mode:{qat_mode}" + ) + # LSQ: on the transition from 0 → nonzero, seed per-row log-scales from + # the current weight statistics (max-abs / half). Also reseed on each + # progressive level change so the learned scales start from a valid grid + # for the new quantisation resolution. + if args.qat_lsq and target_levels > 0 and prev_levels != target_levels: + n_lsq = init_lsq_scales(base_model, target_levels) + log0(f"qat: lsq_init count:{n_lsq} levels:{target_levels}") + # Clear stale Adam momentum/variance from the previous level regime + # so the fresh scale values get unbiased gradient updates. + if optimizer_lsq is not None: + optimizer_lsq.state.clear() + log0(f"qat: lsq_state_reset levels:{target_levels}") + + # Sequence length curriculum: ramp from curriculum_min_seq_len → train_seq_len. + if args.curriculum_enabled and step < args.curriculum_steps: + frac_c = step / max(args.curriculum_steps, 1) + curr_seq_len = args.curriculum_min_seq_len + int((args.train_seq_len - args.curriculum_min_seq_len) * frac_c) + curr_seq_len = 1 << int(math.log2(max(64, curr_seq_len))) + else: + curr_seq_len = args.train_seq_len + + distill_active = ( + ema_teacher is not None + and args.distill_weight > 0.0 + and distill_is_active(args, step, elapsed_ms, max_wallclock_ms, distill_start_step) + ) + apply_distill_this_step = bool(distill_active and (step % args.jpcr_apply_every == 0)) + jpcr_runtime_active = bool(base_model.jpcr_enabled and apply_distill_this_step) + # JPCR loss warmup: ramp weight from 0 → full over jpcr_warmup_steps after distill activates. + # Also freeze blend gates for first 300 steps so predictors learn via loss before affecting forward pass. + if distill_active and base_model.jpcr_enabled: + if not hasattr(main, "_jpcr_distill_start_step"): + main._jpcr_distill_start_step = step # type: ignore[attr-defined] + jpcr_steps_since = step - main._jpcr_distill_start_step # type: ignore[attr-defined] + jpcr_ramp = min(jpcr_steps_since / max(args.jpcr_warmup_steps, 1), 1.0) + jpcr_active_weight = args.jpcr_weight * jpcr_ramp + # Freeze/unfreeze blend gates: let predictor learn before gate opens + gate_frozen = jpcr_steps_since < 300 + else: + jpcr_active_weight = 0.0 + gate_frozen = False + dual_head_active_weight = ( + float(args.dual_head_weight) + if args.dual_head_enabled and step >= dual_head_start_step and args.dual_head_weight > 0.0 + else 0.0 + ) + + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, curr_seq_len, grad_accum_steps) + # Always pass consistent types AND shapes to forward() to avoid torch.compile + # retracing when distillation activates. JPCR is only enabled once distill is on. + teacher_logits: Tensor = torch.empty(0, device=device) + if jpcr_runtime_active and args.jpcr_weight > 0.0: + _n_jpcr = (base_model.intra_loop_end - base_model.intra_loop_start + 1) + teacher_intermediates: list[Tensor] = [ + torch.zeros(x.size(0), curr_seq_len, args.model_dim, device=device, dtype=torch.bfloat16) + for _ in range(_n_jpcr) + ] + else: + teacher_intermediates = [] + token_weights: Tensor | None = None + aux_targets: Tensor | None = None + train_loss_mask = build_train_loss_mask(x.size(0), curr_seq_len) + if apply_distill_this_step and ema_teacher is not None: + # Use no_grad (not inference_mode) because inference tensors can error when + # downstream ops save them for backward (e.g., KL in distillation under compile). + # Wrap in autocast to match training dtype (bf16) — teacher weights are bf16. + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=autocast_enabled): + if jpcr_runtime_active and args.jpcr_weight > 0.0: + # Capture both logits and per-block intermediates for JPCR. + teacher_logits, teacher_intermediates = ema_teacher.forward_logits_and_intermediates( + x, jpcr_runtime_active=True + ) + teacher_logits = teacher_logits.detach() + teacher_intermediates = [h.detach() for h in teacher_intermediates] + else: + teacher_logits = ema_teacher.forward_logits(x).detach() + if args.byte_weighted_loss_enabled: + with torch.no_grad(): + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.float32) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.float32) + mean_bytes = token_bytes.mean().clamp_min(1e-6) + rel = token_bytes / mean_bytes + alpha = float(args.byte_weighted_loss_alpha) + rel = (1.0 - alpha) + alpha * rel + token_weights = rel.reshape_as(y) + if dual_head_active_weight > 0.0: + with torch.no_grad(): + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + is_boundary = is_boundary_token_lut[tgt_ids] + has_space = has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids] + is_long = base_bytes_lut[tgt_ids] >= 4 + cls = torch.zeros_like(tgt_ids, dtype=torch.long) + cls = torch.where(has_space, torch.ones_like(cls), cls) # class 1: leading-space continuation + cls = torch.where(is_long, torch.full_like(cls, 2), cls) # class 2: long piece (4+ bytes) + cls = torch.where(is_boundary, torch.full_like(cls, 3), cls) # class 3: boundary/special + aux_targets = cls.reshape_as(y) + if autocast_enabled: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model( + x, + y, + loss_mask=train_loss_mask, + per_token_weights=token_weights, + aux_targets=aux_targets, + aux_weight=dual_head_active_weight, + distill_teacher_logits=teacher_logits, + distill_weight=args.distill_weight if apply_distill_this_step else 0.0, + distill_temp=args.distill_temp, + logit_reg_weight=args.logit_reg_weight, + jpcr_teacher_intermediates=teacher_intermediates, + jpcr_weight=jpcr_active_weight, + jpcr_runtime_active=jpcr_runtime_active, + ) + else: + loss = model( + x, + y, + loss_mask=train_loss_mask, + per_token_weights=token_weights, + aux_targets=aux_targets, + aux_weight=dual_head_active_weight, + distill_teacher_logits=teacher_logits, + distill_weight=args.distill_weight if apply_distill_this_step else 0.0, + distill_temp=args.distill_temp, + logit_reg_weight=args.logit_reg_weight, + jpcr_teacher_intermediates=teacher_intermediates, + jpcr_weight=jpcr_active_weight, + jpcr_runtime_active=jpcr_runtime_active, + ) + train_loss += loss.detach() + (loss * grad_scale).backward() + if gate_frozen: + for p in base_model.jpcr_predictors: + if p.blend_gate.grad is not None: + p.blend_gate.grad = None + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + if ema_teacher is not None: + with torch.no_grad(): + decay = float(args.distill_ema_decay) + for p_t, p_s in zip(ema_teacher.parameters(), base_model.parameters(), strict=True): + p_t.mul_(decay).add_(p_s, alpha=1.0 - decay) + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + if device.type == "cuda": + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # a compressed quantized artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + raw_total_submission = model_bytes + code_bytes + raw_budget_delta = args.submission_size_budget_bytes - raw_total_submission + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {raw_total_submission} bytes") + if raw_budget_delta >= 0: + log0( + f"submission_budget raw_total:{raw_total_submission} budget:{args.submission_size_budget_bytes} " + f"headroom_bytes:{raw_budget_delta}" + ) + else: + log0( + f"submission_budget raw_total:{raw_total_submission} budget:{args.submission_size_budget_bytes} " + f"over_bytes:{-raw_budget_delta}" + ) + + resolved_compressor, compressor_note = resolve_compressor(args.compressor) + + export_state_dict = base_model.state_dict() + qat_export_levels = CastedLinear.qat_levels + if master_process and args.qat_scheme != "none" and qat_export_levels <= 0: + log0( + f"qat_warning: QAT_SCHEME={args.qat_scheme} was requested but fake-quant never enabled before export; " + f"step:{step} qat_start_step:{args.qat_start_step} qat_end_step:{args.qat_end_step} " + f"qat_start_wallclock_frac:{args.qat_start_wallclock_frac} " + f"qat_end_wallclock_frac:{args.qat_end_wallclock_frac} iterations:{args.iterations}" + ) + elif master_process and args.qat_scheme != "none": + log0(f"qat_export: active_levels:{qat_export_levels}") + + # LSQ export plumbing (if enabled): collect learned per-row scales and strip + # the log_scale parameters from the state_dict. + lsq_scales_export: dict[str, Tensor] | None = None + if args.qat_lsq: + lsq_scales_export = collect_lsq_scales(base_model) + export_state_dict = { + k: v for k, v in export_state_dict.items() if not k.endswith(".qat_log_scale") + } + if master_process: + log0(f"qat_lsq: collected {len(lsq_scales_export)} per-row scales for export") + + # GPTQ: Hessian-aware post-training quantization (replaces naive round-to-nearest). + gptq_results: dict[str, tuple[Tensor, Tensor]] | None = None + if args.gptq_enabled: + active_scheme = args.mixed_low_precision_scheme if args.quant_scheme == "mixed" else args.quant_scheme + gptq_bits = 4 if active_scheme == "int4" else (5 if active_scheme == "int5" else 8) + if master_process: + log0(f"gptq: collecting Hessians from {args.gptq_nsamples} calibration samples...") + CastedLinear.qat_levels = 0 # disable fake-quant for calibration + hessians = collect_gptq_hessians( + base_model, val_tokens, device, + seq_len=args.train_seq_len, + nsamples=args.gptq_nsamples, + ) + if master_process: + log0(f"gptq: collected {len(hessians)} Hessians, quantizing with bits={gptq_bits}...") + gptq_results = gptq_quantize_state_dict( + base_model, export_state_dict, hessians, + bits=gptq_bits, + percdamp=args.gptq_percdamp, + blocksize=args.gptq_blocksize, + group_size=INT4_GROUP_SIZE if gptq_bits == 4 else 0, + use_nf4=NF4_ENABLED if gptq_bits == 4 else False, + ) + if master_process: + log0(f"gptq: quantized {len(gptq_results)} weight matrices") + + quant_obj, quant_stats = quantize_state_dict( + export_state_dict, + scheme=args.quant_scheme, + weight_order=args.weight_order, + mixed_low_precision_scheme=args.mixed_low_precision_scheme, + precomputed_scales=lsq_scales_export, + gptq_results=gptq_results, + ) + artifact_name = export_artifact_name(args.quant_scheme, resolved_compressor) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, resolved_compressor, args.compress_level) + quant_raw_bytes = len(quant_raw) + if master_process: + with open(artifact_name, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(artifact_name) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["payload_bytes"], 1) + if compressor_note: + log0(f"export_note:{compressor_note}") + log0( + f"export_config quant_scheme:{args.quant_scheme} mixed_low_precision_scheme:{args.mixed_low_precision_scheme} " + f"compressor:{resolved_compressor} weight_order:{args.weight_order} compress_level:{args.compress_level}" + ) + log0( + f"Serialized model {args.quant_scheme}+{resolved_compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + quant_total_submission = quant_file_bytes + code_bytes + quant_budget_delta = args.submission_size_budget_bytes - quant_total_submission + log0(f"Total submission size {args.quant_scheme}+{resolved_compressor}: {quant_total_submission} bytes") + if quant_budget_delta >= 0: + log0( + f"submission_budget {args.quant_scheme}+{resolved_compressor} total:{quant_total_submission} " + f"budget:{args.submission_size_budget_bytes} headroom_bytes:{quant_budget_delta}" + ) + else: + log0( + f"submission_budget {args.quant_scheme}+{resolved_compressor} total:{quant_total_submission} " + f"budget:{args.submission_size_budget_bytes} over_bytes:{-quant_budget_delta}" + ) + with open("final_export_manifest.json", "w", encoding="utf-8") as f: + json.dump( + { + "quant_scheme": args.quant_scheme, + "mixed_low_precision_scheme": args.mixed_low_precision_scheme, + "compressor_requested": args.compressor, + "compressor_resolved": resolved_compressor, + "compress_level": args.compress_level, + "weight_order": args.weight_order, + "artifact_name": artifact_name, + "artifact_bytes": quant_file_bytes, + "code_bytes": code_bytes, + "total_submission_bytes": quant_total_submission, + "submission_size_budget_bytes": args.submission_size_budget_bytes, + "budget_headroom_bytes": quant_budget_delta, + "baseline_tensor_bytes": quant_stats["baseline_tensor_bytes"], + "payload_bytes": quant_stats["payload_bytes"], + "raw_torch_bytes": quant_raw_bytes, + "payload_ratio": ratio, + "quant_format": quant_obj.get("__quant_format__", ""), + }, + f, + indent=2, + sort_keys=True, + ) + + if args.final_roundtrip_eval: + if distributed: + dist.barrier() + # Disable QAT fake-quant during roundtrip eval so loaded dequantized + # weights are not re-fake-quantized through stale LSQ scales. + CastedLinear.qat_levels = 0 + with open(artifact_name, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(decompress_blob(quant_blob_disk, resolved_compressor)), + map_location="cpu", + weights_only=True, + ) + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=False) + if device.type == "cuda": + torch.cuda.synchronize() + t_qeval = time.perf_counter() + roundtrip_tag = f"final_{args.quant_scheme}_{resolved_compressor}_roundtrip" + q_val_loss, q_val_bpb = run_final_eval_suite( + args, + roundtrip_tag, + model, + rank, + world_size, + device, + autocast_enabled, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + sweep_specs, + blend_specs, + blend_weights, + log0, + ) + if device.type == "cuda": + torch.cuda.synchronize() + log0( + f"{roundtrip_tag} val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms mode:{args.final_eval_mode}" + ) + log0( + f"{roundtrip_tag}_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f} " + f"mode:{args.final_eval_mode}" + ) + else: + log0("final_roundtrip skipped FINAL_ROUNDTRIP_EVAL=0") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +device:cuda:0 distributed:True use_torch_compile:False torch_compile_dynamic:True +Thu Apr 30 17:08:00 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | +| N/A 32C P0 115W / 700W | 1185MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 8775 C /usr/bin/python3.12 1176MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/parameter-golf/data/tokenizers/fineweb_8192_bpe.model +train_loader:dataset:fineweb10B_sp8192 train_shards:80 +val_loader:shards pattern=/workspace/parameter-golf/data/dual_bpe/datasets/fineweb10B_sp8192/fineweb_val_*.bin tokens:40541184 val_max_tokens:full +eval_primary: seq_len:1024 rope_scale:1.0000 stride_frac:0.5000 final_eval_mode:primary +eval_cont_cache: enabled:0 window:8192 topk:64 weight:0.1200 logit_scale:12.0000 conf_power:1.0000 batch_seqs:8 +train_loss_mask: enabled:0 stride_frac:0.5000 +ddp_find_unused_parameters:0 +model_params:18313072 +world_size:1 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=True math=True mode=auto +attention_mode:gqa num_heads:8 num_kv_heads:4 use_swiglu:True use_ssm:True ssm_every_n:4 ssm_impl:mamba3 ssm_expand:2.0 ssm_kernel:4 mamba3_d_state:128 mamba3_head_dim:64 mamba3_is_mimo:True mamba3_mimo_rank:4 mamba3_chunk_size:16 mamba3_outproj_norm:False mtp_enabled:False mtp_steps:2 mtp_weight:0.3 mtp_decay:1.0 mtp_tie_embeddings:True distill_enabled:False distill_start_frac:-1.0 distill_start_step:-1 distill_start_wallclock_frac:-1.0 distill_weight:0.1 distill_temp:1.5 distill_ema_decay:0.999 jpcr_apply_every:1 logit_reg_weight:0.0 byte_weighted_loss:False byte_weighted_loss_alpha:1.0 residual_ngram_enabled:False residual_bigram_rank:0 residual_trigram_rank:0 residual_ngram_lr:0.04 residual_ngram_mix_init:-2.5 ngram_softcap:0.0 ngram_entropy_gate:False ttt_enabled:False ttt_lr:0.001 ttt_steps:1 ttt_momentum:0.9 copy_cache_enabled:False copy_cache_window:256 copy_cache_dim:64 copy_cache_lr:0.02 copy_cache_gate_init:-4.0 dual_head_enabled:False dual_head_weight:0.05 dual_head_start_frac:0.0 dual_head_lr:0.02 qat_scheme:none qat_start_step:9000 qat_end_step:0 qat_start_wallclock_frac:-1.0 qat_end_wallclock_frac:1.0 moe_num_experts:0 moe_every_n:2 moe_capacity_factor:1.0 moe_aux_loss_coeff:0.001 num_moe_blocks:0 +architecture:stacked num_layers:9 encoder_layers:4 decoder_layers:5 ssm_blocks:2 attn_blocks:7 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 mtp_lr:0.0 copy_cache_lr:0.0 dual_head_lr:0.0 +train_batch_tokens:65536 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:1800.000 +seed:1337 +Initializing DistributedTokenLoader... +Saving initial model and optimizer states for warmup... +Starting warmup loop (20 steps). The first step may compile TileLang/custom kernels kernels... +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:1/20000 train_loss:9.0106 train_time:161ms step_avg:161.09ms +step:2/20000 train_loss:8.7251 train_time:429ms step_avg:214.51ms +step:3/20000 train_loss:8.0297 train_time:691ms step_avg:230.40ms +step:4/20000 train_loss:8.7973 train_time:928ms step_avg:231.88ms +step:5/20000 train_loss:9.6566 train_time:1172ms step_avg:234.49ms +step:6/20000 train_loss:9.1696 train_time:1409ms step_avg:234.87ms +step:7/20000 train_loss:8.5339 train_time:1677ms step_avg:239.52ms +step:8/20000 train_loss:8.2636 train_time:1931ms step_avg:241.40ms +step:9/20000 train_loss:7.9307 train_time:2176ms step_avg:241.78ms +step:10/20000 train_loss:7.6718 train_time:2433ms step_avg:243.26ms +step:200/20000 train_loss:4.8204 train_time:29107ms step_avg:145.53ms +step:400/20000 train_loss:4.2743 train_time:57228ms step_avg:143.07ms +step:600/20000 train_loss:4.0320 train_time:85353ms step_avg:142.25ms +step:800/20000 train_loss:4.0121 train_time:113380ms step_avg:141.72ms +step:1000/20000 train_loss:4.0565 train_time:141437ms step_avg:141.44ms +step:1200/20000 train_loss:4.0025 train_time:169519ms step_avg:141.27ms +step:1400/20000 train_loss:3.8015 train_time:197585ms step_avg:141.13ms +step:1600/20000 train_loss:3.7708 train_time:225889ms step_avg:141.18ms +step:1800/20000 train_loss:3.7069 train_time:253926ms step_avg:141.07ms +step:2000/20000 train_loss:3.7772 train_time:282003ms step_avg:141.00ms +step:2200/20000 train_loss:3.7300 train_time:310037ms step_avg:140.93ms +step:2400/20000 train_loss:3.6245 train_time:338101ms step_avg:140.88ms +step:2600/20000 train_loss:3.7231 train_time:366169ms step_avg:140.83ms +step:2800/20000 train_loss:3.8833 train_time:394270ms step_avg:140.81ms +step:3000/20000 train_loss:3.2707 train_time:422332ms step_avg:140.78ms +step:3200/20000 train_loss:3.6679 train_time:450690ms step_avg:140.84ms +step:3400/20000 train_loss:3.6953 train_time:478757ms step_avg:140.81ms +step:3600/20000 train_loss:3.5502 train_time:506815ms step_avg:140.78ms +step:3800/20000 train_loss:3.5074 train_time:534855ms step_avg:140.75ms +step:4000/20000 train_loss:3.4725 train_time:562915ms step_avg:140.73ms +step:4200/20000 train_loss:3.3046 train_time:590974ms step_avg:140.71ms +step:4400/20000 train_loss:3.5560 train_time:619068ms step_avg:140.70ms +step:4600/20000 train_loss:3.4089 train_time:659431ms step_avg:143.35ms +step:4800/20000 train_loss:3.4861 train_time:687539ms step_avg:143.24ms +step:5000/20000 train_loss:3.4997 train_time:715654ms step_avg:143.13ms +step:5200/20000 train_loss:3.5006 train_time:743723ms step_avg:143.02ms +step:5400/20000 train_loss:3.5224 train_time:771734ms step_avg:142.91ms +step:5600/20000 train_loss:3.4293 train_time:799807ms step_avg:142.82ms +step:5800/20000 train_loss:3.5951 train_time:827902ms step_avg:142.74ms +step:6000/20000 train_loss:3.4159 train_time:855941ms step_avg:142.66ms +step:6200/20000 train_loss:3.4465 train_time:886671ms step_avg:143.01ms +step:6400/20000 train_loss:3.5107 train_time:914736ms step_avg:142.93ms +step:6600/20000 train_loss:3.5548 train_time:942789ms step_avg:142.85ms +step:6800/20000 train_loss:3.4912 train_time:970834ms step_avg:142.77ms +step:7000/20000 train_loss:3.4542 train_time:998875ms step_avg:142.70ms +step:7200/20000 train_loss:3.3893 train_time:1026899ms step_avg:142.62ms +step:7400/20000 train_loss:3.5671 train_time:1054951ms step_avg:142.56ms +step:7600/20000 train_loss:3.4850 train_time:1082994ms step_avg:142.50ms +step:7800/20000 train_loss:3.4146 train_time:1121404ms step_avg:143.77ms +step:8000/20000 train_loss:3.4183 train_time:1149479ms step_avg:143.68ms +step:8200/20000 train_loss:3.4433 train_time:1177509ms step_avg:143.60ms +step:8400/20000 train_loss:3.4722 train_time:1205562ms step_avg:143.52ms +step:8600/20000 train_loss:3.5405 train_time:1233610ms step_avg:143.44ms +step:8800/20000 train_loss:3.4164 train_time:1261663ms step_avg:143.37ms +step:9000/20000 train_loss:3.3632 train_time:1289705ms step_avg:143.30ms +step:9200/20000 train_loss:3.4606 train_time:1323722ms step_avg:143.88ms +step:9400/20000 train_loss:3.4833 train_time:1351766ms step_avg:143.80ms +step:9600/20000 train_loss:3.4175 train_time:1380908ms step_avg:143.84ms +step:9800/20000 train_loss:3.5235 train_time:1410960ms step_avg:143.98ms +step:10000/20000 train_loss:3.1892 train_time:1440958ms step_avg:144.10ms +step:10200/20000 train_loss:3.4338 train_time:1470958ms step_avg:144.21ms +step:10400/20000 train_loss:3.4435 train_time:1500858ms step_avg:144.31ms +step:10600/20000 train_loss:3.3058 train_time:1530861ms step_avg:144.42ms +step:10800/20000 train_loss:3.6694 train_time:1571658ms step_avg:145.52ms +step:11000/20000 train_loss:3.2885 train_time:1601505ms step_avg:145.59ms +step:11200/20000 train_loss:3.3120 train_time:1631259ms step_avg:145.65ms +step:11400/20000 train_loss:3.3545 train_time:1661059ms step_avg:145.71ms +step:11600/20000 train_loss:3.2881 train_time:1691158ms step_avg:145.79ms +step:11800/20000 train_loss:3.3444 train_time:1720756ms step_avg:145.83ms +step:12000/20000 train_loss:3.3739 train_time:1750472ms step_avg:145.87ms +step:12200/20000 train_loss:3.4419 train_time:1780435ms step_avg:145.94ms +step:12278/20000 val_loss:3.2398 val_bpb:1.2542 train_time:1800080ms step_avg:146.61ms +stopping_early: wallclock_cap train_time:1800080ms step:12278/20000 +swa: loading averaged weights from 275 snapshots +peak memory allocated: 18470 MiB reserved: 20162 MiB +Serialized model: 65953401 bytes +Code size: 231880 bytes +Total submission size: 66185281 bytes +submission_budget raw_total:66185281 budget:16000000 over_bytes:50185281 +gptq: collecting Hessians from 128 calibration samples... +gptq: collected 56 Hessians, quantizing with bits=8... +gptq: quantized 56 weight matrices +export_config quant_scheme:int8 mixed_low_precision_scheme:int8 compressor:zstd weight_order:none compress_level:-1 +Serialized model int8+zstd: 15847612 bytes (payload:17329844 raw_torch:17386102 payload_ratio:3.46x) +Total submission size int8+zstd: 15860231 bytes +submission_budget int8+zstd total:15860231 budget:16000000 under_bytes:139769 +final_int8_zstd_roundtrip_ctx_exact name:primary seq_len:1024 rope_scale:1.0000 stride_frac:0.5000 ttt:0 ttt_params:0 ttt_lr:0.001 ttt_steps:1 val_loss:3.25624330 val_bpb:1.26060944 +final_int8_zstd_roundtrip val_loss:3.2562 val_bpb:1.2606 eval_time:46872ms mode:primary +final_int8_zstd_roundtrip_exact val_loss:3.25624330 val_bpb:1.26060944 mode:primary diff --git a/records/track_non_record_16mb/2026-04-30_SP8192_BPE_Mamba3_d448_ssm4_1xH100/train_gpt.py b/records/track_non_record_16mb/2026-04-30_SP8192_BPE_Mamba3_d448_ssm4_1xH100/train_gpt.py new file mode 100644 index 0000000000..01dec28cf9 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_SP8192_BPE_Mamba3_d448_ssm4_1xH100/train_gpt.py @@ -0,0 +1,4755 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import importlib +import io +import json +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +_MAMBA3_IMPORT_ERROR: Exception | None = None +try: + from mamba_ssm.modules.mamba3 import Mamba3 as _OfficialMamba3 +except Exception as exc: # pragma: no cover - depends on CUDA extension install + _MAMBA3_IMPORT_ERROR = exc + _OfficialMamba3 = None +# Increase dynamo cache limit to avoid recompilation fallback when training conditions change +# (e.g., distillation activation, rotary cache identity changes). Default is 8, which is too low. +torch._dynamo.config.cache_size_limit = 64 +# Workaround for torch 2.10.0 inductor bug in joint_graph `mul_softmax_pattern` that crashes +# with "Tried to erase Node mul_N but it still had 1 users" during mid-training recompiles. +# The keep-alive fallback (suppress_errors) kicks the *entire* forward into eager, which is +# catastrophic for step time — so we defuse the broken pattern at its source instead. +# +# Strategy: +# (1) Monkey-patch `mul_softmax_pattern` in the joint_graph module and in every PatternEntry +# handler slot that references it. Replace with a no-op that never rewrites the graph. +# (2) Keep suppress_errors=True only as a last-resort safety net, so if a different pattern +# fails during a mid-training recompile the specific subgraph falls back to eager instead +# of killing the whole run. +torch._dynamo.config.suppress_errors = True +def _pg_noop_mul_softmax_pattern(match, *args, **kwargs): # noqa: ANN001 + # No rewrite: leave the matched subgraph alone. Inductor will still lower it correctly + # through the generic softmax/mul path — we just give up this one fusion opportunity. + return +try: + from torch._inductor.fx_passes import joint_graph as _pg_joint_graph + # (a) Replace the module-level function so future imports resolve to the no-op. + if hasattr(_pg_joint_graph, "mul_softmax_pattern"): + _pg_joint_graph.mul_softmax_pattern = _pg_noop_mul_softmax_pattern + # (b) Walk the registered PatternMatcherPass and swap any PatternEntry whose handler is the + # buggy function. In torch 2.10, `patterns.patterns` is a defaultdict[key, list[entry]]. + _pg_patterns = getattr(_pg_joint_graph, "patterns", None) + if _pg_patterns is not None: + _pg_inner = getattr(_pg_patterns, "patterns", None) + if _pg_inner is not None: + # Handle both dict-of-list and plain-list shapes. + if isinstance(_pg_inner, dict): + _pg_iter = [_e for _lst in _pg_inner.values() for _e in _lst] + else: + _pg_iter = list(_pg_inner) + for _entry in _pg_iter: + _h = getattr(_entry, "handler", None) + if _h is None: + continue + _qn = getattr(_h, "__qualname__", "") or getattr(_h, "__name__", "") + if "mul_softmax_pattern" in _qn: + try: + _entry.handler = _pg_noop_mul_softmax_pattern + except Exception: + pass +except Exception: + # If torch's internal layout has shifted, fall through to the suppress_errors safety net. + pass +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/dual_bpe/datasets/fineweb10B_sp8192") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_8192_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + # Optional cap for fast local smoke runs; 0 means full validation split. + val_max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 200)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 5.0)) + use_swiglu = bool(int(os.environ.get("USE_SWIGLU", "1"))) + # Sliding window eval: only score tokens beyond prefix_len in each window. + # eval_stride_frac=0.5 means stride=seq_len//2 → each scored token has ≥seq_len//2 tokens of context. + # eval_stride_frac=1.0 (default) = original non-overlapping behaviour. + eval_stride_frac = float(os.environ.get("EVAL_STRIDE_FRAC", "0.5")) + # Long-context eval: evaluate at a longer sequence length than training. + # 0 = same as train_seq_len. Pair with NTK RoPE scaling (eval_rope_scale>1) for best results. + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", "0")) + # NTK-aware RoPE scaling at eval: new_base = rope_base * eval_rope_scale^(head_dim/(head_dim-2)). + # Suggested: eval_rope_scale = (eval_seq_len / train_seq_len) ** 2 (≈4 for 2× context) + eval_rope_scale = float(os.environ.get("EVAL_ROPE_SCALE", "1.0")) + # Optional extra eval contexts to sweep at the end of a run. These do not affect the + # in-training validation path unless promoted to the primary eval context via EVAL_SEQ_LEN. + eval_sweep_seq_lens = os.environ.get("EVAL_SWEEP_SEQ_LENS", "").strip() + eval_sweep_rope_scales = os.environ.get("EVAL_SWEEP_ROPE_SCALES", "").strip() + # Multi-context eval blend: evaluate multiple contexts on the same scored token blocks and + # blend their token probabilities. Set FINAL_EVAL_MODE=blend to make this the official score. + eval_blend_seq_lens = os.environ.get("EVAL_BLEND_SEQ_LENS", "").strip() + eval_blend_rope_scales = os.environ.get("EVAL_BLEND_ROPE_SCALES", "").strip() + eval_blend_weights = os.environ.get("EVAL_BLEND_WEIGHTS", "").strip() + # 0 = inherit EVAL_STRIDE_FRAC. Otherwise, use this stride fraction for the common scored span. + eval_blend_stride_frac = float(os.environ.get("EVAL_BLEND_STRIDE_FRAC", "0.0")) + # Optional position-dependent blend ramp. Positive bias shifts weight from shorter contexts + # early in the scored span toward longer contexts later in the scored span. + eval_blend_position_bias = float(os.environ.get("EVAL_BLEND_POSITION_BIAS", "0.0")) + eval_blend_position_power = float(os.environ.get("EVAL_BLEND_POSITION_POWER", "1.0")) + # Eval-only continuous cache: mixes the base LM with a retrieval distribution over recent + # validation-history hidden states. This is eval-only and does not change the artifact. + eval_cont_cache_enabled = bool(int(os.environ.get("EVAL_CONT_CACHE_ENABLED", "0"))) + eval_cont_cache_window = int(os.environ.get("EVAL_CONT_CACHE_WINDOW", "8192")) + eval_cont_cache_topk = int(os.environ.get("EVAL_CONT_CACHE_TOPK", "64")) + eval_cont_cache_weight = float(os.environ.get("EVAL_CONT_CACHE_WEIGHT", "0.12")) + eval_cont_cache_logit_scale = float(os.environ.get("EVAL_CONT_CACHE_LOGIT_SCALE", "12.0")) + eval_cont_cache_conf_power = float(os.environ.get("EVAL_CONT_CACHE_CONF_POWER", "1.0")) + eval_cont_cache_batch_seqs = int(os.environ.get("EVAL_CONT_CACHE_BATCH_SEQS", "8")) + # primary | blend + final_eval_mode = os.environ.get("FINAL_EVAL_MODE", "primary").strip().lower() + # Low-rank bigram logit bias: learnable rank-r factored bigram table. + # bigram_bias[i] = bigram_right(bigram_left(prev_token[i])) added to logits before softcap. + # 0 = disabled. 32 costs ~64K int8 params (≈32 KB), well within the 164 KB headroom. + bigram_rank = int(os.environ.get("BIGRAM_RANK", "32")) + bigram_lr = float(os.environ.get("BIGRAM_LR", "0.04")) + # Residual n-gram modeling: mix neural logits with a lightweight n-gram baseline. + # total_prob = (1-gate)*P_neural + gate*P_ngram, where gate is learned per token. + # This lets the transformer focus more capacity on hard residual structure. + residual_ngram_enabled = bool(int(os.environ.get("RESIDUAL_NGRAM_ENABLED", "0"))) + residual_bigram_rank = int(os.environ.get("RESIDUAL_BIGRAM_RANK", "0")) + residual_trigram_rank = int(os.environ.get("RESIDUAL_TRIGRAM_RANK", "0")) + residual_ngram_lr = float(os.environ.get("RESIDUAL_NGRAM_LR", "0.04")) + residual_ngram_mix_init = float(os.environ.get("RESIDUAL_NGRAM_MIX_INIT", "-2.5")) + # Pointer-style local copy/cache head. + # P(next) = (1-gate) * P_model + gate * P_copy, where P_copy attends to recent context + # positions and copies their next-token targets into vocab space. + copy_cache_enabled = bool(int(os.environ.get("COPY_CACHE_ENABLED", "0"))) + copy_cache_window = int(os.environ.get("COPY_CACHE_WINDOW", "256")) + copy_cache_dim = int(os.environ.get("COPY_CACHE_DIM", "64")) + copy_cache_lr = float(os.environ.get("COPY_CACHE_LR", "0.02")) + copy_cache_gate_init = float(os.environ.get("COPY_CACHE_GATE_INIT", "-4.0")) + # Stochastic Weight Averaging: average weights during the warmdown phase. + # Takes the mean of snapshots every SWA_COLLECT_EVERY steps once LR starts decaying. + # Research-confirmed ~0.5-1.5% BPB improvement, especially helps quantization quality. + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_collect_every = int(os.environ.get("SWA_COLLECT_EVERY", "10")) + # Optional train-side loss mask aligned to sliding-window eval. When enabled, only the + # suffix of each training chunk contributes loss, matching the eval metric more closely. + train_loss_mask_enabled = bool(int(os.environ.get("TRAIN_LOSS_MASK_ENABLED", "0"))) + # 0 = inherit EVAL_STRIDE_FRAC. + train_loss_mask_stride_frac = float(os.environ.get("TRAIN_LOSS_MASK_STRIDE_FRAC", "0.0")) + # Sequence length curriculum: ramp seq_len from curriculum_min_seq_len → train_seq_len + # over the first curriculum_steps training steps. Faster early convergence on local patterns. + curriculum_enabled = bool(int(os.environ.get("CURRICULUM_ENABLED", "0"))) + curriculum_min_seq_len = int(os.environ.get("CURRICULUM_MIN_SEQ_LEN", "256")) + curriculum_steps = int(os.environ.get("CURRICULUM_STEPS", "5000")) + # Multi-token prediction (MTP): auxiliary future-token losses used during training. + mtp_enabled = bool(int(os.environ.get("MTP_ENABLED", "0"))) + mtp_steps = int(os.environ.get("MTP_STEPS", "2")) + mtp_weight = float(os.environ.get("MTP_WEIGHT", "0.3")) + mtp_decay = float(os.environ.get("MTP_DECAY", "1.0")) + mtp_tie_embeddings = bool(int(os.environ.get("MTP_TIE_EMBEDDINGS", "1"))) + mtp_lr = float(os.environ.get("MTP_LR", "0.02")) + # On-the-fly distillation (EMA teacher) in the late training tail. + distill_enabled = bool(int(os.environ.get("DISTILL_ENABLED", "0"))) + distill_start_frac = float(os.environ.get("DISTILL_START_FRAC", "0.7")) + # Optional overrides for wallclock-capped runs. DISTILL_START_STEP wins over frac. + # DISTILL_START_WALLCLOCK_FRAC keys distillation off elapsed/max_wallclock instead of ITERATIONS. + distill_start_step = int(os.environ.get("DISTILL_START_STEP", "-1")) + distill_start_wallclock_frac = float(os.environ.get("DISTILL_START_WALLCLOCK_FRAC", "-1.0")) + distill_weight = float(os.environ.get("DISTILL_WEIGHT", "0.08")) + distill_temp = float(os.environ.get("DISTILL_TEMP", "2.0")) + distill_ema_decay = float(os.environ.get("DISTILL_EMA_DECAY", "0.999")) + # JPCR: JEPA Predictive Coding Recurrence. Replaces Ouroboros controllers with + # representation predictors trained via JEPA loss (MSE) against EMA teacher intermediates. + # Each predictor learns to predict the "ideal" hidden state at this depth, then blends + # that prediction into the recurrence input — transforming blind repetition into + # JEPA-guided iterative refinement. Progressive depth targeting: pass s of block i + # targets teacher's block (i+s) output, teaching the recurrence to "look ahead". + # At inference, predictors run as part of the model (no teacher needed). + jpcr_enabled = bool(int(os.environ.get("JPCR_ENABLED", "0"))) + jpcr_hidden = int(os.environ.get("JPCR_HIDDEN", "128")) # predictor MLP hidden dim + jpcr_proj_dim = int(os.environ.get("JPCR_PROJ_DIM", str(jpcr_hidden))) + jpcr_weight = float(os.environ.get("JPCR_WEIGHT", "0.1")) # JEPA MSE loss weight + jpcr_blend_init = float(os.environ.get("JPCR_BLEND_INIT", "-2.0")) # logit for sigmoid gate init (~0.12) + jpcr_lr = float(os.environ.get("JPCR_LR", "0.02")) # predictor learning rate + jpcr_warmup_steps = int(os.environ.get("JPCR_WARMUP_STEPS", "200")) # ramp JPCR loss weight over this many steps after activation + # Distillation/JPCR application cadence. 1 = apply every step. + # When >1, distill+JPCR are applied every Nth step (no stale-target reuse). + _jpcr_apply_every_env = os.environ.get("JPCR_APPLY_EVERY", os.environ.get("JPCR_TEACHER_EVERY", "1")) + jpcr_apply_every = max(1, int(_jpcr_apply_every_env)) + # Dual-head objective: auxiliary coarse-structure prediction head. + # Classes are derived from token properties (boundary/space/byte-length) and trained + # with a small coefficient so the main LM head can focus on harder entropy. + dual_head_enabled = bool(int(os.environ.get("DUAL_HEAD_ENABLED", "0"))) + dual_head_weight = float(os.environ.get("DUAL_HEAD_WEIGHT", "0.05")) + dual_head_start_frac = float(os.environ.get("DUAL_HEAD_START_FRAC", "0.0")) + dual_head_lr = float(os.environ.get("DUAL_HEAD_LR", "0.02")) + # Logit range regularization on pre-softcap logits for quantization robustness. + logit_reg_weight = float(os.environ.get("LOGIT_REG_WEIGHT", "0.0")) + # Sandwich norm: apply post-sublayer RMSNorm (before residual add) for each block. + # Controls residual stream norm growth; used by Gemma 2. + use_sandwich_norm = bool(int(os.environ.get("USE_SANDWICH_NORM", "0"))) + # Embedding scale: multiply token embeddings by sqrt(model_dim) after lookup. + # Aligns embedding magnitude with residual stream scale. Used by Gemma, T5, PaLM. + embed_scale = bool(int(os.environ.get("EMBED_SCALE", "0"))) + # Byte-weighted training loss (align objective closer to tokenizer-agnostic BPB). + byte_weighted_loss_enabled = bool(int(os.environ.get("BYTE_WEIGHTED_LOSS_ENABLED", "0"))) + byte_weighted_loss_alpha = float(os.environ.get("BYTE_WEIGHTED_LOSS_ALPHA", "1.0")) + # Hybrid SSM blocks: periodically replace attention blocks with a mixer. + # In this experiment file the default is official CUDA-backed Mamba-3. + use_ssm = bool(int(os.environ.get("USE_SSM", "0"))) + ssm_every_n = int(os.environ.get("SSM_EVERY_N", "2")) + ssm_expand = float(os.environ.get("SSM_EXPAND", "2.0")) + ssm_kernel = int(os.environ.get("SSM_KERNEL", "4")) + ssm_impl = os.environ.get("SSM_IMPL", "mamba3").strip().lower() + mamba3_d_state = int(os.environ.get("MAMBA3_D_STATE", "128")) + # 0 = auto-pick a divisor of MODEL_DIM near 64. + mamba3_head_dim = int(os.environ.get("MAMBA3_HEAD_DIM", "0")) + mamba3_is_mimo = bool(int(os.environ.get("MAMBA3_IS_MIMO", "1"))) + mamba3_mimo_rank = int(os.environ.get("MAMBA3_MIMO_RANK", "4")) + mamba3_chunk_size = int(os.environ.get("MAMBA3_CHUNK_SIZE", "16")) + mamba3_outproj_norm = bool(int(os.environ.get("MAMBA3_OUTPROJ_NORM", "0"))) + # Quantization-Aware Training: fake-quantise weights during forward to teach the model + # to tolerate quantisation noise, dramatically reducing the roundtrip BPB penalty. + # QAT_SCHEME: "none" | "int8" | "int5" | "int4" (should match QUANT_SCHEME at export) + # QAT_START_STEP/QAT_END_STEP: step-based QAT schedule. + # QAT_START_WALLCLOCK_FRAC/QAT_END_WALLCLOCK_FRAC: optional wallclock-based + # schedule for capped runs; when start frac is >= 0 and max wallclock is set, + # it wins over the step schedule. + qat_scheme = os.environ.get("QAT_SCHEME", "none").strip().lower() + qat_start_step = int(os.environ.get("QAT_START_STEP", "9000")) + qat_end_step = int(os.environ.get("QAT_END_STEP", "0")) + qat_start_wallclock_frac = float(os.environ.get("QAT_START_WALLCLOCK_FRAC", "-1.0")) + qat_end_wallclock_frac = float(os.environ.get("QAT_END_WALLCLOCK_FRAC", "1.0")) + # QAT_LSQ=1 enables Learned Step-Size Quantization: per-row learnable log-scale + # replaces the max-abs scale in fake-quant, reducing int4 roundtrip penalty by + # letting the model optimise the clip threshold per output row via backprop (STE). + qat_lsq = bool(int(os.environ.get("QAT_LSQ", "0"))) + + # GPTQ post-training quantization (replaces naive round-to-nearest at export). + gptq_enabled = bool(int(os.environ.get("GPTQ", "1"))) + gptq_nsamples = int(os.environ.get("GPTQ_NSAMPLES", "128")) + gptq_blocksize = int(os.environ.get("GPTQ_BLOCKSIZE", "128")) + gptq_percdamp = float(os.environ.get("GPTQ_PERCDAMP", "0.01")) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 8192)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + recurrent_core_layers = int(os.environ.get("RECURRENT_CORE_LAYERS", 0)) + recurrent_steps = int(os.environ.get("RECURRENT_STEPS", 0)) + share_ffn_across_blocks = bool(int(os.environ.get("SHARE_FFN_ACROSS_BLOCKS", "0"))) + # Intra-layer recurrence: run layers [intra_loop_start..intra_loop_end] intra_loop_steps times. + # All blocks remain unique (no weight sharing), so parameter count is unchanged. + # Research (arXiv:2505.01855) shows front-loading repetitions on early layers maximises BPB gain. + # Example: INTRA_LOOP_START=0 INTRA_LOOP_END=2 INTRA_LOOP_STEPS=3 on a 9L model gives + # effective depth 9 + 2*3 = 15 with zero extra parameters. + intra_loop_start = int(os.environ.get("INTRA_LOOP_START", "3")) # -1 = disabled + intra_loop_end = int(os.environ.get("INTRA_LOOP_END", "4")) + intra_loop_steps = int(os.environ.get("INTRA_LOOP_STEPS", "2")) + # Parallel residuals: attn and MLP read same pre-norm input, outputs summed. + # One norm per block instead of two; improved gradient flow. Leaderboard PR #1477. + use_parallel_residual = bool(int(os.environ.get("PARALLEL_RESIDUAL", "0"))) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + # Mixture of Experts (MoE): replace dense MLPs with sparse expert routing. + # MOE_NUM_EXPERTS=0 → disabled (dense MLP as usual) + # MOE_NUM_EXPERTS=2 → 2 experts per MoE layer, Expert Choice routing + # MOE_EVERY_N=1 → all layers are MoE; =2 → alternating (even layers); =3 → every 3rd + # MOE_CAPACITY_FACTOR: each expert sees int(cf * S / E) tokens (1.0 = perfect balance) + # MOE_AUX_LOSS_COEFF: weight on router Z-loss (stabilises routing, prevents collapse) + moe_num_experts = int(os.environ.get("MOE_NUM_EXPERTS", "0")) + moe_every_n = int(os.environ.get("MOE_EVERY_N", "2")) + moe_capacity_factor = float(os.environ.get("MOE_CAPACITY_FACTOR", "1.0")) + moe_aux_loss_coeff = float(os.environ.get("MOE_AUX_LOSS_COEFF", "1e-3")) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + # Decoupled softcap for the ngram residual branch (0 = inherit LOGIT_SOFTCAP). + # Letting the ngram branch push harder than the neural head often helps when the + # residual ngram is well-trained (small but sharp tables). + ngram_softcap = float(os.environ.get("NGRAM_SOFTCAP", "0.0")) + # Entropy-conditioned ngram gate: gate also sees a confidence signal (lse - max logit, + # a cheap proxy for -log max_prob of the neural head) so ngram can dominate when the + # neural model is unsure. Adds one scalar input per gate. + ngram_entropy_gate = bool(int(os.environ.get("NGRAM_ENTROPY_GATE", "0"))) + # Test-time training (competition-compliant): after scoring each eval batch, take one + # SGD step on the scored positions' CE loss. Only ngram/gate/scale params update; the + # base transformer is frozen. Params are snapshotted before eval and restored after, + # so intermediate val checkpoints are unaffected. Only activated in the final eval + # suite. Default off so existing runs are bit-identical. + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", "1e-3")) + ttt_steps = int(os.environ.get("TTT_STEPS", "1")) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", "0.9")) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + # Export / compression controls. + quant_scheme = os.environ.get("QUANT_SCHEME", "int8").strip().lower() + compressor = os.environ.get("COMPRESSOR", "zlib").strip().lower() + compress_level = int(os.environ.get("COMPRESS_LEVEL", "-1")) + weight_order = os.environ.get("WEIGHT_ORDER", "none").strip().lower() + mixed_low_precision_scheme = os.environ.get("MIXED_LOW_PRECISION_SCHEME", "int8").strip().lower() + # If 0, skip the post-quantization roundtrip eval pass (saves one full val sweep). + final_roundtrip_eval = bool( + int(os.environ.get("FINAL_ROUNDTRIP_EVAL", os.environ.get("FINAL_INT8_ROUNDTRIP_EVAL", "1"))) + ) + final_int8_roundtrip_eval = final_roundtrip_eval + submission_size_budget_bytes = int(os.environ.get("SUBMISSION_SIZE_BUDGET_BYTES", str(16 * 1024 * 1024))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.to(dtype=torch.bfloat16 if G.is_cuda else torch.float32) + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_dtype = torch.bfloat16 if params[0].device.type == "cuda" else torch.float32 + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=updates_dtype) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + # MuonEq-R: row equilibration before Newton-Schulz + # (removes marginal row-scale mismatch, arxiv 2603.28254) + if g.ndim == 2: + g = g / g.norm(dim=1, keepdim=True).clamp(min=1e-8) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def parse_csv_ints(raw: str) -> list[int]: + values: list[int] = [] + for part in raw.split(","): + item = part.strip() + if item: + values.append(int(item)) + return values + + +def parse_csv_floats(raw: str) -> list[float]: + values: list[float] = [] + for part in raw.split(","): + item = part.strip() + if item: + values.append(float(item)) + return values + + +def default_eval_rope_scale(seq_len: int, train_seq_len: int) -> float: + if seq_len == train_seq_len: + return 1.0 + return float(seq_len / train_seq_len) ** 2 + + +def resolve_seq_len(raw_seq_len: int, train_seq_len: int) -> int: + return train_seq_len if raw_seq_len <= 0 else raw_seq_len + + +def resolve_stride(seq_len: int, stride_frac: float) -> int: + frac = stride_frac if stride_frac > 0.0 else 1.0 + return max(1, min(seq_len, int(seq_len * frac))) + + +def build_loss_mask_cpu(seq_len: int, stride_frac: float) -> tuple[Tensor, int, int]: + stride = resolve_stride(seq_len, stride_frac) + prefix_len = seq_len - stride + loss_mask_cpu = torch.zeros(seq_len, dtype=torch.float32) + loss_mask_cpu[prefix_len:] = 1.0 + return loss_mask_cpu, prefix_len, stride + + +def format_float_tag(value: float) -> str: + text = f"{value:.4f}".rstrip("0").rstrip(".") + return text.replace("-", "m").replace(".", "p") if text else "0" + + +def make_eval_spec_name(seq_len: int, rope_scale: float) -> str: + return f"seq{seq_len}_rope{format_float_tag(rope_scale)}" + + +def resolve_primary_eval_spec(args: Hyperparameters) -> tuple[str, int, float]: + seq_len = resolve_seq_len(args.eval_seq_len, args.train_seq_len) + rope_scale = float(args.eval_rope_scale) + return "primary", seq_len, rope_scale + + +def resolve_eval_sweep_specs(args: Hyperparameters) -> list[tuple[str, int, float]]: + specs: list[tuple[str, int, float]] = [] + seen: set[tuple[int, int]] = set() + + def add_spec(name: str, seq_len: int, rope_scale: float) -> None: + key = (seq_len, int(round(rope_scale * 1_000_000))) + if key in seen: + return + seen.add(key) + specs.append((name, seq_len, rope_scale)) + + primary_name, primary_seq_len, primary_rope_scale = resolve_primary_eval_spec(args) + add_spec(primary_name, primary_seq_len, primary_rope_scale) + + sweep_seq_lens = parse_csv_ints(args.eval_sweep_seq_lens) + sweep_rope_scales = parse_csv_floats(args.eval_sweep_rope_scales) + if sweep_rope_scales and len(sweep_rope_scales) != len(sweep_seq_lens): + raise ValueError( + "EVAL_SWEEP_ROPE_SCALES must have the same number of entries as EVAL_SWEEP_SEQ_LENS" + ) + for idx, raw_seq_len in enumerate(sweep_seq_lens): + seq_len = resolve_seq_len(raw_seq_len, args.train_seq_len) + rope_scale = ( + sweep_rope_scales[idx] + if sweep_rope_scales + else default_eval_rope_scale(seq_len, args.train_seq_len) + ) + add_spec(make_eval_spec_name(seq_len, rope_scale), seq_len, float(rope_scale)) + return specs + + +def resolve_eval_blend_specs(args: Hyperparameters) -> tuple[list[tuple[str, int, float]], list[float]]: + blend_seq_lens = parse_csv_ints(args.eval_blend_seq_lens) + if not blend_seq_lens: + return [], [] + blend_rope_scales = parse_csv_floats(args.eval_blend_rope_scales) + if blend_rope_scales and len(blend_rope_scales) != len(blend_seq_lens): + raise ValueError( + "EVAL_BLEND_ROPE_SCALES must have the same number of entries as EVAL_BLEND_SEQ_LENS" + ) + blend_weights = parse_csv_floats(args.eval_blend_weights) + if blend_weights and len(blend_weights) != len(blend_seq_lens): + raise ValueError( + "EVAL_BLEND_WEIGHTS must have the same number of entries as EVAL_BLEND_SEQ_LENS" + ) + + specs: list[tuple[str, int, float]] = [] + for idx, raw_seq_len in enumerate(blend_seq_lens): + seq_len = resolve_seq_len(raw_seq_len, args.train_seq_len) + rope_scale = ( + blend_rope_scales[idx] + if blend_rope_scales + else default_eval_rope_scale(seq_len, args.train_seq_len) + ) + specs.append((make_eval_spec_name(seq_len, float(rope_scale)), seq_len, float(rope_scale))) + + if not blend_weights: + blend_weights = [1.0] * len(specs) + total_weight = sum(blend_weights) + if total_weight <= 0.0: + raise ValueError("EVAL_BLEND_WEIGHTS must sum to a positive value") + normalized = [w / total_weight for w in blend_weights] + return specs, normalized + + +def resolve_max_eval_seq_len( + args: Hyperparameters, + sweep_specs: list[tuple[str, int, float]], + blend_specs: list[tuple[str, int, float]], +) -> int: + max_seq_len = args.train_seq_len + for _, seq_len, _ in sweep_specs: + max_seq_len = max(max_seq_len, seq_len) + for _, seq_len, _ in blend_specs: + max_seq_len = max(max_seq_len, seq_len) + return max_seq_len + + +def resolve_train_loss_mask_stride_frac(args: Hyperparameters) -> float: + return args.train_loss_mask_stride_frac if args.train_loss_mask_stride_frac > 0.0 else args.eval_stride_frac + + +def resolve_distill_start_step(args: Hyperparameters) -> int: + if args.distill_start_step >= 0: + return args.distill_start_step + if args.distill_start_frac < 0.0: + return args.iterations + 1 # Never trigger via fraction if negative + return int(max(0.0, min(1.0, args.distill_start_frac)) * args.iterations) + + +def distill_is_active( + args: Hyperparameters, + step: int, + elapsed_ms: float, + max_wallclock_ms: float | None, + distill_start_step: int, +) -> bool: + if args.distill_start_step >= 0: + return step >= args.distill_start_step + if args.distill_start_wallclock_frac >= 0.0 and max_wallclock_ms is not None and max_wallclock_ms > 0.0: + start_frac = max(0.0, min(1.0, args.distill_start_wallclock_frac)) + return elapsed_ms >= start_frac * max_wallclock_ms + return step >= distill_start_step + + +def qat_target_levels( + args: Hyperparameters, + step: int, + elapsed_ms: float, + max_wallclock_ms: float | None, +) -> tuple[int, str]: + if args.qat_scheme == "none": + return 0, "off" + + use_wallclock = ( + args.qat_start_wallclock_frac >= 0.0 + and max_wallclock_ms is not None + and max_wallclock_ms > 0.0 + ) + if use_wallclock: + start_frac = max(0.0, min(1.0, args.qat_start_wallclock_frac)) + end_frac = max(start_frac + 1e-6, min(1.0, args.qat_end_wallclock_frac)) + start_pos = start_frac * max_wallclock_ms + end_pos = end_frac * max_wallclock_ms + current_pos = elapsed_ms + mode = f"wallclock_frac:{start_frac:.4f}->{end_frac:.4f}" + else: + start_pos = float(args.qat_start_step) + end_step = args.qat_end_step if args.qat_end_step > args.qat_start_step else args.iterations + end_pos = float(end_step) + current_pos = float(step) + mode = f"step:{args.qat_start_step}->{int(end_pos)}" + + if current_pos < start_pos: + return 0, mode + if args.qat_scheme == "int8": + return 256, mode + + frac = (current_pos - start_pos) / max(end_pos - start_pos, 1.0) + frac = max(0.0, min(1.0, frac)) + if args.qat_scheme == "int5": + return (256 if frac < 0.33 else (64 if frac < 0.67 else 32)), mode + return (256 if frac < 0.33 else (64 if frac < 0.67 else 16)), mode + + +def build_blend_position_log_weights( + args: Hyperparameters, + blend_specs: list[tuple[str, int, float]], + blend_weights: list[float], + blend_stride: int, + device: torch.device, +) -> Tensor: + base_log_weights = torch.log(torch.tensor(blend_weights, device=device, dtype=torch.float32).clamp_min(1e-12)) + if args.eval_blend_position_bias == 0.0 or len(blend_specs) <= 1: + return base_log_weights[:, None].expand(-1, blend_stride) + + seq_lens = torch.tensor([seq_len for _, seq_len, _ in blend_specs], device=device, dtype=torch.float32) + centered = seq_lens - seq_lens.mean() + centered = centered / centered.abs().max().clamp_min(1e-6) + pos = torch.linspace(0.0, 1.0, steps=blend_stride, device=device, dtype=torch.float32) + signed_pos = 2.0 * pos - 1.0 + power = max(float(args.eval_blend_position_power), 1e-6) + if power != 1.0: + signed_pos = signed_pos.sign() * signed_pos.abs().pow(power) + logits = base_log_weights[:, None] + float(args.eval_blend_position_bias) * centered[:, None] * signed_pos[None, :] + return F.log_softmax(logits, dim=0) + + +def apply_eval_continuous_cache( + args: Hyperparameters, + scored_log_probs: Tensor, + scored_hidden: Tensor, + scored_targets: Tensor, + cache_state: tuple[Tensor, Tensor] | None, +) -> tuple[Tensor, tuple[Tensor, Tensor] | None]: + if not args.eval_cont_cache_enabled: + return scored_log_probs, cache_state + + flat_log_probs = scored_log_probs.reshape(-1, scored_log_probs.size(-1)).float() + flat_hidden = F.normalize(scored_hidden.reshape(-1, scored_hidden.size(-1)).float(), dim=-1) + flat_targets = scored_targets.reshape(-1).to(dtype=torch.int64) + mixed_log_probs = flat_log_probs + + if cache_state is not None and cache_state[0].numel() > 0: + cache_keys, cache_values = cache_state + scores = torch.matmul(flat_hidden, cache_keys.transpose(0, 1)) * float(args.eval_cont_cache_logit_scale) + topk = min(max(int(args.eval_cont_cache_topk), 0), cache_keys.size(0)) + if topk > 0 and topk < cache_keys.size(0): + scores, top_idx = torch.topk(scores, k=topk, dim=-1) + retrieved_ids = cache_values[top_idx] + else: + retrieved_ids = cache_values.unsqueeze(0).expand(scores.size(0), -1) + attn = F.softmax(scores, dim=-1) + cache_probs = torch.zeros_like(mixed_log_probs) + cache_probs.scatter_add_(1, retrieved_ids, attn) + cache_log_probs = torch.log(cache_probs.clamp_min(1e-9)) + mix = torch.full( + (mixed_log_probs.size(0),), + float(args.eval_cont_cache_weight), + device=mixed_log_probs.device, + dtype=torch.float32, + ) + if args.eval_cont_cache_conf_power >= 0.0: + cache_conf = cache_probs.max(dim=-1).values.clamp_(0.0, 1.0) + mix = mix * cache_conf.pow(float(args.eval_cont_cache_conf_power)) + mix = mix.clamp(min=1e-5, max=1.0 - 1e-5) + mixed_log_probs = torch.logaddexp( + torch.log1p(-mix).unsqueeze(-1) + mixed_log_probs, + torch.log(mix).unsqueeze(-1) + cache_log_probs, + ) + + window = max(1, int(args.eval_cont_cache_window)) + new_keys = flat_hidden.detach()[-window:] + new_values = flat_targets.detach()[-window:] + if cache_state is None or cache_state[0].numel() == 0: + updated_state = (new_keys, new_values) + else: + cache_keys, cache_values = cache_state + cache_keys = torch.cat((cache_keys, new_keys), dim=0) + cache_values = torch.cat((cache_values, new_values), dim=0) + if cache_keys.size(0) > window: + cache_keys = cache_keys[-window:] + cache_values = cache_values[-window:] + updated_state = (cache_keys.detach(), cache_values.detach()) + return mixed_log_probs.reshape_as(scored_log_probs).to(dtype=scored_log_probs.dtype), updated_state + + +def get_eval_model(model: nn.Module) -> nn.Module: + raw_model = model.module if hasattr(model, "module") else model + if hasattr(raw_model, "forward_hidden_and_output"): + return raw_model + if hasattr(raw_model, "_orig_mod") and hasattr(raw_model._orig_mod, "forward_hidden_and_output"): + return raw_model._orig_mod + if hasattr(raw_model, "forward_logits"): + return raw_model + if hasattr(raw_model, "_orig_mod") and hasattr(raw_model._orig_mod, "forward_logits"): + return raw_model._orig_mod + raise AttributeError("Could not find a forward_logits-capable model for evaluation") + + +TTT_PARAM_NAME_MATCH = ( + "residual_bigram_", + "residual_trigram_", + "residual_ngram_", + "bigram_left", + "bigram_right", + "bigram_scale", + "copy_gate", +) + + +def collect_ttt_params(raw_model: nn.Module) -> list[tuple[str, nn.Parameter]]: + # Keep TTT scoped to the small adaptive heads/tables. Residual n-gram + # predictors are named residual_bigram_* / residual_trigram_*, not only + # residual_ngram_*, so include all of those prefixes. + params: list[tuple[str, nn.Parameter]] = [] + for name, p in raw_model.named_parameters(): + leaf = name.rsplit(".", 1)[-1] + if any(name.startswith(pref) or leaf.startswith(pref) for pref in TTT_PARAM_NAME_MATCH): + params.append((name, p)) + return params + + +def apply_eval_rope_scaling( + model: nn.Module, + args: Hyperparameters, + seq_len: int, + rope_scale: float, +) -> list[tuple[object, Tensor]]: + if rope_scale == 1.0 and seq_len == args.train_seq_len: + return [] + head_dim = args.model_dim // args.num_heads + ntk_factor = rope_scale ** (head_dim / max(head_dim - 2, 1)) + raw_model = get_eval_model(model) + if not hasattr(raw_model, "blocks"): + return [] + orig_rope_bases: list[tuple[object, Tensor]] = [] + for block in raw_model.blocks: + attn = getattr(block, "attn", None) + rot = getattr(attn, "rotary", None) + if rot is None: + continue + orig_rope_bases.append((rot, rot.inv_freq.clone())) + new_base = args.rope_base * ntk_factor + new_inv_freq = 1.0 / ( + new_base ** (torch.arange(0, head_dim, 2, dtype=torch.float32, device=rot.inv_freq.device) / head_dim) + ) + rot.inv_freq = new_inv_freq + rot._cos_cached = None + return orig_rope_bases + + +def restore_eval_rope_scaling(orig_rope_bases: list[tuple[object, Tensor]]) -> None: + for rot, orig_inv_freq in orig_rope_bases: + rot.inv_freq = orig_inv_freq + rot._cos_cached = None + + +def forward_eval_outputs( + args: Hyperparameters, + model: nn.Module, + x: Tensor, + seq_len: int, + rope_scale: float, + autocast_enabled: bool, +) -> tuple[Tensor, Tensor]: + eval_model = get_eval_model(model) + orig_rope_bases = apply_eval_rope_scaling(model, args, seq_len, rope_scale) + try: + jpcr_runtime_active = bool(getattr(eval_model, "jpcr_enabled", False)) + if autocast_enabled: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + hidden, logits, logits_are_log_probs = eval_model.forward_hidden_and_output( + x, jpcr_runtime_active=jpcr_runtime_active + ) + else: + hidden, logits, logits_are_log_probs = eval_model.forward_hidden_and_output( + x, jpcr_runtime_active=jpcr_runtime_active + ) + finally: + restore_eval_rope_scaling(orig_rope_bases) + log_probs = logits.float().reshape(x.size(0), x.size(1), -1) + if not logits_are_log_probs: + log_probs = F.log_softmax(log_probs, dim=-1) + return log_probs, hidden.float() + + +def eval_val_single( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + autocast_enabled: bool, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + seq_len: int, + rope_scale: float, + stride_frac: float, + ttt_enabled: bool = False, + ttt_lr: float = 0.0, + ttt_steps: int = 1, + ttt_momentum: float = 0.9, +) -> tuple[float, float]: + _, prefix_len, stride = build_loss_mask_cpu(seq_len, stride_frac) + if args.eval_cont_cache_enabled and world_size != 1: + raise ValueError("EVAL_CONT_CACHE_ENABLED currently requires WORLD_SIZE=1 for deterministic eval order") + + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + local_batch_seqs = max(1, local_batch_tokens // seq_len) + if args.eval_cont_cache_enabled: + local_batch_seqs = min(local_batch_seqs, max(1, args.eval_cont_cache_batch_seqs)) + total_wins = max(1, (val_tokens.numel() - seq_len - 1) // stride) + win_start = (total_wins * rank) // world_size + win_end = (total_wins * (rank + 1)) // world_size + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # --- TTT setup (competition-compliant online update) ----------------------------- + # We snapshot the chosen param subset before eval starts, do SGD steps after each + # scored batch, then restore the snapshot before returning. This keeps the stored + # model state untouched so subsequent eval passes / quantization see clean weights. + ttt_active = bool(ttt_enabled) and float(ttt_lr) > 0.0 + ttt_params: list[tuple[str, nn.Parameter]] = [] + ttt_snapshots: list[Tensor] = [] + ttt_prev_requires_grad: dict[int, bool] = {} + ttt_optim: torch.optim.Optimizer | None = None + raw_model = get_eval_model(model) if ttt_active else None + if ttt_active and raw_model is not None: + # Scope: ngram + pointer-gate + small learned scales. Base transformer stays frozen. + ttt_params = collect_ttt_params(raw_model) + ttt_prev_requires_grad = {id(p): p.requires_grad for p in raw_model.parameters()} + for p in raw_model.parameters(): + p.requires_grad_(False) + for _, p in ttt_params: + p.requires_grad_(True) + ttt_snapshots.append(p.detach().clone()) + if ttt_params: + ttt_optim = torch.optim.SGD( + [p for _, p in ttt_params], lr=float(ttt_lr), momentum=float(ttt_momentum) + ) + else: + ttt_active = False # nothing to update + # --------------------------------------------------------------------------------- + + model.eval() + cache_state: tuple[Tensor, Tensor] | None = None + + eval_ctx = torch.enable_grad() if ttt_active else torch.inference_mode() + with eval_ctx: + for batch_win_start in range(win_start, win_end, local_batch_seqs): + batch_win_end = min(batch_win_start + local_batch_seqs, win_end) + xs, ys = [], [] + for w in range(batch_win_start, batch_win_end): + s = w * stride + xs.append(val_tokens[s : s + seq_len]) + ys.append(val_tokens[s + 1 : s + seq_len + 1]) + x = torch.stack(xs).to(device=device, dtype=torch.int64, non_blocking=True) + y = torch.stack(ys).to(device=device, dtype=torch.int64, non_blocking=True) + log_probs, hidden = forward_eval_outputs(args, model, x, seq_len, rope_scale, autocast_enabled) + scored_log_probs = log_probs[:, prefix_len:, :] + scored_hidden = hidden[:, prefix_len:, :] + scored_targets = y[:, prefix_len:] + scored_log_probs, cache_state = apply_eval_continuous_cache( + args, + scored_log_probs, + scored_hidden, + scored_targets, + cache_state, + ) + target_log_probs = scored_log_probs.gather(-1, scored_targets.unsqueeze(-1)).squeeze(-1) + + # Accumulate BPB stats (always detached from the TTT graph). + tlp_detached = target_log_probs.detach() + val_loss_sum += (-tlp_detached).sum(dtype=torch.float64) + val_token_count += tlp_detached.numel() + + prev_ids = x[:, prefix_len:].reshape(-1) + tgt_ids = scored_targets.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + # TTT update: CE on the scored suffix. This is competition-compliant because + # the update happens AFTER emitting the BPB for this batch, and only uses + # tokens whose predictions are already recorded (online learning). + if ttt_active and ttt_optim is not None: + ttt_loss = -target_log_probs.mean() + ttt_loss.backward() + ttt_optim.step() + ttt_optim.zero_grad(set_to_none=True) + for _ in range(max(0, int(ttt_steps) - 1)): + # Additional steps re-run forward on the same batch. Kept behind + # an explicit env knob; default TTT_STEPS=1 skips this branch. + log_probs2, _h2 = forward_eval_outputs(args, model, x, seq_len, rope_scale, autocast_enabled) + slp2 = log_probs2[:, prefix_len:, :] + tlp2 = slp2.gather(-1, scored_targets.unsqueeze(-1)).squeeze(-1) + (-tlp2.mean()).backward() + ttt_optim.step() + ttt_optim.zero_grad(set_to_none=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + # Restore TTT param snapshots and prior requires_grad flags so the underlying + # model is bitwise unchanged after this function returns. + if ttt_active and raw_model is not None: + with torch.no_grad(): + for (_, p), snap in zip(ttt_params, ttt_snapshots): + p.data.copy_(snap) + for p in raw_model.parameters(): + p.requires_grad_(ttt_prev_requires_grad.get(id(p), False)) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_blend( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + autocast_enabled: bool, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + blend_specs: list[tuple[str, int, float]], + blend_weights: list[float], +) -> tuple[float, float]: + if not blend_specs: + raise ValueError("eval_val_blend requires at least one blend spec") + if args.eval_cont_cache_enabled and world_size != 1: + raise ValueError("EVAL_CONT_CACHE_ENABLED currently requires WORLD_SIZE=1 for deterministic eval order") + + blend_stride_frac = args.eval_blend_stride_frac if args.eval_blend_stride_frac > 0.0 else args.eval_stride_frac + min_seq_len = min(seq_len for _, seq_len, _ in blend_specs) + max_seq_len = max(seq_len for _, seq_len, _ in blend_specs) + blend_stride = resolve_stride(min_seq_len, blend_stride_frac) + max_prefix_len = max(seq_len - blend_stride for _, seq_len, _ in blend_specs) + first_target_pos = max_prefix_len + 1 + max_target_start = val_tokens.numel() - blend_stride + if max_target_start < first_target_pos: + raise ValueError( + f"Validation split is too short for blend eval: first_target_pos={first_target_pos}, " + f"max_target_start={max_target_start}" + ) + + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + local_batch_chunks = max(1, local_batch_tokens // max(max_seq_len * len(blend_specs), 1)) + if args.eval_cont_cache_enabled: + local_batch_chunks = min(local_batch_chunks, max(1, args.eval_cont_cache_batch_seqs)) + total_chunks = ((max_target_start - first_target_pos) // blend_stride) + 1 + chunk_start = (total_chunks * rank) // world_size + chunk_end = (total_chunks * (rank + 1)) // world_size + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + cache_states: list[tuple[Tensor, Tensor] | None] = [None] * len(blend_specs) + with torch.inference_mode(): + for batch_chunk_start in range(chunk_start, chunk_end, local_batch_chunks): + batch_chunk_end = min(batch_chunk_start + local_batch_chunks, chunk_end) + target_starts = [first_target_pos + idx * blend_stride for idx in range(batch_chunk_start, batch_chunk_end)] + pos_log_weights = build_blend_position_log_weights( + args, + blend_specs, + blend_weights, + blend_stride, + device, + ) + + common_prev_ids = torch.stack( + [val_tokens[target_pos - 1 : target_pos + blend_stride - 1] for target_pos in target_starts] + ).to(device=device, dtype=torch.int64, non_blocking=True) + common_target_ids = torch.stack( + [val_tokens[target_pos : target_pos + blend_stride] for target_pos in target_starts] + ).to(device=device, dtype=torch.int64, non_blocking=True) + + blend_log_probs: Tensor | None = None + for spec_idx, (spec_name, seq_len, rope_scale) in enumerate(blend_specs): + del spec_name + prefix_len = seq_len - blend_stride + xs = [] + for target_pos in target_starts: + s = target_pos - prefix_len - 1 + xs.append(val_tokens[s : s + seq_len]) + x = torch.stack(xs).to(device=device, dtype=torch.int64, non_blocking=True) + log_probs, hidden = forward_eval_outputs(args, model, x, seq_len, rope_scale, autocast_enabled) + scored_log_probs = log_probs[:, prefix_len:, :] + scored_hidden = hidden[:, prefix_len:, :] + scored_log_probs, cache_states[spec_idx] = apply_eval_continuous_cache( + args, + scored_log_probs, + scored_hidden, + common_target_ids, + cache_states[spec_idx], + ) + weighted_log_probs = scored_log_probs + pos_log_weights[spec_idx][None, :, None] + blend_log_probs = ( + weighted_log_probs + if blend_log_probs is None + else torch.logaddexp(blend_log_probs, weighted_log_probs) + ) + + if blend_log_probs is None: + raise RuntimeError("blend_log_probs should have been populated") + target_log_probs = blend_log_probs.gather(-1, common_target_ids.unsqueeze(-1)).squeeze(-1) + val_loss_sum += (-target_log_probs).sum(dtype=torch.float64) + val_token_count += target_log_probs.numel() + + prev_ids = common_prev_ids.reshape(-1) + tgt_ids = common_target_ids.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + autocast_enabled: bool, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + _, seq_len, rope_scale = resolve_primary_eval_spec(args) + return eval_val_single( + args, + model, + rank, + world_size, + device, + autocast_enabled, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + seq_len, + rope_scale, + args.eval_stride_frac, + ) + + +def run_final_eval_suite( + args: Hyperparameters, + roundtrip_tag: str, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + autocast_enabled: bool, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + sweep_specs: list[tuple[str, int, float]], + blend_specs: list[tuple[str, int, float]], + blend_weights: list[float], + log0, +) -> tuple[float, float]: + primary_name, primary_seq_len, primary_rope_scale = resolve_primary_eval_spec(args) + ttt_param_count = 0 + if args.ttt_enabled and args.ttt_lr > 0.0: + try: + ttt_param_count = len(collect_ttt_params(get_eval_model(model))) + except AttributeError: + ttt_param_count = 0 + ttt_effective = bool(args.ttt_enabled and args.ttt_lr > 0.0 and ttt_param_count > 0) + primary_val_loss, primary_val_bpb = eval_val_single( + args, + model, + rank, + world_size, + device, + autocast_enabled, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + primary_seq_len, + primary_rope_scale, + args.eval_stride_frac, + ttt_enabled=ttt_effective, + ttt_lr=args.ttt_lr, + ttt_steps=args.ttt_steps, + ttt_momentum=args.ttt_momentum, + ) + log0( + f"{roundtrip_tag}_ctx_exact name:{primary_name} seq_len:{primary_seq_len} " + f"rope_scale:{primary_rope_scale:.4f} stride_frac:{args.eval_stride_frac:.4f} " + f"ttt:{1 if ttt_effective else 0} ttt_params:{ttt_param_count} " + f"ttt_lr:{args.ttt_lr} ttt_steps:{args.ttt_steps} " + f"val_loss:{primary_val_loss:.8f} val_bpb:{primary_val_bpb:.8f}" + ) + + for sweep_name, sweep_seq_len, sweep_rope_scale in sweep_specs[1:]: + sweep_val_loss, sweep_val_bpb = eval_val_single( + args, + model, + rank, + world_size, + device, + autocast_enabled, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + sweep_seq_len, + sweep_rope_scale, + args.eval_stride_frac, + ) + log0( + f"{roundtrip_tag}_ctx_exact name:{sweep_name} seq_len:{sweep_seq_len} " + f"rope_scale:{sweep_rope_scale:.4f} stride_frac:{args.eval_stride_frac:.4f} " + f"val_loss:{sweep_val_loss:.8f} val_bpb:{sweep_val_bpb:.8f}" + ) + + blend_result: tuple[float, float] | None = None + if blend_specs: + blend_stride_frac = args.eval_blend_stride_frac if args.eval_blend_stride_frac > 0.0 else args.eval_stride_frac + blend_val_loss, blend_val_bpb = eval_val_blend( + args, + model, + rank, + world_size, + device, + autocast_enabled, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + blend_specs, + blend_weights, + ) + blend_specs_log = ",".join( + f"{name}:{seq_len}@{rope_scale:.4f}" + for name, seq_len, rope_scale in blend_specs + ) + blend_weights_log = ",".join(f"{weight:.6f}" for weight in blend_weights) + log0( + f"{roundtrip_tag}_blend_exact stride_frac:{blend_stride_frac:.4f} specs:{blend_specs_log} " + f"weights:{blend_weights_log} position_bias:{args.eval_blend_position_bias:.4f} " + f"position_power:{args.eval_blend_position_power:.4f} " + f"val_loss:{blend_val_loss:.8f} val_bpb:{blend_val_bpb:.8f}" + ) + blend_result = (blend_val_loss, blend_val_bpb) + + if args.final_eval_mode == "primary": + return primary_val_loss, primary_val_bpb + if args.final_eval_mode == "blend": + if blend_result is None: + raise ValueError("FINAL_EVAL_MODE=blend requires EVAL_BLEND_SEQ_LENS to be set") + return blend_result + raise ValueError(f"Unsupported FINAL_EVAL_MODE={args.final_eval_mode!r}; expected 'primary' or 'blend'") + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +QUANT_SCALE_EPS = float(os.environ.get("QUANT_SCALE_EPS", "1e-8")) +INT4_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT4_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT4_KEEP_FLOAT_MAX_NUMEL = int(os.environ.get("INT4_KEEP_FLOAT_MAX_NUMEL", 65_536)) +INT4_PER_ROW_SCALE_DTYPE = torch.float16 +INT4_CLIP_PERCENTILE = float(os.environ.get("INT4_CLIP_PERCENTILE", 99.995)) +INT4_CLIP_Q = INT4_CLIP_PERCENTILE / 100.0 +INT4_GROUP_SIZE = int(os.environ.get("INT4_GROUP_SIZE", "128")) # 0 = per-row (legacy) +INT5_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT5_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT5_KEEP_FLOAT_MAX_NUMEL = int(os.environ.get("INT5_KEEP_FLOAT_MAX_NUMEL", 65_536)) +INT5_PER_ROW_SCALE_DTYPE = torch.float16 +INT5_CLIP_PERCENTILE = float(os.environ.get("INT5_CLIP_PERCENTILE", 99.997)) +INT5_CLIP_Q = INT5_CLIP_PERCENTILE / 100.0 + +# NF4 lookup table: 16 quantiles of N(0,1), information-theoretically optimal for normal weights. +# Index 0..15 maps to these fixed float values. Quantize: find nearest, store index. +NF4_ENABLED = bool(int(os.environ.get("NF4_ENABLED", "1"))) +NF4_LUT = torch.tensor([ + -1.0, -0.6962, -0.5251, -0.3949, -0.2844, -0.1848, -0.0911, 0.0, + 0.0796, 0.1609, 0.2461, 0.3379, 0.4407, 0.5626, 0.7230, 1.0, +], dtype=torch.float32) +MIXED_KEEP_FLOAT_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "MIXED_KEEP_FLOAT_NAME_PATTERNS", + "tok_emb,lm_head,final_norm,norm," + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +MIXED_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "MIXED_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +MIXED_KEEP_FLOAT_MAX_NUMEL = int(os.environ.get("MIXED_KEEP_FLOAT_MAX_NUMEL", 65_536)) +SUPPORTED_QUANT_SCHEMES = {"int8", "int5", "int4", "mixed"} +SUPPORTED_COMPRESSORS = {"zlib", "zstd", "auto"} +SUPPORTED_WEIGHT_ORDERS = {"none", "name", "size_desc", "dtype_name"} + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor( + name: str, + t: Tensor, + passthrough_orig_dtypes: dict[str, str], + fp32_name_patterns: tuple[str, ...], +) -> Tensor: + if any(pattern in name for pattern in fp32_name_patterns): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def ordered_state_dict_items(state_dict: dict[str, Tensor], mode: str) -> list[tuple[str, Tensor]]: + items = list(state_dict.items()) + if mode == "none": + return items + if mode == "name": + return sorted(items, key=lambda kv: kv[0]) + if mode == "size_desc": + return sorted(items, key=lambda kv: (-int(kv[1].numel()), kv[0])) + if mode == "dtype_name": + return sorted(items, key=lambda kv: (str(kv[1].dtype), kv[0])) + raise ValueError(f"Unsupported WEIGHT_ORDER={mode!r}; expected one of {sorted(SUPPORTED_WEIGHT_ORDERS)}") + +def quantize_float_tensor_int8( + t: Tensor, precomputed_scale: Tensor | None = None +) -> tuple[Tensor, Tensor, dict[str, object] | None]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + if precomputed_scale is not None: + # LSQ-learned scale: use directly, skip the quantile clip computation. + scale = precomputed_scale.float().clamp_min(QUANT_SCALE_EPS) + else: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + scale = (clip_abs / 127.0).clamp_min(QUANT_SCALE_EPS) + q = torch.clamp(torch.round(t32 / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous(), {"scheme": "int8_per_row", "axis": 0} + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale, {"scheme": "int8_per_tensor", "orig_shape": list(t32.shape)} + +def pack_int4_signed(q_signed: Tensor) -> Tensor: + flat = q_signed.reshape(-1).to(dtype=torch.int16) + if flat.numel() % 2: + flat = torch.cat([flat, torch.zeros((1,), dtype=torch.int16)], dim=0) + uint = (flat + 8).to(torch.uint8) + packed = (uint[0::2] & 0x0F) | ((uint[1::2] & 0x0F) << 4) + return packed.contiguous() + +def unpack_int4_signed(packed: Tensor, numel: int) -> Tensor: + p = packed.reshape(-1).to(dtype=torch.uint8) + low = (p & 0x0F).to(dtype=torch.int16) - 8 + high = ((p >> 4) & 0x0F).to(dtype=torch.int16) - 8 + out = torch.empty((p.numel() * 2,), dtype=torch.int16) + out[0::2] = low + out[1::2] = high + return out[:numel].to(dtype=torch.int8).contiguous() + +def pack_int5_signed(q_signed: Tensor) -> Tensor: + """Pack int5 values (range [-16,15]) stored as int8 into 5 bytes per 8 values (40 bits).""" + flat = q_signed.reshape(-1).to(dtype=torch.int32) + pad = (8 - flat.numel() % 8) % 8 + if pad: + flat = torch.cat([flat, torch.zeros(pad, dtype=torch.int32)]) + u = (flat + 16).to(torch.uint8).reshape(-1, 8) # unsigned [0,31] + # 8 x uint5 → 5 bytes + b0 = (u[:, 0] ) | ((u[:, 1] & 0x07) << 5) + b1 = (u[:, 1] >> 3 ) | ( u[:, 2] << 2) | ((u[:, 3] & 0x01) << 7) + b2 = (u[:, 3] >> 1 ) | ((u[:, 4] & 0x0F) << 4) + b3 = (u[:, 4] >> 4 ) | ( u[:, 5] << 1) | ((u[:, 6] & 0x03) << 6) + b4 = (u[:, 6] >> 2 ) | ( u[:, 7] << 3) + packed = torch.stack([b0, b1, b2, b3, b4], dim=1).reshape(-1).to(torch.uint8) + return packed.contiguous() + +def unpack_int5_signed(packed: Tensor, numel: int) -> Tensor: + """Unpack int5 values from 5-bytes-per-8-values layout back to int8 [-16,15].""" + p = packed.reshape(-1, 5).to(torch.int32) + b0, b1, b2, b3, b4 = p[:, 0], p[:, 1], p[:, 2], p[:, 3], p[:, 4] + v0 = b0 & 0x1F + v1 = ((b0 >> 5) & 0x07) | ((b1 & 0x03) << 3) + v2 = ( b1 >> 2) & 0x1F + v3 = ((b1 >> 7) & 0x01) | ((b2 & 0x0F) << 1) + v4 = ((b2 >> 4) & 0x0F) | ((b3 & 0x01) << 4) + v5 = ( b3 >> 1) & 0x1F + v6 = ((b3 >> 6) & 0x03) | ((b4 & 0x07) << 2) + v7 = ( b4 >> 3) & 0x1F + out = torch.stack([v0, v1, v2, v3, v4, v5, v6, v7], dim=1).reshape(-1) + return (out[:numel] - 16).to(torch.int8).contiguous() + +def quantize_float_tensor_int5( + t: Tensor, precomputed_scale: Tensor | None = None +) -> tuple[Tensor, Tensor, dict[str, object]]: + t32 = t.float() + if t32.ndim == 2: + if precomputed_scale is not None: + scale = precomputed_scale.float().clamp_min(QUANT_SCALE_EPS) + else: + clip_abs = ( + torch.quantile(t32.abs(), INT5_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + scale = (clip_abs / 15.0).clamp_min(QUANT_SCALE_EPS) + q = torch.clamp(torch.round(t32 / scale[:, None]), -16, 15).to(torch.int8) + packed = pack_int5_signed(q) + return ( + packed, + scale.to(dtype=INT5_PER_ROW_SCALE_DTYPE).contiguous(), + {"scheme": "int5_per_row", "axis": 0, "orig_shape": [int(t32.shape[0]), int(t32.shape[1])]}, + ) + clip_abs = float(torch.quantile(t32.abs().flatten(), INT5_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 15.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -16, 15).to(torch.int8) + packed = pack_int5_signed(q) + return packed, scale, {"scheme": "int5_per_tensor", "orig_shape": list(t32.shape)} + +def quantize_float_tensor_int4( + t: Tensor, precomputed_scale: Tensor | None = None +) -> tuple[Tensor, Tensor, dict[str, object]]: + t32 = t.float() + if t32.ndim == 2: + if precomputed_scale is not None: + # LSQ-learned scale: skip quantile, use directly. + scale = precomputed_scale.float().clamp_min(QUANT_SCALE_EPS) + else: + clip_abs = ( + torch.quantile(t32.abs(), INT4_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + scale = (clip_abs / 7.0).clamp_min(QUANT_SCALE_EPS) + q = torch.clamp(torch.round(t32 / scale[:, None]), -8, 7).to(torch.int8) + packed = pack_int4_signed(q) + return ( + packed, + scale.to(dtype=INT4_PER_ROW_SCALE_DTYPE).contiguous(), + {"scheme": "int4_per_row", "axis": 0, "orig_shape": [int(t32.shape[0]), int(t32.shape[1])]}, + ) + clip_abs = float(torch.quantile(t32.abs().flatten(), INT4_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 7.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -8, 7).to(torch.int8) + packed = pack_int4_signed(q) + return packed, scale, {"scheme": "int4_per_tensor", "orig_shape": list(t32.shape)} + +def quantize_state_dict( + state_dict: dict[str, Tensor], + scheme: str = "int8", + weight_order: str = "none", + mixed_low_precision_scheme: str = "int8", + precomputed_scales: dict[str, Tensor] | None = None, + gptq_results: dict[str, tuple[Tensor, Tensor]] | None = None, +): + if scheme not in SUPPORTED_QUANT_SCHEMES: + raise ValueError(f"Unsupported QUANT_SCHEME={scheme!r}; expected one of {sorted(SUPPORTED_QUANT_SCHEMES)}") + if weight_order not in SUPPORTED_WEIGHT_ORDERS: + raise ValueError(f"Unsupported WEIGHT_ORDER={weight_order!r}; expected one of {sorted(SUPPORTED_WEIGHT_ORDERS)}") + if mixed_low_precision_scheme not in {"int8", "int5", "int4"}: + raise ValueError( + f"Unsupported MIXED_LOW_PRECISION_SCHEME={mixed_low_precision_scheme!r}; expected 'int8', 'int5', or 'int4'" + ) + + active_scheme = mixed_low_precision_scheme if scheme == "mixed" else scheme + if active_scheme == "int8": + format_name = f"{scheme}_clean_per_row_v1" + elif active_scheme == "int5": + format_name = f"{scheme}_clean_per_row_int5_v1" + else: + format_name = f"{scheme}_clean_per_row_int4_v1" + # Single supported clean-script export formats: + # - per-row low precision for 2D float tensors + # - per-tensor low precision for other float tensors + # - exact passthrough for non-floats + # - passthrough for selected float tensors, stored as fp16/fp32 + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "payload_bytes"), + 0, + ) + keep_patterns = ( + MIXED_KEEP_FLOAT_NAME_PATTERNS + if scheme == "mixed" + else ( + INT8_KEEP_FLOAT_FP32_NAME_PATTERNS + if active_scheme == "int8" + else (INT5_KEEP_FLOAT_FP32_NAME_PATTERNS if active_scheme == "int5" else INT4_KEEP_FLOAT_FP32_NAME_PATTERNS) + ) + ) + force_fp32_patterns = ( + MIXED_KEEP_FLOAT_FP32_NAME_PATTERNS + if scheme == "mixed" + else ( + INT8_KEEP_FLOAT_FP32_NAME_PATTERNS + if active_scheme == "int8" + else (INT5_KEEP_FLOAT_FP32_NAME_PATTERNS if active_scheme == "int5" else INT4_KEEP_FLOAT_FP32_NAME_PATTERNS) + ) + ) + keep_max_numel = ( + MIXED_KEEP_FLOAT_MAX_NUMEL + if scheme == "mixed" + else (INT8_KEEP_FLOAT_MAX_NUMEL if active_scheme == "int8" else (INT5_KEEP_FLOAT_MAX_NUMEL if active_scheme == "int5" else INT4_KEEP_FLOAT_MAX_NUMEL)) + ) + + for name, tensor in ordered_state_dict_items(state_dict, weight_order): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["payload_bytes"] += tensor_nbytes(t) + continue + + should_keep_float = ( + t.numel() <= keep_max_numel + or (scheme == "mixed" and any(pattern in name for pattern in keep_patterns)) + ) + if should_keep_float: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes, force_fp32_patterns) + passthrough[name] = kept + stats["payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + + # GPTQ fast path: use pre-quantized (Q, scale) from Hessian-aware quantization + if gptq_results is not None and name in gptq_results and t.ndim == 2: + gq, gs = gptq_results[name] + if active_scheme == "int5": + packed = pack_int5_signed(gq) + meta = {"scheme": "int5_per_row", "axis": 0, "orig_shape": [int(t.shape[0]), int(t.shape[1])]} + quantized[name] = packed + scales[name] = gs.to(dtype=INT5_PER_ROW_SCALE_DTYPE).contiguous() + elif active_scheme == "int4": + packed = pack_int4_signed(gq) + if gs.ndim == 2: + # Per-group scales: [rows, num_groups] + scheme_name = "int4_per_group_nf4" if NF4_ENABLED else "int4_per_group" + meta = {"scheme": scheme_name, "axis": 0, + "orig_shape": [int(t.shape[0]), int(t.shape[1])], + "group_size": INT4_GROUP_SIZE} + else: + meta = {"scheme": "int4_per_row", "axis": 0, "orig_shape": [int(t.shape[0]), int(t.shape[1])]} + quantized[name] = packed + scales[name] = gs.to(dtype=INT4_PER_ROW_SCALE_DTYPE).contiguous() + else: + meta = {"scheme": "int8_per_row", "axis": 0} + quantized[name] = gq.contiguous() + scales[name] = gs.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + qmeta[name] = meta + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["payload_bytes"] += tensor_nbytes(quantized[name]) + tensor_nbytes(scales[name]) + continue + + pre_scale = None + if precomputed_scales is not None and t.ndim == 2: + pre_scale = precomputed_scales.get(name) + if pre_scale is not None and pre_scale.shape[0] != t.shape[0]: + pre_scale = None # shape mismatch → fall back to quantile + if active_scheme == "int8": + q, s, meta = quantize_float_tensor_int8(t, precomputed_scale=pre_scale) + elif active_scheme == "int5": + q, s, meta = quantize_float_tensor_int5(t, precomputed_scale=pre_scale) + else: + q, s, meta = quantize_float_tensor_int4(t, precomputed_scale=pre_scale) + if meta: + qmeta[name] = meta + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": format_name, + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + "export_order_mode": weight_order, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + # Backward-compatible alias for existing log paths. + stats["int8_payload_bytes"] = stats["payload_bytes"] + return obj, stats + +# ---- GPTQ: Accurate Post-Training Quantization (Frantar et al., 2022) ---- + +@torch.no_grad() +def _nf4_quantize(w: Tensor, scale: Tensor) -> Tensor: + """Quantize values to NF4: find nearest NF4 level, return index in [-8, 7].""" + nf4 = NF4_LUT.to(w.device) # [16] + normalized = w / scale.clamp(min=1e-8) # normalized to ~[-1, 1] + # Find nearest NF4 level for each value + # nf4 has 16 values, indices 0..15, we store as signed [-8..7] + dists = (normalized.unsqueeze(-1) - nf4.unsqueeze(0)).abs() # [rows, 16] + indices = dists.argmin(dim=-1) # [rows] -> 0..15 + return (indices - 8).to(torch.int8) # shift to [-8, 7] for packing + + +def _nf4_dequantize(q_signed: Tensor, scale: Tensor) -> Tensor: + """Dequantize NF4: index into LUT, multiply by scale.""" + nf4 = NF4_LUT.to(q_signed.device) + indices = (q_signed.to(torch.int16) + 8).clamp(0, 15).long() + return nf4[indices] * scale + + +def gptq_quantize_weight( + W: Tensor, + H: Tensor, + bits: int = 4, + percdamp: float = 0.01, + blocksize: int = 128, + group_size: int = 0, + use_nf4: bool = False, + act_order: bool = True, +) -> tuple[Tensor, Tensor]: + """GPTQ-quantize a single weight matrix using Hessian information. + + Args: + W: [out_features, in_features] weight matrix + H: [in_features, in_features] Hessian proxy (X^T X / n) + bits: 4 or 8 + percdamp: damping fraction of mean diagonal + blocksize: column block size for lazy batch updates + group_size: columns per quantization group (0 = per-row) + use_nf4: use NF4 quantile levels instead of uniform (only for bits=4) + act_order: reorder columns by Hessian diagonal (importance) for lower error + + Returns: + (Q_int8, scale) where Q_int8 holds the quantized integers [-8..7] or [-127..127] + and scale is [rows] (per-row) or [rows, num_groups] (per-group). + """ + device = W.device + rows, cols = W.shape + W = W.clone().float() + H = H.clone().float().to(device) + + if bits == 4: + maxq, minq, sym_max = 7, -8, 7.0 + elif bits == 5: + maxq, minq, sym_max = 15, -16, 15.0 + else: + maxq, minq, sym_max = 127, -127, 127.0 + use_nf4 = use_nf4 and bits == 4 # NF4 only for 4-bit + use_groups = group_size > 0 and bits == 4 + + # Dead columns (no activation energy) → zero out weight and fix Hessian + dead = torch.diag(H) == 0 + H[dead, dead] = 1.0 + W[:, dead] = 0.0 + + # Damping for numerical stability + damp = percdamp * torch.mean(torch.diag(H)).item() + diag_idx = torch.arange(cols, device=device) + H[diag_idx, diag_idx] += damp + + # Act-order: sort columns by Hessian diagonal (most important first) + # Only use act-order without groups (act-order + groups is complex) + if act_order and bits == 4 and not use_groups: + perm = torch.argsort(torch.diag(H), descending=True) + W = W[:, perm] + H = H[perm][:, perm] + else: + perm = None + + # Compute H^{-1} via Cholesky for stability + try: + Hinv = torch.cholesky_inverse(torch.linalg.cholesky(H)) + except torch.linalg.LinAlgError: + H[diag_idx, diag_idx] += 10 * damp + Hinv = torch.cholesky_inverse(torch.linalg.cholesky(H)) + + # Compute scales: per-row or per-group (dynamically recomputed per group) + if use_groups: + num_groups = (cols + group_size - 1) // group_size + scale = torch.zeros(rows, num_groups, device=device) + else: + num_groups = 0 + scale = W.abs().amax(dim=1).clamp(min=1e-8) / sym_max + + Q = torch.zeros(rows, cols, dtype=torch.int8, device=device) + + for i1 in range(0, cols, blocksize): + i2 = min(i1 + blocksize, cols) + Err1 = torch.zeros(rows, i2 - i1, device=device) + + # Dynamically compute group scale at group boundary from current W + if use_groups: + g = i1 // group_size + if i1 % group_size == 0: + c0 = g * group_size + c1 = min(c0 + group_size, cols) + scale[:, g] = W[:, c0:c1].abs().amax(dim=1).clamp(min=1e-8) + if not use_nf4: + scale[:, g] /= sym_max + + for j in range(i2 - i1): + col = i1 + j + w = W[:, col] + d = Hinv[col, col].clamp(min=1e-10) + + # Recompute group scale at group boundary within a block + if use_groups and col > i1 and col % group_size == 0: + g = col // group_size + c0 = g * group_size + c1 = min(c0 + group_size, cols) + scale[:, g] = W[:, c0:c1].abs().amax(dim=1).clamp(min=1e-8) + if not use_nf4: + scale[:, g] /= sym_max + + # Get the scale for this column + if use_groups: + col_scale = scale[:, col // group_size] + else: + col_scale = scale + + if use_nf4: + q = _nf4_quantize(w, col_scale) + Q[:, col] = q + w_hat = _nf4_dequantize(q, col_scale) + else: + q = torch.clamp(torch.round(w / col_scale), minq, maxq) + Q[:, col] = q.to(torch.int8) + w_hat = q * col_scale + + err = (w - w_hat) / d + Err1[:, j] = err + + W[:, col] = w_hat # replace with dequantized + if j + 1 < i2 - i1: + W[:, col + 1 : i2] -= err.unsqueeze(1) * Hinv[col, col + 1 : i2].unsqueeze(0) + + # Lazy batch update: propagate accumulated error to remaining columns + if i2 < cols: + W[:, i2:] -= Err1 @ Hinv[i1:i2, i2:] + + # Un-permute back to original column order (act-order only, no groups) + if perm is not None: + invperm = torch.argsort(perm) + Q = Q[:, invperm] + + return Q, scale + + +@torch.no_grad() +def collect_gptq_hessians( + model: nn.Module, + val_tokens: Tensor, + device: torch.device, + seq_len: int = 1024, + nsamples: int = 128, +) -> dict[str, Tensor]: + """Collect H = (1/n) X^T X for each CastedLinear by running calibration data.""" + hessians: dict[str, Tensor] = {} + sample_counts: dict[str, int] = {} + hooks = [] + + for name, module in model.named_modules(): + if isinstance(module, CastedLinear): + key = name + ".weight" + hessians[key] = torch.zeros(module.in_features, module.in_features, device=device) + sample_counts[key] = 0 + + def make_hook(k: str): + def hook_fn(mod, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + hessians[k].addmm_(x.T, x) + sample_counts[k] += x.shape[0] + return hook_fn + + hooks.append(module.register_forward_hook(make_hook(key))) + + # Tied embeddings use F.linear(hidden, tok_emb.weight) instead of a CastedLinear + # module, so hook the final normalized hidden states as calibration inputs for + # tok_emb.weight. This matters most at large vocab sizes where the tied + # embedding/output matrix dominates both parameters and quantization error. + if getattr(model, "tie_embeddings", False) and hasattr(model, "tok_emb") and hasattr(model, "final_norm"): + key = "tok_emb.weight" + emb = getattr(model, "tok_emb") + embed_dim = int(getattr(emb, "embedding_dim", 0)) + if embed_dim > 0 and key not in hessians: + hessians[key] = torch.zeros(embed_dim, embed_dim, device=device) + sample_counts[key] = 0 + + def tied_embedding_hook(_mod, _inp, out): + x = out.detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + hessians[key].addmm_(x.T, x) + sample_counts[key] += x.shape[0] + + hooks.append(model.final_norm.register_forward_hook(tied_embedding_hook)) + + # Disable QAT fake-quant during calibration + saved_qat_levels = CastedLinear.qat_levels + CastedLinear.qat_levels = 0 + + model.eval() + total_tokens = val_tokens.numel() - 1 + tokens_used = 0 + with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for i in range(0, total_tokens - seq_len, seq_len): + if tokens_used >= nsamples * seq_len: + break + x = val_tokens[i : i + seq_len].unsqueeze(0).to(device=device, dtype=torch.int64) + y = val_tokens[i + 1 : i + seq_len + 1].unsqueeze(0).to(device=device, dtype=torch.int64) + model(x, y) + tokens_used += seq_len + + CastedLinear.qat_levels = saved_qat_levels + + for h in hooks: + h.remove() + + # Normalize: H = (1/n) * X^T X + for key in hessians: + n = max(sample_counts[key], 1) + hessians[key] /= n + + return hessians + + +@torch.no_grad() +def gptq_quantize_state_dict( + model: nn.Module, + state_dict: dict[str, Tensor], + hessians: dict[str, Tensor], + bits: int = 4, + percdamp: float = 0.01, + blocksize: int = 128, + group_size: int = 0, + use_nf4: bool = False, +) -> dict[str, tuple[Tensor, Tensor]]: + """Apply GPTQ to all CastedLinear weights that have Hessians. + + Returns {state_dict_key: (Q_int8, scale)} for quantized 2D tensors. + scale is [rows] (per-row) or [rows, num_groups] (per-group). + """ + device = next(model.parameters()).device + results: dict[str, tuple[Tensor, Tensor]] = {} + for name in sorted(hessians.keys()): + if name not in state_dict: + continue + W = state_dict[name].to(device) + if W.ndim != 2: + continue + H = hessians[name] + Q, scale = gptq_quantize_weight( + W, H, bits=bits, percdamp=percdamp, blocksize=blocksize, + group_size=group_size, use_nf4=use_nf4, + ) + results[name] = (Q.cpu(), scale.cpu()) + return results + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + format_name = str(obj.get("__quant_format__", "")) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + meta = qmeta.get(name, {}) + meta_scheme = str(meta.get("scheme", "")) + if meta_scheme in {"int5_per_row", "int5_per_tensor"}: + orig_shape = tuple(int(v) for v in meta.get("orig_shape", q.shape)) + numel = math.prod(orig_shape) + unpacked = unpack_int5_signed(q, numel) + if meta_scheme == "int5_per_row": + rows, cols = orig_shape + scale_row = s.to(dtype=torch.float32).view(rows, 1) + out[name] = (unpacked.float().view(rows, cols) * scale_row).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (unpacked.float().view(orig_shape) * scale).to(dtype=dtype).contiguous() + continue + if meta_scheme in {"int4_per_row", "int4_per_tensor", "int4_per_group", "int4_per_group_nf4"}: + orig_shape = tuple(int(v) for v in meta.get("orig_shape", q.shape)) + numel = math.prod(orig_shape) + unpacked = unpack_int4_signed(q, numel) + if meta_scheme in {"int4_per_group", "int4_per_group_nf4"}: + rows, cols = orig_shape + group_size = int(meta.get("group_size", 128)) + s_f = s.to(dtype=torch.float32) # [rows, num_groups] + q_mat = unpacked.view(rows, cols) + if meta_scheme == "int4_per_group_nf4": + # NF4 dequantization: index into LUT, then multiply by group scale + nf4 = NF4_LUT # [16] + indices = (q_mat.to(torch.int16) + 8).clamp(0, 15).long() + nf4_vals = nf4[indices] # [rows, cols] in [-1, 1] + # Expand group scales to per-column + group_idx = torch.arange(cols) // group_size + group_idx = group_idx.clamp(max=s_f.shape[1] - 1) + col_scales = s_f[:, group_idx] # [rows, cols] + out[name] = (nf4_vals * col_scales).to(dtype=dtype).contiguous() + else: + # Uniform int4 per-group dequantization + group_idx = torch.arange(cols) // group_size + group_idx = group_idx.clamp(max=s_f.shape[1] - 1) + col_scales = s_f[:, group_idx] # [rows, cols] + out[name] = (unpacked.float().view(rows, cols) * col_scales).to(dtype=dtype).contiguous() + elif meta_scheme == "int4_per_row": + rows, cols = orig_shape + scale_row = s.to(dtype=torch.float32).view(rows, 1) + out[name] = (unpacked.float().view(rows, cols) * scale_row).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (unpacked.float().view(orig_shape) * scale).to(dtype=dtype).contiguous() + continue + if meta_scheme in {"int8_per_row", "per_row"} or (s.ndim > 0 and "int4" not in format_name): + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + +def resolve_compressor(requested: str) -> tuple[str, str | None]: + if requested not in SUPPORTED_COMPRESSORS: + raise ValueError(f"Unsupported COMPRESSOR={requested!r}; expected one of {sorted(SUPPORTED_COMPRESSORS)}") + if requested == "zlib": + return "zlib", None + if requested == "zstd": + if importlib.util.find_spec("zstandard") is None: + raise RuntimeError( + "COMPRESSOR=zstd requested, but the `zstandard` package is not installed. " + "Install it with `pip install zstandard` or use COMPRESSOR=zlib." + ) + return "zstd", None + # auto mode + if importlib.util.find_spec("zstandard") is not None: + return "zstd", "COMPRESSOR=auto selected zstd (package available)" + return "zlib", "COMPRESSOR=auto fell back to zlib (zstandard package not installed)" + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + zlib_level = 9 if level < 0 else max(0, min(level, 9)) + return zlib.compress(data, level=zlib_level) + if compressor == "zstd": + import zstandard as zstd # type: ignore + + zstd_level = 19 if level < 0 else level + return zstd.ZstdCompressor(level=zstd_level).compress(data) + raise ValueError(f"Unsupported compressor={compressor!r}") + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + import zstandard as zstd # type: ignore + + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported compressor={compressor!r}") + +def export_artifact_name(quant_scheme: str, compressor: str) -> str: + if quant_scheme == "int8" and compressor == "zlib": + return "final_model.int8.ptz" + return f"final_model.{quant_scheme}.{compressor}.ptc" + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +def _fake_quantize_row(w: Tensor, levels: int) -> Tensor: + """Per-row fake-quantise a 2D weight with a straight-through estimator (STE). + + Matches the per-row clipping used by quantize_float_tensor_int8/int4 at export, + but uses amax instead of quantile for speed in the hot forward path. + levels=256 → int8 symmetric (range −127…127) + levels=16 → int4 symmetric (range −7…7) + """ + half = float(levels // 2 - (1 if levels in (16, 32) else 0)) # 127 for int8, 15 for int5, 7 for int4 + w32 = w.float() + clip_abs = w32.abs().amax(dim=1).clamp_min(1e-6) # per-row max scale + scale = clip_abs / half + w_scaled = (w32 / scale.unsqueeze(1)).clamp(-half, half) + # STE: round in forward, identity in backward + w_ste = w_scaled + (w_scaled.round() - w_scaled).detach() + return (w_ste * scale.unsqueeze(1)).to(w.dtype) + + +def _fake_quantize_row_lsq(w: Tensor, levels: int, log_scale: Tensor) -> Tensor: + """LSQ variant: per-row learnable step-size quantisation with STE. + + Based on "Learned Step Size Quantization" (Esser et al., 2019). + log_scale is a learnable 1D parameter [out_features] optimised via backprop. + Gradient on log_scale is scaled by g = 1/sqrt(numel_per_row * half) per the LSQ paper, + which keeps the scale-gradient magnitude commensurate with weight-gradient magnitude. + + Compared to max-abs fake-quant, LSQ lets the model adapt the clip threshold per row, + reducing int4 quantisation error by ~30-50% on typical models. + """ + half = float(levels // 2 - (1 if levels in (16, 32) else 0)) + w32 = w.float() + # LSQ gradient scaling trick: effective gradient on log_scale is g * d_loss/d_scale. + numel_per_row = float(w32.shape[1]) + g = 1.0 / math.sqrt(max(numel_per_row * half, 1.0)) + ls_grad_scaled = log_scale * g + (log_scale - log_scale * g).detach() + # Convert log-scale to positive scale via exp (auto-positive, stable). + scale = ls_grad_scaled.float().exp().clamp_min(1e-8) + w_scaled = (w32 / scale.unsqueeze(1)).clamp(-half, half) + w_ste = w_scaled + (w_scaled.round() - w_scaled).detach() + return (w_ste * scale.unsqueeze(1)).to(w.dtype) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + # QAT: set qat_levels to 256 (int8), 32 (int5), or 16 (int4) to enable fake-quantisation. + qat_levels: int = 0 # class-level switch updated from the training loop + # LSQ: when True, CastedLinear instances allocate a learnable per-row log-scale parameter + # used in place of the max-abs scale. Must be set BEFORE model construction. + qat_lsq_enabled: bool = False + + def __init__(self, in_features: int, out_features: int, bias: bool = True, **kwargs) -> None: + super().__init__(in_features, out_features, bias=bias, **kwargs) + if __class__.qat_lsq_enabled: + # Per-row log-scale. Zeros → scale=1.0 placeholder; re-initialised from actual + # weight stats at the step QAT first activates (see init_lsq_scales below). + self.qat_log_scale = nn.Parameter(torch.zeros(out_features)) + else: + self.qat_log_scale = None + + def forward(self, x: Tensor) -> Tensor: + w = self.weight + if __class__.qat_levels > 0 and w.ndim == 2: + if self.qat_log_scale is not None: + w = _fake_quantize_row_lsq(w, __class__.qat_levels, self.qat_log_scale) + else: + w = _fake_quantize_row(w, __class__.qat_levels) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w.to(x.dtype), bias) + + +def init_lsq_scales(model: nn.Module, levels: int) -> int: + """Initialise LSQ per-row log-scales from current weight statistics. + + Called once when QAT first activates. Sets each log_scale to + log(max_abs_per_row / half), matching the initial value a max-abs fake-quant would use. + Returns the number of CastedLinear modules initialised. + """ + half = float(levels // 2 - (1 if levels in (16, 32) else 0)) + count = 0 + with torch.no_grad(): + for m in model.modules(): + if isinstance(m, CastedLinear) and m.qat_log_scale is not None and m.weight.ndim == 2: + w32 = m.weight.detach().float() + scale_val = (w32.abs().amax(dim=1).clamp_min(1e-6) / max(half, 1.0)) + m.qat_log_scale.data.copy_(scale_val.log().to(m.qat_log_scale.dtype)) + count += 1 + return count + + +def collect_lsq_scales(model: nn.Module, prefix: str = "") -> dict[str, Tensor]: + """Walk the model and return a dict of {state_dict_weight_name: exp(log_scale)}. + + Used at export time to plumb LSQ-learned scales into quantize_float_tensor_int4/int8 + via the precomputed_scales dict. + """ + scales: dict[str, Tensor] = {} + for name, m in model.named_modules(prefix=prefix): + if isinstance(m, CastedLinear) and m.qat_log_scale is not None and m.weight.ndim == 2: + key = f"{name}.weight" if name else "weight" + scales[key] = m.qat_log_scale.detach().float().exp().clamp_min(1e-8).cpu() + return scales + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if num_heads <= 0: + raise ValueError(f"num_heads must be positive, got {num_heads}") + if num_kv_heads <= 0: + raise ValueError(f"num_kv_heads must be positive, got {num_kv_heads}") + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + # Expand KV heads to match Q heads for GQA (handles older PyTorch without enable_gqa) + if self.num_kv_heads != self.num_heads: + groups = self.num_heads // self.num_kv_heads + k = k.repeat_interleave(groups, dim=1) + v = v.repeat_interleave(groups, dim=1) + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int, use_swiglu: bool = False): + super().__init__() + self.use_swiglu = use_swiglu + if use_swiglu: + # SwiGLU with the same parameter budget as relu²: + # relu² uses 2 matrices of (dim × mlp_mult*dim) = 2*mlp_mult*dim² params. + # SwiGLU uses 3 matrices of (dim × h): 3*h*dim params. + # Equating: h = (2/3)*mlp_mult*dim. Round down to multiple of 64 for hardware alignment. + hidden = max(64, (2 * mlp_mult * dim // 3 // 64) * 64) + self.gate = CastedLinear(dim, hidden, bias=False) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + else: + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + if self.use_swiglu: + return self.proj(F.silu(self.gate(x)) * self.fc(x)) + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class MoEMLP(nn.Module): + """Sparse Mixture-of-Experts MLP with Expert Choice routing. + + Design goals + ============ + 1. **torch.compile(fullgraph=True) compatible** — Expert Choice routing gives + every expert a statically-shaped slice of tokens [capacity, D], avoiding the + dynamic-shape issues of token-choice top-k dispatch. + 2. **QAT-aware** — all expert weights are CastedLinear, so the class-level + CastedLinear.qat_levels switch applies uniformly to router and experts. + 3. **Muon-trained** — CastedLinear parameters are automatically picked up by + the existing Muon parameter-group logic (2-D weight matrices). + 4. **Load-balanced by construction** — each expert always processes exactly + `capacity` tokens, so no explicit load-balance loss is required. + 5. **Router stability via Z-loss** — a small penalty on router logit magnitudes + prevents collapse (all tokens always sent to one expert). + + Expert Choice routing (Zhou et al., 2022) + ========================================== + Instead of each token selecting its top-k experts (token choice), each expert + selects the top `capacity` tokens it wants to process: + + capacity = max(1, int(capacity_factor * S / E)) # S = B*T, E = num_experts + + router_probs [S, E] = softmax(router_logits) + top_scores [E, cap] \\ + top_indices [E, cap] / = router_probs.T.topk(capacity, dim=1) + + For each expert i: + expert_input = x_flat[top_indices[i]] # [cap, D] — gather + expert_out = expert_mlp_i(expert_input) # [cap, D] + expert_out *= top_scores[i] # weighted by routing prob + output += scatter(expert_out, top_indices[i]) # accumulate + + Every tensor shape is statically determined → fullgraph compile succeeds. + + Args: + dim : model hidden dimension + mlp_mult : MLP width multiplier (identical to base MLP) + num_experts : number of expert MLPs (E); must be ≥ 2 + capacity_factor : fraction of tokens each expert sees; 1.0 = perfect coverage + use_swiglu : SwiGLU activation (matching the base MLP choice) + """ + + def __init__( + self, + dim: int, + mlp_mult: int, + num_experts: int, + capacity_factor: float = 1.0, + use_swiglu: bool = False, + ): + super().__init__() + if num_experts < 2: + raise ValueError(f"MoEMLP requires num_experts >= 2, got {num_experts}") + self.num_experts = num_experts + self.capacity_factor = capacity_factor + self.use_swiglu = use_swiglu + + # Router: linear map from hidden dim to expert scores. + # CastedLinear → participates in QAT and Muon automatically. + self.router = CastedLinear(dim, num_experts, bias=False) + + # Per-expert weight matrices stored as ModuleLists of CastedLinear. + # This is intentionally verbose (vs stacked tensors) so that: + # a) Each expert participates in QAT via CastedLinear.qat_levels + # b) Muon picks them up as standard 2-D parameters + # c) Zero-init of proj layers is handled naturally via _zero_init flag + if use_swiglu: + hidden = max(64, (2 * mlp_mult * dim // 3 // 64) * 64) + self.expert_gates = nn.ModuleList([CastedLinear(dim, hidden, bias=False) for _ in range(num_experts)]) + self.expert_fcs = nn.ModuleList([CastedLinear(dim, hidden, bias=False) for _ in range(num_experts)]) + self.expert_projs = nn.ModuleList([CastedLinear(hidden, dim, bias=False) for _ in range(num_experts)]) + for m in self.expert_projs: + m._zero_init = True + else: + hidden = mlp_mult * dim + self.expert_gates = nn.ModuleList() # unused for relu²; kept for uniform attr + self.expert_fcs = nn.ModuleList([CastedLinear(dim, hidden, bias=False) for _ in range(num_experts)]) + self.expert_projs = nn.ModuleList([CastedLinear(hidden, dim, bias=False) for _ in range(num_experts)]) + for m in self.expert_projs: + m._zero_init = True + + def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: + """ + Args: + x : [B, T, D] + Returns: + output : [B, T, D] — same shape as input + z_loss : scalar — router Z-loss; add to training loss via moe_aux_loss_coeff + """ + B, T, D = x.shape + S = B * T + x_flat = x.reshape(S, D) + + # ── Router ────────────────────────────────────────────────────────── + router_logits = self.router(x_flat) # [S, E] (bfloat16) + + # Z-loss (Zoph et al., 2022 "ST-MoE"): + # z_loss = mean( log(∑_e exp(router_logits))² ) + # Keeps router logits from growing large → prevents routing collapse. + z_loss: Tensor = torch.logsumexp(router_logits.float(), dim=-1).square().mean() + + router_probs = torch.softmax(router_logits.float(), dim=-1) # [S, E] + + # ── Expert Choice: each expert picks its top-capacity tokens ───────── + # capacity is a Python int → static shape → fullgraph-compile friendly + capacity = max(1, int(self.capacity_factor * S / self.num_experts)) + + # router_probs.T is [E, S]; topk over dim=1 selects the top-capacity token + # indices per expert. Both outputs have static shape [E, capacity]. + top_scores, top_indices = router_probs.T.topk(capacity, dim=1) # [E, cap] + + # ── Expert forward + weighted scatter ──────────────────────────────── + output = torch.zeros_like(x_flat) # [S, D] + + for i in range(self.num_experts): + # Gather the tokens this expert selected. Shape: [cap, D] + expert_in = x_flat[top_indices[i]] + weights = top_scores[i].to(expert_in.dtype) # [cap] + + # Expert MLP forward (SwiGLU or relu²) + if self.use_swiglu: + h = F.silu(self.expert_gates[i](expert_in)) * self.expert_fcs[i](expert_in) + expert_out = self.expert_projs[i](h) + else: + h = torch.relu(self.expert_fcs[i](expert_in)) + expert_out = self.expert_projs[i](h.square()) + + # Scale by routing probability (gradient flows through weights here) + expert_out = expert_out * weights.unsqueeze(-1) + + # Scatter-add back into the output buffer at the positions this expert owns. + # top_indices[i] has static shape [cap]; unsqueeze(-1).expand gives [cap, D]. + output.scatter_add_( + 0, + top_indices[i].unsqueeze(-1).expand(-1, D), + expert_out, + ) + + return output.reshape(B, T, D), z_loss + + +class SSMMixer(nn.Module): + """SSM mixer used by SSM blocks. + + `impl="mamba3"` wraps the official CUDA-backed Mamba-3 block from + `mamba_ssm.modules.mamba3`. `impl="conv"` keeps the older lightweight causal + depthwise-conv mixer available for ablations. + """ + + def __init__( + self, + dim: int, + expand: float = 2.0, + kernel_size: int = 4, + impl: str = "mamba3", + mamba3_d_state: int = 128, + mamba3_head_dim: int = 64, + mamba3_is_mimo: bool = True, + mamba3_mimo_rank: int = 4, + mamba3_chunk_size: int = 16, + mamba3_outproj_norm: bool = False, + ): + super().__init__() + self.impl = impl.strip().lower() + if self.impl not in {"mamba3", "conv"}: + raise ValueError(f"Unsupported SSM_IMPL={impl!r}; expected 'mamba3' or 'conv'") + if self.impl == "mamba3": + if _OfficialMamba3 is None: + raise ImportError( + "SSM_IMPL=mamba3 requires the source build of mamba-ssm with Mamba3. " + "Install with: MAMBA_FORCE_BUILD=TRUE pip install --no-cache-dir " + "--force-reinstall git+https://github.com/state-spaces/mamba.git --no-build-isolation" + ) from _MAMBA3_IMPORT_ERROR + if mamba3_head_dim <= 0: + preferred = [128, 64, 32] + mamba3_head_dim = next((h for h in preferred if dim % h == 0), 0) + if mamba3_head_dim <= 0: + raise ValueError( + f"MAMBA3_HEAD_DIM=0 could not auto-pick a tested Mamba-3 headdim " + f"for MODEL_DIM={dim}; use a MODEL_DIM divisible by one of {preferred} " + f"(for example 448 or 512), or explicitly set MAMBA3_HEAD_DIM at your own risk." + ) + if dim % mamba3_head_dim != 0: + raise ValueError( + f"MODEL_DIM={dim} must be divisible by MAMBA3_HEAD_DIM={mamba3_head_dim}" + ) + self.mamba3_head_dim = int(mamba3_head_dim) + if mamba3_d_state <= 0: + raise ValueError(f"MAMBA3_D_STATE must be positive, got {mamba3_d_state}") + if mamba3_is_mimo and mamba3_mimo_rank <= 0: + raise ValueError(f"MAMBA3_MIMO_RANK must be positive, got {mamba3_mimo_rank}") + if mamba3_chunk_size <= 0: + raise ValueError(f"MAMBA3_CHUNK_SIZE must be positive, got {mamba3_chunk_size}") + kwargs = dict( + d_model=dim, + d_state=mamba3_d_state, + headdim=mamba3_head_dim, + is_mimo=bool(mamba3_is_mimo), + chunk_size=mamba3_chunk_size, + is_outproj_norm=bool(mamba3_outproj_norm), + ) + if mamba3_is_mimo: + kwargs["mimo_rank"] = mamba3_mimo_rank + self.mamba3 = _OfficialMamba3(**kwargs) + return + + if kernel_size < 2: + raise ValueError(f"SSM kernel must be >= 2, got {kernel_size}") + hidden = max(64, int(dim * expand) // 64 * 64) + self.in_proj = CastedLinear(dim, hidden * 2, bias=False) + # Depthwise causal conv over time (implemented via left crop after padding). + self.dw_conv = nn.Conv1d( + hidden, + hidden, + kernel_size=kernel_size, + groups=hidden, + bias=False, + padding=kernel_size - 1, + ) + self.out_proj = CastedLinear(hidden, dim, bias=False) + self.out_proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + # x: [B, T, D] + if self.impl == "mamba3": + return self.mamba3(x) + bsz, seqlen, _ = x.shape + uv = self.in_proj(x) + u, v = uv.chunk(2, dim=-1) + u = F.silu(u) + y = self.dw_conv(u.transpose(1, 2))[..., :seqlen].transpose(1, 2).contiguous() + y = y * torch.sigmoid(v) + return self.out_proj(y) + + +class MTPBranch(nn.Module): + """Per-horizon residual branch for multi-token prediction.""" + + def __init__(self, dim: int): + super().__init__() + self.norm = RMSNorm() + self.proj = CastedLinear(dim, dim, bias=False) + self.scale = nn.Parameter(torch.ones(1, dtype=torch.float32)) + + def forward(self, h: Tensor) -> Tensor: + return h + self.scale.to(dtype=h.dtype) * self.proj(self.norm(h)) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_swiglu: bool = False, + use_ssm: bool = False, + ssm_expand: float = 2.0, + ssm_kernel: int = 4, + ssm_impl: str = "mamba3", + mamba3_d_state: int = 128, + mamba3_head_dim: int = 64, + mamba3_is_mimo: bool = True, + mamba3_mimo_rank: int = 4, + mamba3_chunk_size: int = 16, + mamba3_outproj_norm: bool = False, + moe_num_experts: int = 0, + moe_capacity_factor: float = 1.0, + use_parallel_residual: bool = False, + use_sandwich_norm: bool = False, + ): + super().__init__() + self.use_ssm = use_ssm + self.use_sandwich_norm = use_sandwich_norm and not use_parallel_residual + # Parallel residual: one shared pre-norm feeds both attn and MLP simultaneously. + # Saves one RMSNorm, improves gradient flow; validated by leaderboard PRs. + self.use_parallel_residual = use_parallel_residual and not use_ssm + if use_parallel_residual and not use_ssm: + self.norm = RMSNorm() # single shared norm + self.attn_norm = self.norm # alias for compat + self.mlp_norm = self.norm # alias for compat + else: + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + if use_ssm: + self.attn = None + self.ssm = SSMMixer( + dim, + expand=ssm_expand, + kernel_size=ssm_kernel, + impl=ssm_impl, + mamba3_d_state=mamba3_d_state, + mamba3_head_dim=mamba3_head_dim, + mamba3_is_mimo=mamba3_is_mimo, + mamba3_mimo_rank=mamba3_mimo_rank, + mamba3_chunk_size=mamba3_chunk_size, + mamba3_outproj_norm=mamba3_outproj_norm, + ) + else: + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.ssm = None + # MoE or dense MLP — is_moe is a Python bool, resolved at compile time. + self.is_moe: bool = moe_num_experts >= 2 + if self.is_moe: + self.mlp: MLP | MoEMLP = MoEMLP(dim, mlp_mult, moe_num_experts, moe_capacity_factor, use_swiglu) + else: + self.mlp = MLP(dim, mlp_mult, use_swiglu=use_swiglu) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + # Sandwich norm: post-sublayer norms (Gemma 2 style). Applied before residual add. + if self.use_sandwich_norm: + self.attn_post_norm = RMSNorm() + self.mlp_post_norm = RMSNorm() + + def forward(self, x: Tensor, x0: Tensor) -> tuple[Tensor, Tensor]: + """Returns (hidden_state, moe_z_loss). + moe_z_loss is a zero scalar for non-MoE blocks so callers can always + accumulate unconditionally without a Python-level branch.""" + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + if self.use_ssm: + if self.ssm is None: + raise RuntimeError("SSM block is enabled but mixer is missing") + mix_out = self.ssm(self.attn_norm(x)) + if self.use_sandwich_norm: + mix_out = self.attn_post_norm(mix_out) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * mix_out + if self.is_moe: + mlp_out, z_loss = self.mlp(self.mlp_norm(x)) + else: + mlp_out = self.mlp(self.mlp_norm(x)) + z_loss = x.new_zeros(()) + if self.use_sandwich_norm: + mlp_out = self.mlp_post_norm(mlp_out) + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * mlp_out + elif self.use_parallel_residual: + # Parallel: both attn and MLP read the same pre-norm input, outputs added together. + if self.attn is None: + raise RuntimeError("Attention block is enabled but attention module is missing") + h = self.norm(x) + attn_out = self.attn(h) + if self.is_moe: + mlp_out, z_loss = self.mlp(h) + else: + mlp_out = self.mlp(h) + z_loss = x.new_zeros(()) + x = (x + + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * mlp_out) + else: + if self.attn is None: + raise RuntimeError("Attention block is enabled but attention module is missing") + mix_out = self.attn(self.attn_norm(x)) + if self.use_sandwich_norm: + mix_out = self.attn_post_norm(mix_out) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * mix_out + if self.is_moe: + mlp_out, z_loss = self.mlp(self.mlp_norm(x)) + else: + mlp_out = self.mlp(self.mlp_norm(x)) + z_loss = x.new_zeros(()) + if self.use_sandwich_norm: + mlp_out = self.mlp_post_norm(mlp_out) + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * mlp_out + return x, z_loss + + +class JPCRPredictor(nn.Module): + """JEPA Predictive Coding Recurrence predictor (v2 — BYOL/data2vec-inspired). + + Per-token MLP that predicts "where the hidden state should be" at this depth. + Trained with cosine similarity loss against instance-normalized EMA teacher + intermediates projected into a smaller space (BYOL-style). + + Architecture: + Blend path: RMSNorm → Linear(dim, hidden) → SiLU → Linear(hidden, dim) → residual + Loss path: shared Linear(dim, proj_dim) on prediction and normalized target, cosine loss + + The blend path modifies the recurrence input at inference (no teacher needed). + The loss path trains the predictor — projects to proj_dim for stable, bounded loss. + """ + + def __init__(self, model_dim: int, hidden_dim: int = 128, proj_dim: int = 128, + blend_init: float = -2.0): + super().__init__() + self.model_dim = model_dim + self.proj_dim = proj_dim + # Blend path: predicts delta to add to x + self.proj_in = nn.Linear(model_dim, hidden_dim, bias=True) + self.proj_out = nn.Linear(hidden_dim, model_dim, bias=True) + # Learnable blend gate (logit space). sigmoid(-2.0) ≈ 0.12 → conservative start. + self.blend_gate = nn.Parameter(torch.tensor(blend_init, dtype=torch.float32)) + # Zero-init output → identity at start of training (delta = 0) + nn.init.zeros_(self.proj_out.weight) + nn.init.zeros_(self.proj_out.bias) + # Loss projection heads (BYOL-style): project to smaller space for loss + self.student_proj = nn.Linear(model_dim, proj_dim, bias=False) + + def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: + """Returns (predicted_target, gate_value). No loss computation here.""" + h = F.rms_norm(x, (self.model_dim,)) + h = F.silu(self.proj_in(h)) + delta = self.proj_out(h) + predicted_target = x + delta + gate = torch.sigmoid(self.blend_gate.to(x.dtype)) + return predicted_target, gate + + def compute_loss(self, predicted_target: Tensor, teacher_target: Tensor) -> Tensor: + """Cosine similarity loss in projected space with instance-normalized targets. + + Returns scalar loss in [0, 2] (0 = perfect alignment, 2 = opposite). + Uses data2vec-style instance normalization + BYOL-style projection. + """ + # Instance-normalize teacher target (data2vec): zero-mean, unit-var per token + t = teacher_target.float() + t = (t - t.mean(dim=-1, keepdim=True)) / (t.std(dim=-1, keepdim=True) + 1e-6) + # Project both to smaller space with shared projector, detach target branch. + s_proj = self.student_proj(predicted_target.float()) + t_proj = self.student_proj(t).detach() + # Cosine similarity loss: 1 - cos_sim, bounded [0, 2] + s_norm = F.normalize(s_proj, dim=-1) + t_norm = F.normalize(t_proj, dim=-1) + return (1.0 - (s_norm * t_norm).sum(dim=-1)).mean() + + +def _run_ctrl_safe(ctrl: nn.Sequential, x: Tensor, loop_steps: int, model_dim: int) -> Tensor: + """Run Ouroboros controller with explicit dtype handling to avoid autocast/compile issues.""" + d = x.dtype + h = x.mean(dim=1) # [B, dim] + # Functional forward through controller: Linear -> SiLU -> Linear + h = F.linear(h, ctrl[0].weight.to(d), ctrl[0].bias.to(d)) + h = F.silu(h) + h = F.linear(h, ctrl[2].weight.to(d), ctrl[2].bias.to(d)) + return h.view(x.shape[0], loop_steps, 2, model_dim) + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + recurrent_core_layers: int = 0, + recurrent_steps: int = 0, + share_ffn_across_blocks: bool = False, + intra_loop_start: int = -1, + intra_loop_end: int = -1, + intra_loop_steps: int = 3, + use_parallel_residual: bool = False, + use_swiglu: bool = False, + bigram_rank: int = 0, + mtp_enabled: bool = False, + mtp_steps: int = 2, + mtp_weight: float = 0.3, + mtp_decay: float = 1.0, + mtp_tie_embeddings: bool = True, + use_ssm: bool = False, + ssm_every_n: int = 2, + ssm_expand: float = 2.0, + ssm_kernel: int = 4, + ssm_impl: str = "mamba3", + mamba3_d_state: int = 128, + mamba3_head_dim: int = 64, + mamba3_is_mimo: bool = True, + mamba3_mimo_rank: int = 4, + mamba3_chunk_size: int = 16, + mamba3_outproj_norm: bool = False, + residual_ngram_enabled: bool = False, + residual_bigram_rank: int = 0, + residual_trigram_rank: int = 0, + residual_ngram_mix_init: float = -2.5, + ngram_softcap: float = 0.0, + ngram_entropy_gate: bool = False, + copy_cache_enabled: bool = False, + copy_cache_window: int = 256, + copy_cache_dim: int = 64, + copy_cache_gate_init: float = -4.0, + moe_num_experts: int = 0, + moe_every_n: int = 2, + moe_capacity_factor: float = 1.0, + moe_aux_loss_coeff: float = 1e-3, + dual_head_enabled: bool = False, + dual_head_num_classes: int = 4, + jpcr_enabled: bool = False, + jpcr_hidden: int = 128, + jpcr_proj_dim: int = 128, + jpcr_blend_init: float = -2.0, + use_sandwich_norm: bool = False, + embed_scale: bool = False, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + if (recurrent_core_layers > 0) != (recurrent_steps > 0): + raise ValueError( + "RECURRENT_CORE_LAYERS and RECURRENT_STEPS must both be > 0 for recurrence mode, " + f"got RECURRENT_CORE_LAYERS={recurrent_core_layers}, RECURRENT_STEPS={recurrent_steps}" + ) + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.use_recurrence = recurrent_core_layers > 0 and recurrent_steps > 0 + self.recurrent_core_layers = recurrent_core_layers + self.recurrent_steps = recurrent_steps + self.share_ffn_across_blocks = share_ffn_across_blocks + # Partial depth recurrence: loop layers [intra_loop_start..intra_loop_end] N times. + # Middle layers are optimal (see Universal Transformers; leaderboard PR #1394). + # Loop-position embeddings (shape [n_looped_blocks, steps, dim], init=0) let the + # model distinguish iteration 0 from iteration 1, learned via Adam at scalar_lr. + _intra_active = (intra_loop_start >= 0 and intra_loop_end >= intra_loop_start + and intra_loop_steps > 1 and not self.use_recurrence) + self.intra_loop_start = int(intra_loop_start) if _intra_active else -1 + self.intra_loop_end = int(intra_loop_end) if _intra_active else -1 + self.intra_loop_steps = int(intra_loop_steps) if _intra_active else 1 + self.use_ssm = use_ssm + self.ssm_every_n = ssm_every_n + self.ssm_expand = ssm_expand + self.ssm_kernel = ssm_kernel + self.ssm_impl = ssm_impl + self.mamba3_d_state = mamba3_d_state + self.mamba3_head_dim = mamba3_head_dim + self.mamba3_is_mimo = mamba3_is_mimo + self.mamba3_mimo_rank = mamba3_mimo_rank + self.mamba3_chunk_size = mamba3_chunk_size + self.mamba3_outproj_norm = mamba3_outproj_norm + self.mtp_enabled = mtp_enabled and mtp_steps > 0 + self.mtp_steps = max(0, mtp_steps) + self.mtp_weight = max(0.0, mtp_weight) + self.mtp_decay = mtp_decay + self.mtp_tie_embeddings = mtp_tie_embeddings + self.residual_bigram_rank = max(0, residual_bigram_rank) + self.residual_trigram_rank = max(0, residual_trigram_rank) + self.residual_ngram_enabled = residual_ngram_enabled and ( + self.residual_bigram_rank > 0 or self.residual_trigram_rank > 0 + ) + self.residual_ngram_mix_init = residual_ngram_mix_init + # 0.0 means "inherit logit_softcap"; >0 decouples the ngram branch cap. + self.ngram_softcap = float(ngram_softcap) if ngram_softcap > 0.0 else 0.0 + self.ngram_entropy_gate = bool(ngram_entropy_gate) and self.residual_ngram_enabled + self.copy_cache_enabled = copy_cache_enabled + self.copy_cache_window = max(1, int(copy_cache_window)) + self.copy_cache_dim = max(8, int(copy_cache_dim)) + self.copy_cache_gate_init = copy_cache_gate_init + self.dual_head_enabled = bool(dual_head_enabled) + self.dual_head_num_classes = max(2, int(dual_head_num_classes)) + if self.use_recurrence: + self.total_effective_layers = recurrent_core_layers * recurrent_steps + elif self.intra_loop_start >= 0: + n_looped = self.intra_loop_end - self.intra_loop_start + 1 + self.total_effective_layers = num_layers + n_looped * (self.intra_loop_steps - 1) + else: + self.total_effective_layers = num_layers + + # MoE config stored on model (used in forward() to gate the aux loss) + self.moe_aux_loss_coeff = float(moe_aux_loss_coeff) + self._has_moe = moe_num_experts >= 2 and moe_every_n > 0 + + def is_ssm_block(idx: int) -> bool: + return self.use_ssm and self.ssm_every_n > 0 and ((idx + 1) % self.ssm_every_n == 0) + + def is_moe_block(idx: int) -> bool: + return moe_num_experts >= 2 and moe_every_n > 0 and idx % moe_every_n == 0 + + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.embed_scale = embed_scale + self._embed_scale_factor = model_dim ** 0.5 if embed_scale else 1.0 + if self.use_recurrence: + self.num_encoder_layers = 0 + self.num_decoder_layers = 0 + self.num_skip_weights = 0 + # In recurrence mode skip_weights are unused; keep as buffer so DDP + # doesn't expect gradients for an empty parameter tensor. + self.register_buffer("skip_weights", torch.ones(0, model_dim, dtype=torch.float32), persistent=False) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + use_swiglu=use_swiglu, + use_ssm=is_ssm_block(i), + ssm_expand=ssm_expand, + ssm_kernel=ssm_kernel, + ssm_impl=ssm_impl, + mamba3_d_state=mamba3_d_state, + mamba3_head_dim=mamba3_head_dim, + mamba3_is_mimo=mamba3_is_mimo, + mamba3_mimo_rank=mamba3_mimo_rank, + mamba3_chunk_size=mamba3_chunk_size, + mamba3_outproj_norm=mamba3_outproj_norm, + moe_num_experts=moe_num_experts if is_moe_block(i) else 0, + moe_capacity_factor=moe_capacity_factor, + use_parallel_residual=use_parallel_residual and not is_ssm_block(i), + use_sandwich_norm=use_sandwich_norm, + ) + for i in range(recurrent_core_layers) + ] + ) + # SHARE_FFN_ACROSS_BLOCKS is incompatible with MoE (different experts per layer). + if share_ffn_across_blocks and len(self.blocks) > 1 and not self._has_moe: + shared_mlp = self.blocks[0].mlp + for i in range(1, len(self.blocks)): + self.blocks[i].mlp = shared_mlp + else: + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + use_swiglu=use_swiglu, + use_ssm=is_ssm_block(i), + ssm_expand=ssm_expand, + ssm_kernel=ssm_kernel, + ssm_impl=ssm_impl, + mamba3_d_state=mamba3_d_state, + mamba3_head_dim=mamba3_head_dim, + mamba3_is_mimo=mamba3_is_mimo, + mamba3_mimo_rank=mamba3_mimo_rank, + mamba3_chunk_size=mamba3_chunk_size, + mamba3_outproj_norm=mamba3_outproj_norm, + moe_num_experts=moe_num_experts if is_moe_block(i) else 0, + moe_capacity_factor=moe_capacity_factor, + use_sandwich_norm=use_sandwich_norm, + ) + for i in range(num_layers) + ] + ) + if share_ffn_across_blocks and len(self.blocks) > 1 and not self._has_moe: + shared_mlp = self.blocks[0].mlp + for i in range(1, len(self.blocks)): + self.blocks[i].mlp = shared_mlp + self.num_ssm_blocks = sum(1 for block in self.blocks if block.use_ssm) + self.num_moe_blocks = sum(1 for block in self.blocks if block.is_moe) + self.num_attn_blocks = len(self.blocks) - self.num_ssm_blocks + # JPCR (JEPA Predictive Coding Recurrence) or Ouroboros loop conditioning. + # JPCR: per-token MLP predictors trained with JEPA MSE loss against teacher intermediates. + # Each predictor predicts the ideal hidden state; a learned gate blends this prediction + # into the recurrence input. Progressive depth targeting across loop iterations. + # Ouroboros: per-looped-block tiny hypernetwork generating (scale, shift) from mean(x). + self.jpcr_enabled = bool(jpcr_enabled) and _intra_active + if self.jpcr_enabled: + n_looped = self.intra_loop_end - self.intra_loop_start + 1 + predictors = [] + for _ in range(n_looped): + predictors.append(JPCRPredictor(model_dim, jpcr_hidden, jpcr_proj_dim, jpcr_blend_init)) + self.jpcr_predictors = nn.ModuleList(predictors) + self.intra_loop_controllers = nn.ModuleList([]) # not used with JPCR + self._intra_model_dim = model_dim + elif _intra_active: + self.jpcr_predictors = nn.ModuleList([]) + n_looped = self.intra_loop_end - self.intra_loop_start + 1 + _ctrl_hidden = 32 + # One controller per looped block; each outputs [steps, 2, dim] + controllers = [] + for _ in range(n_looped): + net = nn.Sequential( + nn.Linear(model_dim, _ctrl_hidden, bias=True), + nn.SiLU(), + nn.Linear(_ctrl_hidden, self.intra_loop_steps * 2 * model_dim, bias=True), + ) + # Zero-init output layer → identity transform at start of training + nn.init.zeros_(net[-1].weight) + nn.init.zeros_(net[-1].bias) + controllers.append(net) + self.intra_loop_controllers = nn.ModuleList(controllers) + self._intra_model_dim = model_dim + else: + self.jpcr_predictors = nn.ModuleList([]) + self.intra_loop_controllers = nn.ModuleList([]) + self._intra_model_dim = model_dim + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + self.dual_head = CastedLinear(model_dim, self.dual_head_num_classes, bias=True) if self.dual_head_enabled else None + if self.lm_head is not None: + self.lm_head._zero_init = True + if self.mtp_enabled: + self.mtp_branches = nn.ModuleList([MTPBranch(model_dim) for _ in range(self.mtp_steps)]) + if self.mtp_tie_embeddings and self.tie_embeddings: + self.mtp_heads = None + else: + self.mtp_heads = nn.ModuleList([CastedLinear(model_dim, vocab_size, bias=False) for _ in range(self.mtp_steps)]) + self.register_buffer( + "mtp_step_weights", + torch.tensor([self.mtp_decay**i for i in range(self.mtp_steps)], dtype=torch.float32), + persistent=False, + ) + else: + self.mtp_branches = None + self.mtp_heads = None + self.register_buffer("mtp_step_weights", torch.zeros((0,), dtype=torch.float32), persistent=False) + # Low-rank bigram logit bias. At position i, adds bigram_right(bigram_left(input[i])) to logits. + # This gives the model a cheap, learned n-gram prior on top of the contextual representations. + self.bigram_rank = bigram_rank + if bigram_rank > 0: + self.bigram_left = nn.Embedding(vocab_size, bigram_rank) + self.bigram_right = CastedLinear(bigram_rank, vocab_size, bias=False) + self.bigram_right._zero_init = True # starts contributing nothing; learns when useful + self.bigram_scale = nn.Parameter(torch.ones(1, dtype=torch.float32)) + if self.residual_ngram_enabled: + if self.residual_bigram_rank > 0: + self.residual_bigram_left = nn.Embedding(vocab_size, self.residual_bigram_rank) + self.residual_bigram_right = CastedLinear(self.residual_bigram_rank, vocab_size, bias=False) + self.residual_bigram_right._zero_init = True + if self.residual_trigram_rank > 0: + self.residual_trigram_prev1 = nn.Embedding(vocab_size, self.residual_trigram_rank) + self.residual_trigram_prev2 = nn.Embedding(vocab_size, self.residual_trigram_rank) + self.residual_trigram_right = CastedLinear(self.residual_trigram_rank, vocab_size, bias=False) + self.residual_trigram_right._zero_init = True + self.residual_ngram_scale = nn.Parameter(torch.ones(1, dtype=torch.float32)) + gate_in_dim = model_dim + (1 if self.ngram_entropy_gate else 0) + self.residual_ngram_gate = CastedLinear(gate_in_dim, 1, bias=True) + if self.copy_cache_enabled: + self.copy_q = CastedLinear(model_dim, self.copy_cache_dim, bias=False) + self.copy_k = CastedLinear(model_dim, self.copy_cache_dim, bias=False) + self.copy_gate = CastedLinear(model_dim, 1, bias=True) + self._init_weights() + if self.residual_ngram_enabled: + nn.init.zeros_(self.residual_ngram_gate.weight) + if self.residual_ngram_gate.bias is not None: + nn.init.constant_(self.residual_ngram_gate.bias, self.residual_ngram_mix_init) + if self.copy_cache_enabled: + nn.init.zeros_(self.copy_gate.weight) + if self.copy_gate.bias is not None: + nn.init.constant_(self.copy_gate.bias, self.copy_cache_gate_init) + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _compute_residual_ngram_logits(self, input_ids: Tensor) -> Tensor | None: + if not self.residual_ngram_enabled: + return None + prev1 = input_ids.reshape(-1) + ngram_logits: Tensor | None = None + if self.residual_bigram_rank > 0: + bg = self.residual_bigram_right(self.residual_bigram_left(prev1)) + ngram_logits = bg + if self.residual_trigram_rank > 0: + prev2_ids = torch.cat((input_ids[:, :1], input_ids[:, :-1]), dim=1).reshape(-1) + tri_feat = self.residual_trigram_prev1(prev1) * self.residual_trigram_prev2(prev2_ids) + tri = self.residual_trigram_right(tri_feat) + ngram_logits = tri if ngram_logits is None else (ngram_logits + tri) + if ngram_logits is None: + return None + return self.residual_ngram_scale * ngram_logits + + def _build_copy_cache_log_probs(self, hidden: Tensor, input_ids: Tensor, source_next_ids: Tensor) -> Tensor: + # hidden: [B, T, D], input_ids/source_next_ids: [B, T] + bsz, seqlen, _ = hidden.shape + q = self.copy_q(hidden).float() + k = self.copy_k(hidden).float() + scale = 1.0 / math.sqrt(float(self.copy_cache_dim)) + att = torch.matmul(q, k.transpose(1, 2)) * scale # [B, T, T] + + pos = torch.arange(seqlen, device=hidden.device) + t_pos = pos.view(1, seqlen, 1) + j_pos = pos.view(1, 1, seqlen) + causal = j_pos < t_pos + within = (t_pos - j_pos) <= self.copy_cache_window + mask = causal & within + att = att.masked_fill(~mask, float("-inf")) + no_source = ~mask.any(dim=-1, keepdim=True) + att = torch.where(no_source, torch.zeros_like(att), att) + att_prob = F.softmax(att, dim=-1).masked_fill(no_source, 0.0) + + copy_probs = torch.zeros((bsz, seqlen, self.tok_emb.num_embeddings), device=hidden.device, dtype=torch.float32) + copy_probs.scatter_add_( + 2, + source_next_ids.unsqueeze(1).expand(-1, seqlen, -1), + att_prob, + ) + return torch.log(copy_probs.clamp_min(1e-9)) + + def _compose_output_logits( + self, + logits_proj: Tensor, + input_ids: Tensor, + hidden: Tensor, + source_next_ids: Tensor | None = None, + ) -> tuple[Tensor, bool]: + neural_logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + ngram_logits = self._compute_residual_ngram_logits(input_ids) + composed = neural_logits + if ngram_logits is not None: + # Stable residual composition in logit space. + flat_h = hidden.reshape(-1, hidden.size(-1)) + if self.ngram_entropy_gate: + # Cheap confidence signal: (logsumexp - max) = -log max_prob. Larger = less confident. + # Detached so the gate signal is stop-grad wrt the neural head (keeps semantics simple). + with torch.no_grad(): + n_logits_f = neural_logits.float() + lse = torch.logsumexp(n_logits_f, dim=-1, keepdim=True) + max_logit = n_logits_f.max(dim=-1, keepdim=True).values + neg_max_log_prob = (lse - max_logit).to(dtype=flat_h.dtype) + gate_input = torch.cat([flat_h, neg_max_log_prob], dim=-1) + gate = torch.sigmoid(self.residual_ngram_gate(gate_input)) + else: + gate = torch.sigmoid(self.residual_ngram_gate(flat_h)) + cap = self.ngram_softcap if self.ngram_softcap > 0.0 else self.logit_softcap + ngram_logits = cap * torch.tanh(ngram_logits / cap) + composed = composed + gate.to(dtype=composed.dtype) * ngram_logits.to(dtype=composed.dtype) + + if not self.copy_cache_enabled: + return composed, False + + if source_next_ids is None: + source_next_ids = torch.cat((input_ids[:, 1:], input_ids[:, -1:]), dim=1) + copy_log_probs = self._build_copy_cache_log_probs(hidden, input_ids, source_next_ids) + model_log_probs = F.log_softmax(composed.float().reshape(input_ids.size(0), input_ids.size(1), -1), dim=-1) + gate = torch.sigmoid(self.copy_gate(hidden).float()).clamp(min=1e-4, max=1.0 - 1e-4) + mixed_log_probs = torch.logaddexp( + torch.log1p(-gate) + model_log_probs, + torch.log(gate) + copy_log_probs, + ) + return mixed_log_probs.reshape(-1, mixed_log_probs.size(-1)).to(dtype=composed.dtype), True + + def _apply_loop_conditioning(self, x: Tensor, block_idx: int, step: int) -> Tensor: + """Apply JPCR blend or Ouroboros conditioning before a looped block execution.""" + if self.jpcr_enabled and len(self.jpcr_predictors) > 0: + predictor = self.jpcr_predictors[block_idx - self.intra_loop_start] + predicted_target, gate = predictor(x) + # Blend: nudge current state toward predicted target + x = x + gate * (predicted_target - x) + elif len(self.intra_loop_controllers) > 0: + ctrl = self.intra_loop_controllers[block_idx - self.intra_loop_start] + out = _run_ctrl_safe(ctrl, x, self.intra_loop_steps, self._intra_model_dim) + scale = out[:, step, 0, :].unsqueeze(1).to(dtype=x.dtype) + shift = out[:, step, 1, :].unsqueeze(1).to(dtype=x.dtype) + x = x * (1.0 + scale.tanh()) + shift + return x + + def _forward_hidden(self, input_ids: Tensor, *, jpcr_runtime_active: bool | None = None) -> Tensor: + x = self.tok_emb(input_ids) + if self.embed_scale: + x = x * self._embed_scale_factor + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + jpcr_runtime_active = self.jpcr_enabled if jpcr_runtime_active is None else bool(jpcr_runtime_active) + if self.use_recurrence: + for _ in range(self.recurrent_steps): + for block in self.blocks: + x, _ = block(x, x0) + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + n_rep = self.intra_loop_steps if (jpcr_runtime_active and self.intra_loop_start <= i <= self.intra_loop_end) else 1 + for s in range(n_rep): + if n_rep > 1 and s > 0: + x = self._apply_loop_conditioning(x, i, s) + x, _ = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + j = self.num_encoder_layers + i + n_rep = self.intra_loop_steps if (jpcr_runtime_active and self.intra_loop_start <= j <= self.intra_loop_end) else 1 + for s in range(n_rep): + if n_rep > 1 and s > 0: + x = self._apply_loop_conditioning(x, j, s) + x, _ = self.blocks[j](x, x0) + return self.final_norm(x) + + def _forward_hidden_with_intermediates(self, input_ids: Tensor, *, jpcr_runtime_active: bool | None = None) -> tuple[Tensor, list[Tensor]]: + """Forward pass capturing hidden states ONLY for looped blocks (NO loop, NO conditioning). + + Used by the EMA teacher to provide clean JEPA targets for JPCR predictors. + Runs each block exactly once — the teacher represents the "ideal" single-pass model. + Only captures intermediates for blocks in [intra_loop_start, intra_loop_end] to save memory. + Returns (final_hidden_after_norm, list_of_looped_block_hidden_states). + """ + x = self.tok_emb(input_ids) + if self.embed_scale: + x = x * self._embed_scale_factor + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + intermediates: list[Tensor] = [] + jpcr_runtime_active = self.jpcr_enabled if jpcr_runtime_active is None else bool(jpcr_runtime_active) + if self.use_recurrence: + for _ in range(self.recurrent_steps): + for block in self.blocks: + x, _ = block(x, x0) + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x, _ = self.blocks[i](x, x0) + if jpcr_runtime_active and self.intra_loop_start <= i <= self.intra_loop_end: + intermediates.append(x) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + j = self.num_encoder_layers + i + x, _ = self.blocks[j](x, x0) + if jpcr_runtime_active and self.intra_loop_start <= j <= self.intra_loop_end: + intermediates.append(x) + return self.final_norm(x), intermediates + + def forward_hidden_and_output(self, input_ids: Tensor, *, jpcr_runtime_active: bool | None = None) -> tuple[Tensor, Tensor, bool]: + h = self._forward_hidden(input_ids, jpcr_runtime_active=jpcr_runtime_active) + flat_h = h.reshape(-1, h.size(-1)) + if self.tie_embeddings: + logits_proj = F.linear(flat_h, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(flat_h) + if self.bigram_rank > 0: + bg = self.bigram_right(self.bigram_left(input_ids.reshape(-1))) # [B*T, vocab] + logits_proj = logits_proj + self.bigram_scale * bg + logits, logits_are_log_probs = self._compose_output_logits(logits_proj, input_ids, h) + return h, logits, logits_are_log_probs + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Forward pass returning logits. NOTE: when self.copy_cache_enabled is True, + the returned tensor is log-probabilities (already log_softmax'd), not raw logits. + Callers that feed this into distillation must rely on student's logits_are_log_probs + flag to interpret format consistently (student and teacher share config).""" + _, logits, _ = self.forward_hidden_and_output(input_ids) + return logits + + def forward_logits_and_intermediates(self, input_ids: Tensor, *, jpcr_runtime_active: bool | None = None) -> tuple[Tensor, list[Tensor]]: + """Forward pass returning logits AND per-block hidden states for JPCR teacher. + Same format caveat as forward_logits: log-probs when copy_cache is enabled.""" + h, intermediates = self._forward_hidden_with_intermediates(input_ids, jpcr_runtime_active=jpcr_runtime_active) + flat_h = h.reshape(-1, h.size(-1)) + if self.tie_embeddings: + logits_proj = F.linear(flat_h, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(flat_h) + if self.bigram_rank > 0: + bg = self.bigram_right(self.bigram_left(input_ids.reshape(-1))) + logits_proj = logits_proj + self.bigram_scale * bg + logits, _ = self._compose_output_logits(logits_proj, input_ids, h) + return logits, intermediates + + def forward( + self, + input_ids: Tensor, + target_ids: Tensor, + loss_mask: Tensor | None = None, + per_token_weights: Tensor | None = None, + aux_targets: Tensor | None = None, + aux_weight: float = 0.0, + distill_teacher_logits: Tensor | None = None, + distill_weight: float = 0.0, + distill_temp: float = 1.0, + logit_reg_weight: float = 0.0, + jpcr_teacher_intermediates: list[Tensor] | None = (), + jpcr_weight: float = 0.0, + jpcr_runtime_active: bool = False, + ) -> Tensor: + if jpcr_teacher_intermediates is None: + jpcr_teacher_intermediates = () + x = self.tok_emb(input_ids) + if self.embed_scale: + x = x * self._embed_scale_factor + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + moe_z_loss: Tensor = x.new_zeros(()) # accumulates router Z-losses from all MoE blocks + jpcr_loss: Tensor = x.new_zeros(()) # accumulates JEPA MSE losses from JPCR predictors + jpcr_count: int = 0 # number of JPCR predictions for averaging + if self.use_recurrence: + for _ in range(self.recurrent_steps): + for block in self.blocks: + x, zl = block(x, x0) + moe_z_loss = moe_z_loss + zl + else: + skips: list[Tensor] = [] + # Only enable repeated intra-loop passes when loop conditioning is active. + # For JPCR this means post-distill runtime activation; for Ouroboros + # (controllers present) this remains active whenever configured. + loop_active = jpcr_runtime_active or len(self.intra_loop_controllers) > 0 + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + n_rep = (self.intra_loop_steps if self.intra_loop_start <= i <= self.intra_loop_end else 1) if loop_active else 1 + for s in range(n_rep): + if n_rep > 1 and s > 0: + if self.jpcr_enabled and len(self.jpcr_predictors) > 0: + if jpcr_runtime_active: + predictor = self.jpcr_predictors[i - self.intra_loop_start] + predicted_target, gate = predictor(x) + # Always compute JPCR loss when teacher targets exist. + # jpcr_weight=0 before distill → no gradient impact. + # No branch on len(intermediates) to avoid torch.compile retrace. + target_idx = (i + s) - self.intra_loop_start + if target_idx < len(jpcr_teacher_intermediates): + teacher_target = jpcr_teacher_intermediates[target_idx] + jpcr_loss = jpcr_loss + predictor.compute_loss(predicted_target, teacher_target) + jpcr_count += 1 + x = x + gate * (predicted_target - x) + elif len(self.intra_loop_controllers) > 0: + ctrl = self.intra_loop_controllers[i - self.intra_loop_start] + out = _run_ctrl_safe(ctrl, x, self.intra_loop_steps, self._intra_model_dim) + scale = out[:, s, 0, :].unsqueeze(1).to(dtype=x.dtype) + shift = out[:, s, 1, :].unsqueeze(1).to(dtype=x.dtype) + x = x * (1.0 + scale.tanh()) + shift + x, zl = self.blocks[i](x, x0) + moe_z_loss = moe_z_loss + zl + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + j = self.num_encoder_layers + i + n_rep = (self.intra_loop_steps if self.intra_loop_start <= j <= self.intra_loop_end else 1) if loop_active else 1 + for s in range(n_rep): + if n_rep > 1 and s > 0: + if self.jpcr_enabled and len(self.jpcr_predictors) > 0: + if jpcr_runtime_active: + predictor = self.jpcr_predictors[j - self.intra_loop_start] + predicted_target, gate = predictor(x) + target_idx = (j + s) - self.intra_loop_start + if target_idx < len(jpcr_teacher_intermediates): + teacher_target = jpcr_teacher_intermediates[target_idx] + jpcr_loss = jpcr_loss + predictor.compute_loss(predicted_target, teacher_target) + jpcr_count += 1 + x = x + gate * (predicted_target - x) + elif len(self.intra_loop_controllers) > 0: + ctrl = self.intra_loop_controllers[j - self.intra_loop_start] + out = _run_ctrl_safe(ctrl, x, self.intra_loop_steps, self._intra_model_dim) + scale = out[:, s, 0, :].unsqueeze(1).to(dtype=x.dtype) + shift = out[:, s, 1, :].unsqueeze(1).to(dtype=x.dtype) + x = x * (1.0 + scale.tanh()) + shift + x, zl = self.blocks[j](x, x0) + moe_z_loss = moe_z_loss + zl + + h = self.final_norm(x) + flat_h = h.reshape(-1, h.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(flat_h, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(flat_h) + # Low-rank bigram bias: cheap learned n-gram prior on top of contextual representation. + if self.bigram_rank > 0: + bg = self.bigram_right(self.bigram_left(input_ids.reshape(-1))) # [B*T, vocab] + logits_proj = logits_proj + self.bigram_scale * bg + logits, logits_are_log_probs = self._compose_output_logits( + logits_proj, + input_ids, + h, + source_next_ids=target_ids, + ) + if logits_are_log_probs: + base_per_token = F.nll_loss(logits.float(), targets, reduction="none") # [B*T] + else: + base_per_token = F.cross_entropy(logits.float(), targets, reduction="none") # [B*T] + weighted = base_per_token + norm = torch.ones((), device=base_per_token.device, dtype=base_per_token.dtype) * base_per_token.numel() + if per_token_weights is not None: + token_w = per_token_weights.reshape(-1).to(base_per_token.dtype) + weighted = weighted * token_w + norm = token_w.sum().clamp(min=1) + if loss_mask is not None: + mask = loss_mask.reshape(-1).to(base_per_token.dtype) + weighted = weighted * mask + if per_token_weights is None: + norm = mask.sum().clamp(min=1) + else: + norm = (token_w * mask).sum().clamp(min=1) + base_loss = weighted.sum() / norm + + total_loss = base_loss + + if self.dual_head is not None and aux_targets is not None and aux_weight > 0.0: + aux_logits = self.dual_head(flat_h) # [B*T, C] + aux_flat_targets = aux_targets.reshape(-1) + aux_per_token = F.cross_entropy(aux_logits.float(), aux_flat_targets, reduction="none") + if loss_mask is not None: + mask = loss_mask.reshape(-1).to(aux_per_token.dtype) + aux_loss = (aux_per_token * mask).sum() / mask.sum().clamp(min=1) + else: + aux_loss = aux_per_token.mean() + total_loss = total_loss + float(aux_weight) * aux_loss + elif self.dual_head is not None: + # Safety touch keeps dual-head params in graph when auxiliary loss is inactive. + total_loss = total_loss + 0.0 * ( + self.dual_head.weight.reshape(-1)[0].float() + + (self.dual_head.bias.reshape(-1)[0].float() if self.dual_head.bias is not None else 0.0) + ) + + if logit_reg_weight > 0.0: + total_loss = total_loss + float(logit_reg_weight) * logits_proj.float().pow(2).mean() + + if distill_teacher_logits is not None and distill_teacher_logits.numel() > 0 and distill_weight > 0.0: + temp = max(float(distill_temp), 1e-4) + if logits_are_log_probs: + # Both student and teacher share config (EMA teacher). When copy_cache is + # enabled, both emit log-probs, so teacher must be exp()'d to probs. + # Temperature scaling is skipped (would need renormalization in prob space). + student_log_probs = logits.float() + teacher_probs = distill_teacher_logits.float().exp() + else: + student = (logits.float() / temp) + teacher = (distill_teacher_logits.float() / temp) + student_log_probs = F.log_softmax(student, dim=-1) + teacher_probs = F.softmax(teacher, dim=-1) + if loss_mask is not None: + mask = loss_mask.reshape(-1) > 0 + student_log_probs = student_log_probs[mask] + teacher_probs = teacher_probs[mask] + kl = F.kl_div( + student_log_probs, + teacher_probs, + reduction="batchmean", + ) * (temp * temp if not logits_are_log_probs else 1.0) + total_loss = total_loss + float(distill_weight) * kl + + # JPCR (JEPA Predictive Coding Recurrence) loss: average MSE across all predictor outputs. + # Always add the term (no branch on jpcr_weight) to keep torch.compile graph constant. + # When jpcr_weight=0.0 (before distill), the multiplication zeros out the gradient. + if jpcr_count > 0: + total_loss = total_loss + float(jpcr_weight) * (jpcr_loss / jpcr_count) + total_loss = total_loss + 0.0 * jpcr_loss + if self.jpcr_enabled and len(self.jpcr_predictors) > 0: + # Safety touch keeps ALL JPCR params in graph every step (zero gradient where unused). + # This supports DDP find_unused_parameters=False with conditional JPCR execution. + for p in self.jpcr_predictors.parameters(): + total_loss = total_loss + 0.0 * p.reshape(-1)[0].float() + + # MoE router Z-loss — only during training (loss_mask is None means no sliding-window eval mask). + # Follows the same pattern as MTP (excluded during eval to keep val_bpb clean). + if self._has_moe and self.moe_aux_loss_coeff > 0.0 and loss_mask is None: + total_loss = total_loss + self.moe_aux_loss_coeff * moe_z_loss + + # Keep eval metric comparable by applying MTP only when loss_mask is not provided. + if not self.mtp_enabled or self.mtp_weight <= 0.0 or loss_mask is not None: + return total_loss + + _, seqlen = target_ids.shape + weighted_aux = torch.zeros((), device=base_loss.device, dtype=base_loss.dtype) + weight_sum = torch.zeros((), device=base_loss.device, dtype=base_loss.dtype) + if self.mtp_branches is not None: + for step_idx in range(self.mtp_steps): + horizon = step_idx + 1 # 1 predicts token at t+2, 2 predicts t+3, ... + if seqlen - horizon <= 0: + continue + branch_h = self.mtp_branches[step_idx](h[:, : seqlen - horizon, :]) + branch_flat_h = branch_h.reshape(-1, branch_h.size(-1)) + future_targets = target_ids[:, horizon:].reshape(-1) + if self.mtp_heads is None: + aux_logits_proj = F.linear(branch_flat_h, self.tok_emb.weight) + else: + aux_logits_proj = self.mtp_heads[step_idx](branch_flat_h) + aux_logits = self.logit_softcap * torch.tanh(aux_logits_proj / self.logit_softcap) + aux_loss = F.cross_entropy(aux_logits.float(), future_targets, reduction="mean") + w = self.mtp_step_weights[step_idx].to(dtype=weighted_aux.dtype) + weighted_aux = weighted_aux + aux_loss.to(weighted_aux.dtype) * w + weight_sum = weight_sum + w + + aux_loss = weighted_aux / weight_sum.clamp_min(1e-12) + return total_loss + self.mtp_weight * aux_loss + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.quant_scheme not in SUPPORTED_QUANT_SCHEMES: + raise ValueError(f"Unsupported QUANT_SCHEME={args.quant_scheme!r}; expected one of {sorted(SUPPORTED_QUANT_SCHEMES)}") + if args.compressor not in SUPPORTED_COMPRESSORS: + raise ValueError(f"Unsupported COMPRESSOR={args.compressor!r}; expected one of {sorted(SUPPORTED_COMPRESSORS)}") + if args.weight_order not in SUPPORTED_WEIGHT_ORDERS: + raise ValueError(f"Unsupported WEIGHT_ORDER={args.weight_order!r}; expected one of {sorted(SUPPORTED_WEIGHT_ORDERS)}") + if args.mixed_low_precision_scheme not in {"int8", "int5", "int4"}: + raise ValueError( + f"Unsupported MIXED_LOW_PRECISION_SCHEME={args.mixed_low_precision_scheme!r}; expected 'int8', 'int5', or 'int4'" + ) + sweep_specs = resolve_eval_sweep_specs(args) + blend_specs, blend_weights = resolve_eval_blend_specs(args) + max_eval_seq_len = resolve_max_eval_seq_len(args, sweep_specs, blend_specs) + train_loss_mask_stride_frac = resolve_train_loss_mask_stride_frac(args) + if args.final_eval_mode not in {"primary", "blend"}: + raise ValueError(f"Unsupported FINAL_EVAL_MODE={args.final_eval_mode!r}; expected 'primary' or 'blend'") + if args.final_eval_mode == "blend" and not blend_specs: + raise ValueError("FINAL_EVAL_MODE=blend requires EVAL_BLEND_SEQ_LENS to be set") + + # ----------------------------- + # DISTRIBUTED + DEVICE SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + device_override = os.environ.get("DEVICE", "").strip().lower() + grad_accum_override = os.environ.get("GRAD_ACCUM_STEPS", "").strip() + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if grad_accum_override: + grad_accum_steps = int(grad_accum_override) + if grad_accum_steps <= 0: + raise ValueError(f"GRAD_ACCUM_STEPS must be positive, got {grad_accum_steps}") + else: + if 8 % world_size != 0: + raise ValueError( + f"WORLD_SIZE={world_size} must divide 8 for default grad accumulation; " + "set GRAD_ACCUM_STEPS explicitly to override" + ) + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + tokens_per_microstep = world_size * grad_accum_steps * args.train_seq_len + if args.train_batch_tokens % tokens_per_microstep != 0: + raise ValueError( + "TRAIN_BATCH_TOKENS must be divisible by WORLD_SIZE*GRAD_ACCUM_STEPS*TRAIN_SEQ_LEN; " + f"got TRAIN_BATCH_TOKENS={args.train_batch_tokens}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + if device_override: + if device_override == "cuda" and not torch.cuda.is_available(): + raise RuntimeError("DEVICE=cuda requested but CUDA is unavailable") + if device_override not in {"cpu", "cuda"}: + raise ValueError(f"Unsupported DEVICE={device_override!r}; expected 'cpu' or 'cuda'") + device = torch.device(device_override, local_rank) if device_override == "cuda" else torch.device("cpu") + else: + device = torch.device("cuda", local_rank) if torch.cuda.is_available() else torch.device("cpu") + if device.type == "cuda": + torch.cuda.set_device(device) + autocast_enabled = device.type == "cuda" + use_compile = bool(int(os.environ.get("USE_TORCH_COMPILE", "1" if device.type == "cuda" else "0"))) + compile_dynamic_mode_raw = os.environ.get("TORCH_COMPILE_DYNAMIC", "true").strip().lower() + if compile_dynamic_mode_raw in {"1", "true", "yes", "on"}: + compile_dynamic: bool | None = True + elif compile_dynamic_mode_raw in {"0", "false", "no", "off"}: + compile_dynamic = False + elif compile_dynamic_mode_raw in {"none", "auto", "default", ""}: + compile_dynamic = None + else: + raise ValueError( + f"Unsupported TORCH_COMPILE_DYNAMIC={compile_dynamic_mode_raw!r}; expected true|false|none" + ) + if use_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + if distributed: + if device.type == "cuda": + dist.init_process_group(backend="nccl", device_id=device) + else: + dist.init_process_group(backend="gloo") + dist.barrier() + master_process = rank == 0 + + sdp_backends_log = "cpu" + if device.type == "cuda": + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + # Some consumer GPUs and GQA configs do not support flash-only SDPA. + # Default to "auto" so CUDA kernels can fall back to math/mem-efficient. + sdp_backend_mode = os.environ.get("SDP_BACKEND_MODE", "auto").strip().lower() + if sdp_backend_mode == "flash": + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + sdp_backends_log = "cudnn=False flash=True mem_efficient=False math=False mode=flash" + elif sdp_backend_mode == "math": + enable_cudnn_sdp(False) + enable_flash_sdp(False) + enable_mem_efficient_sdp(False) + enable_math_sdp(True) + sdp_backends_log = "cudnn=False flash=False mem_efficient=False math=True mode=math" + elif sdp_backend_mode == "auto": + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(True) + enable_math_sdp(True) + sdp_backends_log = "cudnn=False flash=True mem_efficient=True math=True mode=auto" + else: + raise ValueError( + f"Unsupported SDP_BACKEND_MODE={sdp_backend_mode!r}; expected 'auto', 'flash', or 'math'" + ) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + f"device:{device} distributed:{distributed} use_torch_compile:{use_compile} " + f"torch_compile_dynamic:{compile_dynamic}", + console=False, + ) + if device.type == "cuda": + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if device.type == "cuda": + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, max_eval_seq_len) + if args.val_max_tokens > 0: + usable = (min(args.val_max_tokens, val_tokens.numel() - 1) // max_eval_seq_len) * max_eval_seq_len + if usable <= 0: + raise ValueError( + f"VAL_MAX_TOKENS={args.val_max_tokens} is too small for MAX_EVAL_SEQ_LEN={max_eval_seq_len}" + ) + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0( + f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1} " + f"val_max_tokens:{args.val_max_tokens if args.val_max_tokens > 0 else 'full'}" + ) + _, primary_eval_seq_len, primary_eval_rope_scale = resolve_primary_eval_spec(args) + log0( + f"eval_primary: seq_len:{primary_eval_seq_len} rope_scale:{primary_eval_rope_scale:.4f} " + f"stride_frac:{args.eval_stride_frac:.4f} final_eval_mode:{args.final_eval_mode}" + ) + if len(sweep_specs) > 1: + sweep_specs_log = ",".join( + f"{name}:{seq_len}@{rope_scale:.4f}" + for name, seq_len, rope_scale in sweep_specs[1:] + ) + log0(f"eval_sweep: specs:{sweep_specs_log}") + if blend_specs: + blend_stride_frac = args.eval_blend_stride_frac if args.eval_blend_stride_frac > 0.0 else args.eval_stride_frac + blend_specs_log = ",".join( + f"{name}:{seq_len}@{rope_scale:.4f}" + for name, seq_len, rope_scale in blend_specs + ) + blend_weights_log = ",".join(f"{weight:.6f}" for weight in blend_weights) + log0( + f"eval_blend: stride_frac:{blend_stride_frac:.4f} specs:{blend_specs_log} " + f"weights:{blend_weights_log} position_bias:{args.eval_blend_position_bias:.4f} " + f"position_power:{args.eval_blend_position_power:.4f}" + ) + log0( + f"eval_cont_cache: enabled:{int(args.eval_cont_cache_enabled)} " + f"window:{args.eval_cont_cache_window} topk:{args.eval_cont_cache_topk} " + f"weight:{args.eval_cont_cache_weight:.4f} logit_scale:{args.eval_cont_cache_logit_scale:.4f} " + f"conf_power:{args.eval_cont_cache_conf_power:.4f} batch_seqs:{args.eval_cont_cache_batch_seqs}" + ) + log0( + f"train_loss_mask: enabled:{int(args.train_loss_mask_enabled)} " + f"stride_frac:{train_loss_mask_stride_frac:.4f}" + ) + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + # Enable LSQ fake-quant allocation on CastedLinear BEFORE model construction so + # each CastedLinear gains a per-row learnable qat_log_scale parameter automatically. + CastedLinear.qat_lsq_enabled = bool(args.qat_lsq) + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + recurrent_core_layers=args.recurrent_core_layers, + recurrent_steps=args.recurrent_steps, + share_ffn_across_blocks=args.share_ffn_across_blocks, + intra_loop_start=args.intra_loop_start, + intra_loop_end=args.intra_loop_end, + intra_loop_steps=args.intra_loop_steps, + use_parallel_residual=args.use_parallel_residual, + use_swiglu=args.use_swiglu, + bigram_rank=args.bigram_rank, + mtp_enabled=args.mtp_enabled, + mtp_steps=args.mtp_steps, + mtp_weight=args.mtp_weight, + mtp_decay=args.mtp_decay, + mtp_tie_embeddings=args.mtp_tie_embeddings, + use_ssm=args.use_ssm, + ssm_every_n=args.ssm_every_n, + ssm_expand=args.ssm_expand, + ssm_kernel=args.ssm_kernel, + ssm_impl=args.ssm_impl, + mamba3_d_state=args.mamba3_d_state, + mamba3_head_dim=args.mamba3_head_dim, + mamba3_is_mimo=args.mamba3_is_mimo, + mamba3_mimo_rank=args.mamba3_mimo_rank, + mamba3_chunk_size=args.mamba3_chunk_size, + mamba3_outproj_norm=args.mamba3_outproj_norm, + residual_ngram_enabled=args.residual_ngram_enabled, + residual_bigram_rank=args.residual_bigram_rank, + residual_trigram_rank=args.residual_trigram_rank, + residual_ngram_mix_init=args.residual_ngram_mix_init, + ngram_softcap=args.ngram_softcap, + ngram_entropy_gate=args.ngram_entropy_gate, + copy_cache_enabled=args.copy_cache_enabled, + copy_cache_window=args.copy_cache_window, + copy_cache_dim=args.copy_cache_dim, + copy_cache_gate_init=args.copy_cache_gate_init, + moe_num_experts=args.moe_num_experts, + moe_every_n=args.moe_every_n, + moe_capacity_factor=args.moe_capacity_factor, + moe_aux_loss_coeff=args.moe_aux_loss_coeff, + dual_head_enabled=args.dual_head_enabled, + dual_head_num_classes=4, + jpcr_enabled=args.jpcr_enabled, + jpcr_hidden=args.jpcr_hidden, + jpcr_proj_dim=args.jpcr_proj_dim, + jpcr_blend_init=args.jpcr_blend_init, + use_sandwich_norm=args.use_sandwich_norm, + embed_scale=args.embed_scale, + ).to(device=device, dtype=torch.bfloat16 if autocast_enabled else torch.float32) + if autocast_enabled: + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + if _OfficialMamba3 is not None and isinstance(module, _OfficialMamba3): + module.float() + restore_low_dim_params_to_fp32(base_model) + if use_compile: + # Disable DDPOptimizer: it splits compiled graphs at DDP bucket boundaries and + # crashes with `AttributeError: 'int' object has no attribute 'meta'` when plain + # Python int instance attrs (num_heads, head_dim) are captured as symbolic inputs + # to a subgraph. With world_size=1 the optimisation is a no-op anyway. + torch._dynamo.config.optimize_ddp = False + compiled_model = torch.compile(base_model, dynamic=compile_dynamic) if use_compile else base_model + model: nn.Module + if distributed: + ddp_find_unused_override = os.environ.get("DDP_FIND_UNUSED_PARAMETERS", "").strip().lower() + # find_unused_parameters=True is required when QAT_LSQ=1 because + # qat_log_scale params are registered but sit idle until QAT activates. + # Dual-head and JPCR are safety-touched in loss so they remain in graph with zero grads. + if ddp_find_unused_override in {"1", "true", "yes", "on"}: + _ddp_find_unused = True + elif ddp_find_unused_override in {"0", "false", "no", "off"}: + _ddp_find_unused = False + elif ddp_find_unused_override in {"", "auto", "default"}: + _ddp_find_unused = bool(args.qat_lsq) + else: + raise ValueError( + f"Unsupported DDP_FIND_UNUSED_PARAMETERS={ddp_find_unused_override!r}; expected true|false|auto" + ) + log0(f"ddp_find_unused_parameters:{int(_ddp_find_unused)}", console=False) + model = ( + DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, find_unused_parameters=_ddp_find_unused) + if device.type == "cuda" + else DDP(compiled_model, broadcast_buffers=False, find_unused_parameters=_ddp_find_unused) + ) + else: + model = compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if (p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) + and not name.endswith("qat_log_scale") + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=autocast_enabled, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=autocast_enabled, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if args.bigram_rank > 0: + bigram_params = [base_model.bigram_left.weight, base_model.bigram_right.weight, base_model.bigram_scale] + optimizer_bigram = torch.optim.Adam( + [{"params": bigram_params, "lr": args.bigram_lr, "base_lr": args.bigram_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=autocast_enabled, + ) + optimizers.append(optimizer_bigram) + if args.residual_ngram_enabled and getattr(base_model, "residual_ngram_enabled", False): + residual_params: list[nn.Parameter] = [ + base_model.residual_ngram_scale, + base_model.residual_ngram_gate.weight, + ] + if base_model.residual_ngram_gate.bias is not None: + residual_params.append(base_model.residual_ngram_gate.bias) + if base_model.residual_bigram_rank > 0: + residual_params.extend([base_model.residual_bigram_left.weight, base_model.residual_bigram_right.weight]) + if base_model.residual_trigram_rank > 0: + residual_params.extend( + [ + base_model.residual_trigram_prev1.weight, + base_model.residual_trigram_prev2.weight, + base_model.residual_trigram_right.weight, + ] + ) + optimizer_residual = torch.optim.Adam( + [{"params": residual_params, "lr": args.residual_ngram_lr, "base_lr": args.residual_ngram_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=autocast_enabled, + ) + optimizers.append(optimizer_residual) + if args.copy_cache_enabled and getattr(base_model, "copy_cache_enabled", False): + copy_params: list[nn.Parameter] = [ + base_model.copy_q.weight, + base_model.copy_k.weight, + base_model.copy_gate.weight, + ] + if base_model.copy_gate.bias is not None: + copy_params.append(base_model.copy_gate.bias) + optimizer_copy = torch.optim.Adam( + [{"params": copy_params, "lr": args.copy_cache_lr, "base_lr": args.copy_cache_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=autocast_enabled, + ) + optimizers.append(optimizer_copy) + if args.dual_head_enabled and getattr(base_model, "dual_head", None) is not None: + dual_params = [base_model.dual_head.weight] + if base_model.dual_head.bias is not None: + dual_params.append(base_model.dual_head.bias) + optimizer_dual = torch.optim.Adam( + [{"params": dual_params, "lr": args.dual_head_lr, "base_lr": args.dual_head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=autocast_enabled, + ) + optimizers.append(optimizer_dual) + if args.mtp_enabled and base_model.mtp_branches is not None: + mtp_params: list[nn.Parameter] = [] + for branch in base_model.mtp_branches: + mtp_params.extend(list(branch.parameters())) + if base_model.mtp_heads is not None: + for head in base_model.mtp_heads: + mtp_params.extend(list(head.parameters())) + if mtp_params: + optimizer_mtp = torch.optim.Adam( + [{"params": mtp_params, "lr": args.mtp_lr, "base_lr": args.mtp_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=autocast_enabled, + ) + optimizers.append(optimizer_mtp) + # JPCR predictor optimizer (also covers Ouroboros controllers if used) + if base_model.jpcr_enabled and len(base_model.jpcr_predictors) > 0: + jpcr_params: list[nn.Parameter] = list(base_model.jpcr_predictors.parameters()) + if jpcr_params: + optimizer_jpcr = torch.optim.Adam( + [{"params": jpcr_params, "lr": args.jpcr_lr, "base_lr": args.jpcr_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=autocast_enabled, + ) + optimizers.append(optimizer_jpcr) + elif len(base_model.intra_loop_controllers) > 0: + # Ouroboros controllers need an optimizer too (was missing before!) + ctrl_params: list[nn.Parameter] = list(base_model.intra_loop_controllers.parameters()) + if ctrl_params: + optimizer_ctrl = torch.optim.Adam( + [{"params": ctrl_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=autocast_enabled, + ) + optimizers.append(optimizer_ctrl) + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=autocast_enabled, + ) + optimizers.insert(1, optimizer_head) + + # Dedicated optimizer for LSQ per-row log_scale parameters across the WHOLE model. + # These are 1D learnable steps inside every CastedLinear (blocks + lm_head + bigram + ...), + # not all of which would otherwise land in scalar_params (which only walks blocks). + optimizer_lsq: torch.optim.Optimizer | None = None + if args.qat_lsq: + lsq_params: list[nn.Parameter] = [ + m.qat_log_scale + for m in base_model.modules() + if isinstance(m, CastedLinear) and m.qat_log_scale is not None + ] + if lsq_params: + lsq_lr = float(os.environ.get("QAT_LSQ_LR", str(args.scalar_lr))) + optimizer_lsq = torch.optim.Adam( + [{"params": lsq_params, "lr": lsq_lr, "base_lr": lsq_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=autocast_enabled, + ) + optimizers.append(optimizer_lsq) + if master_process: + log0(f"qat_lsq: optimizer params={len(lsq_params)} lr={lsq_lr}") + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"sdp_backends:{sdp_backends_log}") + attention_mode = "mha" if args.num_kv_heads == args.num_heads else "gqa" + log0( + f"attention_mode:{attention_mode} num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads} " + f"use_swiglu:{args.use_swiglu} use_ssm:{args.use_ssm} ssm_every_n:{args.ssm_every_n} " + f"ssm_impl:{args.ssm_impl} ssm_expand:{args.ssm_expand} ssm_kernel:{args.ssm_kernel} " + f"mamba3_d_state:{args.mamba3_d_state} mamba3_head_dim:{args.mamba3_head_dim} " + f"mamba3_is_mimo:{args.mamba3_is_mimo} mamba3_mimo_rank:{args.mamba3_mimo_rank} " + f"mamba3_chunk_size:{args.mamba3_chunk_size} mamba3_outproj_norm:{args.mamba3_outproj_norm} " + f"mtp_enabled:{args.mtp_enabled} mtp_steps:{args.mtp_steps} mtp_weight:{args.mtp_weight} " + f"mtp_decay:{args.mtp_decay} mtp_tie_embeddings:{args.mtp_tie_embeddings} " + f"distill_enabled:{args.distill_enabled} distill_start_frac:{args.distill_start_frac} " + f"distill_start_step:{args.distill_start_step} distill_start_wallclock_frac:{args.distill_start_wallclock_frac} " + f"distill_weight:{args.distill_weight} distill_temp:{args.distill_temp} distill_ema_decay:{args.distill_ema_decay} " + f"jpcr_apply_every:{args.jpcr_apply_every} " + f"logit_reg_weight:{args.logit_reg_weight} byte_weighted_loss:{args.byte_weighted_loss_enabled} " + f"byte_weighted_loss_alpha:{args.byte_weighted_loss_alpha} " + f"residual_ngram_enabled:{args.residual_ngram_enabled} residual_bigram_rank:{args.residual_bigram_rank} " + f"residual_trigram_rank:{args.residual_trigram_rank} residual_ngram_lr:{args.residual_ngram_lr} " + f"residual_ngram_mix_init:{args.residual_ngram_mix_init} " + f"ngram_softcap:{args.ngram_softcap} ngram_entropy_gate:{args.ngram_entropy_gate} " + f"ttt_enabled:{args.ttt_enabled} ttt_lr:{args.ttt_lr} ttt_steps:{args.ttt_steps} ttt_momentum:{args.ttt_momentum} " + f"copy_cache_enabled:{args.copy_cache_enabled} copy_cache_window:{args.copy_cache_window} " + f"copy_cache_dim:{args.copy_cache_dim} copy_cache_lr:{args.copy_cache_lr} " + f"copy_cache_gate_init:{args.copy_cache_gate_init} " + f"dual_head_enabled:{args.dual_head_enabled} dual_head_weight:{args.dual_head_weight} " + f"dual_head_start_frac:{args.dual_head_start_frac} dual_head_lr:{args.dual_head_lr} " + f"qat_scheme:{args.qat_scheme} qat_start_step:{args.qat_start_step} qat_end_step:{args.qat_end_step} " + f"qat_start_wallclock_frac:{args.qat_start_wallclock_frac} qat_end_wallclock_frac:{args.qat_end_wallclock_frac} " + f"moe_num_experts:{args.moe_num_experts} moe_every_n:{args.moe_every_n} " + f"moe_capacity_factor:{args.moe_capacity_factor} moe_aux_loss_coeff:{args.moe_aux_loss_coeff} " + f"num_moe_blocks:{base_model.num_moe_blocks}" + ) + if base_model.use_recurrence: + log0( + f"architecture:recurrent core_layers:{base_model.recurrent_core_layers} " + f"recurrent_steps:{base_model.recurrent_steps} " + f"effective_layers:{base_model.total_effective_layers} " + f"ssm_blocks:{base_model.num_ssm_blocks} attn_blocks:{base_model.num_attn_blocks} " + f"share_ffn_across_blocks:{base_model.share_ffn_across_blocks}" + ) + else: + intra_info = ( + f" intra_loop:[{base_model.intra_loop_start}-{base_model.intra_loop_end}]x{base_model.intra_loop_steps}" + f" effective_layers:{base_model.total_effective_layers}" + if base_model.intra_loop_start >= 0 else "" + ) + jpcr_info = ( + f" jpcr:hidden={args.jpcr_hidden},weight={args.jpcr_weight},blend_init={args.jpcr_blend_init},lr={args.jpcr_lr}" + if base_model.jpcr_enabled else "" + ) + log0( + f"architecture:stacked num_layers:{args.num_layers} " + f"encoder_layers:{base_model.num_encoder_layers} decoder_layers:{base_model.num_decoder_layers} " + f"ssm_blocks:{base_model.num_ssm_blocks} attn_blocks:{base_model.num_attn_blocks}" + f"{intra_info}{jpcr_info}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr} mtp_lr:{args.mtp_lr if args.mtp_enabled else 0.0} " + f"copy_cache_lr:{args.copy_cache_lr if args.copy_cache_enabled else 0.0} " + f"dual_head_lr:{args.dual_head_lr if args.dual_head_enabled else 0.0}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + log0("Initializing DistributedTokenLoader...") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + train_loss_mask_cache: dict[int, Tensor] = {} + + def build_train_loss_mask(batch_size: int, seq_len: int) -> Tensor | None: + if not args.train_loss_mask_enabled: + return None + mask_cpu = train_loss_mask_cache.get(seq_len) + if mask_cpu is None: + mask_cpu, _, _ = build_loss_mask_cpu(seq_len, train_loss_mask_stride_frac) + train_loss_mask_cache[seq_len] = mask_cpu + return mask_cpu.unsqueeze(0).expand(batch_size, -1).to(device=device) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + log0("Saving initial model and optimizer states for warmup...") + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + warmup_reason = "torch.compile/TileLang" if use_compile else "TileLang/custom kernels" + log0(f"Starting warmup loop ({args.warmup_steps} steps). The first step may compile {warmup_reason} kernels...") + # Pre-build dummy tensors matching the main training loop signature so that + # torch.compile traces the correct graph during warmup (no re-trace at step 1). + _warmup_n_jpcr = (base_model.intra_loop_end - base_model.intra_loop_start + 1) if base_model.jpcr_enabled else 0 + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + warmup_loss_mask = build_train_loss_mask(x.size(0), args.train_seq_len) + # Use the same kwargs signature as the main loop so compile doesn't retrace later. + _wu_teacher_logits: Tensor = torch.empty(0, device=device) + _wu_intermediates: list[Tensor] = [ + torch.zeros(x.size(0), args.train_seq_len, args.model_dim, device=device, dtype=torch.bfloat16) + for _ in range(_warmup_n_jpcr) + ] if _warmup_n_jpcr > 0 else [] + # Dummy per_token_weights / aux_targets so warmup traces the same graph + # as the main loop (some configs pass non-None here — traced branches + # differ, so include them unconditionally to avoid retracing on step 1). + _wu_token_weights = torch.ones_like(y, dtype=torch.float32) if args.byte_weighted_loss_enabled else None + _wu_aux_targets = torch.zeros_like(y, dtype=torch.long) if args.dual_head_enabled else None + _wu_aux_weight = 0.0 + if autocast_enabled: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model( + x, y, + loss_mask=warmup_loss_mask, + per_token_weights=_wu_token_weights, + aux_targets=_wu_aux_targets, + aux_weight=_wu_aux_weight, + distill_teacher_logits=_wu_teacher_logits, + distill_weight=0.0, + distill_temp=args.distill_temp, + logit_reg_weight=0.0, + jpcr_teacher_intermediates=_wu_intermediates, + jpcr_weight=0.0, + jpcr_runtime_active=False, + ) + else: + warmup_loss = model( + x, y, + loss_mask=warmup_loss_mask, + per_token_weights=_wu_token_weights, + aux_targets=_wu_aux_targets, + aux_weight=_wu_aux_weight, + distill_teacher_logits=_wu_teacher_logits, + distill_weight=0.0, + distill_temp=args.distill_temp, + logit_reg_weight=0.0, + jpcr_teacher_intermediates=_wu_intermediates, + jpcr_weight=0.0, + jpcr_runtime_active=False, + ) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if warmup_step == 0 or args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + distill_start_step = resolve_distill_start_step(args) + dual_head_start_step = int(max(0.0, min(1.0, args.dual_head_start_frac)) * args.iterations) + ema_teacher: GPT | None = None + if args.distill_enabled and args.distill_weight > 0.0: + ema_teacher = copy.deepcopy(base_model) + ema_teacher.eval() + for p in ema_teacher.parameters(): + p.requires_grad_(False) + if args.distill_enabled and args.distill_weight > 0.0: + distill_mode = ( + f"step:{args.distill_start_step}" + if args.distill_start_step >= 0 + else ( + f"wallclock_frac:{max(0.0, min(1.0, args.distill_start_wallclock_frac)):.4f}" + if args.distill_start_wallclock_frac >= 0.0 and max_wallclock_ms is not None + else f"iter_frac:{max(0.0, min(1.0, args.distill_start_frac)):.4f}" + ) + ) + log0(f"distill_start: mode:{distill_mode} resolved_step:{distill_start_step}") + if args.jpcr_apply_every > 1: + log0(f"jpcr_apply_every:{args.jpcr_apply_every} (distill+JPCR applied every Nth step)") + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + if device.type == "cuda": + torch.cuda.synchronize() + t0 = time.perf_counter() + + # SWA state: accumulated on CPU to avoid GPU memory pressure. + swa_state: dict[str, torch.Tensor] | None = None + swa_count = 0 + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + if device.type == "cuda": + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + autocast_enabled, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + if device.type == "cuda": + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + # Load SWA-averaged weights before eval + export (better generalization + quantization). + if args.swa_enabled and swa_state is not None: + log0(f"swa: loading averaged weights from {swa_count} snapshots") + cur_dtypes = {k: v.dtype for k, v in base_model.state_dict().items()} + swa_load = {k: v.to(device=device, dtype=cur_dtypes[k]) for k, v in swa_state.items() if k in cur_dtypes} + # strict=False because qat_log_scale entries are intentionally excluded from swa_state. + base_model.load_state_dict(swa_load, strict=not args.qat_lsq) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + + # SWA: once warmdown begins (scale < 1), start averaging weights on CPU every N steps. + # qat_log_scale params are intentionally excluded: SWA would corrupt them by averaging + # scales from different QAT level regimes (256/64/16). The final trained scales are kept. + if args.swa_enabled and scale < 1.0 and step % args.swa_collect_every == 0: + swa_snapshot = { + k: v.detach().cpu().float().clone() + for k, v in base_model.state_dict().items() + if not k.endswith(".qat_log_scale") + } + if swa_state is None: + swa_state = swa_snapshot + swa_count = 1 + else: + inv = 1.0 / (swa_count + 1) + for k, v in swa_snapshot.items(): + if k in swa_state: + swa_state[k].mul_(1.0 - inv).add_(v, alpha=inv) + swa_count += 1 + + # QAT: enable fake-quantisation once model has partially converged. + # int8: single stage at qat_start_step (levels=256). + # int4: 3-stage progressive schedule starting at qat_start_step: + # stage 0 (<33% of QAT window): levels=256 (gentle, int8-equivalent) + # stage 1 (33-67% of QAT window): levels=64 + # stage 2 (>67% of QAT window): levels=16 (true int4) + # Progressive avoids the catastrophic loss spike from jumping straight to 16 levels. + if args.qat_scheme != "none": + target_levels, qat_mode = qat_target_levels(args, step, elapsed_ms, max_wallclock_ms) + if CastedLinear.qat_levels != target_levels: + prev_levels = CastedLinear.qat_levels + CastedLinear.qat_levels = target_levels + log0( + f"qat: {'enabled' if target_levels > 0 else 'disabled'} levels:{target_levels} " + f"step:{step} elapsed_ms:{elapsed_ms:.0f} mode:{qat_mode}" + ) + # LSQ: on the transition from 0 → nonzero, seed per-row log-scales from + # the current weight statistics (max-abs / half). Also reseed on each + # progressive level change so the learned scales start from a valid grid + # for the new quantisation resolution. + if args.qat_lsq and target_levels > 0 and prev_levels != target_levels: + n_lsq = init_lsq_scales(base_model, target_levels) + log0(f"qat: lsq_init count:{n_lsq} levels:{target_levels}") + # Clear stale Adam momentum/variance from the previous level regime + # so the fresh scale values get unbiased gradient updates. + if optimizer_lsq is not None: + optimizer_lsq.state.clear() + log0(f"qat: lsq_state_reset levels:{target_levels}") + + # Sequence length curriculum: ramp from curriculum_min_seq_len → train_seq_len. + if args.curriculum_enabled and step < args.curriculum_steps: + frac_c = step / max(args.curriculum_steps, 1) + curr_seq_len = args.curriculum_min_seq_len + int((args.train_seq_len - args.curriculum_min_seq_len) * frac_c) + curr_seq_len = 1 << int(math.log2(max(64, curr_seq_len))) + else: + curr_seq_len = args.train_seq_len + + distill_active = ( + ema_teacher is not None + and args.distill_weight > 0.0 + and distill_is_active(args, step, elapsed_ms, max_wallclock_ms, distill_start_step) + ) + apply_distill_this_step = bool(distill_active and (step % args.jpcr_apply_every == 0)) + jpcr_runtime_active = bool(base_model.jpcr_enabled and apply_distill_this_step) + # JPCR loss warmup: ramp weight from 0 → full over jpcr_warmup_steps after distill activates. + # Also freeze blend gates for first 300 steps so predictors learn via loss before affecting forward pass. + if distill_active and base_model.jpcr_enabled: + if not hasattr(main, "_jpcr_distill_start_step"): + main._jpcr_distill_start_step = step # type: ignore[attr-defined] + jpcr_steps_since = step - main._jpcr_distill_start_step # type: ignore[attr-defined] + jpcr_ramp = min(jpcr_steps_since / max(args.jpcr_warmup_steps, 1), 1.0) + jpcr_active_weight = args.jpcr_weight * jpcr_ramp + # Freeze/unfreeze blend gates: let predictor learn before gate opens + gate_frozen = jpcr_steps_since < 300 + else: + jpcr_active_weight = 0.0 + gate_frozen = False + dual_head_active_weight = ( + float(args.dual_head_weight) + if args.dual_head_enabled and step >= dual_head_start_step and args.dual_head_weight > 0.0 + else 0.0 + ) + + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, curr_seq_len, grad_accum_steps) + # Always pass consistent types AND shapes to forward() to avoid torch.compile + # retracing when distillation activates. JPCR is only enabled once distill is on. + teacher_logits: Tensor = torch.empty(0, device=device) + if jpcr_runtime_active and args.jpcr_weight > 0.0: + _n_jpcr = (base_model.intra_loop_end - base_model.intra_loop_start + 1) + teacher_intermediates: list[Tensor] = [ + torch.zeros(x.size(0), curr_seq_len, args.model_dim, device=device, dtype=torch.bfloat16) + for _ in range(_n_jpcr) + ] + else: + teacher_intermediates = [] + token_weights: Tensor | None = None + aux_targets: Tensor | None = None + train_loss_mask = build_train_loss_mask(x.size(0), curr_seq_len) + if apply_distill_this_step and ema_teacher is not None: + # Use no_grad (not inference_mode) because inference tensors can error when + # downstream ops save them for backward (e.g., KL in distillation under compile). + # Wrap in autocast to match training dtype (bf16) — teacher weights are bf16. + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=autocast_enabled): + if jpcr_runtime_active and args.jpcr_weight > 0.0: + # Capture both logits and per-block intermediates for JPCR. + teacher_logits, teacher_intermediates = ema_teacher.forward_logits_and_intermediates( + x, jpcr_runtime_active=True + ) + teacher_logits = teacher_logits.detach() + teacher_intermediates = [h.detach() for h in teacher_intermediates] + else: + teacher_logits = ema_teacher.forward_logits(x).detach() + if args.byte_weighted_loss_enabled: + with torch.no_grad(): + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.float32) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.float32) + mean_bytes = token_bytes.mean().clamp_min(1e-6) + rel = token_bytes / mean_bytes + alpha = float(args.byte_weighted_loss_alpha) + rel = (1.0 - alpha) + alpha * rel + token_weights = rel.reshape_as(y) + if dual_head_active_weight > 0.0: + with torch.no_grad(): + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + is_boundary = is_boundary_token_lut[tgt_ids] + has_space = has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids] + is_long = base_bytes_lut[tgt_ids] >= 4 + cls = torch.zeros_like(tgt_ids, dtype=torch.long) + cls = torch.where(has_space, torch.ones_like(cls), cls) # class 1: leading-space continuation + cls = torch.where(is_long, torch.full_like(cls, 2), cls) # class 2: long piece (4+ bytes) + cls = torch.where(is_boundary, torch.full_like(cls, 3), cls) # class 3: boundary/special + aux_targets = cls.reshape_as(y) + if autocast_enabled: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model( + x, + y, + loss_mask=train_loss_mask, + per_token_weights=token_weights, + aux_targets=aux_targets, + aux_weight=dual_head_active_weight, + distill_teacher_logits=teacher_logits, + distill_weight=args.distill_weight if apply_distill_this_step else 0.0, + distill_temp=args.distill_temp, + logit_reg_weight=args.logit_reg_weight, + jpcr_teacher_intermediates=teacher_intermediates, + jpcr_weight=jpcr_active_weight, + jpcr_runtime_active=jpcr_runtime_active, + ) + else: + loss = model( + x, + y, + loss_mask=train_loss_mask, + per_token_weights=token_weights, + aux_targets=aux_targets, + aux_weight=dual_head_active_weight, + distill_teacher_logits=teacher_logits, + distill_weight=args.distill_weight if apply_distill_this_step else 0.0, + distill_temp=args.distill_temp, + logit_reg_weight=args.logit_reg_weight, + jpcr_teacher_intermediates=teacher_intermediates, + jpcr_weight=jpcr_active_weight, + jpcr_runtime_active=jpcr_runtime_active, + ) + train_loss += loss.detach() + (loss * grad_scale).backward() + if gate_frozen: + for p in base_model.jpcr_predictors: + if p.blend_gate.grad is not None: + p.blend_gate.grad = None + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + if ema_teacher is not None: + with torch.no_grad(): + decay = float(args.distill_ema_decay) + for p_t, p_s in zip(ema_teacher.parameters(), base_model.parameters(), strict=True): + p_t.mul_(decay).add_(p_s, alpha=1.0 - decay) + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + if device.type == "cuda": + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # a compressed quantized artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + raw_total_submission = model_bytes + code_bytes + raw_budget_delta = args.submission_size_budget_bytes - raw_total_submission + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {raw_total_submission} bytes") + if raw_budget_delta >= 0: + log0( + f"submission_budget raw_total:{raw_total_submission} budget:{args.submission_size_budget_bytes} " + f"headroom_bytes:{raw_budget_delta}" + ) + else: + log0( + f"submission_budget raw_total:{raw_total_submission} budget:{args.submission_size_budget_bytes} " + f"over_bytes:{-raw_budget_delta}" + ) + + resolved_compressor, compressor_note = resolve_compressor(args.compressor) + + export_state_dict = base_model.state_dict() + qat_export_levels = CastedLinear.qat_levels + if master_process and args.qat_scheme != "none" and qat_export_levels <= 0: + log0( + f"qat_warning: QAT_SCHEME={args.qat_scheme} was requested but fake-quant never enabled before export; " + f"step:{step} qat_start_step:{args.qat_start_step} qat_end_step:{args.qat_end_step} " + f"qat_start_wallclock_frac:{args.qat_start_wallclock_frac} " + f"qat_end_wallclock_frac:{args.qat_end_wallclock_frac} iterations:{args.iterations}" + ) + elif master_process and args.qat_scheme != "none": + log0(f"qat_export: active_levels:{qat_export_levels}") + + # LSQ export plumbing (if enabled): collect learned per-row scales and strip + # the log_scale parameters from the state_dict. + lsq_scales_export: dict[str, Tensor] | None = None + if args.qat_lsq: + lsq_scales_export = collect_lsq_scales(base_model) + export_state_dict = { + k: v for k, v in export_state_dict.items() if not k.endswith(".qat_log_scale") + } + if master_process: + log0(f"qat_lsq: collected {len(lsq_scales_export)} per-row scales for export") + + # GPTQ: Hessian-aware post-training quantization (replaces naive round-to-nearest). + gptq_results: dict[str, tuple[Tensor, Tensor]] | None = None + if args.gptq_enabled: + active_scheme = args.mixed_low_precision_scheme if args.quant_scheme == "mixed" else args.quant_scheme + gptq_bits = 4 if active_scheme == "int4" else (5 if active_scheme == "int5" else 8) + if master_process: + log0(f"gptq: collecting Hessians from {args.gptq_nsamples} calibration samples...") + CastedLinear.qat_levels = 0 # disable fake-quant for calibration + hessians = collect_gptq_hessians( + base_model, val_tokens, device, + seq_len=args.train_seq_len, + nsamples=args.gptq_nsamples, + ) + if master_process: + log0(f"gptq: collected {len(hessians)} Hessians, quantizing with bits={gptq_bits}...") + gptq_results = gptq_quantize_state_dict( + base_model, export_state_dict, hessians, + bits=gptq_bits, + percdamp=args.gptq_percdamp, + blocksize=args.gptq_blocksize, + group_size=INT4_GROUP_SIZE if gptq_bits == 4 else 0, + use_nf4=NF4_ENABLED if gptq_bits == 4 else False, + ) + if master_process: + log0(f"gptq: quantized {len(gptq_results)} weight matrices") + + quant_obj, quant_stats = quantize_state_dict( + export_state_dict, + scheme=args.quant_scheme, + weight_order=args.weight_order, + mixed_low_precision_scheme=args.mixed_low_precision_scheme, + precomputed_scales=lsq_scales_export, + gptq_results=gptq_results, + ) + artifact_name = export_artifact_name(args.quant_scheme, resolved_compressor) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, resolved_compressor, args.compress_level) + quant_raw_bytes = len(quant_raw) + if master_process: + with open(artifact_name, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(artifact_name) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["payload_bytes"], 1) + if compressor_note: + log0(f"export_note:{compressor_note}") + log0( + f"export_config quant_scheme:{args.quant_scheme} mixed_low_precision_scheme:{args.mixed_low_precision_scheme} " + f"compressor:{resolved_compressor} weight_order:{args.weight_order} compress_level:{args.compress_level}" + ) + log0( + f"Serialized model {args.quant_scheme}+{resolved_compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + quant_total_submission = quant_file_bytes + code_bytes + quant_budget_delta = args.submission_size_budget_bytes - quant_total_submission + log0(f"Total submission size {args.quant_scheme}+{resolved_compressor}: {quant_total_submission} bytes") + if quant_budget_delta >= 0: + log0( + f"submission_budget {args.quant_scheme}+{resolved_compressor} total:{quant_total_submission} " + f"budget:{args.submission_size_budget_bytes} headroom_bytes:{quant_budget_delta}" + ) + else: + log0( + f"submission_budget {args.quant_scheme}+{resolved_compressor} total:{quant_total_submission} " + f"budget:{args.submission_size_budget_bytes} over_bytes:{-quant_budget_delta}" + ) + with open("final_export_manifest.json", "w", encoding="utf-8") as f: + json.dump( + { + "quant_scheme": args.quant_scheme, + "mixed_low_precision_scheme": args.mixed_low_precision_scheme, + "compressor_requested": args.compressor, + "compressor_resolved": resolved_compressor, + "compress_level": args.compress_level, + "weight_order": args.weight_order, + "artifact_name": artifact_name, + "artifact_bytes": quant_file_bytes, + "code_bytes": code_bytes, + "total_submission_bytes": quant_total_submission, + "submission_size_budget_bytes": args.submission_size_budget_bytes, + "budget_headroom_bytes": quant_budget_delta, + "baseline_tensor_bytes": quant_stats["baseline_tensor_bytes"], + "payload_bytes": quant_stats["payload_bytes"], + "raw_torch_bytes": quant_raw_bytes, + "payload_ratio": ratio, + "quant_format": quant_obj.get("__quant_format__", ""), + }, + f, + indent=2, + sort_keys=True, + ) + + if args.final_roundtrip_eval: + if distributed: + dist.barrier() + # Disable QAT fake-quant during roundtrip eval so loaded dequantized + # weights are not re-fake-quantized through stale LSQ scales. + CastedLinear.qat_levels = 0 + with open(artifact_name, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(decompress_blob(quant_blob_disk, resolved_compressor)), + map_location="cpu", + weights_only=True, + ) + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=False) + if device.type == "cuda": + torch.cuda.synchronize() + t_qeval = time.perf_counter() + roundtrip_tag = f"final_{args.quant_scheme}_{resolved_compressor}_roundtrip" + q_val_loss, q_val_bpb = run_final_eval_suite( + args, + roundtrip_tag, + model, + rank, + world_size, + device, + autocast_enabled, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + sweep_specs, + blend_specs, + blend_weights, + log0, + ) + if device.type == "cuda": + torch.cuda.synchronize() + log0( + f"{roundtrip_tag} val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms mode:{args.final_eval_mode}" + ) + log0( + f"{roundtrip_tag}_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f} " + f"mode:{args.final_eval_mode}" + ) + else: + log0("final_roundtrip skipped FINAL_ROUNDTRIP_EVAL=0") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main()