From c135e6678df8589a7daba0289efba3c88f15db95 Mon Sep 17 00:00:00 2001 From: Carson Poole Date: Wed, 22 May 2024 11:52:36 -0400 Subject: [PATCH 1/6] stuff for v0.2 --- Cargo.lock | 172 ++++- Cargo.toml | 7 + src/constants.rs | 8 +- src/errors.rs | 27 + src/lib.rs | 1 + src/main.rs | 12 +- src/main1.rs | 28 +- src/services/commit.rs | 259 ++----- src/services/namespace_state.rs | 34 +- src/services/query.rs | 103 +-- src/structures.rs | 5 +- src/structures/ann_tree.rs | 896 ++++++++++++++++++++++ src/structures/ann_tree/k_modes.rs | 351 +++++++++ src/structures/ann_tree/metadata.rs | 274 +++++++ src/structures/ann_tree/node.rs | 455 +++++++++++ src/structures/ann_tree/serialization.rs | 45 ++ src/structures/ann_tree/storage.rs | 379 +++++++++ src/structures/filters.rs | 660 +++++++++++++--- src/structures/inverted_index.rs | 4 +- src/structures/metadata_index.rs | 233 +++++- src/structures/mmap_tree.rs | 8 +- src/structures/mmap_tree/serialization.rs | 26 + src/structures/mmap_tree/storage.rs | 2 +- src/structures/wal.rs | 393 +++++----- 24 files changed, 3727 insertions(+), 655 deletions(-) create mode 100644 src/errors.rs create mode 100644 src/structures/ann_tree.rs create mode 100644 src/structures/ann_tree/k_modes.rs create mode 100644 src/structures/ann_tree/metadata.rs create mode 100644 src/structures/ann_tree/node.rs create mode 100644 src/structures/ann_tree/serialization.rs create mode 100644 src/structures/ann_tree/storage.rs diff --git a/Cargo.lock b/Cargo.lock index 38ce3b0..526317c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,19 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "ahash" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +dependencies = [ + "cfg-if", + "getrandom", + "once_cell", + "version_check", + "zerocopy", +] + [[package]] name = "aho-corasick" version = "1.1.3" @@ -26,6 +39,21 @@ dependencies = [ "memchr", ] +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "anes" version = "0.1.6" @@ -113,6 +141,12 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "bitflags" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" + [[package]] name = "block-buffer" version = "0.10.4" @@ -158,6 +192,20 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chrono" +version = "0.4.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "js-sys", + "num-traits", + "wasm-bindgen", + "windows-targets 0.52.5", +] + [[package]] name = "ciborium" version = "0.2.2" @@ -216,6 +264,12 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" +[[package]] +name = "core-foundation-sys" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" + [[package]] name = "cpufeatures" version = "0.2.12" @@ -362,6 +416,18 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" +[[package]] +name = "fallible-iterator" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" + +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + [[package]] name = "fnv" version = "1.0.7" @@ -490,17 +556,33 @@ name = "hashbrown" version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", +] + +[[package]] +name = "hashlink" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" +dependencies = [ + "hashbrown", +] [[package]] name = "haystackdb" version = "0.1.0" dependencies = [ + "ahash", + "chrono", "criterion", "env_logger", "fs2", "log", "memmap", + "rand", "rayon", + "rusqlite", "serde", "serde_json", "tokio", @@ -613,6 +695,29 @@ dependencies = [ "want", ] +[[package]] +name = "iana-time-zone" +version = "0.1.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + [[package]] name = "idna" version = "0.5.0" @@ -674,6 +779,16 @@ version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" +[[package]] +name = "libsqlite3-sys" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c10584274047cb335c23d3e61bcef8e323adae7c5c8c760540f73610177fc3f" +dependencies = [ + "pkg-config", + "vcpkg", +] + [[package]] name = "lock_api" version = "0.4.11" @@ -861,6 +976,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkg-config" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" + [[package]] name = "plotters" version = "0.3.5" @@ -969,7 +1090,7 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" dependencies = [ - "bitflags", + "bitflags 1.3.2", ] [[package]] @@ -1001,6 +1122,20 @@ version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" +[[package]] +name = "rusqlite" +version = "0.31.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b838eba278d213a8beaf485bd313fd580ca4505a00d5871caeb1457c55322cae" +dependencies = [ + "bitflags 2.5.0", + "fallible-iterator", + "fallible-streaming-iterator", + "hashlink", + "libsqlite3-sys", + "smallvec", +] + [[package]] name = "rustc-demangle" version = "0.1.23" @@ -1360,6 +1495,12 @@ dependencies = [ "serde", ] +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "version_check" version = "0.9.4" @@ -1515,6 +1656,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets 0.52.5", +] + [[package]] name = "windows-sys" version = "0.48.0" @@ -1653,3 +1803,23 @@ name = "windows_x86_64_msvc" version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" + +[[package]] +name = "zerocopy" +version = "0.7.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae87e3fcd617500e5d106f0380cf7b77f3c6092aae37191433159dda23cfb087" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/Cargo.toml b/Cargo.toml index 4ab60b7..1d32712 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,10 @@ log = "0.4.14" fs2 = "0.4.0" env_logger = "0.11.3" serde_json = "1.0.68" +ahash = "0.8.11" +rand = "0.8.4" +rusqlite = "0.31.0" +chrono = "0.4" [profile.release] opt-level = 3 @@ -25,6 +29,9 @@ opt-level = 3 [profile.bench] opt-level = 3 +# [profile.dev] +# debug = true + # [[bench]] # name = "hamming_distance" # harness = false diff --git a/src/constants.rs b/src/constants.rs index cef93b1..89968db 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -1,2 +1,8 @@ pub const VECTOR_SIZE: usize = 1024; -pub const QUANTIZED_VECTOR_SIZE: usize = VECTOR_SIZE / 8; +pub const QUANTIZED_VECTOR_SIZE: usize = 128; +pub const K: usize = 512; +pub const C: usize = 1; +pub const ALPHA: usize = 64; +pub const BETA: usize = 3; +pub const GAMMA: usize = 1; +pub const RHO: usize = 1; diff --git a/src/errors.rs b/src/errors.rs new file mode 100644 index 0000000..0685ea6 --- /dev/null +++ b/src/errors.rs @@ -0,0 +1,27 @@ +use std::error::Error; +use std::fmt; + +#[derive(Debug)] +pub struct HaystackError { + details: String, +} + +impl HaystackError { + pub fn new(msg: &str) -> HaystackError { + HaystackError { + details: msg.to_string(), + } + } +} + +impl fmt::Display for HaystackError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.details) + } +} + +impl Error for HaystackError { + fn description(&self) -> &str { + &self.details + } +} diff --git a/src/lib.rs b/src/lib.rs index c0aaa47..1e80a0c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ pub mod constants; +pub mod errors; pub mod math; pub mod services; pub mod structures; diff --git a/src/main.rs b/src/main.rs index 8f44465..f3ee897 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,7 +3,7 @@ use haystackdb::constants::VECTOR_SIZE; use haystackdb::services::CommitService; use haystackdb::services::QueryService; use haystackdb::structures::filters::Filter as QueryFilter; -use haystackdb::structures::metadata_index::KVPair; +use haystackdb::structures::filters::KVPair; use log::info; use log::LevelFilter; use std::io::Write; @@ -36,7 +36,7 @@ async fn main() { .and(with_active_namespaces(active_namespaces.clone())) .then( |namespace_id: String, body: (Vec, QueryFilter, usize), active_namespaces| async move { - let base_path = PathBuf::from(format!("/workspace/data/{}/current", namespace_id.clone())); + let base_path = PathBuf::from(format!("/Users/carsonpoole/haystackdb/tests/data/{}/current", namespace_id.clone())); ensure_namespace_initialized(&namespace_id, &active_namespaces, base_path.clone()) .await; @@ -73,7 +73,7 @@ async fn main() { body: (Vec, Vec, String), active_namespaces| async move { let base_path = PathBuf::from(format!( - "/workspace/data/{}/current", + "/Users/carsonpoole/haystackdb/tests/data/{}/current", namespace_id.clone() )); @@ -107,8 +107,10 @@ async fn main() { .then( |namespace_id: String, timestamp: String, active_namespaces| async move { println!("PITR for namespace: {}", namespace_id); - let base_path = - PathBuf::from(format!("/workspace/data/{}/current", namespace_id.clone())); + let base_path = PathBuf::from(format!( + "/Users/carsonpoole/haystackdb/tests/data/{}/current", + namespace_id.clone() + )); ensure_namespace_initialized(&namespace_id, &active_namespaces, base_path.clone()) .await; diff --git a/src/main1.rs b/src/main1.rs index 8334b63..132406b 100644 --- a/src/main1.rs +++ b/src/main1.rs @@ -2,14 +2,18 @@ extern crate haystackdb; use haystackdb::constants::VECTOR_SIZE; use haystackdb::services::commit::CommitService; use haystackdb::services::query::QueryService; -use haystackdb::structures::metadata_index::KVPair; +use haystackdb::structures::filters::{Filter, KVPair, KVValue}; use std::fs; use std::path::PathBuf; use std::str::FromStr; use uuid; fn random_vec() -> [f32; VECTOR_SIZE] { - return [0.0; VECTOR_SIZE]; + let mut vec = [0.0; VECTOR_SIZE]; + for i in 0..VECTOR_SIZE { + vec[i] = rand::random::() * 2.0 - 1.0; + } + vec } fn main() { @@ -35,7 +39,7 @@ fn main() { // .expect("Failed to add to WAL"); // } - const NUM_VECTORS: usize = 100_000; + const NUM_VECTORS: usize = 10_000_000; let batch_vectors: Vec> = (0..NUM_VECTORS).map(|_| vec![random_vec()]).collect(); @@ -43,7 +47,7 @@ fn main() { .map(|_| { vec![vec![KVPair { key: "key".to_string(), - value: "value".to_string(), + value: KVValue::String("value".to_string()), }]] }) .collect(); @@ -71,6 +75,10 @@ fn main() { println!("Commit took: {:?}", start.elapsed()); + commit_service.calibrate(); + + commit_service.state.vectors.summarize_tree(); + let mut query_service = QueryService::new(path.clone(), namespace_id).expect("Failed to create query service"); @@ -81,18 +89,18 @@ fn main() { let start = std::time::Instant::now(); for _ in 0..NUM_RUNS { - let _ = query_service + let result = query_service .query( - &[0.0; VECTOR_SIZE], - vec![KVPair { - key: "key".to_string(), - value: "value".to_string(), - }], + &random_vec(), + &Filter::Eq("key".to_string(), "value".to_string()), 1, ) .expect("Failed to query"); // println!("{:?}", result); + if result.len() == 0 { + println!("No results found"); + } } println!("Query took: {:?}", start.elapsed().div_f32(NUM_RUNS as f32)); diff --git a/src/services/commit.rs b/src/services/commit.rs index 98c62e1..515fc77 100644 --- a/src/services/commit.rs +++ b/src/services/commit.rs @@ -1,6 +1,8 @@ use crate::constants::VECTOR_SIZE; -use crate::structures::inverted_index::InvertedIndexItem; -use crate::structures::metadata_index::{KVPair, MetadataIndexItem}; +// use crate::structures::inverted_index::InvertedIndexItem; +// use crate::structures::metadata_index::{KVPair, MetadataIndexItem}; +use crate::structures::filters::{KVPair, KVValue}; +use rusqlite::Result; use super::namespace_state::NamespaceState; use std::collections::HashMap; @@ -19,132 +21,63 @@ impl CommitService { Ok(CommitService { state }) } - pub fn commit(&mut self) -> io::Result<()> { + pub fn commit(&mut self) -> Result<()> { let commits = self.state.wal.get_uncommitted(100000)?; - let commits_len = commits.len(); - - if commits.len() == 0 { - return Ok(()); - } - - println!("Commits: {:?}", commits_len); - let mut processed = 0; - let merged_commits = commits - .iter() - .fold((Vec::new(), Vec::new()), |mut items, commit| { - let vectors = commit.vectors.clone(); - let kvs = commit.kvs.clone(); - - items.0.extend(vectors); - items.1.extend(kvs); + println!("Commits to process: {:?}", commits.len()); - items - }); + let mut vectors = Vec::new(); + let mut kvs = Vec::new(); + let mut ids = Vec::new(); - for (vectors, kvs) in vec![merged_commits] { - // let vectors = commit.vectors; - // let kvs = commit.kvs; - - if vectors.len() != kvs.len() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Quantized vectors length mismatch", - )); - } - - println!( - "Processing commit: {} of {} with vectors of len: {}", - processed, - commits_len, - vectors.len() - ); - - processed += 1; - - // generate u128 ids - - let ids = (0..vectors.len()) + for commit in commits.iter() { + let inner_vectors = commit.vectors.clone(); + let inner_kvs = commit.kvs.clone(); + let inner_ids: Vec = inner_vectors + .iter() .map(|_| uuid::Uuid::new_v4().as_u128()) - .collect::>(); - - println!("Generated ids"); - - let vector_indices = self.state.vectors.batch_push(vectors)?; - - println!("Vector indices: {:?}", vector_indices); - - println!("Pushed vectors"); - - let mut inverted_index_items: HashMap> = HashMap::new(); - - // let mut metadata_index_items = Vec::new(); - - let mut batch_metadata_to_insert = Vec::new(); - - for (idx, kv) in kvs.iter().enumerate() { - let metadata_index_item = MetadataIndexItem { - id: ids[idx], - kvs: kv.clone(), - vector_index: vector_indices[idx], - // namespaced_id: self.state.namespace_id.clone(), - }; - - // println!("Inserting id: {}, {} of {}", ids[idx], idx, ids.len()); - - batch_metadata_to_insert.push((ids[idx], metadata_index_item)); - - // self.state - // .metadata_index - // .insert(ids[idx], metadata_index_item); - - for kv in kv { - // let inverted_index_item = InvertedIndexItem { - // indices: vec![vector_indices[idx]], - // ids: vec![ids[idx]], - // }; - - // self.state - // .inverted_index - // .insert_append(kv.clone(), inverted_index_item); - - inverted_index_items - .entry(kv.clone()) - .or_insert_with(Vec::new) - .push((vector_indices[idx], ids[idx])); - } + .collect(); + + for ((vector, kv), id) in inner_vectors + .iter() + .zip(inner_kvs.iter()) + .zip(inner_ids.iter()) + { + vectors.push(vector.clone()); + kvs.push(kv.clone()); + ids.push(id.clone()); } + } - self.state - .metadata_index - .batch_insert(batch_metadata_to_insert); + self.state.vectors.bulk_insert(vectors, ids, kvs); - // self.state.metadata_index.batch_insert(metadata_index_items); + // for commit in commits { + // // let vectors = commit.vectors; + // // let kvs = commit.kvs; - for (kv, items) in inverted_index_items { - let inverted_index_item = InvertedIndexItem { - indices: items.iter().map(|(idx, _)| *idx).collect(), - ids: items.iter().map(|(_, id)| *id).collect(), - }; + // // println!("Processing commit: {:?}", processed); + // // processed += 1; - self.state - .inverted_index - .insert_append(kv, inverted_index_item); - } - } + // // for (vector, kv) in vectors.iter().zip(kvs.iter()) { + // // let id = uuid::Uuid::new_v4().as_u128(); - for commit in commits { - self.state.wal.mark_commit_finished(commit.hash)?; - } + // // self.state.vectors.insert(vector.clone(), id, kv.clone()); + // // } + + // self.state.wal.mark_commit_finished(commit.hash)?; + // } Ok(()) } - pub fn recover_point_in_time(&mut self, timestamp: u64) -> io::Result<()> { + pub fn recover_point_in_time(&mut self, timestamp: u64) -> Result<()> { println!("Recovering to timestamp: {}", timestamp); - let versions: Vec = self.state.get_all_versions()?; + let versions: Vec = self + .state + .get_all_versions() + .expect("Failed to get versions"); let max_version = versions.iter().max().unwrap(); let new_version = max_version + 1; @@ -160,7 +93,8 @@ impl CommitService { .join(format!("v{}", new_version)); let mut fresh_state = - NamespaceState::new(new_version_path.clone(), self.state.namespace_id.clone())?; + NamespaceState::new(new_version_path.clone(), self.state.namespace_id.clone()) + .expect("Failed to create fresh state"); let commits = self.state.wal.get_commits_before(timestamp)?; let commits_len = commits.len(); @@ -173,77 +107,25 @@ impl CommitService { let mut processed = 0; - for commit in commits.iter() { - let vectors = commit.vectors.clone(); - let kvs = commit.kvs.clone(); - - if vectors.len() != kvs.len() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Quantized vectors length mismatch", - )); - } - - println!( - "Processing commit: {} of {} with vectors of len: {}", - processed, - commits_len, - vectors.len() - ); - - processed += 1; - - // generate u128 ids - let ids = (0..vectors.len()) - .map(|_| uuid::Uuid::new_v4().as_u128()) - .collect::>(); - - println!("Generated ids"); - - let vector_indices = fresh_state.vectors.batch_push(vectors)?; - - println!("Pushed vectors"); + // fresh_state.wal.mark_commit_finished(commit.hash)?; - let mut inverted_index_items: HashMap> = HashMap::new(); - - let mut metadata_index_items = Vec::new(); - - for (idx, kv) in kvs.iter().enumerate() { - let metadata_index_item = MetadataIndexItem { - id: ids[idx], - kvs: kv.clone(), - vector_index: vector_indices[idx], - // namespaced_id: self.state.namespace_id.clone(), - }; - - // println!("Inserting id: {}, {} of {}", ids[idx], idx, ids.len()); + for commit in commits { + let vectors = commit.vectors; + let kvs = commit.kvs; - metadata_index_items.push((ids[idx], metadata_index_item)); + for (vector, kv) in vectors.iter().zip(kvs.iter()) { + let id = uuid::Uuid::new_v4().as_u128(); - for kv in kv { - inverted_index_items - .entry(kv.clone()) - .or_insert_with(Vec::new) - .push((vector_indices[idx], ids[idx])); - } + fresh_state.vectors.insert(vector.clone(), id, kv.clone()); } - fresh_state - .metadata_index - .batch_insert(metadata_index_items); + fresh_state.wal.mark_commit_finished(commit.hash)?; - for (kv, items) in inverted_index_items { - let inverted_index_item = InvertedIndexItem { - indices: items.iter().map(|(idx, _)| *idx).collect(), - ids: items.iter().map(|(_, id)| *id).collect(), - }; + processed += 1; - fresh_state - .inverted_index - .insert_append(kv, inverted_index_item); + if processed % 1000 == 0 { + println!("Processed: {}/{}", processed, commits_len); } - - fresh_state.wal.mark_commit_finished(commit.hash)?; } // update symlink for /current @@ -251,24 +133,24 @@ impl CommitService { println!("Removing current symlink: {:?}", current_path); - std::fs::remove_file(¤t_path)?; - unix_fs::symlink(&new_version_path, ¤t_path)?; + std::fs::remove_file(¤t_path).expect("Failed to remove current symlink"); + unix_fs::symlink(&new_version_path, ¤t_path).expect("Failed to create symlink"); Ok(()) } + pub fn calibrate(&mut self) { + self.state + .vectors + .true_calibrate() + .expect("Failed to calibrate"); + } + pub fn add_to_wal( &mut self, vectors: Vec<[f32; VECTOR_SIZE]>, kvs: Vec>, - ) -> io::Result<()> { - if vectors.len() != vectors.len() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Quantized vectors length mismatch", - )); - } - + ) -> Result<()> { // self.state.wal.commit(hash, quantized_vectors, kvs) self.state .wal @@ -282,14 +164,7 @@ impl CommitService { &mut self, vectors: Vec>, kvs: Vec>>, - ) -> io::Result<()> { - if vectors.len() != kvs.len() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Quantized vectors length mismatch", - )); - } - + ) -> Result<()> { self.state.wal.batch_add_to_wal(vectors, kvs)?; Ok(()) diff --git a/src/services/namespace_state.rs b/src/services/namespace_state.rs index 22ca8a8..90d1c39 100644 --- a/src/services/namespace_state.rs +++ b/src/services/namespace_state.rs @@ -1,6 +1,8 @@ -use crate::structures::dense_vector_list::DenseVectorList; -use crate::structures::inverted_index::InvertedIndex; -use crate::structures::metadata_index::MetadataIndex; +use crate::structures::ann_tree::ANNTree; +// use crate::structures::dense_vector_list::DenseVectorList; +// use crate::structures::inverted_index::InvertedIndex; +// use crate::structures::metadata_index::MetadataIndex; +use crate::structures::mmap_tree::Tree; use crate::structures::wal::WAL; use std::fs; use std::io; @@ -11,9 +13,10 @@ use super::LockService; pub struct NamespaceState { pub namespace_id: String, - pub metadata_index: MetadataIndex, - pub inverted_index: InvertedIndex, - pub vectors: DenseVectorList, + // pub metadata_index: MetadataIndex, + // pub inverted_index: InvertedIndex, + pub texts: Tree>, + pub vectors: ANNTree, pub wal: WAL, pub locks: LockService, pub path: PathBuf, @@ -69,22 +72,25 @@ impl NamespaceState { let wal_path = path.clone().join("wal"); let locks_path = path.clone().join("locks"); - fs::create_dir_all(&wal_path).expect("Failed to create directory"); + fs::create_dir_all(&wal_path).unwrap_or_default(); - fs::create_dir_all(&locks_path).expect("Failed to create directory"); + fs::create_dir_all(&locks_path).unwrap_or_default(); let vectors_path = path.clone().join("vectors.bin"); + let texts_path = path.clone().join("texts.bin"); - let metadata_index = MetadataIndex::new(metadata_path); - let inverted_index = InvertedIndex::new(inverted_index_path); - let wal = WAL::new(wal_path, namespace_id.clone())?; - let vectors = DenseVectorList::new(vectors_path, 100_000)?; + // let metadata_index = MetadataIndex::new(metadata_path); + // let inverted_index = InvertedIndex::new(inverted_index_path); + let wal = WAL::new(wal_path, namespace_id.clone()).expect("Failed to create WAL"); + let vectors = ANNTree::new(vectors_path)?; let locks = LockService::new(locks_path); + let texts = Tree::new(texts_path)?; Ok(NamespaceState { namespace_id, - metadata_index, - inverted_index, + // metadata_index, + // inverted_index, + texts, vectors, wal, locks, diff --git a/src/services/query.rs b/src/services/query.rs index df944e1..90d1553 100644 --- a/src/services/query.rs +++ b/src/services/query.rs @@ -3,8 +3,7 @@ use rayon::prelude::*; use super::namespace_state::NamespaceState; use crate::constants::VECTOR_SIZE; use crate::math::hamming_distance; -use crate::structures::filters::{Filter, Filters}; -use crate::structures::metadata_index::KVPair; +use crate::structures::filters::{Filter, KVPair}; use crate::utils::quantize; use std::io; use std::path::PathBuf; @@ -27,97 +26,21 @@ impl QueryService { ) -> io::Result>> { let quantized_query_vector = quantize(query_vector); - let (indices, ids) = - Filters::evaluate(filters, &mut self.state.inverted_index).get_indices(); - // group contiguous indices to batch get vectors - let mut batch_indices: Vec> = Vec::new(); - - let mut current_batch = Vec::new(); - - for index in indices { - if current_batch.len() == 0 { - current_batch.push(index); - } else { - let last_index = current_batch[current_batch.len() - 1]; - if index == last_index + 1 { - current_batch.push(index); - } else { - batch_indices.push(current_batch); - current_batch = Vec::new(); - current_batch.push(index); - } - } - } - - current_batch.sort(); - current_batch.dedup(); - - if current_batch.len() > 0 { - batch_indices.push(current_batch); - } - - // println!("BATCH INDICES: {:?}", batch_indices.len()); - - let mut top_k_indices = Vec::new(); - - let top_k_to_use = top_k.min(ids.len()); - - for batch in batch_indices { - let vectors = self.state.vectors.get_contiguous(batch[0], batch.len())?; - top_k_indices.extend( - vectors - .par_iter() - .enumerate() - .fold( - || Vec::new(), - |mut acc, (idx, vector)| { - let distance = hamming_distance(&quantized_query_vector, vector); - - if acc.len() < top_k_to_use { - acc.push((ids[idx], distance)); - acc.sort(); - } else { - let worst_best_distance = acc[acc.len() - 1].1; - if distance < worst_best_distance { - acc.pop(); - acc.push((ids[idx], distance)); - acc.sort(); - } - } - - acc - }, - ) - .reduce( - || Vec::new(), // Initializer for the reduce step - |mut a, mut b| { - // How to combine results from different threads - a.append(&mut b); - a.sort_by_key(|&(_, dist)| dist); // Sort by distance - a.truncate(top_k_to_use); // Keep only the top k elements - a - }, - ), - ); - } - - let mut kvs = Vec::new(); + let result = self + .state + .vectors + .search(quantized_query_vector, top_k, filters) + .iter() + .map(|(_, metadata)| { + // let mut metadata = metadata.clone(); + // metadata.push(KVPair::new("id".to_string(), id.to_string())); - for (id, _) in top_k_indices { - let r = self.state.metadata_index.get(id); - match r { - Some(item) => { - kvs.push(item.kvs); - } - None => { - println!("Metadata not found"); - continue; - } - } - } + metadata.clone() + }) + .collect(); - Ok(kvs) + Ok(result) } } diff --git a/src/structures.rs b/src/structures.rs index 066f830..33facb3 100644 --- a/src/structures.rs +++ b/src/structures.rs @@ -1,7 +1,8 @@ +pub mod ann_tree; pub mod dense_vector_list; pub mod filters; -pub mod inverted_index; -pub mod metadata_index; +// pub mod inverted_index; +// pub mod metadata_index; pub mod mmap_tree; pub mod tree; pub mod wal; diff --git a/src/structures/ann_tree.rs b/src/structures/ann_tree.rs new file mode 100644 index 0000000..0b76de4 --- /dev/null +++ b/src/structures/ann_tree.rs @@ -0,0 +1,896 @@ +pub mod k_modes; +pub mod metadata; +pub mod node; +pub mod serialization; +pub mod storage; + +use node::{Node, NodeType}; +use rayon::iter::{IndexedParallelIterator, IntoParallelIterator}; +use storage::StorageManager; + +use crate::constants::QUANTIZED_VECTOR_SIZE; +use std::io; + +use self::k_modes::find_modes; +use self::metadata::{NodeMetadata, NodeMetadataIndex}; +use self::node::Vector; +use crate::math::hamming_distance; + +use super::filters::{combine_filters, Filter, Filters}; +// use super::metadata_index::{KVPair, KVValue}; +use super::mmap_tree::serialization::{TreeDeserialization, TreeSerialization}; +use crate::structures::filters::{calc_metadata_index_for_metadata, KVPair, KVValue}; + +use rayon::prelude::*; + +use ahash::{AHashMap as HashMap, AHashSet as HashSet}; +use std::fmt::{Debug, Display}; +use std::path::PathBuf; + +pub struct ANNTree { + pub k: usize, + pub storage_manager: storage::StorageManager, +} + +#[derive(Eq, PartialEq)] +struct PathNode { + distance: u16, + offset: usize, +} + +// Implement `Ord` and `PartialOrd` for `PathNode` to use it in a min-heap +impl Ord for PathNode { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + other.distance.cmp(&self.distance) // Reverse order for min-heap + } +} + +impl PartialOrd for PathNode { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl ANNTree { + pub fn new(path: PathBuf) -> Result { + let mut storage_manager = + StorageManager::new(path).expect("Failed to make storage manager in ANN Tree"); + + // println!("INIT Used space: {}", storage_manager.used_space); + + if storage_manager.root_offset() != 0 { + return Ok(ANNTree { + storage_manager, + k: crate::constants::K, + }); + } + + let mut root = Node::new_leaf(); + root.is_root = true; + + storage_manager.store_node(&mut root)?; + storage_manager.set_root_offset(root.offset); + + Ok(ANNTree { + storage_manager, + k: crate::constants::K, + }) + } + + pub fn batch_insert( + &mut self, + vectors: Vec, + ids: Vec, + metadata: Vec>, + ) { + for ((vector, id), metadata) in vectors.iter().zip(ids.iter()).zip(metadata.iter()) { + self.insert(vector.clone(), *id, metadata.clone()); + } + } + + pub fn bulk_insert( + &mut self, + vectors: Vec, + ids: Vec, + metadata: Vec>, + ) { + let mut current_leaves = Vec::new(); + self.collect_leaf_nodes(self.storage_manager.root_offset(), &mut current_leaves) + .expect("Failed to collect leaf nodes"); + + let mut new_root = Node::new_internal(); + + new_root.is_root = true; + + self.storage_manager.store_node(&mut new_root).unwrap(); + + self.storage_manager.set_root_offset(new_root.offset); + + // self.storage_manager.set_root_offset(leaf.offset); + + println!("Current leaves: {:?}", current_leaves.len()); + + // for leaf in current_leaves.iter_mut() { + // if leaf.is_root { + // leaf.is_root = false; + // } + // leaf.parent_offset = Some(leaf.offset); + // leaf.children.push(leaf.offset); + // leaf.vectors.push(find_modes(leaf.vectors.clone())); + // self.storage_manager.store_node(leaf).unwrap(); + // } + + let mut all_vectors = Vec::new(); + let mut all_ids = Vec::new(); + let mut all_metadata = Vec::new(); + + for leaf in current_leaves.iter_mut() { + leaf.is_root = false; + self.storage_manager.store_node(leaf).unwrap(); + all_vectors.extend(leaf.vectors.clone()); + all_ids.extend(leaf.ids.clone()); + all_metadata.extend(leaf.metadata.clone()); + } + + all_vectors.extend(vectors); + all_ids.extend(ids); + all_metadata.extend(metadata); + + println!("All vectors: {:?}", all_vectors.len()); + println!("All ids: {:?}", all_ids.len()); + println!("All metadata: {:?}", all_metadata.len()); + + let mut leaf = Node::new_leaf(); + + for ((vector, id), metadata) in all_vectors + .iter() + .zip(all_ids.iter()) + .zip(all_metadata.iter()) + { + if leaf.is_full() { + leaf.parent_offset = Some(new_root.offset); + leaf.node_metadata = calc_metadata_index_for_metadata(leaf.metadata.clone()); + self.storage_manager.store_node(&mut leaf).unwrap(); + new_root.children.push(leaf.offset); + new_root.vectors.push(find_modes(leaf.vectors.clone())); + self.storage_manager.store_node(&mut new_root).unwrap(); + leaf = Node::new_leaf(); + } + leaf.vectors.push(vector.clone()); + leaf.ids.push(*id); + leaf.metadata.push(metadata.clone()); + } + + new_root.node_metadata = self.compute_node_metadata(&new_root); + + self.storage_manager.store_node(&mut new_root).unwrap(); + + self.storage_manager.set_root_offset(new_root.offset); + + // self.true_calibrate(); + + // self.summarize_tree(); + } + + pub fn insert(&mut self, vector: Vector, id: u128, metadata: Vec) { + let entrypoint = self.find_entrypoint(vector); + let mut node = self.storage_manager.load_node(entrypoint).unwrap(); + + // println!("Entrypoint: {:?}", entrypoint); + + if node.is_full() { + let mut siblings = node.split().expect("Failed to split node"); + let sibling_offsets: Vec = siblings + .iter_mut() + .map(|sibling| { + sibling.parent_offset = node.parent_offset; // Set parent offset before storing + sibling.node_metadata = self.compute_node_metadata(&sibling); + self.storage_manager.store_node(sibling).unwrap() + }) + .collect(); + + for sibling in siblings.clone() { + if sibling.node_type == NodeType::Internal + && sibling.children.len() != sibling.vectors.len() + { + panic!("Internal node has different number of children and vectors"); + } + } + + if node.is_root { + let mut new_root = Node::new_internal(); + new_root.is_root = true; + new_root.children.push(node.offset); + new_root.vectors.push(find_modes(node.vectors.clone())); + for sibling_offset in &sibling_offsets { + let sibling = self.storage_manager.load_node(*sibling_offset).unwrap(); + new_root.vectors.push(find_modes(sibling.vectors)); + new_root.children.push(*sibling_offset); + } + self.storage_manager.store_node(&mut new_root).unwrap(); + self.storage_manager.set_root_offset(new_root.offset); + node.is_root = false; + node.parent_offset = Some(new_root.offset); + siblings + .iter_mut() + .for_each(|sibling| sibling.parent_offset = Some(new_root.offset)); + self.storage_manager.store_node(&mut node).unwrap(); + siblings.iter_mut().for_each(|sibling| { + if sibling.node_type == NodeType::Internal + && sibling.children.len() != sibling.vectors.len() + { + panic!("Internal node has different number of children and vectors v3"); + } + sibling.node_metadata = self.compute_node_metadata(sibling); + self.storage_manager.store_node(sibling).unwrap(); + }); + } else { + let parent_offset = node.parent_offset.unwrap(); + let mut parent = self.storage_manager.load_node(parent_offset).unwrap(); + parent.children.push(node.offset); + parent.vectors.push(find_modes(node.vectors.clone())); + sibling_offsets + .iter() + .for_each(|&offset| parent.children.push(offset)); + siblings + .iter() + .for_each(|sibling| parent.vectors.push(find_modes(sibling.vectors.clone()))); + if parent.node_type == NodeType::Internal + && parent.children.len() != parent.vectors.len() + { + println!("Parent vectors: {:?}", parent.vectors.len()); + println!("Parent children: {:?}", parent.children); + println!("Sibling offsets: {:?}", sibling_offsets.len()); + + panic!("parent node has different number of children and vectors"); + } + self.storage_manager.store_node(&mut parent).unwrap(); + node.parent_offset = Some(parent_offset); + self.storage_manager.store_node(&mut node).unwrap(); + siblings.into_iter().for_each(|mut sibling| { + if sibling.node_type == NodeType::Internal + && sibling.children.len() != sibling.vectors.len() + { + panic!("Internal node has different number of children and vectors v3"); + } + sibling.parent_offset = Some(parent_offset); + sibling.node_metadata = self.compute_node_metadata(&sibling); + self.storage_manager.store_node(&mut sibling).unwrap(); + }); + + let mut current_node = parent; + while current_node.is_full() { + println!("Current node is full"); + let mut siblings = current_node.split().expect("Failed to split node"); + let sibling_offsets: Vec = siblings + .iter_mut() + .map(|sibling| { + sibling.parent_offset = Some(current_node.parent_offset.unwrap()); + sibling.node_metadata = self.compute_node_metadata(sibling); + self.storage_manager.store_node(sibling).unwrap() + }) + .collect(); + + for sibling in siblings.clone() { + if sibling.node_type == NodeType::Internal + && sibling.children.len() != sibling.vectors.len() + { + panic!("Internal node has different number of children and vectors v2"); + } + } + + if current_node.is_root { + let mut new_root = Node::new_internal(); + new_root.is_root = true; + new_root.children.push(current_node.offset); + new_root.children.extend(sibling_offsets.clone()); + new_root + .vectors + .push(find_modes(current_node.vectors.clone())); + siblings.iter().for_each(|sibling| { + new_root.vectors.push(find_modes(sibling.vectors.clone())) + }); + self.storage_manager.store_node(&mut new_root).unwrap(); + self.storage_manager.set_root_offset(new_root.offset); + current_node.is_root = false; + current_node.parent_offset = Some(new_root.offset); + siblings + .iter_mut() + .for_each(|sibling| sibling.parent_offset = Some(new_root.offset)); + self.storage_manager.store_node(&mut current_node).unwrap(); + siblings.into_iter().for_each(|mut sibling| { + if sibling.node_type == NodeType::Internal + && sibling.children.len() != sibling.vectors.len() + { + panic!( + "Internal node has different number of children and vectors v4" + ); + } + sibling.node_metadata = self.compute_node_metadata(&sibling); + self.storage_manager.store_node(&mut sibling).unwrap(); + }); + new_root.node_metadata = self.compute_node_metadata(&new_root); + } else { + let parent_offset = current_node.parent_offset.unwrap(); + let mut parent = self.storage_manager.load_node(parent_offset).unwrap(); + parent.children.push(current_node.offset); + sibling_offsets + .iter() + .for_each(|&offset| parent.children.push(offset)); + parent + .vectors + .push(find_modes(current_node.vectors.clone())); + siblings.iter().for_each(|sibling| { + if sibling.node_type == NodeType::Internal + && sibling.children.len() != sibling.vectors.len() + { + panic!( + "Internal node has different number of children and vectors v5" + ); + } + parent.vectors.push(find_modes(sibling.vectors.clone())) + }); + self.storage_manager.store_node(&mut parent).unwrap(); + current_node.parent_offset = Some(parent_offset); + self.storage_manager.store_node(&mut current_node).unwrap(); + siblings.into_iter().for_each(|mut sibling| { + sibling.parent_offset = Some(parent_offset); + sibling.node_metadata = self.compute_node_metadata(&sibling); + self.storage_manager.store_node(&mut sibling).unwrap(); + }); + parent.node_metadata = self.compute_node_metadata(&parent); + current_node = parent; + } + } + } + } else { + if node.node_type != NodeType::Leaf { + panic!("Entrypoint is not a leaf node"); + } + node.vectors.push(vector); + node.ids.push(id); + node.metadata.push(metadata.clone()); + for kv in metadata { + match node.node_metadata.get(kv.key.clone()) { + Some(res) => { + let mut set = res.clone(); + match kv.value { + KVValue::String(value) => { + set.values.insert(value); + } + KVValue::Float(val) => { + let mut float_range = set.float_range.unwrap_or((val, val)); + if val < float_range.0 { + float_range.0 = val; + } + if val > float_range.1 { + float_range.1 = val; + } + set.float_range = Some(float_range); + } + KVValue::Integer(val) => { + let mut int_range = set.int_range.unwrap_or((val, val)); + if val < int_range.0 { + int_range.0 = val; + } + if val > int_range.1 { + int_range.1 = val; + } + set.int_range = Some(int_range); + } + } + + node.node_metadata.insert(kv.key.clone(), set); + } + None => { + let mut set = NodeMetadata::new(); + match kv.value { + KVValue::String(v) => { + set.values.insert(v); + } + KVValue::Float(val) => { + set.float_range = Some((val, val)); + } + KVValue::Integer(val) => { + set.int_range = Some((val, val)); + } + } + + node.node_metadata.insert(kv.key.clone(), set); + } + } + } + self.storage_manager.store_node(&mut node).unwrap(); + } + } + + fn find_entrypoint(&mut self, vector: Vector) -> usize { + let mut node = self + .storage_manager + .load_node(self.storage_manager.root_offset()) + .unwrap(); + + while node.node_type == NodeType::Internal { + let mut distances: Vec<(usize, u16)> = node + .vectors + .par_iter() + .map(|key| hamming_distance(&vector, key)) + .enumerate() + .collect(); + + distances.sort_by_key(|&(_, distance)| distance); + + let best = distances.get(0).unwrap(); + + let best_node = self + .storage_manager + .load_node(node.children[best.0]) + .unwrap(); + + node = best_node; + } + + // Now node is a leaf node + node.offset + } + + pub fn search( + &mut self, + vector: Vector, + top_k: usize, + filters: &Filter, + ) -> Vec<(u128, Vec)> { + let node = self + .storage_manager + .load_node(self.storage_manager.root_offset()) + .unwrap(); + + // let mut visited = HashSet::new(); + + // println!( + // "Root node: {:?}, {:?}", + // node.offset, + // self.storage_manager.root_offset() + // ); + + let mut candidates = self.traverse(&vector, &node, top_k, filters, 0); + + // Sort by distance and truncate to top_k results + candidates.sort_by_key(|&(_, distance, _)| distance); + candidates.truncate(top_k); + + candidates + .into_iter() + .map(|(id, _, pairs)| (id, pairs)) + .collect() + } + + fn traverse( + &self, + vector: &Vector, + node: &Node, + k: usize, + filters: &Filter, + depth: usize, + ) -> Vec<(u128, u16, Vec)> { + if node.node_type == NodeType::Leaf { + // println!("Leaf node: {:?}", node.offset); + return self + .collect_top_k_with_filters(vector, &node.vectors, &node.metadata, filters, k) + .par_iter() + .map(|(idx, distance)| { + let id = node.ids[*idx]; + let metadata = node.metadata[*idx].clone(); + (id, *distance, metadata) + }) + .collect::>(); + // .into_iter() + // .filter(|(id, _, _)| visited.insert(*id)) + // .collect(); + } + + let mut alpha = crate::constants::ALPHA; + if alpha <= 1 { + alpha = 1; + } + + // Collect top alpha nodes + let best_children: Vec<(usize, u16, Node)> = + self.collect_top_k_with_nodes(vector, &node.vectors, &node.children, filters, alpha); + + // println!("Best children: {:?}", best_children.len()); + + // Track results from top alpha nodes + // let mut all_results: Vec<(u128, u16, Vec)> = vec![]; + // for (_, _, child_node) in best_children.iter() { + // let mut current_results = + // self.traverse(vector, child_node, k, filters, depth + 1, visited); + // all_results.append(&mut current_results); + // } + let mut all_results: Vec<_> = best_children + .par_iter() + .flat_map(|(_, _, child_node)| self.traverse(vector, child_node, k, filters, depth + 1)) + .collect(); + + // Evaluate paths and choose the best based on distance + all_results.sort_by_key(|&(_, distance, _)| distance); + all_results.truncate(alpha); + + all_results + } + + fn collect_top_k_with_filters( + &self, + query: &Vector, + vector_items: &Vec, + metadata_items: &Vec>, + filters: &Filter, + k: usize, + ) -> Vec<(usize, u16)> { + let mut top_k_values: Vec<(usize, u16)> = Vec::with_capacity(k); + + let mut distances: Vec<(usize, u16)> = vector_items + .par_iter() + .enumerate() + .map(|(idx, item)| (idx, hamming_distance(item, query))) + .collect(); + + // Sort distances to find the top-k closest items + distances.sort_by_key(|&(_, distance)| distance); + + // Load nodes and filter top-k items + for &(idx, distance) in distances.iter() { + if top_k_values.len() >= k && distance >= top_k_values[k - 1].1 { + break; // No need to check further if we already have top-k and current distance is not better + } + + // Evaluate filters for the loaded node + if !Filters::should_prune_metadata(filters, &&metadata_items[idx]) { + // Add to top-k if it matches the filter + if top_k_values.len() < k { + top_k_values.push((idx, distance)); + } else { + // Replace the worst in top-k if current distance is better + let worst_best_distance = top_k_values[k - 1].1; + if distance < worst_best_distance { + top_k_values.pop(); + top_k_values.push((idx, distance)); + top_k_values.sort_by_key(|&(_, distance)| distance); + } + } + } else { + // println!("Pruned"); + } + } + + top_k_values + } + + fn collect_top_k_with_nodes( + &self, + query: &Vector, + items: &Vec, + children: &Vec, + filters: &Filter, + k: usize, + ) -> Vec<(usize, u16, Node)> { + let mut top_k_values: Vec<(usize, u16, Node)> = Vec::with_capacity(k); + + let mut distances: Vec<(usize, u16)> = items + .par_iter() + .enumerate() + .map(|(idx, item)| (idx, hamming_distance(item, query))) + .collect(); + + // Sort distances to find the top-k closest items + distances.sort_by_key(|&(_, distance)| distance); + + // Load nodes and filter top-k items + for &(idx, distance) in distances.iter() { + if top_k_values.len() >= k && distance >= top_k_values[k - 1].1 { + break; // No need to check further if we already have top-k and current distance is not better + } + + let child_node = self.storage_manager.load_node(children[idx]).unwrap(); + + // Evaluate filters for the loaded node + if !Filters::should_prune(filters, &child_node.node_metadata) { + // Add to top-k if it matches the filter + if top_k_values.len() < k { + top_k_values.push((idx, distance, child_node)); + } else { + // Replace the worst in top-k if current distance is better + let worst_best_distance = top_k_values[k - 1].1; + if distance < worst_best_distance { + top_k_values.pop(); + top_k_values.push((idx, distance, child_node)); + top_k_values.sort_by_key(|&(_, distance, _)| distance); + } + } + } else { + // println!("Pruned"); + // println!("Filters: {:?}", filters); + // println!("Node metadata: {:?}", child_node.node_metadata); + } + } + + top_k_values + } + + fn collect_leaf_nodes( + &mut self, + offset: usize, + leaf_nodes: &mut Vec, + ) -> Result<(), io::Error> { + let node = self.storage_manager.load_node(offset).unwrap().clone(); + if node.node_type == NodeType::Leaf { + leaf_nodes.push(node); + } else { + for &child_offset in &node.children { + self.collect_leaf_nodes(child_offset, leaf_nodes)?; + } + } + Ok(()) + } + + pub fn true_calibrate(&mut self) -> Result<(), io::Error> { + // Step 1: Get all leaf nodes + let mut leaf_nodes = Vec::new(); + self.collect_leaf_nodes(self.storage_manager.root_offset(), &mut leaf_nodes)?; + + // Step 2: Make a new root + let mut new_root = Node::new_internal(); + new_root.is_root = true; + + // Step 3: Store the new root to set its offset + self.storage_manager.store_node(&mut new_root)?; + self.storage_manager.set_root_offset(new_root.offset); + + // Step 4: Make all the leaf nodes the new root's children, and set all their parent_offsets to the new root's offset + for leaf_node in &mut leaf_nodes { + leaf_node.parent_offset = Some(new_root.offset); + new_root.children.push(leaf_node.offset); + new_root.vectors.push(find_modes(leaf_node.vectors.clone())); + // new_root.node_metadata = self.compute_node_metadata(&new_root); + self.storage_manager.store_node(leaf_node)?; + } + + new_root.node_metadata = self.compute_node_metadata(&new_root); + + // new_root.node_metadata = combine_filters( + // leaf_nodes + // .iter() + // .map(|node| node.node_metadata.clone()) + // .collect(), + // ); + + // Update the root node with its children and vectors + self.storage_manager.store_node(&mut new_root)?; + + // Step 5: Split the nodes until it is balanced/there are no nodes that are full + let mut current_nodes = vec![new_root]; + while let Some(mut node) = current_nodes.pop() { + if node.is_full() { + let mut siblings = node.split().expect("Failed to split node"); + let sibling_offsets: Vec = siblings + .iter_mut() + .map(|sibling| { + sibling.parent_offset = node.parent_offset; // Set parent offset before storing + sibling.node_metadata = self.compute_node_metadata(sibling); + self.storage_manager.store_node(sibling).unwrap() + }) + .collect(); + + for sibling in siblings.clone() { + if sibling.node_type == NodeType::Internal + && sibling.children.len() != sibling.vectors.len() + { + panic!("Internal node has different number of children and vectors"); + } + } + + if node.is_root { + let mut new_root = Node::new_internal(); + new_root.is_root = true; + new_root.children.push(node.offset); + new_root.vectors.push(find_modes(node.vectors.clone())); + + for sibling_offset in &sibling_offsets { + let sibling = self + .storage_manager + .load_node(*sibling_offset) + .unwrap() + .clone(); + new_root.vectors.push(find_modes(sibling.vectors)); + new_root.children.push(*sibling_offset); + } + + new_root.node_metadata = self.compute_node_metadata(&new_root); + self.storage_manager.store_node(&mut new_root)?; + self.storage_manager.set_root_offset(new_root.offset); + node.is_root = false; + node.parent_offset = Some(new_root.offset); + self.storage_manager.store_node(&mut node)?; + siblings + .iter_mut() + .for_each(|sibling| sibling.parent_offset = Some(new_root.offset)); + self.storage_manager.store_node(&mut node)?; + siblings.iter_mut().for_each(|sibling| { + if sibling.node_type == NodeType::Internal + && sibling.children.len() != sibling.vectors.len() + { + panic!("Internal node has different number of children and vectors v3"); + } + sibling.node_metadata = self.compute_node_metadata(sibling); + self.storage_manager.store_node(sibling); + }); + } else { + let parent_offset = node.parent_offset.unwrap(); + let mut parent = self + .storage_manager + .load_node(parent_offset) + .unwrap() + .clone(); + parent.children.push(node.offset); + parent.vectors.push(find_modes(node.vectors.clone())); + sibling_offsets + .iter() + .for_each(|&offset| parent.children.push(offset)); + siblings.iter().for_each(|sibling| { + parent.vectors.push(find_modes(sibling.vectors.clone())) + }); + parent.node_metadata = self.compute_node_metadata(&parent); + + if parent.node_type == NodeType::Internal + && parent.children.len() != parent.vectors.len() + { + panic!("parent node has different number of children and vectors"); + } + self.storage_manager.store_node(&mut parent)?; + node.parent_offset = Some(parent_offset); + self.storage_manager.store_node(&mut node)?; + siblings.into_iter().for_each(|mut sibling| { + if sibling.node_type == NodeType::Internal + && sibling.children.len() != sibling.vectors.len() + { + panic!("Internal node has different number of children and vectors v3"); + } + sibling.parent_offset = Some(parent_offset); + sibling.node_metadata = self.compute_node_metadata(&sibling); + self.storage_manager.store_node(&mut sibling); + }); + + let mut current_node = parent; + while current_node.is_full() { + let mut siblings = current_node.split().expect("Failed to split node"); + let sibling_offsets: Vec = siblings + .iter_mut() + .map(|sibling| { + sibling.parent_offset = Some(current_node.parent_offset.unwrap()); + sibling.node_metadata = self.compute_node_metadata(sibling); + self.storage_manager.store_node(sibling).unwrap() + }) + .collect(); + + for sibling in siblings.clone() { + if sibling.node_type == NodeType::Internal + && sibling.children.len() != sibling.vectors.len() + { + panic!( + "Internal node has different number of children and vectors v2" + ); + } + } + + if current_node.is_root { + let mut new_root = Node::new_internal(); + new_root.is_root = true; + new_root.children.push(current_node.offset); + new_root.children.extend(sibling_offsets.clone()); + new_root + .vectors + .push(find_modes(current_node.vectors.clone())); + siblings.iter().for_each(|sibling| { + new_root.vectors.push(find_modes(sibling.vectors.clone())) + }); + new_root.node_metadata = self.compute_node_metadata(&new_root); + self.storage_manager.store_node(&mut new_root)?; + self.storage_manager.set_root_offset(new_root.offset); + current_node.is_root = false; + current_node.parent_offset = Some(new_root.offset); + siblings + .iter_mut() + .for_each(|sibling| sibling.parent_offset = Some(new_root.offset)); + self.storage_manager.store_node(&mut current_node)?; + siblings.into_iter().for_each(|mut sibling| { + if sibling.node_type == NodeType::Internal && sibling.children.len() != sibling.vectors.len() { + panic!("Internal node has different number of children and vectors v4"); + } + sibling.node_metadata = self.compute_node_metadata(&sibling); + self.storage_manager.store_node(&mut sibling); + }); + } else { + let parent_offset = current_node.parent_offset.unwrap(); + let mut parent = self + .storage_manager + .load_node(parent_offset) + .unwrap() + .clone(); + parent.children.push(current_node.offset); + sibling_offsets + .iter() + .for_each(|&offset| parent.children.push(offset)); + parent + .vectors + .push(find_modes(current_node.vectors.clone())); + siblings.iter_mut().for_each(|sibling| { + if sibling.node_type == NodeType::Internal && sibling.children.len() != sibling.vectors.len() { + panic!("Internal node has different number of children and vectors v5"); + } + sibling.node_metadata = self.compute_node_metadata(sibling); + parent.vectors.push(find_modes(sibling.vectors.clone())) + }); + parent.node_metadata = self.compute_node_metadata(&parent); + self.storage_manager.store_node(&mut parent)?; + current_node.parent_offset = Some(parent_offset); + current_node.node_metadata = self.compute_node_metadata(¤t_node); + self.storage_manager.store_node(&mut current_node)?; + siblings.into_iter().for_each(|mut sibling| { + sibling.parent_offset = Some(parent_offset); + sibling.node_metadata = self.compute_node_metadata(&sibling); + self.storage_manager.store_node(&mut sibling); + }); + current_node = parent.clone(); + current_node.node_metadata = self.compute_node_metadata(¤t_node); + } + } + } + } + + node.node_metadata = self.compute_node_metadata(&node); + } + Ok(()) + } + + fn compute_node_metadata(&self, node: &Node) -> NodeMetadataIndex { + let mut children_metadatas = Vec::new(); + + for child_offset in &node.children { + let child = self.storage_manager.load_node(*child_offset).unwrap(); + + children_metadatas.push(child.node_metadata); + } + + combine_filters(children_metadatas) + } + + pub fn summarize_tree(&self) { + let mut queue = vec![self.storage_manager.root_offset()]; + let mut depth = 0; + + while !queue.is_empty() { + let mut next_queue = Vec::new(); + + for offset in queue { + let node = self.storage_manager.load_node(offset).unwrap(); + println!( + "Depth: {}, Node type: {:?}, Offset: {}, Children: {}, Vectors: {}", + depth, + node.node_type, + node.offset, + node.children.len(), + node.vectors.len() + ); + + if node.node_type == NodeType::Internal { + next_queue.extend(node.children.clone()); + } + } + + queue = next_queue; + depth += 1; + } + + println!("Tree depth: {}", depth); + } +} diff --git a/src/structures/ann_tree/k_modes.rs b/src/structures/ann_tree/k_modes.rs new file mode 100644 index 0000000..3c87b8a --- /dev/null +++ b/src/structures/ann_tree/k_modes.rs @@ -0,0 +1,351 @@ +use rand::seq::SliceRandom; +use rand::thread_rng; +use rand::Rng; + +use crate::constants::QUANTIZED_VECTOR_SIZE; + +// Function to calculate Hamming distance between two vectors +fn hamming_distance(v1: &[u8; QUANTIZED_VECTOR_SIZE], v2: &[u8; QUANTIZED_VECTOR_SIZE]) -> u32 { + v1.iter() + .zip(v2.iter()) + .fold(0, |acc, (&x, &y)| acc + (x ^ y).count_ones()) +} + +// Function to find the mode of bits at each position for vectors in a cluster +pub fn find_modes(vectors: Vec<[u8; QUANTIZED_VECTOR_SIZE]>) -> [u8; QUANTIZED_VECTOR_SIZE] { + let mut modes = [0; QUANTIZED_VECTOR_SIZE]; + for i in 0..QUANTIZED_VECTOR_SIZE * 8 { + let count_ones = vectors + .iter() + .filter(|vec| (vec[i / 8] & (1 << (i % 8))) != 0) + .count(); + if count_ones * 2 > vectors.len() { + // majority of ones + modes[i / 8] |= 1 << (i % 8); + } + } + modes +} + +// K-modes clustering function +fn k_modes_clustering( + data: Vec<[u8; QUANTIZED_VECTOR_SIZE]>, + k: usize, +) -> Vec<[u8; QUANTIZED_VECTOR_SIZE]> { + assert!( + data.len() >= k, + "Not enough data points to form the requested number of clusters." + ); + + let mut centroids: Vec<[u8; QUANTIZED_VECTOR_SIZE]> = Vec::new(); + let mut assignments = vec![0usize; data.len()]; + + // Initialize centroids (naively picking the first k elements) + for i in 0..k { + centroids.push(data[i].clone()); + } + + let mut change = true; + while change { + change = false; + + // Assign data points to centroids + for (idx, vec) in data.iter().enumerate() { + let mut min_distance = u32::MAX; + let mut min_index = 0; + for (centroid_idx, centroid) in centroids.iter().enumerate() { + let distance = hamming_distance(&vec, centroid); + if distance < min_distance { + min_distance = distance; + min_index = centroid_idx; + } + } + if assignments[idx] != min_index { + assignments[idx] = min_index; + change = true; + } + } + + // Update centroids + for centroid_idx in 0..k { + let cluster: Vec<[u8; QUANTIZED_VECTOR_SIZE]> = data + .iter() + .zip(assignments.iter()) + .filter_map(|(vec, &assignment)| { + if assignment == centroid_idx { + Some(vec.clone()) + } else { + None + } + }) + .map(|vec| vec) + .collect(); + + if !cluster.is_empty() { + centroids[centroid_idx] = find_modes(cluster); + } + } + } + + centroids +} + +pub fn balanced_k_modes(data: Vec<[u8; QUANTIZED_VECTOR_SIZE]>) -> (Vec, Vec) { + let mut centroids: [[u8; QUANTIZED_VECTOR_SIZE]; 2] = [data[0], data[1]]; // Correct syntax for fixed-size array initialization + let mut cluster_indices = vec![0usize; data.len()]; + let mut changes = true; + let mut iterations = 0; + + while changes && iterations < 50 { + // Avoid infinite loops + changes = false; + let mut cluster0 = Vec::new(); + let mut cluster1 = Vec::new(); + + // Assign vectors to the closest centroid + for (index, vector) in data.iter().enumerate() { + let dist0 = hamming_distance(vector, ¢roids[0]); + let dist1 = hamming_distance(vector, ¢roids[1]); + let current_assignment = cluster_indices[index]; + let new_assignment = if dist0 <= dist1 { 0 } else { 1 }; + + if new_assignment != current_assignment { + cluster_indices[index] = new_assignment; + changes = true; + } + + if new_assignment == 0 { + cluster0.push(index); + } else { + cluster1.push(index); + } + } + + // Update centroids for each cluster + if !cluster0.is_empty() { + centroids[0] = find_modes(cluster0.iter().map(|&i| data[i]).collect::>()); + } + if !cluster1.is_empty() { + centroids[1] = find_modes(cluster1.iter().map(|&i| data[i]).collect::>()); + } + + iterations += 1; + } + + // Distribute indices evenly + let (mut indices0, mut indices1) = (Vec::new(), Vec::new()); + for (i, &cluster) in cluster_indices.iter().enumerate() { + if cluster == 0 && indices0.len() < data.len() / 2 || indices1.len() >= data.len() / 2 { + indices0.push(i); + } else { + indices1.push(i); + } + } + + (indices0, indices1) +} + +pub fn balanced_k_modes_4( + mut data: Vec<[u8; QUANTIZED_VECTOR_SIZE]>, +) -> (Vec, Vec, Vec, Vec) { + if data.len() < 4 { + panic!("Not enough data points to initialize four clusters"); + } + + // Improved initial centroid selection using random sampling + let mut rng = thread_rng(); + let mut centroids = data + .choose_multiple(&mut rng, 4) + .cloned() + .collect::>(); + + let mut cluster_indices = vec![0usize; data.len()]; + let mut changes = true; + let mut iterations = 0; + let mut clusters = vec![Vec::new(), Vec::new(), Vec::new(), Vec::new()]; + + while changes && iterations < 50 { + changes = false; + clusters.iter_mut().for_each(|cluster| cluster.clear()); + + // Assign vectors to the closest centroid + for (index, vector) in data.iter().enumerate() { + let distances = centroids + .iter() + .map(|¢roid| hamming_distance(vector, ¢roid)) + .collect::>(); + let new_assignment = distances + .iter() + .enumerate() + .min_by_key(|&(_, dist)| dist) + .map(|(idx, _)| idx) + .unwrap(); + + if cluster_indices[index] != new_assignment { + cluster_indices[index] = new_assignment; + changes = true; + } + clusters[new_assignment].push(index); + } + + // Update centroids and manage empty clusters immediately + for (i, cluster) in clusters.iter_mut().enumerate() { + if cluster.is_empty() { + // Assign a random vector to an empty cluster + let random_index = rng.gen_range(0..data.len()); + cluster.push(random_index); + centroids[i] = data[random_index]; + changes = true; + } else { + centroids[i] = find_modes(cluster.iter().map(|&idx| data[idx]).collect::>()); + } + } + + iterations += 1; + } + + // Ensure balanced clusters for final output + balance_clusters(&mut clusters, data.len(), 4); + + ( + clusters[0].clone(), + clusters[1].clone(), + clusters[2].clone(), + clusters[3].clone(), + ) +} + +fn find_modes_bits(vectors: &[[u8; QUANTIZED_VECTOR_SIZE]]) -> [u8; QUANTIZED_VECTOR_SIZE] { + let mut modes = [0u8; QUANTIZED_VECTOR_SIZE]; + for i in 0..QUANTIZED_VECTOR_SIZE { + let mut counts = [0usize; 8]; + for vector in vectors { + for j in 0..8 { + counts[j] += ((vector[i] >> j) & 1) as usize; + } + } + modes[i] = counts + .iter() + .enumerate() + .map(|(j, &count)| ((count >= vectors.len() / 2) as u8) << j) + .fold(0, |acc, bit| acc | bit); + } + modes +} + +fn find_medoid(vectors: &[[u8; QUANTIZED_VECTOR_SIZE]]) -> [u8; QUANTIZED_VECTOR_SIZE] { + let mut min_sum_distance = u32::MAX; + let mut medoid = [0u8; QUANTIZED_VECTOR_SIZE]; + for &vector in vectors { + let sum_distance: u32 = vectors.iter().map(|&v| hamming_distance(&vector, &v)).sum(); + if sum_distance < min_sum_distance { + min_sum_distance = sum_distance; + medoid = vector; + } + } + medoid +} + +pub fn balanced_k_modes_k_clusters( + mut data: Vec<[u8; QUANTIZED_VECTOR_SIZE]>, + k: usize, +) -> Vec> { + if data.len() < k { + panic!("Not enough data points to initialize the specified number of clusters"); + } + + // Improved initial centroid selection using random sampling + let mut rng = thread_rng(); + let mut centroids = data + .choose_multiple(&mut rng, k) + .cloned() + .collect::>(); + + let mut cluster_indices = vec![0usize; data.len()]; + let mut changes = true; + let mut iterations = 0; + let mut clusters = vec![Vec::new(); k]; + + while changes && iterations < 100 { + // println!("Iteration {}", iterations); + changes = false; + clusters.iter_mut().for_each(|cluster| cluster.clear()); + + // Assign vectors to the closest centroid + for (index, vector) in data.iter().enumerate() { + let distances = centroids + .iter() + .map(|¢roid| hamming_distance(vector, ¢roid)) + .collect::>(); + let new_assignment = distances + .iter() + .enumerate() + .min_by_key(|&(_, dist)| dist) + .map(|(idx, _)| idx) + .unwrap(); + + if cluster_indices[index] != new_assignment { + cluster_indices[index] = new_assignment; + changes = true; + } + clusters[new_assignment].push(index); + } + + // Update centroids and manage empty clusters immediately + for (i, cluster) in clusters.iter_mut().enumerate() { + if cluster.is_empty() { + // Assign a random vector to an empty cluster + let random_index = rng.gen_range(0..data.len()); + cluster.push(random_index); + centroids[i] = data[random_index]; + changes = true; + } else { + let vectors = cluster.iter().map(|&idx| data[idx]).collect::>(); + centroids[i] = find_medoid(&vectors); + } + } + + iterations += 1; + } + + // Ensure balanced clusters for final output + // balance_clusters(&mut clusters, data.len(), k); + + clusters +} + +fn balance_clusters(clusters: &mut Vec>, total: usize, k: usize) { + let target_size = total / k; + let mut rng = thread_rng(); + let mut all_indices = clusters + .iter_mut() + .flat_map(|cluster| cluster.drain(..)) + .collect::>(); + all_indices.shuffle(&mut rng); + clusters.iter_mut().for_each(|cluster| cluster.clear()); + + // Distribute indices to ensure no cluster is empty and all are balanced + for (i, index) in all_indices.into_iter().enumerate() { + clusters[i % k].push(index); + } +} + +#[test] +fn test_clustering() { + let vectors = vec![ + [0u8; QUANTIZED_VECTOR_SIZE], + [0u8; QUANTIZED_VECTOR_SIZE], + [0u8; QUANTIZED_VECTOR_SIZE], + [0u8; QUANTIZED_VECTOR_SIZE], + [1u8; QUANTIZED_VECTOR_SIZE], + [1u8; QUANTIZED_VECTOR_SIZE], + [1u8; QUANTIZED_VECTOR_SIZE], + [1u8; QUANTIZED_VECTOR_SIZE], + ]; + + let clusters = balanced_k_modes_k_clusters(vectors, 2); + + for cluster in clusters { + println!("{:?}", cluster); + } +} diff --git a/src/structures/ann_tree/metadata.rs b/src/structures/ann_tree/metadata.rs new file mode 100644 index 0000000..d631cc3 --- /dev/null +++ b/src/structures/ann_tree/metadata.rs @@ -0,0 +1,274 @@ +use std::path::PathBuf; + +// use std::collections::HashSet; +use ahash::{AHashMap as HashMap, AHashSet as HashSet}; + +use serde::{Deserialize, Serialize}; + +use crate::structures::filters::Filter; +// use crate::structures::metadata_index::{KVPair, KVValue}; +use crate::structures::filters::{KVPair, KVValue}; +use crate::structures::mmap_tree::Tree; +use std::io; + +use crate::structures::mmap_tree::serialization::{TreeDeserialization, TreeSerialization}; + +#[derive(Debug, Clone)] +pub struct NodeMetadata { + pub values: HashSet, + pub int_range: Option<(i64, i64)>, + pub float_range: Option<(f32, f32)>, +} + +impl NodeMetadata { + pub fn new() -> Self { + NodeMetadata { + values: HashSet::new(), + int_range: None, + float_range: None, + } + } +} + +impl TreeSerialization for NodeMetadata { + fn serialize(&self) -> Vec { + let mut serialized = Vec::new(); + + serialized.extend_from_slice(self.values.len().to_le_bytes().as_ref()); + for value in &self.values { + serialized.extend_from_slice(value.len().to_le_bytes().as_ref()); + serialized.extend_from_slice(value.as_bytes()); + } + + if let Some((start, end)) = self.int_range { + serialized.extend_from_slice(start.to_le_bytes().as_ref()); + serialized.extend_from_slice(end.to_le_bytes().as_ref()); + } + + if let Some((start, end)) = self.float_range { + serialized.extend_from_slice(start.to_le_bytes().as_ref()); + serialized.extend_from_slice(end.to_le_bytes().as_ref()); + } + + serialized + } +} + +impl TreeDeserialization for NodeMetadata { + fn deserialize(serialized: &[u8]) -> Self { + let mut values = HashSet::new(); + + let mut offset = 0; + + let values_len = + u64::from_le_bytes(serialized[offset..offset + 8].try_into().unwrap()) as usize; + offset += 8; + + for _ in 0..values_len { + let value_len = + u64::from_le_bytes(serialized[offset..offset + 8].try_into().unwrap()) as usize; + offset += 8; + + let value = String::from_utf8(serialized[offset..offset + value_len].to_vec()).unwrap(); + offset += value_len; + + values.insert(value); + } + + let int_range = if offset < serialized.len() { + let start = i64::from_le_bytes(serialized[offset..offset + 8].try_into().unwrap()); + offset += 8; + let end = i64::from_le_bytes(serialized[offset..offset + 8].try_into().unwrap()); + offset += 8; + + Some((start, end)) + } else { + None + }; + + let float_range = if offset < serialized.len() { + let start = f32::from_le_bytes(serialized[offset..offset + 8].try_into().unwrap()); + offset += 4; + let end = f32::from_le_bytes(serialized[offset..offset + 8].try_into().unwrap()); + offset += 4; + + Some((start, end)) + } else { + None + }; + + NodeMetadata { + values, + int_range, + float_range, + } + } +} + +#[derive(Debug, Clone)] +pub struct NodeMetadataIndex { + pub data: HashMap, +} + +impl NodeMetadataIndex { + pub fn new() -> Self { + NodeMetadataIndex { + data: HashMap::new(), + } + } + + pub fn from_kv_pairs(kv_pairs: Vec<&KVPair>) -> Result { + let mut data: HashMap = HashMap::new(); + + for kv_pair in kv_pairs { + // if let Some(result) = tree + // .search(kv_pair.key.clone()) + // .expect("Failed to search tree") + // { + // let mut node_metadata = result.clone(); + // node_metadata.values.insert(kv_pair.value.clone()); + // tree.insert(kv_pair.key.clone(), node_metadata); + // } else { + // let mut node_metadata = NodeMetadata { + // values: HashSet::new(), + // }; + // node_metadata.values.insert(kv_pair.value.clone()); + // tree.insert(kv_pair.key.clone(), node_metadata); + // } + + match kv_pair.value.clone() { + KVValue::String(val) => { + if let Some(result) = data.get(&kv_pair.key) { + let mut node_metadata = result.clone(); + node_metadata.values.insert(val.clone()); + data.insert(kv_pair.key.clone(), node_metadata); + } else { + let mut node_metadata = NodeMetadata { + values: HashSet::new(), + float_range: None, + int_range: None, + }; + node_metadata.values.insert(val.clone()); + data.insert(kv_pair.key.clone(), node_metadata); + } + } + + KVValue::Integer(val) => { + if let Some(result) = data.get(&kv_pair.key) { + let mut node_metadata = result.clone(); + let current_min = node_metadata.int_range.unwrap().0; + let current_max = node_metadata.int_range.unwrap().1; + let new_value = val; + node_metadata.int_range = + Some((current_min.min(new_value), current_max.max(new_value))); + data.insert(kv_pair.key.clone(), node_metadata); + } else { + let node_metadata = NodeMetadata { + values: HashSet::new(), + float_range: None, + int_range: Some((val, val)), + }; + data.insert(kv_pair.key.clone(), node_metadata); + } + } + + KVValue::Float(val) => { + if let Some(result) = data.get(&kv_pair.key.clone()) { + let mut node_metadata = result.clone(); + let current_min = node_metadata.float_range.unwrap().0; + let current_max = node_metadata.float_range.unwrap().1; + let new_value = val; + node_metadata.float_range = + Some((current_min.min(new_value), current_max.max(new_value))); + data.insert(kv_pair.key.clone(), node_metadata); + } else { + let node_metadata = NodeMetadata { + values: HashSet::new(), + float_range: Some((val, val)), + int_range: None, + }; + data.insert(kv_pair.key.clone(), node_metadata); + } + } + } + } + + Ok(NodeMetadataIndex { data }) + } + + pub fn insert_kv_pair(&mut self, kv_pair: &KVPair) { + match kv_pair.value.clone() { + KVValue::String(val) => { + if let Some(result) = self.data.get(&kv_pair.key.clone()) { + let mut node_metadata = result.clone(); + node_metadata.values.insert(val.clone()); + self.data.insert(kv_pair.key.clone(), node_metadata); + } else { + let mut node_metadata = NodeMetadata { + values: HashSet::new(), + float_range: None, + int_range: None, + }; + node_metadata.values.insert(val.clone()); + self.data.insert(kv_pair.key.clone(), node_metadata); + } + } + + KVValue::Integer(val) => { + if let Some(result) = self.data.get(&kv_pair.key.clone()) { + let mut node_metadata = result.clone(); + let current_min = node_metadata.int_range.unwrap().0; + let current_max = node_metadata.int_range.unwrap().1; + let new_value = val; + node_metadata.int_range = + Some((current_min.min(new_value), current_max.max(new_value))); + self.data.insert(kv_pair.key.clone(), node_metadata); + } else { + let node_metadata = NodeMetadata { + values: HashSet::new(), + float_range: None, + int_range: Some((val, val)), + }; + self.data.insert(kv_pair.key.clone(), node_metadata); + } + } + + KVValue::Float(val) => { + if let Some(result) = self.data.get(&kv_pair.key.clone()) { + let mut node_metadata = result.clone(); + let current_min = node_metadata.float_range.unwrap().0; + let current_max = node_metadata.float_range.unwrap().1; + let new_value = val; + node_metadata.float_range = + Some((current_min.min(new_value), current_max.max(new_value))); + self.data.insert(kv_pair.key.clone(), node_metadata); + } else { + let node_metadata = NodeMetadata { + values: HashSet::new(), + float_range: Some((val, val)), + int_range: None, + }; + self.data.insert(kv_pair.key.clone(), node_metadata); + } + } + } + } + + pub fn get_all_values(&self) -> Vec<(&String, &NodeMetadata)> { + let mut all_values = Vec::new(); + + for (key, node_metadata) in self.data.iter() { + all_values.push((key, node_metadata)); + } + + all_values + } + + pub fn get(&self, key: String) -> Option<&NodeMetadata> { + self.data.get(&key) + } + + pub fn insert(&mut self, key: String, node_metadata: NodeMetadata) { + self.data.insert(key, node_metadata); + } +} diff --git a/src/structures/ann_tree/node.rs b/src/structures/ann_tree/node.rs new file mode 100644 index 0000000..885a7a6 --- /dev/null +++ b/src/structures/ann_tree/node.rs @@ -0,0 +1,455 @@ +use serde::Serialize; + +use crate::structures::ann_tree::k_modes::{balanced_k_modes, balanced_k_modes_4}; +// use crate::structures::metadata_index::{KVPair, KVValue}; +use crate::structures::filters::{KVPair, KVValue}; +use crate::{constants::QUANTIZED_VECTOR_SIZE, errors::HaystackError}; +use std::fmt::Debug; +use std::hash::Hash; +use std::path::PathBuf; + +use super::k_modes::balanced_k_modes_k_clusters; +use super::metadata::{NodeMetadata, NodeMetadataIndex}; + +use ahash::AHashMap as HashMap; +use ahash::AHashSet as HashSet; + +#[derive(Debug, PartialEq, Clone)] +pub enum NodeType { + Leaf, + Internal, +} + +pub fn serialize_node_type(node_type: &NodeType) -> [u8; 1] { + match node_type { + NodeType::Leaf => [0], + NodeType::Internal => [1], + } +} + +pub fn deserialize_node_type(data: &[u8]) -> NodeType { + match data[0] { + 0 => NodeType::Leaf, + 1 => NodeType::Internal, + _ => panic!("Invalid node type"), + } +} + +const K: usize = crate::constants::K; + +pub type Vector = [u8; QUANTIZED_VECTOR_SIZE]; + +#[derive(Clone)] +pub struct Node { + pub vectors: Vec, + pub ids: Vec, + pub children: Vec, + pub metadata: Vec>, + pub k: usize, + pub node_type: NodeType, + pub offset: usize, + pub is_root: bool, + pub parent_offset: Option, + pub node_metadata: NodeMetadataIndex, +} + +impl Node { + pub fn new_leaf() -> Self { + Node { + vectors: Vec::new(), + ids: Vec::new(), + children: Vec::new(), + metadata: Vec::new(), + k: K, + node_type: NodeType::Leaf, + offset: 0, + is_root: false, + parent_offset: None, + node_metadata: NodeMetadataIndex::new(), + } + } + + pub fn new_internal() -> Self { + Node { + vectors: Vec::new(), + ids: Vec::new(), + children: Vec::new(), + metadata: Vec::new(), + k: K, + node_type: NodeType::Internal, + offset: 0, + is_root: false, + parent_offset: None, + node_metadata: NodeMetadataIndex::new(), + } + } + + pub fn split(&mut self) -> Result, HaystackError> { + let k = match self.node_type { + NodeType::Leaf => 2, + NodeType::Internal => 2, + }; + if self.vectors.len() < k { + panic!("Cannot split a node with less than k keys"); + } + + // Assuming a modified balanced_k_modes that returns k clusters of indices + let clusters_indices = balanced_k_modes_k_clusters(self.vectors.clone(), k); + + let mut clusters_vectors = vec![Vec::new(); k]; + let mut clusters_ids = vec![Vec::new(); k]; + let mut clusters_children = vec![Vec::new(); k]; + let mut clusters_metadata = vec![Vec::new(); k]; + + // Distribute vectors, IDs, children, and metadata based on indices from clustering + for (i, indices) in clusters_indices.iter().enumerate() { + if indices.is_empty() { + panic!("Empty cluster found"); + } + for &index in indices { + clusters_vectors[i].push(self.vectors[index].clone()); + if self.node_type == NodeType::Leaf { + clusters_ids[i].push(self.ids[index].clone()); + clusters_metadata[i].push(self.metadata[index].clone()); + } + if self.node_type == NodeType::Internal { + clusters_children[i].push(self.children[index].clone()); + } + } + } + + let mut siblings = Vec::new(); + + // Create sibling nodes for the second, third, ..., k-th clusters + for i in 1..k { + let mut node_metadata = NodeMetadataIndex::new(); + + for kv in &clusters_metadata[i] { + for pair in kv { + node_metadata.insert_kv_pair(pair); + } + } + + let sibling = Node { + vectors: clusters_vectors[i].clone(), + ids: clusters_ids[i].clone(), + children: clusters_children[i].clone(), + metadata: clusters_metadata[i].clone(), + k: self.k, + node_type: self.node_type.clone(), + offset: 0, // This should be set when the node is stored + is_root: false, + parent_offset: self.parent_offset, + node_metadata, + }; + siblings.push(sibling.clone()); + + if sibling.node_type == NodeType::Internal + && (sibling.vectors.len() != sibling.clone().children.len() + || sibling.children.is_empty()) + { + panic!("Internal node vectors and children must be the same length"); + } + } + + // Update the current node with the first cluster + self.vectors = clusters_vectors[0].clone(); + self.ids = clusters_ids[0].clone(); + self.children = clusters_children[0].clone(); + self.metadata = clusters_metadata[0].clone(); + + println!("Node split into {} siblings", siblings.len()); + println!("Current node has {} vectors", self.vectors.len()); + + for sibling in &siblings { + println!("Sibling has {} vectors", sibling.vectors.len()); + } + + let mut node_metadata = NodeMetadataIndex::new(); + + for kv in &self.metadata { + for pair in kv { + node_metadata.insert_kv_pair(pair); + } + } + + self.node_metadata = node_metadata; + + Ok(siblings) + } + + pub fn is_full(&self) -> bool { + return self.vectors.len() >= self.k; + } + + pub fn serialize(&self) -> Vec { + let mut serialized = Vec::new(); + + // Serialize node_type and is_root + serialized.extend_from_slice(&serialize_node_type(&self.node_type)); + serialized.push(self.is_root as u8); + + // Serialize parent_offset + serialized.extend_from_slice(&(self.parent_offset.unwrap_or(0) as i64).to_le_bytes()); + + // Serialize offset + serialized.extend_from_slice(&self.offset.to_le_bytes()); + + // Serialize vectors + serialize_length(&mut serialized, self.vectors.len() as u32); + for vector in &self.vectors { + serialized.extend_from_slice(vector); + } + + // Serialize ids + serialize_length(&mut serialized, self.ids.len() as u32); + for id in &self.ids { + serialized.extend_from_slice(&id.to_le_bytes()); + } + + // Serialize children + serialize_length(&mut serialized, self.children.len() as u32); + for child in &self.children { + serialized.extend_from_slice(&child.to_le_bytes()); + } + + // Serialize metadata + serialize_length(&mut serialized, self.metadata.len() as u32); + for meta in &self.metadata { + // let serialized_meta = serialize_metadata(meta); // Function to serialize a Vec + // serialized.extend_from_slice(&serialized_meta); + serialize_metadata(&mut serialized, meta); + } + + // Serialize node_metadata + serialize_length( + &mut serialized, + self.node_metadata.get_all_values().len() as u32, + ); + for (key, item) in self.node_metadata.get_all_values() { + let serialized_key = key.as_bytes(); + serialize_length(&mut serialized, serialized_key.len() as u32); + serialized.extend_from_slice(serialized_key); + + let values = item.values.clone(); + + serialize_length(&mut serialized, values.len() as u32); + for value in values { + let serialized_value = value.as_bytes(); + serialize_length(&mut serialized, serialized_value.len() as u32); + serialized.extend_from_slice(serialized_value); + } + + let int_range = item.int_range.clone(); + if int_range.is_none() { + serialized.extend_from_slice(&(0 as i64).to_le_bytes()); + serialized.extend_from_slice(&(0 as i64).to_le_bytes()); + } else { + serialized.extend_from_slice(&int_range.unwrap().0.to_le_bytes()); + serialized.extend_from_slice(&int_range.unwrap().1.to_le_bytes()); + } + + let float_range = item.float_range.clone(); + if float_range.is_none() { + serialized.extend_from_slice(&(0 as f32).to_le_bytes()); + serialized.extend_from_slice(&(0 as f32).to_le_bytes()); + } else { + serialized.extend_from_slice(&float_range.unwrap().0.to_le_bytes()); + serialized.extend_from_slice(&float_range.unwrap().1.to_le_bytes()); + } + } + + serialized + } + + pub fn deserialize(data: &[u8]) -> Self { + let mut offset = 0; + + // Deserialize node_type and is_root + let node_type = deserialize_node_type(&data[offset..offset + 1]); + offset += 1; + let is_root = data[offset] == 1; + offset += 1; + + // Deserialize parent_offset + let parent_offset = + u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; + offset += 8; + + // Deserialize offset + let node_offset = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; + offset += 8; + + // Deserialize vectors + let vectors_len = read_length(&data[offset..offset + 4]); + offset += 4; + let mut vectors = Vec::with_capacity(vectors_len); + for _ in 0..vectors_len { + vectors.push( + data[offset..offset + QUANTIZED_VECTOR_SIZE] + .try_into() + .unwrap(), + ); + offset += QUANTIZED_VECTOR_SIZE; + } + + // Deserialize ids + let ids_len = read_length(&data[offset..offset + 4]); + offset += 4; + let mut ids = Vec::with_capacity(ids_len); + for _ in 0..ids_len { + let id = u128::from_le_bytes(data[offset..offset + 16].try_into().unwrap()); + offset += 16; + ids.push(id); + } + + // Deserialize children + let children_len = read_length(&data[offset..offset + 4]); + offset += 4; + let mut children = Vec::with_capacity(children_len); + for _ in 0..children_len { + let child = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; + offset += 8; + children.push(child); + } + + // Deserialize metadata + let metadata_len = read_length(&data[offset..offset + 4]); + offset += 4; + let mut metadata = Vec::with_capacity(metadata_len); + for _ in 0..metadata_len { + let (meta, meta_size) = deserialize_metadata(&data[offset..]); + metadata.push(meta); + offset += meta_size; // Increment offset based on actual size of deserialized metadata + } + + // Deserialize node_metadata + let mut node_metadata = NodeMetadataIndex::new(); + let node_metadata_len = read_length(&data[offset..offset + 4]); + offset += 4; + + for _ in 0..node_metadata_len { + let key_len = read_length(&data[offset..offset + 4]); + offset += 4; + + let key = String::from_utf8(data[offset..offset + key_len as usize].to_vec()).unwrap(); + offset += key_len as usize; + + let mut values = HashSet::new(); + let values_len = read_length(&data[offset..offset + 4]); + offset += 4; + + for _ in 0..values_len { + let value_len = read_length(&data[offset..offset + 4]); + offset += 4; + + let value = + String::from_utf8(data[offset..offset + value_len as usize].to_vec()).unwrap(); + offset += value_len as usize; + + values.insert(value); + } + + let mut item = NodeMetadata { + values: values.clone(), + int_range: None, + float_range: None, + }; + + let min_int = i64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()); + offset += 8; + let max_int = i64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()); + offset += 8; + + let min_float = f32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()); + offset += 4; + let max_float = f32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()); + offset += 4; + + item.int_range = Some((min_int, max_int)); + item.float_range = Some((min_float, max_float)); + + node_metadata.insert(key, item); + } + + Node { + vectors, + ids, + children, + metadata, + k: K, + node_type, + offset: node_offset, + is_root, + parent_offset: if parent_offset != 0 { + Some(parent_offset) + } else { + None + }, + node_metadata, + } + } +} +impl Default for Node { + fn default() -> Self { + Node { + vectors: Vec::new(), + ids: Vec::new(), + children: Vec::new(), + metadata: Vec::new(), + k: K, // Adjust this as necessary + node_type: NodeType::Leaf, // Or another appropriate default NodeType + offset: 0, + is_root: false, + parent_offset: None, + node_metadata: NodeMetadataIndex::new(), + } + } +} + +fn serialize_metadata(serialized: &mut Vec, metadata: &[KVPair]) { + // Serialize the length of the metadata vector + serialize_length(serialized, metadata.len() as u32); + + for kv in metadata { + let serialized_kv = kv.serialize(); // Assuming KVPair has a serialize method that returns Vec + // Serialize the length of this KVPair + serialize_length(serialized, serialized_kv.len() as u32); + // Append the serialized KVPair + serialized.extend_from_slice(&serialized_kv); + } +} + +fn deserialize_metadata(data: &[u8]) -> (Vec, usize) { + let mut offset = 0; + + // Read the length of the metadata vector + let metadata_len = read_length(&data[offset..offset + 4]) as usize; + offset += 4; + + let mut metadata = Vec::with_capacity(metadata_len); + for _ in 0..metadata_len { + // Read the length of the next KVPair + let kv_length = read_length(&data[offset..offset + 4]) as usize; + offset += 4; + + // Deserialize the KVPair from the next segment + let kv = KVPair::deserialize(&data[offset..offset + kv_length]); + offset += kv_length; + + metadata.push(kv); + } + + (metadata, offset) +} + +pub fn serialize_length(buffer: &mut Vec, length: u32) -> &Vec { + buffer.extend_from_slice(&length.to_le_bytes()); + + // Return the buffer to allow chaining + buffer +} + +pub fn read_length(data: &[u8]) -> usize { + u32::from_le_bytes(data.try_into().unwrap()) as usize +} diff --git a/src/structures/ann_tree/serialization.rs b/src/structures/ann_tree/serialization.rs new file mode 100644 index 0000000..9ec9f29 --- /dev/null +++ b/src/structures/ann_tree/serialization.rs @@ -0,0 +1,45 @@ +pub trait TreeSerialization { + fn serialize(&self) -> Vec; +} + +pub trait TreeDeserialization { + fn deserialize(data: &[u8]) -> Self + where + Self: Sized; +} + +impl TreeDeserialization for i32 { + fn deserialize(data: &[u8]) -> Self { + let mut bytes = [0; 4]; + bytes.copy_from_slice(&data[..4]); + i32::from_le_bytes(bytes) + } +} + +impl TreeSerialization for i32 { + fn serialize(&self) -> Vec { + self.to_le_bytes().to_vec() + } +} +impl TreeDeserialization for String { + fn deserialize(data: &[u8]) -> Self { + if data.len() < 4 { + panic!("Data too short to contain length prefix"); + } + let len = u32::from_le_bytes(data[0..4].try_into().unwrap()) as usize; // Read length + if data.len() < 4 + len { + panic!("Data too short for specified string length"); + } + let string_data = &data[4..4 + len]; // Extract string data + String::from_utf8(string_data.to_vec()).unwrap() + } +} + +impl TreeSerialization for String { + fn serialize(&self) -> Vec { + let mut data = Vec::new(); + data.extend_from_slice(&(self.len() as u32).to_le_bytes()); // Write length + data.extend_from_slice(self.as_bytes()); // Write string data + data + } +} diff --git a/src/structures/ann_tree/storage.rs b/src/structures/ann_tree/storage.rs new file mode 100644 index 0000000..aedb814 --- /dev/null +++ b/src/structures/ann_tree/storage.rs @@ -0,0 +1,379 @@ +use crate::services::LockService; + +use super::node::Node; +use memmap::MmapMut; +use std::fs; +use std::fs::OpenOptions; +use std::io; +use std::path::PathBuf; + +use super::serialization::{TreeDeserialization, TreeSerialization}; +use std::fmt::Debug; + +pub struct StorageManager { + pub mmap: MmapMut, + pub used_space: usize, + path: PathBuf, + locks: LockService, +} + +pub const SIZE_OF_USIZE: usize = std::mem::size_of::(); +pub const HEADER_SIZE: usize = SIZE_OF_USIZE * 2; // Used space + root offset + +pub const BLOCK_SIZE: usize = 4096; +pub const OVERFLOW_POINTER_SIZE: usize = SIZE_OF_USIZE; +pub const BLOCK_HEADER_SIZE: usize = SIZE_OF_USIZE + 1; // one byte for if it is the primary block or overflow block +pub const BLOCK_DATA_SIZE: usize = BLOCK_SIZE - OVERFLOW_POINTER_SIZE - BLOCK_HEADER_SIZE; + +impl StorageManager { + pub fn new(path: PathBuf) -> io::Result { + let exists = path.exists(); + let file = OpenOptions::new() + .read(true) + .write(true) + .create(!exists) + .open(path.clone())?; + + if !exists { + file.set_len(1_000_000)?; + } + + let mmap = unsafe { MmapMut::map_mut(&file)? }; + + // take path, remove everything after the last dot (the extension), and add _locks + let mut locks_path = path.clone().to_str().unwrap().to_string(); + let last_dot = locks_path.rfind('.').unwrap(); + locks_path.replace_range(last_dot.., "_locks"); + + fs::create_dir_all(&locks_path).expect("Failed to create directory"); + + let mut manager = StorageManager { + mmap, + used_space: 0, + path, + locks: LockService::new(locks_path.into()), + }; + + let used_space = if exists && manager.mmap.len() > HEADER_SIZE { + manager.used_space() + } else { + 0 + }; + + // println!("INIT Used space: {}", used_space); + + manager.set_used_space(used_space); + + Ok(manager) + } + + pub fn store_node(&mut self, node: &mut Node) -> io::Result { + let serialized = node.serialize(); + + // println!("Storing Serialized len: {}", serialized.len()); + + let serialized_len = serialized.len(); + + let num_blocks_required = (serialized_len + BLOCK_DATA_SIZE - 1) / BLOCK_DATA_SIZE; + + let mut needs_new_blocks = true; + + let mut prev_num_blocks_required = 0; + + if node.offset == 0 { + node.offset = self.increment_and_allocate_block()?; + // println!("Allocating block offset: {}", node.offset); + } else { + // println!("Using previous node offset: {}", node.offset); + let prev_serialized_len = usize::from_le_bytes( + self.read_from_offset(node.offset + 1, SIZE_OF_USIZE) + .try_into() + .unwrap(), + ); + prev_num_blocks_required = + (prev_serialized_len + BLOCK_DATA_SIZE - 1) / BLOCK_DATA_SIZE; + needs_new_blocks = num_blocks_required > prev_num_blocks_required; + + // println!( + // "Prev serialized len: {}, prev num blocks required: {}", + // prev_serialized_len, prev_num_blocks_required + // ); + } + + // println!( + // "Storing node at offset: {}, serialized len: {}", + // node.offset, serialized_len + // ); + + let mut current_block_offset = node.offset.clone(); + + let original_offset = current_block_offset.clone(); + + let mut remaining_bytes_to_write = serialized_len; + + let mut serialized_bytes_written = 0; + + let mut is_primary = 1u8; + + let mut blocks_written = 0; + + // + + // println!( + // "Num blocks required: {}, num blocks prev: {}, needs new blocks: {}", + // num_blocks_required, prev_num_blocks_required, needs_new_blocks + // ); + + self.acquire_block_lock(original_offset)?; + + while remaining_bytes_to_write > 0 { + let bytes_to_write = std::cmp::min(remaining_bytes_to_write, BLOCK_DATA_SIZE); + + // println!( + // "writing is primary: {}, at offset: {}", + // is_primary, current_block_offset + // ); + + self.write_to_offset(current_block_offset, is_primary.to_le_bytes().as_ref()); + + current_block_offset += 1; // one for the primary byte + + self.write_to_offset(current_block_offset, &serialized_len.to_le_bytes()); + + current_block_offset += SIZE_OF_USIZE; + self.write_to_offset( + current_block_offset, + &serialized[serialized_bytes_written..serialized_bytes_written + bytes_to_write], + ); + + blocks_written += 1; + serialized_bytes_written += bytes_to_write; + + remaining_bytes_to_write -= bytes_to_write; + // current_block_offset += BLOCK_DATA_SIZE; + current_block_offset += BLOCK_DATA_SIZE; // Move to the end of written data + + // println!( + // "Remaining bytes to write: {}, bytes written: {}", + // remaining_bytes_to_write, serialized_bytes_written + // ); + + if remaining_bytes_to_write > 0 { + let next_block_offset: usize; + + if needs_new_blocks && blocks_written >= prev_num_blocks_required { + next_block_offset = self.increment_and_allocate_block()?; + + self.write_to_offset(current_block_offset, &next_block_offset.to_le_bytes()); + } else { + next_block_offset = usize::from_le_bytes( + self.read_from_offset(current_block_offset, SIZE_OF_USIZE) + .try_into() + .unwrap(), + ); + + // if next_block_offset == 0 { + // next_block_offset = self.increment_and_allocate_block()?; + // println!("allocating bc 0 Next block offset: {}", next_block_offset); + // self.write_to_offset( + // current_block_offset, + // &next_block_offset.to_le_bytes(), + // ); + // } + + // println!("Next block offset: {}", next_block_offset); + } + + current_block_offset = next_block_offset; + } else { + self.write_to_offset(current_block_offset, &0u64.to_le_bytes()); + + // println!( + // "Setting next block offset to 0 at offset: {}", + // current_block_offset + // ); + // // Clear the remaining unused overflow blocks + // let mut next_block_offset = usize::from_le_bytes( + // self.read_from_offset(current_block_offset, SIZE_OF_USIZE) + // .try_into() + // .unwrap(), + // ); + + // while next_block_offset != 0 { + // let next_next_block_offset = usize::from_le_bytes( + // self.read_from_offset(next_block_offset + BLOCK_DATA_SIZE, SIZE_OF_USIZE) + // .try_into() + // .unwrap(), + // ); + + // println!("Clearing next block offset: {}", next_block_offset); + + // self.write_to_offset(next_block_offset + BLOCK_DATA_SIZE, &0u64.to_le_bytes()); + + // next_block_offset = next_next_block_offset; + // } + } + + is_primary = 0; + } + + self.release_block_lock(original_offset)?; + + Ok(node.offset) + } + + pub fn load_node(&self, offset: usize) -> io::Result { + let original_offset = offset.clone(); + let mut offset = offset.clone(); + + // println!("Loading node at offset: {}", offset); + + let mut serialized = Vec::new(); + + let mut is_primary; + + let mut serialized_len; + + let mut bytes_read = 0; + + self.acquire_block_lock(original_offset)?; + + loop { + let block_is_primary = + u8::from_le_bytes(self.read_from_offset(offset, 1).try_into().unwrap()); + + if block_is_primary == 0 { + is_primary = false; + } else if block_is_primary == 1 { + is_primary = true; + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Invalid block type", + )); + } + + if !is_primary && bytes_read == 0 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Primary block not found", + )); + } + + offset += 1; // one for the primary byte + + serialized_len = usize::from_le_bytes( + self.read_from_offset(offset, SIZE_OF_USIZE) + .try_into() + .unwrap(), + ); + + offset += SIZE_OF_USIZE; + + // println!("Serialized len: {}", serialized_len); + + let bytes_to_read = std::cmp::min(serialized_len - bytes_read, BLOCK_DATA_SIZE); + // println!( + // "Bytes read: {}, bytes to read: {}", + // bytes_read, bytes_to_read + // ); + + bytes_read += bytes_to_read; + + serialized.extend_from_slice(&self.read_from_offset(offset, bytes_to_read)); + + offset += BLOCK_DATA_SIZE; + + let next_block_offset = usize::from_le_bytes( + self.read_from_offset(offset, SIZE_OF_USIZE) + .try_into() + .unwrap(), + ); + + // println!("Next block offset: {}", next_block_offset); + + if next_block_offset == 0 { + break; + } + + offset = next_block_offset; + } + + self.release_block_lock(original_offset)?; + + let mut node = Node::deserialize(&serialized); + node.offset = original_offset; + + Ok(node) + } + + fn resize_mmap(&mut self) -> io::Result<()> { + let current_len = self.mmap.len(); + let new_len = current_len * 2; + + let file = OpenOptions::new() + .read(true) + .write(true) + .open(self.path.clone())?; // Ensure this path is handled correctly + + file.set_len(new_len as u64)?; + + self.mmap = unsafe { MmapMut::map_mut(&file)? }; + Ok(()) + } + + pub fn used_space(&self) -> usize { + usize::from_le_bytes(self.read_from_offset(0, SIZE_OF_USIZE).try_into().unwrap()) + } + + pub fn set_used_space(&mut self, used_space: usize) { + self.write_to_offset(0, &used_space.to_le_bytes()); + } + + pub fn root_offset(&self) -> usize { + usize::from_le_bytes( + self.read_from_offset(SIZE_OF_USIZE, SIZE_OF_USIZE) + .try_into() + .unwrap(), + ) + // self.root_offset + } + + pub fn set_root_offset(&mut self, root_offset: usize) { + self.write_to_offset(SIZE_OF_USIZE, &root_offset.to_le_bytes()); + // self.root_offset = root_offset; + } + + pub fn increment_and_allocate_block(&mut self) -> io::Result { + let used_space = self.used_space(); + // println!("Used space: {}", used_space); + self.set_used_space(used_space + BLOCK_SIZE); + let out = used_space + HEADER_SIZE; + // println!("Allocating block at offset: {}", out); + + if out + BLOCK_SIZE > self.mmap.len() { + self.resize_mmap()?; + } + + Ok(out) + } + + fn write_to_offset(&mut self, offset: usize, data: &[u8]) { + self.mmap[offset..offset + data.len()].copy_from_slice(data); + // self.mmap.flush().unwrap(); + } + + fn read_from_offset(&self, offset: usize, len: usize) -> &[u8] { + &self.mmap[offset..offset + len] + } + + fn acquire_block_lock(&self, offset: usize) -> io::Result<()> { + self.locks.acquire(offset.to_string())?; + Ok(()) + } + + fn release_block_lock(&self, offset: usize) -> io::Result<()> { + self.locks.release(offset.to_string())?; + Ok(()) + } +} diff --git a/src/structures/filters.rs b/src/structures/filters.rs index dda6c10..5ff757e 100644 --- a/src/structures/filters.rs +++ b/src/structures/filters.rs @@ -1,158 +1,590 @@ -use crate::structures::inverted_index::InvertedIndex; -use crate::structures::metadata_index::KVPair; +use ahash::{AHashMap as HashMap, AHashSet as HashSet}; use rayon::prelude::*; -use std::collections::HashSet; - use serde::{Deserialize, Serialize}; +use super::ann_tree::metadata::{NodeMetadata, NodeMetadataIndex}; +use crate::structures::mmap_tree::serialization::{TreeDeserialization, TreeSerialization}; +use std::fmt::Display; +use std::hash::{Hash, Hasher}; + #[derive(Debug, Serialize, Deserialize, Clone)] -#[serde(tag = "type", content = "args")] -pub enum Filter { - And(Vec), - Or(Vec), - In(String, Vec), // Assuming first String is the key and Vec is the list of values - Eq(String, String), // Assuming first String is the key and second String is the value +#[serde(untagged)] +pub enum KVValue { + String(String), + Integer(i64), + Float(f32), } -#[derive(Debug, Serialize, Deserialize)] -pub struct Query { - filters: Filter, +impl Display for KVValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + KVValue::String(s) => write!(f, "{}", s), + KVValue::Integer(i) => write!(f, "{}", i), + KVValue::Float(fl) => write!(f, "{}", fl), + } + } } -pub struct Filters { - pub current_indices: Vec, - pub current_ids: Vec, +impl Hash for KVValue { + fn hash(&self, state: &mut H) { + match self { + KVValue::String(s) => s.hash(state), + KVValue::Integer(i) => i.hash(state), + KVValue::Float(f) => { + let bits: u32 = f.to_bits(); + bits.hash(state); + } + } + } } -impl Filters { - pub fn new(indices: Vec, current_ids: Vec) -> Self { - Filters { - current_indices: indices, - current_ids: current_ids, +impl PartialEq for KVValue { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (KVValue::String(s1), KVValue::String(s2)) => s1 == s2, + (KVValue::Integer(i1), KVValue::Integer(i2)) => i1 == i2, + (KVValue::Float(f1), KVValue::Float(f2)) => (f1 - f2).abs() < 1e-6, + _ => false, } } +} + +impl Eq for KVValue {} - pub fn get_indices(&self) -> (Vec, Vec) { - (self.current_indices.clone(), self.current_ids.clone()) +impl PartialOrd for KVValue { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) } +} - pub fn set_indices(&mut self, indices: Vec, ids: Vec) { - self.current_indices = indices; - self.current_ids = ids; +impl Ord for KVValue { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + match (self, other) { + (KVValue::String(s1), KVValue::String(s2)) => s1.cmp(s2), + (KVValue::Integer(i1), KVValue::Integer(i2)) => i1.cmp(i2), + (KVValue::Float(f1), KVValue::Float(f2)) => f1.partial_cmp(f2).unwrap(), + _ => std::cmp::Ordering::Less, + } } +} - pub fn intersection(&self, other: &Filters) -> Filters { - let intersection_indices: Vec = self - .current_indices - .par_iter() - .filter(|&x| other.current_indices.contains(x)) - .cloned() - .collect(); +#[derive(Debug, Serialize, Deserialize, Clone, Hash)] +pub struct KVPair { + pub key: String, + pub value: KVValue, +} - let intersection_ids: Vec = self - .current_ids - .par_iter() - .filter(|&x| other.current_ids.contains(x)) - .cloned() - .collect(); +impl KVPair { + pub fn new(key: String, value: String) -> Self { + KVPair { + key, + value: KVValue::String(value), + } + } - Filters::new(intersection_indices, intersection_ids) + pub fn new_int(key: String, value: i64) -> Self { + KVPair { + key, + value: KVValue::Integer(value), + } } - pub fn union(&self, other: &Filters) -> Filters { - let mut union_indices = self.current_indices.clone(); - union_indices.extend(other.current_indices.iter().cloned()); - union_indices.sort_unstable(); - union_indices.dedup(); - let mut union_ids = self.current_ids.clone(); - union_ids.extend(other.current_ids.iter().cloned()); - union_ids.sort_unstable(); - union_ids.dedup(); + pub fn new_float(key: String, value: f32) -> Self { + KVPair { + key, + value: KVValue::Float(value), + } + } +} - Filters::new(union_indices, union_ids) +impl PartialEq for KVPair { + fn eq(&self, other: &Self) -> bool { + self.key == other.key && self.value == other.value } +} - pub fn difference(&self, other: &Filters) -> Filters { - let other_indices_set: HashSet<_> = other.current_indices.iter().collect(); - let difference_indices = self - .current_indices - .iter() - .filter(|&x| !other_indices_set.contains(x)) - .cloned() - .collect::>(); +impl Eq for KVPair {} - let other_ids_set: HashSet<_> = other.current_ids.iter().collect(); - let difference_ids = self - .current_ids - .iter() - .filter(|&x| !other_ids_set.contains(x)) - .cloned() - .collect::>(); +impl PartialOrd for KVPair { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} - Filters::new(difference_indices, difference_ids) +impl Ord for KVPair { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.key + .cmp(&other.key) + .then_with(|| self.value.cmp(&other.value)) } +} - pub fn is_subset(&self, other: &Filters) -> bool { - self.current_indices - .par_iter() - .all(|x| other.current_indices.contains(x)) - && self - .current_ids - .par_iter() - .all(|x| other.current_ids.contains(x)) +impl Display for KVPair { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "KVPair {{ key: {}, value: {} }}", self.key, self.value) } +} + +impl TreeSerialization for KVPair { + fn serialize(&self) -> Vec { + let mut serialized = Vec::new(); + + serialized.extend_from_slice(self.key.len().to_le_bytes().as_ref()); + serialized.extend_from_slice(self.key.as_bytes()); + // serialized.extend_from_slice(self.value.len().to_le_bytes().as_ref()); + // serialized.extend_from_slice(self.value.as_bytes()); + + match self.value.clone() { + KVValue::String(s) => { + serialized.push(0); + serialized.extend_from_slice(s.len().to_le_bytes().as_ref()); + serialized.extend_from_slice(s.as_bytes()); + } + KVValue::Integer(i) => { + serialized.push(1); + serialized.extend_from_slice(i.to_le_bytes().as_ref()); + } + KVValue::Float(f) => { + serialized.push(2); + serialized.extend_from_slice(f.to_bits().to_le_bytes().as_ref()); + } + } - pub fn is_superset(&self, other: &Filters) -> bool { - other.is_subset(self) + serialized } +} + +impl KVPair { + pub fn serialize(&self) -> Vec { + let mut serialized = Vec::new(); + + serialized.extend_from_slice(self.key.len().to_le_bytes().as_ref()); + serialized.extend_from_slice(self.key.as_bytes()); + // serialized.extend_from_slice(self.value.len().to_le_bytes().as_ref()); + // serialized.extend_from_slice(self.value.as_bytes()); - pub fn from_index(index: &mut InvertedIndex, key: &KVPair) -> Self { - match index.get(key.clone()) { - Some(item) => Filters::new(item.indices, item.ids), - None => Filters::new(vec![], vec![]), + match self.value.clone() { + KVValue::String(s) => { + serialized.push(0); + serialized.extend_from_slice(s.len().to_le_bytes().as_ref()); + serialized.extend_from_slice(s.as_bytes()); + } + KVValue::Integer(i) => { + serialized.push(1); + serialized.extend_from_slice(i.to_le_bytes().as_ref()); + } + KVValue::Float(f) => { + serialized.push(2); + serialized.extend_from_slice(f.to_bits().to_le_bytes().as_ref()); + } } + + serialized } - // Evaluate a Filter and return the resulting Filters object - pub fn evaluate(filter: &Filter, index: &mut InvertedIndex) -> Filters { + pub fn deserialize(data: &[u8]) -> Self { + let mut offset = 0; + + let key_len = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; + offset += 8; + let key = String::from_utf8(data[offset..offset + key_len].to_vec()).unwrap(); + offset += key_len; + + // let value_len = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; + // offset += 8; + // let value = String::from_utf8(data[offset..offset + value_len].to_vec()).unwrap(); + // // offset += value_len; + + let value_flag = data[offset]; + offset += 1; + + let value = match value_flag { + 0 => { + let value_len = + u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; + offset += 8; + let value = String::from_utf8(data[offset..offset + value_len].to_vec()).unwrap(); + KVValue::String(value) + } + 1 => { + let value = i64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()); + KVValue::Integer(value) + } + 2 => { + let bits = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()); + let value = f32::from_bits(bits); + KVValue::Float(value) + } + _ => KVValue::String("".to_string()), + }; + + KVPair { key, value } + } +} + +impl TreeDeserialization for KVPair { + fn deserialize(data: &[u8]) -> Self { + let mut offset = 0; + + let key_len = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; + offset += 8; + let key = String::from_utf8(data[offset..offset + key_len].to_vec()).unwrap(); + offset += key_len; + + // let value_len = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; + // offset += 8; + // let value = String::from_utf8(data[offset..offset + value_len].to_vec()).unwrap(); + // // offset += value_len; + + let value_flag = data[offset]; + offset += 1; + + let value = match value_flag { + 0 => { + let value_len = + u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; + offset += 8; + let value = String::from_utf8(data[offset..offset + value_len].to_vec()).unwrap(); + KVValue::String(value) + } + 1 => { + let value = i64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()); + KVValue::Integer(value) + } + 2 => { + let bits = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()); + let value = f32::from_bits(bits); + KVValue::Float(value) + } + _ => KVValue::String("".to_string()), + }; + + KVPair { key, value } + } +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(tag = "type", content = "args")] +pub enum Filter { + And(Vec), + Or(Vec), + In(String, Vec), + Eq(String, String), + Gt(String, f64), // Greater than + Gte(String, f64), // Greater than or equal + Lt(String, f64), // Less than + Lte(String, f64), // Less than or equal +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Query { + filters: Filter, +} + +pub struct Filters { + pub metadata: HashMap>, +} + +impl Filters { + pub fn new(metadata: HashMap>) -> Self { + Filters { metadata } + } + + // pub fn matches(&self, filters: &Filter) -> bool { + // match filters { + // Filter::And(filters) => filters.par_iter().all(|f| self.matches(f)), + // Filter::Or(filters) => filters.par_iter().any(|f| self.matches(f)), + // Filter::In(key, values) => match self.metadata.get(key) { + // Some(set) => values.iter().any(|v| set.contains(v)), + // None => false, + // }, + // Filter::Eq(key, value) => match self.metadata.get(key) { + // Some(set) => set.contains(value), + // None => false, + // }, + // } + // } + + pub fn should_prune(filter: &Filter, node_metadata: &NodeMetadataIndex) -> bool { match filter { - Filter::And(filters) => { - let mut result = Filters::new(vec![], vec![]); // Start with an empty set or universal set if applicable - for f in filters.iter() { - let current = Filters::evaluate(f, index); - if result.current_indices.is_empty() && result.current_ids.is_empty() { - result = current; - } else { - result = result.intersection(¤t); + Filter::And(filters) => filters + .par_iter() + .all(|f| Filters::should_prune(f, node_metadata)), + Filter::Or(filters) => filters + .par_iter() + .any(|f| Filters::should_prune(f, node_metadata)), + Filter::In(key, values) => match node_metadata.get(key.to_string()) { + Some(node_values) => !values.par_iter().any(|v| node_values.values.contains(v)), + None => true, + }, + Filter::Eq(key, value) => match node_metadata.get(key.to_string()) { + Some(node_values) => !node_values.values.contains(value), + None => true, + }, + Filter::Gt(key, value) => match node_metadata.get(key.to_string()) { + Some(node_values) => match node_values.float_range { + Some((min, _)) => min > (*value as f32), + None => match node_values.int_range { + Some((min, _)) => min > (*value as i64), + None => true, + }, + }, + None => true, + }, + Filter::Gte(key, value) => match node_metadata.get(key.to_string()) { + Some(node_values) => match node_values.float_range { + Some((min, _)) => min >= (*value as f32), + None => match node_values.int_range { + Some((min, _)) => min >= (*value as i64), + None => true, + }, + }, + None => true, + }, + Filter::Lt(key, value) => match node_metadata.get(key.to_string()) { + Some(node_values) => match node_values.float_range { + Some((_, max)) => max < (*value as f32), + None => match node_values.int_range { + Some((_, max)) => max < (*value as i64), + None => true, + }, + }, + None => true, + }, + Filter::Lte(key, value) => match node_metadata.get(key.to_string()) { + Some(node_values) => match node_values.float_range { + Some((_, max)) => max <= (*value as f32), + None => match node_values.int_range { + Some((_, max)) => max <= (*value as i64), + None => true, + }, + }, + None => true, + }, + } + } + + pub fn should_prune_metadata(filter: &Filter, metadata: &Vec) -> bool { + let node_metadata: NodeMetadataIndex = metadata + .into_iter() + .map(|kv_pair| (kv_pair.key.clone(), kv_pair.value.clone())) + .fold(NodeMetadataIndex::new(), |mut acc, (key, value)| { + match acc.get(key.clone()) { + Some(node_values) => match value { + KVValue::String(v) => { + let mut node_values = node_values.clone(); + node_values.values.insert(v); + acc.insert(key, node_values); + } + KVValue::Float(v) => { + let mut node_values = node_values.clone(); + let float_range = match node_values.float_range { + Some((min, max)) => Some((min.min(v), max.max(v))), + None => Some((v, v)), + }; + node_values.float_range = float_range; + acc.insert(key, node_values); + } + KVValue::Integer(v) => { + let mut node_values = node_values.clone(); + let int_range = match node_values.int_range { + Some((min, max)) => Some((min.min(v as i64), max.max(v as i64))), + None => Some((v as i64, v as i64)), + }; + node_values.int_range = int_range; + acc.insert(key, node_values); + } + }, + None => { + let mut node_values = NodeMetadata { + float_range: None, + int_range: None, + values: HashSet::new(), + }; + match value { + KVValue::String(v) => { + node_values.values.insert(v); + } + KVValue::Float(v) => { + node_values.float_range = Some((v, v)); + } + KVValue::Integer(v) => { + node_values.int_range = Some((v as i64, v as i64)); + } + } + + acc.insert(key, node_values); } } - result - } - Filter::Or(filters) => { - let mut result = Filters::new(vec![], vec![]); - for f in filters.iter() { - let current = Filters::evaluate(f, index); - result = result.union(¤t); + + acc + }); + + Filters::should_prune(filter, &node_metadata) + } +} + +pub fn combine_filters(filters: Vec) -> NodeMetadataIndex { + let mut result = NodeMetadataIndex::new(); + + for filter in filters { + for (key, values) in filter.get_all_values() { + match result.get(key.clone()) { + Some(node_values) => { + let mut node_values = node_values.clone(); + node_values.values.extend(values.values.iter().cloned()); + + let float_range = match (node_values.float_range, values.float_range) { + (Some((min1, max1)), Some((min2, max2))) => { + Some((min1.min(min2), max1.max(max2))) + } + (Some((min1, max1)), None) => Some((min1, max1)), + (None, Some((min2, max2))) => Some((min2, max2)), + _ => None, + }; + + let int_range = match (node_values.int_range, values.int_range) { + (Some((min1, max1)), Some((min2, max2))) => { + Some((min1.min(min2), max1.max(max2))) + } + (Some((min1, max1)), None) => Some((min1, max1)), + (None, Some((min2, max2))) => Some((min2, max2)), + _ => None, + }; + + node_values.float_range = float_range; + node_values.int_range = int_range; + + result.insert(key.clone(), node_values); } - result - } - Filter::In(key, values) => { - let mut result = Filters::new(vec![], vec![]); - for value in values.iter() { - let kv_pair = KVPair::new(key.clone(), value.clone()); // Ensure correct KVPair creation - let current = Filters::from_index(index, &kv_pair); - result = result.union(¤t); + None => { + let mut node_values = NodeMetadata { + float_range: None, + int_range: None, + values: HashSet::new(), + }; + node_values.values.extend(values.values.iter().cloned()); + node_values.float_range = values.float_range; + node_values.int_range = values.int_range; + + result.insert(key.clone(), node_values); } - result - } - Filter::Eq(key, value) => { - println!( - "Evaluating EQ filter for key: {:?}, value: {:?}", - key, value - ); // Debug output - let kv_pair = KVPair::new(key.clone(), value.clone()); // Ensure correct KVPair creation - Filters::from_index(index, &kv_pair) } } } + + result +} + +pub fn calc_metadata_index_for_metadata(kvs: Vec>) -> NodeMetadataIndex { + let node_metadata: NodeMetadataIndex = kvs + .into_iter() + .map(|metadata| { + metadata + .into_iter() + .map(|kv_pair| (kv_pair.key.clone(), kv_pair.value.clone())) + .fold(NodeMetadataIndex::new(), |mut acc, (key, value)| { + match acc.get(key.clone()) { + Some(node_values) => match value { + KVValue::String(v) => { + let mut node_values = node_values.clone(); + node_values.values.insert(v); + acc.insert(key, node_values); + } + KVValue::Float(v) => { + let mut node_values = node_values.clone(); + let float_range = match node_values.float_range { + Some((min, max)) => Some((min.min(v), max.max(v))), + None => Some((v, v)), + }; + node_values.float_range = float_range; + acc.insert(key, node_values); + } + KVValue::Integer(v) => { + let mut node_values = node_values.clone(); + let int_range = match node_values.int_range { + Some((min, max)) => { + Some((min.min(v as i64), max.max(v as i64))) + } + None => Some((v as i64, v as i64)), + }; + node_values.int_range = int_range; + acc.insert(key, node_values); + } + }, + None => { + let mut node_values = NodeMetadata { + float_range: None, + int_range: None, + values: HashSet::new(), + }; + match value { + KVValue::String(v) => { + node_values.values.insert(v); + } + KVValue::Float(v) => { + node_values.float_range = Some((v, v)); + } + KVValue::Integer(v) => { + node_values.int_range = Some((v as i64, v as i64)); + } + } + + acc.insert(key, node_values); + } + } + + acc + }) + }) + .fold(NodeMetadataIndex::new(), |mut acc, metadata| { + for (key, values) in metadata.get_all_values() { + match acc.get(key.clone()) { + Some(node_values) => { + let mut node_values = node_values.clone(); + node_values.values.extend(values.values.iter().cloned()); + + let float_range = match (node_values.float_range, values.float_range) { + (Some((min1, max1)), Some((min2, max2))) => { + Some((min1.min(min2), max1.max(max2))) + } + (Some((min1, max1)), None) => Some((min1, max1)), + (None, Some((min2, max2))) => Some((min2, max2)), + _ => None, + }; + + let int_range = match (node_values.int_range, values.int_range) { + (Some((min1, max1)), Some((min2, max2))) => { + Some((min1.min(min2), max1.max(max2))) + } + (Some((min1, max1)), None) => Some((min1, max1)), + (None, Some((min2, max2))) => Some((min2, max2)), + _ => None, + }; + + node_values.float_range = float_range; + + node_values.int_range = int_range; + + acc.insert(key.clone(), node_values); + } + None => { + let mut node_values = NodeMetadata { + float_range: None, + int_range: None, + values: HashSet::new(), + }; + node_values.values.extend(values.values.iter().cloned()); + node_values.float_range = values.float_range; + node_values.int_range = values.int_range; + + acc.insert(key.clone(), node_values); + } + } + } + + acc + }); + + node_metadata } diff --git a/src/structures/inverted_index.rs b/src/structures/inverted_index.rs index 2671d8f..afdcaff 100644 --- a/src/structures/inverted_index.rs +++ b/src/structures/inverted_index.rs @@ -179,10 +179,10 @@ impl InvertedIndex { // decompress the indices match v { Some(mut item) => { - println!("Search result: {:?}", item); // Add this + // println!("Search result: {:?}", item); // Add this item.indices = decompress_indices(item.indices); - println!("Decompressed indices: {:?}", item.indices); // Check output + // println!("Decompressed indices: {:?}", item.indices); // Check output Some(item) } diff --git a/src/structures/metadata_index.rs b/src/structures/metadata_index.rs index 502c0c2..63d95d9 100644 --- a/src/structures/metadata_index.rs +++ b/src/structures/metadata_index.rs @@ -1,22 +1,100 @@ use std::fmt::Display; use std::hash::Hash; -use std::path::PathBuf; +use std::hash::Hasher; use serde::{Deserialize, Serialize}; -use crate::structures::mmap_tree::Tree; +use crate::structures::tree::Tree; -use super::mmap_tree::serialization::{TreeDeserialization, TreeSerialization}; +use super::tree::serialization::{TreeDeserialization, TreeSerialization}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(untagged)] +pub enum KVValue { + String(String), + Integer(i64), + Float(f32), +} + +impl Display for KVValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + KVValue::String(s) => write!(f, "{}", s), + KVValue::Integer(i) => write!(f, "{}", i), + KVValue::Float(fl) => write!(f, "{}", fl), + } + } +} + +impl Hash for KVValue { + fn hash(&self, state: &mut H) { + match self { + KVValue::String(s) => s.hash(state), + KVValue::Integer(i) => i.hash(state), + KVValue::Float(f) => { + let bits: u32 = f.to_bits(); + bits.hash(state); + } + } + } +} + +impl PartialEq for KVValue { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (KVValue::String(s1), KVValue::String(s2)) => s1 == s2, + (KVValue::Integer(i1), KVValue::Integer(i2)) => i1 == i2, + (KVValue::Float(f1), KVValue::Float(f2)) => (f1 - f2).abs() < 1e-6, + _ => false, + } + } +} + +impl Eq for KVValue {} + +impl PartialOrd for KVValue { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for KVValue { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + match (self, other) { + (KVValue::String(s1), KVValue::String(s2)) => s1.cmp(s2), + (KVValue::Integer(i1), KVValue::Integer(i2)) => i1.cmp(i2), + (KVValue::Float(f1), KVValue::Float(f2)) => f1.partial_cmp(f2).unwrap(), + _ => std::cmp::Ordering::Less, + } + } +} #[derive(Debug, Serialize, Deserialize, Clone, Hash)] pub struct KVPair { pub key: String, - pub value: String, + pub value: KVValue, } impl KVPair { pub fn new(key: String, value: String) -> Self { - KVPair { key, value } + KVPair { + key, + value: KVValue::String(value), + } + } + + pub fn new_int(key: String, value: i64) -> Self { + KVPair { + key, + value: KVValue::Integer(value), + } + } + + pub fn new_float(key: String, value: f32) -> Self { + KVPair { + key, + value: KVValue::Float(value), + } } } @@ -54,13 +132,97 @@ impl TreeSerialization for KVPair { serialized.extend_from_slice(self.key.len().to_le_bytes().as_ref()); serialized.extend_from_slice(self.key.as_bytes()); - serialized.extend_from_slice(self.value.len().to_le_bytes().as_ref()); - serialized.extend_from_slice(self.value.as_bytes()); + // serialized.extend_from_slice(self.value.len().to_le_bytes().as_ref()); + // serialized.extend_from_slice(self.value.as_bytes()); + + match self.value.clone() { + KVValue::String(s) => { + serialized.push(0); + serialized.extend_from_slice(s.len().to_le_bytes().as_ref()); + serialized.extend_from_slice(s.as_bytes()); + } + KVValue::Integer(i) => { + serialized.push(1); + serialized.extend_from_slice(i.to_le_bytes().as_ref()); + } + KVValue::Float(f) => { + serialized.push(2); + serialized.extend_from_slice(f.to_bits().to_le_bytes().as_ref()); + } + } serialized } } +impl KVPair { + pub fn serialize(&self) -> Vec { + let mut serialized = Vec::new(); + + serialized.extend_from_slice(self.key.len().to_le_bytes().as_ref()); + serialized.extend_from_slice(self.key.as_bytes()); + // serialized.extend_from_slice(self.value.len().to_le_bytes().as_ref()); + // serialized.extend_from_slice(self.value.as_bytes()); + + match self.value.clone() { + KVValue::String(s) => { + serialized.push(0); + serialized.extend_from_slice(s.len().to_le_bytes().as_ref()); + serialized.extend_from_slice(s.as_bytes()); + } + KVValue::Integer(i) => { + serialized.push(1); + serialized.extend_from_slice(i.to_le_bytes().as_ref()); + } + KVValue::Float(f) => { + serialized.push(2); + serialized.extend_from_slice(f.to_bits().to_le_bytes().as_ref()); + } + } + + serialized + } + + pub fn deserialize(data: &[u8]) -> Self { + let mut offset = 0; + + let key_len = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; + offset += 8; + let key = String::from_utf8(data[offset..offset + key_len].to_vec()).unwrap(); + offset += key_len; + + // let value_len = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; + // offset += 8; + // let value = String::from_utf8(data[offset..offset + value_len].to_vec()).unwrap(); + // // offset += value_len; + + let value_flag = data[offset]; + offset += 1; + + let value = match value_flag { + 0 => { + let value_len = + u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; + offset += 8; + let value = String::from_utf8(data[offset..offset + value_len].to_vec()).unwrap(); + KVValue::String(value) + } + 1 => { + let value = i64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()); + KVValue::Integer(value) + } + 2 => { + let bits = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()); + let value = f32::from_bits(bits); + KVValue::Float(value) + } + _ => KVValue::String("".to_string()), + }; + + KVPair { key, value } + } +} + impl TreeDeserialization for KVPair { fn deserialize(data: &[u8]) -> Self { let mut offset = 0; @@ -70,10 +232,33 @@ impl TreeDeserialization for KVPair { let key = String::from_utf8(data[offset..offset + key_len].to_vec()).unwrap(); offset += key_len; - let value_len = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; - offset += 8; - let value = String::from_utf8(data[offset..offset + value_len].to_vec()).unwrap(); - // offset += value_len; + // let value_len = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; + // offset += 8; + // let value = String::from_utf8(data[offset..offset + value_len].to_vec()).unwrap(); + // // offset += value_len; + + let value_flag = data[offset]; + offset += 1; + + let value = match value_flag { + 0 => { + let value_len = + u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; + offset += 8; + let value = String::from_utf8(data[offset..offset + value_len].to_vec()).unwrap(); + KVValue::String(value) + } + 1 => { + let value = i64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()); + KVValue::Integer(value) + } + 2 => { + let bits = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()); + let value = f32::from_bits(bits); + KVValue::Float(value) + } + _ => KVValue::String("".to_string()), + }; KVPair { key, value } } @@ -150,8 +335,7 @@ impl TreeDeserialization for MetadataIndexItem { // kvs.push(KVPair { key, value }); - let kv_len = - usize::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; + let kv_len = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; offset += 8; let kv = TreeDeserialization::deserialize(&data[offset..offset + kv_len]); @@ -166,7 +350,8 @@ impl TreeDeserialization for MetadataIndexItem { let id = u128::from_le_bytes(data[offset..offset + 16].try_into().unwrap()); offset += 16; - let vector_index = usize::from_le_bytes(data[offset..offset + 8].try_into().unwrap()); + let vector_index = + u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; // offset += 8; // let namespaced_id_len = @@ -199,14 +384,13 @@ impl TreeDeserialization for u128 { } pub struct MetadataIndex { - pub path: PathBuf, pub tree: Tree, } impl MetadataIndex { - pub fn new(path: PathBuf) -> Self { - let tree = Tree::new(path.clone()).expect("Failed to create tree"); - MetadataIndex { path, tree } + pub fn new() -> Self { + let tree = Tree::new().expect("Failed to create tree"); + MetadataIndex { tree } } pub fn insert(&mut self, key: u128, value: MetadataIndexItem) { @@ -226,4 +410,17 @@ impl MetadataIndex { Err(_) => None, } } + + pub fn len(&self) -> usize { + self.tree.len() + } + + pub fn to_binary(&mut self) -> Vec { + self.tree.to_binary() + } + + pub fn from_binary(data: Vec) -> Self { + let tree = Tree::from_binary(data).expect("Failed to create tree from binary"); + MetadataIndex { tree } + } } diff --git a/src/structures/mmap_tree.rs b/src/structures/mmap_tree.rs index 525754a..a9dffe7 100644 --- a/src/structures/mmap_tree.rs +++ b/src/structures/mmap_tree.rs @@ -172,11 +172,11 @@ where Ok(()) } - pub fn search(&mut self, key: K) -> Result, io::Error> { + pub fn search(&self, key: K) -> Result, io::Error> { self.search_node(self.storage_manager.root_offset(), key) } - fn search_node(&mut self, node_offset: usize, key: K) -> Result, io::Error> { + fn search_node(&self, node_offset: usize, key: K) -> Result, io::Error> { // println!("Searching for key: {} at offset: {}", key, node_offset); let node = self.storage_manager.load_node(node_offset)?; @@ -192,11 +192,11 @@ where } } - pub fn has_key(&mut self, key: K) -> Result { + pub fn has_key(&self, key: K) -> Result { self.has_key_node(self.storage_manager.root_offset(), key) } - pub fn has_key_node(&mut self, node_offset: usize, key: K) -> Result { + pub fn has_key_node(&self, node_offset: usize, key: K) -> Result { let node = self.storage_manager.load_node(node_offset)?; match node.node_type { diff --git a/src/structures/mmap_tree/serialization.rs b/src/structures/mmap_tree/serialization.rs index 9ec9f29..512ed7c 100644 --- a/src/structures/mmap_tree/serialization.rs +++ b/src/structures/mmap_tree/serialization.rs @@ -16,6 +16,32 @@ impl TreeDeserialization for i32 { } } +impl TreeSerialization for u128 { + fn serialize(&self) -> Vec { + self.to_le_bytes().to_vec() + } +} + +impl TreeDeserialization for u128 { + fn deserialize(data: &[u8]) -> Self { + let mut bytes = [0; 16]; + bytes.copy_from_slice(&data[..16]); + u128::from_le_bytes(bytes) + } +} + +impl TreeSerialization for Vec { + fn serialize(&self) -> Vec { + self.clone() + } +} + +impl TreeDeserialization for Vec { + fn deserialize(data: &[u8]) -> Self { + data.to_vec() + } +} + impl TreeSerialization for i32 { fn serialize(&self) -> Vec { self.to_le_bytes().to_vec() diff --git a/src/structures/mmap_tree/storage.rs b/src/structures/mmap_tree/storage.rs index a3d0b47..ddee8cf 100644 --- a/src/structures/mmap_tree/storage.rs +++ b/src/structures/mmap_tree/storage.rs @@ -228,7 +228,7 @@ where Ok(node.offset) } - pub fn load_node(&mut self, offset: usize) -> io::Result> { + pub fn load_node(&self, offset: usize) -> io::Result> { let original_offset = offset.clone(); let mut offset = offset.clone(); diff --git a/src/structures/wal.rs b/src/structures/wal.rs index e14a0c6..1400b2b 100644 --- a/src/structures/wal.rs +++ b/src/structures/wal.rs @@ -1,21 +1,26 @@ use crate::constants::{QUANTIZED_VECTOR_SIZE, VECTOR_SIZE}; -use super::{ - metadata_index::KVPair, - mmap_tree::{ - serialization::{TreeDeserialization, TreeSerialization}, - Tree, - }, +use super::mmap_tree::{ + serialization::{TreeDeserialization, TreeSerialization}, + Tree, }; +use crate::structures::filters::KVPair; use crate::utils::quantize; -use std::hash::{Hash, Hasher}; +use chrono::{NaiveDateTime, Utc}; +use rusqlite::{params, Connection, Error, Result, ToSql}; +use serde_json::json; use std::{ fmt::Display, + fs, hash::DefaultHasher, io, path::PathBuf, time::{SystemTime, UNIX_EPOCH}, }; +use std::{ + hash::{Hash, Hasher}, + sync::Arc, +}; #[derive(Debug, Clone)] pub struct CommitListItem { @@ -167,169 +172,152 @@ impl TreeDeserialization for Vec { } pub struct WAL { - pub commit_list: Tree, - pub timestamps: Tree>, // maps a timestamp to a hash - pub commit_finish: Tree, + pub conn: Connection, pub path: PathBuf, pub namespace_id: String, } impl WAL { - pub fn new(path: PathBuf, namespace_id: String) -> io::Result { - let commit_list_path = path.clone().join("commit_list.bin"); - let commit_list = Tree::::new(commit_list_path)?; - let timestamps_path = path.clone().join("timestamps.bin"); - let timestamps = Tree::>::new(timestamps_path)?; - let commit_finish_path = path.clone().join("commit_finish.bin"); - let commit_finish = Tree::::new(commit_finish_path)?; + pub fn new(path: PathBuf, namespace_id: String) -> Result { + let db_path = path.join("wal.db"); + + // Create the directory if it doesn't exist + // fs::create_dir_all(&path).expect("Failed to create directory"); + + let conn = Connection::open(db_path.clone())?; + + // Enable WAL mode + conn.execute_batch("PRAGMA journal_mode = WAL; PRAGMA synchronous = NORMAL;")?; + + // Create the table if it doesn't exist + conn.execute_batch( + "CREATE TABLE IF NOT EXISTS wal ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + hash INTEGER NOT NULL, + data BLOB NOT NULL, + metadata TEXT NOT NULL, + added_timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, + committed_timestamp DATETIME + );", + )?; Ok(WAL { - commit_list, - path, + conn, + path: db_path, namespace_id, - timestamps, - commit_finish, }) } + fn u64_to_i64(&self, value: u64) -> i64 { + // Safely convert u64 to i64 by reinterpreting the bits + i64::from_ne_bytes(value.to_ne_bytes()) + } + + fn i64_to_u64(&self, value: i64) -> u64 { + // Safely convert i64 to u64 by reinterpreting the bits + u64::from_ne_bytes(value.to_ne_bytes()) + } + pub fn add_to_commit_list( &mut self, hash: u64, vectors: Vec<[u8; QUANTIZED_VECTOR_SIZE]>, kvs: Vec>, - ) -> Result<(), io::Error> { + ) -> Result<()> { let timestamp = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_secs(); - let commit_list_item = CommitListItem { - hash, - timestamp, - vectors, - kvs, - }; - - self.commit_list.insert(hash, commit_list_item)?; - - // self.commit_finish.insert(hash, false)?; + let metadata = json!(kvs).to_string(); + let data: Vec = vectors.iter().flat_map(|v| v.to_vec()).collect(); - // self.timestamps.insert(timestamp, hash)?; + self.conn.execute( + "INSERT INTO wal (hash, data, metadata, added_timestamp) VALUES (?1, ?2, ?3, ?4);", + params![ + self.u64_to_i64(hash), + &data, + &metadata, + self.u64_to_i64(timestamp) + ], + )?; Ok(()) } - pub fn has_been_committed(&mut self, hash: u64) -> Result { - match self.commit_list.has_key(hash) { - Ok(r) => Ok(r), - Err(_) => Ok(false), - } - } - - // pub fn get_commits_after(&self, timestamp: u64) -> Result, io::Error> { - // let hashes = self.timestamps.get_range(timestamp, u64::MAX)?; - - // let mut commits = Vec::new(); - - // for (_, hash) in hashes { - // match self.commit_list.search(hash) { - // Ok(commit) => match commit { - // Some(c) => { - // commits.push(c); - // } - // None => {} - // }, - // Err(_) => {} - // } - // } - - // Ok(commits) - // } - - pub fn get_commits(&mut self) -> Result, io::Error> { - let start = 0; - let end = u64::MAX; - - let commits = self - .commit_list - .get_range(start, end) - .expect("Error getting commits"); - - Ok(commits.into_iter().map(|(_, v)| v).collect()) - } - - pub fn get_commit(&mut self, hash: u64) -> Result, io::Error> { - match self.commit_list.search(hash) { - Ok(v) => Ok(v), - Err(_) => Ok(None), - } + pub fn has_been_committed(&mut self, hash: u64) -> Result { + let mut stmt = self + .conn + .prepare("SELECT 1 FROM wal WHERE hash = ?1 AND committed_timestamp IS NOT NULL;")?; + let mut rows = stmt.query(params![hash])?; + Ok(rows.next()?.is_some()) } - pub fn get_commits_before(&mut self, timestamp: u64) -> Result, io::Error> { - let hash_end = self.timestamps.get_range(0, timestamp)?; + pub fn get_commits(&mut self) -> Result> { + let mut stmt = self.conn.prepare( + "SELECT hash, data, metadata, added_timestamp, committed_timestamp FROM wal;", + )?; + let rows = stmt.query_map(params![], |row| { + let data: Vec = row.get(1)?; + let vectors = data + .chunks(QUANTIZED_VECTOR_SIZE) + .map(|chunk| { + let mut arr = [0; QUANTIZED_VECTOR_SIZE]; + arr.copy_from_slice(chunk); + arr + }) + .collect(); + + Ok(CommitListItem { + hash: self.i64_to_u64(row.get(0)?), + timestamp: self.i64_to_u64(row.get(3)?), + vectors, + kvs: serde_json::from_str(&row.get::<_, String>(2)?).unwrap(), + }) + })?; let mut commits = Vec::new(); - - for (_, hash) in hash_end { - for h in hash { - match self.commit_list.search(h) { - Ok(commit) => match commit { - Some(c) => { - commits.push(c); - } - None => {} - }, - Err(_) => {} - } - } + for commit in rows { + commits.push(commit?); } - // println!("Commits before: {:?}", commits.len()); - Ok(commits) } - pub fn get_uncommitted(&mut self, last_seconds: u64) -> Result, io::Error> { - let start = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() - - last_seconds; - - let end = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() - + 1; - - let all_hashes = self.timestamps.get_range(start, end)?; - - let mut commits = Vec::new(); - - for (_, hashes) in all_hashes { - for hash in hashes { - match self.commit_finish.has_key(hash) { - Ok(has_key) => { - if !has_key { - match self.commit_list.search(hash) { - Ok(commit) => match commit { - Some(c) => { - commits.push(c); - } - None => {} - }, - Err(_) => {} - } - } - } - Err(_) => {} - } - } + pub fn get_commit(&mut self, hash: u64) -> Result> { + let mut stmt = self.conn.prepare("SELECT hash, data, metadata, added_timestamp, committed_timestamp FROM wal WHERE hash = ?1;")?; + let mut rows = stmt.query(params![hash])?; + + if let Some(row) = rows.next()? { + let data: Vec = row.get(1)?; + let vectors = data + .chunks(QUANTIZED_VECTOR_SIZE) + .map(|chunk| { + let mut arr = [0; QUANTIZED_VECTOR_SIZE]; + arr.copy_from_slice(chunk); + arr + }) + .collect(); + + return Ok(Some(CommitListItem { + hash: self.i64_to_u64(row.get(0)?), + timestamp: self.i64_to_u64(row.get(3)?), + vectors, + kvs: serde_json::from_str(&row.get::<_, String>(2)?).unwrap(), + })); } - // commits.dedup_by_key(|c| c.hash); + Ok(None) + } + + pub fn mark_commit_finished(&mut self, hash: u64) -> Result<()> { + let committed_timestamp = Utc::now().naive_utc(); + self.conn.execute( + "UPDATE wal SET committed_timestamp = ?1 WHERE hash = ?2;", + params![committed_timestamp.to_string(), self.u64_to_i64(hash)], + )?; - Ok(commits) + Ok(()) } pub fn compute_hash( @@ -337,15 +325,9 @@ impl WAL { vectors: &Vec<[u8; QUANTIZED_VECTOR_SIZE]>, kvs: &Vec>, ) -> u64 { - let mut hasher = DefaultHasher::default(); - - // for vector in vectors { - // vector.hash(&mut hasher); - // } + let mut hasher = DefaultHasher::new(); vectors.hash(&mut hasher); - kvs.hash(&mut hasher); - hasher.finish() } @@ -353,92 +335,101 @@ impl WAL { &mut self, vectors: Vec<[f32; VECTOR_SIZE]>, kvs: Vec>, - ) -> io::Result<()> { - if vectors.len() != kvs.len() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Quantized vectors length mismatch", - )); - } - + ) -> Result<()> { let quantized_vectors: Vec<[u8; QUANTIZED_VECTOR_SIZE]> = vectors.iter().map(|v| quantize(v)).collect(); - let hash = self.compute_hash(&quantized_vectors, &kvs); - - let current_timestamp = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs(); - - // println!("Current timestamp: {}", current_timestamp); - - let mut current_timestamp_vals = match self.timestamps.search(current_timestamp) { - Ok(v) => v, - Err(_) => Some(Vec::new()), - } - .unwrap_or(Vec::new()); - - current_timestamp_vals.push(hash); - - self.timestamps - .insert(current_timestamp, current_timestamp_vals)?; - - self.add_to_commit_list(hash, quantized_vectors, kvs)?; - - Ok(()) + self.add_to_commit_list(hash, quantized_vectors, kvs) } pub fn batch_add_to_wal( &mut self, vectors: Vec>, kvs: Vec>>, - ) -> io::Result<()> { - if vectors.len() != kvs.len() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Quantized vectors length mismatch", - )); - } - + ) -> Result<()> { let quantized_vectors: Vec> = vectors .iter() .map(|v| v.iter().map(|v| quantize(v)).collect()) .collect(); - let mut hashes = Vec::new(); + for (v, k) in quantized_vectors.iter().zip(kvs.iter()) { + let hash = self.compute_hash(v, k); + self.add_to_commit_list(hash, v.clone(), k.clone())?; + } - let current_timestamp = SystemTime::now() + Ok(()) + } + + pub fn get_uncommitted(&mut self, last_seconds: u64) -> Result> { + let current_time = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_secs(); - let mut current_timestamp_vals = match self.timestamps.search(current_timestamp) { - Ok(v) => v, - Err(_) => Some(Vec::new()), - } - .unwrap_or(Vec::new()); + let start_time = self.u64_to_i64(current_time.saturating_sub(last_seconds)); + + let mut stmt = self.conn.prepare( + "SELECT hash, data, metadata, added_timestamp, committed_timestamp + FROM wal + WHERE added_timestamp >= ?1 AND committed_timestamp IS NULL;", + )?; + let rows = stmt.query_map(params![start_time], |row| { + let data: Vec = row.get(1)?; + let vectors = data + .chunks(QUANTIZED_VECTOR_SIZE) + .map(|chunk| { + let mut arr = [0; QUANTIZED_VECTOR_SIZE]; + arr.copy_from_slice(chunk); + arr + }) + .collect(); + + Ok(CommitListItem { + hash: self.i64_to_u64(row.get(0)?), + timestamp: self.i64_to_u64(row.get(3)?), + vectors, + kvs: serde_json::from_str(&row.get::<_, String>(2)?).unwrap(), + }) + })?; - for (_i, (v, k)) in quantized_vectors.iter().zip(kvs.iter()).enumerate() { - let hash = self.compute_hash(v, k); - hashes.push(hash); - - current_timestamp_vals.push(hash); - } - - self.timestamps - .insert(current_timestamp, current_timestamp_vals)?; - - for (hash, (v, k)) in hashes.iter().zip(quantized_vectors.iter().zip(kvs.iter())) { - self.add_to_commit_list(*hash, v.clone(), k.clone())?; + let mut commits = Vec::new(); + for commit in rows { + commits.push(commit?); } - Ok(()) + Ok(commits) } - pub fn mark_commit_finished(&mut self, hash: u64) -> io::Result<()> { - self.commit_finish.insert(hash, true)?; + pub fn get_commits_before(&self, ts: u64) -> Result> { + let mut stmt = self.conn.prepare( + "SELECT hash, data, metadata, added_timestamp, committed_timestamp + FROM wal + WHERE added_timestamp < ?1 AND committed_timestamp IS NOT NULL;", + )?; + let rows = stmt.query_map(params![self.u64_to_i64(ts)], |row| { + let data: Vec = row.get(1)?; + let vectors = data + .chunks(QUANTIZED_VECTOR_SIZE) + .map(|chunk| { + let mut arr = [0; QUANTIZED_VECTOR_SIZE]; + arr.copy_from_slice(chunk); + arr + }) + .collect(); + + Ok(CommitListItem { + hash: self.i64_to_u64(row.get(0)?), + timestamp: self.i64_to_u64(row.get(3)?), + vectors, + kvs: serde_json::from_str(&row.get::<_, String>(2)?).unwrap(), + }) + })?; - Ok(()) + let mut commits = Vec::new(); + for commit in rows { + commits.push(commit?); + } + + Ok(commits) } } From 789e8deb24e1ec98e27e53f45a1bcfc29078299e Mon Sep 17 00:00:00 2001 From: Carson Poole Date: Wed, 22 May 2024 11:55:44 -0400 Subject: [PATCH 2/6] fix paths --- src/main.rs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/main.rs b/src/main.rs index f3ee897..7485d77 100644 --- a/src/main.rs +++ b/src/main.rs @@ -36,7 +36,7 @@ async fn main() { .and(with_active_namespaces(active_namespaces.clone())) .then( |namespace_id: String, body: (Vec, QueryFilter, usize), active_namespaces| async move { - let base_path = PathBuf::from(format!("/Users/carsonpoole/haystackdb/tests/data/{}/current", namespace_id.clone())); + let base_path = PathBuf::from(format!("/workspace/data/{}/current", namespace_id.clone())); ensure_namespace_initialized(&namespace_id, &active_namespaces, base_path.clone()) .await; @@ -73,7 +73,7 @@ async fn main() { body: (Vec, Vec, String), active_namespaces| async move { let base_path = PathBuf::from(format!( - "/Users/carsonpoole/haystackdb/tests/data/{}/current", + "/workspace/data/{}/current", namespace_id.clone() )); @@ -107,10 +107,8 @@ async fn main() { .then( |namespace_id: String, timestamp: String, active_namespaces| async move { println!("PITR for namespace: {}", namespace_id); - let base_path = PathBuf::from(format!( - "/Users/carsonpoole/haystackdb/tests/data/{}/current", - namespace_id.clone() - )); + let base_path = + PathBuf::from(format!("/workspace/data/{}/current", namespace_id.clone())); ensure_namespace_initialized(&namespace_id, &active_namespaces, base_path.clone()) .await; From 75c61721e6401817fad893eba639665427d4cf64 Mon Sep 17 00:00:00 2001 From: Carson Poole Date: Thu, 23 May 2024 03:11:28 +0000 Subject: [PATCH 3/6] lots of bug fixes and redesigned storage engine --- Cargo.lock | 37 ++ Cargo.toml | 1 + src/constants.rs | 2 +- src/services/commit.rs | 105 ++++-- src/services/query.rs | 15 +- src/structures/ann_tree.rs | 2 +- src/structures/ann_tree/node.rs | 65 +++- src/structures/ann_tree/storage.rs | 531 ++++++++++++++-------------- src/structures/filters.rs | 4 +- src/structures/mmap_tree.rs | 225 ++++++------ src/structures/mmap_tree/storage.rs | 2 +- src/structures/wal.rs | 2 +- src/utils.rs | 26 ++ 13 files changed, 588 insertions(+), 429 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 526317c..871f4df 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -39,6 +39,21 @@ dependencies = [ "memchr", ] +[[package]] +name = "alloc-no-stdlib" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" + +[[package]] +name = "alloc-stdlib" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" +dependencies = [ + "alloc-no-stdlib", +] + [[package]] name = "android-tzdata" version = "0.1.1" @@ -156,6 +171,27 @@ dependencies = [ "generic-array", ] +[[package]] +name = "brotli" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74f7971dbd9326d58187408ab83117d8ac1bb9c17b085fdacd1cf2f598719b6b" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", + "brotli-decompressor", +] + +[[package]] +name = "brotli-decompressor" +version = "4.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6221fe77a248b9117d431ad93761222e1cf8ff282d9d1d5d9f53d6299a1cf76" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", +] + [[package]] name = "bumpalo" version = "3.16.0" @@ -574,6 +610,7 @@ name = "haystackdb" version = "0.1.0" dependencies = [ "ahash", + "brotli", "chrono", "criterion", "env_logger", diff --git a/Cargo.toml b/Cargo.toml index 1d32712..5da202f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ ahash = "0.8.11" rand = "0.8.4" rusqlite = "0.31.0" chrono = "0.4" +brotli = "6.0.0" [profile.release] opt-level = 3 diff --git a/src/constants.rs b/src/constants.rs index 89968db..efd3c16 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -1,6 +1,6 @@ pub const VECTOR_SIZE: usize = 1024; pub const QUANTIZED_VECTOR_SIZE: usize = 128; -pub const K: usize = 512; +pub const K: usize = 128; pub const C: usize = 1; pub const ALPHA: usize = 64; pub const BETA: usize = 3; diff --git a/src/services/commit.rs b/src/services/commit.rs index 515fc77..d2abe40 100644 --- a/src/services/commit.rs +++ b/src/services/commit.rs @@ -2,6 +2,7 @@ use crate::constants::VECTOR_SIZE; // use crate::structures::inverted_index::InvertedIndexItem; // use crate::structures::metadata_index::{KVPair, MetadataIndexItem}; use crate::structures::filters::{KVPair, KVValue}; +use crate::utils::compress_string; use rusqlite::Result; use super::namespace_state::NamespaceState; @@ -28,46 +29,86 @@ impl CommitService { println!("Commits to process: {:?}", commits.len()); - let mut vectors = Vec::new(); - let mut kvs = Vec::new(); - let mut ids = Vec::new(); + // let mut vectors = Vec::new(); + // let mut kvs = Vec::new(); + // let mut ids = Vec::new(); + + // for commit in commits.iter() { + // let inner_vectors = commit.vectors.clone(); + // let inner_kvs = commit.kvs.clone(); + // let inner_ids: Vec = inner_vectors + // .iter() + // .map(|_| uuid::Uuid::new_v4().as_u128()) + // .collect(); + + // for ((vector, kv), id) in inner_vectors + // .iter() + // .zip(inner_kvs.iter()) + // .zip(inner_ids.iter()) + // { + // vectors.push(vector.clone()); + // kvs.push(kv.clone()); + // ids.push(id.clone()); + // } + // } - for commit in commits.iter() { - let inner_vectors = commit.vectors.clone(); - let inner_kvs = commit.kvs.clone(); - let inner_ids: Vec = inner_vectors - .iter() - .map(|_| uuid::Uuid::new_v4().as_u128()) - .collect(); + // self.state.vectors.bulk_insert(vectors, ids, kvs); - for ((vector, kv), id) in inner_vectors + for commit in commits { + let vectors = commit.vectors; + let kvs: Vec> = commit + .kvs + .clone() .iter() - .zip(inner_kvs.iter()) - .zip(inner_ids.iter()) - { - vectors.push(vector.clone()); - kvs.push(kv.clone()); - ids.push(id.clone()); - } - } - - self.state.vectors.bulk_insert(vectors, ids, kvs); + .map(|kv| { + kv.clone() + .iter() + .filter(|item| item.key != "text") + .cloned() + .collect() + }) + .collect::>(); + + let texts: Vec = commit + .kvs + .clone() + .iter() + .map(|kv| { + kv.clone() + .iter() + .filter(|item| item.key == "text") + .collect::>() + .first() + .unwrap() + .value + .clone() + }) + .collect::>(); + + println!("Processing commit: {:?}", processed); - // for commit in commits { - // // let vectors = commit.vectors; - // // let kvs = commit.kvs; + processed += 1; - // // println!("Processing commit: {:?}", processed); - // // processed += 1; + for ((vector, kv), texts) in vectors.iter().zip(kvs).zip(texts) { + let id = uuid::Uuid::new_v4().as_u128(); - // // for (vector, kv) in vectors.iter().zip(kvs.iter()) { - // // let id = uuid::Uuid::new_v4().as_u128(); + self.state.vectors.insert(vector.clone(), id, kv.clone()); + // self.state.texts.insert(id, texts.clone()); + match texts { + KVValue::String(text) => { + self.state + .texts + .insert(id, compress_string(&text)) + .expect("Failed to insert text"); + } + _ => {} + } + } - // // self.state.vectors.insert(vector.clone(), id, kv.clone()); - // // } + self.state.wal.mark_commit_finished(commit.hash)?; + } - // self.state.wal.mark_commit_finished(commit.hash)?; - // } + self.state.vectors.true_calibrate(); Ok(()) } diff --git a/src/services/query.rs b/src/services/query.rs index 90d1553..614d2d4 100644 --- a/src/services/query.rs +++ b/src/services/query.rs @@ -4,7 +4,7 @@ use super::namespace_state::NamespaceState; use crate::constants::VECTOR_SIZE; use crate::math::hamming_distance; use crate::structures::filters::{Filter, KVPair}; -use crate::utils::quantize; +use crate::utils::{decompress_string, quantize}; use std::io; use std::path::PathBuf; @@ -33,10 +33,21 @@ impl QueryService { .vectors .search(quantized_query_vector, top_k, filters) .iter() - .map(|(_, metadata)| { + .map(|(id, metadata)| { // let mut metadata = metadata.clone(); // metadata.push(KVPair::new("id".to_string(), id.to_string())); + // let text = self + // .state + // .texts + // .search(*id) + // .unwrap() + // .expect("Text not found"); + + // let mut metadata = metadata.clone(); + + // metadata.push(KVPair::new("text".to_string(), decompress_string(&text))); + metadata.clone() }) .collect(); diff --git a/src/structures/ann_tree.rs b/src/structures/ann_tree.rs index 0b76de4..f191511 100644 --- a/src/structures/ann_tree.rs +++ b/src/structures/ann_tree.rs @@ -265,7 +265,7 @@ impl ANNTree { let sibling_offsets: Vec = siblings .iter_mut() .map(|sibling| { - sibling.parent_offset = Some(current_node.parent_offset.unwrap()); + sibling.parent_offset = current_node.parent_offset; sibling.node_metadata = self.compute_node_metadata(sibling); self.storage_manager.store_node(sibling).unwrap() }) diff --git a/src/structures/ann_tree/node.rs b/src/structures/ann_tree/node.rs index 885a7a6..06dd643 100644 --- a/src/structures/ann_tree/node.rs +++ b/src/structures/ann_tree/node.rs @@ -2,7 +2,7 @@ use serde::Serialize; use crate::structures::ann_tree::k_modes::{balanced_k_modes, balanced_k_modes_4}; // use crate::structures::metadata_index::{KVPair, KVValue}; -use crate::structures::filters::{KVPair, KVValue}; +use crate::structures::filters::{calc_metadata_index_for_metadata, KVPair, KVValue}; use crate::{constants::QUANTIZED_VECTOR_SIZE, errors::HaystackError}; use std::fmt::Debug; use std::hash::Hash; @@ -339,10 +339,27 @@ impl Node { let values_len = read_length(&data[offset..offset + 4]); offset += 4; - for _ in 0..values_len { + for idx in 0..values_len { let value_len = read_length(&data[offset..offset + 4]); offset += 4; + if value_len > data.len() - offset { + println!("Current IDX: {}", idx); + println!("Value length: {}", value_len); + println!("Value len binary: {:?}", (value_len as u32).to_le_bytes()); + println!("Data length: {}", data.len()); + // add some more debug prints for the current state of things to figure out where it's going wrong + + println!("Offset: {}", offset); + println!("Key: {}", key); + println!("Values: {:?}", values); + println!("Values len: {}", values_len); + println!("Node metadata: {:?}", node_metadata); + println!("Node metadata len: {}", node_metadata_len); + + panic!("Value length exceeds data length"); + } + let value = String::from_utf8(data[offset..offset + value_len as usize].to_vec()).unwrap(); offset += value_len as usize; @@ -453,3 +470,47 @@ pub fn serialize_length(buffer: &mut Vec, length: u32) -> &Vec { pub fn read_length(data: &[u8]) -> usize { u32::from_le_bytes(data.try_into().unwrap()) as usize } + +fn random_string(len: usize) -> String { + use rand::distributions::Alphanumeric; + use rand::{thread_rng, Rng}; + + String::from_utf8_lossy( + thread_rng() + .sample_iter(&Alphanumeric) + .take(len) + .collect::>() + .as_slice(), + ) + .to_string() +} + +#[test] +fn test_serialize_deserialize() { + let mut node = Node::new_leaf(); + let mut vectors = Vec::new(); + let mut ids = Vec::new(); + let mut kvs = Vec::new(); + + for _ in 0..96 { + let vector: [u8; 128] = [0; 128]; + vectors.push(vector); + ids.push(0); + kvs.push(vec![ + KVPair::new("key".to_string(), random_string(77)), + KVPair::new("url".to_string(), random_string(44)), + KVPair::new("name".to_string(), random_string(17)), + ]); + } + + node.vectors = vectors; + node.ids = ids; + node.metadata = kvs.clone(); + node.node_metadata = calc_metadata_index_for_metadata(kvs.clone()); + + let serialized = node.serialize(); + + let deserialized = Node::deserialize(&serialized); + + // assert_eq!(node, deserialized); +} diff --git a/src/structures/ann_tree/storage.rs b/src/structures/ann_tree/storage.rs index aedb814..932e39d 100644 --- a/src/structures/ann_tree/storage.rs +++ b/src/structures/ann_tree/storage.rs @@ -1,15 +1,11 @@ -use crate::services::LockService; - use super::node::Node; +use crate::services::LockService; use memmap::MmapMut; use std::fs; use std::fs::OpenOptions; use std::io; use std::path::PathBuf; -use super::serialization::{TreeDeserialization, TreeSerialization}; -use std::fmt::Debug; - pub struct StorageManager { pub mmap: MmapMut, pub used_space: usize, @@ -17,13 +13,45 @@ pub struct StorageManager { locks: LockService, } -pub const SIZE_OF_USIZE: usize = std::mem::size_of::(); -pub const HEADER_SIZE: usize = SIZE_OF_USIZE * 2; // Used space + root offset +/* + + Schema for Storage Manager: + + - Header: + - Used blocks (u64) + - Root index (u64) + - Blocks: + - Block Header: + - Is primary (u8) + - Index in chain (u64) + - Primary index (u64) + - Next block index (u64) + - Previous block index (u64) + - Serialized node length (u64) + + - Data: + - Node data + +*/ + +pub const SIZE_OF_U64: usize = std::mem::size_of::(); +pub const HEADER_SIZE: usize = SIZE_OF_U64 * 2; // Used space + root offset + +pub const BLOCK_SIZE: usize = 1024; +pub const BLOCK_HEADER_SIZE: usize = SIZE_OF_U64 * 5 + 1; // Index in chain + Primary index + Next block offset + Previous block offset + Serialized node length + Is primary +pub const BLOCK_DATA_SIZE: usize = BLOCK_SIZE - BLOCK_HEADER_SIZE; + +#[derive(Debug, Clone)] +pub struct BlockHeaderData { + pub is_primary: bool, + pub index_in_chain: u64, + pub primary_index: u64, + pub next_block_offset: u64, + pub previous_block_offset: u64, + pub serialized_node_length: u64, +} -pub const BLOCK_SIZE: usize = 4096; -pub const OVERFLOW_POINTER_SIZE: usize = SIZE_OF_USIZE; -pub const BLOCK_HEADER_SIZE: usize = SIZE_OF_USIZE + 1; // one byte for if it is the primary block or overflow block -pub const BLOCK_DATA_SIZE: usize = BLOCK_SIZE - OVERFLOW_POINTER_SIZE - BLOCK_HEADER_SIZE; +impl Copy for BlockHeaderData {} impl StorageManager { pub fn new(path: PathBuf) -> io::Result { @@ -54,326 +82,301 @@ impl StorageManager { locks: LockService::new(locks_path.into()), }; - let used_space = if exists && manager.mmap.len() > HEADER_SIZE { - manager.used_space() - } else { - 0 - }; + // let used_blocks = if exists && manager.mmap.len() > HEADER_SIZE { + // manager.used_blocks() + // } else { + // 0 + // }; - // println!("INIT Used space: {}", used_space); - - manager.set_used_space(used_space); + // manager.set_used_blocks(used_blocks); Ok(manager) } - pub fn store_node(&mut self, node: &mut Node) -> io::Result { - let serialized = node.serialize(); - - // println!("Storing Serialized len: {}", serialized.len()); - - let serialized_len = serialized.len(); + pub fn used_blocks(&self) -> usize { + u64::from_le_bytes(self.mmap[0..SIZE_OF_U64].try_into().unwrap()) as usize + } - let num_blocks_required = (serialized_len + BLOCK_DATA_SIZE - 1) / BLOCK_DATA_SIZE; + pub fn set_used_blocks(&mut self, used_blocks: usize) { + self.mmap[0..SIZE_OF_U64].copy_from_slice(&(used_blocks as u64).to_le_bytes()); + } - let mut needs_new_blocks = true; + pub fn root_offset(&self) -> usize { + u64::from_le_bytes( + self.mmap[SIZE_OF_U64..(2 * SIZE_OF_U64)] + .try_into() + .unwrap(), + ) as usize + } - let mut prev_num_blocks_required = 0; + pub fn set_root_offset(&mut self, root_offset: usize) { + self.mmap[SIZE_OF_U64..(2 * SIZE_OF_U64)] + .copy_from_slice(&(root_offset as u64).to_le_bytes()); + } - if node.offset == 0 { - node.offset = self.increment_and_allocate_block()?; - // println!("Allocating block offset: {}", node.offset); - } else { - // println!("Using previous node offset: {}", node.offset); - let prev_serialized_len = usize::from_le_bytes( - self.read_from_offset(node.offset + 1, SIZE_OF_USIZE) - .try_into() - .unwrap(), - ); - prev_num_blocks_required = - (prev_serialized_len + BLOCK_DATA_SIZE - 1) / BLOCK_DATA_SIZE; - needs_new_blocks = num_blocks_required > prev_num_blocks_required; + pub fn increment_and_allocate_block(&mut self) -> usize { + // self.mmap.flush().unwrap(); + let used_blocks = self.used_blocks(); + self.set_used_blocks(used_blocks + 1); + // self.mmap.flush().unwrap(); - // println!( - // "Prev serialized len: {}, prev num blocks required: {}", - // prev_serialized_len, prev_num_blocks_required - // ); + if (used_blocks + 1) * BLOCK_SIZE > self.mmap.len() { + self.resize_mmap().unwrap(); } - // println!( - // "Storing node at offset: {}, serialized len: {}", - // node.offset, serialized_len - // ); + // println!("Allocated block at index {}", used_blocks); - let mut current_block_offset = node.offset.clone(); + used_blocks + } - let original_offset = current_block_offset.clone(); + fn resize_mmap(&mut self) -> io::Result<()> { + println!("Resizing mmap"); + let current_len = self.mmap.len(); + let new_len = current_len * 2; - let mut remaining_bytes_to_write = serialized_len; + let file = OpenOptions::new() + .read(true) + .write(true) + .open(self.path.clone())?; + + file.set_len(new_len as u64)?; - let mut serialized_bytes_written = 0; + self.mmap = unsafe { MmapMut::map_mut(&file)? }; + Ok(()) + } - let mut is_primary = 1u8; + pub fn get_block_header_data(&self, index: usize) -> BlockHeaderData { + let start = HEADER_SIZE + index * BLOCK_SIZE; + let end = start + BLOCK_HEADER_SIZE; + + let is_primary = self.mmap[start] == 1; + let index_in_chain = + u64::from_le_bytes(self.mmap[start + 1..start + 9].try_into().unwrap()); + let primary_index = + u64::from_le_bytes(self.mmap[start + 9..start + 17].try_into().unwrap()); + let next_block_offset = + u64::from_le_bytes(self.mmap[start + 17..start + 25].try_into().unwrap()); + let previous_block_offset = + u64::from_le_bytes(self.mmap[start + 25..start + 33].try_into().unwrap()); + let serialized_node_length = + u64::from_le_bytes(self.mmap[start + 33..end].try_into().unwrap()); + + BlockHeaderData { + is_primary, + index_in_chain, + primary_index, + next_block_offset, + previous_block_offset, + serialized_node_length, + } + } - let mut blocks_written = 0; + pub fn set_block_header_data(&mut self, index: usize, data: BlockHeaderData) { + let start = HEADER_SIZE + index * BLOCK_SIZE; + + self.mmap[start] = data.is_primary as u8; + self.mmap[start + 1..start + 9].copy_from_slice(&data.index_in_chain.to_le_bytes()); + self.mmap[start + 9..start + 17].copy_from_slice(&data.primary_index.to_le_bytes()); + self.mmap[start + 17..start + 25].copy_from_slice(&data.next_block_offset.to_le_bytes()); + self.mmap[start + 25..start + 33] + .copy_from_slice(&data.previous_block_offset.to_le_bytes()); + self.mmap[start + 33..start + 41] + .copy_from_slice(&data.serialized_node_length.to_le_bytes()); + } - // + pub fn get_block_bytes(&self, index: usize) -> &[u8] { + let start = HEADER_SIZE + index * BLOCK_SIZE + BLOCK_HEADER_SIZE; + let end = start + BLOCK_DATA_SIZE; - // println!( - // "Num blocks required: {}, num blocks prev: {}, needs new blocks: {}", - // num_blocks_required, prev_num_blocks_required, needs_new_blocks - // ); + &self.mmap[start..end] + } - self.acquire_block_lock(original_offset)?; + pub fn store_block(&mut self, index: usize, data: &[u8]) { + let start = HEADER_SIZE + index * BLOCK_SIZE + BLOCK_HEADER_SIZE; - while remaining_bytes_to_write > 0 { - let bytes_to_write = std::cmp::min(remaining_bytes_to_write, BLOCK_DATA_SIZE); + self.mmap[start..start + data.len()].copy_from_slice(data); + } - // println!( - // "writing is primary: {}, at offset: {}", - // is_primary, current_block_offset - // ); + pub fn store_node(&mut self, node: &mut Node) -> io::Result { + let serialized = node.serialize(); + let serialized_len = serialized.len() as u64; + let blocks_required = + ((serialized_len + BLOCK_DATA_SIZE as u64 - 1) / BLOCK_DATA_SIZE as u64) as usize; - self.write_to_offset(current_block_offset, is_primary.to_le_bytes().as_ref()); + // Allocate new block if this is a new node + let mut current_block_index = if node.offset == 0 { + self.increment_and_allocate_block() + } else { + node.offset + }; - current_block_offset += 1; // one for the primary byte + // Initialize writing state + let mut remaining_bytes_to_write = serialized_len; + let mut bytes_written = 0; + let mut prev_block_index = 0; + + let original_block_index = current_block_index; + + self.acquire_lock(original_block_index)?; + + // Clear previous overflow chain if it exists + let mut used_blocks = Vec::new(); + if node.offset != 0 { + let mut temp_block_index = node.offset; + while temp_block_index != 0 { + let block_header = self.get_block_header_data(temp_block_index); + used_blocks.push(temp_block_index); + temp_block_index = block_header.next_block_offset as usize; + if used_blocks.len() >= blocks_required { + break; + } + } + } - self.write_to_offset(current_block_offset, &serialized_len.to_le_bytes()); + // Clear excess blocks if the node is smaller + if used_blocks.len() > blocks_required { + for &index in &used_blocks[blocks_required..] { + self.clear_block(index); + } + used_blocks.truncate(blocks_required); + } - current_block_offset += SIZE_OF_USIZE; - self.write_to_offset( - current_block_offset, - &serialized[serialized_bytes_written..serialized_bytes_written + bytes_to_write], + // Write new node data into blocks + for i in 0..blocks_required { + let bytes_to_write = std::cmp::min(remaining_bytes_to_write as usize, BLOCK_DATA_SIZE); + self.store_block( + current_block_index, + &serialized[bytes_written as usize..bytes_written + bytes_to_write], ); - blocks_written += 1; - serialized_bytes_written += bytes_to_write; + let next_block_index = if remaining_bytes_to_write > bytes_to_write as u64 { + if i + 1 < used_blocks.len() { + used_blocks[i + 1] + } else { + self.increment_and_allocate_block() + } + } else { + 0 + }; + + let block_header = BlockHeaderData { + is_primary: i == 0, + index_in_chain: i as u64, + primary_index: original_block_index as u64, + next_block_offset: next_block_index as u64, + previous_block_offset: if i == 0 { 0 } else { prev_block_index as u64 }, + serialized_node_length: serialized_len, + }; - remaining_bytes_to_write -= bytes_to_write; - // current_block_offset += BLOCK_DATA_SIZE; - current_block_offset += BLOCK_DATA_SIZE; // Move to the end of written data + self.set_block_header_data(current_block_index, block_header); + // Debug statements // println!( - // "Remaining bytes to write: {}, bytes written: {}", - // remaining_bytes_to_write, serialized_bytes_written + // "Block {}: is_primary={}, index_in_chain={}, primary_index={}, next_block_offset={}, previous_block_offset={}, serialized_node_length={}", + // current_block_index, + // block_header.is_primary, + // block_header.index_in_chain, + // block_header.primary_index, + // block_header.next_block_offset, + // block_header.previous_block_offset, + // block_header.serialized_node_length // ); - if remaining_bytes_to_write > 0 { - let next_block_offset: usize; - - if needs_new_blocks && blocks_written >= prev_num_blocks_required { - next_block_offset = self.increment_and_allocate_block()?; - - self.write_to_offset(current_block_offset, &next_block_offset.to_le_bytes()); - } else { - next_block_offset = usize::from_le_bytes( - self.read_from_offset(current_block_offset, SIZE_OF_USIZE) - .try_into() - .unwrap(), - ); - - // if next_block_offset == 0 { - // next_block_offset = self.increment_and_allocate_block()?; - // println!("allocating bc 0 Next block offset: {}", next_block_offset); - // self.write_to_offset( - // current_block_offset, - // &next_block_offset.to_le_bytes(), - // ); - // } - - // println!("Next block offset: {}", next_block_offset); - } + prev_block_index = current_block_index; + current_block_index = next_block_index; + remaining_bytes_to_write -= bytes_to_write as u64; + bytes_written += bytes_to_write; + } - current_block_offset = next_block_offset; - } else { - self.write_to_offset(current_block_offset, &0u64.to_le_bytes()); - - // println!( - // "Setting next block offset to 0 at offset: {}", - // current_block_offset - // ); - // // Clear the remaining unused overflow blocks - // let mut next_block_offset = usize::from_le_bytes( - // self.read_from_offset(current_block_offset, SIZE_OF_USIZE) - // .try_into() - // .unwrap(), - // ); - - // while next_block_offset != 0 { - // let next_next_block_offset = usize::from_le_bytes( - // self.read_from_offset(next_block_offset + BLOCK_DATA_SIZE, SIZE_OF_USIZE) - // .try_into() - // .unwrap(), - // ); - - // println!("Clearing next block offset: {}", next_block_offset); - - // self.write_to_offset(next_block_offset + BLOCK_DATA_SIZE, &0u64.to_le_bytes()); - - // next_block_offset = next_next_block_offset; - // } - } + if bytes_written != serialized_len as usize { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Bytes written does not match serialized length", + )); + } - is_primary = 0; + if remaining_bytes_to_write != 0 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Remaining bytes to write is not 0", + )); } - self.release_block_lock(original_offset)?; + self.release_lock(original_block_index)?; + node.offset = original_block_index; + self.mmap.flush()?; Ok(node.offset) } - pub fn load_node(&self, offset: usize) -> io::Result { - let original_offset = offset.clone(); - let mut offset = offset.clone(); - - // println!("Loading node at offset: {}", offset); + fn clear_block(&mut self, index: usize) { + let start = HEADER_SIZE + index * BLOCK_SIZE; + let end = start + BLOCK_SIZE; + self.mmap[start..end].fill(0); + // Debug statement + println!("Cleared block at index {}", index); + } + pub fn load_node(&self, offset: usize) -> io::Result { + let mut current_block_index = offset; let mut serialized = Vec::new(); - let mut is_primary; - - let mut serialized_len; - - let mut bytes_read = 0; - - self.acquire_block_lock(original_offset)?; + self.acquire_lock(offset)?; loop { - let block_is_primary = - u8::from_le_bytes(self.read_from_offset(offset, 1).try_into().unwrap()); + let block_header = self.get_block_header_data(current_block_index); - if block_is_primary == 0 { - is_primary = false; - } else if block_is_primary == 1 { - is_primary = true; - } else { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "Invalid block type", - )); - } - - if !is_primary && bytes_read == 0 { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "Primary block not found", - )); + if block_header.is_primary { + serialized = Vec::with_capacity(block_header.serialized_node_length as usize); } - offset += 1; // one for the primary byte + let data = self.get_block_bytes(current_block_index); + serialized.extend_from_slice(data); - serialized_len = usize::from_le_bytes( - self.read_from_offset(offset, SIZE_OF_USIZE) - .try_into() - .unwrap(), - ); - - offset += SIZE_OF_USIZE; - - // println!("Serialized len: {}", serialized_len); - - let bytes_to_read = std::cmp::min(serialized_len - bytes_read, BLOCK_DATA_SIZE); // println!( - // "Bytes read: {}, bytes to read: {}", - // bytes_read, bytes_to_read + // "LOADING Block {}: is_primary={}, index_in_chain={}, primary_index={}, next_block_offset={}, previous_block_offset={}, serialized_node_length={}", + // current_block_index, + // block_header.is_primary, + // block_header.index_in_chain, + // block_header.primary_index, + // block_header.next_block_offset, + // block_header.previous_block_offset, + // block_header.serialized_node_length // ); - bytes_read += bytes_to_read; - - serialized.extend_from_slice(&self.read_from_offset(offset, bytes_to_read)); - - offset += BLOCK_DATA_SIZE; - - let next_block_offset = usize::from_le_bytes( - self.read_from_offset(offset, SIZE_OF_USIZE) - .try_into() - .unwrap(), - ); + if block_header.next_block_offset == 0 { + if serialized.len() < block_header.serialized_node_length as usize { + println!("Serialized node length does not match actual length"); + println!("Serialized length: {}", serialized.len()); + println!( + "Actual length: {}", + block_header.serialized_node_length as usize + ); - // println!("Next block offset: {}", next_block_offset); + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Serialized node length does not match actual length", + )); + } - if next_block_offset == 0 { break; } - offset = next_block_offset; + current_block_index = block_header.next_block_offset as usize; } - self.release_block_lock(original_offset)?; + self.release_lock(offset)?; let mut node = Node::deserialize(&serialized); - node.offset = original_offset; - + node.offset = offset; Ok(node) } - fn resize_mmap(&mut self) -> io::Result<()> { - let current_len = self.mmap.len(); - let new_len = current_len * 2; - - let file = OpenOptions::new() - .read(true) - .write(true) - .open(self.path.clone())?; // Ensure this path is handled correctly - - file.set_len(new_len as u64)?; - - self.mmap = unsafe { MmapMut::map_mut(&file)? }; - Ok(()) - } - - pub fn used_space(&self) -> usize { - usize::from_le_bytes(self.read_from_offset(0, SIZE_OF_USIZE).try_into().unwrap()) - } - - pub fn set_used_space(&mut self, used_space: usize) { - self.write_to_offset(0, &used_space.to_le_bytes()); - } - - pub fn root_offset(&self) -> usize { - usize::from_le_bytes( - self.read_from_offset(SIZE_OF_USIZE, SIZE_OF_USIZE) - .try_into() - .unwrap(), - ) - // self.root_offset - } - - pub fn set_root_offset(&mut self, root_offset: usize) { - self.write_to_offset(SIZE_OF_USIZE, &root_offset.to_le_bytes()); - // self.root_offset = root_offset; - } - - pub fn increment_and_allocate_block(&mut self) -> io::Result { - let used_space = self.used_space(); - // println!("Used space: {}", used_space); - self.set_used_space(used_space + BLOCK_SIZE); - let out = used_space + HEADER_SIZE; - // println!("Allocating block at offset: {}", out); - - if out + BLOCK_SIZE > self.mmap.len() { - self.resize_mmap()?; - } - - Ok(out) - } - - fn write_to_offset(&mut self, offset: usize, data: &[u8]) { - self.mmap[offset..offset + data.len()].copy_from_slice(data); - // self.mmap.flush().unwrap(); - } - - fn read_from_offset(&self, offset: usize, len: usize) -> &[u8] { - &self.mmap[offset..offset + len] + pub fn acquire_lock(&self, index: usize) -> io::Result<()> { + self.locks.acquire(index.to_string()) } - fn acquire_block_lock(&self, offset: usize) -> io::Result<()> { - self.locks.acquire(offset.to_string())?; - Ok(()) - } - - fn release_block_lock(&self, offset: usize) -> io::Result<()> { - self.locks.release(offset.to_string())?; - Ok(()) + pub fn release_lock(&self, index: usize) -> io::Result<()> { + self.locks.release(index.to_string()) } } diff --git a/src/structures/filters.rs b/src/structures/filters.rs index 5ff757e..ede8f42 100644 --- a/src/structures/filters.rs +++ b/src/structures/filters.rs @@ -158,7 +158,7 @@ impl KVPair { pub fn serialize(&self) -> Vec { let mut serialized = Vec::new(); - serialized.extend_from_slice(self.key.len().to_le_bytes().as_ref()); + serialized.extend_from_slice(self.key.as_bytes().len().to_le_bytes().as_ref()); serialized.extend_from_slice(self.key.as_bytes()); // serialized.extend_from_slice(self.value.len().to_le_bytes().as_ref()); // serialized.extend_from_slice(self.value.as_bytes()); @@ -166,7 +166,7 @@ impl KVPair { match self.value.clone() { KVValue::String(s) => { serialized.push(0); - serialized.extend_from_slice(s.len().to_le_bytes().as_ref()); + serialized.extend_from_slice(s.as_bytes().len().to_le_bytes().as_ref()); serialized.extend_from_slice(s.as_bytes()); } KVValue::Integer(i) => { diff --git a/src/structures/mmap_tree.rs b/src/structures/mmap_tree.rs index a9dffe7..86402b2 100644 --- a/src/structures/mmap_tree.rs +++ b/src/structures/mmap_tree.rs @@ -10,6 +10,8 @@ use node::{Node, NodeType}; use serialization::{TreeDeserialization, TreeSerialization}; use storage::StorageManager; +use crate::errors::HaystackError; + pub struct Tree { pub b: usize, pub storage_manager: storage::StorageManager, @@ -42,47 +44,112 @@ where pub fn insert(&mut self, key: K, value: V) -> Result<(), io::Error> { // println!("Inserting key: {}, value: {}", key, value); - let mut root = self - .storage_manager - .load_node(self.storage_manager.root_offset())?; - - // println!("Root offset: {}, {}", self.root_offset, root.offset); - - if root.is_full() { - // println!("Root is full, needs splitting"); - let mut new_root = Node::new_internal(0); - new_root.is_root = true; - let (median, mut sibling) = root.split(self.b)?; - // println!("Root split: median = {}, new sibling created", median); - // println!("Root split: median = {}, new sibling created", median); - root.is_root = false; - self.storage_manager.store_node(&mut root)?; - // println!("Root stored"); - let sibling_offset = self.storage_manager.store_node(&mut sibling)?; - new_root.keys.push(median); - new_root.children.push(self.storage_manager.root_offset()); // old root offset - new_root.children.push(sibling_offset); // new sibling offset - new_root.is_root = true; - self.storage_manager.store_node(&mut new_root)?; - self.storage_manager.set_root_offset(new_root.offset); - - root.parent_offset = Some(new_root.offset); - sibling.parent_offset = Some(new_root.offset); - self.storage_manager.store_node(&mut root)?; - self.storage_manager.store_node(&mut sibling)?; - // println!( - // "New root created with children offsets: {} and {}", - // self.root_offset, sibling_offset - // ); + let vals = vec![(key, value)]; + + self.batch_insert(vals) + } + + pub fn batch_insert(&mut self, entries: Vec<(K, V)>) -> Result<(), io::Error> { + if entries.is_empty() { + println!("No entries to insert"); + return Ok(()); } - // println!("Inserting into non-full root"); - self.insert_non_full(self.storage_manager.root_offset(), key, value, 0)?; - // println!("Inserted key, root offset: {}", self.root_offset); + let mut entries = entries; + entries.sort_by(|a, b| a.0.cmp(&b.0)); + + let entrypoint = self + .find_entrypoint(entries[0].0.clone()) + .expect("Failed to find entrypoint"); + + let mut current_offset = entrypoint; + let mut node = self.storage_manager.load_node(current_offset)?; + + for (key, value) in entries.iter() { + while node.node_type == NodeType::Internal { + // We should only be operating on leaf nodes in this loop + let idx = node.keys.binary_search(key).unwrap_or_else(|x| x); + current_offset = node.children[idx]; + node = self.storage_manager.load_node(current_offset)?; + } + + if node.is_full() { + let (median, mut sibling) = node.split(self.b)?; + let sibling_offset = self.storage_manager.store_node(&mut sibling)?; + self.storage_manager.store_node(&mut node)?; // Store changes to the original node after splitting + + if node.is_root { + // println!("Creating new root"); + // Create a new root if the current node is the root + let mut new_root = Node::new_internal(0); + new_root.is_root = true; + new_root.keys.push(median.clone()); + new_root.children.push(current_offset); // old root offset + new_root.children.push(sibling_offset); // new sibling offset + new_root.parent_offset = None; + let new_root_offset = self.storage_manager.store_node(&mut new_root)?; + self.storage_manager.set_root_offset(new_root_offset); + node.is_root = false; + node.parent_offset = Some(new_root_offset); + sibling.parent_offset = Some(new_root_offset); + // println!("New root offset: {}", new_root_offset); + self.storage_manager.store_node(&mut node)?; + self.storage_manager.store_node(&mut sibling)?; + } else { + // Update the parent node with the new median + let parent_offset = node.parent_offset.unwrap(); + // println!("Parent offset: {}", parent_offset); + let mut parent = self.storage_manager.load_node(parent_offset)?; + let idx = parent + .keys + .binary_search(&median.clone()) + .unwrap_or_else(|x| x); + parent.keys.insert(idx, median.clone()); + parent.children.insert(idx + 1, sibling_offset); + self.storage_manager.store_node(&mut parent)?; + } + + // Decide which node to continue insertion + if *key >= median { + current_offset = sibling_offset; + node = sibling; + } + } + + // Insert the key into the correct leaf node + let position = node.keys.binary_search(key).unwrap_or_else(|x| x); + + if node.keys.get(position) == Some(&key) { + node.values[position] = Some(value.clone()); + } else { + node.keys.insert(position, key.clone()); + node.values.insert(position, Some(value.clone())); + } + self.storage_manager.store_node(&mut node)?; // Store changes after each insertion + } Ok(()) } + fn find_entrypoint(&mut self, key: K) -> Result { + let mut current_offset = self.storage_manager.root_offset(); + let mut node = self + .storage_manager + .load_node(current_offset) + .expect("Failed to load node"); + + while node.node_type == NodeType::Internal { + let idx = node.keys.binary_search(&key).unwrap_or_else(|x| x); + current_offset = node.children[idx] as usize; + node = self + .storage_manager + .load_node(current_offset) + .expect("Failed to load node"); + } + + Ok(current_offset) + } + fn insert_non_full( &mut self, node_offset: usize, @@ -265,92 +332,4 @@ where Ok(()) } - - pub fn batch_insert(&mut self, entries: Vec<(K, V)>) -> Result<(), io::Error> { - if entries.is_empty() { - println!("No entries to insert"); - return Ok(()); - } - - let mut entries = entries; - entries.sort_by(|a, b| a.0.cmp(&b.0)); - - let entrypoint = self.find_entrypoint(entries[0].0.clone())?; - - let mut current_offset = entrypoint; - let mut node = self.storage_manager.load_node(current_offset)?; - - for (key, value) in entries.iter() { - while node.node_type == NodeType::Internal { - // We should only be operating on leaf nodes in this loop - let idx = node.keys.binary_search(key).unwrap_or_else(|x| x); - current_offset = node.children[idx]; - node = self.storage_manager.load_node(current_offset)?; - } - - if node.is_full() { - let (median, mut sibling) = node.split(self.b)?; - let sibling_offset = self.storage_manager.store_node(&mut sibling)?; - self.storage_manager.store_node(&mut node)?; // Store changes to the original node after splitting - - if node.is_root { - // println!("Creating new root"); - // Create a new root if the current node is the root - let mut new_root = Node::new_internal(0); - new_root.is_root = true; - new_root.keys.push(median.clone()); - new_root.children.push(current_offset); // old root offset - new_root.children.push(sibling_offset); // new sibling offset - new_root.parent_offset = None; - let new_root_offset = self.storage_manager.store_node(&mut new_root)?; - self.storage_manager.set_root_offset(new_root_offset); - node.is_root = false; - node.parent_offset = Some(new_root_offset); - sibling.parent_offset = Some(new_root_offset); - // println!("New root offset: {}", new_root_offset); - self.storage_manager.store_node(&mut node)?; - self.storage_manager.store_node(&mut sibling)?; - } else { - // Update the parent node with the new median - let parent_offset = node.parent_offset.unwrap(); - // println!("Parent offset: {}", parent_offset); - let mut parent = self.storage_manager.load_node(parent_offset)?; - let idx = parent - .keys - .binary_search(&median.clone()) - .unwrap_or_else(|x| x); - parent.keys.insert(idx, median.clone()); - parent.children.insert(idx + 1, sibling_offset); - self.storage_manager.store_node(&mut parent)?; - } - - // Decide which node to continue insertion - if *key >= median { - current_offset = sibling_offset; - node = sibling; - } - } - - // Insert the key into the correct leaf node - let position = node.keys.binary_search(key).unwrap_or_else(|x| x); - node.keys.insert(position, key.clone()); - node.values.insert(position, Some(value.clone())); - self.storage_manager.store_node(&mut node)?; // Store changes after each insertion - } - - Ok(()) - } - - fn find_entrypoint(&mut self, key: K) -> Result { - let mut current_offset = self.storage_manager.root_offset(); - let mut node = self.storage_manager.load_node(current_offset)?; - - while node.node_type == NodeType::Internal { - let idx = node.keys.binary_search(&key).unwrap_or_else(|x| x); - current_offset = node.children[idx]; - node = self.storage_manager.load_node(current_offset)?; - } - - Ok(current_offset) - } } diff --git a/src/structures/mmap_tree/storage.rs b/src/structures/mmap_tree/storage.rs index ddee8cf..f8ff7ed 100644 --- a/src/structures/mmap_tree/storage.rs +++ b/src/structures/mmap_tree/storage.rs @@ -21,7 +21,7 @@ pub struct StorageManager { pub const SIZE_OF_USIZE: usize = std::mem::size_of::(); pub const HEADER_SIZE: usize = SIZE_OF_USIZE * 2; // Used space + root offset -pub const BLOCK_SIZE: usize = 4096; +pub const BLOCK_SIZE: usize = 16384; pub const OVERFLOW_POINTER_SIZE: usize = SIZE_OF_USIZE; pub const BLOCK_HEADER_SIZE: usize = SIZE_OF_USIZE + 1; // one byte for if it is the primary block or overflow block pub const BLOCK_DATA_SIZE: usize = BLOCK_SIZE - OVERFLOW_POINTER_SIZE - BLOCK_HEADER_SIZE; diff --git a/src/structures/wal.rs b/src/structures/wal.rs index 1400b2b..8da38c8 100644 --- a/src/structures/wal.rs +++ b/src/structures/wal.rs @@ -187,7 +187,7 @@ impl WAL { let conn = Connection::open(db_path.clone())?; // Enable WAL mode - conn.execute_batch("PRAGMA journal_mode = WAL; PRAGMA synchronous = NORMAL;")?; + conn.execute_batch("PRAGMA journal_mode = WAL; PRAGMA busy_timeout = 30000;")?; // Create the table if it doesn't exist conn.execute_batch( diff --git a/src/utils.rs b/src/utils.rs index 3b1f24b..a068266 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,3 +1,29 @@ pub mod quantization; pub use quantization::{dequantize, quantize}; + +use brotli::CompressorWriter; +use brotli::Decompressor; +use std::io::prelude::*; + +pub fn compress_string(input: &str) -> Vec { + let mut compressed = Vec::new(); + { + let mut compressor = CompressorWriter::new(&mut compressed, 4096, 11, 22); + compressor + .write_all(input.as_bytes()) + .expect("Failed to write data"); + } + compressed +} + +pub fn decompress_string(input: &[u8]) -> String { + let mut decompressed = Vec::new(); + { + let mut decompressor = Decompressor::new(input, 4096); + decompressor + .read_to_end(&mut decompressed) + .expect("Failed to read data"); + } + String::from_utf8(decompressed).expect("Failed to convert to string") +} From f5b73a9e05a0a328bd64958b1bca2f5962c50d2b Mon Sep 17 00:00:00 2001 From: Carson Poole Date: Thu, 23 May 2024 23:52:53 -0400 Subject: [PATCH 4/6] completely reworked storage manager and added lazy values for sp33d --- src/constants.rs | 1 + src/structures.rs | 3 +- src/structures/ann_tree.rs | 322 ++++++++++------- src/structures/ann_tree/node.rs | 505 +++++++++++++++++++-------- src/structures/block_storage.rs | 359 +++++++++++++++++++ src/structures/dense_vector_list.rs | 193 ---------- src/structures/filters.rs | 11 +- src/structures/inverted_index.rs | 228 ------------ src/structures/metadata_index.rs | 426 ---------------------- src/structures/tree.rs | 302 ++++++++++++---- src/structures/tree/node.rs | 475 ++++++++++++++++++++++--- src/structures/tree/serialization.rs | 48 --- src/structures/tree/storage.rs | 0 13 files changed, 1586 insertions(+), 1287 deletions(-) create mode 100644 src/structures/block_storage.rs delete mode 100644 src/structures/dense_vector_list.rs delete mode 100644 src/structures/inverted_index.rs delete mode 100644 src/structures/metadata_index.rs delete mode 100644 src/structures/tree/serialization.rs delete mode 100644 src/structures/tree/storage.rs diff --git a/src/constants.rs b/src/constants.rs index efd3c16..2e3d22c 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -6,3 +6,4 @@ pub const ALPHA: usize = 64; pub const BETA: usize = 3; pub const GAMMA: usize = 1; pub const RHO: usize = 1; +pub const B: usize = 128; diff --git a/src/structures.rs b/src/structures.rs index 33facb3..e0c9b1d 100644 --- a/src/structures.rs +++ b/src/structures.rs @@ -1,5 +1,6 @@ pub mod ann_tree; -pub mod dense_vector_list; +// pub mod dense_vector_list; +pub mod block_storage; pub mod filters; // pub mod inverted_index; // pub mod metadata_index; diff --git a/src/structures/ann_tree.rs b/src/structures/ann_tree.rs index f191511..9d93b96 100644 --- a/src/structures/ann_tree.rs +++ b/src/structures/ann_tree.rs @@ -9,6 +9,7 @@ use rayon::iter::{IndexedParallelIterator, IntoParallelIterator}; use storage::StorageManager; use crate::constants::QUANTIZED_VECTOR_SIZE; +use crate::structures::ann_tree::node::LazyValue; use std::io; use self::k_modes::find_modes; @@ -16,6 +17,7 @@ use self::metadata::{NodeMetadata, NodeMetadataIndex}; use self::node::Vector; use crate::math::hamming_distance; +use super::block_storage::BlockStorage; use super::filters::{combine_filters, Filter, Filters}; // use super::metadata_index::{KVPair, KVValue}; use super::mmap_tree::serialization::{TreeDeserialization, TreeSerialization}; @@ -29,7 +31,7 @@ use std::path::PathBuf; pub struct ANNTree { pub k: usize, - pub storage_manager: storage::StorageManager, + pub storage_manager: BlockStorage, } #[derive(Eq, PartialEq)] @@ -54,7 +56,7 @@ impl PartialOrd for PathNode { impl ANNTree { pub fn new(path: PathBuf) -> Result { let mut storage_manager = - StorageManager::new(path).expect("Failed to make storage manager in ANN Tree"); + BlockStorage::new(path).expect("Failed to make storage manager in ANN Tree"); // println!("INIT Used space: {}", storage_manager.used_space); @@ -68,8 +70,13 @@ impl ANNTree { let mut root = Node::new_leaf(); root.is_root = true; - storage_manager.store_node(&mut root)?; - storage_manager.set_root_offset(root.offset); + // store_node(&mut root)?; + // storage_manager.set_root_offset(root.offset); + + let serialized_root = root.serialize(); + + let offset = storage_manager.store(serialized_root, 0)?; + storage_manager.set_root_offset(offset); Ok(ANNTree { storage_manager, @@ -77,6 +84,20 @@ impl ANNTree { }) } + pub fn store_node(&mut self, node: &mut Node) -> Result { + let serialized_node = node.serialize(); + let offset = self.storage_manager.store(serialized_node, node.offset)?; + node.offset = offset; + Ok(offset) + } + + pub fn load_node(&self, offset: usize) -> Result { + let serialized_node = self.storage_manager.load(offset)?; + let mut node = Node::deserialize(&serialized_node); + node.offset = offset; + Ok(node) + } + pub fn batch_insert( &mut self, vectors: Vec, @@ -102,7 +123,7 @@ impl ANNTree { new_root.is_root = true; - self.storage_manager.store_node(&mut new_root).unwrap(); + self.store_node(&mut new_root).unwrap(); self.storage_manager.set_root_offset(new_root.offset); @@ -117,7 +138,7 @@ impl ANNTree { // leaf.parent_offset = Some(leaf.offset); // leaf.children.push(leaf.offset); // leaf.vectors.push(find_modes(leaf.vectors.clone())); - // self.storage_manager.store_node(leaf).unwrap(); + // self.store_node(leaf).unwrap(); // } let mut all_vectors = Vec::new(); @@ -126,7 +147,7 @@ impl ANNTree { for leaf in current_leaves.iter_mut() { leaf.is_root = false; - self.storage_manager.store_node(leaf).unwrap(); + self.store_node(leaf).unwrap(); all_vectors.extend(leaf.vectors.clone()); all_ids.extend(leaf.ids.clone()); all_metadata.extend(leaf.metadata.clone()); @@ -134,7 +155,15 @@ impl ANNTree { all_vectors.extend(vectors); all_ids.extend(ids); - all_metadata.extend(metadata); + all_metadata.extend( + metadata + .iter() + .map(|md| { + LazyValue::new(md.clone(), &mut self.storage_manager) + .expect("Failed to create lazy value") + }) + .collect::>(), + ); println!("All vectors: {:?}", all_vectors.len()); println!("All ids: {:?}", all_ids.len()); @@ -149,11 +178,17 @@ impl ANNTree { { if leaf.is_full() { leaf.parent_offset = Some(new_root.offset); - leaf.node_metadata = calc_metadata_index_for_metadata(leaf.metadata.clone()); - self.storage_manager.store_node(&mut leaf).unwrap(); + let new_node_metdata = + calc_metadata_index_for_metadata(leaf.metadata.clone(), &self.storage_manager); + + leaf.node_metadata = Some( + LazyValue::new(new_node_metdata, &mut self.storage_manager) + .expect("Failed to create lazy value"), + ); + self.store_node(&mut leaf).unwrap(); new_root.children.push(leaf.offset); new_root.vectors.push(find_modes(leaf.vectors.clone())); - self.storage_manager.store_node(&mut new_root).unwrap(); + self.store_node(&mut new_root).unwrap(); leaf = Node::new_leaf(); } leaf.vectors.push(vector.clone()); @@ -161,9 +196,9 @@ impl ANNTree { leaf.metadata.push(metadata.clone()); } - new_root.node_metadata = self.compute_node_metadata(&new_root); + new_root.node_metadata = Some(self.compute_node_metadata(&new_root)); - self.storage_manager.store_node(&mut new_root).unwrap(); + self.store_node(&mut new_root).unwrap(); self.storage_manager.set_root_offset(new_root.offset); @@ -174,18 +209,20 @@ impl ANNTree { pub fn insert(&mut self, vector: Vector, id: u128, metadata: Vec) { let entrypoint = self.find_entrypoint(vector); - let mut node = self.storage_manager.load_node(entrypoint).unwrap(); + let mut node = self.load_node(entrypoint).unwrap(); // println!("Entrypoint: {:?}", entrypoint); if node.is_full() { - let mut siblings = node.split().expect("Failed to split node"); + let mut siblings = node + .split(&mut self.storage_manager) + .expect("Failed to split node"); let sibling_offsets: Vec = siblings .iter_mut() .map(|sibling| { sibling.parent_offset = node.parent_offset; // Set parent offset before storing - sibling.node_metadata = self.compute_node_metadata(&sibling); - self.storage_manager.store_node(sibling).unwrap() + sibling.node_metadata = Some(self.compute_node_metadata(&sibling)); + self.store_node(sibling).unwrap() }) .collect(); @@ -203,30 +240,30 @@ impl ANNTree { new_root.children.push(node.offset); new_root.vectors.push(find_modes(node.vectors.clone())); for sibling_offset in &sibling_offsets { - let sibling = self.storage_manager.load_node(*sibling_offset).unwrap(); + let sibling = self.load_node(*sibling_offset).unwrap(); new_root.vectors.push(find_modes(sibling.vectors)); new_root.children.push(*sibling_offset); } - self.storage_manager.store_node(&mut new_root).unwrap(); + self.store_node(&mut new_root).unwrap(); self.storage_manager.set_root_offset(new_root.offset); node.is_root = false; node.parent_offset = Some(new_root.offset); siblings .iter_mut() .for_each(|sibling| sibling.parent_offset = Some(new_root.offset)); - self.storage_manager.store_node(&mut node).unwrap(); + self.store_node(&mut node).unwrap(); siblings.iter_mut().for_each(|sibling| { if sibling.node_type == NodeType::Internal && sibling.children.len() != sibling.vectors.len() { panic!("Internal node has different number of children and vectors v3"); } - sibling.node_metadata = self.compute_node_metadata(sibling); - self.storage_manager.store_node(sibling).unwrap(); + sibling.node_metadata = Some(self.compute_node_metadata(sibling)); + self.store_node(sibling).unwrap(); }); } else { let parent_offset = node.parent_offset.unwrap(); - let mut parent = self.storage_manager.load_node(parent_offset).unwrap(); + let mut parent = self.load_node(parent_offset).unwrap(); parent.children.push(node.offset); parent.vectors.push(find_modes(node.vectors.clone())); sibling_offsets @@ -244,9 +281,9 @@ impl ANNTree { panic!("parent node has different number of children and vectors"); } - self.storage_manager.store_node(&mut parent).unwrap(); + self.store_node(&mut parent).unwrap(); node.parent_offset = Some(parent_offset); - self.storage_manager.store_node(&mut node).unwrap(); + self.store_node(&mut node).unwrap(); siblings.into_iter().for_each(|mut sibling| { if sibling.node_type == NodeType::Internal && sibling.children.len() != sibling.vectors.len() @@ -254,20 +291,22 @@ impl ANNTree { panic!("Internal node has different number of children and vectors v3"); } sibling.parent_offset = Some(parent_offset); - sibling.node_metadata = self.compute_node_metadata(&sibling); - self.storage_manager.store_node(&mut sibling).unwrap(); + sibling.node_metadata = Some(self.compute_node_metadata(&sibling)); + self.store_node(&mut sibling).unwrap(); }); let mut current_node = parent; while current_node.is_full() { println!("Current node is full"); - let mut siblings = current_node.split().expect("Failed to split node"); + let mut siblings = current_node + .split(&mut self.storage_manager) + .expect("Failed to split node"); let sibling_offsets: Vec = siblings .iter_mut() .map(|sibling| { sibling.parent_offset = current_node.parent_offset; - sibling.node_metadata = self.compute_node_metadata(sibling); - self.storage_manager.store_node(sibling).unwrap() + sibling.node_metadata = Some(self.compute_node_metadata(sibling)); + self.store_node(sibling).unwrap() }) .collect(); @@ -290,14 +329,14 @@ impl ANNTree { siblings.iter().for_each(|sibling| { new_root.vectors.push(find_modes(sibling.vectors.clone())) }); - self.storage_manager.store_node(&mut new_root).unwrap(); + self.store_node(&mut new_root).unwrap(); self.storage_manager.set_root_offset(new_root.offset); current_node.is_root = false; current_node.parent_offset = Some(new_root.offset); siblings .iter_mut() .for_each(|sibling| sibling.parent_offset = Some(new_root.offset)); - self.storage_manager.store_node(&mut current_node).unwrap(); + self.store_node(&mut current_node).unwrap(); siblings.into_iter().for_each(|mut sibling| { if sibling.node_type == NodeType::Internal && sibling.children.len() != sibling.vectors.len() @@ -306,13 +345,13 @@ impl ANNTree { "Internal node has different number of children and vectors v4" ); } - sibling.node_metadata = self.compute_node_metadata(&sibling); - self.storage_manager.store_node(&mut sibling).unwrap(); + sibling.node_metadata = Some(self.compute_node_metadata(&sibling)); + self.store_node(&mut sibling).unwrap(); }); - new_root.node_metadata = self.compute_node_metadata(&new_root); + new_root.node_metadata = Some(self.compute_node_metadata(&new_root)); } else { let parent_offset = current_node.parent_offset.unwrap(); - let mut parent = self.storage_manager.load_node(parent_offset).unwrap(); + let mut parent = self.load_node(parent_offset).unwrap(); parent.children.push(current_node.offset); sibling_offsets .iter() @@ -330,15 +369,15 @@ impl ANNTree { } parent.vectors.push(find_modes(sibling.vectors.clone())) }); - self.storage_manager.store_node(&mut parent).unwrap(); + self.store_node(&mut parent).unwrap(); current_node.parent_offset = Some(parent_offset); - self.storage_manager.store_node(&mut current_node).unwrap(); + self.store_node(&mut current_node).unwrap(); siblings.into_iter().for_each(|mut sibling| { sibling.parent_offset = Some(parent_offset); - sibling.node_metadata = self.compute_node_metadata(&sibling); - self.storage_manager.store_node(&mut sibling).unwrap(); + sibling.node_metadata = Some(self.compute_node_metadata(&sibling)); + self.store_node(&mut sibling).unwrap(); }); - parent.node_metadata = self.compute_node_metadata(&parent); + parent.node_metadata = Some(self.compute_node_metadata(&parent)); current_node = parent; } } @@ -349,9 +388,17 @@ impl ANNTree { } node.vectors.push(vector); node.ids.push(id); - node.metadata.push(metadata.clone()); + node.metadata + .push(LazyValue::new(metadata.clone(), &mut self.storage_manager).unwrap()); for kv in metadata { - match node.node_metadata.get(kv.key.clone()) { + match node + .node_metadata + .clone() + .expect("") + .get(&self.storage_manager) + .expect("Failed to get node metadata") + .get(kv.key.clone()) + { Some(res) => { let mut set = res.clone(); match kv.value { @@ -380,7 +427,18 @@ impl ANNTree { } } - node.node_metadata.insert(kv.key.clone(), set); + // node.node_metadata.insert(kv.key.clone(), set); + let mut current_node_metadata = node + .node_metadata + .expect("") + .get(&self.storage_manager) + .expect("Failed to get node metadata") + .clone(); + current_node_metadata.insert(kv.key.clone(), set); + node.node_metadata = Some( + LazyValue::new(current_node_metadata, &mut self.storage_manager) + .expect("Failed to create lazy value"), + ); } None => { let mut set = NodeMetadata::new(); @@ -396,19 +454,28 @@ impl ANNTree { } } - node.node_metadata.insert(kv.key.clone(), set); + // node.node_metadata.insert(kv.key.clone(), set); + let mut current_node_metadata = node + .node_metadata + .expect("") + .get(&self.storage_manager) + .expect("Failed to get node metadata") + .clone(); + + current_node_metadata.insert(kv.key.clone(), set); + node.node_metadata = Some( + LazyValue::new(current_node_metadata, &mut self.storage_manager) + .expect("Failed to create lazy value"), + ); } } } - self.storage_manager.store_node(&mut node).unwrap(); + self.store_node(&mut node).unwrap(); } } fn find_entrypoint(&mut self, vector: Vector) -> usize { - let mut node = self - .storage_manager - .load_node(self.storage_manager.root_offset()) - .unwrap(); + let mut node = self.load_node(self.storage_manager.root_offset()).unwrap(); while node.node_type == NodeType::Internal { let mut distances: Vec<(usize, u16)> = node @@ -422,10 +489,7 @@ impl ANNTree { let best = distances.get(0).unwrap(); - let best_node = self - .storage_manager - .load_node(node.children[best.0]) - .unwrap(); + let best_node = self.load_node(node.children[best.0]).unwrap(); node = best_node; } @@ -440,10 +504,7 @@ impl ANNTree { top_k: usize, filters: &Filter, ) -> Vec<(u128, Vec)> { - let node = self - .storage_manager - .load_node(self.storage_manager.root_offset()) - .unwrap(); + let node = self.load_node(self.storage_manager.root_offset()).unwrap(); // let mut visited = HashSet::new(); @@ -480,7 +541,10 @@ impl ANNTree { .par_iter() .map(|(idx, distance)| { let id = node.ids[*idx]; - let metadata = node.metadata[*idx].clone(); + let metadata = node.metadata[*idx] + .clone() + .get(&self.storage_manager) + .unwrap(); (id, *distance, metadata) }) .collect::>(); @@ -523,7 +587,7 @@ impl ANNTree { &self, query: &Vector, vector_items: &Vec, - metadata_items: &Vec>, + metadata_items: &Vec>>, filters: &Filter, k: usize, ) -> Vec<(usize, u16)> { @@ -545,7 +609,13 @@ impl ANNTree { } // Evaluate filters for the loaded node - if !Filters::should_prune_metadata(filters, &&metadata_items[idx]) { + if !Filters::should_prune_metadata( + filters, + &&metadata_items[idx] + .clone() + .get(&self.storage_manager) + .unwrap(), + ) { // Add to top-k if it matches the filter if top_k_values.len() < k { top_k_values.push((idx, distance)); @@ -591,10 +661,18 @@ impl ANNTree { break; // No need to check further if we already have top-k and current distance is not better } - let child_node = self.storage_manager.load_node(children[idx]).unwrap(); + let child_node = self.load_node(children[idx]).unwrap(); // Evaluate filters for the loaded node - if !Filters::should_prune(filters, &child_node.node_metadata) { + if !Filters::should_prune( + filters, + &child_node + .node_metadata + .clone() + .expect("") + .get(&self.storage_manager) + .expect(""), + ) { // Add to top-k if it matches the filter if top_k_values.len() < k { top_k_values.push((idx, distance, child_node)); @@ -622,7 +700,7 @@ impl ANNTree { offset: usize, leaf_nodes: &mut Vec, ) -> Result<(), io::Error> { - let node = self.storage_manager.load_node(offset).unwrap().clone(); + let node = self.load_node(offset).unwrap().clone(); if node.node_type == NodeType::Leaf { leaf_nodes.push(node); } else { @@ -643,7 +721,7 @@ impl ANNTree { new_root.is_root = true; // Step 3: Store the new root to set its offset - self.storage_manager.store_node(&mut new_root)?; + self.store_node(&mut new_root)?; self.storage_manager.set_root_offset(new_root.offset); // Step 4: Make all the leaf nodes the new root's children, and set all their parent_offsets to the new root's offset @@ -651,11 +729,11 @@ impl ANNTree { leaf_node.parent_offset = Some(new_root.offset); new_root.children.push(leaf_node.offset); new_root.vectors.push(find_modes(leaf_node.vectors.clone())); - // new_root.node_metadata = self.compute_node_metadata(&new_root); - self.storage_manager.store_node(leaf_node)?; + // new_root.node_metadata = Some(self.compute_node_metadata(&new_root)); + self.store_node(leaf_node)?; } - new_root.node_metadata = self.compute_node_metadata(&new_root); + new_root.node_metadata = Some(self.compute_node_metadata(&new_root)); // new_root.node_metadata = combine_filters( // leaf_nodes @@ -665,19 +743,21 @@ impl ANNTree { // ); // Update the root node with its children and vectors - self.storage_manager.store_node(&mut new_root)?; + self.store_node(&mut new_root)?; // Step 5: Split the nodes until it is balanced/there are no nodes that are full let mut current_nodes = vec![new_root]; while let Some(mut node) = current_nodes.pop() { if node.is_full() { - let mut siblings = node.split().expect("Failed to split node"); + let mut siblings = node + .split(&mut self.storage_manager) + .expect("Failed to split node"); let sibling_offsets: Vec = siblings .iter_mut() .map(|sibling| { sibling.parent_offset = node.parent_offset; // Set parent offset before storing - sibling.node_metadata = self.compute_node_metadata(sibling); - self.storage_manager.store_node(sibling).unwrap() + sibling.node_metadata = Some(self.compute_node_metadata(sibling)); + self.store_node(sibling).unwrap() }) .collect(); @@ -696,41 +776,33 @@ impl ANNTree { new_root.vectors.push(find_modes(node.vectors.clone())); for sibling_offset in &sibling_offsets { - let sibling = self - .storage_manager - .load_node(*sibling_offset) - .unwrap() - .clone(); + let sibling = self.load_node(*sibling_offset).unwrap().clone(); new_root.vectors.push(find_modes(sibling.vectors)); new_root.children.push(*sibling_offset); } - new_root.node_metadata = self.compute_node_metadata(&new_root); - self.storage_manager.store_node(&mut new_root)?; + new_root.node_metadata = Some(self.compute_node_metadata(&new_root)); + self.store_node(&mut new_root)?; self.storage_manager.set_root_offset(new_root.offset); node.is_root = false; node.parent_offset = Some(new_root.offset); - self.storage_manager.store_node(&mut node)?; + self.store_node(&mut node)?; siblings .iter_mut() .for_each(|sibling| sibling.parent_offset = Some(new_root.offset)); - self.storage_manager.store_node(&mut node)?; + self.store_node(&mut node)?; siblings.iter_mut().for_each(|sibling| { if sibling.node_type == NodeType::Internal && sibling.children.len() != sibling.vectors.len() { panic!("Internal node has different number of children and vectors v3"); } - sibling.node_metadata = self.compute_node_metadata(sibling); - self.storage_manager.store_node(sibling); + sibling.node_metadata = Some(self.compute_node_metadata(sibling)); + self.store_node(sibling); }); } else { let parent_offset = node.parent_offset.unwrap(); - let mut parent = self - .storage_manager - .load_node(parent_offset) - .unwrap() - .clone(); + let mut parent = self.load_node(parent_offset).unwrap().clone(); parent.children.push(node.offset); parent.vectors.push(find_modes(node.vectors.clone())); sibling_offsets @@ -739,16 +811,16 @@ impl ANNTree { siblings.iter().for_each(|sibling| { parent.vectors.push(find_modes(sibling.vectors.clone())) }); - parent.node_metadata = self.compute_node_metadata(&parent); + parent.node_metadata = Some(self.compute_node_metadata(&parent)); if parent.node_type == NodeType::Internal && parent.children.len() != parent.vectors.len() { panic!("parent node has different number of children and vectors"); } - self.storage_manager.store_node(&mut parent)?; + self.store_node(&mut parent)?; node.parent_offset = Some(parent_offset); - self.storage_manager.store_node(&mut node)?; + self.store_node(&mut node)?; siblings.into_iter().for_each(|mut sibling| { if sibling.node_type == NodeType::Internal && sibling.children.len() != sibling.vectors.len() @@ -756,19 +828,21 @@ impl ANNTree { panic!("Internal node has different number of children and vectors v3"); } sibling.parent_offset = Some(parent_offset); - sibling.node_metadata = self.compute_node_metadata(&sibling); - self.storage_manager.store_node(&mut sibling); + sibling.node_metadata = Some(self.compute_node_metadata(&sibling)); + self.store_node(&mut sibling); }); let mut current_node = parent; while current_node.is_full() { - let mut siblings = current_node.split().expect("Failed to split node"); + let mut siblings = current_node + .split(&mut self.storage_manager) + .expect("Failed to split node"); let sibling_offsets: Vec = siblings .iter_mut() .map(|sibling| { sibling.parent_offset = Some(current_node.parent_offset.unwrap()); - sibling.node_metadata = self.compute_node_metadata(sibling); - self.storage_manager.store_node(sibling).unwrap() + sibling.node_metadata = Some(self.compute_node_metadata(sibling)); + self.store_node(sibling).unwrap() }) .collect(); @@ -793,29 +867,25 @@ impl ANNTree { siblings.iter().for_each(|sibling| { new_root.vectors.push(find_modes(sibling.vectors.clone())) }); - new_root.node_metadata = self.compute_node_metadata(&new_root); - self.storage_manager.store_node(&mut new_root)?; + new_root.node_metadata = Some(self.compute_node_metadata(&new_root)); + self.store_node(&mut new_root)?; self.storage_manager.set_root_offset(new_root.offset); current_node.is_root = false; current_node.parent_offset = Some(new_root.offset); siblings .iter_mut() .for_each(|sibling| sibling.parent_offset = Some(new_root.offset)); - self.storage_manager.store_node(&mut current_node)?; + self.store_node(&mut current_node)?; siblings.into_iter().for_each(|mut sibling| { if sibling.node_type == NodeType::Internal && sibling.children.len() != sibling.vectors.len() { panic!("Internal node has different number of children and vectors v4"); } - sibling.node_metadata = self.compute_node_metadata(&sibling); - self.storage_manager.store_node(&mut sibling); + sibling.node_metadata = Some(self.compute_node_metadata(&sibling)); + self.store_node(&mut sibling); }); } else { let parent_offset = current_node.parent_offset.unwrap(); - let mut parent = self - .storage_manager - .load_node(parent_offset) - .unwrap() - .clone(); + let mut parent = self.load_node(parent_offset).unwrap().clone(); parent.children.push(current_node.offset); sibling_offsets .iter() @@ -827,41 +897,51 @@ impl ANNTree { if sibling.node_type == NodeType::Internal && sibling.children.len() != sibling.vectors.len() { panic!("Internal node has different number of children and vectors v5"); } - sibling.node_metadata = self.compute_node_metadata(sibling); + sibling.node_metadata = Some(self.compute_node_metadata(sibling)); parent.vectors.push(find_modes(sibling.vectors.clone())) }); - parent.node_metadata = self.compute_node_metadata(&parent); - self.storage_manager.store_node(&mut parent)?; + parent.node_metadata = Some(self.compute_node_metadata(&parent)); + self.store_node(&mut parent)?; current_node.parent_offset = Some(parent_offset); - current_node.node_metadata = self.compute_node_metadata(¤t_node); - self.storage_manager.store_node(&mut current_node)?; + current_node.node_metadata = + Some(self.compute_node_metadata(¤t_node)); + self.store_node(&mut current_node)?; siblings.into_iter().for_each(|mut sibling| { sibling.parent_offset = Some(parent_offset); - sibling.node_metadata = self.compute_node_metadata(&sibling); - self.storage_manager.store_node(&mut sibling); + sibling.node_metadata = Some(self.compute_node_metadata(&sibling)); + self.store_node(&mut sibling); }); current_node = parent.clone(); - current_node.node_metadata = self.compute_node_metadata(¤t_node); + current_node.node_metadata = + Some(self.compute_node_metadata(¤t_node)); } } } } - node.node_metadata = self.compute_node_metadata(&node); + node.node_metadata = Some(self.compute_node_metadata(&node)); } Ok(()) } - fn compute_node_metadata(&self, node: &Node) -> NodeMetadataIndex { + fn compute_node_metadata(&mut self, node: &Node) -> LazyValue { let mut children_metadatas = Vec::new(); for child_offset in &node.children { - let child = self.storage_manager.load_node(*child_offset).unwrap(); - - children_metadatas.push(child.node_metadata); + let child = self.load_node(*child_offset).unwrap(); + + children_metadatas.push( + child + .node_metadata + .expect("Child has no metadata") + .get(&self.storage_manager) + .expect("Failed to get child metadata"), + ); } - combine_filters(children_metadatas) + let out = combine_filters(children_metadatas); + + LazyValue::new(out, &mut self.storage_manager).unwrap() } pub fn summarize_tree(&self) { @@ -872,7 +952,7 @@ impl ANNTree { let mut next_queue = Vec::new(); for offset in queue { - let node = self.storage_manager.load_node(offset).unwrap(); + let node = self.load_node(offset).unwrap(); println!( "Depth: {}, Node type: {:?}, Offset: {}, Children: {}, Vectors: {}", depth, diff --git a/src/structures/ann_tree/node.rs b/src/structures/ann_tree/node.rs index 06dd643..642f63b 100644 --- a/src/structures/ann_tree/node.rs +++ b/src/structures/ann_tree/node.rs @@ -1,7 +1,10 @@ use serde::Serialize; +use std::io; use crate::structures::ann_tree::k_modes::{balanced_k_modes, balanced_k_modes_4}; +use crate::structures::block_storage::BlockStorage; // use crate::structures::metadata_index::{KVPair, KVValue}; +use crate::structures::ann_tree::serialization::{TreeDeserialization, TreeSerialization}; use crate::structures::filters::{calc_metadata_index_for_metadata, KVPair, KVValue}; use crate::{constants::QUANTIZED_VECTOR_SIZE, errors::HaystackError}; use std::fmt::Debug; @@ -35,6 +38,176 @@ pub fn deserialize_node_type(data: &[u8]) -> NodeType { } } +#[derive(Debug, PartialEq, Clone)] +pub struct LazyValue { + offset: usize, + value: Option, +} + +impl LazyValue +where + T: Clone + TreeDeserialization + TreeSerialization, +{ + pub fn get(&mut self, storage: &BlockStorage) -> Result { + match self.value.clone() { + Some(value) => Ok(value), + None => { + let bytes = storage.load(self.offset)?; + let value = T::deserialize(&bytes); + self.value = Some(value.clone()); + Ok(value) + } + } + } + + pub fn new(value: T, storage: &mut BlockStorage) -> Result { + let offset = storage.store(value.serialize(), 0)?; + Ok(LazyValue { + offset, + value: Some(value), + }) + } +} + +impl TreeSerialization for Vec { + fn serialize(&self) -> Vec { + let mut serialized = Vec::new(); + + // Serialize the length of the vector + serialize_length(&mut serialized, self.len() as u32); + + for kv in self { + let serialized_kv = kv.serialize(); + serialize_length(&mut serialized, serialized_kv.len() as u32); + serialized.extend_from_slice(&serialized_kv); + } + + serialized + } +} + +impl TreeDeserialization for Vec { + fn deserialize(data: &[u8]) -> Self { + let mut offset = 0; + let mut metadata = Vec::new(); + + let metadata_len = read_length(&data[offset..offset + 4]); + offset += 4; + + for _ in 0..metadata_len { + let kv_length = read_length(&data[offset..offset + 4]); + offset += 4; + + let kv = KVPair::deserialize(&data[offset..offset + kv_length]); + offset += kv_length; + + metadata.push(kv); + } + + metadata + } +} + +impl TreeSerialization for NodeMetadataIndex { + fn serialize(&self) -> Vec { + let mut serialized = Vec::new(); + + // Serialize the length of the metadata index + serialize_length(&mut serialized, self.data.len() as u32); + + for (key, item) in self.get_all_values() { + let serialized_key = key.as_bytes(); + serialize_length(&mut serialized, serialized_key.len() as u32); + serialized.extend_from_slice(serialized_key); + + let values = item.values.clone(); + + serialize_length(&mut serialized, values.len() as u32); + for value in values { + let serialized_value = value.as_bytes(); + serialize_length(&mut serialized, serialized_value.len() as u32); + serialized.extend_from_slice(serialized_value); + } + + let int_range = item.int_range.clone(); + if int_range.is_none() { + serialized.extend_from_slice(&(0 as i64).to_le_bytes()); + serialized.extend_from_slice(&(0 as i64).to_le_bytes()); + } else { + serialized.extend_from_slice(&int_range.unwrap().0.to_le_bytes()); + serialized.extend_from_slice(&int_range.unwrap().1.to_le_bytes()); + } + + let float_range = item.float_range.clone(); + if float_range.is_none() { + serialized.extend_from_slice(&(0 as f32).to_le_bytes()); + serialized.extend_from_slice(&(0 as f32).to_le_bytes()); + } else { + serialized.extend_from_slice(&float_range.unwrap().0.to_le_bytes()); + serialized.extend_from_slice(&float_range.unwrap().1.to_le_bytes()); + } + } + + serialized + } +} + +impl TreeDeserialization for NodeMetadataIndex { + fn deserialize(data: &[u8]) -> Self { + let mut offset = 0; + let mut metadata = NodeMetadataIndex::new(); + + let metadata_len = read_length(&data[offset..offset + 4]); + offset += 4; + + for _ in 0..metadata_len { + let key_len = read_length(&data[offset..offset + 4]); + offset += 4; + + let key = String::from_utf8(data[offset..offset + key_len as usize].to_vec()).unwrap(); + offset += key_len as usize; + + let mut values = HashSet::new(); + let values_len = read_length(&data[offset..offset + 4]); + offset += 4; + + for idx in 0..values_len { + let value_len = read_length(&data[offset..offset + 4]); + offset += 4; + + let value = + String::from_utf8(data[offset..offset + value_len as usize].to_vec()).unwrap(); + offset += value_len as usize; + + values.insert(value); + } + + let mut item = NodeMetadata { + values: values.clone(), + int_range: None, + float_range: None, + }; + + let min_int = i64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()); + offset += 8; + let max_int = i64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()); + offset += 8; + + let min_float = f32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()); + offset += 4; + let max_float = f32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()); + offset += 4; + + item.int_range = Some((min_int, max_int)); + item.float_range = Some((min_float, max_float)); + + metadata.insert(key, item); + } + + metadata + } +} + const K: usize = crate::constants::K; pub type Vector = [u8; QUANTIZED_VECTOR_SIZE]; @@ -44,13 +217,13 @@ pub struct Node { pub vectors: Vec, pub ids: Vec, pub children: Vec, - pub metadata: Vec>, + pub metadata: Vec>>, pub k: usize, pub node_type: NodeType, pub offset: usize, pub is_root: bool, pub parent_offset: Option, - pub node_metadata: NodeMetadataIndex, + pub node_metadata: Option>, } impl Node { @@ -65,7 +238,7 @@ impl Node { offset: 0, is_root: false, parent_offset: None, - node_metadata: NodeMetadataIndex::new(), + node_metadata: None, } } @@ -80,11 +253,11 @@ impl Node { offset: 0, is_root: false, parent_offset: None, - node_metadata: NodeMetadataIndex::new(), + node_metadata: None, } } - pub fn split(&mut self) -> Result, HaystackError> { + pub fn split(&mut self, storage: &mut BlockStorage) -> Result, io::Error> { let k = match self.node_type { NodeType::Leaf => 2, NodeType::Internal => 2, @@ -125,8 +298,8 @@ impl Node { let mut node_metadata = NodeMetadataIndex::new(); for kv in &clusters_metadata[i] { - for pair in kv { - node_metadata.insert_kv_pair(pair); + for pair in kv.clone().get(storage)? { + node_metadata.insert_kv_pair(&pair); } } @@ -140,7 +313,7 @@ impl Node { offset: 0, // This should be set when the node is stored is_root: false, parent_offset: self.parent_offset, - node_metadata, + node_metadata: Some(LazyValue::new(node_metadata, storage)?), }; siblings.push(sibling.clone()); @@ -168,12 +341,12 @@ impl Node { let mut node_metadata = NodeMetadataIndex::new(); for kv in &self.metadata { - for pair in kv { - node_metadata.insert_kv_pair(pair); + for pair in kv.clone().get(storage)? { + node_metadata.insert_kv_pair(&pair); } } - self.node_metadata = node_metadata; + self.node_metadata = Some(LazyValue::new(node_metadata, storage)?); Ok(siblings) } @@ -213,51 +386,60 @@ impl Node { serialized.extend_from_slice(&child.to_le_bytes()); } - // Serialize metadata + // serialize metadata offsets serialize_length(&mut serialized, self.metadata.len() as u32); for meta in &self.metadata { - // let serialized_meta = serialize_metadata(meta); // Function to serialize a Vec - // serialized.extend_from_slice(&serialized_meta); - serialize_metadata(&mut serialized, meta); + serialized.extend_from_slice(&meta.offset.to_le_bytes()); } - // Serialize node_metadata - serialize_length( - &mut serialized, - self.node_metadata.get_all_values().len() as u32, - ); - for (key, item) in self.node_metadata.get_all_values() { - let serialized_key = key.as_bytes(); - serialize_length(&mut serialized, serialized_key.len() as u32); - serialized.extend_from_slice(serialized_key); - - let values = item.values.clone(); - - serialize_length(&mut serialized, values.len() as u32); - for value in values { - let serialized_value = value.as_bytes(); - serialize_length(&mut serialized, serialized_value.len() as u32); - serialized.extend_from_slice(serialized_value); - } - - let int_range = item.int_range.clone(); - if int_range.is_none() { - serialized.extend_from_slice(&(0 as i64).to_le_bytes()); - serialized.extend_from_slice(&(0 as i64).to_le_bytes()); - } else { - serialized.extend_from_slice(&int_range.unwrap().0.to_le_bytes()); - serialized.extend_from_slice(&int_range.unwrap().1.to_le_bytes()); - } - - let float_range = item.float_range.clone(); - if float_range.is_none() { - serialized.extend_from_slice(&(0 as f32).to_le_bytes()); - serialized.extend_from_slice(&(0 as f32).to_le_bytes()); - } else { - serialized.extend_from_slice(&float_range.unwrap().0.to_le_bytes()); - serialized.extend_from_slice(&float_range.unwrap().1.to_le_bytes()); - } - } + // serialize node_metadata offset + serialized.extend_from_slice(&self.node_metadata.as_ref().unwrap().offset.to_le_bytes()); + + // // Serialize metadata + // serialize_length(&mut serialized, self.metadata.len() as u32); + // for meta in &self.metadata { + // // let serialized_meta = serialize_metadata(meta); // Function to serialize a Vec + // // serialized.extend_from_slice(&serialized_meta); + // serialize_metadata(&mut serialized, meta); + // } + + // // Serialize node_metadata + // serialize_length( + // &mut serialized, + // self.node_metadata.get_all_values().len() as u32, + // ); + // for (key, item) in self.node_metadata.get_all_values() { + // let serialized_key = key.as_bytes(); + // serialize_length(&mut serialized, serialized_key.len() as u32); + // serialized.extend_from_slice(serialized_key); + + // let values = item.values.clone(); + + // serialize_length(&mut serialized, values.len() as u32); + // for value in values { + // let serialized_value = value.as_bytes(); + // serialize_length(&mut serialized, serialized_value.len() as u32); + // serialized.extend_from_slice(serialized_value); + // } + + // let int_range = item.int_range.clone(); + // if int_range.is_none() { + // serialized.extend_from_slice(&(0 as i64).to_le_bytes()); + // serialized.extend_from_slice(&(0 as i64).to_le_bytes()); + // } else { + // serialized.extend_from_slice(&int_range.unwrap().0.to_le_bytes()); + // serialized.extend_from_slice(&int_range.unwrap().1.to_le_bytes()); + // } + + // let float_range = item.float_range.clone(); + // if float_range.is_none() { + // serialized.extend_from_slice(&(0 as f32).to_le_bytes()); + // serialized.extend_from_slice(&(0 as f32).to_le_bytes()); + // } else { + // serialized.extend_from_slice(&float_range.unwrap().0.to_le_bytes()); + // serialized.extend_from_slice(&float_range.unwrap().1.to_le_bytes()); + // } + // } serialized } @@ -313,81 +495,106 @@ impl Node { children.push(child); } - // Deserialize metadata + // deserialize metadata let metadata_len = read_length(&data[offset..offset + 4]); offset += 4; + let mut metadata = Vec::with_capacity(metadata_len); for _ in 0..metadata_len { - let (meta, meta_size) = deserialize_metadata(&data[offset..]); - metadata.push(meta); - offset += meta_size; // Increment offset based on actual size of deserialized metadata + let meta_offset = + u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; + offset += 8; + metadata.push(LazyValue { + offset: meta_offset, + value: None, + }); } // Deserialize node_metadata - let mut node_metadata = NodeMetadataIndex::new(); - let node_metadata_len = read_length(&data[offset..offset + 4]); - offset += 4; - - for _ in 0..node_metadata_len { - let key_len = read_length(&data[offset..offset + 4]); - offset += 4; - - let key = String::from_utf8(data[offset..offset + key_len as usize].to_vec()).unwrap(); - offset += key_len as usize; - - let mut values = HashSet::new(); - let values_len = read_length(&data[offset..offset + 4]); - offset += 4; - - for idx in 0..values_len { - let value_len = read_length(&data[offset..offset + 4]); - offset += 4; - - if value_len > data.len() - offset { - println!("Current IDX: {}", idx); - println!("Value length: {}", value_len); - println!("Value len binary: {:?}", (value_len as u32).to_le_bytes()); - println!("Data length: {}", data.len()); - // add some more debug prints for the current state of things to figure out where it's going wrong - - println!("Offset: {}", offset); - println!("Key: {}", key); - println!("Values: {:?}", values); - println!("Values len: {}", values_len); - println!("Node metadata: {:?}", node_metadata); - println!("Node metadata len: {}", node_metadata_len); - - panic!("Value length exceeds data length"); - } - - let value = - String::from_utf8(data[offset..offset + value_len as usize].to_vec()).unwrap(); - offset += value_len as usize; - - values.insert(value); - } - - let mut item = NodeMetadata { - values: values.clone(), - int_range: None, - float_range: None, - }; - - let min_int = i64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()); - offset += 8; - let max_int = i64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()); - offset += 8; - - let min_float = f32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()); - offset += 4; - let max_float = f32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()); - offset += 4; + let node_metadata_offset = + u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; + offset += 8; - item.int_range = Some((min_int, max_int)); - item.float_range = Some((min_float, max_float)); + let node_metadata = LazyValue { + offset: node_metadata_offset, + value: None, + }; - node_metadata.insert(key, item); - } + // // Deserialize metadata + // let metadata_len = read_length(&data[offset..offset + 4]); + // offset += 4; + // let mut metadata = Vec::with_capacity(metadata_len); + // for _ in 0..metadata_len { + // let (meta, meta_size) = deserialize_metadata(&data[offset..]); + // metadata.push(meta); + // offset += meta_size; // Increment offset based on actual size of deserialized metadata + // } + + // // Deserialize node_metadata + // let mut node_metadata = NodeMetadataIndex::new(); + // let node_metadata_len = read_length(&data[offset..offset + 4]); + // offset += 4; + + // for _ in 0..node_metadata_len { + // let key_len = read_length(&data[offset..offset + 4]); + // offset += 4; + + // let key = String::from_utf8(data[offset..offset + key_len as usize].to_vec()).unwrap(); + // offset += key_len as usize; + + // let mut values = HashSet::new(); + // let values_len = read_length(&data[offset..offset + 4]); + // offset += 4; + + // for idx in 0..values_len { + // let value_len = read_length(&data[offset..offset + 4]); + // offset += 4; + + // if value_len > data.len() - offset { + // println!("Current IDX: {}", idx); + // println!("Value length: {}", value_len); + // println!("Value len binary: {:?}", (value_len as u32).to_le_bytes()); + // println!("Data length: {}", data.len()); + // // add some more debug prints for the current state of things to figure out where it's going wrong + + // println!("Offset: {}", offset); + // println!("Key: {}", key); + // println!("Values: {:?}", values); + // println!("Values len: {}", values_len); + // println!("Node metadata: {:?}", node_metadata); + // println!("Node metadata len: {}", node_metadata_len); + + // panic!("Value length exceeds data length"); + // } + + // let value = + // String::from_utf8(data[offset..offset + value_len as usize].to_vec()).unwrap(); + // offset += value_len as usize; + + // values.insert(value); + // } + + // let mut item = NodeMetadata { + // values: values.clone(), + // int_range: None, + // float_range: None, + // }; + + // let min_int = i64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()); + // offset += 8; + // let max_int = i64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()); + // offset += 8; + + // let min_float = f32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()); + // offset += 4; + // let max_float = f32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()); + // offset += 4; + + // item.int_range = Some((min_int, max_int)); + // item.float_range = Some((min_float, max_float)); + + // node_metadata.insert(key, item); + // } Node { vectors, @@ -403,7 +610,7 @@ impl Node { } else { None }, - node_metadata, + node_metadata: Some(node_metadata), } } } @@ -419,7 +626,7 @@ impl Default for Node { offset: 0, is_root: false, parent_offset: None, - node_metadata: NodeMetadataIndex::new(), + node_metadata: None, } } } @@ -485,32 +692,32 @@ fn random_string(len: usize) -> String { .to_string() } -#[test] -fn test_serialize_deserialize() { - let mut node = Node::new_leaf(); - let mut vectors = Vec::new(); - let mut ids = Vec::new(); - let mut kvs = Vec::new(); - - for _ in 0..96 { - let vector: [u8; 128] = [0; 128]; - vectors.push(vector); - ids.push(0); - kvs.push(vec![ - KVPair::new("key".to_string(), random_string(77)), - KVPair::new("url".to_string(), random_string(44)), - KVPair::new("name".to_string(), random_string(17)), - ]); - } - - node.vectors = vectors; - node.ids = ids; - node.metadata = kvs.clone(); - node.node_metadata = calc_metadata_index_for_metadata(kvs.clone()); - - let serialized = node.serialize(); - - let deserialized = Node::deserialize(&serialized); - - // assert_eq!(node, deserialized); -} +// #[test] +// fn test_serialize_deserialize() { +// let mut node = Node::new_leaf(); +// let mut vectors = Vec::new(); +// let mut ids = Vec::new(); +// let mut kvs = Vec::new(); + +// for _ in 0..96 { +// let vector: [u8; 128] = [0; 128]; +// vectors.push(vector); +// ids.push(0); +// kvs.push(vec![ +// KVPair::new("key".to_string(), random_string(77)), +// KVPair::new("url".to_string(), random_string(44)), +// KVPair::new("name".to_string(), random_string(17)), +// ]); +// } + +// node.vectors = vectors; +// node.ids = ids; +// node.metadata = kvs.clone(); +// node.node_metadata = calc_metadata_index_for_metadata(kvs.clone()); + +// let serialized = node.serialize(); + +// let deserialized = Node::deserialize(&serialized); + +// // assert_eq!(node, deserialized); +// } diff --git a/src/structures/block_storage.rs b/src/structures/block_storage.rs new file mode 100644 index 0000000..f9ec87c --- /dev/null +++ b/src/structures/block_storage.rs @@ -0,0 +1,359 @@ +use crate::services::LockService; +use memmap::MmapMut; +use std::fs; +use std::fs::OpenOptions; +use std::io; +use std::path::PathBuf; + +pub struct BlockStorage { + pub mmap: MmapMut, + path: PathBuf, + locks: LockService, +} + +/* + + Schema for Block Storage: + + - Header: + - Used blocks (u64) + - Root index (u64) + - Blocks: + - Block Header: + - Is primary (u8) + - Index in chain (u64) + - Primary index (u64) + - Next block index (u64) + - Previous block index (u64) + - Serialized node length (u64) + + - Data: + - Node data + +*/ + +pub const SIZE_OF_U64: usize = std::mem::size_of::(); +pub const HEADER_SIZE: usize = SIZE_OF_U64 * 2; // Used space + root offset + +pub const BLOCK_SIZE: usize = 4096; // typical page size +pub const BLOCK_HEADER_SIZE: usize = SIZE_OF_U64 * 5 + 1; // Index in chain + Primary index + Next block offset + Previous block offset + Serialized node length + Is primary +pub const BLOCK_DATA_SIZE: usize = BLOCK_SIZE - BLOCK_HEADER_SIZE; + +#[derive(Debug, Clone)] +pub struct BlockHeaderData { + pub is_primary: bool, + pub index_in_chain: u64, + pub primary_index: u64, + pub next_block_offset: u64, + pub previous_block_offset: u64, + pub serialized_node_length: u64, +} + +impl Copy for BlockHeaderData {} + +impl BlockStorage { + pub fn new(path: PathBuf) -> io::Result { + let exists = path.exists(); + let file = OpenOptions::new() + .read(true) + .write(true) + .create(!exists) + .open(path.clone())?; + + if !exists { + file.set_len(1_000_000)?; + } + + let mmap = unsafe { MmapMut::map_mut(&file)? }; + + // take path, remove everything after the last dot (the extension), and add _locks + let mut locks_path = path.clone().to_str().unwrap().to_string(); + let last_dot = locks_path.rfind('.').unwrap(); + locks_path.replace_range(last_dot.., "_locks"); + + fs::create_dir_all(&locks_path).expect("Failed to create directory"); + + Ok(BlockStorage { + mmap, + path, + locks: LockService::new(locks_path.into()), + }) + } + + pub fn used_blocks(&self) -> usize { + (u64::from_le_bytes(self.mmap[0..SIZE_OF_U64].try_into().unwrap()) as usize) + 1 + } + + pub fn set_used_blocks(&mut self, used_blocks: usize) { + self.mmap[0..SIZE_OF_U64].copy_from_slice(&(used_blocks as u64).to_le_bytes()); + } + + pub fn root_offset(&self) -> usize { + u64::from_le_bytes( + self.mmap[SIZE_OF_U64..(2 * SIZE_OF_U64)] + .try_into() + .unwrap(), + ) as usize + } + + pub fn set_root_offset(&mut self, root_offset: usize) { + self.mmap[SIZE_OF_U64..(2 * SIZE_OF_U64)] + .copy_from_slice(&(root_offset as u64).to_le_bytes()); + } + + pub fn increment_and_allocate_block(&mut self) -> usize { + let mut used_blocks = self.used_blocks(); + self.set_used_blocks(used_blocks + 1); + + if (used_blocks + 1) * BLOCK_SIZE > self.mmap.len() { + self.resize_mmap().unwrap(); + } + + used_blocks + } + + fn resize_mmap(&mut self) -> io::Result<()> { + println!("Resizing mmap"); + let current_len = self.mmap.len(); + let new_len = current_len * 2; + + let file = OpenOptions::new() + .read(true) + .write(true) + .open(self.path.clone())?; + + file.set_len(new_len as u64)?; + + self.mmap = unsafe { MmapMut::map_mut(&file)? }; + Ok(()) + } + + pub fn get_block_header_data(&self, index: usize) -> BlockHeaderData { + let start = HEADER_SIZE + index * BLOCK_SIZE; + let end = start + BLOCK_HEADER_SIZE; + + let is_primary = self.mmap[start] == 1; + let index_in_chain = + u64::from_le_bytes(self.mmap[start + 1..start + 9].try_into().unwrap()); + let primary_index = + u64::from_le_bytes(self.mmap[start + 9..start + 17].try_into().unwrap()); + let next_block_offset = + u64::from_le_bytes(self.mmap[start + 17..start + 25].try_into().unwrap()); + let previous_block_offset = + u64::from_le_bytes(self.mmap[start + 25..start + 33].try_into().unwrap()); + let serialized_node_length = + u64::from_le_bytes(self.mmap[start + 33..end].try_into().unwrap()); + + BlockHeaderData { + is_primary, + index_in_chain, + primary_index, + next_block_offset, + previous_block_offset, + serialized_node_length, + } + } + + pub fn set_block_header_data(&mut self, index: usize, data: BlockHeaderData) { + let start = HEADER_SIZE + index * BLOCK_SIZE; + + self.mmap[start] = data.is_primary as u8; + self.mmap[start + 1..start + 9].copy_from_slice(&data.index_in_chain.to_le_bytes()); + self.mmap[start + 9..start + 17].copy_from_slice(&data.primary_index.to_le_bytes()); + self.mmap[start + 17..start + 25].copy_from_slice(&data.next_block_offset.to_le_bytes()); + self.mmap[start + 25..start + 33] + .copy_from_slice(&data.previous_block_offset.to_le_bytes()); + self.mmap[start + 33..start + 41] + .copy_from_slice(&data.serialized_node_length.to_le_bytes()); + } + + pub fn get_block_bytes(&self, index: usize) -> &[u8] { + let start = HEADER_SIZE + index * BLOCK_SIZE + BLOCK_HEADER_SIZE; + let end = start + BLOCK_DATA_SIZE; + + &self.mmap[start..end] + } + + pub fn store_block(&mut self, index: usize, data: &[u8]) { + let start = HEADER_SIZE + index * BLOCK_SIZE + BLOCK_HEADER_SIZE; + + self.mmap[start..start + data.len()].copy_from_slice(data); + } + + pub fn store(&mut self, serialized: Vec, index: usize) -> io::Result { + let serialized_len = serialized.len() as u64; + let blocks_required = + ((serialized_len + BLOCK_DATA_SIZE as u64 - 1) / BLOCK_DATA_SIZE as u64) as usize; + + // Allocate new block if this is a new node + let mut current_block_index = if index == 0 { + self.increment_and_allocate_block() + } else { + index + }; + + // Initialize writing state + let mut remaining_bytes_to_write = serialized_len; + let mut bytes_written = 0; + let mut prev_block_index = 0; + + let original_block_index = current_block_index; + + self.acquire_lock(original_block_index)?; + + // Clear previous overflow chain if it exists + let mut used_blocks = Vec::new(); + if index != 0 { + let mut temp_block_index = index; + while temp_block_index != 0 { + let block_header = self.get_block_header_data(temp_block_index); + used_blocks.push(temp_block_index); + temp_block_index = block_header.next_block_offset as usize; + if used_blocks.len() >= blocks_required { + break; + } + } + } + + // Clear excess blocks if the node is smaller + if used_blocks.len() > blocks_required { + for &index in &used_blocks[blocks_required..] { + self.clear_block(index); + } + used_blocks.truncate(blocks_required); + } + + // Write new node data into blocks + for i in 0..blocks_required { + let bytes_to_write = std::cmp::min(remaining_bytes_to_write as usize, BLOCK_DATA_SIZE); + self.store_block( + current_block_index, + &serialized[bytes_written as usize..bytes_written + bytes_to_write], + ); + + let next_block_index = if remaining_bytes_to_write > bytes_to_write as u64 { + if i + 1 < used_blocks.len() { + used_blocks[i + 1] + } else { + self.increment_and_allocate_block() + } + } else { + 0 + }; + + let block_header = BlockHeaderData { + is_primary: i == 0, + index_in_chain: i as u64, + primary_index: original_block_index as u64, + next_block_offset: next_block_index as u64, + previous_block_offset: if i == 0 { 0 } else { prev_block_index as u64 }, + serialized_node_length: serialized_len, + }; + + self.set_block_header_data(current_block_index, block_header); + + // Debug statements + // println!( + // "Block {}: is_primary={}, index_in_chain={}, primary_index={}, next_block_offset={}, previous_block_offset={}, serialized_node_length={}", + // current_block_index, + // block_header.is_primary, + // block_header.index_in_chain, + // block_header.primary_index, + // block_header.next_block_offset, + // block_header.previous_block_offset, + // block_header.serialized_node_length + // ); + + prev_block_index = current_block_index; + current_block_index = next_block_index; + remaining_bytes_to_write -= bytes_to_write as u64; + bytes_written += bytes_to_write; + } + + if bytes_written != serialized_len as usize { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Bytes written does not match serialized length", + )); + } + + if remaining_bytes_to_write != 0 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Remaining bytes to write is not 0", + )); + } + + self.release_lock(original_block_index)?; + Ok(original_block_index) + } + + fn clear_block(&mut self, index: usize) { + let start = HEADER_SIZE + index * BLOCK_SIZE; + let end = start + BLOCK_SIZE; + self.mmap[start..end].fill(0); + // Debug statement + println!("Cleared block at index {}", index); + } + + pub fn load(&self, offset: usize) -> io::Result> { + let mut current_block_index = offset; + let mut serialized = Vec::new(); + + self.acquire_lock(offset)?; + + loop { + let block_header = self.get_block_header_data(current_block_index); + + if block_header.is_primary { + serialized = Vec::with_capacity(block_header.serialized_node_length as usize); + } + + let data = self.get_block_bytes(current_block_index); + serialized.extend_from_slice(data); + + // println!( + // "LOADING Block {}: is_primary={}, index_in_chain={}, primary_index={}, next_block_offset={}, previous_block_offset={}, serialized_node_length={}", + // current_block_index, + // block_header.is_primary, + // block_header.index_in_chain, + // block_header.primary_index, + // block_header.next_block_offset, + // block_header.previous_block_offset, + // block_header.serialized_node_length + // ); + + if block_header.next_block_offset == 0 { + if serialized.len() < block_header.serialized_node_length as usize { + println!("Serialized node length does not match actual length"); + println!("Serialized length: {}", serialized.len()); + println!( + "Actual length: {}", + block_header.serialized_node_length as usize + ); + + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Serialized node length does not match actual length", + )); + } + + break; + } + + current_block_index = block_header.next_block_offset as usize; + } + + self.release_lock(offset)?; + + Ok(serialized) + } + + pub fn acquire_lock(&self, index: usize) -> io::Result<()> { + self.locks.acquire(index.to_string()) + } + + pub fn release_lock(&self, index: usize) -> io::Result<()> { + self.locks.release(index.to_string()) + } +} diff --git a/src/structures/dense_vector_list.rs b/src/structures/dense_vector_list.rs deleted file mode 100644 index 8243704..0000000 --- a/src/structures/dense_vector_list.rs +++ /dev/null @@ -1,193 +0,0 @@ -use crate::constants::QUANTIZED_VECTOR_SIZE; -use memmap::MmapMut; -use std::fs::OpenOptions; -use std::io; -use std::path::PathBuf; - -const SIZE_OF_U64: usize = std::mem::size_of::(); -const HEADER_SIZE: usize = SIZE_OF_U64; - -pub struct DenseVectorList { - mmap: MmapMut, - used_space: usize, // Track the used space within the mmap beyond the header - pub path: PathBuf, -} - -impl DenseVectorList { - pub fn new(path: PathBuf, elements: u64) -> io::Result { - let exists = path.exists(); - let file = OpenOptions::new() - .read(true) - .write(true) - .create(!exists) - .open(path.clone())?; - - if !exists { - // Set the file size, accounting for the header - file.set_len(elements * (QUANTIZED_VECTOR_SIZE as u64) + HEADER_SIZE as u64)?; - } - - let mut mmap = unsafe { MmapMut::map_mut(&file)? }; - - let used_space = if exists && file.metadata().unwrap().len() as usize > HEADER_SIZE { - // Read the existing used space from the file - let used_bytes = &mmap[0..HEADER_SIZE]; - u64::from_le_bytes(used_bytes.try_into().unwrap()) as usize - } else { - 0 // No data written yet, or file did not exist - }; - - if !exists { - // Initialize the header if the file is newly created - mmap[0..HEADER_SIZE].copy_from_slice(&(used_space as u64).to_le_bytes()); - } - - Ok(DenseVectorList { - mmap, - used_space, - path, - }) - } - - pub fn push(&mut self, vector: [u8; QUANTIZED_VECTOR_SIZE]) -> io::Result { - let offset = self.used_space + HEADER_SIZE; - let required_space = offset + QUANTIZED_VECTOR_SIZE; - - if required_space > self.mmap.len() { - self.resize_mmap(required_space * 2)?; - } - - self.mmap[offset..required_space].copy_from_slice(&vector); - self.used_space += QUANTIZED_VECTOR_SIZE; - // Update the header in the mmap - self.mmap[0..HEADER_SIZE].copy_from_slice(&(self.used_space as u64).to_le_bytes()); - - Ok(self.used_space / QUANTIZED_VECTOR_SIZE - 1) - } - - fn resize_mmap(&mut self, new_len: usize) -> io::Result<()> { - println!("Resizing mmap in DenseVectorList"); - - let file = OpenOptions::new() - .read(true) - .write(true) - .open(self.path.clone())?; // Ensure this path is handled correctly - - file.set_len(new_len as u64)?; - - self.mmap = unsafe { MmapMut::map_mut(&file)? }; - Ok(()) - } - - pub fn batch_push( - &mut self, - vectors: Vec<[u8; QUANTIZED_VECTOR_SIZE]>, - ) -> io::Result> { - let start_offset = self.used_space + HEADER_SIZE; - let total_size = vectors.len() * QUANTIZED_VECTOR_SIZE; - let required_space = start_offset + total_size; - - // println!( - // "Required space: {}, mmap len: {}", - // required_space, - // self.mmap.len() - // ); - - if required_space > self.mmap.len() { - self.resize_mmap(required_space * 2)?; - } - - // println!("Batch push"); - - for (i, vector) in vectors.iter().enumerate() { - let offset = start_offset + i * QUANTIZED_VECTOR_SIZE; - self.mmap[offset..offset + QUANTIZED_VECTOR_SIZE].copy_from_slice(vector); - } - - // println!("Batch push done"); - - self.used_space += total_size; - // Update the header in the mmap - self.mmap[0..HEADER_SIZE].copy_from_slice(&(self.used_space as u64).to_le_bytes()); - - Ok((((start_offset - HEADER_SIZE) / QUANTIZED_VECTOR_SIZE) - ..(self.used_space / QUANTIZED_VECTOR_SIZE)) - .collect()) - } - - pub fn get(&self, index: usize) -> io::Result<&[u8; QUANTIZED_VECTOR_SIZE]> { - let offset = HEADER_SIZE + index * QUANTIZED_VECTOR_SIZE; - let end = offset + QUANTIZED_VECTOR_SIZE; - - if end > self.used_space + HEADER_SIZE { - // print everything for debugging - println!("Offset: {}", offset); - println!("End: {}", end); - println!("Used space: {}", self.used_space); - - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Index out of bounds", - )); - } - - //don't use unsafe - let bytes = &self.mmap[offset..end]; - let val = bytes.try_into().unwrap(); - Ok(val) - } - - pub fn get_contiguous( - &self, - index: usize, - num_elements: usize, - ) -> io::Result<&[[u8; QUANTIZED_VECTOR_SIZE]]> { - let start = HEADER_SIZE + index * QUANTIZED_VECTOR_SIZE; - let end = start + num_elements * QUANTIZED_VECTOR_SIZE; - - if end > self.used_space + HEADER_SIZE { - println!("start: {}", start); - println!("End: {}", end); - println!("Used space: {}", self.used_space); - println!("Num elements: {}", num_elements); - println!("Index: {}", index); - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Index out of bounds", - )); - } - - // let mut vectors = Vec::with_capacity(num_elements); - // for i in 0..num_elements { - // let offset = HEADER_SIZE + (index + i) * QUANTIZED_VECTOR_SIZE; - // vectors.push(self.get(index + i)?); - // } - - // the indices are contiguous, so we can just get a slice of the mmap - let vectors: &[[u8; QUANTIZED_VECTOR_SIZE]] = unsafe { - std::slice::from_raw_parts( - self.mmap.as_ptr().add(start) as *const [u8; QUANTIZED_VECTOR_SIZE], - num_elements, - ) - }; - - Ok(vectors) - } - - pub fn len(&self) -> usize { - self.used_space / QUANTIZED_VECTOR_SIZE - } - - pub fn insert(&mut self, index: usize, vector: [u8; QUANTIZED_VECTOR_SIZE]) -> io::Result<()> { - let offset = HEADER_SIZE + index * QUANTIZED_VECTOR_SIZE; - let end = offset + QUANTIZED_VECTOR_SIZE; - - if end > self.used_space + HEADER_SIZE { - self.resize_mmap(end * 2)?; - } - - self.mmap[offset..end].copy_from_slice(&vector); - - Ok(()) - } -} diff --git a/src/structures/filters.rs b/src/structures/filters.rs index ede8f42..f760e22 100644 --- a/src/structures/filters.rs +++ b/src/structures/filters.rs @@ -3,6 +3,8 @@ use rayon::prelude::*; use serde::{Deserialize, Serialize}; use super::ann_tree::metadata::{NodeMetadata, NodeMetadataIndex}; +use super::ann_tree::node::LazyValue; +use super::block_storage::BlockStorage; use crate::structures::mmap_tree::serialization::{TreeDeserialization, TreeSerialization}; use std::fmt::Display; use std::hash::{Hash, Hasher}; @@ -476,11 +478,16 @@ pub fn combine_filters(filters: Vec) -> NodeMetadataIndex { result } -pub fn calc_metadata_index_for_metadata(kvs: Vec>) -> NodeMetadataIndex { +pub fn calc_metadata_index_for_metadata( + kvs: Vec>>, + storage: &BlockStorage, +) -> NodeMetadataIndex { let node_metadata: NodeMetadataIndex = kvs .into_iter() - .map(|metadata| { + .map(|mut metadata| { metadata + .get(storage) + .expect("Failed to get metadata") .into_iter() .map(|kv_pair| (kv_pair.key.clone(), kv_pair.value.clone())) .fold(NodeMetadataIndex::new(), |mut acc, (key, value)| { diff --git a/src/structures/inverted_index.rs b/src/structures/inverted_index.rs deleted file mode 100644 index afdcaff..0000000 --- a/src/structures/inverted_index.rs +++ /dev/null @@ -1,228 +0,0 @@ -use std::fmt::Display; -use std::path::PathBuf; - -use serde::{Deserialize, Serialize}; - -use crate::structures::mmap_tree::Tree; - -use super::metadata_index::KVPair; -use super::mmap_tree::serialization::{TreeDeserialization, TreeSerialization}; - -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct InvertedIndexItem { - pub indices: Vec, - pub ids: Vec, -} - -impl Display for InvertedIndexItem { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "InvertedIndexItem {{ ... }}") - } -} - -impl TreeSerialization for InvertedIndexItem { - fn serialize(&self) -> Vec { - let mut serialized = Vec::new(); - - serialized.extend_from_slice(self.indices.len().to_le_bytes().as_ref()); - - let len_of_index_bytes: usize = 8; - - serialized.extend_from_slice(len_of_index_bytes.to_le_bytes().as_ref()); - - for index in &self.indices { - serialized.extend_from_slice(index.to_le_bytes().as_ref()); - } - - serialized.extend_from_slice(self.ids.len().to_le_bytes().as_ref()); - - let len_of_id_bytes: usize = 16; - - serialized.extend_from_slice(len_of_id_bytes.to_le_bytes().as_ref()); - - for id in &self.ids { - serialized.extend_from_slice(id.to_le_bytes().as_ref()); - } - - serialized - } -} - -impl TreeDeserialization for InvertedIndexItem { - fn deserialize(data: &[u8]) -> Self { - let mut offset = 0; - - let indices_len = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; - offset += 8; - // let mut indices = Vec::new(); - let len_of_index_bytes = usize::from_le_bytes(data[offset..offset + 8].try_into().unwrap()); - offset += 8; - - let start = offset; - let end = start + indices_len * len_of_index_bytes; - - let indices_bytes = &data[start..end]; - - let indices_chunks = indices_bytes.chunks(len_of_index_bytes); - - // for chunk in indices_chunks { - // let index = usize::from_le_bytes(chunk.try_into().unwrap()); - // indices.push(index); - // } - - let indices = indices_chunks - .map(|chunk| usize::from_le_bytes(chunk.try_into().unwrap())) - .collect(); - - offset = end; - - let ids_len = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; - offset += 8; - // let mut ids = Vec::new(); - let len_of_id_bytes = usize::from_le_bytes(data[offset..offset + 8].try_into().unwrap()); - offset += 8; - - // get them all and split the bytes into chunks - - let start = offset; - let end = start + ids_len * len_of_id_bytes; - let ids_bytes = &data[start..end]; - - let ids_chunks = ids_bytes.chunks(len_of_id_bytes); - - // for chunk in ids_chunks { - // let id = String::from_utf8(chunk.to_vec()).unwrap(); - // ids.push(id); - // } - let ids = ids_chunks - .map(|chunk| u128::from_le_bytes(chunk.try_into().unwrap())) - .collect(); - - InvertedIndexItem { indices, ids } - } -} - -pub struct InvertedIndex { - pub path: PathBuf, - pub tree: Tree, -} - -pub fn compress_indices(indices: Vec) -> Vec { - let mut compressed = Vec::new(); - if indices.is_empty() { - return compressed; - } - - let mut current_start = indices[0]; - let mut count = 1; - - for i in 1..indices.len() { - if indices[i] == current_start + count { - count += 1; - } else { - compressed.push(current_start); - compressed.push(count); - current_start = indices[i]; - count = 1; - } - } - compressed.push(current_start); - compressed.push(count); - - compressed -} - -pub fn decompress_indices(compressed: Vec) -> Vec { - let mut decompressed = Vec::new(); - let mut i = 0; - - while i < compressed.len() { - let start = compressed[i]; - let count = compressed[i + 1]; - decompressed.extend((start..start + count).collect::>()); - i += 2; // Move to the next pair - } - - decompressed -} - -impl InvertedIndex { - pub fn new(path: PathBuf) -> Self { - let tree = Tree::new(path.clone()).expect("Failed to create tree"); - InvertedIndex { path, tree } - } - - pub fn insert(&mut self, key: KVPair, value: InvertedIndexItem, skip_compression: bool) { - // println!("Inserting INTO INVERTED INDEX: {:?}", key); - if !skip_compression { - let compressed_indices = compress_indices(value.indices); - let value = InvertedIndexItem { - indices: compressed_indices, - ids: value.ids, - }; - self.tree.insert(key, value).expect("Failed to insert"); - } else { - self.tree.insert(key, value).expect("Failed to insert"); - } - // let compressed_indices = compress_indices(value.indices); - // let value = InvertedIndexItem { - // indices: compressed_indices, - // ids: value.ids, - // }; - // self.tree.insert(key, value).expect("Failed to insert"); - } - - pub fn get(&mut self, key: KVPair) -> Option { - // println!("Getting key: {:?}", key); - match self.tree.search(key) { - Ok(v) => { - // decompress the indices - match v { - Some(mut item) => { - // println!("Search result: {:?}", item); // Add this - - item.indices = decompress_indices(item.indices); - // println!("Decompressed indices: {:?}", item.indices); // Check output - - Some(item) - } - None => None, - } - } - Err(_) => None, - } - } - - pub fn insert_append(&mut self, key: KVPair, mut value: InvertedIndexItem) { - match self.get(key.clone()) { - Some(mut v) => { - // v.indices.extend(value.indices); - v.ids.extend(value.ids); - - let mut decompressed = v.indices.clone(); - - // binary search to insert all of the ones to append - for index in value.indices { - let idx = decompressed.binary_search(&index).unwrap_or_else(|x| x); - decompressed.insert(idx, index); - } - - decompressed.sort_unstable(); - decompressed.dedup(); - - // println!("Before compression: {:?}", decompressed); - - v.indices = compress_indices(decompressed); - - // println!("After compression: {:?}", v.indices); - - self.insert(key, v, true); - } - None => { - value.indices = compress_indices(value.indices); - // println!("Compressed: {:?}", value.indices); - self.insert(key, value, true); - } - } - } -} diff --git a/src/structures/metadata_index.rs b/src/structures/metadata_index.rs deleted file mode 100644 index 63d95d9..0000000 --- a/src/structures/metadata_index.rs +++ /dev/null @@ -1,426 +0,0 @@ -use std::fmt::Display; -use std::hash::Hash; -use std::hash::Hasher; - -use serde::{Deserialize, Serialize}; - -use crate::structures::tree::Tree; - -use super::tree::serialization::{TreeDeserialization, TreeSerialization}; - -#[derive(Debug, Serialize, Deserialize, Clone)] -#[serde(untagged)] -pub enum KVValue { - String(String), - Integer(i64), - Float(f32), -} - -impl Display for KVValue { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - KVValue::String(s) => write!(f, "{}", s), - KVValue::Integer(i) => write!(f, "{}", i), - KVValue::Float(fl) => write!(f, "{}", fl), - } - } -} - -impl Hash for KVValue { - fn hash(&self, state: &mut H) { - match self { - KVValue::String(s) => s.hash(state), - KVValue::Integer(i) => i.hash(state), - KVValue::Float(f) => { - let bits: u32 = f.to_bits(); - bits.hash(state); - } - } - } -} - -impl PartialEq for KVValue { - fn eq(&self, other: &Self) -> bool { - match (self, other) { - (KVValue::String(s1), KVValue::String(s2)) => s1 == s2, - (KVValue::Integer(i1), KVValue::Integer(i2)) => i1 == i2, - (KVValue::Float(f1), KVValue::Float(f2)) => (f1 - f2).abs() < 1e-6, - _ => false, - } - } -} - -impl Eq for KVValue {} - -impl PartialOrd for KVValue { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for KVValue { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - match (self, other) { - (KVValue::String(s1), KVValue::String(s2)) => s1.cmp(s2), - (KVValue::Integer(i1), KVValue::Integer(i2)) => i1.cmp(i2), - (KVValue::Float(f1), KVValue::Float(f2)) => f1.partial_cmp(f2).unwrap(), - _ => std::cmp::Ordering::Less, - } - } -} - -#[derive(Debug, Serialize, Deserialize, Clone, Hash)] -pub struct KVPair { - pub key: String, - pub value: KVValue, -} - -impl KVPair { - pub fn new(key: String, value: String) -> Self { - KVPair { - key, - value: KVValue::String(value), - } - } - - pub fn new_int(key: String, value: i64) -> Self { - KVPair { - key, - value: KVValue::Integer(value), - } - } - - pub fn new_float(key: String, value: f32) -> Self { - KVPair { - key, - value: KVValue::Float(value), - } - } -} - -impl PartialEq for KVPair { - fn eq(&self, other: &Self) -> bool { - self.key == other.key && self.value == other.value - } -} - -impl Eq for KVPair {} - -impl PartialOrd for KVPair { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for KVPair { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.key - .cmp(&other.key) - .then_with(|| self.value.cmp(&other.value)) - } -} - -impl Display for KVPair { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "KVPair {{ key: {}, value: {} }}", self.key, self.value) - } -} - -impl TreeSerialization for KVPair { - fn serialize(&self) -> Vec { - let mut serialized = Vec::new(); - - serialized.extend_from_slice(self.key.len().to_le_bytes().as_ref()); - serialized.extend_from_slice(self.key.as_bytes()); - // serialized.extend_from_slice(self.value.len().to_le_bytes().as_ref()); - // serialized.extend_from_slice(self.value.as_bytes()); - - match self.value.clone() { - KVValue::String(s) => { - serialized.push(0); - serialized.extend_from_slice(s.len().to_le_bytes().as_ref()); - serialized.extend_from_slice(s.as_bytes()); - } - KVValue::Integer(i) => { - serialized.push(1); - serialized.extend_from_slice(i.to_le_bytes().as_ref()); - } - KVValue::Float(f) => { - serialized.push(2); - serialized.extend_from_slice(f.to_bits().to_le_bytes().as_ref()); - } - } - - serialized - } -} - -impl KVPair { - pub fn serialize(&self) -> Vec { - let mut serialized = Vec::new(); - - serialized.extend_from_slice(self.key.len().to_le_bytes().as_ref()); - serialized.extend_from_slice(self.key.as_bytes()); - // serialized.extend_from_slice(self.value.len().to_le_bytes().as_ref()); - // serialized.extend_from_slice(self.value.as_bytes()); - - match self.value.clone() { - KVValue::String(s) => { - serialized.push(0); - serialized.extend_from_slice(s.len().to_le_bytes().as_ref()); - serialized.extend_from_slice(s.as_bytes()); - } - KVValue::Integer(i) => { - serialized.push(1); - serialized.extend_from_slice(i.to_le_bytes().as_ref()); - } - KVValue::Float(f) => { - serialized.push(2); - serialized.extend_from_slice(f.to_bits().to_le_bytes().as_ref()); - } - } - - serialized - } - - pub fn deserialize(data: &[u8]) -> Self { - let mut offset = 0; - - let key_len = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; - offset += 8; - let key = String::from_utf8(data[offset..offset + key_len].to_vec()).unwrap(); - offset += key_len; - - // let value_len = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; - // offset += 8; - // let value = String::from_utf8(data[offset..offset + value_len].to_vec()).unwrap(); - // // offset += value_len; - - let value_flag = data[offset]; - offset += 1; - - let value = match value_flag { - 0 => { - let value_len = - u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; - offset += 8; - let value = String::from_utf8(data[offset..offset + value_len].to_vec()).unwrap(); - KVValue::String(value) - } - 1 => { - let value = i64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()); - KVValue::Integer(value) - } - 2 => { - let bits = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()); - let value = f32::from_bits(bits); - KVValue::Float(value) - } - _ => KVValue::String("".to_string()), - }; - - KVPair { key, value } - } -} - -impl TreeDeserialization for KVPair { - fn deserialize(data: &[u8]) -> Self { - let mut offset = 0; - - let key_len = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; - offset += 8; - let key = String::from_utf8(data[offset..offset + key_len].to_vec()).unwrap(); - offset += key_len; - - // let value_len = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; - // offset += 8; - // let value = String::from_utf8(data[offset..offset + value_len].to_vec()).unwrap(); - // // offset += value_len; - - let value_flag = data[offset]; - offset += 1; - - let value = match value_flag { - 0 => { - let value_len = - u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; - offset += 8; - let value = String::from_utf8(data[offset..offset + value_len].to_vec()).unwrap(); - KVValue::String(value) - } - 1 => { - let value = i64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()); - KVValue::Integer(value) - } - 2 => { - let bits = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()); - let value = f32::from_bits(bits); - KVValue::Float(value) - } - _ => KVValue::String("".to_string()), - }; - - KVPair { key, value } - } -} - -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct MetadataIndexItem { - pub kvs: Vec, - pub id: u128, - pub vector_index: usize, - // pub namespaced_id: String, -} - -impl Display for MetadataIndexItem { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "MetadataIndexItem {{ kvs: {:?}, id: {}, vector_index: {}, namespaced_id: }}", - self.kvs, self.id, self.vector_index - ) - } -} - -impl TreeSerialization for MetadataIndexItem { - fn serialize(&self) -> Vec { - let mut serialized = Vec::new(); - - serialized.extend_from_slice(self.kvs.len().to_le_bytes().as_ref()); - // for kv in &self.kvs { - // serialized.extend_from_slice(kv.key.len().to_le_bytes().as_ref()); - // serialized.extend_from_slice(kv.key.as_bytes()); - // serialized.extend_from_slice(kv.value.len().to_le_bytes().as_ref()); - // serialized.extend_from_slice(kv.value.as_bytes()); - // } - for kv in &self.kvs { - let serialized_kv = TreeSerialization::serialize(kv); - serialized.extend_from_slice(serialized_kv.len().to_le_bytes().as_ref()); - serialized.extend_from_slice(serialized_kv.as_ref()); - } - - // serialized.extend_from_slice(self.id.len().to_le_bytes().as_ref()); - serialized.extend_from_slice(self.id.to_le_bytes().as_ref()); - - serialized.extend_from_slice(self.vector_index.to_le_bytes().as_ref()); - - // serialized.extend_from_slice(self.namespaced_id.len().to_le_bytes().as_ref()); - // serialized.extend_from_slice(self.namespaced_id.as_bytes()); - - serialized - } -} - -impl TreeDeserialization for MetadataIndexItem { - fn deserialize(data: &[u8]) -> Self { - let mut offset = 0; - - let kvs_len = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; - offset += 8; - - let mut kvs = Vec::new(); - for _ in 0..kvs_len { - // let key_len = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; - // offset += 8; - - // let key = String::from_utf8(data[offset..offset + key_len].to_vec()).unwrap(); - // offset += key_len; - - // let value_len = - // u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; - // offset += 8; - - // let value = String::from_utf8(data[offset..offset + value_len].to_vec()).unwrap(); - // offset += value_len; - - // kvs.push(KVPair { key, value }); - - let kv_len = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; - offset += 8; - - let kv = TreeDeserialization::deserialize(&data[offset..offset + kv_len]); - offset += kv_len; - - kvs.push(kv); - } - - // let id_len = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; - // offset += 8; - - let id = u128::from_le_bytes(data[offset..offset + 16].try_into().unwrap()); - offset += 16; - - let vector_index = - u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; - // offset += 8; - - // let namespaced_id_len = - // u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; - // offset += 8; - - // let namespaced_id = - // String::from_utf8(data[offset..offset + namespaced_id_len].to_vec()).unwrap(); - // offset += namespaced_id_len; - - MetadataIndexItem { - kvs, - id, - vector_index, - // namespaced_id, - } - } -} - -impl TreeSerialization for u128 { - fn serialize(&self) -> Vec { - self.to_le_bytes().to_vec() - } -} - -impl TreeDeserialization for u128 { - fn deserialize(data: &[u8]) -> Self { - u128::from_le_bytes(data.try_into().unwrap()) - } -} - -pub struct MetadataIndex { - pub tree: Tree, -} - -impl MetadataIndex { - pub fn new() -> Self { - let tree = Tree::new().expect("Failed to create tree"); - MetadataIndex { tree } - } - - pub fn insert(&mut self, key: u128, value: MetadataIndexItem) { - // self.tree.insert(key, value).expect("Failed to insert"); - self.tree.insert(key, value).expect("Failed to insert"); - } - - pub fn batch_insert(&mut self, items: Vec<(u128, MetadataIndexItem)>) { - self.tree - .batch_insert(items) - .expect("Failed to batch insert"); - } - - pub fn get(&mut self, key: u128) -> Option { - match self.tree.search(key) { - Ok(v) => v, - Err(_) => None, - } - } - - pub fn len(&self) -> usize { - self.tree.len() - } - - pub fn to_binary(&mut self) -> Vec { - self.tree.to_binary() - } - - pub fn from_binary(data: Vec) -> Self { - let tree = Tree::from_binary(data).expect("Failed to create tree from binary"); - MetadataIndex { tree } - } -} diff --git a/src/structures/tree.rs b/src/structures/tree.rs index 5bc73ef..a7aa164 100644 --- a/src/structures/tree.rs +++ b/src/structures/tree.rs @@ -1,103 +1,267 @@ pub mod node; -pub mod serialization; use std::fmt::{Debug, Display}; use std::io; +use std::path::PathBuf; -use node::{Node, NodeType}; -use serialization::{TreeDeserialization, TreeSerialization}; +use node::{Node, NodeType, NodeValue}; + +use super::ann_tree::serialization::{TreeDeserialization, TreeSerialization}; +use super::block_storage::BlockStorage; pub struct Tree { - pub root: Box>, - pub b: usize, + pub storage: BlockStorage, + phantom: std::marker::PhantomData<(K, V)>, } impl Tree where - K: Clone + Ord + TreeSerialization + TreeDeserialization + Display + Debug + Copy, - V: Clone + TreeSerialization + TreeDeserialization + Display + Debug, + K: Clone + Ord + TreeSerialization + TreeDeserialization + Debug + Display, + V: Clone + TreeSerialization + TreeDeserialization, { - pub fn new() -> Self { - Tree { - root: Box::new(Node::new_leaf()), // Initially the root is a leaf node - b: 4, + pub fn new(path: PathBuf) -> io::Result { + let mut storage = BlockStorage::new(path)?; + + if storage.used_blocks() <= 1 { + let root_offset: usize; + let mut root: Node = Node::new_leaf(0); + root.is_root = true; + + let serialized_root = root.serialize(); + + root_offset = storage.store(serialized_root, 0)?; + + println!("Root offset: {}", root_offset); + storage.set_root_offset(root_offset); } + + Ok(Tree { + storage, + phantom: std::marker::PhantomData, + }) } - pub fn insert(&mut self, key: K, value: V) -> Result<(), io::Error> { - let mut root = std::mem::replace(&mut self.root, Box::new(Node::new_leaf())); - if self.is_node_full(&root)? { - let mut new_root = Node::new_internal(); - let (median, sibling) = root.split()?; - new_root.keys.push(median); - new_root.children.push(root); - new_root.children.push(Box::new(sibling)); - root = Box::new(new_root); - } - self.insert_non_full(&mut *root, key, value)?; - self.root = root; - Ok(()) + pub fn store_node(&mut self, node: &mut Node) -> io::Result { + let serialized_node = node.serialize(); + println!("Storing node: {:?}", node.offset); + println!("Node has {} keys", node.keys.len()); + let offset = self.storage.store(serialized_node, node.offset)?; + node.offset = offset; + + Ok(offset) + } + + pub fn load_node(&self, offset: usize) -> io::Result> { + let serialized_node = self.storage.load(offset)?; + let mut node = Node::::deserialize(&serialized_node); + node.offset = offset; + Ok(node) } - fn insert_non_full( - &mut self, - node: &mut Node, - key: K, - value: V, - ) -> Result<(), io::Error> { - match &mut node.node_type { + pub fn get(&self, key: &K) -> io::Result>> { + let root_offset = self.storage.root_offset(); + let root = self.load_node(root_offset)?; + + println!("Root offset: {}", root_offset); + println!("Root keys: {:?}", root.keys); + + let result = self.get_recursive(&root, key); + + Ok(result) + } + + fn get_recursive(&self, node: &Node, key: &K) -> Option> { + match node.node_type { NodeType::Leaf => { - let idx = node.keys.binary_search(&key).unwrap_or_else(|x| x); - node.keys.insert(idx, key); - node.values.insert(idx, Some(value)); - Ok(()) + for i in 0..node.keys.len() { + if node.keys[i] == *key { + println!("Found key: {}", key); + return node.values[i].clone(); + } + } + None } NodeType::Internal => { - let idx = node.keys.binary_search(&key).unwrap_or_else(|x| x); - let child_idx = if idx == node.keys.len() || key < node.keys[idx] { - idx - } else { - idx + 1 - }; - - if self.is_node_full(&node.children[child_idx])? { - let (median, sibling) = node.children[child_idx].split()?; - node.keys.insert(idx, median); - node.children.insert(child_idx + 1, Box::new(sibling)); - if key >= node.keys[idx] { - self.insert_non_full(&mut *node.children[child_idx + 1], key, value) - } else { - self.insert_non_full(&mut *node.children[child_idx], key, value) - } - } else { - self.insert_non_full(&mut *node.children[child_idx], key, value) + let mut i = 0; + while i < node.keys.len() && *key > node.keys[i] { + i += 1; } + + println!("Searching in child: {}", i); + + let child_offset = node.children[i]; + let child = self.load_node(child_offset).unwrap(); + + self.get_recursive(&child, key) } } } - fn is_node_full(&self, node: &Node) -> Result { - Ok(node.keys.len() == node.max_keys) - } + pub fn insert(&mut self, key: K, value: V) -> Result<(), io::Error> { + // println!("Inserting key: {}, value: {}", key, value); + let vals = vec![(key, value)]; - pub fn search(&self, key: K) -> Result, io::Error> { - self.search_node(&*self.root, key) + self.batch_insert(vals) } - fn search_node(&self, node: &Node, key: K) -> Result, io::Error> { - match node.node_type { - NodeType::Internal => { - let idx = node.keys.binary_search(&key).unwrap_or_else(|x| x); - if idx < node.keys.len() && node.keys[idx] == key { - self.search_node(&node.children[idx + 1], key) + pub fn batch_insert(&mut self, entries: Vec<(K, V)>) -> Result<(), io::Error> { + if entries.is_empty() { + println!("No entries to insert"); + return Ok(()); + } + + let mut entries = entries; + entries.sort_by(|a, b| a.0.cmp(&b.0)); + + let entrypoint = self.find_entrypoint(&entries[0].0)?; + + let mut current_node = entrypoint; + + for (key, value) in entries.iter() { + if current_node.is_full() { + let (median, mut sibling) = current_node.split(crate::constants::B)?; + let sibling_offset = self.store_node(&mut sibling)?; + self.store_node(&mut current_node)?; // Store changes to the original node after splitting + if current_node.is_root { + let mut new_root = Node::new_internal(0); + new_root.is_root = true; + new_root.keys.push(median.clone()); + new_root.children.push(current_node.offset); // old root offset + new_root.children.push(sibling_offset); // new sibling offset + new_root.parent_offset = None; + let new_root_offset = self.store_node(&mut new_root)?; + self.storage.set_root_offset(new_root_offset); + current_node.is_root = false; + current_node.parent_offset = Some(new_root_offset); + sibling.parent_offset = Some(new_root_offset); + self.store_node(&mut current_node)?; + self.store_node(&mut sibling)?; + self.storage.set_root_offset(new_root_offset); } else { - self.search_node(&node.children[idx], key) + let parent_offset = current_node.parent_offset.unwrap(); + let mut parent = self.load_node(parent_offset)?; + let idx = parent + .keys + .binary_search(&median.clone()) + .unwrap_or_else(|x| x); + parent.keys.insert(idx, median.clone()); + parent.children.insert(idx + 1, sibling_offset); + self.store_node(&mut parent)?; } + + if *key >= median { + current_node = sibling; + } + } + + // Insert the key into the correct leaf node + let position = current_node.keys.binary_search(key).unwrap_or_else(|x| x); + + if current_node.keys.get(position) == Some(&key) { + current_node.values[position] = + Some(NodeValue::new(value.clone(), &mut self.storage)?); + } else { + current_node.keys.insert(position, key.clone()); + current_node.values.insert( + position, + Some(NodeValue::new(value.clone(), &mut self.storage)?), + ); } - NodeType::Leaf => match node.keys.binary_search(&key) { - Ok(idx) => Ok(node.values.get(idx).expect("could not get value").clone()), - Err(_) => Ok(None), - }, + self.store_node(&mut current_node)?; // Store changes after each insertion } + + Ok(()) + } + + pub fn find_entrypoint(&self, key: &K) -> io::Result> { + let root_offset = self.storage.root_offset(); + let root = self.load_node(root_offset)?; + + let mut current_node = root; + loop { + match current_node.node_type { + NodeType::Leaf => { + return Ok(current_node); + } + NodeType::Internal => { + let mut i = 0; + while i < current_node.keys.len() && *key > current_node.keys[i] { + i += 1; + } + + let child_offset = current_node.children[i]; + current_node = self.load_node(child_offset)?; + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use std::path::PathBuf; + + #[test] + fn test_tree_insert() { + let path = PathBuf::from("test_tree_insert.bin"); + let _ = fs::remove_dir_all(&path); + + let mut tree: Tree = Tree::new(path).unwrap(); + + let entries = vec![(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)]; + tree.batch_insert(entries).unwrap(); + + let result = tree + .get(&1) + .unwrap() + .unwrap() + .get(&tree.storage) + .expect("Value not found"); + assert_eq!(result, 1); + + let result = tree + .get(&2) + .unwrap() + .unwrap() + .get(&tree.storage) + .expect("Value not found"); + assert_eq!(result, 2); + + let result = tree + .get(&3) + .unwrap() + .unwrap() + .get(&tree.storage) + .expect("Value not found"); + + assert_eq!(result, 3); + + let result = tree + .get(&4) + .unwrap() + .unwrap() + .get(&tree.storage) + .expect("Value not found"); + + assert_eq!(result, 4); + + let result = tree + .get(&5) + .unwrap() + .unwrap() + .get(&tree.storage) + .expect("Value not found"); + + assert_eq!(result, 5); + + let result = tree.get(&6).unwrap(); + + assert_eq!(result, None); + + let result = tree.get(&0).unwrap(); + + assert_eq!(result, None); } } diff --git a/src/structures/tree/node.rs b/src/structures/tree/node.rs index 632f2ec..ebe1bad 100644 --- a/src/structures/tree/node.rs +++ b/src/structures/tree/node.rs @@ -1,9 +1,8 @@ -use std::{ - fmt::{Debug, Display}, - io, +use crate::structures::{ + ann_tree::serialization::{TreeDeserialization, TreeSerialization}, + block_storage::BlockStorage, }; - -use super::serialization::{TreeDeserialization, TreeSerialization}; +use std::io; #[derive(Debug, PartialEq, Clone)] pub enum NodeType { @@ -11,104 +10,480 @@ pub enum NodeType { Internal, } -const MAX_KEYS: usize = 10; +pub fn serialize_node_type(node_type: &NodeType) -> [u8; 1] { + match node_type { + NodeType::Leaf => [0], + NodeType::Internal => [1], + } +} + +pub fn deserialize_node_type(data: &[u8]) -> NodeType { + match data[0] { + 0 => NodeType::Leaf, + 1 => NodeType::Internal, + _ => panic!("Invalid node type"), + } +} + +fn serialize_length(buffer: &mut Vec, length: u32) -> &Vec { + buffer.extend_from_slice(&length.to_le_bytes()); + + // Return the buffer to allow chaining + buffer +} + +fn read_length(data: &[u8]) -> usize { + u32::from_le_bytes(data.try_into().unwrap()) as usize +} + +#[derive(Debug, PartialEq, Clone)] +pub struct NodeValue { + offset: usize, + value: Option, +} +impl NodeValue +where + T: Clone + TreeDeserialization + TreeSerialization, +{ + pub fn get(&mut self, storage: &BlockStorage) -> Result { + match self.value.clone() { + Some(value) => Ok(value), + None => { + let bytes = storage.load(self.offset)?; + let value = T::deserialize(&bytes); + self.value = Some(value.clone()); + Ok(value) + } + } + } + + pub fn new(value: T, storage: &mut BlockStorage) -> Result { + let offset = storage.store(value.serialize(), 0)?; + Ok(NodeValue { + offset, + value: Some(value), + }) + } +} + +#[derive(Debug, PartialEq, Clone)] pub struct Node { pub keys: Vec, - pub values: Vec>, // Option for handling deletion in COW - pub children: Vec>>, // Using Box for heap allocation - pub max_keys: usize, // Maximum number of keys a node can hold + pub values: Vec>>, + pub children: Vec, pub node_type: NodeType, + pub offset: usize, + pub is_root: bool, + pub parent_offset: Option, } impl Node where - K: Clone + Ord + TreeSerialization + TreeDeserialization + Display + Debug + Copy, - V: Clone + TreeSerialization + TreeDeserialization + Display + Debug, + K: Clone + Ord + TreeSerialization + TreeDeserialization, + V: Clone + TreeSerialization + TreeDeserialization, { - pub fn new_leaf() -> Self { + pub fn new_leaf(offset: usize) -> Self { Node { keys: Vec::new(), values: Vec::new(), children: Vec::new(), - max_keys: MAX_KEYS, // Assuming a small number for testing purposes node_type: NodeType::Leaf, + offset, + is_root: false, + parent_offset: Some(0), } } - pub fn new_internal() -> Self { + pub fn new_internal(offset: usize) -> Self { Node { keys: Vec::new(), values: Vec::new(), children: Vec::new(), - max_keys: MAX_KEYS, node_type: NodeType::Internal, + offset, + is_root: false, + parent_offset: Some(0), } } - pub fn clone(&self) -> Self { - Node { - keys: self.keys.clone(), - values: self.values.clone(), - children: self - .children - .iter() - .map(|c| Box::new((**c).clone())) - .collect(), - max_keys: self.max_keys, - node_type: self.node_type.clone(), - } - } + pub fn split(&mut self, b: usize) -> Result<(K, Node), io::Error> { + // println!("Splitting node: {:?}", self.keys); - pub fn split(&mut self) -> Result<(K, Node), io::Error> { match self.node_type { NodeType::Internal => { - let split_index = (self.keys.len() + 1) / 2; - let median_key = self.keys[split_index].clone(); + if b <= 1 || b > self.keys.len() { + return Err(io::Error::new( + io::ErrorKind::Other, + "Invalid split point for internal node", + )); + } + let mut sibling_keys = self.keys.split_off(b - 1); + let median_key = sibling_keys.remove(0); - let sibling_keys = self.keys.split_off(split_index + 1); - let sibling_children = self.children.split_off(split_index + 1); + let sibling_children = self.children.split_off(b); let sibling = Node { keys: sibling_keys, values: Vec::new(), children: sibling_children, - max_keys: self.max_keys, node_type: NodeType::Internal, + offset: 0, + is_root: false, + parent_offset: self.parent_offset, }; - self.keys.pop(); - Ok((median_key, sibling)) } NodeType::Leaf => { - let split_index = (self.keys.len() + 1) / 2; - let median_key = self.keys[split_index].clone(); - - let sibling_keys = self.keys.split_off(split_index); - let sibling_values = self.values.split_off(split_index); + if b < 1 || b >= self.keys.len() { + return Err(io::Error::new( + io::ErrorKind::Other, + "Invalid split point for leaf node", + )); + } + let sibling_keys = self.keys.split_off(b); + let median_key = self.keys.get(b - 1).unwrap().clone(); + let sibling_values = self.values.split_off(b); let sibling = Node { keys: sibling_keys, values: sibling_values, children: Vec::new(), - max_keys: self.max_keys, node_type: NodeType::Leaf, + offset: 0, + is_root: false, + parent_offset: self.parent_offset, }; Ok((median_key, sibling)) } } } -} -impl Default for Node { - fn default() -> Self { + + pub fn is_full(&self) -> bool { + let b = crate::constants::B; + return self.keys.len() >= (2 * b - 1); + } + + pub fn serialize(&self) -> Vec { + let mut buffer = Vec::new(); + + buffer.extend_from_slice(&serialize_node_type(&self.node_type)); + buffer.extend_from_slice(&(self.is_root as u8).to_le_bytes()); + match &self.parent_offset { + Some(parent_offset) => { + buffer.extend_from_slice(&(*parent_offset as u64).to_le_bytes()); + } + None => { + buffer.extend_from_slice(&0u64.to_le_bytes()); + } + } + + serialize_length(&mut buffer, self.keys.len() as u32); + + for key in &self.keys { + let serialized_key = key.serialize(); + serialize_length(&mut buffer, serialized_key.len() as u32); + buffer.extend_from_slice(&serialized_key); + } + + match &self.node_type { + NodeType::Leaf => { + for value in &self.values { + match value { + Some(value) => { + buffer.extend_from_slice(&(value.offset as u64).to_le_bytes()); + } + None => { + buffer.extend_from_slice(&0u64.to_le_bytes()); + } + } + } + } + NodeType::Internal => { + for child in &self.children { + buffer.extend_from_slice(&(*child as u64).to_le_bytes()); + } + } + } + + buffer + } + + pub fn deserialize(data: &[u8]) -> Self { + let mut offset = 0; + + let node_type = deserialize_node_type(&data[offset..offset + 1]); + offset += 1; + let is_root = u8::from_le_bytes(data[offset..offset + 1].try_into().unwrap()) == 1; + offset += 1; + let parent_offset = + u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; + offset += 8; + + let keys_len = read_length(&data[offset..offset + 4]); + offset += 4; + + let mut keys = Vec::new(); + for _ in 0..keys_len { + let key_len = read_length(&data[offset..offset + 4]); + offset += 4; + let key = K::deserialize(&data[offset..offset + key_len]); + offset += key_len; + keys.push(key); + } + + let mut values = Vec::new(); + let mut children = Vec::new(); + match node_type { + NodeType::Leaf => { + for _ in 0..keys_len { + let value_offset = + u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; + offset += 8; + values.push(Some(NodeValue { + offset: value_offset, + value: None, + })); + } + } + NodeType::Internal => { + for _ in 0..keys_len { + let child_offset = + u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; + offset += 8; + children.push(child_offset); + } + } + } + Node { - keys: Vec::new(), - values: Vec::new(), - children: Vec::new(), - max_keys: 0, // Adjust this as necessary - node_type: NodeType::Leaf, // Or another appropriate default NodeType + keys, + values, + children, + node_type, + offset: 0, + is_root, + parent_offset: Some(parent_offset), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::structures::ann_tree::serialization::{TreeDeserialization, TreeSerialization}; + + #[derive(Debug, PartialEq, Clone, PartialOrd, Eq)] + struct TestKey { + key: String, + } + + impl Ord for TestKey { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.key.cmp(&other.key) } } + + impl TreeSerialization for TestKey { + fn serialize(&self) -> Vec { + self.key.as_bytes().to_vec() + } + } + + impl TreeDeserialization for TestKey { + fn deserialize(data: &[u8]) -> Self { + TestKey { + key: String::from_utf8(data.to_vec()).unwrap(), + } + } + } + + #[derive(Debug, PartialEq, Clone)] + struct TestValue { + value: String, + } + + impl TreeSerialization for TestValue { + fn serialize(&self) -> Vec { + self.value.as_bytes().to_vec() + } + } + + impl TreeDeserialization for TestValue { + fn deserialize(data: &[u8]) -> Self { + TestValue { + value: String::from_utf8(data.to_vec()).unwrap(), + } + } + } + + #[test] + fn test_serialize_node_type() { + assert_eq!(serialize_node_type(&NodeType::Leaf), [0]); + assert_eq!(serialize_node_type(&NodeType::Internal), [1]); + } + + #[test] + fn test_deserialize_node_type() { + assert_eq!(deserialize_node_type(&[0]), NodeType::Leaf); + assert_eq!(deserialize_node_type(&[1]), NodeType::Internal); + } + + #[test] + fn test_serialize_length() { + let mut buffer = Vec::new(); + serialize_length(&mut buffer, 10); + assert_eq!(buffer, [10, 0, 0, 0]); + } + + #[test] + fn test_read_length() { + assert_eq!(read_length(&[10, 0, 0, 0]), 10); + } + + #[test] + fn test_serialize_node() { + let node = Node:: { + keys: vec![TestKey { + key: "key1".to_string(), + }], + values: vec![Some(NodeValue { + offset: 0, + value: Some(TestValue { + value: "value1".to_string(), + }), + })], + children: vec![1], + node_type: NodeType::Leaf, + offset: 0, + is_root: true, + parent_offset: Some(0), + }; + + let serialized = node.serialize(); + let deserialized = Node::::deserialize(&serialized); + + assert_eq!(node, deserialized); + } + + #[test] + fn test_serialize_internal_node() { + let node = Node:: { + keys: vec![TestKey { + key: "key1".to_string(), + }], + values: vec![], + children: vec![1], + node_type: NodeType::Internal, + offset: 0, + is_root: true, + parent_offset: Some(0), + }; + + let serialized = node.serialize(); + let deserialized = Node::::deserialize(&serialized); + + assert_eq!(node, deserialized); + } + + #[test] + fn test_serialize_node_with_multiple_keys() { + let node = Node:: { + keys: vec![ + TestKey { + key: "key1".to_string(), + }, + TestKey { + key: "key2".to_string(), + }, + ], + values: vec![ + Some(NodeValue { + offset: 0, + value: Some(TestValue { + value: "value1".to_string(), + }), + }), + Some(NodeValue { + offset: 1, + value: Some(TestValue { + value: "value2".to_string(), + }), + }), + ], + children: vec![1, 2], + node_type: NodeType::Leaf, + offset: 0, + is_root: true, + parent_offset: Some(0), + }; + + let serialized = node.serialize(); + let deserialized = Node::::deserialize(&serialized); + + assert_eq!(node, deserialized); + } + + #[test] + fn test_serialize_internal_node_with_multiple_keys() { + let node = Node:: { + keys: vec![ + TestKey { + key: "key1".to_string(), + }, + TestKey { + key: "key2".to_string(), + }, + ], + values: vec![], + children: vec![1, 2], + node_type: NodeType::Internal, + offset: 0, + is_root: true, + parent_offset: Some(0), + }; + + let serialized = node.serialize(); + let deserialized = Node::::deserialize(&serialized); + + assert_eq!(node, deserialized); + } + + #[test] + fn test_serialize_node_with_no_keys() { + let node = Node:: { + keys: vec![], + values: vec![], + children: vec![], + node_type: NodeType::Leaf, + offset: 0, + is_root: true, + parent_offset: Some(0), + }; + + let serialized = node.serialize(); + let deserialized = Node::::deserialize(&serialized); + + assert_eq!(node, deserialized); + } + + #[test] + fn test_serialize_internal_node_with_no_keys() { + let node = Node:: { + keys: vec![], + values: vec![], + children: vec![], + node_type: NodeType::Internal, + offset: 0, + is_root: true, + parent_offset: Some(0), + }; + + let serialized = node.serialize(); + let deserialized = Node::::deserialize(&serialized); + + assert_eq!(node, deserialized); + } } diff --git a/src/structures/tree/serialization.rs b/src/structures/tree/serialization.rs deleted file mode 100644 index 23604ca..0000000 --- a/src/structures/tree/serialization.rs +++ /dev/null @@ -1,48 +0,0 @@ -pub trait TreeSerialization { - fn serialize(&self) -> Vec; -} - -pub trait TreeDeserialization { - fn deserialize(data: &[u8]) -> Self - where - Self: Sized; -} - -impl TreeDeserialization for i32 { - fn deserialize(data: &[u8]) -> Self { - let mut bytes = [0; 4]; - bytes.copy_from_slice(&data[..4]); - i32::from_le_bytes(bytes) - } -} - -impl TreeSerialization for i32 { - fn serialize(&self) -> Vec { - self.to_le_bytes().to_vec() - } -} - -impl TreeDeserialization for String { - fn deserialize(data: &[u8]) -> Self { - let mut bytes = Vec::new(); - let mut i = 4; - while i < data.len() { - let len = data[i..i + 4].try_into().unwrap(); - let len = i32::from_le_bytes(len) as usize; - let start = i + 4; - let end = start + len; - bytes.extend_from_slice(&data[start..end]); - i = end; - } - String::from_utf8(bytes).unwrap() - } -} - -impl TreeSerialization for String { - fn serialize(&self) -> Vec { - let mut data = Vec::new(); - data.extend_from_slice(&(self.len() as i32).to_le_bytes()); - data.extend_from_slice(self.as_bytes()); - data - } -} diff --git a/src/structures/tree/storage.rs b/src/structures/tree/storage.rs deleted file mode 100644 index e69de29..0000000 From 7f3cd0d1bcc4731fe7d98b0a8394316542d5f29b Mon Sep 17 00:00:00 2001 From: Carson Poole Date: Fri, 24 May 2024 05:44:14 +0000 Subject: [PATCH 5/6] more bug fixes --- src/main1.rs | 2 +- src/services/commit.rs | 32 ++++++++----- src/services/namespace_state.rs | 2 +- src/services/query.rs | 22 +++++---- src/structures/ann_tree/node.rs | 10 +++- src/structures/ann_tree/serialization.rs | 37 +++++++++++++++ src/structures/wal.rs | 60 +++++++++++++++++++++--- 7 files changed, 133 insertions(+), 32 deletions(-) diff --git a/src/main1.rs b/src/main1.rs index 132406b..f0ce77e 100644 --- a/src/main1.rs +++ b/src/main1.rs @@ -39,7 +39,7 @@ fn main() { // .expect("Failed to add to WAL"); // } - const NUM_VECTORS: usize = 10_000_000; + const NUM_VECTORS: usize = 10_000; let batch_vectors: Vec> = (0..NUM_VECTORS).map(|_| vec![random_vec()]).collect(); diff --git a/src/services/commit.rs b/src/services/commit.rs index d2abe40..b463a94 100644 --- a/src/services/commit.rs +++ b/src/services/commit.rs @@ -79,7 +79,10 @@ impl CommitService { .filter(|item| item.key == "text") .collect::>() .first() - .unwrap() + .unwrap_or(&&KVPair { + key: "text".to_string(), + value: KVValue::String("".to_string()), + }) .value .clone() }) @@ -92,23 +95,26 @@ impl CommitService { for ((vector, kv), texts) in vectors.iter().zip(kvs).zip(texts) { let id = uuid::Uuid::new_v4().as_u128(); - self.state.vectors.insert(vector.clone(), id, kv.clone()); + self.state.vectors.insert(*vector, id, kv); // self.state.texts.insert(id, texts.clone()); - match texts { - KVValue::String(text) => { - self.state - .texts - .insert(id, compress_string(&text)) - .expect("Failed to insert text"); - } - _ => {} - } + // match texts { + // KVValue::String(text) => { + // self.state + // .texts + // .insert(id, compress_string(&text)) + // .expect("Failed to insert text"); + // } + // _ => {} + // } } - self.state.wal.mark_commit_finished(commit.hash)?; + // self.state.wal.mark_commit_finished(commit.hash)?; } - self.state.vectors.true_calibrate(); + self.state + .vectors + .true_calibrate() + .expect("Failed to calibrate"); Ok(()) } diff --git a/src/services/namespace_state.rs b/src/services/namespace_state.rs index 90d1c39..2d64f0c 100644 --- a/src/services/namespace_state.rs +++ b/src/services/namespace_state.rs @@ -2,7 +2,7 @@ use crate::structures::ann_tree::ANNTree; // use crate::structures::dense_vector_list::DenseVectorList; // use crate::structures::inverted_index::InvertedIndex; // use crate::structures::metadata_index::MetadataIndex; -use crate::structures::mmap_tree::Tree; +use crate::structures::tree::Tree; use crate::structures::wal::WAL; use std::fs; use std::io; diff --git a/src/services/query.rs b/src/services/query.rs index 614d2d4..83e22cf 100644 --- a/src/services/query.rs +++ b/src/services/query.rs @@ -37,16 +37,18 @@ impl QueryService { // let mut metadata = metadata.clone(); // metadata.push(KVPair::new("id".to_string(), id.to_string())); - // let text = self - // .state - // .texts - // .search(*id) - // .unwrap() - // .expect("Text not found"); - - // let mut metadata = metadata.clone(); - - // metadata.push(KVPair::new("text".to_string(), decompress_string(&text))); + let text = self + .state + .texts + .get(id) + .unwrap() + .expect("Text not found") + .get(&self.state.texts.storage) + .expect("Text not found"); + + let mut metadata = metadata.clone(); + + metadata.push(KVPair::new("text".to_string(), decompress_string(&text))); metadata.clone() }) diff --git a/src/structures/ann_tree/node.rs b/src/structures/ann_tree/node.rs index 642f63b..6fd00d5 100644 --- a/src/structures/ann_tree/node.rs +++ b/src/structures/ann_tree/node.rs @@ -393,7 +393,15 @@ impl Node { } // serialize node_metadata offset - serialized.extend_from_slice(&self.node_metadata.as_ref().unwrap().offset.to_le_bytes()); + // serialized.extend_from_slice(&self.node_metadata.as_ref().unwrap_or(LazyValue::).offset.to_le_bytes()); + match self.node_metadata { + Some(ref node_metadata) => { + serialized.extend_from_slice(&node_metadata.offset.to_le_bytes()); + } + None => { + serialized.extend_from_slice(&(0 as i64).to_le_bytes()); + } + } // // Serialize metadata // serialize_length(&mut serialized, self.metadata.len() as u32); diff --git a/src/structures/ann_tree/serialization.rs b/src/structures/ann_tree/serialization.rs index 9ec9f29..efd85d6 100644 --- a/src/structures/ann_tree/serialization.rs +++ b/src/structures/ann_tree/serialization.rs @@ -35,6 +35,43 @@ impl TreeDeserialization for String { } } +impl TreeSerialization for u128 { + fn serialize(&self) -> Vec { + self.to_le_bytes().to_vec() + } +} + +impl TreeDeserialization for u128 { + fn deserialize(data: &[u8]) -> Self { + let mut bytes = [0; 16]; + bytes.copy_from_slice(&data[..16]); + u128::from_le_bytes(bytes) + } +} + +impl TreeSerialization for Vec { + fn serialize(&self) -> Vec { + let mut data = Vec::new(); + data.extend_from_slice(&(self.len() as u32).to_le_bytes()); // Write length + data.extend_from_slice(self); // Write string data + data + } +} + +impl TreeDeserialization for Vec { + fn deserialize(data: &[u8]) -> Self { + if data.len() < 4 { + panic!("Data too short to contain length prefix"); + } + let len = u32::from_le_bytes(data[0..4].try_into().unwrap()) as usize; // Read length + if data.len() < 4 + len { + panic!("Data too short for specified string length"); + } + let string_data = &data[4..4 + len]; // Extract string data + string_data.to_vec() + } +} + impl TreeSerialization for String { fn serialize(&self) -> Vec { let mut data = Vec::new(); diff --git a/src/structures/wal.rs b/src/structures/wal.rs index 8da38c8..775a36b 100644 --- a/src/structures/wal.rs +++ b/src/structures/wal.rs @@ -210,12 +210,12 @@ impl WAL { fn u64_to_i64(&self, value: u64) -> i64 { // Safely convert u64 to i64 by reinterpreting the bits - i64::from_ne_bytes(value.to_ne_bytes()) + i64::from_le_bytes(value.to_le_bytes()) } fn i64_to_u64(&self, value: i64) -> u64 { // Safely convert i64 to u64 by reinterpreting the bits - u64::from_ne_bytes(value.to_ne_bytes()) + u64::from_le_bytes(value.to_le_bytes()) } pub fn add_to_commit_list( @@ -245,6 +245,42 @@ impl WAL { Ok(()) } + pub fn batch_add_to_commit_list( + &mut self, + hashes: Vec, + vectors: Vec>, + kvs: Vec>>, + ) -> Result<()> { + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + // let mut stmt = self.conn.prepare( + // "INSERT INTO wal (hash, data, metadata, added_timestamp) VALUES (?1, ?2, ?3, ?4);", + // )?; + + let timestamp_i64 = self.u64_to_i64(timestamp); + + let tx = self.conn.transaction()?; + + for ((hash, vectors), kvs) in hashes.iter().zip(vectors.iter()).zip(kvs.iter()) { + let metadata = json!(kvs).to_string(); + let data: Vec = vectors.iter().flat_map(|v| v.to_vec()).collect(); + + let hash_i64 = i64::from_le_bytes(hash.to_le_bytes()); + + tx.execute( + "INSERT INTO wal (hash, data, metadata, added_timestamp) VALUES (?1, ?2, ?3, ?4);", + params![&hash_i64, &data, &metadata, ×tamp_i64], + )?; + } + + tx.commit()?; + + Ok(()) + } + pub fn has_been_committed(&mut self, hash: u64) -> Result { let mut stmt = self .conn @@ -352,10 +388,22 @@ impl WAL { .map(|v| v.iter().map(|v| quantize(v)).collect()) .collect(); - for (v, k) in quantized_vectors.iter().zip(kvs.iter()) { - let hash = self.compute_hash(v, k); - self.add_to_commit_list(hash, v.clone(), k.clone())?; - } + let mut i = 0; + + // for (v, k) in quantized_vectors.iter().zip(kvs.iter()) { + // i += 1; + // println!("Adding to WAL: {}/{}", i, vectors.len()); + // let hash = self.compute_hash(v, k); + // self.add_to_commit_list(hash, v.clone(), k.clone())?; + // } + + let hashes: Vec = quantized_vectors + .iter() + .zip(kvs.iter()) + .map(|(v, k)| self.compute_hash(v, k)) + .collect(); + + self.batch_add_to_commit_list(hashes, quantized_vectors, kvs)?; Ok(()) } From a9f8ddf1ed20136153e4ff485f1474c1c9fda6d1 Mon Sep 17 00:00:00 2001 From: Carson Poole Date: Fri, 24 May 2024 22:30:12 +0000 Subject: [PATCH 6/6] new version of storage manager --- src/constants.rs | 4 +- src/errors.rs | 0 src/lib.rs | 0 src/main.rs | 317 ++++++++++------------ src/main1.rs | 154 ----------- src/main3.rs | 173 ++++++++++++ src/math.rs | 0 src/math/gemm.rs | 0 src/math/gemv.rs | 0 src/math/hamming_distance.rs | 0 src/services.rs | 0 src/services/commit.rs | 18 +- src/services/lock_service.rs | 0 src/services/namespace_state.rs | 0 src/services/query.rs | 0 src/structures.rs | 1 + src/structures/ann_tree.rs | 28 +- src/structures/ann_tree/k_modes.rs | 0 src/structures/ann_tree/metadata.rs | 0 src/structures/ann_tree/node.rs | 144 +++------- src/structures/ann_tree/serialization.rs | 0 src/structures/ann_tree/storage.rs | 0 src/structures/block_storage.rs | 2 + src/structures/filters.rs | 4 +- src/structures/mmap_tree.rs | 0 src/structures/mmap_tree/node.rs | 0 src/structures/mmap_tree/serialization.rs | 0 src/structures/mmap_tree/storage.rs | 0 src/structures/storage_layer.rs | 170 ++++++++++++ src/structures/tree.rs | 28 +- src/structures/tree/node.rs | 5 +- src/structures/wal.rs | 0 src/utils.rs | 0 src/utils/quantization.rs | 0 34 files changed, 564 insertions(+), 484 deletions(-) mode change 100644 => 100755 src/constants.rs mode change 100644 => 100755 src/errors.rs mode change 100644 => 100755 src/lib.rs mode change 100644 => 100755 src/main.rs delete mode 100644 src/main1.rs create mode 100755 src/main3.rs mode change 100644 => 100755 src/math.rs mode change 100644 => 100755 src/math/gemm.rs mode change 100644 => 100755 src/math/gemv.rs mode change 100644 => 100755 src/math/hamming_distance.rs mode change 100644 => 100755 src/services.rs mode change 100644 => 100755 src/services/commit.rs mode change 100644 => 100755 src/services/lock_service.rs mode change 100644 => 100755 src/services/namespace_state.rs mode change 100644 => 100755 src/services/query.rs mode change 100644 => 100755 src/structures.rs mode change 100644 => 100755 src/structures/ann_tree.rs mode change 100644 => 100755 src/structures/ann_tree/k_modes.rs mode change 100644 => 100755 src/structures/ann_tree/metadata.rs mode change 100644 => 100755 src/structures/ann_tree/node.rs mode change 100644 => 100755 src/structures/ann_tree/serialization.rs mode change 100644 => 100755 src/structures/ann_tree/storage.rs mode change 100644 => 100755 src/structures/block_storage.rs mode change 100644 => 100755 src/structures/filters.rs mode change 100644 => 100755 src/structures/mmap_tree.rs mode change 100644 => 100755 src/structures/mmap_tree/node.rs mode change 100644 => 100755 src/structures/mmap_tree/serialization.rs mode change 100644 => 100755 src/structures/mmap_tree/storage.rs create mode 100644 src/structures/storage_layer.rs mode change 100644 => 100755 src/structures/tree.rs mode change 100644 => 100755 src/structures/tree/node.rs mode change 100644 => 100755 src/structures/wal.rs mode change 100644 => 100755 src/utils.rs mode change 100644 => 100755 src/utils/quantization.rs diff --git a/src/constants.rs b/src/constants.rs old mode 100644 new mode 100755 index 2e3d22c..7a55361 --- a/src/constants.rs +++ b/src/constants.rs @@ -1,9 +1,9 @@ pub const VECTOR_SIZE: usize = 1024; pub const QUANTIZED_VECTOR_SIZE: usize = 128; -pub const K: usize = 128; +pub const K: usize = 512; pub const C: usize = 1; pub const ALPHA: usize = 64; pub const BETA: usize = 3; pub const GAMMA: usize = 1; pub const RHO: usize = 1; -pub const B: usize = 128; +pub const B: usize = 512; diff --git a/src/errors.rs b/src/errors.rs old mode 100644 new mode 100755 diff --git a/src/lib.rs b/src/lib.rs old mode 100644 new mode 100755 diff --git a/src/main.rs b/src/main.rs old mode 100644 new mode 100755 index 7485d77..bbe66be --- a/src/main.rs +++ b/src/main.rs @@ -1,173 +1,154 @@ -use env_logger::Builder; +extern crate haystackdb; use haystackdb::constants::VECTOR_SIZE; -use haystackdb::services::CommitService; -use haystackdb::services::QueryService; -use haystackdb::structures::filters::Filter as QueryFilter; -use haystackdb::structures::filters::KVPair; -use log::info; -use log::LevelFilter; -use std::io::Write; -use std::sync::{Arc, Mutex}; -use std::{self, path::PathBuf}; -use tokio::time::{interval, Duration}; - -use std::collections::HashMap; -use tokio::sync::OnceCell; -use warp::Filter; - -static ACTIVE_NAMESPACES: OnceCell>>>> = - OnceCell::const_new(); - -#[tokio::main] -async fn main() { - let mut builder = Builder::new(); - builder - .format(|buf, record| writeln!(buf, "{}: {}", record.level(), record.args())) - .filter(None, LevelFilter::Info) - .init(); - - let active_namespaces = ACTIVE_NAMESPACES - .get_or_init(|| async { Arc::new(Mutex::new(HashMap::new())) }) - .await; - - let search_route = warp::path!("query" / String) - .and(warp::post()) - .and(warp::body::json()) - .and(with_active_namespaces(active_namespaces.clone())) - .then( - |namespace_id: String, body: (Vec, QueryFilter, usize), active_namespaces| async move { - let base_path = PathBuf::from(format!("/workspace/data/{}/current", namespace_id.clone())); - ensure_namespace_initialized(&namespace_id, &active_namespaces, base_path.clone()) - .await; - - let mut query_service = QueryService::new(base_path, namespace_id.clone()).unwrap(); - let fvec = &body.0; - let metadata = &body.1; - let top_k = body.2; - - let mut vec: [f32; VECTOR_SIZE] = [0.0; VECTOR_SIZE]; - fvec.iter() - .enumerate() - .for_each(|(i, &val)| vec[i] = val as f32); - - let start = std::time::Instant::now(); - - let search_result = query_service - .query(&vec, metadata, top_k) - .expect("Failed to query"); - - let duration = start.elapsed(); - - println!("Query took {:?} to complete", duration); - warp::reply::json(&search_result) - }, - ); - - let add_vector_route = - warp::path!("addVector" / String) - .and(warp::post()) - .and(warp::body::json()) - .and(with_active_namespaces(active_namespaces.clone())) - .then( - |namespace_id: String, - body: (Vec, Vec, String), - active_namespaces| async move { - let base_path = PathBuf::from(format!( - "/workspace/data/{}/current", - namespace_id.clone() - )); - - ensure_namespace_initialized( - &namespace_id, - &active_namespaces, - base_path.clone(), - ) - .await; - - let mut commit_service = - CommitService::new(base_path, namespace_id.clone()).unwrap(); - let fvec = &body.0; - let metadata = &body.1; - - let mut vec: [f32; VECTOR_SIZE] = [0.0; VECTOR_SIZE]; - fvec.iter() - .enumerate() - .for_each(|(i, &val)| vec[i] = val as f32); - - // let id = uuid::Uuid::from_str(id_str).unwrap(); - commit_service.add_to_wal(vec![vec], vec![metadata.clone()]).expect("Failed to add to WAL"); - warp::reply::json(&"Success") - }, - ); - - // add a PITR route - let pitr_route = warp::path!("pitr" / String / String) - .and(warp::get()) - .and(with_active_namespaces(active_namespaces.clone())) - .then( - |namespace_id: String, timestamp: String, active_namespaces| async move { - println!("PITR for namespace: {}", namespace_id); - let base_path = - PathBuf::from(format!("/workspace/data/{}/current", namespace_id.clone())); - - ensure_namespace_initialized(&namespace_id, &active_namespaces, base_path.clone()) - .await; - - let mut commit_service = - CommitService::new(base_path, namespace_id.clone()).unwrap(); - - let timestamp = timestamp.parse::().unwrap(); - commit_service - .recover_point_in_time(timestamp) - .expect("Failed to PITR"); - warp::reply::json(&"Success") - }, - ); - - let routes = search_route - .or(add_vector_route) - .or(pitr_route) - .with(warp::cors().allow_any_origin()); - warp::serve(routes).run(([0, 0, 0, 0], 8080)).await; -} - -fn with_active_namespaces( - active_namespaces: Arc>>>, -) -> impl Filter< - Extract = (Arc>>>,), - Error = std::convert::Infallible, -> + Clone { - warp::any().map(move || active_namespaces.clone()) +use haystackdb::services::commit::CommitService; +use haystackdb::services::query::QueryService; +use haystackdb::structures::filters::{Filter, KVPair, KVValue}; +use std::fs; +use std::path::PathBuf; +use std::str::FromStr; +use uuid; + +fn random_vec() -> [f32; VECTOR_SIZE] { + let mut vec = [0.0; VECTOR_SIZE]; + for i in 0..VECTOR_SIZE { + vec[i] = rand::random::() * 2.0 - 1.0; + } + vec } -async fn ensure_namespace_initialized( - namespace_id: &String, - active_namespaces: &Arc>>>, - base_path_for_async: PathBuf, -) { - let mut namespaces = active_namespaces.lock().unwrap(); - if !namespaces.contains_key(namespace_id) { - let namespace_id_cloned = namespace_id.clone(); - let handle = tokio::spawn(async move { - let mut interval = interval(Duration::from_secs(10)); - loop { - interval.tick().await; - println!("Committing for namespace {}", namespace_id_cloned); - let start = std::time::Instant::now(); - let commit_worker = std::sync::Arc::new(std::sync::Mutex::new( - CommitService::new(base_path_for_async.clone(), namespace_id_cloned.clone()) - .unwrap(), - )); - - commit_worker - .lock() - .unwrap() - .commit() - .expect("Failed to commit"); - let duration = start.elapsed(); - info!("Commit worker took {:?} to complete", duration); - } - }); - namespaces.insert(namespace_id.clone(), handle); +fn main() { + let namespace_id = uuid::Uuid::new_v4().to_string(); + let path = PathBuf::from_str("tests/data") + .expect("Failed to create path") + .join("namespaces") + .join(namespace_id.clone()); + fs::create_dir_all(&path).expect("Failed to create directory"); + let mut commit_service = CommitService::new(path.clone(), namespace_id.clone()) + .expect("Failed to create commit service"); + + let start = std::time::Instant::now(); + // for _ in 0..20000 { + // commit_service + // .add_to_wal( + // vec![random_vec()], + // vec![vec![KVPair { + // key: "key".to_string(), + // value: "value".to_string(), + // }]], + // ) + // .expect("Failed to add to WAL"); + // } + + const NUM_VECTORS: usize = 1_000; + + let batch_vectors: Vec> = + (0..NUM_VECTORS).map(|_| vec![random_vec()]).collect(); + let batch_kvs: Vec>> = (0..NUM_VECTORS) + .map(|_| { + vec![vec![KVPair { + key: "key".to_string(), + value: KVValue::String("value".to_string()), + }]] + }) + .collect(); + + println!("Batch creation took: {:?}", start.elapsed()); + commit_service + .batch_add_to_wal(batch_vectors, batch_kvs) + .expect("Failed to add to WAL"); + + println!("Add to WAL took: {:?}", start.elapsed()); + + // commit_service + // .add_to_wal( + // vec![[0.0; VECTOR_SIZE]], + // vec![vec![KVPair { + // key: "key".to_string(), + // value: "value".to_string(), + // }]], + // ) + // .expect("Failed to add to WAL"); + + let start = std::time::Instant::now(); + + commit_service.commit().expect("Failed to commit"); + + println!("Commit took: {:?}", start.elapsed()); + + // commit_service.calibrate(); + + // commit_service.state.vectors.summarize_tree(); + + let mut query_service = + QueryService::new(path.clone(), namespace_id).expect("Failed to create query service"); + + let _start = std::time::Instant::now(); + + const NUM_RUNS: usize = 100; + + let start = std::time::Instant::now(); + + for _ in 0..NUM_RUNS { + let result = query_service + .query( + &random_vec(), + &Filter::Eq("key".to_string(), "value".to_string()), + 1, + ) + .expect("Failed to query"); + + // println!("{:?}", result); + if result.len() == 0 { + println!("No results found"); + } } + + println!("Query took: {:?}", start.elapsed().div_f32(NUM_RUNS as f32)); + + // let result = query_service + // .query( + // &[0.0; VECTOR_SIZE], + // vec![KVPair { + // key: "key".to_string(), + // value: "value".to_string(), + // }], + // 1, + // ) + // .expect("Failed to query"); + + // println!("{:?}", result); + + // println!("Query took: {:?}", start.elapsed()); } + +// fn main() { +// let mut storage_manager: StorageManager = StorageManager::new( +// PathBuf::from_str("tests/data/test.db").expect("Failed to create path"), +// ) +// .expect("Failed to create storage manager"); + +// let mut node: Node = Node::new_leaf(0); + +// for i in 0..2048 { +// node.set_key_value(i, uuid::Uuid::new_v4().to_string()); +// } + +// let serialized = Node::serialize(&node); +// let deserialized = Node::deserialize(&serialized); + +// assert_eq!(node, deserialized); + +// let offset = storage_manager +// .store_node(&mut node) +// .expect("Failed to store node"); + +// node.offset = offset; + +// let mut loaded_node = storage_manager +// .load_node(offset) +// .expect("Failed to load node"); + +// loaded_node.offset = offset; + +// assert_eq!(loaded_node, node); +// } diff --git a/src/main1.rs b/src/main1.rs deleted file mode 100644 index f0ce77e..0000000 --- a/src/main1.rs +++ /dev/null @@ -1,154 +0,0 @@ -extern crate haystackdb; -use haystackdb::constants::VECTOR_SIZE; -use haystackdb::services::commit::CommitService; -use haystackdb::services::query::QueryService; -use haystackdb::structures::filters::{Filter, KVPair, KVValue}; -use std::fs; -use std::path::PathBuf; -use std::str::FromStr; -use uuid; - -fn random_vec() -> [f32; VECTOR_SIZE] { - let mut vec = [0.0; VECTOR_SIZE]; - for i in 0..VECTOR_SIZE { - vec[i] = rand::random::() * 2.0 - 1.0; - } - vec -} - -fn main() { - let namespace_id = uuid::Uuid::new_v4().to_string(); - let path = PathBuf::from_str("tests/data") - .expect("Failed to create path") - .join("namespaces") - .join(namespace_id.clone()); - fs::create_dir_all(&path).expect("Failed to create directory"); - let mut commit_service = CommitService::new(path.clone(), namespace_id.clone()) - .expect("Failed to create commit service"); - - let start = std::time::Instant::now(); - // for _ in 0..20000 { - // commit_service - // .add_to_wal( - // vec![random_vec()], - // vec![vec![KVPair { - // key: "key".to_string(), - // value: "value".to_string(), - // }]], - // ) - // .expect("Failed to add to WAL"); - // } - - const NUM_VECTORS: usize = 10_000; - - let batch_vectors: Vec> = - (0..NUM_VECTORS).map(|_| vec![random_vec()]).collect(); - let batch_kvs: Vec>> = (0..NUM_VECTORS) - .map(|_| { - vec![vec![KVPair { - key: "key".to_string(), - value: KVValue::String("value".to_string()), - }]] - }) - .collect(); - - println!("Batch creation took: {:?}", start.elapsed()); - commit_service - .batch_add_to_wal(batch_vectors, batch_kvs) - .expect("Failed to add to WAL"); - - println!("Add to WAL took: {:?}", start.elapsed()); - - // commit_service - // .add_to_wal( - // vec![[0.0; VECTOR_SIZE]], - // vec![vec![KVPair { - // key: "key".to_string(), - // value: "value".to_string(), - // }]], - // ) - // .expect("Failed to add to WAL"); - - let start = std::time::Instant::now(); - - commit_service.commit().expect("Failed to commit"); - - println!("Commit took: {:?}", start.elapsed()); - - commit_service.calibrate(); - - commit_service.state.vectors.summarize_tree(); - - let mut query_service = - QueryService::new(path.clone(), namespace_id).expect("Failed to create query service"); - - let _start = std::time::Instant::now(); - - const NUM_RUNS: usize = 100; - - let start = std::time::Instant::now(); - - for _ in 0..NUM_RUNS { - let result = query_service - .query( - &random_vec(), - &Filter::Eq("key".to_string(), "value".to_string()), - 1, - ) - .expect("Failed to query"); - - // println!("{:?}", result); - if result.len() == 0 { - println!("No results found"); - } - } - - println!("Query took: {:?}", start.elapsed().div_f32(NUM_RUNS as f32)); - - // let result = query_service - // .query( - // &[0.0; VECTOR_SIZE], - // vec![KVPair { - // key: "key".to_string(), - // value: "value".to_string(), - // }], - // 1, - // ) - // .expect("Failed to query"); - - // println!("{:?}", result); - - // println!("Query took: {:?}", start.elapsed()); -} - -// fn main() { -// let mut storage_manager: StorageManager = StorageManager::new( -// PathBuf::from_str("tests/data/test.db").expect("Failed to create path"), -// ) -// .expect("Failed to create storage manager"); - -// let mut node: Node = Node::new_leaf(0); - -// for i in 0..2048 { -// node.set_key_value(i, uuid::Uuid::new_v4().to_string()); -// } - -// let serialized = Node::serialize(&node); -// let deserialized = Node::deserialize(&serialized); - -// assert_eq!(node, deserialized); - -// let offset = storage_manager -// .store_node(&mut node) -// .expect("Failed to store node"); - -// node.offset = offset; - -// let mut loaded_node = storage_manager -// .load_node(offset) -// .expect("Failed to load node"); - -// loaded_node.offset = offset; - -// assert_eq!(loaded_node, node); -// } diff --git a/src/main3.rs b/src/main3.rs new file mode 100755 index 0000000..7485d77 --- /dev/null +++ b/src/main3.rs @@ -0,0 +1,173 @@ +use env_logger::Builder; +use haystackdb::constants::VECTOR_SIZE; +use haystackdb::services::CommitService; +use haystackdb::services::QueryService; +use haystackdb::structures::filters::Filter as QueryFilter; +use haystackdb::structures::filters::KVPair; +use log::info; +use log::LevelFilter; +use std::io::Write; +use std::sync::{Arc, Mutex}; +use std::{self, path::PathBuf}; +use tokio::time::{interval, Duration}; + +use std::collections::HashMap; +use tokio::sync::OnceCell; +use warp::Filter; + +static ACTIVE_NAMESPACES: OnceCell>>>> = + OnceCell::const_new(); + +#[tokio::main] +async fn main() { + let mut builder = Builder::new(); + builder + .format(|buf, record| writeln!(buf, "{}: {}", record.level(), record.args())) + .filter(None, LevelFilter::Info) + .init(); + + let active_namespaces = ACTIVE_NAMESPACES + .get_or_init(|| async { Arc::new(Mutex::new(HashMap::new())) }) + .await; + + let search_route = warp::path!("query" / String) + .and(warp::post()) + .and(warp::body::json()) + .and(with_active_namespaces(active_namespaces.clone())) + .then( + |namespace_id: String, body: (Vec, QueryFilter, usize), active_namespaces| async move { + let base_path = PathBuf::from(format!("/workspace/data/{}/current", namespace_id.clone())); + ensure_namespace_initialized(&namespace_id, &active_namespaces, base_path.clone()) + .await; + + let mut query_service = QueryService::new(base_path, namespace_id.clone()).unwrap(); + let fvec = &body.0; + let metadata = &body.1; + let top_k = body.2; + + let mut vec: [f32; VECTOR_SIZE] = [0.0; VECTOR_SIZE]; + fvec.iter() + .enumerate() + .for_each(|(i, &val)| vec[i] = val as f32); + + let start = std::time::Instant::now(); + + let search_result = query_service + .query(&vec, metadata, top_k) + .expect("Failed to query"); + + let duration = start.elapsed(); + + println!("Query took {:?} to complete", duration); + warp::reply::json(&search_result) + }, + ); + + let add_vector_route = + warp::path!("addVector" / String) + .and(warp::post()) + .and(warp::body::json()) + .and(with_active_namespaces(active_namespaces.clone())) + .then( + |namespace_id: String, + body: (Vec, Vec, String), + active_namespaces| async move { + let base_path = PathBuf::from(format!( + "/workspace/data/{}/current", + namespace_id.clone() + )); + + ensure_namespace_initialized( + &namespace_id, + &active_namespaces, + base_path.clone(), + ) + .await; + + let mut commit_service = + CommitService::new(base_path, namespace_id.clone()).unwrap(); + let fvec = &body.0; + let metadata = &body.1; + + let mut vec: [f32; VECTOR_SIZE] = [0.0; VECTOR_SIZE]; + fvec.iter() + .enumerate() + .for_each(|(i, &val)| vec[i] = val as f32); + + // let id = uuid::Uuid::from_str(id_str).unwrap(); + commit_service.add_to_wal(vec![vec], vec![metadata.clone()]).expect("Failed to add to WAL"); + warp::reply::json(&"Success") + }, + ); + + // add a PITR route + let pitr_route = warp::path!("pitr" / String / String) + .and(warp::get()) + .and(with_active_namespaces(active_namespaces.clone())) + .then( + |namespace_id: String, timestamp: String, active_namespaces| async move { + println!("PITR for namespace: {}", namespace_id); + let base_path = + PathBuf::from(format!("/workspace/data/{}/current", namespace_id.clone())); + + ensure_namespace_initialized(&namespace_id, &active_namespaces, base_path.clone()) + .await; + + let mut commit_service = + CommitService::new(base_path, namespace_id.clone()).unwrap(); + + let timestamp = timestamp.parse::().unwrap(); + commit_service + .recover_point_in_time(timestamp) + .expect("Failed to PITR"); + warp::reply::json(&"Success") + }, + ); + + let routes = search_route + .or(add_vector_route) + .or(pitr_route) + .with(warp::cors().allow_any_origin()); + warp::serve(routes).run(([0, 0, 0, 0], 8080)).await; +} + +fn with_active_namespaces( + active_namespaces: Arc>>>, +) -> impl Filter< + Extract = (Arc>>>,), + Error = std::convert::Infallible, +> + Clone { + warp::any().map(move || active_namespaces.clone()) +} + +async fn ensure_namespace_initialized( + namespace_id: &String, + active_namespaces: &Arc>>>, + base_path_for_async: PathBuf, +) { + let mut namespaces = active_namespaces.lock().unwrap(); + if !namespaces.contains_key(namespace_id) { + let namespace_id_cloned = namespace_id.clone(); + let handle = tokio::spawn(async move { + let mut interval = interval(Duration::from_secs(10)); + loop { + interval.tick().await; + println!("Committing for namespace {}", namespace_id_cloned); + let start = std::time::Instant::now(); + let commit_worker = std::sync::Arc::new(std::sync::Mutex::new( + CommitService::new(base_path_for_async.clone(), namespace_id_cloned.clone()) + .unwrap(), + )); + + commit_worker + .lock() + .unwrap() + .commit() + .expect("Failed to commit"); + let duration = start.elapsed(); + info!("Commit worker took {:?} to complete", duration); + } + }); + namespaces.insert(namespace_id.clone(), handle); + } +} diff --git a/src/math.rs b/src/math.rs old mode 100644 new mode 100755 diff --git a/src/math/gemm.rs b/src/math/gemm.rs old mode 100644 new mode 100755 diff --git a/src/math/gemv.rs b/src/math/gemv.rs old mode 100644 new mode 100755 diff --git a/src/math/hamming_distance.rs b/src/math/hamming_distance.rs old mode 100644 new mode 100755 diff --git a/src/services.rs b/src/services.rs old mode 100644 new mode 100755 diff --git a/src/services/commit.rs b/src/services/commit.rs old mode 100644 new mode 100755 index b463a94..ebb3129 --- a/src/services/commit.rs +++ b/src/services/commit.rs @@ -97,15 +97,15 @@ impl CommitService { self.state.vectors.insert(*vector, id, kv); // self.state.texts.insert(id, texts.clone()); - // match texts { - // KVValue::String(text) => { - // self.state - // .texts - // .insert(id, compress_string(&text)) - // .expect("Failed to insert text"); - // } - // _ => {} - // } + match texts { + KVValue::String(text) => { + self.state + .texts + .insert(id, compress_string(&text)) + .expect("Failed to insert text"); + } + _ => {} + } } // self.state.wal.mark_commit_finished(commit.hash)?; diff --git a/src/services/lock_service.rs b/src/services/lock_service.rs old mode 100644 new mode 100755 diff --git a/src/services/namespace_state.rs b/src/services/namespace_state.rs old mode 100644 new mode 100755 diff --git a/src/services/query.rs b/src/services/query.rs old mode 100644 new mode 100755 diff --git a/src/structures.rs b/src/structures.rs old mode 100644 new mode 100755 index e0c9b1d..bafc44d --- a/src/structures.rs +++ b/src/structures.rs @@ -5,5 +5,6 @@ pub mod filters; // pub mod inverted_index; // pub mod metadata_index; pub mod mmap_tree; +pub mod storage_layer; pub mod tree; pub mod wal; diff --git a/src/structures/ann_tree.rs b/src/structures/ann_tree.rs old mode 100644 new mode 100755 index 9d93b96..40a83ec --- a/src/structures/ann_tree.rs +++ b/src/structures/ann_tree.rs @@ -21,6 +21,7 @@ use super::block_storage::BlockStorage; use super::filters::{combine_filters, Filter, Filters}; // use super::metadata_index::{KVPair, KVValue}; use super::mmap_tree::serialization::{TreeDeserialization, TreeSerialization}; +use super::storage_layer::StorageLayer; use crate::structures::filters::{calc_metadata_index_for_metadata, KVPair, KVValue}; use rayon::prelude::*; @@ -31,32 +32,13 @@ use std::path::PathBuf; pub struct ANNTree { pub k: usize, - pub storage_manager: BlockStorage, -} - -#[derive(Eq, PartialEq)] -struct PathNode { - distance: u16, - offset: usize, -} - -// Implement `Ord` and `PartialOrd` for `PathNode` to use it in a min-heap -impl Ord for PathNode { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - other.distance.cmp(&self.distance) // Reverse order for min-heap - } -} - -impl PartialOrd for PathNode { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } + pub storage_manager: StorageLayer, } impl ANNTree { pub fn new(path: PathBuf) -> Result { let mut storage_manager = - BlockStorage::new(path).expect("Failed to make storage manager in ANN Tree"); + StorageLayer::new(path).expect("Failed to make storage manager in ANN Tree"); // println!("INIT Used space: {}", storage_manager.used_space); @@ -396,7 +378,7 @@ impl ANNTree { .clone() .expect("") .get(&self.storage_manager) - .expect("Failed to get node metadata") + .unwrap_or(NodeMetadataIndex::new()) .get(kv.key.clone()) { Some(res) => { @@ -459,7 +441,7 @@ impl ANNTree { .node_metadata .expect("") .get(&self.storage_manager) - .expect("Failed to get node metadata") + .unwrap_or(NodeMetadataIndex::new()) .clone(); current_node_metadata.insert(kv.key.clone(), set); diff --git a/src/structures/ann_tree/k_modes.rs b/src/structures/ann_tree/k_modes.rs old mode 100644 new mode 100755 diff --git a/src/structures/ann_tree/metadata.rs b/src/structures/ann_tree/metadata.rs old mode 100644 new mode 100755 diff --git a/src/structures/ann_tree/node.rs b/src/structures/ann_tree/node.rs old mode 100644 new mode 100755 index 6fd00d5..c8821ab --- a/src/structures/ann_tree/node.rs +++ b/src/structures/ann_tree/node.rs @@ -2,7 +2,7 @@ use serde::Serialize; use std::io; use crate::structures::ann_tree::k_modes::{balanced_k_modes, balanced_k_modes_4}; -use crate::structures::block_storage::BlockStorage; +use crate::structures::storage_layer::StorageLayer; // use crate::structures::metadata_index::{KVPair, KVValue}; use crate::structures::ann_tree::serialization::{TreeDeserialization, TreeSerialization}; use crate::structures::filters::{calc_metadata_index_for_metadata, KVPair, KVValue}; @@ -48,7 +48,7 @@ impl LazyValue where T: Clone + TreeDeserialization + TreeSerialization, { - pub fn get(&mut self, storage: &BlockStorage) -> Result { + pub fn get(&mut self, storage: &StorageLayer) -> Result { match self.value.clone() { Some(value) => Ok(value), None => { @@ -60,7 +60,7 @@ where } } - pub fn new(value: T, storage: &mut BlockStorage) -> Result { + pub fn new(value: T, storage: &mut StorageLayer) -> Result { let offset = storage.store(value.serialize(), 0)?; Ok(LazyValue { offset, @@ -257,7 +257,7 @@ impl Node { } } - pub fn split(&mut self, storage: &mut BlockStorage) -> Result, io::Error> { + pub fn split(&mut self, storage: &mut StorageLayer) -> Result, io::Error> { let k = match self.node_type { NodeType::Leaf => 2, NodeType::Internal => 2, @@ -473,50 +473,44 @@ impl Node { // Deserialize vectors let vectors_len = read_length(&data[offset..offset + 4]); offset += 4; - let mut vectors = Vec::with_capacity(vectors_len); - for _ in 0..vectors_len { - vectors.push( - data[offset..offset + QUANTIZED_VECTOR_SIZE] - .try_into() - .unwrap(), - ); - offset += QUANTIZED_VECTOR_SIZE; - } + let vectors = data[offset..offset + vectors_len * QUANTIZED_VECTOR_SIZE] + .chunks_exact(QUANTIZED_VECTOR_SIZE) + .map(|chunk| chunk.try_into().unwrap()) + .collect::>(); + offset += vectors_len * QUANTIZED_VECTOR_SIZE; // Deserialize ids let ids_len = read_length(&data[offset..offset + 4]); offset += 4; - let mut ids = Vec::with_capacity(ids_len); - for _ in 0..ids_len { - let id = u128::from_le_bytes(data[offset..offset + 16].try_into().unwrap()); - offset += 16; - ids.push(id); - } + let ids = data[offset..offset + ids_len * 16] + .chunks_exact(16) + .map(|chunk| u128::from_le_bytes(chunk.try_into().unwrap())) + .collect::>(); + offset += ids_len * 16; // Deserialize children let children_len = read_length(&data[offset..offset + 4]); offset += 4; - let mut children = Vec::with_capacity(children_len); - for _ in 0..children_len { - let child = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; - offset += 8; - children.push(child); - } + let children = data[offset..offset + children_len * 8] + .chunks_exact(8) + .map(|chunk| u64::from_le_bytes(chunk.try_into().unwrap()) as usize) + .collect::>(); + offset += children_len * 8; - // deserialize metadata + // Deserialize metadata let metadata_len = read_length(&data[offset..offset + 4]); offset += 4; - - let mut metadata = Vec::with_capacity(metadata_len); - for _ in 0..metadata_len { - let meta_offset = - u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; - offset += 8; - metadata.push(LazyValue { - offset: meta_offset, - value: None, - }); - } + let metadata = data[offset..offset + metadata_len * 8] + .chunks_exact(8) + .map(|chunk| { + let meta_offset = u64::from_le_bytes(chunk.try_into().unwrap()) as usize; + LazyValue { + offset: meta_offset, + value: None, + } + }) + .collect::>(); + offset += metadata_len * 8; // Deserialize node_metadata let node_metadata_offset = @@ -528,82 +522,6 @@ impl Node { value: None, }; - // // Deserialize metadata - // let metadata_len = read_length(&data[offset..offset + 4]); - // offset += 4; - // let mut metadata = Vec::with_capacity(metadata_len); - // for _ in 0..metadata_len { - // let (meta, meta_size) = deserialize_metadata(&data[offset..]); - // metadata.push(meta); - // offset += meta_size; // Increment offset based on actual size of deserialized metadata - // } - - // // Deserialize node_metadata - // let mut node_metadata = NodeMetadataIndex::new(); - // let node_metadata_len = read_length(&data[offset..offset + 4]); - // offset += 4; - - // for _ in 0..node_metadata_len { - // let key_len = read_length(&data[offset..offset + 4]); - // offset += 4; - - // let key = String::from_utf8(data[offset..offset + key_len as usize].to_vec()).unwrap(); - // offset += key_len as usize; - - // let mut values = HashSet::new(); - // let values_len = read_length(&data[offset..offset + 4]); - // offset += 4; - - // for idx in 0..values_len { - // let value_len = read_length(&data[offset..offset + 4]); - // offset += 4; - - // if value_len > data.len() - offset { - // println!("Current IDX: {}", idx); - // println!("Value length: {}", value_len); - // println!("Value len binary: {:?}", (value_len as u32).to_le_bytes()); - // println!("Data length: {}", data.len()); - // // add some more debug prints for the current state of things to figure out where it's going wrong - - // println!("Offset: {}", offset); - // println!("Key: {}", key); - // println!("Values: {:?}", values); - // println!("Values len: {}", values_len); - // println!("Node metadata: {:?}", node_metadata); - // println!("Node metadata len: {}", node_metadata_len); - - // panic!("Value length exceeds data length"); - // } - - // let value = - // String::from_utf8(data[offset..offset + value_len as usize].to_vec()).unwrap(); - // offset += value_len as usize; - - // values.insert(value); - // } - - // let mut item = NodeMetadata { - // values: values.clone(), - // int_range: None, - // float_range: None, - // }; - - // let min_int = i64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()); - // offset += 8; - // let max_int = i64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()); - // offset += 8; - - // let min_float = f32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()); - // offset += 4; - // let max_float = f32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()); - // offset += 4; - - // item.int_range = Some((min_int, max_int)); - // item.float_range = Some((min_float, max_float)); - - // node_metadata.insert(key, item); - // } - Node { vectors, ids, diff --git a/src/structures/ann_tree/serialization.rs b/src/structures/ann_tree/serialization.rs old mode 100644 new mode 100755 diff --git a/src/structures/ann_tree/storage.rs b/src/structures/ann_tree/storage.rs old mode 100644 new mode 100755 diff --git a/src/structures/block_storage.rs b/src/structures/block_storage.rs old mode 100644 new mode 100755 index f9ec87c..91977c5 --- a/src/structures/block_storage.rs +++ b/src/structures/block_storage.rs @@ -350,10 +350,12 @@ impl BlockStorage { } pub fn acquire_lock(&self, index: usize) -> io::Result<()> { + return Ok(()); self.locks.acquire(index.to_string()) } pub fn release_lock(&self, index: usize) -> io::Result<()> { + return Ok(()); self.locks.release(index.to_string()) } } diff --git a/src/structures/filters.rs b/src/structures/filters.rs old mode 100644 new mode 100755 index f760e22..bdc0381 --- a/src/structures/filters.rs +++ b/src/structures/filters.rs @@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize}; use super::ann_tree::metadata::{NodeMetadata, NodeMetadataIndex}; use super::ann_tree::node::LazyValue; -use super::block_storage::BlockStorage; +use super::storage_layer::StorageLayer; use crate::structures::mmap_tree::serialization::{TreeDeserialization, TreeSerialization}; use std::fmt::Display; use std::hash::{Hash, Hasher}; @@ -480,7 +480,7 @@ pub fn combine_filters(filters: Vec) -> NodeMetadataIndex { pub fn calc_metadata_index_for_metadata( kvs: Vec>>, - storage: &BlockStorage, + storage: &StorageLayer, ) -> NodeMetadataIndex { let node_metadata: NodeMetadataIndex = kvs .into_iter() diff --git a/src/structures/mmap_tree.rs b/src/structures/mmap_tree.rs old mode 100644 new mode 100755 diff --git a/src/structures/mmap_tree/node.rs b/src/structures/mmap_tree/node.rs old mode 100644 new mode 100755 diff --git a/src/structures/mmap_tree/serialization.rs b/src/structures/mmap_tree/serialization.rs old mode 100644 new mode 100755 diff --git a/src/structures/mmap_tree/storage.rs b/src/structures/mmap_tree/storage.rs old mode 100644 new mode 100755 diff --git a/src/structures/storage_layer.rs b/src/structures/storage_layer.rs new file mode 100644 index 0000000..1b9b978 --- /dev/null +++ b/src/structures/storage_layer.rs @@ -0,0 +1,170 @@ +use crate::services::LockService; +use std::fs; +use std::fs::OpenOptions; +use std::io; +use std::io::{Read, Write}; +use std::path::PathBuf; + +pub struct StorageLayer { + path: PathBuf, + root_offset_path: PathBuf, + used_blocks_path: PathBuf, +} + +pub const SIZE_OF_U64: usize = std::mem::size_of::(); +pub const HEADER_SIZE: usize = SIZE_OF_U64 * 2; // Used space + root offset + +pub const BLOCK_SIZE: usize = 4096; // typical page size +pub const BLOCK_HEADER_SIZE: usize = SIZE_OF_U64 * 5 + 1; // Index in chain + Primary index + Next block offset + Previous block offset + Serialized node length + Is primary +pub const BLOCK_DATA_SIZE: usize = BLOCK_SIZE - BLOCK_HEADER_SIZE; + +impl StorageLayer { + pub fn new(path: PathBuf) -> io::Result { + let root_offset_path = path.join("root_offset.bin"); + let used_blocks_path = path.join("used_blocks.bin"); + + fs::create_dir_all(&path).expect("Failed to create directory"); + + let mut root_offset_file = OpenOptions::new() + .write(true) + .create(true) + .open(root_offset_path.clone()) + .expect("Failed to create root offset file"); + root_offset_file + .write_all(&[0u8; SIZE_OF_U64]) + .expect("Failed to write to root offset file"); + root_offset_file + .sync_all() + .expect("Failed to sync root offset file"); + + let mut used_blocks_file = OpenOptions::new() + .write(true) + .create(true) + .open(used_blocks_path.clone()) + .expect("Failed to create used blocks file"); + + used_blocks_file + .write_all(&[0u8; SIZE_OF_U64]) + .expect("Failed to write to used blocks file"); + used_blocks_file + .sync_all() + .expect("Failed to sync used blocks file"); + + Ok(StorageLayer { + path, + root_offset_path, + used_blocks_path, + }) + } + + pub fn used_blocks(&self) -> usize { + // (u64::from_le_bytes(self.mmap[0..SIZE_OF_U64].try_into().unwrap()) as usize) + 1 + + let mut file = OpenOptions::new() + .read(true) + .open(self.used_blocks_path.clone()) + .expect("Failed to open used blocks file"); + let mut bytes = Vec::new(); + file.read_to_end(&mut bytes) + .expect("Failed to read used blocks file"); + u64::from_le_bytes(bytes.try_into().unwrap()) as usize + 1usize + } + + pub fn set_used_blocks(&mut self, used_blocks: usize) { + // self.mmap[0..SIZE_OF_U64].copy_from_slice(&(used_blocks as u64).to_le_bytes()); + + let mut file = OpenOptions::new() + .write(true) + .open(self.used_blocks_path.clone()) + .expect("Failed to open used blocks file"); + file.write_all(&(used_blocks as u64).to_le_bytes()); + // file.sync_all().expect("Failed to sync used blocks file"); + } + + pub fn root_offset(&self) -> usize { + // u64::from_le_bytes( + // self.mmap[SIZE_OF_U64..(2 * SIZE_OF_U64)] + // .try_into() + // .unwrap(), + // ) as usize + + let mut file = OpenOptions::new() + .read(true) + .open(self.root_offset_path.clone()) + .expect("Failed to open root offset file"); + + let mut bytes = Vec::new(); + file.read_to_end(&mut bytes) + .expect("Failed to read root offset file"); + + u64::from_le_bytes(bytes.try_into().unwrap()) as usize + } + + pub fn set_root_offset(&mut self, root_offset: usize) { + // self.mmap[SIZE_OF_U64..(2 * SIZE_OF_U64)] + // .copy_from_slice(&(root_offset as u64).to_le_bytes()); + + let mut file = OpenOptions::new() + .write(true) + .open(self.root_offset_path.clone()) + .expect("Failed to open root offset file"); + + file.write_all(&(root_offset as u64).to_le_bytes()); + // file.sync_all().expect("Failed to sync root offset file"); + } + + pub fn increment_and_allocate_block(&mut self) -> usize { + let mut used_blocks = self.used_blocks(); + self.set_used_blocks(used_blocks + 1); + + used_blocks + } + + pub fn store(&mut self, serialized: Vec, index: usize) -> io::Result { + // save to temporary file and atomically rename + let block_index = if index == 0 { + self.increment_and_allocate_block() + } else { + index + }; + + let temp_file_path = self.path.join(format!("{}.tmp", block_index)); + let file_path = self.path.join(format!("{}.bin", block_index)); + + let mut file = OpenOptions::new() + .write(true) + .create(true) + .open(temp_file_path.clone()) + .expect("Failed to open temp file"); + + file.write_all(&serialized) + .expect("Failed to write to file"); + + fs::rename(temp_file_path, file_path).expect("Failed to rename temp file"); + + Ok(block_index) + } + + pub fn load(&self, offset: usize) -> io::Result> { + let file_path = self.path.join(format!("{}.bin", offset)); + + if !file_path.exists() { + return Err(io::Error::new( + io::ErrorKind::NotFound, + format!("File not found: {:?}", file_path), + )); + } + + let mut file = OpenOptions::new() + .read(true) + .open(file_path.clone()) + .expect(format!("Failed to open file: {:?}", file_path).as_str()); + + let mut serialized = Vec::new(); + + file.read_to_end(&mut serialized) + .expect("Failed to read file"); + + Ok(serialized) + } +} diff --git a/src/structures/tree.rs b/src/structures/tree.rs old mode 100644 new mode 100755 index a7aa164..458a1dd --- a/src/structures/tree.rs +++ b/src/structures/tree.rs @@ -7,10 +7,11 @@ use std::path::PathBuf; use node::{Node, NodeType, NodeValue}; use super::ann_tree::serialization::{TreeDeserialization, TreeSerialization}; -use super::block_storage::BlockStorage; +// use super::block_storage::StorageLayer; +use super::storage_layer::StorageLayer; pub struct Tree { - pub storage: BlockStorage, + pub storage: StorageLayer, phantom: std::marker::PhantomData<(K, V)>, } @@ -20,7 +21,7 @@ where V: Clone + TreeSerialization + TreeDeserialization, { pub fn new(path: PathBuf) -> io::Result { - let mut storage = BlockStorage::new(path)?; + let mut storage = StorageLayer::new(path)?; if storage.used_blocks() <= 1 { let root_offset: usize; @@ -144,8 +145,12 @@ where .keys .binary_search(&median.clone()) .unwrap_or_else(|x| x); - parent.keys.insert(idx, median.clone()); - parent.children.insert(idx + 1, sibling_offset); + parent + .keys + .insert(idx.min(parent.children.len() - 1), median.clone()); + parent + .children + .insert((idx + 1).min(parent.children.len() - 1), sibling_offset); self.store_node(&mut parent)?; } @@ -155,7 +160,11 @@ where } // Insert the key into the correct leaf node - let position = current_node.keys.binary_search(key).unwrap_or_else(|x| x); + let position = current_node + .keys + .binary_search(key) + .unwrap_or_else(|x| x) + .min(current_node.keys.len() - 1); if current_node.keys.get(position) == Some(&key) { current_node.values[position] = @@ -184,12 +193,9 @@ where return Ok(current_node); } NodeType::Internal => { - let mut i = 0; - while i < current_node.keys.len() && *key > current_node.keys[i] { - i += 1; - } + let i = current_node.keys.binary_search(key).unwrap_or_else(|x| x); - let child_offset = current_node.children[i]; + let child_offset = current_node.children[i.min(current_node.keys.len() - 1)]; current_node = self.load_node(child_offset)?; } } diff --git a/src/structures/tree/node.rs b/src/structures/tree/node.rs old mode 100644 new mode 100755 index ebe1bad..adec629 --- a/src/structures/tree/node.rs +++ b/src/structures/tree/node.rs @@ -1,6 +1,7 @@ use crate::structures::{ ann_tree::serialization::{TreeDeserialization, TreeSerialization}, block_storage::BlockStorage, + storage_layer::StorageLayer, }; use std::io; @@ -46,7 +47,7 @@ impl NodeValue where T: Clone + TreeDeserialization + TreeSerialization, { - pub fn get(&mut self, storage: &BlockStorage) -> Result { + pub fn get(&mut self, storage: &StorageLayer) -> Result { match self.value.clone() { Some(value) => Ok(value), None => { @@ -58,7 +59,7 @@ where } } - pub fn new(value: T, storage: &mut BlockStorage) -> Result { + pub fn new(value: T, storage: &mut StorageLayer) -> Result { let offset = storage.store(value.serialize(), 0)?; Ok(NodeValue { offset, diff --git a/src/structures/wal.rs b/src/structures/wal.rs old mode 100644 new mode 100755 diff --git a/src/utils.rs b/src/utils.rs old mode 100644 new mode 100755 diff --git a/src/utils/quantization.rs b/src/utils/quantization.rs old mode 100644 new mode 100755