diff --git a/crates/Cargo.toml b/crates/Cargo.toml index ed4a0c18..7c8ff3c4 100644 --- a/crates/Cargo.toml +++ b/crates/Cargo.toml @@ -21,6 +21,7 @@ members = [ "khive-mcp", "khive-vcs", "kkernel", + "khive-retrieval", ] # khive-merge excluded — forward-deployed (ADR-043) but not yet compilable # against restructured khive-vcs. Will be re-added when ADR-043 integrates. diff --git a/crates/khive-fold/src/objective/registry.rs b/crates/khive-fold/src/objective/registry.rs new file mode 100644 index 00000000..4ce97815 --- /dev/null +++ b/crates/khive-fold/src/objective/registry.rs @@ -0,0 +1,275 @@ +//! Objective registry for dynamic dispatch. + +use std::collections::HashMap; +use std::sync::Arc; + +use parking_lot::RwLock; + +use crate::{Objective, ObjectiveContext, ObjectiveError, ObjectiveResult, Selection}; + +/// A type-erased objective wrapper. +pub struct RegisteredObjective { + /// Name of the objective + pub name: String, + /// Description + pub description: Option, + /// The objective implementation + objective: Box>, +} + +impl RegisteredObjective { + /// Create a new registered objective + pub fn new(name: impl Into, objective: Box>) -> Self { + Self { + name: name.into(), + description: None, + objective, + } + } + + /// Add a description + pub fn with_description(mut self, desc: impl Into) -> Self { + self.description = Some(desc.into()); + self + } + + /// Score a candidate + pub fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 { + self.objective.score(candidate, context) + } + + /// Select from candidates + pub fn select<'a>( + &self, + candidates: &'a [T], + context: &ObjectiveContext, + ) -> ObjectiveResult> { + self.objective.select(candidates, context) + } +} + +/// Registry of named objectives. +pub struct ObjectiveRegistry { + objectives: RwLock>>>, + default: RwLock>, +} + +impl Default for ObjectiveRegistry { + fn default() -> Self { + Self::new() + } +} + +impl ObjectiveRegistry { + /// Create a new empty registry + pub fn new() -> Self { + Self { + objectives: RwLock::new(HashMap::new()), + default: RwLock::new(None), + } + } + + /// Register an objective. + /// + /// Returns the previously registered objective if one existed with the same name. + pub fn register( + &self, + name: impl Into, + objective: Box>, + ) -> Option>> { + let name = name.into(); + let registered = Arc::new(RegisteredObjective::new(name.clone(), objective)); + + let mut objectives = self.objectives.write(); + objectives.insert(name, registered) + } + + /// Register an objective with description. + /// + /// Returns the previously registered objective if one existed with the same name. + pub fn register_with_desc( + &self, + name: impl Into, + description: impl Into, + objective: Box>, + ) -> Option>> { + let name = name.into(); + let registered = Arc::new( + RegisteredObjective::new(name.clone(), objective).with_description(description), + ); + + let mut objectives = self.objectives.write(); + objectives.insert(name, registered) + } + + /// Set the default objective + pub fn set_default(&self, name: impl Into) -> ObjectiveResult<()> { + let name = name.into(); + + let objectives = self.objectives.read(); + if !objectives.contains_key(&name) { + return Err(ObjectiveError::NotFound(name)); + } + drop(objectives); + + let mut default = self.default.write(); + *default = Some(name); + Ok(()) + } + + /// Get an objective by name + pub fn get(&self, name: &str) -> ObjectiveResult>> { + let objectives = self.objectives.read(); + objectives + .get(name) + .cloned() + .ok_or_else(|| ObjectiveError::NotFound(name.to_string())) + } + + /// Get the default objective + pub fn get_default(&self) -> ObjectiveResult>> { + let default = self.default.read(); + match default.as_ref() { + Some(name) => { + let name: String = name.clone(); + drop(default); + self.get(&name) + } + None => Err(ObjectiveError::NotFound("No default set".to_string())), + } + } + + /// List all registered objective names. + /// + /// Returns names in sorted order for deterministic output. + pub fn list(&self) -> Vec { + let objectives = self.objectives.read(); + let mut names: Vec = objectives.keys().cloned().collect(); + names.sort(); + names + } + + /// Check if an objective is registered + pub fn contains(&self, name: &str) -> bool { + let objectives = self.objectives.read(); + objectives.contains_key(name) + } + + /// Score using a named objective + pub fn score( + &self, + name: &str, + candidate: &T, + context: &ObjectiveContext, + ) -> ObjectiveResult { + let objective = self.get(name)?; + Ok(objective.score(candidate, context)) + } + + /// Select using a named objective + pub fn select<'a>( + &self, + name: &str, + candidates: &'a [T], + context: &ObjectiveContext, + ) -> ObjectiveResult> { + let objective = self.get(name)?; + objective.select(candidates, context) + } + + /// Select using the default objective + pub fn select_default<'a>( + &self, + candidates: &'a [T], + context: &ObjectiveContext, + ) -> ObjectiveResult> { + let objective = self.get_default()?; + objective.select(candidates, context) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::objective_fn; + + #[test] + fn test_register_and_get() { + let registry: ObjectiveRegistry = ObjectiveRegistry::new(); + + let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64); + let old = registry.register("max", Box::new(obj)); + + assert!(old.is_none()); + assert!(registry.contains("max")); + assert!(!registry.contains("min")); + } + + #[test] + fn test_register_overwrites() { + let registry: ObjectiveRegistry = ObjectiveRegistry::new(); + + let obj1 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64); + let obj2 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| -(*n as f64)); + + let old1 = registry.register("test", Box::new(obj1)); + assert!(old1.is_none()); + + let old2 = registry.register("test", Box::new(obj2)); + assert!(old2.is_some()); + + let candidates = vec![1, 5, 3]; + let selection = registry + .select("test", &candidates, &ObjectiveContext::new()) + .unwrap(); + assert_eq!(*selection.item, 1); + } + + #[test] + fn test_select_by_name() { + let registry: ObjectiveRegistry = ObjectiveRegistry::new(); + + let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64); + registry.register("max", Box::new(obj)); + + let candidates = vec![1, 5, 3]; + let selection = registry + .select("max", &candidates, &ObjectiveContext::new()) + .unwrap(); + + assert_eq!(*selection.item, 5); + } + + #[test] + fn test_default_objective() { + let registry: ObjectiveRegistry = ObjectiveRegistry::new(); + + let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64); + registry.register("max", Box::new(obj)); + registry.set_default("max").unwrap(); + + let candidates = vec![1, 5, 3]; + let selection = registry + .select_default(&candidates, &ObjectiveContext::new()) + .unwrap(); + + assert_eq!(*selection.item, 5); + } + + #[test] + fn test_list_objectives_sorted() { + let registry: ObjectiveRegistry = ObjectiveRegistry::new(); + + let obj1 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64); + let obj2 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| -(*n as f64)); + let obj3 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| (*n as f64).abs()); + + registry.register("zebra", Box::new(obj1)); + registry.register("alpha", Box::new(obj2)); + registry.register("middle", Box::new(obj3)); + + let names = registry.list(); + assert_eq!(names.len(), 3); + assert_eq!(names, vec!["alpha", "middle", "zebra"]); + } +} diff --git a/crates/khive-retrieval/Cargo.toml b/crates/khive-retrieval/Cargo.toml new file mode 100644 index 00000000..19a761e2 --- /dev/null +++ b/crates/khive-retrieval/Cargo.toml @@ -0,0 +1,56 @@ +[package] +name = "khive-retrieval" +version.workspace = true +edition.workspace = true +authors.workspace = true +license.workspace = true +repository.workspace = true +homepage.workspace = true +keywords.workspace = true +categories.workspace = true +description = "Hybrid retrieval composer (HNSW + BM25 + fusion + graph + cross-encoder) with deterministic scoring" + +[dependencies] +khive-hnsw = { version = "0.2.0", path = "../khive-hnsw" } +khive-bm25 = { version = "0.2.0", path = "../khive-bm25" } +khive-fusion = { version = "0.2.0", path = "../khive-fusion" } +khive-score = { version = "0.2.0", path = "../khive-score" } +khive-types = { version = "0.2.0", path = "../khive-types" } +khive-fold = { version = "0.2.0", path = "../khive-fold", optional = true } +khive-storage = { version = "0.2.0", path = "../khive-storage", optional = true } +khive-db = { version = "0.2.0", path = "../khive-db" } +khive-gate = { version = "0.2.0", path = "../khive-gate", optional = true } +lattice-embed = { workspace = true } + +serde = { workspace = true } +serde_json = { workspace = true } +thiserror = { workspace = true } +parking_lot = { workspace = true } +async-trait = { workspace = true } +tokio = { workspace = true } +tokio-util = { version = "0.7", features = ["rt"] } +chrono = { workspace = true } +uuid = { workspace = true } +rusqlite = { version = "0.33", optional = true } +tracing = { workspace = true, optional = true } +rand = { version = "0.8", optional = true } + +[features] +default = [] +# Policy-based access control for search results (uses khive-gate API) +policy = ["khive-gate"] +# HNSW checkpoint integration with khive-fold +# Note: khive_hnsw::HnswCheckpoint/HnswCheckpointStore depend on khive_fold::Checkpoint +# which doesn't exist in the current khive-fold API. Those re-exports are gated out +# until the khive-fold Checkpoint trait is ported. +checkpoint = ["khive-fold"] +# SQLite-based persistence for HNSW and BM25 indexes +persist = ["rusqlite", "tracing", "rand"] +# Adapters bridging khive-storage backends (sqlite-vec, FTS5) to retrieval search traits +storage-adapters = ["khive-storage"] +# Native cross-encoder reranking (deferred until khive-inference is ported) +native-rerank = [] +# Native embedding service (delegated to lattice-embed; reserved for future feature-gating) +embed = [] +# Legacy graph traversal module (depends on old EntityRef/LinkStore API; not yet ported) +graph-legacy = [] diff --git a/crates/khive-retrieval/src/adapters/mod.rs b/crates/khive-retrieval/src/adapters/mod.rs new file mode 100644 index 00000000..479e0fd2 --- /dev/null +++ b/crates/khive-retrieval/src/adapters/mod.rs @@ -0,0 +1,456 @@ +//! Adapters bridging `khive-storage-traits` backends to retrieval search traits. +//! +//! The retrieval crate defines [`VectorSearch`] and [`KeywordSearch`] as async +//! traits with an associated `Id` type. The `khive-storage-traits` crate defines +//! [`VectorStore`] and [`TextSearch`] as async persistence traits using `Uuid`. +//! +//! This module provides adapter types that implement the retrieval search traits +//! by delegating to storage-traits backends: +//! +//! - [`StorageVectorSearch`]: wraps `Arc` -> `VectorSearch` +//! - [`StorageKeywordSearch`]: wraps `Arc` -> `KeywordSearch` +//! +//! This makes [`HybridSearcher`] work with persistent backends (sqlite-vec, FTS5) +//! alongside the existing in-memory backends (HNSW, BM25). +//! +//! # Example +//! +//! ```rust,ignore +//! use khive_db::StorageBackend; +//! use khive_retrieval::adapters::{StorageVectorSearch, StorageKeywordSearch}; +//! use khive_retrieval::hybrid::{VectorSearch, KeywordSearch}; +//! +//! let backend = StorageBackend::memory().unwrap(); +//! let vec_store = backend.vectors("model", 384).unwrap(); +//! let text_store = backend.text("docs").unwrap(); +//! +//! let vector_search = StorageVectorSearch::new(vec_store); +//! let keyword_search = StorageKeywordSearch::new(text_store); +//! +//! // Both implement the retrieval search traits with Id = Uuid +//! let hits = vector_search.vector_search(&query_embedding, 10).await?; +//! let kw_hits = keyword_search.keyword_search("some query", 10).await?; +//! ``` + +use std::sync::Arc; + +use async_trait::async_trait; +use khive_score::DeterministicScore; +use khive_storage::types::{TextQueryMode, TextSearchRequest, VectorSearchRequest}; +use khive_storage::{TextSearch, VectorStore}; +use uuid::Uuid; + +use crate::error::{Result, RetrievalError}; +use crate::hybrid::{KeywordSearch, VectorSearch}; + +// --------------------------------------------------------------------------- +// Error conversion +// --------------------------------------------------------------------------- + +/// Convert a `StorageError` into a `RetrievalError`. +/// +/// Maps storage-level errors to the closest retrieval error variant: +/// - Vector-related storage errors -> `Hnsw` (vector search context) +/// - Text-related storage errors -> `Bm25` (keyword search context) +/// - Timeout/pool errors -> transient retrieval errors +/// - Everything else -> generic error string +fn storage_err_to_retrieval( + err: khive_storage::StorageError, + context: &'static str, +) -> RetrievalError { + use khive_storage::StorageError; + + match &err { + StorageError::Timeout { .. } => { + // Map to a transient retrieval error + RetrievalError::Hnsw(format!("{context}: {err}")) + } + StorageError::InvalidInput { message, .. } => { + RetrievalError::InvalidQuery(format!("{context}: {message}")) + } + _ => { + // Generic mapping -- preserve the full error message + RetrievalError::Hnsw(format!("{context}: {err}")) + } + } +} + +// --------------------------------------------------------------------------- +// StorageVectorSearch +// --------------------------------------------------------------------------- + +/// Adapter implementing [`VectorSearch`] by delegating to a [`VectorStore`]. +/// +/// Wraps an `Arc` (e.g., `SqliteVecStore`) and implements +/// the retrieval `VectorSearch` trait with `Id = Uuid`. +/// +/// The adapter is `Send + Sync` and can be shared across tasks. +pub struct StorageVectorSearch { + store: Arc, +} + +impl StorageVectorSearch { + /// Create a new adapter wrapping the given vector store. + pub fn new(store: Arc) -> Self { + Self { store } + } +} + +#[async_trait] +impl VectorSearch for StorageVectorSearch { + type Id = Uuid; + + async fn vector_search( + &self, + embedding: &[f32], + top_k: usize, + ) -> Result> { + let request = VectorSearchRequest { + query_embedding: embedding.to_vec(), + top_k: top_k as u32, + namespace: None, + kind: None, + }; + + let hits = self + .store + .search(request) + .await + .map_err(|e| storage_err_to_retrieval(e, "vector search"))?; + + Ok(hits + .into_iter() + .map(|hit| (hit.subject_id, hit.score)) + .collect()) + } +} + +// --------------------------------------------------------------------------- +// StorageKeywordSearch +// --------------------------------------------------------------------------- + +/// Adapter implementing [`KeywordSearch`] by delegating to a [`TextSearch`]. +/// +/// Wraps an `Arc` (e.g., `Fts5TextSearch`) and implements +/// the retrieval `KeywordSearch` trait with `Id = Uuid`. +/// +/// Uses `TextQueryMode::Plain` for keyword queries by default. The snippet +/// length is set to 0 since retrieval only needs IDs and scores. +pub struct StorageKeywordSearch { + search: Arc, +} + +impl StorageKeywordSearch { + /// Create a new adapter wrapping the given text search backend. + pub fn new(search: Arc) -> Self { + Self { search } + } +} + +#[async_trait] +impl KeywordSearch for StorageKeywordSearch { + type Id = Uuid; + + async fn keyword_search( + &self, + text: &str, + top_k: usize, + ) -> Result> { + let request = TextSearchRequest { + query: text.to_string(), + mode: TextQueryMode::Plain, + filter: None, + top_k: top_k as u32, + snippet_chars: 0, // retrieval only needs IDs + scores + }; + + let hits = self + .search + .search(request) + .await + .map_err(|e| storage_err_to_retrieval(e, "keyword search"))?; + + Ok(hits + .into_iter() + .map(|hit| (hit.subject_id, hit.score)) + .collect()) + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use khive_db::StorageBackend; + use khive_storage::types::TextDocument; + use khive_types::SubstrateKind; + + /// Helper: create a memory-backed StorageBackend. + fn test_backend() -> StorageBackend { + StorageBackend::memory().expect("memory backend") + } + + // ----------------------------------------------------------------------- + // StorageVectorSearch tests + // ----------------------------------------------------------------------- + + #[tokio::test] + async fn vector_search_basic_roundtrip() { + let backend = test_backend(); + let store = backend.vectors("test_vs", 3).unwrap(); + + // Insert two vectors + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + store + .insert(id1, SubstrateKind::Entity, "test", vec![1.0, 0.0, 0.0]) + .await + .unwrap(); + store + .insert(id2, SubstrateKind::Entity, "test", vec![0.0, 1.0, 0.0]) + .await + .unwrap(); + + // Wrap in adapter and use VectorSearch trait + let adapter = StorageVectorSearch::new(store); + let hits = adapter.vector_search(&[1.0, 0.0, 0.0], 2).await.unwrap(); + + assert_eq!(hits.len(), 2); + // Closest to [1,0,0] should be id1 + assert_eq!(hits[0].0, id1); + // Score should be high (cosine similarity ~1.0) + assert!(hits[0].1.to_f64() > 0.9); + } + + #[tokio::test] + async fn vector_search_respects_top_k() { + let backend = test_backend(); + let store = backend.vectors("test_topk", 3).unwrap(); + + // Insert 5 vectors + for _ in 0..5 { + store + .insert( + Uuid::new_v4(), + SubstrateKind::Entity, + "test", + vec![1.0, 0.0, 0.0], + ) + .await + .unwrap(); + } + + let adapter = StorageVectorSearch::new(store); + let hits = adapter.vector_search(&[1.0, 0.0, 0.0], 3).await.unwrap(); + + assert_eq!(hits.len(), 3); + } + + #[tokio::test] + async fn vector_search_empty_store() { + let backend = test_backend(); + let store = backend.vectors("test_empty", 3).unwrap(); + + let adapter = StorageVectorSearch::new(store); + let hits = adapter.vector_search(&[1.0, 0.0, 0.0], 5).await.unwrap(); + + assert!(hits.is_empty()); + } + + #[tokio::test] + async fn vector_search_returns_deterministic_scores() { + let backend = test_backend(); + let store = backend.vectors("test_det", 3).unwrap(); + + let id = Uuid::new_v4(); + store + .insert(id, SubstrateKind::Entity, "test", vec![1.0, 0.0, 0.0]) + .await + .unwrap(); + + let adapter = StorageVectorSearch::new(store); + + // Run twice -- scores must be identical (deterministic) + let hits1 = adapter.vector_search(&[1.0, 0.0, 0.0], 1).await.unwrap(); + let hits2 = adapter.vector_search(&[1.0, 0.0, 0.0], 1).await.unwrap(); + + assert_eq!(hits1[0].1, hits2[0].1); + } + + // ----------------------------------------------------------------------- + // StorageKeywordSearch tests + // ----------------------------------------------------------------------- + + #[tokio::test] + async fn keyword_search_basic_roundtrip() { + let backend = test_backend(); + let store = backend.text("test_ks").unwrap(); + + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + + store + .upsert_document(TextDocument { + subject_id: id1, + kind: SubstrateKind::Entity, + namespace: "test".to_string(), + title: Some("Rust Programming".to_string()), + body: "Rust is a systems programming language.".to_string(), + tags: vec![], + metadata: None, + updated_at: chrono::Utc::now(), + }) + .await + .unwrap(); + + store + .upsert_document(TextDocument { + subject_id: id2, + kind: SubstrateKind::Entity, + namespace: "test".to_string(), + title: Some("Python Guide".to_string()), + body: "Python is a high-level programming language.".to_string(), + tags: vec![], + metadata: None, + updated_at: chrono::Utc::now(), + }) + .await + .unwrap(); + + // Wrap in adapter and use KeywordSearch trait + let adapter = StorageKeywordSearch::new(store); + let hits = adapter.keyword_search("Rust", 10).await.unwrap(); + + // Should find the Rust document + assert!(!hits.is_empty()); + assert_eq!(hits[0].0, id1); + assert!(hits[0].1.to_f64() > 0.0); + } + + #[tokio::test] + async fn keyword_search_respects_top_k() { + let backend = test_backend(); + let store = backend.text("test_ks_topk").unwrap(); + + // Insert 5 documents all containing "programming" + for i in 0..5 { + store + .upsert_document(TextDocument { + subject_id: Uuid::new_v4(), + kind: SubstrateKind::Note, + namespace: "test".to_string(), + title: Some(format!("Doc {}", i)), + body: format!("Programming topic number {}.", i), + tags: vec![], + metadata: None, + updated_at: chrono::Utc::now(), + }) + .await + .unwrap(); + } + + let adapter = StorageKeywordSearch::new(store); + let hits = adapter.keyword_search("programming", 3).await.unwrap(); + + assert!(hits.len() <= 3); + } + + #[tokio::test] + async fn keyword_search_empty_store() { + let backend = test_backend(); + let store = backend.text("test_ks_empty").unwrap(); + + let adapter = StorageKeywordSearch::new(store); + let hits = adapter.keyword_search("anything", 5).await.unwrap(); + + assert!(hits.is_empty()); + } + + #[tokio::test] + async fn keyword_search_no_match() { + let backend = test_backend(); + let store = backend.text("test_ks_nomatch").unwrap(); + + store + .upsert_document(TextDocument { + subject_id: Uuid::new_v4(), + kind: SubstrateKind::Entity, + namespace: "test".to_string(), + title: Some("Alpha".to_string()), + body: "Alpha article content.".to_string(), + tags: vec![], + metadata: None, + updated_at: chrono::Utc::now(), + }) + .await + .unwrap(); + + let adapter = StorageKeywordSearch::new(store); + let hits = adapter + .keyword_search("nonexistent_xyz_term", 5) + .await + .unwrap(); + + assert!(hits.is_empty()); + } + + // ----------------------------------------------------------------------- + // Integration: both adapters with fusion + // ----------------------------------------------------------------------- + + #[tokio::test] + async fn adapters_produce_fusible_results() { + use crate::hybrid::{fuse_search_results, HybridConfig}; + + let backend = test_backend(); + let vec_store = backend.vectors("test_fuse", 3).unwrap(); + let text_store = backend.text("test_fuse").unwrap(); + + let id = Uuid::new_v4(); + + // Insert into both stores + vec_store + .insert(id, SubstrateKind::Note, "test", vec![1.0, 0.0, 0.0]) + .await + .unwrap(); + text_store + .upsert_document(TextDocument { + subject_id: id, + kind: SubstrateKind::Note, + namespace: "test".to_string(), + title: Some("Test".to_string()), + body: "Test document for fusion.".to_string(), + tags: vec![], + metadata: None, + updated_at: chrono::Utc::now(), + }) + .await + .unwrap(); + + let vec_adapter = StorageVectorSearch::new(vec_store); + let kw_adapter = StorageKeywordSearch::new(text_store); + + let vec_hits = vec_adapter + .vector_search(&[1.0, 0.0, 0.0], 5) + .await + .unwrap(); + let kw_hits = kw_adapter.keyword_search("Test", 5).await.unwrap(); + + // Both should return the same UUID + assert!(!vec_hits.is_empty()); + assert!(!kw_hits.is_empty()); + assert_eq!(vec_hits[0].0, id); + assert_eq!(kw_hits[0].0, id); + + // Fuse the results -- same Id type (Uuid) means fusion works + let config = HybridConfig::new(10); + let fused = fuse_search_results(vec![vec_hits, kw_hits], &config); + + assert!(!fused.is_empty()); + // The single shared UUID should appear in fused results + assert_eq!(fused[0].0, id); + } +} diff --git a/crates/khive-retrieval/src/error.rs b/crates/khive-retrieval/src/error.rs new file mode 100644 index 00000000..d6f5b681 --- /dev/null +++ b/crates/khive-retrieval/src/error.rs @@ -0,0 +1,504 @@ +//! Error types for retrieval operations. +//! +//! Uses khive-db error patterns and integrates with EmbeddingError. +//! +//! # Error Classification (RETRIEVAL-06) +//! +//! Errors are classified into two categories for retry behavior: +//! +//! ## Transient Errors (retryable) +//! +//! These errors may succeed on retry and include: +//! - **Network errors**: Connection timeouts, temporary unavailability +//! - **Resource contention**: Lock conflicts, rate limiting +//! - **External service errors**: Embedding/link store temporary failures +//! +//! Recommended retry strategy: exponential backoff with jitter, max 3 retries. +//! +//! ## Permanent Errors (non-retryable) +//! +//! These errors indicate logic/data issues that won't be fixed by retry: +//! - **Validation errors**: Invalid query, dimension mismatch +//! - **Configuration errors**: Bad parameters, missing required fields +//! - **Data integrity errors**: Corrupt index, rebuild required +//! +//! These should be surfaced to the user immediately. +//! +//! # Usage +//! +//! ```rust +//! use khive_retrieval::error::RetrievalError; +//! +//! fn handle_error(err: RetrievalError) { +//! if err.is_transient() { +//! // Retry with backoff +//! println!("Retrying: {}", err); +//! } else { +//! // Surface to user immediately +//! eprintln!("Permanent error: {}", err); +//! } +//! } +//! ``` + +use thiserror::Error; + +/// Error classification for retry behavior. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ErrorKind { + /// Transient error that may succeed on retry (network, contention). + Transient, + /// Permanent error that won't be fixed by retry (validation, config). + Permanent, +} + +/// Errors that can occur during retrieval operations. +#[derive(Error, Debug)] +#[non_exhaustive] +pub enum RetrievalError { + /// Vector index operation failed. + #[error("hnsw error: {0}")] + Hnsw(String), + + /// BM25 index operation failed. + #[error("bm25 error: {0}")] + Bm25(String), + + /// Fusion operation failed. + #[error("fusion error: {0}")] + Fusion(String), + + /// Graph traversal failed. + #[error("graph traversal error: {0}")] + GraphTraversal(String), + + /// Invalid query parameters. + #[error("invalid query: {0}")] + InvalidQuery(String), + + /// Dimension mismatch. + #[error("dimension mismatch: expected {expected}, got {actual}")] + DimensionMismatch { + /// Expected dimensions. + expected: usize, + /// Actual dimensions. + actual: usize, + }, + + /// Configuration error. + #[error("configuration error: {0}")] + Configuration(String), + + /// Embedding store error. + #[error("embedding store: {0}")] + EmbeddingStore(String), + + /// Link store error (for graph operations). + #[error("link store: {0}")] + LinkStore(String), + + /// Index not initialized. + #[error("index not initialized: {0}")] + IndexNotInitialized(String), + + /// Index rebuild required. + #[error("index rebuild required: {reason}")] + RebuildRequired { + /// Why rebuild is needed. + reason: String, + }, + + /// Query timed out before completing. + /// + /// The search operation exceeded the configured timeout duration. + /// This is a transient error: the query may succeed with a longer timeout + /// or fewer results requested. + #[error("query timed out after {elapsed_ms}ms")] + QueryTimeout { + /// Elapsed time in milliseconds before timeout. + elapsed_ms: u64, + }, + + /// Query was cancelled via cancellation token. + /// + /// The search operation was cancelled before completing. + /// This is a transient error: the query may succeed if not cancelled. + #[error("query cancelled")] + QueryCancelled, + + /// Memory budget exceeded. + /// + /// The insert operation would cause the index to exceed its configured + /// memory budget. This is a permanent error: the same insert will always + /// fail unless the budget is raised or existing data is removed. + #[error("memory budget exceeded: current {current_usage} + item {item_size} > limit {limit}")] + BudgetExceeded { + /// Current estimated memory usage in bytes. + current_usage: usize, + /// Estimated size of the item being inserted in bytes. + item_size: usize, + /// Configured memory budget in bytes. + limit: usize, + }, + + /// Reranking operation failed (permanent). + #[error("rerank error: {0}")] + Rerank(String), + // TODO(port-rerank): khive-inference not ported yet; re-enable when available. + // #[cfg(feature = "native-rerank")] + // #[error("inference error: {0}")] + // Inference(#[from] khive_inference::InferenceError), +} + +impl RetrievalError { + /// Get the error classification (transient or permanent). + /// + /// This classification determines retry behavior: + /// - `Transient`: May succeed on retry (network, external services) + /// - `Permanent`: Won't be fixed by retry (validation, config, data) + /// + /// # Error Classification Table + /// + /// | Error Type | Classification | Reason | + /// |------------|---------------|--------| + /// | EmbeddingStore | Transient | External service, may recover | + /// | LinkStore | Transient | External service, may recover | + /// | Hnsw | Permanent | Index algorithm error | + /// | Bm25 | Permanent | Index algorithm error | + /// | Fusion | Permanent | Score combination error | + /// | GraphTraversal | Permanent | Graph algorithm error | + /// | InvalidQuery | Permanent | User input validation | + /// | DimensionMismatch | Permanent | Data incompatibility | + /// | Configuration | Permanent | Setup/config issue | + /// | IndexNotInitialized | Permanent | Missing prerequisite | + /// | RebuildRequired | Permanent | Data integrity issue | + /// | QueryTimeout | Transient | May succeed with longer timeout | + /// | QueryCancelled | Transient | May succeed if not cancelled | + /// | BudgetExceeded | Permanent | Capacity limit, won't auto-resolve | + pub fn kind(&self) -> ErrorKind { + match self { + // Transient: external services that may recover, timeouts, cancellations + RetrievalError::EmbeddingStore(_) + | RetrievalError::LinkStore(_) + | RetrievalError::QueryTimeout { .. } + | RetrievalError::QueryCancelled => ErrorKind::Transient, + + // Permanent: logic, validation, and configuration errors + RetrievalError::Hnsw(_) + | RetrievalError::Bm25(_) + | RetrievalError::Fusion(_) + | RetrievalError::GraphTraversal(_) + | RetrievalError::InvalidQuery(_) + | RetrievalError::DimensionMismatch { .. } + | RetrievalError::Configuration(_) + | RetrievalError::IndexNotInitialized(_) + | RetrievalError::RebuildRequired { .. } + | RetrievalError::BudgetExceeded { .. } + | RetrievalError::Rerank(_) => ErrorKind::Permanent, + // TODO(port-rerank): khive-inference not ported yet + // #[cfg(feature = "native-rerank")] + // RetrievalError::Inference(_) => ErrorKind::Permanent, + } + } + + /// Check if this error is transient (may succeed on retry). + /// + /// Transient errors include: + /// - External service failures (embedding store, link store) + /// - Network-related issues + /// - Resource contention + /// + /// # Retry Strategy + /// + /// For transient errors, use exponential backoff with jitter: + /// - Initial delay: 100ms + /// - Max delay: 5s + /// - Max retries: 3 + /// - Jitter: +/- 20% + /// + /// # Example + /// + /// ```rust + /// use khive_retrieval::error::RetrievalError; + /// + /// fn should_retry(err: &RetrievalError) -> bool { + /// err.is_transient() + /// } + /// ``` + #[inline] + pub fn is_transient(&self) -> bool { + self.kind() == ErrorKind::Transient + } + + /// Check if this error is permanent (won't be fixed by retry). + /// + /// Permanent errors should be surfaced to the user immediately + /// without retry attempts. + #[inline] + pub fn is_permanent(&self) -> bool { + self.kind() == ErrorKind::Permanent + } + + /// Check if this error is retryable (alias for `is_transient`). + /// + /// Provided for backward compatibility and semantic clarity. + #[inline] + pub fn is_retryable(&self) -> bool { + self.is_transient() + } + + /// Create a rerank error (permanent). + pub fn rerank(msg: impl Into) -> Self { + Self::Rerank(msg.into()) + } + + /// Create an HNSW error (permanent). + pub fn hnsw(msg: impl Into) -> Self { + Self::Hnsw(msg.into()) + } + + /// Create a BM25 error (permanent). + pub fn bm25(msg: impl Into) -> Self { + Self::Bm25(msg.into()) + } + + /// Create a fusion error (permanent). + pub fn fusion(msg: impl Into) -> Self { + Self::Fusion(msg.into()) + } + + /// Create a graph traversal error (permanent). + pub fn graph_traversal(msg: impl Into) -> Self { + Self::GraphTraversal(msg.into()) + } + + /// Create an invalid query error (permanent). + pub fn invalid_query(msg: impl Into) -> Self { + Self::InvalidQuery(msg.into()) + } + + /// Create a dimension mismatch error (permanent). + pub fn dimension_mismatch(expected: usize, actual: usize) -> Self { + Self::DimensionMismatch { expected, actual } + } + + /// Create a configuration error (permanent). + pub fn configuration(msg: impl Into) -> Self { + Self::Configuration(msg.into()) + } + + /// Create an index not initialized error (permanent). + pub fn index_not_initialized(msg: impl Into) -> Self { + Self::IndexNotInitialized(msg.into()) + } + + /// Create a rebuild required error (permanent). + pub fn rebuild_required(reason: impl Into) -> Self { + Self::RebuildRequired { + reason: reason.into(), + } + } + + /// Create a query timeout error (transient). + pub fn query_timeout(elapsed_ms: u64) -> Self { + Self::QueryTimeout { elapsed_ms } + } + + /// Create a query cancelled error (transient). + pub fn query_cancelled() -> Self { + Self::QueryCancelled + } + + /// Create a budget exceeded error (permanent). + pub fn budget_exceeded(current_usage: usize, item_size: usize, limit: usize) -> Self { + Self::BudgetExceeded { + current_usage, + item_size, + limit, + } + } +} + +/// Result type alias for retrieval operations. +pub type Result = std::result::Result; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_error_display() { + let err = RetrievalError::hnsw("connection failed"); + assert_eq!(err.to_string(), "hnsw error: connection failed"); + } + + #[test] + fn test_dimension_mismatch() { + let err = RetrievalError::dimension_mismatch(768, 512); + assert_eq!(err.to_string(), "dimension mismatch: expected 768, got 512"); + } + + #[test] + fn test_is_retryable() { + // Non-retryable (permanent errors) + assert!(!RetrievalError::hnsw("fail").is_retryable()); + assert!(!RetrievalError::bm25("fail").is_retryable()); + assert!(!RetrievalError::InvalidQuery("bad".into()).is_retryable()); + assert!(!RetrievalError::dimension_mismatch(768, 512).is_retryable()); + } + + // RETRIEVAL-06: Comprehensive error classification tests + + #[test] + fn test_error_kind_transient() { + // EmbeddingStore and LinkStore are transient (external services) + // Note: We can't easily construct these without the actual error types, + // so we test via is_transient/is_permanent methods on constructable errors + } + + #[test] + fn test_error_kind_permanent_all_variants() { + // All internal errors should be permanent + let permanent_errors: Vec = vec![ + RetrievalError::hnsw("index corrupt"), + RetrievalError::bm25("tokenization failed"), + RetrievalError::fusion("incompatible scores"), + RetrievalError::graph_traversal("cycle detected"), + RetrievalError::invalid_query("empty query"), + RetrievalError::dimension_mismatch(768, 512), + RetrievalError::configuration("invalid k1 value"), + RetrievalError::index_not_initialized("HNSW index"), + RetrievalError::rebuild_required("version mismatch"), + RetrievalError::budget_exceeded(1000, 500, 1200), + ]; + + for err in permanent_errors { + assert!(err.is_permanent(), "Expected permanent: {err:?}"); + assert!(!err.is_transient(), "Should not be transient: {err:?}"); + assert_eq!( + err.kind(), + ErrorKind::Permanent, + "Kind mismatch for: {err:?}" + ); + } + } + + #[test] + fn test_is_transient_is_permanent_consistency() { + // is_transient and is_permanent should be mutually exclusive and exhaustive + let test_errors: Vec = vec![ + RetrievalError::hnsw("test"), + RetrievalError::bm25("test"), + RetrievalError::fusion("test"), + RetrievalError::invalid_query("test"), + RetrievalError::dimension_mismatch(1, 2), + RetrievalError::configuration("test"), + RetrievalError::budget_exceeded(100, 50, 120), + ]; + + for err in test_errors { + let transient = err.is_transient(); + let permanent = err.is_permanent(); + + // XOR: exactly one should be true + assert!( + transient ^ permanent, + "Error must be exactly transient OR permanent: {err:?} (transient={transient}, permanent={permanent})" + ); + + // is_retryable should match is_transient + assert_eq!( + err.is_retryable(), + err.is_transient(), + "is_retryable should equal is_transient for: {err:?}" + ); + } + } + + #[test] + fn test_error_constructors_produce_correct_messages() { + assert_eq!(RetrievalError::hnsw("test").to_string(), "hnsw error: test"); + assert_eq!(RetrievalError::bm25("test").to_string(), "bm25 error: test"); + assert_eq!( + RetrievalError::fusion("test").to_string(), + "fusion error: test" + ); + assert_eq!( + RetrievalError::graph_traversal("test").to_string(), + "graph traversal error: test" + ); + assert_eq!( + RetrievalError::invalid_query("test").to_string(), + "invalid query: test" + ); + assert_eq!( + RetrievalError::configuration("test").to_string(), + "configuration error: test" + ); + assert_eq!( + RetrievalError::index_not_initialized("test").to_string(), + "index not initialized: test" + ); + assert_eq!( + RetrievalError::rebuild_required("test").to_string(), + "index rebuild required: test" + ); + assert_eq!( + RetrievalError::budget_exceeded(100, 50, 120).to_string(), + "memory budget exceeded: current 100 + item 50 > limit 120" + ); + } + + #[test] + fn test_error_kind_enum_debug() { + // Verify ErrorKind is Debug-able + assert_eq!(format!("{:?}", ErrorKind::Transient), "Transient"); + assert_eq!(format!("{:?}", ErrorKind::Permanent), "Permanent"); + } + + #[test] + fn test_error_kind_equality() { + // Verify ErrorKind implements PartialEq correctly + assert_eq!(ErrorKind::Transient, ErrorKind::Transient); + assert_eq!(ErrorKind::Permanent, ErrorKind::Permanent); + assert_ne!(ErrorKind::Transient, ErrorKind::Permanent); + } + + #[test] + fn test_query_timeout_error() { + let err = RetrievalError::query_timeout(5000); + assert_eq!(err.to_string(), "query timed out after 5000ms"); + assert!(err.is_transient()); + assert!(!err.is_permanent()); + assert!(err.is_retryable()); + assert_eq!(err.kind(), ErrorKind::Transient); + } + + #[test] + fn test_query_cancelled_error() { + let err = RetrievalError::query_cancelled(); + assert_eq!(err.to_string(), "query cancelled"); + assert!(err.is_transient()); + assert!(!err.is_permanent()); + assert!(err.is_retryable()); + assert_eq!(err.kind(), ErrorKind::Transient); + } + + #[test] + fn test_transient_errors_classification() { + // All transient errors should be classified correctly + let transient_errors: Vec = vec![ + RetrievalError::query_timeout(100), + RetrievalError::query_cancelled(), + ]; + + for err in transient_errors { + assert!(err.is_transient(), "Expected transient: {err:?}"); + assert!(!err.is_permanent(), "Should not be permanent: {err:?}"); + assert_eq!( + err.kind(), + ErrorKind::Transient, + "Kind mismatch for: {err:?}" + ); + } + } +} diff --git a/crates/khive-retrieval/src/eval/engine_eval.rs b/crates/khive-retrieval/src/eval/engine_eval.rs new file mode 100644 index 00000000..a6021b21 --- /dev/null +++ b/crates/khive-retrieval/src/eval/engine_eval.rs @@ -0,0 +1,658 @@ +//! Retrieval evaluation types and metrics for the khive compose pipeline. +//! +//! Provides the label taxonomy, graded scoring, and standard information-retrieval +//! metrics needed to measure compose quality against annotated benchmarks. +//! +//! # Design +//! +//! Labels follow a 5-level taxonomy (`Decisive` → `AdjacentWrong`) modelled on +//! GPQA-style relevance judgements where topically adjacent but factually wrong +//! sections are explicitly penalised. The `gain` scoring function drives nDCG and +//! `net_evidence` metrics. +//! +//! All metric functions operate on a slice of [`LabeledResult`] in **ranked order** +//! (index 0 = rank 1). Callers are responsible for pre-sorting. + +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +// --------------------------------------------------------------------------- +// Label taxonomy +// --------------------------------------------------------------------------- + +/// Five-level relevance label for retrieved sections. +/// +/// Labels are designed for GPQA-style evaluation where *topically adjacent but +/// factually wrong* sections are more harmful than irrelevant ones — they can +/// actively mislead an LLM agent that trusts retrieved context. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum RetrievalLabel { + /// Section directly answers or enables answering the query. + Decisive, + /// Section provides useful supporting evidence for the query. + Supporting, + /// Section provides general context but not specific evidence. + Background, + /// Section has no relationship to the query. + Irrelevant, + /// Section is on-topic but contains factually incorrect information that + /// would mislead an LLM agent (the "GPQA failure mode"). + AdjacentWrong, +} + +impl RetrievalLabel { + /// Graded relevance gain used in DCG / net-evidence calculations. + /// + /// `AdjacentWrong` carries a negative gain to penalise retrieval of + /// misleading but plausible-sounding sections. + pub fn gain(self) -> f64 { + match self { + Self::Decisive => 3.0, + Self::Supporting => 2.0, + Self::Background => 0.5, + Self::Irrelevant => 0.0, + Self::AdjacentWrong => -2.0, + } + } + + /// Returns `true` for labels that count as "relevant" in binary recall/precision. + pub fn is_relevant(self) -> bool { + matches!(self, Self::Decisive | Self::Supporting) + } + + /// Returns `true` for labels that count as active distractors. + pub fn is_distractor(self) -> bool { + matches!(self, Self::AdjacentWrong) + } +} + +// --------------------------------------------------------------------------- +// Result type +// --------------------------------------------------------------------------- + +/// A single retrieved section with its ground-truth relevance label. +/// +/// The slice passed to metric functions must be ordered by descending score +/// (rank 1 at index 0). +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LabeledResult { + /// Unique identifier of the retrieved section. + pub section_id: Uuid, + /// Retrieval score (higher = more relevant according to the pipeline). + pub score: f64, + /// Ground-truth relevance label assigned by a human or eval pipeline. + pub label: RetrievalLabel, +} + +// --------------------------------------------------------------------------- +// Aggregate metrics struct +// --------------------------------------------------------------------------- + +/// All standard retrieval metrics computed for a single query. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RetrievalMetrics { + /// Recall at multiple k values: `[(k, recall_k), ...]`. + pub recall_at_k: Vec<(usize, f64)>, + /// nDCG at k = 10 using graded gains from [`RetrievalLabel::gain`]. + pub ndcg_at_10: f64, + /// Precision at k = 5 (fraction of top-5 that are relevant). + pub precision_at_5: f64, + /// Precision at k = 10 (fraction of top-10 that are relevant). + pub precision_at_10: f64, + /// Fraction of top-10 results that are `AdjacentWrong` distractors. + pub distractor_at_10: f64, + /// Net graded evidence at k = 10: `sum(gain_i / log2(i+1))` for i in 1..=10. + pub net_evidence_at_10: f64, + /// Mean reciprocal rank: `1 / rank` of the first `Decisive` result, or `0.0`. + pub mrr: f64, + /// Optional before/after flip ratio: `wrong→right / right→wrong`. + /// + /// `None` when only a single ranking is available (no before/after pair). + pub flip_ratio: Option, +} + +// --------------------------------------------------------------------------- +// Metric functions +// --------------------------------------------------------------------------- + +/// Recall at k: fraction of all `Decisive | Supporting` items that appear in top-k. +/// +/// Returns `1.0` when there are no relevant items in the full list (vacuously true). +pub fn recall_at_k(results: &[LabeledResult], k: usize) -> f64 { + let total_relevant: usize = results.iter().filter(|r| r.label.is_relevant()).count(); + if total_relevant == 0 { + return 1.0; + } + let k = k.min(results.len()); + let found: usize = results[..k] + .iter() + .filter(|r| r.label.is_relevant()) + .count(); + found as f64 / total_relevant as f64 +} + +/// Precision at k: fraction of top-k results that are `Decisive | Supporting`. +/// +/// Returns `0.0` when k = 0 or results is empty. +pub fn precision_at_k(results: &[LabeledResult], k: usize) -> f64 { + if k == 0 || results.is_empty() { + return 0.0; + } + let k = k.min(results.len()); + let relevant: usize = results[..k] + .iter() + .filter(|r| r.label.is_relevant()) + .count(); + relevant as f64 / k as f64 +} + +/// Distractor at k: fraction of top-k results that are `AdjacentWrong`. +/// +/// Returns `0.0` when k = 0 or results is empty. +pub fn distractor_at_k(results: &[LabeledResult], k: usize) -> f64 { + if k == 0 || results.is_empty() { + return 0.0; + } + let k = k.min(results.len()); + let distractors: usize = results[..k] + .iter() + .filter(|r| r.label.is_distractor()) + .count(); + distractors as f64 / k as f64 +} + +/// Net evidence at k: `sum(gain(label_i) / log2(i+2))` for i in 0..k. +/// +/// The discount denominator uses `log2(i+2)` so that rank-1 (i=0) gets +/// `log2(2) = 1.0` — the standard DCG convention. +/// +/// Returns `0.0` when k = 0 or results is empty. +pub fn net_evidence_at_k(results: &[LabeledResult], k: usize) -> f64 { + if k == 0 || results.is_empty() { + return 0.0; + } + let k = k.min(results.len()); + results[..k] + .iter() + .enumerate() + .map(|(i, r)| r.label.gain() / (i as f64 + 2.0).log2()) + .sum() +} + +/// nDCG at k using graded gains from [`RetrievalLabel::gain`]. +/// +/// The ideal ranking places all `Decisive` results first, then `Supporting`, +/// `Background`, `Irrelevant`, and finally `AdjacentWrong`. The ideal DCG is +/// computed from a sorted-by-gain copy of the full result list. +/// +/// Returns `1.0` when the ideal DCG is zero (no positive-gain items). +pub fn ndcg_at_k(results: &[LabeledResult], k: usize) -> f64 { + if k == 0 || results.is_empty() { + return 0.0; + } + let k = k.min(results.len()); + + let dcg = results[..k] + .iter() + .enumerate() + .map(|(i, r)| r.label.gain() / (i as f64 + 2.0).log2()) + .sum::(); + + // Ideal DCG: sort all results by gain descending, take top-k. + let mut gains: Vec = results.iter().map(|r| r.label.gain()).collect(); + gains.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal)); + + let idcg = gains[..k] + .iter() + .enumerate() + .map(|(i, &g)| g / (i as f64 + 2.0).log2()) + .sum::(); + + if idcg == 0.0 { + // No positive-gain items exist and no negative-gain items; vacuously perfect. + return 1.0; + } + if idcg < 0.0 { + // Only negative-gain items (all distractors); worst possible outcome. + return 0.0; + } + (dcg / idcg).clamp(0.0, 1.0) +} + +/// Mean reciprocal rank: `1.0 / rank` of the first `Decisive` result. +/// +/// Returns `0.0` if no `Decisive` result appears in the list. +pub fn mrr(results: &[LabeledResult]) -> f64 { + for (i, r) in results.iter().enumerate() { + if r.label == RetrievalLabel::Decisive { + return 1.0 / (i as f64 + 1.0); + } + } + 0.0 +} + +/// Compute all standard retrieval metrics at their canonical k values. +/// +/// `recall_at_k` is evaluated at k ∈ {1, 3, 5, 10}. +/// All other metrics use their k = 10 (or full-list for MRR) defaults. +pub fn compute_all(results: &[LabeledResult]) -> RetrievalMetrics { + let recall_at_k_vals = vec![ + (1, recall_at_k(results, 1)), + (3, recall_at_k(results, 3)), + (5, recall_at_k(results, 5)), + (10, recall_at_k(results, 10)), + ]; + + RetrievalMetrics { + recall_at_k: recall_at_k_vals, + ndcg_at_10: ndcg_at_k(results, 10), + precision_at_5: precision_at_k(results, 5), + precision_at_10: precision_at_k(results, 10), + distractor_at_10: distractor_at_k(results, 10), + net_evidence_at_10: net_evidence_at_k(results, 10), + mrr: mrr(results), + flip_ratio: None, + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + // ---- helpers ---- + + fn uuid(n: u64) -> Uuid { + Uuid::from_u64_pair(0, n) + } + + fn make_result(n: u64, label: RetrievalLabel) -> LabeledResult { + LabeledResult { + section_id: uuid(n), + score: 1.0 / (n as f64 + 1.0), + label, + } + } + + // ---- RetrievalLabel ---- + + #[test] + fn label_gain_values() { + assert_eq!(RetrievalLabel::Decisive.gain(), 3.0); + assert_eq!(RetrievalLabel::Supporting.gain(), 2.0); + assert_eq!(RetrievalLabel::Background.gain(), 0.5); + assert_eq!(RetrievalLabel::Irrelevant.gain(), 0.0); + assert_eq!(RetrievalLabel::AdjacentWrong.gain(), -2.0); + } + + #[test] + fn label_is_relevant() { + assert!(RetrievalLabel::Decisive.is_relevant()); + assert!(RetrievalLabel::Supporting.is_relevant()); + assert!(!RetrievalLabel::Background.is_relevant()); + assert!(!RetrievalLabel::Irrelevant.is_relevant()); + assert!(!RetrievalLabel::AdjacentWrong.is_relevant()); + } + + #[test] + fn label_is_distractor() { + assert!(RetrievalLabel::AdjacentWrong.is_distractor()); + assert!(!RetrievalLabel::Decisive.is_distractor()); + assert!(!RetrievalLabel::Irrelevant.is_distractor()); + } + + // ---- recall_at_k ---- + + #[test] + fn recall_at_k_all_relevant() { + // 3 decisive results, k = 3 → recall = 1.0 + let results: Vec = (0..3) + .map(|i| make_result(i, RetrievalLabel::Decisive)) + .collect(); + assert!((recall_at_k(&results, 3) - 1.0).abs() < 1e-9); + } + + #[test] + fn recall_at_k_partial() { + // 2 decisive at positions 0,1; 2 irrelevant at 2,3 + let results = vec![ + make_result(0, RetrievalLabel::Decisive), + make_result(1, RetrievalLabel::Decisive), + make_result(2, RetrievalLabel::Irrelevant), + make_result(3, RetrievalLabel::Irrelevant), + ]; + // k=1: 1 of 2 decisive in top-1 → 0.5 + assert!((recall_at_k(&results, 1) - 0.5).abs() < 1e-9); + // k=2: 2 of 2 decisive in top-2 → 1.0 + assert!((recall_at_k(&results, 2) - 1.0).abs() < 1e-9); + } + + #[test] + fn recall_at_k_none_relevant_vacuously_one() { + let results = vec![ + make_result(0, RetrievalLabel::Irrelevant), + make_result(1, RetrievalLabel::Background), + ]; + assert!((recall_at_k(&results, 5) - 1.0).abs() < 1e-9); + } + + #[test] + fn recall_at_k_k_exceeds_length() { + let results = vec![make_result(0, RetrievalLabel::Decisive)]; + // k=100 should be clamped to len=1 + assert!((recall_at_k(&results, 100) - 1.0).abs() < 1e-9); + } + + // ---- precision_at_k ---- + + #[test] + fn precision_at_k_perfect() { + let results: Vec = (0..5) + .map(|i| make_result(i, RetrievalLabel::Decisive)) + .collect(); + assert!((precision_at_k(&results, 5) - 1.0).abs() < 1e-9); + } + + #[test] + fn precision_at_k_half_relevant() { + let results = vec![ + make_result(0, RetrievalLabel::Decisive), + make_result(1, RetrievalLabel::Irrelevant), + make_result(2, RetrievalLabel::Supporting), + make_result(3, RetrievalLabel::Irrelevant), + ]; + // top-4: 2 relevant → 0.5 + assert!((precision_at_k(&results, 4) - 0.5).abs() < 1e-9); + } + + #[test] + fn precision_at_k_zero_when_k_zero() { + let results = vec![make_result(0, RetrievalLabel::Decisive)]; + assert_eq!(precision_at_k(&results, 0), 0.0); + } + + #[test] + fn precision_at_k_zero_when_empty() { + assert_eq!(precision_at_k(&[], 5), 0.0); + } + + #[test] + fn precision_at_k_adjacent_wrong_not_counted() { + let results = vec![ + make_result(0, RetrievalLabel::AdjacentWrong), + make_result(1, RetrievalLabel::AdjacentWrong), + ]; + assert_eq!(precision_at_k(&results, 2), 0.0); + } + + // ---- distractor_at_k ---- + + #[test] + fn distractor_at_k_all_wrong() { + let results: Vec = (0..4) + .map(|i| make_result(i, RetrievalLabel::AdjacentWrong)) + .collect(); + assert!((distractor_at_k(&results, 4) - 1.0).abs() < 1e-9); + } + + #[test] + fn distractor_at_k_none_wrong() { + let results = vec![ + make_result(0, RetrievalLabel::Decisive), + make_result(1, RetrievalLabel::Irrelevant), + ]; + assert_eq!(distractor_at_k(&results, 2), 0.0); + } + + #[test] + fn distractor_at_k_mixed() { + // 1 wrong in top-4 → 0.25 + let results = vec![ + make_result(0, RetrievalLabel::Decisive), + make_result(1, RetrievalLabel::AdjacentWrong), + make_result(2, RetrievalLabel::Irrelevant), + make_result(3, RetrievalLabel::Background), + ]; + assert!((distractor_at_k(&results, 4) - 0.25).abs() < 1e-9); + } + + #[test] + fn distractor_at_k_zero_when_k_zero() { + let results = vec![make_result(0, RetrievalLabel::AdjacentWrong)]; + assert_eq!(distractor_at_k(&results, 0), 0.0); + } + + // ---- net_evidence_at_k ---- + + #[test] + fn net_evidence_at_k_single_decisive_rank1() { + // Rank-1 Decisive: gain=3.0 / log2(2)=1.0 → 3.0 + let results = vec![make_result(0, RetrievalLabel::Decisive)]; + assert!((net_evidence_at_k(&results, 1) - 3.0).abs() < 1e-9); + } + + #[test] + fn net_evidence_at_k_negative_for_all_wrong() { + // Each AdjacentWrong at rank i contributes -2.0 / log2(i+2) + let results: Vec = (0..3) + .map(|i| make_result(i as u64, RetrievalLabel::AdjacentWrong)) + .collect(); + let score = net_evidence_at_k(&results, 3); + assert!( + score < 0.0, + "all distractors should produce negative net evidence" + ); + } + + #[test] + fn net_evidence_at_k_zero_for_empty() { + assert_eq!(net_evidence_at_k(&[], 5), 0.0); + } + + #[test] + fn net_evidence_at_k_zero_for_k_zero() { + let results = vec![make_result(0, RetrievalLabel::Decisive)]; + assert_eq!(net_evidence_at_k(&results, 0), 0.0); + } + + #[test] + fn net_evidence_at_k_mixed_sums_correctly() { + // rank1=Decisive(3.0/log2(2)=3.0), rank2=Supporting(2.0/log2(3)≈1.2619) + let results = vec![ + make_result(0, RetrievalLabel::Decisive), + make_result(1, RetrievalLabel::Supporting), + ]; + let expected = 3.0 / 2.0_f64.log2() + 2.0 / 3.0_f64.log2(); + let actual = net_evidence_at_k(&results, 2); + assert!( + (actual - expected).abs() < 1e-9, + "expected {expected}, got {actual}" + ); + } + + // ---- ndcg_at_k ---- + + #[test] + fn ndcg_at_k_perfect_ranking() { + // Perfect ranking: Decisive first → nDCG = 1.0 + let results = vec![ + make_result(0, RetrievalLabel::Decisive), + make_result(1, RetrievalLabel::Supporting), + make_result(2, RetrievalLabel::Irrelevant), + ]; + let score = ndcg_at_k(&results, 3); + assert!( + (score - 1.0).abs() < 1e-9, + "perfect ranking should yield nDCG=1.0, got {score}" + ); + } + + #[test] + fn ndcg_at_k_suboptimal_ranking() { + // Irrelevant first, Decisive second → nDCG < 1.0 + let results = vec![ + make_result(0, RetrievalLabel::Irrelevant), + make_result(1, RetrievalLabel::Decisive), + ]; + let score = ndcg_at_k(&results, 2); + assert!( + score < 1.0 && score > 0.0, + "suboptimal ranking should yield 0 < nDCG < 1.0, got {score}" + ); + } + + #[test] + fn ndcg_at_k_all_irrelevant_vacuously_one() { + // No positive-gain items → vacuously 1.0 (idcg = 0) + let results = vec![ + make_result(0, RetrievalLabel::Irrelevant), + make_result(1, RetrievalLabel::Irrelevant), + ]; + let score = ndcg_at_k(&results, 2); + assert!((score - 1.0).abs() < 1e-9); + } + + #[test] + fn ndcg_at_k_zero_for_zero_k() { + let results = vec![make_result(0, RetrievalLabel::Decisive)]; + assert_eq!(ndcg_at_k(&results, 0), 0.0); + } + + #[test] + fn ndcg_at_k_clamped_not_above_one() { + // Construct a case that could produce DCG > IDCG due to floating-point; + // verify clamp keeps result ≤ 1.0. + let results: Vec = (0..10) + .map(|i| make_result(i, RetrievalLabel::Decisive)) + .collect(); + let score = ndcg_at_k(&results, 10); + assert!( + score <= 1.0 + 1e-12, + "nDCG must not exceed 1.0, got {score}" + ); + } + + // ---- mrr ---- + + #[test] + fn mrr_decisive_at_rank1() { + let results = vec![ + make_result(0, RetrievalLabel::Decisive), + make_result(1, RetrievalLabel::Irrelevant), + ]; + assert!((mrr(&results) - 1.0).abs() < 1e-9); + } + + #[test] + fn mrr_decisive_at_rank3() { + let results = vec![ + make_result(0, RetrievalLabel::Irrelevant), + make_result(1, RetrievalLabel::Supporting), + make_result(2, RetrievalLabel::Decisive), + ]; + // 1/3 + assert!((mrr(&results) - 1.0 / 3.0).abs() < 1e-9); + } + + #[test] + fn mrr_no_decisive() { + let results = vec![ + make_result(0, RetrievalLabel::Supporting), + make_result(1, RetrievalLabel::Irrelevant), + ]; + assert_eq!(mrr(&results), 0.0); + } + + #[test] + fn mrr_empty() { + assert_eq!(mrr(&[]), 0.0); + } + + // ---- compute_all ---- + + #[test] + fn compute_all_returns_correct_structure() { + let results: Vec = (0..10) + .map(|i| { + let label = if i < 3 { + RetrievalLabel::Decisive + } else { + RetrievalLabel::Irrelevant + }; + make_result(i, label) + }) + .collect(); + let metrics = compute_all(&results); + + // recall_at_k has 4 entries for k ∈ {1,3,5,10} + assert_eq!(metrics.recall_at_k.len(), 4); + assert_eq!(metrics.recall_at_k[0].0, 1); + assert_eq!(metrics.recall_at_k[1].0, 3); + assert_eq!(metrics.recall_at_k[2].0, 5); + assert_eq!(metrics.recall_at_k[3].0, 10); + + // k=3: all 3 decisive in top-3 → recall=1.0 + assert!((metrics.recall_at_k[1].1 - 1.0).abs() < 1e-9); + + // MRR = 1.0 (decisive at rank 1) + assert!((metrics.mrr - 1.0).abs() < 1e-9); + + // flip_ratio is None (no before/after pair provided) + assert!(metrics.flip_ratio.is_none()); + } + + #[test] + fn compute_all_distractor_metric() { + // 5 adjacent-wrong at ranks 1-5, rest irrelevant + let results: Vec = (0..10) + .map(|i| { + let label = if i < 5 { + RetrievalLabel::AdjacentWrong + } else { + RetrievalLabel::Irrelevant + }; + make_result(i, label) + }) + .collect(); + let metrics = compute_all(&results); + // distractor_at_10 = 5/10 = 0.5 + assert!( + (metrics.distractor_at_10 - 0.5).abs() < 1e-9, + "got {}", + metrics.distractor_at_10 + ); + // mrr = 0 (no Decisive) + assert_eq!(metrics.mrr, 0.0); + } + + // ---- serialization round-trip ---- + + #[test] + fn label_serde_roundtrip() { + for label in [ + RetrievalLabel::Decisive, + RetrievalLabel::Supporting, + RetrievalLabel::Background, + RetrievalLabel::Irrelevant, + RetrievalLabel::AdjacentWrong, + ] { + let json = serde_json::to_string(&label).unwrap(); + let back: RetrievalLabel = serde_json::from_str(&json).unwrap(); + assert_eq!(label, back); + } + } + + #[test] + fn metrics_serde_roundtrip() { + let results = vec![make_result(0, RetrievalLabel::Decisive)]; + let m = compute_all(&results); + let json = serde_json::to_string(&m).unwrap(); + let back: RetrievalMetrics = serde_json::from_str(&json).unwrap(); + assert_eq!(back.recall_at_k.len(), 4); + assert!((back.mrr - 1.0).abs() < 1e-9); + } +} diff --git a/crates/khive-retrieval/src/eval/mod.rs b/crates/khive-retrieval/src/eval/mod.rs new file mode 100644 index 00000000..4de5a74a --- /dev/null +++ b/crates/khive-retrieval/src/eval/mod.rs @@ -0,0 +1,5 @@ +//! Retrieval evaluation types and metrics. + +pub mod engine_eval; + +pub use engine_eval::*; diff --git a/crates/khive-retrieval/src/graph/bfs.rs b/crates/khive-retrieval/src/graph/bfs.rs new file mode 100644 index 00000000..e4c5b1d1 --- /dev/null +++ b/crates/khive-retrieval/src/graph/bfs.rs @@ -0,0 +1,148 @@ +//! BFS (Breadth-First Search) traversal. +//! +//! # Formal Verification +//! +//! This implementation corresponds to the formal proofs in +//! `proofs/Lion/Retrieval/Graph.lean`. Key theorems: +//! +//! - `bfs_terminates`: BFS always terminates (queue eventually empty) +//! - `bfs_complete`: all reachable vertices are visited +//! - `visited_mono`: visited set grows monotonically +//! - `reachable_trans`: reachability is transitive + +use std::collections::{HashSet, VecDeque}; + +use super::compat::{EntityRef, LinkStore, StorageContext}; + +use crate::error::Result; + +use super::helpers::{get_edge_weight, get_neighbor_entity, get_neighbors, matches_link_type}; +use super::types::{PathNode, TraversalOptions, MAX_TRAVERSAL_DEPTH, MAX_TRAVERSAL_RESULTS}; + +/// Perform BFS traversal from a starting entity. +/// +/// BFS explores nodes level by level, guaranteeing that nodes at depth N +/// are visited before nodes at depth N+1. This makes it ideal for: +/// +/// - Finding all entities within N hops +/// - Social network expansion (friends of friends) +/// - Entity neighborhood exploration +/// +/// # Arguments +/// +/// * `store` - The link store to query +/// * `ctx` - Storage context for namespace isolation +/// * `start` - Starting entity reference +/// * `options` - Traversal options (depth, direction, filters) +/// +/// # Returns +/// +/// Vector of [`PathNode`] in BFS order. The first element is always the start node. +/// +/// # Complexity +/// +/// - Time: O(V + E) where V = vertices, E = edges +/// - Space: O(V) for visited set and queue +/// +/// # Example +/// +/// ```ignore +/// let options = TraversalOptions::new(3) +/// .with_direction(Direction::Out) +/// .with_link_types(["KNOWS"]); +/// +/// let nodes = bfs_traverse(&store, &ctx, start_ref, &options).await?; +/// for node in &nodes { +/// println!("Entity {:?} at depth {}", node.entity_id, node.depth); +/// } +/// ``` +/// +/// **PROOF CORRESPONDENCE**: `Lion.Retrieval.Graph.bfs_terminates` +/// Queue shrinks each iteration; visited set prevents re-enqueue; terminates when queue empty. +/// +/// **PROOF CORRESPONDENCE**: `Lion.Retrieval.Graph.bfs_complete` +/// All reachable vertices within max_depth are visited; BFS explores level-by-level. +pub async fn bfs_traverse( + store: &S, + ctx: &StorageContext, + start: EntityRef, + options: &TraversalOptions, +) -> Result> { + let max_depth = options.max_depth.min(MAX_TRAVERSAL_DEPTH); + let limit = options + .limit + .unwrap_or(MAX_TRAVERSAL_RESULTS) + .min(MAX_TRAVERSAL_RESULTS); + let min_weight = options.min_weight.unwrap_or(f64::NEG_INFINITY); + + // **PROOF CORRESPONDENCE**: `Lion.Retrieval.Graph.visited_mono` + // Visited set only grows (insert-only); never shrinks during traversal. + // EntityRef implements Hash + Eq, enabling direct use as HashMap key. + let mut visited: HashSet = HashSet::new(); + let mut results: Vec = Vec::new(); + // Queue: (entity_ref, depth, path_weight) + let mut queue: VecDeque<(EntityRef, usize, f64)> = VecDeque::new(); + + // Start node + visited.insert(start.clone()); + results.push(PathNode::start(start.clone())); + queue.push_back((start, 0, 0.0)); + + while let Some((current, depth, path_weight)) = queue.pop_front() { + // Check depth limit + if depth >= max_depth { + continue; + } + + // Check result limit + if results.len() >= limit { + break; + } + + // Get neighbors based on direction + let links = get_neighbors(store, ctx, ¤t, &options.direction).await?; + + for link in links { + // Filter by link type + if !matches_link_type(&link, &options.link_types) { + continue; + } + + // Get edge weight and filter + let edge_weight = get_edge_weight(&link); + if edge_weight < min_weight { + continue; + } + + // Determine neighbor entity based on direction + let neighbor = get_neighbor_entity(&link, ¤t, &options.direction); + + // Skip if already visited (EntityRef implements Hash + Eq) + if visited.contains(&neighbor) { + continue; + } + + // Mark as visited and add to results + visited.insert(neighbor.clone()); + let new_weight = path_weight + edge_weight; + + let node = PathNode { + entity_id: neighbor.clone(), + depth: depth + 1, + via_link: Some(link), + path_weight: new_weight, + }; + results.push(node); + + // Check limit after adding + if results.len() >= limit { + break; + } + + // Add to queue for further exploration + queue.push_back((neighbor, depth + 1, new_weight)); + } + } + + Ok(results) +} diff --git a/crates/khive-retrieval/src/graph/compat.rs b/crates/khive-retrieval/src/graph/compat.rs new file mode 100644 index 00000000..03df1b09 --- /dev/null +++ b/crates/khive-retrieval/src/graph/compat.rs @@ -0,0 +1,228 @@ +//! Compatibility shims for the legacy graph traversal module. +//! +//! The graph module was written against an older `khive_db` API that exported +//! `EntityRef`, `Link`, `LinkStore`, and `StorageContext`. These types no longer +//! exist in `khive_db`. This module provides minimal shims so the graph code +//! compiles under the `graph-legacy` feature until the module is ported to the +//! current `khive_storage::GraphStore` API. + +use std::collections::BTreeMap; + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; + +use crate::error::{Result, RetrievalError}; + +// --------------------------------------------------------------------------- +// EntityRef +// --------------------------------------------------------------------------- + +/// A reference to a graph entity. +/// +/// Legacy type — maps to the old `khive_db::EntityRef` API. +#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(tag = "kind", content = "id", rename_all = "snake_case")] +pub enum EntityRef { + /// An externally-identified entity (string key). + External(String), +} + +// --------------------------------------------------------------------------- +// Link +// --------------------------------------------------------------------------- + +/// An opaque link identifier. +/// +/// Legacy type — shim for the old `khive_db::LinkId`. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct LinkId(u64); + +impl LinkId { + /// The nil / zero link ID. + pub const NIL: Self = Self(0); +} + +/// A directed edge between two entities. +/// +/// Legacy type — maps to the old `khive_db::Link` API. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct Link { + /// Opaque link identifier. + pub id: LinkId, + /// Source entity. + pub source: EntityRef, + /// Target entity. + pub target: EntityRef, + /// Relation type (e.g. "contains", "references"). + pub relation: String, + /// Optional edge properties (e.g. `{"weight": 0.9}`). + pub properties: Option>, +} + +impl Link { + /// Create a new link with no properties. + pub fn new( + id: LinkId, + source: EntityRef, + target: EntityRef, + relation: impl Into, + ) -> Self { + Self { + id, + source, + target, + relation: relation.into(), + properties: None, + } + } + + /// Create a new link with serializable properties. + pub fn with_properties( + id: LinkId, + source: EntityRef, + target: EntityRef, + relation: impl Into, + props: serde_json::Value, + ) -> Self { + let properties = props + .as_object() + .map(|m| m.iter().map(|(k, v)| (k.clone(), v.clone())).collect()); + Self { + id, + source, + target, + relation: relation.into(), + properties, + } + } +} + +// --------------------------------------------------------------------------- +// StorageContext +// --------------------------------------------------------------------------- + +/// Context for storage operations (namespace isolation, etc.). +/// +/// Legacy type — maps to the old `khive_db::StorageContext` API. +#[derive(Clone, Debug, Default)] +pub struct StorageContext { + /// Namespace for multi-tenant isolation. + pub namespace: String, +} + +impl StorageContext { + /// Create a new storage context with the given namespace. + pub fn new(namespace: impl Into) -> Self { + Self { + namespace: namespace.into(), + } + } +} + +// --------------------------------------------------------------------------- +// LinkStore +// --------------------------------------------------------------------------- + +/// Trait for querying directed graph edges. +/// +/// Legacy trait — maps to the old `khive_db::LinkStore` API. +#[async_trait] +pub trait LinkStore: Send + Sync { + /// Get all outgoing links from an entity. + async fn outgoing(&self, ctx: &StorageContext, entity: &EntityRef) -> Result>; + + /// Get all incoming links to an entity. + async fn incoming(&self, ctx: &StorageContext, entity: &EntityRef) -> Result>; + + /// Create a link between two entities. + async fn link( + &self, + ctx: &StorageContext, + source: EntityRef, + target: EntityRef, + relation: &str, + properties: Option, + ) -> Result; +} + +// --------------------------------------------------------------------------- +// MockLinkStore (for tests) +// --------------------------------------------------------------------------- + +/// In-memory mock implementation of `LinkStore` for tests. +pub struct MockLinkStore { + links: parking_lot::Mutex>, + next_id: std::sync::atomic::AtomicU64, +} + +impl MockLinkStore { + /// Create a new empty mock store. + pub fn new() -> Self { + Self { + links: parking_lot::Mutex::new(Vec::new()), + next_id: std::sync::atomic::AtomicU64::new(1), + } + } +} + +impl Default for MockLinkStore { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl LinkStore for MockLinkStore { + async fn outgoing(&self, _ctx: &StorageContext, entity: &EntityRef) -> Result> { + let links = self.links.lock(); + Ok(links + .iter() + .filter(|l| &l.source == entity) + .cloned() + .collect()) + } + + async fn incoming(&self, _ctx: &StorageContext, entity: &EntityRef) -> Result> { + let links = self.links.lock(); + Ok(links + .iter() + .filter(|l| &l.target == entity) + .cloned() + .collect()) + } + + async fn link( + &self, + _ctx: &StorageContext, + source: EntityRef, + target: EntityRef, + relation: &str, + properties: Option, + ) -> Result { + let id = self + .next_id + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + let link = if let Some(props) = properties { + Link::with_properties(LinkId(id), source, target, relation, props) + } else { + Link::new(LinkId(id), source, target, relation) + }; + self.links.lock().push(link.clone()); + Ok(link) + } +} + +/// Create a test storage context. +pub fn test_context() -> StorageContext { + StorageContext::new("test") +} + +// --------------------------------------------------------------------------- +// Error adapter +// --------------------------------------------------------------------------- + +/// Adapt a `String` error into a `RetrievalError::GraphTraversal`. +#[allow(dead_code)] +pub(crate) fn graph_err(msg: impl std::fmt::Display) -> RetrievalError { + RetrievalError::GraphTraversal(msg.to_string()) +} diff --git a/crates/khive-retrieval/src/graph/dfs.rs b/crates/khive-retrieval/src/graph/dfs.rs new file mode 100644 index 00000000..5bed7156 --- /dev/null +++ b/crates/khive-retrieval/src/graph/dfs.rs @@ -0,0 +1,135 @@ +//! DFS (Depth-First Search) traversal. +//! +//! # Formal Verification +//! +//! This implementation corresponds to the formal proofs in +//! `proofs/Lion/Retrieval/Graph.lean`. Key theorems: +//! +//! - `dfs_terminates_bound`: DFS bounded by |V| vertices +//! - `visited_mono`: visited set grows monotonically +//! - `reachable_trans`: reachability is transitive + +use std::collections::HashSet; + +use super::compat::{EntityRef, Link, LinkStore, StorageContext}; + +use crate::error::Result; + +use super::helpers::{get_edge_weight, get_neighbor_entity, get_neighbors, matches_link_type}; +use super::types::{PathNode, TraversalOptions, MAX_TRAVERSAL_DEPTH, MAX_TRAVERSAL_RESULTS}; + +/// Perform DFS traversal from a starting entity. +/// +/// DFS explores as far as possible along each branch before backtracking. +/// This makes it ideal for: +/// +/// - Deep chain exploration +/// - Path existence checking +/// - Exhaustive graph exploration with limited results +/// +/// # Arguments +/// +/// * `store` - The link store to query +/// * `ctx` - Storage context for namespace isolation +/// * `start` - Starting entity reference +/// * `options` - Traversal options (depth, direction, filters) +/// +/// # Returns +/// +/// Vector of [`PathNode`] in DFS pre-order (parent before children). +/// +/// # Complexity +/// +/// - Time: O(V + E) where V = vertices, E = edges +/// - Space: O(V) for visited set + O(h) stack where h = max depth +/// +/// # Example +/// +/// ```ignore +/// let options = TraversalOptions::new(5) +/// .with_direction(Direction::Out); +/// +/// let nodes = dfs_traverse(&store, &ctx, start_ref, &options).await?; +/// ``` +/// +/// **PROOF CORRESPONDENCE**: `Lion.Retrieval.Graph.dfs_terminates_bound` +/// Each vertex visited at most once; |visited| bounded by |V|; stack pops exceed pushes eventually. +pub async fn dfs_traverse( + store: &S, + ctx: &StorageContext, + start: EntityRef, + options: &TraversalOptions, +) -> Result> { + let max_depth = options.max_depth.min(MAX_TRAVERSAL_DEPTH); + let limit = options + .limit + .unwrap_or(MAX_TRAVERSAL_RESULTS) + .min(MAX_TRAVERSAL_RESULTS); + let min_weight = options.min_weight.unwrap_or(f64::NEG_INFINITY); + + // **PROOF CORRESPONDENCE**: `Lion.Retrieval.Graph.visited_mono` + // Visited set only grows (insert-only); never shrinks during traversal. + // EntityRef implements Hash + Eq, enabling direct use as HashMap key. + let mut visited: HashSet = HashSet::new(); + let mut results: Vec = Vec::new(); + + // Stack: (entity_ref, depth, path_weight, via_link) + let mut stack: Vec<(EntityRef, usize, f64, Option)> = Vec::new(); + stack.push((start, 0, 0.0, None)); + + while let Some((current, depth, path_weight, via_link)) = stack.pop() { + // Skip if already visited (EntityRef implements Hash + Eq) + if visited.contains(¤t) { + continue; + } + + // Mark as visited and add to results + visited.insert(current.clone()); + results.push(PathNode { + entity_id: current.clone(), + depth, + via_link, + path_weight, + }); + + // Check result limit + if results.len() >= limit { + break; + } + + // Check depth limit before exploring children + if depth >= max_depth { + continue; + } + + // Get neighbors and push to stack (reverse order for consistent traversal) + let links = get_neighbors(store, ctx, ¤t, &options.direction).await?; + + // Push in reverse order so first neighbor is processed first + for link in links.into_iter().rev() { + // Filter by link type + if !matches_link_type(&link, &options.link_types) { + continue; + } + + // Get edge weight and filter + let edge_weight = get_edge_weight(&link); + if edge_weight < min_weight { + continue; + } + + // Determine neighbor entity + let neighbor = get_neighbor_entity(&link, ¤t, &options.direction); + + // Skip if already visited (EntityRef implements Hash + Eq) + if visited.contains(&neighbor) { + continue; + } + + let new_weight = path_weight + edge_weight; + stack.push((neighbor, depth + 1, new_weight, Some(link))); + } + } + + Ok(results) +} diff --git a/crates/khive-retrieval/src/graph/helpers.rs b/crates/khive-retrieval/src/graph/helpers.rs new file mode 100644 index 00000000..4227a93c --- /dev/null +++ b/crates/khive-retrieval/src/graph/helpers.rs @@ -0,0 +1,278 @@ +//! Helper functions for graph traversal. + +#[cfg(test)] +use super::compat::LinkId; +use super::compat::{EntityRef, Link, LinkStore, StorageContext}; +use khive_score::DeterministicScore; + +use crate::error::{Result, RetrievalError}; + +use super::types::Direction; + +/// Extract edge weight from link properties. +/// +/// Returns the `weight` property if present, otherwise defaults to 1.0. +pub fn get_edge_weight(link: &Link) -> f64 { + link.properties + .as_ref() + .and_then(|props| props.get("weight")) + .and_then(|v| v.as_f64()) + .unwrap_or(1.0) +} + +/// Check if a link matches the type filter. +/// +/// Returns `true` if: +/// - The filter is `None` (all types match) +/// - The link's relation is in the filter list +pub fn matches_link_type(link: &Link, filter: &Option>) -> bool { + match filter { + None => true, + Some(types) => types.iter().any(|t| t == &link.relation), + } +} + +/// Get neighbor links based on direction. +/// +/// # Arguments +/// +/// * `store` - The link store to query +/// * `ctx` - Storage context for namespace isolation +/// * `entity` - The entity to get neighbors for +/// * `direction` - Which direction to follow edges +/// +/// # Returns +/// +/// Vector of links in the specified direction(s). +pub async fn get_neighbors( + store: &S, + ctx: &StorageContext, + entity: &EntityRef, + direction: &Direction, +) -> Result> { + let links = + match direction { + Direction::Out => store + .outgoing(ctx, entity) + .await + .map_err(|e| RetrievalError::GraphTraversal(format!("link store error: {e}"))), + Direction::In => store + .incoming(ctx, entity) + .await + .map_err(|e| RetrievalError::GraphTraversal(format!("link store error: {e}"))), + Direction::Both => { + let mut out = store.outgoing(ctx, entity).await.map_err(|e| { + RetrievalError::GraphTraversal(format!("link store error: {e}")) + })?; + let incoming = store.incoming(ctx, entity).await.map_err(|e| { + RetrievalError::GraphTraversal(format!("link store error: {e}")) + })?; + out.extend(incoming); + Ok(out) + } + }; + + links +} + +/// Convert graph depth to proximity score. +/// +/// Closer nodes (lower depth) get higher scores. This enables fusion with +/// vector and keyword search results via RRF. +/// +/// # Arguments +/// +/// * `depth` - Distance from the start node (0 = start node itself) +/// * `max_depth` - Maximum traversal depth configured +/// +/// # Returns +/// +/// A `DeterministicScore` in range [0.0, 1.0]: +/// - depth=0 → 1.0 (at start node) +/// - depth=max_depth → 0.0 (maximum distance) +/// +/// # Edge Cases +/// +/// When `max_depth = 0`: +/// - depth=0 → 1.0 (only start node is reachable) +/// - depth>0 → 0.0 (should not occur, but handled safely) +/// +/// # Proof Correspondence +/// +/// This function maintains the invariant: +/// - `proximity_nonneg`: Result is always >= 0 +/// - `proximity_bounded`: Result is always <= 1.0 +/// - `proximity_mono`: Higher depth → lower score (monotonically decreasing) +pub fn proximity_score(depth: usize, max_depth: usize) -> DeterministicScore { + // Guard against division by zero + if max_depth == 0 { + // At max_depth=0, only the start node (depth=0) is reachable + return DeterministicScore::from_f64(if depth == 0 { 1.0 } else { 0.0 }); + } + // Closer = higher score (inverse relationship) + let proximity = 1.0 - (depth as f64 / max_depth as f64); + DeterministicScore::from_f64(proximity) +} + +/// Get the neighbor entity from a link based on traversal direction and current node. +/// +/// # Arguments +/// +/// * `link` - The link to extract neighbor from +/// * `current` - The current entity we're traversing from +/// * `direction` - The traversal direction +/// +/// # Returns +/// +/// The entity at the "other end" of the link relative to the traversal direction. +pub fn get_neighbor_entity(link: &Link, current: &EntityRef, direction: &Direction) -> EntityRef { + match direction { + Direction::Out => link.target.clone(), + Direction::In => link.source.clone(), + Direction::Both => { + // In bidirectional mode, return the "other end" of the link + if &link.source == current { + link.target.clone() + } else { + link.source.clone() + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_matches_link_type() { + let link = Link::new( + LinkId::NIL, + EntityRef::External("a".to_string()), + EntityRef::External("b".to_string()), + "contains", + ); + + // No filter matches all + assert!(matches_link_type(&link, &None)); + + // Matching type + assert!(matches_link_type( + &link, + &Some(vec!["contains".to_string()]) + )); + + // Non-matching type + assert!(!matches_link_type( + &link, + &Some(vec!["references".to_string()]) + )); + + // Multiple types, one matches + assert!(matches_link_type( + &link, + &Some(vec!["references".to_string(), "contains".to_string()]) + )); + } + + #[test] + fn test_get_edge_weight() { + // No properties = default weight 1.0 + let link = Link::new( + LinkId::NIL, + EntityRef::External("a".to_string()), + EntityRef::External("b".to_string()), + "test", + ); + assert_eq!(get_edge_weight(&link), 1.0); + + // With weight property + let link_with_weight = Link::with_properties( + LinkId::NIL, + EntityRef::External("a".to_string()), + EntityRef::External("b".to_string()), + "test", + serde_json::json!({"weight": 2.5}), + ); + assert_eq!(get_edge_weight(&link_with_weight), 2.5); + } + + #[test] + fn test_get_neighbor_entity() { + let source = EntityRef::External("source".to_string()); + let target = EntityRef::External("target".to_string()); + let link = Link::new(LinkId::NIL, source.clone(), target.clone(), "test"); + + // Outgoing: return target + assert_eq!(get_neighbor_entity(&link, &source, &Direction::Out), target); + + // Incoming: return source + assert_eq!(get_neighbor_entity(&link, &target, &Direction::In), source); + + // Both from source: return target (other end) + assert_eq!( + get_neighbor_entity(&link, &source, &Direction::Both), + target + ); + + // Both from target: return source (other end) + assert_eq!( + get_neighbor_entity(&link, &target, &Direction::Both), + source + ); + } + + #[test] + fn test_proximity_score_normal() { + // At start node (depth=0) + let score = proximity_score(0, 5); + assert!((score.to_f64() - 1.0).abs() < f64::EPSILON); + + // At max depth + let score = proximity_score(5, 5); + assert!((score.to_f64() - 0.0).abs() < f64::EPSILON); + + // Midway + let score = proximity_score(2, 4); + assert!((score.to_f64() - 0.5).abs() < f64::EPSILON); + } + + #[test] + fn test_proximity_score_max_depth_zero() { + // Edge case: max_depth = 0, depth = 0 (only valid case) + let score = proximity_score(0, 0); + assert!((score.to_f64() - 1.0).abs() < f64::EPSILON); + + // Edge case: max_depth = 0, depth > 0 (should not occur, but handled safely) + let score = proximity_score(1, 0); + assert!((score.to_f64() - 0.0).abs() < f64::EPSILON); + } + + #[test] + fn test_proximity_score_monotonic() { + // Scores should decrease as depth increases + let max_depth = 10; + let mut prev_score = f64::MAX; + + for depth in 0..=max_depth { + let score = proximity_score(depth, max_depth).to_f64(); + assert!( + score <= prev_score, + "Score should be monotonically decreasing" + ); + prev_score = score; + } + } + + #[test] + fn test_proximity_score_bounded() { + // All scores should be in [0.0, 1.0] + for max_depth in [0, 1, 5, 10, 100] { + for depth in 0..=max_depth { + let score = proximity_score(depth, max_depth).to_f64(); + assert!(score >= 0.0, "Score should be >= 0"); + assert!(score <= 1.0, "Score should be <= 1"); + } + } + } +} diff --git a/crates/khive-retrieval/src/graph/mod.rs b/crates/khive-retrieval/src/graph/mod.rs new file mode 100644 index 00000000..7105ca75 --- /dev/null +++ b/crates/khive-retrieval/src/graph/mod.rs @@ -0,0 +1,99 @@ +//! Graph traversal algorithms for relationship-aware retrieval. +//! +//! This module provides BFS, DFS, and shortest path algorithms for exploring +//! the knowledge graph. All algorithms operate on the `LinkStore` trait from +//! khive-db, enabling relationship-aware retrieval pipelines. +//! +//! # Algorithm Selection Guide +//! +//! | Use Case | Algorithm | Function | +//! |----------|-----------|----------| +//! | Explore neighbors | BFS | [`bfs_traverse`] | +//! | Find shortest path | Bidirectional BFS | [`find_shortest_path`] | +//! | Deep exploration | DFS | [`dfs_traverse`] | +//! +//! # Architecture (ADR-004) +//! +//! ```text +//! khive-db khive-retrieval +//! +-----------------+ +----------------------+ +//! | LinkStore trait | <--- | Traversal algorithms | +//! | EntityRef, Link | | PathNode, Direction | +//! | StorageContext | | TraversalOptions | +//! +-----------------+ +----------------------+ +//! ``` +//! +//! # RETRIEVAL-09: Audit Logging for Graph Operations +//! +//! **Current state**: Graph traversal algorithms do NOT emit audit logs. +//! +//! **Design decision**: Audit logging is the responsibility of the caller +//! (typically khive-api or middleware layer), not the retrieval algorithms. +//! This keeps the traversal code focused and testable. +//! +//! **What callers should log**: +//! +//! | Event | Context to Capture | +//! |-------|-------------------| +//! | Traversal start | start_node, direction, max_depth, link_types | +//! | Traversal complete | nodes_visited, paths_found, duration_ms | +//! | Depth limit hit | node_at_limit, depth | +//! | Result limit hit | total_candidates, returned_count | +//! +//! **Future work**: If audit logging moves into the retrieval layer, add +//! a `TraversalObserver` trait for pluggable logging without coupling to +//! a specific logging framework. +//! +//! # Safety Limits +//! +//! All algorithms enforce safety limits to prevent runaway traversals: +//! - [`MAX_TRAVERSAL_DEPTH`]: Maximum hops from start (20) +//! - [`MAX_TRAVERSAL_RESULTS`]: Maximum nodes returned (10,000) +//! +//! # Example +//! +//! ```ignore +//! use khive_retrieval::graph::{bfs_traverse, find_shortest_path, TraversalOptions, Direction}; +//! use khive_db::{LinkStore, StorageContext}; +//! +//! // BFS exploration +//! let options = TraversalOptions::new(3) +//! .with_direction(Direction::Out) +//! .with_link_types(["contains", "references"]); +//! +//! let neighbors = bfs_traverse(&store, &ctx, start_ref, &options).await?; +//! +//! // Find shortest path +//! if let Some(path) = find_shortest_path(&store, &ctx, from, to, 10).await? { +//! println!("Path length: {} hops", path.len() - 1); +//! } +//! ``` +//! +//! See [ADR-004](../docs/ADR-004-graph-traversal.md) for algorithm specification. + +mod bfs; +mod compat; +mod dfs; +/// Helper functions for graph traversal (proximity scoring, neighbor extraction, etc.). +pub mod helpers; +mod shortest; +mod types; + +#[cfg(test)] +mod tests; + +// Re-export compat types (legacy graph API shims) +pub use compat::{test_context, EntityRef, Link, LinkId, LinkStore, MockLinkStore, StorageContext}; + +// Re-export public types +pub use types::{ + Direction, PathNode, TraversalOptions, MAX_TRAVERSAL_DEPTH, MAX_TRAVERSAL_RESULTS, +}; + +// Re-export direction variants for convenience +pub use types::Direction::{Both, In, Out}; + +// Re-export traversal algorithms +pub use bfs::bfs_traverse; +pub use dfs::dfs_traverse; +pub use shortest::find_shortest_path; diff --git a/crates/khive-retrieval/src/graph/shortest.rs b/crates/khive-retrieval/src/graph/shortest.rs new file mode 100644 index 00000000..80ff756d --- /dev/null +++ b/crates/khive-retrieval/src/graph/shortest.rs @@ -0,0 +1,266 @@ +//! Shortest path algorithm using bidirectional BFS. + +use std::collections::{HashMap, VecDeque}; + +use super::compat::{EntityRef, Link, LinkStore, StorageContext}; + +use crate::error::{Result, RetrievalError}; + +use super::types::{PathNode, MAX_TRAVERSAL_DEPTH}; + +/// Find the shortest path between two entities using bidirectional BFS. +/// +/// Bidirectional BFS searches from both start and end simultaneously, +/// meeting in the middle. This reduces search space from O(b^d) to O(b^(d/2)) +/// where b = branching factor and d = path depth. +/// +/// # Arguments +/// +/// * `store` - The link store to query +/// * `ctx` - Storage context for namespace isolation +/// * `from` - Starting entity reference +/// * `to` - Target entity reference +/// * `max_depth` - Maximum path length (clamped to [`MAX_TRAVERSAL_DEPTH`]) +/// +/// # Returns +/// +/// - `Some(Vec)` - Path from source to target (inclusive) +/// - `path[0]` is the start node (via_link = None) +/// - `path[i].via_link` is the edge from `path[i-1]` to `path[i]` +/// - `None` - No path exists within max_depth +/// +/// # Complexity +/// +/// - Time: O(b^(d/2)) vs O(b^d) for standard BFS +/// - Space: O(b^(d/2)) for both frontiers +/// +/// # Example +/// +/// ```ignore +/// let path = find_shortest_path(&store, &ctx, alice_ref, bob_ref, 5).await?; +/// if let Some(path) = path { +/// println!("Found path of {} hops", path.len() - 1); +/// for node in &path { +/// if let Some(link) = &node.via_link { +/// println!(" via {} to {:?}", link.relation, node.entity_id); +/// } +/// } +/// } +/// ``` +pub async fn find_shortest_path( + store: &S, + ctx: &StorageContext, + from: EntityRef, + to: EntityRef, + max_depth: usize, +) -> Result>> { + // Clamp max_depth to prevent excessive search + let max_depth = max_depth.min(MAX_TRAVERSAL_DEPTH); + + // Same node = trivial path (EntityRef implements Eq) + if from == to { + return Ok(Some(vec![PathNode::start(from)])); + } + + // Forward search state: entity -> (depth, parent_entity, link to this node) + // EntityRef implements Hash + Eq, enabling direct use as HashMap key. + let mut forward_visited: HashMap, Option)> = + HashMap::new(); + let mut forward_queue: VecDeque = VecDeque::new(); + forward_visited.insert(from.clone(), (0, None, None)); + forward_queue.push_back(from.clone()); + + // Backward search state: entity -> (depth, child_entity, link from this node) + let mut backward_visited: HashMap, Option)> = + HashMap::new(); + let mut backward_queue: VecDeque = VecDeque::new(); + backward_visited.insert(to.clone(), (0, None, None)); + backward_queue.push_back(to.clone()); + + let mut best_meeting: Option<(EntityRef, usize)> = None; // (node, total_dist) + let mut current_depth = 0; + + // Alternate between forward and backward expansion. + // Process entire BFS levels before checking for a meeting point so we + // find the meeting node with the smallest total distance, not just the + // first one encountered (which depends on HashMap iteration order). + while !forward_queue.is_empty() || !backward_queue.is_empty() { + if current_depth > max_depth { + break; + } + + // Expand forward frontier (following outgoing edges) + let forward_level_size = forward_queue.len(); + for _ in 0..forward_level_size { + if let Some(current) = forward_queue.pop_front() { + let outgoing = store.outgoing(ctx, ¤t).await.map_err(|e| { + RetrievalError::GraphTraversal(format!("link store error: {e}")) + })?; + + for link in outgoing { + let neighbor = link.target.clone(); + + if !forward_visited.contains_key(&neighbor) { + let fwd_dist = current_depth + 1; + // Store: link goes from current to neighbor + forward_visited.insert( + neighbor.clone(), + (fwd_dist, Some(current.clone()), Some(link)), + ); + forward_queue.push_back(neighbor.clone()); + + // Check if we've met the backward search + if let Some((bwd_dist, _, _)) = backward_visited.get(&neighbor) { + let total = fwd_dist + bwd_dist; + if best_meeting.as_ref().is_none_or(|&(_, best)| total < best) { + best_meeting = Some((neighbor, total)); + } + } + } + } + } + } + + // If we found a meeting point during forward expansion, the best + // meeting at this depth is optimal -- no need to expand backward. + if best_meeting.is_some() { + break; + } + + // Expand backward frontier (following incoming edges) + let backward_level_size = backward_queue.len(); + for _ in 0..backward_level_size { + if let Some(current) = backward_queue.pop_front() { + let incoming = store.incoming(ctx, ¤t).await.map_err(|e| { + RetrievalError::GraphTraversal(format!("link store error: {e}")) + })?; + + for link in incoming { + // For incoming: link goes from neighbor to current + let neighbor = link.source.clone(); + + if !backward_visited.contains_key(&neighbor) { + let bwd_dist = current_depth + 1; + // Store: link goes from neighbor to current (for path reconstruction) + backward_visited.insert( + neighbor.clone(), + (bwd_dist, Some(current.clone()), Some(link)), + ); + backward_queue.push_back(neighbor.clone()); + + // Check if we've met the forward search + if let Some((fwd_dist, _, _)) = forward_visited.get(&neighbor) { + let total = fwd_dist + bwd_dist; + if best_meeting.as_ref().is_none_or(|&(_, best)| total < best) { + best_meeting = Some((neighbor, total)); + } + } + } + } + } + } + + // After processing both frontiers at this depth, check for meeting + if best_meeting.is_some() { + break; + } + + current_depth += 1; + } + + // Reconstruct path if found + match best_meeting { + Some((mid, _total_dist)) => { + let path = reconstruct_path(&forward_visited, &backward_visited, &mid); + Ok(Some(path)) + } + None => Ok(None), + } +} + +/// Reconstruct the path from forward and backward visited maps. +fn reconstruct_path( + forward_visited: &HashMap, Option)>, + backward_visited: &HashMap, Option)>, + meeting_point: &EntityRef, +) -> Vec { + // Build forward part: start -> meeting_point + let mut forward_entities: Vec = Vec::new(); + let mut forward_links: Vec> = Vec::new(); + let mut current = meeting_point.clone(); + + // Walk backwards from meeting point to start + while let Some((_, parent, link)) = forward_visited.get(¤t) { + forward_entities.push(current.clone()); + forward_links.push(link.clone()); + match parent { + Some(p) => current = p.clone(), + None => break, + } + } + + // Reverse to get start -> meeting_point order + forward_entities.reverse(); + forward_links.reverse(); + + // Build backward part: meeting_point -> end + let mut backward_entities: Vec = Vec::new(); + let mut backward_links: Vec> = Vec::new(); + + // Start from meeting point, walk towards 'to' + if let Some((_, Some(child), link)) = backward_visited.get(meeting_point) { + backward_links.push(link.clone()); + current = child.clone(); + + while let Some((_, next_child, link)) = backward_visited.get(¤t) { + backward_entities.push(current.clone()); + match next_child { + Some(nc) => { + backward_links.push(link.clone()); + current = nc.clone(); + } + None => break, + } + } + // Defensive: if the while loop exited because backward_visited + // lacked an entry for `current` (shouldn't happen in a consistent + // graph, but guards against any map skew), include `current` so + // the target node is never silently dropped. + if backward_entities.last().map_or(true, |e| e != ¤t) { + backward_entities.push(current.clone()); + } + } + + // Combine into final path + let mut path: Vec = Vec::new(); + + // Add forward nodes + for (i, entity) in forward_entities.iter().enumerate() { + let link = if i == 0 { + None // Start node has no inbound edge + } else { + forward_links.get(i).cloned().flatten() + }; + + path.push(PathNode { + entity_id: entity.clone(), + depth: i, + via_link: link, + path_weight: i as f64, + }); + } + + // Add backward nodes (these come after meeting point) + let base_depth = path.len(); + for (i, entity) in backward_entities.iter().enumerate() { + let link = backward_links.get(i).cloned().flatten(); + path.push(PathNode { + entity_id: entity.clone(), + depth: base_depth + i, + via_link: link, + path_weight: (base_depth + i) as f64, + }); + } + + path +} diff --git a/crates/khive-retrieval/src/graph/tests.rs b/crates/khive-retrieval/src/graph/tests.rs new file mode 100644 index 00000000..639b3efd --- /dev/null +++ b/crates/khive-retrieval/src/graph/tests.rs @@ -0,0 +1,134 @@ +//! Unit tests for graph traversal module. + +use super::compat::{test_context, EntityRef, MockLinkStore}; + +use crate::graph::types::{ + Direction, PathNode, TraversalOptions, MAX_TRAVERSAL_DEPTH, MAX_TRAVERSAL_RESULTS, +}; + +#[test] +fn test_traversal_options_default() { + let opts = TraversalOptions::default(); + assert_eq!(opts.max_depth, 3); + assert_eq!(opts.direction, Direction::Out); + assert!(opts.link_types.is_none()); + assert_eq!(opts.limit, Some(MAX_TRAVERSAL_RESULTS)); +} + +#[test] +fn test_traversal_options_builder() { + let opts = TraversalOptions::new(5) + .with_direction(Direction::Both) + .with_link_types(["contains", "references"]) + .with_limit(100) + .with_min_weight(0.5); + + assert_eq!(opts.max_depth, 5); + assert_eq!(opts.direction, Direction::Both); + assert_eq!( + opts.link_types, + Some(vec!["contains".to_string(), "references".to_string()]) + ); + assert_eq!(opts.limit, Some(100)); + assert_eq!(opts.min_weight, Some(0.5)); +} + +#[test] +fn test_traversal_options_clamping() { + // Depth clamping + let opts = TraversalOptions::new(100); + assert_eq!(opts.max_depth, MAX_TRAVERSAL_DEPTH); + + // Limit clamping + let opts = TraversalOptions::new(3).with_limit(100_000); + assert_eq!(opts.limit, Some(MAX_TRAVERSAL_RESULTS)); +} + +#[test] +fn test_path_node_start() { + let entity = EntityRef::External("test".to_string()); + let node = PathNode::start(entity.clone()); + + assert_eq!(node.entity_id, entity); + assert_eq!(node.depth, 0); + assert!(node.via_link.is_none()); + assert_eq!(node.path_weight, 0.0); +} + +#[test] +fn test_direction_default() { + let dir = Direction::default(); + assert_eq!(dir, Direction::Out); +} + +#[test] +fn test_safety_constants() { + // Verify safety constants are reasonable + assert_eq!(MAX_TRAVERSAL_DEPTH, 20); + assert_eq!(MAX_TRAVERSAL_RESULTS, 10_000); +} + +#[tokio::test] +async fn shortest_path_includes_target_node() { + // Graph: A → B → C. Verify path is [A, B, C] — all three nodes including target C. + let store = MockLinkStore::new(); + let ctx = test_context(); + + let a = EntityRef::External("A".to_string()); + let b = EntityRef::External("B".to_string()); + let c = EntityRef::External("C".to_string()); + + store + .link( + &ctx, + a.clone(), + b.clone(), + "edge", + None::, + ) + .await + .unwrap(); + store + .link(&ctx, b.clone(), c.clone(), "edge", None) + .await + .unwrap(); + + let path = super::shortest::find_shortest_path(&store, &ctx, a.clone(), c.clone(), 5) + .await + .unwrap() + .expect("path exists"); + + assert_eq!(path.len(), 3, "path should contain 3 nodes: A, B, C"); + assert_eq!(path[0].entity_id, a, "first node is start (A)"); + assert_eq!(path[2].entity_id, c, "last node is target (C)"); +} + +#[tokio::test] +async fn shortest_path_direct_edge_includes_target() { + // Graph: A → B (direct). Path should be [A, B], not just [A]. + let store = MockLinkStore::new(); + let ctx = test_context(); + + let a = EntityRef::External("X".to_string()); + let b = EntityRef::External("Y".to_string()); + + store + .link( + &ctx, + a.clone(), + b.clone(), + "edge", + None::, + ) + .await + .unwrap(); + + let path = super::shortest::find_shortest_path(&store, &ctx, a.clone(), b.clone(), 5) + .await + .unwrap() + .expect("path exists"); + + assert_eq!(path.len(), 2, "path should contain 2 nodes: X, Y"); + assert_eq!(path[0].entity_id, a); + assert_eq!(path[1].entity_id, b, "target node must be in path"); +} diff --git a/crates/khive-retrieval/src/graph/types.rs b/crates/khive-retrieval/src/graph/types.rs new file mode 100644 index 00000000..cf5b8159 --- /dev/null +++ b/crates/khive-retrieval/src/graph/types.rs @@ -0,0 +1,208 @@ +//! Graph traversal types. + +use super::compat::{EntityRef, Link}; +use serde::{Deserialize, Serialize}; + +/// Maximum traversal depth to prevent stack overflow and runaway queries. +pub const MAX_TRAVERSAL_DEPTH: usize = 20; + +/// Maximum results per traversal to prevent memory exhaustion. +pub const MAX_TRAVERSAL_RESULTS: usize = 10_000; + +/// Direction of edge traversal. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum Direction { + /// Follow outgoing edges (source -> target). + #[default] + #[serde(alias = "Out")] + Out, + /// Follow incoming edges (target <- source). + #[serde(alias = "In")] + In, + /// Follow edges in both directions. + #[serde(alias = "Both")] + Both, +} + +/// A node in a traversal path. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PathNode { + /// The entity at this position in the path. + pub entity_id: EntityRef, + /// Depth from the start node (0 = start node). + pub depth: usize, + /// The link that led to this node (None for start node). + pub via_link: Option, + /// Cumulative path weight (sum of edge weights). + pub path_weight: f64, +} + +impl PathNode { + /// Create a new path node for the start position. + pub fn start(entity_id: EntityRef) -> Self { + Self { + entity_id, + depth: 0, + via_link: None, + path_weight: 0.0, + } + } + + /// Create a path node from an outgoing link. + pub fn from_outgoing_link(link: Link, depth: usize, path_weight: f64) -> Self { + Self { + entity_id: link.target.clone(), + depth, + via_link: Some(link), + path_weight, + } + } + + /// Create a path node from an incoming link. + pub fn from_incoming_link(link: Link, depth: usize, path_weight: f64) -> Self { + Self { + entity_id: link.source.clone(), + depth, + via_link: Some(link), + path_weight, + } + } +} + +/// Options for graph traversal operations. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TraversalOptions { + /// Maximum depth to traverse (clamped to [`MAX_TRAVERSAL_DEPTH`]). + pub max_depth: usize, + /// Maximum number of nodes to return (clamped to [`MAX_TRAVERSAL_RESULTS`]). + pub limit: Option, + /// Direction to follow edges. + pub direction: Direction, + /// Filter by link relation types (None = all types). + pub link_types: Option>, + /// Minimum edge weight to consider (for weighted traversal). + pub min_weight: Option, +} + +impl Default for TraversalOptions { + fn default() -> Self { + Self { + max_depth: 3, + limit: Some(MAX_TRAVERSAL_RESULTS), + direction: Direction::Out, + link_types: None, + min_weight: None, + } + } +} + +impl TraversalOptions { + /// Create new options with specified max depth. + pub fn new(max_depth: usize) -> Self { + Self { + max_depth: max_depth.min(MAX_TRAVERSAL_DEPTH), + limit: Some(MAX_TRAVERSAL_RESULTS), + ..Default::default() + } + } + + /// Set the maximum traversal depth. + #[must_use] + pub fn with_max_depth(mut self, depth: usize) -> Self { + self.max_depth = depth.min(MAX_TRAVERSAL_DEPTH); + self + } + + /// Set traversal direction. + #[must_use] + pub fn with_direction(mut self, direction: Direction) -> Self { + self.direction = direction; + self + } + + /// Filter to specific link relation types. + #[must_use] + pub fn with_link_types(mut self, types: impl IntoIterator>) -> Self { + self.link_types = Some(types.into_iter().map(Into::into).collect()); + self + } + + /// Set maximum number of results. + #[must_use] + pub fn with_limit(mut self, limit: usize) -> Self { + self.limit = Some(limit.min(MAX_TRAVERSAL_RESULTS)); + self + } + + /// Set minimum edge weight threshold. + #[must_use] + pub fn with_min_weight(mut self, weight: f64) -> Self { + self.min_weight = Some(weight); + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_traversal_options_default() { + let opts = TraversalOptions::default(); + assert_eq!(opts.max_depth, 3); + assert_eq!(opts.direction, Direction::Out); + assert!(opts.link_types.is_none()); + assert_eq!(opts.limit, Some(MAX_TRAVERSAL_RESULTS)); + } + + #[test] + fn test_traversal_options_builder() { + let opts = TraversalOptions::new(5) + .with_direction(Direction::Both) + .with_link_types(["contains", "references"]) + .with_limit(100) + .with_min_weight(0.5); + + assert_eq!(opts.max_depth, 5); + assert_eq!(opts.direction, Direction::Both); + assert_eq!( + opts.link_types, + Some(vec!["contains".to_string(), "references".to_string()]) + ); + assert_eq!(opts.limit, Some(100)); + assert_eq!(opts.min_weight, Some(0.5)); + } + + #[test] + fn test_traversal_options_clamping() { + let opts = TraversalOptions::new(100); + assert_eq!(opts.max_depth, MAX_TRAVERSAL_DEPTH); + + let opts = TraversalOptions::new(3).with_limit(100_000); + assert_eq!(opts.limit, Some(MAX_TRAVERSAL_RESULTS)); + } + + #[test] + fn test_path_node_start() { + let entity = EntityRef::External("test".to_string()); + let node = PathNode::start(entity.clone()); + + assert_eq!(node.entity_id, entity); + assert_eq!(node.depth, 0); + assert!(node.via_link.is_none()); + assert_eq!(node.path_weight, 0.0); + } + + #[test] + fn test_direction_default() { + let dir = Direction::default(); + assert_eq!(dir, Direction::Out); + } + + #[test] + fn test_safety_constants() { + assert_eq!(MAX_TRAVERSAL_DEPTH, 20); + assert_eq!(MAX_TRAVERSAL_RESULTS, 10_000); + } +} diff --git a/crates/khive-retrieval/src/hybrid/config.rs b/crates/khive-retrieval/src/hybrid/config.rs new file mode 100644 index 00000000..febac2bf --- /dev/null +++ b/crates/khive-retrieval/src/hybrid/config.rs @@ -0,0 +1,260 @@ +//! Hybrid search configuration types. + +use std::time::Duration; + +use khive_score::DeterministicScore; +use serde::{Deserialize, Serialize}; + +use khive_fusion::FusionStrategy; + +/// Default candidate pool multiplier over top_k. +pub const DEFAULT_POOL_MULTIPLIER: usize = 5; + +/// Query for hybrid search. +/// +/// Combines text for keyword search and optional embedding for vector search. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Query { + /// Text for keyword search (required). + pub text: String, + + /// Pre-computed embedding for vector search (optional). + /// + /// If None, vector search is skipped or caller must provide. + pub embedding: Option>, + + /// Optional filters to apply post-retrieval. + pub filters: Option, +} + +impl Query { + /// Create a new query with text only (keyword search). + pub fn text(text: impl Into) -> Self { + Self { + text: text.into(), + embedding: None, + filters: None, + } + } + + /// Create a query with both text and embedding (hybrid search). + pub fn hybrid(text: impl Into, embedding: Vec) -> Self { + Self { + text: text.into(), + embedding: Some(embedding), + filters: None, + } + } + + /// Add filters to the query. + #[must_use] + pub fn with_filters(mut self, filters: serde_json::Value) -> Self { + self.filters = Some(filters); + self + } + + /// Check if this query supports vector search. + pub fn has_embedding(&self) -> bool { + self.embedding.is_some() + } +} + +/// Configuration for hybrid search. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HybridConfig { + /// Fusion strategy to use (default: RRF with k=60). + pub fusion_strategy: FusionStrategy, + + /// Number of results to return. + pub top_k: usize, + + /// Candidates to fetch from each retriever before fusion. + /// + /// Should be >= 5 * top_k for quality fusion. + pub candidate_pool_size: usize, + + /// Minimum score threshold (post-fusion). + pub min_score: Option, + + /// Weight for vector search results (0.0 to 1.0). + /// + /// Only used when fusion_strategy is Weighted. + pub vector_weight: f64, + + /// Weight for keyword search results (0.0 to 1.0). + /// + /// Only used when fusion_strategy is Weighted. + pub keyword_weight: f64, + + /// Optional timeout for the entire search operation. + /// + /// If set, the search will be cancelled if it exceeds this duration, + /// returning [`RetrievalError::QueryTimeout`]. + /// If None, no timeout is applied. + #[serde( + default, + skip_serializing_if = "Option::is_none", + with = "crate::timeout::serde_opt_duration" + )] + pub timeout: Option, +} + +impl Default for HybridConfig { + fn default() -> Self { + Self { + fusion_strategy: FusionStrategy::rrf(), + top_k: 10, + candidate_pool_size: 50, // 5 * top_k + min_score: None, + vector_weight: 0.7, + keyword_weight: 0.3, + timeout: None, + } + } +} + +impl HybridConfig { + /// Create a new config with specified top_k. + pub fn new(top_k: usize) -> Self { + Self { + top_k, + candidate_pool_size: top_k * DEFAULT_POOL_MULTIPLIER, + ..Default::default() + } + } + + /// Set the fusion strategy. + #[must_use] + pub fn with_fusion_strategy(mut self, strategy: FusionStrategy) -> Self { + self.fusion_strategy = strategy; + self + } + + /// Set the candidate pool size. + #[must_use] + pub fn with_pool_size(mut self, size: usize) -> Self { + self.candidate_pool_size = size; + self + } + + /// Set the minimum score threshold. + #[must_use] + pub fn with_min_score(mut self, score: DeterministicScore) -> Self { + self.min_score = Some(score); + self + } + + /// Set weights for weighted fusion. + /// + /// Weights are clamped to [0.0, 1.0]. + #[must_use] + pub fn with_weights(mut self, vector: f64, keyword: f64) -> Self { + self.vector_weight = vector.clamp(0.0, 1.0); + self.keyword_weight = keyword.clamp(0.0, 1.0); + self + } + + /// Set the search timeout. + /// + /// If the search operation exceeds this duration, it will return + /// [`RetrievalError::QueryTimeout`]. + #[must_use] + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.timeout = Some(timeout); + self + } + + /// Get normalized weights that sum to 1.0. + /// + /// If both weights are zero, returns equal weights (0.5, 0.5). + pub fn normalized_weights(&self) -> (f64, f64) { + let sum = self.vector_weight + self.keyword_weight; + if sum <= 0.0 { + (0.5, 0.5) + } else { + (self.vector_weight / sum, self.keyword_weight / sum) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_query_text_only() { + let q = Query::text("hello world"); + assert_eq!(q.text, "hello world"); + assert!(q.embedding.is_none()); + assert!(!q.has_embedding()); + } + + #[test] + fn test_query_hybrid() { + let embedding = vec![0.1, 0.2, 0.3]; + let q = Query::hybrid("hello", embedding.clone()); + assert_eq!(q.text, "hello"); + assert_eq!(q.embedding, Some(embedding)); + assert!(q.has_embedding()); + } + + #[test] + fn test_query_with_filters() { + let q = Query::text("test").with_filters(serde_json::json!({"type": "memory"})); + assert!(q.filters.is_some()); + } + + #[test] + fn test_hybrid_config_default() { + let config = HybridConfig::default(); + assert_eq!(config.top_k, 10); + assert_eq!(config.candidate_pool_size, 50); + assert!(matches!( + config.fusion_strategy, + FusionStrategy::Rrf { k: 60 } + )); + assert!(config.min_score.is_none()); + } + + #[test] + fn test_hybrid_config_new() { + let config = HybridConfig::new(20); + assert_eq!(config.top_k, 20); + assert_eq!(config.candidate_pool_size, 100); // 20 * 5 + } + + #[test] + fn test_hybrid_config_builder() { + let config = HybridConfig::new(10) + .with_fusion_strategy(FusionStrategy::union()) + .with_pool_size(200) + .with_weights(0.6, 0.4); + + assert_eq!(config.top_k, 10); + assert_eq!(config.candidate_pool_size, 200); + assert!(matches!(config.fusion_strategy, FusionStrategy::Union)); + assert_eq!(config.vector_weight, 0.6); + assert_eq!(config.keyword_weight, 0.4); + } + + #[test] + fn test_normalized_weights() { + let config = HybridConfig::default(); + let (v, k) = config.normalized_weights(); + assert!((v - 0.7).abs() < 0.01); + assert!((k - 0.3).abs() < 0.01); + + // Zero weights -> equal + let config = HybridConfig::default().with_weights(0.0, 0.0); + let (v, k) = config.normalized_weights(); + assert!((v - 0.5).abs() < 0.01); + assert!((k - 0.5).abs() < 0.01); + } + + #[test] + fn test_weight_clamping() { + let config = HybridConfig::default().with_weights(1.5, -0.5); + assert_eq!(config.vector_weight, 1.0); + assert_eq!(config.keyword_weight, 0.0); + } +} diff --git a/crates/khive-retrieval/src/hybrid/cross_encoder.rs b/crates/khive-retrieval/src/hybrid/cross_encoder.rs new file mode 100644 index 00000000..ddc42343 --- /dev/null +++ b/crates/khive-retrieval/src/hybrid/cross_encoder.rs @@ -0,0 +1,291 @@ +//! Native cross-encoder reranking via `khive-inference`. +//! +//! Provides `NativeCrossEncoderReranker` which implements `Reranker`. +//! Document texts are fetched by `RerankDocumentResolver` so the existing +//! `Reranker` trait (which only carries IDs and scores) does not need to change. + +use std::marker::PhantomData; +use std::sync::Arc; + +use async_trait::async_trait; +use khive_score::DeterministicScore; + +use crate::error::{Result, RetrievalError}; +use crate::hybrid::searcher::Reranker; + +/// Resolve document texts for a set of candidate IDs. +/// +/// Implementors fetch raw document text from whatever backing store is +/// available. A missing document (e.g. deleted after indexing) should be +/// returned as `None`. +#[async_trait] +pub trait RerankDocumentResolver: Send + Sync +where + Id: Send + Sync + 'static, +{ + /// Fetch document bodies for `ids` in input order. + /// + /// The returned `Vec` must be the same length as `ids`. A missing document + /// is represented as `None`; the reranker will return an error in that case. + async fn resolve_documents(&self, ids: &[Id]) -> Result>>; +} + +/// Synchronous cross-encoder scorer abstraction (for testability). +pub trait CrossEncoderScorer: Send + Sync { + /// Score a query against a batch of documents; returns one value per document. + fn score_batch(&self, query: &str, documents: &[&str]) -> Vec; +} + +// TODO(port-rerank): khive-inference not ported yet; CrossEncoderModel impl disabled. +// impl CrossEncoderScorer for khive_inference::CrossEncoderModel { ... } + +/// Reranker that scores candidates with a native cross-encoder model. +/// +/// The generic parameter `S` is the scorer implementation (defaults to no external dep +/// in this OSS build; use a concrete scorer by passing one explicitly). +/// Tests substitute a lightweight fake scorer. +pub struct NativeCrossEncoderReranker +where + Id: Clone + Send + Sync + 'static, + R: RerankDocumentResolver, + S: CrossEncoderScorer, +{ + model: Arc, + resolver: Arc, + _id: PhantomData Id>, +} + +impl NativeCrossEncoderReranker +where + Id: Clone + Send + Sync + 'static, + R: RerankDocumentResolver, + S: CrossEncoderScorer, +{ + /// Construct from an existing scorer and resolver. + pub fn new(model: Arc, resolver: Arc) -> Self { + Self { + model, + resolver, + _id: PhantomData, + } + } +} + +// TODO(port-rerank): from_directory constructor requires khive-inference::CrossEncoderModel. +// Re-enable once khive-inference is ported. +// impl NativeCrossEncoderReranker { ... } + +#[async_trait] +impl Reranker for NativeCrossEncoderReranker +where + Id: Clone + Send + Sync + 'static, + R: RerankDocumentResolver, + S: CrossEncoderScorer, +{ + async fn rerank( + &self, + query: &str, + results: Vec<(Id, DeterministicScore)>, + top_k: usize, + ) -> Result> { + if top_k == 0 || results.is_empty() { + return Ok(Vec::new()); + } + + let ids: Vec = results.iter().map(|(id, _)| id.clone()).collect(); + let resolved = self.resolver.resolve_documents(&ids).await?; + if resolved.len() != results.len() { + return Err(RetrievalError::rerank(format!( + "resolver returned {} documents for {} candidates", + resolved.len(), + results.len() + ))); + } + + let mut documents: Vec = Vec::with_capacity(resolved.len()); + for (idx, opt) in resolved.into_iter().enumerate() { + let text = opt.ok_or_else(|| { + RetrievalError::rerank(format!( + "missing document text for rerank candidate at index {idx}" + )) + })?; + documents.push(text); + } + + let document_refs: Vec<&str> = documents.iter().map(String::as_str).collect(); + let scores = self.model.score_batch(query, &document_refs); + if scores.len() != results.len() { + return Err(RetrievalError::rerank(format!( + "model returned {} scores for {} candidates", + scores.len(), + results.len() + ))); + } + + let mut scored: Vec<(usize, Id, f32)> = results + .into_iter() + .zip(scores) + .enumerate() + .map(|(idx, ((id, _), score))| (idx, id, score)) + .collect(); + + scored.sort_by(|a, b| b.2.total_cmp(&a.2).then_with(|| a.0.cmp(&b.0))); + + Ok(scored + .into_iter() + .take(top_k) + .map(|(_, id, score)| (id, DeterministicScore::from_f64(score as f64))) + .collect()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + struct FakeScorer { + scores: Vec, + } + + impl CrossEncoderScorer for FakeScorer { + fn score_batch(&self, _query: &str, _documents: &[&str]) -> Vec { + self.scores.clone() + } + } + + struct FakeResolver { + documents: Vec>, + } + + #[async_trait] + impl RerankDocumentResolver for FakeResolver { + async fn resolve_documents(&self, _ids: &[u32]) -> Result>> { + Ok(self.documents.clone()) + } + } + + fn make_reranker( + scores: Vec, + documents: Vec>, + ) -> NativeCrossEncoderReranker { + NativeCrossEncoderReranker::new( + Arc::new(FakeScorer { scores }), + Arc::new(FakeResolver { documents }), + ) + } + + #[tokio::test] + async fn test_top_k_zero_returns_empty() { + let reranker = make_reranker(vec![0.9, 0.1], vec![Some("a".into()), Some("b".into())]); + let results = vec![(1u32, DeterministicScore::from_f64(0.5))]; + let out = reranker.rerank("q", results, 0).await.unwrap(); + assert!(out.is_empty()); + } + + #[tokio::test] + async fn test_empty_input_returns_empty() { + let reranker = make_reranker(vec![], vec![]); + let out = reranker.rerank("q", vec![], 5).await.unwrap(); + assert!(out.is_empty()); + } + + #[tokio::test] + async fn test_descending_sort() { + let reranker = make_reranker( + vec![0.1, 0.9, 0.5], + vec![Some("a".into()), Some("b".into()), Some("c".into())], + ); + let results = vec![ + (1u32, DeterministicScore::from_f64(0.3)), + (2u32, DeterministicScore::from_f64(0.3)), + (3u32, DeterministicScore::from_f64(0.3)), + ]; + let out = reranker.rerank("q", results, 3).await.unwrap(); + assert_eq!(out[0].0, 2u32); // score 0.9 + assert_eq!(out[1].0, 3u32); // score 0.5 + assert_eq!(out[2].0, 1u32); // score 0.1 + } + + #[tokio::test] + async fn test_tie_preserves_original_order() { + let reranker = make_reranker( + vec![0.5, 0.5, 0.5], + vec![Some("a".into()), Some("b".into()), Some("c".into())], + ); + let results = vec![ + (10u32, DeterministicScore::from_f64(0.3)), + (20u32, DeterministicScore::from_f64(0.3)), + (30u32, DeterministicScore::from_f64(0.3)), + ]; + let out = reranker.rerank("q", results, 3).await.unwrap(); + assert_eq!(out[0].0, 10u32); + assert_eq!(out[1].0, 20u32); + assert_eq!(out[2].0, 30u32); + } + + #[tokio::test] + async fn test_missing_document_returns_error() { + let reranker = make_reranker(vec![0.5], vec![None]); + let results = vec![(1u32, DeterministicScore::from_f64(0.5))]; + let err = reranker.rerank("q", results, 1).await.unwrap_err(); + assert!(matches!(err, RetrievalError::Rerank(_))); + } + + #[tokio::test] + async fn test_resolver_length_mismatch_returns_error() { + struct BadResolver; + + #[async_trait] + impl RerankDocumentResolver for BadResolver { + async fn resolve_documents(&self, _ids: &[u32]) -> Result>> { + Ok(vec![]) // wrong length + } + } + + let reranker = NativeCrossEncoderReranker::new( + Arc::new(FakeScorer { scores: vec![0.5] }), + Arc::new(BadResolver), + ); + let results = vec![(1u32, DeterministicScore::from_f64(0.5))]; + let err = reranker.rerank("q", results, 1).await.unwrap_err(); + assert!(matches!(err, RetrievalError::Rerank(_))); + } + + #[tokio::test] + async fn test_top_k_limits_output() { + let reranker = make_reranker( + vec![0.9, 0.8, 0.7], + vec![Some("a".into()), Some("b".into()), Some("c".into())], + ); + let results = vec![ + (1u32, DeterministicScore::from_f64(0.3)), + (2u32, DeterministicScore::from_f64(0.3)), + (3u32, DeterministicScore::from_f64(0.3)), + ]; + let out = reranker.rerank("q", results, 2).await.unwrap(); + assert_eq!(out.len(), 2); + } + + #[tokio::test] + async fn test_top_k_larger_than_results_returns_all() { + // top_k=10 with only 2 candidates — should return all 2, sorted by score + let reranker = make_reranker(vec![0.1, 0.9], vec![Some("a".into()), Some("b".into())]); + let results = vec![ + (1u32, DeterministicScore::from_f64(0.5)), + (2u32, DeterministicScore::from_f64(0.3)), + ]; + let out = reranker.rerank("q", results, 10).await.unwrap(); + assert_eq!(out.len(), 2); + assert_eq!(out[0].0, 2u32); // score 0.9 + assert_eq!(out[1].0, 1u32); // score 0.1 + } + + #[tokio::test] + async fn test_single_result_passes_through() { + let reranker = make_reranker(vec![0.75], vec![Some("only doc".into())]); + let results = vec![(42u32, DeterministicScore::from_f64(0.5))]; + let out = reranker.rerank("q", results, 1).await.unwrap(); + assert_eq!(out.len(), 1); + assert_eq!(out[0].0, 42u32); + } +} diff --git a/crates/khive-retrieval/src/hybrid/dual_index.rs b/crates/khive-retrieval/src/hybrid/dual_index.rs new file mode 100644 index 00000000..19350892 --- /dev/null +++ b/crates/khive-retrieval/src/hybrid/dual_index.rs @@ -0,0 +1,524 @@ +//! Dual-index query routing for embedding model migration. +//! +//! During an embedding model migration, both old and new embedding indexes coexist. +//! The retrieval system must query both indexes and merge results to maintain search +//! quality throughout the transition. This module provides a routing layer that +//! decides which indexes to query and fuses the results using the fusion system. +//! +//! # Migration Lifecycle +//! +//! ```text +//! Phase 1 (start): Both → query old + new, fuse with RRF +//! Phase 2 (mid): Weighted → prefer new, still query old +//! Phase 3 (end): PrimaryOnly → new index fully populated +//! ``` +//! +//! # Example +//! +//! ```rust +//! use khive_retrieval::hybrid::dual_index::{DualIndexConfig, DualIndexRouter, DualIndexStrategy}; +//! use khive_score::DeterministicScore; +//! +//! // During migration: query both indexes, fuse with RRF +//! let config = DualIndexConfig::default(); +//! let router = DualIndexRouter::::new(config); +//! +//! assert!(router.should_query_primary(None)); +//! assert!(router.should_query_legacy(None)); +//! +//! // Merge results from both indexes +//! let primary = vec![ +//! ("doc_a".to_string(), DeterministicScore::from_f64(0.9)), +//! ("doc_b".to_string(), DeterministicScore::from_f64(0.8)), +//! ]; +//! let legacy = vec![ +//! ("doc_b".to_string(), DeterministicScore::from_f64(0.95)), +//! ("doc_c".to_string(), DeterministicScore::from_f64(0.7)), +//! ]; +//! +//! let merged = router.merge_results(primary, legacy, 10); +//! // doc_b appears in both, gets highest RRF score +//! assert_eq!(merged[0].0, "doc_b"); +//! ``` + +use std::hash::Hash; + +use khive_score::DeterministicScore; + +use khive_fusion::{fuse, FusionStrategy}; + +/// Strategy for routing queries during dual-index operation. +/// +/// Controls which indexes are queried and how results are combined. +#[derive(Debug, Clone, PartialEq)] +pub enum DualIndexStrategy { + /// Query both indexes, fuse results (default during migration). + /// + /// Uses the specified fusion strategy to combine results from both + /// the primary (new) and legacy (old) indexes. + Both { + /// Fusion strategy for combining results from both indexes. + fusion: FusionStrategy, + }, + + /// Query only the primary (new) index. + /// + /// Use after migration is complete and all documents have been + /// re-embedded with the new model. + PrimaryOnly, + + /// Query only the legacy (old) index. + /// + /// Use as a fallback if the new index has issues. + LegacyOnly, + + /// Weighted preference: primary gets `primary_weight`, legacy gets `1 - primary_weight`. + /// + /// Useful during mid-migration when the new index covers most documents + /// but the old index still has better coverage for some. + Weighted { + /// Weight for primary index results, in range [0.0, 1.0]. + /// Legacy index weight is computed as `1.0 - primary_weight`. + primary_weight: f64, + }, +} + +impl Default for DualIndexStrategy { + fn default() -> Self { + DualIndexStrategy::Both { + fusion: FusionStrategy::rrf(), + } + } +} + +/// Configuration for dual-index query routing. +#[derive(Debug, Clone)] +pub struct DualIndexConfig { + /// Routing strategy. + pub strategy: DualIndexStrategy, + + /// Candidate pool multiplier for each index. + /// + /// Each index fetches `top_k * pool_multiplier` candidates before fusion. + /// Default: 3. + pub pool_multiplier: usize, + + /// Minimum migration progress to auto-switch to `PrimaryOnly`, in range [0.0, 1.0]. + /// + /// When `migration_progress >= auto_switch_threshold`, the router automatically + /// skips the legacy index. Set to `None` to disable auto-switching. + pub auto_switch_threshold: Option, +} + +impl Default for DualIndexConfig { + fn default() -> Self { + Self { + strategy: DualIndexStrategy::default(), + pool_multiplier: 3, + auto_switch_threshold: None, + } + } +} + +impl DualIndexConfig { + /// Create a config with a specific strategy. + pub fn with_strategy(mut self, strategy: DualIndexStrategy) -> Self { + self.strategy = strategy; + self + } + + /// Set the candidate pool multiplier. + pub fn with_pool_multiplier(mut self, multiplier: usize) -> Self { + self.pool_multiplier = multiplier.max(1); + self + } + + /// Set the auto-switch threshold for migration progress. + pub fn with_auto_switch_threshold(mut self, threshold: f64) -> Self { + self.auto_switch_threshold = Some(threshold.clamp(0.0, 1.0)); + self + } +} + +/// Routes queries between primary (new) and legacy (old) vector indexes. +/// +/// During embedding model migration, this router ensures search quality +/// by querying both indexes and fusing results. It is generic over the +/// document ID type, matching the [`VectorSearch`](crate::hybrid::VectorSearch) trait. +/// +/// # Type Parameters +/// +/// * `Id` - The identifier type for search results. Must implement `Eq + Hash + Clone + Ord` +/// for deterministic fusion (same bounds as [`fuse`]). +pub struct DualIndexRouter { + config: DualIndexConfig, + _marker: std::marker::PhantomData, +} + +impl DualIndexRouter +where + Id: Eq + Hash + Clone + Ord, +{ + /// Create a new dual-index router with the given configuration. + pub fn new(config: DualIndexConfig) -> Self { + Self { + config, + _marker: std::marker::PhantomData, + } + } + + /// Determine whether the primary (new) index should be queried. + /// + /// Returns `false` only for [`DualIndexStrategy::LegacyOnly`]. + pub fn should_query_primary(&self, _migration_progress: Option) -> bool { + !matches!(self.config.strategy, DualIndexStrategy::LegacyOnly) + } + + /// Determine whether the legacy (old) index should be queried. + /// + /// Returns `false` for [`DualIndexStrategy::PrimaryOnly`], and also returns + /// `false` when migration progress exceeds the auto-switch threshold. + pub fn should_query_legacy(&self, migration_progress: Option) -> bool { + match &self.config.strategy { + DualIndexStrategy::PrimaryOnly => false, + DualIndexStrategy::LegacyOnly => true, + DualIndexStrategy::Both { .. } | DualIndexStrategy::Weighted { .. } => { + // Auto-switch: if migration is nearly complete, skip legacy + if let (Some(threshold), Some(progress)) = + (self.config.auto_switch_threshold, migration_progress) + { + progress < threshold + } else { + true + } + } + } + } + + /// Get the candidate pool size for each index. + /// + /// Returns `top_k * pool_multiplier`. + pub fn pool_size(&self, top_k: usize) -> usize { + top_k * self.config.pool_multiplier + } + + /// Merge results from primary and legacy indexes. + /// + /// Applies the configured strategy to combine results: + /// - `PrimaryOnly`: returns primary results (truncated) + /// - `LegacyOnly`: returns legacy results (truncated) + /// - `Both`: fuses both result sets using the configured fusion strategy + /// - `Weighted`: fuses with per-index weights + /// + /// # Arguments + /// + /// * `primary_results` - Results from the new embedding index + /// * `legacy_results` - Results from the old embedding index + /// * `top_k` - Number of results to return after merging + pub fn merge_results( + &self, + primary_results: Vec<(Id, DeterministicScore)>, + legacy_results: Vec<(Id, DeterministicScore)>, + top_k: usize, + ) -> Vec<(Id, DeterministicScore)> { + match &self.config.strategy { + DualIndexStrategy::PrimaryOnly => { + let mut results = primary_results; + results.truncate(top_k); + results + } + DualIndexStrategy::LegacyOnly => { + let mut results = legacy_results; + results.truncate(top_k); + results + } + DualIndexStrategy::Both { fusion } => { + let sources = vec![primary_results, legacy_results]; + fuse(sources, fusion, top_k) + } + DualIndexStrategy::Weighted { primary_weight } => { + let w = primary_weight.clamp(0.0, 1.0); + let strategy = FusionStrategy::weighted(vec![w, 1.0 - w]); + let sources = vec![primary_results, legacy_results]; + fuse(sources, &strategy, top_k) + } + } + } + + /// Get a reference to the current routing strategy. + pub fn strategy(&self) -> &DualIndexStrategy { + &self.config.strategy + } + + /// Update the routing strategy (e.g., when migration completes). + pub fn set_strategy(&mut self, strategy: DualIndexStrategy) { + self.config.strategy = strategy; + } + + /// Get a reference to the full configuration. + pub fn config(&self) -> &DualIndexConfig { + &self.config + } +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Helper to build scored result lists from (id, f64) pairs. + fn make_results(items: Vec<(&str, f64)>) -> Vec<(String, DeterministicScore)> { + items + .into_iter() + .map(|(id, score)| (id.to_string(), DeterministicScore::from_f64(score))) + .collect() + } + + // -- Strategy routing tests -- + + #[test] + fn test_primary_only_queries_only_primary() { + let config = DualIndexConfig::default().with_strategy(DualIndexStrategy::PrimaryOnly); + let router = DualIndexRouter::::new(config); + + assert!(router.should_query_primary(None)); + assert!(!router.should_query_legacy(None)); + } + + #[test] + fn test_legacy_only_queries_only_legacy() { + let config = DualIndexConfig::default().with_strategy(DualIndexStrategy::LegacyOnly); + let router = DualIndexRouter::::new(config); + + assert!(!router.should_query_primary(None)); + assert!(router.should_query_legacy(None)); + } + + #[test] + fn test_both_queries_both_indexes() { + let config = DualIndexConfig::default(); // default is Both { rrf } + let router = DualIndexRouter::::new(config); + + assert!(router.should_query_primary(None)); + assert!(router.should_query_legacy(None)); + } + + #[test] + fn test_weighted_queries_both_indexes() { + let config = DualIndexConfig::default().with_strategy(DualIndexStrategy::Weighted { + primary_weight: 0.8, + }); + let router = DualIndexRouter::::new(config); + + assert!(router.should_query_primary(None)); + assert!(router.should_query_legacy(None)); + } + + // -- Auto-switch threshold tests -- + + #[test] + fn test_auto_switch_skips_legacy_when_threshold_exceeded() { + let config = DualIndexConfig::default().with_auto_switch_threshold(0.95); + let router = DualIndexRouter::::new(config); + + // Migration at 90% - below threshold, still query legacy + assert!(router.should_query_legacy(Some(0.90))); + + // Migration at 95% - at threshold, skip legacy (progress >= threshold) + assert!(!router.should_query_legacy(Some(0.95))); + + // Migration at 99% - above threshold, skip legacy + assert!(!router.should_query_legacy(Some(0.99))); + } + + #[test] + fn test_auto_switch_no_threshold_always_queries_legacy() { + let config = DualIndexConfig::default(); // no auto_switch_threshold + let router = DualIndexRouter::::new(config); + + // Even with 100% progress, queries legacy without threshold + assert!(router.should_query_legacy(Some(1.0))); + } + + #[test] + fn test_auto_switch_no_progress_queries_legacy() { + let config = DualIndexConfig::default().with_auto_switch_threshold(0.95); + let router = DualIndexRouter::::new(config); + + // No progress info provided - query legacy to be safe + assert!(router.should_query_legacy(None)); + } + + // -- Pool size tests -- + + #[test] + fn test_pool_size_calculation() { + let config = DualIndexConfig::default(); // pool_multiplier = 3 + let router = DualIndexRouter::::new(config); + + assert_eq!(router.pool_size(10), 30); + assert_eq!(router.pool_size(1), 3); + assert_eq!(router.pool_size(0), 0); + } + + #[test] + fn test_pool_size_custom_multiplier() { + let config = DualIndexConfig::default().with_pool_multiplier(5); + let router = DualIndexRouter::::new(config); + + assert_eq!(router.pool_size(10), 50); + } + + // -- Merge results tests -- + + #[test] + fn test_merge_primary_only_returns_primary() { + let config = DualIndexConfig::default().with_strategy(DualIndexStrategy::PrimaryOnly); + let router = DualIndexRouter::::new(config); + + let primary = make_results(vec![("a", 0.9), ("b", 0.8)]); + let legacy = make_results(vec![("c", 0.95), ("d", 0.85)]); + + let merged = router.merge_results(primary, legacy, 10); + assert_eq!(merged.len(), 2); + assert_eq!(merged[0].0, "a"); + assert_eq!(merged[1].0, "b"); + } + + #[test] + fn test_merge_legacy_only_returns_legacy() { + let config = DualIndexConfig::default().with_strategy(DualIndexStrategy::LegacyOnly); + let router = DualIndexRouter::::new(config); + + let primary = make_results(vec![("a", 0.9), ("b", 0.8)]); + let legacy = make_results(vec![("c", 0.95), ("d", 0.85)]); + + let merged = router.merge_results(primary, legacy, 10); + assert_eq!(merged.len(), 2); + assert_eq!(merged[0].0, "c"); + assert_eq!(merged[1].0, "d"); + } + + #[test] + fn test_merge_both_fuses_with_rrf() { + let config = DualIndexConfig::default(); // Both { Rrf { k: 60 } } + let router = DualIndexRouter::::new(config); + + let primary = make_results(vec![("a", 0.9), ("b", 0.8)]); + let legacy = make_results(vec![("b", 0.95), ("c", 0.7)]); + + let merged = router.merge_results(primary, legacy, 10); + + // "b" appears in both sources, should get highest RRF score + assert_eq!(merged[0].0, "b"); + // All three unique IDs should be present + assert_eq!(merged.len(), 3); + } + + #[test] + fn test_merge_weighted_applies_weights() { + let config = DualIndexConfig::default().with_strategy(DualIndexStrategy::Weighted { + primary_weight: 0.8, + }); + let router = DualIndexRouter::::new(config); + + let primary = make_results(vec![("a", 0.9), ("b", 0.5)]); + let legacy = make_results(vec![("b", 0.9), ("c", 0.5)]); + + let merged = router.merge_results(primary, legacy, 10); + + // All three unique IDs should appear + let ids: Vec<&str> = merged.iter().map(|(id, _)| id.as_str()).collect(); + assert!(ids.contains(&"a")); + assert!(ids.contains(&"b")); + assert!(ids.contains(&"c")); + } + + #[test] + fn test_merge_respects_top_k() { + let config = DualIndexConfig::default(); + let router = DualIndexRouter::::new(config); + + let primary = make_results(vec![("a", 0.9), ("b", 0.8), ("c", 0.7)]); + let legacy = make_results(vec![("d", 0.95), ("e", 0.85), ("f", 0.75)]); + + let merged = router.merge_results(primary, legacy, 2); + assert_eq!(merged.len(), 2); + } + + #[test] + fn test_merge_empty_sources() { + let config = DualIndexConfig::default(); + let router = DualIndexRouter::::new(config); + + let merged = router.merge_results(vec![], vec![], 10); + assert!(merged.is_empty()); + } + + // -- Strategy mutation tests -- + + #[test] + fn test_set_strategy_updates_routing() { + let config = DualIndexConfig::default(); + let mut router = DualIndexRouter::::new(config); + + assert!(matches!(router.strategy(), DualIndexStrategy::Both { .. })); + + router.set_strategy(DualIndexStrategy::PrimaryOnly); + assert!(matches!(router.strategy(), DualIndexStrategy::PrimaryOnly)); + } + + // -- Config builder tests -- + + #[test] + fn test_config_default() { + let config = DualIndexConfig::default(); + assert!(matches!(config.strategy, DualIndexStrategy::Both { .. })); + assert_eq!(config.pool_multiplier, 3); + assert!(config.auto_switch_threshold.is_none()); + } + + #[test] + fn test_config_builder_chain() { + let config = DualIndexConfig::default() + .with_strategy(DualIndexStrategy::Weighted { + primary_weight: 0.7, + }) + .with_pool_multiplier(5) + .with_auto_switch_threshold(0.95); + + assert!(matches!( + config.strategy, + DualIndexStrategy::Weighted { primary_weight } if (primary_weight - 0.7).abs() < f64::EPSILON + )); + assert_eq!(config.pool_multiplier, 5); + assert!((config.auto_switch_threshold.unwrap() - 0.95).abs() < f64::EPSILON); + } + + #[test] + fn test_pool_multiplier_min_enforced() { + let config = DualIndexConfig::default().with_pool_multiplier(0); + assert_eq!(config.pool_multiplier, 1); + } + + #[test] + fn test_auto_switch_threshold_clamped() { + let config = DualIndexConfig::default().with_auto_switch_threshold(1.5); + assert!((config.auto_switch_threshold.unwrap() - 1.0).abs() < f64::EPSILON); + + let config = DualIndexConfig::default().with_auto_switch_threshold(-0.5); + assert!((config.auto_switch_threshold.unwrap() - 0.0).abs() < f64::EPSILON); + } + + // -- Default strategy tests -- + + #[test] + fn test_default_strategy_is_both_rrf() { + let strategy = DualIndexStrategy::default(); + assert_eq!( + strategy, + DualIndexStrategy::Both { + fusion: FusionStrategy::rrf() + } + ); + } +} diff --git a/crates/khive-retrieval/src/hybrid/mod.rs b/crates/khive-retrieval/src/hybrid/mod.rs new file mode 100644 index 00000000..8e28d4ce --- /dev/null +++ b/crates/khive-retrieval/src/hybrid/mod.rs @@ -0,0 +1,83 @@ +//! Unified hybrid search interface. +//! +//! Combines HNSW vector search, BM25 keyword search, and graph traversal +//! into a single query interface with configurable fusion strategies. +//! +//! # Architecture (ADR-002) +//! +//! ```text +//! Query ──┬── [Vector Search] ── HNSW ── Vec<(Id, Distance)> +//! │ │ +//! │ distance → similarity +//! │ │ +//! │ Vec<(Id, DeterministicScore)> +//! │ │ +//! └── [Keyword Search] ── BM25 ── Vec<(Id, BM25Score)> +//! │ +//! normalize → DeterministicScore +//! │ +//! Vec<(Id, DeterministicScore)> +//! │ +//! ┌─────────────┴─────────────┐ +//! │ reciprocal_rank_fusion │ +//! │ k=60 (standard) │ +//! └─────────────┬─────────────┘ +//! │ +//! Vec<(Id, DeterministicScore)> +//! ``` +//! +//! # Trait Hierarchy +//! +//! ```text +//! VectorSearch ──┐ +//! ├── HybridSearcher +//! KeywordSearch ─┘ +//! +//! Reranker (standalone, generic over Id) +//! ``` +//! +//! Each trait can be implemented independently: +//! - [`VectorSearch`]: Embedding-based nearest-neighbor search (e.g., HNSW) +//! - [`KeywordSearch`]: Text-based retrieval (e.g., BM25) +//! - [`HybridSearcher`]: Combined search requiring both vector + keyword +//! - [`Reranker`]: Post-retrieval reranking (e.g., cross-encoder) +//! +//! # Fusion Strategies +//! +//! - **RRF (Reciprocal Rank Fusion)**: Default and recommended. Uses only ranks, +//! making it robust to score distribution differences. +//! - **Weighted**: Linear combination of scores with configurable weights. +//! - **Union**: Takes the maximum score per ID across sources. +//! +//! # Example +//! +//! ```rust,ignore +//! use khive_retrieval::hybrid::{ +//! HybridConfig, HybridSearcher, VectorSearch, KeywordSearch, Query, fuse_search_results, +//! }; +//! use khive_score::DeterministicScore; +//! +//! // Create your own searcher implementing VectorSearch + KeywordSearch + HybridSearcher +//! // Then use fuse_search_results to combine vector and keyword results +//! +//! let vector_results = vec![("doc1".to_string(), DeterministicScore::from_f64(0.9))]; +//! let keyword_results = vec![("doc1".to_string(), DeterministicScore::from_f64(0.85))]; +//! +//! let config = HybridConfig::new(10); +//! let fused = fuse_search_results(vec![vector_results, keyword_results], &config); +//! ``` +//! +//! See [ADR-002](../docs/ADR-002-hybrid-search.md) for algorithm specification. + +mod config; +#[cfg(feature = "native-rerank")] +mod cross_encoder; +pub mod dual_index; +mod searcher; + +// Re-export public types +pub use config::{HybridConfig, Query, DEFAULT_POOL_MULTIPLIER}; +#[cfg(feature = "native-rerank")] +pub use cross_encoder::{CrossEncoderScorer, NativeCrossEncoderReranker, RerankDocumentResolver}; +pub use dual_index::{DualIndexConfig, DualIndexRouter, DualIndexStrategy}; +pub use searcher::{fuse_search_results, HybridSearcher, KeywordSearch, Reranker, VectorSearch}; diff --git a/crates/khive-retrieval/src/hybrid/searcher.rs b/crates/khive-retrieval/src/hybrid/searcher.rs new file mode 100644 index 00000000..d24b12bd --- /dev/null +++ b/crates/khive-retrieval/src/hybrid/searcher.rs @@ -0,0 +1,359 @@ +//! Granular search traits and hybrid search implementation. +//! +//! # Trait Hierarchy +//! +//! ```text +//! VectorSearch ──┐ +//! ├── HybridSearcher +//! KeywordSearch ─┘ +//! +//! Reranker (standalone, generic over Id) +//! ``` +//! +//! Each trait can be implemented independently, enabling: +//! - Vector-only search (e.g., HNSW index) +//! - Keyword-only search (e.g., BM25 index) +//! - Full hybrid search (combining both with fusion) +//! - Reranking as a separate, composable concern + +use std::hash::Hash; + +use async_trait::async_trait; +use khive_score::DeterministicScore; + +use crate::error::Result; +use khive_fusion::{fuse, FusionStrategy}; + +use super::config::{HybridConfig, Query}; + +/// Trait for vector similarity search. +/// +/// Implementors provide embedding-based nearest-neighbor search +/// (e.g., HNSW, flat scan, IVF). +/// +/// # Associated Types +/// +/// * `Id` - The identifier type for documents/results. Requires `Ord` for +/// deterministic tie-breaking when scores are equal. +/// +/// # Example +/// +/// ```rust,ignore +/// use khive_retrieval::hybrid::VectorSearch; +/// +/// struct MyVectorIndex { /* ... */ } +/// +/// #[async_trait::async_trait] +/// impl VectorSearch for MyVectorIndex { +/// type Id = String; +/// +/// async fn vector_search(&self, embedding: &[f32], top_k: usize) +/// -> khive_retrieval::Result> +/// { +/// // Your HNSW/ANN implementation here +/// todo!() +/// } +/// } +/// ``` +#[async_trait] +pub trait VectorSearch: Send + Sync { + /// The ID type for search results. + /// `Ord` is required for deterministic tie-breaking when scores are equal. + type Id: Eq + Hash + Clone + Ord + Send + Sync; + + /// Perform vector-only search. + /// + /// # Arguments + /// + /// * `embedding` - Query embedding vector + /// * `top_k` - Number of results to return + /// + /// # Returns + /// + /// Vector of (Id, DeterministicScore) pairs sorted by similarity descending. + async fn vector_search( + &self, + embedding: &[f32], + top_k: usize, + ) -> Result>; +} + +/// Trait for keyword-based search. +/// +/// Implementors provide text-based retrieval (e.g., BM25, TF-IDF). +/// +/// # Associated Types +/// +/// * `Id` - The identifier type for documents/results. Requires `Ord` for +/// deterministic tie-breaking when scores are equal. +/// +/// # Example +/// +/// ```rust,ignore +/// use khive_retrieval::hybrid::KeywordSearch; +/// +/// struct MyBm25Index { /* ... */ } +/// +/// #[async_trait::async_trait] +/// impl KeywordSearch for MyBm25Index { +/// type Id = String; +/// +/// async fn keyword_search(&self, text: &str, top_k: usize) +/// -> khive_retrieval::Result> +/// { +/// // Your BM25 implementation here +/// todo!() +/// } +/// } +/// ``` +#[async_trait] +pub trait KeywordSearch: Send + Sync { + /// The ID type for search results. + /// `Ord` is required for deterministic tie-breaking when scores are equal. + type Id: Eq + Hash + Clone + Ord + Send + Sync; + + /// Perform keyword-only search (BM25). + /// + /// # Arguments + /// + /// * `text` - Query text + /// * `top_k` - Number of results to return + /// + /// # Returns + /// + /// Vector of (Id, DeterministicScore) pairs sorted by BM25 score descending. + async fn keyword_search( + &self, + text: &str, + top_k: usize, + ) -> Result>; +} + +/// Trait for hybrid search operations. +/// +/// Combines vector similarity search (HNSW) with keyword search (BM25) +/// using configurable fusion strategies. +/// +/// # Supertrait Constraint +/// +/// Requires both [`VectorSearch`] and [`KeywordSearch`] to be implemented +/// with the **same `Id` type**, enforced by the +/// `KeywordSearch::Id>` bound. +/// +/// # Example +/// +/// ```rust,ignore +/// use khive_retrieval::hybrid::{HybridSearcher, VectorSearch, KeywordSearch}; +/// +/// struct MyHybridIndex { /* ... */ } +/// +/// // Implement VectorSearch and KeywordSearch first, then HybridSearcher +/// #[async_trait::async_trait] +/// impl HybridSearcher for MyHybridIndex { +/// async fn hybrid_search(&self, query: &Query, config: &HybridConfig) +/// -> Result> +/// { +/// let mut sources = Vec::new(); +/// if let Some(emb) = &query.embedding { +/// sources.push(self.vector_search(emb, config.candidate_pool_size).await?); +/// } +/// sources.push(self.keyword_search(&query.text, config.candidate_pool_size).await?); +/// Ok(fuse_search_results(sources, config)) +/// } +/// } +/// ``` +#[async_trait] +pub trait HybridSearcher: VectorSearch + KeywordSearch::Id> { + /// Perform hybrid search combining vector and keyword retrieval. + /// + /// # Arguments + /// + /// * `query` - The search query (text + optional embedding) + /// * `config` - Hybrid search configuration + /// + /// # Returns + /// + /// Vector of (Id, DeterministicScore) pairs sorted by fused score descending. + async fn hybrid_search( + &self, + query: &Query, + config: &HybridConfig, + ) -> Result::Id, DeterministicScore)>>; +} + +/// Trait for reranking search results. +/// +/// Separates the reranking concern from search, enabling: +/// - Cross-encoder neural reranking +/// - LLM-based reranking +/// - Custom scoring adjustments +/// +/// The `Id` type is a generic parameter rather than an associated type, +/// allowing a single reranker to work with different ID types. +/// +/// # Example +/// +/// ```rust,ignore +/// use khive_retrieval::hybrid::Reranker; +/// +/// struct CrossEncoderReranker { /* model handle */ } +/// +/// #[async_trait::async_trait] +/// impl Reranker for CrossEncoderReranker { +/// async fn rerank( +/// &self, +/// query: &str, +/// results: Vec<(String, DeterministicScore)>, +/// top_k: usize, +/// ) -> Result> { +/// // Score each (query, document) pair with cross-encoder +/// // Sort by new scores and truncate to top_k +/// todo!() +/// } +/// } +/// ``` +#[async_trait] +pub trait Reranker: Send + Sync { + /// Rerank search results using additional signals. + /// + /// # Arguments + /// + /// * `query` - The original query text for relevance scoring + /// * `results` - Pre-ranked results to reorder + /// * `top_k` - Number of results to return after reranking + /// + /// # Returns + /// + /// Reranked vector of (Id, DeterministicScore) pairs, truncated to `top_k`. + async fn rerank( + &self, + query: &str, + results: Vec<(Id, DeterministicScore)>, + top_k: usize, + ) -> Result>; +} + +/// Helper function to perform fusion on search results. +/// +/// This can be used by implementors of [`HybridSearcher`] to fuse results +/// from their [`VectorSearch`] and [`KeywordSearch`] implementations. +/// +/// `Ord` is required for deterministic tie-breaking when scores are equal. +pub fn fuse_search_results( + sources: Vec>, + config: &HybridConfig, +) -> Vec<(Id, DeterministicScore)> { + if sources.is_empty() { + return Vec::new(); + } + + if sources.len() == 1 { + let mut results = sources.into_iter().next().unwrap(); + if let Some(min_score) = config.min_score { + results.retain(|(_, score)| *score >= min_score); + } + results.truncate(config.top_k); + return results; + } + + // Determine fusion strategy + let strategy = match &config.fusion_strategy { + FusionStrategy::Weighted { .. } => { + // Use configured weights — constrained to exactly 2 sources (vector + keyword) + debug_assert_eq!( + sources.len(), + 2, + "Weighted fusion expects exactly 2 sources" + ); + let (v, k) = config.normalized_weights(); + FusionStrategy::weighted(vec![v, k]) + } + other => other.clone(), + }; + + // Fuse results + let mut fused = fuse(sources, &strategy, config.top_k); + + // Apply minimum score filter + if let Some(min_score) = config.min_score { + fused.retain(|(_, score)| *score >= min_score); + } + + fused +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_fuse_empty_sources() { + let sources: Vec> = vec![]; + let config = HybridConfig::default(); + let results = fuse_search_results(sources, &config); + assert!(results.is_empty()); + } + + #[test] + fn test_fuse_single_source() { + let sources = vec![vec![ + ("a".to_string(), DeterministicScore::from_f64(0.9)), + ("b".to_string(), DeterministicScore::from_f64(0.8)), + ]]; + let config = HybridConfig::new(10); + let results = fuse_search_results(sources, &config); + + assert_eq!(results.len(), 2); + assert_eq!(results[0].0, "a"); + } + + #[test] + fn test_fuse_multiple_sources_rrf() { + let source1 = vec![ + ("a".to_string(), DeterministicScore::from_f64(0.9)), + ("b".to_string(), DeterministicScore::from_f64(0.8)), + ]; + let source2 = vec![ + ("b".to_string(), DeterministicScore::from_f64(0.95)), + ("c".to_string(), DeterministicScore::from_f64(0.7)), + ]; + + let config = HybridConfig::new(10); + let results = fuse_search_results(vec![source1, source2], &config); + + assert_eq!(results.len(), 3); + // b appears in both, should have highest RRF score + assert_eq!(results[0].0, "b"); + } + + #[test] + fn test_fuse_with_min_score() { + let sources = vec![vec![ + ("a".to_string(), DeterministicScore::from_f64(0.9)), + ("b".to_string(), DeterministicScore::from_f64(0.1)), + ]]; + + let config = HybridConfig::new(10).with_min_score(DeterministicScore::from_f64(0.5)); + let results = fuse_search_results(sources, &config); + + // b should be filtered out (RRF score ~0.016 < 0.5) + // Actually RRF scores are very small, let's use a lower threshold + assert!(!results.is_empty()); + } + + #[test] + fn test_fuse_top_k_limit() { + let sources = vec![vec![ + ("a".to_string(), DeterministicScore::from_f64(0.9)), + ("b".to_string(), DeterministicScore::from_f64(0.8)), + ("c".to_string(), DeterministicScore::from_f64(0.7)), + ("d".to_string(), DeterministicScore::from_f64(0.6)), + ("e".to_string(), DeterministicScore::from_f64(0.5)), + ]]; + + let config = HybridConfig::new(3); + let results = fuse_search_results(sources, &config); + + assert_eq!(results.len(), 3); + } +} diff --git a/crates/khive-retrieval/src/lib.rs b/crates/khive-retrieval/src/lib.rs new file mode 100644 index 00000000..60e61287 --- /dev/null +++ b/crates/khive-retrieval/src/lib.rs @@ -0,0 +1,190 @@ +#![allow(clippy::uninlined_format_args)] +#![allow(clippy::field_reassign_with_default)] +#![allow(clippy::approx_constant)] +// Note: field_reassign_with_default is needed for some internal tests + +//! Hybrid search and ranking with deterministic scoring for khive. +//! +//! This crate provides: +//! - HNSW vector search with `DeterministicScore` output +//! - BM25 keyword search for exact matches +//! - Reciprocal Rank Fusion (RRF) for hybrid search +//! - Graph traversal for relationship-aware retrieval +//! +//! # Architecture +//! +//! ```text +//! ┌─────────────────────────────────────────────────────────────────┐ +//! │ khive-retrieval │ +//! │ ┌───────────┐ ┌───────────┐ ┌───────────┐ ┌───────────┐ │ +//! │ │ hnsw/ │ │ bm25/ │ │ graph/ │ │ fusion/ │ │ +//! │ │ (vector) │ │ (keyword) │ │(traversal)│ │ (RRF) │ │ +//! │ └───────────┘ └───────────┘ └───────────┘ └───────────┘ │ +//! │ │ │ +//! │ ▼ │ +//! │ ┌───────────────┐ │ +//! │ │ hybrid/ │ │ +//! │ │ (unified) │ │ +//! │ └───────────────┘ │ +//! │ │ +//! │ Inputs: Query + optional embedding + optional start nodes │ +//! │ Outputs: Vec<(Id, DeterministicScore)> │ +//! └─────────────────────────────────────────────────────────────────┘ +//! ``` +//! +//! # Design Principles +//! +//! ## Deterministic Scoring (ADR-002) +//! +//! All scores use `DeterministicScore` from `khive-score` for: +//! - Cross-platform identical rankings (x86_64, ARM64, WASM) +//! - `Ord` implementation (sortable, usable in BTreeSet) +//! - `Hash` implementation (cacheable) +//! +//! ## Index Management (ADR-003) +//! +//! - HNSW: Hierarchical Navigable Small World graphs for ANN search +//! - BM25: Okapi BM25 for keyword relevance +//! - Both support incremental updates with periodic rebuild +//! +//! ## Graph Traversal (ADR-004) +//! +//! - BFS for level-by-level exploration +//! - DFS for deep path exploration +//! - Bidirectional BFS for shortest path +//! +//! ## ID Types and Bridging +//! +//! Each retrieval module uses a different ID type: +//! +//! | Module | ID Type | Backing | +//! |--------|---------|---------| +//! | HNSW | [`EmbeddingId`] | 128-bit (ULID, from khive-types) | +//! | BM25 | [`DocumentId`] | Newtype over `String` | +//! | Graph | `EntityRef` | Enum (from khive-db) | +//! | Fusion | Generic `Id` | `Eq + Hash + Clone + Ord` | +//! +//! The [`fusion::fuse`] function is generic over the ID type, so hybrid +//! search that combines results from different modules requires a common +//! representation. Bridging strategies: +//! +//! 1. **String-based**: Convert all IDs to `String` before fusion. +//! 2. **DocumentId-based**: Convert `EmbeddingId` to `DocumentId` via +//! `DocumentId::new(embedding_id.to_string())`. +//! 3. **Application-level mapping**: Maintain a bidirectional lookup table +//! between ID types in the application layer. +//! +//! See [`DocumentId`] for details on the newtype and conversion traits. +//! +//! # Quick Start +//! +//! ```rust,ignore +//! use khive_retrieval::{VectorSearch, KeywordSearch, HybridSearcher, Query, HybridConfig}; +//! +//! // Implement granular traits independently: +//! // - VectorSearch for embedding-based search (HNSW) +//! // - KeywordSearch for text-based search (BM25) +//! // - HybridSearcher for combined search (requires both) +//! // - Reranker for post-retrieval reranking (standalone) +//! +//! // Example: keyword-only search +//! let results = searcher.keyword_search("distributed systems", 10).await?; +//! +//! // Example: hybrid search (vector + keyword with fusion) +//! let query = Query::hybrid("distributed systems", embedding_vec); +//! let config = HybridConfig::new(10); +//! let results = searcher.hybrid_search(&query, &config).await?; +//! +//! for (id, score) in results { +//! println!("{}: {}", id, score); +//! } +//! ``` + +#![warn(missing_docs)] +#![warn(clippy::all)] + +#[cfg(feature = "storage-adapters")] +pub mod adapters; +pub mod error; +pub mod eval; +// graph module depends on EntityRef/LinkStore/StorageContext from old monolith khive-db API; +// gated until ported to current khive-storage GraphStore trait. +#[cfg(feature = "graph-legacy")] +pub mod graph; +pub mod hybrid; +pub mod metrics; +#[cfg(feature = "persist")] +pub mod persist; +pub mod policy; +pub mod query_ir; +#[cfg(feature = "persist")] +pub mod replay; +pub mod search_config; +pub mod timeout; +#[cfg(feature = "persist")] +pub mod weights; + +// Re-export adapter types +#[cfg(feature = "storage-adapters")] +pub use adapters::{StorageKeywordSearch, StorageVectorSearch}; + +// Re-export core types +pub use error::{ErrorKind, Result, RetrievalError}; + +// Re-export types from sibling crates (now separate crates) +#[cfg(feature = "graph-legacy")] +pub use graph::{ + bfs_traverse, dfs_traverse, find_shortest_path, Direction, PathNode, TraversalOptions, + MAX_TRAVERSAL_DEPTH, MAX_TRAVERSAL_RESULTS, +}; +pub use khive_bm25::{Bm25Config, Bm25Index, Bm25Stats, DocumentId, SearchContext}; +pub use khive_fusion::{ + fuse, normalize_weights, reciprocal_rank_fusion, weighted_fusion, weights_are_normalized, + FusionStrategy, DEFAULT_RRF_K, +}; +pub use khive_hnsw::{ + DistanceMetric, HnswCheckpointConfig, HnswConfig, HnswIndex, HnswSearchContext, HnswSnapshot, + NodeId, RebuildStats, TombstoneStats, +}; +// TODO(port-checkpoint): HnswCheckpoint/HnswCheckpointStore depend on khive_fold::Checkpoint +// which doesn't exist in the current khive-fold API. Re-enable when ported. +// #[cfg(feature = "checkpoint")] +// pub use khive_hnsw::{HnswCheckpoint, HnswCheckpointStore}; +pub use hybrid::{ + fuse_search_results, DualIndexConfig, DualIndexRouter, DualIndexStrategy, HybridConfig, + HybridSearcher, KeywordSearch, Query, Reranker, VectorSearch, +}; +// TODO(port-rerank): native cross-encoder reranking deferred; khive-inference not ported yet +// #[cfg(feature = "native-rerank")] +// pub use hybrid::{CrossEncoderScorer, NativeCrossEncoderReranker, RerankDocumentResolver}; +pub use metrics::{MetricEvent, MetricValue, MetricsSink, NoopSink, RecordingSink}; +#[cfg(feature = "persist")] +pub use persist::{ + PersistError, PersistenceStats, RetrievalPersistence, ShadowMetrics, ShadowValidationConfig, + ShadowValidationResult, +}; +pub use policy::{filter_by_policy, filter_by_predicate, ClearanceLevel, SearchPolicy}; +pub use query_ir::{FilterPredicate, FuseStrategy, QueryNode, RerankMethod}; +pub use search_config::SearchConfig; +pub use timeout::{ + search_with_cancellation, search_with_deadline, search_with_optional_timeout, + search_with_timeout, +}; + +/// Re-exports from `lattice-embed` for app-layer access. +/// +/// Apps should use these re-exports instead of depending on `lattice-embed` directly. +/// This maintains the layer boundary: apps -> platform (retrieval) -> foundation (embed). +/// +/// Core types (`EmbeddingModel`, `EmbeddingService`, `EmbedError`) are always available. +/// Native model implementations (`NativeEmbeddingService`, etc.) require the `embed` feature. +pub mod embed { + // Core types and traits (always available, no feature gate needed) + /// Result alias for embedding operations. + pub use lattice_embed::Result as EmbedResult; + pub use lattice_embed::{EmbedError, EmbeddingModel, EmbeddingService}; + + // Native model implementations (pure Rust lattice-embed via "embed" feature) + #[cfg(feature = "embed")] + pub use lattice_embed::{CachedEmbeddingService, NativeEmbeddingService}; +} diff --git a/crates/khive-retrieval/src/metrics.rs b/crates/khive-retrieval/src/metrics.rs new file mode 100644 index 00000000..0a074c8a --- /dev/null +++ b/crates/khive-retrieval/src/metrics.rs @@ -0,0 +1,353 @@ +//! Observability hooks for retrieval indices. +//! +//! Provides a lightweight, trait-based metrics abstraction that avoids coupling +//! the retrieval crate to any specific observability stack (Prometheus, OpenTelemetry, +//! etc.). Callers inject a [`MetricsSink`] implementation and the indices emit +//! well-known [`MetricEvent`]s during their operations. +//! +//! # Design Rationale +//! +//! - **Trait-based sink** rather than a global registry keeps the library +//! dependency-free and testable. The [`NoopSink`] compiles to zero overhead +//! when no observability is needed. +//! - **`Arc`** allows sharing one sink across multiple indices +//! without lifetime gymnastics. +//! - **Well-known metric names** are `&'static str` constants so dashboards +//! can be built once and never break on typos. +//! +//! # Quick Start +//! +//! ```rust,ignore +//! use std::sync::Arc; +//! use khive_retrieval::metrics::{MetricsSink, NoopSink, RecordingSink}; +//! use khive_retrieval::HnswIndex; +//! +//! // Production: no-op (zero overhead) +//! let mut idx = HnswIndex::new(128); +//! +//! // Testing: capture events +//! let sink = Arc::new(RecordingSink::new()); +//! let mut idx = HnswIndex::new(128).with_metrics(sink.clone()); +//! // ... perform operations ... +//! let events = sink.events(); +//! assert!(!events.is_empty()); +//! ``` + +use std::fmt; +use std::sync::{Arc, Mutex}; + +// --------------------------------------------------------------------------- +// Core types +// --------------------------------------------------------------------------- + +/// A single metric observation. +#[derive(Debug, Clone)] +pub struct MetricEvent { + /// Well-known metric name (use constants from [`names`]). + pub name: &'static str, + /// Observed value. + pub value: MetricValue, + /// Dimensional labels for grouping / filtering. + pub labels: Vec<(&'static str, String)>, +} + +/// Metric value kinds. +#[derive(Debug, Clone, PartialEq)] +pub enum MetricValue { + /// Monotonically increasing count. + Counter(u64), + /// Point-in-time measurement (can go up or down). + Gauge(f64), + /// Duration or distribution sample (typically seconds or milliseconds). + Histogram(f64), +} + +impl fmt::Display for MetricValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + MetricValue::Counter(v) => write!(f, "counter({v})"), + MetricValue::Gauge(v) => write!(f, "gauge({v})"), + MetricValue::Histogram(v) => write!(f, "histogram({v})"), + } + } +} + +// --------------------------------------------------------------------------- +// Sink trait +// --------------------------------------------------------------------------- + +/// Receiver of metric events emitted by retrieval indices. +/// +/// Implementors bridge to their observability stack (Prometheus counters, +/// OTel meters, StatsD, etc.). The trait is `Send + Sync` so a single +/// `Arc` can be shared across threads. +pub trait MetricsSink: Send + Sync + fmt::Debug { + /// Record a single metric event. + fn record(&self, event: MetricEvent); +} + +// --------------------------------------------------------------------------- +// Built-in sinks +// --------------------------------------------------------------------------- + +/// Sink that silently discards every event. +/// +/// This is the implicit default when no metrics are configured. +/// All calls compile down to a no-op. +#[derive(Debug, Clone, Copy, Default)] +pub struct NoopSink; + +impl MetricsSink for NoopSink { + #[inline] + fn record(&self, _event: MetricEvent) { + // intentionally empty + } +} + +/// Thread-safe recording sink for tests. +/// +/// Collects every [`MetricEvent`] into an internal `Vec` guarded by a +/// `Mutex`. Use [`events()`](Self::events) to snapshot the recorded events +/// and [`clear()`](Self::clear) to reset. +/// +/// # Example +/// +/// ```rust,ignore +/// use std::sync::Arc; +/// use khive_retrieval::metrics::RecordingSink; +/// +/// let sink = Arc::new(RecordingSink::new()); +/// // ... pass to index ... +/// let events = sink.events(); +/// assert!(events.iter().any(|e| e.name == "hnsw.search.duration_ms")); +/// ``` +#[derive(Debug, Default)] +pub struct RecordingSink { + events: Mutex>, +} + +impl RecordingSink { + /// Create a new, empty recording sink. + pub fn new() -> Self { + Self { + events: Mutex::new(Vec::new()), + } + } + + /// Return a snapshot of all recorded events. + /// + /// Returns an empty vec if the mutex is poisoned (indicates a prior panic). + pub fn events(&self) -> Vec { + self.events + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) + .clone() + } + + /// Discard all recorded events. + /// + /// Silently skips clearing if the mutex is poisoned. + pub fn clear(&self) { + if let Ok(mut guard) = self.events.lock() { + guard.clear(); + } + } + + /// Return the number of recorded events. + /// + /// Returns 0 if the mutex is poisoned. + pub fn len(&self) -> usize { + self.events + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) + .len() + } + + /// Check if no events have been recorded. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +impl MetricsSink for RecordingSink { + fn record(&self, event: MetricEvent) { + if let Ok(mut guard) = self.events.lock() { + guard.push(event); + } + } +} + +// --------------------------------------------------------------------------- +// Well-known metric names +// --------------------------------------------------------------------------- + +/// Well-known metric name constants. +/// +/// Using constants prevents typos and allows dashboards to be built once. +/// Names follow the `{subsystem}.{operation}.{measurement}` convention. +pub mod names { + // -- HNSW -- + + /// Duration of a single HNSW search in milliseconds. + pub const HNSW_SEARCH_DURATION_MS: &str = "hnsw.search.duration_ms"; + /// Number of HNSW search operations completed. + pub const HNSW_SEARCH_COUNT: &str = "hnsw.search.count"; + /// Number of results returned by an HNSW search. + pub const HNSW_SEARCH_RESULTS: &str = "hnsw.search.results"; + + /// Duration of a single HNSW insert in milliseconds. + pub const HNSW_INSERT_DURATION_MS: &str = "hnsw.insert.duration_ms"; + /// Number of HNSW insert operations completed. + pub const HNSW_INSERT_COUNT: &str = "hnsw.insert.count"; + + /// Duration of an HNSW rebuild in milliseconds. + pub const HNSW_REBUILD_DURATION_MS: &str = "hnsw.rebuild.duration_ms"; + /// Number of HNSW rebuild operations completed. + pub const HNSW_REBUILD_COUNT: &str = "hnsw.rebuild.count"; + /// Number of nodes removed during a rebuild. + pub const HNSW_REBUILD_NODES_REMOVED: &str = "hnsw.rebuild.nodes_removed"; + + /// Current number of live vectors in the HNSW index. + pub const HNSW_INDEX_SIZE: &str = "hnsw.index.size"; + + // -- BM25 -- + + /// Duration of a single BM25 search in milliseconds. + pub const BM25_SEARCH_DURATION_MS: &str = "bm25.search.duration_ms"; + /// Number of BM25 search operations completed. + pub const BM25_SEARCH_COUNT: &str = "bm25.search.count"; + /// Number of results returned by a BM25 search. + pub const BM25_SEARCH_RESULTS: &str = "bm25.search.results"; + + /// Duration of a single BM25 index_document call in milliseconds. + pub const BM25_INDEX_DURATION_MS: &str = "bm25.index_document.duration_ms"; + /// Number of BM25 index_document operations completed. + pub const BM25_INDEX_COUNT: &str = "bm25.index_document.count"; + + /// Current number of documents in the BM25 index. + pub const BM25_INDEX_SIZE: &str = "bm25.index.size"; +} + +// --------------------------------------------------------------------------- +// Helper: emit to optional sink +// --------------------------------------------------------------------------- + +/// Convenience function to emit a metric event to an optional sink. +/// +/// This avoids repeating `if let Some(sink) = &self.metrics { ... }` in +/// every instrumented method. +#[inline] +pub fn emit(sink: &Option>, event: MetricEvent) { + if let Some(s) = sink { + s.record(event); + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +#[allow(clippy::approx_constant)] +mod tests { + use super::*; + + #[test] + fn noop_sink_does_not_panic() { + let sink = NoopSink; + sink.record(MetricEvent { + name: names::HNSW_SEARCH_COUNT, + value: MetricValue::Counter(1), + labels: vec![], + }); + } + + #[test] + fn recording_sink_captures_events() { + let sink = RecordingSink::new(); + assert!(sink.is_empty()); + + sink.record(MetricEvent { + name: names::HNSW_SEARCH_DURATION_MS, + value: MetricValue::Histogram(1.5), + labels: vec![("k", "10".to_string())], + }); + sink.record(MetricEvent { + name: names::HNSW_SEARCH_COUNT, + value: MetricValue::Counter(1), + labels: vec![], + }); + + assert_eq!(sink.len(), 2); + assert!(!sink.is_empty()); + + let events = sink.events(); + assert_eq!(events.len(), 2); + assert_eq!(events[0].name, names::HNSW_SEARCH_DURATION_MS); + assert_eq!(events[1].name, names::HNSW_SEARCH_COUNT); + } + + #[test] + fn recording_sink_clear() { + let sink = RecordingSink::new(); + sink.record(MetricEvent { + name: names::HNSW_INSERT_COUNT, + value: MetricValue::Counter(1), + labels: vec![], + }); + assert_eq!(sink.len(), 1); + + sink.clear(); + assert!(sink.is_empty()); + } + + #[test] + fn metric_value_display() { + assert_eq!(MetricValue::Counter(42).to_string(), "counter(42)"); + assert_eq!(MetricValue::Gauge(3.14).to_string(), "gauge(3.14)"); + assert_eq!(MetricValue::Histogram(1.5).to_string(), "histogram(1.5)"); + } + + #[test] + fn emit_helper_with_none() { + // Should not panic + emit( + &None, + MetricEvent { + name: names::HNSW_SEARCH_COUNT, + value: MetricValue::Counter(1), + labels: vec![], + }, + ); + } + + #[test] + fn emit_helper_with_some() { + let sink = Arc::new(RecordingSink::new()); + let opt: Option> = Some(sink.clone()); + + emit( + &opt, + MetricEvent { + name: names::BM25_SEARCH_COUNT, + value: MetricValue::Counter(1), + labels: vec![], + }, + ); + + assert_eq!(sink.len(), 1); + } + + #[test] + fn recording_sink_is_send_sync() { + fn assert_send_sync() {} + assert_send_sync::(); + } + + #[test] + fn metrics_sink_is_object_safe() { + // Prove we can construct an Arc + let _: Arc = Arc::new(NoopSink); + let _: Arc = Arc::new(RecordingSink::new()); + } +} diff --git a/crates/khive-retrieval/src/persist/bm25.rs b/crates/khive-retrieval/src/persist/bm25.rs new file mode 100644 index 00000000..b6043dc3 --- /dev/null +++ b/crates/khive-retrieval/src/persist/bm25.rs @@ -0,0 +1,112 @@ +//! BM25-specific persistence methods. + +use khive_bm25::Bm25Index; + +use super::shadow::{log_validation_result, should_sample}; +use super::{ + PersistError, RetrievalPersistence, ShadowMetrics, ShadowValidationConfig, + ShadowValidationResult, +}; + +impl RetrievalPersistence { + /// Persist a BM25 index to SQLite. + /// + /// The entire index is serialized (it already has Serde derives). + pub async fn persist_bm25_index(&self, index: &Bm25Index) -> Result<(), PersistError> { + self.persist_snapshot("bm25", index).await + } + + /// Load the latest BM25 index from SQLite. + /// + /// Returns `None` if no snapshot exists for this namespace. + /// Rebuilds the fast-path `doc_lengths_vec` from the deserialized HashMap. + pub async fn load_bm25_index(&self) -> Result, PersistError> { + let mut index = self.load_snapshot::("bm25").await?; + if let Some(ref mut idx) = index { + idx.ensure_doc_lengths_vec(); + } + Ok(index) + } + + /// Persist a BM25 index with optional shadow validation. + /// + /// If shadow validation is enabled, the index is immediately loaded + /// back and compared to verify integrity. Discrepancies are logged but + /// do not block the persist operation. + pub async fn persist_bm25_with_validation( + &self, + index: &Bm25Index, + config: &ShadowValidationConfig, + ) -> Result, PersistError> { + // Always persist first + self.persist_bm25_index(index).await?; + + // Skip validation if disabled or not sampled + if !config.enabled || !should_sample(config.sample_rate) { + return Ok(None); + } + + // Capture expected metrics + let expected = ShadowMetrics { + item_count: index.doc_count(), + tombstone_count: 0, // BM25 doesn't have tombstones + snapshot_size: 0, + }; + + // Perform shadow validation + let result = self.validate_bm25_snapshot(expected).await; + + // Log result (non-blocking) + log_validation_result(&result); + + Ok(Some(result)) + } + + /// Validate a BM25 snapshot by loading it back and comparing metrics. + pub(crate) async fn validate_bm25_snapshot( + &self, + expected: ShadowMetrics, + ) -> ShadowValidationResult { + let mut result = ShadowValidationResult { + passed: false, + index_type: "bm25".to_string(), + expected: expected.clone(), + actual: None, + discrepancies: Vec::new(), + }; + + // Try to load the snapshot back + match self.load_bm25_index().await { + Ok(Some(index)) => { + let actual = ShadowMetrics { + item_count: index.doc_count(), + tombstone_count: 0, + snapshot_size: 0, + }; + + // Compare metrics + if actual.item_count != expected.item_count { + result.discrepancies.push(format!( + "doc_count mismatch: expected {}, got {}", + expected.item_count, actual.item_count + )); + } + + result.actual = Some(actual); + result.passed = result.discrepancies.is_empty(); + } + Ok(None) => { + result + .discrepancies + .push("index not found after persist".to_string()); + } + Err(e) => { + result + .discrepancies + .push(format!("failed to load index: {e}")); + } + } + + result + } +} diff --git a/crates/khive-retrieval/src/persist/hnsw.rs b/crates/khive-retrieval/src/persist/hnsw.rs new file mode 100644 index 00000000..d325f6a7 --- /dev/null +++ b/crates/khive-retrieval/src/persist/hnsw.rs @@ -0,0 +1,127 @@ +//! HNSW-specific persistence methods. + +use khive_hnsw::HnswIndex; +use khive_hnsw::HnswSnapshot; + +use super::shadow::{log_validation_result, should_sample}; +use super::{ + PersistError, RetrievalPersistence, ShadowMetrics, ShadowValidationConfig, + ShadowValidationResult, +}; + +impl RetrievalPersistence { + /// Persist an HNSW index snapshot to SQLite. + /// + /// Creates a snapshot of the index and stores it as a serialized BLOB. + pub async fn persist_hnsw_snapshot(&self, index: &HnswIndex) -> Result<(), PersistError> { + let snapshot = index.snapshot(); + self.persist_snapshot("hnsw", &snapshot).await + } + + /// Load the latest HNSW snapshot from SQLite. + /// + /// Returns `None` if no snapshot exists for this namespace. + pub async fn load_hnsw_snapshot(&self) -> Result, PersistError> { + self.load_snapshot::("hnsw").await + } + + /// Persist an HNSW snapshot with optional shadow validation. + /// + /// If shadow validation is enabled, the snapshot is immediately loaded + /// back and compared to verify integrity. Discrepancies are logged but + /// do not block the persist operation. + pub async fn persist_hnsw_with_validation( + &self, + index: &HnswIndex, + config: &ShadowValidationConfig, + ) -> Result, PersistError> { + // Always persist first + self.persist_hnsw_snapshot(index).await?; + + // Skip validation if disabled or not sampled + if !config.enabled || !should_sample(config.sample_rate) { + return Ok(None); + } + + // Capture expected metrics + let expected = ShadowMetrics { + item_count: index.len(), + tombstone_count: index.tombstone_stats().tombstone_count, + snapshot_size: 0, // Will be filled by stats + }; + + // Perform shadow validation + let result = self.validate_hnsw_snapshot(expected).await; + + // Log result (non-blocking) + log_validation_result(&result); + + Ok(Some(result)) + } + + /// Validate an HNSW snapshot by loading it back and comparing metrics. + pub(crate) async fn validate_hnsw_snapshot( + &self, + expected: ShadowMetrics, + ) -> ShadowValidationResult { + let mut result = ShadowValidationResult { + passed: false, + index_type: "hnsw".to_string(), + expected: expected.clone(), + actual: None, + discrepancies: Vec::new(), + }; + + // Try to load the snapshot back + match self.load_hnsw_snapshot().await { + Ok(Some(snapshot)) => { + // Issue #867: Deep verification using HnswSnapshot::verify() + // This checks internal consistency beyond just count comparison: + // - Count consistency: total_nodes == live_nodes + tombstone_count + // - ID count integrity: indexed_ids.len() == total_nodes + // - Tombstone containment: all tombstoned IDs exist in indexed_ids + if let Err(e) = snapshot.verify() { + result + .discrepancies + .push(format!("Snapshot verification failed: {e}")); + } + + let actual = ShadowMetrics { + item_count: snapshot.total_nodes, + tombstone_count: snapshot.tombstone_count, + snapshot_size: 0, // Not easily available without re-serializing + }; + + // Compare metrics + if actual.item_count != expected.item_count { + result.discrepancies.push(format!( + "item_count mismatch: expected {}, got {}", + expected.item_count, actual.item_count + )); + } + + if actual.tombstone_count != expected.tombstone_count { + result.discrepancies.push(format!( + "tombstone_count mismatch: expected {}, got {}", + expected.tombstone_count, actual.tombstone_count + )); + } + + result.actual = Some(actual); + result.passed = result.discrepancies.is_empty(); + } + Ok(None) => { + result + .discrepancies + .push("snapshot not found after persist".to_string()); + } + Err(e) => { + result + .discrepancies + .push(format!("failed to load snapshot: {e}")); + } + } + + result + } +} diff --git a/crates/khive-retrieval/src/persist/mod.rs b/crates/khive-retrieval/src/persist/mod.rs new file mode 100644 index 00000000..40d4e678 --- /dev/null +++ b/crates/khive-retrieval/src/persist/mod.rs @@ -0,0 +1,318 @@ +//! Retrieval index persistence using SQLite. +//! +//! This module provides SQLite-based persistence for HNSW and BM25 indexes, +//! following the write-through pattern established in khive-engine: +//! +//! 1. Persist snapshots to SQLite (point of no return) +//! 2. Rebuild in-memory indexes on cold start +//! +//! # Architecture +//! +//! ```text +//! HnswIndex ──snapshot──> HnswSnapshot ──serialize──> SQLite BLOB +//! │ +//! HnswIndex <──restore───────────────────────────────────┘ +//! ``` +//! +//! # Feature Flag +//! +//! This module requires the `persist` feature flag: +//! +//! ```toml +//! khive-retrieval = { path = "../khive-retrieval", features = ["persist"] } +//! ``` +//! +//! # Example +//! +//! ```rust,no_run +//! use khive_retrieval::persist::RetrievalPersistence; +//! use khive_retrieval::hnsw::HnswIndex; +//! use rusqlite::Connection; +//! use std::sync::Arc; +//! use tokio::sync::Mutex; +//! +//! async fn example() -> Result<(), Box> { +//! // Open a file-based SQLite connection +//! let conn = Connection::open("retrieval.db")?; +//! let conn = Arc::new(Mutex::new(conn)); +//! +//! let persist = RetrievalPersistence::new(conn, "default"); +//! +//! // Initialize schema before use +//! persist.init_schema().await?; +//! +//! // Persist an HNSW index +//! let index = HnswIndex::new(384); +//! persist.persist_hnsw_snapshot(&index).await?; +//! +//! // Restore on cold start +//! if let Some(snapshot) = persist.load_hnsw_snapshot().await? { +//! // Rebuild index from snapshot +//! } +//! Ok(()) +//! } +//! ``` + +use std::sync::Arc; + +use rusqlite::Connection; +use serde::{de::DeserializeOwned, Serialize}; +use thiserror::Error; +use tokio::sync::Mutex; + +mod bm25; +mod hnsw; +mod shadow; + +#[cfg(test)] +mod tests; + +pub use shadow::{ShadowMetrics, ShadowValidationConfig, ShadowValidationResult}; + +/// Errors that can occur during retrieval persistence operations. +#[derive(Error, Debug)] +pub enum PersistError { + /// SQLite operation failed. + #[error("SQLite error: {0}")] + Sqlite(#[from] rusqlite::Error), + + /// Serialization failed. + #[error("Serialization error: {0}")] + Serialize(String), + + /// Deserialization failed. + #[error("Deserialization error: {0}")] + Deserialize(String), + + /// Spawn blocking task failed. + #[error("Task join error: {0}")] + TaskJoin(String), + + /// Snapshot verification failed. + #[error("Snapshot verification failed: {0}")] + SnapshotVerification(String), + + /// Validation error (e.g. empty namespace, out-of-range parameter). + #[error("Validation error: {0}")] + Validation(String), + + /// Task join error from spawn_blocking. + #[error("Blocking task failed: {0}")] + BlockingJoin(String), + + /// JoinError from tokio spawn_blocking (auto-converted). + #[error("Tokio join error: {0}")] + Join(#[from] tokio::task::JoinError), + + /// Internal error (generic, for ported engine code). + #[error("Internal error: {0}")] + Internal(String), + + /// Embedding error (for ported engine code). + #[error("Embedding error: {0}")] + Embedding(String), + + /// Retrieval error (for ported engine code). + #[error("Retrieval error: {0}")] + Retrieval(String), +} + +/// Retrieval index persistence using SQLite. +/// +/// Provides methods to persist and restore HNSW and BM25 index snapshots +/// to/from SQLite. Uses the write-through pattern from khive-engine. +pub struct RetrievalPersistence { + /// SQLite connection (thread-safe via async mutex). + pub(crate) conn: Arc>, + /// Namespace for multi-tenancy. + /// Uses Arc for O(1) cloning in async spawn contexts. + pub(crate) namespace: Arc, +} + +impl RetrievalPersistence { + /// Create a new persistence layer. + /// + /// # Arguments + /// + /// * `conn` - Arc-wrapped SQLite connection + /// * `namespace` - Namespace for multi-tenancy isolation + pub fn new(conn: Arc>, namespace: impl Into) -> Self { + Self { + conn, + namespace: Arc::from(namespace.into()), + } + } + + /// Initialize the persistence schema. + /// + /// Creates tables for index snapshots if they don't exist. + pub async fn init_schema(&self) -> Result<(), PersistError> { + let conn = self.conn.clone(); + tokio::task::spawn_blocking(move || { + let conn = conn.blocking_lock(); + conn.execute_batch( + r#" + CREATE TABLE IF NOT EXISTS retrieval_snapshots ( + namespace TEXT NOT NULL, + index_type TEXT NOT NULL, + snapshot BLOB NOT NULL, + created_at INTEGER NOT NULL, + PRIMARY KEY (namespace, index_type) + ); + + CREATE INDEX IF NOT EXISTS idx_retrieval_snapshots_namespace + ON retrieval_snapshots(namespace); + "#, + )?; + Ok(()) + }) + .await + .map_err(|e| PersistError::TaskJoin(e.to_string()))? + } + + /// Generic snapshot persistence. + pub(crate) async fn persist_snapshot( + &self, + index_type: &str, + snapshot: &T, + ) -> Result<(), PersistError> { + let data = + serde_json::to_vec(snapshot).map_err(|e| PersistError::Serialize(e.to_string()))?; + + let conn = self.conn.clone(); + let namespace = self.namespace.clone(); + let index_type = index_type.to_string(); + + tokio::task::spawn_blocking(move || { + let conn = conn.blocking_lock(); + conn.execute( + r#" + INSERT OR REPLACE INTO retrieval_snapshots + (namespace, index_type, snapshot, created_at) + VALUES + (?1, ?2, ?3, ?4) + "#, + rusqlite::params![ + &*namespace, + index_type, + data, + chrono::Utc::now().timestamp_micros() + ], + )?; + Ok(()) + }) + .await + .map_err(|e| PersistError::TaskJoin(e.to_string()))? + } + + /// Generic snapshot loading. + pub(crate) async fn load_snapshot( + &self, + index_type: &str, + ) -> Result, PersistError> { + let conn = self.conn.clone(); + let namespace = self.namespace.clone(); + let index_type = index_type.to_string(); + + tokio::task::spawn_blocking(move || { + let conn = conn.blocking_lock(); + let mut stmt = conn.prepare( + r#" + SELECT snapshot FROM retrieval_snapshots + WHERE namespace = ?1 AND index_type = ?2 + "#, + )?; + + let result: Option> = match stmt + .query_row(rusqlite::params![&*namespace, index_type], |row| row.get(0)) + { + Ok(data) => Some(data), + Err(rusqlite::Error::QueryReturnedNoRows) => None, + Err(e) => return Err(PersistError::Sqlite(e)), + }; + + match result { + Some(data) => { + let snapshot: T = serde_json::from_slice(&data) + .map_err(|e| PersistError::Deserialize(e.to_string()))?; + Ok(Some(snapshot)) + } + None => Ok(None), + } + }) + .await + .map_err(|e| PersistError::TaskJoin(e.to_string()))? + } + + /// Delete all snapshots for this namespace. + pub async fn clear(&self) -> Result<(), PersistError> { + let conn = self.conn.clone(); + let namespace = self.namespace.clone(); + + tokio::task::spawn_blocking(move || { + let conn = conn.blocking_lock(); + conn.execute( + "DELETE FROM retrieval_snapshots WHERE namespace = ?1", + rusqlite::params![&*namespace], + )?; + Ok(()) + }) + .await + .map_err(|e| PersistError::TaskJoin(e.to_string()))? + } + + /// Get persistence statistics. + pub async fn stats(&self) -> Result { + let conn = self.conn.clone(); + let namespace = self.namespace.clone(); + + tokio::task::spawn_blocking(move || { + let conn = conn.blocking_lock(); + let mut stmt = conn.prepare( + r#" + SELECT index_type, length(snapshot), created_at + FROM retrieval_snapshots + WHERE namespace = ?1 + "#, + )?; + + let mut stats = PersistenceStats::default(); + let mut rows = stmt.query(rusqlite::params![&*namespace])?; + + while let Some(row) = rows.next()? { + let index_type: String = row.get(0)?; + let size: i64 = row.get(1)?; + let created_at: i64 = row.get(2)?; + + match index_type.as_str() { + "hnsw" => { + stats.hnsw_snapshot_size = size as usize; + stats.hnsw_snapshot_at = Some(created_at); + } + "bm25" => { + stats.bm25_snapshot_size = size as usize; + stats.bm25_snapshot_at = Some(created_at); + } + _ => {} + } + } + + Ok(stats) + }) + .await + .map_err(|e| PersistError::TaskJoin(e.to_string()))? + } +} + +/// Statistics about persisted snapshots. +#[derive(Debug, Default, Clone)] +pub struct PersistenceStats { + /// Size of HNSW snapshot in bytes. + pub hnsw_snapshot_size: usize, + /// Timestamp when HNSW snapshot was created (Unix seconds). + pub hnsw_snapshot_at: Option, + /// Size of BM25 snapshot in bytes. + pub bm25_snapshot_size: usize, + /// Timestamp when BM25 snapshot was created (Unix seconds). + pub bm25_snapshot_at: Option, +} diff --git a/crates/khive-retrieval/src/persist/shadow.rs b/crates/khive-retrieval/src/persist/shadow.rs new file mode 100644 index 00000000..a7814ec7 --- /dev/null +++ b/crates/khive-retrieval/src/persist/shadow.rs @@ -0,0 +1,105 @@ +//! Shadow validation types and helpers for persistence integrity checking. +//! +//! Shadow validation (Issue #628) verifies persisted snapshots can be correctly +//! restored without blocking production operations. Discrepancies are logged only. + +use rand::Rng; + +// --------------------------------------------------------------------------- +// Shadow Validation (Issue #628) +// --------------------------------------------------------------------------- + +/// Configuration for shadow validation. +/// +/// Shadow validation verifies persisted snapshots can be correctly restored +/// without blocking production operations. Discrepancies are logged only. +#[derive(Debug, Clone)] +pub struct ShadowValidationConfig { + /// Whether shadow validation is enabled. + pub enabled: bool, + /// Sample rate for validation (0.0 to 1.0). + /// Set to 1.0 to validate every persist operation. + pub sample_rate: f64, +} + +impl Default for ShadowValidationConfig { + fn default() -> Self { + Self { + enabled: false, + sample_rate: 0.1, // 10% sample rate by default + } + } +} + +impl ShadowValidationConfig { + /// Enable shadow validation with full coverage. + pub fn enabled() -> Self { + Self { + enabled: true, + sample_rate: 1.0, + } + } + + /// Enable shadow validation with a specific sample rate. + pub fn with_sample_rate(rate: f64) -> Self { + Self { + enabled: true, + sample_rate: rate.clamp(0.0, 1.0), + } + } +} + +/// Result of shadow validation. +#[derive(Debug, Clone)] +pub struct ShadowValidationResult { + /// Whether validation passed. + pub passed: bool, + /// Index type that was validated. + pub index_type: String, + /// Expected metrics from the original index. + pub expected: ShadowMetrics, + /// Actual metrics from the restored snapshot. + pub actual: Option, + /// Discrepancies found (empty if validation passed). + pub discrepancies: Vec, +} + +/// Metrics captured for shadow validation comparison. +#[derive(Debug, Clone, Default)] +pub struct ShadowMetrics { + /// Total number of items in the index. + pub item_count: usize, + /// Number of tombstoned/deleted items (HNSW only). + pub tombstone_count: usize, + /// Snapshot size in bytes. + pub snapshot_size: usize, +} + +/// Determine whether to sample this operation for validation. +pub(crate) fn should_sample(rate: f64) -> bool { + if rate >= 1.0 { + return true; + } + if rate <= 0.0 { + return false; + } + rand::thread_rng().gen::() < rate +} + +/// Log the validation result (logging-only, non-blocking). +/// +/// This function logs discrepancies but never blocks or returns errors. +/// In production, this should integrate with the application's logging +/// infrastructure (e.g., tracing crate). +pub(crate) fn log_validation_result(result: &ShadowValidationResult) { + // Only log failures - successful validations are silent by default + // to avoid log noise. The result is still returned to callers who + // may want to record metrics or take other actions. + if !result.passed { + tracing::warn!( + index_type = %result.index_type, + discrepancies = ?result.discrepancies, + "Shadow validation failed" + ); + } +} diff --git a/crates/khive-retrieval/src/persist/tests.rs b/crates/khive-retrieval/src/persist/tests.rs new file mode 100644 index 00000000..2efdf72d --- /dev/null +++ b/crates/khive-retrieval/src/persist/tests.rs @@ -0,0 +1,1214 @@ +use super::*; +use khive_bm25::Bm25Index; +use khive_hnsw::HnswIndex; +use rusqlite::Connection; +use std::sync::Arc; +use tokio::sync::Mutex; + +async fn setup_test_persistence() -> RetrievalPersistence { + let conn = Connection::open_in_memory().expect("open in-memory db"); + conn.execute_batch("PRAGMA journal_mode=WAL; PRAGMA synchronous=NORMAL;") + .expect("set pragmas"); + let persist = RetrievalPersistence::new(Arc::new(Mutex::new(conn)), "test"); + persist.init_schema().await.expect("init schema"); + persist +} + +#[tokio::test] +async fn test_persist_and_load_bm25() { + let persist = setup_test_persistence().await; + + // Create and persist a BM25 index + let mut index = Bm25Index::default(); + index + .index_document("doc1", "hello world") + .expect("index doc"); + index + .index_document("doc2", "goodbye world") + .expect("index doc"); + + persist.persist_bm25_index(&index).await.expect("persist"); + + // Load and verify + let loaded = persist.load_bm25_index().await.expect("load"); + assert!(loaded.is_some()); + let loaded = loaded.unwrap(); + assert_eq!(loaded.doc_count(), 2); +} + +#[tokio::test] +async fn test_persist_and_load_hnsw() { + let persist = setup_test_persistence().await; + + // Create and persist an HNSW index with some vectors + let mut index = HnswIndex::new(4); // 4 dimensions + + // Insert a few vectors + let id1 = NodeId::new([1; 16]); + let id2 = NodeId::new([2; 16]); + let id3 = NodeId::new([3; 16]); + + index.insert(id1, vec![1.0, 0.0, 0.0, 0.0]).expect("insert"); + index.insert(id2, vec![0.0, 1.0, 0.0, 0.0]).expect("insert"); + index.insert(id3, vec![0.0, 0.0, 1.0, 0.0]).expect("insert"); + + assert_eq!(index.len(), 3); + + // Persist the snapshot + persist + .persist_hnsw_snapshot(&index) + .await + .expect("persist"); + + // Load and verify the snapshot + let loaded = persist.load_hnsw_snapshot().await.expect("load"); + assert!(loaded.is_some()); + let snapshot = loaded.unwrap(); + + // Verify snapshot contains correct metadata + assert_eq!(snapshot.total_nodes, 3); + assert_eq!(snapshot.live_nodes, 3); + assert_eq!(snapshot.tombstone_count, 0); + assert_eq!(snapshot.indexed_ids.len(), 3); + assert!(snapshot.indexed_ids.contains(&id1)); + assert!(snapshot.indexed_ids.contains(&id2)); + assert!(snapshot.indexed_ids.contains(&id3)); +} + +#[tokio::test] +async fn test_stats() { + let persist = setup_test_persistence().await; + + // Initially empty + let stats = persist.stats().await.expect("stats"); + assert_eq!(stats.hnsw_snapshot_size, 0); + assert_eq!(stats.bm25_snapshot_size, 0); + + // Persist BM25 + let index = Bm25Index::default(); + persist.persist_bm25_index(&index).await.expect("persist"); + + // Check stats + let stats = persist.stats().await.expect("stats"); + assert!(stats.bm25_snapshot_size > 0); + assert!(stats.bm25_snapshot_at.is_some()); +} + +#[tokio::test] +async fn test_clear() { + let persist = setup_test_persistence().await; + + // Persist something + let index = Bm25Index::default(); + persist.persist_bm25_index(&index).await.expect("persist"); + + // Clear + persist.clear().await.expect("clear"); + + // Should be gone + let loaded = persist.load_bm25_index().await.expect("load"); + assert!(loaded.is_none()); +} + +// -- Shadow validation tests -- + +#[tokio::test] +async fn test_shadow_validation_config_default() { + let config = ShadowValidationConfig::default(); + assert!(!config.enabled); + assert!((config.sample_rate - 0.1).abs() < f64::EPSILON); +} + +#[tokio::test] +async fn test_shadow_validation_config_enabled() { + let config = ShadowValidationConfig::enabled(); + assert!(config.enabled); + assert!((config.sample_rate - 1.0).abs() < f64::EPSILON); +} + +#[tokio::test] +async fn test_shadow_validation_config_sample_rate() { + let config = ShadowValidationConfig::with_sample_rate(0.5); + assert!(config.enabled); + assert!((config.sample_rate - 0.5).abs() < f64::EPSILON); + + // Test clamping + let config = ShadowValidationConfig::with_sample_rate(1.5); + assert!((config.sample_rate - 1.0).abs() < f64::EPSILON); + + let config = ShadowValidationConfig::with_sample_rate(-0.5); + assert!((config.sample_rate - 0.0).abs() < f64::EPSILON); +} + +#[tokio::test] +async fn test_bm25_shadow_validation_passes() { + let persist = setup_test_persistence().await; + let config = ShadowValidationConfig::enabled(); + + // Create and persist a BM25 index with validation + let mut index = Bm25Index::default(); + index + .index_document("doc1", "hello world") + .expect("index doc"); + index + .index_document("doc2", "goodbye world") + .expect("index doc"); + + let result = persist + .persist_bm25_with_validation(&index, &config) + .await + .expect("persist with validation"); + + assert!(result.is_some()); + let validation = result.unwrap(); + assert!( + validation.passed, + "validation should pass: {:?}", + validation.discrepancies + ); + assert_eq!(validation.index_type, "bm25"); + assert_eq!(validation.expected.item_count, 2); + assert!(validation.discrepancies.is_empty()); +} + +#[tokio::test] +async fn test_shadow_validation_skipped_when_disabled() { + let persist = setup_test_persistence().await; + let config = ShadowValidationConfig::default(); // disabled + + let index = Bm25Index::default(); + let result = persist + .persist_bm25_with_validation(&index, &config) + .await + .expect("persist"); + + // Validation should be skipped + assert!(result.is_none()); + + // But the persist should still work + let loaded = persist.load_bm25_index().await.expect("load"); + assert!(loaded.is_some()); +} + +#[tokio::test] +async fn test_should_sample() { + use super::shadow::should_sample; + + // Always sample at 1.0 + assert!(should_sample(1.0)); + assert!(should_sample(1.5)); // clamped to 1.0 + + // Never sample at 0.0 + assert!(!should_sample(0.0)); + assert!(!should_sample(-0.5)); // clamped to 0.0 +} + +// -- Issue #865: HNSW shadow validation test -- + +#[tokio::test] +async fn test_hnsw_shadow_validation_passes() { + let persist = setup_test_persistence().await; + let config = ShadowValidationConfig::enabled(); + + // Create an HNSW index with vectors + let mut index = HnswIndex::new(4); + let id1 = NodeId::new([1; 16]); + let id2 = NodeId::new([2; 16]); + + index.insert(id1, vec![1.0, 0.0, 0.0, 0.0]).expect("insert"); + index.insert(id2, vec![0.0, 1.0, 0.0, 0.0]).expect("insert"); + + let result = persist + .persist_hnsw_with_validation(&index, &config) + .await + .expect("persist with validation"); + + assert!(result.is_some()); + let validation = result.unwrap(); + assert!( + validation.passed, + "validation should pass: {:?}", + validation.discrepancies + ); + assert_eq!(validation.index_type, "hnsw"); + assert_eq!(validation.expected.item_count, 2); + assert!(validation.discrepancies.is_empty()); +} + +#[tokio::test] +async fn test_hnsw_shadow_validation_with_tombstones() { + let persist = setup_test_persistence().await; + let config = ShadowValidationConfig::enabled(); + + // Create an HNSW index with vectors and tombstones + let mut index = HnswIndex::new(4); + let id1 = NodeId::new([1; 16]); + let id2 = NodeId::new([2; 16]); + let id3 = NodeId::new([3; 16]); + + index.insert(id1, vec![1.0, 0.0, 0.0, 0.0]).expect("insert"); + index.insert(id2, vec![0.0, 1.0, 0.0, 0.0]).expect("insert"); + index.insert(id3, vec![0.0, 0.0, 1.0, 0.0]).expect("insert"); + index.delete(id2); // Tombstone id2 + + let result = persist + .persist_hnsw_with_validation(&index, &config) + .await + .expect("persist with validation"); + + assert!(result.is_some()); + let validation = result.unwrap(); + assert!( + validation.passed, + "validation should pass with tombstones: {:?}", + validation.discrepancies + ); + assert_eq!(validation.expected.item_count, 3); // total_nodes including tombstones + assert_eq!(validation.expected.tombstone_count, 1); +} + +// -- Issue #866: Namespace isolation test -- + +#[tokio::test] +async fn test_namespace_isolation() { + let conn = Connection::open_in_memory().expect("open in-memory db"); + conn.execute_batch("PRAGMA journal_mode=WAL; PRAGMA synchronous=NORMAL;") + .expect("set pragmas"); + let conn = Arc::new(Mutex::new(conn)); + + // Create two persistence layers with different namespaces + let persist_ns1 = RetrievalPersistence::new(conn.clone(), "namespace1"); + let persist_ns2 = RetrievalPersistence::new(conn.clone(), "namespace2"); + + // Initialize schema (only needed once since they share the connection) + persist_ns1.init_schema().await.expect("init schema"); + + // Persist different data to each namespace + let mut index1 = Bm25Index::default(); + index1 + .index_document("doc1", "namespace one content") + .expect("index"); + + let mut index2 = Bm25Index::default(); + index2 + .index_document("doc2", "namespace two content") + .expect("index"); + index2 + .index_document("doc3", "more namespace two") + .expect("index"); + + persist_ns1 + .persist_bm25_index(&index1) + .await + .expect("persist ns1"); + persist_ns2 + .persist_bm25_index(&index2) + .await + .expect("persist ns2"); + + // Verify each namespace loads its own data + let loaded1 = persist_ns1.load_bm25_index().await.expect("load ns1"); + let loaded2 = persist_ns2.load_bm25_index().await.expect("load ns2"); + + assert!(loaded1.is_some()); + assert!(loaded2.is_some()); + assert_eq!(loaded1.unwrap().doc_count(), 1); + assert_eq!(loaded2.unwrap().doc_count(), 2); + + // Clear one namespace and verify the other is unaffected + persist_ns1.clear().await.expect("clear ns1"); + + let loaded1_after = persist_ns1 + .load_bm25_index() + .await + .expect("load ns1 after clear"); + let loaded2_after = persist_ns2 + .load_bm25_index() + .await + .expect("load ns2 after clear"); + + assert!(loaded1_after.is_none(), "ns1 should be cleared"); + assert!(loaded2_after.is_some(), "ns2 should still exist"); + assert_eq!(loaded2_after.unwrap().doc_count(), 2); +} + +// -- Issue #868: Corrupted data handling tests -- + +#[tokio::test] +async fn test_corrupted_bm25_data_returns_error() { + let persist = setup_test_persistence().await; + + // Manually insert corrupted JSON + { + let conn = persist.conn.clone(); + let namespace = "test".to_string(); + tokio::task::spawn_blocking(move || { + let conn = conn.blocking_lock(); + conn.execute( + r#" + INSERT OR REPLACE INTO retrieval_snapshots + (namespace, index_type, snapshot, created_at) + VALUES + (?1, 'bm25', ?2, strftime('%s', 'now')) + "#, + rusqlite::params![namespace, b"not valid json {{{{"], + ) + .expect("insert corrupted"); + }) + .await + .expect("spawn"); + } + + // Attempt to load should return an error + let result = persist.load_bm25_index().await; + assert!(result.is_err(), "loading corrupted data should fail"); + let err = result.unwrap_err(); + assert!(matches!(err, PersistError::Deserialize(_))); +} + +#[tokio::test] +async fn test_corrupted_hnsw_data_returns_error() { + let persist = setup_test_persistence().await; + + // Manually insert corrupted JSON + { + let conn = persist.conn.clone(); + let namespace = "test".to_string(); + tokio::task::spawn_blocking(move || { + let conn = conn.blocking_lock(); + conn.execute( + r#" + INSERT OR REPLACE INTO retrieval_snapshots + (namespace, index_type, snapshot, created_at) + VALUES + (?1, 'hnsw', ?2, strftime('%s', 'now')) + "#, + rusqlite::params![namespace, b"truncated json {\"total_nodes\":"], + ) + .expect("insert corrupted"); + }) + .await + .expect("spawn"); + } + + // Attempt to load should return an error + let result = persist.load_hnsw_snapshot().await; + assert!(result.is_err(), "loading corrupted HNSW data should fail"); + let err = result.unwrap_err(); + assert!(matches!(err, PersistError::Deserialize(_))); +} + +// -- Issue #868: Additional corrupted data handling tests -- + +#[tokio::test] +async fn test_valid_json_wrong_schema_bm25() { + let persist = setup_test_persistence().await; + + // Insert valid JSON but wrong schema (missing required fields) + { + let conn = persist.conn.clone(); + let namespace = "test".to_string(); + tokio::task::spawn_blocking(move || { + let conn = conn.blocking_lock(); + // Valid JSON but wrong structure for Bm25Index + let wrong_schema = br#"{"some_field": "value", "other": 123}"#; + conn.execute( + r#" + INSERT OR REPLACE INTO retrieval_snapshots + (namespace, index_type, snapshot, created_at) + VALUES + (?1, 'bm25', ?2, strftime('%s', 'now')) + "#, + rusqlite::params![namespace, wrong_schema.as_slice()], + ) + .expect("insert wrong schema"); + }) + .await + .expect("spawn"); + } + + // Attempt to load should return an error (missing required fields) + let result = persist.load_bm25_index().await; + assert!(result.is_err(), "loading wrong schema should fail"); + let err = result.unwrap_err(); + assert!(matches!(err, PersistError::Deserialize(_))); +} + +#[tokio::test] +async fn test_valid_json_wrong_schema_hnsw() { + let persist = setup_test_persistence().await; + + // Insert valid JSON but wrong schema (missing required fields for HnswSnapshot) + { + let conn = persist.conn.clone(); + let namespace = "test".to_string(); + tokio::task::spawn_blocking(move || { + let conn = conn.blocking_lock(); + // Valid JSON but wrong structure for HnswSnapshot + let wrong_schema = br#"{"total_nodes": 5, "wrong_field": true}"#; + conn.execute( + r#" + INSERT OR REPLACE INTO retrieval_snapshots + (namespace, index_type, snapshot, created_at) + VALUES + (?1, 'hnsw', ?2, strftime('%s', 'now')) + "#, + rusqlite::params![namespace, wrong_schema.as_slice()], + ) + .expect("insert wrong schema"); + }) + .await + .expect("spawn"); + } + + // Attempt to load should return an error (missing required fields) + let result = persist.load_hnsw_snapshot().await; + assert!(result.is_err(), "loading wrong schema should fail"); + let err = result.unwrap_err(); + assert!(matches!(err, PersistError::Deserialize(_))); +} + +#[tokio::test] +async fn test_empty_blob_returns_error() { + let persist = setup_test_persistence().await; + + // Insert empty blob + { + let conn = persist.conn.clone(); + let namespace = "test".to_string(); + tokio::task::spawn_blocking(move || { + let conn = conn.blocking_lock(); + conn.execute( + r#" + INSERT OR REPLACE INTO retrieval_snapshots + (namespace, index_type, snapshot, created_at) + VALUES + (?1, 'bm25', ?2, strftime('%s', 'now')) + "#, + rusqlite::params![namespace, &[] as &[u8]], + ) + .expect("insert empty blob"); + }) + .await + .expect("spawn"); + } + + // Attempt to load should return an error + let result = persist.load_bm25_index().await; + assert!(result.is_err(), "loading empty blob should fail"); + let err = result.unwrap_err(); + assert!(matches!(err, PersistError::Deserialize(_))); +} + +// -- Issue #869: Empty index persistence edge case tests -- + +#[tokio::test] +async fn test_empty_bm25_index_persistence() { + let persist = setup_test_persistence().await; + + // Persist an empty BM25 index + let index = Bm25Index::default(); + assert_eq!(index.doc_count(), 0); + + persist + .persist_bm25_index(&index) + .await + .expect("persist empty"); + + // Load and verify + let loaded = persist.load_bm25_index().await.expect("load"); + assert!(loaded.is_some()); + let loaded = loaded.unwrap(); + assert_eq!(loaded.doc_count(), 0, "empty index should remain empty"); +} + +#[tokio::test] +async fn test_empty_hnsw_index_persistence() { + let persist = setup_test_persistence().await; + + // Persist an empty HNSW index + let index = HnswIndex::new(4); + assert_eq!(index.len(), 0); + + persist + .persist_hnsw_snapshot(&index) + .await + .expect("persist empty"); + + // Load and verify + let loaded = persist.load_hnsw_snapshot().await.expect("load"); + assert!(loaded.is_some()); + let snapshot = loaded.unwrap(); + assert_eq!( + snapshot.total_nodes, 0, + "empty index snapshot should have 0 nodes" + ); + assert_eq!(snapshot.live_nodes, 0); + assert!(snapshot.indexed_ids.is_empty()); +} + +#[tokio::test] +async fn test_empty_hnsw_shadow_validation() { + let persist = setup_test_persistence().await; + let config = ShadowValidationConfig::enabled(); + + // Empty HNSW index + let index = HnswIndex::new(4); + + let result = persist + .persist_hnsw_with_validation(&index, &config) + .await + .expect("persist empty with validation"); + + assert!(result.is_some()); + let validation = result.unwrap(); + assert!(validation.passed, "empty index validation should pass"); + assert_eq!(validation.expected.item_count, 0); +} + +// -- Issue #867: Test that verify() is called during shadow validation -- + +#[tokio::test] +async fn test_hnsw_shadow_validation_calls_verify() { + let persist = setup_test_persistence().await; + let config = ShadowValidationConfig::enabled(); + + // Create an HNSW index with vectors + let mut index = HnswIndex::new(4); + let id1 = NodeId::new([1; 16]); + let id2 = NodeId::new([2; 16]); + let id3 = NodeId::new([3; 16]); + + index.insert(id1, vec![1.0, 0.0, 0.0, 0.0]).expect("insert"); + index.insert(id2, vec![0.0, 1.0, 0.0, 0.0]).expect("insert"); + index.insert(id3, vec![0.0, 0.0, 1.0, 0.0]).expect("insert"); + index.delete(id2); // Create tombstone + + let result = persist + .persist_hnsw_with_validation(&index, &config) + .await + .expect("persist with validation"); + + // Validation should pass because verify() succeeds on valid snapshot + assert!(result.is_some()); + let validation = result.unwrap(); + assert!( + validation.passed, + "valid snapshot should pass verify(): {:?}", + validation.discrepancies + ); + assert_eq!(validation.expected.item_count, 3); + assert_eq!(validation.expected.tombstone_count, 1); +} + +// ========================================================================== +// Issue #1114: HNSW index corruption recovery tests +// ========================================================================== +// +// These tests verify that the persistence layer correctly detects and handles +// various forms of HNSW index corruption, enabling the engine to recover +// by rebuilding from source data. + +/// Helper: insert raw bytes into the HNSW snapshot slot for a persistence instance. +async fn inject_raw_hnsw_snapshot(persist: &RetrievalPersistence, data: &[u8]) { + let conn = persist.conn.clone(); + let namespace = persist.namespace.clone(); + let data = data.to_vec(); + tokio::task::spawn_blocking(move || { + let conn = conn.blocking_lock(); + conn.execute( + r#" + INSERT OR REPLACE INTO retrieval_snapshots + (namespace, index_type, snapshot, created_at) + VALUES + (?1, 'hnsw', ?2, strftime('%s', 'now')) + "#, + rusqlite::params![&*namespace, data], + ) + .expect("inject raw snapshot"); + }) + .await + .expect("spawn"); +} + +/// Helper: build a valid HNSW index with some vectors and persist it. +async fn build_and_persist_hnsw(persist: &RetrievalPersistence) -> HnswIndex { + let mut index = HnswIndex::new(4); + let id1 = NodeId::new([1; 16]); + let id2 = NodeId::new([2; 16]); + let id3 = NodeId::new([3; 16]); + + index.insert(id1, vec![1.0, 0.0, 0.0, 0.0]).expect("insert"); + index.insert(id2, vec![0.0, 1.0, 0.0, 0.0]).expect("insert"); + index.insert(id3, vec![0.0, 0.0, 1.0, 0.0]).expect("insert"); + + persist + .persist_hnsw_snapshot(&index) + .await + .expect("persist"); + index +} + +// -- Test: Truncated HNSW snapshot file -- +// +// Scenario: The snapshot BLOB in SQLite is truncated (e.g., write was interrupted). +// Expected: load_hnsw_snapshot returns a Deserialize error, not a panic or corrupt data. + +#[tokio::test] +async fn test_truncated_hnsw_snapshot_detected() { + let persist = setup_test_persistence().await; + + // First persist a valid snapshot so we have realistic JSON to truncate + build_and_persist_hnsw(&persist).await; + + // Load the valid snapshot and get its serialized form + let valid_snapshot = persist + .load_hnsw_snapshot() + .await + .expect("load valid") + .expect("snapshot exists"); + + let valid_json = serde_json::to_vec(&valid_snapshot).expect("serialize"); + assert!(valid_json.len() > 20, "valid JSON should be non-trivial"); + + // Truncate at various points to simulate interrupted writes + for truncate_at in [1, 10, valid_json.len() / 4, valid_json.len() / 2] { + let truncated = &valid_json[..truncate_at]; + inject_raw_hnsw_snapshot(&persist, truncated).await; + + let result = persist.load_hnsw_snapshot().await; + assert!( + result.is_err(), + "truncated snapshot (at byte {truncate_at}) should fail to load" + ); + let err = result.unwrap_err(); + assert!( + matches!(err, PersistError::Deserialize(_)), + "should be a Deserialize error, got: {err:?}" + ); + } +} + +// -- Test: Corrupted bytes in HNSW snapshot -- +// +// Scenario: Random byte corruption in the snapshot BLOB (e.g., disk bit flip). +// Expected: Deserialization fails or snapshot verify() catches inconsistency. + +#[tokio::test] +async fn test_corrupted_bytes_in_hnsw_snapshot_detected() { + let persist = setup_test_persistence().await; + + // Build and persist a valid snapshot + build_and_persist_hnsw(&persist).await; + + let valid_snapshot = persist + .load_hnsw_snapshot() + .await + .expect("load valid") + .expect("snapshot exists"); + + let mut corrupted_json = serde_json::to_vec(&valid_snapshot).expect("serialize"); + + // Corrupt bytes in the middle of the JSON (likely to break structure) + let mid = corrupted_json.len() / 2; + for i in mid..mid.saturating_add(10).min(corrupted_json.len()) { + corrupted_json[i] = 0xFF; + } + + inject_raw_hnsw_snapshot(&persist, &corrupted_json).await; + + let result = persist.load_hnsw_snapshot().await; + + // The corrupted JSON should either fail to deserialize or produce + // a snapshot that fails verification. Either outcome is acceptable + // as long as we don't silently return corrupt data. + match result { + Err(PersistError::Deserialize(_)) => { + // Good: deserialization caught it + } + Ok(Some(snapshot)) => { + // If it deserialized, verify() should catch the inconsistency + // (corrupted counts, missing IDs, etc.) + let verify_result = snapshot.verify(); + // Even if verify passes (unlikely with random corruption), we accept it + // because the snapshot's data fields would be garbled. The key invariant + // is that we don't panic or produce silently wrong results. + let _ = verify_result; + } + Ok(None) => { + panic!("snapshot was injected, should not return None"); + } + Err(other) => { + panic!("unexpected error variant: {other:?}"); + } + } +} + +// -- Test: Missing HNSW snapshot (no row in SQLite) -- +// +// Scenario: The snapshot row doesn't exist (e.g., first boot, or snapshot was +// deleted/cleared). Engine should detect this and rebuild from source. + +#[tokio::test] +async fn test_missing_hnsw_snapshot_returns_none() { + let persist = setup_test_persistence().await; + + // No snapshot has been persisted yet + let result = persist + .load_hnsw_snapshot() + .await + .expect("load should not error"); + assert!( + result.is_none(), + "missing snapshot should return None, not error" + ); +} + +#[tokio::test] +async fn test_missing_hnsw_snapshot_after_clear_returns_none() { + let persist = setup_test_persistence().await; + + // Persist a valid snapshot + build_and_persist_hnsw(&persist).await; + + // Verify it exists + let loaded = persist.load_hnsw_snapshot().await.expect("load"); + assert!(loaded.is_some(), "snapshot should exist before clear"); + + // Clear all snapshots (simulating data loss / recovery scenario) + persist.clear().await.expect("clear"); + + // Now loading should return None + let after_clear = persist + .load_hnsw_snapshot() + .await + .expect("load after clear"); + assert!( + after_clear.is_none(), + "snapshot should be None after clear, enabling rebuild from source" + ); +} + +// -- Test: HNSW snapshot with internally inconsistent state -- +// +// Scenario: Snapshot deserializes successfully but has corrupted internal state +// (e.g., total_nodes doesn't match indexed_ids count). This simulates +// a partial write or in-memory corruption before serialization. + +#[tokio::test] +async fn test_inconsistent_hnsw_snapshot_detected_by_verify() { + use khive_hnsw::{HnswCheckpointConfig, HnswSnapshot}; + + let persist = setup_test_persistence().await; + + let id1 = NodeId::new([1; 16]); + let id2 = NodeId::new([2; 16]); + + // Create a snapshot where total_nodes doesn't match indexed_ids.len() + let bad_snapshot = HnswSnapshot { + vector_count: 0, + total_nodes: 5, // WRONG: says 5 but only 2 IDs + live_nodes: 5, + tombstone_count: 0, + max_layer: 0, + entry_point: Some(id1), + config: HnswCheckpointConfig { + m: 16, + ef_construction: 200, + metric: "cosine".to_string(), + }, + indexed_ids: vec![id1, id2], // Only 2 IDs + tombstoned_ids: vec![], + layers: vec![vec![(id1, vec![id2]), (id2, vec![id1])]], + + vectors: vec![], + }; + + // Persist it (persistence layer doesn't validate, just serializes) + let data = serde_json::to_vec(&bad_snapshot).expect("serialize"); + inject_raw_hnsw_snapshot(&persist, &data).await; + + // Load succeeds (it's valid JSON with correct schema) + let loaded = persist + .load_hnsw_snapshot() + .await + .expect("load should succeed for valid JSON"); + assert!(loaded.is_some(), "snapshot should load"); + + let snapshot = loaded.unwrap(); + + // But verify() detects the inconsistency + let verify_result = snapshot.verify(); + assert!( + verify_result.is_err(), + "verify should catch total_nodes != indexed_ids.len()" + ); +} + +#[tokio::test] +async fn test_tombstone_inconsistency_detected_by_verify() { + use khive_hnsw::{HnswCheckpointConfig, HnswSnapshot}; + + let persist = setup_test_persistence().await; + + let id1 = NodeId::new([1; 16]); + let id2 = NodeId::new([2; 16]); + let id3 = NodeId::new([3; 16]); + let id_phantom = NodeId::new([99; 16]); + + // Snapshot claims id_phantom is tombstoned but it's not in indexed_ids + let bad_snapshot = HnswSnapshot { + vector_count: 0, + total_nodes: 3, + live_nodes: 2, + tombstone_count: 1, + max_layer: 0, + entry_point: Some(id1), + config: HnswCheckpointConfig { + m: 16, + ef_construction: 200, + metric: "cosine".to_string(), + }, + indexed_ids: vec![id1, id2, id3], + tombstoned_ids: vec![id_phantom], // NOT in indexed_ids + layers: vec![], + + vectors: vec![], + }; + + let data = serde_json::to_vec(&bad_snapshot).expect("serialize"); + inject_raw_hnsw_snapshot(&persist, &data).await; + + let loaded = persist + .load_hnsw_snapshot() + .await + .expect("load") + .expect("snapshot exists"); + + let verify_result = loaded.verify(); + assert!( + verify_result.is_err(), + "verify should catch tombstoned ID not in indexed_ids" + ); +} + +// -- Test: Shadow validation catches corrupted snapshot state -- +// +// Scenario: Snapshot is persisted correctly, then corrupted in-place in SQLite. +// Shadow validation (read-back) should detect the corruption. + +#[tokio::test] +async fn test_shadow_validation_detects_corruption() { + let persist = setup_test_persistence().await; + + // Build and persist valid index + let mut index = HnswIndex::new(4); + let id1 = NodeId::new([1; 16]); + let id2 = NodeId::new([2; 16]); + index.insert(id1, vec![1.0, 0.0, 0.0, 0.0]).expect("insert"); + index.insert(id2, vec![0.0, 1.0, 0.0, 0.0]).expect("insert"); + + persist + .persist_hnsw_snapshot(&index) + .await + .expect("persist"); + + // Now corrupt the stored data in-place + inject_raw_hnsw_snapshot(&persist, b"not valid json at all {{{").await; + + // Shadow validation should detect the corruption + let expected = ShadowMetrics { + item_count: 2, + tombstone_count: 0, + snapshot_size: 0, + }; + + let result = persist.validate_hnsw_snapshot(expected).await; + assert!( + !result.passed, + "shadow validation should fail on corrupted data" + ); + assert!( + !result.discrepancies.is_empty(), + "should report discrepancies" + ); +} + +// -- Test: Full recovery workflow -- +// +// Scenario: Snapshot is corrupted. Engine detects via load failure, clears the +// corrupt entry, and rebuilds from source vectors. After rebuild, +// the new snapshot is valid. + +#[tokio::test] +async fn test_full_recovery_workflow_corrupt_then_rebuild() { + let persist = setup_test_persistence().await; + + let id1 = NodeId::new([1; 16]); + let id2 = NodeId::new([2; 16]); + let id3 = NodeId::new([3; 16]); + + let vectors = vec![ + (id1, vec![1.0, 0.0, 0.0, 0.0]), + (id2, vec![0.0, 1.0, 0.0, 0.0]), + (id3, vec![0.0, 0.0, 1.0, 0.0]), + ]; + + // Step 1: Build and persist a valid index + { + let mut index = HnswIndex::new(4); + for (id, vec) in &vectors { + index.insert(*id, vec.clone()).expect("insert"); + } + persist + .persist_hnsw_snapshot(&index) + .await + .expect("persist"); + } + + // Step 2: Corrupt the snapshot + inject_raw_hnsw_snapshot(&persist, b"corrupted snapshot data").await; + + // Step 3: Attempt to load -- should fail + let load_result = persist.load_hnsw_snapshot().await; + assert!( + load_result.is_err(), + "loading corrupted snapshot should fail" + ); + + // Step 4: Recovery -- clear corrupt data + persist.clear().await.expect("clear corrupted data"); + + // Step 5: Verify cleared + let after_clear = persist + .load_hnsw_snapshot() + .await + .expect("load after clear"); + assert!(after_clear.is_none(), "snapshot should be gone after clear"); + + // Step 6: Rebuild index from source vectors + let mut rebuilt_index = HnswIndex::new(4); + for (id, vec) in &vectors { + rebuilt_index.insert(*id, vec.clone()).expect("re-insert"); + } + + assert_eq!( + rebuilt_index.len(), + 3, + "rebuilt index should have 3 vectors" + ); + + // Step 7: Persist the rebuilt index + persist + .persist_hnsw_snapshot(&rebuilt_index) + .await + .expect("persist rebuilt"); + + // Step 8: Verify the new snapshot is valid + let new_snapshot = persist + .load_hnsw_snapshot() + .await + .expect("load rebuilt") + .expect("snapshot exists"); + + assert_eq!(new_snapshot.total_nodes, 3); + assert_eq!(new_snapshot.live_nodes, 3); + assert!( + new_snapshot.verify().is_ok(), + "rebuilt snapshot should pass verification" + ); +} + +// -- Test: Recovery from inconsistent snapshot via verify-then-rebuild -- +// +// Scenario: Snapshot loads but fails verify(). Engine should detect this and +// trigger rebuild rather than using the corrupt topology. + +#[tokio::test] +async fn test_recovery_from_inconsistent_snapshot_via_verify() { + use khive_hnsw::{HnswCheckpointConfig, HnswSnapshot}; + + let persist = setup_test_persistence().await; + + let id1 = NodeId::new([1; 16]); + let id2 = NodeId::new([2; 16]); + let id3 = NodeId::new([3; 16]); + + // Inject a snapshot with mismatched tombstone counts + let bad_snapshot = HnswSnapshot { + vector_count: 0, + total_nodes: 3, + live_nodes: 1, + tombstone_count: 2, // Claims 2 tombstones + max_layer: 0, + entry_point: Some(id1), + config: HnswCheckpointConfig { + m: 16, + ef_construction: 200, + metric: "cosine".to_string(), + }, + indexed_ids: vec![id1, id2, id3], + tombstoned_ids: vec![id2], // Only 1 tombstone ID (mismatch!) + layers: vec![], + + vectors: vec![], + }; + + let data = serde_json::to_vec(&bad_snapshot).expect("serialize"); + inject_raw_hnsw_snapshot(&persist, &data).await; + + // Load succeeds + let loaded = persist + .load_hnsw_snapshot() + .await + .expect("load") + .expect("snapshot exists"); + + // But verify catches the corruption + let verify_err = loaded.verify().unwrap_err(); + let err_msg = verify_err.to_string(); + assert!( + err_msg.contains("tombstoned_ids count mismatch"), + "should report tombstone count mismatch, got: {err_msg}" + ); + + // Recovery: clear and rebuild + persist.clear().await.expect("clear"); + + let mut rebuilt = HnswIndex::new(4); + rebuilt + .insert(id1, vec![1.0, 0.0, 0.0, 0.0]) + .expect("insert"); + rebuilt + .insert(id2, vec![0.0, 1.0, 0.0, 0.0]) + .expect("insert"); + rebuilt + .insert(id3, vec![0.0, 0.0, 1.0, 0.0]) + .expect("insert"); + + persist + .persist_hnsw_snapshot(&rebuilt) + .await + .expect("persist rebuilt"); + + let new_snapshot = persist + .load_hnsw_snapshot() + .await + .expect("load") + .expect("snapshot exists"); + assert!( + new_snapshot.verify().is_ok(), + "rebuilt snapshot should be valid" + ); +} + +// -- Test: Restore from snapshot detects corrupt snapshot -- +// +// Scenario: An index tries to restore_from_snapshot with a corrupt snapshot. +// The restore should fail with an error, not silently use bad data. + +#[tokio::test] +async fn test_restore_from_corrupt_snapshot_fails() { + use khive_hnsw::{HnswCheckpointConfig, HnswSnapshot}; + + let id1 = NodeId::new([1; 16]); + let id2 = NodeId::new([2; 16]); + + let mut index = HnswIndex::new(4); + + // Create a corrupt snapshot (total_nodes mismatch) + let bad_snapshot = HnswSnapshot { + vector_count: 0, + total_nodes: 10, // WRONG + live_nodes: 10, + tombstone_count: 0, + max_layer: 0, + entry_point: Some(id1), + config: HnswCheckpointConfig { + m: 16, + ef_construction: 200, + metric: "cosine".to_string(), + }, + indexed_ids: vec![id1, id2], // Only 2 + tombstoned_ids: vec![], + layers: vec![], + + vectors: vec![], + }; + + let vectors: std::collections::HashMap> = [ + (id1, vec![1.0, 0.0, 0.0, 0.0]), + (id2, vec![0.0, 1.0, 0.0, 0.0]), + ] + .into_iter() + .collect(); + + let result = index.restore_from_snapshot(&bad_snapshot, &vectors); + assert!( + result.is_err(), + "restore_from_snapshot should reject corrupt snapshot" + ); + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("Invalid snapshot"), + "error should mention invalid snapshot, got: {err_msg}" + ); +} + +// -- Test: Binary garbage as HNSW snapshot -- +// +// Scenario: Random non-JSON binary data in the snapshot slot (e.g., disk corruption +// that overwrites the entire BLOB). + +#[tokio::test] +async fn test_binary_garbage_hnsw_snapshot_detected() { + let persist = setup_test_persistence().await; + + // Insert pure binary garbage + let garbage: Vec = (0..256).map(|i| i as u8).collect(); + inject_raw_hnsw_snapshot(&persist, &garbage).await; + + let result = persist.load_hnsw_snapshot().await; + assert!(result.is_err(), "binary garbage should fail to deserialize"); + let err = result.unwrap_err(); + assert!( + matches!(err, PersistError::Deserialize(_)), + "should be Deserialize error, got: {err:?}" + ); +} + +// -- Test: Overwrite corrupt snapshot with valid one -- +// +// Scenario: After detecting corruption, persisting a new valid snapshot should +// overwrite the corrupt data (INSERT OR REPLACE behavior). + +#[tokio::test] +async fn test_overwrite_corrupt_snapshot_with_valid() { + let persist = setup_test_persistence().await; + + // Inject corrupt data + inject_raw_hnsw_snapshot(&persist, b"this is not valid json").await; + + // Verify it's corrupt + assert!(persist.load_hnsw_snapshot().await.is_err()); + + // Now persist a valid index (should overwrite the corrupt entry) + let mut index = HnswIndex::new(4); + let id1 = NodeId::new([1; 16]); + index.insert(id1, vec![1.0, 0.0, 0.0, 0.0]).expect("insert"); + + persist + .persist_hnsw_snapshot(&index) + .await + .expect("persist should overwrite corrupt entry"); + + // Loading should now succeed + let loaded = persist + .load_hnsw_snapshot() + .await + .expect("load should succeed after overwrite") + .expect("snapshot should exist"); + + assert_eq!(loaded.total_nodes, 1); + assert!(loaded.verify().is_ok()); +} diff --git a/crates/khive-retrieval/src/policy.rs b/crates/khive-retrieval/src/policy.rs new file mode 100644 index 00000000..668100cc --- /dev/null +++ b/crates/khive-retrieval/src/policy.rs @@ -0,0 +1,344 @@ +//! Policy integration for access-controlled retrieval. +//! +//! # RETRIEVAL-03: Policy Integration +//! +//! This module provides policy-based filtering of search results, ensuring +//! that callers only see documents they are authorized to access. +//! +//! # Architecture +//! +//! ```text +//! Query -> Retrieval -> Policy Filter -> Results +//! | +//! v +//! PolicyEngine +//! ``` +//! +//! # Example +//! +//! ```ignore +//! use khive_retrieval::policy::{SearchPolicy, filter_by_policy}; +//! +//! let policy = SearchPolicy::new(ClearanceLevel::Internal); +//! let filtered = filter_by_policy(results, &policy, |id| get_doc_clearance(id)); +//! ``` + +use khive_score::DeterministicScore; +use std::hash::Hash; + +#[cfg(feature = "policy")] +use khive_gate::GateContext as PolicyContext; + +/// Clearance level for documents. +/// +/// Higher values indicate more restricted access. +/// This is a simple hierarchical model; more complex ABAC can be built on top. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] +pub enum ClearanceLevel { + /// Public documents, accessible to all. + #[default] + Public = 0, + /// Internal documents, accessible to authenticated users. + Internal = 1, + /// Confidential documents, restricted access. + Confidential = 2, + /// Secret documents, highly restricted. + Secret = 3, +} + +impl ClearanceLevel { + /// Check if this clearance level can access a document with the given level. + /// + /// A caller can access a document if their clearance is >= the document's level. + #[inline] + pub fn can_access(&self, document_level: ClearanceLevel) -> bool { + *self >= document_level + } +} + +/// Policy context for search operations. +/// +/// Encapsulates the caller's clearance level and optional policy engine +/// for more complex access control decisions. +#[derive(Debug, Clone)] +pub struct SearchPolicy { + /// Caller's clearance level (simple hierarchical model). + pub caller_clearance: ClearanceLevel, + + /// Optional policy engine for complex ABAC decisions. + #[cfg(feature = "policy")] + pub policy_context: Option, +} + +impl SearchPolicy { + /// Create a new search policy with the given clearance level. + pub fn new(caller_clearance: ClearanceLevel) -> Self { + Self { + caller_clearance, + #[cfg(feature = "policy")] + policy_context: None, + } + } + + /// Create a public-level search policy (default). + pub fn public() -> Self { + Self::new(ClearanceLevel::Public) + } + + /// Create an internal-level search policy. + pub fn internal() -> Self { + Self::new(ClearanceLevel::Internal) + } + + /// Create a confidential-level search policy. + pub fn confidential() -> Self { + Self::new(ClearanceLevel::Confidential) + } + + /// Create a secret-level search policy. + pub fn secret() -> Self { + Self::new(ClearanceLevel::Secret) + } + + /// Set the policy context for complex access control. + #[cfg(feature = "policy")] + pub fn with_context(mut self, context: PolicyContext) -> Self { + self.policy_context = Some(context); + self + } + + /// Check if the caller can access a document with the given clearance. + #[inline] + pub fn can_access(&self, document_clearance: ClearanceLevel) -> bool { + self.caller_clearance.can_access(document_clearance) + } +} + +impl Default for SearchPolicy { + fn default() -> Self { + Self::public() + } +} + +/// Filter search results based on policy. +/// +/// # Arguments +/// +/// * `results` - The search results to filter. +/// * `policy` - The search policy to apply. +/// * `get_clearance` - A function that returns the clearance level for a given ID. +/// +/// # Returns +/// +/// A new vector containing only the results the caller is authorized to see. +/// +/// # Example +/// +/// ```ignore +/// let policy = SearchPolicy::new(ClearanceLevel::Internal); +/// let filtered = filter_by_policy(results, &policy, |id| { +/// // Look up document clearance from metadata +/// get_document_clearance(id) +/// }); +/// ``` +pub fn filter_by_policy( + results: Vec<(Id, DeterministicScore)>, + policy: &SearchPolicy, + get_clearance: F, +) -> Vec<(Id, DeterministicScore)> +where + Id: Clone, + F: Fn(&Id) -> ClearanceLevel, +{ + results + .into_iter() + .filter(|(id, _)| { + let doc_clearance = get_clearance(id); + policy.can_access(doc_clearance) + }) + .collect() +} + +/// Filter search results using a custom predicate. +/// +/// This is a more flexible version of `filter_by_policy` that allows +/// arbitrary access control logic. +/// +/// # Arguments +/// +/// * `results` - The search results to filter. +/// * `is_accessible` - A predicate that returns true if the caller can access the document. +/// +/// # Returns +/// +/// A new vector containing only the accessible results. +pub fn filter_by_predicate( + results: Vec<(Id, DeterministicScore)>, + is_accessible: F, +) -> Vec<(Id, DeterministicScore)> +where + Id: Clone, + F: Fn(&Id) -> bool, +{ + results + .into_iter() + .filter(|(id, _)| is_accessible(id)) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_clearance_level_ordering() { + assert!(ClearanceLevel::Secret > ClearanceLevel::Confidential); + assert!(ClearanceLevel::Confidential > ClearanceLevel::Internal); + assert!(ClearanceLevel::Internal > ClearanceLevel::Public); + } + + #[test] + fn test_clearance_can_access() { + let secret = ClearanceLevel::Secret; + let public = ClearanceLevel::Public; + + // Secret can access everything + assert!(secret.can_access(ClearanceLevel::Secret)); + assert!(secret.can_access(ClearanceLevel::Confidential)); + assert!(secret.can_access(ClearanceLevel::Internal)); + assert!(secret.can_access(ClearanceLevel::Public)); + + // Public can only access public + assert!(public.can_access(ClearanceLevel::Public)); + assert!(!public.can_access(ClearanceLevel::Internal)); + assert!(!public.can_access(ClearanceLevel::Confidential)); + assert!(!public.can_access(ClearanceLevel::Secret)); + } + + #[test] + fn test_search_policy_constructors() { + let policy = SearchPolicy::public(); + assert_eq!(policy.caller_clearance, ClearanceLevel::Public); + + let policy = SearchPolicy::secret(); + assert_eq!(policy.caller_clearance, ClearanceLevel::Secret); + } + + // ========================================================================= + // RETRIEVAL-03: Policy Integration Tests + // ========================================================================= + + #[test] + fn test_filter_by_policy_hides_secret_from_public() { + let results = vec![ + ("doc_public", DeterministicScore::from_f64(0.9)), + ("doc_secret", DeterministicScore::from_f64(0.95)), + ("doc_internal", DeterministicScore::from_f64(0.8)), + ]; + + let policy = SearchPolicy::public(); + + // Clearance lookup function + let get_clearance = |id: &&str| -> ClearanceLevel { + match *id { + "doc_public" => ClearanceLevel::Public, + "doc_internal" => ClearanceLevel::Internal, + "doc_secret" => ClearanceLevel::Secret, + _ => ClearanceLevel::Public, + } + }; + + let filtered = filter_by_policy(results, &policy, get_clearance); + + // Public caller should only see public documents + assert_eq!(filtered.len(), 1); + assert_eq!(filtered[0].0, "doc_public"); + } + + #[test] + fn test_filter_by_policy_secret_sees_all() { + let results = vec![ + ("doc_public", DeterministicScore::from_f64(0.9)), + ("doc_secret", DeterministicScore::from_f64(0.95)), + ("doc_confidential", DeterministicScore::from_f64(0.8)), + ]; + + let policy = SearchPolicy::secret(); + + let get_clearance = |id: &&str| -> ClearanceLevel { + match *id { + "doc_public" => ClearanceLevel::Public, + "doc_confidential" => ClearanceLevel::Confidential, + "doc_secret" => ClearanceLevel::Secret, + _ => ClearanceLevel::Public, + } + }; + + let filtered = filter_by_policy(results, &policy, get_clearance); + + // Secret caller should see all documents + assert_eq!(filtered.len(), 3); + } + + #[test] + fn test_filter_by_policy_internal_sees_public_and_internal() { + let results = vec![ + ("doc_public", DeterministicScore::from_f64(0.9)), + ("doc_secret", DeterministicScore::from_f64(0.95)), + ("doc_internal", DeterministicScore::from_f64(0.8)), + ]; + + let policy = SearchPolicy::internal(); + + let get_clearance = |id: &&str| -> ClearanceLevel { + match *id { + "doc_public" => ClearanceLevel::Public, + "doc_internal" => ClearanceLevel::Internal, + "doc_secret" => ClearanceLevel::Secret, + _ => ClearanceLevel::Public, + } + }; + + let filtered = filter_by_policy(results, &policy, get_clearance); + + // Internal caller should see public and internal + assert_eq!(filtered.len(), 2); + assert!(filtered.iter().any(|(id, _)| *id == "doc_public")); + assert!(filtered.iter().any(|(id, _)| *id == "doc_internal")); + assert!(!filtered.iter().any(|(id, _)| *id == "doc_secret")); + } + + #[test] + fn test_filter_by_policy_preserves_order() { + let results = vec![ + ("doc1", DeterministicScore::from_f64(0.9)), + ("doc2", DeterministicScore::from_f64(0.8)), + ("doc3", DeterministicScore::from_f64(0.7)), + ]; + + let policy = SearchPolicy::public(); + let get_clearance = |_: &&str| ClearanceLevel::Public; + + let filtered = filter_by_policy(results, &policy, get_clearance); + + // Order should be preserved + assert_eq!(filtered[0].0, "doc1"); + assert_eq!(filtered[1].0, "doc2"); + assert_eq!(filtered[2].0, "doc3"); + } + + #[test] + fn test_filter_by_predicate() { + let results = vec![ + ("allowed", DeterministicScore::from_f64(0.9)), + ("denied", DeterministicScore::from_f64(0.8)), + ("allowed2", DeterministicScore::from_f64(0.7)), + ]; + + let filtered = filter_by_predicate(results, |id| id.starts_with("allowed")); + + assert_eq!(filtered.len(), 2); + assert_eq!(filtered[0].0, "allowed"); + assert_eq!(filtered[1].0, "allowed2"); + } +} diff --git a/crates/khive-retrieval/src/query_ir.rs b/crates/khive-retrieval/src/query_ir.rs new file mode 100644 index 00000000..a86ab164 --- /dev/null +++ b/crates/khive-retrieval/src/query_ir.rs @@ -0,0 +1,632 @@ +//! Query Intermediate Representation for the retrieval pipeline. +//! +//! Provides a composable, analyzable tree representation of search queries +//! that can be inspected and optimized before execution. +//! +//! # Motivation +//! +//! The existing [`Query`](crate::hybrid::Query) struct captures *what* to +//! search (text + optional embedding), but not *how* the retrieval pipeline +//! should compose sub-queries, apply filters, or perform fusion. `QueryNode` +//! makes that composition explicit as an IR tree. +//! +//! # Example +//! +//! ```rust +//! use khive_retrieval::query_ir::{QueryNode, FuseStrategy}; +//! +//! // Build a hybrid query: vector + keyword fused with RRF, then top-10 +//! let embedding = vec![0.1_f32; 128]; +//! let q = QueryNode::hybrid(embedding, "distributed consensus", 10); +//! +//! assert_eq!(q.leaf_count(), 2); +//! assert_eq!(q.top_k(), 10); +//! assert!(!q.is_empty()); +//! ``` + +use khive_score::DeterministicScore; +use serde::{Deserialize, Serialize}; + +// --------------------------------------------------------------------------- +// Core IR node +// --------------------------------------------------------------------------- + +/// A node in the Query IR tree. +/// +/// Each variant represents a single retrieval operation or combinator. +/// Nodes compose recursively -- a `Fuse` holds children, a `Filter` wraps +/// a single child, and leaf nodes (`Vector`, `Keyword`, `Empty`) terminate +/// the tree. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum QueryNode { + /// Vector similarity search (e.g. HNSW nearest-neighbor). + Vector { + /// Pre-computed query embedding. + embedding: Vec, + /// Number of results to return. + top_k: usize, + /// Optional minimum similarity threshold. + min_score: Option, + }, + + /// Keyword / BM25 text search. + Keyword { + /// Query text. + text: String, + /// Number of results to return. + top_k: usize, + /// Optional minimum relevance threshold. + min_score: Option, + }, + + /// Fuse multiple sub-queries into a single ranked list. + Fuse { + /// Sub-queries to fuse. + children: Vec, + /// Strategy for combining ranked lists. + strategy: FuseStrategy, + /// Number of results after fusion. + top_k: usize, + }, + + /// Filter the results of a sub-query. + Filter { + /// The sub-query whose results are filtered. + child: Box, + /// Predicate to apply. + predicate: FilterPredicate, + }, + + /// Rerank the results of a sub-query. + Rerank { + /// The sub-query whose results are reranked. + child: Box, + /// Reranking method. + method: RerankMethod, + /// Number of results after reranking. + top_k: usize, + }, + + /// An empty query that is guaranteed to produce no results. + /// + /// Useful as the result of constant-folding provably-empty sub-trees. + Empty, +} + +// --------------------------------------------------------------------------- +// Supporting enums +// --------------------------------------------------------------------------- + +/// Fusion strategy for combining sub-query result lists. +/// +/// Mirrors [`FusionStrategy`](crate::fusion::FusionStrategy) at the IR level +/// so that the query plan is self-contained and serialisable without depending +/// on runtime fusion internals. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum FuseStrategy { + /// Reciprocal Rank Fusion with smoothing constant `k`. + /// + /// Standard default: k = 60 (Craswell et al., 2009). + Rrf { + /// Smoothing constant. + k: usize, + }, + + /// Weighted linear combination of scores. + /// + /// One weight per child; weights are normalised at execution time. + Weighted { + /// Per-child weights (will be normalised to sum to 1.0). + weights: Vec, + }, + + /// Union with max-score-per-document semantics. + Union, +} + +/// Predicate for post-retrieval filtering. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum FilterPredicate { + /// Keep only results whose score meets a minimum threshold. + MinScore(DeterministicScore), + + /// Keep at most `k` results (top-k truncation). + TopK(usize), + + /// Keep results where a metadata field equals a given value. + MetadataEquals { + /// Metadata field name. + field: String, + /// Expected value (JSON). + value: serde_json::Value, + }, + + /// All contained predicates must hold (conjunction). + And(Vec), + + /// At least one contained predicate must hold (disjunction). + Or(Vec), +} + +/// Method for reranking search results. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum RerankMethod { + /// Cross-encoder neural reranking (placeholder for future integration). + CrossEncoder { + /// Model identifier. + model: String, + }, + + /// Score-based reranking with custom per-signal weights. + ScoreWeighted { + /// Weights applied to each scoring signal. + weights: Vec, + }, +} + +// --------------------------------------------------------------------------- +// Construction helpers +// --------------------------------------------------------------------------- + +impl QueryNode { + /// Create a vector search leaf node. + /// + /// # Arguments + /// + /// * `embedding` - Pre-computed query embedding. + /// * `top_k` - Number of results to return. + pub fn vector(embedding: Vec, top_k: usize) -> Self { + QueryNode::Vector { + embedding, + top_k, + min_score: None, + } + } + + /// Create a keyword search leaf node. + /// + /// # Arguments + /// + /// * `text` - Query text (converted via `Into`). + /// * `top_k` - Number of results to return. + pub fn keyword(text: impl Into, top_k: usize) -> Self { + QueryNode::Keyword { + text: text.into(), + top_k, + min_score: None, + } + } + + /// Create a hybrid query (vector + keyword with RRF fusion). + /// + /// The two leaf sub-queries each request `top_k * 3` candidates to give + /// the fusion step a sufficiently large candidate pool. + /// + /// # Arguments + /// + /// * `embedding` - Pre-computed query embedding. + /// * `text` - Query text (converted via `Into`). + /// * `top_k` - Number of final results after fusion. + pub fn hybrid(embedding: Vec, text: impl Into, top_k: usize) -> Self { + QueryNode::Fuse { + children: vec![ + QueryNode::vector(embedding, top_k * 3), + QueryNode::keyword(text, top_k * 3), + ], + strategy: FuseStrategy::Rrf { k: 60 }, + top_k, + } + } + + /// Wrap this node with a minimum-score filter. + #[must_use] + pub fn with_min_score(self, min_score: DeterministicScore) -> Self { + QueryNode::Filter { + child: Box::new(self), + predicate: FilterPredicate::MinScore(min_score), + } + } + + /// Wrap this node with a top-k truncation filter. + #[must_use] + pub fn with_top_k(self, k: usize) -> Self { + QueryNode::Filter { + child: Box::new(self), + predicate: FilterPredicate::TopK(k), + } + } + + // ----------------------------------------------------------------------- + // Analysis helpers + // ----------------------------------------------------------------------- + + /// Returns `true` if this query is provably empty (no results possible). + /// + /// A query is provably empty when: + /// - It is the `Empty` variant. + /// - A leaf has `top_k == 0`. + /// - A keyword leaf has empty text. + /// - A fuse node has no children. + /// - A filter/rerank wraps a provably-empty child. + pub fn is_empty(&self) -> bool { + match self { + QueryNode::Empty => true, + QueryNode::Vector { top_k: 0, .. } => true, + QueryNode::Keyword { top_k: 0, .. } => true, + QueryNode::Keyword { text, .. } if text.is_empty() => true, + QueryNode::Fuse { children, .. } if children.is_empty() => true, + QueryNode::Filter { child, .. } => child.is_empty(), + QueryNode::Rerank { child, .. } => child.is_empty(), + _ => false, + } + } + + /// Count the total number of leaf search operations in the tree. + /// + /// `Vector` and `Keyword` nodes each count as 1. `Empty` counts as 0. + /// Combinators recurse into their children. + pub fn leaf_count(&self) -> usize { + match self { + QueryNode::Vector { .. } | QueryNode::Keyword { .. } => 1, + QueryNode::Fuse { children, .. } => children.iter().map(|c| c.leaf_count()).sum(), + QueryNode::Filter { child, .. } | QueryNode::Rerank { child, .. } => child.leaf_count(), + QueryNode::Empty => 0, + } + } + + /// Return the effective `top_k` requested by this node. + /// + /// For `Filter` nodes with a `TopK` predicate, the predicate's value is + /// returned. Otherwise the child's `top_k` propagates upward. + pub fn top_k(&self) -> usize { + match self { + QueryNode::Vector { top_k, .. } => *top_k, + QueryNode::Keyword { top_k, .. } => *top_k, + QueryNode::Fuse { top_k, .. } => *top_k, + QueryNode::Filter { child, predicate } => match predicate { + FilterPredicate::TopK(k) => *k, + _ => child.top_k(), + }, + QueryNode::Rerank { top_k, .. } => *top_k, + QueryNode::Empty => 0, + } + } + + /// Return the depth of the IR tree (longest root-to-leaf path). + /// + /// Leaf nodes have depth 1. `Empty` has depth 0. + pub fn depth(&self) -> usize { + match self { + QueryNode::Empty => 0, + QueryNode::Vector { .. } | QueryNode::Keyword { .. } => 1, + QueryNode::Fuse { children, .. } => { + 1 + children.iter().map(|c| c.depth()).max().unwrap_or(0) + } + QueryNode::Filter { child, .. } | QueryNode::Rerank { child, .. } => 1 + child.depth(), + } + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +#[allow(clippy::uninlined_format_args)] +mod tests { + use super::*; + + // -- Construction ------------------------------------------------------- + + #[test] + fn test_vector_construction() { + let emb = vec![0.1, 0.2, 0.3]; + let node = QueryNode::vector(emb.clone(), 10); + match &node { + QueryNode::Vector { + embedding, + top_k, + min_score, + } => { + assert_eq!(embedding, &emb); + assert_eq!(*top_k, 10); + assert!(min_score.is_none()); + } + other => panic!("expected Vector, got {:?}", other), + } + } + + #[test] + fn test_keyword_construction() { + let node = QueryNode::keyword("hello world", 5); + match &node { + QueryNode::Keyword { + text, + top_k, + min_score, + } => { + assert_eq!(text, "hello world"); + assert_eq!(*top_k, 5); + assert!(min_score.is_none()); + } + other => panic!("expected Keyword, got {:?}", other), + } + } + + #[test] + fn test_hybrid_construction() { + let emb = vec![0.1_f32; 128]; + let node = QueryNode::hybrid(emb, "distributed consensus", 10); + match &node { + QueryNode::Fuse { + children, + strategy, + top_k, + } => { + assert_eq!(children.len(), 2); + assert_eq!(*top_k, 10); + // Sub-queries should request 3x candidates. + assert_eq!(children[0].top_k(), 30); + assert_eq!(children[1].top_k(), 30); + assert!(matches!(strategy, FuseStrategy::Rrf { k: 60 })); + } + other => panic!("expected Fuse, got {:?}", other), + } + } + + // -- is_empty ----------------------------------------------------------- + + #[test] + fn test_empty_variant() { + assert!(QueryNode::Empty.is_empty()); + assert_eq!(QueryNode::Empty.leaf_count(), 0); + assert_eq!(QueryNode::Empty.top_k(), 0); + assert_eq!(QueryNode::Empty.depth(), 0); + } + + #[test] + fn test_vector_top_k_zero_is_empty() { + let node = QueryNode::vector(vec![1.0], 0); + assert!(node.is_empty()); + } + + #[test] + fn test_keyword_top_k_zero_is_empty() { + let node = QueryNode::keyword("hello", 0); + assert!(node.is_empty()); + } + + #[test] + fn test_keyword_empty_text_is_empty() { + let node = QueryNode::keyword("", 10); + assert!(node.is_empty()); + } + + #[test] + fn test_fuse_no_children_is_empty() { + let node = QueryNode::Fuse { + children: vec![], + strategy: FuseStrategy::Rrf { k: 60 }, + top_k: 10, + }; + assert!(node.is_empty()); + } + + #[test] + fn test_filter_of_empty_is_empty() { + let node = QueryNode::Empty.with_min_score(DeterministicScore::from_f64(0.5)); + assert!(node.is_empty()); + } + + #[test] + fn test_rerank_of_empty_is_empty() { + let node = QueryNode::Rerank { + child: Box::new(QueryNode::Empty), + method: RerankMethod::ScoreWeighted { weights: vec![1.0] }, + top_k: 10, + }; + assert!(node.is_empty()); + } + + #[test] + fn test_non_empty_query() { + let node = QueryNode::keyword("hello", 5); + assert!(!node.is_empty()); + } + + // -- leaf_count --------------------------------------------------------- + + #[test] + fn test_leaf_count_single() { + assert_eq!(QueryNode::vector(vec![1.0], 5).leaf_count(), 1); + assert_eq!(QueryNode::keyword("q", 5).leaf_count(), 1); + } + + #[test] + fn test_leaf_count_hybrid() { + let q = QueryNode::hybrid(vec![1.0], "q", 10); + assert_eq!(q.leaf_count(), 2); + } + + #[test] + fn test_leaf_count_nested() { + // Fuse(Fuse(vec, kw), kw) = 3 leaves + let inner = QueryNode::hybrid(vec![1.0], "inner", 10); + let outer = QueryNode::Fuse { + children: vec![inner, QueryNode::keyword("outer", 10)], + strategy: FuseStrategy::Union, + top_k: 10, + }; + assert_eq!(outer.leaf_count(), 3); + } + + // -- top_k -------------------------------------------------------------- + + #[test] + fn test_top_k_leaf() { + assert_eq!(QueryNode::vector(vec![], 7).top_k(), 7); + assert_eq!(QueryNode::keyword("q", 3).top_k(), 3); + } + + #[test] + fn test_top_k_fuse() { + let q = QueryNode::hybrid(vec![1.0], "q", 15); + assert_eq!(q.top_k(), 15); + } + + #[test] + fn test_top_k_filter_topk_predicate() { + let node = QueryNode::keyword("q", 100).with_top_k(5); + assert_eq!(node.top_k(), 5); + } + + #[test] + fn test_top_k_filter_non_topk_predicate() { + let node = QueryNode::keyword("q", 20).with_min_score(DeterministicScore::from_f64(0.5)); + // min_score filter doesn't change top_k -- falls through to child. + assert_eq!(node.top_k(), 20); + } + + // -- depth -------------------------------------------------------------- + + #[test] + fn test_depth_leaf() { + assert_eq!(QueryNode::vector(vec![1.0], 5).depth(), 1); + assert_eq!(QueryNode::keyword("q", 5).depth(), 1); + } + + #[test] + fn test_depth_hybrid() { + let q = QueryNode::hybrid(vec![1.0], "q", 10); + // Fuse -> leaf = depth 2 + assert_eq!(q.depth(), 2); + } + + #[test] + fn test_depth_chained_filters() { + let q = QueryNode::keyword("q", 10) + .with_min_score(DeterministicScore::from_f64(0.5)) + .with_top_k(5); + // TopK(Filter(MinScore(Keyword))) = 3 wrappers + 1 leaf = depth 3 + assert_eq!(q.depth(), 3); + } + + // -- with_min_score / with_top_k chaining ------------------------------- + + #[test] + fn test_builder_chaining() { + let node = QueryNode::keyword("rust async patterns", 20) + .with_min_score(DeterministicScore::from_f64(0.3)) + .with_top_k(10); + + assert_eq!(node.top_k(), 10); + assert_eq!(node.leaf_count(), 1); + assert!(!node.is_empty()); + } + + // -- Serde round-trip --------------------------------------------------- + + #[test] + fn test_serde_roundtrip_vector() { + let node = QueryNode::vector(vec![0.1, 0.2, 0.3], 10); + let json = serde_json::to_string(&node).expect("serialize"); + let back: QueryNode = serde_json::from_str(&json).expect("deserialize"); + assert_eq!(back.top_k(), 10); + assert_eq!(back.leaf_count(), 1); + } + + #[test] + fn test_serde_roundtrip_keyword() { + let node = QueryNode::keyword("hello world", 5); + let json = serde_json::to_string(&node).expect("serialize"); + let back: QueryNode = serde_json::from_str(&json).expect("deserialize"); + assert_eq!(back.top_k(), 5); + } + + #[test] + fn test_serde_roundtrip_hybrid() { + let node = QueryNode::hybrid(vec![1.0, 2.0], "search query", 10); + let json = serde_json::to_string(&node).expect("serialize"); + let back: QueryNode = serde_json::from_str(&json).expect("deserialize"); + assert_eq!(back.top_k(), 10); + assert_eq!(back.leaf_count(), 2); + } + + #[test] + fn test_serde_roundtrip_complex() { + let node = QueryNode::hybrid(vec![0.5; 4], "complex query", 10) + .with_min_score(DeterministicScore::from_f64(0.2)) + .with_top_k(5); + + let json = serde_json::to_string_pretty(&node).expect("serialize"); + let back: QueryNode = serde_json::from_str(&json).expect("deserialize"); + + assert_eq!(back.top_k(), 5); + assert_eq!(back.leaf_count(), 2); + assert!(!back.is_empty()); + } + + #[test] + fn test_serde_roundtrip_empty() { + let node = QueryNode::Empty; + let json = serde_json::to_string(&node).expect("serialize"); + let back: QueryNode = serde_json::from_str(&json).expect("deserialize"); + assert!(back.is_empty()); + } + + #[test] + fn test_serde_roundtrip_filter_metadata() { + let node = QueryNode::Filter { + child: Box::new(QueryNode::keyword("docs", 10)), + predicate: FilterPredicate::MetadataEquals { + field: "type".to_string(), + value: serde_json::json!("memory"), + }, + }; + let json = serde_json::to_string(&node).expect("serialize"); + let back: QueryNode = serde_json::from_str(&json).expect("deserialize"); + assert_eq!(back.leaf_count(), 1); + } + + #[test] + fn test_serde_roundtrip_rerank() { + let node = QueryNode::Rerank { + child: Box::new(QueryNode::keyword("rerank me", 20)), + method: RerankMethod::CrossEncoder { + model: "ms-marco-MiniLM".to_string(), + }, + top_k: 10, + }; + let json = serde_json::to_string(&node).expect("serialize"); + let back: QueryNode = serde_json::from_str(&json).expect("deserialize"); + assert_eq!(back.top_k(), 10); + } + + #[test] + fn test_serde_roundtrip_compound_predicate() { + let pred = FilterPredicate::And(vec![ + FilterPredicate::MinScore(DeterministicScore::from_f64(0.3)), + FilterPredicate::Or(vec![ + FilterPredicate::MetadataEquals { + field: "lang".to_string(), + value: serde_json::json!("en"), + }, + FilterPredicate::MetadataEquals { + field: "lang".to_string(), + value: serde_json::json!("zh"), + }, + ]), + ]); + let node = QueryNode::Filter { + child: Box::new(QueryNode::keyword("test", 10)), + predicate: pred, + }; + let json = serde_json::to_string(&node).expect("serialize"); + let back: QueryNode = serde_json::from_str(&json).expect("deserialize"); + assert_eq!(back.leaf_count(), 1); + } +} diff --git a/crates/khive-retrieval/src/replay/engine_replay.rs b/crates/khive-retrieval/src/replay/engine_replay.rs new file mode 100644 index 00000000..d25a85bb --- /dev/null +++ b/crates/khive-retrieval/src/replay/engine_replay.rs @@ -0,0 +1,1027 @@ +//! Temporal replay APIs — Three Observables Feedback Loop (Phase 3). +//! +//! Provides four primitives for diffing past vs. present weight state: +//! +//! | Function | Purpose | +//! |-----------------------|------------------------------------------------------------| +//! | [`weights_as_of`] | Reconstruct weight snapshot at a past timestamp | +//! | [`replay`] | Re-run vector search with historical or live weights | +//! | [`diff`] | Jaccard + rank-delta report between two temporal replays | +//! | [`rank_history`] | Weight change timeline for a single atom | +//! | [`regression_check`] | Re-run a stored compose event against current weights | +//! +//! # Design +//! +//! The weight_events table is the ground-truth log. No external baseline is +//! needed — the log IS the reference. Temporal replay reconstructs past weight +//! state by selecting the latest `weight_events` row per (lambda_id, atom_id) +//! with `ts ≤ at_time`. +//! +//! Ranking is performed by multiplying raw vector similarity scores by per-atom +//! weights, then returning the top-K atom IDs in descending score order. +//! +//! # Drift Metrics (submodule) +//! +//! [`metrics::jaccard_stability_7d`] — rolling 7-day median Jaccard from +//! regression_check over stored compose events. +//! +//! [`metrics::atom_rank_variance`] — variance of an atom's rank position across +//! all compose events where it appeared in top_atoms. +//! +//! [`metrics::adjustment_rate_per_day`] — count of weight_events rows per day, +//! useful for detecting runaway adjustment patterns. + +// The `engine` feature is a future integration point (EmbeddedEngine not yet ported). +// Silence the cfg warning — the feature gate is intentionally undeclared so it never activates. +#![allow(unexpected_cfgs)] + +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; + +use chrono::{DateTime, NaiveDate, Utc}; +use parking_lot::Mutex; +#[cfg(feature = "engine")] +use rusqlite::OptionalExtension as _; +use rusqlite::{params, Connection}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::persist::PersistError as EngineError; +use crate::weights::WEIGHT_FLOOR; +// TODO(port-engine): EmbeddedEngine not yet in khive-retrieval scope; stub for compilation. +#[allow(dead_code)] +type EmbeddedEngine = (); + +// --------------------------------------------------------------------------- +// Public types +// --------------------------------------------------------------------------- + +/// Per-atom weight change record in chronological order. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RankHistoryPoint { + /// Timestamp of the adjustment (UTC). + pub ts: DateTime, + /// Weight after this adjustment was applied. + pub weight_after: f32, + /// Raw delta that was applied. + pub delta: f32, + /// Channel that emitted this adjustment (`ambient`, `explicit`, `ground_truth`). + pub channel: String, + /// Optional context identifier carried by the caller. + pub context_id: Option, + /// Optional brain_events UUID that triggered this adjustment. + pub event_id: Option, +} + +/// Diff report comparing two temporal rank lists for the same query. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DiffReport { + /// Jaccard similarity: |A ∩ B| / |A ∪ B|. + pub jaccard: f32, + /// Atoms present in the t2 result but absent from t1. + pub added: Vec, + /// Atoms present in the t1 result but absent from t2. + pub dropped: Vec, + /// Per-atom rank change from t1 → t2 (negative = moved up). + pub rank_deltas: Vec<(Uuid, i32)>, + /// Ordered top-K atom IDs at t1. + pub top_k_at_t1: Vec, + /// Ordered top-K atom IDs at t2. + pub top_k_at_t2: Vec, +} + +/// Report comparing a stored compose event's original top_atoms against +/// the same query re-run with current weights. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RegressionReport { + /// Brain-events row UUID that was replayed. + pub event_id: Uuid, + /// Query text recorded at compose time (empty string if none). + pub query_text: String, + /// Ordered atom list stored in the original compose event. + pub original_top_atoms: Vec, + /// Ordered atom list from re-running the query with current weights. + pub current_top_atoms: Vec, + /// Jaccard similarity between the two lists. + pub jaccard: f32, + /// Atoms present in current but absent from original. + pub added: Vec, + /// Atoms present in original but absent from current. + pub dropped: Vec, + /// UTC timestamp when the original compose event was recorded. + pub timestamp_original: DateTime, +} + +// --------------------------------------------------------------------------- +// weights_as_of +// --------------------------------------------------------------------------- + +/// Reconstruct the weight state for a lambda at a given point in time. +/// +/// For each (lambda_id, atom_id) pair, selects the latest `weight_events` row +/// with `ts ≤ at_time` and returns `weight_after`. Atoms with no history +/// before `at_time` are absent from the map; callers should treat absence as +/// the implicit default of 1.0. +/// +/// # SQL +/// +/// ```sql +/// SELECT atom_id, weight_after +/// FROM ( +/// SELECT atom_id, weight_after, +/// ROW_NUMBER() OVER (PARTITION BY atom_id ORDER BY ts DESC) as rn +/// FROM weight_events +/// WHERE lambda_id = ?1 AND ts <= ?2 +/// ) +/// WHERE rn = 1 +/// ``` +pub async fn weights_as_of( + conn: &Arc>, + namespace: &str, + at_time: DateTime, +) -> Result, EngineError> { + let conn = Arc::clone(conn); + let namespace_str = namespace.to_string(); + let at_time_us = at_time.timestamp_micros(); + + tokio::task::spawn_blocking(move || { + let conn = conn.lock(); + let mut result = HashMap::new(); + + let mut stmt = conn + .prepare( + "SELECT atom_id, weight_after + FROM ( + SELECT atom_id, weight_after, + ROW_NUMBER() OVER (PARTITION BY atom_id ORDER BY ts DESC) as rn + FROM weight_events + WHERE namespace = ?1 AND ts <= ?2 + ) + WHERE rn = 1", + ) + .map_err(|e| EngineError::Internal(format!("weights_as_of prepare: {e}")))?; + + let mut rows = stmt + .query(params![namespace_str, at_time_us]) + .map_err(|e| EngineError::Internal(format!("weights_as_of query: {e}")))?; + + while let Some(row) = rows + .next() + .map_err(|e| EngineError::Internal(format!("weights_as_of row: {e}")))? + { + let atom_id_str: String = row + .get(0) + .map_err(|e| EngineError::Internal(format!("weights_as_of col 0: {e}")))?; + let weight_after: f64 = row + .get(1) + .map_err(|e| EngineError::Internal(format!("weights_as_of col 1: {e}")))?; + + if let Ok(uuid) = atom_id_str.parse::() { + let clamped = + (weight_after as f32).clamp(WEIGHT_FLOOR, crate::weights::WEIGHT_CEIL); + result.insert(uuid, clamped); + } + } + + Ok(result) + }) + .await + .map_err(|e| EngineError::Internal(format!("weights_as_of join: {e}")))? +} + +// --------------------------------------------------------------------------- +// Namespace isolation helper (B1 fix — must come before replay) +// --------------------------------------------------------------------------- + +/// Return the subset of `candidate_ids` whose atoms are owned by `namespace`. +/// +/// Queries `atoms WHERE namespace = ?1 AND id IN (?) AND deleted_at IS NULL`. +/// Preserves no particular order — the caller re-orders by HNSW rank after +/// filtering. +/// +/// The in-memory HNSW snapshot is global (it indexes all atoms regardless of +/// namespace). Without this post-filter, `replay()` would return atoms from +/// any namespace that happen to be semantically close to the query, leaking +/// cross-tenant atom UUIDs to the requesting lambda. +#[allow(dead_code)] // used only when feature = "engine" is active +fn filter_atoms_by_namespace( + conn: &Connection, + namespace: &str, + candidate_ids: &[Uuid], +) -> Result, EngineError> { + if candidate_ids.is_empty() { + return Ok(HashSet::new()); + } + + // Build the IN clause with per-item placeholders. + // SQLITE_SAFE_BIND_LIMIT is 999; candidate_k is at most top_k*4 (≤400 for top_k=100). + let placeholders: Vec = candidate_ids.iter().map(|_| "?".to_string()).collect(); + let sql = format!( + "SELECT id FROM atoms WHERE namespace = ? AND id IN ({}) AND deleted_at IS NULL", + placeholders.join(", ") + ); + + let mut stmt = conn + .prepare(&sql) + .map_err(|e| EngineError::Internal(format!("filter_atoms_by_namespace prepare: {e}")))?; + + // Collect all bind values as Strings so they have a uniform owned type. + // namespace goes first, then the UUID strings for the IN clause. + let id_strings: Vec = candidate_ids.iter().map(|u| u.to_string()).collect(); + let all_values: Vec<&str> = std::iter::once(namespace) + .chain(id_strings.iter().map(|s| s.as_str())) + .collect(); + + let rows = stmt + .query_map(rusqlite::params_from_iter(all_values.iter()), |row| { + row.get::<_, String>(0) + }) + .map_err(|e| EngineError::Internal(format!("filter_atoms_by_namespace query: {e}")))?; + + let owned: HashSet = rows + .filter_map(|r| r.ok()) + .filter_map(|s| s.parse::().ok()) + .collect(); + + Ok(owned) +} + +// TODO(port-engine): replay, diff, regression_check, load_brain_event, and +// jaccard_stability_7d require EmbeddedEngine which is not yet ported to +// khive-retrieval scope. Gated behind "engine" feature until ported. +#[cfg(feature = "engine")] +/// When `weight_override` is `Some(map)`, each atom's raw similarity score is +/// multiplied by the weight from the map (absent atoms default to 1.0). When +/// `None`, current `atom_weights` rows are used via `batch_load_weights`. +/// +/// Returns atom IDs in descending weighted-score order. +pub async fn replay( + engine: &EmbeddedEngine, + namespace: &str, + query_text: &str, + at_time: Option>, + top_k: usize, +) -> Result, EngineError> { + // Step 1: embed the query. + let query_vec = engine + .embed_query(query_text) + .await + .map_err(|e| EngineError::Embedding(format!("replay embed: {e}")))?; + + // Step 2: vector search via HNSW for a broad candidate set. + // Do this first so we know which atom IDs to load weights for. + let candidate_k = (top_k * 4).max(20); + let raw_results = engine + .search_by_vector(&query_vec, candidate_k) + .await + .map_err(|e| EngineError::Retrieval(format!("replay search: {e}")))?; + + // Step 2b (B1 fix): filter to atoms owned by this lambda's namespace. + // + // The HNSW snapshot is global — it contains atoms from every namespace + // stored in this engine instance. Without this filter, `replay()` would + // leak cross-tenant atom UUIDs into the ranked result (they default to + // weight 1.0 when absent from the weight map, potentially outranking the + // requesting lambda's own down-weighted atoms). + // + // A single engine instance may serve multiple lambdas whose atoms co-exist + // in SQLite but whose HNSW vectors are interleaved. + let raw_results = { + let conn_guard = engine.store().conn(); + let c = conn_guard.lock(); + let all_candidate_ids: Vec = raw_results.iter().map(|h| h.id).collect(); + let owned = filter_atoms_by_namespace(&c, namespace, &all_candidate_ids)?; + // Re-filter raw_results (preserving HNSW rank order). + raw_results + .into_iter() + .filter(|h| owned.contains(&h.id)) + .collect::>() + }; + + // Step 3: resolve weights for the candidate atom IDs. + let candidate_ids: Vec = raw_results.iter().map(|h| h.id).collect(); + let weights: HashMap = match at_time { + Some(t) => weights_as_of(&engine.store().conn(), namespace, t).await?, + None => { + crate::weights::batch_load_weights(&engine.store().conn(), namespace, &candidate_ids) + .await + .unwrap_or_default() + } + }; + + // Step 4: apply weight multiplier. + let mut scored: Vec<(Uuid, f32)> = raw_results + .into_iter() + .map(|hit| { + let w = weights.get(&hit.id).copied().unwrap_or(1.0_f32); + (hit.id, hit.score * w) + }) + .collect(); + + // Step 5: sort descending and truncate. + scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + scored.truncate(top_k); + + Ok(scored.into_iter().map(|(id, _)| id).collect()) +} + +/// Build a [`DiffReport`] from two ordered atom lists. +#[allow(dead_code)] // used only when feature = "engine" is active +fn compute_diff_report(top_k_at_t1: Vec, top_k_at_t2: Vec) -> DiffReport { + use std::collections::HashSet; + + let set_t1: HashSet = top_k_at_t1.iter().copied().collect(); + let set_t2: HashSet = top_k_at_t2.iter().copied().collect(); + + let intersection_size = set_t1.intersection(&set_t2).count(); + let union_size = set_t1.union(&set_t2).count(); + let jaccard = if union_size == 0 { + 1.0_f32 + } else { + intersection_size as f32 / union_size as f32 + }; + + let added: Vec = set_t2.difference(&set_t1).copied().collect(); + let dropped: Vec = set_t1.difference(&set_t2).copied().collect(); + + // Build rank maps (0-indexed). + let rank_t1: HashMap = top_k_at_t1 + .iter() + .enumerate() + .map(|(i, &id)| (id, i)) + .collect(); + let rank_t2: HashMap = top_k_at_t2 + .iter() + .enumerate() + .map(|(i, &id)| (id, i)) + .collect(); + + // Rank deltas only for atoms present in both. + let rank_deltas: Vec<(Uuid, i32)> = set_t1 + .intersection(&set_t2) + .filter_map(|&id| { + let r1 = *rank_t1.get(&id)?; + let r2 = *rank_t2.get(&id)?; + Some((id, r2 as i32 - r1 as i32)) + }) + .collect(); + + DiffReport { + jaccard, + added, + dropped, + rank_deltas, + top_k_at_t1, + top_k_at_t2, + } +} + +// --------------------------------------------------------------------------- +// diff — engine-dependent, gated +// --------------------------------------------------------------------------- + +/// Compute the diff between two temporal replays of the same query. +#[cfg(feature = "engine")] +pub async fn diff( + engine: &EmbeddedEngine, + namespace: &str, + query_text: &str, + t1: DateTime, + t2: DateTime, + top_k: usize, +) -> Result { + let (top_k_at_t1, top_k_at_t2) = tokio::try_join!( + replay(engine, namespace, query_text, Some(t1), top_k), + replay(engine, namespace, query_text, Some(t2), top_k), + )?; + Ok(compute_diff_report(top_k_at_t1, top_k_at_t2)) +} + +// --------------------------------------------------------------------------- +// rank_history +// --------------------------------------------------------------------------- + +/// Return the full weight-change history for a single (namespace, atom_id) pair +/// in ascending timestamp order. +/// +/// Useful for answering "why did this atom's rank change?" — each row captures +/// the delta, resulting weight, channel, and optional originating context/event. +pub async fn rank_history( + conn: &Arc>, + namespace: &str, + atom_id: Uuid, +) -> Result, EngineError> { + let conn = Arc::clone(conn); + let namespace_str = namespace.to_string(); + let atom_id_str = atom_id.to_string(); + + tokio::task::spawn_blocking(move || { + let conn = conn.lock(); + let mut stmt = conn + .prepare( + "SELECT ts, weight_after, delta, channel, context_id, event_id + FROM weight_events + WHERE namespace = ?1 AND atom_id = ?2 + ORDER BY ts ASC", + ) + .map_err(|e| EngineError::Internal(format!("rank_history prepare: {e}")))?; + + let rows = stmt + .query_map(params![namespace_str, atom_id_str], |row| { + let ts_us: i64 = row.get(0)?; + let weight_after: f64 = row.get(1)?; + let delta: f64 = row.get(2)?; + let channel: String = row.get(3)?; + let context_id: Option = row.get(4)?; + let event_id_str: Option = row.get(5)?; + Ok(( + ts_us, + weight_after, + delta, + channel, + context_id, + event_id_str, + )) + }) + .map_err(|e| EngineError::Internal(format!("rank_history query: {e}")))?; + + let mut points = Vec::new(); + for row in rows { + let (ts_us, weight_after, delta, channel, context_id, event_id_str) = + row.map_err(|e| EngineError::Internal(format!("rank_history row: {e}")))?; + + let ts = DateTime::from_timestamp_micros(ts_us).unwrap_or_else(Utc::now); + + let event_id = event_id_str.and_then(|s| s.parse::().ok()); + + points.push(RankHistoryPoint { + ts, + weight_after: weight_after as f32, + delta: delta as f32, + channel, + context_id, + event_id, + }); + } + + Ok(points) + }) + .await + .map_err(|e| EngineError::Internal(format!("rank_history join: {e}")))? +} + +// --------------------------------------------------------------------------- +// regression_check — engine-dependent, gated +// --------------------------------------------------------------------------- + +/// Re-run the query from a stored compose event against current weights. +#[cfg(feature = "engine")] +pub async fn regression_check( + engine: &EmbeddedEngine, + event_id: Uuid, +) -> Result { + // Step 1: load brain_events row. + // load_brain_event now returns InvalidData on malformed payload (B4 fix) + // and includes the stored embedding_model for validation (B5 fix). + let (query_text, original_top_atoms, namespace, created_at_us, stored_model) = + load_brain_event(engine, event_id).await?; + + // Step 2 (B5 fix): validate embedding model compatibility. + // + // If the stored row recorded an embedding_model AND it differs from the + // engine's current model, the query re-embedding would produce a vector in + // a different space, making the resulting Jaccard meaningless. We surface + // this as a distinct error so callers can skip or re-embed rather than + // silently reporting catastrophic drift. + // + // Legacy rows (stored_model == None, i.e. pre-Phase-2 events) are accepted + // with a warning — we cannot validate compatibility but also cannot reject + // all historical data. + if let Some(ref stored) = stored_model { + let current = engine.embedding_model(); + if stored != current { + return Err(EngineError::IncompatibleEmbeddingModel { + stored: stored.clone(), + current: current.to_string(), + }); + } + } else { + tracing::warn!( + event_id = %event_id, + "regression_check: brain_events row has no embedding_model (legacy row); \ + proceeding without model compatibility check" + ); + } + + // Step 3: replay with current weights. + let current_top_atoms = replay( + engine, + &namespace, + &query_text, + None, // current weights + original_top_atoms.len().max(10), + ) + .await?; + + // Step 4: compute Jaccard. + let report = compute_diff_report(original_top_atoms.clone(), current_top_atoms.clone()); + + let timestamp_original = + DateTime::from_timestamp_micros(created_at_us).unwrap_or_else(Utc::now); + + Ok(RegressionReport { + event_id, + query_text, + original_top_atoms, + current_top_atoms, + jaccard: report.jaccard, + added: report.added, + dropped: report.dropped, + timestamp_original, + }) +} + +/// Load a brain_events row and extract replay inputs. (engine-gated) +#[cfg(feature = "engine")] +async fn load_brain_event( + engine: &EmbeddedEngine, + event_id: Uuid, +) -> Result<(String, Vec, String, i64, Option), EngineError> { + // Use the legacy conn() path (Arc>) which is Send + Clone. + let conn = engine.store().conn(); + let event_id_str = event_id.to_string(); + + tokio::task::spawn_blocking(move || { + let conn = conn.lock(); + let guard = &*conn; + + // Select query_text, payload (top_atoms lives in payload JSON), + // actor_id (our namespace proxy), created_at, and the embedding_model + // column added in migration v27. + let result = guard + .query_row( + "SELECT query_text, payload, actor_id, created_at, embedding_model + FROM brain_events WHERE id = ?1", + params![event_id_str.clone()], + |row| { + let query_text: Option = row.get(0)?; + let payload_str: String = row.get(1)?; + let actor_id: Option = row.get(2)?; + let created_at: i64 = row.get(3)?; + let embedding_model: Option = row.get(4)?; + Ok(( + query_text, + payload_str, + actor_id, + created_at, + embedding_model, + )) + }, + ) + .optional() + .map_err(|e| EngineError::Internal(format!("load_brain_event query: {e}")))?; + + let (query_text_opt, payload_str, actor_id_opt, created_at, stored_model) = result + .ok_or_else(|| { + EngineError::NotFound(format!("brain_events row not found: {event_id}")) + })?; + + let query_text = query_text_opt.unwrap_or_default(); + + // B4 fix: propagate JSON parse errors instead of silently substituting + // `{}`, which would cause `top_atoms` to be empty and `regression_check` + // to report false 100% drift. + let payload: serde_json::Value = serde_json::from_str(&payload_str).map_err(|e| { + EngineError::InvalidData(format!( + "brain_events row {event_id} has unparseable payload JSON: {e}" + )) + })?; + + // top_atoms in payload is an array of UUID strings. + let top_atoms: Vec = payload + .get("top_atoms") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().and_then(|s| s.parse::().ok())) + .collect() + }) + .unwrap_or_default(); + + // B4 fix (continued): an empty top_atoms list is also invalid data — + // it would produce a trivially-true jaccard=0 without indicating real drift. + if top_atoms.is_empty() { + return Err(EngineError::InvalidData(format!( + "brain_events row {event_id} has missing or empty top_atoms in payload" + ))); + } + + // namespace from payload field (most reliable) or actor_id column. + // Note: stored payload uses key "lambda_id" (legacy; kept for DB compat). + let namespace = payload + .get("lambda_id") + .or_else(|| payload.get("namespace")) + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + .or(actor_id_opt) + .unwrap_or_default(); + + Ok((query_text, top_atoms, namespace, created_at, stored_model)) + }) + .await + .map_err(|e| EngineError::Internal(format!("load_brain_event join: {e}")))? +} + +// --------------------------------------------------------------------------- +// Drift Metrics +// --------------------------------------------------------------------------- + +/// Drift metrics for the Three Observables feedback loop. +pub mod metrics { + use super::*; + + /// M1: Rolling 7-day median Jaccard stability. (engine-gated) + #[cfg(feature = "engine")] + pub async fn jaccard_stability_7d( + engine: &EmbeddedEngine, + namespace: &str, + ) -> Result { + let conn = engine.store().conn(); + // brain_events.payload stores namespace under legacy key "lambda_id" (#2536). + // The JSON key cannot be renamed without a data migration; the column name + // was already `namespace` in v25 when the table was created. + let namespace_str = namespace.to_string(); + + // Collect event IDs from the last 7 days where actor_id matches namespace. + let event_ids: Vec = { + let conn = Arc::clone(&conn); + tokio::task::spawn_blocking(move || { + let c = conn.lock(); + let cutoff_us = (Utc::now() - chrono::Duration::days(7)).timestamp_micros(); + let mut stmt = c + .prepare( + "SELECT id FROM brain_events + WHERE kind = 'ComposeEvent' + AND created_at >= ?1 + AND json_extract(payload, '$.lambda_id') = ?2 + ORDER BY created_at DESC", + ) + .map_err(|e| { + EngineError::Internal(format!("jaccard_stability_7d prepare: {e}")) + })?; + + let rows = stmt + .query_map(params![cutoff_us, namespace_str], |row| { + row.get::<_, String>(0) + }) + .map_err(|e| { + EngineError::Internal(format!("jaccard_stability_7d query: {e}")) + })?; + + let ids: Vec = rows + .filter_map(|r| r.ok()) + .filter_map(|s| s.parse::().ok()) + .collect(); + Ok::, EngineError>(ids) + }) + .await + .map_err(|e| EngineError::Internal(format!("jaccard_stability_7d join: {e}")))?? + }; + + if event_ids.is_empty() { + return Ok(1.0); + } + + // Run regression_check on each event; collect Jaccard values. + let mut jaccards: Vec = Vec::new(); + for eid in event_ids { + match regression_check(engine, eid).await { + Ok(report) => jaccards.push(report.jaccard), + Err(_) => { + // Non-fatal: skip events that fail to replay (e.g., empty query). + continue; + } + } + } + + if jaccards.is_empty() { + return Ok(1.0); + } + + // Median (sort + mid point). + jaccards.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + let mid = jaccards.len() / 2; + let median = if jaccards.len() % 2 == 0 { + (jaccards[mid - 1] + jaccards[mid]) / 2.0 + } else { + jaccards[mid] + }; + + Ok(median) + } + + /// M2: Rank variance for an atom across all compose events where it appeared. + /// + /// High variance = context-sensitive atom; low variance = reliably ranked. + /// Variance is computed over the 0-indexed rank positions in `top_atoms` + /// arrays stored in `brain_events.payload`. + /// + /// Returns 0.0 when the atom has appeared in fewer than 2 events. + pub async fn atom_rank_variance( + conn: &Arc>, + namespace: &str, + atom_id: Uuid, + ) -> Result { + let conn = Arc::clone(conn); + let atom_id_str = atom_id.to_string(); + let namespace_str = namespace.to_string(); + + tokio::task::spawn_blocking(move || { + let c = conn.lock(); + let mut stmt = c + .prepare( + "SELECT payload FROM brain_events + WHERE kind = 'ComposeEvent' + AND json_extract(payload, '$.lambda_id') = ?1", + ) + .map_err(|e| EngineError::Internal(format!("atom_rank_variance prepare: {e}")))?; + + let rows = stmt + .query_map(params![namespace_str], |row| row.get::<_, String>(0)) + .map_err(|e| EngineError::Internal(format!("atom_rank_variance query: {e}")))?; + + let mut ranks: Vec = Vec::new(); + for row in rows.filter_map(|r| r.ok()) { + let payload: serde_json::Value = + serde_json::from_str(&row).unwrap_or(serde_json::json!({})); + if let Some(top_atoms) = payload.get("top_atoms").and_then(|v| v.as_array()) { + if let Some(pos) = top_atoms + .iter() + .position(|v| v.as_str() == Some(&atom_id_str)) + { + ranks.push(pos as f32); + } + } + } + + if ranks.len() < 2 { + return Ok(0.0_f32); + } + + let mean = ranks.iter().sum::() / ranks.len() as f32; + let variance = + ranks.iter().map(|r| (r - mean).powi(2)).sum::() / ranks.len() as f32; + Ok(variance) + }) + .await + .map_err(|e| EngineError::Internal(format!("atom_rank_variance join: {e}")))? + } + + /// M3: Count of weight_events per calendar day over the last `days` days. + /// + /// A sudden spike in adjustment rate signals a potential runaway feedback loop. + /// Returns a vec of `(NaiveDate, count)` sorted by date ascending. + pub async fn adjustment_rate_per_day( + conn: &Arc>, + namespace: &str, + days: u32, + ) -> Result, EngineError> { + let conn = Arc::clone(conn); + let namespace_str = namespace.to_string(); + let days_i64 = days as i64; + + tokio::task::spawn_blocking(move || { + let c = conn.lock(); + let cutoff_us = (Utc::now() - chrono::Duration::days(days_i64)).timestamp_micros(); + + let mut stmt = c + .prepare( + // SQLite: integer division gives day bucket (micros / 86_400_000_000). + "SELECT ts / 86400000000 AS day_bucket, COUNT(*) as cnt + FROM weight_events + WHERE namespace = ?1 AND ts >= ?2 + GROUP BY day_bucket + ORDER BY day_bucket ASC", + ) + .map_err(|e| { + EngineError::Internal(format!("adjustment_rate_per_day prepare: {e}")) + })?; + + let rows = stmt + .query_map(params![namespace_str, cutoff_us], |row| { + let day_bucket: i64 = row.get(0)?; + let cnt: i64 = row.get(1)?; + Ok((day_bucket, cnt as u64)) + }) + .map_err(|e| { + EngineError::Internal(format!("adjustment_rate_per_day query: {e}")) + })?; + + let mut result = Vec::new(); + for row in rows.filter_map(|r| r.ok()) { + let (day_bucket, cnt) = row; + // day_bucket = days since Unix epoch. + // NaiveDate::from_num_days_from_ce expects days from year 1, so offset. + // Unix epoch (1970-01-01) = day 719_163 in from_num_days_from_ce. + const UNIX_EPOCH_CE_DAYS: i32 = 719_163; + let date = + NaiveDate::from_num_days_from_ce_opt(UNIX_EPOCH_CE_DAYS + day_bucket as i32) + .unwrap_or(NaiveDate::from_ymd_opt(1970, 1, 1).unwrap()); + result.push((date, cnt)); + } + + Ok(result) + }) + .await + .map_err(|e| EngineError::Internal(format!("adjustment_rate_per_day join: {e}")))? + } +} + +// --------------------------------------------------------------------------- +// Unit tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use khive_db::SqliteStore; + + fn make_conn() -> Arc> { + let store = SqliteStore::memory().expect("in-memory store"); + store.conn() + } + + fn insert_weight_event( + conn: &Arc>, + namespace: &str, + atom_id: &str, + weight_after: f32, + ts_us: i64, + ) { + let c = conn.lock(); + c.execute( + "INSERT INTO weight_events (namespace, atom_id, delta, weight_after, channel, eta, ts) + VALUES (?1, ?2, 0.1, ?3, 'explicit', 0.1, ?4)", + params![namespace, atom_id, weight_after as f64, ts_us], + ) + .expect("insert weight_event"); + } + + #[tokio::test] + async fn test_weights_as_of_returns_snapshot_at_time() { + let conn = make_conn(); + let lambda = "lambda:test"; + let atom = Uuid::new_v4(); + let atom_str = atom.to_string(); + + // t0: weight 1.5 + let t0_us: i64 = 1_000_000_000; + insert_weight_event(&conn, lambda, &atom_str, 1.5, t0_us); + + // t1: weight 2.5 (later) + let t1_us: i64 = 2_000_000_000; + insert_weight_event(&conn, lambda, &atom_str, 2.5, t1_us); + + // Query at t0 + 1: should see 1.5. + let at_t0 = DateTime::from_timestamp_micros(t0_us + 1).unwrap(); + let snapshot = weights_as_of(&conn, lambda, at_t0) + .await + .expect("weights_as_of"); + let w = *snapshot.get(&atom).expect("atom must be in snapshot"); + assert!((w - 1.5).abs() < 0.01, "expected 1.5 at t0, got {w}"); + + // Query at t1 + 1: should see 2.5. + let at_t1 = DateTime::from_timestamp_micros(t1_us + 1).unwrap(); + let snapshot2 = weights_as_of(&conn, lambda, at_t1) + .await + .expect("weights_as_of at t1"); + let w2 = *snapshot2 + .get(&atom) + .expect("atom must be in snapshot at t1"); + assert!((w2 - 2.5).abs() < 0.01, "expected 2.5 at t1, got {w2}"); + } + + #[tokio::test] + async fn test_weights_as_of_before_any_event_is_empty() { + let conn = make_conn(); + let lambda = "lambda:test"; + let atom = Uuid::new_v4(); + let atom_str = atom.to_string(); + + let t1_us: i64 = 2_000_000_000; + insert_weight_event(&conn, lambda, &atom_str, 2.0, t1_us); + + // Query before t1: no rows. + let before = DateTime::from_timestamp_micros(t1_us - 1).unwrap(); + let snapshot = weights_as_of(&conn, lambda, before) + .await + .expect("weights_as_of"); + assert!( + snapshot.is_empty(), + "snapshot before any event should be empty" + ); + } + + #[tokio::test] + async fn test_rank_history_returns_ordered_events() { + let conn = make_conn(); + let lambda = "lambda:rank_hist"; + let atom = Uuid::new_v4(); + let atom_str = atom.to_string(); + + insert_weight_event(&conn, lambda, &atom_str, 1.2, 1_000); + insert_weight_event(&conn, lambda, &atom_str, 1.4, 2_000); + insert_weight_event(&conn, lambda, &atom_str, 1.1, 3_000); + + let history = rank_history(&conn, lambda, atom) + .await + .expect("rank_history"); + + assert_eq!(history.len(), 3, "expected 3 history points"); + // Verify ascending timestamp order. + assert!(history[0].ts <= history[1].ts); + assert!(history[1].ts <= history[2].ts); + // Verify weights. + assert!((history[0].weight_after - 1.2).abs() < 0.01); + assert!((history[1].weight_after - 1.4).abs() < 0.01); + assert!((history[2].weight_after - 1.1).abs() < 0.01); + } + + #[test] + fn test_compute_diff_report_jaccard() { + let t1 = vec![ + Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap(), + Uuid::parse_str("00000000-0000-0000-0000-000000000002").unwrap(), + Uuid::parse_str("00000000-0000-0000-0000-000000000003").unwrap(), + ]; + let t2 = vec![ + Uuid::parse_str("00000000-0000-0000-0000-000000000002").unwrap(), + Uuid::parse_str("00000000-0000-0000-0000-000000000003").unwrap(), + Uuid::parse_str("00000000-0000-0000-0000-000000000004").unwrap(), + ]; + + let report = compute_diff_report(t1, t2); + + // |intersection| = 2 ({002, 003}), |union| = 4 + assert!( + (report.jaccard - 0.5).abs() < 0.01, + "jaccard={}", + report.jaccard + ); + assert_eq!(report.added.len(), 1, "one atom added"); + assert_eq!(report.dropped.len(), 1, "one atom dropped"); + } + + #[test] + fn test_compute_diff_report_identical() { + let ids: Vec = (1..=3) + .map(|i| Uuid::parse_str(&format!("00000000-0000-0000-0000-{:012}", i)).unwrap()) + .collect(); + + let report = compute_diff_report(ids.clone(), ids); + assert!((report.jaccard - 1.0).abs() < 0.001); + assert!(report.added.is_empty()); + assert!(report.dropped.is_empty()); + } + + #[test] + fn test_compute_diff_report_disjoint() { + let t1 = vec![Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap()]; + let t2 = vec![Uuid::parse_str("00000000-0000-0000-0000-000000000002").unwrap()]; + + let report = compute_diff_report(t1, t2); + assert!((report.jaccard - 0.0).abs() < 0.001); + assert_eq!(report.added.len(), 1); + assert_eq!(report.dropped.len(), 1); + } + + #[tokio::test] + async fn test_adjustment_rate_per_day() { + let conn = make_conn(); + let lambda = "lambda:rate_test"; + let atom = Uuid::new_v4(); + let atom_str = atom.to_string(); + + // Insert 3 events: 2 "today" and 1 "yesterday". + let now_us = Utc::now().timestamp_micros(); + let yesterday_us = now_us - 86_400_000_001_i64; // slightly over 24h ago + + insert_weight_event(&conn, lambda, &atom_str, 1.1, now_us - 100); + insert_weight_event(&conn, lambda, &atom_str, 1.2, now_us - 50); + insert_weight_event(&conn, lambda, &atom_str, 1.3, yesterday_us); + + let rates = metrics::adjustment_rate_per_day(&conn, lambda, 7) + .await + .expect("adjustment_rate_per_day"); + + // At least 2 buckets (today and yesterday within 7 days). + assert!( + rates.len() >= 1, + "expected at least 1 day bucket, got {:?}", + rates + ); + // Sum of all counts should be 3. + let total: u64 = rates.iter().map(|(_, c)| c).sum(); + assert_eq!(total, 3, "expected 3 total events"); + } +} diff --git a/crates/khive-retrieval/src/replay/mod.rs b/crates/khive-retrieval/src/replay/mod.rs new file mode 100644 index 00000000..0c90700b --- /dev/null +++ b/crates/khive-retrieval/src/replay/mod.rs @@ -0,0 +1,5 @@ +//! Temporal replay APIs for retrieval weight analysis. + +pub mod engine_replay; + +pub use engine_replay::*; diff --git a/crates/khive-retrieval/src/search_config.rs b/crates/khive-retrieval/src/search_config.rs new file mode 100644 index 00000000..3dd3de84 --- /dev/null +++ b/crates/khive-retrieval/src/search_config.rs @@ -0,0 +1,253 @@ +//! Tunable hybrid search configuration — Brain Phase 7 substrate. +//! +//! `SearchConfig` is a per-call configuration that controls how vector and +//! keyword results are retrieved and fused. It is the public API surface for +//! `recall()` and compose's internal search phase. +//! +//! # Defaults (backward-compatible) +//! +//! `SearchConfig::default()` is designed to produce **identical results** to +//! the pre-Phase-7 hardcoded search behavior: RRF with k=60, top_k=10, +//! no min_score filter. Existing callers that do not supply a `SearchConfig` +//! get the same behavior as before. +//! +//! # Presets +//! +//! | Preset | Strategy | vector_weight | +//! |--------|----------|---------------| +//! | `default()` / `hybrid_balanced()` | RRF (k=60) | 0.5 | +//! | `vector_only()` | VectorOnly | 1.0 | +//! | `keyword_only()` | KeywordOnly | 0.0 | +//! +//! # Usage in recall +//! +//! ```rust,ignore +//! let opts = RecallOptions { +//! query: "metal inference kernel".to_string(), +//! search: Some(SearchConfig::vector_only()), +//! ..Default::default() +//! }; +//! service.recall(opts).await?; +//! ``` + +use serde::{Deserialize, Serialize}; + +use khive_fusion::{FusionStrategy, DEFAULT_RRF_K}; + +/// Per-call configuration for hybrid search retrieval and fusion. +/// +/// Added to `RecallOptions` and `ComposeOptions` as `search: Option`. +/// When `None`, callers receive identical behavior to pre-Phase-7 code. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct SearchConfig { + /// Maximum number of results to return. + /// + /// Default: 10. + #[serde(default = "default_top_k")] + pub top_k: usize, + + /// Candidate pool multiplier over `top_k`. + /// + /// The retriever fetches `top_k * candidate_pool_multiplier` candidates + /// before fusion and reranking. Higher values improve recall quality at + /// the cost of more computation. + /// + /// Default: 3. + #[serde(default = "default_multiplier")] + pub candidate_pool_multiplier: usize, + + /// Fusion strategy for combining vector and keyword result lists. + /// + /// Default: RRF with k=60. + #[serde(default = "default_fusion")] + pub fusion_strategy: FusionStrategy, + + /// Weight for vector search in weighted fusion (0.0 to 1.0). + /// + /// Only used when `fusion_strategy` is `Weighted`. Keyword weight is + /// implicitly `1.0 - vector_weight`. + /// + /// Default: 0.5 (balanced). + #[serde(default = "default_vector_weight")] + pub vector_weight: f64, + + /// Minimum score threshold. + /// + /// Results with a final score below this value are filtered out. + /// When `None`, no threshold is applied. + /// + /// Default: None. + #[serde(default)] + pub min_score: Option, +} + +fn default_top_k() -> usize { + 10 +} + +fn default_multiplier() -> usize { + 3 +} + +fn default_fusion() -> FusionStrategy { + FusionStrategy::Rrf { k: DEFAULT_RRF_K } +} + +fn default_vector_weight() -> f64 { + 0.5 +} + +impl Default for SearchConfig { + fn default() -> Self { + Self { + top_k: default_top_k(), + candidate_pool_multiplier: default_multiplier(), + fusion_strategy: default_fusion(), + vector_weight: default_vector_weight(), + min_score: None, + } + } +} + +impl SearchConfig { + /// Preset: skip BM25 entirely, return only vector search results. + /// + /// Use when keyword search degrades quality (e.g., short queries, code search). + pub fn vector_only() -> Self { + Self { + top_k: default_top_k(), + candidate_pool_multiplier: 1, + fusion_strategy: FusionStrategy::VectorOnly, + vector_weight: 1.0, + min_score: None, + } + } + + /// Preset: skip HNSW entirely, return only BM25 keyword results. + /// + /// Use for exact-match retrieval (e.g., medication names, identifiers). + pub fn keyword_only() -> Self { + Self { + top_k: default_top_k(), + candidate_pool_multiplier: 1, + fusion_strategy: FusionStrategy::KeywordOnly, + vector_weight: 0.0, + min_score: None, + } + } + + /// Preset: balanced hybrid search using RRF with k=60. + /// + /// Equivalent to `SearchConfig::default()`. Combines vector and keyword + /// results with equal weight using Reciprocal Rank Fusion. + pub fn hybrid_balanced() -> Self { + Self::default() + } + + /// Set a custom top_k. + #[must_use] + pub fn with_top_k(mut self, top_k: usize) -> Self { + self.top_k = top_k; + self + } + + /// Set the candidate pool multiplier. + #[must_use] + pub fn with_candidate_pool_multiplier(mut self, multiplier: usize) -> Self { + self.candidate_pool_multiplier = multiplier; + self + } + + /// Set a minimum score filter. + #[must_use] + pub fn with_min_score(mut self, min: f64) -> Self { + self.min_score = Some(min); + self + } + + /// Compute the candidate pool size from `top_k * candidate_pool_multiplier`. + pub fn candidate_pool_size(&self) -> usize { + self.top_k * self.candidate_pool_multiplier.max(1) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_config() { + let cfg = SearchConfig::default(); + assert_eq!(cfg.top_k, 10); + assert_eq!(cfg.candidate_pool_multiplier, 3); + assert!((cfg.vector_weight - 0.5).abs() < f64::EPSILON); + assert!(cfg.min_score.is_none()); + assert!(matches!(cfg.fusion_strategy, FusionStrategy::Rrf { k: 60 })); + } + + #[test] + fn test_vector_only_preset() { + let cfg = SearchConfig::vector_only(); + assert!(matches!(cfg.fusion_strategy, FusionStrategy::VectorOnly)); + assert!((cfg.vector_weight - 1.0).abs() < f64::EPSILON); + assert_eq!(cfg.candidate_pool_multiplier, 1); + } + + #[test] + fn test_keyword_only_preset() { + let cfg = SearchConfig::keyword_only(); + assert!(matches!(cfg.fusion_strategy, FusionStrategy::KeywordOnly)); + assert!((cfg.vector_weight - 0.0).abs() < f64::EPSILON); + } + + #[test] + fn test_hybrid_balanced_is_default() { + let balanced = SearchConfig::hybrid_balanced(); + let default = SearchConfig::default(); + assert_eq!(balanced.top_k, default.top_k); + assert_eq!( + balanced.candidate_pool_multiplier, + default.candidate_pool_multiplier + ); + assert!((balanced.vector_weight - default.vector_weight).abs() < f64::EPSILON); + } + + #[test] + fn test_candidate_pool_size() { + let cfg = SearchConfig::default(); + assert_eq!(cfg.candidate_pool_size(), 30); // 10 * 3 + + let cfg = SearchConfig::vector_only().with_top_k(5); + assert_eq!(cfg.candidate_pool_size(), 5); // 5 * 1 + } + + #[test] + fn test_builder_methods() { + let cfg = SearchConfig::default() + .with_top_k(20) + .with_candidate_pool_multiplier(5) + .with_min_score(0.3); + assert_eq!(cfg.top_k, 20); + assert_eq!(cfg.candidate_pool_multiplier, 5); + assert_eq!(cfg.min_score, Some(0.3)); + assert_eq!(cfg.candidate_pool_size(), 100); + } + + #[test] + fn test_serde_roundtrip() { + let cfg = SearchConfig { + top_k: 15, + candidate_pool_multiplier: 4, + fusion_strategy: FusionStrategy::Weighted { + weights: vec![0.7, 0.3], + }, + vector_weight: 0.7, + min_score: Some(0.1), + }; + let json = serde_json::to_string(&cfg).unwrap(); + let back: SearchConfig = serde_json::from_str(&json).unwrap(); + assert_eq!(back.top_k, 15); + assert_eq!(back.candidate_pool_multiplier, 4); + assert_eq!(back.min_score, Some(0.1)); + } +} diff --git a/crates/khive-retrieval/src/timeout.rs b/crates/khive-retrieval/src/timeout.rs new file mode 100644 index 00000000..1f3ba20b --- /dev/null +++ b/crates/khive-retrieval/src/timeout.rs @@ -0,0 +1,435 @@ +//! Timeout and cancellation support for search operations. +//! +//! Provides utilities for wrapping search futures with timeout and cancellation +//! semantics. Uses `tokio::time::timeout` for deadline enforcement and +//! `tokio_util::sync::CancellationToken` for cooperative cancellation. +//! +//! # Design +//! +//! Timeout and cancellation are applied at the search entry points (hybrid search, +//! graph traversal) rather than at every internal function call. This keeps the +//! internal algorithms clean while providing operational safety at the boundaries. +//! +//! # Usage +//! +//! ```rust,ignore +//! use std::time::Duration; +//! use khive_retrieval::timeout::{search_with_timeout, search_with_cancellation}; +//! use tokio_util::sync::CancellationToken; +//! +//! // Timeout: cancel if search takes longer than 5 seconds +//! let results = search_with_timeout( +//! searcher.hybrid_search(&query, &config), +//! Duration::from_secs(5), +//! ).await?; +//! +//! // Cancellation: cancel via token (e.g., from a request handler) +//! let token = CancellationToken::new(); +//! let results = search_with_cancellation( +//! searcher.hybrid_search(&query, &config), +//! token.clone(), +//! ).await?; +//! +//! // From another task: +//! token.cancel(); +//! ``` +//! +//! See also: [`HybridConfig::timeout`] for declarative timeout configuration. + +use std::future::Future; +use std::time::Duration; + +use tokio_util::sync::CancellationToken; + +use crate::error::{Result, RetrievalError}; + +/// Execute a search future with a timeout. +/// +/// Wraps the given future with `tokio::time::timeout`. If the future does not +/// complete within the specified duration, returns [`RetrievalError::QueryTimeout`]. +/// +/// # Arguments +/// +/// * `future` - The search operation to execute +/// * `duration` - Maximum time to wait for completion +/// +/// # Returns +/// +/// The search result if completed within the timeout, or `QueryTimeout` error. +/// +/// # Example +/// +/// ```rust,ignore +/// use std::time::Duration; +/// use khive_retrieval::timeout::search_with_timeout; +/// +/// let results = search_with_timeout( +/// searcher.hybrid_search(&query, &config), +/// Duration::from_secs(5), +/// ).await?; +/// ``` +pub async fn search_with_timeout(future: F, duration: Duration) -> Result +where + F: Future>, +{ + match tokio::time::timeout(duration, future).await { + Ok(result) => result, + Err(_elapsed) => Err(RetrievalError::QueryTimeout { + elapsed_ms: duration.as_millis() as u64, + }), + } +} + +/// Execute a search future with an optional timeout. +/// +/// If `timeout` is `Some`, wraps the future with [`search_with_timeout`]. +/// If `None`, executes the future directly without timeout. +/// +/// This is a convenience function for use with [`HybridConfig::timeout`]. +/// +/// # Arguments +/// +/// * `future` - The search operation to execute +/// * `timeout` - Optional maximum time to wait +/// +/// # Returns +/// +/// The search result, or `QueryTimeout` if the timeout elapsed. +pub async fn search_with_optional_timeout(future: F, timeout: Option) -> Result +where + F: Future>, +{ + match timeout { + Some(duration) => search_with_timeout(future, duration).await, + None => future.await, + } +} + +/// Execute a search future with a cancellation token. +/// +/// Uses `tokio::select!` to race the search future against the cancellation token. +/// If the token is cancelled before the search completes, returns +/// [`RetrievalError::QueryCancelled`]. +/// +/// # Arguments +/// +/// * `future` - The search operation to execute +/// * `token` - Cancellation token to observe +/// +/// # Returns +/// +/// The search result if completed before cancellation, or `QueryCancelled` error. +/// +/// # Example +/// +/// ```rust,ignore +/// use tokio_util::sync::CancellationToken; +/// use khive_retrieval::timeout::search_with_cancellation; +/// +/// let token = CancellationToken::new(); +/// let token_clone = token.clone(); +/// +/// // Spawn a task that cancels after 1 second +/// tokio::spawn(async move { +/// tokio::time::sleep(Duration::from_secs(1)).await; +/// token_clone.cancel(); +/// }); +/// +/// let results = search_with_cancellation( +/// searcher.hybrid_search(&query, &config), +/// token, +/// ).await?; +/// ``` +pub async fn search_with_cancellation(future: F, token: CancellationToken) -> Result +where + F: Future>, +{ + tokio::select! { + result = future => result, + _ = token.cancelled() => Err(RetrievalError::QueryCancelled), + } +} + +/// Execute a search future with both timeout and optional cancellation. +/// +/// Combines timeout and cancellation into a single wrapper. The search will +/// be terminated if either: +/// - The timeout duration elapses (`QueryTimeout`) +/// - The cancellation token is triggered (`QueryCancelled`) +/// - The search completes normally +/// +/// # Arguments +/// +/// * `future` - The search operation to execute +/// * `timeout` - Optional maximum time to wait +/// * `cancel` - Optional cancellation token to observe +/// +/// # Returns +/// +/// The search result, or an appropriate error if timed out or cancelled. +pub async fn search_with_deadline( + future: F, + timeout: Option, + cancel: Option, +) -> Result +where + F: Future>, +{ + match (timeout, cancel) { + (Some(duration), Some(token)) => { + tokio::select! { + result = tokio::time::timeout(duration, future) => { + match result { + Ok(inner) => inner, + Err(_elapsed) => Err(RetrievalError::QueryTimeout { + elapsed_ms: duration.as_millis() as u64, + }), + } + } + _ = token.cancelled() => Err(RetrievalError::QueryCancelled), + } + } + (Some(duration), None) => search_with_timeout(future, duration).await, + (None, Some(token)) => search_with_cancellation(future, token).await, + (None, None) => future.await, + } +} + +/// Serde support for `Option` as milliseconds. +/// +/// Serializes `Duration` as `u64` milliseconds for JSON compatibility. +pub(crate) mod serde_opt_duration { + use serde::{Deserialize, Deserializer, Serialize, Serializer}; + use std::time::Duration; + + /// Intermediate representation for serde. + #[derive(Serialize, Deserialize)] + struct DurationMs(u64); + + /// Serialize `Option` as optional milliseconds. + pub fn serialize(value: &Option, serializer: S) -> Result + where + S: Serializer, + { + match value { + Some(d) => DurationMs(d.as_millis() as u64).serialize(serializer), + None => serializer.serialize_none(), + } + } + + /// Deserialize `Option` from optional milliseconds. + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + let opt: Option = Option::deserialize(deserializer)?; + Ok(opt.map(|ms| Duration::from_millis(ms.0))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::time::Duration; + + #[tokio::test] + async fn test_search_with_timeout_completes() { + // A future that completes immediately + let future = async { Ok::<_, RetrievalError>(vec![1, 2, 3]) }; + let result = search_with_timeout(future, Duration::from_secs(5)).await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), vec![1, 2, 3]); + } + + #[tokio::test] + async fn test_search_with_timeout_expires() { + // A future that takes too long + let future = async { + tokio::time::sleep(Duration::from_secs(10)).await; + Ok::<_, RetrievalError>(vec![1, 2, 3]) + }; + let result = search_with_timeout(future, Duration::from_millis(50)).await; + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(matches!(err, RetrievalError::QueryTimeout { .. })); + assert!(err.is_transient()); + } + + #[tokio::test] + async fn test_search_with_timeout_propagates_error() { + // A future that fails with a different error + let future = async { Err::, _>(RetrievalError::invalid_query("bad query")) }; + let result = search_with_timeout(future, Duration::from_secs(5)).await; + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + RetrievalError::InvalidQuery(_) + )); + } + + #[tokio::test] + async fn test_search_with_optional_timeout_none() { + // No timeout means direct execution + let future = async { Ok::<_, RetrievalError>(42) }; + let result = search_with_optional_timeout(future, None).await; + assert_eq!(result.unwrap(), 42); + } + + #[tokio::test] + async fn test_search_with_optional_timeout_some() { + // With timeout, same as search_with_timeout + let future = async { + tokio::time::sleep(Duration::from_secs(10)).await; + Ok::<_, RetrievalError>(42) + }; + let result = search_with_optional_timeout(future, Some(Duration::from_millis(50))).await; + assert!(matches!( + result.unwrap_err(), + RetrievalError::QueryTimeout { .. } + )); + } + + #[tokio::test] + async fn test_search_with_cancellation_completes() { + let token = CancellationToken::new(); + let future = async { Ok::<_, RetrievalError>(vec![1, 2, 3]) }; + let result = search_with_cancellation(future, token).await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_search_with_cancellation_cancelled() { + let token = CancellationToken::new(); + let token_clone = token.clone(); + + // Cancel immediately + token_clone.cancel(); + + let future = async { + tokio::time::sleep(Duration::from_secs(10)).await; + Ok::<_, RetrievalError>(vec![1, 2, 3]) + }; + let result = search_with_cancellation(future, token).await; + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(matches!(err, RetrievalError::QueryCancelled)); + assert!(err.is_transient()); + } + + #[tokio::test] + async fn test_search_with_cancellation_delayed() { + let token = CancellationToken::new(); + let token_clone = token.clone(); + + // Cancel after a short delay + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(20)).await; + token_clone.cancel(); + }); + + let future = async { + tokio::time::sleep(Duration::from_secs(10)).await; + Ok::<_, RetrievalError>(vec![1, 2, 3]) + }; + let result = search_with_cancellation(future, token).await; + assert!(matches!( + result.unwrap_err(), + RetrievalError::QueryCancelled + )); + } + + #[tokio::test] + async fn test_search_with_deadline_timeout_and_cancel() { + let token = CancellationToken::new(); + + // Timeout fires first (50ms vs 10s sleep) + let future = async { + tokio::time::sleep(Duration::from_secs(10)).await; + Ok::<_, RetrievalError>(42) + }; + let result = + search_with_deadline(future, Some(Duration::from_millis(50)), Some(token)).await; + assert!(matches!( + result.unwrap_err(), + RetrievalError::QueryTimeout { .. } + )); + } + + #[tokio::test] + async fn test_search_with_deadline_cancel_fires_first() { + let token = CancellationToken::new(); + let token_clone = token.clone(); + + // Cancel immediately, timeout is long + token_clone.cancel(); + + let future = async { + tokio::time::sleep(Duration::from_secs(10)).await; + Ok::<_, RetrievalError>(42) + }; + let result = search_with_deadline(future, Some(Duration::from_secs(60)), Some(token)).await; + assert!(matches!( + result.unwrap_err(), + RetrievalError::QueryCancelled + )); + } + + #[tokio::test] + async fn test_search_with_deadline_neither() { + // No timeout, no cancellation: direct execution + let future = async { Ok::<_, RetrievalError>(42) }; + let result = search_with_deadline(future, None, None).await; + assert_eq!(result.unwrap(), 42); + } + + #[tokio::test] + async fn test_timeout_error_display() { + let err = RetrievalError::query_timeout(5000); + assert_eq!(err.to_string(), "query timed out after 5000ms"); + } + + #[tokio::test] + async fn test_cancelled_error_display() { + let err = RetrievalError::query_cancelled(); + assert_eq!(err.to_string(), "query cancelled"); + } + + #[tokio::test] + async fn test_timeout_error_is_transient() { + assert!(RetrievalError::query_timeout(100).is_transient()); + assert!(RetrievalError::query_cancelled().is_transient()); + assert!(!RetrievalError::query_timeout(100).is_permanent()); + assert!(!RetrievalError::query_cancelled().is_permanent()); + } + + #[test] + fn test_serde_opt_duration_roundtrip() { + use serde::{Deserialize, Serialize}; + + #[derive(Serialize, Deserialize, Debug, PartialEq)] + struct TestConfig { + #[serde( + default, + skip_serializing_if = "Option::is_none", + with = "super::serde_opt_duration" + )] + timeout: Option, + } + + // With timeout + let config = TestConfig { + timeout: Some(Duration::from_millis(5000)), + }; + let json = serde_json::to_string(&config).unwrap(); + assert!(json.contains("5000")); + let restored: TestConfig = serde_json::from_str(&json).unwrap(); + assert_eq!(restored.timeout, Some(Duration::from_millis(5000))); + + // Without timeout + let config = TestConfig { timeout: None }; + let json = serde_json::to_string(&config).unwrap(); + assert!(!json.contains("timeout")); + let restored: TestConfig = serde_json::from_str("{}").unwrap(); + assert_eq!(restored.timeout, None); + } +} diff --git a/crates/khive-retrieval/src/weights/engine_weights.rs b/crates/khive-retrieval/src/weights/engine_weights.rs new file mode 100644 index 00000000..7530767c --- /dev/null +++ b/crates/khive-retrieval/src/weights/engine_weights.rs @@ -0,0 +1,561 @@ +//! Unified weight store — Three Observables Feedback Loop (Phase 2.A). +//! +//! This module provides the core EMA-update + audit-log primitives backing +//! all three feedback channels: +//! +//! | Channel | η | Signal source | +//! |-------------|-------|--------------------------------------| +//! | Ambient | 0.003 | Every recall / compose operation | +//! | Explicit | 0.10 | note.create quality score | +//! | GroundTruth | 0.50 | Atlas eval / CLI manual trigger | +//! +//! # Weight semantics +//! +//! Weights live in `atom_weights(namespace, atom_id)` and are bounded to +//! `[WEIGHT_FLOOR, WEIGHT_CEIL]` = [0.1, 5.0]. Missing rows are treated as +//! implicit 1.0 by callers; this module never inserts rows on first read. +//! +//! # EMA formula +//! +//! ```text +//! new_weight = clamp(old_weight * (1 - η) + delta, WEIGHT_FLOOR, WEIGHT_CEIL) +//! ``` +//! +//! `old_weight` defaults to 1.0 when no row exists yet. + +use std::collections::HashMap; +use std::sync::Arc; + +use parking_lot::Mutex; +use rusqlite::{params, Connection, OptionalExtension}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::persist::PersistError; + +// --------------------------------------------------------------------------- +// Constants +// --------------------------------------------------------------------------- + +/// Lower bound for atom weights — prevents a weight from reaching zero. +pub const WEIGHT_FLOOR: f32 = 0.1; + +/// Upper bound for atom weights — prevents runaway boosting. +pub const WEIGHT_CEIL: f32 = 5.0; + +// --------------------------------------------------------------------------- +// WeightChannel +// --------------------------------------------------------------------------- + +/// Three feedback channels each with a distinct learning rate η. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum WeightChannel { + /// η = 0.003 — every recall / compose invocation. + Ambient, + /// η = 0.10 — voluntary quality signal via note.create / note.correct. + Explicit, + /// η = 0.50 — Atlas eval or manual CLI trigger. + GroundTruth, +} + +impl WeightChannel { + /// Learning rate η for this channel. + pub fn eta(self) -> f32 { + match self { + Self::Ambient => 0.003, + Self::Explicit => 0.10, + Self::GroundTruth => 0.50, + } + } + + /// Canonical snake_case string stored in `weight_events.channel`. + pub fn as_str(self) -> &'static str { + match self { + Self::Ambient => "ambient", + Self::Explicit => "explicit", + Self::GroundTruth => "ground_truth", + } + } +} + +// --------------------------------------------------------------------------- +// apply_weight_delta +// --------------------------------------------------------------------------- + +/// Apply an EMA weight update to one `(namespace, atom_id)` pair and append +/// an audit row to `weight_events`. +/// +/// # Algorithm +/// +/// ```text +/// old = atom_weights[namespace, atom_id].weight (default 1.0 if missing) +/// new = clamp(old * (1 − η) + delta, WEIGHT_FLOOR, WEIGHT_CEIL) +/// ``` +/// +/// Both the `atom_weights` upsert and the `weight_events` insert execute inside +/// a single `BEGIN IMMEDIATE` transaction so they are atomically consistent. +/// +/// # Returns +/// +/// `(new_weight, weight_event_row_id)` on success. +pub async fn apply_weight_delta( + conn: &Arc>, + namespace: &str, + atom_id: Uuid, + delta: f32, + channel: WeightChannel, + event_id: Option, + context_id: Option<&str>, +) -> Result<(f32, i64), PersistError> { + apply_weight_delta_with_eta( + conn, + namespace, + atom_id, + delta, + channel, + channel.eta(), + event_id, + context_id, + ) + .await +} + +/// Variant of [`apply_weight_delta`] that accepts a runtime-overridden `eta`. +/// +/// Use this when the caller loads η from runtime config (e.g., atlas's +/// `knowledge.toml` override of Channel C's default 0.50). Same algorithm +/// and transactional guarantees as [`apply_weight_delta`]. +pub async fn apply_weight_delta_with_eta( + conn: &Arc>, + namespace: &str, + atom_id: Uuid, + delta: f32, + channel: WeightChannel, + eta: f32, + event_id: Option, + context_id: Option<&str>, +) -> Result<(f32, i64), PersistError> { + if namespace.is_empty() { + tracing::warn!( + atom_id = %atom_id, + channel = %channel.as_str(), + "apply_weight_delta called with empty namespace — rejecting to avoid dead-namespace pollution" + ); + return Err(PersistError::Validation( + "namespace must not be empty".to_string(), + )); + } + if !(0.0..=1.0).contains(&eta) { + return Err(PersistError::Validation(format!( + "eta must be in [0.0, 1.0], got {eta}" + ))); + } + let conn = Arc::clone(conn); + let namespace_str = namespace.to_string(); + let atom_id_str = atom_id.to_string(); + let channel_str = channel.as_str(); + let event_id_str = event_id.map(|u| u.to_string()); + let context_id = context_id.map(|s| s.to_string()); + let now_us = chrono::Utc::now().timestamp_micros(); + + tokio::task::spawn_blocking(move || { + let conn = conn.lock(); + + let tx = + rusqlite::Transaction::new_unchecked(&conn, rusqlite::TransactionBehavior::Immediate)?; + + // Read current weight (default 1.0 if row absent). + let old_weight: f32 = tx + .query_row( + "SELECT weight FROM atom_weights WHERE namespace = ?1 AND atom_id = ?2", + params![namespace_str, atom_id_str], + |row| row.get::<_, f64>(0), + ) + .optional() + .map_err(PersistError::from)? + .unwrap_or(1.0_f64) as f32; + + // EMA update + clamp. + let new_weight = (old_weight * (1.0 - eta) + delta).clamp(WEIGHT_FLOOR, WEIGHT_CEIL); + + // Upsert atom_weights — increment version on each write. + tx.execute( + "INSERT INTO atom_weights (namespace, atom_id, weight, updated_at, version) + VALUES (?1, ?2, ?3, ?4, 1) + ON CONFLICT(namespace, atom_id) DO UPDATE SET + weight = excluded.weight, + updated_at = excluded.updated_at, + version = version + 1", + params![namespace_str, atom_id_str, new_weight as f64, now_us], + )?; + + // Append weight_events audit row. + tx.execute( + "INSERT INTO weight_events + (namespace, atom_id, delta, weight_after, channel, eta, event_id, context_id, ts) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)", + params![ + namespace_str, + atom_id_str, + delta as f64, + new_weight as f64, + channel_str, + eta as f64, + event_id_str, + context_id, + now_us, + ], + )?; + + let row_id = tx.last_insert_rowid(); + tx.commit()?; + + Ok((new_weight, row_id)) + }) + .await? +} + +// --------------------------------------------------------------------------- +// batch_load_weights +// --------------------------------------------------------------------------- + +/// Batch-load current weights for a slice of atom IDs under one lambda. +/// +/// Only rows that exist in `atom_weights` are returned. Missing atoms are +/// **not** inserted; callers should treat absent entries as implicit 1.0. +/// +/// Uses a single SQL query with a dynamic `IN (...)` clause. The batch is +/// chunked when `atom_ids` exceeds the SQLite 999-bind-param ceiling. +pub async fn batch_load_weights( + conn: &Arc>, + namespace: &str, + atom_ids: &[Uuid], +) -> Result, PersistError> { + if atom_ids.is_empty() { + return Ok(HashMap::new()); + } + + let conn = Arc::clone(conn); + let namespace_str = namespace.to_string(); + // Convert UUIDs to strings once, then move into the blocking closure. + let id_strs: Vec = atom_ids.iter().map(|u| u.to_string()).collect(); + + tokio::task::spawn_blocking(move || { + let conn = conn.lock(); + let mut result = HashMap::with_capacity(id_strs.len()); + + // Chunk to stay within the SQLite 999-parameter limit. + // Each chunk uses 1 param (namespace) + N params (atom_ids) = N+1 total. + const CHUNK_SIZE: usize = 998; + + for chunk in id_strs.chunks(CHUNK_SIZE) { + let placeholders = chunk + .iter() + .enumerate() + .map(|(i, _)| format!("?{}", i + 2)) + .collect::>() + .join(", "); + + let sql = format!( + "SELECT atom_id, weight FROM atom_weights \ + WHERE namespace = ?1 AND atom_id IN ({placeholders})" + ); + + let mut stmt = conn.prepare(&sql).map_err(PersistError::from)?; + + let mut param_values: Vec = Vec::with_capacity(chunk.len() + 1); + param_values.push(rusqlite::types::Value::Text(namespace_str.clone())); + for s in chunk { + param_values.push(rusqlite::types::Value::Text(s.clone())); + } + + let mut rows = stmt + .query(rusqlite::params_from_iter(param_values)) + .map_err(PersistError::from)?; + + while let Some(row) = rows.next().map_err(PersistError::from)? { + let aid: String = row.get(0).map_err(PersistError::from)?; + let w: f64 = row.get(1).map_err(PersistError::from)?; + if let Ok(uuid) = aid.parse::() { + // Clamp on read — symmetric with write-side invariant. Protects compose + // from weight=0 rows introduced by manual SQL or future schema drift. + let clamped = (w as f32).clamp(WEIGHT_FLOOR, WEIGHT_CEIL); + result.insert(uuid, clamped); + } + } + } + + Ok(result) + }) + .await? +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use khive_db::SqliteStore; + use std::sync::Arc; + + fn make_conn() -> Arc> { + // Open an in-memory SQLite DB and run migrations so atom_weights and + // weight_events tables exist. + let store = SqliteStore::memory().expect("in-memory store"); + store.conn() + } + + // ------------------------------------------------------------------------- + // Test 1 — ambient channel drives weight above 1.0 over 5 ticks + // ------------------------------------------------------------------------- + #[tokio::test] + async fn test_apply_weight_delta_ambient_channel() { + let conn = make_conn(); + let lambda = "lambda:test"; + let atom = Uuid::new_v4(); + let delta = 0.01_f32; // positive ambient nudge + + let mut last_weight = 1.0_f32; + for _ in 0..5 { + let (w, _row_id) = apply_weight_delta( + &conn, + lambda, + atom, + delta, + WeightChannel::Ambient, + None, + None, + ) + .await + .expect("apply_weight_delta should succeed"); + last_weight = w; + } + + assert!( + last_weight > 1.0, + "weight should rise above 1.0 with positive delta, got {last_weight}" + ); + assert!( + last_weight < WEIGHT_CEIL, + "weight should not reach ceiling after 5 ticks" + ); + + // Verify 5 audit rows were written. + let map = batch_load_weights(&conn, lambda, &[atom]) + .await + .expect("batch_load_weights"); + assert!(map.contains_key(&atom), "weight row should exist"); + + // Count weight_events rows directly. + let count: i64 = { + let c = conn.lock(); + c.query_row( + "SELECT COUNT(*) FROM weight_events WHERE namespace = ?1 AND atom_id = ?2 AND channel = 'ambient'", + params![lambda, atom.to_string()], + |r| r.get(0), + ) + .unwrap() + }; + assert_eq!(count, 5, "expected 5 ambient weight_events rows"); + } + + // ------------------------------------------------------------------------- + // Test 2 — ceiling clamp + // ------------------------------------------------------------------------- + #[tokio::test] + async fn test_apply_weight_delta_clamps_at_ceiling() { + let conn = make_conn(); + let lambda = "lambda:test"; + let atom = Uuid::new_v4(); + + // Repeatedly push large positive deltas. + for _ in 0..100 { + apply_weight_delta( + &conn, + lambda, + atom, + 5.0, // huge delta + WeightChannel::GroundTruth, + None, + None, + ) + .await + .expect("apply_weight_delta should succeed"); + } + + let map = batch_load_weights(&conn, lambda, &[atom]) + .await + .expect("batch_load_weights"); + let w = *map.get(&atom).expect("atom weight must exist"); + assert_eq!( + w, WEIGHT_CEIL, + "weight should be clamped at WEIGHT_CEIL={WEIGHT_CEIL}, got {w}" + ); + } + + // ------------------------------------------------------------------------- + // Test 3 — namespace isolation + // ------------------------------------------------------------------------- + #[tokio::test] + async fn test_apply_weight_delta_namespace_isolation() { + let conn = make_conn(); + let lambda_a = "lambda:a"; + let lambda_b = "lambda:b"; + let atom = Uuid::new_v4(); + + apply_weight_delta( + &conn, + lambda_a, + atom, + 0.5, + WeightChannel::Explicit, + None, + None, + ) + .await + .expect("apply for lambda:a"); + + // lambda:b should see nothing. + let map_b = batch_load_weights(&conn, lambda_b, &[atom]) + .await + .expect("batch_load for lambda:b"); + assert!( + !map_b.contains_key(&atom), + "lambda:b should not see lambda:a's weight" + ); + + // lambda:a should see the written weight. + let map_a = batch_load_weights(&conn, lambda_a, &[atom]) + .await + .expect("batch_load for lambda:a"); + assert!( + map_a.contains_key(&atom), + "lambda:a should see its own weight" + ); + } + + // ------------------------------------------------------------------------- + // Test 4 — missing atoms are absent (not 1.0 rows) from batch_load result + // ------------------------------------------------------------------------- + #[tokio::test] + async fn test_batch_load_weights_missing_atoms_default() { + let conn = make_conn(); + let lambda = "lambda:test"; + let atom_a = Uuid::new_v4(); + let atom_b = Uuid::new_v4(); + let atom_c = Uuid::new_v4(); + + // Write only atom_a. + apply_weight_delta( + &conn, + lambda, + atom_a, + 0.3, + WeightChannel::Explicit, + None, + None, + ) + .await + .expect("apply for atom_a"); + + let map = batch_load_weights(&conn, lambda, &[atom_a, atom_b, atom_c]) + .await + .expect("batch_load"); + + assert!(map.contains_key(&atom_a), "atom_a should be present"); + assert!( + !map.contains_key(&atom_b), + "atom_b should be absent (caller treats as 1.0)" + ); + assert!( + !map.contains_key(&atom_c), + "atom_c should be absent (caller treats as 1.0)" + ); + + // atom_a weight should be non-default (was boosted). + let w_a = *map.get(&atom_a).unwrap(); + assert!(w_a != 1.0_f32, "atom_a weight should differ from default"); + } + + // ------------------------------------------------------------------------- + // Test 5 — negative delta writes a weight_events row (B4 regression guard) + // ------------------------------------------------------------------------- + /// Verifies that apply_weight_delta writes to weight_events even when the + /// delta is negative, guarding against the B4 Channel-A decay skip bug. + #[tokio::test] + async fn test_channel_a_applies_on_negative_delta() { + let conn = make_conn(); + let lambda = "lambda:test"; + let atom = Uuid::new_v4(); + + // First boost the atom so it is above floor. + apply_weight_delta( + &conn, + lambda, + atom, + 0.5, + WeightChannel::GroundTruth, + None, + None, + ) + .await + .expect("initial boost"); + + // Now apply a negative ambient delta (simulates decay). + let (w_after, _) = apply_weight_delta( + &conn, + lambda, + atom, + -0.1, + WeightChannel::Ambient, + None, + Some("decay_test"), + ) + .await + .expect("negative delta must succeed"); + + // Weight must be below the post-boost value (started ~1.25, decay should lower it). + assert!( + w_after < 1.5, + "weight should have decayed below post-boost value, got {w_after}" + ); + assert!( + w_after >= WEIGHT_FLOOR, + "weight must not go below WEIGHT_FLOOR, got {w_after}" + ); + + // Confirm a weight_events row was written for the negative delta. + let count: i64 = { + let c = conn.lock(); + c.query_row( + "SELECT COUNT(*) FROM weight_events \ + WHERE namespace = ?1 AND atom_id = ?2 AND delta < 0", + params![lambda, atom.to_string()], + |r| r.get(0), + ) + .unwrap() + }; + assert_eq!( + count, 1, + "expected 1 weight_event row with negative delta, got {count}" + ); + } + + // ------------------------------------------------------------------------- + // Test 6 — empty namespace returns Validation error (F2 guard) + // ------------------------------------------------------------------------- + #[tokio::test] + async fn test_apply_weight_delta_rejects_empty_namespace() { + let conn = make_conn(); + let atom = Uuid::new_v4(); + let result = + apply_weight_delta(&conn, "", atom, 0.1, WeightChannel::Ambient, None, None).await; + assert!( + matches!(result, Err(PersistError::Validation(_))), + "expected Validation error for empty namespace, got {result:?}" + ); + } +} diff --git a/crates/khive-retrieval/src/weights/mod.rs b/crates/khive-retrieval/src/weights/mod.rs new file mode 100644 index 00000000..6e3b834d --- /dev/null +++ b/crates/khive-retrieval/src/weights/mod.rs @@ -0,0 +1,5 @@ +//! Unified weight store for the Three Observables Feedback Loop. + +pub mod engine_weights; + +pub use engine_weights::*;