diff --git a/Cargo.lock b/Cargo.lock index 7153b0e4f..4fd7649ec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1185,7 +1185,7 @@ dependencies = [ "criterion 0.5.1", "libm", "proptest", - "ruvector-mincut 2.0.5", + "ruvector-mincut 2.0.6", ] [[package]] @@ -5811,7 +5811,7 @@ dependencies = [ "ruqu-algorithms", "ruvector-attention", "ruvector-cluster", - "ruvector-core 2.0.5", + "ruvector-core 2.0.6", "ruvector-delta-core", "ruvector-filter", "ruvector-gnn", @@ -6589,11 +6589,11 @@ dependencies = [ "rkyv", "roaring", "ruvector-attention", - "ruvector-core 2.0.5", + "ruvector-core 2.0.6", "ruvector-gnn", "ruvector-graph", "ruvector-hyperbolic-hnsw", - "ruvector-mincut 2.0.5", + "ruvector-mincut 2.0.6", "ruvector-nervous-system", "ruvector-raft", "ruvector-sona 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)", @@ -7374,7 +7374,7 @@ dependencies = [ "ndarray 0.16.1", "rand 0.8.5", "rand_distr 0.4.3", - "ruvector-core 2.0.5", + "ruvector-core 2.0.6", "serde", "serde_json", "thiserror 2.0.18", @@ -7611,7 +7611,7 @@ dependencies = [ [[package]] name = "ruqu" -version = "2.0.5" +version = "2.0.6" dependencies = [ "blake3", "cognitum-gate-tilezero 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", @@ -7885,7 +7885,7 @@ dependencies = [ [[package]] name = "ruvector-attention" -version = "2.0.5" +version = "2.0.6" dependencies = [ "approx", "criterion 0.5.1", @@ -7900,7 +7900,7 @@ dependencies = [ [[package]] name = "ruvector-attention-node" -version = "2.0.5" +version = "2.0.6" dependencies = [ "napi", "napi-build", @@ -7932,7 +7932,7 @@ dependencies = [ [[package]] name = "ruvector-attention-wasm" -version = "2.0.5" +version = "2.0.6" dependencies = [ "console_error_panic_hook", "getrandom 0.2.17", @@ -7947,7 +7947,7 @@ dependencies = [ [[package]] name = "ruvector-attn-mincut" -version = "2.0.5" +version = "2.0.6" dependencies = [ "serde", "serde_json", @@ -7956,7 +7956,7 @@ dependencies = [ [[package]] name = "ruvector-bench" -version = "2.0.5" +version = "2.0.6" dependencies = [ "anyhow", "byteorder", @@ -7977,8 +7977,8 @@ dependencies = [ "rayon", "ruvector-cognitive-container", "ruvector-coherence", - "ruvector-core 2.0.5", - "ruvector-mincut 2.0.5", + "ruvector-core 2.0.6", + "ruvector-mincut 2.0.6", "serde", "serde_json", "statistical", @@ -8007,7 +8007,7 @@ dependencies = [ "rand_distr 0.4.3", "rayon", "reqwest 0.11.27", - "ruvector-core 2.0.5", + "ruvector-core 2.0.6", "rvf-crypto", "rvf-types", "rvf-wire", @@ -8024,7 +8024,7 @@ dependencies = [ [[package]] name = "ruvector-cli" -version = "2.0.5" +version = "2.0.6" dependencies = [ "anyhow", "assert_cmd", @@ -8049,7 +8049,7 @@ dependencies = [ "predicates", "prettytable-rs", "rand 0.8.5", - "ruvector-core 2.0.5", + "ruvector-core 2.0.6", "ruvector-gnn", "ruvector-graph", "serde", @@ -8082,7 +8082,7 @@ dependencies = [ "rand_distr 0.4.3", "rayon", "ruvector-attention", - "ruvector-core 2.0.5", + "ruvector-core 2.0.6", "ruvector-gnn", "ruvector-graph", "serde", @@ -8098,7 +8098,7 @@ dependencies = [ [[package]] name = "ruvector-cluster" -version = "2.0.5" +version = "2.0.6" dependencies = [ "async-trait", "bincode 2.0.1", @@ -8107,7 +8107,7 @@ dependencies = [ "futures", "parking_lot 0.12.5", "rand 0.8.5", - "ruvector-core 2.0.5", + "ruvector-core 2.0.6", "serde", "serde_json", "thiserror 2.0.18", @@ -8118,7 +8118,7 @@ dependencies = [ [[package]] name = "ruvector-cnn" -version = "2.0.5" +version = "2.0.6" dependencies = [ "criterion 0.5.1", "fastrand", @@ -8146,7 +8146,7 @@ dependencies = [ [[package]] name = "ruvector-cognitive-container" -version = "2.0.5" +version = "2.0.6" dependencies = [ "proptest", "serde", @@ -8156,7 +8156,7 @@ dependencies = [ [[package]] name = "ruvector-coherence" -version = "2.0.5" +version = "2.0.6" dependencies = [ "serde", "serde_json", @@ -8164,13 +8164,13 @@ dependencies = [ [[package]] name = "ruvector-collections" -version = "2.0.5" +version = "2.0.6" dependencies = [ "bincode 2.0.1", "chrono", "dashmap 6.1.0", "parking_lot 0.12.5", - "ruvector-core 2.0.5", + "ruvector-core 2.0.6", "serde", "serde_json", "thiserror 2.0.18", @@ -8231,7 +8231,7 @@ dependencies = [ [[package]] name = "ruvector-core" -version = "2.0.5" +version = "2.0.6" dependencies = [ "anyhow", "bincode 2.0.1", @@ -8269,7 +8269,7 @@ dependencies = [ "approx", "ruvector-attention", "ruvector-gnn", - "ruvector-mincut 2.0.5", + "ruvector-mincut 2.0.6", "serde", "serde_json", "thiserror 1.0.69", @@ -8277,7 +8277,7 @@ dependencies = [ [[package]] name = "ruvector-dag" -version = "2.0.5" +version = "2.0.6" dependencies = [ "criterion 0.5.1", "crossbeam", @@ -8289,7 +8289,7 @@ dependencies = [ "pqcrypto-kyber", "proptest", "rand 0.8.5", - "ruvector-core 2.0.5", + "ruvector-core 2.0.6", "serde", "serde_json", "sha2", @@ -8413,7 +8413,7 @@ dependencies = [ [[package]] name = "ruvector-domain-expansion" -version = "2.0.5" +version = "2.0.6" dependencies = [ "criterion 0.5.1", "proptest", @@ -8456,7 +8456,7 @@ dependencies = [ [[package]] name = "ruvector-exotic-wasm" -version = "2.0.5" +version = "2.0.6" dependencies = [ "console_error_panic_hook", "getrandom 0.2.17", @@ -8472,12 +8472,12 @@ dependencies = [ [[package]] name = "ruvector-filter" -version = "2.0.5" +version = "2.0.6" dependencies = [ "chrono", "dashmap 6.1.0", "ordered-float", - "ruvector-core 2.0.5", + "ruvector-core 2.0.6", "serde", "serde_json", "thiserror 2.0.18", @@ -8523,7 +8523,7 @@ dependencies = [ [[package]] name = "ruvector-gnn" -version = "2.0.5" +version = "2.0.6" dependencies = [ "anyhow", "criterion 0.5.1", @@ -8539,7 +8539,7 @@ dependencies = [ "rand 0.8.5", "rand_distr 0.4.3", "rayon", - "ruvector-core 2.0.5", + "ruvector-core 2.0.6", "serde", "serde_json", "tempfile", @@ -8548,7 +8548,7 @@ dependencies = [ [[package]] name = "ruvector-gnn-node" -version = "2.0.5" +version = "2.0.6" dependencies = [ "napi", "napi-build", @@ -8559,7 +8559,7 @@ dependencies = [ [[package]] name = "ruvector-gnn-wasm" -version = "2.0.5" +version = "2.0.6" dependencies = [ "console_error_panic_hook", "getrandom 0.2.17", @@ -8574,7 +8574,7 @@ dependencies = [ [[package]] name = "ruvector-graph" -version = "2.0.5" +version = "2.0.6" dependencies = [ "anyhow", "bincode 2.0.1", @@ -8614,7 +8614,7 @@ dependencies = [ "rkyv", "roaring", "ruvector-cluster", - "ruvector-core 2.0.5", + "ruvector-core 2.0.6", "ruvector-raft", "ruvector-replication", "serde", @@ -8635,14 +8635,14 @@ dependencies = [ [[package]] name = "ruvector-graph-node" -version = "2.0.5" +version = "2.0.6" dependencies = [ "anyhow", "futures", "napi", "napi-build", "napi-derive", - "ruvector-core 2.0.5", + "ruvector-core 2.0.6", "ruvector-graph", "serde", "serde_json", @@ -8654,14 +8654,14 @@ dependencies = [ [[package]] name = "ruvector-graph-transformer" -version = "2.0.5" +version = "2.0.6" dependencies = [ "proptest", "rand 0.8.5", "ruvector-attention", "ruvector-coherence", "ruvector-gnn", - "ruvector-mincut 2.0.5", + "ruvector-mincut 2.0.6", "ruvector-solver", "ruvector-verified", "serde", @@ -8670,7 +8670,7 @@ dependencies = [ [[package]] name = "ruvector-graph-transformer-node" -version = "2.0.5" +version = "2.0.6" dependencies = [ "napi", "napi-build", @@ -8682,7 +8682,7 @@ dependencies = [ [[package]] name = "ruvector-graph-transformer-wasm" -version = "2.0.5" +version = "2.0.6" dependencies = [ "js-sys", "serde", @@ -8694,7 +8694,7 @@ dependencies = [ [[package]] name = "ruvector-graph-wasm" -version = "2.0.5" +version = "2.0.6" dependencies = [ "anyhow", "console_error_panic_hook", @@ -8703,7 +8703,7 @@ dependencies = [ "js-sys", "parking_lot 0.12.5", "regex", - "ruvector-core 2.0.5", + "ruvector-core 2.0.6", "ruvector-graph", "serde", "serde-wasm-bindgen", @@ -8745,7 +8745,7 @@ dependencies = [ [[package]] name = "ruvector-math" -version = "2.0.5" +version = "2.0.6" dependencies = [ "approx", "criterion 0.5.1", @@ -8760,7 +8760,7 @@ dependencies = [ [[package]] name = "ruvector-math-wasm" -version = "2.0.5" +version = "2.0.6" dependencies = [ "console_error_panic_hook", "getrandom 0.2.17", @@ -8778,7 +8778,7 @@ dependencies = [ [[package]] name = "ruvector-metrics" -version = "2.0.5" +version = "2.0.6" dependencies = [ "chrono", "lazy_static", @@ -8833,7 +8833,7 @@ dependencies = [ [[package]] name = "ruvector-mincut" -version = "2.0.5" +version = "2.0.6" dependencies = [ "anyhow", "criterion 0.5.1", @@ -8847,7 +8847,7 @@ dependencies = [ "rand 0.8.5", "rayon", "roaring", - "ruvector-core 2.0.5", + "ruvector-core 2.0.6", "ruvector-graph", "serde", "serde_json", @@ -8892,24 +8892,24 @@ dependencies = [ [[package]] name = "ruvector-mincut-node" -version = "2.0.5" +version = "2.0.6" dependencies = [ "napi", "napi-build", "napi-derive", - "ruvector-mincut 2.0.5", + "ruvector-mincut 2.0.6", "serde", "serde_json", ] [[package]] name = "ruvector-mincut-wasm" -version = "2.0.5" +version = "2.0.6" dependencies = [ "console_error_panic_hook", "getrandom 0.2.17", "js-sys", - "ruvector-mincut 2.0.5", + "ruvector-mincut 2.0.6", "serde", "serde-wasm-bindgen", "serde_json", @@ -8919,7 +8919,7 @@ dependencies = [ [[package]] name = "ruvector-nervous-system" -version = "2.0.5" +version = "2.0.6" dependencies = [ "anyhow", "approx", @@ -8953,14 +8953,14 @@ dependencies = [ [[package]] name = "ruvector-node" -version = "2.0.5" +version = "2.0.6" dependencies = [ "anyhow", "napi", "napi-build", "napi-derive", "ruvector-collections", - "ruvector-core 2.0.5", + "ruvector-core 2.0.6", "ruvector-filter", "ruvector-metrics", "serde", @@ -9013,7 +9013,7 @@ dependencies = [ [[package]] name = "ruvector-profiler" -version = "2.0.5" +version = "2.0.6" dependencies = [ "serde", "serde_json", @@ -9022,7 +9022,7 @@ dependencies = [ [[package]] name = "ruvector-raft" -version = "2.0.5" +version = "2.0.6" dependencies = [ "bincode 2.0.1", "chrono", @@ -9030,7 +9030,7 @@ dependencies = [ "futures", "parking_lot 0.12.5", "rand 0.8.5", - "ruvector-core 2.0.5", + "ruvector-core 2.0.6", "serde", "serde_json", "thiserror 2.0.18", @@ -9041,7 +9041,7 @@ dependencies = [ [[package]] name = "ruvector-replication" -version = "2.0.5" +version = "2.0.6" dependencies = [ "bincode 2.0.1", "chrono", @@ -9049,7 +9049,7 @@ dependencies = [ "futures", "parking_lot 0.12.5", "rand 0.8.5", - "ruvector-core 2.0.5", + "ruvector-core 2.0.6", "serde", "serde_json", "thiserror 2.0.18", @@ -9084,7 +9084,7 @@ dependencies = [ [[package]] name = "ruvector-router-cli" -version = "2.0.5" +version = "2.0.6" dependencies = [ "anyhow", "chrono", @@ -9099,7 +9099,7 @@ dependencies = [ [[package]] name = "ruvector-router-core" -version = "2.0.5" +version = "2.0.6" dependencies = [ "anyhow", "bincode 2.0.1", @@ -9126,7 +9126,7 @@ dependencies = [ [[package]] name = "ruvector-router-ffi" -version = "2.0.5" +version = "2.0.6" dependencies = [ "anyhow", "chrono", @@ -9141,7 +9141,7 @@ dependencies = [ [[package]] name = "ruvector-router-wasm" -version = "2.0.5" +version = "2.0.6" dependencies = [ "js-sys", "ruvector-router-core", @@ -9155,7 +9155,7 @@ dependencies = [ [[package]] name = "ruvector-scipix" -version = "2.0.5" +version = "2.0.6" dependencies = [ "ab_glyph", "anyhow", @@ -9228,12 +9228,12 @@ dependencies = [ [[package]] name = "ruvector-server" -version = "2.0.5" +version = "2.0.6" dependencies = [ "axum", "dashmap 6.1.0", "parking_lot 0.12.5", - "ruvector-core 2.0.5", + "ruvector-core 2.0.6", "serde", "serde_json", "thiserror 2.0.18", @@ -9246,13 +9246,13 @@ dependencies = [ [[package]] name = "ruvector-snapshot" -version = "2.0.5" +version = "2.0.6" dependencies = [ "async-trait", "bincode 2.0.1", "chrono", "flate2", - "ruvector-core 2.0.5", + "ruvector-core 2.0.6", "serde", "serde_json", "sha2", @@ -9263,7 +9263,7 @@ dependencies = [ [[package]] name = "ruvector-solver" -version = "2.0.5" +version = "2.0.6" dependencies = [ "approx", "criterion 0.5.1", @@ -9282,7 +9282,7 @@ dependencies = [ [[package]] name = "ruvector-solver-node" -version = "2.0.5" +version = "2.0.6" dependencies = [ "napi", "napi-build", @@ -9295,7 +9295,7 @@ dependencies = [ [[package]] name = "ruvector-solver-wasm" -version = "2.0.5" +version = "2.0.6" dependencies = [ "getrandom 0.2.17", "js-sys", @@ -9345,7 +9345,7 @@ dependencies = [ [[package]] name = "ruvector-sparse-inference" -version = "2.0.5" +version = "2.0.6" dependencies = [ "anyhow", "byteorder", @@ -9368,11 +9368,11 @@ dependencies = [ [[package]] name = "ruvector-temporal-tensor" -version = "2.0.5" +version = "2.0.6" [[package]] name = "ruvector-tiny-dancer-core" -version = "2.0.5" +version = "2.0.6" dependencies = [ "anyhow", "bytemuck", @@ -9402,7 +9402,7 @@ dependencies = [ [[package]] name = "ruvector-tiny-dancer-node" -version = "2.0.5" +version = "2.0.6" dependencies = [ "anyhow", "chrono", @@ -9419,7 +9419,7 @@ dependencies = [ [[package]] name = "ruvector-tiny-dancer-wasm" -version = "2.0.5" +version = "2.0.6" dependencies = [ "js-sys", "ruvector-tiny-dancer-core", @@ -9440,7 +9440,7 @@ dependencies = [ "proptest", "ruvector-cognitive-container", "ruvector-coherence", - "ruvector-core 2.0.5", + "ruvector-core 2.0.6", "serde", "serde_json", "thiserror 2.0.18", @@ -9462,7 +9462,7 @@ dependencies = [ [[package]] name = "ruvector-wasm" -version = "2.0.5" +version = "2.0.6" dependencies = [ "anyhow", "base64 0.22.1", @@ -9475,7 +9475,7 @@ dependencies = [ "parking_lot 0.12.5", "rand 0.8.5", "ruvector-collections", - "ruvector-core 2.0.5", + "ruvector-core 2.0.6", "ruvector-filter", "serde", "serde-wasm-bindgen", @@ -9524,7 +9524,7 @@ dependencies = [ [[package]] name = "ruvllm" -version = "2.0.5" +version = "2.0.6" dependencies = [ "anyhow", "async-trait", @@ -9554,7 +9554,7 @@ dependencies = [ "rayon", "regex", "ruvector-attention", - "ruvector-core 2.0.5", + "ruvector-core 2.0.6", "ruvector-gnn", "ruvector-graph", "ruvector-sona 0.1.6", @@ -9574,7 +9574,7 @@ dependencies = [ [[package]] name = "ruvllm-cli" -version = "2.0.5" +version = "2.0.6" dependencies = [ "anyhow", "assert_cmd", @@ -9594,7 +9594,7 @@ dependencies = [ "predicates", "prettytable-rs", "rustyline", - "ruvllm 2.0.5", + "ruvllm 2.0.6", "serde", "serde_json", "tempfile", @@ -9635,7 +9635,7 @@ dependencies = [ "rand_distr 0.4.3", "ruvector-attention", "ruvector-collections", - "ruvector-core 2.0.5", + "ruvector-core 2.0.6", "ruvector-dag", "ruvector-filter", "ruvector-gnn", @@ -9735,7 +9735,7 @@ dependencies = [ "js-sys", "once_cell", "parking_lot 0.12.5", - "ruvector-core 2.0.5", + "ruvector-core 2.0.6", "rvf-runtime", "rvf-types", "serde", @@ -10496,7 +10496,7 @@ name = "subpolynomial-time-mincut-demo" version = "0.1.0" dependencies = [ "rand 0.8.5", - "ruvector-mincut 2.0.5", + "ruvector-mincut 2.0.6", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 27b4e20d1..89d21c64d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -108,7 +108,7 @@ members = [ resolver = "2" [workspace.package] -version = "2.0.5" +version = "2.0.6" edition = "2021" rust-version = "1.77" license = "MIT" diff --git a/README.md b/README.md index f3578762b..0160071e3 100644 --- a/README.md +++ b/README.md @@ -706,6 +706,8 @@ Everything RuVector can do — organized by category. Vector search, graph queri | **RuvLTRA Models** | Pre-trained GGUF for routing & embeddings | <10ms inference → [HuggingFace](https://huggingface.co/ruv/ruvltra) | | **Streaming Tokens** | Real-time token generation | Responsive chat UX | | **Quantization** | Q4, Q5, Q8 model support | Run 7B models in 4GB RAM | +| **π-Quantization (ADR-090)** | 2-bit weights via π-transform + Hadamard rotation + QAT-STE | **10 GB/s** dequantization, 16x memory reduction | +| **MoE Memory-Aware Routing (ADR-092)** | Cache-aware expert selection with EMA affinity tracking | **70%+ cache hit rate**, <10µs routing latency | ```bash npm install @ruvector/ruvllm # Node.js @@ -754,6 +756,7 @@ cargo add ruvector-raft ruvector-cluster ruvector-replication | Feature | What It Does | Why It Matters | |---------|--------------|----------------| | **Tensor Compression** | f32→f16→PQ8→PQ4→Binary | 2-32x memory reduction | +| **INT8 CNN Quantization (ADR-091)** | Quantized Conv2D/Linear/Pooling with SIMD kernels | **4x memory reduction**, 2x faster CNN inference | | **Differentiable Search** | Soft attention k-NN | End-to-end trainable | | **Semantic Router** | Route queries to optimal endpoints | Multi-model AI orchestration | | **Hybrid Routing** | Keyword-first + embedding fallback | **90% accuracy** for agent routing | diff --git a/crates/neural-trader-coherence/src/lib.rs b/crates/neural-trader-coherence/src/lib.rs index 45da417aa..591685d11 100644 --- a/crates/neural-trader-coherence/src/lib.rs +++ b/crates/neural-trader-coherence/src/lib.rs @@ -130,8 +130,7 @@ impl CoherenceGate for ThresholdGate { let cut_ok = ctx.mincut_value >= floor; let cusum_ok = ctx.cusum_score < self.config.cusum_threshold; let drift_ok = ctx.drift_score < self.config.max_drift_score; - let boundary_ok = - ctx.boundary_stable_count >= self.config.boundary_stability_windows; + let boundary_ok = ctx.boundary_stable_count >= self.config.boundary_stability_windows; // Learning requires tighter drift margin (half the max). let learn_drift_ok = ctx.drift_score < self.config.max_drift_score * 0.5; diff --git a/crates/neural-trader-replay/src/lib.rs b/crates/neural-trader-replay/src/lib.rs index 5c383e133..471f9e060 100644 --- a/crates/neural-trader-replay/src/lib.rs +++ b/crates/neural-trader-replay/src/lib.rs @@ -84,11 +84,8 @@ pub trait MemoryStore { /// Attempts to write a segment. Returns `true` if the gate allowed /// admission, `false` if rejected. - fn maybe_write( - &mut self, - seg: ReplaySegment, - gate: &CoherenceDecision, - ) -> anyhow::Result; + fn maybe_write(&mut self, seg: ReplaySegment, gate: &CoherenceDecision) + -> anyhow::Result; } // --------------------------------------------------------------------------- diff --git a/crates/neural-trader-wasm/src/lib.rs b/crates/neural-trader-wasm/src/lib.rs index 2c564d675..9c99bb3f6 100644 --- a/crates/neural-trader-wasm/src/lib.rs +++ b/crates/neural-trader-wasm/src/lib.rs @@ -45,7 +45,10 @@ fn bytes16_to_hex(b: &[u8; 16]) -> String { fn hex_to_bytes16_inner(s: &str) -> Result<[u8; 16], String> { let s = s.trim(); // Strip optional 0x prefix for JS ergonomics. - let s = s.strip_prefix("0x").or_else(|| s.strip_prefix("0X")).unwrap_or(s); + let s = s + .strip_prefix("0x") + .or_else(|| s.strip_prefix("0X")) + .unwrap_or(s); if !s.is_ascii() || s.len() != 32 { return Err( "hex string must be exactly 32 ASCII hex chars (optional 0x prefix)".to_string(), @@ -53,8 +56,7 @@ fn hex_to_bytes16_inner(s: &str) -> Result<[u8; 16], String> { } let mut out = [0u8; 16]; for (i, byte) in out.iter_mut().enumerate() { - *byte = u8::from_str_radix(&s[i * 2..i * 2 + 2], 16) - .map_err(|e| e.to_string())?; + *byte = u8::from_str_radix(&s[i * 2..i * 2 + 2], 16).map_err(|e| e.to_string())?; } Ok(out) } @@ -66,7 +68,8 @@ fn hex_to_bytes16(s: &str) -> Result<[u8; 16], JsValue> { /// Serialize using BigInt-aware serializer to avoid u64 precision loss. fn to_js(v: &T) -> Result { let ser = serde_wasm_bindgen::Serializer::new().serialize_large_number_types_as_bigints(true); - v.serialize(&ser).map_err(|e| JsValue::from_str(&e.to_string())) + v.serialize(&ser) + .map_err(|e| JsValue::from_str(&e.to_string())) } // --------------------------------------------------------------------------- @@ -143,8 +146,16 @@ enum_convert!(SegmentKindWasm <=> neural_trader_replay::SegmentKind { #[wasm_bindgen] #[derive(Clone, Copy, Debug)] pub enum NodeKindWasm { - Symbol = 0, Venue = 1, PriceLevel = 2, Order = 3, Trade = 4, - Event = 5, Participant = 6, TimeBucket = 7, Regime = 8, StrategyState = 9, + Symbol = 0, + Venue = 1, + PriceLevel = 2, + Order = 3, + Trade = 4, + Event = 5, + Participant = 6, + TimeBucket = 7, + Regime = 8, + StrategyState = 9, } enum_convert!(NodeKindWasm <=> neural_trader_core::NodeKind { Symbol, Venue, PriceLevel, Order, Trade, Event, Participant, @@ -154,9 +165,18 @@ enum_convert!(NodeKindWasm <=> neural_trader_core::NodeKind { #[wasm_bindgen] #[derive(Clone, Copy, Debug)] pub enum EdgeKindWasm { - AtLevel = 0, NextTick = 1, Generated = 2, Matched = 3, ModifiedFrom = 4, - CanceledBy = 5, BelongsToSymbol = 6, OnVenue = 7, InWindow = 8, - CorrelatedWith = 9, InRegime = 10, AffectsState = 11, + AtLevel = 0, + NextTick = 1, + Generated = 2, + Matched = 3, + ModifiedFrom = 4, + CanceledBy = 5, + BelongsToSymbol = 6, + OnVenue = 7, + InWindow = 8, + CorrelatedWith = 9, + InRegime = 10, + AffectsState = 11, } enum_convert!(EdgeKindWasm <=> neural_trader_core::EdgeKind { AtLevel, NextTick, Generated, Matched, ModifiedFrom, CanceledBy, @@ -774,11 +794,7 @@ impl ReservoirStoreWasm { /// Retrieve segments matching a symbol, returned as JSON array. #[wasm_bindgen(js_name = "retrieveBySymbol")] - pub fn retrieve_by_symbol( - &self, - symbol_id: u32, - limit: usize, - ) -> Result { + pub fn retrieve_by_symbol(&self, symbol_id: u32, limit: usize) -> Result { let query = neural_trader_replay::MemoryQuery { symbol_id, embedding: vec![], diff --git a/crates/ruvector-cnn-wasm/src/lib.rs b/crates/ruvector-cnn-wasm/src/lib.rs index 73535ba53..b554236c9 100644 --- a/crates/ruvector-cnn-wasm/src/lib.rs +++ b/crates/ruvector-cnn-wasm/src/lib.rs @@ -10,9 +10,11 @@ #![allow(clippy::new_without_default)] -use wasm_bindgen::prelude::*; -use ruvector_cnn::contrastive::{InfoNCELoss as RustInfoNCE, TripletLoss as RustTriplet, TripletDistance}; +use ruvector_cnn::contrastive::{ + InfoNCELoss as RustInfoNCE, TripletDistance, TripletLoss as RustTriplet, +}; use ruvector_cnn::simd; +use wasm_bindgen::prelude::*; /// Initialize panic hook for better error messages #[wasm_bindgen(start)] @@ -94,9 +96,8 @@ impl WasmCnnEmbedder { let mean: f32 = channel_data.iter().sum::() / pixels_per_channel as f32; // Variance - let variance: f32 = channel_data.iter() - .map(|x| (x - mean).powi(2)) - .sum::() / pixels_per_channel as f32; + let variance: f32 = channel_data.iter().map(|x| (x - mean).powi(2)).sum::() + / pixels_per_channel as f32; // Store in embedding if c * 2 < self.embedding_dim { @@ -195,7 +196,12 @@ impl WasmInfoNCELoss { /// Compute loss for a batch of embedding pairs /// embeddings: [2N, D] flattened where (i, i+N) are positive pairs #[wasm_bindgen] - pub fn forward(&self, embeddings: &[f32], batch_size: usize, dim: usize) -> Result { + pub fn forward( + &self, + embeddings: &[f32], + batch_size: usize, + dim: usize, + ) -> Result { if embeddings.len() != 2 * batch_size * dim { return Err(JsValue::from_str(&format!( "Expected {} elements, got {}", @@ -269,7 +275,10 @@ impl WasmTripletLoss { negatives: &[f32], dim: usize, ) -> Result { - if anchors.len() % dim != 0 || positives.len() != anchors.len() || negatives.len() != anchors.len() { + if anchors.len() % dim != 0 + || positives.len() != anchors.len() + || negatives.len() != anchors.len() + { return Err(JsValue::from_str("Invalid triplet dimensions")); } @@ -277,9 +286,18 @@ impl WasmTripletLoss { let mut total_loss = 0.0f64; for i in 0..batch_size { - let a: Vec = anchors[i * dim..(i + 1) * dim].iter().map(|&x| x as f64).collect(); - let p: Vec = positives[i * dim..(i + 1) * dim].iter().map(|&x| x as f64).collect(); - let n: Vec = negatives[i * dim..(i + 1) * dim].iter().map(|&x| x as f64).collect(); + let a: Vec = anchors[i * dim..(i + 1) * dim] + .iter() + .map(|&x| x as f64) + .collect(); + let p: Vec = positives[i * dim..(i + 1) * dim] + .iter() + .map(|&x| x as f64) + .collect(); + let n: Vec = negatives[i * dim..(i + 1) * dim] + .iter() + .map(|&x| x as f64) + .collect(); total_loss += self.inner.forward(&a, &p, &n); } @@ -351,14 +369,28 @@ impl LayerOps { ) -> Vec { let channels = gamma.len(); let mut output = vec![0.0f32; input.len()]; - simd::batch_norm_simd(input, &mut output, gamma, beta, mean, var, epsilon, channels); + simd::batch_norm_simd( + input, + &mut output, + gamma, + beta, + mean, + var, + epsilon, + channels, + ); output } /// Apply global average pooling /// Returns one value per channel #[wasm_bindgen] - pub fn global_avg_pool(input: &[f32], height: usize, width: usize, channels: usize) -> Vec { + pub fn global_avg_pool( + input: &[f32], + height: usize, + width: usize, + channels: usize, + ) -> Vec { let mut output = vec![0.0f32; channels]; simd::global_avg_pool_simd(input, &mut output, height, width, channels); output @@ -382,7 +414,8 @@ mod tests { input_size: 8, embedding_dim: 64, normalize: true, - })).unwrap(); + })) + .unwrap(); let image_data = vec![128u8; 8 * 8 * 3]; let embedding = embedder.extract(&image_data, 8, 8).unwrap(); diff --git a/crates/ruvector-cnn/benches/cnn_benchmarks.rs b/crates/ruvector-cnn/benches/cnn_benchmarks.rs index f6f3b594f..7d934c69d 100644 --- a/crates/ruvector-cnn/benches/cnn_benchmarks.rs +++ b/crates/ruvector-cnn/benches/cnn_benchmarks.rs @@ -10,16 +10,13 @@ //! Run with: cargo bench --package ruvector-cnn //! View HTML report: open target/criterion/report/index.html -use criterion::{ - black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput, -}; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; use ruvector_cnn::{ layers::{ - BatchNorm, Conv2d, DepthwiseSeparableConv, GlobalAvgPool, Layer, MaxPool2d, - ReLU, ReLU6, HardSwish, Swish, + BatchNorm, Conv2d, DepthwiseSeparableConv, GlobalAvgPool, HardSwish, Layer, MaxPool2d, + ReLU, ReLU6, Swish, }, - simd, - Tensor, + simd, Tensor, }; // ============================================================================ @@ -95,14 +92,15 @@ fn bench_simd_dot_product(c: &mut Criterion) { let b = vec![2.0f32; size]; group.bench_with_input(BenchmarkId::new("simd", size), &size, |b_iter, _| { - b_iter.iter(|| { - black_box(simd::dot_product_simd(black_box(&a), black_box(&b))) - }) + b_iter.iter(|| black_box(simd::dot_product_simd(black_box(&a), black_box(&b)))) }); group.bench_with_input(BenchmarkId::new("scalar", size), &size, |b_iter, _| { b_iter.iter(|| { - black_box(simd::scalar::dot_product_scalar(black_box(&a), black_box(&b))) + black_box(simd::scalar::dot_product_scalar( + black_box(&a), + black_box(&b), + )) }) }); } @@ -112,12 +110,7 @@ fn bench_simd_dot_product(c: &mut Criterion) { fn bench_simd_batch_norm(c: &mut Criterion) { // (height, width, channels) - let configs = [ - (8, 8, 64), - (28, 28, 128), - (56, 56, 64), - (112, 112, 32), - ]; + let configs = [(8, 8, 64), (28, 28, 128), (56, 56, 64), (112, 112, 32)]; let mut group = c.benchmark_group("simd/batch_norm"); group.sample_size(50); @@ -247,8 +240,8 @@ fn bench_simd_conv_3x3(c: &mut Criterion) { fn bench_simd_global_avg_pool(c: &mut Criterion) { let configs = [ - (7, 7, 576), // MobileNetV3-Small final - (7, 7, 960), // MobileNetV3-Large final + (7, 7, 576), // MobileNetV3-Small final + (7, 7, 960), // MobileNetV3-Large final (14, 14, 256), (28, 28, 128), ]; @@ -268,13 +261,7 @@ fn bench_simd_global_avg_pool(c: &mut Criterion) { &(h, w, ch), |b, _| { b.iter(|| { - simd::global_avg_pool_simd( - black_box(&input), - black_box(&mut output), - h, - w, - ch, - ); + simd::global_avg_pool_simd(black_box(&input), black_box(&mut output), h, w, ch); }) }, ); @@ -306,10 +293,10 @@ fn bench_simd_global_avg_pool(c: &mut Criterion) { fn bench_conv2d_layer(c: &mut Criterion) { // (batch, height, width, in_channels, out_channels) let configs = [ - (1, 8, 8, 3, 16), // Small - (1, 32, 32, 3, 32), // Medium - (1, 56, 56, 16, 64), // Large (MobileNet-style) - (1, 112, 112, 3, 32), // Initial conv + (1, 8, 8, 3, 16), // Small + (1, 32, 32, 3, 32), // Medium + (1, 56, 56, 16, 64), // Large (MobileNet-style) + (1, 112, 112, 3, 32), // Initial conv ]; let mut group = c.benchmark_group("layers/conv2d_3x3"); @@ -325,11 +312,7 @@ fn bench_conv2d_layer(c: &mut Criterion) { group.bench_with_input( BenchmarkId::new("forward", format!("{}x{}x{}->{}", h, w, in_c, out_c)), &(n, h, w, in_c, out_c), - |b, _| { - b.iter(|| { - black_box(conv.forward(black_box(&input)).unwrap()) - }) - }, + |b, _| b.iter(|| black_box(conv.forward(black_box(&input)).unwrap())), ); } @@ -357,11 +340,7 @@ fn bench_depthwise_separable(c: &mut Criterion) { group.bench_with_input( BenchmarkId::new("forward", format!("{}x{}x{}->{}", h, w, in_c, out_c)), &(n, h, w, in_c, out_c), - |b, _| { - b.iter(|| { - black_box(conv.forward(black_box(&input)).unwrap()) - }) - }, + |b, _| b.iter(|| black_box(conv.forward(black_box(&input)).unwrap())), ); } @@ -389,11 +368,7 @@ fn bench_batch_norm_layer(c: &mut Criterion) { group.bench_with_input( BenchmarkId::new("forward", format!("{}x{}x{}", h, w, ch)), &(n, h, w, ch), - |b, _| { - b.iter(|| { - black_box(bn.forward(black_box(&input)).unwrap()) - }) - }, + |b, _| b.iter(|| black_box(bn.forward(black_box(&input)).unwrap())), ); } @@ -405,11 +380,7 @@ fn bench_batch_norm_layer(c: &mut Criterion) { // ============================================================================ fn bench_activations(c: &mut Criterion) { - let configs = [ - (1, 56, 56, 64), - (1, 28, 28, 128), - (1, 14, 14, 256), - ]; + let configs = [(1, 56, 56, 64), (1, 28, 28, 128), (1, 14, 14, 256)]; let mut group = c.benchmark_group("layers/activations"); group.sample_size(50); @@ -427,11 +398,7 @@ fn bench_activations(c: &mut Criterion) { group.bench_with_input( BenchmarkId::new("relu", &size_label), &(n, h, w, ch), - |b, _| { - b.iter(|| { - black_box(relu.forward(black_box(&input)).unwrap()) - }) - }, + |b, _| b.iter(|| black_box(relu.forward(black_box(&input)).unwrap())), ); // ReLU6 @@ -439,11 +406,7 @@ fn bench_activations(c: &mut Criterion) { group.bench_with_input( BenchmarkId::new("relu6", &size_label), &(n, h, w, ch), - |b, _| { - b.iter(|| { - black_box(relu6.forward(black_box(&input)).unwrap()) - }) - }, + |b, _| b.iter(|| black_box(relu6.forward(black_box(&input)).unwrap())), ); // Swish @@ -451,11 +414,7 @@ fn bench_activations(c: &mut Criterion) { group.bench_with_input( BenchmarkId::new("swish", &size_label), &(n, h, w, ch), - |b, _| { - b.iter(|| { - black_box(swish.forward(black_box(&input)).unwrap()) - }) - }, + |b, _| b.iter(|| black_box(swish.forward(black_box(&input)).unwrap())), ); // HardSwish @@ -463,11 +422,7 @@ fn bench_activations(c: &mut Criterion) { group.bench_with_input( BenchmarkId::new("hard_swish", &size_label), &(n, h, w, ch), - |b, _| { - b.iter(|| { - black_box(hard_swish.forward(black_box(&input)).unwrap()) - }) - }, + |b, _| b.iter(|| black_box(hard_swish.forward(black_box(&input)).unwrap())), ); } @@ -483,8 +438,8 @@ fn bench_pooling(c: &mut Criterion) { (1, 8, 8, 64), (1, 28, 28, 128), (1, 56, 56, 64), - (1, 7, 7, 576), // MobileNetV3-Small final - (1, 7, 7, 960), // MobileNetV3-Large final + (1, 7, 7, 576), // MobileNetV3-Small final + (1, 7, 7, 960), // MobileNetV3-Large final ]; let mut group = c.benchmark_group("layers/pooling"); @@ -502,11 +457,7 @@ fn bench_pooling(c: &mut Criterion) { group.bench_with_input( BenchmarkId::new("global_avg", &size_label), &(n, h, w, ch), - |b, _| { - b.iter(|| { - black_box(gap.forward(black_box(&input)).unwrap()) - }) - }, + |b, _| b.iter(|| black_box(gap.forward(black_box(&input)).unwrap())), ); // MaxPool2d (only for sizes >= 4) @@ -515,11 +466,7 @@ fn bench_pooling(c: &mut Criterion) { group.bench_with_input( BenchmarkId::new("max_2x2", &size_label), &(n, h, w, ch), - |b, _| { - b.iter(|| { - black_box(maxpool.forward(black_box(&input)).unwrap()) - }) - }, + |b, _| b.iter(|| black_box(maxpool.forward(black_box(&input)).unwrap())), ); } } @@ -609,15 +556,9 @@ fn bench_batch_scaling(c: &mut Criterion) { let input = Tensor::ones(&[batch, h, w, in_c]); - group.bench_with_input( - BenchmarkId::new("conv2d", batch), - &batch, - |b, _| { - b.iter(|| { - black_box(conv.forward(black_box(&input)).unwrap()) - }) - }, - ); + group.bench_with_input(BenchmarkId::new("conv2d", batch), &batch, |b, _| { + b.iter(|| black_box(conv.forward(black_box(&input)).unwrap())) + }); } group.finish(); @@ -632,11 +573,7 @@ fn bench_tensor_operations(c: &mut Criterion) { group.sample_size(50); // Test tensor creation - let shapes = [ - (1, 224, 224, 3), - (1, 56, 56, 64), - (16, 28, 28, 128), - ]; + let shapes = [(1, 224, 224, 3), (1, 56, 56, 64), (16, 28, 28, 128)]; for (n, h, w, c) in shapes { let elements = n * h * w * c; @@ -644,36 +581,18 @@ fn bench_tensor_operations(c: &mut Criterion) { group.throughput(Throughput::Elements(elements as u64)); - group.bench_with_input( - BenchmarkId::new("zeros", &label), - &(n, h, w, c), - |b, _| { - b.iter(|| { - black_box(Tensor::zeros(&[n, h, w, c])) - }) - }, - ); + group.bench_with_input(BenchmarkId::new("zeros", &label), &(n, h, w, c), |b, _| { + b.iter(|| black_box(Tensor::zeros(&[n, h, w, c]))) + }); - group.bench_with_input( - BenchmarkId::new("ones", &label), - &(n, h, w, c), - |b, _| { - b.iter(|| { - black_box(Tensor::ones(&[n, h, w, c])) - }) - }, - ); + group.bench_with_input(BenchmarkId::new("ones", &label), &(n, h, w, c), |b, _| { + b.iter(|| black_box(Tensor::ones(&[n, h, w, c]))) + }); let tensor = Tensor::ones(&[n, h, w, c]); - group.bench_with_input( - BenchmarkId::new("clone", &label), - &(n, h, w, c), - |b, _| { - b.iter(|| { - black_box(tensor.clone()) - }) - }, - ); + group.bench_with_input(BenchmarkId::new("clone", &label), &(n, h, w, c), |b, _| { + b.iter(|| black_box(tensor.clone())) + }); } group.finish(); @@ -702,15 +621,8 @@ criterion_group!( bench_pooling, ); -criterion_group!( - block_benches, - bench_full_block, - bench_batch_scaling, -); +criterion_group!(block_benches, bench_full_block, bench_batch_scaling,); -criterion_group!( - misc_benches, - bench_tensor_operations, -); +criterion_group!(misc_benches, bench_tensor_operations,); criterion_main!(simd_benches, layer_benches, block_benches, misc_benches); diff --git a/crates/ruvector-cnn/benches/int8_bench.rs b/crates/ruvector-cnn/benches/int8_bench.rs index ffcd9b5d3..a9276fc3a 100644 --- a/crates/ruvector-cnn/benches/int8_bench.rs +++ b/crates/ruvector-cnn/benches/int8_bench.rs @@ -4,8 +4,8 @@ //! //! Run with: `cargo bench --bench int8_bench` -use criterion::{black_box, criterion_group, criterion_main, Criterion, BenchmarkId, Throughput}; -use ruvector_cnn::int8::{QuantParams, quantize_tensor, dequantize_tensor}; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use ruvector_cnn::int8::{dequantize_tensor, quantize_tensor, QuantParams}; #[cfg(target_arch = "x86_64")] use ruvector_cnn::int8::kernels::simd::{conv2d_int8_simd, matmul_int8_simd}; @@ -94,15 +94,7 @@ fn bench_conv2d_int8(c: &mut Criterion) { &(&input_fp32, &kernel_fp32), |b, (input, kernel)| { b.iter(|| { - conv2d_fp32_naive( - black_box(input), - black_box(kernel), - h, - w, - c, - k, - stride, - ) + conv2d_fp32_naive(black_box(input), black_box(kernel), h, w, c, k, stride) }) }, ); @@ -116,10 +108,10 @@ fn bench_matmul_int8(c: &mut Criterion) { // Test different matrix sizes let sizes = vec![ - (64, 64, 64), // Small - (128, 128, 128), // Medium - (256, 256, 256), // Large - (512, 512, 512), // XLarge + (64, 64, 64), // Small + (128, 128, 128), // Medium + (256, 256, 256), // Large + (512, 512, 512), // XLarge ]; for (m, n, k) in sizes { @@ -146,14 +138,7 @@ fn bench_matmul_int8(c: &mut Criterion) { &(&a, &b, params), |bench, (a, b, params)| { bench.iter(|| { - matmul_int8_scalar( - black_box(a), - black_box(b), - black_box(*params), - m, - n, - k, - ) + matmul_int8_scalar(black_box(a), black_box(b), black_box(*params), m, n, k) }) }, ); @@ -188,17 +173,7 @@ fn bench_matmul_int8(c: &mut Criterion) { group.bench_with_input( BenchmarkId::new("fp32_baseline", &bench_name), &(&a_fp32, &b_fp32), - |bench, (a, b)| { - bench.iter(|| { - matmul_fp32_naive( - black_box(a), - black_box(b), - m, - n, - k, - ) - }) - }, + |bench, (a, b)| bench.iter(|| matmul_fp32_naive(black_box(a), black_box(b), m, n, k)), ); } @@ -233,10 +208,7 @@ fn bench_mobilenetv3_int8(c: &mut Criterion) { for i in 0..embedding_size { let start = (i * input_size) / embedding_size; let end = ((i + 1) * input_size) / embedding_size; - let sum: i32 = input_int8[start..end] - .iter() - .map(|&x| x as i32) - .sum(); + let sum: i32 = input_int8[start..end].iter().map(|&x| x as i32).sum(); embedding[i] = sum; } black_box(embedding) @@ -279,9 +251,7 @@ fn bench_quantization_dequantization(c: &mut Criterion) { group.bench_with_input( BenchmarkId::new("quantize", size), &(&fp32, params), - |b, (fp32, params)| { - b.iter(|| quantize_tensor(black_box(fp32), black_box(params))) - }, + |b, (fp32, params)| b.iter(|| quantize_tensor(black_box(fp32), black_box(params))), ); // Benchmark dequantization @@ -289,9 +259,7 @@ fn bench_quantization_dequantization(c: &mut Criterion) { group.bench_with_input( BenchmarkId::new("dequantize", size), &(&int8, params), - |b, (int8, params)| { - b.iter(|| dequantize_tensor(black_box(int8), black_box(params))) - }, + |b, (int8, params)| b.iter(|| dequantize_tensor(black_box(int8), black_box(params))), ); // Benchmark round-trip @@ -341,7 +309,8 @@ fn bench_memory_usage(c: &mut Criterion) { assert!( reduction >= 3.0, "GATE-4 FAILED: Memory reduction {:.2}x < 3.0x for size {}", - reduction, size + reduction, + size ); // Dummy benchmark to keep Criterion happy diff --git a/crates/ruvector-cnn/examples/graph_rewrite_demo.rs b/crates/ruvector-cnn/examples/graph_rewrite_demo.rs index 4ef418754..a0812f46c 100644 --- a/crates/ruvector-cnn/examples/graph_rewrite_demo.rs +++ b/crates/ruvector-cnn/examples/graph_rewrite_demo.rs @@ -6,8 +6,8 @@ //! - GR-3: Q/DQ node insertion //! - GR-4: Activation fusion (ReLU, HardSwish) -use ruvector_cnn::quantize::graph_rewrite::*; use ruvector_cnn::quantize::calibration::QuantizationParams; +use ruvector_cnn::quantize::graph_rewrite::*; use std::collections::HashMap; fn main() { @@ -53,7 +53,11 @@ fn demo_batchnorm_fusion() { println!("Before fusion: {} nodes", graph.nodes.len()); let fused = fuse_batchnorm_to_conv(&mut graph); - println!("After fusion: {} nodes (fused {} BatchNorm layers)", graph.nodes.len(), fused); + println!( + "After fusion: {} nodes (fused {} BatchNorm layers)", + graph.nodes.len(), + fused + ); if let Some(conv_node) = graph.get_node(conv) { if let NodeParams::Conv2d { weights, bias, .. } = &conv_node.params { @@ -150,7 +154,11 @@ fn demo_qdq_insertion() { println!("Graph: Input → Conv → Output"); let inserted = insert_qdq_nodes(&mut graph, &quant_params); - println!("After Q/DQ insertion: {} nodes ({} Q/DQ nodes added)", graph.nodes.len(), inserted); + println!( + "After Q/DQ insertion: {} nodes ({} Q/DQ nodes added)", + graph.nodes.len(), + inserted + ); println!("Graph: Input → Quantize → Conv → Dequantize → Output"); println!(); } @@ -175,7 +183,11 @@ fn demo_activation_fusion() { println!("ReLU Fusion:"); println!(" Before: {} nodes (Conv → ReLU)", graph1.nodes.len()); let fused_relu = fuse_relu(&mut graph1); - println!(" After: {} nodes (Conv with fused ReLU, {} activations fused)", graph1.nodes.len(), fused_relu); + println!( + " After: {} nodes (Conv with fused ReLU, {} activations fused)", + graph1.nodes.len(), + fused_relu + ); let mut graph2 = ComputationGraph::new(); let conv2 = graph2.add_node( @@ -194,7 +206,11 @@ fn demo_activation_fusion() { println!("\nHardSwish Fusion:"); println!(" Before: {} nodes (Conv → HardSwish)", graph2.nodes.len()); let fused_hs = fuse_hardswish(&mut graph2); - println!(" After: {} nodes (Conv with LUT-based HardSwish, {} activations fused)", graph2.nodes.len(), fused_hs); + println!( + " After: {} nodes (Conv with LUT-based HardSwish, {} activations fused)", + graph2.nodes.len(), + fused_hs + ); // Generate HardSwish LUT let lut = generate_hardswish_lut(0.1, 0); @@ -259,16 +275,31 @@ fn demo_complete_pipeline() { println!("\nApplying optimization passes:"); let bn_fused = fuse_batchnorm_to_conv(&mut graph); - println!(" ✓ GR-1: Fused {} BatchNorm layers → {} nodes", bn_fused, graph.nodes.len()); + println!( + " ✓ GR-1: Fused {} BatchNorm layers → {} nodes", + bn_fused, + graph.nodes.len() + ); let relu_fused = fuse_relu(&mut graph); - println!(" ✓ GR-4: Fused {} ReLU activations → {} nodes", relu_fused, graph.nodes.len()); + println!( + " ✓ GR-4: Fused {} ReLU activations → {} nodes", + relu_fused, + graph.nodes.len() + ); let hs_fused = fuse_hardswish(&mut graph); - println!(" ✓ GR-4: Fused {} HardSwish activations → {} nodes", hs_fused, graph.nodes.len()); + println!( + " ✓ GR-4: Fused {} HardSwish activations → {} nodes", + hs_fused, + graph.nodes.len() + ); println!("\nOptimized graph: {} nodes", graph.nodes.len()); println!(" Input → Conv1(+BN+ReLU) → Conv2(+HardSwish) → Output"); - println!("\nMemory savings: {} nodes eliminated", 7 - graph.nodes.len()); + println!( + "\nMemory savings: {} nodes eliminated", + 7 - graph.nodes.len() + ); println!("Runtime benefit: 3 fewer ops, fused activations"); } diff --git a/crates/ruvector-cnn/src/backbone/blocks.rs b/crates/ruvector-cnn/src/backbone/blocks.rs index 927364910..c080ef435 100644 --- a/crates/ruvector-cnn/src/backbone/blocks.rs +++ b/crates/ruvector-cnn/src/backbone/blocks.rs @@ -5,11 +5,11 @@ //! - SqueezeExcitation: Channel attention mechanism //! - InvertedResidual: The main building block with optional SE +use super::layer::Layer; use crate::error::CnnResult; use crate::layers::{ Activation, ActivationType, BatchNorm2d, Conv2d, GlobalAvgPool2d, Linear, TensorShape, }; -use super::layer::Layer; /// Convolution + BatchNorm + Activation block. /// @@ -83,7 +83,15 @@ impl ConvBNActivation { activation: ActivationType, ) -> CnnResult { let padding = kernel_size / 2; - Self::new(channels, channels, kernel_size, stride, padding, channels, activation) + Self::new( + channels, + channels, + kernel_size, + stride, + padding, + channels, + activation, + ) } /// Returns a reference to the convolution layer. @@ -361,7 +369,10 @@ impl InvertedResidual { // Optional SE let se = if config.use_se { let se_channels = (config.expanded_channels / 4).max(1); - Some(SqueezeExcitation::new(config.expanded_channels, se_channels)?) + Some(SqueezeExcitation::new( + config.expanded_channels, + se_channels, + )?) } else { None }; @@ -555,8 +566,13 @@ mod tests { #[test] fn test_inverted_residual_no_expansion() { let block = InvertedResidual::create( - 16, 16, 16, // in == exp == out (no expansion) - 3, 1, false, ActivationType::ReLU, + 16, + 16, + 16, // in == exp == out (no expansion) + 3, + 1, + false, + ActivationType::ReLU, ) .unwrap(); @@ -567,8 +583,13 @@ mod tests { #[test] fn test_inverted_residual_with_expansion() { let block = InvertedResidual::create( - 16, 64, 24, // expansion ratio 4 - 3, 1, true, ActivationType::HardSwish, + 16, + 64, + 24, // expansion ratio 4 + 3, + 1, + true, + ActivationType::HardSwish, ) .unwrap(); @@ -580,9 +601,13 @@ mod tests { #[test] fn test_inverted_residual_output_shape() { let block = InvertedResidual::create( - 16, 64, 24, - 3, 2, // stride 2 - true, ActivationType::HardSwish, + 16, + 64, + 24, + 3, + 2, // stride 2 + true, + ActivationType::HardSwish, ) .unwrap(); @@ -597,12 +622,8 @@ mod tests { #[test] fn test_inverted_residual_params() { - let block = InvertedResidual::create( - 16, 64, 24, - 3, 1, - true, ActivationType::HardSwish, - ) - .unwrap(); + let block = + InvertedResidual::create(16, 64, 24, 3, 1, true, ActivationType::HardSwish).unwrap(); // Should have params from: expand, depthwise, SE, project assert!(block.num_params() > 0); diff --git a/crates/ruvector-cnn/src/backbone/layer.rs b/crates/ruvector-cnn/src/backbone/layer.rs index bf676bc75..2fbd3f6e2 100644 --- a/crates/ruvector-cnn/src/backbone/layer.rs +++ b/crates/ruvector-cnn/src/backbone/layer.rs @@ -5,8 +5,8 @@ use crate::error::CnnResult; use crate::layers::{ - Activation, ActivationType, BatchNorm, Conv2d, GlobalAvgPool, Linear, TensorShape, - conv_output_size, + conv_output_size, Activation, ActivationType, BatchNorm, Conv2d, GlobalAvgPool, Linear, + TensorShape, }; use crate::Tensor; @@ -42,13 +42,26 @@ impl Layer for Conv2d { // Convert output back to NCHW let out_shape = output_tensor.shape(); - let out_tensor_shape = TensorShape::new(out_shape[0], out_shape[3], out_shape[1], out_shape[2]); + let out_tensor_shape = + TensorShape::new(out_shape[0], out_shape[3], out_shape[1], out_shape[2]); Ok(nhwc_to_nchw(output_tensor.data(), &out_tensor_shape)) } fn output_shape(&self, input_shape: &TensorShape) -> TensorShape { - let out_h = conv_output_size(input_shape.h, self.kernel_size(), self.stride(), self.padding(), 1); - let out_w = conv_output_size(input_shape.w, self.kernel_size(), self.stride(), self.padding(), 1); + let out_h = conv_output_size( + input_shape.h, + self.kernel_size(), + self.stride(), + self.padding(), + 1, + ); + let out_w = conv_output_size( + input_shape.w, + self.kernel_size(), + self.stride(), + self.padding(), + 1, + ); TensorShape::new(input_shape.n, self.out_channels(), out_h, out_w) } @@ -140,7 +153,11 @@ impl Layer for Linear { fn num_params(&self) -> usize { let weight_params = self.out_features() * self.in_features(); - let bias_params = if self.bias().is_some() { self.out_features() } else { 0 }; + let bias_params = if self.bias().is_some() { + self.out_features() + } else { + 0 + }; weight_params + bias_params } } diff --git a/crates/ruvector-cnn/src/backbone/mobilenet.rs b/crates/ruvector-cnn/src/backbone/mobilenet.rs index 1f60b2421..ccb75f319 100644 --- a/crates/ruvector-cnn/src/backbone/mobilenet.rs +++ b/crates/ruvector-cnn/src/backbone/mobilenet.rs @@ -16,9 +16,9 @@ //! - **Small**: ~2.5M params, 576 output channels, optimized for latency //! - **Large**: ~5.4M params, 960 output channels, optimized for accuracy -use super::{Backbone, BackboneExt, BackboneType}; use super::blocks::{ConvBNActivation, InvertedResidual as BlockInvertedResidual}; use super::layer::Layer; +use super::{Backbone, BackboneExt, BackboneType}; use crate::error::CnnResult; use crate::layers::{self, ActivationType, GlobalAvgPool2d, Linear, TensorShape}; @@ -49,7 +49,10 @@ impl Default for MobileNetConfig { /// /// **DEPRECATED**: Use [`MobileNetV3`] with `BackboneType::MobileNetV3Small` instead. /// This legacy implementation has limited functionality. -#[deprecated(since = "2.0.6", note = "Use MobileNetV3 with BackboneType::MobileNetV3Small instead")] +#[deprecated( + since = "2.0.6", + note = "Use MobileNetV3 with BackboneType::MobileNetV3Small instead" +)] #[derive(Debug, Clone)] pub struct MobileNetV3Small { config: MobileNetConfig, @@ -66,7 +69,10 @@ pub struct MobileNetV3Small { /// /// **DEPRECATED**: Use [`MobileNetV3`] with `BackboneType::MobileNetV3Large` instead. /// This legacy implementation has limited functionality. -#[deprecated(since = "2.0.6", note = "Use MobileNetV3 with BackboneType::MobileNetV3Large instead")] +#[deprecated( + since = "2.0.6", + note = "Use MobileNetV3 with BackboneType::MobileNetV3Large instead" +)] #[derive(Debug, Clone)] pub struct MobileNetV3Large { config: MobileNetConfig, @@ -137,7 +143,7 @@ impl MobileNetV3Small { fn create_blocks(config: &MobileNetConfig) -> Vec { // Simplified block configuration for MobileNet-V3 Small let block_configs = [ - (16, 16, 1, false), // in, out, expansion, se + (16, 16, 1, false), // in, out, expansion, se (16, 24, 4, false), (24, 24, 3, false), (24, 40, 3, true), @@ -149,41 +155,53 @@ impl MobileNetV3Small { (96, 96, 6, true), ]; - block_configs.iter().map(|&(in_c, out_c, exp, se)| { - let in_c = ((in_c as f32) * config.width_mult) as usize; - let out_c = ((out_c as f32) * config.width_mult) as usize; - let mid_c = in_c * exp; - - InvertedResidual { - expand_weights: if exp != 1 { Some(vec![0.0; in_c * mid_c]) } else { None }, - expand_bn: if exp != 1 { Some(BnParams::new(mid_c)) } else { None }, - dw_weights: vec![0.0; 9 * mid_c], - dw_bn: BnParams::new(mid_c), - se_reduce: if se { Some(vec![0.0; mid_c * (mid_c / 4)]) } else { None }, - se_expand: if se { Some(vec![0.0; (mid_c / 4) * mid_c]) } else { None }, - project_weights: vec![0.0; mid_c * out_c], - project_bn: BnParams::new(out_c), - in_channels: in_c, - out_channels: out_c, - expansion: exp, - use_se: se, - use_residual: in_c == out_c, - } - }).collect() + block_configs + .iter() + .map(|&(in_c, out_c, exp, se)| { + let in_c = ((in_c as f32) * config.width_mult) as usize; + let out_c = ((out_c as f32) * config.width_mult) as usize; + let mid_c = in_c * exp; + + InvertedResidual { + expand_weights: if exp != 1 { + Some(vec![0.0; in_c * mid_c]) + } else { + None + }, + expand_bn: if exp != 1 { + Some(BnParams::new(mid_c)) + } else { + None + }, + dw_weights: vec![0.0; 9 * mid_c], + dw_bn: BnParams::new(mid_c), + se_reduce: if se { + Some(vec![0.0; mid_c * (mid_c / 4)]) + } else { + None + }, + se_expand: if se { + Some(vec![0.0; (mid_c / 4) * mid_c]) + } else { + None + }, + project_weights: vec![0.0; mid_c * out_c], + project_bn: BnParams::new(out_c), + in_channels: in_c, + out_channels: out_c, + expansion: exp, + use_se: se, + use_residual: in_c == out_c, + } + }) + .collect() } } impl Backbone for MobileNetV3Small { fn forward(&self, input: &[f32], height: usize, width: usize) -> Vec { // Stem: 3x3 conv, stride 2 - let mut x = layers::conv2d_3x3( - input, - &self.stem_weights, - 3, - 16, - height, - width, - ); + let mut x = layers::conv2d_3x3(input, &self.stem_weights, 3, 16, height, width); x = layers::batch_norm( &x, &self.stem_bn.gamma, @@ -249,41 +267,53 @@ impl MobileNetV3Large { (160, 160, 6, true), ]; - block_configs.iter().map(|&(in_c, out_c, exp, se)| { - let in_c = ((in_c as f32) * config.width_mult) as usize; - let out_c = ((out_c as f32) * config.width_mult) as usize; - let mid_c = in_c * exp; - - InvertedResidual { - expand_weights: if exp != 1 { Some(vec![0.0; in_c * mid_c]) } else { None }, - expand_bn: if exp != 1 { Some(BnParams::new(mid_c)) } else { None }, - dw_weights: vec![0.0; 9 * mid_c], - dw_bn: BnParams::new(mid_c), - se_reduce: if se { Some(vec![0.0; mid_c * (mid_c / 4)]) } else { None }, - se_expand: if se { Some(vec![0.0; (mid_c / 4) * mid_c]) } else { None }, - project_weights: vec![0.0; mid_c * out_c], - project_bn: BnParams::new(out_c), - in_channels: in_c, - out_channels: out_c, - expansion: exp, - use_se: se, - use_residual: in_c == out_c, - } - }).collect() + block_configs + .iter() + .map(|&(in_c, out_c, exp, se)| { + let in_c = ((in_c as f32) * config.width_mult) as usize; + let out_c = ((out_c as f32) * config.width_mult) as usize; + let mid_c = in_c * exp; + + InvertedResidual { + expand_weights: if exp != 1 { + Some(vec![0.0; in_c * mid_c]) + } else { + None + }, + expand_bn: if exp != 1 { + Some(BnParams::new(mid_c)) + } else { + None + }, + dw_weights: vec![0.0; 9 * mid_c], + dw_bn: BnParams::new(mid_c), + se_reduce: if se { + Some(vec![0.0; mid_c * (mid_c / 4)]) + } else { + None + }, + se_expand: if se { + Some(vec![0.0; (mid_c / 4) * mid_c]) + } else { + None + }, + project_weights: vec![0.0; mid_c * out_c], + project_bn: BnParams::new(out_c), + in_channels: in_c, + out_channels: out_c, + expansion: exp, + use_se: se, + use_residual: in_c == out_c, + } + }) + .collect() } } impl Backbone for MobileNetV3Large { fn forward(&self, input: &[f32], height: usize, width: usize) -> Vec { // Same structure as Small but with more blocks - let mut x = layers::conv2d_3x3( - input, - &self.stem_weights, - 3, - 16, - height, - width, - ); + let mut x = layers::conv2d_3x3(input, &self.stem_weights, 3, 16, height, width); x = layers::batch_norm( &x, &self.stem_bn.gamma, @@ -316,7 +346,11 @@ impl Backbone for MobileNetV3Large { } /// Process a single inverted residual block - fn process_inverted_residual(input: &[f32], block: &InvertedResidual, in_channels: usize) -> Vec { + fn process_inverted_residual( + input: &[f32], + block: &InvertedResidual, + in_channels: usize, + ) -> Vec { let spatial = input.len() / in_channels; let h = (spatial as f32).sqrt() as usize; let w = h; @@ -365,8 +399,14 @@ impl Backbone for MobileNetV3Large { } } } - x = layers::batch_norm(&dw_out, &block.dw_bn.gamma, &block.dw_bn.beta, - &block.dw_bn.mean, &block.dw_bn.var, 1e-5); + x = layers::batch_norm( + &dw_out, + &block.dw_bn.gamma, + &block.dw_bn.beta, + &block.dw_bn.mean, + &block.dw_bn.var, + 1e-5, + ); x = layers::hard_swish(&x); // SE block (optional) @@ -419,8 +459,14 @@ impl Backbone for MobileNetV3Large { projected[s * out_c + oc] = sum; } } - let output = layers::batch_norm(&projected, &block.project_bn.gamma, &block.project_bn.beta, - &block.project_bn.mean, &block.project_bn.var, 1e-5); + let output = layers::batch_norm( + &projected, + &block.project_bn.gamma, + &block.project_bn.beta, + &block.project_bn.mean, + &block.project_bn.var, + 1e-5, + ); // Residual connection if block.use_residual && in_channels == out_c { @@ -438,7 +484,11 @@ impl Backbone for MobileNetV3Large { // Same helper for Small variant impl MobileNetV3Small { /// Process a single inverted residual block - fn process_inverted_residual(input: &[f32], block: &InvertedResidual, in_channels: usize) -> Vec { + fn process_inverted_residual( + input: &[f32], + block: &InvertedResidual, + in_channels: usize, + ) -> Vec { MobileNetV3Large::process_inverted_residual(input, block, in_channels) } } @@ -655,11 +705,8 @@ impl MobileNetV3 { // Last conv: 1x1 to expand features let feature_dim = config.scale_channels(config.feature_dim); - let last_conv = ConvBNActivation::pointwise( - in_channels, - feature_dim, - ActivationType::HardSwish, - )?; + let last_conv = + ConvBNActivation::pointwise(in_channels, feature_dim, ActivationType::HardSwish)?; // Global average pooling let pool = GlobalAvgPool2d::new(); @@ -725,7 +772,11 @@ impl MobileNetV3 { } /// Forward pass through feature layers only. - fn forward_features_impl(&self, input: &[f32], input_shape: &TensorShape) -> CnnResult> { + fn forward_features_impl( + &self, + input: &[f32], + input_shape: &TensorShape, + ) -> CnnResult> { let mut x = input.to_vec(); let mut shape = *input_shape; diff --git a/crates/ruvector-cnn/src/contrastive/augmentation.rs b/crates/ruvector-cnn/src/contrastive/augmentation.rs index f37615af3..374ea763a 100644 --- a/crates/ruvector-cnn/src/contrastive/augmentation.rs +++ b/crates/ruvector-cnn/src/contrastive/augmentation.rs @@ -118,7 +118,13 @@ impl ContrastiveAugmentationBuilder { } /// Set the color jitter parameters. - pub fn color_jitter(mut self, brightness: f64, contrast: f64, saturation: f64, hue: f64) -> Self { + pub fn color_jitter( + mut self, + brightness: f64, + contrast: f64, + saturation: f64, + hue: f64, + ) -> Self { self.config.brightness = brightness; self.config.contrast = contrast; self.config.saturation = saturation; @@ -277,10 +283,13 @@ impl ContrastiveAugmentation { // Try up to 10 times to find a valid crop for _ in 0..10 { // Sample scale and aspect ratio - let scale = self.rng.gen_range(self.config.crop_scale_min..=self.config.crop_scale_max); - let aspect = self.rng.gen_range( - self.config.aspect_ratio_min.ln()..=self.config.aspect_ratio_max.ln(), - ).exp(); + let scale = self + .rng + .gen_range(self.config.crop_scale_min..=self.config.crop_scale_max); + let aspect = self + .rng + .gen_range(self.config.aspect_ratio_min.ln()..=self.config.aspect_ratio_max.ln()) + .exp(); // Compute crop dimensions let crop_area = orig_area * scale; @@ -352,9 +361,18 @@ impl ContrastiveAugmentation { let mut result = image.clone(); // Sample jitter factors - let brightness_factor = 1.0 + self.rng.gen_range(-self.config.brightness..=self.config.brightness); - let contrast_factor = 1.0 + self.rng.gen_range(-self.config.contrast..=self.config.contrast); - let saturation_factor = 1.0 + self.rng.gen_range(-self.config.saturation..=self.config.saturation); + let brightness_factor = 1.0 + + self + .rng + .gen_range(-self.config.brightness..=self.config.brightness); + let contrast_factor = 1.0 + + self + .rng + .gen_range(-self.config.contrast..=self.config.contrast); + let saturation_factor = 1.0 + + self + .rng + .gen_range(-self.config.saturation..=self.config.saturation); let hue_shift = self.rng.gen_range(-self.config.hue..=self.config.hue); // Compute image mean for contrast adjustment @@ -363,7 +381,11 @@ impl ContrastiveAugmentation { for y in 0..height { for x in 0..width { let pixel = image.get_pixel(x, y); - let mut rgb = [pixel[0] as f64 / 255.0, pixel[1] as f64 / 255.0, pixel[2] as f64 / 255.0]; + let mut rgb = [ + pixel[0] as f64 / 255.0, + pixel[1] as f64 / 255.0, + pixel[2] as f64 / 255.0, + ]; // Apply brightness for c in rgb.iter_mut() { @@ -419,14 +441,20 @@ impl ContrastiveAugmentation { /// Gaussian blur (simplified box blur implementation). #[cfg(feature = "augmentation")] pub fn gaussian_blur(&mut self, image: &RgbImage) -> CnnResult { - let sigma = self.rng.gen_range(self.config.blur_sigma_range.0..=self.config.blur_sigma_range.1); + let sigma = self + .rng + .gen_range(self.config.blur_sigma_range.0..=self.config.blur_sigma_range.1); // Use kernel size from config, or compute from sigma let kernel_size = if self.config.blur_kernel_size > 0 { self.config.blur_kernel_size } else { let k = (sigma * 6.0).ceil() as u32; - if k % 2 == 0 { k + 1 } else { k } + if k % 2 == 0 { + k + 1 + } else { + k + } }; // Generate Gaussian kernel @@ -475,17 +503,22 @@ impl ContrastiveAugmentation { for x in 0..width { let mut sum = [0.0, 0.0, 0.0]; for (i, &k) in kernel.iter().enumerate() { - let sx = (x as i32 + i as i32 - radius as i32).clamp(0, width as i32 - 1) as u32; + let sx = + (x as i32 + i as i32 - radius as i32).clamp(0, width as i32 - 1) as u32; let pixel = image.get_pixel(sx, y); sum[0] += pixel[0] as f64 * k; sum[1] += pixel[1] as f64 * k; sum[2] += pixel[2] as f64 * k; } - temp.put_pixel(x, y, Rgb([ - sum[0].clamp(0.0, 255.0) as u8, - sum[1].clamp(0.0, 255.0) as u8, - sum[2].clamp(0.0, 255.0) as u8, - ])); + temp.put_pixel( + x, + y, + Rgb([ + sum[0].clamp(0.0, 255.0) as u8, + sum[1].clamp(0.0, 255.0) as u8, + sum[2].clamp(0.0, 255.0) as u8, + ]), + ); } } @@ -495,17 +528,22 @@ impl ContrastiveAugmentation { for x in 0..width { let mut sum = [0.0, 0.0, 0.0]; for (i, &k) in kernel.iter().enumerate() { - let sy = (y as i32 + i as i32 - radius as i32).clamp(0, height as i32 - 1) as u32; + let sy = + (y as i32 + i as i32 - radius as i32).clamp(0, height as i32 - 1) as u32; let pixel = temp.get_pixel(x, sy); sum[0] += pixel[0] as f64 * k; sum[1] += pixel[1] as f64 * k; sum[2] += pixel[2] as f64 * k; } - result.put_pixel(x, y, Rgb([ - sum[0].clamp(0.0, 255.0) as u8, - sum[1].clamp(0.0, 255.0) as u8, - sum[2].clamp(0.0, 255.0) as u8, - ])); + result.put_pixel( + x, + y, + Rgb([ + sum[0].clamp(0.0, 255.0) as u8, + sum[1].clamp(0.0, 255.0) as u8, + sum[2].clamp(0.0, 255.0) as u8, + ]), + ); } } @@ -752,9 +790,27 @@ mod tests { let (h, s, v) = rgb_to_hsv(r, g, b); let (r2, g2, b2) = hsv_to_rgb(h, s, v); - assert!((r - r2).abs() < 1e-6, "R mismatch for ({}, {}, {})", r, g, b); - assert!((g - g2).abs() < 1e-6, "G mismatch for ({}, {}, {})", r, g, b); - assert!((b - b2).abs() < 1e-6, "B mismatch for ({}, {}, {})", r, g, b); + assert!( + (r - r2).abs() < 1e-6, + "R mismatch for ({}, {}, {})", + r, + g, + b + ); + assert!( + (g - g2).abs() < 1e-6, + "G mismatch for ({}, {}, {})", + r, + g, + b + ); + assert!( + (b - b2).abs() < 1e-6, + "B mismatch for ({}, {}, {})", + r, + g, + b + ); } } diff --git a/crates/ruvector-cnn/src/contrastive/infonce.rs b/crates/ruvector-cnn/src/contrastive/infonce.rs index 5abac714d..c9db31f07 100644 --- a/crates/ruvector-cnn/src/contrastive/infonce.rs +++ b/crates/ruvector-cnn/src/contrastive/infonce.rs @@ -142,7 +142,9 @@ impl InfoNCELoss { ) -> CnnResult { let n = embeddings.len(); if n == 0 { - return Err(CnnError::InvalidInput("embeddings cannot be empty".to_string())); + return Err(CnnError::InvalidInput( + "embeddings cannot be empty".to_string(), + )); } if n < 2 { return Err(CnnError::InvalidInput( @@ -258,7 +260,12 @@ impl InfoNCELoss { for i in 0..n { matrix[i][i] = 1.0; // Self-similarity for j in (i + 1)..n { - let sim = cosine_similarity_normalized(&embeddings[i], &embeddings[j], norms[i], norms[j]); + let sim = cosine_similarity_normalized( + &embeddings[i], + &embeddings[j], + norms[i], + norms[j], + ); matrix[i][j] = sim; matrix[j][i] = sim; } @@ -293,7 +300,9 @@ impl InfoNCELoss { } if anchors.is_empty() { - return Err(CnnError::InvalidInput("anchors cannot be empty".to_string())); + return Err(CnnError::InvalidInput( + "anchors cannot be empty".to_string(), + )); } let dim = anchors[0].len(); @@ -410,7 +419,10 @@ mod tests { let loss = loss_fn.forward(&embeddings, 2); // Loss should be low for identical pairs - assert!(loss < 5.0, "Loss should be relatively low for identical pairs"); + assert!( + loss < 5.0, + "Loss should be relatively low for identical pairs" + ); } #[test] @@ -433,11 +445,7 @@ mod tests { fn test_similarity_matrix() { let loss_fn = InfoNCELoss::new(0.07); - let embeddings = vec![ - vec![1.0, 0.0], - vec![0.0, 1.0], - vec![1.0, 1.0], - ]; + let embeddings = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]]; let sim_matrix = loss_fn.compute_similarity_matrix(&embeddings); diff --git a/crates/ruvector-cnn/src/contrastive/triplet.rs b/crates/ruvector-cnn/src/contrastive/triplet.rs index 759fc30fc..94477e0da 100644 --- a/crates/ruvector-cnn/src/contrastive/triplet.rs +++ b/crates/ruvector-cnn/src/contrastive/triplet.rs @@ -179,7 +179,11 @@ impl TripletLoss { } // Check for NaN/Inf - for (name, vec) in [("anchor", anchor), ("positive", positive), ("negative", negative)] { + for (name, vec) in [ + ("anchor", anchor), + ("positive", positive), + ("negative", negative), + ] { if vec.iter().any(|x| x.is_nan() || x.is_infinite()) { return Err(CnnError::InvalidInput(format!( "{} contains NaN or Inf", @@ -420,7 +424,9 @@ mod tests { let positive = vec![1.0, 0.0]; // identical let negative = vec![-1.0, 0.0]; // opposite - let result = triplet.forward_detailed(&anchor, &positive, &negative).unwrap(); + let result = triplet + .forward_detailed(&anchor, &positive, &negative) + .unwrap(); assert_eq!(result.loss, 0.0); assert!(!result.is_hard); } @@ -434,7 +440,9 @@ mod tests { let positive = vec![2.0, 0.0]; let negative = vec![1.0, 0.0]; - let result = triplet.forward_detailed(&anchor, &positive, &negative).unwrap(); + let result = triplet + .forward_detailed(&anchor, &positive, &negative) + .unwrap(); assert!(result.loss > 0.0); assert!(result.is_hard); assert!(result.violates_margin); @@ -490,7 +498,9 @@ mod tests { let positives = vec![vec![0.9, 0.1], vec![0.1, 0.9]]; let negatives = vec![vec![0.0, 1.0], vec![1.0, 0.0]]; - let loss = triplet.forward_batch(&anchors, &positives, &negatives).unwrap(); + let loss = triplet + .forward_batch(&anchors, &positives, &negatives) + .unwrap(); assert!(loss >= 0.0); } diff --git a/crates/ruvector-cnn/src/embedding.rs b/crates/ruvector-cnn/src/embedding.rs index 9c365f530..36ef2e693 100644 --- a/crates/ruvector-cnn/src/embedding.rs +++ b/crates/ruvector-cnn/src/embedding.rs @@ -407,7 +407,9 @@ mod tests { let batch_size = 2; let images = vec![0.5f32; batch_size * 3 * 224 * 224]; - let embeddings = embedder.extract_batch(&images, batch_size, 224, 224).unwrap(); + let embeddings = embedder + .extract_batch(&images, batch_size, 224, 224) + .unwrap(); assert_eq!(embeddings.len(), batch_size); for embedding in &embeddings { @@ -434,7 +436,9 @@ mod tests { #[test] fn test_without_normalization() { - let embedder = MobileNetEmbedder::v3_small().unwrap().without_normalization(); + let embedder = MobileNetEmbedder::v3_small() + .unwrap() + .without_normalization(); assert!(!embedder.is_normalized()); } } diff --git a/crates/ruvector-cnn/src/error.rs b/crates/ruvector-cnn/src/error.rs index 2b7abca0d..ef3f309f1 100644 --- a/crates/ruvector-cnn/src/error.rs +++ b/crates/ruvector-cnn/src/error.rs @@ -102,7 +102,9 @@ pub enum CnnError { NormalizationError(String), /// Invalid kernel configuration. - #[error("Invalid kernel: kernel_size={kernel_size}, but input spatial dims are ({height}, {width})")] + #[error( + "Invalid kernel: kernel_size={kernel_size}, but input spatial dims are ({height}, {width})" + )] InvalidKernel { /// Kernel size kernel_size: usize, diff --git a/crates/ruvector-cnn/src/kernels/int8_avx2.rs b/crates/ruvector-cnn/src/kernels/int8_avx2.rs index d9c5017bf..cb063657b 100644 --- a/crates/ruvector-cnn/src/kernels/int8_avx2.rs +++ b/crates/ruvector-cnn/src/kernels/int8_avx2.rs @@ -150,7 +150,9 @@ pub unsafe fn conv2d_int8_avx2( // Load 32 input activations let input_base = (ih * in_w + iw) * in_c + ic_base; - let va = _mm256_loadu_si256(input.as_ptr().add(input_base) as *const __m256i); + let va = _mm256_loadu_si256( + input.as_ptr().add(input_base) as *const __m256i + ); // For each output channel in this chunk for i in 0..8 { @@ -346,8 +348,8 @@ pub unsafe fn depthwise_conv2d_int8_avx2( #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx2")] pub unsafe fn matmul_int8_avx2( - a: &[u8], // M x K (unsigned activations) - b: &[i8], // K x N (signed weights) + a: &[u8], // M x K (unsigned activations) + b: &[i8], // K x N (signed weights) output: &mut [i32], // M x N (int32 accumulators) m: usize, k: usize, @@ -535,7 +537,17 @@ mod tests { if is_x86_feature_detected!("avx2") { unsafe { conv2d_int8_avx2( - &input, 0, &kernel, &bias, &mut output, in_h, in_w, in_c, out_c, 1, 0, + &input, + 0, + &kernel, + &bias, + &mut output, + in_h, + in_w, + in_c, + out_c, + 1, + 0, ); // All outputs should be 10 * 9 = 90 (10 input * 9 weights) diff --git a/crates/ruvector-cnn/src/kernels/int8_neon.rs b/crates/ruvector-cnn/src/kernels/int8_neon.rs index 474770fb1..6c1ccd195 100644 --- a/crates/ruvector-cnn/src/kernels/int8_neon.rs +++ b/crates/ruvector-cnn/src/kernels/int8_neon.rs @@ -168,7 +168,8 @@ pub unsafe fn conv2d_int8_neon( let total = vaddq_s32(sum_low, sum_high); // Horizontal sum - let sum_pair = vpadd_s32(vget_low_s32(total), vget_high_s32(total)); + let sum_pair = + vpadd_s32(vget_low_s32(total), vget_high_s32(total)); let sum_final = vpadd_s32(sum_pair, sum_pair); acc[i] += vget_lane_s32(sum_final, 0); } diff --git a/crates/ruvector-cnn/src/kernels/int8_wasm.rs b/crates/ruvector-cnn/src/kernels/int8_wasm.rs index 67fc1f1fc..b731a8076 100644 --- a/crates/ruvector-cnn/src/kernels/int8_wasm.rs +++ b/crates/ruvector-cnn/src/kernels/int8_wasm.rs @@ -143,7 +143,8 @@ pub unsafe fn conv2d_int8_wasm( let input_base = (ih * in_w + iw) * in_c + ic_base; // Load 16 u8 inputs and convert to i8 - let input_u8 = v128_load(input.as_ptr().add(input_base) as *const v128); + let input_u8 = + v128_load(input.as_ptr().add(input_base) as *const v128); let offset = u8x16_splat(128); let input_shifted = u8x16_sub(input_u8, offset); diff --git a/crates/ruvector-cnn/src/layers/batchnorm.rs b/crates/ruvector-cnn/src/layers/batchnorm.rs index bfb221648..a14c9bc43 100644 --- a/crates/ruvector-cnn/src/layers/batchnorm.rs +++ b/crates/ruvector-cnn/src/layers/batchnorm.rs @@ -190,7 +190,8 @@ mod tests { let mut bn = BatchNorm::new(2); // Set mean=[1, 2], var=[1, 4] - bn.set_running_stats(vec![1.0, 2.0], vec![1.0, 4.0]).unwrap(); + bn.set_running_stats(vec![1.0, 2.0], vec![1.0, 4.0]) + .unwrap(); // Input: [[1, 2], [3, 4]] at each spatial location let input = Tensor::from_data(vec![1.0, 2.0, 3.0, 4.0], &[1, 2, 1, 2]).unwrap(); diff --git a/crates/ruvector-cnn/src/layers/conv.rs b/crates/ruvector-cnn/src/layers/conv.rs index dc8e053ab..87b586346 100644 --- a/crates/ruvector-cnn/src/layers/conv.rs +++ b/crates/ruvector-cnn/src/layers/conv.rs @@ -90,18 +90,21 @@ impl Conv2dBuilder { /// Build the Conv2d layer pub fn build(self) -> CnnResult { if self.in_channels % self.groups != 0 { - return Err(CnnError::InvalidParameter( - format!("in_channels {} must be divisible by groups {}", self.in_channels, self.groups) - )); + return Err(CnnError::InvalidParameter(format!( + "in_channels {} must be divisible by groups {}", + self.in_channels, self.groups + ))); } if self.out_channels % self.groups != 0 { - return Err(CnnError::InvalidParameter( - format!("out_channels {} must be divisible by groups {}", self.out_channels, self.groups) - )); + return Err(CnnError::InvalidParameter(format!( + "out_channels {} must be divisible by groups {}", + self.out_channels, self.groups + ))); } let in_channels_per_group = self.in_channels / self.groups; - let num_weights = self.out_channels * self.kernel_size * self.kernel_size * in_channels_per_group; + let num_weights = + self.out_channels * self.kernel_size * self.kernel_size * in_channels_per_group; // Xavier/Glorot initialization let fan_in = in_channels_per_group * self.kernel_size * self.kernel_size; @@ -329,7 +332,10 @@ impl Layer for Conv2d { self.stride, self.padding, ); - } else if self.kernel_size == 3 && self.groups == self.in_channels && self.in_channels == self.out_channels { + } else if self.kernel_size == 3 + && self.groups == self.in_channels + && self.in_channels == self.out_channels + { // Depthwise 3x3 convolution (groups == in_channels == out_channels) simd::depthwise_conv_3x3_simd( input_slice, @@ -405,18 +411,15 @@ impl Conv2d { let ih = (oh * self.stride + kh) as isize - self.padding as isize; let iw = (ow * self.stride + kw) as isize - self.padding as isize; - if ih >= 0 - && ih < in_h as isize - && iw >= 0 - && iw < in_w as isize - { + if ih >= 0 && ih < in_h as isize && iw >= 0 && iw < in_w as isize { let ih = ih as usize; let iw = iw as usize; for ic_local in 0..in_channels_per_group { let ic = in_c_start + ic_local; - let input_idx = - ih * in_w * self.in_channels + iw * self.in_channels + ic; + let input_idx = ih * in_w * self.in_channels + + iw * self.in_channels + + ic; // Kernel layout: [out_c, kh, kw, in_c_per_group] let kernel_idx = oc * ks * ks * in_channels_per_group + kh * ks * in_channels_per_group @@ -560,7 +563,8 @@ impl Layer for DepthwiseSeparableConv { for b in 0..batch { let input_slice = &input.data()[b * batch_in_size..(b + 1) * batch_in_size]; - let output_slice = &mut dw_output.data_mut()[b * batch_dw_size..(b + 1) * batch_dw_size]; + let output_slice = + &mut dw_output.data_mut()[b * batch_dw_size..(b + 1) * batch_dw_size]; if self.kernel_size == 3 { simd::depthwise_conv_3x3_simd( @@ -636,11 +640,7 @@ impl DepthwiseSeparableConv { let ih = (oh * self.stride + kh) as isize - self.padding as isize; let iw = (ow * self.stride + kw) as isize - self.padding as isize; - if ih >= 0 - && ih < in_h as isize - && iw >= 0 - && iw < in_w as isize - { + if ih >= 0 && ih < in_h as isize && iw >= 0 && iw < in_w as isize { let ih = ih as usize; let iw = iw as usize; diff --git a/crates/ruvector-cnn/src/layers/linear.rs b/crates/ruvector-cnn/src/layers/linear.rs index c733f7df7..2a4b317c6 100644 --- a/crates/ruvector-cnn/src/layers/linear.rs +++ b/crates/ruvector-cnn/src/layers/linear.rs @@ -250,13 +250,8 @@ mod tests { #[test] fn test_linear_forward_with_bias() { - let linear = Linear::with_weights( - 2, - 2, - vec![1.0, 0.0, 0.0, 1.0], - Some(vec![5.0, 10.0]), - ) - .unwrap(); + let linear = + Linear::with_weights(2, 2, vec![1.0, 0.0, 0.0, 1.0], Some(vec![5.0, 10.0])).unwrap(); let input = vec![1.0, 2.0]; let output = linear.forward_vec(&input).unwrap(); @@ -267,13 +262,7 @@ mod tests { #[test] fn test_linear_forward_batch() { - let linear = Linear::with_weights( - 2, - 2, - vec![1.0, 0.0, 0.0, 1.0], - None, - ) - .unwrap(); + let linear = Linear::with_weights(2, 2, vec![1.0, 0.0, 0.0, 1.0], None).unwrap(); let input = vec![1.0, 2.0, 3.0, 4.0]; // batch of 2 let output = linear.forward_batch(&input, 2).unwrap(); diff --git a/crates/ruvector-cnn/src/layers/mod.rs b/crates/ruvector-cnn/src/layers/mod.rs index 35dfa3376..a49ae5888 100644 --- a/crates/ruvector-cnn/src/layers/mod.rs +++ b/crates/ruvector-cnn/src/layers/mod.rs @@ -78,7 +78,13 @@ impl std::fmt::Display for TensorShape { } /// Computes output size for convolution or pooling. -pub fn conv_output_size(input: usize, kernel: usize, stride: usize, padding: usize, dilation: usize) -> usize { +pub fn conv_output_size( + input: usize, + kernel: usize, + stride: usize, + padding: usize, + dilation: usize, +) -> usize { let effective_kernel = dilation * (kernel - 1) + 1; (input + 2 * padding - effective_kernel) / stride + 1 } diff --git a/crates/ruvector-cnn/src/layers/pooling.rs b/crates/ruvector-cnn/src/layers/pooling.rs index 6d2fdc9ba..158ee2b3a 100644 --- a/crates/ruvector-cnn/src/layers/pooling.rs +++ b/crates/ruvector-cnn/src/layers/pooling.rs @@ -293,7 +293,7 @@ mod tests { // Create input with known values: channel 0 = 1, channel 1 = 2 let mut data = vec![0.0; 2 * 2 * 2]; for i in 0..4 { - data[i * 2] = 1.0; // channel 0 + data[i * 2] = 1.0; // channel 0 data[i * 2 + 1] = 2.0; // channel 1 } let input = Tensor::from_data(data, &[1, 2, 2, 2]).unwrap(); diff --git a/crates/ruvector-cnn/src/layers/quantized_conv2d.rs b/crates/ruvector-cnn/src/layers/quantized_conv2d.rs index 0395c5e4d..9d7c270eb 100644 --- a/crates/ruvector-cnn/src/layers/quantized_conv2d.rs +++ b/crates/ruvector-cnn/src/layers/quantized_conv2d.rs @@ -6,10 +6,7 @@ //! - Weight packing for SIMD efficiency //! - Fused bias and requantization -use crate::{ - simd::quantize::QuantParams, - CnnError, CnnResult, Tensor, -}; +use crate::{simd::quantize::QuantParams, CnnError, CnnResult, Tensor}; use super::{Conv2d, Layer, TensorShape}; @@ -51,11 +48,7 @@ impl QuantizedConv2d { /// * `conv` - FP32 convolution layer to quantize /// * `input_scale` - Expected input activation scale /// * `input_zero_point` - Expected input zero point - pub fn from_fp32( - conv: &Conv2d, - input_scale: f32, - input_zero_point: i32, - ) -> Self { + pub fn from_fp32(conv: &Conv2d, input_scale: f32, input_zero_point: i32) -> Self { let out_c = conv.out_channels(); let in_c = conv.in_channels(); let ks = conv.kernel_size(); @@ -99,7 +92,8 @@ impl QuantizedConv2d { } // Pre-compute bias in i32 accumulator space - let bias_f32 = conv.bias() + let bias_f32 = conv + .bias() .map(|b| b.to_vec()) .unwrap_or_else(|| vec![0.0; out_c]); let mut bias_q = vec![0i32; out_c]; @@ -147,7 +141,7 @@ impl QuantizedConv2d { if input_shape.len() != 4 { return Err(CnnError::invalid_shape( "4D input (NHWC)", - format!("{}D", input_shape.len()) + format!("{}D", input_shape.len()), )); } @@ -159,7 +153,7 @@ impl QuantizedConv2d { if in_c != self.in_channels { return Err(CnnError::invalid_shape( format!("{} input channels", self.in_channels), - format!("{} channels", in_c) + format!("{} channels", in_c), )); } @@ -185,7 +179,10 @@ impl QuantizedConv2d { input_slice, input_zero_point as i32, output_slice, - in_h, in_w, out_h, out_w, + in_h, + in_w, + out_h, + out_w, ); } } else { @@ -193,7 +190,10 @@ impl QuantizedConv2d { input_slice, input_zero_point as i32, output_slice, - in_h, in_w, out_h, out_w, + in_h, + in_w, + out_h, + out_w, ); } } @@ -204,7 +204,10 @@ impl QuantizedConv2d { input_slice, input_zero_point as i32, output_slice, - in_h, in_w, out_h, out_w, + in_h, + in_w, + out_h, + out_w, ); } } @@ -212,10 +215,7 @@ impl QuantizedConv2d { // Dequantize i32 accumulator to f32 let output_f32 = self.dequantize_output(&output_i32, input_scale); - Tensor::from_data( - output_f32, - &[batch, out_h, out_w, self.out_channels], - ) + Tensor::from_data(output_f32, &[batch, out_h, out_w, self.out_channels]) } /// Scalar INT8 convolution implementation @@ -264,9 +264,11 @@ impl QuantizedConv2d { for ic in 0..self.in_channels { let input_idx = (ih * in_w + iw) * self.in_channels + ic; - let weight_idx = (oc * self.in_channels + ic) * ks * ks + kh * ks + kw; + let weight_idx = + (oc * self.in_channels + ic) * ks * ks + kh * ks + kw; - acc += (input[input_idx] as i32) * (self.weights_q[weight_idx] as i32); + acc += (input[input_idx] as i32) + * (self.weights_q[weight_idx] as i32); } } } diff --git a/crates/ruvector-cnn/src/layers/quantized_depthwise.rs b/crates/ruvector-cnn/src/layers/quantized_depthwise.rs index 7f523069c..6887108e8 100644 --- a/crates/ruvector-cnn/src/layers/quantized_depthwise.rs +++ b/crates/ruvector-cnn/src/layers/quantized_depthwise.rs @@ -67,11 +67,7 @@ impl QuantizedDepthwiseConv2d { max_abs = max_abs.max(weights[idx].abs()); } } - weight_scales[c] = if max_abs > 0.0 { - max_abs / 127.0 - } else { - 1.0 - }; + weight_scales[c] = if max_abs > 0.0 { max_abs / 127.0 } else { 1.0 }; } // Quantize weights @@ -88,7 +84,9 @@ impl QuantizedDepthwiseConv2d { } // Pre-compute bias in i32 accumulator space - let bias_f32 = bias.map(|b| b.to_vec()).unwrap_or_else(|| vec![0.0; channels]); + let bias_f32 = bias + .map(|b| b.to_vec()) + .unwrap_or_else(|| vec![0.0; channels]); let mut bias_q = vec![0i32; channels]; for c in 0..channels { @@ -129,7 +127,7 @@ impl QuantizedDepthwiseConv2d { if input_shape.len() != 4 { return Err(CnnError::invalid_shape( "4D input (NHWC)", - format!("{}D", input_shape.len()) + format!("{}D", input_shape.len()), )); } @@ -141,7 +139,7 @@ impl QuantizedDepthwiseConv2d { if in_c != self.channels { return Err(CnnError::invalid_shape( format!("{} channels", self.channels), - format!("{} channels", in_c) + format!("{} channels", in_c), )); } @@ -162,17 +160,17 @@ impl QuantizedDepthwiseConv2d { input_slice, input_zero_point as i32, output_slice, - in_h, in_w, out_h, out_w, + in_h, + in_w, + out_h, + out_w, ); } // Dequantize to f32 let output_f32 = self.dequantize_output(&output_i32, input_scale); - Tensor::from_data( - output_f32, - &[batch, out_h, out_w, self.channels], - ) + Tensor::from_data(output_f32, &[batch, out_h, out_w, self.channels]) } /// Scalar depthwise convolution implementation @@ -221,7 +219,8 @@ impl QuantizedDepthwiseConv2d { let input_idx = (ih * in_w + iw) * self.channels + c; let weight_idx = c * ks * ks + kh * ks + kw; - acc += (input[input_idx] as i32) * (self.weights_q[weight_idx] as i32); + acc += + (input[input_idx] as i32) * (self.weights_q[weight_idx] as i32); } } } @@ -277,15 +276,8 @@ mod tests { let kernel_size = 3; let weights = vec![0.1f32; channels * kernel_size * kernel_size]; - let qconv = QuantizedDepthwiseConv2d::from_fp32( - channels, - kernel_size, - &weights, - None, - 1, - 1, - 0.01, - ); + let qconv = + QuantizedDepthwiseConv2d::from_fp32(channels, kernel_size, &weights, None, 1, 1, 0.01); let input = vec![128u8; 1 * 8 * 8 * channels]; let input_shape = &[1, 8, 8, channels]; diff --git a/crates/ruvector-cnn/src/layers/quantized_linear.rs b/crates/ruvector-cnn/src/layers/quantized_linear.rs index 6735c211c..cf6c5dd0d 100644 --- a/crates/ruvector-cnn/src/layers/quantized_linear.rs +++ b/crates/ruvector-cnn/src/layers/quantized_linear.rs @@ -52,11 +52,7 @@ impl QuantizedLinear { let idx = of * in_features + if_; max_abs = max_abs.max(weights[idx].abs()); } - weight_scales[of] = if max_abs > 0.0 { - max_abs / 127.0 - } else { - 1.0 - }; + weight_scales[of] = if max_abs > 0.0 { max_abs / 127.0 } else { 1.0 }; } // Quantize weights @@ -71,7 +67,8 @@ impl QuantizedLinear { } // Pre-compute bias in i32 accumulator space - let bias_f32 = linear.bias() + let bias_f32 = linear + .bias() .map(|b| b.to_vec()) .unwrap_or_else(|| vec![0.0; out_features]); let mut bias_q = vec![0i32; out_features]; @@ -112,7 +109,7 @@ impl QuantizedLinear { if input.len() != batch_size * self.in_features { return Err(CnnError::invalid_shape( format!("input size {}", batch_size * self.in_features), - format!("size {}", input.len()) + format!("size {}", input.len()), )); } diff --git a/crates/ruvector-cnn/src/layers/quantized_pooling.rs b/crates/ruvector-cnn/src/layers/quantized_pooling.rs index d1ac2ed51..4bd02e9bd 100644 --- a/crates/ruvector-cnn/src/layers/quantized_pooling.rs +++ b/crates/ruvector-cnn/src/layers/quantized_pooling.rs @@ -44,7 +44,7 @@ impl QuantizedMaxPool2d { if input_shape.len() != 4 { return Err(CnnError::invalid_shape( "4D input (NHWC)", - format!("{}D", input_shape.len()) + format!("{}D", input_shape.len()), )); } @@ -85,7 +85,12 @@ impl QuantizedMaxPool2d { } } - Ok((output, vec![batch, out_h, out_w, channels], scale, zero_point)) + Ok(( + output, + vec![batch, out_h, out_w, channels], + scale, + zero_point, + )) } } @@ -130,7 +135,7 @@ impl QuantizedAvgPool2d { if input_shape.len() != 4 { return Err(CnnError::invalid_shape( "4D input (NHWC)", - format!("{}D", input_shape.len()) + format!("{}D", input_shape.len()), )); } @@ -184,12 +189,15 @@ impl QuantizedAvgPool2d { } // Convert i16 back to u8 - let output: Vec = output_i16.iter() - .map(|&v| v.clamp(0, 255) as u8) - .collect(); + let output: Vec = output_i16.iter().map(|&v| v.clamp(0, 255) as u8).collect(); // Output scale remains the same as input for average pooling - Ok((output, vec![batch, out_h, out_w, channels], input_scale, input_zero_point)) + Ok(( + output, + vec![batch, out_h, out_w, channels], + input_scale, + input_zero_point, + )) } } @@ -202,14 +210,12 @@ mod tests { let pool = QuantizedMaxPool2d::new(2, 2, 0); let input = vec![ - 100, 150, 200, 255, - 120, 180, 210, 230, - 110, 140, 190, 240, - 130, 160, 220, 250, + 100, 150, 200, 255, 120, 180, 210, 230, 110, 140, 190, 240, 130, 160, 220, 250, ]; let input_shape = &[1, 4, 4, 1]; - let (output, output_shape, scale, _zp) = pool.forward_int8(&input, input_shape, 0.01, 0).unwrap(); + let (output, output_shape, scale, _zp) = + pool.forward_int8(&input, input_shape, 0.01, 0).unwrap(); assert_eq!(output_shape, vec![1, 2, 2, 1]); assert_eq!(scale, 0.01); @@ -223,14 +229,12 @@ mod tests { let pool = QuantizedAvgPool2d::new(2, 2, 0); let input = vec![ - 100, 100, 200, 200, - 100, 100, 200, 200, - 100, 100, 200, 200, - 100, 100, 200, 200, + 100, 100, 200, 200, 100, 100, 200, 200, 100, 100, 200, 200, 100, 100, 200, 200, ]; let input_shape = &[1, 4, 4, 1]; - let (output, output_shape, scale, _zp) = pool.forward_int8(&input, input_shape, 0.01, 0).unwrap(); + let (output, output_shape, scale, _zp) = + pool.forward_int8(&input, input_shape, 0.01, 0).unwrap(); assert_eq!(output_shape, vec![1, 2, 2, 1]); assert_eq!(scale, 0.01); @@ -247,7 +251,8 @@ mod tests { let input = vec![100u8; 1 * 4 * 4 * 1]; let input_shape = &[1, 4, 4, 1]; - let (_output, output_shape, _, _) = pool.forward_int8(&input, input_shape, 0.01, 50).unwrap(); + let (_output, output_shape, _, _) = + pool.forward_int8(&input, input_shape, 0.01, 50).unwrap(); assert_eq!(output_shape, vec![1, 4, 4, 1]); } diff --git a/crates/ruvector-cnn/src/layers/quantized_residual.rs b/crates/ruvector-cnn/src/layers/quantized_residual.rs index 13c084ee9..0f8c2a24d 100644 --- a/crates/ruvector-cnn/src/layers/quantized_residual.rs +++ b/crates/ruvector-cnn/src/layers/quantized_residual.rs @@ -67,7 +67,7 @@ impl QuantizedResidualAdd { if input1.len() != input2.len() { return Err(CnnError::invalid_shape( format!("input size {}", input1.len()), - format!("size {}", input2.len()) + format!("size {}", input2.len()), )); } @@ -89,7 +89,9 @@ impl QuantizedResidualAdd { let sum = val1 + val2; // Requantize to output - let output_q = (sum + self.output_zero_point as f32).round().clamp(0.0, 255.0); + let output_q = (sum + self.output_zero_point as f32) + .round() + .clamp(0.0, 255.0); output[i] = output_q as u8; } @@ -112,7 +114,7 @@ impl QuantizedResidualAdd { if input1.len() != input2.len() { return Err(CnnError::invalid_shape( format!("input size {}", input1.len()), - format!("size {}", input2.len()) + format!("size {}", input2.len()), )); } diff --git a/crates/ruvector-cnn/src/lib.rs b/crates/ruvector-cnn/src/lib.rs index c08bf7435..1e0bdc94b 100644 --- a/crates/ruvector-cnn/src/lib.rs +++ b/crates/ruvector-cnn/src/lib.rs @@ -38,13 +38,13 @@ mod error; mod tensor; // Core modules (always available) +pub mod kernels; pub mod layers; pub mod simd; -pub mod kernels; // Quantization support (INT8 optimization) -pub mod quantize; pub mod int8; +pub mod quantize; // Optional modules (require backbone feature due to API incompatibility) #[cfg(feature = "backbone")] @@ -61,19 +61,16 @@ pub use tensor::Tensor; // Re-export backbone types (only when feature enabled) #[cfg(feature = "backbone")] pub use backbone::{ - Backbone, BackboneExt, BackboneType, - MobileNetV3, MobileNetV3Config, - MobileNetV3Small, MobileNetV3Large, MobileNetConfig, - ConvBNActivation, InvertedResidual, SqueezeExcitation, - create_backbone, mobilenet_v3_small, mobilenet_v3_large, + create_backbone, mobilenet_v3_large, mobilenet_v3_small, Backbone, BackboneExt, BackboneType, + ConvBNActivation, InvertedResidual, MobileNetConfig, MobileNetV3, MobileNetV3Config, + MobileNetV3Large, MobileNetV3Small, SqueezeExcitation, }; // Re-export embedding types (only when feature enabled) #[cfg(feature = "backbone")] pub use embedding::{ - MobileNetEmbedder, EmbeddingExtractorExt, - EmbeddingConfig as MobileNetEmbeddingConfig, - cosine_similarity, euclidean_distance, + cosine_similarity, euclidean_distance, EmbeddingConfig as MobileNetEmbeddingConfig, + EmbeddingExtractorExt, MobileNetEmbedder, }; // ParallelEmbedding requires the `parallel` feature (not yet implemented) @@ -180,7 +177,10 @@ impl CnnEmbedder { if image_data.len() != expected_size { return Err(CnnError::InvalidInput(format!( "Expected {} bytes for {}x{} RGBA image, got {}", - expected_size, width, height, image_data.len() + expected_size, + width, + height, + image_data.len() ))); } @@ -336,7 +336,8 @@ mod tests { embedding_dim: 8, normalize: true, quantized: false, - }).unwrap(); + }) + .unwrap(); let image_data = vec![128u8; 4 * 4 * 4]; let embedding = embedder.extract(&image_data, 4, 4).unwrap(); diff --git a/crates/ruvector-cnn/src/quantize/graph_rewrite.rs b/crates/ruvector-cnn/src/quantize/graph_rewrite.rs index 8d96e4a64..c13cc95b9 100644 --- a/crates/ruvector-cnn/src/quantize/graph_rewrite.rs +++ b/crates/ruvector-cnn/src/quantize/graph_rewrite.rs @@ -329,9 +329,19 @@ pub fn insert_qdq_nodes( ); // Reconnect: input → Q → node - graph.nodes.get_mut(&input_id).unwrap().outputs.retain(|&x| x != node_id); + graph + .nodes + .get_mut(&input_id) + .unwrap() + .outputs + .retain(|&x| x != node_id); graph.nodes.get_mut(&input_id).unwrap().outputs.push(q_id); - graph.nodes.get_mut(&node_id).unwrap().inputs.retain(|&x| x != input_id); + graph + .nodes + .get_mut(&node_id) + .unwrap() + .inputs + .retain(|&x| x != input_id); graph.nodes.get_mut(&node_id).unwrap().inputs.push(q_id); graph.nodes.get_mut(&q_id).unwrap().inputs.push(input_id); graph.nodes.get_mut(&q_id).unwrap().outputs.push(node_id); @@ -363,9 +373,19 @@ pub fn insert_qdq_nodes( ); // Reconnect: node → DQ → output - graph.nodes.get_mut(&node_id).unwrap().outputs.retain(|&x| x != output_id); + graph + .nodes + .get_mut(&node_id) + .unwrap() + .outputs + .retain(|&x| x != output_id); graph.nodes.get_mut(&node_id).unwrap().outputs.push(dq_id); - graph.nodes.get_mut(&output_id).unwrap().inputs.retain(|&x| x != node_id); + graph + .nodes + .get_mut(&output_id) + .unwrap() + .inputs + .retain(|&x| x != node_id); graph.nodes.get_mut(&output_id).unwrap().inputs.push(dq_id); graph.nodes.get_mut(&dq_id).unwrap().inputs.push(node_id); graph.nodes.get_mut(&dq_id).unwrap().outputs.push(output_id); @@ -741,13 +761,16 @@ mod tests { let mut graph = ComputationGraph::new(); let id1 = graph.add_node(NodeType::Input, NodeParams::None); - let id2 = graph.add_node(NodeType::Conv2d, NodeParams::Conv2d { - weights: vec![1.0; 4], - bias: None, - in_channels: 1, - out_channels: 1, - kernel_size: 2, - }); + let id2 = graph.add_node( + NodeType::Conv2d, + NodeParams::Conv2d { + weights: vec![1.0; 4], + bias: None, + in_channels: 1, + out_channels: 1, + kernel_size: 2, + }, + ); let id3 = graph.add_node(NodeType::Output, NodeParams::None); graph.connect(id1, id2); diff --git a/crates/ruvector-cnn/src/quantize/mod.rs b/crates/ruvector-cnn/src/quantize/mod.rs index 2dbdb3783..c6bbaba5d 100644 --- a/crates/ruvector-cnn/src/quantize/mod.rs +++ b/crates/ruvector-cnn/src/quantize/mod.rs @@ -21,13 +21,12 @@ pub mod calibration; pub mod graph_rewrite; // Phase 1 exports -pub use params::{QuantizationParams as QuantParams, QuantizationScheme, QuantizationMode}; -pub use tensor::{QuantizedTensor, QuantizationMetadata}; +pub use params::{QuantizationMode, QuantizationParams as QuantParams, QuantizationScheme}; +pub use tensor::{QuantizationMetadata, QuantizedTensor}; // Existing exports (kept for backward compatibility) pub use calibration::{CalibrationHistogram, QuantizationParams, Quantizer}; pub use graph_rewrite::{ - ComputationGraph, GraphNode, NodeParams, NodeType, - fuse_batchnorm_to_conv, fuse_relu, fuse_hardswish, fuse_zp_to_bias, - generate_hardswish_lut, insert_qdq_nodes, + fuse_batchnorm_to_conv, fuse_hardswish, fuse_relu, fuse_zp_to_bias, generate_hardswish_lut, + insert_qdq_nodes, ComputationGraph, GraphNode, NodeParams, NodeType, }; diff --git a/crates/ruvector-cnn/src/quantize/params.rs b/crates/ruvector-cnn/src/quantize/params.rs index 1d61ad381..5e3cebbe6 100644 --- a/crates/ruvector-cnn/src/quantize/params.rs +++ b/crates/ruvector-cnn/src/quantize/params.rs @@ -85,7 +85,7 @@ impl QuantizationParams { // Asymmetric: Map [min_val, max_val] to [-127, 127] (255 bins) // to maintain compatibility with i8 storage let scale = if max_val > min_val { - (max_val - min_val) / 254.0 // Use 254 to avoid clipping at edges + (max_val - min_val) / 254.0 // Use 254 to avoid clipping at edges } else { 1.0 }; @@ -230,8 +230,8 @@ mod tests { #[test] fn test_symmetric_minmax() { - let params = QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric) - .unwrap(); + let params = + QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric).unwrap(); assert_eq!(params.zero_point, 0); assert!(params.scale > 0.0); @@ -244,8 +244,8 @@ mod tests { #[test] fn test_asymmetric_minmax() { - let params = QuantizationParams::from_minmax(0.0, 10.0, QuantizationMode::Asymmetric) - .unwrap(); + let params = + QuantizationParams::from_minmax(0.0, 10.0, QuantizationMode::Asymmetric).unwrap(); // For [0, 10] range, zero_point should map 0.0 to a quantized value assert!(params.scale > 0.0); @@ -257,8 +257,8 @@ mod tests { #[test] fn test_quantize_dequantize_symmetric() { - let params = QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric) - .unwrap(); + let params = + QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric).unwrap(); let value = 5.0f32; let quantized = params.quantize_value(value); @@ -270,8 +270,8 @@ mod tests { #[test] fn test_quantize_dequantize_asymmetric() { - let params = QuantizationParams::from_minmax(0.0, 10.0, QuantizationMode::Asymmetric) - .unwrap(); + let params = + QuantizationParams::from_minmax(0.0, 10.0, QuantizationMode::Asymmetric).unwrap(); let value = 5.0f32; let quantized = params.quantize_value(value); @@ -282,8 +282,8 @@ mod tests { #[test] fn test_zero_value_quantization() { - let params = QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric) - .unwrap(); + let params = + QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric).unwrap(); let quantized = params.quantize_value(0.0); assert_eq!(quantized, 0); @@ -294,8 +294,8 @@ mod tests { #[test] fn test_clipping() { - let params = QuantizationParams::from_minmax(-1.0, 1.0, QuantizationMode::Symmetric) - .unwrap(); + let params = + QuantizationParams::from_minmax(-1.0, 1.0, QuantizationMode::Symmetric).unwrap(); // Values outside range should be clipped let large = params.quantize_value(1000.0); @@ -313,8 +313,8 @@ mod tests { #[test] fn test_percentile_constructor() { - let params = QuantizationParams::from_percentile(-9.5, 9.5, QuantizationMode::Symmetric) - .unwrap(); + let params = + QuantizationParams::from_percentile(-9.5, 9.5, QuantizationMode::Symmetric).unwrap(); assert_eq!(params.zero_point, 0); params.validate().unwrap(); @@ -322,8 +322,8 @@ mod tests { #[test] fn test_validation_negative_scale() { - let mut params = QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric) - .unwrap(); + let mut params = + QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric).unwrap(); params.scale = -1.0; assert!(params.validate().is_err()); @@ -331,8 +331,8 @@ mod tests { #[test] fn test_validation_zero_scale() { - let mut params = QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric) - .unwrap(); + let mut params = + QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric).unwrap(); params.scale = 0.0; assert!(params.validate().is_err()); @@ -340,8 +340,8 @@ mod tests { #[test] fn test_validation_invalid_qmin_qmax() { - let mut params = QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric) - .unwrap(); + let mut params = + QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric).unwrap(); params.qmin = 127; params.qmax = -127; @@ -350,8 +350,8 @@ mod tests { #[test] fn test_validation_zero_point_out_of_range() { - let mut params = QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric) - .unwrap(); + let mut params = + QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric).unwrap(); params.zero_point = 200; assert!(params.validate().is_err()); diff --git a/crates/ruvector-cnn/src/quantize/tensor.rs b/crates/ruvector-cnn/src/quantize/tensor.rs index 2655c9d13..36a827bde 100644 --- a/crates/ruvector-cnn/src/quantize/tensor.rs +++ b/crates/ruvector-cnn/src/quantize/tensor.rs @@ -3,8 +3,8 @@ //! This module provides type-safe INT8 tensors with quantization metadata //! for efficient neural network inference. -use crate::error::{CnnError, CnnResult}; use super::params::QuantizationParams; +use crate::error::{CnnError, CnnResult}; use serde::{Deserialize, Serialize}; /// Metadata for a quantized tensor. @@ -49,13 +49,13 @@ impl QuantizationMetadata { if self.shape.is_empty() { return Err(CnnError::QuantizationError( - "shape cannot be empty".to_string() + "shape cannot be empty".to_string(), )); } if self.shape.iter().any(|&d| d == 0) { return Err(CnnError::QuantizationError( - "shape dimensions must be positive".to_string() + "shape dimensions must be positive".to_string(), )); } @@ -164,11 +164,7 @@ impl QuantizedTensor { .map(|&val| params.quantize_value(val)) .collect(); - let metadata = QuantizationMetadata::new( - params.scale, - params.zero_point, - shape.to_vec(), - ); + let metadata = QuantizationMetadata::new(params.scale, params.zero_point, shape.to_vec()); Ok(Self { data, metadata }) } @@ -195,7 +191,8 @@ impl QuantizedTensor { qmax: 127, }; - let fp32_data: Vec = self.data + let fp32_data: Vec = self + .data .iter() .map(|&val| params.dequantize_value(val)) .collect(); @@ -264,7 +261,7 @@ impl QuantizedTensor { // INV-3: Bounds check if !self.check_bounds(-127, 127) { return Err(CnnError::QuantizationError( - "INV-3 violation: some values outside [-127, 127]".to_string() + "INV-3 violation: some values outside [-127, 127]".to_string(), )); } @@ -421,8 +418,8 @@ mod tests { fn test_asymmetric_quantization() { let fp32_data = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0]; let shape = vec![6]; - let params = QuantizationParams::from_minmax(0.0, 5.0, QuantizationMode::Asymmetric) - .unwrap(); + let params = + QuantizationParams::from_minmax(0.0, 5.0, QuantizationMode::Asymmetric).unwrap(); let quantized = QuantizedTensor::quantize(&fp32_data, &shape, ¶ms).unwrap(); assert!(quantized.validate().is_ok()); @@ -435,7 +432,10 @@ mod tests { assert!( error < 0.6, "Value mismatch at index {}: original={}, restored={}, error={}", - i, original, restored, error + i, + original, + restored, + error ); } } diff --git a/crates/ruvector-cnn/src/simd/avx2.rs b/crates/ruvector-cnn/src/simd/avx2.rs index 445e31b4e..1c11d4783 100644 --- a/crates/ruvector-cnn/src/simd/avx2.rs +++ b/crates/ruvector-cnn/src/simd/avx2.rs @@ -11,29 +11,29 @@ use std::arch::x86_64::*; #[target_feature(enable = "avx2", enable = "fma")] pub unsafe fn dot_product_avx2_fma(a: &[f32], b: &[f32]) -> f32 { debug_assert_eq!(a.len(), b.len()); - + let mut sum = _mm256_setzero_ps(); let chunks = a.len() / 8; - + for i in 0..chunks { let va = _mm256_loadu_ps(a.as_ptr().add(i * 8)); let vb = _mm256_loadu_ps(b.as_ptr().add(i * 8)); sum = _mm256_fmadd_ps(va, vb, sum); } - + // Horizontal sum let sum128 = _mm_add_ps(_mm256_extractf128_ps(sum, 0), _mm256_extractf128_ps(sum, 1)); let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128)); let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1)); - + let mut result = 0.0f32; _mm_store_ss(&mut result, sum32); - + // Handle remainder for i in (chunks * 8)..a.len() { result += a[i] * b[i]; } - + result } @@ -42,30 +42,30 @@ pub unsafe fn dot_product_avx2_fma(a: &[f32], b: &[f32]) -> f32 { #[target_feature(enable = "avx2")] pub unsafe fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 { debug_assert_eq!(a.len(), b.len()); - + let mut sum = _mm256_setzero_ps(); let chunks = a.len() / 8; - + for i in 0..chunks { let va = _mm256_loadu_ps(a.as_ptr().add(i * 8)); let vb = _mm256_loadu_ps(b.as_ptr().add(i * 8)); let prod = _mm256_mul_ps(va, vb); sum = _mm256_add_ps(sum, prod); } - + // Horizontal sum let sum128 = _mm_add_ps(_mm256_extractf128_ps(sum, 0), _mm256_extractf128_ps(sum, 1)); let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128)); let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1)); - + let mut result = 0.0f32; _mm_store_ss(&mut result, sum32); - + // Handle remainder for i in (chunks * 8)..a.len() { result += a[i] * b[i]; } - + result } @@ -74,23 +74,23 @@ pub unsafe fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 { #[target_feature(enable = "avx512f")] pub unsafe fn dot_product_avx512(a: &[f32], b: &[f32]) -> f32 { debug_assert_eq!(a.len(), b.len()); - + let mut sum = _mm512_setzero_ps(); let chunks = a.len() / 16; - + for i in 0..chunks { let va = _mm512_loadu_ps(a.as_ptr().add(i * 16)); let vb = _mm512_loadu_ps(b.as_ptr().add(i * 16)); sum = _mm512_fmadd_ps(va, vb, sum); } - + let mut result = _mm512_reduce_add_ps(sum); - + // Handle remainder for i in (chunks * 16)..a.len() { result += a[i] * b[i]; } - + result } @@ -99,16 +99,16 @@ pub unsafe fn dot_product_avx512(a: &[f32], b: &[f32]) -> f32 { #[target_feature(enable = "avx2")] pub unsafe fn relu_avx2(input: &[f32], output: &mut [f32]) { debug_assert_eq!(input.len(), output.len()); - + let zero = _mm256_setzero_ps(); let chunks = input.len() / 8; - + for i in 0..chunks { let v = _mm256_loadu_ps(input.as_ptr().add(i * 8)); let result = _mm256_max_ps(v, zero); _mm256_storeu_ps(output.as_mut_ptr().add(i * 8), result); } - + // Handle remainder for i in (chunks * 8)..input.len() { output[i] = input[i].max(0.0); @@ -120,17 +120,17 @@ pub unsafe fn relu_avx2(input: &[f32], output: &mut [f32]) { #[target_feature(enable = "avx2")] pub unsafe fn relu6_avx2(input: &[f32], output: &mut [f32]) { debug_assert_eq!(input.len(), output.len()); - + let zero = _mm256_setzero_ps(); let six = _mm256_set1_ps(6.0); let chunks = input.len() / 8; - + for i in 0..chunks { let v = _mm256_loadu_ps(input.as_ptr().add(i * 8)); let result = _mm256_min_ps(_mm256_max_ps(v, zero), six); _mm256_storeu_ps(output.as_mut_ptr().add(i * 8), result); } - + // Handle remainder for i in (chunks * 8)..input.len() { output[i] = input[i].max(0.0).min(6.0); @@ -151,24 +151,24 @@ pub unsafe fn batch_norm_avx2( channels: usize, ) { debug_assert_eq!(input.len(), output.len()); - + // Pre-compute scale and shift for each channel let mut scale = vec![0.0f32; channels]; let mut shift = vec![0.0f32; channels]; - + for c in 0..channels { let inv_std = 1.0 / (var[c] + epsilon).sqrt(); scale[c] = gamma[c] * inv_std; shift[c] = beta[c] - mean[c] * scale[c]; } - + let spatial = input.len() / channels; - + // Process 8 spatial positions at a time if channels == 8 if channels == 8 { let scale_v = _mm256_loadu_ps(scale.as_ptr()); let shift_v = _mm256_loadu_ps(shift.as_ptr()); - + for s in 0..spatial { let offset = s * channels; let v = _mm256_loadu_ps(input.as_ptr().add(offset)); @@ -250,10 +250,14 @@ pub unsafe fn conv_3x3_avx2_fma( let ic_base = ic_chunk_idx * 4; // Load 4 input values and broadcast each - let input_val0 = _mm256_set1_ps(*input.get_unchecked(input_base + ic_base)); - let input_val1 = _mm256_set1_ps(*input.get_unchecked(input_base + ic_base + 1)); - let input_val2 = _mm256_set1_ps(*input.get_unchecked(input_base + ic_base + 2)); - let input_val3 = _mm256_set1_ps(*input.get_unchecked(input_base + ic_base + 3)); + let input_val0 = + _mm256_set1_ps(*input.get_unchecked(input_base + ic_base)); + let input_val1 = + _mm256_set1_ps(*input.get_unchecked(input_base + ic_base + 1)); + let input_val2 = + _mm256_set1_ps(*input.get_unchecked(input_base + ic_base + 2)); + let input_val3 = + _mm256_set1_ps(*input.get_unchecked(input_base + ic_base + 3)); // Gather 8 kernel weights for each of the 4 input channels let mut kv0 = [0.0f32; 8]; @@ -263,10 +267,18 @@ pub unsafe fn conv_3x3_avx2_fma( for i in 0..8 { let oc_idx = oc_base + i; - kv0[i] = *kernel.get_unchecked((oc_idx * in_c + ic_base) * 9 + kernel_offset); - kv1[i] = *kernel.get_unchecked((oc_idx * in_c + ic_base + 1) * 9 + kernel_offset); - kv2[i] = *kernel.get_unchecked((oc_idx * in_c + ic_base + 2) * 9 + kernel_offset); - kv3[i] = *kernel.get_unchecked((oc_idx * in_c + ic_base + 3) * 9 + kernel_offset); + kv0[i] = *kernel.get_unchecked( + (oc_idx * in_c + ic_base) * 9 + kernel_offset, + ); + kv1[i] = *kernel.get_unchecked( + (oc_idx * in_c + ic_base + 1) * 9 + kernel_offset, + ); + kv2[i] = *kernel.get_unchecked( + (oc_idx * in_c + ic_base + 2) * 9 + kernel_offset, + ); + kv3[i] = *kernel.get_unchecked( + (oc_idx * in_c + ic_base + 3) * 9 + kernel_offset, + ); } let kernel_v0 = _mm256_loadu_ps(kv0.as_ptr()); @@ -283,11 +295,14 @@ pub unsafe fn conv_3x3_avx2_fma( // Handle remainder input channels (0-3 channels) for ic in ic_remainder_start..in_c { - let input_val = _mm256_set1_ps(*input.get_unchecked(input_base + ic)); + let input_val = + _mm256_set1_ps(*input.get_unchecked(input_base + ic)); let mut kernel_vals = [0.0f32; 8]; for i in 0..8 { - kernel_vals[i] = *kernel.get_unchecked(((oc_base + i) * in_c + ic) * 9 + kernel_offset); + kernel_vals[i] = *kernel.get_unchecked( + ((oc_base + i) * in_c + ic) * 9 + kernel_offset, + ); } let kernel_v = _mm256_loadu_ps(kernel_vals.as_ptr()); @@ -389,10 +404,14 @@ pub unsafe fn conv_3x3_avx2( for ic_chunk_idx in 0..ic_chunks { let ic_base = ic_chunk_idx * 4; - let input_val0 = _mm256_set1_ps(*input.get_unchecked(input_base + ic_base)); - let input_val1 = _mm256_set1_ps(*input.get_unchecked(input_base + ic_base + 1)); - let input_val2 = _mm256_set1_ps(*input.get_unchecked(input_base + ic_base + 2)); - let input_val3 = _mm256_set1_ps(*input.get_unchecked(input_base + ic_base + 3)); + let input_val0 = + _mm256_set1_ps(*input.get_unchecked(input_base + ic_base)); + let input_val1 = + _mm256_set1_ps(*input.get_unchecked(input_base + ic_base + 1)); + let input_val2 = + _mm256_set1_ps(*input.get_unchecked(input_base + ic_base + 2)); + let input_val3 = + _mm256_set1_ps(*input.get_unchecked(input_base + ic_base + 3)); let mut kv0 = [0.0f32; 8]; let mut kv1 = [0.0f32; 8]; @@ -401,10 +420,18 @@ pub unsafe fn conv_3x3_avx2( for i in 0..8 { let oc_idx = oc_base + i; - kv0[i] = *kernel.get_unchecked((oc_idx * in_c + ic_base) * 9 + kernel_offset); - kv1[i] = *kernel.get_unchecked((oc_idx * in_c + ic_base + 1) * 9 + kernel_offset); - kv2[i] = *kernel.get_unchecked((oc_idx * in_c + ic_base + 2) * 9 + kernel_offset); - kv3[i] = *kernel.get_unchecked((oc_idx * in_c + ic_base + 3) * 9 + kernel_offset); + kv0[i] = *kernel.get_unchecked( + (oc_idx * in_c + ic_base) * 9 + kernel_offset, + ); + kv1[i] = *kernel.get_unchecked( + (oc_idx * in_c + ic_base + 1) * 9 + kernel_offset, + ); + kv2[i] = *kernel.get_unchecked( + (oc_idx * in_c + ic_base + 2) * 9 + kernel_offset, + ); + kv3[i] = *kernel.get_unchecked( + (oc_idx * in_c + ic_base + 3) * 9 + kernel_offset, + ); } let kernel_v0 = _mm256_loadu_ps(kv0.as_ptr()); @@ -421,11 +448,14 @@ pub unsafe fn conv_3x3_avx2( // Remainder input channels for ic in ic_remainder_start..in_c { - let input_val = _mm256_set1_ps(*input.get_unchecked(input_base + ic)); + let input_val = + _mm256_set1_ps(*input.get_unchecked(input_base + ic)); let mut kernel_vals = [0.0f32; 8]; for i in 0..8 { - kernel_vals[i] = *kernel.get_unchecked(((oc_base + i) * in_c + ic) * 9 + kernel_offset); + kernel_vals[i] = *kernel.get_unchecked( + ((oc_base + i) * in_c + ic) * 9 + kernel_offset, + ); } let kernel_v = _mm256_loadu_ps(kernel_vals.as_ptr()); @@ -521,7 +551,10 @@ pub unsafe fn depthwise_conv_3x3_avx2( if iw >= 0 && iw < w as isize { let input_base = (ih0 * w + iw as usize) * c + c_base; let input_v = _mm256_loadu_ps(input.as_ptr().add(input_base)); - sum_row0 = _mm256_add_ps(sum_row0, _mm256_mul_ps(input_v, kernel_cache[0][kw])); + sum_row0 = _mm256_add_ps( + sum_row0, + _mm256_mul_ps(input_v, kernel_cache[0][kw]), + ); } } } @@ -535,7 +568,10 @@ pub unsafe fn depthwise_conv_3x3_avx2( if iw >= 0 && iw < w as isize { let input_base = (ih1 * w + iw as usize) * c + c_base; let input_v = _mm256_loadu_ps(input.as_ptr().add(input_base)); - sum_row1 = _mm256_add_ps(sum_row1, _mm256_mul_ps(input_v, kernel_cache[1][kw])); + sum_row1 = _mm256_add_ps( + sum_row1, + _mm256_mul_ps(input_v, kernel_cache[1][kw]), + ); } } } @@ -549,7 +585,10 @@ pub unsafe fn depthwise_conv_3x3_avx2( if iw >= 0 && iw < w as isize { let input_base = (ih2 * w + iw as usize) * c + c_base; let input_v = _mm256_loadu_ps(input.as_ptr().add(input_base)); - sum_row2 = _mm256_add_ps(sum_row2, _mm256_mul_ps(input_v, kernel_cache[2][kw])); + sum_row2 = _mm256_add_ps( + sum_row2, + _mm256_mul_ps(input_v, kernel_cache[2][kw]), + ); } } } @@ -589,7 +628,13 @@ pub unsafe fn depthwise_conv_3x3_avx2( /// Averages over H*W spatial dimensions, processing 8 channels at a time. #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx2")] -pub unsafe fn global_avg_pool_avx2(input: &[f32], output: &mut [f32], h: usize, w: usize, c: usize) { +pub unsafe fn global_avg_pool_avx2( + input: &[f32], + output: &mut [f32], + h: usize, + w: usize, + c: usize, +) { let spatial = h * w; let c_chunks = c / 8; let inv_spatial = _mm256_set1_ps(1.0 / spatial as f32); @@ -682,24 +727,87 @@ pub unsafe fn max_pool_2x2_avx2( // Non-x86_64 stubs to allow compilation #[cfg(not(target_arch = "x86_64"))] -pub unsafe fn dot_product_avx2_fma(_a: &[f32], _b: &[f32]) -> f32 { 0.0 } +pub unsafe fn dot_product_avx2_fma(_a: &[f32], _b: &[f32]) -> f32 { + 0.0 +} #[cfg(not(target_arch = "x86_64"))] -pub unsafe fn dot_product_avx2(_a: &[f32], _b: &[f32]) -> f32 { 0.0 } +pub unsafe fn dot_product_avx2(_a: &[f32], _b: &[f32]) -> f32 { + 0.0 +} #[cfg(not(target_arch = "x86_64"))] -pub unsafe fn dot_product_avx512(_a: &[f32], _b: &[f32]) -> f32 { 0.0 } +pub unsafe fn dot_product_avx512(_a: &[f32], _b: &[f32]) -> f32 { + 0.0 +} #[cfg(not(target_arch = "x86_64"))] pub unsafe fn relu_avx2(_input: &[f32], _output: &mut [f32]) {} #[cfg(not(target_arch = "x86_64"))] pub unsafe fn relu6_avx2(_input: &[f32], _output: &mut [f32]) {} #[cfg(not(target_arch = "x86_64"))] -pub unsafe fn batch_norm_avx2(_input: &[f32], _output: &mut [f32], _gamma: &[f32], _beta: &[f32], _mean: &[f32], _var: &[f32], _epsilon: f32, _channels: usize) {} +pub unsafe fn batch_norm_avx2( + _input: &[f32], + _output: &mut [f32], + _gamma: &[f32], + _beta: &[f32], + _mean: &[f32], + _var: &[f32], + _epsilon: f32, + _channels: usize, +) { +} #[cfg(not(target_arch = "x86_64"))] -pub unsafe fn conv_3x3_avx2_fma(_input: &[f32], _kernel: &[f32], _output: &mut [f32], _in_h: usize, _in_w: usize, _in_c: usize, _out_c: usize, _stride: usize, _padding: usize) {} +pub unsafe fn conv_3x3_avx2_fma( + _input: &[f32], + _kernel: &[f32], + _output: &mut [f32], + _in_h: usize, + _in_w: usize, + _in_c: usize, + _out_c: usize, + _stride: usize, + _padding: usize, +) { +} #[cfg(not(target_arch = "x86_64"))] -pub unsafe fn conv_3x3_avx2(_input: &[f32], _kernel: &[f32], _output: &mut [f32], _in_h: usize, _in_w: usize, _in_c: usize, _out_c: usize, _stride: usize, _padding: usize) {} +pub unsafe fn conv_3x3_avx2( + _input: &[f32], + _kernel: &[f32], + _output: &mut [f32], + _in_h: usize, + _in_w: usize, + _in_c: usize, + _out_c: usize, + _stride: usize, + _padding: usize, +) { +} #[cfg(not(target_arch = "x86_64"))] -pub unsafe fn depthwise_conv_3x3_avx2(_input: &[f32], _kernel: &[f32], _output: &mut [f32], _h: usize, _w: usize, _c: usize, _stride: usize, _padding: usize) {} +pub unsafe fn depthwise_conv_3x3_avx2( + _input: &[f32], + _kernel: &[f32], + _output: &mut [f32], + _h: usize, + _w: usize, + _c: usize, + _stride: usize, + _padding: usize, +) { +} #[cfg(not(target_arch = "x86_64"))] -pub unsafe fn global_avg_pool_avx2(_input: &[f32], _output: &mut [f32], _h: usize, _w: usize, _c: usize) {} +pub unsafe fn global_avg_pool_avx2( + _input: &[f32], + _output: &mut [f32], + _h: usize, + _w: usize, + _c: usize, +) { +} #[cfg(not(target_arch = "x86_64"))] -pub unsafe fn max_pool_2x2_avx2(_input: &[f32], _output: &mut [f32], _h: usize, _w: usize, _c: usize, _stride: usize) {} +pub unsafe fn max_pool_2x2_avx2( + _input: &[f32], + _output: &mut [f32], + _h: usize, + _w: usize, + _c: usize, + _stride: usize, +) { +} diff --git a/crates/ruvector-cnn/src/simd/mod.rs b/crates/ruvector-cnn/src/simd/mod.rs index 9b0868de9..5a9c39de2 100644 --- a/crates/ruvector-cnn/src/simd/mod.rs +++ b/crates/ruvector-cnn/src/simd/mod.rs @@ -21,12 +21,13 @@ pub mod wasm; // Re-export the dispatch functions pub use avx2::*; -pub use scalar::*; -pub use winograd::{conv_3x3_winograd, transform_filter, transform_input, transform_output, WinogradFilterCache}; pub use quantize::{ - QuantParams, QuantizedTensor, QuantizationType, PerChannelQuantParams, - quantize_simd, dequantize_simd, quantize_batch, dequantize_batch, - pi_constants, + dequantize_batch, dequantize_simd, pi_constants, quantize_batch, quantize_simd, + PerChannelQuantParams, QuantParams, QuantizationType, QuantizedTensor, +}; +pub use scalar::*; +pub use winograd::{ + conv_3x3_winograd, transform_filter, transform_input, transform_output, WinogradFilterCache, }; /// SIMD-accelerated dot product with automatic architecture dispatch @@ -55,7 +56,11 @@ pub fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 { wasm::dot_product_wasm(a, b) } - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "wasm32")))] + #[cfg(not(any( + target_arch = "x86_64", + target_arch = "aarch64", + target_arch = "wasm32" + )))] { scalar::dot_product_scalar(a, b) } @@ -83,7 +88,11 @@ pub fn relu_simd(input: &[f32], output: &mut [f32]) { wasm::relu_wasm(input, output) } - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "wasm32")))] + #[cfg(not(any( + target_arch = "x86_64", + target_arch = "aarch64", + target_arch = "wasm32" + )))] { scalar::relu_scalar(input, output) } @@ -111,7 +120,11 @@ pub fn relu6_simd(input: &[f32], output: &mut [f32]) { wasm::relu6_wasm(input, output) } - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "wasm32")))] + #[cfg(not(any( + target_arch = "x86_64", + target_arch = "aarch64", + target_arch = "wasm32" + )))] { scalar::relu6_scalar(input, output) } @@ -132,7 +145,9 @@ pub fn batch_norm_simd( #[cfg(target_arch = "x86_64")] { if is_x86_feature_detected!("avx2") { - unsafe { avx2::batch_norm_avx2(input, output, gamma, beta, mean, var, epsilon, channels) } + unsafe { + avx2::batch_norm_avx2(input, output, gamma, beta, mean, var, epsilon, channels) + } } else { scalar::batch_norm_scalar(input, output, gamma, beta, mean, var, epsilon, channels) } @@ -148,7 +163,11 @@ pub fn batch_norm_simd( wasm::batch_norm_wasm(input, output, gamma, beta, mean, var, epsilon, channels) } - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "wasm32")))] + #[cfg(not(any( + target_arch = "x86_64", + target_arch = "aarch64", + target_arch = "wasm32" + )))] { scalar::batch_norm_scalar(input, output, gamma, beta, mean, var, epsilon, channels) } @@ -171,32 +190,48 @@ pub fn conv_3x3_simd( { if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") { unsafe { - avx2::conv_3x3_avx2_fma(input, kernel, output, in_h, in_w, in_c, out_c, stride, padding) + avx2::conv_3x3_avx2_fma( + input, kernel, output, in_h, in_w, in_c, out_c, stride, padding, + ) } } else if is_x86_feature_detected!("avx2") { unsafe { - avx2::conv_3x3_avx2(input, kernel, output, in_h, in_w, in_c, out_c, stride, padding) + avx2::conv_3x3_avx2( + input, kernel, output, in_h, in_w, in_c, out_c, stride, padding, + ) } } else { - scalar::conv_3x3_scalar(input, kernel, output, in_h, in_w, in_c, out_c, stride, padding) + scalar::conv_3x3_scalar( + input, kernel, output, in_h, in_w, in_c, out_c, stride, padding, + ) } } #[cfg(target_arch = "aarch64")] { unsafe { - neon::conv_3x3_neon(input, kernel, output, in_h, in_w, in_c, out_c, stride, padding) + neon::conv_3x3_neon( + input, kernel, output, in_h, in_w, in_c, out_c, stride, padding, + ) } } #[cfg(target_arch = "wasm32")] { - wasm::conv_3x3_wasm(input, kernel, output, in_h, in_w, in_c, out_c, stride, padding) + wasm::conv_3x3_wasm( + input, kernel, output, in_h, in_w, in_c, out_c, stride, padding, + ) } - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "wasm32")))] + #[cfg(not(any( + target_arch = "x86_64", + target_arch = "aarch64", + target_arch = "wasm32" + )))] { - scalar::conv_3x3_scalar(input, kernel, output, in_h, in_w, in_c, out_c, stride, padding) + scalar::conv_3x3_scalar( + input, kernel, output, in_h, in_w, in_c, out_c, stride, padding, + ) } } @@ -215,7 +250,9 @@ pub fn depthwise_conv_3x3_simd( #[cfg(target_arch = "x86_64")] { if is_x86_feature_detected!("avx2") { - unsafe { avx2::depthwise_conv_3x3_avx2(input, kernel, output, h, w, c, stride, padding) } + unsafe { + avx2::depthwise_conv_3x3_avx2(input, kernel, output, h, w, c, stride, padding) + } } else { scalar::depthwise_conv_3x3_scalar(input, kernel, output, h, w, c, stride, padding) } @@ -231,7 +268,11 @@ pub fn depthwise_conv_3x3_simd( wasm::depthwise_conv_3x3_wasm(input, kernel, output, h, w, c, stride, padding) } - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "wasm32")))] + #[cfg(not(any( + target_arch = "x86_64", + target_arch = "aarch64", + target_arch = "wasm32" + )))] { scalar::depthwise_conv_3x3_scalar(input, kernel, output, h, w, c, stride, padding) } @@ -259,7 +300,11 @@ pub fn global_avg_pool_simd(input: &[f32], output: &mut [f32], h: usize, w: usiz wasm::global_avg_pool_wasm(input, output, h, w, c) } - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "wasm32")))] + #[cfg(not(any( + target_arch = "x86_64", + target_arch = "aarch64", + target_arch = "wasm32" + )))] { scalar::global_avg_pool_scalar(input, output, h, w, c) } @@ -294,7 +339,11 @@ pub fn max_pool_2x2_simd( wasm::max_pool_2x2_wasm(input, output, h, w, c, stride) } - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "wasm32")))] + #[cfg(not(any( + target_arch = "x86_64", + target_arch = "aarch64", + target_arch = "wasm32" + )))] { scalar::max_pool_2x2_scalar(input, output, h, w, c, stride) } diff --git a/crates/ruvector-cnn/src/simd/neon.rs b/crates/ruvector-cnn/src/simd/neon.rs index 9b2442692..7e0b5d05d 100644 --- a/crates/ruvector-cnn/src/simd/neon.rs +++ b/crates/ruvector-cnn/src/simd/neon.rs @@ -379,7 +379,13 @@ pub unsafe fn depthwise_conv_3x3_neon( /// NEON global average pooling #[cfg(target_arch = "aarch64")] #[inline] -pub unsafe fn global_avg_pool_neon(input: &[f32], output: &mut [f32], h: usize, w: usize, c: usize) { +pub unsafe fn global_avg_pool_neon( + input: &[f32], + output: &mut [f32], + h: usize, + w: usize, + c: usize, +) { let spatial_size = h * w; let inv_spatial = 1.0 / spatial_size as f32; let inv_spatial_v = vdupq_n_f32(inv_spatial); diff --git a/crates/ruvector-cnn/src/simd/quantize.rs b/crates/ruvector-cnn/src/simd/quantize.rs index fa9f2089f..c406da11d 100644 --- a/crates/ruvector-cnn/src/simd/quantize.rs +++ b/crates/ruvector-cnn/src/simd/quantize.rs @@ -238,7 +238,8 @@ impl QuantizedTensor { kernel_h: usize, kernel_w: usize, ) -> Self { - let per_channel = PerChannelQuantParams::symmetric_per_channel(weights, out_channels, in_channels); + let per_channel = + PerChannelQuantParams::symmetric_per_channel(weights, out_channels, in_channels); let kernel_size = kernel_h * kernel_w; let mut quantized = Vec::with_capacity(weights.len()); @@ -366,7 +367,8 @@ pub unsafe fn quantize_batch_avx2(input: &[f32], output: &mut [i8], params: &Qua // Handle remainder let remainder_start = chunks * 8; for i in remainder_start..len { - let scaled = input[i] / params.scale + params.zero_point as f32 + params.anti_resonance * 0.5; + let scaled = + input[i] / params.scale + params.zero_point as f32 + params.anti_resonance * 0.5; output[i] = scaled.round().clamp(-128.0, 127.0) as i8; } } diff --git a/crates/ruvector-cnn/src/simd/wasm.rs b/crates/ruvector-cnn/src/simd/wasm.rs index 7b142ddd2..069641f43 100644 --- a/crates/ruvector-cnn/src/simd/wasm.rs +++ b/crates/ruvector-cnn/src/simd/wasm.rs @@ -311,7 +311,8 @@ pub fn depthwise_conv_3x3_wasm( } unsafe { - let input_v = v128_load(input[input_base..].as_ptr() as *const v128); + let input_v = + v128_load(input[input_base..].as_ptr() as *const v128); let kernel_v = v128_load(kernel_vals.as_ptr() as *const v128); let prod = f32x4_mul(input_v, kernel_v); @@ -532,13 +533,7 @@ pub fn depthwise_conv_3x3_wasm( } #[cfg(not(target_arch = "wasm32"))] -pub fn global_avg_pool_wasm( - _input: &[f32], - _output: &mut [f32], - _h: usize, - _w: usize, - _c: usize, -) { +pub fn global_avg_pool_wasm(_input: &[f32], _output: &mut [f32], _h: usize, _w: usize, _c: usize) { unimplemented!("WASM SIMD not available on this architecture") } diff --git a/crates/ruvector-cnn/tests/acceptance_gates.rs b/crates/ruvector-cnn/tests/acceptance_gates.rs index 993c8cc2d..767535a57 100644 --- a/crates/ruvector-cnn/tests/acceptance_gates.rs +++ b/crates/ruvector-cnn/tests/acceptance_gates.rs @@ -10,7 +10,7 @@ //! - GATE-6: WASM build succeeds (placeholder) //! - GATE-7: CI pipeline passes (placeholder) -use ruvector_cnn::int8::{QuantParams, quantize_tensor, dequantize_tensor}; +use ruvector_cnn::int8::{dequantize_tensor, quantize_tensor, QuantParams}; #[cfg(test)] mod acceptance_gates { @@ -58,19 +58,22 @@ mod acceptance_gates { assert!( params.scale.is_finite(), "GATE-1 FAILED ({}): Scale is not finite: {}", - name, params.scale + name, + params.scale ); assert!( params.scale > 0.0, "GATE-1 FAILED ({}): Scale must be positive: {}", - name, params.scale + name, + params.scale ); // Validate zero_point assert!( params.zero_point >= -128 && params.zero_point <= 127, "GATE-1 FAILED ({}): Zero point {} out of range [-128, 127]", - name, params.zero_point + name, + params.zero_point ); println!( @@ -100,9 +103,7 @@ mod acceptance_gates { for (name, size) in test_cases { let mut rng = fastrand::Rng::with_seed(42 + size); - let fp32: Vec = (0..size) - .map(|_| rng.f32() * 2.0 - 1.0) - .collect(); + let fp32: Vec = (0..size).map(|_| rng.f32() * 2.0 - 1.0).collect(); let params = QuantParams::from_tensor(&fp32); let int8 = quantize_tensor(&fp32, ¶ms); @@ -267,7 +268,8 @@ mod acceptance_gates { .map(|_| { let u1 = rng.f32(); let u2 = rng.f32(); - ((-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos()) * 0.5 + ((-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos()) + * 0.5 }) .collect() }), @@ -278,7 +280,13 @@ mod acceptance_gates { generator: Box::new(|size| { let mut rng = fastrand::Rng::with_seed(789); (0..size) - .map(|_| if rng.f32() < 0.9 { 0.0 } else { rng.f32() * 2.0 - 1.0 }) + .map(|_| { + if rng.f32() < 0.9 { + 0.0 + } else { + rng.f32() * 2.0 - 1.0 + } + }) .collect() }), min_similarity: 0.990, @@ -338,15 +346,12 @@ mod acceptance_gates { let mut rng = fastrand::Rng::with_seed(42); let calibration_batch: Vec> = (0..batch_size) - .map(|_| { - (0..embedding_size) - .map(|_| rng.f32() * 2.0 - 1.0) - .collect() - }) + .map(|_| (0..embedding_size).map(|_| rng.f32() * 2.0 - 1.0).collect()) .collect(); // Flatten for global calibration - let flattened: Vec = calibration_batch.iter() + let flattened: Vec = calibration_batch + .iter() .flat_map(|v| v.iter().copied()) .collect(); @@ -372,17 +377,11 @@ mod acceptance_gates { min_similarity = min_similarity.min(similarity); if similarity < 0.99 { - println!( - "⚠ Batch item {} has lower similarity: {:.6}", - i, similarity - ); + println!("⚠ Batch item {} has lower similarity: {:.6}", i, similarity); } } - println!( - "✓ Minimum similarity across batch: {:.6}", - min_similarity - ); + println!("✓ Minimum similarity across batch: {:.6}", min_similarity); assert!( min_similarity >= 0.99, diff --git a/crates/ruvector-cnn/tests/backbone_test.rs b/crates/ruvector-cnn/tests/backbone_test.rs index 41eb54e90..4804a81d2 100644 --- a/crates/ruvector-cnn/tests/backbone_test.rs +++ b/crates/ruvector-cnn/tests/backbone_test.rs @@ -11,9 +11,8 @@ #![cfg(feature = "backbone")] use ruvector_cnn::backbone::{ - Backbone, BackboneExt, BackboneType, create_backbone, Layer, - MobileNetV3, MobileNetV3Config, - MobileNetConfig, MobileNetV3Small, MobileNetV3Large, + create_backbone, Backbone, BackboneExt, BackboneType, Layer, MobileNetConfig, MobileNetV3, + MobileNetV3Config, MobileNetV3Large, MobileNetV3Small, }; use ruvector_cnn::layers::TensorShape; @@ -181,7 +180,9 @@ fn test_mobilenet_v3_forward_features() { let input_shape = TensorShape::new(1, 3, 224, 224); let input = vec![0.5; input_shape.numel()]; - let output = model.forward_features(&input, &input_shape).expect("Forward failed"); + let output = model + .forward_features(&input, &input_shape) + .expect("Forward failed"); // Output should be [batch, feature_dim] assert_eq!(output.len(), 576); @@ -194,7 +195,9 @@ fn test_mobilenet_v3_forward_with_classifier() { let input_shape = TensorShape::new(1, 3, 224, 224); let input = vec![0.5; input_shape.numel()]; - let output = model.forward_with_shape(&input, &input_shape).expect("Forward failed"); + let output = model + .forward_with_shape(&input, &input_shape) + .expect("Forward failed"); // Output should be [batch, num_classes] assert_eq!(output.len(), 1000); @@ -208,7 +211,9 @@ fn test_mobilenet_v3_forward_batch() { let input_shape = TensorShape::new(batch_size, 3, 224, 224); let input = vec![0.5; input_shape.numel()]; - let output = model.forward_features(&input, &input_shape).expect("Forward failed"); + let output = model + .forward_features(&input, &input_shape) + .expect("Forward failed"); // Output should be [batch, feature_dim] assert_eq!(output.len(), batch_size * 576); @@ -220,8 +225,12 @@ fn test_mobilenet_v3_forward_deterministic() { let input_shape = TensorShape::new(1, 3, 224, 224); let input = vec![0.5; input_shape.numel()]; - let output1 = model.forward_features(&input, &input_shape).expect("Forward failed"); - let output2 = model.forward_features(&input, &input_shape).expect("Forward failed"); + let output1 = model + .forward_features(&input, &input_shape) + .expect("Forward failed"); + let output2 = model + .forward_features(&input, &input_shape) + .expect("Forward failed"); // Same input should produce same output (no randomness in inference) for (v1, v2) in output1.iter().zip(output2.iter()) { @@ -235,8 +244,8 @@ fn test_mobilenet_v3_forward_deterministic() { #[test] fn test_create_backbone_small() { - let backbone = create_backbone(BackboneType::MobileNetV3Small, 1000) - .expect("Failed to create backbone"); + let backbone = + create_backbone(BackboneType::MobileNetV3Small, 1000).expect("Failed to create backbone"); assert_eq!(backbone.backbone_type(), BackboneType::MobileNetV3Small); assert_eq!(backbone.output_dim(), 576); @@ -244,8 +253,8 @@ fn test_create_backbone_small() { #[test] fn test_create_backbone_large() { - let backbone = create_backbone(BackboneType::MobileNetV3Large, 1000) - .expect("Failed to create backbone"); + let backbone = + create_backbone(BackboneType::MobileNetV3Large, 1000).expect("Failed to create backbone"); assert_eq!(backbone.backbone_type(), BackboneType::MobileNetV3Large); assert_eq!(backbone.output_dim(), 960); @@ -253,8 +262,8 @@ fn test_create_backbone_large() { #[test] fn test_create_backbone_feature_extraction() { - let backbone = create_backbone(BackboneType::MobileNetV3Small, 0) - .expect("Failed to create backbone"); + let backbone = + create_backbone(BackboneType::MobileNetV3Small, 0).expect("Failed to create backbone"); assert_eq!(backbone.output_dim(), 576); } @@ -303,8 +312,8 @@ fn test_mobilenet_v3_last_conv() { #[test] fn test_feature_output_shape() { - let backbone = create_backbone(BackboneType::MobileNetV3Small, 0) - .expect("Failed to create backbone"); + let backbone = + create_backbone(BackboneType::MobileNetV3Small, 0).expect("Failed to create backbone"); let input_shape = TensorShape::new(1, 3, 224, 224); let output_shape = backbone.feature_output_shape(&input_shape); @@ -317,8 +326,8 @@ fn test_feature_output_shape() { #[test] fn test_feature_output_shape_batch() { - let backbone = create_backbone(BackboneType::MobileNetV3Small, 0) - .expect("Failed to create backbone"); + let backbone = + create_backbone(BackboneType::MobileNetV3Small, 0).expect("Failed to create backbone"); let input_shape = TensorShape::new(4, 3, 224, 224); let output_shape = backbone.feature_output_shape(&input_shape); @@ -337,7 +346,9 @@ fn test_backbone_forward_all_zeros() { let input_shape = TensorShape::new(1, 3, 224, 224); let input = vec![0.0; input_shape.numel()]; - let output = model.forward_features(&input, &input_shape).expect("Forward failed"); + let output = model + .forward_features(&input, &input_shape) + .expect("Forward failed"); // Output should be valid (all finite) assert!(output.iter().all(|x| x.is_finite())); @@ -349,7 +360,9 @@ fn test_backbone_forward_all_ones() { let input_shape = TensorShape::new(1, 3, 224, 224); let input = vec![1.0; input_shape.numel()]; - let output = model.forward_features(&input, &input_shape).expect("Forward failed"); + let output = model + .forward_features(&input, &input_shape) + .expect("Forward failed"); assert!(output.iter().all(|x| x.is_finite())); } @@ -360,7 +373,9 @@ fn test_backbone_forward_negative_values() { let input_shape = TensorShape::new(1, 3, 224, 224); let input = vec![-0.5; input_shape.numel()]; - let output = model.forward_features(&input, &input_shape).expect("Forward failed"); + let output = model + .forward_features(&input, &input_shape) + .expect("Forward failed"); assert!(output.iter().all(|x| x.is_finite())); } diff --git a/crates/ruvector-cnn/tests/contrastive_test.rs b/crates/ruvector-cnn/tests/contrastive_test.rs index c6912b183..0f62edb1d 100644 --- a/crates/ruvector-cnn/tests/contrastive_test.rs +++ b/crates/ruvector-cnn/tests/contrastive_test.rs @@ -7,7 +7,7 @@ //! - Training pipeline integration use ruvector_cnn::contrastive::{ - AugmentationConfig, ContrastiveAugmentation, InfoNCELoss, TripletLoss, TripletDistance, + AugmentationConfig, ContrastiveAugmentation, InfoNCELoss, TripletDistance, TripletLoss, }; // ============================================================================ @@ -75,7 +75,10 @@ fn test_infonce_temperature_effect() { // Both should be valid assert!(low_temp_loss.is_finite(), "Low temp loss should be finite"); - assert!(high_temp_loss.is_finite(), "High temp loss should be finite"); + assert!( + high_temp_loss.is_finite(), + "High temp loss should be finite" + ); } #[test] @@ -87,11 +90,7 @@ fn test_infonce_many_negatives() { for i in 0..10 { let angle = (i as f64) * 0.5; embeddings.push(vec![angle.cos(), angle.sin(), 0.0]); - embeddings.push(vec![ - (angle + 0.1).cos(), - (angle + 0.1).sin(), - 0.1, - ]); + embeddings.push(vec![(angle + 0.1).cos(), (angle + 0.1).sin(), 0.1]); } let loss = loss_fn.forward(&embeddings, 2); @@ -128,7 +127,10 @@ fn test_infonce_detailed_results() { // Self-similarity should be 1.0 for i in 0..4 { - assert!((sim_matrix[i][i] - 1.0).abs() < 1e-6, "Self-similarity should be 1.0"); + assert!( + (sim_matrix[i][i] - 1.0).abs() < 1e-6, + "Self-similarity should be 1.0" + ); } } @@ -140,7 +142,9 @@ fn test_infonce_forward_with_pairs() { let positives = vec![vec![0.9, 0.1, 0.0], vec![0.1, 0.9, 0.0]]; - let loss = loss_fn.forward_with_pairs(&anchors, &positives, None).unwrap(); + let loss = loss_fn + .forward_with_pairs(&anchors, &positives, None) + .unwrap(); assert!(loss > 0.0); assert!(loss.is_finite()); @@ -175,10 +179,7 @@ fn test_triplet_loss_zero_case() { let loss = loss_fn.forward(&anchor, &positive, &negative); - assert_eq!( - loss, 0.0, - "Loss should be zero when margin is satisfied" - ); + assert_eq!(loss, 0.0, "Loss should be zero when margin is satisfied"); } #[test] @@ -222,7 +223,9 @@ fn test_triplet_loss_batch() { let positives = vec![vec![0.9, 0.1], vec![0.1, 0.9]]; let negatives = vec![vec![-1.0, 0.0], vec![0.0, -1.0]]; - let loss = loss_fn.forward_batch(&anchors, &positives, &negatives).unwrap(); + let loss = loss_fn + .forward_batch(&anchors, &positives, &negatives) + .unwrap(); assert!(loss >= 0.0); assert!(loss.is_finite()); @@ -236,7 +239,9 @@ fn test_triplet_loss_detailed() { let positive = vec![1.0, 0.0]; let negative = vec![0.5, 0.0]; // Closer to anchor than positive - let result = loss_fn.forward_detailed(&anchor, &positive, &negative).unwrap(); + let result = loss_fn + .forward_detailed(&anchor, &positive, &negative) + .unwrap(); assert!(result.loss > 0.0); assert!(result.is_hard); @@ -290,7 +295,10 @@ fn test_augmentation_with_seed() { // Same seed should produce same config values assert_eq!(aug1.config().crop_scale_min, aug2.config().crop_scale_min); assert_eq!(aug1.config().crop_scale_max, aug2.config().crop_scale_max); - assert_eq!(aug1.config().horizontal_flip_prob, aug2.config().horizontal_flip_prob); + assert_eq!( + aug1.config().horizontal_flip_prob, + aug2.config().horizontal_flip_prob + ); } #[test] @@ -346,7 +354,10 @@ fn test_infonce_with_normalized_embeddings() { let loss = loss_fn.forward(&normalized, 2); - assert!(loss.is_finite(), "Loss with normalized vectors should be finite"); + assert!( + loss.is_finite(), + "Loss with normalized vectors should be finite" + ); assert!(loss > 0.0, "Loss should be positive"); } @@ -382,10 +393,10 @@ fn test_triplet_mine_hard_triplets() { // Create embeddings where hard triplets exist // Class 0 embeddings are close to class 1 embeddings, creating hard triplets let embeddings = vec![ - vec![1.0f64, 0.0], // class 0 - vec![0.95, 0.05], // class 0 - close to anchor - vec![0.9, 0.1], // class 1 - close to class 0 - vec![0.85, 0.15], // class 1 - also close + vec![1.0f64, 0.0], // class 0 + vec![0.95, 0.05], // class 0 - close to anchor + vec![0.9, 0.1], // class 1 - close to class 0 + vec![0.85, 0.15], // class 1 - also close ]; let labels = vec![0, 0, 1, 1]; @@ -393,8 +404,14 @@ fn test_triplet_mine_hard_triplets() { // Verify triplet structure for any hard triplets found for (a, p, n) in &hard_triplets { - assert_eq!(labels[*a], labels[*p], "anchor and positive should be same class"); - assert_ne!(labels[*a], labels[*n], "anchor and negative should be different class"); + assert_eq!( + labels[*a], labels[*p], + "anchor and positive should be same class" + ); + assert_ne!( + labels[*a], labels[*n], + "anchor and negative should be different class" + ); } // Note: depending on the margin and embeddings, hard triplets may or may not be found diff --git a/crates/ruvector-cnn/tests/graph_rewrite_integration.rs b/crates/ruvector-cnn/tests/graph_rewrite_integration.rs index 0f1e8244b..95d25f469 100644 --- a/crates/ruvector-cnn/tests/graph_rewrite_integration.rs +++ b/crates/ruvector-cnn/tests/graph_rewrite_integration.rs @@ -1,9 +1,9 @@ //! Integration tests for graph rewrite passes (ADR-091 Phase 3) use ruvector_cnn::quantize::{ - CalibrationHistogram, ComputationGraph, NodeParams, NodeType, QuantizationParams, - fuse_batchnorm_to_conv, fuse_hardswish, fuse_relu, fuse_zp_to_bias, - generate_hardswish_lut, insert_qdq_nodes, + fuse_batchnorm_to_conv, fuse_hardswish, fuse_relu, fuse_zp_to_bias, generate_hardswish_lut, + insert_qdq_nodes, CalibrationHistogram, ComputationGraph, NodeParams, NodeType, + QuantizationParams, }; use std::collections::HashMap; diff --git a/crates/ruvector-cnn/tests/integration_test.rs b/crates/ruvector-cnn/tests/integration_test.rs index 1762e5371..483c6d91a 100644 --- a/crates/ruvector-cnn/tests/integration_test.rs +++ b/crates/ruvector-cnn/tests/integration_test.rs @@ -14,8 +14,7 @@ use ruvector_cnn::{CnnEmbedder, EmbeddingConfig, EmbeddingExtractor}; #[test] fn test_cnn_embedder_creation() { - let embedder = CnnEmbedder::new(EmbeddingConfig::default()) - .expect("Failed to create embedder"); + let embedder = CnnEmbedder::new(EmbeddingConfig::default()).expect("Failed to create embedder"); assert_eq!(embedder.embedding_dim(), 512); assert_eq!(embedder.input_size(), 224); @@ -23,16 +22,14 @@ fn test_cnn_embedder_creation() { #[test] fn test_cnn_embedder_v3_small() { - let embedder = CnnEmbedder::new_v3_small() - .expect("Failed to create V3 Small embedder"); + let embedder = CnnEmbedder::new_v3_small().expect("Failed to create V3 Small embedder"); assert_eq!(embedder.embedding_dim(), 576); } #[test] fn test_cnn_embedder_v3_large() { - let embedder = CnnEmbedder::new_v3_large() - .expect("Failed to create V3 Large embedder"); + let embedder = CnnEmbedder::new_v3_large().expect("Failed to create V3 Large embedder"); assert_eq!(embedder.embedding_dim(), 960); } @@ -63,9 +60,7 @@ fn test_image_to_embedding_pipeline() { let embedder = CnnEmbedder::new(config).expect("Failed to create embedder"); // Create a test image (RGBA format, 64x64) - let image: Vec = (0..(64 * 64 * 4)) - .map(|i| (i % 256) as u8) - .collect(); + let image: Vec = (0..(64 * 64 * 4)).map(|i| (i % 256) as u8).collect(); let embedding = embedder.extract(&image, 64, 64).expect("Extraction failed"); @@ -73,7 +68,10 @@ fn test_image_to_embedding_pipeline() { assert_eq!(embedding.len(), 128, "Embedding dimension mismatch"); // Verify no NaN or Inf - assert!(embedding.iter().all(|x| x.is_finite()), "Embedding contains non-finite values"); + assert!( + embedding.iter().all(|x| x.is_finite()), + "Embedding contains non-finite values" + ); } #[test] @@ -106,7 +104,8 @@ fn test_different_image_sizes() { embedding_dim: 32, normalize: true, quantized: false, - }).expect("Failed to create embedder"); + }) + .expect("Failed to create embedder"); // Test with 64x64 image let image_64 = vec![128u8; 64 * 64 * 4]; @@ -121,15 +120,14 @@ fn test_grayscale_vs_color_images() { embedding_dim: 16, normalize: false, quantized: false, - }).expect("Failed to create embedder"); + }) + .expect("Failed to create embedder"); // Uniform gray image (all pixels same value) let gray_image: Vec = vec![128; 32 * 32 * 4]; // Colorful image (varying pixels) - let color_image: Vec = (0..(32 * 32 * 4)) - .map(|i| ((i * 37) % 256) as u8) - .collect(); + let color_image: Vec = (0..(32 * 32 * 4)).map(|i| ((i * 37) % 256) as u8).collect(); let emb_gray = embedder.extract(&gray_image, 32, 32).expect("Failed"); let emb_color = embedder.extract(&color_image, 32, 32).expect("Failed"); @@ -139,11 +137,16 @@ fn test_grayscale_vs_color_images() { assert_eq!(emb_color.len(), 16); // They should be different - let diff_count = emb_gray.iter().zip(emb_color.iter()) + let diff_count = emb_gray + .iter() + .zip(emb_color.iter()) .filter(|(a, b)| (*a - *b).abs() > 1e-10) .count(); - assert!(diff_count > 0, "Different images should produce different embeddings"); + assert!( + diff_count > 0, + "Different images should produce different embeddings" + ); } // ============================================================================ @@ -157,16 +160,15 @@ fn test_similar_images_similar_embeddings() { embedding_dim: 32, normalize: true, quantized: false, - }).expect("Failed to create embedder"); + }) + .expect("Failed to create embedder"); // Two similar images (same content, slight variation) let image1: Vec = vec![128; 32 * 32 * 4]; let image2: Vec = vec![130; 32 * 32 * 4]; // Slightly brighter // Very different image - let image3: Vec = (0..(32 * 32 * 4)) - .map(|i| ((i * 37) % 256) as u8) - .collect(); + let image3: Vec = (0..(32 * 32 * 4)).map(|i| ((i * 37) % 256) as u8).collect(); let emb1 = embedder.extract(&image1, 32, 32).expect("Failed"); let emb2 = embedder.extract(&image2, 32, 32).expect("Failed"); @@ -193,14 +195,15 @@ fn test_embedding_extractor_trait() { embedding_dim: 64, normalize: true, quantized: false, - }).expect("Failed to create embedder"); + }) + .expect("Failed to create embedder"); // Use trait methods assert_eq!(embedder.embedding_dim(), 64); let image = vec![128u8; 32 * 32 * 4]; - let embedding = EmbeddingExtractor::extract(&embedder, &image, 32, 32) - .expect("Trait extraction failed"); + let embedding = + EmbeddingExtractor::extract(&embedder, &image, 32, 32).expect("Trait extraction failed"); assert_eq!(embedding.len(), 64); } @@ -216,7 +219,8 @@ fn test_invalid_image_dimensions() { embedding_dim: 32, normalize: true, quantized: false, - }).expect("Failed to create embedder"); + }) + .expect("Failed to create embedder"); // Image data doesn't match dimensions (too small) let image: Vec = vec![128; 100]; @@ -230,20 +234,20 @@ fn test_invalid_image_dimensions() { fn test_zero_dimension_image() { use std::panic; - let embedder = CnnEmbedder::new(EmbeddingConfig::default()) - .expect("Failed to create embedder"); + let embedder = CnnEmbedder::new(EmbeddingConfig::default()).expect("Failed to create embedder"); let image: Vec = vec![]; // Zero dimension should either return an error or panic // (currently panics due to index bounds check in SIMD code) - let result = panic::catch_unwind(panic::AssertUnwindSafe(|| { - embedder.extract(&image, 0, 0) - })); + let result = panic::catch_unwind(panic::AssertUnwindSafe(|| embedder.extract(&image, 0, 0))); // Either panicked or returned an error is acceptable for invalid input let failed = result.is_err() || result.map(|r| r.is_err()).unwrap_or(false); - assert!(failed, "Should fail with zero dimensions (either panic or error)"); + assert!( + failed, + "Should fail with zero dimensions (either panic or error)" + ); } // ============================================================================ @@ -257,7 +261,8 @@ fn test_extraction_deterministic() { embedding_dim: 16, normalize: true, quantized: false, - }).expect("Failed to create embedder"); + }) + .expect("Failed to create embedder"); let image = vec![128u8; 32 * 32 * 4]; @@ -285,7 +290,8 @@ fn test_concurrent_extraction() { embedding_dim: 16, normalize: true, quantized: false, - }).expect("Failed to create embedder") + }) + .expect("Failed to create embedder"), ); let handles: Vec<_> = (0..4) @@ -298,7 +304,8 @@ fn test_concurrent_extraction() { }) .collect(); - let results: Vec<_> = handles.into_iter() + let results: Vec<_> = handles + .into_iter() .map(|h| h.join().expect("Thread panicked")) .collect(); @@ -321,7 +328,8 @@ fn test_multiple_extractions_no_leak() { embedding_dim: 16, normalize: true, quantized: false, - }).expect("Failed to create embedder"); + }) + .expect("Failed to create embedder"); // Process many images to check for memory leaks for i in 0..100 { @@ -346,15 +354,15 @@ fn test_embedder_with_infonce() { embedding_dim: 8, normalize: true, quantized: false, - }).expect("Failed to create embedder"); + }) + .expect("Failed to create embedder"); // Create augmented pairs (simulate SimCLR) - let images: Vec> = (0..4) - .map(|i| vec![(i * 50) as u8; 16 * 16 * 4]) - .collect(); + let images: Vec> = (0..4).map(|i| vec![(i * 50) as u8; 16 * 16 * 4]).collect(); // Extract embeddings (convert f32 to f64 for InfoNCE) - let embeddings: Vec> = images.iter() + let embeddings: Vec> = images + .iter() .map(|img| { let emb = embedder.extract(img, 16, 16).expect("Failed"); emb.into_iter().map(|x| x as f64).collect() @@ -378,22 +386,33 @@ fn test_embedder_with_triplet() { embedding_dim: 8, normalize: true, quantized: false, - }).expect("Failed to create embedder"); + }) + .expect("Failed to create embedder"); // Create anchor, positive, negative images let anchor_img = vec![128u8; 16 * 16 * 4]; - let positive_img = vec![130u8; 16 * 16 * 4]; // Similar to anchor - let negative_img: Vec = (0..(16 * 16 * 4)) - .map(|i| ((i * 37) % 256) as u8) - .collect(); + let positive_img = vec![130u8; 16 * 16 * 4]; // Similar to anchor + let negative_img: Vec = (0..(16 * 16 * 4)).map(|i| ((i * 37) % 256) as u8).collect(); // Extract embeddings (convert f32 to f64 for TripletLoss) - let anchor: Vec = embedder.extract(&anchor_img, 16, 16) - .expect("Failed").into_iter().map(|x| x as f64).collect(); - let positive: Vec = embedder.extract(&positive_img, 16, 16) - .expect("Failed").into_iter().map(|x| x as f64).collect(); - let negative: Vec = embedder.extract(&negative_img, 16, 16) - .expect("Failed").into_iter().map(|x| x as f64).collect(); + let anchor: Vec = embedder + .extract(&anchor_img, 16, 16) + .expect("Failed") + .into_iter() + .map(|x| x as f64) + .collect(); + let positive: Vec = embedder + .extract(&positive_img, 16, 16) + .expect("Failed") + .into_iter() + .map(|x| x as f64) + .collect(); + let negative: Vec = embedder + .extract(&negative_img, 16, 16) + .expect("Failed") + .into_iter() + .map(|x| x as f64) + .collect(); // Compute triplet loss let loss_fn = TripletLoss::new(0.5); @@ -418,7 +437,11 @@ fn test_simd_functions_available() { let dot = simd::dot_product_simd(&a, &b); // 1*4 + 2*3 + 3*2 + 4*1 = 4 + 6 + 6 + 4 = 20 - assert!((dot - 20.0).abs() < 1e-5, "Expected dot product to be 20.0, got {}", dot); + assert!( + (dot - 20.0).abs() < 1e-5, + "Expected dot product to be 20.0, got {}", + dot + ); } #[test] @@ -452,7 +475,7 @@ fn test_simd_relu6() { #[test] fn test_layers_module_available() { - use ruvector_cnn::layers::{conv2d_3x3, batch_norm, relu, relu6, hard_swish, global_avg_pool}; + use ruvector_cnn::layers::{batch_norm, conv2d_3x3, global_avg_pool, hard_swish, relu, relu6}; // Test standalone layer functions let input = vec![0.5f32; 3 * 8 * 8]; // 3 channels, 8x8 diff --git a/crates/ruvector-cnn/tests/kernel_equivalence.rs b/crates/ruvector-cnn/tests/kernel_equivalence.rs index 10de444a7..d61af9b6f 100644 --- a/crates/ruvector-cnn/tests/kernel_equivalence.rs +++ b/crates/ruvector-cnn/tests/kernel_equivalence.rs @@ -31,7 +31,11 @@ mod kernel_equivalence { assert!( diff <= tolerance, "{}: Element {} differs by {}: {} vs {}", - context, i, diff, va, vb + context, + i, + diff, + va, + vb ); } } diff --git a/crates/ruvector-cnn/tests/layers_test.rs b/crates/ruvector-cnn/tests/layers_test.rs index 23a8ea46a..40cbb72c8 100644 --- a/crates/ruvector-cnn/tests/layers_test.rs +++ b/crates/ruvector-cnn/tests/layers_test.rs @@ -7,9 +7,8 @@ //! - Pooling operations use ruvector_cnn::layers::{ - Activation, ActivationType, BatchNorm, Conv2d, DepthwiseSeparableConv, - GlobalAvgPool, HardSwish, Layer, MaxPool2d, AvgPool2d, ReLU, ReLU6, - Sigmoid, Swish, TensorShape, + Activation, ActivationType, AvgPool2d, BatchNorm, Conv2d, DepthwiseSeparableConv, + GlobalAvgPool, HardSwish, Layer, MaxPool2d, ReLU, ReLU6, Sigmoid, Swish, TensorShape, }; use ruvector_cnn::{simd, Tensor}; @@ -160,7 +159,8 @@ fn test_batchnorm_with_running_stats() { let mut bn = BatchNorm::new(2); // Set mean=[1, 2], var=[1, 4] - bn.set_running_stats(vec![1.0, 2.0], vec![1.0, 4.0]).unwrap(); + bn.set_running_stats(vec![1.0, 2.0], vec![1.0, 4.0]) + .unwrap(); let input = Tensor::from_data(vec![1.0, 2.0, 3.0, 4.0], &[1, 2, 1, 2]).unwrap(); let output = bn.forward(&input).unwrap(); @@ -318,7 +318,7 @@ fn test_global_avg_pool_computes_average() { // Create input where channel 0 = 1, channel 1 = 2 let mut data = vec![0.0; 2 * 2 * 2]; for i in 0..4 { - data[i * 2] = 1.0; // channel 0 + data[i * 2] = 1.0; // channel 0 data[i * 2 + 1] = 2.0; // channel 1 } let input = Tensor::from_data(data, &[1, 2, 2, 2]).unwrap(); diff --git a/crates/ruvector-cnn/tests/quality_validation.rs b/crates/ruvector-cnn/tests/quality_validation.rs index 34b6f33c1..4c0cc7a88 100644 --- a/crates/ruvector-cnn/tests/quality_validation.rs +++ b/crates/ruvector-cnn/tests/quality_validation.rs @@ -5,7 +5,7 @@ //! - Per-layer MSE tracking //! - Embedding validation on test dataset -use ruvector_cnn::int8::{QuantParams, quantize_tensor, dequantize_tensor}; +use ruvector_cnn::int8::{dequantize_tensor, quantize_tensor, QuantParams}; #[cfg(test)] mod quality_tests { @@ -30,13 +30,15 @@ mod quality_tests { fn mean_squared_error(a: &[f32], b: &[f32]) -> f32 { assert_eq!(a.len(), b.len(), "Tensors must have same length"); - let mse: f32 = a.iter() + let mse: f32 = a + .iter() .zip(b.iter()) .map(|(x, y)| { let diff = x - y; diff * diff }) - .sum::() / a.len() as f32; + .sum::() + / a.len() as f32; mse } @@ -86,14 +88,15 @@ mod quality_tests { let mut rng = fastrand::Rng::with_seed(42); println!("\nPer-Layer MSE Analysis:"); - println!("{:<15} {:>10} {:>15} {:>15}", "Layer", "Size", "MSE", "Cosine Sim"); + println!( + "{:<15} {:>10} {:>15} {:>15}", + "Layer", "Size", "MSE", "Cosine Sim" + ); println!("{}", "-".repeat(60)); for (layer_name, size) in layer_sizes { // Generate random tensor - let fp32_tensor: Vec = (0..size) - .map(|_| rng.f32() * 2.0 - 1.0) - .collect(); + let fp32_tensor: Vec = (0..size).map(|_| rng.f32() * 2.0 - 1.0).collect(); // Quantize and dequantize let params = QuantParams::from_tensor(&fp32_tensor); @@ -113,7 +116,8 @@ mod quality_tests { assert!( similarity >= 0.99, "Layer {} has low similarity: {:.6}", - layer_name, similarity + layer_name, + similarity ); } } @@ -154,7 +158,8 @@ mod quality_tests { // Box-Muller transform for Gaussian let u1 = rng.f32(); let u2 = rng.f32(); - ((-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos()) * 0.5 + ((-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos()) + * 0.5 }) .collect() }, @@ -184,10 +189,7 @@ mod quality_tests { let similarity = cosine_similarity(&test_case.embedding, &dequantized); let mse = mean_squared_error(&test_case.embedding, &dequantized); - println!( - "{:<15} {:>15.6} {:>15.6e}", - test_case.name, similarity, mse - ); + println!("{:<15} {:>15.6} {:>15.6e}", test_case.name, similarity, mse); assert!( similarity >= test_case.expected_min_similarity, @@ -243,7 +245,10 @@ mod quality_tests { ]; println!("\nQuantization Range Edge Cases:"); - println!("{:<20} {:>15} {:>15}", "Edge Case", "Max Error", "Cosine Sim"); + println!( + "{:<20} {:>15} {:>15}", + "Edge Case", "Max Error", "Cosine Sim" + ); println!("{}", "-".repeat(55)); for edge_case in edge_cases { @@ -251,7 +256,9 @@ mod quality_tests { let int8_tensor = quantize_tensor(&edge_case.values, ¶ms); let dequantized = dequantize_tensor(&int8_tensor, ¶ms); - let max_error = edge_case.values.iter() + let max_error = edge_case + .values + .iter() .zip(dequantized.iter()) .map(|(a, b)| (a - b).abs()) .fold(0.0f32, f32::max); @@ -269,7 +276,8 @@ mod quality_tests { assert!( similarity >= 0.95, "Edge case '{}' has low similarity: {:.6}", - edge_case.name, similarity + edge_case.name, + similarity ); } } @@ -285,15 +293,12 @@ mod quality_tests { // Generate batch of embeddings let batch: Vec> = (0..batch_size) - .map(|_| { - (0..embedding_size) - .map(|_| rng.f32() * 2.0 - 1.0) - .collect() - }) + .map(|_| (0..embedding_size).map(|_| rng.f32() * 2.0 - 1.0).collect()) .collect(); // Quantize each independently - let individual_results: Vec> = batch.iter() + let individual_results: Vec> = batch + .iter() .map(|emb| { let params = QuantParams::from_tensor(emb); let int8 = quantize_tensor(emb, ¶ms); @@ -307,7 +312,8 @@ mod quality_tests { assert!( similarity >= 0.995, "Batch item {} has low similarity: {:.6}", - i, similarity + i, + similarity ); } @@ -320,9 +326,7 @@ mod quality_tests { let size = 1024; let mut rng = fastrand::Rng::with_seed(999); - let fp32_tensor: Vec = (0..size) - .map(|_| rng.f32() * 2.0 - 1.0) - .collect(); + let fp32_tensor: Vec = (0..size).map(|_| rng.f32() * 2.0 - 1.0).collect(); // Quantize twice let params1 = QuantParams::from_tensor(&fp32_tensor); @@ -332,8 +336,14 @@ mod quality_tests { let int8_2 = quantize_tensor(&fp32_tensor, ¶ms2); // Parameters should be identical - assert_eq!(params1.scale, params2.scale, "Scale should be deterministic"); - assert_eq!(params1.zero_point, params2.zero_point, "Zero point should be deterministic"); + assert_eq!( + params1.scale, params2.scale, + "Scale should be deterministic" + ); + assert_eq!( + params1.zero_point, params2.zero_point, + "Zero point should be deterministic" + ); // Quantized values should be identical assert_eq!(int8_1, int8_2, "Quantized tensors should be identical"); @@ -349,9 +359,7 @@ mod quality_tests { let mut rng = fastrand::Rng::with_seed(111); // Create symmetric tensor (mirrored around zero) - let positive: Vec = (0..size) - .map(|_| rng.f32()) - .collect(); + let positive: Vec = (0..size).map(|_| rng.f32()).collect(); let mut symmetric = positive.clone(); symmetric.extend(positive.iter().map(|&x| -x)); diff --git a/crates/ruvector-cnn/tests/simd_test.rs b/crates/ruvector-cnn/tests/simd_test.rs index a5283154d..4ae49e3f7 100644 --- a/crates/ruvector-cnn/tests/simd_test.rs +++ b/crates/ruvector-cnn/tests/simd_test.rs @@ -49,7 +49,9 @@ fn test_dot_product_large_vector() { #[test] fn test_dot_product_various_sizes() { // Test sizes that exercise different SIMD code paths - for size in [1, 3, 7, 8, 9, 15, 16, 17, 31, 32, 63, 64, 100, 128, 255, 256] { + for size in [ + 1, 3, 7, 8, 9, 15, 16, 17, 31, 32, 63, 64, 100, 128, 255, 256, + ] { let a: Vec = (0..size).map(|i| (i as f32) * 0.1).collect(); let b: Vec = (0..size).map(|i| ((size - i) as f32) * 0.1).collect(); @@ -311,7 +313,16 @@ fn test_batch_norm_identity() { let mut output = vec![0.0; input.len()]; - simd::batch_norm_simd(&input, &mut output, &gamma, &beta, &mean, &var, 1e-5, channels); + simd::batch_norm_simd( + &input, + &mut output, + &gamma, + &beta, + &mean, + &var, + 1e-5, + channels, + ); for (i, (&inp, &out)) in input.iter().zip(output.iter()).enumerate() { assert!( @@ -339,7 +350,16 @@ fn test_batch_norm_normalization() { let mut output = vec![0.0; 4]; - simd::batch_norm_simd(&input, &mut output, &gamma, &beta, &mean, &var, 1e-5, channels); + simd::batch_norm_simd( + &input, + &mut output, + &gamma, + &beta, + &mean, + &var, + 1e-5, + channels, + ); // (5 - 5) / sqrt(1 + eps) = 0 // (10 - 10) / sqrt(1 + eps) = 0 @@ -362,7 +382,9 @@ fn test_conv_3x3_simd_vs_scalar() { let padding = 1; let input: Vec = (0..in_h * in_w * in_c).map(|i| (i as f32) * 0.01).collect(); - let kernel: Vec = (0..out_c * 3 * 3 * in_c).map(|i| (i as f32) * 0.001).collect(); + let kernel: Vec = (0..out_c * 3 * 3 * in_c) + .map(|i| (i as f32) * 0.001) + .collect(); let out_h = (in_h + 2 * padding - 3) / stride + 1; let out_w = (in_w + 2 * padding - 3) / stride + 1; @@ -394,13 +416,7 @@ fn test_conv_3x3_simd_vs_scalar() { ); for (i, (&s, &r)) in simd_output.iter().zip(scalar_output.iter()).enumerate() { - assert!( - (s - r).abs() < 0.1, - "Index {}: SIMD={}, Scalar={}", - i, - s, - r - ); + assert!((s - r).abs() < 0.1, "Index {}: SIMD={}, Scalar={}", i, s, r); } } @@ -438,13 +454,7 @@ fn test_depthwise_conv_3x3_simd_vs_scalar() { ); for (i, (&s, &r)) in simd_output.iter().zip(scalar_output.iter()).enumerate() { - assert!( - (s - r).abs() < 0.1, - "Index {}: SIMD={}, Scalar={}", - i, - s, - r - ); + assert!((s - r).abs() < 0.1, "Index {}: SIMD={}, Scalar={}", i, s, r); } } @@ -579,7 +589,9 @@ fn test_simd_single_element() { fn test_simd_remainder_handling() { // Test sizes that don't align with SIMD width (not multiple of 8) for size in [3, 7, 9, 15, 17, 25, 33] { - let input: Vec = (0..size).map(|i| (i as f32) - (size as f32 / 2.0)).collect(); + let input: Vec = (0..size) + .map(|i| (i as f32) - (size as f32 / 2.0)) + .collect(); let mut simd_output = vec![0.0; size]; let mut scalar_output = vec![0.0; size]; @@ -702,7 +714,16 @@ fn test_dot_product_nan_propagation() { #[test] fn test_activation_with_special_values() { - let input = vec![f32::INFINITY, f32::NEG_INFINITY, f32::NAN, 0.0, 1.0, -1.0, 6.0, 100.0]; + let input = vec![ + f32::INFINITY, + f32::NEG_INFINITY, + f32::NAN, + 0.0, + 1.0, + -1.0, + 6.0, + 100.0, + ]; let mut output = vec![0.0; 8]; simd::relu_simd(&input, &mut output); diff --git a/crates/ruvector-postgres/src/graph/mod.rs b/crates/ruvector-postgres/src/graph/mod.rs index 7c1524fd9..32a0d77dd 100644 --- a/crates/ruvector-postgres/src/graph/mod.rs +++ b/crates/ruvector-postgres/src/graph/mod.rs @@ -92,12 +92,11 @@ fn load_graph_from_tables(name: &str) -> Option> { .get_by_name::("properties")? .unwrap_or(JsonB(serde_json::json!({}))); - let props: HashMap = - if let JsonValue::Object(map) = props_json.0 { - map.into_iter().collect() - } else { - HashMap::new() - }; + let props: HashMap = if let JsonValue::Object(map) = props_json.0 { + map.into_iter().collect() + } else { + HashMap::new() + }; let mut node = Node::new(id as u64); node.labels = labels; @@ -131,12 +130,11 @@ fn load_graph_from_tables(name: &str) -> Option> { .get_by_name::("properties")? .unwrap_or(JsonB(serde_json::json!({}))); - let props: HashMap = - if let JsonValue::Object(map) = props_json.0 { - map.into_iter().collect() - } else { - HashMap::new() - }; + let props: HashMap = if let JsonValue::Object(map) = props_json.0 { + map.into_iter().collect() + } else { + HashMap::new() + }; let mut edge = Edge::new(id as u64, source as u64, target as u64, edge_type); edge.properties = props; @@ -262,7 +260,11 @@ pub fn list_graphs() -> Vec { let mut names: Vec = Vec::new(); let _ = Spi::connect(|client| { - let tup_table = client.select("SELECT name FROM _ruvector_graphs ORDER BY name", None, None)?; + let tup_table = client.select( + "SELECT name FROM _ruvector_graphs ORDER BY name", + None, + None, + )?; for row in tup_table { if let Some(name) = row.get_by_name::("name")? { names.push(name); diff --git a/crates/ruvector-postgres/src/graph/sparql/mod.rs b/crates/ruvector-postgres/src/graph/sparql/mod.rs index 96f8204ea..01b0387fe 100644 --- a/crates/ruvector-postgres/src/graph/sparql/mod.rs +++ b/crates/ruvector-postgres/src/graph/sparql/mod.rs @@ -93,7 +93,9 @@ fn load_store_from_tables(name: &str) -> Option> { for row in tup_table { let subject: String = row.get_by_name::("subject")?.unwrap_or_default(); - let predicate: String = row.get_by_name::("predicate")?.unwrap_or_default(); + let predicate: String = row + .get_by_name::("predicate")? + .unwrap_or_default(); let object: String = row.get_by_name::("object")?.unwrap_or_default(); let graph_name: Option = row.get_by_name::("graph_name")?; @@ -188,8 +190,11 @@ pub fn list_stores() -> Vec { let mut names: Vec = Vec::new(); let _ = Spi::connect(|client| { - let tup_table = - client.select("SELECT name FROM _ruvector_rdf_stores ORDER BY name", None, None)?; + let tup_table = client.select( + "SELECT name FROM _ruvector_rdf_stores ORDER BY name", + None, + None, + )?; for row in tup_table { if let Some(name) = row.get_by_name::("name")? { names.push(name); diff --git a/crates/ruvector-postgres/src/index/hnsw_am.rs b/crates/ruvector-postgres/src/index/hnsw_am.rs index db0352eb8..d1732131b 100644 --- a/crates/ruvector-postgres/src/index/hnsw_am.rs +++ b/crates/ruvector-postgres/src/index/hnsw_am.rs @@ -674,7 +674,8 @@ unsafe fn hnsw_search( pgrx::warning!( "HNSW search: entry_point is InvalidBlockNumber (node_count={}, dims={}). \ Index may need REINDEX. Check: SELECT ruvector_hnsw_debug('index_name')", - meta.node_count, meta.dimensions + meta.node_count, + meta.dimensions ); return Vec::new(); } @@ -2127,15 +2128,15 @@ fn ruvector_hnsw_debug(index_name: &str) -> pgrx::JsonB { ); let index_exists: bool = Spi::connect(|client| { - let row = client.select(&query, None, None)? - .first(); + let row = client.select(&query, None, None)?.first(); let found = match row.get_datum_by_ordinal(1) { Ok(Some(_)) => true, _ => false, }; Ok::(found) - }).unwrap_or(false); + }) + .unwrap_or(false); if !index_exists { return pgrx::JsonB(serde_json::json!({ "error": format!("Index '{}' not found or is not an HNSW index", index_name), @@ -2152,16 +2153,20 @@ fn ruvector_hnsw_debug(index_name: &str) -> pgrx::JsonB { ); let (rel_size, rel_path) = Spi::connect(|client| { - let row = client.select(&meta_query, None, None)? - .first(); - let size: Option = row.get_datum_by_ordinal(1) - .ok().flatten() + let row = client.select(&meta_query, None, None)?.first(); + let size: Option = row + .get_datum_by_ordinal(1) + .ok() + .flatten() .and_then(|d| unsafe { i64::from_polymorphic_datum(d, false, pg_sys::INT8OID) }); - let path: Option = row.get_datum_by_ordinal(2) - .ok().flatten() + let path: Option = row + .get_datum_by_ordinal(2) + .ok() + .flatten() .and_then(|d| unsafe { String::from_polymorphic_datum(d, false, pg_sys::TEXTOID) }); Ok::<_, pgrx::spi::SpiError>((size.unwrap_or(0), path.unwrap_or_default())) - }).unwrap_or((0, String::new())); + }) + .unwrap_or((0, String::new())); let pages = rel_size / 8192; // BLCKSZ let has_data = pages > 1; // More than just meta page diff --git a/crates/ruvector-robotics/benches/robotics_benchmarks.rs b/crates/ruvector-robotics/benches/robotics_benchmarks.rs index 749102c6a..cfd0534ce 100644 --- a/crates/ruvector-robotics/benches/robotics_benchmarks.rs +++ b/crates/ruvector-robotics/benches/robotics_benchmarks.rs @@ -17,11 +17,11 @@ use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; use std::collections::HashMap; +use ruvector_robotics::bridge::gaussian::{gaussians_from_cloud, to_viewer_json}; use ruvector_robotics::bridge::{ GaussianConfig, Obstacle, OccupancyGrid, Point3D, PointCloud, RobotState, SceneObject, SpatialIndex, }; -use ruvector_robotics::bridge::gaussian::{gaussians_from_cloud, to_viewer_json}; use ruvector_robotics::cognitive::behavior_tree::{BehaviorNode, BehaviorStatus, BehaviorTree}; use ruvector_robotics::mcp::executor::ToolExecutor; use ruvector_robotics::mcp::ToolRequest; @@ -75,11 +75,8 @@ fn generate_scene_objects(n: usize) -> Vec { .map(|i| { let angle = (i as f64) * 2.0 * std::f64::consts::PI / (n as f64); let r = 5.0 + (i as f64) * 0.1; - let mut obj = SceneObject::new( - i, - [r * angle.cos(), r * angle.sin(), 0.0], - [0.5, 0.5, 1.8], - ); + let mut obj = + SceneObject::new(i, [r * angle.cos(), r * angle.sin(), 0.0], [0.5, 0.5, 1.8]); obj.label = if i % 3 == 0 { "person".to_string() } else if i % 3 == 1 { @@ -108,32 +105,21 @@ fn bench_point_cloud_conversion(c: &mut Criterion) { for size in [100, 1_000, 10_000, 100_000] { group.throughput(Throughput::Elements(size as u64)); - group.bench_with_input( - BenchmarkId::new("to_flat_vectors", size), - &size, - |b, &n| { - let points = generate_point_cloud(n); - b.iter(|| { - let vectors: Vec> = points.iter().map(|p| p.to_vec()).collect(); - black_box(vectors) - }) - }, - ); - - group.bench_with_input( - BenchmarkId::new("to_flat_array", size), - &size, - |b, &n| { - let points = generate_point_cloud(n); - b.iter(|| { - let flat: Vec = points - .iter() - .flat_map(|p| [p.x, p.y, p.z]) - .collect(); - black_box(flat) - }) - }, - ); + group.bench_with_input(BenchmarkId::new("to_flat_vectors", size), &size, |b, &n| { + let points = generate_point_cloud(n); + b.iter(|| { + let vectors: Vec> = points.iter().map(|p| p.to_vec()).collect(); + black_box(vectors) + }) + }); + + group.bench_with_input(BenchmarkId::new("to_flat_array", size), &size, |b, &n| { + let points = generate_point_cloud(n); + b.iter(|| { + let flat: Vec = points.iter().flat_map(|p| [p.x, p.y, p.z]).collect(); + black_box(flat) + }) + }); } group.finish(); @@ -331,11 +317,9 @@ fn bench_trajectory_prediction(c: &mut Criterion) { let s4: f64 = times.iter().map(|t| t * t * t * t).sum(); let sy: f64 = vals.iter().sum(); let sty: f64 = times.iter().zip(vals.iter()).map(|(t, y)| t * y).sum(); - let st2y: f64 = - times.iter().zip(vals.iter()).map(|(t, y)| t * t * y).sum(); + let st2y: f64 = times.iter().zip(vals.iter()).map(|(t, y)| t * t * y).sum(); - let det = nn * (s2 * s4 - s3 * s3) - - s1 * (s1 * s4 - s3 * s2) + let det = nn * (s2 * s4 - s3 * s3) - s1 * (s1 * s4 - s3 * s2) + s2 * (s1 * s3 - s2 * s2); if det.abs() > 1e-12 { coeffs[axis][0] = (sy * (s2 * s4 - s3 * s3) @@ -461,56 +445,48 @@ fn bench_behavior_tree_tick(c: &mut Criterion) { let n_leaves = 1usize << depth; group.throughput(Throughput::Elements(n_leaves as u64)); - group.bench_with_input( - BenchmarkId::new("sequence_tree", depth), - &depth, - |b, &d| { - fn build_seq(depth: usize) -> BehaviorNode { - if depth == 0 { - BehaviorNode::Action("leaf".into()) - } else { - BehaviorNode::Sequence(vec![ - build_seq(depth - 1), - BehaviorNode::Action("leaf".into()), - ]) - } + group.bench_with_input(BenchmarkId::new("sequence_tree", depth), &depth, |b, &d| { + fn build_seq(depth: usize) -> BehaviorNode { + if depth == 0 { + BehaviorNode::Action("leaf".into()) + } else { + BehaviorNode::Sequence(vec![ + build_seq(depth - 1), + BehaviorNode::Action("leaf".into()), + ]) } - let root = build_seq(d); - let mut tree = BehaviorTree::new(root); - tree.set_action_result("leaf", BehaviorStatus::Success); - - b.iter(|| { - let status = tree.tick(); - black_box(status) - }) - }, - ); - - group.bench_with_input( - BenchmarkId::new("selector_tree", depth), - &depth, - |b, &d| { - fn build_sel(depth: usize) -> BehaviorNode { - if depth == 0 { - BehaviorNode::Action("leaf".into()) - } else { - BehaviorNode::Selector(vec![ - BehaviorNode::Action("fail".into()), - build_sel(depth - 1), - ]) - } + } + let root = build_seq(d); + let mut tree = BehaviorTree::new(root); + tree.set_action_result("leaf", BehaviorStatus::Success); + + b.iter(|| { + let status = tree.tick(); + black_box(status) + }) + }); + + group.bench_with_input(BenchmarkId::new("selector_tree", depth), &depth, |b, &d| { + fn build_sel(depth: usize) -> BehaviorNode { + if depth == 0 { + BehaviorNode::Action("leaf".into()) + } else { + BehaviorNode::Selector(vec![ + BehaviorNode::Action("fail".into()), + build_sel(depth - 1), + ]) } - let root = build_sel(d); - let mut tree = BehaviorTree::new(root); - tree.set_action_result("fail", BehaviorStatus::Failure); - tree.set_action_result("leaf", BehaviorStatus::Success); - - b.iter(|| { - let status = tree.tick(); - black_box(status) - }) - }, - ); + } + let root = build_sel(d); + let mut tree = BehaviorTree::new(root); + tree.set_action_result("fail", BehaviorStatus::Failure); + tree.set_action_result("leaf", BehaviorStatus::Success); + + b.iter(|| { + let status = tree.tick(); + black_box(status) + }) + }); } group.finish(); @@ -615,14 +591,12 @@ fn bench_memory_recall(c: &mut Criterion) { let query: Vec = (0..dim).map(|d| pseudo_random_f32(9999, d)).collect(); b.iter(|| { - let q_norm: f32 = - query.iter().map(|x| x * x).sum::().sqrt().max(1e-10); + let q_norm: f32 = query.iter().map(|x| x * x).sum::().sqrt().max(1e-10); let mut sims: Vec<(usize, f32)> = episodes .iter() .enumerate() .map(|(i, ep)| { - let dot: f32 = - query.iter().zip(ep.iter()).map(|(a, b)| a * b).sum(); + let dot: f32 = query.iter().zip(ep.iter()).map(|(a, b)| a * b).sum(); let ep_norm: f32 = ep.iter().map(|x| x * x).sum::().sqrt().max(1e-10); (i, dot / (q_norm * ep_norm)) @@ -647,8 +621,7 @@ fn bench_memory_recall(c: &mut Criterion) { .collect() }) .collect(); - let query: Vec = - (0..dim).map(|d| pseudo_random_f32(8888, d)).collect(); + let query: Vec = (0..dim).map(|d| pseudo_random_f32(8888, d)).collect(); b.iter(|| { let mut dists: Vec<(usize, f32)> = episodes @@ -850,8 +823,7 @@ fn bench_swarm_task_assignment(c: &mut Criterion) { for ti in 0..nt { let dx = robots[ri][0] - tasks[ti][0]; let dy = robots[ri][1] - tasks[ti][1]; - let value = - 100.0 - (dx * dx + dy * dy).sqrt() - prices[ti]; + let value = 100.0 - (dx * dx + dy * dy).sqrt() - prices[ti]; if value > best_val { second_val = best_val; best_val = value; @@ -1135,14 +1107,13 @@ fn bench_mcp_tool_execution(c: &mut Criterion) { // Benchmark predict_trajectory (lightweight) group.bench_function("predict_trajectory_10steps", |b| { let mut exec = ToolExecutor::new(); - let args: HashMap = - serde_json::from_value(serde_json::json!({ - "position": [0.0, 0.0, 0.0], - "velocity": [1.0, 0.5, 0.0], - "steps": 10, - "dt": 0.1, - })) - .unwrap(); + let args: HashMap = serde_json::from_value(serde_json::json!({ + "position": [0.0, 0.0, 0.0], + "velocity": [1.0, 0.5, 0.0], + "steps": 10, + "dt": 0.1, + })) + .unwrap(); let req = ToolRequest { tool_name: "predict_trajectory".to_string(), arguments: args, @@ -1201,12 +1172,11 @@ fn bench_mcp_tool_execution(c: &mut Criterion) { let points = generate_point_cloud(500); let cloud = PointCloud::new(points, 1000); let cloud_json = serde_json::to_string(&cloud).unwrap(); - let args: HashMap = - serde_json::from_value(serde_json::json!({ - "point_cloud_json": cloud_json, - "robot_position": [5.0, 5.0, 0.0], - })) - .unwrap(); + let args: HashMap = serde_json::from_value(serde_json::json!({ + "point_cloud_json": cloud_json, + "robot_position": [5.0, 5.0, 0.0], + })) + .unwrap(); let req = ToolRequest { tool_name: "detect_obstacles".to_string(), arguments: args, diff --git a/crates/ruvector-robotics/examples/behavior_tree.rs b/crates/ruvector-robotics/examples/behavior_tree.rs index 43e9b285e..cf6b7d469 100644 --- a/crates/ruvector-robotics/examples/behavior_tree.rs +++ b/crates/ruvector-robotics/examples/behavior_tree.rs @@ -5,9 +5,7 @@ //! - Ticking the tree and observing status changes //! - Using the blackboard for inter-node communication -use ruvector_robotics::cognitive::{ - BehaviorNode, BehaviorStatus, BehaviorTree, DecoratorType, -}; +use ruvector_robotics::cognitive::{BehaviorNode, BehaviorStatus, BehaviorTree, DecoratorType}; fn main() { println!("=== Behavior Tree Demo ===\n"); @@ -76,8 +74,15 @@ fn main() { for i in 1..=4 { let s = t2.tick(); - println!(" Tick {}: {:?}{}", i, s, - if s == BehaviorStatus::Failure { " (TIMED OUT)" } else { "" } + println!( + " Tick {}: {:?}{}", + i, + s, + if s == BehaviorStatus::Failure { + " (TIMED OUT)" + } else { + "" + } ); } diff --git a/crates/ruvector-robotics/examples/cognitive_loop.rs b/crates/ruvector-robotics/examples/cognitive_loop.rs index 79000155e..9ad2bdc6b 100644 --- a/crates/ruvector-robotics/examples/cognitive_loop.rs +++ b/crates/ruvector-robotics/examples/cognitive_loop.rs @@ -54,7 +54,10 @@ fn main() { // Act let cmd = core.act(decision); match &cmd.action { - ActionType::Move(pos) => println!(" Action: Move to [{:.1}, {:.1}, {:.1}]", pos[0], pos[1], pos[2]), + ActionType::Move(pos) => println!( + " Action: Move to [{:.1}, {:.1}, {:.1}]", + pos[0], pos[1], pos[2] + ), ActionType::Wait(ms) => println!(" Action: Wait {}ms", ms), _ => println!(" Action: {:?}", cmd.action), } diff --git a/crates/ruvector-robotics/examples/obstacle_detection.rs b/crates/ruvector-robotics/examples/obstacle_detection.rs index 3d6bebc29..c6080e443 100644 --- a/crates/ruvector-robotics/examples/obstacle_detection.rs +++ b/crates/ruvector-robotics/examples/obstacle_detection.rs @@ -6,8 +6,8 @@ //! - Classifying obstacles as Static, Dynamic, or Unknown use ruvector_robotics::bridge::{Point3D, PointCloud}; -use ruvector_robotics::perception::{ObstacleDetector, ObstacleClass}; use ruvector_robotics::perception::config::ObstacleConfig; +use ruvector_robotics::perception::{ObstacleClass, ObstacleDetector}; fn main() { println!("=== Obstacle Detection Demo ===\n"); diff --git a/crates/ruvector-robotics/examples/spatial_indexing.rs b/crates/ruvector-robotics/examples/spatial_indexing.rs index de49fe8cc..20fb5760f 100644 --- a/crates/ruvector-robotics/examples/spatial_indexing.rs +++ b/crates/ruvector-robotics/examples/spatial_indexing.rs @@ -16,11 +16,7 @@ fn main() { // Shelves along the x-axis. for row in 0..5 { for col in 0..10 { - points.push(Point3D::new( - col as f32 * 2.0, - row as f32 * 3.0, - 0.0, - )); + points.push(Point3D::new(col as f32 * 2.0, row as f32 * 3.0, 0.0)); } } // A few elevated points (items on shelves). diff --git a/crates/ruvector-robotics/examples/swarm_coordination.rs b/crates/ruvector-robotics/examples/swarm_coordination.rs index fdfc43183..829e615a5 100644 --- a/crates/ruvector-robotics/examples/swarm_coordination.rs +++ b/crates/ruvector-robotics/examples/swarm_coordination.rs @@ -50,8 +50,10 @@ fn main() { for robot in &robots { let registered = coordinator.register_robot(robot.clone()); - println!("Registered robot {} (speed={}, sensors={:?}): {}", - robot.id, robot.max_speed, robot.sensors, registered); + println!( + "Registered robot {} (speed={}, sensors={:?}): {}", + robot.id, robot.max_speed, robot.sensors, registered + ); } println!("\nActive robots: {}\n", coordinator.robot_count()); @@ -105,7 +107,10 @@ fn main() { let positions = coordinator.compute_formation(&formation); println!(" {} formation:", name); for (i, pos) in positions.iter().enumerate() { - println!(" Robot {}: [{:.2}, {:.2}, {:.2}]", i, pos[0], pos[1], pos[2]); + println!( + " Robot {}: [{:.2}, {:.2}, {:.2}]", + i, pos[0], pos[1], pos[2] + ); } } @@ -115,7 +120,11 @@ fn main() { println!( " Proposal: '{}' -> {} (for={}, against={})", result.proposal, - if result.accepted { "ACCEPTED" } else { "REJECTED" }, + if result.accepted { + "ACCEPTED" + } else { + "REJECTED" + }, result.votes_for, result.votes_against, ); diff --git a/crates/ruvector-robotics/src/bridge/converters.rs b/crates/ruvector-robotics/src/bridge/converters.rs index d00daaf15..61cb3f9b1 100644 --- a/crates/ruvector-robotics/src/bridge/converters.rs +++ b/crates/ruvector-robotics/src/bridge/converters.rs @@ -90,10 +90,7 @@ pub fn robot_state_to_vector(state: &RobotState) -> Vec { } /// Reconstruct a [`RobotState`] from a 9-element vector and a timestamp. -pub fn vector_to_robot_state( - v: &[f64], - timestamp: i64, -) -> Result { +pub fn vector_to_robot_state(v: &[f64], timestamp: i64) -> Result { if v.len() != 9 { return Err(ConversionError::LengthMismatch { expected: 9, @@ -144,9 +141,7 @@ pub fn occupancy_grid_to_vectors(grid: &OccupancyGrid) -> Vec> { type NodeFeatures = Vec>; type EdgeList = Vec<(usize, usize, f64)>; -pub fn scene_graph_to_adjacency( - scene: &SceneGraph, -) -> (NodeFeatures, EdgeList) { +pub fn scene_graph_to_adjacency(scene: &SceneGraph) -> (NodeFeatures, EdgeList) { let nodes: Vec> = scene .objects .iter() @@ -208,7 +203,13 @@ mod tests { let mut cloud = PointCloud::new(vec![Point3D::new(1.0, 2.0, 3.0)], 0); cloud.intensities = vec![]; let err = point_cloud_to_vectors_with_intensity(&cloud).unwrap_err(); - assert_eq!(err, ConversionError::LengthMismatch { expected: 1, got: 0 }); + assert_eq!( + err, + ConversionError::LengthMismatch { + expected: 1, + got: 0 + } + ); } #[test] @@ -231,7 +232,13 @@ mod tests { fn test_vectors_to_point_cloud_wrong_dim() { let vecs = vec![vec![1.0, 2.0]]; let err = vectors_to_point_cloud(&vecs, 0).unwrap_err(); - assert_eq!(err, ConversionError::LengthMismatch { expected: 3, got: 2 }); + assert_eq!( + err, + ConversionError::LengthMismatch { + expected: 3, + got: 2 + } + ); } #[test] @@ -269,7 +276,13 @@ mod tests { fn test_vector_to_robot_state_wrong_len() { let v = vec![1.0, 2.0, 3.0]; let err = vector_to_robot_state(&v, 0).unwrap_err(); - assert_eq!(err, ConversionError::LengthMismatch { expected: 9, got: 3 }); + assert_eq!( + err, + ConversionError::LengthMismatch { + expected: 9, + got: 3 + } + ); } #[test] @@ -383,7 +396,10 @@ mod tests { #[test] fn test_conversion_error_display() { - let e1 = ConversionError::LengthMismatch { expected: 3, got: 5 }; + let e1 = ConversionError::LengthMismatch { + expected: 3, + got: 5, + }; assert!(format!("{e1}").contains("3") && format!("{e1}").contains("5")); assert!(format!("{}", ConversionError::EmptyInput).contains("empty")); } diff --git a/crates/ruvector-robotics/src/bridge/gaussian.rs b/crates/ruvector-robotics/src/bridge/gaussian.rs index e6bf50552..e3aed2162 100644 --- a/crates/ruvector-robotics/src/bridge/gaussian.rs +++ b/crates/ruvector-robotics/src/bridge/gaussian.rs @@ -75,10 +75,7 @@ impl Default for GaussianConfig { /// Convert a [`PointCloud`] into a [`GaussianSplatCloud`] by clustering nearby /// points and computing per-cluster statistics. -pub fn gaussians_from_cloud( - cloud: &PointCloud, - config: &GaussianConfig, -) -> GaussianSplatCloud { +pub fn gaussians_from_cloud(cloud: &PointCloud, config: &GaussianConfig) -> GaussianSplatCloud { if cloud.is_empty() || config.cell_size <= 0.0 { return GaussianSplatCloud { gaussians: Vec::new(), @@ -191,10 +188,7 @@ mod tests { #[test] fn test_single_cluster() { - let cloud = make_cloud( - &[[1.0, 0.0, 0.0], [1.1, 0.0, 0.0], [1.0, 0.1, 0.0]], - 1000, - ); + let cloud = make_cloud(&[[1.0, 0.0, 0.0], [1.1, 0.0, 0.0], [1.0, 0.1, 0.0]], 1000); let gs = gaussians_from_cloud(&cloud, &GaussianConfig::default()); assert_eq!(gs.len(), 1); let g = &gs.gaussians[0]; @@ -206,8 +200,10 @@ mod tests { fn test_two_clusters() { let cloud = make_cloud( &[ - [0.0, 0.0, 0.0], [0.1, 0.0, 0.0], - [10.0, 10.0, 0.0], [10.1, 10.0, 0.0], + [0.0, 0.0, 0.0], + [0.1, 0.0, 0.0], + [10.0, 10.0, 0.0], + [10.1, 10.0, 0.0], ], 2000, ); @@ -217,11 +213,11 @@ mod tests { #[test] fn test_min_cluster_size_filtering() { - let cloud = make_cloud( - &[[0.0, 0.0, 0.0], [10.0, 10.0, 0.0]], - 0, - ); - let config = GaussianConfig { min_cluster_size: 3, ..Default::default() }; + let cloud = make_cloud(&[[0.0, 0.0, 0.0], [10.0, 10.0, 0.0]], 0); + let config = GaussianConfig { + min_cluster_size: 3, + ..Default::default() + }; let gs = gaussians_from_cloud(&cloud, &config); assert!(gs.is_empty()); } @@ -229,10 +225,7 @@ mod tests { #[test] fn test_scale_reflects_spread() { // Use a larger cell size so all three points end up in one cluster. - let cloud = make_cloud( - &[[0.0, 0.0, 0.0], [0.3, 0.0, 0.0], [0.15, 0.0, 0.0]], - 0, - ); + let cloud = make_cloud(&[[0.0, 0.0, 0.0], [0.3, 0.0, 0.0], [0.15, 0.0, 0.0]], 0); let gs = gaussians_from_cloud(&cloud, &GaussianConfig::default()); assert_eq!(gs.len(), 1); let g = &gs.gaussians[0]; @@ -265,7 +258,10 @@ mod tests { #[test] fn test_zero_cell_size() { let cloud = make_cloud(&[[1.0, 0.0, 0.0]], 0); - let config = GaussianConfig { cell_size: 0.0, ..Default::default() }; + let config = GaussianConfig { + cell_size: 0.0, + ..Default::default() + }; let gs = gaussians_from_cloud(&cloud, &config); assert!(gs.is_empty()); } diff --git a/crates/ruvector-robotics/src/bridge/indexing.rs b/crates/ruvector-robotics/src/bridge/indexing.rs index e7cb7a87c..3932c2c0c 100644 --- a/crates/ruvector-robotics/src/bridge/indexing.rs +++ b/crates/ruvector-robotics/src/bridge/indexing.rs @@ -96,10 +96,7 @@ fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 { #[inline] fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 { - a.iter() - .zip(b.iter()) - .map(|(x, y)| (x - y).abs()) - .sum() + a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum() } /// Cosine distance in a single fused loop (dot + both norms together). @@ -174,11 +171,7 @@ impl SpatialIndex { /// /// For the Euclidean metric, squared distances are used internally and /// only the final `k` results are square-rooted. - pub fn search_nearest( - &self, - query: &[f32], - k: usize, - ) -> Result, IndexError> { + pub fn search_nearest(&self, query: &[f32], k: usize) -> Result, IndexError> { let n = self.len(); if n == 0 { return Err(IndexError::EmptyIndex); @@ -220,7 +213,11 @@ impl SpatialIndex { let mut result: Vec<(usize, f32)> = heap .into_iter() .map(|e| { - let dist = if use_sq { e.distance.sqrt() } else { e.distance }; + let dist = if use_sq { + e.distance.sqrt() + } else { + e.distance + }; (e.index, dist) }) .collect(); @@ -401,7 +398,13 @@ mod tests { let mut idx = SpatialIndex::new(3); idx.insert_vectors(&[vec![1.0, 2.0, 3.0]]); let err = idx.search_nearest(&[0.0, 0.0], 1).unwrap_err(); - assert_eq!(err, IndexError::DimensionMismatch { expected: 3, got: 2 }); + assert_eq!( + err, + IndexError::DimensionMismatch { + expected: 3, + got: 2 + } + ); } #[test] @@ -499,7 +502,10 @@ mod tests { #[test] fn test_index_error_display() { - let e = IndexError::DimensionMismatch { expected: 3, got: 5 }; + let e = IndexError::DimensionMismatch { + expected: 3, + got: 5, + }; assert!(format!("{e}").contains("3")); assert!(format!("{}", IndexError::EmptyIndex).contains("empty")); } diff --git a/crates/ruvector-robotics/src/bridge/mod.rs b/crates/ruvector-robotics/src/bridge/mod.rs index e5d8b0dbc..f6f674053 100644 --- a/crates/ruvector-robotics/src/bridge/mod.rs +++ b/crates/ruvector-robotics/src/bridge/mod.rs @@ -72,7 +72,12 @@ pub struct Quaternion { impl Default for Quaternion { fn default() -> Self { - Self { x: 0.0, y: 0.0, z: 0.0, w: 1.0 } + Self { + x: 0.0, + y: 0.0, + z: 0.0, + w: 1.0, + } } } @@ -250,7 +255,11 @@ pub struct SceneGraph { impl SceneGraph { pub fn new(objects: Vec, edges: Vec, timestamp: i64) -> Self { - Self { objects, edges, timestamp } + Self { + objects, + edges, + timestamp, + } } } @@ -264,7 +273,11 @@ pub struct Trajectory { impl Trajectory { pub fn new(waypoints: Vec<[f64; 3]>, timestamps: Vec, confidence: f64) -> Self { - Self { waypoints, timestamps, confidence } + Self { + waypoints, + timestamps, + confidence, + } } pub fn len(&self) -> usize { @@ -364,11 +377,7 @@ mod tests { #[test] fn test_trajectory() { - let t = Trajectory::new( - vec![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], - vec![100, 200], - 0.95, - ); + let t = Trajectory::new(vec![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], vec![100, 200], 0.95); assert_eq!(t.len(), 2); assert!(!t.is_empty()); } diff --git a/crates/ruvector-robotics/src/bridge/pipeline.rs b/crates/ruvector-robotics/src/bridge/pipeline.rs index d5ee39fdf..cf4429ba3 100644 --- a/crates/ruvector-robotics/src/bridge/pipeline.rs +++ b/crates/ruvector-robotics/src/bridge/pipeline.rs @@ -185,11 +185,7 @@ impl PerceptionPipeline { let n = self.position_history.len(); let prev = &self.position_history[n - 2]; let curr = &self.position_history[n - 1]; - let vel = [ - curr[0] - prev[0], - curr[1] - prev[1], - curr[2] - prev[2], - ]; + let vel = [curr[0] - prev[0], curr[1] - prev[1], curr[2] - prev[2]]; // Predict 5 steps into the future with constant velocity. let steps = 5; @@ -215,11 +211,7 @@ mod tests { use super::*; use crate::bridge::{Point3D, PointCloud, RobotState, SensorFrame}; - fn make_frame( - points: Vec, - position: [f64; 3], - ts: i64, - ) -> SensorFrame { + fn make_frame(points: Vec, position: [f64; 3], ts: i64) -> SensorFrame { SensorFrame { cloud: Some(PointCloud::new(points, ts)), state: Some(RobotState { @@ -272,11 +264,7 @@ mod tests { ..Default::default() }; let mut pipeline = PerceptionPipeline::new(config); - let frame = make_frame( - vec![Point3D::new(10.0, 10.0, 10.0)], - [0.0, 0.0, 0.0], - 1000, - ); + let frame = make_frame(vec![Point3D::new(10.0, 10.0, 10.0)], [0.0, 0.0, 0.0], 1000); let result = pipeline.process_frame(&frame); assert!(result.obstacles.is_empty()); } @@ -335,10 +323,7 @@ mod tests { }; let mut pipeline = PerceptionPipeline::new(config); let frame = make_frame( - vec![ - Point3D::new(1.0, 0.0, 0.0), - Point3D::new(2.0, 0.0, 0.0), - ], + vec![Point3D::new(1.0, 0.0, 0.0), Point3D::new(2.0, 0.0, 0.0)], [0.0, 0.0, 0.0], 0, ); @@ -368,11 +353,7 @@ mod tests { }; let mut pipeline = PerceptionPipeline::new(config); - let f1 = make_frame( - vec![Point3D::new(1.0, 0.0, 0.0)], - [0.0, 0.0, 0.0], - 0, - ); + let f1 = make_frame(vec![Point3D::new(1.0, 0.0, 0.0)], [0.0, 0.0, 0.0], 0); pipeline.process_frame(&f1); assert_eq!(pipeline.index().len(), 1); diff --git a/crates/ruvector-robotics/src/cognitive/behavior_tree.rs b/crates/ruvector-robotics/src/cognitive/behavior_tree.rs index 523f9eb0b..369ccbee7 100644 --- a/crates/ruvector-robotics/src/cognitive/behavior_tree.rs +++ b/crates/ruvector-robotics/src/cognitive/behavior_tree.rs @@ -112,9 +112,7 @@ impl BehaviorTree { /// Set the result that a named action should return. pub fn set_action_result(&mut self, name: &str, status: BehaviorStatus) { - self.context - .action_results - .insert(name.to_string(), status); + self.context.action_results.insert(name.to_string(), status); } /// Read-only access to the context. diff --git a/crates/ruvector-robotics/src/cognitive/cognitive_core.rs b/crates/ruvector-robotics/src/cognitive/cognitive_core.rs index 513b6022a..511730667 100644 --- a/crates/ruvector-robotics/src/cognitive/cognitive_core.rs +++ b/crates/ruvector-robotics/src/cognitive/cognitive_core.rs @@ -155,10 +155,11 @@ impl CognitiveCore { } // Simple heuristic: pick the most confident percept and derive an action. - let best = self - .percept_buffer - .iter() - .max_by(|a, b| a.confidence.partial_cmp(&b.confidence).unwrap_or(std::cmp::Ordering::Equal))?; + let best = self.percept_buffer.iter().max_by(|a, b| { + a.confidence + .partial_cmp(&b.confidence) + .unwrap_or(std::cmp::Ordering::Equal) + })?; let action_type = if best.data.len() >= 3 { ActionType::Move([best.data[0], best.data[1], best.data[2]]) @@ -176,7 +177,10 @@ impl CognitiveCore { }, confidence: best.confidence, }, - reasoning: format!("Best percept from '{}' (conf={:.2})", best.source, best.confidence), + reasoning: format!( + "Best percept from '{}' (conf={:.2})", + best.source, best.confidence + ), utility: best.confidence, }; @@ -197,11 +201,9 @@ impl CognitiveCore { // Adjust attention threshold based on success/failure. if outcome.success { - self.config.attention_threshold = - (self.config.attention_threshold - 0.01).max(0.1); + self.config.attention_threshold = (self.config.attention_threshold - 0.01).max(0.1); } else { - self.config.attention_threshold = - (self.config.attention_threshold + 0.01).min(0.9); + self.config.attention_threshold = (self.config.attention_threshold + 0.01).min(0.9); } // Clear processed percepts so the next cycle starts fresh. diff --git a/crates/ruvector-robotics/src/cognitive/memory_system.rs b/crates/ruvector-robotics/src/cognitive/memory_system.rs index a17005076..31de79092 100644 --- a/crates/ruvector-robotics/src/cognitive/memory_system.rs +++ b/crates/ruvector-robotics/src/cognitive/memory_system.rs @@ -42,16 +42,11 @@ impl WorkingMemory { pub fn add(&mut self, item: MemoryItem) { if self.items.len() >= self.max_size { // Evict least important. - if let Some((idx, _)) = self - .items - .iter() - .enumerate() - .min_by(|(_, a), (_, b)| { - a.importance - .partial_cmp(&b.importance) - .unwrap_or(std::cmp::Ordering::Equal) - }) - { + if let Some((idx, _)) = self.items.iter().enumerate().min_by(|(_, a), (_, b)| { + a.importance + .partial_cmp(&b.importance) + .unwrap_or(std::cmp::Ordering::Equal) + }) { self.items.remove(idx); } } diff --git a/crates/ruvector-robotics/src/domain_expansion.rs b/crates/ruvector-robotics/src/domain_expansion.rs index dcc181f3c..efec4fc15 100644 --- a/crates/ruvector-robotics/src/domain_expansion.rs +++ b/crates/ruvector-robotics/src/domain_expansion.rs @@ -25,7 +25,9 @@ //! graph traversal) directly appear in perception and planning kernels. use rand::Rng; -use ruvector_domain_expansion::domain::{Domain, DomainEmbedding, DomainId, Evaluation, Solution, Task}; +use ruvector_domain_expansion::domain::{ + Domain, DomainEmbedding, DomainId, Evaluation, Solution, Task, +}; use serde::{Deserialize, Serialize}; const EMBEDDING_DIM: usize = 64; @@ -120,7 +122,13 @@ impl RoboticsDomain { // -- task generators --------------------------------------------------- fn gen_clustering(&self, difficulty: f32, rng: &mut impl Rng) -> RoboticsTaskSpec { - let num_clusters = if difficulty < 0.3 { 2 } else if difficulty < 0.7 { 5 } else { 10 }; + let num_clusters = if difficulty < 0.3 { + 2 + } else if difficulty < 0.7 { + 5 + } else { + 10 + }; let pts_per_cluster = if difficulty < 0.3 { 10 } else { 20 }; let spread = if difficulty < 0.5 { 0.5 } else { 2.0 }; @@ -155,15 +163,17 @@ impl RoboticsDomain { } fn gen_avoidance(&self, difficulty: f32, rng: &mut impl Rng) -> RoboticsTaskSpec { - let num_obstacles = if difficulty < 0.3 { 3 } else if difficulty < 0.7 { 8 } else { 15 }; + let num_obstacles = if difficulty < 0.3 { + 3 + } else if difficulty < 0.7 { + 8 + } else { + 15 + }; let mut obstacles = Vec::new(); for _ in 0..num_obstacles { obstacles.push(TaskObstacle { - center: [ - rng.gen_range(1.0..9.0), - rng.gen_range(1.0..9.0), - 0.0, - ], + center: [rng.gen_range(1.0..9.0), rng.gen_range(1.0..9.0), 0.0], radius: rng.gen_range(0.3..1.5), }); } @@ -185,7 +195,13 @@ impl RoboticsDomain { } fn gen_scene_graph(&self, difficulty: f32, rng: &mut impl Rng) -> RoboticsTaskSpec { - let num_objects = if difficulty < 0.3 { 3 } else if difficulty < 0.7 { 8 } else { 15 }; + let num_objects = if difficulty < 0.3 { + 3 + } else if difficulty < 0.7 { + 8 + } else { + 15 + }; let mut obstacles = Vec::new(); for _ in 0..num_objects { obstacles.push(TaskObstacle { @@ -221,8 +237,8 @@ impl RoboticsDomain { vec!["scan", "approach", "align", "grasp", "lift", "place"] } else { vec![ - "scan", "classify", "approach", "align", "grasp", - "lift", "navigate", "place", "verify", "retreat", + "scan", "classify", "approach", "align", "grasp", "lift", "navigate", "place", + "verify", "retreat", ] }; @@ -258,15 +274,18 @@ impl RoboticsDomain { } fn gen_swarm_formation(&self, difficulty: f32, _rng: &mut impl Rng) -> RoboticsTaskSpec { - let num_robots = if difficulty < 0.3 { 4 } else if difficulty < 0.7 { 8 } else { 16 }; + let num_robots = if difficulty < 0.3 { + 4 + } else if difficulty < 0.7 { + 8 + } else { + 16 + }; let formation = if difficulty < 0.5 { "circle" } else { "grid" }; RoboticsTaskSpec { category: RoboticsCategory::SwarmFormation, - description: format!( - "Assign {} robots to a {} formation.", - num_robots, formation, - ), + description: format!("Assign {} robots to a {} formation.", num_robots, formation,), size: num_robots, world_bounds: [20.0, 20.0, 1.0], obstacles: Vec::new(), @@ -297,8 +316,16 @@ impl RoboticsDomain { }; let correctness = cluster_accuracy; - let efficiency = if sol.cluster_ids.len() == spec.size { 1.0 } else { 0.5 }; - let elegance = if actual_clusters <= expected_clusters * 2 { 0.8 } else { 0.3 }; + let efficiency = if sol.cluster_ids.len() == spec.size { + 1.0 + } else { + 0.5 + }; + let elegance = if actual_clusters <= expected_clusters * 2 { + 0.8 + } else { + 0.3 + }; if (actual_clusters as i32 - expected_clusters as i32).unsigned_abs() > 2 { notes.push(format!( @@ -342,7 +369,10 @@ impl RoboticsDomain { } let correctness = if reaches_goal { 0.6 } else { 0.2 } - + (1.0 - (collisions as f32 / (sol.waypoints.len() * spec.obstacles.len()).max(1) as f32).min(1.0)) * 0.4; + + (1.0 + - (collisions as f32 / (sol.waypoints.len() * spec.obstacles.len()).max(1) as f32) + .min(1.0)) + * 0.4; let efficiency = 1.0 - (sol.waypoints.len() as f32 / 100.0).min(1.0); let elegance = if collisions == 0 { 0.9 } else { 0.3 }; @@ -439,7 +469,11 @@ impl RoboticsDomain { let dep_penalty = violations as f32 / expected_skills.max(1) as f32; let correctness = (coverage.min(1.0) * (1.0 - dep_penalty.min(1.0))).max(0.0); - let efficiency = if sol.skill_sequence.len() <= expected_skills + 2 { 0.9 } else { 0.5 }; + let efficiency = if sol.skill_sequence.len() <= expected_skills + 2 { + 0.9 + } else { + 0.5 + }; let elegance = if violations == 0 { 0.9 } else { 0.3 }; Evaluation { @@ -800,7 +834,11 @@ mod tests { let tasks = domain.generate_tasks(20, 0.5); for task in &tasks { let ref_sol = domain.reference_solution(task); - assert!(ref_sol.is_some(), "Reference solution missing for {}", task.id); + assert!( + ref_sol.is_some(), + "Reference solution missing for {}", + task.id + ); } } diff --git a/crates/ruvector-robotics/src/lib.rs b/crates/ruvector-robotics/src/lib.rs index bc56e8040..403bc3af5 100644 --- a/crates/ruvector-robotics/src/lib.rs +++ b/crates/ruvector-robotics/src/lib.rs @@ -50,12 +50,10 @@ pub mod rvf; // Convenience re-exports of the most commonly used types. pub use bridge::{ - BridgeConfig, DistanceMetric, OccupancyGrid, Obstacle as BridgeObstacle, Point3D, PointCloud, + BridgeConfig, DistanceMetric, Obstacle as BridgeObstacle, OccupancyGrid, Point3D, PointCloud, Pose, Quaternion, RobotState, SceneEdge, SceneGraph, SceneObject, SensorFrame, SpatialIndex, Trajectory, }; pub use cognitive::{BehaviorNode, BehaviorStatus, BehaviorTree, CognitiveCore, CognitiveState}; -pub use perception::{ - ObstacleDetector, PerceptionConfig, PerceptionPipeline, SceneGraphBuilder, -}; +pub use perception::{ObstacleDetector, PerceptionConfig, PerceptionPipeline, SceneGraphBuilder}; pub use planning::{GridPath, VelocityCommand}; diff --git a/crates/ruvector-robotics/src/mcp/executor.rs b/crates/ruvector-robotics/src/mcp/executor.rs index ef83b4ee0..9327f014d 100644 --- a/crates/ruvector-robotics/src/mcp/executor.rs +++ b/crates/ruvector-robotics/src/mcp/executor.rs @@ -157,13 +157,13 @@ impl ToolExecutor { .arguments .get("query") .and_then(|v| v.as_array()) - .map(|a| a.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect()) + .map(|a| { + a.iter() + .filter_map(|v| v.as_f64().map(|f| f as f32)) + .collect() + }) .ok_or("missing 'query'")?; - let k = req - .arguments - .get("k") - .and_then(|v| v.as_u64()) - .unwrap_or(5) as usize; + let k = req.arguments.get("k").and_then(|v| v.as_u64()).unwrap_or(5) as usize; let results = self .index @@ -196,10 +196,7 @@ impl Default for ToolExecutor { // -- argument parsers ------------------------------------------------------- -fn parse_point_cloud( - req: &ToolRequest, - key: &str, -) -> std::result::Result { +fn parse_point_cloud(req: &ToolRequest, key: &str) -> std::result::Result { let raw = req .arguments .get(key) @@ -212,10 +209,7 @@ fn parse_point_cloud( } } -fn parse_position( - req: &ToolRequest, - key: &str, -) -> std::result::Result<[f64; 3], String> { +fn parse_position(req: &ToolRequest, key: &str) -> std::result::Result<[f64; 3], String> { let arr = req .arguments .get(key) @@ -253,9 +247,11 @@ mod tests { use std::collections::HashMap; fn make_request(tool: &str, args: serde_json::Value) -> ToolRequest { - let arguments: HashMap = - serde_json::from_value(args).unwrap(); - ToolRequest { tool_name: tool.to_string(), arguments } + let arguments: HashMap = serde_json::from_value(args).unwrap(); + ToolRequest { + tool_name: tool.to_string(), + arguments, + } } #[test] @@ -270,10 +266,13 @@ mod tests { 1000, ); let cloud_json = serde_json::to_string(&cloud).unwrap(); - let req = make_request("detect_obstacles", serde_json::json!({ - "point_cloud_json": cloud_json, - "robot_position": [0.0, 0.0, 0.0], - })); + let req = make_request( + "detect_obstacles", + serde_json::json!({ + "point_cloud_json": cloud_json, + "robot_position": [0.0, 0.0, 0.0], + }), + ); let resp = exec.execute(&req); assert!(resp.success); } @@ -281,12 +280,15 @@ mod tests { #[test] fn test_predict_trajectory() { let mut exec = ToolExecutor::new(); - let req = make_request("predict_trajectory", serde_json::json!({ - "position": [0.0, 0.0, 0.0], - "velocity": [1.0, 0.0, 0.0], - "steps": 5, - "dt": 0.5, - })); + let req = make_request( + "predict_trajectory", + serde_json::json!({ + "position": [0.0, 0.0, 0.0], + "velocity": [1.0, 0.0, 0.0], + "steps": 5, + "dt": 0.5, + }), + ); let resp = exec.execute(&req); assert!(resp.success); let traj = resp.result; @@ -304,18 +306,24 @@ mod tests { Point3D::new(10.0, 0.0, 0.0), ]; let points_json = serde_json::to_string(&points).unwrap(); - let req = make_request("insert_points", serde_json::json!({ - "points_json": points_json, - })); + let req = make_request( + "insert_points", + serde_json::json!({ + "points_json": points_json, + }), + ); let resp = exec.execute(&req); assert!(resp.success); assert_eq!(resp.result["total"], 3); // Search - let req = make_request("spatial_search", serde_json::json!({ - "query": [1.0, 0.0, 0.0], - "k": 2, - })); + let req = make_request( + "spatial_search", + serde_json::json!({ + "query": [1.0, 0.0, 0.0], + "k": 2, + }), + ); let resp = exec.execute(&req); assert!(resp.success); let results = resp.result.as_array().unwrap(); @@ -339,10 +347,13 @@ mod tests { SceneObject::new(1, [2.0, 0.0, 0.0], [1.0, 1.0, 1.0]), ]; let objects_json = serde_json::to_string(&objects).unwrap(); - let req = make_request("build_scene_graph", serde_json::json!({ - "objects_json": objects_json, - "max_edge_distance": 5.0, - })); + let req = make_request( + "build_scene_graph", + serde_json::json!({ + "objects_json": objects_json, + "max_edge_distance": 5.0, + }), + ); let resp = exec.execute(&req); assert!(resp.success); assert_eq!(resp.result["edges"].as_array().unwrap().len(), 1); diff --git a/crates/ruvector-robotics/src/mcp/mod.rs b/crates/ruvector-robotics/src/mcp/mod.rs index a406d3d9a..a9b19e871 100644 --- a/crates/ruvector-robotics/src/mcp/mod.rs +++ b/crates/ruvector-robotics/src/mcp/mod.rs @@ -159,7 +159,12 @@ pub struct ToolResponse { impl ToolResponse { /// Convenience constructor for a successful response. pub fn ok(result: serde_json::Value, latency_us: u64) -> Self { - Self { success: true, result, error: None, latency_us } + Self { + success: true, + result, + error: None, + latency_us, + } } /// Convenience constructor for a failed response. @@ -196,14 +201,18 @@ impl Default for RoboticsToolRegistry { impl RoboticsToolRegistry { /// Create a registry pre-populated with all built-in robotics tools. pub fn new() -> Self { - let mut registry = Self { tools: HashMap::new() }; + let mut registry = Self { + tools: HashMap::new(), + }; registry.register_defaults(); registry } /// Create an empty registry with no tools registered. pub fn empty() -> Self { - Self { tools: HashMap::new() } + Self { + tools: HashMap::new(), + } } /// Register a single tool. Overwrites any existing tool with the same name. @@ -223,7 +232,10 @@ impl RoboticsToolRegistry { /// Return all tools belonging to the given category. pub fn list_by_category(&self, category: ToolCategory) -> Vec<&ToolDefinition> { - self.tools.values().filter(|t| t.category == category).collect() + self.tools + .values() + .filter(|t| t.category == category) + .collect() } /// Produce a full MCP-compatible JSON schema describing every tool. @@ -247,13 +259,22 @@ impl RoboticsToolRegistry { "Detect obstacles in a point cloud relative to the robot position", vec![ ToolParameter::new( - "point_cloud_json", "JSON-encoded point cloud", ParamType::String, true, + "point_cloud_json", + "JSON-encoded point cloud", + ParamType::String, + true, ), ToolParameter::new( - "robot_position", "Robot [x,y,z] position", ParamType::Array, true, + "robot_position", + "Robot [x,y,z] position", + ParamType::Array, + true, ), ToolParameter::new( - "max_distance", "Maximum detection distance in meters", ParamType::Number, false, + "max_distance", + "Maximum detection distance in meters", + ParamType::Number, + false, ), ], ToolCategory::Perception, @@ -264,10 +285,16 @@ impl RoboticsToolRegistry { "Build a scene graph from detected objects with spatial edges", vec![ ToolParameter::new( - "objects_json", "JSON array of scene objects", ParamType::String, true, + "objects_json", + "JSON array of scene objects", + ParamType::String, + true, ), ToolParameter::new( - "max_edge_distance", "Maximum edge distance in meters", ParamType::Number, false, + "max_edge_distance", + "Maximum edge distance in meters", + ParamType::Number, + false, ), ], ToolCategory::Perception, @@ -277,9 +304,24 @@ impl RoboticsToolRegistry { "predict_trajectory", "Predict future trajectory from current position and velocity", vec![ - ToolParameter::new("position", "Current [x,y,z] position", ParamType::Array, true), - ToolParameter::new("velocity", "Current [vx,vy,vz] velocity", ParamType::Array, true), - ToolParameter::new("steps", "Number of prediction steps", ParamType::Integer, true), + ToolParameter::new( + "position", + "Current [x,y,z] position", + ParamType::Array, + true, + ), + ToolParameter::new( + "velocity", + "Current [vx,vy,vz] velocity", + ParamType::Array, + true, + ), + ToolParameter::new( + "steps", + "Number of prediction steps", + ParamType::Integer, + true, + ), ToolParameter::new("dt", "Time step in seconds", ParamType::Number, false), ], ToolCategory::Navigation, @@ -290,10 +332,18 @@ impl RoboticsToolRegistry { "Extract a region of interest from a point cloud by center and radius", vec![ ToolParameter::new( - "point_cloud_json", "JSON-encoded point cloud", ParamType::String, true, + "point_cloud_json", + "JSON-encoded point cloud", + ParamType::String, + true, ), ToolParameter::new("center", "Focus center [x,y,z]", ParamType::Array, true), - ToolParameter::new("radius", "Attention radius in meters", ParamType::Number, true), + ToolParameter::new( + "radius", + "Attention radius in meters", + ParamType::Number, + true, + ), ], ToolCategory::Perception, )); @@ -301,11 +351,12 @@ impl RoboticsToolRegistry { self.register_tool(ToolDefinition::new( "detect_anomalies", "Detect anomalous points in a point cloud using statistical analysis", - vec![ - ToolParameter::new( - "point_cloud_json", "JSON-encoded point cloud", ParamType::String, true, - ), - ], + vec![ToolParameter::new( + "point_cloud_json", + "JSON-encoded point cloud", + ParamType::String, + true, + )], ToolCategory::Perception, )); @@ -314,7 +365,12 @@ impl RoboticsToolRegistry { "Search for nearest neighbours in the spatial index", vec![ ToolParameter::new("query", "Query vector [x,y,z]", ParamType::Array, true), - ToolParameter::new("k", "Number of neighbours to return", ParamType::Integer, true), + ToolParameter::new( + "k", + "Number of neighbours to return", + ParamType::Integer, + true, + ), ], ToolCategory::Perception, )); @@ -322,11 +378,12 @@ impl RoboticsToolRegistry { self.register_tool(ToolDefinition::new( "insert_points", "Insert points into the spatial index for later retrieval", - vec![ - ToolParameter::new( - "points_json", "JSON array of [x,y,z] points", ParamType::String, true, - ), - ], + vec![ToolParameter::new( + "points_json", + "JSON array of [x,y,z] points", + ParamType::String, + true, + )], ToolCategory::Perception, )); @@ -337,7 +394,10 @@ impl RoboticsToolRegistry { ToolParameter::new("key", "Unique memory key", ParamType::String, true), ToolParameter::new("data", "Data vector to store", ParamType::Array, true), ToolParameter::new( - "importance", "Importance weight 0.0-1.0", ParamType::Number, false, + "importance", + "Importance weight 0.0-1.0", + ParamType::Number, + false, ), ], ToolCategory::Memory, @@ -348,9 +408,17 @@ impl RoboticsToolRegistry { "Recall the k most similar memories to a query vector", vec![ ToolParameter::new( - "query", "Query vector for similarity search", ParamType::Array, true, + "query", + "Query vector for similarity search", + ParamType::Array, + true, + ), + ToolParameter::new( + "k", + "Number of memories to recall", + ParamType::Integer, + true, ), - ToolParameter::new("k", "Number of memories to recall", ParamType::Integer, true), ], ToolCategory::Memory, )); @@ -373,9 +441,12 @@ impl RoboticsToolRegistry { self.register_tool(ToolDefinition::new( "execute_skill", "Execute a previously learned skill by name", - vec![ - ToolParameter::new("name", "Name of the skill to execute", ParamType::String, true), - ], + vec![ToolParameter::new( + "name", + "Name of the skill to execute", + ParamType::String, + true, + )], ToolCategory::Cognition, )); @@ -397,33 +468,36 @@ impl RoboticsToolRegistry { self.register_tool(ToolDefinition::new( "coordinate_swarm", "Coordinate a multi-robot swarm for a given task", - vec![ - ToolParameter::new( - "task_json", "JSON-encoded task specification", ParamType::String, true, - ), - ], + vec![ToolParameter::new( + "task_json", + "JSON-encoded task specification", + ParamType::String, + true, + )], ToolCategory::Swarm, )); self.register_tool(ToolDefinition::new( "update_world_model", "Update the internal world model with a new or changed object", - vec![ - ToolParameter::new( - "object_json", "JSON-encoded object to upsert", ParamType::String, true, - ), - ], + vec![ToolParameter::new( + "object_json", + "JSON-encoded object to upsert", + ParamType::String, + true, + )], ToolCategory::Cognition, )); self.register_tool(ToolDefinition::new( "get_world_state", "Retrieve the current world model state, optionally filtered by object id", - vec![ - ToolParameter::new( - "object_id", "Optional object id to filter", ParamType::Integer, false, - ), - ], + vec![ToolParameter::new( + "object_id", + "Optional object id to filter", + ParamType::Integer, + false, + )], ToolCategory::Cognition, )); } @@ -477,7 +551,10 @@ mod tests { let tool = registry.get_tool("detect_obstacles").unwrap(); assert_eq!(tool.category, ToolCategory::Perception); assert_eq!(tool.parameters.len(), 3); - assert!(tool.parameters.iter().any(|p| p.name == "point_cloud_json" && p.required)); + assert!(tool + .parameters + .iter() + .any(|p| p.name == "point_cloud_json" && p.required)); let tool = registry.get_tool("predict_trajectory").unwrap(); assert_eq!(tool.category, ToolCategory::Navigation); @@ -541,7 +618,10 @@ mod tests { let schema = registry.to_mcp_schema(); let tools = schema["tools"].as_array().unwrap(); - let obs = tools.iter().find(|t| t["name"] == "detect_obstacles").unwrap(); + let obs = tools + .iter() + .find(|t| t["name"] == "detect_obstacles") + .unwrap(); let required = obs["inputSchema"]["required"].as_array().unwrap(); let req_names: Vec<&str> = required.iter().map(|v| v.as_str().unwrap()).collect(); assert!(req_names.contains(&"point_cloud_json")); @@ -555,7 +635,10 @@ mod tests { args.insert("k".to_string(), serde_json::json!(5)); args.insert("query".to_string(), serde_json::json!([1.0, 2.0, 3.0])); - let req = ToolRequest { tool_name: "spatial_search".to_string(), arguments: args }; + let req = ToolRequest { + tool_name: "spatial_search".to_string(), + arguments: args, + }; let json = serde_json::to_string(&req).unwrap(); let deserialized: ToolRequest = serde_json::from_str(&json).unwrap(); assert_eq!(deserialized.tool_name, "spatial_search"); @@ -591,7 +674,12 @@ mod tests { let custom = ToolDefinition::new( "my_custom_tool", "A custom tool for testing", - vec![ToolParameter::new("input", "The input data", ParamType::String, true)], + vec![ToolParameter::new( + "input", + "The input data", + ParamType::String, + true, + )], ToolCategory::Cognition, ); registry.register_tool(custom); diff --git a/crates/ruvector-robotics/src/perception/clustering.rs b/crates/ruvector-robotics/src/perception/clustering.rs index fb0381c73..32ab9daa9 100644 --- a/crates/ruvector-robotics/src/perception/clustering.rs +++ b/crates/ruvector-robotics/src/perception/clustering.rs @@ -32,11 +32,8 @@ pub fn cluster_point_cloud(cloud: &PointCloud, cell_size: f64) -> Vec = cell_map.keys().copied().collect(); let cell_count = cells.len(); - let cell_idx: HashMap<(i64, i64, i64), usize> = cells - .iter() - .enumerate() - .map(|(i, &k)| (k, i)) - .collect(); + let cell_idx: HashMap<(i64, i64, i64), usize> = + cells.iter().enumerate().map(|(i, &k)| (k, i)).collect(); let mut parent: Vec = (0..cell_count).collect(); let mut rank: Vec = vec![0; cell_count]; @@ -124,11 +121,7 @@ mod tests { #[test] fn test_single_cluster() { - let cloud = make_cloud(&[ - [1.0, 1.0, 0.0], - [1.1, 1.0, 0.0], - [1.0, 1.1, 0.0], - ]); + let cloud = make_cloud(&[[1.0, 1.0, 0.0], [1.1, 1.0, 0.0], [1.0, 1.1, 0.0]]); let clusters = cluster_point_cloud(&cloud, 0.5); assert_eq!(clusters.len(), 1); assert_eq!(clusters[0].len(), 3); @@ -148,11 +141,7 @@ mod tests { #[test] fn test_negative_coordinates() { - let cloud = make_cloud(&[ - [-1.0, -1.0, 0.0], - [-0.9, -1.0, 0.0], - [1.0, 1.0, 0.0], - ]); + let cloud = make_cloud(&[[-1.0, -1.0, 0.0], [-0.9, -1.0, 0.0], [1.0, 1.0, 0.0]]); let clusters = cluster_point_cloud(&cloud, 0.5); assert_eq!(clusters.len(), 2); } diff --git a/crates/ruvector-robotics/src/perception/mod.rs b/crates/ruvector-robotics/src/perception/mod.rs index 5aa49740e..21f21ebd7 100644 --- a/crates/ruvector-robotics/src/perception/mod.rs +++ b/crates/ruvector-robotics/src/perception/mod.rs @@ -10,7 +10,9 @@ pub mod scene_graph; pub mod sensor_fusion; pub use config::{ObstacleConfig, PerceptionConfig, SceneGraphConfig}; -pub use obstacle_detector::{ClassifiedObstacle, DetectedObstacle, ObstacleClass, ObstacleDetector}; +pub use obstacle_detector::{ + ClassifiedObstacle, DetectedObstacle, ObstacleClass, ObstacleDetector, +}; pub use scene_graph::PointCloudSceneGraphBuilder; use serde::{Deserialize, Serialize}; @@ -199,10 +201,8 @@ impl PerceptionPipeline { let obstacle_cfg = ObstacleConfig::default(); let scene_cfg = SceneGraphConfig::default(); let detector = ObstacleDetector::new(obstacle_cfg.clone()); - let graph_builder = SceneGraphBuilder::new( - scene_cfg.edge_distance_threshold, - scene_cfg.max_objects, - ); + let graph_builder = + SceneGraphBuilder::new(scene_cfg.edge_distance_threshold, scene_cfg.max_objects); Self { detector, graph_builder, @@ -221,15 +221,14 @@ impl PerceptionPipeline { ) -> (Vec, SceneGraph) { self.frames_processed += 1; let obstacles = self.detector.detect(cloud, robot_pos); - let graph = self.graph_builder.build_from_obstacles(&obstacles, cloud.timestamp_us); + let graph = self + .graph_builder + .build_from_obstacles(&obstacles, cloud.timestamp_us); (obstacles, graph) } /// Classify previously detected obstacles. - pub fn classify( - &self, - obstacles: &[DetectedObstacle], - ) -> Vec { + pub fn classify(&self, obstacles: &[DetectedObstacle]) -> Vec { self.detector.classify_obstacles(obstacles) } @@ -272,8 +271,7 @@ impl PerceptionPipeline { continue; } - let confidence = (cluster.len() as f32 / cloud.points.len() as f32) - .clamp(0.1, 1.0); + let confidence = (cluster.len() as f32 / cloud.points.len() as f32).clamp(0.1, 1.0); obstacles.push(Obstacle { id: next_id, @@ -487,7 +485,10 @@ impl PerceptionPipeline { // -- private helpers ---------------------------------------------------- fn bounding_sphere(points: &[Point3D]) -> ([f64; 3], f64) { - debug_assert!(!points.is_empty(), "bounding_sphere called with empty slice"); + debug_assert!( + !points.is_empty(), + "bounding_sphere called with empty slice" + ); let n = points.len() as f64; let (mut sx, mut sy, mut sz) = (0.0_f64, 0.0_f64, 0.0_f64); for p in points { @@ -527,8 +528,7 @@ mod tests { use crate::bridge::Point3D; fn make_cloud(pts: &[[f32; 3]]) -> PointCloud { - let points: Vec = - pts.iter().map(|a| Point3D::new(a[0], a[1], a[2])).collect(); + let points: Vec = pts.iter().map(|a| Point3D::new(a[0], a[1], a[2])).collect(); PointCloud::new(points, 1000) } @@ -603,11 +603,7 @@ mod tests { #[test] fn test_detect_obstacles_single_cluster() { let pipe = PerceptionPipeline::with_thresholds(1.0, 2.0); - let cloud = make_cloud(&[ - [2.0, 0.0, 0.0], - [2.1, 0.0, 0.0], - [2.0, 0.1, 0.0], - ]); + let cloud = make_cloud(&[[2.0, 0.0, 0.0], [2.1, 0.0, 0.0], [2.0, 0.1, 0.0]]); let obs = pipe.detect_obstacles(&cloud, [0.0; 3], 10.0).unwrap(); assert_eq!(obs.len(), 1); assert!(obs[0].distance > 1.0); @@ -618,11 +614,7 @@ mod tests { #[test] fn test_detect_obstacles_filters_distant() { let pipe = PerceptionPipeline::with_thresholds(1.0, 2.0); - let cloud = make_cloud(&[ - [50.0, 0.0, 0.0], - [50.1, 0.0, 0.0], - [50.0, 0.1, 0.0], - ]); + let cloud = make_cloud(&[[50.0, 0.0, 0.0], [50.1, 0.0, 0.0], [50.0, 0.1, 0.0]]); let obs = pipe.detect_obstacles(&cloud, [0.0; 3], 5.0).unwrap(); assert!(obs.is_empty()); } @@ -692,14 +684,8 @@ mod tests { #[test] fn test_focus_attention_filters() { let pipe = PerceptionPipeline::with_thresholds(0.5, 2.0); - let cloud = make_cloud(&[ - [0.0, 0.0, 0.0], - [1.0, 0.0, 0.0], - [10.0, 0.0, 0.0], - ]); - let focused = pipe - .focus_attention(&cloud, [0.0, 0.0, 0.0], 2.0) - .unwrap(); + let cloud = make_cloud(&[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [10.0, 0.0, 0.0]]); + let focused = pipe.focus_attention(&cloud, [0.0, 0.0, 0.0], 2.0).unwrap(); assert_eq!(focused.len(), 2); } @@ -716,8 +702,7 @@ mod tests { #[test] fn test_detect_anomalies_outlier() { let pipe = PerceptionPipeline::with_thresholds(0.5, 2.0); - let mut pts: Vec<[f32; 3]> = - (0..20).map(|i| [i as f32 * 0.1, 0.0, 0.0]).collect(); + let mut pts: Vec<[f32; 3]> = (0..20).map(|i| [i as f32 * 0.1, 0.0, 0.0]).collect(); pts.push([100.0, 100.0, 100.0]); let cloud = make_cloud(&pts); let anomalies = pipe.detect_anomalies(&cloud).unwrap(); @@ -728,11 +713,7 @@ mod tests { #[test] fn test_detect_anomalies_no_outliers() { let pipe = PerceptionPipeline::with_thresholds(0.5, 2.0); - let cloud = make_cloud(&[ - [1.0, 1.0, 1.0], - [1.0, 1.0, 1.0], - [1.0, 1.0, 1.0], - ]); + let cloud = make_cloud(&[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]); let anomalies = pipe.detect_anomalies(&cloud).unwrap(); assert!(anomalies.is_empty()); } diff --git a/crates/ruvector-robotics/src/perception/obstacle_detector.rs b/crates/ruvector-robotics/src/perception/obstacle_detector.rs index ed8ebe02a..e0270fab3 100644 --- a/crates/ruvector-robotics/src/perception/obstacle_detector.rs +++ b/crates/ruvector-robotics/src/perception/obstacle_detector.rs @@ -69,11 +69,7 @@ impl ObstacleDetector { /// 4. Compute bounding box and centroid per cluster. /// 5. Filter by `max_detection_range` from the robot. /// 6. Sort results by distance (ascending). - pub fn detect( - &self, - cloud: &PointCloud, - robot_pos: &[f64; 3], - ) -> Vec { + pub fn detect(&self, cloud: &PointCloud, robot_pos: &[f64; 3]) -> Vec { if cloud.is_empty() { return Vec::new(); } @@ -104,10 +100,7 @@ impl ObstacleDetector { /// (wall-like). /// * **Dynamic** -- the largest-to-smallest ratio is <= 2 (compact). /// * **Unknown** -- everything else. - pub fn classify_obstacles( - &self, - obstacles: &[DetectedObstacle], - ) -> Vec { + pub fn classify_obstacles(&self, obstacles: &[DetectedObstacle]) -> Vec { obstacles .iter() .map(|o| { @@ -243,11 +236,7 @@ mod tests { safety_margin: 0.1, }); // Cluster at ~10 units away -- should be filtered out. - let cloud = make_cloud(&[ - [10.0, 0.0, 0.0], - [10.1, 0.0, 0.0], - [10.0, 0.1, 0.0], - ]); + let cloud = make_cloud(&[[10.0, 0.0, 0.0], [10.1, 0.0, 0.0], [10.0, 0.1, 0.0]]); let result = det.detect(&cloud, &[0.0, 0.0, 0.0]); assert!(result.is_empty()); } @@ -260,11 +249,7 @@ mod tests { safety_margin: 0.1, }); // Only 3 points -- below minimum. - let cloud = make_cloud(&[ - [1.0, 1.0, 0.0], - [1.1, 1.0, 0.0], - [1.0, 1.1, 0.0], - ]); + let cloud = make_cloud(&[[1.0, 1.0, 0.0], [1.1, 1.0, 0.0], [1.0, 1.1, 0.0]]); let result = det.detect(&cloud, &[0.0, 0.0, 0.0]); assert!(result.is_empty()); } diff --git a/crates/ruvector-robotics/src/perception/scene_graph.rs b/crates/ruvector-robotics/src/perception/scene_graph.rs index 6227aa1ec..1b74b755d 100644 --- a/crates/ruvector-robotics/src/perception/scene_graph.rs +++ b/crates/ruvector-robotics/src/perception/scene_graph.rs @@ -87,10 +87,8 @@ impl PointCloudSceneGraphBuilder { let mut objects: Vec = seen_ids.into_values().collect(); objects.sort_by(|a, b| a.id.cmp(&b.id)); - let truncated: Vec = objects - .into_iter() - .take(self.config.max_objects) - .collect(); + let truncated: Vec = + objects.into_iter().take(self.config.max_objects).collect(); let edges = self.create_edges(&truncated); SceneGraph::new(truncated, edges, latest_ts) @@ -99,7 +97,10 @@ impl PointCloudSceneGraphBuilder { // -- private helpers ---------------------------------------------------- fn cluster_to_object(id: usize, points: &[Point3D]) -> SceneObject { - debug_assert!(!points.is_empty(), "cluster_to_object called with empty slice"); + debug_assert!( + !points.is_empty(), + "cluster_to_object called with empty slice" + ); let (mut min_x, mut min_y, mut min_z) = (f64::MAX, f64::MAX, f64::MAX); let (mut max_x, mut max_y, mut max_z) = (f64::MIN, f64::MIN, f64::MIN); let (mut sum_x, mut sum_y, mut sum_z) = (0.0_f64, 0.0_f64, 0.0_f64); @@ -334,7 +335,7 @@ mod tests { let objects = vec![ SceneObject::new(0, [0.0, 0.0, 0.0], [1.0, 1.0, 1.0]), - SceneObject::new(1, [5.0, 0.0, 0.0], [1.0, 1.0, 1.0]), // ~5 < 9.9 => adjacent + SceneObject::new(1, [5.0, 0.0, 0.0], [1.0, 1.0, 1.0]), // ~5 < 9.9 => adjacent SceneObject::new(2, [15.0, 0.0, 0.0], [1.0, 1.0, 1.0]), // ~15 < 19.8 => near SceneObject::new(3, [25.0, 0.0, 0.0], [1.0, 1.0, 1.0]), // ~25 < 30 => far ]; @@ -342,10 +343,7 @@ mod tests { let graph = builder.build_from_objects(&objects); // Check that adjacent relation exists for objects 0 and 1. - let edge_0_1 = graph - .edges - .iter() - .find(|e| e.from == 0 && e.to == 1); + let edge_0_1 = graph.edges.iter().find(|e| e.from == 0 && e.to == 1); assert!(edge_0_1.is_some()); assert_eq!(edge_0_1.unwrap().relation, "adjacent"); } @@ -359,11 +357,7 @@ mod tests { edge_distance_threshold: 100.0, }); - let cloud = make_cloud(&[ - [0.0, 0.0, 0.0], - [50.0, 0.0, 0.0], - [100.0, 0.0, 0.0], - ]); + let cloud = make_cloud(&[[0.0, 0.0, 0.0], [50.0, 0.0, 0.0], [100.0, 0.0, 0.0]]); let graph = builder.build_from_point_cloud(&cloud); // min_cluster_size=1, so each point is its own cluster. diff --git a/crates/ruvector-robotics/src/perception/sensor_fusion.rs b/crates/ruvector-robotics/src/perception/sensor_fusion.rs index 84c7db0f5..342b3cfbc 100644 --- a/crates/ruvector-robotics/src/perception/sensor_fusion.rs +++ b/crates/ruvector-robotics/src/perception/sensor_fusion.rs @@ -150,7 +150,10 @@ mod tests { fn test_fuse_skips_stale() { let c1 = make_cloud(&[[1.0, 0.0, 0.0]], 0); let c2 = make_cloud(&[[2.0, 0.0, 0.0]], 100_000); // 100ms apart - let config = FusionConfig { max_time_delta_us: 50_000, ..Default::default() }; + let config = FusionConfig { + max_time_delta_us: 50_000, + ..Default::default() + }; let result = fuse_clouds(&[c1, c2], &config); assert_eq!(result.len(), 1); // c2 skipped } @@ -159,12 +162,16 @@ mod tests { fn test_voxel_downsample() { let c1 = make_cloud( &[ - [0.0, 0.0, 0.0], [0.01, 0.01, 0.01], // same voxel - [5.0, 5.0, 5.0], // different voxel + [0.0, 0.0, 0.0], + [0.01, 0.01, 0.01], // same voxel + [5.0, 5.0, 5.0], // different voxel ], 0, ); - let config = FusionConfig { voxel_size: 1.0, ..Default::default() }; + let config = FusionConfig { + voxel_size: 1.0, + ..Default::default() + }; let result = fuse_clouds(&[c1], &config); assert_eq!(result.len(), 2); } @@ -172,7 +179,10 @@ mod tests { #[test] fn test_density_weighting() { let c1 = make_cloud(&[[1.0, 0.0, 0.0]], 0); - let config = FusionConfig { density_weighting: true, ..Default::default() }; + let config = FusionConfig { + density_weighting: true, + ..Default::default() + }; let result = fuse_clouds(&[c1], &config); assert_eq!(result.len(), 1); // With 1 point, weight = 1/sqrt(1) = 1.0, so intensity unchanged. diff --git a/crates/ruvector-robotics/src/planning.rs b/crates/ruvector-robotics/src/planning.rs index 28cc03b55..ba3eaaa4c 100644 --- a/crates/ruvector-robotics/src/planning.rs +++ b/crates/ruvector-robotics/src/planning.rs @@ -66,11 +66,7 @@ impl PartialOrd for AStarEntry { /// `goal`. Cells with occupancy >= 0.5 are treated as impassable. /// /// Diagonal moves cost √2, cardinal moves cost 1. -pub fn astar( - grid: &OccupancyGrid, - start: Cell, - goal: Cell, -) -> Result { +pub fn astar(grid: &OccupancyGrid, start: Cell, goal: Cell) -> Result { if !cell_free(grid, start) { return Err(PlanningError::InvalidStart(start.0, start.1)); } @@ -78,7 +74,10 @@ pub fn astar( return Err(PlanningError::InvalidGoal(goal.0, goal.1)); } if start == goal { - return Ok(GridPath { cells: vec![start], cost: 0.0 }); + return Ok(GridPath { + cells: vec![start], + cost: 0.0, + }); } let mut g_score: HashMap = HashMap::with_capacity(128); @@ -88,7 +87,10 @@ pub fn astar( let mut neighbor_buf: Vec<(usize, usize, f64)> = Vec::with_capacity(8); g_score.insert(start, 0.0); - open.push(AStarEntry { cell: start, f: heuristic(start, goal) }); + open.push(AStarEntry { + cell: start, + f: heuristic(start, goal), + }); while let Some(AStarEntry { cell, .. }) = open.pop() { if cell == goal { @@ -241,8 +243,8 @@ pub fn potential_field( let dist = (dx * dx + dy * dy + dz * dz).sqrt().max(0.01); if dist < config.obstacle_influence { - let strength = - config.repulsive_gain * (1.0 / dist - 1.0 / config.obstacle_influence) / (dist * dist); + let strength = config.repulsive_gain * (1.0 / dist - 1.0 / config.obstacle_influence) + / (dist * dist); fx += strength * dx / dist; fy += strength * dy / dist; fz += strength * dz / dist; @@ -258,7 +260,11 @@ pub fn potential_field( fz *= s; } - VelocityCommand { vx: fx, vy: fy, vz: fz } + VelocityCommand { + vx: fx, + vy: fy, + vz: fz, + } } // --------------------------------------------------------------------------- @@ -377,13 +383,11 @@ mod tests { #[test] fn test_potential_field_max_speed() { - let config = PotentialFieldConfig { max_speed: 1.0, ..Default::default() }; - let cmd = potential_field( - &[0.0, 0.0, 0.0], - &[100.0, 100.0, 0.0], - &[], - &config, - ); + let config = PotentialFieldConfig { + max_speed: 1.0, + ..Default::default() + }; + let cmd = potential_field(&[0.0, 0.0, 0.0], &[100.0, 100.0, 0.0], &[], &config); let speed = (cmd.vx * cmd.vx + cmd.vy * cmd.vy + cmd.vz * cmd.vz).sqrt(); assert!((speed - 1.0).abs() < 1e-9); } diff --git a/crates/ruvector-robotics/src/rvf.rs b/crates/ruvector-robotics/src/rvf.rs index 4f58cdd65..a0858d501 100644 --- a/crates/ruvector-robotics/src/rvf.rs +++ b/crates/ruvector-robotics/src/rvf.rs @@ -22,14 +22,10 @@ use std::path::Path; use rvf_runtime::options::DistanceMetric; -use rvf_runtime::{ - IngestResult, QueryOptions, RvfOptions, RvfStore, SearchResult, -}; +use rvf_runtime::{IngestResult, QueryOptions, RvfOptions, RvfStore, SearchResult}; -use crate::bridge::{ - GaussianConfig, Obstacle, PointCloud, SceneGraph, SceneObject, Trajectory, -}; use crate::bridge::gaussian::{gaussians_from_cloud, GaussianSplatCloud}; +use crate::bridge::{GaussianConfig, Obstacle, PointCloud, SceneGraph, SceneObject, Trajectory}; // --------------------------------------------------------------------------- // Errors @@ -95,21 +91,33 @@ impl RoboticsRvf { ..Default::default() }; let store = RvfStore::create(path.as_ref(), options)?; - Ok(Self { store, dimension, next_id: 1 }) + Ok(Self { + store, + dimension, + next_id: 1, + }) } /// Open an existing `.rvf` file for read-write access. pub fn open>(path: P) -> Result { let store = RvfStore::open(path.as_ref())?; let dim = store.dimension(); - Ok(Self { store, dimension: dim, next_id: 1_000_000 }) + Ok(Self { + store, + dimension: dim, + next_id: 1_000_000, + }) } /// Open an existing `.rvf` file for read-only queries. pub fn open_readonly>(path: P) -> Result { let store = RvfStore::open_readonly(path.as_ref())?; let dim = store.dimension(); - Ok(Self { store, dimension: dim, next_id: 0 }) + Ok(Self { + store, + dimension: dim, + next_id: 0, + }) } /// Current store status. @@ -131,11 +139,7 @@ impl RoboticsRvf { return Err(RvfPackError::EmptyData("point cloud is empty")); } - let vectors: Vec> = cloud - .points - .iter() - .map(|p| vec![p.x, p.y, p.z]) - .collect(); + let vectors: Vec> = cloud.points.iter().map(|p| vec![p.x, p.y, p.z]).collect(); let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect(); let ids: Vec = (0..cloud.len()) .map(|_| { @@ -294,11 +298,7 @@ impl RoboticsRvf { // -- querying --------------------------------------------------------- /// Query the store for the `k` nearest vectors to `query`. - pub fn query_nearest( - &self, - query: &[f32], - k: usize, - ) -> Result> { + pub fn query_nearest(&self, query: &[f32], k: usize) -> Result> { if query.len() != self.dimension as usize { return Err(RvfPackError::DimensionMismatch { expected: self.dimension as usize, @@ -390,7 +390,9 @@ mod tests { let result = rvf.pack_scene_objects(&objects).unwrap(); assert_eq!(result.accepted, 2); - let hits = rvf.query_nearest(&[1.0, 2.0, 0.0, 0.5, 0.5, 1.8, 1.0, 0.0, 0.0], 1).unwrap(); + let hits = rvf + .query_nearest(&[1.0, 2.0, 0.0, 0.5, 0.5, 1.8, 1.0, 0.0, 0.0], 1) + .unwrap(); assert_eq!(hits.len(), 1); rvf.close().unwrap(); @@ -430,7 +432,10 @@ mod tests { ], 1000, ); - let config = GaussianConfig { min_cluster_size: 3, ..Default::default() }; + let config = GaussianConfig { + min_cluster_size: 3, + ..Default::default() + }; let (splat_cloud, result) = rvf.pack_gaussians(&cloud, &config).unwrap(); assert!(!splat_cloud.is_empty()); assert!(result.accepted > 0); @@ -444,16 +449,14 @@ mod tests { let path = tmp_path(); let mut rvf = RoboticsRvf::create(&path, 6).unwrap(); - let obstacles = vec![ - Obstacle { - id: 0, - position: [2.0, 0.0, 0.0], - distance: 2.0, - radius: 0.5, - label: "person".into(), - confidence: 0.9, - }, - ]; + let obstacles = vec![Obstacle { + id: 0, + position: [2.0, 0.0, 0.0], + distance: 2.0, + radius: 0.5, + label: "person".into(), + confidence: 0.9, + }]; let result = rvf.pack_obstacles(&obstacles).unwrap(); assert_eq!(result.accepted, 1); diff --git a/crates/ruvector-robotics/tests/integration.rs b/crates/ruvector-robotics/tests/integration.rs index b4f493f2d..04642b7aa 100644 --- a/crates/ruvector-robotics/tests/integration.rs +++ b/crates/ruvector-robotics/tests/integration.rs @@ -8,9 +8,9 @@ use ruvector_robotics::bridge::{ }; use ruvector_robotics::cognitive::{ BehaviorNode, BehaviorStatus, BehaviorTree, CognitiveConfig, CognitiveCore, CognitiveMode, - Demonstration, EpisodicMemory, Episode, Formation, FormationType, MemoryItem, Outcome, - Percept, RobotCapabilities, SkillLibrary, SwarmConfig, SwarmCoordinator, SwarmTask, - TrackedObject, WorkingMemory, WorldModel, + Demonstration, Episode, EpisodicMemory, Formation, FormationType, MemoryItem, Outcome, Percept, + RobotCapabilities, SkillLibrary, SwarmConfig, SwarmCoordinator, SwarmTask, TrackedObject, + WorkingMemory, WorldModel, }; use ruvector_robotics::mcp::{RoboticsToolRegistry, ToolCategory}; use ruvector_robotics::perception::{PerceptionConfig, PerceptionPipeline}; diff --git a/crates/ruvector-robotics/tests/robotics_integration.rs b/crates/ruvector-robotics/tests/robotics_integration.rs index 755249071..d9b4d94d9 100644 --- a/crates/ruvector-robotics/tests/robotics_integration.rs +++ b/crates/ruvector-robotics/tests/robotics_integration.rs @@ -9,22 +9,20 @@ use std::sync::{Arc, Mutex}; use std::time::Instant; -use ruvector_robotics::bridge::{ - GaussianConfig, Point3D, PointCloud, SceneObject, SpatialIndex, -}; use ruvector_robotics::bridge::gaussian::gaussians_from_cloud; +use ruvector_robotics::bridge::{GaussianConfig, Point3D, PointCloud, SceneObject, SpatialIndex}; use ruvector_robotics::cognitive::behavior_tree::{ BehaviorNode, BehaviorStatus, BehaviorTree, DecoratorType, }; use ruvector_robotics::cognitive::{ - ActionOption, DecisionConfig, DecisionEngine, Demonstration, EpisodicMemory, - Episode, SkillLibrary, SwarmConfig, SwarmCoordinator, SwarmTask, RobotCapabilities, - TrackedObject, WorldModel, + ActionOption, DecisionConfig, DecisionEngine, Demonstration, Episode, EpisodicMemory, + RobotCapabilities, SkillLibrary, SwarmConfig, SwarmCoordinator, SwarmTask, TrackedObject, + WorldModel, }; -use ruvector_robotics::mcp::{RoboticsToolRegistry, ToolRequest}; use ruvector_robotics::mcp::executor::ToolExecutor; -use ruvector_robotics::perception::PerceptionPipeline; +use ruvector_robotics::mcp::{RoboticsToolRegistry, ToolRequest}; use ruvector_robotics::perception::sensor_fusion::{fuse_clouds, FusionConfig}; +use ruvector_robotics::perception::PerceptionPipeline; use ruvector_robotics::planning; // --------------------------------------------------------------------------- @@ -59,13 +57,19 @@ fn generate_point_cloud_around(center: Point3D, n: usize, spread: f32) -> Vec 1 { - assert!(!graph.edges.is_empty(), "Should have edges between nearby objects"); + assert!( + !graph.edges.is_empty(), + "Should have edges between nearby objects" + ); } // Predict trajectory using the real API @@ -93,18 +100,45 @@ fn test_cognitive_perceive_think_act() { let pipe = PerceptionPipeline::with_thresholds(0.5, 2.0); let mut points = generate_point_cloud_around(Point3D::new(6.0, 5.0, 0.0), 20, 0.3); - points.extend(generate_point_cloud_around(Point3D::new(10.0, 10.0, 0.0), 15, 0.3)); + points.extend(generate_point_cloud_around( + Point3D::new(10.0, 10.0, 0.0), + 15, + 0.3, + )); let cloud = PointCloud::new(points, 0); - let obstacles = pipe.detect_obstacles(&cloud, [5.0, 5.0, 0.0], 20.0).unwrap(); - let min_dist = obstacles.iter().map(|o| o.distance).fold(f64::MAX, f64::min); + let obstacles = pipe + .detect_obstacles(&cloud, [5.0, 5.0, 0.0], 20.0) + .unwrap(); + let min_dist = obstacles + .iter() + .map(|o| o.distance) + .fold(f64::MAX, f64::min); // Use the real DecisionEngine let engine = DecisionEngine::new(DecisionConfig::default()); let options = vec![ - ActionOption { name: "proceed_fast".into(), reward: 8.0, risk: 0.9, energy_cost: 0.5, novelty: 0.0 }, - ActionOption { name: "slow_down".into(), reward: 5.0, risk: 0.2, energy_cost: 0.3, novelty: 0.0 }, - ActionOption { name: "stop".into(), reward: 2.0, risk: 0.0, energy_cost: 0.0, novelty: 0.0 }, + ActionOption { + name: "proceed_fast".into(), + reward: 8.0, + risk: 0.9, + energy_cost: 0.5, + novelty: 0.0, + }, + ActionOption { + name: "slow_down".into(), + reward: 5.0, + risk: 0.2, + energy_cost: 0.3, + novelty: 0.0, + }, + ActionOption { + name: "stop".into(), + reward: 2.0, + risk: 0.0, + energy_cost: 0.0, + novelty: 0.0, + }, ]; let (best_idx, _utility) = engine.evaluate(&options).unwrap(); assert!(!options[best_idx].name.is_empty()); @@ -229,12 +263,22 @@ fn test_skill_learning_from_demo() { let demos = vec![ Demonstration { - trajectory: vec![[0.0, 0.0, 0.0], [0.5, 0.0, 0.2], [1.0, 0.0, 0.5], [1.0, 0.0, 0.0]], + trajectory: vec![ + [0.0, 0.0, 0.0], + [0.5, 0.0, 0.2], + [1.0, 0.0, 0.5], + [1.0, 0.0, 0.0], + ], timestamps: vec![0, 100, 200, 300], metadata: "demo1".into(), }, Demonstration { - trajectory: vec![[0.0, 0.1, 0.0], [0.5, 0.1, 0.25], [1.0, 0.1, 0.55], [1.0, 0.1, 0.0]], + trajectory: vec![ + [0.0, 0.1, 0.0], + [0.5, 0.1, 0.25], + [1.0, 0.1, 0.55], + [1.0, 0.1, 0.0], + ], timestamps: vec![0, 100, 200, 300], metadata: "demo2".into(), }, @@ -272,9 +316,27 @@ fn test_decision_engine_selects_best() { curiosity_weight: 0.0, }); let options = vec![ - ActionOption { name: "proceed_fast".into(), reward: 8.0, risk: 0.7, energy_cost: 0.5, novelty: 0.0 }, - ActionOption { name: "detour".into(), reward: 7.0, risk: 0.3, energy_cost: 0.3, novelty: 0.0 }, - ActionOption { name: "stop".into(), reward: 3.0, risk: 0.0, energy_cost: 0.0, novelty: 0.0 }, + ActionOption { + name: "proceed_fast".into(), + reward: 8.0, + risk: 0.7, + energy_cost: 0.5, + novelty: 0.0, + }, + ActionOption { + name: "detour".into(), + reward: 7.0, + risk: 0.3, + energy_cost: 0.3, + novelty: 0.0, + }, + ActionOption { + name: "stop".into(), + reward: 3.0, + risk: 0.0, + energy_cost: 0.0, + novelty: 0.0, + }, ]; let (best_idx, _) = engine.evaluate(&options).unwrap(); // With risk_aversion=1, proceed_fast (reward=8 - risk*1=7.3) beats detour (7 - 0.3=6.7) @@ -298,13 +360,28 @@ fn test_mcp_tool_listing() { assert_eq!(registry.list_tools().len(), 15); let expected = [ - "detect_obstacles", "build_scene_graph", "predict_trajectory", - "focus_attention", "detect_anomalies", "spatial_search", "insert_points", - "store_memory", "recall_memory", "learn_skill", "execute_skill", - "plan_behavior", "coordinate_swarm", "update_world_model", "get_world_state", + "detect_obstacles", + "build_scene_graph", + "predict_trajectory", + "focus_attention", + "detect_anomalies", + "spatial_search", + "insert_points", + "store_memory", + "recall_memory", + "learn_skill", + "execute_skill", + "plan_behavior", + "coordinate_swarm", + "update_world_model", + "get_world_state", ]; for name in &expected { - assert!(registry.get_tool(name).is_some(), "Tool '{}' should be registered", name); + assert!( + registry.get_tool(name).is_some(), + "Tool '{}' should be registered", + name + ); } } @@ -321,13 +398,21 @@ fn test_mcp_tool_execution() { ("velocity".into(), serde_json::json!([1.0, 0.0, 0.0])), ("steps".into(), serde_json::json!(5)), ("dt".into(), serde_json::json!(0.5)), - ].into(), + ] + .into(), }; let resp = executor.execute(&req); - assert!(resp.success, "predict_trajectory should succeed: {:?}", resp.error); + assert!( + resp.success, + "predict_trajectory should succeed: {:?}", + resp.error + ); // Unknown tool - let req = ToolRequest { tool_name: "nonexistent".into(), arguments: Default::default() }; + let req = ToolRequest { + tool_name: "nonexistent".into(), + arguments: Default::default(), + }; let resp = executor.execute(&req); assert!(!resp.success, "nonexistent tool should fail"); } @@ -336,11 +421,19 @@ fn test_mcp_tool_execution() { #[test] fn test_gaussian_splatting() { let mut points = generate_point_cloud_around(Point3D::new(2.0, 0.0, 0.0), 30, 0.5); - points.extend(generate_point_cloud_around(Point3D::new(8.0, 0.0, 0.0), 30, 0.5)); + points.extend(generate_point_cloud_around( + Point3D::new(8.0, 0.0, 0.0), + 30, + 0.5, + )); let cloud = PointCloud::new(points, 1000); let gaussians = gaussians_from_cloud(&cloud, &GaussianConfig::default()); - assert!(gaussians.len() >= 2, "Should produce at least 2 Gaussians, got {}", gaussians.len()); + assert!( + gaussians.len() >= 2, + "Should produce at least 2 Gaussians, got {}", + gaussians.len() + ); for g in &gaussians.gaussians { assert!(g.point_count >= 2); @@ -366,11 +459,19 @@ fn test_astar_planning() { let path = planning::astar(&grid, (5, 5), (15, 5)).unwrap(); assert_eq!(*path.cells.first().unwrap(), (5, 5)); assert_eq!(*path.cells.last().unwrap(), (15, 5)); - assert!(path.cost > 10.0, "Path around wall should be longer than straight line"); + assert!( + path.cost > 10.0, + "Path around wall should be longer than straight line" + ); // Verify no cell in path is occupied for &(x, y) in &path.cells { - assert!(grid.get(x, y).unwrap() < 0.5, "Path cell ({},{}) is occupied", x, y); + assert!( + grid.get(x, y).unwrap() < 0.5, + "Path cell ({},{}) is occupied", + x, + y + ); } } @@ -403,16 +504,16 @@ fn test_sensor_fusion() { vec![Point3D::new(1.0, 0.0, 0.0), Point3D::new(2.0, 0.0, 0.0)], 1000, ); - let c2 = PointCloud::new( - vec![Point3D::new(3.0, 0.0, 0.0)], - 1010, - ); + let c2 = PointCloud::new(vec![Point3D::new(3.0, 0.0, 0.0)], 1010); let c3_stale = PointCloud::new( vec![Point3D::new(99.0, 0.0, 0.0)], 200_000, // 199ms later — too stale ); - let config = FusionConfig { max_time_delta_us: 50_000, ..Default::default() }; + let config = FusionConfig { + max_time_delta_us: 50_000, + ..Default::default() + }; let fused = fuse_clouds(&[c1, c2, c3_stale], &config); assert_eq!(fused.len(), 3, "Should include c1+c2 but skip c3"); } @@ -440,8 +541,15 @@ fn test_full_pipeline_100_frames() { } let elapsed = start.elapsed(); - assert!(total_obstacles > 0, "Should detect obstacles across 100 frames"); - assert!(elapsed.as_secs() < 5, "100 frames should complete in < 5s, took {:?}", elapsed); + assert!( + total_obstacles > 0, + "Should detect obstacles across 100 frames" + ); + assert!( + elapsed.as_secs() < 5, + "100 frames should complete in < 5s, took {:?}", + elapsed + ); } /// Test 17: Concurrent spatial search from multiple threads. @@ -482,9 +590,18 @@ fn test_concurrent_spatial_search() { let final_results = results.lock().unwrap(); assert_eq!(final_results.len(), 4, "All 4 threads should complete"); for (tid, neighbors) in final_results.iter() { - assert_eq!(neighbors.len(), 5, "Thread {} should return 5 neighbors", tid); + assert_eq!( + neighbors.len(), + 5, + "Thread {} should return 5 neighbors", + tid + ); for window in neighbors.windows(2) { - assert!(window[0].1 <= window[1].1, "Thread {} results should be distance-sorted", tid); + assert!( + window[0].1 <= window[1].1, + "Thread {} results should be distance-sorted", + tid + ); } } } @@ -503,7 +620,9 @@ fn test_edge_cases() { assert!(pipe.build_scene_graph(&[], -1.0).is_err()); // Trajectory with zero steps - assert!(pipe.predict_trajectory([0.0; 3], [1.0, 0.0, 0.0], 0, 1.0).is_err()); + assert!(pipe + .predict_trajectory([0.0; 3], [1.0, 0.0, 0.0], 0, 1.0) + .is_err()); // Attention with negative radius assert!(pipe.focus_attention(&empty, [0.0; 3], -1.0).is_err()); @@ -516,7 +635,10 @@ fn test_edge_cases() { let index = SpatialIndex::new(3); assert!(index.search_nearest(&[0.0_f32, 0.0, 0.0], 5).is_err()); // Radius search on empty index returns Ok(empty) - assert!(index.search_radius(&[0.0_f32, 0.0, 0.0], 1.0).unwrap().is_empty()); + assert!(index + .search_radius(&[0.0_f32, 0.0, 0.0], 1.0) + .unwrap() + .is_empty()); // Behavior tree decorator let node = BehaviorNode::Decorator( diff --git a/crates/ruvllm-wasm/src/pi_quant_wasm.rs b/crates/ruvllm-wasm/src/pi_quant_wasm.rs index 1e01114fe..158fd02c9 100644 --- a/crates/ruvllm-wasm/src/pi_quant_wasm.rs +++ b/crates/ruvllm-wasm/src/pi_quant_wasm.rs @@ -774,7 +774,16 @@ mod tests { let step = q.step_size(); // Values well outside the range [-4, 3] * step - let weights = vec![step * 10.0, step * -10.0, step * 5.0, step * -6.0, 0.0, 0.0, 0.0, 0.0]; + let weights = vec![ + step * 10.0, + step * -10.0, + step * 5.0, + step * -6.0, + 0.0, + 0.0, + 0.0, + 0.0, + ]; let packed = q.quantize(&weights); let reconstructed = q.dequantize(&packed); diff --git a/crates/ruvllm-wasm/src/quant_bench_wasm.rs b/crates/ruvllm-wasm/src/quant_bench_wasm.rs index dbc5e6de1..56ee89c3b 100644 --- a/crates/ruvllm-wasm/src/quant_bench_wasm.rs +++ b/crates/ruvllm-wasm/src/quant_bench_wasm.rs @@ -363,7 +363,11 @@ impl QuantBenchWasm { let best_compression = formats .iter() - .max_by(|a, b| a.compression_ratio.partial_cmp(&b.compression_ratio).unwrap()) + .max_by(|a, b| { + a.compression_ratio + .partial_cmp(&b.compression_ratio) + .unwrap() + }) .map(|r| r.format.clone()) .unwrap_or_default(); diff --git a/crates/ruvllm/.reasoning_bank_patterns b/crates/ruvllm/.reasoning_bank_patterns index 7e48c93d5..94f1e57c4 100644 Binary files a/crates/ruvllm/.reasoning_bank_patterns and b/crates/ruvllm/.reasoning_bank_patterns differ diff --git a/crates/ruvllm/Cargo.toml b/crates/ruvllm/Cargo.toml index 0e372f9f5..e3bca4750 100644 --- a/crates/ruvllm/Cargo.toml +++ b/crates/ruvllm/Cargo.toml @@ -222,6 +222,10 @@ harness = false name = "pi_quant_bench" harness = false +[[bench]] +name = "moe_bench" +harness = false + # Test configurations [[test]] name = "real_model_test" diff --git a/crates/ruvllm/benches/moe_bench.rs b/crates/ruvllm/benches/moe_bench.rs new file mode 100644 index 000000000..2be04941b --- /dev/null +++ b/crates/ruvllm/benches/moe_bench.rs @@ -0,0 +1,958 @@ +//! MoE Memory-Aware Routing Benchmarks (ADR-092) +//! +//! Criterion benchmarks for validating ADR-092 performance targets: +//! +//! - **Routing overhead**: Target <= 15 us (baseline ~5 us) +//! - **Affinity update**: EMA computation performance +//! - **Precision allocation**: Lookup and allocation performance +//! - **Cache hit rate simulation**: Compare baseline LRU vs affinity-aware +//! - **Paging simulation**: Expert paging latency +//! +//! Run with: `cargo bench --bench moe_bench` + +#![allow( + clippy::all, + unused_imports, + unused_variables, + dead_code, + unused_mut, + unused_assignments, + unexpected_cfgs, + unused_must_use +)] + +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use ruvllm::bitnet::expert_cache::{ + align_to_cache_line, expert_memory_footprint, EvictionPolicy, ExpertBatch, ExpertCache, + ExpertCacheConfig, MoeBatchScheduler, NullPrefetcher, Prefetcher, +}; +use std::collections::HashMap; +use std::time::Duration; + +// ============================================================================ +// Configuration Constants +// ============================================================================ + +/// Number of experts in Mixtral-style model +const NUM_EXPERTS: usize = 8; + +/// Top-K experts per token +const TOP_K: usize = 2; + +/// Hot-set size +const HOT_SET_SIZE: usize = 4; + +/// Benchmark iterations for statistical significance +const BENCH_ITERS: usize = 10_000; + +// ============================================================================ +// Routing Overhead Benchmark +// ============================================================================ + +/// Benchmark: Compare standard vs memory-aware routing latency +/// +/// Target: Memory-aware routing overhead <= 15 us (baseline ~5 us) +fn bench_routing_overhead(c: &mut Criterion) { + let mut group = c.benchmark_group("routing_overhead"); + group.measurement_time(Duration::from_secs(5)); + + // Baseline: Simple LRU cache access + group.bench_function("baseline_lru", |b| { + let config = ExpertCacheConfig { + max_hot_experts: HOT_SET_SIZE, + prefetch_threshold: 0.0, // No prefetch + eviction_policy: EvictionPolicy::Lru, + }; + let mut cache = ExpertCache::new(NUM_EXPERTS, config); + + // Warm up + for i in 0..HOT_SET_SIZE { + cache.access(i); + } + + let mut expert_id = 0; + b.iter(|| { + expert_id = (expert_id + 1) % NUM_EXPERTS; + black_box(cache.access(expert_id)) + }); + }); + + // Memory-aware: Adaptive eviction + prefetch checking + group.bench_function("memory_aware_adaptive", |b| { + let config = ExpertCacheConfig { + max_hot_experts: HOT_SET_SIZE, + prefetch_threshold: 0.1, + eviction_policy: EvictionPolicy::Adaptive, + }; + let mut cache = ExpertCache::new(NUM_EXPERTS, config); + + // Warm up + for i in 0..HOT_SET_SIZE { + cache.access(i); + } + + let mut expert_id = 0; + b.iter(|| { + expert_id = (expert_id + 1) % NUM_EXPERTS; + let hit = cache.access(expert_id); + let should_prefetch = cache.should_prefetch((expert_id + 1) % NUM_EXPERTS, 0.15); + black_box((hit, should_prefetch)) + }); + }); + + // Full routing with prefetch admission + group.bench_function("full_routing_with_prefetch", |b| { + let config = ExpertCacheConfig { + max_hot_experts: HOT_SET_SIZE, + prefetch_threshold: 0.1, + eviction_policy: EvictionPolicy::Adaptive, + }; + let mut cache = ExpertCache::new(NUM_EXPERTS, config); + + // Warm up + for i in 0..HOT_SET_SIZE { + cache.access(i); + } + + let mut expert_id = 0; + b.iter(|| { + expert_id = (expert_id + 1) % NUM_EXPERTS; + let hit = cache.access(expert_id); + let next_expert = (expert_id + 1) % NUM_EXPERTS; + if cache.should_prefetch(next_expert, 0.15) { + cache.prefetch_admit(next_expert); + } + black_box(hit) + }); + }); + + group.finish(); +} + +// ============================================================================ +// Affinity Update Benchmark +// ============================================================================ + +/// Benchmark: EMA-based affinity score updates +fn bench_affinity_update(c: &mut Criterion) { + let mut group = c.benchmark_group("affinity_update"); + + // Simulate EMA-based affinity tracking + struct AffinityTracker { + scores: Vec, + decay: f32, + } + + impl AffinityTracker { + fn new(num_experts: usize, decay: f32) -> Self { + Self { + scores: vec![0.5; num_experts], + decay, + } + } + + #[inline] + fn activate(&mut self, expert_id: usize) { + if expert_id < self.scores.len() { + // EMA update: score = decay * score + (1 - decay) * 1.0 + self.scores[expert_id] = self.decay * self.scores[expert_id] + (1.0 - self.decay); + } + } + + #[inline] + fn decay_step(&mut self, expert_id: usize) { + if expert_id < self.scores.len() { + self.scores[expert_id] *= self.decay; + } + } + + #[inline] + fn decay_all(&mut self) { + for score in &mut self.scores { + *score *= self.decay; + } + } + + #[inline] + fn get_top_k(&self, k: usize) -> Vec { + let mut indexed: Vec<(usize, f32)> = self.scores.iter().copied().enumerate().collect(); + indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + indexed.into_iter().take(k).map(|(idx, _)| idx).collect() + } + } + + // Single activation + group.bench_function("single_activation", |b| { + let mut tracker = AffinityTracker::new(NUM_EXPERTS, 0.9); + let mut expert_id = 0; + + b.iter(|| { + expert_id = (expert_id + 1) % NUM_EXPERTS; + tracker.activate(expert_id); + black_box(tracker.scores[expert_id]) + }); + }); + + // Decay all experts + group.bench_function("decay_all_experts", |b| { + let mut tracker = AffinityTracker::new(NUM_EXPERTS, 0.9); + + b.iter(|| { + tracker.decay_all(); + black_box(tracker.scores[0]) + }); + }); + + // Get top-K by affinity + group.bench_function("get_top_k", |b| { + let mut tracker = AffinityTracker::new(NUM_EXPERTS, 0.9); + + // Set up varied affinities + for i in 0..100 { + tracker.activate(i % NUM_EXPERTS); + } + + b.iter(|| black_box(tracker.get_top_k(HOT_SET_SIZE))); + }); + + // Combined: activate + decay + get_top_k (full routing step) + group.bench_function("full_routing_step", |b| { + let mut tracker = AffinityTracker::new(NUM_EXPERTS, 0.9); + let mut step = 0; + + b.iter(|| { + step += 1; + let expert_id = step % NUM_EXPERTS; + tracker.activate(expert_id); + tracker.decay_all(); + let top_k = tracker.get_top_k(HOT_SET_SIZE); + black_box(top_k) + }); + }); + + // Larger expert counts (Mixtral 8x22B has 8 experts, but future models may have more) + for num_experts in [8, 16, 32, 64] { + group.bench_with_input( + BenchmarkId::new("decay_all_scaled", num_experts), + &num_experts, + |b, &n| { + let mut tracker = AffinityTracker::new(n, 0.9); + + b.iter(|| { + tracker.decay_all(); + black_box(tracker.scores[0]) + }); + }, + ); + } + + group.finish(); +} + +// ============================================================================ +// Precision Allocation Benchmark +// ============================================================================ + +/// Benchmark: Precision allocation lookup and decision +fn bench_precision_allocation(c: &mut Criterion) { + let mut group = c.benchmark_group("precision_allocation"); + + #[derive(Debug, Clone, Copy, PartialEq)] + #[repr(u8)] + enum Precision { + FP16 = 0, + INT8 = 1, + INT4 = 2, + } + + // Simple lookup-based allocation + group.bench_function("lookup_allocation", |b| { + let config = ExpertCacheConfig { + max_hot_experts: HOT_SET_SIZE, + prefetch_threshold: 0.1, + eviction_policy: EvictionPolicy::Lru, + }; + let mut cache = ExpertCache::new(NUM_EXPERTS, config); + + // Warm up cache + for i in 0..HOT_SET_SIZE { + cache.access(i); + } + + let mut expert_id = 0; + b.iter(|| { + expert_id = (expert_id + 1) % NUM_EXPERTS; + let precision = if cache.is_hot(expert_id) { + Precision::FP16 + } else { + Precision::INT4 + }; + black_box(precision) + }); + }); + + // Batch allocation for all experts + group.bench_function("batch_allocation", |b| { + let config = ExpertCacheConfig { + max_hot_experts: HOT_SET_SIZE, + prefetch_threshold: 0.1, + eviction_policy: EvictionPolicy::Lru, + }; + let mut cache = ExpertCache::new(NUM_EXPERTS, config); + + // Warm up cache + for i in 0..HOT_SET_SIZE { + cache.access(i); + } + + b.iter(|| { + let allocations: Vec = (0..NUM_EXPERTS) + .map(|expert_id| { + if cache.is_hot(expert_id) { + Precision::FP16 + } else { + Precision::INT4 + } + }) + .collect(); + black_box(allocations) + }); + }); + + // Memory budget calculation + group.bench_function("memory_budget_check", |b| { + let intermediate_size = 11008; + let hidden_size = 4096; + let block_size = 256; + + b.iter(|| { + let expert_footprint = + expert_memory_footprint(intermediate_size, hidden_size, block_size) * 3; + let hot_set_budget = expert_footprint * HOT_SET_SIZE; + black_box(hot_set_budget) + }); + }); + + group.finish(); +} + +// ============================================================================ +// Cache Hit Rate Simulation Benchmark +// ============================================================================ + +/// Benchmark: Simulate workload and measure hit rate +fn bench_cache_hit_rate_simulation(c: &mut Criterion) { + let mut group = c.benchmark_group("cache_hit_rate"); + group.measurement_time(Duration::from_secs(5)); + + // Pre-generate routing decisions for consistent benchmarking + let routing_decisions: Vec> = (0..1000) + .map(|token_idx| { + let expert1 = (token_idx * 3 + token_idx / 10) % NUM_EXPERTS; + let expert2 = (expert1 + 1 + token_idx % 3) % NUM_EXPERTS; + vec![(expert1, 0.6), (expert2, 0.4)] + }) + .collect(); + + // Baseline LRU + group.bench_function("baseline_lru", |b| { + b.iter(|| { + let config = ExpertCacheConfig { + max_hot_experts: HOT_SET_SIZE, + prefetch_threshold: 0.0, + eviction_policy: EvictionPolicy::Lru, + }; + let mut cache = ExpertCache::new(NUM_EXPERTS, config); + + for experts in &routing_decisions { + for &(expert_id, _) in experts { + cache.access(expert_id); + } + } + + black_box(cache.stats().hit_rate()) + }); + }); + + // LFU eviction + group.bench_function("lfu_eviction", |b| { + b.iter(|| { + let config = ExpertCacheConfig { + max_hot_experts: HOT_SET_SIZE, + prefetch_threshold: 0.0, + eviction_policy: EvictionPolicy::Lfu, + }; + let mut cache = ExpertCache::new(NUM_EXPERTS, config); + + for experts in &routing_decisions { + for &(expert_id, _) in experts { + cache.access(expert_id); + } + } + + black_box(cache.stats().hit_rate()) + }); + }); + + // Adaptive eviction (memory-aware) + group.bench_function("adaptive_eviction", |b| { + b.iter(|| { + let config = ExpertCacheConfig { + max_hot_experts: HOT_SET_SIZE, + prefetch_threshold: 0.1, + eviction_policy: EvictionPolicy::Adaptive, + }; + let mut cache = ExpertCache::new(NUM_EXPERTS, config); + + for experts in &routing_decisions { + for &(expert_id, weight) in experts { + cache.access(expert_id); + // Also check prefetch candidates + if cache.should_prefetch((expert_id + 1) % NUM_EXPERTS, weight) { + cache.prefetch_admit((expert_id + 1) % NUM_EXPERTS); + } + } + } + + black_box(cache.stats().hit_rate()) + }); + }); + + // Skewed workload (tests adaptive vs LRU more clearly) + let skewed_routing: Vec> = (0..1000) + .map(|token_idx| { + // 80% of accesses to experts 0, 1, 2 + let primary = if token_idx % 10 < 8 { + token_idx % 3 + } else { + 3 + token_idx % (NUM_EXPERTS - 3) + }; + let secondary = (primary + 1) % NUM_EXPERTS; + vec![(primary, 0.7), (secondary, 0.3)] + }) + .collect(); + + group.bench_function("skewed_workload_lru", |b| { + b.iter(|| { + let config = ExpertCacheConfig { + max_hot_experts: HOT_SET_SIZE, + prefetch_threshold: 0.0, + eviction_policy: EvictionPolicy::Lru, + }; + let mut cache = ExpertCache::new(NUM_EXPERTS, config); + + for experts in &skewed_routing { + for &(expert_id, _) in experts { + cache.access(expert_id); + } + } + + black_box(cache.stats().hit_rate()) + }); + }); + + group.bench_function("skewed_workload_adaptive", |b| { + b.iter(|| { + let config = ExpertCacheConfig { + max_hot_experts: HOT_SET_SIZE, + prefetch_threshold: 0.1, + eviction_policy: EvictionPolicy::Adaptive, + }; + let mut cache = ExpertCache::new(NUM_EXPERTS, config); + + for experts in &skewed_routing { + for &(expert_id, weight) in experts { + cache.access(expert_id); + if cache.should_prefetch((expert_id + 1) % NUM_EXPERTS, weight) { + cache.prefetch_admit((expert_id + 1) % NUM_EXPERTS); + } + } + } + + black_box(cache.stats().hit_rate()) + }); + }); + + group.finish(); +} + +// ============================================================================ +// Paging Simulation Benchmark +// ============================================================================ + +/// Benchmark: Expert paging latency simulation +fn bench_paging_simulation(c: &mut Criterion) { + let mut group = c.benchmark_group("paging_simulation"); + group.measurement_time(Duration::from_secs(3)); + + // Simulate paging overhead based on expert memory footprint + let intermediate_size = 11008; + let hidden_size = 4096; + let block_size = 256; + let expert_size = expert_memory_footprint(intermediate_size, hidden_size, block_size) * 3; + + // Memory throughput assumptions (GB/s) + // DDR5-4800: ~38 GB/s + // Apple M4 unified memory: ~120 GB/s + let memory_bandwidth_gbps = 120.0; // M4 assumption + + fn simulate_page_in(expert_size: usize, bandwidth_gbps: f64) -> Duration { + let bytes = expert_size as f64; + let gb = bytes / 1e9; + let seconds = gb / bandwidth_gbps; + Duration::from_secs_f64(seconds) + } + + // Single expert page-in + group.bench_function("single_expert_page_in", |b| { + b.iter(|| { + let latency = simulate_page_in(expert_size, memory_bandwidth_gbps); + black_box(latency) + }); + }); + + // Batch expert page-in (amortized cost) + for batch_size in [1, 2, 4, 8] { + group.bench_with_input( + BenchmarkId::new("batch_page_in", batch_size), + &batch_size, + |b, &size| { + b.iter(|| { + let total_size = expert_size * size; + let latency = simulate_page_in(total_size, memory_bandwidth_gbps); + let per_expert = latency / size as u32; + black_box(per_expert) + }); + }, + ); + } + + // Cache management overhead during paging + group.bench_function("paging_with_cache_update", |b| { + let config = ExpertCacheConfig { + max_hot_experts: HOT_SET_SIZE, + prefetch_threshold: 0.1, + eviction_policy: EvictionPolicy::Adaptive, + }; + let mut cache = ExpertCache::new(NUM_EXPERTS, config); + + let mut expert_id = 0; + b.iter(|| { + expert_id = (expert_id + 1) % NUM_EXPERTS; + + // Check if page-in needed + let needs_page_in = !cache.is_hot(expert_id); + + // Simulate page-in latency calculation + let latency = if needs_page_in { + simulate_page_in(expert_size, memory_bandwidth_gbps) + } else { + Duration::ZERO + }; + + // Update cache + cache.access(expert_id); + + black_box(latency) + }); + }); + + group.finish(); +} + +// ============================================================================ +// Batch Scheduler Benchmark +// ============================================================================ + +/// Benchmark: MoE batch scheduling performance +fn bench_batch_scheduler(c: &mut Criterion) { + let mut group = c.benchmark_group("batch_scheduler"); + + for batch_size in [1, 8, 32, 128, 512] { + let routing_decisions: Vec<(usize, Vec<(usize, f32)>)> = (0..batch_size) + .map(|token_idx| { + let expert1 = (token_idx * 3) % NUM_EXPERTS; + let expert2 = (expert1 + 1 + token_idx % 2) % NUM_EXPERTS; + (token_idx, vec![(expert1, 0.6), (expert2, 0.4)]) + }) + .collect(); + + group.throughput(Throughput::Elements(batch_size as u64)); + + group.bench_with_input( + BenchmarkId::new("schedule", batch_size), + &routing_decisions, + |b, routing| { + b.iter(|| { + let batches = MoeBatchScheduler::schedule(routing); + black_box(batches) + }); + }, + ); + } + + group.finish(); +} + +// ============================================================================ +// Memory Footprint Calculation Benchmark +// ============================================================================ + +/// Benchmark: Memory footprint calculations +fn bench_memory_footprint(c: &mut Criterion) { + let mut group = c.benchmark_group("memory_footprint"); + + // Various model sizes + let model_configs = [ + ("small", 4096, 2048, 256), + ("medium", 8192, 4096, 256), + ("mixtral", 11008, 4096, 256), + ("large", 14336, 8192, 256), + ]; + + for (name, intermediate, hidden, block) in model_configs { + group.bench_with_input( + BenchmarkId::new("calculate", name), + &(intermediate, hidden, block), + |b, &(i, h, bs)| { + b.iter(|| { + let gate = expert_memory_footprint(i, h, bs); + let up = expert_memory_footprint(i, h, bs); + let down = expert_memory_footprint(h, i, bs); + black_box(gate + up + down) + }); + }, + ); + } + + // Cache line alignment + group.bench_function("cache_line_align", |b| { + let mut offset = 0usize; + b.iter(|| { + offset = (offset + 137) % 10000; // Varied offsets + black_box(align_to_cache_line(offset)) + }); + }); + + group.finish(); +} + +// ============================================================================ +// Prefetch Decision Benchmark +// ============================================================================ + +/// Benchmark: Prefetch decision making +fn bench_prefetch_decision(c: &mut Criterion) { + let mut group = c.benchmark_group("prefetch_decision"); + + group.bench_function("should_prefetch", |b| { + let config = ExpertCacheConfig { + max_hot_experts: HOT_SET_SIZE, + prefetch_threshold: 0.1, + eviction_policy: EvictionPolicy::Adaptive, + }; + let mut cache = ExpertCache::new(NUM_EXPERTS, config); + + // Warm up + for i in 0..HOT_SET_SIZE { + cache.access(i); + } + + let weights: Vec = (0..NUM_EXPERTS).map(|i| 0.05 + (i as f32) * 0.03).collect(); + + let mut idx = 0; + b.iter(|| { + idx = (idx + 1) % NUM_EXPERTS; + black_box(cache.should_prefetch(idx, weights[idx])) + }); + }); + + group.bench_function("prefetch_admit", |b| { + let config = ExpertCacheConfig { + max_hot_experts: HOT_SET_SIZE, + prefetch_threshold: 0.1, + eviction_policy: EvictionPolicy::Adaptive, + }; + let mut cache = ExpertCache::new(NUM_EXPERTS, config); + + let mut idx = 0; + b.iter(|| { + idx = (idx + 1) % NUM_EXPERTS; + cache.prefetch_admit(idx); + black_box(()) + }); + }); + + group.bench_function("full_prefetch_cycle", |b| { + let config = ExpertCacheConfig { + max_hot_experts: HOT_SET_SIZE, + prefetch_threshold: 0.1, + eviction_policy: EvictionPolicy::Adaptive, + }; + let mut cache = ExpertCache::new(NUM_EXPERTS, config); + + // Simulate router weights + let weights: Vec = vec![0.35, 0.25, 0.15, 0.10, 0.06, 0.04, 0.03, 0.02]; + + let mut token = 0; + b.iter(|| { + token += 1; + + // Current routing decision + let top_2: Vec = (0..NUM_EXPERTS) + .filter(|&i| weights[i] > 0.1) + .take(2) + .collect(); + + // Access current experts + for &expert_id in &top_2 { + cache.access(expert_id); + } + + // Check prefetch candidates + for (expert_id, &weight) in weights.iter().enumerate() { + if cache.should_prefetch(expert_id, weight) { + cache.prefetch_admit(expert_id); + } + } + + black_box(top_2) + }); + }); + + group.finish(); +} + +// ============================================================================ +// Eviction Policy Comparison Benchmark +// ============================================================================ + +/// Benchmark: Compare eviction policies +fn bench_eviction_policies(c: &mut Criterion) { + let mut group = c.benchmark_group("eviction_policies"); + group.measurement_time(Duration::from_secs(3)); + + let policies = [ + ("lru", EvictionPolicy::Lru), + ("lfu", EvictionPolicy::Lfu), + ("adaptive", EvictionPolicy::Adaptive), + ]; + + // Generate access pattern with locality + let access_pattern: Vec = (0..1000) + .map(|i| { + // 70% local, 30% random + if i % 10 < 7 { + i % 3 // Local accesses to experts 0, 1, 2 + } else { + (i * 7) % NUM_EXPERTS // Pseudo-random + } + }) + .collect(); + + for (name, policy) in policies { + group.bench_with_input( + BenchmarkId::new("access_pattern", name), + &policy, + |b, &policy| { + b.iter(|| { + let config = ExpertCacheConfig { + max_hot_experts: HOT_SET_SIZE, + prefetch_threshold: 0.1, + eviction_policy: policy, + }; + let mut cache = ExpertCache::new(NUM_EXPERTS, config); + + for &expert_id in &access_pattern { + cache.access(expert_id); + } + + black_box(cache.stats().hit_rate()) + }); + }, + ); + } + + group.finish(); +} + +// ============================================================================ +// Memory-Aware Router Benchmarks (P1-P4 Optimizations) +// ============================================================================ + +/// Benchmark: MemoryAwareRouter performance (ADR-092) +fn bench_memory_aware_router(c: &mut Criterion) { + use ruvllm::moe::{AffinityConfig, ExpertAffinity, MemoryAwareRouter, RouterConfig}; + + let mut group = c.benchmark_group("memory_aware_router"); + group.measurement_time(Duration::from_secs(5)); + + // Test various expert counts + for num_experts in [8, 16, 32, 64] { + // P4: Top-2 unrolled optimization + group.bench_with_input( + BenchmarkId::new("route_top2", num_experts), + &num_experts, + |b, &n| { + let config = RouterConfig::new(n, 2).with_cache_bonus(0.15); + let mut router = MemoryAwareRouter::with_default_affinity(config).unwrap(); + + // Set half experts as resident + let resident: Vec = (0..n / 2).collect(); + router.update_cache_state(&resident); + + // Generate gate logits + let gate_logits: Vec = (0..n).map(|i| 0.1 + (i as f32) * 0.01).collect(); + + b.iter(|| { + let (selected, _paging) = router.route(black_box(&gate_logits)); + black_box(selected) + }); + }, + ); + + // P2: Batch routing optimization + group.bench_with_input( + BenchmarkId::new("route_batch_8", num_experts), + &num_experts, + |b, &n| { + let config = RouterConfig::new(n, 2).with_cache_bonus(0.15); + let mut router = MemoryAwareRouter::with_default_affinity(config).unwrap(); + + let resident: Vec = (0..n / 2).collect(); + router.update_cache_state(&resident); + + // Generate batch of 8 tokens + let batch_logits: Vec> = (0..8) + .map(|t| (0..n).map(|i| 0.1 + (i as f32 + t as f32) * 0.01).collect()) + .collect(); + let batch_refs: Vec<&[f32]> = batch_logits.iter().map(|v| v.as_slice()).collect(); + + b.iter(|| { + let results = router.route_batch(black_box(&batch_refs)); + black_box(results) + }); + }, + ); + } + + // P1: Bitmask cache check overhead + group.bench_function("cache_mask_check_64", |b| { + let config = RouterConfig::new(64, 2); + let mut router = MemoryAwareRouter::with_default_affinity(config).unwrap(); + + // Set alternating experts as resident + let resident: Vec = (0..64).step_by(2).collect(); + router.update_cache_state(&resident); + + let mut id = 0usize; + b.iter(|| { + id = (id + 1) % 64; + black_box(router.is_resident(id)) + }); + }); + + // P1: Large expert count (>64, uses extended bitmask) + group.bench_function("cache_mask_check_128", |b| { + let config = RouterConfig::new(128, 4); + let mut router = MemoryAwareRouter::with_default_affinity(config).unwrap(); + + let resident: Vec = (0..128).step_by(2).collect(); + router.update_cache_state(&resident); + + let mut id = 0usize; + b.iter(|| { + id = (id + 1) % 128; + black_box(router.is_resident(id)) + }); + }); + + // Compare top-2 vs top-4 selection + group.bench_function("select_top2_vs_sort", |b| { + let config = RouterConfig::new(64, 2).with_cache_bonus(0.15); + let mut router = MemoryAwareRouter::with_default_affinity(config).unwrap(); + + let gate_logits: Vec = (0..64).map(|i| (i as f32 * 0.7).sin()).collect(); + + b.iter(|| black_box(router.select_top_k(black_box(&gate_logits)))); + }); + + group.bench_function("select_top4_partial_sort", |b| { + let config = RouterConfig::new(64, 4).with_cache_bonus(0.15); + let mut router = MemoryAwareRouter::with_default_affinity(config).unwrap(); + + let gate_logits: Vec = (0..64).map(|i| (i as f32 * 0.7).sin()).collect(); + + b.iter(|| black_box(router.select_top_k(black_box(&gate_logits)))); + }); + + group.finish(); +} + +/// Benchmark: SIMD affinity decay (P1 optimization) +fn bench_simd_affinity_decay(c: &mut Criterion) { + use ruvllm::moe::{AffinityConfig, ExpertAffinity}; + + let mut group = c.benchmark_group("simd_affinity_decay"); + + for num_experts in [8, 16, 32, 64, 128, 256] { + group.throughput(Throughput::Elements(num_experts as u64)); + + group.bench_with_input( + BenchmarkId::new("decay_all", num_experts), + &num_experts, + |b, &n| { + let config = AffinityConfig::with_num_experts(n).with_decay(0.95); + let mut affinity = ExpertAffinity::new(config); + + // Activate all experts initially + let all: Vec = (0..n).collect(); + affinity.update(&all); + + b.iter(|| { + affinity.update(&[]); // Decay-only update + black_box(affinity.score(0)) + }); + }, + ); + + group.bench_with_input( + BenchmarkId::new("update_with_activation", num_experts), + &num_experts, + |b, &n| { + let config = AffinityConfig::with_num_experts(n).with_decay(0.95); + let mut affinity = ExpertAffinity::new(config); + + let activated = vec![0, 1]; // Activate 2 experts per call + + b.iter(|| { + affinity.update(&activated); + black_box(affinity.score(0)) + }); + }, + ); + } + + group.finish(); +} + +// ============================================================================ +// Criterion Groups +// ============================================================================ + +criterion_group!( + benches, + bench_routing_overhead, + bench_affinity_update, + bench_precision_allocation, + bench_cache_hit_rate_simulation, + bench_paging_simulation, + bench_batch_scheduler, + bench_memory_footprint, + bench_prefetch_decision, + bench_eviction_policies, + bench_memory_aware_router, + bench_simd_affinity_decay, +); + +criterion_main!(benches); diff --git a/crates/ruvllm/benches/pi_quant_bench.rs b/crates/ruvllm/benches/pi_quant_bench.rs index 2c33ef9f8..94b91c4ae 100644 --- a/crates/ruvllm/benches/pi_quant_bench.rs +++ b/crates/ruvllm/benches/pi_quant_bench.rs @@ -18,9 +18,7 @@ #![allow(unused_imports, dead_code, unused_variables)] -use criterion::{ - black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput, -}; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; use rand::prelude::*; use std::f32::consts::PI; @@ -52,18 +50,14 @@ impl Pi3BitBlock { u[i] = (v + 4) as u8; } - self.data[0] = (u[0] & 0x07) - | ((u[1] & 0x07) << 3) - | ((u[2] & 0x07) << 6); + self.data[0] = (u[0] & 0x07) | ((u[1] & 0x07) << 3) | ((u[2] & 0x07) << 6); self.data[1] = ((u[2] >> 2) & 0x01) | ((u[3] & 0x07) << 1) | ((u[4] & 0x07) << 4) | ((u[5] & 0x07) << 7); - self.data[2] = ((u[5] >> 1) & 0x03) - | ((u[6] & 0x07) << 2) - | ((u[7] & 0x07) << 5); + self.data[2] = ((u[5] >> 1) & 0x03) | ((u[6] & 0x07) << 2) | ((u[7] & 0x07) << 5); } fn unpack(&self) -> [i8; 8] { @@ -484,9 +478,7 @@ impl SteVariant { 0.0 } } - SteVariant::Ewgs { lambda } => { - grad_out * (1.0 + lambda * (w - q).abs()) - } + SteVariant::Ewgs { lambda } => grad_out * (1.0 + lambda * (w - q).abs()), } } } @@ -959,13 +951,17 @@ unsafe fn quantize_2bit_avx2(weights: &[f32], step: f32, output: &mut [u8]) -> u fn quantize_3bit_dispatch(weights: &[f32], step: f32, output: &mut [u8]) -> usize { #[cfg(target_arch = "aarch64")] { - unsafe { return quantize_3bit_neon(weights, step, output); } + unsafe { + return quantize_3bit_neon(weights, step, output); + } } #[cfg(target_arch = "x86_64")] { if is_x86_feature_detected!("avx2") { - unsafe { return quantize_3bit_avx2(weights, step, output); } + unsafe { + return quantize_3bit_avx2(weights, step, output); + } } } @@ -975,13 +971,17 @@ fn quantize_3bit_dispatch(weights: &[f32], step: f32, output: &mut [u8]) -> usiz fn quantize_2bit_dispatch(weights: &[f32], step: f32, output: &mut [u8]) -> usize { #[cfg(target_arch = "aarch64")] { - unsafe { return quantize_2bit_neon(weights, step, output); } + unsafe { + return quantize_2bit_neon(weights, step, output); + } } #[cfg(target_arch = "x86_64")] { if is_x86_feature_detected!("avx2") { - unsafe { return quantize_2bit_avx2(weights, step, output); } + unsafe { + return quantize_2bit_avx2(weights, step, output); + } } } @@ -1004,9 +1004,7 @@ fn bench_pi_quantize_3bit_fast(c: &mut Criterion) { group.throughput(Throughput::Bytes(output_bytes as u64)); group.bench_with_input(BenchmarkId::new("size", size), &weights, |b, w| { - b.iter(|| { - quantize_3bit_fast(black_box(w), step, black_box(&mut output)) - }) + b.iter(|| quantize_3bit_fast(black_box(w), step, black_box(&mut output))) }); } @@ -1028,9 +1026,7 @@ fn bench_pi_quantize_2bit_fast(c: &mut Criterion) { group.throughput(Throughput::Bytes(num_blocks as u64)); group.bench_with_input(BenchmarkId::new("size", size), &weights, |b, w| { - b.iter(|| { - quantize_2bit_fast(black_box(w), step, black_box(&mut output)) - }) + b.iter(|| quantize_2bit_fast(black_box(w), step, black_box(&mut output))) }); } @@ -1053,9 +1049,7 @@ fn bench_pi_quantize_3bit_simd(c: &mut Criterion) { group.throughput(Throughput::Bytes(output_bytes as u64)); group.bench_with_input(BenchmarkId::new("size", size), &weights, |b, w| { - b.iter(|| { - quantize_3bit_dispatch(black_box(w), step, black_box(&mut output)) - }) + b.iter(|| quantize_3bit_dispatch(black_box(w), step, black_box(&mut output))) }); } @@ -1077,9 +1071,7 @@ fn bench_pi_quantize_2bit_simd(c: &mut Criterion) { group.throughput(Throughput::Bytes(num_blocks as u64)); group.bench_with_input(BenchmarkId::new("size", size), &weights, |b, w| { - b.iter(|| { - quantize_2bit_dispatch(black_box(w), step, black_box(&mut output)) - }) + b.iter(|| quantize_2bit_dispatch(black_box(w), step, black_box(&mut output))) }); } @@ -1102,9 +1094,7 @@ fn bench_pi_quantize_3bit_neon(c: &mut Criterion) { group.throughput(Throughput::Bytes(output_bytes as u64)); group.bench_with_input(BenchmarkId::new("weights", size), &weights, |b, w| { - b.iter(|| unsafe { - quantize_3bit_neon(black_box(w), step, black_box(&mut output)) - }) + b.iter(|| unsafe { quantize_3bit_neon(black_box(w), step, black_box(&mut output)) }) }); } @@ -1126,9 +1116,7 @@ fn bench_pi_quantize_2bit_neon(c: &mut Criterion) { group.throughput(Throughput::Bytes(num_blocks as u64)); group.bench_with_input(BenchmarkId::new("weights", size), &weights, |b, w| { - b.iter(|| unsafe { - quantize_2bit_neon(black_box(w), step, black_box(&mut output)) - }) + b.iter(|| unsafe { quantize_2bit_neon(black_box(w), step, black_box(&mut output)) }) }); } @@ -1155,9 +1143,7 @@ fn bench_pi_quantize_3bit_avx2(c: &mut Criterion) { group.throughput(Throughput::Bytes(output_bytes as u64)); group.bench_with_input(BenchmarkId::new("weights", size), &weights, |b, w| { - b.iter(|| unsafe { - quantize_3bit_avx2(black_box(w), step, black_box(&mut output)) - }) + b.iter(|| unsafe { quantize_3bit_avx2(black_box(w), step, black_box(&mut output)) }) }); } @@ -1183,9 +1169,7 @@ fn bench_pi_quantize_2bit_avx2(c: &mut Criterion) { group.throughput(Throughput::Bytes(num_blocks as u64)); group.bench_with_input(BenchmarkId::new("weights", size), &weights, |b, w| { - b.iter(|| unsafe { - quantize_2bit_avx2(black_box(w), step, black_box(&mut output)) - }) + b.iter(|| unsafe { quantize_2bit_avx2(black_box(w), step, black_box(&mut output)) }) }); } @@ -1234,15 +1218,11 @@ fn bench_pi_dequantize_scalar(c: &mut Criterion) { let input_bytes = packed.len(); group.throughput(Throughput::Bytes(input_bytes as u64)); - group.bench_with_input( - BenchmarkId::new("weights", num_weights), - &packed, - |b, p| { - b.iter(|| { - pi_dequantize_scalar(black_box(p), scale, black_box(&mut output)); - }) - }, - ); + group.bench_with_input(BenchmarkId::new("weights", num_weights), &packed, |b, p| { + b.iter(|| { + pi_dequantize_scalar(black_box(p), scale, black_box(&mut output)); + }) + }); } group.finish(); @@ -1263,15 +1243,11 @@ fn bench_pi_dequantize_neon(c: &mut Criterion) { let input_bytes = packed.len(); group.throughput(Throughput::Bytes(input_bytes as u64)); - group.bench_with_input( - BenchmarkId::new("weights", num_weights), - &packed, - |b, p| { - b.iter(|| unsafe { - pi_dequantize_neon(black_box(p), scale, black_box(&mut output)); - }) - }, - ); + group.bench_with_input(BenchmarkId::new("weights", num_weights), &packed, |b, p| { + b.iter(|| unsafe { + pi_dequantize_neon(black_box(p), scale, black_box(&mut output)); + }) + }); } group.finish(); @@ -1296,15 +1272,11 @@ fn bench_pi_dequantize_avx2(c: &mut Criterion) { let input_bytes = packed.len(); group.throughput(Throughput::Bytes(input_bytes as u64)); - group.bench_with_input( - BenchmarkId::new("weights", num_weights), - &packed, - |b, p| { - b.iter(|| unsafe { - pi_dequantize_avx2(black_box(p), scale, black_box(&mut output)); - }) - }, - ); + group.bench_with_input(BenchmarkId::new("weights", num_weights), &packed, |b, p| { + b.iter(|| unsafe { + pi_dequantize_avx2(black_box(p), scale, black_box(&mut output)); + }) + }); } group.finish(); @@ -1363,7 +1335,9 @@ fn bench_hadamard_layer_sizes(c: &mut Criterion) { // Common layer dimensions (rounded to power of 2) for &size in &[256, 4096, 8192, 16384] { - let data: Vec = (0..size).map(|i| (i as f32 - size as f32 / 2.0) / 100.0).collect(); + let data: Vec = (0..size) + .map(|i| (i as f32 - size as f32 / 2.0) / 100.0) + .collect(); group.throughput(Throughput::Elements(size as u64)); group.bench_with_input(BenchmarkId::new("dim", size), &data, |b, d| { @@ -1453,9 +1427,7 @@ fn bench_mse_computation(c: &mut Criterion) { group.bench_with_input( BenchmarkId::new("weights", size), &(&original, &quantized), - |b, (o, q)| { - b.iter(|| compute_mse(black_box(*o), black_box(*q))) - }, + |b, (o, q)| b.iter(|| compute_mse(black_box(*o), black_box(*q))), ); } @@ -1475,9 +1447,7 @@ fn bench_spectral_distortion(c: &mut Criterion) { group.bench_with_input( BenchmarkId::new("weights", size), &(&original, &quantized), - |b, (o, q)| { - b.iter(|| compute_spectral_distortion(black_box(*o), black_box(*q))) - }, + |b, (o, q)| b.iter(|| compute_spectral_distortion(black_box(*o), black_box(*q))), ); } diff --git a/crates/ruvllm/src/bitnet/expert_cache.rs b/crates/ruvllm/src/bitnet/expert_cache.rs index b91be2b77..04d26904b 100644 --- a/crates/ruvllm/src/bitnet/expert_cache.rs +++ b/crates/ruvllm/src/bitnet/expert_cache.rs @@ -321,8 +321,299 @@ impl ExpertCache { self.config.max_hot_experts } + /// Get list of currently hot experts. + /// + /// Returns the expert IDs currently in the hot set, in no particular order. + /// Useful for prefetch decisions and cache diagnostics. + /// + /// # Example + /// + /// ```rust,ignore + /// use ruvllm::bitnet::expert_cache::{ExpertCache, ExpertCacheConfig}; + /// + /// let mut cache = ExpertCache::new(8, ExpertCacheConfig::default()); + /// cache.access(2); + /// cache.access(5); + /// + /// let hot = cache.hot_experts(); + /// assert!(hot.contains(&2)); + /// assert!(hot.contains(&5)); + /// ``` + pub fn hot_experts(&self) -> Vec { + self.hot_set.iter().map(|&(id, _)| id).collect() + } + + /// Suggest eviction with affinity awareness. + /// + /// Combines the base eviction score (from LRU/LFU/Adaptive policy) with + /// affinity scores to make better eviction decisions. Experts with high + /// affinity are less likely to be evicted even if they have low frequency + /// or old access times. + /// + /// # Algorithm + /// + /// For each hot expert, compute a combined score: + /// ```text + /// eviction_score = (1 - affinity_weight) * base_score + affinity_weight * (1 - affinity) + /// ``` + /// + /// Where: + /// - `base_score` is normalized LRU/LFU score (0=least likely to evict, 1=most likely) + /// - `affinity` is the expert's affinity score from `ExpertAffinity` + /// - `affinity_weight` controls the influence of affinity (0.0-1.0) + /// + /// The expert with the **highest** eviction_score is suggested for eviction. + /// + /// # Arguments + /// + /// * `affinity` - The expert affinity tracker (from `moe::ExpertAffinity`) + /// * `affinity_weight` - How much affinity influences eviction (0.0-1.0) + /// - 0.0 = pure base policy (LRU/LFU/Adaptive) + /// - 1.0 = pure affinity-based (evict lowest affinity) + /// - 0.3-0.5 = recommended balance + /// + /// # Returns + /// + /// `Some(expert_id)` if the hot set is full and an expert should be evicted. + /// `None` if the hot set is not full. + /// + /// # Example + /// + /// ```rust,ignore + /// use ruvllm::bitnet::expert_cache::{ExpertCache, ExpertCacheConfig}; + /// use ruvllm::moe::{ExpertAffinity, AffinityConfig}; + /// + /// let mut cache = ExpertCache::new(8, ExpertCacheConfig::default()); + /// let mut affinity = ExpertAffinity::new(AffinityConfig::with_num_experts(8)); + /// + /// // Fill the hot set + /// for i in 0..4 { cache.access(i); } + /// + /// // Update affinity - expert 0 has high affinity + /// for _ in 0..10 { affinity.update(&[0]); } + /// + /// // Should NOT suggest expert 0 despite being LRU + /// let victim = cache.suggest_eviction_with_affinity(&affinity, 0.5); + /// assert_ne!(victim, Some(0)); + /// ``` + pub fn suggest_eviction_with_affinity( + &self, + affinity: &crate::moe::ExpertAffinity, + affinity_weight: f32, + ) -> Option { + if self.hot_set.len() < self.config.max_hot_experts { + return None; + } + + // Clamp weight to valid range + let weight = affinity_weight.clamp(0.0, 1.0); + + // If weight is 0, just use base policy + if weight < 1e-6 { + return self.suggest_eviction(); + } + + // Compute base scores based on policy + let base_scores = self.compute_base_eviction_scores(); + + if base_scores.is_empty() { + return None; + } + + // Find expert with highest combined eviction score + let mut best_victim: Option = None; + let mut best_score: f32 = f32::MIN; + + for &(id, _) in &self.hot_set { + let base_score = base_scores.get(&id).copied().unwrap_or(0.5); + let expert_affinity = affinity.score(id); + + // Combined score: higher = more likely to evict + // (1 - affinity) means low affinity -> high eviction likelihood + let combined = (1.0 - weight) * base_score + weight * (1.0 - expert_affinity); + + if combined > best_score { + best_score = combined; + best_victim = Some(id); + } + } + + best_victim + } + + /// Prefetch experts based on affinity predictions. + /// + /// Selects the top experts by affinity score that are not already in the + /// hot set, up to the given budget, and admits them via prefetch. + /// + /// # Arguments + /// + /// * `affinity` - The expert affinity tracker + /// * `budget` - Maximum number of experts to prefetch + /// + /// # Returns + /// + /// Vector of expert IDs that were actually prefetched (may be fewer than + /// `budget` if the hot set is nearly full or all high-affinity experts + /// are already hot). + /// + /// # Example + /// + /// ```rust,ignore + /// use ruvllm::bitnet::expert_cache::{ExpertCache, ExpertCacheConfig}; + /// use ruvllm::moe::{ExpertAffinity, AffinityConfig}; + /// + /// let config = ExpertCacheConfig { max_hot_experts: 4, ..Default::default() }; + /// let mut cache = ExpertCache::new(8, config); + /// let mut affinity = ExpertAffinity::new(AffinityConfig::with_num_experts(8)); + /// + /// // Build up affinity for experts 3 and 5 + /// for _ in 0..5 { affinity.update(&[3, 5]); } + /// + /// // Prefetch top 2 by affinity + /// let prefetched = cache.prefetch_by_affinity(&affinity, 2); + /// + /// assert!(prefetched.contains(&3) || prefetched.contains(&5)); + /// assert!(cache.is_hot(3) || cache.is_hot(5)); + /// ``` + pub fn prefetch_by_affinity( + &mut self, + affinity: &crate::moe::ExpertAffinity, + budget: usize, + ) -> Vec { + if budget == 0 { + return Vec::new(); + } + + // Get top experts by affinity + let top_experts = affinity.top_k_by_affinity(self.num_experts); + + let mut prefetched = Vec::with_capacity(budget); + + for expert_id in top_experts { + if prefetched.len() >= budget { + break; + } + + // Skip if already hot + if self.is_hot(expert_id) { + continue; + } + + // Skip if hot set is full and we can't make room + if self.hot_set.len() >= self.config.max_hot_experts { + // Try to evict using affinity-aware policy + if let Some(victim) = self.suggest_eviction_with_affinity(affinity, 0.5) { + self.evict(victim); + } else { + break; // Can't make room + } + } + + // Admit via prefetch + self.prefetch_admit(expert_id); + prefetched.push(expert_id); + } + + prefetched + } + // --- Private helpers --- + /// Compute normalized base eviction scores for all hot experts. + /// + /// Returns a map of expert_id -> score where: + /// - 0.0 = least likely to evict + /// - 1.0 = most likely to evict + fn compute_base_eviction_scores(&self) -> HashMap { + let mut scores = HashMap::new(); + + if self.hot_set.is_empty() { + return scores; + } + + match self.config.eviction_policy { + EvictionPolicy::Lru => { + // LRU: older timestamp = higher eviction score + let timestamps: Vec = self.hot_set.iter().map(|&(_, ts)| ts).collect(); + let min_ts = timestamps.iter().copied().min().unwrap_or(0); + let max_ts = timestamps.iter().copied().max().unwrap_or(1); + let range = (max_ts - min_ts) as f32; + + for &(id, ts) in &self.hot_set { + let score = if range > 0.0 { + 1.0 - ((ts - min_ts) as f32 / range) + } else { + 0.5 + }; + scores.insert(id, score); + } + } + EvictionPolicy::Lfu => { + // LFU: lower frequency = higher eviction score + let freqs: Vec = self + .hot_set + .iter() + .map(|&(id, _)| self.frequency.get(id).copied().unwrap_or(0)) + .collect(); + let min_freq = freqs.iter().copied().min().unwrap_or(0); + let max_freq = freqs.iter().copied().max().unwrap_or(1); + let range = (max_freq - min_freq) as f32; + + for &(id, _) in &self.hot_set { + let freq = self.frequency.get(id).copied().unwrap_or(0); + let score = if range > 0.0 { + 1.0 - ((freq - min_freq) as f32 / range) + } else { + 0.5 + }; + scores.insert(id, score); + } + } + EvictionPolicy::Adaptive => { + // Adaptive: check skewness and use appropriate policy + let freqs: Vec = self + .hot_set + .iter() + .map(|&(id, _)| self.frequency.get(id).copied().unwrap_or(0)) + .collect(); + let max_freq = freqs.iter().copied().max().unwrap_or(0); + let min_freq = freqs.iter().copied().min().unwrap_or(0); + + if min_freq > 0 && max_freq >= 3 * min_freq { + // Skewed: use LFU scores + let range = (max_freq - min_freq) as f32; + for &(id, _) in &self.hot_set { + let freq = self.frequency.get(id).copied().unwrap_or(0); + let score = if range > 0.0 { + 1.0 - ((freq - min_freq) as f32 / range) + } else { + 0.5 + }; + scores.insert(id, score); + } + } else { + // Not skewed: use LRU scores + let timestamps: Vec = self.hot_set.iter().map(|&(_, ts)| ts).collect(); + let min_ts = timestamps.iter().copied().min().unwrap_or(0); + let max_ts = timestamps.iter().copied().max().unwrap_or(1); + let range = (max_ts - min_ts) as f32; + + for &(id, ts) in &self.hot_set { + let score = if range > 0.0 { + 1.0 - ((ts - min_ts) as f32 / range) + } else { + 0.5 + }; + scores.insert(id, score); + } + } + } + } + + scores + } + /// LRU eviction: pick the expert with the smallest (oldest) timestamp. fn suggest_lru_eviction(&self) -> Option { self.hot_set @@ -1061,4 +1352,225 @@ mod tests { cache.admit(1); assert_eq!(cache.hot_count(), 2); } + + // --------------------------------------------------------------- + // 21. hot_experts returns current hot set + // --------------------------------------------------------------- + + #[test] + fn test_hot_experts_list() { + let mut cache = make_cache(8, 4, EvictionPolicy::Lru); + + cache.access(2); + cache.access(5); + cache.access(7); + + let hot = cache.hot_experts(); + + assert_eq!(hot.len(), 3); + assert!(hot.contains(&2)); + assert!(hot.contains(&5)); + assert!(hot.contains(&7)); + assert!(!hot.contains(&0)); + } + + // --------------------------------------------------------------- + // 22. Eviction with affinity prefers low affinity experts + // --------------------------------------------------------------- + + #[test] + fn test_eviction_with_affinity_prefers_low_affinity() { + use crate::moe::{AffinityConfig, ExpertAffinity}; + + let mut cache = make_cache(8, 3, EvictionPolicy::Lru); + // Use small activation_boost to create meaningful differences + let mut affinity = ExpertAffinity::new( + AffinityConfig::with_num_experts(8) + .with_decay(1.0) + .with_activation_boost(0.1), + ); + + // Fill cache with experts 0, 1, 2 + cache.access(0); + cache.access(1); + cache.access(2); + + // Expert 0 has high affinity (many activations): 10 * 0.1 = 1.0 (clamped) + for _ in 0..10 { + affinity.update(&[0]); + } + + // Expert 2 has low affinity (few activations): 1 * 0.1 = 0.1 + affinity.update(&[2]); + + // Expert 1 has medium affinity: 5 * 0.1 = 0.5 + for _ in 0..5 { + affinity.update(&[1]); + } + + // With pure affinity_weight=1.0, should suggest evicting expert 2 (lowest affinity) + let victim = cache.suggest_eviction_with_affinity(&affinity, 1.0); + + // Expert 2 should be evicted (lowest affinity=0.1) + assert_eq!(victim, Some(2), "Should evict lowest affinity expert"); + } + + // --------------------------------------------------------------- + // 23. Prefetch by affinity respects budget + // --------------------------------------------------------------- + + #[test] + fn test_prefetch_by_affinity_respects_budget() { + use crate::moe::{AffinityConfig, ExpertAffinity}; + + let config = ExpertCacheConfig { + max_hot_experts: 6, + prefetch_threshold: 0.1, + eviction_policy: EvictionPolicy::Lru, + }; + let mut cache = ExpertCache::new(8, config); + let mut affinity = ExpertAffinity::new(AffinityConfig::with_num_experts(8).with_decay(1.0)); + + // Build affinity for experts 3, 5, 7 + for _ in 0..5 { + affinity.update(&[3, 5, 7]); + } + + // Prefetch with budget of 2 + let prefetched = cache.prefetch_by_affinity(&affinity, 2); + + // Should prefetch at most 2 experts + assert!(prefetched.len() <= 2, "Should respect budget"); + assert!( + prefetched.len() >= 1, + "Should prefetch at least 1 high-affinity expert" + ); + + // All prefetched should now be hot + for &id in &prefetched { + assert!(cache.is_hot(id), "Prefetched expert should be hot"); + } + } + + // --------------------------------------------------------------- + // 24. Prefetch skips already hot experts + // --------------------------------------------------------------- + + #[test] + fn test_prefetch_skips_already_hot() { + use crate::moe::{AffinityConfig, ExpertAffinity}; + + let config = ExpertCacheConfig { + max_hot_experts: 4, + prefetch_threshold: 0.1, + eviction_policy: EvictionPolicy::Lru, + }; + let mut cache = ExpertCache::new(8, config); + let mut affinity = ExpertAffinity::new(AffinityConfig::with_num_experts(8).with_decay(1.0)); + + // Make expert 3 hot via access + cache.access(3); + + // Build highest affinity for expert 3 + for _ in 0..10 { + affinity.update(&[3]); + } + + // Build lower affinity for expert 5 + for _ in 0..5 { + affinity.update(&[5]); + } + + // Prefetch with budget of 2 + let prefetched = cache.prefetch_by_affinity(&affinity, 2); + + // Expert 3 should NOT be in prefetched (already hot) + assert!( + !prefetched.contains(&3), + "Should not prefetch already-hot expert" + ); + + // Expert 5 should be prefetched + assert!( + prefetched.contains(&5), + "Should prefetch next highest affinity expert" + ); + } + + // --------------------------------------------------------------- + // 25. Affinity weighted eviction blends scores correctly + // --------------------------------------------------------------- + + #[test] + fn test_affinity_weighted_eviction() { + use crate::moe::{AffinityConfig, ExpertAffinity}; + + let mut cache = make_cache(8, 3, EvictionPolicy::Lru); + let mut affinity = ExpertAffinity::new(AffinityConfig::with_num_experts(8).with_decay(1.0)); + + // Fill cache: 0, 1, 2 in that order (LRU order: 0 is oldest) + cache.access(0); + cache.access(1); + cache.access(2); + + // Give expert 0 very high affinity + for _ in 0..20 { + affinity.update(&[0]); + } + + // Expert 1 and 2 have zero affinity + + // With weight=0.0 (pure LRU), should evict expert 0 (oldest) + let victim_lru = cache.suggest_eviction_with_affinity(&affinity, 0.0); + assert_eq!(victim_lru, Some(0), "Weight 0 should use pure LRU"); + + // With weight=1.0 (pure affinity), should evict expert 1 or 2 (lowest affinity) + let victim_affinity = cache.suggest_eviction_with_affinity(&affinity, 1.0); + assert!( + victim_affinity == Some(1) || victim_affinity == Some(2), + "Weight 1.0 should evict lowest affinity" + ); + + // With weight=0.5 (balanced), expert 0's high affinity should protect it + let victim_balanced = cache.suggest_eviction_with_affinity(&affinity, 0.5); + assert_ne!( + victim_balanced, + Some(0), + "Balanced weight should protect high-affinity expert" + ); + } + + // --------------------------------------------------------------- + // 26. Zero affinity weight falls back to base policy + // --------------------------------------------------------------- + + #[test] + fn test_zero_affinity_weight_fallback() { + use crate::moe::{AffinityConfig, ExpertAffinity}; + + let mut cache = make_cache(8, 3, EvictionPolicy::Lfu); + let affinity = ExpertAffinity::new(AffinityConfig::with_num_experts(8)); + + // Expert 0: accessed 1 time (lowest freq) + cache.access(0); + + // Expert 1: accessed 3 times + cache.access(1); + cache.access(1); + cache.access(1); + + // Expert 2: accessed 2 times + cache.access(2); + cache.access(2); + + // With weight=0, should behave exactly like base policy (LFU) + let victim_base = cache.suggest_eviction(); + let victim_zero_weight = cache.suggest_eviction_with_affinity(&affinity, 0.0); + + assert_eq!( + victim_base, victim_zero_weight, + "Zero weight should match base policy" + ); + assert_eq!(victim_base, Some(0), "LFU should evict lowest frequency"); + } } diff --git a/crates/ruvllm/src/lib.rs b/crates/ruvllm/src/lib.rs index 9cfe6192f..202e23e5c 100644 --- a/crates/ruvllm/src/lib.rs +++ b/crates/ruvllm/src/lib.rs @@ -130,6 +130,7 @@ pub mod memory_pool; #[cfg(all(target_os = "macos", feature = "metal-compute"))] pub mod metal; pub mod models; +pub mod moe; pub mod optimization; pub mod paged_attention; pub mod policy_store; @@ -322,6 +323,11 @@ pub use memory_pool::{ MemoryManagerConfig, MemoryManagerStats, PooledBuffer, ScratchSpace, ScratchSpaceManager, ScratchStats, CACHE_LINE_SIZE, DEFAULT_ALIGNMENT, }; +// MoE (Mixture of Experts) - ADR-092 +pub use moe::{ + AffinityConfig, ExpertAffinity, ExpertId, ExpertPrecision, MoeMetrics, MoeMetricsSummary, + PrecisionAllocator, PrecisionConfig, +}; pub use optimization::{ AdaptationResult, BatchSizeStrategy, ConsolidationStrategy, InferenceMetrics, KvCachePressurePolicy, LatencyHistogram, LearningLoopStats, MetricsCollector, MetricsSnapshot, @@ -337,15 +343,30 @@ pub use qat::{ UniformQuantizer, DEFAULT_BITS, DEFAULT_QAT_LR, MAX_BITS, MIN_BITS, }; pub use quantize::{ + // Incoherence transform (ADR-090 Phase 3) + apply_incoherence, dequantize_for_ane, // Memory estimation estimate_memory_q4, estimate_memory_q5, estimate_memory_q8, + // Hadamard transform (ADR-090 Phase 3) + hadamard_batch_inverse, + hadamard_batch_transform, + log2_exact, + next_power_of_2, + pad_to_power_of_2, // Quantization functions quantize_ruvltra_q4, quantize_ruvltra_q5, quantize_ruvltra_q8, + restore_incoherence, + HadamardTransform, + IncoherenceConfig, + IncoherenceEvent, + IncoherencePhase, + IncoherenceStats, + IncoherenceTransform, MemoryEstimate, // Block types Q4KMBlock, @@ -358,23 +379,8 @@ pub use quantize::{ // Core quantizer RuvltraQuantizer, TargetFormat, - // Hadamard transform (ADR-090 Phase 3) - hadamard_batch_inverse, - hadamard_batch_transform, - log2_exact, - next_power_of_2, - pad_to_power_of_2, - HadamardTransform, MAX_LOG_DIM, SIMD_LANES, - // Incoherence transform (ADR-090 Phase 3) - apply_incoherence, - restore_incoherence, - IncoherenceConfig, - IncoherenceEvent, - IncoherencePhase, - IncoherenceStats, - IncoherenceTransform, }; pub use serving::{ BatchStats, diff --git a/crates/ruvllm/src/moe/affinity.rs b/crates/ruvllm/src/moe/affinity.rs new file mode 100644 index 000000000..9a61de2db --- /dev/null +++ b/crates/ruvllm/src/moe/affinity.rs @@ -0,0 +1,990 @@ +//! Expert Affinity Tracking (ADR-092) +//! +//! This module implements EMA-based expert affinity tracking for memory-aware +//! MoE routing. The affinity scores track which experts are frequently activated, +//! enabling: +//! +//! - **Predictive Prefetching**: Load experts with high affinity before they're needed +//! - **Affinity-Aware Eviction**: Evict low-affinity experts first +//! - **Precision Allocation**: Assign higher precision to frequently-used experts +//! +//! ## Key Invariant (INV-2 from ADR-092) +//! +//! **Affinity Monotonicity**: EMA-based affinity scores decrease monotonically +//! without new activations. This ensures predictable eviction behavior. +//! +//! ## Algorithm +//! +//! On each update: +//! 1. All scores are decayed: `score = score * decay` +//! 2. Activated experts receive a boost: `score = min(score + boost, 1.0)` +//! +//! The decay factor (typically 0.95-0.99) controls how quickly old activations +//! are "forgotten". Higher values provide longer memory. + +use super::ExpertId; + +/// SIMD-optimized decay for f32 scores. +/// +/// Applies `scores[i] *= decay` for all elements using platform-specific +/// SIMD intrinsics when available, with scalar fallback. +#[inline] +fn decay_scores_simd(scores: &mut [f32], decay: f32) { + #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] + { + decay_scores_neon(scores, decay); + } + + #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] + { + decay_scores_avx2(scores, decay); + } + + #[cfg(not(any( + all(target_arch = "aarch64", target_feature = "neon"), + all(target_arch = "x86_64", target_feature = "avx2") + )))] + { + decay_scores_scalar(scores, decay); + } +} + +/// Scalar fallback for decay +#[inline] +fn decay_scores_scalar(scores: &mut [f32], decay: f32) { + for score in scores.iter_mut() { + *score *= decay; + } +} + +/// NEON-optimized decay for ARM64 +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] +#[inline] +fn decay_scores_neon(scores: &mut [f32], decay: f32) { + use std::arch::aarch64::*; + + let len = scores.len(); + let chunks = len / 4; + let remainder = len % 4; + + unsafe { + let decay_vec = vdupq_n_f32(decay); + let ptr = scores.as_mut_ptr(); + + for i in 0..chunks { + let offset = i * 4; + let vals = vld1q_f32(ptr.add(offset)); + let result = vmulq_f32(vals, decay_vec); + vst1q_f32(ptr.add(offset), result); + } + + // Handle remainder with scalar + for i in (chunks * 4)..len { + *scores.get_unchecked_mut(i) *= decay; + } + } +} + +/// AVX2-optimized decay for x86_64 +#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] +#[inline] +fn decay_scores_avx2(scores: &mut [f32], decay: f32) { + use std::arch::x86_64::*; + + let len = scores.len(); + let chunks = len / 8; + + unsafe { + let decay_vec = _mm256_set1_ps(decay); + let ptr = scores.as_mut_ptr(); + + for i in 0..chunks { + let offset = i * 8; + let vals = _mm256_loadu_ps(ptr.add(offset)); + let result = _mm256_mul_ps(vals, decay_vec); + _mm256_storeu_ps(ptr.add(offset), result); + } + + // Handle remainder with scalar + for i in (chunks * 8)..len { + *scores.get_unchecked_mut(i) *= decay; + } + } +} + +/// Configuration for expert affinity tracking. +/// +/// # Example +/// +/// ```rust +/// use ruvllm::moe::AffinityConfig; +/// +/// let config = AffinityConfig::with_num_experts(8) +/// .with_decay(0.95) +/// .with_activation_boost(1.0); +/// ``` +#[derive(Debug, Clone)] +pub struct AffinityConfig { + /// Number of experts in the model. + pub num_experts: usize, + + /// EMA decay factor applied to all scores on each update. + /// + /// Range: `0.0 < decay < 1.0` + /// - Higher values (e.g., 0.99) = longer memory, slower forgetting + /// - Lower values (e.g., 0.95) = shorter memory, faster adaptation + /// + /// Default: 0.99 + pub decay: f32, + + /// Boost value added to activated experts. + /// + /// The score after boosting is clamped to `[0.0, max_score]`. + /// + /// Default: 1.0 + pub activation_boost: f32, + + /// Maximum affinity score (clamping bound). + /// + /// Scores are clamped to `[0.0, max_score]` after boosting. + /// + /// Default: 1.0 + pub max_score: f32, +} + +impl Default for AffinityConfig { + fn default() -> Self { + Self { + num_experts: 8, + decay: 0.99, + activation_boost: 1.0, + max_score: 1.0, + } + } +} + +impl AffinityConfig { + /// Create config for a specific number of experts with default decay and boost. + pub fn with_num_experts(num_experts: usize) -> Self { + Self { + num_experts, + ..Default::default() + } + } + + /// Builder: set the decay factor. + /// + /// Values are clamped to `[0.0, 1.0]`. + pub fn with_decay(mut self, decay: f32) -> Self { + self.decay = decay.clamp(0.0, 1.0); + self + } + + /// Builder: set the activation boost. + /// + /// Negative values are clamped to 0. + pub fn with_activation_boost(mut self, boost: f32) -> Self { + self.activation_boost = boost.max(0.0); + self + } + + /// Builder: set the maximum score. + /// + /// Values are clamped to be at least 0.0. + pub fn with_max_score(mut self, max_score: f32) -> Self { + self.max_score = max_score.max(0.0); + self + } +} + +/// EMA-based expert affinity tracker (ADR-092). +/// +/// Tracks which experts are frequently activated using Exponential Moving Average +/// scores. This enables memory-aware routing decisions: +/// +/// - Experts with high affinity should be kept in cache +/// - Experts with low affinity can be evicted or use lower precision +/// - High-affinity experts are good prefetch candidates +/// +/// # Invariant INV-2: Affinity Monotonicity +/// +/// Without new activations, all affinity scores decrease monotonically +/// according to the decay factor. This property is critical for predictable +/// eviction behavior. +/// +/// # Example +/// +/// ```rust +/// use ruvllm::moe::{ExpertAffinity, AffinityConfig}; +/// +/// let config = AffinityConfig::with_num_experts(8).with_decay(0.95); +/// let mut affinity = ExpertAffinity::new(config); +/// +/// // Experts 2 and 5 were selected this round +/// affinity.update(&[2, 5]); +/// +/// // Get current affinity scores +/// assert!(affinity.get_score(2) > affinity.get_score(0)); +/// +/// // Get top experts for prefetching +/// let top3 = affinity.top_k_by_affinity(3); +/// ``` +#[derive(Debug, Clone)] +pub struct ExpertAffinity { + /// EMA scores per expert, range `[0.0, 1.0]`. + scores: Vec, + /// Configuration parameters. + config: AffinityConfig, + /// Total activation count per expert (for precision allocation). + total_activations: Vec, +} + +impl ExpertAffinity { + /// Create a new affinity tracker with the given configuration + pub fn new(config: AffinityConfig) -> Self { + Self { + scores: vec![0.0; config.num_experts], + total_activations: vec![0; config.num_experts], + config, + } + } + + /// Update affinity for activated experts + /// + /// This method: + /// 1. Applies decay to ALL expert scores (INV-2: monotonic decay) + /// 2. Boosts scores for activated experts + /// + /// # Arguments + /// + /// * `activated` - Expert IDs that were activated this step + pub fn update(&mut self, activated: &[ExpertId]) { + // Step 1: Decay all scores (INV-2: monotonic without activation) + // Use SIMD-optimized decay when available + decay_scores_simd(&mut self.scores, self.config.decay); + + // Step 2: Boost activated experts + for &id in activated { + if id < self.scores.len() { + self.scores[id] = + (self.scores[id] + self.config.activation_boost).min(self.config.max_score); + self.total_activations[id] += 1; + } + } + } + + /// Get the affinity score for a specific expert + pub fn score(&self, expert_id: ExpertId) -> f32 { + self.scores.get(expert_id).copied().unwrap_or(0.0) + } + + /// Alias for [`score`](Self::score) for API consistency. + #[inline] + pub fn get_score(&self, expert_id: ExpertId) -> f32 { + self.score(expert_id) + } + + /// Get all affinity scores. + /// + /// The returned slice is indexed by expert ID. + #[inline] + pub fn scores(&self) -> &[f32] { + &self.scores + } + + /// Alias for [`scores`](Self::scores) for API consistency. + #[inline] + pub fn get_scores(&self) -> &[f32] { + self.scores() + } + + /// Get total activation count for an expert. + /// + /// This count is never reset by `update()` and is useful for + /// long-term precision allocation decisions. + /// + /// # Returns + /// + /// Total number of times this expert has been activated, or `0` if + /// the expert ID is out of range. + #[inline] + pub fn activation_count(&self, expert_id: ExpertId) -> u64 { + self.total_activations.get(expert_id).copied().unwrap_or(0) + } + + /// Get all activation counts. + /// + /// The returned slice is indexed by expert ID. + #[inline] + pub fn get_activation_counts(&self) -> &[u64] { + &self.total_activations + } + + /// Get experts sorted by affinity score (highest first) + /// + /// Useful for prefetching decisions. NaN values are treated as lowest priority. + pub fn top_k_by_affinity(&self, k: usize) -> Vec { + let mut indexed: Vec<(ExpertId, f32)> = self + .scores + .iter() + .enumerate() + .map(|(id, &s)| (id, if s.is_finite() { s } else { f32::NEG_INFINITY })) + .collect(); + indexed.sort_by(|a, b| { + b.1.partial_cmp(&a.1) + .unwrap_or(std::cmp::Ordering::Equal) + .then_with(|| a.0.cmp(&b.0)) // Deterministic tie-breaking by ID + }); + indexed.into_iter().take(k).map(|(id, _)| id).collect() + } + + /// Get experts sorted by total activation count (highest first) + /// + /// Useful for precision allocation decisions. + pub fn top_k_by_frequency(&self, k: usize) -> Vec { + let mut indexed: Vec<(ExpertId, u64)> = + self.total_activations.iter().copied().enumerate().collect(); + indexed.sort_by(|a, b| b.1.cmp(&a.1)); + indexed.into_iter().take(k).map(|(id, _)| id).collect() + } + + /// Get the least-affinity expert from a set of candidates + /// + /// Useful for eviction decisions. NaN values are treated as lowest (evict first). + pub fn least_affinity(&self, candidates: &[ExpertId]) -> Option { + candidates.iter().copied().min_by(|&a, &b| { + let score_a = self.score(a); + let score_b = self.score(b); + // NaN handling: treat NaN as NEG_INFINITY for eviction priority + let sa = if score_a.is_finite() { + score_a + } else { + f32::NEG_INFINITY + }; + let sb = if score_b.is_finite() { + score_b + } else { + f32::NEG_INFINITY + }; + sa.partial_cmp(&sb) + .unwrap_or(std::cmp::Ordering::Equal) + .then_with(|| a.cmp(&b)) // Deterministic tie-breaking + }) + } + + /// Compute percentile rank of an expert's activation frequency + /// + /// Returns a value in [0.0, 1.0] where 1.0 means highest frequency. + pub fn frequency_percentile(&self, expert_id: ExpertId) -> f32 { + let count = self.activation_count(expert_id); + let lower = self + .total_activations + .iter() + .filter(|&&c| c < count) + .count(); + let equal = self + .total_activations + .iter() + .filter(|&&c| c == count) + .count(); + let n = self.total_activations.len(); + if n == 0 { + return 0.5; + } + (lower as f32 + 0.5 * equal as f32) / n as f32 + } + + /// Reset all affinity scores to zero + pub fn reset(&mut self) { + self.scores.fill(0.0); + self.total_activations.fill(0); + } + + /// Get the number of experts tracked + pub fn num_experts(&self) -> usize { + self.config.num_experts + } + + /// Get the configuration + pub fn config(&self) -> &AffinityConfig { + &self.config + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ===================================================================== + // Tests following ADR-092 specification (10+ required tests) + // ===================================================================== + + /// Test 1: Affinity creation initializes all scores to zero + #[test] + fn test_affinity_creation() { + let config = AffinityConfig::with_num_experts(8); + let affinity = ExpertAffinity::new(config); + + assert_eq!(affinity.num_experts(), 8); + assert_eq!(affinity.scores().len(), 8); + assert_eq!(affinity.get_scores().len(), 8); + assert!(affinity.scores().iter().all(|&s| s == 0.0)); + assert!(affinity.get_activation_counts().iter().all(|&c| c == 0)); + } + + /// Test 2: Update decays ALL scores, even those not activated + #[test] + fn test_update_decays_all() { + let config = AffinityConfig::with_num_experts(4) + .with_decay(0.5) + .with_activation_boost(1.0); + let mut affinity = ExpertAffinity::new(config); + + // Activate all experts to set initial scores + affinity.update(&[0, 1, 2, 3]); + + // All should be at 1.0 (boost clamped to max_score) + for &score in affinity.scores() { + assert!((score - 1.0).abs() < 1e-6); + } + + // Update with ONLY expert 0 -> all others should decay + affinity.update(&[0]); + + // Expert 0: 1.0 * 0.5 + 1.0 = 1.5 -> clamped to 1.0 + assert!((affinity.score(0) - 1.0).abs() < 1e-6); + + // Experts 1, 2, 3: 1.0 * 0.5 = 0.5 (no boost) + for id in 1..4 { + assert!( + (affinity.score(id) - 0.5).abs() < 1e-6, + "Expert {} should decay to 0.5, got {}", + id, + affinity.score(id) + ); + } + } + + /// Test 3: Update boosts activated experts + #[test] + fn test_update_boosts_activated() { + let config = AffinityConfig::with_num_experts(4) + .with_decay(0.9) + .with_activation_boost(0.5); + let mut affinity = ExpertAffinity::new(config); + + // Activate experts 1 and 3 + affinity.update(&[1, 3]); + + // Experts 1 and 3 should have boost + // Score = 0.0 * 0.9 + 0.5 = 0.5 + assert!((affinity.score(1) - 0.5).abs() < 1e-6); + assert!((affinity.score(3) - 0.5).abs() < 1e-6); + + // Others should be 0 (0.0 * 0.9 = 0.0, no boost) + assert_eq!(affinity.score(0), 0.0); + assert_eq!(affinity.score(2), 0.0); + } + + /// Test 4: INV-2 Property - Monotonic decay without activations + #[test] + fn test_monotonic_decay() { + let config = AffinityConfig::with_num_experts(8).with_decay(0.95); + let mut affinity = ExpertAffinity::new(config); + + // Activate some experts + affinity.update(&[1, 3, 5, 7]); + + // Record initial scores + let scores_t0 = affinity.scores().to_vec(); + + // Multiple updates with NO activations + for iteration in 0..10 { + let scores_before = affinity.scores().to_vec(); + affinity.update(&[]); // Empty update + let scores_after = affinity.scores().to_vec(); + + // INV-2: All scores must decrease monotonically + for (i, (&before, &after)) in scores_before.iter().zip(scores_after.iter()).enumerate() + { + assert!( + after <= before, + "INV-2 violated at iteration {}: score[{}] increased from {} to {}", + iteration, + i, + before, + after + ); + } + } + + // All activated scores should have decayed significantly from t0 + for (i, (&t0, ¤t)) in scores_t0.iter().zip(affinity.scores().iter()).enumerate() { + if t0 > 0.0 { + assert!( + current < t0, + "Score[{}] did not decay: {} -> {}", + i, + t0, + current + ); + } + } + } + + /// Test 5: Top-K by affinity returns correct experts in order + #[test] + fn test_top_k_by_affinity() { + // Use decay=1.0 (no decay) to test pure ordering by activation count + let config = AffinityConfig::with_num_experts(6) + .with_decay(1.0) + .with_activation_boost(0.1); + let mut affinity = ExpertAffinity::new(config); + + // Create distinct affinity levels by activating experts different times + // Expert 3: activated 5 times -> score = 0.5 + for _ in 0..5 { + affinity.update(&[3]); + } + + // Expert 1: activated 3 times -> score = 0.3 + for _ in 0..3 { + affinity.update(&[1]); + } + + // Expert 5: activated 1 time -> score = 0.1 + affinity.update(&[5]); + + // Verify scores are as expected (no decay) + assert!( + (affinity.score(3) - 0.5).abs() < 1e-6, + "Expert 3 score: {}", + affinity.score(3) + ); + assert!( + (affinity.score(1) - 0.3).abs() < 1e-6, + "Expert 1 score: {}", + affinity.score(1) + ); + assert!( + (affinity.score(5) - 0.1).abs() < 1e-6, + "Expert 5 score: {}", + affinity.score(5) + ); + + // Top-2 should be [3, 1] (highest scores) + let top2 = affinity.top_k_by_affinity(2); + assert_eq!(top2.len(), 2); + assert_eq!(top2[0], 3, "Expert 3 should be top"); + assert_eq!(top2[1], 1, "Expert 1 should be second"); + + // Top-4 should include 3, 1, 5, and one of the zeros + let top4 = affinity.top_k_by_affinity(4); + assert_eq!(top4.len(), 4); + assert_eq!(top4[0], 3); + assert_eq!(top4[1], 1); + assert_eq!(top4[2], 5); + + // Top-10 (more than available) should return all 6 + let top10 = affinity.top_k_by_affinity(10); + assert_eq!(top10.len(), 6); + } + + /// Test 6: Score is clamped to max_score (default 1.0) + #[test] + fn test_score_clamped_to_one() { + let config = AffinityConfig::with_num_experts(4) + .with_decay(0.99) + .with_activation_boost(1.0); + let mut affinity = ExpertAffinity::new(config); + + // Activate expert 0 many times + for _ in 0..100 { + affinity.update(&[0]); + } + + // Score should be clamped at max_score (1.0) + assert!( + (affinity.score(0) - 1.0).abs() < 1e-6, + "Score should be clamped to 1.0, got {}", + affinity.score(0) + ); + + // Should never exceed 1.0 + assert!( + affinity.score(0) <= 1.0, + "Score {} exceeds max_score", + affinity.score(0) + ); + } + + /// Test 7: Activation counting tracks total activations + #[test] + fn test_activation_counting() { + let config = AffinityConfig::with_num_experts(4); + let mut affinity = ExpertAffinity::new(config); + + // Activate experts with different frequencies + affinity.update(&[0, 1]); // +1 each + affinity.update(&[0, 2]); // +1 to 0, 2 + affinity.update(&[0]); // +1 to 0 + + assert_eq!(affinity.activation_count(0), 3); + assert_eq!(affinity.activation_count(1), 1); + assert_eq!(affinity.activation_count(2), 1); + assert_eq!(affinity.activation_count(3), 0); + + // Out of range should return 0 + assert_eq!(affinity.activation_count(100), 0); + } + + /// Test 8: Reset clears all scores and activation counts + #[test] + fn test_reset() { + let config = AffinityConfig::with_num_experts(4); + let mut affinity = ExpertAffinity::new(config); + + // Build up some state + for _ in 0..10 { + affinity.update(&[0, 1, 2, 3]); + } + + // Verify state is non-zero + assert!(affinity.score(0) > 0.0); + assert!(affinity.activation_count(0) > 0); + + // Reset + affinity.reset(); + + // All scores should be 0 + for &score in affinity.scores() { + assert_eq!(score, 0.0); + } + + // All activation counts should be 0 + for &count in affinity.get_activation_counts() { + assert_eq!(count, 0); + } + } + + /// Test 9: Empty update only decays, no boosts + #[test] + fn test_empty_update() { + let config = AffinityConfig::with_num_experts(4).with_decay(0.9); + let mut affinity = ExpertAffinity::new(config); + + // Set initial state + affinity.update(&[0, 1, 2, 3]); + + let counts_before = affinity.get_activation_counts().to_vec(); + + // Empty update + affinity.update(&[]); + + // Scores should decay (we verified in monotonic decay test) + // Activation counts should NOT change + assert_eq!(affinity.get_activation_counts(), &counts_before); + } + + /// Test 10: Multiple updates sequence produces correct state + #[test] + fn test_multiple_updates_sequence() { + let config = AffinityConfig::with_num_experts(8) + .with_decay(0.8) + .with_activation_boost(0.5); + let mut affinity = ExpertAffinity::new(config); + + // Simulate a realistic workload: + // Expert 0 and 1 are "hot" (activated frequently) + // Expert 7 is activated once then never + + // Round 1: Activate 0, 1, 7 + affinity.update(&[0, 1, 7]); + assert!((affinity.score(0) - 0.5).abs() < 1e-6); + assert!((affinity.score(7) - 0.5).abs() < 1e-6); + + // Round 2: Activate 0, 1 only + affinity.update(&[0, 1]); + // Expert 0: 0.5 * 0.8 + 0.5 = 0.9 + assert!((affinity.score(0) - 0.9).abs() < 1e-6); + // Expert 7: 0.5 * 0.8 = 0.4 (no boost) + assert!((affinity.score(7) - 0.4).abs() < 1e-6); + + // Round 3: Activate 0, 1 again + affinity.update(&[0, 1]); + // Expert 0: 0.9 * 0.8 + 0.5 = 1.22 -> clamped to 1.0 + assert!((affinity.score(0) - 1.0).abs() < 1e-6); + // Expert 7: 0.4 * 0.8 = 0.32 + assert!((affinity.score(7) - 0.32).abs() < 1e-6); + + // After 3 rounds: + // - Expert 0, 1 should be top (activated every round) + // - Expert 7 should be decaying + let top2 = affinity.top_k_by_affinity(2); + assert!(top2.contains(&0)); + assert!(top2.contains(&1)); + + // Activation counts + assert_eq!(affinity.activation_count(0), 3); + assert_eq!(affinity.activation_count(1), 3); + assert_eq!(affinity.activation_count(7), 1); + } + + // ===================================================================== + // Additional tests beyond the 10 required + // ===================================================================== + + /// Test: Out-of-bounds expert IDs are silently ignored + #[test] + fn test_out_of_bounds_experts_ignored() { + let config = AffinityConfig::with_num_experts(4); + let mut affinity = ExpertAffinity::new(config); + + // Include invalid expert IDs + affinity.update(&[0, 1, 100, 200, 3]); + + // Valid experts should be updated + assert!(affinity.score(0) > 0.0); + assert!(affinity.score(1) > 0.0); + assert!(affinity.score(3) > 0.0); + + // Expert 2 was not activated + assert_eq!(affinity.score(2), 0.0); + + // Activation counts for valid experts + assert_eq!(affinity.activation_count(0), 1); + assert_eq!(affinity.activation_count(100), 0); + } + + /// Test: Config builder methods work correctly + #[test] + fn test_config_builders() { + let config = AffinityConfig::with_num_experts(16) + .with_decay(0.95) + .with_activation_boost(0.75) + .with_max_score(2.0); + + assert_eq!(config.num_experts, 16); + assert!((config.decay - 0.95).abs() < 1e-6); + assert!((config.activation_boost - 0.75).abs() < 1e-6); + assert!((config.max_score - 2.0).abs() < 1e-6); + } + + /// Test: Decay is clamped to [0.0, 1.0] + #[test] + fn test_decay_clamp() { + let config = AffinityConfig::with_num_experts(4).with_decay(1.5); + assert!( + (config.decay - 1.0).abs() < 1e-6, + "Decay should be clamped to 1.0" + ); + + let config2 = AffinityConfig::with_num_experts(4).with_decay(-0.5); + assert!( + (config2.decay - 0.0).abs() < 1e-6, + "Decay should be clamped to 0.0" + ); + } + + /// Test: Frequency percentile calculation + #[test] + fn test_frequency_percentile() { + let config = AffinityConfig::with_num_experts(4); + let mut affinity = ExpertAffinity::new(config); + + // Expert 0: 1 activation + // Expert 1: 5 activations + // Expert 2: 3 activations + // Expert 3: 0 activations + affinity.update(&[0]); + for _ in 0..5 { + affinity.update(&[1]); + } + for _ in 0..3 { + affinity.update(&[2]); + } + + // Expert 1 should have highest percentile + let pct_1 = affinity.frequency_percentile(1); + let pct_3 = affinity.frequency_percentile(3); + + assert!( + pct_1 > pct_3, + "Expert 1 should have higher percentile than 3" + ); + assert!(pct_1 > 0.5, "Expert 1 should be above median"); + } + + /// Test: Least affinity from candidates + #[test] + fn test_least_affinity() { + // Use decay=1.0 (no decay) and small boost to get distinct scores + let config = AffinityConfig::with_num_experts(4) + .with_decay(1.0) + .with_activation_boost(0.1); + let mut affinity = ExpertAffinity::new(config); + + // Different activation levels -> different scores + // Expert 0: 5 activations -> score = 0.5 + for _ in 0..5 { + affinity.update(&[0]); + } + // Expert 1: 2 activations -> score = 0.2 + for _ in 0..2 { + affinity.update(&[1]); + } + // Expert 2: 1 activation -> score = 0.1 + affinity.update(&[2]); + + // Verify scores + assert!((affinity.score(0) - 0.5).abs() < 1e-6); + assert!((affinity.score(1) - 0.2).abs() < 1e-6); + assert!((affinity.score(2) - 0.1).abs() < 1e-6); + + let candidates = vec![0, 1, 2]; + let least = affinity.least_affinity(&candidates); + + // Expert 2 has lowest affinity (0.1) + assert_eq!(least, Some(2)); + + // Empty candidates + let empty: Vec = vec![]; + assert_eq!(affinity.least_affinity(&empty), None); + } + + /// Test: Top-K by frequency + #[test] + fn test_top_k_by_frequency() { + let config = AffinityConfig::with_num_experts(4); + let mut affinity = ExpertAffinity::new(config); + + affinity.update(&[0]); + affinity.update(&[1]); + affinity.update(&[1]); + affinity.update(&[2]); + affinity.update(&[2]); + affinity.update(&[2]); + + let top_2 = affinity.top_k_by_frequency(2); + assert_eq!(top_2.len(), 2); + // Expert 2 activated 3 times, highest + assert_eq!(top_2[0], 2); + // Expert 1 activated 2 times, second + assert_eq!(top_2[1], 1); + } + + /// Test: Default config values + #[test] + fn test_default_config() { + let config = AffinityConfig::default(); + assert_eq!(config.num_experts, 8); + assert!((config.decay - 0.99).abs() < 1e-6); + assert!((config.activation_boost - 1.0).abs() < 1e-6); + assert!((config.max_score - 1.0).abs() < 1e-6); + } + + // ===================================================================== + // P1 Optimization Tests: SIMD decay + // ===================================================================== + + /// Test: SIMD decay with non-aligned sizes (tests remainder handling) + #[test] + fn test_simd_decay_non_aligned() { + // Test sizes that don't align to SIMD widths (4 for NEON, 8 for AVX2) + for size in [1, 3, 5, 7, 9, 15, 17, 33, 65] { + let config = AffinityConfig::with_num_experts(size).with_decay(0.5); + let mut affinity = ExpertAffinity::new(config); + + // Activate all experts + let all_experts: Vec = (0..size).collect(); + affinity.update(&all_experts); + + // All should be at 1.0 + for &score in affinity.scores() { + assert!((score - 1.0).abs() < 1e-6); + } + + // Decay once + affinity.update(&[]); + + // All should be at 0.5 + for (i, &score) in affinity.scores().iter().enumerate() { + assert!( + (score - 0.5).abs() < 1e-6, + "Expert {} score should be 0.5, got {}", + i, + score + ); + } + } + } + + /// Test: SIMD decay with large expert count + #[test] + fn test_simd_decay_large() { + let config = AffinityConfig::with_num_experts(256).with_decay(0.9); + let mut affinity = ExpertAffinity::new(config); + + // Activate first 128 experts + let activated: Vec = (0..128).collect(); + affinity.update(&activated); + + // Decay multiple times + for _ in 0..10 { + affinity.update(&[]); + } + + // All activated scores should have decayed significantly + let expected = 0.9f32.powi(10); + for i in 0..128 { + let score = affinity.score(i); + assert!( + (score - expected).abs() < 1e-5, + "Expert {} score should be ~{}, got {}", + i, + expected, + score + ); + } + + // Non-activated should still be 0 + for i in 128..256 { + assert_eq!(affinity.score(i), 0.0); + } + } + + /// Test: SIMD decay correctness (scalar vs SIMD equivalence) + #[test] + fn test_simd_decay_correctness() { + let config = AffinityConfig::with_num_experts(64) + .with_decay(0.87) + .with_activation_boost(0.33); + let mut affinity = ExpertAffinity::new(config); + + // Activate various experts with a pattern + affinity.update(&[0, 7, 15, 23, 31, 39, 47, 55, 63]); + + // Record scores + let scores_before: Vec = affinity.scores().to_vec(); + + // Decay + affinity.update(&[]); + + // Verify each score decayed correctly + for (i, (&before, &after)) in scores_before + .iter() + .zip(affinity.scores().iter()) + .enumerate() + { + let expected = before * 0.87; + assert!( + (after - expected).abs() < 1e-6, + "Expert {} decay incorrect: {} * 0.87 = {}, got {}", + i, + before, + expected, + after + ); + } + } +} diff --git a/crates/ruvllm/src/moe/metrics.rs b/crates/ruvllm/src/moe/metrics.rs new file mode 100644 index 000000000..8ff5c5bdc --- /dev/null +++ b/crates/ruvllm/src/moe/metrics.rs @@ -0,0 +1,341 @@ +//! MoE Metrics Collection (ADR-092) +//! +//! Tracks cache hit rate, paging latency, and routing performance for +//! memory-aware expert routing. + +use std::time::{Duration, Instant}; + +/// MoE routing and caching metrics. +/// +/// Tracks cache hits, misses, paging operations, and timing information +/// to enable tuning of routing parameters. +#[derive(Debug, Clone, Default)] +pub struct MoeMetrics { + /// Number of routing decisions where selected experts were cache-resident + pub cache_hits: u64, + /// Number of routing decisions requiring expert paging + pub cache_misses: u64, + /// Total experts paged in + pub experts_paged_in: u64, + /// Total experts paged out (evicted) + pub experts_paged_out: u64, + /// Total routing decisions made + pub routing_decisions: u64, + /// Cumulative routing latency in microseconds + pub routing_latency_us: u64, + /// Maximum routing latency in microseconds + pub max_routing_latency_us: u64, + /// Cumulative paging latency in microseconds + pub paging_latency_us: u64, + /// Maximum paging latency in microseconds + pub max_paging_latency_us: u64, + /// Number of prefetch operations + pub prefetch_operations: u64, + /// Successful prefetch hits (prefetched expert was subsequently used) + pub prefetch_hits: u64, + /// Affinity-based evictions (vs random/LRU) + pub affinity_evictions: u64, +} + +impl MoeMetrics { + /// Create new metrics instance + pub fn new() -> Self { + Self::default() + } + + /// Record a cache hit + pub fn record_cache_hit(&mut self) { + self.cache_hits += 1; + } + + /// Record a cache miss + pub fn record_cache_miss(&mut self) { + self.cache_misses += 1; + } + + /// Record multiple cache hits (P2 batch optimization) + #[inline] + pub fn record_cache_hits(&mut self, count: usize) { + self.cache_hits += count as u64; + } + + /// Record multiple cache misses (P2 batch optimization) + #[inline] + pub fn record_cache_misses(&mut self, count: usize) { + self.cache_misses += count as u64; + } + + /// Record expert paged in + pub fn record_page_in(&mut self, latency: Duration) { + self.experts_paged_in += 1; + let latency_us = latency.as_micros() as u64; + self.paging_latency_us += latency_us; + self.max_paging_latency_us = self.max_paging_latency_us.max(latency_us); + } + + /// Record expert paged out (evicted) + pub fn record_page_out(&mut self) { + self.experts_paged_out += 1; + } + + /// Record a routing decision with latency + pub fn record_routing(&mut self, latency: Duration) { + self.routing_decisions += 1; + let latency_us = latency.as_micros() as u64; + self.routing_latency_us += latency_us; + self.max_routing_latency_us = self.max_routing_latency_us.max(latency_us); + } + + /// Record a prefetch operation + pub fn record_prefetch(&mut self) { + self.prefetch_operations += 1; + } + + /// Record a successful prefetch hit + pub fn record_prefetch_hit(&mut self) { + self.prefetch_hits += 1; + } + + /// Record an affinity-based eviction + pub fn record_affinity_eviction(&mut self) { + self.affinity_evictions += 1; + } + + /// Compute the cache hit rate (0.0 - 1.0) + /// + /// Returns 0.0 if no routing decisions have been made. + pub fn hit_rate(&self) -> f32 { + let total = self.cache_hits + self.cache_misses; + if total == 0 { + return 0.0; + } + self.cache_hits as f32 / total as f32 + } + + /// Compute average routing latency in microseconds + /// + /// Returns 0.0 if no routing decisions have been made. + pub fn avg_routing_latency_us(&self) -> f64 { + if self.routing_decisions == 0 { + return 0.0; + } + self.routing_latency_us as f64 / self.routing_decisions as f64 + } + + /// Compute average paging latency in microseconds + /// + /// Returns 0.0 if no paging operations have been recorded. + pub fn avg_paging_latency_us(&self) -> f64 { + if self.experts_paged_in == 0 { + return 0.0; + } + self.paging_latency_us as f64 / self.experts_paged_in as f64 + } + + /// Compute prefetch accuracy (0.0 - 1.0) + /// + /// Returns 0.0 if no prefetch operations have been recorded. + pub fn prefetch_accuracy(&self) -> f32 { + if self.prefetch_operations == 0 { + return 0.0; + } + self.prefetch_hits as f32 / self.prefetch_operations as f32 + } + + /// Generate a summary of current metrics + pub fn summary(&self) -> MoeMetricsSummary { + MoeMetricsSummary { + hit_rate: self.hit_rate(), + avg_routing_latency_us: self.avg_routing_latency_us(), + max_routing_latency_us: self.max_routing_latency_us, + avg_paging_latency_us: self.avg_paging_latency_us(), + max_paging_latency_us: self.max_paging_latency_us, + prefetch_accuracy: self.prefetch_accuracy(), + total_routing_decisions: self.routing_decisions, + total_page_operations: self.experts_paged_in + self.experts_paged_out, + } + } + + /// Reset all metrics to zero + pub fn reset(&mut self) { + *self = Self::default(); + } +} + +/// Summary of MoE metrics for reporting +#[derive(Debug, Clone)] +pub struct MoeMetricsSummary { + /// Cache hit rate (0.0 - 1.0) + pub hit_rate: f32, + /// Average routing latency in microseconds + pub avg_routing_latency_us: f64, + /// Maximum routing latency in microseconds + pub max_routing_latency_us: u64, + /// Average paging latency in microseconds + pub avg_paging_latency_us: f64, + /// Maximum paging latency in microseconds + pub max_paging_latency_us: u64, + /// Prefetch accuracy (0.0 - 1.0) + pub prefetch_accuracy: f32, + /// Total number of routing decisions + pub total_routing_decisions: u64, + /// Total paging operations (in + out) + pub total_page_operations: u64, +} + +impl MoeMetricsSummary { + /// Check if metrics meet ADR-092 targets + /// + /// Returns true if: + /// - Cache hit rate >= 70% + /// - Max routing latency <= 15 us (10 us target with some margin) + pub fn meets_targets(&self) -> bool { + self.hit_rate >= 0.70 && self.max_routing_latency_us <= 15 + } +} + +/// Timer for measuring operation durations +pub struct MetricsTimer { + start: Instant, +} + +impl MetricsTimer { + /// Start a new timer + pub fn start() -> Self { + Self { + start: Instant::now(), + } + } + + /// Get elapsed duration + pub fn elapsed(&self) -> Duration { + self.start.elapsed() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_metrics_new() { + let metrics = MoeMetrics::new(); + assert_eq!(metrics.cache_hits, 0); + assert_eq!(metrics.cache_misses, 0); + assert_eq!(metrics.hit_rate(), 0.0); + } + + #[test] + fn test_hit_rate_calculation() { + let mut metrics = MoeMetrics::new(); + + // 3 hits, 1 miss = 75% hit rate + metrics.record_cache_hit(); + metrics.record_cache_hit(); + metrics.record_cache_hit(); + metrics.record_cache_miss(); + + assert!((metrics.hit_rate() - 0.75).abs() < 1e-6); + } + + #[test] + fn test_routing_latency() { + let mut metrics = MoeMetrics::new(); + + metrics.record_routing(Duration::from_micros(10)); + metrics.record_routing(Duration::from_micros(20)); + + assert_eq!(metrics.routing_decisions, 2); + assert!((metrics.avg_routing_latency_us() - 15.0).abs() < 1e-6); + assert_eq!(metrics.max_routing_latency_us, 20); + } + + #[test] + fn test_prefetch_accuracy() { + let mut metrics = MoeMetrics::new(); + + metrics.record_prefetch(); + metrics.record_prefetch(); + metrics.record_prefetch(); + metrics.record_prefetch_hit(); + metrics.record_prefetch_hit(); + + // 2 hits out of 3 prefetches = 66.67% + assert!((metrics.prefetch_accuracy() - 0.6666667).abs() < 1e-6); + } + + #[test] + fn test_summary_meets_targets() { + let summary = MoeMetricsSummary { + hit_rate: 0.75, + avg_routing_latency_us: 8.0, + max_routing_latency_us: 12, + avg_paging_latency_us: 100.0, + max_paging_latency_us: 200, + prefetch_accuracy: 0.6, + total_routing_decisions: 100, + total_page_operations: 20, + }; + + assert!(summary.meets_targets()); + } + + #[test] + fn test_summary_fails_targets() { + let summary = MoeMetricsSummary { + hit_rate: 0.50, // Below 70% + avg_routing_latency_us: 8.0, + max_routing_latency_us: 12, + avg_paging_latency_us: 100.0, + max_paging_latency_us: 200, + prefetch_accuracy: 0.6, + total_routing_decisions: 100, + total_page_operations: 20, + }; + + assert!(!summary.meets_targets()); + } + + #[test] + fn test_metrics_reset() { + let mut metrics = MoeMetrics::new(); + metrics.record_cache_hit(); + metrics.record_cache_miss(); + metrics.record_routing(Duration::from_micros(10)); + + metrics.reset(); + + assert_eq!(metrics.cache_hits, 0); + assert_eq!(metrics.cache_misses, 0); + assert_eq!(metrics.routing_decisions, 0); + } + + #[test] + fn test_metrics_timer() { + let timer = MetricsTimer::start(); + // Just verify it doesn't panic + let _elapsed = timer.elapsed(); + } + + #[test] + fn test_bulk_cache_recording() { + let mut metrics = MoeMetrics::new(); + + // P2 optimization: bulk recording + metrics.record_cache_hits(5); + metrics.record_cache_misses(2); + + assert_eq!(metrics.cache_hits, 5); + assert_eq!(metrics.cache_misses, 2); + + // Mix with single recording + metrics.record_cache_hit(); + metrics.record_cache_miss(); + + assert_eq!(metrics.cache_hits, 6); + assert_eq!(metrics.cache_misses, 3); + + // Hit rate should be 6/9 = 66.67% + assert!((metrics.hit_rate() - 0.6666667).abs() < 1e-5); + } +} diff --git a/crates/ruvllm/src/moe/mod.rs b/crates/ruvllm/src/moe/mod.rs new file mode 100644 index 000000000..7f7745320 --- /dev/null +++ b/crates/ruvllm/src/moe/mod.rs @@ -0,0 +1,92 @@ +//! MoE (Mixture of Experts) Module +//! +//! This module provides components for efficient Mixture of Experts inference, +//! including routing metrics tracking, expert affinity tracking, and performance +//! monitoring. +//! +//! ## Overview +//! +//! MoE architectures use sparse expert activation to achieve high model capacity +//! while keeping compute costs manageable. Key challenges include: +//! +//! - **Expert Cache Management**: Keeping frequently-used experts in memory +//! - **Routing Efficiency**: Minimizing overhead from expert selection +//! - **Paging Overhead**: Managing memory transfers for cold experts +//! - **Affinity Tracking**: Understanding expert co-activation patterns +//! +//! ## Key Components +//! +//! - [`MemoryAwareRouter`]: Memory-aware expert routing with cache residency bonus (ADR-092) +//! - [`RouterConfig`]: Configuration for router behavior and cache bonus parameters +//! - [`ExpertAffinity`]: EMA-based affinity tracking for memory-aware routing (ADR-092) +//! - [`AffinityConfig`]: Configuration for affinity tracking parameters +//! - [`MoeMetrics`]: Real-time tracking of cache hits, misses, paging, and routing +//! - [`MoeMetricsSummary`]: Aggregated performance summary statistics +//! - [`PagingRequest`]: Request to page experts in/out of cache +//! +//! ## ADR-092 Compliance +//! +//! This module implements memory-aware expert routing as specified in ADR-092: +//! +//! - **INV-2: Affinity Monotonicity**: EMA-based affinity scores decrease monotonically +//! without new activations, ensuring predictable eviction behavior. +//! - **INV-6: Router Determinism**: Same input + cache state always produces same result +//! - Cache hit rates and prefetch effectiveness tracking +//! - Paging latency distribution +//! - Routing decision throughput +//! - Target: >=70% cache hit rate (vs 34% baseline) +//! +//! ## Example: Affinity Tracking +//! +//! ```rust +//! use ruvllm::moe::{ExpertAffinity, AffinityConfig}; +//! +//! let config = AffinityConfig::with_num_experts(8).with_decay(0.95); +//! let mut affinity = ExpertAffinity::new(config); +//! +//! // Experts 2 and 5 were selected this round +//! affinity.update(&[2, 5]); +//! +//! // Get current affinity scores +//! let scores = affinity.scores(); +//! assert!(scores[2] > scores[0]); // Expert 2 was activated +//! +//! // Get top-3 experts for prefetching +//! let top3 = affinity.top_k_by_affinity(3); +//! ``` +//! +//! ## Example: Metrics Tracking +//! +//! ```rust,ignore +//! use ruvllm::moe::{MoeMetrics, MoeMetricsSummary}; +//! +//! let mut metrics = MoeMetrics::new(); +//! +//! // Track cache operations +//! metrics.record_cache_hit(); +//! metrics.record_cache_miss(); +//! metrics.record_page_in(Duration::from_micros(150)); +//! +//! // Get summary statistics +//! let summary = metrics.summary(); +//! println!("Hit rate: {:.2}%", summary.hit_rate * 100.0); +//! println!("Avg paging latency: {:.1}us", summary.avg_paging_latency_us); +//! ``` + +pub mod affinity; +pub mod metrics; +pub mod precision_allocator; +pub mod router; +pub mod sram_mapper; + +/// Expert identifier type (matches bitnet/expert_cache.rs convention). +/// +/// Expert IDs are zero-indexed integers. For a model with `num_experts=8`, +/// valid IDs are `0..8`. +pub type ExpertId = usize; + +pub use affinity::{AffinityConfig, ExpertAffinity}; +pub use metrics::{MoeMetrics, MoeMetricsSummary}; +pub use precision_allocator::{ExpertPrecision, PrecisionAllocator, PrecisionConfig}; +pub use router::{MemoryAwareRouter, PagingDirection, PagingPriority, PagingRequest, RouterConfig}; +pub use sram_mapper::{HardwareConfig, HardwarePreset, MemoryTier, SramExpertAffinity, SramMapper}; diff --git a/crates/ruvllm/src/moe/precision_allocator.rs b/crates/ruvllm/src/moe/precision_allocator.rs new file mode 100644 index 000000000..d2d681a12 --- /dev/null +++ b/crates/ruvllm/src/moe/precision_allocator.rs @@ -0,0 +1,1090 @@ +//! Frequency-Based Precision Allocation for MoE Experts +//! +//! This module implements ADR-092's precision allocation strategy, which assigns +//! different quantization formats to experts based on their activation frequency: +//! +//! - **Hot experts** (high activation): Higher precision (e.g., Q4_K_M) +//! - **Warm experts** (medium activation): Medium precision (e.g., Q3_K or PiQ3) +//! - **Cold experts** (low activation): Lower precision (e.g., Q2_K or PiQ2) +//! +//! This approach preserves model quality by keeping frequently-used experts at +//! higher precision while aggressively compressing rarely-used experts. +//! +//! ## Invariant: INV-4 Precision Preservation +//! +//! Expert precision metadata travels with cached weights. The `PrecisionAllocator` +//! tracks and exposes the assigned precision for each expert so that dequantization +//! uses the correct format. +//! +//! ## Example +//! +//! ```rust,ignore +//! use ruvllm::moe::precision_allocator::{PrecisionAllocator, PrecisionConfig, ExpertPrecision}; +//! use ruvllm::gguf::GgufQuantType; +//! +//! let config = PrecisionConfig::default(); +//! let mut allocator = PrecisionAllocator::new(8, config).unwrap(); +//! +//! // Record activations as experts are used +//! allocator.record_activation(2); +//! allocator.record_activation(2); +//! allocator.record_activation(5); +//! +//! // Recompute thresholds periodically +//! allocator.recompute_thresholds(); +//! +//! // Get precision level for routing decisions +//! let precision = allocator.allocate(2); +//! let format = allocator.get_format(2); +//! ``` + +use crate::gguf::GgufQuantType; + +// Re-export ExpertId from parent module +pub use super::ExpertId; + +// ============================================================================ +// Precision Level +// ============================================================================ + +/// Precision level assigned to an expert based on activation frequency. +/// +/// The three tiers enable differentiated memory/quality tradeoffs: +/// - Hot: Maximum quality, higher memory usage +/// - Warm: Balanced quality/memory +/// - Cold: Aggressive compression, lower memory usage +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum ExpertPrecision { + /// High precision for frequently-activated (hot) experts. + /// + /// Typically Q4_K_M or higher for best quality on important experts. + Hot, + + /// Medium precision for moderately-activated (warm) experts. + /// + /// Typically Q3_K or PiQ3 for balanced quality/memory. + Warm, + + /// Low precision for rarely-activated (cold) experts. + /// + /// Typically Q2_K or PiQ2 for maximum compression on seldom-used experts. + Cold, +} + +impl ExpertPrecision { + /// Returns a human-readable name for the precision level. + pub fn name(&self) -> &'static str { + match self { + ExpertPrecision::Hot => "hot", + ExpertPrecision::Warm => "warm", + ExpertPrecision::Cold => "cold", + } + } +} + +// ============================================================================ +// Configuration +// ============================================================================ + +/// Configuration for frequency-based precision allocation. +/// +/// The percentile thresholds determine how experts are classified: +/// - Experts with activation count >= hot_percentile of max are "hot" +/// - Experts with activation count >= cold_percentile but < hot_percentile are "warm" +/// - Experts with activation count < cold_percentile are "cold" +#[derive(Debug, Clone)] +pub struct PrecisionConfig { + /// Percentile threshold for hot experts (default: 0.9 = top 10% by frequency). + /// + /// Experts whose activation count is at or above this percentile of the + /// maximum activation count are classified as hot. + pub hot_percentile: f32, + + /// Percentile threshold for cold experts (default: 0.3 = bottom 30% by frequency). + /// + /// Experts whose activation count is below this percentile of the + /// maximum activation count are classified as cold. + pub cold_percentile: f32, + + /// GGUF quantization format for hot experts. + /// + /// Default: Q4_K (4-bit k-quant) for good quality. + pub hot_format: GgufQuantType, + + /// GGUF quantization format for warm experts. + /// + /// Default: Q3_K (3-bit k-quant) for balanced quality/size. + pub warm_format: GgufQuantType, + + /// GGUF quantization format for cold experts. + /// + /// Default: Q2_K (2-bit k-quant) for maximum compression. + pub cold_format: GgufQuantType, +} + +impl Default for PrecisionConfig { + fn default() -> Self { + Self { + hot_percentile: 0.9, + cold_percentile: 0.3, + hot_format: GgufQuantType::Q4_K, + warm_format: GgufQuantType::Q3_K, + cold_format: GgufQuantType::Q2_K, + } + } +} + +impl PrecisionConfig { + /// Create a config optimized for memory-constrained devices. + /// + /// Uses more aggressive thresholds and lower precision formats. + pub fn memory_constrained() -> Self { + Self { + hot_percentile: 0.95, + cold_percentile: 0.4, + hot_format: GgufQuantType::Q4_K, + warm_format: GgufQuantType::Q2_K, + cold_format: GgufQuantType::Q2_K, + } + } + + /// Create a config optimized for quality preservation. + /// + /// Uses higher precision formats across all tiers. + pub fn quality_focused() -> Self { + Self { + hot_percentile: 0.8, + cold_percentile: 0.2, + hot_format: GgufQuantType::Q5_K, + warm_format: GgufQuantType::Q4_K, + cold_format: GgufQuantType::Q3_K, + } + } + + /// Validate the configuration. + /// + /// Returns an error message if the configuration is invalid. + pub fn validate(&self) -> Result<(), &'static str> { + if self.hot_percentile <= 0.0 || self.hot_percentile > 1.0 { + return Err("hot_percentile must be in (0.0, 1.0]"); + } + if self.cold_percentile < 0.0 || self.cold_percentile >= 1.0 { + return Err("cold_percentile must be in [0.0, 1.0)"); + } + if self.cold_percentile >= self.hot_percentile { + return Err("cold_percentile must be less than hot_percentile"); + } + Ok(()) + } +} + +// ============================================================================ +// Precision Allocator +// ============================================================================ + +/// Frequency-based precision allocator for MoE experts. +/// +/// Tracks activation counts for each expert and assigns precision levels +/// based on relative frequency. Hot experts (frequently used) get higher +/// precision to preserve quality, while cold experts (rarely used) get +/// lower precision to save memory. +/// +/// # INV-4: Precision Preservation +/// +/// This allocator maintains the mapping from expert ID to precision level, +/// ensuring that the correct quantization format is used when dequantizing +/// cached expert weights. +/// +/// # Example +/// +/// ```rust,ignore +/// use ruvllm::moe::precision_allocator::{PrecisionAllocator, PrecisionConfig}; +/// +/// let config = PrecisionConfig::default(); +/// let mut allocator = PrecisionAllocator::new(8, config).unwrap(); +/// +/// // Simulate expert activations +/// for _ in 0..100 { allocator.record_activation(0); } // Hot +/// for _ in 0..50 { allocator.record_activation(1); } // Warm +/// allocator.record_activation(7); // Cold +/// +/// allocator.recompute_thresholds(); +/// +/// assert_eq!(allocator.allocate(0), ExpertPrecision::Hot); +/// assert_eq!(allocator.allocate(1), ExpertPrecision::Warm); +/// assert_eq!(allocator.allocate(7), ExpertPrecision::Cold); +/// ``` +pub struct PrecisionAllocator { + /// Number of experts tracked. + num_experts: usize, + + /// Activation counts per expert, indexed by ExpertId. + counts: Vec, + + /// Configuration for precision allocation. + config: PrecisionConfig, + + /// Cached threshold for hot classification. + /// + /// Experts with count >= hot_threshold are hot. + hot_threshold: u64, + + /// Cached threshold for cold classification. + /// + /// Experts with count < cold_threshold are cold. + cold_threshold: u64, +} + +impl PrecisionAllocator { + /// Create a new precision allocator. + /// + /// # Arguments + /// + /// * `num_experts` - Total number of experts to track. + /// * `config` - Configuration for precision allocation. + /// + /// # Returns + /// + /// Returns `Err` if the configuration is invalid. + pub fn new(num_experts: usize, config: PrecisionConfig) -> Result { + config.validate()?; + + Ok(Self { + num_experts, + counts: vec![0; num_experts], + config, + hot_threshold: 0, + cold_threshold: 0, + }) + } + + /// Create a new precision allocator, panicking on invalid config. + /// + /// # Panics + /// + /// Panics if the configuration is invalid. + pub fn new_unchecked(num_experts: usize, config: PrecisionConfig) -> Self { + Self::new(num_experts, config).expect("PrecisionConfig validation failed") + } + + /// Record an activation for the given expert. + /// + /// Increments the activation counter for the specified expert. + /// This should be called each time an expert is selected by the router. + /// + /// # Arguments + /// + /// * `expert_id` - The ID of the activated expert. + /// + /// # Notes + /// + /// Out-of-bounds expert IDs are silently ignored. + #[inline] + pub fn record_activation(&mut self, expert_id: ExpertId) { + if expert_id < self.num_experts { + self.counts[expert_id] = self.counts[expert_id].saturating_add(1); + } + } + + /// Record multiple activations in a batch. + /// + /// More efficient than calling `record_activation` in a loop when + /// processing batched routing decisions. + /// + /// # Arguments + /// + /// * `expert_ids` - Slice of activated expert IDs. + pub fn record_activations(&mut self, expert_ids: &[ExpertId]) { + for &expert_id in expert_ids { + self.record_activation(expert_id); + } + } + + /// Get the precision level for the given expert. + /// + /// Returns the precision classification (Hot, Warm, or Cold) based on + /// the expert's activation count relative to the computed thresholds. + /// + /// # Arguments + /// + /// * `expert_id` - The ID of the expert to classify. + /// + /// # Returns + /// + /// The precision level for the expert. Returns `Cold` for out-of-bounds IDs. + pub fn allocate(&self, expert_id: ExpertId) -> ExpertPrecision { + if expert_id >= self.num_experts { + return ExpertPrecision::Cold; + } + + let count = self.counts[expert_id]; + + // If no activations have occurred yet, all experts are cold + if self.hot_threshold == 0 && self.cold_threshold == 0 { + return ExpertPrecision::Cold; + } + + if count >= self.hot_threshold && self.hot_threshold > 0 { + ExpertPrecision::Hot + } else if count >= self.cold_threshold && count > 0 { + ExpertPrecision::Warm + } else { + ExpertPrecision::Cold + } + } + + /// Get the GGUF quantization format for the given expert. + /// + /// Returns the appropriate quantization format based on the expert's + /// precision classification. + /// + /// # Arguments + /// + /// * `expert_id` - The ID of the expert. + /// + /// # Returns + /// + /// The GGUF quantization type to use for this expert. + pub fn get_format(&self, expert_id: ExpertId) -> GgufQuantType { + match self.allocate(expert_id) { + ExpertPrecision::Hot => self.config.hot_format, + ExpertPrecision::Warm => self.config.warm_format, + ExpertPrecision::Cold => self.config.cold_format, + } + } + + /// Recompute the threshold values based on current activation counts. + /// + /// Should be called periodically (e.g., every N tokens or at batch boundaries) + /// to update the precision classifications as activation patterns change. + /// + /// The thresholds are computed from the maximum activation count: + /// - `hot_threshold = max_count * hot_percentile` + /// - `cold_threshold = max_count * cold_percentile` + pub fn recompute_thresholds(&mut self) { + let max_count = self.counts.iter().copied().max().unwrap_or(0); + + if max_count == 0 { + self.hot_threshold = 0; + self.cold_threshold = 0; + return; + } + + // Compute thresholds as fractions of max count + self.hot_threshold = (max_count as f64 * self.config.hot_percentile as f64).ceil() as u64; + self.cold_threshold = + (max_count as f64 * self.config.cold_percentile as f64).floor() as u64; + + // Ensure cold_threshold is at least 1 if there are any activations + if self.cold_threshold == 0 && max_count > 0 { + self.cold_threshold = 1; + } + } + + /// Get the precision map for all experts. + /// + /// Returns a vector of (ExpertId, ExpertPrecision) tuples for all tracked + /// experts. Useful for bulk operations or serialization. + /// + /// # Returns + /// + /// Vector of tuples containing each expert's ID and precision level. + pub fn get_precision_map(&self) -> Vec<(ExpertId, ExpertPrecision)> { + (0..self.num_experts) + .map(|id| (id, self.allocate(id))) + .collect() + } + + /// Get the activation count for a specific expert. + /// + /// # Arguments + /// + /// * `expert_id` - The ID of the expert. + /// + /// # Returns + /// + /// The activation count, or 0 for out-of-bounds IDs. + pub fn get_count(&self, expert_id: ExpertId) -> u64 { + self.counts.get(expert_id).copied().unwrap_or(0) + } + + /// Get the total number of activations across all experts. + pub fn total_activations(&self) -> u64 { + self.counts.iter().sum() + } + + /// Get the number of experts in each precision tier. + /// + /// # Returns + /// + /// Tuple of (hot_count, warm_count, cold_count). + pub fn tier_counts(&self) -> (usize, usize, usize) { + let mut hot = 0; + let mut warm = 0; + let mut cold = 0; + + for id in 0..self.num_experts { + match self.allocate(id) { + ExpertPrecision::Hot => hot += 1, + ExpertPrecision::Warm => warm += 1, + ExpertPrecision::Cold => cold += 1, + } + } + + (hot, warm, cold) + } + + /// Reset all activation counts to zero. + /// + /// Also resets the thresholds. Useful when starting a new evaluation + /// period or when the model's usage patterns have changed significantly. + pub fn reset(&mut self) { + self.counts.fill(0); + self.hot_threshold = 0; + self.cold_threshold = 0; + } + + /// Get the number of experts being tracked. + pub fn num_experts(&self) -> usize { + self.num_experts + } + + /// Get the current hot threshold. + pub fn hot_threshold(&self) -> u64 { + self.hot_threshold + } + + /// Get the current cold threshold. + pub fn cold_threshold(&self) -> u64 { + self.cold_threshold + } + + /// Get a reference to the configuration. + pub fn config(&self) -> &PrecisionConfig { + &self.config + } + + /// Get experts by precision level. + /// + /// # Arguments + /// + /// * `precision` - The precision level to filter by. + /// + /// # Returns + /// + /// Vector of expert IDs with the specified precision level. + pub fn experts_by_precision(&self, precision: ExpertPrecision) -> Vec { + (0..self.num_experts) + .filter(|&id| self.allocate(id) == precision) + .collect() + } + + /// Compute the percentile rank for a given expert. + /// + /// Returns a value in [0.0, 1.0] representing where this expert's + /// activation count falls relative to the maximum. + /// + /// # Arguments + /// + /// * `expert_id` - The ID of the expert. + /// + /// # Returns + /// + /// Percentile rank (0.0 = no activations, 1.0 = max activations). + pub fn compute_percentile(&self, expert_id: ExpertId) -> f32 { + if expert_id >= self.num_experts { + return 0.0; + } + + let max_count = self.counts.iter().copied().max().unwrap_or(0); + if max_count == 0 { + return 0.0; + } + + self.counts[expert_id] as f32 / max_count as f32 + } +} + +// ============================================================================ +// Tests +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + + // --------------------------------------------------------------- + // test_allocator_creation + // --------------------------------------------------------------- + + #[test] + fn test_allocator_creation() { + let config = PrecisionConfig::default(); + let allocator = PrecisionAllocator::new(8, config).unwrap(); + + assert_eq!(allocator.num_experts(), 8); + assert_eq!(allocator.total_activations(), 0); + assert_eq!(allocator.hot_threshold(), 0); + assert_eq!(allocator.cold_threshold(), 0); + + // All experts should be cold initially + for id in 0..8 { + assert_eq!(allocator.allocate(id), ExpertPrecision::Cold); + } + } + + // --------------------------------------------------------------- + // test_hot_expert_allocation + // --------------------------------------------------------------- + + #[test] + fn test_hot_expert_allocation() { + let config = PrecisionConfig { + hot_percentile: 0.9, + cold_percentile: 0.3, + ..Default::default() + }; + let mut allocator = PrecisionAllocator::new(8, config).unwrap(); + + // Expert 0 gets 100 activations (max) + for _ in 0..100 { + allocator.record_activation(0); + } + + // Other experts get 10 activations each + for id in 1..8 { + for _ in 0..10 { + allocator.record_activation(id); + } + } + + allocator.recompute_thresholds(); + + // Expert 0 should be hot (100 >= 90% of 100 = 90) + assert_eq!(allocator.allocate(0), ExpertPrecision::Hot); + assert_eq!(allocator.get_format(0), GgufQuantType::Q4_K); + } + + // --------------------------------------------------------------- + // test_warm_expert_allocation + // --------------------------------------------------------------- + + #[test] + fn test_warm_expert_allocation() { + let config = PrecisionConfig { + hot_percentile: 0.9, + cold_percentile: 0.3, + ..Default::default() + }; + let mut allocator = PrecisionAllocator::new(8, config).unwrap(); + + // Expert 0 gets 100 activations (max) + for _ in 0..100 { + allocator.record_activation(0); + } + + // Expert 1 gets 50 activations (warm: 30-89% of max) + for _ in 0..50 { + allocator.record_activation(1); + } + + allocator.recompute_thresholds(); + + // Expert 1 should be warm (50 >= 30% of 100 = 30, but < 90) + assert_eq!(allocator.allocate(1), ExpertPrecision::Warm); + assert_eq!(allocator.get_format(1), GgufQuantType::Q3_K); + } + + // --------------------------------------------------------------- + // test_cold_expert_allocation + // --------------------------------------------------------------- + + #[test] + fn test_cold_expert_allocation() { + let config = PrecisionConfig { + hot_percentile: 0.9, + cold_percentile: 0.3, + ..Default::default() + }; + let mut allocator = PrecisionAllocator::new(8, config).unwrap(); + + // Expert 0 gets 100 activations (max) + for _ in 0..100 { + allocator.record_activation(0); + } + + // Expert 7 gets 5 activations (cold: < 30% of max) + for _ in 0..5 { + allocator.record_activation(7); + } + + allocator.recompute_thresholds(); + + // Expert 7 should be cold (5 < 30% of 100 = 30) + assert_eq!(allocator.allocate(7), ExpertPrecision::Cold); + assert_eq!(allocator.get_format(7), GgufQuantType::Q2_K); + } + + // --------------------------------------------------------------- + // test_percentile_thresholds + // --------------------------------------------------------------- + + #[test] + fn test_percentile_thresholds() { + let config = PrecisionConfig { + hot_percentile: 0.8, + cold_percentile: 0.2, + ..Default::default() + }; + let mut allocator = PrecisionAllocator::new(4, config).unwrap(); + + // Set up activation counts: 100, 75, 25, 5 + for _ in 0..100 { + allocator.record_activation(0); + } + for _ in 0..75 { + allocator.record_activation(1); + } + for _ in 0..25 { + allocator.record_activation(2); + } + for _ in 0..5 { + allocator.record_activation(3); + } + + allocator.recompute_thresholds(); + + // hot_threshold = ceil(100 * 0.8) = 80 (or 81 due to f32->f64 precision) + // cold_threshold = floor(100 * 0.2) = 20 + // Allow for minor floating-point variance + assert!( + allocator.hot_threshold() >= 80 && allocator.hot_threshold() <= 81, + "hot_threshold {} should be 80 or 81", + allocator.hot_threshold() + ); + assert_eq!(allocator.cold_threshold(), 20); + + // Expert 0: 100 >= hot_threshold -> Hot + assert_eq!(allocator.allocate(0), ExpertPrecision::Hot); + + // Expert 1: 75 >= 20 but < hot_threshold -> Warm + assert_eq!(allocator.allocate(1), ExpertPrecision::Warm); + + // Expert 2: 25 >= 20 but < hot_threshold -> Warm + assert_eq!(allocator.allocate(2), ExpertPrecision::Warm); + + // Expert 3: 5 < 20 -> Cold + assert_eq!(allocator.allocate(3), ExpertPrecision::Cold); + } + + // --------------------------------------------------------------- + // test_activation_recording + // --------------------------------------------------------------- + + #[test] + fn test_activation_recording() { + let config = PrecisionConfig::default(); + let mut allocator = PrecisionAllocator::new(4, config).unwrap(); + + // Record individual activations + allocator.record_activation(0); + allocator.record_activation(0); + allocator.record_activation(1); + + assert_eq!(allocator.get_count(0), 2); + assert_eq!(allocator.get_count(1), 1); + assert_eq!(allocator.get_count(2), 0); + assert_eq!(allocator.total_activations(), 3); + + // Record batch activations + allocator.record_activations(&[2, 2, 3, 0]); + + assert_eq!(allocator.get_count(0), 3); + assert_eq!(allocator.get_count(2), 2); + assert_eq!(allocator.get_count(3), 1); + assert_eq!(allocator.total_activations(), 7); + + // Out-of-bounds should be ignored + allocator.record_activation(100); + assert_eq!(allocator.total_activations(), 7); + } + + // --------------------------------------------------------------- + // test_format_mapping + // --------------------------------------------------------------- + + #[test] + fn test_format_mapping() { + let config = PrecisionConfig { + hot_percentile: 0.9, + cold_percentile: 0.3, + hot_format: GgufQuantType::Q5_K, + warm_format: GgufQuantType::Q4_K, + cold_format: GgufQuantType::Q3_K, + }; + let mut allocator = PrecisionAllocator::new(3, config).unwrap(); + + // Set up: 100 (hot), 50 (warm), 10 (cold) + for _ in 0..100 { + allocator.record_activation(0); + } + for _ in 0..50 { + allocator.record_activation(1); + } + for _ in 0..10 { + allocator.record_activation(2); + } + + allocator.recompute_thresholds(); + + assert_eq!(allocator.get_format(0), GgufQuantType::Q5_K); + assert_eq!(allocator.get_format(1), GgufQuantType::Q4_K); + assert_eq!(allocator.get_format(2), GgufQuantType::Q3_K); + } + + // --------------------------------------------------------------- + // test_recompute_thresholds + // --------------------------------------------------------------- + + #[test] + fn test_recompute_thresholds() { + let config = PrecisionConfig { + hot_percentile: 0.9, + cold_percentile: 0.3, + ..Default::default() + }; + let mut allocator = PrecisionAllocator::new(4, config).unwrap(); + + // Initially all zeros + allocator.recompute_thresholds(); + assert_eq!(allocator.hot_threshold(), 0); + assert_eq!(allocator.cold_threshold(), 0); + + // Add some activations + for _ in 0..100 { + allocator.record_activation(0); + } + allocator.recompute_thresholds(); + + // hot_threshold = ceil(100 * 0.9) = 90 + // cold_threshold = max(1, floor(100 * 0.3)) = 30 + assert_eq!(allocator.hot_threshold(), 90); + assert_eq!(allocator.cold_threshold(), 30); + + // Add more activations and recompute + for _ in 0..100 { + allocator.record_activation(0); + } + allocator.recompute_thresholds(); + + // hot_threshold = ceil(200 * 0.9) = 180 + // cold_threshold = floor(200 * 0.3) = 60 + assert_eq!(allocator.hot_threshold(), 180); + assert_eq!(allocator.cold_threshold(), 60); + } + + // --------------------------------------------------------------- + // test_precision_map + // --------------------------------------------------------------- + + #[test] + fn test_precision_map() { + let config = PrecisionConfig { + hot_percentile: 0.9, + cold_percentile: 0.3, + ..Default::default() + }; + let mut allocator = PrecisionAllocator::new(4, config).unwrap(); + + for _ in 0..100 { + allocator.record_activation(0); + } + for _ in 0..50 { + allocator.record_activation(1); + } + for _ in 0..35 { + allocator.record_activation(2); + } + for _ in 0..10 { + allocator.record_activation(3); + } + + allocator.recompute_thresholds(); + + let map = allocator.get_precision_map(); + assert_eq!(map.len(), 4); + assert_eq!(map[0], (0, ExpertPrecision::Hot)); + assert_eq!(map[1], (1, ExpertPrecision::Warm)); + assert_eq!(map[2], (2, ExpertPrecision::Warm)); + assert_eq!(map[3], (3, ExpertPrecision::Cold)); + } + + // --------------------------------------------------------------- + // test_tier_counts + // --------------------------------------------------------------- + + #[test] + fn test_tier_counts() { + let config = PrecisionConfig { + hot_percentile: 0.9, + cold_percentile: 0.3, + ..Default::default() + }; + let mut allocator = PrecisionAllocator::new(8, config).unwrap(); + + // 2 hot, 3 warm, 3 cold + for _ in 0..100 { + allocator.record_activation(0); + } + for _ in 0..95 { + allocator.record_activation(1); + } + for _ in 0..50 { + allocator.record_activation(2); + } + for _ in 0..40 { + allocator.record_activation(3); + } + for _ in 0..35 { + allocator.record_activation(4); + } + for _ in 0..10 { + allocator.record_activation(5); + } + for _ in 0..5 { + allocator.record_activation(6); + } + // Expert 7 has 0 activations + + allocator.recompute_thresholds(); + + let (hot, warm, cold) = allocator.tier_counts(); + assert_eq!(hot, 2, "Expected 2 hot experts"); + assert!(warm >= 2, "Expected at least 2 warm experts"); + assert!(cold >= 2, "Expected at least 2 cold experts"); + assert_eq!(hot + warm + cold, 8, "Total should equal num_experts"); + } + + // --------------------------------------------------------------- + // test_reset + // --------------------------------------------------------------- + + #[test] + fn test_reset() { + let config = PrecisionConfig::default(); + let mut allocator = PrecisionAllocator::new(4, config).unwrap(); + + // Add activations + for _ in 0..100 { + allocator.record_activation(0); + } + allocator.recompute_thresholds(); + + assert!(allocator.total_activations() > 0); + assert!(allocator.hot_threshold() > 0); + + // Reset + allocator.reset(); + + assert_eq!(allocator.total_activations(), 0); + assert_eq!(allocator.hot_threshold(), 0); + assert_eq!(allocator.cold_threshold(), 0); + for id in 0..4 { + assert_eq!(allocator.get_count(id), 0); + } + } + + // --------------------------------------------------------------- + // test_experts_by_precision + // --------------------------------------------------------------- + + #[test] + fn test_experts_by_precision() { + let config = PrecisionConfig { + hot_percentile: 0.9, + cold_percentile: 0.3, + ..Default::default() + }; + let mut allocator = PrecisionAllocator::new(6, config).unwrap(); + + // Set up known distribution + for _ in 0..100 { + allocator.record_activation(0); + } // Hot + for _ in 0..92 { + allocator.record_activation(1); + } // Hot + for _ in 0..50 { + allocator.record_activation(2); + } // Warm + for _ in 0..40 { + allocator.record_activation(3); + } // Warm + for _ in 0..10 { + allocator.record_activation(4); + } // Cold + // Expert 5 has 0 activations -> Cold + + allocator.recompute_thresholds(); + + let hot_experts = allocator.experts_by_precision(ExpertPrecision::Hot); + let warm_experts = allocator.experts_by_precision(ExpertPrecision::Warm); + let cold_experts = allocator.experts_by_precision(ExpertPrecision::Cold); + + assert!(hot_experts.contains(&0)); + assert!(hot_experts.contains(&1)); + assert!(warm_experts.contains(&2) || warm_experts.contains(&3)); + assert!(cold_experts.contains(&4) || cold_experts.contains(&5)); + } + + // --------------------------------------------------------------- + // test_compute_percentile + // --------------------------------------------------------------- + + #[test] + fn test_compute_percentile() { + let config = PrecisionConfig::default(); + let mut allocator = PrecisionAllocator::new(4, config).unwrap(); + + // No activations -> 0.0 percentile + assert_eq!(allocator.compute_percentile(0), 0.0); + + // Set up: 100, 50, 25, 0 + for _ in 0..100 { + allocator.record_activation(0); + } + for _ in 0..50 { + allocator.record_activation(1); + } + for _ in 0..25 { + allocator.record_activation(2); + } + + // Expert 0: 100/100 = 1.0 + assert!((allocator.compute_percentile(0) - 1.0).abs() < f32::EPSILON); + + // Expert 1: 50/100 = 0.5 + assert!((allocator.compute_percentile(1) - 0.5).abs() < f32::EPSILON); + + // Expert 2: 25/100 = 0.25 + assert!((allocator.compute_percentile(2) - 0.25).abs() < f32::EPSILON); + + // Expert 3: 0/100 = 0.0 + assert!((allocator.compute_percentile(3) - 0.0).abs() < f32::EPSILON); + + // Out-of-bounds + assert_eq!(allocator.compute_percentile(100), 0.0); + } + + // --------------------------------------------------------------- + // test_config_validation + // --------------------------------------------------------------- + + #[test] + fn test_config_validation() { + // Valid config + let valid = PrecisionConfig::default(); + assert!(valid.validate().is_ok()); + + // Invalid: hot_percentile > 1.0 + let invalid1 = PrecisionConfig { + hot_percentile: 1.5, + ..Default::default() + }; + assert!(invalid1.validate().is_err()); + + // Invalid: cold_percentile >= hot_percentile + let invalid2 = PrecisionConfig { + hot_percentile: 0.5, + cold_percentile: 0.6, + ..Default::default() + }; + assert!(invalid2.validate().is_err()); + + // Invalid: cold_percentile negative + let invalid3 = PrecisionConfig { + cold_percentile: -0.1, + ..Default::default() + }; + assert!(invalid3.validate().is_err()); + } + + // --------------------------------------------------------------- + // test_precision_name + // --------------------------------------------------------------- + + #[test] + fn test_precision_name() { + assert_eq!(ExpertPrecision::Hot.name(), "hot"); + assert_eq!(ExpertPrecision::Warm.name(), "warm"); + assert_eq!(ExpertPrecision::Cold.name(), "cold"); + } + + // --------------------------------------------------------------- + // test_out_of_bounds_expert_id + // --------------------------------------------------------------- + + #[test] + fn test_out_of_bounds_expert_id() { + let config = PrecisionConfig::default(); + let allocator = PrecisionAllocator::new(4, config).unwrap(); + + // Out-of-bounds should return Cold + assert_eq!(allocator.allocate(100), ExpertPrecision::Cold); + assert_eq!(allocator.get_format(100), GgufQuantType::Q2_K); + assert_eq!(allocator.get_count(100), 0); + } + + // --------------------------------------------------------------- + // test_memory_constrained_config + // --------------------------------------------------------------- + + #[test] + fn test_memory_constrained_config() { + let config = PrecisionConfig::memory_constrained(); + + assert!(config.validate().is_ok()); + assert_eq!(config.hot_percentile, 0.95); + assert_eq!(config.cold_percentile, 0.4); + // More aggressive compression for warm/cold + assert_eq!(config.warm_format, GgufQuantType::Q2_K); + assert_eq!(config.cold_format, GgufQuantType::Q2_K); + } + + // --------------------------------------------------------------- + // test_quality_focused_config + // --------------------------------------------------------------- + + #[test] + fn test_quality_focused_config() { + let config = PrecisionConfig::quality_focused(); + + assert!(config.validate().is_ok()); + assert_eq!(config.hot_percentile, 0.8); + assert_eq!(config.cold_percentile, 0.2); + // Higher precision formats + assert_eq!(config.hot_format, GgufQuantType::Q5_K); + assert_eq!(config.warm_format, GgufQuantType::Q4_K); + assert_eq!(config.cold_format, GgufQuantType::Q3_K); + } + + // --------------------------------------------------------------- + // test_saturating_add_for_counts + // --------------------------------------------------------------- + + #[test] + fn test_saturating_add_for_counts() { + let config = PrecisionConfig::default(); + let mut allocator = PrecisionAllocator::new(1, config).unwrap(); + + // Set count close to max + allocator.counts[0] = u64::MAX - 1; + + // Should saturate instead of overflow + allocator.record_activation(0); + assert_eq!(allocator.get_count(0), u64::MAX); + + allocator.record_activation(0); + assert_eq!(allocator.get_count(0), u64::MAX); + } +} diff --git a/crates/ruvllm/src/moe/router.rs b/crates/ruvllm/src/moe/router.rs new file mode 100644 index 000000000..5ac61d98c --- /dev/null +++ b/crates/ruvllm/src/moe/router.rs @@ -0,0 +1,1341 @@ +//! Memory-Aware MoE Router (ADR-092) +//! +//! Expert selection with cache residency bonus for >=70% cache hit rate. +//! Implements INV-6: Router Determinism - same input + cache state = same result. +//! +//! ## Algorithm +//! +//! 1. Compute base scores from gate network logits +//! 2. Add cache residency bonus to resident experts +//! 3. Select top-K experts +//! 4. Update affinity tracking +//! 5. Generate paging requests for non-resident experts +//! +//! ## Configuration +//! +//! The `cache_bonus` parameter (0.0-1.0) controls how much to favor resident experts: +//! - 0.0: Pure accuracy (ignore cache state, baseline 34% hit rate) +//! - 0.15: Recommended balance (>=70% hit rate with <1% accuracy loss) +//! - 0.3+: Aggressive caching (may degrade accuracy) + +use super::{ExpertAffinity, ExpertId, MoeMetrics}; +use std::time::Instant; + +// ============================================================================ +// CacheMask: Bitmask-based cache residency tracking (P1 optimization) +// ============================================================================ + +/// Bitmask-based cache residency tracking for efficient memory access patterns. +/// +/// Uses a u64 for up to 64 experts (most common case: 8, 16, 32, 64 experts). +/// Falls back to Vec for larger models. +#[derive(Debug, Clone)] +struct CacheMask { + /// Bitmask for small models (up to 64 experts) + small: u64, + /// Extended bitmask for larger models (>64 experts) + extended: Option>, + /// Number of experts tracked + num_experts: usize, +} + +impl CacheMask { + /// Create a new cache mask for the given number of experts + fn new(num_experts: usize) -> Self { + if num_experts <= 64 { + Self { + small: 0, + extended: None, + num_experts, + } + } else { + let num_words = (num_experts + 63) / 64; + Self { + small: 0, + extended: Some(vec![0u64; num_words]), + num_experts, + } + } + } + + /// Check if an expert is resident + #[inline] + fn is_set(&self, id: ExpertId) -> bool { + if id >= self.num_experts { + return false; + } + if self.num_experts <= 64 { + (self.small & (1u64 << id)) != 0 + } else { + let word = id / 64; + let bit = id % 64; + self.extended + .as_ref() + .map(|v| (v[word] & (1u64 << bit)) != 0) + .unwrap_or(false) + } + } + + /// Set an expert as resident or non-resident + #[inline] + fn set(&mut self, id: ExpertId, resident: bool) { + if id >= self.num_experts { + return; + } + if self.num_experts <= 64 { + if resident { + self.small |= 1u64 << id; + } else { + self.small &= !(1u64 << id); + } + } else if let Some(ref mut v) = self.extended { + let word = id / 64; + let bit = id % 64; + if resident { + v[word] |= 1u64 << bit; + } else { + v[word] &= !(1u64 << bit); + } + } + } + + /// Clear all bits (no experts resident) + #[inline] + fn clear(&mut self) { + self.small = 0; + if let Some(ref mut v) = self.extended { + v.fill(0); + } + } + + /// Get list of resident expert IDs + fn resident_list(&self) -> Vec { + let mut result = Vec::new(); + if self.num_experts <= 64 { + let mut bits = self.small; + while bits != 0 { + let trailing = bits.trailing_zeros() as usize; + result.push(trailing); + bits &= bits - 1; // Clear lowest set bit + } + } else if let Some(ref v) = self.extended { + for (word_idx, &word) in v.iter().enumerate() { + let mut bits = word; + while bits != 0 { + let trailing = bits.trailing_zeros() as usize; + let id = word_idx * 64 + trailing; + if id < self.num_experts { + result.push(id); + } + bits &= bits - 1; + } + } + } + result + } + + /// Count number of resident experts (popcount) + #[inline] + fn count(&self) -> usize { + if self.num_experts <= 64 { + self.small.count_ones() as usize + } else { + self.extended + .as_ref() + .map(|v| v.iter().map(|w| w.count_ones() as usize).sum()) + .unwrap_or(0) + } + } +} + +/// Paging direction for expert load/evict operations +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PagingDirection { + /// Load expert into cache + In, + /// Evict expert from cache + Out, +} + +/// Priority level for paging operations +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum PagingPriority { + /// Normal priority (can be delayed) + Normal, + /// Urgent (needed for current inference) + Urgent, + /// Prefetch (speculative, can be cancelled) + Prefetch, +} + +/// Request to page an expert in or out of cache +#[derive(Debug, Clone)] +pub struct PagingRequest { + /// Expert ID to page + pub expert_id: ExpertId, + /// Direction (In = load, Out = evict) + pub direction: PagingDirection, + /// Priority level + pub priority: PagingPriority, +} + +impl PagingRequest { + /// Create a new paging request + pub fn new(expert_id: ExpertId, direction: PagingDirection, priority: PagingPriority) -> Self { + Self { + expert_id, + direction, + priority, + } + } + + /// Create an urgent page-in request + pub fn page_in_urgent(expert_id: ExpertId) -> Self { + Self::new(expert_id, PagingDirection::In, PagingPriority::Urgent) + } + + /// Create a prefetch request + pub fn prefetch(expert_id: ExpertId) -> Self { + Self::new(expert_id, PagingDirection::In, PagingPriority::Prefetch) + } + + /// Create a page-out request + pub fn page_out(expert_id: ExpertId) -> Self { + Self::new(expert_id, PagingDirection::Out, PagingPriority::Normal) + } +} + +/// Configuration for the memory-aware router +#[derive(Debug, Clone)] +pub struct RouterConfig { + /// Cache residency bonus weight (0.0-1.0) + /// + /// Added to gate scores for experts currently in cache. + /// Default: 0.15 (achieves >=70% hit rate with <1% accuracy loss) + pub cache_bonus: f32, + + /// Top-K experts to select per token + /// + /// Typical values: 1 (Switch), 2 (Mixtral), 4 (GShard) + pub top_k: usize, + + /// Number of total experts in the model + pub num_experts: usize, + + /// Enable memory-aware routing (feature flag) + /// + /// When false, the router ignores cache state and uses pure accuracy mode. + pub memory_aware: bool, + + /// Prefetch threshold (router weight to trigger speculative prefetch) + /// + /// Experts with weight >= this but not selected may be prefetched. + /// Default: 0.1 (10%) + pub prefetch_threshold: f32, +} + +impl Default for RouterConfig { + fn default() -> Self { + Self { + cache_bonus: 0.15, + top_k: 2, + num_experts: 8, + memory_aware: true, + prefetch_threshold: 0.1, + } + } +} + +impl RouterConfig { + /// Create config with specified parameters + pub fn new(num_experts: usize, top_k: usize) -> Self { + Self { + num_experts, + top_k, + ..Default::default() + } + } + + /// Set cache bonus weight + pub fn with_cache_bonus(mut self, bonus: f32) -> Self { + self.cache_bonus = bonus.clamp(0.0, 1.0); + self + } + + /// Set memory-aware mode + pub fn with_memory_aware(mut self, enabled: bool) -> Self { + self.memory_aware = enabled; + self + } + + /// Set prefetch threshold + pub fn with_prefetch_threshold(mut self, threshold: f32) -> Self { + self.prefetch_threshold = threshold.clamp(0.0, 1.0); + self + } + + /// Validate configuration + pub fn validate(&self) -> Result<(), &'static str> { + if self.top_k == 0 { + return Err("top_k must be at least 1"); + } + if self.top_k > self.num_experts { + return Err("top_k cannot exceed num_experts"); + } + if self.num_experts == 0 { + return Err("num_experts must be at least 1"); + } + Ok(()) + } +} + +/// Memory-aware MoE router with cache residency bonus +/// +/// Implements the memory-aware routing algorithm from ADR-092: +/// 1. Add cache residency bonus to gate scores +/// 2. Select top-K experts with adjusted scores +/// 3. Generate paging requests for non-resident selected experts +/// +/// # Invariant INV-6: Router Determinism +/// +/// Given the same input (gate_logits) and same cache state (cache_resident), +/// the router always produces the same output (selected experts, paging requests). +/// +/// # Example +/// +/// ```rust,ignore +/// use ruvllm::moe::{MemoryAwareRouter, RouterConfig, ExpertAffinity, AffinityConfig}; +/// +/// let config = RouterConfig { +/// cache_bonus: 0.15, +/// top_k: 2, +/// num_experts: 8, +/// memory_aware: true, +/// prefetch_threshold: 0.1, +/// }; +/// +/// let affinity = ExpertAffinity::new(AffinityConfig::with_num_experts(8)); +/// let mut router = MemoryAwareRouter::new(config, affinity); +/// +/// // Update which experts are currently cached +/// router.update_cache_state(&[0, 1, 2, 3]); +/// +/// // Route based on gate logits +/// let gate_logits = vec![0.1, 0.3, 0.5, 0.2, 0.4, 0.1, 0.2, 0.15]; +/// let (selected, paging_requests) = router.route(&gate_logits); +/// ``` +pub struct MemoryAwareRouter { + /// Router configuration + config: RouterConfig, + /// Expert affinity tracker + affinity: ExpertAffinity, + /// Bitmask tracking which experts are currently in cache (P1 optimization) + cache_resident: CacheMask, + /// Routing and caching metrics + metrics: MoeMetrics, + /// Reusable score buffer to avoid allocations (P2 optimization) + score_buffer: Vec, + /// Reusable indexed buffer for sorting (P2 optimization) + index_buffer: Vec<(ExpertId, f32)>, +} + +impl MemoryAwareRouter { + /// Create a new memory-aware router + /// + /// # Arguments + /// + /// * `config` - Router configuration + /// * `affinity` - Expert affinity tracker (can be shared) + /// + /// # Returns + /// + /// Returns `Err` if the configuration is invalid. + pub fn new(config: RouterConfig, affinity: ExpertAffinity) -> Result { + config.validate()?; + + let num_experts = config.num_experts; + Ok(Self { + cache_resident: CacheMask::new(num_experts), + // P2: Pre-allocate buffers to avoid allocations in hot path + score_buffer: vec![0.0; num_experts], + index_buffer: Vec::with_capacity(num_experts), + config, + affinity, + metrics: MoeMetrics::new(), + }) + } + + /// Create router with default affinity tracker + /// + /// # Returns + /// + /// Returns `Err` if the configuration is invalid. + pub fn with_default_affinity(config: RouterConfig) -> Result { + let affinity = + ExpertAffinity::new(super::AffinityConfig::with_num_experts(config.num_experts)); + Self::new(config, affinity) + } + + /// Main routing function with cache bonus + /// + /// Returns selected experts and any paging requests needed. + /// + /// # Arguments + /// + /// * `gate_logits` - Raw logits from the gate network (length = num_experts) + /// + /// # Returns + /// + /// Tuple of (selected_expert_ids, paging_requests) + /// + /// # INV-6: Determinism + /// + /// This function is deterministic: same inputs produce same outputs. + /// No random sampling is used. + #[inline] + pub fn route(&mut self, gate_logits: &[f32]) -> (Vec, Vec) { + let start = Instant::now(); + + // Validate input length (P3: early exit for invalid input) + if gate_logits.len() != self.config.num_experts { + let selected: Vec = + (0..self.config.top_k.min(self.config.num_experts)).collect(); + return (selected, Vec::new()); + } + + // P2: Use pre-allocated buffer instead of allocating + let selected = self.route_into_buffer(gate_logits); + + // Step 3: Update affinity for selected experts + self.affinity.update(&selected); + + // Step 4: Generate paging requests for non-resident selected experts + let paging_requests = self.generate_paging_requests(&selected); + + // Step 5: Record metrics (P3: unroll small loops) + let mut hits = 0usize; + for &id in &selected { + if self.cache_resident.is_set(id) { + hits += 1; + } + } + let misses = selected.len() - hits; + self.metrics.record_cache_hits(hits); + self.metrics.record_cache_misses(misses); + self.metrics.record_routing(start.elapsed()); + + (selected, paging_requests) + } + + /// P2 Optimization: Route using pre-allocated buffers + /// + /// Avoids allocation in the hot path by reusing internal buffers. + #[inline] + fn route_into_buffer(&mut self, gate_logits: &[f32]) -> Vec { + let n = gate_logits.len(); + + // Copy scores into buffer and apply cache bonus in-place + self.score_buffer.clear(); + self.score_buffer.extend_from_slice(gate_logits); + + if self.config.memory_aware { + self.apply_cache_bonus_inplace_buffer(); + } + + // Select top-K using index buffer + self.select_top_k_buffered(n) + } + + /// P2: Apply cache bonus using internal buffer + #[inline] + fn apply_cache_bonus_inplace_buffer(&mut self) { + let bonus = self.config.cache_bonus; + for (id, score) in self.score_buffer.iter_mut().enumerate() { + if !score.is_finite() { + *score = 0.0; + continue; + } + if self.cache_resident.is_set(id) { + *score += bonus; + } + } + } + + /// P2: Select top-K using pre-allocated index buffer + #[inline] + fn select_top_k_buffered(&mut self, n: usize) -> Vec { + let k = self.config.top_k.min(n); + if k == 0 || n == 0 { + return Vec::new(); + } + + // Reuse index buffer + self.index_buffer.clear(); + self.index_buffer.extend( + self.score_buffer + .iter() + .enumerate() + .map(|(id, &s)| (id, if s.is_finite() { s } else { f32::NEG_INFINITY })), + ); + + // P4: Unroll for small k (common case: top-2) + if k == 2 && n >= 2 { + return self.select_top_2_unrolled(); + } + + // Use partial sort for larger k + if k < n / 2 { + self.index_buffer.select_nth_unstable_by(k - 1, |a, b| { + b.1.partial_cmp(&a.1) + .unwrap_or(std::cmp::Ordering::Equal) + .then_with(|| a.0.cmp(&b.0)) + }); + self.index_buffer[..k].sort_by(|a, b| { + b.1.partial_cmp(&a.1) + .unwrap_or(std::cmp::Ordering::Equal) + .then_with(|| a.0.cmp(&b.0)) + }); + } else { + self.index_buffer.sort_by(|a, b| { + b.1.partial_cmp(&a.1) + .unwrap_or(std::cmp::Ordering::Equal) + .then_with(|| a.0.cmp(&b.0)) + }); + } + + self.index_buffer + .iter() + .take(k) + .map(|(id, _)| *id) + .collect() + } + + /// P4: Unrolled top-2 selection (most common MoE configuration) + #[inline] + fn select_top_2_unrolled(&self) -> Vec { + let mut best = (0, f32::NEG_INFINITY); + let mut second = (0, f32::NEG_INFINITY); + + for &(id, score) in &self.index_buffer { + if score > best.1 || (score == best.1 && id < best.0) { + second = best; + best = (id, score); + } else if score > second.1 || (score == second.1 && id < second.0) { + second = (id, score); + } + } + + vec![best.0, second.0] + } + + /// Batch routing for multiple tokens (P2 optimization) + /// + /// Routes multiple tokens in a single call, reusing buffers across tokens. + /// More efficient than calling `route()` multiple times. + /// + /// # Arguments + /// + /// * `batch_logits` - Slice of gate logits for each token (shape: [batch_size][num_experts]) + /// + /// # Returns + /// + /// Vector of (selected_experts, paging_requests) for each token + pub fn route_batch( + &mut self, + batch_logits: &[&[f32]], + ) -> Vec<(Vec, Vec)> { + let mut results = Vec::with_capacity(batch_logits.len()); + + for logits in batch_logits { + results.push(self.route(logits)); + } + + results + } + + /// Apply cache residency bonus to scores (in-place mutation for P0 optimization) + /// + /// For each expert currently in cache, adds `cache_bonus` to its score. + /// This biases the selection toward cached experts without completely + /// overriding the gate network's decisions. + /// + /// # Arguments + /// + /// * `scores` - Mutable slice of scores to modify in-place + pub fn apply_cache_bonus_inplace(&self, scores: &mut [f32]) { + for (id, score) in scores.iter_mut().enumerate() { + // Validate score is not NaN/Inf before processing + if !score.is_finite() { + *score = 0.0; + continue; + } + if self.cache_resident.is_set(id) { + *score += self.config.cache_bonus; + } + } + } + + /// Apply cache residency bonus to scores (allocating version for API compatibility) + /// + /// For each expert currently in cache, adds `cache_bonus` to its score. + /// This biases the selection toward cached experts without completely + /// overriding the gate network's decisions. + pub fn apply_cache_bonus(&self, scores: &[f32]) -> Vec { + let mut result = scores.to_vec(); + self.apply_cache_bonus_inplace(&mut result); + result + } + + /// Select top-K experts by score + /// + /// Returns expert IDs sorted by descending score. + /// Ties are broken by expert ID (lower ID wins) for determinism. + /// + /// Uses partial sort (P0 optimization) for better performance when + /// top_k << num_experts. + pub fn select_top_k(&self, scores: &[f32]) -> Vec { + let n = scores.len(); + let k = self.config.top_k.min(n); + + if k == 0 || n == 0 { + return Vec::new(); + } + + // Create indexed scores, handling NaN/Inf values + let mut indexed: Vec<(ExpertId, f32)> = scores + .iter() + .enumerate() + .map(|(id, &s)| (id, if s.is_finite() { s } else { f32::NEG_INFINITY })) + .collect(); + + // Use partial sort for better performance when k << n + if k < n / 2 { + // Partition to get top-k elements (unordered) + indexed.select_nth_unstable_by(k - 1, |a, b| { + b.1.partial_cmp(&a.1) + .unwrap_or(std::cmp::Ordering::Equal) + .then_with(|| a.0.cmp(&b.0)) + }); + // Sort only the top-k portion + indexed[..k].sort_by(|a, b| { + b.1.partial_cmp(&a.1) + .unwrap_or(std::cmp::Ordering::Equal) + .then_with(|| a.0.cmp(&b.0)) + }); + } else { + // Full sort when k is close to n + indexed.sort_by(|a, b| { + b.1.partial_cmp(&a.1) + .unwrap_or(std::cmp::Ordering::Equal) + .then_with(|| a.0.cmp(&b.0)) + }); + } + + // Take top-K + indexed.into_iter().take(k).map(|(id, _)| id).collect() + } + + /// Update cache residency state + /// + /// Call this when experts are paged in or out. + /// + /// # Arguments + /// + /// * `resident` - List of expert IDs currently in cache + pub fn update_cache_state(&mut self, resident: &[ExpertId]) { + // Clear all + self.cache_resident.clear(); + + // Set resident experts + for &id in resident { + self.cache_resident.set(id, true); + } + } + + /// Mark a single expert as resident or non-resident + pub fn set_resident(&mut self, expert_id: ExpertId, resident: bool) { + self.cache_resident.set(expert_id, resident); + } + + /// Check if an expert is currently resident + pub fn is_resident(&self, expert_id: ExpertId) -> bool { + self.cache_resident.is_set(expert_id) + } + + /// Generate paging requests for selected experts + /// + /// Creates urgent page-in requests for non-resident selected experts. + /// Also generates prefetch requests for high-scoring non-selected experts. + pub fn generate_paging_requests(&self, selected: &[ExpertId]) -> Vec { + let mut requests = Vec::new(); + + // Urgent page-in for non-resident selected experts + for &expert_id in selected { + if !self.is_resident(expert_id) { + requests.push(PagingRequest::page_in_urgent(expert_id)); + } + } + + requests + } + + /// Generate prefetch requests based on affinity + /// + /// Returns prefetch requests for high-affinity non-resident experts. + /// + /// # Arguments + /// + /// * `budget` - Maximum number of prefetch requests to generate + pub fn generate_prefetch_requests(&self, budget: usize) -> Vec { + // Get top experts by affinity that are not currently resident + let candidates = self.affinity.top_k_by_affinity(budget * 2); + + candidates + .into_iter() + .filter(|&id| !self.is_resident(id)) + .take(budget) + .map(PagingRequest::prefetch) + .collect() + } + + /// Get a reference to the current metrics + pub fn metrics(&self) -> &MoeMetrics { + &self.metrics + } + + /// Reset metrics + pub fn reset_metrics(&mut self) { + self.metrics.reset(); + } + + /// Get a reference to the affinity tracker + pub fn affinity(&self) -> &ExpertAffinity { + &self.affinity + } + + /// Get a mutable reference to the affinity tracker + pub fn affinity_mut(&mut self) -> &mut ExpertAffinity { + &mut self.affinity + } + + /// Get a reference to the configuration + pub fn config(&self) -> &RouterConfig { + &self.config + } + + /// Get the current cache hit rate + pub fn hit_rate(&self) -> f32 { + self.metrics.hit_rate() + } + + /// Get list of currently resident experts + pub fn resident_experts(&self) -> Vec { + self.cache_resident.resident_list() + } + + /// Get number of experts + pub fn num_experts(&self) -> usize { + self.config.num_experts + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::moe::AffinityConfig; + + fn make_router(num_experts: usize, top_k: usize, cache_bonus: f32) -> MemoryAwareRouter { + let config = RouterConfig::new(num_experts, top_k).with_cache_bonus(cache_bonus); + MemoryAwareRouter::with_default_affinity(config).expect("test config should be valid") + } + + // --------------------------------------------------------------- + // test_routing_basic + // --------------------------------------------------------------- + + #[test] + fn test_routing_basic() { + let mut router = make_router(8, 2, 0.0); + + // No cache bonus, pure selection + let gate_logits = vec![0.1, 0.3, 0.5, 0.2, 0.4, 0.1, 0.2, 0.15]; + let (selected, _) = router.route(&gate_logits); + + assert_eq!(selected.len(), 2); + // Experts 2 (0.5) and 4 (0.4) should be selected + assert!(selected.contains(&2)); + assert!(selected.contains(&4)); + } + + // --------------------------------------------------------------- + // test_cache_bonus_increases_resident_score + // --------------------------------------------------------------- + + #[test] + fn test_cache_bonus_increases_resident_score() { + let mut router = make_router(4, 1, 0.3); + + // Experts: 0=0.4, 1=0.3, 2=0.2, 3=0.1 + // Without bonus: expert 0 selected + // With bonus on expert 1: 0.3 + 0.3 = 0.6 > 0.4 + + router.update_cache_state(&[1]); // Expert 1 is resident + + let gate_logits = vec![0.4, 0.3, 0.2, 0.1]; + let (selected, _) = router.route(&gate_logits); + + // Expert 1 should be selected because of cache bonus + assert_eq!(selected, vec![1]); + } + + // --------------------------------------------------------------- + // test_top_k_selection + // --------------------------------------------------------------- + + #[test] + fn test_top_k_selection() { + let mut router = make_router(8, 3, 0.0); + + let gate_logits = vec![0.8, 0.1, 0.2, 0.7, 0.3, 0.6, 0.4, 0.5]; + let (selected, _) = router.route(&gate_logits); + + assert_eq!(selected.len(), 3); + // Top 3: expert 0 (0.8), expert 3 (0.7), expert 5 (0.6) + assert_eq!(selected[0], 0); + assert_eq!(selected[1], 3); + assert_eq!(selected[2], 5); + } + + // --------------------------------------------------------------- + // test_paging_requests_for_non_resident + // --------------------------------------------------------------- + + #[test] + fn test_paging_requests_for_non_resident() { + let mut router = make_router(4, 2, 0.0); + + // Only expert 0 is resident + router.update_cache_state(&[0]); + + let gate_logits = vec![0.5, 0.6, 0.4, 0.3]; + let (selected, paging) = router.route(&gate_logits); + + // Selected: experts 1 (0.6) and 0 (0.5) + assert!(selected.contains(&0)); + assert!(selected.contains(&1)); + + // Expert 1 is not resident, should have paging request + assert_eq!(paging.len(), 1); + assert_eq!(paging[0].expert_id, 1); + assert_eq!(paging[0].direction, PagingDirection::In); + assert_eq!(paging[0].priority, PagingPriority::Urgent); + } + + // --------------------------------------------------------------- + // test_router_determinism (INV-6) + // --------------------------------------------------------------- + + #[test] + fn test_router_determinism() { + // INV-6: Same input + cache state = same result + + let mut router1 = make_router(8, 2, 0.15); + let mut router2 = make_router(8, 2, 0.15); + + // Same cache state + router1.update_cache_state(&[0, 3, 5]); + router2.update_cache_state(&[0, 3, 5]); + + let gate_logits = vec![0.1, 0.3, 0.5, 0.2, 0.4, 0.1, 0.2, 0.15]; + + let (selected1, paging1) = router1.route(&gate_logits); + let (selected2, paging2) = router2.route(&gate_logits); + + // Results must be identical + assert_eq!( + selected1, selected2, + "INV-6 violation: different expert selection" + ); + assert_eq!( + paging1.len(), + paging2.len(), + "INV-6 violation: different paging count" + ); + + // Run multiple times on same router + router1.reset_metrics(); + let (selected3, _) = router1.route(&gate_logits); + assert_eq!( + selected1, selected3, + "INV-6 violation: non-deterministic routing" + ); + } + + // --------------------------------------------------------------- + // test_affinity_updates + // --------------------------------------------------------------- + + #[test] + fn test_affinity_updates() { + let mut router = make_router(4, 2, 0.0); + + // Route multiple times to build affinity + let gate_logits = vec![0.4, 0.3, 0.5, 0.1]; + + for _ in 0..5 { + router.route(&gate_logits); + } + + // Experts 2 and 0 should have highest affinity (selected 5 times) + let top = router.affinity().top_k_by_affinity(2); + assert!(top.contains(&2), "Expert 2 should have high affinity"); + assert!(top.contains(&0), "Expert 0 should have high affinity"); + } + + // --------------------------------------------------------------- + // test_zero_cache_bonus_fallback + // --------------------------------------------------------------- + + #[test] + fn test_zero_cache_bonus_fallback() { + let mut router = make_router(4, 2, 0.0); + + // All experts resident + router.update_cache_state(&[0, 1, 2, 3]); + + let gate_logits = vec![0.1, 0.4, 0.3, 0.2]; + let (selected, _) = router.route(&gate_logits); + + // Should select purely by score: experts 1 (0.4) and 2 (0.3) + assert_eq!(selected[0], 1); + assert_eq!(selected[1], 2); + } + + // --------------------------------------------------------------- + // test_all_experts_resident + // --------------------------------------------------------------- + + #[test] + fn test_all_experts_resident() { + let mut router = make_router(4, 2, 0.15); + + // All experts resident + router.update_cache_state(&[0, 1, 2, 3]); + + let gate_logits = vec![0.1, 0.4, 0.3, 0.2]; + let (selected, paging) = router.route(&gate_logits); + + assert_eq!(selected.len(), 2); + // No paging needed + assert!( + paging.is_empty(), + "No paging should be needed when all selected are resident" + ); + + // All should be cache hits + assert_eq!(router.metrics().cache_hits, 2); + assert_eq!(router.metrics().cache_misses, 0); + } + + // --------------------------------------------------------------- + // test_no_experts_resident + // --------------------------------------------------------------- + + #[test] + fn test_no_experts_resident() { + let mut router = make_router(4, 2, 0.15); + + // No experts resident (cold start) + router.update_cache_state(&[]); + + let gate_logits = vec![0.1, 0.4, 0.3, 0.2]; + let (selected, paging) = router.route(&gate_logits); + + assert_eq!(selected.len(), 2); + // Should need paging for all selected + assert_eq!( + paging.len(), + 2, + "Should need to page in all selected experts" + ); + + // All should be cache misses + assert_eq!(router.metrics().cache_misses, 2); + assert_eq!(router.metrics().cache_hits, 0); + } + + // --------------------------------------------------------------- + // test_config_validation + // --------------------------------------------------------------- + + #[test] + fn test_config_validation() { + // Valid config + let valid = RouterConfig::new(8, 2); + assert!(valid.validate().is_ok()); + + // Invalid: top_k = 0 + let invalid1 = RouterConfig { + top_k: 0, + ..RouterConfig::default() + }; + assert!(invalid1.validate().is_err()); + + // Invalid: top_k > num_experts + let invalid2 = RouterConfig { + top_k: 10, + num_experts: 8, + ..RouterConfig::default() + }; + assert!(invalid2.validate().is_err()); + + // Invalid: num_experts = 0 + let invalid3 = RouterConfig { + num_experts: 0, + ..RouterConfig::default() + }; + assert!(invalid3.validate().is_err()); + } + + // --------------------------------------------------------------- + // test_memory_aware_disabled + // --------------------------------------------------------------- + + #[test] + fn test_memory_aware_disabled() { + let config = RouterConfig::new(4, 2) + .with_memory_aware(false) + .with_cache_bonus(0.5); + let mut router = MemoryAwareRouter::with_default_affinity(config).unwrap(); + + // Even with high cache bonus, should not apply it when disabled + router.update_cache_state(&[3]); // Expert 3 resident + + let gate_logits = vec![0.4, 0.3, 0.5, 0.2]; + let (selected, _) = router.route(&gate_logits); + + // Should select by pure score: experts 2 (0.5) and 0 (0.4) + assert_eq!(selected[0], 2); + assert_eq!(selected[1], 0); + } + + // --------------------------------------------------------------- + // test_hit_rate_tracking + // --------------------------------------------------------------- + + #[test] + fn test_hit_rate_tracking() { + let mut router = make_router(4, 2, 0.0); + + // 50% resident + router.update_cache_state(&[0, 2]); + + let gate_logits = vec![0.4, 0.3, 0.5, 0.2]; + // Will select experts 2 (resident) and 0 (resident) + router.route(&gate_logits); + + assert_eq!(router.hit_rate(), 1.0); // Both selected are resident + + router.reset_metrics(); + router.update_cache_state(&[1, 3]); + router.route(&gate_logits); + + assert_eq!(router.hit_rate(), 0.0); // Neither selected is resident + } + + // --------------------------------------------------------------- + // test_prefetch_requests + // --------------------------------------------------------------- + + #[test] + fn test_prefetch_requests() { + let config = RouterConfig::new(4, 2).with_cache_bonus(0.0); + let affinity_config = AffinityConfig::with_num_experts(4).with_decay(1.0); + let affinity = ExpertAffinity::new(affinity_config); + let mut router = MemoryAwareRouter::new(config, affinity).unwrap(); + + // Build affinity + let gate_logits = vec![0.4, 0.3, 0.5, 0.2]; + for _ in 0..10 { + router.route(&gate_logits); + } + + // Only expert 1 is resident + router.update_cache_state(&[1]); + + // Should suggest prefetching high-affinity non-resident experts + let prefetch = router.generate_prefetch_requests(2); + + // Should not include expert 1 (already resident) + for req in &prefetch { + assert_ne!(req.expert_id, 1); + assert_eq!(req.priority, PagingPriority::Prefetch); + } + } + + // --------------------------------------------------------------- + // test_resident_experts_list + // --------------------------------------------------------------- + + #[test] + fn test_resident_experts_list() { + let mut router = make_router(8, 2, 0.15); + + router.update_cache_state(&[1, 3, 5, 7]); + + let resident = router.resident_experts(); + assert_eq!(resident.len(), 4); + assert!(resident.contains(&1)); + assert!(resident.contains(&3)); + assert!(resident.contains(&5)); + assert!(resident.contains(&7)); + assert!(!resident.contains(&0)); + } + + // --------------------------------------------------------------- + // test_set_resident + // --------------------------------------------------------------- + + #[test] + fn test_set_resident() { + let mut router = make_router(4, 2, 0.15); + + assert!(!router.is_resident(0)); + + router.set_resident(0, true); + assert!(router.is_resident(0)); + + router.set_resident(0, false); + assert!(!router.is_resident(0)); + } + + // --------------------------------------------------------------- + // test_tie_breaking_determinism + // --------------------------------------------------------------- + + #[test] + fn test_tie_breaking_determinism() { + let mut router = make_router(4, 2, 0.0); + + // All experts have same score + let gate_logits = vec![0.5, 0.5, 0.5, 0.5]; + let (selected1, _) = router.route(&gate_logits); + let (selected2, _) = router.route(&gate_logits); + + // Should consistently select lowest IDs on ties + assert_eq!(selected1, selected2); + assert_eq!(selected1, vec![0, 1]); // Lowest IDs win ties + } + + // --------------------------------------------------------------- + // test_invalid_gate_logits_length + // --------------------------------------------------------------- + + #[test] + fn test_invalid_gate_logits_length() { + let mut router = make_router(4, 2, 0.15); + + // Wrong length input + let gate_logits = vec![0.5, 0.3]; // Only 2 instead of 4 + let (selected, paging) = router.route(&gate_logits); + + // Should fallback gracefully + assert_eq!(selected.len(), 2); + assert!(paging.is_empty() || paging.len() <= 2); + } + + // --------------------------------------------------------------- + // test_apply_cache_bonus + // --------------------------------------------------------------- + + #[test] + fn test_apply_cache_bonus() { + let mut router = make_router(4, 2, 0.2); + router.update_cache_state(&[1, 2]); + + let scores = vec![0.1, 0.3, 0.4, 0.5]; + let adjusted = router.apply_cache_bonus(&scores); + + // Expert 0: 0.1 + 0 = 0.1 + // Expert 1: 0.3 + 0.2 = 0.5 (resident) + // Expert 2: 0.4 + 0.2 = 0.6 (resident) + // Expert 3: 0.5 + 0 = 0.5 + assert!((adjusted[0] - 0.1).abs() < 1e-6); + assert!((adjusted[1] - 0.5).abs() < 1e-6); + assert!((adjusted[2] - 0.6).abs() < 1e-6); + assert!((adjusted[3] - 0.5).abs() < 1e-6); + } + + // --------------------------------------------------------------- + // test_paging_request_constructors + // --------------------------------------------------------------- + + #[test] + fn test_paging_request_constructors() { + let req1 = PagingRequest::page_in_urgent(5); + assert_eq!(req1.expert_id, 5); + assert_eq!(req1.direction, PagingDirection::In); + assert_eq!(req1.priority, PagingPriority::Urgent); + + let req2 = PagingRequest::prefetch(3); + assert_eq!(req2.expert_id, 3); + assert_eq!(req2.direction, PagingDirection::In); + assert_eq!(req2.priority, PagingPriority::Prefetch); + + let req3 = PagingRequest::page_out(7); + assert_eq!(req3.expert_id, 7); + assert_eq!(req3.direction, PagingDirection::Out); + assert_eq!(req3.priority, PagingPriority::Normal); + } + + // --------------------------------------------------------------- + // test_config_builder + // --------------------------------------------------------------- + + #[test] + fn test_config_builder() { + let config = RouterConfig::new(16, 4) + .with_cache_bonus(0.25) + .with_memory_aware(true) + .with_prefetch_threshold(0.15); + + assert_eq!(config.num_experts, 16); + assert_eq!(config.top_k, 4); + assert!((config.cache_bonus - 0.25).abs() < 1e-6); + assert!(config.memory_aware); + assert!((config.prefetch_threshold - 0.15).abs() < 1e-6); + } + + // --------------------------------------------------------------- + // test_cache_bonus_clamping + // --------------------------------------------------------------- + + #[test] + fn test_cache_bonus_clamping() { + let config = RouterConfig::new(8, 2).with_cache_bonus(1.5); + assert!( + (config.cache_bonus - 1.0).abs() < 1e-6, + "cache_bonus should be clamped to 1.0" + ); + + let config2 = RouterConfig::new(8, 2).with_cache_bonus(-0.5); + assert!( + (config2.cache_bonus - 0.0).abs() < 1e-6, + "cache_bonus should be clamped to 0.0" + ); + } + + // --------------------------------------------------------------- + // P1 Optimization Tests: CacheMask bitmask + // --------------------------------------------------------------- + + #[test] + fn test_cache_mask_small() { + let mut mask = CacheMask::new(64); + + // Initially all clear + for i in 0..64 { + assert!(!mask.is_set(i), "Bit {} should be clear initially", i); + } + + // Set some bits + mask.set(0, true); + mask.set(31, true); + mask.set(63, true); + + assert!(mask.is_set(0)); + assert!(mask.is_set(31)); + assert!(mask.is_set(63)); + assert!(!mask.is_set(1)); + assert!(!mask.is_set(32)); + + // Count should be 3 + assert_eq!(mask.count(), 3); + + // Resident list + let list = mask.resident_list(); + assert_eq!(list.len(), 3); + assert!(list.contains(&0)); + assert!(list.contains(&31)); + assert!(list.contains(&63)); + + // Clear and verify + mask.clear(); + assert_eq!(mask.count(), 0); + assert!(!mask.is_set(0)); + } + + #[test] + fn test_cache_mask_large() { + // Test with >64 experts (uses extended Vec) + let mut mask = CacheMask::new(128); + + // Set bits across word boundaries + mask.set(0, true); + mask.set(63, true); + mask.set(64, true); // First bit of second word + mask.set(127, true); + + assert!(mask.is_set(0)); + assert!(mask.is_set(63)); + assert!(mask.is_set(64)); + assert!(mask.is_set(127)); + assert!(!mask.is_set(65)); + + assert_eq!(mask.count(), 4); + + let list = mask.resident_list(); + assert_eq!(list.len(), 4); + + // Clear + mask.clear(); + assert_eq!(mask.count(), 0); + } + + #[test] + fn test_cache_mask_out_of_bounds() { + let mut mask = CacheMask::new(8); + + // Out of bounds should be no-op and return false + mask.set(100, true); + assert!(!mask.is_set(100)); + assert_eq!(mask.count(), 0); + } + + #[test] + fn test_router_with_many_experts() { + // Test router with >64 experts to exercise extended bitmask + let config = RouterConfig::new(128, 4); + let mut router = MemoryAwareRouter::with_default_affinity(config).unwrap(); + + // Set some residents across the full range + router.update_cache_state(&[0, 32, 64, 96, 127]); + + assert!(router.is_resident(0)); + assert!(router.is_resident(64)); + assert!(router.is_resident(127)); + assert!(!router.is_resident(1)); + + let resident = router.resident_experts(); + assert_eq!(resident.len(), 5); + } + + #[test] + fn test_empty_cache_state() { + let mut router = make_router(8, 2, 0.15); + + // Empty update + router.update_cache_state(&[]); + + // No experts should be resident + for i in 0..8 { + assert!( + !router.is_resident(i), + "Expert {} should not be resident", + i + ); + } + + assert!(router.resident_experts().is_empty()); + } +} diff --git a/crates/ruvllm/src/moe/sram_mapper.rs b/crates/ruvllm/src/moe/sram_mapper.rs new file mode 100644 index 000000000..3e4507841 --- /dev/null +++ b/crates/ruvllm/src/moe/sram_mapper.rs @@ -0,0 +1,1150 @@ +//! SRAM Mapper for Hardware Memory Hierarchy Configuration (ADR-092) +//! +//! This module provides platform-specific memory hierarchy configuration for +//! MoE (Mixture of Experts) expert placement across different memory tiers. +//! +//! ## Memory Hierarchy +//! +//! Modern systems have a three-tier memory hierarchy with vastly different +//! latencies and capacities: +//! +//! | Tier | Type | Latency | Capacity | Use Case | +//! |------|------|---------|----------|----------| +//! | SRAM | L2/L3 Cache | ~10-40ns | 4-64MB | Hot experts | +//! | DRAM | Main Memory | ~50-100ns | 2-64GB | Warm experts | +//! | Storage | Flash/NVMe | ~50-200us | Unlimited | Cold experts | +//! +//! ## Expert Placement Strategy +//! +//! For optimal MoE inference performance: +//! 1. **SRAM**: Keep top-K active experts and likely-next-picks in cache +//! 2. **DRAM**: Cache frequently-accessed experts not in SRAM +//! 3. **Storage**: Page in rarely-used experts on demand +//! +//! ## Platform Considerations +//! +//! - **Raspberry Pi 5**: 8GB RAM, small L2 cache, optimize for DRAM +//! - **Mobile**: 2-4GB available, aggressive SRAM management +//! - **Desktop**: 16GB+ RAM, larger caches, more flexibility +//! - **WASM/Browser**: Configurable heap, plan for limited memory +//! +//! ## Example +//! +//! ```rust,ignore +//! use ruvllm::moe::{SramMapper, HardwarePreset, MemoryTier}; +//! +//! // Create mapper for Raspberry Pi 5 with 8 experts +//! let mut mapper = SramMapper::from_preset(HardwarePreset::RaspberryPi5, 8, 34_000_000); +//! +//! // Assign hot experts to SRAM +//! mapper.assign_tier(0, MemoryTier::Sram); +//! mapper.assign_tier(1, MemoryTier::Sram); +//! +//! // Check paging latency estimate +//! let latency = mapper.estimate_paging_latency(5); // Returns microseconds +//! ``` + +use std::collections::HashMap; + +// Use ExpertId from parent module +use super::ExpertId; + +// ============================================================================ +// Types +// ============================================================================ + +/// Memory tier classification for expert placement. +/// +/// Each tier has different characteristics in terms of latency, bandwidth, +/// and capacity. The SramMapper assigns experts to tiers based on access +/// patterns and hardware constraints. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum MemoryTier { + /// L2/L3 cache tier (fastest, smallest capacity). + /// + /// Experts in SRAM tier have sub-microsecond access latency but limited + /// slots. Reserved for the most frequently accessed experts. + Sram, + + /// Main memory tier (DRAM). + /// + /// Moderate latency (~100ns) but significantly larger capacity. Good for + /// experts that are accessed regularly but not in the hot path. + Dram, + + /// Storage tier (Flash/NVMe, slowest, largest capacity). + /// + /// High latency (50-200+ microseconds) but virtually unlimited capacity. + /// Used for cold experts that are rarely accessed. + Storage, +} + +impl MemoryTier { + /// Return a human-readable name for the tier. + pub fn name(&self) -> &'static str { + match self { + MemoryTier::Sram => "SRAM (L2/L3 Cache)", + MemoryTier::Dram => "DRAM (Main Memory)", + MemoryTier::Storage => "Storage (Flash/NVMe)", + } + } + + /// Return the tier as an index for array lookups. + pub fn index(&self) -> usize { + match self { + MemoryTier::Sram => 0, + MemoryTier::Dram => 1, + MemoryTier::Storage => 2, + } + } +} + +/// Expert affinity information for eviction decisions. +/// +/// Tracks access patterns and preferences to help the SRAM mapper make +/// intelligent tier assignment and eviction decisions. +#[derive(Debug, Clone)] +pub struct SramExpertAffinity { + /// Expert identifier. + pub expert_id: ExpertId, + + /// Total access count (frequency). + pub access_count: usize, + + /// Last access timestamp (monotonic counter). + pub last_access: u64, + + /// Average router weight when selected (0.0 - 1.0). + pub avg_router_weight: f32, + + /// Number of tokens that selected this expert recently. + pub recent_selections: usize, + + /// Whether this expert is currently "pinned" to its tier. + pub pinned: bool, +} + +impl Default for SramExpertAffinity { + fn default() -> Self { + Self { + expert_id: 0, + access_count: 0, + last_access: 0, + avg_router_weight: 0.0, + recent_selections: 0, + pinned: false, + } + } +} + +impl SramExpertAffinity { + /// Create new affinity tracking for an expert. + pub fn new(expert_id: ExpertId) -> Self { + Self { + expert_id, + ..Default::default() + } + } + + /// Compute a priority score for eviction decisions (higher = less likely to evict). + /// + /// The score combines frequency, recency, and router weight into a single + /// metric used for tier assignment decisions. + pub fn priority_score(&self) -> f32 { + // Weight the factors: + // - Frequency has diminishing returns (log scale) + // - Recency is important for temporal locality + // - Router weight indicates model preference + let freq_factor = (self.access_count as f32 + 1.0).ln(); + + // Guard against division by zero when last_access is 0 + let recency_factor = if self.last_access == 0 { + 0.0 + } else { + 1.0 / (1.0 + 0.001 / self.last_access as f32) + }; + + let weight_factor = self.avg_router_weight * 2.0; + + freq_factor + recency_factor + weight_factor + } +} + +// ============================================================================ +// Hardware Configuration +// ============================================================================ + +/// Hardware configuration for a specific platform. +/// +/// Describes the memory hierarchy constraints that guide expert placement +/// decisions. Can be created from presets or with custom values. +#[derive(Debug, Clone)] +pub struct HardwareConfig { + /// L2+L3 cache size in bytes. + /// + /// This is the effective SRAM available for expert caching. On most + /// systems this is 4-64MB shared across all cores. + pub sram_bytes: usize, + + /// Available DRAM budget for expert caching in bytes. + /// + /// This is the portion of main memory allocated for keeping experts + /// resident. Should be less than total system RAM to leave room for + /// other allocations. + pub dram_budget_bytes: usize, + + /// Number of expert slots that fit in SRAM. + /// + /// Computed as `sram_bytes / expert_size_bytes`, possibly with some + /// slack for cache line alignment and other overhead. + pub sram_expert_slots: usize, + + /// Number of expert slots that fit in DRAM budget. + /// + /// Computed as `dram_budget_bytes / expert_size_bytes`. + pub dram_expert_slots: usize, + + /// Expert size in bytes (packed weights for one expert). + /// + /// Includes all three projections (gate_proj, up_proj, down_proj) with + /// packed ternary weights and scale factors. + pub expert_size_bytes: usize, +} + +impl Default for HardwareConfig { + fn default() -> Self { + Self { + sram_bytes: 8 * 1024 * 1024, // 8 MB typical L3 + dram_budget_bytes: 4 * 1024 * 1024 * 1024, // 4 GB DRAM budget + sram_expert_slots: 2, + dram_expert_slots: 8, + expert_size_bytes: 34_000_000, // ~34 MB per expert + } + } +} + +impl HardwareConfig { + /// Create a new hardware configuration. + /// + /// Automatically computes slot counts from byte budgets and expert size. + pub fn new(sram_bytes: usize, dram_budget_bytes: usize, expert_size_bytes: usize) -> Self { + let sram_expert_slots = sram_bytes / expert_size_bytes.max(1); + let dram_expert_slots = dram_budget_bytes / expert_size_bytes.max(1); + + Self { + sram_bytes, + dram_budget_bytes, + sram_expert_slots, + dram_expert_slots, + expert_size_bytes, + } + } + + /// Total memory budget across SRAM and DRAM tiers. + pub fn total_budget(&self) -> usize { + self.sram_bytes + self.dram_budget_bytes + } + + /// Total expert slots available (SRAM + DRAM). + pub fn total_slots(&self) -> usize { + self.sram_expert_slots + self.dram_expert_slots + } +} + +// ============================================================================ +// Hardware Presets +// ============================================================================ + +/// Known hardware presets for common deployment targets. +/// +/// Each preset provides sensible defaults for the memory hierarchy based on +/// typical hardware configurations of that platform class. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum HardwarePreset { + /// Raspberry Pi 5 (8GB RAM, ARM Cortex-A76). + /// + /// - L2: 512KB per core (4 cores) + /// - L3: None + /// - RAM: 8GB LPDDR4X + /// - Typical: 4-6 experts in DRAM, 1-2 in cache + RaspberryPi5, + + /// Mobile device (2-4GB available memory). + /// + /// Aggressive memory management due to system constraints. + /// - L2/L3: ~2-8MB shared + /// - RAM: 2-3GB available after OS + /// - Typical: 2-4 experts in DRAM, 1 in cache + Mobile, + + /// Desktop workstation (16GB+ RAM, modern x86_64). + /// + /// - L2: 1-2MB per core + /// - L3: 16-64MB shared + /// - RAM: 16GB+ available + /// - Typical: 8+ experts in DRAM, 2-4 in cache + Desktop, + + /// WebAssembly browser environment. + /// + /// Configurable heap size, typically limited to 1-4GB depending on + /// browser and device. Conservative defaults. + /// - "Cache": WASM linear memory hot region + /// - "DRAM": Rest of WASM heap + /// - Typical: 1-2 experts warm + WasmBrowser, + + /// Custom configuration (use HardwareConfig directly). + Custom, +} + +impl HardwarePreset { + /// Get the default HardwareConfig for this preset. + /// + /// Note: `expert_size_bytes` must be provided separately as it depends + /// on the specific model architecture. + pub fn default_config(&self, expert_size_bytes: usize) -> HardwareConfig { + match self { + HardwarePreset::RaspberryPi5 => { + // RPi5: 512KB L2 per core (effectively ~1-2MB usable), 8GB RAM + let sram_bytes = 2 * 1024 * 1024; // ~2MB effective cache + let dram_budget = 6 * 1024 * 1024 * 1024; // 6GB budget + HardwareConfig::new(sram_bytes, dram_budget, expert_size_bytes) + } + HardwarePreset::Mobile => { + // Mobile: 4MB L3, 2-3GB available + let sram_bytes = 4 * 1024 * 1024; + let dram_budget = 2 * 1024 * 1024 * 1024; + HardwareConfig::new(sram_bytes, dram_budget, expert_size_bytes) + } + HardwarePreset::Desktop => { + // Desktop: 32MB L3, 16GB+ available + let sram_bytes = 32 * 1024 * 1024; + let dram_budget = 12 * 1024 * 1024 * 1024; + HardwareConfig::new(sram_bytes, dram_budget, expert_size_bytes) + } + HardwarePreset::WasmBrowser => { + // WASM: ~2MB hot region, 1GB heap budget + let sram_bytes = 2 * 1024 * 1024; + let dram_budget = 1024 * 1024 * 1024; + HardwareConfig::new(sram_bytes, dram_budget, expert_size_bytes) + } + HardwarePreset::Custom => HardwareConfig::default(), + } + } + + /// Get a human-readable name for the preset. + pub fn name(&self) -> &'static str { + match self { + HardwarePreset::RaspberryPi5 => "Raspberry Pi 5", + HardwarePreset::Mobile => "Mobile Device", + HardwarePreset::Desktop => "Desktop Workstation", + HardwarePreset::WasmBrowser => "WASM Browser", + HardwarePreset::Custom => "Custom", + } + } +} + +// ============================================================================ +// SRAM Mapper +// ============================================================================ + +/// SRAM Mapper for hardware memory hierarchy configuration. +/// +/// Manages expert placement across memory tiers (SRAM/Cache, DRAM, Storage) +/// based on access patterns and hardware constraints. Provides latency +/// estimates and eviction suggestions for optimal MoE inference performance. +/// +/// # Usage +/// +/// ```rust,ignore +/// use ruvllm::moe::{SramMapper, HardwarePreset, MemoryTier}; +/// +/// // Create from preset +/// let mut mapper = SramMapper::from_preset(HardwarePreset::RaspberryPi5, 8, 34_000_000); +/// +/// // Assign experts to tiers +/// mapper.assign_tier(0, MemoryTier::Sram); +/// mapper.assign_tier(1, MemoryTier::Sram); +/// mapper.assign_tier(2, MemoryTier::Dram); +/// +/// // Query tier assignments +/// assert_eq!(mapper.get_tier(0), MemoryTier::Sram); +/// +/// // Get latency estimate (microseconds) +/// let latency = mapper.estimate_paging_latency(5); +/// ``` +pub struct SramMapper { + /// Hardware configuration. + config: HardwareConfig, + + /// Total number of experts in the model. + num_experts: usize, + + /// Current tier assignment for each expert (indexed by ExpertId). + tier_map: Vec, + + /// Expert affinity tracking for eviction decisions. + affinity: Vec, + + /// Estimated paging latency per tier in microseconds. + /// + /// Index 0 = SRAM, 1 = DRAM, 2 = Storage. + tier_latency: [u64; 3], + + /// Current SRAM slot usage. + sram_used: usize, + + /// Current DRAM slot usage. + dram_used: usize, + + /// Monotonic counter for LRU tracking. + access_counter: u64, +} + +impl SramMapper { + /// Create a new SRAM mapper from a hardware preset. + /// + /// # Arguments + /// + /// * `preset` - Hardware preset to use for configuration + /// * `num_experts` - Total number of experts in the model + /// * `expert_size_bytes` - Size of each expert in bytes + /// + /// # Example + /// + /// ```rust,ignore + /// let mapper = SramMapper::from_preset(HardwarePreset::RaspberryPi5, 8, 34_000_000); + /// ``` + pub fn from_preset( + preset: HardwarePreset, + num_experts: usize, + expert_size_bytes: usize, + ) -> Self { + let config = preset.default_config(expert_size_bytes); + Self::from_config(config, num_experts) + } + + /// Create a new SRAM mapper from a custom hardware configuration. + /// + /// # Arguments + /// + /// * `config` - Custom hardware configuration + /// * `num_experts` - Total number of experts in the model + /// + /// # Example + /// + /// ```rust,ignore + /// let config = HardwareConfig::new(16 * 1024 * 1024, 8 * 1024 * 1024 * 1024, 34_000_000); + /// let mapper = SramMapper::from_config(config, 8); + /// ``` + pub fn from_config(config: HardwareConfig, num_experts: usize) -> Self { + // Initialize all experts to Storage tier (cold start) + let tier_map = vec![MemoryTier::Storage; num_experts]; + + // Initialize affinity tracking + let affinity = (0..num_experts).map(SramExpertAffinity::new).collect(); + + // Default latency estimates (microseconds) + // SRAM: ~0.04us (40ns), DRAM: ~0.1us (100ns), Storage: ~100us + let tier_latency = [0, 0, 100]; + + Self { + config, + num_experts, + tier_map, + affinity, + tier_latency, + sram_used: 0, + dram_used: 0, + access_counter: 0, + } + } + + /// Assign an expert to a specific memory tier. + /// + /// This updates the internal tracking and slot usage. If the expert was + /// previously in a different tier, the old slot is freed. + /// + /// # Arguments + /// + /// * `expert_id` - Expert to assign + /// * `tier` - Target memory tier + /// + /// # Returns + /// + /// Returns `false` if `expert_id >= num_experts`, `true` otherwise. + pub fn assign_tier(&mut self, expert_id: ExpertId, tier: MemoryTier) -> bool { + if expert_id >= self.num_experts { + return false; + } + + let old_tier = self.tier_map[expert_id]; + + // Free old slot + match old_tier { + MemoryTier::Sram => { + if self.sram_used > 0 { + self.sram_used -= 1; + } + } + MemoryTier::Dram => { + if self.dram_used > 0 { + self.dram_used -= 1; + } + } + MemoryTier::Storage => {} + } + + // Allocate new slot + match tier { + MemoryTier::Sram => self.sram_used += 1, + MemoryTier::Dram => self.dram_used += 1, + MemoryTier::Storage => {} + } + + self.tier_map[expert_id] = tier; + true + } + + /// Get the current memory tier for an expert. + /// + /// # Arguments + /// + /// * `expert_id` - Expert to query + /// + /// # Returns + /// + /// The current memory tier assignment. Returns `Storage` for out-of-range IDs. + pub fn get_tier(&self, expert_id: ExpertId) -> MemoryTier { + self.tier_map + .get(expert_id) + .copied() + .unwrap_or(MemoryTier::Storage) + } + + /// Estimate the paging latency for accessing an expert in microseconds. + /// + /// The latency depends on the expert's current memory tier: + /// - SRAM: ~0 microseconds (cache hit) + /// - DRAM: ~0 microseconds (memory access) + /// - Storage: ~100+ microseconds (page fault / disk access) + /// + /// # Arguments + /// + /// * `expert_id` - Expert to estimate latency for + /// + /// # Returns + /// + /// Estimated latency in microseconds. + pub fn estimate_paging_latency(&self, expert_id: ExpertId) -> u64 { + let tier = self.get_tier(expert_id); + self.tier_latency[tier.index()] + } + + /// Get the number of experts that fit in SRAM. + pub fn sram_capacity(&self) -> usize { + self.config.sram_expert_slots + } + + /// Get the number of experts that fit in DRAM budget. + pub fn dram_capacity(&self) -> usize { + self.config.dram_expert_slots + } + + /// Get current SRAM slot usage. + pub fn sram_used(&self) -> usize { + self.sram_used + } + + /// Get current DRAM slot usage. + pub fn dram_used(&self) -> usize { + self.dram_used + } + + /// Get available SRAM slots. + pub fn sram_available(&self) -> usize { + self.config.sram_expert_slots.saturating_sub(self.sram_used) + } + + /// Get available DRAM slots. + pub fn dram_available(&self) -> usize { + self.config.dram_expert_slots.saturating_sub(self.dram_used) + } + + /// Record an access to an expert for affinity tracking. + /// + /// # Arguments + /// + /// * `expert_id` - Expert that was accessed + /// * `router_weight` - Router softmax weight for this expert (0.0 - 1.0) + pub fn record_access(&mut self, expert_id: ExpertId, router_weight: f32) { + if expert_id >= self.num_experts { + return; + } + + self.access_counter += 1; + + let affinity = &mut self.affinity[expert_id]; + affinity.access_count += 1; + affinity.last_access = self.access_counter; + affinity.recent_selections += 1; + + // Exponential moving average for router weight + let alpha = 0.1; + affinity.avg_router_weight = + alpha * router_weight + (1.0 - alpha) * affinity.avg_router_weight; + } + + /// Suggest tier changes based on current affinity data. + /// + /// Analyzes expert access patterns and suggests promotions (to faster tiers) + /// or demotions (to slower tiers) to optimize the memory hierarchy. + /// + /// # Arguments + /// + /// * `affinity_data` - Optional external affinity data (uses internal if None) + /// + /// # Returns + /// + /// A vector of `(ExpertId, MemoryTier)` pairs suggesting new tier assignments. + pub fn suggest_eviction_tier( + &self, + _affinity_data: &SramExpertAffinity, + ) -> Vec<(ExpertId, MemoryTier)> { + self.suggest_tier_changes() + } + + /// Suggest tier changes based on internal affinity tracking. + /// + /// Implements a simple policy: + /// 1. Promote high-priority experts to SRAM (if slots available) + /// 2. Demote low-priority SRAM experts to DRAM + /// 3. Demote rarely-used DRAM experts to Storage + /// + /// # Returns + /// + /// A vector of `(ExpertId, MemoryTier)` pairs suggesting new tier assignments. + pub fn suggest_tier_changes(&self) -> Vec<(ExpertId, MemoryTier)> { + let mut suggestions = Vec::new(); + + // Collect experts with their priority scores and current tiers + let mut experts: Vec<(ExpertId, f32, MemoryTier)> = self + .affinity + .iter() + .enumerate() + .filter(|(_, aff)| !aff.pinned) + .map(|(id, aff)| (id, aff.priority_score(), self.tier_map[id])) + .collect(); + + // Sort by priority (highest first) + experts.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + + // Suggest promotions to SRAM for top experts currently in DRAM/Storage + let sram_available = self.sram_available(); + let mut promoted_to_sram = 0; + + for &(expert_id, _priority, current_tier) in &experts { + if promoted_to_sram >= sram_available { + break; + } + if current_tier != MemoryTier::Sram { + suggestions.push((expert_id, MemoryTier::Sram)); + promoted_to_sram += 1; + } + } + + // Suggest demotions for low-priority SRAM experts + // (process from lowest priority) + for &(expert_id, _priority, current_tier) in experts.iter().rev() { + if current_tier == MemoryTier::Sram + && suggestions.iter().all(|(id, _)| *id != expert_id) + { + if self.dram_available() > 0 { + suggestions.push((expert_id, MemoryTier::Dram)); + } else { + suggestions.push((expert_id, MemoryTier::Storage)); + } + } + } + + suggestions + } + + /// Pin an expert to its current tier (prevent automatic eviction). + pub fn pin(&mut self, expert_id: ExpertId) { + if expert_id < self.num_experts { + self.affinity[expert_id].pinned = true; + } + } + + /// Unpin an expert (allow automatic tier changes). + pub fn unpin(&mut self, expert_id: ExpertId) { + if expert_id < self.num_experts { + self.affinity[expert_id].pinned = false; + } + } + + /// Get a reference to the hardware configuration. + pub fn config(&self) -> &HardwareConfig { + &self.config + } + + /// Get the total number of experts. + pub fn num_experts(&self) -> usize { + self.num_experts + } + + /// Set custom tier latency estimates. + /// + /// # Arguments + /// + /// * `sram_us` - SRAM tier latency in microseconds + /// * `dram_us` - DRAM tier latency in microseconds + /// * `storage_us` - Storage tier latency in microseconds + pub fn set_tier_latencies(&mut self, sram_us: u64, dram_us: u64, storage_us: u64) { + self.tier_latency = [sram_us, dram_us, storage_us]; + } + + /// Get experts currently in a specific tier. + pub fn experts_in_tier(&self, tier: MemoryTier) -> Vec { + self.tier_map + .iter() + .enumerate() + .filter(|(_, &t)| t == tier) + .map(|(id, _)| id) + .collect() + } + + /// Get the affinity data for an expert. + pub fn get_affinity(&self, expert_id: ExpertId) -> Option<&SramExpertAffinity> { + self.affinity.get(expert_id) + } + + /// Reset all affinity tracking data. + pub fn reset_affinity(&mut self) { + for (id, aff) in self.affinity.iter_mut().enumerate() { + *aff = SramExpertAffinity::new(id); + } + self.access_counter = 0; + } + + /// Get a summary of current tier distribution. + pub fn tier_summary(&self) -> HashMap { + let mut summary = HashMap::new(); + summary.insert(MemoryTier::Sram, 0); + summary.insert(MemoryTier::Dram, 0); + summary.insert(MemoryTier::Storage, 0); + + for &tier in &self.tier_map { + *summary.entry(tier).or_insert(0) += 1; + } + + summary + } +} + +// ============================================================================ +// Tests +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + + // --------------------------------------------------------------- + // test_from_preset_raspberry_pi + // --------------------------------------------------------------- + + #[test] + fn test_from_preset_raspberry_pi() { + let expert_size = 34_000_000; // 34 MB per expert + let mapper = SramMapper::from_preset(HardwarePreset::RaspberryPi5, 8, expert_size); + + // RPi5: 2MB SRAM, 6GB DRAM + // SRAM slots: 2MB / 34MB = 0 (can't fit one expert in cache) + // DRAM slots: 6GB / 34MB = ~176 + assert_eq!(mapper.num_experts(), 8); + assert_eq!(mapper.sram_capacity(), 0); // Expert too large for cache + assert!(mapper.dram_capacity() > 0); + + // All experts start in Storage tier + for i in 0..8 { + assert_eq!(mapper.get_tier(i), MemoryTier::Storage); + } + } + + #[test] + fn test_from_preset_raspberry_pi_small_experts() { + // Test with smaller experts that fit in cache + let expert_size = 500_000; // 500 KB per expert + let mapper = SramMapper::from_preset(HardwarePreset::RaspberryPi5, 8, expert_size); + + // RPi5: 2MB SRAM = 4 slots @ 500KB each + // 6GB DRAM = 12000 slots + assert_eq!(mapper.sram_capacity(), 4); + assert!(mapper.dram_capacity() > 1000); + } + + // --------------------------------------------------------------- + // test_from_preset_mobile + // --------------------------------------------------------------- + + #[test] + fn test_from_preset_mobile() { + let expert_size = 1024 * 1024; // 1 MiB per expert (binary units) + let mapper = SramMapper::from_preset(HardwarePreset::Mobile, 8, expert_size); + + // Mobile: 4MiB SRAM, 2GiB DRAM + // SRAM slots: 4MiB / 1MiB = 4 + // DRAM slots: 2GiB / 1MiB = 2048 + assert_eq!(mapper.sram_capacity(), 4); + assert_eq!(mapper.dram_capacity(), 2048); + assert_eq!(mapper.num_experts(), 8); + } + + // --------------------------------------------------------------- + // test_from_preset_desktop + // --------------------------------------------------------------- + + #[test] + fn test_from_preset_desktop() { + let expert_size = 8 * 1024 * 1024; // 8 MiB per expert (binary units) + let mapper = SramMapper::from_preset(HardwarePreset::Desktop, 16, expert_size); + + // Desktop: 32MiB SRAM, 12GiB DRAM + // SRAM slots: 32MiB / 8MiB = 4 + // DRAM slots: 12GiB / 8MiB = 12 * 1024 / 8 = 1536 + assert_eq!(mapper.sram_capacity(), 4); + assert_eq!(mapper.dram_capacity(), 1536); + assert_eq!(mapper.num_experts(), 16); + } + + // --------------------------------------------------------------- + // test_tier_assignment + // --------------------------------------------------------------- + + #[test] + fn test_tier_assignment() { + let config = HardwareConfig::new( + 16 * 1024 * 1024, // 16 MB SRAM + 4 * 1024 * 1024 * 1024, // 4 GB DRAM + 4 * 1024 * 1024, // 4 MB per expert + ); + let mut mapper = SramMapper::from_config(config, 8); + + // Initially all in Storage + assert_eq!(mapper.get_tier(0), MemoryTier::Storage); + assert_eq!(mapper.sram_used(), 0); + assert_eq!(mapper.dram_used(), 0); + + // Assign expert 0 to SRAM + mapper.assign_tier(0, MemoryTier::Sram); + assert_eq!(mapper.get_tier(0), MemoryTier::Sram); + assert_eq!(mapper.sram_used(), 1); + + // Assign expert 1 to DRAM + mapper.assign_tier(1, MemoryTier::Dram); + assert_eq!(mapper.get_tier(1), MemoryTier::Dram); + assert_eq!(mapper.dram_used(), 1); + + // Move expert 0 from SRAM to DRAM + mapper.assign_tier(0, MemoryTier::Dram); + assert_eq!(mapper.get_tier(0), MemoryTier::Dram); + assert_eq!(mapper.sram_used(), 0); + assert_eq!(mapper.dram_used(), 2); + + // Move expert 1 back to Storage + mapper.assign_tier(1, MemoryTier::Storage); + assert_eq!(mapper.get_tier(1), MemoryTier::Storage); + assert_eq!(mapper.dram_used(), 1); + } + + // --------------------------------------------------------------- + // test_paging_latency_estimates + // --------------------------------------------------------------- + + #[test] + fn test_paging_latency_estimates() { + let config = HardwareConfig::new(16 * 1024 * 1024, 4 * 1024 * 1024 * 1024, 4 * 1024 * 1024); + let mut mapper = SramMapper::from_config(config, 4); + + // Set custom latencies + mapper.set_tier_latencies(1, 10, 200); + + mapper.assign_tier(0, MemoryTier::Sram); + mapper.assign_tier(1, MemoryTier::Dram); + mapper.assign_tier(2, MemoryTier::Storage); + + assert_eq!(mapper.estimate_paging_latency(0), 1); // SRAM + assert_eq!(mapper.estimate_paging_latency(1), 10); // DRAM + assert_eq!(mapper.estimate_paging_latency(2), 200); // Storage + assert_eq!(mapper.estimate_paging_latency(3), 200); // Default (Storage) + + // Out of range returns Storage latency + assert_eq!(mapper.estimate_paging_latency(100), 200); + } + + // --------------------------------------------------------------- + // test_capacity_calculations + // --------------------------------------------------------------- + + #[test] + fn test_capacity_calculations() { + let config = HardwareConfig::new( + 32 * 1024 * 1024, // 32 MB SRAM + 8 * 1024 * 1024 * 1024, // 8 GB DRAM + 8 * 1024 * 1024, // 8 MB per expert + ); + let mapper = SramMapper::from_config(config, 16); + + // SRAM: 32MB / 8MB = 4 slots + assert_eq!(mapper.sram_capacity(), 4); + + // DRAM: 8GB / 8MB = 1024 slots + assert_eq!(mapper.dram_capacity(), 1024); + + // Total + assert_eq!(mapper.config().total_slots(), 1028); + assert_eq!( + mapper.config().total_budget(), + 32 * 1024 * 1024 + 8 * 1024 * 1024 * 1024 + ); + + // Available (nothing allocated yet) + assert_eq!(mapper.sram_available(), 4); + assert_eq!(mapper.dram_available(), 1024); + } + + // --------------------------------------------------------------- + // test_eviction_suggestions + // --------------------------------------------------------------- + + #[test] + fn test_eviction_suggestions() { + let config = HardwareConfig::new(16 * 1024 * 1024, 4 * 1024 * 1024 * 1024, 4 * 1024 * 1024); + let mut mapper = SramMapper::from_config(config, 8); + + // Simulate access patterns + for _ in 0..10 { + mapper.record_access(0, 0.8); + } + for _ in 0..5 { + mapper.record_access(1, 0.6); + } + mapper.record_access(2, 0.3); + + // Get suggestions - should promote frequently accessed experts + let suggestions = mapper.suggest_tier_changes(); + + // Verify suggestions include high-priority experts + // (exact suggestions depend on affinity algorithm) + assert!(!suggestions.is_empty() || mapper.sram_available() == 0); + } + + // --------------------------------------------------------------- + // test_custom_config + // --------------------------------------------------------------- + + #[test] + fn test_custom_config() { + let config = HardwareConfig { + sram_bytes: 64 * 1024 * 1024, // 64 MB + dram_budget_bytes: 16 * 1024 * 1024 * 1024, // 16 GB + sram_expert_slots: 8, + dram_expert_slots: 200, + expert_size_bytes: 8 * 1024 * 1024, + }; + + let mapper = SramMapper::from_config(config.clone(), 32); + + assert_eq!(mapper.sram_capacity(), 8); + assert_eq!(mapper.dram_capacity(), 200); + assert_eq!(mapper.num_experts(), 32); + assert_eq!(mapper.config().expert_size_bytes, 8 * 1024 * 1024); + } + + // --------------------------------------------------------------- + // test_affinity_tracking + // --------------------------------------------------------------- + + #[test] + fn test_affinity_tracking() { + let config = HardwareConfig::default(); + let mut mapper = SramMapper::from_config(config, 4); + + // Record accesses + mapper.record_access(0, 0.9); + mapper.record_access(0, 0.8); + mapper.record_access(1, 0.5); + + let aff0 = mapper.get_affinity(0).unwrap(); + assert_eq!(aff0.access_count, 2); + assert!(aff0.avg_router_weight > 0.0); + + let aff1 = mapper.get_affinity(1).unwrap(); + assert_eq!(aff1.access_count, 1); + + // Reset affinity + mapper.reset_affinity(); + let aff0_reset = mapper.get_affinity(0).unwrap(); + assert_eq!(aff0_reset.access_count, 0); + } + + // --------------------------------------------------------------- + // test_pin_unpin + // --------------------------------------------------------------- + + #[test] + fn test_pin_unpin() { + let config = HardwareConfig::default(); + let mut mapper = SramMapper::from_config(config, 4); + + // Pin expert 0 + mapper.pin(0); + assert!(mapper.get_affinity(0).unwrap().pinned); + + // Unpin expert 0 + mapper.unpin(0); + assert!(!mapper.get_affinity(0).unwrap().pinned); + } + + // --------------------------------------------------------------- + // test_experts_in_tier + // --------------------------------------------------------------- + + #[test] + fn test_experts_in_tier() { + let config = HardwareConfig::new(16 * 1024 * 1024, 4 * 1024 * 1024 * 1024, 4 * 1024 * 1024); + let mut mapper = SramMapper::from_config(config, 8); + + mapper.assign_tier(0, MemoryTier::Sram); + mapper.assign_tier(1, MemoryTier::Sram); + mapper.assign_tier(2, MemoryTier::Dram); + mapper.assign_tier(3, MemoryTier::Dram); + mapper.assign_tier(4, MemoryTier::Dram); + + let sram_experts = mapper.experts_in_tier(MemoryTier::Sram); + assert_eq!(sram_experts.len(), 2); + assert!(sram_experts.contains(&0)); + assert!(sram_experts.contains(&1)); + + let dram_experts = mapper.experts_in_tier(MemoryTier::Dram); + assert_eq!(dram_experts.len(), 3); + + let storage_experts = mapper.experts_in_tier(MemoryTier::Storage); + assert_eq!(storage_experts.len(), 3); // Experts 5, 6, 7 + } + + // --------------------------------------------------------------- + // test_tier_summary + // --------------------------------------------------------------- + + #[test] + fn test_tier_summary() { + let config = HardwareConfig::new(16 * 1024 * 1024, 4 * 1024 * 1024 * 1024, 4 * 1024 * 1024); + let mut mapper = SramMapper::from_config(config, 8); + + mapper.assign_tier(0, MemoryTier::Sram); + mapper.assign_tier(1, MemoryTier::Dram); + mapper.assign_tier(2, MemoryTier::Dram); + + let summary = mapper.tier_summary(); + assert_eq!(*summary.get(&MemoryTier::Sram).unwrap(), 1); + assert_eq!(*summary.get(&MemoryTier::Dram).unwrap(), 2); + assert_eq!(*summary.get(&MemoryTier::Storage).unwrap(), 5); + } + + // --------------------------------------------------------------- + // test_memory_tier_properties + // --------------------------------------------------------------- + + #[test] + fn test_memory_tier_properties() { + assert_eq!(MemoryTier::Sram.name(), "SRAM (L2/L3 Cache)"); + assert_eq!(MemoryTier::Dram.name(), "DRAM (Main Memory)"); + assert_eq!(MemoryTier::Storage.name(), "Storage (Flash/NVMe)"); + + assert_eq!(MemoryTier::Sram.index(), 0); + assert_eq!(MemoryTier::Dram.index(), 1); + assert_eq!(MemoryTier::Storage.index(), 2); + } + + // --------------------------------------------------------------- + // test_hardware_preset_names + // --------------------------------------------------------------- + + #[test] + fn test_hardware_preset_names() { + assert_eq!(HardwarePreset::RaspberryPi5.name(), "Raspberry Pi 5"); + assert_eq!(HardwarePreset::Mobile.name(), "Mobile Device"); + assert_eq!(HardwarePreset::Desktop.name(), "Desktop Workstation"); + assert_eq!(HardwarePreset::WasmBrowser.name(), "WASM Browser"); + assert_eq!(HardwarePreset::Custom.name(), "Custom"); + } + + // --------------------------------------------------------------- + // test_expert_affinity_priority_score + // --------------------------------------------------------------- + + #[test] + fn test_expert_affinity_priority_score() { + let mut aff = SramExpertAffinity::new(0); + + // Initial score should be low + let initial_score = aff.priority_score(); + + // Increase access count and check score increases + aff.access_count = 100; + aff.avg_router_weight = 0.9; + let high_score = aff.priority_score(); + + assert!(high_score > initial_score); + } + + // --------------------------------------------------------------- + // test_wasm_browser_preset + // --------------------------------------------------------------- + + #[test] + fn test_wasm_browser_preset() { + let expert_size = 2 * 1024 * 1024; // 2 MiB per expert (binary units) + let mapper = SramMapper::from_preset(HardwarePreset::WasmBrowser, 8, expert_size); + + // WASM: 2MiB SRAM, 1GiB DRAM + // SRAM slots: 2MiB / 2MiB = 1 + // DRAM slots: 1GiB / 2MiB = 512 + assert_eq!(mapper.sram_capacity(), 1); + assert_eq!(mapper.dram_capacity(), 512); + } + + // --------------------------------------------------------------- + // test_out_of_range_expert_id + // --------------------------------------------------------------- + + #[test] + fn test_out_of_range_expert_id() { + let config = HardwareConfig::default(); + let mapper = SramMapper::from_config(config, 4); + + // Out of range should return Storage + assert_eq!(mapper.get_tier(100), MemoryTier::Storage); + assert_eq!(mapper.estimate_paging_latency(100), 100); // Default storage latency + } + + // --------------------------------------------------------------- + // test_record_access_out_of_range + // --------------------------------------------------------------- + + #[test] + fn test_record_access_out_of_range() { + let config = HardwareConfig::default(); + let mut mapper = SramMapper::from_config(config, 4); + + // Should not panic + mapper.record_access(100, 0.5); + + // Counter should not advance for invalid ID + // (actually it does advance, but affinity is not updated) + } +} diff --git a/crates/ruvllm/src/qat/calibration.rs b/crates/ruvllm/src/qat/calibration.rs index 40d1e0c79..bcd7c45b1 100644 --- a/crates/ruvllm/src/qat/calibration.rs +++ b/crates/ruvllm/src/qat/calibration.rs @@ -546,9 +546,6 @@ mod tests { let json = serde_json::to_string(&result).unwrap(); let restored: CalibrationResult = serde_json::from_str(&json).unwrap(); - assert_eq!( - result.scales.get("layer.0"), - restored.scales.get("layer.0") - ); + assert_eq!(result.scales.get("layer.0"), restored.scales.get("layer.0")); } } diff --git a/crates/ruvllm/src/qat/config.rs b/crates/ruvllm/src/qat/config.rs index 8a5f85104..775ba6d6e 100644 --- a/crates/ruvllm/src/qat/config.rs +++ b/crates/ruvllm/src/qat/config.rs @@ -61,9 +61,7 @@ impl QuantGranularity { QuantGranularity::PerTensor => 2, // scale + zero_point QuantGranularity::PerChannel => channels * 2, QuantGranularity::PerToken => (n / channels) * 2, - QuantGranularity::PerBlock { block_size } => { - ((n + block_size - 1) / block_size) * 2 - } + QuantGranularity::PerBlock { block_size } => ((n + block_size - 1) / block_size) * 2, } } } @@ -326,8 +324,8 @@ impl QatConfig { /// Create config for 2-bit Pi-quantization (PiQ2) pub fn piq2() -> Self { Self { - use_incoherence: true, // 2-bit typically needs Hadamard - ..Self::pi_quant(2, 3) // step = pi/3 + use_incoherence: true, // 2-bit typically needs Hadamard + ..Self::pi_quant(2, 3) // step = pi/3 } } @@ -411,7 +409,10 @@ impl QatConfig { pub fn validate(&self) -> Result<(), String> { // Validate bit width if !matches!(self.bits, 2 | 3 | 4 | 5 | 8) { - return Err(format!("Invalid bit width: {}. Must be 2, 3, 4, 5, or 8", self.bits)); + return Err(format!( + "Invalid bit width: {}. Must be 2, 3, 4, 5, or 8", + self.bits + )); } // Validate Pi-k value (INV-3) diff --git a/crates/ruvllm/src/qat/differentiable_quant.rs b/crates/ruvllm/src/qat/differentiable_quant.rs index 72c085a5d..df90aaee6 100644 --- a/crates/ruvllm/src/qat/differentiable_quant.rs +++ b/crates/ruvllm/src/qat/differentiable_quant.rs @@ -258,7 +258,10 @@ impl PiQuantDifferentiable { /// * `k` - Pi divisor (step = alpha * pi / k) pub fn new(bits: u8, k: u8) -> Self { assert!(matches!(bits, 2 | 3 | 4 | 5), "Bits must be 2, 3, 4, or 5"); - assert!(matches!(k, 2 | 3 | 4 | 5), "k must be 2, 3, 4, or 5 (INV-3)"); + assert!( + matches!(k, 2 | 3 | 4 | 5), + "k must be 2, 3, 4, or 5 (INV-3)" + ); Self { bits, @@ -314,7 +317,10 @@ impl PiQuantDifferentiable { let end = start + ch_size; let channel_weights = &weights[start..end]; - let max_abs = channel_weights.iter().map(|w| w.abs()).fold(0.0f32, f32::max); + let max_abs = channel_weights + .iter() + .map(|w| w.abs()) + .fold(0.0f32, f32::max); let step = self.step_size(0); // Use default step for calculation let half = (1 << self.bits) / 2; @@ -427,14 +433,16 @@ impl DifferentiableQuantizer for PiQuantDifferentiable { // Determine channel size for per-channel quantization let channel_size = match &self.granularity { - QuantGranularity::PerChannel if self.num_channels > 1 => { - w.len() / self.num_channels - } + QuantGranularity::PerChannel if self.num_channels > 1 => w.len() / self.num_channels, _ => w.len(), }; for (i, &weight) in w.iter().enumerate() { - let channel = if self.num_channels > 1 { i / channel_size } else { 0 }; + let channel = if self.num_channels > 1 { + i / channel_size + } else { + 0 + }; let (q, dequant) = self.quantize_scalar(weight, channel); q_int.push(q); q_float.push(dequant); @@ -460,7 +468,11 @@ impl DifferentiableQuantizer for PiQuantDifferentiable { .iter() .enumerate() .map(|(i, &q)| { - let channel = if self.num_channels > 1 { i / channel_size } else { 0 }; + let channel = if self.num_channels > 1 { + i / channel_size + } else { + 0 + }; q as f32 * self.step_size(channel) }) .collect() @@ -616,9 +628,13 @@ mod tests { let expected = alpha * PI / (k as f32); let actual = q.step_size(0); - assert!((actual - expected).abs() < 1e-6, - "INV-3 violation: step {} != alpha*pi/k {} for k={}", - actual, expected, k); + assert!( + (actual - expected).abs() < 1e-6, + "INV-3 violation: step {} != alpha*pi/k {} for k={}", + actual, + expected, + k + ); } } @@ -665,9 +681,7 @@ mod tests { fn test_pi_quant_scale_init() { let mut quantizer = PiQuantDifferentiable::piq3(); - let weights: Vec = (0..100) - .map(|i| (i as f32 - 50.0) / 10.0) - .collect(); + let weights: Vec = (0..100).map(|i| (i as f32 - 50.0) / 10.0).collect(); quantizer.init_scale_from_weights(&weights, None); @@ -772,9 +786,18 @@ mod tests { #[test] fn test_num_levels() { - assert_eq!(UniformQuantizer::new(2, SteVariant::Standard).num_levels(), 4); - assert_eq!(UniformQuantizer::new(3, SteVariant::Standard).num_levels(), 8); - assert_eq!(UniformQuantizer::new(4, SteVariant::Standard).num_levels(), 16); + assert_eq!( + UniformQuantizer::new(2, SteVariant::Standard).num_levels(), + 4 + ); + assert_eq!( + UniformQuantizer::new(3, SteVariant::Standard).num_levels(), + 8 + ); + assert_eq!( + UniformQuantizer::new(4, SteVariant::Standard).num_levels(), + 16 + ); assert_eq!(PiQuantDifferentiable::piq3().num_levels(), 8); assert_eq!(PiQuantDifferentiable::piq2().num_levels(), 4); } diff --git a/crates/ruvllm/src/qat/distillation.rs b/crates/ruvllm/src/qat/distillation.rs index 45cabd381..eb8b83c2e 100644 --- a/crates/ruvllm/src/qat/distillation.rs +++ b/crates/ruvllm/src/qat/distillation.rs @@ -219,7 +219,8 @@ impl DistillationLoss { // KD loss (KL divergence) if pos < teacher.seq_len { let teacher_probs = teacher.softmax_at(pos, self.config.temperature); - let student_probs = softmax_with_temperature(student_slice, self.config.temperature); + let student_probs = + softmax_with_temperature(student_slice, self.config.temperature); let kd_loss = kl_divergence(&student_probs, &teacher_probs); total_kd_loss += kd_loss * self.config.temperature.powi(2); } @@ -262,11 +263,7 @@ impl DistillationLoss { } /// Compute KD loss only - pub fn compute_kd_loss( - &self, - student_logits: &[f32], - teacher: &TeacherOutput, - ) -> f32 { + pub fn compute_kd_loss(&self, student_logits: &[f32], teacher: &TeacherOutput) -> f32 { let vocab_size = teacher.vocab_size; let seq_len = teacher.seq_len; let mut total_kd_loss = 0.0; @@ -295,9 +292,11 @@ impl DistillationLoss { let n = self.stats.compute_count as f64; let alpha = 1.0 / (n + 1.0); - self.stats.avg_task_loss = (1.0 - alpha) * self.stats.avg_task_loss + alpha * task_loss as f64; + self.stats.avg_task_loss = + (1.0 - alpha) * self.stats.avg_task_loss + alpha * task_loss as f64; self.stats.avg_kd_loss = (1.0 - alpha) * self.stats.avg_kd_loss + alpha * kd_loss as f64; - self.stats.avg_total_loss = (1.0 - alpha) * self.stats.avg_total_loss + alpha * total_loss as f64; + self.stats.avg_total_loss = + (1.0 - alpha) * self.stats.avg_total_loss + alpha * total_loss as f64; self.stats.compute_count += 1; } @@ -417,10 +416,7 @@ mod tests { // Student and teacher logits (vocab_size=4, seq_len=2) let student_logits = vec![1.0, 2.0, 3.0, 4.0, 2.0, 3.0, 4.0, 5.0]; - let teacher = TeacherOutput::from_logits( - vec![1.1, 2.1, 3.1, 4.1, 2.1, 3.1, 4.1, 5.1], - 4, - ); + let teacher = TeacherOutput::from_logits(vec![1.1, 2.1, 3.1, 4.1, 2.1, 3.1, 4.1, 5.1], 4); let labels = vec![3, 3]; let loss = loss_fn.compute(&student_logits, &teacher, &labels); diff --git a/crates/ruvllm/src/qat/lora_qat.rs b/crates/ruvllm/src/qat/lora_qat.rs index b05a7aff9..8ff8d2533 100644 --- a/crates/ruvllm/src/qat/lora_qat.rs +++ b/crates/ruvllm/src/qat/lora_qat.rs @@ -303,11 +303,7 @@ impl LoraWeights { } /// Compute gradients for A and B - pub fn backward( - &self, - input: &[f32], - grad_output: &[f32], - ) -> (Vec, Vec, Vec) { + pub fn backward(&self, input: &[f32], grad_output: &[f32]) -> (Vec, Vec, Vec) { let batch_size = input.len() / self.d_in; // Scale gradient diff --git a/crates/ruvllm/src/qat/mod.rs b/crates/ruvllm/src/qat/mod.rs index 104ab64b0..e5bd3a0bd 100644 --- a/crates/ruvllm/src/qat/mod.rs +++ b/crates/ruvllm/src/qat/mod.rs @@ -156,9 +156,7 @@ pub use calibration::{ }; // Distillation loss (ADR-090 Phase 2) -pub use distillation::{ - DistillationConfig, DistillationLoss, DistillationStats, TeacherOutput, -}; +pub use distillation::{DistillationConfig, DistillationLoss, DistillationStats, TeacherOutput}; // Reasoning loss (ADR-090 Phase 2) pub use reasoning_loss::{ @@ -172,9 +170,7 @@ pub use training_loop::{ }; // LoRA-QAT integration (ADR-090 Phase 2) -pub use lora_qat::{ - LoraGradients, LoraQatConfig, LoraQatLayer, LoraQatModel, LoraWeights, -}; +pub use lora_qat::{LoraGradients, LoraQatConfig, LoraQatLayer, LoraQatModel, LoraWeights}; // STE SIMD optimizations (platform-specific) #[cfg(target_arch = "aarch64")] @@ -307,9 +303,7 @@ mod tests { let quantizer = create_quantizer(&config); // Sample weights - let weights: Vec = (0..256) - .map(|i| (i as f32 - 128.0) / 128.0) - .collect(); + let weights: Vec = (0..256).map(|i| (i as f32 - 128.0) / 128.0).collect(); // Forward pass let (q_int, q_dequant) = quantizer.forward(&weights); @@ -331,9 +325,7 @@ mod tests { #[test] fn test_config_serialization() { - let config = QatConfig::piq3() - .with_epochs(10) - .with_learning_rate(5e-5); + let config = QatConfig::piq3().with_epochs(10).with_learning_rate(5e-5); let json = config.to_json().unwrap(); let restored = QatConfig::from_json(&json).unwrap(); diff --git a/crates/ruvllm/src/qat/reasoning_loss.rs b/crates/ruvllm/src/qat/reasoning_loss.rs index 673299091..c6cc15925 100644 --- a/crates/ruvllm/src/qat/reasoning_loss.rs +++ b/crates/ruvllm/src/qat/reasoning_loss.rs @@ -339,7 +339,10 @@ impl ChainOfThoughtLoss { self.metrics.chains_evaluated += 1; // Average step similarity - let avg_sim: f64 = steps.iter().map(|s| s.cosine_similarity() as f64).sum::() + let avg_sim: f64 = steps + .iter() + .map(|s| s.cosine_similarity() as f64) + .sum::() / steps.len() as f64; let n = self.metrics.chains_evaluated as f64; self.metrics.avg_step_similarity = @@ -355,8 +358,9 @@ impl ChainOfThoughtLoss { // Answer agreement if let (Some(t), Some(s)) = (teacher_answer, student_answer) { let agrees = t.trim() == s.trim(); - self.metrics.answer_agreement_rate = - (self.metrics.answer_agreement_rate * (n - 1.0) + if agrees { 1.0 } else { 0.0 }) / n; + self.metrics.answer_agreement_rate = (self.metrics.answer_agreement_rate * (n - 1.0) + + if agrees { 1.0 } else { 0.0 }) + / n; } } diff --git a/crates/ruvllm/src/qat/ste.rs b/crates/ruvllm/src/qat/ste.rs index 28d719c27..3fb5e11f4 100644 --- a/crates/ruvllm/src/qat/ste.rs +++ b/crates/ruvllm/src/qat/ste.rs @@ -70,9 +70,7 @@ impl SteVariant { // EWGS: Gradient scaled by quantization error // dL/dw = dL/dq * (1 + lambda * |w - q|) // This gives stronger gradient signal for weights far from quantization points - SteVariant::Ewgs { lambda } => { - grad_out * (1.0 + lambda * (w - q).abs()) - } + SteVariant::Ewgs { lambda } => grad_out * (1.0 + lambda * (w - q).abs()), } } @@ -160,10 +158,7 @@ impl SteVariant { pub mod simd { /// NEON-accelerated backward pass (identity, no-op for Standard STE) #[inline] - pub unsafe fn backward_standard_neon( - grad_out: &[f32], - grad_w: &mut [f32], - ) { + pub unsafe fn backward_standard_neon(grad_out: &[f32], grad_w: &mut [f32]) { // For Standard STE, just copy grad_w.copy_from_slice(grad_out); } @@ -244,7 +239,11 @@ pub mod simd { // Handle remainder while i < n { - grad_w[i] = if weights[i].abs() <= clip_val { grad_out[i] } else { 0.0 }; + grad_w[i] = if weights[i].abs() <= clip_val { + grad_out[i] + } else { + 0.0 + }; i += 1; } } @@ -437,7 +436,12 @@ mod tests { let ste_ewgs = SteVariant::Ewgs { lambda: 0.1 }; let expected = 0.3_f32 * (1.0_f32 + 0.1_f32 * (0.7_f32 - 0.5_f32).abs()); let actual = ste_ewgs.backward(0.7, 0.5, 0.3); - assert!((actual - expected).abs() < 1e-6, "EWGS mismatch: {} vs {}", actual, expected); + assert!( + (actual - expected).abs() < 1e-6, + "EWGS mismatch: {} vs {}", + actual, + expected + ); } #[test] @@ -472,8 +476,14 @@ mod tests { for i in 0..100 { let diff = (grad_scalar[i] - grad_simd[i]).abs(); let ulp = f32::EPSILON * grad_scalar[i].abs().max(1.0); - assert!(diff <= ulp, "SIMD mismatch at {}: {} vs {} (diff {})", - i, grad_scalar[i], grad_simd[i], diff); + assert!( + diff <= ulp, + "SIMD mismatch at {}: {} vs {} (diff {})", + i, + grad_scalar[i], + grad_simd[i], + diff + ); } } @@ -499,9 +509,11 @@ mod tests { // Compare for i in 0..100 { - assert_eq!(grad_scalar[i], grad_simd[i], - "Clipped SIMD mismatch at {}: {} vs {}", - i, grad_scalar[i], grad_simd[i]); + assert_eq!( + grad_scalar[i], grad_simd[i], + "Clipped SIMD mismatch at {}: {} vs {}", + i, grad_scalar[i], grad_simd[i] + ); } } } diff --git a/crates/ruvllm/src/qat/training_loop.rs b/crates/ruvllm/src/qat/training_loop.rs index 5e921ebeb..ca799f4ca 100644 --- a/crates/ruvllm/src/qat/training_loop.rs +++ b/crates/ruvllm/src/qat/training_loop.rs @@ -441,7 +441,10 @@ impl QatTrainer { } /// Run calibration phase - pub fn calibrate(&mut self, activations: &HashMap>) -> Result { + pub fn calibrate( + &mut self, + activations: &HashMap>, + ) -> Result { self.phase = TrainingPhase::Calibration; self.emit(QatEvent::CalibrationStarted { config: CalibrationConfig::default(), @@ -481,12 +484,16 @@ impl QatTrainer { // Compute loss components let (task_loss, kd_loss) = if let Some(ref teacher) = batch.teacher_output { - let loss = self.distillation_loss.compute(&student_logits, teacher, &labels); + let loss = self + .distillation_loss + .compute(&student_logits, teacher, &labels); let stats = self.distillation_loss.stats(); (stats.avg_task_loss as f32, stats.avg_kd_loss as f32) } else { let vocab_size = 32000; // TODO: Get from model - let task_loss = self.distillation_loss.compute_task_loss(&student_logits, &labels, vocab_size); + let task_loss = + self.distillation_loss + .compute_task_loss(&student_logits, &labels, vocab_size); (task_loss, 0.0) }; @@ -532,11 +539,7 @@ impl QatTrainer { } /// Run a single training epoch - pub fn train_epoch( - &mut self, - epoch: usize, - batches: &[TrainingBatch], - ) -> Result { + pub fn train_epoch(&mut self, epoch: usize, batches: &[TrainingBatch]) -> Result { self.phase = TrainingPhase::Training; let epoch_start = Instant::now(); self.current_epoch_steps.clear(); diff --git a/crates/ruvllm/src/quantize/hadamard.rs b/crates/ruvllm/src/quantize/hadamard.rs index 31d290516..2c1f23cb3 100644 --- a/crates/ruvllm/src/quantize/hadamard.rs +++ b/crates/ruvllm/src/quantize/hadamard.rs @@ -116,9 +116,14 @@ impl HadamardTransform { let mut rng_state = s; let signs: Vec = (0..dim) .map(|_| { - rng_state = rng_state.wrapping_mul(6364136223846793005) + rng_state = rng_state + .wrapping_mul(6364136223846793005) .wrapping_add(1442695040888963407); - if (rng_state >> 63) & 1 == 0 { 1 } else { -1 } + if (rng_state >> 63) & 1 == 0 { + 1 + } else { + -1 + } }) .collect(); (signs, true) @@ -732,12 +737,7 @@ mod tests { transform.inverse_inplace(&mut data); for (a, b) in data.iter().zip(original.iter()) { - assert!( - (a - b).abs() < 1e-5, - "Roundtrip failed: {} vs {}", - a, - b - ); + assert!((a - b).abs() < 1e-5, "Roundtrip failed: {} vs {}", a, b); } } @@ -793,11 +793,7 @@ mod tests { // After normalization by 1/sqrt(4) = 0.5 for &v in &data { - assert!( - (v - 0.5).abs() < 1e-5, - "Expected 0.5, got {}", - v - ); + assert!((v - 0.5).abs() < 1e-5, "Expected 0.5, got {}", v); } } diff --git a/crates/ruvllm/src/quantize/incoherence.rs b/crates/ruvllm/src/quantize/incoherence.rs index 81fa677bd..6d0dd583b 100644 --- a/crates/ruvllm/src/quantize/incoherence.rs +++ b/crates/ruvllm/src/quantize/incoherence.rs @@ -26,8 +26,8 @@ use std::time::Instant; use super::hadamard::{ - hadamard_batch_inverse, hadamard_batch_transform, log2_exact, next_power_of_2, pad_to_power_of_2, - HadamardTransform, + hadamard_batch_inverse, hadamard_batch_transform, log2_exact, next_power_of_2, + pad_to_power_of_2, HadamardTransform, }; use crate::error::{Result, RuvLLMError}; @@ -291,7 +291,10 @@ impl IncoherenceTransform { let log_dim = match log2_exact(target_len) { Some(ld) => ld, None => { - self.emit_error("Internal error: padded length not power of 2", IncoherencePhase::Forward); + self.emit_error( + "Internal error: padded length not power of 2", + IncoherencePhase::Forward, + ); return Err(RuvLLMError::Quantization( "Padded length is not a power of 2".to_string(), )); @@ -368,10 +371,7 @@ impl IncoherenceTransform { let log_dim = match log2_exact(current_len) { Some(ld) => ld, None => { - self.emit_error( - "Data length is not a power of 2", - IncoherencePhase::Inverse, - ); + self.emit_error("Data length is not a power of 2", IncoherencePhase::Inverse); return Err(RuvLLMError::Quantization( "Data length must be a power of 2 for inverse transform".to_string(), )); @@ -387,7 +387,9 @@ impl IncoherenceTransform { // Truncate to original length if provided let final_len = original_len.unwrap_or_else(|| { let data_id = data.as_ptr() as usize; - self.pending_original_dims.remove(&data_id).unwrap_or(current_len) + self.pending_original_dims + .remove(&data_id) + .unwrap_or(current_len) }); if final_len < current_len { @@ -420,12 +422,7 @@ impl IncoherenceTransform { /// * `data` - Flat buffer containing `batch_size` vectors of `dim` elements each /// * `dim` - Dimension of each vector (must be power of 2) /// * `batch_size` - Number of vectors - pub fn apply_batch( - &mut self, - data: &mut [f32], - dim: usize, - batch_size: usize, - ) -> Result<()> { + pub fn apply_batch(&mut self, data: &mut [f32], dim: usize, batch_size: usize) -> Result<()> { if data.len() != dim * batch_size { return Err(RuvLLMError::Quantization(format!( "Data length {} doesn't match dim {} * batch_size {}", @@ -454,12 +451,7 @@ impl IncoherenceTransform { } /// Restore a batch of weight vectors after dequantization - pub fn restore_batch( - &mut self, - data: &mut [f32], - dim: usize, - batch_size: usize, - ) -> Result<()> { + pub fn restore_batch(&mut self, data: &mut [f32], dim: usize, batch_size: usize) -> Result<()> { if data.len() != dim * batch_size { return Err(RuvLLMError::Quantization(format!( "Data length {} doesn't match dim {} * batch_size {}", @@ -566,7 +558,11 @@ pub fn apply_incoherence(data: &mut Vec, seed: Option) -> Result, original_len: usize, seed: Option) -> Result<()> { +pub fn restore_incoherence( + data: &mut Vec, + original_len: usize, + seed: Option, +) -> Result<()> { let config = IncoherenceConfig { seed, randomized: seed.is_some(), @@ -603,7 +599,9 @@ mod tests { let padded_dim = transform.apply_before_quantization(&mut data).unwrap(); assert_eq!(padded_dim, 8); - transform.restore_after_dequantization(&mut data, Some(8)).unwrap(); + transform + .restore_after_dequantization(&mut data, Some(8)) + .unwrap(); for (a, b) in data.iter().zip(original.iter()) { assert!((a - b).abs() < 1e-5, "Roundtrip failed: {} vs {}", a, b); @@ -627,11 +625,18 @@ mod tests { assert_eq!(padded_dim, 8); assert_eq!(data.len(), 8); - transform.restore_after_dequantization(&mut data, Some(original_len)).unwrap(); + transform + .restore_after_dequantization(&mut data, Some(original_len)) + .unwrap(); assert_eq!(data.len(), original_len); for (a, b) in data.iter().zip(original.iter()) { - assert!((a - b).abs() < 1e-5, "Padded roundtrip failed: {} vs {}", a, b); + assert!( + (a - b).abs() < 1e-5, + "Padded roundtrip failed: {} vs {}", + a, + b + ); } } @@ -649,11 +654,17 @@ mod tests { // Data with an outlier let mut data: Vec = vec![1.0, 1.0, 1.0, 100.0, 1.0, 1.0, 1.0, 1.0]; - let max_before: f32 = data.iter().map(|x: &f32| x.abs()).fold(0.0f32, |a: f32, b: f32| a.max(b)); + let max_before: f32 = data + .iter() + .map(|x: &f32| x.abs()) + .fold(0.0f32, |a: f32, b: f32| a.max(b)); transform.apply_before_quantization(&mut data).unwrap(); - let max_after: f32 = data.iter().map(|x: &f32| x.abs()).fold(0.0f32, |a: f32, b: f32| a.max(b)); + let max_after: f32 = data + .iter() + .map(|x: &f32| x.abs()) + .fold(0.0f32, |a: f32, b: f32| a.max(b)); // The outlier should be spread across all elements // Max after should be significantly smaller than 100 @@ -667,7 +678,12 @@ mod tests { // Check that events were emitted let events = transform.take_events(); assert!(!events.is_empty()); - if let IncoherenceEvent::IncoherenceApplied { max_before: mb, max_after: ma, .. } = &events[0] { + if let IncoherenceEvent::IncoherenceApplied { + max_before: mb, + max_after: ma, + .. + } = &events[0] + { assert!((*ma) < (*mb) * 0.9); } } diff --git a/crates/ruvllm/src/quantize/mod.rs b/crates/ruvllm/src/quantize/mod.rs index 03cab2531..d2853f4b2 100644 --- a/crates/ruvllm/src/quantize/mod.rs +++ b/crates/ruvllm/src/quantize/mod.rs @@ -110,25 +110,25 @@ pub use ruvltra_quant::{ // Pi-Quantization SIMD kernels pub use pi_quant_simd::{ - // Constants - DEFAULT_K, - PI3_BYTES_PER_GROUP, - PI3_VALUES_PER_GROUP, - PI_F32, + // Utility functions + extract_pi3_value, // Runtime dispatch (selects best kernel) pi_dequantize, pi_dequantize_kernel_name, - pi_quantize, - pi_quantize_kernel_name, // Scalar reference (always available) pi_dequantize_scalar, + pi_quantize, + pi_quantize_kernel_name, pi_quantize_scalar, - // Utility functions - extract_pi3_value, pi_quantize_value, pi_scale, pi_scale_adaptive, pi_scale_from_max, + // Constants + DEFAULT_K, + PI3_BYTES_PER_GROUP, + PI3_VALUES_PER_GROUP, + PI_F32, }; // Architecture-specific SIMD kernels (conditionally exported) @@ -140,11 +140,7 @@ pub use pi_quant_simd::{pi_dequantize_avx2, pi_dequantize_avx512, pi_quantize_av // High-performance quantization (ADR-090 >1 GB/s target) pub use pi_quant::{ - batch_quantize_3bit, - quantize_2bit, - quantize_2bit_fast, - quantize_3bit, - quantize_3bit_fast, + batch_quantize_3bit, quantize_2bit, quantize_2bit_fast, quantize_3bit, quantize_3bit_fast, quantize_kernel_name, }; @@ -157,33 +153,17 @@ pub use pi_quant::{quantize_2bit_avx2, quantize_3bit_avx2}; // Hadamard transform (ADR-090 Phase 3) pub use hadamard::{ - hadamard_batch_inverse, - hadamard_batch_transform, - log2_exact, - next_power_of_2, - pad_to_power_of_2, - HadamardTransform, - MAX_LOG_DIM, - SIMD_LANES, + hadamard_batch_inverse, hadamard_batch_transform, log2_exact, next_power_of_2, + pad_to_power_of_2, HadamardTransform, MAX_LOG_DIM, SIMD_LANES, }; // Incoherence transform (ADR-090 Phase 3) pub use incoherence::{ - apply_incoherence, - restore_incoherence, - IncoherenceConfig, - IncoherenceEvent, - IncoherencePhase, - IncoherenceStats, - IncoherenceTransform, + apply_incoherence, restore_incoherence, IncoherenceConfig, IncoherenceEvent, IncoherencePhase, + IncoherenceStats, IncoherenceTransform, }; // QuIP 2-bit quantization (ADR-090 Phase 3) pub use quip::{ - Q2QuipBlock, - Q2QuipSuperBlock, - QuipCodebook, - QuipConfig, - QuipMetadata, - QuipQuantizer, + Q2QuipBlock, Q2QuipSuperBlock, QuipCodebook, QuipConfig, QuipMetadata, QuipQuantizer, }; diff --git a/crates/ruvllm/src/quantize/pi_quant.rs b/crates/ruvllm/src/quantize/pi_quant.rs index d23131228..b4e3fc42b 100644 --- a/crates/ruvllm/src/quantize/pi_quant.rs +++ b/crates/ruvllm/src/quantize/pi_quant.rs @@ -799,7 +799,10 @@ pub fn dequantize_tensor_2bit( /// /// Number of bytes written to output. pub fn quantize_3bit_fast(weights: &[f32], step: f32, output: &mut [u8]) -> usize { - debug_assert!(weights.len() % PI3_BLOCK_WEIGHTS == 0, "Weight length must be multiple of 8"); + debug_assert!( + weights.len() % PI3_BLOCK_WEIGHTS == 0, + "Weight length must be multiple of 8" + ); let num_blocks = weights.len() / PI3_BLOCK_WEIGHTS; let output_bytes = num_blocks * PI3_BLOCK_BYTES; @@ -879,7 +882,10 @@ unsafe fn quantize_3bit_inner( /// /// Number of bytes written to output. pub fn quantize_2bit_fast(weights: &[f32], step: f32, output: &mut [u8]) -> usize { - debug_assert!(weights.len() % PI2_BLOCK_WEIGHTS == 0, "Weight length must be multiple of 4"); + debug_assert!( + weights.len() % PI2_BLOCK_WEIGHTS == 0, + "Weight length must be multiple of 4" + ); let num_blocks = weights.len() / PI2_BLOCK_WEIGHTS; diff --git a/crates/ruvllm/src/quantize/pi_quant_simd.rs b/crates/ruvllm/src/quantize/pi_quant_simd.rs index fa5bf9554..3335467d0 100644 --- a/crates/ruvllm/src/quantize/pi_quant_simd.rs +++ b/crates/ruvllm/src/quantize/pi_quant_simd.rs @@ -264,56 +264,101 @@ pub unsafe fn pi_dequantize_neon(packed: &[u8], scale: f32, output: &mut [f32]) let lo0 = vandq_u32(vshlq_u32(v0, shifts_lo), mask_3bit); let hi0 = vandq_u32(vshlq_u32(v0, shifts_hi), mask_3bit); vst1q_f32(o, vfmaq_f32(bias_scaled, vcvtq_f32_u32(lo0), scale_vec)); - vst1q_f32(o.add(4), vfmaq_f32(bias_scaled, vcvtq_f32_u32(hi0), scale_vec)); + vst1q_f32( + o.add(4), + vfmaq_f32(bias_scaled, vcvtq_f32_u32(hi0), scale_vec), + ); // Group 1 let v1 = vdupq_n_u32(c1); let lo1 = vandq_u32(vshlq_u32(v1, shifts_lo), mask_3bit); let hi1 = vandq_u32(vshlq_u32(v1, shifts_hi), mask_3bit); - vst1q_f32(o.add(8), vfmaq_f32(bias_scaled, vcvtq_f32_u32(lo1), scale_vec)); - vst1q_f32(o.add(12), vfmaq_f32(bias_scaled, vcvtq_f32_u32(hi1), scale_vec)); + vst1q_f32( + o.add(8), + vfmaq_f32(bias_scaled, vcvtq_f32_u32(lo1), scale_vec), + ); + vst1q_f32( + o.add(12), + vfmaq_f32(bias_scaled, vcvtq_f32_u32(hi1), scale_vec), + ); // Group 2 let v2 = vdupq_n_u32(c2); let lo2 = vandq_u32(vshlq_u32(v2, shifts_lo), mask_3bit); let hi2 = vandq_u32(vshlq_u32(v2, shifts_hi), mask_3bit); - vst1q_f32(o.add(16), vfmaq_f32(bias_scaled, vcvtq_f32_u32(lo2), scale_vec)); - vst1q_f32(o.add(20), vfmaq_f32(bias_scaled, vcvtq_f32_u32(hi2), scale_vec)); + vst1q_f32( + o.add(16), + vfmaq_f32(bias_scaled, vcvtq_f32_u32(lo2), scale_vec), + ); + vst1q_f32( + o.add(20), + vfmaq_f32(bias_scaled, vcvtq_f32_u32(hi2), scale_vec), + ); // Group 3 let v3 = vdupq_n_u32(c3); let lo3 = vandq_u32(vshlq_u32(v3, shifts_lo), mask_3bit); let hi3 = vandq_u32(vshlq_u32(v3, shifts_hi), mask_3bit); - vst1q_f32(o.add(24), vfmaq_f32(bias_scaled, vcvtq_f32_u32(lo3), scale_vec)); - vst1q_f32(o.add(28), vfmaq_f32(bias_scaled, vcvtq_f32_u32(hi3), scale_vec)); + vst1q_f32( + o.add(24), + vfmaq_f32(bias_scaled, vcvtq_f32_u32(lo3), scale_vec), + ); + vst1q_f32( + o.add(28), + vfmaq_f32(bias_scaled, vcvtq_f32_u32(hi3), scale_vec), + ); // Group 4 let v4 = vdupq_n_u32(c4); let lo4 = vandq_u32(vshlq_u32(v4, shifts_lo), mask_3bit); let hi4 = vandq_u32(vshlq_u32(v4, shifts_hi), mask_3bit); - vst1q_f32(o.add(32), vfmaq_f32(bias_scaled, vcvtq_f32_u32(lo4), scale_vec)); - vst1q_f32(o.add(36), vfmaq_f32(bias_scaled, vcvtq_f32_u32(hi4), scale_vec)); + vst1q_f32( + o.add(32), + vfmaq_f32(bias_scaled, vcvtq_f32_u32(lo4), scale_vec), + ); + vst1q_f32( + o.add(36), + vfmaq_f32(bias_scaled, vcvtq_f32_u32(hi4), scale_vec), + ); // Group 5 let v5 = vdupq_n_u32(c5); let lo5 = vandq_u32(vshlq_u32(v5, shifts_lo), mask_3bit); let hi5 = vandq_u32(vshlq_u32(v5, shifts_hi), mask_3bit); - vst1q_f32(o.add(40), vfmaq_f32(bias_scaled, vcvtq_f32_u32(lo5), scale_vec)); - vst1q_f32(o.add(44), vfmaq_f32(bias_scaled, vcvtq_f32_u32(hi5), scale_vec)); + vst1q_f32( + o.add(40), + vfmaq_f32(bias_scaled, vcvtq_f32_u32(lo5), scale_vec), + ); + vst1q_f32( + o.add(44), + vfmaq_f32(bias_scaled, vcvtq_f32_u32(hi5), scale_vec), + ); // Group 6 let v6 = vdupq_n_u32(c6); let lo6 = vandq_u32(vshlq_u32(v6, shifts_lo), mask_3bit); let hi6 = vandq_u32(vshlq_u32(v6, shifts_hi), mask_3bit); - vst1q_f32(o.add(48), vfmaq_f32(bias_scaled, vcvtq_f32_u32(lo6), scale_vec)); - vst1q_f32(o.add(52), vfmaq_f32(bias_scaled, vcvtq_f32_u32(hi6), scale_vec)); + vst1q_f32( + o.add(48), + vfmaq_f32(bias_scaled, vcvtq_f32_u32(lo6), scale_vec), + ); + vst1q_f32( + o.add(52), + vfmaq_f32(bias_scaled, vcvtq_f32_u32(hi6), scale_vec), + ); // Group 7 let v7 = vdupq_n_u32(c7); let lo7 = vandq_u32(vshlq_u32(v7, shifts_lo), mask_3bit); let hi7 = vandq_u32(vshlq_u32(v7, shifts_hi), mask_3bit); - vst1q_f32(o.add(56), vfmaq_f32(bias_scaled, vcvtq_f32_u32(lo7), scale_vec)); - vst1q_f32(o.add(60), vfmaq_f32(bias_scaled, vcvtq_f32_u32(hi7), scale_vec)); + vst1q_f32( + o.add(56), + vfmaq_f32(bias_scaled, vcvtq_f32_u32(lo7), scale_vec), + ); + vst1q_f32( + o.add(60), + vfmaq_f32(bias_scaled, vcvtq_f32_u32(hi7), scale_vec), + ); group += 8; } @@ -352,7 +397,10 @@ unsafe fn neon_extract_and_convert( combined: u32, bias_f32: core::arch::aarch64::float32x4_t, scale_vec: core::arch::aarch64::float32x4_t, -) -> (core::arch::aarch64::float32x4_t, core::arch::aarch64::float32x4_t) { +) -> ( + core::arch::aarch64::float32x4_t, + core::arch::aarch64::float32x4_t, +) { use core::arch::aarch64::*; // OPTIMIZED: Use NEON operations instead of scalar extraction @@ -480,8 +528,8 @@ pub unsafe fn pi_dequantize_avx512(packed: &[u8], scale: f32, output: &mut [f32] // Load all 16 values into AVX-512 vector let raw_vec = _mm512_setr_epi32( - v0_0, v0_1, v0_2, v0_3, v0_4, v0_5, v0_6, v0_7, - v1_0, v1_1, v1_2, v1_3, v1_4, v1_5, v1_6, v1_7, + v0_0, v0_1, v0_2, v0_3, v0_4, v0_5, v0_6, v0_7, v1_0, v1_1, v1_2, v1_3, v1_4, v1_5, + v1_6, v1_7, ); // Apply bias (sign extension: raw - 4) @@ -557,7 +605,11 @@ pub unsafe fn pi_quantize_avx512(weights: &[f32], scale: f32, output: &mut [u8]) return; } - let inv_scale = if scale.abs() > 1e-10 { 1.0 / scale } else { 0.0 }; + let inv_scale = if scale.abs() > 1e-10 { + 1.0 / scale + } else { + 0.0 + }; // Broadcast inverse scale to all 16 lanes let inv_scale_vec = _mm512_set1_ps(inv_scale); @@ -569,8 +621,8 @@ pub unsafe fn pi_quantize_avx512(weights: &[f32], scale: f32, output: &mut [u8]) let min_vec = _mm512_set1_epi32(0); let max_vec = _mm512_set1_epi32(7); - // Rounding mode constant (nearest) - let rounding = _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC; + // Rounding mode constant (nearest) - must be const for AVX-512 intrinsics + const ROUNDING: i32 = 0x08; // _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC // Process 2 groups (16 values) at a time let simd_groups = num_groups / 2; @@ -585,7 +637,7 @@ pub unsafe fn pi_quantize_avx512(weights: &[f32], scale: f32, output: &mut [u8]) // Quantize: q = round(w * inv_scale) let scaled_vec = _mm512_mul_ps(weights_vec, inv_scale_vec); - let rounded_vec = _mm512_roundscale_ps(scaled_vec, rounding as i32); + let rounded_vec = _mm512_roundscale_ps(scaled_vec, ROUNDING); let quantized_vec = _mm512_cvtps_epi32(rounded_vec); // Add bias: [−4, +3] -> [0, 7] @@ -932,7 +984,11 @@ pub fn pi_quantize_scalar(values: &[f32], scale: f32, output: &mut [u8]) { "Output buffer size mismatch" ); - let inv_scale = if scale.abs() > 1e-10 { 1.0 / scale } else { 0.0 }; + let inv_scale = if scale.abs() > 1e-10 { + 1.0 / scale + } else { + 0.0 + }; for group in 0..num_groups { let val_offset = group * PI3_VALUES_PER_GROUP; @@ -1068,11 +1124,7 @@ mod tests { pi_dequantize_scalar(&packed, scale, &mut output); for &v in &output { - assert!( - (v - (-4.0)).abs() < EPSILON, - "Expected -4.0, got {}", - v - ); + assert!((v - (-4.0)).abs() < EPSILON, "Expected -4.0, got {}", v); } } @@ -1234,9 +1286,7 @@ mod tests { #[test] fn test_quantize_dequantize_roundtrip() { let scale = pi_scale(4); - let original: Vec = (-4..=3) - .map(|v| (v as f32) * scale) - .collect(); + let original: Vec = (-4..=3).map(|v| (v as f32) * scale).collect(); let mut packed = vec![0u8; 3]; let mut reconstructed = vec![0.0f32; 8]; @@ -1272,11 +1322,11 @@ mod tests { assert!((output[0] - (-4.0)).abs() < EPSILON); // -10 -> -4 assert!((output[1] - (-4.0)).abs() < EPSILON); // -5 -> -4 assert!((output[2] - (-4.0)).abs() < EPSILON); // -4 -> -4 - assert!((output[3] - 0.0).abs() < EPSILON); // 0 -> 0 - assert!((output[4] - 3.0).abs() < EPSILON); // 3 -> 3 - assert!((output[5] - 3.0).abs() < EPSILON); // 5 -> 3 - assert!((output[6] - 3.0).abs() < EPSILON); // 10 -> 3 - assert!((output[7] - 3.0).abs() < EPSILON); // 100 -> 3 + assert!((output[3] - 0.0).abs() < EPSILON); // 0 -> 0 + assert!((output[4] - 3.0).abs() < EPSILON); // 3 -> 3 + assert!((output[5] - 3.0).abs() < EPSILON); // 5 -> 3 + assert!((output[6] - 3.0).abs() < EPSILON); // 10 -> 3 + assert!((output[7] - 3.0).abs() < EPSILON); // 100 -> 3 } // ------------------------------------------------------------------------- @@ -1550,21 +1600,19 @@ mod tests { // Test with edge case scales let test_scales = [ - 1.0f32, // Unit scale - 0.001, // Very small scale - 1000.0, // Large scale - -1.0, // Negative scale - PI / 4.0, // Typical pi-quantization scale - PI / 2.0, // Another pi-based scale + 1.0f32, // Unit scale + 0.001, // Very small scale + 1000.0, // Large scale + -1.0, // Negative scale + PI / 4.0, // Typical pi-quantization scale + PI / 2.0, // Another pi-based scale f32::MIN_POSITIVE, // Smallest positive normal ]; for &scale in &test_scales { // Generate packed data let num_groups = 8; - let packed: Vec = (0..num_groups * 3) - .map(|i| (i * 31) as u8) - .collect(); + let packed: Vec = (0..num_groups * 3).map(|i| (i * 31) as u8).collect(); let mut scalar_output = vec![0.0f32; num_groups * 8]; let mut avx512_output = vec![0.0f32; num_groups * 8]; @@ -1594,9 +1642,7 @@ mod tests { // Ensure dispatch produces same results as scalar // Test sizes that exercise all paths including AVX-512's 8-group batching for num_groups in [1, 4, 8, 16, 32, 100, 123] { - let packed: Vec = (0..num_groups * 3) - .map(|i| (i * 23) as u8) - .collect(); + let packed: Vec = (0..num_groups * 3).map(|i| (i * 23) as u8).collect(); let scale = pi_scale(4); let mut scalar_output = vec![0.0f32; num_groups * 8]; @@ -1703,11 +1749,7 @@ mod tests { // 1 * -1.0 = -1.0 for &v in &output { - assert!( - (v - (-1.0)).abs() < EPSILON, - "Expected -1.0, got {}", - v - ); + assert!((v - (-1.0)).abs() < EPSILON, "Expected -1.0, got {}", v); } } @@ -1774,9 +1816,7 @@ mod tests { // 1000 groups = 3000 bytes = 8000 values // Exercises SIMD main loop + remainder handling let num_groups = 1000; - let packed: Vec = (0..num_groups * 3) - .map(|i| (i % 256) as u8) - .collect(); + let packed: Vec = (0..num_groups * 3).map(|i| (i % 256) as u8).collect(); let scale = pi_scale(4); let mut output = vec![0.0f32; num_groups * 8]; @@ -1784,12 +1824,7 @@ mod tests { // Verify no NaN or Inf values for (i, &v) in output.iter().enumerate() { - assert!( - v.is_finite(), - "Non-finite value at index {}: {}", - i, - v - ); + assert!(v.is_finite(), "Non-finite value at index {}: {}", i, v); // Values should be in range [-4*scale, 3*scale] let min_val = -4.0 * scale; let max_val = 3.0 * scale; @@ -1809,9 +1844,7 @@ mod tests { // Test cases with various remainder sizes after SIMD loop // SIMD processes 4 groups at a time, so test 1, 2, 3, 5, 6, 7 groups for num_groups in [1, 2, 3, 5, 6, 7, 9, 13, 17] { - let packed: Vec = (0..num_groups * 3) - .map(|i| (i * 37) as u8) - .collect(); + let packed: Vec = (0..num_groups * 3).map(|i| (i * 37) as u8).collect(); let scale = 1.0; let mut scalar_output = vec![0.0f32; num_groups * 8]; diff --git a/crates/ruvllm/src/quantize/quip.rs b/crates/ruvllm/src/quantize/quip.rs index 306353d2d..a66ba7752 100644 --- a/crates/ruvllm/src/quantize/quip.rs +++ b/crates/ruvllm/src/quantize/quip.rs @@ -664,8 +664,16 @@ impl QuipQuantizer { .zip(restored.iter()) .map(|(a, b)| (*a as f64) * (*b as f64)) .sum(); - let norm_a: f64 = original.iter().map(|a| (*a as f64).powi(2)).sum::().sqrt(); - let norm_b: f64 = restored.iter().map(|b| (*b as f64).powi(2)).sum::().sqrt(); + let norm_a: f64 = original + .iter() + .map(|a| (*a as f64).powi(2)) + .sum::() + .sqrt(); + let norm_b: f64 = restored + .iter() + .map(|b| (*b as f64).powi(2)) + .sum::() + .sqrt(); self.stats.cosine_similarity = if norm_a > 0.0 && norm_b > 0.0 { dot / (norm_a * norm_b) diff --git a/crates/ruvllm/src/quantize/ruvltra_quant.rs b/crates/ruvllm/src/quantize/ruvltra_quant.rs index 3771dbfe8..362aa18f4 100644 --- a/crates/ruvllm/src/quantize/ruvltra_quant.rs +++ b/crates/ruvllm/src/quantize/ruvltra_quant.rs @@ -104,8 +104,8 @@ impl TargetFormat { TargetFormat::Q8_0 => Q8_BLOCK_SIZE, TargetFormat::F16 => 1, // Pi-quant block sizes - TargetFormat::PiQ3 => 8, // 8 weights per 3-byte block - TargetFormat::PiQ2 => 4, // 4 weights per 1-byte block + TargetFormat::PiQ3 => 8, // 8 weights per 3-byte block + TargetFormat::PiQ2 => 4, // 4 weights per 1-byte block } } @@ -1001,9 +1001,13 @@ impl RuvltraQuantizer { TargetFormat::PiQ3 => { // Pi-constant 3-bit quantization (ADR-090) // Uses pi-scaled step sizes for better precision at ultra-low bits - use super::pi_quant_simd::{pi_scale_adaptive, pi_quantize_scalar, PI3_VALUES_PER_GROUP, PI3_BYTES_PER_GROUP, DEFAULT_K}; + use super::pi_quant_simd::{ + pi_quantize_scalar, pi_scale_adaptive, DEFAULT_K, PI3_BYTES_PER_GROUP, + PI3_VALUES_PER_GROUP, + }; - let num_groups = (padded_data.len() + PI3_VALUES_PER_GROUP - 1) / PI3_VALUES_PER_GROUP; + let num_groups = + (padded_data.len() + PI3_VALUES_PER_GROUP - 1) / PI3_VALUES_PER_GROUP; let mut bytes = Vec::with_capacity(num_groups * (PI3_BYTES_PER_GROUP + 2)); // +2 for scale as f16 for chunk in padded_data.chunks(PI3_VALUES_PER_GROUP) { @@ -1011,7 +1015,11 @@ impl RuvltraQuantizer { let max_abs = chunk.iter().map(|x| x.abs()).fold(0.0f32, f32::max); // Scale = alpha * pi / k, where alpha is derived from max_abs // For 3-bit signed range [-4, 3], we need max_abs / 4 as alpha - let alpha = if max_abs > 1e-10 { max_abs / 4.0 } else { 1e-10 }; + let alpha = if max_abs > 1e-10 { + max_abs / 4.0 + } else { + 1e-10 + }; let scale = pi_scale_adaptive(alpha, DEFAULT_K); // Store scale as f16 (2 bytes) @@ -1045,7 +1053,11 @@ impl RuvltraQuantizer { // Compute adaptive scale for this block let max_abs = chunk.iter().map(|x| x.abs()).fold(0.0f32, f32::max); // For 2-bit signed range [-2, 1], we need max_abs / 2 as alpha - let alpha = if max_abs > 1e-10 { max_abs / 2.0 } else { 1e-10 }; + let alpha = if max_abs > 1e-10 { + max_abs / 2.0 + } else { + 1e-10 + }; let scale = pi_scale_adaptive(alpha, DEFAULT_K); // Store scale as f16 (2 bytes) @@ -1053,7 +1065,11 @@ impl RuvltraQuantizer { // Quantize 4 values into 1 byte (2 bits each) let mut packed_byte = 0u8; - let inv_scale = if scale.abs() > 1e-10 { 1.0 / scale } else { 0.0 }; + let inv_scale = if scale.abs() > 1e-10 { + 1.0 / scale + } else { + 0.0 + }; for (i, &val) in chunk.iter().take(4).enumerate() { // 2-bit quantization: round and clamp to [-2, 1] let quantized = (val * inv_scale).round() as i32; @@ -1115,7 +1131,7 @@ impl RuvltraQuantizer { // Pi-quant formats: 8 values per 5 bytes (3 data + 2 scale) for PiQ3 // 4 values per 3 bytes (1 data + 2 scale) for PiQ2 TargetFormat::PiQ3 => { - use super::pi_quant_simd::{PI3_VALUES_PER_GROUP, PI3_BYTES_PER_GROUP}; + use super::pi_quant_simd::{PI3_BYTES_PER_GROUP, PI3_VALUES_PER_GROUP}; let num_groups = (input_elements + PI3_VALUES_PER_GROUP - 1) / PI3_VALUES_PER_GROUP; num_groups * (PI3_BYTES_PER_GROUP + 2) // +2 for f16 scale } diff --git a/crates/ruvllm/src/quantize/security.rs b/crates/ruvllm/src/quantize/security.rs index e30102492..c53865369 100644 --- a/crates/ruvllm/src/quantize/security.rs +++ b/crates/ruvllm/src/quantize/security.rs @@ -217,9 +217,9 @@ impl WeightIntegrity { original_hash.copy_from_slice(&bytes[0..32]); quantized_hash.copy_from_slice(&bytes[32..64]); - let mse_bytes: [u8; 4] = bytes[64..68].try_into().map_err(|_| { - RuvLLMError::Quantization("Invalid MSE bytes".to_string()) - })?; + let mse_bytes: [u8; 4] = bytes[64..68] + .try_into() + .map_err(|_| RuvLLMError::Quantization("Invalid MSE bytes".to_string()))?; let max_layer_mse = f32::from_le_bytes(mse_bytes); config_hash.copy_from_slice(&bytes[68..100]); @@ -378,7 +378,10 @@ impl QuantizationBounds { debug_assert!( clamped >= self.min_value && clamped <= self.max_value, "quantization overflow: q={}, range=[{}, {}), format={}", - value, self.min_value, self.max_value, self.format_name + value, + self.min_value, + self.max_value, + self.format_name ); #[cfg(not(debug_assertions))] @@ -400,7 +403,11 @@ impl QuantizationBounds { #[inline] pub fn quantize_with_bounds(&self, weight: f32, scale: f32) -> i8 { // INV-2: Scale must be positive - debug_assert!(scale > 0.0, "INV-2 violation: scale must be positive, got {}", scale); + debug_assert!( + scale > 0.0, + "INV-2 violation: scale must be positive, got {}", + scale + ); let q = (weight / scale).round() as i32; let q_clamped = self.validate(q); @@ -451,7 +458,8 @@ impl WasmSandboxConfig { if !self.linear_memory_isolation { return Err(RuvLLMError::Quantization( - "WASM sandbox security violation: linear memory isolation must be enabled".to_string(), + "WASM sandbox security violation: linear memory isolation must be enabled" + .to_string(), )); } @@ -781,12 +789,7 @@ mod tests { #[test] fn test_weight_integrity_serialization() { - let integrity = WeightIntegrity::new( - [1u8; 32], - [2u8; 32], - 0.0005, - [3u8; 32], - ); + let integrity = WeightIntegrity::new([1u8; 32], [2u8; 32], 0.0005, [3u8; 32]); let bytes = integrity.to_bytes(); let restored = WeightIntegrity::from_bytes(&bytes).unwrap(); @@ -799,12 +802,7 @@ mod tests { let data = b"test weights"; let hash = WeightIntegrity::sha256(data); - let integrity = WeightIntegrity::new( - [0u8; 32], - hash, - 0.0001, - [0u8; 32], - ); + let integrity = WeightIntegrity::new([0u8; 32], hash, 0.0001, [0u8; 32]); assert!(integrity.verify_quantized(data).is_ok()); } @@ -812,10 +810,8 @@ mod tests { #[test] fn test_weight_integrity_verification_failure() { let integrity = WeightIntegrity::new( - [0u8; 32], - [1u8; 32], // Wrong hash - 0.0001, - [0u8; 32], + [0u8; 32], [1u8; 32], // Wrong hash + 0.0001, [0u8; 32], ); assert!(integrity.verify_quantized(b"test weights").is_err()); @@ -896,7 +892,9 @@ mod tests { let large_perturb = vec![1.1, 2.1, 3.1]; assert!(InvariantValidator::validate_perturbation_bound(&original, &small_perturb).is_ok()); - assert!(InvariantValidator::validate_perturbation_bound(&original, &large_perturb).is_err()); + assert!( + InvariantValidator::validate_perturbation_bound(&original, &large_perturb).is_err() + ); } #[test] @@ -915,8 +913,12 @@ mod tests { let scalar_match = vec![1.0001, 2.0001, 3.0001]; let scalar_mismatch = vec![1.1, 2.0, 3.0]; - assert!(InvariantValidator::validate_simd_scalar_match(&simd, &scalar_match, 0.001).is_ok()); - assert!(InvariantValidator::validate_simd_scalar_match(&simd, &scalar_mismatch, 0.001).is_err()); + assert!( + InvariantValidator::validate_simd_scalar_match(&simd, &scalar_match, 0.001).is_ok() + ); + assert!( + InvariantValidator::validate_simd_scalar_match(&simd, &scalar_mismatch, 0.001).is_err() + ); } #[test] @@ -924,12 +926,7 @@ mod tests { let weights = b"test quantized weights"; let hash = WeightIntegrity::sha256(weights); - let integrity = WeightIntegrity::new( - [0u8; 32], - hash, - 0.0001, - [0u8; 32], - ); + let integrity = WeightIntegrity::new([0u8; 32], hash, 0.0001, [0u8; 32]); assert!(validate_quantized_model(weights, &integrity, None).is_ok()); } @@ -940,9 +937,7 @@ mod tests { let hash = WeightIntegrity::sha256(weights); let integrity = WeightIntegrity::new( - [0u8; 32], - hash, - 0.01, // High MSE + [0u8; 32], hash, 0.01, // High MSE [0u8; 32], ); diff --git a/crates/ruvllm/tests/acceptance_gates.rs b/crates/ruvllm/tests/acceptance_gates.rs index 54270599d..c8c0deab6 100644 --- a/crates/ruvllm/tests/acceptance_gates.rs +++ b/crates/ruvllm/tests/acceptance_gates.rs @@ -79,7 +79,11 @@ mod acceptance_gates { }; // Cosine similarity - let dot: f32 = original.iter().zip(dequantized.iter()).map(|(a, b)| a * b).sum(); + let dot: f32 = original + .iter() + .zip(dequantized.iter()) + .map(|(a, b)| a * b) + .sum(); let norm_orig: f32 = original.iter().map(|x| x * x).sum::().sqrt(); let norm_deq: f32 = dequantized.iter().map(|x| x * x).sum::().sqrt(); let cosine_similarity = if norm_orig > EPSILON && norm_deq > EPSILON { @@ -195,7 +199,10 @@ mod acceptance_gates { fn dequantize_block(&self, quantized: &[i8], alpha: f32) -> Vec { let step = self.step_size(); - quantized.iter().map(|&q| (q as f32) * alpha * step).collect() + quantized + .iter() + .map(|&q| (q as f32) * alpha * step) + .collect() } } @@ -271,10 +278,14 @@ mod acceptance_gates { let metrics_piq3 = QualityMetrics::calculate(weights, &deq_piq3); // Verify PiQ3 produces valid output - assert!(!alpha_piq3.is_nan() && !alpha_piq3.is_infinite(), - "PiQ3 alpha should be finite"); - assert!(deq_piq3.iter().all(|v| !v.is_nan()), - "PiQ3 output should not contain NaN"); + assert!( + !alpha_piq3.is_nan() && !alpha_piq3.is_infinite(), + "PiQ3 alpha should be finite" + ); + assert!( + deq_piq3.iter().all(|v| !v.is_nan()), + "PiQ3 output should not contain NaN" + ); // Uniform Q3 quantization let (q_uniform, scale_uniform) = uniform.quantize_block(weights); @@ -282,10 +293,14 @@ mod acceptance_gates { let metrics_uniform = QualityMetrics::calculate(weights, &deq_uniform); // Verify Uniform produces valid output - assert!(!scale_uniform.is_nan() && !scale_uniform.is_infinite(), - "Uniform scale should be finite"); - assert!(deq_uniform.iter().all(|v| !v.is_nan()), - "Uniform output should not contain NaN"); + assert!( + !scale_uniform.is_nan() && !scale_uniform.is_infinite(), + "Uniform scale should be finite" + ); + assert!( + deq_uniform.iter().all(|v| !v.is_nan()), + "Uniform output should not contain NaN" + ); // Verify metrics are valid (not NaN) assert!(!metrics_piq3.mse.is_nan(), "PiQ3 MSE should be valid"); @@ -298,12 +313,20 @@ mod acceptance_gates { "Distribution {}: PiQ3 better on {}/4 metrics", i, piq3_better ); - eprintln!(" PiQ3: MSE={:.6}, SNR={:.2}dB, cos={:.4}, outlier={:.2}%", - metrics_piq3.mse, metrics_piq3.spectral_db, - metrics_piq3.cosine_similarity, metrics_piq3.outlier_retention * 100.0); - eprintln!(" Uniform: MSE={:.6}, SNR={:.2}dB, cos={:.4}, outlier={:.2}%", - metrics_uniform.mse, metrics_uniform.spectral_db, - metrics_uniform.cosine_similarity, metrics_uniform.outlier_retention * 100.0); + eprintln!( + " PiQ3: MSE={:.6}, SNR={:.2}dB, cos={:.4}, outlier={:.2}%", + metrics_piq3.mse, + metrics_piq3.spectral_db, + metrics_piq3.cosine_similarity, + metrics_piq3.outlier_retention * 100.0 + ); + eprintln!( + " Uniform: MSE={:.6}, SNR={:.2}dB, cos={:.4}, outlier={:.2}%", + metrics_uniform.mse, + metrics_uniform.spectral_db, + metrics_uniform.cosine_similarity, + metrics_uniform.outlier_retention * 100.0 + ); if piq3_better >= 2 { total_piq3_wins += 1; @@ -314,7 +337,10 @@ mod acceptance_gates { // G1: Verify comparison framework works // For reference implementations, we validate the framework functions correctly // rather than asserting one method is definitively better - eprintln!("\nG1 Summary: PiQ3 wins {}/{} distributions", total_piq3_wins, total_tests); + eprintln!( + "\nG1 Summary: PiQ3 wins {}/{} distributions", + total_piq3_wins, total_tests + ); eprintln!("(Framework validation passed - both quantizers produce valid results)"); // The comparison framework must have run successfully on all distributions @@ -341,40 +367,88 @@ mod acceptance_gates { eprintln!("\nDetailed Quality Comparison (4096 normal weights):"); eprintln!("Metric PiQ3 Uniform Winner"); eprintln!("---------------------------------------------------"); - eprintln!("MSE {:.6} {:.6} {}", - m_piq3.mse, m_uniform.mse, - if m_piq3.mse < m_uniform.mse { "PiQ3" } else { "Uniform" }); - eprintln!("Spectral (dB) {:.2} {:.2} {}", - m_piq3.spectral_db, m_uniform.spectral_db, - if m_piq3.spectral_db > m_uniform.spectral_db { "PiQ3" } else { "Uniform" }); - eprintln!("Cosine Sim {:.4} {:.4} {}", - m_piq3.cosine_similarity, m_uniform.cosine_similarity, - if m_piq3.cosine_similarity > m_uniform.cosine_similarity { "PiQ3" } else { "Uniform" }); - eprintln!("Outlier Ret (%) {:.1} {:.1} {}", - m_piq3.outlier_retention * 100.0, m_uniform.outlier_retention * 100.0, - if m_piq3.outlier_retention > m_uniform.outlier_retention { "PiQ3" } else { "Uniform" }); + eprintln!( + "MSE {:.6} {:.6} {}", + m_piq3.mse, + m_uniform.mse, + if m_piq3.mse < m_uniform.mse { + "PiQ3" + } else { + "Uniform" + } + ); + eprintln!( + "Spectral (dB) {:.2} {:.2} {}", + m_piq3.spectral_db, + m_uniform.spectral_db, + if m_piq3.spectral_db > m_uniform.spectral_db { + "PiQ3" + } else { + "Uniform" + } + ); + eprintln!( + "Cosine Sim {:.4} {:.4} {}", + m_piq3.cosine_similarity, + m_uniform.cosine_similarity, + if m_piq3.cosine_similarity > m_uniform.cosine_similarity { + "PiQ3" + } else { + "Uniform" + } + ); + eprintln!( + "Outlier Ret (%) {:.1} {:.1} {}", + m_piq3.outlier_retention * 100.0, + m_uniform.outlier_retention * 100.0, + if m_piq3.outlier_retention > m_uniform.outlier_retention { + "PiQ3" + } else { + "Uniform" + } + ); // Verify metrics are valid and comparable let better_count = m_piq3.count_better_than(&m_uniform); eprintln!("\nPiQ3 wins {}/4 metrics", better_count); // Validate all metrics are finite and reasonable - assert!(m_piq3.mse >= 0.0 && !m_piq3.mse.is_nan(), "PiQ3 MSE should be valid non-negative"); - assert!(m_uniform.mse >= 0.0 && !m_uniform.mse.is_nan(), "Uniform MSE should be valid non-negative"); - assert!(m_piq3.cosine_similarity >= -1.0 && m_piq3.cosine_similarity <= 1.0, - "PiQ3 cosine similarity should be in [-1, 1]"); - assert!(m_uniform.cosine_similarity >= -1.0 && m_uniform.cosine_similarity <= 1.0, - "Uniform cosine similarity should be in [-1, 1]"); - assert!(m_piq3.outlier_retention >= 0.0 && m_piq3.outlier_retention <= 1.0, - "PiQ3 outlier retention should be in [0, 1]"); - assert!(m_uniform.outlier_retention >= 0.0 && m_uniform.outlier_retention <= 1.0, - "Uniform outlier retention should be in [0, 1]"); + assert!( + m_piq3.mse >= 0.0 && !m_piq3.mse.is_nan(), + "PiQ3 MSE should be valid non-negative" + ); + assert!( + m_uniform.mse >= 0.0 && !m_uniform.mse.is_nan(), + "Uniform MSE should be valid non-negative" + ); + assert!( + m_piq3.cosine_similarity >= -1.0 && m_piq3.cosine_similarity <= 1.0, + "PiQ3 cosine similarity should be in [-1, 1]" + ); + assert!( + m_uniform.cosine_similarity >= -1.0 && m_uniform.cosine_similarity <= 1.0, + "Uniform cosine similarity should be in [-1, 1]" + ); + assert!( + m_piq3.outlier_retention >= 0.0 && m_piq3.outlier_retention <= 1.0, + "PiQ3 outlier retention should be in [0, 1]" + ); + assert!( + m_uniform.outlier_retention >= 0.0 && m_uniform.outlier_retention <= 1.0, + "Uniform outlier retention should be in [0, 1]" + ); // Verify both methods achieve reasonable quality (cosine similarity > 0.8) - assert!(m_piq3.cosine_similarity > 0.8, - "PiQ3 should achieve reasonable quality: cos_sim={}", m_piq3.cosine_similarity); - assert!(m_uniform.cosine_similarity > 0.8, - "Uniform should achieve reasonable quality: cos_sim={}", m_uniform.cosine_similarity); + assert!( + m_piq3.cosine_similarity > 0.8, + "PiQ3 should achieve reasonable quality: cos_sim={}", + m_piq3.cosine_similarity + ); + assert!( + m_uniform.cosine_similarity > 0.8, + "Uniform should achieve reasonable quality: cos_sim={}", + m_uniform.cosine_similarity + ); } // ============================================================================ @@ -518,7 +592,9 @@ mod acceptance_gates { let piq3 = PiQ3Quantizer::new(); // Test with various sizes that could cause bounds issues - for size in [0, 1, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255, 256, 257, 511, 512] { + for size in [ + 0, 1, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255, 256, 257, 511, 512, + ] { let weights = generate_normal_weights(size); // Should not panic @@ -578,9 +654,7 @@ mod acceptance_gates { // NaN handling let with_nan = vec![1.0, f32::NAN, 2.0, 3.0]; - let result = std::panic::catch_unwind(|| { - piq3.quantize_block(&with_nan) - }); + let result = std::panic::catch_unwind(|| piq3.quantize_block(&with_nan)); // Should either succeed or fail gracefully (no crash) drop(result); } @@ -591,9 +665,7 @@ mod acceptance_gates { /// Generate uniform random weights in [-1, 1] fn generate_uniform_weights(n: usize) -> Vec { - (0..n) - .map(|i| ((i as f32) * 1.234).sin()) - .collect() + (0..n).map(|i| ((i as f32) * 1.234).sin()).collect() } /// Generate normal-ish distributed weights diff --git a/crates/ruvllm/tests/hadamard_tests.rs b/crates/ruvllm/tests/hadamard_tests.rs index 730804389..5c9858193 100644 --- a/crates/ruvllm/tests/hadamard_tests.rs +++ b/crates/ruvllm/tests/hadamard_tests.rs @@ -54,7 +54,11 @@ mod hadamard_tests { let signs: Vec = (0..size) .map(|_| { state = state.wrapping_mul(6364136223846793005).wrapping_add(1); - if (state >> 63) == 0 { 1 } else { -1 } + if (state >> 63) == 0 { + 1 + } else { + -1 + } }) .collect(); @@ -264,11 +268,7 @@ mod hadamard_tests { let ht = transpose(&h); let hth = matmul(&ht, &h); - assert!( - is_identity(&hth, EPSILON), - "H^T × H != I for n={}", - n - ); + assert!(is_identity(&hth, EPSILON), "H^T × H != I for n={}", n); } } @@ -292,7 +292,10 @@ mod hadamard_tests { assert!( (orig - rec).abs() < EPSILON, "Invertibility failed at {} for n={}: orig={}, rec={}", - i, n, orig, rec + i, + n, + orig, + rec ); } } @@ -313,7 +316,10 @@ mod hadamard_tests { assert!( (orig - rec).abs() < EPSILON, "Random invertibility failed at {} for n={}: orig={}, rec={}", - i, n, orig, rec + i, + n, + orig, + rec ); } } @@ -332,7 +338,9 @@ mod hadamard_tests { assert!( rec.abs() < EPSILON, "Zero invertibility failed at {} for n={}: rec={}", - i, n, rec + i, + n, + rec ); } } @@ -351,7 +359,10 @@ mod hadamard_tests { assert!( (orig - rec).abs() < EPSILON, "Ones invertibility failed at {} for n={}: orig={}, rec={}", - i, n, orig, rec + i, + n, + orig, + rec ); } } @@ -375,7 +386,10 @@ mod hadamard_tests { assert!( (orig - rec).abs() < EPSILON, "Sign flip invertibility failed at {} for n={}: orig={}, rec={}", - i, n, orig, rec + i, + n, + orig, + rec ); } } @@ -402,7 +416,8 @@ mod hadamard_tests { assert!( mse < EPSILON * EPSILON, "Seed {} failed with MSE {}", - seed, mse + seed, + mse ); } } @@ -414,11 +429,7 @@ mod hadamard_tests { if let Some(ref signs) = transform.sign_flips { // All signs should be +1 or -1 for &sign in signs { - assert!( - sign == 1 || sign == -1, - "Invalid sign value: {}", - sign - ); + assert!(sign == 1 || sign == -1, "Invalid sign value: {}", sign); } } } @@ -441,7 +452,9 @@ mod hadamard_tests { assert!( (energy_x - energy_hx).abs() < EPSILON * energy_x.max(1.0), "Energy not preserved for n={}: input={}, output={}", - n, energy_x, energy_hx + n, + energy_x, + energy_hx ); } } @@ -459,7 +472,9 @@ mod hadamard_tests { assert!( (energy_x - energy_hx).abs() < EPSILON * energy_x.max(1.0), "Energy not preserved with sign flips for n={}: input={}, output={}", - n, energy_x, energy_hx + n, + energy_x, + energy_hx ); } } @@ -480,7 +495,11 @@ mod hadamard_tests { let b = -1.3f32; // Compute H(ax + by) - let ax_by: Vec = x.iter().zip(y.iter()).map(|(xi, yi)| a * xi + b * yi).collect(); + let ax_by: Vec = x + .iter() + .zip(y.iter()) + .map(|(xi, yi)| a * xi + b * yi) + .collect(); let h_ax_by = transform.forward(&ax_by); // Compute aH(x) + bH(y) @@ -500,7 +519,11 @@ mod hadamard_tests { assert!( diff < linearity_epsilon * max_val, "Linearity failed at {} for n={}: H(ax+by)={}, aH(x)+bH(y)={}, diff={}", - i, n, left, right, diff + i, + n, + left, + right, + diff ); } } @@ -526,7 +549,10 @@ mod hadamard_tests { assert!( (h[i][j] - expected[i][j]).abs() < EPSILON, "H_2[{}][{}] = {}, expected {}", - i, j, h[i][j], expected[i][j] + i, + j, + h[i][j], + expected[i][j] ); } } @@ -544,7 +570,10 @@ mod hadamard_tests { assert!( (h[i][j].abs() - expected_abs).abs() < EPSILON, "H_4[{}][{}] = {}, expected +/- {}", - i, j, h[i][j], expected_abs + i, + j, + h[i][j], + expected_abs ); } } @@ -566,7 +595,9 @@ mod hadamard_tests { assert!( (result[i] - h[i][0]).abs() < EPSILON, "Transform of e_0 mismatch at {}: got {}, expected {}", - i, result[i], h[i][0] + i, + result[i], + h[i][0] ); } } @@ -613,7 +644,9 @@ mod hadamard_tests { assert!( relative_error < EPSILON, "Large value invertibility failed at {}: orig={}, rec={}", - i, orig, rec + i, + orig, + rec ); } } @@ -630,7 +663,9 @@ mod hadamard_tests { assert!( (orig - rec).abs() < 1e-10, "Small value invertibility failed at {}: orig={}, rec={}", - i, orig, rec + i, + orig, + rec ); } } @@ -648,7 +683,10 @@ mod hadamard_tests { transform.forward_inplace(&mut data); // Data should be modified - let different = data.iter().zip(original.iter()).any(|(a, b)| (a - b).abs() > EPSILON); + let different = data + .iter() + .zip(original.iter()) + .any(|(a, b)| (a - b).abs() > EPSILON); assert!(different, "In-place transform should modify data"); } @@ -660,16 +698,25 @@ mod hadamard_tests { let result = transform.forward(&original); // Original should be unchanged - for (i, (&orig, &expected)) in original.iter().zip((0..8).map(|i| i as f32).collect::>().iter()).enumerate() { + for (i, (&orig, &expected)) in original + .iter() + .zip((0..8).map(|i| i as f32).collect::>().iter()) + .enumerate() + { assert!( (orig - expected).abs() < EPSILON, "Original modified at {}: {} vs {}", - i, orig, expected + i, + orig, + expected ); } // Result should be different - let different = result.iter().zip(original.iter()).any(|(a, b)| (a - b).abs() > EPSILON); + let different = result + .iter() + .zip(original.iter()) + .any(|(a, b)| (a - b).abs() > EPSILON); assert!(different, "Transform result should differ from input"); } @@ -697,7 +744,8 @@ mod hadamard_tests { assert!( (max_component - min_component).abs() < EPSILON, "Energy should be uniformly spread: max={}, min={}", - max_component, min_component + max_component, + min_component ); } @@ -713,7 +761,13 @@ mod hadamard_tests { let hx2 = transform2.forward(&x); // Different seeds should produce different outputs - let different = hx1.iter().zip(hx2.iter()).any(|(a, b)| (a - b).abs() > EPSILON); - assert!(different, "Different seeds should produce different outputs"); + let different = hx1 + .iter() + .zip(hx2.iter()) + .any(|(a, b)| (a - b).abs() > EPSILON); + assert!( + different, + "Different seeds should produce different outputs" + ); } } diff --git a/crates/ruvllm/tests/moe_integration.rs b/crates/ruvllm/tests/moe_integration.rs new file mode 100644 index 000000000..433a23b70 --- /dev/null +++ b/crates/ruvllm/tests/moe_integration.rs @@ -0,0 +1,979 @@ +//! ADR-092 MoE Memory-Aware Routing Integration Tests +//! +//! Validates acceptance gates defined in ADR-092: +//! +//! - **G1**: Cache hit rate >= 70% (vs 34% baseline with LRU) +//! - **G2**: Accuracy retention <= 1% degradation +//! - **G3**: Latency bounds <= 10% p99 increase +//! - **G4**: Memory budget enforcement (never exceed configured budget) +//! +//! Invariants tested: +//! +//! - **INV-1**: Cached weights match persisted weights +//! - **INV-2**: Affinity scores decrease monotonically without activation +//! - **INV-3**: Total cached memory never exceeds configured budget +//! - **INV-6**: Router determinism (same input + cache state = same result) +//! +//! Test commands: +//! - All gates: `cargo test -p ruvllm moe_integration` +//! - G1 only: `cargo test -p ruvllm gate_1` +//! - G3 only: `cargo test -p ruvllm gate_3` +//! - G4 only: `cargo test -p ruvllm gate_4` + +#[cfg(test)] +mod moe_integration { + use ruvllm::bitnet::expert_cache::{ + align_to_cache_line, expert_memory_footprint, EvictionPolicy, ExpertCache, + ExpertCacheConfig, MoeBatchScheduler, NullPrefetcher, Prefetcher, + }; + use std::time::{Duration, Instant}; + + // ============================================================================ + // Test Constants (ADR-092 Targets) + // ============================================================================ + + /// Target cache hit rate for memory-aware routing (G1) + const TARGET_HIT_RATE: f32 = 0.70; + + /// Baseline LRU hit rate (from ADR-092) + const BASELINE_HIT_RATE: f32 = 0.34; + + /// Maximum accuracy degradation allowed (G2) + const MAX_ACCURACY_DEGRADATION: f32 = 0.01; + + /// Maximum p99 latency increase allowed (G3) + const MAX_LATENCY_INCREASE: f32 = 0.10; + + /// Routing overhead target in microseconds + const ROUTING_OVERHEAD_TARGET_US: u64 = 15; + + /// Baseline routing overhead in microseconds + const BASELINE_ROUTING_US: u64 = 5; + + /// Number of experts in Mixtral-style model + const NUM_EXPERTS: usize = 8; + + /// Top-K experts selected per token + const TOP_K: usize = 2; + + /// Hot-set size for memory-aware cache + const HOT_SET_SIZE: usize = 4; + + /// Number of tokens for workload simulation + const WORKLOAD_TOKENS: usize = 1000; + + /// Minimum prefetch accuracy target + const PREFETCH_ACCURACY_TARGET: f32 = 0.60; + + // ============================================================================ + // G1: Cache Hit Rate >= 70% + // ============================================================================ + + /// G1 Gate: Memory-aware routing achieves >= 70% cache hit rate + /// + /// Simulates a Mixtral-style workload with 8 experts and top-K=2 routing. + /// Standard LRU achieves ~34% hit rate; memory-aware should hit >= 70%. + #[test] + fn test_gate_1_cache_hit_rate() { + let config = ExpertCacheConfig { + max_hot_experts: HOT_SET_SIZE, + prefetch_threshold: 0.1, + eviction_policy: EvictionPolicy::Adaptive, + }; + let mut cache = ExpertCache::new(NUM_EXPERTS, config); + + // Simulate realistic workload with temporal locality + let routing_decisions = generate_realistic_routing(WORKLOAD_TOKENS, NUM_EXPERTS, TOP_K); + + for (_, experts) in &routing_decisions { + for &(expert_id, _weight) in experts { + cache.access(expert_id); + } + } + + let hit_rate = cache.stats().hit_rate(); + + eprintln!("\nG1 Cache Hit Rate Test:"); + eprintln!( + " Hit rate: {:.2}% (target: >= {:.0}%, baseline: {:.0}%)", + hit_rate * 100.0, + TARGET_HIT_RATE * 100.0, + BASELINE_HIT_RATE * 100.0 + ); + eprintln!( + " Hits: {}, Misses: {}, Evictions: {}", + cache.stats().hits, + cache.stats().misses, + cache.stats().evictions + ); + + // G1: Cache hit rate must be >= 70% + assert!( + hit_rate >= TARGET_HIT_RATE, + "G1 FAILED: Cache hit rate {:.2}% < target {:.0}%", + hit_rate * 100.0, + TARGET_HIT_RATE * 100.0 + ); + } + + /// G1 Comparison: Memory-aware vs baseline LRU + #[test] + fn test_gate_1_memory_aware_vs_lru_comparison() { + // Memory-aware with adaptive eviction + let adaptive_config = ExpertCacheConfig { + max_hot_experts: HOT_SET_SIZE, + prefetch_threshold: 0.1, + eviction_policy: EvictionPolicy::Adaptive, + }; + let mut adaptive_cache = ExpertCache::new(NUM_EXPERTS, adaptive_config); + + // Baseline LRU + let lru_config = ExpertCacheConfig { + max_hot_experts: HOT_SET_SIZE, + prefetch_threshold: 0.0, // No prefetch for baseline + eviction_policy: EvictionPolicy::Lru, + }; + let mut lru_cache = ExpertCache::new(NUM_EXPERTS, lru_config); + + // Same workload for both + let routing_decisions = generate_skewed_routing(WORKLOAD_TOKENS, NUM_EXPERTS, TOP_K); + + for (_, experts) in &routing_decisions { + for &(expert_id, _weight) in experts { + adaptive_cache.access(expert_id); + lru_cache.access(expert_id); + } + } + + let adaptive_hit_rate = adaptive_cache.stats().hit_rate(); + let lru_hit_rate = lru_cache.stats().hit_rate(); + + eprintln!("\nG1 Memory-Aware vs LRU Comparison:"); + eprintln!(" Adaptive hit rate: {:.2}%", adaptive_hit_rate * 100.0); + eprintln!(" LRU hit rate: {:.2}%", lru_hit_rate * 100.0); + eprintln!( + " Improvement: {:.2}x", + adaptive_hit_rate / lru_hit_rate.max(0.01) + ); + + // Adaptive should outperform LRU on skewed workloads + assert!( + adaptive_hit_rate >= lru_hit_rate, + "G1: Adaptive should match or exceed LRU performance" + ); + } + + // ============================================================================ + // G3: Routing Latency Overhead <= 10% p99 Increase + // ============================================================================ + + /// G3 Gate: Routing overhead <= 15 microseconds (baseline ~5 us) + #[test] + fn test_gate_3_routing_latency_overhead() { + let config = ExpertCacheConfig { + max_hot_experts: HOT_SET_SIZE, + prefetch_threshold: 0.1, + eviction_policy: EvictionPolicy::Adaptive, + }; + let mut cache = ExpertCache::new(NUM_EXPERTS, config); + + // Warm up cache + for i in 0..HOT_SET_SIZE { + cache.access(i); + } + + let iterations = 10000; + let mut latencies = Vec::with_capacity(iterations); + + for i in 0..iterations { + let expert_id = i % NUM_EXPERTS; + + let start = Instant::now(); + let _hit = cache.access(expert_id); + let _should_prefetch = cache.should_prefetch((i + 1) % NUM_EXPERTS, 0.15); + let elapsed = start.elapsed(); + + latencies.push(elapsed); + } + + // Sort for percentile calculation + latencies.sort(); + + let p50 = latencies[iterations / 2]; + let p95 = latencies[(iterations as f64 * 0.95) as usize]; + let p99 = latencies[(iterations as f64 * 0.99) as usize]; + let max = latencies[iterations - 1]; + + eprintln!("\nG3 Routing Latency Test:"); + eprintln!(" p50: {:?}", p50); + eprintln!(" p95: {:?}", p95); + eprintln!( + " p99: {:?} (target: <= {} us)", + p99, ROUTING_OVERHEAD_TARGET_US + ); + eprintln!(" max: {:?}", max); + + let p99_us = p99.as_micros() as u64; + + // G3: p99 latency must be <= 15 microseconds + // Note: On very fast machines, this may be sub-microsecond + assert!( + p99_us <= ROUTING_OVERHEAD_TARGET_US + || p99 <= Duration::from_micros(ROUTING_OVERHEAD_TARGET_US), + "G3 FAILED: p99 latency {} us > target {} us", + p99_us, + ROUTING_OVERHEAD_TARGET_US + ); + } + + /// G3: Batch scheduling latency + #[test] + fn test_gate_3_batch_scheduling_latency() { + let batch_sizes = [1, 8, 32, 128, 512]; + + eprintln!("\nG3 Batch Scheduling Latency:"); + + for &batch_size in &batch_sizes { + let routing_decisions: Vec<(usize, Vec<(usize, f32)>)> = (0..batch_size) + .map(|token_idx| { + let expert1 = (token_idx * 3) % NUM_EXPERTS; + let expert2 = (token_idx * 5 + 1) % NUM_EXPERTS; + (token_idx, vec![(expert1, 0.6), (expert2, 0.4)]) + }) + .collect(); + + let iterations = 1000; + let mut latencies = Vec::with_capacity(iterations); + + for _ in 0..iterations { + let start = Instant::now(); + let _batches = MoeBatchScheduler::schedule(&routing_decisions); + latencies.push(start.elapsed()); + } + + latencies.sort(); + let p99 = latencies[(iterations as f64 * 0.99) as usize]; + + eprintln!(" batch_size={}: p99={:?}", batch_size, p99); + + // Batch scheduling latency scales with batch size + // Target: O(n log n) for sorting, with generous allowance for debug builds + // Production builds would be ~5x faster; these thresholds are for correctness + let expected_max_us = 50 + (batch_size as u64); + assert!( + p99 < Duration::from_micros(expected_max_us), + "Batch scheduling too slow for size {}: {:?} (expected < {} us)", + batch_size, + p99, + expected_max_us + ); + } + } + + // ============================================================================ + // G4: Memory Budget Enforcement + // ============================================================================ + + /// G4 Gate: Total cached memory never exceeds configured budget + #[test] + fn test_gate_4_memory_budget_enforcement() { + let config = ExpertCacheConfig { + max_hot_experts: HOT_SET_SIZE, + prefetch_threshold: 0.1, + eviction_policy: EvictionPolicy::Adaptive, + }; + let mut cache = ExpertCache::new(NUM_EXPERTS, config); + + // Stress test: access many different experts rapidly + for i in 0..WORKLOAD_TOKENS * 10 { + let expert_id = (i * 7) % NUM_EXPERTS; // Pseudo-random pattern + cache.access(expert_id); + + // INV-3: Hot set size must never exceed configured maximum + assert!( + cache.hot_count() <= HOT_SET_SIZE, + "G4/INV-3 FAILED: Hot count {} exceeds max {} at iteration {}", + cache.hot_count(), + HOT_SET_SIZE, + i + ); + } + + eprintln!("\nG4 Memory Budget Enforcement Test:"); + eprintln!(" Max hot experts configured: {}", HOT_SET_SIZE); + eprintln!(" Final hot count: {}", cache.hot_count()); + eprintln!( + " Total accesses: {}", + cache.stats().hits + cache.stats().misses + ); + eprintln!(" Total evictions: {}", cache.stats().evictions); + } + + /// G4: Memory footprint calculation for realistic model sizes + #[test] + fn test_gate_4_memory_footprint_realistic() { + // Mixtral-style expert dimensions + let intermediate_size = 11008; + let hidden_size = 4096; + let block_size = 256; + + // Single projection memory footprint + let gate_proj = expert_memory_footprint(intermediate_size, hidden_size, block_size); + let up_proj = expert_memory_footprint(intermediate_size, hidden_size, block_size); + let down_proj = expert_memory_footprint(hidden_size, intermediate_size, block_size); + + let expert_total = gate_proj + up_proj + down_proj; + let hot_set_total = expert_total * HOT_SET_SIZE; + let all_experts_total = expert_total * NUM_EXPERTS; + + eprintln!("\nG4 Memory Footprint Analysis (Mixtral-style):"); + eprintln!(" gate_proj: {:.2} MB", gate_proj as f64 / 1e6); + eprintln!(" up_proj: {:.2} MB", up_proj as f64 / 1e6); + eprintln!(" down_proj: {:.2} MB", down_proj as f64 / 1e6); + eprintln!(" Per expert: {:.2} MB", expert_total as f64 / 1e6); + eprintln!( + " Hot set ({}): {:.2} MB", + HOT_SET_SIZE, + hot_set_total as f64 / 1e6 + ); + eprintln!( + " All experts ({}): {:.2} MB", + NUM_EXPERTS, + all_experts_total as f64 / 1e6 + ); + eprintln!( + " Memory savings: {:.2}x", + all_experts_total as f64 / hot_set_total as f64 + ); + + // Verify hot set is significantly smaller than full expert set + assert!( + hot_set_total < all_experts_total, + "Hot set should be smaller than full expert set" + ); + assert!( + hot_set_total as f64 / all_experts_total as f64 <= 0.5, + "Hot set should be at most 50% of full expert set" + ); + } + + // ============================================================================ + // INV-1: Cache Consistency + // ============================================================================ + + /// INV-1: Cached weights match persisted weights (simulated) + #[test] + fn test_invariant_1_cache_consistency() { + let config = ExpertCacheConfig::default(); + let mut cache = ExpertCache::new(NUM_EXPERTS, config); + + // Track what should be in cache + let mut expected_hot: Vec = Vec::new(); + + for i in 0..NUM_EXPERTS * 2 { + let expert_id = i % NUM_EXPERTS; + let was_hot = cache.is_hot(expert_id); + let is_hit = cache.access(expert_id); + + // INV-1: access() return value matches prior is_hot() state + assert_eq!( + was_hot, is_hit, + "INV-1 FAILED: is_hot={} but access returned hit={}", + was_hot, is_hit + ); + + // Track expected state + if !was_hot { + if expected_hot.len() >= cache.max_hot() { + expected_hot.remove(0); // Simulated LRU eviction + } + expected_hot.push(expert_id); + } + } + + eprintln!("\nINV-1 Cache Consistency Test: PASSED"); + } + + // ============================================================================ + // INV-2: Affinity Score Monotonicity + // ============================================================================ + + /// INV-2: Affinity scores decrease monotonically without activation + /// + /// This test simulates the EMA decay of affinity scores. + #[test] + fn test_invariant_2_affinity_monotonicity() { + // Simulate EMA-based affinity tracking + struct AffinityTracker { + scores: Vec, + decay: f32, + } + + impl AffinityTracker { + fn new(num_experts: usize, decay: f32) -> Self { + Self { + scores: vec![0.0; num_experts], + decay, + } + } + + fn activate(&mut self, expert_id: usize) { + if expert_id < self.scores.len() { + self.scores[expert_id] = 1.0; + } + } + + fn decay_all(&mut self) { + for score in &mut self.scores { + *score *= self.decay; + } + } + + fn score(&self, expert_id: usize) -> f32 { + self.scores.get(expert_id).copied().unwrap_or(0.0) + } + } + + let mut tracker = AffinityTracker::new(NUM_EXPERTS, 0.9); + + // Activate expert 0 + tracker.activate(0); + let initial_score = tracker.score(0); + assert_eq!(initial_score, 1.0); + + // Decay without reactivation - scores should decrease monotonically + let mut prev_score = initial_score; + for step in 1..=20 { + tracker.decay_all(); + let current_score = tracker.score(0); + + // INV-2: Score must decrease or stay equal (monotonic non-increase) + assert!( + current_score <= prev_score, + "INV-2 FAILED: Score increased from {} to {} at step {}", + prev_score, + current_score, + step + ); + + prev_score = current_score; + } + + eprintln!("\nINV-2 Affinity Monotonicity Test:"); + eprintln!(" Initial score: 1.0"); + eprintln!(" Final score after 20 decay steps: {:.6}", prev_score); + eprintln!(" Expected (0.9^20): {:.6}", 0.9f32.powi(20)); + } + + // ============================================================================ + // INV-6: Router Determinism + // ============================================================================ + + /// INV-6: Same input + cache state = same routing result + #[test] + fn test_invariant_6_router_determinism() { + let config = ExpertCacheConfig { + max_hot_experts: HOT_SET_SIZE, + prefetch_threshold: 0.1, + eviction_policy: EvictionPolicy::Lru, // Deterministic policy + }; + + // Run same access pattern twice + let access_pattern: Vec = (0..100).map(|i| (i * 3 + i / 7) % NUM_EXPERTS).collect(); + + let mut cache1 = ExpertCache::new(NUM_EXPERTS, config.clone()); + let mut results1 = Vec::new(); + for &expert_id in &access_pattern { + results1.push(( + cache1.access(expert_id), + cache1.should_prefetch((expert_id + 1) % NUM_EXPERTS, 0.15), + )); + } + + let mut cache2 = ExpertCache::new(NUM_EXPERTS, config); + let mut results2 = Vec::new(); + for &expert_id in &access_pattern { + results2.push(( + cache2.access(expert_id), + cache2.should_prefetch((expert_id + 1) % NUM_EXPERTS, 0.15), + )); + } + + // INV-6: Results must be identical + assert_eq!( + results1.len(), + results2.len(), + "INV-6 FAILED: Different result counts" + ); + + for (i, ((hit1, pf1), (hit2, pf2))) in results1.iter().zip(results2.iter()).enumerate() { + assert_eq!( + hit1, hit2, + "INV-6 FAILED: Different hit result at index {}", + i + ); + assert_eq!( + pf1, pf2, + "INV-6 FAILED: Different prefetch result at index {}", + i + ); + } + + // Stats should also match + assert_eq!( + cache1.stats().hits, + cache2.stats().hits, + "INV-6 FAILED: Different hit counts" + ); + assert_eq!( + cache1.stats().misses, + cache2.stats().misses, + "INV-6 FAILED: Different miss counts" + ); + + eprintln!("\nINV-6 Router Determinism Test: PASSED"); + eprintln!(" Pattern length: {}", access_pattern.len()); + eprintln!(" All {} results matched", results1.len()); + } + + // ============================================================================ + // End-to-End Routing Pipeline + // ============================================================================ + + /// Test full routing pipeline: route -> page -> compute -> metrics + #[test] + fn test_end_to_end_routing_pipeline() { + let config = ExpertCacheConfig { + max_hot_experts: HOT_SET_SIZE, + prefetch_threshold: 0.1, + eviction_policy: EvictionPolicy::Adaptive, + }; + let mut cache = ExpertCache::new(NUM_EXPERTS, config); + + // Simulate batch of tokens + let batch_size = 32; + let routing_decisions = generate_realistic_routing(batch_size, NUM_EXPERTS, TOP_K); + + // Phase 1: Route tokens and update cache + let mut total_hits = 0; + let mut total_accesses = 0; + for (_, experts) in &routing_decisions { + for &(expert_id, _weight) in experts { + if cache.access(expert_id) { + total_hits += 1; + } + total_accesses += 1; + } + } + + // Phase 2: Batch schedule for execution + let batches = MoeBatchScheduler::schedule(&routing_decisions); + + // Phase 3: Verify batch structure + let mut total_tokens_in_batches = 0; + for batch in &batches { + total_tokens_in_batches += batch.token_indices.len(); + + // Verify all weights are positive + for &weight in &batch.weights { + assert!(weight > 0.0, "Expert weights must be positive"); + } + } + + // Verify all token-expert pairs are accounted for + let expected_pairs = batch_size * TOP_K; + assert_eq!( + total_tokens_in_batches, expected_pairs, + "Batch should contain all token-expert pairs" + ); + + eprintln!("\nEnd-to-End Pipeline Test:"); + eprintln!(" Batch size: {} tokens", batch_size); + eprintln!(" Top-K: {}", TOP_K); + eprintln!(" Total accesses: {}", total_accesses); + eprintln!(" Cache hits: {}", total_hits); + eprintln!(" Expert batches: {}", batches.len()); + eprintln!(" Tokens scheduled: {}", total_tokens_in_batches); + } + + // ============================================================================ + // Precision Allocation Tests + // ============================================================================ + + /// Test precision allocation: hot experts get high precision, cold get low + #[test] + fn test_precision_allocation_correctness() { + // Simulate precision allocation based on cache status + #[derive(Debug, Clone, Copy, PartialEq)] + enum Precision { + FP16, // High precision for hot experts + INT8, // Medium precision + INT4, // Low precision for cold experts + } + + fn allocate_precision(cache: &ExpertCache, expert_id: usize) -> Precision { + if cache.is_hot(expert_id) { + Precision::FP16 + } else { + Precision::INT4 + } + } + + let config = ExpertCacheConfig { + max_hot_experts: HOT_SET_SIZE, + prefetch_threshold: 0.1, + eviction_policy: EvictionPolicy::Lru, + }; + let mut cache = ExpertCache::new(NUM_EXPERTS, config); + + // Make some experts hot + for i in 0..HOT_SET_SIZE { + cache.access(i); + } + + // Verify precision allocation + let mut hot_precision_count = 0; + let mut cold_precision_count = 0; + + for expert_id in 0..NUM_EXPERTS { + let precision = allocate_precision(&cache, expert_id); + if cache.is_hot(expert_id) { + assert_eq!( + precision, + Precision::FP16, + "Hot expert {} should get FP16 precision", + expert_id + ); + hot_precision_count += 1; + } else { + assert_eq!( + precision, + Precision::INT4, + "Cold expert {} should get INT4 precision", + expert_id + ); + cold_precision_count += 1; + } + } + + eprintln!("\nPrecision Allocation Test:"); + eprintln!(" FP16 (hot): {} experts", hot_precision_count); + eprintln!(" INT4 (cold): {} experts", cold_precision_count); + eprintln!( + " Total: {} experts", + hot_precision_count + cold_precision_count + ); + + assert_eq!(hot_precision_count, HOT_SET_SIZE); + assert_eq!(cold_precision_count, NUM_EXPERTS - HOT_SET_SIZE); + } + + // ============================================================================ + // Prefetch Prediction Accuracy + // ============================================================================ + + /// Test prefetch prediction accuracy (target: >= 60%) + #[test] + fn test_prefetch_prediction_accuracy() { + let config = ExpertCacheConfig { + max_hot_experts: HOT_SET_SIZE, + prefetch_threshold: 0.1, + eviction_policy: EvictionPolicy::Adaptive, + }; + let mut cache = ExpertCache::new(NUM_EXPERTS, config); + + // Generate routing with router weights + let routing_decisions = generate_routing_with_weights(WORKLOAD_TOKENS, NUM_EXPERTS, TOP_K); + + let mut prefetch_suggestions = 0; + + for (token_idx, experts) in &routing_decisions { + // Access current experts + for &(expert_id, _) in experts { + cache.access(expert_id); + } + + // Check what we would prefetch for next token + if *token_idx < routing_decisions.len() - 1 { + let next_experts = &routing_decisions[token_idx + 1].1; + for &(expert_id, weight) in next_experts { + if cache.should_prefetch(expert_id, weight) { + prefetch_suggestions += 1; + cache.prefetch_admit(expert_id); + } + } + } + } + + let prefetch_accuracy = if prefetch_suggestions > 0 { + cache.stats().prefetch_hits as f32 / prefetch_suggestions as f32 + } else { + 1.0 + }; + + eprintln!("\nPrefetch Prediction Accuracy Test:"); + eprintln!(" Prefetch suggestions: {}", prefetch_suggestions); + eprintln!(" Prefetch hits: {}", cache.stats().prefetch_hits); + eprintln!( + " Accuracy: {:.2}% (target: >= {:.0}%)", + prefetch_accuracy * 100.0, + PREFETCH_ACCURACY_TARGET * 100.0 + ); + + // Target: >= 60% prefetch accuracy + // Note: This depends on workload predictability + if prefetch_suggestions > 10 { + assert!( + prefetch_accuracy >= PREFETCH_ACCURACY_TARGET * 0.5, // Relaxed for test stability + "Prefetch accuracy {:.2}% below target {:.0}%", + prefetch_accuracy * 100.0, + PREFETCH_ACCURACY_TARGET * 100.0 + ); + } + } + + // ============================================================================ + // Workload Simulation: Mixtral + // ============================================================================ + + /// Simulate realistic Mixtral workload with 1000 tokens + #[test] + fn test_workload_simulation_mixtral() { + let config = ExpertCacheConfig { + max_hot_experts: HOT_SET_SIZE, + prefetch_threshold: 0.1, + eviction_policy: EvictionPolicy::Adaptive, + }; + let mut cache = ExpertCache::new(NUM_EXPERTS, config); + + // Mixtral-specific parameters + let num_layers = 32; + let tokens_per_batch = 32; + let batches = WORKLOAD_TOKENS / tokens_per_batch; + + let mut layer_hit_rates = Vec::new(); + + for _batch in 0..batches { + // Each batch goes through all layers + for _layer in 0..num_layers { + cache.reset_stats(); + + // Generate routing for this layer + let routing = generate_realistic_routing(tokens_per_batch, NUM_EXPERTS, TOP_K); + + for (_, experts) in &routing { + for &(expert_id, _) in experts { + cache.access(expert_id); + } + } + + layer_hit_rates.push(cache.stats().hit_rate()); + } + } + + let avg_hit_rate: f32 = layer_hit_rates.iter().sum::() / layer_hit_rates.len() as f32; + let min_hit_rate = layer_hit_rates + .iter() + .cloned() + .fold(f32::INFINITY, f32::min); + let max_hit_rate = layer_hit_rates.iter().cloned().fold(0.0f32, f32::max); + + eprintln!("\nMixtral Workload Simulation:"); + eprintln!(" Layers: {}", num_layers); + eprintln!(" Batches: {}", batches); + eprintln!(" Tokens per batch: {}", tokens_per_batch); + eprintln!(" Average hit rate: {:.2}%", avg_hit_rate * 100.0); + eprintln!(" Min hit rate: {:.2}%", min_hit_rate * 100.0); + eprintln!(" Max hit rate: {:.2}%", max_hit_rate * 100.0); + + // Verify reasonable performance + assert!( + avg_hit_rate > 0.3, + "Average hit rate {:.2}% too low", + avg_hit_rate * 100.0 + ); + } + + // ============================================================================ + // Helper Functions + // ============================================================================ + + /// Generate realistic routing decisions with temporal locality + fn generate_realistic_routing( + num_tokens: usize, + num_experts: usize, + top_k: usize, + ) -> Vec<(usize, Vec<(usize, f32)>)> { + // Simulate that certain experts are more popular (Zipf-like distribution) + let popularity: Vec = (0..num_experts) + .map(|i| 1.0 / ((i + 1) as f32).powf(0.5)) + .collect(); + + (0..num_tokens) + .map(|token_idx| { + // Select top-K based on popularity + some noise + let mut experts_with_scores: Vec<(usize, f32)> = (0..num_experts) + .map(|expert_id| { + let noise = ((token_idx * expert_id) as f32 * 0.1).sin() * 0.2; + (expert_id, popularity[expert_id] + noise) + }) + .collect(); + + experts_with_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + + let selected: Vec<(usize, f32)> = experts_with_scores + .into_iter() + .take(top_k) + .map(|(id, score)| { + // Normalize weights + (id, score.max(0.1)) + }) + .collect(); + + // Normalize to sum to 1 + let sum: f32 = selected.iter().map(|(_, w)| w).sum(); + let normalized: Vec<(usize, f32)> = + selected.into_iter().map(|(id, w)| (id, w / sum)).collect(); + + (token_idx, normalized) + }) + .collect() + } + + /// Generate skewed routing (some experts heavily favored) + fn generate_skewed_routing( + num_tokens: usize, + num_experts: usize, + _top_k: usize, + ) -> Vec<(usize, Vec<(usize, f32)>)> { + (0..num_tokens) + .map(|token_idx| { + // 80% of tokens go to experts 0, 1, 2 + // 20% go to other experts + let primary = if (token_idx * 7) % 10 < 8 { + token_idx % 3 // Experts 0, 1, 2 + } else { + 3 + (token_idx % (num_experts - 3)) // Other experts + }; + + let secondary = (primary + 1 + token_idx % 2) % num_experts; + + let experts = vec![(primary, 0.6), (secondary, 0.4)]; + (token_idx, experts) + }) + .collect() + } + + /// Generate routing with explicit router weights + fn generate_routing_with_weights( + num_tokens: usize, + num_experts: usize, + top_k: usize, + ) -> Vec<(usize, Vec<(usize, f32)>)> { + (0..num_tokens) + .map(|token_idx| { + let mut weights: Vec<(usize, f32)> = (0..num_experts) + .map(|expert_id| { + // Simulate softmax output + let logit = ((token_idx * expert_id) as f32 * 0.1).sin(); + (expert_id, logit.exp()) + }) + .collect(); + + // Normalize + let sum: f32 = weights.iter().map(|(_, w)| w).sum(); + for (_, w) in &mut weights { + *w /= sum; + } + + // Sort and take top-K + weights.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + let selected: Vec<(usize, f32)> = weights.into_iter().take(top_k).collect(); + + (token_idx, selected) + }) + .collect() + } + + // ============================================================================ + // Additional Edge Case Tests + // ============================================================================ + + #[test] + fn test_empty_routing() { + let routing: Vec<(usize, Vec<(usize, f32)>)> = vec![]; + let batches = MoeBatchScheduler::schedule(&routing); + assert!(batches.is_empty()); + } + + #[test] + fn test_single_expert_routing() { + let routing = vec![ + (0, vec![(3, 1.0)]), + (1, vec![(3, 1.0)]), + (2, vec![(3, 1.0)]), + ]; + let batches = MoeBatchScheduler::schedule(&routing); + + assert_eq!(batches.len(), 1); + assert_eq!(batches[0].expert_id, 3); + assert_eq!(batches[0].token_indices.len(), 3); + } + + #[test] + fn test_all_experts_routing() { + let routing: Vec<(usize, Vec<(usize, f32)>)> = + (0..NUM_EXPERTS).map(|i| (i, vec![(i, 1.0)])).collect(); + let batches = MoeBatchScheduler::schedule(&routing); + + assert_eq!(batches.len(), NUM_EXPERTS); + for batch in &batches { + assert_eq!(batch.token_indices.len(), 1); + } + } + + #[test] + fn test_cache_eviction_stress() { + let config = ExpertCacheConfig { + max_hot_experts: 2, // Very small hot set + prefetch_threshold: 0.1, + eviction_policy: EvictionPolicy::Lru, + }; + let mut cache = ExpertCache::new(16, config); + + // Access pattern that causes maximum evictions + for i in 0..1000 { + cache.access(i % 16); + } + + // Should have many evictions due to small hot set + assert!( + cache.stats().evictions > 500, + "Expected many evictions, got {}", + cache.stats().evictions + ); + assert!(cache.hot_count() <= 2); + } + + #[test] + fn test_prefetcher_trait() { + let prefetcher = NullPrefetcher; + let data = vec![0u8; 4096]; + + // Should not panic + prefetcher.prefetch(&data, 0, 64); + prefetcher.prefetch(&data, 2048, 1024); + prefetcher.prefetch(&data, 4000, 1000); // Exceeds data length - should be no-op + prefetcher.prefetch(&[], 0, 0); + } + + #[test] + fn test_cache_line_alignment() { + assert_eq!(align_to_cache_line(0), 0); + assert_eq!(align_to_cache_line(1), 64); + assert_eq!(align_to_cache_line(63), 64); + assert_eq!(align_to_cache_line(64), 64); + assert_eq!(align_to_cache_line(65), 128); + assert_eq!(align_to_cache_line(1000), 1024); + } +} diff --git a/crates/ruvllm/tests/pi_quant_tests.rs b/crates/ruvllm/tests/pi_quant_tests.rs index 432200fc1..d889a6e58 100644 --- a/crates/ruvllm/tests/pi_quant_tests.rs +++ b/crates/ruvllm/tests/pi_quant_tests.rs @@ -177,8 +177,10 @@ mod pi_quant_tests { // Byte 1: val[2](1 high) | val[3](3) | val[4](3) | val[5](1 low) // Byte 2: val[5](2 high) | val[6](3) | val[7](3) block.packed[0] = unsigned[0] | (unsigned[1] << 3) | ((unsigned[2] & 0x03) << 6); - block.packed[1] = - ((unsigned[2] >> 2) & 0x01) | (unsigned[3] << 1) | (unsigned[4] << 4) | ((unsigned[5] & 0x01) << 7); + block.packed[1] = ((unsigned[2] >> 2) & 0x01) + | (unsigned[3] << 1) + | (unsigned[4] << 4) + | ((unsigned[5] & 0x01) << 7); block.packed[2] = ((unsigned[5] >> 1) & 0x03) | (unsigned[6] << 2) | (unsigned[7] << 5); block @@ -358,9 +360,7 @@ mod pi_quant_tests { let q = PiQuantizer::piq3(); // Generate pseudo-random weights in [-1, 1] - let weights: Vec = (0..256) - .map(|i| ((i as f32) * 1.234).sin()) - .collect(); + let weights: Vec = (0..256).map(|i| ((i as f32) * 1.234).sin()).collect(); let (quantized, alpha) = q.quantize_block(&weights); let dequantized = q.dequantize_block(&quantized, alpha); @@ -410,7 +410,10 @@ mod pi_quant_tests { q.clamp(-4, 3) }) .collect(); - let deq_uniform: Vec = q_uniform.iter().map(|&q| (q as f32) * uniform_step).collect(); + let deq_uniform: Vec = q_uniform + .iter() + .map(|&q| (q as f32) * uniform_step) + .collect(); let mse_uniform: f32 = weights .iter() @@ -583,8 +586,16 @@ mod pi_quant_tests { let min_q = q.quantize_scalar(-100.0, alpha); // Verify values are at the extremes of the valid range - assert!(max_q >= 2 && max_q <= 3, "Large positive should clamp to max range, got {}", max_q); - assert!(min_q >= -4 && min_q <= -3, "Large negative should clamp to min range, got {}", min_q); + assert!( + max_q >= 2 && max_q <= 3, + "Large positive should clamp to max range, got {}", + max_q + ); + assert!( + min_q >= -4 && min_q <= -3, + "Large negative should clamp to min range, got {}", + min_q + ); // Most importantly, verify clamping works (values stay in range) assert!(max_q <= 3, "Max should not exceed 3"); @@ -601,8 +612,16 @@ mod pi_quant_tests { let min_q = q.quantize_scalar(-100.0, alpha); // Verify values are at the extremes of the valid range - assert!(max_q >= 0 && max_q <= 1, "Large positive should clamp to max range, got {}", max_q); - assert!(min_q >= -2 && min_q <= -1, "Large negative should clamp to min range, got {}", min_q); + assert!( + max_q >= 0 && max_q <= 1, + "Large positive should clamp to max range, got {}", + max_q + ); + assert!( + min_q >= -2 && min_q <= -1, + "Large negative should clamp to min range, got {}", + min_q + ); // Most importantly, verify clamping works (values stay in range) assert!(max_q <= 1, "Max should not exceed 1"); @@ -648,9 +667,7 @@ mod pi_quant_tests { // This test verifies that infinity values don't cause panics // and produce values within the valid range - let result = std::panic::catch_unwind(|| { - q.quantize_block(&weights) - }); + let result = std::panic::catch_unwind(|| q.quantize_block(&weights)); match result { Ok((quantized, alpha)) => { @@ -665,11 +682,7 @@ mod pi_quant_tests { // Alpha computation may produce infinity - that's acceptable // as long as it doesn't produce NaN - assert!( - !alpha.is_nan(), - "Alpha should not be NaN: {}", - alpha - ); + assert!(!alpha.is_nan(), "Alpha should not be NaN: {}", alpha); } Err(_) => { // Panicking on infinity is acceptable behavior @@ -684,9 +697,7 @@ mod pi_quant_tests { let weights = vec![f32::NAN, 1.0, -1.0, 0.5]; // Should not panic - let result = std::panic::catch_unwind(|| { - q.quantize_block(&weights) - }); + let result = std::panic::catch_unwind(|| q.quantize_block(&weights)); // Either succeeds with reasonable output or panics gracefully if let Ok((quantized, _alpha)) = result { @@ -742,15 +753,16 @@ mod pi_quant_tests { // Dequantization should not produce NaN let dequantized = q.dequantize_block(&quantized, alpha); for &d in &dequantized { - assert!( - !d.is_nan(), - "Dequantized value should not be NaN" - ); + assert!(!d.is_nan(), "Dequantized value should not be NaN"); } // If alpha is finite, verify the large values were handled if alpha.is_finite() && alpha < 1e38 { - assert!(alpha > 1e20, "Alpha should scale with large weights: {}", alpha); + assert!( + alpha > 1e20, + "Alpha should scale with large weights: {}", + alpha + ); } } @@ -767,7 +779,10 @@ mod pi_quant_tests { assert!(alpha > 0.0, "Alpha must be positive for empty input"); let (quantized, _) = q.quantize_block(&weights); - assert!(quantized.is_empty(), "Empty input should produce empty output"); + assert!( + quantized.is_empty(), + "Empty input should produce empty output" + ); } #[test] @@ -783,7 +798,11 @@ mod pi_quant_tests { // Roundtrip error should be bounded let error = (weights[0] - dequantized[0]).abs(); - assert!(error < 0.5, "Single element roundtrip error too high: {}", error); + assert!( + error < 0.5, + "Single element roundtrip error too high: {}", + error + ); } #[test] @@ -839,7 +858,11 @@ mod pi_quant_tests { let dequantized = q.dequantize_block(&quantized, alpha); // Compute cosine similarity - let dot: f32 = weights.iter().zip(dequantized.iter()).map(|(a, b)| a * b).sum(); + let dot: f32 = weights + .iter() + .zip(dequantized.iter()) + .map(|(a, b)| a * b) + .sum(); let norm_orig: f32 = weights.iter().map(|x| x * x).sum::().sqrt(); let norm_deq: f32 = dequantized.iter().map(|x| x * x).sum::().sqrt(); diff --git a/crates/ruvllm/tests/simd_equivalence_tests.rs b/crates/ruvllm/tests/simd_equivalence_tests.rs index a6b0849fd..19966f282 100644 --- a/crates/ruvllm/tests/simd_equivalence_tests.rs +++ b/crates/ruvllm/tests/simd_equivalence_tests.rs @@ -224,7 +224,11 @@ mod simd_equivalence_tests { assert!( ulp <= MAX_ULP_DIFFERENCE, "Element {}: scalar={}, simd={}, ULP diff={} (max={})", - i, s, simd, ulp, MAX_ULP_DIFFERENCE + i, + s, + simd, + ulp, + MAX_ULP_DIFFERENCE ); } } @@ -282,7 +286,10 @@ mod simd_equivalence_tests { assert!( within_ulp_tolerance(s, simd, MAX_ULP_DIFFERENCE), "Large block element {}: scalar={}, simd={}, ULP={}", - i, s, simd, ulp_difference(s, simd) + i, + s, + simd, + ulp_difference(s, simd) ); } } @@ -304,7 +311,9 @@ mod simd_equivalence_tests { assert!( s == 0.0 && simd == 0.0, "Zero input should produce zero output at {}: scalar={}, simd={}", - i, s, simd + i, + s, + simd ); } } @@ -324,7 +333,10 @@ mod simd_equivalence_tests { assert!( ulp <= MAX_ULP_DIFFERENCE, "Max value element {}: scalar={}, simd={}, ULP={}", - i, s, simd, ulp + i, + s, + simd, + ulp ); } } @@ -344,7 +356,10 @@ mod simd_equivalence_tests { assert!( ulp <= MAX_ULP_DIFFERENCE, "Min value element {}: scalar={}, simd={}, ULP={}", - i, s, simd, ulp + i, + s, + simd, + ulp ); } } @@ -367,7 +382,10 @@ mod simd_equivalence_tests { assert!( ulp <= MAX_ULP_DIFFERENCE, "Alternating element {}: scalar={}, simd={}, ULP={}", - i, s, simd, ulp + i, + s, + simd, + ulp ); } } @@ -390,7 +408,10 @@ mod simd_equivalence_tests { assert!( ulp <= MAX_ULP_DIFFERENCE, "Small alpha element {}: scalar={}, simd={}, ULP={}", - i, s, simd, ulp + i, + s, + simd, + ulp ); } } @@ -409,7 +430,10 @@ mod simd_equivalence_tests { assert!( ulp <= MAX_ULP_DIFFERENCE, "Large alpha element {}: scalar={}, simd={}, ULP={}", - i, s, simd, ulp + i, + s, + simd, + ulp ); } } @@ -428,7 +452,10 @@ mod simd_equivalence_tests { assert!( ulp <= MAX_ULP_DIFFERENCE, "Fractional alpha element {}: scalar={}, simd={}, ULP={}", - i, s, simd, ulp + i, + s, + simd, + ulp ); } } @@ -451,7 +478,10 @@ mod simd_equivalence_tests { assert!( ulp <= MAX_ULP_DIFFERENCE, "k=2 element {}: scalar={}, simd={}, ULP={}", - i, s, simd, ulp + i, + s, + simd, + ulp ); } } @@ -470,7 +500,10 @@ mod simd_equivalence_tests { assert!( ulp <= MAX_ULP_DIFFERENCE, "k=4 element {}: scalar={}, simd={}, ULP={}", - i, s, simd, ulp + i, + s, + simd, + ulp ); } } @@ -502,7 +535,10 @@ mod simd_equivalence_tests { assert!( ulp <= MAX_ULP_DIFFERENCE, "Full pipeline element {}: scalar={}, simd={}, ULP={}", - i, s, simd, ulp + i, + s, + simd, + ulp ); } } @@ -526,7 +562,10 @@ mod simd_equivalence_tests { assert!( ulp <= MAX_ULP_DIFFERENCE, "NEON element {}: scalar={}, neon={}, ULP={}", - i, s, neon, ulp + i, + s, + neon, + ulp ); } } @@ -546,7 +585,10 @@ mod simd_equivalence_tests { assert!( ulp <= MAX_ULP_DIFFERENCE, "AVX2 element {}: scalar={}, avx2={}, ULP={}", - i, s, avx2, ulp + i, + s, + avx2, + ulp ); } } @@ -565,7 +607,10 @@ mod simd_equivalence_tests { assert!( ulp <= MAX_ULP_DIFFERENCE, "WASM element {}: scalar={}, wasm={}, ULP={}", - i, s, wasm, ulp + i, + s, + wasm, + ulp ); } } @@ -602,7 +647,12 @@ mod simd_equivalence_tests { assert!( ulp <= STRESS_TEST_ULP_TOLERANCE, "Block {} element {}: scalar={}, simd={}, ULP={} (max allowed={})", - block_idx, i, s, simd, ulp, STRESS_TEST_ULP_TOLERANCE + block_idx, + i, + s, + simd, + ulp, + STRESS_TEST_ULP_TOLERANCE ); } } @@ -635,7 +685,11 @@ mod simd_equivalence_tests { assert!( ulp <= MAX_ULP_DIFFERENCE, "Size {} element {}: scalar={}, simd={}, ULP={}", - size, i, s, simd, ulp + size, + i, + s, + simd, + ulp ); } } diff --git a/crates/ruvllm/tests/ste_tests.rs b/crates/ruvllm/tests/ste_tests.rs index 50710935f..fda762e8d 100644 --- a/crates/ruvllm/tests/ste_tests.rs +++ b/crates/ruvllm/tests/ste_tests.rs @@ -52,7 +52,9 @@ mod ste_tests { let half_range = (levels - 1) / 2; // Quantize: q = round(x / scale) - let q = (x / scale).round().clamp(-(half_range as f32) - 1.0, half_range as f32) as i8; + let q = (x / scale) + .round() + .clamp(-(half_range as f32) - 1.0, half_range as f32) as i8; // Dequantize: x_hat = q * scale let x_hat = (q as f32) * scale; @@ -61,12 +63,7 @@ mod ste_tests { } /// STE backward pass for input gradient - fn ste_backward_input( - grad_output: f32, - x: f32, - scale: f32, - variant: SteVariant, - ) -> f32 { + fn ste_backward_input(grad_output: f32, x: f32, scale: f32, variant: SteVariant) -> f32 { match variant { SteVariant::Standard => { // Standard STE: pass gradient through unchanged @@ -102,13 +99,7 @@ mod ste_tests { } /// STE backward pass for scale gradient (LSQ variant) - fn ste_backward_scale( - grad_output: f32, - x: f32, - scale: f32, - q: i8, - variant: SteVariant, - ) -> f32 { + fn ste_backward_scale(grad_output: f32, x: f32, scale: f32, q: i8, variant: SteVariant) -> f32 { match variant { SteVariant::LearnedStepSize => { // LSQ scale gradient: grad_s = grad_output * (q - x/s) / sqrt(n_levels) @@ -137,7 +128,8 @@ mod ste_tests { assert!( (grad_in - grad_out).abs() < EPSILON, "Standard STE should pass gradient: grad_out={}, got grad_in={}", - grad_out, grad_in + grad_out, + grad_in ); } } @@ -154,7 +146,8 @@ mod ste_tests { assert!( (grad_in - grad_out).abs() < EPSILON, "Standard STE should ignore scale: scale={}, grad_in={}", - scale, grad_in + scale, + grad_in ); } } @@ -175,7 +168,8 @@ mod ste_tests { assert!( (grad_in - grad_out).abs() < EPSILON, "Clipped STE should pass gradient for x={}: got {}", - x, grad_in + x, + grad_in ); } } @@ -192,7 +186,8 @@ mod ste_tests { assert!( grad_in.abs() < EPSILON, "Clipped STE should zero gradient for x={}: got {}", - x, grad_in + x, + grad_in ); } } @@ -238,7 +233,8 @@ mod ste_tests { assert!( (grad_in - grad_out).abs() < EPSILON, "LSQ should pass gradient for x={}: got {}", - x, grad_in + x, + grad_in ); } @@ -248,7 +244,8 @@ mod ste_tests { assert!( grad_in.abs() < EPSILON, "LSQ should zero gradient for x={}: got {}", - x, grad_in + x, + grad_in ); } } @@ -276,7 +273,9 @@ mod ste_tests { assert!( (grad_s - expected).abs() < 0.1, "LSQ scale gradient for x={}: expected {}, got {}", - x, expected, grad_s + x, + expected, + grad_s ); } } @@ -364,7 +363,8 @@ mod ste_tests { assert!( grad_in >= 0.0 && grad_in <= grad_out, "EWGS gradient out of range for x={}: got {}", - x, grad_in + x, + grad_in ); } } @@ -383,7 +383,10 @@ mod ste_tests { assert!( (grad_pos - grad_neg).abs() < EPSILON, "EWGS should be symmetric: x={} -> {}, x={} -> {}", - abs_x, grad_pos, -abs_x, grad_neg + abs_x, + grad_pos, + -abs_x, + grad_neg ); } } @@ -413,7 +416,9 @@ mod ste_tests { assert!( (grad_in - expected).abs() < EPSILON, "PyTorch ref Standard STE: x={}, expected={}, got={}", - x, expected, grad_in + x, + expected, + grad_in ); } } @@ -424,14 +429,14 @@ mod ste_tests { let variant = SteVariant::Clipped; let test_cases = [ // (x, scale, grad_out, expected_grad_in) - (0.5, 1.0, 1.0, 1.0), // Inside range - (-0.5, 1.0, 1.0, 1.0), // Inside range - (1.0, 1.0, 1.0, 1.0), // At boundary - (-1.0, 1.0, 1.0, 1.0), // At boundary - (1.5, 1.0, 1.0, 0.0), // Outside range - (-1.5, 1.0, 1.0, 0.0), // Outside range - (0.5, 0.5, 1.0, 1.0), // Inside range (x/scale = 1) - (1.0, 0.5, 1.0, 0.0), // Outside range (x/scale = 2) + (0.5, 1.0, 1.0, 1.0), // Inside range + (-0.5, 1.0, 1.0, 1.0), // Inside range + (1.0, 1.0, 1.0, 1.0), // At boundary + (-1.0, 1.0, 1.0, 1.0), // At boundary + (1.5, 1.0, 1.0, 0.0), // Outside range + (-1.5, 1.0, 1.0, 0.0), // Outside range + (0.5, 0.5, 1.0, 1.0), // Inside range (x/scale = 1) + (1.0, 0.5, 1.0, 0.0), // Outside range (x/scale = 2) ]; for (x, scale, grad_out, expected) in test_cases { @@ -439,7 +444,10 @@ mod ste_tests { assert!( (grad_in - expected).abs() < EPSILON, "PyTorch ref Clipped STE: x={}, scale={}, expected={}, got={}", - x, scale, expected, grad_in + x, + scale, + expected, + grad_in ); } } @@ -449,10 +457,10 @@ mod ste_tests { // LSQ paper: gradient passes through in [-Qn, Qp] range let variant = SteVariant::LearnedStepSize; let test_cases = [ - (0.0, 1.0, 1.0, 1.0), // Center - (2.0, 1.0, 1.0, 1.0), // Within range - (4.0, 1.0, 1.0, 1.0), // At boundary - (5.0, 1.0, 1.0, 0.0), // Outside range + (0.0, 1.0, 1.0, 1.0), // Center + (2.0, 1.0, 1.0, 1.0), // Within range + (4.0, 1.0, 1.0, 1.0), // At boundary + (5.0, 1.0, 1.0, 0.0), // Outside range ]; for (x, scale, grad_out, expected) in test_cases { @@ -460,7 +468,9 @@ mod ste_tests { assert!( (grad_in - expected).abs() < EPSILON, "PyTorch ref LSQ input: x={}, expected={}, got={}", - x, expected, grad_in + x, + expected, + grad_in ); } } @@ -489,7 +499,10 @@ mod ste_tests { assert!( (grad_chain - upstream_grad * grad_1).abs() < EPSILON, "{:?}: chain rule violated: {} * {} != {}", - variant, upstream_grad, grad_1, grad_chain + variant, + upstream_grad, + grad_1, + grad_chain ); } } @@ -512,7 +525,8 @@ mod ste_tests { assert!( (accumulated - expected).abs() < EPSILON, "Gradient accumulation: expected {}, got {}", - expected, accumulated + expected, + accumulated ); } @@ -548,7 +562,8 @@ mod ste_tests { assert!( numerical_grad.abs() < 1.0, "Numerical gradient should be near 0 away from boundaries at x={}: {}", - x, numerical_grad + x, + numerical_grad ); } @@ -585,7 +600,8 @@ mod ste_tests { assert!( grad_in.abs() < EPSILON, "{:?}: zero upstream should give zero local: got {}", - variant, grad_in + variant, + grad_in ); } } @@ -613,7 +629,8 @@ mod ste_tests { assert!( grad_in.is_finite(), "{:?}: gradient should be finite with small scale: got {}", - variant, grad_in + variant, + grad_in ); } } @@ -629,7 +646,8 @@ mod ste_tests { assert!( grad_in.is_finite(), "{:?}: gradient should be finite with large scale: got {}", - variant, grad_in + variant, + grad_in ); } } @@ -672,7 +690,8 @@ mod ste_tests { assert!( final_error <= initial_error * 1.1 || final_error < 0.5, "Training should reduce error: initial={}, final={}", - initial_error, final_error + initial_error, + final_error ); } diff --git a/crates/rvf/rvf-adapters/agentdb/src/index_adapter.rs b/crates/rvf/rvf-adapters/agentdb/src/index_adapter.rs index 06224a780..1bba04409 100644 --- a/crates/rvf/rvf-adapters/agentdb/src/index_adapter.rs +++ b/crates/rvf/rvf-adapters/agentdb/src/index_adapter.rs @@ -313,10 +313,7 @@ mod tests { #[test] fn compute_centroid_basic() { - let vecs = vec![ - vec![1.0, 2.0, 3.0], - vec![3.0, 4.0, 5.0], - ]; + let vecs = vec![vec![1.0, 2.0, 3.0], vec![3.0, 4.0, 5.0]]; let centroid = compute_centroid(&vecs, 3); assert_eq!(centroid, vec![2.0, 3.0, 4.0]); } diff --git a/crates/rvf/rvf-adapters/agentdb/src/vector_store.rs b/crates/rvf/rvf-adapters/agentdb/src/vector_store.rs index 2d50f7737..e369abe46 100644 --- a/crates/rvf/rvf-adapters/agentdb/src/vector_store.rs +++ b/crates/rvf/rvf-adapters/agentdb/src/vector_store.rs @@ -5,9 +5,7 @@ use std::path::{Path, PathBuf}; -use rvf_runtime::options::{ - DistanceMetric, MetadataEntry, QueryOptions, RvfOptions, SearchResult, -}; +use rvf_runtime::options::{DistanceMetric, MetadataEntry, QueryOptions, RvfOptions, SearchResult}; use rvf_runtime::RvfStore; use rvf_types::{ErrorCode, RvfError}; @@ -105,7 +103,10 @@ impl RvfVectorStore { ids: &[u64], metadata: Option<&[MetadataEntry]>, ) -> Result { - let store = self.store.as_mut().ok_or(RvfError::Code(ErrorCode::InvalidManifest))?; + let store = self + .store + .as_mut() + .ok_or(RvfError::Code(ErrorCode::InvalidManifest))?; let result = store.ingest_batch(vectors, ids, metadata)?; Ok(result.accepted) } @@ -119,7 +120,10 @@ impl RvfVectorStore { k: usize, ef_search: Option, ) -> Result, RvfError> { - let store = self.store.as_ref().ok_or(RvfError::Code(ErrorCode::InvalidManifest))?; + let store = self + .store + .as_ref() + .ok_or(RvfError::Code(ErrorCode::InvalidManifest))?; let opts = QueryOptions { ef_search: ef_search.unwrap_or(self.config.ef_search), ..Default::default() @@ -129,7 +133,10 @@ impl RvfVectorStore { /// Delete vectors by their IDs. pub fn delete_vectors(&mut self, ids: &[u64]) -> Result { - let store = self.store.as_mut().ok_or(RvfError::Code(ErrorCode::InvalidManifest))?; + let store = self + .store + .as_mut() + .ok_or(RvfError::Code(ErrorCode::InvalidManifest))?; let result = store.delete(ids)?; Ok(result.deleted) } @@ -155,7 +162,9 @@ impl RvfVectorStore { ef_search: self.config.ef_search, ..Default::default() }; - let results = store.query(&zero_query, status.total_vectors as usize, &opts).ok()?; + let results = store + .query(&zero_query, status.total_vectors as usize, &opts) + .ok()?; results.into_iter().find(|r| r.id == id) } @@ -189,7 +198,10 @@ impl RvfVectorStore { /// Run compaction to reclaim space from deleted vectors. pub fn compact(&mut self) -> Result { - let store = self.store.as_mut().ok_or(RvfError::Code(ErrorCode::InvalidManifest))?; + let store = self + .store + .as_mut() + .ok_or(RvfError::Code(ErrorCode::InvalidManifest))?; let result = store.compact()?; Ok(result.bytes_reclaimed) } diff --git a/crates/rvf/rvf-adapters/agentic-flow/src/learning.rs b/crates/rvf/rvf-adapters/agentic-flow/src/learning.rs index 31ca2043f..73674afcc 100644 --- a/crates/rvf/rvf-adapters/agentic-flow/src/learning.rs +++ b/crates/rvf/rvf-adapters/agentic-flow/src/learning.rs @@ -85,11 +85,7 @@ impl LearningPatternStore { /// Search patterns by returning those whose IDs are in the given candidate /// set (from a vector similarity search), enriched with metadata. - pub fn enrich_results( - &self, - candidates: &[(u64, f32)], - k: usize, - ) -> Vec { + pub fn enrich_results(&self, candidates: &[(u64, f32)], k: usize) -> Vec { let mut results: Vec = candidates .iter() .filter_map(|&(id, distance)| { @@ -201,7 +197,9 @@ mod tests { #[test] fn store_and_retrieve() { let mut store = LearningPatternStore::new(); - let id = store.store_pattern("convergent", "Use batched writes", 0.85).unwrap(); + let id = store + .store_pattern("convergent", "Use batched writes", 0.85) + .unwrap(); let p = store.get_pattern(id).unwrap(); assert_eq!(p.pattern_type, "convergent"); @@ -212,7 +210,9 @@ mod tests { #[test] fn update_score() { let mut store = LearningPatternStore::new(); - let id = store.store_pattern("lateral", "Try alternative approach", 0.5).unwrap(); + let id = store + .store_pattern("lateral", "Try alternative approach", 0.5) + .unwrap(); store.update_score(id, 0.95).unwrap(); let p = store.get_pattern(id).unwrap(); diff --git a/crates/rvf/rvf-adapters/agentic-flow/src/lib.rs b/crates/rvf/rvf-adapters/agentic-flow/src/lib.rs index 68b23e2d3..cb90428f4 100644 --- a/crates/rvf/rvf-adapters/agentic-flow/src/lib.rs +++ b/crates/rvf/rvf-adapters/agentic-flow/src/lib.rs @@ -48,6 +48,4 @@ pub mod swarm_store; pub use config::{AgenticFlowConfig, ConfigError}; pub use coordination::{ConsensusVote, StateEntry, SwarmCoordination}; pub use learning::{LearningPatternStore, PatternResult}; -pub use swarm_store::{ - RvfSwarmStore, SharedMemoryEntry, SharedMemoryResult, SwarmStoreError, -}; +pub use swarm_store::{RvfSwarmStore, SharedMemoryEntry, SharedMemoryResult, SwarmStoreError}; diff --git a/crates/rvf/rvf-adapters/agentic-flow/src/swarm_store.rs b/crates/rvf/rvf-adapters/agentic-flow/src/swarm_store.rs index 1c7cafb0c..6cf391f61 100644 --- a/crates/rvf/rvf-adapters/agentic-flow/src/swarm_store.rs +++ b/crates/rvf/rvf-adapters/agentic-flow/src/swarm_store.rs @@ -81,8 +81,8 @@ impl RvfSwarmStore { ..Default::default() }; - let store = RvfStore::create(&config.store_path(), rvf_options) - .map_err(SwarmStoreError::Rvf)?; + let store = + RvfStore::create(&config.store_path(), rvf_options).map_err(SwarmStoreError::Rvf)?; Ok(Self { store, @@ -99,8 +99,7 @@ impl RvfSwarmStore { pub fn open(config: AgenticFlowConfig) -> Result { config.validate().map_err(SwarmStoreError::Config)?; - let store = - RvfStore::open(&config.store_path()).map_err(SwarmStoreError::Rvf)?; + let store = RvfStore::open(&config.store_path()).map_err(SwarmStoreError::Rvf)?; // Rebuild next_id from the store status so new IDs don't collide. let status = store.status(); @@ -138,10 +137,7 @@ impl RvfSwarmStore { }); } - let compound_key = format!( - "{}/{}/{}", - self.config.agent_id, namespace, key - ); + let compound_key = format!("{}/{}/{}", self.config.agent_id, namespace, key); // Soft-delete existing entry with the same compound key. if let Some(&old_id) = self.key_index.get(&compound_key) { @@ -194,11 +190,7 @@ impl RvfSwarmStore { /// /// Returns up to `k` results sorted by distance (closest first), /// enriched with agent metadata from the in-memory index. - pub fn search_shared( - &self, - embedding: &[f32], - k: usize, - ) -> Vec { + pub fn search_shared(&self, embedding: &[f32], k: usize) -> Vec { let options = QueryOptions::default(); let results = match self.store.query(embedding, k, &options) { Ok(r) => r, @@ -238,17 +230,12 @@ impl RvfSwarmStore { return Ok(0); } - self.store - .delete(&existing) - .map_err(SwarmStoreError::Rvf)?; + self.store.delete(&existing).map_err(SwarmStoreError::Rvf)?; let mut removed = 0; for &id in &existing { if let Some(entry) = self.entry_index.remove(&id) { - let compound_key = format!( - "{}/{}/{}", - entry.agent_id, entry.namespace, entry.key - ); + let compound_key = format!("{}/{}/{}", entry.agent_id, entry.namespace, entry.key); self.key_index.remove(&compound_key); removed += 1; } @@ -528,8 +515,7 @@ mod tests { #[test] fn agent_id_accessor() { let dir = TempDir::new().unwrap(); - let config = AgenticFlowConfig::new(dir.path(), "special-agent") - .with_dimension(4); + let config = AgenticFlowConfig::new(dir.path(), "special-agent").with_dimension(4); let store = RvfSwarmStore::create(config).unwrap(); assert_eq!(store.agent_id(), "special-agent"); diff --git a/crates/rvf/rvf-adapters/claude-flow/src/memory_store.rs b/crates/rvf/rvf-adapters/claude-flow/src/memory_store.rs index 52d8abfa2..4e3a86539 100644 --- a/crates/rvf/rvf-adapters/claude-flow/src/memory_store.rs +++ b/crates/rvf/rvf-adapters/claude-flow/src/memory_store.rs @@ -52,7 +52,9 @@ impl RvfMemoryStore { /// Create a new memory store, initializing the data directory and RVF file. pub fn create(config: ClaudeFlowConfig) -> Result { config.validate().map_err(MemoryStoreError::Config)?; - config.ensure_dirs().map_err(|e| MemoryStoreError::Io(e.to_string()))?; + config + .ensure_dirs() + .map_err(|e| MemoryStoreError::Io(e.to_string()))?; let rvf_options = RvfOptions { dimension: config.dimension, @@ -60,12 +62,11 @@ impl RvfMemoryStore { ..Default::default() }; - let store = RvfStore::create(&config.store_path(), rvf_options) - .map_err(MemoryStoreError::Rvf)?; + let store = + RvfStore::create(&config.store_path(), rvf_options).map_err(MemoryStoreError::Rvf)?; let witness = if config.enable_witness { - Some(WitnessChain::create(&config.witness_path()) - .map_err(MemoryStoreError::Witness)?) + Some(WitnessChain::create(&config.witness_path()).map_err(MemoryStoreError::Witness)?) } else { None }; @@ -83,12 +84,13 @@ impl RvfMemoryStore { pub fn open(config: ClaudeFlowConfig) -> Result { config.validate().map_err(MemoryStoreError::Config)?; - let store = RvfStore::open(&config.store_path()) - .map_err(MemoryStoreError::Rvf)?; + let store = RvfStore::open(&config.store_path()).map_err(MemoryStoreError::Rvf)?; let witness = if config.enable_witness { - Some(WitnessChain::open_or_create(&config.witness_path()) - .map_err(MemoryStoreError::Witness)?) + Some( + WitnessChain::open_or_create(&config.witness_path()) + .map_err(MemoryStoreError::Witness)?, + ) } else { None }; @@ -131,7 +133,9 @@ impl RvfMemoryStore { // If key already exists in this namespace, soft-delete the old entry. let compound_key = format!("{namespace}/{key}"); if let Some(&old_id) = self.key_index.get(&compound_key) { - self.store.delete(&[old_id]).map_err(MemoryStoreError::Rvf)?; + self.store + .delete(&[old_id]) + .map_err(MemoryStoreError::Rvf)?; } let vector_id = self.next_id; @@ -141,9 +145,18 @@ impl RvfMemoryStore { let tags_str = tags.join(","); let metadata = vec![ - MetadataEntry { field_id: FIELD_KEY, value: MetadataValue::String(key.to_string()) }, - MetadataEntry { field_id: FIELD_NAMESPACE, value: MetadataValue::String(namespace.to_string()) }, - MetadataEntry { field_id: FIELD_TAGS, value: MetadataValue::String(tags_str) }, + MetadataEntry { + field_id: FIELD_KEY, + value: MetadataValue::String(key.to_string()), + }, + MetadataEntry { + field_id: FIELD_NAMESPACE, + value: MetadataValue::String(namespace.to_string()), + }, + MetadataEntry { + field_id: FIELD_TAGS, + value: MetadataValue::String(tags_str), + }, ]; self.store @@ -174,16 +187,17 @@ impl RvfMemoryStore { }); } - let filter = namespace.map(|ns| { - FilterExpr::Eq(FIELD_NAMESPACE, FilterValue::String(ns.to_string())) - }); + let filter = namespace + .map(|ns| FilterExpr::Eq(FIELD_NAMESPACE, FilterValue::String(ns.to_string()))); let options = QueryOptions { filter, ..Default::default() }; - let results = self.store.query(query_embedding, k, &options) + let results = self + .store + .query(query_embedding, k, &options) .map_err(MemoryStoreError::Rvf)?; if let Some(ref mut w) = self.witness { @@ -198,24 +212,18 @@ impl RvfMemoryStore { /// /// Returns the vector ID if found (the entry can then be used with /// the underlying store for further operations). - pub fn retrieve_memory( - &self, - key: &str, - namespace: &str, - ) -> Option { + pub fn retrieve_memory(&self, key: &str, namespace: &str) -> Option { let compound_key = format!("{namespace}/{key}"); self.key_index.get(&compound_key).copied() } /// Soft-delete a memory entry by key and namespace. - pub fn delete_memory( - &mut self, - key: &str, - namespace: &str, - ) -> Result { + pub fn delete_memory(&mut self, key: &str, namespace: &str) -> Result { let compound_key = format!("{namespace}/{key}"); if let Some(vector_id) = self.key_index.remove(&compound_key) { - self.store.delete(&[vector_id]).map_err(MemoryStoreError::Rvf)?; + self.store + .delete(&[vector_id]) + .map_err(MemoryStoreError::Rvf)?; if let Some(ref mut w) = self.witness { let _ = w.record_delete(key, namespace); @@ -305,10 +313,15 @@ mod tests { let config = test_config(dir.path()); let mut store = RvfMemoryStore::create(config).unwrap(); - let id = store.store_memory( - "key1", "value1", "default", &["tag1".into(), "tag2".into()], - &make_embedding(1.0), - ).unwrap(); + let id = store + .store_memory( + "key1", + "value1", + "default", + &["tag1".into(), "tag2".into()], + &make_embedding(1.0), + ) + .unwrap(); assert!(id > 0); let status = store.status(); @@ -324,16 +337,26 @@ mod tests { let mut store = RvfMemoryStore::create(config).unwrap(); - store.store_memory("a", "val_a", "ns1", &[], &[1.0, 0.0, 0.0, 0.0]).unwrap(); - store.store_memory("b", "val_b", "ns1", &[], &[0.0, 1.0, 0.0, 0.0]).unwrap(); - store.store_memory("c", "val_c", "ns2", &[], &[0.0, 0.0, 1.0, 0.0]).unwrap(); + store + .store_memory("a", "val_a", "ns1", &[], &[1.0, 0.0, 0.0, 0.0]) + .unwrap(); + store + .store_memory("b", "val_b", "ns1", &[], &[0.0, 1.0, 0.0, 0.0]) + .unwrap(); + store + .store_memory("c", "val_c", "ns2", &[], &[0.0, 0.0, 1.0, 0.0]) + .unwrap(); // Search all namespaces - let results = store.search_memory(&[1.0, 0.0, 0.0, 0.0], 3, None, None).unwrap(); + let results = store + .search_memory(&[1.0, 0.0, 0.0, 0.0], 3, None, None) + .unwrap(); assert_eq!(results.len(), 3); // Search filtered by namespace - let results = store.search_memory(&[1.0, 0.0, 0.0, 0.0], 3, Some("ns1"), None).unwrap(); + let results = store + .search_memory(&[1.0, 0.0, 0.0, 0.0], 3, Some("ns1"), None) + .unwrap(); assert_eq!(results.len(), 2); store.close().unwrap(); @@ -345,7 +368,9 @@ mod tests { let config = test_config(dir.path()); let mut store = RvfMemoryStore::create(config).unwrap(); - let id = store.store_memory("mykey", "myval", "ns", &[], &make_embedding(2.0)).unwrap(); + let id = store + .store_memory("mykey", "myval", "ns", &[], &make_embedding(2.0)) + .unwrap(); assert_eq!(store.retrieve_memory("mykey", "ns"), Some(id)); assert_eq!(store.retrieve_memory("missing", "ns"), None); @@ -360,7 +385,9 @@ mod tests { let config = test_config(dir.path()); let mut store = RvfMemoryStore::create(config).unwrap(); - store.store_memory("k", "v", "ns", &[], &make_embedding(3.0)).unwrap(); + store + .store_memory("k", "v", "ns", &[], &make_embedding(3.0)) + .unwrap(); assert!(store.delete_memory("k", "ns").unwrap()); assert!(!store.delete_memory("k", "ns").unwrap()); // already deleted @@ -375,8 +402,12 @@ mod tests { let config = test_config(dir.path()); let mut store = RvfMemoryStore::create(config).unwrap(); - let id1 = store.store_memory("k", "v1", "ns", &[], &make_embedding(1.0)).unwrap(); - let id2 = store.store_memory("k", "v2", "ns", &[], &make_embedding(2.0)).unwrap(); + let id1 = store + .store_memory("k", "v1", "ns", &[], &make_embedding(1.0)) + .unwrap(); + let id2 = store + .store_memory("k", "v2", "ns", &[], &make_embedding(2.0)) + .unwrap(); // New ID should be different (old was soft-deleted) assert_ne!(id1, id2); @@ -405,8 +436,12 @@ mod tests { let config = test_config(dir.path()); let mut store = RvfMemoryStore::create(config).unwrap(); - store.store_memory("a", "v", "ns", &[], &make_embedding(1.0)).unwrap(); - store.search_memory(&make_embedding(1.0), 1, None, None).unwrap(); + store + .store_memory("a", "v", "ns", &[], &make_embedding(1.0)) + .unwrap(); + store + .search_memory(&make_embedding(1.0), 1, None, None) + .unwrap(); store.delete_memory("a", "ns").unwrap(); let witness = store.witness().unwrap(); @@ -422,8 +457,12 @@ mod tests { let config = test_config(dir.path()); let mut store = RvfMemoryStore::create(config).unwrap(); - store.store_memory("a", "v", "ns", &[], &make_embedding(1.0)).unwrap(); - store.store_memory("b", "v", "ns", &[], &make_embedding(2.0)).unwrap(); + store + .store_memory("a", "v", "ns", &[], &make_embedding(1.0)) + .unwrap(); + store + .store_memory("b", "v", "ns", &[], &make_embedding(2.0)) + .unwrap(); store.delete_memory("a", "ns").unwrap(); store.compact().unwrap(); diff --git a/crates/rvf/rvf-adapters/claude-flow/src/witness.rs b/crates/rvf/rvf-adapters/claude-flow/src/witness.rs index cccf8bc4b..1e1545d22 100644 --- a/crates/rvf/rvf-adapters/claude-flow/src/witness.rs +++ b/crates/rvf/rvf-adapters/claude-flow/src/witness.rs @@ -7,8 +7,8 @@ use std::fs::{File, OpenOptions}; use std::io::{Read, Write}; use std::path::{Path, PathBuf}; -use rvf_crypto::witness::{WitnessEntry, create_witness_chain, verify_witness_chain}; use rvf_crypto::shake256_256; +use rvf_crypto::witness::{create_witness_chain, verify_witness_chain, WitnessEntry}; /// Witness type constants for claude-flow actions. pub const WITNESS_STORE: u8 = 0x01; @@ -40,7 +40,8 @@ impl WitnessChain { pub fn open(path: &Path) -> Result { let mut file = File::open(path).map_err(|e| WitnessError::Io(e.to_string()))?; let mut data = Vec::new(); - file.read_to_end(&mut data).map_err(|e| WitnessError::Io(e.to_string()))?; + file.read_to_end(&mut data) + .map_err(|e| WitnessError::Io(e.to_string()))?; if data.is_empty() { return Ok(Self { @@ -50,8 +51,7 @@ impl WitnessChain { }); } - let entries = verify_witness_chain(&data) - .map_err(|_| WitnessError::ChainCorrupted)?; + let entries = verify_witness_chain(&data).map_err(|_| WitnessError::ChainCorrupted)?; Ok(Self { path: path.to_path_buf(), @@ -109,8 +109,8 @@ impl WitnessChain { if self.chain_data.is_empty() { return Ok(0); } - let entries = verify_witness_chain(&self.chain_data) - .map_err(|_| WitnessError::ChainCorrupted)?; + let entries = + verify_witness_chain(&self.chain_data).map_err(|_| WitnessError::ChainCorrupted)?; Ok(entries.len()) } @@ -142,8 +142,7 @@ impl WitnessChain { let mut all_entries = if self.chain_data.is_empty() { Vec::new() } else { - verify_witness_chain(&self.chain_data) - .map_err(|_| WitnessError::ChainCorrupted)? + verify_witness_chain(&self.chain_data).map_err(|_| WitnessError::ChainCorrupted)? }; all_entries.push(entry); @@ -158,7 +157,8 @@ impl WitnessChain { .truncate(true) .open(&tmp_path) .map_err(|e| WitnessError::Io(e.to_string()))?; - f.write_all(&new_chain).map_err(|e| WitnessError::Io(e.to_string()))?; + f.write_all(&new_chain) + .map_err(|e| WitnessError::Io(e.to_string()))?; f.sync_all().map_err(|e| WitnessError::Io(e.to_string()))?; } std::fs::rename(&tmp_path, &self.path).map_err(|e| WitnessError::Io(e.to_string()))?; diff --git a/crates/rvf/rvf-adapters/ospipe/src/observation_store.rs b/crates/rvf/rvf-adapters/ospipe/src/observation_store.rs index 780b6abc5..b8a2e8513 100644 --- a/crates/rvf/rvf-adapters/ospipe/src/observation_store.rs +++ b/crates/rvf/rvf-adapters/ospipe/src/observation_store.rs @@ -140,8 +140,8 @@ impl RvfObservationStore { ..Default::default() }; - let store = RvfStore::create(&config.store_path(), options) - .map_err(OspipeAdapterError::Rvf)?; + let store = + RvfStore::create(&config.store_path(), options).map_err(OspipeAdapterError::Rvf)?; Ok(Self { store, @@ -152,8 +152,7 @@ impl RvfObservationStore { /// Open an existing observation store. pub fn open(config: ObservationStoreConfig) -> Result { - let store = RvfStore::open(&config.store_path()) - .map_err(OspipeAdapterError::Rvf)?; + let store = RvfStore::open(&config.store_path()).map_err(OspipeAdapterError::Rvf)?; let status = store.status(); let next_id = status.total_vectors + status.current_epoch as u64 + 1; @@ -167,8 +166,8 @@ impl RvfObservationStore { /// Open an existing store in read-only mode. pub fn open_readonly(config: ObservationStoreConfig) -> Result { - let store = RvfStore::open_readonly(&config.store_path()) - .map_err(OspipeAdapterError::Rvf)?; + let store = + RvfStore::open_readonly(&config.store_path()).map_err(OspipeAdapterError::Rvf)?; Ok(Self { store, @@ -189,11 +188,10 @@ impl RvfObservationStore { self.next_id += 1; let entries = meta.to_entries(); - let result = self.store.ingest_batch( - &[state_vector], - &[id], - Some(&entries), - ).map_err(OspipeAdapterError::Rvf)?; + let result = self + .store + .ingest_batch(&[state_vector], &[id], Some(&entries)) + .map_err(OspipeAdapterError::Rvf)?; Ok((id, result)) } @@ -240,11 +238,18 @@ impl RvfObservationStore { } } - let result = self.store.ingest_batch( - vectors, - &ids, - if flat_entries.is_empty() { None } else { Some(&flat_entries) }, - ).map_err(OspipeAdapterError::Rvf)?; + let result = self + .store + .ingest_batch( + vectors, + &ids, + if flat_entries.is_empty() { + None + } else { + Some(&flat_entries) + }, + ) + .map_err(OspipeAdapterError::Rvf)?; Ok((ids, result)) } @@ -315,7 +320,9 @@ impl RvfObservationStore { &mut self, filter: &FilterExpr, ) -> Result { - self.store.delete_by_filter(filter).map_err(OspipeAdapterError::Rvf) + self.store + .delete_by_filter(filter) + .map_err(OspipeAdapterError::Rvf) } /// Get the current store status. @@ -366,7 +373,9 @@ mod tests { let mut v = Vec::with_capacity(dim); let mut x = seed; for _ in 0..dim { - x = x.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407); + x = x + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); v.push(((x >> 33) as f32) / (u32::MAX as f32) - 0.5); } v diff --git a/crates/rvf/rvf-adapters/ospipe/src/pipeline.rs b/crates/rvf/rvf-adapters/ospipe/src/pipeline.rs index 04610f8bf..be67a5084 100644 --- a/crates/rvf/rvf-adapters/ospipe/src/pipeline.rs +++ b/crates/rvf/rvf-adapters/ospipe/src/pipeline.rs @@ -135,10 +135,7 @@ impl RvfPipelineAdapter { /// /// Scans for observations with timestamps before `before_secs` and /// soft-deletes them. Returns the number of observations deleted. - pub fn expire_before( - &mut self, - before_secs: u64, - ) -> Result { + pub fn expire_before(&mut self, before_secs: u64) -> Result { use rvf_runtime::filter::{FilterExpr, FilterValue}; let filter = FilterExpr::Lt( @@ -185,7 +182,9 @@ mod tests { let mut v = Vec::with_capacity(dim); let mut x = seed; for _ in 0..dim { - x = x.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407); + x = x + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); v.push(((x >> 33) as f32) / (u32::MAX as f32) - 0.5); } v diff --git a/crates/rvf/rvf-adapters/sona/src/experience.rs b/crates/rvf/rvf-adapters/sona/src/experience.rs index 1cba11f2e..6678be943 100644 --- a/crates/rvf/rvf-adapters/sona/src/experience.rs +++ b/crates/rvf/rvf-adapters/sona/src/experience.rs @@ -54,7 +54,9 @@ impl ExperienceReplayBuffer { /// Create a new experience replay buffer. pub fn create(config: SonaConfig) -> Result { config.validate().map_err(ExperienceStoreError::Config)?; - config.ensure_dirs().map_err(|e| ExperienceStoreError::Io(e.to_string()))?; + config + .ensure_dirs() + .map_err(|e| ExperienceStoreError::Io(e.to_string()))?; let rvf_options = RvfOptions { dimension: config.dimension, @@ -99,7 +101,9 @@ impl ExperienceReplayBuffer { if self.experience_ids.len() >= self.config.replay_capacity { if let Some(old_id) = self.experience_ids.pop_front() { self.experience_meta.pop_front(); - self.store.delete(&[old_id]).map_err(ExperienceStoreError::Rvf)?; + self.store + .delete(&[old_id]) + .map_err(ExperienceStoreError::Rvf)?; } } @@ -107,11 +111,26 @@ impl ExperienceReplayBuffer { self.next_id += 1; let metadata = vec![ - MetadataEntry { field_id: FIELD_STEP_ID, value: MetadataValue::U64(vector_id) }, - MetadataEntry { field_id: FIELD_ACTION, value: MetadataValue::String(action.to_string()) }, - MetadataEntry { field_id: FIELD_REWARD, value: MetadataValue::F64(reward) }, - MetadataEntry { field_id: FIELD_CATEGORY, value: MetadataValue::String(String::new()) }, - MetadataEntry { field_id: FIELD_TYPE, value: MetadataValue::String(TYPE_EXPERIENCE.to_string()) }, + MetadataEntry { + field_id: FIELD_STEP_ID, + value: MetadataValue::U64(vector_id), + }, + MetadataEntry { + field_id: FIELD_ACTION, + value: MetadataValue::String(action.to_string()), + }, + MetadataEntry { + field_id: FIELD_REWARD, + value: MetadataValue::F64(reward), + }, + MetadataEntry { + field_id: FIELD_CATEGORY, + value: MetadataValue::String(String::new()), + }, + MetadataEntry { + field_id: FIELD_TYPE, + value: MetadataValue::String(TYPE_EXPERIENCE.to_string()), + }, ]; self.store @@ -191,7 +210,8 @@ impl ExperienceReplayBuffer { }); } - let results = self.store + let results = self + .store .query(embedding, n, &QueryOptions::default()) .map_err(ExperienceStoreError::Rvf)?; @@ -224,7 +244,9 @@ impl ExperienceReplayBuffer { results .iter() .map(|r| { - let meta = self.experience_ids.iter() + let meta = self + .experience_ids + .iter() .zip(self.experience_meta.iter()) .find(|(&vid, _)| vid == r.id) .map(|(_, m)| m); @@ -295,9 +317,12 @@ mod tests { let config = test_config(dir.path()); let mut buf = ExperienceReplayBuffer::create(config).unwrap(); - buf.push(&make_embedding(1.0), "explore", 0.5, &make_embedding(1.1)).unwrap(); - buf.push(&make_embedding(2.0), "exploit", 0.8, &make_embedding(2.1)).unwrap(); - buf.push(&make_embedding(3.0), "explore", 0.3, &make_embedding(3.1)).unwrap(); + buf.push(&make_embedding(1.0), "explore", 0.5, &make_embedding(1.1)) + .unwrap(); + buf.push(&make_embedding(2.0), "exploit", 0.8, &make_embedding(2.1)) + .unwrap(); + buf.push(&make_embedding(3.0), "explore", 0.3, &make_embedding(3.1)) + .unwrap(); assert_eq!(buf.len(), 3); assert!(!buf.is_full()); @@ -315,7 +340,13 @@ mod tests { let mut buf = ExperienceReplayBuffer::create(config).unwrap(); for i in 0..7 { - buf.push(&make_embedding(i as f32 + 0.1), &format!("act{i}"), i as f64 * 0.1, &make_embedding(0.0)).unwrap(); + buf.push( + &make_embedding(i as f32 + 0.1), + &format!("act{i}"), + i as f64 * 0.1, + &make_embedding(0.0), + ) + .unwrap(); } assert_eq!(buf.len(), 5); @@ -335,9 +366,12 @@ mod tests { let config = test_config(dir.path()); let mut buf = ExperienceReplayBuffer::create(config).unwrap(); - buf.push(&[1.0, 0.0, 0.0, 0.0], "a", 0.1, &[0.0; 4]).unwrap(); - buf.push(&[0.0, 1.0, 0.0, 0.0], "b", 0.2, &[0.0; 4]).unwrap(); - buf.push(&[0.9, 0.1, 0.0, 0.0], "c", 0.3, &[0.0; 4]).unwrap(); + buf.push(&[1.0, 0.0, 0.0, 0.0], "a", 0.1, &[0.0; 4]) + .unwrap(); + buf.push(&[0.0, 1.0, 0.0, 0.0], "b", 0.2, &[0.0; 4]) + .unwrap(); + buf.push(&[0.9, 0.1, 0.0, 0.0], "c", 0.3, &[0.0; 4]) + .unwrap(); let results = buf.sample_prioritized(2, &[1.0, 0.0, 0.0, 0.0]).unwrap(); assert_eq!(results.len(), 2); @@ -371,8 +405,10 @@ mod tests { let config = test_config(dir.path()); let mut buf = ExperienceReplayBuffer::create(config).unwrap(); - buf.push(&make_embedding(1.0), "a", 0.1, &make_embedding(0.0)).unwrap(); - buf.push(&make_embedding(2.0), "b", 0.2, &make_embedding(0.0)).unwrap(); + buf.push(&make_embedding(1.0), "a", 0.1, &make_embedding(0.0)) + .unwrap(); + buf.push(&make_embedding(2.0), "b", 0.2, &make_embedding(0.0)) + .unwrap(); let samples = buf.sample(10); assert_eq!(samples.len(), 2); diff --git a/crates/rvf/rvf-adapters/sona/src/pattern.rs b/crates/rvf/rvf-adapters/sona/src/pattern.rs index af73119e3..e643cf06f 100644 --- a/crates/rvf/rvf-adapters/sona/src/pattern.rs +++ b/crates/rvf/rvf-adapters/sona/src/pattern.rs @@ -62,15 +62,17 @@ impl NeuralPatternStore { /// Create a new neural pattern store. pub fn create(config: SonaConfig) -> Result { config.validate().map_err(PatternStoreError::Config)?; - config.ensure_dirs().map_err(|e| PatternStoreError::Io(e.to_string()))?; + config + .ensure_dirs() + .map_err(|e| PatternStoreError::Io(e.to_string()))?; let rvf_options = RvfOptions { dimension: config.dimension, ..Default::default() }; - let store = RvfStore::create(&config.store_path(), rvf_options) - .map_err(PatternStoreError::Rvf)?; + let store = + RvfStore::create(&config.store_path(), rvf_options).map_err(PatternStoreError::Rvf)?; Ok(Self { store, @@ -102,11 +104,26 @@ impl NeuralPatternStore { self.next_id += 1; let metadata = vec![ - MetadataEntry { field_id: FIELD_STEP_ID, value: MetadataValue::U64(vector_id) }, - MetadataEntry { field_id: FIELD_NAME, value: MetadataValue::String(name.to_string()) }, - MetadataEntry { field_id: FIELD_CONFIDENCE, value: MetadataValue::F64(confidence) }, - MetadataEntry { field_id: FIELD_CATEGORY, value: MetadataValue::String(category.to_string()) }, - MetadataEntry { field_id: FIELD_TYPE, value: MetadataValue::String(TYPE_PATTERN.to_string()) }, + MetadataEntry { + field_id: FIELD_STEP_ID, + value: MetadataValue::U64(vector_id), + }, + MetadataEntry { + field_id: FIELD_NAME, + value: MetadataValue::String(name.to_string()), + }, + MetadataEntry { + field_id: FIELD_CONFIDENCE, + value: MetadataValue::F64(confidence), + }, + MetadataEntry { + field_id: FIELD_CATEGORY, + value: MetadataValue::String(category.to_string()), + }, + MetadataEntry { + field_id: FIELD_TYPE, + value: MetadataValue::String(TYPE_PATTERN.to_string()), + }, ]; self.store @@ -140,7 +157,8 @@ impl NeuralPatternStore { }); } - let results = self.store + let results = self + .store .query(embedding, k, &QueryOptions::default()) .map_err(PatternStoreError::Rvf)?; @@ -180,7 +198,9 @@ impl NeuralPatternStore { /// Get the top `k` patterns ranked by confidence (highest first). pub fn get_top_patterns(&self, k: usize) -> Vec { - let mut all: Vec<_> = self.patterns.iter() + let mut all: Vec<_> = self + .patterns + .iter() .map(|(&vid, meta)| NeuralPattern { id: vid, name: meta.name.clone(), @@ -191,7 +211,9 @@ impl NeuralPatternStore { .collect(); all.sort_by(|a, b| { - b.confidence.partial_cmp(&a.confidence).unwrap_or(std::cmp::Ordering::Equal) + b.confidence + .partial_cmp(&a.confidence) + .unwrap_or(std::cmp::Ordering::Equal) }); all.truncate(k); all @@ -217,23 +239,21 @@ impl NeuralPatternStore { fn enrich_results(&self, results: &[rvf_runtime::SearchResult]) -> Vec { results .iter() - .map(|r| { - match self.patterns.get(&r.id) { - Some(meta) => NeuralPattern { - id: r.id, - name: meta.name.clone(), - category: meta.category.clone(), - confidence: meta.confidence, - distance: r.distance, - }, - None => NeuralPattern { - id: r.id, - name: String::new(), - category: String::new(), - confidence: 0.0, - distance: r.distance, - }, - } + .map(|r| match self.patterns.get(&r.id) { + Some(meta) => NeuralPattern { + id: r.id, + name: meta.name.clone(), + category: meta.category.clone(), + confidence: meta.confidence, + distance: r.distance, + }, + None => NeuralPattern { + id: r.id, + name: String::new(), + category: String::new(), + confidence: 0.0, + distance: r.distance, + }, }) .collect() } @@ -289,9 +309,15 @@ mod tests { let config = test_config(dir.path()); let mut store = NeuralPatternStore::create(config).unwrap(); - store.store_pattern("convergent", "thinking", &[1.0, 0.0, 0.0, 0.0], 0.9).unwrap(); - store.store_pattern("divergent", "thinking", &[0.0, 1.0, 0.0, 0.0], 0.7).unwrap(); - store.store_pattern("lateral", "creative", &[0.0, 0.0, 1.0, 0.0], 0.8).unwrap(); + store + .store_pattern("convergent", "thinking", &[1.0, 0.0, 0.0, 0.0], 0.9) + .unwrap(); + store + .store_pattern("divergent", "thinking", &[0.0, 1.0, 0.0, 0.0], 0.7) + .unwrap(); + store + .store_pattern("lateral", "creative", &[0.0, 0.0, 1.0, 0.0], 0.8) + .unwrap(); let results = store.search_patterns(&[1.0, 0.0, 0.0, 0.0], 2).unwrap(); assert_eq!(results.len(), 2); @@ -306,9 +332,15 @@ mod tests { let config = test_config(dir.path()); let mut store = NeuralPatternStore::create(config).unwrap(); - store.store_pattern("p1", "alpha", &make_embedding(1.0), 0.9).unwrap(); - store.store_pattern("p2", "beta", &make_embedding(2.0), 0.7).unwrap(); - store.store_pattern("p3", "alpha", &make_embedding(3.0), 0.8).unwrap(); + store + .store_pattern("p1", "alpha", &make_embedding(1.0), 0.9) + .unwrap(); + store + .store_pattern("p2", "beta", &make_embedding(2.0), 0.7) + .unwrap(); + store + .store_pattern("p3", "alpha", &make_embedding(3.0), 0.8) + .unwrap(); let alpha = store.get_by_category("alpha"); assert_eq!(alpha.len(), 2); @@ -330,7 +362,9 @@ mod tests { let config = test_config(dir.path()); let mut store = NeuralPatternStore::create(config).unwrap(); - let id = store.store_pattern("p1", "cat", &make_embedding(1.0), 0.5).unwrap(); + let id = store + .store_pattern("p1", "cat", &make_embedding(1.0), 0.5) + .unwrap(); store.update_confidence(id, 0.95).unwrap(); @@ -359,9 +393,15 @@ mod tests { let config = test_config(dir.path()); let mut store = NeuralPatternStore::create(config).unwrap(); - store.store_pattern("low", "cat", &make_embedding(1.0), 0.3).unwrap(); - store.store_pattern("high", "cat", &make_embedding(2.0), 0.9).unwrap(); - store.store_pattern("mid", "cat", &make_embedding(3.0), 0.6).unwrap(); + store + .store_pattern("low", "cat", &make_embedding(1.0), 0.3) + .unwrap(); + store + .store_pattern("high", "cat", &make_embedding(2.0), 0.9) + .unwrap(); + store + .store_pattern("mid", "cat", &make_embedding(3.0), 0.6) + .unwrap(); let top = store.get_top_patterns(2); assert_eq!(top.len(), 2); @@ -377,7 +417,9 @@ mod tests { let config = test_config(dir.path()); let mut store = NeuralPatternStore::create(config).unwrap(); - store.store_pattern("only", "cat", &make_embedding(1.0), 0.5).unwrap(); + store + .store_pattern("only", "cat", &make_embedding(1.0), 0.5) + .unwrap(); let top = store.get_top_patterns(10); assert_eq!(top.len(), 1); diff --git a/crates/rvf/rvf-adapters/sona/src/trajectory.rs b/crates/rvf/rvf-adapters/sona/src/trajectory.rs index 58c6958ac..5096fa2d9 100644 --- a/crates/rvf/rvf-adapters/sona/src/trajectory.rs +++ b/crates/rvf/rvf-adapters/sona/src/trajectory.rs @@ -55,15 +55,17 @@ impl TrajectoryStore { /// Create a new trajectory store, initializing the data directory and RVF file. pub fn create(config: SonaConfig) -> Result { config.validate().map_err(SonaStoreError::Config)?; - config.ensure_dirs().map_err(|e| SonaStoreError::Io(e.to_string()))?; + config + .ensure_dirs() + .map_err(|e| SonaStoreError::Io(e.to_string()))?; let rvf_options = RvfOptions { dimension: config.dimension, ..Default::default() }; - let store = RvfStore::create(&config.store_path(), rvf_options) - .map_err(SonaStoreError::Rvf)?; + let store = + RvfStore::create(&config.store_path(), rvf_options).map_err(SonaStoreError::Rvf)?; Ok(Self { store, @@ -95,11 +97,26 @@ impl TrajectoryStore { self.next_id += 1; let metadata = vec![ - MetadataEntry { field_id: FIELD_STEP_ID, value: MetadataValue::U64(step_id) }, - MetadataEntry { field_id: FIELD_ACTION, value: MetadataValue::String(action.to_string()) }, - MetadataEntry { field_id: FIELD_REWARD, value: MetadataValue::F64(reward) }, - MetadataEntry { field_id: FIELD_CATEGORY, value: MetadataValue::String(String::new()) }, - MetadataEntry { field_id: FIELD_TYPE, value: MetadataValue::String(TYPE_TRAJECTORY.to_string()) }, + MetadataEntry { + field_id: FIELD_STEP_ID, + value: MetadataValue::U64(step_id), + }, + MetadataEntry { + field_id: FIELD_ACTION, + value: MetadataValue::String(action.to_string()), + }, + MetadataEntry { + field_id: FIELD_REWARD, + value: MetadataValue::F64(reward), + }, + MetadataEntry { + field_id: FIELD_CATEGORY, + value: MetadataValue::String(String::new()), + }, + MetadataEntry { + field_id: FIELD_TYPE, + value: MetadataValue::String(TYPE_TRAJECTORY.to_string()), + }, ]; self.store @@ -107,7 +124,8 @@ impl TrajectoryStore { .map_err(SonaStoreError::Rvf)?; self.step_ids.push_back(vector_id); - self.step_meta.push_back((step_id, action.to_string(), reward)); + self.step_meta + .push_back((step_id, action.to_string(), reward)); // Trim to trajectory window size. while self.step_ids.len() > self.config.trajectory_window { @@ -152,7 +170,8 @@ impl TrajectoryStore { }); } - let results = self.store + let results = self + .store .query(embedding, k, &QueryOptions::default()) .map_err(SonaStoreError::Rvf)?; @@ -184,7 +203,9 @@ impl TrajectoryStore { } if !ids_to_delete.is_empty() { - self.store.delete(&ids_to_delete).map_err(SonaStoreError::Rvf)?; + self.store + .delete(&ids_to_delete) + .map_err(SonaStoreError::Rvf)?; } Ok(ids_to_delete.len()) @@ -212,7 +233,9 @@ impl TrajectoryStore { results .iter() .map(|r| { - let meta = self.step_ids.iter() + let meta = self + .step_ids + .iter() .zip(self.step_meta.iter()) .find(|(&vid, _)| vid == r.id) .map(|(_, m)| m); @@ -285,9 +308,15 @@ mod tests { let config = test_config(dir.path()); let mut store = TrajectoryStore::create(config).unwrap(); - store.record_step(1, &make_embedding(1.0), "explore", 0.5).unwrap(); - store.record_step(2, &make_embedding(2.0), "exploit", 0.8).unwrap(); - store.record_step(3, &make_embedding(3.0), "explore", 0.3).unwrap(); + store + .record_step(1, &make_embedding(1.0), "explore", 0.5) + .unwrap(); + store + .record_step(2, &make_embedding(2.0), "exploit", 0.8) + .unwrap(); + store + .record_step(3, &make_embedding(3.0), "explore", 0.3) + .unwrap(); let recent = store.get_recent(2); assert_eq!(recent.len(), 2); @@ -304,7 +333,9 @@ mod tests { let config = test_config(dir.path()); let mut store = TrajectoryStore::create(config).unwrap(); - store.record_step(1, &make_embedding(1.0), "a", 0.1).unwrap(); + store + .record_step(1, &make_embedding(1.0), "a", 0.1) + .unwrap(); let recent = store.get_recent(10); assert_eq!(recent.len(), 1); @@ -320,7 +351,9 @@ mod tests { let mut store = TrajectoryStore::create(config).unwrap(); for i in 0..8 { - store.record_step(i, &make_embedding(i as f32 + 0.1), "act", 0.1).unwrap(); + store + .record_step(i, &make_embedding(i as f32 + 0.1), "act", 0.1) + .unwrap(); } assert_eq!(store.len(), 5); @@ -339,11 +372,19 @@ mod tests { let config = test_config(dir.path()); let mut store = TrajectoryStore::create(config).unwrap(); - store.record_step(1, &[1.0, 0.0, 0.0, 0.0], "a", 0.1).unwrap(); - store.record_step(2, &[0.0, 1.0, 0.0, 0.0], "b", 0.2).unwrap(); - store.record_step(3, &[0.9, 0.1, 0.0, 0.0], "c", 0.3).unwrap(); - - let results = store.search_similar_states(&[1.0, 0.0, 0.0, 0.0], 2).unwrap(); + store + .record_step(1, &[1.0, 0.0, 0.0, 0.0], "a", 0.1) + .unwrap(); + store + .record_step(2, &[0.0, 1.0, 0.0, 0.0], "b", 0.2) + .unwrap(); + store + .record_step(3, &[0.9, 0.1, 0.0, 0.0], "c", 0.3) + .unwrap(); + + let results = store + .search_similar_states(&[1.0, 0.0, 0.0, 0.0], 2) + .unwrap(); assert_eq!(results.len(), 2); // Closest to [1,0,0,0] should be step 1 or step 3 assert!(results[0].distance <= results[1].distance); @@ -358,7 +399,9 @@ mod tests { let mut store = TrajectoryStore::create(config).unwrap(); for i in 0..5 { - store.record_step(i, &make_embedding(i as f32 + 0.1), "act", 0.1).unwrap(); + store + .record_step(i, &make_embedding(i as f32 + 0.1), "act", 0.1) + .unwrap(); } let removed = store.clear_old(2).unwrap(); @@ -379,7 +422,9 @@ mod tests { let config = test_config(dir.path()); let mut store = TrajectoryStore::create(config).unwrap(); - store.record_step(1, &make_embedding(1.0), "a", 0.1).unwrap(); + store + .record_step(1, &make_embedding(1.0), "a", 0.1) + .unwrap(); let removed = store.clear_old(10).unwrap(); assert_eq!(removed, 0); @@ -399,7 +444,9 @@ mod tests { assert!(store.get_recent(5).is_empty()); assert!(store.get_trajectory_window().is_empty()); - let results = store.search_similar_states(&make_embedding(1.0), 5).unwrap(); + let results = store + .search_similar_states(&make_embedding(1.0), 5) + .unwrap(); assert!(results.is_empty()); store.close().unwrap(); diff --git a/crates/rvf/rvf-federation/benches/federation_bench.rs b/crates/rvf/rvf-federation/benches/federation_bench.rs index 443d18ee9..2566abc45 100644 --- a/crates/rvf/rvf-federation/benches/federation_bench.rs +++ b/crates/rvf/rvf-federation/benches/federation_bench.rs @@ -1,12 +1,12 @@ //! Benchmarks for rvf-federation crate. -use criterion::{criterion_group, criterion_main, Criterion, black_box}; -use rvf_federation::*; -use rvf_federation::aggregate::{FederatedAggregator, AggregationStrategy, Contribution}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use rvf_federation::aggregate::{AggregationStrategy, Contribution, FederatedAggregator}; use rvf_federation::diff_privacy::{DiffPrivacyEngine, PrivacyAccountant}; -use rvf_federation::pii_strip::PiiStripper; use rvf_federation::federation::{ExportBuilder, ImportMerger}; +use rvf_federation::pii_strip::PiiStripper; use rvf_federation::policy::FederationPolicy; +use rvf_federation::*; fn bench_pii_strip(c: &mut Criterion) { let mut group = c.benchmark_group("pii_strip"); @@ -20,15 +20,17 @@ fn bench_pii_strip(c: &mut Criterion) { }); group.bench_function("strip_10_fields", |b| { - let fields: Vec<(&str, &str)> = (0..10).map(|i| { - if i % 3 == 0 { - ("path", "/home/user/data/file.csv") - } else if i % 3 == 1 { - ("ip", "server at 10.0.0.1:8080") - } else { - ("clean", "no pii here at all") - } - }).collect(); + let fields: Vec<(&str, &str)> = (0..10) + .map(|i| { + if i % 3 == 0 { + ("path", "/home/user/data/file.csv") + } else if i % 3 == 1 { + ("ip", "server at 10.0.0.1:8080") + } else { + ("clean", "no pii here at all") + } + }) + .collect(); b.iter(|| { let mut stripper = PiiStripper::new(); black_box(stripper.strip_fields(black_box(&fields))); @@ -36,13 +38,15 @@ fn bench_pii_strip(c: &mut Criterion) { }); group.bench_function("strip_100_fields", |b| { - let fields: Vec<(&str, &str)> = (0..100).map(|i| { - if i % 5 == 0 { - ("path", "/home/user/data/file.csv") - } else { - ("clean", "just normal text content") - } - }).collect(); + let fields: Vec<(&str, &str)> = (0..100) + .map(|i| { + if i % 5 == 0 { + ("path", "/home/user/data/file.csv") + } else { + ("clean", "just normal text content") + } + }) + .collect(); b.iter(|| { let mut stripper = PiiStripper::new(); black_box(stripper.strip_fields(black_box(&fields))); @@ -57,7 +61,9 @@ fn bench_diff_privacy(c: &mut Criterion) { group.bench_function("gaussian_noise_100_params", |b| { b.iter(|| { - let mut engine = DiffPrivacyEngine::gaussian(1.0, 1e-5, 1.0, 10.0).unwrap().with_seed(42); + let mut engine = DiffPrivacyEngine::gaussian(1.0, 1e-5, 1.0, 10.0) + .unwrap() + .with_seed(42); let mut params: Vec = (0..100).map(|i| i as f64 * 0.01).collect(); black_box(engine.add_noise(black_box(&mut params))); }); @@ -65,7 +71,9 @@ fn bench_diff_privacy(c: &mut Criterion) { group.bench_function("gaussian_noise_10000_params", |b| { b.iter(|| { - let mut engine = DiffPrivacyEngine::gaussian(1.0, 1e-5, 1.0, 10.0).unwrap().with_seed(42); + let mut engine = DiffPrivacyEngine::gaussian(1.0, 1e-5, 1.0, 10.0) + .unwrap() + .with_seed(42); let mut params: Vec = (0..10_000).map(|i| i as f64 * 0.0001).collect(); black_box(engine.add_noise(black_box(&mut params))); }); @@ -165,21 +173,28 @@ fn bench_export_import(c: &mut Criterion) { group.bench_function("full_export_pipeline", |b| { b.iter(|| { - let mut dp = DiffPrivacyEngine::gaussian(1.0, 1e-5, 1.0, 10.0).unwrap().with_seed(42); + let mut dp = DiffPrivacyEngine::gaussian(1.0, 1e-5, 1.0, 10.0) + .unwrap() + .with_seed(42); let priors = TransferPriorSet { source_domain: "/home/user/my_domain".to_string(), - entries: (0..20).map(|i| TransferPriorEntry { - bucket_id: format!("bucket_{}", i), - arm_id: format!("arm_{}", i % 4), - params: BetaParams::new(5.0 + i as f64, 3.0 + i as f64 * 0.5), - observation_count: 50 + i * 10, - }).collect(), + entries: (0..20) + .map(|i| TransferPriorEntry { + bucket_id: format!("bucket_{}", i), + arm_id: format!("arm_{}", i % 4), + params: BetaParams::new(5.0 + i as f64, 3.0 + i as f64 * 0.5), + observation_count: 50 + i * 10, + }) + .collect(), cost_ema: 0.85, }; let export = ExportBuilder::new("pseudo".into(), "domain".into()) .add_priors(priors) .add_weights((0..256).map(|i| i as f64 * 0.001).collect()) - .add_string_field("note".into(), "trained on /home/user/data at 192.168.1.1".into()) + .add_string_field( + "note".into(), + "trained on /home/user/data at 192.168.1.1".into(), + ) .build(&mut dp) .unwrap(); black_box(export); @@ -188,19 +203,23 @@ fn bench_export_import(c: &mut Criterion) { group.bench_function("merge_100_priors", |b| { let merger = ImportMerger::new(); - let remote: Vec = (0..100).map(|i| TransferPriorEntry { - bucket_id: format!("bucket_{}", i), - arm_id: format!("arm_{}", i % 4), - params: BetaParams::new(10.0, 5.0), - observation_count: 50, - }).collect(); - b.iter(|| { - let mut local: Vec = (0..50).map(|i| TransferPriorEntry { + let remote: Vec = (0..100) + .map(|i| TransferPriorEntry { bucket_id: format!("bucket_{}", i), arm_id: format!("arm_{}", i % 4), - params: BetaParams::new(5.0, 3.0), - observation_count: 20, - }).collect(); + params: BetaParams::new(10.0, 5.0), + observation_count: 50, + }) + .collect(); + b.iter(|| { + let mut local: Vec = (0..50) + .map(|i| TransferPriorEntry { + bucket_id: format!("bucket_{}", i), + arm_id: format!("arm_{}", i % 4), + params: BetaParams::new(5.0, 3.0), + observation_count: 20, + }) + .collect(); merger.merge_priors(black_box(&mut local), black_box(&remote), 1); black_box(local); }); @@ -209,5 +228,11 @@ fn bench_export_import(c: &mut Criterion) { group.finish(); } -criterion_group!(benches, bench_pii_strip, bench_diff_privacy, bench_aggregation, bench_export_import); +criterion_group!( + benches, + bench_pii_strip, + bench_diff_privacy, + bench_aggregation, + bench_export_import +); criterion_main!(benches); diff --git a/crates/rvf/rvf-federation/src/aggregate.rs b/crates/rvf/rvf-federation/src/aggregate.rs index fb2091845..afccdf2f2 100644 --- a/crates/rvf/rvf-federation/src/aggregate.rs +++ b/crates/rvf/rvf-federation/src/aggregate.rs @@ -108,12 +108,15 @@ impl FederatedAggregator { } // Compute mean and std of L2 norms - let norms: Vec = self.contributions.iter() + let norms: Vec = self + .contributions + .iter() .map(|c| c.weights.iter().map(|w| w * w).sum::().sqrt()) .collect(); let mean_norm = norms.iter().sum::() / norms.len() as f64; - let variance = norms.iter().map(|n| (n - mean_norm).powi(2)).sum::() / norms.len() as f64; + let variance = + norms.iter().map(|n| (n - mean_norm).powi(2)).sum::() / norms.len() as f64; let std_dev = variance.sqrt(); if std_dev < 1e-10 { @@ -162,14 +165,17 @@ impl FederatedAggregator { let participation_count = self.contributions.len() as u32; // Compute loss stats - let losses: Vec = self.contributions.iter() + let losses: Vec = self + .contributions + .iter() .map(|c| { // Use inverse quality as a proxy for loss 1.0 - c.quality_weight.clamp(0.0, 1.0) }) .collect(); let mean_loss = losses.iter().sum::() / losses.len() as f64; - let loss_variance = losses.iter().map(|l| (l - mean_loss).powi(2)).sum::() / losses.len() as f64; + let loss_variance = + losses.iter().map(|l| (l - mean_loss).powi(2)).sum::() / losses.len() as f64; self.contributions.clear(); @@ -188,7 +194,9 @@ impl FederatedAggregator { /// FedAvg: weighted average by trajectory count. fn fedavg(&self, dim: usize) -> (Vec, Vec) { - let total_trajectories: f64 = self.contributions.iter() + let total_trajectories: f64 = self + .contributions + .iter() .map(|c| c.trajectory_count as f64) .sum(); @@ -211,12 +219,19 @@ impl FederatedAggregator { // Confidence = inverse of variance across contributions per dimension for i in 0..dim { let mean = avg[i]; - let var: f64 = self.contributions.iter() + let var: f64 = self + .contributions + .iter() .map(|c| { - let v = if i < c.weights.len() { c.weights[i] } else { 0.0 }; + let v = if i < c.weights.len() { + c.weights[i] + } else { + 0.0 + }; (v - mean).powi(2) }) - .sum::() / self.contributions.len() as f64; + .sum::() + / self.contributions.len() as f64; confidences[i] = 1.0 / (1.0 + var); } @@ -255,12 +270,19 @@ impl FederatedAggregator { for i in 0..dim { let mean = avg[i]; - let var: f64 = self.contributions.iter() + let var: f64 = self + .contributions + .iter() .map(|c| { - let v = if i < c.weights.len() { c.weights[i] } else { 0.0 }; + let v = if i < c.weights.len() { + c.weights[i] + } else { + 0.0 + }; (v - mean).powi(2) }) - .sum::() / self.contributions.len() as f64; + .sum::() + / self.contributions.len() as f64; confidences[i] = 1.0 / (1.0 + var); } @@ -272,7 +294,12 @@ impl FederatedAggregator { mod tests { use super::*; - fn make_contribution(name: &str, weights: Vec, quality: f64, trajectories: u64) -> Contribution { + fn make_contribution( + name: &str, + weights: Vec, + quality: f64, + trajectories: u64, + ) -> Contribution { Contribution { contributor: name.to_string(), weights, @@ -319,8 +346,9 @@ mod tests { agg_avg.add_contribution(make_contribution("b", vec![10.0], 1.0, 100)); let avg_result = agg_avg.aggregate().unwrap(); - let mut agg_prox = FederatedAggregator::new("test".into(), AggregationStrategy::FedProx { mu: 50 }) - .with_min_contributions(2); + let mut agg_prox = + FederatedAggregator::new("test".into(), AggregationStrategy::FedProx { mu: 50 }) + .with_min_contributions(2); agg_prox.add_contribution(make_contribution("a", vec![10.0], 1.0, 100)); agg_prox.add_contribution(make_contribution("b", vec![10.0], 1.0, 100)); let prox_result = agg_prox.aggregate().unwrap(); diff --git a/crates/rvf/rvf-federation/src/federation.rs b/crates/rvf/rvf-federation/src/federation.rs index 89d89c68d..00e250fb3 100644 --- a/crates/rvf/rvf-federation/src/federation.rs +++ b/crates/rvf/rvf-federation/src/federation.rs @@ -95,15 +95,21 @@ impl ExportBuilder { } /// Build the export: PII-strip, add DP noise, assemble manifest. - pub fn build(mut self, dp_engine: &mut DiffPrivacyEngine) -> Result { + pub fn build( + mut self, + dp_engine: &mut DiffPrivacyEngine, + ) -> Result { // 1. Apply quality gate from policy self.priors.retain(|ps| { - ps.entries.iter().all(|e| e.observation_count >= self.policy.min_observations) + ps.entries + .iter() + .all(|e| e.observation_count >= self.policy.min_observations) }); // 2. PII stripping let mut stripper = PiiStripper::new(); - let field_refs: Vec<(&str, &str)> = self.string_fields + let field_refs: Vec<(&str, &str)> = self + .string_fields .iter() .map(|(n, v)| (n.as_str(), v.as_str())) .collect(); @@ -168,17 +174,25 @@ impl ExportBuilder { }; // 4. Build manifest - let total_trajectories: u64 = self.priors.iter() + let total_trajectories: u64 = self + .priors + .iter() .flat_map(|ps| ps.entries.iter()) .map(|e| e.observation_count) .sum(); let avg_quality = if !self.priors.is_empty() { - self.priors.iter() + self.priors + .iter() .flat_map(|ps| ps.entries.iter()) .map(|e| e.params.mean()) .sum::() - / self.priors.iter().map(|ps| ps.entries.len()).sum::().max(1) as f64 + / self + .priors + .iter() + .map(|ps| ps.entries.len()) + .sum::() + .max(1) as f64 } else { 0.0 }; @@ -246,7 +260,9 @@ impl ImportMerger { // Check privacy proof has valid parameters if export.privacy_proof.epsilon <= 0.0 { - return Err(FederationError::InvalidEpsilon(export.privacy_proof.epsilon)); + return Err(FederationError::InvalidEpsilon( + export.privacy_proof.epsilon, + )); } // Check priors have positive parameters @@ -283,9 +299,10 @@ impl ImportMerger { for remote_entry in remote { let dampened = remote_entry.params.dampen(dampen); - if let Some(local_entry) = local.iter_mut().find(|l| { - l.bucket_id == remote_entry.bucket_id && l.arm_id == remote_entry.arm_id - }) { + if let Some(local_entry) = local + .iter_mut() + .find(|l| l.bucket_id == remote_entry.bucket_id && l.arm_id == remote_entry.arm_id) + { // Merge: sum parameters minus uniform prior local_entry.params = local_entry.params.merge(&dampened); local_entry.observation_count += remote_entry.observation_count; @@ -380,7 +397,10 @@ mod tests { assert_eq!(export.weights.len(), 1); // Weights should be different from original (noise added) - assert!(export.weights[0].iter().zip(weights.iter()).any(|(a, b)| (a - b).abs() > 1e-10)); + assert!(export.weights[0] + .iter() + .zip(weights.iter()) + .any(|(a, b)| (a - b).abs() > 1e-10)); } #[test] diff --git a/crates/rvf/rvf-federation/src/lib.rs b/crates/rvf/rvf-federation/src/lib.rs index 767fd1c3d..399d2d0c9 100644 --- a/crates/rvf/rvf-federation/src/lib.rs +++ b/crates/rvf/rvf-federation/src/lib.rs @@ -7,18 +7,18 @@ //! - **Federated aggregation**: FedAvg, FedProx, Byzantine-tolerant weighted averaging //! - **Segment types**: FederatedManifest, DiffPrivacyProof, RedactionLog, AggregateWeights -pub mod types; -pub mod error; -pub mod pii_strip; +pub mod aggregate; pub mod diff_privacy; +pub mod error; pub mod federation; -pub mod aggregate; +pub mod pii_strip; pub mod policy; +pub mod types; -pub use types::*; -pub use error::FederationError; -pub use pii_strip::PiiStripper; +pub use aggregate::{AggregationStrategy, FederatedAggregator}; pub use diff_privacy::{DiffPrivacyEngine, PrivacyAccountant}; +pub use error::FederationError; pub use federation::{ExportBuilder, ImportMerger}; -pub use aggregate::{FederatedAggregator, AggregationStrategy}; +pub use pii_strip::PiiStripper; pub use policy::FederationPolicy; +pub use types::*; diff --git a/crates/rvf/rvf-federation/src/pii_strip.rs b/crates/rvf/rvf-federation/src/pii_strip.rs index 7952ed531..38dfc7368 100644 --- a/crates/rvf/rvf-federation/src/pii_strip.rs +++ b/crates/rvf/rvf-federation/src/pii_strip.rs @@ -4,11 +4,14 @@ //! **Stage 2 — Redaction**: Replace PII with deterministic pseudonyms. //! **Stage 3 — Attestation**: Generate a `RedactionLog` segment. -use std::collections::HashMap; use regex::Regex; -use sha3::{Shake256, digest::{Update, ExtendableOutput, XofReader}}; +use sha3::{ + digest::{ExtendableOutput, Update, XofReader}, + Shake256, +}; +use std::collections::HashMap; -use crate::types::{RedactionLog, RedactionEntry}; +use crate::types::{RedactionEntry, RedactionLog}; /// PII category with its detection regex and replacement template. struct PiiRule { @@ -36,19 +39,26 @@ impl PiiStripper { PiiRule { category: "path", rule_id: "rule_path_unix", - pattern: Regex::new(r#"(?:/(?:home|Users|var|tmp|opt|etc)/[^\s,;:"'\]}>)]+)"#).unwrap(), + pattern: Regex::new(r#"(?:/(?:home|Users|var|tmp|opt|etc)/[^\s,;:"'\]}>)]+)"#) + .unwrap(), prefix: "PATH", }, PiiRule { category: "path", rule_id: "rule_path_windows", - pattern: Regex::new(r#"(?i:[A-Z]:\\(?:Users|Documents|Program Files)[^\s,;:"'\]}>)]+)"#).unwrap(), + pattern: Regex::new( + r#"(?i:[A-Z]:\\(?:Users|Documents|Program Files)[^\s,;:"'\]}>)]+)"#, + ) + .unwrap(), prefix: "PATH", }, PiiRule { category: "ip", rule_id: "rule_ipv4", - pattern: Regex::new(r"\b(?:(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\b").unwrap(), + pattern: Regex::new( + r"\b(?:(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\b", + ) + .unwrap(), prefix: "IP", }, PiiRule { @@ -110,7 +120,8 @@ impl PiiStripper { category: "phone", rule_id: "rule_phone_us", // US phone: 555-867-5309, (555) 867-5309, +1-555-867-5309, 555.867.5309 - pattern: Regex::new(r"\b(?:\+?1[-.\s]?)?(?:\(?\d{3}\)?[-.\s])\d{3}[-.\s]\d{4}\b").unwrap(), + pattern: Regex::new(r"\b(?:\+?1[-.\s]?)?(?:\(?\d{3}\)?[-.\s])\d{3}[-.\s]\d{4}\b") + .unwrap(), prefix: "PHONE", }, PiiRule { @@ -138,7 +149,13 @@ impl PiiStripper { } /// Add a custom detection rule. - pub fn add_rule(&mut self, category: &'static str, rule_id: &'static str, pattern: &str, prefix: &'static str) -> Result<(), regex::Error> { + pub fn add_rule( + &mut self, + category: &'static str, + rule_id: &'static str, + pattern: &str, + prefix: &'static str, + ) -> Result<(), regex::Error> { self.custom_rules.push(PiiRule { category, rule_id, @@ -162,7 +179,8 @@ impl PiiStripper { let counter = self.counters.entry(prefix.to_string()).or_insert(0); *counter += 1; let pseudo = format!("<{}_{}>", prefix, counter); - self.pseudonym_map.insert(original.to_string(), pseudo.clone()); + self.pseudonym_map + .insert(original.to_string(), pseudo.clone()); pseudo } @@ -183,7 +201,10 @@ impl PiiStripper { let r = &self.custom_rules[i - num_builtin]; (&r.pattern as &Regex, r.prefix, r.category, r.rule_id) }; - let matches: Vec = pattern.find_iter(&result).map(|m| m.as_str().to_string()).collect(); + let matches: Vec = pattern + .find_iter(&result) + .map(|m| m.as_str().to_string()) + .collect(); if matches.is_empty() { continue; } @@ -206,7 +227,10 @@ impl PiiStripper { /// Strip PII from a collection of named string fields. /// /// Returns the redacted fields and a `RedactionLog` attestation. - pub fn strip_fields(&mut self, fields: &[(&str, &str)]) -> (Vec<(String, String)>, RedactionLog) { + pub fn strip_fields( + &mut self, + fields: &[(&str, &str)], + ) -> (Vec<(String, String)>, RedactionLog) { // Stage 1+2: Detect and redact let mut redacted_fields = Vec::new(); let mut all_detections: HashMap<(String, String), u32> = HashMap::new(); @@ -367,7 +391,14 @@ mod tests { #[test] fn custom_rule() { let mut stripper = PiiStripper::new(); - stripper.add_rule("custom_ssn", "rule_custom_ssn", r"\b\d{3}-\d{2}-\d{4}\b", "CUSTOM_SSN").unwrap(); + stripper + .add_rule( + "custom_ssn", + "rule_custom_ssn", + r"\b\d{3}-\d{2}-\d{4}\b", + "CUSTOM_SSN", + ) + .unwrap(); assert!(stripper.contains_pii("ssn: 123-45-6789")); } diff --git a/crates/rvf/rvf-federation/src/policy.rs b/crates/rvf/rvf-federation/src/policy.rs index 168dfb375..bf2b09d18 100644 --- a/crates/rvf/rvf-federation/src/policy.rs +++ b/crates/rvf/rvf-federation/src/policy.rs @@ -152,7 +152,9 @@ mod tests { #[test] fn segment_allowlist() { - let p = FederationPolicy::default().allow_segment(0x33).allow_segment(0x34); + let p = FederationPolicy::default() + .allow_segment(0x33) + .allow_segment(0x34); assert!(p.is_segment_allowed(0x33)); assert!(p.is_segment_allowed(0x34)); assert!(!p.is_segment_allowed(0x35)); // not in allowlist diff --git a/crates/rvf/rvf-federation/src/types.rs b/crates/rvf/rvf-federation/src/types.rs index b240df18c..c309332a4 100644 --- a/crates/rvf/rvf-federation/src/types.rs +++ b/crates/rvf/rvf-federation/src/types.rs @@ -251,7 +251,10 @@ impl BetaParams { /// Uniform (uninformative) prior. pub fn uniform() -> Self { - Self { alpha: 1.0, beta: 1.0 } + Self { + alpha: 1.0, + beta: 1.0, + } } /// Mean of the Beta distribution. diff --git a/crates/rvf/rvf-node/src/lib.rs b/crates/rvf/rvf-node/src/lib.rs index 09606d395..9adf698b0 100644 --- a/crates/rvf/rvf-node/src/lib.rs +++ b/crates/rvf/rvf-node/src/lib.rs @@ -274,9 +274,9 @@ fn json_to_filter(val: &JsonValue) -> Result { let vals: Vec = arr .iter() .map(|v| { - let s = v - .as_str() - .ok_or_else(|| napi::Error::from_reason("'values' entries must be strings"))?; + let s = v.as_str().ok_or_else(|| { + napi::Error::from_reason("'values' entries must be strings") + })?; parse_filter_value(&vt, s) }) .collect::>()?; @@ -699,9 +699,12 @@ impl RvfDatabase { /// Get this file's unique identifier as a hex string. #[napi] pub fn file_id(&self) -> Result { - let guard = self.inner.lock() + let guard = self + .inner + .lock() .map_err(|_| napi::Error::from_reason("Lock poisoned"))?; - let store = guard.as_ref() + let store = guard + .as_ref() .ok_or_else(|| napi::Error::from_reason("Store is closed"))?; Ok(hex_encode(store.file_id())) } @@ -709,9 +712,12 @@ impl RvfDatabase { /// Get the parent file's identifier as a hex string (all zeros if root). #[napi] pub fn parent_id(&self) -> Result { - let guard = self.inner.lock() + let guard = self + .inner + .lock() .map_err(|_| napi::Error::from_reason("Lock poisoned"))?; - let store = guard.as_ref() + let store = guard + .as_ref() .ok_or_else(|| napi::Error::from_reason("Store is closed"))?; Ok(hex_encode(store.parent_id())) } @@ -719,9 +725,12 @@ impl RvfDatabase { /// Get the lineage depth (0 for root files). #[napi] pub fn lineage_depth(&self) -> Result { - let guard = self.inner.lock() + let guard = self + .inner + .lock() .map_err(|_| napi::Error::from_reason("Lock poisoned"))?; - let store = guard.as_ref() + let store = guard + .as_ref() .ok_or_else(|| napi::Error::from_reason("Store is closed"))?; Ok(store.lineage_depth()) } @@ -729,9 +738,12 @@ impl RvfDatabase { /// Derive a child store from this parent. #[napi] pub fn derive(&self, child_path: String, options: Option) -> Result { - let guard = self.inner.lock() + let guard = self + .inner + .lock() .map_err(|_| napi::Error::from_reason("Lock poisoned"))?; - let store = guard.as_ref() + let store = guard + .as_ref() .ok_or_else(|| napi::Error::from_reason("Store is closed"))?; let child_opts = match options { @@ -739,11 +751,13 @@ impl RvfDatabase { None => None, }; - let child_store = store.derive( - Path::new(&child_path), - rvf_types::DerivationType::Filter, - child_opts, - ).map_err(map_rvf_err)?; + let child_store = store + .derive( + Path::new(&child_path), + rvf_types::DerivationType::Filter, + child_opts, + ) + .map_err(map_rvf_err)?; Ok(RvfDatabase { inner: Mutex::new(Some(child_store)), @@ -764,19 +778,24 @@ impl RvfDatabase { api_port: u32, cmdline: Option, ) -> Result { - let mut guard = self.inner.lock() + let mut guard = self + .inner + .lock() .map_err(|_| napi::Error::from_reason("Lock poisoned"))?; - let store = guard.as_mut() + let store = guard + .as_mut() .ok_or_else(|| napi::Error::from_reason("Store is closed"))?; - let seg_id = store.embed_kernel( - arch as u8, - kernel_type as u8, - flags, - &image, - api_port as u16, - cmdline.as_deref(), - ).map_err(map_rvf_err)?; + let seg_id = store + .embed_kernel( + arch as u8, + kernel_type as u8, + flags, + &image, + api_port as u16, + cmdline.as_deref(), + ) + .map_err(map_rvf_err)?; Ok(seg_id as i64) } @@ -785,9 +804,12 @@ impl RvfDatabase { /// Returns null if no kernel segment is present. #[napi] pub fn extract_kernel(&self) -> Result> { - let guard = self.inner.lock() + let guard = self + .inner + .lock() .map_err(|_| napi::Error::from_reason("Lock poisoned"))?; - let store = guard.as_ref() + let store = guard + .as_ref() .ok_or_else(|| napi::Error::from_reason("Store is closed"))?; match store.extract_kernel().map_err(map_rvf_err)? { @@ -810,19 +832,24 @@ impl RvfDatabase { bytecode: Buffer, btf: Option, ) -> Result { - let mut guard = self.inner.lock() + let mut guard = self + .inner + .lock() .map_err(|_| napi::Error::from_reason("Lock poisoned"))?; - let store = guard.as_mut() + let store = guard + .as_mut() .ok_or_else(|| napi::Error::from_reason("Store is closed"))?; let btf_ref = btf.as_ref().map(|b| b.as_ref()); - let seg_id = store.embed_ebpf( - program_type as u8, - attach_type as u8, - max_dimension as u16, - &bytecode, - btf_ref, - ).map_err(map_rvf_err)?; + let seg_id = store + .embed_ebpf( + program_type as u8, + attach_type as u8, + max_dimension as u16, + &bytecode, + btf_ref, + ) + .map_err(map_rvf_err)?; Ok(seg_id as i64) } @@ -831,9 +858,12 @@ impl RvfDatabase { /// Returns null if no eBPF segment is present. #[napi] pub fn extract_ebpf(&self) -> Result> { - let guard = self.inner.lock() + let guard = self + .inner + .lock() .map_err(|_| napi::Error::from_reason("Lock poisoned"))?; - let store = guard.as_ref() + let store = guard + .as_ref() .ok_or_else(|| napi::Error::from_reason("Store is closed"))?; match store.extract_ebpf().map_err(map_rvf_err)? { @@ -850,28 +880,35 @@ impl RvfDatabase { /// Get the list of segments in the store. #[napi] pub fn segments(&self) -> Result> { - let guard = self.inner.lock() + let guard = self + .inner + .lock() .map_err(|_| napi::Error::from_reason("Lock poisoned"))?; - let store = guard.as_ref() + let store = guard + .as_ref() .ok_or_else(|| napi::Error::from_reason("Store is closed"))?; let seg_dir = store.segment_dir(); - Ok(seg_dir.iter().map(|&(id, offset, payload_len, seg_type)| { - RvfSegmentInfo { + Ok(seg_dir + .iter() + .map(|&(id, offset, payload_len, seg_type)| RvfSegmentInfo { id: id as i64, offset: offset as i64, payload_length: payload_len as i64, seg_type: segment_type_name(seg_type), - } - }).collect()) + }) + .collect()) } /// Get the vector dimensionality of this store. #[napi] pub fn dimension(&self) -> Result { - let guard = self.inner.lock() + let guard = self + .inner + .lock() .map_err(|_| napi::Error::from_reason("Lock poisoned"))?; - let store = guard.as_ref() + let store = guard + .as_ref() .ok_or_else(|| napi::Error::from_reason("Store is closed"))?; Ok(store.dimension() as u32) } @@ -881,9 +918,12 @@ impl RvfDatabase { /// Get HNSW index statistics for this store. #[napi] pub fn index_stats(&self) -> Result { - let guard = self.inner.lock() + let guard = self + .inner + .lock() .map_err(|_| napi::Error::from_reason("Lock poisoned"))?; - let store = guard.as_ref() + let store = guard + .as_ref() .ok_or_else(|| napi::Error::from_reason("Store is closed"))?; let status = store.status(); @@ -904,15 +944,20 @@ impl RvfDatabase { /// hash). Returns the number of witness entries and validity status. #[napi] pub fn verify_witness(&self) -> Result { - let guard = self.inner.lock() + let guard = self + .inner + .lock() .map_err(|_| napi::Error::from_reason("Lock poisoned"))?; - let store = guard.as_ref() + let store = guard + .as_ref() .ok_or_else(|| napi::Error::from_reason("Store is closed"))?; // Witness segment type discriminator (0x0A). const WITNESS_SEG_TYPE: u8 = 0x0A; - let witness_count = store.segment_dir().iter() + let witness_count = store + .segment_dir() + .iter() .filter(|&&(_, _, _, seg_type)| seg_type == WITNESS_SEG_TYPE) .count() as u32; @@ -923,7 +968,9 @@ impl RvfDatabase { Ok(RvfWitnessResult { valid: false, entries: witness_count, - error: Some("Witness segments exist but chain hash is zero (corrupt or reset)".to_string()), + error: Some( + "Witness segments exist but chain hash is zero (corrupt or reset)".to_string(), + ), }) } else { Ok(RvfWitnessResult { @@ -940,9 +987,12 @@ impl RvfDatabase { /// Returns the manifest epoch at freeze time. #[napi] pub fn freeze(&self) -> Result { - let mut guard = self.inner.lock() + let mut guard = self + .inner + .lock() .map_err(|_| napi::Error::from_reason("Lock poisoned"))?; - let store = guard.as_mut() + let store = guard + .as_mut() .ok_or_else(|| napi::Error::from_reason("Store is closed"))?; let epoch = store.epoch(); @@ -953,9 +1003,12 @@ impl RvfDatabase { /// Get the distance metric used by this store. #[napi] pub fn metric(&self) -> Result { - let guard = self.inner.lock() + let guard = self + .inner + .lock() .map_err(|_| napi::Error::from_reason("Lock poisoned"))?; - let store = guard.as_ref() + let store = guard + .as_ref() .ok_or_else(|| napi::Error::from_reason("Store is closed"))?; let metric_str = match store.metric() { @@ -979,8 +1032,7 @@ fn hex_encode(bytes: &[u8]) -> String { } const HEX_CHARS: [char; 16] = [ - '0', '1', '2', '3', '4', '5', '6', '7', - '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', ]; fn segment_type_name(seg_type: u8) -> String { diff --git a/crates/rvf/rvf-solver-wasm/src/engine.rs b/crates/rvf/rvf-solver-wasm/src/engine.rs index cb31eed27..59c9de317 100644 --- a/crates/rvf/rvf-solver-wasm/src/engine.rs +++ b/crates/rvf/rvf-solver-wasm/src/engine.rs @@ -13,10 +13,10 @@ use alloc::vec::Vec; use serde::{Deserialize, Serialize}; use crate::policy::{ - CompiledConfig, KnowledgeCompiler, PolicyContext, PolicyKernel, SkipMode, SkipOutcome, - count_distractors, + count_distractors, CompiledConfig, KnowledgeCompiler, PolicyContext, PolicyKernel, SkipMode, + SkipOutcome, }; -use crate::types::{Constraint, Date, Puzzle, Rng64, Weekday, constraint_type_name}; +use crate::types::{constraint_type_name, Constraint, Date, Puzzle, Rng64, Weekday}; // ═════════════════════════════════════════════════════════════════════ // Solve result @@ -52,7 +52,14 @@ impl ReasoningBank { Self::default() } - pub fn record(&mut self, puzzle_id: &str, difficulty: u8, ctypes: &[&str], steps: usize, correct: bool) { + pub fn record( + &mut self, + puzzle_id: &str, + difficulty: u8, + ctypes: &[&str], + steps: usize, + correct: bool, + ) { let entry = ( String::from(puzzle_id), difficulty, @@ -87,7 +94,15 @@ impl ReasoningBank { let refs: Vec<(String, u8, Vec<&str>, usize, bool)> = self .trajectories .iter() - .map(|(id, d, ct, s, c)| (id.clone(), *d, ct.iter().map(|x| x.as_str()).collect(), *s, *c)) + .map(|(id, d, ct, s, c)| { + ( + id.clone(), + *d, + ct.iter().map(|x| x.as_str()).collect(), + *s, + *c, + ) + }) .collect(); compiler.compile_from_trajectories(&refs); } @@ -127,7 +142,11 @@ impl PuzzleGenerator { _ => 28, }; let day = self.rng.range(1, max_day) as u32; - let target = Date::new(year, month, day).unwrap_or(Date { year, month: 1, day: 1 }); + let target = Date::new(year, month, day).unwrap_or(Date { + year, + month: 1, + day: 1, + }); let mut constraints = Vec::new(); let constraint_count = (difficulty as usize / 2 + 2).min(7); @@ -247,7 +266,10 @@ impl AdaptiveSolver { /// Solve a puzzle using the three-loop adaptive architecture. pub fn solve(&mut self, puzzle: &Puzzle) -> SolveResult { - let has_dow = puzzle.constraints.iter().any(|c| matches!(c, Constraint::DayOfWeek(_))); + let has_dow = puzzle + .constraints + .iter() + .any(|c| matches!(c, Constraint::DayOfWeek(_))); let range = self.estimate_range(puzzle); let distractors = count_distractors(puzzle); @@ -271,8 +293,8 @@ impl AdaptiveSolver { // Fast loop: solve with constraint propagation let (solutions, steps) = self.solve_inner(puzzle, &skip_mode, &compiled); - let correct = !solutions.is_empty() - && puzzle.solutions.iter().any(|s| solutions.contains(s)); + let correct = + !solutions.is_empty() && puzzle.solutions.iter().any(|s| solutions.contains(s)); let solved = !solutions.is_empty(); // Check for early commit error @@ -292,8 +314,13 @@ impl AdaptiveSolver { self.policy_kernel.record_outcome(&ctx, &outcome); // Record trajectory (fast loop → slow loop feedback) - let ctypes: Vec<&str> = puzzle.constraints.iter().map(constraint_type_name).collect(); - self.bank.record(&puzzle.id, puzzle.difficulty, &ctypes, steps, correct); + let ctypes: Vec<&str> = puzzle + .constraints + .iter() + .map(constraint_type_name) + .collect(); + self.bank + .record(&puzzle.id, puzzle.difficulty, &ctypes, steps, correct); // Update compiler on success/failure if self.compiler_enabled { @@ -403,22 +430,34 @@ impl AdaptiveSolver { for c in &puzzle.constraints { match c { Constraint::Between(a, b) => { - if *a > lo { lo = *a; } - if *b < hi { hi = *b; } + if *a > lo { + lo = *a; + } + if *b < hi { + hi = *b; + } } Constraint::After(d) => { let next = d.succ(); - if next > lo { lo = next; } + if next > lo { + lo = next; + } } Constraint::Before(d) => { let prev = d.pred(); - if prev < hi { hi = prev; } + if prev < hi { + hi = prev; + } } Constraint::InYear(y) => { let yr_start = Date::new(*y, 1, 1).unwrap(); let yr_end = Date::new(*y, 12, 31).unwrap(); - if yr_start > lo { lo = yr_start; } - if yr_end < hi { hi = yr_end; } + if yr_start > lo { + lo = yr_start; + } + if yr_end < hi { + hi = yr_end; + } } Constraint::Exact(d) => { lo = *d; @@ -588,11 +627,7 @@ pub fn run_acceptance_mode( }); // ── Training phase (data available for next cycle's compile) ── - let mut gen = PuzzleGenerator::new( - config.training_seed + (cycle as u64 * 10_000), - 1, - 10, - ); + let mut gen = PuzzleGenerator::new(config.training_seed + (cycle as u64 * 10_000), 1, 10); let training = gen.generate_batch(config.training_per_cycle); let mut train_rng = Rng64::new(config.training_seed.wrapping_add(cycle as u64 * 7919)); @@ -612,7 +647,9 @@ pub fn run_acceptance_mode( let first = &cycle_metrics[0]; let last = cycle_metrics.last().unwrap(); - let accuracy_maintained = cycle_metrics.iter().all(|c| c.accuracy >= config.min_accuracy * 0.95) + let accuracy_maintained = cycle_metrics + .iter() + .all(|c| c.accuracy >= config.min_accuracy * 0.95) && last.accuracy >= config.min_accuracy; let cost_decrease = if first.cost_per_solve > 0.0 { @@ -628,9 +665,15 @@ pub fn run_acceptance_mode( let zero_violations = cycle_metrics.iter().all(|c| c.violations == 0); let mut dims = 0; - if cost_improved { dims += 1; } - if robustness_improved { dims += 1; } - if last.accuracy >= first.accuracy { dims += 1; } + if cost_improved { + dims += 1; + } + if robustness_improved { + dims += 1; + } + if last.accuracy >= first.accuracy { + dims += 1; + } let passed = accuracy_maintained && zero_violations && dims >= 2; @@ -731,16 +774,52 @@ fn inject_noise(puzzle: &Puzzle, rng: &mut Rng64) -> Puzzle { #[cfg(test)] mod tests { extern crate std; - use std::println; use super::*; + use std::println; #[test] fn test_acceptance_mode_c_parameter_sweep() { // Test various configs to find what passes Mode C let configs = [ - ("small", AcceptanceConfig { holdout_size: 30, training_per_cycle: 200, cycles: 5, step_budget: 500, holdout_seed: 0xDEAD_BEEF, training_seed: 42, noise_rate: 0.25, min_accuracy: 0.80 }), - ("medium", AcceptanceConfig { holdout_size: 50, training_per_cycle: 500, cycles: 8, step_budget: 1000, holdout_seed: 0xDEAD_BEEF, training_seed: 42, noise_rate: 0.25, min_accuracy: 0.80 }), - ("large", AcceptanceConfig { holdout_size: 50, training_per_cycle: 800, cycles: 12, step_budget: 2000, holdout_seed: 0xDEAD_BEEF, training_seed: 42, noise_rate: 0.25, min_accuracy: 0.80 }), + ( + "small", + AcceptanceConfig { + holdout_size: 30, + training_per_cycle: 200, + cycles: 5, + step_budget: 500, + holdout_seed: 0xDEAD_BEEF, + training_seed: 42, + noise_rate: 0.25, + min_accuracy: 0.80, + }, + ), + ( + "medium", + AcceptanceConfig { + holdout_size: 50, + training_per_cycle: 500, + cycles: 8, + step_budget: 1000, + holdout_seed: 0xDEAD_BEEF, + training_seed: 42, + noise_rate: 0.25, + min_accuracy: 0.80, + }, + ), + ( + "large", + AcceptanceConfig { + holdout_size: 50, + training_per_cycle: 800, + cycles: 12, + step_budget: 2000, + holdout_seed: 0xDEAD_BEEF, + training_seed: 42, + noise_rate: 0.25, + min_accuracy: 0.80, + }, + ), ]; for (label, config) in &configs { @@ -774,9 +853,16 @@ mod tests { let result = run_acceptance_mode(&config, true, true); let last = result.cycles.last().unwrap(); let status = if result.passed { "PASS" } else { "FAIL" }; - println!("seed={seed:#x} {status} acc={:.3} cost_imp={} robust_imp={} dims={}", - last.accuracy, result.cost_improved, result.robustness_improved, result.dimensions_improved); - if result.passed { pass_count += 1; } + println!( + "seed={seed:#x} {status} acc={:.3} cost_imp={} robust_imp={} dims={}", + last.accuracy, + result.cost_improved, + result.robustness_improved, + result.dimensions_improved + ); + if result.passed { + pass_count += 1; + } } println!("\n{pass_count}/{total} seeds passed"); } diff --git a/crates/rvf/rvf-solver-wasm/src/lib.rs b/crates/rvf/rvf-solver-wasm/src/lib.rs index 8b907347f..f404151e4 100644 --- a/crates/rvf/rvf-solver-wasm/src/lib.rs +++ b/crates/rvf/rvf-solver-wasm/src/lib.rs @@ -39,7 +39,9 @@ pub mod types; use alloc::vec::Vec; -use engine::{AcceptanceConfig, AcceptanceResult, AdaptiveSolver, PuzzleGenerator, run_acceptance_mode}; +use engine::{ + run_acceptance_mode, AcceptanceConfig, AcceptanceResult, AdaptiveSolver, PuzzleGenerator, +}; use rvf_crypto::{create_witness_chain, WitnessEntry}; // ═════════════════════════════════════════════════════════════════════ @@ -291,7 +293,11 @@ pub extern "C" fn rvf_solver_acceptance( // Serialize policy state inst.policy_json = serde_json::to_vec(&inst.solver.policy_kernel).unwrap_or_default(); - if mode_c.passed { 1 } else { 0 } + if mode_c.passed { + 1 + } else { + 0 + } } #[derive(serde::Serialize)] diff --git a/crates/rvf/rvf-solver-wasm/src/policy.rs b/crates/rvf/rvf-solver-wasm/src/policy.rs index abdb86b58..3213c8a58 100644 --- a/crates/rvf/rvf-solver-wasm/src/policy.rs +++ b/crates/rvf/rvf-solver-wasm/src/policy.rs @@ -15,7 +15,7 @@ use alloc::vec::Vec; use libm::{cos, log, pow, sqrt}; use serde::{Deserialize, Serialize}; -use crate::types::{Constraint, Puzzle, constraint_type_name}; +use crate::types::{constraint_type_name, Constraint, Puzzle}; // ═════════════════════════════════════════════════════════════════════ // Skip / Prepass modes @@ -113,8 +113,8 @@ impl SkipModeStats { if self.attempts <= 1 { self.cost_ema = normalized_steps; } else { - self.cost_ema = COST_EMA_ALPHA * normalized_steps - + (1.0 - COST_EMA_ALPHA) * self.cost_ema; + self.cost_ema = + COST_EMA_ALPHA * normalized_steps + (1.0 - COST_EMA_ALPHA) * self.cost_ema; } } @@ -162,7 +162,9 @@ impl PolicyKernel { if !ctx.has_day_of_week { return SkipMode::None; } - let eff = ctx.posterior_range.saturating_sub(Self::K * ctx.distractor_count); + let eff = ctx + .posterior_range + .saturating_sub(Self::K * ctx.distractor_count); if eff >= Self::T { SkipMode::Weekday } else { @@ -191,7 +193,9 @@ impl PolicyKernel { // accumulated enough training data. This ensures a meaningful baseline // in early cycles that training can measurably improve upon. { - let total_observations: usize = self.context_stats.values() + let total_observations: usize = self + .context_stats + .values() .flat_map(|m| m.values()) .map(|s| s.attempts) .sum(); @@ -224,7 +228,10 @@ impl PolicyKernel { }) .collect(); scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(core::cmp::Ordering::Equal)); - scored.first().map(|(m, _)| m.clone()).unwrap_or(SkipMode::None) + scored + .first() + .map(|(m, _)| m.clone()) + .unwrap_or(SkipMode::None) } /// Speculative dual-path check. @@ -424,7 +431,11 @@ impl KnowledgeCompiler { } pub fn signature(puzzle: &Puzzle) -> String { - let mut parts: Vec<&str> = puzzle.constraints.iter().map(constraint_type_name).collect(); + let mut parts: Vec<&str> = puzzle + .constraints + .iter() + .map(constraint_type_name) + .collect(); parts.sort(); format!("v1:{}:{}", puzzle.difficulty, parts.join(",")) } @@ -465,7 +476,10 @@ impl KnowledgeCompiler { } /// Compile knowledge from trajectories (simplified ReasoningBank integration). - pub fn compile_from_trajectories(&mut self, trajectories: &[(String, u8, Vec<&str>, usize, bool)]) { + pub fn compile_from_trajectories( + &mut self, + trajectories: &[(String, u8, Vec<&str>, usize, bool)], + ) { for (_, difficulty, ctypes, steps, correct) in trajectories { if !correct { continue; @@ -505,15 +519,21 @@ pub fn count_distractors(puzzle: &Puzzle) -> usize { for c in &puzzle.constraints { match c { Constraint::Between(_, _) => { - if sb { count += 1; } + if sb { + count += 1; + } sb = true; } Constraint::InYear(_) => { - if sy { count += 1; } + if sy { + count += 1; + } sy = true; } Constraint::DayOfWeek(_) => { - if sd { count += 1; } + if sd { + count += 1; + } sd = true; } _ => {} diff --git a/crates/rvf/rvf-solver-wasm/src/types.rs b/crates/rvf/rvf-solver-wasm/src/types.rs index cfadaeec1..62447cff8 100644 --- a/crates/rvf/rvf-solver-wasm/src/types.rs +++ b/crates/rvf/rvf-solver-wasm/src/types.rs @@ -185,7 +185,9 @@ pub struct Puzzle { impl Puzzle { pub fn check_date(&self, date: Date) -> bool { - self.constraints.iter().all(|c| check_one(date, c, &self.references)) + self.constraints + .iter() + .all(|c| check_one(date, c, &self.references)) } } @@ -196,12 +198,14 @@ fn check_one(date: Date, c: &Constraint, refs: &BTreeMap) -> bool Constraint::Before(d) => date < *d, Constraint::Between(a, b) => date >= *a && date <= *b, Constraint::DayOfWeek(w) => date.weekday() == *w, - Constraint::DaysAfter(name, n) => { - refs.get(name).map(|r| date == r.add_days(*n)).unwrap_or(false) - } - Constraint::DaysBefore(name, n) => { - refs.get(name).map(|r| date == r.add_days(-*n)).unwrap_or(false) - } + Constraint::DaysAfter(name, n) => refs + .get(name) + .map(|r| date == r.add_days(*n)) + .unwrap_or(false), + Constraint::DaysBefore(name, n) => refs + .get(name) + .map(|r| date == r.add_days(-*n)) + .unwrap_or(false), Constraint::InMonth(m) => date.month == *m, Constraint::InYear(y) => date.year == *y, Constraint::DayOfMonth(d) => date.day == *d, diff --git a/crates/rvf/rvf-wasm/src/bootstrap.rs b/crates/rvf/rvf-wasm/src/bootstrap.rs index 947b3faa0..b1e8abd4a 100644 --- a/crates/rvf/rvf-wasm/src/bootstrap.rs +++ b/crates/rvf/rvf-wasm/src/bootstrap.rs @@ -25,8 +25,8 @@ extern crate alloc; +use crate::segment::{parse_segments, SegmentInfo}; use alloc::vec::Vec; -use crate::segment::{SegmentInfo, parse_segments}; /// WASM_SEG type discriminant (matches rvf_types::SegmentType::Wasm). const WASM_SEG_TYPE: u8 = 0x10; @@ -79,13 +79,9 @@ pub enum BootstrapChain { /// File has no WASM_SEGs — requires external runtime for all processing. None, /// File contains only a microkernel — requires host WASM runtime. - HostRequired { - microkernel: WasmModule, - }, + HostRequired { microkernel: WasmModule }, /// File contains a combined interpreter+microkernel — single-step bootstrap. - SelfContained { - combined: WasmModule, - }, + SelfContained { combined: WasmModule }, /// File contains separate interpreter and microkernel — two-step bootstrap. TwoStage { interpreter: WasmModule, @@ -187,14 +183,21 @@ pub fn resolve_bootstrap_chain(buf: &[u8]) -> BootstrapChain { wasm_modules.sort_by_key(|m| m.bootstrap_priority); // Check for combined module (single-step bootstrap) - if let Some(idx) = wasm_modules.iter().position(|m| m.role == WasmRole::Combined as u8) { + if let Some(idx) = wasm_modules + .iter() + .position(|m| m.role == WasmRole::Combined as u8) + { return BootstrapChain::SelfContained { combined: wasm_modules.remove(idx), }; } - let interpreter_idx = wasm_modules.iter().position(|m| m.role == WasmRole::Interpreter as u8); - let microkernel_idx = wasm_modules.iter().position(|m| m.role == WasmRole::Microkernel as u8); + let interpreter_idx = wasm_modules + .iter() + .position(|m| m.role == WasmRole::Interpreter as u8); + let microkernel_idx = wasm_modules + .iter() + .position(|m| m.role == WasmRole::Microkernel as u8); match (interpreter_idx, microkernel_idx) { (Some(i_idx), Some(m_idx)) => { @@ -230,7 +233,10 @@ pub fn resolve_bootstrap_chain(buf: &[u8]) -> BootstrapChain { } (None, Some(_)) => { // Only microkernel, no interpreter → host provides runtime - let m_idx = wasm_modules.iter().position(|m| m.role == WasmRole::Microkernel as u8).unwrap(); + let m_idx = wasm_modules + .iter() + .position(|m| m.role == WasmRole::Microkernel as u8) + .unwrap(); BootstrapChain::HostRequired { microkernel: wasm_modules.remove(m_idx), } @@ -273,7 +279,7 @@ mod tests { seg.extend_from_slice(&[0, 0]); // flags seg.extend_from_slice(&1u64.to_le_bytes()); // segment_id seg.extend_from_slice(&(payload_len as u64).to_le_bytes()); // payload_length - // Fill remaining header bytes to reach 64 + // Fill remaining header bytes to reach 64 while seg.len() < seg_header_size { seg.push(0); } @@ -365,7 +371,11 @@ mod tests { assert!(chain.is_self_bootstrapping()); assert!(matches!(chain, BootstrapChain::TwoStage { .. })); - if let BootstrapChain::TwoStage { interpreter, microkernel } = &chain { + if let BootstrapChain::TwoStage { + interpreter, + microkernel, + } = &chain + { assert_eq!(interpreter.role, WasmRole::Interpreter as u8); assert_eq!(microkernel.role, WasmRole::Microkernel as u8); } @@ -401,7 +411,11 @@ mod tests { assert!(chain.is_self_bootstrapping()); // The interpreter should have lower priority (comes first) - if let BootstrapChain::TwoStage { interpreter, microkernel } = &chain { + if let BootstrapChain::TwoStage { + interpreter, + microkernel, + } = &chain + { assert_eq!(interpreter.bootstrap_priority, 0); assert_eq!(microkernel.bootstrap_priority, 10); } diff --git a/crates/rvf/rvf-wasm/src/lib.rs b/crates/rvf/rvf-wasm/src/lib.rs index 7688d8eb5..e801942c7 100644 --- a/crates/rvf/rvf-wasm/src/lib.rs +++ b/crates/rvf/rvf-wasm/src/lib.rs @@ -101,8 +101,7 @@ pub extern "C" fn rvf_load_block(block_ptr: i32, count: i32, dtype: i32) -> i32 pub extern "C" fn rvf_distances(metric: i32, result_ptr: i32) -> i32 { let (dim, count, dtype) = unsafe { let dim = *(DATA_MEMORY.as_ptr().add(TILE_CONFIG_DIM_OFFSET) as *const u32) as usize; - let count = - *(DATA_MEMORY.as_ptr().add(TILE_CONFIG_COUNT_OFFSET) as *const u32) as usize; + let count = *(DATA_MEMORY.as_ptr().add(TILE_CONFIG_COUNT_OFFSET) as *const u32) as usize; let dtype = *(DATA_MEMORY.as_ptr().add(TILE_CONFIG_DTYPE_OFFSET) as *const u32); (dim, count, dtype) }; @@ -201,9 +200,7 @@ pub extern "C" fn rvf_load_sq_params(params_ptr: i32, dim: i32) -> i32 { /// Returns 0 on success. #[no_mangle] pub extern "C" fn rvf_dequant_i8(src_ptr: i32, dst_ptr: i32, count: i32) -> i32 { - let dim = unsafe { - *(DATA_MEMORY.as_ptr().add(TILE_CONFIG_DIM_OFFSET) as *const u32) as usize - }; + let dim = unsafe { *(DATA_MEMORY.as_ptr().add(TILE_CONFIG_DIM_OFFSET) as *const u32) as usize }; if dim == 0 { return -1; } @@ -231,9 +228,7 @@ pub extern "C" fn rvf_dequant_i8(src_ptr: i32, dst_ptr: i32, count: i32) -> i32 /// Returns 0 on success. #[no_mangle] pub extern "C" fn rvf_load_pq_codebook(codebook_ptr: i32, m: i32, k: i32) -> i32 { - let dim = unsafe { - *(DATA_MEMORY.as_ptr().add(TILE_CONFIG_DIM_OFFSET) as *const u32) as usize - }; + let dim = unsafe { *(DATA_MEMORY.as_ptr().add(TILE_CONFIG_DIM_OFFSET) as *const u32) as usize }; let m_usize = m as usize; if m_usize == 0 { return -1; @@ -329,9 +324,7 @@ pub extern "C" fn rvf_load_neighbors(node_id: i64, layer: i32, out_ptr: i32) -> #[no_mangle] pub extern "C" fn rvf_greedy_step(current_id: i64, layer: i32) -> i64 { let _ = layer; - let neighbor_ptr = unsafe { - *(DATA_MEMORY.as_ptr().add(NEIGHBOR_CACHE_OFFSET) as *const i32) - }; + let neighbor_ptr = unsafe { *(DATA_MEMORY.as_ptr().add(NEIGHBOR_CACHE_OFFSET) as *const i32) }; if neighbor_ptr == 0 { return -1; } @@ -562,7 +555,12 @@ pub extern "C" fn rvf_store_query( return 0; } match store::registry().get(handle) { - Some(s) => s.query(query_ptr as *const f32, k as u32, metric, out_ptr as *mut u8), + Some(s) => s.query( + query_ptr as *const f32, + k as u32, + metric, + out_ptr as *mut u8, + ), None => -1, } } diff --git a/crates/rvf/rvf-wasm/src/store.rs b/crates/rvf/rvf-wasm/src/store.rs index 0daccfe36..bae2630cd 100644 --- a/crates/rvf/rvf-wasm/src/store.rs +++ b/crates/rvf/rvf-wasm/src/store.rs @@ -335,9 +335,7 @@ pub(crate) struct StoreRegistry { impl StoreRegistry { fn new() -> Self { - Self { - stores: Vec::new(), - } + Self { stores: Vec::new() } } pub(crate) fn create(&mut self, dim: u32, metric: u8) -> i32 { diff --git a/examples/google-cloud/src/self_learning.rs b/examples/google-cloud/src/self_learning.rs index ff30c9e55..6bf483b96 100644 --- a/examples/google-cloud/src/self_learning.rs +++ b/examples/google-cloud/src/self_learning.rs @@ -268,8 +268,8 @@ pub struct AutonomousModel { impl AutonomousModel { pub fn new(input_dim: usize, hidden_dim: usize, _output_dim: usize) -> Self { - let gnn_layer = RuvectorLayer::new(input_dim, hidden_dim, 8, 0.1) - .expect("Failed to create GNN layer"); + let gnn_layer = + RuvectorLayer::new(input_dim, hidden_dim, 8, 0.1).expect("Failed to create GNN layer"); let optimizer = Optimizer::new(OptimizerType::Adam { learning_rate: 0.001, diff --git a/examples/robotics/src/bin/01_basic_perception.rs b/examples/robotics/src/bin/01_basic_perception.rs index 44bd27437..ee8649589 100644 --- a/examples/robotics/src/bin/01_basic_perception.rs +++ b/examples/robotics/src/bin/01_basic_perception.rs @@ -6,7 +6,6 @@ /// - Radius search around a query point /// - Using the SpatialIndex for efficient lookups /// - Distance-based wall proximity analysis - use ruvector_robotics::bridge::{Point3D, PointCloud, SpatialIndex}; /// Generate a wall as a strip of points along one axis. @@ -31,13 +30,32 @@ fn main() { let points_per_wall = 50; let mut all_points = Vec::new(); - all_points.extend(generate_wall([0.0, 5.0, 0.0], [5.0, 5.0, 0.0], points_per_wall)); - all_points.extend(generate_wall([0.0, 0.0, 0.0], [5.0, 0.0, 0.0], points_per_wall)); - all_points.extend(generate_wall([0.0, 0.0, 0.0], [0.0, 5.0, 0.0], points_per_wall)); - all_points.extend(generate_wall([5.0, 0.0, 0.0], [5.0, 5.0, 0.0], points_per_wall)); + all_points.extend(generate_wall( + [0.0, 5.0, 0.0], + [5.0, 5.0, 0.0], + points_per_wall, + )); + all_points.extend(generate_wall( + [0.0, 0.0, 0.0], + [5.0, 0.0, 0.0], + points_per_wall, + )); + all_points.extend(generate_wall( + [0.0, 0.0, 0.0], + [0.0, 5.0, 0.0], + points_per_wall, + )); + all_points.extend(generate_wall( + [5.0, 0.0, 0.0], + [5.0, 5.0, 0.0], + points_per_wall, + )); let cloud = PointCloud::new(all_points, 1000); - println!("[1] Room point cloud created: {} points from 4 walls", cloud.len()); + println!( + "[1] Room point cloud created: {} points from 4 walls", + cloud.len() + ); // Step 2: Insert into spatial index let mut index = SpatialIndex::new(3); @@ -46,7 +64,10 @@ fn main() { // Step 3: Robot position in the center of the room let robot_pos = Point3D::new(2.5, 2.5, 0.0); - println!("[3] Robot position: ({:.1}, {:.1}, {:.1})", robot_pos.x, robot_pos.y, robot_pos.z); + println!( + "[3] Robot position: ({:.1}, {:.1}, {:.1})", + robot_pos.x, robot_pos.y, robot_pos.z + ); // Step 4: kNN search (SpatialIndex uses f32 queries) let k = 5; @@ -71,7 +92,11 @@ fn main() { match index.search_radius(&query, radius) { Ok(results) => { println!(); - println!("[5] Points within {:.1}m of robot: {}", radius, results.len()); + println!( + "[5] Points within {:.1}m of robot: {}", + radius, + results.len() + ); } Err(e) => println!("[5] Search error: {:?}", e), } diff --git a/examples/robotics/src/bin/02_obstacle_avoidance.rs b/examples/robotics/src/bin/02_obstacle_avoidance.rs index 91fe9aada..489d041ef 100644 --- a/examples/robotics/src/bin/02_obstacle_avoidance.rs +++ b/examples/robotics/src/bin/02_obstacle_avoidance.rs @@ -5,7 +5,6 @@ /// - Detecting obstacles using the PerceptionPipeline /// - Classifying obstacles by geometry (Static, Dynamic, Unknown) /// - Computing distances and safety margins - use rand::Rng; use ruvector_robotics::bridge::{Point3D, PointCloud}; use ruvector_robotics::perception::{PerceptionConfig, PerceptionPipeline}; diff --git a/examples/robotics/src/bin/03_scene_graph.rs b/examples/robotics/src/bin/03_scene_graph.rs index e7d4d18f3..8687e578e 100644 --- a/examples/robotics/src/bin/03_scene_graph.rs +++ b/examples/robotics/src/bin/03_scene_graph.rs @@ -5,7 +5,6 @@ /// - Building a SceneGraph from objects using SceneGraphBuilder /// - Inspecting computed edges (spatial relationships) /// - Merging two scene graphs into one - use ruvector_robotics::bridge::SceneObject; use ruvector_robotics::perception::SceneGraphBuilder; @@ -54,10 +53,24 @@ fn main() { let graph = builder.build(objects, 1000); println!(); - println!("[2] Scene graph: {} objects, {} edges", graph.objects.len(), graph.edges.len()); + println!( + "[2] Scene graph: {} objects, {} edges", + graph.objects.len(), + graph.edges.len() + ); for edge in &graph.edges { - let from_label = &graph.objects.iter().find(|o| o.id == edge.from).unwrap().label; - let to_label = &graph.objects.iter().find(|o| o.id == edge.to).unwrap().label; + let from_label = &graph + .objects + .iter() + .find(|o| o.id == edge.from) + .unwrap() + .label; + let to_label = &graph + .objects + .iter() + .find(|o| o.id == edge.to) + .unwrap() + .label; println!( " {} --[{}, {:.2}m]--> {}", from_label, edge.relation, edge.distance, to_label @@ -83,7 +96,11 @@ fn main() { let graph_b = builder.build(objects_b, 2000); println!(); - println!("[3] Second scene: {} objects, {} edges", graph_b.objects.len(), graph_b.edges.len()); + println!( + "[3] Second scene: {} objects, {} edges", + graph_b.objects.len(), + graph_b.edges.len() + ); let wide_builder = SceneGraphBuilder::new(10.0, 256); let merged = wide_builder.merge(&graph, &graph_b); diff --git a/examples/robotics/src/bin/04_behavior_tree.rs b/examples/robotics/src/bin/04_behavior_tree.rs index caef2dbca..f78bb0fb9 100644 --- a/examples/robotics/src/bin/04_behavior_tree.rs +++ b/examples/robotics/src/bin/04_behavior_tree.rs @@ -10,7 +10,6 @@ /// Root (Selector) /// |-- Avoid (Sequence): [Condition("obstacle_near")] -> [Action("evade")] /// |-- Patrol (Sequence): [Action("select_wp")] -> [Action("move")] -> [Action("wait")] - use ruvector_robotics::cognitive::{BehaviorNode, BehaviorStatus, BehaviorTree, DecoratorType}; fn main() { diff --git a/examples/robotics/src/bin/05_cognitive_robot.rs b/examples/robotics/src/bin/05_cognitive_robot.rs index c614b9158..fd689511c 100644 --- a/examples/robotics/src/bin/05_cognitive_robot.rs +++ b/examples/robotics/src/bin/05_cognitive_robot.rs @@ -5,7 +5,6 @@ /// - The perceive -> think -> act -> learn cycle /// - Attention threshold adaptation from feedback /// - Decision history and cumulative reward tracking - use ruvector_robotics::cognitive::{ ActionType, CognitiveConfig, CognitiveCore, CognitiveMode, Outcome, Percept, }; @@ -35,7 +34,7 @@ fn main() { let sensor_data: Vec<(&str, Vec, f64)> = vec![ ("lidar", vec![2.0, 1.5, 0.0], 0.9), ("camera", vec![-1.0, 3.0, 0.5], 0.85), - ("imu", vec![0.1, 0.2], 0.3), // below threshold -- will be dropped + ("imu", vec![0.1, 0.2], 0.3), // below threshold -- will be dropped ("lidar", vec![4.0, 0.0, 0.0], 0.95), ("camera", vec![0.0, 5.0, 1.0], 0.7), ("sonar", vec![1.0, 1.0, 0.0], 0.6), @@ -59,13 +58,18 @@ fn main() { let state = core.perceive(percept); println!( " Perceive: source='{}', conf={:.2}, buffered={} [state={:?}]", - source, confidence, core.percept_count(), state + source, + confidence, + core.percept_count(), + state ); // THINK if let Some(decision) = core.think() { let action_desc = match &decision.action.action { - ActionType::Move(pos) => format!("Move({:.1}, {:.1}, {:.1})", pos[0], pos[1], pos[2]), + ActionType::Move(pos) => { + format!("Move({:.1}, {:.1}, {:.1})", pos[0], pos[1], pos[2]) + } ActionType::Wait(ms) => format!("Wait({}ms)", ms), _ => format!("{:?}", decision.action.action), }; @@ -88,7 +92,9 @@ fn main() { }); println!( " Learn: success={}, reward={:.1}, cumulative={:.4}", - success, reward, core.cumulative_reward() + success, + reward, + core.cumulative_reward() ); } else { println!(" Think: no percepts to reason about"); @@ -119,7 +125,10 @@ fn main() { timestamp: 0, }); if let Some(decision) = emergency_core.think() { - println!(" Priority: {} (max for emergency)", decision.action.priority); + println!( + " Priority: {} (max for emergency)", + decision.action.priority + ); println!(" Reasoning: {}", decision.reasoning); } diff --git a/examples/robotics/src/bin/06_swarm_coordination.rs b/examples/robotics/src/bin/06_swarm_coordination.rs index ffe685a0e..23547d7c6 100644 --- a/examples/robotics/src/bin/06_swarm_coordination.rs +++ b/examples/robotics/src/bin/06_swarm_coordination.rs @@ -5,7 +5,6 @@ /// - Assigning tasks based on capability matching /// - Computing Line, Circle, and Grid formations /// - Running consensus votes among swarm members - use ruvector_robotics::cognitive::{ Formation, FormationType, RobotCapabilities, SwarmConfig, SwarmCoordinator, SwarmTask, }; @@ -108,7 +107,9 @@ fn main() { let task = tasks.iter().find(|t| t.id == assignment.task_id).unwrap(); println!( " Task {} ('{}') -> Robot {} (est. {:.1}s)", - assignment.task_id, task.description, assignment.robot_id, + assignment.task_id, + task.description, + assignment.robot_id, assignment.estimated_completion ); } @@ -158,7 +159,11 @@ fn main() { result.proposal, result.votes_for, result.votes_against, - if result.accepted { "ACCEPTED" } else { "REJECTED" } + if result.accepted { + "ACCEPTED" + } else { + "REJECTED" + } ); } diff --git a/examples/robotics/src/bin/07_skill_learning.rs b/examples/robotics/src/bin/07_skill_learning.rs index 7fd436a70..dc4785efc 100644 --- a/examples/robotics/src/bin/07_skill_learning.rs +++ b/examples/robotics/src/bin/07_skill_learning.rs @@ -6,7 +6,6 @@ /// - Executing the learned skill and tracking its trajectory /// - Improving confidence through positive/negative feedback /// - Using the SkillLibrary from ruvector_robotics::cognitive - use ruvector_robotics::cognitive::{Demonstration, SkillLibrary}; fn main() { @@ -18,17 +17,32 @@ fn main() { // -- Step 1: Record demonstrations for "reach" -- let demos = vec![ Demonstration { - trajectory: vec![[0.0, 0.0, 0.0], [1.0, 0.5, 0.0], [2.0, 1.0, 0.0], [3.0, 1.5, 0.0]], + trajectory: vec![ + [0.0, 0.0, 0.0], + [1.0, 0.5, 0.0], + [2.0, 1.0, 0.0], + [3.0, 1.5, 0.0], + ], timestamps: vec![0, 100, 200, 300], metadata: "expert_1".into(), }, Demonstration { - trajectory: vec![[0.0, 0.0, 0.0], [1.2, 0.4, 0.0], [2.1, 0.9, 0.0], [3.1, 1.6, 0.0]], + trajectory: vec![ + [0.0, 0.0, 0.0], + [1.2, 0.4, 0.0], + [2.1, 0.9, 0.0], + [3.1, 1.6, 0.0], + ], timestamps: vec![0, 110, 210, 310], metadata: "expert_2".into(), }, Demonstration { - trajectory: vec![[0.0, 0.0, 0.0], [0.8, 0.6, 0.0], [1.9, 1.1, 0.0], [2.9, 1.4, 0.0]], + trajectory: vec![ + [0.0, 0.0, 0.0], + [0.8, 0.6, 0.0], + [1.9, 1.1, 0.0], + [2.9, 1.4, 0.0], + ], timestamps: vec![0, 90, 190, 290], metadata: "expert_3".into(), }, @@ -49,7 +63,10 @@ fn main() { println!(); let skill = library.learn_from_demonstration("reach", &demos); println!("[2] Learned skill 'reach':"); - println!(" Trajectory length: {} waypoints", skill.trajectory.len()); + println!( + " Trajectory length: {} waypoints", + skill.trajectory.len() + ); println!(" Initial confidence: {:.3}", skill.confidence); println!(" Averaged trajectory:"); for (i, pt) in skill.trajectory.iter().enumerate() { @@ -59,13 +76,21 @@ fn main() { // -- Step 3: Learn another skill with a single demo -- println!(); let wave_demo = Demonstration { - trajectory: vec![[0.0, 0.0, 1.0], [0.5, 0.0, 1.5], [0.0, 0.0, 1.0], [-0.5, 0.0, 1.5]], + trajectory: vec![ + [0.0, 0.0, 1.0], + [0.5, 0.0, 1.5], + [0.0, 0.0, 1.0], + [-0.5, 0.0, 1.5], + ], timestamps: vec![0, 200, 400, 600], metadata: "single_demo".into(), }; let wave_skill = library.learn_from_demonstration("wave", &[wave_demo]); println!("[3] Learned skill 'wave' from 1 demo:"); - println!(" Confidence: {:.3} (lower with fewer demos)", wave_skill.confidence); + println!( + " Confidence: {:.3} (lower with fewer demos)", + wave_skill.confidence + ); println!(" Library now has {} skills", library.len()); // -- Step 4: Execute skills -- @@ -100,14 +125,22 @@ fn main() { library.improve_skill("reach", 0.03); } let after_positive = library.get("reach").unwrap().confidence; - println!(" After 5 successes: confidence={:.4} (+{:.4})", after_positive, after_positive - before); + println!( + " After 5 successes: confidence={:.4} (+{:.4})", + after_positive, + after_positive - before + ); // Negative feedback (2 failures) for _ in 0..2 { library.improve_skill("reach", -0.05); } let after_negative = library.get("reach").unwrap().confidence; - println!(" After 2 failures: confidence={:.4} ({:.4})", after_negative, after_negative - after_positive); + println!( + " After 2 failures: confidence={:.4} ({:.4})", + after_negative, + after_negative - after_positive + ); // -- Step 6: Summary -- println!(); diff --git a/examples/robotics/src/bin/08_world_model.rs b/examples/robotics/src/bin/08_world_model.rs index 54f165bb6..5e7a61799 100644 --- a/examples/robotics/src/bin/08_world_model.rs +++ b/examples/robotics/src/bin/08_world_model.rs @@ -6,7 +6,6 @@ /// - Predicting future states via constant-velocity extrapolation /// - Updating occupancy cells and checking path clearance /// - Removing stale objects by age threshold - use ruvector_robotics::cognitive::{TrackedObject, WorldModel}; fn main() { @@ -60,9 +59,14 @@ fn main() { for obj in &objects { println!( " id={} '{}': pos=({:.1},{:.1},{:.1}), vel=({:.1},{:.1},{:.1}), conf={:.2}", - obj.id, obj.label, - obj.position[0], obj.position[1], obj.position[2], - obj.velocity[0], obj.velocity[1], obj.velocity[2], + obj.id, + obj.label, + obj.position[0], + obj.position[1], + obj.position[2], + obj.velocity[0], + obj.velocity[1], + obj.velocity[2], obj.confidence ); } @@ -104,7 +108,11 @@ fn main() { let clear = world.is_path_clear(*from, *to); println!( " ({:>2},{:>2}) -> ({:>2},{:>2}) [{}]: {}", - from[0], from[1], to[0], to[1], label, + from[0], + from[1], + to[0], + to[1], + label, if clear { "CLEAR" } else { "BLOCKED" } ); } @@ -128,16 +136,27 @@ fn main() { for y in 0..size { for x in 0..size { if let Some(v) = world.get_occupancy(x, y) { - if v >= 0.5 { occupied += 1; } - else { free += 1; } + if v >= 0.5 { + occupied += 1; + } else { + free += 1; + } } } } let total = size * size; println!("[7] Occupancy statistics:"); println!(" Total cells: {}", total); - println!(" Occupied: {} ({:.1}%)", occupied, 100.0 * occupied as f64 / total as f64); - println!(" Free: {} ({:.1}%)", free, 100.0 * free as f64 / total as f64); + println!( + " Occupied: {} ({:.1}%)", + occupied, + 100.0 * occupied as f64 / total as f64 + ); + println!( + " Free: {} ({:.1}%)", + free, + 100.0 * free as f64 / total as f64 + ); println!(); println!("[done] World model example complete."); diff --git a/examples/robotics/src/bin/09_mcp_tools.rs b/examples/robotics/src/bin/09_mcp_tools.rs index 0f1a33452..785f9a37a 100644 --- a/examples/robotics/src/bin/09_mcp_tools.rs +++ b/examples/robotics/src/bin/09_mcp_tools.rs @@ -7,7 +7,6 @@ /// - Looking up a specific tool and inspecting its parameters /// - Registering a custom tool /// - Generating MCP-compatible JSON schema - use ruvector_robotics::mcp::{ ParamType, RoboticsToolRegistry, ToolCategory, ToolDefinition, ToolParameter, }; @@ -18,7 +17,10 @@ fn main() { // -- Step 1: Create registry -- let mut registry = RoboticsToolRegistry::new(); - println!("[1] Tool registry created: {} built-in tools", registry.list_tools().len()); + println!( + "[1] Tool registry created: {} built-in tools", + registry.list_tools().len() + ); // -- Step 2: List all tools -- println!(); @@ -29,7 +31,9 @@ fn main() { let tool = registry.get_tool(name).unwrap(); println!( " {:.<30} {:?} ({} params)", - tool.name, tool.category, tool.parameters.len() + tool.name, + tool.category, + tool.parameters.len() ); } @@ -73,14 +77,32 @@ fn main() { "custom_slam", "Run SLAM algorithm on sensor data", vec![ - ToolParameter::new("point_cloud_json", "JSON-encoded point cloud", ParamType::String, true), - ToolParameter::new("odometry_json", "JSON-encoded odometry data", ParamType::String, false), - ToolParameter::new("resolution", "Map resolution in meters", ParamType::Number, false), + ToolParameter::new( + "point_cloud_json", + "JSON-encoded point cloud", + ParamType::String, + true, + ), + ToolParameter::new( + "odometry_json", + "JSON-encoded odometry data", + ParamType::String, + false, + ), + ToolParameter::new( + "resolution", + "Map resolution in meters", + ParamType::Number, + false, + ), ], ToolCategory::Perception, ); registry.register_tool(custom); - println!("[5] Registered custom tool 'custom_slam'. Total: {} tools", registry.list_tools().len()); + println!( + "[5] Registered custom tool 'custom_slam'. Total: {} tools", + registry.list_tools().len() + ); // -- Step 6: Generate MCP schema -- println!(); diff --git a/examples/robotics/src/bin/10_full_pipeline.rs b/examples/robotics/src/bin/10_full_pipeline.rs index e6af50bda..5c50d019b 100644 --- a/examples/robotics/src/bin/10_full_pipeline.rs +++ b/examples/robotics/src/bin/10_full_pipeline.rs @@ -6,7 +6,6 @@ /// 3. Feed percepts into CognitiveCore (perceive-think-act-learn) /// 4. Track objects in WorldModel /// 5. Report comprehensive statistics - use rand::Rng; use ruvector_robotics::bridge::{Point3D, PointCloud, SpatialIndex}; use ruvector_robotics::cognitive::{ @@ -41,9 +40,20 @@ fn main() { println!("[1] Modules initialized:"); println!(" PerceptionPipeline: default config"); - println!(" CognitiveCore: mode={:?}, state={:?}", core.mode(), core.state()); - println!(" WorldModel: {}x{} grid", world.grid_size(), world.grid_size()); - println!(" MCP registry: {} tools", registry.list_tools().len()); + println!( + " CognitiveCore: mode={:?}, state={:?}", + core.mode(), + core.state() + ); + println!( + " WorldModel: {}x{} grid", + world.grid_size(), + world.grid_size() + ); + println!( + " MCP registry: {} tools", + registry.list_tools().len() + ); println!(); // -- Static obstacle points -- @@ -55,7 +65,11 @@ fn main() { // Box near (7, 2) for dx in 0..5 { for dy in 0..5 { - obstacle_pts.push(Point3D::new(7.0 + dx as f32 * 0.1, 2.0 + dy as f32 * 0.1, 0.0)); + obstacle_pts.push(Point3D::new( + 7.0 + dx as f32 * 0.1, + 2.0 + dy as f32 * 0.1, + 0.0, + )); } } @@ -92,7 +106,10 @@ fn main() { } // THINK: Feed percept to cognitive core - let nearest_dist = obstacles.first().map(|o| o.min_distance).unwrap_or(f64::MAX); + let nearest_dist = obstacles + .first() + .map(|o| o.min_distance) + .unwrap_or(f64::MAX); core.perceive(Percept { source: "perception_pipeline".into(), data: vec![robot_pos[0], robot_pos[1], nearest_dist], @@ -104,7 +121,11 @@ fn main() { if let Some(decision) = core.think() { decisions_made += 1; let _cmd = core.act(decision); - action_label = if nearest_dist < 2.0 { "avoid" } else { "patrol" }; + action_label = if nearest_dist < 2.0 { + "avoid" + } else { + "patrol" + }; } else { action_label = "idle"; } @@ -164,7 +185,10 @@ fn main() { println!(); println!("[A] Movement:"); - println!(" Final position: ({:.2}, {:.2})", robot_pos[0], robot_pos[1]); + println!( + " Final position: ({:.2}, {:.2})", + robot_pos[0], robot_pos[1] + ); println!(" Total distance: {:.2}m", total_distance); println!( " Avg speed: {:.3}m/step", @@ -185,7 +209,11 @@ fn main() { println!("[D] World Model:"); println!(" Tracked objects: {}", world.object_count()); - println!(" Grid size: {}x{}", world.grid_size(), world.grid_size()); + println!( + " Grid size: {}x{}", + world.grid_size(), + world.grid_size() + ); println!(); println!("[E] MCP tools available: {}", registry.list_tools().len());