diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index e46d23e3..12f887d6 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -5,7 +5,7 @@ jobs: cargo-lint: uses: ./.github/workflows/ci.yml with: - toolchain: 1.87.0 + toolchain: 1.89.0 cleanup: needs: - cargo-lint diff --git a/Cargo.toml b/Cargo.toml index 08cee221..d0aae745 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ members = ["crate/findex", "crate/memories"] resolver = "2" [workspace.package] -version = "8.0.1" +version = "8.1.0" authors = [ "Bruno Grieder ", "Célia Corsin ", @@ -30,3 +30,4 @@ cosmian_crypto_core = { version = "10.2", default-features = false, features = [ criterion = { version = "0.6" } smol-macros = { version = "0.1" } tokio = { version = "1.46" } +futures = "0.3.31" diff --git a/crate/findex/Cargo.toml b/crate/findex/Cargo.toml index d3ed614e..5cc351f6 100644 --- a/crate/findex/Cargo.toml +++ b/crate/findex/Cargo.toml @@ -15,14 +15,15 @@ name = "cosmian_findex" path = "src/lib.rs" [features] +batch = ["cosmian_sse_memories/batch"] test-utils = ["agnostic-lite", "criterion"] [dependencies] aes = "0.8" cosmian_crypto_core.workspace = true -cosmian_sse_memories = { path = "../memories", version = "8.0" } +cosmian_sse_memories = { path = "../memories", version = "8.1.0" } xts-mode = "0.5" - +futures.workspace = true # Optional dependencies for testing and benchmarking. agnostic-lite = { workspace = true, optional = true, features = [ "tokio", @@ -33,7 +34,7 @@ criterion = { workspace = true, optional = true } [dev-dependencies] agnostic-lite = { workspace = true, features = ["tokio"] } -cosmian_sse_memories = { path = "../memories", version = "8.0", features = [ +cosmian_sse_memories = { path = "../memories", version = "8.1.0", features = [ "redis-mem", "sqlite-mem", "postgres-mem", diff --git a/crate/findex/README.md b/crate/findex/README.md index d22b9a2d..fa9bad90 100644 --- a/crate/findex/README.md +++ b/crate/findex/README.md @@ -2,6 +2,8 @@ This crate provides the core functionality of Findex, defining the abstract data types, cryptographic operations, and encoding algorithms. +Supports batching operations into a singe call to the memory interface, which reduces connection overhead and avoids file descriptor limits on some Linux systems. + ## Setup Add `cosmian_findex` as dependency to your project : diff --git a/crate/findex/src/adt.rs b/crate/findex/src/adt.rs index ad61b09f..41e6b5c6 100644 --- a/crate/findex/src/adt.rs +++ b/crate/findex/src/adt.rs @@ -51,6 +51,40 @@ pub trait VectorADT: Send { fn read(&self) -> impl Send + Future, Self::Error>>; } +/// This trait provides methods that let an index operate on multiple keywords +/// or entries simultaneously. +#[cfg(feature = "batch")] +pub trait IndexBatcher { + type Error: std::error::Error; + + /// Search the index for the values bound to the given keywords. + fn batch_search( + &self, + keywords: Vec<&Keyword>, + ) -> impl Future>, Self::Error>>; + + /// Binds each value to their associated keyword in this index. + fn batch_insert( + &self, + entries: Entries, + ) -> impl Send + Future> + where + Values: Send + IntoIterator, + Entries: Send + IntoIterator, + Entries::IntoIter: ExactSizeIterator, + ::IntoIter: Send; + + /// Removes the given values from the index. + fn batch_delete( + &self, + entries: Entries, + ) -> impl Send + Future> + where + Values: Send + IntoIterator, + Entries: Send + IntoIterator, + Entries::IntoIter: ExactSizeIterator + Send; +} + #[cfg(test)] pub mod tests { diff --git a/crate/findex/src/error.rs b/crate/findex/src/error.rs index 8954f877..123e2c2e 100644 --- a/crate/findex/src/error.rs +++ b/crate/findex/src/error.rs @@ -1,5 +1,8 @@ use std::fmt::{Debug, Display}; +#[cfg(feature = "batch")] +pub use batch_findex_error::*; + #[derive(Debug)] pub enum Error
{ Parsing(String), @@ -16,3 +19,59 @@ impl Display for Error
{ } impl std::error::Error for Error
{} + +#[cfg(feature = "batch")] +pub mod batch_findex_error { + use cosmian_sse_memories::{MemoryADT, MemoryBatcherError}; + + use super::*; + + #[derive(Debug)] + pub enum BatchFindexError + where + ::Word: Debug, + { + BatchingLayer(MemoryBatcherError), + Findex(Error), + Other(String), + } + + impl Display for BatchFindexError + where + ::Address: Debug, + ::Word: Debug, + { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::BatchingLayer(e) => write!(f, "Batching layer error: {e}"), + Self::Findex(error) => write!(f, "Findex error: {error:?}"), + Self::Other(msg) => write!(f, "{msg}"), + } + } + } + + impl From> for BatchFindexError + where + ::Word: Debug, + { + fn from(e: Error) -> Self { + Self::Findex(e) + } + } + + impl From> for BatchFindexError + where + ::Word: Debug, + { + fn from(e: MemoryBatcherError) -> Self { + Self::BatchingLayer(e) + } + } + + impl std::error::Error for BatchFindexError + where + ::Address: Debug, + ::Word: Debug, + { + } +} diff --git a/crate/findex/src/findex.rs b/crate/findex/src/findex.rs index 9b1e9302..43bf0291 100644 --- a/crate/findex/src/findex.rs +++ b/crate/findex/src/findex.rs @@ -1,5 +1,3 @@ -#![allow(clippy::type_complexity)] - use std::{ collections::HashSet, fmt::Debug, diff --git a/crate/findex/src/findex_batcher.rs b/crate/findex/src/findex_batcher.rs new file mode 100644 index 00000000..f5f18937 --- /dev/null +++ b/crate/findex/src/findex_batcher.rs @@ -0,0 +1,319 @@ +use std::{collections::HashSet, fmt::Debug, hash::Hash, num::NonZero}; + +use cosmian_sse_memories::{ADDRESS_LENGTH, Address, BatchingMemoryADT, MemoryBatcher}; + +use crate::{Decoder, Encoder, Findex, IndexADT, adt::IndexBatcher, error::BatchFindexError}; + +#[derive(Debug)] +pub struct FindexBatcher< + const WORD_LENGTH: usize, + Value: Send + Hash + Eq, + EncodingError: Send + Debug, + BatcherMemory: Clone + Send + BatchingMemoryADT
, Word = [u8; WORD_LENGTH]>, +> { + memory: BatcherMemory, + encode: Encoder, + decode: Decoder, +} + +impl< + const WORD_LENGTH: usize, + Value: Send + Hash + Eq, + BatcherMemory: Debug + + Send + + Sync + + Clone + + BatchingMemoryADT
, Word = [u8; WORD_LENGTH]>, + EncodingError: Send + Debug, +> FindexBatcher +{ + pub fn new( + memory: BatcherMemory, + encode: Encoder, + decode: Decoder, + ) -> Self { + Self { + memory, + encode, + decode, + } + } + + async fn batch_insert_or_delete( + &self, + entries: Entries, + is_insert: bool, + ) -> Result<(), BatchFindexError> + where + Keyword: Send + Sync + Hash + Eq, + Bindings: Send + IntoIterator, + Entries: IntoIterator + Send, + Entries::IntoIter: ExactSizeIterator, + { + let entries = entries.into_iter(); + let n = entries.len(); + let mut futures = Vec::with_capacity(n); + let memory = MemoryBatcher::new( + self.memory.clone(), + NonZero::new(n) + .ok_or_else(|| BatchFindexError::Other("Batch size can not be zero".to_owned()))?, + ); + + for (guard_keyword, bindings) in entries { + let memory = memory.clone(); + // Create a temporary Findex instance using the shared batching layer. + let findex = Findex::::new( + memory.clone(), + self.encode, + self.decode, + ); + + let future = async move { + if is_insert { + findex.insert(guard_keyword, bindings).await + } else { + findex.delete(guard_keyword, bindings).await + }?; + // Within findex, insert/delete operations may perform a variable number of + // memory writes. This requires explicit `unsubscribe()` calls + // to adjust the expected buffer size once one of the operations succeeds. + memory.unsubscribe().await?; + Ok::<_, BatchFindexError<_>>(()) + }; + + futures.push(future); + } + + // Execute all futures concurrently and collect results. + futures::future::try_join_all(futures).await?; + + Ok(()) + } +} + +impl< + const WORD_LENGTH: usize, + Keyword: Send + Sync + Hash + Eq, + Value: Send + Hash + Eq, + EncodingError: Send + Debug, + BatchingMemoryLayer: Debug + + Send + + Sync + + Clone + + BatchingMemoryADT
, Word = [u8; WORD_LENGTH]>, +> IndexBatcher + for FindexBatcher +{ + type Error = BatchFindexError; + + async fn batch_insert(&self, entries: Entries) -> Result<(), Self::Error> + where + Bindings: Send + IntoIterator, + Entries: Send + IntoIterator, + Entries::IntoIter: ExactSizeIterator, + { + self.batch_insert_or_delete(entries, true).await + } + + async fn batch_delete(&self, entries: Entries) -> Result<(), Self::Error> + where + Bindings: Send + IntoIterator, + Entries: Send + IntoIterator, + Entries::IntoIter: ExactSizeIterator, + { + self.batch_insert_or_delete(entries, false).await + } + + async fn batch_search( + &self, + keywords: Vec<&Keyword>, + ) -> Result>, Self::Error> { + let n = keywords.len(); + let mut futures = Vec::with_capacity(n); + let memory = MemoryBatcher::new( + self.memory.clone(), + NonZero::new(n) + .ok_or_else(|| BatchFindexError::Other("Batch size can not be zero".to_owned()))?, + ); + + for keyword in keywords { + let memory = memory.clone(); + let findex = Findex::::new( + memory, + self.encode, + self.decode, + ); + // Search operations do not require calling `unsubscribe()` on the memory + // batcher. This is because all Findex search operations perform the + // same deterministic number of memory read operations. + // Specifically, each (safe) Findex search completes after + // performing exactly two reads. + let future = async move { findex.search(keyword).await }; + futures.push(future); + } + + futures::future::try_join_all(futures) + .await + .map_err(|e| BatchFindexError::Findex(e)) + } +} + +// These tests implement dual testing against the base Findex implementation. +#[cfg(test)] +mod tests { + use std::collections::HashSet; + + use cosmian_crypto_core::define_byte_type; + use cosmian_sse_memories::{ADDRESS_LENGTH, InMemory}; + + use super::*; + use crate::{Findex, IndexADT, dummy_decode, dummy_encode}; + + type Value = Bytes<8>; + define_byte_type!(Bytes); + + impl TryFrom for Bytes { + type Error = String; + + fn try_from(value: usize) -> Result { + Self::try_from(value.to_be_bytes().as_slice()).map_err(|e| e.to_string()) + } + } + + const WORD_LENGTH: usize = 16; + + #[tokio::test] + async fn test_batch_insert_and_delete() { + let trivial_memory = InMemory::, [u8; WORD_LENGTH]>::default(); + + // Initial data for insertion + let cat_bindings = vec![ + Value::try_from(1).unwrap(), + Value::try_from(2).unwrap(), + Value::try_from(3).unwrap(), + Value::try_from(7).unwrap(), + ]; + let dog_bindings = vec![ + Value::try_from(4).unwrap(), + Value::try_from(5).unwrap(), + Value::try_from(6).unwrap(), + ]; + + // Insert using normal findex + let findex = Findex::new( + trivial_memory.clone(), + dummy_encode::, + dummy_decode, + ); + + findex + .insert("cat".to_string(), cat_bindings.clone()) + .await + .unwrap(); + findex + .insert("dog".to_string(), dog_bindings.clone()) + .await + .unwrap(); + + // Create a `findex_batcher` instance + let findex_batcher = FindexBatcher::::new( + trivial_memory.clone(), + dummy_encode, + dummy_decode, + ); + + // Test batch delete + let deletion_entries = vec![ + ( + "cat".to_string(), + vec![Value::try_from(1).unwrap(), Value::try_from(3).unwrap()], // Partial deletion + ), + ("dog".to_string(), dog_bindings), // Complete deletion + ]; + + findex_batcher.batch_delete(deletion_entries).await.unwrap(); + + // Verify deletions using normal findex + let cat_result_after_delete = findex.search(&"cat".to_string()).await.unwrap(); + let dog_result_after_delete = findex.search(&"dog".to_string()).await.unwrap(); + + let expected_cat = vec![ + Value::try_from(2).unwrap(), // 1 and 3 removed, 2 and 7 remain + Value::try_from(7).unwrap(), + ] + .into_iter() + .collect::>(); + let expected_dog = HashSet::new(); // All dog bindings removed + + assert_eq!(cat_result_after_delete, expected_cat); + assert_eq!(dog_result_after_delete, expected_dog); + + // Test batch insert + let insert_entries = vec![( + "dog".to_string(), + vec![Value::try_from(8).unwrap(), Value::try_from(9).unwrap()], + )]; + + findex_batcher.batch_insert(insert_entries).await.unwrap(); + + // Verify insertions using normal findex + let new_dog_results = findex.search(&"dog".to_string()).await.unwrap(); + + let expected_dog = vec![Value::try_from(8).unwrap(), Value::try_from(9).unwrap()] + .into_iter() + .collect::>(); + + assert_eq!(new_dog_results, expected_dog); + } + + #[tokio::test] + async fn test_batch_search() { + let trivial_memory = InMemory::, [u8; WORD_LENGTH]>::default(); + + let findex = Findex::new( + trivial_memory.clone(), + dummy_encode::, + dummy_decode, + ); + let cat_bindings = [ + Value::try_from(1).unwrap(), + Value::try_from(3).unwrap(), + Value::try_from(5).unwrap(), + ]; + let dog_bindings = [ + Value::try_from(0).unwrap(), + Value::try_from(2).unwrap(), + Value::try_from(4).unwrap(), + ]; + findex + .insert("cat".to_string(), cat_bindings.clone()) + .await + .unwrap(); + findex + .insert("dog".to_string(), dog_bindings.clone()) + .await + .unwrap(); + + let findex_batcher = FindexBatcher::::new( + trivial_memory.clone(), + dummy_encode, + dummy_decode, + ); + + let key1 = "cat".to_string(); + let key2 = "dog".to_string(); + // Perform batch search + let batch_search_results = findex_batcher + .batch_search(vec![&key1, &key2]) + .await + .unwrap(); + + assert_eq!( + batch_search_results, + vec![ + cat_bindings.iter().cloned().collect::>(), + dog_bindings.iter().cloned().collect::>() + ] + ); + } +} diff --git a/crate/findex/src/lib.rs b/crate/findex/src/lib.rs index d328cc55..19924690 100644 --- a/crate/findex/src/lib.rs +++ b/crate/findex/src/lib.rs @@ -30,6 +30,13 @@ pub use encryption_layer::{KEY_LENGTH, MemoryEncryptionLayer}; pub use error::Error; pub use findex::{Findex, Op}; +#[cfg(feature = "batch")] +mod findex_batcher; +#[cfg(feature = "batch")] +pub use adt::IndexBatcher; +#[cfg(feature = "batch")] +pub use findex_batcher::FindexBatcher; + #[cfg(feature = "test-utils")] pub mod reexport { // Re-exporting the most commonly used runtime interfaces for convenience. diff --git a/crate/memories/Cargo.toml b/crate/memories/Cargo.toml index 2a3b2454..2fffc7a1 100644 --- a/crate/memories/Cargo.toml +++ b/crate/memories/Cargo.toml @@ -19,6 +19,7 @@ redis-mem = ["redis"] sqlite-mem = ["async-sqlite"] postgres-mem = ["deadpool-postgres", "tokio", "tokio-postgres"] test-utils = ["tokio"] +batch = [] [dependencies] agnostic-lite = { workspace = true } @@ -33,6 +34,7 @@ tokio = { workspace = true, optional = true } tokio-postgres = { version = "0.7", optional = true, features = [ "array-impls", ] } +futures.workspace = true [dev-dependencies] agnostic-lite = { workspace = true, features = ["tokio", "smol"] } diff --git a/crate/memories/src/batching_layer/buffer.rs b/crate/memories/src/batching_layer/buffer.rs new file mode 100644 index 00000000..b3305a98 --- /dev/null +++ b/crate/memories/src/batching_layer/buffer.rs @@ -0,0 +1,88 @@ +//! Thread-safe buffer for batching memory operations. +//! +//! Operations are accumulated until capacity is reached, then flushed in one +//! call to the inner memory. All operations are synchronized via `Mutex` to +//! ensure thread-safe concurrent access. + +use std::{ + mem, + num::{NonZero, NonZeroUsize}, + sync::Mutex, +}; + +use crate::{ + BatchingMemoryADT, + batching_layer::operation::{Operation, PendingOperations}, +}; + +struct Buffer { + capacity: NonZeroUsize, + data: PendingOperations, +} + +impl Buffer { + /// Flushes the buffer if it contains data and returns the flushed + /// operations. Returns None if the buffer is empty. + fn flush_if_not_empty(&mut self) -> Option> { + if !self.data.is_empty() { + Some(mem::take(&mut self.data)) + } else { + None + } + } +} + +pub(crate) struct ThreadSafeBuffer(Mutex>); + +impl ThreadSafeBuffer +where + M: BatchingMemoryADT, +{ + pub(crate) fn new(capacity: NonZeroUsize) -> Self { + Self(Mutex::new(Buffer:: { + capacity, + data: Vec::with_capacity(capacity.into()), + })) + } + + pub(crate) fn shrink_capacity(&self) -> Result>, BufferError> { + let mut buffer = self.0.lock().expect("poisoned lock"); + if buffer.capacity == NonZero::new(1).unwrap() { + return Ok(buffer.flush_if_not_empty()); + } + buffer.capacity = + NonZero::new(buffer.capacity.get() - 1).expect("buffer capacity should not reach zero"); + Ok(buffer.flush_if_not_empty()) + } + + pub(crate) fn push( + &self, + item: Operation, + ) -> Result>, BufferError> { + let mut buffer = self.0.lock().expect("poisoned lock"); + buffer.data.push(item); + Ok(buffer.flush_if_not_empty()) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum BufferError { + TypeMismatch, + Overflow, + Underflow, +} + +impl std::fmt::Display for BufferError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::TypeMismatch => write!( + f, + "Type mismatch: cannot mix read and write operations in the same buffer." + ), + Self::Overflow => write!(f, "Buffer overflow: cannot push below capacity."), + Self::Underflow => write!(f, "Buffer underflow: cannot shrink capacity below zero."), + } + } +} + +impl std::error::Error for BufferError {} diff --git a/crate/memories/src/batching_layer/error.rs b/crate/memories/src/batching_layer/error.rs new file mode 100644 index 00000000..26ad6bd3 --- /dev/null +++ b/crate/memories/src/batching_layer/error.rs @@ -0,0 +1,69 @@ +use std::fmt::{Debug, Display}; + +use futures::channel::oneshot::Canceled; + +use crate::{ + MemoryADT, + batching_layer::{buffer::BufferError, operation::MemoryOutput}, +}; + +#[derive(Debug)] +pub enum MemoryBatcherError +where + M::Word: std::fmt::Debug, +{ + Memory(M::Error), /* the from will not be implemented due to conflicting + * implementations with Rust's `core` library. Use `map_err` instead of + * `?`. */ + ClosedChannel, + Buffering(BufferError), + WrongResultType(MemoryOutput), +} + +impl Display for MemoryBatcherError +where + M::Word: std::fmt::Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Memory(err) => write!(f, "Memory error: {:?}", err), + Self::ClosedChannel => write!( + f, + "Channel closed unexpectedly, the sender was dropped before sending its results \ + with the `send` function.." + ), + Self::Buffering(err) => write!(f, "Internal buffering error: {:?}", err), + Self::WrongResultType(out) => { + write!( + f, + "Wrong result type, expected {:?}", + match out { + MemoryOutput::Read(_) => "Read, got Write.", + MemoryOutput::Write(_) => "Write, got Read.", + } + ) + } + } + } +} + +impl From for MemoryBatcherError +where + M::Word: std::fmt::Debug, +{ + fn from(_: Canceled) -> Self { + Self::ClosedChannel + } +} + +impl From for MemoryBatcherError +where + M::Word: std::fmt::Debug, +{ + fn from(e: BufferError) -> Self { + Self::Buffering(e) + } +} + +impl std::error::Error for MemoryBatcherError where M::Word: std::fmt::Debug +{} diff --git a/crate/memories/src/batching_layer/memory.rs b/crate/memories/src/batching_layer/memory.rs new file mode 100644 index 00000000..c71f2f2d --- /dev/null +++ b/crate/memories/src/batching_layer/memory.rs @@ -0,0 +1,193 @@ +use std::{fmt::Debug, num::NonZeroUsize, sync::Arc}; + +use futures::channel::oneshot; + +use crate::{ + BatchingMemoryADT, MemoryADT, + batching_layer::{ + MemoryBatcherError, + buffer::ThreadSafeBuffer, + operation::{ + MemoryInput, MemoryOutput, Operation, OperationResultReceiver, PendingOperations, + }, + }, +}; + +pub struct MemoryBatcher { + pub inner: Arc, // The actual memory that does the R/W operations. + buffer: Arc>, // The buffer that holds the operations to be batched. +} + +impl Clone for MemoryBatcher { + fn clone(&self) -> Self { + Self { + inner: Arc::clone(&self.inner), + buffer: Arc::clone(&self.buffer), + } + } +} + +impl MemoryADT for MemoryBatcher +where + M::Address: Clone, + M::Word: std::fmt::Debug, +{ + type Address = M::Address; + type Error = MemoryBatcherError; + type Word = M::Word; + + async fn batch_read( + &self, + addresses: Vec, + ) -> Result>, Self::Error> { + let res = self.apply(MemoryInput::Read(addresses)).await?; + + if let MemoryOutput::Read(words) = res { + Ok(words) + } else { + Err(MemoryBatcherError::WrongResultType(res)) + } + } + + async fn guarded_write( + &self, + guard: (Self::Address, Option), + bindings: Vec<(Self::Address, Self::Word)>, + ) -> Result, Self::Error> { + let res = self.apply(MemoryInput::Write((guard, bindings))).await?; + + if let MemoryOutput::Write(word) = res { + Ok(word) + } else { + Err(MemoryBatcherError::WrongResultType(res)) + } + } +} + +impl MemoryBatcher +where + M::Address: Clone + Send, + M::Word: Send + std::fmt::Debug, +{ + pub fn new(inner: M, capacity: NonZeroUsize) -> Self { + Self { + inner: Arc::new(inner), + buffer: Arc::new(ThreadSafeBuffer::new(capacity)), + } + } + + pub async fn unsubscribe(&self) -> Result<(), MemoryBatcherError> { + if let Some(ops) = self.buffer.shrink_capacity()? { + self.manage(ops).await?; + } + Ok(()) + } + + async fn apply(&self, op: MemoryInput) -> Result, MemoryBatcherError> { + let (operation, receiver) = match op { + MemoryInput::Read(addresses) => { + let (sender, receiver) = oneshot::channel(); + ( + Operation::Read((addresses, sender)), + OperationResultReceiver::::Read(receiver), + ) + } + MemoryInput::Write((guard, bindings)) => { + let (sender, receiver) = oneshot::channel(); + ( + Operation::Write(((guard, bindings), sender)), + OperationResultReceiver::::Write(receiver), + ) + } + }; + + if let Some(ops) = self.buffer.push(operation)? { + self.manage(ops).await?; + } + + Ok(match receiver { + OperationResultReceiver::Read(receiver) => { + let result = receiver.await?.map_err(MemoryBatcherError::Memory)?; + MemoryOutput::Read(result) + } + OperationResultReceiver::Write(receiver) => { + let result = receiver.await?.map_err(MemoryBatcherError::Memory)?; + MemoryOutput::Write(result) + } + }) + } + + async fn manage(&self, ops: PendingOperations) -> Result<(), MemoryBatcherError> { + // Assumes the vector is homogeneous, i.e. all operations are of the same type. + // This should be guaranteed by the buffer. + match ops[0] { + Operation::Read(_) => { + // Build combined address list while tracking which addresses belong to which + // batch. + let all_addresses: Vec<_> = ops + .iter() + .map(|op| match op { + Operation::Read((addresses, _)) => Ok(addresses.clone()), + _ => Err(MemoryBatcherError::Buffering( + crate::batching_layer::buffer::BufferError::TypeMismatch, + )), + }) + .collect::, _>>()? // Short-circuit on first error. + .into_iter() + .flatten() + .collect(); + + let mut words = self + .inner + .batch_read(all_addresses) + .await + .map_err(MemoryBatcherError::Memory)?; + + // Distribute results to each batch's sender. + for (input_addresses, sender) in ops + .into_iter() + .map(|op| match op { + Operation::Read((addresses, sender)) => Ok((addresses, sender)), + _ => Err(MemoryBatcherError::Buffering( + crate::batching_layer::buffer::BufferError::TypeMismatch, + )), + }) + .collect::, _>>()? + .into_iter() + .rev() + { + let batch_results = words.split_off(words.len() - input_addresses.len()); // After this call, all_results will be left containing the elements [0, split_point). + sender + .send(Ok(batch_results)) + .map_err(|_| MemoryBatcherError::::ClosedChannel)?; + } + } + Operation::Write(_) => { + let (bindings, senders): (Vec<_>, Vec<_>) = ops + .into_iter() + .map(|op| match op { + Operation::Write((bindings, sender)) => Ok((bindings, sender)), + _ => Err(MemoryBatcherError::Buffering( + crate::batching_layer::buffer::BufferError::TypeMismatch, + )), + }) + .collect::, _>>()? + .into_iter() + .unzip(); + + let aggregated_writes_results = self + .inner + .batch_guarded_write(bindings) + .await + .map_err(MemoryBatcherError::Memory)?; + + for (res, sender) in aggregated_writes_results.into_iter().zip(senders) { + sender + .send(Ok(res)) + .map_err(|_| MemoryBatcherError::::ClosedChannel)?; + } + } + }; + Ok(()) + } +} diff --git a/crate/memories/src/batching_layer/mod.rs b/crate/memories/src/batching_layer/mod.rs new file mode 100644 index 00000000..d98e2bf7 --- /dev/null +++ b/crate/memories/src/batching_layer/mod.rs @@ -0,0 +1,9 @@ +mod buffer; +mod error; +mod memory; +mod operation; + +pub use error::MemoryBatcherError; +pub use memory::MemoryBatcher; + +pub use crate::batching_layer::operation::{BatchReadInput, GuardedWriteInput}; diff --git a/crate/memories/src/batching_layer/operation.rs b/crate/memories/src/batching_layer/operation.rs new file mode 100644 index 00000000..4faff8dd --- /dev/null +++ b/crate/memories/src/batching_layer/operation.rs @@ -0,0 +1,62 @@ +//! This module strongly types and defines the variables that are used within +//! the batching layer. It adds clear distinction between : +//! - And `input`(resp `output`) type, which is the type that the memory backend +//! accepts (resp returns) via its MemoryADT implementation. +//! - An operation type, which is a pair of an input and a oneshot channel. An +//! operation is considered *pending* starting the moment it is pushed to the +//! buffer, and to each operation corresponds exactly one result consisting of +//! an output that can be retrieved from the oneshot channel (or otherwise an +//! error). +use futures::channel::oneshot; + +use crate::{BatchingMemoryADT, MemoryADT}; + +pub type BatchReadInput = Vec<::Address>; +pub type GuardedWriteInput = ( + (::Address, Option<::Word>), + Vec<(::Address, ::Word)>, +); + +// Notice: to avoid breaking changes, the MemoryADT I/O types are kept here for +// now. If a major release is planned, consider moving them to the MemoryADT +// module. +pub(crate) enum MemoryInput { + Read(BatchReadInput), + Write(GuardedWriteInput), +} + +pub(crate) type BatchReadOutput = Vec::Word>>; +pub(crate) type GuardedWriteOutput = Option<::Word>; + +#[derive(Debug)] +pub enum MemoryOutput +where + M::Word: std::fmt::Debug, +{ + Read(BatchReadOutput), + Write(GuardedWriteOutput), +} + +pub(crate) type ReadOperation = ( + BatchReadInput, + oneshot::Sender, ::Error>>, +); + +pub(crate) type WriteOperation = ( + GuardedWriteInput, + oneshot::Sender, ::Error>>, +); + +pub(crate) enum Operation { + Read(ReadOperation), + Write(WriteOperation), +} + +pub(crate) type PendingOperations = Vec>; + +// Match arms do not support heterogeneous types, this enum is the only way to +// escape a 2 branch `apply` function and the code duplication that would imply. +pub(crate) enum OperationResultReceiver { + Read(oneshot::Receiver, ::Error>>), + Write(oneshot::Receiver, ::Error>>), +} diff --git a/crate/memories/src/databases/postgresql_mem/memory.rs b/crate/memories/src/databases/postgresql_mem/memory.rs index 61399d27..b9d8f35a 100644 --- a/crate/memories/src/databases/postgresql_mem/memory.rs +++ b/crate/memories/src/databases/postgresql_mem/memory.rs @@ -205,6 +205,8 @@ impl MemoryADT #[cfg(test)] mod tests { + use std::future::Future; + use deadpool_postgres::Config; use tokio_postgres::NoTls; @@ -235,7 +237,7 @@ mod tests { ) -> Result<(), PostgresMemoryError> where F: FnOnce(PostgresMemory, [u8; 129]>) -> Fut + Send, - Fut: std::future::Future + Send, + Fut: Future + Send, { let test_pool = create_testing_pool(DB_URL).await.unwrap(); let m = PostgresMemory::new_with_pool(test_pool.clone(), table_name.to_string()).await; diff --git a/crate/memories/src/in_memory.rs b/crate/memories/src/in_memory.rs index 101e0e8b..48c8e713 100644 --- a/crate/memories/src/in_memory.rs +++ b/crate/memories/src/in_memory.rs @@ -7,6 +7,8 @@ use std::{ sync::{Arc, Mutex}, }; +#[cfg(feature = "batch")] +use crate::BatchingMemoryADT; use crate::MemoryADT; #[derive(Debug, Clone, PartialEq, Eq)] @@ -74,6 +76,31 @@ impl Memory } } +#[cfg(feature = "batch")] +impl BatchingMemoryADT + for InMemory +{ + async fn batch_guarded_write( + &self, + operations: Vec<((Address, Option), Vec<(Address, Value)>)>, + ) -> Result>, Self::Error> { + let store = &mut *self.inner.lock().expect("poisoned lock"); + Ok(operations + .into_iter() + .map(|(guard, bindings)| { + let (a, old) = guard; + let cur = store.get(&a).cloned(); + if old == cur { + for (k, v) in bindings { + store.insert(k, v); + } + } + cur + }) + .collect()) + } +} + impl IntoIterator for InMemory { diff --git a/crate/memories/src/lib.rs b/crate/memories/src/lib.rs index d2bccf8b..e82c0676 100644 --- a/crate/memories/src/lib.rs +++ b/crate/memories/src/lib.rs @@ -2,6 +2,8 @@ mod address; mod databases; mod in_memory; +use std::future::Future; + pub use address::Address; #[cfg(feature = "postgres-mem")] pub use databases::postgresql_mem::{PostgresMemory, PostgresMemoryError}; @@ -44,7 +46,7 @@ pub trait MemoryADT { fn batch_read( &self, addresses: Vec, - ) -> impl Send + std::future::Future>, Self::Error>>; + ) -> impl Send + Future>, Self::Error>>; /// Write the given bindings if the word currently stored at the guard /// address is the guard word, and returns this word. @@ -52,5 +54,24 @@ pub trait MemoryADT { &self, guard: (Self::Address, Option), bindings: Vec<(Self::Address, Self::Word)>, - ) -> impl Send + std::future::Future, Self::Error>>; + ) -> impl Send + Future, Self::Error>>; +} + +#[cfg(feature = "batch")] +mod batching_layer; + +#[cfg(feature = "batch")] +pub use batching_layer::{MemoryBatcher, MemoryBatcherError}; + +#[cfg(feature = "batch")] +pub use crate::batching_layer::{BatchReadInput, GuardedWriteInput}; + +// Super trait for MemoryADT that allows doing write operations in batches. +#[cfg(feature = "batch")] +pub trait BatchingMemoryADT: MemoryADT { + #[allow(clippy::type_complexity)] // Refactoring this type will make the code unnecessarily more difficult to read without any actual benefit. + fn batch_guarded_write( + &self, + write_operations: Vec>, + ) -> impl Send + Future>, Self::Error>>; }