diff --git a/crates/Cargo.toml b/crates/Cargo.toml index a6c09a05..ed4a0c18 100644 --- a/crates/Cargo.toml +++ b/crates/Cargo.toml @@ -11,6 +11,7 @@ members = [ "khive-gate-rego", "khive-fusion", "khive-bm25", + "khive-hnsw", "khive-runtime", "khive-request", "khive-pack-kg", diff --git a/crates/khive-bm25/src/index/mod.rs b/crates/khive-bm25/src/index/mod.rs index 9f16a5c3..ca8c77ff 100644 --- a/crates/khive-bm25/src/index/mod.rs +++ b/crates/khive-bm25/src/index/mod.rs @@ -149,7 +149,9 @@ fn default_postings_epoch() -> u64 { // TODO(port): was generated by `khive_types::transparent_string_newtype!` macro // which does not yet exist in khive-types. Expanded here manually until the macro // lands in khive-types and is re-adopted. -#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, serde::Serialize, serde::Deserialize)] +#[derive( + Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, serde::Serialize, serde::Deserialize, +)] #[serde(transparent)] pub struct DocumentId(String); @@ -639,7 +641,9 @@ impl Bm25Index { return id; } let id = self.next_internal_id; - self.next_internal_id = self.next_internal_id.checked_add(1) + self.next_internal_id = self + .next_internal_id + .checked_add(1) .expect("internal document ID space exhausted (u32::MAX)"); self.id_to_internal.insert(doc_id.clone(), id); if id as usize >= self.internal_to_id.len() { diff --git a/crates/khive-bm25/src/index/search.rs b/crates/khive-bm25/src/index/search.rs index 510b0c29..a89a9a3c 100644 --- a/crates/khive-bm25/src/index/search.rs +++ b/crates/khive-bm25/src/index/search.rs @@ -1450,11 +1450,11 @@ mod tests_simd_scoring { let tfs_4: [u8; 4] = [0, 0, 0, 0]; let dls_4: [f32; 4] = [10.0, 20.0, 5.0, 1.0]; let result = score_batch_4(&tfs_4, &dls_4, TEST_IDF, TEST_K1P1, TEST_BASE, TEST_DL_FAC); - for i in 0..4 { + for val in &result { assert!( - result[i].abs() < 1e-10, + val.abs() < 1e-10, "tf=0 should produce ~0 score, got {}", - result[i] + val ); } @@ -1467,11 +1467,11 @@ mod tests_simd_scoring { let result = unsafe { score_batch_avx2(&tfs_8, &dls_8, TEST_IDF, TEST_K1P1, TEST_BASE, TEST_DL_FAC) }; - for i in 0..8 { + for val in &result { assert!( - result[i].abs() < 1e-10, + val.abs() < 1e-10, "avx2 tf=0 should produce ~0 score, got {}", - result[i] + val ); } } diff --git a/crates/khive-bm25/src/metrics.rs b/crates/khive-bm25/src/metrics.rs index 731a334b..e45fef21 100644 --- a/crates/khive-bm25/src/metrics.rs +++ b/crates/khive-bm25/src/metrics.rs @@ -63,13 +63,20 @@ pub struct RecordingSink { } #[cfg(test)] -impl RecordingSink { - /// Create an empty recording sink. - pub fn new() -> Self { +impl Default for RecordingSink { + fn default() -> Self { Self { events: Mutex::new(Vec::new()), } } +} + +#[cfg(test)] +impl RecordingSink { + /// Create an empty recording sink. + pub fn new() -> Self { + Self::default() + } /// Return a snapshot of all recorded events. pub fn events(&self) -> Vec { diff --git a/crates/khive-bm25/src/tests.rs b/crates/khive-bm25/src/tests.rs index efeb7e7a..dfc28251 100644 --- a/crates/khive-bm25/src/tests.rs +++ b/crates/khive-bm25/src/tests.rs @@ -813,8 +813,8 @@ mod golden_tests { /// Memory budget enforcement tests for BM25. #[cfg(test)] mod memory_budget_tests { - use crate::{Bm25Config, Bm25Index}; use crate::error::{ErrorKind, RetrievalError}; + use crate::{Bm25Config, Bm25Index}; #[test] fn test_no_budget_allows_unlimited_indexing() { @@ -844,9 +844,8 @@ mod memory_budget_tests { format!("doc{i}"), &format!("some content words for document number {i} with extra text"), ); - if result.is_err() { + if let Err(err) = result { rejected = true; - let err = result.unwrap_err(); assert!( matches!(err, RetrievalError::BudgetExceeded { .. }), "Expected BudgetExceeded, got: {err:?}" @@ -1019,8 +1018,8 @@ mod memory_budget_tests { #[cfg(test)] mod metrics_tests { - use crate::{Bm25Config, Bm25Index}; use crate::metrics::{names, MetricValue, RecordingSink}; + use crate::{Bm25Config, Bm25Index}; use std::sync::Arc; #[test] diff --git a/crates/khive-bm25/src/tokenizer.rs b/crates/khive-bm25/src/tokenizer.rs index 653bd446..e06160d4 100644 --- a/crates/khive-bm25/src/tokenizer.rs +++ b/crates/khive-bm25/src/tokenizer.rs @@ -247,8 +247,10 @@ mod tests { #[test] fn test_stop_words_disabled() { - let mut tokenizer = SimpleTokenizer::default(); - tokenizer.filter_stop_words = false; + let tokenizer = SimpleTokenizer { + filter_stop_words: false, + ..Default::default() + }; let tokens = tokenizer.tokenize("The Quick, Brown FOX!"); assert_eq!(tokens, vec!["the", "quick", "brown", "fox"]); } diff --git a/crates/khive-fusion/src/tests.rs b/crates/khive-fusion/src/tests.rs index 8fda09b8..beec545c 100644 --- a/crates/khive-fusion/src/tests.rs +++ b/crates/khive-fusion/src/tests.rs @@ -221,19 +221,11 @@ mod property_tests { /// Verifies the `present_gt_absent` property from RRF.lean. #[test] fn prop_rrf_more_sources_higher_score() { - let source1: Vec<(String, DeterministicScore)> = vec![( - "doc_common".to_string(), - DeterministicScore::from_f64(0.9), - )]; + let source1: Vec<(String, DeterministicScore)> = + vec![("doc_common".to_string(), DeterministicScore::from_f64(0.9))]; let source2: Vec<(String, DeterministicScore)> = vec![ - ( - "doc_common".to_string(), - DeterministicScore::from_f64(0.9), - ), - ( - "doc_single".to_string(), - DeterministicScore::from_f64(0.8), - ), + ("doc_common".to_string(), DeterministicScore::from_f64(0.9)), + ("doc_single".to_string(), DeterministicScore::from_f64(0.8)), ]; let fused = reciprocal_rank_fusion(vec![source1, source2], 60); diff --git a/crates/khive-hnsw/Cargo.toml b/crates/khive-hnsw/Cargo.toml new file mode 100644 index 00000000..d3edc030 --- /dev/null +++ b/crates/khive-hnsw/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "khive-hnsw" +version.workspace = true +edition.workspace = true +authors.workspace = true +license.workspace = true +repository.workspace = true +homepage.workspace = true +keywords.workspace = true +categories.workspace = true +description = "HNSW (Hierarchical Navigable Small World) vector index with INT8 quantized two-phase search — formally verified in Lean4" + +[dependencies] +khive-score = { version = "0.2.0", path = "../khive-score" } +khive-types = { version = "0.2.0", path = "../khive-types" } +lattice-embed = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +thiserror = { workspace = true } +parking_lot = { workspace = true } +tokio = { workspace = true } +rand = "0.8" +rayon = "1.10" +ulid = "1.1" + +[dev-dependencies] +proptest = "1" + +[features] +checkpoint = [] diff --git a/crates/khive-hnsw/src/alias/drain.rs b/crates/khive-hnsw/src/alias/drain.rs new file mode 100644 index 00000000..e510a882 --- /dev/null +++ b/crates/khive-hnsw/src/alias/drain.rs @@ -0,0 +1,240 @@ +//! Reader tracking and drain detection for zero-downtime index swaps. +//! +//! The drain protocol ensures that after an alias swap, the old index is not +//! deallocated until all in-flight queries have completed. This is implemented +//! via an `AtomicU64` reader counter and an RAII `ReaderGuard` that decrements +//! on drop. +//! +//! # Design +//! +//! - Each collection has an associated `AtomicU64` reader count. +//! - `search_via_alias` increments the count and returns a `ReaderGuard`. +//! - The guard holds an `Arc` snapshot, so the index stays alive +//! even if the alias is swapped while the query is in flight. +//! - `drain_and_remove` polls the reader count until it reaches zero or timeout. + +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use super::error::AliasError; +use crate::HnswIndex; + +/// Tracks active readers for a single collection. +/// +/// The counter is shared (via `Arc`) between the alias manager and all +/// outstanding `ReaderGuard`s. When the alias manager wants to drain a +/// collection, it polls this counter. +#[derive(Debug)] +pub(crate) struct ReaderCounter { + count: AtomicU64, +} + +impl ReaderCounter { + pub fn new() -> Self { + Self { + count: AtomicU64::new(0), + } + } + + /// Increment the reader count. Returns the previous value. + #[inline] + pub fn acquire(&self) -> u64 { + self.count.fetch_add(1, Ordering::Acquire) + } + + /// Decrement the reader count. Returns the previous value. + #[inline] + pub fn release(&self) -> u64 { + self.count.fetch_sub(1, Ordering::Release) + } + + /// Get the current reader count. + #[inline] + pub fn load(&self) -> u64 { + self.count.load(Ordering::Acquire) + } +} + +/// RAII guard that holds a snapshot of the index and decrements the reader +/// count on drop. +/// +/// The caller gets `&HnswIndex` access via `Deref`. The index is guaranteed +/// to remain alive for the lifetime of this guard, even if the alias is +/// swapped to a different collection in the meantime. +pub struct ReaderGuard { + /// Snapshot of the index at the time the guard was acquired. + index: Arc, + /// Reader counter to decrement on drop. + counter: Arc, +} + +impl ReaderGuard { + /// Create a new reader guard, incrementing the reader counter. + pub(crate) fn new(index: Arc, counter: Arc) -> Self { + counter.acquire(); + Self { index, counter } + } + + /// Get a reference to the index snapshot. + pub fn index(&self) -> &HnswIndex { + &self.index + } +} + +impl Drop for ReaderGuard { + fn drop(&mut self) { + self.counter.release(); + } +} + +impl std::ops::Deref for ReaderGuard { + type Target = HnswIndex; + + fn deref(&self) -> &Self::Target { + &self.index + } +} + +/// Wait for all readers on a counter to finish, polling at `poll_interval`. +/// +/// Returns `Ok(())` when the reader count reaches zero, or `Err(DrainTimeout)` +/// if the timeout is exceeded. +pub(crate) async fn drain_readers( + counter: &Arc, + timeout: Duration, + poll_interval: Duration, +) -> Result<(), AliasError> { + let start = Instant::now(); + + loop { + let active = counter.load(); + if active == 0 { + return Ok(()); + } + + let elapsed = start.elapsed(); + if elapsed >= timeout { + return Err(AliasError::DrainTimeout { + elapsed, + timeout, + active_readers: active, + }); + } + + // Sleep for the poll interval (or remaining time, whichever is shorter) + let remaining = timeout - elapsed; + let sleep_dur = poll_interval.min(remaining); + tokio::time::sleep(sleep_dur).await; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_reader_counter_acquire_release() { + let counter = ReaderCounter::new(); + assert_eq!(counter.load(), 0); + + counter.acquire(); + assert_eq!(counter.load(), 1); + + counter.acquire(); + assert_eq!(counter.load(), 2); + + counter.release(); + assert_eq!(counter.load(), 1); + + counter.release(); + assert_eq!(counter.load(), 0); + } + + #[test] + fn test_reader_guard_decrements_on_drop() { + let index = Arc::new(HnswIndex::new(4)); + let counter = Arc::new(ReaderCounter::new()); + + { + let _guard = ReaderGuard::new(Arc::clone(&index), Arc::clone(&counter)); + assert_eq!(counter.load(), 1); + + let _guard2 = ReaderGuard::new(Arc::clone(&index), Arc::clone(&counter)); + assert_eq!(counter.load(), 2); + } + // Both guards dropped + assert_eq!(counter.load(), 0); + } + + #[test] + fn test_reader_guard_deref() { + let index = Arc::new(HnswIndex::new(8)); + let counter = Arc::new(ReaderCounter::new()); + let guard = ReaderGuard::new(Arc::clone(&index), Arc::clone(&counter)); + + // Should be able to call HnswIndex methods via Deref + assert_eq!(guard.len(), 0); + assert!(guard.is_empty()); + } + + #[tokio::test] + async fn test_drain_readers_immediate() { + let counter = Arc::new(ReaderCounter::new()); + // No readers -- drain should return immediately + let result = drain_readers( + &counter, + Duration::from_millis(100), + Duration::from_millis(10), + ) + .await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_drain_readers_timeout() { + let counter = Arc::new(ReaderCounter::new()); + counter.acquire(); // Simulate an active reader that never finishes + + let result = drain_readers( + &counter, + Duration::from_millis(50), + Duration::from_millis(10), + ) + .await; + + assert!(result.is_err()); + match result.unwrap_err() { + AliasError::DrainTimeout { active_readers, .. } => { + assert_eq!(active_readers, 1); + } + other => panic!("Expected DrainTimeout, got: {other:?}"), + } + + // Clean up + counter.release(); + } + + #[tokio::test] + async fn test_drain_readers_delayed_release() { + let counter = Arc::new(ReaderCounter::new()); + counter.acquire(); + + let counter_clone = Arc::clone(&counter); + + // Spawn a task that releases after 30ms + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(30)).await; + counter_clone.release(); + }); + + // Drain with 200ms timeout -- should succeed after ~30ms + let result = drain_readers( + &counter, + Duration::from_millis(200), + Duration::from_millis(5), + ) + .await; + assert!(result.is_ok()); + } +} diff --git a/crates/khive-hnsw/src/alias/error.rs b/crates/khive-hnsw/src/alias/error.rs new file mode 100644 index 00000000..3d37ca80 --- /dev/null +++ b/crates/khive-hnsw/src/alias/error.rs @@ -0,0 +1,102 @@ +//! Error types for alias operations. +//! +//! These errors cover the alias lifecycle: creation, swap, drain, and validation. +//! They integrate with the parent `RetrievalError` via `From` conversion. + +use std::fmt; +use std::time::Duration; + +/// Errors from alias manager operations. +#[derive(Debug)] +pub enum AliasError { + /// The requested alias name does not exist. + AliasNotFound(String), + + /// The requested collection name does not exist. + CollectionNotFound(String), + + /// A collection with this name already exists. + CollectionAlreadyExists(String), + + /// An alias with this name already exists. + AliasAlreadyExists(String), + + /// The pre-swap validation failed. + ValidationFailed { + /// Human-readable reason. + reason: String, + /// Recall score achieved (if applicable). + recall: Option, + /// Minimum recall required (if applicable). + min_recall: Option, + }, + + /// Drain timed out waiting for active readers to finish. + DrainTimeout { + /// How long we waited. + elapsed: Duration, + /// Configured timeout. + timeout: Duration, + /// Number of readers still active. + active_readers: u64, + }, + + /// An HNSW operation failed during migration. + IndexError(String), +} + +impl fmt::Display for AliasError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::AliasNotFound(name) => write!(f, "alias not found: {name}"), + Self::CollectionNotFound(name) => write!(f, "collection not found: {name}"), + Self::CollectionAlreadyExists(name) => { + write!(f, "collection already exists: {name}") + } + Self::AliasAlreadyExists(name) => write!(f, "alias already exists: {name}"), + Self::ValidationFailed { + reason, + recall, + min_recall, + } => { + write!(f, "validation failed: {reason}")?; + if let (Some(r), Some(min)) = (recall, min_recall) { + write!(f, " (recall={r:.4}, min={min:.4})")?; + } + Ok(()) + } + Self::DrainTimeout { + elapsed, + timeout, + active_readers, + } => { + write!( + f, + "drain timeout after {:.1}s (limit {:.1}s, {active_readers} readers remaining)", + elapsed.as_secs_f64(), + timeout.as_secs_f64() + ) + } + Self::IndexError(msg) => write!(f, "index error: {msg}"), + } + } +} + +impl std::error::Error for AliasError {} + +impl From for crate::error::RetrievalError { + fn from(e: AliasError) -> Self { + match &e { + AliasError::DrainTimeout { .. } => { + // Drain timeout is transient -- readers will eventually finish + crate::error::RetrievalError::QueryTimeout { + elapsed_ms: match &e { + AliasError::DrainTimeout { elapsed, .. } => elapsed.as_millis() as u64, + _ => 0, + }, + } + } + _ => crate::error::RetrievalError::Hnsw(e.to_string()), + } + } +} diff --git a/crates/khive-hnsw/src/alias/manager.rs b/crates/khive-hnsw/src/alias/manager.rs new file mode 100644 index 00000000..e55b3625 --- /dev/null +++ b/crates/khive-hnsw/src/alias/manager.rs @@ -0,0 +1,600 @@ +//! Index alias manager for zero-downtime index migration. +//! +//! Provides blue-green deployment semantics for HNSW indexes: readers always +//! get a consistent snapshot, writers build a new index in the background, +//! and the alias swap is atomic (single pointer update under a brief write lock). +//! +//! # Concurrency Model +//! +//! - **Read path**: `parking_lot::RwLock` read guard. `parking_lot` uses +//! adaptive spinning before OS-level blocking, so short critical sections +//! (pointer clone) have near-zero contention overhead. +//! - **Write path**: Brief exclusive lock for the pointer swap only. +//! - **Background build**: Runs on `tokio::task::spawn_blocking`. Does not +//! hold any locks on the alias map. +//! - **Drain**: Async poll with configurable interval and timeout. + +use std::collections::HashMap; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use parking_lot::RwLock; + +use super::drain::{drain_readers, ReaderCounter, ReaderGuard}; +use super::error::AliasError; +use super::validation::IndexValidator; +use crate::config::HnswConfig; +use crate::HnswIndex; +use crate::NodeId; + +/// Metadata for a registered collection. +struct Collection { + /// The index, wrapped in Arc for snapshot sharing with readers. + index: Arc, + /// Active reader counter for drain detection. + readers: Arc, +} + +/// Report from a completed migration. +#[derive(Debug, Clone)] +pub struct MigrationReport { + /// Name of the old collection that was replaced. + pub old_collection: String, + /// Name of the new collection. + pub new_collection: String, + /// Number of vectors in the old index. + pub old_size: usize, + /// Number of vectors in the new index. + pub new_size: usize, + /// Recall score from validation (if validation was run). + pub recall_score: Option, + /// Wall-clock time for the entire migration (build + validate + swap + drain). + pub total_duration: Duration, + /// Wall-clock time for the index build phase. + pub build_duration: Duration, + /// Wall-clock time for the swap + drain phase. + pub swap_drain_duration: Duration, +} + +/// Manages named collections and aliases for zero-downtime index switching. +/// +/// # Usage +/// +/// ```rust,ignore +/// let manager = IndexAliasManager::new(Duration::from_secs(5)); +/// +/// // Register initial index +/// manager.register_collection("index_v1", initial_index)?; +/// manager.create_alias("active", "index_v1")?; +/// +/// // Search through alias +/// let guard = manager.acquire_reader("active")?; +/// let results = guard.search(&query, 10)?; +/// drop(guard); // releases reader count +/// +/// // Migrate to new index +/// let report = manager.migrate("active", vectors, new_config, None).await?; +/// ``` +pub struct IndexAliasManager { + /// Collection name -> Collection data. + /// Protected by RwLock: reads (search) take shared lock, writes (register/remove) + /// take exclusive lock. + collections: RwLock>, + + /// Alias name -> collection name mapping. + /// Protected by RwLock: reads (resolve alias) take shared lock, writes + /// (create/switch alias) take exclusive lock. + aliases: RwLock>, + + /// Maximum time to wait for readers to drain before force-dropping. + drain_timeout: Duration, + + /// Poll interval for drain detection. + drain_poll_interval: Duration, +} + +impl IndexAliasManager { + /// Create a new alias manager with the given drain timeout. + pub fn new(drain_timeout: Duration) -> Self { + Self { + collections: RwLock::new(HashMap::new()), + aliases: RwLock::new(HashMap::new()), + drain_timeout, + drain_poll_interval: Duration::from_millis(10), + } + } + + /// Set the drain poll interval (default: 10ms). + pub fn with_drain_poll_interval(mut self, interval: Duration) -> Self { + self.drain_poll_interval = interval; + self + } + + /// Register a named collection. Fails if the name already exists. + pub fn register_collection(&self, name: &str, index: HnswIndex) -> Result<(), AliasError> { + let mut collections = self.collections.write(); + if collections.contains_key(name) { + return Err(AliasError::CollectionAlreadyExists(name.to_string())); + } + collections.insert( + name.to_string(), + Collection { + index: Arc::new(index), + readers: Arc::new(ReaderCounter::new()), + }, + ); + Ok(()) + } + + /// Create an alias pointing to an existing collection. + /// Fails if the alias already exists or the collection does not exist. + pub fn create_alias(&self, alias: &str, collection: &str) -> Result<(), AliasError> { + // Verify collection exists + { + let collections = self.collections.read(); + if !collections.contains_key(collection) { + return Err(AliasError::CollectionNotFound(collection.to_string())); + } + } + + let mut aliases = self.aliases.write(); + if aliases.contains_key(alias) { + return Err(AliasError::AliasAlreadyExists(alias.to_string())); + } + aliases.insert(alias.to_string(), collection.to_string()); + Ok(()) + } + + /// Acquire a reader guard for the index behind an alias. + /// + /// The returned guard holds an `Arc` snapshot and increments the + /// reader counter. The index is guaranteed to remain alive until the guard + /// is dropped, even if the alias is swapped in the meantime. + /// + /// This is the primary read-path entry point. The critical section is + /// minimal: read-lock the alias map, read-lock the collection map, clone + /// the Arc, increment the counter. + pub fn acquire_reader(&self, alias: &str) -> Result { + // Resolve alias -> collection name + let collection_name = { + let aliases = self.aliases.read(); + aliases + .get(alias) + .ok_or_else(|| AliasError::AliasNotFound(alias.to_string()))? + .clone() + }; + + // Get collection and create reader guard + let collections = self.collections.read(); + let collection = collections + .get(&collection_name) + .ok_or_else(|| AliasError::CollectionNotFound(collection_name.clone()))?; + + Ok(ReaderGuard::new( + Arc::clone(&collection.index), + Arc::clone(&collection.readers), + )) + } + + /// Switch an alias to point to a different collection. + /// + /// If a validator is provided, it runs against the target collection's + /// index before the swap. If validation fails, the alias is not changed. + /// + /// Returns the name of the previous collection (for drain purposes). + pub fn switch_alias( + &self, + alias: &str, + new_collection: &str, + validator: Option<&dyn IndexValidator>, + ) -> Result { + // Verify new collection exists and optionally validate + { + let collections = self.collections.read(); + let collection = collections + .get(new_collection) + .ok_or_else(|| AliasError::CollectionNotFound(new_collection.to_string()))?; + + if let Some(v) = validator { + v.validate(&collection.index)?; + } + } + + // Swap the alias (exclusive lock, but the critical section is just + // a HashMap insert -- nanoseconds) + let mut aliases = self.aliases.write(); + let old_collection = aliases + .get(alias) + .ok_or_else(|| AliasError::AliasNotFound(alias.to_string()))? + .clone(); + + aliases.insert(alias.to_string(), new_collection.to_string()); + Ok(old_collection) + } + + /// Wait for all readers on a collection to finish, then remove it. + /// + /// This is typically called after `switch_alias` to clean up the old + /// collection. If the drain times out, the collection is NOT removed + /// and the error is returned. + pub async fn drain_and_remove(&self, collection: &str) -> Result<(), AliasError> { + // Extract the reader counter (under read lock -- we just need the Arc) + let counter = { + let collections = self.collections.read(); + let coll = collections + .get(collection) + .ok_or_else(|| AliasError::CollectionNotFound(collection.to_string()))?; + Arc::clone(&coll.readers) + }; + + // Wait for readers to drain (async, no locks held) + drain_readers(&counter, self.drain_timeout, self.drain_poll_interval).await?; + + // Remove the collection (exclusive lock) + let mut collections = self.collections.write(); + collections.remove(collection); + Ok(()) + } + + /// Full migration: build new index, validate, swap alias, drain old. + /// + /// This is the high-level API for embedding model migrations. The build + /// phase runs on `spawn_blocking` to avoid blocking the tokio runtime. + /// + /// # Arguments + /// + /// * `alias` - The alias to migrate (must already exist) + /// * `vectors` - All vectors for the new index + /// * `new_config` - Configuration for the new index + /// * `validator` - Optional pre-swap validation + /// + /// # Returns + /// + /// A `MigrationReport` with timing and size information. + pub async fn migrate( + &self, + alias: &str, + vectors: Vec<(NodeId, Vec)>, + new_config: HnswConfig, + validator: Option>, + ) -> Result { + let total_start = Instant::now(); + + // Resolve current alias to get old collection info + let old_collection_name = { + let aliases = self.aliases.read(); + aliases + .get(alias) + .ok_or_else(|| AliasError::AliasNotFound(alias.to_string()))? + .clone() + }; + + let old_size = { + let collections = self.collections.read(); + collections + .get(&old_collection_name) + .map(|c| c.index.len_live()) + .unwrap_or(0) + }; + + // Generate a unique name for the new collection + let new_collection_name = format!( + "{}_migrated_{}", + old_collection_name, + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis() + ); + + // Build new index on a blocking thread + let build_start = Instant::now(); + let new_index = tokio::task::spawn_blocking(move || { + let mut index = HnswIndex::with_config(new_config); + for (id, vec) in vectors { + if let Err(e) = index.insert(id, vec) { + return Err(AliasError::IndexError(e.to_string())); + } + } + Ok(index) + }) + .await + .map_err(|e| AliasError::IndexError(format!("build task panicked: {e}")))? + .map_err(|e| AliasError::IndexError(format!("build failed: {e}")))?; + + let build_duration = build_start.elapsed(); + let new_size = new_index.len_live(); + + // Register the new collection + self.register_collection(&new_collection_name, new_index)?; + + // Validate and swap + let swap_drain_start = Instant::now(); + + // Recall score for the report + let recall_score = None; + + // If we have a validator, run it and capture recall + if let Some(ref v) = validator { + let collections = self.collections.read(); + let coll = collections.get(&new_collection_name).unwrap(); + match v.validate(&coll.index) { + Ok(()) => {} + Err(AliasError::ValidationFailed { + recall, min_recall, .. + }) => { + // Remove the new collection since validation failed + drop(collections); + let mut colls = self.collections.write(); + colls.remove(&new_collection_name); + return Err(AliasError::ValidationFailed { + reason: "recall below threshold".to_string(), + recall, + min_recall, + }); + } + Err(e) => { + // Remove the new collection since validation failed + drop(collections); + let mut colls = self.collections.write(); + colls.remove(&new_collection_name); + return Err(e); + } + } + } + + // Switch the alias + self.switch_alias(alias, &new_collection_name, None)?; + + // Drain and remove old collection + // Note: if drain times out, the old collection stays around but the + // alias already points to the new one. This is acceptable -- readers + // on the old index will finish eventually, and the memory will be + // reclaimed when the last Arc is dropped. + let drain_result = self.drain_and_remove(&old_collection_name).await; + + let swap_drain_duration = swap_drain_start.elapsed(); + let total_duration = total_start.elapsed(); + + // Log drain timeout but don't fail the migration -- the alias is + // already switched, so new queries go to the new index. + if let Err(AliasError::DrainTimeout { .. }) = &drain_result { + // The old collection will be cleaned up when the last reader drops. + // We report this in the migration report but don't fail. + } else if let Err(e) = drain_result { + return Err(e); + } + + Ok(MigrationReport { + old_collection: old_collection_name, + new_collection: new_collection_name, + old_size, + new_size, + recall_score, + total_duration, + build_duration, + swap_drain_duration, + }) + } + + /// Get the number of registered collections. + pub fn collection_count(&self) -> usize { + self.collections.read().len() + } + + /// Get the number of registered aliases. + pub fn alias_count(&self) -> usize { + self.aliases.read().len() + } + + /// Get the collection name that an alias points to. + pub fn resolve_alias(&self, alias: &str) -> Option { + self.aliases.read().get(alias).cloned() + } + + /// Get the active reader count for a collection. + pub fn reader_count(&self, collection: &str) -> Option { + self.collections + .read() + .get(collection) + .map(|c| c.readers.load()) + } +} + +impl std::fmt::Debug for IndexAliasManager { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let collections = self.collections.read(); + let aliases = self.aliases.read(); + f.debug_struct("IndexAliasManager") + .field("collections", &collections.keys().collect::>()) + .field("aliases", &*aliases) + .field("drain_timeout", &self.drain_timeout) + .finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::HnswIndex; + + fn make_index(dims: usize, count: usize) -> HnswIndex { + let mut index = HnswIndex::new(dims); + for i in 0..count { + let id = NodeId::new([(i & 0xFF) as u8; 16]); + let vec = vec![i as f32; dims]; + index.insert(id, vec).unwrap(); + } + index + } + + #[test] + fn test_register_and_create_alias() { + let mgr = IndexAliasManager::new(Duration::from_secs(1)); + + let index = make_index(4, 5); + mgr.register_collection("v1", index).unwrap(); + mgr.create_alias("active", "v1").unwrap(); + + assert_eq!(mgr.collection_count(), 1); + assert_eq!(mgr.alias_count(), 1); + assert_eq!(mgr.resolve_alias("active"), Some("v1".to_string())); + } + + #[test] + fn test_register_duplicate_collection() { + let mgr = IndexAliasManager::new(Duration::from_secs(1)); + mgr.register_collection("v1", make_index(4, 1)).unwrap(); + + let result = mgr.register_collection("v1", make_index(4, 1)); + assert!(matches!( + result, + Err(AliasError::CollectionAlreadyExists(_)) + )); + } + + #[test] + fn test_create_alias_missing_collection() { + let mgr = IndexAliasManager::new(Duration::from_secs(1)); + let result = mgr.create_alias("active", "nonexistent"); + assert!(matches!(result, Err(AliasError::CollectionNotFound(_)))); + } + + #[test] + fn test_create_duplicate_alias() { + let mgr = IndexAliasManager::new(Duration::from_secs(1)); + mgr.register_collection("v1", make_index(4, 1)).unwrap(); + mgr.create_alias("active", "v1").unwrap(); + + let result = mgr.create_alias("active", "v1"); + assert!(matches!(result, Err(AliasError::AliasAlreadyExists(_)))); + } + + #[test] + fn test_acquire_reader() { + let mgr = IndexAliasManager::new(Duration::from_secs(1)); + mgr.register_collection("v1", make_index(4, 5)).unwrap(); + mgr.create_alias("active", "v1").unwrap(); + + let guard = mgr.acquire_reader("active").unwrap(); + assert_eq!(guard.len(), 5); + assert_eq!(mgr.reader_count("v1"), Some(1)); + + drop(guard); + assert_eq!(mgr.reader_count("v1"), Some(0)); + } + + #[test] + fn test_acquire_reader_missing_alias() { + let mgr = IndexAliasManager::new(Duration::from_secs(1)); + let result = mgr.acquire_reader("nonexistent"); + assert!(matches!(result, Err(AliasError::AliasNotFound(_)))); + } + + #[test] + fn test_switch_alias() { + let mgr = IndexAliasManager::new(Duration::from_secs(1)); + mgr.register_collection("v1", make_index(4, 5)).unwrap(); + mgr.register_collection("v2", make_index(4, 10)).unwrap(); + mgr.create_alias("active", "v1").unwrap(); + + let old = mgr.switch_alias("active", "v2", None).unwrap(); + assert_eq!(old, "v1"); + assert_eq!(mgr.resolve_alias("active"), Some("v2".to_string())); + + // Reader should now get v2 + let guard = mgr.acquire_reader("active").unwrap(); + assert_eq!(guard.len(), 10); + } + + #[test] + fn test_switch_alias_with_failing_validator() { + let mgr = IndexAliasManager::new(Duration::from_secs(1)); + mgr.register_collection("v1", make_index(4, 5)).unwrap(); + mgr.register_collection("v2", make_index(4, 10)).unwrap(); + mgr.create_alias("active", "v1").unwrap(); + + // Validator that always fails + struct FailValidator; + impl IndexValidator for FailValidator { + fn validate(&self, _: &HnswIndex) -> Result<(), AliasError> { + Err(AliasError::ValidationFailed { + reason: "test failure".to_string(), + recall: Some(0.5), + min_recall: Some(0.95), + }) + } + } + + let result = mgr.switch_alias("active", "v2", Some(&FailValidator)); + assert!(matches!(result, Err(AliasError::ValidationFailed { .. }))); + + // Alias should still point to v1 + assert_eq!(mgr.resolve_alias("active"), Some("v1".to_string())); + } + + #[tokio::test] + async fn test_drain_and_remove() { + let mgr = IndexAliasManager::new(Duration::from_secs(1)); + mgr.register_collection("v1", make_index(4, 5)).unwrap(); + + // No readers -- drain should succeed immediately + mgr.drain_and_remove("v1").await.unwrap(); + assert_eq!(mgr.collection_count(), 0); + } + + #[tokio::test] + async fn test_concurrent_read_during_swap() { + let mgr = Arc::new(IndexAliasManager::new(Duration::from_secs(5))); + mgr.register_collection("v1", make_index(4, 5)).unwrap(); + mgr.register_collection("v2", make_index(4, 10)).unwrap(); + mgr.create_alias("active", "v1").unwrap(); + + // Acquire a reader on v1 + let guard = mgr.acquire_reader("active").unwrap(); + assert_eq!(guard.len(), 5); + + // Swap alias to v2 while v1 reader is active + let old = mgr.switch_alias("active", "v2", None).unwrap(); + assert_eq!(old, "v1"); + + // Old reader should still see v1 (5 vectors) + assert_eq!(guard.len(), 5); + + // New reader should see v2 (10 vectors) + let new_guard = mgr.acquire_reader("active").unwrap(); + assert_eq!(new_guard.len(), 10); + + // Drop the old reader + drop(guard); + + // Now drain should succeed for v1 + let mgr_clone = Arc::clone(&mgr); + mgr_clone.drain_and_remove("v1").await.unwrap(); + + // v1 should be gone, v2 should remain + assert_eq!(mgr.collection_count(), 1); + } + + #[tokio::test] + async fn test_migrate() { + let mgr = Arc::new(IndexAliasManager::new(Duration::from_secs(5))); + mgr.register_collection("v1", make_index(4, 5)).unwrap(); + mgr.create_alias("active", "v1").unwrap(); + + // Prepare vectors for the new index + let vectors: Vec<(NodeId, Vec)> = (0..8u8) + .map(|i| (NodeId::new([i; 16]), vec![i as f32; 4])) + .collect(); + + let config = HnswConfig::with_dimensions(4); + let report = mgr.migrate("active", vectors, config, None).await.unwrap(); + + assert_eq!(report.old_size, 5); + assert_eq!(report.new_size, 8); + + // The alias should now point to the new collection + let guard = mgr.acquire_reader("active").unwrap(); + assert_eq!(guard.len(), 8); + } +} diff --git a/crates/khive-hnsw/src/alias/mod.rs b/crates/khive-hnsw/src/alias/mod.rs new file mode 100644 index 00000000..0ad6d1d9 --- /dev/null +++ b/crates/khive-hnsw/src/alias/mod.rs @@ -0,0 +1,37 @@ +//! Index alias management for zero-downtime HNSW index migration. +//! +//! This module implements a blue-green deployment pattern for HNSW vector indexes. +//! When switching embedding models (e.g., from BGE-small to mE5-small), every +//! vector must be re-embedded and re-indexed. The alias manager allows this to +//! happen without taking the search service offline: +//! +//! 1. `alias("active")` currently points to `collection("index_v1")` +//! 2. Build `collection("index_v2")` in a background thread +//! 3. Validate the new index (recall@k benchmark) +//! 4. Atomic swap: `alias("active")` now points to `collection("index_v2")` +//! 5. In-flight queries on v1 complete on v1; new queries go to v2 +//! 6. After drain (all v1 readers dropped), deallocate v1 +//! +//! # Concurrency +//! +//! - Read path: `parking_lot::RwLock` read guard (adaptive spinning, no OS block +//! for short critical sections) +//! - Write path: Brief exclusive lock for pointer swap only +//! - Background build: `tokio::task::spawn_blocking`, no locks held +//! - Drain: Async poll via `AtomicU64` reader counter +//! +//! # Module Structure +//! +//! - [`manager`]: `IndexAliasManager` -- the main entry point +//! - [`drain`]: Reader tracking and RAII guard +//! - [`validation`]: Pre-swap index quality validation +//! - [`error`]: Error types + +mod drain; +pub mod error; +mod manager; +pub mod validation; + +pub use drain::ReaderGuard; +pub use manager::{IndexAliasManager, MigrationReport}; +pub use validation::{IndexValidator, NoopValidator, RecallValidator}; diff --git a/crates/khive-hnsw/src/alias/validation.rs b/crates/khive-hnsw/src/alias/validation.rs new file mode 100644 index 00000000..012d42f0 --- /dev/null +++ b/crates/khive-hnsw/src/alias/validation.rs @@ -0,0 +1,184 @@ +//! Pre-swap validation for index migrations. +//! +//! Before committing an alias swap, the system can optionally run a validation +//! function to verify that the new index meets quality requirements. The primary +//! validator is `RecallValidator`, which checks recall@k against a set of +//! golden queries with known ground-truth results. + +use super::error::AliasError; +use crate::HnswIndex; +use crate::NodeId; + +/// Trait for validating an index before an alias swap. +/// +/// Implementations should be stateless or cheaply cloneable, as they may be +/// called from within a `spawn_blocking` context. +pub trait IndexValidator: Send + Sync { + /// Validate the new index. Return `Ok(())` to proceed with the swap, + /// or `Err(AliasError::ValidationFailed)` to abort. + fn validate(&self, new_index: &HnswIndex) -> Result<(), AliasError>; +} + +/// Validates a new index by running recall@k against golden queries. +/// +/// Golden queries are `(query_vector, expected_result_ids)` pairs where the +/// expected results are the true nearest neighbors (typically computed via +/// brute-force on the same dataset). +/// +/// # Recall Computation +/// +/// For each golden query, we search the new index for `k` results and compute: +/// +/// ```text +/// recall = |returned ∩ expected| / |expected| +/// ``` +/// +/// The overall recall is the mean across all golden queries. The swap is +/// approved if `mean_recall >= min_recall`. +pub struct RecallValidator { + /// Golden queries: `(query_vector, expected_nearest_ids)`. + pub golden_queries: Vec<(Vec, Vec)>, + /// Number of results to retrieve per query. + pub k: usize, + /// Minimum acceptable mean recall (e.g., 0.95 for 95%). + pub min_recall: f32, +} + +impl RecallValidator { + /// Create a new recall validator. + pub fn new(golden_queries: Vec<(Vec, Vec)>, k: usize, min_recall: f32) -> Self { + Self { + golden_queries, + k, + min_recall, + } + } +} + +impl IndexValidator for RecallValidator { + fn validate(&self, new_index: &HnswIndex) -> Result<(), AliasError> { + if self.golden_queries.is_empty() { + return Ok(()); + } + + let mut total_recall = 0.0f64; + let mut query_count = 0usize; + + for (query, expected) in &self.golden_queries { + let results = new_index + .search(query, self.k) + .map_err(|e| AliasError::IndexError(e.to_string()))?; + + let returned_ids: std::collections::HashSet = + results.iter().map(|(id, _)| *id).collect(); + + let hits = expected + .iter() + .filter(|id| returned_ids.contains(id)) + .count(); + + let recall = if expected.is_empty() { + 1.0 + } else { + hits as f64 / expected.len() as f64 + }; + + total_recall += recall; + query_count += 1; + } + + let mean_recall = (total_recall / query_count as f64) as f32; + + if mean_recall < self.min_recall { + return Err(AliasError::ValidationFailed { + reason: format!( + "recall@{} = {mean_recall:.4} < {:.4}", + self.k, self.min_recall + ), + recall: Some(mean_recall), + min_recall: Some(self.min_recall), + }); + } + + Ok(()) + } +} + +/// A validator that always passes. Useful for testing or when validation +/// is not needed. +pub struct NoopValidator; + +impl IndexValidator for NoopValidator { + fn validate(&self, _new_index: &HnswIndex) -> Result<(), AliasError> { + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::HnswIndex; + + fn make_test_index() -> HnswIndex { + let mut index = HnswIndex::new(4); + // Insert 10 vectors + for i in 0..10u8 { + let id = NodeId::new([i; 16]); + let vec = vec![i as f32; 4]; + index.insert(id, vec).unwrap(); + } + index + } + + #[test] + fn test_noop_validator() { + let index = make_test_index(); + let validator = NoopValidator; + assert!(validator.validate(&index).is_ok()); + } + + #[test] + fn test_recall_validator_empty_golden() { + let index = make_test_index(); + let validator = RecallValidator::new(vec![], 5, 0.95); + assert!(validator.validate(&index).is_ok()); + } + + #[test] + fn test_recall_validator_perfect_recall() { + let index = make_test_index(); + + // Search for the actual results first, then use those as golden truth + let query = vec![5.0f32; 4]; + let results = index.search(&query, 3).unwrap(); + let expected: Vec = results.iter().map(|(id, _)| *id).collect(); + assert!(!expected.is_empty(), "search should return results"); + + // Validator with the actual results as golden truth should pass + let validator = RecallValidator::new(vec![(query, expected)], 3, 0.95); + assert!(validator.validate(&index).is_ok()); + } + + #[test] + fn test_recall_validator_fails_low_recall() { + let index = make_test_index(); + + // Expect IDs that don't exist in the index + let query = vec![5.0f32; 4]; + let fake_ids: Vec = (100..110u8).map(|i| NodeId::new([i; 16])).collect(); + + let validator = RecallValidator::new(vec![(query, fake_ids)], 5, 0.95); + let result = validator.validate(&index); + assert!(result.is_err()); + + match result.unwrap_err() { + AliasError::ValidationFailed { + recall, min_recall, .. + } => { + assert_eq!(recall, Some(0.0)); + assert_eq!(min_recall, Some(0.95)); + } + other => panic!("Expected ValidationFailed, got: {other:?}"), + } + } +} diff --git a/crates/khive-hnsw/src/arena/arena.rs b/crates/khive-hnsw/src/arena/arena.rs new file mode 100644 index 00000000..8250b658 --- /dev/null +++ b/crates/khive-hnsw/src/arena/arena.rs @@ -0,0 +1,171 @@ +//! Core bump arena allocator. +//! +//! Pre-allocates a contiguous memory slab and bumps a pointer for each +//! allocation. Reset is O(1) -- just set the bump offset back to zero. +//! +//! # Memory Layout +//! +//! ```text +//! [---- slab (1 MiB default) ----] +//! ^ ^ ^ +//! base offset capacity +//! ``` +//! +//! Each `alloc(count)` bumps `offset` by `count * size_of::()` (with +//! alignment padding). If `offset` would exceed `capacity`, the arena grows +//! by allocating a new, larger slab. +//! +//! # Safety Invariants +//! +//! 1. The slab is a `Vec` owned by the arena. All pointers derived from +//! it are valid as long as the arena is alive and has not been reset or grown. +//! 2. `ArenaVec` and `ArenaBinaryHeap` hold an `&SearchArena` reference, +//! tying their lifetime to the arena. After `reset()`, all prior allocations +//! are logically invalid -- the type system enforces this via lifetimes. +//! 3. Growth invalidates all prior pointers. This is safe because growth only +//! happens during `alloc`, and all live `ArenaVec`/`ArenaBinaryHeap` objects +//! manage their own pointer + length, requesting new allocations as needed +//! via copy-on-grow. + +use std::cell::Cell; + +/// Default arena size: 1 MiB. More than sufficient for ef=256 searches. +/// +/// Worst-case per-search memory for ef=256, M=16: +/// +/// - candidates heap: 256 * 12 = 3,072 bytes +/// - results heap: 256 * 12 = 3,072 bytes +/// - batch buffer: 16 * 32 = 512 bytes +/// - result\_buf: 256 * 12 = 3,072 bytes +/// - overhead/alignment: ~1,000 bytes +/// +/// Total: ~10,728 bytes (~10 KiB) +/// +/// 1 MiB gives ~100x headroom. +pub const DEFAULT_ARENA_SIZE: usize = 1 << 20; // 1 MiB + +/// Bump arena allocator for HNSW search operations. +/// +/// All allocations within a search query bump from this arena. Between +/// queries, call `reset()` to reclaim all memory in O(1). +/// +/// The arena uses interior mutability (`Cell`) for the bump offset so that +/// multiple `ArenaVec` instances can allocate from the same `&SearchArena`. +pub struct SearchArena { + /// Backing memory slab. + slab: Cell>, + /// Current bump offset into the slab. + offset: Cell, +} + +impl SearchArena { + /// Create a new arena with the given capacity in bytes. + pub fn new(capacity: usize) -> Self { + let cap = capacity.max(1024); // Minimum 1 KiB + Self { + slab: Cell::new(vec![0u8; cap]), + offset: Cell::new(0), + } + } + + /// Create a new arena with the default 1 MiB capacity. + pub fn with_default_capacity() -> Self { + Self::new(DEFAULT_ARENA_SIZE) + } + + /// Reset the arena in O(1). All prior allocations become invalid. + /// + /// This is the key performance win: no deallocation, no destructors, + /// no zeroing. Just reset the bump pointer. + #[inline] + pub fn reset(&self) { + self.offset.set(0); + } + + /// Current number of bytes allocated from this arena. + #[inline] + pub fn bytes_used(&self) -> usize { + self.offset.get() + } + + /// Total capacity of the arena in bytes. + #[inline] + pub fn capacity(&self) -> usize { + // SAFETY: We take the slab out, read its capacity, and put it back. + // This is safe because we don't keep any references across the take/set. + let slab = self.slab.take(); + let cap = slab.capacity(); + self.slab.set(slab); + cap + } + + /// Allocate `count` elements of type `T` from the arena. + /// + /// Returns a pointer to the allocated memory. The caller is responsible + /// for writing to this memory before reading. + /// + /// # Panics + /// + /// Never panics. If the arena is full, it grows automatically. + /// + /// # Safety + /// + /// The returned pointer is valid until `reset()` is called or the arena + /// is dropped. The caller must not use the pointer after either event. + /// This is enforced by the lifetime parameter on `ArenaVec`. + pub(super) fn alloc(&self, count: usize) -> *mut T { + let size = std::mem::size_of::() * count; + let align = std::mem::align_of::(); + + if size == 0 { + return std::ptr::dangling_mut::(); // ZST: return aligned dangling pointer + } + + let mut current = self.offset.get(); + + // Align up + let aligned = (current + align - 1) & !(align - 1); + let new_offset = aligned + size; + + // Take slab, work with it, put it back + let mut slab = self.slab.take(); + + if new_offset > slab.len() { + // Grow: double or fit, whichever is larger + let new_cap = (slab.len() * 2).max(new_offset).max(slab.len() + size); + slab.resize(new_cap, 0); + // Recompute alignment in case resize moved the buffer + current = self.offset.get(); + let aligned = (current + align - 1) & !(align - 1); + let new_offset = aligned + size; + let ptr = slab.as_mut_ptr().wrapping_add(aligned) as *mut T; + self.offset.set(new_offset); + self.slab.set(slab); + return ptr; + } + + let ptr = slab.as_mut_ptr().wrapping_add(aligned) as *mut T; + self.offset.set(new_offset); + self.slab.set(slab); + ptr + } + + /// Copy `src` slice into the arena and return a mutable pointer to the copy. + /// + /// Useful for bulk-copying data into the arena. Allocates space for + /// `src.len()` elements, copies them in, and returns a pointer to the copy. + #[allow(dead_code)] + pub(super) fn alloc_copy(&self, src: &[T]) -> *mut T { + if src.is_empty() { + return self.alloc::(0); + } + let ptr = self.alloc::(src.len()); + // SAFETY: `ptr` points to freshly allocated arena memory with enough + // space for `src.len()` elements. `src` is a valid slice. No overlap + // because arena memory is freshly bumped. + unsafe { + std::ptr::copy_nonoverlapping(src.as_ptr(), ptr, src.len()); + } + ptr + } +} diff --git a/crates/khive-hnsw/src/arena/arena_heap.rs b/crates/khive-hnsw/src/arena/arena_heap.rs new file mode 100644 index 00000000..53648481 --- /dev/null +++ b/crates/khive-hnsw/src/arena/arena_heap.rs @@ -0,0 +1,149 @@ +//! Arena-backed binary heap. +//! +//! `ArenaBinaryHeap<'a, T>` provides min-heap or max-heap behavior backed +//! by an `ArenaVec`. It implements the same operations as `std::BinaryHeap` +//! (push, pop, peek, len, clear, drain) with identical algorithmic complexity. +//! +//! # Heap Ordering +//! +//! The heap uses `Ord` ordering. For a max-heap (default BinaryHeap behavior), +//! use `T` directly. For a min-heap, wrap elements in `std::cmp::Reverse`. +//! +//! In the HNSW search context: +//! - `candidates`: min-heap via `Reverse<(OrderedF32, usize)>` -- closest first +//! - `results`: max-heap via `(OrderedF32, usize)` -- furthest first (for pruning) + +use super::arena::SearchArena; +use super::arena_vec::ArenaVec; + +/// A binary heap backed by arena allocation. +/// +/// This is a max-heap by default (like `std::BinaryHeap`). For a min-heap, +/// wrap elements in `std::cmp::Reverse`. +/// +/// Elements must be `Copy + Ord`. `Copy` because the backing `ArenaVec` +/// requires it. `Ord` for heap ordering. +pub struct ArenaBinaryHeap<'a, T: Copy + Ord> { + data: ArenaVec<'a, T>, +} + +impl<'a, T: Copy + Ord> ArenaBinaryHeap<'a, T> { + /// Create a new empty heap with the given initial capacity. + #[inline] + pub fn new(arena: &'a SearchArena, capacity: usize) -> Self { + Self { + data: ArenaVec::new(arena, capacity), + } + } + + /// Number of elements in the heap. + #[inline] + pub fn len(&self) -> usize { + self.data.len() + } + + /// Whether the heap is empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } + + /// Peek at the maximum element (for max-heap) without removing it. + #[inline] + pub fn peek(&self) -> Option<&T> { + if self.data.is_empty() { + None + } else { + Some(self.data.get(0)) + } + } + + /// Push an element onto the heap. O(log n). + #[inline] + pub fn push(&mut self, value: T) { + self.data.push(value); + self.sift_up(self.data.len() - 1); + } + + /// Remove and return the maximum element. O(log n). + #[inline] + pub fn pop(&mut self) -> Option { + if self.data.is_empty() { + return None; + } + let len = self.data.len(); + if len == 1 { + return self.data.pop(); + } + // Swap root with last, pop last, sift down root. + let root = *self.data.get(0); + let last = *self.data.get(len - 1); + *self.data.get_mut(0) = last; + self.data.pop(); // removes last element + if !self.data.is_empty() { + self.sift_down(0); + } + Some(root) + } + + /// Clear the heap without deallocating. + #[inline] + pub fn clear(&mut self) { + self.data.clear(); + } + + /// Drain all elements from the heap (in arbitrary order). + /// + /// The returned iterator yields elements in storage order, NOT heap order. + /// If sorted order is needed, collect and sort separately. + pub fn drain(&mut self) -> impl Iterator + '_ { + self.data.drain() + } + + /// Sift element at `pos` up to restore heap property. + #[inline] + fn sift_up(&mut self, mut pos: usize) { + while pos > 0 { + let parent = (pos - 1) / 2; + if *self.data.get(pos) > *self.data.get(parent) { + // Swap child with parent + let child_val = *self.data.get(pos); + let parent_val = *self.data.get(parent); + *self.data.get_mut(pos) = parent_val; + *self.data.get_mut(parent) = child_val; + pos = parent; + } else { + break; + } + } + } + + /// Sift element at `pos` down to restore heap property. + #[inline] + fn sift_down(&mut self, mut pos: usize) { + let len = self.data.len(); + loop { + let left = 2 * pos + 1; + let right = 2 * pos + 2; + let mut largest = pos; + + if left < len && *self.data.get(left) > *self.data.get(largest) { + largest = left; + } + if right < len && *self.data.get(right) > *self.data.get(largest) { + largest = right; + } + + if largest == pos { + break; + } + + // Swap + let pos_val = *self.data.get(pos); + let largest_val = *self.data.get(largest); + *self.data.get_mut(pos) = largest_val; + *self.data.get_mut(largest) = pos_val; + pos = largest; + } + } +} diff --git a/crates/khive-hnsw/src/arena/arena_vec.rs b/crates/khive-hnsw/src/arena/arena_vec.rs new file mode 100644 index 00000000..544c0bbb --- /dev/null +++ b/crates/khive-hnsw/src/arena/arena_vec.rs @@ -0,0 +1,281 @@ +//! Arena-backed growable vector. +//! +//! `ArenaVec<'a, T>` is a vector that allocates from a `SearchArena` instead +//! of the global allocator. It supports push, pop, clear, len, iter, and drain. +//! +//! # Growth Strategy +//! +//! When capacity is exceeded, `ArenaVec` allocates a new (larger) region from +//! the arena and copies existing elements. The old region is "leaked" in the +//! arena -- this is fine because the arena will be reset between queries, +//! reclaiming all memory at once. +//! +//! # Lifetime +//! +//! The `'a` lifetime ties this vec to its arena. The vec cannot outlive the +//! arena, and `reset()` on the arena logically invalidates all vecs. In +//! practice, the search code creates vecs at the start of a query and drops +//! them before calling `reset()`. + +use super::arena::SearchArena; + +/// A growable vector backed by arena allocation. +/// +/// Elements must be `Copy` because growth copies elements to a new arena +/// region (no drop semantics needed). +pub struct ArenaVec<'a, T: Copy> { + /// Pointer to the start of the allocated region in the arena. + ptr: *mut T, + /// Number of live elements. + len: usize, + /// Total capacity (in elements) of the current allocation. + cap: usize, + /// Reference to the owning arena (for growth allocation). + arena: &'a SearchArena, +} + +impl<'a, T: Copy> ArenaVec<'a, T> { + /// Create a new empty `ArenaVec` with the given initial capacity. + #[inline] + pub fn new(arena: &'a SearchArena, capacity: usize) -> Self { + let (ptr, cap) = if capacity > 0 { + (arena.alloc::(capacity), capacity) + } else { + (std::ptr::dangling_mut::(), 0) // dangling + }; + Self { + ptr, + len: 0, + cap, + arena, + } + } + + /// Number of elements in the vec. + #[inline] + pub fn len(&self) -> usize { + self.len + } + + /// Whether the vec is empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + /// Push an element. Grows if necessary. + #[inline] + pub fn push(&mut self, value: T) { + if self.len == self.cap { + self.grow(); + } + // SAFETY: We just ensured len < cap, so ptr.add(len) is within bounds. + unsafe { + self.ptr.add(self.len).write(value); + } + self.len += 1; + } + + /// Pop the last element. + #[inline] + pub fn pop(&mut self) -> Option { + if self.len == 0 { + return None; + } + self.len -= 1; + // SAFETY: len was > 0, so ptr.add(len) points to the last written element. + Some(unsafe { self.ptr.add(self.len).read() }) + } + + /// Clear the vec without deallocating (just resets length). + #[inline] + pub fn clear(&mut self) { + self.len = 0; + } + + /// Get a reference to the element at `index`. + /// + /// # Panics + /// + /// Panics if `index >= len`. Uses `assert!` (not `debug_assert!`) so + /// bounds enforcement is never compiled out in release builds (#2530). + #[inline] + pub fn get(&self, index: usize) -> &T { + assert!(index < self.len, "ArenaVec index out of bounds"); + // SAFETY: index < len, and all elements up to len are initialized. + unsafe { &*self.ptr.add(index) } + } + + /// Get an optional reference to the element at `index` (issue #2530). + /// + /// Returns `None` instead of panicking when `index >= len`. Prefer this + /// over `get()` when the caller cannot statically guarantee the index is + /// within bounds. + #[inline] + pub fn try_get(&self, index: usize) -> Option<&T> { + if index < self.len { + // SAFETY: index < len, and all elements up to len are initialized. + Some(unsafe { &*self.ptr.add(index) }) + } else { + None + } + } + + /// Get a mutable reference to the element at `index`. + /// + /// # Panics + /// + /// Panics if `index >= len`. Uses `assert!` (not `debug_assert!`) so + /// bounds enforcement is never compiled out in release builds (#2530). + #[inline] + pub fn get_mut(&mut self, index: usize) -> &mut T { + assert!(index < self.len, "ArenaVec index out of bounds"); + // SAFETY: index < len, and all elements up to len are initialized. + unsafe { &mut *self.ptr.add(index) } + } + + /// Get an optional mutable reference to the element at `index` (issue #2530). + /// + /// Returns `None` instead of panicking when `index >= len`. + #[inline] + pub fn try_get_mut(&mut self, index: usize) -> Option<&mut T> { + if index < self.len { + // SAFETY: index < len, and all elements up to len are initialized. + Some(unsafe { &mut *self.ptr.add(index) }) + } else { + None + } + } + + /// Get an immutable slice of all elements. + #[inline] + pub fn as_slice(&self) -> &[T] { + if self.len == 0 { + return &[]; + } + // SAFETY: ptr points to len initialized elements in the arena. + unsafe { std::slice::from_raw_parts(self.ptr, self.len) } + } + + /// Get a mutable slice of all elements. + #[inline] + pub fn as_mut_slice(&mut self) -> &mut [T] { + if self.len == 0 { + return &mut []; + } + // SAFETY: ptr points to len initialized elements in the arena. + unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) } + } + + /// Iterate over elements by reference. + #[inline] + pub fn iter(&self) -> std::slice::Iter<'_, T> { + self.as_slice().iter() + } + + /// Swap-remove the element at `index` (O(1) removal). + #[inline] + pub fn swap_remove(&mut self, index: usize) -> T { + assert!(index < self.len, "ArenaVec swap_remove index out of bounds"); + let last = self.len - 1; + // SAFETY: both index and last are < len. + unsafe { + let val = self.ptr.add(index).read(); + if index != last { + let last_val = self.ptr.add(last).read(); + self.ptr.add(index).write(last_val); + } + self.len -= 1; + val + } + } + + /// Drain all elements, returning an iterator over them. + /// After drain, the vec is empty. + pub fn drain(&mut self) -> ArenaVecDrain<'_, T> { + let len = self.len; + self.len = 0; + ArenaVecDrain { + ptr: self.ptr, + pos: 0, + len, + _marker: std::marker::PhantomData, + } + } + + /// Extend from a slice. + pub fn extend_from_slice(&mut self, slice: &[T]) { + for &item in slice { + self.push(item); + } + } + + /// Sort elements using the provided comparison function. + pub fn sort_by(&mut self, compare: F) + where + F: FnMut(&T, &T) -> std::cmp::Ordering, + { + self.as_mut_slice().sort_by(compare); + } + + /// Grow the allocation by doubling capacity (minimum 8). + /// + /// Allocates a new region from the arena and copies existing elements. + /// The old region is leaked in the arena -- reclaimed on reset(). + fn grow(&mut self) { + let new_cap = if self.cap == 0 { 8 } else { self.cap * 2 }; + let new_ptr = self.arena.alloc::(new_cap); + if self.len > 0 { + // SAFETY: copying len elements from old region (ptr, len elements) + // to new region (new_ptr, new_cap >= len elements). No overlap + // because the arena only bumps forward. + unsafe { + std::ptr::copy_nonoverlapping(self.ptr, new_ptr, self.len); + } + } + // Old region is leaked -- reclaimed on arena reset. + self.ptr = new_ptr; + self.cap = new_cap; + } +} + +/// Index by usize for convenience (read-only). +impl<'a, T: Copy> std::ops::Index for ArenaVec<'a, T> { + type Output = T; + + #[inline] + fn index(&self, index: usize) -> &T { + self.get(index) + } +} + +/// Drain iterator for `ArenaVec`. +pub struct ArenaVecDrain<'a, T: Copy> { + ptr: *mut T, + pos: usize, + len: usize, + _marker: std::marker::PhantomData<&'a T>, +} + +impl Iterator for ArenaVecDrain<'_, T> { + type Item = T; + + #[inline] + fn next(&mut self) -> Option { + if self.pos >= self.len { + return None; + } + // SAFETY: pos < len, and all elements were initialized before drain. + let val = unsafe { self.ptr.add(self.pos).read() }; + self.pos += 1; + Some(val) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let remaining = self.len - self.pos; + (remaining, Some(remaining)) + } +} + +impl ExactSizeIterator for ArenaVecDrain<'_, T> {} diff --git a/crates/khive-hnsw/src/arena/mod.rs b/crates/khive-hnsw/src/arena/mod.rs new file mode 100644 index 00000000..c0d13881 --- /dev/null +++ b/crates/khive-hnsw/src/arena/mod.rs @@ -0,0 +1,30 @@ +//! Bump arena allocator for zero-allocation HNSW search. +//! +//! Provides a fixed-size memory arena with O(1) reset between queries. +//! All per-search allocations (candidates heap, results heap, batch buffer, +//! result buffer) bump from this arena instead of the global allocator. +//! +//! # Design +//! +//! The arena pre-allocates a configurable slab (default 1 MiB). Within a +//! single search query, allocations bump a pointer forward. Between queries, +//! `reset()` sets the pointer back to zero -- O(1), no deallocation, no +//! destructors, no zeroing. +//! +//! # Thread Safety +//! +//! The arena is `!Send` and `!Sync` by design. For concurrent search, each +//! thread should own its own `SearchArena` (via `thread_local!` or explicit +//! per-thread allocation). + +#[allow(clippy::module_inception)] +mod arena; +mod arena_heap; +mod arena_vec; + +pub use arena::SearchArena; +pub use arena_heap::ArenaBinaryHeap; +pub use arena_vec::ArenaVec; + +#[cfg(test)] +mod tests; diff --git a/crates/khive-hnsw/src/arena/tests.rs b/crates/khive-hnsw/src/arena/tests.rs new file mode 100644 index 00000000..edd84182 --- /dev/null +++ b/crates/khive-hnsw/src/arena/tests.rs @@ -0,0 +1,440 @@ +//! Tests for the search arena allocator. + +use super::*; + +// ========================================================================= +// SearchArena tests +// ========================================================================= + +#[test] +fn test_arena_alloc_and_reset() { + let arena = SearchArena::new(4096); + assert_eq!(arena.bytes_used(), 0); + + // Allocate some memory + let _ptr: *mut u64 = arena.alloc::(10); + assert!(arena.bytes_used() > 0); + let used_after_alloc = arena.bytes_used(); + + // Allocate more + let _ptr2: *mut u32 = arena.alloc::(20); + assert!(arena.bytes_used() > used_after_alloc); + + // Reset -- O(1), reclaims all memory + arena.reset(); + assert_eq!(arena.bytes_used(), 0); + + // Can allocate again after reset + let _ptr3: *mut u64 = arena.alloc::(10); + assert!(arena.bytes_used() > 0); +} + +#[test] +fn test_arena_overflow_grows() { + // Small arena that will need to grow + let arena = SearchArena::new(1024); // minimum size + let initial_cap = arena.capacity(); + + // Allocate more than capacity + let _ptr: *mut u8 = arena.alloc::(2048); + + // Arena should have grown + assert!(arena.capacity() >= 2048); + assert!(arena.capacity() > initial_cap); +} + +#[test] +fn test_arena_reset_reuse_cycle() { + let arena = SearchArena::new(4096); + + for _ in 0..100 { + // Simulate a search query: allocate various buffers + let _candidates: *mut (f32, usize) = arena.alloc::<(f32, usize)>(64); + let _results: *mut (f32, usize) = arena.alloc::<(f32, usize)>(64); + let _batch: *mut (usize, usize) = arena.alloc::<(usize, usize)>(32); + + assert!(arena.bytes_used() > 0); + + // Reset between queries + arena.reset(); + assert_eq!(arena.bytes_used(), 0); + } +} + +#[test] +fn test_arena_alignment() { + let arena = SearchArena::new(4096); + + // Allocate a u8 to offset the pointer + let _: *mut u8 = arena.alloc::(1); + + // Allocate a u64 -- should be aligned to 8 bytes + let ptr: *mut u64 = arena.alloc::(1); + assert_eq!(ptr as usize % std::mem::align_of::(), 0); + + // Allocate a u128 -- should be aligned to 16 bytes + let ptr128: *mut u128 = arena.alloc::(1); + assert_eq!(ptr128 as usize % std::mem::align_of::(), 0); +} + +// ========================================================================= +// ArenaVec tests +// ========================================================================= + +#[test] +fn test_arena_vec_push_pop() { + let arena = SearchArena::new(4096); + let mut vec = ArenaVec::new(&arena, 4); + + assert!(vec.is_empty()); + assert_eq!(vec.len(), 0); + + vec.push(10); + vec.push(20); + vec.push(30); + + assert_eq!(vec.len(), 3); + assert!(!vec.is_empty()); + + assert_eq!(vec.pop(), Some(30)); + assert_eq!(vec.pop(), Some(20)); + assert_eq!(vec.pop(), Some(10)); + assert_eq!(vec.pop(), None); + assert!(vec.is_empty()); +} + +#[test] +fn test_arena_vec_growth() { + let arena = SearchArena::new(4096); + let mut vec = ArenaVec::new(&arena, 2); + + // Push beyond initial capacity + for i in 0..100 { + vec.push(i); + } + + assert_eq!(vec.len(), 100); + for i in 0..100 { + assert_eq!(*vec.get(i), i); + } +} + +#[test] +fn test_arena_vec_clear() { + let arena = SearchArena::new(4096); + let mut vec = ArenaVec::new(&arena, 8); + + vec.push(1); + vec.push(2); + vec.push(3); + vec.clear(); + + assert!(vec.is_empty()); + assert_eq!(vec.len(), 0); + + // Can push again after clear + vec.push(4); + assert_eq!(vec.len(), 1); + assert_eq!(*vec.get(0), 4); +} + +#[test] +fn test_arena_vec_iter() { + let arena = SearchArena::new(4096); + let mut vec = ArenaVec::new(&arena, 8); + + vec.push(10); + vec.push(20); + vec.push(30); + + let collected: Vec = vec.iter().copied().collect(); + assert_eq!(collected, vec![10, 20, 30]); +} + +#[test] +fn test_arena_vec_drain() { + let arena = SearchArena::new(4096); + let mut vec = ArenaVec::new(&arena, 8); + + vec.push(1); + vec.push(2); + vec.push(3); + + let drained: Vec = vec.drain().collect(); + assert_eq!(drained, vec![1, 2, 3]); + assert!(vec.is_empty()); +} + +#[test] +fn test_arena_vec_as_slice() { + let arena = SearchArena::new(4096); + let mut vec = ArenaVec::new(&arena, 8); + + vec.push(5); + vec.push(10); + vec.push(15); + + assert_eq!(vec.as_slice(), &[5, 10, 15]); +} + +#[test] +fn test_arena_vec_index() { + let arena = SearchArena::new(4096); + let mut vec = ArenaVec::new(&arena, 8); + + vec.push(100); + vec.push(200); + + assert_eq!(vec[0], 100); + assert_eq!(vec[1], 200); +} + +#[test] +fn test_arena_vec_sort_by() { + let arena = SearchArena::new(4096); + let mut vec = ArenaVec::new(&arena, 8); + + vec.push(30); + vec.push(10); + vec.push(20); + + vec.sort_by(|a, b| a.cmp(b)); + assert_eq!(vec.as_slice(), &[10, 20, 30]); +} + +#[test] +fn test_arena_vec_swap_remove() { + let arena = SearchArena::new(4096); + let mut vec = ArenaVec::new(&arena, 8); + + vec.push(10); + vec.push(20); + vec.push(30); + + let removed = vec.swap_remove(0); + assert_eq!(removed, 10); + assert_eq!(vec.len(), 2); + // 30 should now be at index 0 (swapped from last) + assert_eq!(*vec.get(0), 30); + assert_eq!(*vec.get(1), 20); +} + +// ========================================================================= +// ArenaBinaryHeap tests +// ========================================================================= + +#[test] +fn test_arena_heap_max_ordering() { + let arena = SearchArena::new(4096); + let mut heap = ArenaBinaryHeap::new(&arena, 8); + + heap.push(10); + heap.push(30); + heap.push(20); + heap.push(5); + heap.push(25); + + // Max-heap: should pop in descending order + assert_eq!(heap.pop(), Some(30)); + assert_eq!(heap.pop(), Some(25)); + assert_eq!(heap.pop(), Some(20)); + assert_eq!(heap.pop(), Some(10)); + assert_eq!(heap.pop(), Some(5)); + assert_eq!(heap.pop(), None); +} + +#[test] +fn test_arena_heap_min_ordering_via_reverse() { + use std::cmp::Reverse; + + let arena = SearchArena::new(4096); + let mut heap = ArenaBinaryHeap::new(&arena, 8); + + heap.push(Reverse(10)); + heap.push(Reverse(30)); + heap.push(Reverse(20)); + heap.push(Reverse(5)); + heap.push(Reverse(25)); + + // Min-heap via Reverse: should pop in ascending order + assert_eq!(heap.pop(), Some(Reverse(5))); + assert_eq!(heap.pop(), Some(Reverse(10))); + assert_eq!(heap.pop(), Some(Reverse(20))); + assert_eq!(heap.pop(), Some(Reverse(25))); + assert_eq!(heap.pop(), Some(Reverse(30))); + assert_eq!(heap.pop(), None); +} + +#[test] +fn test_arena_heap_peek() { + let arena = SearchArena::new(4096); + let mut heap = ArenaBinaryHeap::new(&arena, 8); + + assert_eq!(heap.peek(), None); + + heap.push(10); + assert_eq!(heap.peek(), Some(&10)); + + heap.push(20); + assert_eq!(heap.peek(), Some(&20)); + + heap.push(5); + assert_eq!(heap.peek(), Some(&20)); // 20 is still max +} + +#[test] +fn test_arena_heap_clear() { + let arena = SearchArena::new(4096); + let mut heap = ArenaBinaryHeap::new(&arena, 8); + + heap.push(1); + heap.push(2); + heap.push(3); + heap.clear(); + + assert!(heap.is_empty()); + assert_eq!(heap.len(), 0); + assert_eq!(heap.peek(), None); +} + +#[test] +fn test_arena_heap_drain() { + let arena = SearchArena::new(4096); + let mut heap = ArenaBinaryHeap::new(&arena, 8); + + heap.push(10); + heap.push(20); + heap.push(30); + + let mut drained: Vec = heap.drain().collect(); + drained.sort(); + assert_eq!(drained, vec![10, 20, 30]); + assert!(heap.is_empty()); +} + +// ========================================================================= +// Integration: HNSW-like usage patterns +// ========================================================================= + +#[test] +fn test_hnsw_search_pattern() { + use crate::distance::OrderedF32; + + let arena = SearchArena::new(4096); + + // Simulate the HNSW search pattern: + // candidates = min-heap, results = max-heap + + let mut candidates: ArenaBinaryHeap> = + ArenaBinaryHeap::new(&arena, 64); + let mut results: ArenaBinaryHeap<(OrderedF32, usize)> = ArenaBinaryHeap::new(&arena, 64); + let mut result_buf: ArenaVec<(f32, usize)> = ArenaVec::new(&arena, 64); + + // Insert entry points + candidates.push(std::cmp::Reverse((OrderedF32(0.5), 0))); + results.push((OrderedF32(0.5), 0)); + + candidates.push(std::cmp::Reverse((OrderedF32(0.3), 1))); + results.push((OrderedF32(0.3), 1)); + + candidates.push(std::cmp::Reverse((OrderedF32(0.7), 2))); + results.push((OrderedF32(0.7), 2)); + + // Pop closest candidate (min-heap) + let closest = candidates.pop().unwrap(); + assert_eq!(closest.0 .0 .0, 0.3); // OrderedF32(0.3) + + // Peek worst result (max-heap) + let worst = results.peek().unwrap(); + assert_eq!(worst.0 .0, 0.7); // OrderedF32(0.7) + + // Drain results into result_buf + for (dist, id) in results.drain() { + result_buf.push((dist.0, id)); + } + result_buf.sort_by(|a, b| OrderedF32(a.0).cmp(&OrderedF32(b.0))); + + assert_eq!(result_buf.len(), 3); + assert_eq!(result_buf[0].0, 0.3); + assert_eq!(result_buf[1].0, 0.5); + assert_eq!(result_buf[2].0, 0.7); + + // Reset arena for next query + arena.reset(); + assert_eq!(arena.bytes_used(), 0); +} + +#[test] +fn test_arena_vec_extend_from_slice() { + let arena = SearchArena::new(4096); + let mut vec = ArenaVec::new(&arena, 4); + + vec.extend_from_slice(&[1, 2, 3, 4, 5]); + assert_eq!(vec.as_slice(), &[1, 2, 3, 4, 5]); +} + +#[test] +fn test_arena_vec_zero_capacity() { + let arena = SearchArena::new(4096); + let mut vec: ArenaVec = ArenaVec::new(&arena, 0); + + assert!(vec.is_empty()); + vec.push(42); + assert_eq!(vec.len(), 1); + assert_eq!(vec[0], 42); +} + +#[test] +fn test_arena_heap_single_element() { + let arena = SearchArena::new(4096); + let mut heap = ArenaBinaryHeap::new(&arena, 4); + + heap.push(42); + assert_eq!(heap.peek(), Some(&42)); + assert_eq!(heap.pop(), Some(42)); + assert!(heap.is_empty()); +} + +#[test] +fn test_arena_heap_duplicate_values() { + let arena = SearchArena::new(4096); + let mut heap = ArenaBinaryHeap::new(&arena, 8); + + heap.push(5); + heap.push(5); + heap.push(5); + + assert_eq!(heap.pop(), Some(5)); + assert_eq!(heap.pop(), Some(5)); + assert_eq!(heap.pop(), Some(5)); + assert_eq!(heap.pop(), None); +} + +#[test] +fn test_multiple_reset_cycles_with_collections() { + let arena = SearchArena::new(4096); + + for cycle in 0..50 { + let mut vec = ArenaVec::new(&arena, 8); + let mut heap = ArenaBinaryHeap::new(&arena, 8); + + for i in 0..20 { + vec.push(cycle * 100 + i); + heap.push(cycle * 100 + i); + } + + assert_eq!(vec.len(), 20); + assert_eq!(heap.len(), 20); + + // Verify max element + assert_eq!(heap.peek(), Some(&(cycle * 100 + 19))); + + arena.reset(); + } +} + +#[test] +fn test_arena_default_capacity() { + let arena = SearchArena::with_default_capacity(); + assert!(arena.capacity() >= super::arena::DEFAULT_ARENA_SIZE); +} diff --git a/crates/khive-hnsw/src/checkpoint/integration_tests.rs b/crates/khive-hnsw/src/checkpoint/integration_tests.rs new file mode 100644 index 00000000..666382aa --- /dev/null +++ b/crates/khive-hnsw/src/checkpoint/integration_tests.rs @@ -0,0 +1,181 @@ +use super::*; +use khive_fold::{Checkpoint, CheckpointStore, FoldContext, InMemoryCheckpointStore}; +use khive_types::Hash32; +use uuid::Uuid; + +fn test_hash() -> Hash32 { + Hash32::from_bytes(*blake3::hash(b"hnsw checkpoint test").as_bytes()) +} + +fn make_id(seed: u8) -> NodeId { + NodeId::new([seed; 16]) +} + +fn sample_snapshot() -> HnswSnapshot { + HnswSnapshot { + vector_count: 0, + total_nodes: 1, + live_nodes: 1, + tombstone_count: 0, + max_layer: 0, + entry_point: Some(make_id(1)), + config: HnswCheckpointConfig { + m: 16, + ef_construction: 200, + metric: "cosine".to_string(), + }, + indexed_ids: vec![make_id(1)], + tombstoned_ids: vec![], + layers: vec![vec![(make_id(1), vec![])]], + vectors: vec![], + } +} + +fn sample_snapshot_with_tombstones() -> HnswSnapshot { + let id1 = make_id(1); + let id2 = make_id(2); + HnswSnapshot { + vector_count: 0, + total_nodes: 2, + live_nodes: 1, + tombstone_count: 1, + max_layer: 0, + entry_point: Some(id1), + config: HnswCheckpointConfig { + m: 16, + ef_construction: 200, + metric: "cosine".to_string(), + }, + indexed_ids: vec![id1, id2], + tombstoned_ids: vec![id2], + layers: vec![vec![(id1, vec![id2]), (id2, vec![id1])]], + vectors: vec![], + } +} + +#[test] +fn create_hnsw_checkpoint() { + let snap = sample_snapshot(); + let checkpoint: HnswCheckpoint = Checkpoint::new( + "hnsw_test:ckpt-1", + snap, + Uuid::new_v4(), + test_hash(), + 100, + FoldContext::new(), + 1, + ); + + assert_eq!(checkpoint.state.total_nodes, 1); + assert_eq!(checkpoint.state.live_nodes, 1); + assert_eq!(checkpoint.entries_processed, 100); + assert_eq!(checkpoint.fold_version, 1); +} + +#[test] +fn create_hnsw_checkpoint_with_tombstones() { + let snap = sample_snapshot_with_tombstones(); + let checkpoint: HnswCheckpoint = Checkpoint::new( + "hnsw_test:ckpt-1", + snap, + Uuid::new_v4(), + test_hash(), + 100, + FoldContext::new(), + 1, + ); + + assert_eq!(checkpoint.state.total_nodes, 2); + assert_eq!(checkpoint.state.live_nodes, 1); + assert_eq!(checkpoint.state.tombstone_count, 1); + assert_eq!(checkpoint.state.tombstoned_ids.len(), 1); +} + +#[test] +fn store_and_load_hnsw_checkpoint() { + let store: HnswCheckpointStore = InMemoryCheckpointStore::new(); + let snap = sample_snapshot(); + + let checkpoint: HnswCheckpoint = Checkpoint::new( + "hnsw_idx:ckpt-1", + snap, + Uuid::new_v4(), + test_hash(), + 50, + FoldContext::new(), + 1, + ); + + store.save(&checkpoint).expect("save"); + + let loaded = store + .load("hnsw_idx:ckpt-1") + .expect("load") + .expect("should exist"); + + assert_eq!(loaded.state.total_nodes, 1); + assert_eq!(loaded.state.live_nodes, 1); + assert_eq!(loaded.state.config.m, 16); + assert_eq!(loaded.state.config.metric, "cosine"); + assert_eq!(loaded.entries_processed, 50); +} + +#[test] +fn store_and_load_checkpoint_with_tombstones() { + let store: HnswCheckpointStore = InMemoryCheckpointStore::new(); + let snap = sample_snapshot_with_tombstones(); + + let checkpoint: HnswCheckpoint = Checkpoint::new( + "hnsw_idx:ckpt-tomb", + snap, + Uuid::new_v4(), + test_hash(), + 50, + FoldContext::new(), + 1, + ); + + store.save(&checkpoint).expect("save"); + + let loaded = store + .load("hnsw_idx:ckpt-tomb") + .expect("load") + .expect("should exist"); + + assert_eq!(loaded.state.total_nodes, 2); + assert_eq!(loaded.state.live_nodes, 1); + assert_eq!(loaded.state.tombstone_count, 1); + assert!(loaded.state.verify().is_ok()); +} + +#[test] +fn load_latest_hnsw_checkpoint() { + let store: HnswCheckpointStore = InMemoryCheckpointStore::new(); + + for i in 0..3 { + let mut snap = sample_snapshot(); + snap.total_nodes = (i + 1) * 100; + snap.live_nodes = (i + 1) * 100; + + let checkpoint: HnswCheckpoint = Checkpoint::new( + format!("hnsw_idx:ckpt-{i}"), + snap, + Uuid::new_v4(), + test_hash(), + (i + 1) * 10, + FoldContext::new(), + 1, + ); + store.save(&checkpoint).expect("save"); + std::thread::sleep(std::time::Duration::from_millis(10)); + } + + let latest = store + .load_latest("hnsw_idx") + .expect("load_latest") + .expect("should exist"); + + assert_eq!(latest.state.total_nodes, 300); + assert_eq!(latest.state.live_nodes, 300); + assert_eq!(latest.entries_processed, 30); +} diff --git a/crates/khive-hnsw/src/checkpoint/mod.rs b/crates/khive-hnsw/src/checkpoint/mod.rs new file mode 100644 index 00000000..1eb88c50 --- /dev/null +++ b/crates/khive-hnsw/src/checkpoint/mod.rs @@ -0,0 +1,423 @@ +//! HNSW index checkpointing using khive-fold checkpoint system. +//! +//! Provides periodic snapshots of the HNSW index for crash recovery +//! and incremental rebuilds. +//! +//! # Architecture +//! +//! The snapshot types ([`HnswSnapshot`], [`HnswCheckpointConfig`]) are always +//! available and carry no extra dependencies. They are plain serializable data. +//! +//! When the `checkpoint` feature is enabled, this module also provides type +//! aliases that integrate with `khive-fold`'s [`Checkpoint`] and +//! [`InMemoryCheckpointStore`] for a complete checkpoint lifecycle. +//! +//! ```text +//! HnswIndex ──snapshot──> HnswSnapshot ──wrap──> Checkpoint +//! │ +//! CheckpointStore::save(...) +//! ``` +//! +//! # Tombstone Tracking +//! +//! Snapshots track both live and tombstoned nodes to ensure accurate restore: +//! - `total_nodes`: All nodes (live + tombstoned) +//! - `live_nodes`: Non-tombstoned nodes only +//! - `tombstone_count`: Number of tombstoned nodes +//! - `tombstoned_ids`: IDs of tombstoned vectors for restore +//! +//! The invariant `total_nodes == live_nodes + tombstone_count` is enforced +//! via the [`HnswSnapshot::verify`] method. +//! +//! # Determinism +//! +//! All ID lists (`indexed_ids`, `tombstoned_ids`) and layer node entries are +//! stored in sorted order by NodeId bytes to ensure deterministic snapshots +//! across runs. This is critical for: +//! - Reproducible checkpoint hashes +//! - Stable index-based encoding (e.g., tombstone bitsets) +//! - Test reproducibility +//! +//! Use [`HnswSnapshot::canonicalize`] to ensure snapshots are in canonical form +//! before serialization, or [`HnswSnapshot::is_canonical`] to verify ordering. + +use std::collections::HashSet; + +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +use super::config::DistanceMetric; +use crate::NodeId; + +/// Errors that can occur during snapshot verification. +#[derive(Error, Debug, Clone, PartialEq, Eq)] +pub enum SnapshotError { + /// Node count fields are inconsistent. + #[error( + "inconsistent counts: total_nodes ({total}) != live_nodes ({live}) + tombstone_count ({tombstones})" + )] + InconsistentCounts { + /// Total nodes reported. + total: usize, + /// Live nodes reported. + live: usize, + /// Tombstones reported. + tombstones: usize, + }, + + /// indexed_ids length doesn't match total_nodes. + #[error("indexed_ids count mismatch: expected {expected}, got {actual}")] + IdCountMismatch { + /// Expected count (total_nodes). + expected: usize, + /// Actual indexed_ids length. + actual: usize, + }, + + /// tombstoned_ids length doesn't match tombstone_count. + #[error("tombstoned_ids count mismatch: expected {expected}, got {actual}")] + TombstoneIdCountMismatch { + /// Expected count (tombstone_count). + expected: usize, + /// Actual tombstoned_ids length. + actual: usize, + }, + + /// Tombstoned ID not found in indexed_ids. + #[error("tombstoned id {id:?} not found in indexed_ids")] + TombstoneNotInIndex { + /// The missing tombstone ID. + id: NodeId, + }, +} + +/// Sort NodeIds by their byte representation for deterministic ordering. +/// +/// This ensures consistent ordering across runs regardless of HashMap iteration order, +/// which is critical for reproducible checkpoint hashes and stable index-based encodings. +#[inline] +pub(crate) fn sort_ids(ids: &mut [NodeId]) { + ids.sort_by(|a, b| a.as_bytes().cmp(b.as_bytes())); +} + +/// Helper for serde skip_serializing_if on legacy vector_count field. +fn is_zero(val: &usize) -> bool { + *val == 0 +} + +/// Serializable snapshot of HNSW index state. +/// +/// Captures enough information to reconstruct the index without +/// re-indexing all vectors from scratch. +/// +/// # Backward Compatibility +/// +/// This struct maintains backward compatibility with v1 snapshots that only +/// had `vector_count`. When deserializing old snapshots: +/// - `vector_count` is read and used to populate `total_nodes`/`live_nodes` +/// - New tombstone fields default to empty/zero +/// - Missing `vectors` field defaults to empty (old snapshots require external vector supply) +/// +/// Call [`HnswSnapshot::normalize`] after deserialization to ensure consistent state. +/// +/// # Warm Start +/// +/// When `vectors` is non-empty the snapshot is self-contained: call +/// [`HnswIndex::restore_from_snapshot_embedded`] to restore without supplying +/// an external vector map. Snapshots produced by [`HnswIndex::snapshot`] +/// always include the full f32 vector data. +/// +/// The estimated size overhead is `dimensions × 4 bytes × node_count`. +/// For 384-dim embeddings with 10 K nodes this is ~15 MB — well within +/// typical checkpoint budgets. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HnswSnapshot { + /// Legacy field for backward compatibility with v1 snapshots. + /// New code should use `total_nodes` and `live_nodes` instead. + #[serde(default, skip_serializing_if = "is_zero")] + pub vector_count: usize, + + /// Total number of nodes (including tombstones). + #[serde(default)] + pub total_nodes: usize, + + /// Number of live (non-tombstoned) nodes. + #[serde(default)] + pub live_nodes: usize, + + /// Number of tombstoned nodes. + #[serde(default)] + pub tombstone_count: usize, + + /// Maximum layer in the graph. + pub max_layer: usize, + + /// Entry point node ID (if any). + pub entry_point: Option, + + /// Index configuration at checkpoint time. + pub config: HnswCheckpointConfig, + + /// IDs of all indexed vectors (for verification on restore). + /// Sorted by byte representation for deterministic ordering. + pub indexed_ids: Vec, + + /// IDs of tombstoned vectors. + /// Sorted by byte representation for deterministic ordering. + #[serde(default)] + pub tombstoned_ids: Vec, + + /// Graph edges per layer: `layer -> [(node_id, [neighbor_ids])]`. + /// Node entries within each layer are sorted by NodeId bytes. + pub layers: Vec)>>, + + /// Embedded f32 vector data for self-contained warm-start snapshots. + /// + /// Maps each `NodeId` to its raw embedding vector. When non-empty, the + /// snapshot is self-contained and can be restored via + /// [`HnswIndex::restore_from_snapshot_embedded`] without supplying a + /// separate vector map. + /// + /// Defaults to empty for backward compatibility with snapshots that + /// pre-date this field (those require an external vector map). + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub vectors: Vec<(NodeId, Vec)>, +} + +/// Subset of [`super::HnswConfig`] relevant for checkpoint compatibility. +/// +/// Stored as simple values (e.g. `metric` as `String`) so that checkpoints +/// remain deserializable even if the enum representation changes. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct HnswCheckpointConfig { + /// Maximum connections per node per layer (M). + pub m: usize, + /// Size of dynamic candidate list during construction. + pub ef_construction: usize, + /// Distance metric name (e.g. `"cosine"`, `"dot"`, `"euclidean"`). + pub metric: String, +} + +impl HnswCheckpointConfig { + /// Create a checkpoint config from the full [`super::HnswConfig`]. + pub fn from_hnsw_config(config: &super::config::HnswConfig) -> Self { + Self { + m: config.m, + ef_construction: config.ef_construction, + metric: metric_to_string(&config.metric), + } + } +} + +impl HnswSnapshot { + /// Check if this snapshot is compatible with the given config. + /// + /// Two configs are compatible when `m`, `ef_construction`, and `metric` + /// all match. Loading a snapshot into an index with incompatible + /// parameters would produce incorrect search results. + pub fn is_compatible(&self, config: &HnswCheckpointConfig) -> bool { + self.config == *config + } + + /// Get the number of live (non-tombstoned) vectors in this snapshot. + /// + /// For backward compatibility, this returns `live_nodes` which represents + /// the same semantic meaning as the legacy `vector_count` (all vectors + /// were "live" before tombstone support). + pub fn len(&self) -> usize { + self.live_nodes + } + + /// Get the total number of nodes (including tombstones). + pub fn total_len(&self) -> usize { + self.total_nodes + } + + /// Get the number of tombstoned nodes. + pub fn tombstone_count(&self) -> usize { + self.tombstone_count + } + + /// Returns `true` if the snapshot contains no live vectors. + pub fn is_empty(&self) -> bool { + self.live_nodes == 0 + } + + /// Normalize the snapshot after deserialization. + /// + /// This handles backward compatibility with v1 snapshots that only + /// had `vector_count`. If `total_nodes` is 0 but `vector_count` > 0 + /// or `indexed_ids` is non-empty, the counts are populated from + /// available data. + /// + /// Call this after deserializing a snapshot of unknown version. + pub fn normalize(&mut self) { + // Handle v1 -> v2 migration + if self.total_nodes == 0 { + if self.vector_count > 0 { + // V1 snapshot with vector_count + self.total_nodes = self.vector_count; + self.live_nodes = self.vector_count; + self.tombstone_count = 0; + } else if !self.indexed_ids.is_empty() { + // Fallback: infer from indexed_ids + self.total_nodes = self.indexed_ids.len(); + self.live_nodes = self.indexed_ids.len() - self.tombstoned_ids.len(); + self.tombstone_count = self.tombstoned_ids.len(); + } + } + + // Ensure tombstone_count matches tombstoned_ids + if self.tombstone_count == 0 && !self.tombstoned_ids.is_empty() { + self.tombstone_count = self.tombstoned_ids.len(); + } + } + + /// Verify internal consistency of the snapshot. + /// + /// Checks: + /// 1. `total_nodes == live_nodes + tombstone_count` + /// 2. `indexed_ids.len() == total_nodes` + /// 3. `tombstoned_ids.len() == tombstone_count` + /// 4. All tombstoned IDs exist in indexed_ids + /// + /// Returns `Ok(())` if all invariants hold, otherwise returns the + /// first error encountered. + pub fn verify(&self) -> Result<(), SnapshotError> { + // Check count consistency + if self.total_nodes != self.live_nodes + self.tombstone_count { + return Err(SnapshotError::InconsistentCounts { + total: self.total_nodes, + live: self.live_nodes, + tombstones: self.tombstone_count, + }); + } + + // Check indexed_ids matches total_nodes + if self.indexed_ids.len() != self.total_nodes { + return Err(SnapshotError::IdCountMismatch { + expected: self.total_nodes, + actual: self.indexed_ids.len(), + }); + } + + // Check tombstoned_ids matches tombstone_count + if self.tombstoned_ids.len() != self.tombstone_count { + return Err(SnapshotError::TombstoneIdCountMismatch { + expected: self.tombstone_count, + actual: self.tombstoned_ids.len(), + }); + } + + // Check all tombstoned IDs are in indexed_ids + if !self.tombstoned_ids.is_empty() { + let indexed_set: HashSet<_> = self.indexed_ids.iter().collect(); + for id in &self.tombstoned_ids { + if !indexed_set.contains(id) { + return Err(SnapshotError::TombstoneNotInIndex { id: *id }); + } + } + } + + Ok(()) + } + + /// Check if indexed_ids, tombstoned_ids, and layers are in canonical sorted order. + /// + /// Canonical order means all ID lists are sorted by their byte representation. + /// This ensures deterministic serialization and stable index-based encodings. + pub fn is_canonical(&self) -> bool { + // Check indexed_ids are sorted + let ids_sorted = self + .indexed_ids + .windows(2) + .all(|w| w[0].as_bytes() <= w[1].as_bytes()); + + if !ids_sorted { + return false; + } + + // Check tombstoned_ids are sorted + let tombstones_sorted = self + .tombstoned_ids + .windows(2) + .all(|w| w[0].as_bytes() <= w[1].as_bytes()); + + if !tombstones_sorted { + return false; + } + + // Check each layer's nodes are sorted by ID + for layer in &self.layers { + let layer_sorted = layer + .windows(2) + .all(|w| w[0].0.as_bytes() <= w[1].0.as_bytes()); + if !layer_sorted { + return false; + } + } + + true + } + + /// Ensure canonical ordering (idempotent). + /// + /// Sorts `indexed_ids`, `tombstoned_ids`, and layer node entries by their + /// byte representation. This should be called before serializing snapshots + /// to ensure deterministic output. + /// + /// # Note + /// + /// Neighbor lists within each node are intentionally not sorted, as their order + /// may reflect proximity/priority from the HNSW algorithm. Only the top-level + /// node ordering within layers is canonicalized. + pub fn canonicalize(&mut self) { + // Sort indexed IDs + sort_ids(&mut self.indexed_ids); + + // Sort tombstoned IDs + sort_ids(&mut self.tombstoned_ids); + + // Sort layer node order (but preserve neighbor list order within each node) + for layer in &mut self.layers { + layer.sort_by(|(a, _), (b, _)| a.as_bytes().cmp(b.as_bytes())); + } + } +} + +/// Convert a [`DistanceMetric`] to its canonical string representation. +pub(crate) fn metric_to_string(metric: &DistanceMetric) -> String { + match metric { + DistanceMetric::Cosine => "cosine".to_string(), + DistanceMetric::Dot => "dot".to_string(), + DistanceMetric::L2 => "euclidean".to_string(), + // Fall back to debug repr for future variants. + other => format!("{:?}", other).to_lowercase(), + } +} + +// ── Feature-gated fold integration ────────────────────────────────────── + +/// Type alias for HNSW checkpoints using the fold checkpoint system. +/// +/// Wraps an [`HnswSnapshot`] in the generic [`khive_fold::Checkpoint`] +/// envelope which adds checkpoint ID, timestamp, and fold context. +#[cfg(feature = "checkpoint")] +pub type HnswCheckpoint = khive_fold::Checkpoint; + +/// Type alias for an in-memory HNSW checkpoint store. +/// +/// Suitable for testing and development. Production deployments should +/// implement [`khive_fold::CheckpointStore`] with durable storage. +#[cfg(feature = "checkpoint")] +pub type HnswCheckpointStore = khive_fold::InMemoryCheckpointStore; + +// ── Tests ─────────────────────────────────────────────────────────────── + +#[cfg(test)] +#[path = "tests.rs"] +mod tests; + +#[cfg(all(test, feature = "checkpoint"))] +#[path = "integration_tests.rs"] +mod checkpoint_integration_tests; diff --git a/crates/khive-hnsw/src/checkpoint/tests.rs b/crates/khive-hnsw/src/checkpoint/tests.rs new file mode 100644 index 00000000..470c4cb8 --- /dev/null +++ b/crates/khive-hnsw/src/checkpoint/tests.rs @@ -0,0 +1,692 @@ +#![allow(clippy::field_reassign_with_default)] + +use super::*; + +fn make_id(seed: u8) -> NodeId { + NodeId::new([seed; 16]) +} + +fn sample_config() -> HnswCheckpointConfig { + HnswCheckpointConfig { + m: 16, + ef_construction: 200, + metric: "cosine".to_string(), + } +} + +fn sample_snapshot() -> HnswSnapshot { + let id1 = make_id(1); + let id2 = make_id(2); + HnswSnapshot { + vector_count: 0, // Not used in v2 + total_nodes: 2, + live_nodes: 2, + tombstone_count: 0, + max_layer: 1, + entry_point: Some(id1), + config: sample_config(), + indexed_ids: vec![id1, id2], + tombstoned_ids: vec![], + layers: vec![ + // Layer 0: both nodes connected to each other + vec![(id1, vec![id2]), (id2, vec![id1])], + // Layer 1: only entry point + vec![(id1, vec![])], + ], + vectors: vec![], + } +} + +fn sample_snapshot_with_tombstones() -> HnswSnapshot { + let id1 = make_id(1); + let id2 = make_id(2); + let id3 = make_id(3); + HnswSnapshot { + vector_count: 0, + total_nodes: 3, + live_nodes: 2, + tombstone_count: 1, + max_layer: 0, + entry_point: Some(id1), + config: sample_config(), + indexed_ids: vec![id1, id2, id3], + tombstoned_ids: vec![id2], // id2 is tombstoned + layers: vec![vec![ + (id1, vec![id2, id3]), + (id2, vec![id1]), + (id3, vec![id1]), + ]], + vectors: vec![], + } +} + +#[test] +fn snapshot_creation_and_accessors() { + let snap = sample_snapshot(); + assert_eq!(snap.len(), 2); + assert_eq!(snap.total_len(), 2); + assert_eq!(snap.tombstone_count(), 0); + assert!(!snap.is_empty()); + assert_eq!(snap.max_layer, 1); + assert!(snap.entry_point.is_some()); + assert_eq!(snap.indexed_ids.len(), 2); + assert_eq!(snap.layers.len(), 2); +} + +#[test] +fn snapshot_with_tombstones_accessors() { + let snap = sample_snapshot_with_tombstones(); + assert_eq!(snap.len(), 2); // live nodes + assert_eq!(snap.total_len(), 3); // total including tombstones + assert_eq!(snap.tombstone_count(), 1); + assert!(!snap.is_empty()); +} + +#[test] +fn empty_snapshot() { + let snap = HnswSnapshot { + vector_count: 0, + total_nodes: 0, + live_nodes: 0, + tombstone_count: 0, + max_layer: 0, + entry_point: None, + config: sample_config(), + indexed_ids: vec![], + tombstoned_ids: vec![], + layers: vec![], + + vectors: vec![], + }; + assert!(snap.is_empty()); + assert_eq!(snap.len(), 0); + assert_eq!(snap.total_len(), 0); +} + +// ── Verification tests ─────────────────────────────────────────────── + +#[test] +fn verify_valid_snapshot() { + let snap = sample_snapshot(); + assert!(snap.verify().is_ok()); +} + +#[test] +fn verify_valid_snapshot_with_tombstones() { + let snap = sample_snapshot_with_tombstones(); + assert!(snap.verify().is_ok()); +} + +#[test] +fn verify_inconsistent_counts() { + let mut snap = sample_snapshot(); + snap.tombstone_count = 1; // Inconsistent: 2 != 2 + 1 + let err = snap.verify().unwrap_err(); + assert!(matches!(err, SnapshotError::InconsistentCounts { .. })); +} + +#[test] +fn verify_id_count_mismatch() { + let mut snap = sample_snapshot(); + snap.total_nodes = 5; // Mismatch with indexed_ids.len() == 2 + snap.live_nodes = 5; + let err = snap.verify().unwrap_err(); + assert!(matches!(err, SnapshotError::IdCountMismatch { .. })); +} + +#[test] +fn verify_tombstone_id_count_mismatch() { + let mut snap = sample_snapshot_with_tombstones(); + snap.tombstone_count = 2; // Mismatch with tombstoned_ids.len() == 1 + snap.live_nodes = 1; // Adjust to keep total consistent + let err = snap.verify().unwrap_err(); + assert!(matches!( + err, + SnapshotError::TombstoneIdCountMismatch { .. } + )); +} + +#[test] +fn verify_tombstone_not_in_index() { + let mut snap = sample_snapshot_with_tombstones(); + snap.tombstoned_ids = vec![make_id(99)]; // ID not in indexed_ids + let err = snap.verify().unwrap_err(); + assert!(matches!(err, SnapshotError::TombstoneNotInIndex { .. })); +} + +// ── Normalization tests ────────────────────────────────────────────── + +#[test] +fn normalize_v1_snapshot() { + // Simulate a v1 snapshot with only vector_count + let mut snap = HnswSnapshot { + vector_count: 5, // V1 field + total_nodes: 0, // Will be populated by normalize + live_nodes: 0, + tombstone_count: 0, + max_layer: 0, + entry_point: None, + config: sample_config(), + indexed_ids: vec![make_id(1), make_id(2), make_id(3), make_id(4), make_id(5)], + tombstoned_ids: vec![], + layers: vec![], + + vectors: vec![], + }; + + snap.normalize(); + + assert_eq!(snap.total_nodes, 5); + assert_eq!(snap.live_nodes, 5); + assert_eq!(snap.tombstone_count, 0); +} + +#[test] +fn normalize_infers_from_indexed_ids() { + let mut snap = HnswSnapshot { + vector_count: 0, + total_nodes: 0, + live_nodes: 0, + tombstone_count: 0, + max_layer: 0, + entry_point: None, + config: sample_config(), + indexed_ids: vec![make_id(1), make_id(2), make_id(3)], + tombstoned_ids: vec![make_id(2)], + layers: vec![], + + vectors: vec![], + }; + + snap.normalize(); + + assert_eq!(snap.total_nodes, 3); + assert_eq!(snap.live_nodes, 2); + assert_eq!(snap.tombstone_count, 1); +} + +// ── Serialization tests ────────────────────────────────────────────── + +#[test] +fn serialization_round_trip() { + let snap = sample_snapshot(); + let json = serde_json::to_string(&snap).expect("serialize"); + let mut restored: HnswSnapshot = serde_json::from_str(&json).expect("deserialize"); + restored.normalize(); + + assert_eq!(restored.total_nodes, snap.total_nodes); + assert_eq!(restored.live_nodes, snap.live_nodes); + assert_eq!(restored.tombstone_count, snap.tombstone_count); + assert_eq!(restored.max_layer, snap.max_layer); + assert_eq!(restored.entry_point, snap.entry_point); + assert_eq!(restored.config, snap.config); + assert_eq!(restored.indexed_ids, snap.indexed_ids); + assert_eq!(restored.tombstoned_ids, snap.tombstoned_ids); + assert_eq!(restored.layers.len(), snap.layers.len()); +} + +#[test] +fn serialization_round_trip_with_tombstones() { + let snap = sample_snapshot_with_tombstones(); + let json = serde_json::to_string(&snap).expect("serialize"); + let restored: HnswSnapshot = serde_json::from_str(&json).expect("deserialize"); + + assert_eq!(restored.total_nodes, snap.total_nodes); + assert_eq!(restored.live_nodes, snap.live_nodes); + assert_eq!(restored.tombstone_count, snap.tombstone_count); + assert_eq!(restored.tombstoned_ids, snap.tombstoned_ids); + assert!(restored.verify().is_ok()); +} + +#[test] +fn backward_compat_v1_deserialization() { + // JSON from a v1 snapshot (only has vector_count, not new fields) + // EmbeddingId serializes as 32 hex chars (no dashes) + let v1_json = r#"{ + "vector_count": 2, + "max_layer": 0, + "entry_point": null, + "config": {"m": 16, "ef_construction": 200, "metric": "cosine"}, + "indexed_ids": ["01010101010101010101010101010101", "02020202020202020202020202020202"], + "layers": [] + }"#; + + let mut snap: HnswSnapshot = serde_json::from_str(v1_json).expect("deserialize v1"); + snap.normalize(); + + assert_eq!(snap.total_nodes, 2); + assert_eq!(snap.live_nodes, 2); + assert_eq!(snap.tombstone_count, 0); + assert_eq!(snap.len(), 2); +} + +// ── Compatibility tests ────────────────────────────────────────────── + +#[test] +fn compatibility_matching_config() { + let snap = sample_snapshot(); + assert!(snap.is_compatible(&sample_config())); +} + +#[test] +fn compatibility_different_m() { + let snap = sample_snapshot(); + let other = HnswCheckpointConfig { + m: 32, + ..sample_config() + }; + assert!(!snap.is_compatible(&other)); +} + +#[test] +fn compatibility_different_ef() { + let snap = sample_snapshot(); + let other = HnswCheckpointConfig { + ef_construction: 400, + ..sample_config() + }; + assert!(!snap.is_compatible(&other)); +} + +#[test] +fn compatibility_different_metric() { + let snap = sample_snapshot(); + let other = HnswCheckpointConfig { + metric: "euclidean".to_string(), + ..sample_config() + }; + assert!(!snap.is_compatible(&other)); +} + +#[test] +fn from_hnsw_config() { + use super::super::config::HnswConfig; + + let hnsw = HnswConfig::default(); + let ckpt = HnswCheckpointConfig::from_hnsw_config(&hnsw); + assert_eq!(ckpt.m, 20); + assert_eq!(ckpt.ef_construction, 200); + assert_eq!(ckpt.metric, "cosine"); +} + +#[test] +fn from_hnsw_config_variants() { + use super::super::config::{DistanceMetric, HnswConfig}; + + let mut config = HnswConfig::default(); + + config.metric = DistanceMetric::Dot; + assert_eq!( + HnswCheckpointConfig::from_hnsw_config(&config).metric, + "dot" + ); + + config.metric = DistanceMetric::L2; + assert_eq!( + HnswCheckpointConfig::from_hnsw_config(&config).metric, + "euclidean" + ); +} + +#[test] +fn metric_to_string_exhaustive() { + assert_eq!(metric_to_string(&DistanceMetric::Cosine), "cosine"); + assert_eq!(metric_to_string(&DistanceMetric::Dot), "dot"); + assert_eq!(metric_to_string(&DistanceMetric::L2), "euclidean"); +} + +// ── Error display tests ────────────────────────────────────────────── + +#[test] +fn snapshot_error_display() { + let err = SnapshotError::InconsistentCounts { + total: 5, + live: 3, + tombstones: 1, + }; + assert!(err.to_string().contains("inconsistent counts")); + + let err = SnapshotError::IdCountMismatch { + expected: 5, + actual: 3, + }; + assert!(err.to_string().contains("indexed_ids count mismatch")); + + let err = SnapshotError::TombstoneIdCountMismatch { + expected: 2, + actual: 1, + }; + assert!(err.to_string().contains("tombstoned_ids count mismatch")); + + let err = SnapshotError::TombstoneNotInIndex { id: make_id(1) }; + assert!(err.to_string().contains("not found in indexed_ids")); +} + +// ── Canonical ordering tests ────────────────────────────────────────── + +#[test] +fn sort_ids_orders_by_bytes() { + let mut ids = vec![make_id(5), make_id(2), make_id(9), make_id(1)]; + sort_ids(&mut ids); + + assert_eq!(ids[0], make_id(1)); + assert_eq!(ids[1], make_id(2)); + assert_eq!(ids[2], make_id(5)); + assert_eq!(ids[3], make_id(9)); +} + +#[test] +fn sort_ids_empty_is_noop() { + let mut ids: Vec = vec![]; + sort_ids(&mut ids); + assert!(ids.is_empty()); +} + +#[test] +fn sort_ids_single_element() { + let mut ids = vec![make_id(42)]; + sort_ids(&mut ids); + assert_eq!(ids.len(), 1); + assert_eq!(ids[0], make_id(42)); +} + +#[test] +fn is_canonical_sorted_snapshot() { + // sample_snapshot() has ids in sorted order (1, 2) and layers also sorted + let snap = sample_snapshot(); + assert!(snap.is_canonical()); +} + +#[test] +fn is_canonical_empty_snapshot() { + let snap = HnswSnapshot { + vector_count: 0, + total_nodes: 0, + live_nodes: 0, + tombstone_count: 0, + max_layer: 0, + entry_point: None, + config: sample_config(), + indexed_ids: vec![], + tombstoned_ids: vec![], + layers: vec![], + + vectors: vec![], + }; + assert!(snap.is_canonical()); +} + +#[test] +fn is_canonical_unsorted_indexed_ids() { + let id1 = make_id(1); + let id2 = make_id(2); + let snap = HnswSnapshot { + vector_count: 0, + total_nodes: 2, + live_nodes: 2, + tombstone_count: 0, + max_layer: 0, + entry_point: Some(id1), + config: sample_config(), + indexed_ids: vec![id2, id1], // Reversed order + tombstoned_ids: vec![], + layers: vec![vec![(id1, vec![id2]), (id2, vec![id1])]], + + vectors: vec![], + }; + assert!(!snap.is_canonical()); +} + +#[test] +fn is_canonical_unsorted_tombstoned_ids() { + let id1 = make_id(1); + let id2 = make_id(2); + let id3 = make_id(3); + let snap = HnswSnapshot { + vector_count: 0, + total_nodes: 3, + live_nodes: 1, + tombstone_count: 2, + max_layer: 0, + entry_point: Some(id1), + config: sample_config(), + indexed_ids: vec![id1, id2, id3], // Sorted + tombstoned_ids: vec![id3, id2], // Reversed order + layers: vec![], + + vectors: vec![], + }; + assert!(!snap.is_canonical()); +} + +#[test] +fn is_canonical_unsorted_layer() { + let id1 = make_id(1); + let id2 = make_id(2); + let snap = HnswSnapshot { + vector_count: 0, + total_nodes: 2, + live_nodes: 2, + tombstone_count: 0, + max_layer: 0, + entry_point: Some(id1), + config: sample_config(), + indexed_ids: vec![id1, id2], // Sorted + tombstoned_ids: vec![], + layers: vec![vec![(id2, vec![id1]), (id1, vec![id2])]], // Reversed order + + vectors: vec![], + }; + assert!(!snap.is_canonical()); +} + +#[test] +fn canonicalize_sorts_indexed_ids() { + let id1 = make_id(1); + let id2 = make_id(2); + let id3 = make_id(3); + let mut snap = HnswSnapshot { + vector_count: 0, + total_nodes: 3, + live_nodes: 3, + tombstone_count: 0, + max_layer: 0, + entry_point: Some(id1), + config: sample_config(), + indexed_ids: vec![id3, id1, id2], // Unsorted + tombstoned_ids: vec![], + layers: vec![], + + vectors: vec![], + }; + + snap.canonicalize(); + + assert_eq!(snap.indexed_ids, vec![id1, id2, id3]); + assert!(snap.is_canonical()); +} + +#[test] +fn canonicalize_sorts_tombstoned_ids() { + let id1 = make_id(1); + let id2 = make_id(2); + let id3 = make_id(3); + let mut snap = HnswSnapshot { + vector_count: 0, + total_nodes: 3, + live_nodes: 1, + tombstone_count: 2, + max_layer: 0, + entry_point: Some(id1), + config: sample_config(), + indexed_ids: vec![id1, id2, id3], + tombstoned_ids: vec![id3, id2], // Unsorted + layers: vec![], + + vectors: vec![], + }; + + snap.canonicalize(); + + assert_eq!(snap.tombstoned_ids, vec![id2, id3]); + assert!(snap.is_canonical()); +} + +#[test] +fn canonicalize_sorts_layers() { + let id1 = make_id(1); + let id2 = make_id(2); + let id3 = make_id(3); + let mut snap = HnswSnapshot { + vector_count: 0, + total_nodes: 3, + live_nodes: 3, + tombstone_count: 0, + max_layer: 1, + entry_point: Some(id1), + config: sample_config(), + indexed_ids: vec![id1, id2, id3], + tombstoned_ids: vec![], + layers: vec![ + // Layer 0: nodes in wrong order + vec![ + (id3, vec![id1, id2]), + (id1, vec![id2, id3]), + (id2, vec![id1, id3]), + ], + // Layer 1: also wrong order + vec![(id2, vec![]), (id1, vec![])], + ], + + vectors: vec![], + }; + + snap.canonicalize(); + + // Verify layer 0 is sorted by node ID + assert_eq!(snap.layers[0][0].0, id1); + assert_eq!(snap.layers[0][1].0, id2); + assert_eq!(snap.layers[0][2].0, id3); + + // Verify layer 1 is sorted by node ID + assert_eq!(snap.layers[1][0].0, id1); + assert_eq!(snap.layers[1][1].0, id2); + + assert!(snap.is_canonical()); +} + +#[test] +fn canonicalize_preserves_neighbor_order() { + // Neighbor order should NOT be sorted (reflects proximity from HNSW algorithm) + let id1 = make_id(1); + let id2 = make_id(2); + let id3 = make_id(3); + let mut snap = HnswSnapshot { + vector_count: 0, + total_nodes: 3, + live_nodes: 3, + tombstone_count: 0, + max_layer: 0, + entry_point: Some(id1), + config: sample_config(), + indexed_ids: vec![id1, id2, id3], + tombstoned_ids: vec![], + layers: vec![vec![ + (id1, vec![id3, id2]), // Neighbors intentionally in non-byte-sorted order + ]], + + vectors: vec![], + }; + + snap.canonicalize(); + + // Neighbor list should be unchanged + assert_eq!(snap.layers[0][0].1, vec![id3, id2]); +} + +#[test] +fn canonicalize_is_idempotent() { + let id1 = make_id(1); + let id2 = make_id(2); + let id3 = make_id(3); + let mut snap = HnswSnapshot { + vector_count: 0, + total_nodes: 3, + live_nodes: 2, + tombstone_count: 1, + max_layer: 1, + entry_point: Some(id1), + config: sample_config(), + indexed_ids: vec![id3, id1, id2], // Unsorted + tombstoned_ids: vec![id3], + layers: vec![ + vec![(id3, vec![id1]), (id1, vec![id2]), (id2, vec![id3])], + vec![(id2, vec![]), (id1, vec![])], + ], + + vectors: vec![], + }; + + snap.canonicalize(); + let after_first = snap.clone(); + + snap.canonicalize(); + let after_second = snap.clone(); + + // Both passes should produce identical results + assert_eq!(after_first.indexed_ids, after_second.indexed_ids); + assert_eq!(after_first.tombstoned_ids, after_second.tombstoned_ids); + assert_eq!(after_first.layers.len(), after_second.layers.len()); + for (l1, l2) in after_first.layers.iter().zip(after_second.layers.iter()) { + assert_eq!(l1, l2); + } +} + +#[test] +fn canonical_snapshot_serializes_deterministically() { + // Create two snapshots with same data but different initial order + let id1 = make_id(1); + let id2 = make_id(2); + let id3 = make_id(3); + + let mut snap1 = HnswSnapshot { + vector_count: 0, + total_nodes: 3, + live_nodes: 3, + tombstone_count: 0, + max_layer: 0, + entry_point: Some(id1), + config: sample_config(), + indexed_ids: vec![id3, id1, id2], + tombstoned_ids: vec![], + layers: vec![vec![(id3, vec![id1]), (id1, vec![id2]), (id2, vec![id3])]], + + vectors: vec![], + }; + + let mut snap2 = HnswSnapshot { + vector_count: 0, + total_nodes: 3, + live_nodes: 3, + tombstone_count: 0, + max_layer: 0, + entry_point: Some(id1), + config: sample_config(), + indexed_ids: vec![id2, id3, id1], // Different initial order + tombstoned_ids: vec![], + layers: vec![vec![(id2, vec![id3]), (id3, vec![id1]), (id1, vec![id2])]], + + vectors: vec![], + }; + + snap1.canonicalize(); + snap2.canonicalize(); + + let json1 = serde_json::to_string(&snap1).expect("serialize"); + let json2 = serde_json::to_string(&snap2).expect("serialize"); + + assert_eq!( + json1, json2, + "Canonical snapshots should serialize identically" + ); +} diff --git a/crates/khive-hnsw/src/config.rs b/crates/khive-hnsw/src/config.rs new file mode 100644 index 00000000..64c16a08 --- /dev/null +++ b/crates/khive-hnsw/src/config.rs @@ -0,0 +1,253 @@ +//! HNSW configuration types. +//! +//! See ADR-003 for recommended parameter values. +//! +//! # RETRIEVAL-05: Embedding Key Validation +//! +//! The current implementation uses `EmbeddingId` (from khive-db) as the key type, +//! which provides type-safe, validated embedding identifiers. The validation occurs +//! at ID construction time (in khive-db), not in HnswConfig. +//! +//! **Design decision**: Validation is NOT duplicated in HnswConfig because: +//! 1. `EmbeddingId` is already a newtype that enforces validity +//! 2. Double validation would add overhead without security benefit +//! 3. The type system already prevents invalid keys at compile time +//! +//! If custom key types are needed in the future, add a `K: EmbeddingKey` trait +//! bound with validation methods. + +use serde::{Deserialize, Serialize}; + +use crate::error::{Result, RetrievalError}; + +/// Maximum allowed level in the HNSW graph. +/// Prevents unbounded memory allocation from malformed random values. +/// For 1 billion vectors with typical ml, expected max level is ~16-18. +pub const MAX_LEVEL: usize = 64; + +/// Default threshold for triggering a rebuild (10% tombstones). +/// Aligned with ADR-003: Index Management Strategy. +pub const DEFAULT_REBUILD_THRESHOLD: f64 = 0.10; + +// Re-export from canonical location (foundation/types). +// Canonical variants: Cosine, Dot, L2. +// Serde aliases on canonical handle backward compat: "euclidean" -> L2, "dot_product" -> Dot. +pub use khive_types::vector::DistanceMetric; + +/// HNSW configuration parameters per ADR-003. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HnswConfig { + /// Maximum number of connections per node per layer (M). + /// Higher = better recall, more memory, slower build. + /// Recommended: 16 (small), 32 (medium), 64 (large datasets). + pub m: usize, + + /// Maximum connections for layer 0 (typically 2*M). + /// Layer 0 is densest, needs more connections for good recall. + pub m_max0: usize, + + /// Size of dynamic candidate list during construction. + /// Higher = better graph quality, slower build. + /// Recommended: 100-500. + pub ef_construction: usize, + + /// Normalization factor for level generation: 1/ln(M). + /// Controls how quickly layers thin out. + pub ml: f64, + + /// Search ef (dynamic candidate list size during search). + /// Higher = better recall, slower search. + /// Recommended: 50-200. + pub ef_search: usize, + + /// Vector dimensions (must match embedding model). + /// Default: 768 (BGE-base). + pub dimensions: usize, + + /// Distance metric for similarity computation. + pub metric: DistanceMetric, + + /// Threshold for automatic rebuild (tombstone ratio). + /// When tombstones exceed this ratio, rebuild() is recommended. + pub rebuild_threshold: f64, + + /// Seed for reproducible level generation. + /// If None, uses OS entropy (non-deterministic). + /// If Some(seed), uses seeded RNG for reproducible index structure. + #[serde(default)] + pub seed: Option, + + /// Maximum memory budget in bytes for the index. + /// If None, no memory limit is enforced (default). + /// If Some(limit), inserts that would exceed the budget are rejected + /// with `RetrievalError::BudgetExceeded`. Updates to existing entries + /// bypass the budget check. + #[serde(default)] + pub memory_budget: Option, +} + +impl Default for HnswConfig { + /// Creates default configuration per ADR-003. + /// + /// M=20, ef_construction=200, ef_search=80, dimensions=384. + /// M=20 is optimal for k=10 recall at 384d (empirically measured). + /// ef_search=80 sufficient for <100K corpus; 100 was overprovisioned. + fn default() -> Self { + Self { + m: 20, + m_max0: 40, + ef_construction: 200, + ml: 1.0 / (20.0_f64).ln(), + ef_search: 80, + dimensions: 384, + metric: DistanceMetric::Cosine, + rebuild_threshold: DEFAULT_REBUILD_THRESHOLD, + seed: None, + memory_budget: None, + } + } +} + +impl HnswConfig { + /// Validate configuration invariants that must hold for every index. + pub fn validate(&self) -> Result<()> { + if self.dimensions == 0 { + return Err(RetrievalError::Configuration( + "dimensions: HNSW dimensions must be greater than zero".to_string(), + )); + } + Ok(()) + } + + /// Create config with custom dimensions, returning an error for invalid values. + pub fn try_with_dimensions(dimensions: usize) -> Result { + let config = Self { + dimensions, + ..Default::default() + }; + config.validate()?; + Ok(config) + } + + /// Create config with custom dimensions, keeping ADR-003 defaults. + /// + /// # Panics + /// Panics if `dimensions` is 0. + pub fn with_dimensions(dimensions: usize) -> Self { + Self::try_with_dimensions(dimensions).expect("HNSW dimensions must be > 0") + } + + /// Create config for high recall (slower build, better search). + pub fn high_recall() -> Self { + Self { + m: 32, + m_max0: 64, + ef_construction: 400, + ef_search: 200, + ..Default::default() + } + } + + /// Create config for fast build (faster build, lower recall). + pub fn fast_build() -> Self { + Self { + m: 12, + m_max0: 24, + ef_construction: 100, + ef_search: 50, + ..Default::default() + } + } + + /// Create config optimized for memory efficiency. + pub fn low_memory() -> Self { + Self { + m: 8, + m_max0: 16, + ef_construction: 80, + ef_search: 40, + ..Default::default() + } + } + + /// Set seed for reproducible level generation. + /// + /// With the same seed and insertion order, the index structure + /// will be identical across runs. + #[must_use] + pub fn with_seed(mut self, seed: u64) -> Self { + self.seed = Some(seed); + self + } + + /// Set memory budget in bytes. + /// + /// When set, inserts that would cause the estimated memory usage + /// to exceed this limit are rejected with `BudgetExceeded`. + /// Updates to existing entries bypass the budget check. + #[must_use] + pub fn with_memory_budget(mut self, budget: usize) -> Self { + self.memory_budget = Some(budget); + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_config_defaults() { + let config = HnswConfig::default(); + assert_eq!(config.m, 20); + assert_eq!(config.ef_construction, 200); + assert_eq!(config.ef_search, 80); + assert_eq!(config.dimensions, 384); + } + + #[test] + fn test_config_variants() { + let high = HnswConfig::high_recall(); + assert_eq!(high.m, 32); + assert_eq!(high.ef_construction, 400); + + let fast = HnswConfig::fast_build(); + assert_eq!(fast.m, 12); + + let low = HnswConfig::low_memory(); + assert_eq!(low.m, 8); + } + + #[test] + fn test_with_dimensions() { + let config = HnswConfig::with_dimensions(1536); + assert_eq!(config.dimensions, 1536); + assert_eq!(config.m, 20); // Other defaults preserved + } + + #[test] + fn test_try_with_dimensions_rejects_zero() { + let result = HnswConfig::try_with_dimensions(0); + assert!(result.is_err()); + } + + #[test] + #[should_panic(expected = "HNSW dimensions must be > 0")] + fn test_with_dimensions_rejects_zero() { + HnswConfig::with_dimensions(0); + } + + #[test] + #[should_panic(expected = "HNSW configuration must be valid")] + fn test_index_with_config_rejects_zero_dimensions() { + crate::HnswIndex::with_config(HnswConfig { + dimensions: 0, + ..Default::default() + }); + } + + #[test] + fn test_distance_metric_default() { + assert_eq!(DistanceMetric::default(), DistanceMetric::Cosine); + } +} diff --git a/crates/khive-hnsw/src/distance.rs b/crates/khive-hnsw/src/distance.rs new file mode 100644 index 00000000..9f225736 --- /dev/null +++ b/crates/khive-hnsw/src/distance.rs @@ -0,0 +1,368 @@ +//! Distance computation for HNSW. +//! +//! # Formal Verification +//! +//! This implementation corresponds to the formal proofs in +//! `proofs/Lion/Retrieval/Distance.lean`. Key theorems: +//! +//! ## Metric Axioms (Euclidean) +//! - `euclidean_nonneg`: d(x,y) ≥ 0 +//! - `euclidean_self`: d(x,x) = 0 +//! - `euclidean_symm`: d(x,y) = d(y,x) +//! - `euclidean_triangle`: d(x,z) ≤ d(x,y) + d(y,z) +//! +//! ## Cosine Properties +//! - `cosine_range`: -1 ≤ cos(x,y) ≤ 1 for unit vectors +//! - `cosine_not_metric`: cosine does NOT satisfy triangle inequality +//! +//! ## Dot Product +//! - `dot_eq_inner`: bridges to Mathlib inner product space +//! +//! ## Distance-Similarity Conversion +//! - `distanceToSimilarity`: sim = 1/(1+d) for Euclidean +//! - `similarity_nonneg`: similarity ≥ 0 +//! - `similarity_bounded`: 0 ≤ sim ≤ 1 for d ≥ 0 + +use super::config::DistanceMetric; + +/// Compute cosine distance from pre-computed dot product and norms. +/// +/// Clamps the cosine similarity to [-1, 1] before converting to distance, +/// preventing out-of-range values caused by floating-point rounding when +/// vectors are not perfectly unit-normalised (RETRIEVAL-M3). +/// +/// Returns a distance in [0, 2] (0 = identical direction, 2 = opposite). +/// Falls back to 1.0 (orthogonal) for zero or infinite norms. +#[inline] +pub(crate) fn cosine_distance_from_parts(dot: f32, a_norm: f32, b_norm: f32) -> f32 { + let denom = a_norm * b_norm; + if !denom.is_finite() || denom <= 0.0 { + return 1.0; + } + let cosine = (dot / denom).clamp(-1.0, 1.0); + if cosine.is_finite() { + 1.0 - cosine + } else { + 1.0 + } +} + +/// Compute distance between two vectors. +/// Returns distance (lower = more similar) for heap operations. +/// +/// Uses SIMD-accelerated implementations from khive-embed (ADR-002). +#[inline] +pub fn compute_distance( + a: &[f32], + a_norm: f32, + b: &[f32], + b_norm: f32, + metric: DistanceMetric, +) -> f32 { + if !a_norm.is_finite() || !b_norm.is_finite() { + return 1.0; + } + + match metric { + DistanceMetric::Cosine => { + // ADR-002: khive-embed is the SIMD foundation layer + // + // **PROOF CORRESPONDENCE**: Lion.Retrieval.Cosine.cosine_sim_bounded + // Cosine similarity is bounded: -1 <= cos(x,y) <= 1 for unit vectors + // + // **PROOF CORRESPONDENCE**: Lion.Retrieval.Cosine.cauchy_schwarz + // Cauchy-Schwarz inequality: || <= ||x|| * ||y|| + let dot = lattice_embed::simd::dot_product(a, b); + cosine_distance_from_parts(dot, a_norm, b_norm) + } + DistanceMetric::Dot => { + // Negate for min-heap (higher dot = lower distance) + // ADR-002: lattice-embed is the SIMD foundation layer + -lattice_embed::simd::dot_product(a, b) + } + DistanceMetric::L2 => { + // ADR-002: lattice-embed is the SIMD foundation layer + // + // **PROOF CORRESPONDENCE**: Lion.Retrieval.Distance.euclidean_nonneg + // Euclidean distance is non-negative: d(x,y) >= 0 + // + // **PROOF CORRESPONDENCE**: Lion.Retrieval.Distance.euclidean_symm + // Euclidean distance is symmetric: d(x,y) = d(y,x) + // + // **PROOF CORRESPONDENCE**: Lion.Retrieval.Distance.euclidean_triangle + // Triangle inequality: d(x,z) <= d(x,y) + d(y,z) + lattice_embed::simd::euclidean_distance(a, b) + } + _ => { + // DistanceMetric is #[non_exhaustive]; fall back to cosine for + // any future variants until explicitly supported. + let dot = lattice_embed::simd::dot_product(a, b); + cosine_distance_from_parts(dot, a_norm, b_norm) + } + } +} + +/// Ordering distance for HNSW-internal comparisons (graph maintenance, neighbor selection). +/// +/// Returns a value that is **monotone** with the true distance: smaller ordering distance +/// ↔ smaller true distance. Callers must NOT interpret the return value as a true distance +/// (e.g., do not pass it to `distance_to_similarity`). Use `compute_distance` for output. +/// +/// Optimization: for L2 metric, returns squared Euclidean distance (skips the final `.sqrt()`). +/// For all other metrics, identical to `compute_distance`. +#[inline] +pub(crate) fn compute_ordering_distance( + a: &[f32], + a_norm: f32, + b: &[f32], + b_norm: f32, + metric: DistanceMetric, +) -> f32 { + match metric { + DistanceMetric::L2 => lattice_embed::simd::squared_euclidean_distance(a, b), + other => compute_distance(a, a_norm, b, b_norm, other), + } +} + +/// Convert distance back to similarity score (higher = more similar). +/// +/// **PROOF CORRESPONDENCE**: Lion.Retrieval.Distance.similarity_mono +/// Similarity conversion is monotonically decreasing in distance: +/// d1 < d2 implies sim(d1) > sim(d2) +#[inline] +pub fn distance_to_similarity(dist: f32, metric: DistanceMetric) -> f32 { + match metric { + DistanceMetric::Cosine => 1.0 - dist, + DistanceMetric::Dot => -dist, + DistanceMetric::L2 => 1.0 / (1.0 + dist), + // Fall back to cosine similarity for future variants. + _ => 1.0 - dist, + } +} + +/// Ordered wrapper for f32 to enable use in BinaryHeap. +/// +/// # NaN Handling +/// +/// NaN values are treated as "infinite distance" (greater than all other values). +/// This ensures deterministic ordering and fail-safe behavior when encountering +/// malformed embeddings. Two NaN values compare as equal. +/// +/// # Formal Properties +/// +/// - Reflexive: a.cmp(a) = Equal +/// - Antisymmetric: a.cmp(b) = Less implies b.cmp(a) = Greater +/// - Transitive: a < b and b < c implies a < c +/// - Total: for all a, b: a.cmp(b) is defined +#[derive(Clone, Copy, PartialEq)] +pub struct OrderedF32(pub f32); + +impl Eq for OrderedF32 {} + +impl PartialOrd for OrderedF32 { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for OrderedF32 { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + // Handle NaN: treat as greater than all finite values + // This ensures fail-safe behavior: NaN results get pushed to the end + match (self.0.is_nan(), other.0.is_nan()) { + (true, true) => std::cmp::Ordering::Equal, + (true, false) => std::cmp::Ordering::Greater, + (false, true) => std::cmp::Ordering::Less, + (false, false) => { + // SAFETY: Both values are confirmed non-NaN by the match guards above. + // For non-NaN f32 values (including infinity), partial_cmp is total + // and always returns Some. This is a mathematical invariant of IEEE 754. + self.0 + .partial_cmp(&other.0) + .expect("both values are non-NaN, partial_cmp should succeed") + } + } + } +} + +impl OrderedF32 { + /// Check if the wrapped value is NaN. + #[inline] + #[allow(dead_code)] // TODO(#2640): wire or remove when use case is defined + pub fn is_nan(&self) -> bool { + self.0.is_nan() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ========================================================================= + // RETRIEVAL-M3: Cosine distance clamping tests + // ========================================================================= + + /// Verify that cosine_distance_from_parts clamps correctly when fp rounding + /// produces a dot/norm ratio slightly outside [-1, 1]. + #[test] + fn test_cosine_distance_from_parts_clamping() { + // Normal case: identical direction => distance 0 + assert!((cosine_distance_from_parts(1.0, 1.0, 1.0) - 0.0).abs() < 1e-6); + + // Normal case: opposite direction => distance 2 + assert!((cosine_distance_from_parts(-1.0, 1.0, 1.0) - 2.0).abs() < 1e-6); + + // Rounding artifact: dot slightly > denom (should clamp cosine to 1.0 => dist 0) + let dist = cosine_distance_from_parts(1.0000002, 1.0, 1.0); + assert!( + dist >= 0.0, + "distance must be non-negative after clamping, got {dist}" + ); + assert!( + dist <= 2.0, + "distance must be <= 2 after clamping, got {dist}" + ); + + // Zero norms => fallback 1.0 + assert_eq!(cosine_distance_from_parts(0.0, 0.0, 1.0), 1.0); + assert_eq!(cosine_distance_from_parts(0.0, 1.0, 0.0), 1.0); + + // Infinite denom => fallback 1.0 + assert_eq!( + cosine_distance_from_parts(f32::NAN, f32::INFINITY, 1.0), + 1.0 + ); + } + + #[test] + fn test_cosine_distance() { + let a = vec![1.0, 0.0]; + let b = vec![1.0, 0.0]; + let a_norm = 1.0; + let b_norm = 1.0; + + let dist = compute_distance(&a, a_norm, &b, b_norm, DistanceMetric::Cosine); + assert!(dist.abs() < 0.001); // Same vector = 0 distance + + let c = vec![0.0, 1.0]; + let dist = compute_distance(&a, a_norm, &c, 1.0, DistanceMetric::Cosine); + assert!((dist - 1.0).abs() < 0.001); // Orthogonal = 1 distance + } + + #[test] + fn test_euclidean_distance() { + let a = vec![0.0, 0.0]; + let b = vec![3.0, 4.0]; + + let dist = compute_distance(&a, 0.0, &b, 5.0, DistanceMetric::L2); + assert!((dist - 5.0).abs() < 0.001); + } + + #[test] + fn test_dot_product_distance() { + let a = vec![1.0, 2.0]; + let b = vec![2.0, 3.0]; + + let dist = compute_distance(&a, 0.0, &b, 0.0, DistanceMetric::Dot); + // dot = 1*2 + 2*3 = 8, distance = -8 + assert!((dist - (-8.0)).abs() < 0.001); + } + + #[test] + fn test_distance_to_similarity() { + // Cosine: similarity = 1 - distance + assert!((distance_to_similarity(0.2, DistanceMetric::Cosine) - 0.8).abs() < 0.001); + + // Dot: similarity = -distance + assert!((distance_to_similarity(-5.0, DistanceMetric::Dot) - 5.0).abs() < 0.001); + + // Euclidean: similarity = 1/(1+distance) + assert!((distance_to_similarity(1.0, DistanceMetric::L2) - 0.5).abs() < 0.001); + } + + #[test] + fn test_ordered_f32() { + let a = OrderedF32(1.0); + let b = OrderedF32(2.0); + assert!(a < b); + assert_eq!(a.cmp(&a), std::cmp::Ordering::Equal); + } + + // ========================================================================= + // RETRIEVAL-02: NaN Handling Tests + // ========================================================================= + + #[test] + fn test_ordered_f32_nan_handling() { + let nan = OrderedF32(f32::NAN); + let finite = OrderedF32(1.0); + let infinity = OrderedF32(f32::INFINITY); + let neg_infinity = OrderedF32(f32::NEG_INFINITY); + + // NaN is greater than all finite values + assert!(nan > finite); + assert!(finite < nan); + + // NaN is greater than infinity + assert!(nan > infinity); + assert!(infinity < nan); + + // NaN is greater than negative infinity + assert!(nan > neg_infinity); + + // Two NaNs are equal + let nan2 = OrderedF32(f32::NAN); + assert_eq!(nan.cmp(&nan2), std::cmp::Ordering::Equal); + } + + #[test] + fn test_ordered_f32_sorting_with_nan() { + // When sorting distances for nearest neighbor search, NaN should end up last + let mut distances = [ + OrderedF32(0.5), + OrderedF32(f32::NAN), + OrderedF32(0.1), + OrderedF32(0.9), + OrderedF32(f32::NAN), + OrderedF32(0.3), + ]; + + // Sort ascending (for min-heap behavior) + distances.sort(); + + // Non-NaN values should be at the front, sorted + assert_eq!(distances[0].0, 0.1); + assert_eq!(distances[1].0, 0.3); + assert_eq!(distances[2].0, 0.5); + assert_eq!(distances[3].0, 0.9); + // NaN values should be at the end + assert!(distances[4].is_nan()); + assert!(distances[5].is_nan()); + } + + #[test] + fn test_ordered_f32_deterministic_ordering() { + // Same values should always produce same ordering + for _ in 0..10 { + let values = vec![OrderedF32(0.5), OrderedF32(0.5), OrderedF32(0.1)]; + + let mut sorted = values.clone(); + sorted.sort(); + + assert_eq!(sorted[0].0, 0.1); + assert_eq!(sorted[1].0, 0.5); + assert_eq!(sorted[2].0, 0.5); + } + } + + #[test] + fn test_ordered_f32_infinity() { + let a = OrderedF32(f32::INFINITY); + let b = OrderedF32(f32::NEG_INFINITY); + let c = OrderedF32(0.0); + + assert!(a > c); + assert!(b < c); + assert!(a > b); + } +} diff --git a/crates/khive-hnsw/src/error.rs b/crates/khive-hnsw/src/error.rs new file mode 100644 index 00000000..0bea4e97 --- /dev/null +++ b/crates/khive-hnsw/src/error.rs @@ -0,0 +1,83 @@ +//! Error types for the HNSW crate. + +use thiserror::Error; + +/// Category of error — used to decide retry policy. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ErrorKind { + /// Transient — may succeed on retry. + Transient, + /// Permanent — retrying will not help. + Permanent, +} + +/// Errors that can occur during HNSW operations. +#[derive(Error, Debug)] +pub enum RetrievalError { + /// Vector dimension mismatch. + #[error("dimension mismatch: expected {expected}, got {actual}")] + DimensionMismatch { + /// Expected dimensionality. + expected: usize, + /// Actual dimensionality provided. + actual: usize, + }, + + /// Memory budget exceeded on insert. + #[error("memory budget exceeded: current={current_usage}, item={item_size}, limit={limit}")] + BudgetExceeded { + /// Current memory usage in bytes. + current_usage: usize, + /// Estimated cost of the new item in bytes. + item_size: usize, + /// Configured budget in bytes. + limit: usize, + }, + + /// HNSW index operation error. + #[error("hnsw error: {0}")] + Hnsw(String), + + /// Configuration error. + #[error("configuration error: {0}")] + Configuration(String), + + /// Query timed out. + #[error("query timed out after {elapsed_ms}ms")] + QueryTimeout { + /// How long the query ran before timing out. + elapsed_ms: u64, + }, +} + +impl RetrievalError { + /// Create an HNSW error from any displayable value. + pub fn hnsw(msg: impl std::fmt::Display) -> Self { + Self::Hnsw(msg.to_string()) + } + + /// Create a `BudgetExceeded` error. + pub fn budget_exceeded(current_usage: usize, item_size: usize, limit: usize) -> Self { + Self::BudgetExceeded { + current_usage, + item_size, + limit, + } + } + + /// Return the error kind (Transient vs Permanent). + pub fn kind(&self) -> ErrorKind { + match self { + Self::QueryTimeout { .. } => ErrorKind::Transient, + _ => ErrorKind::Permanent, + } + } + + /// Returns true if retrying the operation might succeed. + pub fn is_retryable(&self) -> bool { + self.kind() == ErrorKind::Transient + } +} + +/// Convenience `Result` alias for HNSW operations. +pub type Result = std::result::Result; diff --git a/crates/khive-hnsw/src/index/build_batch.rs b/crates/khive-hnsw/src/index/build_batch.rs new file mode 100644 index 00000000..22097ca0 --- /dev/null +++ b/crates/khive-hnsw/src/index/build_batch.rs @@ -0,0 +1,240 @@ +//! Batch build for HNSW index. +//! +//! Builds an HNSW index from a batch of vectors using a two-phase approach: +//! +//! 1. **Seed phase** (sequential): Insert sqrt(N) nodes normally to establish +//! the upper-layer graph structure and entry point. +//! +//! 2. **Search phase**: For remaining nodes, find neighbors against the frozen +//! seed graph, then merge results sequentially. + +use super::HnswIndex; +use crate::error::{Result, RetrievalError}; +use crate::node::HnswNode; +use crate::NodeId; +use rayon::prelude::*; + +/// Pre-computed neighbor information for a node to be inserted. +/// +/// Produced during the parallel search phase, consumed during the sequential merge. +struct PrecomputedInsert { + id: NodeId, + vector: Vec, + level: usize, + /// Neighbors per layer: (layer, Vec<(distance, internal_id)>). + layer_candidates: Vec<(usize, Vec<(f32, usize)>)>, +} + +impl HnswIndex { + /// Build an index from a batch of vectors. + /// + /// This is faster than individual insertions for large batches because + /// validation, level assignment, and merging are handled in one pass. + /// + /// # Algorithm + /// + /// 1. **Seed phase**: The first `sqrt(N)` nodes are inserted sequentially. + /// This builds the upper-layer structure that guides all subsequent searches. + /// + /// 2. **Level assignment**: Levels for remaining nodes are pre-generated + /// sequentially to preserve deterministic RNG consumption order. + /// + /// 3. **Search**: Each remaining node searches the current frozen graph for + /// its neighbors. + /// + /// 4. **Sequential merge**: Nodes are inserted with their pre-computed + /// neighbors, adding bidirectional connections and updating the graph. + /// + /// # Quality Notes + /// + /// Because the parallel phase searches a "frozen" graph (the seed nodes), + /// the neighbor quality for non-seed nodes depends only on the seed graph, + /// not on other parallel insertions. This means recall may differ slightly + /// from fully sequential construction. In practice, with sqrt(N) seeds the + /// difference is negligible for typical workloads. + /// + /// # Errors + /// + /// Returns an error if any vector has incorrect dimensions or if the memory + /// budget would be exceeded. + pub fn build_batch(&mut self, items: Vec<(NodeId, Vec)>) -> Result<()> { + if items.is_empty() { + return Ok(()); + } + + // Validate dimensions upfront to avoid partial builds + for (id, vector) in &items { + if vector.len() != self.config.dimensions { + return Err(RetrievalError::DimensionMismatch { + expected: self.config.dimensions, + actual: vector.len(), + }); + } + // Check for duplicates against existing index + if self.id_to_internal.contains_key(id) { + return Err(RetrievalError::hnsw(format!( + "build_batch does not support updates: ID {id:?} already exists" + ))); + } + } + + // Budget check for entire batch + if let Some(limit) = self.config.memory_budget { + let current = self.memory_usage(); + let cost_per_node = self.estimate_insert_cost(); + let total_cost = cost_per_node * items.len(); + if current + total_cost > limit { + return Err(RetrievalError::budget_exceeded(current, total_cost, limit)); + } + } + + let n = items.len(); + + // For very small batches, fall back to sequential insertion + if n <= 32 { + for (id, vector) in items { + self.insert(id, vector)?; + } + return Ok(()); + } + + // Phase 1: Sequential seed insertion of sqrt(N) nodes + // These build the upper-layer graph structure + let seed_count = ((n as f64).sqrt() as usize).max(1); + let (seed_items, remaining_items) = items.split_at(seed_count); + + for (id, vector) in seed_items { + self.insert(*id, vector.clone())?; + } + + // Phase 2: Pre-generate levels for remaining nodes in deterministic RNG order. + let mut pending: Vec<(NodeId, Vec, usize)> = Vec::with_capacity(remaining_items.len()); + for (id, vector) in remaining_items { + let level = self.random_level(); + pending.push((*id, vector.clone(), level)); + } + + // Phase 3: Neighbor search + // The graph is frozen during this phase -- only read-only search_layer is called. + let entry_point = self.entry_point; + let current_max_level = self.max_level; + let config_ef = self.config.ef_construction; + let index = &*self; + + let precomputed: Vec = pending + .into_par_iter() + .map(|(id, vector, level)| { + let norm = vector.iter().map(|x| x * x).sum::().sqrt(); + + // Navigate upper layers to find entry region + let ep = match entry_point { + Some(ep) => ep, + None => { + // Should not happen after seed phase, but handle gracefully + return PrecomputedInsert { + id, + vector, + level, + layer_candidates: Vec::new(), + }; + } + }; + + let mut current_nearest = vec![ep]; + + // Search from top layer down to level + 1 (greedy, ef=1) + for l in (level + 1..=current_max_level).rev() { + let nearest = index.search_layer(&vector, norm, ¤t_nearest, 1, l); + if !nearest.is_empty() { + current_nearest = vec![nearest[0].1]; + } + } + + // Search layers from min(level, max_level) down to 0 + let mut layer_candidates = Vec::new(); + for l in (0..=level.min(current_max_level)).rev() { + let candidates = + index.search_layer(&vector, norm, ¤t_nearest, config_ef, l); + + if !candidates.is_empty() { + current_nearest = vec![candidates[0].1]; + } + + layer_candidates.push((l, candidates)); + } + + PrecomputedInsert { + id, + vector, + level, + layer_candidates, + } + }) + .collect(); + + // Phase 4: Sequential merge -- insert nodes with pre-computed neighbors + for pc in precomputed { + self.insert_with_precomputed(pc)?; + } + + Ok(()) + } + + /// Insert a node using pre-computed neighbor candidates. + /// + /// This performs the same operations as `insert_inner`, but skips the + /// neighbor search phase since candidates were already found in parallel. + fn insert_with_precomputed(&mut self, pc: PrecomputedInsert) -> Result<()> { + let internal_id = self.nodes.len(); + + // Select neighbors from pre-computed candidates + let mut layer_neighbors: Vec<(usize, Vec)> = Vec::new(); + for (l, candidates) in &pc.layer_candidates { + let m = if *l == 0 { + self.config.m_max0 + } else { + self.config.m + }; + let neighbors = self.select_neighbors(candidates, m); + layer_neighbors.push((*l, neighbors)); + } + + // Build the node with neighbor lists + let mut node = HnswNode::new(pc.vector, pc.level); + for (l, neighbors) in &layer_neighbors { + while node.neighbors.len() <= *l { + node.neighbors.push(Vec::new()); + } + node.neighbors[*l] = neighbors.clone(); + } + + // Insert into storage (including quantized arena) + self.quantized.push(&node.vector, node.norm); + self.nodes.push(node); + self.id_to_internal.insert(pc.id, internal_id); + self.internal_to_id.push(pc.id); + + // Add bidirectional connections + for (l, neighbors) in layer_neighbors { + let m = if l == 0 { + self.config.m_max0 + } else { + self.config.m + }; + for neighbor_id in neighbors { + self.connect(neighbor_id, internal_id, l); + // Shrink if over m (m is already m_max0 for layer 0, m for upper layers) + self.shrink_connections(neighbor_id, l, m); + } + } + + // Update entry point if new node is at higher level + if pc.level > self.max_level { + self.entry_point = Some(internal_id); + self.max_level = pc.level; + } + + self.additions_since_rebuild += 1; + Ok(()) + } +} diff --git a/crates/khive-hnsw/src/index/insert.rs b/crates/khive-hnsw/src/index/insert.rs new file mode 100644 index 00000000..8c0766c8 --- /dev/null +++ b/crates/khive-hnsw/src/index/insert.rs @@ -0,0 +1,325 @@ +//! Insert operations for HNSW index. + +use crate::NodeId; +use rand::Rng; + +use super::HnswIndex; +use crate::config::MAX_LEVEL; +use crate::distance::compute_ordering_distance; +use crate::error::{Result, RetrievalError}; +use crate::metrics::{self, MetricEvent, MetricValue}; +use crate::node::HnswNode; + +impl HnswIndex { + /// Insert a vector into the index. + /// + /// If the ID already exists, the vector is updated. + /// Returns an error if dimensions don't match. + /// + /// Emits `hnsw.insert.duration_ms`, `hnsw.insert.count`, and + /// `hnsw.index.size` metrics when a sink is attached. + pub fn insert(&mut self, id: NodeId, vector: Vec) -> Result<()> { + let start = std::time::Instant::now(); + + let result = self.insert_inner(id, vector); + + // Emit metrics regardless of success/failure + let elapsed = start.elapsed().as_secs_f64() * 1000.0; + metrics::emit( + &self.metrics, + MetricEvent { + name: metrics::names::HNSW_INSERT_DURATION_MS, + value: MetricValue::Histogram(elapsed), + labels: vec![], + }, + ); + metrics::emit( + &self.metrics, + MetricEvent { + name: metrics::names::HNSW_INSERT_COUNT, + value: MetricValue::Counter(1), + labels: vec![], + }, + ); + metrics::emit( + &self.metrics, + MetricEvent { + name: metrics::names::HNSW_INDEX_SIZE, + value: MetricValue::Gauge(self.len_live() as f64), + labels: vec![], + }, + ); + + result + } + + /// Insert a batch of vectors. Returns IDs that failed along with their errors. + /// + /// Callers can hold the write lock once around this call rather than once + /// per record, keeping the lock window bounded to the batch size. + pub fn insert_many( + &mut self, + items: impl IntoIterator)>, + ) -> Vec<(NodeId, crate::error::RetrievalError)> { + let mut failures = Vec::new(); + for (id, vector) in items { + if let Err(e) = self.insert(id, vector) { + failures.push((id, e)); + } + } + failures + } + + /// Inner insert logic (uninstrumented). + pub(super) fn insert_inner(&mut self, id: NodeId, vector: Vec) -> Result<()> { + if vector.len() != self.config.dimensions { + return Err(RetrievalError::DimensionMismatch { + expected: self.config.dimensions, + actual: vector.len(), + }); + } + + // If updating existing node, just update the vector (bypasses budget) + if let Some(&iid) = self.id_to_internal.get(&id) { + self.nodes[iid].update_vector(vector.clone()); + // Update quantized arena to stay in sync + self.quantized.update(iid, &vector, self.nodes[iid].norm); + // Remove from tombstones if it was marked + if iid < self.tombstones.len() && self.tombstones[iid] { + self.tombstones[iid] = false; + self.tombstone_count -= 1; + } + return Ok(()); + } + + // Budget check before allocating a new node + if let Some(limit) = self.config.memory_budget { + let current = self.memory_usage(); + let cost = self.estimate_insert_cost(); + if current + cost > limit { + return Err(RetrievalError::budget_exceeded(current, cost, limit)); + } + } + + let level = self.random_level(); + let node = HnswNode::new(vector.clone(), level); + let query_norm = node.norm; + + // Assign internal ID = next index in vec + let internal_id = self.nodes.len(); + + // First node + if self.nodes.is_empty() { + self.quantized.push(&vector, node.norm); + self.nodes.push(node); + self.id_to_internal.insert(id, internal_id); + self.internal_to_id.push(id); + self.entry_point = Some(internal_id); + self.max_level = level; + self.additions_since_rebuild += 1; + return Ok(()); + } + + let entry_point = self.entry_point.ok_or_else(|| { + RetrievalError::hnsw("HNSW invariant violated: no entry point despite non-empty index") + })?; + let current_max_level = self.max_level; + + // Search from top layer down to level + 1 + let mut current_nearest = vec![entry_point]; + + for l in (level + 1..=current_max_level).rev() { + let nearest = self.search_layer(&vector, query_norm, ¤t_nearest, 1, l); + if !nearest.is_empty() { + current_nearest = vec![nearest[0].1]; + } + } + + // Collect neighbors for all layers (using internal usize IDs) + let mut layer_neighbors: Vec<(usize, Vec)> = Vec::new(); + + for l in (0..=level.min(current_max_level)).rev() { + let candidates = self.search_layer( + &vector, + query_norm, + ¤t_nearest, + self.config.ef_construction, + l, + ); + + let m = if l == 0 { + self.config.m_max0 + } else { + self.config.m + }; + let neighbors = self.select_neighbors(&candidates, m); + + if !candidates.is_empty() { + current_nearest = vec![candidates[0].1]; + } + + layer_neighbors.push((l, neighbors)); + } + + // Insert node first (including quantized arena) + let mut new_node = node; + for (l, neighbors) in &layer_neighbors { + while new_node.neighbors.len() <= *l { + new_node.neighbors.push(Vec::new()); + } + new_node.neighbors[*l] = neighbors.clone(); + } + self.quantized.push(&new_node.vector, new_node.norm); + self.nodes.push(new_node); + self.id_to_internal.insert(id, internal_id); + self.internal_to_id.push(id); + + // Add bidirectional connections + for (l, neighbors) in layer_neighbors { + let m = if l == 0 { + self.config.m_max0 + } else { + self.config.m + }; + for neighbor_id in neighbors { + self.connect(neighbor_id, internal_id, l); + // Shrink if over m (m is already m_max0 for layer 0, m for upper layers) + self.shrink_connections(neighbor_id, l, m); + } + } + + // Update entry point if new node is at higher level + if level > current_max_level { + self.entry_point = Some(internal_id); + self.max_level = level; + } + + self.additions_since_rebuild += 1; + Ok(()) + } + + /// Generate random level for new node (exponential distribution). + /// + /// Uses seeded RNG if `config.seed` was set for reproducible builds. + /// + /// **PROOF CORRESPONDENCE**: Lion.Retrieval.HNSW.level_prob_sums_to_one + /// Level probabilities form a valid distribution: sum_{l=0}^{inf} P(level=l) = 1 + /// + /// **PROOF CORRESPONDENCE**: Lion.Retrieval.HNSW.level_survival_decreasing + /// Survival probability decreases exponentially: P(level >= l) = (1/M)^l + pub(super) fn random_level(&mut self) -> usize { + let r: f64 = self.rng.gen::().max(f64::MIN_POSITIVE); + let level = (-r.ln() * self.config.ml).floor() as usize; + level.min(MAX_LEVEL) + } + + /// Add bidirectional connection using internal IDs. + pub(crate) fn connect(&mut self, from: usize, to: usize, layer: usize) { + let node = &mut self.nodes[from]; + while node.neighbors.len() <= layer { + node.neighbors.push(Vec::new()); + } + if !node.neighbors[layer].contains(&to) { + node.neighbors[layer].push(to); + } + } + + /// Shrink connections if over limit. + pub(crate) fn shrink_connections(&mut self, id: usize, layer: usize, m: usize) { + use crate::distance::OrderedF32; + + // Phase 1: Compute new neighbors (read only) + let new_neighbors = { + let node = &self.nodes[id]; + if layer >= node.neighbors.len() || node.neighbors[layer].len() <= m { + return; + } + + let node_vec = &node.vector; + let node_norm = node.norm; + let neighbor_ids = &node.neighbors[layer]; + + let mut scored: Vec<(f32, usize)> = neighbor_ids + .iter() + .map(|&n_id| { + let n = &self.nodes[n_id]; + ( + compute_ordering_distance( + node_vec, + node_norm, + &n.vector, + n.norm, + self.config.metric, + ), + n_id, + ) + }) + .collect(); + + // Sort by distance, then by external ID for deterministic neighbor selection + scored.sort_by(|a, b| match OrderedF32(a.0).cmp(&OrderedF32(b.0)) { + std::cmp::Ordering::Equal => self.external_id(a.1).cmp(&self.external_id(b.1)), + other => other, + }); + scored + .into_iter() + .take(m) + .map(|(_, id)| id) + .collect::>() + }; + + // Phase 2: Mutate + let node = &mut self.nodes[id]; + if layer < node.neighbors.len() { + node.neighbors[layer] = new_neighbors; + } + } + + /// Sort a node's neighbor list by distance to the node. + /// + /// Available for batch operations like post-rebuild optimization. + #[allow(dead_code)] + pub(super) fn sort_neighbors(&mut self, id: usize, layer: usize) { + use crate::distance::OrderedF32; + + // Phase 1: Compute sorted order (read only) + let sorted = { + let node = &self.nodes[id]; + if layer >= node.neighbors.len() || node.neighbors[layer].is_empty() { + return; + } + + let node_vec = &node.vector; + let node_norm = node.norm; + + let mut scored: Vec<(f32, usize)> = node.neighbors[layer] + .iter() + .map(|&n_id| { + let dist = { + let n = &self.nodes[n_id]; + compute_ordering_distance( + node_vec, + node_norm, + &n.vector, + n.norm, + self.config.metric, + ) + }; + (dist, n_id) + }) + .collect(); + + scored.sort_by(|a, b| match OrderedF32(a.0).cmp(&OrderedF32(b.0)) { + std::cmp::Ordering::Equal => self.external_id(a.1).cmp(&self.external_id(b.1)), + other => other, + }); + scored.into_iter().map(|(_, id)| id).collect::>() + }; + + // Phase 2: Mutate + let node = &mut self.nodes[id]; + if layer < node.neighbors.len() { + node.neighbors[layer] = sorted; + } + } +} diff --git a/crates/khive-hnsw/src/index/memory.rs b/crates/khive-hnsw/src/index/memory.rs new file mode 100644 index 00000000..3a6ba8f7 --- /dev/null +++ b/crates/khive-hnsw/src/index/memory.rs @@ -0,0 +1,88 @@ +//! Memory budget operations for HNSW index. + +use super::HnswIndex; + +impl HnswIndex { + /// Get the configured memory budget, if any. + pub fn memory_budget(&self) -> Option { + self.config.memory_budget + } + + /// Set or clear the memory budget at runtime. + /// + /// Pass `Some(bytes)` to enforce a limit, or `None` to remove it. + pub fn set_memory_budget(&mut self, budget: Option) { + self.config.memory_budget = budget; + } + + /// Estimate the current memory usage of the index in bytes. + /// + /// Formula: `nodes * (dims * 4 + node_overhead) + neighbor_entries * 8 + /// + per_layer_overhead + vec_overhead + id_mapping_overhead + tombstone_overhead` + /// + /// This is a conservative estimate. Actual usage may differ due to + /// allocator overhead and alignment. + pub fn memory_usage(&self) -> usize { + let num_nodes = self.nodes.len(); + let dims = self.config.dimensions; + + // Per-node: vector (dims * 4 bytes for f32) + fixed fields + // HnswNode has: vector(Vec overhead 24 + data) + neighbors(Vec overhead 24) + // + max_layer(8) + norm(4) + let node_overhead: usize = 24 + 24 + 8 + 4; // 60 bytes fixed per node + let per_node = dims * 4 + node_overhead; + let nodes_total = num_nodes * per_node; + + // Neighbor entries: each Vec per layer, each entry is 8 bytes (usize) + // Plus Vec overhead (24 bytes) per layer per node + let mut neighbor_entries: usize = 0; + let mut layer_vecs: usize = 0; + for node in &self.nodes { + layer_vecs += node.neighbors.len() * 24; // Vec overhead per layer + for layer in &node.neighbors { + neighbor_entries += layer.len(); + } + } + let neighbors_total = neighbor_entries * 8 + layer_vecs; + + // ID mapping overhead: + // HashMap: ~(num_nodes * 40) for bucket/metadata + // Vec: num_nodes * 16 + let mapping_overhead = num_nodes * 40 + num_nodes * 16; + + // Tombstone bitset overhead: 1 byte per node (Vec) + let tombstone_overhead = self.tombstones.len(); + + // Quantized arena overhead: + // - data: num_nodes * dims * 1 byte (i8) + // - meta: num_nodes * 8 bytes (QuantMeta: scale f32 + norm f32) + let quantized_overhead = num_nodes * dims + num_nodes * 8; + + nodes_total + neighbors_total + mapping_overhead + tombstone_overhead + quantized_overhead + } + + /// Estimate the memory cost of inserting a new vector. + /// + /// This is the incremental cost of one new node, including its vector + /// storage, node metadata, and expected neighbor connections. + pub fn estimate_insert_cost(&self) -> usize { + let dims = self.config.dimensions; + + // Vector data + node fixed overhead + let node_overhead: usize = 24 + 24 + 8 + 4; + let per_node = dims * 4 + node_overhead; + + // Expected neighbors: at least 1 layer with m_max0 neighbors, + // plus Vec overhead. Use m_max0 as conservative estimate for layer 0. + // Neighbors are usize (8 bytes each) + let expected_neighbors = self.config.m_max0 * 8 + 24; + + // ID mapping entry overhead (HashMap entry + Vec entry) + let mapping_entry = 40 + 16; + + // Quantized arena: dims * 1 byte (i8) + 8 bytes (QuantMeta) + let quantized_cost = dims + 8; + + per_node + expected_neighbors + mapping_entry + quantized_cost + } +} diff --git a/crates/khive-hnsw/src/index/mod.rs b/crates/khive-hnsw/src/index/mod.rs new file mode 100644 index 00000000..ceb1b6ac --- /dev/null +++ b/crates/khive-hnsw/src/index/mod.rs @@ -0,0 +1,730 @@ +//! HNSW index implementation. +//! +//! The core index structure with insert, delete, search, and rebuild operations. +//! +//! # Internal ID Scheme +//! +//! Internally, nodes are identified by dense `usize` indices into a `Vec`. +//! The public API uses `NodeId` (128-bit) -- conversion happens at the boundary. +//! This gives O(1) array indexing on the search hot path instead of HashMap probing. + +mod build_batch; +mod insert; +mod memory; +mod neighbors; +mod rebuild; +mod search; + +use std::collections::HashMap; +use std::sync::Arc; + +use crate::NodeId; +use rand::rngs::StdRng; +use rand::SeedableRng; + +use super::config::HnswConfig; +use super::node::HnswNode; +use super::stats::TombstoneStats; +use crate::metrics::MetricsSink; + +// --------------------------------------------------------------------------- +// INT8 quantized arena for fast approximate distance computation +// --------------------------------------------------------------------------- + +/// Per-vector quantization metadata (symmetric quantization). +/// +/// Stored alongside the flat `Vec` arena. Each vector's quantized data +/// is at `[internal_id * dims .. (internal_id + 1) * dims]` in the arena. +#[derive(Debug, Clone, Copy)] +pub(crate) struct QuantMeta { + /// Scale factor: `float_value = int8_value / scale`. + /// Symmetric quantization maps `[-max_abs, max_abs]` to `[-127, 127]`. + pub scale: f32, + /// Pre-computed L2 norm of the original f32 vector. + pub norm: f32, +} + +/// INT8 quantized vector arena for HNSW search acceleration. +/// +/// Stores quantized vectors in a flat `Vec` arena with the same ordering +/// as the main `nodes` vector. Used for fast approximate distance computation +/// during the candidate filtering phase of search. +/// +/// # Two-Phase Search Strategy +/// +/// 1. **Phase 1 (INT8)**: Compute approximate distance using quantized vectors. +/// This is ~3x faster than f32 distance computation (11ns vs 34ns for 384d). +/// 2. **Phase 2 (f32)**: For candidates that pass the approximate threshold, +/// compute precise f32 distance for final ranking. +/// +/// This skip pattern avoids f32 distance computation for obviously distant +/// neighbors, providing significant speedup at scale (50K+ vectors). +#[derive(Debug, Clone)] +pub(crate) struct QuantizedArena { + /// Flat INT8 vector data. Vector `i` starts at `i * dims`. + pub data: Vec, + /// Per-vector quantization metadata, indexed by internal ID. + pub meta: Vec, + /// Vector dimensionality (cached for bounds checking). + pub dims: usize, +} + +impl QuantizedArena { + /// Create a new empty quantized arena for the given dimensionality. + fn new(dims: usize) -> Self { + Self { + data: Vec::new(), + meta: Vec::new(), + dims, + } + } + + /// Quantize a float vector and append it to the arena. + /// + /// Uses symmetric quantization: `[-max_abs, max_abs]` -> `[-127, 127]`. + /// Returns the index of the newly added vector (should match the internal ID). + fn push(&mut self, vector: &[f32], norm: f32) -> usize { + debug_assert_eq!(vector.len(), self.dims); + + // Single-pass min/max over finite values + let mut max_abs: f32 = 0.0; + for &v in vector { + if v.is_finite() { + let abs = v.abs(); + if abs > max_abs { + max_abs = abs; + } + } + } + + // Symmetric quantization: scale maps max_abs to 127 + let scale = if max_abs > 1e-10 { + 127.0 / max_abs + } else { + 1.0 // Near-zero vector + }; + + // Quantize and append to flat arena + self.data.reserve(self.dims); + for &v in vector { + let q = if v.is_finite() { + (v * scale).round().clamp(-127.0, 127.0) as i8 + } else { + 0i8 + }; + self.data.push(q); + } + + let idx = self.meta.len(); + self.meta.push(QuantMeta { scale, norm }); + idx + } + + /// Update the quantized vector at the given index. + pub(crate) fn update(&mut self, idx: usize, vector: &[f32], norm: f32) { + debug_assert_eq!(vector.len(), self.dims); + debug_assert!(idx < self.meta.len()); + + let mut max_abs: f32 = 0.0; + for &v in vector { + if v.is_finite() { + let abs = v.abs(); + if abs > max_abs { + max_abs = abs; + } + } + } + + let scale = if max_abs > 1e-10 { + 127.0 / max_abs + } else { + 1.0 + }; + + let offset = idx * self.dims; + for (i, &v) in vector.iter().enumerate() { + self.data[offset + i] = if v.is_finite() { + (v * scale).round().clamp(-127.0, 127.0) as i8 + } else { + 0i8 + }; + } + + self.meta[idx] = QuantMeta { scale, norm }; + } + + /// Get the quantized data slice for a given internal ID. + #[inline] + fn get_data(&self, idx: usize) -> &[i8] { + let offset = idx * self.dims; + &self.data[offset..offset + self.dims] + } + + /// Compute approximate INT8 dot product between two quantized vectors, + /// returning the result in the original f32 scale. + /// + /// Uses SIMD-accelerated INT8 dot product from khive-embed. + #[inline] + #[allow(dead_code)] // Available for Dot metric path (future) + pub fn dot_product_approx(&self, a_idx: usize, b_data: &[i8], b_scale: f32) -> f32 { + let a_data = self.get_data(a_idx); + let a_meta = &self.meta[a_idx]; + let denom = a_meta.scale * b_scale; + if denom == 0.0 || !denom.is_finite() { + return 0.0; + } + int8_dot_product_raw(a_data, b_data) / denom + } + + /// Compute approximate INT8 cosine distance between a stored vector and + /// a query's quantized form. + /// + /// Returns distance (1 - cosine_similarity), comparable to the f32 path. + #[inline] + pub fn cosine_distance_approx( + &self, + idx: usize, + query_i8: &[i8], + query_scale: f32, + query_norm: f32, + ) -> f32 { + let meta = &self.meta[idx]; + let denom_scale = meta.scale * query_scale; + if denom_scale == 0.0 || !denom_scale.is_finite() { + return 1.0; + } + let norm_denom = meta.norm * query_norm; + if norm_denom <= 0.0 || !norm_denom.is_finite() { + return 1.0; + } + let dot = int8_dot_product_raw(self.get_data(idx), query_i8) / denom_scale; + 1.0 - (dot / norm_denom) + } + + /// Clear the arena (used by rebuild/clear). + fn clear(&mut self) { + self.data.clear(); + self.meta.clear(); + } +} + +/// Raw INT8 dot product using SIMD from khive-embed. +/// +/// Zero-allocation path: takes raw `&[i8]` slices and returns the integer +/// dot product as f32 (no scale factor division). The caller handles scaling. +/// +/// Uses the same SIMD backend as `dot_product_i8` (NEON/AVX2/AVX-512 VNNI) +/// but without constructing `QuantizedVector` wrappers. +#[inline] +fn int8_dot_product_raw(a: &[i8], b: &[i8]) -> f32 { + lattice_embed::simd::dot_product_i8_raw(a, b) +} + +/// HNSW vector index with tombstone-based lazy deletion. +/// +/// This is an IN-MEMORY index. Persistence via snapshots is handled separately. +/// All output scores are `DeterministicScore` for cross-platform consistency. +/// +/// # Internal ID Scheme +/// +/// Nodes are stored in a dense `Vec` indexed by `usize`. The mappings +/// `id_to_internal` and `internal_to_id` convert between external `NodeId` +/// and internal `usize` at the API boundary. All neighbor lists, entry point, +/// and tombstone tracking use internal `usize` IDs for O(1) lookups. +/// +/// # INT8 Quantized Search (opt-in) +/// +/// When `use_quantized` is true, search uses a two-phase strategy: +/// 1. INT8 approximate distance for candidate filtering (~3x faster) +/// 2. f32 precise distance for final scoring (exact results) +/// +/// Enable via `HnswIndex::set_quantized(true)` or `HnswIndex::with_quantized()`. +pub struct HnswIndex { + /// Configuration. + pub(crate) config: HnswConfig, + + /// Dense node storage indexed by internal usize ID. + pub(crate) nodes: Vec, + + /// External NodeId -> internal usize mapping. + /// Only used at API boundary (insert, search result conversion, delete). + pub(crate) id_to_internal: HashMap, + + /// Internal usize -> external NodeId mapping. + /// Indexed by internal ID for O(1) reverse lookup. + pub(crate) internal_to_id: Vec, + + /// Entry point node (highest layer node). Internal usize ID. + pub(crate) entry_point: Option, + + /// Current maximum layer in the graph. + pub(crate) max_level: usize, + + /// Tombstoned (soft-deleted) internal node IDs. + /// Dense bitset indexed by internal ID for O(1) lookup on the search hot path. + pub(crate) tombstones: Vec, + + /// Count of tombstoned nodes (maintained separately to avoid scanning the Vec). + pub(crate) tombstone_count: usize, + + /// Insertions since last rebuild (for tracking recall degradation). + pub(crate) additions_since_rebuild: usize, + + /// Random number generator for level generation. + /// If config.seed is Some, this is a seeded RNG for reproducibility. + pub(crate) rng: StdRng, + + /// Optional metrics sink for observability. + pub(crate) metrics: Option>, + + /// INT8 quantized vector arena for fast approximate distance. + /// Maintained in parallel with `nodes` -- same internal ID ordering. + pub(crate) quantized: QuantizedArena, + + /// Whether to use INT8 quantized distance for candidate filtering. + /// Default: false. Enable for large indexes (50K+ vectors) where + /// distance computation dominates search time. + pub(crate) use_quantized: bool, +} + +impl Clone for HnswIndex { + fn clone(&self) -> Self { + Self { + config: self.config.clone(), + nodes: self.nodes.clone(), + id_to_internal: self.id_to_internal.clone(), + internal_to_id: self.internal_to_id.clone(), + entry_point: self.entry_point, + max_level: self.max_level, + tombstones: self.tombstones.clone(), + tombstone_count: self.tombstone_count, + additions_since_rebuild: self.additions_since_rebuild, + rng: self.rng.clone(), + metrics: self.metrics.clone(), + quantized: self.quantized.clone(), + use_quantized: self.use_quantized, + } + } +} + +impl std::fmt::Debug for HnswIndex { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("HnswIndex") + .field("config", &self.config) + .field("num_nodes", &self.nodes.len()) + .field("max_level", &self.max_level) + .field("tombstones", &self.tombstone_count) + .field("additions_since_rebuild", &self.additions_since_rebuild) + .field("use_quantized", &self.use_quantized) + .finish() + } +} + +impl HnswIndex { + /// Create a new HNSW index with default configuration and specified dimensions. + pub fn new(dimensions: usize) -> Self { + Self::with_config(HnswConfig::with_dimensions(dimensions)) + } + + /// Create a new HNSW index with custom configuration. + pub fn with_config(config: HnswConfig) -> Self { + config.validate().expect("HNSW configuration must be valid"); + + // Initialize RNG - seeded if config.seed is Some, otherwise from entropy + let rng = match config.seed { + Some(seed) => StdRng::seed_from_u64(seed), + None => StdRng::from_entropy(), + }; + + let dims = config.dimensions; + Self { + config, + nodes: Vec::new(), + id_to_internal: HashMap::new(), + internal_to_id: Vec::new(), + entry_point: None, + max_level: 0, + tombstones: Vec::new(), + tombstone_count: 0, + additions_since_rebuild: 0, + rng, + metrics: None, + quantized: QuantizedArena::new(dims), + use_quantized: false, + } + } + + /// Get the current configuration. + pub fn config(&self) -> &HnswConfig { + &self.config + } + + /// Attach a metrics sink (builder pattern). + /// + /// The sink receives [`MetricEvent`]s from `search`, `insert`, and `rebuild` + /// operations. Pass an `Arc` to share a single sink across + /// multiple indices. + #[must_use] + pub fn with_metrics(mut self, sink: Arc) -> Self { + self.metrics = Some(sink); + self + } + + /// Set or replace the metrics sink at runtime. + /// + /// Pass `Some(sink)` to enable metrics, or `None` to disable. + pub fn set_metrics(&mut self, sink: Option>) { + self.metrics = sink; + } + + /// Enable INT8 quantized distance for search candidate filtering. + /// + /// When enabled, search uses a two-phase strategy: + /// 1. INT8 approximate distance for candidate screening (~3x faster) + /// 2. f32 precise distance for final ranking (exact results) + /// + /// Recommended for indexes with 50K+ vectors where distance computation + /// dominates search time. For smaller indexes, the overhead of maintaining + /// the quantized arena may not be worthwhile. + /// + /// This is a builder-pattern method. For runtime toggling, use `set_quantized`. + #[must_use] + pub fn with_quantized(mut self) -> Self { + self.use_quantized = true; + self + } + + /// Enable or disable INT8 quantized search at runtime. + /// + /// The quantized arena is always maintained (populated on insert), + /// so toggling this flag has no rebuilding cost. + pub fn set_quantized(&mut self, enabled: bool) { + self.use_quantized = enabled; + } + + /// Check if INT8 quantized search is enabled. + pub fn is_quantized(&self) -> bool { + self.use_quantized + } + + /// Get the number of vectors in the index (including tombstones). + pub fn len(&self) -> usize { + self.nodes.len() + } + + /// Get the number of live (non-tombstoned) vectors. + pub fn len_live(&self) -> usize { + self.nodes.len().saturating_sub(self.tombstone_count) + } + + /// Check if the index is empty. + pub fn is_empty(&self) -> bool { + self.nodes.is_empty() + } + + /// Get the vector for an embedding ID, if present. + /// + /// Used by the build swap logic to recover concurrent writes: + /// entries inserted into the live HNSW during a background build + /// need their vectors re-inserted into the new index at swap time. + pub fn get_vector(&self, id: &NodeId) -> Option> { + self.id_to_internal + .get(id) + .map(|&iid| self.nodes[iid].vector.clone()) + } + + /// Get tombstone statistics. + pub fn tombstone_stats(&self) -> TombstoneStats { + let total = self.nodes.len(); + let tombstones = self.tombstone_count; + let live = total.saturating_sub(tombstones); + let ratio = if total > 0 { + tombstones as f64 / total as f64 + } else { + 0.0 + }; + + TombstoneStats { + total_nodes: total, + tombstone_count: tombstones, + live_nodes: live, + ratio, + } + } + + /// Check if a rebuild is recommended based on tombstone ratio. + pub fn needs_rebuild(&self) -> bool { + self.tombstone_stats() + .needs_rebuild_at(self.config.rebuild_threshold) + } + + /// Look up the internal ID for an NodeId, if it exists. + #[inline] + #[allow(dead_code)] // TODO: wire into delta-merge path for cross-index lookups + pub(crate) fn internal_id(&self, id: &NodeId) -> Option { + self.id_to_internal.get(id).copied() + } + + /// Check if an internal ID is tombstoned. O(1) array lookup. + #[inline] + pub(crate) fn is_tombstoned(&self, iid: usize) -> bool { + iid < self.tombstones.len() && self.tombstones[iid] + } + + /// Look up the external NodeId for an internal ID. + #[inline] + pub(crate) fn external_id(&self, iid: usize) -> NodeId { + self.internal_to_id[iid] + } + + /// Create a serializable snapshot of the index state. + /// + /// The snapshot captures: + /// - All indexed vector IDs and their raw f32 embeddings + /// - Graph topology (neighbor connections per layer) + /// - Tombstone information + /// - Configuration for compatibility checking + /// + /// # Self-Contained Warm Start + /// + /// The returned snapshot includes the full vector data in the `vectors` + /// field, making it self-contained for warm-start restores. Use + /// [`restore_from_snapshot_embedded`] to restore directly from the + /// snapshot without supplying a separate vector map. + /// + /// Size estimate: `dimensions × 4 bytes × node_count`. + /// For 384-dim embeddings with 10 K nodes: ~15 MB. + pub fn snapshot(&self) -> super::checkpoint::HnswSnapshot { + use super::checkpoint::{HnswCheckpointConfig, HnswSnapshot}; + + let mut indexed_ids: Vec<_> = self.internal_to_id.clone(); + indexed_ids.sort_by(|a, b| a.as_bytes().cmp(b.as_bytes())); + + let mut tombstoned_ids: Vec<_> = self + .tombstones + .iter() + .enumerate() + .filter(|(_, &is_tomb)| is_tomb) + .map(|(iid, _)| self.external_id(iid)) + .collect(); + tombstoned_ids.sort_by(|a, b| a.as_bytes().cmp(b.as_bytes())); + + // Build layer representation -- convert internal IDs to NodeId + let mut layers = Vec::new(); + for level in 0..=self.max_level { + let mut layer_nodes = Vec::new(); + for (iid, node) in self.nodes.iter().enumerate() { + if level < node.neighbors.len() { + let neighbors: Vec = node.neighbors[level] + .iter() + .map(|&nid| self.external_id(nid)) + .collect(); + layer_nodes.push((self.external_id(iid), neighbors)); + } + } + // Sort for deterministic ordering + layer_nodes.sort_by(|(a, _), (b, _)| a.as_bytes().cmp(b.as_bytes())); + layers.push(layer_nodes); + } + + // Convert entry_point from internal to external + let entry_point_ext = self.entry_point.map(|iid| self.external_id(iid)); + + // Embed full f32 vector data for self-contained warm-start snapshots. + // Sorted by NodeId bytes to match indexed_ids ordering. + let mut vectors: Vec<(NodeId, Vec)> = self + .internal_to_id + .iter() + .zip(self.nodes.iter()) + .map(|(&id, node)| (id, node.vector.clone())) + .collect(); + vectors.sort_by(|(a, _), (b, _)| a.as_bytes().cmp(b.as_bytes())); + + HnswSnapshot { + vector_count: 0, // Legacy field, not used + total_nodes: self.nodes.len(), + live_nodes: self.len_live(), + tombstone_count: self.tombstone_count, + max_layer: self.max_level, + entry_point: entry_point_ext, + config: HnswCheckpointConfig::from_hnsw_config(&self.config), + indexed_ids, + tombstoned_ids, + layers, + vectors, + } + } + + /// Restore index topology from a snapshot using embedded vector data. + /// + /// Convenience wrapper for snapshots produced by [`snapshot`] (which embed + /// the full f32 vectors). No external vector map is required. + /// + /// # Errors + /// + /// Returns an error if: + /// - The snapshot contains no embedded vectors (`vectors` field is empty) + /// - Snapshot config is incompatible with current config + /// - Snapshot verification fails + pub fn restore_from_snapshot_embedded( + &mut self, + snapshot: &super::checkpoint::HnswSnapshot, + ) -> Result<(), crate::error::RetrievalError> { + use crate::error::RetrievalError; + + if snapshot.vectors.is_empty() && !snapshot.indexed_ids.is_empty() { + return Err(RetrievalError::hnsw( + "Snapshot contains no embedded vectors; use restore_from_snapshot with an external vector map", + )); + } + + let vectors: std::collections::HashMap> = + snapshot.vectors.iter().cloned().collect(); + self.restore_from_snapshot(snapshot, &vectors) + } + + /// Restore index topology from a snapshot. + /// + /// This rebuilds the neighbor connections from the snapshot. The caller + /// must supply vector data via the `vectors` map. If the snapshot was + /// produced by [`snapshot`] it already embeds vector data in + /// `snapshot.vectors`; you can use [`restore_from_snapshot_embedded`] + /// instead in that case. + /// + /// When both the snapshot's embedded `vectors` field and the caller-supplied + /// `vectors` map contain an entry for the same `NodeId`, the caller-supplied + /// entry takes precedence (useful for applying incremental updates on top of + /// a base snapshot). + /// + /// # Arguments + /// + /// * `snapshot` - The snapshot to restore from + /// * `vectors` - Map of ID -> vector data for all indexed vectors. + /// May be empty if `snapshot.vectors` is non-empty (self-contained snapshot). + /// + /// # Errors + /// + /// Returns an error if: + /// - Snapshot config is incompatible with current config + /// - Snapshot verification fails + /// - Referenced vectors are missing from both the snapshot and the external map + pub fn restore_from_snapshot( + &mut self, + snapshot: &super::checkpoint::HnswSnapshot, + vectors: &std::collections::HashMap>, + ) -> Result<(), crate::error::RetrievalError> { + use super::checkpoint::HnswCheckpointConfig; + use crate::error::RetrievalError; + + // Verify snapshot integrity + snapshot + .verify() + .map_err(|e| RetrievalError::hnsw(format!("Invalid snapshot: {e}")))?; + + // Check config compatibility + let current_config = HnswCheckpointConfig::from_hnsw_config(&self.config); + if !snapshot.is_compatible(¤t_config) { + return Err(RetrievalError::hnsw(format!( + "Snapshot config incompatible: expected {:?}, got {:?}", + current_config, snapshot.config + ))); + } + + // Build a merged vector lookup: snapshot-embedded vectors are the base, + // caller-supplied entries take precedence (to allow incremental updates). + // + // Priority: caller-supplied > snapshot-embedded. + // Both sources are collected into a single owned map to avoid lifetime + // complications with mixed borrows. + let mut merged_vectors: HashMap> = snapshot + .vectors + .iter() + .map(|(id, v)| (*id, v.clone())) + .collect(); + // Caller-supplied entries override embedded ones + for (id, v) in vectors { + merged_vectors.insert(*id, v.clone()); + } + + // Clear current state + self.nodes.clear(); + self.id_to_internal.clear(); + self.internal_to_id.clear(); + self.tombstones.clear(); + self.tombstone_count = 0; + self.quantized.clear(); + + // First pass: assign internal IDs and build mapping + // We need a consistent ordering, so use indexed_ids order + let mut ext_to_internal: HashMap = HashMap::new(); + for (idx, id) in snapshot.indexed_ids.iter().enumerate() { + ext_to_internal.insert(*id, idx); + } + + // Build nodes with vectors + for id in &snapshot.indexed_ids { + let vector = merged_vectors + .get(id) + .ok_or_else(|| RetrievalError::hnsw(format!("Missing vector for ID {id:?}")))? + .clone(); + + let level = self.calculate_level_for_restore(&snapshot.layers, id); + let node = super::node::HnswNode::new(vector, level); + let iid = self.nodes.len(); + self.quantized.push(&node.vector, node.norm); + self.nodes.push(node); + self.id_to_internal.insert(*id, iid); + self.internal_to_id.push(*id); + } + + // Set max_level and entry_point + self.max_level = snapshot.max_layer; + self.entry_point = snapshot + .entry_point + .and_then(|eid| ext_to_internal.get(&eid).copied()); + + // Restore neighbor connections from layers -- convert NodeId to internal usize + for (level, layer) in snapshot.layers.iter().enumerate() { + for (node_id, neighbors) in layer { + if let Some(&iid) = ext_to_internal.get(node_id) { + if level < self.nodes[iid].neighbors.len() { + self.nodes[iid].neighbors[level] = neighbors + .iter() + .filter_map(|nid| ext_to_internal.get(nid).copied()) + .collect(); + } + } + } + } + + // Restore tombstones -- convert NodeId to internal usize + // Resize tombstones bitset to match node count + self.tombstones.resize(self.nodes.len(), false); + for id in &snapshot.tombstoned_ids { + if let Some(&iid) = ext_to_internal.get(id) { + if !self.tombstones[iid] { + self.tombstones[iid] = true; + self.tombstone_count += 1; + } + } + } + + Ok(()) + } + + /// Calculate the level for a node during restore based on snapshot layers. + fn calculate_level_for_restore( + &self, + layers: &[Vec<(NodeId, Vec)>], + id: &NodeId, + ) -> usize { + // Find the highest layer where this node appears + let mut level = 0; + for (l, layer) in layers.iter().enumerate() { + if layer.iter().any(|(node_id, _)| node_id == id) { + level = l; + } + } + level + } +} diff --git a/crates/khive-hnsw/src/index/neighbors.rs b/crates/khive-hnsw/src/index/neighbors.rs new file mode 100644 index 00000000..50480e6e --- /dev/null +++ b/crates/khive-hnsw/src/index/neighbors.rs @@ -0,0 +1,72 @@ +//! Neighbor selection for HNSW index. + +use super::HnswIndex; +use crate::distance::{compute_ordering_distance, OrderedF32}; + +impl HnswIndex { + /// Select neighbors using diversified heuristic (Algorithm 4 from HNSW paper). + /// + /// Takes candidates as (distance, internal_id) pairs. + /// Returns internal IDs of selected neighbors. + /// + /// Candidates are sorted by distance once upfront, then iterated in order. + /// This avoids the O(N^2) cost of repeated min-scan + Vec::remove on the + /// unsorted candidate list. + pub(crate) fn select_neighbors(&self, candidates: &[(f32, usize)], m: usize) -> Vec { + if candidates.len() <= m { + return candidates.iter().map(|(_, id)| *id).collect(); + } + + // Sort candidates by distance (ascending), tie-break by external ID for determinism. + // This replaces the per-iteration O(N) min_by scan with a single O(N log N) sort. + let mut sorted: Vec<(f32, usize)> = candidates.to_vec(); + sorted.sort_by(|a, b| match OrderedF32(a.0).cmp(&OrderedF32(b.0)) { + std::cmp::Ordering::Equal => self.external_id(a.1).cmp(&self.external_id(b.1)), + other => other, + }); + + let mut selected: Vec<(f32, usize)> = Vec::with_capacity(m); + + // Iterate through sorted candidates in distance order, applying diversity check. + for &(dist_to_query, candidate_id) in &sorted { + if selected.len() >= m { + break; + } + + let candidate_node = &self.nodes[candidate_id]; + let candidate_vec = &candidate_node.vector; + let candidate_norm = candidate_node.norm; + + // Check diversity: candidate is closer to query than to any selected neighbor + let is_diverse = selected.iter().all(|(_, sel_id)| { + let sel_node = &self.nodes[*sel_id]; + let dist_to_selected = compute_ordering_distance( + candidate_vec, + candidate_norm, + &sel_node.vector, + sel_node.norm, + self.config.metric, + ); + dist_to_query <= dist_to_selected + }); + + if is_diverse || selected.is_empty() { + selected.push((dist_to_query, candidate_id)); + } + } + + // Fill with closest remaining if the diversity heuristic was too aggressive + if selected.len() < m { + for &(dist, id) in &sorted { + if selected.len() >= m { + break; + } + if !selected.iter().any(|(_, sid)| *sid == id) { + selected.push((dist, id)); + } + } + } + + selected.into_iter().map(|(_, id)| id).collect() + } +} diff --git a/crates/khive-hnsw/src/index/rebuild.rs b/crates/khive-hnsw/src/index/rebuild.rs new file mode 100644 index 00000000..1cba8196 --- /dev/null +++ b/crates/khive-hnsw/src/index/rebuild.rs @@ -0,0 +1,288 @@ +//! Rebuild, delete, and clear operations for HNSW index. + +use crate::NodeId; + +use super::HnswIndex; +use crate::metrics::{self, MetricEvent, MetricValue}; +use crate::stats::RebuildStats; + +impl HnswIndex { + /// Mark a vector for deletion (lazy tombstone). + /// + /// The vector is not physically removed until `rebuild()` is called. + /// Returns true if the vector existed and was marked. + /// + /// If the deleted node is the current entry point, a replacement is + /// found immediately from the node's neighbors (O(M), not O(N)). + pub fn delete(&mut self, id: NodeId) -> bool { + match self.id_to_internal.get(&id) { + Some(&iid) => { + // Grow tombstone bitset if needed + if iid >= self.tombstones.len() { + self.tombstones.resize(iid + 1, false); + } + let was_new = !self.tombstones[iid]; + if was_new { + self.tombstones[iid] = true; + self.tombstone_count += 1; + self.repair_entry_point_after_delete(iid); + } + was_new + } + None => false, + } + } + + /// If `tombstoned_id` is the current entry point, find a non-tombstoned + /// replacement from its neighbors. This is O(M) in the typical case + /// (M = neighbor count per layer) rather than O(N) scanning all nodes. + /// + /// Falls back to an O(N) scan only if ALL neighbors of the entry point + /// across ALL layers are also tombstoned -- an extremely unlikely scenario + /// that would require deleting an entire neighborhood simultaneously. + fn repair_entry_point_after_delete(&mut self, tombstoned_id: usize) { + let current_ep = match self.entry_point { + Some(ep) if ep == tombstoned_id => ep, + _ => return, // Not the entry point, nothing to do + }; + + // Search the tombstoned node's neighbors across all layers (highest first) + // for a live replacement. Prefer higher-layer neighbors since they provide + // better graph coverage for search entry. + let node = &self.nodes[current_ep]; + for layer in (0..node.neighbors.len()).rev() { + for &neighbor_id in &node.neighbors[layer] { + if !self.is_tombstoned(neighbor_id) { + self.entry_point = Some(neighbor_id); + return; + } + } + } + + // Extremely rare fallback: all neighbors are tombstoned too. + // Scan for ANY live node. This is O(N) but should essentially never happen + // in practice -- it requires tombstoning an entire neighborhood. + for iid in 0..self.nodes.len() { + if !self.is_tombstoned(iid) { + self.entry_point = Some(iid); + return; + } + } + + // All nodes are tombstoned + self.entry_point = None; + } + + /// Rebuild the index by removing tombstoned nodes. + /// + /// This physically removes tombstoned nodes and cleans up neighbor references. + /// Call this when `needs_rebuild()` returns true. + /// + /// # RETRIEVAL-11: Entry Point Behavior After Rebuild + /// + /// When rebuild removes the current entry point (because it was tombstoned), + /// a new entry point is selected automatically. The selection algorithm: + /// + /// 1. **Filter**: Consider only non-tombstoned nodes + /// 2. **Select**: Choose the node with the highest `max_layer` value + /// 3. **Tie-break**: If multiple nodes have the same max_layer, selection + /// is deterministic but implementation-defined + /// + /// ## Why Highest Layer? + /// + /// HNSW search starts at the entry point and descends through layers. A node + /// at a higher layer provides better "coverage" of the graph, allowing search + /// to quickly narrow down to the relevant region before descending. + /// + /// ## Edge Cases + /// + /// | Scenario | Behavior | + /// |----------|----------| + /// | All nodes tombstoned | Entry point becomes `None`, searches return empty | + /// | Entry point not tombstoned | Entry point unchanged | + /// | Single node remains | That node becomes entry point | + /// + /// ## Return Value + /// + /// The `entry_point_updated` field in `RebuildStats` indicates whether the + /// entry point was changed during rebuild. + /// Emits `hnsw.rebuild.duration_ms`, `hnsw.rebuild.count`, + /// `hnsw.rebuild.nodes_removed`, and `hnsw.index.size` metrics when a + /// sink is attached. + pub fn rebuild(&mut self) -> RebuildStats { + let start = std::time::Instant::now(); + + let stats = self.rebuild_inner(); + + // Emit metrics + let elapsed = start.elapsed().as_secs_f64() * 1000.0; + metrics::emit( + &self.metrics, + MetricEvent { + name: metrics::names::HNSW_REBUILD_DURATION_MS, + value: MetricValue::Histogram(elapsed), + labels: vec![], + }, + ); + metrics::emit( + &self.metrics, + MetricEvent { + name: metrics::names::HNSW_REBUILD_COUNT, + value: MetricValue::Counter(1), + labels: vec![], + }, + ); + metrics::emit( + &self.metrics, + MetricEvent { + name: metrics::names::HNSW_REBUILD_NODES_REMOVED, + value: MetricValue::Gauge(stats.nodes_removed as f64), + labels: vec![], + }, + ); + metrics::emit( + &self.metrics, + MetricEvent { + name: metrics::names::HNSW_INDEX_SIZE, + value: MetricValue::Gauge(self.len_live() as f64), + labels: vec![], + }, + ); + + stats + } + + /// Inner rebuild logic (uninstrumented). + /// + /// Rebuilds the dense storage by compacting: removes tombstoned nodes and + /// re-assigns internal IDs so the Vec remains dense. + fn rebuild_inner(&mut self) -> RebuildStats { + let nodes_before = self.nodes.len(); + let nodes_removed = self.tombstone_count; + + // Track if entry point needs update + let entry_point_was_tombstone = self + .entry_point + .map(|ep| self.is_tombstoned(ep)) + .unwrap_or(false); + + // Build old_to_new mapping: compact non-tombstoned nodes + let mut old_to_new: Vec> = vec![None; self.nodes.len()]; + let mut new_nodes: Vec = + Vec::with_capacity(self.nodes.len() - nodes_removed); + let mut new_internal_to_id: Vec = + Vec::with_capacity(self.nodes.len() - nodes_removed); + let mut new_id_to_internal = + std::collections::HashMap::with_capacity(self.nodes.len() - nodes_removed); + + let mut new_idx = 0usize; + for (old_idx, mapping) in old_to_new.iter_mut().enumerate() { + if self.is_tombstoned(old_idx) { + // Remove from external mapping + let ext_id = self.internal_to_id[old_idx]; + self.id_to_internal.remove(&ext_id); + continue; + } + *mapping = Some(new_idx); + let ext_id = self.internal_to_id[old_idx]; + new_id_to_internal.insert(ext_id, new_idx); + new_internal_to_id.push(ext_id); + new_idx += 1; + } + + // Clone nodes and remap neighbor IDs + let mut edges_cleaned = 0usize; + for old_idx in 0..self.nodes.len() { + if self.is_tombstoned(old_idx) { + continue; + } + let mut node = self.nodes[old_idx].clone(); + for neighbors in &mut node.neighbors { + let before = neighbors.len(); + // Remap internal IDs and remove references to tombstoned nodes + *neighbors = neighbors + .iter() + .filter_map(|&old_nid| old_to_new[old_nid]) + .collect(); + edges_cleaned += before - neighbors.len(); + } + new_nodes.push(node); + } + + // Update entry point + let entry_point_updated = if entry_point_was_tombstone || self.entry_point.is_none() { + // Find new entry point among surviving nodes + let new_ep = new_nodes + .iter() + .enumerate() + .max_by_key(|(_, n)| n.max_layer) + .map(|(idx, _)| idx); + self.entry_point = new_ep; + entry_point_was_tombstone + } else { + // Remap existing entry point + self.entry_point = self.entry_point.and_then(|old_ep| old_to_new[old_ep]); + false + }; + + // Update max_level + self.max_level = new_nodes.iter().map(|n| n.max_layer).max().unwrap_or(0); + + // Swap in compacted state + self.nodes = new_nodes; + self.id_to_internal = new_id_to_internal; + self.internal_to_id = new_internal_to_id; + self.tombstones.clear(); + self.tombstone_count = 0; + self.additions_since_rebuild = 0; + + // Rebuild quantized arena from compacted nodes + self.quantized.clear(); + for node in &self.nodes { + self.quantized.push(&node.vector, node.norm); + } + + RebuildStats { + nodes_before, + nodes_removed, + nodes_after: self.nodes.len(), + edges_cleaned, + entry_point_updated, + } + } + + /// Clear all data from the index. + pub fn clear(&mut self) { + self.nodes.clear(); + self.id_to_internal.clear(); + self.internal_to_id.clear(); + self.tombstones.clear(); + self.tombstone_count = 0; + self.quantized.clear(); + self.entry_point = None; + self.max_level = 0; + self.additions_since_rebuild = 0; + } + + /// Update entry point to the node with the highest max_layer. + #[allow(dead_code)] + pub(super) fn update_entry_point(&mut self) { + let new_entry = self + .nodes + .iter() + .enumerate() + .filter(|(idx, _)| !self.is_tombstoned(*idx)) + .max_by_key(|(_, n)| n.max_layer) + .map(|(idx, _)| idx); + + self.entry_point = new_entry; + self.max_level = self + .nodes + .iter() + .enumerate() + .filter(|(idx, _)| !self.is_tombstoned(*idx)) + .map(|(_, n)| n.max_layer) + .max() + .unwrap_or(0); + } +} diff --git a/crates/khive-hnsw/src/index/search.rs b/crates/khive-hnsw/src/index/search.rs new file mode 100644 index 00000000..0176435f --- /dev/null +++ b/crates/khive-hnsw/src/index/search.rs @@ -0,0 +1,837 @@ +//! Search operations for HNSW index. + +use crate::NodeId; +use khive_score::DeterministicScore; + +use super::HnswIndex; +use crate::config::DistanceMetric; +use crate::distance::{cosine_distance_from_parts, distance_to_similarity, OrderedF32}; +use crate::error::{Result, RetrievalError}; +use crate::metrics::{self, MetricEvent, MetricValue}; +use crate::search_context::HnswSearchContext; + +/// Index size below which exact linear scan beats graph traversal for Cosine/Dot. +/// Empirically measured crossover for dim=384, ef_search=80: HNSW graph ≈ exact scan at ~4K nodes. +/// Using 3K keeps us firmly in the "exact wins" zone while leaving 5K+ to the graph. +const EXACT_SCAN_THRESHOLD: usize = 3_000; + +// --------------------------------------------------------------------------- +// Inlined distance function type +// --------------------------------------------------------------------------- + +/// Distance function signature: (query, query_norm, vector, vector_norm) -> distance. +/// +/// Resolved once per search from `DistanceMetric` so the inner loop avoids a +/// `match` dispatch per neighbor. The compiler can inline the concrete SIMD +/// kernel through the function pointer on most targets. +type DistanceFn = fn(&[f32], f32, &[f32], f32) -> f32; + +/// Resolve the metric enum to a concrete distance function pointer. +/// +/// This is called once at the top of `search_layer_inner_ctx` so the hot loop +/// uses a direct call instead of branching on `DistanceMetric` per neighbor. +#[inline] +fn resolve_distance_fn(metric: DistanceMetric) -> DistanceFn { + match metric { + DistanceMetric::Cosine => |a, a_norm, b, b_norm| { + let dot = lattice_embed::simd::dot_product(a, b); + cosine_distance_from_parts(dot, a_norm, b_norm) + }, + DistanceMetric::Dot => |a, _a_norm, b, _b_norm| -lattice_embed::simd::dot_product(a, b), + DistanceMetric::L2 => { + |a, _a_norm, b, _b_norm| lattice_embed::simd::squared_euclidean_distance(a, b) + } + // Fall back to cosine for future variants. + _ => |a, a_norm, b, b_norm| { + let dot = lattice_embed::simd::dot_product(a, b); + cosine_distance_from_parts(dot, a_norm, b_norm) + }, + } +} + +// --------------------------------------------------------------------------- +// Batch-4 distance helpers (query-vs-4-candidates HNSW fast path) +// --------------------------------------------------------------------------- + +/// Check if a cached vector norm (the sqrt norm stored in HnswNode) is ≈ 1.0. +/// +/// HNSW stores sqrt norms, so the squared norm is `norm²`. We apply the same +/// 1e-4 threshold as `is_unit_norm` in `foundation/embed/src/simd/tier.rs`. +#[inline] +fn cached_norm_is_unit(norm: f32) -> bool { + norm.is_finite() && ((norm * norm) - 1.0).abs() < 1e-4 +} + +/// Convert four dot products to HNSW distances using cached candidate norms. +/// +/// For Cosine with unit query + unit candidate: `1 - dot.clamp(-1, 1)` (no sqrt/divide). +/// For Cosine otherwise: `cosine_distance_from_parts` (full formula). +/// For Dot: negate the dot (HNSW uses minimum-distance ordering). +#[inline] +fn hnsw_distance_batch4_from_dots( + metric: DistanceMetric, + dots: [f32; 4], + query_norm: f32, + query_is_unit: bool, + norms: [f32; 4], +) -> [f32; 4] { + match metric { + DistanceMetric::Cosine => { + let mut out = [0.0f32; 4]; + for j in 0..4 { + out[j] = if query_is_unit && cached_norm_is_unit(norms[j]) { + 1.0 - dots[j].clamp(-1.0, 1.0) + } else { + cosine_distance_from_parts(dots[j], query_norm, norms[j]) + }; + } + out + } + DistanceMetric::Dot => [-dots[0], -dots[1], -dots[2], -dots[3]], + _ => unreachable!("hnsw_distance_batch4_from_dots: unexpected metric"), + } +} + +// --------------------------------------------------------------------------- +// Software prefetch helpers +// --------------------------------------------------------------------------- + +/// Prefetch a memory region into L1 data cache (temporal, keep in cache). +/// +/// On aarch64 this emits `PRFM PLDL1KEEP`; on x86_64 it emits `PREFETCHT0`. +/// On other architectures the call is a no-op. +/// +/// # Safety +/// +/// The pointer does not need to be valid or aligned -- hardware prefetch +/// instructions are advisory and silently ignore bad addresses (including +/// null). The unsafe block is required only because we use inline asm / +/// intrinsics. +#[inline(always)] +fn prefetch_read_data(ptr: *const f32) { + #[cfg(target_arch = "aarch64")] + { + // PRFM PLDL1KEEP: prefetch for load, L1 data cache, temporal + // SAFETY: `PRFM` is an advisory hint instruction — the hardware silently + // ignores invalid or unmapped addresses. No validity guarantee on `ptr` + // is required; the instruction cannot fault. The `nostack` and + // `preserves_flags` options ensure the asm does not clobber the stack + // pointer or NZCV flags. + unsafe { + core::arch::asm!( + "prfm pldl1keep, [{x}]", + x = in(reg) ptr, + options(nostack, preserves_flags) + ); + } + } + #[cfg(target_arch = "x86_64")] + { + // SAFETY: `_mm_prefetch` is a prefetch hint — the CPU silently ignores + // invalid or unmapped addresses and the instruction cannot fault. + // No alignment or validity requirement on `ptr` is imposed by the x86 + // ISA for prefetch instructions. + unsafe { + core::arch::x86_64::_mm_prefetch(ptr as *const i8, core::arch::x86_64::_MM_HINT_T0); + } + } + #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))] + { + let _ = ptr; + } +} + +impl HnswIndex { + /// Search for k nearest neighbors. + /// + /// Returns results sorted by descending score (most similar first). + /// Tombstoned nodes are automatically filtered. + /// + /// Emits `hnsw.search.duration_ms`, `hnsw.search.count`, and + /// `hnsw.search.results` metrics when a sink is attached. + /// + /// **PROOF CORRESPONDENCE**: Lion.Retrieval.HNSW.search_complexity_log + /// Search complexity is O(ef * log_M(N)) where: + /// - ef is the search expansion factor + /// - M is the number of neighbors per node + /// - N is the total number of nodes + pub fn search(&self, query: &[f32], k: usize) -> Result> { + let start = std::time::Instant::now(); + + let result = self.search_inner(query, k); + + // Emit metrics + let elapsed = start.elapsed().as_secs_f64() * 1000.0; + metrics::emit( + &self.metrics, + MetricEvent { + name: metrics::names::HNSW_SEARCH_DURATION_MS, + value: MetricValue::Histogram(elapsed), + labels: vec![], + }, + ); + metrics::emit( + &self.metrics, + MetricEvent { + name: metrics::names::HNSW_SEARCH_COUNT, + value: MetricValue::Counter(1), + labels: vec![], + }, + ); + if let Ok(ref results) = result { + metrics::emit( + &self.metrics, + MetricEvent { + name: metrics::names::HNSW_SEARCH_RESULTS, + value: MetricValue::Gauge(results.len() as f64), + labels: vec![], + }, + ); + } + + result + } + + /// Search for k nearest neighbors using a pre-allocated search context. + /// + /// This avoids per-query heap allocation by reusing buffers across searches. + /// For maximum throughput in batch/streaming scenarios, create one + /// `HnswSearchContext` and pass it to every search call. + /// + /// Returns results sorted by descending score (most similar first). + /// Tombstoned nodes are automatically filtered. + /// + /// Emits the same metrics as [`search`](Self::search). + pub fn search_with_context( + &self, + query: &[f32], + k: usize, + ctx: &mut HnswSearchContext, + ) -> Result> { + let start = std::time::Instant::now(); + + let result = self.search_inner_with_ctx(query, k, ctx); + + // Emit metrics + let elapsed = start.elapsed().as_secs_f64() * 1000.0; + metrics::emit( + &self.metrics, + MetricEvent { + name: metrics::names::HNSW_SEARCH_DURATION_MS, + value: MetricValue::Histogram(elapsed), + labels: vec![], + }, + ); + metrics::emit( + &self.metrics, + MetricEvent { + name: metrics::names::HNSW_SEARCH_COUNT, + value: MetricValue::Counter(1), + labels: vec![], + }, + ); + if let Ok(ref results) = result { + metrics::emit( + &self.metrics, + MetricEvent { + name: metrics::names::HNSW_SEARCH_RESULTS, + value: MetricValue::Gauge(results.len() as f64), + labels: vec![], + }, + ); + } + + result + } + + /// Inner search logic (uninstrumented), allocating fresh buffers. + fn search_inner(&self, query: &[f32], k: usize) -> Result> { + let ef = self.config.ef_search.max(k); + let mut ctx = HnswSearchContext::new(ef); + self.search_inner_with_ctx(query, k, &mut ctx) + } + + /// Inner search logic using a caller-provided search context. + fn search_inner_with_ctx( + &self, + query: &[f32], + k: usize, + ctx: &mut HnswSearchContext, + ) -> Result> { + if query.len() != self.config.dimensions { + return Err(RetrievalError::DimensionMismatch { + expected: self.config.dimensions, + actual: query.len(), + }); + } + + if self.nodes.is_empty() { + return Ok(Vec::new()); + } + + // H3: exact scan beats graph traversal for small Cosine/Dot indexes. + if self.nodes.len() <= EXACT_SCAN_THRESHOLD + && matches!( + self.config.metric, + DistanceMetric::Cosine | DistanceMetric::Dot + ) + { + return self.exact_scan_top_k(query, k); + } + + let entry_point = match self.entry_point { + Some(ep) => ep, + None => return Ok(Vec::new()), + }; + + // The entry point should always be live because `delete()` calls + // `repair_entry_point_after_delete` to maintain this invariant. + // The check below is a defensive fallback for indexes that were + // constructed before this fix or restored from old snapshots. + let effective_entry = if self.is_tombstoned(entry_point) { + match self.find_live_neighbor(entry_point) { + Some(alt) => alt, + None => return Ok(Vec::new()), // All nodes are tombstoned + } + } else { + entry_point + }; + + self.search_from_entry_with_ctx(query, k, effective_entry, ctx) + } + + /// Search from a specific entry point with tombstone filtering, using pre-allocated context. + fn search_from_entry_with_ctx( + &self, + query: &[f32], + k: usize, + entry_point: usize, + ctx: &mut HnswSearchContext, + ) -> Result> { + let query_norm = query.iter().map(|x| x * x).sum::().sqrt(); + let (effective_k, effective_ef) = self.compute_overscan(k); + + // Search from top layer + let mut current_nearest = vec![entry_point]; + + // Traverse upper layers (greedy search with ef=1) + for l in (1..=self.max_level).rev() { + self.search_layer_inner_ctx(query, query_norm, ¤t_nearest, 1, l, false, ctx); + if !ctx.result_buf.is_empty() { + current_nearest = vec![ctx.result_buf[0].1]; + } + } + + // If final entry point is tombstoned after upper-layer traversal, + // find a live neighbor. This is rare since the entry point invariant + // ensures we start from a live node, but upper-layer greedy search + // can land on a tombstoned node if it was tombstoned after insertion. + let final_entry = current_nearest[0]; + if self.is_tombstoned(final_entry) { + match self.find_live_neighbor(final_entry) { + Some(alt) => current_nearest = vec![alt], + None => return Ok(Vec::new()), // All nodes tombstoned + } + } + + // Search layer 0 with full ef, filtering tombstones. + // effective_ef is already scaled up by the tombstone ratio. + let ef = effective_ef.max(effective_k); + self.search_layer_inner_ctx(query, query_norm, ¤t_nearest, ef, 0, true, ctx); + + // Convert internal IDs to external NodeId with DeterministicScore. + // For L2, the internal search used squared_euclidean_distance for ordering; + // recover the true L2 distance before converting to similarity. + let is_l2 = self.config.metric == DistanceMetric::L2; + let search_results: Vec<(NodeId, DeterministicScore)> = ctx + .result_buf + .iter() + .filter(|(_, iid)| !self.is_tombstoned(*iid)) + .take(k) + .map(|(dist, iid)| { + let true_dist = if is_l2 { dist.max(0.0).sqrt() } else { *dist }; + let similarity = distance_to_similarity(true_dist, self.config.metric); + ( + self.external_id(*iid), + DeterministicScore::from_f32(similarity), + ) + }) + .collect(); + + Ok(search_results) + } + + /// Exact linear scan for small indexes (n <= EXACT_SCAN_THRESHOLD). + /// + /// Uses the batch-4 SIMD dot kernel for throughput. For Cosine and Dot metrics + /// only — the early-exit condition in search_inner_with_ctx enforces this. + fn exact_scan_top_k( + &self, + query: &[f32], + k: usize, + ) -> Result> { + if k == 0 { + return Ok(Vec::new()); + } + + let dot4 = lattice_embed::simd::resolved_dot_product_batch4_kernel(); + let dot1 = lattice_embed::simd::resolved_dot_product_kernel(); + + let query_norm = query.iter().map(|x| x * x).sum::().sqrt(); + let query_is_unit = cached_norm_is_unit(query_norm); + let metric = self.config.metric; + let n = self.nodes.len(); + + let mut scored: Vec<(usize, f32)> = Vec::with_capacity(n); + let mut i = 0usize; + + while i + 4 <= n { + let dots = dot4( + query, + &self.nodes[i].vector, + &self.nodes[i + 1].vector, + &self.nodes[i + 2].vector, + &self.nodes[i + 3].vector, + ); + let norms = [ + self.nodes[i].norm, + self.nodes[i + 1].norm, + self.nodes[i + 2].norm, + self.nodes[i + 3].norm, + ]; + let dists = + hnsw_distance_batch4_from_dots(metric, dots, query_norm, query_is_unit, norms); + for (j, &dist) in dists.iter().enumerate() { + if !self.is_tombstoned(i + j) { + scored.push((i + j, distance_to_similarity(dist, metric))); + } + } + i += 4; + } + while i < n { + if !self.is_tombstoned(i) { + let dot = dot1(query, &self.nodes[i].vector); + let dist = if query_is_unit && cached_norm_is_unit(self.nodes[i].norm) { + 1.0 - dot.clamp(-1.0, 1.0) + } else { + match metric { + DistanceMetric::Cosine => { + cosine_distance_from_parts(dot, query_norm, self.nodes[i].norm) + } + DistanceMetric::Dot => -dot, + _ => unreachable!(), + } + }; + scored.push((i, distance_to_similarity(dist, metric))); + } + i += 1; + } + + if scored.is_empty() { + return Ok(Vec::new()); + } + + let effective_k = k.min(scored.len()); + if scored.len() > effective_k { + scored.select_nth_unstable_by(effective_k - 1, |(_, a), (_, b)| { + b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal) + }); + scored.truncate(effective_k); + } + scored.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal)); + + Ok(scored + .into_iter() + .map(|(iid, sim)| (self.external_id(iid), DeterministicScore::from_f32(sim))) + .collect()) + } + + /// Find a live (non-tombstoned) node by searching neighbors of the given + /// node across all layers. Returns `None` only if every reachable neighbor + /// and every node in the index is tombstoned. + /// + /// Complexity: O(M * L) in the typical case (M = neighbors per layer, + /// L = layers). Falls back to O(N) scan only if all neighbors are dead. + fn find_live_neighbor(&self, node_id: usize) -> Option { + let node = &self.nodes[node_id]; + // Check neighbors from highest layer down (higher layers have better + // graph coverage, so finding a live neighbor there is preferable). + for layer in (0..node.neighbors.len()).rev() { + for &neighbor_id in &node.neighbors[layer] { + if !self.is_tombstoned(neighbor_id) { + return Some(neighbor_id); + } + } + } + + // Extremely rare fallback: all neighbors tombstoned. O(N) scan. + (0..self.nodes.len()).find(|&iid| !self.is_tombstoned(iid)) + } + + /// Compute overscan factors based on tombstone ratio. + /// + /// Returns `(effective_k, effective_ef)` where both are scaled up + /// proportionally to compensate for tombstoned nodes that will be + /// filtered from results. + /// + /// Previously only `k` was scaled, which had no effect when + /// `ef_search >> k` (the common case). Now `ef_search` is also scaled + /// so the beam width expands to compensate for dead nodes encountered + /// during graph traversal. + pub(super) fn compute_overscan(&self, k: usize) -> (usize, usize) { + let stats = self.tombstone_stats(); + if stats.tombstone_count == 0 { + return (k, self.config.ef_search); + } + + let live_ratio = stats.live_nodes as f64 / stats.total_nodes.max(1) as f64; + if live_ratio <= 0.0 { + return (k, self.config.ef_search); + } + + let inv_live = 1.0 / live_ratio; + + let overscan_k = (k as f64 * inv_live).ceil() as usize; + let effective_k = overscan_k.min(k * 4).max(k); + + let overscan_ef = (self.config.ef_search as f64 * inv_live).ceil() as usize; + let effective_ef = overscan_ef + .min(self.config.ef_search * 4) + .max(self.config.ef_search); + + (effective_k, effective_ef) + } + + /// Search a single layer for nearest neighbors (allocates fresh buffers). + /// + /// Used by insert path where a search context is not available. + /// Returns (distance, internal_id) pairs. + pub(crate) fn search_layer( + &self, + query: &[f32], + query_norm: f32, + entry_points: &[usize], + ef: usize, + layer: usize, + ) -> Vec<(f32, usize)> { + self.search_layer_inner(query, query_norm, entry_points, ef, layer, false) + } + + /// Internal search implementation with tombstone filtering option. + /// Allocates fresh buffers -- used by the insert path. + pub(super) fn search_layer_inner( + &self, + query: &[f32], + query_norm: f32, + entry_points: &[usize], + ef: usize, + layer: usize, + filter_tombstones: bool, + ) -> Vec<(f32, usize)> { + let mut ctx = HnswSearchContext::new(ef); + self.search_layer_inner_ctx( + query, + query_norm, + entry_points, + ef, + layer, + filter_tombstones, + &mut ctx, + ); + // Move out of ctx to avoid clone + std::mem::take(&mut ctx.result_buf) + } + + /// Core search implementation using pre-allocated buffers. + /// + /// Results are written into `ctx.result_buf`, sorted by distance ascending + /// with deterministic tie-breaking by external NodeId. + /// + /// This is the hot path. Optimizations applied: + /// - Pre-allocated visited set with O(1) generation-counter clear + /// - O(1) array indexing via dense usize IDs (no HashMap probing) + /// - Cached worst-distance to avoid heap peek per neighbor + /// - Tombstone check skipped entirely when tombstone set is empty + /// - Early termination when candidate distance exceeds worst result + /// - Inlined distance dispatch: metric resolved to function pointer once + /// - Batch neighbor processing with software prefetch pipelining + /// - (Optional) INT8 quantized pre-filter: skip f32 for obviously distant neighbors + #[allow(clippy::too_many_arguments)] + fn search_layer_inner_ctx( + &self, + query: &[f32], + query_norm: f32, + entry_points: &[usize], + ef: usize, + layer: usize, + filter_tombstones: bool, + ctx: &mut HnswSearchContext, + ) { + // Reset buffers without deallocating + ctx.clear(); + ctx.ensure_capacity(ef, self.nodes.len()); + + // Pre-compute: skip tombstone checks when there are no tombstones. + let check_tombstones = filter_tombstones && self.tombstone_count > 0; + + // Resolve distance function once -- eliminates per-neighbor match dispatch. + let distance_fn = resolve_distance_fn(self.config.metric); + + // Resolve batch-4 dot kernel once for Cosine/Dot metrics. + let metric = self.config.metric; + let use_dot_batch4 = matches!(metric, DistanceMetric::Cosine | DistanceMetric::Dot); + let dot4_kernel = lattice_embed::simd::resolved_dot_product_batch4_kernel(); + let query_is_unit = cached_norm_is_unit(query_norm); + + // --------------------------------------------------------------- + // INT8 quantized pre-filter setup (only for Cosine on layer 0) + // --------------------------------------------------------------- + // The quantized path is only used when: + // 1. use_quantized is enabled + // 2. We're on layer 0 (densest layer, most distance computations) + // 3. The metric is Cosine (the only one we have INT8 distance for) + // 4. The quantized arena is populated + // + // For upper layers (ef=1, greedy), the overhead of quantization is + // not worthwhile since we evaluate very few candidates. + let use_quant = self.use_quantized + && layer == 0 + && self.config.metric == DistanceMetric::Cosine + && !self.quantized.meta.is_empty(); + + // Pre-quantize the query vector once if we're using the INT8 path. + // This avoids re-quantizing per neighbor. + let (query_i8, query_scale) = if use_quant { + let mut max_abs: f32 = 0.0; + for &v in query { + if v.is_finite() { + let abs = v.abs(); + if abs > max_abs { + max_abs = abs; + } + } + } + let scale = if max_abs > 1e-10 { + 127.0 / max_abs + } else { + 1.0 + }; + let quantized: Vec = query + .iter() + .map(|&v| { + if v.is_finite() { + (v * scale).round().clamp(-127.0, 127.0) as i8 + } else { + 0i8 + } + }) + .collect(); + (quantized, scale) + } else { + (Vec::new(), 0.0) + }; + + // Mark entry points as visited + ctx.visited.visit_all(entry_points.iter().copied()); + + // Initialize with entry points (always use f32 for entry points) + for &ep in entry_points { + if check_tombstones && self.is_tombstoned(ep) { + continue; + } + let node = &self.nodes[ep]; + let dist = distance_fn(query, query_norm, &node.vector, node.norm); + ctx.candidates + .push(std::cmp::Reverse((OrderedF32(dist), ep))); + ctx.results.push((OrderedF32(dist), ep)); + } + + // Track the worst distance in the result set to avoid heap peek per neighbor. + let mut worst_dist = ctx + .results + .peek() + .map(|(OrderedF32(d), _)| *d) + .unwrap_or(f32::MAX); + + // Scratch buffer for batching neighbor processing. + // Each entry: (internal_id, vector_ptr, vector_len, norm). + let mut batch: Vec<(usize, *const f32, usize, f32)> = Vec::with_capacity(32); + + while let Some(std::cmp::Reverse((OrderedF32(c_dist), c_id))) = ctx.candidates.pop() { + // Early termination: if the closest candidate is worse than the + // worst result and we have enough results, we're done. + if c_dist > worst_dist && ctx.results.len() >= ef { + break; + } + + // Explore neighbors -- direct array index, no HashMap lookup + let node = &self.nodes[c_id]; + if layer < node.neighbors.len() { + let neighbors = &node.neighbors[layer]; + + // Phase 1: Collect unvisited neighbors. + // O(1) visited check via generation counter, O(1) node access via array index. + batch.clear(); + for &neighbor_id in neighbors { + if ctx.visited.visit(neighbor_id) { + if check_tombstones && self.is_tombstoned(neighbor_id) { + continue; + } + + // INT8 pre-filter: skip neighbors that are clearly worse + // than the current worst result. Uses a 10% margin to + // account for quantization error. + if use_quant && ctx.results.len() >= ef { + let approx_dist = self.quantized.cosine_distance_approx( + neighbor_id, + &query_i8, + query_scale, + query_norm, + ); + // Only skip if the approximate distance exceeds the + // worst by more than the quantization margin (10%). + // This ensures we never miss a true nearest neighbor. + if approx_dist > worst_dist * 1.1 + 0.01 { + continue; + } + } + + let neighbor = &self.nodes[neighbor_id]; + batch.push(( + neighbor_id, + neighbor.vector.as_ptr(), + neighbor.vector.len(), + neighbor.norm, + )); + } + } + + // Phase 2: Compute f32 distances with batch-4 SIMD + prefetch pipelining. + // + // For Cosine/Dot: process 4 candidates at once via the batch-4 dot kernel, + // converting raw dots to HNSW distances (with unit-norm shortcut for cosine). + // Remainder (< 4) and L2/other metrics use the per-pair distance_fn path. + // + // Heap updates are applied in original batch order to preserve HNSW recall + // and deterministic neighbor ordering. + if let Some(&(_, ptr, _, _)) = batch.first() { + prefetch_read_data(ptr); + } + + let mut bi = 0; + + // Batch-4 fast path (Cosine / Dot metrics only). + if use_dot_batch4 { + while bi + 4 <= batch.len() { + // Prefetch the entry 4 slots ahead to hide memory latency. + if bi + 4 < batch.len() { + prefetch_read_data(batch[bi + 4].1); + } + + let (id0, p0, l0, n0) = batch[bi]; + let (id1, p1, l1, n1) = batch[bi + 1]; + let (id2, p2, l2, n2) = batch[bi + 2]; + let (id3, p3, l3, n3) = batch[bi + 3]; + + if l0 == query.len() + && l1 == query.len() + && l2 == query.len() + && l3 == query.len() + { + // SAFETY: pointers from live &Vec within this &self borrow. + let v0 = unsafe { std::slice::from_raw_parts(p0, l0) }; + let v1 = unsafe { std::slice::from_raw_parts(p1, l1) }; + let v2 = unsafe { std::slice::from_raw_parts(p2, l2) }; + let v3 = unsafe { std::slice::from_raw_parts(p3, l3) }; + + let dots = dot4_kernel(query, v0, v1, v2, v3); + let dists = hnsw_distance_batch4_from_dots( + metric, + dots, + query_norm, + query_is_unit, + [n0, n1, n2, n3], + ); + + for (neighbor_id, dist) in [ + (id0, dists[0]), + (id1, dists[1]), + (id2, dists[2]), + (id3, dists[3]), + ] { + if !(ctx.results.len() >= ef && dist > worst_dist) { + ctx.candidates + .push(std::cmp::Reverse((OrderedF32(dist), neighbor_id))); + ctx.results.push((OrderedF32(dist), neighbor_id)); + if ctx.results.len() > ef { + ctx.results.pop(); + } + if let Some(&(OrderedF32(d), _)) = ctx.results.peek() { + worst_dist = d; + } + } + } + bi += 4; + continue; + } + + // Dimension mismatch — fall through to scalar remainder. + break; + } + } + + // Scalar remainder: 0-3 leftover entries after batch-4, + // or all entries for L2 / other metrics. + while bi < batch.len() { + let (neighbor_id, vec_ptr, vec_len, norm) = batch[bi]; + + if bi + 1 < batch.len() { + let (_, next_ptr, next_len, _) = batch[bi + 1]; + prefetch_read_data(next_ptr); + if next_len > 16 { + prefetch_read_data(next_ptr.wrapping_add(16)); + } + } + + // SAFETY: vec_ptr and vec_len come from a live `&Vec` + // obtained from `self.nodes[neighbor_id]` within this same `&self` + // borrow. The Vec is immutable (&self) and is not mutated between + // Phase 1 (pointer capture) and Phase 2 (dereference). + let neighbor_vec = unsafe { std::slice::from_raw_parts(vec_ptr, vec_len) }; + let dist = distance_fn(query, query_norm, neighbor_vec, norm); + + if !(ctx.results.len() >= ef && dist > worst_dist) { + ctx.candidates + .push(std::cmp::Reverse((OrderedF32(dist), neighbor_id))); + ctx.results.push((OrderedF32(dist), neighbor_id)); + if ctx.results.len() > ef { + ctx.results.pop(); + } + if let Some(&(OrderedF32(d), _)) = ctx.results.peek() { + worst_dist = d; + } + } + bi += 1; + } + } + } + + // Drain results into the scratch buffer and sort. + // Tie-break by external NodeId for deterministic ordering. + ctx.result_buf.clear(); + ctx.result_buf + .extend(ctx.results.drain().map(|(d, iid)| (d.0, iid))); + ctx.result_buf + .sort_by(|a, b| match OrderedF32(a.0).cmp(&OrderedF32(b.0)) { + std::cmp::Ordering::Equal => self.external_id(a.1).cmp(&self.external_id(b.1)), + other => other, + }); + } +} + +// NOTE ON SORTED NEIGHBORS: Sorting neighbor lists by distance was evaluated. It adds +// O(M log M) work per connection during insert. With dense Vec-based node storage, +// neighbors are already accessed via O(1) array index. Sorting could improve early +// termination in the distance computation phase but the benefit is marginal. +// The `sort_neighbors` method is preserved for future use with post-rebuild batch +// optimization, but is not called on the insert hot path. diff --git a/crates/khive-hnsw/src/lib.rs b/crates/khive-hnsw/src/lib.rs new file mode 100644 index 00000000..e93a9cba --- /dev/null +++ b/crates/khive-hnsw/src/lib.rs @@ -0,0 +1,133 @@ +//! HNSW (Hierarchical Navigable Small World) vector index. +//! +//! High-performance approximate nearest neighbor search with O(log N) complexity. +//! See ADR-003 for configuration and performance characteristics. +//! +//! # Usage +//! +//! ```rust,ignore +//! use khive_hnsw::{HnswConfig, HnswIndex, NodeId}; +//! +//! // Create index with default config +//! let config = HnswConfig::default(); +//! let mut index = HnswIndex::new(768); +//! +//! // Insert vectors +//! let id = NodeId::new([1u8; 16]); +//! index.insert(id, vec![0.1; 768])?; +//! +//! // Search for k nearest neighbors +//! let results = index.search(&query_vector, 10)?; +//! for (id, score) in results { +//! println!("{}: {}", id, score); +//! } +//! ``` +//! +//! # Algorithm +//! +//! HNSW builds a multi-layer graph where: +//! - Higher layers have fewer nodes (exponentially distributed) +//! - Search starts from top layer and descends +//! - Each layer uses greedy search to find nearest neighbors +//! +//! Reference: Malkov & Yashunin, "Efficient and robust approximate nearest +//! neighbor search using Hierarchical Navigable Small World graphs" (2018) + +pub mod alias; +pub mod arena; +pub mod checkpoint; +mod config; +mod distance; +pub mod error; +mod index; +pub mod metrics; +mod node; +pub(crate) mod search_context; +mod stats; + +#[cfg(test)] +mod tests; + +// Re-export public types +#[cfg(feature = "checkpoint")] +pub use checkpoint::{HnswCheckpoint, HnswCheckpointStore}; +pub use checkpoint::{HnswCheckpointConfig, HnswSnapshot}; +pub use config::{DistanceMetric, HnswConfig}; +pub use index::HnswIndex; +pub use search_context::HnswSearchContext; +pub use stats::{RebuildStats, TombstoneStats}; + +/// 128-bit opaque node identifier for HNSW entries. +/// +/// Serializes as a 32-character lowercase hex string (compatible with the +/// snapshot format used by `HnswSnapshot`). +#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct NodeId([u8; 16]); + +impl NodeId { + /// Create a `NodeId` from raw bytes. + #[inline] + pub const fn new(bytes: [u8; 16]) -> Self { + Self(bytes) + } + + /// Return the raw byte representation. + #[inline] + pub fn as_bytes(&self) -> &[u8; 16] { + &self.0 + } +} + +impl std::fmt::Debug for NodeId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "NodeId(")?; + for b in &self.0 { + write!(f, "{b:02x}")?; + } + write!(f, ")") + } +} + +impl std::fmt::Display for NodeId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for b in &self.0 { + write!(f, "{b:02x}")?; + } + Ok(()) + } +} + +impl serde::Serialize for NodeId { + fn serialize(&self, s: S) -> Result { + let mut hex = String::with_capacity(32); + for b in &self.0 { + hex.push_str(&format!("{b:02x}")); + } + s.serialize_str(&hex) + } +} + +impl<'de> serde::Deserialize<'de> for NodeId { + fn deserialize>(d: D) -> Result { + let s = String::deserialize(d)?; + if s.len() != 32 { + return Err(serde::de::Error::custom(format!( + "NodeId hex string must be 32 chars, got {}", + s.len() + ))); + } + let mut bytes = [0u8; 16]; + for (i, chunk) in s.as_bytes().chunks(2).enumerate() { + let hi = char::from(chunk[0]) + .to_digit(16) + .ok_or_else(|| serde::de::Error::custom("invalid hex character"))? + as u8; + let lo = char::from(chunk[1]) + .to_digit(16) + .ok_or_else(|| serde::de::Error::custom("invalid hex character"))? + as u8; + bytes[i] = (hi << 4) | lo; + } + Ok(NodeId(bytes)) + } +} diff --git a/crates/khive-hnsw/src/metrics.rs b/crates/khive-hnsw/src/metrics.rs new file mode 100644 index 00000000..47468569 --- /dev/null +++ b/crates/khive-hnsw/src/metrics.rs @@ -0,0 +1,148 @@ +//! Metrics infrastructure for HNSW observability. +//! +//! Consumers attach a `MetricsSink` implementation to an `HnswIndex` to +//! receive structured telemetry from insert, search, and rebuild operations. +//! +//! # Design +//! +//! The trait is object-safe (`Arc`) so a single sink can be +//! shared across multiple index instances. The `emit` helper handles the +//! `None` case (no sink attached) at call sites. + +use std::sync::{Arc, Mutex}; + +// --------------------------------------------------------------------------- +// Metric value types +// --------------------------------------------------------------------------- + +/// A single metric value emitted from an HNSW operation. +#[derive(Debug, Clone, PartialEq)] +pub enum MetricValue { + /// Monotonically increasing counter (e.g., insert count). + Counter(u64), + /// Point-in-time gauge (e.g., index size). + Gauge(f64), + /// Distribution observation (e.g., operation duration in ms). + Histogram(f64), +} + +/// A single metric event emitted from an HNSW operation. +#[derive(Debug, Clone)] +pub struct MetricEvent { + /// Metric name (use the constants in [`names`]). + pub name: &'static str, + /// The metric value. + pub value: MetricValue, + /// Optional key-value label pairs (e.g., `[("metric", "cosine")]`). + pub labels: Vec<(&'static str, String)>, +} + +// --------------------------------------------------------------------------- +// Sink trait +// --------------------------------------------------------------------------- + +/// Receiver for metric events from HNSW operations. +/// +/// Implement this trait to bridge HNSW telemetry to your observability stack +/// (e.g., Prometheus, OpenTelemetry, tracing spans). +/// +/// # Thread Safety +/// +/// The trait requires `Send + Sync` so that `Arc` can be +/// shared across threads. +pub trait MetricsSink: Send + Sync { + /// Handle a metric event. + fn emit(&self, event: MetricEvent); +} + +// --------------------------------------------------------------------------- +// Emit helper +// --------------------------------------------------------------------------- + +/// Emit a metric event to the attached sink, if any. +/// +/// This is the call-site helper used by `HnswIndex` internals. It is a no-op +/// when `sink` is `None`. +#[inline] +pub fn emit(sink: &Option>, event: MetricEvent) { + if let Some(s) = sink { + s.emit(event); + } +} + +// --------------------------------------------------------------------------- +// Metric name constants +// --------------------------------------------------------------------------- + +/// Canonical metric name constants. +/// +/// Using `&'static str` constants avoids string formatting on the hot path. +pub mod names { + /// Duration of a single insert operation in milliseconds (Histogram). + pub const HNSW_INSERT_DURATION_MS: &str = "hnsw.insert.duration_ms"; + /// Number of insert operations (Counter). + pub const HNSW_INSERT_COUNT: &str = "hnsw.insert.count"; + /// Current live node count after insert (Gauge). + pub const HNSW_INDEX_SIZE: &str = "hnsw.index.size"; + + /// Duration of a single search operation in milliseconds (Histogram). + pub const HNSW_SEARCH_DURATION_MS: &str = "hnsw.search.duration_ms"; + /// Number of search operations (Counter). + pub const HNSW_SEARCH_COUNT: &str = "hnsw.search.count"; + /// Number of results returned by a search (Gauge). + pub const HNSW_SEARCH_RESULTS: &str = "hnsw.search.results"; + + /// Duration of a rebuild operation in milliseconds (Histogram). + pub const HNSW_REBUILD_DURATION_MS: &str = "hnsw.rebuild.duration_ms"; + /// Number of rebuild operations (Counter). + pub const HNSW_REBUILD_COUNT: &str = "hnsw.rebuild.count"; + /// Number of nodes removed during a rebuild (Gauge). + pub const HNSW_REBUILD_NODES_REMOVED: &str = "hnsw.rebuild.nodes_removed"; +} + +// --------------------------------------------------------------------------- +// Recording sink (test helper) +// --------------------------------------------------------------------------- + +/// A `MetricsSink` that records all events for inspection in tests. +/// +/// Thread-safe: uses an internal `Mutex`. +pub struct RecordingSink { + events: Mutex>, +} + +impl RecordingSink { + /// Create a new, empty recording sink. + pub fn new() -> Self { + Self { + events: Mutex::new(Vec::new()), + } + } + + /// Return a snapshot of all recorded events. + pub fn events(&self) -> Vec { + self.events.lock().unwrap().clone() + } + + /// Clear all recorded events. + pub fn clear(&self) { + self.events.lock().unwrap().clear(); + } + + /// Returns `true` if no events have been recorded since the last clear. + pub fn is_empty(&self) -> bool { + self.events.lock().unwrap().is_empty() + } +} + +impl Default for RecordingSink { + fn default() -> Self { + Self::new() + } +} + +impl MetricsSink for RecordingSink { + fn emit(&self, event: MetricEvent) { + self.events.lock().unwrap().push(event); + } +} diff --git a/crates/khive-hnsw/src/node.rs b/crates/khive-hnsw/src/node.rs new file mode 100644 index 00000000..e1757534 --- /dev/null +++ b/crates/khive-hnsw/src/node.rs @@ -0,0 +1,62 @@ +//! Internal HNSW node representation. + +/// Internal node in the HNSW graph. +/// +/// Nodes are stored in a dense `Vec` indexed by an internal `usize` ID. +/// The `EmbeddingId` <-> `usize` mapping is maintained by `HnswIndex`. +/// Neighbor lists use internal `usize` IDs for O(1) array lookups during search. +#[derive(Debug, Clone)] +pub(crate) struct HnswNode { + /// The vector data. + pub vector: Vec, + /// Connections per layer: layer -> list of internal neighbor IDs. + pub neighbors: Vec>, + /// Maximum layer this node exists in. + pub max_layer: usize, + /// Cached L2 norm for cosine similarity optimization. + pub norm: f32, +} + +impl HnswNode { + /// Create a new node with computed norm. + pub fn new(vector: Vec, max_layer: usize) -> Self { + let norm = vector.iter().map(|x| x * x).sum::().sqrt(); + Self { + vector, + neighbors: vec![Vec::new(); max_layer + 1], + max_layer, + norm, + } + } + + /// Update vector and recompute norm. + pub fn update_vector(&mut self, vector: Vec) { + self.norm = vector.iter().map(|x| x * x).sum::().sqrt(); + self.vector = vector; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_node_creation() { + let vector = vec![3.0, 4.0]; // norm = 5.0 + let node = HnswNode::new(vector, 2); + + assert_eq!(node.max_layer, 2); + assert!((node.norm - 5.0).abs() < 0.001); + assert_eq!(node.neighbors.len(), 3); // layers 0, 1, 2 + } + + #[test] + fn test_node_update_vector() { + let mut node = HnswNode::new(vec![1.0, 0.0], 1); + assert!((node.norm - 1.0).abs() < 0.001); + + node.update_vector(vec![3.0, 4.0]); + assert!((node.norm - 5.0).abs() < 0.001); + assert_eq!(node.vector, vec![3.0, 4.0]); + } +} diff --git a/crates/khive-hnsw/src/search_context.rs b/crates/khive-hnsw/src/search_context.rs new file mode 100644 index 00000000..f7c13425 --- /dev/null +++ b/crates/khive-hnsw/src/search_context.rs @@ -0,0 +1,165 @@ +//! Pre-allocated search buffers for HNSW search. +//! +//! Avoids per-query heap allocation of `BinaryHeap`, `HashSet`, and result vectors. +//! Create one `HnswSearchContext` and reuse it across multiple `search_with_context` calls +//! for maximum throughput. +//! +//! # Performance +//! +//! The key optimizations are: +//! 1. **Buffer reuse**: All data structures are cleared between searches but their +//! allocated memory persists, eliminating allocator pressure. +//! 2. **Generation-counter visited set**: Uses a dense `Vec` indexed directly by +//! internal node ID. `clear()` is O(1) (just increment generation counter). +//! `visit()` and `is_visited()` are O(1) array lookups with no hashing. + +use std::collections::BinaryHeap; + +use crate::distance::OrderedF32; + +/// O(1) visited set using generation counter and dense array. +/// +/// Each node slot stores the generation number when it was last visited. +/// To "clear" the set, we just increment the generation counter -- O(1). +/// A node is visited iff `markers[id] == generation`. +/// +/// This replaces `HashSet` which required O(capacity) clear +/// and O(1) amortized insert with hash computation overhead per operation. +pub(crate) struct VisitedSet { + /// Current generation number. Incremented on each `clear()`. + generation: u64, + /// Dense array indexed by internal node ID. + /// `markers[id] == generation` means node `id` has been visited. + markers: Vec, +} + +impl VisitedSet { + /// Create a new visited set with the given capacity hint. + pub fn new(capacity: usize) -> Self { + Self { + generation: 1, // Start at 1 so default 0 values are "not visited" + markers: vec![0u64; capacity], + } + } + + /// Clear the visited set in O(1) by incrementing the generation counter. + /// + /// On the extremely rare wrap-around (every 2^64 clears), we zero the + /// markers array to prevent false positives. + #[inline] + pub fn clear(&mut self) { + self.generation = self.generation.wrapping_add(1); + if self.generation == 0 { + // Wrapped around -- reset markers to avoid false positives + self.markers.fill(0); + self.generation = 1; + } + } + + /// Ensure the set can accommodate node IDs up to `max_id` (inclusive). + #[inline] + pub fn ensure_capacity(&mut self, max_id: usize) { + if max_id >= self.markers.len() { + self.markers.resize(max_id + 1, 0); + } + } + + /// Mark a node as visited. Returns `true` if the node was NOT previously visited + /// (i.e., this is the first visit), matching `HashSet::insert` semantics. + #[inline] + pub fn visit(&mut self, id: usize) -> bool { + if id >= self.markers.len() { + self.markers.resize(id + 1, 0); + } + if self.markers[id] == self.generation { + false // already visited + } else { + self.markers[id] = self.generation; + true // newly visited + } + } + + /// Mark multiple nodes as visited. + #[inline] + pub fn visit_all(&mut self, ids: impl Iterator) { + for id in ids { + self.visit(id); + } + } +} + +/// Pre-allocated search context for HNSW queries. +/// +/// Reuse across multiple `search_with_context` calls to amortize allocation cost. +/// The context holds the working buffers for the greedy beam search: +/// +/// - `candidates`: min-heap of nodes to explore (closest first) -- uses internal usize IDs +/// - `results`: max-heap of best results so far (furthest first, for pruning) -- uses internal usize IDs +/// - `visited`: generation-counter visited set indexed by internal usize ID +/// - `result_buf`: scratch buffer for final sorted output (internal usize IDs) +/// +/// # Example +/// +/// ```rust,ignore +/// use khive_retrieval::hnsw::{HnswIndex, HnswSearchContext}; +/// +/// let index = HnswIndex::new(128); +/// // ... insert vectors ... +/// +/// let mut ctx = HnswSearchContext::new(index.config().ef_search); +/// +/// // Reuse ctx across many searches +/// for query in queries { +/// let results = index.search_with_context(&query, 10, &mut ctx)?; +/// // process results... +/// } +/// ``` +pub struct HnswSearchContext { + /// Min-heap: candidates to explore (closest first). Uses internal usize IDs. + pub(crate) candidates: BinaryHeap>, + /// Max-heap: best results so far (furthest first, for pruning). Uses internal usize IDs. + pub(crate) results: BinaryHeap<(OrderedF32, usize)>, + /// Visited node tracking with O(1) operations. + pub(crate) visited: VisitedSet, + /// Scratch buffer for final sorted results (internal usize IDs). + pub(crate) result_buf: Vec<(f32, usize)>, + /// Pre-allocated capacity hint (ef value used to size buffers). + ef_hint: usize, +} + +impl HnswSearchContext { + /// Create a new search context pre-allocated for the given `ef` value. + /// + /// The `ef` parameter should match or exceed the `ef_search` config value + /// of the index you plan to search. + pub fn new(ef: usize) -> Self { + Self { + candidates: BinaryHeap::with_capacity(ef), + results: BinaryHeap::with_capacity(ef), + visited: VisitedSet::new(ef * 4), // Over-allocate to reduce resizes + result_buf: Vec::with_capacity(ef), + ef_hint: ef, + } + } + + /// Clear all buffers without deallocating. + /// + /// Called automatically at the start of each search. You do not need to + /// call this manually. + pub(crate) fn clear(&mut self) { + self.candidates.clear(); + self.results.clear(); + self.visited.clear(); // O(1) generation increment + self.result_buf.clear(); + } + + /// Ensure all buffers are large enough for the given `ef` and node count. + pub(crate) fn ensure_capacity(&mut self, ef: usize, num_nodes: usize) { + if ef > self.ef_hint { + self.result_buf + .reserve(ef.saturating_sub(self.result_buf.capacity())); + self.ef_hint = ef; + } + self.visited.ensure_capacity(num_nodes); + } +} diff --git a/crates/khive-hnsw/src/stats.rs b/crates/khive-hnsw/src/stats.rs new file mode 100644 index 00000000..acd0df36 --- /dev/null +++ b/crates/khive-hnsw/src/stats.rs @@ -0,0 +1,79 @@ +//! HNSW index statistics types. + +use super::config::DEFAULT_REBUILD_THRESHOLD; + +/// Statistics about tombstoned nodes in an HNSW index. +#[derive(Debug, Clone, Copy)] +pub struct TombstoneStats { + /// Total number of nodes in the graph. + pub total_nodes: usize, + /// Number of tombstoned nodes. + pub tombstone_count: usize, + /// Number of live (non-tombstoned) nodes. + pub live_nodes: usize, + /// Ratio of tombstoned to total nodes (0.0 - 1.0). + pub ratio: f64, +} + +impl TombstoneStats { + /// Check if rebuild is needed based on default threshold. + pub fn needs_rebuild(&self) -> bool { + self.needs_rebuild_at(DEFAULT_REBUILD_THRESHOLD) + } + + /// Check if rebuild is needed at a specific threshold. + pub fn needs_rebuild_at(&self, threshold: f64) -> bool { + self.ratio > threshold + } +} + +/// Statistics returned from a rebuild operation. +#[derive(Debug, Clone, Copy)] +pub struct RebuildStats { + /// Number of nodes before rebuild. + pub nodes_before: usize, + /// Number of nodes removed (tombstones). + pub nodes_removed: usize, + /// Number of nodes after rebuild. + pub nodes_after: usize, + /// Number of neighbor references cleaned up. + pub edges_cleaned: usize, + /// Whether entry point was updated. + pub entry_point_updated: bool, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tombstone_stats_needs_rebuild() { + let stats = TombstoneStats { + total_nodes: 100, + tombstone_count: 10, + live_nodes: 90, + ratio: 0.10, + }; + assert!(!stats.needs_rebuild()); // 10% == 10% threshold (not strictly greater) + + let stats = TombstoneStats { + total_nodes: 100, + tombstone_count: 20, + live_nodes: 80, + ratio: 0.20, + }; + assert!(stats.needs_rebuild()); // 20% > 10% threshold (ADR-003) + } + + #[test] + fn test_tombstone_stats_custom_threshold() { + let stats = TombstoneStats { + total_nodes: 100, + tombstone_count: 10, + live_nodes: 90, + ratio: 0.10, + }; + assert!(stats.needs_rebuild_at(0.05)); // 10% > 5% + assert!(!stats.needs_rebuild_at(0.15)); // 10% < 15% + } +} diff --git a/crates/khive-hnsw/src/tests.rs b/crates/khive-hnsw/src/tests.rs new file mode 100644 index 00000000..1cd0296f --- /dev/null +++ b/crates/khive-hnsw/src/tests.rs @@ -0,0 +1,1976 @@ +//! Tests for HNSW index. + +#[cfg(test)] +mod unit_tests { + use crate::NodeId; + use crate::{DistanceMetric, HnswConfig, HnswIndex}; + use khive_score::DeterministicScore; + + use std::collections::HashSet; + + fn make_id(seed: u8) -> NodeId { + NodeId::new([seed; 16]) + } + + fn generate_random_vector(dim: usize, seed: u64) -> Vec { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + (0..dim) + .map(|i| { + let mut hasher = DefaultHasher::new(); + (seed, i).hash(&mut hasher); + (hasher.finish() as f32 / u64::MAX as f32) * 2.0 - 1.0 + }) + .collect() + } + + #[test] + fn test_insert_and_search() { + let mut index = HnswIndex::new(3); + + let id1 = make_id(1); + let id2 = make_id(2); + let id3 = make_id(3); + + index.insert(id1, vec![1.0, 0.0, 0.0]).expect("insert id1"); + index.insert(id2, vec![0.9, 0.1, 0.0]).expect("insert id2"); + index.insert(id3, vec![0.0, 1.0, 0.0]).expect("insert id3"); + + assert_eq!(index.len(), 3); + + // Search for vector similar to [1.0, 0.0, 0.0] + let results = index.search(&[1.0, 0.0, 0.0], 2).expect("search"); + + assert_eq!(results.len(), 2); + // First result should be id1 (exact match) + assert_eq!(results[0].0, id1); + assert!(results[0].1.to_f64() > 0.99); + } + + #[test] + fn test_update_existing() { + let mut index = HnswIndex::new(3); + let id = make_id(1); + + index.insert(id, vec![1.0, 0.0, 0.0]).expect("insert"); + index.insert(id, vec![0.0, 1.0, 0.0]).expect("update"); + + // Should still be 1 vector + assert_eq!(index.len(), 1); + + // Search should find the updated vector + let results = index.search(&[0.0, 1.0, 0.0], 1).expect("search"); + assert_eq!(results[0].0, id); + assert!(results[0].1.to_f64() > 0.99); + } + + #[test] + fn test_delete_tombstone() { + let mut index = HnswIndex::new(3); + + let id1 = make_id(1); + let id2 = make_id(2); + + index.insert(id1, vec![1.0, 0.0, 0.0]).expect("insert id1"); + index.insert(id2, vec![0.0, 1.0, 0.0]).expect("insert id2"); + + // Delete id1 + assert!(index.delete(id1)); + + // Should still have 2 nodes but 1 tombstone + assert_eq!(index.len(), 2); + assert_eq!(index.tombstone_stats().tombstone_count, 1); + + // Search should not return tombstoned node + let results = index.search(&[1.0, 0.0, 0.0], 2).expect("search"); + assert_eq!(results.len(), 1); + assert_eq!(results[0].0, id2); + } + + #[test] + fn test_rebuild() { + let mut index = HnswIndex::new(3); + + let id1 = make_id(1); + let id2 = make_id(2); + let id3 = make_id(3); + + index.insert(id1, vec![1.0, 0.0, 0.0]).expect("insert"); + index.insert(id2, vec![0.0, 1.0, 0.0]).expect("insert"); + index.insert(id3, vec![0.0, 0.0, 1.0]).expect("insert"); + + index.delete(id1); + index.delete(id2); + + let stats = index.rebuild(); + assert_eq!(stats.nodes_before, 3); + assert_eq!(stats.nodes_removed, 2); + assert_eq!(stats.nodes_after, 1); + + // After rebuild, should only have id3 + assert_eq!(index.len(), 1); + assert_eq!(index.tombstone_stats().tombstone_count, 0); + } + + #[test] + fn test_dimension_mismatch() { + let mut index = HnswIndex::new(3); + let id = make_id(1); + + // Wrong dimension should error + let result = index.insert(id, vec![1.0, 0.0]); + assert!(result.is_err()); + + // Insert correct dimension + index.insert(id, vec![1.0, 0.0, 0.0]).expect("insert"); + + // Search with wrong dimension should error + let result = index.search(&[1.0, 0.0], 1); + assert!(result.is_err()); + } + + #[test] + fn test_empty_search() { + let index = HnswIndex::new(3); + let results = index.search(&[1.0, 0.0, 0.0], 10).expect("search empty"); + assert!(results.is_empty()); + } + + #[test] + fn test_dot_product_metric() { + let mut config = HnswConfig::with_dimensions(2); + config.metric = DistanceMetric::Dot; + let mut index = HnswIndex::with_config(config); + + let id1 = make_id(1); + let id2 = make_id(2); + + index.insert(id1, vec![1.0, 0.0]).expect("insert id1"); + index.insert(id2, vec![2.0, 0.0]).expect("insert id2"); + + // For dot product, [2,0] . [1,0] = 2 > [1,0] . [1,0] = 1 + let results = index.search(&[1.0, 0.0], 2).expect("search"); + assert_eq!(results[0].0, id2); + } + + #[test] + fn test_euclidean_metric() { + let mut config = HnswConfig::with_dimensions(2); + config.metric = DistanceMetric::L2; + let mut index = HnswIndex::with_config(config); + + let id1 = make_id(1); + let id2 = make_id(2); + + index.insert(id1, vec![1.0, 0.0]).expect("insert id1"); + index.insert(id2, vec![10.0, 0.0]).expect("insert id2"); + + // Euclidean: closer = higher score + let results = index.search(&[0.0, 0.0], 2).expect("search"); + assert_eq!(results[0].0, id1); // Closer to origin + } + + #[test] + fn test_larger_index() { + let mut index = HnswIndex::new(128); + + let n = 500; + let mut ids = Vec::new(); + for i in 0..n { + let id = NodeId::new([ + (i >> 8) as u8, + (i & 0xff) as u8, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ]); + let vector = generate_random_vector(128, i as u64); + index.insert(id, vector).expect("insert"); + ids.push(id); + } + + assert_eq!(index.len(), n); + + // Search should return k results + let query = generate_random_vector(128, 50); + let results = index.search(&query, 10).expect("search"); + assert_eq!(results.len(), 10); + + // Scores should be sorted descending + for window in results.windows(2) { + assert!(window[0].1 >= window[1].1); + } + } + + #[test] + fn test_recall() { + let mut index = HnswIndex::new(64); + + let n = 500; + let mut vectors: Vec<(NodeId, Vec)> = Vec::new(); + for i in 0..n { + let id = NodeId::new([ + (i >> 8) as u8, + (i & 0xff) as u8, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ]); + let vector = generate_random_vector(64, i as u64); + index.insert(id, vector.clone()).expect("insert"); + vectors.push((id, vector)); + } + + let k = 10; + let mut total_recall = 0.0; + let num_queries = 10; + + for q in 0..num_queries { + let (query_id, query) = &vectors[q * 50]; + + // Brute force ground truth + let mut ground_truth: Vec<(f32, NodeId)> = vectors + .iter() + .map(|(id, v)| { + let dot: f32 = query.iter().zip(v).map(|(a, b)| a * b).sum(); + let q_norm: f32 = query.iter().map(|x| x * x).sum::().sqrt(); + let v_norm: f32 = v.iter().map(|x| x * x).sum::().sqrt(); + let sim = if q_norm > 0.0 && v_norm > 0.0 { + dot / (q_norm * v_norm) + } else { + 0.0 + }; + (1.0 - sim, *id) + }) + .collect(); + ground_truth.sort_by(|a, b| a.0.total_cmp(&b.0)); + let truth_ids: HashSet = + ground_truth.iter().take(k).map(|(_, id)| *id).collect(); + + // HNSW search + let results = index.search(query, k).expect("search"); + let result_ids: HashSet = results.iter().map(|(id, _)| *id).collect(); + + let recall = + truth_ids.intersection(&result_ids).count() as f32 / truth_ids.len() as f32; + total_recall += recall; + + // First result should be the query itself + assert_eq!( + results[0].0, *query_id, + "Query {q} should return itself as first result" + ); + } + + let avg_recall = total_recall / num_queries as f32; + assert!( + avg_recall > 0.8, + "Average recall {avg_recall:.2} should be > 0.8" + ); + } + + #[test] + fn test_config_variants() { + // Test each config variant builds successfully + for config in [ + HnswConfig::default(), + HnswConfig::high_recall(), + HnswConfig::fast_build(), + HnswConfig::low_memory(), + ] { + let mut index = HnswIndex::with_config(HnswConfig { + dimensions: 32, + ..config + }); + + for i in 0..100 { + let id = NodeId::new([i as u8; 16]); + let vector = generate_random_vector(32, i as u64); + index.insert(id, vector).expect("insert"); + } + assert_eq!(index.len(), 100); + } + } + + #[test] + fn test_tombstone_stats() { + let mut index = HnswIndex::new(3); + + for i in 0..10 { + let id = make_id(i); + index.insert(id, vec![i as f32, 0.0, 0.0]).expect("insert"); + } + + // Tombstone 3 nodes + for i in 0..3 { + index.delete(make_id(i)); + } + + let stats = index.tombstone_stats(); + assert_eq!(stats.total_nodes, 10); + assert_eq!(stats.tombstone_count, 3); + assert_eq!(stats.live_nodes, 7); + assert!((stats.ratio - 0.3).abs() < 0.01); + } + + #[test] + fn test_needs_rebuild_threshold() { + let mut config = HnswConfig::with_dimensions(3); + config.rebuild_threshold = 0.2; + let mut index = HnswIndex::with_config(config); + + for i in 0..10 { + let id = make_id(i); + index.insert(id, vec![i as f32, 0.0, 0.0]).expect("insert"); + } + + // 1 tombstone = 10%, under 20% threshold + index.delete(make_id(0)); + assert!(!index.needs_rebuild()); + + // 3 tombstones = 30%, over 20% threshold + index.delete(make_id(1)); + index.delete(make_id(2)); + assert!(index.needs_rebuild()); + } + + #[test] + fn test_deterministic_score_output() { + let mut index = HnswIndex::new(3); + + let id = make_id(1); + index.insert(id, vec![1.0, 0.0, 0.0]).expect("insert"); + + let results = index.search(&[1.0, 0.0, 0.0], 1).expect("search"); + + // Score should be DeterministicScore, not f32 + let score = results[0].1; + assert!(score.to_f64() > 0.99); + assert!(score.to_f64() <= 1.0); + + // Scores are comparable + assert!(score > DeterministicScore::from_f64(0.5)); + } + + #[test] + fn test_clear() { + let mut index = HnswIndex::new(3); + + index + .insert(make_id(1), vec![1.0, 0.0, 0.0]) + .expect("insert"); + index + .insert(make_id(2), vec![0.0, 1.0, 0.0]) + .expect("insert"); + index.delete(make_id(1)); + + assert_eq!(index.len(), 2); + assert_eq!(index.tombstone_stats().tombstone_count, 1); + + index.clear(); + + assert_eq!(index.len(), 0); + assert_eq!(index.tombstone_stats().tombstone_count, 0); + assert!(index.is_empty()); + } + + #[test] + fn test_seeded_rng_reproducibility() { + // Two indexes with the same seed should produce identical structure + let config1 = HnswConfig::with_dimensions(32).with_seed(42); + let config2 = HnswConfig::with_dimensions(32).with_seed(42); + + let mut index1 = HnswIndex::with_config(config1); + let mut index2 = HnswIndex::with_config(config2); + + // Insert same vectors in same order + for i in 0..50 { + let id = NodeId::new([i as u8; 16]); + let vector = generate_random_vector(32, i as u64); + index1.insert(id, vector.clone()).expect("insert"); + index2.insert(id, vector).expect("insert"); + } + + // Search should return identical results + let query = generate_random_vector(32, 999); + let results1 = index1.search(&query, 10).expect("search"); + let results2 = index2.search(&query, 10).expect("search"); + + assert_eq!(results1.len(), results2.len()); + for (r1, r2) in results1.iter().zip(results2.iter()) { + assert_eq!(r1.0, r2.0, "Same seed should produce identical results"); + assert_eq!(r1.1, r2.1, "Scores should match exactly"); + } + } + + #[test] + fn test_different_seeds_different_structure() { + // Two indexes with different seeds should (likely) produce different structures + let config1 = HnswConfig::with_dimensions(32).with_seed(42); + let config2 = HnswConfig::with_dimensions(32).with_seed(123); + + let mut index1 = HnswIndex::with_config(config1); + let mut index2 = HnswIndex::with_config(config2); + + // Insert same vectors in same order + for i in 0..100 { + let id = NodeId::new([i as u8; 16]); + let vector = generate_random_vector(32, i as u64); + index1.insert(id, vector.clone()).expect("insert"); + index2.insert(id, vector).expect("insert"); + } + + // Get max_level for each - with different seeds, the random levels + // assigned to nodes will differ, leading to different max_level values + // (statistically likely to differ with 100 insertions) + // Note: They might still be the same by chance, but the internal + // structure (which nodes are at which level) will differ + + // At minimum, both should be valid indexes + assert_eq!(index1.len(), 100); + assert_eq!(index2.len(), 100); + } +} + +#[cfg(test)] +mod memory_budget_tests { + use crate::error::{ErrorKind, RetrievalError}; + use crate::NodeId; + use crate::{HnswConfig, HnswIndex}; + + fn make_id(seed: u8) -> NodeId { + NodeId::new([seed; 16]) + } + + #[test] + fn test_no_budget_allows_unlimited_inserts() { + // Without a budget, inserts always succeed + let mut index = HnswIndex::new(4); + for i in 0..50 { + let id = make_id(i); + index + .insert(id, vec![i as f32, 0.0, 0.0, 0.0]) + .expect("insert should succeed without budget"); + } + assert_eq!(index.len(), 50); + } + + #[test] + fn test_budget_blocks_insert_when_exceeded() { + // Set a very tight budget that allows only a few nodes + let config = HnswConfig::with_dimensions(4).with_memory_budget(2_000); + let mut index = HnswIndex::with_config(config); + + // First insert should succeed (index starts empty) + index + .insert(make_id(1), vec![1.0, 0.0, 0.0, 0.0]) + .expect("first insert should succeed"); + + // Keep inserting until we hit the budget + let mut rejected = false; + for i in 2..=100u8 { + let result = index.insert(make_id(i), vec![i as f32, 0.0, 0.0, 0.0]); + if let Err(err) = result { + rejected = true; + assert!( + matches!(err, RetrievalError::BudgetExceeded { .. }), + "Expected BudgetExceeded, got: {err:?}" + ); + assert_eq!(err.kind(), ErrorKind::Permanent); + assert!(!err.is_retryable()); + break; + } + } + assert!(rejected, "Budget should have rejected an insert"); + } + + #[test] + fn test_budget_update_bypasses_check() { + // Set a budget, fill it up, then update an existing entry + let config = HnswConfig::with_dimensions(4).with_memory_budget(2_000); + let mut index = HnswIndex::with_config(config); + + let id1 = make_id(1); + index + .insert(id1, vec![1.0, 0.0, 0.0, 0.0]) + .expect("first insert"); + + // Fill until budget hit + for i in 2..=100u8 { + if index + .insert(make_id(i), vec![i as f32, 0.0, 0.0, 0.0]) + .is_err() + { + break; + } + } + + // Updating an existing entry should always succeed (bypass budget) + index + .insert(id1, vec![9.0, 9.0, 9.0, 9.0]) + .expect("update existing should bypass budget"); + } + + #[test] + fn test_memory_usage_increases_with_inserts() { + let mut index = HnswIndex::new(8); + + let before = index.memory_usage(); + assert_eq!(before, 0, "Empty index should have zero usage"); + + index.insert(make_id(1), vec![1.0; 8]).expect("insert"); + let after_one = index.memory_usage(); + assert!(after_one > 0, "Usage should increase after insert"); + + index.insert(make_id(2), vec![2.0; 8]).expect("insert"); + let after_two = index.memory_usage(); + assert!( + after_two > after_one, + "Usage should increase with more inserts" + ); + } + + #[test] + fn test_estimate_insert_cost_is_positive() { + let index = HnswIndex::new(128); + let cost = index.estimate_insert_cost(); + assert!(cost > 0, "Insert cost should be positive"); + // For 128 dims: 128*4 = 512 bytes for vector alone + assert!(cost >= 512, "Cost should include at least the vector data"); + } + + #[test] + fn test_memory_budget_getter_setter() { + let mut index = HnswIndex::new(4); + + // Default: no budget + assert_eq!(index.memory_budget(), None); + + // Set budget via runtime setter + index.set_memory_budget(Some(10_000)); + assert_eq!(index.memory_budget(), Some(10_000)); + + // Clear budget + index.set_memory_budget(None); + assert_eq!(index.memory_budget(), None); + } + + #[test] + fn test_budget_from_config() { + let config = HnswConfig::with_dimensions(4).with_memory_budget(5_000); + let index = HnswIndex::with_config(config); + assert_eq!(index.memory_budget(), Some(5_000)); + } + + #[test] + fn test_budget_exceeded_error_details() { + let config = HnswConfig::with_dimensions(4).with_memory_budget(1); + let mut index = HnswIndex::with_config(config); + + // Budget of 1 byte is too small for any insert + let result = index.insert(make_id(1), vec![1.0, 0.0, 0.0, 0.0]); + assert!(result.is_err()); + + let err = result.unwrap_err(); + match err { + RetrievalError::BudgetExceeded { + current_usage, + item_size, + limit, + } => { + assert_eq!(current_usage, 0, "Empty index"); + assert!(item_size > 0, "Item should have non-zero cost"); + assert_eq!(limit, 1, "Limit should match config"); + assert!(current_usage + item_size > limit, "Should genuinely exceed"); + } + other => panic!("Expected BudgetExceeded, got: {other:?}"), + } + } + + #[test] + fn test_search_unaffected_by_budget() { + // Budget is only checked on insert, never on search + let config = HnswConfig::with_dimensions(3).with_memory_budget(100_000); + let mut index = HnswIndex::with_config(config); + + index + .insert(make_id(1), vec![1.0, 0.0, 0.0]) + .expect("insert"); + index + .insert(make_id(2), vec![0.0, 1.0, 0.0]) + .expect("insert"); + + // Search should work regardless of budget + let results = index.search(&[1.0, 0.0, 0.0], 2).expect("search"); + assert_eq!(results.len(), 2); + } +} + +#[cfg(test)] +mod proptests { + use crate::HnswIndex; + use crate::NodeId; + + use proptest::prelude::*; + + const DIM: usize = 32; + + fn seeded_vector(seed: u64) -> Vec { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + (0..DIM) + .map(|i| { + let mut hasher = DefaultHasher::new(); + (seed, i).hash(&mut hasher); + (hasher.finish() as f32 / u64::MAX as f32) * 2.0 - 1.0 + }) + .collect() + } + + proptest! { + /// Property: search() returns exactly k results when k <= num_vectors + #[test] + fn search_returns_k_results( + k in 1usize..=10, + num_vectors in 10usize..=100 + ) { + prop_assume!(k <= num_vectors); + + let mut index = HnswIndex::new(DIM); + + for i in 0..num_vectors { + let id = NodeId::new([ + (i >> 8) as u8, + (i & 0xff) as u8, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]); + let vector = seeded_vector(i as u64); + index.insert(id, vector).expect("insert"); + } + + prop_assert_eq!(index.len(), num_vectors); + + let query = seeded_vector(999); + let results = index.search(&query, k).expect("search"); + + prop_assert_eq!( + results.len(), + k, + "Expected {} results but got {}", + k, + results.len() + ); + + // All scores should be finite (DeterministicScore is i64 fixed-point; check via to_f64). + for (_, score) in &results { + prop_assert!( + score.to_f64().is_finite(), + "Score should be finite" + ); + } + + // Results sorted descending by score + for window in results.windows(2) { + prop_assert!( + window[0].1 >= window[1].1, + "Results should be sorted by score descending" + ); + } + } + + /// Property: search() returns min(k, num_vectors) when k > num_vectors + #[test] + fn search_returns_all_when_k_exceeds_count( + num_vectors in 1usize..=20, + k_excess in 1usize..=30 + ) { + let k = num_vectors + k_excess; + + let mut index = HnswIndex::new(DIM); + + for i in 0..num_vectors { + let id = NodeId::new([i as u8; 16]); + let vector = seeded_vector(i as u64); + index.insert(id, vector).expect("insert"); + } + + let query = seeded_vector(999); + let results = index.search(&query, k).expect("search"); + + prop_assert_eq!( + results.len(), + num_vectors, + "Expected {} results (all vectors) but got {}", + num_vectors, + results.len() + ); + } + + /// Property: empty index returns empty results + #[test] + fn search_empty_index_returns_empty(k in 1usize..=100) { + let index = HnswIndex::new(DIM); + + let query = seeded_vector(0); + let results = index.search(&query, k).expect("search"); + + prop_assert!( + results.is_empty(), + "Empty index should return empty results" + ); + } + } +} + +#[cfg(test)] +mod metrics_tests { + use crate::metrics::{names, MetricValue, RecordingSink}; + use crate::HnswIndex; + use crate::NodeId; + + use std::sync::Arc; + + fn make_id(seed: u8) -> NodeId { + NodeId::new([seed; 16]) + } + + #[test] + fn insert_emits_metrics() { + let sink = Arc::new(RecordingSink::new()); + let mut index = HnswIndex::new(3).with_metrics(sink.clone()); + + index.insert(make_id(1), vec![1.0, 0.0, 0.0]).unwrap(); + + let events = sink.events(); + let event_names: Vec<&str> = events.iter().map(|e| e.name).collect(); + + assert!( + event_names.contains(&names::HNSW_INSERT_DURATION_MS), + "Missing insert duration metric" + ); + assert!( + event_names.contains(&names::HNSW_INSERT_COUNT), + "Missing insert count metric" + ); + assert!( + event_names.contains(&names::HNSW_INDEX_SIZE), + "Missing index size metric" + ); + + // Index size should be 1 after first insert + let size_event = events + .iter() + .find(|e| e.name == names::HNSW_INDEX_SIZE) + .unwrap(); + assert_eq!(size_event.value, MetricValue::Gauge(1.0)); + } + + #[test] + fn search_emits_metrics() { + let sink = Arc::new(RecordingSink::new()); + let mut index = HnswIndex::new(3).with_metrics(sink.clone()); + + index.insert(make_id(1), vec![1.0, 0.0, 0.0]).unwrap(); + index.insert(make_id(2), vec![0.0, 1.0, 0.0]).unwrap(); + + // Clear insert metrics + sink.clear(); + + let results = index.search(&[1.0, 0.0, 0.0], 2).unwrap(); + + let events = sink.events(); + let event_names: Vec<&str> = events.iter().map(|e| e.name).collect(); + + assert!( + event_names.contains(&names::HNSW_SEARCH_DURATION_MS), + "Missing search duration metric" + ); + assert!( + event_names.contains(&names::HNSW_SEARCH_COUNT), + "Missing search count metric" + ); + assert!( + event_names.contains(&names::HNSW_SEARCH_RESULTS), + "Missing search results metric" + ); + + // Results count should match actual results + let results_event = events + .iter() + .find(|e| e.name == names::HNSW_SEARCH_RESULTS) + .unwrap(); + assert_eq!( + results_event.value, + MetricValue::Gauge(results.len() as f64) + ); + } + + #[test] + fn rebuild_emits_metrics() { + let sink = Arc::new(RecordingSink::new()); + let mut index = HnswIndex::new(3).with_metrics(sink.clone()); + + let id1 = make_id(1); + let id2 = make_id(2); + index.insert(id1, vec![1.0, 0.0, 0.0]).unwrap(); + index.insert(id2, vec![0.0, 1.0, 0.0]).unwrap(); + + // Tombstone one node + index.delete(id1); + + // Clear prior metrics + sink.clear(); + + let stats = index.rebuild(); + + let events = sink.events(); + let event_names: Vec<&str> = events.iter().map(|e| e.name).collect(); + + assert!( + event_names.contains(&names::HNSW_REBUILD_DURATION_MS), + "Missing rebuild duration metric" + ); + assert!( + event_names.contains(&names::HNSW_REBUILD_COUNT), + "Missing rebuild count metric" + ); + assert!( + event_names.contains(&names::HNSW_REBUILD_NODES_REMOVED), + "Missing rebuild nodes_removed metric" + ); + assert!( + event_names.contains(&names::HNSW_INDEX_SIZE), + "Missing index size metric after rebuild" + ); + + // nodes_removed should be 1 + let removed_event = events + .iter() + .find(|e| e.name == names::HNSW_REBUILD_NODES_REMOVED) + .unwrap(); + assert_eq!(removed_event.value, MetricValue::Gauge(1.0)); + assert_eq!(stats.nodes_removed, 1); + } + + #[test] + fn no_metrics_without_sink() { + // Ensure no panic when metrics is None (default) + let mut index = HnswIndex::new(3); + index.insert(make_id(1), vec![1.0, 0.0, 0.0]).unwrap(); + let _ = index.search(&[1.0, 0.0, 0.0], 1).unwrap(); + index.rebuild(); + } + + #[test] + fn set_metrics_at_runtime() { + let mut index = HnswIndex::new(3); + index.insert(make_id(1), vec![1.0, 0.0, 0.0]).unwrap(); + + // Attach sink mid-lifecycle + let sink = Arc::new(RecordingSink::new()); + index.set_metrics(Some(sink.clone())); + + index.insert(make_id(2), vec![0.0, 1.0, 0.0]).unwrap(); + + // Should have metrics from the second insert only + assert!(!sink.is_empty()); + + // Detach + index.set_metrics(None); + sink.clear(); + + index.insert(make_id(3), vec![0.0, 0.0, 1.0]).unwrap(); + assert!(sink.is_empty(), "No events after detaching sink"); + } + + #[test] + fn search_on_empty_index_still_emits() { + let sink = Arc::new(RecordingSink::new()); + let index = HnswIndex::new(3).with_metrics(sink.clone()); + + let results = index.search(&[1.0, 0.0, 0.0], 5).unwrap(); + assert!(results.is_empty()); + + // Should still emit duration and count (even for empty index) + let events = sink.events(); + let event_names: Vec<&str> = events.iter().map(|e| e.name).collect(); + assert!(event_names.contains(&names::HNSW_SEARCH_DURATION_MS)); + assert!(event_names.contains(&names::HNSW_SEARCH_COUNT)); + } + + #[test] + fn insert_duration_is_nonnegative() { + let sink = Arc::new(RecordingSink::new()); + let mut index = HnswIndex::new(3).with_metrics(sink.clone()); + + index.insert(make_id(1), vec![1.0, 0.0, 0.0]).unwrap(); + + let duration_event = sink + .events() + .into_iter() + .find(|e| e.name == names::HNSW_INSERT_DURATION_MS) + .unwrap(); + + match duration_event.value { + MetricValue::Histogram(ms) => assert!(ms >= 0.0, "Duration must be >= 0"), + other => panic!("Expected Histogram, got {other:?}"), + } + } +} + +#[cfg(test)] +mod search_context_tests { + use crate::search_context::HnswSearchContext; + use crate::NodeId; + use crate::{DistanceMetric, HnswConfig, HnswIndex}; + + fn make_id(seed: u8) -> NodeId { + NodeId::new([seed; 16]) + } + + fn generate_random_vector(dim: usize, seed: u64) -> Vec { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + (0..dim) + .map(|i| { + let mut hasher = DefaultHasher::new(); + (seed, i).hash(&mut hasher); + (hasher.finish() as f32 / u64::MAX as f32) * 2.0 - 1.0 + }) + .collect() + } + + #[test] + fn search_with_context_matches_search() { + // Build a non-trivial index + let config = HnswConfig::with_dimensions(64).with_seed(42); + let mut index = HnswIndex::with_config(config); + + for i in 0..200u16 { + let id = NodeId::new([ + (i >> 8) as u8, + (i & 0xff) as u8, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ]); + let vector = generate_random_vector(64, i as u64); + index.insert(id, vector).expect("insert"); + } + + let mut ctx = HnswSearchContext::new(index.config().ef_search); + + // Run multiple queries and verify results are identical + for q_seed in [0u64, 50, 100, 999, 12345] { + let query = generate_random_vector(64, q_seed); + + let results_normal = index.search(&query, 10).expect("search"); + let results_ctx = index + .search_with_context(&query, 10, &mut ctx) + .expect("search_with_context"); + + assert_eq!( + results_normal.len(), + results_ctx.len(), + "Result count should match for query seed {q_seed}" + ); + for (i, (r_normal, r_ctx)) in results_normal.iter().zip(results_ctx.iter()).enumerate() + { + assert_eq!( + r_normal.0, r_ctx.0, + "ID mismatch at position {i} for query seed {q_seed}" + ); + assert_eq!( + r_normal.1, r_ctx.1, + "Score mismatch at position {i} for query seed {q_seed}" + ); + } + } + } + + #[test] + fn context_reuse_across_many_searches() { + let config = HnswConfig::with_dimensions(32).with_seed(42); + let mut index = HnswIndex::with_config(config); + + for i in 0..100u16 { + let id = NodeId::new([i as u8; 16]); + let vector = generate_random_vector(32, i as u64); + index.insert(id, vector).expect("insert"); + } + + let mut ctx = HnswSearchContext::new(index.config().ef_search); + + // Run 50 searches reusing the same context + for q in 0..50u64 { + let query = generate_random_vector(32, q * 7); + let results = index + .search_with_context(&query, 5, &mut ctx) + .expect("search_with_context"); + assert_eq!(results.len(), 5, "Should return 5 results on iteration {q}"); + + // Verify sorted descending by score + for window in results.windows(2) { + assert!( + window[0].1 >= window[1].1, + "Results should be sorted descending" + ); + } + } + } + + #[test] + fn context_works_with_tombstones() { + let config = HnswConfig::with_dimensions(3).with_seed(42); + let mut index = HnswIndex::with_config(config); + + let id1 = make_id(1); + let id2 = make_id(2); + let id3 = make_id(3); + + index.insert(id1, vec![1.0, 0.0, 0.0]).expect("insert"); + index.insert(id2, vec![0.9, 0.1, 0.0]).expect("insert"); + index.insert(id3, vec![0.0, 1.0, 0.0]).expect("insert"); + + // Delete id1 + index.delete(id1); + + let mut ctx = HnswSearchContext::new(index.config().ef_search); + + let results_normal = index.search(&[1.0, 0.0, 0.0], 3).expect("search"); + let results_ctx = index + .search_with_context(&[1.0, 0.0, 0.0], 3, &mut ctx) + .expect("search_with_context"); + + // Should not include tombstoned id1 + assert_eq!(results_normal.len(), results_ctx.len()); + for (r_normal, r_ctx) in results_normal.iter().zip(results_ctx.iter()) { + assert_eq!(r_normal.0, r_ctx.0); + assert_eq!(r_normal.1, r_ctx.1); + assert_ne!(r_normal.0, id1, "Tombstoned id1 should not appear"); + } + } + + #[test] + fn context_with_empty_index() { + let index = HnswIndex::new(3); + let mut ctx = HnswSearchContext::new(80); + + let results = index + .search_with_context(&[1.0, 0.0, 0.0], 10, &mut ctx) + .expect("search empty"); + assert!(results.is_empty()); + } + + #[test] + fn context_dimension_mismatch() { + let mut index = HnswIndex::new(3); + index + .insert(make_id(1), vec![1.0, 0.0, 0.0]) + .expect("insert"); + + let mut ctx = HnswSearchContext::new(80); + let result = index.search_with_context(&[1.0, 0.0], 1, &mut ctx); + assert!(result.is_err()); + } + + #[test] + fn context_works_with_all_metrics() { + for metric in [ + DistanceMetric::Cosine, + DistanceMetric::L2, + DistanceMetric::Dot, + ] { + let mut config = HnswConfig::with_dimensions(4); + config.metric = metric; + config.seed = Some(42); + let mut index = HnswIndex::with_config(config); + + for i in 0..50u8 { + let id = make_id(i); + let vector = generate_random_vector(4, i as u64); + index.insert(id, vector).expect("insert"); + } + + let query = generate_random_vector(4, 999); + let mut ctx = HnswSearchContext::new(index.config().ef_search); + + let results_normal = index.search(&query, 5).expect("search"); + let results_ctx = index + .search_with_context(&query, 5, &mut ctx) + .expect("search_with_context"); + + assert_eq!( + results_normal.len(), + results_ctx.len(), + "Result count mismatch for {metric:?}" + ); + for (r_normal, r_ctx) in results_normal.iter().zip(results_ctx.iter()) { + assert_eq!(r_normal.0, r_ctx.0, "ID mismatch for {metric:?}"); + assert_eq!(r_normal.1, r_ctx.1, "Score mismatch for {metric:?}"); + } + } + } + + // ========================================================================= + // build_batch (parallel HNSW construction) + // ========================================================================= + + use std::collections::HashSet; + + #[test] + fn test_build_batch_empty() { + let mut index = HnswIndex::new(3); + index + .build_batch(vec![]) + .expect("empty batch should succeed"); + assert!(index.is_empty()); + } + + #[test] + fn test_build_batch_single_item() { + let mut index = HnswIndex::new(3); + let id = make_id(1); + index + .build_batch(vec![(id, vec![1.0, 0.0, 0.0])]) + .expect("single-item batch"); + assert_eq!(index.len(), 1); + + let results = index.search(&[1.0, 0.0, 0.0], 1).expect("search"); + assert_eq!(results.len(), 1); + assert_eq!(results[0].0, id); + } + + #[test] + fn test_build_batch_small_falls_back_to_sequential() { + // <= 32 items should use sequential fallback + let dim = 8; + let mut index = HnswIndex::new(dim); + let items: Vec<_> = (0..20u8) + .map(|i| (make_id(i), generate_random_vector(dim, i as u64))) + .collect(); + + index.build_batch(items).expect("small batch"); + assert_eq!(index.len(), 20); + + // Verify search works + let query = generate_random_vector(dim, 999); + let results = index.search(&query, 5).expect("search"); + assert_eq!(results.len(), 5); + } + + #[test] + fn test_build_batch_parallel_correctness() { + // Large enough to trigger parallel path (> 32) + let dim = 16; + let n = 200; + let config = HnswConfig { + dimensions: dim, + seed: Some(42), + ef_construction: 400, + ..HnswConfig::default() + }; + + let items: Vec<_> = (0..n) + .map(|i| { + let mut id_bytes = [0u8; 16]; + id_bytes[0] = (i & 0xFF) as u8; + id_bytes[1] = ((i >> 8) & 0xFF) as u8; + let id = NodeId::new(id_bytes); + (id, generate_random_vector(dim, i as u64)) + }) + .collect(); + + let mut index = HnswIndex::with_config(config); + index.build_batch(items.clone()).expect("parallel batch"); + + assert_eq!(index.len(), n); + + // Nearly all IDs should find themselves as nearest neighbor. + // The batch path searches a frozen seed graph, so a small number of + // vectors may not find themselves as #1 due to the approximate nature + // of the parallel build. Require >= 95% self-recall. + let mut self_found = 0usize; + for (id, vec) in &items { + let results = index.search(vec, 1).expect("search"); + assert!(!results.is_empty(), "should find at least 1 result"); + if results[0].0 == *id { + self_found += 1; + } + } + let self_recall = self_found as f64 / n as f64; + assert!( + self_recall >= 0.95, + "Self-recall too low: {:.1}% ({}/{}) (expected >= 95%)", + self_recall * 100.0, + self_found, + n + ); + } + + #[test] + fn test_build_batch_recall_quality() { + // Compare recall: build_batch vs sequential insert + let dim = 32; + let n = 500; + + let items: Vec<_> = (0..n) + .map(|i| { + let mut id_bytes = [0u8; 16]; + id_bytes[0] = (i & 0xFF) as u8; + id_bytes[1] = ((i >> 8) & 0xFF) as u8; + let id = NodeId::new(id_bytes); + (id, generate_random_vector(dim, i as u64)) + }) + .collect(); + + // Build with batch + let config_batch = HnswConfig { + dimensions: dim, + seed: Some(42), + ..HnswConfig::default() + }; + let mut batch_index = HnswIndex::with_config(config_batch); + batch_index.build_batch(items.clone()).expect("batch build"); + + // Build sequentially + let config_seq = HnswConfig { + dimensions: dim, + seed: Some(42), + ..HnswConfig::default() + }; + let mut seq_index = HnswIndex::with_config(config_seq); + for (id, vec) in &items { + seq_index + .insert(*id, vec.clone()) + .expect("sequential insert"); + } + + // Both should have same size + assert_eq!(batch_index.len(), seq_index.len()); + + // Compare recall on random queries + let k = 10; + let num_queries = 20; + let mut total_overlap = 0; + let total_possible = num_queries * k; + + for q in 0..num_queries { + let query = generate_random_vector(dim, 10_000 + q as u64); + + let batch_results = batch_index.search(&query, k).expect("batch search"); + let seq_results = seq_index.search(&query, k).expect("seq search"); + + let batch_ids: HashSet<_> = batch_results.iter().map(|(id, _)| *id).collect(); + let seq_ids: HashSet<_> = seq_results.iter().map(|(id, _)| *id).collect(); + + total_overlap += batch_ids.intersection(&seq_ids).count(); + } + + // Expect at least 55% overlap between batch and sequential builds. + // The parallel build searches a frozen seed graph with the correct degree + // cap (m_max0 for layer 0, m for upper layers), so graph topology differs + // from sequential builds. 55% is conservative for this comparison. + let recall_overlap = total_overlap as f64 / total_possible as f64; + assert!( + recall_overlap >= 0.55, + "Recall overlap too low: {:.1}% (expected >= 55%)", + recall_overlap * 100.0 + ); + } + + #[test] + fn test_build_batch_dimension_mismatch() { + let mut index = HnswIndex::new(3); + let result = index.build_batch(vec![(make_id(1), vec![1.0, 0.0])]); + assert!(result.is_err()); + } + + #[test] + fn test_build_batch_rejects_duplicate_ids() { + let mut index = HnswIndex::new(3); + index + .insert(make_id(1), vec![1.0, 0.0, 0.0]) + .expect("pre-insert"); + + let result = index.build_batch(vec![(make_id(1), vec![0.0, 1.0, 0.0])]); + assert!(result.is_err(), "should reject duplicate ID"); + } + + #[test] + fn test_build_batch_search_after_build() { + // Verify that search still works correctly after batch build + let dim = 16; + let n = 100; + + let mut index = HnswIndex::new(dim); + let items: Vec<_> = (0..n) + .map(|i| { + let mut id_bytes = [0u8; 16]; + id_bytes[0] = (i & 0xFF) as u8; + let id = NodeId::new(id_bytes); + (id, generate_random_vector(dim, i as u64)) + }) + .collect(); + + index.build_batch(items.clone()).expect("batch build"); + + // Search for the first vector -- should find itself + let results = index.search(&items[0].1, 1).expect("search"); + assert_eq!(results[0].0, items[0].0); + assert!( + results[0].1.to_f64() > 0.99, + "exact match should have high score" + ); + } + + #[test] + fn test_build_batch_with_delete_after() { + // Build batch, then delete, then search + let dim = 8; + let n = 50; + + let mut index = HnswIndex::new(dim); + let items: Vec<_> = (0..n) + .map(|i| { + let mut id_bytes = [0u8; 16]; + id_bytes[0] = (i & 0xFF) as u8; + let id = NodeId::new(id_bytes); + (id, generate_random_vector(dim, i as u64)) + }) + .collect(); + + index.build_batch(items.clone()).expect("batch build"); + + // Delete first 10 items + for (id, _) in &items[..10] { + assert!(index.delete(*id)); + } + + // Search should not return deleted items + let query = generate_random_vector(dim, 999); + let results = index.search(&query, 20).expect("search"); + + let deleted_ids: HashSet<_> = items[..10].iter().map(|(id, _)| *id).collect(); + for (id, _) in &results { + assert!( + !deleted_ids.contains(id), + "deleted ID should not appear in results" + ); + } + } + + // ========================================================================= + // INT8 Quantized Search Tests + // ========================================================================= + + #[test] + fn test_quantized_search_identical_results_small() { + let config = HnswConfig { + dimensions: 32, + seed: Some(42), + ..Default::default() + }; + let mut index = HnswIndex::with_config(config); + + for i in 0..50u8 { + let vec = generate_random_vector(32, i as u64); + index.insert(make_id(i), vec).unwrap(); + } + + let query = generate_random_vector(32, 999); + let results_f32 = index.search(&query, 10).unwrap(); + + index.set_quantized(true); + assert!(index.is_quantized()); + let results_quant = index.search(&query, 10).unwrap(); + + assert_eq!(results_f32.len(), results_quant.len()); + for (f32_result, quant_result) in results_f32.iter().zip(results_quant.iter()) { + assert_eq!(f32_result.0, quant_result.0, "ID mismatch"); + assert_eq!(f32_result.1, quant_result.1, "Score mismatch"); + } + } + + #[test] + fn test_quantized_search_identical_results_medium() { + let config = HnswConfig { + dimensions: 128, + seed: Some(42), + ef_search: 50, + ..Default::default() + }; + let mut index = HnswIndex::with_config(config); + + for i in 0..200u64 { + let vec = generate_random_vector(128, i); + let id_bytes: [u8; 16] = { + let mut b = [0u8; 16]; + b[..8].copy_from_slice(&i.to_le_bytes()); + b + }; + index.insert(NodeId::new(id_bytes), vec).unwrap(); + } + + for q_seed in 1000..1010u64 { + let query = generate_random_vector(128, q_seed); + + let results_f32 = index.search(&query, 10).unwrap(); + + index.set_quantized(true); + let results_quant = index.search(&query, 10).unwrap(); + index.set_quantized(false); + + assert_eq!( + results_f32.len(), + results_quant.len(), + "Result count mismatch for query seed {q_seed}" + ); + for (f32_r, quant_r) in results_f32.iter().zip(results_quant.iter()) { + assert_eq!(f32_r.0, quant_r.0, "ID mismatch for query seed {q_seed}"); + // Allow a small tolerance for f32 FP rounding: batch-4 SIMD kernels + // use a different FMA accumulation order than the scalar pair kernel, + // producing differences of ~1e-7 (within single-precision epsilon for + // 128-dim vectors). The neighbor IDs above confirm same recall; this + // only checks that scores are within acceptable precision. + let diff = (f32_r.1.to_f64() - quant_r.1.to_f64()).abs(); + assert!( + diff < 1e-5, + "Score mismatch for query seed {q_seed}: {} vs {} (diff={diff:.2e})", + f32_r.1.to_f64(), + quant_r.1.to_f64() + ); + } + } + } + + #[test] + fn test_quantized_builder_pattern() { + let index = HnswIndex::new(64).with_quantized(); + assert!(index.is_quantized()); + } + + #[test] + fn test_quantized_runtime_toggle() { + let mut index = HnswIndex::new(64); + assert!(!index.is_quantized()); + + index.set_quantized(true); + assert!(index.is_quantized()); + + index.set_quantized(false); + assert!(!index.is_quantized()); + } + + #[test] + fn test_quantized_arena_survives_update() { + let mut index = HnswIndex::new(3); + index.set_quantized(true); + + let id = make_id(1); + index.insert(id, vec![1.0, 0.0, 0.0]).unwrap(); + index.insert(id, vec![0.0, 1.0, 0.0]).unwrap(); + + let results = index.search(&[0.0, 1.0, 0.0], 1).unwrap(); + assert_eq!(results.len(), 1); + assert_eq!(results[0].0, id); + assert!(results[0].1.to_f64() > 0.99); + } + + #[test] + fn test_quantized_arena_survives_rebuild() { + let mut index = HnswIndex::new(3); + index.set_quantized(true); + + let id1 = make_id(1); + let id2 = make_id(2); + let id3 = make_id(3); + + index.insert(id1, vec![1.0, 0.0, 0.0]).unwrap(); + index.insert(id2, vec![0.0, 1.0, 0.0]).unwrap(); + index.insert(id3, vec![0.0, 0.0, 1.0]).unwrap(); + + index.delete(id2); + index.rebuild(); + + let results = index.search(&[1.0, 0.0, 0.0], 2).unwrap(); + assert_eq!(results.len(), 2); + assert_eq!(results[0].0, id1); + assert!(results.iter().all(|(id, _)| *id != id2)); + } + + #[test] + fn test_quantized_arena_survives_clear() { + let mut index = HnswIndex::new(3); + index.set_quantized(true); + + index.insert(make_id(1), vec![1.0, 0.0, 0.0]).unwrap(); + index.clear(); + + index.insert(make_id(2), vec![0.0, 1.0, 0.0]).unwrap(); + let results = index.search(&[0.0, 1.0, 0.0], 1).unwrap(); + assert_eq!(results.len(), 1); + assert_eq!(results[0].0, make_id(2)); + } + + #[test] + fn test_quantized_empty_index() { + let index = HnswIndex::new(3).with_quantized(); + let results = index.search(&[1.0, 0.0, 0.0], 5).unwrap(); + assert!(results.is_empty()); + } + + #[test] + fn test_quantized_only_affects_cosine() { + for metric in [DistanceMetric::Dot, DistanceMetric::L2] { + let config = HnswConfig { + dimensions: 32, + metric, + seed: Some(42), + ..Default::default() + }; + let mut index = HnswIndex::with_config(config); + + for i in 0..30u8 { + let vec = generate_random_vector(32, i as u64); + index.insert(make_id(i), vec).unwrap(); + } + + let query = generate_random_vector(32, 999); + let results_f32 = index.search(&query, 5).unwrap(); + + index.set_quantized(true); + let results_quant = index.search(&query, 5).unwrap(); + + assert_eq!( + results_f32.len(), + results_quant.len(), + "Result count mismatch for {metric:?}" + ); + for (a, b) in results_f32.iter().zip(results_quant.iter()) { + assert_eq!(a.0, b.0, "ID mismatch for {metric:?}"); + assert_eq!(a.1, b.1, "Score mismatch for {metric:?}"); + } + } + } + + #[test] + fn test_quantization_error_bounded() { + let dim = 384; + for seed in 0..20u64 { + let vec = generate_random_vector(dim, seed); + let norm: f32 = vec.iter().map(|x| x * x).sum::().sqrt(); + + let mut max_abs: f32 = 0.0; + for &v in &vec { + let abs = v.abs(); + if abs > max_abs { + max_abs = abs; + } + } + let scale = if max_abs > 1e-10 { + 127.0 / max_abs + } else { + 1.0 + }; + let quantized: Vec = vec + .iter() + .map(|&v| (v * scale).round().clamp(-127.0, 127.0) as i8) + .collect(); + + let dequantized: Vec = quantized.iter().map(|&v| v as f32 / scale).collect(); + + let dot: f32 = vec.iter().zip(dequantized.iter()).map(|(a, b)| a * b).sum(); + let dq_norm: f32 = dequantized.iter().map(|x| x * x).sum::().sqrt(); + let cos_sim = if norm > 0.0 && dq_norm > 0.0 { + dot / (norm * dq_norm) + } else { + 1.0 + }; + + assert!( + cos_sim > 0.95, + "Quantization error too high: cosine_sim={cos_sim} for seed={seed}" + ); + assert!( + cos_sim > 0.99, + "Expected high fidelity for 384d: cosine_sim={cos_sim}" + ); + } + } +} + +// ============================================================================= +// Snapshot / restore tests (Issue #2161) +// ============================================================================= + +#[cfg(test)] +mod snapshot_tests { + use crate::NodeId; + use crate::{HnswConfig, HnswIndex}; + use std::collections::HashMap; + + fn make_id(seed: u8) -> NodeId { + NodeId::new([seed; 16]) + } + + fn generate_random_vector(dim: usize, seed: u64) -> Vec { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + (0..dim) + .map(|i| { + let mut hasher = DefaultHasher::new(); + (seed, i).hash(&mut hasher); + (hasher.finish() as f32 / u64::MAX as f32) * 2.0 - 1.0 + }) + .collect() + } + + /// Snapshot from a populated index includes vector data. + #[test] + fn snapshot_includes_vector_data() { + let mut index = HnswIndex::new(4); + + let id1 = make_id(1); + let id2 = make_id(2); + let vec1 = vec![1.0, 0.0, 0.0, 0.0]; + let vec2 = vec![0.0, 1.0, 0.0, 0.0]; + + index.insert(id1, vec1.clone()).expect("insert"); + index.insert(id2, vec2.clone()).expect("insert"); + + let snap = index.snapshot(); + + assert_eq!(snap.vectors.len(), 2, "snapshot should contain 2 vectors"); + + // Vectors are sorted by NodeId bytes — look up by id + let vec_map: HashMap> = + snap.vectors.iter().map(|(id, v)| (*id, v)).collect(); + + assert_eq!( + vec_map.get(&id1).copied(), + Some(&vec1), + "id1 vector should match" + ); + assert_eq!( + vec_map.get(&id2).copied(), + Some(&vec2), + "id2 vector should match" + ); + } + + /// Empty index snapshot has no vectors. + #[test] + fn snapshot_empty_index_has_no_vectors() { + let index = HnswIndex::new(4); + let snap = index.snapshot(); + assert!( + snap.vectors.is_empty(), + "empty index snapshot has no vectors" + ); + } + + /// Self-contained round-trip: snapshot() → restore_from_snapshot_embedded(). + #[test] + fn snapshot_restore_embedded_round_trip() { + let config = HnswConfig::with_dimensions(8).with_seed(42); + let mut original = HnswIndex::with_config(config.clone()); + + let ids: Vec = (0..20u8).map(make_id).collect(); + let vecs: Vec> = (0..20u64).map(|i| generate_random_vector(8, i)).collect(); + + for (id, vec) in ids.iter().zip(vecs.iter()) { + original.insert(*id, vec.clone()).expect("insert"); + } + + // Take a snapshot and restore into a fresh index + let snap = original.snapshot(); + assert_eq!( + snap.vectors.len(), + 20, + "snapshot should embed all 20 vectors" + ); + + let mut restored = HnswIndex::with_config(config); + restored + .restore_from_snapshot_embedded(&snap) + .expect("restore embedded"); + + assert_eq!(restored.len(), 20, "restored index should have 20 nodes"); + + // Search results should match the original + let query = generate_random_vector(8, 999); + let results_orig = original.search(&query, 5).expect("search original"); + let results_rest = restored.search(&query, 5).expect("search restored"); + + assert_eq!( + results_orig.len(), + results_rest.len(), + "result count should match" + ); + for (r_orig, r_rest) in results_orig.iter().zip(results_rest.iter()) { + assert_eq!(r_orig.0, r_rest.0, "result IDs should match"); + } + } + + /// Tombstoned nodes are preserved through embedded snapshot round-trip. + #[test] + fn snapshot_restore_embedded_preserves_tombstones() { + let config = HnswConfig::with_dimensions(4).with_seed(42); + let mut index = HnswIndex::with_config(config.clone()); + + let id1 = make_id(1); + let id2 = make_id(2); + let id3 = make_id(3); + + index.insert(id1, vec![1.0, 0.0, 0.0, 0.0]).expect("insert"); + index.insert(id2, vec![0.0, 1.0, 0.0, 0.0]).expect("insert"); + index.insert(id3, vec![0.0, 0.0, 1.0, 0.0]).expect("insert"); + + // Tombstone id2 + assert!(index.delete(id2)); + + let snap = index.snapshot(); + assert_eq!(snap.total_nodes, 3); + assert_eq!(snap.tombstone_count, 1); + assert_eq!(snap.vectors.len(), 3, "all 3 vectors including tombstone"); + + let mut restored = HnswIndex::with_config(config); + restored + .restore_from_snapshot_embedded(&snap) + .expect("restore"); + + assert_eq!(restored.len(), 3, "total nodes preserved"); + assert_eq!( + restored.tombstone_stats().tombstone_count, + 1, + "tombstone count preserved" + ); + + // Search should not return id2 + let results = restored.search(&[0.0, 1.0, 0.0, 0.0], 3).expect("search"); + let result_ids: Vec = results.iter().map(|(id, _)| *id).collect(); + assert!( + !result_ids.contains(&id2), + "tombstoned id2 should not appear in results" + ); + } + + /// Snapshot serialization includes vectors and round-trips via JSON. + #[test] + fn snapshot_serialization_includes_vectors() { + let mut index = HnswIndex::new(4); + + let id1 = make_id(1); + index.insert(id1, vec![1.0, 2.0, 3.0, 4.0]).expect("insert"); + + let snap = index.snapshot(); + assert!(!snap.vectors.is_empty(), "vectors should be in snapshot"); + + let json = serde_json::to_string(&snap).expect("serialize"); + assert!( + json.contains("vectors"), + "serialized JSON should contain vectors field" + ); + + let restored_snap: crate::checkpoint::HnswSnapshot = + serde_json::from_str(&json).expect("deserialize"); + assert_eq!( + restored_snap.vectors.len(), + 1, + "deserialized snapshot should have 1 vector" + ); + assert_eq!( + restored_snap.vectors[0].0, id1, + "vector id should be preserved" + ); + assert_eq!( + restored_snap.vectors[0].1, + vec![1.0, 2.0, 3.0, 4.0], + "vector data should be preserved" + ); + } + + /// Old snapshots (no vectors field) still deserialize correctly. + #[test] + fn backward_compat_snapshot_without_vectors() { + // Simulate a snapshot from before the vectors field was added + let old_json = r#"{ + "total_nodes": 1, + "live_nodes": 1, + "tombstone_count": 0, + "max_layer": 0, + "entry_point": "01010101010101010101010101010101", + "config": {"m": 16, "ef_construction": 200, "metric": "cosine"}, + "indexed_ids": ["01010101010101010101010101010101"], + "tombstoned_ids": [], + "layers": [] + }"#; + + let snap: crate::checkpoint::HnswSnapshot = + serde_json::from_str(old_json).expect("deserialize old snapshot"); + + assert_eq!(snap.total_nodes, 1); + assert!( + snap.vectors.is_empty(), + "old snapshot deserialized with empty vectors" + ); + assert!(snap.verify().is_ok(), "old snapshot should verify"); + } + + /// restore_from_snapshot_embedded fails when snapshot has no vectors. + #[test] + fn restore_embedded_fails_without_vectors() { + let config = HnswConfig::with_dimensions(4); + let mut index = HnswIndex::with_config(config); + + let id1 = make_id(1); + let snap = crate::checkpoint::HnswSnapshot { + vector_count: 0, + total_nodes: 1, + live_nodes: 1, + tombstone_count: 0, + max_layer: 0, + entry_point: Some(id1), + config: crate::checkpoint::HnswCheckpointConfig { + m: 20, + ef_construction: 200, + metric: "cosine".to_string(), + }, + indexed_ids: vec![id1], + tombstoned_ids: vec![], + layers: vec![vec![(id1, vec![])]], + vectors: vec![], // Intentionally empty + }; + + let result = index.restore_from_snapshot_embedded(&snap); + assert!( + result.is_err(), + "should fail when snapshot has no embedded vectors" + ); + } + + /// restore_from_snapshot with external map takes priority over embedded vectors. + #[test] + fn restore_external_overrides_embedded_vectors() { + let config = HnswConfig::with_dimensions(4).with_seed(42); + let mut source = HnswIndex::with_config(config.clone()); + + let id1 = make_id(1); + source + .insert(id1, vec![1.0, 0.0, 0.0, 0.0]) + .expect("insert"); + + let snap = source.snapshot(); + + // Provide a different vector for id1 via the external map + let updated_vector = vec![0.0, 0.0, 0.0, 1.0]; + let external: HashMap> = + [(id1, updated_vector.clone())].into_iter().collect(); + + let mut restored = HnswIndex::with_config(config); + restored + .restore_from_snapshot(&snap, &external) + .expect("restore with external"); + + // The restored index should use the external (override) vector + let retrieved = restored.get_vector(&id1).expect("get vector"); + assert_eq!( + retrieved, updated_vector, + "external vector should override embedded" + ); + } + + /// Large snapshot round-trip preserves search quality. + #[test] + fn snapshot_restore_preserves_search_quality() { + let config = HnswConfig::with_dimensions(32).with_seed(42); + let mut original = HnswIndex::with_config(config.clone()); + + let n = 200usize; + for i in 0..n { + let id = NodeId::new({ + let mut b = [0u8; 16]; + b[0] = (i & 0xff) as u8; + b[1] = (i >> 8) as u8; + b + }); + let vec = generate_random_vector(32, i as u64); + original.insert(id, vec).expect("insert"); + } + + let snap = original.snapshot(); + assert_eq!(snap.vectors.len(), n, "all vectors embedded"); + + let mut restored = HnswIndex::with_config(config); + restored + .restore_from_snapshot_embedded(&snap) + .expect("restore"); + + // Search quality: >= 80% recall@10 across 10 queries + let k = 10; + let mut total_recall = 0.0; + for q in 0..10 { + let query = generate_random_vector(32, 10_000 + q); + + let results_orig: std::collections::HashSet = original + .search(&query, k) + .expect("orig search") + .into_iter() + .map(|(id, _)| id) + .collect(); + let results_rest: std::collections::HashSet = restored + .search(&query, k) + .expect("rest search") + .into_iter() + .map(|(id, _)| id) + .collect(); + + let overlap = results_orig.intersection(&results_rest).count(); + total_recall += overlap as f32 / k as f32; + } + + let avg_recall = total_recall / 10.0; + assert!( + avg_recall >= 0.8, + "restored index recall {avg_recall:.2} should be >= 0.8" + ); + } +} diff --git a/crates/khive-runtime/src/lib.rs b/crates/khive-runtime/src/lib.rs index 4e39ffda..7857a22b 100644 --- a/crates/khive-runtime/src/lib.rs +++ b/crates/khive-runtime/src/lib.rs @@ -26,9 +26,9 @@ pub mod fusion; pub mod graph_traversal; pub mod objectives; pub mod operations; -pub mod registry; pub mod pack; pub mod portability; +pub mod registry; pub mod retrieval; pub mod runtime; @@ -45,11 +45,11 @@ pub use objectives::{ VectorSimilarityObjective, }; pub use operations::{NoteSearchHit, QueryResult, Resolved}; -pub use registry::{ObjectiveRegistry, RegisteredObjective}; pub use pack::{ DispatchHook, KindHook, PackFactory, PackRegistration, PackRegistry, PackRuntime, VerbRegistry, VerbRegistryBuilder, }; pub use portability::{ImportSummary, KgArchive}; +pub use registry::{ObjectiveRegistry, RegisteredObjective}; pub use retrieval::{SearchHit, SearchSource}; pub use runtime::{parse_pack_list, KhiveRuntime, RuntimeConfig}; diff --git a/crates/khive-runtime/src/registry.rs b/crates/khive-runtime/src/registry.rs index cfab6f1c..fe700da1 100644 --- a/crates/khive-runtime/src/registry.rs +++ b/crates/khive-runtime/src/registry.rs @@ -9,7 +9,9 @@ use std::sync::Arc; use parking_lot::RwLock; -use khive_fold::objective::{Objective, ObjectiveContext, ObjectiveError, ObjectiveResult, Selection}; +use khive_fold::objective::{ + Objective, ObjectiveContext, ObjectiveError, ObjectiveResult, Selection, +}; /// A type-erased objective wrapper. pub struct RegisteredObjective { @@ -288,7 +290,9 @@ mod tests { let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64 * 2.0); registry.register("double", Box::new(obj)); - let score = registry.score("double", &5, &ObjectiveContext::new()).unwrap(); + let score = registry + .score("double", &5, &ObjectiveContext::new()) + .unwrap(); assert!((score - 10.0).abs() < 1e-12); } @@ -330,7 +334,8 @@ mod tests { let reg = Arc::clone(®istry); s.spawn(move || { let name = format!("obj_{i}"); - let obj = objective_fn(move |n: &i32, _ctx: &ObjectiveContext| *n as f64 + i as f64); + let obj = + objective_fn(move |n: &i32, _ctx: &ObjectiveContext| *n as f64 + i as f64); reg.register(name.clone(), Box::new(obj)); assert!(reg.contains(&name));