diff --git a/Cargo.lock b/Cargo.lock index 9d48a98..575912a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "aho-corasick" @@ -64,16 +64,16 @@ dependencies = [ ] [[package]] -name = "byte-slice-cast" -version = "1.2.2" +name = "bumpalo" +version = "3.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3ac9f8b63eca6fd385229b3675f6cc0dc5c8a5c8a54a59d4f52ffd670d87b0c" +checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" [[package]] -name = "bytemuck" -version = "1.14.0" +name = "byte-slice-cast" +version = "1.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6" +checksum = "c3ac9f8b63eca6fd385229b3675f6cc0dc5c8a5c8a54a59d4f52ffd670d87b0c" [[package]] name = "cast" @@ -145,6 +145,15 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f7144d30dcf0fafbce74250a3963025d8d52177934239851c917d29f1df280c2" +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + [[package]] name = "criterion" version = "0.5.1" @@ -179,6 +188,31 @@ dependencies = [ "itertools", ] +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + [[package]] name = "derive_more" version = "0.99.17" @@ -208,19 +242,21 @@ version = "0.1.0" dependencies = [ "blake2b_simd", "bounded-collections", + "bumpalo", "criterion", "parity-scale-codec", "quickcheck", "rand", + "rayon", "reed-solomon-simd", "thiserror", ] [[package]] name = "fixedbitset" -version = "0.4.2" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" +checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" [[package]] name = "getrandom" @@ -300,9 +336,9 @@ checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" [[package]] name = "libc" -version = "0.2.151" +version = "0.2.180" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4" +checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc" [[package]] name = "log" @@ -327,9 +363,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.19.0" +version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" [[package]] name = "oorandom" @@ -444,6 +480,26 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rayon" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "readme-rustdocifier" version = "0.1.1" @@ -452,10 +508,11 @@ checksum = "08ad765b21a08b1a8e5cdce052719188a23772bcbefb3c439f0baaf62c56ceac" [[package]] name = "reed-solomon-simd" -version = "2.1.0" -source = "git+https://github.com/ordian/reed-solomon-simd?branch=simd-feature#52d42754a13508581cdc48dc7ea6321cfdf918db" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cffef0520d30fbd4151fb20e262947ae47fb0ab276a744a19b6398438105a072" dependencies = [ - "bytemuck", + "cpufeatures", "fixedbitset", "once_cell", "readme-rustdocifier", diff --git a/Cargo.toml b/Cargo.toml index 86d3594..c8bf445 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,13 +7,14 @@ license = "Apache-2.0" [dependencies] blake2b_simd = { version = "1", default-features = false } bounded-collections = { version = "0.1.9", default-features = false } -reed-solomon = { package = "reed-solomon-simd", git = "https://github.com/ordian/reed-solomon-simd", branch = "simd-feature", default-features = false } +reed-solomon = { package = "reed-solomon-simd", version = "3.1.0" } scale = { package = "parity-scale-codec", version = "3.6.9", default-features = false, features = ["derive"] } thiserror = { version = "1.0.56", default-features = false } +rayon = { version = "1.8" } +bumpalo = { version = "3.14", optional = true } [features] -default = ["simd"] -simd = ["reed-solomon/simd"] +arena = ["bumpalo"] [profile.dev] panic = "abort" diff --git a/benches/all.rs b/benches/all.rs index 5b0b5bc..fedcc86 100644 --- a/benches/all.rs +++ b/benches/all.rs @@ -2,13 +2,13 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Through use erasure_coding::*; use std::time::Duration; -fn chunks(n_chunks: u16, pov: &[u8]) -> Vec> { - construct_chunks(n_chunks, pov).unwrap() +fn chunks(n_chunks: u16, pov: &[u8], mode: &ThreadMode) -> Vec> { + construct_chunks(n_chunks, pov, mode).unwrap() } -fn erasure_root(n_chunks: u16, pov: &[u8]) -> ErasureRoot { - let chunks = chunks(n_chunks, pov); - MerklizedChunks::compute(chunks).root() +fn erasure_root(n_chunks: u16, pov: &[u8], mode: &ThreadMode) -> ErasureRoot { + let chunks = chunks(n_chunks, pov, mode); + MerklizedChunks::compute(chunks, mode).unwrap().root() } struct BenchParam { @@ -28,17 +28,19 @@ fn bench_all(c: &mut Criterion) { const POV_SIZES: [usize; 3] = [128 * KB, MB, 5 * MB]; const N_CHUNKS: [u16; 2] = [1023, 1024]; + let mode_multi = ThreadMode::multi_with_num_threads(None).unwrap(); + let mut group = c.benchmark_group("construct"); for pov_size in POV_SIZES { for n_chunks in N_CHUNKS { let param = BenchParam { pov_size, n_chunks }; let pov = vec![0xfe; pov_size]; - let expected_root = erasure_root(n_chunks, &pov); + let expected_root = erasure_root(n_chunks, &pov, &mode_multi); group.throughput(Throughput::Bytes(pov.len() as u64)); group.bench_with_input(BenchmarkId::from_parameter(param), &n_chunks, |b, &n| { b.iter(|| { - let root = erasure_root(n, &pov); + let root = erasure_root(n, &pov, &mode_multi); assert_eq!(root, expected_root); }); }); @@ -51,7 +53,7 @@ fn bench_all(c: &mut Criterion) { for n_chunks in N_CHUNKS { let param = BenchParam { pov_size, n_chunks }; let pov = vec![0xfe; pov_size]; - let all_chunks = chunks(n_chunks, &pov); + let all_chunks = chunks(n_chunks, &pov, &mode_multi); let chunks: Vec<_> = all_chunks .into_iter() @@ -64,7 +66,7 @@ fn bench_all(c: &mut Criterion) { group.throughput(Throughput::Bytes(pov.len() as u64)); group.bench_with_input(BenchmarkId::from_parameter(param), &n_chunks, |b, &n| { b.iter(|| { - let _pov: Vec = reconstruct(n, chunks.clone(), pov.len()).unwrap(); + let _pov: Vec = reconstruct(n, chunks.clone()).unwrap(); }); }); } @@ -76,7 +78,7 @@ fn bench_all(c: &mut Criterion) { for n_chunks in N_CHUNKS { let param = BenchParam { pov_size, n_chunks }; let pov = vec![0xfe; pov_size]; - let all_chunks = chunks(n_chunks, &pov); + let all_chunks = chunks(n_chunks, &pov, &mode_multi); let chunks = all_chunks .into_iter() @@ -86,13 +88,12 @@ fn bench_all(c: &mut Criterion) { group.throughput(Throughput::Bytes(pov.len() as u64)); group.bench_with_input(BenchmarkId::from_parameter(param), &n_chunks, |b, &n| { b.iter(|| { - let _pov: Vec = reconstruct_from_systematic( - n, - chunks.len(), - &mut chunks.iter().map(Vec::as_slice), - pov.len(), - ) - .unwrap(); + let _pov: Vec = reconstruct_from_systematic( + n, + chunks.len(), + &mut chunks.iter().map(Vec::as_slice), + ) + .unwrap(); }); }); } @@ -104,12 +105,13 @@ fn bench_all(c: &mut Criterion) { for n_chunks in N_CHUNKS { let param = BenchParam { pov_size, n_chunks }; let pov = vec![0xfe; pov_size]; - let all_chunks = chunks(n_chunks, &pov); + let all_chunks = chunks(n_chunks, &pov, &mode_multi); group.throughput(Throughput::Bytes(pov.len() as u64)); group.bench_with_input(BenchmarkId::from_parameter(param), &n_chunks, |b, _| { b.iter(|| { - let iter = MerklizedChunks::compute(all_chunks.clone()); + let iter = MerklizedChunks::compute(all_chunks.clone(), &mode_multi).unwrap(); + let n = iter.collect::>().len(); assert_eq!(n, all_chunks.len()); }); @@ -123,10 +125,12 @@ fn bench_all(c: &mut Criterion) { for n_chunks in N_CHUNKS { let param = BenchParam { pov_size, n_chunks }; let pov = vec![0xfe; pov_size]; - let all_chunks = chunks(n_chunks, &pov); - let merkle = MerklizedChunks::compute(all_chunks); + let all_chunks = chunks(n_chunks, &pov, &mode_multi); + + let merkle = MerklizedChunks::compute(all_chunks, &mode_multi).unwrap(); let root = merkle.root(); let chunks: Vec<_> = merkle.collect(); + let chunk = chunks[n_chunks as usize / 2].clone(); group.throughput(Throughput::Bytes(pov.len() as u64)); diff --git a/fuzz/fuzz_targets/round_trip.rs b/fuzz/fuzz_targets/round_trip.rs index f04512b..59a13a4 100644 --- a/fuzz/fuzz_targets/round_trip.rs +++ b/fuzz/fuzz_targets/round_trip.rs @@ -10,7 +10,8 @@ fuzz_target!(|data: (Vec, u16)| { if data.is_empty() || data.len() > 1 * 1024 * 1024 { return; } - let chunks = construct_chunks(n_chunks, &data).unwrap(); + let mode = ThreadMode::single(); + let chunks = construct_chunks(n_chunks, &data, &mode).unwrap(); assert_eq!(chunks.len() as u16, n_chunks); let threshold = systematic_recovery_threshold(n_chunks).unwrap(); @@ -18,7 +19,6 @@ fuzz_target!(|data: (Vec, u16)| { n_chunks, chunks.len(), &mut chunks.iter().map(Vec::as_slice), - data.len(), ) .unwrap(); @@ -29,7 +29,7 @@ fuzz_target!(|data: (Vec, u16)| { .map(|(i, v)| (ChunkIndex::from(i as u16), v)) .collect(); let some_chunks = map.into_iter().take(threshold as usize); - let reconstructed: Vec = reconstruct(n_chunks, some_chunks, data.len()).unwrap(); + let reconstructed: Vec = reconstruct(n_chunks, some_chunks).unwrap(); assert_eq!(reconstructed, data); assert_eq!(reconstructed_systematic, data); diff --git a/src/error.rs b/src/error.rs index ff72789..9bb14a6 100644 --- a/src/error.rs +++ b/src/error.rs @@ -26,6 +26,8 @@ pub enum Error { Bug, #[error("An unknown error has appeared when (re)constructing erasure code chunks")] Unknown, + #[error("Invalid data length")] + InvalidDataLength, } impl From for Error { diff --git a/src/lib.rs b/src/lib.rs index 837baa7..6ad90f4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,11 +15,84 @@ use scale::{Decode, Encode}; use std::ops::AddAssign; pub use subshard::*; +use rayon::prelude::*; +use std::sync::Arc; + +#[cfg(feature = "arena")] +use bumpalo::Bump; + +// Prefetch hints for cache locality optimization +#[cfg(target_arch = "x86")] +use std::arch::x86::*; +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +// Branch prediction hints +#[cold] +#[inline(never)] +fn cold() {} + +#[inline(always)] +fn likely(b: bool) -> bool { + if !b { + cold(); + } + b +} + +#[inline(always)] +fn unlikely(b: bool) -> bool { + if b { + cold(); + } + b +} + pub const MAX_CHUNKS: u16 = 16384; // The reed-solomon library requires each shards to be 64 bytes aligned. const SHARD_ALIGNMENT: usize = 64; +const PADDING_ALIGNMENT: usize = 4; + +#[derive(Clone)] +pub enum ThreadMode { + Multi(Arc), + Single, +} + +impl ThreadMode { + pub fn multi_with_num_threads(num_threads: Option) -> Result { + let threads = match num_threads { + None => { + let logical_cores = + std::thread::available_parallelism().map(|n| n.get()).unwrap_or(1); + (logical_cores / 2).max(1) + }, + Some(0) => std::thread::available_parallelism().map(|n| n.get()).unwrap_or(1), + Some(n) => n, + }; + + let pool = rayon::ThreadPoolBuilder::new() + .num_threads(threads) + .build() + .map_err(|_| Error::Unknown)?; + + Ok(ThreadMode::Multi(Arc::new(pool))) + } + + pub fn single() -> Self { + ThreadMode::Single + } + + pub fn num_threads(&self) -> usize { + match self { + ThreadMode::Multi(pool) => pool.current_num_threads(), + ThreadMode::Single => 1, + } + } +} + /// The index of an erasure chunk. #[derive(Eq, Ord, PartialEq, PartialOrd, Copy, Clone, Encode, Decode, Hash, Debug)] pub struct ChunkIndex(pub u16); @@ -48,6 +121,7 @@ pub struct ErasureChunk { } /// Obtain a threshold of chunks that should be enough to recover the data. +#[inline] pub const fn recovery_threshold(n_chunks: u16) -> Result { if n_chunks > MAX_CHUNKS { return Err(Error::TooManyTotalChunks); @@ -61,6 +135,7 @@ pub const fn recovery_threshold(n_chunks: u16) -> Result { } /// Obtain the threshold of systematic chunks that should be enough to recover the data. +#[inline] pub fn systematic_recovery_threshold(n_chunks: u16) -> Result { recovery_threshold(n_chunks) } @@ -69,48 +144,49 @@ pub fn systematic_recovery_threshold(n_chunks: u16) -> Result { /// /// Provide a vector containing the first k chunks in order. If too few chunks are provided, /// recovery is not possible. -/// -/// Due to the internals of the erasure coding algorithm, the output might be -/// larger than the original data and padded with zeroes; passing `data_len` -/// allows to truncate the output to the original data size. pub fn reconstruct_from_systematic<'a>( n_chunks: u16, systematic_len: usize, systematic_chunks: &'a mut impl Iterator, - data_len: usize, ) -> Result, Error> { let k = systematic_recovery_threshold(n_chunks)? as usize; - if systematic_len < k { + if unlikely(systematic_len < k) { return Err(Error::NotEnoughChunks); } + let mut bytes: Vec = Vec::with_capacity(0); let mut shard_len = 0; let mut nb = 0; + for chunk in systematic_chunks.by_ref() { nb += 1; - if shard_len == 0 { + if unlikely(shard_len == 0) { shard_len = chunk.len(); - if shard_len % SHARD_ALIGNMENT != 0 && nb != k { + if unlikely(shard_len % SHARD_ALIGNMENT != 0 && nb != k) { return Err(Error::UnalignedChunk); } - if k == 1 { - return Ok(chunk[..data_len].to_vec()); + if unlikely(k == 1) { + let mut result = chunk.to_vec(); + remove_padding(&mut result); + return Ok(result); } bytes = Vec::with_capacity(shard_len * k); } - if chunk.len() != shard_len { - return Err(Error::NonUniformChunks) + if unlikely(chunk.len() != shard_len) { + return Err(Error::NonUniformChunks); } + // extend_from_slice uses optimized memcpy bytes.extend_from_slice(chunk); - if nb == k { + + if unlikely(nb == k) { break; } } - bytes.resize(data_len, 0); + remove_padding(&mut bytes); Ok(bytes) } @@ -118,15 +194,44 @@ pub fn reconstruct_from_systematic<'a>( /// /// Works only for 1..65536 chunks. /// The data must be non-empty. -pub fn construct_chunks(n_chunks: u16, data: &[u8]) -> Result>, Error> { - if data.is_empty() { +pub fn construct_chunks( + n_chunks: u16, + data: &[u8], + mode: &ThreadMode, +) -> Result>, Error> { + if unlikely(data.is_empty()) { return Err(Error::BadPayload); } - if n_chunks == 1 { - return Ok(vec![data.to_vec()]); + + let padded = add_padding(data); + + if unlikely(n_chunks == 1) { + return Ok(vec![padded]); } + + #[cfg(feature = "arena")] + { + construct_chunks_arena(n_chunks, &padded, mode) + } + + #[cfg(not(feature = "arena"))] + { + construct_chunks_default(n_chunks, &padded, mode) + } +} + +/// Construct erasure-coded chunks. +/// +/// Works only for 1..65536 chunks. +/// The data must be non-empty. +#[inline] +fn construct_chunks_default( + n_chunks: u16, + data: &[u8], + mode: &ThreadMode, +) -> Result>, Error> { let systematic = systematic_recovery_threshold(n_chunks)?; - let original_data = make_original_shards(systematic, data); + let original_data = make_original_shards(systematic, data, mode)?; let original_iter = original_data.iter(); let original_count = systematic as usize; let recovery_count = (n_chunks - systematic) as usize; @@ -139,29 +244,152 @@ pub fn construct_chunks(n_chunks: u16, data: &[u8]) -> Result>, Erro Ok(result) } +#[cfg(feature = "arena")] +fn construct_chunks_arena( + n_chunks: u16, + data: &[u8], + _mode: &ThreadMode, +) -> Result>, Error> { + let systematic = systematic_recovery_threshold(n_chunks)?; + let original_count = systematic as usize; + let recovery_count = (n_chunks - systematic) as usize; + let shard_size = shard_bytes(systematic, data.len()); + + // Arena for temporary allocations + let arena = Bump::with_capacity(original_count * shard_size + 4096); + + // Create shards using arena for intermediate data + let original_data = make_original_shards_arena(&arena, systematic, data, shard_size); + let original_iter = original_data.iter(); + + let recovery = reed_solomon::encode(original_count, recovery_count, original_iter)?; + + let mut result = original_data; + result.extend(recovery); + + Ok(result) +} + +/// Creating shards using arena allocator +#[cfg(feature = "arena")] +fn make_original_shards_arena( + _arena: &Bump, + original_count: u16, + data: &[u8], + shard_size: usize, +) -> Vec> { + let total_size = original_count as usize * shard_size; + let mut flat_buffer = vec![0u8; total_size]; + + let data_to_copy = data.len().min(total_size); + flat_buffer[..data_to_copy].copy_from_slice(&data[..data_to_copy]); + + let mut result = Vec::with_capacity(original_count as usize); + for chunk_data in flat_buffer.chunks_exact(shard_size) { + result.push(chunk_data.to_vec()); + } + + result +} + +#[inline(always)] fn next_aligned(n: usize, alignment: usize) -> usize { ((n + alignment - 1) / alignment) * alignment } +#[inline] fn shard_bytes(systematic: u16, data_len: usize) -> usize { let shard_bytes = (data_len + systematic as usize - 1) / systematic as usize; next_aligned(shard_bytes, SHARD_ALIGNMENT) } +#[inline] +fn add_padding(data: &[u8]) -> Vec { + let remainder = data.len() % PADDING_ALIGNMENT; + let padding_len = if remainder == 0 { PADDING_ALIGNMENT } else { PADDING_ALIGNMENT - remainder }; + let mut padded = Vec::with_capacity(data.len() + padding_len); + padded.extend_from_slice(data); + padded.resize(data.len() + padding_len, padding_len as u8); + padded +} + +#[inline] +fn remove_padding(bytes: &mut Vec) { + // Find the last non-zero byte + if let Some(last_non_zero) = bytes.iter().rposition(|&b| b != 0) { + // Truncate trailing zeros + bytes.truncate(last_non_zero + 1); + // Last byte is the padding length + let padding_len = bytes[last_non_zero] as usize; + // Remove padding bytes + bytes.truncate(bytes.len().saturating_sub(padding_len)); + } else { + // All zeros — shouldn't happen if padding was added correctly + bytes.clear(); + } +} + // The reed-solomon library takes sharded data as input. -fn make_original_shards(original_count: u16, data: &[u8]) -> Vec> { +fn make_original_shards( + original_count: u16, + data: &[u8], + mode: &ThreadMode, +) -> Result>, Error> { assert!(!data.is_empty(), "data must be non-empty"); assert_ne!(original_count, 0); let shard_bytes = shard_bytes(original_count, data.len()); assert_ne!(shard_bytes, 0); - let mut result = vec![vec![0u8; shard_bytes]; original_count as usize]; - for (i, chunk) in data.chunks(shard_bytes).enumerate() { - result[i][..chunk.len()].as_mut().copy_from_slice(chunk); + match mode { + ThreadMode::Multi(pool) => Ok(pool.install(|| { + (0..original_count as usize) + .into_par_iter() + .map(|i| { + let mut chunk = vec![0u8; shard_bytes]; + let start = i * shard_bytes; + let end = (start + shard_bytes).min(data.len()); + + if likely(start < data.len()) { + let copy_len = end - start; + chunk[..copy_len].copy_from_slice(&data[start..end]); + } + + chunk + }) + .collect() + })), + ThreadMode::Single => { + let mut result = Vec::with_capacity(original_count as usize); + let mut remaining_data = data; + + for i in 0..original_count as usize { + let mut chunk = vec![0u8; shard_bytes]; + let copy_len = remaining_data.len().min(shard_bytes); + + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + if i + 1 < original_count as usize && remaining_data.len() > shard_bytes { + unsafe { + let next_ptr = remaining_data.as_ptr().add(shard_bytes); + if (remaining_data.len() - shard_bytes) >= 64 { + _mm_prefetch(next_ptr as *const i8, _MM_HINT_T0); + } + } + } + + chunk[..copy_len].copy_from_slice(&remaining_data[..copy_len]); + + if likely(remaining_data.len() >= shard_bytes) { + remaining_data = &remaining_data[shard_bytes..]; + } else { + remaining_data = &[]; + } + + result.push(chunk); + } + Ok(result) + }, } - - result } /// Reconstruct the original data from a set of chunks. @@ -171,16 +399,14 @@ fn make_original_shards(original_count: u16, data: &[u8]) -> Vec> { /// are provided, recovery is not possible. /// /// Works only for 1..65536 chunks. -/// -/// Due to the internals of the erasure coding algorithm, the output might be -/// larger than the original data and padded with zeroes; passing `data_len` -/// allows to truncate the output to the original data size. -pub fn reconstruct(n_chunks: u16, chunks: I, data_len: usize) -> Result, Error> +pub fn reconstruct(n_chunks: u16, chunks: I) -> Result, Error> where I: IntoIterator)>, { if n_chunks == 1 { - return chunks.into_iter().next().map(|(_, v)| v).ok_or(Error::NotEnoughChunks); + let mut data = chunks.into_iter().next().map(|(_, v)| v).ok_or(Error::NotEnoughChunks)?; + remove_padding(&mut data); + return Ok(data); } let n = n_chunks as usize; let original_count = systematic_recovery_threshold(n_chunks)? as usize; @@ -198,7 +424,13 @@ where let mut recovered = reed_solomon::decode(original_count, recovery_count, original_iter, recovery)?; - let shard_bytes = shard_bytes(original_count as u16, data_len); + let shard_bytes = recovered + .values() + .next() + .or_else(|| original.first().map(|(_, v)| v)) + .map(|v| v.len()) + .ok_or(Error::NotEnoughChunks)?; + let mut bytes = Vec::with_capacity(shard_bytes * original_count); let mut original = original.into_iter(); @@ -211,7 +443,7 @@ where bytes.extend_from_slice(chunk.as_slice()); } - bytes.truncate(data_len); + remove_padding(&mut bytes); Ok(bytes) } @@ -255,16 +487,18 @@ mod tests { fn property(available_data: ArbitraryAvailableData, n_chunks: u16) { let n_chunks = n_chunks.max(1).min(MAX_CHUNKS); let threshold = systematic_recovery_threshold(n_chunks).unwrap(); - let data_len = available_data.0.len(); - let chunks = construct_chunks(n_chunks, &available_data.0).unwrap(); - let reconstructed: Vec = reconstruct_from_systematic( - n_chunks, - chunks.len(), - &mut chunks.iter().take(threshold as usize).map(Vec::as_slice), - data_len, - ) - .unwrap(); - assert_eq!(reconstructed, available_data.0); + + for mode in [ThreadMode::single(), ThreadMode::multi_with_num_threads(None).unwrap()] { + let chunks = construct_chunks(n_chunks, &available_data.0, &mode).unwrap(); + + let reconstructed: Vec = reconstruct_from_systematic( + n_chunks, + chunks.len(), + &mut chunks.iter().take(threshold as usize).map(Vec::as_slice), + ) + .unwrap(); + assert_eq!(reconstructed, available_data.0); + } } QuickCheck::new().quickcheck(property as fn(ArbitraryAvailableData, u16)) @@ -274,17 +508,19 @@ mod tests { fn round_trip_works() { fn property(available_data: ArbitraryAvailableData, n_chunks: u16) { let n_chunks = n_chunks.max(1).min(MAX_CHUNKS); - let data_len = available_data.0.len(); let threshold = recovery_threshold(n_chunks).unwrap(); - let chunks = construct_chunks(n_chunks, &available_data.0).unwrap(); - let map: HashMap> = chunks - .into_iter() - .enumerate() - .map(|(i, v)| (ChunkIndex::from(i as u16), v)) - .collect(); - let some_chunks = map.into_iter().take(threshold as usize); - let reconstructed: Vec = reconstruct(n_chunks, some_chunks, data_len).unwrap(); - assert_eq!(reconstructed, available_data.0); + + for mode in [ThreadMode::single(), ThreadMode::multi_with_num_threads(None).unwrap()] { + let chunks = construct_chunks(n_chunks, &available_data.0, &mode).unwrap(); + let map: HashMap> = chunks + .into_iter() + .enumerate() + .map(|(i, v)| (ChunkIndex::from(i as u16), v)) + .collect(); + let some_chunks = map.into_iter().take(threshold as usize); + let reconstructed: Vec = reconstruct(n_chunks, some_chunks).unwrap(); + assert_eq!(reconstructed, available_data.0); + } } QuickCheck::new().quickcheck(property as fn(ArbitraryAvailableData, u16)) @@ -294,27 +530,394 @@ mod tests { fn proof_verification_works() { fn property(data: SmallAvailableData, n_chunks: u16) { let n_chunks = n_chunks.max(1).min(2048); - let chunks = construct_chunks(n_chunks, &data.0).unwrap(); - assert_eq!(chunks.len() as u16, n_chunks); - let iter = MerklizedChunks::compute(chunks.clone()); - let root = iter.root(); - let erasure_chunks: Vec<_> = iter.collect(); + for mode in [ThreadMode::single(), ThreadMode::multi_with_num_threads(None).unwrap()] { + let chunks = construct_chunks(n_chunks, &data.0, &mode).unwrap(); + assert_eq!(chunks.len() as u16, n_chunks); + let iter = MerklizedChunks::compute(chunks.clone(), &mode).unwrap(); + let root = iter.root(); + let erasure_chunks: Vec<_> = iter.collect(); - assert_eq!(erasure_chunks.len(), chunks.len()); + assert_eq!(erasure_chunks.len(), chunks.len()); - for erasure_chunk in erasure_chunks.into_iter() { - let encode = Encode::encode(&erasure_chunk.proof); - let decode = Decode::decode(&mut &encode[..]).unwrap(); - assert_eq!(erasure_chunk.proof, decode); - assert_eq!(encode, Encode::encode(&decode)); + for erasure_chunk in erasure_chunks.into_iter() { + let encode = Encode::encode(&erasure_chunk.proof); + let decode = Decode::decode(&mut &encode[..]).unwrap(); + assert_eq!(erasure_chunk.proof, decode); + assert_eq!(encode, Encode::encode(&decode)); - assert_eq!(&erasure_chunk.chunk, &chunks[erasure_chunk.index.0 as usize]); + assert_eq!(&erasure_chunk.chunk, &chunks[erasure_chunk.index.0 as usize]); - assert!(erasure_chunk.verify(&root)); + assert!(erasure_chunk.verify(&root)); + } } } QuickCheck::new().quickcheck(property as fn(SmallAvailableData, u16)) } + + #[test] + fn stress_test_various_sizes_with_random_chunk_loss() { + use rand::{seq::SliceRandom, Rng, SeedableRng}; + + let data_sizes = vec![10, 1000, 10_000, 100_000, 1_000_000, 10_000_000, 50_000_000]; + + let chunk_configs = vec![16, 64, 256, 1024]; + + for data_size in data_sizes.iter() { + println!("Testing data size: {} bytes", data_size); + + for &n_chunks in chunk_configs.iter() { + if *data_size < 1000 && n_chunks > 64 { + continue; + } + + println!(" Testing with {} chunks", n_chunks); + + let mut rng = + rand::rngs::SmallRng::seed_from_u64((*data_size as u64) ^ (n_chunks as u64)); + let original_data: Vec = (0..*data_size).map(|_| rng.gen()).collect(); + + for (mode_name, mode) in [ + ("Single", ThreadMode::single()), + ("Multi", ThreadMode::multi_with_num_threads(None).unwrap()), + ] { + let chunks = construct_chunks(n_chunks, &original_data, &mode).unwrap(); + + assert_eq!(chunks.len(), n_chunks as usize); + + let threshold = recovery_threshold(n_chunks).unwrap() as usize; + + let mut chunk_indices: Vec = (0..n_chunks as usize).collect(); + + chunk_indices.shuffle(&mut rng); + + let selected_indices = &chunk_indices[..threshold]; + + let available_chunks: HashMap> = selected_indices + .iter() + .map(|&idx| (ChunkIndex(idx as u16), chunks[idx].clone())) + .collect(); + + let reconstructed = + reconstruct(n_chunks, available_chunks.into_iter()) + .unwrap(); + + assert_eq!( + reconstructed.len(), + original_data.len(), + "Reconstructed data length mismatch for size {} with {} chunks (mode: {})", + data_size, + n_chunks, + mode_name + ); + + assert_eq!( + reconstructed, original_data, + "Reconstructed data does not match original for size {} with {} chunks (mode: {})", + data_size, n_chunks, mode_name + ); + } + } + + println!(" ✓ All chunk configurations passed for size {}", data_size); + } + + println!("✓ All stress tests passed!"); + } + + #[test] + fn test_thread_mode_configurations() { + use std::thread::available_parallelism; + + let data = vec![1u8; 1024]; + let n_chunks = 16; + + let mode_default = ThreadMode::multi_with_num_threads(None).unwrap(); + let logical_cores = available_parallelism().map(|n| n.get()).unwrap_or(1); + let expected_default = (logical_cores / 2).max(1); + assert_eq!( + mode_default.num_threads(), + expected_default, + "Thread mode with None should use half of logical cores" + ); + let chunks = construct_chunks(n_chunks, &data, &mode_default).unwrap(); + assert_eq!(chunks.len(), n_chunks as usize); + + let all_cores = available_parallelism().map(|n| n.get()).unwrap_or(1); + let mode_all = ThreadMode::multi_with_num_threads(Some(0)).unwrap(); + assert_eq!( + mode_all.num_threads(), + all_cores, + "Thread mode with Some(0) should use all logical cores" + ); + let chunks = construct_chunks(n_chunks, &data, &mode_all).unwrap(); + assert_eq!(chunks.len(), n_chunks as usize); + + let mode_2 = ThreadMode::multi_with_num_threads(Some(2)).unwrap(); + assert_eq!( + mode_2.num_threads(), + 2, + "Thread mode with Some(2) should use exactly 2 threads" + ); + let chunks = construct_chunks(n_chunks, &data, &mode_2).unwrap(); + assert_eq!(chunks.len(), n_chunks as usize); + + let mode_4 = ThreadMode::multi_with_num_threads(Some(4)).unwrap(); + assert_eq!( + mode_4.num_threads(), + 4, + "Thread mode with Some(4) should use exactly 4 threads" + ); + let chunks = construct_chunks(n_chunks, &data, &mode_4).unwrap(); + assert_eq!(chunks.len(), n_chunks as usize); + + let mode_single = ThreadMode::single(); + assert_eq!(mode_single.num_threads(), 1, "Single thread mode should report 1 thread"); + let chunks = construct_chunks(n_chunks, &data, &mode_single).unwrap(); + assert_eq!(chunks.len(), n_chunks as usize); + + println!("✓ Thread mode configuration test passed!"); + } + + #[test] + fn test_padding_add_remove() { + // Alignment 4: data of length 3 → 1 byte of padding [1] + let data = vec![10, 20, 30]; + let padded = add_padding(&data); + assert_eq!(padded, vec![10, 20, 30, 1]); + + // Alignment 4: data of length 4 → 4 bytes of padding [4,4,4,4] + let data = vec![10, 20, 30, 40]; + let padded = add_padding(&data); + assert_eq!(padded, vec![10, 20, 30, 40, 4, 4, 4, 4]); + + // Alignment 4: data of length 5 → 3 bytes of padding [3,3,3] + let data = vec![1, 2, 3, 4, 5]; + let padded = add_padding(&data); + assert_eq!(padded, vec![1, 2, 3, 4, 5, 3, 3, 3]); + + // Alignment 4: data of length 1 → 3 bytes of padding [3,3,3] + let data = vec![42]; + let padded = add_padding(&data); + assert_eq!(padded, vec![42, 3, 3, 3]); + + // Test remove_padding reverses add_padding + for len in 1..=20 { + let data: Vec = (0..len).map(|i| (i * 7 + 13) as u8).collect(); + let mut padded = add_padding(&data); + // Simulate trailing zeros from reed-solomon + padded.extend_from_slice(&[0u8; 100]); + remove_padding(&mut padded); + assert_eq!(padded, data, "Round-trip failed for data length {}", len); + } + } + + #[test] + fn test_padding_data_ending_with_zeros() { + // Data consisting entirely of zeros + for len in 1..=16 { + let data = vec![0u8; len]; + + for mode in [ThreadMode::single(), ThreadMode::multi_with_num_threads(None).unwrap()] { + let n_chunks = 4u16; + let chunks = construct_chunks(n_chunks, &data, &mode).unwrap(); + + // Test reconstruct_from_systematic + let systematic = systematic_recovery_threshold(n_chunks).unwrap() as usize; + let reconstructed_sys = reconstruct_from_systematic( + n_chunks, + chunks.len(), + &mut chunks.iter().take(systematic).map(Vec::as_slice), + ) + .unwrap(); + assert_eq!( + reconstructed_sys, data, + "Systematic failed for zero-data of length {} (mode: {:?})", + len, mode.num_threads() + ); + + // Test reconstruct + let threshold = recovery_threshold(n_chunks).unwrap(); + let map: HashMap> = chunks + .into_iter() + .enumerate() + .take(threshold as usize) + .map(|(i, v)| (ChunkIndex::from(i as u16), v)) + .collect(); + let reconstructed = reconstruct(n_chunks, map.into_iter()).unwrap(); + assert_eq!( + reconstructed, data, + "Reconstruct failed for zero-data of length {} (mode: {:?})", + len, mode.num_threads() + ); + } + } + } + + #[test] + fn test_padding_aligned_and_unaligned_data() { + // Test various data lengths: multiples of 4 and non-multiples + let test_sizes = vec![ + 1, 2, 3, 4, 5, 6, 7, 8, + 15, 16, 17, + 63, 64, 65, + 100, 127, 128, 129, + 255, 256, 257, + 1000, 1023, 1024, 1025, + ]; + + for data_len in test_sizes { + let original_data: Vec = (0..data_len).map(|i| (i % 256) as u8).collect(); + + for n_chunks in [2u16, 4, 8, 16] { + for mode in [ThreadMode::single(), ThreadMode::multi_with_num_threads(None).unwrap()] { + let chunks = construct_chunks(n_chunks, &original_data, &mode).unwrap(); + + // Test reconstruct_from_systematic + let systematic = systematic_recovery_threshold(n_chunks).unwrap() as usize; + let reconstructed_sys = reconstruct_from_systematic( + n_chunks, + chunks.len(), + &mut chunks.iter().take(systematic).map(Vec::as_slice), + ) + .unwrap(); + assert_eq!( + reconstructed_sys, original_data, + "Systematic failed: data_len={}, n_chunks={}", + data_len, n_chunks + ); + + // Test reconstruct + let threshold = recovery_threshold(n_chunks).unwrap(); + let map: HashMap> = chunks + .into_iter() + .enumerate() + .take(threshold as usize) + .map(|(i, v)| (ChunkIndex::from(i as u16), v)) + .collect(); + let reconstructed = reconstruct(n_chunks, map.into_iter()).unwrap(); + assert_eq!( + reconstructed, original_data, + "Reconstruct failed: data_len={}, n_chunks={}", + data_len, n_chunks + ); + } + } + } + } + + #[test] + fn test_padding_data_with_padding_like_values() { + // Data ending with bytes that look like padding values [4,4,4,4] + let data = vec![4u8; 4]; + for n_chunks in [2u16, 4, 8] { + let mode = ThreadMode::single(); + let chunks = construct_chunks(n_chunks, &data, &mode).unwrap(); + let threshold = recovery_threshold(n_chunks).unwrap(); + let map: HashMap> = chunks + .into_iter() + .enumerate() + .take(threshold as usize) + .map(|(i, v)| (ChunkIndex::from(i as u16), v)) + .collect(); + let reconstructed = reconstruct(n_chunks, map.into_iter()).unwrap(); + assert_eq!(reconstructed, data, "Failed for data=[4,4,4,4], n_chunks={}", n_chunks); + } + + // Data ending with [1] + let data = vec![1u8]; + for n_chunks in [2u16, 4, 8] { + let mode = ThreadMode::single(); + let chunks = construct_chunks(n_chunks, &data, &mode).unwrap(); + let threshold = recovery_threshold(n_chunks).unwrap(); + let map: HashMap> = chunks + .into_iter() + .enumerate() + .take(threshold as usize) + .map(|(i, v)| (ChunkIndex::from(i as u16), v)) + .collect(); + let reconstructed = reconstruct(n_chunks, map.into_iter()).unwrap(); + assert_eq!(reconstructed, data, "Failed for data=[1], n_chunks={}", n_chunks); + } + + // Data ending with [3, 3, 3] + let data = vec![3u8; 3]; + for n_chunks in [2u16, 4, 8] { + let mode = ThreadMode::single(); + let chunks = construct_chunks(n_chunks, &data, &mode).unwrap(); + let threshold = recovery_threshold(n_chunks).unwrap(); + let map: HashMap> = chunks + .into_iter() + .enumerate() + .take(threshold as usize) + .map(|(i, v)| (ChunkIndex::from(i as u16), v)) + .collect(); + let reconstructed = reconstruct(n_chunks, map.into_iter()).unwrap(); + assert_eq!(reconstructed, data, "Failed for data=[3,3,3], n_chunks={}", n_chunks); + } + } + + #[test] + fn test_padding_random_data() { + use rand::{Rng, SeedableRng}; + + let mut rng = rand::rngs::SmallRng::seed_from_u64(12345); + + for _ in 0..50 { + let data_len = rng.gen_range(1..=4096); + let original_data: Vec = (0..data_len).map(|_| rng.gen()).collect(); + let n_chunks = [2u16, 4, 8, 16, 32][rng.gen_range(0..5)]; + + let mode = ThreadMode::single(); + let chunks = construct_chunks(n_chunks, &original_data, &mode).unwrap(); + + // Test reconstruct_from_systematic + let systematic = systematic_recovery_threshold(n_chunks).unwrap() as usize; + let reconstructed_sys = reconstruct_from_systematic( + n_chunks, + chunks.len(), + &mut chunks.iter().take(systematic).map(Vec::as_slice), + ) + .unwrap(); + assert_eq!( + reconstructed_sys, original_data, + "Systematic failed: data_len={}, n_chunks={}", + data_len, n_chunks + ); + + // Test reconstruct + let threshold = recovery_threshold(n_chunks).unwrap(); + let map: HashMap> = chunks + .into_iter() + .enumerate() + .take(threshold as usize) + .map(|(i, v)| (ChunkIndex::from(i as u16), v)) + .collect(); + let reconstructed = reconstruct(n_chunks, map.into_iter()).unwrap(); + assert_eq!( + reconstructed, original_data, + "Reconstruct failed: data_len={}, n_chunks={}", + data_len, n_chunks + ); + } + } + + #[test] + fn test_padding_single_chunk() { + // n_chunks == 1: special case + let data = vec![1, 2, 3, 4, 5]; + let mode = ThreadMode::single(); + let chunks = construct_chunks(1, &data, &mode).unwrap(); + assert_eq!(chunks.len(), 1); + + // reconstruct with n_chunks == 1 + let map: HashMap> = chunks + .into_iter() + .enumerate() + .map(|(i, v)| (ChunkIndex::from(i as u16), v)) + .collect(); + let reconstructed = reconstruct(1, map.into_iter()).unwrap(); + assert_eq!(reconstructed, data); + } } diff --git a/src/merklize.rs b/src/merklize.rs index f2522c4..9a3eefb 100644 --- a/src/merklize.rs +++ b/src/merklize.rs @@ -6,6 +6,9 @@ use scale::{Decode, Encode}; use blake2b_simd::{blake2b as hash_fn, Hash as InnerHash, State as InnerHasher}; +use crate::ThreadMode; +use rayon::prelude::*; + // Binary Merkle Tree with 16-bit `ChunkIndex` has depth at most 17. // The proof has at most `depth - 1` length. const MAX_MERKLE_PROOF_DEPTH: u32 = 16; @@ -98,7 +101,9 @@ impl Iterator for MerklizedChunks { let d = self.tree.len() - 1; let idx = self.current_index.0; let mut index = idx as usize; + let mut path = Vec::with_capacity(d); + for i in 0..d { let layer = &self.tree[i]; if index % 2 == 0 { @@ -119,46 +124,66 @@ impl Iterator for MerklizedChunks { impl MerklizedChunks { /// Compute `MerklizedChunks` from a list of erasure chunks. - pub fn compute(chunks: Vec>) -> Self { - let mut hashes: Vec = chunks - .iter() - .map(|chunk| { - let hash = hash_fn(chunk); - Hash::from(hash) - }) - .collect(); - hashes.resize(chunks.len().next_power_of_two(), Hash::default()); + pub fn compute(chunks: Vec>, mode: &ThreadMode) -> Result { + let chunks_len = chunks.len(); + let target_size = chunks_len.next_power_of_two(); + + let mut hashes: Vec = match mode { + ThreadMode::Multi(pool) => pool.install(|| { + chunks.par_iter().map(|chunk| Hash::from(hash_fn(chunk))).collect::>() + }), + ThreadMode::Single => + chunks.iter().map(|chunk| Hash::from(hash_fn(chunk))).collect::>(), + }; + + hashes.resize(target_size, Hash::default()); let depth = hashes.len().ilog2() as usize + 1; - let mut tree = vec![Vec::new(); depth]; + let mut tree = Vec::with_capacity(depth); + + for lvl in 0..depth { + let len = if lvl == 0 { target_size } else { 2usize.pow((depth - 1 - lvl) as u32) }; + tree.push(Vec::with_capacity(len)); + } + tree[0] = hashes; - // Build the tree bottom-up. - (1..depth).for_each(|lvl| { + for lvl in 1..depth { let len = 2usize.pow((depth - 1 - lvl) as u32); tree[lvl].resize(len, Hash::default()); - // NOTE: This can be parallelized. - (0..len).for_each(|i| { - let prev = &tree[lvl - 1]; - - let hash = combine(prev[2 * i], prev[2 * i + 1]); - - tree[lvl][i] = hash; - }); - }); + let (prev_slice, out_slice) = tree.split_at_mut(lvl); + let prev = &*prev_slice.last().unwrap(); + let out = &mut out_slice[0]; + + match mode { + ThreadMode::Multi(pool) => { + pool.install(|| { + out.par_iter_mut().enumerate().for_each(|(i, out_val)| { + *out_val = combine(prev[2 * i], prev[2 * i + 1]); + }); + }); + }, + ThreadMode::Single => { + out.iter_mut().enumerate().for_each(|(i, out_val)| { + *out_val = combine(prev[2 * i], prev[2 * i + 1]); + }); + }, + } + } assert!(tree[tree.len() - 1].len() == 1, "root must be a single hash"); - Self { + Ok(Self { root: ErasureRoot::from(tree[tree.len() - 1][0]), data: chunks.into(), tree, current_index: ChunkIndex::from(0), - } + }) } } +#[inline(always)] fn combine(left: Hash, right: Hash) -> Hash { let mut hasher = InnerHasher::new(); @@ -172,6 +197,7 @@ fn combine(left: Hash, right: Hash) -> Hash { impl ErasureChunk { /// Verify the proof of the chunk against the erasure root and index. + #[inline] pub fn verify(&self, root: &ErasureRoot) -> bool { let leaf_hash = Hash::from(hash_fn(&self.chunk)); let bits = Bitfield(self.index.0); @@ -203,58 +229,65 @@ mod tests { #[test] fn zero_chunks_works() { let chunks = vec![]; - let iter = MerklizedChunks::compute(chunks.clone()); - let root = iter.root(); - let erasure_chunks: Vec = iter.collect(); - assert_eq!(erasure_chunks.len(), chunks.len()); - assert_eq!(root, ErasureRoot(Hash::default())); + + for mode in [ThreadMode::single(), ThreadMode::multi_with_num_threads(None).unwrap()] { + let iter = MerklizedChunks::compute(chunks.clone(), &mode).unwrap(); + + let root = iter.root(); + let erasure_chunks: Vec = iter.collect(); + assert_eq!(erasure_chunks.len(), chunks.len()); + assert_eq!(root, ErasureRoot(Hash::default())); + } } #[test] fn iter_works() { let chunks = vec![vec![1], vec![2], vec![3]]; - let iter = MerklizedChunks::compute(chunks.clone()); - let root = iter.root(); - let erasure_chunks: Vec = iter.collect(); - assert_eq!(erasure_chunks.len(), chunks.len()); - // compute the proof manually - let proof_0 = { - let a0 = hash_fn(&chunks[0]).into(); - let a1 = hash_fn(&chunks[1]).into(); - let a2 = hash_fn(&chunks[2]).into(); - let a3 = Hash::default(); + for mode in [ThreadMode::single(), ThreadMode::multi_with_num_threads(None).unwrap()] { + let iter = MerklizedChunks::compute(chunks.clone(), &mode).unwrap(); - let b0 = combine(a0, a1); - let b1 = combine(a2, a3); + let root = iter.root(); + let erasure_chunks: Vec = iter.collect(); + assert_eq!(erasure_chunks.len(), chunks.len()); - let c0 = combine(b0, b1); + let proof_0 = { + let a0 = hash_fn(&chunks[0]).into(); + let a1 = hash_fn(&chunks[1]).into(); + let a2 = hash_fn(&chunks[2]).into(); + let a3 = Hash::default(); - assert_eq!(c0, root.0); + let b0 = combine(a0, a1); + let b1 = combine(a2, a3); - let p = vec![a1, b1]; - Proof::try_from(p).unwrap() - }; + let c0 = combine(b0, b1); - assert_eq!(erasure_chunks[0].proof, proof_0); + assert_eq!(c0, root.0); - let invalid_1 = ErasureChunk { - chunk: erasure_chunks[0].chunk.clone(), - proof: erasure_chunks[0].proof.clone(), - index: ChunkIndex(erasure_chunks[0].index.0 + 1), - }; + let p = vec![a1, b1]; + Proof::try_from(p).unwrap() + }; - let invalid_2 = ErasureChunk { - chunk: erasure_chunks[0].chunk.clone(), - proof: erasure_chunks[0].proof.clone(), - index: ChunkIndex(erasure_chunks[0].index.0 | 1 << 15), - }; + assert_eq!(erasure_chunks[0].proof, proof_0); - assert!(!invalid_1.verify(&root)); - assert!(!invalid_2.verify(&root)); + let invalid_1 = ErasureChunk { + chunk: erasure_chunks[0].chunk.clone(), + proof: erasure_chunks[0].proof.clone(), + index: ChunkIndex(erasure_chunks[0].index.0 + 1), + }; - for chunk in erasure_chunks { - assert!(chunk.verify(&root)); + let invalid_2 = ErasureChunk { + chunk: erasure_chunks[0].chunk.clone(), + proof: erasure_chunks[0].proof.clone(), + index: ChunkIndex(erasure_chunks[0].index.0 | 1 << 15), + }; + + assert!(!invalid_1.verify(&root)); + assert!(!invalid_2.verify(&root)); + + for chunk in erasure_chunks { + assert!(chunk.verify(&root)); + } } } } diff --git a/src/subshard.rs b/src/subshard.rs index 6354336..dae8177 100644 --- a/src/subshard.rs +++ b/src/subshard.rs @@ -9,6 +9,19 @@ use std::{ mem::MaybeUninit, }; +/// Macro to create a vector without cloning the element. +/// The element expression is evaluated on each iteration. +macro_rules! vec_no_clone { + ($elem:expr; $n:expr) => {{ + let n = $n; + let mut result = Vec::with_capacity(n); + for _ in 0..n { + result.push($elem); + } + result + }}; +} + /// Fix segment size. pub const SEGMENT_SIZE: usize = 4096; @@ -94,7 +107,8 @@ impl SubShardEncoder { &mut self, segments: &[Segment], ) -> Result>, Error> { - let mut result = vec![Box::new([[0u8; SUBSHARD_SIZE]; TOTAL_SHARDS]); segments.len()]; + let mut result = + vec_no_clone![Box::new([[0u8; SUBSHARD_SIZE]; TOTAL_SHARDS]); segments.len()]; let mut seg_offset = 0; let mut shard = [0u8; BATCH_SHARD_SIZE]; @@ -200,11 +214,15 @@ impl SubShardDecoder { where I: Iterator, { - let mut ori = vec![Vec::new(); TOTAL_SHARDS]; + let mut ori = Vec::with_capacity(TOTAL_SHARDS); + for _ in 0..TOTAL_SHARDS { + ori.push(Vec::new()); + } + let mut segments = BTreeMap::::new(); let mut nb_decode = 0; - // TODO processed and run_segments could be skiped if we are sure to get + // TODO processed and run_segments could be skiped if we are sure to get // correct number of chunks all for the same given chunk ix and segments. for (segment, chunk_index, chunk) in subshards { ori[chunk_index.0 as usize].push((segment, chunk)); @@ -257,7 +275,7 @@ impl SubShardDecoder { } continue; } - // TODO max size 16, rather [;16] lookup? + // TODO max size 16, rather [;16] lookup? let mut segment_batch = BTreeSet::new(); for (seg, count) in run_segments.into_iter() { if count == N_SHARDS { @@ -345,7 +363,7 @@ impl SubShardDecoder { debug_assert!(nb_chunk == N_SHARDS); let ori_ret = self.decoder.decode()?; nb_decode += 1; - // TODO modify deps to also access original data and avoid self.ori_shards buffer. + // TODO modify deps to also access original data and avoid self.ori_shards buffer. // Also to avoid instantiating ori_map container. for (i, o) in ori_ret.restored_original_iter() { ori_map.insert(i, o);