diff --git a/crates/khive-pack-memory/Cargo.toml b/crates/khive-pack-memory/Cargo.toml index e1a60e7a..d01a040b 100644 --- a/crates/khive-pack-memory/Cargo.toml +++ b/crates/khive-pack-memory/Cargo.toml @@ -13,6 +13,7 @@ description = "Memory verb pack — remember/recall semantics with decay-aware r [dependencies] khive-types = { version = "0.2.2", path = "../khive-types", features = ["serde"] } khive-runtime = { version = "0.2.2", path = "../khive-runtime" } +khive-retrieval = { version = "0.2.2", path = "../khive-retrieval" } khive-pack-brain = { version = "0.2.2", path = "../khive-pack-brain" } inventory = { workspace = true } khive-storage = { version = "0.2.2", path = "../khive-storage" } diff --git a/crates/khive-pack-memory/src/handlers.rs b/crates/khive-pack-memory/src/handlers.rs index 6667a7f8..8346977f 100644 --- a/crates/khive-pack-memory/src/handlers.rs +++ b/crates/khive-pack-memory/src/handlers.rs @@ -4,8 +4,13 @@ use serde::Deserialize; use serde_json::{json, Value}; use uuid::Uuid; -use khive_runtime::fusion::fuse_with_strategy; -use khive_runtime::{NamespaceToken, RuntimeError, SearchHit, SearchSource, VerbRegistry}; +use khive_retrieval::{ + fuse_search_results, FusionStrategy as RetrievalFusionStrategy, HybridConfig, +}; +use khive_runtime::{ + FusionStrategy as RuntimeFusionStrategy, NamespaceToken, RuntimeError, SearchHit, SearchSource, + VerbRegistry, +}; use khive_storage::types::{ TextFilter, TextQueryMode, TextSearchHit, TextSearchRequest, VectorSearchHit, VectorSearchRequest, @@ -32,6 +37,19 @@ fn validate_memory_type(mt: &str) -> Result<(), RuntimeError> { } } +fn parse_fusion_strategy_str(s: &str) -> Result { + match s { + "rrf" => Ok(RuntimeFusionStrategy::Rrf { k: 60 }), + "weighted" => Ok(RuntimeFusionStrategy::Weighted { + weights: vec![0.3, 0.7], + }), + "union" => Ok(RuntimeFusionStrategy::Union), + other => Err(RuntimeError::InvalidInput(format!( + "invalid fusion_strategy {other:?}: must be one of \"rrf\", \"weighted\", \"union\"" + ))), + } +} + #[derive(Deserialize)] struct RememberParams { content: String, @@ -53,6 +71,9 @@ struct RecallParams { min_score: Option, min_salience: Option, config: Option, + top_k: Option, + fusion_strategy: Option, + score_floor: Option, } impl RecallParams { @@ -138,6 +159,49 @@ fn search_source_label(source: SearchSource) -> &'static str { } } +#[derive(Default)] +struct CandidateMeta { + in_text: bool, + in_vector: bool, + title: Option, + snippet: Option, +} + +fn to_retrieval_fusion_strategy(strategy: &RuntimeFusionStrategy) -> RetrievalFusionStrategy { + match strategy { + RuntimeFusionStrategy::Rrf { k } => RetrievalFusionStrategy::Rrf { k: *k }, + RuntimeFusionStrategy::Weighted { .. } => RetrievalFusionStrategy::Weighted { + weights: Vec::new(), + }, + RuntimeFusionStrategy::Union => RetrievalFusionStrategy::Union, + RuntimeFusionStrategy::VectorOnly => RetrievalFusionStrategy::VectorOnly, + } +} + +fn retrieval_hybrid_config(strategy: &RuntimeFusionStrategy, limit: usize) -> HybridConfig { + let mut config = HybridConfig::new(limit) + .with_pool_size(limit) + .with_fusion_strategy(to_retrieval_fusion_strategy(strategy)); + + if let RuntimeFusionStrategy::Weighted { weights } = strategy { + // Runtime weighted fusion uses [text, vector]. HybridConfig uses keyword/vector. + // Preserve arbitrary positive scales — do not clamp via with_weights(). + config.keyword_weight = weights.first().copied().unwrap_or(0.0).max(0.0); + config.vector_weight = weights.get(1).copied().unwrap_or(0.0).max(0.0); + } + + config +} + +fn source_from_meta(meta: &CandidateMeta) -> SearchSource { + match (meta.in_vector, meta.in_text) { + (true, true) => SearchSource::Both, + (true, false) => SearchSource::Vector, + (false, true) => SearchSource::Text, + (false, false) => SearchSource::Text, + } +} + fn fuse_candidates( text_hits: Vec, vector_hits: Vec, @@ -145,15 +209,68 @@ fn fuse_candidates( cfg: &RecallConfig, limit: usize, ) -> Vec { - let text: Vec = text_hits + let mut meta = HashMap::::new(); + + let text_source: Vec<_> = text_hits .into_iter() .filter(|h| memory_ids.contains(&h.subject_id)) + .map(|h| { + let TextSearchHit { + subject_id, + score, + title, + snippet, + .. + } = h; + let entry = meta.entry(subject_id).or_default(); + entry.in_text = true; + if entry.title.is_none() { + entry.title = title; + } + if entry.snippet.is_none() { + entry.snippet = snippet; + } + (subject_id, score) + }) .collect(); - let vec: Vec = vector_hits + + let vector_source: Vec<_> = vector_hits .into_iter() .filter(|h| memory_ids.contains(&h.subject_id)) + .map(|h| { + let entry = meta.entry(h.subject_id).or_default(); + entry.in_vector = true; + (h.subject_id, h.score) + }) .collect(); - fuse_with_strategy(text, vec, &cfg.fuse_strategy, limit) + + let vector_only = matches!(&cfg.fuse_strategy, RuntimeFusionStrategy::VectorOnly); + let sources = if vector_only { + vec![vector_source] + } else { + // HybridConfig weighted convention: vector first, keyword second. + vec![vector_source, text_source] + }; + + let retrieval_cfg = retrieval_hybrid_config(&cfg.fuse_strategy, limit); + fuse_search_results(sources, &retrieval_cfg) + .into_iter() + .map(|(id, score)| { + let m = meta.remove(&id).unwrap_or_default(); + let (source, title, snippet) = if vector_only { + (SearchSource::Vector, None, None) + } else { + (source_from_meta(&m), m.title, m.snippet) + }; + SearchHit { + entity_id: id, + score, + source, + title, + snippet, + } + }) + .collect() } impl MemoryPack { @@ -335,10 +452,35 @@ impl MemoryPack { validate_memory_type(mt)?; } - let cfg = p.effective_config(self.active_config()); + if let Some(ref fs) = p.fusion_strategy { + parse_fusion_strategy_str(fs)?; + } + + let mut cfg = p.effective_config(self.active_config()); + if let Some(ref fs) = p.fusion_strategy { + let mut new_strategy = parse_fusion_strategy_str(fs)?; + // "weighted" in the request means "use weighted fusion" — the actual + // weight values come from pack config, not the request (ADR-033 §6.1). + if let ( + RuntimeFusionStrategy::Weighted { + weights: ref mut new_w, + }, + RuntimeFusionStrategy::Weighted { + weights: ref existing_w, + }, + ) = (&mut new_strategy, &cfg.fuse_strategy) + { + *new_w = existing_w.clone(); + } + cfg.fuse_strategy = new_strategy; + } cfg.validate()?; - let limit = p.limit.unwrap_or(10).min(100); + let limit = if let Some(k) = p.top_k { + u32::try_from(k.min(100)).unwrap_or(100) + } else { + p.limit.unwrap_or(10).min(100) + }; let candidate_limit = recall_candidate_count(&cfg, limit); let candidates = self .collect_recall_candidates(&p.query, token, candidate_limit) @@ -392,6 +534,11 @@ impl MemoryPack { if final_score < cfg.min_score { continue; } + if let Some(floor) = p.score_floor { + if final_score < floor as f64 { + continue; + } + } ranked.push((id, final_score, breakdown, note)); } @@ -661,6 +808,9 @@ mod tests { min_score: None, min_salience: None, config: None, + top_k: None, + fusion_strategy: None, + score_floor: None, }; let cfg = p.effective_config(RecallConfig::default()); assert!((cfg.relevance_weight - 0.70).abs() < 1e-12); @@ -677,6 +827,9 @@ mod tests { min_score: Some(0.5), min_salience: Some(0.3), config: None, + top_k: None, + fusion_strategy: None, + score_floor: None, }; let cfg = p.effective_config(RecallConfig::default()); assert!((cfg.min_score - 0.5).abs() < 1e-12); @@ -695,6 +848,9 @@ mod tests { relevance_weight: 0.50, ..RecallConfig::default() }), + top_k: None, + fusion_strategy: None, + score_floor: None, }; let cfg = p.effective_config(RecallConfig::default()); assert!((cfg.relevance_weight - 0.50).abs() < 1e-12); @@ -702,6 +858,139 @@ mod tests { assert!((cfg.min_score - 0.1).abs() < 1e-12); } + #[test] + fn test_weighted_strategy_preserves_pack_weights() { + use khive_runtime::FusionStrategy as RuntimeFusionStrategy; + + // Pack config has custom weighted weights [0.8, 0.2] + let base = RecallConfig { + fuse_strategy: RuntimeFusionStrategy::Weighted { + weights: vec![0.8, 0.2], + }, + ..RecallConfig::default() + }; + + // Request overrides to "weighted" — must preserve [0.8, 0.2], not replace with [0.3, 0.7] + let p = RecallParams { + query: "test".to_string(), + limit: None, + memory_type: None, + min_score: None, + min_salience: None, + config: None, + top_k: None, + fusion_strategy: Some("weighted".to_string()), + score_floor: None, + }; + + let mut cfg = p.effective_config(base); + if let Some(ref fs) = p.fusion_strategy { + let mut new_strategy = parse_fusion_strategy_str(fs).unwrap(); + if let ( + RuntimeFusionStrategy::Weighted { + weights: ref mut new_w, + }, + RuntimeFusionStrategy::Weighted { + weights: ref existing_w, + }, + ) = (&mut new_strategy, &cfg.fuse_strategy) + { + *new_w = existing_w.clone(); + } + cfg.fuse_strategy = new_strategy; + } + + match cfg.fuse_strategy { + RuntimeFusionStrategy::Weighted { weights } => { + assert_eq!( + weights, + vec![0.8, 0.2], + "fusion_strategy=weighted must preserve pack weights [0.8, 0.2], not override with [0.3, 0.7]" + ); + } + other => panic!("expected Weighted strategy, got {other:?}"), + } + } + + #[test] + fn fusion_strategy_change_produces_observable_ordering_difference() { + // Codex Medium 2 (PR #406): prove the fusion_strategy knob actually + // affects fusion output, not just validation. Uses a deterministic fixture + // where rank-based (RRF) and score-based (Weighted) fusion must rank + // differently. + use khive_runtime::FusionStrategy as RuntimeFusionStrategy; + use khive_storage::types::{TextSearchHit, VectorSearchHit}; + use std::collections::HashSet; + use uuid::Uuid; + + let id_a = Uuid::from_u128(0xAAAA_AAAA_AAAA_AAAA_AAAA_AAAA_AAAA_AAAA); + let id_b = Uuid::from_u128(0xBBBB_BBBB_BBBB_BBBB_BBBB_BBBB_BBBB_BBBB); + let id_c = Uuid::from_u128(0xCCCC_CCCC_CCCC_CCCC_CCCC_CCCC_CCCC_CCCC); + + let text_hits = vec![ + TextSearchHit { + subject_id: id_a, + score: 0.9_f64.into(), + rank: 1, + title: None, + snippet: None, + }, + TextSearchHit { + subject_id: id_b, + score: 0.5_f64.into(), + rank: 2, + title: None, + snippet: None, + }, + ]; + let vector_hits = vec![ + VectorSearchHit { + subject_id: id_c, + score: 0.95_f64.into(), + rank: 1, + }, + VectorSearchHit { + subject_id: id_a, + score: 0.3_f64.into(), + rank: 2, + }, + ]; + let memory_ids: HashSet = [id_a, id_b, id_c].into_iter().collect(); + + let cfg_rrf = RecallConfig { + fuse_strategy: RuntimeFusionStrategy::Rrf { k: 60 }, + ..RecallConfig::default() + }; + let rrf_results = fuse_candidates( + text_hits.clone(), + vector_hits.clone(), + &memory_ids, + &cfg_rrf, + 10, + ); + let rrf_order: Vec = rrf_results.iter().map(|h| h.entity_id).collect(); + + let cfg_weighted = RecallConfig { + fuse_strategy: RuntimeFusionStrategy::Weighted { + weights: vec![0.1, 0.9], + }, + ..RecallConfig::default() + }; + let weighted_results = + fuse_candidates(text_hits, vector_hits, &memory_ids, &cfg_weighted, 10); + let weighted_order: Vec = weighted_results.iter().map(|h| h.entity_id).collect(); + + // RRF on this fixture: id_a in both sources gets highest combined rank score; + // id_c (vector rank 1) and id_b (text rank 2) tied around 0.0161-0.0164. + // Weighted [0.1, 0.9]: id_c dominates (0.95 * 0.9 = 0.855); id_a drops + // (0.9 * 0.1 + 0.3 * 0.9 = 0.36); id_b last (0.5 * 0.1 = 0.05). + // The orderings MUST differ — this is the discriminating assertion. + assert_ne!( + rrf_order, weighted_order, + "fusion_strategy change must affect ordering; RRF and Weighted produced identical: {rrf_order:?}" + ); + } + #[test] fn compute_score_default_config_reproduces_legacy() { let cfg = RecallConfig::default(); diff --git a/crates/khive-pack-memory/tests/integration.rs b/crates/khive-pack-memory/tests/integration.rs index 946856c7..60b467b0 100644 --- a/crates/khive-pack-memory/tests/integration.rs +++ b/crates/khive-pack-memory/tests/integration.rs @@ -657,6 +657,123 @@ async fn test_recall_fuse_source_field_is_plain_string() { ); } +/// Verifies that recall.fuse routes through khive_retrieval::fuse_search_results +/// by injecting a non-default fusion config (Rrf k=1) and asserting the fused +/// score matches the RRF k=1 formula: 1/(k + rank) = 1/(1 + 1) = 0.5. +/// +/// Under default k=60 the score would be 1/61 ≈ 0.0164. The large gap (0.5 vs +/// 0.0164) is the discriminator: if the adapter did not pass k=1 through to +/// khive_retrieval::HybridConfig, the score would not be 0.5. +#[tokio::test] +async fn test_recall_fuse_rrf_k1_uses_retrieval_adapter() { + let rt = make_runtime(); + let registry = make_registry(rt); + + registry + .dispatch( + "remember", + json!({ "content": "retrieval adapter rrf k1 probe memory" }), + ) + .await + .expect("remember"); + + let result = registry + .dispatch( + "recall.fuse", + json!({ + "query": "retrieval adapter rrf k1 probe", + "config": { + "fuse_strategy": { "rrf": { "k": 1 } } + } + }), + ) + .await + .expect("recall.fuse with Rrf k=1"); + + let fused = result["fused_candidates"].as_array().expect("fused array"); + assert!( + !fused.is_empty(), + "recall.fuse must return at least one candidate" + ); + + let score = fused[0]["fused_score"] + .as_f64() + .expect("fused_score is f64"); + // Rank 1 in a single text source with k=1: RRF = 1/(1+1) = 0.5. + // If k=60 were used instead, score ≈ 0.0164 — the gap proves the adapter works. + let expected = 0.5_f64; + assert!( + (score - expected).abs() < 1e-6, + "RRF k=1, rank 1 → fused_score must be 0.5; got {score:.6} \ + (≈0.0164 means the adapter passed k=60 instead of k=1)" + ); +} + +/// Regression: after wiring khive-retrieval into fuse_candidates, the recall.fuse +/// response shape must be unchanged — top-level strategy + candidate_limit, and +/// per-candidate note_id + fused_score + source must all be present. Full recall +/// fields (content, salience) must remain absent. +#[tokio::test] +async fn test_recall_fuse_shape_preserved_after_retrieval_wiring() { + let rt = make_runtime(); + let registry = make_registry(rt); + + registry + .dispatch( + "remember", + json!({ "content": "shape regression check after retrieval wiring" }), + ) + .await + .expect("remember"); + + let result = registry + .dispatch( + "recall.fuse", + json!({ "query": "shape regression retrieval wiring" }), + ) + .await + .expect("recall.fuse"); + + // Top-level shape + assert!( + result.get("strategy").is_some(), + "strategy field must be present in recall.fuse response" + ); + assert!( + result["candidate_limit"].as_u64().is_some(), + "candidate_limit must be a non-negative integer" + ); + + let fused = result["fused_candidates"] + .as_array() + .expect("fused_candidates array"); + assert!(!fused.is_empty(), "fused_candidates must be non-empty"); + + let c = &fused[0]; + assert!( + c["note_id"].as_str().is_some(), + "note_id must be a string UUID" + ); + assert!( + c["fused_score"].as_f64().is_some(), + "fused_score must be a float" + ); + let source = c["source"].as_str().expect("source must be a plain string"); + assert!( + matches!(source, "text" | "vector" | "both"), + "source must be a plain label, got {source:?}" + ); + // Full recall fields must not leak into fuse output + assert!( + c.get("content").is_none(), + "content must be absent from recall.fuse output" + ); + assert!( + c.get("salience").is_none(), + "salience must be absent from recall.fuse output" + ); +} + /// When include_breakdown is true, breakdown.total() must equal the hit's composite score. #[tokio::test] async fn test_recall_breakdown_total_matches_composite_score() { @@ -884,3 +1001,249 @@ async fn test_pack_tunable_apply_config_affects_recall_score() { "under relevance_weight=1.0 with rrf=1.0 → score=1.0; got {total2}" ); } + +// ── ADR-033 §6 knob tests ────────────────────────────────────────────────── + +#[tokio::test] +async fn test_recall_default_identity() { + let rt = make_runtime(); + let registry = make_registry(rt.clone()); + + // Create multiple memories so the identity comparison is meaningful + // (single-hit fixtures can't distinguish ordering changes). + for content in [ + "the mitochondria is the powerhouse of the cell", + "ribosomes synthesize proteins in the cell", + "the nucleus contains the cell's DNA", + "lysosomes digest cellular waste in the cell", + ] { + registry + .dispatch("remember", json!({ "content": content, "importance": 0.8 })) + .await + .expect("remember succeeds"); + } + + // Baseline recall with no knobs — query a term present in all 4 memories + let base = registry + .dispatch("recall", json!({ "query": "cell" })) + .await + .expect("baseline recall succeeds"); + let base_hits = base.as_array().expect("array"); + assert!( + base_hits.len() >= 2, + "baseline must return at least two hits to make ordering meaningful, got {}", + base_hits.len() + ); + + // Same call with all three knobs explicitly set to null — must be byte-identical + let knobless = registry + .dispatch( + "recall", + json!({ + "query": "cell", + "top_k": null, + "fusion_strategy": null, + "score_floor": null, + }), + ) + .await + .expect("recall with all knobs null succeeds"); + let knobless_hits = knobless.as_array().expect("array"); + + assert_eq!( + base_hits.len(), + knobless_hits.len(), + "null knobs must not change result count" + ); + + // Full ordering identity: each hit's note_id AND fused_score must match + // position-by-position. This catches a regression where a null knob silently + // shifts the ranking or rescaling. + for (i, (b, k)) in base_hits.iter().zip(knobless_hits.iter()).enumerate() { + assert_eq!( + b["note_id"].as_str(), + k["note_id"].as_str(), + "null knobs altered note_id at position {i}" + ); + // Scores must round-trip; allow tiny float jitter + let bs = b["score"].as_f64().unwrap_or(0.0); + let ks = k["score"].as_f64().unwrap_or(0.0); + assert!( + (bs - ks).abs() < 1e-9, + "null knobs altered score at position {i}: baseline={bs} knobless={ks}" + ); + } +} + +#[tokio::test] +async fn test_recall_top_k_override() { + let rt = make_runtime(); + let registry = make_registry(rt.clone()); + + // Create several distinct memories to ensure the pool is large enough + for i in 0..5 { + registry + .dispatch( + "remember", + json!({ + "content": format!("rust ownership memory safety concept {i}"), + "importance": 0.7 + }), + ) + .await + .expect("remember succeeds"); + } + + // Recall with top_k=2 — must not return more than 2 results + let result = registry + .dispatch( + "recall", + json!({ "query": "rust ownership memory safety", "top_k": 2 }), + ) + .await + .expect("recall with top_k=2 succeeds"); + let hits = result.as_array().expect("array"); + assert!( + hits.len() <= 2, + "top_k=2 must return at most 2 results, got {}", + hits.len() + ); + + // top_k=1 must return at most 1 + let result1 = registry + .dispatch( + "recall", + json!({ "query": "rust ownership memory safety", "top_k": 1 }), + ) + .await + .expect("recall with top_k=1 succeeds"); + let hits1 = result1.as_array().expect("array"); + assert!( + hits1.len() <= 1, + "top_k=1 must return at most 1 result, got {}", + hits1.len() + ); +} + +#[tokio::test] +async fn test_recall_fusion_strategy_override() { + let rt = make_runtime(); + let registry = make_registry(rt.clone()); + + registry + .dispatch( + "remember", + json!({ + "content": "gradient descent optimization machine learning", + "importance": 0.8 + }), + ) + .await + .expect("remember succeeds"); + + // Each valid strategy must succeed and return an array + for strategy in &["rrf", "weighted", "union"] { + let result = registry + .dispatch( + "recall", + json!({ + "query": "gradient descent optimization", + "fusion_strategy": strategy + }), + ) + .await + .unwrap_or_else(|e| panic!("recall with fusion_strategy={strategy:?} failed: {e}")); + assert!( + result.is_array(), + "fusion_strategy={strategy:?} must return an array, got {result}" + ); + } + + // Invalid strategy must return an error + let err = registry + .dispatch( + "recall", + json!({ + "query": "gradient descent optimization", + "fusion_strategy": "bogus" + }), + ) + .await; + assert!(err.is_err(), "invalid fusion_strategy must return an error"); + let msg = err.unwrap_err().to_string(); + assert!( + msg.contains("rrf") && msg.contains("weighted") && msg.contains("union"), + "error message must list valid strategies, got: {msg}" + ); +} + +#[tokio::test] +async fn test_recall_score_floor() { + let rt = make_runtime(); + let registry = make_registry(rt.clone()); + + registry + .dispatch( + "remember", + json!({ + "content": "backpropagation neural network training algorithm", + "importance": 0.6 + }), + ) + .await + .expect("remember succeeds"); + + // Baseline: no floor — get result count + let base = registry + .dispatch( + "recall", + json!({ "query": "backpropagation neural network" }), + ) + .await + .expect("baseline recall succeeds"); + let base_count = base.as_array().expect("array").len(); + + // score_floor=0.99 must not return MORE results than baseline + let floored = registry + .dispatch( + "recall", + json!({ + "query": "backpropagation neural network", + "score_floor": 0.99 + }), + ) + .await + .expect("recall with score_floor=0.99 succeeds"); + let floored_hits = floored.as_array().expect("array"); + assert!( + floored_hits.len() <= base_count, + "score_floor=0.99 must return ≤ baseline count ({base_count}), got {}", + floored_hits.len() + ); + + // All returned hits must have score >= 0.99 + for hit in floored_hits { + let score = hit["score"].as_f64().expect("score is a number"); + assert!( + score >= 0.99, + "score_floor=0.99: all returned scores must be ≥ 0.99, got {score}" + ); + } + + // score_floor=0.0 must behave same as no floor + let zero_floor = registry + .dispatch( + "recall", + json!({ + "query": "backpropagation neural network", + "score_floor": 0.0 + }), + ) + .await + .expect("recall with score_floor=0.0 succeeds"); + let zero_count = zero_floor.as_array().expect("array").len(); + assert_eq!( + zero_count, base_count, + "score_floor=0.0 must return same count as no floor" + ); +} diff --git a/crates/khive-retrieval/src/graph/tests.rs b/crates/khive-retrieval/src/graph/tests.rs index 639b3efd..92e3e936 100644 --- a/crates/khive-retrieval/src/graph/tests.rs +++ b/crates/khive-retrieval/src/graph/tests.rs @@ -1,6 +1,6 @@ //! Unit tests for graph traversal module. -use super::compat::{test_context, EntityRef, MockLinkStore}; +use super::compat::{test_context, EntityRef, LinkStore, MockLinkStore}; use crate::graph::types::{ Direction, PathNode, TraversalOptions, MAX_TRAVERSAL_DEPTH, MAX_TRAVERSAL_RESULTS, diff --git a/crates/khive-retrieval/src/persist/mod.rs b/crates/khive-retrieval/src/persist/mod.rs index 40d4e678..0893903a 100644 --- a/crates/khive-retrieval/src/persist/mod.rs +++ b/crates/khive-retrieval/src/persist/mod.rs @@ -26,7 +26,7 @@ //! //! ```rust,no_run //! use khive_retrieval::persist::RetrievalPersistence; -//! use khive_retrieval::hnsw::HnswIndex; +//! use khive_retrieval::HnswIndex; //! use rusqlite::Connection; //! use std::sync::Arc; //! use tokio::sync::Mutex; diff --git a/crates/khive-retrieval/src/persist/tests.rs b/crates/khive-retrieval/src/persist/tests.rs index 2efdf72d..88d6e84e 100644 --- a/crates/khive-retrieval/src/persist/tests.rs +++ b/crates/khive-retrieval/src/persist/tests.rs @@ -1,4 +1,5 @@ use super::*; +use crate::NodeId; use khive_bm25::Bm25Index; use khive_hnsw::HnswIndex; use rusqlite::Connection; diff --git a/crates/khive-retrieval/src/replay/engine_replay.rs b/crates/khive-retrieval/src/replay/engine_replay.rs index d25a85bb..45b8bbc2 100644 --- a/crates/khive-retrieval/src/replay/engine_replay.rs +++ b/crates/khive-retrieval/src/replay/engine_replay.rs @@ -844,11 +844,26 @@ pub mod metrics { #[cfg(test)] mod tests { use super::*; - use khive_db::SqliteStore; fn make_conn() -> Arc> { - let store = SqliteStore::memory().expect("in-memory store"); - store.conn() + let conn = Connection::open_in_memory().expect("open in-memory db"); + conn.execute_batch( + r#" + CREATE TABLE weight_events ( + namespace TEXT NOT NULL, + atom_id TEXT NOT NULL, + delta REAL NOT NULL, + weight_after REAL NOT NULL, + channel TEXT NOT NULL, + eta REAL NOT NULL, + event_id TEXT, + context_id TEXT, + ts INTEGER NOT NULL + ); + "#, + ) + .expect("init replay test schema"); + Arc::new(Mutex::new(conn)) } fn insert_weight_event( diff --git a/crates/khive-retrieval/src/weights/engine_weights.rs b/crates/khive-retrieval/src/weights/engine_weights.rs index 7530767c..0b47a7cc 100644 --- a/crates/khive-retrieval/src/weights/engine_weights.rs +++ b/crates/khive-retrieval/src/weights/engine_weights.rs @@ -298,14 +298,35 @@ pub async fn batch_load_weights( #[cfg(test)] mod tests { use super::*; - use khive_db::SqliteStore; use std::sync::Arc; fn make_conn() -> Arc> { - // Open an in-memory SQLite DB and run migrations so atom_weights and - // weight_events tables exist. - let store = SqliteStore::memory().expect("in-memory store"); - store.conn() + let conn = Connection::open_in_memory().expect("open in-memory db"); + conn.execute_batch( + r#" + CREATE TABLE atom_weights ( + namespace TEXT NOT NULL, + atom_id TEXT NOT NULL, + weight REAL NOT NULL, + updated_at INTEGER NOT NULL, + version INTEGER NOT NULL DEFAULT 1, + PRIMARY KEY(namespace, atom_id) + ); + CREATE TABLE weight_events ( + namespace TEXT NOT NULL, + atom_id TEXT NOT NULL, + delta REAL NOT NULL, + weight_after REAL NOT NULL, + channel TEXT NOT NULL, + eta REAL NOT NULL, + event_id TEXT, + context_id TEXT, + ts INTEGER NOT NULL + ); + "#, + ) + .expect("init weight test schema"); + Arc::new(Mutex::new(conn)) } // ------------------------------------------------------------------------- diff --git a/crates/khive-retrieval/tests/fusion_surface.rs b/crates/khive-retrieval/tests/fusion_surface.rs new file mode 100644 index 00000000..29ae15cf --- /dev/null +++ b/crates/khive-retrieval/tests/fusion_surface.rs @@ -0,0 +1,61 @@ +use khive_retrieval::{fuse_search_results, FusionStrategy, HybridConfig}; +use khive_score::DeterministicScore; + +#[test] +fn fuse_search_results_rrf_surface_matches_expected_order() { + // doc_b appears at rank 1 in both vector and keyword — must win under RRF k=60. + let vector = vec![ + ("doc_b", DeterministicScore::from_f64(0.9)), + ("doc_a", DeterministicScore::from_f64(0.8)), + ]; + let keyword = vec![ + ("doc_b", DeterministicScore::from_f64(4.0)), + ("doc_c", DeterministicScore::from_f64(3.0)), + ]; + let config = HybridConfig::new(10) + .with_pool_size(10) + .with_fusion_strategy(FusionStrategy::Rrf { k: 60 }); + + let results = fuse_search_results(vec![vector, keyword], &config); + + assert!(!results.is_empty(), "fusion must return results"); + assert_eq!( + results[0].0, "doc_b", + "doc_b must rank first (appears in both sources)" + ); + + // RRF score for doc_b: 1/(1+60) + 1/(1+60) = 2/61 ≈ 0.03279 + let expected = 2.0 / 61.0; + let actual = results[0].1.to_f64(); + assert!( + (actual - expected).abs() < 1e-6, + "fused score = {actual}, expected ~{expected}" + ); +} + +#[test] +fn fuse_search_results_empty_sources_returns_empty() { + let config = HybridConfig::default(); + let results = fuse_search_results::<&str>(vec![], &config); + assert!(results.is_empty()); +} + +#[test] +fn fuse_search_results_single_source_truncates_to_top_k() { + let source: Vec<_> = (0..20) + .map(|i| { + ( + format!("doc_{i}"), + DeterministicScore::from_f64(1.0 - i as f64 * 0.01), + ) + }) + .collect(); + let config = HybridConfig::new(5); + let results = fuse_search_results(vec![source], &config); + assert_eq!( + results.len(), + 5, + "single-source result must be truncated to top_k=5" + ); + assert_eq!(results[0].0, "doc_0", "highest score must be first"); +} diff --git a/docs/adr/ADR-033-recall-pipeline.md b/docs/adr/ADR-033-recall-pipeline.md index 375856c0..e6075de8 100644 --- a/docs/adr/ADR-033-recall-pipeline.md +++ b/docs/adr/ADR-033-recall-pipeline.md @@ -277,6 +277,43 @@ document its Hoare triple: | **Program** | Stage 1 (`memory.recall_embed`): query → embedding via multi-engine fan-out. Stage 2 (`memory.recall_candidates`): broad recall from FTS5 + vector, `candidate_multiplier × limit` candidates per path. Stage 3 (`memory.recall_fuse`): apply `fusion_strategy` (default RRF) to produce fused hits. Stage 4 (`memory.recall_rerank`, ADR-042 §7): run all rerankers whose weight in `reranker_weights` is > 0; each writes its score to `candidate.rerank_scores[name]`. Stage 5 (`memory.recall_score`): apply `ComposePipeline` with `WeightedObjective` over the three base Objectives plus one `RerankerObjective` per active reranker. Stage 6 (select): truncate to `limit`; apply `budget` via `GreedySelector` if set. | | **Postcondition** | Output is a deterministic list of memory notes ordered by composite score, within `limit`. All returned notes are alive (not soft-deleted) and `kind = memory`. Score breakdown is available on request via `memory.recall_score`. | +### 6.1 Per-request knobs (ADR-033 §6 addendum) + +The `recall` verb accepts three optional per-request knobs that override the pack-level +`RecallConfig` for a single call. All knobs are optional; absent or `null` preserves the +current default behavior. + +| Parameter | Type | Default | Semantics | +| ----------------- | ---------------- | ---------------- | --------------------------------------------------------------------------------------------------------------------------------------- | +| `top_k` | `usize` \| null | `limit` or `10` | Maximum number of results to return. Overrides `limit` when set. Capped at `100`. | +| `fusion_strategy` | `string` \| null | `"rrf"` (k=60) | Fusion algorithm for candidate merging. Must be one of `"rrf"`, `"weighted"`, `"union"`. Returns an error for any other value. | +| `score_floor` | `f32` \| null | `0.0` (no floor) | Minimum composite score threshold applied after `compute_score`. Results below this floor are excluded. `0.0` or `null` = no filtering. | + +**`fusion_strategy` details:** + +- `"rrf"` — Reciprocal Rank Fusion with k=60 (default). Robust across query types. +- `"weighted"` — Weighted linear combination. Text/vector weights come from the pack-level + config (`RecallConfig.fuse_strategy`), not the request. The request cannot override weights. +- `"union"` — Max-score per candidate ID. Inclusive but may surface low-quality text-only hits. + +**Example request DSL:** + +```json +{ + "query": "attention mechanism in transformers", + "top_k": 5, + "fusion_strategy": "union", + "score_floor": 0.3 +} +``` + +This returns at most 5 results, fused via union strategy, with composite score ≥ 0.3. + +**Interaction with `RecallConfig`:** Per-request knobs have higher precedence than `config` +and pack-level tuning. Resolution order: `top_k`/`fusion_strategy`/`score_floor` (request) + +> `config` object (per-call) > pack active config (tunable) > `RecallConfig::default()`. + ### 7. Calibration protocol To calibrate recall parameters for a deployment: