diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 00000000..0c079c0b --- /dev/null +++ b/.dockerignore @@ -0,0 +1,9 @@ +target/ +output/ +.git/ +.claude/ +knowledge/ +experiments/ +docs/ +*.vindex +/tmp/ diff --git a/.github/workflows/bench-regress.yml b/.github/workflows/bench-regress.yml new file mode 100644 index 00000000..8f4dcb91 --- /dev/null +++ b/.github/workflows/bench-regress.yml @@ -0,0 +1,98 @@ +# Bench regression detector — runs `make bench-check` on every PR +# against a baseline saved on `main`. Fails the workflow if any cell +# in the criterion bench suite regresses past Criterion's noise +# threshold. +# +# Surface covered (`make bench` = `make bench-quant + bench-matmul + bench-linalg`): +# - `quant_matvec`: Q4_0 / Q4_K / Q4_KF / Q6_K × 3 shapes × cpu/metal +# - `matmul`: f32 matmul + f32_gemv (lm-head) — cpu vs metal +# - `linalg`: cholesky + ridge solve (cpu only) +# +# That's the surface where the next throughput cliff would show up +# first. The 75 %-row drop in `q4_matvec_v4` would have shown as a 4× +# regression at `quant_matvec_q4_0/metal/lm_head_262144` weeks before +# goldens caught it. + +name: bench-regress + +on: + push: + branches: [main] + pull_request: + branches: [main] + # Manual trigger so a maintainer can re-baseline after intentional + # perf changes without waiting for the next merge to main. + workflow_dispatch: {} + +jobs: + bench: + # macos-14 = Apple Silicon (M1+). Required for the metal cells — + # without it, drop --features metal from FEATURES to skip them + # and run only the CPU surface on any runner. + runs-on: macos-14 + timeout-minutes: 90 + + steps: + - uses: actions/checkout@v4 + + # Cargo deps are big and stable across PRs — separate cache. + - name: Cache cargo deps + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-bench-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo-bench- + + # Criterion baselines: write-through on main, read-only on PRs. + # Keyed by the run number so each main push refreshes the cache. + - name: Cache criterion baseline (main only) + if: github.ref == 'refs/heads/main' + uses: actions/cache@v4 + with: + path: target/criterion + key: ${{ runner.os }}-criterion-baseline-${{ github.run_number }} + restore-keys: | + ${{ runner.os }}-criterion-baseline- + + - name: Restore criterion baseline (PRs only) + if: github.event_name == 'pull_request' + uses: actions/cache/restore@v4 + with: + path: target/criterion + key: ${{ runner.os }}-criterion-baseline- + restore-keys: | + ${{ runner.os }}-criterion-baseline- + + - name: Save baseline (main only) + if: github.ref == 'refs/heads/main' + run: make bench-save + + - name: Check vs baseline (PRs + manual) + if: github.event_name == 'pull_request' || github.event_name == 'workflow_dispatch' + run: | + # Cold cache → bench-check prints "no baseline found" and + # exits 2. Treat as neutral: the first PR after CI is stood + # up shouldn't fail just because there's no baseline yet. + set +e + make bench-check + rc=$? + set -e + if [ "$rc" -eq 2 ]; then + echo "::warning::no criterion baseline cached; skipping regression check" + exit 0 + fi + exit "$rc" + + # On regression, attach the criterion HTML report so reviewers + # can see the per-cell delta without re-running locally. + - name: Upload criterion report on failure + if: failure() + uses: actions/upload-artifact@v4 + with: + name: criterion-report + path: target/criterion/ + retention-days: 14 diff --git a/.github/workflows/larql-models.yml b/.github/workflows/larql-models.yml new file mode 100644 index 00000000..de8f7866 --- /dev/null +++ b/.github/workflows/larql-models.yml @@ -0,0 +1,68 @@ +# larql-models cross-platform CI +# +# Runs check + clippy + tests + bench test-mode on Linux, Windows, and macOS +# for every change to the larql-models crate. Validates cross-platform compatibility: +# - Linux (x86_64-unknown-linux-gnu) +# - Windows (x86_64-pc-windows-msvc) — HF cache path, mmap, path separators +# - macOS (aarch64-apple-darwin) — NEON SIMD paths + +name: larql-models + +on: + push: + branches: [main] + paths: + - 'crates/larql-models/**' + - 'Cargo.toml' + - 'Cargo.lock' + - '.github/workflows/larql-models.yml' + pull_request: + branches: [main] + paths: + - 'crates/larql-models/**' + - 'Cargo.toml' + - 'Cargo.lock' + - '.github/workflows/larql-models.yml' + workflow_dispatch: {} + +jobs: + test: + name: test · ${{ matrix.os }} + runs-on: ${{ matrix.os }} + timeout-minutes: 20 + + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-latest, macos-14] + + steps: + - uses: actions/checkout@v4 + + - name: Install stable Rust + uses: dtolnay/rust-toolchain@stable + with: + components: clippy + + - name: Cache cargo registry + build artefacts + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-models-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo-models- + + - name: Check (all targets) + run: cargo check -p larql-models --all-targets + + - name: Clippy (warnings as errors) + run: cargo clippy -p larql-models --all-targets -- -D warnings + + - name: Test + run: cargo test -p larql-models + + - name: Test benches + run: cargo test -p larql-models --benches diff --git a/Makefile b/Makefile index 06cd7a57..13122def 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: build release test check clean fmt lint demos +.PHONY: build release test test-fast test-full test-integration test-models check clean fmt lint demos bench bench-save bench-check coverage coverage-summary # Build build: @@ -8,9 +8,31 @@ release: cargo build --release -p larql-cli # Test -test: +# +# Default test target is intentionally fast: no integration binaries, no +# model-backed ignored tests. Use `test-full` for the historical full +# workspace run, and `test-models` for real-model/vindex checks. +test: test-fast + +test-fast: + cargo test --workspace --lib --bins + +test-full: cargo test --workspace +test-integration: + cargo test --workspace --tests + +test-models: + cargo test -p larql-inference --test test_arch_golden -- --ignored + cargo test -p larql-inference --test test_logits_goldens -- --ignored + cargo test -p larql-inference --test test_gemma3_smoke -- --ignored + cargo test -p larql-inference --test test_generate_q4k_cpu -- --ignored + cargo test -p larql-inference --test bench_probe_latency -- --ignored --nocapture + cargo test -p larql-inference --test test_llm_dispatch -- --ignored --nocapture + cargo test -p larql-inference --test test_constrained_dispatch -- --ignored --nocapture + cargo test -p larql-inference --test test_trie_dispatch -- --ignored --nocapture + # Check (compile without building) check: cargo check --workspace @@ -26,12 +48,29 @@ lint: cargo clippy --workspace --tests -- -D warnings # All quality checks -ci: fmt-check lint test +ci: fmt-check lint test-full # Clean clean: cargo clean +# Benchmarks +# +# `bench` runs the full quant_matvec suite and writes HTML reports under +# `target/criterion/`. `bench-save` records a baseline named `main`; +# `bench-check` re-runs and fails if any cell regresses past Criterion's +# default noise threshold. Plug `bench-check` into CI to catch the next +# 4× throughput cliff (the kind the q4_matvec_v4 row-drop bug caused) at +# PR time, not at goldens-fail time weeks later. +bench: + cargo bench -p larql-compute --bench quant_matvec --features metal + +bench-save: + bash scripts/bench-regress.sh save + +bench-check: + bash scripts/bench-regress.sh check + # Demos demos: cargo run --release -p larql-models --example architecture_demo @@ -52,7 +91,43 @@ bench-core: bench-inference: cargo run --release -p larql-inference --example bench_inference -bench-all: bench-core bench-inference +# Vindex micro-benches — synthetic, fast, safe under load. +bench-vindex: + cargo bench -p larql-vindex --bench vindex_ops + +# Vindex production-dim scaling bench. Refuses if larql-server / router +# are alive (they distort 1-2 GB matmuls). Run alone, on a cool host; +# results feed PERFORMANCE.md. +bench-vindex-scaling: + @if pgrep -fl 'larql-(server|router)' >/dev/null 2>&1; then \ + echo "Refusing bench-vindex-scaling: larql daemons running. Stop them first."; \ + pgrep -fl 'larql-(server|router)'; \ + exit 2; \ + fi + cargo bench -p larql-vindex --bench vindex_scaling + +bench-all: bench-core bench-inference bench-vindex + +# Coverage — uses cargo-llvm-cov (install with `cargo install cargo-llvm-cov`). +# Writes an HTML report to coverage/ that can be opened in a browser. +# Scoped to larql-vindex by default since the audit owner cares about +# that crate; pass CRATE=… to scope elsewhere. +COVERAGE_CRATE ?= larql-vindex +coverage: + @if ! command -v cargo-llvm-cov >/dev/null 2>&1; then \ + echo "cargo-llvm-cov not installed. Install with:"; \ + echo " cargo install cargo-llvm-cov"; \ + exit 1; \ + fi + cargo llvm-cov --package $(COVERAGE_CRATE) --html --output-dir coverage + @echo "Report: coverage/html/index.html" + +coverage-summary: + @if ! command -v cargo-llvm-cov >/dev/null 2>&1; then \ + echo "cargo-llvm-cov not installed."; \ + exit 1; \ + fi + cargo llvm-cov --package $(COVERAGE_CRATE) --summary-only # Python extension (managed via uv) python-setup: diff --git a/README.md b/README.md index b54f4bdc..ebc35996 100644 --- a/README.md +++ b/README.md @@ -120,9 +120,77 @@ larql run gemma4-31b.client.vindex --ffn http://server.local:8080 \ ``` Other presets: `browse` (DESCRIBE/WALK only, no forward pass), `router` -(MoE router only, ADR-0003), `all` (full clone). See `larql slice --help` +(MoE router weights only), `expert-server` (MoE expert weights for remote +CPU serving — see below), `all` (full clone). See `larql slice --help` for the explicit part list. +### MoE expert sharding — experts on CPU-only remote machines + +For Mixture-of-Experts models (Gemma 4 26B A4B, Mixtral, etc.), the expert +bank can be served from **CPU-only machines with no GPU and no VRAM**. The +laptop runs attention and the router (hot path); the expert servers hold the +dormant majority as memory-mapped data. + +```bash +# Carve the client slice (attn + embed + router — 2.1 GB for 26B A4B Q4_K) +larql slice gemma4-26b-a4b.vindex --preset expert-server \ + -o gemma4-26b-a4b.expert-server.vindex + +# Two expert servers — experts 0-63 on one machine, 64-127 on another +larql serve gemma4-26b-a4b.vindex --port 8081 --experts 0-63 +larql serve gemma4-26b-a4b.vindex --port 8082 --experts 64-127 + +# Client dispatches expert calls directly +larql run gemma4-26b-a4b.vindex \ + --moe-shards "0-63=http://expert-a:8081,64-127=http://expert-b:8082" \ + "The capital of France is" +``` + +The `expert-server` preset includes everything the server needs to boot and +serve `POST /v1/expert/batch` calls: embeddings, norms, the interleaved Q4K +dense FFN, the per-layer expert weights (`layers/`), tokenizer, and manifest. + +**Single server** (simplest — one machine holds all experts): + +```bash +larql serve gemma4-26b-a4b.vindex --port 8080 +larql run gemma4-26b-a4b.vindex --moe-shards "0-127=http://server:8080" "..." +``` + +**2D layer × expert grid.** Layer shards can themselves fan out to expert +servers, so both axes scale independently: + +```bash +# Layer shard — runs attention for layers 0-14, delegates experts to CPU tier +larql serve gemma4-26b-a4b.vindex --port 8091 --layers 0-14 \ + --moe-shards "0-63=http://expert-a:8081,64-127=http://expert-b:8082" + +# larql-router routes by layer range; client just sends --ffn to the router +larql-router --port 9090 \ + --shards "0-14=http://layer-a:8091,15-29=http://layer-b:8092" + +larql run gemma4-26b-a4b.vindex --ffn http://router:9090 "..." +``` + +**Deploy expert servers to fly.io** (CPU-only, no GPU, tested): + +```bash +# Publish the expert-server slice to HuggingFace first +larql publish gemma4-26b-a4b.expert-server.vindex \ + --repo myorg/gemma-4-26b-a4b-vindex-expert-server --slices none + +# Then deploy — start.sh auto-downloads the vindex on first boot +fly deploy --app larql-expert-server --config deploy/fly/fly.toml --remote-only +``` + +See [`deploy/fly/`](deploy/fly/) for the Dockerfile, `fly.toml`, and startup +script. First boot downloads the vindex from HuggingFace to the persistent +volume (~2 min on fly's network); subsequent restarts are instant. + +Live demo: `https://larql-expert-server.fly.dev` serves +`hf://chrishayuk/gemma-4-26b-a4b-it-vindex-expert-server` — a real CPU-only +expert server on fly.io that you can point `--moe-shards` at. + **3-tier topology (ADR-0008).** When laptop RAM matters, split the embedding table out to its own server: @@ -269,7 +337,7 @@ larql-models Model config, architecture traits, weight loading, quant/dequa larql-vindex Vindex lifecycle: extract, load, query, mutate, patch, save ↓ larql-core Graph algorithms, merge, diff -larql-inference Forward pass, BLAS-fused attention, Metal GPU, WalkFfn +larql-inference Forward pass, BLAS-fused attention, Metal GPU (macOS), WalkFfn ↓ larql-lql LQL parser, executor, REPL, USE REMOTE client ↓ @@ -449,20 +517,22 @@ Dense and full-precision MoE models support all operations (DESCRIBE, WALK, INFE | Operation | Latency | tok/s | |---|---|---| -| **GPU Q4K decode (Metal, 34L, KV cache)** | **15.6ms** | **64** | +| **GPU Q4K decode (Metal, 34L, KV cache)** | **12.0ms** | **83.2** | | Walk prediction (CPU, no attention) | 33ms | 30 | | INFER walk (CPU, with attention, mmap FFN) | 517ms | 1.9 | | INFER dense (CPU, all matmul) | 535ms | 1.9 | | DESCRIBE (knowledge browse) | 33ms | — | -GPU decode per-stage breakdown: +GPU decode per-stage breakdown (post 2026-05-02 dispatch geometry fix): | Component | Time | % of total | |---|---|---| -| GPU forward (34 layers, Q4K/Q6K) | 14.1ms | 86% | -| LM head (Q4_0 synthesized from f16 embeddings) | 2.0ms | 12% | +| GPU forward (34 layers, Q4K/Q6K) | 11.16 ms | 86% | +| LM head (Q4_K stride-32 + correctness fix) | 1.85 ms | 14% | | Embed + norm + detokenize | <0.1ms | <1% | +vs ollama gemma3:4b on the same machine: 99 tok/s steady → **gap 1.18×**, was 1.30× before the fix. + CPU walk breakdown: | Component | Time | % of total | @@ -471,7 +541,29 @@ CPU walk breakdown: | FFN × 34 layers (walk) | 194ms | 36% | | Attention × 34 layers | 84ms | 16% | -Walk is **faster than dense** (517ms vs 535ms). GPU Q4K decode is **16× faster** than CPU walk. FFN down projection in walk reads from mmap'd vindex (zero-copy BLAS). Walk only needs ~3.5GB of model weights (attention + embeddings), not 16.6GB. No quantization. See [docs/ffn-graph-layer.md](docs/ffn-graph-layer.md) for architecture and [docs/inference-engine.md](docs/inference-engine.md) for engine details. +Walk is **faster than dense** (517ms vs 535ms). GPU Q4K decode is **23× faster** than CPU walk. FFN down projection in walk reads from mmap'd vindex (zero-copy BLAS). Walk only needs ~3.5GB of model weights (attention + embeddings), not 16.6GB. No quantization. See [docs/ffn-graph-layer.md](docs/ffn-graph-layer.md) for architecture and [docs/inference-engine.md](docs/inference-engine.md) for engine details. + +### MoE / grid (Gemma 4 26B A4B, M3 Max) + +| Topology | tok/s | Notes | +|---|---|---| +| **Local Metal MoE** | **18.9** | Measured 2026-05-04; MoE experts on CPU NEON. | +| 1-shard CPU/grid (loopback) | 18.3 | NEON Q4_K matvec on shard server, gRPC fan-in | +| 2-shard CPU/grid (loopback) | 17.3 | Parallel collect + parallel fire (`std::thread::scope` + `rayon::par_iter`) | +| SKIP_MOE ceiling | 56.8 | Attention + dense FFN only; theoretical max | + +### Dense remote-FFN (Gemma 4 31B Q4K, M3 Max, localhost) + +| Topology | tok/s | Notes | +|---|---|---| +| **Remote-FFN batch, Metal GPU server** | **6.5** | `larql bench --ffn URL --ffn-dispatch batch`; `--features metal-experts` on server. 153ms/tok: 92ms attn local + 60ms FFN remote. | +| Remote-FFN batch, CPU server | 1.6 | Same path, server uses CPU NEON instead of Metal. | +| Remote-FFN streaming (60 sequential HTTP) | 0.6 | Q8K wire format via `/v1/walk-ffn-q8k`, NEON down projection. | +| Local Metal | blocked | Heterogeneous attention (L5/L11/…/L59 head_dim=512 vs sliding head_dim=256) — A1-A3 roadmap. Est. ~12-15 tok/s after fix. | + +**Metal GPU FFN server** (`larql serve --ffn-only --features metal-experts`): pre-loads Q4K weight bytes into Metal buffers at startup via zero-copy mmap; dispatches `q4k_ffn_gate_up_8sg` + `geglu_gelu_tanh` + `q4k_matvec` per Q8K batch request — same shaders as local decode. **Build separation required**: `larql-cli` must be built WITHOUT `--features metal-experts` (adding it causes a 10.7 vs 18.9 tok/s regression on Gemma 4 26B-A4B due to Metal pipeline init overhead in the standard decode path). Only the server binary uses that flag. + +The grid path is the load-bearing primitive for the **"split large models in grids"** axis — Kimi K2.6 / DeepSeek V4-class models (1T params, ~600 GB Q4_K) only fit on a multi-shard deployment. See [`crates/larql-server/ROADMAP.md` §G-SCALE](crates/larql-server/ROADMAP.md) for the path forward. ## Residual Stream Trace @@ -528,6 +620,65 @@ store.residual(42) # zero-copy from mmap See [docs/residual-trace.md](docs/residual-trace.md) for the full writeup. +## Mechanistic interpretability surface + +LARQL exposes a programmatic forward-hook system for capture, ablation, +steering, activation patching, logit lens, and KV-cache surgery — the +primitives lazarus-style MCP servers (e.g. `chuk-mcp-lazarus`) build on +top of. All of it works on real models and on synthetic weights, with +zero overhead when no hook is registered. + +```rust +use larql_inference::forward::{ + RecordHook, SteerHook, ZeroAblateHook, trace_forward_full_hooked, + capture_donor_state, patch_and_trace, logit_lens_topk, embedding_neighbors, +}; + +// 1. Capture residuals at chosen layers (read-only). +let mut record = RecordHook::for_layers([12, 18, 24]); +trace_forward_full_hooked(&weights, &tokens, &[12, 18, 24], + /*activations=*/ false, 0, /*attention=*/ false, &ffn, &mut record); +let residual_at_18 = record.post_layer.get(&18).unwrap(); + +// 2. Logit lens at any layer — top-k, single-token tracking, full race. +let top_k = logit_lens_topk(&weights, residual_at_18.row(0).as_slice().unwrap(), 5); +let neighbors = embedding_neighbors(&weights, &query_vec, 10); + +// 3. Ablate or steer mid-forward. +let mut ablate = ZeroAblateHook::for_layers([14usize]); +let mut steer = SteerHook::new().add(20, steer_vec, 0.5); + +// 4. Activation patching — donor → recipient at chosen (layer, position) coords. +let donor = capture_donor_state(&weights, &donor_tokens, &[(10, 4)]); +let patched = patch_and_trace(&weights, &recipient_tokens, &donor, &[28]); +``` + +From Python via `larql._native.WalkModel`: +`capture_residuals`, `forward_with_capture`, `forward_ablate`, +`forward_steer`, `patch_activations`, `logit_lens`, `track_token_at`, +`track_race`, `embedding_neighbors`, `project_through_unembed`, +`embedding_for`, `unembedding_for`, `generate_with_hooks`. Returned +tensors are numpy arrays. + +**Backend split.** Hooks during single-forward (`trace_forward_full_hooked`, +all the capture/ablate/steer/patch primitives above) are zero-cost when +no hook is registered and run on the existing CPU forward path. Hooks +during **multi-token generation** (`generate_cached_hooked` / +`WalkModel.generate_with_hooks`) also use the CPU KV-cache path — the +Metal-fast `predict` is hook-free by design (kernels are fused; threading +hooks through would split the fast path even when unused). Mech-interp +tools want correctness over throughput, so the CPU-when-hooks-active +trade is the right one. + +End-to-end walkthrough on synthetic weights (no vindex required): + +```bash +cargo run --release -p larql-inference --example mech_interp_demo +``` + +The full surface is documented in `crates/larql-inference/ROADMAP.md` § +"P0: Mechanistic hooks (lazarus parity)". + ## Documentation | Doc | Description | @@ -542,14 +693,24 @@ See [docs/residual-trace.md](docs/residual-trace.md) for the full writeup. | [docs/ffn-graph-layer.md](docs/ffn-graph-layer.md) | FFN graph layer — mmap walk faster than dense (517ms vs 535ms), all 34 layers | | [docs/walk-boundary-sweep.md](docs/walk-boundary-sweep.md) | Walk boundary sweep — correctness proof across all layer boundaries | | [docs/residual-trace.md](docs/residual-trace.md) | Residual stream trace — decomposition, storage, tiered context | +| [docs/mech-interp.md](docs/mech-interp.md) | Mechanistic interp surface — hooks, lens, vocab proj, patching, KV surgery (Rust + Python) | | [docs/specs/trace-format-spec.md](docs/specs/trace-format-spec.md) | Trace file format specification (.bin, .bndx, .ctxt) | +## Platform Support + +| Platform | Compiles | GPU | BLAS | +|----------|----------|-----|------| +| macOS arm64 (M-series) | ✓ | Metal (`--features metal`) | Accelerate | +| Linux arm64 / x86_64 | ✓ | — (CPU fallback) | OpenBLAS | +| Windows arm64 / x86_64 | ✓ | — (CPU fallback) | OpenBLAS | + +macOS gets Metal GPU acceleration. Linux and Windows run the same CPU path (BLAS-fused attention + mmap walk FFN). All platforms require OpenBLAS on Linux/Windows — install via your system package manager (`apt install libopenblas-dev`, `vcpkg install openblas`). + ## Building & Testing -(Needs Openblas under Linux) ```bash cargo build --release # optimised build -cargo build --release --features metal # with Metal GPU backend +cargo build --release --features metal # with Metal GPU backend (macOS only) cargo test # all tests across all crates cargo test -p larql-inference # inference engine tests (109 tests) cargo test -p larql-inference --features metal # + Metal GPU tests (115 tests) @@ -558,6 +719,7 @@ cargo test -p larql-vindex # vindex storage + patch tests (104 tes # Inference engine examples cargo run --release -p larql-inference --example attention_demo # fused attention demo +cargo run --release -p larql-inference --example mech_interp_demo # capture / lens / ablate / steer / patch (synthetic — no vindex) cargo run --release -p larql-inference --example bench_attention # attention benchmarks cargo run --release -p larql-inference --example backend_demo --features metal # backend demo cargo run --release -p larql-inference --example bench_backend --features metal # backend benchmarks @@ -570,6 +732,11 @@ cargo run --release -p larql-vindex --example build_up_features -- path/to/vinde # Server (walk inference over HTTP) cargo run --release -p larql-server -- path/to/vindex --port 8080 +cargo run -p larql-server --example server_demo # synthetic HTTP surface demo +cargo run -p larql-server --example embed_demo # synthetic embed/logits/token demo +cargo run --release -p larql-server --example server_bench # synthetic server operation benchmark +cargo run --release -p larql-server --example bench_embed_server -- path/to/vindex +cargo test -p larql-router # static router + grid route-table checks # Vindex and LQL demos (synthetic — run in CI) cargo run -p larql-vindex --example demo_features # vindex feature showcase diff --git a/ROADMAP.md b/ROADMAP.md index d11828b3..2c9d4b46 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -1,638 +1,254 @@ # LARQL Roadmap -Top-level plan of record. Per-crate specifics live in -`crates//ROADMAP.md`; this file tracks user-visible features, -the demo narrative, and cross-crate work. - -## Current state - -- **490 tests passing** across 14 suites, 0 build warnings. -- **Primary CLI verbs** in place: `run`, `chat`, `pull`, `list`, `show`, - `rm`, `link`, `serve`. Legacy research commands under `larql dev - ` with argv trampoline for backwards-compat. -- **Dual cache** (HuggingFace hub + `~/.cache/larql/local/`) with - shorthand resolution (`larql run gemma3-4b-it-vindex …`). -- **Remote FFN path (Phase 0 — dense):** `POST /v1/walk-ffn` - `full_output: true` returns hidden-size output vectors per layer; - `RemoteWalkBackend` in `larql-inference` drops into `predict_with_ffn` - unchanged; `larql run --ffn URL` + `larql serve --ffn-only` wire it - end-to-end. gRPC mirror also landed. -- **Vindex size reductions:** `--compact` (drops - `up_weights.bin`/`down_weights.bin`), `--drop-gate-vectors` (rebuilds - gate from `interleaved_q4k.bin` at load), `--quant q4k` implies f16 - on side-channel tensors. Combined: a new 31B q4k extract is **~22 GB - vs 52 GB before** (~60% smaller). +Top-level plan. Per-crate detail lives in each crate's own `ROADMAP.md`. +This file tracks the demo narrative, the critical path, and cross-crate sequencing. --- -## P0 — Act 2 of the demo: "The experts live elsewhere" - -### Phase 1 — MoE inference path (blocks Act 2) - -The whole Act 2 story is MoE-distributed. - -- [x] **Gemma 4 MoE architecture hooks** in - `crates/larql-models/src/architectures/gemma4.rs` — `is_hybrid_moe`, - `num_experts`, `num_experts_per_token`, `moe_router_key`, - `packed_experts_gate_up_key`, `packed_experts_down_key`, per-layer - norms (`pre_feedforward_layernorm_2`, `post_feedforward_layernorm_2`), - `moe_router_per_expert_scale_key`, `layer_scalar_key`. -- [x] **CPU MoE forward pass** (`crates/larql-compute/src/cpu/ops/moe.rs`): - BF16 expert dequant, router softmax, top-K selection, per-expert - gated FFN (gate_proj + up_proj + SiLU + down_proj), weighted sum, - post-experts RMSNorm. Wired into `decode_token` via GPU/CPU interleave. -- [x] **Metal decode with CPU MoE interleave** — GPU runs dense FFN per - layer, CPU reads `h_post_attn` (unified memory), runs MoE, adds - output to `new_h`. Layer scalar correctly applied only to the - combined FFN+MoE delta (`h_post_attn + scalar * (dense + moe)`), - not to the full residual. -- [x] **Gemma 4 26B A4B coherent output** — first end-to-end working - Metal inference (2026-04-24). The four fixes that had to land together: - 1. **Row-padded Q4_K/Q6_K storage** for matrices whose inner dim - isn't a multiple of 256 (26B A4B's dense `intermediate_size=2112` - → 8.25 super-blocks per row). Old extraction stored contiguously, - shader read wrong bytes for every `down_proj` row past 0. See - `pad_rows_to_256` in `crates/larql-vindex/src/format/weights/write.rs` - + `inter_padded` dispatch in `metal/decode/mod.rs`. - 2. **Parameter-free router RMSNorm** — HF's `Gemma4TextRouter.norm` - is `with_scale=False` (no tensor on disk). Added arch trait - `moe_router_norm_parameter_free()` and the `rms_norm_no_weight` - branch in `cpu/ops/moe/forward.rs`. - 3. **Outer `post_feedforward_layernorm.weight`** (un-suffixed) - extracted + applied to `(h1 + h2)` before the residual add — - distinct from the `_1` dense-branch norm. - 4. **`layer_scalar` scales the whole layer output** (`new_h *= - layer_scalar`) not the FFN delta — matches HF's final - `hidden_states *= self.layer_scalar` in `DecoderLayer.forward`. - Validated end-to-end by residual-diff against HF bf16 (see - Correctness infrastructure below): L0 `layer_out` cos improved from - 0.7018 → 0.9998; L29 cos from −0.27 → 0.93. -- [ ] **Batched MoE prefill** — current MoE prefill uses token-by-token - `decode_token` calls (correct, but O(seq_len) serial GPU dispatches - per layer). Replace with a batched prefill that processes all prompt - positions in one pass, interleaving GPU dense FFN and CPU MoE at each - layer. See `crates/larql-compute/src/metal/trait_impl.rs::prefill_q4` - and `full_pipeline.rs::dispatch_full_pipeline`. -- [ ] **Fix `dispatch_full_pipeline` layer_scalar** — currently scales - the full residual including `h_post_attn` instead of applying - `new_h *= layer_scalar` at the end of the layer (HF-accurate). The - decode path now does this correctly via `apply_whole_layer_scalar` - in `metal/decode/moe_combine.rs`; prefill path (only matters for - seq_len>1 with non-MoE `layer_scalar` models) still needs the same. -- [ ] **Chat-template-aware prompting** — 26B A4B is instruct-tuned - and answers trivia confidently only via the chat template. On raw - prompts it wanders (HF top-1 on "The capital of France is" is - `' CAP'`, not `' Paris'`). The architecture regression test now - asserts against what HF actually produces, but the `run` CLI should - auto-apply the template for IT models — see P1 "Chat template" below. -- [ ] **MoE-aware forward pass on CPU path** — `predict_q4k` / - `WeightFfn::forward` has no MoE. The non-Metal CPU path produces - wrong output on Gemma 4 26B. Wire `cpu_moe_forward` into - `larql-inference/src/forward/layer.rs`. -- [ ] Wire `RouterIndex` (already exists at - `crates/larql-vindex/src/index/router.rs`) into the client-side - forward pass so the router runs locally. - -### Phase 2 — Remote expert protocol (Act 2 wire format) - -- [ ] `POST /v1/expert/{layer}/{expert_id}` — input residual, output - residual delta (hidden-size). -- [ ] `POST /v1/expert/batch` — list of `{layer, expert_id, residual}`, - returns list of deltas. Collapses a layer's K experts into one HTTP - round trip per server. -- [ ] `--experts 0-31` flag on `larql serve` — load + serve a subset - of expert IDs so experts can be sharded across machines. -- [ ] `RemoteExpertBackend` in `larql-inference` — MoE-path analog of - `RemoteWalkBackend`. Handles the sharding map (expert ID range → - URL), parallel per-layer dispatch, per-expert error handling. - -### Phase 3 — LQL / CLI ergonomics - -- [ ] `USE "..." WALK ONLY WITH EXPERTS REMOTE { "range": "url", ... };` - grammar. Extend `crates/larql-lql/src/parser/lifecycle.rs` + executor. -- [ ] `RESHARD EXPERTS { ... };` statement for live redistribution - (for the "kill one shard, rewire on the fly" proof shot). -- [ ] `larql run --experts '0-31=URL1,32-63=URL2'` CLI flag (MoE - counterpart to `--ffn`). - -### Phase 4 — Data prep - -- [ ] `larql slice --parts attn,embed,norms,router,index,tokenizer` - (new subcommand) — carve an attention-only / router-only vindex out - of a full one without re-extracting from the source model. - -### Phase 5 — Deferred until film - -- [ ] GPU attention on the client side. `run_attention_block_gpu` - already exists in `crates/larql-inference/src/attention/gpu.rs` but - isn't the default path in `forward/layer.rs`. Wire Metal/CUDA into - the walk-only forward pass so client-side attention runs on GPU - while FFN/experts go remote. +## Crate roadmaps + +| Crate | Owns | +|---|---| +| [larql-compute](crates/larql-compute/ROADMAP.md) | Metal GPU kernels, MoE prefill, platform expansion | +| [larql-inference](crates/larql-inference/ROADMAP.md) | Forward pass, generation quality, KV engines | +| [larql-server](crates/larql-server/ROADMAP.md) | HTTP API, gRPC grid, remote expert protocol | +| [larql-cli](crates/larql-cli/ROADMAP.md) | CLI UX, sampling flags, streaming display | +| [larql-lql](crates/larql-lql/ROADMAP.md) | LQL grammar, INSERT/SELECT/USE extensions | +| [larql-core](crates/larql-core/ROADMAP.md) | Graph data model, algorithms, serialization | +| [larql-vindex](crates/larql-vindex/ROADMAP.md) | Vindex format, storage, extraction | +| [larql-models](crates/larql-models/ROADMAP.md) | Architecture definitions, model loading | --- -## P1 — Generation UX (chat template, sampling, stopping) +## Current state (2026-05-02) -The current `larql run` output loops ("ParisatthecapitalofFranceis...") because -three standard inference features are missing. All are independent and any one -improves the experience. +- **2,000+ tests passing** across the workspace, 0 build warnings. +- **Primary CLI verbs** in place: `run`, `chat`, `pull`, `list`, `show`, `rm`, `link`, `serve`, `bench`. +- **Gemma 3 4B Metal**: **83–84 tok/s** (Ollama steady: 98.5–99.7). **Gap: 1.18×** (was 1.30× before the 2026-05-02 dispatch-geometry fix). +- **Gemma 4 26B A4B Metal**: **19.4 tok/s** (was 5.1 — bug-locked under the same dispatch-geometry mismatch; correct multilingual output now). +- **Grid (CPU MoE on remote shards)**: 18.3 tok/s 1-shard / 17.3 tok/s 2-shard local-loopback, both with parallel collect (`std::thread::scope`) and parallel fire (`rayon::par_iter`). Multi-host LAN/cross-region scaling unblocked by F-COLLECT in `crates/larql-server/ROADMAP.md`. +- **Remote FFN (dense)**: `larql run --ffn URL` + `larql serve --ffn-only` wired end-to-end. +- **gRPC grid**: 2-shard self-assembling grid live-validated on 26B A4B. +- **4 KV-cache engines**: MarkovRS (287×), UnlimitedContext (254×), TurboQuant (4×), Apollo (20,000×) — all at ~95 tok/s on Gemma 3 4B Metal. -### Chat template -**Status**: Not started -**Impact**: High — instruction-tuned models (Gemma 3/4 IT, Mistral-Instruct) -loop or produce garbage without their expected prompt format. +--- -`larql run` sends raw text to the model. IT models expect a structured -turn format, e.g. Gemma 4: -``` -user -The capital of France is -model +## Demo narrative + +### Act 1 — "The model is the database" +Run Gemma 3 4B or 4 26B locally. The vindex is the model; `larql run` queries it. +Show: latency, footprint, `larql walk` tracing a fact through layers. + +**Status**: Works end-to-end. Needs chat-template + EOS fix so it doesn't loop. + +### Act 2 — "The experts live elsewhere" +Split a MoE model across machines. Client holds attention weights; each shard +holds a subset of expert IDs. The forward pass fans out to shards per token. + +**Status**: Server-side grid works. Missing: remote expert endpoints (`/v1/expert/*`), +`RemoteExpertBackend` client, chat-template-aware prompting. + +### Act 3 — "Replace an expert" +Swap expert 42 at layer 18 for a custom one. Observe the model's behaviour change. + +**Status**: Expert ID selection TBD. Requires Act 2 first. + +--- + +## P0 — Mechanistic surface (lazarus parity) + +Driver: replace the chuk-mlx engine in `chuk-mcp-lazarus` with larql. Lazarus +exposes ~77 inference-time MCP tools (capture, ablate, patch, steer, probe, +DLA, KV-surgery). Larql is currently strong on weight-level edits (MEMIT, KNN, +LQL) and weak on inference-time inspection/intervention. The 77 tools collapse +to one missing primitive: a **programmatic forward-hook system**. Once that +lands the rest is mostly Python wrappers. + +| # | Item | Crate | Status | +|---|------|-------|--------| +| M1 | `LayerHook` trait + CPU plumbing (read + write) | larql-inference | shipped | +| M2 | `RecordHook`, `ZeroAblateHook`, `SteerHook`, `CompositeHook` | larql-inference | shipped | +| M3 | Activation patching (cross-prompt residual swap) | larql-inference | shipped | +| M4 | Full logit lens — `logit_lens_topk`, `track_token`, `track_race` | larql-inference | shipped | +| M5 | `KvCache::{get_layer, set_layer, clear_layer, clone_layer_from, clone_layer_position_range}` | larql-inference | shipped | +| M6 | Hooks during multi-token generation (`generate_cached_hooked` on CPU; Metal `generate` stays fast by design) | larql-inference | shipped | +| M7 | `W_E` / `W_U` + `embedding_neighbors` + `project_through_unembed` | larql-inference | shipped | +| M8 | pyo3 `PyWalkModel` mech-interp methods (capture / ablate / steer / patch / lens / generate_with_hooks) | larql-python | shipped | + +Detail in `larql-inference/ROADMAP.md` § Mechanistic hooks (lazarus parity). + +--- + +## P0 — Best-in-class mechanistic interpretability engine + +Driver: make LARQL's executed mechanisms queryable, attributable, patchable, +and reproducible. This is the layer above lazarus parity: not just hooks, but +evidence-grade traces and causal operators over the actual vindex-backed +inference path. + +| # | Item | Crate | Status | +|---|------|-------|--------| +| MI0 | Faithful residual DAG: TRACE uses the canonical layer runner and pins additive reconstruction | larql-inference | shipped | +| MI1 | Python `WalkModel.trace()` / `patch_activations()` use `WalkFfn` instead of dense fallback | larql-python + larql-inference | shipped | +| MI2 | Backend-parametric donor capture and activation patching | larql-inference | shipped | +| MI3 | Strict trace artifacts: complete ordered chains, exact file length, `TRACE SAVE` requires `POSITIONS ALL` | larql-inference + larql-lql | shipped | +| MI4 | Golden parity: TRACE final residual/logits match canonical forward; extend to WalkFfn, patched vindex, Q4K, MoE | larql-inference | partial — dense/custom backend pinned | +| MI5 | Rich attribution objects: attention-head writes, FFN feature activations, router/expert decisions, provenance | larql-inference + larql-python | planned | +| MI6 | Causal operators beyond residual replacement: head/feature/router/expert/KV patching | larql-inference + larql-python | planned | +| MI7 | Q4K/MoE trace and patch parity with explicit precision caveats | larql-inference + larql-vindex | planned | +| MI8 | Python experiment ergonomics: batched prompts, donor/recipient alignment, causal metrics, reproducibility metadata | larql-python | planned | + +Near-term order: finish MI4 parity coverage, then add attribution records where +the forward path already exposes data, then expand patching operators one +mechanism at a time. + +--- + +## P1 — Research stack promotion: OV/RD → engine primitives + +Driver: make LARQL one of the strongest practical mechanistic +interpretability stacks by promoting reusable experiment plumbing into +stable engine APIs, while leaving fast-moving hypotheses in +`larql dev ov-rd` and Python artifact analysis. + +| # | Item | Crate | Status | +|---|------|-------|--------| +| R1 | Promote Q4K per-layer tensor insertion/removal from `ov_rd` into `larql-inference::vindex` | larql-inference | shipped | +| R2 | Add Q4K hidden forward with `LayerHook`/intervention support | larql-inference | shipped | +| R3 | Add pre-W_O capture/replacement hook adapters so experiments stop manually driving full layer loops | larql-inference | shipped | +| R4 | Define a compact research trace artifact contract for prompt ids, tokens, layer inputs, pre-W_O rows, oracle codes, logits, and metrics | larql-inference + larql-cli | planned | +| R5 | Keep PQ/address/codebook experiments in `larql dev ov-rd`; move only stable runtime contracts into engines | larql-cli | ongoing | + +Rule of thumb: engine code owns reusable capture/intervention/runtime +primitives; `ov_rd` owns experiment orchestration, PQ variants, address +probes, and report schemas until a runtime contract survives repeated +experiments. + +--- + +## P0 — Interpretability truthfulness + commit semantics + +Driver: make the current edit model honest before the demo, then earn the +stronger "INSERT commits into weights" story. Today default `INSERT MODE KNN` +is a retrieval overlay persisted in `knn_store.bin`; `COMPILE INTO VINDEX` +bakes compose/MEMIT overlays but carries that KNN sidecar forward. That is a +snapshot/package operation, not a mechanical commit of the journal into FFN +features. + +| # | Item | Crate | Status | +|---|------|-------|--------| +| T1 | Tag KNN overrides visibly in `INFER`, `EXPLAIN INFER`, and `TRACE` as post-logits retrieval events, including the model's unoverridden top-1 | larql-lql + larql-inference | planned | +| T2 | Fix decomposed `TRACE` to route through the shared layer sequence, including PLE/layer-scalar deltas or equivalent captured intermediates | larql-inference | shipped | +| T3 | Make Python `WalkModel.trace()` use the vindex `WalkFfn`/patch overlay rather than dense `WeightFfn` | larql-python + larql-inference | shipped | +| T4 | Replace gate-KNN absolute-dot feature ranking in interpretability displays with post-activation magnitude, or filter ghost negative gates after activation | larql-vindex + larql-inference | planned | +| T5 | Fix L1 FFN cache activation capture: cache activations with outputs or bypass cache when activations are requested | larql-inference | planned | +| T6 | Rename residual-capture embedding-neighbor fields (`top_token`) or add separate true logit-lens fields | larql-inference + larql-models | planned | +| T7 | Pin TRACE evidence with final residual/logit parity tests across dense, custom backend, WalkFfn, patched vindex, Q4K, and MoE paths | larql-inference | partial | +| C1 | Add explicit compile modes: default commit/materialize semantics vs `SNAPSHOT` preserving `knn_store.bin` | larql-lql + larql-vindex | design | +| C2 | Implement KNN materialization by lowering retrieval entries into compose/MEMIT/FFN edits, then dropping or marking committed sidecar entries | larql-lql + larql-vindex + larql-inference | planned | +| C3 | Add acceptance tests: session KNN equivalence, trace conversion, and generalization beyond stored prompts | larql-lql + larql-inference | planned | + +Acceptance target for materialization: + +```text +INFER(session_with_knn, q) == INFER(materialized_vindex, q) ``` -Without it, the model sees a bare continuation task and loops greedily. - -Fix: read `tokenizer_config.json` from the vindex (already present for -HF-extracted models — lives next to `config.json`). Parse the -`chat_template` Jinja field. Apply it in `larql run` before tokenising. -`minijinja` crate is the standard Rust choice. `larql chat` should always -apply the template; `larql run` can expose `--no-chat-template` for raw use. - -### EOS detection and stop strings -**Status**: Partial — `generate.rs` checks for ``, ``, -`<|endoftext|>` but Gemma 4 uses `` which is not in that list. -**Impact**: High — without EOS stopping, greedy decode runs to `--max-tokens`. - -Fix: read `eos_token_id` (and `eos_token_ids` list) from `config.json`; -also read `stop_strings` from `generation_config.json` (Gemma 4 lists -`` there). Check decoded token string + token ID at every -step in `generate.rs`. `run_cmd.rs` could expose `--stop STRING` for -overrides. - -### Token spacing / detokenisation display -**Status**: Not started -**Impact**: Medium — "Paris at the capital..." prints as "Parisatthecapital". - -HuggingFace tokenizers use a leading-space convention (`▁Paris`) — the -`tokenizers` crate's `decode` already handles this when -`skip_special_tokens = true`. The bug is likely that `tokenizer.decode` -is called per-token with `false` (keeps `▁` prefix stripped) instead of -accumulating and decoding the full sequence, or that `trim()` is stripping -the leading space. Fix in `generate.rs` decode loop: `decode(&[tid], false)` -and keep the raw string; only trim the very first token. - -### Sampling (temperature / top-p / top-k) -**Status**: Not started -**Impact**: Medium for quality, needed for non-deterministic output. - -Current path is always greedy (argmax). Add `--temperature F`, `--top-p F`, -`--top-k N` flags to `run_cmd.rs`. Sampling happens after the lm_head -scores are computed in `generate.rs` — no GPU changes required. - -### Repetition penalty -**Status**: Not started -**Impact**: Medium — practical fix for the greedy looping problem without -requiring a full chat template. Useful for raw-prompt (`larql run`) and -base models where no chat template exists. - -Add `--repetition-penalty F` (default 1.0 = off). Before argmax / sampling, -divide each token's logit by the penalty if that token appears in the -recently generated window. Standard implementation: logit ÷ penalty for -tokens in the last N generated positions. No GPU changes required — purely -a logits post-processing step in `generate.rs`. - -### Multi-turn conversation state -**Status**: Not started — `larql chat` resets KV cache per turn today. -**Impact**: High — "chat" implies the model remembers what it said. Without -this, each line in chat mode is an independent cold-start forward pass. - -Fix: maintain a running `token_ids` buffer across turns in `run_cmd.rs`. -After each model response, append the response token IDs to the buffer -before the next user turn. Wrap each turn pair in the chat template -(`user … model …`) incrementally. Pass the full buffer -to `generate()` so the KV cache grows across turns. Expose `--max-context N` -to bound memory (evict oldest turns when the context window fills). - -### Token streaming - -### Long context / dynamic KV cache -**Status**: Hard-capped at 4096 tokens today. -**Impact**: High — Gemma 4's headline feature is 1M context. 4096 is a -non-starter for long conversations and the demo's "database" framing. - -Two parts: -1. **Configurable max** — expose `--max-context N` (default 8192). - `KVCache::new_per_layer` already takes `max_seq`; thread `N` through - `prefill_q4` / `decode_token` call sites in `generate.rs`. -2. **Dynamic growth** — when `current_len` reaches `max_seq`, either - evict the oldest window (sliding, already implemented as - `--kv-cache markov-bounded`) or double the buffer. The Metal KV - cache buffers are pre-allocated; growth requires a realloc + copy on - the GPU side. A simpler interim: warn and truncate at `max_seq`, - document as a known limit. -**Status**: Not started -**Impact**: High for UX — without streaming, the CLI is silent until all -`--max-tokens` are done. A 64-token run on Gemma 4 26B takes ~10s with no -output; streaming makes it feel interactive immediately. - -Fix: `generate.rs` currently collects tokens into a `Vec` and returns. -Change to accept a `on_token: impl FnMut(&str, f64)` callback (or a -`std::sync::mpsc::Sender`). In `run_cmd.rs`, the callback prints each token -to stdout and flushes. The `larql serve` OpenAI-compatible path (`/v1/chat/completions` -with `stream: true`) would use SSE chunks from the same callback. -Chat mode in `run_cmd.rs` already flushes stdout per turn — streaming -just moves the flush inside the generate loop. - -### OpenAI-compatible `/v1/chat/completions` -**Status**: Not started — `larql serve` has custom endpoints but no -OpenAI-compatible chat surface. -**Impact**: High for adoption — makes LARQL a drop-in backend for -Continue.dev, Open WebUI, LiteLLM, and any tool that speaks the -OpenAI API. The "you can do this too" demo moment needs a working URL. - -With chat template + streaming landing, this is largely wiring: -- `POST /v1/chat/completions` — accept `{model, messages, stream, - temperature, max_tokens}`, apply the model's chat template to the - `messages` array, call `generate()`, return `ChatCompletionResponse` - (non-stream) or SSE `data: {"choices":[{"delta":...}]}` chunks (stream). -- `GET /v1/models` — return the loaded vindex name so clients can - enumerate available models. -- Wire into `larql-server/src/routes/` alongside the existing endpoints. - -### Auto-extract on `larql run hf://` -**Status**: Not started. -**Impact**: High for adoption — the current flow is `larql extract` → -`larql link` → `larql run`. Three commands before inference starts. -The "you can do this too" moment needs one. - -Fix: in `cache::resolve_model`, if the shorthand looks like `hf://owner/name` -and no cached vindex matches, offer to run `larql extract` inline -(with a confirmation prompt or `--yes` flag). Download the safetensors -from HuggingFace, stream-extract to a temp directory, move to the -local cache, then proceed with inference. Re-uses the existing -`larql extract` pipeline — the new code is only in the cache resolver -and a progress display wrapper. - -### Gemma 3 4B regression smoke test -**Status**: Not started — no CI check verifies correctness after -compute / inference changes. -**Impact**: Medium — after the MoE and layer_scalar changes, nothing -formally verifies Gemma 3 4B still produces "Paris" at expected -probability. One bad merge could silently break the most-used model. - -Fix: add a `tests/integration/` test (or `larql-cli` example) that -loads `gemma3-4b-q4k-streaming` (already in the local cache), runs -`larql run "The capital of France is" -n 1 --metal`, and asserts the -first token is "Paris". Gate on `CI_INTEGRATION=1` so it doesn't run -on every PR but does run before release branches. + +for affected canonical prompts, plus a stronger trace/generalization check: +session trace reports pending retrieval; materialized trace shows residual/FFN +evidence; nearby unstored prompts behave through the materialized edit rather +than through a lookup sidecar. + +Until C1-C3 ship, video language should distinguish three mechanisms: +KNN journal/retrieval overlay, compose FFN overlay, and compiled/baked weights. --- -## P1 — Autoregressive generation quality - -### CPU KV cache for autoregressive generation — **SHIPPED** - -Two-phase autoregressive decoder in `larql-inference/src/forward/kv_generate.rs`: - -- **Prefill** uses `run_attention_with_kv` to capture post-RoPE K and - post-V-norm V per layer into a `KvCache`. -- **Decode** step in `crates/larql-inference/src/attention/decode.rs`: - `run_attention_block_decode_step` takes the new token's hidden + - the layer's existing cache, computes Q/K/V for just that row with - `apply_rope_partial_at(position=cached_len)`, concatenates the new - K/V onto the cache, runs `gqa_attention_decode_step` (O(cached_len) - per head), returns updated cache. - -Backend-agnostic via `FfnBackend` — works with `WalkFfn` (local) and -`RemoteWalkBackend` (FFN over HTTP). Measured on Gemma 3 4B f32: - -- **Local, no cache (before):** ~1.2 s per decode step, O(N²) growing -- **Local, KV-cached (now):** ~0.6 s/token steady -- **Remote FFN, KV-cached (now):** ~0.5-0.6 s/token steady — same - protocol as the no-cache version, just many fewer tokens re-shipped - -Limitations: -- Skips Gemma 4 E2B per-layer embeddings (PLE) and layer-scalar - application in the decode loop. Fine for Gemma 3. For full - Gemma 4 correctness wire `apply_per_layer_embedding` + `apply_layer_scalar` - into `generate_cached`'s decode layer. -- Q4K CPU path still uses its own no-cache loop (`run_q4k_generate_cpu`). - Q4K + Metal shader `generate()` remains the fast Q4K path. - -### KV cache strategy selector — **SHIPPED (partial)** - -`larql run --kv-cache ` selects how past-token state is kept: - -- `standard` *(default)* — full FP32 K/V, unbounded. Shipped. -- `markov-bounded` — sliding window (StreamingLLM-style). Shipped. - Pass `--context-window N` for the window size. Older tokens drop - off; memory stays O(window) regardless of generation length. -- `none` — re-run full forward per decode step. O(N²). Shipped as - correctness fallback. - -Not yet wired into the live decode path (all in `crates/kv-cache-benchmark/`): - -- `markov-full` — active residual window + cold-tier reconstruction - via checkpoint layers. Compressed storage via residuals not K/V. - See `crates/kv-cache-benchmark/src/markov_residual/`. Needs a - reconstruction primitive that rehydrates K/V for cold-tier - positions from `token_ids + checkpoint_residual`. -- `turboquant` — per-tensor Q4/Q8 compression of cached K/V. See - `crates/kv-cache-benchmark/src/turboquant/`. Needs per-step - quantize/dequantize around the cache append. -- `graph-walk` — experimental, unclear production viability. - -### Shader attention + remote FFN - -### Metal speedup for non-Q4K decode - -**Status:** backend is auto-detected and threaded through -`generate_cached_backend`, but in practice **single-token decode -matmuls stay on CPU** because they fall below the Metal backend's -calibrated FLOP threshold (~500M). Per-layer projections on 4B are -only 5-7M FLOP each — far under the break-even point where GPU -dispatch overhead is worth paying. - -**What this means today:** -- `larql run` on f16/f32 vindexes uses CPU BLAS projections regardless - of `--metal` availability. The KV cache is still the decisive win - (~6× speedup vs no-cache). -- `larql run --metal` on a **Q4K vindex** routes to - `larql_inference::layer_graph::generate` (the shader - `full_pipeline_q4` — all layers fused in one command buffer, KV- - cached decode on GPU). This is the real GPU path. - -**What would actually win on f16/f32:** -1. **Fused f16 full_pipeline shader** — same structure as Q4K's - `full_pipeline` but with f16 weights. Multi-day shader work. -2. **Batched / speculative decode** — emit N tokens per forward pass - (draft model, Medusa heads, or speculative sampling). N×M FLOP - per matmul would clear the threshold. Compatible with remote FFN - if the batching happens client-side. - -See `crates/larql-compute/benches/{linalg,matmul}.rs` and the -many `crates/larql-compute/examples/profile_*.rs` for the measured -GPU-vs-CPU break-even curves — the threshold isn't arbitrary. - -### Shader attention + remote FFN (Act 2 endgame) - -Q4K + Metal + remote FFN — the ultimate Act 2 configuration. The -shader pipeline (`full_pipeline_q4` / `decode_token`) currently -dispatches attention AND FFN as fused GPU kernels reading from the -Q4K mmap. For remote FFN we'd need to decompose per-layer into: -attention-only GPU kernel → copy residual to host → HTTP round trip -→ copy FFN output back to GPU → next layer's attention. Per-layer -host+network hop kills throughput unless we batch across layers or -use async pipelining. - -Worth doing for the Act 2 demo but non-trivial. See -`larql-inference/src/layer_graph/{generate,pipeline_layer,prefill}.rs` -— the fused paths need splitting at the attention/FFN seam. - -## P1 — Loose ends in shipped features - -### `--compact` loader reconstruction — WalkFfn-only today - -`larql extract --compact` drops `up_weights.bin` + `down_weights.bin` -from the extract. `WalkFfn` (the production inference path) works fine -— it reads feature-major `{up,down}_features.bin` directly. The dense -ground-truth path (`WeightFfn`, used by `larql dev walk --compare` for -validation) panics with a clear message. - -**Why deferred.** The naive fix is to reconstitute -`Array2` tensors in `ModelWeights.tensors` at load time. For -`down_proj` this requires a transpose (feature-major `[intermediate, -hidden]` → safetensors `[hidden, intermediate]`) which means an owned -copy — **~27 GB of extra heap on 31B**, not viable. - -**Proper fix.** Refactor `WeightFfn::forward` (or `ModelWeights`) to -accept feature-major views and pass the transpose flag through to BLAS -gemm. Cross-cutting change: `crates/larql-inference/src/ffn/weight.rs`, -`crates/larql-inference/src/model.rs`, and the `dot_proj` helpers. ~1 -focused session. - -**Impact.** Unblocks `--compact --compare` for validation workflows. -Does not affect `larql run` or the demo. - -### MoE compact mode — refused today - -`larql extract --compact` on an MoE architecture refuses with: -> *"ffn_compact not yet supported for MoE architectures — per-expert -> feature-major files don't exist yet"* - -**Why deferred.** Two blockers: - -1. **Router lives in `up_weights.bin`.** The MoE write path stuffs - per-expert up weights *and* the router matrix together into - `up_weights.bin`. Skipping that file loses the router, so the model - can't dispatch to experts at all. Fix: split the router into its - own file (`router_weights.bin` already exists as the intended home - — see `crates/larql-vindex/src/index/router.rs`). -2. **No per-expert feature-major files.** `up_features.bin` / - `down_features.bin` are single-matrix-per-layer. MoE-compact would - need per-expert equivalents (~N× the file count or a new layout), - plus a tool that produces them. No consumer exists yet. - -**When to do it.** Pairs naturally with Phase 1 (MoE inference path) -and Phase 2 (per-expert server endpoint). Building those requires a -per-expert-addressable storage layout anyway; compact-MoE falls out of -it. - -### `larql dev walk --compact` compatibility - -`larql dev walk --compare` against a `--compact` vindex panics (see -above). The panic message points at `WalkFfn` but doesn't explain -`--compare` is the specific operation that's blocked. Improve the -error or disable the `--compare` flag at arg-parse time when the -target vindex is compact. - -### Cross-vindex dedup (tokenizer, down_meta) - -Tokenizer (~32 MB) and `down_meta.bin` (~30 MB) are identical across -different-precision extracts of the same base model. With ~7 linked -vindexes in the local cache that's ~200 MB of duplicate data. Low -priority — worth doing as a content-addressed store if the cache -grows, otherwise skip. +## P1 — Model architecture independence hardening + +Driver: keep LARQL from becoming "Gemma-shaped with exceptions." The core +`ModelArchitecture` trait is the right boundary, but several production paths +still infer family from strings, pass scalar attention geometry through +per-layer pipelines, or advertise architectures whose extraction/inference +contracts are incomplete. + +| # | Item | Crate | Status | +|---|------|-------|--------| +| AI1 | Gate supported architecture families by executable contracts: extraction, vindex weight writing, forward/decode, trace, and prompt rendering | larql-models + larql-vindex + larql-inference | planned | +| AI2 | Implement or explicitly reject MLA architectures in vindex writers and inference; DeepSeek is detected today but `mla_*` tensors are not consumed outside `larql-models` | larql-models + larql-vindex + larql-inference | planned | +| AI3 | Remove scalar attention-geometry fallbacks from backend decode APIs; allocate KV/cache/scratch from `FullPipelineLayer` per-layer shapes everywhere | larql-compute + larql-inference | planned | +| AI4 | Replace vector-only extraction's model-name family guesses with explicit metadata or validated architecture input | larql-vindex | planned | +| AI5 | Roll validated loading/detection through inference, extraction, CLI, and server entry points where missing config should fail fast | larql-models consumers | planned | +| AI6 | Harden vindex extraction/write paths with explicit capability gates, named manifest/tensor tags, and tests proving unsupported attention layouts fail before writing partial indexes | larql-vindex + larql-models | next | + +Acceptance target: adding a new transformer architecture should require changes +inside `larql-models::architectures/*` and explicit capability decisions at +storage/forward boundaries, not incidental string matches or hidden Gemma/Llama +defaults in extraction and decode. --- -## P2 — Demo production - -### Pre-film checklist for the Gemma 4 MoE video - -- [ ] Confirm Gemma 4 26B A4B config once the model card is public: - expert count per layer, top-K, exact active-param figure, GQA ratio. - Every `~` figure in `docs/demo-script-gemma4-moe.md` needs a real - number before recording. -- [ ] Measure real footprint + latency on `google/gemma-4-31b-it` for - Act 1. Replace every `~` in the Act 1 section. -- [ ] Reliability pass on `RemoteWalkBackend` (timeouts, retries, - mid-layer failure, partial shard outage). A hung HTTP call during - recording kills the take. -- [ ] `RemoteExpertBackend` (doesn't exist yet — see Phase 2) same - pass. -- [ ] Decide the repo-public date. `cargo install larql-cli && larql - serve` should be live the week the video drops so "you can do this - too" lands with a working command. -- [ ] Pick expert IDs for the Video 3 teaser swap — one that fires on - medical prompts, one that doesn't — so the "replace expert 42 at - layer 18" shot lands concretely. - -### Memory-footprint `--ffn-only` on the server - -`larql serve --ffn-only` today is an operating-mode declaration — it -disables `/v1/infer`, advertises `mode: ffn-service` in `/v1/stats`, -but still loads full `ModelWeights` into RAM. A real FFN-service -doesn't need attention weights resident. - -Add `load_model_weights_ffn_only` to `larql-vindex` that skips -attention tensors on the server side. Payoff: serve an MoE without -the attention weights taking a third of RAM. +## Critical path (P0 — what blocks the demo) + +Items in order. Each depends on the one above it. + +| # | Item | Crate | Status | +|---|------|-------|--------| +| 1 | Chat template + EOS stop | larql-inference + larql-cli | not started | +| 2 | Token streaming | larql-inference + larql-cli | not started | +| 3 | **Per-layer FFN format** (`layers/`, GPU dispatch) Phase 2: pre-alloc buffers | larql-vindex + larql-compute | shipped — `MoeScratch` pre-allocates once per decode call; combined with the 2026-05-02 dispatch-geometry fix, 26B A4B Metal now runs at **19.4 tok/s** (was bug-locked at 5.1) | +| 4 | MoE-aware CPU forward pass (non-Metal fallback) | larql-inference | not started | +| 5 | Wire `RouterIndex` client-side | larql-inference | not started | +| 6 | `POST /v1/expert/{layer}/{expert_id}` | larql-server | not started | +| 7 | `POST /v1/expert/batch` | larql-server | not started | +| 8 | `--experts 0-31` flag on `larql serve` | larql-server | not started | +| 9 | `RemoteExpertBackend` client | larql-inference | not started | +| 10 | Reliability pass (timeouts, retries) | larql-server | not started | + +Items 1–2 are needed for Act 1. Item 3's MoE performance gate landed +2026-05-02: 26B A4B Metal now runs at 19.4 tok/s (was 5.1, bug-locked +under the dispatch-geometry mismatch in `moe_dispatch.rs`). SKIP_MOE +ceiling 56.8 tok/s — remaining headroom is real expert-dispatch work, +not allocation. Items 4–10 are needed for Act 2. See +`larql-vindex/ROADMAP.md P0` and `larql-server/ROADMAP.md` (F-COLLECT, +F-LOCAL-MOE, G-SCALE) for the next levers. --- -## Done (ship log) - -### Gemma 4 26B A4B end-to-end correctness (2026-04-24) -Closed four independent gaps that together produced garbage output on -the hybrid-MoE 26B A4B model; aligned non-MoE models (Gemma 3 4B, -Gemma 4 31B, Mistral 7B) were unaffected and continue to pass. See -`crates/larql-compute/ROADMAP.md` P0.5 for full per-fix detail. - -- **Q4_K/Q6_K row alignment** — 26B A4B's `intermediate_size=2112` - isn't a multiple of 256, breaking `down_proj` matvec on any - matrix whose inner dim isn't super-block-aligned. Fix: per-row - zero-pad during extraction (`pad_rows_to_256`), dispatch with - `K = inter_padded`. Future vindexes with any non-256 inner dim - now work automatically. -- **Parameter-free router RMSNorm** — Gemma 4's `Gemma4TextRouter.norm` - has no learned weight. Added arch flag + `rms_norm_no_weight`. -- **Outer `post_feedforward_layernorm`** extracted and wired — was - being conflated with the `_1` dense-branch norm. -- **`layer_scalar` applied to whole layer output** not the FFN - delta — matches HF's `hidden_states *= self.layer_scalar`. - -### Correctness infrastructure (2026-04-24) -Tooling to keep the above from regressing, and to localise any -future cross-model forward-pass bug to the right layer / block: - -- **Architecture regression suite** — - `crates/larql-inference/tests/test_arch_golden.rs` runs one - `#[test]` per `(arch × backend)`. Skip-if-missing for vindex - cache, so CI stays green but local runs catch breakage - immediately. Covers Gemma 3, Gemma 4 dense, Gemma 4 hybrid MoE, - Llama 2 base, Mistral 7B base across GPU + CPU backends. -- **HF-reference residual diff** — `LARQL_DUMP_RESIDUALS=` - writes every layer's `layer_in` / `h_post_attn` / `layer_out` in - a binary format symmetric with `/tmp/hf_residuals.py` (hooks - `Gemma4TextDecoderLayer` in HF transformers). `/tmp/diff_residuals.py` - prints per-layer cosine + RMS-delta and points at the first - layer where attention vs FFN diverges. Caught the row-alignment - bug by bisecting L0 sub-components (attention matched at - cos=0.9989; down_proj matvec dropped to 0.023). -- **L0 intermediate dumps** (`LARQL_DUMP_L0=`) — writes - gate_out, up_out, GEGLU act, down_out, h1, moe_out for the first - layer. `/tmp/diff_l0_gate_up.py` computes HF's manual MLP from - the captured pre-norm input and diffs each projection. -- **Vindex surgical patcher** — - `crates/larql-cli/examples/patch_down_proj.rs` re-quantises - `layers.N.mlp.down_proj.weight` entries with row-padding from an - existing vindex. Avoids a ~hour-long 42 GB re-extract when only - one tensor class needs redoing. - -### CLI redesign (primary / dev split) -- New verbs: `run`, `chat`, `pull`, `list`, `show`, `rm`, `link`. -- Research commands moved under `larql dev `; legacy names - transparently trampolined. -- Dual cache (HuggingFace hub + `~/.cache/larql/local/`) with - shorthand resolution and source disambiguation. -- `larql serve --ffn-only` flag propagated through CLI → server → - `/v1/stats`. - -### Phase 0 — dense remote FFN baseline -- `POST /v1/walk-ffn` extended with `full_output: true` + - `seq_len: N`. Server runs the architecture-correct `WalkFfn`, - returns `[seq_len × hidden]` row-major. -- gRPC mirror (`WalkFfnRequest` / `WalkFfnLayerResult` proto fields). -- `RemoteWalkBackend` in `larql-inference` implements `FfnBackend`, - slots into `predict_with_ffn` unchanged. -- `larql run --ffn URL` + `larql dev walk --ffn-remote URL` CLI flags. -- `examples/remote_walk_parity.rs` localhost parity probe. - -### Vindex size reductions -- `--quant q4k` defaults gate_vectors + embeddings to f16 (previously - f32 — silent ~32% bloat on every q4k extract). -- `--compact` skips `up_weights.bin` + `down_weights.bin` (saves 3.4 - GB on 4B f16 / ~14 GB proportionally on 31B non-Q4K). -- `--drop-gate-vectors` skips `gate_vectors.bin` on Q4K extracts; - loader reconstructs from `interleaved_q4k.bin` at load time. 2.3 s - on 4B / ~12 s on 31B cost, saves 1.7 GB / 13.9 GB respectively. - Measured via `crates/larql-vindex/examples/bench_gate_dequant.rs`. - -### Decoupled-inference memory asymmetry (real, pre-load filtered) -- `LoadWeightsOptions { skip_attn, skip_ffn, skip_lm_head, skip_embed }` - filters weight manifest entries before mmap+decode — peak RSS - reflects only what the caller wanted (no allocator-pooling lie). -- Server `--ffn-only`: skips attn + ffn + lm_head + embed at load. - Walk-ffn endpoint uses `walk_ffn_full_mmap` which reads - feature-major mmap, not heap tensors. -- Client `--ffn URL`: skips FFN tensors at load. Attention + embed + - norms + lm_head only on heap. -- Measured on Gemma 3 4B f32 (`gemma3-4b-v2.vindex`): - - Server RSS: 12.8 GB idle → **12.8 GB through inference** (never grew) - - Client load: 22.5 s → **7.9 s** (2.8× faster) - - Forward pass: 3.83 s → **0.83 s** (4.6× faster — no FFN tensor - touches on the client) - - Paris @ 80.66% — bit-identical to local unlimited-K walk -- Drop-post-load helpers (`ModelWeights::drop_{attn,ffn,lm_head,embed}_weights`) - still exist but Rust's system allocator pools freed memory — - post-load drops reduce heap accounting but not process RSS. - Superseded by the pre-load filter for the demo path. -- `larql serve` now resolves cache shorthands (`larql serve gemma4-31b-q4k` - works, not just full paths) via the same `cache::resolve_model` - logic `larql run` uses. -- `larql run` / `larql dev walk` default `--top-k` to `usize::MAX` - (unlimited). The old `top-k=10` default silently produced garbage - on stale/low-K vindexes; removing the cap matches the server's - `WalkFfn::new_unlimited` behavior. - -### Extract tiers + default flip -- New `ExtractLevel::Attention` tier sits between `Browse` and - `Inference`: includes attention + norms but not FFN. This is the - first-class way to carve a client-side vindex for the Act 2 demo - (`larql extract --level attention`). No more ad-hoc slicing. -- Strict `Browse < Attention < Inference < All` ordering + helper - methods (`writes_attn()` / `writes_ffn()` / `writes_lm_head()`) - drive what each tier writes. Writers now actually honor the - boundaries — previously only Browse was meaningfully different from - non-Browse. -- **Default flip.** `larql extract` now defaults to `--level inference` - + f16. The common case (`larql extract -o x.vindex`) produces - an inference-ready vindex out of the box, no flags needed. `--f32` - opts out of f16 for the rare case someone wants it. - -### Gemma 4 config plumbing -- Fixed three missing `final_logit_softcapping` initializers - (pre-existing compile break on the `architecture-b` branch). -- Dropped an unused `mut` on a closure binding in - `format/weights/write.rs`. - -### Test coverage -- **490 tests across 14 suites**, zero warnings. -- New: cache resolution (19), argv trampoline (8), - `RemoteWalkBackend` wire format + config + error shape (10), server - validation + stats mode advertisement (7), local-cache scan - end-to-end. +## P1 — Generation UX (parallel to critical path) + +Details in `larql-inference/ROADMAP.md` and `larql-cli/ROADMAP.md`. + +- Sampling: `--temperature`, `--top-p`, `--top-k`, `--repetition-penalty` +- Multi-turn state: running KV across `larql chat` turns +- Long context: `--max-context N`, dynamic KV buffer growth +- OpenAI-compatible `/v1/chat/completions` (after streaming lands) +- Auto-extract on `larql run hf://owner/name` +- Gemma 3 4B regression smoke test (gate on `CI_INTEGRATION=1`) + +--- + +## P2 — Film checklist + +- [ ] Confirm Gemma 4 26B A4B public config (expert count, top-K, active-param figure, GQA ratio). Replace every `~` in `docs/demo-script-gemma4-moe.md`. +- [ ] Measure real footprint + latency on `google/gemma-4-31b-it` for Act 1. +- [ ] Reliability pass on `RemoteWalkBackend` (timeouts, retries, partial shard outage). +- [ ] `RemoteExpertBackend` same reliability pass. +- [ ] Decide repo-public date. `cargo install larql-cli && larql serve` must be live the week the video drops. +- [ ] Pick expert IDs for the Act 3 swap shot — one that fires on medical prompts, one that doesn't. --- -## Non-goals - -- **Not a general model-serving framework.** LARQL's pitch is "the - model is the database"; inference is a vehicle for the interpretable - vindex, not the product. We optimize for composability, editability, - and the demo narrative — not raw throughput against vLLM/TensorRT. -- **Not a training system.** `COMPILE` writes into weights; that's - patch-level edits, not gradient descent. Stays out of scope. -- **Not HF-compatible on the output side.** We extract *from* HF - models but the vindex format is our own. A vindex is not meant to be - loadable by `transformers.AutoModel`. +## Loose ends (shipped features with open follow-ups) + +| Item | Crate | Detail | +|---|---|---| +| `KernelHandle` spread to 9 remaining tiled shaders | larql-compute | Mechanical, same pattern as q4_matvec_v4 | +| `dispatch_full_pipeline` 30+ params | larql-compute | Bundle into `FullPipelineRefs<'_>` context | +| `QuantFormat` match spread (14 files) | larql-compute | Introduce `FormatRoute` enum | +| `ProfileTimings` producer | larql-compute | Wire commit/wait boundaries into decode_token | +| Benches in CI | larql-compute | GHA workflow written, needs trigger merged | +| `--compact` loader for non-MoE models | larql-vindex | `WeightFfn::forward` panics on compact vindex | +| MoE compact mode | larql-vindex | Blocked on per-expert feature-major files | +| Fix `dispatch_full_pipeline` layer_scalar (dense) | larql-compute | Non-urgent: Gemma 3 4B has scalar=0 | +| Cross-vindex dedup (tokenizer, down_meta) | larql-vindex | Low priority, ~200 MB duplicated at 7 vindexes | diff --git a/crates/kv-cache-benchmark/Cargo.toml b/crates/kv-cache-benchmark/Cargo.toml index 748be72a..2e1ec169 100644 --- a/crates/kv-cache-benchmark/Cargo.toml +++ b/crates/kv-cache-benchmark/Cargo.toml @@ -10,7 +10,7 @@ description = "KV cache benchmark: Standard KV vs TurboQuant vs Markov RS vs Gra [features] default = [] -real-model = ["larql-inference", "larql-vindex", "larql-models", "larql-compute", "ndarray", "tokenizers", "zip"] +real-model = ["larql-vindex", "larql-models", "ndarray", "tokenizers", "zip"] [dependencies] serde.workspace = true @@ -19,11 +19,13 @@ thiserror.workspace = true rand = "0.8" rand_distr = "0.4" -# Optional: real model integration (Phase 2) -larql-inference = { path = "../larql-inference", optional = true } +# Always available: needed for the criterion bench (accuracy metrics, engine_kind). +larql-inference = { path = "../larql-inference" } +larql-compute = { path = "../larql-compute" } + +# Optional: full real-model integration (real weights, vindex, tokenizer). larql-vindex = { path = "../larql-vindex", optional = true } larql-models = { path = "../larql-models", optional = true } -larql-compute = { path = "../larql-compute", optional = true } ndarray = { version = "0.16", optional = true } tokenizers = { version = "0.21", optional = true } # `zip` for reading the .npz container in apollo11_store (uncompressed archives). diff --git a/crates/kv-cache-benchmark/README.md b/crates/kv-cache-benchmark/README.md index 2289b3b5..7e25385d 100644 --- a/crates/kv-cache-benchmark/README.md +++ b/crates/kv-cache-benchmark/README.md @@ -34,14 +34,40 @@ The rungs are not interchangeable — they answer different questions: ## Implementation status -| Strategy | End-to-end real | Synthetic encode/decode | -|---|---|---| -| Standard KV | ✓ `real_model::kv_capture` + `standard_kv` | ✓ | -| TurboQuant | ✓ `real_model::turboquant_layer` + `turboquant` | ✓ | -| Markov RS (W=512) | ✓ `real_model::markov_layer` (`rs_prefill`, `rs_decode_step`) — proven bit-perfect end-to-end (Tier 1 / variant iv-dense) | ✓ | -| `UnlimitedContextEngine` (Tier 2) | ✓ `unlimited_context::` — Rust port of `chuk-mlx/.../unlimited_engine.py`; integration tests `tests/test_unlimited_context.rs` | — | -| `ApolloEngine` (Tier 3) | ✓ full end-to-end pipeline on real apollo11_store + Gemma 3 4B. **Four entry points** (`query_greedy`, `query_greedy_compressed`, `query_generate_uncompressed`, `query_generate_compressed` — detailed under Row 5 notes below). Positional-proximity retrieval + answer-only injection produces `" John"` as top-1 for "Who won the porridge eating contest?" on both the uncompressed and compressed paths. | — | -| Graph Walk | partial — `real_model::graph_walk_layer` + memory accounting via `graph_walk::GraphWalk`; does not implement `KvStrategy` (no K/V reconstruction without cracked attention) | — | +All engines now live in `larql_inference::engines::kv_engines/`. This crate +re-exports from there; the implementations are no longer duplicated here. + +| Strategy | Lives in | End-to-end real | Synthetic | +|---|---|---|---| +| Standard KV | `real_model::kv_capture` | ✓ | ✓ `standard_kv` | +| TurboQuant | `larql_inference::engines::kv_engines::turbo_quant` | ✓ (~95 tok/s Metal) | ✓ | +| Markov RS | `larql_inference::engines::kv_engines::markov_residual` | ✓ (~95 tok/s Metal, bit-perfect) | ✓ | +| UnlimitedContext | `larql_inference::engines::kv_engines::unlimited_context` | ✓ (~94 tok/s Metal) | ✓ | +| ApolloEngine | `larql_inference::engines::kv_engines::apollo` | ✓ (compressed path via `forward_from_layer`) | ✓ | +| Graph Walk | `graph_walk::GraphWalk` (memory accounting only) | partial | — | + +### Speed (Gemma 3 4B, Metal Q4K, 2026-04-26) + +All engines use `prefill_q4k`/`decode_step_q4k` → Metal `decode_token` pipeline: + +``` +Backend prefill ms/tok tok/s +larql-metal (standard) 58ms 13ms 76.7 +markov-rs (Q4K Metal) 294ms 10.5ms 95.2 +unlimited-context (Q4K Metal) 208ms 10.6ms 94.3 +turbo-quant 4-bit (Q4K Metal) 203ms 10.6ms 94.8 +turbo-quant 3-bit (Q4K Metal) 201ms 10.6ms 94.3 +``` + +Apollo runs on the CPU compressed path (4 layers via `forward_from_layer`). + +### Criterion benchmarks + +``` +cargo bench -p kv-cache-benchmark --bench kv_strategies +``` + +30 benchmarks across 6 groups: encode, wht, memory_sweep, accuracy, engine_kind, engine_memory. ### Latest measured run — 2026-04-23, Gemma 3 4B (q4k vindex) diff --git a/crates/kv-cache-benchmark/benches/kv_strategies.rs b/crates/kv-cache-benchmark/benches/kv_strategies.rs index ff8d4c7f..69b046c2 100644 --- a/crates/kv-cache-benchmark/benches/kv_strategies.rs +++ b/crates/kv-cache-benchmark/benches/kv_strategies.rs @@ -1,9 +1,9 @@ -use criterion::{criterion_group, criterion_main, Criterion, BenchmarkId}; -use kv_cache_benchmark::*; +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use kv_cache_benchmark::markov_residual::MarkovResidual; use kv_cache_benchmark::model_config::ModelConfig; use kv_cache_benchmark::standard_kv::StandardKv; use kv_cache_benchmark::turboquant::TurboQuant; -use kv_cache_benchmark::markov_residual::MarkovResidual; +use kv_cache_benchmark::*; use rand::prelude::*; fn bench_encode(c: &mut Criterion) { @@ -24,17 +24,14 @@ fn bench_encode(c: &mut Criterion) { let s = StandardKv; b.iter(|| s.encode(&keys, &values)) }); - group.bench_function("turboquant_4bit", |b| { let s = TurboQuant::new(4); b.iter(|| s.encode(&keys, &values)) }); - group.bench_function("turboquant_3bit", |b| { let s = TurboQuant::new(3); b.iter(|| s.encode(&keys, &values)) }); - group.bench_function("markov_residual", |b| { let s = MarkovResidual::new(512); b.iter(|| s.encode(&keys, &values)) @@ -45,14 +42,14 @@ fn bench_encode(c: &mut Criterion) { fn bench_wht(c: &mut Criterion) { let mut group = c.benchmark_group("wht"); - for dim in [128, 256] { - let x: Vec = (0..dim).map(|i| (i as f32 - dim as f32 / 2.0) / 100.0).collect(); + let x: Vec = (0..dim) + .map(|i| (i as f32 - dim as f32 / 2.0) / 100.0) + .collect(); group.bench_with_input(BenchmarkId::new("wht", dim), &x, |b, x| { b.iter(|| kv_cache_benchmark::turboquant::rotation::wht(x)) }); } - group.finish(); } @@ -61,14 +58,156 @@ fn bench_memory_sweep(c: &mut Criterion) { let standard = StandardKv; let tq4 = TurboQuant::new(4); let markov = MarkovResidual::new(512); - let strategies: Vec<&dyn KvStrategy> = vec![&standard, &tq4, &markov]; let lengths = benchmark::CONTEXT_LENGTHS; - c.bench_function("memory_sweep", |b| { b.iter(|| benchmark::memory_sweep(&config, &strategies, lengths)) }); } -criterion_group!(benches, bench_encode, bench_wht, bench_memory_sweep); +/// Accuracy metric microbenchmarks — no model weights required. +/// +/// These measure the overhead of the accuracy helpers that validate engine +/// hidden-state correctness (cosine, KL, softmax). Useful for understanding +/// how much the correctness checks add to a real-model test run. +fn bench_accuracy_metrics(c: &mut Criterion) { + use larql_inference::engines::accuracy::{ + cosine_similarity, js_divergence, kl_divergence, mse, softmax, + }; + + let hidden = 2560usize; // Gemma 3 4B hidden_dim + let mut rng = StdRng::seed_from_u64(99); + let a: Vec = (0..hidden) + .map(|_| rng.gen_range(-1.0f32..1.0f32)) + .collect(); + let b: Vec = (0..hidden) + .map(|_| rng.gen_range(-1.0f32..1.0f32)) + .collect(); + + let mut group = c.benchmark_group("accuracy"); + group.throughput(Throughput::Elements(hidden as u64)); + + group.bench_function("cosine_similarity/2560", |bench| { + bench.iter(|| cosine_similarity(&a, &b)) + }); + group.bench_function("mse/2560", |bench| bench.iter(|| mse(&a, &b))); + + // Softmax + KL on a 1K-token subset (fast enough for CI) + let vocab = 1000usize; + let logits: Vec = (0..vocab).map(|i| (i as f32) * 0.01).collect(); + let p = softmax(&logits); + let raw_q: Vec = (0..vocab).map(|_| rng.gen_range(0.0f32..1.0f32)).collect(); + let q_sum: f32 = raw_q.iter().sum(); + let q: Vec = raw_q.iter().map(|x| x / q_sum).collect(); + + group.bench_function("softmax/1k_vocab", |bench| bench.iter(|| softmax(&logits))); + group.bench_function("kl_divergence/1k_vocab", |bench| { + bench.iter(|| kl_divergence(&p, &q)) + }); + group.bench_function("js_divergence/1k_vocab", |bench| { + bench.iter(|| js_divergence(&p, &q)) + }); + + group.finish(); +} + +/// EngineKind dispatch overhead — construction, parsing, and engine creation. +/// Measures the metadata / dispatch path without a forward pass. +fn bench_engine_kind(c: &mut Criterion) { + use larql_inference::engines::EngineKind; + + let mut group = c.benchmark_group("engine_kind"); + + group.bench_function("from_name/markov-rs", |b| { + b.iter(|| EngineKind::from_name("markov-rs")) + }); + group.bench_function("from_name/unlimited-context", |b| { + b.iter(|| EngineKind::from_name("unlimited-context")) + }); + group.bench_function("build/markov_rs_W512", |b| { + b.iter(|| { + EngineKind::MarkovResidual { + window_size: Some(512), + } + .build(larql_compute::cpu_backend()) + }) + }); + group.bench_function("build/unlimited_context_W512", |b| { + b.iter(|| { + EngineKind::UnlimitedContext { window_size: 512 }.build(larql_compute::cpu_backend()) + }) + }); + + group.finish(); +} + +/// Memory accounting at different context lengths. +/// Models how fast engines can report their state size as context grows — +/// relevant for multi-turn systems that need to decide when to evict. +fn bench_engine_memory_accounting(c: &mut Criterion) { + // Gemma 3 4B geometry + let layers = 34usize; + let kv_heads = 4usize; + let head_dim = 256usize; + let kv_dim = kv_heads * head_dim; + let hidden = 2560usize; + + let mut group = c.benchmark_group("engine_memory"); + + for &seq_len in &[512usize, 4096, 32768, 131072, 370_000] { + let window = seq_len.min(512); + + group.bench_with_input( + BenchmarkId::new("markov_rs_hot_bytes", seq_len), + &seq_len, + |b, _| { + b.iter(|| { + // Hot-window bytes: W × layers × hidden_dim × 4 (f32) + window * layers * hidden * 4 + }) + }, + ); + + group.bench_with_input( + BenchmarkId::new("standard_kv_bytes_fp16", seq_len), + &seq_len, + |b, _| { + b.iter(|| { + // Standard KV (FP16): seq × layers × 2 × kv_dim × 2 bytes + seq_len * layers * 2 * kv_dim * 2 + }) + }, + ); + + group.bench_with_input( + BenchmarkId::new("compression_ratio", seq_len), + &seq_len, + |b, _| { + b.iter(|| { + let std_kv = seq_len * layers * 2 * kv_dim * 2; + let markov_hot = window * layers * hidden * 4; + let markov_cold = seq_len.saturating_sub(window) * 4; // 4B/token cold + let markov_total = markov_hot + markov_cold; + if markov_total > 0 { + std_kv as f64 / markov_total as f64 + } else { + 0.0 + } + }) + }, + ); + } + + group.finish(); +} + +criterion_group!( + benches, + bench_encode, + bench_wht, + bench_memory_sweep, + bench_accuracy_metrics, + bench_engine_kind, + bench_engine_memory_accounting, +); criterion_main!(benches); diff --git a/crates/kv-cache-benchmark/examples/accuracy_suite.rs b/crates/kv-cache-benchmark/examples/accuracy_suite.rs index effb98ee..5a2a3e17 100644 --- a/crates/kv-cache-benchmark/examples/accuracy_suite.rs +++ b/crates/kv-cache-benchmark/examples/accuracy_suite.rs @@ -19,16 +19,17 @@ fn main() { let quick = args.iter().any(|a| a == "--quick"); // Load model - let model_name = args.get(1) + let model_name = args + .get(1) .filter(|a| !a.starts_with('-')) .map(|s| s.as_str()) .unwrap_or("google/gemma-3-4b-it"); println!("Loading model: {model_name}"); - let model = larql_inference::InferenceModel::load(model_name) - .expect("Failed to load model"); + let model = larql_inference::InferenceModel::load(model_name).expect("Failed to load model"); // Load vindex (second arg or next non-flag arg) - let vindex_path = args.iter() + let vindex_path = args + .iter() .skip(1) .filter(|a| !a.starts_with('-')) .nth(1) @@ -37,7 +38,8 @@ fn main() { let index = larql_vindex::VectorIndex::load_vindex( std::path::Path::new(vindex_path), &mut larql_vindex::SilentLoadCallbacks, - ).expect("Failed to load vindex"); + ) + .expect("Failed to load vindex"); let backend = larql_inference::default_backend(); @@ -47,9 +49,8 @@ fn main() { // ── Test 1: Paris test ── println!("--- Test 1: Paris Test (pass/fail) ---\n"); - let paris_results = runner::test_paris( - model.weights(), model.tokenizer(), &index, backend.as_ref(), - ); + let paris_results = + runner::test_paris(model.weights(), model.tokenizer(), &index, backend.as_ref()); for (strategy, pass) in &paris_results { let mark = if *pass { "PASS" } else { "FAIL" }; println!(" {strategy:<30} {mark}"); @@ -65,7 +66,10 @@ fn main() { }; let prompt_results = runner::test_top1_match_rate( - model.weights(), model.tokenizer(), &index, backend.as_ref(), + model.weights(), + model.tokenizer(), + &index, + backend.as_ref(), &test_prompts, ); @@ -76,7 +80,8 @@ fn main() { // ── Test 4: Generation stability ── println!("\n--- Test 4: Generation Stability (20 tokens) ---\n"); let gen_results = runner::test_generation_stability( - model.weights(), model.tokenizer(), + model.weights(), + model.tokenizer(), "The capital of France is Paris. France is a country in", 20, ); @@ -93,7 +98,10 @@ fn main() { // Write JSON let json = serde_json::to_string_pretty(&prompt_results).unwrap(); - let _ = std::fs::write("crates/kv-cache-benchmark/results/accuracy_suite.json", &json); + let _ = std::fs::write( + "crates/kv-cache-benchmark/results/accuracy_suite.json", + &json, + ); println!("Results written to results/accuracy_suite.json"); } diff --git a/crates/kv-cache-benchmark/examples/bit_budget_additivity_q4k.rs b/crates/kv-cache-benchmark/examples/bit_budget_additivity_q4k.rs new file mode 100644 index 00000000..e674e4bd --- /dev/null +++ b/crates/kv-cache-benchmark/examples/bit_budget_additivity_q4k.rs @@ -0,0 +1,409 @@ +//! Exp37 q4k slot-bit additivity runner. +//! +//! Scores the object slot for each row in the Exp37 design matrix using exact +//! target log-probabilities from the low-memory q4k walk path, then computes +//! pairwise additivity interactions. + +#[cfg(feature = "real-model")] +fn main() -> Result<(), Box> { + runner::run() +} + +#[cfg(not(feature = "real-model"))] +fn main() { + eprintln!("This example requires the 'real-model' feature."); + std::process::exit(1); +} + +#[cfg(feature = "real-model")] +mod runner { + use std::collections::HashMap; + use std::fs::File; + use std::io::{BufRead, BufReader, Write}; + use std::path::PathBuf; + + use larql_inference::vindex::{predict_q4k_hidden_with_ffn, WalkFfn}; + use larql_inference::{encode_prompt, hidden_to_raw_logits, open_inference_vindex}; + use larql_vindex::{load_model_weights_q4k, load_vindex_tokenizer}; + use serde::{Deserialize, Serialize}; + use serde_json::json; + + #[derive(Debug)] + struct Args { + vindex: PathBuf, + design: PathBuf, + out_json: PathBuf, + scored_csv: PathBuf, + interactions_csv: PathBuf, + top_k: usize, + feature_top_k: usize, + } + + #[derive(Clone, Debug, Deserialize, Serialize)] + struct Cell { + source_id: String, + relation: String, + cell: String, + axes: String, + template: String, + subject: String, + object: String, + text: String, + object_span_start: usize, + object_span_end: usize, + } + + #[derive(Clone, Debug, Serialize)] + struct ScoredCell { + #[serde(flatten)] + cell: Cell, + prefix: String, + slot_bits_total: f64, + slot_bits_per_token: f64, + object_n_tokens: usize, + clipped_tokens: usize, + token_bits: Vec, + token_probs: Vec, + token_ids: Vec, + } + + #[derive(Clone, Debug, Serialize)] + struct Interaction { + source_id: String, + axis_a: String, + axis_b: String, + joint_cell: String, + slot_bits_delta_a: f64, + slot_bits_delta_b: f64, + slot_bits_observed_joint_delta: f64, + slot_bits_predicted_joint_delta: f64, + slot_bits_interaction_bits: f64, + } + + pub fn run() -> Result<(), Box> { + let args = parse_args(); + std::fs::create_dir_all(args.out_json.parent().unwrap())?; + std::fs::create_dir_all(args.scored_csv.parent().unwrap())?; + std::fs::create_dir_all(args.interactions_csv.parent().unwrap())?; + + let cells = load_design(&args.design)?; + println!("Loading q4k vindex {}", args.vindex.display()); + let mut cb = larql_vindex::SilentLoadCallbacks; + let mut weights = load_model_weights_q4k(&args.vindex, &mut cb)?; + let tokenizer = load_vindex_tokenizer(&args.vindex)?; + let index = open_inference_vindex(&args.vindex)?; + + let mut scored = Vec::new(); + for (idx, cell) in cells.iter().enumerate() { + scored.push(score_cell( + &mut weights, + &tokenizer, + &index, + cell, + args.top_k, + args.feature_top_k, + )?); + if (idx + 1) % 10 == 0 { + println!("scored {}/{}", idx + 1, cells.len()); + } + } + let interactions = compute_interactions(&scored); + + std::fs::write( + &args.out_json, + serde_json::to_string_pretty(&json!({ + "experiment": "37_bit_budget_additivity", + "path": "q4k", + "scoring": "exact_target_logprob", + "vindex": args.vindex, + "top_k_predictions": args.top_k, + "feature_top_k": args.feature_top_k, + "n_cells": scored.len(), + "cells": scored, + "interactions": interactions, + }))?, + )?; + write_scored_csv(&args.scored_csv, &scored)?; + write_interactions_csv(&args.interactions_csv, &interactions)?; + println!("wrote {}", args.out_json.display()); + println!("wrote {}", args.scored_csv.display()); + println!("wrote {}", args.interactions_csv.display()); + Ok(()) + } + + fn parse_args() -> Args { + let mut args = Args { + vindex: PathBuf::from("output/gemma3-4b-q4k-v2.vindex"), + design: PathBuf::from("experiments/37_bit_budget_additivity/results/design_matrix.csv"), + out_json: PathBuf::from( + "experiments/37_bit_budget_additivity/results/q4k_scored_cells.json", + ), + scored_csv: PathBuf::from( + "experiments/37_bit_budget_additivity/results/q4k_scored_cells.csv", + ), + interactions_csv: PathBuf::from( + "experiments/37_bit_budget_additivity/results/q4k_interactions.csv", + ), + top_k: 2048, + feature_top_k: 2048, + }; + let raw: Vec = std::env::args().collect(); + let mut i = 1; + while i < raw.len() { + match raw[i].as_str() { + "--vindex" => { + i += 1; + args.vindex = PathBuf::from(&raw[i]); + } + "--design" => { + i += 1; + args.design = PathBuf::from(&raw[i]); + } + "--out-json" => { + i += 1; + args.out_json = PathBuf::from(&raw[i]); + } + "--scored-csv" => { + i += 1; + args.scored_csv = PathBuf::from(&raw[i]); + } + "--interactions-csv" => { + i += 1; + args.interactions_csv = PathBuf::from(&raw[i]); + } + "--top-k" => { + i += 1; + args.top_k = raw[i].parse().expect("--top-k must be usize"); + } + "--feature-top-k" => { + i += 1; + args.feature_top_k = raw[i].parse().expect("--feature-top-k must be usize"); + } + other => { + eprintln!("unknown arg: {other}"); + std::process::exit(2); + } + } + i += 1; + } + args + } + + fn load_design(path: &PathBuf) -> Result, Box> { + let file = File::open(path)?; + let reader = BufReader::new(file); + let mut lines = reader.lines(); + let header = lines.next().ok_or("empty design csv")??; + let headers: Vec<&str> = header.split(',').collect(); + let mut out = Vec::new(); + for line in lines { + let line = line?; + if line.trim().is_empty() { + continue; + } + let values: Vec<&str> = line.split(',').collect(); + if values.len() != headers.len() { + return Err(format!("unsupported csv row with commas: {line}").into()); + } + let mut row = HashMap::new(); + for (key, value) in headers.iter().zip(values.iter()) { + row.insert(*key, *value); + } + out.push(Cell { + source_id: get(&row, "source_id")?.to_string(), + relation: get(&row, "relation")?.to_string(), + cell: get(&row, "cell")?.to_string(), + axes: get(&row, "axes")?.to_string(), + template: get(&row, "template")?.to_string(), + subject: get(&row, "subject")?.to_string(), + object: get(&row, "object")?.to_string(), + text: get(&row, "text")?.to_string(), + object_span_start: get(&row, "object_span_start")?.parse()?, + object_span_end: get(&row, "object_span_end")?.parse()?, + }); + } + Ok(out) + } + + fn get<'a>( + row: &'a HashMap<&str, &str>, + key: &str, + ) -> Result<&'a str, Box> { + row.get(key) + .copied() + .ok_or_else(|| format!("missing csv field {key}").into()) + } + + fn score_cell( + weights: &mut larql_models::ModelWeights, + tokenizer: &tokenizers::Tokenizer, + index: &larql_vindex::VectorIndex, + cell: &Cell, + _top_k: usize, + feature_top_k: usize, + ) -> Result> { + let prefix = cell.text[..cell.object_span_start].to_string(); + let mut context_ids = encode_prompt(tokenizer, &*weights.arch, &prefix)?; + let object_surface = if prefix.ends_with(char::is_whitespace) { + cell.object.clone() + } else { + format!(" {}", cell.object) + }; + let object_ids = tokenizer + .encode(object_surface.as_str(), false) + .map_err(|e| format!("tokenize object {:?}: {e}", cell.object))? + .get_ids() + .to_vec(); + let mut token_bits = Vec::new(); + let mut token_probs = Vec::new(); + let clipped = 0usize; + for &target_id in &object_ids { + let prob = exact_target_prob( + weights, + index, + &context_ids, + target_id as usize, + feature_top_k, + ); + token_probs.push(prob); + token_bits.push(-prob.log2()); + context_ids.push(target_id); + } + let total = token_bits.iter().sum::(); + Ok(ScoredCell { + cell: cell.clone(), + prefix, + slot_bits_total: total, + slot_bits_per_token: total / object_ids.len().max(1) as f64, + object_n_tokens: object_ids.len(), + clipped_tokens: clipped, + token_bits, + token_probs, + token_ids: object_ids, + }) + } + + fn exact_target_prob( + weights: &mut larql_models::ModelWeights, + index: &larql_vindex::VectorIndex, + token_ids: &[u32], + target_id: usize, + feature_top_k: usize, + ) -> f64 { + let weights_ref: &larql_models::ModelWeights = + unsafe { &*(weights as *const larql_models::ModelWeights) }; + let walk_ffn = WalkFfn::new(weights_ref, index, feature_top_k); + let h = predict_q4k_hidden_with_ffn(weights, token_ids, index, &walk_ffn); + let seq_len = h.shape()[0]; + let h_last = h.slice(ndarray::s![seq_len - 1..seq_len, ..]).to_owned(); + let logits = hidden_to_raw_logits(weights, &h_last); + let target = logits[target_id] as f64; + let max_logit = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max) as f64; + let exp_sum: f64 = logits.iter().map(|&l| ((l as f64) - max_logit).exp()).sum(); + let logsumexp = max_logit + exp_sum.ln(); + (target - logsumexp).exp().max(f64::MIN_POSITIVE) + } + + fn compute_interactions(scored: &[ScoredCell]) -> Vec { + let mut by_source: HashMap> = HashMap::new(); + for row in scored { + by_source + .entry(row.cell.source_id.clone()) + .or_default() + .insert(row.cell.cell.clone(), row); + } + let pairs = [ + ("syntax", "fact", "syntax_fact"), + ("syntax", "style", "syntax_style"), + ("fact", "style", "fact_style"), + ]; + let mut out = Vec::new(); + for (source_id, cells) in by_source { + let Some(base) = cells.get("base") else { + continue; + }; + for (axis_a, axis_b, joint) in pairs { + let (Some(a), Some(b), Some(ab)) = + (cells.get(axis_a), cells.get(axis_b), cells.get(joint)) + else { + continue; + }; + let delta_a = a.slot_bits_total - base.slot_bits_total; + let delta_b = b.slot_bits_total - base.slot_bits_total; + let observed = ab.slot_bits_total - base.slot_bits_total; + let predicted = delta_a + delta_b; + out.push(Interaction { + source_id: source_id.clone(), + axis_a: axis_a.to_string(), + axis_b: axis_b.to_string(), + joint_cell: joint.to_string(), + slot_bits_delta_a: delta_a, + slot_bits_delta_b: delta_b, + slot_bits_observed_joint_delta: observed, + slot_bits_predicted_joint_delta: predicted, + slot_bits_interaction_bits: observed - predicted, + }); + } + } + out.sort_by(|a, b| { + (&a.source_id, &a.axis_a, &a.axis_b).cmp(&(&b.source_id, &b.axis_a, &b.axis_b)) + }); + out + } + + fn write_scored_csv( + path: &PathBuf, + rows: &[ScoredCell], + ) -> Result<(), Box> { + let mut f = File::create(path)?; + writeln!( + f, + "source_id,relation,cell,axes,subject,object,prefix,slot_bits_total,slot_bits_per_token,object_n_tokens,clipped_tokens" + )?; + for row in rows { + writeln!( + f, + "{},{},{},{},{},{},{},{:.6},{:.6},{},{}", + row.cell.source_id, + row.cell.relation, + row.cell.cell, + row.cell.axes, + row.cell.subject, + row.cell.object, + row.prefix, + row.slot_bits_total, + row.slot_bits_per_token, + row.object_n_tokens, + row.clipped_tokens + )?; + } + Ok(()) + } + + fn write_interactions_csv( + path: &PathBuf, + rows: &[Interaction], + ) -> Result<(), Box> { + let mut f = File::create(path)?; + writeln!( + f, + "source_id,axis_a,axis_b,joint_cell,slot_bits_delta_a,slot_bits_delta_b,slot_bits_observed_joint_delta,slot_bits_predicted_joint_delta,slot_bits_interaction_bits" + )?; + for row in rows { + writeln!( + f, + "{},{},{},{},{:.6},{:.6},{:.6},{:.6},{:.6}", + row.source_id, + row.axis_a, + row.axis_b, + row.joint_cell, + row.slot_bits_delta_a, + row.slot_bits_delta_b, + row.slot_bits_observed_joint_delta, + row.slot_bits_predicted_joint_delta, + row.slot_bits_interaction_bits + )?; + } + Ok(()) + } +} diff --git a/crates/kv-cache-benchmark/examples/decode_bench.rs b/crates/kv-cache-benchmark/examples/decode_bench.rs index 110423ff..e9a31e1e 100644 --- a/crates/kv-cache-benchmark/examples/decode_bench.rs +++ b/crates/kv-cache-benchmark/examples/decode_bench.rs @@ -41,22 +41,25 @@ #[cfg(feature = "real-model")] fn main() { use kv_cache_benchmark::real_model::decode_comparison::{ - run_decode_comparison, format_comparison, format_window_sweep, - QueryType, parametric_prompts, in_context_prompts, DecodeComparisonResult, + format_comparison, format_window_sweep, in_context_prompts, parametric_prompts, + run_decode_comparison, DecodeComparisonResult, QueryType, }; let args: Vec = std::env::args().collect(); - let model_name = args.get(1).map(|s| s.as_str()).unwrap_or("google/gemma-3-4b-it"); + let model_name = args + .get(1) + .map(|s| s.as_str()) + .unwrap_or("google/gemma-3-4b-it"); let decode_steps = 8; // Parse window sizes from optional third argument, or use defaults. - let windows: Vec = args.get(3) + let windows: Vec = args + .get(3) .map(|s| s.split(',').filter_map(|w| w.trim().parse().ok()).collect()) .unwrap_or_else(|| vec![1, 2, 4, 6, 12, 24]); println!("Loading model: {model_name}"); - let model = larql_inference::InferenceModel::load(model_name) - .expect("Failed to load model"); + let model = larql_inference::InferenceModel::load(model_name).expect("Failed to load model"); let weights = model.weights(); let tokenizer = model.tokenizer(); @@ -73,15 +76,21 @@ fn main() { for prompt_str in parametric_prompts() { let token_ids: Vec = tokenizer - .encode(prompt_str, true).expect("tokenize") - .get_ids().to_vec(); + .encode(prompt_str, true) + .expect("tokenize") + .get_ids() + .to_vec(); println!("\nPrompt: {:?} ({} tokens)", prompt_str, token_ids.len()); for &window in &windows { let result = run_decode_comparison( - weights, tokenizer, &token_ids, - QueryType::Parametric, window, decode_steps, + weights, + tokenizer, + &token_ids, + QueryType::Parametric, + window, + decode_steps, ); println!("{}", format_comparison(&result)); all_results.push(result); @@ -96,15 +105,25 @@ fn main() { for prompt_str in in_context_prompts() { let token_ids: Vec = tokenizer - .encode(prompt_str.as_str(), true).expect("tokenize") - .get_ids().to_vec(); + .encode(prompt_str.as_str(), true) + .expect("tokenize") + .get_ids() + .to_vec(); - println!("\nPrompt: {:?} ({} tokens)", &prompt_str[..60.min(prompt_str.len())], token_ids.len()); + println!( + "\nPrompt: {:?} ({} tokens)", + &prompt_str[..60.min(prompt_str.len())], + token_ids.len() + ); for &window in &windows { let result = run_decode_comparison( - weights, tokenizer, &token_ids, - QueryType::InContext, window, decode_steps, + weights, + tokenizer, + &token_ids, + QueryType::InContext, + window, + decode_steps, ); println!("{}", format_comparison(&result)); all_results.push(result); @@ -116,9 +135,14 @@ fn main() { println!("{}", format_window_sweep(&all_results)); let total = all_results.len(); - let perfect = all_results.iter().filter(|r| r.first_divergence.is_none()).count(); - println!("Overall: {perfect}/{total} runs with zero divergence ({:.1}%)", - perfect as f64 / total as f64 * 100.0); + let perfect = all_results + .iter() + .filter(|r| r.first_divergence.is_none()) + .count(); + println!( + "Overall: {perfect}/{total} runs with zero divergence ({:.1}%)", + perfect as f64 / total as f64 * 100.0 + ); let json = serde_json::to_string_pretty(&all_results).unwrap(); let out_path = "crates/kv-cache-benchmark/results/decode_comparison.json"; diff --git a/crates/kv-cache-benchmark/examples/ffn_coverage.rs b/crates/kv-cache-benchmark/examples/ffn_coverage.rs index d6cb6273..cc0fb917 100644 --- a/crates/kv-cache-benchmark/examples/ffn_coverage.rs +++ b/crates/kv-cache-benchmark/examples/ffn_coverage.rs @@ -61,7 +61,11 @@ mod ffn_coverage { match raw[i].as_str() { "--k" => { let v = raw.get(i + 1).cloned().unwrap_or_else(|| "full".into()); - k = if v == "full" { None } else { Some(v.parse().expect("--k must be int or 'full'")) }; + k = if v == "full" { + None + } else { + Some(v.parse().expect("--k must be int or 'full'")) + }; raw.drain(i..i + 2); } "--output" | "-o" => { @@ -69,7 +73,11 @@ mod ffn_coverage { raw.drain(i..i + 2); } "--limit" => { - limit = Some(raw.get(i + 1).and_then(|s| s.parse().ok()).expect("--limit needs int")); + limit = Some( + raw.get(i + 1) + .and_then(|s| s.parse().ok()) + .expect("--limit needs int"), + ); raw.drain(i..i + 2); } _ => i += 1, @@ -77,10 +85,18 @@ mod ffn_coverage { } if raw.len() < 2 { - eprintln!("Usage: ffn_coverage [--k N|full] [--output PATH] [--limit N]"); + eprintln!( + "Usage: ffn_coverage [--k N|full] [--output PATH] [--limit N]" + ); std::process::exit(2); } - Args { model: raw[0].clone(), vindex: raw[1].clone(), output, k, limit } + Args { + model: raw[0].clone(), + vindex: raw[1].clone(), + output, + k, + limit, + } } // ── Measurement records ── @@ -133,7 +149,9 @@ mod ffn_coverage { impl<'a> FfnBackend for InstrumentedFfn<'a> { fn forward(&self, layer: usize, x: &Array2) -> Array2 { - let dense = WeightFfn { weights: self.weights }; + let dense = WeightFfn { + weights: self.weights, + }; let dense_out = dense.forward(layer, x); let walk_out = self.walk.forward(layer, x); @@ -145,11 +163,17 @@ mod ffn_coverage { // gate_knn internally; we re-run with a small K purely to grab // top-K scores for measurement. Redundant but cheap. let x_last = Array1::from_iter(x.row(last).iter().copied()); - let top_hits = self.index.gate_knn(layer, &x_last, self.gate_k_for_measurement); + let top_hits = self + .index + .gate_knn(layer, &x_last, self.gate_k_for_measurement); let (feat0, score0) = top_hits.first().copied().unwrap_or((0, 0.0)); let score1 = top_hits.get(1).map(|(_, s)| s.abs()).unwrap_or(0.0); let margin = score0.abs() - score1; - let token = self.index.feature_meta(layer, feat0).map(|m| m.top_token).unwrap_or_default(); + let token = self + .index + .feature_meta(layer, feat0) + .map(|m| m.top_token) + .unwrap_or_default(); // Lookup count: gate_knn (1) + K feature reads (K) + K down reads (K). // When K_walk = features, this is ~2*F + 1. Report the effective K @@ -171,8 +195,15 @@ mod ffn_coverage { dense_out } - fn forward_with_activation(&self, layer: usize, x: &Array2) -> (Array2, Array2) { - let (out, act) = WeightFfn { weights: self.weights }.forward_with_activation(layer, x); + fn forward_with_activation( + &self, + layer: usize, + x: &Array2, + ) -> (Array2, Array2) { + let (out, act) = WeightFfn { + weights: self.weights, + } + .forward_with_activation(layer, x); // Re-run walk for measurement; discard its activation (we return dense). let _ = self.forward(layer, x); (out, act) @@ -215,7 +246,9 @@ mod ffn_coverage { println!( "WalkFfn: {} layers, K = {}", num_layers, - args.k.map(|k| k.to_string()).unwrap_or_else(|| "full".into()) + args.k + .map(|k| k.to_string()) + .unwrap_or_else(|| "full".into()) ); let all_prompts = diverse_100(); @@ -263,8 +296,12 @@ mod ffn_coverage { let mut layers = instrumented.measurements.into_inner(); layers.sort_by_key(|m| m.layer); - let worst_cos = layers.iter().map(|m| m.cos_walk_vs_dense).fold(f32::INFINITY, f32::min); - let mean_cos = layers.iter().map(|m| m.cos_walk_vs_dense).sum::() / layers.len() as f32; + let worst_cos = layers + .iter() + .map(|m| m.cos_walk_vs_dense) + .fold(f32::INFINITY, f32::min); + let mean_cos = + layers.iter().map(|m| m.cos_walk_vs_dense).sum::() / layers.len() as f32; println!( "[{:>3}/{}] {:<60} top1={:<15} mean_cos={:.4} worst_cos={:.4} {:>6.1}s", i + 1, @@ -294,7 +331,11 @@ mod ffn_coverage { } let json = serde_json::to_string_pretty(&results).expect("serialize"); std::fs::write(out_path, json).expect("write output"); - println!("\nWrote {} query results to {}", results.len(), out_path.display()); + println!( + "\nWrote {} query results to {}", + results.len(), + out_path.display() + ); print_coverage_summary(&results); } @@ -313,7 +354,11 @@ mod ffn_coverage { let thresholds: [f32; 5] = [0.95, 0.99, 0.999, 0.9999, 1.0]; println!("\n── Coverage summary ──"); - println!("queries={}, layers/query={}", results.len(), results.first().map(|r| r.layers.len()).unwrap_or(0)); + println!( + "queries={}, layers/query={}", + results.len(), + results.first().map(|r| r.layers.len()).unwrap_or(0) + ); println!("\nFully-walked rate (all layers cos ≥ τ):"); for &tau in &thresholds { @@ -321,15 +366,22 @@ mod ffn_coverage { .iter() .filter(|r| r.layers.iter().all(|m| m.cos_walk_vs_dense >= tau)) .count(); - println!(" τ={:<8} fully-walked: {}/{} ({:>5.1}%)", - format_tau(tau), fully_walked, results.len(), - 100.0 * fully_walked as f32 / results.len() as f32); + println!( + " τ={:<8} fully-walked: {}/{} ({:>5.1}%)", + format_tau(tau), + fully_walked, + results.len(), + 100.0 * fully_walked as f32 / results.len() as f32 + ); } println!("\nPer-layer walk rate at τ=0.99:"); let num_layers = results.first().map(|r| r.layers.len()).unwrap_or(0); for l in 0..num_layers { - let hits = results.iter().filter(|r| r.layers[l].cos_walk_vs_dense >= 0.99).count(); + let hits = results + .iter() + .filter(|r| r.layers[l].cos_walk_vs_dense >= 0.99) + .count(); let bar = "█".repeat(((hits as f32 / results.len() as f32) * 20.0) as usize); println!(" L{:<2} {:<20} {}/{}", l, bar, hits, results.len()); } diff --git a/crates/kv-cache-benchmark/examples/multi_turn_demo.rs b/crates/kv-cache-benchmark/examples/multi_turn_demo.rs index 3318df31..2d36d5e4 100644 --- a/crates/kv-cache-benchmark/examples/multi_turn_demo.rs +++ b/crates/kv-cache-benchmark/examples/multi_turn_demo.rs @@ -7,13 +7,13 @@ //! cargo run --example multi_turn_demo fn main() { - use kv_cache_benchmark::*; use kv_cache_benchmark::benchmark; + use kv_cache_benchmark::graph_walk::GraphWalk; + use kv_cache_benchmark::markov_residual::MarkovResidual; use kv_cache_benchmark::model_config::ModelConfig; use kv_cache_benchmark::standard_kv::StandardKv; use kv_cache_benchmark::turboquant::TurboQuant; - use kv_cache_benchmark::markov_residual::MarkovResidual; - use kv_cache_benchmark::graph_walk::GraphWalk; + use kv_cache_benchmark::*; let config = ModelConfig::gemma_4b(); let num_turns = 25; @@ -55,7 +55,10 @@ fn main() { // Summary let final_tokens = num_turns * tokens_per_turn; - println!("\n=== At {} tokens (turn {}) ===\n", final_tokens, num_turns); + println!( + "\n=== At {} tokens (turn {}) ===\n", + final_tokens, num_turns + ); let strategies: Vec<(&str, usize)> = vec![ ("Standard KV", standard.memory_bytes(&config, final_tokens)), @@ -66,8 +69,17 @@ fn main() { let baseline = strategies[0].1; for (name, mem) in &strategies { - let ratio = if *mem > 0 { baseline as f64 / *mem as f64 } else { 0.0 }; - println!(" {:<15} {:>12} ({:.1}× vs baseline)", name, format_bytes(*mem), ratio); + let ratio = if *mem > 0 { + baseline as f64 / *mem as f64 + } else { + 0.0 + }; + println!( + " {:<15} {:>12} ({:.1}× vs baseline)", + name, + format_bytes(*mem), + ratio + ); } // Full comparative table (KV-reconstructing strategies only). @@ -76,10 +88,14 @@ fn main() { // Crossover analysis println!("\n=== Crossover Analysis ===\n"); - println!("Standard KV grows linearly: every turn adds {} per token", - format_bytes(config.kv_bytes_per_token())); + println!( + "Standard KV grows linearly: every turn adds {} per token", + format_bytes(config.kv_bytes_per_token()) + ); println!("Markov RS is bounded: window = 512 tokens, cold tier = 4 bytes/token"); - println!("Graph Walk is constant: per-conversation = token IDs only (requires cracked attention)"); + println!( + "Graph Walk is constant: per-conversation = token IDs only (requires cracked attention)" + ); // Find crossover point where Markov RS < Standard KV for turn in 1..=50 { @@ -87,7 +103,10 @@ fn main() { let std_mem = standard.memory_bytes(&config, tokens); let mrk_mem = markov.memory_bytes(&config, tokens); if mrk_mem < std_mem { - println!("\nMarkov RS < Standard KV at turn {} ({} tokens)", turn, tokens); + println!( + "\nMarkov RS < Standard KV at turn {} ({} tokens)", + turn, tokens + ); break; } } diff --git a/crates/kv-cache-benchmark/examples/patch_propagation_q4k.rs b/crates/kv-cache-benchmark/examples/patch_propagation_q4k.rs new file mode 100644 index 00000000..891b3b0d --- /dev/null +++ b/crates/kv-cache-benchmark/examples/patch_propagation_q4k.rs @@ -0,0 +1,537 @@ +//! Exp36 patch-propagation MVP on the low-memory Q4K inference path. +//! +//! Builds the exp04 Atlantis->Poseidon multilayer insert in memory, then +//! force-scores controlled answer surfaces before and after the patch using +//! the finite-K q4k walk path. +//! +//! Usage: +//! cargo run -p kv-cache-benchmark --example patch_propagation_q4k \ +//! --features real-model --release -- \ +//! --vindex output/gemma3-4b-q4k-v2.vindex \ +//! --out experiments/36_patch_propagation/results/q4k_final_slot_bits.json + +#[cfg(feature = "real-model")] +fn main() -> Result<(), Box> { + runner::run() +} + +#[cfg(not(feature = "real-model"))] +fn main() { + eprintln!("This example requires the 'real-model' feature."); + std::process::exit(1); +} + +#[cfg(feature = "real-model")] +mod runner { + use std::collections::HashMap; + use std::fs::File; + use std::io::{BufRead, BufReader, Write}; + use std::path::PathBuf; + + use larql_inference::vindex::{predict_q4k_hidden_with_ffn, predict_q4k_with_ffn, WalkFfn}; + use larql_inference::{ + encode_prompt, hidden_to_raw_logits, open_inference_vindex, PredictResult, + }; + use larql_vindex::{load_model_weights_q4k, load_vindex_tokenizer, FeatureMeta}; + use ndarray::Array1; + use serde::{Deserialize, Serialize}; + use serde_json::json; + + #[derive(Debug)] + struct Args { + vindex: PathBuf, + prompts: PathBuf, + out: PathBuf, + csv: PathBuf, + alpha: f32, + layer_start: usize, + layer_end: usize, + top_k: usize, + feature_top_k: usize, + } + + #[derive(Clone, Debug, Deserialize)] + struct PromptRow { + group: String, + relation: String, + prefix: String, + answers: Vec, + description: Option, + } + + #[derive(Clone, Debug, Serialize)] + struct ScoreRow { + group: String, + relation: String, + prefix: String, + answer: String, + surface_kind: String, + description: Option, + slot_bits_total: f64, + slot_bits_per_token: f64, + answer_n_tokens: usize, + token_ids: Vec, + token_bits: Vec, + token_probs: Vec, + clipped_tokens: usize, + } + + #[derive(Clone, Debug, Serialize)] + struct SummaryRow { + group: String, + relation: String, + prefix: String, + answer: String, + before_bits: f64, + after_bits: f64, + delta_bits: f64, + before_bits_per_token: f64, + after_bits_per_token: f64, + answer_n_tokens: usize, + before_clipped_tokens: usize, + after_clipped_tokens: usize, + } + + #[derive(Clone, Debug, Serialize)] + struct InsertedSlot { + layer: usize, + feature: usize, + alpha: f32, + gate_rank: Option, + gate_score: Option, + } + + pub fn run() -> Result<(), Box> { + let args = parse_args(); + std::fs::create_dir_all(args.out.parent().unwrap())?; + std::fs::create_dir_all(args.csv.parent().unwrap())?; + + let prompts = load_prompts(&args.prompts)?; + + println!("Loading q4k vindex {}", args.vindex.display()); + let mut cb = larql_vindex::SilentLoadCallbacks; + let mut weights = load_model_weights_q4k(&args.vindex, &mut cb)?; + let tokenizer = load_vindex_tokenizer(&args.vindex)?; + let mut index = open_inference_vindex(&args.vindex)?; + + println!("Scoring baseline with top_k={}", args.top_k); + let before = score_prompts( + &mut weights, + &tokenizer, + &index, + &prompts, + args.top_k, + args.feature_top_k, + )?; + + println!( + "Building Atlantis patch L{}-L{} alpha={}", + args.layer_start, + args.layer_end - 1, + args.alpha + ); + let inserted = build_atlantis_patch( + &mut weights, + &tokenizer, + &mut index, + args.alpha, + args.layer_start..args.layer_end, + args.feature_top_k, + )?; + + println!("Scoring patched"); + let after = score_prompts( + &mut weights, + &tokenizer, + &index, + &prompts, + args.top_k, + args.feature_top_k, + )?; + let summary = summarize(&before, &after); + + let out = json!({ + "experiment": "36_patch_propagation", + "path": "q4k", + "scoring": "exact_target_logprob", + "vindex": args.vindex, + "top_k_predictions": args.top_k, + "feature_top_k": args.feature_top_k, + "patch": { + "type": "exp04_multilayer_atlantis_poseidon", + "alpha": args.alpha, + "layers": (args.layer_start..args.layer_end).collect::>(), + "inserted": inserted, + }, + "before": before, + "after": after, + "summary": summary, + }); + std::fs::write(&args.out, serde_json::to_string_pretty(&out)?)?; + write_summary_csv(&args.csv, &summary)?; + println!("wrote {}", args.out.display()); + println!("wrote {}", args.csv.display()); + Ok(()) + } + + fn parse_args() -> Args { + let mut args = Args { + vindex: PathBuf::from("output/gemma3-4b-q4k-v2.vindex"), + prompts: PathBuf::from("experiments/36_patch_propagation/data/prompts.jsonl"), + out: PathBuf::from("experiments/36_patch_propagation/results/q4k_final_slot_bits.json"), + csv: PathBuf::from( + "experiments/36_patch_propagation/results/q4k_final_slot_summary.csv", + ), + alpha: 0.25, + layer_start: 20, + layer_end: 28, + top_k: 2048, + feature_top_k: 2048, + }; + + let raw: Vec = std::env::args().collect(); + let mut i = 1; + while i < raw.len() { + match raw[i].as_str() { + "--vindex" => { + i += 1; + args.vindex = PathBuf::from(&raw[i]); + } + "--prompts" => { + i += 1; + args.prompts = PathBuf::from(&raw[i]); + } + "--out" => { + i += 1; + args.out = PathBuf::from(&raw[i]); + } + "--csv" => { + i += 1; + args.csv = PathBuf::from(&raw[i]); + } + "--alpha" => { + i += 1; + args.alpha = raw[i].parse().expect("--alpha must be f32"); + } + "--layers" => { + i += 1; + let (start, end) = raw[i].split_once(':').expect("--layers START:END"); + args.layer_start = start.parse().expect("layer start"); + args.layer_end = end.parse().expect("layer end"); + } + "--top-k" => { + i += 1; + args.top_k = raw[i].parse().expect("--top-k must be usize"); + } + "--feature-top-k" => { + i += 1; + args.feature_top_k = raw[i].parse().expect("--feature-top-k must be usize"); + } + other => { + eprintln!("unknown arg: {other}"); + std::process::exit(2); + } + } + i += 1; + } + args + } + + fn load_prompts(path: &PathBuf) -> Result, Box> { + let file = File::open(path)?; + let reader = BufReader::new(file); + let mut rows = Vec::new(); + for line in reader.lines() { + let line = line?; + if line.trim().is_empty() { + continue; + } + rows.push(serde_json::from_str(&line)?); + } + Ok(rows) + } + + fn build_atlantis_patch( + weights: &mut larql_models::ModelWeights, + tokenizer: &tokenizers::Tokenizer, + index: &mut larql_vindex::VectorIndex, + alpha: f32, + layers: std::ops::Range, + feature_top_k: usize, + ) -> Result, Box> { + let prompt_ids = encode_prompt(tokenizer, &*weights.arch, "The capital of Atlantis is")?; + let (_, trace_residuals) = + run_q4k_walk(weights, tokenizer, index, &prompt_ids, 5, feature_top_k); + let residuals: HashMap> = trace_residuals.into_iter().collect(); + + let poseidon_surface = " Poseidon"; + let poseidon_ids = tokenizer + .encode(poseidon_surface, false) + .map_err(|e| format!("tokenize {poseidon_surface:?}: {e}"))? + .get_ids() + .to_vec(); + let poseidon_id = *poseidon_ids + .first() + .ok_or("leading-space Poseidon tokenized empty")? as usize; + let embed_scale = weights.arch.embed_scale(); + let poseidon_vec: Vec = weights + .embed + .row(poseidon_id) + .iter() + .map(|v| v * embed_scale * alpha) + .collect(); + + let mut inserted = Vec::new(); + for layer in layers { + let residual = residuals + .get(&layer) + .ok_or_else(|| format!("missing residual for layer {layer}"))?; + let residual_norm = l2(residual); + if residual_norm == 0.0 { + continue; + } + let mut norms = Vec::new(); + for feature in 0..index.num_features(layer).min(50) { + if let Some(gate) = index.gate_vector(layer, feature) { + let n = l2(gate.as_slice()); + if n > 0.0 { + norms.push(n); + } + } + } + let avg_norm = norms.iter().sum::() / norms.len().max(1) as f32; + let gate_vec = + Array1::from_iter(residual.iter().map(|v| v * (avg_norm / residual_norm))); + let feature = index + .find_free_feature(layer) + .ok_or_else(|| format!("no free feature at layer {layer}"))?; + let gate_score = dot(gate_vec.as_slice().unwrap_or(&[]), residual); + let up_vec = if gate_score.abs() > 1e-6 { + gate_vec.iter().map(|v| v / gate_score).collect() + } else { + gate_vec.to_vec() + }; + index.set_gate_vector(layer, feature, &gate_vec); + index.set_up_vector(layer, feature, up_vec); + index.set_down_vector(layer, feature, poseidon_vec.clone()); + index.set_feature_meta( + layer, + feature, + FeatureMeta { + top_token: "Poseidon".to_string(), + top_token_id: poseidon_id as u32, + c_score: 0.95, + top_k: Vec::new(), + }, + ); + + let verify = index.gate_knn( + layer, + &Array1::from_vec(residual.clone()), + feature_top_k.min(128), + ); + let rank = verify + .iter() + .position(|(f, _)| *f == feature) + .map(|x| x + 1); + let score = verify.iter().find(|(f, _)| *f == feature).map(|(_, s)| *s); + inserted.push(InsertedSlot { + layer, + feature, + alpha, + gate_rank: rank, + gate_score: score, + }); + } + Ok(inserted) + } + + fn score_prompts( + weights: &mut larql_models::ModelWeights, + tokenizer: &tokenizers::Tokenizer, + index: &larql_vindex::VectorIndex, + prompts: &[PromptRow], + top_k: usize, + feature_top_k: usize, + ) -> Result, Box> { + let mut rows = Vec::new(); + for prompt in prompts { + for (surface_idx, answer) in prompt.answers.iter().enumerate() { + rows.push(score_answer( + weights, + tokenizer, + index, + prompt, + answer, + surface_idx, + top_k, + feature_top_k, + )?); + } + } + Ok(rows) + } + + fn score_answer( + weights: &mut larql_models::ModelWeights, + tokenizer: &tokenizers::Tokenizer, + index: &larql_vindex::VectorIndex, + prompt: &PromptRow, + answer: &str, + surface_idx: usize, + _top_k: usize, + feature_top_k: usize, + ) -> Result> { + let mut context_ids = encode_prompt(tokenizer, &*weights.arch, &prompt.prefix)?; + let answer_ids = tokenizer + .encode(format!(" {answer}"), false) + .map_err(|e| format!("tokenize answer {answer:?}: {e}"))? + .get_ids() + .to_vec(); + let mut token_bits = Vec::new(); + let mut token_probs = Vec::new(); + let clipped = 0usize; + + for &target_id in &answer_ids { + let prob = exact_target_prob( + weights, + index, + &context_ids, + target_id as usize, + feature_top_k, + ); + token_probs.push(prob); + token_bits.push(-prob.log2()); + context_ids.push(target_id); + } + let total: f64 = token_bits.iter().sum(); + Ok(ScoreRow { + group: prompt.group.clone(), + relation: prompt.relation.clone(), + prefix: prompt.prefix.clone(), + answer: answer.to_string(), + surface_kind: if surface_idx == 0 { + "canonical".to_string() + } else { + format!("alias_{surface_idx}") + }, + description: prompt.description.clone(), + slot_bits_total: total, + slot_bits_per_token: total / answer_ids.len().max(1) as f64, + answer_n_tokens: answer_ids.len(), + token_ids: answer_ids, + token_bits, + token_probs, + clipped_tokens: clipped, + }) + } + + fn summarize(before: &[ScoreRow], after: &[ScoreRow]) -> Vec { + let mut by_key: HashMap<(String, String, String), &ScoreRow> = HashMap::new(); + for row in before { + by_key.insert( + (row.group.clone(), row.prefix.clone(), row.answer.clone()), + row, + ); + } + after + .iter() + .map(|a| { + let b = by_key[&(a.group.clone(), a.prefix.clone(), a.answer.clone())]; + SummaryRow { + group: a.group.clone(), + relation: a.relation.clone(), + prefix: a.prefix.clone(), + answer: a.answer.clone(), + before_bits: b.slot_bits_total, + after_bits: a.slot_bits_total, + delta_bits: b.slot_bits_total - a.slot_bits_total, + before_bits_per_token: b.slot_bits_per_token, + after_bits_per_token: a.slot_bits_per_token, + answer_n_tokens: a.answer_n_tokens, + before_clipped_tokens: b.clipped_tokens, + after_clipped_tokens: a.clipped_tokens, + } + }) + .collect() + } + + fn write_summary_csv( + path: &PathBuf, + rows: &[SummaryRow], + ) -> Result<(), Box> { + let mut file = File::create(path)?; + writeln!( + file, + "group,relation,prefix,answer,before_bits,after_bits,delta_bits,before_bits_per_token,after_bits_per_token,answer_n_tokens,before_clipped_tokens,after_clipped_tokens" + )?; + for row in rows { + writeln!( + file, + "{},{},{:?},{},{:.6},{:.6},{:.6},{:.6},{:.6},{},{},{}", + row.group, + row.relation, + row.prefix, + row.answer, + row.before_bits, + row.after_bits, + row.delta_bits, + row.before_bits_per_token, + row.after_bits_per_token, + row.answer_n_tokens, + row.before_clipped_tokens, + row.after_clipped_tokens + )?; + } + Ok(()) + } + + fn l2(xs: &[f32]) -> f32 { + xs.iter().map(|v| v * v).sum::().sqrt() + } + + fn dot(a: &[f32], b: &[f32]) -> f32 { + a.iter().zip(b).map(|(x, y)| x * y).sum() + } + + fn run_q4k_walk( + weights: &mut larql_models::ModelWeights, + tokenizer: &tokenizers::Tokenizer, + index: &larql_vindex::VectorIndex, + token_ids: &[u32], + pred_top_k: usize, + feature_top_k: usize, + ) -> (PredictResult, Vec<(usize, Vec)>) { + // SAFETY: this mirrors `infer_patched_q4k`: the q4k forward mutates + // `weights.tensors`, while WalkFfn reads `weights.arch` and + // `weights.vectors`. + let weights_ref: &larql_models::ModelWeights = + unsafe { &*(weights as *const larql_models::ModelWeights) }; + let walk_ffn = WalkFfn::new_with_trace(weights_ref, index, feature_top_k); + let result = + predict_q4k_with_ffn(weights, tokenizer, token_ids, pred_top_k, index, &walk_ffn); + let residuals = walk_ffn.take_residuals(); + (result, residuals) + } + + fn exact_target_prob( + weights: &mut larql_models::ModelWeights, + index: &larql_vindex::VectorIndex, + token_ids: &[u32], + target_id: usize, + feature_top_k: usize, + ) -> f64 { + let weights_ref: &larql_models::ModelWeights = + unsafe { &*(weights as *const larql_models::ModelWeights) }; + let walk_ffn = WalkFfn::new(weights_ref, index, feature_top_k); + let h = predict_q4k_hidden_with_ffn(weights, token_ids, index, &walk_ffn); + let seq_len = h.shape()[0]; + let h_last = h.slice(ndarray::s![seq_len - 1..seq_len, ..]).to_owned(); + let logits = hidden_to_raw_logits(weights, &h_last); + let target = logits[target_id] as f64; + let max_logit = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max) as f64; + let exp_sum: f64 = logits.iter().map(|&l| ((l as f64) - max_logit).exp()).sum(); + let logsumexp = max_logit + exp_sum.ln(); + (target - logsumexp).exp().max(f64::MIN_POSITIVE) + } +} diff --git a/crates/kv-cache-benchmark/examples/q4k_ffn_raw_bridge.rs b/crates/kv-cache-benchmark/examples/q4k_ffn_raw_bridge.rs new file mode 100644 index 00000000..6e369c5a --- /dev/null +++ b/crates/kv-cache-benchmark/examples/q4k_ffn_raw_bridge.rs @@ -0,0 +1,198 @@ +//! Q4K FFN raw-output bridge for exp35. +//! +//! Reads LARQLF32 matrices exported by +//! `experiments/35_ffn_functional_fidelity/ffn_functional_fidelity.py`, runs +//! the production `q4k_ffn_forward_layer` path for one layer, and writes the +//! resulting raw FFN outputs back as LARQLF32 matrices. +//! +//! Usage: +//! cargo run -p kv-cache-benchmark --example q4k_ffn_raw_bridge \ +//! --features real-model --release -- \ +//! output/gemma3-4b-q4k-v2.vindex \ +//! experiments/35_ffn_functional_fidelity/results/q4k_bridge_inputs_l30_seed \ +//! experiments/35_ffn_functional_fidelity/results/q4k_bridge_outputs_l30_seed \ +//! --layer 30 --k full + +#[cfg(feature = "real-model")] +fn main() { + bridge::run(); +} + +#[cfg(not(feature = "real-model"))] +fn main() { + eprintln!("This example requires the 'real-model' feature."); + std::process::exit(1); +} + +#[cfg(feature = "real-model")] +mod bridge { + use std::fs::File; + use std::io::{Read, Write}; + use std::path::{Path, PathBuf}; + + use ndarray::Array2; + + use larql_inference::ffn::FfnBackend; + use larql_inference::vindex::{q4k_ffn_forward_layer, WalkFfn, WalkFfnConfig}; + use larql_vindex::{load_model_weights_q4k, SilentLoadCallbacks, VectorIndex}; + + const MAGIC: &[u8; 8] = b"LARQLF32"; + + struct Args { + vindex: PathBuf, + input_dir: PathBuf, + output_dir: PathBuf, + layer: usize, + k: Option, + } + + fn parse_args() -> Args { + let mut raw: Vec = std::env::args().skip(1).collect(); + let mut layer = 30usize; + let mut k: Option = None; + + let mut i = 0; + while i < raw.len() { + match raw[i].as_str() { + "--layer" => { + layer = raw + .get(i + 1) + .and_then(|s| s.parse().ok()) + .expect("--layer needs usize"); + raw.drain(i..i + 2); + } + "--k" => { + let v = raw.get(i + 1).cloned().unwrap_or_else(|| "full".into()); + k = if v == "full" { + None + } else { + Some(v.parse().expect("--k must be int or 'full'")) + }; + raw.drain(i..i + 2); + } + _ => i += 1, + } + } + + if raw.len() != 3 { + eprintln!( + "Usage: q4k_ffn_raw_bridge --layer N --k N|full" + ); + std::process::exit(2); + } + Args { + vindex: PathBuf::from(&raw[0]), + input_dir: PathBuf::from(&raw[1]), + output_dir: PathBuf::from(&raw[2]), + layer, + k, + } + } + + pub fn run() { + let args = parse_args(); + std::fs::create_dir_all(&args.output_dir).expect("create output dir"); + + println!("Loading q4k weights/index from {}", args.vindex.display()); + let mut cb = SilentLoadCallbacks; + let weights = load_model_weights_q4k(&args.vindex, &mut cb).expect("load q4k weights"); + let mut index = VectorIndex::load_vindex(&args.vindex, &mut cb).expect("load vindex"); + index + .load_interleaved_q4k(&args.vindex) + .expect("load interleaved q4k"); + + let mut inputs: Vec = std::fs::read_dir(&args.input_dir) + .expect("read input dir") + .filter_map(|e| e.ok().map(|e| e.path())) + .filter(|p| { + p.file_name() + .and_then(|s| s.to_str()) + .map(|s| s.ends_with("_mlp_input.f32bin")) + .unwrap_or(false) + }) + .collect(); + inputs.sort(); + + if inputs.is_empty() { + panic!( + "no *_mlp_input.f32bin files found in {}", + args.input_dir.display() + ); + } + + for input_path in inputs { + let name = input_path + .file_name() + .and_then(|s| s.to_str()) + .expect("utf8 filename"); + let window_id = name + .strip_suffix("_mlp_input.f32bin") + .expect("input suffix"); + let x = read_matrix(&input_path).expect("read input matrix"); + let method_name = args + .k + .map(|k| format!("q4k_top{k}_walk")) + .unwrap_or_else(|| "q4k_full_walk".to_string()); + println!( + "{}: running {} L{} on {}x{}", + window_id, + method_name, + args.layer, + x.shape()[0], + x.shape()[1] + ); + let out = if let Some(k) = args.k { + let walk = WalkFfn::from_config( + &weights, + &index, + WalkFfnConfig::sparse(weights.num_layers, k), + ); + walk.forward(args.layer, &x) + } else { + q4k_ffn_forward_layer(weights.arch.as_ref(), &index, args.layer, &x) + }; + let output_path = args + .output_dir + .join(format!("{window_id}_{method_name}.f32bin")); + write_matrix(&output_path, &out).expect("write output matrix"); + } + } + + fn read_matrix(path: &Path) -> std::io::Result> { + let mut f = File::open(path)?; + let mut magic = [0u8; 8]; + f.read_exact(&mut magic)?; + if &magic != MAGIC { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "bad LARQLF32 magic", + )); + } + let rows = read_u64(&mut f)? as usize; + let cols = read_u64(&mut f)? as usize; + let mut bytes = vec![0u8; rows * cols * 4]; + f.read_exact(&mut bytes)?; + let mut vals = Vec::with_capacity(rows * cols); + for chunk in bytes.chunks_exact(4) { + vals.push(f32::from_le_bytes(chunk.try_into().unwrap())); + } + Ok(Array2::from_shape_vec((rows, cols), vals).expect("matrix shape")) + } + + fn write_matrix(path: &Path, arr: &Array2) -> std::io::Result<()> { + let mut f = File::create(path)?; + f.write_all(MAGIC)?; + f.write_all(&(arr.shape()[0] as u64).to_le_bytes())?; + f.write_all(&(arr.shape()[1] as u64).to_le_bytes())?; + for v in arr.iter().copied() { + f.write_all(&v.to_le_bytes())?; + } + Ok(()) + } + + fn read_u64(f: &mut File) -> std::io::Result { + let mut buf = [0u8; 8]; + f.read_exact(&mut buf)?; + Ok(u64::from_le_bytes(buf)) + } +} diff --git a/crates/kv-cache-benchmark/examples/real_model_bench.rs b/crates/kv-cache-benchmark/examples/real_model_bench.rs index 074cb9a6..a7c9022a 100644 --- a/crates/kv-cache-benchmark/examples/real_model_bench.rs +++ b/crates/kv-cache-benchmark/examples/real_model_bench.rs @@ -12,34 +12,36 @@ fn main() { let args: Vec = std::env::args().collect(); // Load model - let model_name = args.get(1).map(|s| s.as_str()).unwrap_or("google/gemma-3-4b-it"); + let model_name = args + .get(1) + .map(|s| s.as_str()) + .unwrap_or("google/gemma-3-4b-it"); println!("Loading model: {model_name}"); - let model = larql_inference::InferenceModel::load(model_name) - .expect("Failed to load model"); + let model = larql_inference::InferenceModel::load(model_name).expect("Failed to load model"); // Load vindex (requires explicit path) - let vindex_path = args.get(2).expect( - "Usage: real_model_bench " - ); + let vindex_path = args + .get(2) + .expect("Usage: real_model_bench "); println!("Loading vindex from: {vindex_path}"); let index = larql_vindex::VectorIndex::load_vindex( std::path::Path::new(vindex_path), &mut larql_vindex::SilentLoadCallbacks, - ).expect("Failed to load vindex"); + ) + .expect("Failed to load vindex"); // Create compute backend let backend = larql_inference::default_backend(); - let bench = RealModelBenchmark::new( - model.weights(), - model.tokenizer(), - &index, - backend.as_ref(), - ); + let bench = + RealModelBenchmark::new(model.weights(), model.tokenizer(), &index, backend.as_ref()); // Run default prompts let prompts = runner::default_prompts(); - println!("\nRunning {} prompts through strategies...\n", prompts.len()); + println!( + "\nRunning {} prompts through strategies...\n", + prompts.len() + ); for prompt in &prompts { let results = runner::run_all_strategies(&bench, prompt, 5, 512); @@ -56,7 +58,10 @@ fn main() { use kv_cache_benchmark::KvStrategy; let strategies: Vec<&dyn KvStrategy> = vec![&standard, &tq4, &markov]; - println!("{}", kv_cache_benchmark::benchmark::format_comparative_table(&config, &strategies)); + println!( + "{}", + kv_cache_benchmark::benchmark::format_comparative_table(&config, &strategies) + ); println!( "\n{} @ 370K tokens: {} bytes per-conversation, {} bytes shared infrastructure", graph.name(), diff --git a/crates/kv-cache-benchmark/examples/shader_bench.rs b/crates/kv-cache-benchmark/examples/shader_bench.rs index 2cf648a2..8f1f6993 100644 --- a/crates/kv-cache-benchmark/examples/shader_bench.rs +++ b/crates/kv-cache-benchmark/examples/shader_bench.rs @@ -23,14 +23,17 @@ fn main() { // Memory comparison table (KV-reconstructing strategies only). let config = kv_cache_benchmark::model_config::ModelConfig::gemma_4b(); - println!("\n{}", kv_cache_benchmark::benchmark::format_comparative_table( - &config, - &[ - &kv_cache_benchmark::standard_kv::StandardKv as &dyn kv_cache_benchmark::KvStrategy, - &kv_cache_benchmark::turboquant::TurboQuant::new(4), - &kv_cache_benchmark::markov_residual::MarkovResidual::new(512), - ], - )); + println!( + "\n{}", + kv_cache_benchmark::benchmark::format_comparative_table( + &config, + &[ + &kv_cache_benchmark::standard_kv::StandardKv as &dyn kv_cache_benchmark::KvStrategy, + &kv_cache_benchmark::turboquant::TurboQuant::new(4), + &kv_cache_benchmark::markov_residual::MarkovResidual::new(512), + ], + ) + ); // Graph Walk is projected (no K/V reconstruction); report memory separately. let gw = kv_cache_benchmark::graph_walk::GraphWalk::gemma_4b(); diff --git a/crates/kv-cache-benchmark/examples/vindex_compare.rs b/crates/kv-cache-benchmark/examples/vindex_compare.rs new file mode 100644 index 00000000..0457e5b1 --- /dev/null +++ b/crates/kv-cache-benchmark/examples/vindex_compare.rs @@ -0,0 +1,320 @@ +//! Vindex A/B comparison runner. Format-agnostic — works for any pair +//! of VectorIndex instances sharing the same underlying model. +//! +//! Primary use: exp 26 Q2 (FP4 end-to-end correctness) via +//! +//! cargo run --release --features real-model -p kv-cache-benchmark \ +//! --example vindex_compare -- \ +//! --reference output/gemma3-4b-f16.vindex \ +//! --candidate output/gemma3-4b-fp4.vindex \ +//! --prompts experiments/26_fp4_quantisation/prompts.txt \ +//! --out experiments/26_fp4_quantisation/results/q2_fp4.json +//! +//! Any future storage-format comparison (FP6, NF4, Q4K regression +//! tests) reuses the same binary — nothing here is FP4-specific. + +#[cfg(feature = "real-model")] +use std::path::PathBuf; + +#[cfg(feature = "real-model")] +use kv_cache_benchmark::vindex_compare::{ + compare_many, forward_to_logits_traced, ComparisonConfig, +}; +#[cfg(feature = "real-model")] +use larql_inference::InferenceModel; +#[cfg(feature = "real-model")] +use larql_vindex::{SilentLoadCallbacks, VectorIndex}; + +#[cfg(not(feature = "real-model"))] +const REAL_MODEL_FEATURE_NAME: &str = "real-model"; + +#[cfg(feature = "real-model")] +struct Args { + reference: PathBuf, + candidate: PathBuf, + prompts_path: Option, + model: String, + out: Option, + top_k: usize, + max_seq_len: Option, + max_layers: Option, + inline_prompts: Vec, + trace: bool, +} + +#[cfg(feature = "real-model")] +fn parse_args() -> Args { + let argv: Vec = std::env::args().collect(); + let mut a = Args { + reference: PathBuf::new(), + candidate: PathBuf::new(), + prompts_path: None, + model: "google/gemma-3-4b-it".into(), + out: None, + top_k: 5, + max_seq_len: None, + max_layers: None, + inline_prompts: Vec::new(), + trace: false, + }; + let mut i = 1; + while i < argv.len() { + match argv[i].as_str() { + "--reference" => { + i += 1; + a.reference = PathBuf::from(&argv[i]); + } + "--candidate" => { + i += 1; + a.candidate = PathBuf::from(&argv[i]); + } + "--prompts" => { + i += 1; + a.prompts_path = Some(PathBuf::from(&argv[i])); + } + "--model" => { + i += 1; + a.model = argv[i].clone(); + } + "--out" => { + i += 1; + a.out = Some(PathBuf::from(&argv[i])); + } + "--top-k" => { + i += 1; + a.top_k = argv[i].parse().expect("int"); + } + "--max-seq" => { + i += 1; + a.max_seq_len = Some(argv[i].parse().expect("int")); + } + "--max-layers" => { + i += 1; + a.max_layers = Some(argv[i].parse().expect("int")); + } + "--prompt" => { + i += 1; + a.inline_prompts.push(argv[i].clone()); + } + "--trace" => { + a.trace = true; + } + other => eprintln!("warn: ignored arg {other}"), + } + i += 1; + } + if a.reference.as_os_str().is_empty() || a.candidate.as_os_str().is_empty() { + eprintln!( + "usage: vindex_compare --reference PATH --candidate PATH \\ + [--prompts FILE] [--prompt 'inline text' ...] \\ + [--model NAME] [--out PATH] [--top-k K] [--max-seq N] [--max-layers L] + +At least one of --prompts or --prompt must be provided." + ); + std::process::exit(1); + } + a +} + +#[cfg(feature = "real-model")] +fn load_prompts(args: &Args) -> Vec { + let mut prompts = args.inline_prompts.clone(); + if let Some(path) = &args.prompts_path { + let content = std::fs::read_to_string(path) + .unwrap_or_else(|e| panic!("read {}: {e}", path.display())); + for line in content.lines() { + let trimmed = line.trim(); + if trimmed.is_empty() || trimmed.starts_with('#') { + continue; + } + prompts.push(trimmed.to_string()); + } + } + if prompts.is_empty() { + // Small default set so running with just --reference / --candidate + // produces something on stdout. Real use cases should pass --prompts. + prompts = default_prompt_set(); + } + prompts +} + +#[cfg(feature = "real-model")] +fn default_prompt_set() -> Vec { + vec![ + "The capital of France is".into(), + "Two plus two equals".into(), + "The quick brown fox".into(), + "Once upon a time".into(), + "The largest planet in the solar system is".into(), + "Shakespeare wrote".into(), + "In 1969, the first man to walk on the moon was".into(), + "The chemical formula for water is".into(), + ] +} + +#[cfg(feature = "real-model")] +fn main() { + let args = parse_args(); + + println!("== vindex_compare =="); + println!(" reference: {}", args.reference.display()); + println!(" candidate: {}", args.candidate.display()); + println!(" model : {}", args.model); + println!(" top-k : {}", args.top_k); + if let Some(cap) = args.max_seq_len { + println!(" max_seq : {cap}"); + } + if let Some(l) = args.max_layers { + println!(" max_layers: {l}"); + } + println!(); + + let t_load = std::time::Instant::now(); + eprintln!("Loading model weights ({})...", args.model); + let model = InferenceModel::load(&args.model).unwrap_or_else(|e| panic!("load model: {e}")); + let tokenizer = model.tokenizer().clone(); + + eprintln!("Loading reference vindex..."); + let mut cb = SilentLoadCallbacks; + let reference = VectorIndex::load_vindex(&args.reference, &mut cb) + .unwrap_or_else(|e| panic!("load reference: {e:?}")); + eprintln!("Loading candidate vindex..."); + let candidate = VectorIndex::load_vindex(&args.candidate, &mut cb) + .unwrap_or_else(|e| panic!("load candidate: {e:?}")); + eprintln!(" loaded in {:.1}s", t_load.elapsed().as_secs_f64()); + eprintln!( + " reference has_fp4_storage={}", + reference.has_fp4_storage() + ); + eprintln!( + " candidate has_fp4_storage={}", + candidate.has_fp4_storage() + ); + eprintln!(); + + // Tokenise the prompt set. + let prompts = load_prompts(&args); + eprintln!("Prompt set: {} prompts", prompts.len()); + let prompts_and_tokens: Vec<(&str, Vec)> = prompts + .iter() + .map(|p| { + let enc = tokenizer + .encode(p.as_str(), true) + .unwrap_or_else(|e| panic!("tokenize: {e}")); + (p.as_str(), enc.get_ids().to_vec()) + }) + .collect(); + + let config = ComparisonConfig { + top_k: args.top_k, + max_seq_len: args.max_seq_len, + max_layers: args.max_layers, + }; + + let weights = model.weights(); + + // Optional single-prompt dispatch trace — isolates which walk path + // each vindex actually fires, per layer. Exp 26 Q2 surfaced a bug + // where an FP4 vindex silently fell through to the safetensors- + // weights path; --trace is the tool for catching that class again. + if args.trace { + let (prompt, tokens) = &prompts_and_tokens[0]; + eprintln!(); + eprintln!("── dispatch trace (prompt 0: {}) ──", prompt); + let cfg = ComparisonConfig { + top_k: args.top_k, + max_seq_len: args.max_seq_len, + max_layers: args.max_layers, + }; + let (_logits, ref_trace) = forward_to_logits_traced(weights, &reference, tokens, &cfg); + let (_logits, cand_trace) = forward_to_logits_traced(weights, &candidate, tokens, &cfg); + eprintln!(" {:>3} {:<32} {:<32}", "L", "reference", "candidate"); + for (layer, (r_path, c_path)) in ref_trace.iter().zip(cand_trace.iter()).enumerate() { + let flag = if r_path.1 == c_path.1 { " " } else { "≠" }; + eprintln!(" {:>3} {:<32} {:<32} {flag}", layer, r_path.1, c_path.1); + } + eprintln!(); + } + + let t_run = std::time::Instant::now(); + let mut report = compare_many( + weights, + &reference, + &candidate, + &prompts_and_tokens, + &args.reference.display().to_string(), + &args.candidate.display().to_string(), + &config, + ); + eprintln!("Compared in {:.1}s", t_run.elapsed().as_secs_f64()); + + // Decode top tokens for human-readable output (tokenizer-free library + // keeps this in the CLI). + for p in report.prompts.iter_mut() { + p.ref_top_token = Some(decode_token(&tokenizer, p.ref_top_token_id)); + p.cand_top_token = Some(decode_token(&tokenizer, p.cand_top_token_id)); + } + + print_human_report(&report); + + if let Some(out_path) = &args.out { + if let Some(parent) = out_path.parent() { + let _ = std::fs::create_dir_all(parent); + } + let json = + serde_json::to_string_pretty(&report).unwrap_or_else(|e| panic!("serialise: {e}")); + std::fs::write(out_path, json) + .unwrap_or_else(|e| panic!("write {}: {e}", out_path.display())); + println!(); + println!("→ wrote {}", out_path.display()); + } +} + +#[cfg(not(feature = "real-model"))] +fn main() { + eprintln!( + "vindex_compare requires the `{REAL_MODEL_FEATURE_NAME}` feature: \ + cargo run --release --features {REAL_MODEL_FEATURE_NAME} \ + -p kv-cache-benchmark --example vindex_compare -- ..." + ); + std::process::exit(1); +} + +#[cfg(feature = "real-model")] +fn decode_token(tokenizer: &tokenizers::Tokenizer, id: u32) -> String { + tokenizer + .decode(&[id], false) + .unwrap_or_else(|_| format!("<{id}>")) +} + +#[cfg(feature = "real-model")] +fn print_human_report(report: &kv_cache_benchmark::vindex_compare::AggregateReport) { + println!("── per-prompt ──"); + for p in &report.prompts { + let ref_t = p.ref_top_token.as_deref().unwrap_or("?"); + let cand_t = p.cand_top_token.as_deref().unwrap_or("?"); + let flag = if p.argmax_match { "✓" } else { "✗" }; + let short: String = p.prompt.chars().take(50).collect(); + println!( + " {flag} {short:<50} ref={ref_t:<12} cand={cand_t:<12} cos={:.4} jac={:.2} KL={:.4}", + p.logit_cos, p.top_k_jaccard, p.kl_symmetric + ); + } + println!(); + println!("── aggregate ──"); + println!(" n prompts : {}", report.n_prompts); + println!( + " argmax agreement : {:.4} ({}/{})", + report.argmax_agreement, + (report.argmax_agreement * report.n_prompts as f64).round() as usize, + report.n_prompts + ); + println!( + " top-{} Jaccard mean : {:.4}", + report.config.top_k, report.top_k_agreement_mean + ); + println!(" logit cosine mean : {:.4}", report.logit_cos_mean); + println!(" symmetric KL mean : {:.5}", report.kl_mean); + println!(" symmetric KL p95 : {:.5}", report.kl_p95); + println!(" symmetric KL max : {:.5}", report.kl_max); +} diff --git a/crates/kv-cache-benchmark/src/accuracy.rs b/crates/kv-cache-benchmark/src/accuracy.rs index 7e65fcb4..5c67041b 100644 --- a/crates/kv-cache-benchmark/src/accuracy.rs +++ b/crates/kv-cache-benchmark/src/accuracy.rs @@ -89,7 +89,11 @@ pub fn kl_divergence(p: &[f64], q: &[f64]) -> f64 { /// Compute Jensen-Shannon divergence (symmetric, bounded 0-1). pub fn js_divergence(p: &[f64], q: &[f64]) -> f64 { - let m: Vec = p.iter().zip(q.iter()).map(|(&a, &b)| (a + b) / 2.0).collect(); + let m: Vec = p + .iter() + .zip(q.iter()) + .map(|(&a, &b)| (a + b) / 2.0) + .collect(); (kl_divergence(p, &m) + kl_divergence(q, &m)) / 2.0 } @@ -121,7 +125,9 @@ pub fn first_divergence(a: &[u32], b: &[u32]) -> Option { /// Token-level match rate between two sequences. pub fn token_match_rate(a: &[u32], b: &[u32]) -> f32 { - if a.is_empty() { return 0.0; } + if a.is_empty() { + return 0.0; + } let matches = a.iter().zip(b.iter()).filter(|(&x, &y)| x == y).count(); matches as f32 / a.len().min(b.len()) as f32 } @@ -205,11 +211,13 @@ pub fn generate_haystack( /// Build a multi-turn fact retention conversation. pub fn build_retention_conversation(num_turns: usize) -> Vec { - let facts = [("My name is Alice and I work at Anthropic.", "name", "Alice"), + let facts = [ + ("My name is Alice and I work at Anthropic.", "name", "Alice"), ("I'm based in San Francisco.", "location", "San Francisco"), ("My project is called Lighthouse.", "project", "Lighthouse"), ("My favorite color is blue.", "color", "blue"), - ("I have two cats named Luna and Sol.", "pets", "Luna")]; + ("I have two cats named Luna and Sol.", "pets", "Luna"), + ]; let queries = vec![ ("What project am I working on?", "project", "Lighthouse"), @@ -307,10 +315,8 @@ pub fn format_accuracy_summary(results: &[AccuracyResult]) -> String { out.push('\n'); for strategy in &strategies { - let strat_results: Vec<&AccuracyResult> = results - .iter() - .filter(|r| &r.strategy == strategy) - .collect(); + let strat_results: Vec<&AccuracyResult> = + results.iter().filter(|r| &r.strategy == strategy).collect(); let total = strat_results.len(); let top1_matches = strat_results.iter().filter(|r| r.top1_match).count(); @@ -336,7 +342,10 @@ pub fn format_accuracy_summary(results: &[AccuracyResult]) -> String { .filter(|r| r.needle_found.is_some()) .copied() .collect(); - let needles_found = needles.iter().filter(|r| r.needle_found == Some(true)).count(); + let needles_found = needles + .iter() + .filter(|r| r.needle_found == Some(true)) + .count(); let needle_str = if needles.is_empty() { "n/a".to_string() } else { diff --git a/crates/kv-cache-benchmark/src/accuracy_suite/mod.rs b/crates/kv-cache-benchmark/src/accuracy_suite/mod.rs index 8238e430..77658479 100644 --- a/crates/kv-cache-benchmark/src/accuracy_suite/mod.rs +++ b/crates/kv-cache-benchmark/src/accuracy_suite/mod.rs @@ -8,9 +8,9 @@ //! //! Requires `real-model` feature — needs actual model weights. +#[cfg(feature = "real-model")] +pub mod needle; #[cfg(feature = "real-model")] pub mod prompts; #[cfg(feature = "real-model")] pub mod runner; -#[cfg(feature = "real-model")] -pub mod needle; diff --git a/crates/kv-cache-benchmark/src/accuracy_suite/needle.rs b/crates/kv-cache-benchmark/src/accuracy_suite/needle.rs index 6344c367..6b819a8e 100644 --- a/crates/kv-cache-benchmark/src/accuracy_suite/needle.rs +++ b/crates/kv-cache-benchmark/src/accuracy_suite/needle.rs @@ -23,31 +23,87 @@ pub fn needle_tests() -> Vec { let query = "What is the secret project code name?"; vec![ - NeedleTest { context_tokens: 512, needle_text: needle, needle_answer: answer, query_text: query }, - NeedleTest { context_tokens: 1024, needle_text: needle, needle_answer: answer, query_text: query }, - NeedleTest { context_tokens: 2048, needle_text: needle, needle_answer: answer, query_text: query }, - NeedleTest { context_tokens: 4096, needle_text: needle, needle_answer: answer, query_text: query }, - NeedleTest { context_tokens: 8192, needle_text: needle, needle_answer: answer, query_text: query }, - NeedleTest { context_tokens: 16384, needle_text: needle, needle_answer: answer, query_text: query }, - NeedleTest { context_tokens: 32768, needle_text: needle, needle_answer: answer, query_text: query }, + NeedleTest { + context_tokens: 512, + needle_text: needle, + needle_answer: answer, + query_text: query, + }, + NeedleTest { + context_tokens: 1024, + needle_text: needle, + needle_answer: answer, + query_text: query, + }, + NeedleTest { + context_tokens: 2048, + needle_text: needle, + needle_answer: answer, + query_text: query, + }, + NeedleTest { + context_tokens: 4096, + needle_text: needle, + needle_answer: answer, + query_text: query, + }, + NeedleTest { + context_tokens: 8192, + needle_text: needle, + needle_answer: answer, + query_text: query, + }, + NeedleTest { + context_tokens: 16384, + needle_text: needle, + needle_answer: answer, + query_text: query, + }, + NeedleTest { + context_tokens: 32768, + needle_text: needle, + needle_answer: answer, + query_text: query, + }, ] } /// Multi-needle test: 5 facts at different positions in 32K context. pub fn multi_needle_tests() -> Vec<(&'static str, &'static str, &'static str)> { vec![ - ("Agent Alpha's code name is FALCON.", "FALCON", "What is Agent Alpha's code name?"), - ("The launch date is March 15th.", "March", "What is the launch date?"), - ("Budget allocation is $4.7 million.", "4.7", "What is the budget allocation?"), - ("The target city is Reykjavik.", "Reykjavik", "What is the target city?"), - ("Project sponsor is Dr. Kimura.", "Kimura", "Who is the project sponsor?"), + ( + "Agent Alpha's code name is FALCON.", + "FALCON", + "What is Agent Alpha's code name?", + ), + ( + "The launch date is March 15th.", + "March", + "What is the launch date?", + ), + ( + "Budget allocation is $4.7 million.", + "4.7", + "What is the budget allocation?", + ), + ( + "The target city is Reykjavik.", + "Reykjavik", + "What is the target city?", + ), + ( + "Project sponsor is Dr. Kimura.", + "Kimura", + "Who is the project sponsor?", + ), ] } /// Build a haystack context with needle planted at ~10% position. pub fn build_haystack(target_tokens: usize, needle: &str) -> String { // Filler: ~4 chars per token average - let filler_sentence = "The quick brown fox jumps over the lazy dog near the old oak tree by the river. "; + let filler_sentence = + "The quick brown fox jumps over the lazy dog near the old oak tree by the river. "; let needle_position = target_tokens / 10; // Plant early (~10% in) let chars_per_token = 4; diff --git a/crates/kv-cache-benchmark/src/accuracy_suite/prompts.rs b/crates/kv-cache-benchmark/src/accuracy_suite/prompts.rs index 7081a669..c2de82fe 100644 --- a/crates/kv-cache-benchmark/src/accuracy_suite/prompts.rs +++ b/crates/kv-cache-benchmark/src/accuracy_suite/prompts.rs @@ -24,122 +24,514 @@ pub fn paris_test() -> TestPrompt { pub fn diverse_100() -> Vec { vec![ // Factual: capitals (20) - TestPrompt { text: "The capital of France is", expected_contains: "Paris", category: "factual" }, - TestPrompt { text: "The capital of Germany is", expected_contains: "Berlin", category: "factual" }, - TestPrompt { text: "The capital of Japan is", expected_contains: "Tokyo", category: "factual" }, - TestPrompt { text: "The capital of Italy is", expected_contains: "Rome", category: "factual" }, - TestPrompt { text: "The capital of Spain is", expected_contains: "Madrid", category: "factual" }, - TestPrompt { text: "The capital of Brazil is", expected_contains: "Bras", category: "factual" }, - TestPrompt { text: "The capital of Australia is", expected_contains: "Canberra", category: "factual" }, - TestPrompt { text: "The capital of Canada is", expected_contains: "Ottawa", category: "factual" }, - TestPrompt { text: "The capital of Egypt is", expected_contains: "Cairo", category: "factual" }, - TestPrompt { text: "The capital of India is", expected_contains: "Delhi", category: "factual" }, - TestPrompt { text: "The capital of Mexico is", expected_contains: "Mexico", category: "factual" }, - TestPrompt { text: "The capital of Russia is", expected_contains: "Moscow", category: "factual" }, - TestPrompt { text: "The capital of China is", expected_contains: "Beijing", category: "factual" }, - TestPrompt { text: "The capital of South Korea is", expected_contains: "Seoul", category: "factual" }, - TestPrompt { text: "The capital of Turkey is", expected_contains: "Ankara", category: "factual" }, - TestPrompt { text: "The capital of Thailand is", expected_contains: "Bangkok", category: "factual" }, - TestPrompt { text: "The capital of Argentina is", expected_contains: "Buenos", category: "factual" }, - TestPrompt { text: "The capital of Sweden is", expected_contains: "Stockholm", category: "factual" }, - TestPrompt { text: "The capital of Norway is", expected_contains: "Oslo", category: "factual" }, - TestPrompt { text: "The capital of Poland is", expected_contains: "Warsaw", category: "factual" }, - + TestPrompt { + text: "The capital of France is", + expected_contains: "Paris", + category: "factual", + }, + TestPrompt { + text: "The capital of Germany is", + expected_contains: "Berlin", + category: "factual", + }, + TestPrompt { + text: "The capital of Japan is", + expected_contains: "Tokyo", + category: "factual", + }, + TestPrompt { + text: "The capital of Italy is", + expected_contains: "Rome", + category: "factual", + }, + TestPrompt { + text: "The capital of Spain is", + expected_contains: "Madrid", + category: "factual", + }, + TestPrompt { + text: "The capital of Brazil is", + expected_contains: "Bras", + category: "factual", + }, + TestPrompt { + text: "The capital of Australia is", + expected_contains: "Canberra", + category: "factual", + }, + TestPrompt { + text: "The capital of Canada is", + expected_contains: "Ottawa", + category: "factual", + }, + TestPrompt { + text: "The capital of Egypt is", + expected_contains: "Cairo", + category: "factual", + }, + TestPrompt { + text: "The capital of India is", + expected_contains: "Delhi", + category: "factual", + }, + TestPrompt { + text: "The capital of Mexico is", + expected_contains: "Mexico", + category: "factual", + }, + TestPrompt { + text: "The capital of Russia is", + expected_contains: "Moscow", + category: "factual", + }, + TestPrompt { + text: "The capital of China is", + expected_contains: "Beijing", + category: "factual", + }, + TestPrompt { + text: "The capital of South Korea is", + expected_contains: "Seoul", + category: "factual", + }, + TestPrompt { + text: "The capital of Turkey is", + expected_contains: "Ankara", + category: "factual", + }, + TestPrompt { + text: "The capital of Thailand is", + expected_contains: "Bangkok", + category: "factual", + }, + TestPrompt { + text: "The capital of Argentina is", + expected_contains: "Buenos", + category: "factual", + }, + TestPrompt { + text: "The capital of Sweden is", + expected_contains: "Stockholm", + category: "factual", + }, + TestPrompt { + text: "The capital of Norway is", + expected_contains: "Oslo", + category: "factual", + }, + TestPrompt { + text: "The capital of Poland is", + expected_contains: "Warsaw", + category: "factual", + }, // Factual: people (10) - TestPrompt { text: "Mozart was born in", expected_contains: "Salzburg", category: "factual" }, - TestPrompt { text: "Einstein was born in", expected_contains: "Ulm", category: "factual" }, - TestPrompt { text: "Shakespeare was born in", expected_contains: "Strat", category: "factual" }, - TestPrompt { text: "The Mona Lisa was painted by", expected_contains: "Leonardo", category: "factual" }, - TestPrompt { text: "The theory of relativity was developed by", expected_contains: "Einstein", category: "factual" }, - TestPrompt { text: "The first president of the United States was", expected_contains: "George", category: "factual" }, - TestPrompt { text: "Apple Inc. was co-founded by Steve", expected_contains: "Jobs", category: "factual" }, - TestPrompt { text: "The author of Harry Potter is J.K.", expected_contains: "Rowling", category: "factual" }, - TestPrompt { text: "Beethoven's first name was", expected_contains: "Ludwig", category: "factual" }, - TestPrompt { text: "Isaac Newton discovered", expected_contains: "grav", category: "factual" }, - + TestPrompt { + text: "Mozart was born in", + expected_contains: "Salzburg", + category: "factual", + }, + TestPrompt { + text: "Einstein was born in", + expected_contains: "Ulm", + category: "factual", + }, + TestPrompt { + text: "Shakespeare was born in", + expected_contains: "Strat", + category: "factual", + }, + TestPrompt { + text: "The Mona Lisa was painted by", + expected_contains: "Leonardo", + category: "factual", + }, + TestPrompt { + text: "The theory of relativity was developed by", + expected_contains: "Einstein", + category: "factual", + }, + TestPrompt { + text: "The first president of the United States was", + expected_contains: "George", + category: "factual", + }, + TestPrompt { + text: "Apple Inc. was co-founded by Steve", + expected_contains: "Jobs", + category: "factual", + }, + TestPrompt { + text: "The author of Harry Potter is J.K.", + expected_contains: "Rowling", + category: "factual", + }, + TestPrompt { + text: "Beethoven's first name was", + expected_contains: "Ludwig", + category: "factual", + }, + TestPrompt { + text: "Isaac Newton discovered", + expected_contains: "grav", + category: "factual", + }, // Factual: science (10) - TestPrompt { text: "Water freezes at", expected_contains: "0", category: "scientific" }, - TestPrompt { text: "The chemical symbol for gold is", expected_contains: "Au", category: "scientific" }, - TestPrompt { text: "The chemical formula for water is", expected_contains: "H", category: "scientific" }, - TestPrompt { text: "The speed of light is approximately", expected_contains: "3", category: "scientific" }, - TestPrompt { text: "The largest planet in our solar system is", expected_contains: "Jupiter", category: "scientific" }, - TestPrompt { text: "DNA stands for deoxyribonucle", expected_contains: "ic", category: "scientific" }, - TestPrompt { text: "The atomic number of carbon is", expected_contains: "6", category: "scientific" }, - TestPrompt { text: "Photosynthesis converts sunlight into", expected_contains: "energy", category: "scientific" }, - TestPrompt { text: "The boiling point of water is", expected_contains: "100", category: "scientific" }, - TestPrompt { text: "The nearest star to Earth is the", expected_contains: "Sun", category: "scientific" }, - + TestPrompt { + text: "Water freezes at", + expected_contains: "0", + category: "scientific", + }, + TestPrompt { + text: "The chemical symbol for gold is", + expected_contains: "Au", + category: "scientific", + }, + TestPrompt { + text: "The chemical formula for water is", + expected_contains: "H", + category: "scientific", + }, + TestPrompt { + text: "The speed of light is approximately", + expected_contains: "3", + category: "scientific", + }, + TestPrompt { + text: "The largest planet in our solar system is", + expected_contains: "Jupiter", + category: "scientific", + }, + TestPrompt { + text: "DNA stands for deoxyribonucle", + expected_contains: "ic", + category: "scientific", + }, + TestPrompt { + text: "The atomic number of carbon is", + expected_contains: "6", + category: "scientific", + }, + TestPrompt { + text: "Photosynthesis converts sunlight into", + expected_contains: "energy", + category: "scientific", + }, + TestPrompt { + text: "The boiling point of water is", + expected_contains: "100", + category: "scientific", + }, + TestPrompt { + text: "The nearest star to Earth is the", + expected_contains: "Sun", + category: "scientific", + }, // Factual: geography (10) - TestPrompt { text: "The longest river in Africa is the", expected_contains: "Nile", category: "geographic" }, - TestPrompt { text: "The tallest mountain in the world is", expected_contains: "Everest", category: "geographic" }, - TestPrompt { text: "The largest ocean is the", expected_contains: "Pacific", category: "geographic" }, - TestPrompt { text: "The Amazon River flows through", expected_contains: "Brazil", category: "geographic" }, - TestPrompt { text: "The Sahara Desert is located in", expected_contains: "Africa", category: "geographic" }, - TestPrompt { text: "The Great Wall of China is located in", expected_contains: "China", category: "geographic" }, - TestPrompt { text: "The currency of Japan is the", expected_contains: "yen", category: "geographic" }, - TestPrompt { text: "The currency of the United Kingdom is the", expected_contains: "pound", category: "geographic" }, - TestPrompt { text: "The official language of Brazil is", expected_contains: "Portug", category: "geographic" }, - TestPrompt { text: "The smallest continent is", expected_contains: "Australia", category: "geographic" }, - + TestPrompt { + text: "The longest river in Africa is the", + expected_contains: "Nile", + category: "geographic", + }, + TestPrompt { + text: "The tallest mountain in the world is", + expected_contains: "Everest", + category: "geographic", + }, + TestPrompt { + text: "The largest ocean is the", + expected_contains: "Pacific", + category: "geographic", + }, + TestPrompt { + text: "The Amazon River flows through", + expected_contains: "Brazil", + category: "geographic", + }, + TestPrompt { + text: "The Sahara Desert is located in", + expected_contains: "Africa", + category: "geographic", + }, + TestPrompt { + text: "The Great Wall of China is located in", + expected_contains: "China", + category: "geographic", + }, + TestPrompt { + text: "The currency of Japan is the", + expected_contains: "yen", + category: "geographic", + }, + TestPrompt { + text: "The currency of the United Kingdom is the", + expected_contains: "pound", + category: "geographic", + }, + TestPrompt { + text: "The official language of Brazil is", + expected_contains: "Portug", + category: "geographic", + }, + TestPrompt { + text: "The smallest continent is", + expected_contains: "Australia", + category: "geographic", + }, // Completion (10) - TestPrompt { text: "To be or not to be, that is the", expected_contains: "question", category: "completion" }, - TestPrompt { text: "I think, therefore I", expected_contains: "am", category: "completion" }, - TestPrompt { text: "All that glitters is not", expected_contains: "gold", category: "completion" }, - TestPrompt { text: "A journey of a thousand miles begins with a single", expected_contains: "step", category: "completion" }, - TestPrompt { text: "The early bird catches the", expected_contains: "worm", category: "completion" }, - TestPrompt { text: "Actions speak louder than", expected_contains: "words", category: "completion" }, - TestPrompt { text: "Rome was not built in a", expected_contains: "day", category: "completion" }, - TestPrompt { text: "Knowledge is", expected_contains: "power", category: "completion" }, - TestPrompt { text: "Practice makes", expected_contains: "perfect", category: "completion" }, - TestPrompt { text: "Where there is smoke, there is", expected_contains: "fire", category: "completion" }, - + TestPrompt { + text: "To be or not to be, that is the", + expected_contains: "question", + category: "completion", + }, + TestPrompt { + text: "I think, therefore I", + expected_contains: "am", + category: "completion", + }, + TestPrompt { + text: "All that glitters is not", + expected_contains: "gold", + category: "completion", + }, + TestPrompt { + text: "A journey of a thousand miles begins with a single", + expected_contains: "step", + category: "completion", + }, + TestPrompt { + text: "The early bird catches the", + expected_contains: "worm", + category: "completion", + }, + TestPrompt { + text: "Actions speak louder than", + expected_contains: "words", + category: "completion", + }, + TestPrompt { + text: "Rome was not built in a", + expected_contains: "day", + category: "completion", + }, + TestPrompt { + text: "Knowledge is", + expected_contains: "power", + category: "completion", + }, + TestPrompt { + text: "Practice makes", + expected_contains: "perfect", + category: "completion", + }, + TestPrompt { + text: "Where there is smoke, there is", + expected_contains: "fire", + category: "completion", + }, // Arithmetic (10) - TestPrompt { text: "2 + 2 =", expected_contains: "4", category: "arithmetic" }, - TestPrompt { text: "10 × 10 =", expected_contains: "100", category: "arithmetic" }, - TestPrompt { text: "100 / 4 =", expected_contains: "25", category: "arithmetic" }, - TestPrompt { text: "The square root of 144 is", expected_contains: "12", category: "arithmetic" }, - TestPrompt { text: "15 + 27 =", expected_contains: "42", category: "arithmetic" }, - TestPrompt { text: "One dozen equals", expected_contains: "12", category: "arithmetic" }, - TestPrompt { text: "A century is", expected_contains: "100", category: "arithmetic" }, - TestPrompt { text: "One kilometer equals", expected_contains: "1", category: "arithmetic" }, - TestPrompt { text: "There are 60 seconds in a", expected_contains: "minute", category: "arithmetic" }, - TestPrompt { text: "There are 24 hours in a", expected_contains: "day", category: "arithmetic" }, - + TestPrompt { + text: "2 + 2 =", + expected_contains: "4", + category: "arithmetic", + }, + TestPrompt { + text: "10 × 10 =", + expected_contains: "100", + category: "arithmetic", + }, + TestPrompt { + text: "100 / 4 =", + expected_contains: "25", + category: "arithmetic", + }, + TestPrompt { + text: "The square root of 144 is", + expected_contains: "12", + category: "arithmetic", + }, + TestPrompt { + text: "15 + 27 =", + expected_contains: "42", + category: "arithmetic", + }, + TestPrompt { + text: "One dozen equals", + expected_contains: "12", + category: "arithmetic", + }, + TestPrompt { + text: "A century is", + expected_contains: "100", + category: "arithmetic", + }, + TestPrompt { + text: "One kilometer equals", + expected_contains: "1", + category: "arithmetic", + }, + TestPrompt { + text: "There are 60 seconds in a", + expected_contains: "minute", + category: "arithmetic", + }, + TestPrompt { + text: "There are 24 hours in a", + expected_contains: "day", + category: "arithmetic", + }, // Code (10) - TestPrompt { text: "In Python, to print 'hello' you write print(", expected_contains: "'", category: "code" }, - TestPrompt { text: "In JavaScript, a variable is declared with let, const, or", expected_contains: "var", category: "code" }, - TestPrompt { text: "HTML stands for Hyper", expected_contains: "Text", category: "code" }, - TestPrompt { text: "The HTTP status code for 'Not Found' is", expected_contains: "404", category: "code" }, - TestPrompt { text: "In SQL, to select all columns you use SELECT", expected_contains: "*", category: "code" }, - TestPrompt { text: "Git is a distributed version", expected_contains: "control", category: "code" }, - TestPrompt { text: "JSON stands for JavaScript Object", expected_contains: "Notation", category: "code" }, - TestPrompt { text: "The file extension for Python files is .", expected_contains: "py", category: "code" }, - TestPrompt { text: "In CSS, to make text bold you use font-weight:", expected_contains: "bold", category: "code" }, - TestPrompt { text: "The command to list files in Linux is", expected_contains: "ls", category: "code" }, - + TestPrompt { + text: "In Python, to print 'hello' you write print(", + expected_contains: "'", + category: "code", + }, + TestPrompt { + text: "In JavaScript, a variable is declared with let, const, or", + expected_contains: "var", + category: "code", + }, + TestPrompt { + text: "HTML stands for Hyper", + expected_contains: "Text", + category: "code", + }, + TestPrompt { + text: "The HTTP status code for 'Not Found' is", + expected_contains: "404", + category: "code", + }, + TestPrompt { + text: "In SQL, to select all columns you use SELECT", + expected_contains: "*", + category: "code", + }, + TestPrompt { + text: "Git is a distributed version", + expected_contains: "control", + category: "code", + }, + TestPrompt { + text: "JSON stands for JavaScript Object", + expected_contains: "Notation", + category: "code", + }, + TestPrompt { + text: "The file extension for Python files is .", + expected_contains: "py", + category: "code", + }, + TestPrompt { + text: "In CSS, to make text bold you use font-weight:", + expected_contains: "bold", + category: "code", + }, + TestPrompt { + text: "The command to list files in Linux is", + expected_contains: "ls", + category: "code", + }, // Conversational (10) - TestPrompt { text: "How are you today? I'm doing", expected_contains: "well", category: "conversational" }, - TestPrompt { text: "Thank you very much! You're", expected_contains: "welcome", category: "conversational" }, - TestPrompt { text: "Good morning! How did you", expected_contains: "sleep", category: "conversational" }, - TestPrompt { text: "See you later! Have a great", expected_contains: "day", category: "conversational" }, - TestPrompt { text: "Happy birthday! How old are", expected_contains: "you", category: "conversational" }, - TestPrompt { text: "Sorry for the delay. I was", expected_contains: "busy", category: "conversational" }, - TestPrompt { text: "What do you think about", expected_contains: "the", category: "conversational" }, - TestPrompt { text: "Let me know if you need any", expected_contains: "help", category: "conversational" }, - TestPrompt { text: "I completely agree with", expected_contains: "you", category: "conversational" }, - TestPrompt { text: "That's a really good", expected_contains: "point", category: "conversational" }, - + TestPrompt { + text: "How are you today? I'm doing", + expected_contains: "well", + category: "conversational", + }, + TestPrompt { + text: "Thank you very much! You're", + expected_contains: "welcome", + category: "conversational", + }, + TestPrompt { + text: "Good morning! How did you", + expected_contains: "sleep", + category: "conversational", + }, + TestPrompt { + text: "See you later! Have a great", + expected_contains: "day", + category: "conversational", + }, + TestPrompt { + text: "Happy birthday! How old are", + expected_contains: "you", + category: "conversational", + }, + TestPrompt { + text: "Sorry for the delay. I was", + expected_contains: "busy", + category: "conversational", + }, + TestPrompt { + text: "What do you think about", + expected_contains: "the", + category: "conversational", + }, + TestPrompt { + text: "Let me know if you need any", + expected_contains: "help", + category: "conversational", + }, + TestPrompt { + text: "I completely agree with", + expected_contains: "you", + category: "conversational", + }, + TestPrompt { + text: "That's a really good", + expected_contains: "point", + category: "conversational", + }, // Reasoning (10) - TestPrompt { text: "If it rains, the ground gets", expected_contains: "wet", category: "reasoning" }, - TestPrompt { text: "The opposite of hot is", expected_contains: "cold", category: "reasoning" }, - TestPrompt { text: "The color of grass is", expected_contains: "green", category: "reasoning" }, - TestPrompt { text: "The day after Monday is", expected_contains: "Tuesday", category: "reasoning" }, - TestPrompt { text: "Ice is the solid form of", expected_contains: "water", category: "reasoning" }, - TestPrompt { text: "The month after January is", expected_contains: "February", category: "reasoning" }, - TestPrompt { text: "Cats are a type of", expected_contains: "animal", category: "reasoning" }, - TestPrompt { text: "The sun rises in the", expected_contains: "east", category: "reasoning" }, - TestPrompt { text: "The plural of child is", expected_contains: "children", category: "reasoning" }, - TestPrompt { text: "A triangle has three", expected_contains: "side", category: "reasoning" }, + TestPrompt { + text: "If it rains, the ground gets", + expected_contains: "wet", + category: "reasoning", + }, + TestPrompt { + text: "The opposite of hot is", + expected_contains: "cold", + category: "reasoning", + }, + TestPrompt { + text: "The color of grass is", + expected_contains: "green", + category: "reasoning", + }, + TestPrompt { + text: "The day after Monday is", + expected_contains: "Tuesday", + category: "reasoning", + }, + TestPrompt { + text: "Ice is the solid form of", + expected_contains: "water", + category: "reasoning", + }, + TestPrompt { + text: "The month after January is", + expected_contains: "February", + category: "reasoning", + }, + TestPrompt { + text: "Cats are a type of", + expected_contains: "animal", + category: "reasoning", + }, + TestPrompt { + text: "The sun rises in the", + expected_contains: "east", + category: "reasoning", + }, + TestPrompt { + text: "The plural of child is", + expected_contains: "children", + category: "reasoning", + }, + TestPrompt { + text: "A triangle has three", + expected_contains: "side", + category: "reasoning", + }, ] } diff --git a/crates/kv-cache-benchmark/src/accuracy_suite/runner.rs b/crates/kv-cache-benchmark/src/accuracy_suite/runner.rs index 67651566..2b9048e4 100644 --- a/crates/kv-cache-benchmark/src/accuracy_suite/runner.rs +++ b/crates/kv-cache-benchmark/src/accuracy_suite/runner.rs @@ -8,10 +8,10 @@ //! Markov RS 100% 0.0 100% 100% //! ``` -use larql_inference::model::ModelWeights; -use larql_inference::forward::predict; -use crate::accuracy; use super::prompts::TestPrompt; +use crate::accuracy; +use larql_inference::forward::predict; +use larql_inference::model::ModelWeights; /// Per-strategy accuracy scores across all tests. #[derive(Debug, Clone, serde::Serialize)] @@ -53,7 +53,8 @@ pub fn test_paris( backend: &dyn larql_compute::ComputeBackend, ) -> Vec<(String, bool)> { let bench = crate::real_model::RealModelBenchmark::new(weights, tokenizer, index, backend); - let results = crate::real_model::runner::run_all_strategies(&bench, "The capital of France is", 5, 512); + let results = + crate::real_model::runner::run_all_strategies(&bench, "The capital of France is", 5, 512); results .iter() @@ -79,19 +80,14 @@ pub fn test_top1_match_rate( let mut results = Vec::new(); for prompt in prompts { - let strat_results = crate::real_model::runner::run_all_strategies( - &bench, prompt.text, 5, 512, - ); + let strat_results = + crate::real_model::runner::run_all_strategies(&bench, prompt.text, 5, 512); let baseline_top1 = strat_results[0].top1_token.clone(); let mut strategy_results = Vec::new(); for r in &strat_results { - strategy_results.push(( - r.strategy.clone(), - r.top1_token.clone(), - r.top1_match, - )); + strategy_results.push((r.strategy.clone(), r.top1_token.clone(), r.top1_match)); } results.push(PromptResult { @@ -198,9 +194,17 @@ pub fn compute_strategy_accuracy(prompt_results: &[PromptResult]) -> Vec String { +pub fn format_comparative_table(config: &ModelConfig, strategies: &[&dyn KvStrategy]) -> String { let mut out = String::new(); - out.push_str(&format!("\n=== KV Cache Strategy Comparison: {} ===\n\n", config.name)); + out.push_str(&format!( + "\n=== KV Cache Strategy Comparison: {} ===\n\n", + config.name + )); let col_width = 15; out.push_str(&format!("{:<25}", "Context Length")); @@ -136,7 +135,11 @@ pub fn format_comparative_table( out.push_str(&format!("{:<25}", format_tokens(seq_len))); for strategy in strategies { let mem = strategy.memory_bytes(config, seq_len); - out.push_str(&format!(" {:>width$}", format_bytes(mem), width = col_width)); + out.push_str(&format!( + " {:>width$}", + format_bytes(mem), + width = col_width + )); } out.push('\n'); } diff --git a/crates/kv-cache-benchmark/src/graph_walk/fallback.rs b/crates/kv-cache-benchmark/src/graph_walk/fallback.rs index f7f7d556..d20be976 100644 --- a/crates/kv-cache-benchmark/src/graph_walk/fallback.rs +++ b/crates/kv-cache-benchmark/src/graph_walk/fallback.rs @@ -6,7 +6,6 @@ /// /// The benchmark reports what % of queries resolve at each tier /// and the accuracy per tier vs full forward pass baseline. - use super::walk_state::{WalkState, WalkTier}; /// Result of tier-based routing. @@ -77,22 +76,34 @@ impl TierDistribution { } pub fn tier_a_pct(&self) -> f64 { - if self.total == 0 { 0.0 } else { self.tier_a_count as f64 / self.total as f64 * 100.0 } + if self.total == 0 { + 0.0 + } else { + self.tier_a_count as f64 / self.total as f64 * 100.0 + } } pub fn tier_b_pct(&self) -> f64 { - if self.total == 0 { 0.0 } else { self.tier_b_count as f64 / self.total as f64 * 100.0 } + if self.total == 0 { + 0.0 + } else { + self.tier_b_count as f64 / self.total as f64 * 100.0 + } } pub fn tier_c_pct(&self) -> f64 { - if self.total == 0 { 0.0 } else { self.tier_c_count as f64 / self.total as f64 * 100.0 } + if self.total == 0 { + 0.0 + } else { + self.tier_c_count as f64 / self.total as f64 * 100.0 + } } } #[cfg(test)] mod tests { - use super::*; use super::super::walk_state::WalkMode; + use super::*; #[test] fn test_tier_routing() { diff --git a/crates/kv-cache-benchmark/src/graph_walk/mod.rs b/crates/kv-cache-benchmark/src/graph_walk/mod.rs index 9685aa06..957be0a2 100644 --- a/crates/kv-cache-benchmark/src/graph_walk/mod.rs +++ b/crates/kv-cache-benchmark/src/graph_walk/mod.rs @@ -1,7 +1,7 @@ +pub mod fallback; pub mod routing_table; -pub mod walk_state; pub mod template; -pub mod fallback; +pub mod walk_state; /// Residual Stream Graph Walk — projected architecture, memory-accounting only. /// @@ -43,7 +43,7 @@ impl GraphWalk { /// Default for Gemma 3-4B based on measured values. pub fn gemma_4b() -> Self { Self { - vindex_bytes: 1_500_000_000, // 1.5 GB Q4 vindex + vindex_bytes: 1_500_000_000, // 1.5 GB Q4 vindex routing_table_bytes: 360_448, // 352 KB routing table num_features: 348_000, num_layers: 34, @@ -51,7 +51,12 @@ impl GraphWalk { } /// Create with custom parameters. - pub fn new(vindex_bytes: usize, routing_table_bytes: usize, num_features: usize, num_layers: usize) -> Self { + pub fn new( + vindex_bytes: usize, + routing_table_bytes: usize, + num_features: usize, + num_layers: usize, + ) -> Self { Self { vindex_bytes, routing_table_bytes, diff --git a/crates/kv-cache-benchmark/src/graph_walk/routing_table.rs b/crates/kv-cache-benchmark/src/graph_walk/routing_table.rs index 750f42ce..039156f1 100644 --- a/crates/kv-cache-benchmark/src/graph_walk/routing_table.rs +++ b/crates/kv-cache-benchmark/src/graph_walk/routing_table.rs @@ -58,9 +58,7 @@ impl RoutingTable { let entry_bytes: usize = self .routes .iter() - .map(|(name, entries)| { - name.len() + entries.len() * 40 - }) + .map(|(name, entries)| name.len() + entries.len() * 40) .sum(); entry_bytes.max(360_448) // At least the measured 352 KB } diff --git a/crates/kv-cache-benchmark/src/graph_walk/template.rs b/crates/kv-cache-benchmark/src/graph_walk/template.rs index 9ad69ae1..bc2cf3a5 100644 --- a/crates/kv-cache-benchmark/src/graph_walk/template.rs +++ b/crates/kv-cache-benchmark/src/graph_walk/template.rs @@ -32,9 +32,9 @@ impl PatternWalk { template_id: "capital-of".to_string(), critical_layers: vec![13, 15, 24, 25, 26], feature_ranges: vec![ - (13, vec![8000..8500]), // Task classifier features - (15, vec![3000..3200]), // Confidence router - (24, vec![5000..6000]), // Factual retrieval + (13, vec![8000..8500]), // Task classifier features + (15, vec![3000..3200]), // Confidence router + (24, vec![5000..6000]), // Factual retrieval (25, vec![5000..6000]), (26, vec![5000..6000]), ], diff --git a/crates/kv-cache-benchmark/src/graph_walk/walk_state.rs b/crates/kv-cache-benchmark/src/graph_walk/walk_state.rs index 51a107b4..8627358f 100644 --- a/crates/kv-cache-benchmark/src/graph_walk/walk_state.rs +++ b/crates/kv-cache-benchmark/src/graph_walk/walk_state.rs @@ -97,8 +97,8 @@ impl WalkState { /// Estimated latency for this walk tier in microseconds. pub fn estimated_latency_us(&self) -> f64 { match self.tier { - WalkTier::CachedTemplate => 100.0, // <0.1ms - WalkTier::DynamicWalk => 3_000.0, // ~3ms + WalkTier::CachedTemplate => 100.0, // <0.1ms + WalkTier::DynamicWalk => 3_000.0, // ~3ms WalkTier::MarkovFallback => 200_000.0, // ~200ms } } @@ -112,7 +112,10 @@ fn extract_entity(text: &str) -> Option { let clean = word.trim_matches(|c: char| !c.is_alphanumeric()); if clean.len() > 1 && clean.chars().next().is_some_and(|c| c.is_uppercase()) - && !["The", "What", "Who", "Where", "How", "Is", "Was", "Tell", "A"].contains(&clean) + && ![ + "The", "What", "Who", "Where", "How", "Is", "Was", "Tell", "A", + ] + .contains(&clean) { return Some(clean.to_string()); } diff --git a/crates/kv-cache-benchmark/src/lib.rs b/crates/kv-cache-benchmark/src/lib.rs index 0d8fa60f..f4976acd 100644 --- a/crates/kv-cache-benchmark/src/lib.rs +++ b/crates/kv-cache-benchmark/src/lib.rs @@ -1,26 +1,29 @@ #![allow(clippy::empty_line_after_doc_comments)] #![allow(clippy::single_range_in_vec_init)] -pub mod model_config; +pub mod accuracy; +pub mod accuracy_suite; +pub mod benchmark; +pub mod graph_walk; +pub mod markov_residual; pub mod metrics; +pub mod model_config; +pub mod shader_bench; pub mod standard_kv; pub mod turboquant; -pub mod markov_residual; -pub mod graph_walk; -pub mod benchmark; -pub mod shader_bench; -pub mod accuracy; -pub mod accuracy_suite; #[cfg(feature = "real-model")] pub mod real_model; -#[cfg(feature = "real-model")] +// unlimited_context re-exports from larql_inference::engines — always available. pub mod unlimited_context; #[cfg(feature = "real-model")] pub mod apollo; +#[cfg(feature = "real-model")] +pub mod vindex_compare; + use metrics::Metrics; use model_config::ModelConfig; @@ -45,7 +48,12 @@ pub trait KvStrategy { fn encode(&self, keys: &[Vec], values: &[Vec]) -> Vec; /// Decode encoded bytes back to KV vectors. - fn decode(&self, encoded: &[u8], num_vectors: usize, dim: usize) -> (Vec>, Vec>); + fn decode( + &self, + encoded: &[u8], + num_vectors: usize, + dim: usize, + ) -> (Vec>, Vec>); /// Analytical memory for `seq_len` tokens (config-level, no data needed). fn memory_bytes(&self, config: &ModelConfig, seq_len: usize) -> usize; diff --git a/crates/kv-cache-benchmark/src/markov_residual/mod.rs b/crates/kv-cache-benchmark/src/markov_residual/mod.rs index 4cd9f1b4..731c5926 100644 --- a/crates/kv-cache-benchmark/src/markov_residual/mod.rs +++ b/crates/kv-cache-benchmark/src/markov_residual/mod.rs @@ -1,8 +1,8 @@ -pub mod window; pub mod checkpoint; pub mod cold_tier; +pub mod window; -use crate::{KvStrategy, model_config::ModelConfig}; +use crate::{model_config::ModelConfig, KvStrategy}; /// Strategy 3: Markov Residual Stream. /// @@ -89,7 +89,12 @@ impl KvStrategy for MarkovResidual { buf } - fn decode(&self, encoded: &[u8], num_vectors: usize, dim: usize) -> (Vec>, Vec>) { + fn decode( + &self, + encoded: &[u8], + num_vectors: usize, + dim: usize, + ) -> (Vec>, Vec>) { let total = u32::from_le_bytes([encoded[0], encoded[1], encoded[2], encoded[3]]) as usize; let window = u32::from_le_bytes([encoded[4], encoded[5], encoded[6], encoded[7]]) as usize; @@ -110,7 +115,12 @@ impl KvStrategy for MarkovResidual { let mut v = Vec::with_capacity(dim); for j in 0..dim { let o = offset + j * 4; - let x = f32::from_le_bytes([encoded[o], encoded[o + 1], encoded[o + 2], encoded[o + 3]]); + let x = f32::from_le_bytes([ + encoded[o], + encoded[o + 1], + encoded[o + 2], + encoded[o + 3], + ]); v.push(x); } keys.push(v.clone()); @@ -121,7 +131,9 @@ impl KvStrategy for MarkovResidual { } fn memory_bytes(&self, config: &ModelConfig, seq_len: usize) -> usize { - self.window_bytes(config) + self.checkpoint_bytes(config, seq_len) + self.cold_tier_bytes(seq_len) + self.window_bytes(config) + + self.checkpoint_bytes(config, seq_len) + + self.cold_tier_bytes(seq_len) } } @@ -143,7 +155,10 @@ mod tests { let _checkpoint_fixed = strategy.checkpoint_bytes(&config, 370_000); let cold_370k = strategy.cold_tier_bytes(370_000); - assert!(cold_370k < 2_000_000, "Cold tier (token IDs) should be < 2MB at 370K"); + assert!( + cold_370k < 2_000_000, + "Cold tier (token IDs) should be < 2MB at 370K" + ); // Total should be WAY less than standard KV let standard_mem = config.kv_memory(370_000); diff --git a/crates/kv-cache-benchmark/src/metrics.rs b/crates/kv-cache-benchmark/src/metrics.rs index a84aa794..3eb449ff 100644 --- a/crates/kv-cache-benchmark/src/metrics.rs +++ b/crates/kv-cache-benchmark/src/metrics.rs @@ -69,7 +69,11 @@ impl Metrics { let mut total = 0.0f64; for q in queries { assert_eq!(q.len(), original.len()); - let dot_orig: f64 = q.iter().zip(original).map(|(a, b)| *a as f64 * *b as f64).sum(); + let dot_orig: f64 = q + .iter() + .zip(original) + .map(|(a, b)| *a as f64 * *b as f64) + .sum(); let dot_recon: f64 = q .iter() .zip(reconstructed) diff --git a/crates/kv-cache-benchmark/src/real_model/decode_comparison.rs b/crates/kv-cache-benchmark/src/real_model/decode_comparison.rs index 2f71e76d..40602670 100644 --- a/crates/kv-cache-benchmark/src/real_model/decode_comparison.rs +++ b/crates/kv-cache-benchmark/src/real_model/decode_comparison.rs @@ -17,14 +17,15 @@ //! L1/L32 → parametric routing (static for in-context queries) //! L29/L30 → in-context comprehension (dynamic for in-context, static for parametric) -use ndarray::Array2; -use larql_inference::model::ModelWeights; +use larql_compute::MatMul; use larql_inference::attention::run_attention_block_decode_step; -use larql_inference::forward::{embed_tokens_pub, run_ffn, logits_to_predictions_pub}; use larql_inference::ffn::WeightFfn; +use larql_inference::forward::{embed_tokens_pub, logits_to_predictions_pub, run_ffn}; +use larql_inference::model::ModelWeights; +use ndarray::Array2; use super::kv_capture::capture_kv; -use super::markov_layer::{rs_prefill, rs_decode_step}; +use super::markov_layer::{rs_decode_step, rs_prefill}; /// Whether the answer is in the model's weights or planted in the prompt. #[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize)] @@ -83,20 +84,21 @@ pub fn run_decode_comparison( window_size: usize, decode_steps: usize, ) -> DecodeComparisonResult { - let prompt = tokenizer - .decode(token_ids, false) - .unwrap_or_default(); + let prompt = tokenizer.decode(token_ids, false).unwrap_or_default(); // --- Prefill ----------------------------------------------------------- // Both strategies share the same prefill. Divergence is decode-only. let kv = capture_kv(weights, token_ids); - let rs_result = rs_prefill(weights, token_ids, Some(window_size)); + let rs_result = rs_prefill( + weights, + token_ids, + Some(window_size), + &larql_compute::CpuBackend, + ); // Build per-layer mutable KV cache from captured tensors. - let mut kv_cache: Vec<(Array2, Array2)> = kv.keys - .into_iter() - .zip(kv.values) - .collect(); + let mut kv_cache: Vec<(Array2, Array2)> = + kv.keys.into_iter().zip(kv.values).collect(); // RS store starts with the bounded window from prefill. let mut rs_store = rs_result.store; @@ -104,7 +106,8 @@ pub fn run_decode_comparison( // Seed both decoders with the first predicted token (from the identical // prefill — this token is the same for both). let preds = logits_to_predictions_pub(weights, &kv.hidden, tokenizer, 1, 1.0); - let seed_token = preds.predictions + let seed_token = preds + .predictions .first() .map(|(t, _)| t.clone()) .unwrap_or_default(); @@ -123,17 +126,30 @@ pub fn run_decode_comparison( // --- Full-KV decode step --- let h_full = full_kv_step(weights, full_id, &mut kv_cache, next_pos, &ffn); let full_preds = logits_to_predictions_pub(weights, &h_full, tokenizer, 3, 1.0); - let next_full = full_preds.predictions.first().map(|(t, _)| t.clone()).unwrap_or_default(); - let next_full_prob = full_preds.predictions.first().map(|(_, p)| *p).unwrap_or(0.0); + let next_full = full_preds + .predictions + .first() + .map(|(t, _)| t.clone()) + .unwrap_or_default(); + let next_full_prob = full_preds + .predictions + .first() + .map(|(_, p)| *p) + .unwrap_or(0.0); // --- RS decode step --- - let (h_rs, new_store) = match rs_decode_step(weights, rs_id, rs_store) { - Some(r) => r, - None => break, - }; + let (h_rs, new_store) = + match rs_decode_step(weights, rs_id, rs_store, &larql_compute::CpuBackend) { + Some(r) => r, + None => break, + }; rs_store = new_store; let rs_preds = logits_to_predictions_pub(weights, &h_rs, tokenizer, 3, 1.0); - let next_rs = rs_preds.predictions.first().map(|(t, _)| t.clone()).unwrap_or_default(); + let next_rs = rs_preds + .predictions + .first() + .map(|(t, _)| t.clone()) + .unwrap_or_default(); let next_rs_prob = rs_preds.predictions.first().map(|(_, p)| *p).unwrap_or(0.0); let cosine = hidden_cosine(&h_full, &h_rs); @@ -182,9 +198,9 @@ fn full_kv_step( ) -> Array2 { let mut h = embed_tokens_pub(weights, &[token_id]); for (layer, kv_slot) in kv_cache.iter_mut().enumerate() { - let (h_post, new_kv) = run_attention_block_decode_step( - weights, &h, layer, Some(kv_slot), abs_position, - ).expect("full-KV decode step failed"); + let (h_post, new_kv) = + run_attention_block_decode_step(weights, &h, layer, Some(kv_slot), abs_position) + .expect("full-KV decode step failed"); *kv_slot = new_kv; let (h_out, _) = run_ffn(weights, &h_post, layer, ffn, false); h = h_out; @@ -196,10 +212,18 @@ fn full_kv_step( fn hidden_cosine(h1: &Array2, h2: &Array2) -> f64 { let v1 = h1.row(h1.shape()[0] - 1); let v2 = h2.row(h2.shape()[0] - 1); - let dot: f64 = v1.iter().zip(v2.iter()).map(|(&a, &b)| a as f64 * b as f64).sum(); + let dot: f64 = v1 + .iter() + .zip(v2.iter()) + .map(|(&a, &b)| a as f64 * b as f64) + .sum(); let n1: f64 = v1.iter().map(|&a| a as f64 * a as f64).sum::().sqrt(); let n2: f64 = v2.iter().map(|&a| a as f64 * a as f64).sum::().sqrt(); - if n1 * n2 < 1e-12 { 0.0 } else { dot / (n1 * n2) } + if n1 * n2 < 1e-12 { + 0.0 + } else { + dot / (n1 * n2) + } } /// Get the first token ID for a token string. @@ -268,7 +292,9 @@ pub fn format_window_sweep(results: &[DecodeComparisonResult]) -> String { r.window_size, format!("{:?}", r.query_type), r.match_rate * 100.0, - r.first_divergence.map(|d| d.to_string()).unwrap_or("-".to_string()), + r.first_divergence + .map(|d| d.to_string()) + .unwrap_or("-".to_string()), r.verdict(), )); } @@ -279,7 +305,14 @@ fn truncate(s: &str, max: usize) -> String { if s.chars().count() <= max { s.to_string() } else { - format!("{}…", &s[..s.char_indices().nth(max - 1).map(|(i, _)| i).unwrap_or(s.len())]) + format!( + "{}…", + &s[..s + .char_indices() + .nth(max - 1) + .map(|(i, _)| i) + .unwrap_or(s.len())] + ) } } @@ -302,11 +335,13 @@ pub fn in_context_prompts() -> Vec { // Medium gap — fact buried under filler "Remember: the answer is forty-two. \ The weather today is pleasant and calm. \ - The answer is".to_string(), + The answer is" + .to_string(), // Long gap — fact far from query "Note: the password is CRIMSON. \ It is a beautiful day outside. The sun is shining brightly. \ The birds are singing in the trees. \ - The password is".to_string(), + The password is" + .to_string(), ] } diff --git a/crates/kv-cache-benchmark/src/real_model/graph_walk_layer.rs b/crates/kv-cache-benchmark/src/real_model/graph_walk_layer.rs index bdbbb04c..dd3aaf94 100644 --- a/crates/kv-cache-benchmark/src/real_model/graph_walk_layer.rs +++ b/crates/kv-cache-benchmark/src/real_model/graph_walk_layer.rs @@ -8,10 +8,10 @@ //! B: dynamic graph walk (1-5ms) //! C: fallback to Markov RS (~200ms) -use larql_inference::model::ModelWeights; +use crate::graph_walk::walk_state::{WalkState, WalkTier}; use larql_inference::forward::embed_tokens_pub; +use larql_inference::model::ModelWeights; use larql_vindex::VectorIndex; -use crate::graph_walk::walk_state::{WalkState, WalkTier}; /// Result of graph walk prediction. pub struct GraphWalkResult { @@ -125,7 +125,12 @@ pub fn run_graph_walk_vindex_logits( // Use the existing predict_with_graph_vindex_logits pipeline let result = larql_inference::predict_with_graph_vindex_logits( - weights, tokenizer, token_ids, top_k, &walk_graph, index, + weights, + tokenizer, + token_ids, + top_k, + &walk_graph, + index, ); let latency_us = t0.elapsed().as_secs_f64() * 1e6; diff --git a/crates/kv-cache-benchmark/src/real_model/kv_capture.rs b/crates/kv-cache-benchmark/src/real_model/kv_capture.rs index dac1749b..1044c198 100644 --- a/crates/kv-cache-benchmark/src/real_model/kv_capture.rs +++ b/crates/kv-cache-benchmark/src/real_model/kv_capture.rs @@ -3,11 +3,11 @@ //! Runs `run_attention_with_kv()` per layer and collects the post-RoPE K and V //! tensors. These are the ground-truth vectors that TurboQuant compresses. -use ndarray::Array2; -use larql_inference::model::ModelWeights; use larql_inference::attention::run_attention_with_kv; -use larql_inference::forward::{embed_tokens_pub, run_ffn}; use larql_inference::ffn::WeightFfn; +use larql_inference::forward::{embed_tokens_pub, run_ffn}; +use larql_inference::model::ModelWeights; +use ndarray::Array2; /// Captured K/V tensors from a full forward pass. pub struct KvCapture { @@ -32,8 +32,8 @@ pub fn capture_kv(weights: &ModelWeights, token_ids: &[u32]) -> KvCapture { let mut values = Vec::with_capacity(num_layers); for layer in 0..num_layers { - let (h_post_attn, k_rope, v) = run_attention_with_kv(weights, &h, layer) - .expect("attention failed"); + let (h_post_attn, k_rope, v) = + run_attention_with_kv(weights, &h, layer).expect("attention failed"); keys.push(k_rope); values.push(v); diff --git a/crates/kv-cache-benchmark/src/real_model/markov_layer.rs b/crates/kv-cache-benchmark/src/real_model/markov_layer.rs index 77cac548..5c120c35 100644 --- a/crates/kv-cache-benchmark/src/real_model/markov_layer.rs +++ b/crates/kv-cache-benchmark/src/real_model/markov_layer.rs @@ -1,590 +1,10 @@ -//! Markov Residual Stream (RS) strategy on the real model. +//! Markov Residual Stream strategy — delegates to `larql_inference::engines::markov_residual`. //! -//! ## Core claim -//! -//! The pre-layer residual vector IS the complete Markov state of the -//! transformer at that position. Proven empirically on Gemma 3-4B: -//! transplanting full residuals from one forward pass into another -//! produces KL divergence = 0.0. No K/V cache is needed; K and V can be -//! recomputed from the stored residual at decode time at zero information -//! loss. -//! -//! ## Three-tier storage -//! -//! ```text -//! ┌─────────────────────────────────────────────────────────────────┐ -//! │ Cold tier │ Hot window │ New token │ -//! │ (evicted) │ (last W positions) │ (current decode) │ -//! │ residuals │ residuals │ embedded │ -//! └─────────────────────────────────────────────────────────────────┘ -//! ``` -//! -//! - **Hot window** (`stored`): the last `W` pre-layer residuals per layer, -//! shape `[W, hidden_dim]`. These are recomputed into K/V at every decode -//! step. W is small (e.g. 6–24 for the bounded-state experiment; 32 768 -//! for production RS+CA). -//! -//! - **Cold tier** (`cold_residuals`): residuals evicted from the hot window -//! during prefill are *kept* rather than discarded. At decode time these -//! are prepended to the hot window so the full attention prefix is -//! visible, matching full-KV output exactly (cos h = 1.000000). -//! -//! This is the Rust port of the Python `extend()` / `replay_window()` -//! mechanism in `rs_generator.py` / `unlimited_engine.py`. -//! -//! - **New token** (`h_new`): the freshly embedded token being decoded. -//! Its pre-layer residual is appended to the hot window after each step. -//! -//! ## Memory accounting (Gemma 3-4B: hidden=2560, num_kv=4, head_dim=256) -//! -//! ```text -//! Storage kind Bytes / position / layer -//! ───────────────────────────────────────────── -//! Hot-window residual 10,240 (f32, hidden_dim × 4) -//! Cold-tier residual 10,240 (same — full residual saved) -//! Standard KV (fp16) 4,096 (K + V × num_kv × head_dim × 2 bytes) -//! ``` -//! -//! For bounded-window decode experiments the cold tier stores the full -//! prefill history, so total memory equals standard KV × 2.5. The -//! production boundary-residual approach (store one summary residual per -//! window boundary + token IDs for replay) reduces cold storage to -//! ≈ 4 bytes/token — the v12 "56 GB → 2.1 MB" insight — but that -//! optimisation is orthogonal to the Markov correctness claim tested here. -//! -//! ## Decode step -//! -//! ```text -//! For each layer: -//! 1. full_h = concat([cold_residuals[l], hot_window[l]]) // [C+W, hidden] -//! 2. (K, V) = recompute_kv(full_h, abs_start=cold_abs_start) -//! (layernorm → K/V proj → QK-norm → RoPE at original positions) -//! 3. h_new = GQA(Q_new, K, V) // single-token query against full history -//! 4. h_new = FFN(h_new) -//! 5. Append h_new residual to hot window; clip overflow to cold tier. -//! ``` +//! This module is a thin re-export / compat shim so the benchmark runner +//! continues to work while the implementation lives in larql-inference. -use ndarray::{Array2, s}; -use larql_inference::model::ModelWeights; -use larql_inference::forward::{embed_tokens_pub, run_ffn, apply_norm, dot_proj, add_bias}; -use larql_inference::attention::{ - run_attention_with_kv, run_attention_block_decode_step, - apply_rope_partial_at, +pub use larql_inference::engines::accuracy::compare_hidden as compare_hidden_states; +pub use larql_inference::engines::markov_residual::{ + kv_memory_bytes_for_seq, recompute_kv, rs_decode_step, rs_prefill, MarkovResidualEngine, + RsPrefillResult, RsStore, }; -use larql_inference::residual::{rms_norm_heads, rms_norm_heads_no_weight}; -use larql_inference::ffn::WeightFfn; - -/// Per-layer pre-attention residuals for all stored positions. -/// `stored[i]` shape: `[S, hidden_dim]` — the residual entering layer `i` -/// for positions `[next_position - S, next_position)`. -/// -/// Cold-tier: when the hot window is smaller than the full sequence, -/// the evicted rows are saved in `cold_residuals` (one per layer). At -/// decode time both tiers are concatenated so attention covers the full -/// history — same as the Python `extend()` replay mechanism. -pub struct RsStore { - pub stored: Vec>, - /// Evicted (cold-tier) residuals: `cold_residuals[i]` holds rows that - /// were clipped from `stored[i]`. `None` when no eviction has occurred. - pub cold_residuals: Option>>, - /// Absolute position of the first token in the cold tier (0 if no cold tier). - pub cold_abs_start: usize, - /// Absolute token position of the NEXT token to be appended. - pub next_position: usize, - /// Optional sliding window: if `Some(W)`, only the last W residuals - /// are kept per layer; older ones are moved to the cold tier. - pub max_window: Option, -} - -impl RsStore { - /// Memory used by the stored residuals in bytes (f32). - pub fn memory_bytes(&self) -> usize { - let hot: usize = self.stored.iter().map(|s| s.len() * 4).sum(); - let cold: usize = self.cold_residuals.as_ref() - .map(|c| c.iter().map(|s| s.len() * 4).sum()) - .unwrap_or(0); - hot + cold - } - - /// Evict old positions beyond the window, saving them in the cold tier. - pub(crate) fn clip_layer(&mut self, layer: usize, cold: &mut Vec>) { - let window = match self.max_window { - Some(w) => w, - None => return, - }; - let s = &self.stored[layer]; - let rows = s.shape()[0]; - if rows <= window { - cold.push(Array2::zeros((0, s.shape()[1]))); - return; - } - let start = rows - window; - cold.push(s.slice(s![..start, ..]).to_owned()); - self.stored[layer] = s.slice(s![start.., ..]).to_owned(); - } -} - -/// Result of an RS prefill or decode step. -pub struct RsMarkovResult { - /// Final hidden state (last token position) after the forward pass. - pub hidden: Array2, - /// Residual store — holds pre-layer residuals for the active window. - pub store: RsStore, - /// Total memory used by the RS store in bytes. - pub memory_bytes: usize, - /// Active window token count (how many positions are stored). - pub window_tokens: usize, - /// Wall clock for the forward pass in microseconds. - pub forward_us: f64, -} - -/// Run the full prefill forward pass, storing pre-layer residuals. -/// -/// Equivalent to `capture_kv` but stores residuals instead of K/V. -/// The hidden state is identical — this is the same forward pass. -pub fn rs_prefill( - weights: &ModelWeights, - token_ids: &[u32], - max_window: Option, -) -> RsMarkovResult { - let num_layers = weights.num_layers; - let seq_len = token_ids.len(); - let ffn = WeightFfn { weights }; - - let t0 = std::time::Instant::now(); - - let mut h = embed_tokens_pub(weights, token_ids); - let mut stored: Vec> = Vec::with_capacity(num_layers); - - for layer in 0..num_layers { - // Store the pre-layer residual — this is the Markov state for this layer. - stored.push(h.clone()); - - let (h_post_attn, _k, _v) = run_attention_with_kv(weights, &h, layer) - .expect("attention failed"); - let (h_out, _) = run_ffn(weights, &h_post_attn, layer, &ffn, false); - h = h_out; - } - - let forward_us = t0.elapsed().as_secs_f64() * 1e6; - - let mut rs = RsStore { - stored, - cold_residuals: None, - cold_abs_start: 0, - next_position: seq_len, - max_window, - }; - - // Apply window clipping to all layers, saving evicted rows as cold tier. - let mut cold: Vec> = Vec::with_capacity(num_layers); - for layer in 0..num_layers { - rs.clip_layer(layer, &mut cold); - } - - // How many cold rows were saved (use layer 0 as reference). - let cold_rows = cold.first().map_or(0, |c| c.shape()[0]); - if cold_rows > 0 { - rs.cold_residuals = Some(cold); - // cold tier starts at position 0 (beginning of the prefill). - rs.cold_abs_start = 0; - } - - let window_tokens = rs.stored.first().map_or(0, |s| s.shape()[0]); - let memory_bytes = rs.memory_bytes(); - - RsMarkovResult { - hidden: last_row(&h), - store: rs, - memory_bytes, - window_tokens, - forward_us, - } -} - -/// Run one decode step for a new token using the RS store. -/// -/// For each layer: -/// 1. Recompute K/V from stored residuals (norm → proj → k-norm → RoPE at -/// original positions). -/// 2. Run single-token decode attention against [K_old | K_new]. -/// 3. Run FFN on the new token. -/// 4. Append the pre-layer residual of the new token to the store. -/// -/// Returns the updated hidden state (1 × hidden_dim) and updated store. -pub fn rs_decode_step( - weights: &ModelWeights, - new_token_id: u32, - rs: RsStore, -) -> Option<(Array2, RsStore)> { - let num_layers = weights.num_layers; - let ffn = WeightFfn { weights }; - let abs_position = rs.next_position; - - let mut h_new = embed_tokens_pub(weights, &[new_token_id]); - let mut new_stored: Vec> = Vec::with_capacity(num_layers); - - for layer in 0..num_layers { - let h_hot = &rs.stored[layer]; // [S_hot, hidden_dim] - let s_hot = h_hot.shape()[0]; - - // Concatenate cold tier + hot tier for full-history attention. - let (h_full, full_abs_start) = if let Some(cold) = &rs.cold_residuals { - let h_cold = &cold[layer]; - let s_cold = h_cold.shape()[0]; - if s_cold > 0 { - let hidden = h_hot.shape()[1]; - let mut combined = Array2::::zeros((s_cold + s_hot, hidden)); - combined.slice_mut(s![..s_cold, ..]).assign(h_cold); - combined.slice_mut(s![s_cold.., ..]).assign(h_hot); - (combined, rs.cold_abs_start) - } else { - (h_hot.clone(), abs_position.saturating_sub(s_hot)) - } - } else { - (h_hot.clone(), abs_position.saturating_sub(s_hot)) - }; - - // Recompute K/V from full history (cold + hot). - let (k_recomputed, v_recomputed) = - recompute_kv(weights, &h_full, layer, full_abs_start)?; - - // Save pre-layer residual for the new token before processing. - new_stored.push(h_new.clone()); - - // Decode-step attention: new token Q against [K_old | K_new]. - let (h_post_attn, _new_kv) = run_attention_block_decode_step( - weights, &h_new, layer, Some(&(k_recomputed, v_recomputed)), abs_position, - )?; - - let (h_out, _) = run_ffn(weights, &h_post_attn, layer, &ffn, false); - h_new = h_out; - } - - // Merge old hot residuals with new token's pre-layer residual. - let mut updated_stored: Vec> = Vec::with_capacity(num_layers); - for (stored, new_row) in rs.stored.iter().zip(new_stored.iter()) { - let s_old = stored.shape()[0]; - let hidden_dim = stored.shape()[1]; - let mut combined = Array2::::zeros((s_old + 1, hidden_dim)); - combined.slice_mut(s![..s_old, ..]).assign(stored); - combined.slice_mut(s![s_old.., ..]).assign(new_row); - updated_stored.push(combined); - } - - // Preserve cold tier; carry cold_abs_start forward. - let cold_residuals = rs.cold_residuals; - let cold_abs_start = rs.cold_abs_start; - let max_window = rs.max_window; - - let mut updated_rs = RsStore { - stored: updated_stored, - cold_residuals, - cold_abs_start, - next_position: abs_position + 1, - max_window, - }; - - // Clip hot tier; any newly evicted rows accumulate into the cold tier. - let mut overflow: Vec> = Vec::with_capacity(num_layers); - for layer in 0..num_layers { - updated_rs.clip_layer(layer, &mut overflow); - } - // Merge overflow into existing cold tier (append at the end of each layer). - let overflow_rows = overflow.first().map_or(0, |c| c.shape()[0]); - if overflow_rows > 0 { - match updated_rs.cold_residuals.as_mut() { - Some(cold) => { - for layer in 0..num_layers { - let hidden = cold[layer].shape()[1]; - let c_old = cold[layer].shape()[0]; - let c_new = overflow[layer].shape()[0]; - let mut merged = Array2::::zeros((c_old + c_new, hidden)); - merged.slice_mut(s![..c_old, ..]).assign(&cold[layer]); - merged.slice_mut(s![c_old.., ..]).assign(&overflow[layer]); - cold[layer] = merged; - } - } - None => { - updated_rs.cold_residuals = Some(overflow); - } - } - } - - Some((last_row(&h_new), updated_rs)) -} - -/// Recompute K/V from stored pre-layer residuals. -/// -/// Mirrors the Python `_raw_step` K/V recomputation: -/// x_old = layernorm(h_old) -/// k_old = k_proj(x_old) → k_norm → RoPE at positions abs_start.. -/// v_old = v_proj(x_old) → v_norm -pub(crate) fn recompute_kv( - weights: &ModelWeights, - h_stored: &Array2, // [S, hidden_dim] - layer: usize, - abs_start: usize, -) -> Option<(Array2, Array2)> { - let arch = &*weights.arch; - let head_dim = arch.head_dim_for_layer(layer); - let num_kv = arch.num_kv_heads_for_layer(layer); - let norm_offset = arch.norm_weight_offset(); - let qk_offset = arch.qk_norm_weight_offset(); - let qk_norm_off = if qk_offset != 0.0 { qk_offset } else { norm_offset }; - - let h_norm = apply_norm(weights, h_stored, &arch.input_layernorm_key(layer), norm_offset); - - let w_k = weights.tensors.get(&arch.attn_k_key(layer))?; - let v_from_k = !weights.tensors.contains_key(&arch.attn_v_key(layer)); - let w_v = if v_from_k { w_k } else { weights.tensors.get(&arch.attn_v_key(layer))? }; - - let mut k = dot_proj(&h_norm, w_k); - let mut v = dot_proj(&h_norm, w_v); - - if let Some(bias) = arch.attn_k_bias_key(layer).and_then(|k| weights.vectors.get(&k)) { - add_bias(&mut k, bias); - } - if let Some(bias) = arch.attn_v_bias_key(layer).and_then(|k| weights.vectors.get(&k)) { - add_bias(&mut v, bias); - } - - if arch.has_v_norm() { - v = rms_norm_heads_no_weight(&v, num_kv, head_dim); - } - let k_normed = match arch.attn_k_norm_key(layer).and_then(|k| weights.vectors.get(&k)) { - Some(norm_w) => rms_norm_heads(&k, norm_w, num_kv, head_dim, qk_norm_off), - None => k, - }; - - let layer_rope_base = arch.rope_base_for_layer(layer); - let rotary_frac = arch.rotary_fraction_for_layer(layer); - // Apply RoPE at the original absolute positions of the stored tokens. - let k_rope = apply_rope_partial_at( - &k_normed, num_kv, head_dim, layer_rope_base, rotary_frac, abs_start, - ); - - Some((k_rope, v)) -} - -/// Memory used by a standard KV cache (FP16) for comparison. -pub fn kv_memory_bytes_for_seq(weights: &ModelWeights, seq_len: usize) -> usize { - let arch = &*weights.arch; - let mut total = 0; - for layer in 0..weights.num_layers { - let num_kv = arch.num_kv_heads_for_layer(layer); - let head_dim = arch.head_dim_for_layer(layer); - let kv_dim = num_kv * head_dim; - // K + V, FP16 (2 bytes each) - total += seq_len * kv_dim * 2 * 2; - } - total -} - -/// Compare two hidden states (last-row cosine and MSE). -pub fn compare_hidden_states(h1: &Array2, h2: &Array2) -> (f64, f64) { - let v1: Vec = h1.row(h1.shape()[0] - 1).to_vec(); - let v2: Vec = h2.row(h2.shape()[0] - 1).to_vec(); - let mse = crate::metrics::Metrics::compute_mse(&v1, &v2); - let cosine = crate::metrics::Metrics::compute_cosine(&v1, &v2); - (mse, cosine) -} - -fn last_row(h: &Array2) -> Array2 { - let last = h.shape()[0] - 1; - h.slice(s![last..=last, ..]).to_owned() -} - -#[cfg(test)] -mod tests { - use super::*; - - fn make_rs(num_layers: usize, seq_len: usize, hidden: usize, window: Option) -> RsStore { - let stored = (0..num_layers) - .enumerate() - .map(|(l, _)| { - // Each layer gets distinct row values so splits are verifiable. - let mut a = Array2::::zeros((seq_len, hidden)); - for i in 0..seq_len { - a.row_mut(i).fill((l * 1000 + i) as f32); - } - a - }) - .collect(); - RsStore { - stored, - cold_residuals: None, - cold_abs_start: 0, - next_position: seq_len, - max_window: window, - } - } - - // ── clip_layer ─────────────────────────────────────────────────────────── - - #[test] - fn clip_no_window_keeps_all() { - let mut rs = make_rs(1, 10, 4, None); - let mut cold = Vec::new(); - rs.clip_layer(0, &mut cold); - assert_eq!(rs.stored[0].shape()[0], 10); - assert!(cold.is_empty(), "no cold entry pushed when max_window is None"); - } - - #[test] - fn clip_exact_window_keeps_all() { - let mut rs = make_rs(1, 5, 4, Some(5)); - let mut cold = Vec::new(); - rs.clip_layer(0, &mut cold); - assert_eq!(rs.stored[0].shape()[0], 5); - assert_eq!(cold[0].shape()[0], 0, "no cold rows when seq_len == window"); - } - - #[test] - fn clip_splits_hot_cold_correctly() { - // 10 rows, window=4 → cold gets rows 0..6, hot keeps rows 6..10. - let mut rs = make_rs(1, 10, 4, Some(4)); - let mut cold = Vec::new(); - rs.clip_layer(0, &mut cold); - - assert_eq!(cold[0].shape()[0], 6, "6 rows evicted to cold"); - assert_eq!(rs.stored[0].shape()[0], 4, "4 rows remain in hot window"); - - // Cold contains the OLDEST rows (indices 0..6). - for i in 0..6 { - assert_eq!(cold[0][[i, 0]], i as f32, "cold row {i} has correct value"); - } - // Hot contains the NEWEST rows (indices 6..10). - for i in 0..4 { - assert_eq!(rs.stored[0][[i, 0]], (6 + i) as f32, "hot row {i} has correct value"); - } - } - - #[test] - fn clip_multi_layer_consistent() { - // Each layer has different values but the same split should apply. - let mut rs = make_rs(3, 8, 4, Some(3)); - let mut cold = Vec::new(); - for layer in 0..3 { - rs.clip_layer(layer, &mut cold); - } - for (l, (c, s)) in cold.iter().zip(rs.stored.iter()).enumerate() { - assert_eq!(c.shape()[0], 5, "layer {l}: 5 cold rows"); - assert_eq!(s.shape()[0], 3, "layer {l}: 3 hot rows"); - } - } - - // ── RsStore cold-tier field wiring (simulating rs_prefill clip) ────────── - - #[test] - fn prefill_clip_wires_cold_residuals() { - let num_layers = 2; - let seq_len = 10; - let window = 4; - let hidden = 8; - - let mut rs = make_rs(num_layers, seq_len, hidden, Some(window)); - let mut cold: Vec> = Vec::with_capacity(num_layers); - for layer in 0..num_layers { - rs.clip_layer(layer, &mut cold); - } - let cold_rows = cold.first().map_or(0, |c| c.shape()[0]); - assert_eq!(cold_rows, seq_len - window); - - rs.cold_residuals = Some(cold); - rs.cold_abs_start = 0; - - assert_eq!(rs.stored[0].shape()[0], window, "hot window trimmed to {window}"); - let cold_ref = rs.cold_residuals.as_ref().unwrap(); - assert_eq!(cold_ref[0].shape()[0], seq_len - window, "cold tier has evicted rows"); - assert_eq!(rs.cold_abs_start, 0); - } - - #[test] - fn no_cold_when_seq_within_window() { - let mut rs = make_rs(2, 3, 4, Some(6)); - let mut cold: Vec> = Vec::new(); - for layer in 0..2 { - rs.clip_layer(layer, &mut cold); - } - let cold_rows = cold.first().map_or(0, |c| c.shape()[0]); - assert_eq!(cold_rows, 0, "no cold tier when seq_len ≤ window"); - } - - // ── memory_bytes includes both tiers ───────────────────────────────────── - - #[test] - fn memory_bytes_hot_only() { - let rs = make_rs(2, 4, 8, None); - // 2 layers × 4 rows × 8 hidden × 4 bytes = 256 - assert_eq!(rs.memory_bytes(), 2 * 4 * 8 * 4); - } - - #[test] - fn memory_bytes_includes_cold_tier() { - let num_layers = 2; - let seq_len = 10; - let window = 4; - let hidden = 8; - let mut rs = make_rs(num_layers, seq_len, hidden, Some(window)); - let mut cold: Vec> = Vec::with_capacity(num_layers); - for layer in 0..num_layers { - rs.clip_layer(layer, &mut cold); - } - rs.cold_residuals = Some(cold); - - let hot_bytes = num_layers * window * hidden * 4; - let cold_bytes = num_layers * (seq_len - window) * hidden * 4; - assert_eq!(rs.memory_bytes(), hot_bytes + cold_bytes); - } - - // ── cold-tier carry-forward in decode step ──────────────────────────────── - - #[test] - fn decode_step_overflow_merges_into_cold() { - // Simulate the overflow merge: hot at capacity + 1 new row → 1 row - // spills to cold, cold grows by 1. - let window = 3; - let hidden = 4; - - // Start: hot = [window rows], cold = [2 rows] already - let hot: Vec> = vec![Array2::ones((window, hidden))]; - let existing_cold: Vec> = vec![Array2::zeros((2, hidden))]; - - let mut rs = RsStore { - stored: hot.clone(), - cold_residuals: Some(existing_cold), - cold_abs_start: 0, - next_position: 2 + window, // cold=2, hot=3 - max_window: Some(window), - }; - - // Append one new row — hot grows to window+1, then clip evicts 1 row to overflow. - let new_row = Array2::::from_elem((1, hidden), 9.0); - let s_old = rs.stored[0].shape()[0]; - let mut combined = Array2::::zeros((s_old + 1, hidden)); - combined.slice_mut(s![..s_old, ..]).assign(&rs.stored[0]); - combined.slice_mut(s![s_old.., ..]).assign(&new_row); - rs.stored[0] = combined; - - let mut overflow: Vec> = Vec::new(); - rs.clip_layer(0, &mut overflow); - - // overflow should have 1 row - assert_eq!(overflow[0].shape()[0], 1); - - // Merge into existing cold - if let Some(cold) = rs.cold_residuals.as_mut() { - let c_old = cold[0].shape()[0]; - let c_new = overflow[0].shape()[0]; - let mut merged = Array2::::zeros((c_old + c_new, hidden)); - merged.slice_mut(s![..c_old, ..]).assign(&cold[0]); - merged.slice_mut(s![c_old.., ..]).assign(&overflow[0]); - cold[0] = merged; - } - - let cold_ref = rs.cold_residuals.as_ref().unwrap(); - assert_eq!(cold_ref[0].shape()[0], 3, "existing 2 + overflow 1 = 3 cold rows"); - assert_eq!(rs.stored[0].shape()[0], window, "hot stays at window size"); - } -} diff --git a/crates/kv-cache-benchmark/src/real_model/mod.rs b/crates/kv-cache-benchmark/src/real_model/mod.rs index 5cccfe67..409c5a42 100644 --- a/crates/kv-cache-benchmark/src/real_model/mod.rs +++ b/crates/kv-cache-benchmark/src/real_model/mod.rs @@ -8,11 +8,11 @@ //! - Markov RS: runs bounded-window forward pass, stores residuals + cold tier token IDs //! - Graph Walk: vindex walk through FFN graph, no forward pass for factual queries -pub mod runner; +pub mod decode_comparison; +pub mod graph_walk_layer; pub mod kv_capture; -pub mod turboquant_layer; pub mod markov_layer; -pub mod graph_walk_layer; -pub mod decode_comparison; +pub mod runner; +pub mod turboquant_layer; -pub use runner::{RealModelBenchmark, RealModelResult, run_all_strategies}; +pub use runner::{run_all_strategies, RealModelBenchmark, RealModelResult}; diff --git a/crates/kv-cache-benchmark/src/real_model/runner.rs b/crates/kv-cache-benchmark/src/real_model/runner.rs index 04480368..387c9bd9 100644 --- a/crates/kv-cache-benchmark/src/real_model/runner.rs +++ b/crates/kv-cache-benchmark/src/real_model/runner.rs @@ -13,18 +13,20 @@ //! decode time. //! 4. Graph Walk — vindex FFN walk; no forward pass for factual queries. +use larql_compute::ComputeBackend; +use larql_inference::engines::accuracy::compare_hidden; +use larql_inference::engines::markov_residual::kv_memory_bytes_for_seq; +use larql_inference::engines::{EngineKind, KvEngine}; +use larql_inference::forward::{hidden_to_raw_logits, logits_to_predictions_pub}; use larql_inference::model::ModelWeights; -use larql_inference::forward::logits_to_predictions_pub; use larql_vindex::VectorIndex; -use larql_compute::ComputeBackend; +use super::graph_walk_layer; use super::kv_capture; -use super::turboquant_layer; use super::markov_layer; -use super::graph_walk_layer; +use super::turboquant_layer; use crate::turboquant::TurboQuant; - /// Result from running one strategy on a real model. #[derive(Debug, Clone, serde::Serialize)] pub struct RealModelResult { @@ -39,6 +41,34 @@ pub struct RealModelResult { pub top1_match: bool, /// Cosine similarity of hidden state vs baseline (where applicable) pub hidden_cosine: Option, + /// Hot-window bytes (for engines that expose it). + pub hot_bytes: Option, + /// Cold-tier bytes. + pub cold_bytes: Option, + /// Compression ratio vs Standard KV (FP16). + pub compression_ratio: Option, +} + +/// Timing + accuracy result from a single `KvEngine` run. +#[derive(Debug, Clone, serde::Serialize)] +pub struct EngineTimingResult { + pub engine: String, + pub prompt: String, + pub top1_token: String, + pub top1_match: bool, + pub hidden_cosine: f64, + pub prefill_ms: f64, + pub hot_bytes: usize, + pub cold_bytes: usize, + pub total_bytes: usize, + pub kv_ref_bytes: usize, + pub compression_ratio: f64, +} + +impl EngineTimingResult { + pub fn compression_label(&self) -> String { + format!("{:.0}×", self.compression_ratio) + } } /// Full benchmark: run all four strategies on the same prompt. @@ -56,7 +86,12 @@ impl<'a> RealModelBenchmark<'a> { index: &'a VectorIndex, backend: &'a dyn ComputeBackend, ) -> Self { - Self { weights, tokenizer, index, backend } + Self { + weights, + tokenizer, + index, + backend, + } } } @@ -67,7 +102,10 @@ pub fn run_all_strategies( top_k: usize, window_size: usize, ) -> Vec { - let encoding = bench.tokenizer.encode(prompt, true).expect("tokenize failed"); + let encoding = bench + .tokenizer + .encode(prompt, true) + .expect("tokenize failed"); let token_ids: Vec = encoding.get_ids().to_vec(); let mut results = Vec::with_capacity(4); @@ -75,26 +113,35 @@ pub fn run_all_strategies( // === Strategy 1: Standard KV (baseline) === let t0 = std::time::Instant::now(); let kv = kv_capture::capture_kv(bench.weights, &token_ids); - let baseline_preds = logits_to_predictions_pub( - bench.weights, &kv.hidden, bench.tokenizer, top_k, 1.0, - ); + let baseline_preds = + logits_to_predictions_pub(bench.weights, &kv.hidden, bench.tokenizer, top_k, 1.0); let std_us = t0.elapsed().as_secs_f64() * 1e6; let std_mem = kv_capture::kv_memory_bytes(&kv); - let baseline_top1 = baseline_preds.predictions.first() + let baseline_top1 = baseline_preds + .predictions + .first() .map(|(t, _)| t.clone()) .unwrap_or_default(); + let kv_ref_bytes = kv_memory_bytes_for_seq(bench.weights, token_ids.len()); results.push(RealModelResult { strategy: "Standard KV (FP16)".to_string(), prompt: prompt.to_string(), top1_token: baseline_top1.clone(), - top1_prob: baseline_preds.predictions.first().map(|(_, p)| *p).unwrap_or(0.0), + top1_prob: baseline_preds + .predictions + .first() + .map(|(_, p)| *p) + .unwrap_or(0.0), top5: baseline_preds.predictions.clone(), memory_bytes: std_mem, wall_clock_us: std_us, - top1_match: true, // baseline matches itself + top1_match: true, hidden_cosine: Some(1.0), + hot_bytes: Some(std_mem), + cold_bytes: Some(0), + compression_ratio: Some(1.0), }); // === Strategy 2: TurboQuant 4-bit === @@ -102,84 +149,91 @@ pub fn run_all_strategies( let tq = TurboQuant::new(4); let tq_result = turboquant_layer::apply_turboquant(&kv, &tq); let tq_us = t0.elapsed().as_secs_f64() * 1e6; - - // TurboQuant doesn't change the forward pass output — it compresses the stored K/V. - // The accuracy impact shows up when dequantized K/V is used for attention. - // For the benchmark, we report compression stats. The hidden state is identical - // because TQ is applied post-forward-pass (cache compression, not compute change). + let tq_ratio = kv_ref_bytes as f64 / tq_result.compressed_bytes as f64; results.push(RealModelResult { - strategy: format!("TurboQuant 4-bit (MSE={:.6}, cos={:.4})", tq_result.mse, tq_result.cosine_sim), + strategy: format!("TurboQuant 4-bit (cos={:.4})", tq_result.cosine_sim), prompt: prompt.to_string(), - top1_token: baseline_top1.clone(), // Same forward pass - top1_prob: baseline_preds.predictions.first().map(|(_, p)| *p).unwrap_or(0.0), + top1_token: baseline_top1.clone(), + top1_prob: baseline_preds + .predictions + .first() + .map(|(_, p)| *p) + .unwrap_or(0.0), top5: baseline_preds.predictions.clone(), memory_bytes: tq_result.compressed_bytes, - wall_clock_us: std_us + tq_us, // Forward pass + quantize overhead - top1_match: true, // Same forward pass, TQ is storage compression - hidden_cosine: Some(1.0), // Hidden state unchanged + wall_clock_us: std_us + tq_us, + top1_match: true, + hidden_cosine: Some(1.0), + hot_bytes: Some(tq_result.compressed_bytes), + cold_bytes: Some(0), + compression_ratio: Some(tq_ratio), }); - // === Strategy 3: Markov Residual Stream === - // - // Stores pre-layer residuals instead of K/V. At decode time, K/V are - // recomputed from stored residuals — the residual IS the complete Markov - // state (proven: KL=0.0, cos h=1.000000 at all window sizes). + // === Strategy 3: Markov Residual Stream (via KvEngine trait) === // - // Three-tier storage (Rust port of Python rs_generator.py extend()): - // hot window — last W residuals per layer (recomputed into K/V each step) - // cold tier — evicted residuals from prefill (prepended at decode time - // so full history is visible; matches full-KV exactly) - // new token — current embed, appended after each decode step - // - // The memory_bytes reported here includes both hot + cold tier residuals. + // Uses `MarkovResidualEngine::prefill` via the unified `KvEngine` interface. + // Backend-dispatched: K/V projection matmuls route through the compute backend. let t0 = std::time::Instant::now(); - let rs_result = markov_layer::rs_prefill(bench.weights, &token_ids, Some(window_size)); - let rs_preds = logits_to_predictions_pub( - bench.weights, &rs_result.hidden, bench.tokenizer, top_k, 1.0, - ); + let mut rs_engine = EngineKind::MarkovResidual { + window_size: Some(window_size), + } + .build(larql_compute::cpu_backend()); + let rs_hidden = rs_engine + .prefill(bench.weights, &token_ids) + .expect("MarkovRS prefill failed"); + let rs_preds = + logits_to_predictions_pub(bench.weights, &rs_hidden, bench.tokenizer, top_k, 1.0); let rs_us = t0.elapsed().as_secs_f64() * 1e6; - let rs_top1 = rs_preds.predictions.first() + let rs_top1 = rs_preds + .predictions + .first() .map(|(t, _)| t.clone()) .unwrap_or_default(); + let rs_acc = compare_hidden(&kv.hidden, &rs_hidden); + let rs_cold = rs_engine.cold_bytes(); + let rs_hot = rs_engine.memory_bytes().saturating_sub(rs_cold); + let rs_ratio = if rs_engine.memory_bytes() > 0 { + kv_ref_bytes as f64 / rs_engine.memory_bytes() as f64 + } else { + 0.0 + }; - let (_rs_mse, rs_cosine) = markov_layer::compare_hidden_states( - &kv.hidden, &rs_result.hidden, - ); - - // Show both RS store memory and equivalent standard-KV memory for context. - let kv_equiv_bytes = markov_layer::kv_memory_bytes_for_seq(bench.weights, token_ids.len()); - let rs_window = rs_result.window_tokens; - let cold_bytes = rs_result.store.cold_residuals.as_ref() - .map(|c| c.iter().map(|s| s.len() * 4).sum::()) - .unwrap_or(0); - let hot_bytes = rs_result.memory_bytes - cold_bytes; results.push(RealModelResult { strategy: format!( - "Markov RS (hot={:.1}KB cold={:.1}KB KV={:.1}KB win={})", - hot_bytes as f64 / 1024.0, - cold_bytes as f64 / 1024.0, - kv_equiv_bytes as f64 / 1024.0, - rs_window, + "Markov RS W={} (hot={:.1}KB cold={:.1}KB {:.0}×)", + rs_engine.window_tokens(), + rs_hot as f64 / 1024.0, + rs_cold as f64 / 1024.0, + rs_ratio, ), prompt: prompt.to_string(), top1_token: rs_top1.clone(), top1_prob: rs_preds.predictions.first().map(|(_, p)| *p).unwrap_or(0.0), top5: rs_preds.predictions, - memory_bytes: rs_result.memory_bytes, + memory_bytes: rs_engine.memory_bytes(), wall_clock_us: rs_us, top1_match: rs_top1 == baseline_top1, - hidden_cosine: Some(rs_cosine), + hidden_cosine: Some(rs_acc.cosine), + hot_bytes: Some(rs_hot), + cold_bytes: Some(rs_cold), + compression_ratio: Some(rs_ratio), }); // === Strategy 4: Graph Walk === let t0 = std::time::Instant::now(); let gw = graph_walk_layer::run_graph_walk( - bench.weights, bench.tokenizer, bench.index, &token_ids, top_k, + bench.weights, + bench.tokenizer, + bench.index, + &token_ids, + top_k, ); let gw_us = t0.elapsed().as_secs_f64() * 1e6; - let gw_top1 = gw.predictions.first() + let gw_top1 = gw + .predictions + .first() .map(|(t, _)| t.clone()) .unwrap_or_default(); @@ -193,11 +247,133 @@ pub fn run_all_strategies( wall_clock_us: gw_us, top1_match: gw_top1 == baseline_top1, hidden_cosine: None, + hot_bytes: None, + cold_bytes: None, + compression_ratio: Some(kv_ref_bytes as f64 / gw.memory_bytes.max(1) as f64), }); results } +/// Benchmark all registered `KvEngine` implementations on a prompt. +/// +/// Times prefill only (single token generation is too noisy for a one-shot +/// call; for decode timing use `larql bench --engine`). Returns one result +/// per engine in insertion order. +pub fn run_all_engines_bench( + weights: &ModelWeights, + tokenizer: &tokenizers::Tokenizer, + prompt: &str, + window_size: usize, + backend: &dyn ComputeBackend, +) -> Vec { + let encoding = tokenizer.encode(prompt, true).expect("tokenize failed"); + let token_ids: Vec = encoding.get_ids().to_vec(); + + // Standard KV hidden state for cosine comparison. + let kv = kv_capture::capture_kv(weights, &token_ids); + let kv_ref_bytes = kv_memory_bytes_for_seq(weights, token_ids.len()); + + let engines: &[(&str, EngineKind)] = &[ + ( + "markov-rs", + EngineKind::MarkovResidual { + window_size: Some(window_size), + }, + ), + ( + "unlimited-context", + EngineKind::UnlimitedContext { window_size }, + ), + ]; + + let mut results = Vec::new(); + for (label, kind) in engines { + let mut engine = kind.clone().build(larql_compute::cpu_backend()); + + let t0 = std::time::Instant::now(); + let hidden = match engine.prefill(weights, &token_ids) { + Some(h) => h, + None => { + eprintln!("[engine bench] {label}: prefill returned None"); + continue; + } + }; + let prefill_ms = t0.elapsed().as_secs_f64() * 1000.0; + + let logits = hidden_to_raw_logits(weights, &hidden); + let top1_idx = logits + .iter() + .enumerate() + .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal)) + .map(|(i, _)| i as u32) + .unwrap_or(0); + let top1_token = tokenizer.decode(&[top1_idx], true).unwrap_or_default(); + let top1_match = top1_token + == tokenizer + .decode( + &[logits + .iter() + .enumerate() + .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal)) + .map(|(i, _)| i as u32) + .unwrap_or(0)], + true, + ) + .unwrap_or_default(); + + let acc = compare_hidden(&kv.hidden, &hidden); + let cold = engine.cold_bytes(); + let hot = engine.memory_bytes().saturating_sub(cold); + let total = engine.memory_bytes(); + let ratio = if total > 0 { + kv_ref_bytes as f64 / total as f64 + } else { + 0.0 + }; + let _ = backend; // engines build with cpu_backend(); backend param reserved for future + + results.push(EngineTimingResult { + engine: label.to_string(), + prompt: prompt.to_string(), + top1_token, + top1_match, + hidden_cosine: acc.cosine, + prefill_ms, + hot_bytes: hot, + cold_bytes: cold, + total_bytes: total, + kv_ref_bytes, + compression_ratio: ratio, + }); + } + results +} + +/// Format `run_all_engines_bench` output as an ASCII table. +pub fn format_engine_results(results: &[EngineTimingResult]) -> String { + let mut out = String::new(); + out.push_str(&format!( + "\n{:<22} {:>10} {:>10} {:>10} {:>8} {:>6} {}\n", + "Engine", "prefill_ms", "hot_MB", "cold_MB", "ratio×", "cos", "top1", + )); + out.push_str(&"-".repeat(90)); + out.push('\n'); + for r in results { + out.push_str(&format!( + "{:<22} {:>10.1} {:>10.1} {:>10.1} {:>8.0} {:>6.4} {}\n", + r.engine, + r.prefill_ms, + r.hot_bytes as f64 / 1_048_576.0, + r.cold_bytes as f64 / 1_048_576.0, + r.compression_ratio, + r.hidden_cosine, + r.top1_token, + )); + } + out +} + /// Run multiple prompts and aggregate results. pub fn run_prompt_suite( bench: &RealModelBenchmark, @@ -205,48 +381,55 @@ pub fn run_prompt_suite( top_k: usize, window_size: usize, ) -> Vec> { - prompts.iter().map(|p| run_all_strategies(bench, p, top_k, window_size)).collect() + prompts + .iter() + .map(|p| run_all_strategies(bench, p, top_k, window_size)) + .collect() } -/// Format results as a comparison table. +/// Format results as a comparison table including compression ratio. pub fn format_results(results: &[RealModelResult]) -> String { let mut out = String::new(); - out.push_str(&format!("\n=== Real Model Benchmark: \"{}\" ===\n\n", results[0].prompt)); + if let Some(r) = results.first() { + out.push_str(&format!( + "\n=== Real Model Benchmark: {:?} ===\n\n", + r.prompt + )); + } out.push_str(&format!( - "{:<40} {:>10} {:>12} {:>10} {:>8}\n", - "Strategy", "Top-1", "Memory", "Time (ms)", "Match?" + "{:<44} {:>8} {:>10} {:>8} {:>7} {}\n", + "Strategy", "Top-1", "Memory", "ms", "ratio×", "cos/match", )); - out.push_str(&"-".repeat(85)); + out.push_str(&"-".repeat(95)); out.push('\n'); for r in results { let mem_str = if r.memory_bytes >= 1_000_000 { - format!("{:.1} MB", r.memory_bytes as f64 / 1e6) + format!("{:.1}MB", r.memory_bytes as f64 / 1e6) } else if r.memory_bytes >= 1_000 { - format!("{:.1} KB", r.memory_bytes as f64 / 1e3) + format!("{:.1}KB", r.memory_bytes as f64 / 1e3) + } else { + format!("{}B", r.memory_bytes) + }; + let ratio_str = r + .compression_ratio + .map(|c| format!("{c:.0}×")) + .unwrap_or_else(|| "—".into()); + let accuracy_str = if let Some(cos) = r.hidden_cosine { + format!("{cos:.4}") } else { - format!("{} B", r.memory_bytes) + (if r.top1_match { "match" } else { "miss" }).into() }; - let match_str = if r.top1_match { "YES" } else { "no" }; out.push_str(&format!( - "{:<40} {:>10} {:>12} {:>10.1} {:>8}\n", + "{:<44} {:>8} {:>10} {:>8.1} {:>7} {}\n", r.strategy, r.top1_token, mem_str, r.wall_clock_us / 1000.0, - match_str, + ratio_str, + accuracy_str, )); } - - if let Some(r) = results.iter().find(|r| r.strategy.contains("Markov RS")) { - if let Some(cosine) = r.hidden_cosine { - out.push_str(&format!( - "\nMarkov RS: hidden cosine vs baseline = {cosine:.6} \ - (should be ~1.0 — same forward pass, different storage format)\n" - )); - } - } - out } diff --git a/crates/kv-cache-benchmark/src/real_model/turboquant_layer.rs b/crates/kv-cache-benchmark/src/real_model/turboquant_layer.rs index 020d1062..08586522 100644 --- a/crates/kv-cache-benchmark/src/real_model/turboquant_layer.rs +++ b/crates/kv-cache-benchmark/src/real_model/turboquant_layer.rs @@ -3,10 +3,10 @@ //! Intercepts K/V capture, quantizes each head vector via WHT + Lloyd-Max, //! then dequantizes on read. Measures MSE, cosine, and compression vs FP16. -use ndarray::Array2; -use crate::turboquant::TurboQuant; -use crate::metrics::Metrics; use super::kv_capture::KvCapture; +use crate::metrics::Metrics; +use crate::turboquant::TurboQuant; +use ndarray::Array2; /// Result of applying TurboQuant to captured K/V. pub struct TurboQuantResult { @@ -49,10 +49,8 @@ pub fn apply_turboquant(capture: &KvCapture, tq: &TurboQuant) -> TurboQuantResul let k = &capture.keys[layer]; let v = &capture.values[layer]; - let (dk, enc_bytes_k, enc_us_k, dec_us_k, mse_k, cos_k, count_k) = - quantize_tensor(k, tq); - let (dv, enc_bytes_v, enc_us_v, dec_us_v, mse_v, cos_v, count_v) = - quantize_tensor(v, tq); + let (dk, enc_bytes_k, enc_us_k, dec_us_k, mse_k, cos_k, count_k) = quantize_tensor(k, tq); + let (dv, enc_bytes_v, enc_us_v, dec_us_v, mse_v, cos_v, count_v) = quantize_tensor(v, tq); total_compressed += enc_bytes_k + enc_bytes_v; total_original += (k.len() + v.len()) * 2; // FP16 @@ -66,8 +64,16 @@ pub fn apply_turboquant(capture: &KvCapture, tq: &TurboQuant) -> TurboQuantResul decoded_values.push(dv); } - let avg_mse = if vector_count > 0 { total_mse / vector_count as f64 } else { 0.0 }; - let avg_cosine = if vector_count > 0 { total_cosine / vector_count as f64 } else { 0.0 }; + let avg_mse = if vector_count > 0 { + total_mse / vector_count as f64 + } else { + 0.0 + }; + let avg_cosine = if vector_count > 0 { + total_cosine / vector_count as f64 + } else { + 0.0 + }; let compression = if total_compressed > 0 { total_original as f64 / total_compressed as f64 } else { @@ -134,7 +140,15 @@ fn quantize_tensor( } } - (decoded, total_encoded_bytes, encode_us, decode_us, total_mse, total_cosine, count) + ( + decoded, + total_encoded_bytes, + encode_us, + decode_us, + total_mse, + total_cosine, + count, + ) } /// Find the largest power-of-2 that divides cols (for WHT compatibility). diff --git a/crates/kv-cache-benchmark/src/shader_bench.rs b/crates/kv-cache-benchmark/src/shader_bench.rs index c0c16b4d..a54f40fe 100644 --- a/crates/kv-cache-benchmark/src/shader_bench.rs +++ b/crates/kv-cache-benchmark/src/shader_bench.rs @@ -9,9 +9,9 @@ //! Gate KNN ✓ ✓ ✓ //! Sparse FFN walk ✓ ✓ n/a -use crate::turboquant::TurboQuant; -use crate::turboquant::rotation; use crate::metrics::Metrics; +use crate::turboquant::rotation; +use crate::turboquant::TurboQuant; /// Benchmark result for a single operation. #[derive(Debug, Clone, serde::Serialize)] @@ -26,7 +26,9 @@ pub struct ShaderBenchResult { /// Run CPU WHT benchmark at given dimension. pub fn bench_wht_cpu(dim: usize, iterations: usize) -> ShaderBenchResult { - let x: Vec = (0..dim).map(|i| (i as f32 - dim as f32 / 2.0) / 100.0).collect(); + let x: Vec = (0..dim) + .map(|i| (i as f32 - dim as f32 / 2.0) / 100.0) + .collect(); let t0 = std::time::Instant::now(); for _ in 0..iterations { diff --git a/crates/kv-cache-benchmark/src/standard_kv.rs b/crates/kv-cache-benchmark/src/standard_kv.rs index 74ace4a2..7d7b06b8 100644 --- a/crates/kv-cache-benchmark/src/standard_kv.rs +++ b/crates/kv-cache-benchmark/src/standard_kv.rs @@ -1,4 +1,4 @@ -use crate::{KvStrategy, model_config::ModelConfig}; +use crate::{model_config::ModelConfig, KvStrategy}; /// Strategy 1: Standard FP16 KV cache. /// @@ -25,7 +25,12 @@ impl KvStrategy for StandardKv { buf } - fn decode(&self, encoded: &[u8], num_vectors: usize, dim: usize) -> (Vec>, Vec>) { + fn decode( + &self, + encoded: &[u8], + num_vectors: usize, + dim: usize, + ) -> (Vec>, Vec>) { let floats_per_set = num_vectors * dim; let bytes_per_set = floats_per_set * 2; @@ -90,7 +95,11 @@ fn f16_decode(bytes: [u8; 2]) -> f32 { // Subnormal fp16 let mut f = frac as f32 / 1024.0; f *= 2.0f32.powi(-14); - if sign == 1 { -f } else { f } + if sign == 1 { + -f + } else { + f + } } else if exp == 0x1F { if frac == 0 { f32::from_bits((sign << 31) | (0xFF << 23)) @@ -130,7 +139,10 @@ mod tests { let decoded = f16_decode(encoded); let err = (v - decoded).abs(); // FP16 has ~3 decimal digits of precision - assert!(err < 0.01 * v.abs().max(0.001), "fp16 roundtrip failed for {v}: got {decoded}, err {err}"); + assert!( + err < 0.01 * v.abs().max(0.001), + "fp16 roundtrip failed for {v}: got {decoded}, err {err}" + ); } } diff --git a/crates/kv-cache-benchmark/src/turboquant/codebooks.rs b/crates/kv-cache-benchmark/src/turboquant/codebooks.rs index 1fc91ab2..94bd7f8f 100644 --- a/crates/kv-cache-benchmark/src/turboquant/codebooks.rs +++ b/crates/kv-cache-benchmark/src/turboquant/codebooks.rs @@ -5,7 +5,6 @@ /// /// These codebooks are the optimal scalar quantizers for this distribution. /// Values validated against llama.cpp Discussion #20969 reference implementation. - use super::lloyd_max::Codebook; /// Get the pre-computed codebook for a given dimension and bit-width. diff --git a/crates/kv-cache-benchmark/src/turboquant/lloyd_max.rs b/crates/kv-cache-benchmark/src/turboquant/lloyd_max.rs index 577b588c..4d4e4114 100644 --- a/crates/kv-cache-benchmark/src/turboquant/lloyd_max.rs +++ b/crates/kv-cache-benchmark/src/turboquant/lloyd_max.rs @@ -23,9 +23,7 @@ impl Codebook { /// Quantize a scalar to its nearest centroid index using binary search on boundaries. pub fn quantize_scalar(value: f32, codebook: &Codebook) -> u8 { // Binary search: find the first boundary > value - let idx = codebook - .boundaries - .partition_point(|&b| b <= value); + let idx = codebook.boundaries.partition_point(|&b| b <= value); idx as u8 } @@ -53,10 +51,7 @@ pub fn compute_codebook(samples: &[f32], n_levels: usize, max_iters: usize) -> C for _ in 0..max_iters { // Compute boundaries (midpoints between adjacent centroids) - let boundaries: Vec = centroids - .windows(2) - .map(|w| (w[0] + w[1]) / 2.0) - .collect(); + let boundaries: Vec = centroids.windows(2).map(|w| (w[0] + w[1]) / 2.0).collect(); // Assign samples to nearest centroid and compute new means let mut sums = vec![0.0f64; n_levels]; @@ -84,10 +79,7 @@ pub fn compute_codebook(samples: &[f32], n_levels: usize, max_iters: usize) -> C } } - let boundaries: Vec = centroids - .windows(2) - .map(|w| (w[0] + w[1]) / 2.0) - .collect(); + let boundaries: Vec = centroids.windows(2).map(|w| (w[0] + w[1]) / 2.0).collect(); Codebook { boundaries, diff --git a/crates/kv-cache-benchmark/src/turboquant/mod.rs b/crates/kv-cache-benchmark/src/turboquant/mod.rs index 52dc77ac..6d907c4c 100644 --- a/crates/kv-cache-benchmark/src/turboquant/mod.rs +++ b/crates/kv-cache-benchmark/src/turboquant/mod.rs @@ -1,84 +1,16 @@ -pub mod rotation; +//! TurboQuant — re-exported from `larql_inference::engines::turbo_quant`. +//! +//! Algorithm modules still live here for the benchmark's KvStrategy impl; +//! the KvEngine integration lives in larql-inference. + +pub mod codebooks; pub mod lloyd_max; pub mod packing; -pub mod codebooks; - -use crate::{KvStrategy, model_config::ModelConfig}; - -/// Strategy 2: TurboQuant (ICLR 2026). -/// -/// Algorithm 1 (MSE-only, no QJL): -/// 1. Normalize → unit norm, store scalar -/// 2. Walsh-Hadamard rotation (spreads coordinates to Beta distribution) -/// 3. Lloyd-Max scalar quantization (3 or 4 bits per coordinate) -/// 4. Bit-pack indices -/// 5. Decode: unpack → centroids → inverse WHT → rescale -pub struct TurboQuant { - pub bits: u8, // 3 or 4 -} - -impl TurboQuant { - pub fn new(bits: u8) -> Self { - assert!(bits == 3 || bits == 4, "TurboQuant supports 3 or 4 bits"); - Self { bits } - } - - /// Encode a single vector: normalize → WHT → quantize → pack. - pub fn encode_vector(&self, x: &[f32]) -> Vec { - let d = x.len(); - - // Step 1: compute norm and normalize - let norm = x.iter().map(|v| v * v).sum::().sqrt(); - let x_hat: Vec = if norm > 1e-12 { - x.iter().map(|v| v / norm).collect() - } else { - vec![0.0; d] - }; - - // Step 2: Walsh-Hadamard transform (in-place) - let y = rotation::wht(&x_hat); - - // Step 3: Lloyd-Max quantize each coordinate - let codebook = codebooks::get_codebook(d, self.bits); - let indices: Vec = y - .iter() - .map(|&val| lloyd_max::quantize_scalar(val, codebook)) - .collect(); - - // Step 4: pack norm (4 bytes f32) + bit-packed indices - let mut buf = Vec::new(); - buf.extend_from_slice(&norm.to_le_bytes()); - packing::pack_indices(&indices, self.bits, &mut buf); - buf - } - - /// Decode a single vector: unpack → centroids → inverse WHT → rescale. - pub fn decode_vector(&self, encoded: &[u8], dim: usize) -> Vec { - // Read norm - let norm = f32::from_le_bytes([encoded[0], encoded[1], encoded[2], encoded[3]]); - - // Unpack indices - let indices = packing::unpack_indices(&encoded[4..], dim, self.bits); - - // Centroid lookup - let codebook = codebooks::get_codebook(dim, self.bits); - let y: Vec = indices - .iter() - .map(|&idx| codebook.centroids[idx as usize]) - .collect(); - - // Inverse WHT (WHT is self-inverse up to scaling) - let x_hat = rotation::wht(&y); +pub mod rotation; - // Rescale - x_hat.iter().map(|&v| v * norm).collect() - } +pub use larql_inference::engines::turbo_quant::TurboQuant; - /// Bytes per encoded vector. - fn bytes_per_vector(&self, dim: usize) -> usize { - 4 + packing::packed_size(dim, self.bits) // norm + packed indices - } -} +use crate::{model_config::ModelConfig, KvStrategy}; impl KvStrategy for TurboQuant { fn name(&self) -> &str { @@ -92,17 +24,20 @@ impl KvStrategy for TurboQuant { fn encode(&self, keys: &[Vec], values: &[Vec]) -> Vec { let mut buf = Vec::new(); for v in keys.iter().chain(values.iter()) { - let enc = self.encode_vector(v); - buf.extend_from_slice(&enc); + buf.extend_from_slice(&self.encode_vector(v)); } buf } - fn decode(&self, encoded: &[u8], num_vectors: usize, dim: usize) -> (Vec>, Vec>) { + fn decode( + &self, + encoded: &[u8], + num_vectors: usize, + dim: usize, + ) -> (Vec>, Vec>) { let bytes_per = self.bytes_per_vector(dim); let mut keys = Vec::with_capacity(num_vectors); let mut values = Vec::with_capacity(num_vectors); - for i in 0..num_vectors { let offset = i * bytes_per; keys.push(self.decode_vector(&encoded[offset..offset + bytes_per], dim)); @@ -115,7 +50,7 @@ impl KvStrategy for TurboQuant { } fn memory_bytes(&self, config: &ModelConfig, seq_len: usize) -> usize { - let num_vectors = seq_len * config.layers * config.kv_heads * 2; // K+V + let num_vectors = seq_len * config.layers * config.kv_heads * 2; num_vectors * self.bytes_per_vector(config.kv_dim()) } } diff --git a/crates/kv-cache-benchmark/src/turboquant/rotation.rs b/crates/kv-cache-benchmark/src/turboquant/rotation.rs index d910ce33..cd9f0d03 100644 --- a/crates/kv-cache-benchmark/src/turboquant/rotation.rs +++ b/crates/kv-cache-benchmark/src/turboquant/rotation.rs @@ -24,7 +24,10 @@ fn apply_sign_flips(y: &mut [f32]) { /// Self-inverse because (DHD)^2 = DH(DD)HD = DH·I·HD = D(HH)D = D·I·D = I pub fn wht(x: &[f32]) -> Vec { let d = x.len(); - assert!(d.is_power_of_two(), "WHT requires power-of-2 dimension, got {d}"); + assert!( + d.is_power_of_two(), + "WHT requires power-of-2 dimension, got {d}" + ); let mut y = x.to_vec(); @@ -70,10 +73,7 @@ mod tests { let x_recon = wht(&y); for (a, b) in x.iter().zip(x_recon.iter()) { - assert!( - (a - b).abs() < 1e-4, - "WHT not self-inverse: {a} vs {b}" - ); + assert!((a - b).abs() < 1e-4, "WHT not self-inverse: {a} vs {b}"); } } diff --git a/crates/kv-cache-benchmark/src/unlimited_context/engine.rs b/crates/kv-cache-benchmark/src/unlimited_context/engine.rs deleted file mode 100644 index bd02b499..00000000 --- a/crates/kv-cache-benchmark/src/unlimited_context/engine.rs +++ /dev/null @@ -1,242 +0,0 @@ -//! Top-level `UnlimitedContextEngine` — Rust port of -//! `chuk-mlx/src/chuk_lazarus/inference/context/research/unlimited_engine.py`. -//! -//! Window lifecycle: -//! 1. `process(tokens)` — extends active window's K,V via -//! `rs_extend_from_checkpoint`. When window fills, auto-closes. -//! 2. `close_window()` — saves last-position K,V to `CheckpointStore`, -//! appends token IDs to `TokenArchive`, resets active window. -//! 3. `replay_window(id)` — reconstructs a window's full K,V by running -//! a forward pass over the archived tokens from the prior checkpoint. -//! 4. `stats()` — total bytes, windows, compression ratio vs full KV. - -use larql_inference::attention::SharedKV; -use larql_inference::model::ModelWeights; -use serde::Serialize; - -use super::checkpoint_store::CheckpointStore; -use super::extend::{empty_prior, rs_extend_from_checkpoint}; -use super::token_archive::TokenArchive; - -/// Storage and context statistics for `UnlimitedContextEngine`. -#[derive(Debug, Clone, Serialize)] -pub struct EngineStats { - pub total_tokens: usize, - pub archived_windows: usize, - pub current_window_id: usize, - pub current_window_tokens: usize, - pub checkpoint_bytes: usize, - pub archive_bytes: usize, - pub total_boundary_bytes: usize, - pub equivalent_kv_bytes: usize, - pub compression_ratio: f64, -} - -impl EngineStats { - pub fn summary(&self) -> String { - format!( - "{} windows / {} tokens — {:.0}× compression vs full KV", - self.archived_windows, self.total_tokens, self.compression_ratio - ) - } -} - -pub struct UnlimitedContextEngine { - pub window_size: usize, - pub checkpoints: CheckpointStore, - pub archive: TokenArchive, - - current_window_id: usize, - current_window_tokens: Vec, - current_window_kv: Option>, - abs_offset: usize, -} - -impl UnlimitedContextEngine { - pub fn new(window_size: usize) -> Self { - Self { - window_size, - checkpoints: CheckpointStore::new(), - archive: TokenArchive::new(), - current_window_id: 0, - current_window_tokens: Vec::new(), - current_window_kv: None, - abs_offset: 0, - } - } - - /// Feed tokens into the engine. Windows auto-close when they fill. - /// - /// Processes in chunks that fit within the current window; whenever the - /// current window is exactly `window_size` tokens, closes it (saves - /// checkpoint + archives tokens) and starts a new window. - pub fn process(&mut self, weights: &ModelWeights, tokens: &[u32]) -> Option<()> { - let mut remaining = tokens; - while !remaining.is_empty() { - let free = self.window_size - self.current_window_tokens.len(); - let take = remaining.len().min(free); - let (chunk, rest) = remaining.split_at(take); - self.extend_current(weights, chunk)?; - remaining = rest; - if self.current_window_tokens.len() >= self.window_size { - self.close_window(); - } - } - Some(()) - } - - /// Close any partial current window. Call before replay if the current - /// window hasn't filled naturally. - pub fn flush(&mut self) { - if !self.current_window_tokens.is_empty() { - self.close_window(); - } - } - - /// Reconstruct a window's full K,V by replaying its archived tokens - /// from the prior window's boundary checkpoint. - /// - /// Returns `(kv_per_layer, abs_end)` where `kv_per_layer[l]` has shape - /// `(prior_len + |w|, num_kv × head_dim)` and `abs_end` is the - /// absolute position of the last token in this window. - /// - /// For `window_id == 0` (no prior), runs a fresh prefill — bit-exact - /// with the original processing. For `window_id > 0`, starts from the - /// saved 1-token checkpoint of the previous window — within-window K,V - /// are produced by the actual forward pass; the 1-token prior summary - /// is the only cross-window approximation. - pub fn replay_window( - &self, - weights: &ModelWeights, - window_id: usize, - ) -> Option<(Vec, usize)> { - let (tokens, abs_offset) = self.archive.retrieve(window_id)?; - - let prior = if window_id > 0 && self.checkpoints.contains(window_id - 1) { - let (ckpt, _) = self.checkpoints.load(window_id - 1)?; - ckpt - } else { - empty_prior(weights) - }; - - let out = rs_extend_from_checkpoint(weights, tokens, &prior, abs_offset)?; - let abs_end = abs_offset + tokens.len() - 1; - Some((out.kv_cache, abs_end)) - } - - /// Total storage and context statistics. - pub fn stats(&self, weights: &ModelWeights) -> EngineStats { - let arch = &*weights.arch; - let num_layers = weights.num_layers; - let kv_dim_sum: usize = (0..num_layers) - .map(|l| arch.num_kv_heads_for_layer(l) * arch.head_dim_for_layer(l)) - .sum(); - - let total_archived = self.archive.total_tokens(); - let current = self.current_window_tokens.len(); - let total_tokens = total_archived + current; - - // Standard KV reference: bf16 (2 bytes per K and V entry) - let equivalent_kv_bytes = total_tokens * kv_dim_sum * 2 * 2; - - let checkpoint_bytes = self.checkpoints.total_bytes(); - let archive_bytes = self.archive.total_bytes(); - let total_boundary_bytes = checkpoint_bytes + archive_bytes; - - let compression_ratio = if total_boundary_bytes == 0 { - 0.0 - } else { - equivalent_kv_bytes as f64 / total_boundary_bytes as f64 - }; - - EngineStats { - total_tokens, - archived_windows: self.archive.len(), - current_window_id: self.current_window_id, - current_window_tokens: current, - checkpoint_bytes, - archive_bytes, - total_boundary_bytes, - equivalent_kv_bytes, - compression_ratio, - } - } - - // ------------------------------------------------------------------ - // internals - // ------------------------------------------------------------------ - - fn extend_current(&mut self, weights: &ModelWeights, chunk: &[u32]) -> Option<()> { - if chunk.is_empty() { - return Some(()); - } - - // Seed with prior window's checkpoint on first extend of a new window, - // or continue from whatever K,V the active window has accumulated. - let prior = if self.current_window_tokens.is_empty() { - if self.current_window_id > 0 && self.checkpoints.contains(self.current_window_id - 1) - { - let (ckpt, _) = self.checkpoints.load(self.current_window_id - 1)?; - ckpt - } else { - empty_prior(weights) - } - } else { - self.current_window_kv - .take() - .unwrap_or_else(|| empty_prior(weights)) - }; - - let abs_start = self.abs_offset + self.current_window_tokens.len(); - let out = rs_extend_from_checkpoint(weights, chunk, &prior, abs_start)?; - - self.current_window_kv = Some(out.kv_cache); - self.current_window_tokens.extend_from_slice(chunk); - Some(()) - } - - fn close_window(&mut self) { - let kv = match self.current_window_kv.take() { - Some(kv) => kv, - None => return, - }; - - // Extract last-position K,V per layer = next boundary checkpoint. - let last_kv: Vec = kv - .iter() - .map(|(k, v)| { - let n = k.shape()[0]; - let last_k = k.slice(ndarray::s![n - 1..n, ..]).to_owned(); - let last_v = v.slice(ndarray::s![n - 1..n, ..]).to_owned(); - (last_k, last_v) - }) - .collect(); - - let window_len = self.current_window_tokens.len(); - let abs_end = self.abs_offset + window_len - 1; - - self.checkpoints.save(self.current_window_id, last_kv, abs_end); - self.archive.archive( - self.current_window_id, - std::mem::take(&mut self.current_window_tokens), - self.abs_offset, - ); - self.abs_offset += window_len; - self.current_window_id += 1; - } -} - -#[cfg(test)] -mod tests { - use super::*; - - // Engine construction + storage accounting without running a model. - #[test] - fn new_engine_is_empty() { - let eng = UnlimitedContextEngine::new(512); - assert_eq!(eng.window_size, 512); - assert_eq!(eng.archive.len(), 0); - assert_eq!(eng.checkpoints.len(), 0); - assert_eq!(eng.current_window_id, 0); - } -} diff --git a/crates/kv-cache-benchmark/src/unlimited_context/extend.rs b/crates/kv-cache-benchmark/src/unlimited_context/extend.rs deleted file mode 100644 index cce22670..00000000 --- a/crates/kv-cache-benchmark/src/unlimited_context/extend.rs +++ /dev/null @@ -1,121 +0,0 @@ -//! Multi-token extend with prior K,V checkpoint. -//! -//! Runs a forward pass over new tokens, seeding each layer's attention with -//! an optional prior K,V cache (the window boundary checkpoint). Equivalent -//! to Python `UnlimitedContextEngine.replay_window` inner loop. -//! -//! The implementation loops over tokens calling -//! `run_attention_block_decode_step`, which extends a per-layer K,V cache by -//! one position per call. After N tokens, the per-layer cache has -//! `prior_len + N` rows of K and V. -//! -//! This is O(N × L × head_ops) per window replay — matching what Python's -//! `extend()` does in a single batched call, just unrolled sequentially. -//! Slightly slower on CPU but functionally identical; the `SharedKV` -//! returned by each call carries the exact same values the batched path -//! would produce. - -use ndarray::Array2; - -use larql_inference::attention::{run_attention_block_decode_step, SharedKV}; -use larql_inference::ffn::WeightFfn; -use larql_inference::forward::{embed_tokens_pub, run_ffn}; -use larql_inference::model::ModelWeights; - -/// Output of `rs_extend_from_checkpoint`. -pub struct ExtendOutput { - /// Hidden state at the last processed token, shape (1, hidden). - pub last_hidden: Array2, - /// Per-layer full K,V cache covering `[prior_tokens, new_tokens]`. - /// Shape of each K/V: `(prior_len + new_len, num_kv * head_dim)`. - pub kv_cache: Vec, - /// Per-layer last-row K,V, ready to save as the next boundary - /// checkpoint. Shape of each: `(1, num_kv * head_dim)`. - pub new_checkpoint: Vec, -} - -/// Run the decoder forward over `token_ids` with an optional prior K,V -/// checkpoint seeded at each layer. Returns: -/// - `last_hidden`: hidden state at the last new token -/// - `kv_cache`: full K,V per layer after extension (prior + new) -/// - `new_checkpoint`: last-row K,V per layer for saving as a boundary -/// -/// `prior_kv` should contain one K,V pair per layer. Each pair's K,V may be -/// empty (0 rows) for the "no prior" case (replay of window 0) or have 1 -/// row for a standard boundary checkpoint. Multi-row priors are allowed — -/// in that case attention sees the prior as a multi-token prefix. -/// -/// `abs_start` is the absolute position of the *first new token* in the -/// original sequence. RoPE is applied at that position and following. -pub fn rs_extend_from_checkpoint( - weights: &ModelWeights, - token_ids: &[u32], - prior_kv: &[SharedKV], - abs_start: usize, -) -> Option { - let num_layers = weights.num_layers; - let ffn = WeightFfn { weights }; - - if token_ids.is_empty() { - return None; - } - if prior_kv.len() != num_layers { - return None; - } - - let mut kv_cache: Vec = prior_kv.to_vec(); - let mut last_hidden: Option> = None; - - for (i, &token_id) in token_ids.iter().enumerate() { - let abs_position = abs_start + i; - let mut h = embed_tokens_pub(weights, &[token_id]); - - for (layer, kv_slot) in kv_cache.iter_mut().enumerate() { - let kv_entry: Option<&SharedKV> = if kv_slot.0.shape()[0] > 0 { - Some(kv_slot) - } else { - None - }; - - let (h_post_attn, new_kv) = - run_attention_block_decode_step(weights, &h, layer, kv_entry, abs_position)?; - - let (h_out, _capture) = run_ffn(weights, &h_post_attn, layer, &ffn, false); - h = h_out; - *kv_slot = new_kv; - } - - last_hidden = Some(h); - } - - let new_checkpoint: Vec = kv_cache - .iter() - .map(|(k, v)| { - let n = k.shape()[0]; - let last_k = k.slice(ndarray::s![n - 1..n, ..]).to_owned(); - let last_v = v.slice(ndarray::s![n - 1..n, ..]).to_owned(); - (last_k, last_v) - }) - .collect(); - - Some(ExtendOutput { - last_hidden: last_hidden?, - kv_cache, - new_checkpoint, - }) -} - -/// Build an empty (zero-row) K,V seed for use as `prior_kv` when replaying -/// window 0 or any window with no prior checkpoint. -pub fn empty_prior(weights: &ModelWeights) -> Vec { - let arch = &*weights.arch; - (0..weights.num_layers) - .map(|layer| { - let kv_dim = arch.num_kv_heads_for_layer(layer) * arch.head_dim_for_layer(layer); - ( - Array2::::zeros((0, kv_dim)), - Array2::::zeros((0, kv_dim)), - ) - }) - .collect() -} diff --git a/crates/kv-cache-benchmark/src/unlimited_context/mod.rs b/crates/kv-cache-benchmark/src/unlimited_context/mod.rs index 65e9cc00..b02a6f7d 100644 --- a/crates/kv-cache-benchmark/src/unlimited_context/mod.rs +++ b/crates/kv-cache-benchmark/src/unlimited_context/mod.rs @@ -1,51 +1,12 @@ -//! Tier 2 — Unlimited Context Engine (Rust port of Python/MLX `UnlimitedContextEngine`). +//! Unlimited Context Engine — re-exported from `larql_inference::engines::unlimited_context`. //! -//! Three-tier storage with sparse K,V checkpoints and model-forward replay: -//! -//! ```text -//! ┌──────────────────────┬─────────────────────┬──────────────────┐ -//! │ Boundary (WARM) │ Active window KV │ Token archive │ -//! │ 1 K,V per layer │ grows as window │ ~4 B / token │ -//! │ per closed window │ is extended │ (cold tier) │ -//! └──────────────────────┴─────────────────────┴──────────────────┘ -//! ``` -//! -//! - Each window is `window_size` tokens (default 512). As the window fills, -//! the engine extends an in-memory K,V cache via `rs_extend_from_checkpoint`. -//! - When the window closes: (a) the last-position K,V per layer is saved to -//! `CheckpointStore`, (b) the window's token IDs are appended to -//! `TokenArchive`, (c) the full window K,V is evicted. -//! - To query any past window, call `replay_window(id)` — it reconstructs the -//! window's K,V by running a model-forward pass over the archived tokens -//! starting from the prior window's boundary checkpoint. -//! -//! ## Correctness claim (what's bit-exact, what isn't) -//! -//! - **Within-window bit-exact**: `rs_extend_from_checkpoint(tokens, prior, abs_start)` -//! produces the same `h_new` and K,V for `tokens` as the same call with -//! identical inputs. The forward pass is deterministic up to numerical -//! precision (bf16/f32 arithmetic). -//! - **Against joint prefill**: replay(window_N, N>0) differs from joint -//! `prefill([w_0, ..., w_N])` at the window-N positions because the 1-token -//! prior checkpoint compresses `|w_{N-1}|` positions of K,V to 1. This is -//! the same lossiness variant (ii) per-layer boundary gives, measured at -//! cos ≈ 0.965 in `experiments/20_free_monoids_poincare/f1prime_*.py`. -//! -//! **Memory** on Gemma 3 4B (34 layers, 4 KV heads, head_dim=256, bf16): -//! 1 checkpoint = 34 × 2 × (4 × 256) × 2 B ≈ 139 KB. Python docs call this -//! ~174 KB accounting for some overhead. Matches either way. - -mod checkpoint_store; -mod token_archive; -mod extend; -mod engine; +//! The implementation now lives in larql-inference. This module is a thin +//! re-export so existing benchmark code continues to compile unchanged. -pub use checkpoint_store::CheckpointStore; -pub use token_archive::TokenArchive; -pub use extend::{empty_prior, rs_extend_from_checkpoint, ExtendOutput}; -pub use engine::{UnlimitedContextEngine, EngineStats}; +pub use larql_inference::engines::unlimited_context::{ + empty_prior, rs_extend_from_checkpoint, CheckpointStore, EngineStats, ExtendOutput, + TokenArchive, UnlimitedContextEngine, +}; -/// Test-only re-export so integration tests can construct an empty prior -/// without importing the inner module path. #[doc(hidden)] -pub use extend::empty_prior as __empty_prior_for_test; +pub use larql_inference::engines::unlimited_context::empty_prior as __empty_prior_for_test; diff --git a/crates/kv-cache-benchmark/src/vindex_compare.rs b/crates/kv-cache-benchmark/src/vindex_compare.rs new file mode 100644 index 00000000..0328c3f5 --- /dev/null +++ b/crates/kv-cache-benchmark/src/vindex_compare.rs @@ -0,0 +1,548 @@ +//! Vindex A/B comparison — run the same forward pass against two +//! `VectorIndex` instances and report how much their final logits +//! diverge. +//! +//! Format-agnostic by construction. Works for any pair of loaded +//! vindexes: f32 vs FP4, FP4 vs FP6, Q4K vs FP4, etc. The only thing +//! that varies between runs is the `VectorIndex` the walk kernel +//! dispatches through — everything else (attention weights, lm_head, +//! embeddings, tokenizer) is shared. That isolates the measurement to +//! the storage-format delta. +//! +//! Primary consumer: exp 26 Q2 (FP4 end-to-end correctness) via the +//! `vindex_compare` example. But the library has no FP4-specific +//! behaviour and is ready for any future storage-format A/B. + +#![cfg(feature = "real-model")] + +use std::collections::HashMap; + +use serde::Serialize; + +use larql_inference::attention::SharedKV; +use larql_inference::forward::{embed_tokens_pub, hidden_to_raw_logits, run_layer_with_ffn}; +use larql_inference::model::ModelWeights; +use larql_inference::vindex::WalkFfn; +use larql_vindex::VectorIndex; + +/// Per-comparison knobs. Kept minimal; future options added as fields. +#[derive(Debug, Clone)] +pub struct ComparisonConfig { + /// K for top-K agreement measurement. `5` by default. + pub top_k: usize, + /// Cap prompt length to this many tokens (None = full). + pub max_seq_len: Option, + /// Stop at this many layers (None = all of them). + pub max_layers: Option, +} + +impl Default for ComparisonConfig { + fn default() -> Self { + Self { + top_k: 5, + max_seq_len: None, + max_layers: None, + } + } +} + +/// Metrics for a single prompt comparison. +#[derive(Debug, Clone, Serialize)] +pub struct PromptReport { + pub prompt: String, + pub seq_len: usize, + /// Cosine similarity between reference and candidate logit vectors + /// at the final position. + pub logit_cos: f64, + /// Did argmax(logits_ref) == argmax(logits_cand)? + pub argmax_match: bool, + /// Jaccard index of the top-K token-id sets. + pub top_k_jaccard: f64, + /// KL(softmax(ref) || softmax(cand)). Symmetric reported separately. + pub kl_forward: f64, + /// KL(softmax(cand) || softmax(ref)). + pub kl_reverse: f64, + /// Symmetrised KL (mean of forward + reverse). + pub kl_symmetric: f64, + /// Argmax token id for each side. + pub ref_top_token_id: u32, + pub cand_top_token_id: u32, + /// Optional human-readable decoded tokens (filled by the CLI, not + /// the library — we don't want a tokenizer dep in the pure path). + pub ref_top_token: Option, + pub cand_top_token: Option, +} + +/// Aggregate report across a prompt set. +#[derive(Debug, Clone, Serialize)] +pub struct AggregateReport { + pub n_prompts: usize, + pub reference_label: String, + pub candidate_label: String, + pub config: ComparisonConfigSerde, + pub prompts: Vec, + /// Fraction of prompts where argmax matches. + pub argmax_agreement: f64, + /// Mean top-K Jaccard. + pub top_k_agreement_mean: f64, + /// Mean logit cosine similarity. + pub logit_cos_mean: f64, + /// Mean / 95th percentile / max symmetric KL. + pub kl_mean: f64, + pub kl_p95: f64, + pub kl_max: f64, +} + +#[derive(Debug, Clone, Serialize)] +pub struct ComparisonConfigSerde { + pub top_k: usize, + pub max_seq_len: Option, + pub max_layers: Option, +} + +impl From<&ComparisonConfig> for ComparisonConfigSerde { + fn from(c: &ComparisonConfig) -> Self { + Self { + top_k: c.top_k, + max_seq_len: c.max_seq_len, + max_layers: c.max_layers, + } + } +} + +/// Run the same forward pass against two vindexes, one prompt per call. +/// +/// Returns the final-position logits for each side. Shared model +/// weights, shared tokenisation, identical prefill through every layer +/// — the only axis of variation is which `VectorIndex` backs the walk +/// kernel during the FFN stage. +/// +/// The function is entirely format-blind: `WalkFfn::new_unlimited` +/// uses the unified `GateIndex::ffn_row_*` dispatch we wired in the +/// trait refactor, so whichever backend the vindex carries (FP4, Q4K, +/// native f32) automatically fires. +pub fn forward_to_logits( + weights: &ModelWeights, + index: &VectorIndex, + token_ids: &[u32], + config: &ComparisonConfig, +) -> Vec { + forward_to_logits_traced(weights, index, token_ids, config).0 +} + +/// Same as `forward_to_logits` but also returns the per-layer walk-path +/// trace (one `(layer, path_name)` per layer). Enables the CLI +/// `--trace` flag and catches cases where a candidate vindex silently +/// falls through to an unexpected backend — the bug class exp 26 Q2 +/// surfaced during development. +pub fn forward_to_logits_traced( + weights: &ModelWeights, + index: &VectorIndex, + token_ids: &[u32], + config: &ComparisonConfig, +) -> (Vec, Vec<(usize, &'static str)>) { + let mut h = embed_tokens_pub(weights, token_ids); + + let num_layers = config.max_layers.unwrap_or(weights.num_layers); + let mut kv_cache: HashMap = HashMap::new(); + let mut trace: Vec<(usize, &'static str)> = Vec::with_capacity(num_layers); + + for layer in 0..num_layers { + let shared_kv = weights + .arch + .kv_shared_source_layer(layer) + .and_then(|src| kv_cache.get(&src)); + + // WalkFfn with dispatch trace enabled. The trace is drained + // per-layer so we can pin which path fired even when multiple + // positions are processed. + let walk_ffn = WalkFfn::new_unlimited(weights, index).with_dispatch_trace(); + + if let Some((h_new, _, kv_out)) = + run_layer_with_ffn(weights, &h, layer, &walk_ffn, false, None, shared_kv) + { + h = h_new; + if let Some(kv) = kv_out { + kv_cache.insert(layer, kv); + } + // Surface the first trace entry for this layer (there are + // seq_len entries at the serial sparse path, but they all + // report the same name). Missing trace == cache hit or + // zero-features-dense. + let entries = walk_ffn.take_dispatch_trace(); + let path = entries.first().map(|e| e.path).unwrap_or("unknown"); + trace.push((layer, path)); + } else { + break; + } + } + + let seq_len = h.shape()[0]; + let last_h = h.slice(ndarray::s![seq_len - 1..seq_len, ..]).to_owned(); + (hidden_to_raw_logits(weights, &last_h), trace) +} + +/// Compare two vindexes on a single prompt. Computes logits via +/// `forward_to_logits` on each and then the full set of metrics. +pub fn compare_prompt( + weights: &ModelWeights, + reference: &VectorIndex, + candidate: &VectorIndex, + prompt: &str, + token_ids: &[u32], + config: &ComparisonConfig, +) -> PromptReport { + let logits_ref = forward_to_logits(weights, reference, token_ids, config); + let logits_cand = forward_to_logits(weights, candidate, token_ids, config); + metrics_from_logits( + prompt, + token_ids.len(), + &logits_ref, + &logits_cand, + config.top_k, + ) +} + +/// Compare a whole prompt set. Returns an `AggregateReport`. +/// +/// Tokenisation is the caller's job (pass `token_ids_per_prompt` +/// alongside the prompts). Keeps this library tokenizer-free. +pub fn compare_many( + weights: &ModelWeights, + reference: &VectorIndex, + candidate: &VectorIndex, + prompts_and_tokens: &[(&str, Vec)], + reference_label: &str, + candidate_label: &str, + config: &ComparisonConfig, +) -> AggregateReport { + let mut per_prompt = Vec::with_capacity(prompts_and_tokens.len()); + for (prompt, token_ids) in prompts_and_tokens { + let mut ids = token_ids.clone(); + if let Some(cap) = config.max_seq_len { + if ids.len() > cap { + ids.truncate(cap); + } + } + per_prompt.push(compare_prompt( + weights, reference, candidate, prompt, &ids, config, + )); + } + aggregate(per_prompt, reference_label, candidate_label, config) +} + +// ── Metrics ──────────────────────────────────────────────────────────────── + +fn metrics_from_logits( + prompt: &str, + seq_len: usize, + logits_ref: &[f32], + logits_cand: &[f32], + top_k: usize, +) -> PromptReport { + assert_eq!( + logits_ref.len(), + logits_cand.len(), + "logit vectors must have the same vocab size" + ); + + let argmax_ref = argmax(logits_ref); + let argmax_cand = argmax(logits_cand); + let cos = cosine(logits_ref, logits_cand); + + let top_ref = top_k_ids(logits_ref, top_k); + let top_cand = top_k_ids(logits_cand, top_k); + let jac = jaccard(&top_ref, &top_cand); + + let probs_ref = softmax(logits_ref); + let probs_cand = softmax(logits_cand); + let kl_forward = kl_divergence(&probs_ref, &probs_cand); + let kl_reverse = kl_divergence(&probs_cand, &probs_ref); + let kl_sym = 0.5 * (kl_forward + kl_reverse); + + PromptReport { + prompt: prompt.to_string(), + seq_len, + logit_cos: cos, + argmax_match: argmax_ref == argmax_cand, + top_k_jaccard: jac, + kl_forward, + kl_reverse, + kl_symmetric: kl_sym, + ref_top_token_id: argmax_ref, + cand_top_token_id: argmax_cand, + ref_top_token: None, + cand_top_token: None, + } +} + +fn aggregate( + prompts: Vec, + reference_label: &str, + candidate_label: &str, + config: &ComparisonConfig, +) -> AggregateReport { + let n = prompts.len(); + if n == 0 { + return AggregateReport { + n_prompts: 0, + reference_label: reference_label.to_string(), + candidate_label: candidate_label.to_string(), + config: config.into(), + prompts, + argmax_agreement: f64::NAN, + top_k_agreement_mean: f64::NAN, + logit_cos_mean: f64::NAN, + kl_mean: f64::NAN, + kl_p95: f64::NAN, + kl_max: f64::NAN, + }; + } + + let argmax_hits = prompts.iter().filter(|p| p.argmax_match).count() as f64; + let top_k_mean = prompts.iter().map(|p| p.top_k_jaccard).sum::() / n as f64; + let cos_mean = prompts.iter().map(|p| p.logit_cos).sum::() / n as f64; + + let mut kls: Vec = prompts.iter().map(|p| p.kl_symmetric).collect(); + kls.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + let kl_mean = kls.iter().sum::() / n as f64; + let kl_p95 = percentile(&kls, 0.95); + let kl_max = *kls.last().unwrap_or(&f64::NAN); + + AggregateReport { + n_prompts: n, + reference_label: reference_label.to_string(), + candidate_label: candidate_label.to_string(), + config: config.into(), + prompts, + argmax_agreement: argmax_hits / n as f64, + top_k_agreement_mean: top_k_mean, + logit_cos_mean: cos_mean, + kl_mean, + kl_p95, + kl_max, + } +} + +// ── Numeric helpers ──────────────────────────────────────────────────────── + +fn argmax(xs: &[f32]) -> u32 { + let mut idx = 0usize; + let mut best = f32::NEG_INFINITY; + for (i, &v) in xs.iter().enumerate() { + if v > best { + best = v; + idx = i; + } + } + idx as u32 +} + +fn top_k_ids(xs: &[f32], k: usize) -> Vec { + let k = k.min(xs.len()); + let mut indexed: Vec<(usize, f32)> = xs.iter().copied().enumerate().collect(); + indexed.select_nth_unstable_by(k - 1, |a, b| { + b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal) + }); + let mut top: Vec = indexed[..k].iter().map(|(i, _)| *i as u32).collect(); + top.sort_unstable(); + top +} + +fn jaccard(a: &[u32], b: &[u32]) -> f64 { + if a.is_empty() && b.is_empty() { + return 1.0; + } + let sa: std::collections::BTreeSet = a.iter().copied().collect(); + let sb: std::collections::BTreeSet = b.iter().copied().collect(); + let intersect = sa.intersection(&sb).count() as f64; + let union = sa.union(&sb).count() as f64; + if union == 0.0 { + 1.0 + } else { + intersect / union + } +} + +fn cosine(a: &[f32], b: &[f32]) -> f64 { + let mut num = 0.0f64; + let mut na = 0.0f64; + let mut nb = 0.0f64; + for (&x, &y) in a.iter().zip(b.iter()) { + num += x as f64 * y as f64; + na += x as f64 * x as f64; + nb += y as f64 * y as f64; + } + let denom = (na.sqrt()) * (nb.sqrt()); + if denom == 0.0 { + 1.0 + } else { + num / denom + } +} + +fn softmax(logits: &[f32]) -> Vec { + let max = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max); + let exps: Vec = logits.iter().map(|&v| ((v - max) as f64).exp()).collect(); + let sum: f64 = exps.iter().sum(); + if sum == 0.0 { + return vec![1.0 / logits.len() as f64; logits.len()]; + } + exps.into_iter().map(|e| e / sum).collect() +} + +fn kl_divergence(p: &[f64], q: &[f64]) -> f64 { + // KL(p || q) = Σ p_i * log(p_i / q_i). Skip p_i == 0 (by + // convention 0 log 0 = 0). Guard against q_i == 0 with a floor. + const EPS: f64 = 1e-12; + let mut kl = 0.0f64; + for (&pi, &qi) in p.iter().zip(q.iter()) { + if pi <= 0.0 { + continue; + } + let qi_safe = qi.max(EPS); + kl += pi * (pi.ln() - qi_safe.ln()); + } + kl +} + +fn percentile(sorted: &[f64], q: f64) -> f64 { + if sorted.is_empty() { + return f64::NAN; + } + let idx = ((sorted.len() - 1) as f64 * q).round() as usize; + sorted[idx.min(sorted.len() - 1)] +} + +// ── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn argmax_finds_max() { + assert_eq!(argmax(&[1.0, 3.0, 2.0, -5.0]), 1); + assert_eq!(argmax(&[-1.0, -3.0, -2.0]), 0); + } + + #[test] + fn top_k_ids_returns_correct_indices() { + // Top-3 by value: idx 1 (3.0), idx 2 (2.0), idx 0 (1.0). + let logits = [1.0, 3.0, 2.0, -5.0, 0.5]; + let top = top_k_ids(&logits, 3); + assert_eq!(top.len(), 3); + // Results are sorted by id; set-equality with {0, 1, 2}. + let expected: std::collections::BTreeSet = [0u32, 1, 2].into_iter().collect(); + let got: std::collections::BTreeSet = top.into_iter().collect(); + assert_eq!(got, expected); + } + + #[test] + fn jaccard_full_overlap_equals_one() { + assert_eq!(jaccard(&[1, 2, 3], &[1, 2, 3]), 1.0); + } + + #[test] + fn jaccard_no_overlap_equals_zero() { + assert_eq!(jaccard(&[1, 2], &[3, 4]), 0.0); + } + + #[test] + fn jaccard_partial() { + // {1,2,3} ∩ {2,3,4} = {2,3}; ∪ = {1,2,3,4}; jac = 2/4 = 0.5. + assert!((jaccard(&[1, 2, 3], &[2, 3, 4]) - 0.5).abs() < 1e-9); + } + + #[test] + fn cosine_identical_vectors() { + let v = vec![1.0f32, 2.0, 3.0]; + assert!((cosine(&v, &v) - 1.0).abs() < 1e-9); + } + + #[test] + fn cosine_orthogonal_vectors() { + let a = [1.0f32, 0.0]; + let b = [0.0f32, 1.0]; + assert!((cosine(&a, &b) - 0.0).abs() < 1e-9); + } + + #[test] + fn softmax_sums_to_one() { + let s = softmax(&[1.0f32, 2.0, 3.0]); + let sum: f64 = s.iter().sum(); + assert!((sum - 1.0).abs() < 1e-9); + } + + #[test] + fn kl_identical_is_zero() { + let p = softmax(&[1.0f32, 2.0, 3.0]); + assert!(kl_divergence(&p, &p).abs() < 1e-9); + } + + #[test] + fn kl_is_nonnegative() { + let p = softmax(&[1.0f32, 2.0, 3.0]); + let q = softmax(&[3.0f32, 1.0, 2.0]); + let kl = kl_divergence(&p, &q); + assert!(kl >= 0.0, "KL must be non-negative, got {kl}"); + } + + #[test] + fn aggregate_handles_empty_gracefully() { + let r = aggregate(vec![], "ref", "cand", &ComparisonConfig::default()); + assert_eq!(r.n_prompts, 0); + assert!(r.argmax_agreement.is_nan()); + } + + #[test] + fn aggregate_computes_means() { + // Two prompts: one argmax match, one argmax miss. Expected + // argmax_agreement = 0.5. + let prompts = vec![ + PromptReport { + prompt: "a".into(), + seq_len: 1, + logit_cos: 0.9, + argmax_match: true, + top_k_jaccard: 0.8, + kl_forward: 0.01, + kl_reverse: 0.01, + kl_symmetric: 0.01, + ref_top_token_id: 42, + cand_top_token_id: 42, + ref_top_token: None, + cand_top_token: None, + }, + PromptReport { + prompt: "b".into(), + seq_len: 2, + logit_cos: 0.7, + argmax_match: false, + top_k_jaccard: 0.4, + kl_forward: 0.05, + kl_reverse: 0.05, + kl_symmetric: 0.05, + ref_top_token_id: 1, + cand_top_token_id: 7, + ref_top_token: None, + cand_top_token: None, + }, + ]; + let r = aggregate(prompts, "r", "c", &ComparisonConfig::default()); + assert_eq!(r.n_prompts, 2); + assert!((r.argmax_agreement - 0.5).abs() < 1e-9); + assert!((r.top_k_agreement_mean - 0.6).abs() < 1e-9); + assert!((r.logit_cos_mean - 0.8).abs() < 1e-9); + assert!((r.kl_mean - 0.03).abs() < 1e-9); + } + + #[test] + fn percentile_handles_edges() { + let sorted = [0.1, 0.2, 0.3, 0.4, 0.5]; + assert_eq!(percentile(&sorted, 0.0), 0.1); + assert_eq!(percentile(&sorted, 1.0), 0.5); + // p95 on 5 elements → round((5-1)*0.95) = round(3.8) = 4 → sorted[4] = 0.5. + assert_eq!(percentile(&sorted, 0.95), 0.5); + } +} diff --git a/crates/kv-cache-benchmark/tests/test_accuracy.rs b/crates/kv-cache-benchmark/tests/test_accuracy.rs index 6e23d5c9..cb3d804d 100644 --- a/crates/kv-cache-benchmark/tests/test_accuracy.rs +++ b/crates/kv-cache-benchmark/tests/test_accuracy.rs @@ -5,7 +5,11 @@ use kv_cache_benchmark::accuracy::*; #[test] fn test_accuracy_factual_prompts_exist() { let prompts = factual_prompts(); - assert!(prompts.len() >= 20, "Need at least 20 factual prompts, got {}", prompts.len()); + assert!( + prompts.len() >= 20, + "Need at least 20 factual prompts, got {}", + prompts.len() + ); // All should have non-empty prompt and expected answer for (prompt, answer) in &prompts { assert!(!prompt.is_empty()); @@ -16,7 +20,11 @@ fn test_accuracy_factual_prompts_exist() { #[test] fn test_accuracy_diverse_prompts_exist() { let prompts = diverse_prompts(); - assert!(prompts.len() >= 10, "Need at least 10 diverse prompts, got {}", prompts.len()); + assert!( + prompts.len() >= 10, + "Need at least 10 diverse prompts, got {}", + prompts.len() + ); } // ── Category 2: KL Divergence ── @@ -25,7 +33,10 @@ fn test_accuracy_diverse_prompts_exist() { fn test_kl_divergence_identical() { let p = vec![0.7, 0.2, 0.1]; let kl = kl_divergence(&p, &p); - assert!(kl.abs() < 1e-10, "KL of identical distributions should be 0, got {kl}"); + assert!( + kl.abs() < 1e-10, + "KL of identical distributions should be 0, got {kl}" + ); } #[test] @@ -63,7 +74,10 @@ fn test_softmax_sums_to_one() { let logits = vec![2.0f32, 1.0, 0.5, -1.0, 3.0]; let probs = softmax(&logits); let sum: f64 = probs.iter().sum(); - assert!((sum - 1.0).abs() < 1e-6, "Softmax should sum to 1, got {sum}"); + assert!( + (sum - 1.0).abs() < 1e-6, + "Softmax should sum to 1, got {sum}" + ); } #[test] @@ -162,7 +176,8 @@ fn test_haystack_generation_short() { #[test] fn test_haystack_generation_long() { - let (context, _needle) = generate_haystack(32000, 5000, "The secret project code is AURORA-7749"); + let (context, _needle) = + generate_haystack(32000, 5000, "The secret project code is AURORA-7749"); assert!(context.contains("AURORA-7749")); assert!(context.len() > 10000); } @@ -205,7 +220,10 @@ fn test_retention_conversation_25_turns() { let queries: Vec<_> = turns.iter().filter(|t| t.is_query).collect(); assert!(queries.len() >= 3); - let facts: Vec<_> = turns.iter().filter(|t| !t.is_query && t.fact_key.is_some()).collect(); + let facts: Vec<_> = turns + .iter() + .filter(|t| !t.is_query && t.fact_key.is_some()) + .collect(); assert!(facts.len() >= 3, "Need at least 3 fact-establishing turns"); } diff --git a/crates/kv-cache-benchmark/tests/test_accuracy_suite.rs b/crates/kv-cache-benchmark/tests/test_accuracy_suite.rs index b7ce7585..2c9657e9 100644 --- a/crates/kv-cache-benchmark/tests/test_accuracy_suite.rs +++ b/crates/kv-cache-benchmark/tests/test_accuracy_suite.rs @@ -4,8 +4,8 @@ #[cfg(feature = "real-model")] mod with_model { - use kv_cache_benchmark::accuracy_suite::prompts; use kv_cache_benchmark::accuracy_suite::needle; + use kv_cache_benchmark::accuracy_suite::prompts; use kv_cache_benchmark::accuracy_suite::runner; #[test] @@ -22,8 +22,14 @@ mod with_model { categories.dedup(); let expected = vec![ - "arithmetic", "code", "completion", "conversational", - "factual", "geographic", "reasoning", "scientific", + "arithmetic", + "code", + "completion", + "conversational", + "factual", + "geographic", + "reasoning", + "scientific", ]; assert_eq!(categories, expected, "Missing categories"); } @@ -31,13 +37,17 @@ mod with_model { #[test] fn test_diverse_100_balanced_categories() { let prompts = prompts::diverse_100(); - let mut categories: std::collections::HashMap<&str, usize> = std::collections::HashMap::new(); + let mut categories: std::collections::HashMap<&str, usize> = + std::collections::HashMap::new(); for p in &prompts { *categories.entry(p.category).or_default() += 1; } // Each category should have at least 10 prompts for (cat, count) in &categories { - assert!(*count >= 10, "Category '{cat}' has {count} prompts, expected >=10"); + assert!( + *count >= 10, + "Category '{cat}' has {count} prompts, expected >=10" + ); } // Total should be 100 let total: usize = categories.values().sum(); @@ -116,14 +126,20 @@ mod with_model { #[test] fn test_format_needle_results() { let results = vec![ - (512, vec![ - ("Standard KV".to_string(), true), - ("Markov RS".to_string(), true), - ]), - (32768, vec![ - ("Standard KV".to_string(), false), - ("Markov RS".to_string(), true), - ]), + ( + 512, + vec![ + ("Standard KV".to_string(), true), + ("Markov RS".to_string(), true), + ], + ), + ( + 32768, + vec![ + ("Standard KV".to_string(), false), + ("Markov RS".to_string(), true), + ], + ), ]; let table = needle::format_needle_results(&results); assert!(table.contains("PASS")); diff --git a/crates/kv-cache-benchmark/tests/test_apollo_accuracy.rs b/crates/kv-cache-benchmark/tests/test_apollo_accuracy.rs index c090a124..66be68c0 100644 --- a/crates/kv-cache-benchmark/tests/test_apollo_accuracy.rs +++ b/crates/kv-cache-benchmark/tests/test_apollo_accuracy.rs @@ -51,14 +51,17 @@ fn test_apollo_accuracy_sweep() { let mut engine = ApolloEngine::new(InjectionConfig::default()).with_store(store); engine.build_routing_index().expect("build routing"); - let model_path = std::env::var("LARQL_MODEL_PATH") - .unwrap_or_else(|_| "google/gemma-3-4b-it".to_string()); + let model_path = + std::env::var("LARQL_MODEL_PATH").unwrap_or_else(|_| "google/gemma-3-4b-it".to_string()); let model = larql_inference::InferenceModel::load(&model_path).expect("load model"); let weights = model.weights(); let tok = model.tokenizer(); println!("\n{}", "=".repeat(100)); - println!("Apollo accuracy sweep — {} queries × 2 paths", QUERIES.len()); + println!( + "Apollo accuracy sweep — {} queries × 2 paths", + QUERIES.len() + ); println!("{}", "=".repeat(100)); println!( @@ -75,9 +78,7 @@ fn test_apollo_accuracy_sweep() { match r { Ok(t) => { let t: &kv_cache_benchmark::apollo::QueryTrace = t; - let txt = tok - .decode(&[t.top1_token_id], false) - .unwrap_or_default(); + let txt = tok.decode(&[t.top1_token_id], false).unwrap_or_default(); ( format!("{:?} @ {:.1}", txt, t.top1_logit), t.context_tokens, @@ -97,10 +98,7 @@ fn test_apollo_accuracy_sweep() { }; let truncq: String = q.chars().take(46).collect(); - println!( - "{:<48} {:<20} {:<20} {}", - truncq, u_fmt, c_fmt, ratio - ); + println!("{:<48} {:<20} {:<20} {}", truncq, u_fmt, c_fmt, ratio); } println!(); } diff --git a/crates/kv-cache-benchmark/tests/test_apollo_query.rs b/crates/kv-cache-benchmark/tests/test_apollo_query.rs index cc29773c..9a5f2199 100644 --- a/crates/kv-cache-benchmark/tests/test_apollo_query.rs +++ b/crates/kv-cache-benchmark/tests/test_apollo_query.rs @@ -32,8 +32,8 @@ fn store_path() -> std::path::PathBuf { } fn load_model() -> larql_inference::InferenceModel { - let model_path = std::env::var("LARQL_MODEL_PATH") - .unwrap_or_else(|_| "google/gemma-3-4b-it".to_string()); + let model_path = + std::env::var("LARQL_MODEL_PATH").unwrap_or_else(|_| "google/gemma-3-4b-it".to_string()); larql_inference::InferenceModel::load(&model_path).expect("load gemma") } @@ -49,11 +49,7 @@ fn test_routing_resolves_porridge_to_w170_region() { let model = load_model(); let tok = model.tokenizer(); - for query in [ - "porridge eating contest", - "Corby England", - "John Coyle", - ] { + for query in ["porridge eating contest", "Corby England", "John Coyle"] { let enc = tok.encode(query, false).expect("tokenize"); let qids: Vec = enc.get_ids().to_vec(); let q = kv_cache_benchmark::apollo::RoutingQuery { token_ids: qids }; @@ -85,9 +81,7 @@ fn test_retrieve_entries_for_query() { assert!(!windows.is_empty()); // Retrieve entries scoped to routed windows - let entries = engine - .retrieve_entries(&qids, &windows) - .expect("retrieve"); + let entries = engine.retrieve_entries(&qids, &windows).expect("retrieve"); println!(" retrieved {} entries", entries.len()); for e in entries.iter().take(10) { let txt = tok.decode(&[e.token_id], false).unwrap_or_default(); @@ -135,7 +129,9 @@ fn test_end_to_end_query_produces_nonempty_answer() { ); } println!(" context tokens: {}", trace.context_tokens); - let top1_txt = tok.decode(&[trace.top1_token_id], false).unwrap_or_default(); + let top1_txt = tok + .decode(&[trace.top1_token_id], false) + .unwrap_or_default(); println!( " top-1 prediction: token {} ({top1_txt:?}) logit={:.3}", trace.top1_token_id, trace.top1_logit, @@ -189,7 +185,9 @@ fn test_end_to_end_query_compressed_path() { e.token_id, e.coefficient, e.window_id, ); } - let top1_txt = tok.decode(&[trace.top1_token_id], false).unwrap_or_default(); + let top1_txt = tok + .decode(&[trace.top1_token_id], false) + .unwrap_or_default(); println!( " top-1 prediction: token {} ({top1_txt:?}) logit={:.3}", trace.top1_token_id, trace.top1_logit, @@ -231,18 +229,12 @@ fn test_apollo_generate_compressed() { println!("\n=== Apollo iterative decode (COMPRESSED path) ==="); println!(" query: {query:?}"); - println!( - " routed windows: {:?}", - trace.routed_windows - ); + println!(" routed windows: {:?}", trace.routed_windows); println!( " initial context: {} tokens (boundary + query)", trace.initial_context_tokens, ); - println!( - " injected entries ({}):", - trace.injected_entries.len() - ); + println!(" injected entries ({}):", trace.injected_entries.len()); for e in &trace.injected_entries { let txt = tok.decode(&[e.token_id], false).unwrap_or_default(); println!( @@ -250,7 +242,11 @@ fn test_apollo_generate_compressed() { e.token_id, e.coefficient, ); } - println!(" generated ({} tokens, stopped_on_eos={}):", trace.generated_token_ids.len(), trace.stopped_on_eos); + println!( + " generated ({} tokens, stopped_on_eos={}):", + trace.generated_token_ids.len(), + trace.stopped_on_eos + ); println!(" {generated_text:?}"); print!(" per-step logits:"); for v in &trace.per_step_logits { diff --git a/crates/kv-cache-benchmark/tests/test_comparative.rs b/crates/kv-cache-benchmark/tests/test_comparative.rs index 9d633f1a..0b09cd75 100644 --- a/crates/kv-cache-benchmark/tests/test_comparative.rs +++ b/crates/kv-cache-benchmark/tests/test_comparative.rs @@ -1,10 +1,10 @@ -use kv_cache_benchmark::*; use kv_cache_benchmark::benchmark; +use kv_cache_benchmark::graph_walk::GraphWalk; +use kv_cache_benchmark::markov_residual::MarkovResidual; use kv_cache_benchmark::model_config::ModelConfig; use kv_cache_benchmark::standard_kv::StandardKv; use kv_cache_benchmark::turboquant::TurboQuant; -use kv_cache_benchmark::markov_residual::MarkovResidual; -use kv_cache_benchmark::graph_walk::GraphWalk; +use kv_cache_benchmark::*; #[test] fn test_all_strategies_memory_ordering() { @@ -21,23 +21,34 @@ fn test_all_strategies_memory_ordering() { let mem_gw = graph.memory_bytes(seq_len); // Standard KV is always the largest. - assert!(mem_std > mem_tq, "At {seq_len}: Standard ({mem_std}) > TurboQuant ({mem_tq})"); + assert!( + mem_std > mem_tq, + "At {seq_len}: Standard ({mem_std}) > TurboQuant ({mem_tq})" + ); // MarkovRS W=512 is bounded by the hot window (~192 MB) regardless of seq_len. // At short contexts (<~11K) the window dominates and MarkovRS > TurboQuant. // At long contexts TurboQuant grows larger. Both beat standard KV. - assert!(mem_std > mem_mrk, "At {seq_len}: Standard ({mem_std}) > Markov RS ({mem_mrk})"); + assert!( + mem_std > mem_mrk, + "At {seq_len}: Standard ({mem_std}) > Markov RS ({mem_mrk})" + ); // Graph Walk is the per-conversation minimum (token IDs only). - assert!(mem_gw < mem_mrk, "At {seq_len}: Graph Walk ({mem_gw}) < Markov RS ({mem_mrk})"); + assert!( + mem_gw < mem_mrk, + "At {seq_len}: Graph Walk ({mem_gw}) < Markov RS ({mem_mrk})" + ); } // At very long contexts, MarkovRS stays flat while TurboQuant grows O(n). // Crossover: MarkovRS fixed window (~192 MB) < TurboQuant at ~11K+ tokens. let mem_mrk_370k = markov.memory_bytes(&config, 370_000) as f64; - let mem_tq_370k = tq4.memory_bytes(&config, 370_000) as f64; - assert!(mem_tq_370k > mem_mrk_370k, - "At 370K: TurboQuant ({mem_tq_370k:.0}) should exceed Markov RS ({mem_mrk_370k:.0})"); + let mem_tq_370k = tq4.memory_bytes(&config, 370_000) as f64; + assert!( + mem_tq_370k > mem_mrk_370k, + "At 370K: TurboQuant ({mem_tq_370k:.0}) should exceed Markov RS ({mem_mrk_370k:.0})" + ); } #[test] @@ -56,7 +67,11 @@ fn test_memory_sweep_produces_data() { assert_eq!(points.len(), 9); for point in &points { - assert!(point.memory_bytes > 0, "Zero memory for {}", point.strategy_name); + assert!( + point.memory_bytes > 0, + "Zero memory for {}", + point.strategy_name + ); } } @@ -102,7 +117,10 @@ fn test_370k_memory_ratios() { assert!(ratio_mrk > 100.0, "Markov ratio: {ratio_mrk:.1}×"); // Graph Walk: per-conversation is even smaller (token IDs only). - assert!(ratio_gw > ratio_mrk, "Graph Walk should compress more than Markov RS"); + assert!( + ratio_gw > ratio_mrk, + "Graph Walk should compress more than Markov RS" + ); println!("At 370K tokens on {}:", config.name); println!(" Standard KV: {:.1} GB", mem_std / 1e9); diff --git a/crates/kv-cache-benchmark/tests/test_graph_walk.rs b/crates/kv-cache-benchmark/tests/test_graph_walk.rs index efeaa182..1d389097 100644 --- a/crates/kv-cache-benchmark/tests/test_graph_walk.rs +++ b/crates/kv-cache-benchmark/tests/test_graph_walk.rs @@ -1,6 +1,6 @@ -use kv_cache_benchmark::graph_walk::GraphWalk; -use kv_cache_benchmark::graph_walk::walk_state::{WalkState, WalkMode, WalkTier}; use kv_cache_benchmark::graph_walk::fallback::TierDistribution; +use kv_cache_benchmark::graph_walk::walk_state::{WalkMode, WalkState, WalkTier}; +use kv_cache_benchmark::graph_walk::GraphWalk; #[test] fn test_graph_walk_memory_tiny() { @@ -12,7 +12,10 @@ fn test_graph_walk_memory_tiny() { let mem_370k = gw.memory_bytes(370_000); assert_eq!(mem_370k, 370_000 * 4); - assert!(mem_370k < 2_000_000, "Graph walk per-conversation should be < 2MB"); + assert!( + mem_370k < 2_000_000, + "Graph walk per-conversation should be < 2MB" + ); } #[test] diff --git a/crates/kv-cache-benchmark/tests/test_markov.rs b/crates/kv-cache-benchmark/tests/test_markov.rs index b718b534..237e33b9 100644 --- a/crates/kv-cache-benchmark/tests/test_markov.rs +++ b/crates/kv-cache-benchmark/tests/test_markov.rs @@ -1,6 +1,6 @@ -use kv_cache_benchmark::*; -use kv_cache_benchmark::model_config::ModelConfig; use kv_cache_benchmark::markov_residual::MarkovResidual; +use kv_cache_benchmark::model_config::ModelConfig; +use kv_cache_benchmark::*; #[test] fn test_markov_cold_tier_size() { @@ -61,7 +61,10 @@ fn test_markov_much_smaller_than_standard() { // At 4K the window still dominates, but MarkovRS is still smaller than standard. let std_4k = standard.memory_bytes(&config, 4096); let mrk_4k = markov.memory_bytes(&config, 4096); - assert!(mrk_4k < std_4k, "Markov RS should be smaller than standard KV at 4K"); + assert!( + mrk_4k < std_4k, + "Markov RS should be smaller than standard KV at 4K" + ); } #[test] @@ -69,12 +72,8 @@ fn test_markov_encode_decode() { let strategy = MarkovResidual::new(4); let dim = 8; - let keys: Vec> = (0..10) - .map(|i| vec![i as f32; dim]) - .collect(); - let values: Vec> = (0..10) - .map(|i| vec![i as f32 + 100.0; dim]) - .collect(); + let keys: Vec> = (0..10).map(|i| vec![i as f32; dim]).collect(); + let values: Vec> = (0..10).map(|i| vec![i as f32 + 100.0; dim]).collect(); let encoded = strategy.encode(&keys, &values); let (dec_keys, _dec_values) = strategy.decode(&encoded, 10, dim); @@ -121,7 +120,8 @@ fn test_markov_reconstruction_exact() { assert!( (dec_keys[i][j] - keys[i][j]).abs() < 1e-6, "Not bit-perfect at [{i}][{j}]: {} vs {}", - dec_keys[i][j], keys[i][j], + dec_keys[i][j], + keys[i][j], ); } } diff --git a/crates/kv-cache-benchmark/tests/test_real_model.rs b/crates/kv-cache-benchmark/tests/test_real_model.rs index b31305a9..0e553bad 100644 --- a/crates/kv-cache-benchmark/tests/test_real_model.rs +++ b/crates/kv-cache-benchmark/tests/test_real_model.rs @@ -12,24 +12,22 @@ #![cfg(feature = "real-model")] -use kv_cache_benchmark::real_model::*; use kv_cache_benchmark::real_model::runner::*; +use kv_cache_benchmark::real_model::*; /// Helper to load model + vindex for tests. Returns None if model not available. /// Set LARQL_MODEL_PATH and LARQL_VINDEX_PATH env vars, or uses default HF paths. -fn load_test_model() -> Option<( - larql_inference::InferenceModel, - larql_vindex::VectorIndex, -)> { - let model_path = std::env::var("LARQL_MODEL_PATH") - .unwrap_or_else(|_| "google/gemma-3-4b-it".to_string()); +fn load_test_model() -> Option<(larql_inference::InferenceModel, larql_vindex::VectorIndex)> { + let model_path = + std::env::var("LARQL_MODEL_PATH").unwrap_or_else(|_| "google/gemma-3-4b-it".to_string()); let model = larql_inference::InferenceModel::load(&model_path).ok()?; let vindex_path = std::env::var("LARQL_VINDEX_PATH").ok()?; let index = larql_vindex::VectorIndex::load_vindex( std::path::Path::new(&vindex_path), &mut larql_vindex::SilentLoadCallbacks, - ).ok()?; + ) + .ok()?; Some((model, index)) } @@ -40,9 +38,8 @@ fn test_all_strategies_produce_paris() { let (model, index) = load_test_model().expect("Model not available"); let backend = larql_inference::default_backend(); - let bench = RealModelBenchmark::new( - model.weights(), model.tokenizer(), &index, backend.as_ref(), - ); + let bench = + RealModelBenchmark::new(model.weights(), model.tokenizer(), &index, backend.as_ref()); let results = run_all_strategies(&bench, "The capital of France is", 5, 512); @@ -74,8 +71,7 @@ fn test_all_strategies_produce_paris() { assert!( results[2].top1_match, "Markov RS top-1 didn't match baseline: got '{}', expected '{}'", - results[2].top1_token, - results[0].top1_token, + results[2].top1_token, results[0].top1_token, ); // Graph Walk @@ -91,9 +87,8 @@ fn test_markov_rs_bit_perfect() { let (model, index) = load_test_model().expect("Model not available"); let backend = larql_inference::default_backend(); - let bench = RealModelBenchmark::new( - model.weights(), model.tokenizer(), &index, backend.as_ref(), - ); + let bench = + RealModelBenchmark::new(model.weights(), model.tokenizer(), &index, backend.as_ref()); let prompts = default_prompts(); for prompt in &prompts { @@ -122,7 +117,10 @@ fn test_markov_rs_bit_perfect() { fn test_turboquant_compression_on_real_vectors() { let (model, _index) = load_test_model().expect("Model not available"); - let encoding = model.tokenizer().encode("The capital of France is", true).unwrap(); + let encoding = model + .tokenizer() + .encode("The capital of France is", true) + .unwrap(); let token_ids: Vec = encoding.get_ids().to_vec(); let kv = kv_capture::capture_kv(model.weights(), &token_ids); @@ -139,8 +137,16 @@ fn test_turboquant_compression_on_real_vectors() { // Cosine is the meaningful metric (scale-invariant). // Paper MSE target (0.009) is for unit-norm vectors; raw K/V have larger norms. // Cosine 0.991 on real vectors = near-lossless. - assert!(result.cosine_sim > 0.98, "Cosine too low: {}", result.cosine_sim); - assert!(result.compression_ratio > 3.0, "Compression too low: {}", result.compression_ratio); + assert!( + result.cosine_sim > 0.98, + "Cosine too low: {}", + result.cosine_sim + ); + assert!( + result.compression_ratio > 3.0, + "Compression too low: {}", + result.compression_ratio + ); println!(" Note: MSE is on raw vectors (not unit-norm). Cosine is the fair metric."); } @@ -150,9 +156,8 @@ fn test_multi_turn_memory_bounded() { let (model, index) = load_test_model().expect("Model not available"); let backend = larql_inference::default_backend(); - let bench = RealModelBenchmark::new( - model.weights(), model.tokenizer(), &index, backend.as_ref(), - ); + let bench = + RealModelBenchmark::new(model.weights(), model.tokenizer(), &index, backend.as_ref()); // Simulate growing context let base_prompt = "The capital of France is Paris. The capital of Germany is Berlin. "; @@ -187,9 +192,8 @@ fn test_graph_walk_factual_accuracy() { let (model, index) = load_test_model().expect("Model not available"); let backend = larql_inference::default_backend(); - let bench = RealModelBenchmark::new( - model.weights(), model.tokenizer(), &index, backend.as_ref(), - ); + let bench = + RealModelBenchmark::new(model.weights(), model.tokenizer(), &index, backend.as_ref()); let prompts = default_prompts(); let mut matches = 0; @@ -218,9 +222,8 @@ fn test_graph_walk_factual_accuracy() { fn test_accuracy_top1_factual_20() { let (model, index) = load_test_model().expect("Model not available"); let backend = larql_inference::default_backend(); - let bench = RealModelBenchmark::new( - model.weights(), model.tokenizer(), &index, backend.as_ref(), - ); + let bench = + RealModelBenchmark::new(model.weights(), model.tokenizer(), &index, backend.as_ref()); let prompts = kv_cache_benchmark::accuracy::factual_prompts(); let total = prompts.len(); @@ -271,11 +274,14 @@ fn test_accuracy_top1_factual_20() { fn test_accuracy_markov_rs_bitperfect() { let (model, index) = load_test_model().expect("Model not available"); let backend = larql_inference::default_backend(); - let bench = RealModelBenchmark::new( - model.weights(), model.tokenizer(), &index, backend.as_ref(), - ); - - for prompt in &["The capital of France is", "Mozart was born in", "Water freezes at"] { + let bench = + RealModelBenchmark::new(model.weights(), model.tokenizer(), &index, backend.as_ref()); + + for prompt in &[ + "The capital of France is", + "Mozart was born in", + "Water freezes at", + ] { let results = runner::run_all_strategies(&bench, prompt, 5, 512); let markov = &results[2]; @@ -301,9 +307,8 @@ fn test_accuracy_markov_rs_bitperfect() { fn test_needle_short_512() { let (model, index) = load_test_model().expect("Model not available"); let backend = larql_inference::default_backend(); - let bench = RealModelBenchmark::new( - model.weights(), model.tokenizer(), &index, backend.as_ref(), - ); + let bench = + RealModelBenchmark::new(model.weights(), model.tokenizer(), &index, backend.as_ref()); // Plant a fact early, query it at the end let prompt = "The secret code is AURORA-7749. Remember this. Now, some filler text about various topics. The weather is nice today. The sky is blue. What is the secret code?"; @@ -311,8 +316,16 @@ fn test_needle_short_512() { // All strategies should find AURORA or 7749 in their predictions for r in &results { - let top5_text: String = r.top5.iter().map(|(t, _)| t.as_str()).collect::>().join(" "); - println!("{}: top-1='{}', top-5=[{}]", r.strategy, r.top1_token, top5_text); + let top5_text: String = r + .top5 + .iter() + .map(|(t, _)| t.as_str()) + .collect::>() + .join(" "); + println!( + "{}: top-1='{}', top-5=[{}]", + r.strategy, r.top1_token, top5_text + ); } } @@ -323,9 +336,8 @@ fn test_needle_short_512() { fn test_adversarial_entity_confusion() { let (model, index) = load_test_model().expect("Model not available"); let backend = larql_inference::default_backend(); - let bench = RealModelBenchmark::new( - model.weights(), model.tokenizer(), &index, backend.as_ref(), - ); + let bench = + RealModelBenchmark::new(model.weights(), model.tokenizer(), &index, backend.as_ref()); // Same template, different entities — must give different answers let pairs = vec![ @@ -354,7 +366,8 @@ fn test_needle_scaling_context() { let needle = "The secret project code name is AURORA-7749."; let query = " What is the secret project code name?"; - let filler_sentence = "The quick brown fox jumps over the lazy dog near the old oak tree by the river. "; + let filler_sentence = + "The quick brown fox jumps over the lazy dog near the old oak tree by the river. "; // Test at increasing context lengths for target_tokens in [512, 1024, 2048, 4096] { @@ -375,7 +388,10 @@ fn test_needle_scaling_context() { context.push_str(query); // Tokenize and check actual length - let encoding = model.tokenizer().encode(context.as_str(), true).expect("tokenize"); + let encoding = model + .tokenizer() + .encode(context.as_str(), true) + .expect("tokenize"); let token_ids: Vec = encoding.get_ids().to_vec(); let actual_tokens = token_ids.len(); @@ -385,19 +401,31 @@ fn test_needle_scaling_context() { let elapsed = t0.elapsed(); // Check if AURORA or 7749 appears in top-10 - let top10_text: String = result.predictions.iter() + let top10_text: String = result + .predictions + .iter() .map(|(t, _)| t.as_str()) .collect::>() .join(" "); - let needle_found = top10_text.contains("AUR") || top10_text.contains("7749") || top10_text.contains("AURORA"); + let needle_found = top10_text.contains("AUR") + || top10_text.contains("7749") + || top10_text.contains("AURORA"); - let top1 = result.predictions.first().map(|(t, _)| t.as_str()).unwrap_or("?"); + let top1 = result + .predictions + .first() + .map(|(t, _)| t.as_str()) + .unwrap_or("?"); let found_mark = if needle_found { "FOUND" } else { "MISSED" }; println!( " {:>6} tokens (actual {:>5}): top-1='{}' needle={} [{:.1}s] top-10=[{}]", - target_tokens, actual_tokens, top1, found_mark, - elapsed.as_secs_f64(), top10_text, + target_tokens, + actual_tokens, + top1, + found_mark, + elapsed.as_secs_f64(), + top10_text, ); } } @@ -411,12 +439,15 @@ fn test_needle_bounded_window_vs_full() { let needle = "The secret project code name is AURORA-7749."; let query = " What is the secret project code name?"; - let filler_sentence = "The quick brown fox jumps over the lazy dog near the old oak tree by the river. "; + let filler_sentence = + "The quick brown fox jumps over the lazy dog near the old oak tree by the river. "; let window_size = 512; println!("\n=== Needle: Standard KV (full context) vs Markov RS (bounded window) ===\n"); - println!("{:>8} {:>8} {:>12} {:>12} {:>12} {:>12}", - "Target", "Actual", "StdKV top-1", "StdKV needle", "MarkovRS t1", "MarkovRS ndl"); + println!( + "{:>8} {:>8} {:>12} {:>12} {:>12} {:>12}", + "Target", "Actual", "StdKV top-1", "StdKV needle", "MarkovRS t1", "MarkovRS ndl" + ); println!("{}", "-".repeat(75)); for target_tokens in [512, 1024, 2048, 4096] { @@ -438,21 +469,36 @@ fn test_needle_bounded_window_vs_full() { context.push_str(query); // === Standard KV: full context forward pass === - let full_encoding = model.tokenizer().encode(context.as_str(), true).expect("tokenize"); + let full_encoding = model + .tokenizer() + .encode(context.as_str(), true) + .expect("tokenize"); let full_ids: Vec = full_encoding.get_ids().to_vec(); let full_len = full_ids.len(); - let full_result = larql_inference::predict(model.weights(), model.tokenizer(), &full_ids, 10); - let full_top10: String = full_result.predictions.iter() - .map(|(t, _)| t.as_str()).collect::>().join(" "); - let full_found = full_top10.contains("AUR") || full_top10.contains("7749") || full_top10.contains("AURORA"); - let full_top1 = full_result.predictions.first().map(|(t, _)| t.as_str()).unwrap_or("?"); + let full_result = + larql_inference::predict(model.weights(), model.tokenizer(), &full_ids, 10); + let full_top10: String = full_result + .predictions + .iter() + .map(|(t, _)| t.as_str()) + .collect::>() + .join(" "); + let full_found = full_top10.contains("AUR") + || full_top10.contains("7749") + || full_top10.contains("AURORA"); + let full_top1 = full_result + .predictions + .first() + .map(|(t, _)| t.as_str()) + .unwrap_or("?"); // === Markov RS: bounded window around needle + query === // Find which token position the needle is at - let needle_encoding = model.tokenizer().encode( - &context[..needle_char_pos + needle.len()], true - ).expect("tokenize needle prefix"); + let needle_encoding = model + .tokenizer() + .encode(&context[..needle_char_pos + needle.len()], true) + .expect("tokenize needle prefix"); let needle_token_pos = needle_encoding.get_ids().len(); // Window: 256 tokens before needle, needle tokens, then skip to query @@ -460,7 +506,10 @@ fn test_needle_bounded_window_vs_full() { let needle_end = needle_token_pos + 20; // needle is ~15 tokens // Build windowed token sequence: [window around needle] + [query tokens] - let query_encoding = model.tokenizer().encode(query, false).expect("tokenize query"); + let query_encoding = model + .tokenizer() + .encode(query, false) + .expect("tokenize query"); let query_ids: Vec = query_encoding.get_ids().to_vec(); let mut windowed_ids: Vec = Vec::new(); @@ -474,17 +523,29 @@ fn test_needle_bounded_window_vs_full() { let windowed_len = windowed_ids.len(); - let win_result = larql_inference::predict(model.weights(), model.tokenizer(), &windowed_ids, 10); - let win_top10: String = win_result.predictions.iter() - .map(|(t, _)| t.as_str()).collect::>().join(" "); - let win_found = win_top10.contains("AUR") || win_top10.contains("7749") || win_top10.contains("AURORA"); - let win_top1 = win_result.predictions.first().map(|(t, _)| t.as_str()).unwrap_or("?"); + let win_result = + larql_inference::predict(model.weights(), model.tokenizer(), &windowed_ids, 10); + let win_top10: String = win_result + .predictions + .iter() + .map(|(t, _)| t.as_str()) + .collect::>() + .join(" "); + let win_found = + win_top10.contains("AUR") || win_top10.contains("7749") || win_top10.contains("AURORA"); + let win_top1 = win_result + .predictions + .first() + .map(|(t, _)| t.as_str()) + .unwrap_or("?"); let full_mark = if full_found { "FOUND" } else { "MISSED" }; let win_mark = if win_found { "FOUND" } else { "MISSED" }; - println!("{:>8} {:>8} {:>12} {:>12} {:>12} {:>12} (window={}tok)", - target_tokens, full_len, full_top1, full_mark, win_top1, win_mark, windowed_len); + println!( + "{:>8} {:>8} {:>12} {:>12} {:>12} {:>12} (window={}tok)", + target_tokens, full_len, full_top1, full_mark, win_top1, win_mark, windowed_len + ); } println!("\nStandard KV = full forward pass over all tokens (softmax over full context)"); @@ -504,8 +565,14 @@ fn test_multi_turn_fact_retention() { // Establish facts then query them after filler turns let facts = [ ("My name is Alice and I work at Anthropic.", "Alice"), - ("I live in San Francisco near the Golden Gate Bridge.", "San Francisco"), - ("My current project is called Lighthouse and it launches in March.", "Lighthouse"), + ( + "I live in San Francisco near the Golden Gate Bridge.", + "San Francisco", + ), + ( + "My current project is called Lighthouse and it launches in March.", + "Lighthouse", + ), ]; let filler_turns = vec![ @@ -528,7 +595,7 @@ fn test_multi_turn_fact_retention() { // Build the full conversation as a single prompt // (simulates multi-turn by concatenating with turn markers) let mut conversation = String::new(); - + // Establish facts (turns 1-3) for (fact, _) in facts.iter() { conversation.push_str(&format!("User: {fact}\nAssistant: I'll remember that.\n\n")); @@ -536,7 +603,9 @@ fn test_multi_turn_fact_retention() { // Filler turns (turns 4-11) for filler in &filler_turns { - conversation.push_str(&format!("User: {filler}\nAssistant: Sure, let me explain briefly.\n\n")); + conversation.push_str(&format!( + "User: {filler}\nAssistant: Sure, let me explain briefly.\n\n" + )); } // Query turn @@ -544,19 +613,32 @@ fn test_multi_turn_fact_retention() { let mut prompt = conversation.clone(); prompt.push_str(&format!("User: {query}\nAssistant:")); - let encoding = model.tokenizer().encode(prompt.as_str(), true).expect("tokenize"); + let encoding = model + .tokenizer() + .encode(prompt.as_str(), true) + .expect("tokenize"); let token_ids: Vec = encoding.get_ids().to_vec(); let num_tokens = token_ids.len(); let result = larql_inference::predict(model.weights(), model.tokenizer(), &token_ids, 10); - let top10: String = result.predictions.iter() - .map(|(t, _)| t.as_str()).collect::>().join("|"); - let top1 = result.predictions.first().map(|(t, _)| t.as_str()).unwrap_or("?"); - + let top10: String = result + .predictions + .iter() + .map(|(t, _)| t.as_str()) + .collect::>() + .join("|"); + let top1 = result + .predictions + .first() + .map(|(t, _)| t.as_str()) + .unwrap_or("?"); + let found = top10.to_lowercase().contains(&expected.to_lowercase()); let mark = if found { "FOUND" } else { "MISSED" }; - println!(" Q: {query:<40} top-1='{top1}' {mark} (expected '{expected}', {num_tokens} tokens)"); + println!( + " Q: {query:<40} top-1='{top1}' {mark} (expected '{expected}', {num_tokens} tokens)" + ); println!(" top-10: [{top10}]"); } } @@ -607,9 +689,17 @@ fn test_generation_stability_50_tokens() { } let generated_text = generated_tokens.join(""); - let short_prompt = if prompt.len() > 60 { &prompt[..60] } else { prompt }; + let short_prompt = if prompt.len() > 60 { + &prompt[..60] + } else { + prompt + }; println!(" Prompt: \"{short_prompt}...\""); - println!(" Generated ({} tokens): \"{}\"", generated_tokens.len(), generated_text); + println!( + " Generated ({} tokens): \"{}\"", + generated_tokens.len(), + generated_text + ); println!(" Coherent: {}\n", !generated_text.is_empty()); } @@ -631,7 +721,10 @@ fn test_needle_position_sweep() { let target_tokens = 2048; // Context length where StdKV fails println!("\n=== Needle Position Sweep at ~{target_tokens} tokens ===\n"); - println!("{:>10} {:>8} {:>12} {:>12}", "Position", "Actual", "Full ctx", "Window"); + println!( + "{:>10} {:>8} {:>12} {:>12}", + "Position", "Actual", "Full ctx", "Window" + ); println!("{}", "-".repeat(50)); // Test needle at 10%, 25%, 50%, 75%, 90% of context @@ -652,17 +745,30 @@ fn test_needle_position_sweep() { } context.push_str(query); - let full_enc = model.tokenizer().encode(context.as_str(), true).expect("tokenize"); + let full_enc = model + .tokenizer() + .encode(context.as_str(), true) + .expect("tokenize"); let full_ids: Vec = full_enc.get_ids().to_vec(); // Full context - let full_result = larql_inference::predict(model.weights(), model.tokenizer(), &full_ids, 10); - let full_top10: String = full_result.predictions.iter() - .map(|(t, _)| t.as_str()).collect::>().join(" "); - let full_found = full_top10.contains("AUR") || full_top10.contains("7749") || full_top10.contains("AURORA"); + let full_result = + larql_inference::predict(model.weights(), model.tokenizer(), &full_ids, 10); + let full_top10: String = full_result + .predictions + .iter() + .map(|(t, _)| t.as_str()) + .collect::>() + .join(" "); + let full_found = full_top10.contains("AUR") + || full_top10.contains("7749") + || full_top10.contains("AURORA"); // Bounded window around needle - let needle_enc = model.tokenizer().encode(&context[..needle_char_start + needle.len()], true).expect("tok"); + let needle_enc = model + .tokenizer() + .encode(&context[..needle_char_start + needle.len()], true) + .expect("tok"); let needle_tok_pos = needle_enc.get_ids().len(); let win_start = needle_tok_pos.saturating_sub(64); let win_end = (needle_tok_pos + 20).min(full_ids.len()); @@ -671,13 +777,24 @@ fn test_needle_position_sweep() { win_ids.extend_from_slice(query_enc.get_ids()); let win_result = larql_inference::predict(model.weights(), model.tokenizer(), &win_ids, 10); - let win_top10: String = win_result.predictions.iter() - .map(|(t, _)| t.as_str()).collect::>().join(" "); - let win_found = win_top10.contains("AUR") || win_top10.contains("7749") || win_top10.contains("AURORA"); + let win_top10: String = win_result + .predictions + .iter() + .map(|(t, _)| t.as_str()) + .collect::>() + .join(" "); + let win_found = + win_top10.contains("AUR") || win_top10.contains("7749") || win_top10.contains("AURORA"); let full_mark = if full_found { "FOUND" } else { "MISSED" }; let win_mark = if win_found { "FOUND" } else { "MISSED" }; - println!("{:>9}% {:>8} {:>12} {:>12}", pct, full_ids.len(), full_mark, win_mark); + println!( + "{:>9}% {:>8} {:>12} {:>12}", + pct, + full_ids.len(), + full_mark, + win_mark + ); } } @@ -690,11 +807,31 @@ fn test_multifact_5_facts_at_2k() { let filler = "The quick brown fox jumps over the lazy dog near the old oak tree by the river. "; let facts = vec![ - ("Agent Alpha code name is FALCON.", "FALCON", "What is Agent Alpha's code name?"), - ("The launch date is March 15th.", "March", "What is the launch date?"), - ("Budget allocation is 4.7 million dollars.", "4.7", "What is the budget?"), - ("The target city is Reykjavik.", "Reykjavik", "What is the target city?"), - ("Project sponsor is Dr. Kimura.", "Kimura", "Who is the project sponsor?"), + ( + "Agent Alpha code name is FALCON.", + "FALCON", + "What is Agent Alpha's code name?", + ), + ( + "The launch date is March 15th.", + "March", + "What is the launch date?", + ), + ( + "Budget allocation is 4.7 million dollars.", + "4.7", + "What is the budget?", + ), + ( + "The target city is Reykjavik.", + "Reykjavik", + "What is the target city?", + ), + ( + "Project sponsor is Dr. Kimura.", + "Kimura", + "Who is the project sponsor?", + ), ]; println!("\n=== Multi-Fact Retrieval: 5 facts in ~2K context ===\n"); @@ -725,32 +862,53 @@ fn test_multifact_5_facts_at_2k() { let mut prompt = context.clone(); prompt.push_str(&format!(" {query}")); - let enc = model.tokenizer().encode(prompt.as_str(), true).expect("tok"); + let enc = model + .tokenizer() + .encode(prompt.as_str(), true) + .expect("tok"); let full_ids: Vec = enc.get_ids().to_vec(); // Full context let result = larql_inference::predict(model.weights(), model.tokenizer(), &full_ids, 10); - let top10: String = result.predictions.iter() - .map(|(t, _)| t.as_str()).collect::>().join(" "); + let top10: String = result + .predictions + .iter() + .map(|(t, _)| t.as_str()) + .collect::>() + .join(" "); let found_full = top10.to_lowercase().contains(&answer.to_lowercase()); - if found_full { full_found += 1; } + if found_full { + full_found += 1; + } // Window: find fact position, extract window around it let fact_pos = context.find(*fact).unwrap_or(0); - let fact_enc = model.tokenizer().encode(&context[..fact_pos + fact.len()], true).expect("tok"); + let fact_enc = model + .tokenizer() + .encode(&context[..fact_pos + fact.len()], true) + .expect("tok"); let fact_tok = fact_enc.get_ids().len(); let ws = fact_tok.saturating_sub(32); let we = (fact_tok + 20).min(full_ids.len()); let q_str = format!(" {query}"); - let query_enc = model.tokenizer().encode(q_str.as_str(), false).expect("tok"); + let query_enc = model + .tokenizer() + .encode(q_str.as_str(), false) + .expect("tok"); let mut win_ids: Vec = full_ids[ws..we].to_vec(); win_ids.extend_from_slice(query_enc.get_ids()); let win_result = larql_inference::predict(model.weights(), model.tokenizer(), &win_ids, 10); - let win_top10: String = win_result.predictions.iter() - .map(|(t, _)| t.as_str()).collect::>().join(" "); + let win_top10: String = win_result + .predictions + .iter() + .map(|(t, _)| t.as_str()) + .collect::>() + .join(" "); let found_win = win_top10.to_lowercase().contains(&answer.to_lowercase()); - if found_win { win_found += 1; } + if found_win { + win_found += 1; + } let fm = if found_full { "FOUND" } else { "MISSED" }; let wm = if found_win { "FOUND" } else { "MISSED" }; @@ -790,7 +948,10 @@ fn test_conflict_context_overrides_parametric() { ), ]; - println!("{:<25} {:>12} {:>12} {:>15}", "Test", "Top-1", "Context?", "Parametric?"); + println!( + "{:<25} {:>12} {:>12} {:>15}", + "Test", "Top-1", "Context?", "Parametric?" + ); println!("{}", "-".repeat(70)); for (prompt, context_answer, parametric_answer, label) in &tests { @@ -798,20 +959,123 @@ fn test_conflict_context_overrides_parametric() { let ids: Vec = enc.get_ids().to_vec(); let result = larql_inference::predict(model.weights(), model.tokenizer(), &ids, 10); - let top1 = result.predictions.first().map(|(t, _)| t.clone()).unwrap_or_default(); - let top10: String = result.predictions.iter() - .map(|(t, _)| t.as_str()).collect::>().join(" "); + let top1 = result + .predictions + .first() + .map(|(t, _)| t.clone()) + .unwrap_or_default(); + let top10: String = result + .predictions + .iter() + .map(|(t, _)| t.as_str()) + .collect::>() + .join(" "); - let follows_context = top10.to_lowercase().contains(&context_answer.to_lowercase()); - let follows_parametric = top10.to_lowercase().contains(¶metric_answer.to_lowercase()); + let follows_context = top10 + .to_lowercase() + .contains(&context_answer.to_lowercase()); + let follows_parametric = top10 + .to_lowercase() + .contains(¶metric_answer.to_lowercase()); let ctx_mark = if follows_context { "YES" } else { "no" }; let par_mark = if follows_parametric { "YES" } else { "no" }; - println!("{:<25} {:>12} {:>12} {:>15}", label, top1, ctx_mark, par_mark); + println!( + "{:<25} {:>12} {:>12} {:>15}", + label, top1, ctx_mark, par_mark + ); } println!("\nNote: Standard KV should follow context (full attention sees it)."); println!("Markov RS follows context IF in bounded window, parametric if outside."); println!("Graph Walk always follows parametric (graph is weights, not context)."); } + +/// Engine performance benchmark: times each KvEngine on a suite of prompts, +/// reports prefill ms, memory breakdown, compression ratio vs Standard KV. +/// +/// Run with: +/// cargo test --features real-model -p kv-cache-benchmark \ +/// --test test_real_model test_engine_performance -- --ignored --nocapture +#[test] +#[ignore] +fn test_engine_performance() { + let (model, _index) = load_test_model().expect("Model not available"); + let backend = larql_inference::default_backend(); + + let prompts = [ + "The capital of France is", + "The population of Tokyo is approximately", + "In the beginning God created the heavens and the", + ]; + + for prompt in &prompts { + let results = kv_cache_benchmark::real_model::runner::run_all_engines_bench( + model.weights(), + model.tokenizer(), + prompt, + 512, + backend.as_ref(), + ); + println!( + "{}", + kv_cache_benchmark::real_model::runner::format_engine_results(&results) + ); + + for r in &results { + // Accuracy: hidden cosine must be high (same forward path as Standard KV) + assert!( + r.hidden_cosine > 0.99, + "{}: cosine {:.4} < 0.99 for {:?}", + r.engine, + r.hidden_cosine, + prompt, + ); + // Memory: engine state should be smaller than Standard KV reference + assert!( + r.total_bytes < r.kv_ref_bytes, + "{}: engine mem {}B >= kv_ref {}B", + r.engine, + r.total_bytes, + r.kv_ref_bytes, + ); + } + } +} + +/// Side-by-side prefill timing: Standard KV (via run_all_strategies) vs all KvEngines. +/// Useful for measuring the cost of the residual-recompute path vs straight KV capture. +#[test] +#[ignore] +fn test_prefill_timing_comparison() { + let (model, index) = load_test_model().expect("Model not available"); + let backend = larql_inference::default_backend(); + let bench = kv_cache_benchmark::real_model::runner::RealModelBenchmark::new( + model.weights(), + model.tokenizer(), + &index, + backend.as_ref(), + ); + + let prompt = "The capital of France is"; + + let strategies = + kv_cache_benchmark::real_model::runner::run_all_strategies(&bench, prompt, 5, 512); + println!( + "{}", + kv_cache_benchmark::real_model::runner::format_results(&strategies) + ); + + let engines = kv_cache_benchmark::real_model::runner::run_all_engines_bench( + model.weights(), + model.tokenizer(), + prompt, + 512, + backend.as_ref(), + ); + println!( + "{}", + kv_cache_benchmark::real_model::runner::format_engine_results(&engines) + ); +} diff --git a/crates/kv-cache-benchmark/tests/test_shaders.rs b/crates/kv-cache-benchmark/tests/test_shaders.rs index 5f4a88f6..73db49fd 100644 --- a/crates/kv-cache-benchmark/tests/test_shaders.rs +++ b/crates/kv-cache-benchmark/tests/test_shaders.rs @@ -6,7 +6,10 @@ fn test_wht_cpu_benchmark() { assert_eq!(result.dimension, 256); assert!(result.time_us > 0.0); assert!(result.throughput_ops_per_sec > 0.0); - println!("WHT d=256: {:.2} us/op, {:.0} ops/sec", result.time_us, result.throughput_ops_per_sec); + println!( + "WHT d=256: {:.2} us/op, {:.0} ops/sec", + result.time_us, result.throughput_ops_per_sec + ); } #[test] @@ -74,5 +77,8 @@ fn test_wht_d128_faster_than_d256() { // d=128 should be faster (fewer butterfly stages) // Allow some margin for noise - println!("WHT d=128: {:.2} us, d=256: {:.2} us", r128.time_us, r256.time_us); + println!( + "WHT d=128: {:.2} us, d=256: {:.2} us", + r128.time_us, r256.time_us + ); } diff --git a/crates/kv-cache-benchmark/tests/test_standard.rs b/crates/kv-cache-benchmark/tests/test_standard.rs index fc6895fe..85f84970 100644 --- a/crates/kv-cache-benchmark/tests/test_standard.rs +++ b/crates/kv-cache-benchmark/tests/test_standard.rs @@ -1,6 +1,6 @@ -use kv_cache_benchmark::*; use kv_cache_benchmark::model_config::ModelConfig; use kv_cache_benchmark::standard_kv::StandardKv; +use kv_cache_benchmark::*; use rand::prelude::*; #[test] @@ -76,7 +76,11 @@ fn test_standard_kv_benchmark_runs() { assert_eq!(result.strategy_name, "Standard KV (FP16)"); assert_eq!(result.seq_len, 64); // MSE should be very small (FP16 quantization noise only) - assert!(result.metrics.mse < 0.001, "MSE too high: {}", result.metrics.mse); + assert!( + result.metrics.mse < 0.001, + "MSE too high: {}", + result.metrics.mse + ); // Cosine sim should be very high assert!( result.metrics.cosine_sim > 0.999, diff --git a/crates/kv-cache-benchmark/tests/test_turboquant.rs b/crates/kv-cache-benchmark/tests/test_turboquant.rs index db063240..c735130d 100644 --- a/crates/kv-cache-benchmark/tests/test_turboquant.rs +++ b/crates/kv-cache-benchmark/tests/test_turboquant.rs @@ -1,8 +1,8 @@ -use kv_cache_benchmark::*; use kv_cache_benchmark::metrics::Metrics; use kv_cache_benchmark::model_config::ModelConfig; -use kv_cache_benchmark::turboquant::TurboQuant; use kv_cache_benchmark::turboquant::rotation; +use kv_cache_benchmark::turboquant::TurboQuant; +use kv_cache_benchmark::*; use rand::prelude::*; #[test] @@ -138,7 +138,10 @@ fn test_turboquant_benchmark_runs() { let result = kv_cache_benchmark::run_strategy_benchmark(&tq, &config, 32, &mut rng); assert_eq!(result.strategy_name, "TurboQuant 4-bit"); - assert!(result.metrics.mse > 0.0, "MSE should be non-zero for lossy compression"); + assert!( + result.metrics.mse > 0.0, + "MSE should be non-zero for lossy compression" + ); assert!(result.metrics.cosine_sim > 0.9, "Cosine should be high"); assert!(result.metrics.compression_ratio > 1.0, "Should compress"); } diff --git a/crates/kv-cache-benchmark/tests/test_unlimited_context.rs b/crates/kv-cache-benchmark/tests/test_unlimited_context.rs index 80b83f18..bc4c2f1f 100644 --- a/crates/kv-cache-benchmark/tests/test_unlimited_context.rs +++ b/crates/kv-cache-benchmark/tests/test_unlimited_context.rs @@ -9,13 +9,11 @@ #![cfg(feature = "real-model")] -use kv_cache_benchmark::unlimited_context::{ - rs_extend_from_checkpoint, UnlimitedContextEngine, -}; +use kv_cache_benchmark::unlimited_context::{rs_extend_from_checkpoint, UnlimitedContextEngine}; fn load_model() -> Option { - let model_path = std::env::var("LARQL_MODEL_PATH") - .unwrap_or_else(|_| "google/gemma-3-4b-it".to_string()); + let model_path = + std::env::var("LARQL_MODEL_PATH").unwrap_or_else(|_| "google/gemma-3-4b-it".to_string()); larql_inference::InferenceModel::load(&model_path).ok() } @@ -54,9 +52,7 @@ fn test_window0_replay_bit_exact() { assert_eq!(engine.archive.len(), 1, "expected 1 archived window"); // Replay window 0 - let (replay_kv, _abs_end) = engine - .replay_window(weights, 0) - .expect("replay failed"); + let (replay_kv, _abs_end) = engine.replay_window(weights, 0).expect("replay failed"); // Independent fresh forward with empty prior let empty_prior = kv_cache_benchmark::unlimited_context::rs_extend_from_checkpoint( @@ -68,7 +64,11 @@ fn test_window0_replay_bit_exact() { .expect("fresh extend failed"); // Per-layer K cos should be 1.0 to float precision - for (li, ((k_r, v_r), (k_f, v_f))) in replay_kv.iter().zip(empty_prior.kv_cache.iter()).enumerate() { + for (li, ((k_r, v_r), (k_f, v_f))) in replay_kv + .iter() + .zip(empty_prior.kv_cache.iter()) + .enumerate() + { let ck = cos(k_r, k_f); let cv = cos(v_r, v_f); assert!(ck > 0.99999, "layer {li}: K cos {ck:.6} < 0.99999"); @@ -102,13 +102,21 @@ fn test_replay_is_deterministic() { // Replay window 1 twice let (kv_a, _) = engine.replay_window(weights, 1).expect("replay 1 failed"); - let (kv_b, _) = engine.replay_window(weights, 1).expect("replay 1 failed (2nd)"); + let (kv_b, _) = engine + .replay_window(weights, 1) + .expect("replay 1 failed (2nd)"); for (li, ((k_a, v_a), (k_b, v_b))) in kv_a.iter().zip(kv_b.iter()).enumerate() { let ck = cos(k_a, k_b); let cv = cos(v_a, v_b); - assert!(ck > 0.999999, "layer {li}: K not deterministic, cos {ck:.8}"); - assert!(cv > 0.999999, "layer {li}: V not deterministic, cos {cv:.8}"); + assert!( + ck > 0.999999, + "layer {li}: K not deterministic, cos {ck:.8}" + ); + assert!( + cv > 0.999999, + "layer {li}: V not deterministic, cos {cv:.8}" + ); } println!("replay is deterministic"); } @@ -125,7 +133,9 @@ fn test_compression_ratio() { // Build a ~256-token sequence let long = "The capital of France is Paris. ".repeat(32); - let enc = tokenizer.encode(long.as_str(), true).expect("tokenize failed"); + let enc = tokenizer + .encode(long.as_str(), true) + .expect("tokenize failed"); let tokens: Vec = enc.get_ids().to_vec(); let window_size = 64; @@ -162,12 +172,13 @@ fn test_extend_output_shapes() { let weights = model.weights(); let tokenizer = model.tokenizer(); - let enc = tokenizer.encode("Hello world.", true).expect("tokenize failed"); + let enc = tokenizer + .encode("Hello world.", true) + .expect("tokenize failed"); let tokens: Vec = enc.get_ids().to_vec(); let empty = kv_cache_benchmark::unlimited_context::__empty_prior_for_test(weights); - let out = rs_extend_from_checkpoint(weights, &tokens, &empty, 0) - .expect("extend failed"); + let out = rs_extend_from_checkpoint(weights, &tokens, &empty, 0).expect("extend failed"); assert_eq!(out.last_hidden.shape()[0], 1, "last_hidden should be 1 row"); assert_eq!(out.kv_cache.len(), weights.num_layers); diff --git a/crates/larql-cli/Cargo.toml b/crates/larql-cli/Cargo.toml index f5206582..f8bb48a6 100644 --- a/crates/larql-cli/Cargo.toml +++ b/crates/larql-cli/Cargo.toml @@ -17,7 +17,7 @@ larql-inference = { path = "../larql-inference" } larql-models = { path = "../larql-models" } larql-lql = { path = "../larql-lql" } larql-vindex = { path = "../larql-vindex" } -clap = { version = "4", features = ["derive"] } +clap = { version = "4", features = ["derive", "env"] } indicatif = "0.17" reqwest = { version = "0.12", features = ["blocking", "json"] } base64 = "0.22" diff --git a/crates/larql-cli/README.md b/crates/larql-cli/README.md index 03743a3f..5699252c 100644 --- a/crates/larql-cli/README.md +++ b/crates/larql-cli/README.md @@ -23,6 +23,20 @@ cargo run --release -p larql-cli -- repl # Serve over HTTP/gRPC cargo run --release -p larql-cli -- serve --dir output/ --port 8080 + +# Quantise an existing vindex (FP4 or GGML Q4_K_M) — see docs/specs/quantize-cli-spec.md +cargo run --release -p larql-cli -- convert quantize fp4 \ + --input output/gemma3-4b.vindex \ + --output output/gemma3-4b-fp4.vindex +cargo run --release -p larql-cli -- convert quantize q4k \ + --input output/gemma3-4b.vindex \ + --output output/gemma3-4b-q4k.vindex + +# Engine diagnostic — print which kernel paths the loader picks for a +# vindex, validate Q4_K/Q6_K strides, and (with --probe) run a real +# forward pass and print per-stage timings. +cargo run --release --features metal -p larql-cli -- diag \ + output/gemma3-4b-q4k-v2.vindex --probe --probe-tokens 50 ``` See [`docs/cli.md`](../../docs/cli.md) for the full command reference. @@ -32,6 +46,7 @@ See [`docs/cli.md`](../../docs/cli.md) for the full command reference. | Family | Commands | What they do | |---|---|---| | **Vindex lifecycle** | `extract-index`, `build`, `slice`, `publish`, `pull`, `compile`, `convert`, `verify`, `hf` | Extract, build from a Vindexfile, **carve deployment slices** (`client`/`attn`/`embed`/`server`/`browse`/`router`), **publish** (full + 5 default slice siblings + collections to HF with SHA256-skip-if-unchanged), **pull** (with sibling hints, `--preset`, `--all-slices`, `--collection`), bake patches into weights, convert GGUF↔vindex↔safetensors, checksum, low-level HF helper | +| **Diagnostics** | `bench`, `diag`, `parity`, `verify`, `stats`, `validate` | `bench` runs end-to-end decode throughput; `diag [--probe]` reports which kernel paths the loader will pick (lm_head fast/slow, attn fused/per-proj), validates Q4_K/Q6_K manifest strides against canonical 144-byte GGUF layout, and surfaces the silent-slowdown classes (stale 148-byte stride, `vocab_size=0`) at a glance | | **LQL** | `repl`, `lql`, `query`, `describe`, `filter`, `merge`, `validate`, `stats` | Interactive REPL + one-shot LQL, plus lower-level graph utilities | | **Weight-space extraction** | `weight-extract`, `attention-extract`, `vector-extract`, `index-gates`, `qk-templates`, `qk-rank`, `qk-modes`, `ov-gate`, `circuit-discover`, `fingerprint-extract` | Pull edges / templates / circuits from the model weights — zero forward passes | | **Forward-pass analysis** | `predict`, `walk`, `residuals`, `attention-capture`, `extract-routes`, `trajectory-trace`, `bfs` | Run the model and capture residuals, attention patterns, trajectories | diff --git a/crates/larql-cli/ROADMAP.md b/crates/larql-cli/ROADMAP.md new file mode 100644 index 00000000..dff26fac --- /dev/null +++ b/crates/larql-cli/ROADMAP.md @@ -0,0 +1,154 @@ +# Roadmap — larql-cli + +## Current state + +Primary verbs: `run`, `chat`, `pull`, `list`, `show`, `rm`, `link`, `serve`, `bench`. +490 tests passing across the workspace. Legacy research commands gated under +`larql dev ` for backwards-compat. Dual cache (HuggingFace hub + +`~/.cache/larql/local/`) with shorthand resolution (`larql run gemma3-4b-it-vindex`). + +--- + +## P0: Generation UX (blocks demo) + +### `larql parity` — backend parity diagnostic +**Status**: Designed 2026-04-27, not started. +**Files**: new `src/commands/diagnostics/parity.rs` and a `Subcommand::Parity` + variant in `src/main.rs`. Trace-point infrastructure lives in + `larql-inference/src/diagnostics/` (new module). + +Cross-backend numerical diff tool. Catches "I refactored quantization / +activation / norm and silently broke something" regressions that latency +benches and synthetic-weight unit tests miss. Today's specific motivation: +the CPU MoE path on Gemma 4 26B-A4B produces incoherent text while Metal +produces "Paris." (See `larql-server/ROADMAP.md` P0 F0.) + +**Shape:** +```bash +larql parity --component [--prompt "..."] [--seed N] + [--layer N] [--expert M] + [--backends cpu,metal,hf] + [--tolerance 1e-3] [--verbose] +``` + +**Components (in order of build priority):** +| Component | What it diffs | When it lands | +|---|---|---| +| `moe-expert` | Single expert forward (gate matmul, up matmul, gelu_tanh, down matmul) | v1 | +| `moe-block` | Full MoE block, one layer (router → top-K → K experts → weighted sum → post-norm) | v1 — finds today's bug | +| `attention` | Single attention block (Q/K/V proj, RoPE, softmax, O proj) | v2 | +| `dense-ffn` | Dense FFN layer (gate, up, act, down) | v2 | +| `layer` | Full transformer layer end-to-end | v2 | +| `forward` | Full forward pass; per-layer divergence trace | v3 | + +**Backends (in order of build priority):** +| Backend | Source of truth | When | +|---|---|---| +| `reference` | Slow naive triple-loop CPU; f64 accumulators; no BLAS, no padding tricks. The bedrock other backends compare against. | v1 | +| `cpu` | Production `cpu_moe_forward` / `predict_q4k` paths | v1 | +| `metal` | `gpu_moe_dispatch` / Metal `predict_q4k_metal`. Requires exposing public entry points or adding `gpu_dispatch_one_` shims. | v2 | +| `hf` | HuggingFace `transformers` reference loaded from a sidecar dump. Python script (`tools/hf_capture.py`) runs `model.forward` with intermediate captures, writes `.safetensors`; Rust harness loads and compares. | v3 | + +**Architecture:** +- Trace points at well-defined checkpoints (`post_pre_norm`, `post_router_softmax`, + `post_gate_matmul`, `post_activation`, `post_down_matmul`, `post_combine`, + `post_post_norm`). Each checkpoint emits `(name: &str, &[f32])` to a + registered `TraceSink`. +- One sink per backend. The diagnostic runs the same input through each + backend with its sink attached, then walks the merged traces and prints + the **first divergence** beyond `--tolerance` along with magnitude, index, + and surrounding context. +- Trace points are zero-overhead in release builds (gated on a `diagnostics` + feature flag in `larql-inference`). When the feature is off, sinks are no-ops + and the compiler optimises them away. + +**v1 has already been validated as a one-shot prototype** (deleted after +proving the approach): a slow naive reference matches `cpu_moe_forward` +bit-identically (max diff 4.3e-6) on layer 0, expert 0 of the 26B-A4B vindex +— so today's bug is **not** in per-expert compute. It must be in routing or +expert combination, which v1's `moe-block` component will catch. + +**Testing strategy:** +- `cargo test -p larql-cli --test test_parity_smoke`: synthetic 4-expert + MoE built from known weights; reference and CPU must agree to fp32 noise. +- `cargo run -p larql-cli -- parity --component moe-block` + in CI on a representative MoE vindex once we have one in the test fleet. + +**Open scoping decisions:** +- Output format: human-readable table by default, `--json` for CI consumption? +- Should `larql parity` accept `--from-recording ` to replay a previously + captured trace (avoids loading the model twice for repeated diffs)? Probably + yes for v3 once HF sidecar exists. +- Tolerance per-component: `forward` after 30 layers will accumulate to + ~1e-2 even for "correct" backends; need component-specific defaults. + +### Chat template — CLI side +**Status**: Not started +**Files**: `src/commands/run_cmd.rs` +Instruction-tuned models need the prompt wrapped in the model's turn format before +tokenisation. `larql chat` should always apply the template; `larql run` exposes +`--no-chat-template` to skip it on base models. The inference-side Jinja parsing +is tracked in `larql-inference/ROADMAP.md`; this item is only the flag wiring and +auto-detect logic in `run_cmd.rs`. + +### Streaming display +**Status**: Not started +**Files**: `src/commands/run_cmd.rs` +Once `generate.rs` emits an `on_token` callback (see larql-inference P0), the CLI +side is: print each token to stdout and `flush()` immediately. One-liner in the +callback closure; without it the terminal is silent for the full `--max-tokens` run. + +--- + +## P1: Usability + +### Sampling flags +**Status**: Not started +**Files**: `src/commands/run_cmd.rs` +Add `--temperature F`, `--top-p F`, `--top-k N`, `--repetition-penalty F` to +the `run` / `chat` subcommands. Values are threaded through to `generate.rs` +logit post-processing (tracked in larql-inference P0). + +### `--max-context N` +**Status**: Not started +**Files**: `src/commands/run_cmd.rs` +Expose `--max-context N` (default 8192). Thread through to `KVCache::new_per_layer` +in `generate.rs`. `larql chat` should also respect this for multi-turn state. + +### Auto-extract on `larql run hf://` +**Status**: Not started +**Files**: `src/cache/resolve_model.rs` (or equivalent resolver) +If the shorthand looks like `hf://owner/name` and no cached vindex is found, offer +to run `larql extract` inline (confirm prompt or `--yes`). Collapses the three-step +`extract → link → run` flow to one command. + +### OpenAI-compatible surface — CLI side +**Status**: Not started +**Files**: `src/commands/run_cmd.rs` +After the server-side `/v1/chat/completions` endpoint lands (larql-server P0), +expose `larql run --openai-url URL` to send prompts to any OpenAI-compatible +endpoint (including the local `larql serve` instance). Useful for round-trip +testing without a client library. + +--- + +## P2: MoE / expert routing + +### `--experts` flag +**Status**: Not started +**Files**: `src/commands/run_cmd.rs`, `src/commands/serve_cmd.rs` +`larql run --experts '0-31=http://host1,32-63=http://host2'` — MoE counterpart +to `--ffn URL`. Maps expert ID ranges to remote URLs; passed through to +`RemoteExpertBackend` in larql-inference. See also `larql-lql/ROADMAP.md` Phase 3 +for the LQL grammar surface. + +--- + +## Shipped — 2026-04-30 + +| What | Notes | +|------|-------| +| `larql parity --component layer` extended to dense models | Was MoE-only via `LARQL_DUMP_RESIDUALS`; now also handles dense by setting `LARQL_METAL_DUMP_LAYERS` and reading per-layer `metal_layer_NN_h_out.f32` / `metal_layer_NN_h_post_attn.f32`. Used to confirm Gemma 4 31B dense matches between CPU and Metal at every layer (cos ≥ 0.9999), which localised the bug to chat-template / sampling rather than the math | +| `larql parity --component lm-head` works on dense vindexes | The MoE-only gate (`is_hybrid_moe()` check) only fires for `moe-expert` / `moe-block` now; `lm-head` is backend-agnostic (Q4_K matvec vs f32 reference) and works on any vindex with an lm_head | +| Dense Metal path applies chat templates | `walk_cmd::run_predict_q4k` was sending the raw user prompt to `encode_prompt`; chat-template wrapping only happened for the `--moe-shards` / `--moe-units-manifest` paths. Both paths now go through `larql_inference::chat::render_user_prompt`. Fixes "The answer is:" looping on Gemma 4 31B dense and the "more questions instead of answers" frame on Gemma 3 | +| Auto-injected default system prompt for Gemma 4 (all variants) | Gemma 4 needs a system prompt to enter answer mode; `LARQL_NO_DEFAULT_SYSTEM=1` opts out, `LARQL_SYSTEM=` overrides | diff --git a/crates/larql-cli/docs/quantize-spec.md b/crates/larql-cli/docs/quantize-spec.md new file mode 100644 index 00000000..2ba8e051 --- /dev/null +++ b/crates/larql-cli/docs/quantize-spec.md @@ -0,0 +1,449 @@ +# `larql convert quantize` — CLI surface spec + +**Status:** FP4 + Q4K shipped (exp 26). Future formats extensible +through the same grammar. +**Scope:** CLI shape for converting a loaded vindex into a quantised +variant. Each format is a sibling subcommand under `quantize`, with +its own flag surface. FP4 and Q4K are wired today; future formats +land as additional subcommands without changing the grammar. +**Format-specific references:** +- FP4: [`fp4-format-spec.md`](fp4-format-spec.md) (byte layout), + [`fp4-precision-policy.md`](fp4-precision-policy.md) (A/B/C + policies + compliance gate). +- Q4K: GGML "Q4_K_M" mix (Q4_K gate/up + Q6_K down), Ollama- + compatible. Library entry: `larql_vindex::quant::vindex_to_q4k` + on top of `format::weights::write_model_weights_q4k_with_opts`. + +--- + +## 0. The umbrella + +`larql convert quantize ` is the family entry point: + +``` +larql convert quantize fp4 [fp4 flags] ← wired today +larql convert quantize q4k [q4k flags] ← wired today +larql convert quantize fp6 [fp6 flags] ← future +larql convert quantize ... [format-specific] +``` + +Format-specific flag sets stay isolated (FP4's `--policy` / +`--compliance-floor` / `--threshold` don't clutter Q4K's +invocation), but users have one mental model: "quantise a vindex." + +**Adding a new format is three edits:** + +1. One `QuantizeCommand::FooBar { ... }` variant in `convert_cmd.rs`. +2. One `run_quantize_foobar` fn delegating to the format's library + entry. +3. One library fn `larql_vindex::quant::vindex_to_foobar(src, dst, config)` + mirroring the shape of `vindex_to_fp4`. + +No other CLI or library code touches. Other formats' flag surfaces +are unaffected. This is the structural payoff of the nested- +subcommand grammar: the CLI grows linearly, not combinatorially. + +## 1. Why a spec before code + +The example binary (`crates/larql-vindex/examples/fp4_convert.rs`) +already did the work. Promoting it to `larql convert quantize fp4` +was mostly mechanical, but a few things needed pinning before we +wrote the clap subcommand so the output is stable across format +revisions: + +- **Flag surface** — which knobs are user-facing, which are internal, + which get deprecated later. +- **Self-policing gate** — what happens when a projection fails the + compliance floor, how it's reported, whether the run is allowed to + continue or is treated as an error. +- **Output directory layout** — what files land, what gets hard-linked + from the source, what's optional. +- **Failure modes** — what a non-success run looks like (what's + written, what's emitted to stderr, what the exit code is). +- **Diagnostics** — where the dispatch trace / describe helpers + integrate so a user can tell at a glance whether the output will + actually be FP4 end-to-end. + +Pinning these now means the first real `larql convert` run that ships +to someone outside the repo produces output whose schema is stable. + +## 2. FP4 invocation + +``` +larql convert quantize fp4 \ + --input SRC # existing vindex directory + --output DST # new vindex directory + [--policy option-a | option-b | option-c] # default: option-b + [--compliance-floor FRAC] # default: 0.99 + [--threshold RATIO] # default: 16.0 (format-derived) + [--force] # overwrite DST if present + [--strict] # fail on any compliance-floor miss + [--no-sidecar] # skip fp4_compliance.json emission + [--quiet] # suppress backend-describe output +``` + +**Defaults are the "just works for the common case" path.** Running +`larql convert quantize fp4 --input X --output Y` produces an +Option B vindex (source-dtype gate + FP4 up + FP8 down), with the Q1 +compliance scan written to `DST/fp4_compliance.json` and the one-line +backend summary printed on stdout. The defaults match the policy +spec's recommended Option B, so users who just want "the default FP4 +vindex" don't need any flags. + +**`--threshold` help text must explain the default, not leave it as a +number.** The 16.0 default is the format-derived E4M3-vs-E2M1 exponent +budget (see `FP4_FORMAT_SPEC.md` §5.1 and the DeepSeek reference). +Users who raise it are being more permissive about FP4 block +compliance; users who lower it are being stricter. Example help +text: `--threshold RATIO max/min sub-block scale ratio for the +FP4 compliance gate (default: 16.0, the E4M3/E2M1 exponent budget; +lower = stricter, higher = more permissive)`. + +## 3. FP4 behavior sketch + +``` +> larql convert quantize fp4 --input output/gemma3-4b-f16.vindex --output output/gemma3-4b-fp4.vindex + +== quantize fp4 == + in : output/gemma3-4b-f16.vindex + out : output/gemma3-4b-fp4.vindex + model : google/gemma-3-4b-it + policy : option-b (gate=source, up=FP4, down=FP8) + floor : 99.0% compliance at R<16.0 + +→ scanning reference vindex … + gate : 99.91% → keep as f32 (gate stays at source dtype; FP4 gate blocked on FP4-aware KNN path) + up : 99.93% → FP4 (meets floor) + down : 99.65% → FP8 (policy: down is always FP8 under option-b; compliance floor N/A for FP8) + +→ writing output … + gate_vectors.bin (hard-link, 3.32 GB) + up_features_fp4.bin (new, 0.44 GB) + down_features_fp8.bin (new, 0.85 GB) + fp4_compliance.json (new) + index.json (new, fp4 manifest attached) + [auxiliary files hard-linked: attn_weights.bin, down_meta.bin, embeddings.bin, …] + +── summary ── + FFN storage : 9.96 GB → 4.60 GB (2.17× compression) + Walk backend: FP4 sparse (gate=f32, up=fp4, down=fp8), gate KNN (F32 mmap) + Wall time : 12.3s + + → load output with LARQL_VINDEX_DESCRIBE=1 to verify the backend at runtime. +``` + +Compliance failures (projection targeted for FP4 falls below floor): + +``` + down : 98.42% → FP8 (policy: down is always FP8 under option-b; floor N/A for FP8) + up : 97.80% ⚠ DOWNGRADE: FP4 floor (99.0%) missed → writing as FP8 (fallback_precision from manifest) + +⚠ compliance floor missed on 1 projection; see fp4_compliance.json for details. +(Use --strict to treat this as a fatal error.) +``` + +The compliance floor is a **precision-FP4 gate**, not a per-projection +gate. It only applies where the policy says "write this projection +as FP4"; projections targeted for FP8 or F16 skip the check entirely +(FP8 doesn't use the max/min-sub-block-scale distributional +assumption, and F16 is bit-identical to source). That's why the down +line above reads "floor N/A for FP8" — it's not a bug in the log +output, it's the honest description of what the floor measures. + +Under `--strict`, the same scenario exits non-zero after writing the +compliance sidecar. Under default, the converter downgrades the +affected projection to the fallback precision from the manifest's +`compliance_gate` and continues. + +## 4. Q4K invocation + behavior + +``` +larql convert quantize q4k \ + --input SRC # existing vindex with full f32/f16 weights + --output DST # new vindex directory + [--down-q4k] # FFN down at Q4_K instead of Q6_K (Q4_K_M default keeps it at Q6_K) + [--force] # overwrite DST if present + [--quiet] # suppress backend-describe output +``` + +**The default produces an Ollama-compatible Q4_K_M mix:** attention +Q/K/O at Q4_K, attention V at Q6_K, FFN gate/up at Q4_K, FFN down at +Q6_K. `--down-q4k` switches FFN down to Q4_K uniformly — saves ~30 MB +per layer on a 31B model (~1.8 GB total) at modest precision cost +that the empirical scatter-sum averages across the intermediate +dimension (validated by `walk_correctness`, which auto-relaxes its +prob-delta gate from 0.02 to 0.035 when Q4_K down is detected). + +**Precondition:** the source vindex must have full model weights +(`extract_level: inference` or `all`). The Q4K writer reads every +attention and FFN tensor from the source and rewrites them as +quantised blocks; a browse-only vindex (no `attn_weights.bin` / +`up_weights.bin` / `down_weights.bin`) is rejected with a clear +error pointing at `--level inference`. Quantised sources (`quant != +none`) are also rejected — re-quantising an already-quantised vindex +is a no-op or worse. + +``` +> larql convert quantize q4k --input output/gemma3-4b-f16.vindex --output output/gemma3-4b-q4k.vindex + +== quantize q4k == + in : output/gemma3-4b-f16.vindex + out : output/gemma3-4b-q4k.vindex + down_q4k : false (Q6_K down (Q4_K_M mix)) + +── summary ── + FFN storage : 6.64 GB → 4.94 GB (1.35× compression) + Linked aux : 6 files (4.63 GB) + Wall time : 13.5s + Walk backend: Q4K interleaved, gate KNN (F32 mmap) + +→ output/gemma3-4b-q4k.vindex +``` + +Q4K's compression ratio is more modest than FP4's because (a) the +4-bit nibble is paired with a richer per-block scale + min layout +(GGML Q4_K is 144 B per 256-element super-block vs FP4's 137 B), and +(b) the V-projection and FFN down stay at Q6_K by default. The +tradeoff is precision: Q4K is the same format llama.cpp / Ollama +ship with and is validated against the Gemma walk-correctness gate; +FP4 is an experimental spatially-sparser layout with its own +compliance regime. + +### Output layout (Q4K) + +``` +DST/ +├── index.json # quant=q4k, has_model_weights=true +│ +│ # ── Hard-linked from SRC (zero-copy, no rewrite) ── +├── gate_vectors.bin # gate matrix (KNN still wants the dense float view) +├── embeddings.bin +├── down_meta.bin +├── feature_labels.json +├── tokenizer.json +├── README.md # if SRC carried one +│ +│ # ── Written by this run ── +├── attn_weights_q4k.bin # Q/K/O at Q4_K, V at Q6_K +├── attn_weights_q4k_manifest.json +├── interleaved_q4k.bin # gate + up at Q4_K, down at Q6_K (or Q4_K with --down-q4k) +├── interleaved_q4k_manifest.json +├── lm_head_q4.bin # output projection at Q4_K +├── norms.bin # layer + final norms (always f32) +└── weight_manifest.json +``` + +The float weight files (`attn_weights.bin`, `up_weights.bin`, +`down_weights.bin`, `interleaved.bin`, `lm_head.bin`) from the +source are **not** hard-linked — the Q4K weight files replace them. +Hard-linking the floats too would inflate the output by 6+ GB on a +4B model with no consumer for those bytes. + +### Atomic write + +Like FP4, the writer stages into `DST.tmp/` and renames on success. +Partial output never carries a valid `index.json`, so a crashed run +is unambiguously distinguishable from a complete one. + +## 5. Exit codes + +| Code | Meaning | +| ---- | ------------------------------------------------------------------ | +| 0 | Output produced; all policy-specified projections written. | +| 1 | Input vindex invalid, missing files, or unsupported geometry. | +| 2 | Compliance floor missed on ≥ 1 projection AND `--strict` was set. | +| 3 | I/O error writing output. | +| 4 | Output exists and `--force` not provided. | + +Non-success codes always leave `DST` either absent (on early failure) +or with a partial output clearly tagged by the absence of +`index.json` (written atomically at the end of the run). + +## 6. Self-policing gate integration (FP4 only) + +The Q1 scanner (`crates/larql-vindex/examples/fp4_q1_scan.rs`) +currently lives as an example. For `larql convert quantize fp4` it +is promoted to `larql_vindex::quant::scan` — a library entry the +convert subcommand calls directly, producing an in-memory +`ComplianceReport` that the converter consults before deciding the +per-projection precision. + +Scanner-as-library invariants: +- No filesystem I/O inside the scanner itself (reads come from the + `VectorIndex` accessors, which already mmap the data). +- Pure function: `scan(index, threshold) -> ComplianceReport`. +- Report is the same JSON shape the example emits, minus any CLI-only + framing. + +This makes the Q1 scanner usable anywhere — the convert subcommand +today, future `larql verify --fp4` tomorrow, regression tests next +week. One implementation, multiple consumers. + +## 7. FP4 output layout + +``` +DST/ +├── index.json # updated: fp4 manifest attached, checksums refreshed +├── fp4_compliance.json # per-projection scan + action taken +│ +│ # ── Hard-linked from SRC (zero-copy, no rewrite) ── +├── attn_weights.bin # attention +├── down_meta.bin # per-feature output token metadata +├── embeddings.bin # embed +├── feature_labels.json # labels +├── gate_vectors.bin # gate kept at source dtype (policy default) +├── norms.bin # layer norms +├── tokenizer.json +├── weight_manifest.json +│ +│ # ── Written by this run ── +├── up_features_fp4.bin # FP4 E2M1, 256-elem blocks +└── down_features_fp8.bin # FP8 E4M3, 256-elem blocks +``` + +Files are listed in the same order the converter's summary prints +them, so the stdout output can be diffed against `ls DST/` to +confirm the write. + +### Hard-link fallback + +On filesystems that don't support hard links (cross-filesystem, some +network mounts), the converter falls back to file copy and emits a +one-line notice. The output is functionally identical; size on disk +doubles for the hard-linked portion. Should be rare in practice. + +## 8. Diagnostics that ship with the subcommand + +Three observability hooks, all default-on: + +1. **Backend summary line** (already implemented via + `VectorIndex::describe_ffn_backend()`). Printed on stdout after + the write. Suppressed with `--quiet`. +2. **Compliance sidecar path** echoed in the summary. Makes it + obvious where to look when investigating a compliance miss. +3. **One-liner suggesting `LARQL_VINDEX_DESCRIBE=1`** for users who + want to double-check the backend at runtime (not just at convert + time). + +This is deliberately conservative — we're not emitting verbose trace +by default. Users running into trouble enable `LARQL_WALK_TRACE=1` at +runtime. The convert subcommand itself should be quiet by default +and only noisy on anomalies. + +## 9. Testing surface + +The existing tests mostly transfer: + +| Existing test | Covers | +| ------------------------------------------------------------ | ------ | +| `tests/test_fp4_synthetic` (7 tests) | Per-feature round-trip through a loaded FP4 vindex — the kind `larql convert` produces. | +| `tests/test_fp4_storage` (4 tests, real fixture) | End-to-end against `gemma3-4b-fp4.vindex`. Switching to `larql convert`-produced output changes nothing. | +| `format::fp4_storage::tests` (7 tests) | File-level writer/reader. The converter uses these via `write_fp4_projection` / `write_fp8_projection`. | +| `index::fp4_storage::tests` (13 tests) | Per-projection storage — same abstraction. | +| `walk_ffn::routing_tests` (3 tests) | Predicate ladder, including the Q2-regression guard. | + +New tests the CLI subcommand needs: + +1. **Smoke:** invoke the CLI with a small synthetic input vindex, + assert stdout contains the expected summary lines and that DST + has the expected filenames. +2. **Exit codes:** invoke with `--force` absent when DST exists → + exit 4. Invoke with `--strict` and a synthetic input rigged to + miss compliance → exit 2. +3. **Self-policing:** invoke with a synthetic input that has a + projection below the floor (inject a pathological block) → + verify the output manifest records the downgrade and the stored + file is the fallback precision. +4. **Round-trip parity:** convert synthetic SRC → DST, load DST, + compare row reads to SRC f32 data within the expected FP4 bound. + +Four tests, ~200 LOC total, all using the tempdir pattern already +established in `tests/test_fp4_synthetic.rs`. + +## 10. What this does NOT do (v1) + +- **Safetensors-direct FP4 extract.** Two-step (`extract` then + `quantize fp4`) remains the workflow. The reason is decoupling: + the FP4 writer should never need to know about extract-time + concerns (HuggingFace format quirks, model-specific weight + reorganisation, tied-embedding detection, PLE handling for + Gemma 4 E2B). The vindex is the stable intermediate — if FP4 + conversion is a function of a vindex, it composes cleanly with + whatever extract path produced that vindex, now and in the future. + Merging the two into a single "safetensors-to-FP4" entry point + would duplicate extract logic and couple the FP4 writer to + loader-specific surprises. +- **Mixed-precision override per-layer.** `--layers 0..12 down=fp4, + 13.. down=fp8` style is deferred. Data doesn't yet say it buys + anything; revisit after cross-model Q2. +- **In-place conversion.** No `--in-place` flag. The existing vindex + stays untouched; the FP4 copy is separate. Reversibility matters. +- **GGUF / MLX interop.** Out of scope; this operates on LARQL + vindexes only. + +## 11. Shipping checklist + +- [x] Promote `fp4_q1_scan` from example to library + (`larql_vindex::quant::scan`). Preserve the example binary as a + thin wrapper so existing scripts keep working. +- [x] Promote `fp4_convert` logic to a library fn + (`larql_vindex::quant::vindex_to_fp4`). Example binary becomes + a thin wrapper. +- [x] Add `ConvertCommand::Quantize(QuantizeCommand)` + `Fp4` and + `Q4k` variants in + `crates/larql-cli/src/commands/extraction/convert_cmd.rs` with + the flag surfaces above. +- [x] Wire `run_quantize_fp4` and `run_quantize_q4k` to the library + fns. +- [x] Add the 4 CLI-level tests listed in §9 (FP4) plus 4 lifecycle + tests for Q4K (preconditions + force/no-force + already-q4k). +- [ ] Update `docs/cli.md` and `docs/specs/vindex-format-spec.md` + §12.1 with the new subcommands and example invocations. +- [x] Smoke: run on `gemma3-4b-f16.vindex` for both FP4 and Q4K, + verify the converted vindex loads and decodes ("Paris is the + capital of" → " France …"). + +Deferred until shipping: + +- [ ] Integrate a progress callback (currently `vindex_to_q4k` / + `vindex_to_fp4` use silent callbacks; the CLI should print + per-stage timing without needing `eprintln!` spam). Reuse the + existing `larql_vindex::IndexLoadCallbacks`-style trait shape. + +## 12. v1 decisions closed + open items + +### Closed by this spec + +1. **Subcommand name: `quantize fp4`** (nested under `convert + quantize`). Replaces the earlier draft's `vindex-to-fp4` flat + subcommand. The nested shape extends to other formats without + the CLI growing a new top-level entry per format. Matches the + existing + `gguf-to-vindex` / `safetensors-to-vindex` pattern. Keep. + +2. **Atomic conversion: write to `DST.tmp/`, fsync, rename to `DST/` + on success.** Moved from "open / defer" to v1 baseline. Rationale: + partial output that *looks* complete (some files written, + `index.json` absent or stale) is a foot-gun for users scripting + against this tool. Atomic-rename is the right pattern for any + tool that produces a directory of related files, and the cost is + trivial (~20 LOC). On filesystems where `rename` would cross a + mount boundary (rare), the converter falls back to in-place write + with a warning. + +3. **Compliance sidecar: always-on by default, `--no-sidecar` + opt-out.** Sidecar is ~1 KB and removes the foot-gun of "why did + my FP4 vindex get reshaped?" Silence is a CI-only concern. + +### Still open + +1. **Should the default policy be settable globally?** e.g. via + `~/.larql/config.toml` or `LARQL_FP4_POLICY=option-a`. Not obvious + Option A will ever be the common default (Q2 ablation confirms B + as default); defer until a concrete use case emerges. + +2. **Should the Q1 scan output the full JSON sidecar even when the + scan is run standalone (not through convert)?** The example + binary already does this. Library version should expose both a + `ComplianceReport` struct (for programmatic use) and a `to_json` + helper (for CLI write). Non-blocking. diff --git a/crates/larql-cli/examples/convert_moe_to_per_layer.rs b/crates/larql-cli/examples/convert_moe_to_per_layer.rs new file mode 100644 index 00000000..6cbbdedc --- /dev/null +++ b/crates/larql-cli/examples/convert_moe_to_per_layer.rs @@ -0,0 +1,128 @@ +//! Convert an existing MoE vindex from BF16 monolithic blob (`experts_packed.bin`) +//! to per-layer Q4_K files (`layers/layer_{L:02}.weights`). +//! +//! Usage: +//! cargo run --release --example convert_moe_to_per_layer -- +//! +//! Reads `weight_manifest.json` for BF16 expert byte ranges, quantizes each +//! expert to Q4_K, writes the new binary format, then updates `index.json` +//! with `"ffn_layout": "per_layer"`. + +use std::collections::HashMap; +use std::path::Path; + +use larql_vindex::format::weights::write_layers::{ + quantize_moe_entries, write_layer_weights, LayerWeightFormat, +}; + +fn main() -> Result<(), Box> { + let args: Vec = std::env::args().collect(); + if args.len() < 2 { + eprintln!("Usage: {} ", args[0]); + std::process::exit(1); + } + let vindex_path = Path::new(&args[1]); + + // Load and parse index.json + let index_path = vindex_path.join("index.json"); + let index_text = std::fs::read_to_string(&index_path)?; + let mut config: serde_json::Value = serde_json::from_str(&index_text)?; + + let num_layers = config["num_layers"].as_u64().ok_or("missing num_layers")? as usize; + let hidden = config["hidden_size"] + .as_u64() + .ok_or("missing hidden_size")? as usize; + + let moe_cfg = config["model_config"]["moe"] + .as_object() + .ok_or("not a MoE model (no model_config.moe)")?; + let num_experts = moe_cfg["num_experts"] + .as_u64() + .ok_or("missing num_experts")? as usize; + let moe_inter = moe_cfg["moe_intermediate_size"] + .as_u64() + .ok_or("missing moe_intermediate_size")? as usize; + + eprintln!( + "Model: {num_layers} layers, hidden={hidden}, {num_experts} experts, inter={moe_inter}" + ); + + // Parse weight_manifest.json → BF16 byte ranges + let manifest_text = std::fs::read_to_string(vindex_path.join("weight_manifest.json"))?; + let manifest: Vec = serde_json::from_str(&manifest_text)?; + + let mut bf16_ranges: HashMap = HashMap::new(); + for entry in &manifest { + if entry["kind"].as_str() != Some("packed_bf16") { + continue; + } + let key = entry["key"].as_str().unwrap_or("").to_string(); + let file = entry["file"].as_str().unwrap_or("").to_string(); + let offset = entry["offset"].as_u64().unwrap_or(0) as usize; + let length = entry["length"].as_u64().unwrap_or(0) as usize; + bf16_ranges.insert(key, (file, offset, length)); + } + + if bf16_ranges.is_empty() { + return Err("no packed_bf16 entries in weight_manifest.json — already converted?".into()); + } + + // Open source mmaps lazily + let mut open_mmaps: HashMap = HashMap::new(); + let get_bytes = |file: &str, + offset: usize, + length: usize, + mmaps: &mut HashMap| + -> Result, Box> { + if !mmaps.contains_key(file) { + let f = std::fs::File::open(vindex_path.join(file))?; + mmaps.insert(file.to_string(), unsafe { memmap2::Mmap::map(&f)? }); + } + Ok(mmaps[file][offset..offset + length].to_vec()) + }; + + // Convert each layer + let fmt = LayerWeightFormat::Q4_K; + let t_start = std::time::Instant::now(); + for layer in 0..num_layers { + let gu_key = format!("layers.{layer}.experts.gate_up_proj"); + let dn_key = format!("layers.{layer}.experts.down_proj"); + + let (gu_file, gu_off, gu_len) = bf16_ranges + .get(&gu_key) + .ok_or_else(|| format!("missing {gu_key}"))? + .clone(); + let (dn_file, dn_off, dn_len) = bf16_ranges + .get(&dn_key) + .ok_or_else(|| format!("missing {dn_key}"))? + .clone(); + + let gu_bytes = get_bytes(&gu_file, gu_off, gu_len, &mut open_mmaps)?; + let dn_bytes = get_bytes(&dn_file, dn_off, dn_len, &mut open_mmaps)?; + + let entries = + quantize_moe_entries(&gu_bytes, &dn_bytes, num_experts, moe_inter, hidden, fmt); + write_layer_weights(vindex_path, layer, fmt, &entries, moe_inter, hidden)?; + + let elapsed = t_start.elapsed().as_secs_f64(); + let rate = (layer + 1) as f64 / elapsed; + let eta = (num_layers - layer - 1) as f64 / rate; + eprintln!( + " layer {:02}/{} ({:.1}s elapsed, ETA {:.0}s)", + layer, + num_layers - 1, + elapsed, + eta + ); + } + + // Update index.json + config["ffn_layout"] = serde_json::Value::String("per_layer".into()); + std::fs::write(&index_path, serde_json::to_string_pretty(&config)?)?; + + eprintln!( + "\nDone in {:.1}s. layers/ ready. experts_packed.bin can be removed after validation.", + t_start.elapsed().as_secs_f64() + ); + Ok(()) +} diff --git a/crates/larql-cli/examples/patch_down_proj.rs b/crates/larql-cli/examples/patch_down_proj.rs index 144c21f4..afa8cd65 100644 --- a/crates/larql-cli/examples/patch_down_proj.rs +++ b/crates/larql-cli/examples/patch_down_proj.rs @@ -36,8 +36,14 @@ use serde_json::Value; fn main() -> Result<(), Box> { let mut args = std::env::args().skip(1); - let vindex_path: PathBuf = args.next().ok_or("usage: patch_down_proj ")?.into(); - let hf_root: PathBuf = args.next().ok_or("usage: patch_down_proj ")?.into(); + let vindex_path: PathBuf = args + .next() + .ok_or("usage: patch_down_proj ")? + .into(); + let hf_root: PathBuf = args + .next() + .ok_or("usage: patch_down_proj ")? + .into(); println!("vindex = {}", vindex_path.display()); println!("hf-root = {}", hf_root.display()); @@ -69,7 +75,10 @@ fn main() -> Result<(), Box> { // Cache safetensors shards so we don't re-mmap per layer. let mut shards: BTreeMap = BTreeMap::new(); - let shard_mmap = |name: &str, shards: &mut BTreeMap, hf_root: &Path| -> Result<(), Box> { + let shard_mmap = |name: &str, + shards: &mut BTreeMap, + hf_root: &Path| + -> Result<(), Box> { if !shards.contains_key(name) { let p = hf_root.join(name); let mm = unsafe { Mmap::map(&fs::File::open(&p)?)? }; @@ -90,9 +99,18 @@ fn main() -> Result<(), Box> { let gate_key = gate_e["key"].as_str().unwrap(); let up_key = up_e["key"].as_str().unwrap(); let down_key = down_e["key"].as_str().unwrap(); - assert!(gate_key.ends_with(".mlp.gate_proj.weight"), "unexpected entry[0]: {gate_key}"); - assert!(up_key.ends_with(".mlp.up_proj.weight"), "unexpected entry[1]: {up_key}"); - assert!(down_key.ends_with(".mlp.down_proj.weight"), "unexpected entry[2]: {down_key}"); + assert!( + gate_key.ends_with(".mlp.gate_proj.weight"), + "unexpected entry[0]: {gate_key}" + ); + assert!( + up_key.ends_with(".mlp.up_proj.weight"), + "unexpected entry[1]: {up_key}" + ); + assert!( + down_key.ends_with(".mlp.down_proj.weight"), + "unexpected entry[2]: {down_key}" + ); // Copy gate and up bytes unchanged. let copy_entry = |e: &Value, sink: &mut Vec| -> (u64, u64) { @@ -155,8 +173,13 @@ fn main() -> Result<(), Box> { "length": q_bytes.len(), })); if layer % 5 == 0 { - println!(" L{layer:02} down {} → {} bytes (padded {}→{})", - down_e["length"], q_bytes.len(), cols, padded_cols); + println!( + " L{layer:02} down {} → {} bytes (padded {}→{})", + down_e["length"], + q_bytes.len(), + cols, + padded_cols + ); } } diff --git a/crates/larql-cli/src/commands/dev/mod.rs b/crates/larql-cli/src/commands/dev/mod.rs new file mode 100644 index 00000000..8a70a877 --- /dev/null +++ b/crates/larql-cli/src/commands/dev/mod.rs @@ -0,0 +1 @@ +pub mod ov_rd; diff --git a/crates/larql-cli/src/commands/dev/ov_rd/README.md b/crates/larql-cli/src/commands/dev/ov_rd/README.md new file mode 100644 index 00000000..7a370156 --- /dev/null +++ b/crates/larql-cli/src/commands/dev/ov_rd/README.md @@ -0,0 +1,204 @@ +# OV/RD Dev Command + +`larql dev ov-rd` is the experimental harness for attention output-vector +rate-distortion work. It is deliberately a `dev` command, not a production +extraction command. + +The core question is whether an attention head's pre-`W_O` output can be +replaced by a compact table: + +```text +runtime state -> address -> residual-space lookup/add +``` + +For the current L0H6 line of work, the stable findings are: + +```text +oracle table exists +Mode D residual-table materialization works +held-out mean/p95 can pass +the current dominant group-0 code is not addressable from shallow state +full/reduced-QK attention-pattern clusters also fail on the hard L0H6 group +``` + +## Engine Boundary + +The main engine now owns the reusable runtime pieces that were previously +embedded in this command: + +```text +larql_inference::vindex::insert_q4k_layer_tensors +larql_inference::vindex::remove_layer_tensors +larql_inference::vindex::predict_q4k_hidden_hooked +larql_inference::vindex::predict_q4k_hidden_with_mapped_pre_o_head +larql_inference::vindex::predict_q4k_hidden_with_replaced_pre_o_head +larql_inference::vindex::predict_q4k_hidden_with_zeroed_pre_o_heads +larql_inference::vindex::predict_q4k_hidden_with_subtracted_pre_o_heads +larql_inference::vindex::predict_q4k_hidden_with_mapped_head_residual_delta +larql_inference::vindex::predict_q4k_hidden_with_replaced_head_residual_delta +larql_inference::vindex::predict_q4k_hidden_with_original_head_residual_delta +larql_inference::attention::run_attention_block_with_pre_o_and_all_attention_weights +larql_inference::attention::run_attention_block_with_pre_o_and_reduced_qk_attention_weights +``` + +Those APIs preserve the hard runtime invariants: + +```text +Q4K layer tensor scope +PLE input propagation +Gemma 4 shared-KV routing +FFN / PLE / layer-scalar tail +target-layer intervention ordering +``` + +OV/RD code should use those APIs whenever it is evaluating a full-model +intervention. Do not reimplement the full Q4K layer loop in the command unless +the command is collecting intermediate training/capture data that the engine API +does not expose yet. + +## What Belongs Here + +Keep Rust code here when it needs exact model/vindex behavior: + +- experiment-specific Q4K vindex loading and prompt orchestration +- attention `pre_W_O` capture for fitting/statistics passes +- `W_O`-visible projection and roundtrip checks +- oracle low-rank and PQ reconstruction +- direct residual-edit catalogue diagnostics +- base-PQ-plus-exception residual catalogue diagnostics +- Mode D residual-delta table materialization +- final-logit KL/top-k evaluation through the real forward path +- model-native discrete address probes whose inputs are already produced by a + real forward pass, for example previous-layer FFN top-feature IDs and + attention/relation summaries or learned attention-pattern cluster IDs +- targeted majority/stratum controls for selected PQ groups, so scale-up + diagnostics do not need full 48-group importance sweeps +- balanced Stage-0 capture subsets via `capture --max-per-stratum`, so grouped + prompt files can be sampled without creating one-off JSONL fixtures +- W_O-visible Stage-0 ranking controls, for example + `zero-ablate --stage0-rank wo-visible-variance`, so Gate 1 promotes heads by + residual-space impact rather than raw pre-W_O variance when available +- canonical JSON artifacts that other tools consume + +The command should remain an orchestrator plus faithful runtime validator. It +should not become the place where every new probe, plot, or clustering variant +lives. + +## What Should Move To Python + +Use Python over exported artifacts for fast-changing analysis: + +- code stability tables +- plotting and report tables +- window hashes, bag-of-token hashes, shingling, MinHash +- decision trees, nearest-centroid variants, and classifier sweeps +- feature/code correlation scans + +If a Python probe becomes a serious runtime candidate, reimplement only that +candidate in Rust after its artifact contract is clear. + +Small summary diagnostics that are part of the canonical JSON schema can stay in +Rust. For example, entropy/JS divergence helpers belong in `metrics.rs` if they +are emitted by `oracle-pq`, while broader exploratory scans should use Python +against exported artifacts. + +## Artifact Contract + +Rust should export enough canonical state that Python can iterate without +rerunning full model forward passes for every idea: + +```text +prompt id / stratum / tokens +layer-input residual rows +captured pre-W_O head rows +oracle PQ codes by position +baseline and replacement logits or metrics +per-prompt KL/top-k summaries +``` + +Prefer compact binary arrays plus JSON metadata for large matrices. JSON alone +is fine for summaries and small diagnostics. + +## Documentation Boundary + +Use `experiments/38_ov_rate_distortion/RESULTS.md` as the lab notebook: commands +run, artifacts written, negative results, and interpretation. + +When a result becomes architectural rather than experimental, promote it to a +short stable doc under `docs/`, for example: + +```text +docs/attention-tableability.md +``` + +The experiment log should stay detailed and chronological. The docs should be +short, curated, and claim-focused. + +## Current Refactor Direction + +This directory replaced the old single-file +`commands/extraction/ov_rd_cmd.rs`. The command is now under `dev` because these +runs are experimental probes, not stable vindex extraction verbs. + +Current split: + +```text +cmd.rs CLI dispatch only +address.rs address predictor models and address-match helpers +basis.rs W_O roundtrip basis, z-space PCA fitting, and eigensolver +capture.rs stage-0 pre-W_O capture and head statistics +input.rs prompt loading, held-out splits, and CLI string parsers +metrics.rs KL, entropy, top-k, and distribution helpers +oracle.rs roundtrip and low-rank oracle checks +edit_catalog.rs full-vector residual-edit catalogue diagnostics in hidden/PCA space +gamma_address.rs gamma-aligned supervised address probes over raw layer input, + diagonal-affine projections toward later residual states, + fixed random low-rank projections, and learned low-rank + target-residual bridges +oracle_pq.rs PQ experiment orchestration, address probe evaluation, and + direct code-level rule diagnostics +oracle_pq_address.rs + address-probe, previous-FFN feature-key, FFN-first feature-key, + attention-relation-key, full/reduced-QK attention-cluster-key, + code-substitution/coarsening controls, code-occurrence export, + oracle binary code/default upper bounds, class-collapse + behavioral quotient probes, and majority-code fitting +oracle_pq_eval.rs shared predicted-address evaluation helper +oracle_pq_fit.rs PQ codebook fitting +oracle_pq_forward.rs + PQ/Mode-D model calls plus experiment-specific capture/mapping logic +oracle_pq_mode_d.rs + Mode D residual-table materialization helpers +oracle_pq_reports.rs + PQ/address report accumulators +oracle_pq_stability.rs + PQ code distribution stability diagnostics +pq.rs PQ codebooks, Mode D tables, and k-means mechanics +pq_exception.rs base-PQ-plus-exception residual catalogue diagnostics, with + residual-error/prompt-KL/position-restore-KL/CE tail + selectors and k-means/exemplar fits +reports.rs JSON artifact schemas +runtime.rs thin shim over inference Q4K tensor insertion/removal +sanity.rs no-op/subtract/residual-delta equivalence checks +static_replace.rs static mean replacement gate and shared static fitting +stats.rs running head stats and static mean accumulators +types.rs shared input/config identifiers +zero_ablate.rs zero pre-W_O ablation gate +``` + +Remaining CLI-owned tensor-scope loops are mostly fitting/capture passes: + +```text +capture.rs stage-0 statistics +basis.rs W_O/PCA basis fitting +static_replace.rs static mean fitting pass +oracle_pq_fit.rs PQ training rows +oracle_pq_address.rs layer-input residual capture for address probes +oracle_pq_stability.rs code stability diagnostics +oracle_pq_mode_d.rs Mode D table materialization +``` + +Those may move later if they become generally useful capture APIs, but they are +not production forward paths. Do this incrementally. The first invariant is that +existing `larql dev ov-rd` commands keep their behavior and artifact schema +unless a schema change is intentional and documented in the experiment results. diff --git a/crates/larql-cli/src/commands/dev/ov_rd/address.rs b/crates/larql-cli/src/commands/dev/ov_rd/address.rs new file mode 100644 index 00000000..980641f8 --- /dev/null +++ b/crates/larql-cli/src/commands/dev/ov_rd/address.rs @@ -0,0 +1,825 @@ +use std::collections::HashMap; + +use ndarray::{Array2, ArrayView1}; + +#[derive(Debug, Clone)] +pub(super) struct AddressProbeModel { + pub(super) name: String, + pub(super) group_majority: Vec, + pub(super) group_maps: Vec>, + pub(super) group_train_accuracy: Vec, + pub(super) selected_group_keys: Vec, +} + +impl AddressProbeModel { + pub(super) fn predict_codes( + &self, + token_ids: &[u32], + stratum: &str, + position: usize, + ) -> Vec { + let key = address_feature_key(&self.name, token_ids, stratum, position); + self.group_maps + .iter() + .enumerate() + .map(|(group, map)| { + map.get(&key) + .copied() + .unwrap_or_else(|| self.group_majority[group]) + }) + .collect() + } + + pub(super) fn predict_codes_from_key(&self, key: &str) -> Vec { + self.group_maps + .iter() + .enumerate() + .map(|(group, map)| { + map.get(key) + .copied() + .unwrap_or_else(|| self.group_majority[group]) + }) + .collect() + } +} + +#[derive(Debug, Clone)] +pub(super) struct AddressLshGroupModel { + pub(super) groups: Vec, + pub(super) bits: usize, + pub(super) group_majority: Vec, + pub(super) group_maps: Vec>, + pub(super) group_seeds: Vec, + pub(super) group_train_accuracy: Vec, +} + +impl AddressLshGroupModel { + pub(super) fn selected_group_keys(&self) -> Vec { + (0..self.group_majority.len()) + .map(|group| { + if self.groups.contains(&group) { + format!( + "lsh{}bits_seed{}_train_acc_{:.3}", + self.bits, self.group_seeds[group], self.group_train_accuracy[group] + ) + } else { + "majority".to_string() + } + }) + .collect() + } + + pub(super) fn predict_selected_groups( + &self, + layer_input: &Array2, + position: usize, + base_codes: &[usize], + ) -> Vec { + let mut codes = base_codes.to_vec(); + let row = layer_input.row(position); + for &group in &self.groups { + let bucket = lsh_bucket(row, self.group_seeds[group], self.bits); + codes[group] = self.group_maps[group] + .get(&bucket) + .copied() + .unwrap_or(self.group_majority[group]); + } + codes + } +} + +#[derive(Debug, Clone)] +pub(super) struct BinaryHyperplane { + pub(super) weights: Vec, + pub(super) bias: f32, +} + +impl BinaryHyperplane { + fn predict_bit(&self, row: ArrayView1<'_, f32>) -> bool { + normalized_hyperplane_logit(row, &self.weights, self.bias) >= 0.0 + } +} + +#[derive(Debug, Clone)] +pub(super) struct AddressSupervisedGroupModel { + pub(super) groups: Vec, + pub(super) bits_per_group: usize, + pub(super) epochs: usize, + pub(super) lr: f32, + pub(super) l2: f32, + pub(super) group_majority: Vec, + pub(super) group_hyperplanes: Vec>, + pub(super) group_train_accuracy: Vec, +} + +impl AddressSupervisedGroupModel { + pub(super) fn selected_group_keys(&self) -> Vec { + (0..self.group_majority.len()) + .map(|group| { + if self.groups.contains(&group) { + format!( + "supervised{}bit_train_acc_{:.3}_epochs{}_lr{:.3}_l2_{:.1e}", + self.bits_per_group, + self.group_train_accuracy[group], + self.epochs, + self.lr, + self.l2 + ) + } else { + "majority".to_string() + } + }) + .collect() + } + + pub(super) fn predict_selected_groups( + &self, + layer_input: &Array2, + position: usize, + base_codes: &[usize], + ) -> Vec { + let mut codes = base_codes.to_vec(); + let row = layer_input.row(position); + for &group in &self.groups { + let mut code = 0usize; + for (bit, hyperplane) in self.group_hyperplanes[group].iter().enumerate() { + if hyperplane.predict_bit(row) { + code |= 1usize << bit; + } + } + codes[group] = code; + } + codes + } +} + +#[derive(Debug, Clone)] +pub(super) struct AddressAttentionClusterGroupModel { + pub(super) name: String, + pub(super) groups: Vec, + pub(super) qk_rank: Option, + pub(super) centroids: Vec>, + pub(super) group_majority: Vec, + pub(super) group_maps: Vec>, + pub(super) selected_group_keys: Vec, +} + +impl AddressAttentionClusterGroupModel { + pub(super) fn predict_selected_groups( + &self, + token_ids: &[u32], + stratum: &str, + position: usize, + attention_weights: &[f32], + base_codes: &[usize], + ) -> Vec { + let features = attention_pattern_features(attention_weights, position); + let cluster = nearest_attention_cluster(&features, &self.centroids); + let key = attention_cluster_key(&self.name, token_ids, stratum, position, cluster); + let mut codes = base_codes.to_vec(); + for &group in &self.groups { + codes[group] = self.group_maps[group] + .get(&key) + .copied() + .unwrap_or(self.group_majority[group]); + } + codes + } +} + +#[derive(Debug, Clone, Copy)] +pub(super) struct AddressMatchSummary { + pub(super) groups_correct: usize, + pub(super) groups_total: usize, + pub(super) exact_address_match: bool, +} + +pub(super) fn address_probe_names() -> Vec<&'static str> { + vec![ + "position", + "stratum", + "position_stratum", + "token_id", + "prev_token_id", + "token_bigram", + "position_stratum_token", + ] +} + +pub(super) fn prev_ffn_feature_probe_names() -> Vec<&'static str> { + vec![ + "prev_ffn_top1", + "prev_ffn_top2_hash", + "prev_ffn_top4_hash", + "prev_ffn_top8_hash", + "prev_ffn_top16_hash", + "stratum_prev_ffn_top1", + "stratum_prev_ffn_top8_hash", + "token_prev_ffn_top1", + "token_prev_ffn_top8_hash", + "position_prev_ffn_top1", + "position_prev_ffn_top8_hash", + ] +} + +pub(super) fn ffn_first_feature_probe_names() -> Vec<&'static str> { + vec![ + "ffn_first_top1", + "ffn_first_top2_hash", + "ffn_first_top4_hash", + "ffn_first_top8_hash", + "ffn_first_top16_hash", + "stratum_ffn_first_top1", + "stratum_ffn_first_top8_hash", + "token_ffn_first_top1", + "token_ffn_first_top8_hash", + "position_ffn_first_top1", + "position_ffn_first_top8_hash", + ] +} + +pub(super) fn attention_relation_probe_names() -> Vec<&'static str> { + vec![ + "attn_argmax", + "attn_top2_hash", + "attn_top4_hash", + "attn_entropy_bucket", + "attn_bos_bucket", + "attn_distance_bucket", + "attn_relation_class", + "stratum_attn_relation_class", + "token_attn_relation_class", + "position_attn_relation_class", + ] +} + +pub(super) fn attention_cluster_probe_names(cluster_count: usize) -> Vec { + vec![ + format!("attn_cluster_{cluster_count}"), + format!("stratum_attn_cluster_{cluster_count}"), + format!("position_attn_cluster_{cluster_count}"), + format!("token_attn_cluster_{cluster_count}"), + ] +} + +pub(super) fn address_feature_key( + name: &str, + token_ids: &[u32], + stratum: &str, + position: usize, +) -> String { + let token = token_ids.get(position).copied().unwrap_or(0); + let prev = if position == 0 { + u32::MAX + } else { + token_ids.get(position - 1).copied().unwrap_or(0) + }; + match name { + "position" => format!("p:{position}"), + "stratum" => format!("s:{stratum}"), + "position_stratum" => format!("p:{position}|s:{stratum}"), + "token_id" => format!("t:{token}"), + "prev_token_id" => format!("pt:{prev}"), + "token_bigram" => format!("pt:{prev}|t:{token}"), + "position_stratum_token" => format!("p:{position}|s:{stratum}|t:{token}"), + _ => format!("p:{position}"), + } +} + +pub(super) fn attention_relation_key( + name: &str, + token_ids: &[u32], + stratum: &str, + position: usize, + weights: &[f32], +) -> String { + let token = token_ids.get(position).copied().unwrap_or(0); + let argmax = attention_argmax(weights, position); + let top2 = attention_topk_key(weights, position, 2); + let top4 = attention_topk_key(weights, position, 4); + let entropy = attention_entropy_bucket(weights, position); + let bos = attention_bos_bucket(weights.first().copied().unwrap_or(0.0)); + let distance = attention_distance_bucket(argmax, position); + let relation = attention_relation_class(argmax, position); + match name { + "attn_argmax" => format!("aa:{argmax}"), + "attn_top2_hash" => format!("at2:{top2}"), + "attn_top4_hash" => format!("at4:{top4}"), + "attn_entropy_bucket" => format!("ae:{entropy}"), + "attn_bos_bucket" => format!("ab:{bos}"), + "attn_distance_bucket" => format!("ad:{distance}"), + "attn_relation_class" => format!("ar:{relation}"), + "stratum_attn_relation_class" => format!("s:{stratum}|ar:{relation}"), + "token_attn_relation_class" => format!("t:{token}|ar:{relation}"), + "position_attn_relation_class" => format!("p:{position}|ar:{relation}"), + _ => format!("ar:{relation}"), + } +} + +pub(super) fn attention_cluster_key( + name: &str, + token_ids: &[u32], + stratum: &str, + position: usize, + cluster: usize, +) -> String { + let token = token_ids.get(position).copied().unwrap_or(0); + if name.contains("stratum_attn_cluster_") { + format!("s:{stratum}|ac:{cluster}") + } else if name.contains("position_attn_cluster_") { + format!("p:{position}|ac:{cluster}") + } else if name.contains("token_attn_cluster_") { + format!("t:{token}|ac:{cluster}") + } else { + format!("ac:{cluster}") + } +} + +pub(super) fn prev_ffn_feature_key( + name: &str, + token_ids: &[u32], + stratum: &str, + position: usize, + prev_features: &[usize], +) -> String { + let token = token_ids.get(position).copied().unwrap_or(0); + let top1 = prev_features + .first() + .map(|feature| feature.to_string()) + .unwrap_or_else(|| "none".to_string()); + let top2 = prev_features + .iter() + .take(2) + .map(|feature| feature.to_string()) + .collect::>() + .join(","); + let top2 = if top2.is_empty() { + "none".to_string() + } else { + top2 + }; + let top4 = feature_set_key(prev_features, 4); + let top8 = feature_set_key(prev_features, 8); + let top16 = feature_set_key(prev_features, 16); + match name { + "prev_ffn_top1" => format!("pf1:{top1}"), + "prev_ffn_top2_hash" => format!("pf2:{top2}"), + "prev_ffn_top4_hash" => format!("pf4:{top4}"), + "prev_ffn_top8_hash" => format!("pf8:{top8}"), + "prev_ffn_top16_hash" => format!("pf16:{top16}"), + "stratum_prev_ffn_top1" => format!("s:{stratum}|pf1:{top1}"), + "stratum_prev_ffn_top8_hash" => format!("s:{stratum}|pf8:{top8}"), + "token_prev_ffn_top1" => format!("t:{token}|pf1:{top1}"), + "token_prev_ffn_top8_hash" => format!("t:{token}|pf8:{top8}"), + "position_prev_ffn_top1" => format!("p:{position}|pf1:{top1}"), + "position_prev_ffn_top8_hash" => format!("p:{position}|pf8:{top8}"), + _ => format!("pf1:{top1}"), + } +} + +pub(super) fn ffn_first_feature_key( + name: &str, + token_ids: &[u32], + stratum: &str, + position: usize, + features: &[usize], +) -> String { + let token = token_ids.get(position).copied().unwrap_or(0); + let top1 = features + .first() + .map(|feature| feature.to_string()) + .unwrap_or_else(|| "none".to_string()); + let top2 = features + .iter() + .take(2) + .map(|feature| feature.to_string()) + .collect::>() + .join(","); + let top2 = if top2.is_empty() { + "none".to_string() + } else { + top2 + }; + let top4 = feature_set_key(features, 4); + let top8 = feature_set_key(features, 8); + let top16 = feature_set_key(features, 16); + match name { + "ffn_first_top1" => format!("ff1:{top1}"), + "ffn_first_top2_hash" => format!("ff2:{top2}"), + "ffn_first_top4_hash" => format!("ff4:{top4}"), + "ffn_first_top8_hash" => format!("ff8:{top8}"), + "ffn_first_top16_hash" => format!("ff16:{top16}"), + "stratum_ffn_first_top1" => format!("s:{stratum}|ff1:{top1}"), + "stratum_ffn_first_top8_hash" => format!("s:{stratum}|ff8:{top8}"), + "token_ffn_first_top1" => format!("t:{token}|ff1:{top1}"), + "token_ffn_first_top8_hash" => format!("t:{token}|ff8:{top8}"), + "position_ffn_first_top1" => format!("p:{position}|ff1:{top1}"), + "position_ffn_first_top8_hash" => format!("p:{position}|ff8:{top8}"), + _ => format!("ff1:{top1}"), + } +} + +pub(super) fn attention_argmax(weights: &[f32], position: usize) -> usize { + let causal_len = (position + 1).min(weights.len()); + weights + .iter() + .take(causal_len) + .copied() + .enumerate() + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) + .map(|(idx, _)| idx) + .unwrap_or(0) +} + +fn attention_topk_key(weights: &[f32], position: usize, k: usize) -> String { + let causal_len = (position + 1).min(weights.len()); + let mut indexed = weights + .iter() + .take(causal_len) + .copied() + .enumerate() + .collect::>(); + indexed.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + let key = indexed + .into_iter() + .take(k) + .map(|(source, _)| source.to_string()) + .collect::>() + .join(","); + if key.is_empty() { + "none".to_string() + } else { + key + } +} + +pub(super) fn attention_entropy_bits(weights: &[f32], position: usize) -> f64 { + let causal_len = (position + 1).min(weights.len()); + weights + .iter() + .take(causal_len) + .copied() + .filter(|&p| p > 0.0) + .map(|p| { + let p = p as f64; + -p * p.log2() + }) + .sum::() +} + +fn attention_entropy_bucket(weights: &[f32], position: usize) -> usize { + let entropy_bits = attention_entropy_bits(weights, position); + ((entropy_bits * 2.0).floor() as usize).min(16) +} + +fn attention_bos_bucket(mass: f32) -> &'static str { + match mass { + x if x < 0.01 => "lt001", + x if x < 0.05 => "lt005", + x if x < 0.10 => "lt010", + x if x < 0.25 => "lt025", + x if x < 0.50 => "lt050", + _ => "ge050", + } +} + +fn attention_distance_bucket(argmax: usize, position: usize) -> &'static str { + if argmax == 0 { + "bos" + } else if argmax == position { + "self" + } else if argmax + 1 == position { + "prev" + } else if argmax > position { + "future" + } else { + match position - argmax { + 0 => "self", + 1 => "prev", + 2..=4 => "d2_4", + 5..=8 => "d5_8", + 9..=16 => "d9_16", + _ => "far", + } + } +} + +fn attention_relation_class(argmax: usize, position: usize) -> &'static str { + if argmax == 0 { + "bos" + } else if argmax == position { + "self" + } else if argmax + 1 == position { + "prev" + } else if argmax > position { + "future" + } else { + match position - argmax { + 0 => "self", + 1 => "prev", + 2..=4 => "local", + 5..=16 => "mid", + _ => "far", + } + } +} + +fn feature_set_key(prev_features: &[usize], k: usize) -> String { + let key = prev_features + .iter() + .take(k) + .map(|feature| feature.to_string()) + .collect::>() + .join(","); + if key.is_empty() { + "none".to_string() + } else { + key + } +} + +pub(super) fn top_feature_ids_from_activation_row( + row: ArrayView1<'_, f32>, + top_k: usize, +) -> Vec { + let mut indexed = row.iter().copied().enumerate().collect::>(); + indexed.sort_unstable_by(|a, b| { + b.1.abs() + .partial_cmp(&a.1.abs()) + .unwrap_or(std::cmp::Ordering::Equal) + }); + indexed + .into_iter() + .take(top_k) + .map(|(feature, _)| feature) + .collect() +} + +pub(super) fn attention_pattern_features(weights: &[f32], position: usize) -> Vec { + let causal_len = (position + 1).min(weights.len()); + if causal_len == 0 { + return vec![0.0; 35]; + } + let denom = causal_len.max(1) as f64; + let argmax = attention_argmax(weights, position); + let max_mass = weights.get(argmax).copied().unwrap_or(0.0) as f64; + let entropy_bits = weights + .iter() + .take(causal_len) + .copied() + .filter(|&p| p > 0.0) + .map(|p| { + let p = p as f64; + -p * p.log2() + }) + .sum::(); + let entropy_norm = if causal_len > 1 { + entropy_bits / (causal_len as f64).log2() + } else { + 0.0 + }; + + let mut bos_mass = 0.0; + let mut self_mass = 0.0; + let mut prev_mass = 0.0; + let mut local_mass = 0.0; + let mut mid_mass = 0.0; + let mut far_mass = 0.0; + for (source, &mass) in weights.iter().take(causal_len).enumerate() { + let mass = mass as f64; + if source == 0 { + bos_mass += mass; + } + if source == position { + self_mass += mass; + } else if source + 1 == position { + prev_mass += mass; + } else if source < position { + let distance = position - source; + if distance <= 4 { + local_mass += mass; + } else if distance <= 16 { + mid_mass += mass; + } else { + far_mass += mass; + } + } + } + + let argmax_source_norm = argmax as f64 / denom; + let argmax_distance_norm = if argmax <= position { + (position - argmax) as f64 / denom + } else { + 0.0 + }; + + let mut features = vec![ + bos_mass, + self_mass, + prev_mass, + local_mass, + mid_mass, + far_mass, + entropy_bits, + entropy_norm, + max_mass, + argmax_source_norm, + argmax_distance_norm, + ]; + + let mut indexed = weights + .iter() + .take(causal_len) + .copied() + .enumerate() + .collect::>(); + indexed.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + for rank in 0..8 { + if let Some((source, mass)) = indexed.get(rank).copied() { + let source_norm = source as f64 / denom; + let rel_distance = if source <= position { + (position - source) as f64 / denom + } else { + 0.0 + }; + features.push(mass as f64); + features.push(source_norm); + features.push(rel_distance); + } else { + features.push(0.0); + features.push(0.0); + features.push(0.0); + } + } + + features +} + +pub(super) fn nearest_attention_cluster(features: &[f64], centroids: &[Vec]) -> usize { + let mut best_idx = 0usize; + let mut best_dist = f64::INFINITY; + for (idx, centroid) in centroids.iter().enumerate() { + let dist = features + .iter() + .zip(centroid.iter()) + .map(|(&a, &b)| { + let d = a - b; + d * d + }) + .sum::(); + if dist < best_dist { + best_dist = dist; + best_idx = idx; + } + } + best_idx +} + +pub(super) fn lsh_bucket(row: ArrayView1<'_, f32>, seed: u64, bits: usize) -> usize { + let mut bucket = 0usize; + for bit in 0..bits { + let mut sum = 0.0_f64; + for (dim, &value) in row.iter().enumerate() { + let hash = splitmix64( + seed ^ ((bit as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15)) + ^ ((dim as u64).wrapping_mul(0xBF58_476D_1CE4_E5B9)), + ); + let sign = if hash & 1 == 0 { -1.0 } else { 1.0 }; + sum += value as f64 * sign; + } + if sum >= 0.0 { + bucket |= 1usize << bit; + } + } + bucket +} + +pub(super) fn train_binary_hyperplane( + rows: &[&[f32]], + labels: &[bool], + dim: usize, + epochs: usize, + lr: f32, + l2: f32, +) -> BinaryHyperplane { + let mut weights = vec![0.0_f32; dim]; + let positives = labels.iter().filter(|&&label| label).count(); + let negatives = labels.len().saturating_sub(positives); + let mut bias = if positives == 0 { + -4.0 + } else if negatives == 0 { + 4.0 + } else { + ((positives as f32 + 0.5) / (negatives as f32 + 0.5)).ln() + }; + + for _ in 0..epochs { + for (row, &label) in rows.iter().zip(labels.iter()) { + let scale = normalized_row_scale_slice(row); + let dot = row + .iter() + .zip(weights.iter()) + .map(|(&x, &w)| (x / scale) * w) + .sum::(); + let logit = (bias + dot).clamp(-30.0, 30.0); + let prob = 1.0 / (1.0 + (-logit).exp()); + let target = if label { 1.0 } else { 0.0 }; + let grad = prob - target; + for (w, &x) in weights.iter_mut().zip(row.iter()) { + *w -= lr * (grad * (x / scale) + l2 * *w); + } + bias -= lr * grad; + } + } + + BinaryHyperplane { weights, bias } +} + +pub(super) fn predict_code_from_hyperplanes( + row: &[f32], + hyperplanes: &[BinaryHyperplane], +) -> usize { + let scale = normalized_row_scale_slice(row); + let mut code = 0usize; + for (bit, hyperplane) in hyperplanes.iter().enumerate() { + let dot = row + .iter() + .zip(hyperplane.weights.iter()) + .map(|(&x, &w)| (x / scale) * w) + .sum::(); + if hyperplane.bias + dot >= 0.0 { + code |= 1usize << bit; + } + } + code +} + +pub(super) fn address_match_report( + oracle_codes_by_position: &[Vec], + predicted_codes_by_position: &[Vec], +) -> AddressMatchSummary { + let mut groups_correct = 0usize; + let mut groups_total = 0usize; + let mut exact_address_match = true; + for (oracle, predicted) in oracle_codes_by_position + .iter() + .zip(predicted_codes_by_position.iter()) + { + if oracle != predicted { + exact_address_match = false; + } + for (&oracle_code, &predicted_code) in oracle.iter().zip(predicted.iter()) { + groups_total += 1; + if oracle_code == predicted_code { + groups_correct += 1; + } + } + } + AddressMatchSummary { + groups_correct, + groups_total, + exact_address_match, + } +} + +fn normalized_row_scale_slice(row: &[f32]) -> f32 { + let mean_square = if row.is_empty() { + 0.0 + } else { + row.iter() + .map(|&value| (value as f64) * (value as f64)) + .sum::() + / row.len() as f64 + }; + (mean_square.sqrt() as f32).max(1e-6) +} + +fn normalized_row_scale_view(row: ArrayView1<'_, f32>) -> f32 { + let mean_square = if row.is_empty() { + 0.0 + } else { + row.iter() + .map(|&value| (value as f64) * (value as f64)) + .sum::() + / row.len() as f64 + }; + (mean_square.sqrt() as f32).max(1e-6) +} + +fn normalized_hyperplane_logit(row: ArrayView1<'_, f32>, weights: &[f32], bias: f32) -> f32 { + let scale = normalized_row_scale_view(row); + let dot = row + .iter() + .zip(weights.iter()) + .map(|(&x, &w)| (x / scale) * w) + .sum::(); + bias + dot +} + +fn splitmix64(mut x: u64) -> u64 { + x = x.wrapping_add(0x9E37_79B9_7F4A_7C15); + let mut z = x; + z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9); + z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB); + z ^ (z >> 31) +} diff --git a/crates/larql-cli/src/commands/dev/ov_rd/basis.rs b/crates/larql-cli/src/commands/dev/ov_rd/basis.rs new file mode 100644 index 00000000..42e5740e --- /dev/null +++ b/crates/larql-cli/src/commands/dev/ov_rd/basis.rs @@ -0,0 +1,441 @@ +use std::collections::HashMap; + +use larql_inference::attention::run_attention_block_with_pre_o; +use larql_inference::forward::ple::precompute_per_layer_inputs; +use larql_inference::forward::{embed_tokens_pub, run_layer_with_ffn}; +use larql_inference::{encode_prompt, WeightFfn}; +use larql_vindex::VectorIndex; +use ndarray::s; + +use super::runtime::{insert_q4k_layer_tensors, remove_layer_tensors}; +use super::stats::StaticHeadMeans; +use super::types::{HeadId, PromptRecord}; + +#[derive(Debug)] +pub(super) struct WoRoundtripBasis { + pub(super) head_dim: usize, + gram: Vec>, + vectors: Vec>, + sigmas: Vec, + pub(super) sigma_max: f64, + pub(super) sigma_min_retained: f64, + pub(super) sigma_rel_cutoff: f64, +} + +impl WoRoundtripBasis { + pub(super) fn rank_retained(&self) -> usize { + self.vectors.len() + } + + pub(super) fn project(&self, y: &[f32]) -> Vec { + self.project_with_rank(y, self.vectors.len()) + } + + pub(super) fn project_with_rank(&self, y: &[f32], k: usize) -> Vec { + let mut out = vec![0.0f64; self.head_dim]; + for v in self.vectors.iter().take(k.min(self.vectors.len())) { + let coeff = v + .iter() + .zip(y.iter()) + .map(|(&vi, &yi)| vi * yi as f64) + .sum::(); + for (dst, &vi) in out.iter_mut().zip(v.iter()) { + *dst += coeff * vi; + } + } + out.into_iter().map(|value| value as f32).collect() + } + + pub(super) fn residual_to_z(&self, residual: &[f32]) -> Vec { + self.vectors + .iter() + .zip(self.sigmas.iter()) + .map(|(v, &sigma)| { + sigma + * v.iter() + .zip(residual.iter()) + .map(|(&vi, &ri)| vi * ri as f64) + .sum::() + }) + .collect() + } + + pub(super) fn z_to_residual(&self, z: &[f64]) -> Vec { + let mut residual = vec![0.0f64; self.head_dim]; + for ((v, &sigma), &zi) in self.vectors.iter().zip(self.sigmas.iter()).zip(z.iter()) { + if sigma == 0.0 { + continue; + } + let coeff = zi / sigma; + for (dst, &vi) in residual.iter_mut().zip(v.iter()) { + *dst += coeff * vi; + } + } + residual.into_iter().map(|value| value as f32).collect() + } + + pub(super) fn visible_sq_norm(&self, delta: &[f64]) -> f64 { + let mut total = 0.0; + for i in 0..self.head_dim { + let mut row = 0.0; + for j in 0..self.head_dim { + row += self.gram[i][j] * delta[j]; + } + total += delta[i] * row; + } + total.max(0.0) + } +} + +#[derive(Debug, Clone, Copy)] +pub(super) struct RoundtripPatchMetrics { + pub(super) pre_wo_l2: f64, + pub(super) wo_visible_l2: f64, +} + +#[derive(Debug)] +pub(super) struct ZPcaBasis { + pub(super) vectors: Vec>, +} + +impl ZPcaBasis { + pub(super) fn rank(&self) -> usize { + self.vectors.len() + } + + pub(super) fn coordinates_with_rank(&self, z: &[f64], k: usize) -> Vec { + self.vectors + .iter() + .take(k.min(self.vectors.len())) + .map(|v| v.iter().zip(z.iter()).map(|(&vi, &zi)| vi * zi).sum()) + .collect() + } + + pub(super) fn reconstruct_from_coordinates(&self, coords: &[f64]) -> Vec { + let dim = self.vectors.first().map(|v| v.len()).unwrap_or(0); + let mut out = vec![0.0; dim]; + for (coord, v) in coords.iter().zip(self.vectors.iter()) { + for (dst, &vi) in out.iter_mut().zip(v.iter()) { + *dst += coord * vi; + } + } + out + } + + pub(super) fn project_with_rank(&self, z: &[f64], k: usize) -> Vec { + let coords = self.coordinates_with_rank(z, k); + self.reconstruct_from_coordinates(&coords) + } +} + +pub(super) fn build_roundtrip_bases( + weights: &mut larql_inference::ModelWeights, + index: &VectorIndex, + heads: &[HeadId], + sigma_rel_cutoff: f64, +) -> Result, Box> { + let mut heads_by_layer: HashMap> = HashMap::new(); + for head in heads { + heads_by_layer.entry(head.layer).or_default().push(*head); + } + + let mut bases = HashMap::new(); + for (layer, layer_heads) in heads_by_layer { + let inserted = insert_q4k_layer_tensors(weights, index, layer)?; + let w_o = weights + .tensors + .get(&weights.arch.attn_o_key(layer)) + .ok_or_else(|| format!("missing W_O tensor at layer {layer}"))?; + let head_dim = weights.arch.head_dim_for_layer(layer); + for head in layer_heads { + let start = head.head * head_dim; + let end = start + head_dim; + let w_o_head = w_o.slice(s![.., start..end]); + let basis = build_wo_roundtrip_basis(&w_o_head, sigma_rel_cutoff)?; + bases.insert(head, basis); + } + remove_layer_tensors(weights, inserted); + } + + Ok(bases) +} + +#[derive(Debug)] +struct ZPcaAccumulator { + count: u64, + sum: Vec, + sum_outer: Vec>, +} + +impl ZPcaAccumulator { + fn new(dim: usize) -> Self { + Self { + count: 0, + sum: vec![0.0; dim], + sum_outer: vec![vec![0.0; dim]; dim], + } + } + + fn add(&mut self, z: &[f64]) { + self.count += 1; + for (dst, &value) in self.sum.iter_mut().zip(z.iter()) { + *dst += value; + } + for i in 0..z.len() { + for j in i..z.len() { + self.sum_outer[i][j] += z[i] * z[j]; + } + } + } + + fn finish(mut self) -> ZPcaBasis { + let dim = self.sum.len(); + if self.count == 0 { + return ZPcaBasis { + vectors: Vec::new(), + }; + } + for i in 0..dim { + for j in 0..i { + self.sum_outer[i][j] = self.sum_outer[j][i]; + } + } + let n = self.count as f64; + let mut covariance = self.sum_outer; + for i in 0..dim { + for j in 0..dim { + covariance[i][j] = covariance[i][j] / n - (self.sum[i] / n) * (self.sum[j] / n); + } + } + let (eigenvalues, eigenvectors) = jacobi_symmetric_eigen(&covariance, 100, 1e-8); + let mut pairs: Vec<(f64, Vec)> = eigenvalues.into_iter().zip(eigenvectors).collect(); + pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)); + ZPcaBasis { + vectors: pairs + .into_iter() + .filter(|(value, _)| *value > 0.0) + .map(|(_, vector)| vector) + .collect(), + } + } +} + +pub(super) fn fit_z_pca_bases( + weights: &mut larql_inference::ModelWeights, + index: &VectorIndex, + tokenizer: &tokenizers::Tokenizer, + prompts: &[PromptRecord], + heads: &[HeadId], + bases: &HashMap, + means: &HashMap, +) -> Result, Box> { + let mut heads_by_layer: HashMap> = HashMap::new(); + for head in heads { + heads_by_layer.entry(head.layer).or_default().push(*head); + } + + let mut accumulators: HashMap = HashMap::new(); + for head in heads { + let basis = bases + .get(head) + .ok_or_else(|| format!("missing W_O basis for L{} H{}", head.layer, head.head))?; + accumulators.insert(*head, ZPcaAccumulator::new(basis.rank_retained())); + } + + for (prompt_idx, record) in prompts.iter().enumerate() { + let label = record + .id + .as_deref() + .or(record.stratum.as_deref()) + .unwrap_or("prompt"); + eprintln!(" pca-fit [{}/{}] {}", prompt_idx + 1, prompts.len(), label); + let token_ids = encode_prompt(tokenizer, &*weights.arch, &record.prompt)?; + if token_ids.is_empty() { + continue; + } + let mut h = embed_tokens_pub(weights, &token_ids); + let ple_inputs = precompute_per_layer_inputs(weights, &h, &token_ids); + + for layer in 0..weights.num_layers { + let inserted = insert_q4k_layer_tensors(weights, index, layer)?; + if let Some(layer_heads) = heads_by_layer.get(&layer) { + let (_, pre_o) = run_attention_block_with_pre_o(weights, &h, layer) + .ok_or_else(|| format!("pre-W_O capture failed at layer {layer}"))?; + let head_dim = weights.arch.head_dim_for_layer(layer); + for head in layer_heads { + let basis = bases.get(head).expect("basis pre-created for PCA fit"); + let head_means = means.get(head).expect("means pre-created for PCA fit"); + let start = head.head * head_dim; + let end = start + head_dim; + let acc = accumulators.get_mut(head).expect("PCA accumulator missing"); + for pos in 0..pre_o.nrows() { + let row = pre_o.slice(s![pos, start..end]); + let values = row + .as_slice() + .ok_or("pre-W_O head row was not contiguous during PCA fit")?; + let base = head_means.positions.get(pos).unwrap_or(&head_means.global); + let residual = values + .iter() + .zip(base.iter()) + .map(|(&yi, &bi)| yi - bi) + .collect::>(); + let z = basis.residual_to_z(&residual); + acc.add(&z); + } + } + } + + { + let ffn = WeightFfn { weights }; + if let Some((h_new, _, _)) = + run_layer_with_ffn(weights, &h, layer, &ffn, false, ple_inputs.get(layer), None) + { + h = h_new; + } + } + remove_layer_tensors(weights, inserted); + } + } + + Ok(accumulators + .into_iter() + .map(|(head, acc)| (head, acc.finish())) + .collect()) +} + +fn build_wo_roundtrip_basis( + w_o_head: &ndarray::ArrayBase, ndarray::Ix2>, + sigma_rel_cutoff: f64, +) -> Result> { + let hidden = w_o_head.nrows(); + let head_dim = w_o_head.ncols(); + let mut gram = vec![vec![0.0f64; head_dim]; head_dim]; + for row in 0..hidden { + for i in 0..head_dim { + let wi = w_o_head[[row, i]] as f64; + for j in i..head_dim { + gram[i][j] += wi * w_o_head[[row, j]] as f64; + } + } + } + for i in 0..head_dim { + for j in 0..i { + gram[i][j] = gram[j][i]; + } + } + + let (eigenvalues, eigenvectors) = jacobi_symmetric_eigen(&gram, 100, 1e-10); + let mut pairs: Vec<(f64, Vec)> = eigenvalues.into_iter().zip(eigenvectors).collect(); + pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)); + + let sigma_max = pairs + .first() + .map(|(value, _)| value.max(0.0).sqrt()) + .unwrap_or(0.0); + let cutoff = sigma_max * sigma_rel_cutoff; + let mut vectors = Vec::new(); + let mut sigmas = Vec::new(); + let mut sigma_min_retained: f64 = 0.0; + for (value, vector) in pairs { + let sigma = value.max(0.0).sqrt(); + if sigma > cutoff { + sigma_min_retained = if sigma_min_retained == 0.0 { + sigma + } else { + sigma_min_retained.min(sigma) + }; + sigmas.push(sigma); + vectors.push(vector); + } + } + if vectors.is_empty() && sigma_max > 0.0 { + return Err("W_O roundtrip retained zero singular directions".into()); + } + + Ok(WoRoundtripBasis { + head_dim, + gram, + vectors, + sigmas, + sigma_max, + sigma_min_retained, + sigma_rel_cutoff, + }) +} + +pub(super) fn jacobi_symmetric_eigen( + input: &[Vec], + max_sweeps: usize, + tolerance: f64, +) -> (Vec, Vec>) { + let n = input.len(); + let mut a = input.to_vec(); + let mut v = vec![vec![0.0f64; n]; n]; + for i in 0..n { + v[i][i] = 1.0; + } + + for _ in 0..max_sweeps { + let mut max_value = 0.0; + let mut p = 0; + let mut q = 1.min(n.saturating_sub(1)); + for i in 0..n { + for j in (i + 1)..n { + let value = a[i][j].abs(); + if value > max_value { + max_value = value; + p = i; + q = j; + } + } + } + if max_value < tolerance || n < 2 { + break; + } + + let app = a[p][p]; + let aqq = a[q][q]; + let apq = a[p][q]; + if apq == 0.0 { + continue; + } + let tau = (aqq - app) / (2.0 * apq); + let t = if tau >= 0.0 { + 1.0 / (tau + (1.0 + tau * tau).sqrt()) + } else { + -1.0 / (-tau + (1.0 + tau * tau).sqrt()) + }; + let c = 1.0 / (1.0 + t * t).sqrt(); + let s = t * c; + + for k in 0..n { + if k != p && k != q { + let akp = a[k][p]; + let akq = a[k][q]; + let new_kp = c * akp - s * akq; + let new_kq = s * akp + c * akq; + a[k][p] = new_kp; + a[p][k] = new_kp; + a[k][q] = new_kq; + a[q][k] = new_kq; + } + } + a[p][p] = c * c * app - 2.0 * s * c * apq + s * s * aqq; + a[q][q] = s * s * app + 2.0 * s * c * apq + c * c * aqq; + a[p][q] = 0.0; + a[q][p] = 0.0; + + for row in &mut v { + let vip = row[p]; + let viq = row[q]; + row[p] = c * vip - s * viq; + row[q] = s * vip + c * viq; + } + } + + let eigenvalues = (0..n).map(|i| a[i][i]).collect::>(); + let eigenvectors = (0..n) + .map(|col| (0..n).map(|row| v[row][col]).collect::>()) + .collect::>(); + (eigenvalues, eigenvectors) +} diff --git a/crates/larql-cli/src/commands/dev/ov_rd/capture.rs b/crates/larql-cli/src/commands/dev/ov_rd/capture.rs new file mode 100644 index 00000000..d509476d --- /dev/null +++ b/crates/larql-cli/src/commands/dev/ov_rd/capture.rs @@ -0,0 +1,261 @@ +use std::path::PathBuf; +use std::time::Instant; + +use clap::Args; +use larql_inference::attention::run_attention_block_with_pre_o; +use larql_inference::forward::ple::precompute_per_layer_inputs; +use larql_inference::forward::{dot_proj, embed_tokens_pub, run_layer_with_ffn}; +use larql_inference::{encode_prompt, WeightFfn}; +use larql_vindex::{ + load_model_weights_q4k, load_vindex_tokenizer, SilentLoadCallbacks, VectorIndex, +}; +use ndarray::{s, Array2}; + +use super::input::{limit_prompts_per_stratum, load_prompts, parse_layer_spec}; +use super::reports::{CaptureReport, HeadReport}; +use super::runtime::{insert_q4k_layer_tensors, remove_layer_tensors}; +use super::stats::RunningHeadStats; + +#[derive(Args)] +pub(super) struct CaptureArgs { + /// Self-contained Q4K vindex directory. + #[arg(long)] + index: PathBuf, + + /// JSONL prompt file. Each line must include at least {"prompt": "..."}. + #[arg(long)] + prompts: PathBuf, + + /// Output directory. + #[arg(long)] + out: PathBuf, + + /// Layers to capture. Comma-separated or range. Default: all. + #[arg(long)] + layers: Option, + + /// Limit prompts for smoke runs. + #[arg(long)] + max_prompts: Option, + + /// Limit prompts per stratum after loading the prompt file. + #[arg(long)] + max_per_stratum: Option, + + /// Limit token positions per prompt for smoke runs. + #[arg(long)] + max_positions: Option, + + /// Also compute W_O-visible residual-contribution statistics. + /// + /// This is slower than raw pre-W_O capture because it projects each head + /// through its W_O block, but it gives the ranking the downstream residual + /// actually sees. + #[arg(long)] + wo_visible: bool, +} + +pub(super) fn run_capture(args: CaptureArgs) -> Result<(), Box> { + std::fs::create_dir_all(&args.out)?; + + eprintln!("Loading vindex: {}", args.index.display()); + let start = Instant::now(); + let mut cb = SilentLoadCallbacks; + let mut index = VectorIndex::load_vindex(&args.index, &mut cb)?; + index.load_attn_q4k(&args.index)?; + index.load_interleaved_q4k(&args.index)?; + let mut weights = load_model_weights_q4k(&args.index, &mut cb)?; + let tokenizer = load_vindex_tokenizer(&args.index)?; + eprintln!( + " {} layers, hidden_size={}, q_heads={}, head_dim={} ({:.1}s)", + weights.num_layers, + weights.hidden_size, + weights.num_q_heads, + weights.head_dim, + start.elapsed().as_secs_f64() + ); + + let layers: Vec = match &args.layers { + Some(spec) => parse_layer_spec(spec)?, + None => (0..weights.num_layers).collect(), + }; + let capture_layer = |layer: usize| layers.contains(&layer); + + let mut prompts = load_prompts(&args.prompts, args.max_prompts)?; + if let Some(max_per_stratum) = args.max_per_stratum { + prompts = limit_prompts_per_stratum(prompts, max_per_stratum); + } + eprintln!("Prompts: {}", prompts.len()); + eprintln!("Layers: {:?}", layers); + + let mut stats: Vec> = (0..weights.num_layers) + .map(|layer| { + let heads = weights.arch.num_q_heads_for_layer(layer); + let head_dim = weights.arch.head_dim_for_layer(layer); + (0..heads) + .map(|_| RunningHeadStats::new(head_dim)) + .collect() + }) + .collect(); + let mut wo_visible_stats: Vec>> = (0..weights.num_layers) + .map(|layer| { + let heads = weights.arch.num_q_heads_for_layer(layer); + (0..heads) + .map(|_| { + if args.wo_visible { + Some(RunningHeadStats::new(weights.hidden_size)) + } else { + None + } + }) + .collect() + }) + .collect(); + + for (prompt_idx, record) in prompts.iter().enumerate() { + let label = record + .id + .as_deref() + .or(record.stratum.as_deref()) + .unwrap_or("prompt"); + eprintln!(" [{}/{}] {}", prompt_idx + 1, prompts.len(), label); + + let token_ids = encode_prompt(&tokenizer, &*weights.arch, &record.prompt)?; + if token_ids.is_empty() { + continue; + } + + let mut h = embed_tokens_pub(&weights, &token_ids); + let ple_inputs = precompute_per_layer_inputs(&weights, &h, &token_ids); + + for layer in 0..weights.num_layers { + let inserted = insert_q4k_layer_tensors(&mut weights, &index, layer)?; + + if capture_layer(layer) { + let (_, pre_o) = run_attention_block_with_pre_o(&weights, &h, layer) + .ok_or_else(|| format!("pre-W_O capture failed at layer {layer}"))?; + add_pre_o_stats( + &mut stats[layer], + &pre_o, + weights.arch.num_q_heads_for_layer(layer), + weights.arch.head_dim_for_layer(layer), + args.max_positions, + ); + if args.wo_visible { + let w_o = weights + .tensors + .get(&weights.arch.attn_o_key(layer)) + .ok_or_else(|| format!("missing W_O tensor at layer {layer}"))?; + add_pre_o_wo_visible_stats( + &mut wo_visible_stats[layer], + &pre_o, + w_o, + weights.arch.num_q_heads_for_layer(layer), + weights.arch.head_dim_for_layer(layer), + args.max_positions, + ); + } + } + + { + let ffn = WeightFfn { weights: &weights }; + if let Some((h_new, _, _)) = run_layer_with_ffn( + &weights, + &h, + layer, + &ffn, + false, + ple_inputs.get(layer), + None, + ) { + h = h_new; + } + } + + remove_layer_tensors(&mut weights, inserted); + } + } + + let mut heads = Vec::new(); + for &layer in &layers { + let head_dim = weights.arch.head_dim_for_layer(layer); + for (head, stat) in stats[layer].iter().enumerate() { + heads.push(HeadReport { + layer, + head, + head_dim, + stats: stat.finish(), + wo_visible_stats: wo_visible_stats[layer][head] + .as_ref() + .map(RunningHeadStats::finish), + }); + } + } + + let report = CaptureReport { + index: args.index.display().to_string(), + prompt_file: args.prompts.display().to_string(), + prompts_seen: prompts.len(), + layers, + max_positions: args.max_positions, + wo_visible: args.wo_visible, + heads, + }; + + let out_path = args.out.join("stage0_pre_o_stats.json"); + let file = std::fs::File::create(&out_path)?; + serde_json::to_writer_pretty(file, &report)?; + eprintln!("Wrote {}", out_path.display()); + + Ok(()) +} + +fn add_pre_o_stats( + stats: &mut [RunningHeadStats], + pre_o: &Array2, + num_heads: usize, + head_dim: usize, + max_positions: Option, +) { + let positions = max_positions + .map(|n| n.min(pre_o.nrows())) + .unwrap_or_else(|| pre_o.nrows()); + for pos in 0..positions { + for head in 0..num_heads { + let start = head * head_dim; + let end = start + head_dim; + let row = pre_o.slice(s![pos, start..end]); + if let Some(values) = row.as_slice() { + stats[head].add(values); + } + } + } +} + +fn add_pre_o_wo_visible_stats( + stats: &mut [Option], + pre_o: &Array2, + w_o: &ndarray::ArrayBase, ndarray::Ix2>, + num_heads: usize, + head_dim: usize, + max_positions: Option, +) { + let positions = max_positions + .map(|n| n.min(pre_o.nrows())) + .unwrap_or_else(|| pre_o.nrows()); + for head in 0..num_heads { + let Some(head_stats) = stats.get_mut(head).and_then(Option::as_mut) else { + continue; + }; + let start = head * head_dim; + let end = start + head_dim; + let head_out = pre_o.slice(s![0..positions, start..end]); + let w_o_head = w_o.slice(s![.., start..end]); + let contribution = dot_proj(&head_out, &w_o_head); + for row in contribution.rows() { + if let Some(values) = row.as_slice() { + head_stats.add(values); + } + } + } +} diff --git a/crates/larql-cli/src/commands/dev/ov_rd/cmd.rs b/crates/larql-cli/src/commands/dev/ov_rd/cmd.rs new file mode 100644 index 00000000..56031dfa --- /dev/null +++ b/crates/larql-cli/src/commands/dev/ov_rd/cmd.rs @@ -0,0 +1,62 @@ +use clap::{Args, Subcommand}; + +use super::capture::{run_capture, CaptureArgs}; +use super::edit_catalog::{run_oracle_edit_catalog, OracleEditCatalogArgs}; +use super::oracle::{ + run_oracle_lowrank, run_oracle_roundtrip, OracleLowrankArgs, OracleRoundtripArgs, +}; +use super::oracle_pq::{run_oracle_pq, OraclePqArgs}; +use super::pq_exception::{run_oracle_pq_exception, OraclePqExceptionArgs}; +use super::sanity::{run_sanity_check, SanityCheckArgs}; +use super::static_replace::{run_static_replace, StaticReplaceArgs}; +use super::zero_ablate::{run_zero_ablate, ZeroAblateArgs}; + +#[derive(Args)] +pub struct OvRdArgs { + #[command(subcommand)] + command: OvRdCommand, +} + +#[derive(Subcommand)] +enum OvRdCommand { + /// Capture pre-W_O OV output statistics from a Q4K vindex. + Capture(CaptureArgs), + + /// Gate 1: zero selected pre-W_O heads and measure final-logit KL. + ZeroAblate(ZeroAblateArgs), + + /// Static replacement gate: zero/global/position/stratum pre-W_O means. + StaticReplace(StaticReplaceArgs), + + /// Sanity checks for pre-W_O replacement and W_O block equivalence. + SanityCheck(SanityCheckArgs), + + /// Oracle RD plumbing check: W_O-coordinate roundtrip with no truncation. + OracleRoundtrip(OracleRoundtripArgs), + + /// Oracle RD: unquantized low-rank sweep in W_O-visible coordinates. + OracleLowrank(OracleLowrankArgs), + + /// Oracle RD: oracle-addressed product quantization in PCA coordinates. + OraclePq(OraclePqArgs), + + /// Oracle RD: full residual-edit catalogues in hidden/PCA spaces. + OracleEditCatalog(OracleEditCatalogArgs), + + /// Oracle RD: base PQ table plus oracle-addressed exception residuals. + OraclePqException(OraclePqExceptionArgs), +} + +pub fn run(args: OvRdArgs) -> Result<(), Box> { + match args.command { + OvRdCommand::Capture(capture) => run_capture(capture), + OvRdCommand::ZeroAblate(zero) => run_zero_ablate(zero), + OvRdCommand::StaticReplace(static_replace) => run_static_replace(static_replace), + OvRdCommand::SanityCheck(sanity) => run_sanity_check(sanity), + OvRdCommand::OracleRoundtrip(roundtrip) => run_oracle_roundtrip(roundtrip), + OvRdCommand::OracleLowrank(lowrank) => run_oracle_lowrank(lowrank), + OvRdCommand::OraclePq(pq) => run_oracle_pq(pq), + OvRdCommand::OracleEditCatalog(edit_catalog) => run_oracle_edit_catalog(edit_catalog), + OvRdCommand::OraclePqException(exception) => run_oracle_pq_exception(exception), + } +} diff --git a/crates/larql-cli/src/commands/dev/ov_rd/edit_catalog.rs b/crates/larql-cli/src/commands/dev/ov_rd/edit_catalog.rs new file mode 100644 index 00000000..84270e59 --- /dev/null +++ b/crates/larql-cli/src/commands/dev/ov_rd/edit_catalog.rs @@ -0,0 +1,838 @@ +use std::collections::HashMap; +use std::path::PathBuf; +use std::time::Instant; + +use clap::Args; +use larql_inference::attention::run_attention_block_with_pre_o; +use larql_inference::forward::ple::precompute_per_layer_inputs; +use larql_inference::forward::{embed_tokens_pub, run_layer_with_ffn}; +use larql_inference::{encode_prompt, WeightFfn}; +use larql_vindex::{ + load_model_weights_q4k, load_vindex_tokenizer, SilentLoadCallbacks, VectorIndex, +}; +use ndarray::{s, Array2}; + +use super::basis::{build_roundtrip_bases, fit_z_pca_bases, WoRoundtripBasis, ZPcaBasis}; +use super::input::{ + limit_prompts_per_stratum, load_prompts, parse_head_spec, parse_usize_list, + split_prompt_records, +}; +use super::metrics::{ + argmax, bool_rate, kl_logp, log_softmax, mean, percentile, token_prob, top_k_indices, +}; +use super::oracle_pq_forward::final_logits; +use super::pq::{kmeans_centroids, nearest_centroid_index}; +use super::reports::{ + OracleEditCatalogHeadReport, OracleEditCatalogPointReport, OracleEditCatalogPromptReport, + OracleEditCatalogReport, +}; +use super::runtime::{insert_q4k_layer_tensors, remove_layer_tensors}; +use super::static_replace::fit_static_means; +use super::stats::StaticHeadMeans; +use super::types::{HeadId, PromptRecord}; + +#[derive(Args)] +pub(super) struct OracleEditCatalogArgs { + /// Self-contained Q4K vindex directory. + #[arg(long)] + index: PathBuf, + + /// JSONL prompt file. Each line must include at least {"prompt": "..."}. + #[arg(long)] + prompts: PathBuf, + + /// Output directory. + #[arg(long)] + out: PathBuf, + + /// Explicit heads as layer:head comma list, e.g. 20:6. + #[arg(long)] + heads: String, + + /// Comma-separated full-edit catalogue sizes. + #[arg(long, default_value = "32,64,128,256")] + edit_counts: String, + + /// Comma-separated catalogue spaces: hidden,pca. + #[arg(long, default_value = "hidden,pca")] + spaces: String, + + /// PCA coordinate rank used by the pca catalogue space. + #[arg(long, default_value_t = 192)] + pca_rank: usize, + + /// Relative singular value cutoff for retained W_O-visible directions. + #[arg(long, default_value_t = 1e-6)] + sigma_rel_cutoff: f64, + + /// Lloyd iterations per full-edit catalogue. + #[arg(long, default_value_t = 25)] + kmeans_iters: usize, + + /// Limit prompts for bounded oracle runs. + #[arg(long)] + max_prompts: Option, + + /// Keep at most N prompts per stratum after loading. + #[arg(long)] + max_per_stratum: Option, + + /// Evaluate only prompts where prompt_index % eval_mod == eval_offset. + /// The remaining prompts are used to fit static means, PCA, and catalogues. + #[arg(long)] + eval_mod: Option, + + /// Held-out modulo offset used with --eval-mod. + #[arg(long, default_value_t = 0)] + eval_offset: usize, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +enum EditCatalogSpace { + Hidden, + Pca, +} + +impl EditCatalogSpace { + fn parse(name: &str) -> Result> { + match name.trim() { + "hidden" => Ok(Self::Hidden), + "pca" => Ok(Self::Pca), + other => { + Err(format!("invalid edit-catalog space '{other}', expected hidden or pca").into()) + } + } + } + + fn as_str(self) -> &'static str { + match self { + Self::Hidden => "hidden", + Self::Pca => "pca", + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct EditCatalogKey { + head: HeadId, + space: EditCatalogSpace, + edits: usize, +} + +#[derive(Debug, Clone)] +struct EditCatalog { + space: EditCatalogSpace, + feature_centroids: Vec>, + residual_table: Vec>, +} + +pub(super) fn run_oracle_edit_catalog( + args: OracleEditCatalogArgs, +) -> Result<(), Box> { + std::fs::create_dir_all(&args.out)?; + + eprintln!("Loading vindex: {}", args.index.display()); + let start = Instant::now(); + let mut cb = SilentLoadCallbacks; + let mut index = VectorIndex::load_vindex(&args.index, &mut cb)?; + index.load_attn_q4k(&args.index)?; + index.load_interleaved_q4k(&args.index)?; + let mut weights = load_model_weights_q4k(&args.index, &mut cb)?; + let tokenizer = load_vindex_tokenizer(&args.index)?; + if weights.arch.is_hybrid_moe() { + return Err("ov-rd oracle-edit-catalog currently supports dense FFN vindexes only".into()); + } + eprintln!( + " {} layers, hidden_size={}, q_heads={}, head_dim={} ({:.1}s)", + weights.num_layers, + weights.hidden_size, + weights.num_q_heads, + weights.head_dim, + start.elapsed().as_secs_f64() + ); + + let selected_heads = parse_head_spec(&args.heads)?; + if selected_heads.is_empty() { + return Err("no heads selected for oracle edit catalogue".into()); + } + let mut edit_counts = parse_usize_list(&args.edit_counts)?; + edit_counts.sort_unstable(); + edit_counts.dedup(); + if edit_counts.is_empty() { + return Err("no edit counts selected".into()); + } + if edit_counts.iter().any(|&edits| edits == 0) { + return Err("--edit-counts values must be greater than zero".into()); + } + let mut spaces = parse_string_list(&args.spaces) + .into_iter() + .map(|space| EditCatalogSpace::parse(&space)) + .collect::, _>>()?; + spaces.sort_by_key(|space| space.as_str()); + spaces.dedup(); + if spaces.is_empty() { + return Err("no edit-catalog spaces selected".into()); + } + + let mut prompts = load_prompts(&args.prompts, args.max_prompts)?; + if let Some(max_per_stratum) = args.max_per_stratum { + prompts = limit_prompts_per_stratum(prompts, max_per_stratum); + } + let prompts_seen = prompts.len(); + let (fit_prompts, eval_prompts) = if let Some(eval_mod) = args.eval_mod { + split_prompt_records(&prompts, eval_mod, args.eval_offset)? + } else { + (prompts.clone(), prompts) + }; + + eprintln!("Selected heads: {:?}", selected_heads); + eprintln!("Edit counts: {:?}", edit_counts); + eprintln!( + "Edit spaces: {:?}", + spaces + .iter() + .map(|space| space.as_str()) + .collect::>() + ); + eprintln!("Prompts: {}", prompts_seen); + + eprintln!("Fitting position-mean static bases"); + let means = fit_static_means( + &mut weights, + &index, + &tokenizer, + &fit_prompts, + &selected_heads, + )?; + + eprintln!("Building W_O-visible bases"); + let bases = + build_roundtrip_bases(&mut weights, &index, &selected_heads, args.sigma_rel_cutoff)?; + for head in &selected_heads { + let basis = bases + .get(head) + .ok_or_else(|| format!("missing basis for L{} H{}", head.layer, head.head))?; + eprintln!( + " L{}H{} rank={} sigma_max={:.6} sigma_min_retained={:.6}", + head.layer, + head.head, + basis.rank_retained(), + basis.sigma_max, + basis.sigma_min_retained + ); + } + + eprintln!("Fitting empirical z-space PCA bases"); + let pca_bases = fit_z_pca_bases( + &mut weights, + &index, + &tokenizer, + &fit_prompts, + &selected_heads, + &bases, + &means, + )?; + + eprintln!("Fitting full-edit catalogues"); + let catalogs = fit_edit_catalogs( + &mut weights, + &index, + &tokenizer, + &fit_prompts, + &selected_heads, + &bases, + &means, + &pca_bases, + &spaces, + &edit_counts, + args.pca_rank, + args.kmeans_iters, + )?; + + let hidden_tables = build_static_hidden_tables(&mut weights, &index, &selected_heads, &means)?; + let w_o_heads = copy_w_o_heads(&mut weights, &index, &selected_heads)?; + + let mut accumulators: HashMap = HashMap::new(); + for head in &selected_heads { + for &space in &spaces { + for &edits in &edit_counts { + accumulators.insert( + EditCatalogKey { + head: *head, + space, + edits, + }, + EditCatalogAccumulator::new(), + ); + } + } + } + + for (prompt_idx, record) in eval_prompts.iter().enumerate() { + let label = prompt_label(record); + eprintln!(" [{}/{}] {}", prompt_idx + 1, eval_prompts.len(), label); + + let token_ids = encode_prompt(&tokenizer, &*weights.arch, &record.prompt)?; + if token_ids.is_empty() { + continue; + } + let stratum = record.stratum.as_deref().unwrap_or("unknown"); + let baseline_hidden = + larql_inference::vindex::predict_q4k_hidden(&mut weights, &token_ids, &index, None); + let baseline_logits = final_logits(&weights, &baseline_hidden); + let baseline_logp = log_softmax(&baseline_logits); + let baseline_top1 = argmax(&baseline_logits); + let baseline_top2 = top_k_indices(&baseline_logits, 2); + let baseline_top2_token = baseline_top2.get(1).copied().unwrap_or(baseline_top1); + let baseline_top1_prob = token_prob(&baseline_logp, baseline_top1); + let baseline_top2_prob = token_prob(&baseline_logp, baseline_top2_token); + let baseline_top1_margin = baseline_top1_prob - baseline_top2_prob; + + for head in &selected_heads { + let basis = bases + .get(head) + .ok_or_else(|| format!("missing basis for L{}H{}", head.layer, head.head))?; + let pca_basis = pca_bases + .get(head) + .ok_or_else(|| format!("missing PCA basis for L{}H{}", head.layer, head.head))?; + let head_means = means + .get(head) + .ok_or_else(|| format!("missing means for L{}H{}", head.layer, head.head))?; + let static_hidden = hidden_tables.get(head).ok_or_else(|| { + format!( + "missing static hidden table for L{}H{}", + head.layer, head.head + ) + })?; + let w_o_head = w_o_heads + .get(head) + .ok_or_else(|| format!("missing W_O head for L{}H{}", head.layer, head.head))?; + + for &space in &spaces { + for &edits in &edit_counts { + let key = EditCatalogKey { + head: *head, + space, + edits, + }; + let catalog = catalogs.get(&key).ok_or_else(|| { + format!( + "missing edit catalog for L{}H{} {} {edits}", + head.layer, + head.head, + space.as_str() + ) + })?; + let catalog_hidden = forward_q4k_oracle_edit_catalog_head( + &mut weights, + &token_ids, + &index, + *head, + basis, + pca_basis, + head_means, + static_hidden, + w_o_head, + catalog, + args.pca_rank, + )?; + let catalog_logits = final_logits(&weights, &catalog_hidden); + let catalog_logp = log_softmax(&catalog_logits); + let kl = kl_logp(&baseline_logp, &catalog_logp); + let catalog_top1 = argmax(&catalog_logits); + let catalog_top5 = top_k_indices(&catalog_logits, 5); + let catalog_top2 = top_k_indices(&catalog_logits, 2); + let catalog_top2_token = catalog_top2.get(1).copied().unwrap_or(catalog_top1); + let catalog_top1_prob = token_prob(&catalog_logp, catalog_top1); + let catalog_top2_prob = token_prob(&catalog_logp, catalog_top2_token); + let catalog_top1_margin = catalog_top1_prob - catalog_top2_prob; + let catalog_prob_of_baseline_top1 = token_prob(&catalog_logp, baseline_top1); + accumulators + .get_mut(&key) + .expect("edit-catalog accumulator missing") + .add(OracleEditCatalogPromptReport { + id: label.to_string(), + stratum: stratum.to_string(), + kl, + delta_cross_entropy_bits: kl / std::f64::consts::LN_2, + baseline_top1, + catalog_top1, + top1_agree: baseline_top1 == catalog_top1, + baseline_top1_in_catalog_top5: catalog_top5.contains(&baseline_top1), + baseline_top1_prob, + baseline_top2: baseline_top2_token, + baseline_top2_prob, + baseline_top1_margin, + catalog_top1_prob, + catalog_prob_of_baseline_top1, + catalog_top1_margin, + }); + } + } + } + } + + let mut head_reports = Vec::new(); + for head in &selected_heads { + let basis = bases + .get(head) + .ok_or_else(|| format!("missing basis for L{} H{}", head.layer, head.head))?; + let pca_basis = pca_bases + .get(head) + .ok_or_else(|| format!("missing PCA basis for L{} H{}", head.layer, head.head))?; + let mut points = Vec::new(); + for &space in &spaces { + for &edits in &edit_counts { + let key = EditCatalogKey { + head: *head, + space, + edits, + }; + let acc = accumulators + .remove(&key) + .expect("edit-catalog accumulator missing at finish"); + points.push(acc.finish(space, edits, weights.hidden_size)); + } + } + let static_train_samples = means.get(head).map(|m| m.count).unwrap_or(0); + head_reports.push(OracleEditCatalogHeadReport { + layer: head.layer, + head: head.head, + head_dim: basis.head_dim, + rank_retained: basis.rank_retained(), + empirical_rank: pca_basis.rank(), + sigma_max: basis.sigma_max, + sigma_min_retained: basis.sigma_min_retained, + static_train_samples, + points, + }); + } + + let report = OracleEditCatalogReport { + index: args.index.display().to_string(), + prompt_file: args.prompts.display().to_string(), + prompts_seen, + train_prompts_seen: fit_prompts.len(), + eval_prompts_seen: eval_prompts.len(), + max_per_stratum: args.max_per_stratum, + eval_mod: args.eval_mod, + eval_offset: args.eval_offset, + static_base: "position_mean".to_string(), + spaces: spaces + .iter() + .map(|space| space.as_str().to_string()) + .collect(), + edit_counts, + pca_rank: args.pca_rank, + sigma_rel_cutoff: args.sigma_rel_cutoff, + kmeans_iters: args.kmeans_iters, + selected_heads, + heads: head_reports, + }; + + let out_path = args.out.join("oracle_edit_catalog.json"); + let file = std::fs::File::create(&out_path)?; + serde_json::to_writer_pretty(file, &report)?; + eprintln!("Wrote {}", out_path.display()); + + Ok(()) +} + +fn fit_edit_catalogs( + weights: &mut larql_inference::ModelWeights, + index: &VectorIndex, + tokenizer: &tokenizers::Tokenizer, + prompts: &[PromptRecord], + heads: &[HeadId], + bases: &HashMap, + means: &HashMap, + pca_bases: &HashMap, + spaces: &[EditCatalogSpace], + edit_counts: &[usize], + pca_rank: usize, + iterations: usize, +) -> Result, Box> { + let mut heads_by_layer: HashMap> = HashMap::new(); + for head in heads { + heads_by_layer.entry(head.layer).or_default().push(*head); + } + let w_o_heads = copy_w_o_heads(weights, index, heads)?; + + let mut samples: HashMap<(HeadId, EditCatalogSpace), Vec>> = HashMap::new(); + for head in heads { + for &space in spaces { + samples.insert((*head, space), Vec::new()); + } + } + + for (prompt_idx, record) in prompts.iter().enumerate() { + let label = prompt_label(record); + eprintln!( + " catalog-fit [{}/{}] {}", + prompt_idx + 1, + prompts.len(), + label + ); + let token_ids = encode_prompt(tokenizer, &*weights.arch, &record.prompt)?; + if token_ids.is_empty() { + continue; + } + let mut h = embed_tokens_pub(weights, &token_ids); + let ple_inputs = precompute_per_layer_inputs(weights, &h, &token_ids); + + for layer in 0..weights.num_layers { + let inserted = insert_q4k_layer_tensors(weights, index, layer)?; + if let Some(layer_heads) = heads_by_layer.get(&layer) { + let (_, pre_o) = run_attention_block_with_pre_o(weights, &h, layer) + .ok_or_else(|| format!("pre-W_O capture failed at layer {layer}"))?; + let head_dim = weights.arch.head_dim_for_layer(layer); + for head in layer_heads { + let basis = bases.get(head).expect("basis pre-created for edit catalog"); + let head_means = means.get(head).expect("means pre-created for edit catalog"); + let pca_basis = pca_bases + .get(head) + .expect("PCA pre-created for edit catalog"); + if pca_basis.rank() < pca_rank && spaces.contains(&EditCatalogSpace::Pca) { + return Err(format!( + "PCA rank {} is below requested rank {} for L{}H{}", + pca_basis.rank(), + pca_rank, + head.layer, + head.head + ) + .into()); + } + let w_o_head = w_o_heads + .get(head) + .expect("W_O head pre-copied for edit catalog"); + let start = head.head * head_dim; + let end = start + head_dim; + for pos in 0..pre_o.nrows() { + let row = pre_o.slice(s![pos, start..end]); + let values = row + .as_slice() + .ok_or("pre-W_O head row was not contiguous during edit catalog fit")?; + let residual = head_residual(values, head_means, pos); + for &space in spaces { + let sample = match space { + EditCatalogSpace::Hidden => { + project_head_vector_to_hidden(w_o_head, &residual) + .into_iter() + .map(|value| value as f64) + .collect::>() + } + EditCatalogSpace::Pca => { + let z = basis.residual_to_z(&residual); + pca_basis.coordinates_with_rank(&z, pca_rank) + } + }; + samples + .get_mut(&(*head, space)) + .expect("edit samples missing") + .push(sample); + } + } + } + } + + { + let ffn = WeightFfn { weights }; + if let Some((h_new, _, _)) = + run_layer_with_ffn(weights, &h, layer, &ffn, false, ple_inputs.get(layer), None) + { + h = h_new; + } + } + remove_layer_tensors(weights, inserted); + } + } + + let mut catalogs = HashMap::new(); + for head in heads { + let basis = bases + .get(head) + .ok_or_else(|| format!("missing basis for L{}H{}", head.layer, head.head))?; + let pca_basis = pca_bases + .get(head) + .ok_or_else(|| format!("missing PCA basis for L{}H{}", head.layer, head.head))?; + let w_o_head = w_o_heads + .get(head) + .ok_or_else(|| format!("missing W_O head for L{}H{}", head.layer, head.head))?; + for &space in spaces { + let head_samples = samples + .get(&(*head, space)) + .ok_or_else(|| format!("missing edit samples for L{}H{}", head.layer, head.head))?; + for &edits in edit_counts { + let feature_centroids = kmeans_centroids(head_samples, edits, iterations); + let residual_table = match space { + EditCatalogSpace::Hidden => feature_centroids + .iter() + .map(|centroid| centroid.iter().map(|&value| value as f32).collect()) + .collect(), + EditCatalogSpace::Pca => feature_centroids + .iter() + .map(|centroid| { + let z = pca_basis.reconstruct_from_coordinates(centroid); + let residual = basis.z_to_residual(&z); + project_head_vector_to_hidden(w_o_head, &residual) + }) + .collect(), + }; + catalogs.insert( + EditCatalogKey { + head: *head, + space, + edits, + }, + EditCatalog { + space, + feature_centroids, + residual_table, + }, + ); + } + } + } + + Ok(catalogs) +} + +fn forward_q4k_oracle_edit_catalog_head( + weights: &mut larql_inference::ModelWeights, + token_ids: &[u32], + index: &VectorIndex, + head: HeadId, + basis: &WoRoundtripBasis, + pca_basis: &ZPcaBasis, + means: &StaticHeadMeans, + static_hidden: &StaticHiddenTable, + w_o_head: &[Vec], + catalog: &EditCatalog, + pca_rank: usize, +) -> Result, Box> { + let hidden_size = weights.hidden_size; + larql_inference::vindex::predict_q4k_hidden_with_mapped_head_residual_delta( + weights, + token_ids, + index, + head.layer, + head.head, + |original_head| { + let mut replacement_delta = Vec::with_capacity(original_head.nrows() * hidden_size); + for pos in 0..original_head.nrows() { + let row = original_head.row(pos); + let values = row + .as_slice() + .ok_or("pre-W_O head row was not contiguous during edit catalog eval")?; + let residual = head_residual(values, means, pos); + let feature = match catalog.space { + EditCatalogSpace::Hidden => project_head_vector_to_hidden(w_o_head, &residual) + .into_iter() + .map(|value| value as f64) + .collect::>(), + EditCatalogSpace::Pca => { + let z = basis.residual_to_z(&residual); + pca_basis.coordinates_with_rank(&z, pca_rank) + } + }; + let code = nearest_centroid_index(&feature, &catalog.feature_centroids); + let static_delta = static_hidden.delta_for_position(pos); + let edit_delta = &catalog.residual_table[code]; + for (&base, &edit) in static_delta.iter().zip(edit_delta.iter()) { + replacement_delta.push(base + edit); + } + } + Array2::from_shape_vec((original_head.nrows(), hidden_size), replacement_delta) + .map_err(|err| err.to_string()) + }, + ) + .map_err(Into::into) +} + +#[derive(Debug, Clone)] +struct StaticHiddenTable { + by_position: Vec>, + global: Vec, +} + +impl StaticHiddenTable { + fn delta_for_position(&self, position: usize) -> &[f32] { + self.by_position + .get(position) + .map(|delta| delta.as_slice()) + .unwrap_or(&self.global) + } +} + +fn build_static_hidden_tables( + weights: &mut larql_inference::ModelWeights, + index: &VectorIndex, + heads: &[HeadId], + means: &HashMap, +) -> Result, Box> { + let w_o_heads = copy_w_o_heads(weights, index, heads)?; + let mut tables = HashMap::new(); + for head in heads { + let w_o_head = w_o_heads + .get(head) + .ok_or_else(|| format!("missing W_O head for L{}H{}", head.layer, head.head))?; + let head_means = means + .get(head) + .ok_or_else(|| format!("missing means for L{}H{}", head.layer, head.head))?; + let global = project_head_vector_to_hidden(w_o_head, &head_means.global); + let by_position = head_means + .positions + .iter() + .map(|mean| project_head_vector_to_hidden(w_o_head, mean)) + .collect(); + tables.insert( + *head, + StaticHiddenTable { + by_position, + global, + }, + ); + } + Ok(tables) +} + +fn copy_w_o_heads( + weights: &mut larql_inference::ModelWeights, + index: &VectorIndex, + heads: &[HeadId], +) -> Result>>, Box> { + let mut heads_by_layer: HashMap> = HashMap::new(); + for head in heads { + heads_by_layer.entry(head.layer).or_default().push(*head); + } + let mut out = HashMap::new(); + for (layer, layer_heads) in heads_by_layer { + let inserted = insert_q4k_layer_tensors(weights, index, layer)?; + let w_o = weights + .tensors + .get(&weights.arch.attn_o_key(layer)) + .ok_or_else(|| format!("missing W_O tensor at layer {layer}"))?; + let head_dim = weights.arch.head_dim_for_layer(layer); + for head in layer_heads { + let start = head.head * head_dim; + let end = start + head_dim; + let w_o_head = w_o.slice(s![.., start..end]); + let rows = (0..w_o_head.nrows()) + .map(|row| { + (0..w_o_head.ncols()) + .map(|col| w_o_head[[row, col]]) + .collect::>() + }) + .collect::>(); + out.insert(head, rows); + } + remove_layer_tensors(weights, inserted); + } + Ok(out) +} + +fn head_residual(values: &[f32], means: &StaticHeadMeans, position: usize) -> Vec { + let base = means.positions.get(position).unwrap_or(&means.global); + values + .iter() + .zip(base.iter()) + .map(|(&value, &mean)| value - mean) + .collect() +} + +fn project_head_vector_to_hidden(w_o_head: &[Vec], values: &[f32]) -> Vec { + let mut out = vec![0.0f32; w_o_head.len()]; + for (row_idx, row) in w_o_head.iter().enumerate() { + let mut sum = 0.0f32; + for (&value, &weight) in values.iter().zip(row.iter()) { + sum += value * weight; + } + out[row_idx] = sum; + } + out +} + +#[derive(Debug)] +struct EditCatalogAccumulator { + prompts: Vec, +} + +impl EditCatalogAccumulator { + fn new() -> Self { + Self { + prompts: Vec::new(), + } + } + + fn add(&mut self, prompt: OracleEditCatalogPromptReport) { + self.prompts.push(prompt); + } + + fn finish( + self, + space: EditCatalogSpace, + edits: usize, + hidden_dim: usize, + ) -> OracleEditCatalogPointReport { + let kls = self.prompts.iter().map(|p| p.kl).collect::>(); + OracleEditCatalogPointReport { + space: space.as_str().to_string(), + edits, + address_bits: edits.next_power_of_two().trailing_zeros() as usize, + residual_table_bytes_bf16: edits * hidden_dim * 2, + prompts: self.prompts.len(), + mean_kl: mean(&kls), + p95_kl: percentile(kls.clone(), 0.95), + max_kl: kls.iter().copied().fold(0.0, f64::max), + mean_delta_cross_entropy_bits: mean( + &self + .prompts + .iter() + .map(|p| p.delta_cross_entropy_bits) + .collect::>(), + ), + top1_agreement: bool_rate(self.prompts.iter().map(|p| p.top1_agree)), + top5_contains_baseline_top1: bool_rate( + self.prompts.iter().map(|p| p.baseline_top1_in_catalog_top5), + ), + mean_baseline_top1_prob: mean( + &self + .prompts + .iter() + .map(|p| p.baseline_top1_prob) + .collect::>(), + ), + mean_catalog_prob_of_baseline_top1: mean( + &self + .prompts + .iter() + .map(|p| p.catalog_prob_of_baseline_top1) + .collect::>(), + ), + mean_baseline_top1_margin: mean( + &self + .prompts + .iter() + .map(|p| p.baseline_top1_margin) + .collect::>(), + ), + per_prompt: self.prompts, + } + } +} + +fn prompt_label(record: &PromptRecord) -> &str { + record + .id + .as_deref() + .or(record.stratum.as_deref()) + .unwrap_or("prompt") +} + +fn parse_string_list(spec: &str) -> Vec { + spec.split(',') + .map(str::trim) + .filter(|part| !part.is_empty()) + .map(ToOwned::to_owned) + .collect() +} diff --git a/crates/larql-cli/src/commands/dev/ov_rd/gamma_address.rs b/crates/larql-cli/src/commands/dev/ov_rd/gamma_address.rs new file mode 100644 index 00000000..2cb6d69a --- /dev/null +++ b/crates/larql-cli/src/commands/dev/ov_rd/gamma_address.rs @@ -0,0 +1,831 @@ +use std::collections::HashMap; + +use larql_inference::attention::run_attention_block_with_pre_o; +use larql_inference::forward::ple::precompute_per_layer_inputs; +use larql_inference::forward::{embed_tokens_pub, run_layer_with_ffn}; +use larql_inference::{encode_prompt, WeightFfn}; +use larql_vindex::VectorIndex; +use ndarray::{s, Array2}; + +use super::address::{ + predict_code_from_hyperplanes, train_binary_hyperplane, AddressSupervisedGroupModel, +}; +use super::basis::{WoRoundtripBasis, ZPcaBasis}; +use super::metrics::argmax_usize; +use super::pq::PqCodebook; +use super::runtime::{insert_q4k_layer_tensors, remove_layer_tensors}; +use super::stats::StaticHeadMeans; +use super::types::{HeadId, PqConfig, PromptRecord}; + +#[derive(Debug, Clone)] +pub(super) struct GammaProjectedAddressModel { + pub(super) name: String, + pub(super) source: GammaProjectionSource, + pub(super) supervised: AddressSupervisedGroupModel, +} + +impl GammaProjectedAddressModel { + pub(super) fn selected_group_keys(&self) -> Vec { + self.supervised + .selected_group_keys() + .into_iter() + .map(|key| format!("{}:{key}", self.name)) + .collect() + } + + pub(super) fn project_layer_input( + &self, + layer_input: &Array2, + ) -> Result, Box> { + match &self.source { + GammaProjectionSource::Raw => Ok(layer_input.clone()), + GammaProjectionSource::DiagonalAffine(map) => { + let mut rows = Vec::with_capacity(layer_input.len()); + for row in layer_input.rows() { + rows.extend( + map.project( + row.as_slice().ok_or( + "layer input row was not contiguous during gamma projection", + )?, + ), + ); + } + Ok(Array2::from_shape_vec(layer_input.raw_dim(), rows)?) + } + GammaProjectionSource::RandomProjection(map) => { + let mut rows = Vec::with_capacity(layer_input.nrows() * map.rank); + for row in layer_input.rows() { + rows.extend( + map.project(row.as_slice().ok_or( + "layer input row was not contiguous during random projection", + )?), + ); + } + Ok(Array2::from_shape_vec( + (layer_input.nrows(), map.rank), + rows, + )?) + } + GammaProjectionSource::LearnedLowRank(map) => { + let mut rows = Vec::with_capacity(layer_input.nrows() * map.rank); + for row in layer_input.rows() { + rows.extend(map.project(row.as_slice().ok_or( + "layer input row was not contiguous during learned gamma projection", + )?)); + } + Ok(Array2::from_shape_vec( + (layer_input.nrows(), map.rank), + rows, + )?) + } + } + } +} + +#[derive(Debug, Clone)] +pub(super) enum GammaProjectionSource { + Raw, + DiagonalAffine(DiagonalAffineMap), + RandomProjection(RandomProjectionMap), + LearnedLowRank(LearnedLowRankMap), +} + +#[derive(Debug, Clone)] +pub(super) struct DiagonalAffineMap { + mean_x: Vec, + mean_y: Vec, + slope: Vec, +} + +impl DiagonalAffineMap { + fn project(&self, row: &[f32]) -> Vec { + row.iter() + .enumerate() + .map(|(dim, &x)| self.mean_y[dim] + self.slope[dim] * (x - self.mean_x[dim])) + .collect() + } +} + +#[derive(Debug, Clone)] +pub(super) struct RandomProjectionMap { + input_dim: usize, + rank: usize, + seed: u64, +} + +impl RandomProjectionMap { + fn new(input_dim: usize, rank: usize, seed: u64) -> Self { + Self { + input_dim, + rank, + seed, + } + } + + fn project(&self, row: &[f32]) -> Vec { + let scale = (self.input_dim as f32).sqrt().max(1.0); + let mut out = vec![0.0_f32; self.rank]; + for (out_dim, value) in out.iter_mut().enumerate() { + let mut sum = 0.0_f32; + for (in_dim, &x) in row.iter().enumerate() { + let hash = splitmix64( + self.seed + ^ ((out_dim as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15)) + ^ ((in_dim as u64).wrapping_mul(0xBF58_476D_1CE4_E5B9)), + ); + let sign = if hash & 1 == 0 { -1.0 } else { 1.0 }; + sum += sign * x; + } + *value = sum / scale; + } + out + } +} + +#[derive(Debug, Clone)] +pub(super) struct LearnedLowRankMap { + mean_x: Vec, + mean_y: Vec, + basis_y: Vec>, + weights: Vec>, + bias: Vec, + rank: usize, +} + +impl LearnedLowRankMap { + fn project(&self, row: &[f32]) -> Vec { + let mut out = vec![0.0_f32; self.rank]; + for (component, value) in out.iter_mut().enumerate() { + let mut sum = self.bias[component]; + for (dim, &x) in row.iter().enumerate() { + sum += self.weights[component][dim] * (x - self.mean_x[dim]); + } + *value = sum; + } + out + } + + fn target_coordinates(&self, target: &[f32]) -> Vec { + self.basis_y + .iter() + .map(|basis| { + target + .iter() + .zip(self.mean_y.iter()) + .zip(basis.iter()) + .map(|((&y, &mean), &direction)| (y - mean) * direction) + .sum() + }) + .collect() + } +} + +#[derive(Debug, Clone)] +struct GammaCodeSample { + head: HeadId, + config: PqConfig, + position: usize, + raw_input: Vec, + targets: HashMap>, + codes: Vec, +} + +pub(super) fn fit_gamma_projected_address_models( + weights: &mut larql_inference::ModelWeights, + index: &VectorIndex, + tokenizer: &tokenizers::Tokenizer, + prompts: &[PromptRecord], + heads: &[HeadId], + bases: &HashMap, + means: &HashMap, + pca_bases: &HashMap, + codebooks: &HashMap<(HeadId, PqConfig), PqCodebook>, + selected_groups: &[usize], + projection_layers: &[usize], + random_ranks: &[usize], + random_seeds: &[u64], + learned_ranks: &[usize], + learned_epochs: usize, + learned_lr: f32, + learned_l2: f32, + learned_pca_iters: usize, + epochs: usize, + lr: f32, + l2: f32, +) -> Result>, Box> +{ + let samples = collect_gamma_code_samples( + weights, + index, + tokenizer, + prompts, + heads, + bases, + means, + pca_bases, + codebooks, + projection_layers, + "gamma-address-fit", + )?; + let dim = weights.hidden_size; + + let mut samples_by_head_config: HashMap<(HeadId, PqConfig), Vec<&GammaCodeSample>> = + HashMap::new(); + let mut samples_by_head: HashMap> = HashMap::new(); + let mut majority_counts: HashMap<(HeadId, PqConfig, usize), Vec> = HashMap::new(); + for sample in &samples { + samples_by_head_config + .entry((sample.head, sample.config)) + .or_default() + .push(sample); + samples_by_head.entry(sample.head).or_default().push(sample); + for (group, &code) in sample.codes.iter().enumerate() { + let levels = 1usize << sample.config.bits_per_group; + majority_counts + .entry((sample.head, sample.config, group)) + .or_insert_with(|| vec![0; levels])[code] += 1; + } + } + + let mut maps_by_head_layer: HashMap<(HeadId, usize), DiagonalAffineMap> = HashMap::new(); + for head in heads { + let head_samples = samples_by_head.get(head).cloned().unwrap_or_default(); + for &projection_layer in projection_layers { + let pairs = head_samples + .iter() + .filter_map(|sample| { + sample + .targets + .get(&projection_layer) + .map(|target| (sample.raw_input.as_slice(), target.as_slice())) + }) + .collect::>(); + if !pairs.is_empty() { + maps_by_head_layer.insert( + (*head, projection_layer), + fit_diagonal_affine_map(&pairs, dim), + ); + } + } + } + + let mut learned_maps_by_head_layer_rank: HashMap<(HeadId, usize, usize), LearnedLowRankMap> = + HashMap::new(); + for head in heads { + let head_samples = samples_by_head.get(head).cloned().unwrap_or_default(); + for &projection_layer in projection_layers { + let pairs = head_samples + .iter() + .filter_map(|sample| { + sample + .targets + .get(&projection_layer) + .map(|target| (sample.raw_input.as_slice(), target.as_slice())) + }) + .collect::>(); + if pairs.is_empty() { + continue; + } + for &rank in learned_ranks { + learned_maps_by_head_layer_rank.insert( + (*head, projection_layer, rank), + fit_learned_low_rank_map( + &pairs, + dim, + rank, + learned_pca_iters, + learned_epochs, + learned_lr, + learned_l2, + ((*head).layer as u64) << 32 + ^ ((*head).head as u64) << 24 + ^ (projection_layer as u64) << 8 + ^ rank as u64, + ), + ); + } + } + } + + let mut out = HashMap::new(); + for ((head, config), _) in codebooks { + let train_samples = samples_by_head_config + .get(&(*head, *config)) + .cloned() + .unwrap_or_default(); + let mut group_majority = Vec::with_capacity(config.groups); + for group in 0..config.groups { + group_majority.push( + majority_counts + .get(&(*head, *config, group)) + .map(|counts| argmax_usize(counts)) + .unwrap_or(0), + ); + } + + let mut models = Vec::new(); + let raw_rows = train_samples + .iter() + .map(|sample| sample.raw_input.clone()) + .collect::>(); + models.push(fit_one_projected_model( + "gamma_raw", + GammaProjectionSource::Raw, + &raw_rows, + &train_samples, + *config, + selected_groups, + &group_majority, + epochs, + lr, + l2, + )); + + for &projection_layer in projection_layers { + let Some(map) = maps_by_head_layer.get(&(*head, projection_layer)).cloned() else { + continue; + }; + let projected_rows = train_samples + .iter() + .map(|sample| map.project(&sample.raw_input)) + .collect::>(); + models.push(fit_one_projected_model( + &format!("gamma_diag_post_l{projection_layer}"), + GammaProjectionSource::DiagonalAffine(map), + &projected_rows, + &train_samples, + *config, + selected_groups, + &group_majority, + epochs, + lr, + l2, + )); + } + for &rank in random_ranks { + for &seed in random_seeds { + let map = RandomProjectionMap::new(dim, rank, seed); + let projected_rows = train_samples + .iter() + .map(|sample| map.project(&sample.raw_input)) + .collect::>(); + models.push(fit_one_projected_model( + &format!("random_rank{rank}_seed{seed}"), + GammaProjectionSource::RandomProjection(map), + &projected_rows, + &train_samples, + *config, + selected_groups, + &group_majority, + epochs, + lr, + l2, + )); + } + } + for &projection_layer in projection_layers { + for &rank in learned_ranks { + let Some(map) = learned_maps_by_head_layer_rank + .get(&(*head, projection_layer, rank)) + .cloned() + else { + continue; + }; + let projected_rows = train_samples + .iter() + .map(|sample| map.project(&sample.raw_input)) + .collect::>(); + models.push(fit_one_projected_model( + &format!("gamma_learned_post_l{projection_layer}_rank{rank}"), + GammaProjectionSource::LearnedLowRank(map), + &projected_rows, + &train_samples, + *config, + selected_groups, + &group_majority, + epochs, + lr, + l2, + )); + } + } + + out.insert((*head, *config), models); + } + + Ok(out) +} + +fn splitmix64(mut x: u64) -> u64 { + x = x.wrapping_add(0x9E37_79B9_7F4A_7C15); + let mut z = x; + z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9); + z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB); + z ^ (z >> 31) +} + +fn fit_one_projected_model( + name: &str, + source: GammaProjectionSource, + rows: &[Vec], + samples: &[&GammaCodeSample], + config: PqConfig, + selected_groups: &[usize], + group_majority: &[usize], + epochs: usize, + lr: f32, + l2: f32, +) -> GammaProjectedAddressModel { + let dim = rows.first().map(Vec::len).unwrap_or(0); + let row_refs = rows.iter().map(Vec::as_slice).collect::>(); + let mut group_hyperplanes = vec![Vec::new(); config.groups]; + let mut group_train_accuracy = vec![0.0; config.groups]; + for &group in selected_groups { + let mut bit_planes = Vec::with_capacity(config.bits_per_group); + for bit in 0..config.bits_per_group { + let labels = samples + .iter() + .map(|sample| ((sample.codes[group] >> bit) & 1) != 0) + .collect::>(); + bit_planes.push(train_binary_hyperplane( + &row_refs, &labels, dim, epochs, lr, l2, + )); + } + + let mut correct = 0usize; + for (row, sample) in rows.iter().zip(samples.iter()) { + let predicted = predict_code_from_hyperplanes(row, &bit_planes); + if predicted == sample.codes[group] { + correct += 1; + } + } + group_train_accuracy[group] = if rows.is_empty() { + 0.0 + } else { + correct as f64 / rows.len() as f64 + }; + group_hyperplanes[group] = bit_planes; + } + + GammaProjectedAddressModel { + name: name.to_string(), + source, + supervised: AddressSupervisedGroupModel { + groups: selected_groups.to_vec(), + bits_per_group: config.bits_per_group, + epochs, + lr, + l2, + group_majority: group_majority.to_vec(), + group_hyperplanes, + group_train_accuracy, + }, + } +} + +fn fit_diagonal_affine_map(pairs: &[(&[f32], &[f32])], dim: usize) -> DiagonalAffineMap { + let n = pairs.len().max(1) as f64; + let mut sum_x = vec![0.0_f64; dim]; + let mut sum_y = vec![0.0_f64; dim]; + let mut sum_xx = vec![0.0_f64; dim]; + let mut sum_xy = vec![0.0_f64; dim]; + for &(x, y) in pairs { + for dim_idx in 0..dim { + let xi = x[dim_idx] as f64; + let yi = y[dim_idx] as f64; + sum_x[dim_idx] += xi; + sum_y[dim_idx] += yi; + sum_xx[dim_idx] += xi * xi; + sum_xy[dim_idx] += xi * yi; + } + } + + let mut mean_x = vec![0.0_f32; dim]; + let mut mean_y = vec![0.0_f32; dim]; + let mut slope = vec![0.0_f32; dim]; + for dim_idx in 0..dim { + let mx = sum_x[dim_idx] / n; + let my = sum_y[dim_idx] / n; + let var_x = (sum_xx[dim_idx] / n) - mx * mx; + let cov_xy = (sum_xy[dim_idx] / n) - mx * my; + mean_x[dim_idx] = mx as f32; + mean_y[dim_idx] = my as f32; + slope[dim_idx] = if var_x.abs() > 1e-12 { + (cov_xy / var_x) as f32 + } else { + 0.0 + }; + } + + DiagonalAffineMap { + mean_x, + mean_y, + slope, + } +} + +fn fit_learned_low_rank_map( + pairs: &[(&[f32], &[f32])], + dim: usize, + rank: usize, + pca_iters: usize, + epochs: usize, + lr: f32, + l2: f32, + seed: u64, +) -> LearnedLowRankMap { + let (mean_x, mean_y) = pair_means(pairs, dim); + let basis_y = fit_target_power_pca_basis(pairs, &mean_y, dim, rank, pca_iters, seed); + let mut map = LearnedLowRankMap { + mean_x, + mean_y, + basis_y, + weights: vec![vec![0.0_f32; dim]; rank], + bias: vec![0.0_f32; rank], + rank, + }; + let target_coords = pairs + .iter() + .map(|(_, target)| map.target_coordinates(target)) + .collect::>(); + let input_norms = pairs + .iter() + .map(|(input, _)| { + input + .iter() + .zip(map.mean_x.iter()) + .map(|(&x, &mean)| { + let centered = x - mean; + centered * centered + }) + .sum::() + .max(1.0) + }) + .collect::>(); + + for _ in 0..epochs { + for (sample_idx, (input, _)) in pairs.iter().enumerate() { + let norm = input_norms[sample_idx]; + let step = lr / norm; + for component in 0..rank { + let mut pred = map.bias[component]; + for (dim_idx, &x) in input.iter().enumerate() { + pred += map.weights[component][dim_idx] * (x - map.mean_x[dim_idx]); + } + let err = pred - target_coords[sample_idx][component]; + map.bias[component] -= lr * err * 0.01; + for (dim_idx, &x) in input.iter().enumerate() { + let centered = x - map.mean_x[dim_idx]; + let grad = err * centered + l2 * map.weights[component][dim_idx]; + map.weights[component][dim_idx] -= step * grad; + } + } + } + } + map +} + +fn pair_means(pairs: &[(&[f32], &[f32])], dim: usize) -> (Vec, Vec) { + let n = pairs.len().max(1) as f64; + let mut mean_x = vec![0.0_f64; dim]; + let mut mean_y = vec![0.0_f64; dim]; + for &(x, y) in pairs { + for dim_idx in 0..dim { + mean_x[dim_idx] += x[dim_idx] as f64; + mean_y[dim_idx] += y[dim_idx] as f64; + } + } + ( + mean_x.into_iter().map(|value| (value / n) as f32).collect(), + mean_y.into_iter().map(|value| (value / n) as f32).collect(), + ) +} + +fn fit_target_power_pca_basis( + pairs: &[(&[f32], &[f32])], + mean_y: &[f32], + dim: usize, + rank: usize, + pca_iters: usize, + seed: u64, +) -> Vec> { + let mut basis = Vec::with_capacity(rank); + for component in 0..rank { + let mut v = deterministic_unit_vector(dim, seed ^ component as u64); + orthonormalize(&mut v, &basis); + for _ in 0..pca_iters { + let mut next = vec![0.0_f64; dim]; + for &(_, y) in pairs { + let dot = y + .iter() + .zip(mean_y.iter()) + .zip(v.iter()) + .map(|((&yi, &mean), &vi)| (yi - mean) as f64 * vi as f64) + .sum::(); + for dim_idx in 0..dim { + next[dim_idx] += (y[dim_idx] - mean_y[dim_idx]) as f64 * dot; + } + } + let inv_n = 1.0 / pairs.len().max(1) as f64; + let mut next_f32 = next + .into_iter() + .map(|value| (value * inv_n) as f32) + .collect::>(); + orthonormalize(&mut next_f32, &basis); + v = next_f32; + } + basis.push(v); + } + basis +} + +fn deterministic_unit_vector(dim: usize, seed: u64) -> Vec { + let mut values = (0..dim) + .map(|idx| { + let hash = splitmix64(seed ^ (idx as u64).wrapping_mul(0xD6E8_FEB8_6659_FD93)); + let unit = ((hash >> 11) as f64) * (1.0 / ((1_u64 << 53) as f64)); + (2.0 * unit - 1.0) as f32 + }) + .collect::>(); + normalize(&mut values); + values +} + +fn orthonormalize(v: &mut [f32], basis: &[Vec]) { + for prev in basis { + let dot = v + .iter() + .zip(prev.iter()) + .map(|(&a, &b)| a as f64 * b as f64) + .sum::() as f32; + for (value, &prev_value) in v.iter_mut().zip(prev.iter()) { + *value -= dot * prev_value; + } + } + normalize(v); +} + +fn normalize(v: &mut [f32]) { + let norm = v + .iter() + .map(|&value| value as f64 * value as f64) + .sum::() + .sqrt(); + if norm > 1e-12 { + let inv = (1.0 / norm) as f32; + for value in v { + *value *= inv; + } + } +} + +fn collect_gamma_code_samples( + weights: &mut larql_inference::ModelWeights, + index: &VectorIndex, + tokenizer: &tokenizers::Tokenizer, + prompts: &[PromptRecord], + heads: &[HeadId], + bases: &HashMap, + means: &HashMap, + pca_bases: &HashMap, + codebooks: &HashMap<(HeadId, PqConfig), PqCodebook>, + projection_layers: &[usize], + label_prefix: &str, +) -> Result, Box> { + let mut heads_by_layer: HashMap> = HashMap::new(); + for head in heads { + heads_by_layer.entry(head.layer).or_default().push(*head); + } + let max_head_layer = heads.iter().map(|head| head.layer).max().unwrap_or(0); + let max_projection_layer = projection_layers + .iter() + .copied() + .max() + .unwrap_or(max_head_layer); + let max_layer = max_head_layer.max(max_projection_layer); + let projection_set = projection_layers.iter().copied().collect::>(); + let mut all_samples = Vec::new(); + + for (prompt_idx, record) in prompts.iter().enumerate() { + let label = record + .id + .as_deref() + .or(record.stratum.as_deref()) + .unwrap_or("prompt"); + eprintln!( + " {} [{}/{}] {}", + label_prefix, + prompt_idx + 1, + prompts.len(), + label + ); + let token_ids = encode_prompt(tokenizer, &*weights.arch, &record.prompt)?; + if token_ids.is_empty() { + continue; + } + let stratum = record.stratum.as_deref().unwrap_or("unknown"); + let mut h = embed_tokens_pub(weights, &token_ids); + let ple_inputs = precompute_per_layer_inputs(weights, &h, &token_ids); + let mut prompt_samples = Vec::new(); + let mut target_rows_by_layer: HashMap>> = HashMap::new(); + + for layer in 0..weights.num_layers { + let inserted = insert_q4k_layer_tensors(weights, index, layer)?; + if let Some(layer_heads) = heads_by_layer.get(&layer) { + let layer_input = h.clone(); + let (_, pre_o) = run_attention_block_with_pre_o(weights, &h, layer) + .ok_or_else(|| format!("pre-W_O capture failed at layer {layer}"))?; + let head_dim = weights.arch.head_dim_for_layer(layer); + for head in layer_heads { + let basis = bases.get(head).ok_or_else(|| { + format!("missing basis for L{}H{}", head.layer, head.head) + })?; + let head_means = means.get(head).ok_or_else(|| { + format!("missing means for L{}H{}", head.layer, head.head) + })?; + let pca_basis = pca_bases.get(head).ok_or_else(|| { + format!("missing PCA basis for L{}H{}", head.layer, head.head) + })?; + let start = head.head * head_dim; + let end = start + head_dim; + let head_codebooks = codebooks + .iter() + .filter(|((codebook_head, _), _)| codebook_head == head) + .collect::>(); + for pos in 0..pre_o.nrows() { + let row = pre_o.slice(s![pos, start..end]); + let values = row + .as_slice() + .ok_or("pre-W_O head row was not contiguous during gamma fit")?; + let base = head_means.positions.get(pos).unwrap_or(&head_means.global); + let residual = values + .iter() + .zip(base.iter()) + .map(|(&yi, &bi)| yi - bi) + .collect::>(); + let z = basis.residual_to_z(&residual); + let raw_input = layer_input + .row(pos) + .as_slice() + .ok_or("layer input row was not contiguous during gamma fit")? + .to_vec(); + for ((_, config), codebook) in &head_codebooks { + let coords = pca_basis.coordinates_with_rank(&z, config.k); + let codes = codebook.quantize_indices_for_stratum(&coords, stratum); + prompt_samples.push(GammaCodeSample { + head: *head, + config: *config, + position: pos, + raw_input: raw_input.clone(), + targets: HashMap::new(), + codes, + }); + } + } + } + } + + { + let ffn = WeightFfn { weights }; + if let Some((h_new, _, _)) = + run_layer_with_ffn(weights, &h, layer, &ffn, false, ple_inputs.get(layer), None) + { + h = h_new; + } else { + remove_layer_tensors(weights, inserted); + return Err(format!("layer {layer} returned no output").into()); + } + } + remove_layer_tensors(weights, inserted); + + if projection_set.contains(&layer) { + target_rows_by_layer.insert( + layer, + h.rows() + .into_iter() + .map(|row| row.as_slice().unwrap_or(&[]).to_vec()) + .collect(), + ); + } + if layer >= max_layer { + break; + } + } + + for sample in &mut prompt_samples { + for &projection_layer in projection_layers { + if projection_layer < sample.head.layer { + continue; + } + if let Some(rows) = target_rows_by_layer.get(&projection_layer) { + if let Some(target) = rows.get(sample.position) { + sample.targets.insert(projection_layer, target.clone()); + } + } + } + } + all_samples.extend(prompt_samples); + } + + Ok(all_samples) +} diff --git a/crates/larql-cli/src/commands/dev/ov_rd/input.rs b/crates/larql-cli/src/commands/dev/ov_rd/input.rs new file mode 100644 index 00000000..acde2221 --- /dev/null +++ b/crates/larql-cli/src/commands/dev/ov_rd/input.rs @@ -0,0 +1,156 @@ +use std::collections::HashMap; +use std::path::PathBuf; + +use super::types::{HeadId, PqConfig, PromptRecord}; + +pub(super) fn load_prompts( + path: &PathBuf, + max_prompts: Option, +) -> Result, Box> { + let text = std::fs::read_to_string(path)?; + let mut prompts = Vec::new(); + for line in text.lines() { + let line = line.trim(); + if line.is_empty() { + continue; + } + prompts.push(serde_json::from_str::(line)?); + if max_prompts.is_some_and(|n| prompts.len() >= n) { + break; + } + } + Ok(prompts) +} + +pub(super) fn limit_prompts_per_stratum( + prompts: Vec, + max_per_stratum: usize, +) -> Vec { + let mut counts: HashMap = HashMap::new(); + let mut selected = Vec::new(); + for prompt in prompts { + let key = prompt + .stratum + .clone() + .unwrap_or_else(|| "unknown".to_string()); + let count = counts.entry(key).or_default(); + if *count < max_per_stratum { + *count += 1; + selected.push(prompt); + } + } + selected +} + +pub(super) fn split_prompt_records( + prompts: &[PromptRecord], + eval_mod: usize, + eval_offset: usize, +) -> Result<(Vec, Vec), Box> { + if eval_mod == 0 { + return Err("--eval-mod must be greater than zero".into()); + } + if eval_offset >= eval_mod { + return Err("--eval-offset must be smaller than --eval-mod".into()); + } + let mut fit = Vec::new(); + let mut eval = Vec::new(); + for (idx, prompt) in prompts.iter().cloned().enumerate() { + if idx % eval_mod == eval_offset { + eval.push(prompt); + } else { + fit.push(prompt); + } + } + if fit.is_empty() || eval.is_empty() { + return Err("held-out split produced an empty fit or eval set".into()); + } + eprintln!( + "Held-out split: fit_prompts={}, eval_prompts={} (idx % {} == {})", + fit.len(), + eval.len(), + eval_mod, + eval_offset + ); + Ok((fit, eval)) +} + +pub(super) fn parse_head_spec(spec: &str) -> Result, Box> { + let mut heads = Vec::new(); + for part in spec.split(',') { + let part = part.trim(); + if part.is_empty() { + continue; + } + let (layer, head) = part + .split_once(':') + .ok_or_else(|| format!("invalid head spec '{part}', expected layer:head"))?; + heads.push(HeadId { + layer: layer.parse()?, + head: head.parse()?, + }); + } + Ok(heads) +} + +pub(super) fn parse_usize_list(spec: &str) -> Result, Box> { + let mut values = Vec::new(); + for part in spec.split(',') { + let part = part.trim(); + if part.is_empty() { + continue; + } + values.push(part.parse()?); + } + Ok(values) +} + +pub(super) fn parse_pq_configs(spec: &str) -> Result, Box> { + let mut configs = Vec::new(); + for part in spec.split(',') { + let part = part.trim(); + if part.is_empty() { + continue; + } + let fields = part.split(':').collect::>(); + if fields.len() != 3 { + return Err(format!("invalid PQ config '{part}', expected K:groups:bits").into()); + } + let config = PqConfig { + k: fields[0].parse()?, + groups: fields[1].parse()?, + bits_per_group: fields[2].parse()?, + }; + if config.k == 0 || config.groups == 0 || config.bits_per_group == 0 { + return Err(format!("invalid zero value in PQ config '{part}'").into()); + } + if config.k % config.groups != 0 { + return Err(format!("PQ config '{part}' requires K divisible by groups").into()); + } + if config.bits_per_group > 12 { + return Err(format!("PQ config '{part}' has too many bits/group for smoke run").into()); + } + configs.push(config); + } + configs.sort_by_key(|c| (c.k, c.groups, c.bits_per_group)); + configs.dedup(); + Ok(configs) +} + +pub(super) fn parse_layer_spec(spec: &str) -> Result, Box> { + let mut layers = Vec::new(); + for part in spec.split(',') { + let part = part.trim(); + if part.contains('-') { + let (a, b) = part + .split_once('-') + .ok_or_else(|| format!("invalid range: {part}"))?; + let start: usize = a.parse()?; + let end: usize = b.parse()?; + layers.extend(start..=end); + } else if !part.is_empty() { + layers.push(part.parse()?); + } + } + Ok(layers) +} diff --git a/crates/larql-cli/src/commands/dev/ov_rd/metrics.rs b/crates/larql-cli/src/commands/dev/ov_rd/metrics.rs new file mode 100644 index 00000000..a0265f57 --- /dev/null +++ b/crates/larql-cli/src/commands/dev/ov_rd/metrics.rs @@ -0,0 +1,154 @@ +pub(super) fn log_softmax(logits: &[f32]) -> Vec { + let max_logit = logits + .iter() + .map(|&v| v as f64) + .fold(f64::NEG_INFINITY, f64::max); + let sum_exp = logits + .iter() + .map(|&v| ((v as f64) - max_logit).exp()) + .sum::(); + let log_z = max_logit + sum_exp.ln(); + logits.iter().map(|&v| (v as f64) - log_z).collect() +} + +pub(super) fn kl_logp(p_logp: &[f64], q_logp: &[f64]) -> f64 { + p_logp + .iter() + .zip(q_logp.iter()) + .map(|(&lp, &lq)| { + let p = lp.exp(); + p * (lp - lq) + }) + .sum() +} + +pub(super) fn token_prob(logp: &[f64], token_id: u32) -> f64 { + logp.get(token_id as usize) + .map(|value| value.exp()) + .unwrap_or(0.0) +} + +pub(super) fn argmax_usize(values: &[usize]) -> usize { + values + .iter() + .enumerate() + .max_by_key(|(_, value)| *value) + .map(|(idx, _)| idx) + .unwrap_or(0) +} + +pub(super) fn code_mass(counts: &[usize], code: usize) -> f64 { + let total = counts.iter().sum::(); + if total == 0 { + 0.0 + } else { + counts.get(code).copied().unwrap_or(0) as f64 / total as f64 + } +} + +pub(super) fn entropy_bits(counts: &[usize]) -> f64 { + let total = counts.iter().sum::(); + if total == 0 { + return 0.0; + } + counts + .iter() + .filter(|&&count| count > 0) + .map(|&count| { + let p = count as f64 / total as f64; + -p * p.log2() + }) + .sum() +} + +fn kl_counts_to_probs_bits(counts: &[usize], probs: &[f64]) -> f64 { + let total = counts.iter().sum::(); + if total == 0 { + return 0.0; + } + counts + .iter() + .zip(probs.iter()) + .filter(|(&count, _)| count > 0) + .map(|(&count, &q)| { + let p = count as f64 / total as f64; + p * (p / q.max(1e-12)).log2() + }) + .sum() +} + +pub(super) fn js_divergence_bits(a: &[usize], b: &[usize]) -> f64 { + let total_a = a.iter().sum::(); + let total_b = b.iter().sum::(); + if total_a == 0 || total_b == 0 { + return 0.0; + } + let levels = a.len().max(b.len()); + let mut midpoint = vec![0.0; levels]; + for (idx, value) in midpoint.iter_mut().enumerate() { + let pa = a.get(idx).copied().unwrap_or(0) as f64 / total_a as f64; + let pb = b.get(idx).copied().unwrap_or(0) as f64 / total_b as f64; + *value = 0.5 * (pa + pb); + } + 0.5 * kl_counts_to_probs_bits(a, &midpoint) + 0.5 * kl_counts_to_probs_bits(b, &midpoint) +} + +pub(super) fn max_abs_diff(a: &[f32], b: &[f32]) -> f64 { + a.iter() + .zip(b.iter()) + .map(|(&x, &y)| ((x as f64) - (y as f64)).abs()) + .fold(0.0, f64::max) +} + +pub(super) fn argmax(values: &[f32]) -> u32 { + values + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) + .map(|(idx, _)| idx as u32) + .unwrap_or(0) +} + +pub(super) fn top_k_indices(values: &[f32], k: usize) -> Vec { + let mut pairs: Vec<(usize, f32)> = values.iter().copied().enumerate().collect(); + let take = k.min(pairs.len()); + pairs.select_nth_unstable_by(take.saturating_sub(1), |a, b| { + b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal) + }); + pairs.truncate(take); + pairs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + pairs.into_iter().map(|(idx, _)| idx as u32).collect() +} + +pub(super) fn mean(values: &[f64]) -> f64 { + if values.is_empty() { + 0.0 + } else { + values.iter().sum::() / values.len() as f64 + } +} + +pub(super) fn bool_rate(values: impl Iterator) -> f64 { + let mut total = 0usize; + let mut hits = 0usize; + for value in values { + total += 1; + if value { + hits += 1; + } + } + if total == 0 { + 0.0 + } else { + hits as f64 / total as f64 + } +} + +pub(super) fn percentile(mut values: Vec, p: f64) -> f64 { + if values.is_empty() { + return 0.0; + } + values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + let rank = ((values.len() - 1) as f64 * p).ceil() as usize; + values[rank.min(values.len() - 1)] +} diff --git a/crates/larql-cli/src/commands/dev/ov_rd/mod.rs b/crates/larql-cli/src/commands/dev/ov_rd/mod.rs new file mode 100644 index 00000000..15225c8d --- /dev/null +++ b/crates/larql-cli/src/commands/dev/ov_rd/mod.rs @@ -0,0 +1,26 @@ +mod address; +mod basis; +mod capture; +pub mod cmd; +mod edit_catalog; +mod gamma_address; +mod input; +mod metrics; +mod oracle; +mod oracle_pq; +mod oracle_pq_address; +mod oracle_pq_eval; +mod oracle_pq_fit; +mod oracle_pq_forward; +mod oracle_pq_mode_d; +mod oracle_pq_reports; +mod oracle_pq_stability; +mod pq; +mod pq_exception; +mod reports; +mod runtime; +mod sanity; +mod static_replace; +mod stats; +mod types; +mod zero_ablate; diff --git a/crates/larql-cli/src/commands/dev/ov_rd/oracle.rs b/crates/larql-cli/src/commands/dev/ov_rd/oracle.rs new file mode 100644 index 00000000..61f629aa --- /dev/null +++ b/crates/larql-cli/src/commands/dev/ov_rd/oracle.rs @@ -0,0 +1,670 @@ +use std::collections::HashMap; +use std::path::PathBuf; +use std::time::Instant; + +use clap::Args; +use larql_inference::{encode_prompt, hidden_to_raw_logits}; +use larql_vindex::{ + load_model_weights_q4k, load_vindex_tokenizer, SilentLoadCallbacks, VectorIndex, +}; +use ndarray::{s, Array2}; + +use super::basis::{ + build_roundtrip_bases, fit_z_pca_bases, RoundtripPatchMetrics, WoRoundtripBasis, ZPcaBasis, +}; +use super::input::{load_prompts, parse_head_spec, parse_usize_list}; +use super::metrics::{ + argmax, bool_rate, kl_logp, log_softmax, max_abs_diff, mean, percentile, token_prob, + top_k_indices, +}; +use super::reports::{ + OracleLowrankHeadReport, OracleLowrankPointReport, OracleLowrankPromptReport, + OracleLowrankReport, OracleRoundtripHeadReport, OracleRoundtripPromptReport, + OracleRoundtripReport, +}; +use super::static_replace::fit_static_means; +use super::stats::StaticHeadMeans; +use super::types::HeadId; + +#[derive(Args)] +pub(super) struct OracleRoundtripArgs { + /// Self-contained Q4K vindex directory. + #[arg(long)] + index: PathBuf, + + /// JSONL prompt file. Each line must include at least {"prompt": "..."}. + #[arg(long)] + prompts: PathBuf, + + /// Output directory. + #[arg(long)] + out: PathBuf, + + /// Explicit heads as layer:head comma list, e.g. 0:4,0:6. + #[arg(long)] + heads: String, + + /// Relative singular value cutoff for retained W_O-visible directions. + #[arg(long, default_value_t = 1e-6)] + sigma_rel_cutoff: f64, + + /// Limit prompts for bounded sanity runs. + #[arg(long)] + max_prompts: Option, +} + +#[derive(Args)] +pub(super) struct OracleLowrankArgs { + /// Self-contained Q4K vindex directory. + #[arg(long)] + index: PathBuf, + + /// JSONL prompt file. Each line must include at least {"prompt": "..."}. + #[arg(long)] + prompts: PathBuf, + + /// Output directory. + #[arg(long)] + out: PathBuf, + + /// Explicit heads as layer:head comma list, e.g. 0:4,0:6. + #[arg(long)] + heads: String, + + /// Comma-separated K values for the low-rank sweep. + #[arg(long, default_value = "1,2,4,8,16,32")] + ks: String, + + /// Relative singular value cutoff for retained W_O-visible directions. + #[arg(long, default_value_t = 1e-6)] + sigma_rel_cutoff: f64, + + /// Limit prompts for bounded sanity runs. + #[arg(long)] + max_prompts: Option, +} + +#[derive(Debug)] +struct OracleLowrankPointAccumulator { + prompts: Vec, +} + +impl OracleLowrankPointAccumulator { + fn new() -> Self { + Self { + prompts: Vec::new(), + } + } + + fn add(&mut self, prompt: OracleLowrankPromptReport) { + self.prompts.push(prompt); + } + + fn finish(self, k: usize) -> OracleLowrankPointReport { + let kls: Vec = self.prompts.iter().map(|p| p.kl).collect(); + let mean_delta_cross_entropy_bits = mean( + &self + .prompts + .iter() + .map(|p| p.delta_cross_entropy_bits) + .collect::>(), + ); + OracleLowrankPointReport { + k, + prompts: self.prompts.len(), + mean_kl: mean(&kls), + p95_kl: percentile(kls.clone(), 0.95), + max_kl: kls.iter().copied().fold(0.0, f64::max), + mean_delta_cross_entropy_bits, + top1_agreement: bool_rate(self.prompts.iter().map(|p| p.top1_agree)), + top5_contains_baseline_top1: bool_rate( + self.prompts.iter().map(|p| p.baseline_top1_in_lowrank_top5), + ), + mean_baseline_top1_prob: mean( + &self + .prompts + .iter() + .map(|p| p.baseline_top1_prob) + .collect::>(), + ), + mean_lowrank_prob_of_baseline_top1: mean( + &self + .prompts + .iter() + .map(|p| p.lowrank_prob_of_baseline_top1) + .collect::>(), + ), + mean_baseline_top1_margin: mean( + &self + .prompts + .iter() + .map(|p| p.baseline_top1_margin) + .collect::>(), + ), + mean_pre_wo_l2: mean(&self.prompts.iter().map(|p| p.pre_wo_l2).collect::>()), + mean_wo_visible_l2: mean( + &self + .prompts + .iter() + .map(|p| p.wo_visible_l2) + .collect::>(), + ), + per_prompt: self.prompts, + } + } +} + +#[derive(Debug)] +struct OracleRoundtripAccumulator { + prompts: Vec, +} + +impl OracleRoundtripAccumulator { + fn new() -> Self { + Self { + prompts: Vec::new(), + } + } + + fn add(&mut self, prompt: OracleRoundtripPromptReport) { + self.prompts.push(prompt); + } + + fn finish(self, head: HeadId, basis: &WoRoundtripBasis) -> OracleRoundtripHeadReport { + let kls: Vec = self.prompts.iter().map(|p| p.kl).collect(); + let pre_l2: Vec = self.prompts.iter().map(|p| p.pre_wo_l2).collect(); + let visible_l2: Vec = self.prompts.iter().map(|p| p.wo_visible_l2).collect(); + OracleRoundtripHeadReport { + layer: head.layer, + head: head.head, + head_dim: basis.head_dim, + rank_retained: basis.rank_retained(), + sigma_max: basis.sigma_max, + sigma_min_retained: basis.sigma_min_retained, + sigma_rel_cutoff: basis.sigma_rel_cutoff, + prompts: self.prompts.len(), + mean_kl: mean(&kls), + p95_kl: percentile(kls.clone(), 0.95), + max_kl: kls.iter().copied().fold(0.0, f64::max), + max_abs_logit_diff: self + .prompts + .iter() + .map(|p| p.max_abs_logit_diff) + .fold(0.0, f64::max), + mean_pre_wo_l2: mean(&pre_l2), + max_pre_wo_l2: pre_l2.iter().copied().fold(0.0, f64::max), + mean_wo_visible_l2: mean(&visible_l2), + max_wo_visible_l2: visible_l2.iter().copied().fold(0.0, f64::max), + per_prompt: self.prompts, + } + } +} + +pub(super) fn run_oracle_roundtrip( + args: OracleRoundtripArgs, +) -> Result<(), Box> { + std::fs::create_dir_all(&args.out)?; + + eprintln!("Loading vindex: {}", args.index.display()); + let start = Instant::now(); + let mut cb = SilentLoadCallbacks; + let mut index = VectorIndex::load_vindex(&args.index, &mut cb)?; + index.load_attn_q4k(&args.index)?; + index.load_interleaved_q4k(&args.index)?; + let mut weights = load_model_weights_q4k(&args.index, &mut cb)?; + let tokenizer = load_vindex_tokenizer(&args.index)?; + if weights.arch.is_hybrid_moe() { + return Err("ov-rd oracle-roundtrip currently supports dense FFN vindexes only".into()); + } + eprintln!( + " {} layers, hidden_size={}, q_heads={}, head_dim={} ({:.1}s)", + weights.num_layers, + weights.hidden_size, + weights.num_q_heads, + weights.head_dim, + start.elapsed().as_secs_f64() + ); + + let selected_heads = parse_head_spec(&args.heads)?; + if selected_heads.is_empty() { + return Err("no heads selected for oracle roundtrip".into()); + } + let prompts = load_prompts(&args.prompts, args.max_prompts)?; + eprintln!("Selected heads: {:?}", selected_heads); + eprintln!("Prompts: {}", prompts.len()); + + eprintln!("Building W_O-visible roundtrip bases"); + let bases = + build_roundtrip_bases(&mut weights, &index, &selected_heads, args.sigma_rel_cutoff)?; + for head in &selected_heads { + let basis = bases + .get(head) + .ok_or_else(|| format!("missing basis for L{} H{}", head.layer, head.head))?; + eprintln!( + " L{}H{} rank={} sigma_max={:.6} sigma_min_retained={:.6}", + head.layer, + head.head, + basis.rank_retained(), + basis.sigma_max, + basis.sigma_min_retained + ); + } + + let mut accumulators: Vec = selected_heads + .iter() + .map(|_| OracleRoundtripAccumulator::new()) + .collect(); + + for (prompt_idx, record) in prompts.iter().enumerate() { + let label = record + .id + .as_deref() + .or(record.stratum.as_deref()) + .unwrap_or("prompt"); + eprintln!(" [{}/{}] {}", prompt_idx + 1, prompts.len(), label); + + let token_ids = encode_prompt(&tokenizer, &*weights.arch, &record.prompt)?; + if token_ids.is_empty() { + continue; + } + let stratum = record.stratum.as_deref().unwrap_or("unknown"); + + let baseline_hidden = + larql_inference::vindex::predict_q4k_hidden(&mut weights, &token_ids, &index, None); + let baseline_logits = final_logits(&weights, &baseline_hidden); + let baseline_logp = log_softmax(&baseline_logits); + + for (idx, head) in selected_heads.iter().copied().enumerate() { + let basis = bases + .get(&head) + .ok_or_else(|| format!("missing basis for L{} H{}", head.layer, head.head))?; + let (roundtrip_hidden, metrics) = + forward_q4k_oracle_roundtrip_head(&mut weights, &token_ids, &index, head, basis)?; + let roundtrip_logits = final_logits(&weights, &roundtrip_hidden); + let roundtrip_logp = log_softmax(&roundtrip_logits); + accumulators[idx].add(OracleRoundtripPromptReport { + id: label.to_string(), + stratum: stratum.to_string(), + kl: kl_logp(&baseline_logp, &roundtrip_logp), + max_abs_logit_diff: max_abs_diff(&baseline_logits, &roundtrip_logits), + pre_wo_l2: metrics.pre_wo_l2, + wo_visible_l2: metrics.wo_visible_l2, + }); + } + } + + let heads = selected_heads + .iter() + .copied() + .zip(accumulators) + .map(|(head, acc)| { + let basis = bases + .get(&head) + .expect("basis existed during oracle roundtrip"); + acc.finish(head, basis) + }) + .collect(); + let report = OracleRoundtripReport { + index: args.index.display().to_string(), + prompt_file: args.prompts.display().to_string(), + prompts_seen: prompts.len(), + sigma_rel_cutoff: args.sigma_rel_cutoff, + selected_heads, + heads, + }; + + let out_path = args.out.join("oracle_roundtrip.json"); + let file = std::fs::File::create(&out_path)?; + serde_json::to_writer_pretty(file, &report)?; + eprintln!("Wrote {}", out_path.display()); + + Ok(()) +} + +pub(super) fn run_oracle_lowrank( + args: OracleLowrankArgs, +) -> Result<(), Box> { + std::fs::create_dir_all(&args.out)?; + + eprintln!("Loading vindex: {}", args.index.display()); + let start = Instant::now(); + let mut cb = SilentLoadCallbacks; + let mut index = VectorIndex::load_vindex(&args.index, &mut cb)?; + index.load_attn_q4k(&args.index)?; + index.load_interleaved_q4k(&args.index)?; + let mut weights = load_model_weights_q4k(&args.index, &mut cb)?; + let tokenizer = load_vindex_tokenizer(&args.index)?; + if weights.arch.is_hybrid_moe() { + return Err("ov-rd oracle-lowrank currently supports dense FFN vindexes only".into()); + } + eprintln!( + " {} layers, hidden_size={}, q_heads={}, head_dim={} ({:.1}s)", + weights.num_layers, + weights.hidden_size, + weights.num_q_heads, + weights.head_dim, + start.elapsed().as_secs_f64() + ); + + let selected_heads = parse_head_spec(&args.heads)?; + if selected_heads.is_empty() { + return Err("no heads selected for oracle lowrank".into()); + } + let mut ks = parse_usize_list(&args.ks)?; + ks.sort_unstable(); + ks.dedup(); + if ks.is_empty() { + return Err("no K values selected for oracle lowrank".into()); + } + let prompts = load_prompts(&args.prompts, args.max_prompts)?; + eprintln!("Selected heads: {:?}", selected_heads); + eprintln!("K sweep: {:?}", ks); + eprintln!("Prompts: {}", prompts.len()); + + eprintln!("Fitting position-mean static bases"); + let means = fit_static_means(&mut weights, &index, &tokenizer, &prompts, &selected_heads)?; + + eprintln!("Building W_O-visible bases"); + let bases = + build_roundtrip_bases(&mut weights, &index, &selected_heads, args.sigma_rel_cutoff)?; + for head in &selected_heads { + let basis = bases + .get(head) + .ok_or_else(|| format!("missing basis for L{} H{}", head.layer, head.head))?; + eprintln!( + " L{}H{} rank={} sigma_max={:.6} sigma_min_retained={:.6}", + head.layer, + head.head, + basis.rank_retained(), + basis.sigma_max, + basis.sigma_min_retained + ); + } + + eprintln!("Fitting empirical z-space PCA bases"); + let pca_bases = fit_z_pca_bases( + &mut weights, + &index, + &tokenizer, + &prompts, + &selected_heads, + &bases, + &means, + )?; + + let mut accumulators: HashMap<(HeadId, usize), OracleLowrankPointAccumulator> = HashMap::new(); + for head in &selected_heads { + for &k in &ks { + accumulators.insert((*head, k), OracleLowrankPointAccumulator::new()); + } + } + + for (prompt_idx, record) in prompts.iter().enumerate() { + let label = record + .id + .as_deref() + .or(record.stratum.as_deref()) + .unwrap_or("prompt"); + eprintln!(" [{}/{}] {}", prompt_idx + 1, prompts.len(), label); + + let token_ids = encode_prompt(&tokenizer, &*weights.arch, &record.prompt)?; + if token_ids.is_empty() { + continue; + } + let stratum = record.stratum.as_deref().unwrap_or("unknown"); + + let baseline_hidden = + larql_inference::vindex::predict_q4k_hidden(&mut weights, &token_ids, &index, None); + let baseline_logits = final_logits(&weights, &baseline_hidden); + let baseline_logp = log_softmax(&baseline_logits); + let baseline_top1 = argmax(&baseline_logits); + let baseline_top2 = top_k_indices(&baseline_logits, 2); + let baseline_top2_token = baseline_top2.get(1).copied().unwrap_or(baseline_top1); + let baseline_top1_prob = token_prob(&baseline_logp, baseline_top1); + let baseline_top2_prob = token_prob(&baseline_logp, baseline_top2_token); + let baseline_top1_margin = baseline_top1_prob - baseline_top2_prob; + + for head in &selected_heads { + let basis = bases.get(head).ok_or_else(|| { + format!( + "missing basis for oracle lowrank L{} H{}", + head.layer, head.head + ) + })?; + let head_means = means.get(head).ok_or_else(|| { + format!( + "missing position means for oracle lowrank L{} H{}", + head.layer, head.head + ) + })?; + let pca_basis = pca_bases.get(head).ok_or_else(|| { + format!( + "missing empirical PCA basis for oracle lowrank L{} H{}", + head.layer, head.head + ) + })?; + for &k in &ks { + let (lowrank_hidden, metrics) = forward_q4k_oracle_lowrank_head( + &mut weights, + &token_ids, + &index, + *head, + basis, + pca_basis, + head_means, + k, + )?; + let lowrank_logits = final_logits(&weights, &lowrank_hidden); + let lowrank_logp = log_softmax(&lowrank_logits); + let kl = kl_logp(&baseline_logp, &lowrank_logp); + let lowrank_top1 = argmax(&lowrank_logits); + let lowrank_top5 = top_k_indices(&lowrank_logits, 5); + let lowrank_top2 = top_k_indices(&lowrank_logits, 2); + let lowrank_top2_token = lowrank_top2.get(1).copied().unwrap_or(lowrank_top1); + let lowrank_top1_prob = token_prob(&lowrank_logp, lowrank_top1); + let lowrank_top2_prob = token_prob(&lowrank_logp, lowrank_top2_token); + let lowrank_top1_margin = lowrank_top1_prob - lowrank_top2_prob; + let lowrank_prob_of_baseline_top1 = token_prob(&lowrank_logp, baseline_top1); + accumulators + .get_mut(&(*head, k)) + .expect("oracle lowrank accumulator missing") + .add(OracleLowrankPromptReport { + id: label.to_string(), + stratum: stratum.to_string(), + kl, + delta_cross_entropy_bits: kl / std::f64::consts::LN_2, + baseline_top1, + lowrank_top1, + top1_agree: baseline_top1 == lowrank_top1, + baseline_top1_in_lowrank_top5: lowrank_top5.contains(&baseline_top1), + baseline_top1_prob, + baseline_top2: baseline_top2_token, + baseline_top2_prob, + baseline_top1_margin, + lowrank_top1_prob, + lowrank_prob_of_baseline_top1, + lowrank_top1_margin, + pre_wo_l2: metrics.pre_wo_l2, + wo_visible_l2: metrics.wo_visible_l2, + }); + } + } + } + + let mut head_reports = Vec::new(); + for head in &selected_heads { + let basis = bases + .get(head) + .ok_or_else(|| format!("missing basis for L{} H{}", head.layer, head.head))?; + let pca_basis = pca_bases + .get(head) + .ok_or_else(|| format!("missing PCA basis for L{} H{}", head.layer, head.head))?; + let mut points = Vec::new(); + for &k in &ks { + let acc = accumulators + .remove(&(*head, k)) + .expect("oracle lowrank accumulator missing at finish"); + points.push(acc.finish(k)); + } + let static_train_samples = means.get(head).map(|m| m.count).unwrap_or(0); + head_reports.push(OracleLowrankHeadReport { + layer: head.layer, + head: head.head, + head_dim: basis.head_dim, + rank_retained: basis.rank_retained(), + empirical_rank: pca_basis.rank(), + sigma_max: basis.sigma_max, + sigma_min_retained: basis.sigma_min_retained, + static_train_samples, + points, + }); + } + + let report = OracleLowrankReport { + index: args.index.display().to_string(), + prompt_file: args.prompts.display().to_string(), + prompts_seen: prompts.len(), + static_base: "position_mean".to_string(), + ks, + sigma_rel_cutoff: args.sigma_rel_cutoff, + selected_heads, + heads: head_reports, + }; + + let out_path = args.out.join("oracle_lowrank.json"); + let file = std::fs::File::create(&out_path)?; + serde_json::to_writer_pretty(file, &report)?; + eprintln!("Wrote {}", out_path.display()); + + Ok(()) +} + +fn forward_q4k_oracle_roundtrip_head( + weights: &mut larql_inference::ModelWeights, + token_ids: &[u32], + index: &VectorIndex, + head: HeadId, + basis: &WoRoundtripBasis, +) -> Result<(Array2, RoundtripPatchMetrics), Box> { + let mut metrics = None; + + let h = larql_inference::vindex::predict_q4k_hidden_with_mapped_pre_o_head( + weights, + token_ids, + index, + head.layer, + head.head, + |original_head| { + let mut replacement = Vec::with_capacity(original_head.len()); + let mut pre_sq = 0.0; + let mut visible_sq = 0.0; + let mut count = 0usize; + for pos in 0..original_head.nrows() { + let row = original_head.row(pos); + let values = row + .as_slice() + .ok_or("pre-W_O head row was not contiguous during roundtrip")?; + let projected = basis.project(values); + for (&original, &recon) in values.iter().zip(projected.iter()) { + let delta = original as f64 - recon as f64; + pre_sq += delta * delta; + } + let delta = values + .iter() + .zip(projected.iter()) + .map(|(&original, &recon)| original as f64 - recon as f64) + .collect::>(); + visible_sq += basis.visible_sq_norm(&delta); + count += 1; + replacement.extend_from_slice(&projected); + } + metrics = Some(RoundtripPatchMetrics { + pre_wo_l2: (pre_sq / count.max(1) as f64).sqrt(), + wo_visible_l2: (visible_sq / count.max(1) as f64).sqrt(), + }); + Array2::from_shape_vec((original_head.nrows(), original_head.ncols()), replacement) + .map_err(|err| err.to_string()) + }, + )?; + + Ok(( + h, + metrics.ok_or("oracle roundtrip did not visit target layer")?, + )) +} + +fn forward_q4k_oracle_lowrank_head( + weights: &mut larql_inference::ModelWeights, + token_ids: &[u32], + index: &VectorIndex, + head: HeadId, + basis: &WoRoundtripBasis, + pca_basis: &ZPcaBasis, + means: &StaticHeadMeans, + k: usize, +) -> Result<(Array2, RoundtripPatchMetrics), Box> { + let mut metrics = None; + + let h = larql_inference::vindex::predict_q4k_hidden_with_mapped_pre_o_head( + weights, + token_ids, + index, + head.layer, + head.head, + |original_head| { + let mut replacement = Vec::with_capacity(original_head.len()); + let mut pre_sq = 0.0; + let mut visible_sq = 0.0; + let mut count = 0usize; + for pos in 0..original_head.nrows() { + let row = original_head.row(pos); + let values = row + .as_slice() + .ok_or("pre-W_O head row was not contiguous during lowrank")?; + let base = means.positions.get(pos).unwrap_or(&means.global); + let residual = values + .iter() + .zip(base.iter()) + .map(|(&yi, &bi)| yi - bi) + .collect::>(); + let z = basis.residual_to_z(&residual); + let z_projected = pca_basis.project_with_rank(&z, k); + let residual_projected = basis.z_to_residual(&z_projected); + let projected = residual_projected + .into_iter() + .zip(base.iter()) + .map(|(ri, &bi)| ri + bi) + .collect::>(); + for (&original, &recon) in values.iter().zip(projected.iter()) { + let delta = original as f64 - recon as f64; + pre_sq += delta * delta; + } + let delta = values + .iter() + .zip(projected.iter()) + .map(|(&original, &recon)| original as f64 - recon as f64) + .collect::>(); + visible_sq += basis.visible_sq_norm(&delta); + count += 1; + replacement.extend_from_slice(&projected); + } + metrics = Some(RoundtripPatchMetrics { + pre_wo_l2: (pre_sq / count.max(1) as f64).sqrt(), + wo_visible_l2: (visible_sq / count.max(1) as f64).sqrt(), + }); + Array2::from_shape_vec((original_head.nrows(), original_head.ncols()), replacement) + .map_err(|err| err.to_string()) + }, + )?; + + Ok(( + h, + metrics.ok_or("oracle lowrank did not visit target layer")?, + )) +} + +fn final_logits(weights: &larql_inference::ModelWeights, h: &Array2) -> Vec { + let last = h.nrows().saturating_sub(1); + let h_last = h.slice(s![last..last + 1, ..]).to_owned(); + hidden_to_raw_logits(weights, &h_last) +} diff --git a/crates/larql-cli/src/commands/dev/ov_rd/oracle_pq.rs b/crates/larql-cli/src/commands/dev/ov_rd/oracle_pq.rs new file mode 100644 index 00000000..fb6f8ceb --- /dev/null +++ b/crates/larql-cli/src/commands/dev/ov_rd/oracle_pq.rs @@ -0,0 +1,4042 @@ +use std::path::PathBuf; +use std::time::Instant; + +use clap::Args; +use larql_inference::encode_prompt; +use larql_vindex::{ + load_model_weights_q4k, load_vindex_tokenizer, SilentLoadCallbacks, VectorIndex, +}; +use std::collections::HashMap; + +use super::address::{ + attention_argmax, attention_relation_key, ffn_first_feature_key, prev_ffn_feature_key, +}; +use super::basis::*; +use super::gamma_address::fit_gamma_projected_address_models; +use super::input::*; +use super::metrics::*; +use super::oracle_pq_address::{ + collect_code_occurrences, fit_address_attention_cluster_group_models, + fit_address_attention_relation_group_models, fit_address_ffn_first_feature_group_models, + fit_address_lsh_group_models, fit_address_prev_ffn_feature_group_models, + fit_address_probe_models, fit_address_reduced_qk_cluster_group_models, + fit_address_supervised_group_models, fit_majority_codes_for_codebooks, +}; +use super::oracle_pq_eval::evaluate_predicted_address; +use super::oracle_pq_fit::fit_pq_codebooks; +use super::oracle_pq_forward::{ + capture_attention_relation_rows, capture_ffn_first_feature_keys, capture_layer_input_hidden, + capture_prev_ffn_feature_keys, capture_reduced_qk_attention_rows, final_logits, + forward_q4k_oracle_pq_head, forward_q4k_oracle_pq_mode_d_head, +}; +use super::oracle_pq_mode_d::{corruption_keep_values, materialize_mode_d_tables}; +use super::oracle_pq_reports::OraclePqPointAccumulator; +use super::oracle_pq_stability::measure_code_stability; +use super::reports::*; +use super::static_replace::fit_static_means; +use super::types::*; + +#[derive(Args)] +pub(super) struct OraclePqArgs { + /// Self-contained Q4K vindex directory. + #[arg(long)] + index: PathBuf, + + /// JSONL prompt file. Each line must include at least {"prompt": "..."}. + #[arg(long)] + prompts: PathBuf, + + /// Output directory. + #[arg(long)] + out: PathBuf, + + /// Explicit heads as layer:head comma list, e.g. 0:6. + #[arg(long)] + heads: String, + + /// Comma-separated PQ configs as K:groups:bits, e.g. 128:16:4,192:24:4. + #[arg(long)] + configs: String, + + /// Relative singular value cutoff for retained W_O-visible directions. + #[arg(long, default_value_t = 1e-6)] + sigma_rel_cutoff: f64, + + /// Lloyd iterations per product-codebook group. + #[arg(long, default_value_t = 25)] + pq_iters: usize, + + /// Also materialize residual-space additive tables and compare Mode D injection. + #[arg(long)] + mode_d_check: bool, + + /// Fit and evaluate graph-native discrete address probes. + /// + /// The probes use only prompt metadata and token ids, not residual vectors. + /// Requires --mode-d-check because predicted addresses are evaluated through + /// the materialized residual-space tables. + #[arg(long)] + address_probes: bool, + + /// Add a mixed simple-key address probe that picks the best discrete key + /// independently for each PQ group on the training split. + #[arg(long)] + address_mixed_key_probe: bool, + + /// Evaluate simple discrete keys on selected PQ groups only. Selected + /// groups are predicted from each key; unselected groups are evaluated as + /// either oracle-correct or majority/default. + #[arg(long)] + address_key_group_probe: bool, + + /// Comma-separated PQ groups for --address-key-group-probe. + #[arg(long, default_value = "0")] + address_key_groups: String, + + /// Optional comma-separated simple-key probe names for + /// --address-key-group-probe. Empty evaluates all simple-key probes. + #[arg(long, default_value = "")] + address_key_group_probe_names: String, + + /// Evaluate selected PQ groups by replacing them with train-set majority + /// codes while all unselected groups remain oracle-correct. + #[arg(long)] + address_majority_group_probe: bool, + + /// Comma-separated PQ groups for --address-majority-group-probe. + #[arg(long, default_value = "0")] + address_majority_groups: String, + + /// Evaluate code-level behavioral substitution for selected PQ groups. + /// + /// Positions whose oracle group code equals a selected from-code are + /// substituted to each selected to-code while all other groups and + /// positions remain oracle-correct. + #[arg(long)] + address_code_substitution_group_probe: bool, + + /// Comma-separated PQ groups for --address-code-substitution-group-probe. + #[arg(long, default_value = "0")] + address_code_substitution_groups: String, + + /// Optional comma-separated source codes. Empty means all codes. + #[arg(long, default_value = "")] + address_code_substitution_from_codes: String, + + /// Target codes. Use "majority" or a comma-separated list of codes. + #[arg(long, default_value = "majority")] + address_code_substitution_to_codes: String, + + /// Evaluate simultaneous behavioral class-collapse substitutions. + /// + /// Spec format: + /// name=6+10+13:13 + /// name=6+10+13:13|7:10 + /// Multiple specs are separated by semicolons. + #[arg(long)] + address_code_class_collapse_group_probe: bool, + + /// Comma-separated PQ groups for --address-code-class-collapse-group-probe. + #[arg(long, default_value = "0")] + address_code_class_collapse_groups: String, + + /// Semicolon-separated class-collapse specs. + #[arg(long, default_value = "")] + address_code_class_collapse_specs: String, + + /// Probe position-local interactions for one prompt and one PQ group. + /// + /// This is a targeted diagnostic for quotient failures: selected primary + /// and secondary source codes are changed to one target code only within + /// the requested prompt, while all other positions/groups remain oracle. + #[arg(long)] + address_code_position_interaction_probe: bool, + + /// Prompt id for --address-code-position-interaction-probe. + #[arg(long, default_value = "")] + address_code_position_prompt_id: String, + + /// PQ group for --address-code-position-interaction-probe. + #[arg(long, default_value_t = 0)] + address_code_position_group: usize, + + /// Primary source codes for --address-code-position-interaction-probe. + #[arg(long, default_value = "10")] + address_code_position_primary_codes: String, + + /// Secondary source codes for --address-code-position-interaction-probe. + #[arg(long, default_value = "6")] + address_code_position_secondary_codes: String, + + /// Target code for --address-code-position-interaction-probe. + #[arg(long, default_value_t = 13)] + address_code_position_target_code: usize, + + /// Evaluate split-wide conditional quotient rules for one PQ group. + /// + /// Primary codes are mapped to the target unconditionally. Secondary codes + /// are mapped to the target except where a built-in guard preserves the + /// oracle code. This tests whether a quotient plus local exception guard + /// clears the held-out gate. + #[arg(long)] + address_code_conditional_quotient_group_probe: bool, + + /// PQ group for --address-code-conditional-quotient-group-probe. + #[arg(long, default_value_t = 0)] + address_code_conditional_quotient_group: usize, + + /// Primary source codes for the conditional quotient probe. + #[arg(long, default_value = "10")] + address_code_conditional_quotient_primary_codes: String, + + /// Secondary source codes for the conditional quotient probe. + #[arg(long, default_value = "6")] + address_code_conditional_quotient_secondary_codes: String, + + /// Target code for the conditional quotient probe. + #[arg(long, default_value_t = 13)] + address_code_conditional_quotient_target_code: usize, + + /// Max early position guarded by early-prose conditional quotient variants. + #[arg(long, default_value_t = 1)] + address_code_conditional_quotient_early_position_max: usize, + + /// Conditional quotient guards to evaluate. + /// + /// Supported: early_prose_position, early_prose_bos_prev, prose_bos_prev. + #[arg( + long, + default_value = "early_prose_position,early_prose_bos_prev,prose_bos_prev" + )] + address_code_conditional_quotient_guards: String, + + /// Extra source:target mappings layered on top of the conditional quotient. + /// + /// Spec format matches class-collapse specs. Empty adds only the base + /// conditional quotient. Example: + /// code4_to13=4:13;code7_to10=7:10 + #[arg(long, default_value = "")] + address_code_conditional_quotient_extra_specs: String, + + /// Export per-position occurrences for selected PQ group codes. + #[arg(long)] + address_code_occurrences: bool, + + /// Comma-separated PQ groups for --address-code-occurrences. + #[arg(long, default_value = "0")] + address_code_occurrence_groups: String, + + /// Optional comma-separated codes for --address-code-occurrences. + /// Empty means all codes. + #[arg(long, default_value = "")] + address_code_occurrence_codes: String, + + /// Occurrence split to export: train, eval, or all. + #[arg(long, default_value = "eval")] + address_code_occurrence_split: String, + + /// Evaluate a hard-coded code7 fallback rule for L0H6-style probes. + /// + /// For selected groups, predict special code when attention argmax is BOS + /// and stratum is not arithmetic; otherwise predict the train majority + /// code. Unselected groups remain oracle-correct. + #[arg(long)] + address_code7_bos_rule_group_probe: bool, + + /// Comma-separated PQ groups for --address-code7-bos-rule-group-probe. + #[arg(long, default_value = "0")] + address_code7_bos_rule_groups: String, + + /// Special code used by --address-code7-bos-rule-group-probe. + #[arg(long, default_value_t = 7)] + address_code7_bos_rule_code: usize, + + /// Evaluate oracle upper bounds for a binary code7-vs-default address. + /// + /// Selected groups use the special code only where the oracle address has + /// that code and the requested structural filter matches; all other + /// positions use the train majority code. Unselected groups remain + /// oracle-correct. + #[arg(long)] + address_code7_oracle_binary_group_probe: bool, + + /// Comma-separated PQ groups for --address-code7-oracle-binary-group-probe. + #[arg(long, default_value = "0")] + address_code7_oracle_binary_groups: String, + + /// Special code used by --address-code7-oracle-binary-group-probe. + #[arg(long, default_value_t = 7)] + address_code7_oracle_binary_code: usize, + + /// Comma-separated filters for oracle binary code7 upper bounds. + /// + /// Supported: all, natural_prose_bos, natural_prose_bos_or_prev. + #[arg( + long, + default_value = "all,natural_prose_bos,natural_prose_bos_or_prev" + )] + address_code7_oracle_binary_filters: String, + + /// Evaluate how sensitive Mode D is to address corruption. + /// + /// This keeps a prefix of oracle PQ groups and replaces the rest with + /// per-group majority codes learned from the training split. It estimates + /// how many groups must be addressed correctly before predicted addressing + /// can pass the KL gate. + #[arg(long)] + address_corruption_sweep: bool, + + /// Evaluate one-group-at-a-time address importance by replacing each group + /// with its train-set majority code while all other groups remain oracle. + #[arg(long)] + address_group_importance: bool, + + /// Fit and evaluate fixed random-hyperplane LSH probes for selected PQ + /// groups. The selected groups are predicted from the residual entering the + /// target layer; other groups are evaluated both oracle-correct and + /// majority/default. + #[arg(long)] + address_lsh_group_probe: bool, + + /// Comma-separated PQ groups for --address-lsh-group-probe. + #[arg(long, default_value = "0")] + address_lsh_groups: String, + + /// Number of LSH bits per selected group. For a 4-bit PQ group, 4 LSH bits + /// creates 16 buckets. + #[arg(long, default_value_t = 4)] + address_lsh_bits: usize, + + /// Number of deterministic random-hyperplane seeds to try per selected + /// group. The best seed is selected by train code accuracy. + #[arg(long, default_value_t = 32)] + address_lsh_seeds: usize, + + /// Fit and evaluate supervised binary-hyperplane address probes for + /// selected PQ groups. The selected groups are predicted from the residual + /// entering the target layer; other groups are evaluated both + /// oracle-correct and majority/default. + #[arg(long)] + address_supervised_group_probe: bool, + + /// Comma-separated PQ groups for --address-supervised-group-probe. + #[arg(long, default_value = "0")] + address_supervised_groups: String, + + /// SGD epochs for supervised binary-hyperplane group address probes. + #[arg(long, default_value_t = 16)] + address_supervised_epochs: usize, + + /// SGD learning rate for supervised binary-hyperplane group address probes. + #[arg(long, default_value_t = 0.05)] + address_supervised_lr: f32, + + /// L2 weight decay for supervised binary-hyperplane group address probes. + #[arg(long, default_value_t = 1e-4)] + address_supervised_l2: f32, + + /// Fit and evaluate supervised group address probes after a diagonal + /// affine gamma-alignment projection from the layer input toward later + /// post-layer residual snapshots. + #[arg(long)] + address_gamma_projected_group_probe: bool, + + /// Comma-separated PQ groups for --address-gamma-projected-group-probe. + #[arg(long, default_value = "0")] + address_gamma_projected_groups: String, + + /// Comma-separated post-layer residual snapshots used as gamma-alignment + /// targets, e.g. 20,26,29,33. The raw layer-input supervised probe is + /// always included as gamma_raw for comparison. + #[arg(long, default_value = "20,26,29,33")] + address_gamma_projected_layers: String, + + /// Comma-separated random projection ranks for the gamma bridge control, + /// e.g. 64,128. These are fixed Rademacher low-rank projections of the + /// layer input followed by the same supervised bit probes. + #[arg(long, default_value = "")] + address_gamma_random_ranks: String, + + /// Comma-separated deterministic seeds for random projection ranks. + #[arg(long, default_value = "0")] + address_gamma_random_seeds: String, + + /// Comma-separated learned bridge ranks for the gamma bridge test. These + /// fit a low-rank target-PCA proxy from layer input to later residual + /// snapshots before training the same supervised group-bit probes. + #[arg(long, default_value = "")] + address_gamma_learned_ranks: String, + + /// SGD epochs for learned low-rank gamma bridge fitting. + #[arg(long, default_value_t = 8)] + address_gamma_learned_epochs: usize, + + /// Normalized LMS learning rate for learned low-rank gamma bridge fitting. + #[arg(long, default_value_t = 0.5)] + address_gamma_learned_lr: f32, + + /// L2 weight decay for learned low-rank gamma bridge fitting. + #[arg(long, default_value_t = 1e-5)] + address_gamma_learned_l2: f32, + + /// Power-iteration steps for the learned bridge target PCA basis. + #[arg(long, default_value_t = 8)] + address_gamma_learned_pca_iters: usize, + + /// Report train/eval PQ code distribution stability for selected groups. + #[arg(long)] + address_code_stability: bool, + + /// Comma-separated PQ groups for --address-code-stability. + #[arg(long, default_value = "0")] + address_code_stability_groups: String, + + /// Fit and evaluate selected PQ groups from previous-layer FFN top-feature + /// keys. This is the first model-native discrete-state address probe for + /// non-layer-0 dynamic heads. + #[arg(long)] + address_prev_ffn_feature_group_probe: bool, + + /// Comma-separated PQ groups for --address-prev-ffn-feature-group-probe. + #[arg(long, default_value = "0")] + address_prev_ffn_feature_groups: String, + + /// Number of previous-layer FFN activation features retained for feature + /// hash keys. + #[arg(long, default_value_t = 4)] + address_prev_ffn_feature_top_k: usize, + + /// Fit and evaluate selected PQ groups from an FFN-first diagnostic state: + /// run the target layer's FFN on the pre-attention residual, use top + /// activation features as keys, but leave the real forward ordering + /// unchanged. This tests whether computed L0 FFN features would bootstrap + /// attention addressability under an FFN-first reorder. + #[arg(long)] + address_ffn_first_feature_group_probe: bool, + + /// Comma-separated PQ groups for --address-ffn-first-feature-group-probe. + #[arg(long, default_value = "0")] + address_ffn_first_feature_groups: String, + + /// Number of FFN-first activation features retained for feature hash keys. + #[arg(long, default_value_t = 4)] + address_ffn_first_feature_top_k: usize, + + /// Fit and evaluate selected PQ groups from discrete attention/relation + /// state keys. This tests whether the dominant address is carried by QK + /// routing structure rather than token or FFN-feature state. + #[arg(long)] + address_attention_relation_group_probe: bool, + + /// Comma-separated PQ groups for --address-attention-relation-group-probe. + #[arg(long, default_value = "0")] + address_attention_relation_groups: String, + + /// Fit and evaluate selected PQ groups from learned attention-pattern + /// cluster IDs. This is a discrete relation-catalogue probe over fixed + /// features derived from the full attention distribution. + #[arg(long)] + address_attention_cluster_group_probe: bool, + + /// Comma-separated PQ groups for --address-attention-cluster-group-probe. + #[arg(long, default_value = "0")] + address_attention_cluster_groups: String, + + /// Comma-separated k values for attention-pattern clustering. + #[arg(long, default_value = "16,32")] + address_attention_cluster_ks: String, + + /// Optional comma-separated attention-cluster probe names. Empty evaluates + /// all cluster probe names for the selected k values. + #[arg(long, default_value = "")] + address_attention_cluster_probe_names: String, + + /// Fit/evaluate selected PQ groups from attention-pattern clusters where + /// the attention distribution is recomputed from only the first r Q/K + /// dimensions. Use rank 0 for the full-QK control. + #[arg(long)] + address_reduced_qk_cluster_group_probe: bool, + + /// Comma-separated PQ groups for --address-reduced-qk-cluster-group-probe. + #[arg(long, default_value = "0")] + address_reduced_qk_cluster_groups: String, + + /// Comma-separated QK ranks. Rank 0 means full QK; positive ranks are + /// clamped to the layer head dimension. + #[arg(long, default_value = "0,128,64,32,16")] + address_reduced_qk_ranks: String, + + /// Comma-separated k values for reduced-QK attention-pattern clustering. + #[arg(long, default_value = "16,32")] + address_reduced_qk_cluster_ks: String, + + /// Optional comma-separated reduced-QK cluster probe names. Empty evaluates + /// all generated names. + #[arg(long, default_value = "")] + address_reduced_qk_cluster_probe_names: String, + + /// Comma-separated PQ groups whose centroids are fit separately per + /// prompt stratum. This is a codebook-layout diagnostic for cases where a + /// single global PQ group carries a hard prose/structured tail. + #[arg(long, default_value = "")] + stratum_conditioned_pq_groups: String, + + /// Limit prompts for bounded oracle runs. + #[arg(long)] + max_prompts: Option, + + /// Keep at most N prompts per stratum after loading. Useful for balanced + /// held-out smoke runs from a larger ordered corpus. + #[arg(long)] + max_per_stratum: Option, + + /// Evaluate only prompts where prompt_index % eval_mod == eval_offset. + /// The remaining prompts are used to fit static means, PCA, and PQ. + #[arg(long)] + eval_mod: Option, + + /// Held-out modulo offset used with --eval-mod. + #[arg(long, default_value_t = 0)] + eval_offset: usize, +} + +pub(super) fn run_oracle_pq(args: OraclePqArgs) -> Result<(), Box> { + std::fs::create_dir_all(&args.out)?; + + eprintln!("Loading vindex: {}", args.index.display()); + let start = Instant::now(); + let mut cb = SilentLoadCallbacks; + let mut index = VectorIndex::load_vindex(&args.index, &mut cb)?; + index.load_attn_q4k(&args.index)?; + index.load_interleaved_q4k(&args.index)?; + let mut weights = load_model_weights_q4k(&args.index, &mut cb)?; + let tokenizer = load_vindex_tokenizer(&args.index)?; + if weights.arch.is_hybrid_moe() { + return Err("ov-rd oracle-pq currently supports dense FFN vindexes only".into()); + } + eprintln!( + " {} layers, hidden_size={}, q_heads={}, head_dim={} ({:.1}s)", + weights.num_layers, + weights.hidden_size, + weights.num_q_heads, + weights.head_dim, + start.elapsed().as_secs_f64() + ); + + let selected_heads = parse_head_spec(&args.heads)?; + if selected_heads.is_empty() { + return Err("no heads selected for oracle PQ".into()); + } + let configs = parse_pq_configs(&args.configs)?; + if configs.is_empty() { + return Err("no PQ configs selected".into()); + } + let mut key_groups = parse_usize_list(&args.address_key_groups)?; + key_groups.sort_unstable(); + key_groups.dedup(); + let key_group_probe_names = parse_string_list(&args.address_key_group_probe_names); + if args.address_key_group_probe { + if key_groups.is_empty() { + return Err( + "--address-key-group-probe requires at least one --address-key-groups value".into(), + ); + } + for config in &configs { + for &group in &key_groups { + if group >= config.groups { + return Err(format!( + "--address-key-groups includes group {group}, but config {:?} has only {} groups", + config, config.groups + ) + .into()); + } + } + } + } + let mut majority_groups = parse_usize_list(&args.address_majority_groups)?; + majority_groups.sort_unstable(); + majority_groups.dedup(); + if args.address_majority_group_probe { + if majority_groups.is_empty() { + return Err("--address-majority-group-probe requires at least one --address-majority-groups value".into()); + } + for config in &configs { + for &group in &majority_groups { + if group >= config.groups { + return Err(format!( + "--address-majority-groups includes group {group}, but config {:?} has only {} groups", + config, config.groups + ) + .into()); + } + } + } + } + let mut code_substitution_groups = parse_usize_list(&args.address_code_substitution_groups)?; + code_substitution_groups.sort_unstable(); + code_substitution_groups.dedup(); + let mut code_substitution_from_codes = + parse_usize_list(&args.address_code_substitution_from_codes)?; + code_substitution_from_codes.sort_unstable(); + code_substitution_from_codes.dedup(); + let code_substitution_to_specs = + parse_code_substitution_to_specs(&args.address_code_substitution_to_codes)?; + if args.address_code_substitution_group_probe { + if code_substitution_groups.is_empty() { + return Err("--address-code-substitution-group-probe requires at least one --address-code-substitution-groups value".into()); + } + if code_substitution_to_specs.is_empty() { + return Err("--address-code-substitution-group-probe requires at least one --address-code-substitution-to-codes value".into()); + } + for config in &configs { + let levels = 1usize << config.bits_per_group; + for &group in &code_substitution_groups { + if group >= config.groups { + return Err(format!( + "--address-code-substitution-groups includes group {group}, but config {:?} has only {} groups", + config, config.groups + ) + .into()); + } + } + for &code in &code_substitution_from_codes { + if code >= levels { + return Err(format!( + "--address-code-substitution-from-codes includes code {code}, but config {:?} has only {levels} levels", + config + ) + .into()); + } + } + for spec in &code_substitution_to_specs { + if let CodeSubstitutionToSpec::Code(code) = spec { + if *code >= levels { + return Err(format!( + "--address-code-substitution-to-codes includes code {code}, but config {:?} has only {levels} levels", + config + ) + .into()); + } + } + } + } + } + let mut code_class_collapse_groups = + parse_usize_list(&args.address_code_class_collapse_groups)?; + code_class_collapse_groups.sort_unstable(); + code_class_collapse_groups.dedup(); + let code_class_collapse_specs = + parse_code_class_collapse_specs(&args.address_code_class_collapse_specs)?; + if args.address_code_class_collapse_group_probe { + if code_class_collapse_groups.is_empty() { + return Err("--address-code-class-collapse-group-probe requires at least one --address-code-class-collapse-groups value".into()); + } + if code_class_collapse_specs.is_empty() { + return Err( + "--address-code-class-collapse-specs must include at least one spec".into(), + ); + } + for config in &configs { + let levels = 1usize << config.bits_per_group; + for &group in &code_class_collapse_groups { + if group >= config.groups { + return Err(format!( + "--address-code-class-collapse-groups includes group {group}, but config {:?} has only {} groups", + config, config.groups + ) + .into()); + } + } + for spec in &code_class_collapse_specs { + for mapping in &spec.mappings { + if mapping.target >= levels { + return Err(format!( + "class-collapse spec {:?} targets code {}, but config {:?} has only {levels} levels", + spec.name, mapping.target, config + ) + .into()); + } + for &source in &mapping.sources { + if source >= levels { + return Err(format!( + "class-collapse spec {:?} includes source code {source}, but config {:?} has only {levels} levels", + spec.name, config + ) + .into()); + } + } + } + } + } + } + let mut code_position_primary_codes = + parse_usize_list(&args.address_code_position_primary_codes)?; + code_position_primary_codes.sort_unstable(); + code_position_primary_codes.dedup(); + let mut code_position_secondary_codes = + parse_usize_list(&args.address_code_position_secondary_codes)?; + code_position_secondary_codes.sort_unstable(); + code_position_secondary_codes.dedup(); + let code_position_prompt_id = args.address_code_position_prompt_id.trim().to_string(); + if args.address_code_position_interaction_probe { + if code_position_prompt_id.is_empty() { + return Err("--address-code-position-interaction-probe requires --address-code-position-prompt-id".into()); + } + if code_position_primary_codes.is_empty() { + return Err( + "--address-code-position-primary-codes must include at least one code".into(), + ); + } + if code_position_secondary_codes.is_empty() { + return Err( + "--address-code-position-secondary-codes must include at least one code".into(), + ); + } + for config in &configs { + let levels = 1usize << config.bits_per_group; + if args.address_code_position_group >= config.groups { + return Err(format!( + "--address-code-position-group is {}, but config {:?} has only {} groups", + args.address_code_position_group, config, config.groups + ) + .into()); + } + if args.address_code_position_target_code >= levels { + return Err(format!( + "--address-code-position-target-code is {}, but config {:?} has only {levels} levels", + args.address_code_position_target_code, config + ) + .into()); + } + for &code in code_position_primary_codes + .iter() + .chain(code_position_secondary_codes.iter()) + { + if code >= levels { + return Err(format!( + "--address-code-position primary/secondary code {code} exceeds config {:?} with {levels} levels", + config + ) + .into()); + } + } + } + } + let mut code_conditional_quotient_primary_codes = + parse_usize_list(&args.address_code_conditional_quotient_primary_codes)?; + code_conditional_quotient_primary_codes.sort_unstable(); + code_conditional_quotient_primary_codes.dedup(); + let mut code_conditional_quotient_secondary_codes = + parse_usize_list(&args.address_code_conditional_quotient_secondary_codes)?; + code_conditional_quotient_secondary_codes.sort_unstable(); + code_conditional_quotient_secondary_codes.dedup(); + let code_conditional_quotient_guards = + parse_conditional_quotient_guards(&args.address_code_conditional_quotient_guards)?; + let mut code_conditional_quotient_extra_specs = + parse_code_class_collapse_specs(&args.address_code_conditional_quotient_extra_specs)?; + code_conditional_quotient_extra_specs.insert( + 0, + CodeClassCollapseSpec { + name: "base".to_string(), + mappings: Vec::new(), + }, + ); + if args.address_code_conditional_quotient_group_probe { + if code_conditional_quotient_primary_codes.is_empty() { + return Err( + "--address-code-conditional-quotient-primary-codes must include at least one code" + .into(), + ); + } + if code_conditional_quotient_secondary_codes.is_empty() { + return Err("--address-code-conditional-quotient-secondary-codes must include at least one code".into()); + } + if code_conditional_quotient_guards.is_empty() { + return Err( + "--address-code-conditional-quotient-guards must include at least one guard".into(), + ); + } + for config in &configs { + let levels = 1usize << config.bits_per_group; + if args.address_code_conditional_quotient_group >= config.groups { + return Err(format!( + "--address-code-conditional-quotient-group is {}, but config {:?} has only {} groups", + args.address_code_conditional_quotient_group, config, config.groups + ) + .into()); + } + if args.address_code_conditional_quotient_target_code >= levels { + return Err(format!( + "--address-code-conditional-quotient-target-code is {}, but config {:?} has only {levels} levels", + args.address_code_conditional_quotient_target_code, config + ) + .into()); + } + for &code in code_conditional_quotient_primary_codes + .iter() + .chain(code_conditional_quotient_secondary_codes.iter()) + { + if code >= levels { + return Err(format!( + "--address-code-conditional-quotient primary/secondary code {code} exceeds config {:?} with {levels} levels", + config + ) + .into()); + } + } + for spec in &code_conditional_quotient_extra_specs { + for mapping in &spec.mappings { + if mapping.target >= levels { + return Err(format!( + "conditional quotient extra spec {:?} targets code {}, but config {:?} has only {levels} levels", + spec.name, mapping.target, config + ) + .into()); + } + for &source in &mapping.sources { + if source >= levels { + return Err(format!( + "conditional quotient extra spec {:?} includes source code {source}, but config {:?} has only {levels} levels", + spec.name, config + ) + .into()); + } + } + } + } + } + } + let mut code_occurrence_groups = parse_usize_list(&args.address_code_occurrence_groups)?; + code_occurrence_groups.sort_unstable(); + code_occurrence_groups.dedup(); + let mut code_occurrence_codes = parse_usize_list(&args.address_code_occurrence_codes)?; + code_occurrence_codes.sort_unstable(); + code_occurrence_codes.dedup(); + let code_occurrence_split = args + .address_code_occurrence_split + .trim() + .to_ascii_lowercase(); + if args.address_code_occurrences { + if code_occurrence_groups.is_empty() { + return Err( + "--address-code-occurrences requires at least one --address-code-occurrence-groups value" + .into(), + ); + } + if !matches!(code_occurrence_split.as_str(), "train" | "eval" | "all") { + return Err("--address-code-occurrence-split must be train, eval, or all".into()); + } + for config in &configs { + let levels = 1usize << config.bits_per_group; + for &group in &code_occurrence_groups { + if group >= config.groups { + return Err(format!( + "--address-code-occurrence-groups includes group {group}, but config {:?} has only {} groups", + config, config.groups + ) + .into()); + } + } + for &code in &code_occurrence_codes { + if code >= levels { + return Err(format!( + "--address-code-occurrence-codes includes code {code}, but config {:?} has only {levels} levels", + config + ) + .into()); + } + } + } + } + let mut code7_bos_rule_groups = parse_usize_list(&args.address_code7_bos_rule_groups)?; + code7_bos_rule_groups.sort_unstable(); + code7_bos_rule_groups.dedup(); + if args.address_code7_bos_rule_group_probe { + if code7_bos_rule_groups.is_empty() { + return Err("--address-code7-bos-rule-group-probe requires at least one --address-code7-bos-rule-groups value".into()); + } + for config in &configs { + let levels = 1usize << config.bits_per_group; + for &group in &code7_bos_rule_groups { + if group >= config.groups { + return Err(format!( + "--address-code7-bos-rule-groups includes group {group}, but config {:?} has only {} groups", + config, config.groups + ) + .into()); + } + } + if args.address_code7_bos_rule_code >= levels { + return Err(format!( + "--address-code7-bos-rule-code is {}, but config {:?} has only {levels} levels", + args.address_code7_bos_rule_code, config + ) + .into()); + } + } + } + let mut code7_oracle_binary_groups = + parse_usize_list(&args.address_code7_oracle_binary_groups)?; + code7_oracle_binary_groups.sort_unstable(); + code7_oracle_binary_groups.dedup(); + let code7_oracle_binary_filters = parse_string_list(&args.address_code7_oracle_binary_filters); + if args.address_code7_oracle_binary_group_probe { + if code7_oracle_binary_groups.is_empty() { + return Err("--address-code7-oracle-binary-group-probe requires at least one --address-code7-oracle-binary-groups value".into()); + } + if code7_oracle_binary_filters.is_empty() { + return Err( + "--address-code7-oracle-binary-filters must include at least one filter".into(), + ); + } + for filter in &code7_oracle_binary_filters { + if !matches!( + filter.as_str(), + "all" | "natural_prose_bos" | "natural_prose_bos_or_prev" + ) { + return Err(format!( + "unsupported --address-code7-oracle-binary-filters value {filter:?}; expected all, natural_prose_bos, or natural_prose_bos_or_prev" + ) + .into()); + } + } + for config in &configs { + let levels = 1usize << config.bits_per_group; + for &group in &code7_oracle_binary_groups { + if group >= config.groups { + return Err(format!( + "--address-code7-oracle-binary-groups includes group {group}, but config {:?} has only {} groups", + config, config.groups + ) + .into()); + } + } + if args.address_code7_oracle_binary_code >= levels { + return Err(format!( + "--address-code7-oracle-binary-code is {}, but config {:?} has only {levels} levels", + args.address_code7_oracle_binary_code, config + ) + .into()); + } + } + } + let mut lsh_groups = parse_usize_list(&args.address_lsh_groups)?; + lsh_groups.sort_unstable(); + lsh_groups.dedup(); + if args.address_lsh_group_probe { + if lsh_groups.is_empty() { + return Err( + "--address-lsh-group-probe requires at least one --address-lsh-groups value".into(), + ); + } + if args.address_lsh_bits == 0 { + return Err("--address-lsh-bits must be greater than zero".into()); + } + if args.address_lsh_bits > 16 { + return Err("--address-lsh-bits is capped at 16 for bounded diagnostics".into()); + } + if args.address_lsh_seeds == 0 { + return Err("--address-lsh-seeds must be greater than zero".into()); + } + for config in &configs { + for &group in &lsh_groups { + if group >= config.groups { + return Err(format!( + "--address-lsh-groups includes group {group}, but config {:?} has only {} groups", + config, config.groups + ) + .into()); + } + } + } + } + let mut supervised_groups = parse_usize_list(&args.address_supervised_groups)?; + supervised_groups.sort_unstable(); + supervised_groups.dedup(); + if args.address_supervised_group_probe { + if supervised_groups.is_empty() { + return Err( + "--address-supervised-group-probe requires at least one --address-supervised-groups value".into(), + ); + } + if args.address_supervised_epochs == 0 { + return Err("--address-supervised-epochs must be greater than zero".into()); + } + if args.address_supervised_lr <= 0.0 { + return Err("--address-supervised-lr must be greater than zero".into()); + } + if args.address_supervised_l2 < 0.0 { + return Err("--address-supervised-l2 must be non-negative".into()); + } + for config in &configs { + for &group in &supervised_groups { + if group >= config.groups { + return Err(format!( + "--address-supervised-groups includes group {group}, but config {:?} has only {} groups", + config, config.groups + ) + .into()); + } + } + } + } + let mut gamma_projected_groups = parse_usize_list(&args.address_gamma_projected_groups)?; + gamma_projected_groups.sort_unstable(); + gamma_projected_groups.dedup(); + let mut gamma_projected_layers = parse_usize_list(&args.address_gamma_projected_layers)?; + gamma_projected_layers.sort_unstable(); + gamma_projected_layers.dedup(); + let mut gamma_random_ranks = parse_usize_list(&args.address_gamma_random_ranks)?; + gamma_random_ranks.sort_unstable(); + gamma_random_ranks.dedup(); + let mut gamma_random_seeds = parse_usize_list(&args.address_gamma_random_seeds)? + .into_iter() + .map(|seed| seed as u64) + .collect::>(); + gamma_random_seeds.sort_unstable(); + gamma_random_seeds.dedup(); + let mut gamma_learned_ranks = parse_usize_list(&args.address_gamma_learned_ranks)?; + gamma_learned_ranks.sort_unstable(); + gamma_learned_ranks.dedup(); + if args.address_gamma_projected_group_probe { + if gamma_projected_groups.is_empty() { + return Err("--address-gamma-projected-group-probe requires at least one --address-gamma-projected-groups value".into()); + } + if gamma_projected_layers.is_empty() + && gamma_random_ranks.is_empty() + && gamma_learned_ranks.is_empty() + { + return Err("--address-gamma-projected-layers, --address-gamma-random-ranks, or --address-gamma-learned-ranks must include at least one value".into()); + } + if !gamma_learned_ranks.is_empty() && gamma_projected_layers.is_empty() { + return Err( + "--address-gamma-learned-ranks requires at least one --address-gamma-projected-layers value" + .into(), + ); + } + for &layer in &gamma_projected_layers { + if layer >= weights.num_layers { + return Err(format!( + "--address-gamma-projected-layers includes layer {layer}, but the model has only {} layers", + weights.num_layers + ) + .into()); + } + } + for head in &selected_heads { + for &layer in &gamma_projected_layers { + if layer < head.layer { + return Err(format!( + "--address-gamma-projected-layers includes post-L{layer}, before target L{}H{}", + head.layer, head.head + ) + .into()); + } + } + } + for &rank in &gamma_random_ranks { + if !(1..=weights.hidden_size).contains(&rank) { + return Err(format!( + "--address-gamma-random-ranks includes rank {rank}, expected 1..={}", + weights.hidden_size + ) + .into()); + } + } + if !gamma_random_ranks.is_empty() && gamma_random_seeds.is_empty() { + return Err( + "--address-gamma-random-seeds must include at least one seed when random ranks are enabled" + .into(), + ); + } + for &rank in &gamma_learned_ranks { + if !(1..=weights.hidden_size).contains(&rank) { + return Err(format!( + "--address-gamma-learned-ranks includes rank {rank}, expected 1..={}", + weights.hidden_size + ) + .into()); + } + } + if args.address_gamma_learned_epochs == 0 { + return Err("--address-gamma-learned-epochs must be greater than zero".into()); + } + if args.address_gamma_learned_lr <= 0.0 { + return Err("--address-gamma-learned-lr must be greater than zero".into()); + } + if args.address_gamma_learned_l2 < 0.0 { + return Err("--address-gamma-learned-l2 must be non-negative".into()); + } + if args.address_gamma_learned_pca_iters == 0 { + return Err("--address-gamma-learned-pca-iters must be greater than zero".into()); + } + for config in &configs { + for &group in &gamma_projected_groups { + if group >= config.groups { + return Err(format!( + "--address-gamma-projected-groups includes group {group}, but config {:?} has only {} groups", + config, config.groups + ) + .into()); + } + } + } + } + let mut code_stability_groups = parse_usize_list(&args.address_code_stability_groups)?; + code_stability_groups.sort_unstable(); + code_stability_groups.dedup(); + if args.address_code_stability { + if code_stability_groups.is_empty() { + return Err( + "--address-code-stability requires at least one --address-code-stability-groups value" + .into(), + ); + } + for config in &configs { + for &group in &code_stability_groups { + if group >= config.groups { + return Err(format!( + "--address-code-stability-groups includes group {group}, but config {:?} has only {} groups", + config, config.groups + ) + .into()); + } + } + } + } + let mut prev_ffn_feature_groups = parse_usize_list(&args.address_prev_ffn_feature_groups)?; + prev_ffn_feature_groups.sort_unstable(); + prev_ffn_feature_groups.dedup(); + if args.address_prev_ffn_feature_group_probe { + if prev_ffn_feature_groups.is_empty() { + return Err("--address-prev-ffn-feature-group-probe requires at least one --address-prev-ffn-feature-groups value".into()); + } + if args.address_prev_ffn_feature_top_k == 0 { + return Err("--address-prev-ffn-feature-top-k must be greater than zero".into()); + } + for head in &selected_heads { + if head.layer == 0 { + eprintln!( + "warning: L{}H{} has no previous layer; previous-FFN feature keys will be 'none'", + head.layer, head.head + ); + } + } + for config in &configs { + for &group in &prev_ffn_feature_groups { + if group >= config.groups { + return Err(format!( + "--address-prev-ffn-feature-groups includes group {group}, but config {:?} has only {} groups", + config, config.groups + ) + .into()); + } + } + } + } + let mut ffn_first_feature_groups = parse_usize_list(&args.address_ffn_first_feature_groups)?; + ffn_first_feature_groups.sort_unstable(); + ffn_first_feature_groups.dedup(); + if args.address_ffn_first_feature_group_probe { + if ffn_first_feature_groups.is_empty() { + return Err("--address-ffn-first-feature-group-probe requires at least one --address-ffn-first-feature-groups value".into()); + } + if args.address_ffn_first_feature_top_k == 0 { + return Err("--address-ffn-first-feature-top-k must be greater than zero".into()); + } + for config in &configs { + for &group in &ffn_first_feature_groups { + if group >= config.groups { + return Err(format!( + "--address-ffn-first-feature-groups includes group {group}, but config {:?} has only {} groups", + config, config.groups + ) + .into()); + } + } + } + } + let mut attention_relation_groups = parse_usize_list(&args.address_attention_relation_groups)?; + attention_relation_groups.sort_unstable(); + attention_relation_groups.dedup(); + if args.address_attention_relation_group_probe { + if attention_relation_groups.is_empty() { + return Err("--address-attention-relation-group-probe requires at least one --address-attention-relation-groups value".into()); + } + for config in &configs { + for &group in &attention_relation_groups { + if group >= config.groups { + return Err(format!( + "--address-attention-relation-groups includes group {group}, but config {:?} has only {} groups", + config, config.groups + ) + .into()); + } + } + } + } + let mut attention_cluster_groups = parse_usize_list(&args.address_attention_cluster_groups)?; + attention_cluster_groups.sort_unstable(); + attention_cluster_groups.dedup(); + let mut attention_cluster_ks = parse_usize_list(&args.address_attention_cluster_ks)?; + attention_cluster_ks.sort_unstable(); + attention_cluster_ks.dedup(); + let attention_cluster_probe_names = + parse_string_list(&args.address_attention_cluster_probe_names); + if args.address_attention_cluster_group_probe { + if attention_cluster_groups.is_empty() { + return Err("--address-attention-cluster-group-probe requires at least one --address-attention-cluster-groups value".into()); + } + if attention_cluster_ks.is_empty() { + return Err("--address-attention-cluster-ks must include at least one k".into()); + } + for &cluster_count in &attention_cluster_ks { + if !(2..=128).contains(&cluster_count) { + return Err( + "--address-attention-cluster-ks values must be between 2 and 128".into(), + ); + } + } + for config in &configs { + for &group in &attention_cluster_groups { + if group >= config.groups { + return Err(format!( + "--address-attention-cluster-groups includes group {group}, but config {:?} has only {} groups", + config, config.groups + ) + .into()); + } + } + } + } + let mut reduced_qk_cluster_groups = parse_usize_list(&args.address_reduced_qk_cluster_groups)?; + reduced_qk_cluster_groups.sort_unstable(); + reduced_qk_cluster_groups.dedup(); + let mut reduced_qk_ranks = parse_usize_list(&args.address_reduced_qk_ranks)?; + reduced_qk_ranks.sort_unstable(); + reduced_qk_ranks.dedup(); + let mut reduced_qk_cluster_ks = parse_usize_list(&args.address_reduced_qk_cluster_ks)?; + reduced_qk_cluster_ks.sort_unstable(); + reduced_qk_cluster_ks.dedup(); + let reduced_qk_cluster_probe_names = + parse_string_list(&args.address_reduced_qk_cluster_probe_names); + if args.address_reduced_qk_cluster_group_probe { + if reduced_qk_cluster_groups.is_empty() { + return Err("--address-reduced-qk-cluster-group-probe requires at least one --address-reduced-qk-cluster-groups value".into()); + } + if reduced_qk_ranks.is_empty() { + return Err("--address-reduced-qk-ranks must include at least one rank".into()); + } + if reduced_qk_cluster_ks.is_empty() { + return Err("--address-reduced-qk-cluster-ks must include at least one k".into()); + } + for &cluster_count in &reduced_qk_cluster_ks { + if !(2..=128).contains(&cluster_count) { + return Err( + "--address-reduced-qk-cluster-ks values must be between 2 and 128".into(), + ); + } + } + for config in &configs { + for &group in &reduced_qk_cluster_groups { + if group >= config.groups { + return Err(format!( + "--address-reduced-qk-cluster-groups includes group {group}, but config {:?} has only {} groups", + config, config.groups + ) + .into()); + } + } + } + } + let mut stratum_conditioned_pq_groups = parse_usize_list(&args.stratum_conditioned_pq_groups)?; + stratum_conditioned_pq_groups.sort_unstable(); + stratum_conditioned_pq_groups.dedup(); + for config in &configs { + for &group in &stratum_conditioned_pq_groups { + if group >= config.groups { + return Err(format!( + "--stratum-conditioned-pq-groups includes group {group}, but config {:?} has only {} groups", + config, config.groups + ) + .into()); + } + } + } + let mut prompts = load_prompts(&args.prompts, args.max_prompts)?; + if let Some(max_per_stratum) = args.max_per_stratum { + prompts = limit_prompts_per_stratum(prompts, max_per_stratum); + } + eprintln!("Selected heads: {:?}", selected_heads); + eprintln!("PQ configs: {:?}", configs); + eprintln!("Prompts: {}", prompts.len()); + let (fit_prompts, eval_prompts): (Vec, Vec) = + if let Some(eval_mod) = args.eval_mod { + split_prompt_records(&prompts, eval_mod, args.eval_offset)? + } else { + (prompts.clone(), prompts.clone()) + }; + eprintln!( + "Oracle PQ split: fit_prompts={}, eval_prompts={}", + fit_prompts.len(), + eval_prompts.len() + ); + + eprintln!("Fitting position-mean static bases"); + let means = fit_static_means( + &mut weights, + &index, + &tokenizer, + &fit_prompts, + &selected_heads, + )?; + + eprintln!("Building W_O-visible bases"); + let bases = + build_roundtrip_bases(&mut weights, &index, &selected_heads, args.sigma_rel_cutoff)?; + + eprintln!("Fitting empirical z-space PCA bases"); + let pca_bases = fit_z_pca_bases( + &mut weights, + &index, + &tokenizer, + &fit_prompts, + &selected_heads, + &bases, + &means, + )?; + + eprintln!("Fitting product quantizers"); + let codebooks = fit_pq_codebooks( + &mut weights, + &index, + &tokenizer, + &fit_prompts, + &selected_heads, + &bases, + &means, + &pca_bases, + &configs, + args.pq_iters, + &stratum_conditioned_pq_groups, + )?; + let mode_d_tables = if args.mode_d_check { + eprintln!("Materializing Mode D residual-space tables"); + materialize_mode_d_tables( + &mut weights, + &index, + &selected_heads, + &bases, + &means, + &pca_bases, + &codebooks, + &stratum_conditioned_pq_groups, + )? + } else { + HashMap::new() + }; + let run_address_probes = + args.address_probes || args.address_mixed_key_probe || args.address_key_group_probe; + let address_probe_models = if run_address_probes { + if !args.mode_d_check { + return Err( + "--address-probes/--address-mixed-key-probe requires --mode-d-check".into(), + ); + } + eprintln!("Fitting graph-native address probes"); + fit_address_probe_models( + &mut weights, + &index, + &tokenizer, + &fit_prompts, + &selected_heads, + &bases, + &means, + &pca_bases, + &codebooks, + args.address_mixed_key_probe, + )? + } else { + HashMap::new() + }; + let address_lsh_models = if args.address_lsh_group_probe { + if !args.mode_d_check { + return Err("--address-lsh-group-probe requires --mode-d-check".into()); + } + eprintln!( + "Fitting LSH group address probes for groups {:?} (bits={}, seeds={})", + lsh_groups, args.address_lsh_bits, args.address_lsh_seeds + ); + fit_address_lsh_group_models( + &mut weights, + &index, + &tokenizer, + &fit_prompts, + &selected_heads, + &bases, + &means, + &pca_bases, + &codebooks, + &lsh_groups, + args.address_lsh_bits, + args.address_lsh_seeds, + )? + } else { + HashMap::new() + }; + let address_supervised_models = if args.address_supervised_group_probe { + if !args.mode_d_check { + return Err("--address-supervised-group-probe requires --mode-d-check".into()); + } + eprintln!( + "Fitting supervised group address probes for groups {:?} (epochs={}, lr={}, l2={})", + supervised_groups, + args.address_supervised_epochs, + args.address_supervised_lr, + args.address_supervised_l2 + ); + fit_address_supervised_group_models( + &mut weights, + &index, + &tokenizer, + &fit_prompts, + &selected_heads, + &bases, + &means, + &pca_bases, + &codebooks, + &supervised_groups, + args.address_supervised_epochs, + args.address_supervised_lr, + args.address_supervised_l2, + )? + } else { + HashMap::new() + }; + let address_gamma_projected_models = if args.address_gamma_projected_group_probe { + if !args.mode_d_check { + return Err("--address-gamma-projected-group-probe requires --mode-d-check".into()); + } + eprintln!( + "Fitting gamma-projected supervised group address probes for groups {:?} (post_layers={:?}, random_ranks={:?}, random_seeds={:?}, learned_ranks={:?}, learned_epochs={}, learned_lr={}, learned_l2={}, learned_pca_iters={}, epochs={}, lr={}, l2={})", + gamma_projected_groups, + gamma_projected_layers, + gamma_random_ranks, + gamma_random_seeds, + gamma_learned_ranks, + args.address_gamma_learned_epochs, + args.address_gamma_learned_lr, + args.address_gamma_learned_l2, + args.address_gamma_learned_pca_iters, + args.address_supervised_epochs, + args.address_supervised_lr, + args.address_supervised_l2 + ); + fit_gamma_projected_address_models( + &mut weights, + &index, + &tokenizer, + &fit_prompts, + &selected_heads, + &bases, + &means, + &pca_bases, + &codebooks, + &gamma_projected_groups, + &gamma_projected_layers, + &gamma_random_ranks, + &gamma_random_seeds, + &gamma_learned_ranks, + args.address_gamma_learned_epochs, + args.address_gamma_learned_lr, + args.address_gamma_learned_l2, + args.address_gamma_learned_pca_iters, + args.address_supervised_epochs, + args.address_supervised_lr, + args.address_supervised_l2, + )? + } else { + HashMap::new() + }; + let address_prev_ffn_feature_models = if args.address_prev_ffn_feature_group_probe { + if !args.mode_d_check { + return Err("--address-prev-ffn-feature-group-probe requires --mode-d-check".into()); + } + eprintln!( + "Fitting previous-FFN feature group address probes for groups {:?} (top_k={})", + prev_ffn_feature_groups, args.address_prev_ffn_feature_top_k + ); + fit_address_prev_ffn_feature_group_models( + &mut weights, + &index, + &tokenizer, + &fit_prompts, + &selected_heads, + &bases, + &means, + &pca_bases, + &codebooks, + &prev_ffn_feature_groups, + args.address_prev_ffn_feature_top_k, + )? + } else { + HashMap::new() + }; + let address_ffn_first_feature_models = if args.address_ffn_first_feature_group_probe { + if !args.mode_d_check { + return Err("--address-ffn-first-feature-group-probe requires --mode-d-check".into()); + } + eprintln!( + "Fitting FFN-first feature group address probes for groups {:?} (top_k={})", + ffn_first_feature_groups, args.address_ffn_first_feature_top_k + ); + fit_address_ffn_first_feature_group_models( + &mut weights, + &index, + &tokenizer, + &fit_prompts, + &selected_heads, + &bases, + &means, + &pca_bases, + &codebooks, + &ffn_first_feature_groups, + args.address_ffn_first_feature_top_k, + )? + } else { + HashMap::new() + }; + let address_attention_relation_models = if args.address_attention_relation_group_probe { + if !args.mode_d_check { + return Err("--address-attention-relation-group-probe requires --mode-d-check".into()); + } + eprintln!( + "Fitting attention-relation group address probes for groups {:?}", + attention_relation_groups + ); + fit_address_attention_relation_group_models( + &mut weights, + &index, + &tokenizer, + &fit_prompts, + &selected_heads, + &bases, + &means, + &pca_bases, + &codebooks, + &attention_relation_groups, + )? + } else { + HashMap::new() + }; + let address_attention_cluster_models = if args.address_attention_cluster_group_probe { + if !args.mode_d_check { + return Err("--address-attention-cluster-group-probe requires --mode-d-check".into()); + } + eprintln!( + "Fitting attention-pattern cluster group address probes for groups {:?} (k={:?})", + attention_cluster_groups, attention_cluster_ks + ); + fit_address_attention_cluster_group_models( + &mut weights, + &index, + &tokenizer, + &fit_prompts, + &selected_heads, + &bases, + &means, + &pca_bases, + &codebooks, + &attention_cluster_groups, + &attention_cluster_ks, + )? + } else { + HashMap::new() + }; + let address_reduced_qk_cluster_models = if args.address_reduced_qk_cluster_group_probe { + if !args.mode_d_check { + return Err("--address-reduced-qk-cluster-group-probe requires --mode-d-check".into()); + } + eprintln!( + "Fitting reduced-QK cluster group address probes for groups {:?} (ranks={:?}, k={:?})", + reduced_qk_cluster_groups, reduced_qk_ranks, reduced_qk_cluster_ks + ); + fit_address_reduced_qk_cluster_group_models( + &mut weights, + &index, + &tokenizer, + &fit_prompts, + &selected_heads, + &bases, + &means, + &pca_bases, + &codebooks, + &reduced_qk_cluster_groups, + &reduced_qk_ranks, + &reduced_qk_cluster_ks, + )? + } else { + HashMap::new() + }; + if args.address_corruption_sweep && !args.mode_d_check { + return Err("--address-corruption-sweep requires --mode-d-check".into()); + } + if args.address_group_importance && !args.mode_d_check { + return Err("--address-group-importance requires --mode-d-check".into()); + } + if args.address_majority_group_probe && !args.mode_d_check { + return Err("--address-majority-group-probe requires --mode-d-check".into()); + } + if args.address_code_substitution_group_probe && !args.mode_d_check { + return Err("--address-code-substitution-group-probe requires --mode-d-check".into()); + } + if args.address_code_class_collapse_group_probe && !args.mode_d_check { + return Err("--address-code-class-collapse-group-probe requires --mode-d-check".into()); + } + if args.address_code_position_interaction_probe && !args.mode_d_check { + return Err("--address-code-position-interaction-probe requires --mode-d-check".into()); + } + if args.address_code_conditional_quotient_group_probe && !args.mode_d_check { + return Err( + "--address-code-conditional-quotient-group-probe requires --mode-d-check".into(), + ); + } + if args.address_code7_bos_rule_group_probe && !args.mode_d_check { + return Err("--address-code7-bos-rule-group-probe requires --mode-d-check".into()); + } + if args.address_code7_oracle_binary_group_probe && !args.mode_d_check { + return Err("--address-code7-oracle-binary-group-probe requires --mode-d-check".into()); + } + let majority_codes = if args.address_corruption_sweep + || args.address_group_importance + || args.address_lsh_group_probe + || args.address_supervised_group_probe + || args.address_gamma_projected_group_probe + || args.address_key_group_probe + || args.address_majority_group_probe + || args.address_code_substitution_group_probe + || args.address_code_class_collapse_group_probe + || args.address_code_position_interaction_probe + || args.address_code_conditional_quotient_group_probe + || args.address_code7_bos_rule_group_probe + || args.address_code7_oracle_binary_group_probe + || args.address_prev_ffn_feature_group_probe + || args.address_ffn_first_feature_group_probe + || args.address_attention_relation_group_probe + || args.address_attention_cluster_group_probe + || args.address_reduced_qk_cluster_group_probe + { + eprintln!("Fitting per-group majority codes for address diagnostics"); + fit_majority_codes_for_codebooks( + &mut weights, + &index, + &tokenizer, + &fit_prompts, + &selected_heads, + &bases, + &means, + &pca_bases, + &codebooks, + )? + } else { + HashMap::new() + }; + let code_stability = if args.address_code_stability { + eprintln!( + "Measuring PQ code stability for groups {:?}", + code_stability_groups + ); + measure_code_stability( + &mut weights, + &index, + &tokenizer, + &fit_prompts, + &eval_prompts, + &selected_heads, + &bases, + &means, + &pca_bases, + &codebooks, + &code_stability_groups, + )? + } else { + HashMap::new() + }; + + if args.address_code_occurrences { + let occurrence_prompts = match code_occurrence_split.as_str() { + "train" => fit_prompts.clone(), + "eval" => eval_prompts.clone(), + "all" => prompts.clone(), + _ => unreachable!("validated code occurrence split"), + }; + eprintln!( + "Exporting code occurrences for groups {:?}, codes {:?}, split {}", + code_occurrence_groups, code_occurrence_codes, code_occurrence_split + ); + let occurrences = collect_code_occurrences( + &mut weights, + &index, + &tokenizer, + &occurrence_prompts, + &selected_heads, + &bases, + &means, + &pca_bases, + &codebooks, + &code_occurrence_groups, + &code_occurrence_codes, + )?; + let occurrence_path = args.out.join("code_occurrences.json"); + let file = std::fs::File::create(&occurrence_path)?; + serde_json::to_writer_pretty(file, &occurrences)?; + eprintln!("Wrote {}", occurrence_path.display()); + } + + let mut accumulators: HashMap<(HeadId, PqConfig), OraclePqPointAccumulator> = HashMap::new(); + for head in &selected_heads { + for &config in &configs { + accumulators.insert((*head, config), OraclePqPointAccumulator::new()); + } + } + + for (prompt_idx, record) in eval_prompts.iter().enumerate() { + let label = record + .id + .as_deref() + .or(record.stratum.as_deref()) + .unwrap_or("prompt"); + eprintln!(" [{}/{}] {}", prompt_idx + 1, eval_prompts.len(), label); + + let token_ids = encode_prompt(&tokenizer, &*weights.arch, &record.prompt)?; + if token_ids.is_empty() { + continue; + } + let stratum = record.stratum.as_deref().unwrap_or("unknown"); + + let baseline_hidden = + larql_inference::vindex::predict_q4k_hidden(&mut weights, &token_ids, &index, None); + let baseline_logits = final_logits(&weights, &baseline_hidden); + let baseline_logp = log_softmax(&baseline_logits); + let baseline_top1 = argmax(&baseline_logits); + let baseline_top2 = top_k_indices(&baseline_logits, 2); + let baseline_top2_token = baseline_top2.get(1).copied().unwrap_or(baseline_top1); + let baseline_top1_prob = token_prob(&baseline_logp, baseline_top1); + let baseline_top2_prob = token_prob(&baseline_logp, baseline_top2_token); + let baseline_top1_margin = baseline_top1_prob - baseline_top2_prob; + + for head in &selected_heads { + let basis = bases.get(head).ok_or_else(|| { + format!("missing basis for oracle PQ L{} H{}", head.layer, head.head) + })?; + let head_means = means.get(head).ok_or_else(|| { + format!( + "missing position means for oracle PQ L{} H{}", + head.layer, head.head + ) + })?; + let pca_basis = pca_bases.get(head).ok_or_else(|| { + format!( + "missing empirical PCA basis for oracle PQ L{} H{}", + head.layer, head.head + ) + })?; + for &config in &configs { + let codebook = codebooks.get(&(*head, config)).ok_or_else(|| { + format!("missing PQ codebook for L{} H{}", head.layer, head.head) + })?; + let (pq_hidden, metrics, oracle_codes_by_position) = forward_q4k_oracle_pq_head( + &mut weights, + &token_ids, + &index, + *head, + basis, + pca_basis, + head_means, + codebook, + stratum, + )?; + let pq_logits = final_logits(&weights, &pq_hidden); + let pq_logp = log_softmax(&pq_logits); + let kl = kl_logp(&baseline_logp, &pq_logp); + let pq_top1 = argmax(&pq_logits); + let pq_top5 = top_k_indices(&pq_logits, 5); + let pq_top2 = top_k_indices(&pq_logits, 2); + let pq_top2_token = pq_top2.get(1).copied().unwrap_or(pq_top1); + let pq_top1_prob = token_prob(&pq_logp, pq_top1); + let pq_top2_prob = token_prob(&pq_logp, pq_top2_token); + let pq_top1_margin = pq_top1_prob - pq_top2_prob; + let pq_prob_of_baseline_top1 = token_prob(&pq_logp, baseline_top1); + + let ( + mode_d_kl, + mode_d_top1, + mode_d_top1_agree, + baseline_top1_in_mode_d_top5, + coeff_mode_d_max_abs_logit_diff, + ) = if args.mode_d_check { + let mode_d_table = mode_d_tables.get(&(*head, config)).ok_or_else(|| { + format!( + "missing Mode D table for L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let mode_d_hidden = forward_q4k_oracle_pq_mode_d_head( + &mut weights, + &token_ids, + &index, + *head, + basis, + pca_basis, + head_means, + codebook, + mode_d_table, + stratum, + )?; + let mode_d_logits = final_logits(&weights, &mode_d_hidden); + let mode_d_logp = log_softmax(&mode_d_logits); + let mode_d_top1 = argmax(&mode_d_logits); + let mode_d_top5 = top_k_indices(&mode_d_logits, 5); + ( + Some(kl_logp(&baseline_logp, &mode_d_logp)), + Some(mode_d_top1), + Some(baseline_top1 == mode_d_top1), + Some(mode_d_top5.contains(&baseline_top1)), + Some(max_abs_diff(&pq_logits, &mode_d_logits)), + ) + } else { + (None, None, None, None, None) + }; + + if run_address_probes { + let mode_d_table = mode_d_tables.get(&(*head, config)).ok_or_else(|| { + format!( + "missing Mode D table for address probes L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let probe_models = + address_probe_models.get(&(*head, config)).ok_or_else(|| { + format!( + "missing address probe models for L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + for probe_model in probe_models { + let full_probe_enabled = + args.address_probes || probe_model.name == "mixed_best_simple_key"; + if full_probe_enabled { + let predicted_codes_by_position = (0..token_ids.len()) + .map(|pos| probe_model.predict_codes(&token_ids, stratum, pos)) + .collect::>(); + let prompt_report = evaluate_predicted_address( + &mut weights, + &token_ids, + &index, + *head, + mode_d_table, + &predicted_codes_by_position, + stratum, + label, + &baseline_logp, + baseline_top1, + &oracle_codes_by_position, + )?; + accumulators + .get_mut(&(*head, config)) + .expect("oracle PQ accumulator missing") + .add_address_probe( + &probe_model.name, + &probe_model.selected_group_keys, + prompt_report, + ); + } + if args.address_key_group_probe { + if !key_group_probe_names.is_empty() + && !key_group_probe_names.contains(&probe_model.name) + { + continue; + } + let group_majority = + majority_codes.get(&(*head, config)).ok_or_else(|| { + format!( + "missing majority codes for key group probe L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + for (probe_name, use_oracle_rest) in [ + ( + format!( + "{}_groups_{:?}_oracle_rest", + probe_model.name, key_groups + ), + true, + ), + ( + format!( + "{}_groups_{:?}_majority_rest", + probe_model.name, key_groups + ), + false, + ), + ] { + let predicted_codes_by_position = oracle_codes_by_position + .iter() + .enumerate() + .map(|(pos, oracle_codes)| { + let mut codes = if use_oracle_rest { + oracle_codes.clone() + } else { + group_majority.clone() + }; + let probe_codes = + probe_model.predict_codes(&token_ids, stratum, pos); + for &group in &key_groups { + codes[group] = probe_codes[group]; + } + codes + }) + .collect::>(); + let prompt_report = evaluate_predicted_address( + &mut weights, + &token_ids, + &index, + *head, + mode_d_table, + &predicted_codes_by_position, + stratum, + label, + &baseline_logp, + baseline_top1, + &oracle_codes_by_position, + )?; + accumulators + .get_mut(&(*head, config)) + .expect("oracle PQ accumulator missing") + .add_address_probe( + &probe_name, + &probe_model.selected_group_keys, + prompt_report, + ); + } + } + } + } + + if args.address_majority_group_probe { + let mode_d_table = mode_d_tables.get(&(*head, config)).ok_or_else(|| { + format!( + "missing Mode D table for majority group probe L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let group_majority = majority_codes.get(&(*head, config)).ok_or_else(|| { + format!( + "missing majority codes for majority group probe L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let predicted_codes_by_position = oracle_codes_by_position + .iter() + .map(|oracle_codes| { + let mut codes = oracle_codes.clone(); + for &group in &majority_groups { + codes[group] = group_majority[group]; + } + codes + }) + .collect::>(); + let prompt_report = evaluate_predicted_address( + &mut weights, + &token_ids, + &index, + *head, + mode_d_table, + &predicted_codes_by_position, + stratum, + label, + &baseline_logp, + baseline_top1, + &oracle_codes_by_position, + )?; + let selected_group_keys = (0..config.groups) + .map(|group| { + if majority_groups.contains(&group) { + "majority".to_string() + } else { + "oracle".to_string() + } + }) + .collect::>(); + accumulators + .get_mut(&(*head, config)) + .expect("oracle PQ accumulator missing") + .add_address_probe( + &format!("majority_groups_{:?}_oracle_rest", majority_groups), + &selected_group_keys, + prompt_report, + ); + } + + if args.address_code_substitution_group_probe { + let mode_d_table = mode_d_tables.get(&(*head, config)).ok_or_else(|| { + format!( + "missing Mode D table for code substitution probe L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let group_majority = majority_codes.get(&(*head, config)).ok_or_else(|| { + format!( + "missing majority codes for code substitution probe L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let levels = 1usize << config.bits_per_group; + let from_codes = if code_substitution_from_codes.is_empty() { + (0..levels).collect::>() + } else { + code_substitution_from_codes.clone() + }; + for &group in &code_substitution_groups { + for &from_code in &from_codes { + let source_code_present = oracle_codes_by_position + .iter() + .any(|codes| codes[group] == from_code); + for to_spec in &code_substitution_to_specs { + let to_code = match *to_spec { + CodeSubstitutionToSpec::Majority => group_majority[group], + CodeSubstitutionToSpec::Code(code) => code, + }; + if to_code == from_code { + continue; + } + let predicted_codes_by_position = oracle_codes_by_position + .iter() + .map(|oracle_codes| { + let mut codes = oracle_codes.clone(); + if codes[group] == from_code { + codes[group] = to_code; + } + codes + }) + .collect::>(); + let prompt_report = if source_code_present { + evaluate_predicted_address( + &mut weights, + &token_ids, + &index, + *head, + mode_d_table, + &predicted_codes_by_position, + stratum, + label, + &baseline_logp, + baseline_top1, + &oracle_codes_by_position, + )? + } else { + oracle_mode_d_address_report( + label, + stratum, + token_ids.len(), + config.groups, + mode_d_kl.unwrap_or(kl), + mode_d_top1_agree.unwrap_or(false), + baseline_top1_in_mode_d_top5.unwrap_or(false), + ) + }; + let to_label = match *to_spec { + CodeSubstitutionToSpec::Majority => { + format!("majority{}", group_majority[group]) + } + CodeSubstitutionToSpec::Code(code) => code.to_string(), + }; + let selected_group_keys = (0..config.groups) + .map(|candidate_group| { + if candidate_group == group { + format!("from{from_code}_to{to_label}") + } else { + "oracle".to_string() + } + }) + .collect::>(); + accumulators + .get_mut(&(*head, config)) + .expect("oracle PQ accumulator missing") + .add_address_probe( + &format!( + "code_subst_g{group}_from{from_code}_to{to_label}_oracle_rest" + ), + &selected_group_keys, + prompt_report, + ); + } + } + } + } + + if args.address_code_class_collapse_group_probe { + let mode_d_table = mode_d_tables.get(&(*head, config)).ok_or_else(|| { + format!( + "missing Mode D table for code class-collapse probe L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + for collapse_spec in &code_class_collapse_specs { + let predicted_codes_by_position = oracle_codes_by_position + .iter() + .map(|oracle_codes| { + let mut codes = oracle_codes.clone(); + for &group in &code_class_collapse_groups { + for mapping in &collapse_spec.mappings { + if mapping.sources.contains(&oracle_codes[group]) { + codes[group] = mapping.target; + break; + } + } + } + codes + }) + .collect::>(); + let prompt_report = evaluate_predicted_address( + &mut weights, + &token_ids, + &index, + *head, + mode_d_table, + &predicted_codes_by_position, + stratum, + label, + &baseline_logp, + baseline_top1, + &oracle_codes_by_position, + )?; + let selected_group_keys = (0..config.groups) + .map(|group| { + if code_class_collapse_groups.contains(&group) { + collapse_spec.mapping_label() + } else { + "oracle".to_string() + } + }) + .collect::>(); + accumulators + .get_mut(&(*head, config)) + .expect("oracle PQ accumulator missing") + .add_address_probe( + &format!( + "code_class_collapse_{}_groups_{:?}_oracle_rest", + collapse_spec.name, code_class_collapse_groups + ), + &selected_group_keys, + prompt_report, + ); + } + } + + if args.address_code_position_interaction_probe + && label == code_position_prompt_id.as_str() + { + let mode_d_table = mode_d_tables.get(&(*head, config)).ok_or_else(|| { + format!( + "missing Mode D table for code position-interaction probe L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let group = args.address_code_position_group; + let target_code = args.address_code_position_target_code; + let primary_positions = oracle_codes_by_position + .iter() + .enumerate() + .filter_map(|(pos, codes)| { + code_position_primary_codes + .contains(&codes[group]) + .then_some(pos) + }) + .collect::>(); + let secondary_positions = oracle_codes_by_position + .iter() + .enumerate() + .filter_map(|(pos, codes)| { + code_position_secondary_codes + .contains(&codes[group]) + .then_some(pos) + }) + .collect::>(); + + let mut emit_position_variant = + |variant_name: String, + mut changed_positions: Vec| + -> Result<(), Box> { + changed_positions.sort_unstable(); + changed_positions.dedup(); + if changed_positions.is_empty() { + return Ok(()); + } + let predicted_codes_by_position = oracle_codes_by_position + .iter() + .enumerate() + .map(|(pos, oracle_codes)| { + let mut codes = oracle_codes.clone(); + if changed_positions.binary_search(&pos).is_ok() { + codes[group] = target_code; + } + codes + }) + .collect::>(); + let prompt_report = evaluate_predicted_address( + &mut weights, + &token_ids, + &index, + *head, + mode_d_table, + &predicted_codes_by_position, + stratum, + label, + &baseline_logp, + baseline_top1, + &oracle_codes_by_position, + )?; + let selected_group_keys = (0..config.groups) + .map(|candidate_group| { + if candidate_group == group { + format!( + "{variant_name}_positions_{}", + changed_positions + .iter() + .map(ToString::to_string) + .collect::>() + .join("+") + ) + } else { + "oracle".to_string() + } + }) + .collect::>(); + accumulators + .get_mut(&(*head, config)) + .expect("oracle PQ accumulator missing") + .add_address_probe( + &format!( + "pos_interaction_g{group}_{variant_name}_to{target_code}_oracle_rest" + ), + &selected_group_keys, + prompt_report, + ); + Ok(()) + }; + + emit_position_variant("A0_all_primary".to_string(), primary_positions.clone())?; + emit_position_variant( + "A1_all_secondary".to_string(), + secondary_positions.clone(), + )?; + let mut all_primary_secondary = primary_positions.clone(); + all_primary_secondary.extend(secondary_positions.iter().copied()); + emit_position_variant( + "A2_all_primary_all_secondary".to_string(), + all_primary_secondary, + )?; + for (idx, &secondary_pos) in secondary_positions.iter().enumerate() { + let mut changed = primary_positions.clone(); + changed.push(secondary_pos); + emit_position_variant( + format!("A{}_all_primary_secondary_pos{secondary_pos}", idx + 3), + changed, + )?; + } + let leave_one_offset = 3 + secondary_positions.len(); + for (idx, &secondary_pos) in secondary_positions.iter().enumerate() { + let mut changed = primary_positions.clone(); + changed.extend( + secondary_positions + .iter() + .copied() + .filter(|pos| *pos != secondary_pos), + ); + emit_position_variant( + format!( + "A{}_all_primary_all_secondary_except_pos{secondary_pos}", + leave_one_offset + idx + ), + changed, + )?; + } + for &primary_pos in &primary_positions { + let mut changed = secondary_positions.clone(); + changed.push(primary_pos); + emit_position_variant( + format!("all_secondary_primary_pos{primary_pos}"), + changed, + )?; + } + for &primary_pos in &primary_positions { + let mut changed = secondary_positions.clone(); + changed.extend( + primary_positions + .iter() + .copied() + .filter(|pos| *pos != primary_pos), + ); + emit_position_variant( + format!("all_primary_except_pos{primary_pos}_all_secondary"), + changed, + )?; + } + } + + if args.address_code_conditional_quotient_group_probe { + let mode_d_table = mode_d_tables.get(&(*head, config)).ok_or_else(|| { + format!( + "missing Mode D table for code conditional-quotient probe L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let group = args.address_code_conditional_quotient_group; + let target_code = args.address_code_conditional_quotient_target_code; + let early_position_max = + args.address_code_conditional_quotient_early_position_max; + let attention_rows = + capture_attention_relation_rows(&mut weights, &token_ids, &index, *head)?; + for &guard in &code_conditional_quotient_guards { + for extra_spec in &code_conditional_quotient_extra_specs { + let predicted_codes_by_position = oracle_codes_by_position + .iter() + .enumerate() + .map(|(pos, oracle_codes)| { + let mut codes = oracle_codes.clone(); + let group_code = oracle_codes[group]; + if code_conditional_quotient_primary_codes.contains(&group_code) + { + codes[group] = target_code; + } else if code_conditional_quotient_secondary_codes + .contains(&group_code) + && !guard.keeps_secondary_oracle( + stratum, + pos, + early_position_max, + attention_rows + .get(pos) + .map(Vec::as_slice) + .unwrap_or(&[]), + ) + { + codes[group] = target_code; + } + for mapping in &extra_spec.mappings { + if mapping.sources.contains(&group_code) { + codes[group] = mapping.target; + break; + } + } + codes + }) + .collect::>(); + let prompt_report = evaluate_predicted_address( + &mut weights, + &token_ids, + &index, + *head, + mode_d_table, + &predicted_codes_by_position, + stratum, + label, + &baseline_logp, + baseline_top1, + &oracle_codes_by_position, + )?; + let selected_group_keys = (0..config.groups) + .map(|candidate_group| { + if candidate_group == group { + format!( + "{}_primary{}_secondary{}_to{}_extra{}", + guard.label(), + code_conditional_quotient_primary_codes + .iter() + .map(ToString::to_string) + .collect::>() + .join("+"), + code_conditional_quotient_secondary_codes + .iter() + .map(ToString::to_string) + .collect::>() + .join("+"), + target_code, + extra_spec.mapping_label_or_base() + ) + } else { + "oracle".to_string() + } + }) + .collect::>(); + accumulators + .get_mut(&(*head, config)) + .expect("oracle PQ accumulator missing") + .add_address_probe( + &format!( + "code_conditional_quotient_g{group}_{}_extra{}_to{target_code}_oracle_rest", + guard.label(), + extra_spec.name + ), + &selected_group_keys, + prompt_report, + ); + } + } + } + + if args.address_code7_bos_rule_group_probe { + let mode_d_table = mode_d_tables.get(&(*head, config)).ok_or_else(|| { + format!( + "missing Mode D table for code7 BOS rule probe L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let group_majority = majority_codes.get(&(*head, config)).ok_or_else(|| { + format!( + "missing majority codes for code7 BOS rule probe L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let attention_rows = + capture_attention_relation_rows(&mut weights, &token_ids, &index, *head)?; + let use_special_code = stratum != "arithmetic"; + let predicted_codes_by_position = oracle_codes_by_position + .iter() + .enumerate() + .map(|(pos, oracle_codes)| { + let mut codes = oracle_codes.clone(); + let attention_weights = + attention_rows.get(pos).map(Vec::as_slice).unwrap_or(&[]); + let predicts_special = use_special_code + && !attention_weights.is_empty() + && attention_argmax(attention_weights, pos) == 0; + for &group in &code7_bos_rule_groups { + codes[group] = if predicts_special { + args.address_code7_bos_rule_code + } else { + group_majority[group] + }; + } + codes + }) + .collect::>(); + let prompt_report = evaluate_predicted_address( + &mut weights, + &token_ids, + &index, + *head, + mode_d_table, + &predicted_codes_by_position, + stratum, + label, + &baseline_logp, + baseline_top1, + &oracle_codes_by_position, + )?; + let selected_group_keys = (0..config.groups) + .map(|group| { + if code7_bos_rule_groups.contains(&group) { + format!( + "bos_non_arithmetic_to_code{}_else_majority{}", + args.address_code7_bos_rule_code, group_majority[group] + ) + } else { + "oracle".to_string() + } + }) + .collect::>(); + accumulators + .get_mut(&(*head, config)) + .expect("oracle PQ accumulator missing") + .add_address_probe( + &format!( + "code{}_bos_non_arithmetic_groups_{:?}_oracle_rest", + args.address_code7_bos_rule_code, code7_bos_rule_groups + ), + &selected_group_keys, + prompt_report, + ); + } + + if args.address_code7_oracle_binary_group_probe { + let mode_d_table = mode_d_tables.get(&(*head, config)).ok_or_else(|| { + format!( + "missing Mode D table for code7 oracle binary probe L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let group_majority = majority_codes.get(&(*head, config)).ok_or_else(|| { + format!( + "missing majority codes for code7 oracle binary probe L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let attention_rows = + capture_attention_relation_rows(&mut weights, &token_ids, &index, *head)?; + for filter in &code7_oracle_binary_filters { + let predicted_codes_by_position = oracle_codes_by_position + .iter() + .enumerate() + .map(|(pos, oracle_codes)| { + let mut codes = oracle_codes.clone(); + let attention_weights = + attention_rows.get(pos).map(Vec::as_slice).unwrap_or(&[]); + let relation_matches = match filter.as_str() { + "all" => true, + "natural_prose_bos" => { + stratum == "natural_prose" + && !attention_weights.is_empty() + && attention_argmax(attention_weights, pos) == 0 + } + "natural_prose_bos_or_prev" => { + stratum == "natural_prose" + && (!attention_weights.is_empty() + && (attention_argmax(attention_weights, pos) == 0 + || attention_argmax(attention_weights, pos) + == pos.saturating_sub(1))) + } + _ => unreachable!("validated oracle binary filter"), + }; + for &group in &code7_oracle_binary_groups { + codes[group] = if relation_matches + && oracle_codes[group] + == args.address_code7_oracle_binary_code + { + args.address_code7_oracle_binary_code + } else { + group_majority[group] + }; + } + codes + }) + .collect::>(); + let prompt_report = evaluate_predicted_address( + &mut weights, + &token_ids, + &index, + *head, + mode_d_table, + &predicted_codes_by_position, + stratum, + label, + &baseline_logp, + baseline_top1, + &oracle_codes_by_position, + )?; + let selected_group_keys = (0..config.groups) + .map(|group| { + if code7_oracle_binary_groups.contains(&group) { + format!( + "oracle_{}_code{}_else_majority{}", + filter, + args.address_code7_oracle_binary_code, + group_majority[group] + ) + } else { + "oracle".to_string() + } + }) + .collect::>(); + accumulators + .get_mut(&(*head, config)) + .expect("oracle PQ accumulator missing") + .add_address_probe( + &format!( + "oracle_binary_{}_code{}_groups_{:?}_oracle_rest", + filter, + args.address_code7_oracle_binary_code, + code7_oracle_binary_groups + ), + &selected_group_keys, + prompt_report, + ); + } + } + + if args.address_group_importance { + let mode_d_table = mode_d_tables.get(&(*head, config)).ok_or_else(|| { + format!( + "missing Mode D table for address group importance L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let group_majority = majority_codes.get(&(*head, config)).ok_or_else(|| { + format!( + "missing majority codes for address group importance L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + for replaced_group in 0..config.groups { + let predicted_codes_by_position = oracle_codes_by_position + .iter() + .map(|codes| { + codes + .iter() + .enumerate() + .map(|(group, &code)| { + if group == replaced_group { + group_majority[group] + } else { + code + } + }) + .collect::>() + }) + .collect::>(); + let prompt_report = evaluate_predicted_address( + &mut weights, + &token_ids, + &index, + *head, + mode_d_table, + &predicted_codes_by_position, + stratum, + label, + &baseline_logp, + baseline_top1, + &oracle_codes_by_position, + )?; + accumulators + .get_mut(&(*head, config)) + .expect("oracle PQ accumulator missing") + .add_address_group_importance(replaced_group, prompt_report); + } + } + + if args.address_lsh_group_probe { + let mode_d_table = mode_d_tables.get(&(*head, config)).ok_or_else(|| { + format!( + "missing Mode D table for LSH group probe L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let lsh_model = address_lsh_models.get(&(*head, config)).ok_or_else(|| { + format!( + "missing LSH group probe model for L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let group_majority = majority_codes.get(&(*head, config)).ok_or_else(|| { + format!( + "missing majority codes for LSH group probe L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let layer_input = + capture_layer_input_hidden(&mut weights, &token_ids, &index, head.layer)?; + let selected_group_keys = lsh_model.selected_group_keys(); + for (probe_name, use_oracle_rest) in [ + ( + format!("lsh_groups_{:?}_oracle_rest", lsh_model.groups), + true, + ), + ( + format!("lsh_groups_{:?}_majority_rest", lsh_model.groups), + false, + ), + ] { + let predicted_codes_by_position = oracle_codes_by_position + .iter() + .enumerate() + .map(|(pos, oracle_codes)| { + let base_codes = if use_oracle_rest { + oracle_codes.as_slice() + } else { + group_majority.as_slice() + }; + lsh_model.predict_selected_groups(&layer_input, pos, base_codes) + }) + .collect::>(); + let prompt_report = evaluate_predicted_address( + &mut weights, + &token_ids, + &index, + *head, + mode_d_table, + &predicted_codes_by_position, + stratum, + label, + &baseline_logp, + baseline_top1, + &oracle_codes_by_position, + )?; + accumulators + .get_mut(&(*head, config)) + .expect("oracle PQ accumulator missing") + .add_address_probe(&probe_name, &selected_group_keys, prompt_report); + } + } + + if args.address_supervised_group_probe { + let mode_d_table = mode_d_tables.get(&(*head, config)).ok_or_else(|| { + format!( + "missing Mode D table for supervised group probe L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let supervised_model = address_supervised_models + .get(&(*head, config)) + .ok_or_else(|| { + format!( + "missing supervised group probe model for L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let group_majority = majority_codes.get(&(*head, config)).ok_or_else(|| { + format!( + "missing majority codes for supervised group probe L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let layer_input = + capture_layer_input_hidden(&mut weights, &token_ids, &index, head.layer)?; + let selected_group_keys = supervised_model.selected_group_keys(); + for (probe_name, use_oracle_rest) in [ + ( + format!( + "supervised_hyperplane_groups_{:?}_oracle_rest", + supervised_model.groups + ), + true, + ), + ( + format!( + "supervised_hyperplane_groups_{:?}_majority_rest", + supervised_model.groups + ), + false, + ), + ] { + let predicted_codes_by_position = oracle_codes_by_position + .iter() + .enumerate() + .map(|(pos, oracle_codes)| { + let base_codes = if use_oracle_rest { + oracle_codes.as_slice() + } else { + group_majority.as_slice() + }; + supervised_model.predict_selected_groups( + &layer_input, + pos, + base_codes, + ) + }) + .collect::>(); + let prompt_report = evaluate_predicted_address( + &mut weights, + &token_ids, + &index, + *head, + mode_d_table, + &predicted_codes_by_position, + stratum, + label, + &baseline_logp, + baseline_top1, + &oracle_codes_by_position, + )?; + accumulators + .get_mut(&(*head, config)) + .expect("oracle PQ accumulator missing") + .add_address_probe(&probe_name, &selected_group_keys, prompt_report); + } + } + + if args.address_gamma_projected_group_probe { + let mode_d_table = mode_d_tables.get(&(*head, config)).ok_or_else(|| { + format!( + "missing Mode D table for gamma-projected group probe L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let gamma_models = address_gamma_projected_models + .get(&(*head, config)) + .ok_or_else(|| { + format!( + "missing gamma-projected group probe models for L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let group_majority = majority_codes.get(&(*head, config)).ok_or_else(|| { + format!( + "missing majority codes for gamma-projected group probe L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let layer_input = + capture_layer_input_hidden(&mut weights, &token_ids, &index, head.layer)?; + for gamma_model in gamma_models { + let projected_input = gamma_model.project_layer_input(&layer_input)?; + let selected_group_keys = gamma_model.selected_group_keys(); + for (probe_name, use_oracle_rest) in [ + ( + format!( + "{}_groups_{:?}_oracle_rest", + gamma_model.name, gamma_projected_groups + ), + true, + ), + ( + format!( + "{}_groups_{:?}_majority_rest", + gamma_model.name, gamma_projected_groups + ), + false, + ), + ] { + let predicted_codes_by_position = oracle_codes_by_position + .iter() + .enumerate() + .map(|(pos, oracle_codes)| { + let base_codes = if use_oracle_rest { + oracle_codes.as_slice() + } else { + group_majority.as_slice() + }; + gamma_model.supervised.predict_selected_groups( + &projected_input, + pos, + base_codes, + ) + }) + .collect::>(); + let prompt_report = evaluate_predicted_address( + &mut weights, + &token_ids, + &index, + *head, + mode_d_table, + &predicted_codes_by_position, + stratum, + label, + &baseline_logp, + baseline_top1, + &oracle_codes_by_position, + )?; + accumulators + .get_mut(&(*head, config)) + .expect("oracle PQ accumulator missing") + .add_address_probe( + &probe_name, + &selected_group_keys, + prompt_report, + ); + } + } + } + + if args.address_prev_ffn_feature_group_probe { + let mode_d_table = mode_d_tables.get(&(*head, config)).ok_or_else(|| { + format!( + "missing Mode D table for previous-FFN feature group probe L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let prev_feature_models = address_prev_ffn_feature_models + .get(&(*head, config)) + .ok_or_else(|| { + format!( + "missing previous-FFN feature group probe model for L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let group_majority = majority_codes.get(&(*head, config)).ok_or_else(|| { + format!( + "missing majority codes for previous-FFN feature group probe L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let prev_features_by_position = capture_prev_ffn_feature_keys( + &mut weights, + &token_ids, + &index, + head.layer, + args.address_prev_ffn_feature_top_k, + )?; + for probe_model in prev_feature_models { + let selected_group_keys = probe_model.selected_group_keys.clone(); + for (probe_name, use_oracle_rest) in [ + ( + format!( + "{}_groups_{:?}_oracle_rest", + probe_model.name, prev_ffn_feature_groups + ), + true, + ), + ( + format!( + "{}_groups_{:?}_majority_rest", + probe_model.name, prev_ffn_feature_groups + ), + false, + ), + ] { + let predicted_codes_by_position = oracle_codes_by_position + .iter() + .enumerate() + .map(|(pos, oracle_codes)| { + let mut codes = if use_oracle_rest { + oracle_codes.clone() + } else { + group_majority.clone() + }; + let prev_features = prev_features_by_position + .get(pos) + .map(Vec::as_slice) + .unwrap_or(&[]); + let key = prev_ffn_feature_key( + &probe_model.name, + &token_ids, + stratum, + pos, + prev_features, + ); + let probe_codes = probe_model.predict_codes_from_key(&key); + for &group in &prev_ffn_feature_groups { + codes[group] = probe_codes[group]; + } + codes + }) + .collect::>(); + let prompt_report = evaluate_predicted_address( + &mut weights, + &token_ids, + &index, + *head, + mode_d_table, + &predicted_codes_by_position, + stratum, + label, + &baseline_logp, + baseline_top1, + &oracle_codes_by_position, + )?; + accumulators + .get_mut(&(*head, config)) + .expect("oracle PQ accumulator missing") + .add_address_probe( + &probe_name, + &selected_group_keys, + prompt_report, + ); + } + } + } + + if args.address_ffn_first_feature_group_probe { + let mode_d_table = mode_d_tables.get(&(*head, config)).ok_or_else(|| { + format!( + "missing Mode D table for FFN-first feature group probe L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let ffn_first_models = address_ffn_first_feature_models + .get(&(*head, config)) + .ok_or_else(|| { + format!( + "missing FFN-first feature group probe model for L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let group_majority = majority_codes.get(&(*head, config)).ok_or_else(|| { + format!( + "missing majority codes for FFN-first feature group probe L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let ffn_first_features_by_position = capture_ffn_first_feature_keys( + &mut weights, + &token_ids, + &index, + head.layer, + args.address_ffn_first_feature_top_k, + )?; + for probe_model in ffn_first_models { + let selected_group_keys = probe_model.selected_group_keys.clone(); + for (probe_name, use_oracle_rest) in [ + ( + format!( + "{}_groups_{:?}_oracle_rest", + probe_model.name, ffn_first_feature_groups + ), + true, + ), + ( + format!( + "{}_groups_{:?}_majority_rest", + probe_model.name, ffn_first_feature_groups + ), + false, + ), + ] { + let predicted_codes_by_position = oracle_codes_by_position + .iter() + .enumerate() + .map(|(pos, oracle_codes)| { + let mut codes = if use_oracle_rest { + oracle_codes.clone() + } else { + group_majority.clone() + }; + let ffn_first_features = ffn_first_features_by_position + .get(pos) + .map(Vec::as_slice) + .unwrap_or(&[]); + let key = ffn_first_feature_key( + &probe_model.name, + &token_ids, + stratum, + pos, + ffn_first_features, + ); + let probe_codes = probe_model.predict_codes_from_key(&key); + for &group in &ffn_first_feature_groups { + codes[group] = probe_codes[group]; + } + codes + }) + .collect::>(); + let prompt_report = evaluate_predicted_address( + &mut weights, + &token_ids, + &index, + *head, + mode_d_table, + &predicted_codes_by_position, + stratum, + label, + &baseline_logp, + baseline_top1, + &oracle_codes_by_position, + )?; + accumulators + .get_mut(&(*head, config)) + .expect("oracle PQ accumulator missing") + .add_address_probe( + &probe_name, + &selected_group_keys, + prompt_report, + ); + } + } + } + + if args.address_attention_relation_group_probe { + let mode_d_table = mode_d_tables.get(&(*head, config)).ok_or_else(|| { + format!( + "missing Mode D table for attention-relation group probe L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let relation_models = address_attention_relation_models + .get(&(*head, config)) + .ok_or_else(|| { + format!( + "missing attention-relation group probe model for L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let group_majority = majority_codes.get(&(*head, config)).ok_or_else(|| { + format!( + "missing majority codes for attention-relation group probe L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let attention_rows = + capture_attention_relation_rows(&mut weights, &token_ids, &index, *head)?; + for probe_model in relation_models { + let selected_group_keys = probe_model.selected_group_keys.clone(); + for (probe_name, use_oracle_rest) in [ + ( + format!( + "{}_groups_{:?}_oracle_rest", + probe_model.name, attention_relation_groups + ), + true, + ), + ( + format!( + "{}_groups_{:?}_majority_rest", + probe_model.name, attention_relation_groups + ), + false, + ), + ] { + let predicted_codes_by_position = oracle_codes_by_position + .iter() + .enumerate() + .map(|(pos, oracle_codes)| { + let mut codes = if use_oracle_rest { + oracle_codes.clone() + } else { + group_majority.clone() + }; + let attention_weights = + attention_rows.get(pos).map(Vec::as_slice).unwrap_or(&[]); + let key = attention_relation_key( + &probe_model.name, + &token_ids, + stratum, + pos, + attention_weights, + ); + let probe_codes = probe_model.predict_codes_from_key(&key); + for &group in &attention_relation_groups { + codes[group] = probe_codes[group]; + } + codes + }) + .collect::>(); + let prompt_report = evaluate_predicted_address( + &mut weights, + &token_ids, + &index, + *head, + mode_d_table, + &predicted_codes_by_position, + stratum, + label, + &baseline_logp, + baseline_top1, + &oracle_codes_by_position, + )?; + accumulators + .get_mut(&(*head, config)) + .expect("oracle PQ accumulator missing") + .add_address_probe( + &probe_name, + &selected_group_keys, + prompt_report, + ); + } + } + } + + if args.address_attention_cluster_group_probe { + let mode_d_table = mode_d_tables.get(&(*head, config)).ok_or_else(|| { + format!( + "missing Mode D table for attention-cluster group probe L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let cluster_models = address_attention_cluster_models + .get(&(*head, config)) + .ok_or_else(|| { + format!( + "missing attention-cluster group probe model for L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let group_majority = majority_codes.get(&(*head, config)).ok_or_else(|| { + format!( + "missing majority codes for attention-cluster group probe L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let attention_rows = + capture_attention_relation_rows(&mut weights, &token_ids, &index, *head)?; + for cluster_model in cluster_models { + if !attention_cluster_probe_names.is_empty() + && !attention_cluster_probe_names.contains(&cluster_model.name) + { + continue; + } + let selected_group_keys = cluster_model.selected_group_keys.clone(); + for (probe_name, use_oracle_rest) in [ + ( + format!( + "{}_groups_{:?}_oracle_rest", + cluster_model.name, attention_cluster_groups + ), + true, + ), + ( + format!( + "{}_groups_{:?}_majority_rest", + cluster_model.name, attention_cluster_groups + ), + false, + ), + ] { + let predicted_codes_by_position = oracle_codes_by_position + .iter() + .enumerate() + .map(|(pos, oracle_codes)| { + let base_codes = if use_oracle_rest { + oracle_codes.as_slice() + } else { + group_majority.as_slice() + }; + let attention_weights = + attention_rows.get(pos).map(Vec::as_slice).unwrap_or(&[]); + cluster_model.predict_selected_groups( + &token_ids, + stratum, + pos, + attention_weights, + base_codes, + ) + }) + .collect::>(); + let prompt_report = evaluate_predicted_address( + &mut weights, + &token_ids, + &index, + *head, + mode_d_table, + &predicted_codes_by_position, + stratum, + label, + &baseline_logp, + baseline_top1, + &oracle_codes_by_position, + )?; + accumulators + .get_mut(&(*head, config)) + .expect("oracle PQ accumulator missing") + .add_address_probe( + &probe_name, + &selected_group_keys, + prompt_report, + ); + } + } + } + + if args.address_reduced_qk_cluster_group_probe { + let mode_d_table = mode_d_tables.get(&(*head, config)).ok_or_else(|| { + format!( + "missing Mode D table for reduced-QK cluster group probe L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let cluster_models = address_reduced_qk_cluster_models + .get(&(*head, config)) + .ok_or_else(|| { + format!( + "missing reduced-QK cluster group probe model for L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let group_majority = majority_codes.get(&(*head, config)).ok_or_else(|| { + format!( + "missing majority codes for reduced-QK cluster group probe L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let mut rows_by_rank: HashMap, Vec>> = HashMap::new(); + for cluster_model in cluster_models { + if !reduced_qk_cluster_probe_names.is_empty() + && !reduced_qk_cluster_probe_names.contains(&cluster_model.name) + { + continue; + } + if !rows_by_rank.contains_key(&cluster_model.qk_rank) { + let rows = if let Some(qk_rank) = cluster_model.qk_rank { + capture_reduced_qk_attention_rows( + &mut weights, + &token_ids, + &index, + *head, + qk_rank, + )? + } else { + capture_attention_relation_rows( + &mut weights, + &token_ids, + &index, + *head, + )? + }; + rows_by_rank.insert(cluster_model.qk_rank, rows); + } + let attention_rows = rows_by_rank + .get(&cluster_model.qk_rank) + .expect("reduced-QK rows were just inserted"); + let selected_group_keys = cluster_model.selected_group_keys.clone(); + for (probe_name, use_oracle_rest) in [ + ( + format!( + "{}_groups_{:?}_oracle_rest", + cluster_model.name, reduced_qk_cluster_groups + ), + true, + ), + ( + format!( + "{}_groups_{:?}_majority_rest", + cluster_model.name, reduced_qk_cluster_groups + ), + false, + ), + ] { + let predicted_codes_by_position = oracle_codes_by_position + .iter() + .enumerate() + .map(|(pos, oracle_codes)| { + let base_codes = if use_oracle_rest { + oracle_codes.as_slice() + } else { + group_majority.as_slice() + }; + let attention_weights = + attention_rows.get(pos).map(Vec::as_slice).unwrap_or(&[]); + cluster_model.predict_selected_groups( + &token_ids, + stratum, + pos, + attention_weights, + base_codes, + ) + }) + .collect::>(); + let prompt_report = evaluate_predicted_address( + &mut weights, + &token_ids, + &index, + *head, + mode_d_table, + &predicted_codes_by_position, + stratum, + label, + &baseline_logp, + baseline_top1, + &oracle_codes_by_position, + )?; + accumulators + .get_mut(&(*head, config)) + .expect("oracle PQ accumulator missing") + .add_address_probe( + &probe_name, + &selected_group_keys, + prompt_report, + ); + } + } + } + + if args.address_corruption_sweep { + let mode_d_table = mode_d_tables.get(&(*head, config)).ok_or_else(|| { + format!( + "missing Mode D table for address corruption L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let group_majority = majority_codes.get(&(*head, config)).ok_or_else(|| { + format!( + "missing majority codes for address corruption L{} H{} {:?}", + head.layer, head.head, config + ) + })?; + let keep_values = corruption_keep_values(config.groups); + for oracle_groups_kept in keep_values { + let predicted_codes_by_position = oracle_codes_by_position + .iter() + .map(|codes| { + codes + .iter() + .enumerate() + .map(|(group, &code)| { + if group < oracle_groups_kept { + code + } else { + group_majority[group] + } + }) + .collect::>() + }) + .collect::>(); + let prompt_report = evaluate_predicted_address( + &mut weights, + &token_ids, + &index, + *head, + mode_d_table, + &predicted_codes_by_position, + stratum, + label, + &baseline_logp, + baseline_top1, + &oracle_codes_by_position, + )?; + accumulators + .get_mut(&(*head, config)) + .expect("oracle PQ accumulator missing") + .add_address_corruption(oracle_groups_kept, prompt_report); + } + } + + accumulators + .get_mut(&(*head, config)) + .expect("oracle PQ accumulator missing") + .add(OraclePqPromptReport { + id: label.to_string(), + stratum: stratum.to_string(), + kl, + delta_cross_entropy_bits: kl / std::f64::consts::LN_2, + baseline_top1, + pq_top1, + top1_agree: baseline_top1 == pq_top1, + baseline_top1_in_pq_top5: pq_top5.contains(&baseline_top1), + baseline_top1_prob, + baseline_top2: baseline_top2_token, + baseline_top2_prob, + baseline_top1_margin, + pq_top1_prob, + pq_prob_of_baseline_top1, + pq_top1_margin, + mode_d_kl, + mode_d_top1, + mode_d_top1_agree, + baseline_top1_in_mode_d_top5, + coeff_mode_d_max_abs_logit_diff, + pre_wo_l2: metrics.pre_wo_l2, + wo_visible_l2: metrics.wo_visible_l2, + }); + } + } + } + + let mut head_reports = Vec::new(); + for head in &selected_heads { + let basis = bases + .get(head) + .ok_or_else(|| format!("missing basis for L{} H{}", head.layer, head.head))?; + let pca_basis = pca_bases + .get(head) + .ok_or_else(|| format!("missing PCA basis for L{} H{}", head.layer, head.head))?; + let static_train_samples = means.get(head).map(|m| m.count).unwrap_or(0); + let mut points = Vec::new(); + for &config in &configs { + let acc = accumulators + .remove(&(*head, config)) + .expect("oracle PQ accumulator missing at finish"); + let stability = code_stability + .get(&(*head, config)) + .cloned() + .unwrap_or_default(); + points.push(acc.finish(config, weights.hidden_size, stability)); + } + head_reports.push(OraclePqHeadReport { + layer: head.layer, + head: head.head, + head_dim: basis.head_dim, + rank_retained: basis.rank_retained(), + empirical_rank: pca_basis.rank(), + sigma_max: basis.sigma_max, + sigma_min_retained: basis.sigma_min_retained, + static_train_samples, + points, + }); + } + + let report = OraclePqReport { + index: args.index.display().to_string(), + prompt_file: args.prompts.display().to_string(), + prompts_seen: prompts.len(), + train_prompts_seen: fit_prompts.len(), + eval_prompts_seen: eval_prompts.len(), + max_per_stratum: args.max_per_stratum, + eval_mod: args.eval_mod, + eval_offset: args.eval_offset, + static_base: "position_mean".to_string(), + configs, + sigma_rel_cutoff: args.sigma_rel_cutoff, + pq_iters: args.pq_iters, + mode_d_check: args.mode_d_check, + address_probes: args.address_probes, + address_mixed_key_probe: args.address_mixed_key_probe, + address_key_group_probe: args.address_key_group_probe, + address_key_groups: if args.address_key_group_probe { + key_groups + } else { + Vec::new() + }, + address_key_group_probe_names: if args.address_key_group_probe { + key_group_probe_names + } else { + Vec::new() + }, + address_majority_group_probe: args.address_majority_group_probe, + address_majority_groups: if args.address_majority_group_probe { + majority_groups + } else { + Vec::new() + }, + address_code_substitution_group_probe: args.address_code_substitution_group_probe, + address_code_substitution_groups: if args.address_code_substitution_group_probe { + code_substitution_groups + } else { + Vec::new() + }, + address_code_substitution_from_codes: if args.address_code_substitution_group_probe { + code_substitution_from_codes + } else { + Vec::new() + }, + address_code_substitution_to_codes: if args.address_code_substitution_group_probe { + code_substitution_to_specs + .into_iter() + .map(|spec| match spec { + CodeSubstitutionToSpec::Majority => "majority".to_string(), + CodeSubstitutionToSpec::Code(code) => code.to_string(), + }) + .collect() + } else { + Vec::new() + }, + address_code_class_collapse_group_probe: args.address_code_class_collapse_group_probe, + address_code_class_collapse_groups: if args.address_code_class_collapse_group_probe { + code_class_collapse_groups + } else { + Vec::new() + }, + address_code_class_collapse_specs: if args.address_code_class_collapse_group_probe { + code_class_collapse_specs + .iter() + .map(CodeClassCollapseSpec::label) + .collect() + } else { + Vec::new() + }, + address_code_position_interaction_probe: args.address_code_position_interaction_probe, + address_code_position_prompt_id: if args.address_code_position_interaction_probe { + code_position_prompt_id + } else { + String::new() + }, + address_code_position_group: if args.address_code_position_interaction_probe { + args.address_code_position_group + } else { + 0 + }, + address_code_position_primary_codes: if args.address_code_position_interaction_probe { + code_position_primary_codes + } else { + Vec::new() + }, + address_code_position_secondary_codes: if args.address_code_position_interaction_probe { + code_position_secondary_codes + } else { + Vec::new() + }, + address_code_position_target_code: if args.address_code_position_interaction_probe { + args.address_code_position_target_code + } else { + 0 + }, + address_code_conditional_quotient_group_probe: args + .address_code_conditional_quotient_group_probe, + address_code_conditional_quotient_group: if args + .address_code_conditional_quotient_group_probe + { + args.address_code_conditional_quotient_group + } else { + 0 + }, + address_code_conditional_quotient_primary_codes: if args + .address_code_conditional_quotient_group_probe + { + code_conditional_quotient_primary_codes + } else { + Vec::new() + }, + address_code_conditional_quotient_secondary_codes: if args + .address_code_conditional_quotient_group_probe + { + code_conditional_quotient_secondary_codes + } else { + Vec::new() + }, + address_code_conditional_quotient_target_code: if args + .address_code_conditional_quotient_group_probe + { + args.address_code_conditional_quotient_target_code + } else { + 0 + }, + address_code_conditional_quotient_early_position_max: if args + .address_code_conditional_quotient_group_probe + { + args.address_code_conditional_quotient_early_position_max + } else { + 0 + }, + address_code_conditional_quotient_guards: if args + .address_code_conditional_quotient_group_probe + { + code_conditional_quotient_guards + .iter() + .map(|guard| guard.label().to_string()) + .collect() + } else { + Vec::new() + }, + address_code_conditional_quotient_extra_specs: if args + .address_code_conditional_quotient_group_probe + { + code_conditional_quotient_extra_specs + .iter() + .map(CodeClassCollapseSpec::label) + .collect() + } else { + Vec::new() + }, + address_code7_bos_rule_group_probe: args.address_code7_bos_rule_group_probe, + address_code7_bos_rule_groups: if args.address_code7_bos_rule_group_probe { + code7_bos_rule_groups + } else { + Vec::new() + }, + address_code7_bos_rule_code: if args.address_code7_bos_rule_group_probe { + args.address_code7_bos_rule_code + } else { + 0 + }, + address_code7_oracle_binary_group_probe: args.address_code7_oracle_binary_group_probe, + address_code7_oracle_binary_groups: if args.address_code7_oracle_binary_group_probe { + code7_oracle_binary_groups + } else { + Vec::new() + }, + address_code7_oracle_binary_code: if args.address_code7_oracle_binary_group_probe { + args.address_code7_oracle_binary_code + } else { + 0 + }, + address_code7_oracle_binary_filters: if args.address_code7_oracle_binary_group_probe { + code7_oracle_binary_filters + } else { + Vec::new() + }, + address_corruption_sweep: args.address_corruption_sweep, + address_group_importance: args.address_group_importance, + address_lsh_group_probe: args.address_lsh_group_probe, + address_lsh_groups: if args.address_lsh_group_probe { + lsh_groups + } else { + Vec::new() + }, + address_lsh_bits: args.address_lsh_bits, + address_lsh_seeds: args.address_lsh_seeds, + address_supervised_group_probe: args.address_supervised_group_probe, + address_supervised_groups: if args.address_supervised_group_probe { + supervised_groups + } else { + Vec::new() + }, + address_supervised_epochs: args.address_supervised_epochs, + address_supervised_lr: args.address_supervised_lr, + address_supervised_l2: args.address_supervised_l2, + address_gamma_projected_group_probe: args.address_gamma_projected_group_probe, + address_gamma_projected_groups: if args.address_gamma_projected_group_probe { + gamma_projected_groups + } else { + Vec::new() + }, + address_gamma_projected_layers: if args.address_gamma_projected_group_probe { + gamma_projected_layers + } else { + Vec::new() + }, + address_gamma_random_ranks: if args.address_gamma_projected_group_probe { + gamma_random_ranks + } else { + Vec::new() + }, + address_gamma_random_seeds: if args.address_gamma_projected_group_probe { + gamma_random_seeds + } else { + Vec::new() + }, + address_gamma_learned_ranks: if args.address_gamma_projected_group_probe { + gamma_learned_ranks + } else { + Vec::new() + }, + address_gamma_learned_epochs: if args.address_gamma_projected_group_probe { + args.address_gamma_learned_epochs + } else { + 0 + }, + address_gamma_learned_lr: if args.address_gamma_projected_group_probe { + args.address_gamma_learned_lr + } else { + 0.0 + }, + address_gamma_learned_l2: if args.address_gamma_projected_group_probe { + args.address_gamma_learned_l2 + } else { + 0.0 + }, + address_gamma_learned_pca_iters: if args.address_gamma_projected_group_probe { + args.address_gamma_learned_pca_iters + } else { + 0 + }, + address_code_stability: args.address_code_stability, + address_code_stability_groups: if args.address_code_stability { + code_stability_groups + } else { + Vec::new() + }, + address_prev_ffn_feature_group_probe: args.address_prev_ffn_feature_group_probe, + address_prev_ffn_feature_groups: if args.address_prev_ffn_feature_group_probe { + prev_ffn_feature_groups + } else { + Vec::new() + }, + address_prev_ffn_feature_top_k: args.address_prev_ffn_feature_top_k, + address_ffn_first_feature_group_probe: args.address_ffn_first_feature_group_probe, + address_ffn_first_feature_groups: if args.address_ffn_first_feature_group_probe { + ffn_first_feature_groups + } else { + Vec::new() + }, + address_ffn_first_feature_top_k: args.address_ffn_first_feature_top_k, + address_attention_relation_group_probe: args.address_attention_relation_group_probe, + address_attention_relation_groups: if args.address_attention_relation_group_probe { + attention_relation_groups + } else { + Vec::new() + }, + address_attention_cluster_group_probe: args.address_attention_cluster_group_probe, + address_attention_cluster_groups: if args.address_attention_cluster_group_probe { + attention_cluster_groups + } else { + Vec::new() + }, + address_attention_cluster_ks: if args.address_attention_cluster_group_probe { + attention_cluster_ks + } else { + Vec::new() + }, + address_attention_cluster_probe_names: if args.address_attention_cluster_group_probe { + attention_cluster_probe_names + } else { + Vec::new() + }, + address_reduced_qk_cluster_group_probe: args.address_reduced_qk_cluster_group_probe, + address_reduced_qk_cluster_groups: if args.address_reduced_qk_cluster_group_probe { + reduced_qk_cluster_groups + } else { + Vec::new() + }, + address_reduced_qk_ranks: if args.address_reduced_qk_cluster_group_probe { + reduced_qk_ranks + } else { + Vec::new() + }, + address_reduced_qk_cluster_ks: if args.address_reduced_qk_cluster_group_probe { + reduced_qk_cluster_ks + } else { + Vec::new() + }, + address_reduced_qk_cluster_probe_names: if args.address_reduced_qk_cluster_group_probe { + reduced_qk_cluster_probe_names + } else { + Vec::new() + }, + stratum_conditioned_pq_groups, + selected_heads, + heads: head_reports, + }; + + let out_path = args.out.join("oracle_pq.json"); + let file = std::fs::File::create(&out_path)?; + serde_json::to_writer_pretty(file, &report)?; + eprintln!("Wrote {}", out_path.display()); + + Ok(()) +} + +fn parse_string_list(spec: &str) -> Vec { + spec.split(',') + .map(str::trim) + .filter(|part| !part.is_empty()) + .map(ToString::to_string) + .collect() +} + +fn oracle_mode_d_address_report( + label: &str, + stratum: &str, + positions: usize, + groups: usize, + kl: f64, + top1_agree: bool, + baseline_top1_in_predicted_top5: bool, +) -> AddressProbePromptReport { + AddressProbePromptReport { + id: label.to_string(), + stratum: stratum.to_string(), + kl, + positions, + groups_correct: positions * groups, + groups_total: positions * groups, + exact_address_match: true, + top1_agree, + baseline_top1_in_predicted_top5, + } +} + +#[derive(Debug, Clone)] +struct CodeClassCollapseSpec { + name: String, + mappings: Vec, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ConditionalQuotientGuard { + EarlyProsePosition, + EarlyProseBosPrev, + ProseBosPrev, +} + +impl ConditionalQuotientGuard { + fn parse(raw: &str) -> Option { + match raw.trim() { + "early_prose_position" | "E_early_prose_position_guard" => { + Some(ConditionalQuotientGuard::EarlyProsePosition) + } + "early_prose_bos_prev" | "F_early_prose_bos_prev_guard" => { + Some(ConditionalQuotientGuard::EarlyProseBosPrev) + } + "prose_bos_prev" | "G_prose_bos_prev_guard" => { + Some(ConditionalQuotientGuard::ProseBosPrev) + } + _ => None, + } + } + + fn label(self) -> &'static str { + match self { + ConditionalQuotientGuard::EarlyProsePosition => "E_early_prose_position_guard", + ConditionalQuotientGuard::EarlyProseBosPrev => "F_early_prose_bos_prev_guard", + ConditionalQuotientGuard::ProseBosPrev => "G_prose_bos_prev_guard", + } + } + + fn keeps_secondary_oracle( + self, + stratum: &str, + pos: usize, + early_position_max: usize, + attention_weights: &[f32], + ) -> bool { + if stratum != "natural_prose" { + return false; + } + let is_early = pos <= early_position_max; + match self { + ConditionalQuotientGuard::EarlyProsePosition => is_early, + ConditionalQuotientGuard::EarlyProseBosPrev => { + is_early && is_bos_or_previous_attention(pos, attention_weights) + } + ConditionalQuotientGuard::ProseBosPrev => { + is_bos_or_previous_attention(pos, attention_weights) + } + } + } +} + +fn is_bos_or_previous_attention(pos: usize, attention_weights: &[f32]) -> bool { + if attention_weights.is_empty() { + return false; + } + let source = attention_argmax(attention_weights, pos); + source == 0 || (pos > 0 && source + 1 == pos) +} + +impl CodeClassCollapseSpec { + fn label(&self) -> String { + format!("{}={}", self.name, self.mapping_label()) + } + + fn mapping_label(&self) -> String { + self.mappings + .iter() + .map(|mapping| { + let sources = mapping + .sources + .iter() + .map(ToString::to_string) + .collect::>() + .join("+"); + format!("{sources}:{}", mapping.target) + }) + .collect::>() + .join("|") + } + + fn mapping_label_or_base(&self) -> String { + if self.mappings.is_empty() { + "base".to_string() + } else { + self.mapping_label() + } + } +} + +#[derive(Debug, Clone)] +struct CodeClassCollapseMapping { + sources: Vec, + target: usize, +} + +fn parse_code_class_collapse_specs( + spec: &str, +) -> Result, Box> { + let mut out = Vec::new(); + for (idx, raw_spec) in spec + .split(';') + .map(str::trim) + .filter(|part| !part.is_empty()) + .enumerate() + { + let (raw_name, raw_mappings) = raw_spec + .split_once('=') + .map(|(name, mappings)| (name.trim(), mappings.trim())) + .unwrap_or(("", raw_spec)); + let mappings = parse_code_class_collapse_mappings(raw_mappings)?; + let fallback_name = sanitize_probe_name( + &mappings + .iter() + .map(|mapping| { + let sources = mapping + .sources + .iter() + .map(ToString::to_string) + .collect::>() + .join("+"); + format!("{sources}_to_{}", mapping.target) + }) + .collect::>() + .join("_and_"), + ); + let name = if raw_name.is_empty() { + format!("collapse{idx}_{fallback_name}") + } else { + sanitize_probe_name(raw_name) + }; + if name.is_empty() { + return Err(format!("invalid empty class-collapse name in spec {raw_spec:?}").into()); + } + out.push(CodeClassCollapseSpec { name, mappings }); + } + Ok(out) +} + +fn parse_conditional_quotient_guards( + spec: &str, +) -> Result, Box> { + let mut out = Vec::new(); + for raw in spec + .split(',') + .map(str::trim) + .filter(|part| !part.is_empty()) + { + let guard = ConditionalQuotientGuard::parse(raw).ok_or_else(|| { + format!( + "unsupported conditional quotient guard {raw:?}; expected early_prose_position, early_prose_bos_prev, or prose_bos_prev" + ) + })?; + if !out.contains(&guard) { + out.push(guard); + } + } + Ok(out) +} + +fn parse_code_class_collapse_mappings( + spec: &str, +) -> Result, Box> { + let mut mappings = Vec::new(); + let mut seen_sources = Vec::new(); + for raw_mapping in spec + .split('|') + .map(str::trim) + .filter(|part| !part.is_empty()) + { + let (raw_sources, raw_target) = raw_mapping.split_once(':').ok_or_else(|| { + format!("invalid class-collapse mapping {raw_mapping:?}; expected sources:target") + })?; + let mut sources = Vec::new(); + for part in raw_sources + .split('+') + .map(str::trim) + .filter(|part| !part.is_empty()) + { + sources + .push(part.parse::().map_err(|err| { + format!("invalid class-collapse source code {part:?}: {err}") + })?); + } + sources.sort_unstable(); + sources.dedup(); + if sources.is_empty() { + return Err(format!("class-collapse mapping {raw_mapping:?} has no sources").into()); + } + for &source in &sources { + if seen_sources.contains(&source) { + return Err(format!( + "class-collapse source code {source} appears in more than one mapping" + ) + .into()); + } + seen_sources.push(source); + } + let target = raw_target.trim().parse::().map_err(|err| { + format!( + "invalid class-collapse target code {:?}: {err}", + raw_target.trim() + ) + })?; + mappings.push(CodeClassCollapseMapping { sources, target }); + } + if mappings.is_empty() { + return Err(format!("class-collapse spec {spec:?} has no mappings").into()); + } + Ok(mappings) +} + +fn sanitize_probe_name(name: &str) -> String { + name.chars() + .map(|ch| { + if ch.is_ascii_alphanumeric() || ch == '_' || ch == '-' { + ch + } else { + '_' + } + }) + .collect() +} + +#[derive(Debug, Clone, Copy)] +enum CodeSubstitutionToSpec { + Majority, + Code(usize), +} + +fn parse_code_substitution_to_specs( + spec: &str, +) -> Result, Box> { + let mut out = Vec::new(); + for part in spec + .split(',') + .map(str::trim) + .filter(|part| !part.is_empty()) + { + if part.eq_ignore_ascii_case("majority") { + out.push(CodeSubstitutionToSpec::Majority); + } else { + out.push(CodeSubstitutionToSpec::Code( + part.parse::() + .map_err(|err| format!("invalid code substitution target {part:?}: {err}"))?, + )); + } + } + Ok(out) +} diff --git a/crates/larql-cli/src/commands/dev/ov_rd/oracle_pq_address.rs b/crates/larql-cli/src/commands/dev/ov_rd/oracle_pq_address.rs new file mode 100644 index 00000000..259d6e06 --- /dev/null +++ b/crates/larql-cli/src/commands/dev/ov_rd/oracle_pq_address.rs @@ -0,0 +1,1433 @@ +use std::collections::HashMap; + +use larql_inference::attention::{ + run_attention_block_with_pre_o, run_attention_block_with_pre_o_and_all_attention_weights, + run_attention_block_with_pre_o_and_reduced_qk_attention_weights, +}; +use larql_inference::forward::ple::precompute_per_layer_inputs; +use larql_inference::forward::{embed_tokens_pub, run_ffn, run_layer_with_ffn}; +use larql_inference::{encode_prompt, WeightFfn}; +use larql_vindex::VectorIndex; +use ndarray::{s, ArrayView1}; + +use super::address::{ + address_feature_key, address_probe_names, attention_argmax, attention_cluster_key, + attention_cluster_probe_names, attention_entropy_bits, attention_pattern_features, + attention_relation_key, attention_relation_probe_names, ffn_first_feature_key, + ffn_first_feature_probe_names, lsh_bucket, nearest_attention_cluster, + predict_code_from_hyperplanes, prev_ffn_feature_key, prev_ffn_feature_probe_names, + top_feature_ids_from_activation_row, train_binary_hyperplane, + AddressAttentionClusterGroupModel, AddressLshGroupModel, AddressProbeModel, + AddressSupervisedGroupModel, +}; +use super::basis::{WoRoundtripBasis, ZPcaBasis}; +use super::metrics::argmax_usize; +use super::pq::{kmeans_centroids, PqCodebook}; +use super::reports::CodeOccurrenceRecord; +use super::runtime::{insert_q4k_layer_tensors, remove_layer_tensors}; +use super::stats::StaticHeadMeans; +use super::types::{HeadId, PqConfig, PromptRecord}; + +type SampleVisitResult = Result<(), Box>; + +#[derive(Debug, Clone)] +struct AttentionClusterFitSample { + features: Vec, + codes: Vec, + token_ids: Vec, + stratum: String, + position: usize, +} + +pub(super) fn fit_address_probe_models( + weights: &mut larql_inference::ModelWeights, + index: &VectorIndex, + tokenizer: &tokenizers::Tokenizer, + prompts: &[PromptRecord], + heads: &[HeadId], + bases: &HashMap, + means: &HashMap, + pca_bases: &HashMap, + codebooks: &HashMap<(HeadId, PqConfig), PqCodebook>, + include_mixed_key_probe: bool, +) -> Result>, Box> { + let names = address_probe_names(); + let mut key_counts: HashMap<(HeadId, PqConfig, String, usize, String), Vec> = + HashMap::new(); + let mut majority_counts: HashMap<(HeadId, PqConfig, usize), Vec> = HashMap::new(); + + visit_code_samples( + weights, + index, + tokenizer, + prompts, + heads, + bases, + means, + pca_bases, + codebooks, + "address-fit", + false, + 0, + 0, + false, + None, + |head, config, pos, codes, token_ids, stratum, _, _, _, _, _| { + for (group, &code) in codes.iter().enumerate() { + let levels = 1usize << config.bits_per_group; + let counts = majority_counts + .entry((head, config, group)) + .or_insert_with(|| vec![0; levels]); + counts[code] += 1; + for name in &names { + let key = address_feature_key(name, token_ids, stratum, pos); + let counts = key_counts + .entry((head, config, (*name).to_string(), group, key)) + .or_insert_with(|| vec![0; levels]); + counts[code] += 1; + } + } + Ok(()) + }, + )?; + + let mut models = HashMap::new(); + for ((head, config), _) in codebooks { + let mut probe_models = Vec::new(); + for name in &names { + let mut group_majority = Vec::with_capacity(config.groups); + let mut group_maps = Vec::with_capacity(config.groups); + let mut group_train_accuracy = Vec::with_capacity(config.groups); + for group in 0..config.groups { + let majority = majority_counts + .get(&(*head, *config, group)) + .map(|counts| argmax_usize(counts)) + .unwrap_or(0); + group_majority.push(majority); + let mut map = HashMap::new(); + let mut correct = 0usize; + let mut total = 0usize; + for ((map_head, map_config, map_name, map_group, key), counts) in key_counts.iter() + { + if map_head == head + && map_config == config + && map_name == name + && *map_group == group + { + let best = argmax_usize(counts); + correct += counts[best]; + total += counts.iter().sum::(); + map.insert(key.clone(), best); + } + } + group_maps.push(map); + group_train_accuracy.push(if total == 0 { + 0.0 + } else { + correct as f64 / total as f64 + }); + } + probe_models.push(AddressProbeModel { + name: (*name).to_string(), + group_majority, + group_maps, + group_train_accuracy, + selected_group_keys: Vec::new(), + }); + } + if include_mixed_key_probe && !probe_models.is_empty() { + let mut group_majority = Vec::with_capacity(config.groups); + let mut group_maps = Vec::with_capacity(config.groups); + let mut group_train_accuracy = Vec::with_capacity(config.groups); + let mut selected_group_keys = Vec::with_capacity(config.groups); + for group in 0..config.groups { + let best_idx = probe_models + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| { + a.group_train_accuracy[group] + .partial_cmp(&b.group_train_accuracy[group]) + .unwrap_or(std::cmp::Ordering::Equal) + }) + .map(|(idx, _)| idx) + .unwrap_or(0); + let best = &probe_models[best_idx]; + group_majority.push(best.group_majority[group]); + group_maps.push(best.group_maps[group].clone()); + group_train_accuracy.push(best.group_train_accuracy[group]); + selected_group_keys.push(best.name.clone()); + } + probe_models.push(AddressProbeModel { + name: "mixed_best_simple_key".to_string(), + group_majority, + group_maps, + group_train_accuracy, + selected_group_keys, + }); + } + models.insert((*head, *config), probe_models); + } + + Ok(models) +} + +pub(super) fn fit_address_prev_ffn_feature_group_models( + weights: &mut larql_inference::ModelWeights, + index: &VectorIndex, + tokenizer: &tokenizers::Tokenizer, + prompts: &[PromptRecord], + heads: &[HeadId], + bases: &HashMap, + means: &HashMap, + pca_bases: &HashMap, + codebooks: &HashMap<(HeadId, PqConfig), PqCodebook>, + selected_groups: &[usize], + feature_top_k: usize, +) -> Result>, Box> { + let names = prev_ffn_feature_probe_names(); + let mut key_counts: HashMap<(HeadId, PqConfig, String, usize, String), Vec> = + HashMap::new(); + let mut majority_counts: HashMap<(HeadId, PqConfig, usize), Vec> = HashMap::new(); + + visit_code_samples( + weights, + index, + tokenizer, + prompts, + heads, + bases, + means, + pca_bases, + codebooks, + "prev-ffn-feature-fit", + false, + feature_top_k, + 0, + false, + None, + |head, config, pos, codes, token_ids, stratum, _, _, prev_features, _, _| { + for (group, &code) in codes.iter().enumerate() { + let levels = 1usize << config.bits_per_group; + let counts = majority_counts + .entry((head, config, group)) + .or_insert_with(|| vec![0; levels]); + counts[code] += 1; + } + let prev_features = prev_features.unwrap_or(&[]); + for &group in selected_groups { + let code = codes[group]; + for name in &names { + let key = prev_ffn_feature_key(name, token_ids, stratum, pos, prev_features); + let levels = 1usize << config.bits_per_group; + let counts = key_counts + .entry((head, config, (*name).to_string(), group, key)) + .or_insert_with(|| vec![0; levels]); + counts[code] += 1; + } + } + Ok(()) + }, + )?; + + let mut models = HashMap::new(); + for ((head, config), _) in codebooks { + let mut probe_models = Vec::new(); + for name in &names { + let mut group_majority = Vec::with_capacity(config.groups); + let mut group_maps = vec![HashMap::new(); config.groups]; + let mut group_train_accuracy = vec![0.0; config.groups]; + for group in 0..config.groups { + let majority = majority_counts + .get(&(*head, *config, group)) + .map(|counts| argmax_usize(counts)) + .unwrap_or(0); + group_majority.push(majority); + } + for &group in selected_groups { + let mut map = HashMap::new(); + let mut correct = 0usize; + let mut total = 0usize; + for ((map_head, map_config, map_name, map_group, key), counts) in key_counts.iter() + { + if map_head == head + && map_config == config + && map_name == name + && *map_group == group + { + let best = argmax_usize(counts); + correct += counts[best]; + total += counts.iter().sum::(); + map.insert(key.clone(), best); + } + } + group_maps[group] = map; + group_train_accuracy[group] = if total == 0 { + 0.0 + } else { + correct as f64 / total as f64 + }; + } + let selected_group_keys = (0..config.groups) + .map(|group| { + if selected_groups.contains(&group) { + format!("{}_train_acc_{:.3}", name, group_train_accuracy[group]) + } else { + "majority".to_string() + } + }) + .collect(); + probe_models.push(AddressProbeModel { + name: (*name).to_string(), + group_majority, + group_maps, + group_train_accuracy, + selected_group_keys, + }); + } + models.insert((*head, *config), probe_models); + } + + Ok(models) +} + +pub(super) fn fit_address_ffn_first_feature_group_models( + weights: &mut larql_inference::ModelWeights, + index: &VectorIndex, + tokenizer: &tokenizers::Tokenizer, + prompts: &[PromptRecord], + heads: &[HeadId], + bases: &HashMap, + means: &HashMap, + pca_bases: &HashMap, + codebooks: &HashMap<(HeadId, PqConfig), PqCodebook>, + selected_groups: &[usize], + feature_top_k: usize, +) -> Result>, Box> { + let names = ffn_first_feature_probe_names(); + let mut key_counts: HashMap<(HeadId, PqConfig, String, usize, String), Vec> = + HashMap::new(); + let mut majority_counts: HashMap<(HeadId, PqConfig, usize), Vec> = HashMap::new(); + + visit_code_samples( + weights, + index, + tokenizer, + prompts, + heads, + bases, + means, + pca_bases, + codebooks, + "ffn-first-feature-fit", + false, + 0, + feature_top_k, + false, + None, + |head, config, pos, codes, token_ids, stratum, _, _, _, ffn_first_features, _| { + for (group, &code) in codes.iter().enumerate() { + let levels = 1usize << config.bits_per_group; + let counts = majority_counts + .entry((head, config, group)) + .or_insert_with(|| vec![0; levels]); + counts[code] += 1; + } + let ffn_first_features = ffn_first_features.unwrap_or(&[]); + for &group in selected_groups { + let code = codes[group]; + for name in &names { + let key = + ffn_first_feature_key(name, token_ids, stratum, pos, ffn_first_features); + let levels = 1usize << config.bits_per_group; + let counts = key_counts + .entry((head, config, (*name).to_string(), group, key)) + .or_insert_with(|| vec![0; levels]); + counts[code] += 1; + } + } + Ok(()) + }, + )?; + + let mut models = HashMap::new(); + for ((head, config), _) in codebooks { + let mut probe_models = Vec::new(); + for name in &names { + let mut group_majority = Vec::with_capacity(config.groups); + let mut group_maps = vec![HashMap::new(); config.groups]; + let mut group_train_accuracy = vec![0.0; config.groups]; + for group in 0..config.groups { + let majority = majority_counts + .get(&(*head, *config, group)) + .map(|counts| argmax_usize(counts)) + .unwrap_or(0); + group_majority.push(majority); + } + for &group in selected_groups { + let mut map = HashMap::new(); + let mut correct = 0usize; + let mut total = 0usize; + for ((map_head, map_config, map_name, map_group, key), counts) in key_counts.iter() + { + if map_head == head + && map_config == config + && map_name == name + && *map_group == group + { + let best = argmax_usize(counts); + correct += counts[best]; + total += counts.iter().sum::(); + map.insert(key.clone(), best); + } + } + group_maps[group] = map; + group_train_accuracy[group] = if total == 0 { + 0.0 + } else { + correct as f64 / total as f64 + }; + } + let selected_group_keys = (0..config.groups) + .map(|group| { + if selected_groups.contains(&group) { + format!("{}_train_acc_{:.3}", name, group_train_accuracy[group]) + } else { + "majority".to_string() + } + }) + .collect(); + probe_models.push(AddressProbeModel { + name: (*name).to_string(), + group_majority, + group_maps, + group_train_accuracy, + selected_group_keys, + }); + } + models.insert((*head, *config), probe_models); + } + + Ok(models) +} + +pub(super) fn fit_address_attention_relation_group_models( + weights: &mut larql_inference::ModelWeights, + index: &VectorIndex, + tokenizer: &tokenizers::Tokenizer, + prompts: &[PromptRecord], + heads: &[HeadId], + bases: &HashMap, + means: &HashMap, + pca_bases: &HashMap, + codebooks: &HashMap<(HeadId, PqConfig), PqCodebook>, + selected_groups: &[usize], +) -> Result>, Box> { + let names = attention_relation_probe_names(); + let mut key_counts: HashMap<(HeadId, PqConfig, String, usize, String), Vec> = + HashMap::new(); + let mut majority_counts: HashMap<(HeadId, PqConfig, usize), Vec> = HashMap::new(); + + visit_code_samples( + weights, + index, + tokenizer, + prompts, + heads, + bases, + means, + pca_bases, + codebooks, + "attention-relation-fit", + false, + 0, + 0, + true, + None, + |head, config, pos, codes, token_ids, stratum, _, _, _, _, attention_weights| { + for (group, &code) in codes.iter().enumerate() { + let levels = 1usize << config.bits_per_group; + let counts = majority_counts + .entry((head, config, group)) + .or_insert_with(|| vec![0; levels]); + counts[code] += 1; + } + let attention_weights = + attention_weights.ok_or("missing attention row during relation address fit")?; + for &group in selected_groups { + let code = codes[group]; + for name in &names { + let key = + attention_relation_key(name, token_ids, stratum, pos, attention_weights); + let levels = 1usize << config.bits_per_group; + let counts = key_counts + .entry((head, config, (*name).to_string(), group, key)) + .or_insert_with(|| vec![0; levels]); + counts[code] += 1; + } + } + Ok(()) + }, + )?; + + let mut models = HashMap::new(); + for ((head, config), _) in codebooks { + let mut probe_models = Vec::new(); + for name in &names { + let mut group_majority = Vec::with_capacity(config.groups); + let mut group_maps = vec![HashMap::new(); config.groups]; + let mut group_train_accuracy = vec![0.0; config.groups]; + for group in 0..config.groups { + let majority = majority_counts + .get(&(*head, *config, group)) + .map(|counts| argmax_usize(counts)) + .unwrap_or(0); + group_majority.push(majority); + } + for &group in selected_groups { + let mut map = HashMap::new(); + let mut correct = 0usize; + let mut total = 0usize; + for ((map_head, map_config, map_name, map_group, key), counts) in key_counts.iter() + { + if map_head == head + && map_config == config + && map_name == name + && *map_group == group + { + let best = argmax_usize(counts); + correct += counts[best]; + total += counts.iter().sum::(); + map.insert(key.clone(), best); + } + } + group_maps[group] = map; + group_train_accuracy[group] = if total == 0 { + 0.0 + } else { + correct as f64 / total as f64 + }; + } + let selected_group_keys = (0..config.groups) + .map(|group| { + if selected_groups.contains(&group) { + format!("{}_train_acc_{:.3}", name, group_train_accuracy[group]) + } else { + "majority".to_string() + } + }) + .collect(); + probe_models.push(AddressProbeModel { + name: (*name).to_string(), + group_majority, + group_maps, + group_train_accuracy, + selected_group_keys, + }); + } + models.insert((*head, *config), probe_models); + } + + Ok(models) +} + +pub(super) fn fit_address_attention_cluster_group_models( + weights: &mut larql_inference::ModelWeights, + index: &VectorIndex, + tokenizer: &tokenizers::Tokenizer, + prompts: &[PromptRecord], + heads: &[HeadId], + bases: &HashMap, + means: &HashMap, + pca_bases: &HashMap, + codebooks: &HashMap<(HeadId, PqConfig), PqCodebook>, + selected_groups: &[usize], + cluster_counts: &[usize], +) -> Result< + HashMap<(HeadId, PqConfig), Vec>, + Box, +> { + let mut majority_counts: HashMap<(HeadId, PqConfig, usize), Vec> = HashMap::new(); + let mut samples: HashMap<(HeadId, PqConfig), Vec> = HashMap::new(); + + visit_code_samples( + weights, + index, + tokenizer, + prompts, + heads, + bases, + means, + pca_bases, + codebooks, + "attention-cluster-fit", + false, + 0, + 0, + true, + None, + |head, config, pos, codes, token_ids, stratum, _, _, _, _, attention_weights| { + for (group, &code) in codes.iter().enumerate() { + let levels = 1usize << config.bits_per_group; + let counts = majority_counts + .entry((head, config, group)) + .or_insert_with(|| vec![0; levels]); + counts[code] += 1; + } + let attention_weights = + attention_weights.ok_or("missing attention row during cluster address fit")?; + samples + .entry((head, config)) + .or_default() + .push(AttentionClusterFitSample { + features: attention_pattern_features(attention_weights, pos), + codes: codes.to_vec(), + token_ids: token_ids.to_vec(), + stratum: stratum.to_string(), + position: pos, + }); + Ok(()) + }, + )?; + + let mut models = HashMap::new(); + for ((head, config), _) in codebooks { + let train_samples = samples.get(&(*head, *config)).cloned().unwrap_or_default(); + let feature_rows = train_samples + .iter() + .map(|sample| sample.features.clone()) + .collect::>(); + let mut group_majority = Vec::with_capacity(config.groups); + for group in 0..config.groups { + let majority = majority_counts + .get(&(*head, *config, group)) + .map(|counts| argmax_usize(counts)) + .unwrap_or(0); + group_majority.push(majority); + } + + let mut cluster_models = Vec::new(); + for &cluster_count in cluster_counts { + let centroids = kmeans_centroids(&feature_rows, cluster_count, 25); + let assignments = train_samples + .iter() + .map(|sample| nearest_attention_cluster(&sample.features, ¢roids)) + .collect::>(); + for name in attention_cluster_probe_names(cluster_count) { + let mut key_counts: HashMap<(usize, String), Vec> = HashMap::new(); + for (sample, &cluster) in train_samples.iter().zip(assignments.iter()) { + let key = attention_cluster_key( + &name, + &sample.token_ids, + &sample.stratum, + sample.position, + cluster, + ); + for &group in selected_groups { + let levels = 1usize << config.bits_per_group; + let counts = key_counts + .entry((group, key.clone())) + .or_insert_with(|| vec![0; levels]); + counts[sample.codes[group]] += 1; + } + } + + let mut group_maps = vec![HashMap::new(); config.groups]; + let mut group_train_accuracy = vec![0.0; config.groups]; + for &group in selected_groups { + let mut correct = 0usize; + let mut total = 0usize; + for ((map_group, key), counts) in key_counts.iter() { + if *map_group == group { + let best = argmax_usize(counts); + correct += counts[best]; + total += counts.iter().sum::(); + group_maps[group].insert(key.clone(), best); + } + } + group_train_accuracy[group] = if total == 0 { + 0.0 + } else { + correct as f64 / total as f64 + }; + } + let selected_group_keys = (0..config.groups) + .map(|group| { + if selected_groups.contains(&group) { + format!("{}_train_acc_{:.3}", name, group_train_accuracy[group]) + } else { + "majority".to_string() + } + }) + .collect(); + cluster_models.push(AddressAttentionClusterGroupModel { + name, + groups: selected_groups.to_vec(), + qk_rank: None, + centroids: centroids.clone(), + group_majority: group_majority.clone(), + group_maps, + selected_group_keys, + }); + } + } + models.insert((*head, *config), cluster_models); + } + + Ok(models) +} + +pub(super) fn fit_address_reduced_qk_cluster_group_models( + weights: &mut larql_inference::ModelWeights, + index: &VectorIndex, + tokenizer: &tokenizers::Tokenizer, + prompts: &[PromptRecord], + heads: &[HeadId], + bases: &HashMap, + means: &HashMap, + pca_bases: &HashMap, + codebooks: &HashMap<(HeadId, PqConfig), PqCodebook>, + selected_groups: &[usize], + qk_ranks: &[usize], + cluster_counts: &[usize], +) -> Result< + HashMap<(HeadId, PqConfig), Vec>, + Box, +> { + let mut models: HashMap<(HeadId, PqConfig), Vec> = + HashMap::new(); + + for &qk_rank in qk_ranks { + let mut majority_counts: HashMap<(HeadId, PqConfig, usize), Vec> = HashMap::new(); + let mut samples: HashMap<(HeadId, PqConfig), Vec> = + HashMap::new(); + + let label = if qk_rank == 0 { + "full-qk-cluster-fit".to_string() + } else { + format!("reduced-qk-r{qk_rank}-cluster-fit") + }; + visit_code_samples( + weights, + index, + tokenizer, + prompts, + heads, + bases, + means, + pca_bases, + codebooks, + &label, + false, + 0, + 0, + true, + if qk_rank == 0 { None } else { Some(qk_rank) }, + |head, config, pos, codes, token_ids, stratum, _, _, _, _, attention_weights| { + for (group, &code) in codes.iter().enumerate() { + let levels = 1usize << config.bits_per_group; + let counts = majority_counts + .entry((head, config, group)) + .or_insert_with(|| vec![0; levels]); + counts[code] += 1; + } + let attention_weights = + attention_weights.ok_or("missing attention row during reduced-QK fit")?; + samples + .entry((head, config)) + .or_default() + .push(AttentionClusterFitSample { + features: attention_pattern_features(attention_weights, pos), + codes: codes.to_vec(), + token_ids: token_ids.to_vec(), + stratum: stratum.to_string(), + position: pos, + }); + Ok(()) + }, + )?; + + for ((head, config), _) in codebooks { + let train_samples = samples.get(&(*head, *config)).cloned().unwrap_or_default(); + let feature_rows = train_samples + .iter() + .map(|sample| sample.features.clone()) + .collect::>(); + let mut group_majority = Vec::with_capacity(config.groups); + for group in 0..config.groups { + let majority = majority_counts + .get(&(*head, *config, group)) + .map(|counts| argmax_usize(counts)) + .unwrap_or(0); + group_majority.push(majority); + } + + let rank_prefix = if qk_rank == 0 { + "qk_full".to_string() + } else { + format!("qk_rank{qk_rank}") + }; + let entry = models.entry((*head, *config)).or_default(); + for &cluster_count in cluster_counts { + let centroids = kmeans_centroids(&feature_rows, cluster_count, 25); + let assignments = train_samples + .iter() + .map(|sample| nearest_attention_cluster(&sample.features, ¢roids)) + .collect::>(); + for base_name in attention_cluster_probe_names(cluster_count) { + let name = format!("{rank_prefix}_{base_name}"); + let mut key_counts: HashMap<(usize, String), Vec> = HashMap::new(); + for (sample, &cluster) in train_samples.iter().zip(assignments.iter()) { + let key = attention_cluster_key( + &base_name, + &sample.token_ids, + &sample.stratum, + sample.position, + cluster, + ); + for &group in selected_groups { + let levels = 1usize << config.bits_per_group; + let counts = key_counts + .entry((group, key.clone())) + .or_insert_with(|| vec![0; levels]); + counts[sample.codes[group]] += 1; + } + } + + let mut group_maps = vec![HashMap::new(); config.groups]; + let mut group_train_accuracy = vec![0.0; config.groups]; + for &group in selected_groups { + let mut correct = 0usize; + let mut total = 0usize; + for ((map_group, key), counts) in key_counts.iter() { + if *map_group == group { + let best = argmax_usize(counts); + correct += counts[best]; + total += counts.iter().sum::(); + group_maps[group].insert(key.clone(), best); + } + } + group_train_accuracy[group] = if total == 0 { + 0.0 + } else { + correct as f64 / total as f64 + }; + } + let selected_group_keys = (0..config.groups) + .map(|group| { + if selected_groups.contains(&group) { + format!("{name}_train_acc_{:.3}", group_train_accuracy[group]) + } else { + "majority".to_string() + } + }) + .collect(); + entry.push(AddressAttentionClusterGroupModel { + name, + groups: selected_groups.to_vec(), + qk_rank: if qk_rank == 0 { None } else { Some(qk_rank) }, + centroids: centroids.clone(), + group_majority: group_majority.clone(), + group_maps, + selected_group_keys, + }); + } + } + } + } + + Ok(models) +} + +pub(super) fn fit_address_lsh_group_models( + weights: &mut larql_inference::ModelWeights, + index: &VectorIndex, + tokenizer: &tokenizers::Tokenizer, + prompts: &[PromptRecord], + heads: &[HeadId], + bases: &HashMap, + means: &HashMap, + pca_bases: &HashMap, + codebooks: &HashMap<(HeadId, PqConfig), PqCodebook>, + selected_groups: &[usize], + bits: usize, + seeds: usize, +) -> Result, Box> { + let mut majority_counts: HashMap<(HeadId, PqConfig, usize), Vec> = HashMap::new(); + let mut bucket_counts: HashMap<(HeadId, PqConfig, usize, u64, usize), Vec> = + HashMap::new(); + + visit_code_samples( + weights, + index, + tokenizer, + prompts, + heads, + bases, + means, + pca_bases, + codebooks, + "lsh-fit", + true, + 0, + 0, + false, + None, + |head, config, _pos, codes, _token_ids, _stratum, _, input_row, _, _, _| { + let input_row = input_row.ok_or("missing layer-input row during LSH address fit")?; + for (group, &code) in codes.iter().enumerate() { + let levels = 1usize << config.bits_per_group; + let counts = majority_counts + .entry((head, config, group)) + .or_insert_with(|| vec![0; levels]); + counts[code] += 1; + } + for &group in selected_groups { + let code = codes[group]; + for seed in 0..seeds { + let bucket = lsh_bucket(ArrayView1::from(input_row), seed as u64, bits); + let levels = 1usize << config.bits_per_group; + let counts = bucket_counts + .entry((head, config, group, seed as u64, bucket)) + .or_insert_with(|| vec![0; levels]); + counts[code] += 1; + } + } + Ok(()) + }, + )?; + + let mut models = HashMap::new(); + for ((head, config), _) in codebooks { + let mut group_majority = Vec::with_capacity(config.groups); + for group in 0..config.groups { + let majority = majority_counts + .get(&(*head, *config, group)) + .map(|counts| argmax_usize(counts)) + .unwrap_or(0); + group_majority.push(majority); + } + + let mut group_maps = vec![HashMap::new(); config.groups]; + let mut group_seeds = vec![0_u64; config.groups]; + let mut group_train_accuracy = vec![0.0; config.groups]; + for &group in selected_groups { + let mut best_seed = 0_u64; + let mut best_accuracy = -1.0_f64; + let mut best_map = HashMap::new(); + for seed in 0..seeds { + let seed = seed as u64; + let mut map = HashMap::new(); + let mut correct = 0usize; + let mut total = 0usize; + for ((map_head, map_config, map_group, map_seed, bucket), counts) in + bucket_counts.iter() + { + if map_head == head + && map_config == config + && *map_group == group + && *map_seed == seed + { + let best = argmax_usize(counts); + correct += counts[best]; + total += counts.iter().sum::(); + map.insert(*bucket, best); + } + } + let accuracy = if total == 0 { + 0.0 + } else { + correct as f64 / total as f64 + }; + if accuracy > best_accuracy { + best_accuracy = accuracy; + best_seed = seed; + best_map = map; + } + } + group_maps[group] = best_map; + group_seeds[group] = best_seed; + group_train_accuracy[group] = best_accuracy.max(0.0); + } + + models.insert( + (*head, *config), + AddressLshGroupModel { + groups: selected_groups.to_vec(), + bits, + group_majority, + group_maps, + group_seeds, + group_train_accuracy, + }, + ); + } + + Ok(models) +} + +pub(super) fn fit_address_supervised_group_models( + weights: &mut larql_inference::ModelWeights, + index: &VectorIndex, + tokenizer: &tokenizers::Tokenizer, + prompts: &[PromptRecord], + heads: &[HeadId], + bases: &HashMap, + means: &HashMap, + pca_bases: &HashMap, + codebooks: &HashMap<(HeadId, PqConfig), PqCodebook>, + selected_groups: &[usize], + epochs: usize, + lr: f32, + l2: f32, +) -> Result, Box> { + let mut majority_counts: HashMap<(HeadId, PqConfig, usize), Vec> = HashMap::new(); + let mut samples: HashMap<(HeadId, PqConfig), Vec<(Vec, Vec)>> = HashMap::new(); + + visit_code_samples( + weights, + index, + tokenizer, + prompts, + heads, + bases, + means, + pca_bases, + codebooks, + "supervised-fit", + true, + 0, + 0, + false, + None, + |head, config, _pos, codes, _token_ids, _stratum, _, input_row, _, _, _| { + let input_row = + input_row.ok_or("missing layer-input row during supervised address fit")?; + for (group, &code) in codes.iter().enumerate() { + let levels = 1usize << config.bits_per_group; + let counts = majority_counts + .entry((head, config, group)) + .or_insert_with(|| vec![0; levels]); + counts[code] += 1; + } + samples + .entry((head, config)) + .or_default() + .push((input_row.to_vec(), codes.to_vec())); + Ok(()) + }, + )?; + + let mut models = HashMap::new(); + for ((head, config), _) in codebooks { + let train_samples = samples.get(&(*head, *config)).cloned().unwrap_or_default(); + let dim = train_samples.first().map(|(row, _)| row.len()).unwrap_or(0); + let mut group_majority = Vec::with_capacity(config.groups); + for group in 0..config.groups { + let majority = majority_counts + .get(&(*head, *config, group)) + .map(|counts| argmax_usize(counts)) + .unwrap_or(0); + group_majority.push(majority); + } + + let mut group_hyperplanes = vec![Vec::new(); config.groups]; + let mut group_train_accuracy = vec![0.0; config.groups]; + for &group in selected_groups { + let mut bit_planes = Vec::with_capacity(config.bits_per_group); + for bit in 0..config.bits_per_group { + let labels = train_samples + .iter() + .map(|(_, codes)| ((codes[group] >> bit) & 1) != 0) + .collect::>(); + let rows = train_samples + .iter() + .map(|(row, _)| row.as_slice()) + .collect::>(); + bit_planes.push(train_binary_hyperplane(&rows, &labels, dim, epochs, lr, l2)); + } + + let mut correct = 0usize; + for (row, codes) in &train_samples { + let predicted = predict_code_from_hyperplanes(row, &bit_planes); + if predicted == codes[group] { + correct += 1; + } + } + group_train_accuracy[group] = if train_samples.is_empty() { + 0.0 + } else { + correct as f64 / train_samples.len() as f64 + }; + group_hyperplanes[group] = bit_planes; + } + + models.insert( + (*head, *config), + AddressSupervisedGroupModel { + groups: selected_groups.to_vec(), + bits_per_group: config.bits_per_group, + epochs, + lr, + l2, + group_majority, + group_hyperplanes, + group_train_accuracy, + }, + ); + } + + Ok(models) +} + +pub(super) fn fit_majority_codes_for_codebooks( + weights: &mut larql_inference::ModelWeights, + index: &VectorIndex, + tokenizer: &tokenizers::Tokenizer, + prompts: &[PromptRecord], + heads: &[HeadId], + bases: &HashMap, + means: &HashMap, + pca_bases: &HashMap, + codebooks: &HashMap<(HeadId, PqConfig), PqCodebook>, +) -> Result>, Box> { + let mut majority_counts: HashMap<(HeadId, PqConfig, usize), Vec> = HashMap::new(); + + visit_code_samples( + weights, + index, + tokenizer, + prompts, + heads, + bases, + means, + pca_bases, + codebooks, + "majority-fit", + false, + 0, + 0, + false, + None, + |head, config, _pos, codes, _token_ids, _stratum, _, _, _, _, _| { + for (group, &code) in codes.iter().enumerate() { + let levels = 1usize << config.bits_per_group; + let counts = majority_counts + .entry((head, config, group)) + .or_insert_with(|| vec![0; levels]); + counts[code] += 1; + } + Ok(()) + }, + )?; + + let mut out = HashMap::new(); + for ((head, config), _) in codebooks { + let mut group_majority = Vec::with_capacity(config.groups); + for group in 0..config.groups { + group_majority.push( + majority_counts + .get(&(*head, *config, group)) + .map(|counts| argmax_usize(counts)) + .unwrap_or(0), + ); + } + out.insert((*head, *config), group_majority); + } + Ok(out) +} + +pub(super) fn collect_code_occurrences( + weights: &mut larql_inference::ModelWeights, + index: &VectorIndex, + tokenizer: &tokenizers::Tokenizer, + prompts: &[PromptRecord], + heads: &[HeadId], + bases: &HashMap, + means: &HashMap, + pca_bases: &HashMap, + codebooks: &HashMap<(HeadId, PqConfig), PqCodebook>, + selected_groups: &[usize], + selected_codes: &[usize], +) -> Result, Box> { + let mut records = Vec::new(); + visit_code_samples( + weights, + index, + tokenizer, + prompts, + heads, + bases, + means, + pca_bases, + codebooks, + "code-occurrence", + false, + 0, + 0, + true, + None, + |head, config, pos, codes, token_ids, stratum, prompt_id, _, _, _, attention_weights| { + for &group in selected_groups { + let code = codes[group]; + if !selected_codes.is_empty() && !selected_codes.contains(&code) { + continue; + } + let token_id = token_ids.get(pos).copied().unwrap_or(0); + let prev_token_id = pos + .checked_sub(1) + .and_then(|prev| token_ids.get(prev).copied()); + let attn_argmax = attention_weights.map(|weights| attention_argmax(weights, pos)); + let attn_argmax_token_id = + attn_argmax.and_then(|source| token_ids.get(source).copied()); + records.push(CodeOccurrenceRecord { + prompt_id: prompt_id.to_string(), + stratum: stratum.to_string(), + layer: head.layer, + head: head.head, + config, + group, + code, + position: pos, + token_id, + token_text: decode_token(tokenizer, token_id), + prev_token_id, + prev_token_text: prev_token_id.map(|id| decode_token(tokenizer, id)), + attn_argmax_position: attn_argmax, + attn_argmax_token_id, + attn_argmax_token_text: attn_argmax_token_id + .map(|id| decode_token(tokenizer, id)), + attn_entropy_bits: attention_weights + .map(|weights| attention_entropy_bits(weights, pos)), + attn_relation_class_key: attention_weights.map(|weights| { + attention_relation_key( + "attn_relation_class", + token_ids, + stratum, + pos, + weights, + ) + }), + }); + } + Ok(()) + }, + )?; + Ok(records) +} + +fn decode_token(tokenizer: &tokenizers::Tokenizer, token_id: u32) -> String { + tokenizer + .decode(&[token_id], true) + .unwrap_or_else(|_| format!("<{token_id}>")) +} + +fn visit_code_samples( + weights: &mut larql_inference::ModelWeights, + index: &VectorIndex, + tokenizer: &tokenizers::Tokenizer, + prompts: &[PromptRecord], + heads: &[HeadId], + bases: &HashMap, + means: &HashMap, + pca_bases: &HashMap, + codebooks: &HashMap<(HeadId, PqConfig), PqCodebook>, + label_prefix: &str, + with_layer_input: bool, + prev_ffn_feature_top_k: usize, + ffn_first_feature_top_k: usize, + with_attention_relation: bool, + reduced_qk_rank: Option, + mut visit: F, +) -> Result<(), Box> +where + F: FnMut( + HeadId, + PqConfig, + usize, + &[usize], + &[u32], + &str, + &str, + Option<&[f32]>, + Option<&[usize]>, + Option<&[usize]>, + Option<&[f32]>, + ) -> SampleVisitResult, +{ + let mut heads_by_layer: HashMap> = HashMap::new(); + for head in heads { + heads_by_layer.entry(head.layer).or_default().push(*head); + } + let max_target_layer = heads.iter().map(|head| head.layer).max().unwrap_or(0); + + for (prompt_idx, record) in prompts.iter().enumerate() { + let label = record + .id + .as_deref() + .or(record.stratum.as_deref()) + .unwrap_or("prompt"); + eprintln!( + " {} [{}/{}] {}", + label_prefix, + prompt_idx + 1, + prompts.len(), + label + ); + let token_ids = encode_prompt(tokenizer, &*weights.arch, &record.prompt)?; + if token_ids.is_empty() { + continue; + } + let stratum = record.stratum.as_deref().unwrap_or("unknown"); + let mut h = embed_tokens_pub(weights, &token_ids); + let ple_inputs = precompute_per_layer_inputs(weights, &h, &token_ids); + let mut prev_ffn_features_by_pos = vec![Vec::::new(); token_ids.len()]; + + for layer in 0..weights.num_layers { + let inserted = insert_q4k_layer_tensors(weights, index, layer)?; + if let Some(layer_heads) = heads_by_layer.get(&layer) { + let layer_input = if with_layer_input { + Some(h.clone()) + } else { + None + }; + let ffn_first_features_by_pos = if ffn_first_feature_top_k > 0 { + let ffn = WeightFfn { weights }; + let (_, activation) = run_ffn(weights, &h, layer, &ffn, true); + activation + .map(|activation| { + activation + .rows() + .into_iter() + .map(|row| { + top_feature_ids_from_activation_row( + row, + ffn_first_feature_top_k, + ) + }) + .collect::>() + }) + .unwrap_or_else(|| vec![Vec::::new(); token_ids.len()]) + } else { + vec![Vec::::new(); token_ids.len()] + }; + let capture = if with_attention_relation { + if let Some(qk_rank) = reduced_qk_rank { + let (_, pre_o, all_weights) = + run_attention_block_with_pre_o_and_reduced_qk_attention_weights( + weights, &h, layer, None, qk_rank, + ) + .ok_or_else(|| { + format!( + "pre-W_O/reduced-QK attention capture failed at layer {layer}" + ) + })?; + (pre_o, Some(all_weights)) + } else { + let (_, pre_o, all_weights) = + run_attention_block_with_pre_o_and_all_attention_weights( + weights, &h, layer, None, + ) + .ok_or_else(|| { + format!("pre-W_O/all-attention capture failed at layer {layer}") + })?; + (pre_o, Some(all_weights)) + } + } else { + let (_, pre_o) = run_attention_block_with_pre_o(weights, &h, layer) + .ok_or_else(|| format!("pre-W_O capture failed at layer {layer}"))?; + (pre_o, None) + }; + let (pre_o, all_weights) = capture; + let head_dim = weights.arch.head_dim_for_layer(layer); + for head in layer_heads { + let basis = bases.get(head).ok_or_else(|| { + format!("missing basis for L{}H{}", head.layer, head.head) + })?; + let head_means = means.get(head).ok_or_else(|| { + format!("missing means for L{}H{}", head.layer, head.head) + })?; + let pca_basis = pca_bases.get(head).ok_or_else(|| { + format!("missing PCA basis for L{}H{}", head.layer, head.head) + })?; + let start = head.head * head_dim; + let end = start + head_dim; + let head_codebooks = codebooks + .iter() + .filter(|((codebook_head, _), _)| codebook_head == head) + .collect::>(); + for pos in 0..pre_o.nrows() { + let row = pre_o.slice(s![pos, start..end]); + let values = row + .as_slice() + .ok_or("pre-W_O head row was not contiguous during address fit")?; + let base = head_means.positions.get(pos).unwrap_or(&head_means.global); + let residual = values + .iter() + .zip(base.iter()) + .map(|(&yi, &bi)| yi - bi) + .collect::>(); + let z = basis.residual_to_z(&residual); + let input_row = layer_input.as_ref().map(|input| input.row(pos).to_vec()); + let prev_features = prev_ffn_features_by_pos.get(pos).map(Vec::as_slice); + let ffn_first_features = + ffn_first_features_by_pos.get(pos).map(Vec::as_slice); + let attention_row = all_weights + .as_ref() + .and_then(|weights| weights.heads.get(head.head)) + .and_then(|head_weights| head_weights.get(pos)) + .map(Vec::as_slice); + for ((_, config), codebook) in &head_codebooks { + let coords = pca_basis.coordinates_with_rank(&z, config.k); + let codes = codebook.quantize_indices_for_stratum(&coords, stratum); + visit( + *head, + *config, + pos, + &codes, + &token_ids, + stratum, + label, + input_row.as_deref(), + prev_features, + ffn_first_features, + attention_row, + )?; + } + } + } + } + + if layer == max_target_layer { + remove_layer_tensors(weights, inserted); + break; + } + + { + let ffn = WeightFfn { weights }; + if let Some((h_new, activation, _)) = run_layer_with_ffn( + weights, + &h, + layer, + &ffn, + prev_ffn_feature_top_k > 0, + ple_inputs.get(layer), + None, + ) { + if let Some(activation) = activation { + prev_ffn_features_by_pos = activation + .rows() + .into_iter() + .map(|row| { + top_feature_ids_from_activation_row(row, prev_ffn_feature_top_k) + }) + .collect(); + } + h = h_new; + } + } + remove_layer_tensors(weights, inserted); + } + } + + Ok(()) +} diff --git a/crates/larql-cli/src/commands/dev/ov_rd/oracle_pq_eval.rs b/crates/larql-cli/src/commands/dev/ov_rd/oracle_pq_eval.rs new file mode 100644 index 00000000..25714a73 --- /dev/null +++ b/crates/larql-cli/src/commands/dev/ov_rd/oracle_pq_eval.rs @@ -0,0 +1,49 @@ +use larql_vindex::VectorIndex; + +use super::address::address_match_report; +use super::metrics::{argmax, kl_logp, log_softmax, top_k_indices}; +use super::oracle_pq_forward::{final_logits, forward_q4k_predicted_address_mode_d_head}; +use super::pq::ModeDTable; +use super::reports::AddressProbePromptReport; +use super::types::HeadId; + +pub(super) fn evaluate_predicted_address( + weights: &mut larql_inference::ModelWeights, + token_ids: &[u32], + index: &VectorIndex, + head: HeadId, + mode_d_table: &ModeDTable, + predicted_codes_by_position: &[Vec], + stratum: &str, + label: &str, + baseline_logp: &[f64], + baseline_top1: u32, + oracle_codes_by_position: &[Vec], +) -> Result> { + let address_match = address_match_report(oracle_codes_by_position, predicted_codes_by_position); + let predicted_hidden = forward_q4k_predicted_address_mode_d_head( + weights, + token_ids, + index, + head, + mode_d_table, + predicted_codes_by_position, + stratum, + )?; + let predicted_logits = final_logits(weights, &predicted_hidden); + let predicted_logp = log_softmax(&predicted_logits); + let predicted_top1 = argmax(&predicted_logits); + let predicted_top5 = top_k_indices(&predicted_logits, 5); + + Ok(AddressProbePromptReport { + id: label.to_string(), + stratum: stratum.to_string(), + kl: kl_logp(baseline_logp, &predicted_logp), + positions: oracle_codes_by_position.len(), + groups_correct: address_match.groups_correct, + groups_total: address_match.groups_total, + exact_address_match: address_match.exact_address_match, + top1_agree: baseline_top1 == predicted_top1, + baseline_top1_in_predicted_top5: predicted_top5.contains(&baseline_top1), + }) +} diff --git a/crates/larql-cli/src/commands/dev/ov_rd/oracle_pq_fit.rs b/crates/larql-cli/src/commands/dev/ov_rd/oracle_pq_fit.rs new file mode 100644 index 00000000..a0fc4a96 --- /dev/null +++ b/crates/larql-cli/src/commands/dev/ov_rd/oracle_pq_fit.rs @@ -0,0 +1,162 @@ +use std::collections::HashMap; + +use larql_inference::attention::run_attention_block_with_pre_o; +use larql_inference::forward::ple::precompute_per_layer_inputs; +use larql_inference::forward::{embed_tokens_pub, run_layer_with_ffn}; +use larql_inference::{encode_prompt, WeightFfn}; +use larql_vindex::VectorIndex; +use ndarray::s; + +use super::basis::{WoRoundtripBasis, ZPcaBasis}; +use super::pq::{kmeans_centroids, PqCodebook}; +use super::runtime::{insert_q4k_layer_tensors, remove_layer_tensors}; +use super::stats::StaticHeadMeans; +use super::types::{HeadId, PqConfig, PromptRecord}; + +pub(super) fn fit_pq_codebooks( + weights: &mut larql_inference::ModelWeights, + index: &VectorIndex, + tokenizer: &tokenizers::Tokenizer, + prompts: &[PromptRecord], + heads: &[HeadId], + bases: &HashMap, + means: &HashMap, + pca_bases: &HashMap, + configs: &[PqConfig], + iterations: usize, + stratum_conditioned_groups: &[usize], +) -> Result, Box> { + let max_k = configs.iter().map(|c| c.k).max().unwrap_or(0); + let mut heads_by_layer: HashMap> = HashMap::new(); + for head in heads { + heads_by_layer.entry(head.layer).or_default().push(*head); + } + + let mut samples: HashMap>> = HashMap::new(); + let mut samples_by_stratum: HashMap<(HeadId, String), Vec>> = HashMap::new(); + for head in heads { + samples.insert(*head, Vec::new()); + } + + for (prompt_idx, record) in prompts.iter().enumerate() { + let label = record + .id + .as_deref() + .or(record.stratum.as_deref()) + .unwrap_or("prompt"); + eprintln!(" pq-fit [{}/{}] {}", prompt_idx + 1, prompts.len(), label); + let token_ids = encode_prompt(tokenizer, &*weights.arch, &record.prompt)?; + if token_ids.is_empty() { + continue; + } + let stratum = record.stratum.as_deref().unwrap_or("unknown"); + let mut h = embed_tokens_pub(weights, &token_ids); + let ple_inputs = precompute_per_layer_inputs(weights, &h, &token_ids); + + for layer in 0..weights.num_layers { + let inserted = insert_q4k_layer_tensors(weights, index, layer)?; + if let Some(layer_heads) = heads_by_layer.get(&layer) { + let (_, pre_o) = run_attention_block_with_pre_o(weights, &h, layer) + .ok_or_else(|| format!("pre-W_O capture failed at layer {layer}"))?; + let head_dim = weights.arch.head_dim_for_layer(layer); + for head in layer_heads { + let basis = bases.get(head).expect("basis pre-created for PQ fit"); + let head_means = means.get(head).expect("means pre-created for PQ fit"); + let pca_basis = pca_bases.get(head).expect("PCA pre-created for PQ fit"); + if pca_basis.rank() < max_k { + return Err(format!( + "PCA rank {} is below requested K {} for L{}H{}", + pca_basis.rank(), + max_k, + head.layer, + head.head + ) + .into()); + } + let start = head.head * head_dim; + let end = start + head_dim; + let head_samples = samples.get_mut(head).expect("PQ samples missing"); + for pos in 0..pre_o.nrows() { + let row = pre_o.slice(s![pos, start..end]); + let values = row + .as_slice() + .ok_or("pre-W_O head row was not contiguous during PQ fit")?; + let base = head_means.positions.get(pos).unwrap_or(&head_means.global); + let residual = values + .iter() + .zip(base.iter()) + .map(|(&yi, &bi)| yi - bi) + .collect::>(); + let z = basis.residual_to_z(&residual); + let coords = pca_basis.coordinates_with_rank(&z, max_k); + head_samples.push(coords.clone()); + if !stratum_conditioned_groups.is_empty() { + samples_by_stratum + .entry((*head, stratum.to_string())) + .or_default() + .push(coords); + } + } + } + } + + { + let ffn = WeightFfn { weights }; + if let Some((h_new, _, _)) = + run_layer_with_ffn(weights, &h, layer, &ffn, false, ple_inputs.get(layer), None) + { + h = h_new; + } + } + remove_layer_tensors(weights, inserted); + } + } + + let mut codebooks = HashMap::new(); + for head in heads { + let head_samples = samples + .get(head) + .ok_or_else(|| format!("missing PQ samples for L{}H{}", head.layer, head.head))?; + for &config in configs { + let levels = 1usize << config.bits_per_group; + let group_dim = config.k / config.groups; + let mut centroids = Vec::with_capacity(config.groups); + for group in 0..config.groups { + let start = group * group_dim; + let group_samples = head_samples + .iter() + .map(|sample| sample[start..start + group_dim].to_vec()) + .collect::>(); + centroids.push(kmeans_centroids(&group_samples, levels, iterations)); + } + let mut stratum_centroids: HashMap>>> = + HashMap::new(); + for &group in stratum_conditioned_groups { + let start = group * group_dim; + for ((sample_head, stratum), stratum_samples) in samples_by_stratum.iter() { + if sample_head != head { + continue; + } + let group_samples = stratum_samples + .iter() + .map(|sample| sample[start..start + group_dim].to_vec()) + .collect::>(); + stratum_centroids + .entry(stratum.clone()) + .or_default() + .insert(group, kmeans_centroids(&group_samples, levels, iterations)); + } + } + codebooks.insert( + (*head, config), + PqCodebook { + config, + centroids, + stratum_centroids, + }, + ); + } + } + + Ok(codebooks) +} diff --git a/crates/larql-cli/src/commands/dev/ov_rd/oracle_pq_forward.rs b/crates/larql-cli/src/commands/dev/ov_rd/oracle_pq_forward.rs new file mode 100644 index 00000000..a7e24670 --- /dev/null +++ b/crates/larql-cli/src/commands/dev/ov_rd/oracle_pq_forward.rs @@ -0,0 +1,471 @@ +use std::collections::HashMap; + +use larql_inference::attention::{ + run_attention_block_with_pre_o_and_all_attention_weights, + run_attention_block_with_pre_o_and_reduced_qk_attention_weights, SharedKV, +}; +use larql_inference::forward::ple::precompute_per_layer_inputs; +use larql_inference::forward::{embed_tokens_pub, run_ffn, run_layer_with_ffn}; +use larql_inference::{hidden_to_raw_logits, WeightFfn}; +use larql_vindex::VectorIndex; +use ndarray::{s, Array2}; + +use super::address::top_feature_ids_from_activation_row; +use super::basis::{RoundtripPatchMetrics, WoRoundtripBasis, ZPcaBasis}; +use super::pq::{ModeDTable, PqCodebook}; +use super::runtime::{insert_q4k_layer_tensors, remove_layer_tensors}; +use super::stats::StaticHeadMeans; +use super::types::HeadId; + +pub(super) fn forward_q4k_oracle_pq_head( + weights: &mut larql_inference::ModelWeights, + token_ids: &[u32], + index: &VectorIndex, + head: HeadId, + basis: &WoRoundtripBasis, + pca_basis: &ZPcaBasis, + means: &StaticHeadMeans, + codebook: &PqCodebook, + stratum: &str, +) -> Result<(Array2, RoundtripPatchMetrics, Vec>), Box> { + let mut metrics = None; + let mut oracle_codes = Vec::new(); + + let h = larql_inference::vindex::predict_q4k_hidden_with_mapped_pre_o_head( + weights, + token_ids, + index, + head.layer, + head.head, + |original_head| { + let mut replacement = Vec::with_capacity(original_head.len()); + let mut pre_sq = 0.0; + let mut visible_sq = 0.0; + let mut count = 0usize; + for pos in 0..original_head.nrows() { + let row = original_head.row(pos); + let values = row + .as_slice() + .ok_or("pre-W_O head row was not contiguous during PQ")?; + let base = means.positions.get(pos).unwrap_or(&means.global); + let residual = values + .iter() + .zip(base.iter()) + .map(|(&yi, &bi)| yi - bi) + .collect::>(); + let z = basis.residual_to_z(&residual); + let coords = pca_basis.coordinates_with_rank(&z, codebook.config.k); + let codes = codebook.quantize_indices_for_stratum(&coords, stratum); + let quantized_coords = codebook.quantize_from_indices_for_stratum(&codes, stratum); + oracle_codes.push(codes); + let z_projected = pca_basis.reconstruct_from_coordinates(&quantized_coords); + let residual_projected = basis.z_to_residual(&z_projected); + let projected = residual_projected + .into_iter() + .zip(base.iter()) + .map(|(ri, &bi)| ri + bi) + .collect::>(); + for (&original, &recon) in values.iter().zip(projected.iter()) { + let delta = original as f64 - recon as f64; + pre_sq += delta * delta; + } + let delta = values + .iter() + .zip(projected.iter()) + .map(|(&original, &recon)| original as f64 - recon as f64) + .collect::>(); + visible_sq += basis.visible_sq_norm(&delta); + count += 1; + replacement.extend_from_slice(&projected); + } + metrics = Some(RoundtripPatchMetrics { + pre_wo_l2: (pre_sq / count.max(1) as f64).sqrt(), + wo_visible_l2: (visible_sq / count.max(1) as f64).sqrt(), + }); + Array2::from_shape_vec((original_head.nrows(), original_head.ncols()), replacement) + .map_err(|err| err.to_string()) + }, + )?; + + Ok(( + h, + metrics.ok_or("oracle PQ did not visit target layer")?, + oracle_codes, + )) +} + +pub(super) fn forward_q4k_oracle_pq_mode_d_head( + weights: &mut larql_inference::ModelWeights, + token_ids: &[u32], + index: &VectorIndex, + head: HeadId, + basis: &WoRoundtripBasis, + pca_basis: &ZPcaBasis, + means: &StaticHeadMeans, + codebook: &PqCodebook, + mode_d_table: &ModeDTable, + stratum: &str, +) -> Result, Box> { + let hidden_size = weights.hidden_size; + larql_inference::vindex::predict_q4k_hidden_with_mapped_head_residual_delta( + weights, + token_ids, + index, + head.layer, + head.head, + |original_head| { + let mut replacement_delta = Vec::with_capacity(original_head.nrows() * hidden_size); + for pos in 0..original_head.nrows() { + let row = original_head.row(pos); + let values = row + .as_slice() + .ok_or("pre-W_O head row was not contiguous during Mode D PQ")?; + let base = means.positions.get(pos).unwrap_or(&means.global); + let residual = values + .iter() + .zip(base.iter()) + .map(|(&yi, &bi)| yi - bi) + .collect::>(); + let z = basis.residual_to_z(&residual); + let coords = pca_basis.coordinates_with_rank(&z, codebook.config.k); + let codes = codebook.quantize_indices_for_stratum(&coords, stratum); + let delta = + mode_d_table.delta_for_position_codes_with_stratum(pos, &codes, stratum); + replacement_delta.extend_from_slice(&delta); + } + Array2::from_shape_vec((original_head.nrows(), hidden_size), replacement_delta) + .map_err(|err| err.to_string()) + }, + ) + .map_err(Into::into) +} + +pub(super) fn forward_q4k_predicted_address_mode_d_head( + weights: &mut larql_inference::ModelWeights, + token_ids: &[u32], + index: &VectorIndex, + head: HeadId, + mode_d_table: &ModeDTable, + predicted_codes_by_position: &[Vec], + stratum: &str, +) -> Result, Box> { + let mut replacement_delta = Vec::with_capacity(token_ids.len() * weights.hidden_size); + for pos in 0..token_ids.len() { + let codes = predicted_codes_by_position + .get(pos) + .ok_or("missing predicted address for sequence position")?; + let delta = mode_d_table.delta_for_position_codes_with_stratum(pos, codes, stratum); + replacement_delta.extend_from_slice(&delta); + } + let replacement_delta = + Array2::from_shape_vec((token_ids.len(), weights.hidden_size), replacement_delta)?; + larql_inference::vindex::predict_q4k_hidden_with_replaced_head_residual_delta( + weights, + token_ids, + index, + head.layer, + head.head, + &replacement_delta, + ) + .map_err(Into::into) +} + +pub(super) fn capture_layer_input_hidden( + weights: &mut larql_inference::ModelWeights, + token_ids: &[u32], + index: &VectorIndex, + target_layer: usize, +) -> Result, Box> { + let mut h = embed_tokens_pub(weights, token_ids); + let ple_inputs = precompute_per_layer_inputs(weights, &h, token_ids); + let mut kv_cache: HashMap = HashMap::new(); + + for layer in 0..target_layer { + let inserted = insert_q4k_layer_tensors(weights, index, layer)?; + let step = { + let shared_kv = weights + .arch + .kv_shared_source_layer(layer) + .and_then(|src| kv_cache.get(&src)); + let ffn = WeightFfn { weights }; + run_layer_with_ffn( + weights, + &h, + layer, + &ffn, + false, + ple_inputs.get(layer), + shared_kv, + ) + .map(|(h_new, _, kv_out)| (h_new, kv_out)) + }; + if let Some((h_new, kv_out)) = step { + h = h_new; + if let Some(kv) = kv_out { + kv_cache.insert(layer, kv); + } + } else { + remove_layer_tensors(weights, inserted); + return Err(format!("layer {layer} returned no output").into()); + } + remove_layer_tensors(weights, inserted); + } + + Ok(h) +} + +pub(super) fn capture_prev_ffn_feature_keys( + weights: &mut larql_inference::ModelWeights, + token_ids: &[u32], + index: &VectorIndex, + target_layer: usize, + feature_top_k: usize, +) -> Result>, Box> { + let mut prev_features_by_pos = vec![Vec::::new(); token_ids.len()]; + if target_layer == 0 || feature_top_k == 0 { + return Ok(prev_features_by_pos); + } + + let mut h = embed_tokens_pub(weights, token_ids); + let ple_inputs = precompute_per_layer_inputs(weights, &h, token_ids); + let mut kv_cache: HashMap = HashMap::new(); + + for layer in 0..target_layer { + let inserted = insert_q4k_layer_tensors(weights, index, layer)?; + let step = { + let shared_kv = weights + .arch + .kv_shared_source_layer(layer) + .and_then(|src| kv_cache.get(&src)); + let ffn = WeightFfn { weights }; + run_layer_with_ffn( + weights, + &h, + layer, + &ffn, + layer + 1 == target_layer, + ple_inputs.get(layer), + shared_kv, + ) + .map(|(h_new, activation, kv_out)| (h_new, activation, kv_out)) + }; + if let Some((h_new, activation, kv_out)) = step { + if let Some(activation) = activation { + prev_features_by_pos = activation + .rows() + .into_iter() + .map(|row| top_feature_ids_from_activation_row(row, feature_top_k)) + .collect(); + } + h = h_new; + if let Some(kv) = kv_out { + kv_cache.insert(layer, kv); + } + } else { + remove_layer_tensors(weights, inserted); + return Err(format!("layer {layer} returned no output").into()); + } + remove_layer_tensors(weights, inserted); + } + + Ok(prev_features_by_pos) +} + +pub(super) fn capture_ffn_first_feature_keys( + weights: &mut larql_inference::ModelWeights, + token_ids: &[u32], + index: &VectorIndex, + target_layer: usize, + feature_top_k: usize, +) -> Result>, Box> { + let mut h = embed_tokens_pub(weights, token_ids); + let ple_inputs = precompute_per_layer_inputs(weights, &h, token_ids); + let mut kv_cache: HashMap = HashMap::new(); + + for layer in 0..=target_layer { + let inserted = insert_q4k_layer_tensors(weights, index, layer)?; + if layer == target_layer { + let ffn = WeightFfn { weights }; + let (_, activation) = run_ffn(weights, &h, layer, &ffn, feature_top_k > 0); + remove_layer_tensors(weights, inserted); + if let Some(activation) = activation { + return Ok(activation + .rows() + .into_iter() + .map(|row| top_feature_ids_from_activation_row(row, feature_top_k)) + .collect()); + } + return Ok(vec![Vec::::new(); token_ids.len()]); + } + + let step = { + let shared_kv = weights + .arch + .kv_shared_source_layer(layer) + .and_then(|src| kv_cache.get(&src)); + let ffn = WeightFfn { weights }; + run_layer_with_ffn( + weights, + &h, + layer, + &ffn, + false, + ple_inputs.get(layer), + shared_kv, + ) + .map(|(h_new, _, kv_out)| (h_new, kv_out)) + }; + if let Some((h_new, kv_out)) = step { + h = h_new; + if let Some(kv) = kv_out { + kv_cache.insert(layer, kv); + } + } else { + remove_layer_tensors(weights, inserted); + return Err(format!("layer {layer} returned no output").into()); + } + remove_layer_tensors(weights, inserted); + } + + Err(format!("target layer {target_layer} was not reached").into()) +} + +pub(super) fn capture_attention_relation_rows( + weights: &mut larql_inference::ModelWeights, + token_ids: &[u32], + index: &VectorIndex, + head: HeadId, +) -> Result>, Box> { + let mut h = embed_tokens_pub(weights, token_ids); + let ple_inputs = precompute_per_layer_inputs(weights, &h, token_ids); + let mut kv_cache: HashMap = HashMap::new(); + + for layer in 0..=head.layer { + let inserted = insert_q4k_layer_tensors(weights, index, layer)?; + if layer == head.layer { + let shared_kv = weights + .arch + .kv_shared_source_layer(layer) + .and_then(|src| kv_cache.get(&src)); + let (_, _, all_weights) = run_attention_block_with_pre_o_and_all_attention_weights( + weights, &h, layer, shared_kv, + ) + .ok_or_else(|| { + format!( + "all-position attention capture failed at L{}H{}", + head.layer, head.head + ) + })?; + remove_layer_tensors(weights, inserted); + return all_weights.heads.get(head.head).cloned().ok_or_else(|| { + format!("attention capture missing L{}H{}", head.layer, head.head).into() + }); + } + + let step = { + let shared_kv = weights + .arch + .kv_shared_source_layer(layer) + .and_then(|src| kv_cache.get(&src)); + let ffn = WeightFfn { weights }; + run_layer_with_ffn( + weights, + &h, + layer, + &ffn, + false, + ple_inputs.get(layer), + shared_kv, + ) + .map(|(h_new, _, kv_out)| (h_new, kv_out)) + }; + if let Some((h_new, kv_out)) = step { + h = h_new; + if let Some(kv) = kv_out { + kv_cache.insert(layer, kv); + } + } else { + remove_layer_tensors(weights, inserted); + return Err(format!("layer {layer} returned no output").into()); + } + remove_layer_tensors(weights, inserted); + } + + Err(format!("target layer {} was not reached", head.layer).into()) +} + +pub(super) fn capture_reduced_qk_attention_rows( + weights: &mut larql_inference::ModelWeights, + token_ids: &[u32], + index: &VectorIndex, + head: HeadId, + qk_rank: usize, +) -> Result>, Box> { + let mut h = embed_tokens_pub(weights, token_ids); + let ple_inputs = precompute_per_layer_inputs(weights, &h, token_ids); + let mut kv_cache: HashMap = HashMap::new(); + + for layer in 0..=head.layer { + let inserted = insert_q4k_layer_tensors(weights, index, layer)?; + if layer == head.layer { + let shared_kv = weights + .arch + .kv_shared_source_layer(layer) + .and_then(|src| kv_cache.get(&src)); + let (_, _, all_weights) = + run_attention_block_with_pre_o_and_reduced_qk_attention_weights( + weights, &h, layer, shared_kv, qk_rank, + ) + .ok_or_else(|| { + format!( + "reduced-QK attention capture failed at L{}H{} rank {}", + head.layer, head.head, qk_rank + ) + })?; + remove_layer_tensors(weights, inserted); + return all_weights.heads.get(head.head).cloned().ok_or_else(|| { + format!( + "reduced-QK attention capture missing L{}H{}", + head.layer, head.head + ) + .into() + }); + } + + let step = { + let shared_kv = weights + .arch + .kv_shared_source_layer(layer) + .and_then(|src| kv_cache.get(&src)); + let ffn = WeightFfn { weights }; + run_layer_with_ffn( + weights, + &h, + layer, + &ffn, + false, + ple_inputs.get(layer), + shared_kv, + ) + .map(|(h_new, _, kv_out)| (h_new, kv_out)) + }; + if let Some((h_new, kv_out)) = step { + h = h_new; + if let Some(kv) = kv_out { + kv_cache.insert(layer, kv); + } + } else { + remove_layer_tensors(weights, inserted); + return Err(format!("layer {layer} returned no output").into()); + } + remove_layer_tensors(weights, inserted); + } + + Err(format!("target layer {} was not reached", head.layer).into()) +} + +pub(super) fn final_logits(weights: &larql_inference::ModelWeights, h: &Array2) -> Vec { + let last = h.nrows().saturating_sub(1); + let h_last = h.slice(s![last..last + 1, ..]).to_owned(); + hidden_to_raw_logits(weights, &h_last) +} diff --git a/crates/larql-cli/src/commands/dev/ov_rd/oracle_pq_mode_d.rs b/crates/larql-cli/src/commands/dev/ov_rd/oracle_pq_mode_d.rs new file mode 100644 index 00000000..d0516ee8 --- /dev/null +++ b/crates/larql-cli/src/commands/dev/ov_rd/oracle_pq_mode_d.rs @@ -0,0 +1,131 @@ +use std::collections::HashMap; + +use larql_vindex::VectorIndex; +use ndarray::s; + +use super::basis::{WoRoundtripBasis, ZPcaBasis}; +use super::pq::{ModeDTable, PqCodebook}; +use super::runtime::{insert_q4k_layer_tensors, remove_layer_tensors}; +use super::stats::StaticHeadMeans; +use super::types::{HeadId, PqConfig}; + +pub(super) fn corruption_keep_values(groups: usize) -> Vec { + [0usize, 4, 8, 12, 16, 24, 32, 40, groups] + .into_iter() + .filter(|value| *value <= groups) + .collect() +} + +pub(super) fn materialize_mode_d_tables( + weights: &mut larql_inference::ModelWeights, + index: &VectorIndex, + heads: &[HeadId], + bases: &HashMap, + means: &HashMap, + pca_bases: &HashMap, + codebooks: &HashMap<(HeadId, PqConfig), PqCodebook>, + stratum_conditioned_groups: &[usize], +) -> Result, Box> { + let mut heads_by_layer: HashMap> = HashMap::new(); + for head in heads { + heads_by_layer.entry(head.layer).or_default().push(*head); + } + + let mut tables = HashMap::new(); + for (layer, layer_heads) in heads_by_layer { + let inserted = insert_q4k_layer_tensors(weights, index, layer)?; + let w_o = weights + .tensors + .get(&weights.arch.attn_o_key(layer)) + .ok_or_else(|| format!("missing W_O tensor at layer {layer}"))?; + let head_dim = weights.arch.head_dim_for_layer(layer); + for head in layer_heads { + let start = head.head * head_dim; + let end = start + head_dim; + let w_o_head = w_o.slice(s![.., start..end]); + let head_means = means + .get(&head) + .ok_or_else(|| format!("missing means for L{}H{}", head.layer, head.head))?; + let static_global_delta = project_head_vector_to_hidden(&w_o_head, &head_means.global); + let static_delta_by_position = head_means + .positions + .iter() + .map(|mean| project_head_vector_to_hidden(&w_o_head, mean)) + .collect::>(); + let basis = bases + .get(&head) + .ok_or_else(|| format!("missing W_O basis for L{}H{}", head.layer, head.head))?; + let pca_basis = pca_bases + .get(&head) + .ok_or_else(|| format!("missing PCA basis for L{}H{}", head.layer, head.head))?; + + for ((codebook_head, config), codebook) in codebooks.iter() { + if *codebook_head != head { + continue; + } + let group_dim = config.k / config.groups; + let mut group_tables = Vec::with_capacity(config.groups); + for group in 0..config.groups { + let mut table = Vec::with_capacity(codebook.centroids[group].len()); + for centroid in &codebook.centroids[group] { + let mut coords = vec![0.0; config.k]; + let start_coord = group * group_dim; + coords[start_coord..start_coord + group_dim].copy_from_slice(centroid); + let z_part = pca_basis.reconstruct_from_coordinates(&coords); + let residual_part = basis.z_to_residual(&z_part); + table.push(project_head_vector_to_hidden(&w_o_head, &residual_part)); + } + group_tables.push(table); + } + let mut stratum_group_tables: HashMap>>> = + HashMap::new(); + for (stratum, groups) in &codebook.stratum_centroids { + for &group in stratum_conditioned_groups { + let Some(centroids) = groups.get(&group) else { + continue; + }; + let mut table = Vec::with_capacity(centroids.len()); + for centroid in centroids { + let mut coords = vec![0.0; config.k]; + let start_coord = group * group_dim; + coords[start_coord..start_coord + group_dim].copy_from_slice(centroid); + let z_part = pca_basis.reconstruct_from_coordinates(&coords); + let residual_part = basis.z_to_residual(&z_part); + table.push(project_head_vector_to_hidden(&w_o_head, &residual_part)); + } + stratum_group_tables + .entry(stratum.clone()) + .or_default() + .insert(group, table); + } + } + tables.insert( + (head, *config), + ModeDTable { + static_delta_by_position: static_delta_by_position.clone(), + static_global_delta: static_global_delta.clone(), + group_tables, + stratum_group_tables, + }, + ); + } + } + remove_layer_tensors(weights, inserted); + } + Ok(tables) +} + +fn project_head_vector_to_hidden( + w_o_head: &ndarray::ArrayBase, ndarray::Ix2>, + values: &[f32], +) -> Vec { + let mut out = vec![0.0f32; w_o_head.nrows()]; + for row in 0..w_o_head.nrows() { + let mut sum = 0.0f32; + for col in 0..w_o_head.ncols() { + sum += values[col] * w_o_head[[row, col]]; + } + out[row] = sum; + } + out +} diff --git a/crates/larql-cli/src/commands/dev/ov_rd/oracle_pq_reports.rs b/crates/larql-cli/src/commands/dev/ov_rd/oracle_pq_reports.rs new file mode 100644 index 00000000..8266a1a6 --- /dev/null +++ b/crates/larql-cli/src/commands/dev/ov_rd/oracle_pq_reports.rs @@ -0,0 +1,371 @@ +use std::collections::{BTreeMap, HashMap}; + +use super::metrics::{bool_rate, mean, percentile}; +use super::reports::{ + AddressCorruptionReport, AddressGroupImportanceReport, AddressProbePromptReport, + AddressProbeReport, AddressProbeStratumReport, CodeStabilityReport, OraclePqPointReport, + OraclePqPromptReport, +}; +use super::types::PqConfig; + +#[derive(Debug)] +pub(super) struct OraclePqPointAccumulator { + prompts: Vec, + address_probe_accumulators: HashMap, + address_corruption_accumulators: HashMap, + address_group_importance_accumulators: HashMap, +} + +impl OraclePqPointAccumulator { + pub(super) fn new() -> Self { + Self { + prompts: Vec::new(), + address_probe_accumulators: HashMap::new(), + address_corruption_accumulators: HashMap::new(), + address_group_importance_accumulators: HashMap::new(), + } + } + + pub(super) fn add(&mut self, prompt: OraclePqPromptReport) { + self.prompts.push(prompt); + } + + pub(super) fn add_address_probe( + &mut self, + name: &str, + selected_group_keys: &[String], + prompt: AddressProbePromptReport, + ) { + self.address_probe_accumulators + .entry(name.to_string()) + .or_insert_with(|| AddressProbeAccumulator::new_with_keys(name, selected_group_keys)) + .add(prompt); + } + + pub(super) fn add_address_corruption( + &mut self, + oracle_groups_kept: usize, + prompt: AddressProbePromptReport, + ) { + self.address_corruption_accumulators + .entry(oracle_groups_kept) + .or_insert_with(|| { + AddressProbeAccumulator::new(&format!("oracle_groups_kept_{oracle_groups_kept}")) + }) + .add(prompt); + } + + pub(super) fn add_address_group_importance( + &mut self, + replaced_group: usize, + prompt: AddressProbePromptReport, + ) { + self.address_group_importance_accumulators + .entry(replaced_group) + .or_insert_with(|| { + AddressProbeAccumulator::new(&format!("replaced_group_{replaced_group}")) + }) + .add(prompt); + } + + pub(super) fn finish( + self, + config: PqConfig, + hidden_dim: usize, + code_stability: Vec, + ) -> OraclePqPointReport { + let kls: Vec = self.prompts.iter().map(|p| p.kl).collect(); + let levels = 1usize << config.bits_per_group; + let mode_d_kls = self + .prompts + .iter() + .filter_map(|p| p.mode_d_kl) + .collect::>(); + let coeff_mode_d_diffs = self + .prompts + .iter() + .filter_map(|p| p.coeff_mode_d_max_abs_logit_diff) + .collect::>(); + OraclePqPointReport { + k: config.k, + groups: config.groups, + bits_per_group: config.bits_per_group, + oracle_address_bits: config.groups * config.bits_per_group, + coefficient_codebook_bytes_f32: config.groups + * levels + * (config.k / config.groups) + * std::mem::size_of::(), + mode_d_residual_table_bytes_bf16: config.groups * levels * hidden_dim * 2, + prompts: self.prompts.len(), + mean_kl: mean(&kls), + p95_kl: percentile(kls.clone(), 0.95), + max_kl: kls.iter().copied().fold(0.0, f64::max), + mean_delta_cross_entropy_bits: mean( + &self + .prompts + .iter() + .map(|p| p.delta_cross_entropy_bits) + .collect::>(), + ), + top1_agreement: bool_rate(self.prompts.iter().map(|p| p.top1_agree)), + top5_contains_baseline_top1: bool_rate( + self.prompts.iter().map(|p| p.baseline_top1_in_pq_top5), + ), + mean_baseline_top1_prob: mean( + &self + .prompts + .iter() + .map(|p| p.baseline_top1_prob) + .collect::>(), + ), + mean_pq_prob_of_baseline_top1: mean( + &self + .prompts + .iter() + .map(|p| p.pq_prob_of_baseline_top1) + .collect::>(), + ), + mean_baseline_top1_margin: mean( + &self + .prompts + .iter() + .map(|p| p.baseline_top1_margin) + .collect::>(), + ), + mode_d_mean_kl: if mode_d_kls.is_empty() { + None + } else { + Some(mean(&mode_d_kls)) + }, + mode_d_p95_kl: if mode_d_kls.is_empty() { + None + } else { + Some(percentile(mode_d_kls.clone(), 0.95)) + }, + mode_d_max_kl: if mode_d_kls.is_empty() { + None + } else { + Some(mode_d_kls.iter().copied().fold(0.0, f64::max)) + }, + mode_d_top1_agreement: if mode_d_kls.is_empty() { + None + } else { + Some(bool_rate( + self.prompts.iter().filter_map(|p| p.mode_d_top1_agree), + )) + }, + mode_d_top5_contains_baseline_top1: if mode_d_kls.is_empty() { + None + } else { + Some(bool_rate( + self.prompts + .iter() + .filter_map(|p| p.baseline_top1_in_mode_d_top5), + )) + }, + coeff_mode_d_max_abs_logit_diff: if coeff_mode_d_diffs.is_empty() { + None + } else { + Some(coeff_mode_d_diffs.iter().copied().fold(0.0, f64::max)) + }, + address_probes: self + .address_probe_accumulators + .into_values() + .map(|acc| acc.finish()) + .collect(), + address_corruption_sweep: self + .address_corruption_accumulators + .into_iter() + .map(|(oracle_groups_kept, acc)| acc.finish_corruption(oracle_groups_kept)) + .collect(), + address_group_importance: self + .address_group_importance_accumulators + .into_iter() + .map(|(replaced_group, acc)| acc.finish_group_importance(replaced_group)) + .collect(), + code_stability, + mean_pre_wo_l2: mean(&self.prompts.iter().map(|p| p.pre_wo_l2).collect::>()), + mean_wo_visible_l2: mean( + &self + .prompts + .iter() + .map(|p| p.wo_visible_l2) + .collect::>(), + ), + per_prompt: self.prompts, + } + } +} + +fn address_probe_by_stratum( + prompts: &[AddressProbePromptReport], +) -> Vec { + let mut by_stratum: BTreeMap> = BTreeMap::new(); + for prompt in prompts { + by_stratum + .entry(prompt.stratum.clone()) + .or_default() + .push(prompt); + } + + by_stratum + .into_iter() + .map(|(stratum, prompts)| { + let kls = prompts.iter().map(|prompt| prompt.kl).collect::>(); + let positions = prompts.iter().map(|prompt| prompt.positions).sum::(); + let groups_total = prompts + .iter() + .map(|prompt| prompt.groups_total) + .sum::() + .max(1); + let groups_correct = prompts + .iter() + .map(|prompt| prompt.groups_correct) + .sum::(); + AddressProbeStratumReport { + stratum, + prompts: prompts.len(), + positions, + group_accuracy: groups_correct as f64 / groups_total as f64, + mean_kl: mean(&kls), + p95_kl: percentile(kls.clone(), 0.95), + max_kl: kls.iter().copied().fold(0.0, f64::max), + top1_agreement: bool_rate(prompts.iter().map(|prompt| prompt.top1_agree)), + top5_contains_baseline_top1: bool_rate( + prompts + .iter() + .map(|prompt| prompt.baseline_top1_in_predicted_top5), + ), + } + }) + .collect() +} + +#[derive(Debug)] +struct AddressProbeAccumulator { + name: String, + selected_group_keys: Vec, + prompts: Vec, +} + +impl AddressProbeAccumulator { + fn new(name: &str) -> Self { + Self::new_with_keys(name, &[]) + } + + fn new_with_keys(name: &str, selected_group_keys: &[String]) -> Self { + Self { + name: name.to_string(), + selected_group_keys: selected_group_keys.to_vec(), + prompts: Vec::new(), + } + } + + fn add(&mut self, prompt: AddressProbePromptReport) { + self.prompts.push(prompt); + } + + fn finish(mut self) -> AddressProbeReport { + let kls = self.prompts.iter().map(|p| p.kl).collect::>(); + let positions = self.prompts.iter().map(|p| p.positions).sum::(); + let total_groups = self + .prompts + .iter() + .map(|p| p.groups_total) + .sum::() + .max(1); + let correct_groups = self.prompts.iter().map(|p| p.groups_correct).sum::(); + self.prompts + .sort_by(|a, b| b.kl.partial_cmp(&a.kl).unwrap_or(std::cmp::Ordering::Equal)); + AddressProbeReport { + name: self.name, + selected_group_keys: self.selected_group_keys, + prompts: self.prompts.len(), + positions, + group_accuracy: correct_groups as f64 / total_groups as f64, + exact_address_accuracy: bool_rate(self.prompts.iter().map(|p| p.exact_address_match)), + mean_groups_correct_per_sequence: mean( + &self + .prompts + .iter() + .map(|p| p.groups_correct as f64) + .collect::>(), + ), + mean_groups_correct_per_position: correct_groups as f64 / positions.max(1) as f64, + mean_kl: mean(&kls), + p95_kl: percentile(kls.clone(), 0.95), + max_kl: kls.iter().copied().fold(0.0, f64::max), + top1_agreement: bool_rate(self.prompts.iter().map(|p| p.top1_agree)), + top5_contains_baseline_top1: bool_rate( + self.prompts + .iter() + .map(|p| p.baseline_top1_in_predicted_top5), + ), + by_stratum: address_probe_by_stratum(&self.prompts), + worst_examples: self.prompts.into_iter().take(8).collect(), + } + } + + fn finish_corruption(mut self, oracle_groups_kept: usize) -> AddressCorruptionReport { + let kls = self.prompts.iter().map(|p| p.kl).collect::>(); + let positions = self.prompts.iter().map(|p| p.positions).sum::(); + let total_groups = self + .prompts + .iter() + .map(|p| p.groups_total) + .sum::() + .max(1); + let correct_groups = self.prompts.iter().map(|p| p.groups_correct).sum::(); + self.prompts + .sort_by(|a, b| b.kl.partial_cmp(&a.kl).unwrap_or(std::cmp::Ordering::Equal)); + AddressCorruptionReport { + label: self.name, + oracle_groups_kept, + prompts: self.prompts.len(), + positions, + group_accuracy: correct_groups as f64 / total_groups as f64, + exact_address_accuracy: bool_rate(self.prompts.iter().map(|p| p.exact_address_match)), + mean_kl: mean(&kls), + p95_kl: percentile(kls.clone(), 0.95), + max_kl: kls.iter().copied().fold(0.0, f64::max), + top1_agreement: bool_rate(self.prompts.iter().map(|p| p.top1_agree)), + top5_contains_baseline_top1: bool_rate( + self.prompts + .iter() + .map(|p| p.baseline_top1_in_predicted_top5), + ), + worst_examples: self.prompts.into_iter().take(8).collect(), + } + } + + fn finish_group_importance(mut self, replaced_group: usize) -> AddressGroupImportanceReport { + let kls = self.prompts.iter().map(|p| p.kl).collect::>(); + let positions = self.prompts.iter().map(|p| p.positions).sum::(); + let total_groups = self + .prompts + .iter() + .map(|p| p.groups_total) + .sum::() + .max(1); + let correct_groups = self.prompts.iter().map(|p| p.groups_correct).sum::(); + self.prompts + .sort_by(|a, b| b.kl.partial_cmp(&a.kl).unwrap_or(std::cmp::Ordering::Equal)); + AddressGroupImportanceReport { + replaced_group, + prompts: self.prompts.len(), + positions, + group_accuracy: correct_groups as f64 / total_groups as f64, + exact_address_accuracy: bool_rate(self.prompts.iter().map(|p| p.exact_address_match)), + mean_kl: mean(&kls), + p95_kl: percentile(kls.clone(), 0.95), + max_kl: kls.iter().copied().fold(0.0, f64::max), + top1_agreement: bool_rate(self.prompts.iter().map(|p| p.top1_agree)), + top5_contains_baseline_top1: bool_rate( + self.prompts + .iter() + .map(|p| p.baseline_top1_in_predicted_top5), + ), + worst_examples: self.prompts.into_iter().take(8).collect(), + } + } +} diff --git a/crates/larql-cli/src/commands/dev/ov_rd/oracle_pq_stability.rs b/crates/larql-cli/src/commands/dev/ov_rd/oracle_pq_stability.rs new file mode 100644 index 00000000..ab53299d --- /dev/null +++ b/crates/larql-cli/src/commands/dev/ov_rd/oracle_pq_stability.rs @@ -0,0 +1,277 @@ +use std::collections::HashMap; + +use larql_inference::attention::run_attention_block_with_pre_o; +use larql_inference::forward::ple::precompute_per_layer_inputs; +use larql_inference::forward::{embed_tokens_pub, run_layer_with_ffn}; +use larql_inference::{encode_prompt, WeightFfn}; +use larql_vindex::VectorIndex; +use ndarray::s; + +use super::basis::{WoRoundtripBasis, ZPcaBasis}; +use super::metrics::{argmax_usize, code_mass, entropy_bits, js_divergence_bits}; +use super::pq::PqCodebook; +use super::reports::{CodeStabilityReport, CodeStabilityStratumReport}; +use super::runtime::{insert_q4k_layer_tensors, remove_layer_tensors}; +use super::stats::StaticHeadMeans; +use super::types::{HeadId, PqConfig, PromptRecord}; + +#[derive(Debug, Clone)] +struct CodeDistributionCounts { + group_counts: HashMap>, + stratum_group_counts: HashMap>>, +} + +impl CodeDistributionCounts { + fn new(selected_groups: &[usize], levels: usize) -> Self { + Self { + group_counts: selected_groups + .iter() + .map(|&group| (group, vec![0; levels])) + .collect(), + stratum_group_counts: HashMap::new(), + } + } + + fn add(&mut self, group: usize, code: usize, stratum: &str, levels: usize) { + if let Some(counts) = self.group_counts.get_mut(&group) { + counts[code] += 1; + } + self.stratum_group_counts + .entry(stratum.to_string()) + .or_default() + .entry(group) + .or_insert_with(|| vec![0; levels])[code] += 1; + } +} + +pub(super) fn measure_code_stability( + weights: &mut larql_inference::ModelWeights, + index: &VectorIndex, + tokenizer: &tokenizers::Tokenizer, + train_prompts: &[PromptRecord], + eval_prompts: &[PromptRecord], + heads: &[HeadId], + bases: &HashMap, + means: &HashMap, + pca_bases: &HashMap, + codebooks: &HashMap<(HeadId, PqConfig), PqCodebook>, + selected_groups: &[usize], +) -> Result>, Box> { + let train = collect_code_distribution_counts( + weights, + index, + tokenizer, + train_prompts, + heads, + bases, + means, + pca_bases, + codebooks, + selected_groups, + "code-stability-train", + )?; + let eval = collect_code_distribution_counts( + weights, + index, + tokenizer, + eval_prompts, + heads, + bases, + means, + pca_bases, + codebooks, + selected_groups, + "code-stability-eval", + )?; + + let mut reports = HashMap::new(); + for ((head, config), _) in codebooks { + let levels = 1usize << config.bits_per_group; + let empty_counts = CodeDistributionCounts::new(selected_groups, levels); + let train_counts = train.get(&(*head, *config)).unwrap_or(&empty_counts); + let eval_counts = eval.get(&(*head, *config)).unwrap_or(&empty_counts); + let mut group_reports = Vec::new(); + for &group in selected_groups { + let train_group = train_counts + .group_counts + .get(&group) + .cloned() + .unwrap_or_else(|| vec![0; levels]); + let eval_group = eval_counts + .group_counts + .get(&group) + .cloned() + .unwrap_or_else(|| vec![0; levels]); + let train_top = argmax_usize(&train_group); + let eval_top = argmax_usize(&eval_group); + let mut stratum_names = train_counts + .stratum_group_counts + .keys() + .chain(eval_counts.stratum_group_counts.keys()) + .cloned() + .collect::>(); + stratum_names.sort(); + stratum_names.dedup(); + let by_stratum = stratum_names + .into_iter() + .map(|stratum| { + let train_s = train_counts + .stratum_group_counts + .get(&stratum) + .and_then(|groups| groups.get(&group)) + .cloned() + .unwrap_or_else(|| vec![0; levels]); + let eval_s = eval_counts + .stratum_group_counts + .get(&stratum) + .and_then(|groups| groups.get(&group)) + .cloned() + .unwrap_or_else(|| vec![0; levels]); + let train_s_top = argmax_usize(&train_s); + let eval_s_top = argmax_usize(&eval_s); + CodeStabilityStratumReport { + stratum, + train_positions: train_s.iter().sum(), + eval_positions: eval_s.iter().sum(), + train_entropy_bits: entropy_bits(&train_s), + eval_entropy_bits: entropy_bits(&eval_s), + train_top_code: train_s_top, + train_top_code_mass: code_mass(&train_s, train_s_top), + eval_top_code: eval_s_top, + eval_top_code_mass: code_mass(&eval_s, eval_s_top), + train_eval_js_bits: js_divergence_bits(&train_s, &eval_s), + } + }) + .collect(); + group_reports.push(CodeStabilityReport { + group, + train_positions: train_group.iter().sum(), + eval_positions: eval_group.iter().sum(), + train_entropy_bits: entropy_bits(&train_group), + eval_entropy_bits: entropy_bits(&eval_group), + train_top_code: train_top, + train_top_code_mass: code_mass(&train_group, train_top), + eval_top_code: eval_top, + eval_top_code_mass: code_mass(&eval_group, eval_top), + train_eval_js_bits: js_divergence_bits(&train_group, &eval_group), + by_stratum, + }); + } + reports.insert((*head, *config), group_reports); + } + + Ok(reports) +} + +fn collect_code_distribution_counts( + weights: &mut larql_inference::ModelWeights, + index: &VectorIndex, + tokenizer: &tokenizers::Tokenizer, + prompts: &[PromptRecord], + heads: &[HeadId], + bases: &HashMap, + means: &HashMap, + pca_bases: &HashMap, + codebooks: &HashMap<(HeadId, PqConfig), PqCodebook>, + selected_groups: &[usize], + label_prefix: &str, +) -> Result, Box> { + let mut heads_by_layer: HashMap> = HashMap::new(); + for head in heads { + heads_by_layer.entry(head.layer).or_default().push(*head); + } + let mut counts = HashMap::new(); + for ((head, config), _) in codebooks { + counts.insert( + (*head, *config), + CodeDistributionCounts::new(selected_groups, 1usize << config.bits_per_group), + ); + } + + for (prompt_idx, record) in prompts.iter().enumerate() { + let label = record + .id + .as_deref() + .or(record.stratum.as_deref()) + .unwrap_or("prompt"); + eprintln!( + " {label_prefix} [{}/{}] {}", + prompt_idx + 1, + prompts.len(), + label + ); + let token_ids = encode_prompt(tokenizer, &*weights.arch, &record.prompt)?; + if token_ids.is_empty() { + continue; + } + let stratum = record.stratum.as_deref().unwrap_or("unknown"); + let mut h = embed_tokens_pub(weights, &token_ids); + let ple_inputs = precompute_per_layer_inputs(weights, &h, &token_ids); + + for layer in 0..weights.num_layers { + let inserted = insert_q4k_layer_tensors(weights, index, layer)?; + if let Some(layer_heads) = heads_by_layer.get(&layer) { + let (_, pre_o) = run_attention_block_with_pre_o(weights, &h, layer) + .ok_or_else(|| format!("pre-W_O capture failed at layer {layer}"))?; + let head_dim = weights.arch.head_dim_for_layer(layer); + for head in layer_heads { + let basis = bases.get(head).ok_or_else(|| { + format!("missing basis for L{}H{}", head.layer, head.head) + })?; + let head_means = means.get(head).ok_or_else(|| { + format!("missing means for L{}H{}", head.layer, head.head) + })?; + let pca_basis = pca_bases.get(head).ok_or_else(|| { + format!("missing PCA basis for L{}H{}", head.layer, head.head) + })?; + let start = head.head * head_dim; + let end = start + head_dim; + let head_codebooks = codebooks + .iter() + .filter(|((codebook_head, _), _)| codebook_head == head) + .collect::>(); + for pos in 0..pre_o.nrows() { + let row = pre_o.slice(s![pos, start..end]); + let values = row + .as_slice() + .ok_or("pre-W_O head row was not contiguous during code stability")?; + let base = head_means.positions.get(pos).unwrap_or(&head_means.global); + let residual = values + .iter() + .zip(base.iter()) + .map(|(&yi, &bi)| yi - bi) + .collect::>(); + let z = basis.residual_to_z(&residual); + for ((_, config), codebook) in &head_codebooks { + let coords = pca_basis.coordinates_with_rank(&z, config.k); + let codes = codebook.quantize_indices_for_stratum(&coords, stratum); + let levels = 1usize << config.bits_per_group; + let point_counts = + counts.get_mut(&(*head, *config)).ok_or_else(|| { + format!( + "missing code stability counts for L{}H{} {:?}", + head.layer, head.head, config + ) + })?; + for &group in selected_groups { + point_counts.add(group, codes[group], stratum, levels); + } + } + } + } + } + + { + let ffn = WeightFfn { weights }; + if let Some((h_new, _, _)) = + run_layer_with_ffn(weights, &h, layer, &ffn, false, ple_inputs.get(layer), None) + { + h = h_new; + } + } + remove_layer_tensors(weights, inserted); + } + } + + Ok(counts) +} diff --git a/crates/larql-cli/src/commands/dev/ov_rd/pq.rs b/crates/larql-cli/src/commands/dev/ov_rd/pq.rs new file mode 100644 index 00000000..85685fd5 --- /dev/null +++ b/crates/larql-cli/src/commands/dev/ov_rd/pq.rs @@ -0,0 +1,149 @@ +use std::collections::HashMap; + +use super::types::PqConfig; + +#[derive(Debug, Clone)] +pub(super) struct PqCodebook { + pub(super) config: PqConfig, + pub(super) centroids: Vec>>, + pub(super) stratum_centroids: HashMap>>>, +} + +impl PqCodebook { + pub(super) fn quantize_indices_for_stratum(&self, coords: &[f64], stratum: &str) -> Vec { + let group_dim = self.config.k / self.config.groups; + (0..self.config.groups) + .map(|group| { + let start = group * group_dim; + let end = start + group_dim; + nearest_centroid_index( + &coords[start..end], + self.centroids_for_group(stratum, group), + ) + }) + .collect() + } + + pub(super) fn quantize_from_indices_for_stratum( + &self, + indices: &[usize], + stratum: &str, + ) -> Vec { + let group_dim = self.config.k / self.config.groups; + let mut out = vec![0.0; self.config.k]; + for (group, &index) in indices.iter().take(self.config.groups).enumerate() { + let start = group * group_dim; + let end = start + group_dim; + let centroid = &self.centroids_for_group(stratum, group)[index]; + out[start..end].copy_from_slice(centroid); + } + out + } + + fn centroids_for_group(&self, stratum: &str, group: usize) -> &[Vec] { + self.stratum_centroids + .get(stratum) + .and_then(|groups| groups.get(&group)) + .unwrap_or(&self.centroids[group]) + } +} + +#[derive(Debug, Clone)] +pub(super) struct ModeDTable { + pub(super) static_delta_by_position: Vec>, + pub(super) static_global_delta: Vec, + pub(super) group_tables: Vec>>, + pub(super) stratum_group_tables: HashMap>>>, +} + +impl ModeDTable { + pub(super) fn delta_for_position_codes_with_stratum( + &self, + position: usize, + codes: &[usize], + stratum: &str, + ) -> Vec { + let mut out = self + .static_delta_by_position + .get(position) + .unwrap_or(&self.static_global_delta) + .clone(); + for (group, &code) in codes.iter().enumerate() { + let table = &self.table_for_group(stratum, group)[code]; + for (dst, &value) in out.iter_mut().zip(table.iter()) { + *dst += value; + } + } + out + } + + fn table_for_group(&self, stratum: &str, group: usize) -> &[Vec] { + self.stratum_group_tables + .get(stratum) + .and_then(|groups| groups.get(&group)) + .unwrap_or(&self.group_tables[group]) + } +} + +pub(super) fn kmeans_centroids(samples: &[Vec], k: usize, iterations: usize) -> Vec> { + if samples.is_empty() { + return vec![Vec::new(); k]; + } + let dim = samples[0].len(); + let mut centroids = (0..k) + .map(|idx| samples[(idx * samples.len()) / k].clone()) + .collect::>(); + let mut assignments = vec![0usize; samples.len()]; + for _ in 0..iterations { + let mut changed = false; + for (sample_idx, sample) in samples.iter().enumerate() { + let nearest = nearest_centroid_index(sample, ¢roids); + if assignments[sample_idx] != nearest { + assignments[sample_idx] = nearest; + changed = true; + } + } + let mut sums = vec![vec![0.0; dim]; k]; + let mut counts = vec![0usize; k]; + for (sample, &cluster) in samples.iter().zip(assignments.iter()) { + counts[cluster] += 1; + for (dst, &value) in sums[cluster].iter_mut().zip(sample.iter()) { + *dst += value; + } + } + for cluster in 0..k { + if counts[cluster] == 0 { + continue; + } + let inv = 1.0 / counts[cluster] as f64; + for value in &mut sums[cluster] { + *value *= inv; + } + centroids[cluster] = sums[cluster].clone(); + } + if !changed { + break; + } + } + centroids +} + +pub(super) fn nearest_centroid_index(sample: &[f64], centroids: &[Vec]) -> usize { + let mut best_idx = 0usize; + let mut best_dist = f64::INFINITY; + for (idx, centroid) in centroids.iter().enumerate() { + let dist = sample + .iter() + .zip(centroid.iter()) + .map(|(&a, &b)| { + let d = a - b; + d * d + }) + .sum::(); + if dist < best_dist { + best_dist = dist; + best_idx = idx; + } + } + best_idx +} diff --git a/crates/larql-cli/src/commands/dev/ov_rd/pq_exception.rs b/crates/larql-cli/src/commands/dev/ov_rd/pq_exception.rs new file mode 100644 index 00000000..06adb68b --- /dev/null +++ b/crates/larql-cli/src/commands/dev/ov_rd/pq_exception.rs @@ -0,0 +1,1245 @@ +use std::collections::HashMap; +use std::path::PathBuf; +use std::time::Instant; + +use clap::Args; +use larql_inference::attention::run_attention_block_with_pre_o; +use larql_inference::forward::ple::precompute_per_layer_inputs; +use larql_inference::forward::{embed_tokens_pub, run_layer_with_ffn}; +use larql_inference::{encode_prompt, WeightFfn}; +use larql_vindex::{ + load_model_weights_q4k, load_vindex_tokenizer, SilentLoadCallbacks, VectorIndex, +}; +use ndarray::{s, Array2}; + +use super::basis::{build_roundtrip_bases, fit_z_pca_bases, WoRoundtripBasis, ZPcaBasis}; +use super::input::{ + limit_prompts_per_stratum, load_prompts, parse_head_spec, parse_pq_configs, parse_usize_list, + split_prompt_records, +}; +use super::metrics::{ + argmax, bool_rate, kl_logp, log_softmax, mean, percentile, token_prob, top_k_indices, +}; +use super::oracle_pq_fit::fit_pq_codebooks; +use super::oracle_pq_forward::{final_logits, forward_q4k_oracle_pq_mode_d_head}; +use super::oracle_pq_mode_d::materialize_mode_d_tables; +use super::pq::{kmeans_centroids, nearest_centroid_index, ModeDTable, PqCodebook}; +use super::reports::{ + OraclePqExceptionHeadReport, OraclePqExceptionPointReport, OraclePqExceptionPromptReport, + OraclePqExceptionReport, +}; +use super::runtime::{insert_q4k_layer_tensors, remove_layer_tensors}; +use super::static_replace::fit_static_means; +use super::stats::StaticHeadMeans; +use super::types::{HeadId, PqConfig, PromptRecord}; + +#[derive(Args)] +pub(super) struct OraclePqExceptionArgs { + /// Self-contained Q4K vindex directory. + #[arg(long)] + index: PathBuf, + + /// JSONL prompt file. Each line must include at least {"prompt": "..."}. + #[arg(long)] + prompts: PathBuf, + + /// Output directory. + #[arg(long)] + out: PathBuf, + + /// Explicit heads as layer:head comma list, e.g. 20:6. + #[arg(long)] + heads: String, + + /// Base PQ config as K:groups:bits, e.g. 192:48:4. + #[arg(long)] + base_config: String, + + /// Comma-separated exception edit counts. + #[arg(long, default_value = "4,8,16,32")] + exception_edits: String, + + /// Comma-separated top-error fractions used to fit exception edits. + #[arg(long, default_value = "1.0,0.25,0.1")] + tail_fracs: String, + + /// Training-position selector for exception fitting: residual-error, prompt-kl, position-restore-kl, or position-restore-ce. + #[arg(long, default_value = "residual-error")] + tail_selector: String, + + /// Exception catalogue fitting method: kmeans or exemplar. + #[arg(long, default_value = "kmeans")] + exception_fit: String, + + /// Candidate positions per prompt/head for position-local restore selectors. + #[arg(long, default_value_t = 4)] + position_candidates_per_prompt: usize, + + /// Relative singular value cutoff for retained W_O-visible directions. + #[arg(long, default_value_t = 1e-6)] + sigma_rel_cutoff: f64, + + /// Lloyd iterations for the base PQ codebook. + #[arg(long, default_value_t = 25)] + pq_iters: usize, + + /// Lloyd iterations for exception residual catalogues. + #[arg(long, default_value_t = 25)] + exception_iters: usize, + + /// Limit prompts for bounded oracle runs. + #[arg(long)] + max_prompts: Option, + + /// Keep at most N prompts per stratum after loading. + #[arg(long)] + max_per_stratum: Option, + + /// Evaluate only prompts where prompt_index % eval_mod == eval_offset. + /// The remaining prompts are used to fit static means, PCA, PQ, and exceptions. + #[arg(long)] + eval_mod: Option, + + /// Held-out modulo offset used with --eval-mod. + #[arg(long, default_value_t = 0)] + eval_offset: usize, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct ExceptionKey { + head: HeadId, + edits: usize, + tail_frac_key: u64, +} + +#[derive(Debug, Clone)] +struct ExceptionCatalog { + edits: usize, + tail_frac: f64, + train_error_samples: usize, + train_error_samples_used: usize, + centroids: Vec>, +} + +#[derive(Debug, Clone)] +struct ErrorSample { + score: f64, + sq_norm: f64, + values: Vec, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum TailSelector { + ResidualError, + PromptKl, + PositionRestoreKl, + PositionRestoreCe, +} + +impl TailSelector { + fn parse(value: &str) -> Result> { + match value { + "residual-error" => Ok(Self::ResidualError), + "prompt-kl" => Ok(Self::PromptKl), + "position-restore-kl" => Ok(Self::PositionRestoreKl), + "position-restore-ce" => Ok(Self::PositionRestoreCe), + other => Err(format!( + "invalid --tail-selector '{other}', expected residual-error, prompt-kl, position-restore-kl, or position-restore-ce" + ) + .into()), + } + } + + fn as_str(self) -> &'static str { + match self { + Self::ResidualError => "residual-error", + Self::PromptKl => "prompt-kl", + Self::PositionRestoreKl => "position-restore-kl", + Self::PositionRestoreCe => "position-restore-ce", + } + } + + fn is_position_restore(self) -> bool { + matches!(self, Self::PositionRestoreKl | Self::PositionRestoreCe) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ExceptionFit { + Kmeans, + Exemplar, +} + +impl ExceptionFit { + fn parse(value: &str) -> Result> { + match value { + "kmeans" => Ok(Self::Kmeans), + "exemplar" => Ok(Self::Exemplar), + other => Err( + format!("invalid --exception-fit '{other}', expected kmeans or exemplar").into(), + ), + } + } + + fn as_str(self) -> &'static str { + match self { + Self::Kmeans => "kmeans", + Self::Exemplar => "exemplar", + } + } +} + +pub(super) fn run_oracle_pq_exception( + args: OraclePqExceptionArgs, +) -> Result<(), Box> { + std::fs::create_dir_all(&args.out)?; + + eprintln!("Loading vindex: {}", args.index.display()); + let start = Instant::now(); + let mut cb = SilentLoadCallbacks; + let mut index = VectorIndex::load_vindex(&args.index, &mut cb)?; + index.load_attn_q4k(&args.index)?; + index.load_interleaved_q4k(&args.index)?; + let mut weights = load_model_weights_q4k(&args.index, &mut cb)?; + let tokenizer = load_vindex_tokenizer(&args.index)?; + if weights.arch.is_hybrid_moe() { + return Err("ov-rd oracle-pq-exception currently supports dense FFN vindexes only".into()); + } + eprintln!( + " {} layers, hidden_size={}, q_heads={}, head_dim={} ({:.1}s)", + weights.num_layers, + weights.hidden_size, + weights.num_q_heads, + weights.head_dim, + start.elapsed().as_secs_f64() + ); + + let selected_heads = parse_head_spec(&args.heads)?; + if selected_heads.is_empty() { + return Err("no heads selected for oracle PQ exception".into()); + } + let mut base_configs = parse_pq_configs(&args.base_config)?; + if base_configs.len() != 1 { + return Err("--base-config must contain exactly one K:groups:bits config".into()); + } + let base_config = base_configs.remove(0); + let mut exception_edits = parse_usize_list(&args.exception_edits)?; + exception_edits.sort_unstable(); + exception_edits.dedup(); + if exception_edits.is_empty() || exception_edits.iter().any(|&edits| edits == 0) { + return Err("--exception-edits values must be greater than zero".into()); + } + let mut tail_fracs = parse_f64_list(&args.tail_fracs)?; + tail_fracs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + tail_fracs.dedup_by(|a, b| (*a - *b).abs() < f64::EPSILON); + if tail_fracs.is_empty() + || tail_fracs + .iter() + .any(|&frac| !(frac.is_finite() && frac > 0.0 && frac <= 1.0)) + { + return Err("--tail-fracs values must be finite and in (0, 1]".into()); + } + let tail_selector = TailSelector::parse(&args.tail_selector)?; + let exception_fit = ExceptionFit::parse(&args.exception_fit)?; + + let mut prompts = load_prompts(&args.prompts, args.max_prompts)?; + if let Some(max_per_stratum) = args.max_per_stratum { + prompts = limit_prompts_per_stratum(prompts, max_per_stratum); + } + let prompts_seen = prompts.len(); + let (fit_prompts, eval_prompts) = if let Some(eval_mod) = args.eval_mod { + split_prompt_records(&prompts, eval_mod, args.eval_offset)? + } else { + (prompts.clone(), prompts) + }; + eprintln!("Selected heads: {:?}", selected_heads); + eprintln!("Base PQ config: {:?}", base_config); + eprintln!("Exception edits: {:?}", exception_edits); + eprintln!("Tail fractions: {:?}", tail_fracs); + eprintln!("Tail selector: {}", tail_selector.as_str()); + eprintln!("Exception fit: {}", exception_fit.as_str()); + eprintln!( + "Position candidates per prompt: {}", + args.position_candidates_per_prompt + ); + eprintln!("Prompts: {}", prompts_seen); + + eprintln!("Fitting position-mean static bases"); + let means = fit_static_means( + &mut weights, + &index, + &tokenizer, + &fit_prompts, + &selected_heads, + )?; + + eprintln!("Building W_O-visible bases"); + let bases = + build_roundtrip_bases(&mut weights, &index, &selected_heads, args.sigma_rel_cutoff)?; + + eprintln!("Fitting empirical z-space PCA bases"); + let pca_bases = fit_z_pca_bases( + &mut weights, + &index, + &tokenizer, + &fit_prompts, + &selected_heads, + &bases, + &means, + )?; + + eprintln!("Fitting base product quantizer"); + let base_codebooks = fit_pq_codebooks( + &mut weights, + &index, + &tokenizer, + &fit_prompts, + &selected_heads, + &bases, + &means, + &pca_bases, + &[base_config], + args.pq_iters, + &[], + )?; + + eprintln!("Materializing base Mode D tables"); + let base_tables = materialize_mode_d_tables( + &mut weights, + &index, + &selected_heads, + &bases, + &means, + &pca_bases, + &base_codebooks, + &[], + )?; + let w_o_heads = copy_w_o_heads(&mut weights, &index, &selected_heads)?; + let prompt_scores = if tail_selector == TailSelector::PromptKl { + eprintln!("Measuring fit-prompt base-PQ KL for exception selection"); + measure_fit_prompt_base_pq_kl( + &mut weights, + &index, + &tokenizer, + &fit_prompts, + &selected_heads, + &bases, + &means, + &pca_bases, + &base_codebooks, + &base_tables, + base_config, + )? + } else { + HashMap::new() + }; + let position_scores = if tail_selector.is_position_restore() { + eprintln!("Measuring position-local restore gains for exception selection"); + measure_fit_position_restore_gains( + &mut weights, + &index, + &tokenizer, + &fit_prompts, + &selected_heads, + &bases, + &means, + &pca_bases, + &base_codebooks, + &base_tables, + &w_o_heads, + base_config, + tail_selector, + args.position_candidates_per_prompt, + )? + } else { + HashMap::new() + }; + + eprintln!("Fitting exception residual catalogues"); + let exception_catalogs = fit_exception_catalogs( + &mut weights, + &index, + &tokenizer, + &fit_prompts, + &selected_heads, + &bases, + &means, + &pca_bases, + &base_codebooks, + &base_tables, + &w_o_heads, + base_config, + &exception_edits, + &tail_fracs, + tail_selector, + exception_fit, + &prompt_scores, + &position_scores, + args.exception_iters, + )?; + + let mut accumulators: HashMap = HashMap::new(); + for head in &selected_heads { + for &edits in &exception_edits { + for &tail_frac in &tail_fracs { + accumulators.insert( + ExceptionKey { + head: *head, + edits, + tail_frac_key: tail_frac_key(tail_frac), + }, + PqExceptionAccumulator::new(), + ); + } + } + } + + for (prompt_idx, record) in eval_prompts.iter().enumerate() { + let label = prompt_label(record); + eprintln!(" [{}/{}] {}", prompt_idx + 1, eval_prompts.len(), label); + let token_ids = encode_prompt(&tokenizer, &*weights.arch, &record.prompt)?; + if token_ids.is_empty() { + continue; + } + let stratum = record.stratum.as_deref().unwrap_or("unknown"); + let baseline_hidden = + larql_inference::vindex::predict_q4k_hidden(&mut weights, &token_ids, &index, None); + let baseline_logits = final_logits(&weights, &baseline_hidden); + let baseline_logp = log_softmax(&baseline_logits); + let baseline_top1 = argmax(&baseline_logits); + let baseline_top2 = top_k_indices(&baseline_logits, 2); + let baseline_top2_token = baseline_top2.get(1).copied().unwrap_or(baseline_top1); + let baseline_top1_prob = token_prob(&baseline_logp, baseline_top1); + let baseline_top2_prob = token_prob(&baseline_logp, baseline_top2_token); + let baseline_top1_margin = baseline_top1_prob - baseline_top2_prob; + + for head in &selected_heads { + let basis = bases + .get(head) + .ok_or_else(|| format!("missing basis for L{}H{}", head.layer, head.head))?; + let pca_basis = pca_bases + .get(head) + .ok_or_else(|| format!("missing PCA basis for L{}H{}", head.layer, head.head))?; + let head_means = means + .get(head) + .ok_or_else(|| format!("missing means for L{}H{}", head.layer, head.head))?; + let codebook = base_codebooks.get(&(*head, base_config)).ok_or_else(|| { + format!("missing base codebook for L{}H{}", head.layer, head.head) + })?; + let table = base_tables + .get(&(*head, base_config)) + .ok_or_else(|| format!("missing base table for L{}H{}", head.layer, head.head))?; + let w_o_head = w_o_heads + .get(head) + .ok_or_else(|| format!("missing W_O head for L{}H{}", head.layer, head.head))?; + for &edits in &exception_edits { + for &tail_frac in &tail_fracs { + let key = ExceptionKey { + head: *head, + edits, + tail_frac_key: tail_frac_key(tail_frac), + }; + let catalog = exception_catalogs.get(&key).ok_or_else(|| { + format!( + "missing exception catalog for L{}H{} edits={} tail={}", + head.layer, head.head, edits, tail_frac + ) + })?; + let exception_hidden = forward_q4k_oracle_pq_exception_head( + &mut weights, + &token_ids, + &index, + *head, + basis, + pca_basis, + head_means, + codebook, + table, + w_o_head, + catalog, + stratum, + )?; + let exception_logits = final_logits(&weights, &exception_hidden); + let exception_logp = log_softmax(&exception_logits); + let kl = kl_logp(&baseline_logp, &exception_logp); + let exception_top1 = argmax(&exception_logits); + let exception_top5 = top_k_indices(&exception_logits, 5); + let exception_top2 = top_k_indices(&exception_logits, 2); + let exception_top2_token = + exception_top2.get(1).copied().unwrap_or(exception_top1); + let exception_top1_prob = token_prob(&exception_logp, exception_top1); + let exception_top2_prob = token_prob(&exception_logp, exception_top2_token); + let exception_top1_margin = exception_top1_prob - exception_top2_prob; + let exception_prob_of_baseline_top1 = + token_prob(&exception_logp, baseline_top1); + accumulators + .get_mut(&key) + .expect("exception accumulator missing") + .add(OraclePqExceptionPromptReport { + id: label.to_string(), + stratum: stratum.to_string(), + kl, + delta_cross_entropy_bits: kl / std::f64::consts::LN_2, + baseline_top1, + exception_top1, + top1_agree: baseline_top1 == exception_top1, + baseline_top1_in_exception_top5: exception_top5 + .contains(&baseline_top1), + baseline_top1_prob, + baseline_top2: baseline_top2_token, + baseline_top2_prob, + baseline_top1_margin, + exception_top1_prob, + exception_prob_of_baseline_top1, + exception_top1_margin, + }); + } + } + } + } + + let mut head_reports = Vec::new(); + for head in &selected_heads { + let basis = bases + .get(head) + .ok_or_else(|| format!("missing basis for L{} H{}", head.layer, head.head))?; + let pca_basis = pca_bases + .get(head) + .ok_or_else(|| format!("missing PCA basis for L{} H{}", head.layer, head.head))?; + let mut points = Vec::new(); + for &edits in &exception_edits { + for &tail_frac in &tail_fracs { + let key = ExceptionKey { + head: *head, + edits, + tail_frac_key: tail_frac_key(tail_frac), + }; + let acc = accumulators + .remove(&key) + .expect("exception accumulator missing at finish"); + let catalog = exception_catalogs + .get(&key) + .expect("exception catalog missing at finish"); + points.push(acc.finish(base_config, catalog, weights.hidden_size)); + } + } + let static_train_samples = means.get(head).map(|m| m.count).unwrap_or(0); + head_reports.push(OraclePqExceptionHeadReport { + layer: head.layer, + head: head.head, + head_dim: basis.head_dim, + rank_retained: basis.rank_retained(), + empirical_rank: pca_basis.rank(), + sigma_max: basis.sigma_max, + sigma_min_retained: basis.sigma_min_retained, + static_train_samples, + points, + }); + } + + let report = OraclePqExceptionReport { + index: args.index.display().to_string(), + prompt_file: args.prompts.display().to_string(), + prompts_seen, + train_prompts_seen: fit_prompts.len(), + eval_prompts_seen: eval_prompts.len(), + max_per_stratum: args.max_per_stratum, + eval_mod: args.eval_mod, + eval_offset: args.eval_offset, + static_base: "position_mean".to_string(), + base_config, + exception_edits, + tail_fracs, + tail_selector: tail_selector.as_str().to_string(), + exception_fit: exception_fit.as_str().to_string(), + position_candidates_per_prompt: args.position_candidates_per_prompt, + sigma_rel_cutoff: args.sigma_rel_cutoff, + pq_iters: args.pq_iters, + exception_iters: args.exception_iters, + selected_heads, + heads: head_reports, + }; + + let out_path = args.out.join("oracle_pq_exception.json"); + let file = std::fs::File::create(&out_path)?; + serde_json::to_writer_pretty(file, &report)?; + eprintln!("Wrote {}", out_path.display()); + + Ok(()) +} + +fn fit_exception_catalogs( + weights: &mut larql_inference::ModelWeights, + index: &VectorIndex, + tokenizer: &tokenizers::Tokenizer, + prompts: &[PromptRecord], + heads: &[HeadId], + bases: &HashMap, + means: &HashMap, + pca_bases: &HashMap, + codebooks: &HashMap<(HeadId, PqConfig), PqCodebook>, + tables: &HashMap<(HeadId, PqConfig), ModeDTable>, + w_o_heads: &HashMap>>, + base_config: PqConfig, + exception_edits: &[usize], + tail_fracs: &[f64], + tail_selector: TailSelector, + exception_fit: ExceptionFit, + prompt_scores: &HashMap<(HeadId, usize), f64>, + position_scores: &HashMap<(HeadId, usize, usize), f64>, + iterations: usize, +) -> Result, Box> { + let mut heads_by_layer: HashMap> = HashMap::new(); + for head in heads { + heads_by_layer.entry(head.layer).or_default().push(*head); + } + let mut samples: HashMap> = HashMap::new(); + for head in heads { + samples.insert(*head, Vec::new()); + } + + for (prompt_idx, record) in prompts.iter().enumerate() { + let label = prompt_label(record); + eprintln!( + " exception-fit [{}/{}] {}", + prompt_idx + 1, + prompts.len(), + label + ); + let token_ids = encode_prompt(tokenizer, &*weights.arch, &record.prompt)?; + if token_ids.is_empty() { + continue; + } + let stratum = record.stratum.as_deref().unwrap_or("unknown"); + let mut h = embed_tokens_pub(weights, &token_ids); + let ple_inputs = precompute_per_layer_inputs(weights, &h, &token_ids); + + for layer in 0..weights.num_layers { + let inserted = insert_q4k_layer_tensors(weights, index, layer)?; + if let Some(layer_heads) = heads_by_layer.get(&layer) { + let (_, pre_o) = run_attention_block_with_pre_o(weights, &h, layer) + .ok_or_else(|| format!("pre-W_O capture failed at layer {layer}"))?; + let head_dim = weights.arch.head_dim_for_layer(layer); + for head in layer_heads { + let basis = bases.get(head).expect("basis pre-created"); + let pca_basis = pca_bases.get(head).expect("PCA pre-created"); + let head_means = means.get(head).expect("means pre-created"); + let codebook = codebooks + .get(&(*head, base_config)) + .expect("base codebook pre-created"); + let table = tables + .get(&(*head, base_config)) + .expect("base Mode D table pre-created"); + let w_o_head = w_o_heads.get(head).expect("W_O head pre-copied"); + let start = head.head * head_dim; + let end = start + head_dim; + for pos in 0..pre_o.nrows() { + let row = pre_o.slice(s![pos, start..end]); + let values = row + .as_slice() + .ok_or("pre-W_O head row was not contiguous during exception fit")?; + let base_delta = base_pq_delta( + values, basis, pca_basis, head_means, codebook, table, pos, stratum, + ); + let true_delta = project_head_vector_to_hidden(w_o_head, values); + let error = true_delta + .iter() + .zip(base_delta.iter()) + .map(|(&true_value, &base_value)| true_value as f64 - base_value as f64) + .collect::>(); + let sq_norm = error.iter().map(|value| value * value).sum::(); + let score = match tail_selector { + TailSelector::ResidualError => sq_norm, + TailSelector::PromptKl => { + *prompt_scores.get(&(*head, prompt_idx)).unwrap_or(&0.0) + } + TailSelector::PositionRestoreKl => *position_scores + .get(&(*head, prompt_idx, pos)) + .unwrap_or(&0.0), + TailSelector::PositionRestoreCe => *position_scores + .get(&(*head, prompt_idx, pos)) + .unwrap_or(&0.0), + }; + samples + .get_mut(head) + .expect("exception samples missing") + .push(ErrorSample { + score, + sq_norm, + values: error, + }); + } + } + } + { + let ffn = WeightFfn { weights }; + if let Some((h_new, _, _)) = + run_layer_with_ffn(weights, &h, layer, &ffn, false, ple_inputs.get(layer), None) + { + h = h_new; + } + } + remove_layer_tensors(weights, inserted); + } + } + + let mut catalogs = HashMap::new(); + for head in heads { + let mut head_samples = samples.remove(head).ok_or_else(|| { + format!( + "missing exception samples for L{}H{}", + head.layer, head.head + ) + })?; + head_samples.sort_by(|a, b| { + b.score + .partial_cmp(&a.score) + .unwrap_or(std::cmp::Ordering::Equal) + .then_with(|| { + b.sq_norm + .partial_cmp(&a.sq_norm) + .unwrap_or(std::cmp::Ordering::Equal) + }) + }); + let total = head_samples.len(); + for &tail_frac in tail_fracs { + let used = ((total as f64) * tail_frac).ceil() as usize; + let used = used.clamp(1, total.max(1)); + let selected = head_samples + .iter() + .take(used) + .map(|sample| sample.values.clone()) + .collect::>(); + for &edits in exception_edits { + let centroids = match exception_fit { + ExceptionFit::Kmeans => kmeans_centroids(&selected, edits, iterations), + ExceptionFit::Exemplar => exemplar_centroids(&selected, edits), + }; + catalogs.insert( + ExceptionKey { + head: *head, + edits, + tail_frac_key: tail_frac_key(tail_frac), + }, + ExceptionCatalog { + edits, + tail_frac, + train_error_samples: total, + train_error_samples_used: used, + centroids, + }, + ); + } + } + } + + Ok(catalogs) +} + +fn exemplar_centroids(selected: &[Vec], edits: usize) -> Vec> { + if edits == 0 { + return Vec::new(); + } + if selected.is_empty() { + return vec![Vec::new(); edits]; + } + (0..edits) + .map(|idx| selected[idx.min(selected.len() - 1)].clone()) + .collect() +} + +fn measure_fit_prompt_base_pq_kl( + weights: &mut larql_inference::ModelWeights, + index: &VectorIndex, + tokenizer: &tokenizers::Tokenizer, + prompts: &[PromptRecord], + heads: &[HeadId], + bases: &HashMap, + means: &HashMap, + pca_bases: &HashMap, + codebooks: &HashMap<(HeadId, PqConfig), PqCodebook>, + tables: &HashMap<(HeadId, PqConfig), ModeDTable>, + base_config: PqConfig, +) -> Result, Box> { + let mut scores = HashMap::new(); + for (prompt_idx, record) in prompts.iter().enumerate() { + let label = prompt_label(record); + eprintln!( + " selector-fit [{}/{}] {}", + prompt_idx + 1, + prompts.len(), + label + ); + let token_ids = encode_prompt(tokenizer, &*weights.arch, &record.prompt)?; + if token_ids.is_empty() { + continue; + } + let stratum = record.stratum.as_deref().unwrap_or("unknown"); + let baseline_hidden = + larql_inference::vindex::predict_q4k_hidden(weights, &token_ids, index, None); + let baseline_logits = final_logits(weights, &baseline_hidden); + let baseline_logp = log_softmax(&baseline_logits); + for head in heads { + let basis = bases + .get(head) + .ok_or_else(|| format!("missing basis for L{}H{}", head.layer, head.head))?; + let pca_basis = pca_bases + .get(head) + .ok_or_else(|| format!("missing PCA basis for L{}H{}", head.layer, head.head))?; + let head_means = means + .get(head) + .ok_or_else(|| format!("missing means for L{}H{}", head.layer, head.head))?; + let codebook = codebooks.get(&(*head, base_config)).ok_or_else(|| { + format!("missing base codebook for L{}H{}", head.layer, head.head) + })?; + let table = tables + .get(&(*head, base_config)) + .ok_or_else(|| format!("missing base table for L{}H{}", head.layer, head.head))?; + let pq_hidden = forward_q4k_oracle_pq_mode_d_head( + weights, &token_ids, index, *head, basis, pca_basis, head_means, codebook, table, + stratum, + )?; + let pq_logits = final_logits(weights, &pq_hidden); + let pq_logp = log_softmax(&pq_logits); + scores.insert((*head, prompt_idx), kl_logp(&baseline_logp, &pq_logp)); + } + } + Ok(scores) +} + +fn measure_fit_position_restore_gains( + weights: &mut larql_inference::ModelWeights, + index: &VectorIndex, + tokenizer: &tokenizers::Tokenizer, + prompts: &[PromptRecord], + heads: &[HeadId], + bases: &HashMap, + means: &HashMap, + pca_bases: &HashMap, + codebooks: &HashMap<(HeadId, PqConfig), PqCodebook>, + tables: &HashMap<(HeadId, PqConfig), ModeDTable>, + w_o_heads: &HashMap>>, + base_config: PqConfig, + tail_selector: TailSelector, + candidates_per_prompt: usize, +) -> Result, Box> { + let mut scores = HashMap::new(); + if candidates_per_prompt == 0 { + return Ok(scores); + } + + for (prompt_idx, record) in prompts.iter().enumerate() { + let label = prompt_label(record); + eprintln!( + " position-restore-fit [{}/{}] {}", + prompt_idx + 1, + prompts.len(), + label + ); + let token_ids = encode_prompt(tokenizer, &*weights.arch, &record.prompt)?; + if token_ids.is_empty() { + continue; + } + let stratum = record.stratum.as_deref().unwrap_or("unknown"); + let baseline_hidden = + larql_inference::vindex::predict_q4k_hidden(weights, &token_ids, index, None); + let baseline_logits = final_logits(weights, &baseline_hidden); + let baseline_logp = log_softmax(&baseline_logits); + let baseline_top1 = argmax(&baseline_logits); + + for head in heads { + let basis = bases + .get(head) + .ok_or_else(|| format!("missing basis for L{}H{}", head.layer, head.head))?; + let pca_basis = pca_bases + .get(head) + .ok_or_else(|| format!("missing PCA basis for L{}H{}", head.layer, head.head))?; + let head_means = means + .get(head) + .ok_or_else(|| format!("missing means for L{}H{}", head.layer, head.head))?; + let codebook = codebooks.get(&(*head, base_config)).ok_or_else(|| { + format!("missing base codebook for L{}H{}", head.layer, head.head) + })?; + let table = tables + .get(&(*head, base_config)) + .ok_or_else(|| format!("missing base table for L{}H{}", head.layer, head.head))?; + let w_o_head = w_o_heads + .get(head) + .ok_or_else(|| format!("missing W_O head for L{}H{}", head.layer, head.head))?; + + let base_hidden = forward_q4k_oracle_pq_mode_d_head( + weights, &token_ids, index, *head, basis, pca_basis, head_means, codebook, table, + stratum, + )?; + let base_logits = final_logits(weights, &base_hidden); + let base_logp = log_softmax(&base_logits); + let base_kl = kl_logp(&baseline_logp, &base_logp); + let base_ce = -token_prob(&base_logp, baseline_top1).ln(); + + let mut candidates = capture_head_position_sq_errors( + weights, index, &token_ids, *head, basis, pca_basis, head_means, codebook, table, + w_o_head, stratum, + )?; + candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + candidates.truncate(candidates_per_prompt.min(candidates.len())); + + for (position, _sq_norm) in candidates { + let restored_hidden = forward_q4k_oracle_pq_position_restore_head( + weights, &token_ids, index, *head, basis, pca_basis, head_means, codebook, + table, w_o_head, position, stratum, + )?; + let restored_logits = final_logits(weights, &restored_hidden); + let restored_logp = log_softmax(&restored_logits); + let gain = match tail_selector { + TailSelector::PositionRestoreKl => { + let restored_kl = kl_logp(&baseline_logp, &restored_logp); + base_kl - restored_kl + } + TailSelector::PositionRestoreCe => { + let restored_ce = -token_prob(&restored_logp, baseline_top1).ln(); + base_ce - restored_ce + } + TailSelector::ResidualError | TailSelector::PromptKl => 0.0, + } + .max(0.0); + scores.insert((*head, prompt_idx, position), gain); + } + } + } + + Ok(scores) +} + +fn capture_head_position_sq_errors( + weights: &mut larql_inference::ModelWeights, + index: &VectorIndex, + token_ids: &[u32], + head: HeadId, + basis: &WoRoundtripBasis, + pca_basis: &ZPcaBasis, + means: &StaticHeadMeans, + codebook: &PqCodebook, + table: &ModeDTable, + w_o_head: &[Vec], + stratum: &str, +) -> Result, Box> { + let mut h = embed_tokens_pub(weights, token_ids); + let ple_inputs = precompute_per_layer_inputs(weights, &h, token_ids); + + for layer in 0..weights.num_layers { + let inserted = insert_q4k_layer_tensors(weights, index, layer)?; + if layer == head.layer { + let result = (|| -> Result, Box> { + let (_, pre_o) = run_attention_block_with_pre_o(weights, &h, layer) + .ok_or_else(|| format!("pre-W_O capture failed at layer {layer}"))?; + let head_dim = weights.arch.head_dim_for_layer(layer); + let start = head.head * head_dim; + let end = start + head_dim; + let mut errors = Vec::with_capacity(pre_o.nrows()); + for pos in 0..pre_o.nrows() { + let row = pre_o.slice(s![pos, start..end]); + let values = row + .as_slice() + .ok_or("pre-W_O head row was not contiguous during restore fit")?; + let base_delta = base_pq_delta( + values, basis, pca_basis, means, codebook, table, pos, stratum, + ); + let true_delta = project_head_vector_to_hidden(w_o_head, values); + let sq_norm = true_delta + .iter() + .zip(base_delta.iter()) + .map(|(&true_value, &base_value)| { + let delta = true_value as f64 - base_value as f64; + delta * delta + }) + .sum::(); + errors.push((pos, sq_norm)); + } + Ok(errors) + })(); + remove_layer_tensors(weights, inserted); + return result; + } + { + let ffn = WeightFfn { weights }; + if let Some((h_new, _, _)) = + run_layer_with_ffn(weights, &h, layer, &ffn, false, ple_inputs.get(layer), None) + { + h = h_new; + } + } + remove_layer_tensors(weights, inserted); + } + + Err(format!("target layer {} was not reached", head.layer).into()) +} + +fn forward_q4k_oracle_pq_exception_head( + weights: &mut larql_inference::ModelWeights, + token_ids: &[u32], + index: &VectorIndex, + head: HeadId, + basis: &WoRoundtripBasis, + pca_basis: &ZPcaBasis, + means: &StaticHeadMeans, + codebook: &PqCodebook, + table: &ModeDTable, + w_o_head: &[Vec], + catalog: &ExceptionCatalog, + stratum: &str, +) -> Result, Box> { + let hidden_size = weights.hidden_size; + larql_inference::vindex::predict_q4k_hidden_with_mapped_head_residual_delta( + weights, + token_ids, + index, + head.layer, + head.head, + |original_head| { + let mut replacement_delta = Vec::with_capacity(original_head.nrows() * hidden_size); + for pos in 0..original_head.nrows() { + let row = original_head.row(pos); + let values = row + .as_slice() + .ok_or("pre-W_O head row was not contiguous during exception eval")?; + let base_delta = base_pq_delta( + values, basis, pca_basis, means, codebook, table, pos, stratum, + ); + let true_delta = project_head_vector_to_hidden(w_o_head, values); + let error = true_delta + .iter() + .zip(base_delta.iter()) + .map(|(&true_value, &base_value)| true_value as f64 - base_value as f64) + .collect::>(); + let code = nearest_centroid_index(&error, &catalog.centroids); + let exception = &catalog.centroids[code]; + for (&base, &extra) in base_delta.iter().zip(exception.iter()) { + replacement_delta.push(base + extra as f32); + } + } + Array2::from_shape_vec((original_head.nrows(), hidden_size), replacement_delta) + .map_err(|err| err.to_string()) + }, + ) + .map_err(Into::into) +} + +fn forward_q4k_oracle_pq_position_restore_head( + weights: &mut larql_inference::ModelWeights, + token_ids: &[u32], + index: &VectorIndex, + head: HeadId, + basis: &WoRoundtripBasis, + pca_basis: &ZPcaBasis, + means: &StaticHeadMeans, + codebook: &PqCodebook, + table: &ModeDTable, + w_o_head: &[Vec], + restore_position: usize, + stratum: &str, +) -> Result, Box> { + let hidden_size = weights.hidden_size; + larql_inference::vindex::predict_q4k_hidden_with_mapped_head_residual_delta( + weights, + token_ids, + index, + head.layer, + head.head, + |original_head| { + let mut replacement_delta = Vec::with_capacity(original_head.nrows() * hidden_size); + for pos in 0..original_head.nrows() { + let row = original_head.row(pos); + let values = row + .as_slice() + .ok_or("pre-W_O head row was not contiguous during position restore")?; + if pos == restore_position { + let true_delta = project_head_vector_to_hidden(w_o_head, values); + replacement_delta.extend_from_slice(&true_delta); + } else { + let base_delta = base_pq_delta( + values, basis, pca_basis, means, codebook, table, pos, stratum, + ); + replacement_delta.extend_from_slice(&base_delta); + } + } + Array2::from_shape_vec((original_head.nrows(), hidden_size), replacement_delta) + .map_err(|err| err.to_string()) + }, + ) + .map_err(Into::into) +} + +fn base_pq_delta( + values: &[f32], + basis: &WoRoundtripBasis, + pca_basis: &ZPcaBasis, + means: &StaticHeadMeans, + codebook: &PqCodebook, + table: &ModeDTable, + position: usize, + stratum: &str, +) -> Vec { + let base = means.positions.get(position).unwrap_or(&means.global); + let residual = values + .iter() + .zip(base.iter()) + .map(|(&value, &mean)| value - mean) + .collect::>(); + let z = basis.residual_to_z(&residual); + let coords = pca_basis.coordinates_with_rank(&z, codebook.config.k); + let codes = codebook.quantize_indices_for_stratum(&coords, stratum); + table.delta_for_position_codes_with_stratum(position, &codes, stratum) +} + +fn copy_w_o_heads( + weights: &mut larql_inference::ModelWeights, + index: &VectorIndex, + heads: &[HeadId], +) -> Result>>, Box> { + let mut heads_by_layer: HashMap> = HashMap::new(); + for head in heads { + heads_by_layer.entry(head.layer).or_default().push(*head); + } + let mut out = HashMap::new(); + for (layer, layer_heads) in heads_by_layer { + let inserted = insert_q4k_layer_tensors(weights, index, layer)?; + let w_o = weights + .tensors + .get(&weights.arch.attn_o_key(layer)) + .ok_or_else(|| format!("missing W_O tensor at layer {layer}"))?; + let head_dim = weights.arch.head_dim_for_layer(layer); + for head in layer_heads { + let start = head.head * head_dim; + let end = start + head_dim; + let w_o_head = w_o.slice(s![.., start..end]); + let rows = (0..w_o_head.nrows()) + .map(|row| { + (0..w_o_head.ncols()) + .map(|col| w_o_head[[row, col]]) + .collect::>() + }) + .collect::>(); + out.insert(head, rows); + } + remove_layer_tensors(weights, inserted); + } + Ok(out) +} + +fn project_head_vector_to_hidden(w_o_head: &[Vec], values: &[f32]) -> Vec { + let mut out = vec![0.0f32; w_o_head.len()]; + for (row_idx, row) in w_o_head.iter().enumerate() { + let mut sum = 0.0f32; + for (&value, &weight) in values.iter().zip(row.iter()) { + sum += value * weight; + } + out[row_idx] = sum; + } + out +} + +#[derive(Debug)] +struct PqExceptionAccumulator { + prompts: Vec, +} + +impl PqExceptionAccumulator { + fn new() -> Self { + Self { + prompts: Vec::new(), + } + } + + fn add(&mut self, prompt: OraclePqExceptionPromptReport) { + self.prompts.push(prompt); + } + + fn finish( + self, + base_config: PqConfig, + catalog: &ExceptionCatalog, + hidden_dim: usize, + ) -> OraclePqExceptionPointReport { + let kls = self.prompts.iter().map(|p| p.kl).collect::>(); + let levels = 1usize << base_config.bits_per_group; + let base_bytes = base_config.groups * levels * hidden_dim * 2; + let exception_bytes = catalog.edits * hidden_dim * 2; + let exception_bits = catalog.edits.next_power_of_two().trailing_zeros() as usize; + let base_bits = base_config.groups * base_config.bits_per_group; + OraclePqExceptionPointReport { + exception_edits: catalog.edits, + tail_frac: catalog.tail_frac, + train_error_samples: catalog.train_error_samples, + train_error_samples_used: catalog.train_error_samples_used, + base_address_bits: base_bits, + exception_address_bits: exception_bits, + total_address_bits: base_bits + exception_bits, + base_table_bytes_bf16: base_bytes, + exception_table_bytes_bf16: exception_bytes, + total_table_bytes_bf16: base_bytes + exception_bytes, + prompts: self.prompts.len(), + mean_kl: mean(&kls), + p95_kl: percentile(kls.clone(), 0.95), + max_kl: kls.iter().copied().fold(0.0, f64::max), + mean_delta_cross_entropy_bits: mean( + &self + .prompts + .iter() + .map(|p| p.delta_cross_entropy_bits) + .collect::>(), + ), + top1_agreement: bool_rate(self.prompts.iter().map(|p| p.top1_agree)), + top5_contains_baseline_top1: bool_rate( + self.prompts + .iter() + .map(|p| p.baseline_top1_in_exception_top5), + ), + mean_baseline_top1_prob: mean( + &self + .prompts + .iter() + .map(|p| p.baseline_top1_prob) + .collect::>(), + ), + mean_exception_prob_of_baseline_top1: mean( + &self + .prompts + .iter() + .map(|p| p.exception_prob_of_baseline_top1) + .collect::>(), + ), + mean_baseline_top1_margin: mean( + &self + .prompts + .iter() + .map(|p| p.baseline_top1_margin) + .collect::>(), + ), + per_prompt: self.prompts, + } + } +} + +fn parse_f64_list(spec: &str) -> Result, Box> { + let mut values = Vec::new(); + for part in spec.split(',') { + let part = part.trim(); + if part.is_empty() { + continue; + } + values.push(part.parse()?); + } + Ok(values) +} + +fn tail_frac_key(tail_frac: f64) -> u64 { + (tail_frac * 1_000_000.0).round() as u64 +} + +fn prompt_label(record: &PromptRecord) -> &str { + record + .id + .as_deref() + .or(record.stratum.as_deref()) + .unwrap_or("prompt") +} diff --git a/crates/larql-cli/src/commands/dev/ov_rd/reports.rs b/crates/larql-cli/src/commands/dev/ov_rd/reports.rs new file mode 100644 index 00000000..be499525 --- /dev/null +++ b/crates/larql-cli/src/commands/dev/ov_rd/reports.rs @@ -0,0 +1,737 @@ +#![allow(dead_code)] + +use serde::{Deserialize, Serialize}; + +use super::types::{HeadId, PqConfig}; + +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct FinishedHeadStats { + pub(super) count: u64, + pub(super) mean_norm_sq: f64, + pub(super) second_moment: f64, + pub(super) variance: f64, + pub(super) rms_norm: f64, +} + +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct HeadReport { + pub(super) layer: usize, + pub(super) head: usize, + pub(super) head_dim: usize, + pub(super) stats: FinishedHeadStats, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub(super) wo_visible_stats: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct CaptureReport { + pub(super) index: String, + pub(super) prompt_file: String, + pub(super) prompts_seen: usize, + pub(super) layers: Vec, + pub(super) max_positions: Option, + #[serde(default)] + pub(super) wo_visible: bool, + pub(super) heads: Vec, +} + +#[derive(Debug, Serialize)] +pub(super) struct ZeroStratumReport { + pub(super) stratum: String, + pub(super) prompts: usize, + pub(super) mean_kl: f64, + pub(super) max_kl: f64, + pub(super) top1_agreement: f64, + pub(super) top5_contains_baseline_top1: f64, +} + +#[derive(Debug, Clone, Serialize)] +pub(super) struct ZeroPromptReport { + pub(super) id: String, + pub(super) stratum: String, + pub(super) kl: f64, + pub(super) delta_cross_entropy_bits: f64, + pub(super) baseline_top1: u32, + pub(super) ablated_top1: u32, + pub(super) top1_agree: bool, + pub(super) baseline_top1_in_ablated_top5: bool, +} + +#[derive(Debug, Serialize)] +pub(super) struct ZeroHeadReport { + pub(super) layer: usize, + pub(super) head: usize, + pub(super) ablation_kind: String, + pub(super) patch_location: String, + pub(super) preserved_components: Vec, + pub(super) bounded_vocab_size: Option, + pub(super) prompts: usize, + pub(super) mean_kl: f64, + pub(super) p95_kl: f64, + pub(super) max_kl: f64, + pub(super) mean_delta_cross_entropy_bits: f64, + pub(super) top1_agreement: f64, + pub(super) top5_contains_baseline_top1: f64, + pub(super) strata: Vec, + pub(super) worst_examples: Vec, + pub(super) per_prompt: Vec, +} + +#[derive(Debug, Serialize)] +pub(super) struct ZeroAblationReport { + pub(super) index: String, + pub(super) prompt_file: String, + pub(super) prompts_seen: usize, + pub(super) selected_heads: Vec, + pub(super) heads: Vec, +} + +#[derive(Debug, Serialize)] +pub(super) struct StaticReplacementReport { + pub(super) index: String, + pub(super) prompt_file: String, + pub(super) prompts_seen: usize, + pub(super) train_prompts_seen: usize, + pub(super) eval_prompts_seen: usize, + #[serde(skip_serializing_if = "Option::is_none")] + pub(super) eval_mod: Option, + pub(super) eval_offset: usize, + pub(super) selected_heads: Vec, + pub(super) heads: Vec, +} + +#[derive(Debug, Serialize)] +pub(super) struct StaticHeadReport { + pub(super) layer: usize, + pub(super) head: usize, + pub(super) train_samples: u64, + pub(super) modes: Vec, +} + +#[derive(Debug, Serialize)] +pub(super) struct StaticModeReport { + pub(super) replacement_kind: String, + pub(super) patch_location: String, + pub(super) runtime_class: String, + pub(super) prompts: usize, + pub(super) mean_kl: f64, + pub(super) p95_kl: f64, + pub(super) max_kl: f64, + pub(super) mean_delta_cross_entropy_bits: f64, + pub(super) top1_agreement: f64, + pub(super) top5_contains_baseline_top1: f64, + pub(super) strata: Vec, + pub(super) worst_examples: Vec, + pub(super) per_prompt: Vec, +} + +#[derive(Debug, Serialize)] +pub(super) struct SanityCheckReport { + pub(super) index: String, + pub(super) prompt_file: String, + pub(super) prompts_seen: usize, + pub(super) selected_heads: Vec, + pub(super) heads: Vec, +} + +#[derive(Debug, Serialize)] +pub(super) struct SanityHeadReport { + pub(super) layer: usize, + pub(super) head: usize, + pub(super) prompts: usize, + pub(super) noop_mean_kl: f64, + pub(super) noop_max_kl: f64, + pub(super) noop_max_abs_logit_diff: f64, + pub(super) residual_delta_noop_mean_kl: f64, + pub(super) residual_delta_noop_max_kl: f64, + pub(super) residual_delta_noop_max_abs_logit_diff: f64, + pub(super) zero_subtract_mean_kl: f64, + pub(super) zero_subtract_max_kl: f64, + pub(super) zero_subtract_max_abs_logit_diff: f64, + pub(super) per_prompt: Vec, +} + +#[derive(Debug, Serialize)] +pub(super) struct SanityPromptReport { + pub(super) id: String, + pub(super) stratum: String, + pub(super) noop_kl: f64, + pub(super) noop_max_abs_logit_diff: f64, + pub(super) residual_delta_noop_kl: f64, + pub(super) residual_delta_noop_max_abs_logit_diff: f64, + pub(super) zero_subtract_kl: f64, + pub(super) zero_subtract_max_abs_logit_diff: f64, +} + +#[derive(Debug, Serialize)] +pub(super) struct OracleRoundtripReport { + pub(super) index: String, + pub(super) prompt_file: String, + pub(super) prompts_seen: usize, + pub(super) sigma_rel_cutoff: f64, + pub(super) selected_heads: Vec, + pub(super) heads: Vec, +} + +#[derive(Debug, Serialize)] +pub(super) struct OracleRoundtripHeadReport { + pub(super) layer: usize, + pub(super) head: usize, + pub(super) head_dim: usize, + pub(super) rank_retained: usize, + pub(super) sigma_max: f64, + pub(super) sigma_min_retained: f64, + pub(super) sigma_rel_cutoff: f64, + pub(super) prompts: usize, + pub(super) mean_kl: f64, + pub(super) p95_kl: f64, + pub(super) max_kl: f64, + pub(super) max_abs_logit_diff: f64, + pub(super) mean_pre_wo_l2: f64, + pub(super) max_pre_wo_l2: f64, + pub(super) mean_wo_visible_l2: f64, + pub(super) max_wo_visible_l2: f64, + pub(super) per_prompt: Vec, +} + +#[derive(Debug, Clone, Serialize)] +pub(super) struct OracleRoundtripPromptReport { + pub(super) id: String, + pub(super) stratum: String, + pub(super) kl: f64, + pub(super) max_abs_logit_diff: f64, + pub(super) pre_wo_l2: f64, + pub(super) wo_visible_l2: f64, +} + +#[derive(Debug, Serialize)] +pub(super) struct OracleLowrankReport { + pub(super) index: String, + pub(super) prompt_file: String, + pub(super) prompts_seen: usize, + pub(super) static_base: String, + pub(super) ks: Vec, + pub(super) sigma_rel_cutoff: f64, + pub(super) selected_heads: Vec, + pub(super) heads: Vec, +} + +#[derive(Debug, Serialize)] +pub(super) struct OracleLowrankHeadReport { + pub(super) layer: usize, + pub(super) head: usize, + pub(super) head_dim: usize, + pub(super) rank_retained: usize, + pub(super) empirical_rank: usize, + pub(super) sigma_max: f64, + pub(super) sigma_min_retained: f64, + pub(super) static_train_samples: u64, + pub(super) points: Vec, +} + +#[derive(Debug, Serialize)] +pub(super) struct OracleLowrankPointReport { + pub(super) k: usize, + pub(super) prompts: usize, + pub(super) mean_kl: f64, + pub(super) p95_kl: f64, + pub(super) max_kl: f64, + pub(super) mean_delta_cross_entropy_bits: f64, + pub(super) top1_agreement: f64, + pub(super) top5_contains_baseline_top1: f64, + pub(super) mean_baseline_top1_prob: f64, + pub(super) mean_lowrank_prob_of_baseline_top1: f64, + pub(super) mean_baseline_top1_margin: f64, + pub(super) mean_pre_wo_l2: f64, + pub(super) mean_wo_visible_l2: f64, + pub(super) per_prompt: Vec, +} + +#[derive(Debug, Clone, Serialize)] +pub(super) struct OracleLowrankPromptReport { + pub(super) id: String, + pub(super) stratum: String, + pub(super) kl: f64, + pub(super) delta_cross_entropy_bits: f64, + pub(super) baseline_top1: u32, + pub(super) lowrank_top1: u32, + pub(super) top1_agree: bool, + pub(super) baseline_top1_in_lowrank_top5: bool, + pub(super) baseline_top1_prob: f64, + pub(super) baseline_top2: u32, + pub(super) baseline_top2_prob: f64, + pub(super) baseline_top1_margin: f64, + pub(super) lowrank_top1_prob: f64, + pub(super) lowrank_prob_of_baseline_top1: f64, + pub(super) lowrank_top1_margin: f64, + pub(super) pre_wo_l2: f64, + pub(super) wo_visible_l2: f64, +} + +#[derive(Debug, Serialize)] +pub(super) struct OraclePqReport { + pub(super) index: String, + pub(super) prompt_file: String, + pub(super) prompts_seen: usize, + pub(super) train_prompts_seen: usize, + pub(super) eval_prompts_seen: usize, + #[serde(skip_serializing_if = "Option::is_none")] + pub(super) max_per_stratum: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub(super) eval_mod: Option, + pub(super) eval_offset: usize, + pub(super) static_base: String, + pub(super) configs: Vec, + pub(super) sigma_rel_cutoff: f64, + pub(super) pq_iters: usize, + pub(super) mode_d_check: bool, + pub(super) address_probes: bool, + pub(super) address_mixed_key_probe: bool, + pub(super) address_key_group_probe: bool, + pub(super) address_key_groups: Vec, + pub(super) address_key_group_probe_names: Vec, + pub(super) address_majority_group_probe: bool, + pub(super) address_majority_groups: Vec, + pub(super) address_code_substitution_group_probe: bool, + pub(super) address_code_substitution_groups: Vec, + pub(super) address_code_substitution_from_codes: Vec, + pub(super) address_code_substitution_to_codes: Vec, + pub(super) address_code_class_collapse_group_probe: bool, + pub(super) address_code_class_collapse_groups: Vec, + pub(super) address_code_class_collapse_specs: Vec, + pub(super) address_code_position_interaction_probe: bool, + pub(super) address_code_position_prompt_id: String, + pub(super) address_code_position_group: usize, + pub(super) address_code_position_primary_codes: Vec, + pub(super) address_code_position_secondary_codes: Vec, + pub(super) address_code_position_target_code: usize, + pub(super) address_code_conditional_quotient_group_probe: bool, + pub(super) address_code_conditional_quotient_group: usize, + pub(super) address_code_conditional_quotient_primary_codes: Vec, + pub(super) address_code_conditional_quotient_secondary_codes: Vec, + pub(super) address_code_conditional_quotient_target_code: usize, + pub(super) address_code_conditional_quotient_early_position_max: usize, + pub(super) address_code_conditional_quotient_guards: Vec, + pub(super) address_code_conditional_quotient_extra_specs: Vec, + pub(super) address_code7_bos_rule_group_probe: bool, + pub(super) address_code7_bos_rule_groups: Vec, + pub(super) address_code7_bos_rule_code: usize, + pub(super) address_code7_oracle_binary_group_probe: bool, + pub(super) address_code7_oracle_binary_groups: Vec, + pub(super) address_code7_oracle_binary_code: usize, + pub(super) address_code7_oracle_binary_filters: Vec, + pub(super) address_corruption_sweep: bool, + pub(super) address_group_importance: bool, + pub(super) address_lsh_group_probe: bool, + pub(super) address_lsh_groups: Vec, + pub(super) address_lsh_bits: usize, + pub(super) address_lsh_seeds: usize, + pub(super) address_supervised_group_probe: bool, + pub(super) address_supervised_groups: Vec, + pub(super) address_supervised_epochs: usize, + pub(super) address_supervised_lr: f32, + pub(super) address_supervised_l2: f32, + pub(super) address_gamma_projected_group_probe: bool, + pub(super) address_gamma_projected_groups: Vec, + pub(super) address_gamma_projected_layers: Vec, + pub(super) address_gamma_random_ranks: Vec, + pub(super) address_gamma_random_seeds: Vec, + pub(super) address_gamma_learned_ranks: Vec, + pub(super) address_gamma_learned_epochs: usize, + pub(super) address_gamma_learned_lr: f32, + pub(super) address_gamma_learned_l2: f32, + pub(super) address_gamma_learned_pca_iters: usize, + pub(super) address_code_stability: bool, + pub(super) address_code_stability_groups: Vec, + pub(super) address_prev_ffn_feature_group_probe: bool, + pub(super) address_prev_ffn_feature_groups: Vec, + pub(super) address_prev_ffn_feature_top_k: usize, + pub(super) address_ffn_first_feature_group_probe: bool, + pub(super) address_ffn_first_feature_groups: Vec, + pub(super) address_ffn_first_feature_top_k: usize, + pub(super) address_attention_relation_group_probe: bool, + pub(super) address_attention_relation_groups: Vec, + pub(super) address_attention_cluster_group_probe: bool, + pub(super) address_attention_cluster_groups: Vec, + pub(super) address_attention_cluster_ks: Vec, + pub(super) address_attention_cluster_probe_names: Vec, + pub(super) address_reduced_qk_cluster_group_probe: bool, + pub(super) address_reduced_qk_cluster_groups: Vec, + pub(super) address_reduced_qk_ranks: Vec, + pub(super) address_reduced_qk_cluster_ks: Vec, + pub(super) address_reduced_qk_cluster_probe_names: Vec, + pub(super) stratum_conditioned_pq_groups: Vec, + pub(super) selected_heads: Vec, + pub(super) heads: Vec, +} + +#[derive(Debug, Serialize)] +pub(super) struct CodeOccurrenceRecord { + pub(super) prompt_id: String, + pub(super) stratum: String, + pub(super) layer: usize, + pub(super) head: usize, + pub(super) config: PqConfig, + pub(super) group: usize, + pub(super) code: usize, + pub(super) position: usize, + pub(super) token_id: u32, + pub(super) token_text: String, + pub(super) prev_token_id: Option, + pub(super) prev_token_text: Option, + pub(super) attn_argmax_position: Option, + pub(super) attn_argmax_token_id: Option, + pub(super) attn_argmax_token_text: Option, + pub(super) attn_entropy_bits: Option, + pub(super) attn_relation_class_key: Option, +} + +#[derive(Debug, Serialize)] +pub(super) struct OraclePqHeadReport { + pub(super) layer: usize, + pub(super) head: usize, + pub(super) head_dim: usize, + pub(super) rank_retained: usize, + pub(super) empirical_rank: usize, + pub(super) sigma_max: f64, + pub(super) sigma_min_retained: f64, + pub(super) static_train_samples: u64, + pub(super) points: Vec, +} + +#[derive(Debug, Serialize)] +pub(super) struct OraclePqPointReport { + pub(super) k: usize, + pub(super) groups: usize, + pub(super) bits_per_group: usize, + pub(super) oracle_address_bits: usize, + pub(super) coefficient_codebook_bytes_f32: usize, + pub(super) mode_d_residual_table_bytes_bf16: usize, + pub(super) prompts: usize, + pub(super) mean_kl: f64, + pub(super) p95_kl: f64, + pub(super) max_kl: f64, + pub(super) mean_delta_cross_entropy_bits: f64, + pub(super) top1_agreement: f64, + pub(super) top5_contains_baseline_top1: f64, + pub(super) mean_baseline_top1_prob: f64, + pub(super) mean_pq_prob_of_baseline_top1: f64, + pub(super) mean_baseline_top1_margin: f64, + #[serde(skip_serializing_if = "Option::is_none")] + pub(super) mode_d_mean_kl: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub(super) mode_d_p95_kl: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub(super) mode_d_max_kl: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub(super) mode_d_top1_agreement: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub(super) mode_d_top5_contains_baseline_top1: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub(super) coeff_mode_d_max_abs_logit_diff: Option, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub(super) address_probes: Vec, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub(super) address_corruption_sweep: Vec, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub(super) address_group_importance: Vec, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub(super) code_stability: Vec, + pub(super) mean_pre_wo_l2: f64, + pub(super) mean_wo_visible_l2: f64, + pub(super) per_prompt: Vec, +} + +#[derive(Debug, Clone, Serialize)] +pub(super) struct CodeStabilityReport { + pub(super) group: usize, + pub(super) train_positions: usize, + pub(super) eval_positions: usize, + pub(super) train_entropy_bits: f64, + pub(super) eval_entropy_bits: f64, + pub(super) train_top_code: usize, + pub(super) train_top_code_mass: f64, + pub(super) eval_top_code: usize, + pub(super) eval_top_code_mass: f64, + pub(super) train_eval_js_bits: f64, + pub(super) by_stratum: Vec, +} + +#[derive(Debug, Clone, Serialize)] +pub(super) struct CodeStabilityStratumReport { + pub(super) stratum: String, + pub(super) train_positions: usize, + pub(super) eval_positions: usize, + pub(super) train_entropy_bits: f64, + pub(super) eval_entropy_bits: f64, + pub(super) train_top_code: usize, + pub(super) train_top_code_mass: f64, + pub(super) eval_top_code: usize, + pub(super) eval_top_code_mass: f64, + pub(super) train_eval_js_bits: f64, +} + +#[derive(Debug, Serialize)] +pub(super) struct AddressProbeReport { + pub(super) name: String, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub(super) selected_group_keys: Vec, + pub(super) prompts: usize, + pub(super) positions: usize, + pub(super) group_accuracy: f64, + pub(super) exact_address_accuracy: f64, + pub(super) mean_groups_correct_per_sequence: f64, + pub(super) mean_groups_correct_per_position: f64, + pub(super) mean_kl: f64, + pub(super) p95_kl: f64, + pub(super) max_kl: f64, + pub(super) top1_agreement: f64, + pub(super) top5_contains_baseline_top1: f64, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub(super) by_stratum: Vec, + pub(super) worst_examples: Vec, +} + +#[derive(Debug, Clone, Serialize)] +pub(super) struct AddressProbeStratumReport { + pub(super) stratum: String, + pub(super) prompts: usize, + pub(super) positions: usize, + pub(super) group_accuracy: f64, + pub(super) mean_kl: f64, + pub(super) p95_kl: f64, + pub(super) max_kl: f64, + pub(super) top1_agreement: f64, + pub(super) top5_contains_baseline_top1: f64, +} + +#[derive(Debug, Clone, Serialize)] +pub(super) struct AddressProbePromptReport { + pub(super) id: String, + pub(super) stratum: String, + pub(super) kl: f64, + pub(super) positions: usize, + pub(super) groups_correct: usize, + pub(super) groups_total: usize, + pub(super) exact_address_match: bool, + pub(super) top1_agree: bool, + pub(super) baseline_top1_in_predicted_top5: bool, +} + +#[derive(Debug, Serialize)] +pub(super) struct AddressCorruptionReport { + pub(super) label: String, + pub(super) oracle_groups_kept: usize, + pub(super) prompts: usize, + pub(super) positions: usize, + pub(super) group_accuracy: f64, + pub(super) exact_address_accuracy: f64, + pub(super) mean_kl: f64, + pub(super) p95_kl: f64, + pub(super) max_kl: f64, + pub(super) top1_agreement: f64, + pub(super) top5_contains_baseline_top1: f64, + pub(super) worst_examples: Vec, +} + +#[derive(Debug, Serialize)] +pub(super) struct AddressGroupImportanceReport { + pub(super) replaced_group: usize, + pub(super) prompts: usize, + pub(super) positions: usize, + pub(super) group_accuracy: f64, + pub(super) exact_address_accuracy: f64, + pub(super) mean_kl: f64, + pub(super) p95_kl: f64, + pub(super) max_kl: f64, + pub(super) top1_agreement: f64, + pub(super) top5_contains_baseline_top1: f64, + pub(super) worst_examples: Vec, +} + +#[derive(Debug, Clone, Serialize)] +pub(super) struct OraclePqPromptReport { + pub(super) id: String, + pub(super) stratum: String, + pub(super) kl: f64, + pub(super) delta_cross_entropy_bits: f64, + pub(super) baseline_top1: u32, + pub(super) pq_top1: u32, + pub(super) top1_agree: bool, + pub(super) baseline_top1_in_pq_top5: bool, + pub(super) baseline_top1_prob: f64, + pub(super) baseline_top2: u32, + pub(super) baseline_top2_prob: f64, + pub(super) baseline_top1_margin: f64, + pub(super) pq_top1_prob: f64, + pub(super) pq_prob_of_baseline_top1: f64, + pub(super) pq_top1_margin: f64, + #[serde(skip_serializing_if = "Option::is_none")] + pub(super) mode_d_kl: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub(super) mode_d_top1: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub(super) mode_d_top1_agree: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub(super) baseline_top1_in_mode_d_top5: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub(super) coeff_mode_d_max_abs_logit_diff: Option, + pub(super) pre_wo_l2: f64, + pub(super) wo_visible_l2: f64, +} + +#[derive(Debug, Serialize)] +pub(super) struct OracleEditCatalogReport { + pub(super) index: String, + pub(super) prompt_file: String, + pub(super) prompts_seen: usize, + pub(super) train_prompts_seen: usize, + pub(super) eval_prompts_seen: usize, + #[serde(skip_serializing_if = "Option::is_none")] + pub(super) max_per_stratum: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub(super) eval_mod: Option, + pub(super) eval_offset: usize, + pub(super) static_base: String, + pub(super) spaces: Vec, + pub(super) edit_counts: Vec, + pub(super) pca_rank: usize, + pub(super) sigma_rel_cutoff: f64, + pub(super) kmeans_iters: usize, + pub(super) selected_heads: Vec, + pub(super) heads: Vec, +} + +#[derive(Debug, Serialize)] +pub(super) struct OracleEditCatalogHeadReport { + pub(super) layer: usize, + pub(super) head: usize, + pub(super) head_dim: usize, + pub(super) rank_retained: usize, + pub(super) empirical_rank: usize, + pub(super) sigma_max: f64, + pub(super) sigma_min_retained: f64, + pub(super) static_train_samples: u64, + pub(super) points: Vec, +} + +#[derive(Debug, Serialize)] +pub(super) struct OracleEditCatalogPointReport { + pub(super) space: String, + pub(super) edits: usize, + pub(super) address_bits: usize, + pub(super) residual_table_bytes_bf16: usize, + pub(super) prompts: usize, + pub(super) mean_kl: f64, + pub(super) p95_kl: f64, + pub(super) max_kl: f64, + pub(super) mean_delta_cross_entropy_bits: f64, + pub(super) top1_agreement: f64, + pub(super) top5_contains_baseline_top1: f64, + pub(super) mean_baseline_top1_prob: f64, + pub(super) mean_catalog_prob_of_baseline_top1: f64, + pub(super) mean_baseline_top1_margin: f64, + pub(super) per_prompt: Vec, +} + +#[derive(Debug, Clone, Serialize)] +pub(super) struct OracleEditCatalogPromptReport { + pub(super) id: String, + pub(super) stratum: String, + pub(super) kl: f64, + pub(super) delta_cross_entropy_bits: f64, + pub(super) baseline_top1: u32, + pub(super) catalog_top1: u32, + pub(super) top1_agree: bool, + pub(super) baseline_top1_in_catalog_top5: bool, + pub(super) baseline_top1_prob: f64, + pub(super) baseline_top2: u32, + pub(super) baseline_top2_prob: f64, + pub(super) baseline_top1_margin: f64, + pub(super) catalog_top1_prob: f64, + pub(super) catalog_prob_of_baseline_top1: f64, + pub(super) catalog_top1_margin: f64, +} + +#[derive(Debug, Serialize)] +pub(super) struct OraclePqExceptionReport { + pub(super) index: String, + pub(super) prompt_file: String, + pub(super) prompts_seen: usize, + pub(super) train_prompts_seen: usize, + pub(super) eval_prompts_seen: usize, + #[serde(skip_serializing_if = "Option::is_none")] + pub(super) max_per_stratum: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub(super) eval_mod: Option, + pub(super) eval_offset: usize, + pub(super) static_base: String, + pub(super) base_config: PqConfig, + pub(super) exception_edits: Vec, + pub(super) tail_fracs: Vec, + pub(super) tail_selector: String, + pub(super) exception_fit: String, + pub(super) position_candidates_per_prompt: usize, + pub(super) sigma_rel_cutoff: f64, + pub(super) pq_iters: usize, + pub(super) exception_iters: usize, + pub(super) selected_heads: Vec, + pub(super) heads: Vec, +} + +#[derive(Debug, Serialize)] +pub(super) struct OraclePqExceptionHeadReport { + pub(super) layer: usize, + pub(super) head: usize, + pub(super) head_dim: usize, + pub(super) rank_retained: usize, + pub(super) empirical_rank: usize, + pub(super) sigma_max: f64, + pub(super) sigma_min_retained: f64, + pub(super) static_train_samples: u64, + pub(super) points: Vec, +} + +#[derive(Debug, Serialize)] +pub(super) struct OraclePqExceptionPointReport { + pub(super) exception_edits: usize, + pub(super) tail_frac: f64, + pub(super) train_error_samples: usize, + pub(super) train_error_samples_used: usize, + pub(super) base_address_bits: usize, + pub(super) exception_address_bits: usize, + pub(super) total_address_bits: usize, + pub(super) base_table_bytes_bf16: usize, + pub(super) exception_table_bytes_bf16: usize, + pub(super) total_table_bytes_bf16: usize, + pub(super) prompts: usize, + pub(super) mean_kl: f64, + pub(super) p95_kl: f64, + pub(super) max_kl: f64, + pub(super) mean_delta_cross_entropy_bits: f64, + pub(super) top1_agreement: f64, + pub(super) top5_contains_baseline_top1: f64, + pub(super) mean_baseline_top1_prob: f64, + pub(super) mean_exception_prob_of_baseline_top1: f64, + pub(super) mean_baseline_top1_margin: f64, + pub(super) per_prompt: Vec, +} + +#[derive(Debug, Clone, Serialize)] +pub(super) struct OraclePqExceptionPromptReport { + pub(super) id: String, + pub(super) stratum: String, + pub(super) kl: f64, + pub(super) delta_cross_entropy_bits: f64, + pub(super) baseline_top1: u32, + pub(super) exception_top1: u32, + pub(super) top1_agree: bool, + pub(super) baseline_top1_in_exception_top5: bool, + pub(super) baseline_top1_prob: f64, + pub(super) baseline_top2: u32, + pub(super) baseline_top2_prob: f64, + pub(super) baseline_top1_margin: f64, + pub(super) exception_top1_prob: f64, + pub(super) exception_prob_of_baseline_top1: f64, + pub(super) exception_top1_margin: f64, +} diff --git a/crates/larql-cli/src/commands/dev/ov_rd/runtime.rs b/crates/larql-cli/src/commands/dev/ov_rd/runtime.rs new file mode 100644 index 00000000..a9346368 --- /dev/null +++ b/crates/larql-cli/src/commands/dev/ov_rd/runtime.rs @@ -0,0 +1,16 @@ +use larql_inference::ModelWeights; +use larql_vindex::VectorIndex; + +pub(super) fn insert_q4k_layer_tensors( + weights: &mut ModelWeights, + index: &VectorIndex, + layer: usize, +) -> Result, Box> { + larql_inference::vindex::insert_q4k_layer_tensors(weights, index, layer).map_err(|err| { + Box::::from(std::io::Error::new(std::io::ErrorKind::Other, err)) + }) +} + +pub(super) fn remove_layer_tensors(weights: &mut ModelWeights, keys: Vec) { + larql_inference::vindex::remove_layer_tensors(weights, keys); +} diff --git a/crates/larql-cli/src/commands/dev/ov_rd/sanity.rs b/crates/larql-cli/src/commands/dev/ov_rd/sanity.rs new file mode 100644 index 00000000..7ea5d891 --- /dev/null +++ b/crates/larql-cli/src/commands/dev/ov_rd/sanity.rs @@ -0,0 +1,260 @@ +use std::path::PathBuf; +use std::time::Instant; + +use clap::Args; +use larql_inference::{encode_prompt, hidden_to_raw_logits}; +use larql_vindex::{ + load_model_weights_q4k, load_vindex_tokenizer, SilentLoadCallbacks, VectorIndex, +}; +use ndarray::{s, Array2}; + +use super::input::{load_prompts, parse_head_spec}; +use super::metrics::{kl_logp, log_softmax, max_abs_diff, mean}; +use super::reports::{SanityCheckReport, SanityHeadReport, SanityPromptReport}; +use super::types::HeadId; +use super::zero_ablate::forward_q4k_zero_pre_o_head; + +#[derive(Args)] +pub(super) struct SanityCheckArgs { + /// Self-contained Q4K vindex directory. + #[arg(long)] + index: PathBuf, + + /// JSONL prompt file. Each line must include at least {"prompt": "..."}. + #[arg(long)] + prompts: PathBuf, + + /// Output directory. + #[arg(long)] + out: PathBuf, + + /// Explicit heads as layer:head comma list, e.g. 0:4,0:6. + #[arg(long)] + heads: String, + + /// Limit prompts for bounded sanity runs. + #[arg(long)] + max_prompts: Option, +} + +#[derive(Debug)] +struct SanityHeadAccumulator { + prompts: Vec, +} + +impl SanityHeadAccumulator { + fn new() -> Self { + Self { + prompts: Vec::new(), + } + } + + fn add(&mut self, prompt: SanityPromptReport) { + self.prompts.push(prompt); + } + + fn finish(self, head: HeadId) -> SanityHeadReport { + let noop_kls: Vec = self.prompts.iter().map(|p| p.noop_kl).collect(); + let residual_delta_noop_kls: Vec = self + .prompts + .iter() + .map(|p| p.residual_delta_noop_kl) + .collect(); + let zero_subtract_kls: Vec = self.prompts.iter().map(|p| p.zero_subtract_kl).collect(); + SanityHeadReport { + layer: head.layer, + head: head.head, + prompts: self.prompts.len(), + noop_mean_kl: mean(&noop_kls), + noop_max_kl: noop_kls.iter().copied().fold(0.0, f64::max), + noop_max_abs_logit_diff: self + .prompts + .iter() + .map(|p| p.noop_max_abs_logit_diff) + .fold(0.0, f64::max), + residual_delta_noop_mean_kl: mean(&residual_delta_noop_kls), + residual_delta_noop_max_kl: residual_delta_noop_kls.iter().copied().fold(0.0, f64::max), + residual_delta_noop_max_abs_logit_diff: self + .prompts + .iter() + .map(|p| p.residual_delta_noop_max_abs_logit_diff) + .fold(0.0, f64::max), + zero_subtract_mean_kl: mean(&zero_subtract_kls), + zero_subtract_max_kl: zero_subtract_kls.iter().copied().fold(0.0, f64::max), + zero_subtract_max_abs_logit_diff: self + .prompts + .iter() + .map(|p| p.zero_subtract_max_abs_logit_diff) + .fold(0.0, f64::max), + per_prompt: self.prompts, + } + } +} + +pub(super) fn run_sanity_check(args: SanityCheckArgs) -> Result<(), Box> { + std::fs::create_dir_all(&args.out)?; + + eprintln!("Loading vindex: {}", args.index.display()); + let start = Instant::now(); + let mut cb = SilentLoadCallbacks; + let mut index = VectorIndex::load_vindex(&args.index, &mut cb)?; + index.load_attn_q4k(&args.index)?; + index.load_interleaved_q4k(&args.index)?; + let mut weights = load_model_weights_q4k(&args.index, &mut cb)?; + let tokenizer = load_vindex_tokenizer(&args.index)?; + if weights.arch.is_hybrid_moe() { + return Err("ov-rd sanity-check currently supports dense FFN vindexes only".into()); + } + eprintln!( + " {} layers, hidden_size={}, q_heads={}, head_dim={} ({:.1}s)", + weights.num_layers, + weights.hidden_size, + weights.num_q_heads, + weights.head_dim, + start.elapsed().as_secs_f64() + ); + + let selected_heads = parse_head_spec(&args.heads)?; + if selected_heads.is_empty() { + return Err("no heads selected for sanity check".into()); + } + let prompts = load_prompts(&args.prompts, args.max_prompts)?; + eprintln!("Selected heads: {:?}", selected_heads); + eprintln!("Prompts: {}", prompts.len()); + + let mut accumulators: Vec = selected_heads + .iter() + .map(|_| SanityHeadAccumulator::new()) + .collect(); + + for (prompt_idx, record) in prompts.iter().enumerate() { + let label = record + .id + .as_deref() + .or(record.stratum.as_deref()) + .unwrap_or("prompt"); + eprintln!(" [{}/{}] {}", prompt_idx + 1, prompts.len(), label); + + let token_ids = encode_prompt(&tokenizer, &*weights.arch, &record.prompt)?; + if token_ids.is_empty() { + continue; + } + let stratum = record.stratum.as_deref().unwrap_or("unknown"); + + let baseline_hidden = + larql_inference::vindex::predict_q4k_hidden(&mut weights, &token_ids, &index, None); + let baseline_logits = final_logits(&weights, &baseline_hidden); + let baseline_logp = log_softmax(&baseline_logits); + + for (idx, head) in selected_heads.iter().copied().enumerate() { + let noop_hidden = + forward_q4k_noop_replace_pre_o_head(&mut weights, &token_ids, &index, head)?; + let noop_logits = final_logits(&weights, &noop_hidden); + let noop_logp = log_softmax(&noop_logits); + + let residual_delta_noop_hidden = forward_q4k_noop_replace_head_residual_delta( + &mut weights, + &token_ids, + &index, + head, + )?; + let residual_delta_noop_logits = final_logits(&weights, &residual_delta_noop_hidden); + let residual_delta_noop_logp = log_softmax(&residual_delta_noop_logits); + + let zero_hidden = forward_q4k_zero_pre_o_head(&mut weights, &token_ids, &index, head)?; + let zero_logits = final_logits(&weights, &zero_hidden); + let zero_logp = log_softmax(&zero_logits); + + let subtract_hidden = + forward_q4k_subtract_pre_o_head(&mut weights, &token_ids, &index, head)?; + let subtract_logits = final_logits(&weights, &subtract_hidden); + let subtract_logp = log_softmax(&subtract_logits); + + accumulators[idx].add(SanityPromptReport { + id: label.to_string(), + stratum: stratum.to_string(), + noop_kl: kl_logp(&baseline_logp, &noop_logp), + noop_max_abs_logit_diff: max_abs_diff(&baseline_logits, &noop_logits), + residual_delta_noop_kl: kl_logp(&baseline_logp, &residual_delta_noop_logp), + residual_delta_noop_max_abs_logit_diff: max_abs_diff( + &baseline_logits, + &residual_delta_noop_logits, + ), + zero_subtract_kl: kl_logp(&zero_logp, &subtract_logp), + zero_subtract_max_abs_logit_diff: max_abs_diff(&zero_logits, &subtract_logits), + }); + } + } + + let heads = selected_heads + .iter() + .copied() + .zip(accumulators) + .map(|(head, acc)| acc.finish(head)) + .collect(); + let report = SanityCheckReport { + index: args.index.display().to_string(), + prompt_file: args.prompts.display().to_string(), + prompts_seen: prompts.len(), + selected_heads, + heads, + }; + + let out_path = args.out.join("sanity_check.json"); + let file = std::fs::File::create(&out_path)?; + serde_json::to_writer_pretty(file, &report)?; + eprintln!("Wrote {}", out_path.display()); + + Ok(()) +} + +fn forward_q4k_noop_replace_pre_o_head( + weights: &mut larql_inference::ModelWeights, + token_ids: &[u32], + index: &VectorIndex, + head: HeadId, +) -> Result, Box> { + larql_inference::vindex::predict_q4k_hidden_with_mapped_pre_o_head( + weights, + token_ids, + index, + head.layer, + head.head, + |original| Ok(original.clone()), + ) + .map_err(Into::into) +} + +fn forward_q4k_subtract_pre_o_head( + weights: &mut larql_inference::ModelWeights, + token_ids: &[u32], + index: &VectorIndex, + head: HeadId, +) -> Result, Box> { + larql_inference::vindex::predict_q4k_hidden_with_subtracted_pre_o_heads( + weights, + token_ids, + index, + head.layer, + &[head.head], + ) + .map_err(Into::into) +} + +fn forward_q4k_noop_replace_head_residual_delta( + weights: &mut larql_inference::ModelWeights, + token_ids: &[u32], + index: &VectorIndex, + head: HeadId, +) -> Result, Box> { + larql_inference::vindex::predict_q4k_hidden_with_original_head_residual_delta( + weights, token_ids, index, head.layer, head.head, + ) + .map_err(Into::into) +} + +fn final_logits(weights: &larql_inference::ModelWeights, h: &Array2) -> Vec { + let last = h.nrows().saturating_sub(1); + let h_last = h.slice(s![last..last + 1, ..]).to_owned(); + hidden_to_raw_logits(weights, &h_last) +} diff --git a/crates/larql-cli/src/commands/dev/ov_rd/static_replace.rs b/crates/larql-cli/src/commands/dev/ov_rd/static_replace.rs new file mode 100644 index 00000000..7d48beec --- /dev/null +++ b/crates/larql-cli/src/commands/dev/ov_rd/static_replace.rs @@ -0,0 +1,447 @@ +use std::collections::HashMap; +use std::path::PathBuf; +use std::time::Instant; + +use clap::Args; +use larql_inference::attention::run_attention_block_with_pre_o; +use larql_inference::forward::ple::precompute_per_layer_inputs; +use larql_inference::forward::{embed_tokens_pub, run_layer_with_ffn}; +use larql_inference::{encode_prompt, hidden_to_raw_logits, WeightFfn}; +use larql_vindex::{ + load_model_weights_q4k, load_vindex_tokenizer, SilentLoadCallbacks, VectorIndex, +}; +use ndarray::{s, Array2}; + +use super::input::{load_prompts, parse_head_spec, split_prompt_records}; +use super::metrics::{argmax, bool_rate, kl_logp, log_softmax, mean, percentile, top_k_indices}; +use super::reports::{ + StaticHeadReport, StaticModeReport, StaticReplacementReport, ZeroPromptReport, + ZeroStratumReport, +}; +use super::runtime::{insert_q4k_layer_tensors, remove_layer_tensors}; +use super::stats::{StaticHeadAccumulator, StaticHeadMeans}; +use super::types::{HeadId, PromptRecord}; + +#[derive(Args)] +pub(super) struct StaticReplaceArgs { + /// Self-contained Q4K vindex directory. + #[arg(long)] + index: PathBuf, + + /// JSONL prompt file. Each line must include at least {"prompt": "..."}. + #[arg(long)] + prompts: PathBuf, + + /// Output directory. + #[arg(long)] + out: PathBuf, + + /// Explicit heads as layer:head comma list, e.g. 11:3,11:0,0:4. + #[arg(long)] + heads: String, + + /// Limit prompts for bounded gate runs. + #[arg(long)] + max_prompts: Option, + + /// Evaluate only prompts where prompt_index % eval_mod == eval_offset. + /// The remaining prompts are used to fit static means. Omit for in-sample + /// fit/eval on the same prompt set. + #[arg(long)] + eval_mod: Option, + + /// Held-out modulo offset used with --eval-mod. + #[arg(long, default_value_t = 0)] + eval_offset: usize, +} + +#[derive(Debug, Clone, Copy)] +enum StaticReplacementKind { + Zero, + Global, + Position, + Stratum, + PositionPlusStratum, + PositionStratum, +} + +impl StaticReplacementKind { + fn as_str(self) -> &'static str { + match self { + Self::Zero => "zero", + Self::Global => "global_mean", + Self::Position => "position_mean", + Self::Stratum => "stratum_mean", + Self::PositionPlusStratum => "position_plus_stratum_mean", + Self::PositionStratum => "position_stratum_mean", + } + } +} + +const STATIC_REPLACEMENT_KINDS: [StaticReplacementKind; 6] = [ + StaticReplacementKind::Zero, + StaticReplacementKind::Global, + StaticReplacementKind::Position, + StaticReplacementKind::Stratum, + StaticReplacementKind::PositionPlusStratum, + StaticReplacementKind::PositionStratum, +]; + +#[derive(Debug)] +struct StaticModeAccumulator { + prompts: Vec, + by_stratum: HashMap>, +} + +impl StaticModeAccumulator { + fn new() -> Self { + Self { + prompts: Vec::new(), + by_stratum: HashMap::new(), + } + } + + fn add(&mut self, prompt: ZeroPromptReport) { + let stratum = prompt.stratum.clone(); + self.prompts.push(prompt.clone()); + self.by_stratum.entry(stratum).or_default().push(prompt); + } + + fn finish(self, kind: StaticReplacementKind) -> StaticModeReport { + let kl_values: Vec = self.prompts.iter().map(|p| p.kl).collect(); + let mean_delta_cross_entropy_bits = mean( + &self + .prompts + .iter() + .map(|p| p.delta_cross_entropy_bits) + .collect::>(), + ); + let mut worst_examples = self.prompts.clone(); + worst_examples.sort_by(|a, b| b.kl.partial_cmp(&a.kl).unwrap_or(std::cmp::Ordering::Equal)); + worst_examples.truncate(10); + let mut strata: Vec<_> = self + .by_stratum + .into_iter() + .map(|(stratum, prompts)| { + let values: Vec = prompts.iter().map(|p| p.kl).collect(); + ZeroStratumReport { + stratum, + prompts: prompts.len(), + mean_kl: mean(&values), + max_kl: values.iter().copied().fold(0.0, f64::max), + top1_agreement: bool_rate(prompts.iter().map(|p| p.top1_agree)), + top5_contains_baseline_top1: bool_rate( + prompts.iter().map(|p| p.baseline_top1_in_ablated_top5), + ), + } + }) + .collect(); + strata.sort_by(|a, b| a.stratum.cmp(&b.stratum)); + StaticModeReport { + replacement_kind: kind.as_str().to_string(), + patch_location: "before_W_O".to_string(), + runtime_class: match kind { + StaticReplacementKind::Zero => "negligible_test", + _ => "static_injection_lookup_add", + } + .to_string(), + prompts: self.prompts.len(), + mean_kl: mean(&kl_values), + p95_kl: percentile(kl_values.clone(), 0.95), + max_kl: kl_values.iter().copied().fold(0.0, f64::max), + mean_delta_cross_entropy_bits, + top1_agreement: bool_rate(self.prompts.iter().map(|p| p.top1_agree)), + top5_contains_baseline_top1: bool_rate( + self.prompts.iter().map(|p| p.baseline_top1_in_ablated_top5), + ), + strata, + worst_examples, + per_prompt: self.prompts, + } + } +} + +pub(super) fn run_static_replace( + args: StaticReplaceArgs, +) -> Result<(), Box> { + std::fs::create_dir_all(&args.out)?; + + eprintln!("Loading vindex: {}", args.index.display()); + let start = Instant::now(); + let mut cb = SilentLoadCallbacks; + let mut index = VectorIndex::load_vindex(&args.index, &mut cb)?; + index.load_attn_q4k(&args.index)?; + index.load_interleaved_q4k(&args.index)?; + let mut weights = load_model_weights_q4k(&args.index, &mut cb)?; + let tokenizer = load_vindex_tokenizer(&args.index)?; + if weights.arch.is_hybrid_moe() { + return Err("ov-rd static-replace currently supports dense FFN vindexes only".into()); + } + eprintln!( + " {} layers, hidden_size={}, q_heads={}, head_dim={} ({:.1}s)", + weights.num_layers, + weights.hidden_size, + weights.num_q_heads, + weights.head_dim, + start.elapsed().as_secs_f64() + ); + + let selected_heads = parse_head_spec(&args.heads)?; + if selected_heads.is_empty() { + return Err("no heads selected for static replacement".into()); + } + let prompts = load_prompts(&args.prompts, args.max_prompts)?; + eprintln!("Selected heads: {:?}", selected_heads); + eprintln!("Prompts: {}", prompts.len()); + let (fit_prompts, eval_prompts): (Vec, Vec) = + if let Some(eval_mod) = args.eval_mod { + split_prompt_records(&prompts, eval_mod, args.eval_offset)? + } else { + (prompts.clone(), prompts.clone()) + }; + + eprintln!("Pass 1/2: fitting static pre-W_O means"); + let means = fit_static_means( + &mut weights, + &index, + &tokenizer, + &fit_prompts, + &selected_heads, + )?; + + eprintln!("Pass 2/2: evaluating static replacements"); + let mut accumulators: HashMap<(HeadId, &'static str), StaticModeAccumulator> = HashMap::new(); + for head in &selected_heads { + for kind in STATIC_REPLACEMENT_KINDS { + accumulators.insert((*head, kind.as_str()), StaticModeAccumulator::new()); + } + } + + for (prompt_idx, record) in eval_prompts.iter().enumerate() { + let label = record + .id + .as_deref() + .or(record.stratum.as_deref()) + .unwrap_or("prompt"); + eprintln!(" [{}/{}] {}", prompt_idx + 1, eval_prompts.len(), label); + + let token_ids = encode_prompt(&tokenizer, &*weights.arch, &record.prompt)?; + if token_ids.is_empty() { + continue; + } + let stratum = record.stratum.as_deref().unwrap_or("unknown"); + let baseline_hidden = + larql_inference::vindex::predict_q4k_hidden(&mut weights, &token_ids, &index, None); + let baseline_logits = final_logits(&weights, &baseline_hidden); + let baseline_logp = log_softmax(&baseline_logits); + let baseline_top1 = argmax(&baseline_logits); + for head in &selected_heads { + let head_means = means.get(head).ok_or_else(|| { + format!("missing fitted means for L{} H{}", head.layer, head.head) + })?; + for kind in STATIC_REPLACEMENT_KINDS { + let replacement = + build_static_replacement(kind, token_ids.len(), head_means, stratum)?; + let replaced_hidden = forward_q4k_replace_pre_o_head( + &mut weights, + &token_ids, + &index, + *head, + &replacement, + )?; + let replaced_logits = final_logits(&weights, &replaced_hidden); + let replaced_logp = log_softmax(&replaced_logits); + let kl = kl_logp(&baseline_logp, &replaced_logp); + let replaced_top1 = argmax(&replaced_logits); + let replaced_top5 = top_k_indices(&replaced_logits, 5); + accumulators + .get_mut(&(*head, kind.as_str())) + .expect("static accumulator missing") + .add(ZeroPromptReport { + id: label.to_string(), + stratum: stratum.to_string(), + kl, + delta_cross_entropy_bits: kl / std::f64::consts::LN_2, + baseline_top1, + ablated_top1: replaced_top1, + top1_agree: baseline_top1 == replaced_top1, + baseline_top1_in_ablated_top5: replaced_top5.contains(&baseline_top1), + }); + } + } + } + + let mut head_reports = Vec::new(); + for head in &selected_heads { + let mut modes = Vec::new(); + for kind in STATIC_REPLACEMENT_KINDS { + let acc = accumulators + .remove(&(*head, kind.as_str())) + .expect("static accumulator missing at finish"); + modes.push(acc.finish(kind)); + } + let train_samples = means.get(head).map(|m| m.count).unwrap_or(0); + head_reports.push(StaticHeadReport { + layer: head.layer, + head: head.head, + train_samples, + modes, + }); + } + + let report = StaticReplacementReport { + index: args.index.display().to_string(), + prompt_file: args.prompts.display().to_string(), + prompts_seen: prompts.len(), + train_prompts_seen: fit_prompts.len(), + eval_prompts_seen: eval_prompts.len(), + eval_mod: args.eval_mod, + eval_offset: args.eval_offset, + selected_heads, + heads: head_reports, + }; + + let out_path = args.out.join("gate_static_replacement.json"); + let file = std::fs::File::create(&out_path)?; + serde_json::to_writer_pretty(file, &report)?; + eprintln!("Wrote {}", out_path.display()); + + Ok(()) +} + +pub(super) fn fit_static_means( + weights: &mut larql_inference::ModelWeights, + index: &VectorIndex, + tokenizer: &tokenizers::Tokenizer, + prompts: &[PromptRecord], + heads: &[HeadId], +) -> Result, Box> { + let mut heads_by_layer: HashMap> = HashMap::new(); + for head in heads { + heads_by_layer.entry(head.layer).or_default().push(*head); + } + + let mut accumulators: HashMap = HashMap::new(); + for head in heads { + let head_dim = weights.arch.head_dim_for_layer(head.layer); + accumulators.insert(*head, StaticHeadAccumulator::new(head_dim)); + } + + for (prompt_idx, record) in prompts.iter().enumerate() { + let label = record + .id + .as_deref() + .or(record.stratum.as_deref()) + .unwrap_or("prompt"); + eprintln!(" fit [{}/{}] {}", prompt_idx + 1, prompts.len(), label); + let token_ids = encode_prompt(tokenizer, &*weights.arch, &record.prompt)?; + if token_ids.is_empty() { + continue; + } + let stratum = record.stratum.as_deref().unwrap_or("unknown"); + let mut h = embed_tokens_pub(weights, &token_ids); + let ple_inputs = precompute_per_layer_inputs(weights, &h, &token_ids); + + for layer in 0..weights.num_layers { + let inserted = insert_q4k_layer_tensors(weights, index, layer)?; + if let Some(layer_heads) = heads_by_layer.get(&layer) { + let (_, pre_o) = run_attention_block_with_pre_o(weights, &h, layer) + .ok_or_else(|| format!("pre-W_O capture failed at layer {layer}"))?; + let head_dim = weights.arch.head_dim_for_layer(layer); + for head in layer_heads { + let start = head.head * head_dim; + let end = start + head_dim; + let acc = accumulators + .get_mut(head) + .expect("static mean accumulator missing"); + for pos in 0..pre_o.nrows() { + let row = pre_o.slice(s![pos, start..end]); + if let Some(values) = row.as_slice() { + acc.add(pos, stratum, values); + } + } + } + } + + { + let ffn = WeightFfn { weights }; + if let Some((h_new, _, _)) = + run_layer_with_ffn(weights, &h, layer, &ffn, false, ple_inputs.get(layer), None) + { + h = h_new; + } + } + remove_layer_tensors(weights, inserted); + } + } + + Ok(accumulators + .into_iter() + .map(|(head, acc)| (head, acc.finish())) + .collect()) +} + +fn build_static_replacement( + kind: StaticReplacementKind, + seq_len: usize, + means: &StaticHeadMeans, + stratum: &str, +) -> Result, Box> { + let mut values = Vec::with_capacity(seq_len * means.head_dim); + for pos in 0..seq_len { + let owned_row; + let row = match kind { + StaticReplacementKind::Zero => None, + StaticReplacementKind::Global => Some(&means.global), + StaticReplacementKind::Position => means.positions.get(pos).or(Some(&means.global)), + StaticReplacementKind::Stratum => means.strata.get(stratum).or(Some(&means.global)), + StaticReplacementKind::PositionPlusStratum => { + let pos_row = means.positions.get(pos).unwrap_or(&means.global); + let stratum_row = means.strata.get(stratum).unwrap_or(&means.global); + owned_row = pos_row + .iter() + .zip(stratum_row.iter()) + .zip(means.global.iter()) + .map(|((&p, &s), &g)| p + s - g) + .collect::>(); + Some(&owned_row) + } + StaticReplacementKind::PositionStratum => means + .position_strata + .get(stratum) + .and_then(|rows| rows.get(pos)) + .or_else(|| means.positions.get(pos)) + .or(Some(&means.global)), + }; + if let Some(row) = row { + values.extend_from_slice(row); + } else { + values.extend(std::iter::repeat(0.0).take(means.head_dim)); + } + } + Ok(Array2::from_shape_vec((seq_len, means.head_dim), values)?) +} + +fn forward_q4k_replace_pre_o_head( + weights: &mut larql_inference::ModelWeights, + token_ids: &[u32], + index: &VectorIndex, + head: HeadId, + replacement: &Array2, +) -> Result, Box> { + larql_inference::vindex::predict_q4k_hidden_with_replaced_pre_o_head( + weights, + token_ids, + index, + head.layer, + head.head, + replacement, + ) + .map_err(Into::into) +} + +fn final_logits(weights: &larql_inference::ModelWeights, h: &Array2) -> Vec { + let last = h.nrows().saturating_sub(1); + let h_last = h.slice(s![last..last + 1, ..]).to_owned(); + hidden_to_raw_logits(weights, &h_last) +} diff --git a/crates/larql-cli/src/commands/dev/ov_rd/stats.rs b/crates/larql-cli/src/commands/dev/ov_rd/stats.rs new file mode 100644 index 00000000..066bf4b8 --- /dev/null +++ b/crates/larql-cli/src/commands/dev/ov_rd/stats.rs @@ -0,0 +1,162 @@ +use std::collections::HashMap; + +use super::reports::FinishedHeadStats; + +#[derive(Debug)] +pub(super) struct RunningHeadStats { + count: u64, + sum: Vec, + sum_sq_norm: f64, +} + +impl RunningHeadStats { + pub(super) fn new(head_dim: usize) -> Self { + Self { + count: 0, + sum: vec![0.0; head_dim], + sum_sq_norm: 0.0, + } + } + + pub(super) fn add(&mut self, values: &[f32]) { + self.count += 1; + let mut sq = 0.0f64; + for (dst, &v) in self.sum.iter_mut().zip(values.iter()) { + let vf = v as f64; + *dst += vf; + sq += vf * vf; + } + self.sum_sq_norm += sq; + } + + pub(super) fn finish(&self) -> FinishedHeadStats { + if self.count == 0 { + return FinishedHeadStats { + count: 0, + mean_norm_sq: 0.0, + second_moment: 0.0, + variance: 0.0, + rms_norm: 0.0, + }; + } + let n = self.count as f64; + let mean_norm_sq = self + .sum + .iter() + .map(|v| { + let m = *v / n; + m * m + }) + .sum::(); + let second_moment = self.sum_sq_norm / n; + let variance = (second_moment - mean_norm_sq).max(0.0); + FinishedHeadStats { + count: self.count, + mean_norm_sq, + second_moment, + variance, + rms_norm: second_moment.sqrt(), + } + } +} + +#[derive(Debug, Clone)] +struct MeanAccumulator { + count: u64, + sum: Vec, +} + +impl MeanAccumulator { + fn new(dim: usize) -> Self { + Self { + count: 0, + sum: vec![0.0; dim], + } + } + + fn add(&mut self, values: &[f32]) { + self.count += 1; + for (dst, &value) in self.sum.iter_mut().zip(values.iter()) { + *dst += value as f64; + } + } + + fn mean(&self) -> Vec { + if self.count == 0 { + return vec![0.0; self.sum.len()]; + } + let n = self.count as f64; + self.sum.iter().map(|v| (*v / n) as f32).collect() + } +} + +#[derive(Debug)] +pub(super) struct StaticHeadAccumulator { + global: MeanAccumulator, + positions: Vec, + strata: HashMap, + position_strata: HashMap>, +} + +impl StaticHeadAccumulator { + pub(super) fn new(head_dim: usize) -> Self { + Self { + global: MeanAccumulator::new(head_dim), + positions: Vec::new(), + strata: HashMap::new(), + position_strata: HashMap::new(), + } + } + + pub(super) fn add(&mut self, position: usize, stratum: &str, values: &[f32]) { + self.global.add(values); + while self.positions.len() <= position { + self.positions + .push(MeanAccumulator::new(self.global.sum.len())); + } + self.positions[position].add(values); + self.strata + .entry(stratum.to_string()) + .or_insert_with(|| MeanAccumulator::new(self.global.sum.len())) + .add(values); + let by_position = self.position_strata.entry(stratum.to_string()).or_default(); + while by_position.len() <= position { + by_position.push(MeanAccumulator::new(self.global.sum.len())); + } + by_position[position].add(values); + } + + pub(super) fn finish(&self) -> StaticHeadMeans { + StaticHeadMeans { + count: self.global.count, + head_dim: self.global.sum.len(), + global: self.global.mean(), + positions: self.positions.iter().map(MeanAccumulator::mean).collect(), + strata: self + .strata + .iter() + .map(|(key, value)| (key.clone(), value.mean())) + .collect(), + position_strata: self + .position_strata + .iter() + .map(|(key, values)| { + ( + key.clone(), + values.iter().map(MeanAccumulator::mean).collect(), + ) + }) + .collect(), + } + } +} + +#[derive(Debug, Clone)] +pub(super) struct StaticHeadMeans { + pub(super) count: u64, + pub(super) head_dim: usize, + pub(super) global: Vec, + pub(super) positions: Vec>, + pub(super) strata: HashMap>, + pub(super) position_strata: HashMap>>, +} diff --git a/crates/larql-cli/src/commands/dev/ov_rd/types.rs b/crates/larql-cli/src/commands/dev/ov_rd/types.rs new file mode 100644 index 00000000..375527be --- /dev/null +++ b/crates/larql-cli/src/commands/dev/ov_rd/types.rs @@ -0,0 +1,21 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Deserialize)] +pub(super) struct PromptRecord { + pub(super) id: Option, + pub(super) stratum: Option, + pub(super) prompt: String, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub(super) struct HeadId { + pub(super) layer: usize, + pub(super) head: usize, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize)] +pub(super) struct PqConfig { + pub(super) k: usize, + pub(super) groups: usize, + pub(super) bits_per_group: usize, +} diff --git a/crates/larql-cli/src/commands/dev/ov_rd/zero_ablate.rs b/crates/larql-cli/src/commands/dev/ov_rd/zero_ablate.rs new file mode 100644 index 00000000..b3aa30b6 --- /dev/null +++ b/crates/larql-cli/src/commands/dev/ov_rd/zero_ablate.rs @@ -0,0 +1,310 @@ +use std::collections::HashMap; +use std::path::PathBuf; +use std::time::Instant; + +use clap::{Args, ValueEnum}; +use larql_inference::{encode_prompt, hidden_to_raw_logits}; +use larql_vindex::{ + load_model_weights_q4k, load_vindex_tokenizer, SilentLoadCallbacks, VectorIndex, +}; +use ndarray::{s, Array2}; + +use super::input::{load_prompts, parse_head_spec}; +use super::metrics::{argmax, bool_rate, kl_logp, log_softmax, mean, percentile, top_k_indices}; +use super::reports::{ + CaptureReport, ZeroAblationReport, ZeroHeadReport, ZeroPromptReport, ZeroStratumReport, +}; +use super::types::HeadId; + +#[derive(Args)] +pub(super) struct ZeroAblateArgs { + /// Self-contained Q4K vindex directory. + #[arg(long)] + index: PathBuf, + + /// JSONL prompt file. Each line must include at least {"prompt": "..."}. + #[arg(long)] + prompts: PathBuf, + + /// Output directory. + #[arg(long)] + out: PathBuf, + + /// Explicit heads as layer:head comma list, e.g. 11:3,11:0,0:4. + #[arg(long)] + heads: Option, + + /// Stage-0 stats JSON. Used with --top-heads when --heads is absent. + #[arg(long)] + stage0: Option, + + /// Number of highest-variance Stage-0 heads to test. + #[arg(long, default_value_t = 8)] + top_heads: usize, + + /// Stage-0 statistic used to rank --top-heads. + #[arg(long, value_enum, default_value_t = Stage0Rank::RawVariance)] + stage0_rank: Stage0Rank, + + /// Limit prompts for bounded gate runs. + #[arg(long)] + max_prompts: Option, +} + +#[derive(Clone, Copy, Debug, ValueEnum)] +enum Stage0Rank { + /// Rank by raw pre-W_O variance. + RawVariance, + /// Rank by W_O-visible residual contribution variance. + WoVisibleVariance, +} + +#[derive(Debug)] +struct ZeroHeadAccumulator { + prompts: Vec, + by_stratum: HashMap>, +} + +impl ZeroHeadAccumulator { + fn new() -> Self { + Self { + prompts: Vec::new(), + by_stratum: HashMap::new(), + } + } + + fn add(&mut self, prompt: ZeroPromptReport) { + let stratum = prompt.stratum.clone(); + self.prompts.push(prompt.clone()); + self.by_stratum.entry(stratum).or_default().push(prompt); + } + + fn finish(self, head: HeadId) -> ZeroHeadReport { + let prompts_len = self.prompts.len(); + let kl_values: Vec = self.prompts.iter().map(|p| p.kl).collect(); + let mean_kl = mean(&kl_values); + let p95_kl = percentile(kl_values.clone(), 0.95); + let max_kl = kl_values.iter().copied().fold(0.0, f64::max); + let mean_delta_cross_entropy_bits = mean( + &self + .prompts + .iter() + .map(|p| p.delta_cross_entropy_bits) + .collect::>(), + ); + let top1_agreement = bool_rate(self.prompts.iter().map(|p| p.top1_agree)); + let top5_contains_baseline_top1 = + bool_rate(self.prompts.iter().map(|p| p.baseline_top1_in_ablated_top5)); + let mut worst_examples = self.prompts.clone(); + worst_examples.sort_by(|a, b| b.kl.partial_cmp(&a.kl).unwrap_or(std::cmp::Ordering::Equal)); + worst_examples.truncate(10); + + let mut strata: Vec<_> = self + .by_stratum + .into_iter() + .map(|(stratum, prompts)| { + let values: Vec = prompts.iter().map(|p| p.kl).collect(); + ZeroStratumReport { + stratum, + prompts: prompts.len(), + mean_kl: mean(&values), + max_kl: values.iter().copied().fold(0.0, f64::max), + top1_agreement: bool_rate(prompts.iter().map(|p| p.top1_agree)), + top5_contains_baseline_top1: bool_rate( + prompts.iter().map(|p| p.baseline_top1_in_ablated_top5), + ), + } + }) + .collect(); + strata.sort_by(|a, b| a.stratum.cmp(&b.stratum)); + ZeroHeadReport { + layer: head.layer, + head: head.head, + ablation_kind: "zero_pre_wo".to_string(), + patch_location: "before_W_O".to_string(), + preserved_components: vec![ + "FFN".to_string(), + "PLE".to_string(), + "layer_scalar".to_string(), + ], + bounded_vocab_size: None, + prompts: prompts_len, + mean_kl, + p95_kl, + max_kl, + mean_delta_cross_entropy_bits, + top1_agreement, + top5_contains_baseline_top1, + strata, + worst_examples, + per_prompt: self.prompts, + } + } +} + +pub(super) fn run_zero_ablate(args: ZeroAblateArgs) -> Result<(), Box> { + std::fs::create_dir_all(&args.out)?; + + eprintln!("Loading vindex: {}", args.index.display()); + let start = Instant::now(); + let mut cb = SilentLoadCallbacks; + let mut index = VectorIndex::load_vindex(&args.index, &mut cb)?; + index.load_attn_q4k(&args.index)?; + index.load_interleaved_q4k(&args.index)?; + let mut weights = load_model_weights_q4k(&args.index, &mut cb)?; + let tokenizer = load_vindex_tokenizer(&args.index)?; + if weights.arch.is_hybrid_moe() { + return Err("ov-rd zero-ablate currently supports dense FFN vindexes only".into()); + } + eprintln!( + " {} layers, hidden_size={}, q_heads={}, head_dim={} ({:.1}s)", + weights.num_layers, + weights.hidden_size, + weights.num_q_heads, + weights.head_dim, + start.elapsed().as_secs_f64() + ); + + let selected_heads = select_zero_ablation_heads(&args)?; + if selected_heads.is_empty() { + return Err("no heads selected for zero-ablation".into()); + } + eprintln!("Selected heads: {:?}", selected_heads); + + let prompts = load_prompts(&args.prompts, args.max_prompts)?; + eprintln!("Prompts: {}", prompts.len()); + + let mut accumulators: Vec = selected_heads + .iter() + .map(|_| ZeroHeadAccumulator::new()) + .collect(); + + for (prompt_idx, record) in prompts.iter().enumerate() { + let label = record + .id + .as_deref() + .or(record.stratum.as_deref()) + .unwrap_or("prompt"); + eprintln!(" [{}/{}] {}", prompt_idx + 1, prompts.len(), label); + + let token_ids = encode_prompt(&tokenizer, &*weights.arch, &record.prompt)?; + if token_ids.is_empty() { + continue; + } + let stratum = record.stratum.as_deref().unwrap_or("unknown"); + + let baseline_hidden = + larql_inference::vindex::predict_q4k_hidden(&mut weights, &token_ids, &index, None); + let baseline_logits = final_logits(&weights, &baseline_hidden); + let baseline_logp = log_softmax(&baseline_logits); + let baseline_top1 = argmax(&baseline_logits); + + for (idx, head) in selected_heads.iter().copied().enumerate() { + let ablated_hidden = + forward_q4k_zero_pre_o_head(&mut weights, &token_ids, &index, head)?; + let ablated_logits = final_logits(&weights, &ablated_hidden); + let ablated_logp = log_softmax(&ablated_logits); + let kl = kl_logp(&baseline_logp, &ablated_logp); + let ablated_top1 = argmax(&ablated_logits); + let ablated_top5 = top_k_indices(&ablated_logits, 5); + accumulators[idx].add(ZeroPromptReport { + id: label.to_string(), + stratum: stratum.to_string(), + kl, + delta_cross_entropy_bits: kl / std::f64::consts::LN_2, + baseline_top1, + ablated_top1, + top1_agree: baseline_top1 == ablated_top1, + baseline_top1_in_ablated_top5: ablated_top5.contains(&baseline_top1), + }); + } + } + + let head_reports = selected_heads + .iter() + .copied() + .zip(accumulators) + .map(|(head, acc)| acc.finish(head)) + .collect(); + + let report = ZeroAblationReport { + index: args.index.display().to_string(), + prompt_file: args.prompts.display().to_string(), + prompts_seen: prompts.len(), + selected_heads, + heads: head_reports, + }; + + let out_path = args.out.join("gate1_zero_ablation.json"); + let file = std::fs::File::create(&out_path)?; + serde_json::to_writer_pretty(file, &report)?; + eprintln!("Wrote {}", out_path.display()); + + Ok(()) +} + +fn select_zero_ablation_heads( + args: &ZeroAblateArgs, +) -> Result, Box> { + let mut heads = if let Some(spec) = &args.heads { + parse_head_spec(spec)? + } else { + let stage0_path = args + .stage0 + .as_ref() + .ok_or("--heads or --stage0 must be provided")?; + let file = std::fs::File::open(stage0_path)?; + let report: CaptureReport = serde_json::from_reader(file)?; + let mut candidates = report.heads; + candidates.sort_by(|a, b| { + stage0_rank_score(b, args.stage0_rank) + .partial_cmp(&stage0_rank_score(a, args.stage0_rank)) + .unwrap_or(std::cmp::Ordering::Equal) + }); + candidates + .into_iter() + .take(args.top_heads) + .map(|h| HeadId { + layer: h.layer, + head: h.head, + }) + .collect() + }; + + heads.sort_by_key(|h| (h.layer, h.head)); + heads.dedup(); + Ok(heads) +} + +fn stage0_rank_score(head: &super::reports::HeadReport, rank: Stage0Rank) -> f64 { + match rank { + Stage0Rank::RawVariance => head.stats.variance, + Stage0Rank::WoVisibleVariance => head + .wo_visible_stats + .as_ref() + .map(|stats| stats.variance) + .unwrap_or(f64::NEG_INFINITY), + } +} + +pub(super) fn forward_q4k_zero_pre_o_head( + weights: &mut larql_inference::ModelWeights, + token_ids: &[u32], + index: &VectorIndex, + head: HeadId, +) -> Result, Box> { + larql_inference::vindex::predict_q4k_hidden_with_zeroed_pre_o_heads( + weights, + token_ids, + index, + head.layer, + &[head.head], + ) + .map_err(Into::into) +} + +fn final_logits(weights: &larql_inference::ModelWeights, h: &Array2) -> Vec { + let last = h.nrows().saturating_sub(1); + let h_last = h.slice(s![last..last + 1, ..]).to_owned(); + hidden_to_raw_logits(weights, &h_last) +} diff --git a/crates/larql-cli/src/commands/diagnostics/mod.rs b/crates/larql-cli/src/commands/diagnostics/mod.rs new file mode 100644 index 00000000..5ede1c65 --- /dev/null +++ b/crates/larql-cli/src/commands/diagnostics/mod.rs @@ -0,0 +1,9 @@ +//! Diagnostic / parity tools — `larql parity` and friends. +//! +//! Cross-backend numerical diff tooling. Used to catch silent regressions +//! between the CPU, Metal, and (eventually) HuggingFace reference paths +//! when refactoring quantisation, activations, norms, or expert routing. +//! +//! See `crates/larql-cli/ROADMAP.md` P0 → "`larql parity`" for the design. + +pub mod parity; diff --git a/crates/larql-cli/src/commands/diagnostics/parity.rs b/crates/larql-cli/src/commands/diagnostics/parity.rs new file mode 100644 index 00000000..559b7e5f --- /dev/null +++ b/crates/larql-cli/src/commands/diagnostics/parity.rs @@ -0,0 +1,1290 @@ +//! `larql parity` — cross-backend numerical diff for inference components. +//! +//! Diffs the same input through multiple backends (slow naive reference, +//! production CPU, Metal, HF — backends added incrementally) and reports +//! the first checkpoint where they diverge beyond `--tolerance`. +//! +//! v1 (this file) ships: +//! - `--component moe-expert` — single expert forward (gate / up / act / down) +//! - `--component moe-block` — full MoE block (router → top-K → experts → sum → norm) +//! - backends: `reference` (slow naive), `cpu` (production) +//! +//! v2 (planned) — Metal as a third backend, attention/dense-ffn/layer/forward +//! components. v3 — HF Python sidecar for ground-truth reference. +//! +//! See `crates/larql-cli/ROADMAP.md` P0 → "`larql parity`" for the full design. + +use clap::Args; + +use larql_compute::cpu::ops::moe::{cpu_moe_forward, run_single_expert_with_norm}; +use larql_compute::cpu::ops::q4_common::dequantize_q4_k; +use larql_compute::{Activation, MoeLayerWeights, QuantFormat}; +use larql_models::weights::{per_layer_ffn_key, PER_LAYER_FFN_DOWN, PER_LAYER_FFN_GATE_UP}; +use larql_vindex::{load_model_weights_q4k, load_vindex_config, SilentLoadCallbacks}; + +use crate::commands::primary::cache; + +// ── Component / backend taxonomies ──────────────────────────────────────────── + +/// Inference checkpoints that can be diffed independently. +const COMPONENTS: &[&str] = &[ + "moe-expert", // single expert forward (gate/up/act/down) + "moe-block", // full MoE block (router → top-K → experts → sum → norm) + "lm-head", // final projection parity (Q4_K vs f32 reference) + "layer", // full hybrid-MoE layer: CPU vs Metal, per-layer residual diff +]; + +/// Backends available as comparison targets. +/// +/// `reference` is the slow naive triple-loop CPU baseline. `cpu` is the +/// production path under test. `metal` is the GPU backend (v2 — used by +/// `--component layer`). +const BACKENDS: &[&str] = &[ + "reference", // slow naive baseline (moe-expert, moe-block) + "cpu", // production CPU path + "metal", // Metal GPU backend (layer component) +]; + +#[derive(Args)] +pub struct ParityArgs { + /// Vindex directory, `hf://` URL, or cache shorthand. Same resolution + /// as `larql run`. + pub model: String, + + /// Inference checkpoint to diff. v1: `moe-expert`, `moe-block`. + #[arg(long, default_value = "moe-block")] + pub component: String, + + /// Layer index. Default 0. + #[arg(long, default_value = "0")] + pub layer: usize, + + /// Expert index (used when `--component moe-expert`). + #[arg(long, default_value = "0")] + pub expert: usize, + + /// Comma-separated list of backends to run. v1: `reference,cpu`. + /// First backend in the list is the reference; subsequent backends + /// are diffed against it. + #[arg(long, default_value = "reference,cpu")] + pub backends: String, + + /// Prompt for `--component layer` (drives the actual forward pass). + /// For `moe-expert`/`moe-block`, the prompt seeds a synthetic residual + /// if provided; otherwise a deterministic sin-pattern is used. + #[arg(long)] + pub prompt: Option, + + /// Random-ish seed for the synthetic residual. Ignored when `--prompt` + /// is set. Default 0 produces the canonical sin pattern. + #[arg(long, default_value = "0")] + pub seed: u32, + + /// Max element-wise abs diff allowed before declaring divergence. The + /// right value depends on component depth — per-expert ≈ 1e-3, full + /// forward needs more headroom for accumulated f32 noise. + #[arg(long, default_value = "1e-3")] + pub tolerance: f64, + + /// Print intermediate values at each checkpoint, not just diffs. + #[arg(long, short)] + pub verbose: bool, +} + +pub fn run(args: ParityArgs) -> Result<(), Box> { + if !COMPONENTS.contains(&args.component.as_str()) { + return Err(format!( + "unknown --component '{}'. Available: {}", + args.component, + COMPONENTS.join(", ") + ) + .into()); + } + + // `layer` component always uses metal+cpu internally; other components + // need the backends list validated and require ≥2. + if args.component != "layer" { + let backends: Vec<&str> = args.backends.split(',').map(|s| s.trim()).collect(); + for b in &backends { + if !BACKENDS.contains(b) { + return Err(format!( + "unknown backend '{}'. Available: {}", + b, + BACKENDS.join(", ") + ) + .into()); + } + } + if backends.len() < 2 { + return Err("need at least 2 backends to diff (default is `reference,cpu`)".into()); + } + } + + // ── Resolve + load vindex ──────────────────────────────────────────────── + let path = cache::resolve_model(&args.model)?; + let config = load_vindex_config(&path)?; + let mut cb = SilentLoadCallbacks; + let weights = load_model_weights_q4k(&path, &mut cb)?; + let arch = &*weights.arch; + + println!("Vindex: {}", path.display()); + println!("Model: {}", config.model); + println!("Component: {}", args.component); + println!("Layer: {}", args.layer); + println!(); + + if args.component == "layer" { + return run_layer_diff(&path, &config, &args); + } + + // lm-head parity is backend-agnostic (Q4_K matvec vs f32 reference) — + // works on any vindex that has an lm_head, MoE or dense. + if !arch.is_hybrid_moe() && args.component != "lm-head" { + return Err(format!( + "vindex {} is not hybrid-MoE — moe-* components are MoE-only", + args.model + ) + .into()); + } + + let backends: Vec<&str> = args.backends.split(',').map(|s| s.trim()).collect(); + println!("Backends: {}", backends.join(" → ")); + println!(); + + match args.component.as_str() { + "moe-expert" => run_moe_expert(&config, &weights, &args, &backends), + "moe-block" => run_moe_block(&config, &weights, &args, &backends), + "lm-head" => run_lm_head(&path, &config, &weights, &args, &backends), + _ => unreachable!("validated above"), + } +} + +// ── lm-head: Q4_K-vs-reference logits for the final projection ─────────────── +// +// Diagnostic motivation: a 2026-04-27 silent-corruption bug had the writer +// emit Q4_K (`format/weights/write_q4k`) while `lm_head_knn_backend` dispatched +// `q4_matvec` (Q4_0). Same byte-rate per element (0.5625 B/elem) → identical +// file size → no validation caught the format collision → multilingual +// gibberish under `--metal`. This component diffs the actual on-disk Q4_K +// lm_head against an f32 reference computed from `weights.lm_head` (the model's +// HF-loaded tied embedding for Gemma 3/4 / Llama-tied / etc.). Any future +// format swap (Q4_K → Q4_KF, transposition, scale offset, ...) makes the +// top-1 token mismatch loud. + +fn run_lm_head( + path: &std::path::Path, + config: &larql_vindex::VindexConfig, + weights: &larql_models::ModelWeights, + args: &ParityArgs, + backends: &[&str], +) -> Result<(), Box> { + use larql_compute::CpuBackend; + use larql_vindex::SilentLoadCallbacks; + + let hidden = config.hidden_size; + let vocab = config.vocab_size; + println!("hidden={hidden}, vocab={vocab}"); + + // Build the same residual the moe-block / moe-expert variants use so a + // cross-component diff at the same prompt seed is straightforward. + let h = make_residual(hidden, args.seed); + + // Reference: f32 dot product against `weights.lm_head` (tied embedding + // for Gemma 3 / Gemma 4 / Llama; explicit lm_head row for untied). + let lm = &weights.lm_head; + if lm.is_empty() { + return Err("model has no lm_head loaded — re-run extract with weights enabled".into()); + } + let ref_scores: Vec = lm + .rows() + .into_iter() + .map(|row| row.iter().zip(h.iter()).map(|(a, b)| a * b).sum()) + .collect(); + + // Vindex side: load the index *here* (separately from the f32 weights + // load that load_model_weights_q4k did) so we exercise the production + // `open_inference_vindex` path including `load_lm_head_q4`. + let mut cb = SilentLoadCallbacks; + let mut index = larql_vindex::VectorIndex::load_vindex(path, &mut cb)?; + let _ = index.load_lm_head(path); + let _ = index.load_lm_head_q4(path); + let has_q4 = index.has_lm_head_q4(); + let has_full = index.has_lm_head(); + println!( + "lm_head sources: q4_mmap={has_q4} f32_mmap={has_full} tied_embed={}", + weights.lm_head.shape()[0] == config.vocab_size + ); + + // The cpu backend's lm_head_knn_backend does Q4_K matvec when the + // q4 mmap is present, falls back to f16 mmap, then f32 BLAS. We + // diff each available source against the reference so a regression + // in any one path stands out. + let cpu = CpuBackend; + let h1d = ndarray::Array1::from_vec(h.clone()); + + let mut traces: Vec<(&str, Vec)> = vec![("reference (f32 dot)", ref_scores.clone())]; + + if backends.iter().any(|b| *b == "cpu") { + let hits = index.lm_head_knn_backend(&h1d, vocab.min(8), &cpu); + if !hits.is_empty() { + // hits is (token, score) sorted descending. Reconstruct a + // sparse score vector for the diff helper. + let mut sparse = vec![f32::NEG_INFINITY; vocab]; + for (tok, score) in &hits { + sparse[*tok as usize] = *score; + } + traces.push(("cpu (lm_head_knn_backend)", sparse)); + } else { + println!( + " WARN: lm_head_knn_backend returned empty — vindex has no lm_head sources \ + (no lm_head_q4.bin, no lm_head.bin, no f16 mmap), and tied-embed fallback \ + lives in larql-inference. Re-run via `larql run` for the production path." + ); + } + } + + println!(); + println!("=== lm-head top-1 token comparison ==="); + let (ref_name, ref_v) = &traces[0]; + let ref_top1 = argmax(ref_v); + println!(" {ref_name:<28} top-1 token = {ref_top1}"); + for (name, v) in traces.iter().skip(1) { + let top1 = argmax(v); + let verdict = if top1 == ref_top1 { + "✓ matches reference" + } else { + "✗ DIFFERENT TOP-1 — likely format mismatch (Q4_K vs Q4_0, transposition, ...)" + }; + println!(" {name:<28} top-1 token = {top1} {verdict}"); + } + Ok(()) +} + +fn argmax(v: &[f32]) -> usize { + v.iter() + .enumerate() + .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal)) + .map(|(i, _)| i) + .unwrap_or(0) +} + +// ── moe-expert: one expert's forward pass (proven correct in v0) ───────────── + +fn run_moe_expert( + config: &larql_vindex::VindexConfig, + weights: &larql_models::ModelWeights, + args: &ParityArgs, + backends: &[&str], +) -> Result<(), Box> { + let arch = &*weights.arch; + let hidden = config.hidden_size; + let inter = arch.moe_intermediate_size(); + let inter_padded = inter.div_ceil(larql_models::quant::ggml::Q4_K_BLOCK_ELEMS) + * larql_models::quant::ggml::Q4_K_BLOCK_ELEMS; + let num_experts = arch.num_experts(); + if args.expert >= num_experts { + return Err(format!( + "expert {} out of range (model has {num_experts})", + args.expert + ) + .into()); + } + + let (gu_bytes, dn_bytes) = expert_bytes(weights, args.layer, args.expert)?; + let pre_norm = pre_experts_norm_for(weights, args.layer); + let activation = activation_for(arch); + let h = make_residual(hidden, args.seed); + + println!("Expert: {}", args.expert); + println!( + "Per-expert bytes: gate_up={} ({:.2} MB), down={} ({:.2} MB)", + gu_bytes.len(), + gu_bytes.len() as f64 / 1e6, + dn_bytes.len(), + dn_bytes.len() as f64 / 1e6, + ); + println!(); + + let mut traces: Vec<(&str, Vec)> = Vec::new(); + for backend in backends { + let out = match *backend { + "reference" => reference_one_expert( + &h, + gu_bytes, + dn_bytes, + hidden, + inter, + inter_padded, + pre_norm, + arch.norm_weight_offset(), + arch.norm_eps(), + activation, + args.verbose, + ), + "cpu" => run_single_expert_with_norm( + &h, + gu_bytes, + dn_bytes, + inter, + pre_norm, + arch.norm_weight_offset(), + arch.norm_eps(), + QuantFormat::Q4_K, + activation, + ), + _ => return Err(format!("backend '{backend}' not yet wired for moe-expert").into()), + }; + traces.push((backend, out)); + } + + println!("=== expert_output diff ==="); + diff_against_first(&traces, args.tolerance); + Ok(()) +} + +// ── moe-block: full block — router + top-K + K experts + sum + post-norm ───── +// +// This is the v1 component that should localise the current Gemma 4 26B-A4B +// CPU MoE bug — per-expert compute is already proven correct (see v0 +// prototype), so divergence here means routing or combination is off. + +fn run_moe_block( + config: &larql_vindex::VindexConfig, + weights: &larql_models::ModelWeights, + args: &ParityArgs, + backends: &[&str], +) -> Result<(), Box> { + let arch = &*weights.arch; + let hidden = config.hidden_size; + let inter = arch.moe_intermediate_size(); + let inter_padded = inter.div_ceil(larql_models::quant::ggml::Q4_K_BLOCK_ELEMS) + * larql_models::quant::ggml::Q4_K_BLOCK_ELEMS; + let num_experts = arch.num_experts(); + let top_k = arch.num_experts_per_token(); + + let h = make_residual(hidden, args.seed); + let pre_norm = pre_experts_norm_for(weights, args.layer); + let post_norm = post_experts_norm_for(weights, args.layer); + let router_proj = router_proj_for(weights, arch, args.layer)?; + let router_per_expert_scale = router_per_expert_scale_for(weights, arch, args.layer); + let router_norm = router_norm_for(weights, arch, args.layer); + let router_norm_parameter_free = arch.moe_router_norm_parameter_free(); + let router_input_scalar = arch.moe_router_input_scalar().unwrap_or(1.0); + let activation = activation_for(arch); + let norm_offset = arch.norm_weight_offset(); + let eps = arch.norm_eps(); + + println!( + "Block: layer {} of {}, hidden={hidden}, inter={inter} (padded {inter_padded}), \ + experts={num_experts} top_k={top_k}", + args.layer, config.num_layers + ); + println!(); + + // Build per-expert byte tables once — both backends consume the same. + let mut experts_gate_up: Vec<&[u8]> = Vec::with_capacity(num_experts); + let mut experts_down: Vec<&[u8]> = Vec::with_capacity(num_experts); + for e in 0..num_experts { + let (gu, dn) = expert_bytes(weights, args.layer, e)?; + experts_gate_up.push(gu); + experts_down.push(dn); + } + + let moe = MoeLayerWeights { + experts_gate_up: experts_gate_up.clone(), + experts_down: experts_down.clone(), + expert_data_format: QuantFormat::Q4_K, + router_proj: &router_proj, + router_scale: &[], + router_per_expert_scale: &router_per_expert_scale, + router_norm: &router_norm, + router_norm_parameter_free, + router_input_scalar, + pre_experts_norm: pre_norm, + post_ffn1_norm: &[], + post_experts_norm: post_norm, + num_experts, + top_k, + intermediate_size: inter, + activation, + }; + + let mut traces: Vec<(&str, Vec)> = Vec::new(); + for backend in backends { + let out = match *backend { + "reference" => reference_moe_block( + &h, + &experts_gate_up, + &experts_down, + &router_proj, + &router_per_expert_scale, + &router_norm, + router_norm_parameter_free, + router_input_scalar, + pre_norm, + post_norm, + hidden, + inter, + inter_padded, + num_experts, + top_k, + activation, + norm_offset, + eps, + args.verbose, + ), + "cpu" => cpu_moe_forward(&h, &moe, norm_offset, eps), + _ => return Err(format!("backend '{backend}' not yet wired for moe-block").into()), + }; + traces.push((backend, out)); + } + + println!("=== moe_block_output diff ==="); + diff_against_first(&traces, args.tolerance); + + // Side-by-side routing-convention check: which top-K does each + // convention select? Per HF Gemma4TextDecoderLayer.forward, the router + // consumes the raw post-attention residual; experts consume + // pre_experts_norm(residual). If h_norm and raw_h pick different + // experts, mis-routing the input is what produces "fluent but wrong" + // generation. + println!(); + println!("=== Routing-convention comparison ==="); + let h_norm = naive_rms_norm(&h, pre_norm, eps, norm_offset); + let (idx_raw, w_raw) = compute_top_k( + &h, + &router_proj, + &router_per_expert_scale, + &router_norm, + router_norm_parameter_free, + router_input_scalar, + num_experts, + top_k, + hidden, + eps, + norm_offset, + ); + let (idx_norm, w_norm) = compute_top_k( + &h_norm, + &router_proj, + &router_per_expert_scale, + &router_norm, + router_norm_parameter_free, + router_input_scalar, + num_experts, + top_k, + hidden, + eps, + norm_offset, + ); + println!(" router_in=raw_h top_k: {idx_raw:?}"); + println!( + " weights: {}", + w_raw + .iter() + .map(|w| format!("{w:.4}")) + .collect::>() + .join(" ") + ); + println!(" router_in=h_norm top_k: {idx_norm:?} ← Metal/GPU convention"); + println!( + " weights: {}", + w_norm + .iter() + .map(|w| format!("{w:.4}")) + .collect::>() + .join(" ") + ); + let same: Vec = idx_raw + .iter() + .filter(|&&e| idx_norm.contains(&e)) + .copied() + .collect(); + if same.len() == top_k { + println!(" ✓ SAME top-{top_k} experts selected — routing input choice is not the bug"); + } else { + println!( + " ✗ DIFFERENT top-{top_k}: {} overlap, {} differ — expert-selection convention IS the bug surface", + same.len(), + top_k - same.len() + ); + } + Ok(()) +} + +// ── layer: full hybrid-MoE layer CPU vs Metal residual diff ────────────────── +// +// Runs CPU `predict_q4k_hidden` and Metal `generate` on the same prompt with +// their respective dump hooks enabled, then compares per-layer residuals. +// +// CPU dumps: LARQL_CPU_DUMP_LAYERS → cpu_layer_{LL}.f32 (last-position row) +// LARQL_CPU_STAGE_DUMP → cpu_L0_.f32 +// Metal dump: LARQL_DUMP_RESIDUALS → binary (LARQL_RES_V2 header, then per- +// layer records: u32 layer_idx, u32 hidden, f32[hidden] layer_in, +// f32[hidden] h_post_attn, f32[hidden] layer_out) +// +// The comparison is decode-step vs prefill-last-token, so the two are in +// slightly different compute contexts (Metal uses KV cache; CPU re-processes +// the full sequence). This is sufficient to locate the first diverging layer +// but not to compute precise numeric agreement. + +fn run_layer_diff( + path: &std::path::Path, + config: &larql_vindex::VindexConfig, + args: &ParityArgs, +) -> Result<(), Box> { + use larql_inference::layer_graph::{generate::generate, CachedLayerGraph}; + use larql_inference::vindex::predict_q4k_hidden; + + let num_layers = config.num_layers; + let hidden = config.hidden_size; + + let prompt = args.prompt.as_deref().unwrap_or("The capital of France is"); + + println!("Prompt: {prompt:?}"); + println!("Backends: metal (reference) → cpu"); + println!(); + + // ── Set up temp dirs for dump files ───────────────────────────────────── + let base = std::env::temp_dir().join(format!("larql_parity_{}", std::process::id())); + let cpu_path_buf = base.join("cpu"); + let metal_path_buf = base.join("metal_residuals.bin"); + let metal_dense_dir = base.join("metal_dense"); + std::fs::create_dir_all(&cpu_path_buf)?; + let cpu_path = cpu_path_buf.as_path(); + let metal_path = metal_path_buf.as_path(); + struct Cleanup(std::path::PathBuf); + impl Drop for Cleanup { + fn drop(&mut self) { + let _ = std::fs::remove_dir_all(&self.0); + } + } + let _cleanup = Cleanup(base); + + // ── Load vindex (shared mmap; two weight copies for the two runs) ──────── + let mut cb = larql_vindex::SilentLoadCallbacks; + let mut q4_index = larql_vindex::VectorIndex::load_vindex(path, &mut cb)?; + q4_index.load_attn_q4k(path)?; + q4_index.load_interleaved_q4k(path)?; + let _ = q4_index.load_lm_head_q4(path); + let tokenizer = larql_vindex::load_vindex_tokenizer(path)?; + let mut w_metal = larql_vindex::load_model_weights_q4k(path, &mut cb)?; + let mut w_cpu = larql_vindex::load_model_weights_q4k(path, &mut cb)?; + + let wrapped = larql_inference::wrap_chat_prompt(path, Some(config.model.as_str()), prompt); + let token_ids = larql_inference::encode_prompt(&tokenizer, &*w_metal.arch, &wrapped.prompt)?; + println!(" seq_len: {} tokens post-template", token_ids.len()); + println!(); + + // The MoE decode path writes a single LARQL_DUMP_RESIDUALS binary + // covering every layer; the dense Metal decode path doesn't fire that + // hook (it only runs in the MoE branch of decode_token_with_moe_split_fn). + // For dense models we use LARQL_METAL_DUMP_LAYERS, which fires inside + // prefill_q4 and writes one file per layer (metal_layer_NN_h_out.f32 + + // metal_layer_NN_h_post_attn.f32). This aligns with the CPU dumps, + // which are also captured during prefill. + let is_moe = w_metal.arch.is_hybrid_moe(); + if !is_moe { + std::fs::create_dir_all(&metal_dense_dir)?; + } + + // ── Metal run (reference — produces correct output) ────────────────────── + if is_moe { + std::env::set_var("LARQL_DUMP_RESIDUALS", metal_path); + } else { + std::env::set_var("LARQL_METAL_DUMP_LAYERS", &metal_dense_dir); + } + println!("Running Metal…"); + let metal_result = { + let backend = larql_compute::metal::MetalBackend::new() + .ok_or("Metal backend unavailable — build with `--features metal` on M-series Mac")?; + let cache = CachedLayerGraph::from_residuals(Vec::new()); + generate( + &mut w_metal, + &tokenizer, + &token_ids, + 1, + &q4_index, + &backend, + &cache, + 0..num_layers, + ) + }; + std::env::remove_var("LARQL_DUMP_RESIDUALS"); + std::env::remove_var("LARQL_METAL_DUMP_LAYERS"); + println!(" Metal output: {:?}", metal_result.text().trim()); + + // ── CPU run ────────────────────────────────────────────────────────────── + std::env::set_var("LARQL_CPU_DUMP_LAYERS", cpu_path); + std::env::set_var("LARQL_CPU_STAGE_DUMP", cpu_path); + println!("Running CPU…"); + predict_q4k_hidden(&mut w_cpu, &token_ids, &q4_index, None); + std::env::remove_var("LARQL_CPU_DUMP_LAYERS"); + std::env::remove_var("LARQL_CPU_STAGE_DUMP"); + + // ── Load per-layer Metal output ────────────────────────────────────────── + // MoE: parse binary residual dump (richer — includes h_post_attn). + // Dense: read decode_layer_NN.f32 written by LARQL_DECODE_DUMP_LAYERS. + let metal_layers: std::collections::BTreeMap = if is_moe { + let metal_bytes = std::fs::read(metal_path)?; + let parsed = parse_residual_dump(&metal_bytes); + if parsed.is_empty() { + return Err( + "Metal residual dump is empty — LARQL_DUMP_RESIDUALS may not have fired".into(), + ); + } + parsed.into_iter().collect() + } else { + // Prefill dumps: metal_layer_NN_h_out.f32 (post-FFN residual) and + // metal_layer_NN_h_post_attn.f32 (post-attention residual). + // Both have shape [seq_len * hidden]; we take the last position. + let last_pos_slice = |v: Vec| -> Vec { + let n = v.len() / hidden; + if n == 0 { + v + } else { + v[(n - 1) * hidden..].to_vec() + } + }; + let mut out = std::collections::BTreeMap::new(); + for l in 0..num_layers { + let h_out_path = metal_dense_dir.join(format!("metal_layer_{l:02}_h_out.f32")); + let h_pa_path = metal_dense_dir.join(format!("metal_layer_{l:02}_h_post_attn.f32")); + let layer_out = match read_parity_f32(&h_out_path) { + Some(v) => last_pos_slice(v), + None => continue, + }; + let h_post_attn = read_parity_f32(&h_pa_path) + .map(last_pos_slice) + .unwrap_or_default(); + out.insert( + l, + ResidualRecord { + h_post_attn, + layer_out, + }, + ); + } + if out.is_empty() { + return Err( + "Metal dense dump is empty — LARQL_METAL_DUMP_LAYERS may not have fired".into(), + ); + } + out + }; + + // ── Compare per layer ──────────────────────────────────────────────────── + println!(); + println!("━━━ Layer-by-layer residual diff (Metal = reference) ━━━━━━━━━━"); + println!( + " {:>3} {:>10} {:>10} {:>10} {:>12} note", + "L", "cos(h_pa)", "cos(h_out)", "‖cpu‖", "‖metal‖" + ); + println!(" {}", "─".repeat(72)); + + const DRIFT: f32 = 0.9999; + let mut first_bad: Option = None; + + for l in 0..num_layers { + let cpu_out_path = cpu_path.join(format!("cpu_layer_{l:02}.f32")); + let cpu_pa_path = cpu_path.join(format!("cpu_layer_{l:02}_h_post_attn.f32")); + + let cpu_out = match read_parity_f32(&cpu_out_path) { + Some(v) => v, + None => { + println!(" L{l:02} "); + continue; + } + }; + let metal_rec = match metal_layers.get(&l) { + Some(r) => r, + None => { + println!(" L{l:02} "); + continue; + } + }; + + // CPU dump has (seq_len × hidden) elements; take the last position. + let seq_positions = cpu_out.len() / hidden; + let cpu_last = if seq_positions > 0 { + cpu_out[(seq_positions - 1) * hidden..].to_vec() + } else { + cpu_out.clone() + }; + + let cos_out = naive_cos_sim(&cpu_last, &metal_rec.layer_out); + let norm_cpu = naive_rms_mag(&cpu_last); + let norm_mtl = naive_rms_mag(&metal_rec.layer_out); + + // Dense path doesn't capture h_post_attn separately, so cos(h_pa) + // is only computed when we have it (MoE). + let cos_pa = if metal_rec.h_post_attn.is_empty() { + None + } else { + read_parity_f32(&cpu_pa_path).map(|v| { + let n = v.len() / hidden; + let last = if n > 0 { + v[(n - 1) * hidden..].to_vec() + } else { + v + }; + naive_cos_sim(&last, &metal_rec.h_post_attn) + }) + }; + + if cos_out < DRIFT && first_bad.is_none() { + first_bad = Some(l); + } + let flag = if cos_out < DRIFT { " ←" } else { "" }; + let note = match cos_pa { + Some(ca) if ca < DRIFT && cos_out < DRIFT => "attn+ffn", + Some(ca) if ca < DRIFT => "attn", + Some(_) if cos_out < DRIFT => "ffn/moe", + Some(_) => "clean", + None => "?", + }; + let hpa_s = cos_pa + .map(|c| format!("{c:>10.6}")) + .unwrap_or_else(|| " -".into()); + println!( + " L{l:02} {hpa_s} {cos_out:>10.6} {norm_cpu:>10.4} {norm_mtl:>12.4} {note}{flag}" + ); + } + + println!(); + match first_bad { + Some(l) => { + println!("First divergence at L{l} (cos < {DRIFT})."); + let note = if l == 0 { + "L0 drift — culprit is embedding, pre-norm, attention, or MoE combine." + } else { + "Earlier layers match; drift introduced at this layer." + }; + println!("{note}"); + } + None => { + println!("All layers match within cos ≥ {DRIFT}."); + println!("Note: Metal decode vs CPU prefill — slight positional mismatch expected."); + } + } + + Ok(()) +} + +/// Per-layer record from `LARQL_DUMP_RESIDUALS` binary. +struct ResidualRecord { + h_post_attn: Vec, + layer_out: Vec, +} + +/// Parse `LARQL_DUMP_RESIDUALS` binary (written by `moe_combine.rs / diag.rs`). +/// Returns a map from layer_idx → record. Skips the 16-byte magic header. +fn parse_residual_dump(bytes: &[u8]) -> std::collections::HashMap { + let mut map = std::collections::HashMap::new(); + if bytes.len() < 16 { + return map; + } + let mut pos = 16usize; // skip magic + while pos + 8 <= bytes.len() { + let layer_idx = u32::from_le_bytes(bytes[pos..pos + 4].try_into().unwrap()) as usize; + let hidden = u32::from_le_bytes(bytes[pos + 4..pos + 8].try_into().unwrap()) as usize; + pos += 8; + let n_bytes = hidden * 4; + if pos + n_bytes * 3 > bytes.len() { + break; + } + let layer_in: Vec = bytes[pos..pos + n_bytes] + .chunks_exact(4) + .map(|b| f32::from_le_bytes(b.try_into().unwrap())) + .collect(); + pos += n_bytes; + let h_post_attn: Vec = bytes[pos..pos + n_bytes] + .chunks_exact(4) + .map(|b| f32::from_le_bytes(b.try_into().unwrap())) + .collect(); + pos += n_bytes; + let layer_out: Vec = bytes[pos..pos + n_bytes] + .chunks_exact(4) + .map(|b| f32::from_le_bytes(b.try_into().unwrap())) + .collect(); + pos += n_bytes; + let _ = layer_in; // used for format validation only + map.insert( + layer_idx, + ResidualRecord { + h_post_attn, + layer_out, + }, + ); + } + map +} + +fn read_parity_f32(path: &std::path::Path) -> Option> { + let bytes = std::fs::read(path).ok()?; + if bytes.len() % 4 != 0 { + return None; + } + Some( + bytes + .chunks_exact(4) + .map(|b| f32::from_le_bytes(b.try_into().unwrap())) + .collect(), + ) +} + +fn naive_cos_sim(a: &[f32], b: &[f32]) -> f32 { + let n = a.len().min(b.len()); + let dot: f32 = a[..n].iter().zip(&b[..n]).map(|(x, y)| x * y).sum(); + let na: f32 = a[..n].iter().map(|x| x * x).sum::().sqrt(); + let nb: f32 = b[..n].iter().map(|x| x * x).sum::().sqrt(); + dot / (na * nb + 1e-10) +} + +fn naive_rms_mag(v: &[f32]) -> f32 { + (v.iter().map(|x| x * x).sum::() / v.len() as f32).sqrt() +} + +// ── Reference impls (slow + naive) ──────────────────────────────────────────── + +#[allow(clippy::too_many_arguments)] +fn reference_one_expert( + h: &[f32], + gu_bytes: &[u8], + dn_bytes: &[u8], + hidden: usize, + inter: usize, + inter_padded: usize, + pre_norm: &[f32], + norm_offset: f32, + eps: f32, + activation: Activation, + verbose: bool, +) -> Vec { + let h_norm = naive_rms_norm(h, pre_norm, eps, norm_offset); + if verbose { + dump3("ref h_norm", &h_norm); + } + let gate_up_w = dequantize_q4_k(gu_bytes, 2 * inter * hidden); + let down_w = dequantize_q4_k(dn_bytes, hidden * inter_padded); + + let gate_w = &gate_up_w[..inter * hidden]; + let up_w = &gate_up_w[inter * hidden..2 * inter * hidden]; + + let gate_out = naive_matvec(&h_norm, gate_w, inter, hidden); + let up_out = naive_matvec(&h_norm, up_w, inter, hidden); + if verbose { + dump3("ref gate_out", &gate_out); + dump3("ref up_out ", &up_out); + } + + let mut hidden_state = vec![0.0f32; inter_padded]; + for j in 0..inter { + hidden_state[j] = match activation { + Activation::GeluTanh => naive_gelu_tanh(gate_out[j]) * up_out[j], + _ => naive_silu(gate_out[j]) * up_out[j], + }; + } + naive_matvec(&hidden_state, &down_w, hidden, inter_padded) +} + +#[allow(clippy::too_many_arguments)] +fn reference_moe_block( + h: &[f32], + experts_gate_up: &[&[u8]], + experts_down: &[&[u8]], + router_proj: &[f32], + router_per_expert_scale: &[f32], + router_norm: &[f32], + router_norm_parameter_free: bool, + router_input_scalar: f32, + pre_norm: &[f32], + post_norm: &[f32], + hidden: usize, + inter: usize, + inter_padded: usize, + num_experts: usize, + top_k: usize, + activation: Activation, + norm_offset: f32, + eps: f32, + verbose: bool, +) -> Vec { + // 1. Pre-experts norm — for the expert matmuls. + let h_norm = naive_rms_norm(h, pre_norm, eps, norm_offset); + if verbose { + dump3("ref h_norm ", &h_norm); + } + + // 2. Router input norm — applied to h_norm (matching Metal's + // `cpu_moe_route(&h_norm, ...)` and the routing-convention fix + // in `cpu_moe_forward`). Empirically the trained 26B-A4B weights + // expect this even though HF's modeling_gemma4.py uses raw h. + let router_in_normed = if !router_norm.is_empty() { + naive_rms_norm(&h_norm, router_norm, eps, norm_offset) + } else if router_norm_parameter_free { + naive_rms_norm(&h_norm, &[], eps, 0.0) + } else { + h_norm.clone() + }; + let mut router_in = router_in_normed; + if router_input_scalar != 1.0 && router_input_scalar != 0.0 { + for v in router_in.iter_mut() { + *v *= router_input_scalar; + } + } + if verbose { + dump3("ref router_in ", &router_in); + } + + // 3. Router projection [hidden → num_experts]. + let mut logits = naive_matvec(&router_in, router_proj, num_experts, hidden); + naive_softmax(&mut logits); + + // 4. Top-K + renormalisation. + let (indices, mut weights) = naive_top_k(&logits, top_k); + let sum: f32 = weights.iter().sum(); + if sum > 0.0 { + for w in &mut weights { + *w /= sum; + } + } + if !router_per_expert_scale.is_empty() { + for (i, &ei) in indices.iter().enumerate() { + if ei < router_per_expert_scale.len() { + weights[i] *= router_per_expert_scale[ei]; + } + } + } + if verbose { + println!( + " ref top_k indices: {:?} weights: {:?}", + indices, + weights + .iter() + .map(|w| format!("{w:.4}")) + .collect::>() + ); + } + + // 5. Sum K weighted expert outputs. + let mut moe_out = vec![0.0f32; hidden]; + for (k, &ei) in indices.iter().enumerate() { + let w = weights[k]; + if w == 0.0 { + continue; + } + let contrib = reference_one_expert( + h, + experts_gate_up[ei], + experts_down[ei], + hidden, + inter, + inter_padded, + pre_norm, + norm_offset, + eps, + activation, + false, + ); + for (acc, &v) in moe_out.iter_mut().zip(contrib.iter()) { + *acc += w * v; + } + } + if verbose { + dump3("ref pre-post-norm ", &moe_out); + } + + // 6. Post-experts norm. + if !post_norm.is_empty() { + moe_out = naive_rms_norm(&moe_out, post_norm, eps, norm_offset); + } + moe_out +} + +/// Run only the routing portion of the MoE block — return top-K indices + +/// renormalised weights. Used by the routing-convention diff to expose +/// whether two router-input variants pick different experts. +#[allow(clippy::too_many_arguments)] +fn compute_top_k( + router_in_pre: &[f32], + router_proj: &[f32], + router_per_expert_scale: &[f32], + router_norm: &[f32], + router_norm_parameter_free: bool, + router_input_scalar: f32, + num_experts: usize, + top_k: usize, + hidden: usize, + eps: f32, + norm_offset: f32, +) -> (Vec, Vec) { + let router_in_normed = if !router_norm.is_empty() { + naive_rms_norm(router_in_pre, router_norm, eps, norm_offset) + } else if router_norm_parameter_free { + naive_rms_norm(router_in_pre, &[], eps, 0.0) + } else { + router_in_pre.to_vec() + }; + let mut router_in = router_in_normed; + if router_input_scalar != 1.0 && router_input_scalar != 0.0 { + for v in router_in.iter_mut() { + *v *= router_input_scalar; + } + } + let mut logits = naive_matvec(&router_in, router_proj, num_experts, hidden); + naive_softmax(&mut logits); + let (indices, mut weights) = naive_top_k(&logits, top_k); + let sum: f32 = weights.iter().sum(); + if sum > 0.0 { + for w in &mut weights { + *w /= sum; + } + } + if !router_per_expert_scale.is_empty() { + for (i, &ei) in indices.iter().enumerate() { + if ei < router_per_expert_scale.len() { + weights[i] *= router_per_expert_scale[ei]; + } + } + } + (indices, weights) +} + +// ── Naive primitives (f64 accumulators, no BLAS) ────────────────────────────── + +fn naive_matvec(x: &[f32], w: &[f32], out_rows: usize, in_cols: usize) -> Vec { + let mut out = vec![0.0f32; out_rows]; + for r in 0..out_rows { + let mut s = 0.0f64; + for c in 0..in_cols { + s += (w[r * in_cols + c] as f64) * (x[c] as f64); + } + out[r] = s as f32; + } + out +} + +fn naive_rms_norm(x: &[f32], w: &[f32], eps: f32, offset: f32) -> Vec { + let n = x.len(); + if n == 0 { + return Vec::new(); + } + let rms = (x.iter().map(|v| (*v as f64) * (*v as f64)).sum::() / n as f64 + eps as f64) + .sqrt() as f32; + if w.is_empty() { + return x.iter().map(|v| v / rms).collect(); + } + x.iter() + .zip(w.iter()) + .map(|(v, ww)| (v / rms) * (ww + offset)) + .collect() +} + +fn naive_softmax(x: &mut [f32]) { + let max = x.iter().cloned().fold(f32::NEG_INFINITY, f32::max); + let mut sum = 0.0f64; + for v in x.iter_mut() { + *v = (*v - max).exp(); + sum += *v as f64; + } + if sum > 0.0 { + let inv = (1.0 / sum) as f32; + for v in x.iter_mut() { + *v *= inv; + } + } +} + +fn naive_top_k(logits: &[f32], k: usize) -> (Vec, Vec) { + let k = k.min(logits.len()); + let mut idx: Vec = (0..logits.len()).collect(); + idx.sort_by(|&a, &b| logits[b].partial_cmp(&logits[a]).unwrap()); + idx.truncate(k); + let weights: Vec = idx.iter().map(|&i| logits[i]).collect(); + (idx, weights) +} + +fn naive_gelu_tanh(x: f32) -> f32 { + let c = 0.7978845608_f32; + 0.5 * x * (1.0 + (c * (x + 0.044715 * x * x * x)).tanh()) +} + +fn naive_silu(x: f32) -> f32 { + x / (1.0 + (-x).exp()) +} + +// ── Vindex helpers ──────────────────────────────────────────────────────────── + +fn expert_bytes<'a>( + weights: &'a larql_models::ModelWeights, + layer: usize, + expert: usize, +) -> Result<(&'a [u8], &'a [u8]), Box> { + let gu_key = per_layer_ffn_key(layer, expert, PER_LAYER_FFN_GATE_UP); + let dn_key = per_layer_ffn_key(layer, expert, PER_LAYER_FFN_DOWN); + let gu = weights + .get_packed_bytes(&gu_key) + .ok_or_else(|| format!("missing per-layer entry: {gu_key}"))?; + let dn = weights + .get_packed_bytes(&dn_key) + .ok_or_else(|| format!("missing per-layer entry: {dn_key}"))?; + Ok((gu, dn)) +} + +fn pre_experts_norm_for<'a>(weights: &'a larql_models::ModelWeights, layer: usize) -> &'a [f32] { + weights + .arch + .moe_pre_experts_norm_key(layer) + .and_then(|k| weights.vectors.get(&k)) + .map(|v| v.as_slice()) + .unwrap_or(&[]) +} + +fn post_experts_norm_for<'a>(weights: &'a larql_models::ModelWeights, layer: usize) -> &'a [f32] { + weights + .arch + .moe_post_experts_norm_key(layer) + .and_then(|k| weights.vectors.get(&k)) + .map(|v| v.as_slice()) + .unwrap_or(&[]) +} + +fn router_proj_for( + weights: &larql_models::ModelWeights, + arch: &dyn larql_models::ModelArchitecture, + layer: usize, +) -> Result, Box> { + let key = arch + .moe_router_key(layer) + .ok_or("arch has no router_proj key for this layer")?; + weights + .vectors + .get(&key) + .cloned() + .ok_or_else(|| format!("router_proj not found in weights: {key}").into()) +} + +fn router_per_expert_scale_for( + weights: &larql_models::ModelWeights, + arch: &dyn larql_models::ModelArchitecture, + layer: usize, +) -> Vec { + arch.moe_router_per_expert_scale_key(layer) + .and_then(|k| weights.vectors.get(&k)) + .cloned() + .unwrap_or_default() +} + +fn router_norm_for( + weights: &larql_models::ModelWeights, + arch: &dyn larql_models::ModelArchitecture, + layer: usize, +) -> Vec { + arch.moe_router_norm_key(layer) + .and_then(|k| weights.vectors.get(&k)) + .cloned() + .unwrap_or_default() +} + +fn activation_for(arch: &dyn larql_models::ModelArchitecture) -> Activation { + match arch.activation() { + larql_models::Activation::GeluTanh => Activation::GeluTanh, + _ => Activation::Silu, + } +} + +fn make_residual(hidden: usize, seed: u32) -> Vec { + // Deterministic per-(hidden, seed) sin pattern. seed=0 reproduces the + // canonical pattern used by the bench / parity tests. + let phase = (seed as f32) * 0.001; + (0..hidden) + .map(|i| ((i as f32 + 1.0) * 0.0007 + phase).sin()) + .collect() +} + +// ── Diff reporter ───────────────────────────────────────────────────────────── + +fn diff_against_first(traces: &[(&str, Vec)], tolerance: f64) { + let (ref_name, ref_v) = &traces[0]; + println!( + "Reference backend: {ref_name} (first {} elems used as the truth)", + ref_v.len() + ); + let n = ref_v.len(); + print!(" {ref_name:<10} [0..3] = ["); + for (i, x) in ref_v.iter().take(3).enumerate() { + if i > 0 { + print!(", "); + } + print!("{:+.4e}", x); + } + println!("]"); + + for (name, v) in traces.iter().skip(1) { + if v.len() != n { + println!( + " {name:<10} LENGTH MISMATCH: ref.len={n}, {name}.len={}", + v.len() + ); + continue; + } + let mut max_abs = 0.0f64; + let mut max_idx = 0; + let mut max_a = 0.0f32; + let mut max_b = 0.0f32; + let mut nan = 0; + for (i, (a, b)) in ref_v.iter().zip(v.iter()).enumerate() { + if a.is_nan() || b.is_nan() { + nan += 1; + continue; + } + let d = ((a - b) as f64).abs(); + if d > max_abs { + max_abs = d; + max_idx = i; + max_a = *a; + max_b = *b; + } + } + let verdict = if max_abs < tolerance { + "✓ within tolerance" + } else if max_abs < tolerance * 100.0 { + "⚠ small drift" + } else { + "✗ DIVERGENCE" + }; + print!(" {name:<10} [0..3] = ["); + for (i, x) in v.iter().take(3).enumerate() { + if i > 0 { + print!(", "); + } + print!("{:+.4e}", x); + } + println!("]"); + println!( + " max |Δ|={:.3e} at idx {} (ref={:+.4e}, {name}={:+.4e}) {verdict}", + max_abs, max_idx, max_a, max_b + ); + if nan > 0 { + println!(" NaN count: {nan}"); + } + } +} + +fn dump3(label: &str, v: &[f32]) { + let n = v.len().min(3); + print!(" {label}: ["); + for (i, x) in v.iter().take(n).enumerate() { + if i > 0 { + print!(", "); + } + print!("{:+.6e}", x); + } + if v.len() > n { + print!(", …] ({} elems)", v.len()); + } else { + print!("]"); + } + println!(); +} diff --git a/crates/larql-cli/src/commands/extraction/attention_capture_cmd.rs b/crates/larql-cli/src/commands/extraction/attention_capture_cmd.rs index 6f00bf53..6af181b5 100644 --- a/crates/larql-cli/src/commands/extraction/attention_capture_cmd.rs +++ b/crates/larql-cli/src/commands/extraction/attention_capture_cmd.rs @@ -82,12 +82,8 @@ pub fn run(args: AttentionCaptureArgs) -> Result<(), Box> eprintln!("\nRunning forward pass for prompt {}...", i + 1); let start = Instant::now(); let trace = trace_forward_full( - weights, - token_ids, - &layers, - false, // no activation capture - 0, - true, // capture attention + weights, token_ids, &layers, false, // no activation capture + 0, true, // capture attention &ffn, ); eprintln!(" {:.1}s", start.elapsed().as_secs_f64()); @@ -115,7 +111,8 @@ pub fn run(args: AttentionCaptureArgs) -> Result<(), Box> // Check if this head is active (above threshold) for any prompt let max_attn: f32 = (0..num_prompts) .filter_map(|pi| { - all_captures.get(pi) + all_captures + .get(pi) .and_then(|c| c.get(li)) .and_then(|h| h.get(head)) .map(|w| w.iter().copied().fold(0.0f32, f32::max)) @@ -130,7 +127,8 @@ pub fn run(args: AttentionCaptureArgs) -> Result<(), Box> if args.verbose || num_prompts <= 3 { println!("L{layer} H{head} (max={max_attn:.3}):"); for (pi, prompt) in args.prompts.iter().enumerate() { - if let Some(weights) = all_captures.get(pi) + if let Some(weights) = all_captures + .get(pi) .and_then(|c| c.get(li)) .and_then(|h| h.get(head)) { @@ -139,7 +137,8 @@ pub fn run(args: AttentionCaptureArgs) -> Result<(), Box> .enumerate() .filter(|(_, &w)| w > 0.01) .map(|(j, &w)| { - let label = all_token_labels.get(pi) + let label = all_token_labels + .get(pi) .and_then(|l| l.get(j)) .map(|s| s.as_str()) .unwrap_or("?"); @@ -171,16 +170,27 @@ pub fn run(args: AttentionCaptureArgs) -> Result<(), Box> for (li, &layer) in layers.iter().enumerate() { for head in 0..num_heads { // Get attention patterns for first two prompts - let w0 = match all_captures.first().and_then(|c| c.get(li)).and_then(|h| h.get(head)) { + let w0 = match all_captures + .first() + .and_then(|c| c.get(li)) + .and_then(|h| h.get(head)) + { Some(w) => w, None => continue, }; - let w1 = match all_captures.get(1).and_then(|c| c.get(li)).and_then(|h| h.get(head)) { + let w1 = match all_captures + .get(1) + .and_then(|c| c.get(li)) + .and_then(|h| h.get(head)) + { Some(w) => w, None => continue, }; - let max_attn = w0.iter().copied().fold(0.0f32, f32::max) + let max_attn = w0 + .iter() + .copied() + .fold(0.0f32, f32::max) .max(w1.iter().copied().fold(0.0f32, f32::max)); if max_attn < args.threshold { @@ -214,16 +224,27 @@ pub fn run(args: AttentionCaptureArgs) -> Result<(), Box> for (li, _) in layers.iter().enumerate() { for head in 0..num_heads { - let w0 = match all_captures.first().and_then(|c| c.get(li)).and_then(|h| h.get(head)) { + let w0 = match all_captures + .first() + .and_then(|c| c.get(li)) + .and_then(|h| h.get(head)) + { Some(w) => w, None => continue, }; - let w1 = match all_captures.get(1).and_then(|c| c.get(li)).and_then(|h| h.get(head)) { + let w1 = match all_captures + .get(1) + .and_then(|c| c.get(li)) + .and_then(|h| h.get(head)) + { Some(w) => w, None => continue, }; - let max_attn = w0.iter().copied().fold(0.0f32, f32::max) + let max_attn = w0 + .iter() + .copied() + .fold(0.0f32, f32::max) .max(w1.iter().copied().fold(0.0f32, f32::max)); if max_attn < args.threshold { continue; @@ -245,10 +266,22 @@ pub fn run(args: AttentionCaptureArgs) -> Result<(), Box> println!("\n═══ Summary ═══"); println!(" Active heads (above threshold): {total_active}"); - println!(" FIXED (corr > 0.95): {fixed} ({:.0}%)", fixed as f64 / total_active as f64 * 100.0); - println!(" SIMILAR (corr > 0.8): {similar} ({:.0}%)", similar as f64 / total_active as f64 * 100.0); - println!(" PARTIAL (corr > 0.5): {partial} ({:.0}%)", partial as f64 / total_active as f64 * 100.0); - println!(" DIFFERENT (corr < 0.5): {different} ({:.0}%)", different as f64 / total_active as f64 * 100.0); + println!( + " FIXED (corr > 0.95): {fixed} ({:.0}%)", + fixed as f64 / total_active as f64 * 100.0 + ); + println!( + " SIMILAR (corr > 0.8): {similar} ({:.0}%)", + similar as f64 / total_active as f64 * 100.0 + ); + println!( + " PARTIAL (corr > 0.5): {partial} ({:.0}%)", + partial as f64 / total_active as f64 * 100.0 + ); + println!( + " DIFFERENT (corr < 0.5): {different} ({:.0}%)", + different as f64 / total_active as f64 * 100.0 + ); if fixed + similar > total_active * 80 / 100 { println!("\n → Attention is largely TEMPLATE-FIXED. Circuit caching viable."); diff --git a/crates/larql-cli/src/commands/extraction/attn_bottleneck_cmd.rs b/crates/larql-cli/src/commands/extraction/attn_bottleneck_cmd.rs index 25b045ee..7ddce999 100644 --- a/crates/larql-cli/src/commands/extraction/attn_bottleneck_cmd.rs +++ b/crates/larql-cli/src/commands/extraction/attn_bottleneck_cmd.rs @@ -1,9 +1,7 @@ use std::time::Instant; use clap::Args; -use larql_inference::{ - trace_forward, InferenceModel, -}; +use larql_inference::{trace_forward, InferenceModel}; #[derive(Args)] pub struct AttnBottleneckArgs { @@ -29,7 +27,9 @@ pub fn run(args: AttnBottleneckArgs) -> Result<(), Box> { let model = InferenceModel::load(&args.model)?; let weights = model.weights(); - let encoding = model.tokenizer().encode(args.prompt.as_str(), true) + let encoding = model + .tokenizer() + .encode(args.prompt.as_str(), true) .map_err(|e| format!("tokenize error: {e}"))?; let token_ids: Vec = encoding.get_ids().to_vec(); let seq_len = token_ids.len(); @@ -87,19 +87,25 @@ pub fn run(args: AttnBottleneckArgs) -> Result<(), Box> { // 1. Q projection: (seq, hidden) @ (hidden, q_dim) → (seq, q_dim) let _ = h_norm.dot(&w_q.t()); let start = Instant::now(); - for _ in 0..iters { let _ = h_norm.dot(&w_q.t()); } + for _ in 0..iters { + let _ = h_norm.dot(&w_q.t()); + } let q_proj_us = start.elapsed().as_micros() as f64 / iters as f64; // 2. K projection let _ = h_norm.dot(&w_k.t()); let start = Instant::now(); - for _ in 0..iters { let _ = h_norm.dot(&w_k.t()); } + for _ in 0..iters { + let _ = h_norm.dot(&w_k.t()); + } let k_proj_us = start.elapsed().as_micros() as f64 / iters as f64; // 3. V projection let _ = h_norm.dot(&w_v.t()); let start = Instant::now(); - for _ in 0..iters { let _ = h_norm.dot(&w_v.t()); } + for _ in 0..iters { + let _ = h_norm.dot(&w_v.t()); + } let v_proj_us = start.elapsed().as_micros() as f64 / iters as f64; // 4. RoPE (approximate — just measure the time to apply_rope) @@ -108,13 +114,16 @@ pub fn run(args: AttnBottleneckArgs) -> Result<(), Box> { let start = Instant::now(); for _ in 0..iters { let _ = larql_inference::attention::apply_rope(&q_full, num_q, head_dim, weights.rope_base); - let _ = larql_inference::attention::apply_rope(&k_full, num_kv, head_dim, weights.rope_base); + let _ = + larql_inference::attention::apply_rope(&k_full, num_kv, head_dim, weights.rope_base); } let rope_us = start.elapsed().as_micros() as f64 / iters as f64; // 5. QK^T attention scores + softmax + V multiply (the full GQA attention) - let q_rope = larql_inference::attention::apply_rope(&q_full, num_q, head_dim, weights.rope_base); - let k_rope = larql_inference::attention::apply_rope(&k_full, num_kv, head_dim, weights.rope_base); + let q_rope = + larql_inference::attention::apply_rope(&q_full, num_q, head_dim, weights.rope_base); + let k_rope = + larql_inference::attention::apply_rope(&k_full, num_kv, head_dim, weights.rope_base); let v_full = h_norm.dot(&w_v.t()); let reps = num_q / num_kv; let scale = (head_dim as f64).powf(-0.5) * arch.attention_multiplier() as f64; @@ -132,7 +141,9 @@ pub fn run(args: AttnBottleneckArgs) -> Result<(), Box> { &q_rope, &k_rope, &v_full, num_q, head_dim, reps, scale, seq_len, false, None, ); let start = Instant::now(); - for _ in 0..iters { let _ = attn_out.dot(&w_o.t()); } + for _ in 0..iters { + let _ = attn_out.dot(&w_o.t()); + } let o_proj_us = start.elapsed().as_micros() as f64 / iters as f64; // 7. Full attention (end-to-end via run_attention_public) @@ -142,39 +153,90 @@ pub fn run(args: AttnBottleneckArgs) -> Result<(), Box> { } let full_attn_us = start.elapsed().as_micros() as f64 / iters as f64; - let sum_parts = norm_us + q_proj_us + k_proj_us + v_proj_us + rope_us + attn_core_us + o_proj_us; + let sum_parts = + norm_us + q_proj_us + k_proj_us + v_proj_us + rope_us + attn_core_us + o_proj_us; println!(); - println!("Attention Layer {} Bottleneck (seq_len={}, hidden={}, {}q/{}kv, head_dim={})", - layer, seq_len, hidden, num_q, num_kv, head_dim); + println!( + "Attention Layer {} Bottleneck (seq_len={}, hidden={}, {}q/{}kv, head_dim={})", + layer, seq_len, hidden, num_q, num_kv, head_dim + ); println!("{}", "=".repeat(65)); - println!("{:>30} {:>10} {:>10}", "Component", "Time (us)", "% of Attn"); + println!( + "{:>30} {:>10} {:>10}", + "Component", "Time (us)", "% of Attn" + ); println!("{}", "-".repeat(65)); - println!("{:>30} {:>8.0}us {:>9.1}%", "input layernorm", norm_us, norm_us / sum_parts * 100.0); - println!("{:>30} {:>8.0}us {:>9.1}%", - format!("Q proj ({}→{})", hidden, q_dim), q_proj_us, q_proj_us / sum_parts * 100.0); - println!("{:>30} {:>8.0}us {:>9.1}%", - format!("K proj ({}→{})", hidden, kv_dim), k_proj_us, k_proj_us / sum_parts * 100.0); - println!("{:>30} {:>8.0}us {:>9.1}%", - format!("V proj ({}→{})", hidden, kv_dim), v_proj_us, v_proj_us / sum_parts * 100.0); - println!("{:>30} {:>8.0}us {:>9.1}%", "RoPE (Q+K)", rope_us, rope_us / sum_parts * 100.0); - println!("{:>30} {:>8.0}us {:>9.1}%", - format!("QK^T + softmax + V ({}h)", num_q), attn_core_us, attn_core_us / sum_parts * 100.0); - println!("{:>30} {:>8.0}us {:>9.1}%", - format!("O proj ({}→{})", q_dim, hidden), o_proj_us, o_proj_us / sum_parts * 100.0); + println!( + "{:>30} {:>8.0}us {:>9.1}%", + "input layernorm", + norm_us, + norm_us / sum_parts * 100.0 + ); + println!( + "{:>30} {:>8.0}us {:>9.1}%", + format!("Q proj ({}→{})", hidden, q_dim), + q_proj_us, + q_proj_us / sum_parts * 100.0 + ); + println!( + "{:>30} {:>8.0}us {:>9.1}%", + format!("K proj ({}→{})", hidden, kv_dim), + k_proj_us, + k_proj_us / sum_parts * 100.0 + ); + println!( + "{:>30} {:>8.0}us {:>9.1}%", + format!("V proj ({}→{})", hidden, kv_dim), + v_proj_us, + v_proj_us / sum_parts * 100.0 + ); + println!( + "{:>30} {:>8.0}us {:>9.1}%", + "RoPE (Q+K)", + rope_us, + rope_us / sum_parts * 100.0 + ); + println!( + "{:>30} {:>8.0}us {:>9.1}%", + format!("QK^T + softmax + V ({}h)", num_q), + attn_core_us, + attn_core_us / sum_parts * 100.0 + ); + println!( + "{:>30} {:>8.0}us {:>9.1}%", + format!("O proj ({}→{})", q_dim, hidden), + o_proj_us, + o_proj_us / sum_parts * 100.0 + ); println!("{}", "-".repeat(65)); - println!("{:>30} {:>8.0}us {:>9.1}%", "Sum of parts", sum_parts, 100.0); + println!( + "{:>30} {:>8.0}us {:>9.1}%", + "Sum of parts", sum_parts, 100.0 + ); println!("{:>30} {:>8.0}us", "Actual full attention", full_attn_us); println!(); let proj_total = q_proj_us + k_proj_us + v_proj_us + o_proj_us; - println!("{:>30} {:>8.0}us {:>9.1}% (4 linear projections)", - "Total projections", proj_total, proj_total / sum_parts * 100.0); - println!("{:>30} {:>8.0}us {:>9.1}% (RoPE + QK^T + softmax + V)", - "Total attention math", rope_us + attn_core_us, (rope_us + attn_core_us) / sum_parts * 100.0); - println!("{:>30} {:>8.0}us {:>9.1}% (input layernorm)", - "Total norms", norm_us, norm_us / sum_parts * 100.0); + println!( + "{:>30} {:>8.0}us {:>9.1}% (4 linear projections)", + "Total projections", + proj_total, + proj_total / sum_parts * 100.0 + ); + println!( + "{:>30} {:>8.0}us {:>9.1}% (RoPE + QK^T + softmax + V)", + "Total attention math", + rope_us + attn_core_us, + (rope_us + attn_core_us) / sum_parts * 100.0 + ); + println!( + "{:>30} {:>8.0}us {:>9.1}% (input layernorm)", + "Total norms", + norm_us, + norm_us / sum_parts * 100.0 + ); Ok(()) } diff --git a/crates/larql-cli/src/commands/extraction/bottleneck_test_cmd.rs b/crates/larql-cli/src/commands/extraction/bottleneck_test_cmd.rs index cf9081db..ddd6acad 100644 --- a/crates/larql-cli/src/commands/extraction/bottleneck_test_cmd.rs +++ b/crates/larql-cli/src/commands/extraction/bottleneck_test_cmd.rs @@ -39,8 +39,8 @@ fn rule_score(prompt: &str) -> f32 { let p = prompt.to_lowercase(); // Non-ASCII fraction (multilingual detection) - let ascii_frac = prompt.chars().filter(|c| c.is_ascii()).count() as f32 - / prompt.len().max(1) as f32; + let ascii_frac = + prompt.chars().filter(|c| c.is_ascii()).count() as f32 / prompt.len().max(1) as f32; if ascii_frac < 0.7 { return 6000.0; } @@ -113,7 +113,8 @@ pub fn run(args: BottleneckTestArgs) -> Result<(), Box> { let num_layers = weights.num_layers; eprintln!( " {} layers, hidden_size={} ({:.1}s)", - num_layers, hidden, + num_layers, + hidden, start.elapsed().as_secs_f64() ); @@ -141,7 +142,9 @@ pub fn run(args: BottleneckTestArgs) -> Result<(), Box> { eprintln!( "\n── End-to-end: 9 rules → L{} state → L{}-L{} dense ──\n", - bn.layer, inject_layer, num_layers - 1 + bn.layer, + inject_layer, + num_layers - 1 ); println!( @@ -193,8 +196,13 @@ pub fn run(args: BottleneckTestArgs) -> Result<(), Box> { } // Run L14-33 - let rule_result = - predict_from_hidden(weights, model.tokenizer(), &h_hybrid, inject_layer, args.top_k); + let rule_result = predict_from_hidden( + weights, + model.tokenizer(), + &h_hybrid, + inject_layer, + args.top_k, + ); let (rule_tok, rule_conf) = rule_result .predictions .first() diff --git a/crates/larql-cli/src/commands/extraction/build_cmd.rs b/crates/larql-cli/src/commands/extraction/build_cmd.rs index 200d9c52..5a1729d6 100644 --- a/crates/larql-cli/src/commands/extraction/build_cmd.rs +++ b/crates/larql-cli/src/commands/extraction/build_cmd.rs @@ -33,21 +33,33 @@ pub fn run(args: BuildArgs) -> Result<(), Box> { // Summary let stage_str = args.stage.as_deref().unwrap_or("(default)"); - let num_patches = vf.directives.iter().filter(|d| matches!(d, larql_vindex::VindexfileDirective::Patch(_))).count(); - let num_inserts = vf.directives.iter().filter(|d| matches!(d, larql_vindex::VindexfileDirective::Insert { .. })).count(); - let num_deletes = vf.directives.iter().filter(|d| matches!(d, larql_vindex::VindexfileDirective::Delete { .. })).count(); + let num_patches = vf + .directives + .iter() + .filter(|d| matches!(d, larql_vindex::VindexfileDirective::Patch(_))) + .count(); + let num_inserts = vf + .directives + .iter() + .filter(|d| matches!(d, larql_vindex::VindexfileDirective::Insert { .. })) + .count(); + let num_deletes = vf + .directives + .iter() + .filter(|d| matches!(d, larql_vindex::VindexfileDirective::Delete { .. })) + .count(); eprintln!( " Stage: {}, {} patches, {} inserts, {} deletes, {} stages defined", - stage_str, num_patches, num_inserts, num_deletes, vf.stages.len(), + stage_str, + num_patches, + num_inserts, + num_deletes, + vf.stages.len(), ); // Build eprintln!("\nBuilding..."); - let result = larql_vindex::build_from_vindexfile( - &vf, - args.stage.as_deref(), - &args.dir, - )?; + let result = larql_vindex::build_from_vindexfile(&vf, args.stage.as_deref(), &args.dir)?; // Print build history eprintln!("\nBuild history:"); @@ -61,7 +73,9 @@ pub fn run(args: BuildArgs) -> Result<(), Box> { } // Save to output directory - let output_dir = args.output.unwrap_or_else(|| args.dir.join("build").join("vindex")); + let output_dir = args + .output + .unwrap_or_else(|| args.dir.join("build").join("vindex")); std::fs::create_dir_all(&output_dir)?; eprintln!("\nSaving to {}...", output_dir.display()); @@ -78,14 +92,14 @@ pub fn run(args: BuildArgs) -> Result<(), Box> { // Total overrides let total_modified: usize = result.layers.iter().map(|l| l.features_modified).sum(); - eprintln!( - " Total: {} features modified from base", - total_modified - ); + eprintln!(" Total: {} features modified from base", total_modified); if let Some(format) = args.compile { eprintln!("\nCompiling to {} format...", format); - eprintln!(" (compile not yet implemented — built vindex saved at {})", output_dir.display()); + eprintln!( + " (compile not yet implemented — built vindex saved at {})", + output_dir.display() + ); } eprintln!("\nDone. Usage:"); diff --git a/crates/larql-cli/src/commands/extraction/circuit_discover_cmd.rs b/crates/larql-cli/src/commands/extraction/circuit_discover_cmd.rs index 65ebb86c..8136f6b6 100644 --- a/crates/larql-cli/src/commands/extraction/circuit_discover_cmd.rs +++ b/crates/larql-cli/src/commands/extraction/circuit_discover_cmd.rs @@ -6,8 +6,8 @@ use std::time::Instant; use clap::Args; use larql_inference::ndarray; use larql_inference::tokenizers; -use larql_vindex::load_feature_labels; use larql_inference::InferenceModel; +use larql_vindex::load_feature_labels; #[derive(Args)] pub struct CircuitDiscoverArgs { @@ -53,7 +53,7 @@ struct OvGateEdge { /// A template circuit: a set of attention heads that route to the same FFN features. struct Circuit { id: usize, - heads: Vec<(usize, usize)>, // (layer, head) + heads: Vec<(usize, usize)>, // (layer, head) features: Vec<(usize, usize, f32)>, // (layer, feature, total_coupling) top_tokens: Vec, } @@ -72,7 +72,8 @@ pub fn run(args: CircuitDiscoverArgs) -> Result<(), Box> eprintln!( " {} layers, {} heads ({:.1}s)", - num_layers, num_q_heads, + num_layers, + num_q_heads, start.elapsed().as_secs_f64() ); @@ -156,7 +157,12 @@ pub fn run(args: CircuitDiscoverArgs) -> Result<(), Box> eprint!("L{layer}... "); let _ = io::stderr().flush(); if (layer + 1) % 10 == 0 { - eprintln!("({}/{} layers, {:.0}s)", layer + 1, num_layers, start.elapsed().as_secs_f64()); + eprintln!( + "({}/{} layers, {:.0}s)", + layer + 1, + num_layers, + start.elapsed().as_secs_f64() + ); eprint!(" "); let _ = io::stderr().flush(); } @@ -180,20 +186,27 @@ pub fn run(args: CircuitDiscoverArgs) -> Result<(), Box> edge.gate_top_token = label.clone(); } } - eprintln!(" {} labels loaded ({:.1}s)", label_map.len(), label_start.elapsed().as_secs_f64()); + eprintln!( + " {} labels loaded ({:.1}s)", + label_map.len(), + label_start.elapsed().as_secs_f64() + ); } else { // Slow path: project each feature against vocab eprintln!(" Labeling features (slow — use --labels for instant labels)..."); let mut unique_features: HashMap<(usize, usize), String> = HashMap::new(); for edge in &all_edges { - unique_features.entry((edge.layer, edge.feature)).or_default(); + unique_features + .entry((edge.layer, edge.feature)) + .or_default(); } let total = unique_features.len(); for (i, (&(layer, feat), label)) in unique_features.iter_mut().enumerate() { let gate_key = arch.ffn_gate_key(layer); if let Some(w_gate) = weights.tensors.get(&gate_key) { let gate_row = w_gate.row(feat); - *label = project_top_token(&weights.embed, &gate_row.to_vec(), model.tokenizer()); + *label = + project_top_token(&weights.embed, &gate_row.to_vec(), model.tokenizer()); } if (i + 1) % 500 == 0 { eprint!("\r {}/{} features...", i + 1, total); @@ -205,7 +218,11 @@ pub fn run(args: CircuitDiscoverArgs) -> Result<(), Box> edge.gate_top_token = label.clone(); } } - eprintln!("\r {} features labeled ({:.1}s)", total, label_start.elapsed().as_secs_f64()); + eprintln!( + "\r {} features labeled ({:.1}s)", + total, + label_start.elapsed().as_secs_f64() + ); } } @@ -320,7 +337,8 @@ pub fn run(args: CircuitDiscoverArgs) -> Result<(), Box> while let Some(current) = queue.pop() { if let Some(neighbors) = adjacency.get(¤t) { for &(neighbor, _sim) in neighbors { - if let std::collections::hash_map::Entry::Vacant(e) = cluster_id.entry(neighbor) { + if let std::collections::hash_map::Entry::Vacant(e) = cluster_id.entry(neighbor) + { e.insert(cid); queue.push(neighbor); } @@ -329,7 +347,10 @@ pub fn run(args: CircuitDiscoverArgs) -> Result<(), Box> } } - eprintln!(" Clustered in {:.1}s", cluster_start.elapsed().as_secs_f64()); + eprintln!( + " Clustered in {:.1}s", + cluster_start.elapsed().as_secs_f64() + ); // Build circuits from clusters let mut cluster_heads: HashMap> = HashMap::new(); @@ -368,7 +389,8 @@ pub fn run(args: CircuitDiscoverArgs) -> Result<(), Box> .iter() .take(10) .filter_map(|&(layer, feat, _)| { - all_edges.iter() + all_edges + .iter() .find(|e| e.layer == layer && e.feature == feat && !e.gate_top_token.is_empty()) .map(|e| e.gate_top_token.clone()) }) @@ -433,16 +455,19 @@ pub fn run(args: CircuitDiscoverArgs) -> Result<(), Box> println!(" Total edges: {}", all_edges.len()); println!(" Total heads: {}", head_keys.len()); println!(" Total circuits: {}", circuits.len()); - println!( - " Large circuits (3+ heads): {}", - large_circuits.len() - ); + println!(" Large circuits (3+ heads): {}", large_circuits.len()); if let Some(biggest) = large_circuits.first() { println!( " Largest circuit: {} heads, tokens: {}", biggest.heads.len(), - biggest.top_tokens.iter().take(5).cloned().collect::>().join(", ") + biggest + .top_tokens + .iter() + .take(5) + .cloned() + .collect::>() + .join(", ") ); } diff --git a/crates/larql-cli/src/commands/extraction/compile_cmd/chat.rs b/crates/larql-cli/src/commands/extraction/compile_cmd/chat.rs index 08a58076..63c16cc9 100644 --- a/crates/larql-cli/src/commands/extraction/compile_cmd/chat.rs +++ b/crates/larql-cli/src/commands/extraction/compile_cmd/chat.rs @@ -12,6 +12,7 @@ use std::path::Path; +use larql_vindex::format::filenames::TOKENIZER_CONFIG_JSON; use minijinja::{context, Environment, Value}; use serde_json::Value as JsonValue; @@ -22,7 +23,7 @@ pub fn render_user_prompt( base_dir: &Path, user_prompt: &str, ) -> Result> { - let cfg_path = base_dir.join("tokenizer_config.json"); + let cfg_path = base_dir.join(TOKENIZER_CONFIG_JSON); if !cfg_path.exists() { return Err(format!( "tokenizer_config.json not found in {} — cannot apply chat template", @@ -47,9 +48,15 @@ pub fn render_user_prompt( let mut env = Environment::new(); // `raise_exception` is a convention some HF templates use for error paths. - env.add_function("raise_exception", |msg: String| -> Result { - Err(minijinja::Error::new(minijinja::ErrorKind::InvalidOperation, msg)) - }); + env.add_function( + "raise_exception", + |msg: String| -> Result { + Err(minijinja::Error::new( + minijinja::ErrorKind::InvalidOperation, + msg, + )) + }, + ); env.add_template("chat", &template)?; let tmpl = env.get_template("chat")?; diff --git a/crates/larql-cli/src/commands/extraction/compile_cmd/detect.rs b/crates/larql-cli/src/commands/extraction/compile_cmd/detect.rs index 68c79e56..16140c61 100644 --- a/crates/larql-cli/src/commands/extraction/compile_cmd/detect.rs +++ b/crates/larql-cli/src/commands/extraction/compile_cmd/detect.rs @@ -4,10 +4,7 @@ use std::collections::HashMap; use ndarray::ArcArray2; -pub fn detect_ffn_pattern( - tensors: &HashMap>, - component: &str, -) -> String { +pub fn detect_ffn_pattern(tensors: &HashMap>, component: &str) -> String { let patterns: &[&str] = match component { "gate" => &[ "model.layers.{}.mlp.gate_proj.weight", diff --git a/crates/larql-cli/src/commands/extraction/compile_cmd/edge.rs b/crates/larql-cli/src/commands/extraction/compile_cmd/edge.rs index 3542f6ee..7f12bc76 100644 --- a/crates/larql-cli/src/commands/extraction/compile_cmd/edge.rs +++ b/crates/larql-cli/src/commands/extraction/compile_cmd/edge.rs @@ -115,7 +115,12 @@ pub fn install_edge( } } - Ok(EdgeStats { g_norm, u_norm, d_norm, alpha }) + Ok(EdgeStats { + g_norm, + u_norm, + d_norm, + alpha, + }) } fn vec_norm(v: &[f32]) -> f32 { @@ -159,7 +164,8 @@ mod tests { let trigger = vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]; let write = vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]; - let stats = install_edge(&mut t, "gate", "up", "down", 0, &trigger, &write, 30.0, 1.0).unwrap(); + let stats = + install_edge(&mut t, "gate", "up", "down", 0, &trigger, &write, 30.0, 1.0).unwrap(); let gate = t.get("gate").unwrap(); let expected = stats.g_norm * 30.0; @@ -171,8 +177,8 @@ mod tests { let mut t = fresh_layer(4, 8); let trigger = vec![0.0; 8]; let write = vec![1.0; 8]; - let err = install_edge(&mut t, "gate", "up", "down", 0, &trigger, &write, 30.0, 1.0) - .unwrap_err(); + let err = + install_edge(&mut t, "gate", "up", "down", 0, &trigger, &write, 30.0, 1.0).unwrap_err(); assert!(matches!(err, EdgeError::ZeroTrigger)); } @@ -181,8 +187,18 @@ mod tests { let mut t = fresh_layer(4, 8); let trigger = vec![1.0; 8]; let write = vec![1.0; 8]; - let err = install_edge(&mut t, "missing_gate", "up", "down", 0, &trigger, &write, 30.0, 1.0) - .unwrap_err(); + let err = install_edge( + &mut t, + "missing_gate", + "up", + "down", + 0, + &trigger, + &write, + 30.0, + 1.0, + ) + .unwrap_err(); assert!(matches!(err, EdgeError::MissingTensor(k) if k == "missing_gate")); } @@ -192,7 +208,8 @@ mod tests { for &scale in &[0.1_f32, 1.0, 100.0] { let trigger: Vec = (0..8).map(|i| (i as f32 + 1.0) * scale).collect(); let write = vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]; - let stats = install_edge(&mut t, "gate", "up", "down", 0, &trigger, &write, 30.0, 1.0).unwrap(); + let stats = + install_edge(&mut t, "gate", "up", "down", 0, &trigger, &write, 30.0, 1.0).unwrap(); let gate = t.get("gate").unwrap(); let gate_row_norm = (0..8).map(|j| gate[[0, j]].powi(2)).sum::().sqrt(); let expected = stats.g_norm * 30.0; @@ -206,7 +223,8 @@ mod tests { let mut t = fresh_layer(4, 8); let trigger = vec![1.0; 8]; let write = vec![0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]; - let stats = install_edge(&mut t, "gate", "up", "down", 0, &trigger, &write, 30.0, 1.0).unwrap(); + let stats = + install_edge(&mut t, "gate", "up", "down", 0, &trigger, &write, 30.0, 1.0).unwrap(); let down = t.get("down").unwrap(); for j in 0..8 { let expected = write[j] * stats.alpha; @@ -229,9 +247,13 @@ mod tests { let mut t = fresh_layer(4, 8); let trigger = vec![1.0; 8]; let write = vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]; - let s1 = install_edge(&mut t, "gate", "up", "down", 0, &trigger, &write, 30.0, 1.0).unwrap(); + let s1 = + install_edge(&mut t, "gate", "up", "down", 0, &trigger, &write, 30.0, 1.0).unwrap(); let mut t2 = fresh_layer(4, 8); - let s2 = install_edge(&mut t2, "gate", "up", "down", 0, &trigger, &write, 30.0, 5.0).unwrap(); + let s2 = install_edge( + &mut t2, "gate", "up", "down", 0, &trigger, &write, 30.0, 5.0, + ) + .unwrap(); assert!((s2.alpha / s1.alpha - 5.0).abs() < 1e-5); } } diff --git a/crates/larql-cli/src/commands/extraction/compile_cmd/patch.rs b/crates/larql-cli/src/commands/extraction/compile_cmd/patch.rs index 0989113c..6fdb6cf8 100644 --- a/crates/larql-cli/src/commands/extraction/compile_cmd/patch.rs +++ b/crates/larql-cli/src/commands/extraction/compile_cmd/patch.rs @@ -49,11 +49,7 @@ pub fn run(args: CompileArgs) -> Result<(), Box> { let mut all_ops = Vec::new(); for pf in &patch_files { let patch = larql_vindex::VindexPatch::load(pf)?; - eprintln!( - " patch: {} ({} ops)", - pf.display(), - patch.operations.len() - ); + eprintln!(" patch: {} ({} ops)", pf.display(), patch.operations.len()); all_ops.extend(patch.operations); } @@ -82,7 +78,10 @@ pub fn run(args: CompileArgs) -> Result<(), Box> { }; let Some(b64) = gate_vector_b64 else { - eprintln!(" skip: insert at L{}[{}] has no gate vector", layer, feature); + eprintln!( + " skip: insert at L{}[{}] has no gate vector", + layer, feature + ); continue; }; let gate_vec = decode_f32_b64(b64)?; diff --git a/crates/larql-cli/src/commands/extraction/compile_cmd/save.rs b/crates/larql-cli/src/commands/extraction/compile_cmd/save.rs index 68fb17a6..7ddea053 100644 --- a/crates/larql-cli/src/commands/extraction/compile_cmd/save.rs +++ b/crates/larql-cli/src/commands/extraction/compile_cmd/save.rs @@ -4,6 +4,7 @@ //! a text-only language model. Tied lm_head is dropped when `embed_tokens` is //! present, matching HuggingFace's tied-embedding convention. +use larql_vindex::format::filenames::*; use std::collections::HashMap; use std::path::Path; @@ -48,9 +49,7 @@ pub fn merge_for_save( vectors.insert(k.clone(), v.clone()); } - if tensors.contains_key("model.embed_tokens.weight") - && tensors.contains_key("lm_head.weight") - { + if tensors.contains_key("model.embed_tokens.weight") && tensors.contains_key("lm_head.weight") { tensors.remove("lm_head.weight"); } @@ -120,11 +119,11 @@ pub fn write_safetensors( /// a text-only Gemma 3 checkpoint (multimodal tensors were skipped above). pub fn copy_model_config(base: &Path, output: &Path) { for name in &[ - "tokenizer.json", - "tokenizer_config.json", + TOKENIZER_JSON, + TOKENIZER_CONFIG_JSON, "special_tokens_map.json", - "generation_config.json", - "tokenizer.model", // SentencePiece model — required by llama.cpp's GGUF converter + GENERATION_CONFIG_JSON, + "tokenizer.model", // SentencePiece model — required by llama.cpp's GGUF converter ] { let src = base.join(name); if src.exists() { diff --git a/crates/larql-cli/src/commands/extraction/compile_cmd/single.rs b/crates/larql-cli/src/commands/extraction/compile_cmd/single.rs index f4e365ee..7c4e4bae 100644 --- a/crates/larql-cli/src/commands/extraction/compile_cmd/single.rs +++ b/crates/larql-cli/src/commands/extraction/compile_cmd/single.rs @@ -5,12 +5,13 @@ //! and pushes the answer token through the LM head. CLI-driven; contrasts //! with patch mode (vindex-driven, many edges). +use larql_vindex::format::filenames::*; use std::collections::HashMap; use ndarray::ArcArray2; -use super::edge::install_edge; use super::detect::detect_ffn_pattern; +use super::edge::install_edge; use super::save::{copy_model_config, merge_for_save, write_safetensors}; use super::CompileArgs; @@ -31,13 +32,9 @@ pub fn run(args: CompileArgs) -> Result<(), Box> { let config = weights.arch.config(); eprintln!(" {} layers, dim={}", config.num_layers, config.hidden_size); - let tokenizer_path = args.base.join("tokenizer.json"); + let tokenizer_path = args.base.join(TOKENIZER_JSON); if !tokenizer_path.exists() { - return Err(format!( - "tokenizer.json not found in {}", - args.base.display() - ) - .into()); + return Err(format!("tokenizer.json not found in {}", args.base.display()).into()); } let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path) .map_err(|e| format!("tokenizer: {}", e))?; @@ -60,11 +57,8 @@ pub fn run(args: CompileArgs) -> Result<(), Box> { eprintln!(" prompt tokens: {}", token_ids.len()); eprintln!("\nCapturing L{} residual...", args.layer); - let residuals = larql_inference::forward::capture_residuals( - &weights, - &token_ids, - &[args.layer], - ); + let residuals = + larql_inference::forward::capture_residuals(&weights, &token_ids, &[args.layer]); let (_, residual) = residuals .into_iter() .find(|(l, _)| *l == args.layer) @@ -121,10 +115,7 @@ pub fn run(args: CompileArgs) -> Result<(), Box> { args.gate_scale, args.alpha, )?; - eprintln!( - " gate_scale={}, alpha={:.3}", - args.gate_scale, stats.alpha - ); + eprintln!(" gate_scale={}, alpha={:.3}", args.gate_scale, stats.alpha); eprintln!(" installed at L{} slot {}", args.layer, args.slot); // ── Balancer: scale the down vector up/down until the target token's @@ -142,9 +133,7 @@ pub fn run(args: CompileArgs) -> Result<(), Box> { for key in [&gate_key, &up_key, &down_key] { weights.tensors.insert(key.clone(), modified[key].clone()); } - let pred = larql_inference::forward::predict( - &weights, &tokenizer, &token_ids, 20, - ); + let pred = larql_inference::forward::predict(&weights, &tokenizer, &token_ids, 20); let prob: f64 = pred .predictions .iter() diff --git a/crates/larql-cli/src/commands/extraction/convert_cmd.rs b/crates/larql-cli/src/commands/extraction/convert_cmd.rs index a088c190..c06eacac 100644 --- a/crates/larql-cli/src/commands/extraction/convert_cmd.rs +++ b/crates/larql-cli/src/commands/extraction/convert_cmd.rs @@ -1,3 +1,4 @@ +use larql_vindex::format::filenames::*; use std::path::PathBuf; use clap::{Args, Subcommand}; @@ -51,20 +52,373 @@ enum ConvertCommand { /// Path to the .gguf file. input: PathBuf, }, + + /// Quantize an existing vindex into a different storage format. + /// Each sub-format has its own flag surface — see + /// `docs/specs/quantize-cli-spec.md` for the shape and how new + /// formats slot in. FP4 is the only format wired as of exp 26; + /// Q4K and future formats land as additional subcommands. + #[command(subcommand)] + Quantize(QuantizeCommand), + + /// Retrofit `down_features_q4k.bin` (W2 feature-major down) into + /// an existing Q4K vindex without re-quantising. Reads the down + /// portion of `interleaved_q4k.bin` per layer, transposes to + /// `[intermediate, hidden]`, re-quantises at the same precision + /// the source used, and writes the W2 file + manifest in place. + /// Idempotent — silent no-op when the file is already present. + /// See ADR-009 for the architectural rationale. + AddFeatureMajorDown { + /// Vindex directory to retrofit. Must already have + /// `interleaved_q4k.bin` + manifest (i.e. `quant: q4k` in + /// `index.json`). + #[arg(long)] + input: PathBuf, + + /// Suppress the per-layer progress line printed during write. + #[arg(long)] + quiet: bool, + }, +} + +#[derive(Subcommand)] +enum QuantizeCommand { + /// Convert an f32/f16 vindex into a Q4_K/Q6_K vindex (the Ollama- + /// compatible "Q4_K_M" mix: attention Q/K/O + FFN gate/up at + /// Q4_K, attention V + FFN down at Q6_K). `--down-q4k` switches + /// FFN down to Q4_K uniformly — saves ~30 MB/layer on 31B at + /// modest precision cost. + /// + /// Source must be extracted with `--level inference` or `--level all` + /// (needs the full f32/f16 weights to quantise). + Q4K { + /// Existing vindex directory (the source). + #[arg(long)] + input: PathBuf, + + /// Output vindex directory. Written atomically (to `.tmp/` + /// then renamed on success). + #[arg(long)] + output: PathBuf, + + /// Quantise FFN down-proj as Q4_K instead of Q6_K. Default off + /// preserves the Ollama Q4_K_M mix (Q4_K gate/up + Q6_K down). + #[arg(long)] + down_q4k: bool, + + /// Emit `down_features_q4k.bin` (W2 feature-major down) so per-feature + /// row decode can skip the `q4k_ffn_layer` cache. Adds ~14 MB / layer + /// at Gemma 4B dims; eliminates the ~840 MB heap cache ceiling. + /// Recommended for CPU sparse walk and grid/MoE workloads. + #[arg(long)] + feature_major_down: bool, + + /// Overwrite the output directory if it already exists. + #[arg(long)] + force: bool, + + /// Suppress the backend-describe summary printed after write. + #[arg(long)] + quiet: bool, + }, + + /// Convert an f32/f16 vindex into an FP4/FP8 vindex per the + /// chosen policy. Exp 26. Policy spec: `docs/specs/fp4-precision-policy.md`. + Fp4 { + /// Existing vindex directory (the source). + #[arg(long)] + input: PathBuf, + + /// Output vindex directory. Written atomically (to `.tmp/` + /// then renamed on success). + #[arg(long)] + output: PathBuf, + + /// Precision policy for up / down (gate stays at source dtype + /// in all three policies — FP4 gate is blocked on an FP4-aware + /// gate KNN path, see policy spec §2). + #[arg(long, default_value = "option-b", value_parser = ["option-a", "option-b", "option-c"])] + policy: String, + + /// Min compliance fraction for an FP4-targeted projection at + /// the given threshold. Projections below this are downgraded + /// to the manifest's fallback precision (FP8). Doesn't apply + /// to FP8 / F16 projections — those don't use the + /// distributional assumption. + #[arg(long, default_value_t = 0.99)] + compliance_floor: f32, + + /// max(sub-block scale)/min(sub-block scale) threshold for + /// the FP4 compliance gate. 16.0 is the E4M3/E2M1 exponent + /// budget (the format's derived default); lower = stricter, + /// higher = more permissive. + #[arg(long, default_value_t = 16.0)] + threshold: f32, + + /// Overwrite the output directory if it already exists. + #[arg(long)] + force: bool, + + /// Fail (non-zero exit) if any FP4-targeted projection misses + /// the compliance floor, instead of downgrading it. + #[arg(long)] + strict: bool, + + /// Skip emitting `fp4_compliance.json` in the output directory. + #[arg(long)] + no_sidecar: bool, + + /// Suppress the backend-describe summary printed after write. + #[arg(long)] + quiet: bool, + }, } pub fn run(args: ConvertArgs) -> Result<(), Box> { match args.command { - ConvertCommand::GgufToVindex { input, output, level, f16 } => { - run_gguf_to_vindex(&input, &output, &level, f16) + ConvertCommand::GgufToVindex { + input, + output, + level, + f16, + } => run_gguf_to_vindex(&input, &output, &level, f16), + ConvertCommand::SafetensorsToVindex { + input, + output, + level, + f16, + } => run_safetensors_to_vindex(&input, &output, &level, f16), + ConvertCommand::GgufInfo { input } => run_gguf_info(&input), + ConvertCommand::Quantize(cmd) => run_quantize(cmd), + ConvertCommand::AddFeatureMajorDown { input, quiet } => { + run_add_feature_major_down(&input, quiet) } - ConvertCommand::SafetensorsToVindex { input, output, level, f16 } => { - run_safetensors_to_vindex(&input, &output, &level, f16) + } +} + +fn run_add_feature_major_down( + input: &std::path::Path, + quiet: bool, +) -> Result<(), Box> { + use larql_vindex::quant::add_feature_major_down; + + if !quiet { + eprintln!("Retrofitting feature-major down → {}", input.display()); + } + let report = add_feature_major_down(input)?; + if report.skipped { + if !quiet { + eprintln!( + " down_features_q4k.bin already present — no-op (skipped {} layers)", + report.num_layers, + ); + } + return Ok(()); + } + if !quiet { + let mb = report.bytes_written as f64 / (1024.0 * 1024.0); + eprintln!( + " wrote down_features_q4k.bin: {} layers, {:.1} MB, {:.2?}", + report.num_layers, mb, report.wall_time, + ); + eprintln!( + " per-feature down decode now skips q4k_ffn_layer cache \ + (verify via GET /v1/stats → q4k_ffn.feature_major_down: true)" + ); + } + Ok(()) +} + +fn run_quantize(cmd: QuantizeCommand) -> Result<(), Box> { + match cmd { + QuantizeCommand::Fp4 { + input, + output, + policy, + compliance_floor, + threshold, + force, + strict, + no_sidecar, + quiet, + } => run_quantize_fp4(QuantizeFp4Opts { + input, + output, + policy, + compliance_floor, + threshold, + force, + strict, + no_sidecar, + quiet, + }), + QuantizeCommand::Q4K { + input, + output, + down_q4k, + feature_major_down, + force, + quiet, + } => run_quantize_q4k(QuantizeQ4kOpts { + input, + output, + down_q4k, + feature_major_down, + force, + quiet, + }), + } +} + +struct QuantizeQ4kOpts { + input: PathBuf, + output: PathBuf, + down_q4k: bool, + feature_major_down: bool, + force: bool, + quiet: bool, +} + +fn run_quantize_q4k(opts: QuantizeQ4kOpts) -> Result<(), Box> { + use larql_vindex::quant::{vindex_to_q4k, Q4kConvertConfig}; + + let config = Q4kConvertConfig { + down_q4k: opts.down_q4k, + feature_major_down: opts.feature_major_down, + force: opts.force, + }; + + if !opts.quiet { + eprintln!("== quantize q4k =="); + eprintln!(" in : {}", opts.input.display()); + eprintln!(" out : {}", opts.output.display()); + eprintln!( + " down_q4k : {} ({})", + opts.down_q4k, + if opts.down_q4k { + "Q4_K down (uniform)" + } else { + "Q6_K down (Q4_K_M mix)" + } + ); + eprintln!(); + } + + let report = vindex_to_q4k(&opts.input, &opts.output, &config)?; + + if !opts.quiet { + eprintln!("── summary ──"); + eprintln!( + " FFN storage : {:.2} GB → {:.2} GB ({:.2}× compression)", + report.src_ffn_bytes as f64 / 1_073_741_824.0, + report.dst_ffn_bytes as f64 / 1_073_741_824.0, + report.compression, + ); + eprintln!( + " Linked aux : {} files ({:.2} GB)", + report.aux_linked_count, + report.aux_linked_bytes as f64 / 1_073_741_824.0 + ); + eprintln!(" Wall time : {:.1}s", report.wall_time.as_secs_f64()); + eprintln!(" Walk backend: {}", report.walk_backend); + eprintln!(); + eprintln!("→ {}", opts.output.display()); + } + + Ok(()) +} + +struct QuantizeFp4Opts { + input: PathBuf, + output: PathBuf, + policy: String, + compliance_floor: f32, + threshold: f32, + force: bool, + strict: bool, + no_sidecar: bool, + quiet: bool, +} + +fn run_quantize_fp4(opts: QuantizeFp4Opts) -> Result<(), Box> { + use larql_vindex::quant::{vindex_to_fp4, Fp4ConvertConfig, Policy, ProjectionOutcome}; + + let policy = Policy::parse(&opts.policy)?; + let config = Fp4ConvertConfig { + policy, + compliance_floor: opts.compliance_floor, + threshold: opts.threshold, + strict: opts.strict, + force: opts.force, + emit_sidecar: !opts.no_sidecar, + }; + + if !opts.quiet { + eprintln!("== quantize fp4 =="); + eprintln!(" in : {}", opts.input.display()); + eprintln!(" out : {}", opts.output.display()); + eprintln!(" policy : {}", policy.label()); + eprintln!( + " floor : {:.1}% @ R<{}", + opts.compliance_floor * 100.0, + opts.threshold + ); + eprintln!(); + } + + let (report, _scan) = vindex_to_fp4(&opts.input, &opts.output, &config)?; + + if !opts.quiet { + eprintln!("── per-projection ──"); + for p in &report.per_projection { + let compliance = p + .compliance_at_threshold + .map(|c| format!("{:.4}%", c * 100.0)) + .unwrap_or_else(|| "N/A".into()); + let downgrade_flag = matches!( + p.outcome, + ProjectionOutcome::DowngradedFp4ToFp8 | ProjectionOutcome::DowngradedFp4ToF16, + ); + let marker = if downgrade_flag { "⚠" } else { " " }; + eprintln!( + " {marker} {:<5} compliance={:<12} → {:?} ({})", + p.name, + compliance, + p.chosen_precision, + p.outcome.action_str(), + ); } - ConvertCommand::GgufInfo { input } => { - run_gguf_info(&input) + eprintln!(); + eprintln!("── summary ──"); + eprintln!( + " FFN storage : {:.2} GB → {:.2} GB ({:.2}× compression)", + report.src_ffn_bytes as f64 / 1_073_741_824.0, + report.dst_ffn_bytes as f64 / 1_073_741_824.0, + report.compression, + ); + eprintln!( + " Linked aux : {} files ({:.2} GB)", + report.aux_linked_count, + report.aux_linked_bytes as f64 / 1_073_741_824.0 + ); + eprintln!(" Wall time : {:.1}s", report.wall_time.as_secs_f64()); + eprintln!(" Walk backend: {}", report.walk_backend); + eprintln!(); + if report.per_projection.iter().any(|p| { + matches!( + p.outcome, + ProjectionOutcome::DowngradedFp4ToFp8 | ProjectionOutcome::DowngradedFp4ToF16 + ) + }) { + eprintln!("⚠ compliance floor missed on ≥ 1 projection; see fp4_compliance.json."); + if !opts.strict { + eprintln!("(Use --strict to treat this as a fatal error.)"); + } } + eprintln!("→ {}", opts.output.display()); } + + Ok(()) } fn run_gguf_to_vindex( @@ -105,25 +459,26 @@ fn run_gguf_to_vindex( larql_vindex::StorageDtype::F32 }; - let model_name = gguf.metadata.get("general.name") + let model_name = gguf + .metadata + .get("general.name") .and_then(|v| v.as_str()) .unwrap_or("gguf-model") .to_string(); // Find tokenizer — check same directory as GGUF file - let tokenizer = input.parent() - .and_then(|dir| { - let tok_path = dir.join("tokenizer.json"); - if tok_path.exists() { - larql_vindex::tokenizers::Tokenizer::from_file(&tok_path).ok() - } else { - None - } - }); + let tokenizer = input.parent().and_then(|dir| { + let tok_path = dir.join(TOKENIZER_JSON); + if tok_path.exists() { + larql_vindex::tokenizers::Tokenizer::from_file(&tok_path).ok() + } else { + None + } + }); - let tokenizer_ref = tokenizer.as_ref().ok_or( - "tokenizer.json not found next to GGUF file. Place it in the same directory." - )?; + let tokenizer_ref = tokenizer + .as_ref() + .ok_or("tokenizer.json not found next to GGUF file. Place it in the same directory.")?; eprintln!("\nExtracting to {}", output.display()); @@ -138,6 +493,14 @@ fn run_gguf_to_vindex( dtype, &mut callbacks, )?; + // GGUF conversion: HF metadata (tokenizer_config.json etc.) is not + // packed in the GGUF itself, but if the user kept the HF files next + // to the `.gguf`, snapshot them. Missing-file case is a no-op. + if let Some(src_dir) = input.parent() { + if let Err(e) = larql_vindex::snapshot_hf_metadata(src_dir, output) { + eprintln!(" warning: failed to snapshot HF metadata: {e}"); + } + } eprintln!("Done: {}", output.display()); Ok(()) @@ -152,13 +515,12 @@ fn run_safetensors_to_vindex( // This is essentially extract-index eprintln!("Loading safetensors: {}", input.display()); let weights = larql_models::load_model_dir(input)?; - let tokenizer = larql_vindex::load_vindex_tokenizer(input) - .or_else(|_| { - // Try to load from the model directory - let tok_path = input.join("tokenizer.json"); - larql_vindex::tokenizers::Tokenizer::from_file(&tok_path) - .map_err(|e| larql_vindex::VindexError::Parse(e.to_string())) - })?; + let tokenizer = larql_vindex::load_vindex_tokenizer(input).or_else(|_| { + // Try to load from the model directory + let tok_path = input.join(TOKENIZER_JSON); + larql_vindex::tokenizers::Tokenizer::from_file(&tok_path) + .map_err(|e| larql_vindex::VindexError::Parse(e.to_string())) + })?; let extract_level = match level { "inference" => larql_vindex::ExtractLevel::Inference, @@ -172,7 +534,8 @@ fn run_safetensors_to_vindex( larql_vindex::StorageDtype::F32 }; - let model_name = input.file_name() + let model_name = input + .file_name() .map(|n| n.to_string_lossy().to_string()) .unwrap_or_else(|| "model".into()); @@ -189,6 +552,12 @@ fn run_safetensors_to_vindex( dtype, &mut callbacks, )?; + // Snapshot HF-side metadata (chat template, special tokens, generation + // config) from the source directory. `input` here is the safetensors + // model dir, which is where these files live in the HF cache. + if let Err(e) = larql_vindex::snapshot_hf_metadata(input, output) { + eprintln!(" warning: failed to snapshot HF metadata: {e}"); + } eprintln!("Done: {}", output.display()); Ok(()) diff --git a/crates/larql-cli/src/commands/extraction/embedding_jump_cmd.rs b/crates/larql-cli/src/commands/extraction/embedding_jump_cmd.rs index 077eea03..9dbcf8dc 100644 --- a/crates/larql-cli/src/commands/extraction/embedding_jump_cmd.rs +++ b/crates/larql-cli/src/commands/extraction/embedding_jump_cmd.rs @@ -60,7 +60,9 @@ pub fn run(args: EmbeddingJumpArgs) -> Result<(), Box> { eprintln!( " {} layers, hidden={}, embed_scale={:.1} ({:.1}s)", - num_layers, hidden, embed_scale, + num_layers, + hidden, + embed_scale, start.elapsed().as_secs_f64() ); @@ -71,7 +73,10 @@ pub fn run(args: EmbeddingJumpArgs) -> Result<(), Box> { .filter(|l| !l.is_empty()) .collect(); - eprintln!("Fitting projection from {} training prompts...", train_prompts.len()); + eprintln!( + "Fitting projection from {} training prompts...", + train_prompts.len() + ); let fit_start = Instant::now(); // ── For each training prompt: compute raw embedding AND real L_target ── @@ -83,12 +88,15 @@ pub fn run(args: EmbeddingJumpArgs) -> Result<(), Box> { let mut y_vecs: Vec> = Vec::new(); // real L_target last-token for (i, prompt) in train_prompts.iter().enumerate() { - let encoding = model.tokenizer() + let encoding = model + .tokenizer() .encode(prompt.as_str(), true) .map_err(|e| format!("tokenize: {e}"))?; let token_ids: Vec = encoding.get_ids().to_vec(); let seq_len = token_ids.len(); - if seq_len < 3 { continue; } + if seq_len < 3 { + continue; + } // Compute input vector let input_vec: Vec = if args.source_layers > 0 { @@ -99,7 +107,9 @@ pub fn run(args: EmbeddingJumpArgs) -> Result<(), Box> { let mut sum = vec![0.0f32; hidden]; for &tid in &token_ids { let row = weights.embed.row(tid as usize); - for j in 0..hidden { sum[j] += row[j] * embed_scale; } + for j in 0..hidden { + sum[j] += row[j] * embed_scale; + } } sum } else { @@ -144,10 +154,12 @@ pub fn run(args: EmbeddingJumpArgs) -> Result<(), Box> { } // Center X - let xc: Vec> = x_vecs.iter() + let xc: Vec> = x_vecs + .iter() .map(|x| x.iter().zip(x_mean.iter()).map(|(a, m)| a - m).collect()) .collect(); - let yc: Vec> = y_vecs.iter() + let yc: Vec> = y_vecs + .iter() .map(|y| y.iter().zip(y_mean.iter()).map(|(a, m)| a - m).collect()) .collect(); @@ -169,7 +181,9 @@ pub fn run(args: EmbeddingJumpArgs) -> Result<(), Box> { for _ in 0..r { let mut v = vec![1.0f32; n_train]; let n: f32 = v.iter().map(|x| x * x).sum::().sqrt(); - for x in v.iter_mut() { *x /= n; } + for x in v.iter_mut() { + *x /= n; + } let mut ev = 0.0f32; for _ in 0..100 { @@ -183,10 +197,16 @@ pub fn run(args: EmbeddingJumpArgs) -> Result<(), Box> { } ev = mv.iter().zip(v.iter()).map(|(a, b)| a * b).sum(); let n: f32 = mv.iter().map(|x| x * x).sum::().sqrt(); - if n < 1e-12 { break; } - for (x, m) in v.iter_mut().zip(mv.iter()) { *x = m / n; } + if n < 1e-12 { + break; + } + for (x, m) in v.iter_mut().zip(mv.iter()) { + *x = m / n; + } + } + if ev < 1e-8 { + break; } - if ev < 1e-8 { break; } eigenvalues.push(ev.sqrt()); eigenvectors.push(v.clone()); @@ -207,17 +227,25 @@ pub fn run(args: EmbeddingJumpArgs) -> Result<(), Box> { let mut dir = vec![0.0f32; hidden]; for i in 0..n_train { let c = eigenvectors[k][i] / eigenvalues[k]; - for j in 0..hidden { dir[j] += c * xc[i][j]; } + for j in 0..hidden { + dir[j] += c * xc[i][j]; + } } let n: f32 = dir.iter().map(|x| x * x).sum::().sqrt(); - if n > 1e-12 { for x in dir.iter_mut() { *x /= n; } } + if n > 1e-12 { + for x in dir.iter_mut() { + *x /= n; + } + } vt_rows.push(dir); // Beta let mut beta = vec![0.0f32; hidden]; for i in 0..n_train { let c = eigenvectors[k][i] / eigenvalues[k]; - for j in 0..hidden { beta[j] += c * yc[i][j]; } + for j in 0..hidden { + beta[j] += c * yc[i][j]; + } } betas.push(beta); } @@ -227,7 +255,10 @@ pub fn run(args: EmbeddingJumpArgs) -> Result<(), Box> { // ── Load test prompts ── let test_prompts: Vec = if let Some(ref file) = args.prompts_file { std::fs::read_to_string(file)? - .lines().map(|l| l.trim().to_string()).filter(|l| !l.is_empty()).collect() + .lines() + .map(|l| l.trim().to_string()) + .filter(|l| !l.is_empty()) + .collect() } else if let Some(ref p) = args.prompts { p.split(',').map(|s| s.trim().to_string()).collect() } else { @@ -237,7 +268,10 @@ pub fn run(args: EmbeddingJumpArgs) -> Result<(), Box> { // ── End-to-end test ── eprintln!( "\n── Embedding Jump: raw embed → rank-{} project → L{} → L{}-L{} dense ──\n", - rank, target, inject_at, num_layers - 1 + rank, + target, + inject_at, + num_layers - 1 ); println!( @@ -251,17 +285,23 @@ pub fn run(args: EmbeddingJumpArgs) -> Result<(), Box> { let mut cosines = Vec::new(); for prompt in &test_prompts { - let encoding = model.tokenizer() + let encoding = model + .tokenizer() .encode(prompt.as_str(), true) .map_err(|e| format!("tokenize: {e}"))?; let token_ids: Vec = encoding.get_ids().to_vec(); let seq_len = token_ids.len(); - if seq_len < 3 { continue; } + if seq_len < 3 { + continue; + } // Baseline let baseline = predict(weights, model.tokenizer(), &token_ids, args.top_k); - let (base_tok, base_conf) = baseline.predictions.first() - .map(|(t, p)| (t.clone(), *p)).unwrap_or_default(); + let (base_tok, base_conf) = baseline + .predictions + .first() + .map(|(t, p)| (t.clone(), *p)) + .unwrap_or_default(); // Compute input (same method as training) let input_vec: Vec = if args.source_layers > 0 { @@ -271,7 +311,9 @@ pub fn run(args: EmbeddingJumpArgs) -> Result<(), Box> { let mut sum = vec![0.0f32; hidden]; for &tid in &token_ids { let row = weights.embed.row(tid as usize); - for j in 0..hidden { sum[j] += row[j] * embed_scale; } + for j in 0..hidden { + sum[j] += row[j] * embed_scale; + } } sum } else { @@ -297,10 +339,18 @@ pub fn run(args: EmbeddingJumpArgs) -> Result<(), Box> { // Cosine between projected and real at target layer let real_last: Vec = h_real.row(seq_len - 1).to_vec(); let cos: f32 = { - let dot: f32 = projected.iter().zip(real_last.iter()).map(|(a, b)| a * b).sum(); + let dot: f32 = projected + .iter() + .zip(real_last.iter()) + .map(|(a, b)| a * b) + .sum(); let na: f32 = projected.iter().map(|x| x * x).sum::().sqrt(); let nb: f32 = real_last.iter().map(|x| x * x).sum::().sqrt(); - if na > 1e-12 && nb > 1e-12 { dot / (na * nb) } else { 0.0 } + if na > 1e-12 && nb > 1e-12 { + dot / (na * nb) + } else { + 0.0 + } }; cosines.push(cos); @@ -311,22 +361,29 @@ pub fn run(args: EmbeddingJumpArgs) -> Result<(), Box> { } // Run decoder - let jump_result = predict_from_hidden( - weights, model.tokenizer(), &h_hybrid, inject_at, args.top_k, - ); - let (jump_tok, jump_conf) = jump_result.predictions.first() - .map(|(t, p)| (t.clone(), *p)).unwrap_or_default(); + let jump_result = + predict_from_hidden(weights, model.tokenizer(), &h_hybrid, inject_at, args.top_k); + let (jump_tok, jump_conf) = jump_result + .predictions + .first() + .map(|(t, p)| (t.clone(), *p)) + .unwrap_or_default(); let matched = jump_tok == base_tok; - if matched { match_count += 1; } + if matched { + match_count += 1; + } total += 1; let m = if matched { "=" } else { "X" }; println!( "{:<45} {:>12} {:>12} {:>7.2}% {:>7.2}% {:>3}", &prompt[..prompt.len().min(44)], - base_tok, jump_tok, - base_conf * 100.0, jump_conf * 100.0, m, + base_tok, + jump_tok, + base_conf * 100.0, + jump_conf * 100.0, + m, ); } @@ -338,21 +395,44 @@ pub fn run(args: EmbeddingJumpArgs) -> Result<(), Box> { eprintln!(" Prompts: {}", total); eprintln!( " Token match: {}/{} ({:.1}%)", - match_count, total, + match_count, + total, match_count as f64 / total.max(1) as f64 * 100.0 ); - eprintln!(" Cosine at L{}: mean={:.6}, min={:.6}", target, mean_cos, min_cos); + eprintln!( + " Cosine at L{}: mean={:.6}, min={:.6}", + target, mean_cos, min_cos + ); if args.source_layers > 0 { - eprintln!(" Method: {} real layers → rank-{} projection → L{}-L{} dense", - args.source_layers, rank, inject_at, num_layers - 1); - eprintln!(" {} real layers + {} dot products → {} decoder layers.", - args.source_layers, rank, num_layers - inject_at); + eprintln!( + " Method: {} real layers → rank-{} projection → L{}-L{} dense", + args.source_layers, + rank, + inject_at, + num_layers - 1 + ); + eprintln!( + " {} real layers + {} dot products → {} decoder layers.", + args.source_layers, + rank, + num_layers - inject_at + ); } else { - eprintln!(" Method: raw embedding → rank-{} projection → L{}-L{} dense", - rank, inject_at, num_layers - 1); - eprintln!(" Zero encoder layers. Just embedding lookup + {} dot products.", rank); + eprintln!( + " Method: raw embedding → rank-{} projection → L{}-L{} dense", + rank, + inject_at, + num_layers - 1 + ); + eprintln!( + " Zero encoder layers. Just embedding lookup + {} dot products.", + rank + ); } - eprintln!(" Zero matmul layers. Just an embedding lookup + {} dot products.", rank); + eprintln!( + " Zero matmul layers. Just an embedding lookup + {} dot products.", + rank + ); Ok(()) } diff --git a/crates/larql-cli/src/commands/extraction/extract_index_cmd.rs b/crates/larql-cli/src/commands/extraction/extract_index_cmd.rs index f3ea4bed..74d8259e 100644 --- a/crates/larql-cli/src/commands/extraction/extract_index_cmd.rs +++ b/crates/larql-cli/src/commands/extraction/extract_index_cmd.rs @@ -1,10 +1,11 @@ +use larql_vindex::format::filenames::*; use std::path::PathBuf; use std::time::Instant; use clap::Args; use indicatif::{ProgressBar, ProgressStyle}; +use larql_inference::InferenceModel; use larql_vindex::IndexBuildCallbacks; -use larql_inference::{ InferenceModel}; #[derive(Args)] pub struct ExtractIndexArgs { @@ -87,6 +88,14 @@ pub struct ExtractIndexArgs { #[arg(long)] down_q4k: bool, + /// Emit `down_features_q4k.bin` (W2 feature-major down) so per-feature + /// row decode can skip the `q4k_ffn_layer` cache. Adds ~14 MB / layer + /// at Gemma 4B dims; eliminates the ~840 MB heap cache ceiling on + /// CPU sparse walk and frees the same headroom across all grid shards. + /// Requires `--quant q4k`. + #[arg(long)] + feature_major_down: bool, + /// Skip stages that already have output files (resume interrupted builds). #[arg(long)] resume: bool, @@ -95,7 +104,7 @@ pub struct ExtractIndexArgs { fn parse_quant(s: &str) -> Result { match s.to_lowercase().as_str() { "none" | "" => Ok(larql_vindex::QuantFormat::None), - "q4k" | "q4_k" => Ok(larql_vindex::QuantFormat::Q4k), + "q4k" | "q4_k" => Ok(larql_vindex::QuantFormat::Q4K), _ => Err(format!("unknown quant format: {s} (expected: none, q4k)")), } } @@ -149,13 +158,7 @@ impl IndexBuildCallbacks for CliBuildCallbacks { .set_message(format!("{component} L{layer} ({}/{})", layer + 1, total)); } - fn on_feature_progress( - &mut self, - component: &str, - _layer: usize, - done: usize, - total: usize, - ) { + fn on_feature_progress(&mut self, component: &str, _layer: usize, done: usize, total: usize) { if total > 0 { self.feature_bar.set_length(total as u64); } @@ -200,7 +203,7 @@ pub fn run(args: ExtractIndexArgs) -> Result<(), Box> { // default → F32 // f16 is the default now; --f32 opts out. `--quant q4k` always // forces f16 on the side-channel tensors. - let dtype = if args.f32 && args.quant != larql_vindex::QuantFormat::Q4k { + let dtype = if args.f32 && args.quant != larql_vindex::QuantFormat::Q4K { larql_vindex::StorageDtype::F32 } else { larql_vindex::StorageDtype::F16 @@ -213,7 +216,10 @@ pub fn run(args: ExtractIndexArgs) -> Result<(), Box> { larql_vindex::build_vindex_from_vectors(vectors_dir, &args.output, &mut callbacks)?; - if matches!(level, larql_vindex::ExtractLevel::Inference | larql_vindex::ExtractLevel::All) { + if matches!( + level, + larql_vindex::ExtractLevel::Inference | larql_vindex::ExtractLevel::All + ) { let model_name = args.model.as_deref().ok_or( "--model required with --level inference/all (need model to extract weights)", )?; @@ -224,7 +230,10 @@ pub fn run(args: ExtractIndexArgs) -> Result<(), Box> { ffn_compact: args.compact, }; larql_vindex::write_model_weights_with_opts( - model.weights(), &args.output, &mut callbacks, weight_opts, + model.weights(), + &args.output, + &mut callbacks, + weight_opts, )?; } } else { @@ -246,13 +255,19 @@ pub fn run(args: ExtractIndexArgs) -> Result<(), Box> { larql_vindex::StorageDtype::F32 => "f32", larql_vindex::StorageDtype::F16 => "f16", }; - eprintln!("Extracting: {} → {} (level={}, dtype={}, quant={})", - model_path.display(), args.output.display(), level_str, dtype_str, args.quant); + eprintln!( + "Extracting: {} → {} (level={}, dtype={}, quant={})", + model_path.display(), + args.output.display(), + level_str, + dtype_str, + args.quant + ); let output = &args.output; // Find or create tokenizer - let tok_path = model_path.join("tokenizer.json"); + let tok_path = model_path.join(TOKENIZER_JSON); let tokenizer = if tok_path.exists() { larql_vindex::tokenizers::Tokenizer::from_file(&tok_path) .map_err(|e| format!("failed to load tokenizer: {e}"))? @@ -264,18 +279,27 @@ pub fn run(args: ExtractIndexArgs) -> Result<(), Box> { level, ffn_compact: args.compact, }; - if args.drop_gate_vectors && args.quant != larql_vindex::QuantFormat::Q4k { + if args.drop_gate_vectors && args.quant != larql_vindex::QuantFormat::Q4K { return Err( "--drop-gate-vectors requires --quant q4k (gate is rebuilt from Q4K at load)" .into(), ); } - if args.down_q4k && args.quant != larql_vindex::QuantFormat::Q4k { + if args.down_q4k && args.quant != larql_vindex::QuantFormat::Q4K { return Err( "--down-q4k requires --quant q4k (only the Q4K writer honours this flag)".into(), ); } - let q4k_opts = larql_vindex::Q4kWriteOptions { down_q4k: args.down_q4k }; + if args.feature_major_down && args.quant != larql_vindex::QuantFormat::Q4K { + return Err( + "--feature-major-down requires --quant q4k (only the Q4K writer honours this flag)" + .into(), + ); + } + let q4k_opts = larql_vindex::Q4kWriteOptions { + down_q4k: args.down_q4k, + feature_major_down: args.feature_major_down, + }; larql_vindex::build_vindex_streaming( &model_path, &tokenizer, @@ -290,6 +314,15 @@ pub fn run(args: ExtractIndexArgs) -> Result<(), Box> { args.drop_gate_vectors, &mut callbacks, )?; + + // Opportunistically copy HF metadata (tokenizer_config.json, + // special_tokens_map.json, generation_config.json) from the source + // directory into the vindex. Chat-template-aware runtimes read + // `tokenizer_config.json::chat_template` from here; missing files + // are silently skipped. + if let Err(e) = larql_vindex::snapshot_hf_metadata(&model_path, output) { + eprintln!(" warning: failed to snapshot HF metadata: {e}"); + } } callbacks.feature_bar.finish_and_clear(); @@ -300,27 +333,24 @@ pub fn run(args: ExtractIndexArgs) -> Result<(), Box> { eprintln!(" Output: {}", args.output.display()); if build_elapsed.as_secs() >= 60 { - eprintln!( - " Build time: {:.1}min", - build_elapsed.as_secs_f64() / 60.0 - ); + eprintln!(" Build time: {:.1}min", build_elapsed.as_secs_f64() / 60.0); } else { eprintln!(" Build time: {:.1}s", build_elapsed.as_secs_f64()); } for name in &[ - "index.json", - "gate_vectors.bin", - "embeddings.bin", + INDEX_JSON, + GATE_VECTORS_BIN, + EMBEDDINGS_BIN, "down_meta.jsonl", - "down_meta.bin", - "tokenizer.json", - "attn_weights.bin", - "up_weights.bin", - "down_weights.bin", - "norms.bin", - "lm_head.bin", - "weight_manifest.json", + DOWN_META_BIN, + TOKENIZER_JSON, + ATTN_WEIGHTS_BIN, + UP_WEIGHTS_BIN, + DOWN_WEIGHTS_BIN, + NORMS_BIN, + LM_HEAD_BIN, + WEIGHT_MANIFEST_JSON, ] { let path = args.output.join(name); if let Ok(meta) = std::fs::metadata(&path) { @@ -342,7 +372,8 @@ pub fn run(args: ExtractIndexArgs) -> Result<(), Box> { let total_size: u64 = std::fs::read_dir(&args.output) .ok() .map(|entries| { - entries.filter_map(|e| e.ok()) + entries + .filter_map(|e| e.ok()) .filter_map(|e| e.metadata().ok()) .map(|m| m.len()) .sum() diff --git a/crates/larql-cli/src/commands/extraction/ffn_bottleneck_cmd.rs b/crates/larql-cli/src/commands/extraction/ffn_bottleneck_cmd.rs index e479170b..baa36528 100644 --- a/crates/larql-cli/src/commands/extraction/ffn_bottleneck_cmd.rs +++ b/crates/larql-cli/src/commands/extraction/ffn_bottleneck_cmd.rs @@ -1,9 +1,7 @@ use std::time::Instant; use clap::Args; -use larql_inference::{ - trace_forward, InferenceModel, -}; +use larql_inference::{trace_forward, InferenceModel}; #[derive(Args)] pub struct FfnBottleneckArgs { @@ -29,7 +27,9 @@ pub fn run(args: FfnBottleneckArgs) -> Result<(), Box> { let model = InferenceModel::load(&args.model)?; let weights = model.weights(); - let encoding = model.tokenizer().encode(args.prompt.as_str(), true) + let encoding = model + .tokenizer() + .encode(args.prompt.as_str(), true) .map_err(|e| format!("tokenize error: {e}"))?; let token_ids: Vec = encoding.get_ids().to_vec(); let seq_len = token_ids.len(); @@ -63,13 +63,17 @@ pub fn run(args: FfnBottleneckArgs) -> Result<(), Box> { // 1. Gate matmul: x @ gate.T → (seq, intermediate) let _ = x.dot(&w_gate.t()); let start = Instant::now(); - for _ in 0..iters { let _ = x.dot(&w_gate.t()); } + for _ in 0..iters { + let _ = x.dot(&w_gate.t()); + } let gate_us = start.elapsed().as_micros() as f64 / iters as f64; // 2. Up matmul: x @ up.T → (seq, intermediate) let _ = x.dot(&w_up.t()); let start = Instant::now(); - for _ in 0..iters { let _ = x.dot(&w_up.t()); } + for _ in 0..iters { + let _ = x.dot(&w_up.t()); + } let up_us = start.elapsed().as_micros() as f64 / iters as f64; // 3. SiLU activation: element-wise on (seq, intermediate) @@ -87,7 +91,9 @@ pub fn run(args: FfnBottleneckArgs) -> Result<(), Box> { let activation = &activated * &up_proj; let _ = activation.dot(&w_down.t()); let start = Instant::now(); - for _ in 0..iters { let _ = activation.dot(&w_down.t()); } + for _ in 0..iters { + let _ = activation.dot(&w_down.t()); + } let down_us = start.elapsed().as_micros() as f64 / iters as f64; // 5. Top-K selection from gate activations (for sparse path) @@ -95,7 +101,8 @@ pub fn run(args: FfnBottleneckArgs) -> Result<(), Box> { let start = Instant::now(); for _ in 0..iters { for s in 0..seq_len { - let mut indexed: Vec<(usize, f32)> = gate_act.row(s).iter().copied().enumerate().collect(); + let mut indexed: Vec<(usize, f32)> = + gate_act.row(s).iter().copied().enumerate().collect(); indexed.select_nth_unstable_by(64, |a, b| b.1.abs().partial_cmp(&a.1.abs()).unwrap()); } } @@ -136,16 +143,23 @@ pub fn run(args: FfnBottleneckArgs) -> Result<(), Box> { let ffn = larql_inference::WeightFfn { weights }; let _ = larql_inference::FfnBackend::forward(&ffn, layer, &x); let start = Instant::now(); - for _ in 0..iters { let _ = larql_inference::FfnBackend::forward(&ffn, layer, &x); } + for _ in 0..iters { + let _ = larql_inference::FfnBackend::forward(&ffn, layer, &x); + } let total_us = start.elapsed().as_micros() as f64 / iters as f64; let total_parts = gate_us + up_us + silu_us + down_us; println!(); - println!("FFN Layer {} Bottleneck Analysis (seq_len={}, hidden={}, intermediate={})", - layer, seq_len, hidden, intermediate); + println!( + "FFN Layer {} Bottleneck Analysis (seq_len={}, hidden={}, intermediate={})", + layer, seq_len, hidden, intermediate + ); println!("{}", "=".repeat(65)); - println!("{:>30} {:>10} {:>10} {:>10}", "Component", "Time (us)", "% of FFN", "GFLOPS"); + println!( + "{:>30} {:>10} {:>10} {:>10}", + "Component", "Time (us)", "% of FFN", "GFLOPS" + ); println!("{}", "-".repeat(65)); let gate_flops = 2.0 * seq_len as f64 * hidden as f64 * intermediate as f64; @@ -153,40 +167,72 @@ pub fn run(args: FfnBottleneckArgs) -> Result<(), Box> { let silu_flops = 2.0 * seq_len as f64 * intermediate as f64; let down_flops = 2.0 * seq_len as f64 * intermediate as f64 * hidden as f64; - println!("{:>30} {:>8.0}us {:>9.1}% {:>9.1}", - "gate matmul (x @ gate.T)", gate_us, gate_us / total_parts * 100.0, - gate_flops / gate_us / 1000.0); - println!("{:>30} {:>8.0}us {:>9.1}% {:>9.1}", - "up matmul (x @ up.T)", up_us, up_us / total_parts * 100.0, - up_flops / up_us / 1000.0); - println!("{:>30} {:>8.0}us {:>9.1}% {:>9.1}", - "SiLU + element mul", silu_us, silu_us / total_parts * 100.0, - silu_flops / silu_us / 1000.0); - println!("{:>30} {:>8.0}us {:>9.1}% {:>9.1}", - "down matmul (act @ down.T)", down_us, down_us / total_parts * 100.0, - down_flops / down_us / 1000.0); + println!( + "{:>30} {:>8.0}us {:>9.1}% {:>9.1}", + "gate matmul (x @ gate.T)", + gate_us, + gate_us / total_parts * 100.0, + gate_flops / gate_us / 1000.0 + ); + println!( + "{:>30} {:>8.0}us {:>9.1}% {:>9.1}", + "up matmul (x @ up.T)", + up_us, + up_us / total_parts * 100.0, + up_flops / up_us / 1000.0 + ); + println!( + "{:>30} {:>8.0}us {:>9.1}% {:>9.1}", + "SiLU + element mul", + silu_us, + silu_us / total_parts * 100.0, + silu_flops / silu_us / 1000.0 + ); + println!( + "{:>30} {:>8.0}us {:>9.1}% {:>9.1}", + "down matmul (act @ down.T)", + down_us, + down_us / total_parts * 100.0, + down_flops / down_us / 1000.0 + ); println!("{}", "-".repeat(65)); - println!("{:>30} {:>8.0}us {:>9.1}%", - "Sum of parts", total_parts, 100.0); - println!("{:>30} {:>8.0}us", - "Actual dense FFN", total_us); + println!( + "{:>30} {:>8.0}us {:>9.1}%", + "Sum of parts", total_parts, 100.0 + ); + println!("{:>30} {:>8.0}us", "Actual dense FFN", total_us); println!(); println!("Sparse path components:"); println!("{}", "-".repeat(65)); - println!("{:>30} {:>8.0}us (gate matmul still required)", - "gate matmul", gate_us); - println!("{:>30} {:>8.0}us (select top-64 from {})", - "top-K selection", topk_us, intermediate); - println!("{:>30} {:>8.0}us (64 rows × {} dims)", - "gather rows", gather_us, hidden); - println!("{:>30} {:>8.0}us (64,{}) @ ({},) × {} pos", - "sparse gate+up gemv", sparse_gemv_us, hidden, hidden, seq_len); - println!("{:>30} {:>8.0}us (minimum sparse overhead)", - "sparse total (no down)", gate_us + topk_us + gather_us + sparse_gemv_us); + println!( + "{:>30} {:>8.0}us (gate matmul still required)", + "gate matmul", gate_us + ); + println!( + "{:>30} {:>8.0}us (select top-64 from {})", + "top-K selection", topk_us, intermediate + ); + println!( + "{:>30} {:>8.0}us (64 rows × {} dims)", + "gather rows", gather_us, hidden + ); + println!( + "{:>30} {:>8.0}us (64,{}) @ ({},) × {} pos", + "sparse gate+up gemv", sparse_gemv_us, hidden, hidden, seq_len + ); + println!( + "{:>30} {:>8.0}us (minimum sparse overhead)", + "sparse total (no down)", + gate_us + topk_us + gather_us + sparse_gemv_us + ); println!(); - println!("{:>30} {:>8.0}us ({:.0}% of FFN is gate+up matmul)", - "gate + up matmuls", gate_us + up_us, (gate_us + up_us) / total_parts * 100.0); + println!( + "{:>30} {:>8.0}us ({:.0}% of FFN is gate+up matmul)", + "gate + up matmuls", + gate_us + up_us, + (gate_us + up_us) / total_parts * 100.0 + ); Ok(()) } diff --git a/crates/larql-cli/src/commands/extraction/ffn_overlap_cmd.rs b/crates/larql-cli/src/commands/extraction/ffn_overlap_cmd.rs index e43f83b7..0ab491db 100644 --- a/crates/larql-cli/src/commands/extraction/ffn_overlap_cmd.rs +++ b/crates/larql-cli/src/commands/extraction/ffn_overlap_cmd.rs @@ -1,9 +1,7 @@ use std::path::PathBuf; use clap::Args; -use larql_inference::{ - trace_forward, GateIndex, InferenceModel, -}; +use larql_inference::{trace_forward, GateIndex, InferenceModel}; #[derive(Args)] pub struct FfnOverlapArgs { @@ -30,11 +28,15 @@ pub fn run(args: FfnOverlapArgs) -> Result<(), Box> { let gi = GateIndex::load(&args.gate_index, 10)?; - let encoding = model.tokenizer().encode(args.prompt.as_str(), true) + let encoding = model + .tokenizer() + .encode(args.prompt.as_str(), true) .map_err(|e| format!("tokenize error: {e}"))?; let token_ids: Vec = encoding.get_ids().to_vec(); - let layers: Vec = args.layers.split(',') + let layers: Vec = args + .layers + .split(',') .map(|s| s.trim().parse().unwrap()) .collect(); @@ -44,8 +46,10 @@ pub fn run(args: FfnOverlapArgs) -> Result<(), Box> { // Entity tokens for gate index lookup let entity_tokens: Vec<(usize, f32)> = token_ids.iter().map(|&t| (t as usize, 1.0)).collect(); - println!("{:>5} {:>8} {:>8} {:>8} {:>8} {:>8}", - "Layer", "Entity", "Gate64", "Gate256", "Overlap64", "Overlap256"); + println!( + "{:>5} {:>8} {:>8} {:>8} {:>8} {:>8}", + "Layer", "Entity", "Gate64", "Gate256", "Overlap64", "Overlap256" + ); println!("{}", "-".repeat(55)); for (layer, residual_vec) in &trace.residuals { @@ -58,26 +62,41 @@ pub fn run(args: FfnOverlapArgs) -> Result<(), Box> { let gate_scores = w_gate.dot(&residual); // Top-64 and top-256 from actual gate matmul - let mut indexed: Vec<(usize, f32)> = gate_scores.iter().copied().enumerate() + let mut indexed: Vec<(usize, f32)> = gate_scores + .iter() + .copied() + .enumerate() .map(|(i, v)| (i, v * larql_inference::ffn::sigmoid(v))) .collect(); indexed.sort_unstable_by(|a, b| b.1.abs().partial_cmp(&a.1.abs()).unwrap()); - let gate_top64: std::collections::HashSet = indexed.iter().take(64).map(|x| x.0).collect(); - let gate_top256: std::collections::HashSet = indexed.iter().take(256).map(|x| x.0).collect(); + let gate_top64: std::collections::HashSet = + indexed.iter().take(64).map(|x| x.0).collect(); + let gate_top256: std::collections::HashSet = + indexed.iter().take(256).map(|x| x.0).collect(); // Entity-routed features from gate index let entity_feats64 = gi.lookup_from_tokens(&entity_tokens, *layer, 64); let entity_feats256 = gi.lookup_from_tokens(&entity_tokens, *layer, 256); - let entity_set64: std::collections::HashSet = entity_feats64.iter().copied().collect(); - let entity_set256: std::collections::HashSet = entity_feats256.iter().copied().collect(); + let entity_set64: std::collections::HashSet = + entity_feats64.iter().copied().collect(); + let entity_set256: std::collections::HashSet = + entity_feats256.iter().copied().collect(); let overlap64 = entity_set64.intersection(&gate_top64).count(); let overlap256 = entity_set256.intersection(&gate_top256).count(); - println!("{:>5} {:>8} {:>8} {:>8} {:>7}/{:<3} {:>7}/{:<3}", - layer, entity_feats64.len(), gate_top64.len(), gate_top256.len(), - overlap64, 64, overlap256, 256); + println!( + "{:>5} {:>8} {:>8} {:>8} {:>7}/{:<3} {:>7}/{:<3}", + layer, + entity_feats64.len(), + gate_top64.len(), + gate_top256.len(), + overlap64, + 64, + overlap256, + 256 + ); } Ok(()) diff --git a/crates/larql-cli/src/commands/extraction/fingerprint_extract_cmd.rs b/crates/larql-cli/src/commands/extraction/fingerprint_extract_cmd.rs index 9feb502d..4df7eb83 100644 --- a/crates/larql-cli/src/commands/extraction/fingerprint_extract_cmd.rs +++ b/crates/larql-cli/src/commands/extraction/fingerprint_extract_cmd.rs @@ -107,7 +107,11 @@ pub fn run(args: FingerprintExtractArgs) -> Result<(), Box