From 024b54eba27b8cb84b01d385ea872616c335e0c8 Mon Sep 17 00:00:00 2001 From: YaseenHQ Date: Sun, 3 May 2026 04:22:07 -0400 Subject: [PATCH] Add non-record RandProj384 PairMuonQK negative result --- .../README.md | 55 + .../lossless_caps.py | 833 +++ .../online_ngram_state.c | 433 ++ .../online_ngram_tilt.py | 386 ++ .../prepare_caseops_data.py | 194 + .../requirements.txt | 13 + .../submission.json | 10 + .../train_gpt.py | 4666 +++++++++++++++++ .../train_seed42.log | 331 ++ .../ttt_eval_seed42_fail.log | 339 ++ 10 files changed, 7260 insertions(+) create mode 100644 records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/README.md create mode 100644 records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/lossless_caps.py create mode 100644 records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/online_ngram_state.c create mode 100644 records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/online_ngram_tilt.py create mode 100644 records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/prepare_caseops_data.py create mode 100644 records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/requirements.txt create mode 100644 records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/submission.json create mode 100644 records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/train_gpt.py create mode 100644 records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/train_seed42.log create mode 100644 records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/ttt_eval_seed42_fail.log diff --git a/records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/README.md b/records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/README.md new file mode 100644 index 0000000000..4a573e71c8 --- /dev/null +++ b/records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/README.md @@ -0,0 +1,55 @@ +# Non-record: SP8192 + RandProj384 tied embeddings + Pairwise-QK Muon -- Single-seed negative result + +This is a non-record submission testing two new ideas on the SP8192 / CaseOps / legal-TTT stack: + +- random-projection tied embeddings (`RandProj384`) +- pairwise-head Muon orthogonalization for Q/K (`PairMuonQK`) + +This run completed training and quantization on 8xH100 SXM within the 10-minute training cap and produced a legal sub-16MB artifact, but it was not competitive with the frontier. I am submitting it as a negative result because the failure is clear and informative. + +## Single-seed result + +Seed: `42` + +- train steps: `1724` +- train wallclock: `599714 ms` +- in-run full validation: `val_loss 2.4662`, `val_bpb 1.1269` +- post-EMA diagnostic: `val_loss 2.47020597`, `val_bpb 1.12868936` +- quantized model size: `15,399,365` bytes +- total submission size: `15,438,770` bytes +- headroom to 16MB: `561,230` bytes + +## What happened + +The artifact fit comfortably under the size limit, but model quality regressed too far from the public frontier before quantization and before TTT could help. + +The post-training legal TTT eval path also did not complete robustly on this stack: + +- larger TTT batch hit OOM during adaptation +- smaller TTT batch progressed but was too slow to be practical + +Because of that, I am not claiming a final post-TTT score. + +## Why this is still useful + +This result directly constrains the design space: + +- aggressive latent tied-embedding compression was destructive (`1.1269` pre-TTT BPB, more than `0.06` worse than the strongest open public frontier) +- pairwise Q/K Muon orthogonalization did not preserve frontier behavior +- parameter savings alone are insufficient; pre-quantization quality matters + +## Why this is non-record + +- single seed only +- not competitive with current SOTA +- no successful final TTT evaluation +- submitted as an interesting negative result rather than a leaderboard claim + +## Included files + +- `train_gpt.py` +- `requirements.txt` +- helper files required by the run +- `train_seed42.log` +- `ttt_eval_seed42_fail.log` +- `submission.json` diff --git a/records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/lossless_caps.py b/records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/lossless_caps.py new file mode 100644 index 0000000000..98e472f824 --- /dev/null +++ b/records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/lossless_caps.py @@ -0,0 +1,833 @@ +"""Lossless capitalization pre-encoding helpers. + +This module provides a narrow, reversible transform that only touches +ASCII capital letters `A-Z`. Each uppercase ASCII letter is rewritten as +``, where `sentinel` is a private-use Unicode +character that is escaped by doubling if it appears literally in the +input text. + +Example with the default sentinel `\\uE000`: + + "The NASA Launch" -> "\\uE000the \\uE000n\\uE000a\\uE000s\\uE000a \\uE000launch" + +The transform is intentionally simple for v1: + +- lowercase ASCII letters are unchanged +- uppercase ASCII letters become sentinel + lowercase letter +- non-ASCII characters are left untouched +- literal sentinel characters are escaped as sentinel + sentinel + +This makes the transform exactly invertible while allowing a downstream +tokenizer to reuse lowercase subwords across case variants. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Callable, Iterable + +LOSSLESS_CAPS_V1 = "lossless_caps_v1" +LOSSLESS_CAPS_V2 = "lossless_caps_v2" +LOSSLESS_CAPS_V3 = "lossless_caps_v3" +LOSSLESS_CAPS_V4 = "lossless_caps_v4" +LOSSLESS_CAPS_V5 = "lossless_caps_v5" +LOSSLESS_CAPS_V6 = "lossless_caps_v6" +LOSSLESS_CAPS_V7 = "lossless_caps_v7" +LOSSLESS_CAPS_CASEOPS_V1 = "lossless_caps_caseops_v1" +IDENTITY = "identity" +DEFAULT_SENTINEL = "\uE000" +DEFAULT_V2_TITLE = "\uE001" +DEFAULT_V2_ALLCAPS = "\uE002" +DEFAULT_V2_CAPNEXT = "\uE003" +DEFAULT_V2_ESC = "\uE004" +DEFAULT_V5_TITLE_MIN_LEN = 7 +DEFAULT_V6_ALLCAPS_MIN_LEN = 3 +DEFAULT_V7_ALLCAPS_MIN_LEN = 4 + + +class LosslessCapsError(ValueError): + """Raised when a transformed string is malformed.""" + + +def _is_ascii_upper(ch: str) -> bool: + return "A" <= ch <= "Z" + + +def _is_ascii_lower(ch: str) -> bool: + return "a" <= ch <= "z" + + +def _is_ascii_alpha(ch: str) -> bool: + return _is_ascii_lower(ch) or _is_ascii_upper(ch) + + +def _validate_distinct_single_chars(*chars: str) -> None: + if any(len(ch) != 1 for ch in chars): + raise ValueError("all control characters must be exactly one character") + if len(set(chars)) != len(chars): + raise ValueError("control characters must be distinct") + + +def encode_lossless_caps_v1(text: str, *, sentinel: str = DEFAULT_SENTINEL) -> str: + """Encode ASCII capitals reversibly using a one-character sentinel.""" + if len(sentinel) != 1: + raise ValueError("sentinel must be exactly one character") + out: list[str] = [] + for ch in text: + if ch == sentinel: + out.append(sentinel) + out.append(sentinel) + elif _is_ascii_upper(ch): + out.append(sentinel) + out.append(ch.lower()) + else: + out.append(ch) + return "".join(out) + + +def decode_lossless_caps_v1(text: str, *, sentinel: str = DEFAULT_SENTINEL) -> str: + """Decode the `lossless_caps_v1` transform back to the original text.""" + if len(sentinel) != 1: + raise ValueError("sentinel must be exactly one character") + out: list[str] = [] + i = 0 + n = len(text) + while i < n: + ch = text[i] + if ch != sentinel: + out.append(ch) + i += 1 + continue + if i + 1 >= n: + raise LosslessCapsError("dangling capitalization sentinel at end of string") + nxt = text[i + 1] + if nxt == sentinel: + out.append(sentinel) + elif _is_ascii_lower(nxt): + out.append(nxt.upper()) + else: + raise LosslessCapsError( + f"invalid sentinel escape sequence {sentinel + nxt!r}; " + "expected doubled sentinel or sentinel + lowercase ASCII letter" + ) + i += 2 + return "".join(out) + + +def encode_lossless_caps_v2( + text: str, + *, + title: str = DEFAULT_V2_TITLE, + allcaps: str = DEFAULT_V2_ALLCAPS, + capnext: str = DEFAULT_V2_CAPNEXT, + esc: str = DEFAULT_V2_ESC, +) -> str: + """Encode ASCII word capitalization with cheap word-level markers. + + Rules over maximal ASCII alphabetic runs: + - lowercase words stay unchanged + - TitleCase words become `title + lowercase(word)` + - ALLCAPS words become `allcaps + lowercase(word)` + - mixed-case words use: + - optional `title` when the first letter is uppercase + - `capnext + lowercase(letter)` for subsequent uppercase letters + - literal control characters are escaped as `esc + literal` + """ + _validate_distinct_single_chars(title, allcaps, capnext, esc) + controls = {title, allcaps, capnext, esc} + out: list[str] = [] + i = 0 + n = len(text) + while i < n: + ch = text[i] + if ch in controls: + out.append(esc) + out.append(ch) + i += 1 + continue + if not _is_ascii_alpha(ch): + out.append(ch) + i += 1 + continue + + j = i + 1 + while j < n and _is_ascii_alpha(text[j]): + j += 1 + word = text[i:j] + lower_word = word.lower() + + if word.islower(): + out.append(word) + elif len(word) >= 2 and word.isupper(): + out.append(allcaps) + out.append(lower_word) + elif _is_ascii_upper(word[0]) and word[1:].islower(): + out.append(title) + out.append(lower_word) + else: + if _is_ascii_upper(word[0]): + out.append(title) + out.append(lower_word[0]) + for orig_ch, lower_ch in zip(word[1:], lower_word[1:], strict=True): + if _is_ascii_upper(orig_ch): + out.append(capnext) + out.append(lower_ch) + i = j + return "".join(out) + + +def decode_lossless_caps_v2( + text: str, + *, + title: str = DEFAULT_V2_TITLE, + allcaps: str = DEFAULT_V2_ALLCAPS, + capnext: str = DEFAULT_V2_CAPNEXT, + esc: str = DEFAULT_V2_ESC, +) -> str: + """Decode the `lossless_caps_v2` transform back to the original text.""" + _validate_distinct_single_chars(title, allcaps, capnext, esc) + out: list[str] = [] + pending_escape = False + pending_word_mode: str | None = None + active_allcaps = False + pending_capnext = False + in_ascii_word = False + + for ch in text: + if pending_escape: + if pending_word_mode is not None and not _is_ascii_alpha(ch): + raise LosslessCapsError("escaped control char cannot satisfy pending word capitalization mode") + out.append(ch) + pending_escape = False + if _is_ascii_alpha(ch): + in_ascii_word = True + else: + in_ascii_word = False + active_allcaps = False + continue + + if ch == esc: + pending_escape = True + continue + if ch == title: + if pending_word_mode is not None or in_ascii_word or pending_capnext: + raise LosslessCapsError("invalid title marker placement") + pending_word_mode = "title" + continue + if ch == allcaps: + if pending_word_mode is not None or in_ascii_word or pending_capnext: + raise LosslessCapsError("invalid allcaps marker placement") + pending_word_mode = "allcaps" + continue + if ch == capnext: + if pending_capnext: + raise LosslessCapsError("duplicate capnext marker") + pending_capnext = True + continue + + if _is_ascii_alpha(ch): + at_word_start = not in_ascii_word + if at_word_start: + if pending_word_mode == "allcaps": + out.append(ch.upper()) + active_allcaps = True + elif pending_word_mode == "title": + out.append(ch.upper()) + elif pending_capnext: + out.append(ch.upper()) + else: + out.append(ch) + pending_word_mode = None + pending_capnext = False + in_ascii_word = True + continue + + if pending_word_mode is not None: + raise LosslessCapsError("word capitalization marker leaked into the middle of a word") + if active_allcaps: + out.append(ch.upper()) + elif pending_capnext: + out.append(ch.upper()) + else: + out.append(ch) + pending_capnext = False + continue + + if pending_word_mode is not None or pending_capnext: + raise LosslessCapsError("capitalization marker not followed by an ASCII letter") + out.append(ch) + in_ascii_word = False + active_allcaps = False + + if pending_escape: + raise LosslessCapsError("dangling escape marker at end of string") + if pending_word_mode is not None or pending_capnext: + raise LosslessCapsError("dangling capitalization marker at end of string") + return "".join(out) + + +def encode_lossless_caps_v3( + text: str, + *, + title: str = DEFAULT_V2_TITLE, + allcaps: str = DEFAULT_V2_ALLCAPS, + esc: str = DEFAULT_V2_ESC, +) -> str: + """Encode only common word-level capitalization patterns. + + Rules over maximal ASCII alphabetic runs: + - lowercase words stay unchanged + - TitleCase words become `title + lowercase(word)` + - ALLCAPS words become `allcaps + lowercase(word)` + - all other mixed-case words are left unchanged + - literal control characters are escaped as `esc + literal` + """ + _validate_distinct_single_chars(title, allcaps, esc) + controls = {title, allcaps, esc} + out: list[str] = [] + i = 0 + n = len(text) + while i < n: + ch = text[i] + if ch in controls: + out.append(esc) + out.append(ch) + i += 1 + continue + if not _is_ascii_alpha(ch): + out.append(ch) + i += 1 + continue + + j = i + 1 + while j < n and _is_ascii_alpha(text[j]): + j += 1 + word = text[i:j] + + if word.islower(): + out.append(word) + elif len(word) >= 2 and word.isupper(): + out.append(allcaps) + out.append(word.lower()) + elif _is_ascii_upper(word[0]) and word[1:].islower(): + out.append(title) + out.append(word.lower()) + else: + out.append(word) + i = j + return "".join(out) + + +def decode_lossless_caps_v3( + text: str, + *, + title: str = DEFAULT_V2_TITLE, + allcaps: str = DEFAULT_V2_ALLCAPS, + esc: str = DEFAULT_V2_ESC, +) -> str: + """Decode the `lossless_caps_v3` transform back to the original text.""" + _validate_distinct_single_chars(title, allcaps, esc) + out: list[str] = [] + pending_escape = False + pending_word_mode: str | None = None + active_allcaps = False + in_ascii_word = False + + for ch in text: + if pending_escape: + if pending_word_mode is not None and not _is_ascii_alpha(ch): + raise LosslessCapsError("escaped control char cannot satisfy pending word capitalization mode") + out.append(ch) + pending_escape = False + if _is_ascii_alpha(ch): + in_ascii_word = True + else: + in_ascii_word = False + active_allcaps = False + continue + + if ch == esc: + pending_escape = True + continue + if ch == title: + if pending_word_mode is not None or in_ascii_word: + raise LosslessCapsError("invalid title marker placement") + pending_word_mode = "title" + continue + if ch == allcaps: + if pending_word_mode is not None or in_ascii_word: + raise LosslessCapsError("invalid allcaps marker placement") + pending_word_mode = "allcaps" + continue + + if _is_ascii_alpha(ch): + at_word_start = not in_ascii_word + if at_word_start: + if pending_word_mode == "allcaps": + out.append(ch.upper()) + active_allcaps = True + elif pending_word_mode == "title": + out.append(ch.upper()) + else: + out.append(ch) + pending_word_mode = None + in_ascii_word = True + continue + + if pending_word_mode is not None: + raise LosslessCapsError("word capitalization marker leaked into the middle of a word") + out.append(ch.upper() if active_allcaps else ch) + continue + + if pending_word_mode is not None: + raise LosslessCapsError("capitalization marker not followed by an ASCII letter") + out.append(ch) + in_ascii_word = False + active_allcaps = False + + if pending_escape: + raise LosslessCapsError("dangling escape marker at end of string") + if pending_word_mode is not None: + raise LosslessCapsError("dangling capitalization marker at end of string") + return "".join(out) + + +def encode_lossless_caps_v4( + text: str, + *, + allcaps: str = DEFAULT_V2_ALLCAPS, + esc: str = DEFAULT_V2_ESC, +) -> str: + """Encode only ALLCAPS ASCII words, leaving all other case untouched.""" + _validate_distinct_single_chars(allcaps, esc) + controls = {allcaps, esc} + out: list[str] = [] + i = 0 + n = len(text) + while i < n: + ch = text[i] + if ch in controls: + out.append(esc) + out.append(ch) + i += 1 + continue + if not _is_ascii_alpha(ch): + out.append(ch) + i += 1 + continue + j = i + 1 + while j < n and _is_ascii_alpha(text[j]): + j += 1 + word = text[i:j] + if len(word) >= 2 and word.isupper(): + out.append(allcaps) + out.append(word.lower()) + else: + out.append(word) + i = j + return "".join(out) + + +def decode_lossless_caps_v4( + text: str, + *, + allcaps: str = DEFAULT_V2_ALLCAPS, + esc: str = DEFAULT_V2_ESC, +) -> str: + """Decode the `lossless_caps_v4` transform back to the original text.""" + _validate_distinct_single_chars(allcaps, esc) + out: list[str] = [] + pending_escape = False + pending_allcaps = False + in_ascii_word = False + active_allcaps = False + + for ch in text: + if pending_escape: + if pending_allcaps and not _is_ascii_alpha(ch): + raise LosslessCapsError("escaped control char cannot satisfy pending allcaps mode") + out.append(ch) + pending_escape = False + if _is_ascii_alpha(ch): + in_ascii_word = True + else: + in_ascii_word = False + active_allcaps = False + continue + + if ch == esc: + pending_escape = True + continue + if ch == allcaps: + if pending_allcaps or in_ascii_word: + raise LosslessCapsError("invalid allcaps marker placement") + pending_allcaps = True + continue + + if _is_ascii_alpha(ch): + if not in_ascii_word: + active_allcaps = pending_allcaps + pending_allcaps = False + in_ascii_word = True + out.append(ch.upper() if active_allcaps else ch) + continue + + if pending_allcaps: + raise LosslessCapsError("allcaps marker not followed by an ASCII letter") + out.append(ch) + in_ascii_word = False + active_allcaps = False + + if pending_escape: + raise LosslessCapsError("dangling escape marker at end of string") + if pending_allcaps: + raise LosslessCapsError("dangling allcaps marker at end of string") + return "".join(out) + + +def encode_lossless_caps_v5( + text: str, + *, + title: str = DEFAULT_V2_TITLE, + allcaps: str = DEFAULT_V2_ALLCAPS, + esc: str = DEFAULT_V2_ESC, + title_min_len: int = DEFAULT_V5_TITLE_MIN_LEN, +) -> str: + """Encode ALLCAPS words and only sufficiently long TitleCase words.""" + _validate_distinct_single_chars(title, allcaps, esc) + controls = {title, allcaps, esc} + out: list[str] = [] + i = 0 + n = len(text) + while i < n: + ch = text[i] + if ch in controls: + out.append(esc) + out.append(ch) + i += 1 + continue + if not _is_ascii_alpha(ch): + out.append(ch) + i += 1 + continue + j = i + 1 + while j < n and _is_ascii_alpha(text[j]): + j += 1 + word = text[i:j] + if len(word) >= 2 and word.isupper(): + out.append(allcaps) + out.append(word.lower()) + elif len(word) >= title_min_len and _is_ascii_upper(word[0]) and word[1:].islower(): + out.append(title) + out.append(word.lower()) + else: + out.append(word) + i = j + return "".join(out) + + +def decode_lossless_caps_v5( + text: str, + *, + title: str = DEFAULT_V2_TITLE, + allcaps: str = DEFAULT_V2_ALLCAPS, + esc: str = DEFAULT_V2_ESC, +) -> str: + """Decode the `lossless_caps_v5` transform back to the original text.""" + return decode_lossless_caps_v3(text, title=title, allcaps=allcaps, esc=esc) + + +def encode_lossless_caps_v6( + text: str, + *, + allcaps: str = DEFAULT_V2_ALLCAPS, + esc: str = DEFAULT_V2_ESC, + allcaps_min_len: int = DEFAULT_V6_ALLCAPS_MIN_LEN, +) -> str: + """Encode only ALLCAPS words with length >= allcaps_min_len.""" + _validate_distinct_single_chars(allcaps, esc) + controls = {allcaps, esc} + out: list[str] = [] + i = 0 + n = len(text) + while i < n: + ch = text[i] + if ch in controls: + out.append(esc) + out.append(ch) + i += 1 + continue + if not _is_ascii_alpha(ch): + out.append(ch) + i += 1 + continue + j = i + 1 + while j < n and _is_ascii_alpha(text[j]): + j += 1 + word = text[i:j] + if len(word) >= allcaps_min_len and word.isupper(): + out.append(allcaps) + out.append(word.lower()) + else: + out.append(word) + i = j + return "".join(out) + + +def decode_lossless_caps_v6( + text: str, + *, + allcaps: str = DEFAULT_V2_ALLCAPS, + esc: str = DEFAULT_V2_ESC, +) -> str: + """Decode the `lossless_caps_v6` transform back to the original text.""" + return decode_lossless_caps_v4(text, allcaps=allcaps, esc=esc) + + +def encode_lossless_caps_v7( + text: str, + *, + allcaps: str = DEFAULT_V2_ALLCAPS, + esc: str = DEFAULT_V2_ESC, + allcaps_min_len: int = DEFAULT_V7_ALLCAPS_MIN_LEN, +) -> str: + """Encode only ALLCAPS words with length >= 4.""" + return encode_lossless_caps_v6( + text, + allcaps=allcaps, + esc=esc, + allcaps_min_len=allcaps_min_len, + ) + + +def decode_lossless_caps_v7( + text: str, + *, + allcaps: str = DEFAULT_V2_ALLCAPS, + esc: str = DEFAULT_V2_ESC, +) -> str: + """Decode the `lossless_caps_v7` transform back to the original text.""" + return decode_lossless_caps_v6(text, allcaps=allcaps, esc=esc) + + +def get_text_transform(name: str | None) -> Callable[[str], str]: + """Return the forward text transform for the given config name.""" + normalized = IDENTITY if name in {None, "", IDENTITY} else str(name) + if normalized == IDENTITY: + return lambda text: text + if normalized == LOSSLESS_CAPS_V1: + return encode_lossless_caps_v1 + if normalized == LOSSLESS_CAPS_V2: + return encode_lossless_caps_v2 + if normalized == LOSSLESS_CAPS_V3: + return encode_lossless_caps_v3 + if normalized == LOSSLESS_CAPS_V4: + return encode_lossless_caps_v4 + if normalized == LOSSLESS_CAPS_V5: + return encode_lossless_caps_v5 + if normalized == LOSSLESS_CAPS_V6: + return encode_lossless_caps_v6 + if normalized == LOSSLESS_CAPS_V7: + return encode_lossless_caps_v7 + if normalized == LOSSLESS_CAPS_CASEOPS_V1: + return encode_lossless_caps_v2 + raise ValueError(f"unsupported text_transform={name!r}") + + +def get_text_inverse_transform(name: str | None) -> Callable[[str], str]: + """Return the inverse transform for the given config name.""" + normalized = IDENTITY if name in {None, "", IDENTITY} else str(name) + if normalized == IDENTITY: + return lambda text: text + if normalized == LOSSLESS_CAPS_V1: + return decode_lossless_caps_v1 + if normalized == LOSSLESS_CAPS_V2: + return decode_lossless_caps_v2 + if normalized == LOSSLESS_CAPS_V3: + return decode_lossless_caps_v3 + if normalized == LOSSLESS_CAPS_V4: + return decode_lossless_caps_v4 + if normalized == LOSSLESS_CAPS_V5: + return decode_lossless_caps_v5 + if normalized == LOSSLESS_CAPS_V6: + return decode_lossless_caps_v6 + if normalized == LOSSLESS_CAPS_V7: + return decode_lossless_caps_v7 + if normalized == LOSSLESS_CAPS_CASEOPS_V1: + return decode_lossless_caps_v2 + raise ValueError(f"unsupported text_transform={name!r}") + + +def normalize_text_transform_name(name: str | None) -> str: + """Normalize empty/None transform names to the identity transform.""" + return IDENTITY if name in {None, "", IDENTITY} else str(name) + + +def get_text_transform_control_symbols(name: str | None) -> list[str]: + """Return reserved control symbols used by a transform, if any.""" + normalized = normalize_text_transform_name(name) + if normalized == IDENTITY: + return [] + if normalized == LOSSLESS_CAPS_V1: + return [DEFAULT_SENTINEL] + if normalized == LOSSLESS_CAPS_V2: + return [DEFAULT_V2_TITLE, DEFAULT_V2_ALLCAPS, DEFAULT_V2_CAPNEXT, DEFAULT_V2_ESC] + if normalized == LOSSLESS_CAPS_CASEOPS_V1: + return [DEFAULT_V2_TITLE, DEFAULT_V2_ALLCAPS, DEFAULT_V2_CAPNEXT, DEFAULT_V2_ESC] + if normalized in {LOSSLESS_CAPS_V3, LOSSLESS_CAPS_V5}: + return [DEFAULT_V2_TITLE, DEFAULT_V2_ALLCAPS, DEFAULT_V2_ESC] + if normalized in {LOSSLESS_CAPS_V4, LOSSLESS_CAPS_V6, LOSSLESS_CAPS_V7}: + return [DEFAULT_V2_ALLCAPS, DEFAULT_V2_ESC] + raise ValueError(f"unsupported text_transform={name!r}") + + +def infer_text_transform_from_manifest(tokenizer_path: str | Path) -> str: + """Best-effort lookup of a tokenizer's text transform from a local manifest.""" + tokenizer_path = Path(tokenizer_path).expanduser().resolve() + manifest_candidates = [ + tokenizer_path.parent.parent / "manifest.json", + tokenizer_path.parent / "manifest.json", + ] + for manifest_path in manifest_candidates: + if not manifest_path.is_file(): + continue + try: + payload = json.loads(manifest_path.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError): + continue + tokenizers = payload.get("tokenizers") + if not isinstance(tokenizers, list): + continue + for tokenizer_meta in tokenizers: + if not isinstance(tokenizer_meta, dict): + continue + model_path = tokenizer_meta.get("model_path") or tokenizer_meta.get("path") + if not model_path: + continue + candidate = (manifest_path.parent / str(model_path)).resolve() + if candidate == tokenizer_path: + return normalize_text_transform_name(tokenizer_meta.get("text_transform")) + return IDENTITY + + +def surface_piece_original_byte_counts( + surfaces: Iterable[str], + *, + text_transform_name: str | None = None, + sentinel: str = DEFAULT_SENTINEL, +) -> list[int]: + """Return exact original UTF-8 byte counts contributed by each surface piece. + + `surfaces` must be the exact decoded text fragments emitted by SentencePiece + in order, e.g. `piece.surface` from `encode_as_immutable_proto`. + """ + normalized = normalize_text_transform_name(text_transform_name) + if normalized == IDENTITY: + return [len(surface.encode("utf-8")) for surface in surfaces] + if normalized == LOSSLESS_CAPS_V1: + if len(sentinel) != 1: + raise ValueError("sentinel must be exactly one character") + sentinel_bytes = len(sentinel.encode("utf-8")) + pending_sentinel = False + counts: list[int] = [] + for surface in surfaces: + piece_bytes = 0 + for ch in surface: + if pending_sentinel: + if ch == sentinel: + piece_bytes += sentinel_bytes + elif _is_ascii_lower(ch): + piece_bytes += 1 + else: + raise LosslessCapsError( + f"invalid continuation {ch!r} after capitalization sentinel" + ) + pending_sentinel = False + continue + if ch == sentinel: + pending_sentinel = True + else: + piece_bytes += len(ch.encode("utf-8")) + counts.append(piece_bytes) + if pending_sentinel: + raise LosslessCapsError("dangling capitalization sentinel across piece boundary") + return counts + if normalized not in {LOSSLESS_CAPS_V2, LOSSLESS_CAPS_V3, LOSSLESS_CAPS_V4, LOSSLESS_CAPS_V5, LOSSLESS_CAPS_V6, LOSSLESS_CAPS_V7, LOSSLESS_CAPS_CASEOPS_V1}: + raise ValueError(f"unsupported text_transform={text_transform_name!r}") + + title = DEFAULT_V2_TITLE + allcaps = DEFAULT_V2_ALLCAPS + capnext = DEFAULT_V2_CAPNEXT + esc = DEFAULT_V2_ESC + if normalized in {LOSSLESS_CAPS_V2, LOSSLESS_CAPS_CASEOPS_V1}: + _validate_distinct_single_chars(title, allcaps, capnext, esc) + elif normalized in {LOSSLESS_CAPS_V4, LOSSLESS_CAPS_V6, LOSSLESS_CAPS_V7}: + _validate_distinct_single_chars(allcaps, esc) + else: + _validate_distinct_single_chars(title, allcaps, esc) + pending_escape = False + pending_word_mode: str | None = None + active_allcaps = False + pending_capnext = False + in_ascii_word = False + counts: list[int] = [] + for surface in surfaces: + piece_bytes = 0 + for ch in surface: + if pending_escape: + if pending_word_mode is not None and not _is_ascii_alpha(ch): + raise LosslessCapsError("escaped control char cannot satisfy pending word capitalization mode") + piece_bytes += len(ch.encode("utf-8")) + pending_escape = False + if _is_ascii_alpha(ch): + in_ascii_word = True + else: + in_ascii_word = False + active_allcaps = False + continue + if ch == esc: + pending_escape = True + continue + if normalized in {LOSSLESS_CAPS_V2, LOSSLESS_CAPS_V3, LOSSLESS_CAPS_V5, LOSSLESS_CAPS_CASEOPS_V1} and ch == title: + if pending_word_mode is not None or in_ascii_word or pending_capnext: + raise LosslessCapsError("invalid title marker placement") + pending_word_mode = "title" + continue + if ch == allcaps: + if pending_word_mode is not None or in_ascii_word or pending_capnext: + raise LosslessCapsError("invalid allcaps marker placement") + pending_word_mode = "allcaps" + continue + if normalized in {LOSSLESS_CAPS_V2, LOSSLESS_CAPS_CASEOPS_V1} and ch == capnext: + if pending_capnext: + raise LosslessCapsError("duplicate capnext marker") + pending_capnext = True + continue + + if _is_ascii_alpha(ch): + at_word_start = not in_ascii_word + if at_word_start: + piece_bytes += 1 + active_allcaps = pending_word_mode == "allcaps" + pending_word_mode = None + pending_capnext = False + in_ascii_word = True + continue + if pending_word_mode is not None: + raise LosslessCapsError("word capitalization marker leaked into the middle of a word") + piece_bytes += 1 + pending_capnext = False + continue + + if pending_word_mode is not None or pending_capnext: + raise LosslessCapsError("capitalization marker not followed by an ASCII letter") + piece_bytes += len(ch.encode("utf-8")) + in_ascii_word = False + active_allcaps = False + counts.append(piece_bytes) + if pending_escape: + raise LosslessCapsError("dangling escape marker across piece boundary") + if pending_word_mode is not None or pending_capnext: + raise LosslessCapsError("dangling capitalization marker across piece boundary") + return counts diff --git a/records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/online_ngram_state.c b/records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/online_ngram_state.c new file mode 100644 index 0000000000..f8472a6f05 --- /dev/null +++ b/records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/online_ngram_state.c @@ -0,0 +1,433 @@ +#include +#include +#include + +#define COEFF_COUNT 32 + +static const uint64_t ROLLING_COEFFS[COEFF_COUNT] = { + 36313ULL, 27191ULL, 51647ULL, 81929ULL, 131071ULL, 196613ULL, + 262147ULL, 393241ULL, 524309ULL, 655373ULL, 786433ULL, 917521ULL, + 1048583ULL, 1179653ULL, 1310729ULL, 1441801ULL, 1572869ULL, 1703941ULL, + 1835017ULL, 1966087ULL, 2097169ULL, 2228243ULL, 2359319ULL, 2490389ULL, + 2621471ULL, 2752549ULL, 2883617ULL, 3014687ULL, 3145757ULL, 3276833ULL, + 3407903ULL, 3538973ULL, +}; + +static const uint64_t PAIR_MIX = 1000003ULL; +static const uint64_t PREFIX_BASE = 1099511628211ULL; +static const uint64_t LEN_MIX = 0x9E3779B185EBCA87ULL; +static const uint64_t TABLE_MIX = 0x9e3779b97f4a7c15ULL; + +typedef struct { + uint64_t key; + uint32_t total; + uint32_t top_count; + uint16_t top_tok; + uint16_t _pad; +} CtxBucket; + +typedef struct { + uint64_t key; + uint32_t count; + uint32_t _pad; +} PairBucket; + +typedef struct { + int token_ctx_len; + int token_prefix_len; + int token_head; + uint16_t *token_ring; + + CtxBucket *token_ctx_tbl; + uint8_t *token_ctx_used; + size_t token_ctx_mask; + + PairBucket *token_pair_tbl; + uint8_t *token_pair_used; + size_t token_pair_mask; + + uint64_t within_hash; + uint32_t within_len; + + CtxBucket *within_ctx_tbl; + uint8_t *within_ctx_used; + size_t within_ctx_mask; + + PairBucket *within_pair_tbl; + uint8_t *within_pair_used; + size_t within_pair_mask; +} OnlineNgramState; + +static inline size_t mix_index(uint64_t key, size_t mask) { + return (size_t)((key * TABLE_MIX) & mask); +} + +static inline size_t find_ctx_slot( + CtxBucket *tbl, + uint8_t *used, + size_t mask, + uint64_t key, + int *found +) { + size_t idx = mix_index(key, mask); + for (size_t probe = 0; probe <= mask; ++probe) { + if (!used[idx]) { + *found = 0; + return idx; + } + if (tbl[idx].key == key) { + *found = 1; + return idx; + } + idx = (idx + 1U) & mask; + } + *found = -1; + return 0; +} + +static inline size_t find_pair_slot( + PairBucket *tbl, + uint8_t *used, + size_t mask, + uint64_t key, + int *found +) { + size_t idx = mix_index(key, mask); + for (size_t probe = 0; probe <= mask; ++probe) { + if (!used[idx]) { + *found = 0; + return idx; + } + if (tbl[idx].key == key) { + *found = 1; + return idx; + } + idx = (idx + 1U) & mask; + } + *found = -1; + return 0; +} + +static inline uint64_t token_pair_key(uint64_t ctx_key, uint16_t tok, int ctx_len) { + return (ctx_key * PAIR_MIX) ^ (((uint64_t)tok) * ROLLING_COEFFS[(size_t)ctx_len % COEFF_COUNT]); +} + +static inline uint64_t within_pair_key(uint64_t ctx_key, uint16_t tok) { + return (ctx_key * PAIR_MIX) ^ (((uint64_t)tok) * ROLLING_COEFFS[0]); +} + +static inline uint64_t extend_prefix_hash(uint64_t current_hash, uint16_t tok, uint32_t pos) { + return (current_hash * PREFIX_BASE) ^ (((uint64_t)tok + 1ULL) * ROLLING_COEFFS[(size_t)pos % COEFF_COUNT]); +} + +static inline uint32_t pair_increment( + PairBucket *tbl, + uint8_t *used, + size_t mask, + uint64_t key +) { + int found = 0; + size_t idx = find_pair_slot(tbl, used, mask, key, &found); + if (found < 0) { + return 0U; + } + if (!found) { + used[idx] = 1U; + tbl[idx].key = key; + tbl[idx].count = 1U; + return 1U; + } + tbl[idx].count += 1U; + return tbl[idx].count; +} + +static inline int ctx_increment( + CtxBucket *tbl, + uint8_t *used, + size_t mask, + uint64_t key, + uint16_t tok, + uint32_t pair_count +) { + int found = 0; + size_t idx = find_ctx_slot(tbl, used, mask, key, &found); + if (found < 0) { + return -1; + } + if (!found) { + used[idx] = 1U; + tbl[idx].key = key; + tbl[idx].total = 1U; + tbl[idx].top_count = pair_count; + tbl[idx].top_tok = tok; + return 0; + } + tbl[idx].total += 1U; + if (pair_count > tbl[idx].top_count) { + tbl[idx].top_count = pair_count; + tbl[idx].top_tok = tok; + } + return 0; +} + +static inline uint64_t token_context_hash(const OnlineNgramState *st) { + uint64_t h = 0ULL; + if (st->token_ctx_len <= 0) { + return h; + } + for (int j = 0; j < st->token_ctx_len; ++j) { + const int ring_idx = (st->token_head + j) % st->token_ctx_len; + h ^= ((uint64_t)st->token_ring[ring_idx]) * ROLLING_COEFFS[(size_t)j]; + } + return h; +} + +static inline void token_push(OnlineNgramState *st, uint16_t tok) { + if (st->token_ctx_len <= 0) { + return; + } + if (st->token_prefix_len < st->token_ctx_len) { + st->token_ring[st->token_prefix_len] = tok; + st->token_prefix_len += 1; + return; + } + st->token_ring[st->token_head] = tok; + st->token_head = (st->token_head + 1) % st->token_ctx_len; +} + +static void *xcalloc(size_t count, size_t size) { + if (count == 0 || size == 0) { + return NULL; + } + return calloc(count, size); +} + +static int alloc_tables( + size_t table_bits, + CtxBucket **ctx_tbl, + uint8_t **ctx_used, + size_t *ctx_mask, + PairBucket **pair_tbl, + uint8_t **pair_used, + size_t *pair_mask +) { + const size_t size = 1ULL << table_bits; + *ctx_tbl = (CtxBucket *)xcalloc(size, sizeof(CtxBucket)); + *ctx_used = (uint8_t *)xcalloc(size, sizeof(uint8_t)); + *pair_tbl = (PairBucket *)xcalloc(size, sizeof(PairBucket)); + *pair_used = (uint8_t *)xcalloc(size, sizeof(uint8_t)); + if (!*ctx_tbl || !*ctx_used || !*pair_tbl || !*pair_used) { + return -1; + } + *ctx_mask = size - 1U; + *pair_mask = size - 1U; + return 0; +} + +void *online_ngram_state_create( + int token_ctx_len, + int token_table_bits, + int within_table_bits +) { + if (token_ctx_len < 0 || token_table_bits <= 0 || within_table_bits <= 0) { + return NULL; + } + OnlineNgramState *st = (OnlineNgramState *)calloc(1, sizeof(OnlineNgramState)); + if (!st) { + return NULL; + } + st->token_ctx_len = token_ctx_len; + if (token_ctx_len > 0) { + st->token_ring = (uint16_t *)xcalloc((size_t)token_ctx_len, sizeof(uint16_t)); + if (!st->token_ring) { + free(st); + return NULL; + } + } + if (alloc_tables( + (size_t)token_table_bits, + &st->token_ctx_tbl, + &st->token_ctx_used, + &st->token_ctx_mask, + &st->token_pair_tbl, + &st->token_pair_used, + &st->token_pair_mask + ) != 0) { + free(st->token_ring); + free(st); + return NULL; + } + if (alloc_tables( + (size_t)within_table_bits, + &st->within_ctx_tbl, + &st->within_ctx_used, + &st->within_ctx_mask, + &st->within_pair_tbl, + &st->within_pair_used, + &st->within_pair_mask + ) != 0) { + free(st->token_pair_used); + free(st->token_pair_tbl); + free(st->token_ctx_used); + free(st->token_ctx_tbl); + free(st->token_ring); + free(st); + return NULL; + } + return (void *)st; +} + +void online_ngram_state_destroy(void *ptr) { + OnlineNgramState *st = (OnlineNgramState *)ptr; + if (!st) { + return; + } + free(st->within_pair_used); + free(st->within_pair_tbl); + free(st->within_ctx_used); + free(st->within_ctx_tbl); + free(st->token_pair_used); + free(st->token_pair_tbl); + free(st->token_ctx_used); + free(st->token_ctx_tbl); + free(st->token_ring); + free(st); +} + +void online_ngram_state_seed_prefix_token(void *ptr, uint16_t tok) { + OnlineNgramState *st = (OnlineNgramState *)ptr; + if (!st) { + return; + } + token_push(st, tok); +} + +int online_ngram_state_process_chunk( + void *ptr, + const uint16_t *tokens, + int64_t n_tokens, + const uint8_t *starts_new_word_lut, + const uint8_t *boundary_lut, + uint16_t *token_top_token, + float *token_top_prob, + uint16_t *within_top_token, + float *within_top_prob, + uint8_t *within_valid +) { + OnlineNgramState *st = (OnlineNgramState *)ptr; + if (!st || !tokens || n_tokens < 0) { + return -1; + } + for (int64_t i = 0; i < n_tokens; ++i) { + const uint16_t tok = tokens[i]; + const uint8_t is_boundary = boundary_lut[tok]; + const uint8_t is_new_word = starts_new_word_lut[tok]; + + uint64_t token_ctx_key = 0ULL; + if (st->token_ctx_len == 0 || st->token_prefix_len >= st->token_ctx_len) { + token_ctx_key = token_context_hash(st); + int found = 0; + size_t idx = find_ctx_slot( + st->token_ctx_tbl, + st->token_ctx_used, + st->token_ctx_mask, + token_ctx_key, + &found + ); + if (found > 0) { + token_top_token[i] = st->token_ctx_tbl[idx].top_tok; + token_top_prob[i] = + (float)st->token_ctx_tbl[idx].top_count / (float)st->token_ctx_tbl[idx].total; + } else { + token_top_token[i] = 0U; + token_top_prob[i] = 0.0f; + } + } else { + token_top_token[i] = 0U; + token_top_prob[i] = 0.0f; + } + + uint64_t within_ctx_key = 0ULL; + if (!is_boundary && !is_new_word && st->within_len > 0U) { + within_ctx_key = st->within_hash ^ ((uint64_t)st->within_len * LEN_MIX); + int found = 0; + size_t idx = find_ctx_slot( + st->within_ctx_tbl, + st->within_ctx_used, + st->within_ctx_mask, + within_ctx_key, + &found + ); + within_valid[i] = 1U; + if (found > 0) { + within_top_token[i] = st->within_ctx_tbl[idx].top_tok; + within_top_prob[i] = + (float)st->within_ctx_tbl[idx].top_count / (float)st->within_ctx_tbl[idx].total; + } else { + within_top_token[i] = 0U; + within_top_prob[i] = 0.0f; + } + } else { + within_valid[i] = 0U; + within_top_token[i] = 0U; + within_top_prob[i] = 0.0f; + } + + if (st->token_ctx_len == 0 || st->token_prefix_len >= st->token_ctx_len) { + const uint64_t pair_key = token_pair_key(token_ctx_key, tok, st->token_ctx_len); + const uint32_t pair_count = pair_increment( + st->token_pair_tbl, + st->token_pair_used, + st->token_pair_mask, + pair_key + ); + if (pair_count == 0U) { + return -2; + } + if (ctx_increment( + st->token_ctx_tbl, + st->token_ctx_used, + st->token_ctx_mask, + token_ctx_key, + tok, + pair_count + ) != 0) { + return -3; + } + } + token_push(st, tok); + + if (is_boundary) { + st->within_hash = 0ULL; + st->within_len = 0U; + continue; + } + if (is_new_word || st->within_len == 0U) { + st->within_hash = extend_prefix_hash(0ULL, tok, 0U); + st->within_len = 1U; + continue; + } + const uint32_t within_pair_count = pair_increment( + st->within_pair_tbl, + st->within_pair_used, + st->within_pair_mask, + within_pair_key(within_ctx_key, tok) + ); + if (within_pair_count == 0U) { + return -4; + } + if (ctx_increment( + st->within_ctx_tbl, + st->within_ctx_used, + st->within_ctx_mask, + within_ctx_key, + tok, + within_pair_count + ) != 0) { + return -5; + } + st->within_hash = extend_prefix_hash(st->within_hash, tok, st->within_len); + st->within_len += 1U; + } + return 0; +} diff --git a/records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/online_ngram_tilt.py b/records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/online_ngram_tilt.py new file mode 100644 index 0000000000..973c21866f --- /dev/null +++ b/records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/online_ngram_tilt.py @@ -0,0 +1,386 @@ +""" +Vendored online n-gram tilt helpers from PR #1145 (AnirudhRahul, valerio-endorsed). + +Provides causal, normalized, prefix-only n-gram experts that propose at most one +hinted token per scored position. Caller obtains q_t = p(h_t | x) from the model +(post-TTT-adapt logits) and applies multiplicative-boost-with-renorm: + + p'(a) = exp(beta * 1[a == h_t]) * p(a) / Z_t + Z_t = 1 - q_t + exp(beta) * q_t = 1 + q_t * (exp(beta) - 1) + -log p'(y_realized) = -log p(y) - beta * 1[y == h_t] + log Z_t + = ptl - beta * is_hit + log1p(q_t * (exp(beta) - 1)) + +Compliance: +- C1 causal: hint h_t computed from strict prefix (tokens 0..t-1 only) +- C2 normalized over Sigma: closed-form Z_t over full vocab softmax +- C3 score-before-update: hints precomputed in single L->R pass; loss uses prefix-only +- C4 single pass: process_chunk advances state monotonically + +Compatible with both #1934/#1855 base architectures via Hyperparameter env-var gates. +""" + +from __future__ import annotations + +import ctypes +import math +import os +import subprocess +from collections import deque +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch + + +SCRIPT_DIR = Path(__file__).resolve().parent +ONLINE_NGRAM_SRC = SCRIPT_DIR / "online_ngram_state.c" +ONLINE_NGRAM_LIB = SCRIPT_DIR / "libonline_ngram_state.so" + +WHITESPACE_BYTE_IDS = {9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 36} +EDGE_PUNCT = ".,:;!?()[]{}<>\"'`" + + +def normalize_word(text: str, mode: str) -> str: + text = text.strip() + if mode == "lower": + return text.lower() + if mode == "identity": + return text + if mode == "strip_punct_lower": + return text.strip(EDGE_PUNCT).lower() + raise ValueError(f"Unknown word normalization mode: {mode}") + + +def suggest_table_bits(expected_entries: int, load_factor: float) -> int: + if expected_entries <= 0: + return 16 + target = max(int(expected_entries / max(load_factor, 1e-6)), 1) + bits = max(int(math.ceil(math.log2(target))), 12) + return min(bits, 28) + + +def ensure_online_ngram_lib(log0=print) -> ctypes.CDLL: + needs_build = (not ONLINE_NGRAM_LIB.exists()) or ( + ONLINE_NGRAM_SRC.stat().st_mtime_ns > ONLINE_NGRAM_LIB.stat().st_mtime_ns + ) + if needs_build: + log0(f"ngram_tilt:building_native_helper src={ONLINE_NGRAM_SRC.name}") + subprocess.run( + [ + "gcc", "-O3", "-march=native", "-shared", "-fPIC", + "-o", str(ONLINE_NGRAM_LIB), + str(ONLINE_NGRAM_SRC), + ], + check=True, + ) + lib = ctypes.CDLL(str(ONLINE_NGRAM_LIB)) + lib.online_ngram_state_create.restype = ctypes.c_void_p + lib.online_ngram_state_create.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int] + lib.online_ngram_state_destroy.restype = None + lib.online_ngram_state_destroy.argtypes = [ctypes.c_void_p] + lib.online_ngram_state_seed_prefix_token.restype = None + lib.online_ngram_state_seed_prefix_token.argtypes = [ctypes.c_void_p, ctypes.c_uint16] + lib.online_ngram_state_process_chunk.restype = ctypes.c_int + lib.online_ngram_state_process_chunk.argtypes = [ + ctypes.c_void_p, + ctypes.POINTER(ctypes.c_uint16), + ctypes.c_int64, + ctypes.POINTER(ctypes.c_uint8), + ctypes.POINTER(ctypes.c_uint8), + ctypes.POINTER(ctypes.c_uint16), + ctypes.POINTER(ctypes.c_float), + ctypes.POINTER(ctypes.c_uint16), + ctypes.POINTER(ctypes.c_float), + ctypes.POINTER(ctypes.c_uint8), + ] + return lib + + +class OnlineNgramState: + def __init__( + self, *, lib, token_ctx_len, token_table_bits, within_table_bits, + starts_new_word_lut, boundary_lut, seed_prefix_token, + ): + self.lib = lib + self.state = lib.online_ngram_state_create(token_ctx_len, token_table_bits, within_table_bits) + if not self.state: + raise RuntimeError( + f"Native ngram state alloc failed token_table_bits={token_table_bits} within_table_bits={within_table_bits}" + ) + self.starts_new_word_lut = np.ascontiguousarray(starts_new_word_lut.astype(np.uint8, copy=False)) + self.boundary_lut = np.ascontiguousarray(boundary_lut.astype(np.uint8, copy=False)) + self.lib.online_ngram_state_seed_prefix_token(self.state, ctypes.c_uint16(int(seed_prefix_token))) + + def close(self): + if self.state: + self.lib.online_ngram_state_destroy(self.state) + self.state = None + + def __del__(self): + self.close() + + def process_chunk(self, chunk_tokens): + chunk_tokens = np.ascontiguousarray(chunk_tokens.astype(np.uint16, copy=False)) + n = int(chunk_tokens.size) + token_top_token = np.zeros(n, dtype=np.uint16) + token_top_prob = np.zeros(n, dtype=np.float32) + within_top_token = np.zeros(n, dtype=np.uint16) + within_top_prob = np.zeros(n, dtype=np.float32) + within_valid = np.zeros(n, dtype=np.uint8) + rc = self.lib.online_ngram_state_process_chunk( + self.state, + chunk_tokens.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), + ctypes.c_int64(n), + self.starts_new_word_lut.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)), + self.boundary_lut.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)), + token_top_token.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), + token_top_prob.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), + within_top_token.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), + within_top_prob.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), + within_valid.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)), + ) + if rc != 0: + raise RuntimeError(f"Native ngram process_chunk failed rc={rc}") + return token_top_token, token_top_prob, within_top_token, within_top_prob, within_valid.astype(bool) + + +class WordStartState: + def __init__(self, *, sp, order, normalize_mode): + self.sp = sp + self.ctx_w = max(order - 1, 0) + self.normalize_mode = normalize_mode + self.prev_word_ids: deque = deque(maxlen=self.ctx_w) + self.current_word_tokens: list = [] + self.word_to_id: dict = {} + self.next_word_id = 1 + self.ctx_total: dict = {} + self.pair_count: dict = {} + self.ctx_best_token: dict = {} + self.ctx_best_count: dict = {} + + def _flush_current_word(self): + if not self.current_word_tokens: + return + text = normalize_word(self.sp.decode(self.current_word_tokens), self.normalize_mode) + if text: + wid = self.word_to_id.get(text) + if wid is None: + wid = self.next_word_id + self.word_to_id[text] = wid + self.next_word_id += 1 + if self.ctx_w > 0: + self.prev_word_ids.append(wid) + self.current_word_tokens = [] + + def process_chunk(self, chunk_tokens, *, starts_new_word_lut, boundary_lut): + chunk_tokens = np.ascontiguousarray(chunk_tokens.astype(np.uint16, copy=False)) + top_token = np.zeros(chunk_tokens.size, dtype=np.uint16) + top_prob = np.zeros(chunk_tokens.size, dtype=np.float32) + for i, tok_u16 in enumerate(chunk_tokens): + tok = int(tok_u16) + is_boundary = bool(boundary_lut[tok]) + is_word_start = bool(starts_new_word_lut[tok]) or not self.current_word_tokens + if is_boundary: + self._flush_current_word() + continue + if bool(starts_new_word_lut[tok]): + self._flush_current_word() + ctx_key = None + if is_word_start and len(self.prev_word_ids) >= self.ctx_w: + ctx_key = tuple(self.prev_word_ids) if self.ctx_w > 0 else () + total = self.ctx_total.get(ctx_key, 0) + if total > 0: + top_token[i] = np.uint16(self.ctx_best_token[ctx_key]) + top_prob[i] = np.float32(self.ctx_best_count[ctx_key] / total) + if is_word_start: + if ctx_key is not None: + pair_key = (ctx_key, tok) + pair = self.pair_count.get(pair_key, 0) + 1 + self.pair_count[pair_key] = pair + total = self.ctx_total.get(ctx_key, 0) + 1 + self.ctx_total[ctx_key] = total + best_count = self.ctx_best_count.get(ctx_key, 0) + if pair > best_count: + self.ctx_best_count[ctx_key] = pair + self.ctx_best_token[ctx_key] = tok + self.current_word_tokens = [tok] + else: + self.current_word_tokens.append(tok) + return top_token, top_prob + + +def build_piece_luts(*, tokenizer_path, vocab_size): + sp = spm.SentencePieceProcessor(model_file=tokenizer_path) + pieces = [sp.id_to_piece(i) for i in range(sp.vocab_size())] + starts_new_word_lut = np.zeros(vocab_size, dtype=np.uint8) + for i, piece in enumerate(pieces): + starts_new_word_lut[i] = 1 if piece.startswith("▁") else 0 + boundary_lut = np.zeros(vocab_size, dtype=np.uint8) + bos_id = sp.bos_id() + if bos_id >= 0 and bos_id < vocab_size: + boundary_lut[bos_id] = 1 + for tok in range(min(sp.vocab_size(), vocab_size)): + if sp.is_byte(tok) and tok in WHITESPACE_BYTE_IDS: + boundary_lut[tok] = 1 + return sp, starts_new_word_lut, boundary_lut + + +def build_hints_for_targets( + *, target_token_ids_np, tokenizer_path, vocab_size, log0=print, + token_order=16, token_threshold=0.800, token_boost=2.625, + within_tau=0.450, within_boost=0.750, + word_order=4, word_normalize="strip_punct_lower", + word_tau=0.650, word_boost=0.750, + agree_add_boost=0.500, +): + """Single L->R pass. Returns dict with hint_ids, gate_mask, boost_per_pos. + + target_token_ids_np: np.uint16 array of realized targets (length = total_targets). + Output arrays are aligned to target_token_ids_np indexing. + + For each scored position t we pick at most one hint h_t: + - prefer the expert with highest expected gain = p_top * boost - log1p(p_top * (exp(boost)-1)) + - if multiple experts agree on the same h_t, additive boost agree_add_boost + - gate (don't tilt) when no expert clears its threshold + + The realized loss formula used by the caller: + ptl' = ptl - beta * 1[y == h_t] + log1p(q_t * (exp(beta) - 1)) when gate_mask == True + ptl' = ptl when gate_mask == False + """ + sp, starts_new_word_lut, boundary_lut = build_piece_luts( + tokenizer_path=tokenizer_path, vocab_size=vocab_size + ) + total = int(target_token_ids_np.size) + if total == 0: + return { + "hint_ids": np.zeros(0, dtype=np.int64), + "gate_mask": np.zeros(0, dtype=bool), + "boost": np.zeros(0, dtype=np.float32), + "sp": sp, + "starts_new_word_lut": starts_new_word_lut, + "boundary_lut": boundary_lut, + } + + token_table_bits = suggest_table_bits(total, load_factor=0.55) + within_table_bits = suggest_table_bits(max(total // 2, 1), load_factor=0.60) + online_lib = ensure_online_ngram_lib(log0) + ngram_state = OnlineNgramState( + lib=online_lib, + token_ctx_len=max(token_order - 1, 0), + token_table_bits=token_table_bits, + within_table_bits=within_table_bits, + starts_new_word_lut=starts_new_word_lut, + boundary_lut=boundary_lut, + seed_prefix_token=int(target_token_ids_np[0]), + ) + word_state = WordStartState(sp=sp, order=word_order, normalize_mode=word_normalize) + + token_top_tok, token_top_prob, within_top_tok, within_top_prob, within_valid = ( + ngram_state.process_chunk(target_token_ids_np) + ) + word_top_tok, word_top_prob = word_state.process_chunk( + target_token_ids_np, + starts_new_word_lut=starts_new_word_lut, + boundary_lut=boundary_lut, + ) + + def _expected_gain(p_top, boost): + # E[ -log p'(y) under -log p(y)] when y ~ p + # = p_top * boost - log1p(p_top * (exp(boost) - 1)) + # Maximizing this over experts => pick the most informative hint. + log_norm = np.log1p(p_top * (math.exp(boost) - 1.0)) + return p_top * boost - log_norm + + token_gate = token_top_prob >= np.float32(token_threshold) + within_gate = within_valid & (within_top_prob >= np.float32(within_tau)) + word_gate = word_top_prob >= np.float32(word_tau) + + token_gain = np.where(token_gate, _expected_gain(token_top_prob.astype(np.float64), token_boost), -np.inf) + within_gain = np.where(within_gate, _expected_gain(within_top_prob.astype(np.float64), within_boost), -np.inf) + word_gain = np.where(word_gate, _expected_gain(word_top_prob.astype(np.float64), word_boost), -np.inf) + + stack = np.stack([token_gain, within_gain, word_gain], axis=1) + best_idx = np.argmax(stack, axis=1) + best_gain = np.max(stack, axis=1) + any_gate = best_gain > -np.inf + + hint_ids = np.zeros(total, dtype=np.int64) + boost = np.zeros(total, dtype=np.float32) + base_boost_per_expert = np.array([token_boost, within_boost, word_boost], dtype=np.float32) + hint_per_expert = np.stack([ + token_top_tok.astype(np.int64), + within_top_tok.astype(np.int64), + word_top_tok.astype(np.int64), + ], axis=1) + + rows = np.arange(total) + hint_ids[any_gate] = hint_per_expert[rows[any_gate], best_idx[any_gate]] + boost[any_gate] = base_boost_per_expert[best_idx[any_gate]] + + # Agreement bonus: if 2+ experts agree on the same hint as best, add agree_add_boost + gate_mask_each = np.stack([token_gate, within_gate, word_gate], axis=1) + expert_hints = hint_per_expert.copy() + expert_hints[~gate_mask_each] = -1 + agreements = (expert_hints == hint_ids[:, None]).sum(axis=1) + agreement_extra = np.where(agreements >= 2, np.float32(agree_add_boost), np.float32(0.0)) + boost = (boost + agreement_extra).astype(np.float32) + + log0( + f"ngram_tilt:hints total={total} gated={int(any_gate.sum())} " + f"token_gate={int(token_gate.sum())} within_gate={int(within_gate.sum())} word_gate={int(word_gate.sum())} " + f"agree2plus={int((agreements >= 2).sum())}" + ) + + return { + "hint_ids": hint_ids, + "gate_mask": any_gate, + "boost": boost, + "sp": sp, + "starts_new_word_lut": starts_new_word_lut, + "boundary_lut": boundary_lut, + } + + +def apply_tilt_to_ptl_torch( + ptl: torch.Tensor, + log_q_hint: torch.Tensor, + target_ids: torch.Tensor, + hint_ids: torch.Tensor, + gate_mask: torch.Tensor, + boost: torch.Tensor, +): + """Closed-form tilt applied to per-token NLL. + + All tensors same shape [..., L]. + ptl_tilted = ptl - beta * 1[y == h] + log1p(q * (exp(beta) - 1)) if gate else ptl + """ + boost64 = boost.to(torch.float64) + q = log_q_hint.to(torch.float64).clamp_(max=0.0).exp() + is_hit = (target_ids == hint_ids).to(torch.float64) + log_Z = torch.log1p(q * (torch.expm1(boost64))) + ptl_tilted = ptl.to(torch.float64) - boost64 * is_hit + log_Z + return torch.where(gate_mask, ptl_tilted, ptl.to(torch.float64)).to(ptl.dtype) + + +def apply_tilt_to_ptl_torch_fast( + ptl: torch.Tensor, + log_q_hint: torch.Tensor, + target_ids: torch.Tensor, + hint_ids: torch.Tensor, + gate_mask: torch.Tensor, + boost: torch.Tensor, +): + """fp32 variant of apply_tilt — cast removed where safe. + + BPB downstream accumulator is fp64, so per-token tilt computation in + fp32 has no impact on final precision. Saves ~10-15s per eval pass on + H100 (avoids fp64 ALU + double memory traffic). + """ + boost32 = boost.to(torch.float32) + q = log_q_hint.to(torch.float32).clamp_(max=0.0).exp() + is_hit = (target_ids == hint_ids).to(torch.float32) + log_Z = torch.log1p(q * (torch.expm1(boost32))) + ptl_f32 = ptl.to(torch.float32) + ptl_tilted = ptl_f32 - boost32 * is_hit + log_Z + return torch.where(gate_mask, ptl_tilted, ptl_f32).to(ptl.dtype) diff --git a/records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/prepare_caseops_data.py b/records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/prepare_caseops_data.py new file mode 100644 index 0000000000..e3a6123b68 --- /dev/null +++ b/records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/prepare_caseops_data.py @@ -0,0 +1,194 @@ +"""Prepare CaseOps-tokenized FineWeb shards + per-token byte sidecar. + +CaseOps (``lossless_caps_caseops_v1``) is a bijective, character-level text +transform that introduces four operator tokens in place of explicit +capitalization: TITLE, ALLCAPS, CAPNEXT, ESC. The transform is fully +reversible — no information is lost relative to the untransformed UTF-8 +text, so BPB stays computable on TRUE byte counts. + +Forward pipeline: + 1. Read the canonical FineWeb-10B doc stream (``docs_selected.jsonl`` + produced by ``data/download_hf_docs_and_tokenize.py`` in the root repo). + 2. Apply ``encode_lossless_caps_v2`` (the caseops_v1 alias) to each doc. + 3. Tokenize with the shipped SP model + ``tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model`` + (reserves TITLE/ALLCAPS/CAPNEXT/ESC + sentinel as user_defined_symbols). + 4. Write uint16 train/val shards (``fineweb_{train,val}_XXXXXX.bin``). + 5. For the VAL stream only, emit per-token byte sidecar shards + (``fineweb_val_bytes_XXXXXX.bin``, uint16 parallel arrays) that record + each token's ORIGINAL pre-transform UTF-8 byte count. BPB is computed + from these canonical bytes so the score is on the untransformed text + (not the transformed representation). + +Output layout — matches what ``train_gpt.py`` expects under +``DATA_DIR=./data`` with ``CASEOPS_ENABLED=1``: + + data/datasets/fineweb10B_sp8192_caseops/datasets/ + tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model + datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/ + fineweb_train_000000.bin + fineweb_train_000001.bin + ... + fineweb_val_000000.bin + fineweb_val_bytes_000000.bin + +Usage: + + python3 prepare_caseops_data.py \\ + --docs ./fineweb10B_raw/docs_selected.jsonl \\ + --out ./data/datasets/fineweb10B_sp8192_caseops/datasets \\ + --sp ./tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model + +Requirements: sentencepiece, numpy. CPU-only. Runs once; reused across seeds. +""" +from __future__ import annotations + +import argparse +import json +import pathlib +import struct +import sys + +import numpy as np +import sentencepiece as spm + +# Local import — lossless_caps.py ships next to this script. +sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent)) +from lossless_caps import ( # noqa: E402 + LOSSLESS_CAPS_CASEOPS_V1, + encode_lossless_caps_v2, + surface_piece_original_byte_counts, +) + + +SHARD_MAGIC = 20240520 +SHARD_VERSION = 1 +SHARD_TOKENS = 10_000_000 # tokens per shard — matches the main pipeline +BOS_ID = 1 # SP model's control token; train_gpt.py:_find_docs requires BOS per doc + + +def _write_shard(out_path: pathlib.Path, arr: np.ndarray) -> None: + """Write a uint16 shard in the standard header-prefixed format.""" + assert arr.dtype == np.uint16 + header = np.zeros(256, dtype=np.int32) + header[0] = SHARD_MAGIC + header[1] = SHARD_VERSION + header[2] = int(arr.size) + with out_path.open("wb") as fh: + fh.write(header.tobytes()) + fh.write(arr.tobytes()) + + +def _iter_docs(docs_path: pathlib.Path, skip: int = 0): + """Yield doc strings from a jsonl file (one json object per line). + + If skip > 0, the first ``skip`` non-empty lines are discarded before + yielding begins. Use this to continue a prep run from a known offset + (e.g. after already writing N shards from the first M docs). + """ + with docs_path.open("r", encoding="utf-8") as fh: + skipped = 0 + for line in fh: + line = line.strip() + if not line: + continue + if skipped < skip: + skipped += 1 + continue + obj = json.loads(line) + # Support both {"text": ...} and raw strings. + yield obj["text"] if isinstance(obj, dict) else obj + + +def _token_original_byte_counts( + sp: spm.SentencePieceProcessor, + original_text: str, + transformed_text: str, +) -> np.ndarray: + """Per-token canonical (pre-transform) UTF-8 byte counts. + + Delegates to ``surface_piece_original_byte_counts`` in ``lossless_caps.py`` + — the canonical exporter used by the PR #1729 / HF-hosted CaseOps dataset. + Operator pieces (U+E001..U+E004) contribute 0 original bytes; letter pieces + contribute their pre-transform UTF-8 byte count. + """ + proto = sp.encode_as_immutable_proto(transformed_text) + byte_counts = surface_piece_original_byte_counts( + (piece.surface for piece in proto.pieces), + text_transform_name=LOSSLESS_CAPS_CASEOPS_V1, + ) + return np.asarray(list(byte_counts), dtype=np.uint16) + + +def main() -> None: + ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + ap.add_argument("--docs", required=True, type=pathlib.Path, help="Path to docs_selected.jsonl") + ap.add_argument("--out", required=True, type=pathlib.Path, help="Output datasets dir") + ap.add_argument("--sp", required=True, type=pathlib.Path, help="Path to CaseOps SP model") + ap.add_argument("--val-docs", type=int, default=10_000, help="Validation docs count") + ap.add_argument("--skip-docs", type=int, default=0, + help="Skip first N docs before processing (use to continue a partial run)") + ap.add_argument("--start-shard-train", type=int, default=0, + help="Start train shard numbering from this index (use to append to existing shards)") + ap.add_argument("--max-docs", type=int, default=0, + help="Stop after processing this many docs (0 = no limit)") + args = ap.parse_args() + + sp = spm.SentencePieceProcessor(model_file=str(args.sp)) + print(f"loaded sp: vocab={sp.vocab_size()}", flush=True) + + train_out = args.out / "datasets" / "fineweb10B_sp8192_lossless_caps_caseops_v1_reserved" + train_out.mkdir(parents=True, exist_ok=True) + + val_buf_tokens: list[int] = [] + val_buf_bytes: list[int] = [] + train_buf: list[int] = [] + val_written = 0 + train_written = args.start_shard_train + n_docs = 0 + + for text in _iter_docs(args.docs, skip=args.skip_docs): + transformed = encode_lossless_caps_v2(text) + token_ids = [BOS_ID] + sp.encode(transformed, out_type=int) + if n_docs < args.val_docs: + # Validation doc — also compute byte sidecar + byte_counts = _token_original_byte_counts(sp, text, transformed) + val_buf_tokens.extend(token_ids) + val_buf_bytes.append(0) # BOS contributes 0 original bytes + val_buf_bytes.extend(int(b) for b in byte_counts) + if len(val_buf_tokens) >= SHARD_TOKENS: + _write_shard(train_out / f"fineweb_val_{val_written:06d}.bin", + np.array(val_buf_tokens[:SHARD_TOKENS], dtype=np.uint16)) + _write_shard(train_out / f"fineweb_val_bytes_{val_written:06d}.bin", + np.array(val_buf_bytes[:SHARD_TOKENS], dtype=np.uint16)) + val_buf_tokens = val_buf_tokens[SHARD_TOKENS:] + val_buf_bytes = val_buf_bytes[SHARD_TOKENS:] + val_written += 1 + else: + train_buf.extend(token_ids) + if len(train_buf) >= SHARD_TOKENS: + _write_shard(train_out / f"fineweb_train_{train_written:06d}.bin", + np.array(train_buf[:SHARD_TOKENS], dtype=np.uint16)) + train_buf = train_buf[SHARD_TOKENS:] + train_written += 1 + n_docs += 1 + if n_docs % 10_000 == 0: + print(f" processed {n_docs} docs train_shards={train_written} val_shards={val_written}", flush=True) + if args.max_docs > 0 and n_docs >= args.max_docs: + break + + # Flush tail buffers into final (possibly short) shards. + if val_buf_tokens: + _write_shard(train_out / f"fineweb_val_{val_written:06d}.bin", + np.array(val_buf_tokens, dtype=np.uint16)) + _write_shard(train_out / f"fineweb_val_bytes_{val_written:06d}.bin", + np.array(val_buf_bytes, dtype=np.uint16)) + if train_buf: + _write_shard(train_out / f"fineweb_train_{train_written:06d}.bin", + np.array(train_buf, dtype=np.uint16)) + + print(f"done. docs={n_docs} train_shards={train_written + (1 if train_buf else 0)} val_shards={val_written + (1 if val_buf_tokens else 0)}") + + +if __name__ == "__main__": + main() diff --git a/records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/requirements.txt b/records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/requirements.txt new file mode 100644 index 0000000000..b6c55e13aa --- /dev/null +++ b/records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/requirements.txt @@ -0,0 +1,13 @@ +# Python deps. Install with: pip install -r requirements.txt +torch==2.9.1+cu128 +sentencepiece +brotli +huggingface_hub +numpy +python-minifier + +# FlashAttention 3 must be installed separately (not on PyPI): +# pip install --no-deps flash_attn_3 --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291/ + +# System dep (apt): lrzip (used by per-group compressor) +# apt-get install -y lrzip diff --git a/records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/submission.json b/records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/submission.json new file mode 100644 index 0000000000..5d4cd890bc --- /dev/null +++ b/records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/submission.json @@ -0,0 +1,10 @@ +{ + "track": "non_record_16mb", + "date": "2026-05-03", + "name": "SP8192 + RandProj384 tied embeddings + Pairwise-QK Muon", + "author": "YaseenHQ", + "github_id": "YaseenHQ", + "val_bpb": 1.1269, + "val_loss": 2.4662, + "bytes_total": 15438770 +} diff --git a/records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/train_gpt.py b/records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/train_gpt.py new file mode 100644 index 0000000000..c767c5dc5c --- /dev/null +++ b/records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/train_gpt.py @@ -0,0 +1,4666 @@ +import base64, collections, copy, fcntl, glob, io, lzma, math, os +from pathlib import Path +import random, re, subprocess, sys, time, uuid, numpy as np, sentencepiece as spm, torch, torch.distributed as dist, torch.nn.functional as F +from torch import Tensor, nn +from flash_attn_interface import ( + flash_attn_func as flash_attn_3_func, + flash_attn_varlen_func, +) +from concurrent.futures import ThreadPoolExecutor +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + + +# ===== Fused softcapped cross-entropy (Triton) — training-only path ===== +# Replaces the eager +# logits_softcap = softcap * tanh(logits / softcap) +# F.cross_entropy(logits_softcap.float(), targets, reduction="mean") +# sequence with a single fused kernel that reads logits_proj once, applies +# softcap in-register, and computes (LSE, loss) in one streaming pass. The +# backward kernel mirrors the forward so there's no stored softcapped logits. +# Numerically identical to the eager path up to fp32 accumulation differences. +_FUSED_CE_LIBRARY = "pgsubmission1draft7fusedce" +_FUSED_CE_BLOCK_SIZE = 1024 +_FUSED_CE_NUM_WARPS = 4 + + +@triton.jit +def _softcapped_ce_fwd_kernel( + logits_ptr, losses_ptr, lse_ptr, targets_ptr, + stride_logits_n, stride_logits_v, + n_rows, n_cols, softcap, + block_size: tl.constexpr, +): + row_idx = tl.program_id(0).to(tl.int64) + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + max_val = -float("inf") + sum_exp = 0.0 + A = 2.0 * softcap + inv_C = 2.0 / softcap + for off in range(0, n_cols, block_size): + cols = off + tl.arange(0, block_size) + mask = cols < n_cols + val = tl.load( + logits_row_ptr + cols * stride_logits_v, + mask=mask, other=-float("inf"), + ).to(tl.float32) + z = A * tl.sigmoid(val * inv_C) + z = tl.where(mask, z, -float("inf")) + curr_max = tl.max(z, axis=0) + new_max = tl.maximum(max_val, curr_max) + sum_exp = sum_exp * tl.exp(max_val - new_max) + tl.sum(tl.exp(z - new_max), axis=0) + max_val = new_max + lse = max_val + tl.log(sum_exp) + tl.store(lse_ptr + row_idx, lse) + target = tl.load(targets_ptr + row_idx).to(tl.int32) + target_val = tl.load(logits_row_ptr + target * stride_logits_v).to(tl.float32) + target_z = A * tl.sigmoid(target_val * inv_C) + tl.store(losses_ptr + row_idx, lse - target_z) + + +@triton.jit +def _softcapped_ce_bwd_kernel( + grad_logits_ptr, grad_losses_ptr, lse_ptr, logits_ptr, targets_ptr, + stride_logits_n, stride_logits_v, + stride_grad_n, stride_grad_v, + n_rows, n_cols, softcap, + block_size: tl.constexpr, +): + row_idx = tl.program_id(0).to(tl.int64) + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + grad_row_ptr = grad_logits_ptr + row_idx * stride_grad_n + lse = tl.load(lse_ptr + row_idx) + grad_loss = tl.load(grad_losses_ptr + row_idx).to(tl.float32) + target = tl.load(targets_ptr + row_idx).to(tl.int32) + A = 2.0 * softcap + inv_C = 2.0 / softcap + dz_dx_scale = A * inv_C + for off in range(0, n_cols, block_size): + cols = off + tl.arange(0, block_size) + mask = cols < n_cols + val = tl.load( + logits_row_ptr + cols * stride_logits_v, + mask=mask, other=0.0, + ).to(tl.float32) + sigmoid_u = tl.sigmoid(val * inv_C) + z = A * sigmoid_u + probs = tl.exp(z - lse) + grad_z = grad_loss * (probs - tl.where(cols == target, 1.0, 0.0)) + grad_x = grad_z * (dz_dx_scale * sigmoid_u * (1.0 - sigmoid_u)) + tl.store(grad_row_ptr + cols * stride_grad_v, grad_x, mask=mask) + + +def _validate_softcapped_ce_inputs( + logits: Tensor, targets: Tensor, softcap: float, +) -> tuple[Tensor, Tensor]: + if logits.ndim != 2: + raise ValueError(f"Expected logits.ndim=2, got {logits.ndim}") + if targets.ndim != 1: + raise ValueError(f"Expected targets.ndim=1, got {targets.ndim}") + if logits.shape[0] != targets.shape[0]: + raise ValueError( + f"Expected matching rows, got logits={tuple(logits.shape)} targets={tuple(targets.shape)}" + ) + if not logits.is_cuda or not targets.is_cuda: + raise ValueError("softcapped_cross_entropy requires CUDA tensors") + if softcap <= 0.0: + raise ValueError(f"softcap must be positive, got {softcap}") + if logits.dtype not in (torch.float16, torch.bfloat16, torch.float32): + raise ValueError(f"Unsupported logits dtype: {logits.dtype}") + logits = logits.contiguous() + targets = targets.contiguous() + if targets.dtype != torch.int64: + targets = targets.to(dtype=torch.int64) + return logits, targets + + +@torch.library.custom_op(f"{_FUSED_CE_LIBRARY}::softcapped_ce", mutates_args=()) +def softcapped_ce_op(logits: Tensor, targets: Tensor, softcap: float) -> tuple[Tensor, Tensor]: + logits, targets = _validate_softcapped_ce_inputs(logits, targets, float(softcap)) + n_rows, n_cols = logits.shape + losses = torch.empty((n_rows,), device=logits.device, dtype=torch.float32) + lse = torch.empty((n_rows,), device=logits.device, dtype=torch.float32) + _softcapped_ce_fwd_kernel[(n_rows,)]( + logits, losses, lse, targets, + logits.stride(0), logits.stride(1), + n_rows, n_cols, float(softcap), + block_size=_FUSED_CE_BLOCK_SIZE, num_warps=_FUSED_CE_NUM_WARPS, + ) + return losses, lse + + +@softcapped_ce_op.register_fake +def _(logits: Tensor, targets: Tensor, softcap: float): + if logits.ndim != 2 or targets.ndim != 1: + raise ValueError("softcapped_ce fake impl expects 2D logits and 1D targets") + if logits.shape[0] != targets.shape[0]: + raise ValueError( + f"Expected matching rows, got logits={tuple(logits.shape)} targets={tuple(targets.shape)}" + ) + n_rows = logits.shape[0] + return ( + logits.new_empty((n_rows,), dtype=torch.float32), + logits.new_empty((n_rows,), dtype=torch.float32), + ) + + +@torch.library.custom_op(f"{_FUSED_CE_LIBRARY}::softcapped_ce_backward", mutates_args=()) +def softcapped_ce_backward_op( + logits: Tensor, targets: Tensor, lse: Tensor, grad_losses: Tensor, softcap: float, +) -> Tensor: + logits, targets = _validate_softcapped_ce_inputs(logits, targets, float(softcap)) + lse = lse.contiguous() + grad_losses = grad_losses.contiguous().to(dtype=torch.float32) + if lse.ndim != 1 or grad_losses.ndim != 1: + raise ValueError("Expected 1D lse and grad_losses") + if lse.shape[0] != logits.shape[0] or grad_losses.shape[0] != logits.shape[0]: + raise ValueError( + f"Expected row-aligned lse/grad_losses, got logits={tuple(logits.shape)} " + f"lse={tuple(lse.shape)} grad_losses={tuple(grad_losses.shape)}" + ) + grad_logits = torch.empty_like(logits) + n_rows, n_cols = logits.shape + _softcapped_ce_bwd_kernel[(n_rows,)]( + grad_logits, grad_losses, lse, logits, targets, + logits.stride(0), logits.stride(1), + grad_logits.stride(0), grad_logits.stride(1), + n_rows, n_cols, float(softcap), + block_size=_FUSED_CE_BLOCK_SIZE, num_warps=_FUSED_CE_NUM_WARPS, + ) + return grad_logits + + +@softcapped_ce_backward_op.register_fake +def _(logits: Tensor, targets: Tensor, lse: Tensor, grad_losses: Tensor, softcap: float): + if logits.ndim != 2 or targets.ndim != 1 or lse.ndim != 1 or grad_losses.ndim != 1: + raise ValueError("softcapped_ce_backward fake impl expects 2D logits and 1D row tensors") + if ( + logits.shape[0] != targets.shape[0] + or logits.shape[0] != lse.shape[0] + or logits.shape[0] != grad_losses.shape[0] + ): + raise ValueError("softcapped_ce_backward fake impl expects row-aligned tensors") + return logits.new_empty(logits.shape) + + +def _softcapped_ce_setup_context( + ctx: torch.autograd.function.FunctionCtx, inputs, output, +) -> None: + logits, targets, softcap = inputs + _losses, lse = output + ctx.save_for_backward(logits, targets, lse) + ctx.softcap = float(softcap) + + +def _softcapped_ce_backward( + ctx: torch.autograd.function.FunctionCtx, grad_losses: Tensor, grad_lse: "Tensor | None", +): + del grad_lse + logits, targets, lse = ctx.saved_tensors + grad_logits = torch.ops.pgsubmission1draft7fusedce.softcapped_ce_backward( + logits, targets, lse, grad_losses, ctx.softcap + ) + return grad_logits, None, None + + +softcapped_ce_op.register_autograd( + _softcapped_ce_backward, setup_context=_softcapped_ce_setup_context, +) + + +def softcapped_cross_entropy( + logits: Tensor, targets: Tensor, softcap: float, reduction: str = "mean", +) -> Tensor: + losses, _lse = torch.ops.pgsubmission1draft7fusedce.softcapped_ce( + logits, targets, float(softcap) + ) + if reduction == "none": + return losses + if reduction == "sum": + return losses.sum() + if reduction == "mean": + return losses.mean() + raise ValueError(f"Unsupported reduction={reduction!r}") + + +_VALID_ATTN_BACKENDS = {"fa3", "local_global", "delta_hybrid"} +_LEAKY_RELU_SQ_SLOPE = float(os.environ.get("LEAKY_RELU_SQ_SLOPE", "0.3")) + + +def normalize_attn_backend(name: str) -> str: + backend = name.strip().lower() + if backend not in _VALID_ATTN_BACKENDS: + valid = ", ".join(sorted(_VALID_ATTN_BACKENDS)) + raise ValueError(f"ATTN_BACKEND must be one of {{{valid}}}, got {name!r}") + return backend + + +def parse_local_global_pattern(pattern: str) -> list[str]: + pattern = pattern.strip() + if not pattern: + return [] + if "," in pattern or " " in pattern: + raw_tokens = [tok for tok in re.split(r"[\s,]+", pattern) if tok] + else: + raw_tokens = list(pattern) + modes = [] + for token in raw_tokens: + token_l = token.strip().lower() + if token_l in {"l", "local"}: + modes.append("local") + elif token_l in {"g", "global"}: + modes.append("global") + else: + raise ValueError( + f"Unsupported LOCAL_GLOBAL_PATTERN token {token!r}; use L/G or local/global" + ) + if not modes: + raise ValueError("LOCAL_GLOBAL_PATTERN must contain at least one layer mode") + return modes + + +def local_global_mode_for_layer( + layer_idx: int, + num_layers: int, + global_every_n: int, + pattern: list[str], +) -> str: + if pattern: + return pattern[layer_idx % len(pattern)] + if global_every_n <= 0: + return "global" + if layer_idx == num_layers - 1: + return "global" + return "global" if (layer_idx + 1) % global_every_n == 0 else "local" + + +def make_random_orthoproject(latent_dim: int, model_dim: int, seed: int) -> torch.Tensor: + """Return a fixed random orthoproject P with orthonormal rows. + + Shape is [latent_dim, model_dim]. This supports a JL-style tied embedding path: + token embeddings live in the smaller latent space, then project up via P; logits + project hidden states back down via P.T before the tied output matmul. + """ + if latent_dim <= 0 or model_dim <= 0: + raise ValueError( + f"Random projection dims must be positive, got latent_dim={latent_dim}, model_dim={model_dim}" + ) + if latent_dim > model_dim: + raise ValueError( + f"Random projection requires latent_dim <= model_dim, got {latent_dim} > {model_dim}" + ) + g = torch.Generator(device="cpu") + g.manual_seed(seed) + # QR on [model_dim, latent_dim] gives orthonormal columns; transpose to get + # orthonormal rows in [latent_dim, model_dim]. + base = torch.randn(model_dim, latent_dim, generator=g, dtype=torch.float32) + q, _ = torch.linalg.qr(base, mode="reduced") + return q.transpose(0, 1).contiguous() + + +class Hyperparameters: + data_dir = os.environ.get("DATA_DIR", "./data/") + seed = int(os.environ.get("SEED", 1337)) + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_frac = float(os.environ.get("WARMDOWN_FRAC", 0.75)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786432)) + # Fused softcapped CE (Triton). Training-only — forward_logits eval path still uses + # eager softcap+F.cross_entropy. Default ON since validated as at-worst neutral. + fused_ce_enabled = bool(int(os.environ.get("FUSED_CE_ENABLED", "1"))) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 6e2)) + val_batch_tokens = int(os.environ.get("VAL_BATCH_TOKENS", 524288)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 8192)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 11)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 4.0)) + skip_gates_enabled = bool(int(os.environ.get("SKIP_GATES_ENABLED", "1"))) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + random_proj_embed_enabled = bool(int(os.environ.get("RANDOM_PROJ_EMBED_ENABLED", "1"))) + random_proj_embed_dim = int(os.environ.get("RANDOM_PROJ_EMBED_DIM", 384)) + random_proj_seed = int(os.environ.get("RANDOM_PROJ_SEED", "1337")) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 3e1)) + asym_logit_rescale = bool(int(os.environ.get("ASYM_LOGIT_RESCALE", "0"))) + rope_base = float(os.environ.get("ROPE_BASE", 1e4)) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + rope_train_seq_len = int(os.environ.get("ROPE_TRAIN_SEQ_LEN", 2048)) + rope_yarn = bool(int(os.environ.get("ROPE_YARN", "0"))) + attn_backend = os.environ.get("ATTN_BACKEND", "fa3").strip().lower() + local_attn_window = int(os.environ.get("LOCAL_ATTN_WINDOW", 1024)) + global_every_n = int(os.environ.get("GLOBAL_EVERY_N", 6)) + local_global_pattern = os.environ.get("LOCAL_GLOBAL_PATTERN", "").strip() + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 5.25)) + # Layer 1: per-layer QK-Gain init schedule. Comma-separated floats, one per physical + # layer. Falls back to uniform qk_gain_init if empty. Schedule is an initialization + # choice — q_gain remains trainable and evolves from this starting point. + _qk_sched_raw = os.environ.get("QK_GAIN_INIT_SCHEDULE", "") + qk_gain_schedule = ( + [float(x) for x in _qk_sched_raw.split(",") if x.strip()] + if _qk_sched_raw.strip() else [] + ) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + loop_start = int(os.environ.get("LOOP_START", 3)) + loop_end = int(os.environ.get("LOOP_END", 5)) + enable_looping_at = float(os.environ.get("ENABLE_LOOPING_AT", 0.35)) + parallel_start_layer = int(os.environ.get("PARALLEL_START_LAYER", 8)) + parallel_final_lane = os.environ.get("PARALLEL_FINAL_LANE", "mean") + min_lr = float(os.environ.get("MIN_LR", 0.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.026)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.97)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float( + os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92) + ) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + muon_row_normalize = bool(int(os.environ.get("MUON_ROW_NORMALIZE", "1"))) + muon_qk_pair_ortho = bool(int(os.environ.get("MUON_QK_PAIR_ORTHO", "1"))) + muon_qk_pair_size = int(os.environ.get("MUON_QK_PAIR_SIZE", "2")) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-08)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + adam_wd = float(os.environ.get("ADAM_WD", 0.02)) + muon_wd = float(os.environ.get("MUON_WD", 0.095)) + embed_wd = float(os.environ.get("EMBED_WD", 0.085)) + # Layer 3: AdamHD Huber weight decay for Muon. Replaces L2 WD with Huber regularizer: + # quadratic below |w| <= delta (standard L2 behavior), linear above (clips large-weight + # gradient contribution). Specifically suppresses outlier weights that dominate int6 + # quantization error. delta=0 falls back to standard L2 decay. + muon_huber_wd = bool(int(os.environ.get("MUON_HUBER_WD", "0"))) + muon_huber_delta = float(os.environ.get("MUON_HUBER_DELTA", "0.1")) + ema_decay = float(os.environ.get("EMA_DECAY", 0.9965)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "1"))) + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 96)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.0001)) + ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 48)) + ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 2048)) + ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) + ttt_grad_steps = int(os.environ.get("TTT_GRAD_STEPS", 1)) + ttt_grad_steps_clean = bool(int(os.environ.get("TTT_GRAD_STEPS_CLEAN", "0"))) + # PR #1145 online n-gram tilt (AnirudhRahul, valerio-endorsed). Causal, + # normalized, prefix-only experts; closed-form multiplicative-boost-with-renorm + # applied to per-token NLL at scoring time only. See online_ngram_tilt.py. + ngram_tilt_enabled = bool(int(os.environ.get("NGRAM_TILT_ENABLED", "0"))) + token_order = int(os.environ.get("TOKEN_ORDER", "16")) + token_threshold = float(os.environ.get("TOKEN_THRESHOLD", "0.800")) + token_boost = float(os.environ.get("TOKEN_BOOST", "2.625")) + # within-doc and word-start experts gate on target_i properties (is_new_word, + # is_boundary applied to the token being SCORED), which violates C1 causality. + # Defaults 99.0 ensure their gates never fire. Token-only is the legal subset + # (confirmed by PR #1514 merge precedent). + within_tau = float(os.environ.get("WITHIN_TAU", "99.0")) + within_boost = float(os.environ.get("WITHIN_BOOST", "0.0")) + word_order = int(os.environ.get("WORD_ORDER", "4")) + word_normalize = os.environ.get("WORD_NORMALIZE", "strip_punct_lower") + word_tau = float(os.environ.get("WORD_TAU", "99.0")) + word_boost = float(os.environ.get("WORD_BOOST", "0.0")) + agree_add_boost = float(os.environ.get("AGREE_ADD_BOOST", "0.500")) + ngram_hint_precompute_outside = bool(int(os.environ.get("NGRAM_HINT_PRECOMPUTE_OUTSIDE", "1"))) + ttt_weight_decay = float(os.environ.get("TTT_WEIGHT_DECAY", 1.0)) + ttt_beta1 = float(os.environ.get("TTT_BETA1", 0)) + ttt_beta2 = float(os.environ.get("TTT_BETA2", 0.999)) + ttt_k_lora = bool(int(os.environ.get("TTT_K_LORA", "1"))) + ttt_mlp_lora = bool(int(os.environ.get("TTT_MLP_LORA", "1"))) + ttt_o_lora = bool(int(os.environ.get("TTT_O_LORA", "1"))) + ttt_optimizer = os.environ.get("TTT_OPTIMIZER", "adam") + ttt_eval_batches = os.environ.get("TTT_EVAL_BATCHES", "") + val_doc_fraction = float(os.environ.get("VAL_DOC_FRACTION", 1.0)) + compressor = os.environ.get("COMPRESSOR", "brotli") + gptq_calibration_batches = int(os.environ.get("GPTQ_CALIBRATION_BATCHES", 16)) + gptq_reserve_seconds = float(os.environ.get("GPTQ_RESERVE_SECONDS", 4.0)) + phased_ttt_prefix_docs = int(os.environ.get("PHASED_TTT_PREFIX_DOCS", 2000)) + phased_ttt_num_phases = int(os.environ.get("PHASED_TTT_NUM_PHASES", 1)) + global_ttt_lr = float(os.environ.get("GLOBAL_TTT_LR", 0.001)) + global_ttt_momentum = float(os.environ.get("GLOBAL_TTT_MOMENTUM", 0.9)) + global_ttt_epochs = int(os.environ.get("GLOBAL_TTT_EPOCHS", 1)) + global_ttt_chunk_tokens = int(os.environ.get("GLOBAL_TTT_CHUNK_TOKENS", 32768)) + global_ttt_batch_seqs = int(os.environ.get("GLOBAL_TTT_BATCH_SEQS", 32)) + global_ttt_warmup_start_lr = float(os.environ.get("GLOBAL_TTT_WARMUP_START_LR", 0.0)) + global_ttt_warmup_chunks = int(os.environ.get("GLOBAL_TTT_WARMUP_CHUNKS", 0)) + global_ttt_grad_clip = float(os.environ.get("GLOBAL_TTT_GRAD_CLIP", 1.0)) + global_ttt_respect_doc_boundaries = bool(int(os.environ.get("GLOBAL_TTT_RESPECT_DOC_BOUNDARIES", "1"))) + # Layer 4: LaCT global TTT optimizer. "sgd" = original SGD (default). "muon" = Muon + # Newton-Schulz orthogonalized updates for matrix params, SGD for scalars/embeddings. + # LaCT (arxiv 2505.23884) uses large document chunks + Muon fast-weight updates to + # achieve 70% GPU utilization vs <5% for per-token TTT, enabling more TTT epochs. + global_ttt_optimizer = os.environ.get("GLOBAL_TTT_OPTIMIZER", "sgd") + global_ttt_muon_ns_steps = int(os.environ.get("GLOBAL_TTT_MUON_NS_STEPS", "5")) + global_ttt_muon_nesterov = bool(int(os.environ.get("GLOBAL_TTT_MUON_NESTEROV", "1"))) + # Layer 2: OptRot pre-quantization Hadamard rotation. Rotates (W_up, W_down) and + # (W_v, W_o) pairs via Hadamard matrices, redistributing outlier weights before GPTQ. + # Orthogonal transformation: model outputs are preserved exactly; quantization error + # is reduced 30-50% (paper: arxiv 2512.24124). Zero artifact cost (fused into weights). + optrot_enabled = bool(int(os.environ.get("OPTROT_ENABLED", "0"))) + matrix_bits = int(os.environ.get("MATRIX_BITS", 6)) + embed_bits = int(os.environ.get("EMBED_BITS", 8)) + matrix_clip_sigmas = float(os.environ.get("MATRIX_CLIP_SIGMAS", 12.85)) + embed_clip_sigmas = float(os.environ.get("EMBED_CLIP_SIGMAS", 2e1)) + mlp_clip_sigmas = float(os.environ.get("MLP_CLIP_SIGMAS", 10.0)) + attn_clip_sigmas = float(os.environ.get("ATTN_CLIP_SIGMAS", 13.0)) + # AttnOutGate (per-head multiplicative output gate, PR #1667 MarioPaerle). + # Zero-init weight: 2*sigmoid(0)=1 -> transparent at start. Source defaults to + # block input x ('proj'); 'q' uses raw Q projection output. + attn_out_gate_enabled = bool(int(os.environ.get("ATTN_OUT_GATE_ENABLED", "0"))) + attn_out_gate_src = os.environ.get("ATTN_OUT_GATE_SRC", "proj") + # SmearGate (input-dependent forward-1 token smear, modded-nanogpt @classiclarryd + # via PR #1667). x_t <- x_t + lam * sigmoid(W*x_t[:gate_window]) * x_{t-1}. + # lam=0 + W=0 -> transparent at init. + smear_gate_enabled = bool(int(os.environ.get("SMEAR_GATE_ENABLED", "0"))) + # Window: first GATE_WINDOW dims of the source feed the gate projection. + gate_window = int(os.environ.get("GATE_WINDOW", 12)) + # Gated Attention (Qwen, NeurIPS 2025 Best Paper, arXiv:2505.06708; + # qiuzh20/gated_attention). Per-head sigmoid gate on SDPA output, BEFORE + # out_proj. Gate input = full block input x (paper's headwise G1 variant + # driven from hidden_states). W_g shape (num_heads, dim), plain sigmoid. + # Near-zero init gives g~0.5 at step 0 (half attention output); per-block + # attn_scale (init 1.0) compensates during training. Name contains + # "attn_gate" so CONTROL_TENSOR_NAME_PATTERNS routes it to scalar AdamW. + gated_attn_enabled = bool(int(os.environ.get("GATED_ATTN_ENABLED", "0"))) + gated_attn_init_std = float(os.environ.get("GATED_ATTN_INIT_STD", 0.01)) + # Dedicated int8-per-row quantization for `attn_gate_w` tensors. These are + # small ((num_heads, dim) = (8, 512) = 4096 params) and bypass GPTQ via the + # numel<=65536 passthrough branch -> stored as fp16 (8 KB/layer, ~65 KB total + # compressed). int8-per-row cuts the raw tensor in half with negligible BPB + # impact: scales per head (8 values), symmetric quant over [-127, 127]. + # No Hessian needed (gate weights not in collect_hessians()). + gated_attn_quant_gate = bool(int(os.environ.get("GATED_ATTN_QUANT_GATE", "0"))) + # Sparse Attention Gate (modded-nanogpt-style). Keeps dense SDPA and only + # swaps the output-gate input to the first GATE_WINDOW residual dims. + # W_g: (num_heads, gate_window) = (8, 12) = 96 params/layer (~44K total), + # vs dense GatedAttn's (8, 512) = 4K/layer (~44K diff). Name "attn_gate_w" + # is shared so quant routing and int8 gate passthrough Just Work. Gate + # passthrough int8 still applies via GATED_ATTN_QUANT_GATE=1. + # Mutually exclusive with ATTN_OUT_GATE_ENABLED and GATED_ATTN_ENABLED. + sparse_attn_gate_enabled = bool(int(os.environ.get("SPARSE_ATTN_GATE_ENABLED", "0"))) + sparse_attn_gate_init_std = float(os.environ.get("SPARSE_ATTN_GATE_INIT_STD", 0.0)) + sparse_attn_gate_scale = float(os.environ.get("SPARSE_ATTN_GATE_SCALE", 1.0)) + # LQER asymmetric rank-k correction on top-K quant-error tensors (PR #1530 v2 port). + # Computes SVD of E = W_fp - W_quant, packs top-r A,B as INT2/INT4 (asym) or INTk (sym). + lqer_enabled = bool(int(os.environ.get("LQER_ENABLED", "1"))) + lqer_rank = int(os.environ.get("LQER_RANK", 4)) + lqer_top_k = int(os.environ.get("LQER_TOP_K", 3)) + lqer_factor_bits = int(os.environ.get("LQER_FACTOR_BITS", 4)) + lqer_asym_enabled = bool(int(os.environ.get("LQER_ASYM_ENABLED", "1"))) + lqer_asym_group = int(os.environ.get("LQER_ASYM_GROUP", "64")) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + is_main_process = rank == 0 + grad_accum_steps = 8 // world_size + # CaseOps integration: optional override of dataset root + tokenizer path. + # When CASEOPS_ENABLED=1, the wrapper loads a per-token byte sidecar + # (fineweb_val_bytes_*.bin, identical shard layout to val_*.bin) and uses + # it as the canonical raw-byte budget for BPB accounting. The sidecar + # REPLACES the build_sentencepiece_luts byte-counting path entirely. + caseops_enabled = bool(int(os.environ.get("CASEOPS_ENABLED", "0"))) + _default_caseops_data = os.path.join( + data_dir, + "datasets", + "fineweb10B_sp8192_caseops", + "datasets", + "datasets", + "fineweb10B_sp8192_lossless_caps_caseops_v1_reserved", + ) + _default_caseops_tok = os.path.join( + data_dir, + "datasets", + "fineweb10B_sp8192_caseops", + "datasets", + "tokenizers", + "fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model", + ) + if caseops_enabled: + datasets_dir = os.environ.get("DATA_PATH", _default_caseops_data) + tokenizer_path = os.environ.get("TOKENIZER_PATH", _default_caseops_tok) + else: + datasets_dir = os.environ.get( + "DATA_PATH", + os.path.join(data_dir, "datasets", f"fineweb10B_sp{vocab_size}"), + ) + tokenizer_path = os.environ.get( + "TOKENIZER_PATH", + os.path.join(data_dir, "tokenizers", f"fineweb_{vocab_size}_bpe.model"), + ) + train_files = os.path.join(datasets_dir, "fineweb_train_*.bin") + val_files = os.path.join(datasets_dir, "fineweb_val_*.bin") + val_bytes_files = os.path.join(datasets_dir, "fineweb_val_bytes_*.bin") + artifact_dir = os.environ.get("ARTIFACT_DIR", "") + logfile = ( + os.path.join(artifact_dir, f"{run_id}.txt") + if artifact_dir + else f"logs/{run_id}.txt" + ) + model_path = ( + os.path.join(artifact_dir, "final_model.pt") + if artifact_dir + else "final_model.pt" + ) + quantized_model_path = ( + os.path.join(artifact_dir, "final_model.int6.ptz") + if artifact_dir + else "final_model.int6.ptz" + ) + + +_logger_hparams = None + + +def set_logging_hparams(h): + global _logger_hparams + _logger_hparams = h + + +def log(msg, console=True): + if _logger_hparams is None: + print(msg) + return + if _logger_hparams.is_main_process: + if console: + print(msg) + if _logger_hparams.logfile is not None: + with open(_logger_hparams.logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + +class ValidationData: + def __init__(self, h, device): + self.sp = spm.SentencePieceProcessor(model_file=h.tokenizer_path) + if int(self.sp.vocab_size()) != h.vocab_size: + raise ValueError( + f"VOCAB_SIZE={h.vocab_size} does not match tokenizer vocab_size={int(self.sp.vocab_size())}" + ) + self.val_tokens = load_validation_tokens(h.val_files, h.eval_seq_len) + ( + self.base_bytes_lut, + self.has_leading_space_lut, + self.is_boundary_token_lut, + ) = build_sentencepiece_luts(self.sp, h.vocab_size, device) + # CaseOps: when enabled, load per-token byte sidecar and stash it as a + # CPU tensor aligned 1:1 with self.val_tokens. eval_val/eval_val_ttt + # branches use this as the canonical raw-byte budget per token. + self.caseops_enabled = bool(getattr(h, "caseops_enabled", False)) + self.val_bytes = None + if self.caseops_enabled: + self.val_bytes = load_validation_byte_sidecar( + h.val_bytes_files, h.eval_seq_len, self.val_tokens.numel() + ) + + +def build_sentencepiece_luts(sp, vocab_size, device): + sp_vocab_size = int(sp.vocab_size()) + assert ( + sp.piece_to_id("▁") != sp.unk_id() + ), "Tokenizer must have '▁' (space) as its own token for correct BPB byte counting" + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern, seq_len): + # Filter out CaseOps byte sidecar shards which share the val_*.bin glob. + files = [ + Path(p) + for p in sorted(glob.glob(pattern)) + if "_bytes_" not in Path(p).name + ] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = (tokens.numel() - 1) // seq_len * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def load_validation_byte_sidecar(pattern, seq_len, expected_len): + """Load CaseOps per-token byte sidecar(s). Same shard layout as token shards + (256 int32 header + uint16 array). Each entry = canonical raw-text byte + budget for that token in the corresponding val shard. Returns a CPU + int16 tensor sliced to match expected_len (i.e. val_tokens length).""" + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No byte sidecar files for pattern: {pattern}") + shards = [load_data_shard(file) for file in files] + # load_data_shard returns uint16 — that's exactly what the sidecar stores. + bytes_full = torch.cat(shards).contiguous() + if bytes_full.numel() < expected_len: + raise ValueError( + f"Byte sidecar too short: {bytes_full.numel()} < val_tokens {expected_len}" + ) + return bytes_full[:expected_len].to(torch.int32) + + +def load_data_shard(file): + header_bytes = 256 * np.dtype(" 0: + pos = start + while pos < end: + seg_starts.append(pos) + pos += max_doc_len + else: + seg_starts.append(start) + boundaries = seg_starts + [total_len] + padded_len = get_next_multiple_of_n(len(boundaries), bucket_size) + cu = torch.full((padded_len,), total_len, dtype=torch.int32, device=device) + cu[: len(boundaries)] = torch.tensor(boundaries, dtype=torch.int32, device=device) + seg_ends = seg_starts[1:] + [total_len] + max_seqlen = max(end - start for start, end in zip(seg_starts, seg_ends)) + return cu, max_seqlen + +class DocumentPackingLoader: + _shard_pool = ThreadPoolExecutor(1) + + def __init__(self, h, device, cu_bucket_size=64): + self.rank = h.rank + self.world_size = h.world_size + self.device = device + self.cu_bucket_size = cu_bucket_size + self.max_seq_len = h.train_seq_len + all_files = [Path(p) for p in sorted(glob.glob(h.train_files))] + if not all_files: + raise FileNotFoundError(f"No files found for pattern: {h.train_files}") + self.files = all_files + self.file_iter = iter(self.files) + self._init_shard(load_data_shard(next(self.file_iter))) + self._next_shard = self._submit_next_shard() + self._batch_pool = ThreadPoolExecutor(1) + self._next_batch = None + + def _init_shard(self, tokens): + global BOS_ID + self.tokens = tokens + self.shard_size = tokens.numel() + if BOS_ID is None: + BOS_ID = 1 + self.bos_idx = ( + (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + ) + self.cursor = int(self.bos_idx[0]) + + def _submit_next_shard(self): + try: + path = next(self.file_iter) + return self._shard_pool.submit(load_data_shard, path) + except StopIteration: + return None + + def _advance_shard(self): + if self._next_shard is None: + self.file_iter = iter(self.files) + self._next_shard = self._shard_pool.submit( + load_data_shard, next(self.file_iter) + ) + self._init_shard(self._next_shard.result()) + self._next_shard = self._submit_next_shard() + + def _local_doc_starts(self, local_start, total_len): + lo = np.searchsorted(self.bos_idx, local_start, side="left") + hi = np.searchsorted(self.bos_idx, local_start + total_len, side="left") + return (self.bos_idx[lo:hi] - local_start).tolist() + + def _prepare_batch(self, num_tokens_local, max_seq_len): + per_rank_span = num_tokens_local + 1 + global_span = per_rank_span * self.world_size + while self.cursor + global_span > self.shard_size: + self._advance_shard() + local_start = self.cursor + self.rank * per_rank_span + buf = self.tokens[local_start : local_start + per_rank_span] + inputs = buf[:-1].to(dtype=torch.int64).pin_memory() + targets = buf[1:].to(dtype=torch.int64).pin_memory() + starts = self._local_doc_starts(local_start, inputs.numel()) + cu_seqlens, max_seqlen = _build_cu_seqlens( + starts, inputs.numel(), inputs.device, max_seq_len, self.cu_bucket_size + ) + cu_seqlens = cu_seqlens.pin_memory() + self.cursor += global_span + return inputs, targets, cu_seqlens, max_seqlen + + def next_batch(self, global_tokens, grad_accum_steps): + num_tokens_local = global_tokens // (self.world_size * grad_accum_steps) + if self._next_batch is not None: + inputs, targets, cu_seqlens, max_seqlen = self._next_batch.result() + else: + inputs, targets, cu_seqlens, max_seqlen = self._prepare_batch( + num_tokens_local, self.max_seq_len + ) + self._next_batch = self._batch_pool.submit( + self._prepare_batch, num_tokens_local, self.max_seq_len + ) + return ( + inputs[None].to(self.device, non_blocking=True), + targets[None].to(self.device, non_blocking=True), + cu_seqlens.to(self.device, non_blocking=True), + max_seqlen, + ) + + +class ShuffledSequenceLoader: + def __init__(self, h, device): + self.world_size = h.world_size + self.seq_len = h.train_seq_len + self.device = device + all_files = [Path(p) for p in sorted(glob.glob(h.train_files))] + if not all_files: + raise FileNotFoundError(f"No files found for pattern: {h.train_files}") + self.files = all_files[h.rank :: h.world_size] + self.rng = np.random.Generator(np.random.PCG64(h.rank)) + self.num_tokens = [_read_num_tokens(f) for f in self.files] + self.start_inds = [[] for _ in self.files] + for si in range(len(self.files)): + self._reset_shard(si) + + def _reset_shard(self, si): + max_phase = min( + self.seq_len - 1, max(0, self.num_tokens[si] - self.seq_len - 1) + ) + phase = int(self.rng.integers(max_phase + 1)) if max_phase > 0 else 0 + num_sequences = (self.num_tokens[si] - 1 - phase) // self.seq_len + sequence_order = self.rng.permutation(num_sequences) + self.start_inds[si] = (phase + sequence_order * self.seq_len).tolist() + + def next_batch(self, global_tokens, grad_accum_steps): + device_tokens = global_tokens // (self.world_size * grad_accum_steps) + device_batch_size = device_tokens // self.seq_len + remaining = np.array([len(s) for s in self.start_inds], dtype=np.float64) + x = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + y = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + for bi in range(device_batch_size): + total = remaining.sum() + if total <= 0: + for si in range(len(self.files)): + self._reset_shard(si) + remaining = np.array( + [len(s) for s in self.start_inds], dtype=np.float64 + ) + total = remaining.sum() + probs = remaining / total + si = int(self.rng.choice(len(self.files), p=probs)) + start_ind = self.start_inds[si].pop() + remaining[si] -= 1 + mm = _get_shard_memmap(self.files[si]) + window = torch.as_tensor( + np.array(mm[start_ind : start_ind + self.seq_len + 1], dtype=np.int64) + ) + x[bi] = window[:-1] + y[bi] = window[1:] + return x.to(self.device, non_blocking=True), y.to( + self.device, non_blocking=True + ) + + +class DocStartSequenceLoader: + """GPTQ calibration loader yielding windows that begin at BOS positions. + + Matches ShuffledSequenceLoader.next_batch() interface exactly. + Used when GPTQ_CALIBRATION_MODE=doc_start to align calibration distribution + with eval (which processes document-structured data with BOS-prepended contexts). + """ + _N_SCAN_SHARDS = 8 + + def __init__(self, h, device, bos_id=1): + self.world_size = h.world_size + self.seq_len = h.train_seq_len + self.device = device + all_files = [Path(p) for p in sorted(glob.glob(h.train_files))] + if not all_files: + raise FileNotFoundError(f"No files found for pattern: {h.train_files}") + rank_files = all_files[h.rank :: h.world_size] + self.rng = np.random.Generator(np.random.PCG64(h.rank + 7919)) + scan_files = rank_files[: self._N_SCAN_SHARDS] + self._windows = [] # (path, start_offset) pairs starting at BOS + t0 = time.perf_counter() + for path in scan_files: + mm = _get_shard_memmap(path) + n = len(mm) + positions = np.where(mm[:] == np.uint16(bos_id))[0] + valid = positions[positions + self.seq_len + 1 <= n] + for pos in valid.tolist(): + self._windows.append((path, pos)) + log(f"DocStartLoader: {len(self._windows)} BOS windows from {len(scan_files)} shards in {time.perf_counter()-t0:.1f}s") + if not self._windows: + raise RuntimeError("DocStartSequenceLoader: no valid BOS windows found — check BOS_ID and shard files") + + def next_batch(self, global_tokens, grad_accum_steps): + device_tokens = global_tokens // (self.world_size * grad_accum_steps) + device_batch_size = device_tokens // self.seq_len + x = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + y = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + idxs = self.rng.integers(0, len(self._windows), size=device_batch_size) + for bi, idx in enumerate(idxs): + path, start = self._windows[int(idx)] + mm = _get_shard_memmap(path) + window = torch.as_tensor( + np.array(mm[start : start + self.seq_len + 1], dtype=np.int64) + ) + x[bi] = window[:-1] + y[bi] = window[1:] + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +class RMSNorm(nn.Module): + def __init__(self, eps=None): + super().__init__() + self.eps = eps + + def forward(self, x): + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x): + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +@triton.jit +def linear_leaky_relu_square_kernel( + a_desc, + b_desc, + c_desc, + aux_desc, + M, + N, + K, + neg_slope, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + NUM_SMS: tl.constexpr, + FORWARD: tl.constexpr, +): + dtype = tl.bfloat16 + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + tile_id_c = start_pid - NUM_SMS + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_bn, offs_k]) + accumulator = tl.dot(a, b.T, accumulator) + tile_id_c += NUM_SMS + offs_am_c = offs_am + offs_bn_c = offs_bn + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + c0 = acc0.to(dtype) + c1 = acc1.to(dtype) + if not FORWARD: + pre0 = aux_desc.load([offs_am_c, offs_bn_c]) + pre1 = aux_desc.load([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2]) + neg_grad = 2.0 * neg_slope * neg_slope + c0 = c0 * tl.where(pre0 > 0, 2.0 * pre0, neg_grad * pre0) + c1 = c1 * tl.where(pre1 > 0, 2.0 * pre1, neg_grad * pre1) + c_desc.store([offs_am_c, offs_bn_c], c0) + c_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], c1) + if FORWARD: + aux0 = tl.where(c0 > 0, c0, neg_slope * c0) + aux1 = tl.where(c1 > 0, c1, neg_slope * c1) + aux_desc.store([offs_am_c, offs_bn_c], aux0 * aux0) + aux_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], aux1 * aux1) + + +def linear_leaky_relu_square(a, b, aux=None): + M, K = a.shape + N, K2 = b.shape + assert K == K2 + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + forward = aux is None + if aux is None: + aux = torch.empty((M, N), device=a.device, dtype=a.dtype) + num_sms = torch.cuda.get_device_properties(a.device).multi_processor_count + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 128, 256, 64 + num_stages = 4 if forward else 3 + a_desc = TensorDescriptor.from_tensor(a, [BLOCK_SIZE_M, BLOCK_SIZE_K]) + b_desc = TensorDescriptor.from_tensor(b, [BLOCK_SIZE_N, BLOCK_SIZE_K]) + c_desc = TensorDescriptor.from_tensor(c, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + aux_desc = TensorDescriptor.from_tensor(aux, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + grid = lambda _meta: ( + min(num_sms, triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N)), + ) + linear_leaky_relu_square_kernel[grid]( + a_desc, + b_desc, + c_desc, + aux_desc, + M, + N, + K, + float(_LEAKY_RELU_SQ_SLOPE), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + NUM_SMS=num_sms, + FORWARD=forward, + num_stages=num_stages, + num_warps=8, + ) + if forward: + return c, aux + return c + + +class FusedLinearLeakyReLUSquareFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, w1, w2): + x_flat = x.reshape(-1, x.shape[-1]) + pre, post = linear_leaky_relu_square(x_flat, w1) + out = F.linear(post, w2) + ctx.save_for_backward(x, w1, w2, pre, post) + return out.view(*x.shape[:-1], out.shape[-1]) + + @staticmethod + def backward(ctx, grad_output): + x, w1, w2, pre, post = ctx.saved_tensors + x_flat = x.reshape(-1, x.shape[-1]) + grad_output_flat = grad_output.reshape(-1, grad_output.shape[-1]) + dw2 = grad_output_flat.T @ post + dpre = linear_leaky_relu_square(grad_output_flat, w2.T.contiguous(), aux=pre) + dw1 = dpre.T @ x_flat + dx = dpre @ w1 + return dx.view_as(x), dw1, dw2 + + +FusedLeakyReLUSquareMLP = FusedLinearLeakyReLUSquareFunction.apply + + +@triton.jit +def fused_log_softmax_dual_gather_kernel( + logits_ptr, + target_ids_ptr, + hint_ids_ptr, + log_p_y_out_ptr, + log_q_h_out_ptr, + BT, + V, + BLOCK_V: tl.constexpr, +): + """Single pass over [BT, V] logits; extracts log p(target) and log p(hint).""" + pid = tl.program_id(0) + if pid >= BT: + return + target = tl.load(target_ids_ptr + pid) + hint = tl.load(hint_ids_ptr + pid) + row_offset = pid * V + target_logit = tl.load(logits_ptr + row_offset + target).to(tl.float32) + hint_logit = tl.load(logits_ptr + row_offset + hint).to(tl.float32) + NEG_INF = float("-inf") + max_val = NEG_INF + for v_start in tl.range(0, V, BLOCK_V): + v_offsets = v_start + tl.arange(0, BLOCK_V) + mask = v_offsets < V + chunk = tl.load(logits_ptr + row_offset + v_offsets, mask=mask, other=NEG_INF).to(tl.float32) + max_val = tl.maximum(max_val, tl.max(chunk, axis=0)) + sum_exp = tl.zeros((), dtype=tl.float32) + for v_start in tl.range(0, V, BLOCK_V): + v_offsets = v_start + tl.arange(0, BLOCK_V) + mask = v_offsets < V + chunk = tl.load(logits_ptr + row_offset + v_offsets, mask=mask, other=0.0).to(tl.float32) + sum_exp += tl.sum(tl.where(mask, tl.exp(chunk - max_val), 0.0), axis=0) + log_sum_exp = max_val + tl.log(sum_exp) + tl.store(log_p_y_out_ptr + pid, target_logit - log_sum_exp) + tl.store(log_q_h_out_ptr + pid, hint_logit - log_sum_exp) + + +def fused_log_softmax_dual_gather(logits, target_ids, hint_ids): + """Returns (log_p_y, log_q_h) where p = softmax(logits). No backward needed.""" + bsz, sl, V = logits.shape + BT = bsz * sl + logits_flat = logits.reshape(BT, V).contiguous() + log_p_y_out = torch.empty(BT, dtype=torch.float32, device=logits.device) + log_q_h_out = torch.empty(BT, dtype=torch.float32, device=logits.device) + fused_log_softmax_dual_gather_kernel[(BT,)]( + logits_flat, + target_ids.reshape(BT).contiguous(), + hint_ids.reshape(BT).contiguous(), + log_p_y_out, + log_q_h_out, + BT, V, BLOCK_V=1024, num_warps=8, + ) + return log_p_y_out.reshape(bsz, sl), log_q_h_out.reshape(bsz, sl) + + +class Rotary(nn.Module): + def __init__(self, dim, base=1e4, train_seq_len=1024, rope_dims=0, yarn=True): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.yarn = yarn + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / base ** ( + torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + + def forward(self, seq_len, device, dtype): + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached < seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if self.yarn and seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * scale ** (rd / (rd - 2)) + inv_freq = 1.0 / new_base ** ( + torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd + ) + else: + inv_freq = self.inv_freq.float().to(device) + t = torch.arange(seq_len, device=device, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached[:, :seq_len].to(dtype=dtype), self._sin_cached[:, :seq_len].to(dtype=dtype) + + +def apply_rotary_emb(x, cos, sin, rope_dims=0): + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * -sin + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * -sin + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len, yarn=True, + attn_out_gate=False, attn_out_gate_src="proj", gate_window=12, + gated_attn=False, gated_attn_init_std=0.01, + sparse_attn_gate=False, sparse_attn_gate_init_std=0.0, sparse_attn_gate_scale=1.0, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + if int(attn_out_gate) + int(gated_attn) + int(sparse_attn_gate) > 1: + raise ValueError( + "attn_out_gate, gated_attn, and sparse_attn_gate are mutually exclusive" + ) + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + self.q_gain = nn.Parameter( + torch.full((num_heads,), qk_gain_init, dtype=torch.float32) + ) + self.rope_dims = 0 + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len, yarn=yarn) + self.use_xsa = False + self.attn_backend = "fa3" + self.local_attn_window = 1024 + self.local_global_mode = "global" + # AttnOutGate (PR #1667 MarioPaerle): per-head multiplicative gate on attention + # output. CastedLinear so restore_fp32_params casts back to fp32 for GPTQ. + # _zero_init -> 2*sigmoid(0)=1 -> transparent at init. + self.attn_out_gate = attn_out_gate + self.attn_out_gate_src = attn_out_gate_src + self.gate_window = gate_window + if attn_out_gate: + self.attn_gate_proj = CastedLinear(gate_window, num_heads, bias=False) + self.attn_gate_proj._zero_init = True + # Gated Attention (arXiv:2505.06708, Qwen, NeurIPS 2025). Per-head sigmoid + # gate on SDPA output, BEFORE out_proj. Gate projection W_g: (num_heads, dim). + # Name "attn_gate_w" contains "attn_gate" substring so it matches + # CONTROL_TENSOR_NAME_PATTERNS and routes to the scalar AdamW group. + # fp32 Parameter -> restore_fp32_params path covers it via the ndim<2 OR + # name-pattern check (name matches "attn_gate"). Cast to x.dtype on use. + self.gated_attn = gated_attn + if gated_attn: + W = torch.empty(num_heads, dim, dtype=torch.float32) + nn.init.normal_(W, mean=0.0, std=gated_attn_init_std) + self.attn_gate_w = nn.Parameter(W) + # Sparse attention head-output gate (modded-nanogpt style). Keeps dense SDPA + # and only narrows the gate input to the first gate_window residual dims. + # W_g: (num_heads, gate_window). y_{t,h} <- sigmoid(scale * W_g_h @ x_t[:gate_window]) * y_{t,h}. + # Shares attn_gate_w name with dense GatedAttn so the quant routing + # (CONTROL_TENSOR_NAME_PATTERNS / attn_gate_w int8 passthrough) is unchanged. + self.sparse_attn_gate = sparse_attn_gate + self.sparse_attn_gate_scale = sparse_attn_gate_scale + if sparse_attn_gate: + W = torch.empty(num_heads, gate_window, dtype=torch.float32) + if sparse_attn_gate_init_std > 0: + nn.init.normal_(W, mean=0.0, std=sparse_attn_gate_init_std) + else: + nn.init.zeros_(W) + self.attn_gate_w = nn.Parameter(W) + + def _xsa_efficient(self, y, v): + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def _resolve_window_size(self): + backend = normalize_attn_backend(self.attn_backend) + if backend == "fa3": + return (-1, -1) + if backend == "delta_hybrid": + raise NotImplementedError( + "ATTN_BACKEND=delta_hybrid is reserved for a later prototype" + ) + if self.local_global_mode == "global": + return (-1, -1) + if self.local_global_mode != "local": + raise ValueError(f"Unknown local/global attention mode {self.local_global_mode!r}") + if self.local_attn_window <= 0: + raise ValueError( + f"LOCAL_ATTN_WINDOW must be positive for local attention, got {self.local_attn_window}" + ) + return (self.local_attn_window - 1, 0) + + def _run_attention(self, q, k, v, cu_seqlens=None, max_seqlen=0): + window_size = self._resolve_window_size() + if cu_seqlens is not None: + return flash_attn_varlen_func( + q[0], + k[0], + v[0], + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + causal=True, + window_size=window_size, + )[None] + return flash_attn_3_func(q, k, v, causal=True, window_size=window_size) + + def forward(self, x, q_w, k_w, v_w, out_w, cu_seqlens=None, max_seqlen=0): + bsz, seqlen, dim = x.shape + # q_raw kept around as a tap point for attn_out_gate_src='q' (post-projection, + # pre-reshape, pre-RoPE). + q_raw = F.linear(x, q_w.to(x.dtype)) + q = q_raw.reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = F.linear(x, v_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = self._run_attention(q, k, v, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + if self.use_xsa: + y = self._xsa_efficient(y, v) + # AttnOutGate inlined (PR #1667). Inline + .contiguous() barrier so torch.compile + # fullgraph=True is happy (this avoids the @torch.compiler.disable trap that + # crashed gates v3). Per-head gate on (B,T,H,D) tensor: g shape [B,T,H], broadcast + # over D via [..., None]. zero-init weight -> 2*sigmoid(0)=1 -> transparent. + if self.attn_out_gate: + gate_src = q_raw if self.attn_out_gate_src == "q" else x + gate_in = gate_src[..., : self.gate_window].contiguous() + g = 2.0 * torch.sigmoid(self.attn_gate_proj(gate_in)) + y = y * g[..., None] + # Gated Attention (arXiv:2505.06708 G1). Inline + .contiguous() barrier so + # torch.compile fullgraph=True is happy. Per-head gate on (B,T,H,D): g shape + # [B,T,H], broadcast over D via [..., None]. Paper: g = sigmoid(x @ W_g.T) + # where W_g: (H, dim). .to(x.dtype) on fp32 param before broadcast with bf16. + if self.gated_attn: + x_c = x.contiguous() + g = torch.sigmoid(F.linear(x_c, self.attn_gate_w.to(x.dtype))) + y = y * g[..., None] + # Sparse head-output gate: narrower (gate_window) input, same shape g as GatedAttn. + if self.sparse_attn_gate: + gate_in = x[..., : self.gate_window].contiguous() + g = torch.sigmoid( + self.sparse_attn_gate_scale + * F.linear(gate_in, self.attn_gate_w.to(x.dtype)) + ) + y = y * g[..., None] + y = y.reshape(bsz, seqlen, dim) + self._last_proj_input = y.detach() if getattr(self, "_calib", False) else None + return F.linear(y, out_w.to(x.dtype)) + + +class MLP(nn.Module): + def __init__(self, dim, mlp_mult): + super().__init__() + self.use_fused = os.environ.get("FUSED_MLP_ENABLED", "1") == "1" + + def forward(self, x, up_w, down_w): + if self.training and self.use_fused: + return FusedLeakyReLUSquareMLP(x, up_w.to(x.dtype), down_w.to(x.dtype)) + hidden = F.leaky_relu( + F.linear(x, up_w.to(x.dtype)), + negative_slope=_LEAKY_RELU_SQ_SLOPE, + ).square() + self._last_down_input = hidden.detach() if getattr(self, "_calib", False) else None + return F.linear(hidden, down_w.to(x.dtype)) + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + train_seq_len, + layer_idx=0, + ln_scale=False, + yarn=True, + attn_out_gate=False, + attn_out_gate_src="proj", + gate_window=12, + gated_attn=False, + gated_attn_init_std=0.01, + sparse_attn_gate=False, + sparse_attn_gate_init_std=0.0, + sparse_attn_gate_scale=1.0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention( + dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len, yarn=yarn, + attn_out_gate=attn_out_gate, attn_out_gate_src=attn_out_gate_src, gate_window=gate_window, + gated_attn=gated_attn, gated_attn_init_std=gated_attn_init_std, + sparse_attn_gate=sparse_attn_gate, + sparse_attn_gate_init_std=sparse_attn_gate_init_std, + sparse_attn_gate_scale=sparse_attn_gate_scale, + ) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter( + torch.stack((torch.ones(dim), torch.zeros(dim))).float() + ) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=None, max_seqlen=0): + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn( + self.attn_norm(x_in) * self.ln_scale_factor, + q_w, k_w, v_w, out_w, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[ + None, None, : + ] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) + return x_out + +class GPT(nn.Module): + def __init__(self, h): + super().__init__() + if h.logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {h.logit_softcap}") + attn_backend = normalize_attn_backend(h.attn_backend) + if attn_backend == "local_global" and h.local_attn_window <= 0: + raise ValueError( + f"LOCAL_ATTN_WINDOW must be positive for local_global attention, got {h.local_attn_window}" + ) + local_global_pattern = parse_local_global_pattern(h.local_global_pattern) + self.tie_embeddings = h.tie_embeddings + self.tied_embed_init_std = h.tied_embed_init_std + self.model_dim = h.model_dim + self.random_proj_embed_enabled = bool(h.random_proj_embed_enabled) + self.random_proj_embed_dim = h.random_proj_embed_dim + self.logit_softcap = h.logit_softcap + self.asym_logit_enabled = h.asym_logit_rescale + if self.asym_logit_enabled: + self.softcap_pos = nn.Parameter(torch.tensor(h.logit_softcap)) + self.softcap_neg = nn.Parameter(torch.tensor(h.logit_softcap)) + self.fused_ce_enabled = bool(h.fused_ce_enabled) + if self.random_proj_embed_enabled and not self.tie_embeddings: + raise ValueError("RANDOM_PROJ_EMBED_ENABLED requires TIE_EMBEDDINGS=1") + embed_dim = ( + h.random_proj_embed_dim if self.random_proj_embed_enabled else h.model_dim + ) + self.tok_emb = nn.Embedding(h.vocab_size, embed_dim) + if self.random_proj_embed_enabled: + proj = make_random_orthoproject(embed_dim, h.model_dim, h.random_proj_seed) + self.register_buffer("tok_emb_proj", proj, persistent=False) + self.num_layers = h.num_layers + head_dim = h.model_dim // h.num_heads + kv_dim = h.num_kv_heads * head_dim + hidden_dim = int(h.mlp_mult * h.model_dim) + self.qo_bank = nn.Parameter(torch.empty(2 * h.num_layers, h.model_dim, h.model_dim)) + self.kv_bank = nn.Parameter(torch.empty(2 * h.num_layers, kv_dim, h.model_dim)) + self.mlp_up_bank = nn.Parameter(torch.empty(h.num_layers, hidden_dim, h.model_dim)) + self.mlp_down_bank = nn.Parameter(torch.empty(h.num_layers, h.model_dim, hidden_dim)) + self.num_encoder_layers = h.num_layers // 2 + self.num_decoder_layers = h.num_layers - self.num_encoder_layers + self.blocks = nn.ModuleList( + [ + Block( + h.model_dim, + h.num_heads, + h.num_kv_heads, + h.mlp_mult, + h.rope_base, + # Layer 1: per-layer QK-Gain init schedule. Use schedule[i] if provided, + # else uniform qk_gain_init. Schedule is initialization only — q_gain + # stays trainable and diverges from this starting point during training. + (h.qk_gain_schedule[i] if h.qk_gain_schedule and i < len(h.qk_gain_schedule) + else h.qk_gain_init), + h.train_seq_len, + layer_idx=i, + ln_scale=h.ln_scale, + yarn=h.rope_yarn, + attn_out_gate=h.attn_out_gate_enabled, + attn_out_gate_src=h.attn_out_gate_src, + gate_window=h.gate_window, + gated_attn=h.gated_attn_enabled, + gated_attn_init_std=h.gated_attn_init_std, + sparse_attn_gate=h.sparse_attn_gate_enabled, + sparse_attn_gate_init_std=h.sparse_attn_gate_init_std, + sparse_attn_gate_scale=h.sparse_attn_gate_scale, + ) + for i in range(h.num_layers) + ] + ) + if h.rope_dims > 0: + head_dim = h.model_dim // h.num_heads + for block in self.blocks: + block.attn.rope_dims = h.rope_dims + block.attn.rotary = Rotary( + head_dim, + base=h.rope_base, + train_seq_len=h.train_seq_len, + rope_dims=h.rope_dims, + yarn=h.rope_yarn, + ) + for i, block in enumerate(self.blocks): + block.attn.attn_backend = attn_backend + block.attn.local_attn_window = h.local_attn_window + block.attn.local_global_mode = local_global_mode_for_layer( + i, h.num_layers, h.global_every_n, local_global_pattern + ) + self.final_norm = RMSNorm() + self.lm_head = ( + None + if h.tie_embeddings + else CastedLinear(h.model_dim, h.vocab_size, bias=False) + ) + if self.lm_head is not None: + self.lm_head._zero_init = True + if h.xsa_last_n > 0: + for i in range(max(0, h.num_layers - h.xsa_last_n), h.num_layers): + self.blocks[i].attn.use_xsa = True + self.looping_active = False + if h.num_loops > 0: + loop_seg = list(range(h.loop_start, h.loop_end + 1)) + all_indices = list(range(h.loop_start)) + for _ in range(h.num_loops + 1): + all_indices.extend(loop_seg) + all_indices.extend(range(h.loop_end + 1, h.num_layers)) + num_enc = len(all_indices) // 2 + self.encoder_indices = all_indices[:num_enc] + self.decoder_indices = all_indices[num_enc:] + else: + self.encoder_indices = list(range(self.num_encoder_layers)) + self.decoder_indices = list(range(self.num_encoder_layers, h.num_layers)) + self.num_skip_weights = min( + len(self.encoder_indices), len(self.decoder_indices) + ) + self.skip_weights = nn.Parameter( + torch.ones(self.num_skip_weights, h.model_dim, dtype=torch.float32) + ) + self.skip_gates = ( + nn.Parameter( + torch.zeros(self.num_skip_weights, h.model_dim, dtype=torch.float32) + ) + if h.skip_gates_enabled + else None + ) + self.parallel_start_layer = h.parallel_start_layer + self.parallel_final_lane = h.parallel_final_lane.lower() + self.parallel_post_lambdas = nn.Parameter( + torch.ones(h.num_layers, 2, 2, dtype=torch.float32) + ) + self.parallel_resid_lambdas = nn.Parameter( + torch.full((h.num_layers, 2), 1.1, dtype=torch.float32) + ) + # SmearGate (PR #1667 / modded-nanogpt @classiclarryd): + # x_t <- x_t + lam * sigmoid(W * x_t[:gate_window]) * x_{t-1}. + # Per-token forward-1 smear of the embedding lane. W zero-init + lam=0 -> + # transparent at init. Uses CastedLinear so restore_fp32_params handles dtype. + self.smear_gate_enabled = h.smear_gate_enabled + if self.smear_gate_enabled: + self.smear_window = h.gate_window + self.smear_gate = CastedLinear(self.smear_window, 1, bias=False) + self.smear_gate._zero_init = True + self.smear_lambda = nn.Parameter(torch.zeros(1, dtype=torch.float32)) + self._init_weights() + + def _init_weights(self): + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + n = self.num_layers + proj_scale = 1.0 / math.sqrt(2 * n) + for i in range(n): + nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) + nn.init.zeros_(self.qo_bank.data[n + i]) + self.qo_bank.data[n + i].mul_(proj_scale) + nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) + nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) + for i in range(n): + nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) + nn.init.zeros_(self.mlp_down_bank.data[i]) + self.mlp_down_bank.data[i].mul_(proj_scale) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif ( + module.weight.ndim == 2 + and module.weight.shape[0] >= 64 + and module.weight.shape[1] >= 64 + ): + nn.init.orthogonal_(module.weight, gain=1.0) + + def _bank_weights(self, i): + n = self.num_layers + return ( + self.qo_bank[i], + self.kv_bank[i], + self.kv_bank[n + i], + self.qo_bank[n + i], + self.mlp_up_bank[i], + self.mlp_down_bank[i], + ) + + def _parallel_block( + self, block_idx, lane0, lane1, x0, + q_w, k_w, v_w, out_w, up_w, down_w, + cu_seqlens=None, max_seqlen=0, + ): + block = self.blocks[block_idx] + mix = block.resid_mix.to(dtype=lane0.dtype) + attn_read = mix[0][None, None, :] * lane0 + mix[1][None, None, :] * x0 + attn_out = block.attn( + block.attn_norm(attn_read) * block.ln_scale_factor, + q_w, k_w, v_w, out_w, + cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + ) + attn_out = block.attn_scale.to(dtype=attn_out.dtype)[None, None, :] * attn_out + mlp_read = lane1 + mlp_out = block.mlp_scale.to(dtype=lane1.dtype)[None, None, :] * block.mlp( + block.mlp_norm(mlp_read) * block.ln_scale_factor, up_w, down_w + ) + attn_resid = self.parallel_resid_lambdas[block_idx, 0].to(dtype=lane0.dtype) + attn_post = self.parallel_post_lambdas[block_idx, 0].to(dtype=lane0.dtype) + mlp_resid = self.parallel_resid_lambdas[block_idx, 1].to(dtype=lane0.dtype) + mlp_post = self.parallel_post_lambdas[block_idx, 1].to(dtype=lane0.dtype) + lane0 = attn_resid * lane0 + attn_post[0] * attn_out + mlp_post[0] * mlp_out + lane1 = mlp_resid * lane1 + attn_post[1] * attn_out + mlp_post[1] * mlp_out + return lane0, lane1 + + def _final_parallel_hidden(self, lane0, lane1): + if self.parallel_final_lane == "mlp": + return lane1 + if self.parallel_final_lane == "attn": + return lane0 + return 0.5 * (lane0 + lane1) + + def _embed_input(self, input_ids): + x = self.tok_emb(input_ids) + if self.random_proj_embed_enabled: + x = x @ self.tok_emb_proj.to(dtype=x.dtype) + return x + + def _project_hidden_to_tied_space(self, hidden): + if self.random_proj_embed_enabled: + return hidden @ self.tok_emb_proj.transpose(0, 1).to(dtype=hidden.dtype) + return hidden + + def _forward_hidden(self, input_ids, cu_seqlens=None, max_seqlen=0): + """Run the encoder/decoder stack to the final RMSNorm; returns pre-projection hidden. + Shared by eval (softcap+projection via forward_logits) and train (fused CE path).""" + x = self._embed_input(input_ids) + # SmearGate (PR #1667). Inline gate compute with .contiguous() on the slice fed + # to the projection so torch.compile fullgraph is happy. lam=0 + W=0 -> identity + # at init. This block runs unconditionally on the smear path; the cat keeps + # position 0 untouched so causality holds. + if self.smear_gate_enabled: + sl = self.smear_lambda.to(dtype=x.dtype) + gate_in = x[:, 1:, : self.smear_window].contiguous() + g = sl * torch.sigmoid(self.smear_gate(gate_in)) + if os.environ.get("SMEAR_GATE_BOS_FIX", "0") == "1": + not_bos = (input_ids[:, 1:] != BOS_ID).to(x.dtype).unsqueeze(-1) + x = torch.cat([x[:, :1], x[:, 1:] + g * x[:, :-1] * not_bos], dim=1) + else: + x = torch.cat([x[:, :1], x[:, 1:] + g * x[:, :-1]], dim=1) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips = [] + enc_iter = ( + self.encoder_indices + if self.looping_active + else range(self.num_encoder_layers) + ) + dec_iter = ( + self.decoder_indices + if self.looping_active + else range( + self.num_encoder_layers, + self.num_encoder_layers + self.num_decoder_layers, + ) + ) + for i in enc_iter: + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + x = self.blocks[i](x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + skips.append(x) + psl = self.parallel_start_layer + lane0 = None + lane1 = None + for skip_idx, i in enumerate(dec_iter): + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + if i >= psl and psl > 0: + if lane0 is None: + lane0 = x + lane1 = x + if skip_idx < self.num_skip_weights and skips: + skip = skips.pop() + w = self.skip_weights[skip_idx].to(dtype=lane0.dtype)[None, None, :] + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=lane0.dtype))[None, None, :] + lane0 = torch.lerp(w * skip, lane0, g) + else: + lane0 = lane0 + w * skip + lane0, lane1 = self._parallel_block( + i, lane0, lane1, x0, q_w, k_w, v_w, out_w, up_w, down_w, + cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + ) + else: + if skip_idx < self.num_skip_weights and skips: + scaled_skip = ( + self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] + * skips.pop() + ) + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + x = self.blocks[i](x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + if lane0 is not None: + x = self._final_parallel_hidden(lane0, lane1) + x = self.final_norm(x) + return x + + def _project_logits(self, hidden): + if self.tie_embeddings: + tied_hidden = self._project_hidden_to_tied_space(hidden) + return F.linear(tied_hidden, self.tok_emb.weight) + return self.lm_head(hidden) + + def _apply_asym_softcap(self, logits): + """Asymmetric softcap: independent pos/neg learned scalars (PR #1923). + Init: softcap_pos == softcap_neg == logit_softcap → identical to scalar at step 0.""" + sp = self.softcap_pos.to(logits.dtype) + sn = self.softcap_neg.to(logits.dtype) + return torch.where(logits >= 0, + sp * torch.tanh(logits / sp), + sn * torch.tanh(logits / sn)) + + def forward_logits(self, input_ids, cu_seqlens=None, max_seqlen=0): + hidden = self._forward_hidden(input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + logits_proj = self._project_logits(hidden) + if self.asym_logit_enabled: + return self._apply_asym_softcap(logits_proj) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward(self, input_ids, target_ids, cu_seqlens=None, max_seqlen=0): + hidden = self._forward_hidden(input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + logits_proj = self._project_logits(hidden) + flat_targets = target_ids.reshape(-1) + # Fused softcapped-CE kernel (training path only). Applies softcap inside the + # Triton kernel; takes pre-softcap logits_proj. Non-fused path matches stock + # PR-1736 numerics exactly (softcap in fp32, then F.cross_entropy on fp32). + if self.fused_ce_enabled: + return softcapped_cross_entropy( + logits_proj.reshape(-1, logits_proj.size(-1)), + flat_targets, + self.logit_softcap, + reduction="mean", + ) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + flat_targets, + reduction="mean", + ) + + def forward_ttt(self, input_ids, target_ids, lora, hint_ids=None): + x = self._embed_input(input_ids) + # SmearGate on the TTT path — same inline compute as forward_logits. + if self.smear_gate_enabled: + sl = self.smear_lambda.to(dtype=x.dtype) + gate_in = x[:, 1:, : self.smear_window].contiguous() + g = sl * torch.sigmoid(self.smear_gate(gate_in)) + if os.environ.get("SMEAR_GATE_BOS_FIX", "0") == "1": + not_bos = (input_ids[:, 1:] != BOS_ID).to(x.dtype).unsqueeze(-1) + x = torch.cat([x[:, :1], x[:, 1:] + g * x[:, :-1] * not_bos], dim=1) + else: + x = torch.cat([x[:, :1], x[:, 1:] + g * x[:, :-1]], dim=1) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips = [] + enc_iter = ( + self.encoder_indices + if self.looping_active + else list(range(self.num_encoder_layers)) + ) + dec_iter = ( + self.decoder_indices + if self.looping_active + else list( + range( + self.num_encoder_layers, + self.num_encoder_layers + self.num_decoder_layers, + ) + ) + ) + slot = 0 + for i in enc_iter: + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + x = self._block_with_lora(self.blocks[i], x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w) + slot += 1 + skips.append(x) + psl = self.parallel_start_layer + lane0 = None + lane1 = None + for skip_idx, i in enumerate(dec_iter): + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + if i >= psl and psl > 0: + if lane0 is None: + lane0 = x + lane1 = x + if skip_idx < self.num_skip_weights and skips: + skip = skips.pop() + w = self.skip_weights[skip_idx].to(dtype=lane0.dtype)[None, None, :] + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=lane0.dtype))[None, None, :] + lane0 = torch.lerp(w * skip, lane0, g) + else: + lane0 = lane0 + w * skip + lane0, lane1 = self._parallel_block_with_lora( + i, lane0, lane1, x0, lora, slot, + q_w, k_w, v_w, out_w, up_w, down_w, + ) + else: + if skip_idx < self.num_skip_weights and skips: + scaled_skip = ( + self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] + * skips.pop() + ) + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + x = self._block_with_lora(self.blocks[i], x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w) + slot += 1 + if lane0 is not None: + x = self._final_parallel_hidden(lane0, lane1) + x = self.final_norm(x) + if hint_ids is None: + logits_proj = self._project_logits(x) + lora.lm_head_lora(x) + bsz, sl, V = logits_proj.shape + if not self.asym_logit_enabled: + return softcapped_cross_entropy( + logits_proj.reshape(-1, V), + target_ids.reshape(-1), + self.logit_softcap, + reduction="none", + ).reshape(bsz, sl) + logits = self._apply_asym_softcap(logits_proj) + # Avoid materializing a full float32 log-softmax tensor during TTT. + log_z = torch.logsumexp(logits, dim=-1).float() + target_logit = logits.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1).float() + return log_z - target_logit + logits = self._project_logits(x) + logits = logits + lora.lm_head_lora(x) + if self.asym_logit_enabled: + logits = self._apply_asym_softcap(logits) + else: + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + bsz, sl, V = logits.shape + # PR #1145 tilt branch: return (per_tok_loss, log_q_hint) for scoring. + # TTT training backward uses requires_grad path (plain log_softmax); scoring + # uses Triton fused kernel (no autograd needed, saves memory + time). + if logits.requires_grad: + # Compute gathered log-probs via logsumexp so we avoid storing the + # full [bsz, seqlen, vocab] float32 log-softmax activation. + log_z = torch.logsumexp(logits, dim=-1).float() + log_p_y = ( + logits.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1).float() - log_z + ) + log_q_h = ( + logits.gather(-1, hint_ids.clamp(min=0).unsqueeze(-1)).squeeze(-1).float() - log_z + ) + return -log_p_y, log_q_h + log_p_y, log_q_h = fused_log_softmax_dual_gather( + logits, target_ids, hint_ids.clamp(min=0) + ) + return -log_p_y, log_q_h + + def _block_with_lora(self, block, x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w): + mix = block.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + n = block.attn_norm(x_in) * block.ln_scale_factor + attn = block.attn + bsz, seqlen, dim = n.shape + # Keep raw Q for AttnOutGate src='q' (matches forward path semantics). + q_raw = F.linear(n, q_w.to(n.dtype)) + lora.q_loras[slot](n) + q = q_raw.reshape(bsz, seqlen, attn.num_heads, attn.head_dim) + k = F.linear(n, k_w.to(n.dtype)) + if lora.k_loras is not None: + k = k + lora.k_loras[slot](n) + k = k.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + v = (F.linear(n, v_w.to(n.dtype)) + lora.v_loras[slot](n)).reshape( + bsz, seqlen, attn.num_kv_heads, attn.head_dim + ) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = attn.rotary(seqlen, n.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, attn.rope_dims) + k = apply_rotary_emb(k, cos, sin, attn.rope_dims) + q = q * attn.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = attn._run_attention(q, k, v) + if attn.use_xsa: + y = attn._xsa_efficient(y, v) + # AttnOutGate (TTT path) — inline + .contiguous() barrier, same as the eval path. + if attn.attn_out_gate: + gate_src = q_raw if attn.attn_out_gate_src == "q" else n + gate_in = gate_src[..., : attn.gate_window].contiguous() + g = 2.0 * torch.sigmoid(attn.attn_gate_proj(gate_in)) + y = y * g[..., None] + # Gated Attention (TTT path). Gate input is n (post-norm block input), same + # as eval path. .to(n.dtype) on fp32 param before bf16 broadcast. + if attn.gated_attn: + n_c = n.contiguous() + g = torch.sigmoid(F.linear(n_c, attn.attn_gate_w.to(n.dtype))) + y = y * g[..., None] + # Sparse attention head-output gate (TTT path) — must match the eval path in + # forward() exactly, else training (which applied the gate) and TTT eval (which + # skipped it) produce mismatched representations and catastrophic BPB regression. + if attn.sparse_attn_gate: + gate_in = n[..., : attn.gate_window].contiguous() + g = torch.sigmoid( + attn.sparse_attn_gate_scale + * F.linear(gate_in, attn.attn_gate_w.to(n.dtype)) + ) + y = y * g[..., None] + y = y.reshape(bsz, seqlen, dim) + attn_out = F.linear(y, out_w.to(n.dtype)) + if lora.o_loras is not None: + attn_out = attn_out + lora.o_loras[slot](n) + x_out = x_in + block.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + mlp_n = block.mlp_norm(x_out) * block.ln_scale_factor + mlp_out = block.mlp(mlp_n, up_w, down_w) + if lora.mlp_loras is not None: + mlp_out = mlp_out + lora.mlp_loras[slot](mlp_n) + x_out = x_out + block.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * mlp_out + return x_out + + def _parallel_block_with_lora( + self, block_idx, lane0, lane1, x0, lora, slot, + q_w, k_w, v_w, out_w, up_w, down_w, + ): + block = self.blocks[block_idx] + mix = block.resid_mix.to(dtype=lane0.dtype) + attn_read = mix[0][None, None, :] * lane0 + mix[1][None, None, :] * x0 + n = block.attn_norm(attn_read) * block.ln_scale_factor + attn = block.attn + bsz, seqlen, dim = n.shape + q_raw = F.linear(n, q_w.to(n.dtype)) + lora.q_loras[slot](n) + q = q_raw.reshape(bsz, seqlen, attn.num_heads, attn.head_dim) + k = F.linear(n, k_w.to(n.dtype)) + if lora.k_loras is not None: + k = k + lora.k_loras[slot](n) + k = k.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + v = (F.linear(n, v_w.to(n.dtype)) + lora.v_loras[slot](n)).reshape( + bsz, seqlen, attn.num_kv_heads, attn.head_dim + ) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = attn.rotary(seqlen, n.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, attn.rope_dims) + k = apply_rotary_emb(k, cos, sin, attn.rope_dims) + q = q * attn.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = attn._run_attention(q, k, v) + if attn.use_xsa: + y = attn._xsa_efficient(y, v) + # AttnOutGate (TTT parallel path) — inline + .contiguous() barrier. + if attn.attn_out_gate: + gate_src = q_raw if attn.attn_out_gate_src == "q" else n + gate_in = gate_src[..., : attn.gate_window].contiguous() + g = 2.0 * torch.sigmoid(attn.attn_gate_proj(gate_in)) + y = y * g[..., None] + # Gated Attention (TTT parallel path). Gate input is n (post-norm block input). + if attn.gated_attn: + n_c = n.contiguous() + g = torch.sigmoid(F.linear(n_c, attn.attn_gate_w.to(n.dtype))) + y = y * g[..., None] + # Sparse attention head-output gate (TTT parallel path) — must match the + # eval path in forward() to keep train/eval semantics in sync. + if attn.sparse_attn_gate: + gate_in = n[..., : attn.gate_window].contiguous() + g = torch.sigmoid( + attn.sparse_attn_gate_scale + * F.linear(gate_in, attn.attn_gate_w.to(n.dtype)) + ) + y = y * g[..., None] + y = y.reshape(bsz, seqlen, dim) + attn_out = F.linear(y, out_w.to(n.dtype)) + if lora.o_loras is not None: + attn_out = attn_out + lora.o_loras[slot](n) + attn_out = block.attn_scale.to(dtype=attn_out.dtype)[None, None, :] * attn_out + mlp_read = lane1 + mlp_n = block.mlp_norm(mlp_read) * block.ln_scale_factor + mlp_out = block.mlp(mlp_n, up_w, down_w) + if lora.mlp_loras is not None: + mlp_out = mlp_out + lora.mlp_loras[slot](mlp_n) + mlp_out = block.mlp_scale.to(dtype=lane1.dtype)[None, None, :] * mlp_out + attn_resid = self.parallel_resid_lambdas[block_idx, 0].to(dtype=lane0.dtype) + attn_post = self.parallel_post_lambdas[block_idx, 0].to(dtype=lane0.dtype) + mlp_resid = self.parallel_resid_lambdas[block_idx, 1].to(dtype=lane0.dtype) + mlp_post = self.parallel_post_lambdas[block_idx, 1].to(dtype=lane0.dtype) + lane0 = attn_resid * lane0 + attn_post[0] * attn_out + mlp_post[0] * mlp_out + lane1 = mlp_resid * lane1 + attn_post[1] * attn_out + mlp_post[1] * mlp_out + return lane0, lane1 + + +class BatchedLinearLoRA(nn.Module): + # PR-1767: rank-scaled output (alpha/rank), like standard LoRA. Decouples + # effective magnitude from rank so changing rank does not change LR scale. + _ALPHA = float(os.environ.get("TTT_LORA_ALPHA", "144")) + # PR-1767: optionally keep A warm across per-doc resets (only B is zeroed). + # Accumulates useful feature directions across documents within a TTT phase. + _WARM_START_A = bool(int(os.environ.get("TTT_WARM_START_A", "1"))) + + def __init__(self, bsz, in_features, out_features, rank): + super().__init__() + self._bound = 1.0 / math.sqrt(in_features) + self._scale = self._ALPHA / rank + self.A = nn.Parameter( + torch.empty(bsz, rank, in_features).uniform_(-self._bound, self._bound) + ) + self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) + + def reset(self): + with torch.no_grad(): + if not self._WARM_START_A: + self.A.uniform_(-self._bound, self._bound) + self.B.zero_() + + def forward(self, x): + return ((x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2)) * self._scale + + +class BatchedTTTLoRA(nn.Module): + def __init__(self, bsz, model, rank, k_lora=True, mlp_lora=True, o_lora=True): + super().__init__() + self.bsz = bsz + dim = model.qo_bank.shape[-1] + vocab = model.tok_emb.num_embeddings + if getattr(model, "looping_active", False): + num_slots = len(model.encoder_indices) + len(model.decoder_indices) + else: + num_slots = len(model.blocks) + kv_dim = model.blocks[0].attn.num_kv_heads * ( + dim // model.blocks[0].attn.num_heads + ) + self.lm_head_lora = BatchedLinearLoRA(bsz, dim, vocab, rank) + self.q_loras = nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + self.v_loras = nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, kv_dim, rank) for _ in range(num_slots)] + ) + self.k_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, kv_dim, rank) for _ in range(num_slots)] + ) + if k_lora + else None + ) + self.mlp_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + if mlp_lora + else None + ) + self.o_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + if o_lora + else None + ) + + def reset(self): + with torch.no_grad(): + self.lm_head_lora.reset() + for loras in [self.q_loras, self.v_loras, self.k_loras, + self.mlp_loras, self.o_loras]: + if loras is not None: + for lora in loras: + lora.reset() + + +# Polar Express per-iteration minimax Newton-Schulz coefficients (PR #1344). +# Replaces the fixed (3.4445, -4.775, 2.0315) coefficients of stock Muon. +# Applied at backend_steps=5 — taking more than 5 iterations from this list +# falls back to the final (converged) tuple via the slice guard below. +_PE_COEFFS = ( + (8.156554524902461, -22.48329292557795, 15.878769915207462), + (4.042929935166739, -2.808917465908714, 0.5000178451051316), + (3.8916678022926607, -2.772484153217685, 0.5060648178503393), + (3.285753657755655, -2.3681294933425376, 0.46449024233003106), + (2.3465413258596377, -1.7097828382687081, 0.42323551169305323), +) + + +@torch.compile +def zeropower_via_newtonschulz5(G, steps=10, eps=1e-07): + was_2d = G.ndim == 2 + if was_2d: + G = G.unsqueeze(0) + X = G.bfloat16() + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) + coeffs = _PE_COEFFS[:steps] if steps <= len(_PE_COEFFS) else _PE_COEFFS + for a, b, c in coeffs: + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + if transposed: + X = X.mT + if was_2d: + X = X.squeeze(0) + return X + + +def _apply_huber_wd(w, lr, wd, delta): + """Layer 3: Huber weight decay in-place. + Gradient: w (|w|<=delta), delta*sign(w) (|w|>delta). + Transitions from quadratic (L2-like) to linear (L1-like) at |w|=delta. + Bounds the decay rate on outlier weights, preventing GPTQ-damaging over-suppression. + """ + wf = w.float() + abs_w = wf.abs() + grad = torch.where(abs_w <= delta, wf, delta * wf.sign()) + w.add_(grad.to(w.dtype), alpha=-lr * wd) + + +def _muon_zeropower_by_head_pairs( + G: torch.Tensor, + num_heads: int, + head_dim: int, + pair_size: int, + steps: int, +): + if G.ndim != 2: + raise ValueError(f"Expected a 2D matrix for pairwise Muon orthogonalization, got {tuple(G.shape)}") + if pair_size <= 1 or num_heads % pair_size != 0 or G.shape[0] != num_heads * head_dim: + return zeropower_via_newtonschulz5(G, steps=steps) + grouped = G.reshape(num_heads // pair_size, pair_size * head_dim, G.shape[1]) + grouped = zeropower_via_newtonschulz5(grouped, steps=steps) + return grouped.reshape_as(G) + + +def _muon_specialized_zeropower( + update: torch.Tensor, + layout: dict | None, + steps: int, + global_start: int = 0, +): + if not layout: + return zeropower_via_newtonschulz5(update, steps=steps) + mats = update.unsqueeze(0) if update.ndim == 2 else update + out = torch.empty_like(mats) + active_prefix = int(layout.get("active_prefix_slices", 0)) + num_heads = int(layout.get("num_heads", 0)) + head_dim = int(layout.get("head_dim", 0)) + pair_size = int(layout.get("pair_size", 2)) + for local_idx in range(mats.shape[0]): + global_idx = global_start + local_idx + if global_idx < active_prefix: + out[local_idx] = _muon_zeropower_by_head_pairs( + mats[local_idx], + num_heads=num_heads, + head_dim=head_dim, + pair_size=pair_size, + steps=steps, + ) + else: + out[local_idx] = zeropower_via_newtonschulz5(mats[local_idx], steps=steps) + return out.squeeze(0) if update.ndim == 2 else out + + +class Muon(torch.optim.Optimizer): + def __init__( + self, + params, + lr, + momentum, + backend_steps, + nesterov=True, + weight_decay=0.0, + row_normalize=False, + ): + super().__init__( + params, + dict( + lr=lr, + momentum=momentum, + backend_steps=backend_steps, + nesterov=nesterov, + weight_decay=weight_decay, + row_normalize=row_normalize, + ), + ) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + "p": p, + "B": B, + "padded_grad": torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + "shard": torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + "shard_mom": torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + "full_update": torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + "scale": max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + "layout": getattr(p, "_muon_head_pair_layout", None), + }) + self._bank_meta.sort(key=lambda m: -m["p"].numel()) + self._built = True + + def launch_reduce_scatters(self): + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m["p"] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m["padded_grad"] + pg[: m["B"]].copy_(p.grad.bfloat16()) + if pg.shape[0] > m["B"]: + pg[m["B"] :].zero_() + fut = dist.reduce_scatter_tensor( + m["shard"], pg, op=dist.ReduceOp.AVG, async_op=True + ) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + if not self._built: + self._build() + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + row_normalize = group.get("row_normalize", False) + # Layer 3: read AdamHD Huber WD flags from param group + _huber_wd = group.get("huber_wd", False) + _huber_delta = group.get("huber_delta", 0.1) + prev_ag_handle = None + prev_m = None + sharded = self._distributed and hasattr(self, "_rs_futures") + for idx, m in enumerate(self._bank_meta): + p = m["p"] + if p.grad is None: + continue + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m["p"] + upd = prev_m["full_update"][: prev_m["B"]] + if wd > 0.0: + if _huber_wd: + _apply_huber_wd(pp.data, lr, wd, _huber_delta) + else: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m["scale"]) + if sharded and self._rs_futures[idx] is not None: + self._rs_futures[idx].wait() + g = m["shard"] + buf = m["shard_mom"] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + if row_normalize: + rn = update.float().norm(dim=-1, keepdim=True).clamp_min(1e-07) + update = update / rn.to(update.dtype) + shard_start = self._rank * m["shard"].shape[0] if sharded else 0 + update = _muon_specialized_zeropower( + update, + m.get("layout"), + steps=backend_steps, + global_start=shard_start, + ) + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m["full_update"], update, async_op=True + ) + prev_m = m + else: + if wd > 0.0: + if _huber_wd: + _apply_huber_wd(p.data, lr, wd, _huber_delta) + else: + p.data.mul_(1.0 - lr * wd) + p.add_(update.to(dtype=p.dtype), alpha=-lr * m["scale"]) + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m["p"] + upd = prev_m["full_update"][: prev_m["B"]] + if wd > 0.0: + if _huber_wd: + _apply_huber_wd(pp.data, lr, wd, _huber_delta) + else: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m["scale"]) + if hasattr(self, "_rs_futures"): + del self._rs_futures + return loss + + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gates,parallel_post_lambdas,parallel_resid_lambdas,attn_gate_proj,attn_gate_w,smear_gate,smear_lambda", + ).split(",") + if pattern +) + + +PACKED_REPLICATED_GRAD_MAX_NUMEL = 1 << 15 + + +class Optimizers: + def __init__(self, h, base_model): + if h.muon_qk_pair_ortho: + head_dim = h.model_dim // h.num_heads + base_model.qo_bank._muon_head_pair_layout = { + "active_prefix_slices": h.num_layers, + "num_heads": h.num_heads, + "head_dim": head_dim, + "pair_size": h.muon_qk_pair_size, + } + base_model.kv_bank._muon_head_pair_layout = { + "active_prefix_slices": h.num_layers, + "num_heads": h.num_kv_heads, + "head_dim": head_dim, + "pair_size": h.muon_qk_pair_size, + } + matrix_params = [ + base_model.qo_bank, + base_model.kv_bank, + base_model.mlp_up_bank, + base_model.mlp_down_bank, + ] + block_named_params = list(base_model.blocks.named_parameters()) + scalar_params = [ + p + for (name, p) in block_named_params + if p.ndim < 2 + or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.skip_gates is not None and base_model.skip_gates.numel() > 0: + scalar_params.append(base_model.skip_gates) + if base_model.parallel_post_lambdas is not None: + scalar_params.append(base_model.parallel_post_lambdas) + if base_model.parallel_resid_lambdas is not None: + scalar_params.append(base_model.parallel_resid_lambdas) + # SmearGate params live on GPT root (not in .blocks), so add them by hand. + # Both are tiny (gate_window scalars + 1 lambda). Optimized via scalar Adam. + if getattr(base_model, "smear_gate_enabled", False): + scalar_params.append(base_model.smear_gate.weight) + scalar_params.append(base_model.smear_lambda) + token_lr = h.tied_embed_lr if h.tie_embeddings else h.embed_lr + tok_params = [ + {"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr} + ] + self.optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.embed_wd, + fused=True, + ) + self.optimizer_muon = Muon( + matrix_params, + lr=h.matrix_lr, + momentum=h.muon_momentum, + backend_steps=h.muon_backend_steps, + weight_decay=h.muon_wd, + row_normalize=h.muon_row_normalize, + ) + for group in self.optimizer_muon.param_groups: + group["base_lr"] = h.matrix_lr + # Layer 3: AdamHD Huber WD flags injected into each param group. + group["huber_wd"] = bool(getattr(h, "muon_huber_wd", False)) + group["huber_delta"] = float(getattr(h, "muon_huber_delta", 0.1)) + self.optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": h.scalar_lr, "base_lr": h.scalar_lr}], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.adam_wd, + fused=True, + ) + self.optimizers = [ + self.optimizer_tok, + self.optimizer_muon, + self.optimizer_scalar, + ] + self.replicated_params = list(tok_params[0]["params"]) + self.replicated_params.extend(scalar_params) + self.replicated_large_params = [] + self.replicated_packed_params = [] + for p in self.replicated_params: + if p.numel() <= PACKED_REPLICATED_GRAD_MAX_NUMEL: + self.replicated_packed_params.append(p) + else: + self.replicated_large_params.append(p) + + def __iter__(self): + return iter(self.optimizers) + + def zero_grad_all(self): + for opt in self.optimizers: + opt.zero_grad(set_to_none=True) + + def _all_reduce_packed_grads(self): + grads_by_key = collections.defaultdict(list) + for p in self.replicated_packed_params: + if p.grad is not None: + grads_by_key[(p.grad.device, p.grad.dtype)].append(p.grad) + for grads in grads_by_key.values(): + flat = torch.empty( + sum(g.numel() for g in grads), + device=grads[0].device, + dtype=grads[0].dtype, + ) + offset = 0 + for g in grads: + n = g.numel() + flat[offset : offset + n].copy_(g.contiguous().view(-1)) + offset += n + dist.all_reduce(flat, op=dist.ReduceOp.AVG) + offset = 0 + for g in grads: + n = g.numel() + g.copy_(flat[offset : offset + n].view_as(g)) + offset += n + + def step(self, distributed=False): + self.optimizer_muon.launch_reduce_scatters() + if distributed: + reduce_handles = [ + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG, async_op=True) + for p in self.replicated_large_params + if p.grad is not None + ] + self._all_reduce_packed_grads() + for handle in reduce_handles: + handle.wait() + self.optimizer_tok.step() + self.optimizer_scalar.step() + self.optimizer_muon.step() + self.zero_grad_all() + + +def restore_fp32_params(model): + for module in model.modules(): + if isinstance(module, CastedLinear): + module.float() + for name, param in model.named_parameters(): + if ( + param.ndim < 2 + or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ) and param.dtype != torch.float32: + param.data = param.data.float() + if hasattr(model, "qo_bank") and model.qo_bank is not None: + model.qo_bank.data = model.qo_bank.data.float() + model.kv_bank.data = model.kv_bank.data.float() + model.mlp_up_bank.data = model.mlp_up_bank.data.float() + model.mlp_down_bank.data = model.mlp_down_bank.data.float() + + +def collect_hessians(model, train_loader, h, device, n_calibration_batches=64): + hessians = {} + hooks = [] + for i, block in enumerate(model.blocks): + block.attn._calib = True + block.mlp._calib = True + block.mlp.use_fused = False + + def make_attn_hook(layer_idx): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + for suffix in ["c_q", "c_k", "c_v"]: + name = f"blocks.{layer_idx}.attn.{suffix}.weight" + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + y = module._last_proj_input + if y is not None: + y = y.float() + if y.ndim == 3: + y = y.reshape(-1, y.shape[-1]) + name = f"blocks.{layer_idx}.attn.proj.weight" + if name not in hessians: + hessians[name] = torch.zeros( + y.shape[1], y.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(y.T, y) + return hook_fn + + def make_mlp_hook(layer_idx): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + name = f"blocks.{layer_idx}.mlp.fc.weight" + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + h_act = module._last_down_input + if h_act is not None: + h_act = h_act.float() + if h_act.ndim == 3: + h_act = h_act.reshape(-1, h_act.shape[-1]) + name = f"blocks.{layer_idx}.mlp.proj.weight" + if name not in hessians: + hessians[name] = torch.zeros( + h_act.shape[1], h_act.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(h_act.T, h_act) + return hook_fn + + for i, block in enumerate(model.blocks): + hooks.append(block.attn.register_forward_hook(make_attn_hook(i))) + hooks.append(block.mlp.register_forward_hook(make_mlp_hook(i))) + + # Hessian hooks for embedding factorization projection layers + def make_linear_input_hook(weight_name): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if weight_name not in hessians: + hessians[weight_name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[weight_name].addmm_(x.T, x) + return hook_fn + + if model.tie_embeddings: + hook_module = model.final_norm + + def make_output_hook(name): + def hook_fn(module, inp, out): + x = out.detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if getattr(model, "random_proj_embed_enabled", False): + x = model._project_hidden_to_tied_space(x) + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + return hook_fn + + hooks.append( + hook_module.register_forward_hook(make_output_hook("tok_emb.weight")) + ) + model.eval() + with torch.no_grad(): + for _ in range(n_calibration_batches): + x, _ = train_loader.next_batch(h.train_batch_tokens, h.grad_accum_steps) + model.forward_logits(x) + for hook in hooks: + hook.remove() + for i, block in enumerate(model.blocks): + block.attn._calib = False + block.mlp._calib = False + block.mlp.use_fused = True + for name in hessians: + hessians[name] = hessians[name].cpu() / n_calibration_batches + return hessians + + +def gptq_quantize_weight(w, H, clip_sigmas=3.0, clip_range=63, block_size=128): + W_orig = w.float().clone() + rows, cols = W_orig.shape + H = H.float().clone() + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + damp = 0.01 * H.diag().mean() + H.diagonal().add_(damp) + perm = torch.argsort(H.diag(), descending=True) + invperm = torch.argsort(perm) + W_perm = W_orig[:, perm].clone() + W_perm[:, dead[perm]] = 0 + H = H[perm][:, perm] + Hinv = torch.cholesky_inverse(torch.linalg.cholesky(H)) + Hinv = torch.linalg.cholesky(Hinv, upper=True) + row_std = W_orig.std(dim=1) + s = (clip_sigmas * row_std / clip_range).clamp_min(1e-10).to(torch.float16) + sf = s.float() + Q = torch.zeros(rows, cols, dtype=torch.int8) + W_work = W_perm.clone() + for i1 in range(0, cols, block_size): + i2 = min(i1 + block_size, cols) + W_block = W_work[:, i1:i2].clone() + Hinv_block = Hinv[i1:i2, i1:i2] + Err = torch.zeros(rows, i2 - i1) + for j in range(i2 - i1): + w_col = W_block[:, j] + d = Hinv_block[j, j] + q_col = torch.clamp(torch.round(w_col / sf), -clip_range, clip_range) + Q[:, i1 + j] = q_col.to(torch.int8) + err = (w_col - q_col.float() * sf) / d + Err[:, j] = err + W_block[:, j:] -= err.unsqueeze(1) * Hinv_block[j, j:].unsqueeze(0) + if i2 < cols: + W_work[:, i2:] -= Err @ Hinv[i1:i2, i2:] + return Q[:, invperm], s + + +def _quantize_gate_int8_row(w): + # Symmetric int8-per-row quantization for small gate tensors. w shape + # (R, C) -> (R,) scales in fp16, int8 values in [-127, 127]. Single scale + # per row keeps accuracy high while halving storage vs fp16. + W = w.float().contiguous() + row_max = W.abs().amax(dim=1).clamp_min(1e-10) + s = (row_max / 127.0).to(torch.float16) + sf = s.float().view(-1, 1) + q = torch.clamp(torch.round(W / sf), -127, 127).to(torch.int8) + return q, s + + +def _lqer_pack(A, B, bits): + rng = 2 ** (bits - 1) - 1 + sA = (A.abs().amax(dim=1).clamp_min(1e-10) / rng).to(torch.float16) + sB = (B.abs().amax(dim=1).clamp_min(1e-10) / rng).to(torch.float16) + qA = torch.clamp(torch.round(A / sA.float().view(-1, 1)), -rng, rng).to(torch.int8) + qB = torch.clamp(torch.round(B / sB.float().view(-1, 1)), -rng, rng).to(torch.int8) + return qA, sA, qB, sB + + +def _lqer_pack_asym(A, B, g=64): + # A: INT2 per-matrix scalar (signed [-2,1], scale = |A|max/1.5). + sA = (A.abs().amax().clamp_min(1e-10) / 1.5).to(torch.float16) + qA = torch.clamp(torch.round(A / sA.float()), -2, 1).to(torch.int8) + # B: INT4 groupwise g over flattened B (signed [-8,7], per-group scale). + Bf = B.reshape(-1, g) + Bmax = Bf.abs().amax(dim=-1, keepdim=True).clamp_min(1e-10) + sB = (Bmax / 7.5).to(torch.float16).reshape(-1) + qB = torch.clamp(torch.round(Bf / sB.float().reshape(-1, 1)), -8, 7).to( + torch.int8 + ).reshape(B.shape) + return qA, sA, qB, sB + + +def gptq_mixed_quantize(state_dict, hessians, h): + result = {} + meta = {} + quant_gate = bool(getattr(h, "gated_attn_quant_gate", False)) + lqer_on = bool(getattr(h, "lqer_enabled", False)) + lqer_cands = {} + for (name, tensor) in state_dict.items(): + t = tensor.detach().cpu().contiguous() + # Dedicated int8-per-row path for attn_gate_w (bypasses both GPTQ and + # fp16 passthrough). Applied BEFORE the numel<=65536 passthrough check + # so the gate tensor is routed here instead of to fp16. + if ( + quant_gate + and t.is_floating_point() + and t.ndim == 2 + and name.endswith(".attn_gate_w") + # Dense GatedAttn: (num_heads, dim) = (8, 512) = 4096. + # Sparse gate: (num_heads, gate_window) = (8, 12) = 96. + # Both need int8-per-row routing; the 1024 lower bound in stock + # PR-1736 presumed dense-only. Widen to catch both. + and 32 <= t.numel() <= 8192 + ): + gq, gs = _quantize_gate_int8_row(t) + result[name + ".gq"] = gq + result[name + ".gs"] = gs + meta[name] = "gate_int8_row" + continue + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough (float16)" + continue + if "tok_emb" in name: + cs = h.embed_clip_sigmas + elif ".mlp." in name: + cs = h.mlp_clip_sigmas + elif ".attn." in name: + cs = h.attn_clip_sigmas + else: + cs = h.matrix_clip_sigmas + bits = h.embed_bits if "tok_emb" in name else h.matrix_bits + clip_range = 2 ** (bits - 1) - 1 + ret = gptq_quantize_weight( + t, hessians[name], clip_sigmas=cs, clip_range=clip_range + ) + q, s = ret + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = f"gptq (int{bits})" + if lqer_on: + W_q = q.float() * s.float().view(-1, 1) + E = t.float() - W_q + lqer_cands[name] = (E, float(E.norm())) + if lqer_on and lqer_cands: + top = sorted(lqer_cands.items(), key=lambda kv: -kv[1][1])[: h.lqer_top_k] + asym_on = bool(getattr(h, "lqer_asym_enabled", False)) + asym_g = int(getattr(h, "lqer_asym_group", 64)) + for (name, (E, _)) in top: + U, S, Vh = torch.linalg.svd(E, full_matrices=False) + r = min(h.lqer_rank, S.numel()) + A = (U[:, :r] * S[:r]).contiguous() + B = Vh[:r, :].contiguous() + if asym_on and B.numel() % asym_g == 0: + qA, sA, qB, sB = _lqer_pack_asym(A, B, asym_g) + result[name + ".lqA_a"] = qA + result[name + ".lqAs_a"] = sA + result[name + ".lqB_a"] = qB + result[name + ".lqBs_a"] = sB + meta[name] = meta[name] + "+lqer_asym" + else: + qA, sA, qB, sB = _lqer_pack(A, B, h.lqer_factor_bits) + result[name + ".lqA"] = qA + result[name + ".lqAs"] = sA + result[name + ".lqB"] = qB + result[name + ".lqBs"] = sB + meta[name] = meta[name] + "+lqer" + categories = collections.defaultdict(set) + for (name, cat) in meta.items(): + short = re.sub("\\.\\d+$", "", re.sub("blocks\\.\\d+", "blocks", name)) + categories[cat].add(short) + log("Quantized weights:") + for cat in sorted(categories): + log(f" {cat}: {', '.join(sorted(categories[cat]))}") + return result, meta + +def dequantize_mixed(result, meta, template_sd): + out = {} + for (name, orig) in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if "passthrough" in info: + t = result[name] + if t.dtype == torch.float16 and orig_dtype in ( + torch.float32, + torch.bfloat16, + ): + t = t.to(orig_dtype) + out[name] = t + continue + if info == "gate_int8_row": + gq = result[name + ".gq"] + gs = result[name + ".gs"] + out[name] = (gq.float() * gs.float().view(-1, 1)).to(orig_dtype) + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + W = q.float() * s.float().view(q.shape[0], *[1] * (q.ndim - 1)) + else: + W = q.float() * float(s.item()) + if "lqer_asym" in info: + qA_t = result[name + ".lqA_a"] + sA_t = result[name + ".lqAs_a"] + qB_t = result[name + ".lqB_a"] + sB_t = result[name + ".lqBs_a"] + qA = qA_t.float() * float(sA_t) + g_sz = qB_t.numel() // sB_t.numel() + qB = (qB_t.reshape(-1, g_sz).float() * sB_t.float().view(-1, 1)).reshape( + qB_t.shape + ) + W = W + qA @ qB + elif "lqer" in info: + qA = result[name + ".lqA"].float() * result[name + ".lqAs"].float().view(-1, 1) + qB = result[name + ".lqB"].float() * result[name + ".lqBs"].float().view(-1, 1) + W = W + qA @ qB + out[name] = W.to(orig_dtype) + return out + + +_BSHF_MAGIC = b"BSHF" + + +def _byte_shuffle(data, stride=2): + if stride <= 1 or len(data) < stride: + return data + src = np.frombuffer(data, dtype=np.uint8) + n = len(src) + out = np.empty(n, dtype=np.uint8) + dest_off = 0 + for pos in range(stride): + chunk = src[pos::stride] + out[dest_off : dest_off + len(chunk)] = chunk + dest_off += len(chunk) + return _BSHF_MAGIC + bytes([stride]) + out.tobytes() + + +def _byte_unshuffle(data): + if len(data) < 5 or data[:4] != _BSHF_MAGIC: + return data + stride = data[4] + if stride < 2: + return data[5:] + payload = np.frombuffer(data, dtype=np.uint8, offset=5) + n = len(payload) + out = np.empty(n, dtype=np.uint8) + src_off = 0 + for pos in range(stride): + chunk_len = n // stride + (1 if pos < n % stride else 0) + out[pos::stride][:chunk_len] = payload[src_off : src_off + chunk_len] + src_off += chunk_len + return out.tobytes() + + +def _compress(data, compressor): + data = _byte_shuffle(data) + if compressor == "lzma": + return lzma.compress(data, preset=6) + elif compressor == "brotli": + import brotli + + return brotli.compress(data, quality=11) + raise ValueError(f"Unknown compressor: {compressor!r}") + + +def _decompress(data, compressor): + if compressor == "lzma": + raw = lzma.decompress(data) + elif compressor == "brotli": + import brotli + + raw = brotli.decompress(data) + else: + raise ValueError(f"Unknown compressor: {compressor!r}") + raw = _byte_unshuffle(raw) + return raw + + +# --------------------------------------------------------------------------- +# COMPRESSOR=pergroup — role-bucketed lrzip/ZPAQ compression +# +# Splits the quantized state dict into two sections before serialization: +# 1. Q section – the large int8 GPTQ weight tensors (.weight.q keys). +# Each tensor's rows are sorted by L1 nearest-neighbour similarity so +# adjacent rows are numerically close, then transposed for better run- +# length regularity. All sorted+transposed blobs are concatenated and +# compressed with lrzip ZPAQ (-z -L 9). If lrzip is absent the section +# falls back to brotli automatically — never crashes. +# 2. Remainder – scales, LQER factors, passthrough tensors, quant_meta. +# Serialized with torch.save + byte_shuffle + brotli-11 (same as the +# default brotli path). +# +# Frame format (little-endian): +# [4B] magic b"PGRP" +# [4B] uint32 version = 2 +# [4B] uint32 n_q_tensors +# Per Q tensor (sorted by name): +# [2B] uint16 name_len +# [name_len B] name (UTF-8) +# [4B] uint32 rows (original shape[0] before sort+transpose) +# [4B] uint32 cols (original shape[1] before sort+transpose) +# [rows*2 B] uint16 row-permutation indices +# [1B] uint8 q_method (0 = brotli fallback, 1 = lrzip) +# [4B] uint32 q_data_size +# [q_data_size B] compressed Q blob +# [4B] uint32 remainder_size +# [remainder_size B] brotli-compressed remainder +# --------------------------------------------------------------------------- + +import struct as _struct + +_PGRP_MAGIC = b"PGRP" +_PGRP_VERSION = 2 + +# Only the large int8 weight tensors go through the pergroup path. +_PGRP_Q_SUFFIXES = ( + ".mlp.fc.weight.q", + ".mlp.proj.weight.q", + ".attn.c_q.weight.q", + ".attn.proj.weight.q", + ".attn.c_k.weight.q", + ".attn.c_v.weight.q", + "tok_emb.weight.q", +) + + +def _similarity_sort_l1(W): + """Greedy L1 nearest-neighbour row sort. Returns uint16 permutation array. + + O(n²·cols) numpy – takes ~3-6 s on the largest tensors (2048 rows). + """ + n = W.shape[0] + if n <= 1: + return np.arange(n, dtype=np.uint16) + W16 = W.astype(np.int16) + used = np.zeros(n, dtype=bool) + order = np.empty(n, dtype=np.int32) + order[0] = 0 + used[0] = True + for i in range(1, n): + d = np.abs(W16 - W16[order[i - 1]]).sum(axis=1) + d[used] = 2 ** 30 + nxt = int(d.argmin()) + order[i] = nxt + used[nxt] = True + return order.astype(np.uint16) + + +def _lrzip_compress_bytes(data): + """Compress bytes with lrzip ZPAQ via temp files. Returns bytes or None.""" + import tempfile + in_fd, in_path = tempfile.mkstemp(suffix=".bin") + out_path = in_path + ".lrz" + try: + os.write(in_fd, data) + os.close(in_fd) + in_fd = -1 + r = subprocess.run( + ["lrzip", "-z", "-L", "9", "-o", out_path, in_path], + capture_output=True, timeout=300, + ) + if r.returncode == 0 and os.path.exists(out_path): + with open(out_path, "rb") as fh: + return fh.read() + log(f"pergroup:lrzip exit {r.returncode}: {r.stderr.decode()[:200]}") + except FileNotFoundError: + log("pergroup:lrzip not found — falling back to brotli for Q section") + except subprocess.TimeoutExpired: + log("pergroup:lrzip timed out — falling back to brotli for Q section") + except Exception as e: + log(f"pergroup:lrzip error ({e}) — falling back to brotli for Q section") + finally: + if in_fd != -1: + try: os.close(in_fd) + except OSError: pass + for p in (in_path, out_path): + try: os.unlink(p) + except OSError: pass + return None + + +def _lrzip_decompress_bytes(data): + """Decompress lrzip ZPAQ bytes via temp files.""" + import tempfile + lrz_fd, lrz_path = tempfile.mkstemp(suffix=".lrz") + # lrzip strips .lrz → output name is lrz_path[:-4] + out_path = lrz_path[:-4] + try: + os.write(lrz_fd, data) + os.close(lrz_fd) + lrz_fd = -1 + r = subprocess.run( + ["lrzip", "-d", "-o", out_path, lrz_path], + capture_output=True, timeout=300, + ) + if r.returncode != 0: + raise RuntimeError(f"lrzip -d failed (exit {r.returncode}): {r.stderr.decode()[:400]}") + with open(out_path, "rb") as fh: + return fh.read() + finally: + if lrz_fd != -1: + try: os.close(lrz_fd) + except OSError: pass + # lrzip -d removes the input .lrz by default; OSError caught if already gone + for p in (lrz_path, out_path): + try: os.unlink(p) + except OSError: pass + + +def _pack_pergroup(quant_result, quant_meta): + """Serialize quantized state dict using the PGRP frame format. + + Q tensors → similarity-sorted, transposed, lrzip ZPAQ (brotli fallback). + Everything else → torch.save + byte_shuffle + brotli-11. + """ + import brotli + + # --- split Q tensors from remainder --- + q_items = {} # name → int8 numpy array + remainder = {} # name → tensor (scales, LQER, passthrough) + for name, t in quant_result.items(): + if any(name.endswith(sfx) for sfx in _PGRP_Q_SUFFIXES): + q_items[name] = (t.numpy() if hasattr(t, "numpy") else np.asarray(t)) + else: + remainder[name] = t + + q_names = sorted(q_items.keys()) + + # --- similarity sort + transpose per Q tensor --- + q_perms = {} # name → uint16 perm array + q_shapes = {} # name → (rows, cols) + q_blobs = [] # sorted+transposed int8 bytes, concatenated later + + t_sort = time.perf_counter() + for name in q_names: + W = q_items[name] + assert W.ndim == 2, f"pergroup: Q tensor {name} is not 2D: {W.shape}" + rows, cols = W.shape + q_shapes[name] = (rows, cols) + perm = _similarity_sort_l1(W) + q_perms[name] = perm + # sort rows then transpose → (cols, rows); adjacent values more similar + W_st = W[perm.astype(np.int32)].T.astype(np.int8) + q_blobs.append(W_st.tobytes()) + log(f"pergroup:similarity sort done in {time.perf_counter()-t_sort:.1f}s ({len(q_names)} tensors)") + + q_data_raw = b"".join(q_blobs) + + # --- compress Q section (lrzip ZPAQ, brotli fallback) --- + q_compressed = _lrzip_compress_bytes(q_data_raw) + if q_compressed is not None: + q_method = 1 # lrzip + else: + q_compressed = brotli.compress(q_data_raw, quality=11) + q_method = 0 # brotli fallback + log( + f"pergroup:Q {len(q_data_raw)} raw → {len(q_compressed)} " + f"({'lrzip' if q_method else 'brotli'}) " + f"({100*len(q_compressed)/max(len(q_data_raw),1):.1f}%)" + ) + + # --- remainder section: torch.save + byte_shuffle + brotli --- + rem_buf = io.BytesIO() + torch.save({"w": remainder, "m": quant_meta}, rem_buf) + rem_compressed = brotli.compress(_byte_shuffle(rem_buf.getvalue()), quality=11) + log(f"pergroup:remainder {rem_buf.tell()} raw → {len(rem_compressed)} brotli") + + # --- assemble PGRP frame --- + out = io.BytesIO() + out.write(_PGRP_MAGIC) + out.write(_struct.pack("= 1, f"OptRot: W.shape[0]={n} must be power of 2" + Y = W.clone() + h_step = 1 + while h_step < n: + # Butterfly: view as (n_blocks, 2, h_step, ...) so dim-1 selects first/second half + # of each 2*h_step block. Pairs element i with element i+h_step within each block. + n_blocks = n // (h_step * 2) + Y2 = Y.view(n_blocks, 2, h_step, *Y.shape[1:]) + a = Y2[:, 0, ...].clone() + b = Y2[:, 1, ...].clone() + Y2[:, 0, ...] = a + b + Y2[:, 1, ...] = a - b + Y = Y2.view(n, *Y.shape[1:]) + h_step *= 2 + return Y / math.sqrt(n) + + +def _optrot_apply(sd, h): + """Layer 2: Apply per-layer Hadamard rotation to (W_up, W_down) and (W_v, W_o) pairs. + + Rotation R = FWHT / sqrt(n) is orthogonal and self-inverse (R^2 = I). + For the MLP: W_up' = R @ W_up, W_down' = W_down @ R. Product preserved: + W_down' @ W_up' = W_down @ R @ R @ W_up = W_down @ W_up. + For attention V→O (per kv-head, per query-head group): W_v' = R @ W_v per kv-head block; + W_o' = W_o @ R per corresponding query-head column block. Product preserved. + + Only applied to layers with hidden_dim and head_dim that are powers of 2 (all layers + in the default 11L 512d 8H/4KV architecture: hidden=2048, head_dim=64, both OK). + """ + num_layers = h.num_layers + num_heads = h.num_heads + num_kv_heads = h.num_kv_heads + model_dim = h.model_dim + head_dim = model_dim // num_heads + kv_group = num_heads // num_kv_heads + + for i in range(num_layers): + # --- MLP rotation --- + W_up = sd[f"blocks.{i}.mlp.fc.weight"].float() # (hidden_dim, model_dim) + W_down = sd[f"blocks.{i}.mlp.proj.weight"].float() # (model_dim, hidden_dim) + hidden_dim = W_up.shape[0] + if (hidden_dim & (hidden_dim - 1)) == 0: + # R @ W_up: apply FWHT to rows of W_up (along axis 0) + W_up_r = _fwht_along_0(W_up) + # W_down @ R = (R @ W_down.T).T: apply FWHT to rows of W_down.T, then transpose + W_down_r = _fwht_along_0(W_down.T).T + sd[f"blocks.{i}.mlp.fc.weight"] = W_up_r.to(sd[f"blocks.{i}.mlp.fc.weight"].dtype) + sd[f"blocks.{i}.mlp.proj.weight"] = W_down_r.to(sd[f"blocks.{i}.mlp.proj.weight"].dtype) + + # --- Attention V→O rotation (per kv-head, per query-head group) --- + W_v = sd[f"blocks.{i}.attn.c_v.weight"].float() # (kv_dim, model_dim) + W_o = sd[f"blocks.{i}.attn.proj.weight"].float() # (model_dim, model_dim) + kv_dim = W_v.shape[0] + if (head_dim & (head_dim - 1)) == 0: + # Reshape V to (num_kv_heads, head_dim, model_dim), rotate head_dim (axis 0 per head) + V = W_v.view(num_kv_heads, head_dim, model_dim) + V_r = torch.stack([_fwht_along_0(V[k]) for k in range(num_kv_heads)], dim=0) + sd[f"blocks.{i}.attn.c_v.weight"] = V_r.view(kv_dim, model_dim).to(W_v.dtype) + + # Reshape O to (model_dim, num_heads, head_dim). + # For each kv-head k: the corresponding query-head range uses the SAME V rotation. + # Apply R to the head_dim columns of O for each query-head in the group. + O = W_o.view(model_dim, num_heads, head_dim) + for k in range(num_kv_heads): + for g in range(kv_group): + h_idx = k * kv_group + g + # O_col block: (model_dim, head_dim). Apply R on head_dim (axis 0 of .T). + # W_o' @ (R @ y_h) = (W_o @ R) @ (R @ y_h) = W_o @ R^2 @ y_h = W_o @ y_h (R^2=I) + # So we need W_o' columns rotated by R: cols = O[:, h_idx, :], rotate last dim. + O[:, h_idx, :] = _fwht_along_0(O[:, h_idx, :].T).T + sd[f"blocks.{i}.attn.proj.weight"] = O.view(model_dim, model_dim).to(W_o.dtype) + + return sd + + +def _compressed_code_size(code): + code_raw = code.encode("utf-8") + minified = subprocess.run( + ["pyminify", "--no-rename-locals", "--no-hoist-literals", "--remove-literal-statements", "-"], + input=code_raw, capture_output=True, check=True, + ).stdout + compressed = lzma.compress(minified) + encoded = base64.b85encode(compressed) + wrapper = b'import lzma as L,base64 as B\nexec(L.decompress(B.b85decode("' + encoded + b'")))\n' + return len(code_raw), len(wrapper) + + +def serialize(h, base_model, code): + code_bytes_uncompressed, code_bytes = _compressed_code_size(code) + if h.is_main_process: + torch.save(base_model.state_dict(), h.model_path) + model_bytes = os.path.getsize(h.model_path) + log(f"Serialized model: {model_bytes} bytes") + log(f"Code size (uncompressed): {code_bytes_uncompressed} bytes") + log(f"Code size (compressed): {code_bytes} bytes") + sd_cpu = _unbank_state_dict(base_model.state_dict(), h.num_layers) + # Layer 2: OptRot — apply Hadamard rotation to (W_up, W_down) and (W_v, W_o) pairs + # BEFORE Hessian collection and GPTQ. The rotation is orthogonal and self-inverse: + # model outputs are preserved exactly, but weight distributions are more uniform, + # reducing int6 quantization error. The Hessian must be collected on the ROTATED model + # so that GPTQ sees the same input statistics as the rotated forward pass. + if getattr(h, "optrot_enabled", False): + log("OptRot: applying pre-GPTQ Hadamard rotation to (W_up,W_down) and (V,O) pairs...") + t_rot = time.perf_counter() + sd_cpu = _optrot_apply(sd_cpu, h) + # Reload rotated weights into base_model so Hessian collection sees rotated activations + head_dim = h.model_dim // h.num_heads + kv_dim = h.num_kv_heads * head_dim + hidden_dim = int(h.mlp_mult * h.model_dim) + rotated_banked = _rebank_state_dict(sd_cpu, h.num_layers, h.model_dim, kv_dim, hidden_dim) + base_model.load_state_dict(rotated_banked, strict=True) + log(f"OptRot: done in {time.perf_counter()-t_rot:.1f}s") + device = torch.device("cuda", h.local_rank) + t0 = time.perf_counter() + calib_mode = os.getenv("GPTQ_CALIBRATION_MODE", "random") + if calib_mode == "doc_start": + calib_loader = DocStartSequenceLoader(h, device, bos_id=BOS_ID if BOS_ID is not None else 1) + else: + calib_loader = ShuffledSequenceLoader(h, device) + log("GPTQ:collecting Hessians from calibration data...") + hessians = collect_hessians( + base_model, + calib_loader, + h, + device, + n_calibration_batches=h.gptq_calibration_batches, + ) + log(f"GPTQ:collected {len(hessians)} Hessians in {time.perf_counter()-t0:.1f}s") + quant_result, quant_meta = gptq_mixed_quantize(sd_cpu, hessians, h) + if h.compressor == "pergroup": + quant_blob = _pack_pergroup(quant_result, quant_meta) + else: + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = _compress(quant_raw, h.compressor) + quant_file_bytes = len(quant_blob) + bytes_total = quant_file_bytes + code_bytes + if h.is_main_process: + with open(h.quantized_model_path, "wb") as f: + f.write(quant_blob) + log(f"Serialized model quantized+{h.compressor}: {quant_file_bytes} bytes") + log(f"Total submission size quantized+{h.compressor}: {bytes_total} bytes") + return bytes_total, quant_file_bytes + + +def deserialize(h, device): + eval_model = GPT(h).to(device).bfloat16() + restore_fp32_params(eval_model) + flat_template = _unbank_state_dict(eval_model.state_dict(), h.num_layers) + with open(h.quantized_model_path, "rb") as f: + quant_blob_disk = f.read() + if h.compressor == "pergroup": + quant_state = _unpack_pergroup(quant_blob_disk) + else: + quant_state = torch.load( + io.BytesIO(_decompress(quant_blob_disk, h.compressor)), map_location="cpu" + ) + deq_flat = dequantize_mixed(quant_state["w"], quant_state["m"], flat_template) + head_dim = h.model_dim // h.num_heads + kv_dim = h.num_kv_heads * head_dim + hidden_dim = int(h.mlp_mult * h.model_dim) + deq_state = _rebank_state_dict(deq_flat, h.num_layers, h.model_dim, kv_dim, hidden_dim) + eval_model.load_state_dict(deq_state, strict=True) + return eval_model + + +def _loss_bpb(loss_sum, token_count, byte_count): + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + return val_loss, val_bpb + + +def eval_val(h, device, val_data, model, forward_logits_fn=None): + seq_len = h.eval_seq_len + local_batch_tokens = h.val_batch_tokens // (h.world_size * h.grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + f"VAL_BATCH_SIZE must provide at least one sequence per rank; got VAL_BATCH_SIZE={h.val_batch_tokens}, WORLD_SIZE={h.world_size}, GRAD_ACCUM_STEPS={h.grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_data.val_tokens.numel() - 1) // seq_len + seq_start = total_seqs * h.rank // h.world_size + seq_end = total_seqs * (h.rank + 1) // h.world_size + + # TODO: Don't truncate this. + seq_end = seq_start + ((seq_end - seq_start) // local_batch_seqs) * local_batch_seqs + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + run_forward_logits = ( + (model.module.forward_logits if hasattr(model, "module") else model.forward_logits) + if forward_logits_fn is None + else forward_logits_fn + ) + model.eval() + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + with torch.no_grad(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_data.val_tokens[raw_start:raw_end].to( + device=device, dtype=torch.int64, non_blocking=True + ) + x = local[:-1] + y = local[1:] + bos_pos = (x == BOS_ID).nonzero(as_tuple=True)[0].tolist() + cu_seqlens, max_seqlen = _build_cu_seqlens( + bos_pos, x.numel(), x.device, h.eval_seq_len, 64 + ) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = run_forward_logits( + x[None], cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ).detach() + per_token_loss = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y.reshape(-1), + reduction="none", + ) + val_loss_sum += per_token_loss.to(torch.float64).sum() + val_token_count += float(y.numel()) + prev_ids = x + tgt_ids = y + if val_data.caseops_enabled and val_data.val_bytes is not None: + # CaseOps: read per-token byte budget from sidecar at the same + # global positions as the target tokens y. raw_start/raw_end + # span [raw_start, raw_end), x = local[:-1], y = local[1:], + # so y is at sidecar positions [raw_start + 1, raw_end). + sidecar_slice = val_data.val_bytes[raw_start + 1 : raw_end].to( + device=device, dtype=torch.int32, non_blocking=True + ) + val_byte_count += sidecar_slice.to(torch.float64).sum() + else: + token_bytes = val_data.base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += ( + val_data.has_leading_space_lut[tgt_ids] + & ~val_data.is_boundary_token_lut[prev_ids] + ).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + model.train() + return _loss_bpb(val_loss_sum, val_token_count, val_byte_count) + + +def _find_docs(all_tokens): + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() + docs = [] + for i in range(len(bos_positions)): + start = int(bos_positions[i]) + end = ( + int(bos_positions[i + 1]) + if i + 1 < len(bos_positions) + else all_tokens.numel() + ) + if i + 1 < len(bos_positions): + end += 1 + assert end - start >= 2 + docs.append((start, end - start)) + return docs + + +def _build_ttt_global_batches(doc_entries, h, ascending=False): + batch_size = h.ttt_batch_size + global_doc_entries = sorted(doc_entries, key=lambda x: x[1][1]) + global_batches = [ + global_doc_entries[i : i + batch_size] + for i in range(0, len(global_doc_entries), batch_size) + ] + indexed = list(enumerate(global_batches)) + if not ascending: + indexed.sort(key=lambda ib: -max(dl for _, (_, dl) in ib[1])) + return indexed + + +def _init_batch_counter(path): + with open(path, "wb") as f: + f.write((0).to_bytes(4, "little")) + + +def _claim_next_batch(counter_path, queue_len): + try: + with open(counter_path, "r+b") as f: + fcntl.flock(f, fcntl.LOCK_EX) + idx = int.from_bytes(f.read(4), "little") + f.seek(0) + f.write((idx + 1).to_bytes(4, "little")) + f.flush() + except FileNotFoundError: + return queue_len + return idx + + +def _compute_chunk_window(ci, pred_len, num_chunks, chunk_size, eval_seq_len): + chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size + win_start = max(0, chunk_end - eval_seq_len) + win_len = chunk_end - win_start + chunk_start = ci * chunk_size + chunk_offset = chunk_start - win_start + chunk_len = chunk_end - chunk_start + return win_start, win_len, chunk_offset, chunk_len + + +def _accumulate_bpb( + ptl, + x, + y, + chunk_offsets, + chunk_lens, + pos_idx, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + loss_sum, + byte_sum, + token_count, + y_bytes=None, +): + pos = pos_idx[: x.size(1)].unsqueeze(0) + mask = ( + (chunk_lens.unsqueeze(1) > 0) + & (pos >= chunk_offsets.unsqueeze(1)) + & (pos < (chunk_offsets + chunk_lens).unsqueeze(1)) + ) + mask_f64 = mask.to(torch.float64) + if y_bytes is not None: + tok_bytes = y_bytes.to(torch.float64) + else: + tok_bytes = base_bytes_lut[y].to(torch.float64) + tok_bytes += (has_leading_space_lut[y] & ~is_boundary_token_lut[x]).to( + torch.float64 + ) + loss_sum += (ptl.to(torch.float64) * mask_f64).sum() + byte_sum += (tok_bytes * mask_f64).sum() + token_count += chunk_lens.to(torch.float64).sum() + + +def _loss_bpb_from_sums(loss_sum, token_count, byte_sum): + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_sum.item()) + return val_loss, val_bpb + + +def _add_to_counter(path, delta): + try: + with open(path, "r+b") as f: + fcntl.flock(f, fcntl.LOCK_EX) + cur = int.from_bytes(f.read(8), "little", signed=True) + cur += int(delta) + f.seek(0) + f.write(int(cur).to_bytes(8, "little", signed=True)) + f.flush() + return cur + except FileNotFoundError: + return int(delta) + + +def _init_int64_counter(path): + with open(path, "wb") as f: + f.write((0).to_bytes(8, "little", signed=True)) + + +def _select_ttt_doc_entries(docs, h): + doc_entries = list(enumerate(docs)) + if h.val_doc_fraction < 1.0: + sample_n = max(1, int(round(len(docs) * h.val_doc_fraction))) + sampled_indices = sorted( + random.Random(h.seed).sample(range(len(docs)), sample_n) + ) + return [(i, docs[i]) for i in sampled_indices] + return doc_entries + + +def train_val_ttt_global_sgd_distributed(h, device, val_data, base_model, val_tokens, batch_seqs=None): + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + base_model.eval() + seq_len = h.eval_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = h.global_ttt_chunk_tokens + batch_seqs = h.global_ttt_batch_seqs if batch_seqs is None else batch_seqs + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + ttt_params = [p for p in base_model.parameters()] + for p in ttt_params: + p.requires_grad_(True) + # Layer 4: LaCT — use Muon as the fast-weight optimizer for global TTT. + # GLOBAL_TTT_OPTIMIZER=muon: Newton-Schulz orthogonalized updates for matrix params, + # SGD for scalars/vectors. Matches the training optimizer, enabling gradient norms + # compatible with learned LR scale — the key LaCT efficiency improvement. + _global_ttt_opt = getattr(h, "global_ttt_optimizer", "sgd") + if _global_ttt_opt == "muon": + _ttt_matrix_params = [p for p in ttt_params if p.ndim >= 2] + _ttt_scalar_params = [p for p in ttt_params if p.ndim < 2] + _ns_steps = int(getattr(h, "global_ttt_muon_ns_steps", 5)) + _nesterov = bool(getattr(h, "global_ttt_muon_nesterov", True)) + optimizer = Muon( + _ttt_matrix_params, + lr=h.global_ttt_lr, + momentum=h.global_ttt_momentum, + backend_steps=_ns_steps, + nesterov=_nesterov, + weight_decay=0.0, + ) + _scalar_opt = ( + torch.optim.SGD(_ttt_scalar_params, lr=h.global_ttt_lr, momentum=h.global_ttt_momentum) + if _ttt_scalar_params else None + ) + def _ttt_zero_grad(): + optimizer.zero_grad(set_to_none=True) + if _scalar_opt: + _scalar_opt.zero_grad(set_to_none=True) + def _ttt_step(): + optimizer.step() + if _scalar_opt: + _scalar_opt.step() + else: + optimizer = torch.optim.SGD( + ttt_params, lr=h.global_ttt_lr, momentum=h.global_ttt_momentum + ) + _scalar_opt = None + def _ttt_zero_grad(): + optimizer.zero_grad(set_to_none=True) + def _ttt_step(): + optimizer.step() + t_start = time.perf_counter() + for ci in range(num_chunks): + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + is_last_chunk = ci == num_chunks - 1 + if is_last_chunk or h.global_ttt_epochs <= 0: + continue + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs <= 0: + continue + warmup_chunks = max(0, min(h.global_ttt_warmup_chunks, num_chunks - 1)) + if warmup_chunks > 0 and ci < warmup_chunks: + warmup_denom = max(warmup_chunks - 1, 1) + warmup_t = ci / warmup_denom + lr_now = ( + h.global_ttt_warmup_start_lr + + (h.global_ttt_lr - h.global_ttt_warmup_start_lr) * warmup_t + ) + else: + decay_steps = max(num_chunks - 1 - warmup_chunks, 1) + decay_ci = max(ci - warmup_chunks, 0) + lr_now = h.global_ttt_lr * 0.5 * ( + 1.0 + math.cos(math.pi * decay_ci / decay_steps) + ) + for pg in optimizer.param_groups: + pg["lr"] = lr_now + if _scalar_opt is not None: + for pg in _scalar_opt.param_groups: + pg["lr"] = lr_now + my_seq_s = chunk_seqs * h.rank // h.world_size + my_seq_e = chunk_seqs * (h.rank + 1) // h.world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ in range(h.global_ttt_epochs): + for bs in range(0, my_chunk_seqs, batch_seqs): + be = min(bs + batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x_flat = local[:-1] + y_flat = local[1:] + _ttt_zero_grad() + with torch.enable_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + if h.global_ttt_respect_doc_boundaries: + bos_pos = (x_flat == BOS_ID).nonzero(as_tuple=True)[0].tolist() + cu_seqlens, max_seqlen = _build_cu_seqlens( + bos_pos, x_flat.numel(), x_flat.device, h.eval_seq_len, 64 + ) + loss = base_model( + x_flat[None], + y_flat[None], + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + else: + x = x_flat.reshape(-1, seq_len) + y = y_flat.reshape(-1, seq_len) + loss = base_model(x, y) + loss.backward() + if dist.is_available() and dist.is_initialized(): + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.SUM) + p.grad.mul_(1.0 / h.world_size) + if h.global_ttt_grad_clip > 0: + torch.nn.utils.clip_grad_norm_(ttt_params, h.global_ttt_grad_clip) + _ttt_step() + base_model.eval() + if h.rank == 0: + elapsed = time.perf_counter() - t_start + log( + f"tttg: c{ci+1}/{num_chunks} lr:{lr_now:.6f} t:{elapsed:.1f}s" + ) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + +def _compute_ngram_hints_for_val(h, val_data, log0=print): + """Precompute n-gram hints before eval timer starts (Stage 1A from PR #1967). + + Returns (hint_global, gate_global, boost_global) CPU tensors, or None if tilt disabled. + Single L->R causal pass over val tokens only — compliant with C1/C3/C4 constraints. + """ + if not getattr(h, "ngram_tilt_enabled", False): + return None + from online_ngram_tilt import build_hints_for_targets + all_tokens = val_data.val_tokens + targets_np_all = all_tokens.cpu().numpy().astype("uint16", copy=False)[1:] + t_h0 = time.perf_counter() + hints_pkg = build_hints_for_targets( + target_token_ids_np=targets_np_all, + tokenizer_path=h.tokenizer_path, + vocab_size=h.vocab_size, + log0=log0, + token_order=h.token_order, + token_threshold=h.token_threshold, + token_boost=h.token_boost, + within_tau=h.within_tau, + within_boost=h.within_boost, + word_order=h.word_order, + word_normalize=h.word_normalize, + word_tau=h.word_tau, + word_boost=h.word_boost, + agree_add_boost=h.agree_add_boost, + ) + hint_global = torch.from_numpy(hints_pkg["hint_ids"].astype("int64")) + gate_global = torch.from_numpy(hints_pkg["gate_mask"]) + boost_global = torch.from_numpy(hints_pkg["boost"].astype("float32")) + log0( + f"ngram_tilt:precompute_outside_timer_done elapsed={time.perf_counter()-t_h0:.2f}s " + f"total_targets={hint_global.numel()}" + ) + return (hint_global, gate_global, boost_global) + + +def eval_val_ttt_phased(h, base_model, device, val_data, forward_ttt_train, precomputed_hints=None): + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + TTT_LORA_EMA_DECAY = float(os.environ.get("TTT_LORA_EMA_DECAY", "0.0")) + ttt_lora_ema_enabled = TTT_LORA_EMA_DECAY > 0.0 + TTT_UPDATE_EVERY = int(os.environ.get("TTT_UPDATE_EVERY", "1")) + base_model.eval() + for p in base_model.parameters(): + p.requires_grad_(False) + all_tokens = val_data.val_tokens + all_tokens_idx = all_tokens.to(torch.int32) + # === PR #1145 n-gram tilt: set up hint tensors (CPU) === + # hint_global[i] = hinted token id for predicting all_tokens[i+1] from prefix [:i+1]. + # gate_global[i] = True when any expert fires for position i. + # boost_global[i] = combined boost beta for position i. + ngram_hint_global = None + ngram_gate_global = None + ngram_boost_global = None + if precomputed_hints is not None: + ngram_hint_global, ngram_gate_global, ngram_boost_global = precomputed_hints + log( + f"ngram_tilt:using_precomputed_hints " + f"total_targets={ngram_hint_global.numel()} (precompute excluded from eval timer)" + ) + elif getattr(h, "ngram_tilt_enabled", False): + from online_ngram_tilt import build_hints_for_targets + targets_np_all = all_tokens.cpu().numpy().astype("uint16", copy=False)[1:] + t_h0 = time.perf_counter() + hints_pkg = build_hints_for_targets( + target_token_ids_np=targets_np_all, + tokenizer_path=h.tokenizer_path, + vocab_size=h.vocab_size, + log0=log, + token_order=h.token_order, + token_threshold=h.token_threshold, + token_boost=h.token_boost, + within_tau=h.within_tau, + within_boost=h.within_boost, + word_order=h.word_order, + word_normalize=h.word_normalize, + word_tau=h.word_tau, + word_boost=h.word_boost, + agree_add_boost=h.agree_add_boost, + ) + ngram_hint_global = torch.from_numpy(hints_pkg["hint_ids"].astype("int64")) + ngram_gate_global = torch.from_numpy(hints_pkg["gate_mask"]) + ngram_boost_global = torch.from_numpy(hints_pkg["boost"].astype("float32")) + log( + f"ngram_tilt:precompute_done elapsed={time.perf_counter()-t_h0:.2f}s " + f"total_targets={ngram_hint_global.numel()}" + ) + docs = _find_docs(all_tokens) + doc_entries = _select_ttt_doc_entries(docs, h) + prefix_doc_limit = max(0, min(len(doc_entries), int(h.phased_ttt_prefix_docs))) + num_phases = max(1, int(h.phased_ttt_num_phases)) + phase_boundaries = [] + for pi in range(num_phases): + boundary = prefix_doc_limit * (pi + 1) // num_phases + phase_boundaries.append(boundary) + current_phase = 0 + current_phase_boundary = phase_boundaries[0] + log( + "ttt_phased:" + f" total_docs:{len(doc_entries)} prefix_docs:{prefix_doc_limit} " + f"suffix_docs:{len(doc_entries) - prefix_doc_limit}" + f" num_phases:{num_phases} boundaries:{phase_boundaries}" + ) + chunk_size, eval_seq_len = h.ttt_chunk_size, h.ttt_eval_seq_len + eval_batch_set = None + if h.ttt_eval_batches: + eval_batch_set = set(int(x) for x in h.ttt_eval_batches.split(",") if x.strip()) + use_ascending = eval_batch_set is not None + global_batches_sorted = _build_ttt_global_batches( + doc_entries, h, ascending=use_ascending + ) + queue_len = len(global_batches_sorted) + counter_path = f"/tmp/ttt_counter_{h.run_id}" + prefix_counter_path = f"/tmp/ttt_prefix_counter_{h.run_id}" + pause_flag_path = f"/tmp/ttt_pause_flag_{h.run_id}" + if h.rank == 0: + _init_batch_counter(counter_path) + _init_int64_counter(prefix_counter_path) + try: + os.remove(pause_flag_path) + except FileNotFoundError: + pass + if dist.is_available() and dist.is_initialized(): + path_list = [counter_path, prefix_counter_path, pause_flag_path] + dist.broadcast_object_list(path_list, src=0) + counter_path, prefix_counter_path, pause_flag_path = path_list + dist.barrier() + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + t_start = time.perf_counter() + reusable_lora = BatchedTTTLoRA( + h.ttt_batch_size, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + reusable_ema_lora = ( + BatchedTTTLoRA( + h.ttt_batch_size, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + if ttt_lora_ema_enabled else None + ) + + def _build_opt(lora): + if h.ttt_optimizer == "sgd": + return torch.optim.SGD( + lora.parameters(), lr=h.ttt_lora_lr, + momentum=h.ttt_beta1, weight_decay=h.ttt_weight_decay, + ) + return torch.optim.AdamW( + lora.parameters(), lr=h.ttt_lora_lr, + betas=(h.ttt_beta1, h.ttt_beta2), + eps=1e-10, weight_decay=h.ttt_weight_decay, fused=True, + ) + + reusable_opt = _build_opt(reusable_lora) + local_scored_docs = [] + global_ttt_done = prefix_doc_limit == 0 + try: + while True: + queue_idx = _claim_next_batch(counter_path, queue_len) + if queue_idx >= queue_len: + break + orig_batch_idx, batch_entries = global_batches_sorted[queue_idx] + batch = [doc for _, doc in batch_entries] + bsz = len(batch) + prev_loss = loss_sum.item() + prev_bytes = byte_sum.item() + prev_tokens = token_count.item() + if bsz == reusable_lora.bsz: + reusable_lora.reset() + for s in reusable_opt.state.values(): + for k, v in s.items(): + if isinstance(v, torch.Tensor): + v.zero_() + elif k == "step": + s[k] = 0 + cur_lora = reusable_lora + cur_opt = reusable_opt + if ttt_lora_ema_enabled: + reusable_ema_lora.reset() + with torch.no_grad(): + for ema_p, raw_p in zip(reusable_ema_lora.parameters(), cur_lora.parameters()): + ema_p.data.copy_(raw_p.data) + cur_ema_lora = reusable_ema_lora + else: + cur_lora = BatchedTTTLoRA( + bsz, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + cur_opt = _build_opt(cur_lora) + if ttt_lora_ema_enabled: + cur_ema_lora = BatchedTTTLoRA( + bsz, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + with torch.no_grad(): + for ema_p, raw_p in zip(cur_ema_lora.parameters(), cur_lora.parameters()): + ema_p.data.copy_(raw_p.data) + pred_lens = [doc_len - 1 for _, doc_len in batch] + num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] + max_nc = max(num_chunks) + num_chunks_t = torch.tensor(num_chunks, dtype=torch.int64, device=device) + for ci in range(max_nc): + active = [ci < nc for nc in num_chunks] + needs_train = any(ci < nc - 1 for nc in num_chunks) + tok_starts = torch.zeros(bsz, dtype=torch.int64) + tok_wls = torch.zeros(bsz, dtype=torch.int64) + chunk_offsets_cpu = torch.zeros(bsz, dtype=torch.int64) + chunk_lens_cpu = torch.zeros(bsz, dtype=torch.int64) + for b in range(bsz): + if not active[b]: + continue + doc_start, doc_len = batch[b] + win_start, win_len, chunk_offset, chunk_len = _compute_chunk_window( + ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len + ) + tok_starts[b] = doc_start + win_start + tok_wls[b] = win_len + chunk_offsets_cpu[b] = chunk_offset + chunk_lens_cpu[b] = chunk_len + _, context_size, chunk_offset, _ = _compute_chunk_window( + ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len + ) + col_idx = torch.arange(context_size + 1) + idx = tok_starts.unsqueeze(1) + col_idx.unsqueeze(0) + idx.clamp_(max=all_tokens.numel() - 1) + gathered_gpu = all_tokens_idx[idx].to( + device=device, dtype=torch.int64, non_blocking=True + ) + valid = (col_idx[:context_size].unsqueeze(0) < tok_wls.unsqueeze(1)).to( + device, non_blocking=True + ) + chunk_offsets = chunk_offsets_cpu.to(device, non_blocking=True) + chunk_lens = chunk_lens_cpu.to(device, non_blocking=True) + x = torch.where(valid, gathered_gpu[:, :context_size], 0) + y = torch.where(valid, gathered_gpu[:, 1 : context_size + 1], 0) + ctx_pos = torch.arange(context_size, device=device, dtype=torch.int64) + # n-gram tilt: gather hints aligned to y for this chunk + hint_ids_gpu = None + gate_mask_gpu = None + boost_gpu = None + if ngram_hint_global is not None: + hint_idx_cpu = ( + tok_starts.unsqueeze(1) + col_idx[:context_size].unsqueeze(0) + ).clamp_(min=0, max=ngram_hint_global.numel() - 1) + hint_ids_gpu = ngram_hint_global[hint_idx_cpu].to( + device=device, dtype=torch.int64, non_blocking=True + ) + gate_mask_gpu = ngram_gate_global[hint_idx_cpu].to( + device=device, non_blocking=True + ) + boost_gpu = ngram_boost_global[hint_idx_cpu].to( + device=device, dtype=torch.float32, non_blocking=True + ) + hint_ids_gpu = torch.where(valid, hint_ids_gpu, torch.zeros_like(hint_ids_gpu)) + gate_mask_gpu = gate_mask_gpu & valid + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + if hint_ids_gpu is not None: + per_tok_loss, log_q_hint = forward_ttt_train( + x, y, lora=cur_ema_lora if ttt_lora_ema_enabled else cur_lora, + hint_ids=hint_ids_gpu, + ) + else: + per_tok_loss = forward_ttt_train( + x, y, lora=cur_ema_lora if ttt_lora_ema_enabled else cur_lora + ) + log_q_hint = None + # Apply closed-form tilt to BPB accumulation only (not to TTT training objective). + if hint_ids_gpu is not None and log_q_hint is not None: + from online_ngram_tilt import apply_tilt_to_ptl_torch_fast + tilted_loss = apply_tilt_to_ptl_torch_fast( + ptl=per_tok_loss, + log_q_hint=log_q_hint, + target_ids=y, + hint_ids=hint_ids_gpu, + gate_mask=gate_mask_gpu, + boost=boost_gpu, + ) + else: + tilted_loss = per_tok_loss + # CaseOps sidecar-driven byte budget. Mirror the index pattern + # used to build y from all_tokens: y[b, j] corresponds to the + # token at global position tok_starts[b] + 1 + j (when valid). + y_bytes_arg = None + if val_data.caseops_enabled and val_data.val_bytes is not None: + y_idx = ( + tok_starts.unsqueeze(1) + + 1 + + col_idx[:context_size].unsqueeze(0) + ) + y_idx = y_idx.clamp_(max=val_data.val_bytes.numel() - 1) + y_bytes_arg = val_data.val_bytes[y_idx].to( + device=device, dtype=torch.int32, non_blocking=True + ) + # Mirror the `valid` masking used for y so out-of-range tokens + # contribute zero bytes (matches y=0 substitution above). + y_bytes_arg = torch.where( + valid, y_bytes_arg, torch.zeros_like(y_bytes_arg) + ) + with torch.no_grad(): + _accumulate_bpb( + tilted_loss, + x, + y, + chunk_offsets, + chunk_lens, + ctx_pos, + val_data.base_bytes_lut, + val_data.has_leading_space_lut, + val_data.is_boundary_token_lut, + loss_sum, + byte_sum, + token_count, + y_bytes=y_bytes_arg, + ) + if needs_train: + activate_chunk_mask = (num_chunks_t - 1 > ci).float() + is_last_trained_chunk = (ci == max_nc - 2) + is_update_step = (ci % TTT_UPDATE_EVERY == TTT_UPDATE_EVERY - 1) or is_last_trained_chunk + is_window_start = (ci % TTT_UPDATE_EVERY == 0) + for gi in range(h.ttt_grad_steps): + if gi > 0 or ttt_lora_ema_enabled: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_tok_loss = forward_ttt_train(x, y, lora=cur_lora) + per_doc = per_tok_loss[ + :, chunk_offset : chunk_offset + chunk_size + ].mean(dim=-1) + # Original path: zero_grad only at the start of an update window + # (gi==0). With TTT_GRAD_STEPS>1 this leaves gi=0's gradient in + # .grad when gi=1 runs, so step 2 sees (grad_gi0 + grad_gi1) — + # effectively ~2x gradient magnitude on the second update. + # Clean path (TTT_GRAD_STEPS_CLEAN=1): also zero_grad before every + # gi>0 step so each optimizer.step() sees only its own fresh gradient, + # giving true independent half-LR updates with different curvature. + if (is_window_start and gi == 0) or (h.ttt_grad_steps_clean and gi > 0): + cur_opt.zero_grad(set_to_none=True) + (per_doc * activate_chunk_mask).sum().backward() + if is_update_step: + cur_opt.step() + if ttt_lora_ema_enabled and is_update_step: + with torch.no_grad(): + decay = TTT_LORA_EMA_DECAY + for ema_p, raw_p in zip(cur_ema_lora.parameters(), cur_lora.parameters()): + ema_p.data.mul_(decay).add_(raw_p.data, alpha=1.0 - decay) + else: + del per_tok_loss + batch_num = orig_batch_idx + 1 + doc_lens = [dl for _, dl in batch] + should_report = batch_num in eval_batch_set if eval_batch_set is not None else True + if should_report: + cur_tokens = token_count.item() + cur_loss_val = loss_sum.item() + cur_bytes_val = byte_sum.item() + dt = cur_tokens - prev_tokens + db = cur_bytes_val - prev_bytes + if dt > 0 and db > 0: + b_loss = (cur_loss_val - prev_loss) / dt + b_bpb = b_loss / math.log(2.0) * (dt / db) + else: + b_loss = b_bpb = 0.0 + r_loss = cur_loss_val / max(cur_tokens, 1) + r_bpb = r_loss / math.log(2.0) * (cur_tokens / max(cur_bytes_val, 1)) + elapsed = time.perf_counter() - t_start + log( + f"ttp: b{batch_num}/{queue_len} bl:{b_loss:.4f} bb:{b_bpb:.4f} " + f"rl:{r_loss:.4f} rb:{r_bpb:.4f} dl:{min(doc_lens)}-{max(doc_lens)} " + f"gd:{int(global_ttt_done)}" + ) + if not global_ttt_done: + local_scored_docs.extend( + (orig_batch_idx, pos, doc_start, doc_len) + for pos, (doc_start, doc_len) in enumerate(batch) + ) + prefix_done = _add_to_counter(prefix_counter_path, len(batch_entries)) + if prefix_done >= current_phase_boundary: + try: + with open(pause_flag_path, "x"): + pass + except FileExistsError: + pass + should_pause = os.path.exists(pause_flag_path) + if should_pause: + if dist.is_available() and dist.is_initialized(): + dist.barrier() + gathered_scored_docs = [None] * h.world_size + if dist.is_available() and dist.is_initialized(): + dist.all_gather_object(gathered_scored_docs, local_scored_docs) + else: + gathered_scored_docs = [local_scored_docs] + scored_docs_for_global = [] + for rank_docs in gathered_scored_docs: + if rank_docs: + scored_docs_for_global.extend(rank_docs) + scored_docs_for_global.sort(key=lambda x: (x[0], x[1])) + scored_docs_for_global = scored_docs_for_global[:current_phase_boundary] + scored_token_chunks = [ + val_data.val_tokens[doc_start : doc_start + doc_len] + for _, _, doc_start, doc_len in scored_docs_for_global + ] + if scored_token_chunks: + global_ttt_tokens = torch.cat(scored_token_chunks) + else: + global_ttt_tokens = val_data.val_tokens[:0] + if h.rank == 0: + prefix_done = 0 + try: + with open(prefix_counter_path, "rb") as f: + prefix_done = int.from_bytes( + f.read(8), "little", signed=True + ) + except FileNotFoundError: + pass + log( + f"ttpp: phase:{current_phase + 1}/{num_phases} pd:{prefix_done} " + f"gd:{len(scored_docs_for_global)} " + f"t:{time.perf_counter() - t_start:.1f}s" + ) + train_val_ttt_global_sgd_distributed( + h, device, val_data, base_model, global_ttt_tokens + ) + for p in base_model.parameters(): + p.requires_grad_(False) + reusable_lora = BatchedTTTLoRA( + h.ttt_batch_size, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + reusable_opt = _build_opt(reusable_lora) + if ttt_lora_ema_enabled: + reusable_ema_lora = BatchedTTTLoRA( + h.ttt_batch_size, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + current_phase += 1 + if current_phase >= num_phases: + global_ttt_done = True + else: + current_phase_boundary = phase_boundaries[current_phase] + if h.rank == 0: + try: + os.remove(pause_flag_path) + except FileNotFoundError: + pass + if dist.is_available() and dist.is_initialized(): + dist.barrier() + if h.rank == 0: + log(f"ttpr: phase:{current_phase}/{num_phases} t:{time.perf_counter() - t_start:.1f}s") + del cur_lora, cur_opt + if ttt_lora_ema_enabled: + del cur_ema_lora + finally: + pass + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.train() + return _loss_bpb_from_sums(loss_sum, token_count, byte_sum) + + +def timed_eval(label, fn, *args, **kwargs): + torch.cuda.synchronize() + t0 = time.perf_counter() + val_loss, val_bpb = fn(*args, **kwargs) + torch.cuda.synchronize() + elapsed_ms = 1e3 * (time.perf_counter() - t0) + log( + f"{label} val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f} eval_time:{elapsed_ms:.0f}ms" + ) + return val_loss, val_bpb + + +def train_model(h, device, val_data): + base_model = GPT(h).to(device).bfloat16() + restore_fp32_params(base_model) + compile_enabled = os.environ.get("TORCH_COMPILE_ENABLED", "1") == "1" + compile_dynamic = os.environ.get("TORCH_COMPILE_DYNAMIC", "0") == "1" + compile_fullgraph = os.environ.get("TORCH_COMPILE_FULLGRAPH", "1") == "1" + if compile_enabled: + compiled_model = torch.compile( + base_model, dynamic=compile_dynamic, fullgraph=compile_fullgraph + ) + compiled_forward_logits = torch.compile( + base_model.forward_logits, + dynamic=compile_dynamic, + fullgraph=compile_fullgraph, + ) + else: + compiled_model = base_model + compiled_forward_logits = base_model.forward_logits + model = compiled_model + log(f"model_params:{sum(p.numel()for p in base_model.parameters())}") + optimizers = Optimizers(h, base_model) + train_loader = DocumentPackingLoader(h, device) + max_wallclock_ms = ( + 1e3 * h.max_wallclock_seconds if h.max_wallclock_seconds > 0 else None + ) + if max_wallclock_ms is not None: + max_wallclock_ms -= h.gptq_reserve_seconds * 1e3 + log( + f"gptq:reserving {h.gptq_reserve_seconds:.0f}s, effective={max_wallclock_ms:.0f}ms" + ) + prequant_only = os.environ.get("PREQUANT_ONLY", "0") == "1" + + def training_frac(step, elapsed_ms): + if max_wallclock_ms is None: + return step / max(h.iterations, 1) + return elapsed_ms / max(max_wallclock_ms, 1e-09) + + def lr_mul(frac): + if h.warmdown_frac <= 0: + return 1.0 + if frac >= 1.0 - h.warmdown_frac: + return max((1.0 - frac) / h.warmdown_frac, h.min_lr) + return 1.0 + + def step_fn(step, lr_scale): + optimizers.zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(h.grad_accum_steps): + x, y, cu_seqlens, _max_seqlen = train_loader.next_batch( + h.train_batch_tokens, h.grad_accum_steps + ) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y, cu_seqlens=cu_seqlens, max_seqlen=h.train_seq_len) + train_loss += loss.detach() + (loss / h.grad_accum_steps).backward() + train_loss /= h.grad_accum_steps + frac = ( + min(step / h.muon_momentum_warmup_steps, 1.0) + if h.muon_momentum_warmup_steps > 0 + else 1.0 + ) + muon_momentum = ( + 1 - frac + ) * h.muon_momentum_warmup_start + frac * h.muon_momentum + for group in optimizers.optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * lr_scale + if h.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), h.grad_clip_norm) + optimizers.step(distributed=h.distributed) + return train_loss + + if h.warmup_steps > 0: + initial_model_state = { + name: tensor.detach().cpu().clone() + for (name, tensor) in base_model.state_dict().items() + } + initial_optimizer_states = [ + copy.deepcopy(opt.state_dict()) for opt in optimizers + ] + model.train() + num_tokens_local = h.train_batch_tokens // h.world_size + for blk in base_model.blocks: + blk.attn.rotary(num_tokens_local, device, torch.bfloat16) + cu_bucket_size = train_loader.cu_bucket_size + warmup_cu_buckets = tuple(cu_bucket_size * i for i in range(1, 5)) + warmup_cu_iters = 3 + x, y, cu_seqlens, _ = train_loader.next_batch( + h.train_batch_tokens, h.grad_accum_steps + ) + log(f"warmup_cu_buckets:{','.join(str(b) for b in warmup_cu_buckets)} iters_each:{warmup_cu_iters}") + def _run_cu_bucket_warmup(): + for bucket_len in warmup_cu_buckets: + boundaries = list(range(0, x.size(1), max(h.train_seq_len, 1))) + if boundaries[-1] != x.size(1): + boundaries.append(x.size(1)) + cu = torch.full((bucket_len,), x.size(1), dtype=torch.int32, device=device) + cu[: len(boundaries)] = torch.tensor(boundaries, dtype=torch.int32, device=device) + for _ in range(warmup_cu_iters): + optimizers.zero_grad_all() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + wloss = model(x, y, cu_seqlens=cu, max_seqlen=h.train_seq_len) + (wloss / h.grad_accum_steps).backward() + optimizers.zero_grad_all() + _run_cu_bucket_warmup() + if h.num_loops > 0: + base_model.looping_active = True + _run_cu_bucket_warmup() + base_model.looping_active = False + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if ( + warmup_step <= 5 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == h.warmup_steps + ): + log(f"warmup_step: {warmup_step+1}/{h.warmup_steps}") + if h.num_loops > 0: + base_model.looping_active = True + log( + f"loop_warmup:enabled encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}" + ) + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if ( + warmup_step <= 5 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == h.warmup_steps + ): + log(f"loop_warmup_step: {warmup_step+1}/{h.warmup_steps}") + base_model.looping_active = False + base_model.load_state_dict(initial_model_state, strict=True) + for (opt, state) in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + optimizers.zero_grad_all() + train_loader = DocumentPackingLoader(h, device) + ema_state = { + name: t.detach().float().clone() + for (name, t) in base_model.state_dict().items() + } + ema_decay = h.ema_decay + training_time_ms = 0.0 + stop_after_step = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = ( + step == h.iterations + or stop_after_step is not None + and step >= stop_after_step + ) + should_validate = ( + False if prequant_only else ( + last_step or h.val_loss_every > 0 and step % h.val_loss_every == 0 + ) + ) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1e3 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + h, device, val_data, model, compiled_forward_logits + ) + log( + f"{step}/{h.iterations} val_loss: {val_loss:.4f} val_bpb: {val_bpb:.4f}" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < h.iterations: + log( + f"stopping_early: wallclock_cap train_time: {training_time_ms:.0f}ms step: {step}/{h.iterations}" + ) + break + elapsed_ms = training_time_ms + 1e3 * (time.perf_counter() - t0) + frac = training_frac(step, elapsed_ms) + scale = lr_mul(frac) + if ( + h.num_loops > 0 + and not base_model.looping_active + and frac >= h.enable_looping_at + ): + base_model.looping_active = True + log( + f"layer_loop:enabled step:{step} frac:{frac:.3f} encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}" + ) + train_loss = step_fn(step, scale) + with torch.no_grad(): + for (name, t) in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_( + t.detach().float(), alpha=1.0 - ema_decay + ) + step += 1 + approx_training_time_ms = training_time_ms + 1e3 * (time.perf_counter() - t0) + should_log_train = h.train_log_every > 0 and ( + step <= 5 or step % h.train_log_every == 0 or stop_after_step is not None + ) + if should_log_train: + tok_per_sec = step * h.train_batch_tokens / (approx_training_time_ms / 1e3) + log( + f"{step}/{h.iterations} train_loss: {train_loss.item():.4f} train_time: {approx_training_time_ms/60000:.1f}m tok/s: {tok_per_sec:.0f}" + ) + reached_cap = ( + max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + ) + if h.distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log( + f"peak memory allocated: {torch.cuda.max_memory_allocated()//1024//1024} MiB reserved: {torch.cuda.max_memory_reserved()//1024//1024} MiB" + ) + log("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = { + name: t.to(dtype=current_state[name].dtype) for (name, t) in ema_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + return base_model, compiled_model, compiled_forward_logits + + +def train_and_eval(h, device): + random.seed(h.seed) + np.random.seed(h.seed) + torch.manual_seed(h.seed) + torch.cuda.manual_seed_all(h.seed) + compile_enabled = os.environ.get("TORCH_COMPILE_ENABLED", "1") == "1" + compile_dynamic = os.environ.get("TORCH_COMPILE_DYNAMIC", "0") == "1" + compile_fullgraph = os.environ.get("TORCH_COMPILE_FULLGRAPH", "1") == "1" + if h.artifact_dir and h.is_main_process: + os.makedirs(h.artifact_dir, exist_ok=True) + val_data = ValidationData(h, device) + log( + f"train_shards: {len(list(Path(h.datasets_dir).resolve().glob('fineweb_train_*.bin')))}" + ) + log(f"val_tokens: {val_data.val_tokens.numel()-1}") + # TTT_EVAL_ONLY: skip training + GPTQ, jump straight to TTT eval on a + # pre-existing quantized artifact. Used to test TTT-only improvements + # (e.g., PR-1767's alpha/warm-start/WD) without retraining. + ttt_eval_only = os.environ.get("TTT_EVAL_ONLY", "0") == "1" + if ttt_eval_only: + log("TTT_EVAL_ONLY=1 — skipping training + GPTQ, loading saved artifact for TTT eval") + log(f"ttt_lora_alpha: {BatchedLinearLoRA._ALPHA}") + log(f"ttt_warm_start_a: {BatchedLinearLoRA._WARM_START_A}") + log(f"ttt_weight_decay: {h.ttt_weight_decay}") + else: + base_model, compiled_model, compiled_forward_logits = train_model( + h, device, val_data + ) + torch._dynamo.reset() + if os.environ.get("PREQUANT_ONLY", "0") == "1": + log("PREQUANT_ONLY=1 — skipping diagnostic pre-quant eval, serialize/GPTQ/post-quant eval/TTT") + return + timed_eval( + "diagnostic pre-quantization post-ema", + eval_val, + h, + device, + val_data, + compiled_model, + compiled_forward_logits, + ) + serialize(h, base_model, Path(__file__).read_text(encoding="utf-8")) + if h.distributed: + dist.barrier() + eval_model = deserialize(h, device) + if h.num_loops > 0: + eval_model.looping_active = True + if not ttt_eval_only: + if compile_enabled: + compiled_model = torch.compile( + eval_model, dynamic=compile_dynamic, fullgraph=compile_fullgraph + ) + compiled_forward_logits = torch.compile( + eval_model.forward_logits, + dynamic=compile_dynamic, + fullgraph=compile_fullgraph, + ) + else: + compiled_model = eval_model + compiled_forward_logits = eval_model.forward_logits + timed_eval( + "diagnostic quantized", + eval_val, + h, + device, + val_data, + compiled_model, + compiled_forward_logits, + ) + del eval_model + if h.ttt_enabled: + if not ttt_eval_only: + del compiled_model + if ttt_eval_only: + del eval_model + torch._dynamo.reset() + torch.cuda.empty_cache() + ttt_model = deserialize(h, device) + if h.num_loops > 0: + ttt_model.looping_active = True + for p in ttt_model.parameters(): + p.requires_grad_(False) + + if h.rope_yarn: + _yarn_seqlen = h.train_batch_tokens // h.grad_accum_steps + for block in ttt_model.blocks: + block.attn.rotary(_yarn_seqlen, device, torch.bfloat16) + else: + for block in ttt_model.blocks: + block.attn.rotary._cos_cached = None + block.attn.rotary._sin_cached = None + block.attn.rotary._seq_len_cached = 0 + block.attn.rotary(h.ttt_eval_seq_len, device, torch.bfloat16) + + def _fwd_ttt_inner(input_ids, target_ids, lora): + return ttt_model.forward_ttt(input_ids, target_ids, lora=lora) + + def _fwd_ttt_inner_with_hints(input_ids, target_ids, lora, hint_ids): + return ttt_model.forward_ttt(input_ids, target_ids, lora=lora, hint_ids=hint_ids) + + _fwd_ttt_compiled_inner = None + _fwd_ttt_compiled_inner_hints = None + + def _fwd_ttt(input_ids, target_ids, lora, hint_ids=None): + nonlocal _fwd_ttt_compiled_inner, _fwd_ttt_compiled_inner_hints + if not compile_enabled: + if hint_ids is None: + return _fwd_ttt_inner(input_ids, target_ids, lora=lora) + return _fwd_ttt_inner_with_hints( + input_ids, target_ids, lora=lora, hint_ids=hint_ids + ) + if hint_ids is None: + if _fwd_ttt_compiled_inner is None: + _fwd_ttt_compiled_inner = torch.compile(_fwd_ttt_inner, dynamic=True) + return _fwd_ttt_compiled_inner(input_ids, target_ids, lora=lora) + if _fwd_ttt_compiled_inner_hints is None: + _fwd_ttt_compiled_inner_hints = torch.compile( + _fwd_ttt_inner_with_hints, dynamic=True + ) + return _fwd_ttt_compiled_inner_hints( + input_ids, target_ids, lora=lora, hint_ids=hint_ids + ) + + fwd_ttt_compiled = _fwd_ttt + log(f"ttt_grad_steps: {h.ttt_grad_steps} ttt_grad_steps_clean: {h.ttt_grad_steps_clean}") + if compile_enabled: + log(f"ttt_lora:warming up compile (random tokens, no val data)") + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + t_warmup = time.perf_counter() + warmup_bszes = [h.ttt_batch_size] + for bsz in warmup_bszes: + wl = BatchedTTTLoRA( + bsz, ttt_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + wo = torch.optim.AdamW( + wl.parameters(), + lr=h.ttt_lora_lr, + betas=(h.ttt_beta1, h.ttt_beta2), + eps=1e-10, + weight_decay=h.ttt_weight_decay, + fused=True, + ) + for ctx_len in (h.ttt_chunk_size, h.ttt_eval_seq_len): + xw = torch.randint(0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64) + yw = torch.randint(0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = fwd_ttt_compiled(xw, yw, lora=wl) + ptl[:, : min(h.ttt_chunk_size, ctx_len)].mean(dim=-1).sum().backward() + wo.step() + wo.zero_grad(set_to_none=True) + del wl, wo + torch.cuda.empty_cache() + compile_elapsed = time.perf_counter() - t_warmup + log(f"ttt_lora:compile warmup done ({compile_elapsed:.1f}s)") + else: + log("ttt_lora:compile disabled, skipping compile warmup") + # v5 Stage 1A: precompute n-gram hints BEFORE eval timer (single causal pass, + # val tokens only — same compliance as inline). Saves ~168s of measured eval + # time for full tilt without any loss of tilt benefit. + precomputed_hints = None + if h.ngram_tilt_enabled and h.ngram_hint_precompute_outside: + log("ngram_tilt:precomputing hints OUTSIDE eval timer") + precomputed_hints = _compute_ngram_hints_for_val(h, val_data, log0=log) + log("\nbeginning TTT eval timer") + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_val_loss, ttt_val_bpb = eval_val_ttt_phased( + h, ttt_model, device, val_data, forward_ttt_train=fwd_ttt_compiled, + precomputed_hints=precomputed_hints, + ) + torch.cuda.synchronize() + ttt_eval_elapsed = time.perf_counter() - t_ttt + log( + "quantized_ttt_phased " + f"val_loss:{ttt_val_loss:.8f} val_bpb:{ttt_val_bpb:.8f} " + f"eval_time:{1e3*ttt_eval_elapsed:.0f}ms" + ) + log(f"total_eval_time:{ttt_eval_elapsed:.1f}s") + del ttt_model + + +def main(): + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError( + f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral" + ) + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision("high") + from torch.backends.cuda import ( + enable_cudnn_sdp, + enable_flash_sdp, + enable_math_sdp, + enable_mem_efficient_sdp, + ) + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + torch._dynamo.config.optimize_ddp = False + _cache_lim = os.environ.get("TORCH_DYNAMO_CACHE_SIZE_LIMIT", "").strip() + torch._dynamo.config.cache_size_limit = int(_cache_lim) if _cache_lim else 64 + _accum_cache_lim = os.environ.get( + "TORCH_DYNAMO_ACCUMULATED_CACHE_SIZE_LIMIT", "" + ).strip() + if _accum_cache_lim and hasattr(torch._dynamo.config, "accumulated_cache_size_limit"): + torch._dynamo.config.accumulated_cache_size_limit = int(_accum_cache_lim) + _recomp_lim = os.environ.get("TORCH_DYNAMO_RECOMPILE_LIMIT", "").strip() + if _recomp_lim and hasattr(torch._dynamo.config, "recompile_limit"): + torch._dynamo.config.recompile_limit = int(_recomp_lim) + h = Hyperparameters() + set_logging_hparams(h) + if h.is_main_process: + os.makedirs(h.artifact_dir if h.artifact_dir else "logs", exist_ok=True) + log(100 * "=", console=False) + log("Hyperparameters:", console=True) + for (k, v) in sorted(vars(type(h)).items()): + if not k.startswith("_"): + log(f" {k}: {v}", console=True) + log("=" * 100, console=False) + log("Source code:", console=False) + log("=" * 100, console=False) + with open(__file__, "r", encoding="utf-8") as _src: + log(_src.read(), console=False) + log("=" * 100, console=False) + log(f"Running Python {sys.version}", console=False) + log(f"Running PyTorch {torch.__version__}", console=False) + log("=" * 100, console=False) + train_and_eval(h, device) + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/train_seed42.log b/records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/train_seed42.log new file mode 100644 index 0000000000..d07eb7e831 --- /dev/null +++ b/records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/train_seed42.log @@ -0,0 +1,331 @@ +W0503 07:05:50.736000 28930 torch/distributed/run.py:803] +W0503 07:05:50.736000 28930 torch/distributed/run.py:803] ***************************************** +W0503 07:05:50.736000 28930 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0503 07:05:50.736000 28930 torch/distributed/run.py:803] ***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + agree_add_boost: 0.0 + artifact_dir: /workspace/parameter-golf/runs/2026-05-03_2135_randproj384_pairmuon_lrelu03_qkgain525/seed42 + asym_logit_rescale: True + attn_backend: fa3 + attn_clip_sigmas: 13.0 + attn_out_gate_enabled: False + attn_out_gate_src: proj + beta1: 0.9 + beta2: 0.99 + caseops_enabled: True + compressor: pergroup + data_dir: ./data + datasets_dir: /workspace/caseops_data/datasets/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved + distributed: True + ema_decay: 0.9965 + embed_bits: 7 + embed_clip_sigmas: 14.0 + embed_lr: 0.6 + embed_wd: 0.085 + enable_looping_at: 0.35 + eval_seq_len: 2560 + eval_stride: 64 + fused_ce_enabled: True + gate_window: 12 + gated_attn_enabled: False + gated_attn_init_std: 0.01 + gated_attn_quant_gate: False + global_every_n: 6 + global_ttt_batch_seqs: 32 + global_ttt_chunk_tokens: 32768 + global_ttt_epochs: 1 + global_ttt_grad_clip: 1.0 + global_ttt_lr: 0.001 + global_ttt_momentum: 0.9 + global_ttt_muon_nesterov: True + global_ttt_muon_ns_steps: 5 + global_ttt_optimizer: sgd + global_ttt_respect_doc_boundaries: True + global_ttt_warmup_chunks: 0 + global_ttt_warmup_start_lr: 0.0 + gptq_calibration_batches: 32 + gptq_reserve_seconds: 0.5 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + is_main_process: True + iterations: 20000 + ln_scale: True + local_attn_window: 1024 + local_global_pattern: + local_rank: 0 + logfile: /workspace/parameter-golf/runs/2026-05-03_2135_randproj384_pairmuon_lrelu03_qkgain525/seed42/2135_randproj384_pairmuon_lrelu03_qkgain525_seed42.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + lqer_asym_enabled: True + lqer_asym_group: 32 + lqer_enabled: True + lqer_factor_bits: 4 + lqer_rank: 4 + lqer_top_k: 3 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.028 + max_wallclock_seconds: 600.0 + min_lr: 0.1 + mlp_clip_sigmas: 11.5 + mlp_mult: 4.0 + model_dim: 512 + model_path: /workspace/parameter-golf/runs/2026-05-03_2135_randproj384_pairmuon_lrelu03_qkgain525/seed42/final_model.pt + muon_backend_steps: 5 + muon_huber_delta: 0.1 + muon_huber_wd: False + muon_momentum: 0.97 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_qk_pair_ortho: True + muon_qk_pair_size: 2 + muon_row_normalize: True + muon_wd: 0.095 + ngram_hint_precompute_outside: False + ngram_tilt_enabled: True + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + optrot_enabled: False + parallel_final_lane: mean + parallel_start_layer: 8 + phased_ttt_num_phases: 1 + phased_ttt_prefix_docs: 2500 + qk_gain_init: 5.25 + qk_gain_schedule: [] + quantized_model_path: /workspace/parameter-golf/runs/2026-05-03_2135_randproj384_pairmuon_lrelu03_qkgain525/seed42/final_model.int6.ptz + random_proj_embed_dim: 384 + random_proj_embed_enabled: True + random_proj_seed: 1337 + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + rope_yarn: False + run_id: 2135_randproj384_pairmuon_lrelu03_qkgain525_seed42 + scalar_lr: 0.02 + seed: 42 + skip_gates_enabled: True + smear_gate_enabled: True + sparse_attn_gate_enabled: True + sparse_attn_gate_init_std: 0.0 + sparse_attn_gate_scale: 0.5 + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + token_boost: 2.625 + token_order: 16 + token_threshold: 0.8 + tokenizer_path: /workspace/caseops_data/datasets/tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model + train_batch_tokens: 786432 + train_files: /workspace/caseops_data/datasets/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_batch_size: 64 + ttt_beta1: 0.0 + ttt_beta2: 0.99 + ttt_chunk_size: 48 + ttt_enabled: True + ttt_eval_batches: + ttt_eval_seq_len: 2560 + ttt_grad_steps: 1 + ttt_grad_steps_clean: False + ttt_k_lora: False + ttt_lora_lr: 8e-05 + ttt_lora_rank: 80 + ttt_mlp_lora: True + ttt_o_lora: True + ttt_optimizer: adam + ttt_weight_decay: 2.0 + val_batch_tokens: 524288 + val_bytes_files: /workspace/caseops_data/datasets/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_val_bytes_*.bin + val_doc_fraction: 1.0 + val_files: /workspace/caseops_data/datasets/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_val_*.bin + val_loss_every: 0 + vocab_size: 8192 + warmdown_frac: 0.85 + warmup_steps: 20 + within_boost: 0.0 + within_tau: 99.0 + word_boost: 0.0 + word_normalize: strip_punct_lower + word_order: 4 + word_tau: 99.0 + world_size: 8 + xsa_last_n: 11 +train_shards: 80 +val_tokens: 47851520 +model_params:34897097 +gptq:reserving 0s, effective=599500ms +warmup_cu_buckets:64,128,192,256 iters_each:3 +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +1/20000 train_loss: 8.9983 train_time: 0.0m tok/s: 3068114 +2/20000 train_loss: 11.9106 train_time: 0.0m tok/s: 2458438 +3/20000 train_loss: 9.5449 train_time: 0.0m tok/s: 2573313 +4/20000 train_loss: 8.2694 train_time: 0.0m tok/s: 2637615 +5/20000 train_loss: 7.7865 train_time: 0.0m tok/s: 2676234 +500/20000 train_loss: 2.5851 train_time: 2.3m tok/s: 2886632 +layer_loop:enabled step:770 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +1000/20000 train_loss: 2.7498 train_time: 5.1m tok/s: 2589484 +1500/20000 train_loss: 2.4930 train_time: 8.5m tok/s: 2320947 +1724/20000 val_loss: 2.4662 val_bpb: 1.1269 +stopping_early: wallclock_cap train_time: 599714ms step: 1724/20000 +peak memory allocated: 59584 MiB reserved: 60268 MiB +ema:applying EMA weights +diagnostic pre-quantization post-ema val_loss:2.47020597 val_bpb:1.12868936 eval_time:9217ms +Serialized model: 133320959 bytes +Code size (uncompressed): 202291 bytes +Code size (compressed): 39405 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 6.9s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int6)+lqer_asym: blocks.mlp.fc.weight + gptq (int7)+lqer_asym: tok_emb.weight + passthrough (float16): blocks.attn.attn_gate_w, blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, parallel_post_lambdas, parallel_resid_lambdas, skip_gates, skip_weights, smear_gate.weight, smear_lambda, softcap_neg, softcap_pos +pergroup:similarity sort done in 36.6s (67 tensors) +pergroup:Q 34865152 raw → 15164617 (lrzip) (43.5%) +pergroup:remainder 272035 raw → 125834 brotli +pergroup:total frame 15399365 bytes +Serialized model quantized+pergroup: 15399365 bytes +Total submission size quantized+pergroup: 15438770 bytes +[rank1]: Traceback (most recent call last): +[rank1]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4647, in +[rank1]: main() +[rank1]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4641, in main +[rank1]: train_and_eval(h, device) +[rank1]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4454, in train_and_eval +[rank1]: if compile_enabled: +[rank1]: ^^^^^^^^^^^^^^^ +[rank1]: NameError: name 'compile_enabled' is not defined. Did you mean: 'compile_elapsed'? +[rank7]: Traceback (most recent call last): +[rank7]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4647, in +[rank7]: main() +[rank7]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4641, in main +[rank7]: train_and_eval(h, device) +[rank7]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4454, in train_and_eval +[rank7]: if compile_enabled: +[rank7]: ^^^^^^^^^^^^^^^ +[rank7]: NameError: name 'compile_enabled' is not defined. Did you mean: 'compile_elapsed'? +[rank6]: Traceback (most recent call last): +[rank6]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4647, in +[rank6]: main() +[rank6]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4641, in main +[rank6]: train_and_eval(h, device) +[rank6]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4454, in train_and_eval +[rank6]: if compile_enabled: +[rank6]: ^^^^^^^^^^^^^^^ +[rank6]: NameError: name 'compile_enabled' is not defined. Did you mean: 'compile_elapsed'? +[rank0]: Traceback (most recent call last): +[rank0]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4647, in +[rank0]: main() +[rank0]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4641, in main +[rank0]: train_and_eval(h, device) +[rank0]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4454, in train_and_eval +[rank0]: if compile_enabled: +[rank0]: ^^^^^^^^^^^^^^^ +[rank0]: NameError: name 'compile_enabled' is not defined. Did you mean: 'compile_elapsed'? +[rank2]: Traceback (most recent call last): +[rank2]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4647, in +[rank2]: main() +[rank2]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4641, in main +[rank2]: train_and_eval(h, device) +[rank2]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4454, in train_and_eval +[rank2]: if compile_enabled: +[rank2]: ^^^^^^^^^^^^^^^ +[rank2]: NameError: name 'compile_enabled' is not defined. Did you mean: 'compile_elapsed'? +[rank4]: Traceback (most recent call last): +[rank4]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4647, in +[rank4]: main() +[rank4]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4641, in main +[rank4]: train_and_eval(h, device) +[rank4]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4454, in train_and_eval +[rank4]: if compile_enabled: +[rank4]: ^^^^^^^^^^^^^^^ +[rank4]: NameError: name 'compile_enabled' is not defined. Did you mean: 'compile_elapsed'? +[rank3]: Traceback (most recent call last): +[rank3]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4647, in +[rank3]: main() +[rank3]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4641, in main +[rank3]: train_and_eval(h, device) +[rank3]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4454, in train_and_eval +[rank3]: if compile_enabled: +[rank3]: ^^^^^^^^^^^^^^^ +[rank3]: NameError: name 'compile_enabled' is not defined. Did you mean: 'compile_elapsed'? +[rank5]: Traceback (most recent call last): +[rank5]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4647, in +[rank5]: main() +[rank5]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4641, in main +[rank5]: train_and_eval(h, device) +[rank5]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4454, in train_and_eval +[rank5]: if compile_enabled: +[rank5]: ^^^^^^^^^^^^^^^ +[rank5]: NameError: name 'compile_enabled' is not defined. Did you mean: 'compile_elapsed'? +[rank0]:[W503 07:18:57.629920259 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator()) +[rank1]:[W503 07:18:58.099625938 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator()) +[rank7]:[W503 07:18:58.110379517 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator()) +[rank6]:[W503 07:18:58.142569331 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator()) +[rank2]:[W503 07:18:58.551773028 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator()) +[rank4]:[W503 07:18:58.597566049 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator()) +[rank0]:[W503 07:18:58.672160143 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator()) +W0503 07:18:59.388000 28930 torch/distributed/elastic/multiprocessing/api.py:908] Sending process 28998 closing signal SIGTERM +W0503 07:18:59.391000 28930 torch/distributed/elastic/multiprocessing/api.py:908] Sending process 29000 closing signal SIGTERM +W0503 07:18:59.391000 28930 torch/distributed/elastic/multiprocessing/api.py:908] Sending process 29001 closing signal SIGTERM +W0503 07:18:59.392000 28930 torch/distributed/elastic/multiprocessing/api.py:908] Sending process 29002 closing signal SIGTERM +W0503 07:18:59.393000 28930 torch/distributed/elastic/multiprocessing/api.py:908] Sending process 29003 closing signal SIGTERM +W0503 07:18:59.394000 28930 torch/distributed/elastic/multiprocessing/api.py:908] Sending process 29004 closing signal SIGTERM +W0503 07:18:59.395000 28930 torch/distributed/elastic/multiprocessing/api.py:908] Sending process 29005 closing signal SIGTERM +E0503 07:19:01.280000 28930 torch/distributed/elastic/multiprocessing/api.py:882] failed (exitcode: 1) local_rank: 1 (pid: 28999) of binary: /usr/local/bin/python +Traceback (most recent call last): + File "/usr/local/bin/torchrun", line 7, in + sys.exit(main()) + ^^^^^^ + File "/usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 357, in wrapper + return f(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.12/dist-packages/torch/distributed/run.py", line 936, in main + run(args) + File "/usr/local/lib/python3.12/dist-packages/torch/distributed/run.py", line 927, in run + elastic_launch( + File "/usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py", line 156, in __call__ + return launch_agent(self._config, self._entrypoint, list(args)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py", line 293, in launch_agent + raise ChildFailedError( +torch.distributed.elastic.multiprocessing.errors.ChildFailedError: +============================================================ +/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py FAILED +------------------------------------------------------------ +Failures: + +------------------------------------------------------------ +Root Cause (first observed failure): +[0]: + time : 2026-05-03_07:18:59 + host : f6b497c1899c + rank : 1 (local_rank: 1) + exitcode : 1 (pid: 28999) + error_file: + traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html +============================================================ +[W503 07:19:01.439508924 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator()) diff --git a/records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/ttt_eval_seed42_fail.log b/records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/ttt_eval_seed42_fail.log new file mode 100644 index 0000000000..9d181f7b78 --- /dev/null +++ b/records/track_non_record_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_NonRecord/ttt_eval_seed42_fail.log @@ -0,0 +1,339 @@ +W0503 07:51:02.865000 106615 torch/distributed/run.py:803] +W0503 07:51:02.865000 106615 torch/distributed/run.py:803] ***************************************** +W0503 07:51:02.865000 106615 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0503 07:51:02.865000 106615 torch/distributed/run.py:803] ***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + agree_add_boost: 0.0 + artifact_dir: /workspace/parameter-golf/runs/2026-05-03_2135_randproj384_pairmuon_lrelu03_qkgain525/seed42 + asym_logit_rescale: True + attn_backend: fa3 + attn_clip_sigmas: 13.0 + attn_out_gate_enabled: False + attn_out_gate_src: proj + beta1: 0.9 + beta2: 0.99 + caseops_enabled: True + compressor: pergroup + data_dir: ./data + datasets_dir: /workspace/caseops_data/datasets/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved + distributed: True + ema_decay: 0.9965 + embed_bits: 7 + embed_clip_sigmas: 14.0 + embed_lr: 0.6 + embed_wd: 0.085 + enable_looping_at: 0.35 + eval_seq_len: 2560 + eval_stride: 64 + fused_ce_enabled: True + gate_window: 12 + gated_attn_enabled: False + gated_attn_init_std: 0.01 + gated_attn_quant_gate: False + global_every_n: 6 + global_ttt_batch_seqs: 32 + global_ttt_chunk_tokens: 32768 + global_ttt_epochs: 1 + global_ttt_grad_clip: 1.0 + global_ttt_lr: 0.001 + global_ttt_momentum: 0.9 + global_ttt_muon_nesterov: True + global_ttt_muon_ns_steps: 5 + global_ttt_optimizer: sgd + global_ttt_respect_doc_boundaries: True + global_ttt_warmup_chunks: 0 + global_ttt_warmup_start_lr: 0.0 + gptq_calibration_batches: 32 + gptq_reserve_seconds: 0.5 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + is_main_process: True + iterations: 20000 + ln_scale: True + local_attn_window: 1024 + local_global_pattern: + local_rank: 0 + logfile: /workspace/parameter-golf/runs/2026-05-03_2135_randproj384_pairmuon_lrelu03_qkgain525/seed42/2135_randproj384_pairmuon_lrelu03_qkgain525_seed42.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + lqer_asym_enabled: True + lqer_asym_group: 32 + lqer_enabled: True + lqer_factor_bits: 4 + lqer_rank: 4 + lqer_top_k: 3 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.028 + max_wallclock_seconds: 600.0 + min_lr: 0.1 + mlp_clip_sigmas: 11.5 + mlp_mult: 4.0 + model_dim: 512 + model_path: /workspace/parameter-golf/runs/2026-05-03_2135_randproj384_pairmuon_lrelu03_qkgain525/seed42/final_model.pt + muon_backend_steps: 5 + muon_huber_delta: 0.1 + muon_huber_wd: False + muon_momentum: 0.97 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_qk_pair_ortho: True + muon_qk_pair_size: 2 + muon_row_normalize: True + muon_wd: 0.095 + ngram_hint_precompute_outside: False + ngram_tilt_enabled: True + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + optrot_enabled: False + parallel_final_lane: mean + parallel_start_layer: 8 + phased_ttt_num_phases: 1 + phased_ttt_prefix_docs: 2500 + qk_gain_init: 5.25 + qk_gain_schedule: [] + quantized_model_path: /workspace/parameter-golf/runs/2026-05-03_2135_randproj384_pairmuon_lrelu03_qkgain525/seed42/final_model.int6.ptz + random_proj_embed_dim: 384 + random_proj_embed_enabled: True + random_proj_seed: 1337 + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + rope_yarn: False + run_id: 2135_randproj384_pairmuon_lrelu03_qkgain525_seed42 + scalar_lr: 0.02 + seed: 42 + skip_gates_enabled: True + smear_gate_enabled: True + sparse_attn_gate_enabled: True + sparse_attn_gate_init_std: 0.0 + sparse_attn_gate_scale: 0.5 + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + token_boost: 2.625 + token_order: 16 + token_threshold: 0.8 + tokenizer_path: /workspace/caseops_data/datasets/tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model + train_batch_tokens: 786432 + train_files: /workspace/caseops_data/datasets/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_batch_size: 64 + ttt_beta1: 0.0 + ttt_beta2: 0.99 + ttt_chunk_size: 48 + ttt_enabled: True + ttt_eval_batches: + ttt_eval_seq_len: 2560 + ttt_grad_steps: 1 + ttt_grad_steps_clean: False + ttt_k_lora: False + ttt_lora_lr: 8e-05 + ttt_lora_rank: 80 + ttt_mlp_lora: True + ttt_o_lora: True + ttt_optimizer: adam + ttt_weight_decay: 2.0 + val_batch_tokens: 524288 + val_bytes_files: /workspace/caseops_data/datasets/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_val_bytes_*.bin + val_doc_fraction: 1.0 + val_files: /workspace/caseops_data/datasets/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_val_*.bin + val_loss_every: 0 + vocab_size: 8192 + warmdown_frac: 0.85 + warmup_steps: 20 + within_boost: 0.0 + within_tau: 99.0 + word_boost: 0.0 + word_normalize: strip_punct_lower + word_order: 4 + word_tau: 99.0 + world_size: 8 + xsa_last_n: 11 +train_shards: 80 +val_tokens: 47851520 +TTT_EVAL_ONLY=1 — skipping training + GPTQ, loading saved artifact for TTT eval +ttt_lora_alpha: 144.0 +ttt_warm_start_a: True +ttt_weight_decay: 2.0 +ttt_grad_steps: 1 ttt_grad_steps_clean: False +ttt_lora:compile disabled, skipping compile warmup + +beginning TTT eval timer +ngram_tilt:hints total=47851520 gated=628130 token_gate=628130 within_gate=0 word_gate=0 agree2plus=0 +ngram_tilt:precompute_done elapsed=168.45s total_targets=47851520 +ttt_phased: total_docs:50000 prefix_docs:2500 suffix_docs:47500 num_phases:1 boundaries:[2500] +[rank1]: Traceback (most recent call last): +[rank1]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4666, in +[rank1]: main() +[rank1]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4660, in main +[rank1]: train_and_eval(h, device) +[rank1]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4585, in train_and_eval +[rank1]: ttt_val_loss, ttt_val_bpb = eval_val_ttt_phased( +[rank1]: ^^^^^^^^^^^^^^^^^^^^ +[rank1]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4054, in eval_val_ttt_phased +[rank1]: (per_doc * activate_chunk_mask).sum().backward() +[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/_tensor.py", line 625, in backward +[rank1]: torch.autograd.backward( +[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/__init__.py", line 354, in backward +[rank1]: _engine_run_backward( +[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py", line 841, in _engine_run_backward +[rank1]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass +[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +[rank1]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.75 GiB. GPU 1 has a total capacity of 79.19 GiB of which 3.04 GiB is free. Including non-PyTorch memory, this process has 76.14 GiB memory in use. Of the allocated memory 73.13 GiB is allocated by PyTorch, and 1.40 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables) +[rank4]: Traceback (most recent call last): +[rank4]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4666, in +[rank4]: main() +[rank4]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4660, in main +[rank4]: train_and_eval(h, device) +[rank4]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4585, in train_and_eval +[rank4]: ttt_val_loss, ttt_val_bpb = eval_val_ttt_phased( +[rank4]: ^^^^^^^^^^^^^^^^^^^^ +[rank4]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4054, in eval_val_ttt_phased +[rank4]: (per_doc * activate_chunk_mask).sum().backward() +[rank4]: File "/usr/local/lib/python3.12/dist-packages/torch/_tensor.py", line 625, in backward +[rank4]: torch.autograd.backward( +[rank4]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/__init__.py", line 354, in backward +[rank4]: _engine_run_backward( +[rank4]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py", line 841, in _engine_run_backward +[rank4]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass +[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +[rank4]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.75 GiB. GPU 4 has a total capacity of 79.19 GiB of which 3.04 GiB is free. Including non-PyTorch memory, this process has 76.14 GiB memory in use. Of the allocated memory 73.13 GiB is allocated by PyTorch, and 1.40 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables) +[rank2]: Traceback (most recent call last): +[rank2]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4666, in +[rank2]: main() +[rank2]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4660, in main +[rank2]: train_and_eval(h, device) +[rank2]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4585, in train_and_eval +[rank2]: ttt_val_loss, ttt_val_bpb = eval_val_ttt_phased( +[rank2]: ^^^^^^^^^^^^^^^^^^^^ +[rank2]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4054, in eval_val_ttt_phased +[rank2]: (per_doc * activate_chunk_mask).sum().backward() +[rank2]: File "/usr/local/lib/python3.12/dist-packages/torch/_tensor.py", line 625, in backward +[rank2]: torch.autograd.backward( +[rank2]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/__init__.py", line 354, in backward +[rank2]: _engine_run_backward( +[rank2]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py", line 841, in _engine_run_backward +[rank2]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass +[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +[rank2]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.75 GiB. GPU 2 has a total capacity of 79.19 GiB of which 3.04 GiB is free. Including non-PyTorch memory, this process has 76.14 GiB memory in use. Of the allocated memory 73.13 GiB is allocated by PyTorch, and 1.40 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables) +[rank0]: Traceback (most recent call last): +[rank0]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4666, in +[rank0]: main() +[rank0]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4660, in main +[rank0]: train_and_eval(h, device) +[rank0]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4585, in train_and_eval +[rank0]: ttt_val_loss, ttt_val_bpb = eval_val_ttt_phased( +[rank0]: ^^^^^^^^^^^^^^^^^^^^ +[rank0]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4054, in eval_val_ttt_phased +[rank0]: (per_doc * activate_chunk_mask).sum().backward() +[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_tensor.py", line 625, in backward +[rank0]: torch.autograd.backward( +[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/__init__.py", line 354, in backward +[rank0]: _engine_run_backward( +[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py", line 841, in _engine_run_backward +[rank0]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass +[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +[rank0]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.75 GiB. GPU 0 has a total capacity of 79.19 GiB of which 3.04 GiB is free. Including non-PyTorch memory, this process has 76.14 GiB memory in use. Of the allocated memory 73.13 GiB is allocated by PyTorch, and 1.40 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables) +[rank7]: Traceback (most recent call last): +[rank7]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4666, in +[rank7]: main() +[rank7]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4660, in main +[rank7]: train_and_eval(h, device) +[rank7]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4585, in train_and_eval +[rank7]: ttt_val_loss, ttt_val_bpb = eval_val_ttt_phased( +[rank7]: ^^^^^^^^^^^^^^^^^^^^ +[rank7]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4054, in eval_val_ttt_phased +[rank7]: (per_doc * activate_chunk_mask).sum().backward() +[rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/_tensor.py", line 625, in backward +[rank7]: torch.autograd.backward( +[rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/__init__.py", line 354, in backward +[rank7]: _engine_run_backward( +[rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py", line 841, in _engine_run_backward +[rank7]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass +[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +[rank7]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.75 GiB. GPU 7 has a total capacity of 79.19 GiB of which 3.04 GiB is free. Including non-PyTorch memory, this process has 76.14 GiB memory in use. Of the allocated memory 73.13 GiB is allocated by PyTorch, and 1.40 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables) +[rank3]: Traceback (most recent call last): +[rank3]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4666, in +[rank3]: main() +[rank3]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4660, in main +[rank3]: train_and_eval(h, device) +[rank3]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4585, in train_and_eval +[rank3]: ttt_val_loss, ttt_val_bpb = eval_val_ttt_phased( +[rank3]: ^^^^^^^^^^^^^^^^^^^^ +[rank3]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4054, in eval_val_ttt_phased +[rank3]: (per_doc * activate_chunk_mask).sum().backward() +[rank3]: File "/usr/local/lib/python3.12/dist-packages/torch/_tensor.py", line 625, in backward +[rank3]: torch.autograd.backward( +[rank3]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/__init__.py", line 354, in backward +[rank3]: _engine_run_backward( +[rank3]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py", line 841, in _engine_run_backward +[rank3]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass +[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +[rank3]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.75 GiB. GPU 3 has a total capacity of 79.19 GiB of which 3.04 GiB is free. Including non-PyTorch memory, this process has 76.14 GiB memory in use. Of the allocated memory 73.13 GiB is allocated by PyTorch, and 1.40 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables) +[rank6]: Traceback (most recent call last): +[rank6]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4666, in +[rank6]: main() +[rank6]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4660, in main +[rank6]: train_and_eval(h, device) +[rank6]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4585, in train_and_eval +[rank6]: ttt_val_loss, ttt_val_bpb = eval_val_ttt_phased( +[rank6]: ^^^^^^^^^^^^^^^^^^^^ +[rank6]: File "/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py", line 4054, in eval_val_ttt_phased +[rank6]: (per_doc * activate_chunk_mask).sum().backward() +[rank6]: File "/usr/local/lib/python3.12/dist-packages/torch/_tensor.py", line 625, in backward +[rank6]: torch.autograd.backward( +[rank6]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/__init__.py", line 354, in backward +[rank6]: _engine_run_backward( +[rank6]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py", line 841, in _engine_run_backward +[rank6]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass +[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +[rank6]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.75 GiB. GPU 6 has a total capacity of 79.19 GiB of which 3.04 GiB is free. Including non-PyTorch memory, this process has 76.14 GiB memory in use. Of the allocated memory 73.13 GiB is allocated by PyTorch, and 1.40 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables) +[rank0]:[W503 07:55:14.353048190 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator()) +W0503 07:55:15.633000 106615 torch/distributed/elastic/multiprocessing/api.py:908] Sending process 106684 closing signal SIGTERM +W0503 07:55:15.636000 106615 torch/distributed/elastic/multiprocessing/api.py:908] Sending process 106685 closing signal SIGTERM +W0503 07:55:15.637000 106615 torch/distributed/elastic/multiprocessing/api.py:908] Sending process 106686 closing signal SIGTERM +W0503 07:55:15.638000 106615 torch/distributed/elastic/multiprocessing/api.py:908] Sending process 106687 closing signal SIGTERM +W0503 07:55:15.640000 106615 torch/distributed/elastic/multiprocessing/api.py:908] Sending process 106688 closing signal SIGTERM +W0503 07:55:15.641000 106615 torch/distributed/elastic/multiprocessing/api.py:908] Sending process 106689 closing signal SIGTERM +W0503 07:55:15.642000 106615 torch/distributed/elastic/multiprocessing/api.py:908] Sending process 106690 closing signal SIGTERM +E0503 07:55:17.608000 106615 torch/distributed/elastic/multiprocessing/api.py:882] failed (exitcode: 1) local_rank: 0 (pid: 106683) of binary: /usr/local/bin/python +Traceback (most recent call last): + File "/usr/local/bin/torchrun", line 7, in + sys.exit(main()) + ^^^^^^ + File "/usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 357, in wrapper + return f(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.12/dist-packages/torch/distributed/run.py", line 936, in main + run(args) + File "/usr/local/lib/python3.12/dist-packages/torch/distributed/run.py", line 927, in run + elastic_launch( + File "/usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py", line 156, in __call__ + return launch_agent(self._config, self._entrypoint, list(args)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py", line 293, in launch_agent + raise ChildFailedError( +torch.distributed.elastic.multiprocessing.errors.ChildFailedError: +============================================================ +/workspace/parameter-golf/records/track_10min_16mb/2026-05-03_SP8192_RandProj384_PairMuonQK_LegalTTT/train_gpt.py FAILED +------------------------------------------------------------ +Failures: + +------------------------------------------------------------ +Root Cause (first observed failure): +[0]: + time : 2026-05-03_07:55:15 + host : f6b497c1899c + rank : 0 (local_rank: 0) + exitcode : 1 (pid: 106683) + error_file: + traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html +============================================================