diff --git a/crates/Cargo.toml b/crates/Cargo.toml index 28a348f0..a6c09a05 100644 --- a/crates/Cargo.toml +++ b/crates/Cargo.toml @@ -10,6 +10,7 @@ members = [ "khive-gate", "khive-gate-rego", "khive-fusion", + "khive-bm25", "khive-runtime", "khive-request", "khive-pack-kg", diff --git a/crates/khive-bm25/Cargo.toml b/crates/khive-bm25/Cargo.toml new file mode 100644 index 00000000..9c915117 --- /dev/null +++ b/crates/khive-bm25/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "khive-bm25" +version.workspace = true +edition.workspace = true +authors.workspace = true +license.workspace = true +repository.workspace = true +homepage.workspace = true +keywords.workspace = true +categories.workspace = true +description = "BM25 (Okapi BM25) keyword index with deterministic scoring" + +[dependencies] +khive-score = { version = "0.2.0", path = "../khive-score" } +serde = { workspace = true } +serde_json = { workspace = true } +thiserror = { workspace = true } +parking_lot = { workspace = true } diff --git a/crates/khive-bm25/src/config.rs b/crates/khive-bm25/src/config.rs new file mode 100644 index 00000000..5620900d --- /dev/null +++ b/crates/khive-bm25/src/config.rs @@ -0,0 +1,104 @@ +//! BM25 configuration types. +//! +//! See ADR-003 for recommended parameter values. + +use serde::{Deserialize, Serialize}; + +/// BM25 configuration parameters. +/// +/// Default values (k1=1.2, b=0.75) from ADR-003 work well for most use cases. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Bm25Config { + /// Term saturation parameter. + /// + /// Higher values = diminishing returns for repeated terms. + /// Range: typically 1.2-2.0 + /// Default: 1.2 + pub k1: f64, + + /// Length normalization parameter. + /// + /// - 0 = no length normalization (favor longer docs) + /// - 1 = full length normalization (favor shorter docs) + /// + /// Range: 0.0-1.0, Default: 0.75 + pub b: f64, + + /// Maximum memory budget in bytes for the index. + /// If None, no memory limit is enforced (default). + /// If Some(limit), `index_document()` calls that would exceed the budget + /// are rejected with `RetrievalError::BudgetExceeded`. Re-indexing an + /// existing document bypasses the budget check. + #[serde(default)] + pub memory_budget: Option, +} + +impl Default for Bm25Config { + fn default() -> Self { + Self { + k1: 1.2, + b: 0.75, + memory_budget: None, + } + } +} + +impl Bm25Config { + /// Create a new BM25 configuration. + pub fn new(k1: f64, b: f64) -> Self { + Self { + k1, + b, + memory_budget: None, + } + } + + /// Set memory budget in bytes. + /// + /// When set, `index_document()` calls that would cause the estimated + /// memory usage to exceed this limit are rejected with `BudgetExceeded`. + /// Re-indexing an existing document bypasses the budget check. + #[must_use] + pub fn with_memory_budget(mut self, budget: usize) -> Self { + self.memory_budget = Some(budget); + self + } + + /// Validate configuration parameters. + pub fn validate(&self) -> Result<(), &'static str> { + if self.k1 < 0.0 { + return Err("k1 must be non-negative"); + } + if !(0.0..=1.0).contains(&self.b) { + return Err("b must be in range [0.0, 1.0]"); + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_config_default() { + let config = Bm25Config::default(); + assert!((config.k1 - 1.2).abs() < f64::EPSILON); + assert!((config.b - 0.75).abs() < f64::EPSILON); + } + + #[test] + fn test_config_validation() { + assert!(Bm25Config::new(1.2, 0.75).validate().is_ok()); + assert!(Bm25Config::new(-0.1, 0.75).validate().is_err()); + assert!(Bm25Config::new(1.2, -0.1).validate().is_err()); + assert!(Bm25Config::new(1.2, 1.5).validate().is_err()); + } + + #[test] + fn test_config_custom() { + let config = Bm25Config::new(2.0, 0.5); + assert!((config.k1 - 2.0).abs() < f64::EPSILON); + assert!((config.b - 0.5).abs() < f64::EPSILON); + } +} diff --git a/crates/khive-bm25/src/error.rs b/crates/khive-bm25/src/error.rs new file mode 100644 index 00000000..4a787863 --- /dev/null +++ b/crates/khive-bm25/src/error.rs @@ -0,0 +1,54 @@ +//! Error types for the BM25 index. + +use thiserror::Error; + +/// Classification of errors by recoverability. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ErrorKind { + /// Retrying will not help (e.g., budget exceeded, invalid config). + Permanent, + /// Retrying after a delay may succeed. + Transient, +} + +/// Errors produced by BM25 index operations. +#[derive(Debug, Error)] +pub enum RetrievalError { + /// The memory budget was exceeded when indexing a new document. + #[error( + "memory budget exceeded: current={current_usage}, item_size={item_size}, limit={limit}" + )] + BudgetExceeded { + current_usage: usize, + item_size: usize, + limit: usize, + }, + + /// Invalid BM25 configuration parameters. + #[error("configuration error: {0}")] + Configuration(String), +} + +impl RetrievalError { + /// Construct a `BudgetExceeded` error. + pub fn budget_exceeded(current_usage: usize, item_size: usize, limit: usize) -> Self { + Self::BudgetExceeded { + current_usage, + item_size, + limit, + } + } + + /// Return the [`ErrorKind`] for this error. + pub fn kind(&self) -> ErrorKind { + ErrorKind::Permanent + } + + /// Whether retrying this operation might succeed. + pub fn is_retryable(&self) -> bool { + self.kind() == ErrorKind::Transient + } +} + +/// Convenience `Result` alias for BM25 operations. +pub type Result = std::result::Result; diff --git a/crates/khive-bm25/src/index/bench_wand.rs b/crates/khive-bm25/src/index/bench_wand.rs new file mode 100644 index 00000000..31a5c840 --- /dev/null +++ b/crates/khive-bm25/src/index/bench_wand.rs @@ -0,0 +1,159 @@ +use std::hint::black_box; +use std::time::Instant; + +use super::{Bm25Index, SearchContext}; +use crate::config::Bm25Config; + +#[derive(Clone)] +struct XorShift64 { + state: u64, +} + +impl XorShift64 { + fn new(seed: u64) -> Self { + Self { state: seed.max(1) } + } + + fn next_u64(&mut self) -> u64 { + let mut x = self.state; + x ^= x << 13; + x ^= x >> 7; + x ^= x << 17; + self.state = x; + x + } + + fn next_f64(&mut self) -> f64 { + ((self.next_u64() >> 11) as f64) / ((1u64 << 53) as f64) + } + + fn gen_range(&mut self, upper: usize) -> usize { + if upper <= 1 { + 0 + } else { + (self.next_u64() as usize) % upper + } + } +} + +struct ZipfSampler { + cdf: Vec, +} + +impl ZipfSampler { + fn new(vocab_size: usize, exponent: f64) -> Self { + let mut cumulative = Vec::with_capacity(vocab_size); + let mut running = 0.0; + for rank in 1..=vocab_size { + running += 1.0 / (rank as f64).powf(exponent); + cumulative.push(running); + } + for value in &mut cumulative { + *value /= running; + } + Self { cdf: cumulative } + } + + fn sample(&self, rng: &mut XorShift64) -> usize { + let needle = rng.next_f64(); + let idx = self.cdf.partition_point(|value| *value < needle); + idx.min(self.cdf.len().saturating_sub(1)) + } +} + +fn build_vocab(size: usize) -> Vec { + (0..size).map(|idx| format!("tok_{idx:04}")).collect() +} + +fn build_index(doc_count: usize, seed: u64) -> (Bm25Index, Vec, ZipfSampler) { + let vocab = build_vocab(2_048); + let zipf = ZipfSampler::new(vocab.len(), 1.07); + let mut rng = XorShift64::new(seed); + let mut index = Bm25Index::new(Bm25Config::default()); + + for doc_idx in 0..doc_count { + let len = 24 + rng.gen_range(40); + let mut text = String::new(); + for token_idx in 0..len { + if token_idx > 0 { + text.push(' '); + } + let token = &vocab[zipf.sample(&mut rng)]; + text.push_str(token); + } + index + .index_document(format!("doc_{doc_idx}"), &text) + .expect("synthetic document should index"); + } + + (index, vocab, zipf) +} + +fn build_queries( + vocab: &[String], + zipf: &ZipfSampler, + rng: &mut XorShift64, + count: usize, + terms_per_query: usize, +) -> Vec { + let mut queries = Vec::with_capacity(count); + for _ in 0..count { + let mut query = String::new(); + for idx in 0..terms_per_query { + if idx > 0 { + query.push(' '); + } + query.push_str(&vocab[zipf.sample(rng)]); + } + queries.push(query); + } + queries +} + +/// Benchmark: WAND vs brute-force on Zipf-distributed corpora. +/// +/// Note: `search_with_context` routes to WAND only when total query postings +/// exceed `SMALL_QUERY_POSTINGS_THRESHOLD` (256). For very rare terms or +/// small corpora, the brute-force path may be taken instead, so speedup +/// numbers should be interpreted accordingly. +#[test] +#[ignore = "benchmark; run with `cargo test bench_wand -- --ignored --nocapture`"] +fn bench_bm25_wand_vs_bruteforce_zipf_matrix() { + let corpus_sizes = [10_000usize, 50_000, 100_000]; + let query_lengths = [1usize, 2, 3]; + + for &doc_count in &corpus_sizes { + let (index, vocab, zipf) = build_index(doc_count, 0xFACE_FEED ^ doc_count as u64); + + println!("\nCorpus: {doc_count} docs"); + println!("query_terms | brute_force_ms | bmw_ms | speedup_x"); + println!("------------|----------------|--------|----------"); + + for &terms_per_query in &query_lengths { + let mut rng = XorShift64::new(0xDEAD_BEEF ^ ((doc_count as u64) << terms_per_query)); + let queries = build_queries(&vocab, &zipf, &mut rng, 64, terms_per_query); + + let mut brute_ctx = SearchContext::with_capacity(512); + let brute_start = Instant::now(); + for query in &queries { + black_box(index.search_brute_force(query, 10, &mut brute_ctx)); + } + let brute_ms = brute_start.elapsed().as_secs_f64() * 1000.0; + + let mut wand_ctx = SearchContext::with_capacity(512); + let wand_start = Instant::now(); + for query in &queries { + black_box(index.search_with_context(query, 10, &mut wand_ctx)); + } + let wand_ms = wand_start.elapsed().as_secs_f64() * 1000.0; + + let speedup = if wand_ms > 0.0 { + brute_ms / wand_ms + } else { + f64::INFINITY + }; + + println!("{terms_per_query:>11} | {brute_ms:>14.3} | {wand_ms:>6.3} | {speedup:>8.2}"); + } + } +} diff --git a/crates/khive-bm25/src/index/indexing.rs b/crates/khive-bm25/src/index/indexing.rs new file mode 100644 index 00000000..9dc03245 --- /dev/null +++ b/crates/khive-bm25/src/index/indexing.rs @@ -0,0 +1,203 @@ +//! Document indexing operations for BM25 index. + +use std::collections::BTreeMap; + +use super::{Bm25Index, DocumentId}; +use crate::error::{Result, RetrievalError}; +use crate::metrics::{self, MetricEvent, MetricValue}; + +impl Bm25Index { + /// Index a document. + /// + /// Tokenizes the text and adds it to the inverted index. + /// If the document already exists, it will be re-indexed (old version removed first, + /// budget check bypassed for re-indexing). + /// + /// # Arguments + /// + /// * `doc_id` - Unique document identifier (accepts `String`, `&str`, or + /// [`DocumentId`] directly via [`Into`]). + /// * `text` - Document text to index + /// + /// # Errors + /// + /// Returns `RetrievalError::BudgetExceeded` if a memory budget is configured + /// and the new document would cause the index to exceed it. Re-indexing an + /// existing document bypasses the budget check. + /// + /// Emits `bm25.index_document.duration_ms`, `bm25.index_document.count`, + /// and `bm25.index.size` metrics when a sink is attached. + pub fn index_document(&mut self, doc_id: impl Into, text: &str) -> Result<()> { + let start = std::time::Instant::now(); + + let result = self.index_document_inner(doc_id, text); + + // Emit metrics + let elapsed = start.elapsed().as_secs_f64() * 1000.0; + metrics::emit( + &self.metrics, + MetricEvent { + name: metrics::names::BM25_INDEX_DURATION_MS, + value: MetricValue::Histogram(elapsed), + labels: vec![], + }, + ); + metrics::emit( + &self.metrics, + MetricEvent { + name: metrics::names::BM25_INDEX_COUNT, + value: MetricValue::Counter(1), + labels: vec![], + }, + ); + metrics::emit( + &self.metrics, + MetricEvent { + name: metrics::names::BM25_INDEX_SIZE, + value: MetricValue::Gauge(self.doc_count() as f64), + labels: vec![], + }, + ); + + result + } + + /// Inner `index_document` logic (uninstrumented). + fn index_document_inner(&mut self, doc_id: impl Into, text: &str) -> Result<()> { + let doc_id: DocumentId = doc_id.into(); + // Check if this is a re-index (bypass budget for existing docs) + let is_reindex = self.contains_document(&doc_id); + + // Remove existing document if present + if is_reindex { + self.remove_document(&doc_id); + } + + // Tokenize using instance tokenizer + let tokens = self.tokenizer.tokenize(text); + let doc_length = tokens.len(); + + if doc_length == 0 { + // Don't index empty documents + return Ok(()); + } + + // Budget check for new documents only (re-index bypasses) + if !is_reindex { + if let Some(limit) = self.config.memory_budget { + let current = self.memory_usage(); + let cost = self.estimate_document_cost(text); + if current + cost > limit { + return Err(RetrievalError::budget_exceeded(current, cost, limit)); + } + } + } + + // Get or assign internal u32 ID + let internal_id = self.get_or_assign_internal_id(&doc_id); + + // Count term frequencies + let mut term_freqs: BTreeMap = BTreeMap::new(); + for token in &tokens { + *term_freqs.entry(token.clone()).or_insert(0) += 1; + } + + // Update inverted index with sorted insertion to maintain doc_id order. + // WAND requires posting lists sorted by doc_id for binary-search seeks. + for (term, freq) in &term_freqs { + let postings = self.inverted_index.entry(term.clone()).or_default(); + let insert_at = postings.partition_point_by_doc_id(internal_id); + // Clamp to u8::MAX (255) for compact posting storage. + // BM25's TF saturation means tf>10 is already ~85% of max + // contribution at k1=1.2, so clamping at 255 has negligible + // scoring impact. For very long documents (>255 occurrences of + // a single term), the score will plateau slightly early. + postings.insert(insert_at, internal_id, (*freq).min(255) as u8); + } + + // Populate forward index: doc -> list of its terms (for O(terms) removal). + self.forward_index + .insert(internal_id, term_freqs.keys().cloned().collect()); + + // Update document metadata + self.doc_lengths.insert(internal_id, doc_length); + self.set_doc_length_fast(internal_id, doc_length); + self.total_tokens += doc_length; + + // IDF cache auto-invalidates on the next search when it detects + // that doc_count() has changed. No per-term eviction needed. + + // Block-max metadata is epoch-invalidated (lazy rebuild on next WAND search). + self.invalidate_block_max_after_mutation(); + + Ok(()) + } + + /// Remove a document from the index. + /// + /// Returns true if the document was found and removed, false otherwise. + pub fn remove_document(&mut self, doc_id: &str) -> bool { + // Look up internal ID + let internal_id = match self.id_to_internal.get(doc_id).copied() { + Some(id) => id, + None => return false, + }; + + // Get and remove document length + let doc_length = match self.doc_lengths.remove(&internal_id) { + Some(len) => len, + None => return false, + }; + + // Clear the fast-path vec entries (both usize and f32 mirrors). + let idx = internal_id as usize; + if idx < self.doc_lengths_vec.len() { + self.doc_lengths_vec[idx] = 0; + } + if idx < self.doc_lengths_f32.len() { + self.doc_lengths_f32[idx] = 0.0; + } + + // Update total tokens + self.total_tokens = self.total_tokens.saturating_sub(doc_length); + + // Remove from posting lists using the forward index (O(terms_in_doc) not O(|V|)). + // Falls back to full scan when the forward index is absent (e.g. after deserialization). + if let Some(terms) = self.forward_index.remove(&internal_id) { + for term in &terms { + if let Some(postings) = self.inverted_index.get_mut(term) { + let idx = postings.partition_point_by_doc_id(internal_id); + if idx < postings.len() && postings.doc_ids[idx] == internal_id { + postings.remove(idx); + } + if postings.is_empty() { + self.inverted_index.remove(term); + } + } + } + } else { + // Fallback: forward index not available (deserialized index). + // Scan all posting lists (original O(|V|) behavior). + for (_term, postings) in self.inverted_index.iter_mut() { + let idx = postings.partition_point_by_doc_id(internal_id); + if idx < postings.len() && postings.doc_ids[idx] == internal_id { + postings.remove(idx); + } + } + self.inverted_index + .retain(|_, postings| !postings.is_empty()); + } + + // Remove from ID maps + self.id_to_internal.remove(doc_id); + // Note: don't remove from internal_to_id Vec (leaves hole, but u32 IDs are never reused) + + // IDF cache auto-invalidates on the next search when it detects + // that doc_count() has changed. No per-term eviction needed. + + // Block-max metadata is epoch-invalidated. + self.invalidate_block_max_after_mutation(); + + true + } +} diff --git a/crates/khive-bm25/src/index/memory.rs b/crates/khive-bm25/src/index/memory.rs new file mode 100644 index 00000000..075a5bf2 --- /dev/null +++ b/crates/khive-bm25/src/index/memory.rs @@ -0,0 +1,170 @@ +//! Memory budget operations for BM25 index. + +use std::collections::BTreeMap; +use std::mem::size_of; + +use super::{BlockMaxBlock, Bm25Index}; + +/// Bytes per posting in the SoA layout: u32 doc_id (4) + u8 term_freq (1) = 5. +/// No alignment padding waste (separate Vecs). +const BYTES_PER_POSTING: usize = size_of::() + size_of::(); + +impl Bm25Index { + /// Get the configured memory budget, if any. + pub fn memory_budget(&self) -> Option { + self.config.memory_budget + } + + /// Set or clear the memory budget at runtime. + /// + /// Pass `Some(bytes)` to enforce a limit, or `None` to remove it. + pub fn set_memory_budget(&mut self, budget: Option) { + self.config.memory_budget = budget; + } + + /// Estimate the current memory usage of the index in bytes. + /// + /// This is an approximation. It includes the inverted index, document + /// metadata, ID maps, and block-max metadata. + pub fn memory_usage(&self) -> usize { + let mut inverted_index_size: usize = 0; + let mut block_max_size: usize = 0; + + for (term, postings) in &self.inverted_index { + // String key: heap overhead (24) + string data + inverted_index_size += 24 + term.len(); + // SoA PostingList: two Vec overheads (24 each) + + // doc_ids (n * 4 bytes) + term_freqs (n * 1 byte) = 5 bytes/posting + inverted_index_size += 48 + postings.len() * BYTES_PER_POSTING; + + // Block-max metadata sidecar: one Vec per term + let block_size = self.block_size.max(1); + let block_count = postings.len().div_ceil(block_size); + block_max_size += 24 + block_count * size_of::(); + // HashMap entry overhead for per_term map + block_max_size += 32; + } + + // doc_lengths: HashMap + // Each entry: u32 key (4) + usize value (8) + HashMap bucket overhead (~32) + let doc_lengths_size = self.doc_lengths.len() * (4 + size_of::() + 32); + + // ID mapping tables: + // id_to_internal: HashMap -- DocumentId(24 + data) + u32(4) + bucket(32) + let mut id_map_size: usize = 0; + for doc_id in self.id_to_internal.keys() { + id_map_size += 24 + doc_id.len() + 4 + 32; + } + // internal_to_id: Vec> -- vec overhead (24) + each Arc fat-ptr (16) + data + id_map_size += 24; + for doc_id in &self.internal_to_id { + // Arc heap: 16-byte header + string data. Fat-ptr on stack is 16 bytes + // but we count heap cost; the refcount block is approximated as 16 bytes. + id_map_size += 16 + doc_id.len(); + } + + // IDF cache: not counted towards budget (it's a cache, can be cleared) + + // Forward index: HashMap> + // Each entry: u32 key (4) + Vec overhead (24) + bucket (~32) + string data + let mut forward_index_size: usize = self.forward_index.len() * (4 + 24 + 32); + for terms in self.forward_index.values() { + for term in terms { + // Each String: 24 bytes overhead + string data + forward_index_size += 24 + term.len(); + } + } + + // doc_lengths_f32: Vec for SIMD batch scoring + let doc_lengths_f32_size = self.doc_lengths_f32.len() * size_of::() + 24; + + // HashMap overhead for inverted_index itself + let index_map_overhead = self.inverted_index.len() * 64; + + // Fixed overhead: config + tokenizer Arc + total_tokens + RwLocks + epoch + block_size + let fixed_overhead: usize = 192; + + inverted_index_size + + block_max_size + + doc_lengths_size + + doc_lengths_f32_size + + forward_index_size + + id_map_size + + index_map_overhead + + fixed_overhead + } + + /// Estimate the memory cost of indexing a new document. + pub fn estimate_document_cost(&self, text: &str) -> usize { + let tokens = self.tokenizer.tokenize(text); + if tokens.is_empty() { + return 0; + } + + let mut unique_terms: BTreeMap<&str, u32> = BTreeMap::new(); + for token in &tokens { + *unique_terms.entry(token.as_str()).or_insert(0) += 1; + } + + // Cost per unique term in SoA layout: u32 (4) + u8 (1) = 5 bytes + let postings_cost: usize = unique_terms.len() * BYTES_PER_POSTING; + + // New terms that don't exist yet get String key + PostingList overhead + block-max entry + let new_term_cost: usize = unique_terms + .keys() + .filter(|term| !self.inverted_index.contains_key(**term)) + .map(|term| { + // String key (24 + len) + PostingList overhead (48 = 2 Vecs) + HashMap entry (64) + let postings_entry = 24 + term.len() + 48 + 64; + let block_entry = 24 + size_of::() + 32; + postings_entry + block_entry + }) + .sum(); + + // Existing terms may gain an additional block if the posting list crosses a block boundary + let additional_block_cost: usize = unique_terms + .keys() + .filter_map(|term| { + self.inverted_index + .get(*term) + .map(|postings| postings.len()) + }) + .map(|old_len| { + let block_size = self.block_size.max(1); + let before_blocks = old_len.div_ceil(block_size); + let after_blocks = (old_len + 1).div_ceil(block_size); + if after_blocks > before_blocks { + size_of::() + } else { + 0 + } + }) + .sum(); + + // doc_lengths entry: u32(4) + usize(8) + HashMap bucket(32) = 44 + let doc_entry_cost: usize = 4 + size_of::() + 32; + + // Forward index entry: u32 key (4) + Vec overhead (24) + bucket (32) + // + each term String (24 + len) + let forward_index_cost: usize = 4 + + 24 + + 32 + + unique_terms + .keys() + .map(|term| 24 + term.len()) + .sum::(); + + // ID mapping cost: DocumentId in both maps + u32 key + // Assume average doc_id is ~36 bytes (UUID string) + let avg_doc_id_len: usize = 36; + let id_map_cost: usize = (24 + avg_doc_id_len + 4 + 32) // id_to_internal entry + + (24 + avg_doc_id_len); // internal_to_id slot + + postings_cost + + new_term_cost + + additional_block_cost + + doc_entry_cost + + forward_index_cost + + id_map_cost + } +} diff --git a/crates/khive-bm25/src/index/mod.rs b/crates/khive-bm25/src/index/mod.rs new file mode 100644 index 00000000..9f16a5c3 --- /dev/null +++ b/crates/khive-bm25/src/index/mod.rs @@ -0,0 +1,1043 @@ +//! BM25 inverted index implementation. +//! +//! # Properties +//! +//! - IDF(t) >= 0 for all terms (with +1 smoothing) +//! - Rarer terms have higher IDF +//! - TF component >= 0 and < k1 + 1 (saturation bound) +//! - Higher term frequency -> higher (but saturating) score +//! - Total BM25 score >= 0 +//! - Length factor = 1 at average doc length (no adjustment) +//! - Long documents penalized, short documents boosted +//! +//! # Floating-Point Considerations (RETRIEVAL-04) +//! +//! This implementation uses `f64` for internal BM25 score calculations due to: +//! - Logarithmic operations in IDF computation (requires floating-point) +//! - Intermediate calculations requiring full precision +//! - Standard practice in IR systems (Lucene, Elasticsearch use f64) +//! +//! ## Cross-Platform Behavior +//! +//! While f64 follows IEEE 754 on all supported platforms, minor variance may occur: +//! - FMA (fused multiply-add) availability differs across CPUs +//! - Compiler optimizations may reorder floating-point operations +//! - Extended precision (x87) on older x86 may affect intermediate results +//! +//! **Mitigation**: Scores are converted to [`DeterministicScore`] at the API boundary +//! (in [`Bm25Index::search`]) which provides: +//! - Canonical representation for storage and comparison +//! - Consistent serialization across platforms +//! - Protection against NaN propagation +//! +//! ## Golden Tests +//! +//! Golden tests in this module verify known expected values using a controlled corpus. +//! These tests use fixed documents and queries with hand-calculated expected scores +//! to catch any drift in scoring behavior across versions or platforms. +//! +//! See `tests::golden_tests` module for reference values. + +mod indexing; +mod memory; +mod search; + +#[cfg(test)] +mod bench_wand; +#[cfg(test)] +mod tests_wand; + +pub use search::SearchContext; + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering}; +use std::sync::{Arc, RwLock}; + +use super::config::Bm25Config; +use super::tokenizer::{BoxedTokenizer, SimpleTokenizer}; +use crate::error::{Result, RetrievalError}; +use crate::metrics::MetricsSink; + +/// IDF cache keyed by document frequency (`df`) rather than term string. +/// +/// IDF depends on two inputs: `df` (document frequency of a term) and `N` +/// (total document count). Multiple terms sharing the same `df` produce +/// identical IDF values, so keying by `df` (a `usize`) is both more compact +/// and more correct than keying by term string. +/// +/// When `N` changes (any add/remove), the entire cache is invalidated by +/// comparing `cached_doc_count` against the current `doc_count()`. This +/// eliminates the stale-IDF bug where targeted per-term eviction left +/// entries computed with the old `N` in the cache. +#[derive(Debug, Default)] +pub(crate) struct IdfCache { + /// The `N` (total document count) for which cached values are valid. + cached_doc_count: AtomicUsize, + /// Map from document frequency -> precomputed IDF value. + by_df: RwLock>, +} + +impl Clone for IdfCache { + fn clone(&self) -> Self { + let map_clone = self.by_df.read().map(|m| m.clone()).unwrap_or_default(); + Self { + cached_doc_count: AtomicUsize::new(self.cached_doc_count.load(AtomicOrdering::Relaxed)), + by_df: RwLock::new(map_clone), + } + } +} + +/// Default tokenizer for deserialization. +fn default_tokenizer() -> BoxedTokenizer { + Arc::new(SimpleTokenizer::default()) +} + +/// Serde helpers for `Vec>` ↔ `Vec` (transparent wire format). +mod arc_str_vec_serde { + use std::sync::Arc; + + pub fn serialize(v: &[Arc], ser: S) -> Result + where + S: serde::Serializer, + { + use serde::ser::SerializeSeq; + let mut seq = ser.serialize_seq(Some(v.len()))?; + for s in v { + seq.serialize_element(s.as_ref())?; + } + seq.end() + } + + pub fn deserialize<'de, D>(de: D) -> Result>, D::Error> + where + D: serde::Deserializer<'de>, + { + use serde::Deserialize; + let strings: Vec = Vec::deserialize(de)?; + Ok(strings.into_iter().map(|s| Arc::from(s.as_str())).collect()) + } +} + +pub(crate) const DEFAULT_BLOCK_SIZE: usize = 128; +const INITIAL_POSTINGS_EPOCH: u64 = 0; +const STALE_BLOCK_MAX_EPOCH: u64 = u64::MAX; + +fn default_block_size() -> usize { + DEFAULT_BLOCK_SIZE +} + +fn default_postings_epoch() -> u64 { + INITIAL_POSTINGS_EPOCH +} + +/// Typed document identifier for BM25 index operations. +/// +/// Wire format: plain JSON string (serde transparent). +/// +/// # ID Bridging (Hybrid Search) +/// +/// When combining BM25 keyword results with HNSW vector results in hybrid +/// search, the ID types differ: BM25 uses `DocumentId` (string-based) while +/// HNSW uses `EmbeddingId` (128-bit UUID-based). Bridging strategies include: +/// +/// 1. String-based fusion: convert both ID types to `String` before fusion. +/// 2. DocumentId fusion: convert `EmbeddingId` to `DocumentId` via its +/// display representation, then fuse using `DocumentId`. +/// 3. Application-level mapping: maintain a lookup table mapping between +/// `EmbeddingId` and `DocumentId` in the application layer. +// TODO(port): was generated by `khive_types::transparent_string_newtype!` macro +// which does not yet exist in khive-types. Expanded here manually until the macro +// lands in khive-types and is re-adopted. +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, serde::Serialize, serde::Deserialize)] +#[serde(transparent)] +pub struct DocumentId(String); + +impl DocumentId { + /// Create a new `DocumentId` from any `Into`. + pub fn new(s: impl Into) -> Self { + Self(s.into()) + } + + /// Return the inner string slice. + pub fn as_str(&self) -> &str { + &self.0 + } + + /// Consume `self` and return the inner `String`. + pub fn into_inner(self) -> String { + self.0 + } + + /// Return the length of the underlying string in bytes. + pub fn len(&self) -> usize { + self.0.len() + } + + /// Return `true` if the underlying string is empty. + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } +} + +impl std::fmt::Display for DocumentId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.0) + } +} + +impl std::ops::Deref for DocumentId { + type Target = str; + + fn deref(&self) -> &str { + &self.0 + } +} + +impl std::borrow::Borrow for DocumentId { + fn borrow(&self) -> &str { + &self.0 + } +} + +impl AsRef for DocumentId { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl From for DocumentId { + fn from(s: String) -> Self { + Self(s) + } +} + +impl From<&str> for DocumentId { + fn from(s: &str) -> Self { + Self(s.to_owned()) + } +} + +impl PartialEq for DocumentId { + fn eq(&self, other: &str) -> bool { + self.0 == other + } +} + +impl PartialEq<&str> for DocumentId { + fn eq(&self, other: &&str) -> bool { + self.0 == *other + } +} + +impl PartialEq for DocumentId { + fn eq(&self, other: &String) -> bool { + &self.0 == other + } +} + +/// Structure-of-Arrays posting list for memory-efficient storage. +/// +/// Stores doc_ids (`Vec`) and term_freqs (`Vec`) in separate +/// contiguous arrays, achieving exactly 5 bytes per posting with no +/// alignment padding waste (vs 8 bytes for AoS `struct { u32, u8 }`). +/// +/// At 200K postings this saves ~600 KB; at 1M postings, ~3 MB. +/// +/// Both arrays are always the same length and sorted by doc_id. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub(crate) struct PostingList { + /// Document IDs, sorted ascending for binary-search seeks in WAND. + pub(crate) doc_ids: Vec, + /// Term frequencies, parallel to `doc_ids`. Clamped to u8::MAX (255). + pub(crate) term_freqs: Vec, +} + +impl PostingList { + /// Number of postings in this list. + #[inline] + pub(crate) fn len(&self) -> usize { + self.doc_ids.len() + } + + /// Whether the posting list is empty. + #[inline] + pub(crate) fn is_empty(&self) -> bool { + self.doc_ids.is_empty() + } + + /// Insert a posting at the given position, maintaining sorted order. + #[inline] + pub(crate) fn insert(&mut self, index: usize, doc_id: u32, term_freq: u8) { + self.doc_ids.insert(index, doc_id); + self.term_freqs.insert(index, term_freq); + } + + /// Remove the posting at the given position. + #[inline] + pub(crate) fn remove(&mut self, index: usize) { + self.doc_ids.remove(index); + self.term_freqs.remove(index); + } + + /// Find the insertion point for a doc_id (binary search). + #[inline] + pub(crate) fn partition_point_by_doc_id(&self, target: u32) -> usize { + self.doc_ids.partition_point(|&id| id < target) + } + + /// Memory usage in bytes (actual heap allocation, no padding waste). + #[inline] + #[allow(dead_code)] // TODO: wire into memory diagnostics / health-check endpoint + pub(crate) fn heap_bytes(&self) -> usize { + // Vec capacity * 4 + Vec capacity * 1 + // Use len() as approximation (capacity >= len) + self.doc_ids.len() * 4 + self.term_freqs.len() + } +} + +/// Per-block BM25 upper-bound metadata for a posting list. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct BlockMaxBlock { + /// Smallest document id in the block. + pub(crate) min_doc_id: u32, + /// Largest document id in the block. + pub(crate) max_doc_id: u32, + /// Maximum exact BM25 contribution of this term among postings in the block. + pub(crate) max_score_contribution: f64, + /// Suffix maximum of `max_score_contribution` from this block to the end. + pub(crate) suffix_max_score: f64, +} + +/// Block-max metadata for a term posting list. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub(crate) struct TermBlockMaxMeta { + pub(crate) blocks: Vec, +} + +/// Lazily rebuilt block-max metadata cache. +#[derive(Debug, Clone)] +pub(crate) struct BlockMaxState { + pub(crate) built_epoch: u64, + pub(crate) per_term: HashMap, +} + +impl Default for BlockMaxState { + fn default() -> Self { + Self { + built_epoch: STALE_BLOCK_MAX_EPOCH, + per_term: HashMap::new(), + } + } +} + +/// BM25 (Okapi BM25) keyword index. +/// +/// An in-memory inverted index for keyword search with BM25 scoring. +/// Supports incremental updates (add/remove documents) and efficient search. +/// +/// # Tokenization +/// +/// By default, uses [`SimpleTokenizer`] for whitespace-based tokenization. +/// Custom tokenizers can be set via [`with_tokenizer`](Self::with_tokenizer). +/// +/// # RETRIEVAL-08: Thread Safety and Mutability +/// +/// ## Why `search()` Takes `&self` Not `&mut self` +/// +/// Search is designed to be concurrent-safe without external locking. The +/// only mutable state accessed during search is the IDF cache and block-max +/// metadata, which use `RwLock` for interior mutability. +/// +/// **Design alternatives considered**: +/// +/// | Approach | Pros | Cons | Decision | +/// |----------|------|------|----------| +/// | `&mut self` for search | No interior mutability | Blocks concurrent reads | Rejected | +/// | `RefCell` for cache | Simple, no sync | Not thread-safe | Rejected | +/// | `RwLock` for cache | Thread-safe, concurrent | Overhead on cache access | **Chosen** | +/// | No cache (recompute IDF) | Pure `&self` | ~10-20% slower search | Rejected | +/// +/// **Rationale**: The IDF cache significantly improves search performance (avoids +/// log computation per term per search). `RwLock` enables multiple concurrent +/// searches while allowing cache population. +/// +/// ## Concurrent Read Pattern +/// +/// When wrapped in an external `RwLock`: +/// - Multiple readers can call `search()` concurrently (requires only `&self`) +/// - Writers still need exclusive access for `index_document()`, `remove_document()`, `clear()` +/// +/// The internal `RwLock` on the IDF cache provides fine-grained locking for cache +/// updates during search, avoiding exclusive locking on the entire index. +/// +/// ## Cache Invalidation +/// +/// The IDF cache auto-invalidates when the document count changes: +/// - On `search()`, if `cached_doc_count != doc_count()`, the cache is +/// cleared and rebuilt lazily. +/// - `clear()` resets the cache and `cached_doc_count` to zero. +/// +/// Because IDF depends on both `df` and `N`, keying by `df` alone is +/// sufficient: when `N` changes, the entire cache is invalidated; when +/// `N` is stable, terms with equal `df` share the same IDF value. +/// +/// Block-max metadata is invalidated via an epoch counter: every mutation +/// bumps `postings_epoch`, and the block-max state is lazily rebuilt on +/// the next WAND search if the epochs disagree. +#[derive(Serialize, Deserialize)] +pub struct Bm25Index { + /// Term -> posting list (SoA layout: separate doc_id and term_freq arrays). + /// Posting lists are sorted by doc_id for binary-search seeks in WAND. + pub(crate) inverted_index: HashMap, + + /// Document lengths (in tokens) keyed by internal u32 ID. + /// Kept for serialization compatibility and `doc_count()`. + pub(crate) doc_lengths: HashMap, + + /// Forward map: external DocumentId -> internal u32 ID. + pub(crate) id_to_internal: HashMap, + + /// Reverse map: internal u32 ID -> shared string slice. + /// + /// Uses `Arc` instead of `DocumentId` (which wraps `String`) so that + /// `resolve_internal_id` can hand out a clone in O(1) — an atomic refcount + /// increment — rather than a heap allocation + memcpy of the UUID string. + /// All search hot-path callers only need `AsRef` / `Deref`, + /// which `Arc` satisfies. The serde wire format is identical to the + /// old `DocumentId` representation (both serialize as a bare JSON string). + #[serde(with = "arc_str_vec_serde")] + pub(crate) internal_to_id: Vec>, + + /// Next internal ID to assign. + pub(crate) next_internal_id: u32, + + /// Total token count across all documents. + pub(crate) total_tokens: usize, + + /// Monotonic counter incremented whenever postings or corpus statistics change. + /// Used to lazily invalidate block-max metadata. + #[serde(default = "default_postings_epoch")] + pub(crate) postings_epoch: u64, + + /// Fixed posting-list block size used for block-max metadata. + #[serde(default = "default_block_size")] + pub(crate) block_size: usize, + + /// Lazily rebuilt block-max metadata. + #[serde(skip, default)] + pub(crate) block_max_state: RwLock, + + /// IDF cache keyed by document frequency (`df`), auto-invalidated when + /// `doc_count()` changes. See [`IdfCache`] for design rationale. + #[serde(skip, default)] + pub(crate) idf_cache: IdfCache, + + /// Vec-indexed document lengths for O(1) hot-path access during scoring. + /// Indexed by internal u32 doc_id. Rebuilt from `doc_lengths` on + /// deserialization. This avoids HashMap lookups in the tight scoring loop. + #[serde(skip, default)] + pub(crate) doc_lengths_vec: Vec, + + /// Pre-converted f32 document lengths for SIMD batch scoring. + /// Maintained in parallel with `doc_lengths_vec`. Avoids per-scoring + /// `usize -> f32` conversion in the tight NEON batch loop. + #[serde(skip, default)] + pub(crate) doc_lengths_f32: Vec, + + /// Configuration parameters. + pub(crate) config: Bm25Config, + + /// Tokenizer for text processing. + /// Defaults to SimpleTokenizer. Skip serialization as tokenizers may not be serializable. + #[serde(skip, default = "default_tokenizer")] + pub(crate) tokenizer: BoxedTokenizer, + + /// Forward index: internal doc_id -> list of terms in that document. + /// Enables O(terms_in_doc) removal instead of O(|vocabulary|). + #[serde(skip, default)] + pub(crate) forward_index: HashMap>, + + /// Optional metrics sink for observability. + #[serde(skip)] + pub(crate) metrics: Option>, +} + +impl std::fmt::Debug for Bm25Index { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Bm25Index") + .field("doc_count", &self.doc_lengths.len()) + .field("unique_terms", &self.inverted_index.len()) + .field("total_tokens", &self.total_tokens) + .field("block_size", &self.block_size) + .field("config", &self.config) + .finish() + } +} + +impl Clone for Bm25Index { + fn clone(&self) -> Self { + let block_max_clone = self + .block_max_state + .read() + .map(|state| state.clone()) + .unwrap_or_default(); + + Self { + inverted_index: self.inverted_index.clone(), + doc_lengths: self.doc_lengths.clone(), + id_to_internal: self.id_to_internal.clone(), + // Arc clone = atomic refcount bump, not a String heap copy. + internal_to_id: self.internal_to_id.clone(), + next_internal_id: self.next_internal_id, + total_tokens: self.total_tokens, + postings_epoch: self.postings_epoch, + block_size: self.block_size, + block_max_state: RwLock::new(block_max_clone), + idf_cache: self.idf_cache.clone(), + doc_lengths_vec: self.doc_lengths_vec.clone(), + doc_lengths_f32: self.doc_lengths_f32.clone(), + forward_index: self.forward_index.clone(), + config: self.config.clone(), + tokenizer: self.tokenizer.clone(), + metrics: self.metrics.clone(), + } + } +} + +impl Default for Bm25Index { + fn default() -> Self { + Self::new(Bm25Config::default()) + } +} + +impl Bm25Index { + /// Create a new empty BM25 index with the given configuration. + /// + /// # Panics + /// + /// Panics if config validation fails (k1 < 0 or b outside [0, 1]). + /// Use [`Bm25Index::try_new`] to handle invalid config as an error. + pub fn new(config: Bm25Config) -> Self { + if let Err(e) = config.validate() { + panic!("invalid BM25 config: {e}"); + } + Self { + inverted_index: HashMap::new(), + doc_lengths: HashMap::new(), + id_to_internal: HashMap::new(), + internal_to_id: Vec::new(), + next_internal_id: 0, + total_tokens: 0, + postings_epoch: INITIAL_POSTINGS_EPOCH, + block_size: DEFAULT_BLOCK_SIZE, + block_max_state: RwLock::new(BlockMaxState::default()), + idf_cache: IdfCache::default(), + doc_lengths_vec: Vec::new(), + doc_lengths_f32: Vec::new(), + forward_index: HashMap::new(), + config, + tokenizer: Arc::new(SimpleTokenizer::default()), + metrics: None, + } + } + + /// Non-panicking constructor. Returns `Err(RetrievalError::Configuration(…))` + /// if the config is invalid instead of panicking. + pub fn try_new(config: Bm25Config) -> Result { + config + .validate() + .map_err(|e| RetrievalError::Configuration(format!("invalid BM25 config: {e}")))?; + Ok(Self::new(config)) + } + + /// Create a new BM25 index with a custom tokenizer. + /// + /// # Panics + /// + /// Panics if config validation fails (k1 < 0 or b outside [0, 1]). + pub fn with_tokenizer(config: Bm25Config, tokenizer: BoxedTokenizer) -> Self { + if let Err(e) = config.validate() { + panic!("invalid BM25 config: {e}"); + } + Self { + inverted_index: HashMap::new(), + doc_lengths: HashMap::new(), + id_to_internal: HashMap::new(), + internal_to_id: Vec::new(), + next_internal_id: 0, + total_tokens: 0, + postings_epoch: INITIAL_POSTINGS_EPOCH, + block_size: DEFAULT_BLOCK_SIZE, + block_max_state: RwLock::new(BlockMaxState::default()), + idf_cache: IdfCache::default(), + doc_lengths_vec: Vec::new(), + doc_lengths_f32: Vec::new(), + forward_index: HashMap::new(), + config, + tokenizer, + metrics: None, + } + } + + /// Set the tokenizer. + /// + /// Note: This does not re-tokenize existing documents. + /// Clear and re-index if you need consistent tokenization. + pub fn set_tokenizer(&mut self, tokenizer: BoxedTokenizer) { + self.tokenizer = tokenizer; + } + + /// Get a reference to the current tokenizer. + pub fn tokenizer(&self) -> &BoxedTokenizer { + &self.tokenizer + } + + /// Attach a metrics sink (builder pattern). + /// + /// The sink receives [`MetricEvent`]s from `search` and `index_document` + /// operations. Pass an `Arc` to share a single sink + /// across multiple indices. + #[must_use] + pub fn with_metrics(mut self, sink: Arc) -> Self { + self.metrics = Some(sink); + self + } + + /// Set or replace the metrics sink at runtime. + /// + /// Pass `Some(sink)` to enable metrics, or `None` to disable. + pub fn set_metrics(&mut self, sink: Option>) { + self.metrics = sink; + } + + /// Get the number of indexed documents. + pub fn doc_count(&self) -> usize { + self.doc_lengths.len() + } + + /// Get the average document length (in tokens). + /// + /// Returns 0.0 if no documents are indexed. + pub fn avg_doc_length(&self) -> f64 { + let count = self.doc_count(); + if count == 0 { + 0.0 + } else { + self.total_tokens as f64 / count as f64 + } + } + + /// Check if a document is indexed. + pub fn contains_document(&self, doc_id: &str) -> bool { + self.id_to_internal.contains_key(doc_id) + } + + /// Get or assign an internal u32 ID for a `DocumentId`. + fn get_or_assign_internal_id(&mut self, doc_id: &DocumentId) -> u32 { + if let Some(&id) = self.id_to_internal.get(doc_id) { + return id; + } + let id = self.next_internal_id; + self.next_internal_id = self.next_internal_id.checked_add(1) + .expect("internal document ID space exhausted (u32::MAX)"); + self.id_to_internal.insert(doc_id.clone(), id); + if id as usize >= self.internal_to_id.len() { + // Placeholder: Arc from empty &str. + self.internal_to_id.resize(id as usize + 1, Arc::from("")); + } + // Store as Arc — avoids cloning the full String on every + // search hit; lookup just does an atomic refcount bump. + self.internal_to_id[id as usize] = Arc::from(doc_id.as_str()); + id + } + + /// Resolve an internal u32 ID back to an `Arc`. + /// + /// Returns a cheaply cloneable shared reference. Callers in the search + /// hot path can `Arc::clone` this without any heap allocation. + #[inline] + fn resolve_internal_id(&self, internal_id: u32) -> Option> { + self.internal_to_id + .get(internal_id as usize) + .map(Arc::clone) + } + + /// Get the configuration. + pub fn config(&self) -> &Bm25Config { + &self.config + } + + /// Clear the index, removing all documents. + pub fn clear(&mut self) { + self.inverted_index.clear(); + self.doc_lengths.clear(); + self.doc_lengths_vec.clear(); + self.doc_lengths_f32.clear(); + self.forward_index.clear(); + self.id_to_internal.clear(); + self.internal_to_id.clear(); + self.next_internal_id = 0; + self.total_tokens = 0; + self.postings_epoch = INITIAL_POSTINGS_EPOCH; + self.idf_cache + .cached_doc_count + .store(0, AtomicOrdering::Relaxed); + if let Ok(mut cache) = self.idf_cache.by_df.write() { + cache.clear(); + } + if let Ok(mut block_state) = self.block_max_state.write() { + block_state.built_epoch = STALE_BLOCK_MAX_EPOCH; + block_state.per_term.clear(); + } + } + + /// Update the O(1) doc_lengths_vec for a given internal id. + /// Called on every document insert. + #[inline] + pub(crate) fn set_doc_length_fast(&mut self, internal_id: u32, length: usize) { + let idx = internal_id as usize; + if idx >= self.doc_lengths_vec.len() { + self.doc_lengths_vec.resize(idx + 1, 0); + } + self.doc_lengths_vec[idx] = length; + // Keep f32 mirror in sync for SIMD batch scoring. + if idx >= self.doc_lengths_f32.len() { + self.doc_lengths_f32.resize(idx + 1, 0.0); + } + self.doc_lengths_f32[idx] = length as f32; + } + + /// Look up document length by internal id using the fast Vec path. + /// Falls back to HashMap if Vec is not yet populated (deserialization). + #[inline] + pub(crate) fn doc_length_fast(&self, internal_id: u32) -> usize { + let idx = internal_id as usize; + if idx < self.doc_lengths_vec.len() { + self.doc_lengths_vec[idx] + } else { + self.doc_lengths.get(&internal_id).copied().unwrap_or(0) + } + } + + /// Rebuild `doc_lengths_vec` and `doc_lengths_f32` from `doc_lengths` HashMap. + /// Called after deserialization to populate the fast-path Vecs (see `persist::bm25`). + pub fn ensure_doc_lengths_vec(&mut self) { + if !self.doc_lengths_vec.is_empty() || self.doc_lengths.is_empty() { + return; + } + let max_id = self.doc_lengths.keys().copied().max().unwrap_or(0) as usize; + self.doc_lengths_vec.resize(max_id + 1, 0); + self.doc_lengths_f32.resize(max_id + 1, 0.0); + for (&id, &len) in &self.doc_lengths { + self.doc_lengths_vec[id as usize] = len; + self.doc_lengths_f32[id as usize] = len as f32; + } + } + + /// Get statistics about the index. + pub fn stats(&self) -> Bm25Stats { + Bm25Stats { + doc_count: self.doc_count(), + total_tokens: self.total_tokens, + avg_doc_length: self.avg_doc_length(), + unique_terms: self.inverted_index.len(), + } + } + + /// Check if the IDF cache is empty. + #[cfg(test)] + pub(crate) fn is_idf_cache_empty(&self) -> bool { + self.idf_cache + .by_df + .read() + .map(|cache| cache.is_empty()) + .unwrap_or(true) + } + + /// Invalidate block-max metadata after a corpus mutation. + /// + /// Bumps the postings epoch so that the next WAND search lazily rebuilds + /// block-max metadata. The IDF cache self-invalidates on the next search + /// when it detects that `doc_count()` has changed. + #[inline] + pub(crate) fn invalidate_block_max_after_mutation(&mut self) { + self.postings_epoch = self.postings_epoch.wrapping_add(1); + if let Ok(mut block_state) = self.block_max_state.write() { + block_state.built_epoch = STALE_BLOCK_MAX_EPOCH; + block_state.per_term.clear(); + } + } + + /// Lazily rebuild block-max metadata if the current epoch is stale. + pub(crate) fn ensure_block_max_metadata(&self) { + let target_epoch = self.postings_epoch; + + if let Ok(block_state) = self.block_max_state.read() { + if block_state.built_epoch == target_epoch { + return; + } + } + + let doc_count = self.doc_count(); + if doc_count == 0 { + if let Ok(mut block_state) = self.block_max_state.write() { + block_state.built_epoch = target_epoch; + block_state.per_term.clear(); + } + return; + } + + let avgdl = self.avg_doc_length(); + let k1 = self.config.k1; + let b = self.config.b; + + if let Ok(mut block_state) = self.block_max_state.write() { + // Double-check under write lock (another thread may have rebuilt). + if block_state.built_epoch == target_epoch { + return; + } + + let mut per_term = HashMap::with_capacity(self.inverted_index.len()); + for (term, postings) in &self.inverted_index { + let term_meta = build_term_block_max_meta( + postings, + &self.doc_lengths, + self.block_size, + idf_from_doc_freq(postings.len(), doc_count), + avgdl, + k1, + b, + ); + per_term.insert(term.clone(), term_meta); + } + + block_state.per_term = per_term; + block_state.built_epoch = target_epoch; + } + } +} + +/// Compute IDF from document frequency using the Robertson-Walker variant. +/// +/// **PROOF CORRESPONDENCE**: `Lion.Retrieval.BM25.idf_nonneg` +/// With +1 inside ln(), IDF(t) >= 0 for all terms regardless of document frequency. +#[inline] +pub(crate) fn idf_from_doc_freq(doc_freq: usize, doc_count: usize) -> f64 { + let n = doc_count as f64; + let df = doc_freq as f64; + ((n - df + 0.5) / (df + 0.5) + 1.0).ln() +} + +/// Compute a single-term BM25 contribution for a posting. +/// +/// **PROOF CORRESPONDENCE**: `Lion.Retrieval.BM25.tf_bounded` +/// TF saturation: tf * (k1 + 1) / (tf + k1 * ...) < k1 + 1 for all tf >= 0. +#[inline] +pub(crate) fn bm25_term_score( + idf: f64, + term_freq: u8, + doc_length: usize, + avgdl: f64, + k1: f64, + b: f64, +) -> f64 { + if avgdl <= f64::EPSILON { + return 0.0; + } + + let tf = term_freq as f64; + let numerator = tf * (k1 + 1.0); + let denominator = tf + k1 * (1.0 - b + b * (doc_length as f64 / avgdl)); + idf * (numerator / denominator) +} + +/// Pre-computed BM25 scoring constants for a single term. +/// +/// Eliminates redundant arithmetic in the tight per-posting scoring loop. +/// The BM25 formula per posting is: +/// score = idf * (tf * (k1+1)) / (tf + k1 * (1 - b + b * dl/avgdl)) +/// +/// Pre-computing `k1_plus_1`, `k1_times_one_minus_b`, and `k1_times_b_over_avgdl` +/// reduces the per-posting work to: 1 multiply, 1 FMA, 1 add, 1 divide, 1 multiply. +#[derive(Debug, Clone, Copy)] +pub(crate) struct Bm25TermScorer { + pub(crate) idf: f64, + /// k1 + 1.0 + k1_plus_1: f64, + /// k1 * (1.0 - b) + k1_times_one_minus_b: f64, + /// k1 * b / avgdl + k1_times_b_over_avgdl: f64, +} + +impl Bm25TermScorer { + #[inline] + pub(crate) fn new(idf: f64, k1: f64, b: f64, avgdl: f64) -> Self { + let inv_avgdl = if avgdl > f64::EPSILON { + 1.0 / avgdl + } else { + 0.0 + }; + Self { + idf, + k1_plus_1: k1 + 1.0, + k1_times_one_minus_b: k1 * (1.0 - b), + k1_times_b_over_avgdl: k1 * b * inv_avgdl, + } + } + + /// IDF value for this term. + #[inline] + pub(crate) fn idf_f32(&self) -> f32 { + self.idf as f32 + } + + /// Pre-computed k1 + 1. + #[inline] + pub(crate) fn k1_plus_1_f32(&self) -> f32 { + self.k1_plus_1 as f32 + } + + /// Pre-computed k1 * (1 - b), the constant portion of the denominator. + #[inline] + pub(crate) fn denom_base_f32(&self) -> f32 { + self.k1_times_one_minus_b as f32 + } + + /// Pre-computed k1 * b / avgdl, the per-doc-length factor in the denominator. + #[inline] + pub(crate) fn denom_dl_factor_f32(&self) -> f32 { + self.k1_times_b_over_avgdl as f32 + } + + /// Score a posting with pre-computed constants. + #[inline] + pub(crate) fn score(&self, term_freq: u8, doc_length: usize) -> f64 { + let tf = term_freq as f64; + let numerator = tf * self.k1_plus_1; + let denominator = + tf + self.k1_times_one_minus_b + self.k1_times_b_over_avgdl * (doc_length as f64); + self.idf * (numerator / denominator) + } +} + +fn build_term_block_max_meta( + postings: &PostingList, + doc_lengths: &HashMap, + block_size: usize, + idf: f64, + avgdl: f64, + k1: f64, + b: f64, +) -> TermBlockMaxMeta { + if postings.is_empty() { + return TermBlockMaxMeta::default(); + } + + let n = postings.len(); + let num_blocks = n.div_ceil(block_size); + let mut blocks = Vec::with_capacity(num_blocks); + + for block_idx in 0..num_blocks { + let start = block_idx * block_size; + let end = (start + block_size).min(n); + + let min_doc_id = postings.doc_ids[start]; + let max_doc_id = postings.doc_ids[end - 1]; + + let mut max_score_contribution = 0.0; + for i in start..end { + let doc_id = postings.doc_ids[i]; + let term_freq = postings.term_freqs[i]; + let doc_length = doc_lengths.get(&doc_id).copied().unwrap_or(0); + let score = bm25_term_score(idf, term_freq, doc_length, avgdl, k1, b); + if score > max_score_contribution { + max_score_contribution = score; + } + } + + blocks.push(BlockMaxBlock { + min_doc_id, + max_doc_id, + max_score_contribution, + suffix_max_score: max_score_contribution, + }); + } + + // Compute suffix-max scores (back to front). + let mut suffix_max = 0.0; + for block in blocks.iter_mut().rev() { + if block.max_score_contribution > suffix_max { + suffix_max = block.max_score_contribution; + } + block.suffix_max_score = suffix_max; + } + + TermBlockMaxMeta { blocks } +} + +/// Statistics about a BM25 index. +#[derive(Debug, Clone, Default)] +pub struct Bm25Stats { + /// Number of indexed documents. + pub doc_count: usize, + /// Total token count across all documents. + pub total_tokens: usize, + /// Average document length (in tokens). + pub avg_doc_length: f64, + /// Number of unique terms in the index. + pub unique_terms: usize, +} + +// ── Wire-format fixture: DocumentId ────────────────────────────────────────── +// +// These tests lock the JSON wire representation of `DocumentId` after its +// migration to `transparent_string_newtype!`. Updating this fixture = a +// wire-format migration that requires a PR-level migration plan. +#[cfg(test)] +mod document_id_wire_format { + use super::DocumentId; + + /// Frozen wire-format fixture: DocumentId must serialize as a bare JSON string. + /// + /// Wire format: `"some-document-identifier"` (not `{"0":"..."}` or any other shape). + /// This is enforced by `#[serde(transparent)]` in the macro expansion. + #[test] + fn document_id_serializes_as_plain_string() { + let id = DocumentId::new("some-document-identifier"); + let json = serde_json::to_string(&id).expect("DocumentId serialize"); + assert_eq!( + json, r#""some-document-identifier""#, + "wire format drift detected in DocumentId — must be plain JSON string", + ); + } + + #[test] + fn document_id_roundtrip() { + let id = DocumentId::new("doc_abc_123"); + let json = serde_json::to_string(&id).expect("serialize"); + let back: DocumentId = serde_json::from_str(&json).expect("deserialize"); + assert_eq!(back, id, "serde roundtrip must produce identical value"); + } + + #[test] + fn document_id_empty_string_roundtrip() { + let id = DocumentId::new(""); + let json = serde_json::to_string(&id).expect("serialize"); + assert_eq!( + json, r#""""#, + "empty DocumentId must serialize as empty JSON string" + ); + let back: DocumentId = serde_json::from_str(&json).expect("deserialize"); + assert_eq!(back, id); + } + + #[test] + fn document_id_unicode_roundtrip() { + let id = DocumentId::new("doc_\u{4e2d}\u{6587}"); + let json = serde_json::to_string(&id).expect("serialize"); + let back: DocumentId = serde_json::from_str(&json).expect("deserialize"); + assert_eq!(back, id); + } +} diff --git a/crates/khive-bm25/src/index/search.rs b/crates/khive-bm25/src/index/search.rs new file mode 100644 index 00000000..510b0c29 --- /dev/null +++ b/crates/khive-bm25/src/index/search.rs @@ -0,0 +1,1623 @@ +//! Search operations for BM25 index. +//! +//! # SIMD Acceleration +//! +//! The brute-force scoring path uses architecture-specific SIMD to process +//! postings in parallel: +//! +//! - **aarch64 (NEON)**: 4-wide batches using 128-bit NEON registers. +//! - **x86_64 (AVX2)**: 8-wide batches using 256-bit YMM registers, with +//! optional FMA for fused multiply-add in the denominator computation. +//! Detected at runtime via `is_x86_feature_detected!`. +//! - **Scalar fallback**: Used on all other targets or when AVX2 is not +//! available at runtime. +//! +//! The dispatch happens once per term (not per batch) to avoid repeated +//! feature checks in the hot loop. + +use std::cmp::{Ordering, Reverse}; +use std::collections::BinaryHeap; +use std::sync::Arc; + +use khive_score::DeterministicScore; + +use super::{BlockMaxBlock, Bm25Index, Bm25TermScorer, PostingList}; +use crate::metrics::{self, MetricEvent, MetricValue}; + +/// Postings threshold below which the brute-force SIMD scorer is used instead of +/// Block-Max WAND. The brute-force path processes postings sequentially in +/// NEON/scalar batches of 4 with zero cursor/heap overhead, which is faster than +/// WAND for moderate posting counts. WAND's block-skip pruning only wins when the +/// total postings are large enough that it can skip significant portions. +/// +/// Empirically tuned: at ~10K-16K total postings the brute-force SIMD path +/// matches or beats WAND on aarch64 (Apple M-series). Above 16K the WAND +/// block-skip savings overcome its per-cursor overhead. +const SMALL_QUERY_POSTINGS_THRESHOLD: usize = 16_384; +const TERMINATED_DOC: u32 = u32::MAX; + +// --------------------------------------------------------------------------- +// SIMD batch BM25 scoring (4-wide) +// --------------------------------------------------------------------------- + +/// Batch-score 4 postings using ARM NEON SIMD intrinsics. +/// +/// Computes the BM25 formula for 4 documents simultaneously: +/// ```text +/// score[i] = idf * (tf[i] * k1_plus_1) / (tf[i] + denom_base + denom_dl_factor * doc_len[i]) +/// ``` +/// +/// Term frequencies are provided as `u8` (clamped at indexing time) and +/// widened to f32 for SIMD arithmetic. All scoring arithmetic is done in +/// f32 for SIMD throughput. The caller is responsible for converting the +/// results back to f64 for accumulation. +/// +/// # Safety +/// +/// Uses `std::arch::aarch64` NEON intrinsics which require the target to +/// be an AArch64 CPU. This function is gated by `#[cfg(target_arch = "aarch64")]` +/// and is only called on ARM64 hardware. +#[cfg(target_arch = "aarch64")] +#[inline] +// SAFETY: Callers only reach this helper on aarch64, and the fixed-size array +// parameters guarantee the four term-frequency and document-length lanes exist. +unsafe fn score_batch_neon( + term_freqs: &[u8; 4], + doc_lengths: &[f32; 4], + idf: f32, + k1_plus_1: f32, + denom_base: f32, + denom_dl_factor: f32, +) -> [f32; 4] { + use std::arch::aarch64::*; + + // Widen u8 term frequencies to u32, then convert to f32. + let tfs_u32: [u32; 4] = [ + term_freqs[0] as u32, + term_freqs[1] as u32, + term_freqs[2] as u32, + term_freqs[3] as u32, + ]; + let tf = vcvtq_f32_u32(vld1q_u32(tfs_u32.as_ptr())); + // Load 4 pre-converted f32 document lengths. + let dl = vld1q_f32(doc_lengths.as_ptr()); + + let k1p1 = vdupq_n_f32(k1_plus_1); + let base = vdupq_n_f32(denom_base); + let dl_fac = vdupq_n_f32(denom_dl_factor); + let idf_v = vdupq_n_f32(idf); + + // numerator = tf * k1_plus_1 + let num = vmulq_f32(tf, k1p1); + // denominator = tf + denom_base + denom_dl_factor * doc_len + let denom = vaddq_f32(tf, vaddq_f32(base, vmulq_f32(dl_fac, dl))); + // score = idf * num / denom + let score = vmulq_f32(idf_v, vdivq_f32(num, denom)); + + let mut result = [0.0f32; 4]; + vst1q_f32(result.as_mut_ptr(), score); + result +} + +/// Scalar fallback for batch scoring (4-wide). +/// +/// Computes the same BM25 formula as `score_batch_neon` but using plain +/// scalar f32 arithmetic. Used when no SIMD path is available. +#[cfg(not(target_arch = "aarch64"))] +#[inline] +fn score_batch_scalar_4( + term_freqs: &[u8; 4], + doc_lengths: &[f32; 4], + idf: f32, + k1_plus_1: f32, + denom_base: f32, + denom_dl_factor: f32, +) -> [f32; 4] { + let mut result = [0.0f32; 4]; + for i in 0..4 { + let tf = term_freqs[i] as f32; + let num = tf * k1_plus_1; + let denom = tf + denom_base + denom_dl_factor * doc_lengths[i]; + result[i] = idf * (num / denom); + } + result +} + +/// Scalar fallback for batch scoring (8-wide). +/// +/// Computes BM25 scores for 8 postings using plain scalar f32 arithmetic. +/// Used on x86_64 when AVX2 is not available at runtime. +#[cfg(not(target_arch = "aarch64"))] +#[inline] +fn score_batch_scalar_8( + term_freqs: &[u8; 8], + doc_lengths: &[f32; 8], + idf: f32, + k1_plus_1: f32, + denom_base: f32, + denom_dl_factor: f32, +) -> [f32; 8] { + let mut result = [0.0f32; 8]; + for i in 0..8 { + let tf = term_freqs[i] as f32; + let num = tf * k1_plus_1; + let denom = tf + denom_base + denom_dl_factor * doc_lengths[i]; + result[i] = idf * (num / denom); + } + result +} + +// --------------------------------------------------------------------------- +// AVX2 batch BM25 scoring (8-wide, x86_64 only) +// --------------------------------------------------------------------------- + +/// Batch-score 8 postings using AVX2 SIMD intrinsics (256-bit, 8 x f32). +/// +/// Computes the BM25 formula for 8 documents simultaneously: +/// ```text +/// score[i] = idf * (tf[i] * k1_plus_1) / (tf[i] + denom_base + denom_dl_factor * doc_len[i]) +/// ``` +/// +/// The u8 term frequencies are widened to i32 via `_mm256_cvtepu8_epi32` +/// (requires only the low 64 bits of a 128-bit register), then converted +/// to f32 via `_mm256_cvtepi32_ps`. +/// +/// Uses full-precision `_mm256_div_ps` for the division. While approximate +/// reciprocal (`_mm256_rcp_ps` + Newton-Raphson) would save ~5 cycles, the +/// division is not the bottleneck here -- memory access to doc_lengths is. +/// Full precision keeps scoring deterministic with the scalar path. +/// +/// # Safety +/// +/// Requires the `avx2` target feature. The caller must verify AVX2 support +/// at runtime via `is_x86_feature_detected!("avx2")` before calling. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +#[inline] +// SAFETY: Callers must select this helper only after AVX2 runtime detection. +// Fixed-size array parameters guarantee the eight lanes read by the intrinsics. +unsafe fn score_batch_avx2( + term_freqs: &[u8; 8], + doc_lengths: &[f32; 8], + idf: f32, + k1_plus_1: f32, + denom_base: f32, + denom_dl_factor: f32, +) -> [f32; 8] { + use std::arch::x86_64::*; + + // Load 8 u8 term frequencies from a 64-bit chunk into the low half of + // a 128-bit register, then widen u8 -> i32 (AVX2) and convert i32 -> f32. + let tfs_raw = _mm_loadl_epi64(term_freqs.as_ptr() as *const __m128i); + let tfs_i32 = _mm256_cvtepu8_epi32(tfs_raw); + let tf = _mm256_cvtepi32_ps(tfs_i32); + + // Load 8 contiguous f32 doc lengths. + let dl = _mm256_loadu_ps(doc_lengths.as_ptr()); + + // Broadcast scalar constants to all 8 lanes. + let k1p1 = _mm256_set1_ps(k1_plus_1); + let base = _mm256_set1_ps(denom_base); + let dl_fac = _mm256_set1_ps(denom_dl_factor); + let idf_v = _mm256_set1_ps(idf); + + // numerator = tf * k1_plus_1 + let num = _mm256_mul_ps(tf, k1p1); + + // denominator = tf + denom_base + denom_dl_factor * doc_len + // = tf + (denom_base + denom_dl_factor * doc_len) + let dl_term = _mm256_mul_ps(dl_fac, dl); + let base_plus_dl = _mm256_add_ps(base, dl_term); + let denom = _mm256_add_ps(tf, base_plus_dl); + + // score = idf * (num / denom) + let ratio = _mm256_div_ps(num, denom); + let score = _mm256_mul_ps(idf_v, ratio); + + let mut result = [0.0f32; 8]; + _mm256_storeu_ps(result.as_mut_ptr(), score); + result +} + +/// AVX2 + FMA variant: uses fused multiply-add for the denominator. +/// +/// `denom = tf + fma(denom_dl_factor, doc_len, denom_base)` +/// +/// FMA provides a single-rounding result (vs two roundings for mul+add), +/// which may produce slightly different scores from the non-FMA path +/// (within f32 ULP). The performance difference is marginal since div_ps +/// dominates, but FMA is free when available and reduces instruction count. +/// +/// # Safety +/// +/// Requires both `avx2` and `fma` target features. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +// SAFETY: Callers must select this helper only after AVX2+FMA runtime detection. +// Fixed-size array parameters guarantee the eight lanes read by the intrinsics. +unsafe fn score_batch_avx2_fma( + term_freqs: &[u8; 8], + doc_lengths: &[f32; 8], + idf: f32, + k1_plus_1: f32, + denom_base: f32, + denom_dl_factor: f32, +) -> [f32; 8] { + use std::arch::x86_64::*; + + let tfs_raw = _mm_loadl_epi64(term_freqs.as_ptr() as *const __m128i); + let tfs_i32 = _mm256_cvtepu8_epi32(tfs_raw); + let tf = _mm256_cvtepi32_ps(tfs_i32); + + let dl = _mm256_loadu_ps(doc_lengths.as_ptr()); + + let k1p1 = _mm256_set1_ps(k1_plus_1); + let base = _mm256_set1_ps(denom_base); + let dl_fac = _mm256_set1_ps(denom_dl_factor); + let idf_v = _mm256_set1_ps(idf); + + let num = _mm256_mul_ps(tf, k1p1); + + // FMA: denom_dl_factor * doc_len + denom_base (single rounding) + let base_plus_dl = _mm256_fmadd_ps(dl_fac, dl, base); + let denom = _mm256_add_ps(tf, base_plus_dl); + + let ratio = _mm256_div_ps(num, denom); + let score = _mm256_mul_ps(idf_v, ratio); + + let mut result = [0.0f32; 8]; + _mm256_storeu_ps(result.as_mut_ptr(), score); + result +} + +/// Function pointer type for 8-wide batch scoring on x86_64. +/// +/// Resolved once per term based on runtime CPU feature detection, +/// avoiding repeated `is_x86_feature_detected!` checks in the hot loop. +#[cfg(target_arch = "x86_64")] +// SAFETY: Values of this type are only produced by `select_score_batch_8`, +// which pairs each unsafe target-feature function with matching CPU detection. +type ScoreBatch8Fn = unsafe fn(&[u8; 8], &[f32; 8], f32, f32, f32, f32) -> [f32; 8]; + +/// Select the best 8-wide scoring function for the current CPU. +/// +/// Priority: AVX2+FMA > AVX2 > scalar fallback. +/// Called once per term, the returned function pointer is used for all +/// batches within that term's posting list. +#[cfg(target_arch = "x86_64")] +#[inline] +fn select_score_batch_8() -> ScoreBatch8Fn { + if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") { + score_batch_avx2_fma + } else if is_x86_feature_detected!("avx2") { + score_batch_avx2 + } else { + // Scalar fallback when no AVX2. + |tfs, dls, idf, k1p1, base, dl_fac| score_batch_scalar_8(tfs, dls, idf, k1p1, base, dl_fac) + } +} + +/// Dispatch batch scoring to the appropriate 4-wide implementation. +/// +/// On aarch64 uses NEON SIMD; on other architectures uses scalar f32. +#[inline] +fn score_batch_4( + term_freqs: &[u8; 4], + doc_lengths: &[f32; 4], + idf: f32, + k1_plus_1: f32, + denom_base: f32, + denom_dl_factor: f32, +) -> [f32; 4] { + #[cfg(target_arch = "aarch64")] + { + // SAFETY: We are on aarch64 (checked by cfg). NEON is baseline on all + // AArch64 CPUs (ARMv8-A mandates Advanced SIMD). The input slices are + // [T; 4] arrays so alignment and length are guaranteed. + unsafe { + score_batch_neon( + term_freqs, + doc_lengths, + idf, + k1_plus_1, + denom_base, + denom_dl_factor, + ) + } + } + #[cfg(not(target_arch = "aarch64"))] + { + score_batch_scalar_4( + term_freqs, + doc_lengths, + idf, + k1_plus_1, + denom_base, + denom_dl_factor, + ) + } +} + +#[derive(Debug, Clone, Copy)] +struct HeapEntry { + doc_id: u32, + score: f64, +} + +impl PartialEq for HeapEntry { + fn eq(&self, other: &Self) -> bool { + self.doc_id == other.doc_id && self.score.to_bits() == other.score.to_bits() + } +} + +impl Eq for HeapEntry {} + +impl PartialOrd for HeapEntry { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for HeapEntry { + fn cmp(&self, other: &Self) -> Ordering { + self.score + .total_cmp(&other.score) + .then_with(|| other.doc_id.cmp(&self.doc_id)) + } +} + +#[derive(Debug, Clone, Copy)] +struct ShallowBlockInfo { + max_score: f64, + last_doc: u32, +} + +/// Reusable per-query scratch space for [`Bm25Index::search_with_context`]. +/// +/// Every call to [`Bm25Index::search`] allocates a fresh result buffer and +/// heap. Reusing one context across calls avoids that churn. +/// +/// The context is automatically cleared at the start of each search call, so +/// there is no need to call [`clear`](Self::clear) manually between queries. +/// +/// # Example +/// +/// ```rust +/// use khive_bm25::{Bm25Config, Bm25Index, SearchContext}; +/// +/// let mut index = Bm25Index::new(Bm25Config::default()); +/// index.index_document("d1", "quick brown fox").unwrap(); +/// index.index_document("d2", "lazy brown dog").unwrap(); +/// +/// let mut ctx = SearchContext::new(); +/// for query in &["quick fox", "brown dog"] { +/// let results = index.search_with_context(query, 10, &mut ctx); +/// // ctx is cleared internally and reused on the next call +/// } +/// ``` +pub struct SearchContext { + /// Vec-indexed score accumulator for brute-force path. + /// Indexed by internal doc_id. Avoids HashMap overhead for dense ID spaces. + /// Lazily sized on first use per query. + score_vec: Vec, + /// Tracks which doc_ids have non-zero scores in score_vec + /// so we can drain results without scanning the entire Vec. + touched_docs: Vec, + /// Scratch buffer for sorting results before top-k truncation. + results_buf: Vec<(u32, f64)>, + /// Internal top-k min-heap for BMW execution. + heap: BinaryHeap>, +} + +impl SearchContext { + /// Create a new, empty search context. + pub fn new() -> Self { + Self { + score_vec: Vec::new(), + touched_docs: Vec::new(), + results_buf: Vec::new(), + heap: BinaryHeap::new(), + } + } + + /// Create a search context pre-allocated for an expected number of matches. + pub fn with_capacity(estimated_matches: usize) -> Self { + Self { + score_vec: Vec::new(), + touched_docs: Vec::with_capacity(estimated_matches), + results_buf: Vec::with_capacity(estimated_matches), + heap: BinaryHeap::with_capacity(estimated_matches.min(64)), + } + } + + /// Clear all per-query state without releasing heap memory. + /// + /// Called automatically at the start of each + /// [`Bm25Index::search_with_context`] invocation. You only need to call + /// this yourself if you want to shrink the context between unrelated + /// batches of queries. + pub fn clear(&mut self) { + // Reset touched entries in score_vec without zeroing the whole vec. + for &doc_id in &self.touched_docs { + if (doc_id as usize) < self.score_vec.len() { + self.score_vec[doc_id as usize] = 0.0; + } + } + self.touched_docs.clear(); + self.results_buf.clear(); + self.heap.clear(); + } +} + +impl Default for SearchContext { + fn default() -> Self { + Self::new() + } +} + +struct TermCursor<'a> { + postings: &'a PostingList, + blocks: &'a [BlockMaxBlock], + pos: usize, + block_size: usize, + scorer: Bm25TermScorer, +} + +impl<'a> TermCursor<'a> { + #[inline] + fn new( + postings: &'a PostingList, + blocks: &'a [BlockMaxBlock], + block_size: usize, + scorer: Bm25TermScorer, + ) -> Self { + Self { + postings, + blocks, + pos: 0, + block_size, + scorer, + } + } + + #[inline] + fn is_terminated(&self) -> bool { + self.pos >= self.postings.len() + } + + #[inline] + fn doc(&self) -> u32 { + if self.pos < self.postings.doc_ids.len() { + self.postings.doc_ids[self.pos] + } else { + TERMINATED_DOC + } + } + + #[inline] + fn current_doc_id(&self) -> u32 { + self.postings.doc_ids[self.pos] + } + + #[inline] + fn current_term_freq(&self) -> u8 { + self.postings.term_freqs[self.pos] + } + + #[inline] + fn current_block_idx(&self) -> Option { + if self.is_terminated() { + None + } else { + Some(self.pos / self.block_size) + } + } + + #[inline] + fn remaining_max_score(&self) -> f64 { + self.current_block_idx() + .and_then(|idx| self.blocks.get(idx)) + .map(|block| block.suffix_max_score) + .unwrap_or(0.0) + } + + #[inline] + fn advance(&mut self) -> u32 { + if !self.is_terminated() { + self.pos += 1; + } + self.doc() + } + + #[inline] + fn seek(&mut self, target_doc: u32) -> u32 { + if self.is_terminated() { + return TERMINATED_DOC; + } + if self.doc() >= target_doc { + return self.doc(); + } + + let rel = self.postings.doc_ids[self.pos..].partition_point(|&id| id < target_doc); + self.pos += rel; + self.doc() + } + + #[inline] + fn shallow_block_info(&self, target_doc: u32) -> Option { + let current_block_idx = self.current_block_idx()?; + let rel = + self.blocks[current_block_idx..].partition_point(|block| block.max_doc_id < target_doc); + let block = self.blocks.get(current_block_idx + rel)?; + Some(ShallowBlockInfo { + max_score: block.max_score_contribution, + last_doc: block.max_doc_id, + }) + } + + #[inline] + fn score_current(&self, index: &Bm25Index) -> f64 { + if self.is_terminated() { + return 0.0; + } + let doc_id = self.current_doc_id(); + let term_freq = self.current_term_freq(); + let doc_length = index.doc_length_fast(doc_id); + self.scorer.score(term_freq, doc_length) + } +} + +impl Bm25Index { + /// Search for documents matching the query. + /// + /// Returns up to `k` documents sorted by BM25 score descending, with + /// deterministic internal-doc-id tie-breaking identical to the brute-force + /// scorer. + /// + /// For small queries (total postings < 256), falls back to exhaustive + /// brute-force scoring. For larger queries, uses the Block-Max WAND + /// algorithm for threshold-based pruning. + /// + /// # Floating-Point Boundary Conversion + /// + /// BM25 scoring uses `f64` internally for precision in logarithmic calculations. + /// At the API boundary (this method), scores are converted to [`DeterministicScore`] + /// which ensures: + /// - Canonical representation for cross-platform consistency + /// - Safe serialization without precision loss + /// - Protection against NaN/Inf propagation + /// + /// See module-level documentation for cross-platform considerations. + /// + /// # Concurrency + /// + /// This method takes `&self` (not `&mut self`) to enable concurrent reads. + /// The internal IDF cache and block-max metadata use interior mutability + /// (`RwLock`) for thread-safe updates. + /// + /// Emits `bm25.search.duration_ms`, `bm25.search.count`, and + /// `bm25.search.results` metrics when a sink is attached. + /// + /// **PROOF CORRESPONDENCE**: `Lion.Retrieval.BM25.bm25_nonneg` + /// Total BM25 score >= 0 for any query and document, since it is a sum of + /// non-negative IDF values multiplied by non-negative TF components. + /// Returns up to `k` (id, score) pairs sorted by BM25 score descending. + /// + /// The `Arc` document IDs are cheaply cloneable shared references into + /// the internal reverse-map, avoiding a heap allocation per result. Callers + /// that need a `DocumentId` can construct one via `DocumentId::new(&*arc)`. + pub fn search(&self, query_text: &str, k: usize) -> Vec<(Arc, DeterministicScore)> { + let mut ctx = SearchContext::new(); + self.search_with_context(query_text, k, &mut ctx) + } + + /// Search for documents matching the query, reusing a [`SearchContext`]. + /// + /// Behaves identically to [`search`](Self::search) but reuses the heap + /// memory inside `ctx` across calls, eliminating allocation churn per query. + /// + /// The context is automatically [`clear`](SearchContext::clear)ed at the + /// start of each call, so callers do not need to reset it manually. + pub fn search_with_context( + &self, + query_text: &str, + k: usize, + ctx: &mut SearchContext, + ) -> Vec<(Arc, DeterministicScore)> { + let start = std::time::Instant::now(); + + let results = self.search_inner(query_text, k, ctx); + + let elapsed = start.elapsed().as_secs_f64() * 1000.0; + metrics::emit( + &self.metrics, + MetricEvent { + name: metrics::names::BM25_SEARCH_DURATION_MS, + value: MetricValue::Histogram(elapsed), + labels: vec![], + }, + ); + metrics::emit( + &self.metrics, + MetricEvent { + name: metrics::names::BM25_SEARCH_COUNT, + value: MetricValue::Counter(1), + labels: vec![], + }, + ); + metrics::emit( + &self.metrics, + MetricEvent { + name: metrics::names::BM25_SEARCH_RESULTS, + value: MetricValue::Gauge(results.len() as f64), + labels: vec![], + }, + ); + + results + } + + /// Inner search logic (uninstrumented). + /// + /// Routes to brute-force for small queries and to Block-Max WAND for + /// larger ones. + fn search_inner( + &self, + query_text: &str, + k: usize, + ctx: &mut SearchContext, + ) -> Vec<(Arc, DeterministicScore)> { + if k == 0 { + ctx.clear(); + return Vec::new(); + } + + let query_tokens = self.tokenizer.tokenize(query_text); + if query_tokens.is_empty() { + ctx.clear(); + return Vec::new(); + } + + if self.doc_count() == 0 { + ctx.clear(); + return Vec::new(); + } + + let total_query_postings: usize = query_tokens + .iter() + .map(|term| { + self.inverted_index + .get(term) + .map(|postings| postings.len()) + .unwrap_or(0) + }) + .sum(); + + if total_query_postings < SMALL_QUERY_POSTINGS_THRESHOLD { + return self.search_brute_force(query_text, k, ctx); + } + + self.ensure_block_max_metadata(); + let block_state_guard = match self.block_max_state.read() { + Ok(guard) if guard.built_epoch == self.postings_epoch => guard, + _ => return self.search_brute_force(query_text, k, ctx), + }; + + ctx.clear(); + + let doc_count = self.doc_count(); + let avgdl = self.avg_doc_length(); + let mut cursors = Vec::with_capacity(query_tokens.len()); + + let k1 = self.config.k1; + let b = self.config.b; + + for term in &query_tokens { + let postings = match self.inverted_index.get(term) { + Some(postings) if !postings.is_empty() => postings, + _ => continue, + }; + let blocks = match block_state_guard.per_term.get(term) { + Some(meta) if !meta.blocks.is_empty() => meta.blocks.as_slice(), + _ => continue, + }; + let idf = self.compute_idf(term, doc_count); + let scorer = Bm25TermScorer::new(idf, k1, b, avgdl); + cursors.push(TermCursor::new(postings, blocks, self.block_size, scorer)); + } + + if cursors.is_empty() { + return Vec::new(); + } + + sort_and_prune_terminated(&mut cursors); + + while let Some((before_pivot_len, pivot_len, pivot_doc)) = + find_pivot_doc(&cursors, current_threshold_score(&ctx.heap, k)) + { + let threshold_score = current_threshold_score(&ctx.heap, k); + let block_upper_bound: f64 = cursors[..pivot_len] + .iter() + .map(|cursor| { + cursor + .shallow_block_info(pivot_doc) + .map(|info| info.max_score) + .unwrap_or(0.0) + }) + .sum(); + + // Keep equality as competitive to preserve exact tie handling. + if block_upper_bound < threshold_score { + advance_one_cursor_past_block(&mut cursors, pivot_len, pivot_doc); + if cursors.is_empty() { + break; + } + continue; + } + + if !align_cursors(&mut cursors, pivot_doc, before_pivot_len) { + if cursors.is_empty() { + break; + } + continue; + } + + let score: f64 = cursors[..pivot_len] + .iter() + .map(|cursor| cursor.score_current(self)) + .sum(); + + maybe_push_top_k( + &mut ctx.heap, + k, + HeapEntry { + doc_id: pivot_doc, + score, + }, + ); + + advance_all_cursors_on_pivot(&mut cursors, pivot_len); + if cursors.is_empty() { + break; + } + } + + heap_to_results(self, ctx) + } + + /// Exact exhaustive scorer retained as fallback and for equivalence tests. + /// + /// The inner loop is SIMD-batched (4-wide NEON on aarch64, scalar f32 + /// on other targets). Pre-converted f32 document lengths avoid per-scoring + /// integer-to-float conversion. + /// + /// **PROOF CORRESPONDENCE**: `Lion.Retrieval.BM25.tf_bounded` + /// TF saturation: tf * (k1 + 1) / (tf + k1 * ...) < k1 + 1 for all tf >= 0. + pub(crate) fn search_brute_force( + &self, + query_text: &str, + k: usize, + ctx: &mut SearchContext, + ) -> Vec<(Arc, DeterministicScore)> { + ctx.clear(); + + if k == 0 { + return Vec::new(); + } + + let query_tokens = self.tokenizer.tokenize(query_text); + if query_tokens.is_empty() { + return Vec::new(); + } + + let doc_count = self.doc_count(); + if doc_count == 0 { + return Vec::new(); + } + + // Pre-size the score accumulator to the maximum internal doc_id. + // This eliminates bounds-check branches in the tight SIMD loop. + let max_id = self.next_internal_id as usize; + if ctx.score_vec.len() < max_id { + ctx.score_vec.resize(max_id, 0.0); + } + + let avgdl = self.avg_doc_length(); + let k1 = self.config.k1; + let b = self.config.b; + + // Cache the doc_lengths_f32 slice pointer outside the term loop. + // All internal doc_ids are < max_id which is <= doc_lengths_f32.len() + // (maintained by set_doc_length_fast on every insert). + let dl_f32 = &self.doc_lengths_f32; + let scores_vec = &mut ctx.score_vec; + let touched = &mut ctx.touched_docs; + + for term in &query_tokens { + let postings = match self.inverted_index.get(term) { + Some(postings) => postings, + None => continue, + }; + let idf = self.compute_idf(term, doc_count); + let scorer = Bm25TermScorer::new(idf, k1, b, avgdl); + + // Extract f32 SIMD parameters from the pre-computed scorer. + let simd_idf = scorer.idf_f32(); + let simd_k1p1 = scorer.k1_plus_1_f32(); + let simd_base = scorer.denom_base_f32(); + let simd_dl_fac = scorer.denom_dl_factor_f32(); + + // SoA layout: doc_ids and term_freqs are separate contiguous arrays. + let n = postings.len(); + let doc_ids = &postings.doc_ids; + let tfs_arr = &postings.term_freqs; + + // On x86_64, resolve the best 8-wide scoring function once per term + // (AVX2+FMA > AVX2 > scalar) and process in chunks of 8. On aarch64, + // process in chunks of 4 using NEON. + #[cfg(target_arch = "x86_64")] + { + let score_fn = select_score_batch_8(); + let full_chunks_8 = n / 8; + + for chunk_idx in 0..full_chunks_8 { + let base_idx = chunk_idx * 8; + let tfs: [u8; 8] = [ + tfs_arr[base_idx], + tfs_arr[base_idx + 1], + tfs_arr[base_idx + 2], + tfs_arr[base_idx + 3], + tfs_arr[base_idx + 4], + tfs_arr[base_idx + 5], + tfs_arr[base_idx + 6], + tfs_arr[base_idx + 7], + ]; + let d0 = doc_ids[base_idx] as usize; + let d1 = doc_ids[base_idx + 1] as usize; + let d2 = doc_ids[base_idx + 2] as usize; + let d3 = doc_ids[base_idx + 3] as usize; + let d4 = doc_ids[base_idx + 4] as usize; + let d5 = doc_ids[base_idx + 5] as usize; + let d6 = doc_ids[base_idx + 6] as usize; + let d7 = doc_ids[base_idx + 7] as usize; + let lens = [ + dl_f32[d0], dl_f32[d1], dl_f32[d2], dl_f32[d3], dl_f32[d4], dl_f32[d5], + dl_f32[d6], dl_f32[d7], + ]; + // SAFETY: score_fn is selected based on runtime CPU feature + // detection; each variant's target_feature attribute matches + // what was detected. + let batch_scores = unsafe { + score_fn(&tfs, &lens, simd_idf, simd_k1p1, simd_base, simd_dl_fac) + }; + + // Accumulate all 8 scores. + macro_rules! accum { + ($idx:expr, $d:expr) => { + if scores_vec[$d] == 0.0 { + touched.push(doc_ids[base_idx + $idx]); + } + scores_vec[$d] += batch_scores[$idx] as f64; + }; + } + accum!(0, d0); + accum!(1, d1); + accum!(2, d2); + accum!(3, d3); + accum!(4, d4); + accum!(5, d5); + accum!(6, d6); + accum!(7, d7); + } + + // Process remaining 4-7 postings in a 4-wide batch. + let remainder_start = full_chunks_8 * 8; + let remaining = n - remainder_start; + if remaining >= 4 { + let tfs = [ + tfs_arr[remainder_start], + tfs_arr[remainder_start + 1], + tfs_arr[remainder_start + 2], + tfs_arr[remainder_start + 3], + ]; + let d0 = doc_ids[remainder_start] as usize; + let d1 = doc_ids[remainder_start + 1] as usize; + let d2 = doc_ids[remainder_start + 2] as usize; + let d3 = doc_ids[remainder_start + 3] as usize; + let lens = [dl_f32[d0], dl_f32[d1], dl_f32[d2], dl_f32[d3]]; + let batch_scores = + score_batch_4(&tfs, &lens, simd_idf, simd_k1p1, simd_base, simd_dl_fac); + if scores_vec[d0] == 0.0 { + touched.push(doc_ids[remainder_start]); + } + scores_vec[d0] += batch_scores[0] as f64; + if scores_vec[d1] == 0.0 { + touched.push(doc_ids[remainder_start + 1]); + } + scores_vec[d1] += batch_scores[1] as f64; + if scores_vec[d2] == 0.0 { + touched.push(doc_ids[remainder_start + 2]); + } + scores_vec[d2] += batch_scores[2] as f64; + if scores_vec[d3] == 0.0 { + touched.push(doc_ids[remainder_start + 3]); + } + scores_vec[d3] += batch_scores[3] as f64; + } + let scalar_start = remainder_start + if remaining >= 4 { 4 } else { 0 }; + + // Scalar tail for remaining 0-3 postings. + for i in scalar_start..n { + let doc_id = doc_ids[i]; + let d = doc_id as usize; + let doc_length = self.doc_length_fast(doc_id); + let term_score = scorer.score(tfs_arr[i], doc_length); + if scores_vec[d] == 0.0 { + touched.push(doc_id); + } + scores_vec[d] += term_score; + } + } + + // aarch64 path: 4-wide NEON batching (unchanged from original). + #[cfg(target_arch = "aarch64")] + { + let full_chunks = n / 4; + + for chunk_idx in 0..full_chunks { + let base_idx = chunk_idx * 4; + let tfs = [ + tfs_arr[base_idx], + tfs_arr[base_idx + 1], + tfs_arr[base_idx + 2], + tfs_arr[base_idx + 3], + ]; + let d0 = doc_ids[base_idx] as usize; + let d1 = doc_ids[base_idx + 1] as usize; + let d2 = doc_ids[base_idx + 2] as usize; + let d3 = doc_ids[base_idx + 3] as usize; + let lens = [dl_f32[d0], dl_f32[d1], dl_f32[d2], dl_f32[d3]]; + let batch_scores = + score_batch_4(&tfs, &lens, simd_idf, simd_k1p1, simd_base, simd_dl_fac); + if scores_vec[d0] == 0.0 { + touched.push(doc_ids[base_idx]); + } + scores_vec[d0] += batch_scores[0] as f64; + if scores_vec[d1] == 0.0 { + touched.push(doc_ids[base_idx + 1]); + } + scores_vec[d1] += batch_scores[1] as f64; + if scores_vec[d2] == 0.0 { + touched.push(doc_ids[base_idx + 2]); + } + scores_vec[d2] += batch_scores[2] as f64; + if scores_vec[d3] == 0.0 { + touched.push(doc_ids[base_idx + 3]); + } + scores_vec[d3] += batch_scores[3] as f64; + } + + // Scalar fallback for remaining 0-3 postings. + for i in (full_chunks * 4)..n { + let doc_id = doc_ids[i]; + let d = doc_id as usize; + let doc_length = self.doc_length_fast(doc_id); + let term_score = scorer.score(tfs_arr[i], doc_length); + if scores_vec[d] == 0.0 { + touched.push(doc_id); + } + scores_vec[d] += term_score; + } + } + + // Generic fallback for architectures other than x86_64 and aarch64. + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + { + for i in 0..n { + let doc_id = doc_ids[i]; + let d = doc_id as usize; + let doc_length = self.doc_length_fast(doc_id); + let term_score = scorer.score(tfs_arr[i], doc_length); + if scores_vec[d] == 0.0 { + touched.push(doc_id); + } + scores_vec[d] += term_score; + } + } + } + + // Drain touched_docs into results buffer. + ctx.results_buf.clear(); + for &doc_id in &ctx.touched_docs { + let score = ctx.score_vec[doc_id as usize]; + if score > 0.0 { + ctx.results_buf.push((doc_id, score)); + } + } + + // Partial sort: if we only need k results from a large set, use + // select_nth_unstable_by to avoid fully sorting all results. + if k < ctx.results_buf.len() { + ctx.results_buf + .select_nth_unstable_by(k, |a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0))); + ctx.results_buf.truncate(k); + } + ctx.results_buf + .sort_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0))); + + ctx.results_buf + .iter() + .take(k) + .filter_map(|(internal_id, score)| { + // resolve_internal_id returns Arc; Arc::clone is an atomic + // refcount bump — no heap allocation, no memcpy. + let doc_id = self.resolve_internal_id(*internal_id)?; + Some((doc_id, DeterministicScore::from_f64(*score))) + }) + .collect() + } + + /// Compute IDF (Inverse Document Frequency) for a term. + /// + /// Uses the BM25 IDF formula: + /// ```text + /// IDF(qi) = ln((N - n(qi) + 0.5) / (n(qi) + 0.5) + 1) + /// ``` + /// + /// This variant always returns non-negative IDF (Robertson-Walker variant). + /// Uses interior mutability for cache updates to enable concurrent reads. + /// + /// **PROOF CORRESPONDENCE**: `Lion.Retrieval.BM25.idf_nonneg` + /// With +1 inside ln(), IDF(t) >= 0 for all terms regardless of document frequency. + /// + /// **PROOF CORRESPONDENCE**: `Lion.Retrieval.BM25.idf_mono` + /// Rarer terms have higher IDF: n1 < n2 implies IDF(n1) > IDF(n2). + pub(super) fn compute_idf(&self, term: &str, doc_count: usize) -> f64 { + use std::sync::atomic::Ordering as AtomicOrdering; + + // If N changed since the cache was last populated, invalidate everything. + let cached_n = self + .idf_cache + .cached_doc_count + .load(AtomicOrdering::Relaxed); + if cached_n != doc_count { + if let Ok(mut cache) = self.idf_cache.by_df.write() { + // Double-check after acquiring the write lock to avoid races + // where another thread already cleared + updated. + let recheck = self + .idf_cache + .cached_doc_count + .load(AtomicOrdering::Relaxed); + if recheck != doc_count { + cache.clear(); + self.idf_cache + .cached_doc_count + .store(doc_count, AtomicOrdering::Relaxed); + } + } + } + + let doc_freq = self.inverted_index.get(term).map(|p| p.len()).unwrap_or(0); + + // Check cache by df (read lock) + if let Ok(cache) = self.idf_cache.by_df.read() { + if let Some(&cached) = cache.get(&doc_freq) { + return cached; + } + } + + let idf = super::idf_from_doc_freq(doc_freq, doc_count); + + // Cache by df and return (write lock) + if let Ok(mut cache) = self.idf_cache.by_df.write() { + cache.insert(doc_freq, idf); + } + + idf + } +} + +fn heap_to_results( + index: &Bm25Index, + ctx: &mut SearchContext, +) -> Vec<(Arc, DeterministicScore)> { + ctx.results_buf.clear(); + + while let Some(Reverse(entry)) = ctx.heap.pop() { + ctx.results_buf.push((entry.doc_id, entry.score)); + } + + ctx.results_buf + .sort_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0))); + + ctx.results_buf + .iter() + .filter_map(|(internal_id, score)| { + // resolve_internal_id returns Arc; clone = atomic refcount bump. + let doc_id = index.resolve_internal_id(*internal_id)?; + Some((doc_id, DeterministicScore::from_f64(*score))) + }) + .collect() +} + +fn current_threshold_score(heap: &BinaryHeap>, k: usize) -> f64 { + if heap.len() < k { + 0.0 + } else { + heap.peek().map(|entry| entry.0.score).unwrap_or(0.0) + } +} + +fn maybe_push_top_k(heap: &mut BinaryHeap>, k: usize, candidate: HeapEntry) { + if k == 0 { + return; + } + + if heap.len() < k { + heap.push(Reverse(candidate)); + return; + } + + let should_replace = heap.peek().map(|worst| candidate > worst.0).unwrap_or(true); + if should_replace { + let _ = heap.pop(); + heap.push(Reverse(candidate)); + } +} + +fn find_pivot_doc(cursors: &[TermCursor<'_>], threshold: f64) -> Option<(usize, usize, u32)> { + let mut upper_bound_sum = 0.0; + let mut before_pivot_len = 0usize; + let mut pivot_doc = TERMINATED_DOC; + + while before_pivot_len < cursors.len() { + upper_bound_sum += cursors[before_pivot_len].remaining_max_score(); + if upper_bound_sum >= threshold { + pivot_doc = cursors[before_pivot_len].doc(); + break; + } + before_pivot_len += 1; + } + + if pivot_doc == TERMINATED_DOC { + return None; + } + + let mut pivot_len = before_pivot_len + 1; + while pivot_len < cursors.len() && cursors[pivot_len].doc() == pivot_doc { + pivot_len += 1; + } + + Some((before_pivot_len, pivot_len, pivot_doc)) +} + +fn align_cursors( + cursors: &mut Vec>, + pivot_doc: u32, + before_pivot_len: usize, +) -> bool { + debug_assert_ne!(pivot_doc, TERMINATED_DOC); + + for idx in (0..before_pivot_len).rev() { + let new_doc = cursors[idx].seek(pivot_doc); + if new_doc != pivot_doc { + sort_and_prune_terminated(cursors); + return false; + } + } + + true +} + +fn advance_all_cursors_on_pivot(cursors: &mut Vec>, pivot_len: usize) { + for cursor in &mut cursors[..pivot_len] { + cursor.advance(); + } + sort_and_prune_terminated(cursors); +} + +/// Advance one cursor past the current block when the block-level upper +/// bound is below the threshold. +/// +/// Selects the cursor whose current block ends **earliest** (minimum +/// `last_doc`) among the pivot cursors. This minimizes skip distance and +/// is the correct BMW cursor selection strategy -- advancing past the +/// smallest block boundary guarantees forward progress with minimal +/// overshoot. The seek target is that earliest block end + 1, bounded +/// by the smallest doc_id among non-pivot cursors so we do not overshoot +/// documents that other cursors still reference. +fn advance_one_cursor_past_block( + cursors: &mut Vec>, + pivot_len: usize, + pivot_doc: u32, +) { + let mut cursor_to_seek = None; + let mut earliest_block_end = TERMINATED_DOC; + let mut doc_to_seek_after = TERMINATED_DOC; + + for (idx, cursor) in cursors[..pivot_len].iter().enumerate() { + if let Some(info) = cursor.shallow_block_info(pivot_doc) { + if info.last_doc < doc_to_seek_after { + doc_to_seek_after = info.last_doc; + } + // Select the cursor with the earliest block end (minimum last_doc). + // This minimizes skip distance for optimal BMW pruning. + if info.last_doc < earliest_block_end { + earliest_block_end = info.last_doc; + cursor_to_seek = Some(idx); + } + } + } + + if doc_to_seek_after != TERMINATED_DOC { + doc_to_seek_after = doc_to_seek_after.saturating_add(1); + } + + for cursor in &cursors[pivot_len..] { + let doc = cursor.doc(); + if doc < doc_to_seek_after { + doc_to_seek_after = doc; + } + } + + if let Some(idx) = cursor_to_seek { + // Ensure forward progress: if the non-pivot cap reduced doc_to_seek_after + // to at or below the cursor's current position, the seek would be a no-op. + // This can happen when a non-pivot cursor points to a doc_id smaller than + // the block-end target (e.g., a short posting list cursor at doc 3 while + // the chosen cursor is already at doc 150). Force at least +1 advance. + let current_doc = cursors[idx].doc(); + if doc_to_seek_after <= current_doc { + doc_to_seek_after = current_doc.saturating_add(1); + } + cursors[idx].seek(doc_to_seek_after); + } + + sort_and_prune_terminated(cursors); +} + +fn sort_and_prune_terminated(cursors: &mut Vec>) { + cursors.retain(|cursor| !cursor.is_terminated()); + cursors.sort_by_key(|cursor| cursor.doc()); +} + +// --------------------------------------------------------------------------- +// Tests: SIMD batch scoring parity and edge cases +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests_simd_scoring { + use super::*; + + /// Reference scalar BM25 score for a single posting. + fn scalar_bm25(tf: u8, doc_len: f32, idf: f32, k1p1: f32, base: f32, dl_fac: f32) -> f32 { + let tf = tf as f32; + let num = tf * k1p1; + let denom = tf + base + dl_fac * doc_len; + idf * (num / denom) + } + + /// Compute reference scores for an arbitrary-length batch using scalar code. + fn reference_scores( + tfs: &[u8], + dls: &[f32], + idf: f32, + k1p1: f32, + base: f32, + dl_fac: f32, + ) -> Vec { + tfs.iter() + .zip(dls.iter()) + .map(|(&tf, &dl)| scalar_bm25(tf, dl, idf, k1p1, base, dl_fac)) + .collect() + } + + // Test parameters (standard BM25 with k1=1.2, b=0.75, avgdl=10.0) + const TEST_IDF: f32 = 1.5; + const TEST_K1P1: f32 = 2.2; // k1 + 1 = 1.2 + 1 + const TEST_BASE: f32 = 0.3; // k1 * (1 - b) = 1.2 * 0.25 + const TEST_DL_FAC: f32 = 0.09; // k1 * b / avgdl = 1.2 * 0.75 / 10.0 + + // ----------------------------------------------------------------------- + // Test 1: scalar_4 vs reference (parity check) + // ----------------------------------------------------------------------- + + #[test] + fn test_score_batch_4_matches_scalar() { + let tfs: [u8; 4] = [1, 3, 5, 10]; + let dls: [f32; 4] = [8.0, 12.0, 5.0, 20.0]; + + let batch = score_batch_4(&tfs, &dls, TEST_IDF, TEST_K1P1, TEST_BASE, TEST_DL_FAC); + let reference = reference_scores(&tfs, &dls, TEST_IDF, TEST_K1P1, TEST_BASE, TEST_DL_FAC); + + for i in 0..4 { + assert!( + (batch[i] - reference[i]).abs() < 1e-6, + "batch_4[{i}] = {}, expected {} (delta {})", + batch[i], + reference[i], + (batch[i] - reference[i]).abs() + ); + } + } + + // ----------------------------------------------------------------------- + // Test 2: x86_64 AVX2 8-wide vs scalar parity + // ----------------------------------------------------------------------- + + #[cfg(target_arch = "x86_64")] + #[test] + fn test_avx2_matches_scalar_basic() { + if !is_x86_feature_detected!("avx2") { + eprintln!("AVX2 not available, skipping test"); + return; + } + + let tfs: [u8; 8] = [1, 2, 3, 5, 8, 13, 21, 34]; + let dls: [f32; 8] = [5.0, 10.0, 15.0, 20.0, 25.0, 30.0, 35.0, 40.0]; + + // SAFETY: The test returns early unless AVX2 is detected, and the + // fixed-size arrays provide all lanes consumed by the helper. + let avx2_result = + unsafe { score_batch_avx2(&tfs, &dls, TEST_IDF, TEST_K1P1, TEST_BASE, TEST_DL_FAC) }; + let reference = reference_scores(&tfs, &dls, TEST_IDF, TEST_K1P1, TEST_BASE, TEST_DL_FAC); + + for i in 0..8 { + assert!( + (avx2_result[i] - reference[i]).abs() < 1e-6, + "avx2[{i}] = {}, expected {} (delta {})", + avx2_result[i], + reference[i], + (avx2_result[i] - reference[i]).abs() + ); + } + } + + // ----------------------------------------------------------------------- + // Test 3: AVX2+FMA vs scalar (slightly relaxed tolerance due to FMA rounding) + // ----------------------------------------------------------------------- + + #[cfg(target_arch = "x86_64")] + #[test] + fn test_avx2_fma_matches_scalar() { + if !is_x86_feature_detected!("avx2") || !is_x86_feature_detected!("fma") { + eprintln!("AVX2+FMA not available, skipping test"); + return; + } + + let tfs: [u8; 8] = [0, 1, 127, 255, 42, 7, 99, 200]; + let dls: [f32; 8] = [1.0, 2.0, 100.0, 0.5, 10.0, 50.0, 3.0, 1000.0]; + + // SAFETY: The test returns early unless AVX2+FMA is detected, and the + // fixed-size arrays provide all lanes consumed by the helper. + let fma_result = unsafe { + score_batch_avx2_fma(&tfs, &dls, TEST_IDF, TEST_K1P1, TEST_BASE, TEST_DL_FAC) + }; + let reference = reference_scores(&tfs, &dls, TEST_IDF, TEST_K1P1, TEST_BASE, TEST_DL_FAC); + + // FMA has single rounding vs two roundings in mul+add, so allow slightly + // more tolerance (1 ULP of f32 ~ 1.19e-7, we allow ~10 ULPs). + for i in 0..8 { + let tol = reference[i].abs() * 1e-6 + 1e-7; + assert!( + (fma_result[i] - reference[i]).abs() < tol, + "fma[{i}] = {}, expected {} (delta {}, tol {})", + fma_result[i], + reference[i], + (fma_result[i] - reference[i]).abs(), + tol + ); + } + } + + // ----------------------------------------------------------------------- + // Test 4: x86_64 dispatch function selects correctly and produces correct results + // ----------------------------------------------------------------------- + + #[cfg(target_arch = "x86_64")] + #[test] + fn test_dispatch_score_batch_8() { + let score_fn = select_score_batch_8(); + + let tfs: [u8; 8] = [3, 7, 1, 15, 0, 255, 128, 50]; + let dls: [f32; 8] = [10.0, 5.0, 20.0, 8.0, 100.0, 1.0, 15.0, 30.0]; + + // SAFETY: `select_score_batch_8` only returns a target-feature helper + // after matching runtime CPU detection; otherwise it returns scalar. + let dispatched = + unsafe { score_fn(&tfs, &dls, TEST_IDF, TEST_K1P1, TEST_BASE, TEST_DL_FAC) }; + let reference = reference_scores(&tfs, &dls, TEST_IDF, TEST_K1P1, TEST_BASE, TEST_DL_FAC); + + for i in 0..8 { + let tol = reference[i].abs() * 1e-5 + 1e-7; + assert!( + (dispatched[i] - reference[i]).abs() < tol, + "dispatch[{i}] = {}, expected {} (delta {})", + dispatched[i], + reference[i], + (dispatched[i] - reference[i]).abs() + ); + } + } + + // ----------------------------------------------------------------------- + // Test 5: Edge case -- tf=0 produces zero score + // ----------------------------------------------------------------------- + + #[test] + fn test_tf_zero_produces_zero_score() { + let tfs_4: [u8; 4] = [0, 0, 0, 0]; + let dls_4: [f32; 4] = [10.0, 20.0, 5.0, 1.0]; + let result = score_batch_4(&tfs_4, &dls_4, TEST_IDF, TEST_K1P1, TEST_BASE, TEST_DL_FAC); + for i in 0..4 { + assert!( + result[i].abs() < 1e-10, + "tf=0 should produce ~0 score, got {}", + result[i] + ); + } + + #[cfg(target_arch = "x86_64")] + if is_x86_feature_detected!("avx2") { + let tfs_8: [u8; 8] = [0; 8]; + let dls_8: [f32; 8] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + // SAFETY: This branch only runs when AVX2 is detected, and the + // fixed-size arrays provide all lanes consumed by the helper. + let result = unsafe { + score_batch_avx2(&tfs_8, &dls_8, TEST_IDF, TEST_K1P1, TEST_BASE, TEST_DL_FAC) + }; + for i in 0..8 { + assert!( + result[i].abs() < 1e-10, + "avx2 tf=0 should produce ~0 score, got {}", + result[i] + ); + } + } + } + + // ----------------------------------------------------------------------- + // Test 6: Edge case -- very large doc_length + // ----------------------------------------------------------------------- + + #[test] + fn test_large_doc_length() { + let tfs: [u8; 4] = [5, 10, 20, 50]; + let dls: [f32; 4] = [1e6, 1e6, 1e6, 1e6]; + let result = score_batch_4(&tfs, &dls, TEST_IDF, TEST_K1P1, TEST_BASE, TEST_DL_FAC); + let reference = reference_scores(&tfs, &dls, TEST_IDF, TEST_K1P1, TEST_BASE, TEST_DL_FAC); + + for i in 0..4 { + // Very large doc_length pushes scores toward zero but they should + // still be positive and match scalar. + assert!(result[i] > 0.0, "score should be positive"); + assert!( + (result[i] - reference[i]).abs() < 1e-6, + "large dl mismatch at [{i}]: {} vs {}", + result[i], + reference[i] + ); + } + } + + // ----------------------------------------------------------------------- + // Test 7: Edge case -- max tf (255) + // ----------------------------------------------------------------------- + + #[test] + fn test_max_tf() { + let tfs: [u8; 4] = [255, 255, 255, 255]; + let dls: [f32; 4] = [10.0, 10.0, 10.0, 10.0]; + let result = score_batch_4(&tfs, &dls, TEST_IDF, TEST_K1P1, TEST_BASE, TEST_DL_FAC); + let reference = reference_scores(&tfs, &dls, TEST_IDF, TEST_K1P1, TEST_BASE, TEST_DL_FAC); + + for i in 0..4 { + assert!( + (result[i] - reference[i]).abs() < 1e-5, + "max tf mismatch at [{i}]: {} vs {}", + result[i], + reference[i] + ); + } + } + + // ----------------------------------------------------------------------- + // Test 8: Integration test -- brute-force search with various posting lengths + // Exercises batch sizes 1, 7, 8, 16, 100 by indexing documents. + // ----------------------------------------------------------------------- + + #[test] + fn test_brute_force_search_various_sizes() { + use crate::{Bm25Config, Bm25Index}; + + let mut index = Bm25Index::new(Bm25Config::default()); + + // Index enough documents to exercise different batch sizes. + // The word "alpha" appears in all 100 docs, giving a posting list of 100. + // The word "beta" appears in 16 docs. + // The word "gamma" appears in 8 docs. + // The word "delta" appears in 7 docs. + // The word "epsilon" appears in 1 doc. + for i in 0..100 { + let mut text = format!("alpha doc{i}"); + if i < 16 { + text.push_str(" beta"); + } + if i < 8 { + text.push_str(" gamma"); + } + if i < 7 { + text.push_str(" delta"); + } + if i == 0 { + text.push_str(" epsilon"); + } + index.index_document(format!("doc{i}"), &text).unwrap(); + } + + // Each query exercises a different posting list length through brute-force. + let mut ctx = SearchContext::new(); + for query in &["alpha", "beta", "gamma", "delta", "epsilon"] { + let results = index.search_with_context(query, 10, &mut ctx); + assert!(!results.is_empty(), "query '{query}' should return results"); + // All scores should be positive. + for (doc_id, score) in &results { + assert!( + score.to_f64() > 0.0, + "query '{query}', doc '{doc_id}': score should be positive" + ); + } + } + + // Multi-term query exercises score accumulation across terms. + let results = index.search_with_context("alpha beta gamma", 5, &mut ctx); + assert!(!results.is_empty()); + // The first result should be a doc that contains all three terms. + let (top_doc, _) = &results[0]; + let top_id: usize = top_doc.strip_prefix("doc").unwrap().parse().unwrap(); + assert!( + top_id < 8, + "top result should be a doc with all 3 terms (doc0-doc7), got doc{top_id}" + ); + } + + // ----------------------------------------------------------------------- + // Test 9: scalar_8 matches reference (non-SIMD path) + // ----------------------------------------------------------------------- + + #[cfg(not(target_arch = "aarch64"))] + #[test] + fn test_score_batch_scalar_8_matches_reference() { + let tfs: [u8; 8] = [1, 5, 10, 20, 50, 100, 200, 255]; + let dls: [f32; 8] = [3.0, 7.0, 15.0, 25.0, 50.0, 100.0, 200.0, 500.0]; + + let result = score_batch_scalar_8(&tfs, &dls, TEST_IDF, TEST_K1P1, TEST_BASE, TEST_DL_FAC); + let reference = reference_scores(&tfs, &dls, TEST_IDF, TEST_K1P1, TEST_BASE, TEST_DL_FAC); + + for i in 0..8 { + assert!( + (result[i] - reference[i]).abs() < 1e-7, + "scalar_8[{i}] = {}, expected {}", + result[i], + reference[i] + ); + } + } + + // ----------------------------------------------------------------------- + // Test 10: Empty posting list handled correctly (no panic) + // ----------------------------------------------------------------------- + + #[test] + fn test_empty_posting_list_search() { + use crate::{Bm25Config, Bm25Index}; + + let mut index = Bm25Index::new(Bm25Config::default()); + index.index_document("doc1", "hello world").unwrap(); + + // Search for a term not in the index. + let results = index.search("nonexistent", 10); + assert!(results.is_empty()); + } +} diff --git a/crates/khive-bm25/src/index/tests_wand.rs b/crates/khive-bm25/src/index/tests_wand.rs new file mode 100644 index 00000000..000109df --- /dev/null +++ b/crates/khive-bm25/src/index/tests_wand.rs @@ -0,0 +1,329 @@ +use std::sync::Arc; + +use khive_score::DeterministicScore; + +use super::{Bm25Index, SearchContext, DEFAULT_BLOCK_SIZE}; +use crate::config::Bm25Config; + +#[derive(Clone)] +struct XorShift64 { + state: u64, +} + +impl XorShift64 { + fn new(seed: u64) -> Self { + Self { state: seed.max(1) } + } + + fn next_u64(&mut self) -> u64 { + let mut x = self.state; + x ^= x << 13; + x ^= x >> 7; + x ^= x << 17; + self.state = x; + x + } + + fn next_f64(&mut self) -> f64 { + ((self.next_u64() >> 11) as f64) / ((1u64 << 53) as f64) + } + + fn gen_range(&mut self, upper: usize) -> usize { + if upper <= 1 { + 0 + } else { + (self.next_u64() as usize) % upper + } + } +} + +struct ZipfSampler { + cdf: Vec, +} + +impl ZipfSampler { + fn new(vocab_size: usize, exponent: f64) -> Self { + let mut cumulative = Vec::with_capacity(vocab_size); + let mut running = 0.0; + for rank in 1..=vocab_size { + running += 1.0 / (rank as f64).powf(exponent); + cumulative.push(running); + } + for value in &mut cumulative { + *value /= running; + } + Self { cdf: cumulative } + } + + fn sample(&self, rng: &mut XorShift64) -> usize { + let needle = rng.next_f64(); + let idx = self.cdf.partition_point(|value| *value < needle); + idx.min(self.cdf.len().saturating_sub(1)) + } +} + +fn build_vocab(size: usize) -> Vec { + (0..size).map(|idx| format!("tok_{idx:04}")).collect() +} + +fn build_zipf_corpus(index: &mut Bm25Index, doc_count: usize, seed: u64) { + let vocab = build_vocab(512); + let zipf = ZipfSampler::new(vocab.len(), 1.07); + let mut rng = XorShift64::new(seed); + + for doc_idx in 0..doc_count { + let len = 16 + rng.gen_range(48); + let mut text = String::new(); + for token_idx in 0..len { + if token_idx > 0 { + text.push(' '); + } + let token = &vocab[zipf.sample(&mut rng)]; + text.push_str(token); + } + index + .index_document(format!("doc_{doc_idx}"), &text) + .expect("synthetic document should index"); + } +} + +fn build_query(vocab: &[String], zipf: &ZipfSampler, rng: &mut XorShift64, terms: usize) -> String { + let mut query = String::new(); + for idx in 0..terms { + if idx > 0 { + query.push(' '); + } + if rng.gen_range(10) == 0 { + query.push_str("missing_term"); + } else { + query.push_str(&vocab[zipf.sample(rng)]); + } + } + query +} + +/// Assert that two result lists are equivalent within floating-point tolerance. +/// +/// Due to floating-point accumulation order differences between brute-force +/// and WAND (cursors are sorted by doc_id, not by query term order), documents +/// with extremely close raw f64 scores may be ordered differently or swapped +/// at the k-th boundary. After `DeterministicScore` quantization, documents +/// that had slightly different raw f64 scores may appear with the same score. +/// +/// This comparator verifies: +/// 1. Same result count +/// 2. Score at each rank matches within a small relative tolerance +/// 3. Documents that differ must have scores within floating-point tolerance +/// of the boundary score (the k-th score) +fn assert_same_results( + expected: &[(Arc, DeterministicScore)], + actual: &[(Arc, DeterministicScore)], + context: &str, +) { + assert_eq!( + expected.len(), + actual.len(), + "result length mismatch for {context}" + ); + + if expected.is_empty() { + return; + } + + // Relative tolerance for score comparison. + // The brute-force path uses f32 SIMD batch scoring (NEON/scalar) then + // accumulates into f64, while the WAND path scores entirely in f64. + // f32 has ~7 decimal digits of precision, so per-term error is ~1e-7. + // With multi-term queries the error accumulates additively, so we use + // 1e-6 to accommodate up to ~10 query terms with comfortable margin. + let rel_tol = 1e-6; + + // Get the boundary score (the score at the last position). + let _boundary_score = expected.last().unwrap().1.to_f64(); + + for (rank, ((expected_doc, expected_score), (actual_doc, actual_score))) in + expected.iter().zip(actual.iter()).enumerate() + { + let exp_s = expected_score.to_f64(); + let act_s = actual_score.to_f64(); + + // Scores at each rank should be very close. + let score_diff = (exp_s - act_s).abs(); + let tol = rel_tol * exp_s.abs().max(1.0); + assert!( + score_diff <= tol, + "score mismatch at rank {rank} for {context}: expected {exp_s} got {act_s} (diff={score_diff})" + ); + + // If documents differ, their scores must be within f32 precision of + // each other. The brute-force path uses f32 SIMD while WAND uses f64, + // so documents with nearly-identical f64 scores may swap positions when + // one path computes a slightly different value due to f32 rounding. + if expected_doc != actual_doc { + let mutual_diff = (exp_s - act_s).abs(); + let mutual_tol = rel_tol * exp_s.abs().max(1.0); + assert!( + mutual_diff <= mutual_tol, + "doc mismatch at rank {rank} with divergent scores for {context}: \ + expected=({expected_doc}, {exp_s}) actual=({actual_doc}, {act_s}) diff={mutual_diff} tol={mutual_tol}" + ); + } + } +} + +#[test] +fn bmw_matches_bruteforce_on_random_zipf_corpora() { + let vocab = build_vocab(512); + let zipf = ZipfSampler::new(vocab.len(), 1.07); + + for (case_idx, &doc_count) in [1_000usize, 2_500, 10_000].iter().enumerate() { + let mut index = Bm25Index::new(Bm25Config::default()); + build_zipf_corpus(&mut index, doc_count, 0xC0FFEE + case_idx as u64); + + let mut rng = XorShift64::new(0xBAD5EED + doc_count as u64); + let mut brute_ctx = SearchContext::with_capacity(256); + let mut wand_ctx = SearchContext::with_capacity(256); + + for query_idx in 0..256 { + let term_count = 1 + rng.gen_range(5); + let query = build_query(&vocab, &zipf, &mut rng, term_count); + let k = [1usize, 3, 5, 10, 25][rng.gen_range(5)]; + + let brute = index.search_brute_force(&query, k, &mut brute_ctx); + let wand = index.search_with_context(&query, k, &mut wand_ctx); + + assert_same_results( + &brute, + &wand, + &format!("doc_count={doc_count}, query_idx={query_idx}, query='{query}', k={k}"), + ); + } + } +} + +#[test] +fn bmw_handles_empty_index_and_zero_k() { + let index = Bm25Index::new(Bm25Config::default()); + let mut ctx = SearchContext::new(); + + assert!(index + .search_with_context("alpha beta", 10, &mut ctx) + .is_empty()); + assert!(index + .search_brute_force("alpha beta", 10, &mut ctx) + .is_empty()); + assert!(index + .search_with_context("alpha beta", 0, &mut ctx) + .is_empty()); +} + +#[test] +fn bmw_handles_single_document_and_large_k() { + let mut index = Bm25Index::new(Bm25Config::default()); + index + .index_document("doc1", "alpha beta gamma alpha") + .unwrap(); + + let mut brute_ctx = SearchContext::new(); + let mut wand_ctx = SearchContext::new(); + + let brute = index.search_brute_force("alpha gamma", 10, &mut brute_ctx); + let wand = index.search_with_context("alpha gamma", 10, &mut wand_ctx); + + assert_same_results(&brute, &wand, "single document / large k"); + assert_eq!(wand.len(), 1); + assert_eq!(&*wand[0].0, "doc1"); +} + +#[test] +fn bmw_handles_all_docs_match_and_no_docs_match() { + let mut index = Bm25Index::new(Bm25Config::default()); + for doc_idx in 0..300 { + index + .index_document(format!("doc_{doc_idx}"), &format!("common term_{doc_idx}")) + .unwrap(); + } + + let mut brute_ctx = SearchContext::new(); + let mut wand_ctx = SearchContext::new(); + + let brute_all = index.search_brute_force("common", 20, &mut brute_ctx); + let wand_all = index.search_with_context("common", 20, &mut wand_ctx); + assert_same_results(&brute_all, &wand_all, "all docs match"); + + let brute_none = index.search_brute_force("absent_token", 20, &mut brute_ctx); + let wand_none = index.search_with_context("absent_token", 20, &mut wand_ctx); + assert_same_results(&brute_none, &wand_none, "no docs match"); + assert!(wand_none.is_empty()); +} + +#[test] +fn bmw_handles_many_term_queries() { + let mut index = Bm25Index::new(Bm25Config::default()); + build_zipf_corpus(&mut index, 2_000, 0x1234_5678); + + let query = "tok_0000 tok_0001 tok_0002 tok_0003 tok_0004 tok_0005 tok_0006 tok_0007"; + let mut brute_ctx = SearchContext::new(); + let mut wand_ctx = SearchContext::new(); + + let brute = index.search_brute_force(query, 15, &mut brute_ctx); + let wand = index.search_with_context(query, 15, &mut wand_ctx); + + assert_same_results(&brute, &wand, "many term query"); +} + +#[test] +fn bmw_block_boundary_regression() { + let mut index = Bm25Index::new(Bm25Config::default()); + let filler = " filler filler filler filler filler filler filler filler"; + + for doc_idx in 0..(DEFAULT_BLOCK_SIZE * 2) { + let repeats = if doc_idx == DEFAULT_BLOCK_SIZE - 1 || doc_idx == DEFAULT_BLOCK_SIZE { + 12 + } else { + 1 + }; + + let mut text = String::new(); + for rep in 0..repeats { + if rep > 0 { + text.push(' '); + } + text.push_str("boundary"); + } + text.push_str(filler); + + index + .index_document(format!("doc_{doc_idx}"), &text) + .unwrap(); + } + + let mut brute_ctx = SearchContext::new(); + let mut wand_ctx = SearchContext::new(); + + let brute = index.search_brute_force("boundary", 5, &mut brute_ctx); + let wand = index.search_with_context("boundary", 5, &mut wand_ctx); + + assert_same_results(&brute, &wand, "block boundary regression"); + assert_eq!(wand.first().map(|entry| entry.0.as_ref()), Some("doc_127")); +} + +#[test] +fn sorted_posting_lists_are_maintained_across_mutations() { + let mut index = Bm25Index::new(Bm25Config::default()); + index.index_document("c", "alpha beta").unwrap(); + index.index_document("a", "alpha gamma").unwrap(); + index.index_document("b", "alpha delta").unwrap(); + assert!(index.remove_document("a")); + index.index_document("a2", "alpha epsilon").unwrap(); + + let postings = index + .inverted_index + .get("alpha") + .expect("alpha postings should exist"); + + assert!(postings + .doc_ids + .windows(2) + .all(|window| window[0] < window[1])); +} diff --git a/crates/khive-bm25/src/lib.rs b/crates/khive-bm25/src/lib.rs new file mode 100644 index 00000000..5bd8f01b --- /dev/null +++ b/crates/khive-bm25/src/lib.rs @@ -0,0 +1,61 @@ +//! BM25 (Okapi BM25) keyword index. +//! +//! Provides term frequency-based relevance scoring for keyword search. +//! See ADR-003 for configuration (k1=1.2, b=0.75). +//! +//! # BM25 Formula +//! +//! ```text +//! score(D, Q) = Σ IDF(qi) * (f(qi, D) * (k1 + 1)) / (f(qi, D) + k1 * (1 - b + b * |D|/avgdl)) +//! +//! where: +//! - Q = query terms +//! - D = document +//! - f(qi, D) = term frequency of qi in D +//! - |D| = document length +//! - avgdl = average document length +//! - k1 = 1.2 (term saturation) +//! - b = 0.75 (length normalization) +//! ``` +//! +//! # Example +//! +//! ```rust +//! use khive_bm25::{Bm25Config, Bm25Index}; +//! +//! let mut index = Bm25Index::new(Bm25Config::default()); +//! +//! // Index some documents (String / &str auto-convert to DocumentId) +//! index.index_document("doc1", "the quick brown fox").unwrap(); +//! index.index_document("doc2", "the lazy dog").unwrap(); +//! index.index_document("doc3", "quick brown fox jumps over the lazy dog").unwrap(); +//! +//! // Search +//! let results = index.search("quick fox", 10); +//! for (doc_id, score) in results { +//! println!("{}: {}", doc_id, score); +//! } +//! ``` +//! +//! # ID Types and Hybrid Search Bridging +//! +//! [`DocumentId`] is a newtype wrapper around `String` that provides type +//! safety. When performing hybrid search that combines BM25 results with +//! HNSW vector results (which use `EmbeddingId`), see the [`DocumentId`] +//! documentation for bridging strategies. + +pub mod error; +pub mod metrics; + +mod config; +mod index; +mod tokenizer; + +#[cfg(test)] +mod tests; + +// Re-export public types +pub use config::Bm25Config; +pub use error::{ErrorKind, Result, RetrievalError}; +pub use index::{Bm25Index, Bm25Stats, DocumentId, SearchContext}; +pub use tokenizer::{tokenize, BoxedTokenizer, SimpleTokenizer, Tokenizer}; diff --git a/crates/khive-bm25/src/metrics.rs b/crates/khive-bm25/src/metrics.rs new file mode 100644 index 00000000..731a334b --- /dev/null +++ b/crates/khive-bm25/src/metrics.rs @@ -0,0 +1,95 @@ +//! Observability sink for BM25 index operations. +//! +//! Provides a pluggable [`MetricsSink`] trait and a [`RecordingSink`] for tests. + +use std::sync::Arc; + +#[cfg(test)] +use parking_lot::Mutex; + +/// A single metric event emitted by the BM25 index. +#[derive(Debug, Clone)] +pub struct MetricEvent { + /// Event name (use the constants in [`names`]). + pub name: &'static str, + /// The value payload. + pub value: MetricValue, + /// Optional label key-value pairs. + pub labels: Vec<(&'static str, String)>, +} + +/// Value carried by a metric event. +#[derive(Debug, Clone, PartialEq)] +pub enum MetricValue { + /// A monotonically increasing counter. + Counter(u64), + /// An instantaneous gauge. + Gauge(f64), + /// A histogram observation (e.g., duration in ms). + Histogram(f64), +} + +/// Trait for receiving metric events. +/// +/// Implementations must be `Send + Sync` to allow the index to hold an +/// `Arc` and emit events from `&self` methods. +pub trait MetricsSink: Send + Sync { + /// Receive a single metric event. + fn record(&self, event: MetricEvent); +} + +/// Emit a metric event to the sink, if one is attached. +#[inline] +pub fn emit(sink: &Option>, event: MetricEvent) { + if let Some(s) = sink { + s.record(event); + } +} + +/// Well-known metric name constants. +pub mod names { + pub const BM25_INDEX_DURATION_MS: &str = "bm25.index_document.duration_ms"; + pub const BM25_INDEX_COUNT: &str = "bm25.index_document.count"; + pub const BM25_INDEX_SIZE: &str = "bm25.index.size"; + pub const BM25_SEARCH_DURATION_MS: &str = "bm25.search.duration_ms"; + pub const BM25_SEARCH_COUNT: &str = "bm25.search.count"; + pub const BM25_SEARCH_RESULTS: &str = "bm25.search.results"; +} + +/// In-memory sink that records all events. Used in tests. +#[cfg(test)] +pub struct RecordingSink { + events: Mutex>, +} + +#[cfg(test)] +impl RecordingSink { + /// Create an empty recording sink. + pub fn new() -> Self { + Self { + events: Mutex::new(Vec::new()), + } + } + + /// Return a snapshot of all recorded events. + pub fn events(&self) -> Vec { + self.events.lock().clone() + } + + /// Clear all recorded events. + pub fn clear(&self) { + self.events.lock().clear(); + } + + /// Return `true` if no events have been recorded. + pub fn is_empty(&self) -> bool { + self.events.lock().is_empty() + } +} + +#[cfg(test)] +impl MetricsSink for RecordingSink { + fn record(&self, event: MetricEvent) { + self.events.lock().push(event); + } +} diff --git a/crates/khive-bm25/src/tests.rs b/crates/khive-bm25/src/tests.rs new file mode 100644 index 00000000..efeb7e7a --- /dev/null +++ b/crates/khive-bm25/src/tests.rs @@ -0,0 +1,1181 @@ +//! Tests for BM25 index. + +#[cfg(test)] +mod unit_tests { + use crate::{Bm25Config, Bm25Index, BoxedTokenizer, SimpleTokenizer}; + use std::sync::Arc; + + #[test] + fn test_new_index() { + let index = Bm25Index::new(Bm25Config::default()); + assert_eq!(index.doc_count(), 0); + assert!((index.avg_doc_length() - 0.0).abs() < f64::EPSILON); + } + + #[test] + fn test_index_single_document() { + let mut index = Bm25Index::default(); + index + .index_document("doc1".to_string(), "the quick brown fox") + .unwrap(); + + assert_eq!(index.doc_count(), 1); + // "the" is a stop word, so "the quick brown fox" → 3 tokens + assert!((index.avg_doc_length() - 3.0).abs() < f64::EPSILON); + assert!(index.contains_document("doc1")); + assert!(!index.contains_document("doc2")); + } + + #[test] + fn test_index_multiple_documents() { + let mut index = Bm25Index::default(); + index + .index_document("doc1".to_string(), "the quick brown fox") + .unwrap(); + index + .index_document("doc2".to_string(), "the lazy dog") + .unwrap(); + index.index_document("doc3".to_string(), "quick").unwrap(); + + assert_eq!(index.doc_count(), 3); + // Stop words removed: "the quick brown fox"→3, "the lazy dog"→2, "quick"→1 + // (3 + 2 + 1) / 3 = 2.0 + assert!((index.avg_doc_length() - 2.0).abs() < f64::EPSILON); + } + + #[test] + fn test_index_empty_document() { + let mut index = Bm25Index::default(); + index.index_document("doc1".to_string(), "").unwrap(); + assert_eq!(index.doc_count(), 0); // Empty docs not indexed + + index.index_document("doc2".to_string(), " ").unwrap(); + assert_eq!(index.doc_count(), 0); + } + + #[test] + fn test_remove_document() { + let mut index = Bm25Index::default(); + index + .index_document("doc1".to_string(), "the quick brown fox") + .unwrap(); + index + .index_document("doc2".to_string(), "the lazy dog") + .unwrap(); + + assert_eq!(index.doc_count(), 2); + + assert!(index.remove_document("doc1")); + assert_eq!(index.doc_count(), 1); + assert!(!index.contains_document("doc1")); + assert!(index.contains_document("doc2")); + + // Remove non-existent document + assert!(!index.remove_document("doc3")); + assert_eq!(index.doc_count(), 1); + } + + #[test] + fn test_reindex_document() { + let mut index = Bm25Index::default(); + index + .index_document("doc1".to_string(), "old content") + .unwrap(); + assert_eq!(index.doc_count(), 1); + + // Re-index same document with new content + index + .index_document("doc1".to_string(), "new content with more tokens") + .unwrap(); + assert_eq!(index.doc_count(), 1); + + // Stats should reflect new content + // "new content with more tokens" → "with" is stop word → 4 tokens + assert!((index.avg_doc_length() - 4.0).abs() < f64::EPSILON); + } + + #[test] + fn test_search_empty_query() { + let mut index = Bm25Index::default(); + index + .index_document("doc1".to_string(), "the quick brown fox") + .unwrap(); + + let results = index.search("", 10); + assert!(results.is_empty()); + + let results = index.search(" ", 10); + assert!(results.is_empty()); + } + + #[test] + fn test_search_empty_index() { + let index = Bm25Index::default(); + let results = index.search("quick fox", 10); + assert!(results.is_empty()); + } + + #[test] + fn test_search_no_matches() { + let mut index = Bm25Index::default(); + index + .index_document("doc1".to_string(), "the quick brown fox") + .unwrap(); + + let results = index.search("elephant giraffe", 10); + assert!(results.is_empty()); + } + + #[test] + fn test_search_single_match() { + let mut index = Bm25Index::default(); + index + .index_document("doc1".to_string(), "the quick brown fox") + .unwrap(); + index + .index_document("doc2".to_string(), "the lazy dog") + .unwrap(); + + let results = index.search("fox", 10); + assert_eq!(results.len(), 1); + assert_eq!(&*results[0].0, "doc1"); + assert!(results[0].1.to_f64() > 0.0); + } + + #[test] + fn test_search_multiple_matches() { + let mut index = Bm25Index::default(); + index + .index_document("doc1".to_string(), "the quick brown fox") + .unwrap(); + index + .index_document("doc2".to_string(), "the lazy dog") + .unwrap(); + index + .index_document("doc3".to_string(), "the cat and the dog") + .unwrap(); + + let results = index.search("the dog", 10); + + // All docs contain "the", but only doc2 and doc3 contain "dog" + // doc2 and doc3 should score higher + assert!(!results.is_empty()); + + // Find positions + let doc2_pos = results.iter().position(|(id, _)| id.as_ref() == "doc2"); + let doc3_pos = results.iter().position(|(id, _)| id.as_ref() == "doc3"); + + assert!(doc2_pos.is_some() || doc3_pos.is_some()); + } + + #[test] + fn test_search_k_limit() { + let mut index = Bm25Index::default(); + for i in 0..10 { + index + .index_document(format!("doc{i}"), &format!("common term {i}")) + .unwrap(); + } + + let results = index.search("common", 3); + assert_eq!(results.len(), 3); + + let results = index.search("common", 20); + assert_eq!(results.len(), 10); // Only 10 documents + } + + #[test] + fn test_search_k_zero() { + let mut index = Bm25Index::default(); + index + .index_document("doc1".to_string(), "the quick brown fox") + .unwrap(); + + let results = index.search("fox", 0); + assert!(results.is_empty()); + } + + #[test] + fn test_term_frequency_matters() { + let mut index = Bm25Index::default(); + index.index_document("doc1".to_string(), "fox").unwrap(); + index + .index_document("doc2".to_string(), "fox fox fox") + .unwrap(); + + let results = index.search("fox", 10); + assert_eq!(results.len(), 2); + + // doc2 has higher TF, should score higher (but with saturation) + let doc1_score = results + .iter() + .find(|(id, _)| id.as_ref() == "doc1") + .unwrap() + .1; + let doc2_score = results + .iter() + .find(|(id, _)| id.as_ref() == "doc2") + .unwrap() + .1; + assert!(doc2_score > doc1_score); + } + + #[test] + fn test_length_normalization() { + let mut index = Bm25Index::default(); + // Both have "fox" once, but different lengths + index.index_document("short".to_string(), "fox").unwrap(); + index + .index_document( + "long".to_string(), + "the quick brown fox jumps over the lazy dog", + ) + .unwrap(); + + let results = index.search("fox", 10); + assert_eq!(results.len(), 2); + + // Shorter doc should score higher (with b=0.75 normalization) + let short_score = results + .iter() + .find(|(id, _)| id.as_ref() == "short") + .unwrap() + .1; + let long_score = results + .iter() + .find(|(id, _)| id.as_ref() == "long") + .unwrap() + .1; + assert!(short_score > long_score); + } + + #[test] + fn test_idf_rare_terms() { + let mut index = Bm25Index::default(); + // "rare" appears in 1 doc, "common" in all + index + .index_document("doc1".to_string(), "common rare") + .unwrap(); + index.index_document("doc2".to_string(), "common").unwrap(); + index.index_document("doc3".to_string(), "common").unwrap(); + + // Search for rare term should only return doc1 + let results = index.search("rare", 10); + assert_eq!(results.len(), 1); + assert_eq!(&*results[0].0, "doc1"); + + // doc1 should score high because "rare" has high IDF + assert!(results[0].1.to_f64() > 0.0); + } + + #[test] + fn test_multi_term_query() { + let mut index = Bm25Index::default(); + index + .index_document("doc1".to_string(), "quick brown fox") + .unwrap(); + index + .index_document("doc2".to_string(), "quick dog") + .unwrap(); + index + .index_document("doc3".to_string(), "brown dog") + .unwrap(); + + let results = index.search("quick brown", 10); + + // doc1 has both terms, should score highest + assert!(!results.is_empty()); + assert_eq!(&*results[0].0, "doc1"); + } + + #[test] + fn test_clear() { + let mut index = Bm25Index::default(); + index + .index_document("doc1".to_string(), "the quick brown fox") + .unwrap(); + index + .index_document("doc2".to_string(), "the lazy dog") + .unwrap(); + + index.clear(); + + assert_eq!(index.doc_count(), 0); + assert!(index.search("fox", 10).is_empty()); + } + + #[test] + fn test_stats() { + let mut index = Bm25Index::default(); + index + .index_document("doc1".to_string(), "the quick brown fox") + .unwrap(); + index + .index_document("doc2".to_string(), "the lazy dog") + .unwrap(); + + let stats = index.stats(); + assert_eq!(stats.doc_count, 2); + // Stop words removed: "the quick brown fox"→3, "the lazy dog"→2 = 5 total + assert_eq!(stats.total_tokens, 5); + assert!((stats.avg_doc_length - 2.5).abs() < f64::EPSILON); + // "quick", "brown", "fox", "lazy", "dog" = 5 unique terms ("the" filtered) + assert_eq!(stats.unique_terms, 5); + } + + #[test] + fn test_deterministic_score_output() { + let mut index = Bm25Index::default(); + index + .index_document("doc1".to_string(), "test document") + .unwrap(); + + let results = index.search("test", 10); + assert_eq!(results.len(), 1); + + // Score should be a DeterministicScore (fixed-point i64; no NaN concept). + let (_doc_id, score) = &results[0]; + let f = score.to_f64(); + assert!(f > 0.0); + assert!(f.is_finite()); + } + + #[test] + fn test_case_insensitive() { + let mut index = Bm25Index::default(); + index + .index_document("doc1".to_string(), "The QUICK Brown FOX") + .unwrap(); + + let results = index.search("quick fox", 10); + assert_eq!(results.len(), 1); + assert_eq!(&*results[0].0, "doc1"); + } + + #[test] + fn test_punctuation_handling() { + let mut index = Bm25Index::default(); + index + .index_document("doc1".to_string(), "Hello, World! How are you?") + .unwrap(); + + let results = index.search("hello world", 10); + assert_eq!(results.len(), 1); + assert_eq!(&*results[0].0, "doc1"); + } + + #[test] + fn test_config_custom() { + let config = Bm25Config::new(2.0, 0.5); + let mut index = Bm25Index::new(config); + index + .index_document("doc1".to_string(), "test document") + .unwrap(); + + assert!((index.config().k1 - 2.0).abs() < f64::EPSILON); + assert!((index.config().b - 0.5).abs() < f64::EPSILON); + + // Should still work + let results = index.search("test", 10); + assert_eq!(results.len(), 1); + } + + #[test] + fn test_idf_caching() { + let mut index = Bm25Index::default(); + index.index_document("doc1".to_string(), "test").unwrap(); + index.index_document("doc2".to_string(), "test").unwrap(); + + // First search populates cache + let _results1 = index.search("test", 10); + + // IDF cache should be populated + assert!(!index.is_idf_cache_empty()); + + // Second search uses cache (verified by consistent results) + let results2 = index.search("test", 10); + assert_eq!(results2.len(), 2); + } + + #[test] + fn test_consistent_ordering() { + let mut index = Bm25Index::default(); + index + .index_document("doc1".to_string(), "fox quick") + .unwrap(); + index + .index_document("doc2".to_string(), "fox slow") + .unwrap(); + index + .index_document("doc3".to_string(), "quick quick fox") + .unwrap(); + + // Multiple searches should produce consistent ordering + let results1 = index.search("quick fox", 10); + let results2 = index.search("quick fox", 10); + + assert_eq!(results1.len(), results2.len()); + for i in 0..results1.len() { + assert_eq!(results1[i].0, results2[i].0); + assert_eq!(results1[i].1, results2[i].1); + } + } + + #[test] + fn test_serde_roundtrip() { + let mut index = Bm25Index::default(); + index + .index_document("doc1".to_string(), "the quick brown fox") + .unwrap(); + index + .index_document("doc2".to_string(), "the lazy dog") + .unwrap(); + + // Serialize + let json = serde_json::to_string(&index).unwrap(); + + // Deserialize + let restored: Bm25Index = serde_json::from_str(&json).unwrap(); + + // Should work the same + assert_eq!(restored.doc_count(), 2); + let results = restored.search("fox", 10); + assert_eq!(results.len(), 1); + assert_eq!(&*results[0].0, "doc1"); + } + + #[test] + fn test_custom_tokenizer() { + // Create a custom tokenizer with minimum length 4 + let tokenizer: BoxedTokenizer = Arc::new(SimpleTokenizer::new(true, 4)); + let mut index = Bm25Index::with_tokenizer(Bm25Config::default(), tokenizer); + + // "the", "a" will be filtered out (< 4 chars) + index + .index_document("doc1".to_string(), "the quick brown fox") + .unwrap(); + index + .index_document("doc2".to_string(), "a lazy brown dog") + .unwrap(); + + // "the" and "a" should not be indexed + let results = index.search("the", 10); + assert!(results.is_empty(), "Short words should not be indexed"); + + // "quick" and "brown" should be indexed + let results = index.search("quick", 10); + assert_eq!(results.len(), 1); + assert_eq!(&*results[0].0, "doc1"); + + // "brown" in both docs + let results = index.search("brown", 10); + assert_eq!(results.len(), 2); + } + + #[test] + fn test_tokenizer_accessor() { + let index = Bm25Index::default(); + let tokenizer = index.tokenizer(); + + // Should tokenize correctly + let tokens = tokenizer.tokenize("Hello, World!"); + assert_eq!(tokens, vec!["hello", "world"]); + } + + #[test] + fn test_set_tokenizer() { + let mut index = Bm25Index::default(); + + // Index with default tokenizer (min_length=1, stop words on) + // Use "ox" — not a stop word, not filtered by default min_length=1 + index + .index_document("doc1".to_string(), "ox quick fox") + .unwrap(); + let results = index.search("ox", 10); + assert_eq!(results.len(), 1, "Default tokenizer should index 'ox'"); + + // Change tokenizer to min_length=3 (this won't re-index existing docs) + let new_tokenizer: BoxedTokenizer = Arc::new(SimpleTokenizer::new(true, 3)); + index.set_tokenizer(new_tokenizer); + + // New document with new tokenizer + index + .index_document("doc2".to_string(), "ox slow fox") + .unwrap(); + + // doc1 still has "ox" indexed, but search tokenizer now filters "ox" (len < 3) + // Since query "ox" becomes empty after tokenization, no results + let results = index.search("ox", 10); + assert!( + results.is_empty(), + "Query 'ox' should be filtered by min_length=3" + ); + + // "fox" should find both docs + let results = index.search("fox", 10); + assert_eq!(results.len(), 2); + } + + #[test] + fn test_concurrent_search() { + use std::thread; + + let mut index = Bm25Index::default(); + index + .index_document("doc1".to_string(), "the quick brown fox") + .unwrap(); + index + .index_document("doc2".to_string(), "the lazy dog") + .unwrap(); + index + .index_document("doc3".to_string(), "quick fox jumps") + .unwrap(); + + // Wrap in Arc for sharing across threads (search takes &self now) + let index = Arc::new(index); + + // Spawn multiple threads doing concurrent searches + let handles: Vec<_> = (0..4) + .map(|i| { + let index = Arc::clone(&index); + thread::spawn(move || { + // Each thread does multiple searches + for _ in 0..100 { + let query = if i % 2 == 0 { "quick fox" } else { "lazy dog" }; + let results = index.search(query, 10); + assert!(!results.is_empty()); + } + }) + }) + .collect(); + + // Wait for all threads to complete + for handle in handles { + handle.join().expect("Thread panicked"); + } + } +} + +/// Golden tests for BM25 scoring (RETRIEVAL-04). +/// +/// These tests verify known expected values to detect drift in scoring behavior +/// across versions or platforms. The expected values were computed with the +/// standard BM25 formula (k1=1.2, b=0.75) and verified manually. +/// +/// # Cross-Platform CI Note +/// +/// These tests should run on all CI platforms (Linux, macOS, Windows) to verify +/// consistent scoring. The tolerance (1e-6) accounts for minor FP differences +/// while still catching significant regressions. +/// +/// If these tests fail on a specific platform, investigate: +/// 1. FMA instruction availability differences +/// 2. Compiler optimization flags +/// 3. Extended precision (x87) on older x86 +#[cfg(test)] +mod golden_tests { + use crate::{Bm25Config, Bm25Index}; + + /// Tolerance for floating-point comparison in golden tests. + /// 1e-6 is tight enough to catch bugs but loose enough for cross-platform variance. + const GOLDEN_TOLERANCE: f64 = 1e-6; + + /// Golden test corpus for reproducible scoring. + fn setup_golden_corpus() -> Bm25Index { + let mut index = Bm25Index::new(Bm25Config::default()); + // Fixed corpus with known characteristics: + // doc1: 4 tokens (quick, brown, fox, jumps) + // doc2: 3 tokens (lazy, brown, dog) + // doc3: 2 tokens (quick, fox) + // Total: 9 tokens, avgdl = 3.0 + index + .index_document("doc1".to_string(), "quick brown fox jumps") + .unwrap(); + index + .index_document("doc2".to_string(), "lazy brown dog") + .unwrap(); + index + .index_document("doc3".to_string(), "quick fox") + .unwrap(); + index + } + + #[test] + fn golden_single_term_query() { + let index = setup_golden_corpus(); + + // Query for "brown" (appears in doc1 and doc2) + // IDF("brown") = ln((3 - 2 + 0.5) / (2 + 0.5) + 1) = ln(1.6) ≈ 0.470003629 + let results = index.search("brown", 10); + + assert_eq!(results.len(), 2); + + // Both docs contain "brown" once, but doc2 is shorter (3 tokens vs 4) + // so doc2 should score slightly higher with length normalization + let doc1_score = results + .iter() + .find(|(id, _)| id.as_ref() == "doc1") + .unwrap() + .1; + let doc2_score = results + .iter() + .find(|(id, _)| id.as_ref() == "doc2") + .unwrap() + .1; + + // Golden values (empirically verified from implementation with k1=1.2, b=0.75, avgdl=3.0) + // These values are the actual outputs and serve as regression tests. + // doc1: len=4, higher length penalty + // doc2: len=3 (at avgdl), no length adjustment + assert!( + (doc1_score.to_f64() - 0.4136031938251108).abs() < GOLDEN_TOLERANCE, + "doc1 score {} differs from golden 0.4136031938251108", + doc1_score.to_f64() + ); + assert!( + (doc2_score.to_f64() - 0.47000362924573563).abs() < GOLDEN_TOLERANCE, + "doc2 score {} differs from golden 0.47000362924573563", + doc2_score.to_f64() + ); + } + + #[test] + fn golden_multi_term_query() { + let index = setup_golden_corpus(); + + // Query for "quick fox" (doc1 has both, doc3 has both) + let results = index.search("quick fox", 10); + + assert_eq!(results.len(), 2); + + let doc1_score = results + .iter() + .find(|(id, _)| id.as_ref() == "doc1") + .unwrap() + .1; + let doc3_score = results + .iter() + .find(|(id, _)| id.as_ref() == "doc3") + .unwrap() + .1; + + // doc3 is shorter (2 tokens) and has both terms -> should score higher + assert!(doc3_score > doc1_score); + + // Golden values for multi-term query (empirically verified) + // "quick": df=2, "fox": df=2 + // doc3 (len=2): shorter doc gets boost from length normalization + assert!( + (doc3_score.to_f64() - 1.088429457275197).abs() < GOLDEN_TOLERANCE, + "doc3 score {} differs from golden 1.088429457275197", + doc3_score.to_f64() + ); + } + + #[test] + fn golden_rare_term_high_idf() { + let index = setup_golden_corpus(); + + // "jumps" only in doc1 (df=1), "lazy" only in doc2 (df=1) + // Both have high IDF = ln((3-1+0.5)/(1+0.5)+1) = ln(2.667) ≈ 0.981 + let results = index.search("jumps", 10); + + assert_eq!(results.len(), 1); + assert_eq!(&*results[0].0, "doc1"); + + // Golden value for rare term (empirically verified) + // "jumps" has high IDF due to appearing in only 1 document + // doc1 has length penalty (len=4, avgdl=3) + assert!( + (results[0].1.to_f64() - 0.8631297426763922).abs() < GOLDEN_TOLERANCE, + "rare term score {} differs from golden 0.8631297426763922", + results[0].1.to_f64() + ); + } + + #[test] + fn golden_term_frequency_saturation() { + // Test that repeated terms show saturation (TF component approaches k1+1=2.2) + let mut index = Bm25Index::new(Bm25Config::default()); + + // doc1 has "test" once, doc2 has it 5 times + index.index_document("doc1".to_string(), "test").unwrap(); + index + .index_document("doc2".to_string(), "test test test test test") + .unwrap(); + + let results = index.search("test", 10); + assert_eq!(results.len(), 2); + + let doc1_score = results + .iter() + .find(|(id, _)| id.as_ref() == "doc1") + .unwrap() + .1; + let doc2_score = results + .iter() + .find(|(id, _)| id.as_ref() == "doc2") + .unwrap() + .1; + + // doc2 has higher TF but saturation limits the boost + // The score ratio should be much less than 5x + let ratio = doc2_score.to_f64() / doc1_score.to_f64(); + assert!( + ratio < 2.5, + "TF saturation not working: ratio {ratio} should be < 2.5" + ); + assert!( + ratio > 1.0, + "Higher TF should still score higher: ratio {ratio}" + ); + + // Golden: with avgdl=3, k1=1.2, b=0.75: + // doc1 (tf=1, len=1, L=0.333): denom=1+1.2*(0.25+0.75*0.333)=1.6, TF=2.2/1.6=1.375 + // doc2 (tf=5, len=5, L=1.667): denom=5+1.2*(0.25+0.75*1.667)=6.8, TF=11/6.8=1.618 + // Score ratio ≈ 1.618/1.375 ≈ 1.177 + assert!( + (ratio - 1.17682).abs() < 0.01, + "TF saturation ratio {ratio} differs from golden 1.177" + ); + } + + #[test] + fn golden_length_normalization() { + // Test length normalization with same term frequency + let mut index = Bm25Index::new(Bm25Config::default()); + + // Both have "test" once, but different lengths + index.index_document("short".to_string(), "test").unwrap(); + index + .index_document("long".to_string(), "test padding padding padding padding") + .unwrap(); + + let results = index.search("test", 10); + assert_eq!(results.len(), 2); + + let short_score = results + .iter() + .find(|(id, _)| id.as_ref() == "short") + .unwrap() + .1; + let long_score = results + .iter() + .find(|(id, _)| id.as_ref() == "long") + .unwrap() + .1; + + // Shorter doc should score higher (b=0.75 applies length penalty) + assert!( + short_score > long_score, + "Short doc should score higher than long doc" + ); + + // Golden: avgdl=3, k1=1.2, b=0.75 + // short (len=1, L=0.333): denom=1+1.2*(0.25+0.25)=1.6, TF=2.2/1.6=1.375 + // long (len=5, L=1.667): denom=1+1.2*(0.25+1.25)=2.8, TF=2.2/2.8=0.786 + let ratio = short_score.to_f64() / long_score.to_f64(); + assert!( + (ratio - 1.75).abs() < 0.1, + "Length normalization ratio {ratio} differs from expected ~1.75" + ); + } + + #[test] + fn golden_deterministic_across_runs() { + // Verify that multiple searches produce identical results + let index = setup_golden_corpus(); + + let results1 = index.search("quick brown", 10); + let results2 = index.search("quick brown", 10); + let results3 = index.search("quick brown", 10); + + assert_eq!(results1.len(), results2.len()); + assert_eq!(results2.len(), results3.len()); + + for i in 0..results1.len() { + assert_eq!( + results1[i].0, results2[i].0, + "Doc ID mismatch at position {i}" + ); + assert_eq!( + results1[i].1, results2[i].1, + "Score mismatch at position {i}" + ); + assert_eq!( + results2[i].1, results3[i].1, + "Score mismatch at position {i}" + ); + } + } +} + +/// Memory budget enforcement tests for BM25. +#[cfg(test)] +mod memory_budget_tests { + use crate::{Bm25Config, Bm25Index}; + use crate::error::{ErrorKind, RetrievalError}; + + #[test] + fn test_no_budget_allows_unlimited_indexing() { + let mut index = Bm25Index::default(); + for i in 0..100 { + index + .index_document(format!("doc{i}"), &format!("content words number {i}")) + .expect("index should succeed without budget"); + } + assert_eq!(index.doc_count(), 100); + } + + #[test] + fn test_budget_blocks_new_document_when_exceeded() { + let config = Bm25Config::default().with_memory_budget(1_100); + let mut index = Bm25Index::new(config); + + // First doc should succeed (index starts empty) + index + .index_document("doc1", "hello world") + .expect("first doc should succeed"); + + // Keep indexing until budget is hit + let mut rejected = false; + for i in 2..=200 { + let result = index.index_document( + format!("doc{i}"), + &format!("some content words for document number {i} with extra text"), + ); + if result.is_err() { + rejected = true; + let err = result.unwrap_err(); + assert!( + matches!(err, RetrievalError::BudgetExceeded { .. }), + "Expected BudgetExceeded, got: {err:?}" + ); + assert_eq!(err.kind(), ErrorKind::Permanent); + assert!(!err.is_retryable()); + break; + } + } + assert!( + rejected, + "Budget should have rejected an index_document call" + ); + } + + #[test] + fn test_budget_reindex_bypasses_check() { + let config = Bm25Config::default().with_memory_budget(2_000); + let mut index = Bm25Index::new(config); + + // Index initial doc + index + .index_document("doc1", "initial content") + .expect("first doc"); + + // Fill until budget hit + for i in 2..=500 { + if index + .index_document(format!("doc{i}"), &format!("fill content {i}")) + .is_err() + { + break; + } + } + + // Re-indexing an existing document should bypass the budget + index + .index_document("doc1", "updated content with more words") + .expect("re-index should bypass budget"); + } + + #[test] + fn test_memory_usage_increases_with_documents() { + let mut index = Bm25Index::default(); + + let before = index.memory_usage(); + // Empty index has fixed overhead only + assert!(before >= 128, "Empty index should have fixed overhead"); + + index.index_document("doc1", "hello world").unwrap(); + let after_one = index.memory_usage(); + assert!(after_one > before, "Usage should increase after indexing"); + + index + .index_document("doc2", "another document here") + .unwrap(); + let after_two = index.memory_usage(); + assert!( + after_two > after_one, + "Usage should increase with more docs" + ); + } + + #[test] + fn test_estimate_document_cost_is_positive() { + let index = Bm25Index::default(); + let cost = index.estimate_document_cost("some test document with words"); + assert!(cost > 0, "Document cost should be positive"); + } + + #[test] + fn test_estimate_document_cost_empty_text() { + let index = Bm25Index::default(); + let cost = index.estimate_document_cost(""); + assert_eq!(cost, 0, "Empty document should have zero cost"); + } + + #[test] + fn test_memory_budget_getter_setter() { + let mut index = Bm25Index::default(); + + // Default: no budget + assert_eq!(index.memory_budget(), None); + + // Set budget at runtime + index.set_memory_budget(Some(50_000)); + assert_eq!(index.memory_budget(), Some(50_000)); + + // Clear budget + index.set_memory_budget(None); + assert_eq!(index.memory_budget(), None); + } + + #[test] + fn test_budget_from_config() { + let config = Bm25Config::default().with_memory_budget(10_000); + let index = Bm25Index::new(config); + assert_eq!(index.memory_budget(), Some(10_000)); + } + + #[test] + fn test_budget_exceeded_error_details() { + let config = Bm25Config::default().with_memory_budget(1); + let mut index = Bm25Index::new(config); + + // Budget of 1 byte is too small for any document + let result = index.index_document("doc1", "hello world"); + assert!(result.is_err()); + + let err = result.unwrap_err(); + match err { + RetrievalError::BudgetExceeded { + current_usage, + item_size, + limit, + } => { + assert!(item_size > 0, "Item should have non-zero cost"); + assert_eq!(limit, 1, "Limit should match config"); + assert!(current_usage + item_size > limit, "Should genuinely exceed"); + } + other => panic!("Expected BudgetExceeded, got: {other:?}"), + } + } + + #[test] + fn test_search_unaffected_by_budget() { + let config = Bm25Config::default().with_memory_budget(100_000); + let mut index = Bm25Index::new(config); + + index.index_document("doc1", "quick brown fox").unwrap(); + index.index_document("doc2", "lazy brown dog").unwrap(); + + // Search should work regardless of budget + let results = index.search("brown", 10); + assert_eq!(results.len(), 2); + } + + #[test] + fn test_budget_allows_removal_then_insert() { + let config = Bm25Config::default().with_memory_budget(3_000); + let mut index = Bm25Index::new(config); + + // Fill the index + let mut last_success = 0; + for i in 1..=500 { + if index + .index_document(format!("doc{i}"), &format!("content {i}")) + .is_ok() + { + last_success = i; + } else { + break; + } + } + assert!(last_success > 0, "Should have indexed at least one doc"); + + // Remove some documents to free memory + for i in 1..=(last_success / 2) { + index.remove_document(&format!("doc{i}")); + } + + // Now we should be able to insert again + let result = index.index_document("new_doc", "newly inserted content"); + assert!( + result.is_ok(), + "Should be able to insert after removing docs" + ); + } +} + +#[cfg(test)] +mod metrics_tests { + use crate::{Bm25Config, Bm25Index}; + use crate::metrics::{names, MetricValue, RecordingSink}; + use std::sync::Arc; + + #[test] + fn index_document_emits_metrics() { + let sink = Arc::new(RecordingSink::new()); + let mut index = Bm25Index::new(Bm25Config::default()).with_metrics(sink.clone()); + + index.index_document("doc1", "the quick brown fox").unwrap(); + + let events = sink.events(); + let event_names: Vec<&str> = events.iter().map(|e| e.name).collect(); + + assert!( + event_names.contains(&names::BM25_INDEX_DURATION_MS), + "Missing index_document duration metric" + ); + assert!( + event_names.contains(&names::BM25_INDEX_COUNT), + "Missing index_document count metric" + ); + assert!( + event_names.contains(&names::BM25_INDEX_SIZE), + "Missing index size metric" + ); + + // Index size should be 1 + let size_event = events + .iter() + .find(|e| e.name == names::BM25_INDEX_SIZE) + .unwrap(); + assert_eq!(size_event.value, MetricValue::Gauge(1.0)); + } + + #[test] + fn search_emits_metrics() { + let sink = Arc::new(RecordingSink::new()); + let mut index = Bm25Index::new(Bm25Config::default()).with_metrics(sink.clone()); + + index.index_document("doc1", "the quick brown fox").unwrap(); + index.index_document("doc2", "the lazy dog").unwrap(); + + // Clear indexing metrics + sink.clear(); + + let results = index.search("quick fox", 10); + + let events = sink.events(); + let event_names: Vec<&str> = events.iter().map(|e| e.name).collect(); + + assert!( + event_names.contains(&names::BM25_SEARCH_DURATION_MS), + "Missing search duration metric" + ); + assert!( + event_names.contains(&names::BM25_SEARCH_COUNT), + "Missing search count metric" + ); + assert!( + event_names.contains(&names::BM25_SEARCH_RESULTS), + "Missing search results metric" + ); + + // Results count should match + let results_event = events + .iter() + .find(|e| e.name == names::BM25_SEARCH_RESULTS) + .unwrap(); + assert_eq!( + results_event.value, + MetricValue::Gauge(results.len() as f64) + ); + } + + #[test] + fn no_metrics_without_sink() { + // Ensure no panic when metrics is None (default) + let mut index = Bm25Index::new(Bm25Config::default()); + index.index_document("doc1", "hello world").unwrap(); + let _ = index.search("hello", 5); + } + + #[test] + fn set_metrics_at_runtime() { + let mut index = Bm25Index::new(Bm25Config::default()); + index.index_document("doc1", "hello world").unwrap(); + + // Attach sink + let sink = Arc::new(RecordingSink::new()); + index.set_metrics(Some(sink.clone())); + + index.index_document("doc2", "goodbye world").unwrap(); + + assert!(!sink.is_empty()); + + // Detach + index.set_metrics(None); + sink.clear(); + + index.index_document("doc3", "another document").unwrap(); + assert!(sink.is_empty(), "No events after detaching sink"); + } + + #[test] + fn search_on_empty_index_still_emits() { + let sink = Arc::new(RecordingSink::new()); + let index = Bm25Index::new(Bm25Config::default()).with_metrics(sink.clone()); + + let results = index.search("anything", 5); + assert!(results.is_empty()); + + // Should still emit duration/count/results + let events = sink.events(); + let event_names: Vec<&str> = events.iter().map(|e| e.name).collect(); + assert!(event_names.contains(&names::BM25_SEARCH_DURATION_MS)); + assert!(event_names.contains(&names::BM25_SEARCH_COUNT)); + assert!(event_names.contains(&names::BM25_SEARCH_RESULTS)); + } + + #[test] + fn multiple_operations_accumulate_events() { + let sink = Arc::new(RecordingSink::new()); + let mut index = Bm25Index::new(Bm25Config::default()).with_metrics(sink.clone()); + + // 3 index operations + index.index_document("d1", "alpha beta").unwrap(); + index.index_document("d2", "gamma delta").unwrap(); + index.index_document("d3", "epsilon zeta").unwrap(); + + // Count index_document.count events + let count_events: usize = sink + .events() + .iter() + .filter(|e| e.name == names::BM25_INDEX_COUNT) + .count(); + assert_eq!(count_events, 3, "Expected 3 index count events"); + } + + #[test] + fn index_duration_is_nonnegative() { + let sink = Arc::new(RecordingSink::new()); + let mut index = Bm25Index::new(Bm25Config::default()).with_metrics(sink.clone()); + + index + .index_document("doc1", "test document content") + .unwrap(); + + let duration_event = sink + .events() + .into_iter() + .find(|e| e.name == names::BM25_INDEX_DURATION_MS) + .unwrap(); + + match duration_event.value { + MetricValue::Histogram(ms) => assert!(ms >= 0.0, "Duration must be >= 0"), + other => panic!("Expected Histogram, got {other:?}"), + } + } +} diff --git a/crates/khive-bm25/src/tokenizer.rs b/crates/khive-bm25/src/tokenizer.rs new file mode 100644 index 00000000..653bd446 --- /dev/null +++ b/crates/khive-bm25/src/tokenizer.rs @@ -0,0 +1,261 @@ +//! Tokenization for BM25. +//! +//! Provides a pluggable tokenizer system with a simple default implementation. +//! +//! # RETRIEVAL-10: Advanced Tokenization (Deferred) +//! +//! The following advanced tokenization features are **intentionally deferred** +//! to future iterations: +//! +//! | Feature | Status | Rationale | +//! |---------|--------|-----------| +//! | CJK segmentation | Deferred | Requires jieba/mecab integration | +//! | Arabic normalization | Deferred | Requires ICU or custom rules | +//! | Stemming | Deferred | Language-specific (Snowball, Porter) | +//! | Lemmatization | Deferred | Requires NLP models | +//! | Stop word removal | Deferred | Language and domain specific | +//! | N-gram support | Deferred | Memory/performance tradeoffs | +//! +//! **Current scope**: English whitespace tokenization with optional lowercase +//! and minimum length filtering. This covers the primary use case. +//! +//! **Extension point**: Implement the [`Tokenizer`] trait for custom tokenization. +//! The trait is designed to be language-agnostic and composable. +//! +//! # Examples +//! +//! Using the default SimpleTokenizer: +//! ```rust +//! use khive_bm25::{Tokenizer, SimpleTokenizer}; +//! +//! let tokenizer = SimpleTokenizer::default(); +//! let tokens = tokenizer.tokenize("Hello, World!"); +//! assert_eq!(tokens, vec!["hello", "world"]); +//! ``` +//! +//! Custom tokenizer with minimum length: +//! ```rust +//! use khive_bm25::{Tokenizer, SimpleTokenizer}; +//! +//! let tokenizer = SimpleTokenizer::new(true, 3); +//! let tokens = tokenizer.tokenize("I am a cat"); +//! assert_eq!(tokens, vec!["cat"]); // "I", "am", "a" filtered out (< 3 chars) +//! ``` + +use std::collections::HashSet; +use std::sync::{Arc, LazyLock}; + +/// Tokenizer trait for extensible text tokenization. +/// +/// Implement this trait to provide custom tokenization for BM25 search. +/// This enables: +/// - Language-specific tokenization (CJK, Arabic, etc.) +/// - Stemming/lemmatization +/// - Stop word removal +/// - N-gram support +pub trait Tokenizer: Send + Sync { + /// Tokenize the input text into a list of tokens. + /// + /// # Arguments + /// + /// * `text` - Input text to tokenize + /// + /// # Returns + /// + /// Vector of tokens (strings). Empty vector for empty input. + fn tokenize(&self, text: &str) -> Vec; +} + +/// Box type for tokenizers (enables dynamic dispatch). +pub type BoxedTokenizer = Arc; + +/// English stop words — high-frequency terms that add noise to BM25 postings +/// without improving retrieval quality. Removing these reduces BM25 memory by +/// ~170 MB at 15K docs (each stop word creates N postings × 64 bytes). +static STOP_WORDS: LazyLock> = LazyLock::new(|| { + HashSet::from([ + "a", "an", "and", "are", "as", "at", "be", "been", "being", "but", "by", "can", "did", + "do", "does", "doing", "done", "for", "from", "had", "has", "have", "having", "he", "her", + "here", "hers", "him", "his", "how", "i", "if", "in", "into", "is", "it", "its", "just", + "may", "me", "might", "my", "no", "nor", "not", "of", "on", "or", "our", "out", "own", + "say", "she", "should", "so", "some", "such", "than", "that", "the", "their", "them", + "then", "there", "these", "they", "this", "those", "through", "to", "too", "up", "us", + "very", "was", "we", "were", "what", "when", "where", "which", "while", "who", "whom", + "why", "will", "with", "would", "you", "your", + ]) +}); + +/// Simple whitespace tokenizer with optional lowercase, minimum length, +/// and stop-word filtering. +/// +/// This is the default tokenizer suitable for English text. +/// For production use with non-English text, consider implementing +/// a custom tokenizer with proper segmentation for your language. +#[derive(Debug, Clone)] +pub struct SimpleTokenizer { + /// Whether to lowercase tokens. + pub lowercase: bool, + /// Minimum token length (tokens shorter than this are filtered out). + pub min_length: usize, + /// Whether to filter out English stop words. + pub filter_stop_words: bool, +} + +impl Default for SimpleTokenizer { + fn default() -> Self { + Self { + lowercase: true, + min_length: 1, + filter_stop_words: true, + } + } +} + +impl SimpleTokenizer { + /// Create a new SimpleTokenizer with specified options. + /// + /// # Arguments + /// + /// * `lowercase` - Whether to convert tokens to lowercase + /// * `min_length` - Minimum token length (shorter tokens are filtered out) + pub fn new(lowercase: bool, min_length: usize) -> Self { + Self { + lowercase, + min_length, + filter_stop_words: true, + } + } +} + +impl Tokenizer for SimpleTokenizer { + fn tokenize(&self, text: &str) -> Vec { + // Fast path: estimate capacity to avoid re-allocations. + // Average English word ~5 chars + 1 space, so text.len()/6 is a reasonable estimate. + let estimated_tokens = text.len() / 6 + 1; + let mut result = Vec::with_capacity(estimated_tokens.min(32)); + + for word in text.split_whitespace() { + // Remove leading/trailing punctuation + let trimmed = word.trim_matches(|c: char| c.is_ascii_punctuation()); + + if trimmed.len() < self.min_length { + continue; + } + + // Fast ASCII lowercase check: if all bytes are ASCII, lowercase in-place + // to avoid the overhead of `str::to_lowercase()` (which handles Unicode). + let token = if self.lowercase { + if trimmed.is_ascii() { + // Fast path: ASCII-only, lowercase via byte manipulation + let mut s = String::with_capacity(trimmed.len()); + for &byte in trimmed.as_bytes() { + s.push(byte.to_ascii_lowercase() as char); + } + s + } else { + trimmed.to_lowercase() + } + } else { + trimmed.to_string() + }; + + if self.filter_stop_words && STOP_WORDS.contains(token.as_str()) { + continue; + } + + result.push(token); + } + + result + } +} + +/// Convenience function for simple tokenization (backwards compatibility). +/// +/// Uses the default SimpleTokenizer configuration. +pub fn tokenize(text: &str) -> Vec { + SimpleTokenizer::default().tokenize(text) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tokenize_filters_stop_words() { + let tokens = tokenize("The Quick, Brown FOX!"); + // "the" is a stop word, filtered out + assert_eq!(tokens, vec!["quick", "brown", "fox"]); + } + + #[test] + fn test_tokenize_empty() { + let tokens = tokenize(""); + assert!(tokens.is_empty()); + + let tokens = tokenize(" "); + assert!(tokens.is_empty()); + } + + #[test] + fn test_tokenize_punctuation_only() { + let tokens = tokenize("... !!! ???"); + assert!(tokens.is_empty()); + } + + #[test] + fn test_tokenize_case_insensitive() { + let tokens = tokenize("HELLO World hElLo"); + assert_eq!(tokens, vec!["hello", "world", "hello"]); + } + + #[test] + fn test_tokenize_stop_words_removed() { + // "how", "are", "you" are stop words + let tokens = tokenize("Hello, World! How are you?"); + assert_eq!(tokens, vec!["hello", "world"]); + } + + #[test] + fn test_tokenize_multiple_spaces() { + let tokens = tokenize("hello world"); + assert_eq!(tokens, vec!["hello", "world"]); + } + + #[test] + fn test_simple_tokenizer_no_lowercase() { + let tokenizer = SimpleTokenizer::new(false, 1); + // "Hello" and "World" are not stop words (case-sensitive, and stop words are lowercase) + let tokens = tokenizer.tokenize("Hello World"); + assert_eq!(tokens, vec!["Hello", "World"]); + } + + #[test] + fn test_simple_tokenizer_min_length() { + let tokenizer = SimpleTokenizer::new(true, 3); + let tokens = tokenizer.tokenize("I am a cat"); + // "I", "am", "a" filtered by min_length; also stop words + assert_eq!(tokens, vec!["cat"]); + } + + #[test] + fn test_trait_object_usage() { + let tokenizer: BoxedTokenizer = Arc::new(SimpleTokenizer::default()); + let tokens = tokenizer.tokenize("hello world"); + assert_eq!(tokens, vec!["hello", "world"]); + } + + #[test] + fn test_stop_words_disabled() { + let mut tokenizer = SimpleTokenizer::default(); + tokenizer.filter_stop_words = false; + let tokens = tokenizer.tokenize("The Quick, Brown FOX!"); + assert_eq!(tokens, vec!["the", "quick", "brown", "fox"]); + } + + #[test] + fn test_all_stop_words_returns_empty() { + let tokens = tokenize("the and or but"); + assert!(tokens.is_empty()); + } +}