diff --git a/Cargo.lock b/Cargo.lock index b00939faa..e9d5f8bda 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -106,6 +106,7 @@ dependencies = [ "parallel", "tracing", "utils", + "zk-alloc", ] [[package]] @@ -637,6 +638,7 @@ dependencies = [ "serde", "system-info", "utils", + "zk-alloc", ] [[package]] @@ -650,6 +652,7 @@ dependencies = [ "mt-poly", "parallel", "tracing", + "zk-alloc", ] [[package]] @@ -679,6 +682,7 @@ dependencies = [ "tracing-forest", "tracing-subscriber", "utils", + "zk-alloc", ] [[package]] @@ -899,7 +903,6 @@ dependencies = [ "sub_protocols", "tracing", "xmss", - "zk-alloc", ] [[package]] @@ -1432,6 +1435,7 @@ name = "zk-alloc" version = "0.1.0" dependencies = [ "libc", + "parallel", "system-info", ] diff --git a/Cargo.toml b/Cargo.toml index 12eff415d..18515ec57 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -83,8 +83,6 @@ include_dir = "0.7" [features] prox-gaps-conjecture = ["rec_aggregation/prox-gaps-conjecture"] -# Build with the plain system allocator instead of zk-alloc (for comparison/debugging). -standard-alloc = ["rec_aggregation/standard-alloc"] [dependencies] clap.workspace = true @@ -102,3 +100,4 @@ system-info.workspace = true [profile.release] lto = "thin" +codegen-units = 1 diff --git a/crates/backend/Cargo.toml b/crates/backend/Cargo.toml index 7867f4217..618c122da 100644 --- a/crates/backend/Cargo.toml +++ b/crates/backend/Cargo.toml @@ -15,3 +15,4 @@ tracing.workspace = true fiat-shamir = { path = "fiat-shamir", package = "mt-fiat-shamir" } koala-bear = { path = "koala-bear", package = "mt-koala-bear" } utils = { path = "utils", package = "utils" } +zk-alloc.workspace = true diff --git a/crates/backend/koala-bear/src/quintic_extension/packed_extension.rs b/crates/backend/koala-bear/src/quintic_extension/packed_extension.rs index 682f3cb82..1aa93cac5 100644 --- a/crates/backend/koala-bear/src/quintic_extension/packed_extension.rs +++ b/crates/backend/koala-bear/src/quintic_extension/packed_extension.rs @@ -48,7 +48,7 @@ impl> From> for P #[inline] fn from(x: QuinticExtensionField) -> Self { Self { - value: x.value.map(Into::into), + value: array::from_fn(|i| x.value[i].into()), } } } @@ -117,10 +117,11 @@ macro_rules! impl_packed_ext_scalar_ops { impl Mul for PackedQuinticExtensionField { type Output = Self; #[inline] - fn mul(self, rhs: KoalaBear) -> Self { - Self { - value: self.value.map(|x| x * rhs), + fn mul(mut self, rhs: KoalaBear) -> Self { + for v in &mut self.value { + *v *= rhs; } + self } } @@ -281,10 +282,12 @@ where type Output = Self; #[inline] - fn neg(self) -> Self { - Self { - value: self.value.map(PF::neg), + fn neg(mut self) -> Self { + // Loop, not `self.value.map(..)`: avoids a thin-LTO de-inlined `Wrapped` closure. + for v in &mut self.value { + *v = -*v; } + self } } @@ -478,7 +481,7 @@ where #[inline(always)] fn mul(self, rhs: QuinticExtensionField) -> Self { - let b: [PF; 5] = rhs.value.map(|x| x.into()); + let b: [PF; 5] = array::from_fn(|i| rhs.value[i].into()); Self { value: super::extension::quintic_mul(&self.value, &b, PF::dot_product::<5>), } @@ -493,10 +496,11 @@ where type Output = Self; #[inline] - fn mul(self, rhs: PF) -> Self { - Self { - value: self.value.map(|x| x * rhs), + fn mul(mut self, rhs: PF) -> Self { + for v in &mut self.value { + *v *= rhs; } + self } } diff --git a/crates/backend/parallel/src/lib.rs b/crates/backend/parallel/src/lib.rs index 8fb1823b0..ee4686a14 100644 --- a/crates/backend/parallel/src/lib.rs +++ b/crates/backend/parallel/src/lib.rs @@ -56,7 +56,7 @@ thread_local! { /// Calling worker's id in `0..NUM_THREADS` (`0` off-pool). #[must_use] -pub fn current_worker_id() -> usize { +pub(crate) fn current_worker_id() -> usize { WORKER_ID.with(Cell::get) } diff --git a/crates/backend/poly/Cargo.toml b/crates/backend/poly/Cargo.toml index 0d19d40dc..5e49d0407 100644 --- a/crates/backend/poly/Cargo.toml +++ b/crates/backend/poly/Cargo.toml @@ -8,6 +8,8 @@ field = { path = "../field", package = "mt-field" } utils = { path = "../utils", package = "utils" } system-info.workspace = true parallel.workspace = true +zk-alloc.workspace = true + itertools.workspace = true rand.workspace = true serde.workspace = true diff --git a/crates/backend/poly/src/eq_mle.rs b/crates/backend/poly/src/eq_mle.rs index 3c7b9d34c..b11b003db 100644 --- a/crates/backend/poly/src/eq_mle.rs +++ b/crates/backend/poly/src/eq_mle.rs @@ -3,6 +3,7 @@ use crate::{EFPacking, PF}; use ::utils::{iter_array_chunks_padded, log2_ceil_usize, log2_strict_usize}; use field::*; use system_info::NUM_THREADS; +use zk_alloc::ArenaVec; const LOG_NUM_THREADS: usize = log2_ceil_usize(NUM_THREADS); const LOG_BATCHED_TILE_SIZE: usize = 14; @@ -59,26 +60,29 @@ fn par_eval_eq( /// defined on the boolean hypercube by: ∀ (x_1, ..., x_n) ∈ {0, 1}^n, /// P(x_1, ..., x_n) = Π_{i=1}^{n} (x_i.α_i + (1 - x_i).(1 - α_i)) /// (often denoted as P(x) = eq(x, evals)) -pub fn eval_eq>>(eval: &[F]) -> Vec { +/// Returns an arena-backed table (see [`ArenaVec`]). Every eq table is phase-local proof scratch +/// (consumed within the proving phase that built it, or system-backed when the arena is inactive, +/// e.g. in the verifier), so it never outlives a `begin_phase()` reset. +pub fn eval_eq>>(eval: &[F]) -> ArenaVec { eval_eq_scaled(eval, F::ONE) } -pub fn eval_eq_scaled>>(eval: &[F], scalar: F) -> Vec { +pub fn eval_eq_scaled>>(eval: &[F], scalar: F) -> ArenaVec { // Alloc memory without initializing it to zero. - // This is safe because we overwrite it inside `eval_eq`. - let mut out = unsafe { uninitialized_vec(1 << eval.len()) }; + // This is safe because we overwrite it inside `compute_eval_eq`. + let mut out = unsafe { ArenaVec::uninitialized(1 << eval.len()) }; compute_eval_eq::, F, false>(eval, &mut out, scalar); out } -pub fn eval_eq_packed>>(eval: &[F]) -> Vec> { +pub fn eval_eq_packed>>(eval: &[F]) -> ArenaVec> { eval_eq_packed_scaled(eval, F::ONE) } -pub fn eval_eq_packed_scaled>>(eval: &[F], scalar: F) -> Vec> { +pub fn eval_eq_packed_scaled>>(eval: &[F], scalar: F) -> ArenaVec> { // Alloc memory without initializing it to zero. - // This is safe because we overwrite it inside `eval_eq`. - let mut out = unsafe { uninitialized_vec(1 << (eval.len() - packing_log_width::())) }; + // This is safe because we overwrite it inside `compute_eval_eq_packed`. + let mut out = unsafe { ArenaVec::uninitialized(1 << (eval.len() - packing_log_width::())) }; compute_eval_eq_packed::(eval, &mut out, scalar); out } @@ -105,7 +109,7 @@ where let packed = &mut out[selector >> shift]; let mut unpacked: Vec = unpack_extension(&[*packed]); compute_sparse_eval_eq::(selector & ((1 << shift) - 1), eval, &mut unpacked, scalar); - *packed = pack_extension(&unpacked)[0]; + *packed = pack_extension::<_, Vec<_>>(&unpacked)[0]; return; } diff --git a/crates/backend/poly/src/evals.rs b/crates/backend/poly/src/evals.rs index 5624991a9..8c837a2ea 100644 --- a/crates/backend/poly/src/evals.rs +++ b/crates/backend/poly/src/evals.rs @@ -3,8 +3,6 @@ use crate::{EFPacking, PF}; use ::utils::log2_ceil_usize; use field::{ExtensionField, Field, PrimeCharacteristicRing}; use itertools::Itertools; -use std::borrow::Borrow; - pub trait EvaluationsList { fn num_variables(&self) -> usize; fn num_evals(&self) -> usize; @@ -14,30 +12,30 @@ pub trait EvaluationsList { fn evaluate_sparse>(&self, selector: usize, point: &MultilinearPoint) -> EF; } -impl> EvaluationsList for EL { +impl> EvaluationsList for EL { fn num_variables(&self) -> usize { - self.borrow().len().ilog2() as usize + self.as_ref().len().ilog2() as usize } fn num_evals(&self) -> usize { - self.borrow().len() + self.as_ref().len() } fn evaluate>(&self, point: &MultilinearPoint) -> EF { - eval_multilinear::<_, _, true>(self.borrow(), point) + eval_multilinear::<_, _, true>(self.as_ref(), point) } fn evaluate_sequential>(&self, point: &MultilinearPoint) -> EF { - eval_multilinear::<_, _, false>(self.borrow(), point) + eval_multilinear::<_, _, false>(self.as_ref(), point) } fn as_constant(&self) -> F { - assert_eq!(self.borrow().len(), 1); - self.borrow()[0] + assert_eq!(self.as_ref().len(), 1); + self.as_ref()[0] } fn evaluate_sparse>(&self, selector: usize, point: &MultilinearPoint) -> EF { - (&self.borrow()[selector << point.len()..][..(1 << point.len())]).evaluate(point) + (&self.as_ref()[selector << point.len()..][..(1 << point.len())]).evaluate(point) } } @@ -366,7 +364,7 @@ mod tests { let res_normal = eval_multilinear::<_, _, true>(&poly, &point); println!("Normal eval time: {:?}", time.elapsed()); - let packed_poly = pack_extension(&poly); + let packed_poly: Vec<_> = pack_extension(&poly); let time = Instant::now(); let res_packed = eval_packed::<_, true>(&packed_poly, &point); println!("Packed eval time: {:?}", time.elapsed()); diff --git a/crates/backend/poly/src/mle/mle_group_owned.rs b/crates/backend/poly/src/mle/mle_group_owned.rs index b2eae5497..d4e57ed1d 100644 --- a/crates/backend/poly/src/mle/mle_group_owned.rs +++ b/crates/backend/poly/src/mle/mle_group_owned.rs @@ -1,24 +1,25 @@ use crate::*; use ::utils::log2_strict_usize; use field::ExtensionField; +use zk_alloc::ArenaVec; #[derive(Debug)] pub enum MleGroupOwned>> { - Base(Vec>>), - Extension(Vec>), - BasePacked(Vec>>), - ExtensionPacked(Vec>>), + Base(Vec>>), + Extension(Vec>), + BasePacked(Vec>>), + ExtensionPacked(Vec>>), } impl>> MleGroupOwned { - pub fn as_extension_packed_mut(&mut self) -> Option<&mut Vec>>> { + pub fn as_extension_packed_mut(&mut self) -> Option<&mut Vec>>> { match self { Self::ExtensionPacked(e) => Some(e), _ => None, } } - pub fn as_extension(self) -> Option>> { + pub fn as_extension(self) -> Option>> { match self { Self::Extension(e) => Some(e), _ => None, diff --git a/crates/backend/poly/src/mle/mle_group_ref.rs b/crates/backend/poly/src/mle/mle_group_ref.rs index e4a993eb0..335399168 100644 --- a/crates/backend/poly/src/mle/mle_group_ref.rs +++ b/crates/backend/poly/src/mle/mle_group_ref.rs @@ -2,6 +2,7 @@ use crate::*; use ::utils::log2_strict_usize; use field::ExtensionField; use field::PackedValue; +use zk_alloc::ArenaVec; #[derive(Debug)] pub enum MleGroupRef<'a, EF: ExtensionField>> { @@ -158,10 +159,12 @@ impl<'a, EF: ExtensionField>> MleGroupRef<'a, EF> { pub fn clone_to_owned(&self) -> MleGroupOwned { match self { - Self::Base(pols) => MleGroupOwned::Base(pols.iter().map(|v| v.to_vec()).collect()), - Self::Extension(pols) => MleGroupOwned::Extension(pols.iter().map(|v| v.to_vec()).collect()), - Self::BasePacked(pols) => MleGroupOwned::BasePacked(pols.iter().map(|v| v.to_vec()).collect()), - Self::ExtensionPacked(pols) => MleGroupOwned::ExtensionPacked(pols.iter().map(|v| v.to_vec()).collect()), + Self::Base(pols) => MleGroupOwned::Base(pols.iter().map(|v| ArenaVec::from_slice(v)).collect()), + Self::Extension(pols) => MleGroupOwned::Extension(pols.iter().map(|v| ArenaVec::from_slice(v)).collect()), + Self::BasePacked(pols) => MleGroupOwned::BasePacked(pols.iter().map(|v| ArenaVec::from_slice(v)).collect()), + Self::ExtensionPacked(pols) => { + MleGroupOwned::ExtensionPacked(pols.iter().map(|v| ArenaVec::from_slice(v)).collect()) + } } } } diff --git a/crates/backend/poly/src/mle/mle_single_owned.rs b/crates/backend/poly/src/mle/mle_single_owned.rs index 7a2d99c31..538060296 100644 --- a/crates/backend/poly/src/mle/mle_single_owned.rs +++ b/crates/backend/poly/src/mle/mle_single_owned.rs @@ -1,18 +1,19 @@ use crate::{EFPacking, Mle, MleRef, MultilinearPoint, PF, PFPacking, pack_extension, packing_width, unpack_extension}; use field::PackedValue; use field::{ExtensionField, PackedFieldExtension}; +use zk_alloc::ArenaVec; #[derive(Debug, Clone)] pub enum MleOwned>> { - Base(Vec>), - Extension(Vec), - BasePacked(Vec>), - ExtensionPacked(Vec>), + Base(ArenaVec>), + Extension(ArenaVec), + BasePacked(ArenaVec>), + ExtensionPacked(ArenaVec>), } impl>> Default for MleOwned { fn default() -> Self { - Self::Base(vec![]) + Self::Base(ArenaVec::new()) } } @@ -63,35 +64,35 @@ impl>> MleOwned { } } - pub fn as_extension_packed_mut(&mut self) -> Option<&mut Vec>> { + pub fn as_extension_packed_mut(&mut self) -> Option<&mut ArenaVec>> { match self { Self::ExtensionPacked(ep) => Some(ep), _ => None, } } - pub fn into_base(self) -> Option>> { + pub fn into_base(self) -> Option>> { match self { Self::Base(b) => Some(b), _ => None, } } - pub fn into_extension(self) -> Option> { + pub fn into_extension(self) -> Option> { match self { Self::Extension(e) => Some(e), _ => None, } } - pub fn into_base_backed(self) -> Option>> { + pub fn into_base_backed(self) -> Option>> { match self { Self::BasePacked(pb) => Some(pb), _ => None, } } - pub fn into_extension_packed(self) -> Option>> { + pub fn into_extension_packed(self) -> Option>> { match self { Self::ExtensionPacked(ep) => Some(ep), _ => None, diff --git a/crates/backend/poly/src/mle/mle_single_ref.rs b/crates/backend/poly/src/mle/mle_single_ref.rs index c0307c6c4..e925d7eff 100644 --- a/crates/backend/poly/src/mle/mle_single_ref.rs +++ b/crates/backend/poly/src/mle/mle_single_ref.rs @@ -2,6 +2,7 @@ use crate::*; use ::utils::log2_strict_usize; use field::ExtensionField; use field::PackedValue; +use zk_alloc::ArenaVec; #[derive(Debug)] pub enum MleRef<'a, EF: ExtensionField>> { @@ -106,10 +107,10 @@ impl<'a, EF: ExtensionField>> MleRef<'a, EF> { pub fn clone_to_owned(&self) -> MleOwned { match self { - Self::Base(v) => MleOwned::Base(v.to_vec()), - Self::Extension(v) => MleOwned::Extension(v.to_vec()), - Self::BasePacked(pb) => MleOwned::BasePacked(pb.to_vec()), - Self::ExtensionPacked(ep) => MleOwned::ExtensionPacked(ep.to_vec()), + Self::Base(v) => MleOwned::Base(ArenaVec::from_slice(v)), + Self::Extension(v) => MleOwned::Extension(ArenaVec::from_slice(v)), + Self::BasePacked(pb) => MleOwned::BasePacked(ArenaVec::from_slice(pb)), + Self::ExtensionPacked(ep) => MleOwned::ExtensionPacked(ArenaVec::from_slice(ep)), } } diff --git a/crates/backend/poly/src/utils.rs b/crates/backend/poly/src/utils.rs index 8edbdd50a..4f94b3f59 100644 --- a/crates/backend/poly/src/utils.rs +++ b/crates/backend/poly/src/utils.rs @@ -1,15 +1,16 @@ use std::ops::{Add, Sub}; use field::*; +use zk_alloc::{ArenaVec, OwnedBuffer}; use crate::{EFPacking, PF, PFPacking}; pub const PARALLEL_THRESHOLD: usize = 1 << 9; -pub fn pack_extension>>(slice: &[EF]) -> Vec> { +/// AoS->SoA transpose of `slice` into the already-sized packed buffer `out` (`out.len()` +/// packed elements, each consuming `packing_width` scalars). +fn fill_packed_extension>>(slice: &[EF], out: &mut [EFPacking]) { let width = packing_width::(); - let n_packed = slice.len() / width; - let mut out: Vec> = unsafe { uninitialized_vec(n_packed) }; let write = |slot: &mut EFPacking, chunk: &[EF]| { *slot = EFPacking::::from_ext_slice(chunk); }; @@ -18,17 +19,21 @@ pub fn pack_extension>>(slice: &[EF]) -> Vec>>(vec: &[EFPacking]) -> Vec { +pub fn pack_extension>, B: OwnedBuffer>>(slice: &[EF]) -> B { + B::build(slice.len() / packing_width::(), |out| { + fill_packed_extension(slice, out) + }) +} + +fn fill_unpacked_extension>>(vec: &[EFPacking], out: &mut [EF]) { let width = packing_width::(); - let total = vec.len() * width; - let mut out: Vec = unsafe { uninitialized_vec(total) }; + let total = out.len(); let write = |out_chunk: &mut [EF], x: &EFPacking| { let packed_coeffs = x.as_basis_coefficients_slice(); for (lane, slot) in out_chunk.iter_mut().enumerate() { @@ -43,13 +48,18 @@ pub fn unpack_extension>>(vec: &[EFPacking]) -> Ve // One pool task per group of `group` packed elements, each writing `group * width` // contiguous output scalars from a disjoint slice of `vec`. let group = parallel::recommended_chunk_size(vec.len()); - parallel::par_chunks_mut(&mut out, group * width, |ci, out_chunk| { + parallel::par_chunks_mut(out, group * width, |ci, out_chunk| { for (k, sub) in out_chunk.chunks_exact_mut(width).enumerate() { write(sub, &vec[ci * group + k]); } }); } - out +} + +pub fn unpack_extension>, B: OwnedBuffer>(vec: &[EFPacking]) -> B { + B::build(vec.len() * packing_width::(), |out| { + fill_unpacked_extension(vec, out) + }) } pub const fn packing_log_width() -> usize { @@ -65,50 +75,55 @@ pub const fn must_unpack_multilinears(n_vars: usize) -> bool { } #[inline] -fn fold_fill OF + Sync>(len: usize, seq: bool, compute: C) -> Vec { - let mut res = unsafe { uninitialized_vec(len) }; - if seq || len < PARALLEL_THRESHOLD { +fn fill_fold OF + Sync>(res: &mut [OF], seq: bool, compute: C) { + if seq || res.len() < PARALLEL_THRESHOLD { for (i, r) in res.iter_mut().enumerate() { *r = compute(i); } } else { - parallel::par_fill(&mut res, &compute); + parallel::par_fill(res, &compute); } - res } -fn fold_multilinear_lsb< +#[inline] +fn fold_fill, C: Fn(usize) -> OF + Sync>(len: usize, seq: bool, compute: C) -> B { + B::build(len, |res| fill_fold(res, seq, compute)) +} + +pub fn fold_multilinear< EF: PrimeCharacteristicRing + Copy + Send + Sync, IF: Copy + Sub + Send + Sync, OF: Copy + Add + Send + Sync, - Mul: Fn(IF, EF) -> OF + Sync + Send, + F: Fn(IF, EF) -> OF + Sync + Send, + B: OwnedBuffer, >( m: &[IF], alpha: EF, - mul_if_of: &Mul, + mul_if_of: &F, seq: bool, -) -> Vec { - fold_fill(m.len() / 2, seq, |j| { - mul_if_of(m[2 * j + 1] - m[2 * j], alpha) + m[2 * j] - }) +) -> B { + let new_size = m.len() / 2; + fold_fill(new_size, seq, |i| mul_if_of(m[i + new_size] - m[i], alpha) + m[i]) } -/// Fold `m` at variable `bit`. `seq` forces sequential execution (see [`fold_fill`]). pub fn fold_multilinear_at_bit< EF: PrimeCharacteristicRing + Copy + Send + Sync, IF: Copy + Sub + Send + Sync, OF: Copy + Add + Send + Sync, - Mul: Fn(IF, EF) -> OF + Sync + Send, + F: Fn(IF, EF) -> OF + Sync + Send, + B: OwnedBuffer, >( m: &[IF], alpha: EF, bit: usize, - mul_if_of: &Mul, + mul_if_of: &F, seq: bool, -) -> Vec { +) -> B { assert!(m.len() >= 2 * (1 << bit), "bit out of range for slice length"); if bit == 0 { - return fold_multilinear_lsb(m, alpha, mul_if_of, seq); + return fold_fill(m.len() / 2, seq, |j| { + mul_if_of(m[2 * j + 1] - m[2 * j], alpha) + m[2 * j] + }); } let stride = 1usize << bit; let lo_mask = stride - 1; @@ -121,22 +136,6 @@ pub fn fold_multilinear_at_bit< }) } -/// Fold `m` at its top variable. `seq` forces sequential execution (see [`fold_fill`]). -pub fn fold_multilinear< - EF: PrimeCharacteristicRing + Copy + Send + Sync, - IF: Copy + Sub + Send + Sync, - OF: Copy + Add + Send + Sync, - F: Fn(IF, EF) -> OF + Sync + Send, ->( - m: &[IF], - alpha: EF, - mul_if_of: &F, - seq: bool, -) -> Vec { - let new_size = m.len() / 2; - fold_fill(new_size, seq, |i| mul_if_of(m[i + new_size] - m[i], alpha) + m[i]) -} - pub fn batch_fold_multilinears< EF: PrimeCharacteristicRing + Copy + Send + Sync, IF: Copy + Sub + Send + Sync, @@ -146,7 +145,7 @@ pub fn batch_fold_multilinears< polys: &[&[IF]], alpha: EF, mul_if_of: F, -) -> Vec> { +) -> Vec> { let total_size: usize = polys.iter().map(|p| p.len()).sum(); if total_size < PARALLEL_THRESHOLD { polys @@ -168,7 +167,7 @@ pub fn batch_fold_multilinears_at_bit< alpha: EF, bit: usize, mul_if_of: F, -) -> Vec> { +) -> Vec> { let total_size: usize = polys.iter().map(|p| p.len()).sum(); if total_size < PARALLEL_THRESHOLD { polys @@ -335,9 +334,9 @@ mod bench_tests { for &log_n in &LOG_SIZES { let n = 1usize << log_n; let ext_vec: Vec = (0..n).map(|_| rng.random()).collect(); - let packed = pack_extension(&ext_vec); - let _ = unpack_extension::(&packed); // warmup - let (avg, min_t, max_t) = measure(|| unpack_extension::(&packed)); + let packed: Vec<_> = pack_extension(&ext_vec); + let _ = unpack_extension::>(&packed); // warmup + let (avg, min_t, max_t) = measure(|| unpack_extension::>(&packed)); print_row(log_n, n, avg, min_t, max_t); } } @@ -349,8 +348,8 @@ mod bench_tests { for &log_n in &LOG_SIZES { let n = 1usize << log_n; let ext_vec: Vec = (0..n).map(|_| rng.random()).collect(); - let _ = pack_extension::(&ext_vec); // warmup - let (avg, min_t, max_t) = measure(|| pack_extension::(&ext_vec)); + let _ = pack_extension::>(&ext_vec); // warmup + let (avg, min_t, max_t) = measure(|| pack_extension::>(&ext_vec)); print_row(log_n, n, avg, min_t, max_t); } } diff --git a/crates/backend/src/lib.rs b/crates/backend/src/lib.rs index f4cc2d18f..fea3c62ee 100644 --- a/crates/backend/src/lib.rs +++ b/crates/backend/src/lib.rs @@ -8,3 +8,4 @@ pub use sumcheck::*; pub use symetric::*; pub use utils::*; pub use whir::*; +pub use zk_alloc::*; diff --git a/crates/backend/sumcheck/Cargo.toml b/crates/backend/sumcheck/Cargo.toml index 1d5f486ca..8d7188cf6 100644 --- a/crates/backend/sumcheck/Cargo.toml +++ b/crates/backend/sumcheck/Cargo.toml @@ -7,6 +7,7 @@ edition.workspace = true field = { path = "../field", package = "mt-field" } air = { path = "../air", package = "mt-air" } poly = { path = "../poly", package = "mt-poly" } +zk-alloc.workspace = true fiat-shamir = { path = "../fiat-shamir", package = "mt-fiat-shamir" } parallel.workspace = true tracing.workspace = true diff --git a/crates/backend/sumcheck/src/product_computation.rs b/crates/backend/sumcheck/src/product_computation.rs index 343141c4f..c069e7519 100644 --- a/crates/backend/sumcheck/src/product_computation.rs +++ b/crates/backend/sumcheck/src/product_computation.rs @@ -2,6 +2,7 @@ use fiat_shamir::*; use field::*; use poly::*; use tracing::instrument; +use zk_alloc::ArenaVec; use crate::{SumcheckComputation, sumcheck_prove_many_rounds}; @@ -171,14 +172,14 @@ pub fn fold_and_compute_product_sumcheck_polynomial< prev_folding_factor: EF, sum: EF, decompose: impl Fn(EFPacking) -> Vec, -) -> (DensePolynomial, Vec>) { +) -> (DensePolynomial, Vec>) { let n = pol_0.len(); assert_eq!(n, pol_1.len()); assert!(n.is_power_of_two()); let prev_folding_factor_packed = EFPacking::from(prev_folding_factor); - let mut pol_0_folded = unsafe { uninitialized_vec::(n / 2) }; - let mut pol_1_folded = unsafe { uninitialized_vec::(n / 2) }; + let mut pol_0_folded = unsafe { ArenaVec::::uninitialized(n / 2) }; + let mut pol_1_folded = unsafe { ArenaVec::::uninitialized(n / 2) }; #[allow(clippy::type_complexity)] let process_element = |(p0_prev, p0_f): (((&F, &F), (&F, &F)), (&mut EFPacking, &mut EFPacking)), diff --git a/crates/backend/sumcheck/src/sc_computation.rs b/crates/backend/sumcheck/src/sc_computation.rs index 344a27015..57b3ffd9f 100644 --- a/crates/backend/sumcheck/src/sc_computation.rs +++ b/crates/backend/sumcheck/src/sc_computation.rs @@ -4,6 +4,7 @@ use field::*; use poly::*; use std::any::TypeId; use std::ops::{Add, AddAssign, Mul, MulAssign, Sub}; +use zk_alloc::ArenaVec; fn add_assign_vec(mut a: Vec, b: Vec) -> Vec { for (x, y) in a.iter_mut().zip(b) { @@ -456,7 +457,7 @@ fn sumcheck_fold_and_compute_core( fold_f: impl Fn(&[IF], usize) -> FT + Sync + Send, eval_fn: impl Fn(&SC, &[FT], &SC::ExtraData) -> FT + Sync + Send, unpack_sum: impl Fn(FT) -> EF, - wrap_f: impl FnOnce(Vec>) -> MleGroupOwned, + wrap_f: impl FnOnce(Vec>) -> MleGroupOwned, ) -> (Vec, MleGroupOwned) where EF: ExtensionField>, @@ -466,8 +467,8 @@ where { let prev_folded_size = 2 * compute_fold_size; - let folded_f: Vec> = (0..multilinears.len()) - .map(|_| FT::zero_vec(prev_folded_size)) + let folded_f: Vec> = (0..multilinears.len()) + .map(|_| unsafe { ArenaVec::::zeroed(prev_folded_size) }) .collect(); let n_mult = multilinears.len(); @@ -610,7 +611,7 @@ fn sumcheck_fold_and_compute_with_split_eq( fold_f: impl Fn(&[IF], usize) -> EFPacking + Sync + Send, eval_fn: impl Fn(&SC, &[EFPacking], &SC::ExtraData) -> EFPacking + Sync + Send, unpack_sum: impl Fn(EFPacking) -> EF, - wrap_f: impl FnOnce(Vec>>) -> MleGroupOwned, + wrap_f: impl FnOnce(Vec>>) -> MleGroupOwned, ) -> (Vec, MleGroupOwned) where EF: ExtensionField>, @@ -618,8 +619,8 @@ where SC: SumcheckComputation, { let prev_folded_size = 2 * compute_fold_size; - let folded_f: Vec>> = (0..multilinears.len()) - .map(|_| EFPacking::::zero_vec(prev_folded_size)) + let folded_f: Vec>> = (0..multilinears.len()) + .map(|_| unsafe { ArenaVec::>::zeroed(prev_folded_size) }) .collect(); let n_lo = split_eq.n_lo(); diff --git a/crates/backend/sumcheck/src/split_eq.rs b/crates/backend/sumcheck/src/split_eq.rs index a5ced294c..16a8908b8 100644 --- a/crates/backend/sumcheck/src/split_eq.rs +++ b/crates/backend/sumcheck/src/split_eq.rs @@ -1,13 +1,14 @@ use field::{ExtensionField, PackedFieldExtension}; use poly::*; +use zk_alloc::ArenaVec; #[derive(Debug)] pub struct SplitEq>> { - pub eq_lo: Vec, - pub eq_hi_packed: Vec>, + pub eq_lo: ArenaVec, + pub eq_hi_packed: ArenaVec>, pub log_packed_hi: u32, // = log2(eq_hi_packed.len()), cached for bit-shift in get_packed /// Unpacked remainder for when the packed table is empty or exhausted. - pub remainder: Vec, + pub remainder: ArenaVec, } impl>> SplitEq { @@ -16,8 +17,8 @@ impl>> SplitEq { if must_unpack_multilinears::(n + 1) { return Self { - eq_lo: vec![EF::ONE], - eq_hi_packed: Vec::new(), + eq_lo: ArenaVec::filled(EF::ONE, 1), + eq_hi_packed: ArenaVec::new(), log_packed_hi: 0, remainder: eval_eq(eq_point), }; @@ -32,7 +33,7 @@ impl>> SplitEq { eq_lo, eq_hi_packed, log_packed_hi, - remainder: Vec::new(), + remainder: ArenaVec::new(), } } @@ -53,7 +54,7 @@ impl>> SplitEq { self.log_packed_hi = new_len.trailing_zeros(); } else { // eq_hi_packed has 0 or 1 element — unpack to remainder and halve - let mut unpacked: Vec = EFPacking::::to_ext_iter(self.eq_hi_packed.iter().copied()).collect(); + let mut unpacked: ArenaVec = EFPacking::::to_ext_iter(self.eq_hi_packed.iter().copied()).collect(); let scale = self.eq_lo[0]; for v in &mut unpacked { *v *= scale; diff --git a/crates/backend/utils/src/lib.rs b/crates/backend/utils/src/lib.rs index d850a9494..38a651123 100644 --- a/crates/backend/utils/src/lib.rs +++ b/crates/backend/utils/src/lib.rs @@ -129,7 +129,6 @@ pub const fn indices_arr() -> [usize; N] { /// reallocations. /// /// # Safety -/// /// This assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`. #[inline] pub unsafe fn flatten_to_base(vec: Vec) -> Vec { @@ -137,27 +136,21 @@ pub unsafe fn flatten_to_base(vec: Vec) -> Vec assert!(align_of::() == align_of::()); assert!(size_of::().is_multiple_of(size_of::())); } - let d = size_of::() / size_of::(); - let mut values = std::mem::ManuallyDrop::new(vec); - let new_len = values.len() * d; - let new_cap = values.capacity() * d; - let ptr = values.as_mut_ptr() as *mut Base; - unsafe { Vec::from_raw_parts(ptr, new_len, new_cap) } + let mut me = mem::ManuallyDrop::new(vec); + unsafe { Vec::from_raw_parts(me.as_mut_ptr().cast::(), me.len() * d, me.capacity() * d) } } /// Convert a vector of `Base` elements to a vector of `BaseArray` elements. /// /// # Safety -/// /// This assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`. #[inline] -pub unsafe fn reconstitute_from_base(mut vec: Vec) -> Vec { +pub unsafe fn reconstitute_from_base(vec: Vec) -> Vec { const { assert!(align_of::() == align_of::()); assert!(size_of::().is_multiple_of(size_of::())); } - let d = size_of::() / size_of::(); assert!( vec.len().is_multiple_of(d), @@ -166,17 +159,13 @@ pub unsafe fn reconstitute_from_base(mut vec: Vec) d ); let new_len = vec.len() / d; - let cap = vec.capacity(); - - if cap.is_multiple_of(d) { - let mut values = std::mem::ManuallyDrop::new(vec); - let new_cap = cap / d; - let ptr = values.as_mut_ptr() as *mut BaseArray; - unsafe { Vec::from_raw_parts(ptr, new_len, new_cap) } + if vec.capacity().is_multiple_of(d) { + let mut me = mem::ManuallyDrop::new(vec); + unsafe { Vec::from_raw_parts(me.as_mut_ptr().cast::(), new_len, me.capacity() / d) } } else { - let buf_ptr = vec.as_mut_ptr().cast::(); - let slice_ref = unsafe { slice::from_raw_parts(buf_ptr, new_len) }; - slice_ref.to_vec() + // Capacity isn't a clean multiple: copy into a fresh buffer. + let buf_ptr = vec.as_ptr().cast::(); + unsafe { slice::from_raw_parts(buf_ptr, new_len) }.to_vec() } } diff --git a/crates/backend/utils/src/misc.rs b/crates/backend/utils/src/misc.rs index 8c2bafd4d..c33d139bf 100644 --- a/crates/backend/utils/src/misc.rs +++ b/crates/backend/utils/src/misc.rs @@ -5,8 +5,10 @@ pub fn from_end(slice: &[A], n: usize) -> &[A] { &slice[slice.len() - n..] } -pub fn transposed_par_for_each_mut(array: &mut [Vec; N], g: G) +pub fn transposed_par_for_each_mut(array: &mut [V; N], g: G) where + A: Send + Sync, + V: std::ops::DerefMut, G: Fn(usize, [&mut A; N]) + Sync, { // all vectors must have the same length diff --git a/crates/backend/zk-alloc/Cargo.toml b/crates/backend/zk-alloc/Cargo.toml index 0c4ab6a5f..85c078050 100644 --- a/crates/backend/zk-alloc/Cargo.toml +++ b/crates/backend/zk-alloc/Cargo.toml @@ -6,8 +6,7 @@ description = "Bump+reset arena allocator for ZK proving workloads" [dependencies] system-info.workspace = true - -[target.'cfg(not(all(target_os = "linux", target_arch = "x86_64")))'.dependencies] +parallel.workspace = true libc = "0.2" [lints] diff --git a/crates/backend/zk-alloc/src/arena_cow.rs b/crates/backend/zk-alloc/src/arena_cow.rs new file mode 100644 index 000000000..f35e3e845 --- /dev/null +++ b/crates/backend/zk-alloc/src/arena_cow.rs @@ -0,0 +1,73 @@ +use std::ops::Deref; + +use crate::ArenaVec; + +#[derive(Debug)] +pub enum ArenaCow<'a, T> { + Borrowed(&'a [T]), + Owned(ArenaVec), +} + +impl ArenaCow<'_, T> { + #[inline] + #[must_use] + pub fn as_slice(&self) -> &[T] { + match self { + Self::Borrowed(s) => s, + Self::Owned(v) => v, + } + } + + #[inline] + #[must_use] + pub fn len(&self) -> usize { + self.as_slice().len() + } + + #[inline] + #[must_use] + pub fn is_empty(&self) -> bool { + self.as_slice().is_empty() + } +} + +impl ArenaCow<'_, T> { + /// Take ownership of the buffer, copying into the arena only if currently borrowed. + #[inline] + #[must_use] + pub fn into_owned(self) -> ArenaVec { + match self { + Self::Borrowed(s) => ArenaVec::from_slice(s), + Self::Owned(v) => v, + } + } +} + +impl<'a, T> From<&'a [T]> for ArenaCow<'a, T> { + #[inline] + fn from(s: &'a [T]) -> Self { + Self::Borrowed(s) + } +} + +impl From> for ArenaCow<'_, T> { + #[inline] + fn from(v: ArenaVec) -> Self { + Self::Owned(v) + } +} + +impl Deref for ArenaCow<'_, T> { + type Target = [T]; + #[inline] + fn deref(&self) -> &[T] { + self.as_slice() + } +} + +impl AsRef<[T]> for ArenaCow<'_, T> { + #[inline] + fn as_ref(&self) -> &[T] { + self.as_slice() + } +} diff --git a/crates/backend/zk-alloc/src/arena_vec.rs b/crates/backend/zk-alloc/src/arena_vec.rs new file mode 100644 index 000000000..df5ae7db9 --- /dev/null +++ b/crates/backend/zk-alloc/src/arena_vec.rs @@ -0,0 +1,464 @@ +//! [`ArenaVec`] — a minimal owning vector backed by the proving arena. +//! +//! Allocation goes through [`raw_alloc`](crate::raw_alloc) (arena bump in a phase, else system) and +//! `Drop`/growth through [`raw_dealloc`](crate::raw_dealloc), which picks arena-vs-system by pointer +//! range — the dynamic choice that lets `ArenaVec` carry no allocator type parameter. An `ArenaVec` +//! allocated in a phase is invalidated by the next [`begin_phase`](crate::begin_phase); anything +//! that must outlive a phase uses the system allocator (a plain `Vec`, or an `ArenaVec` built outside a +//! phase). + +use std::alloc::handle_alloc_error; +use std::cmp; +use std::fmt; +use std::marker::PhantomData; +use std::mem::{ManuallyDrop, align_of, size_of}; +use std::ops::{Deref, DerefMut}; +use std::ptr::{self, NonNull}; +use std::slice; + +use crate::{raw_alloc, raw_dealloc}; + +/// Owning, growable buffer allocated from the proving arena (see the module docs). +pub struct ArenaVec { + /// Always aligned and non-null; dangling (and never dereferenced for reads) while `cap == 0`. + ptr: NonNull, + len: usize, + /// Element capacity. For zero-sized `T` this is fixed at `usize::MAX` and no memory is owned. + cap: usize, + _marker: PhantomData, +} + +unsafe impl Send for ArenaVec {} +unsafe impl Sync for ArenaVec {} + +pub trait OwnedBuffer: DerefMut + Sized { + /// `len` uninitialized elements. + /// + /// # Safety + /// Every element must be written before it is read. + unsafe fn uninit(len: usize) -> Self; + + /// `len` elements, initialized in place by `fill` — which **must** write all of them. + #[inline] + fn build(len: usize, fill: impl FnOnce(&mut [T])) -> Self { + // SAFETY: `fill` writes every one of the `len` elements before any is read. + let mut buf = unsafe { Self::uninit(len) }; + fill(&mut buf); + buf + } +} + +impl OwnedBuffer for Vec { + #[inline] + #[allow(clippy::uninit_vec)] + unsafe fn uninit(len: usize) -> Self { + let mut v = Vec::with_capacity(len); + // SAFETY: the `uninit`/`build` contract requires all `len` slots written before read. + unsafe { v.set_len(len) }; + v + } +} + +impl OwnedBuffer for ArenaVec { + #[inline] + unsafe fn uninit(len: usize) -> Self { + // SAFETY: as above. + unsafe { Self::uninitialized(len) } + } +} + +impl ArenaVec { + /// `usize::MAX` capacity stands in for "unbounded" for zero-sized elements (which never + /// allocate); `0` otherwise. + const EMPTY_CAP: usize = if size_of::() == 0 { usize::MAX } else { 0 }; + + /// A new, empty vector. No allocation. + #[inline] + #[must_use] + pub const fn new() -> Self { + Self { + ptr: NonNull::dangling(), + len: 0, + cap: Self::EMPTY_CAP, + _marker: PhantomData, + } + } + + /// A new, empty vector with room for `cap` elements pre-reserved (exact, no over-allocation). + #[inline] + #[must_use] + pub fn with_capacity(cap: usize) -> Self { + let mut v = Self::new(); + if size_of::() != 0 && cap != 0 { + v.realloc_to(cap); + } + v + } + + /// Arena-backed `vec![value; n]`. + #[inline] + #[must_use] + pub fn filled(value: T, n: usize) -> Self + where + T: Clone, + { + let mut v = Self::with_capacity(n); + v.resize(n, value); + v + } + + /// Arena-backed zero-initialized buffer of length `n`, zeroed with a single `write_bytes` + /// (`memset`) — far cheaper than [`filled`](Self::filled)'s element-wise clone loop. + /// + /// # Safety + /// `T`'s all-zero bit pattern must be a valid, fully-initialized value of `T` (true for the + /// Montgomery field types and their SIMD packings, whose `ZERO` is all-zero bytes). + #[inline] + #[must_use] + pub unsafe fn zeroed(n: usize) -> Self { + // SAFETY: every slot is initialized by the `write_bytes` below before it can be read. + let mut v = unsafe { Self::uninitialized(n) }; + // SAFETY: `v` owns `n` allocated slots; caller guarantees all-zero is a valid `T`. + unsafe { ptr::write_bytes(v.as_mut_ptr(), 0u8, n) }; + v + } + + /// Arena-backed `slice.to_vec()`. + #[inline] + #[must_use] + pub fn from_slice(slice: &[T]) -> Self + where + T: Clone, + { + let mut v = Self::with_capacity(slice.len()); + v.extend_from_slice(slice); + v + } + + /// `len` uninitialized slots. + /// + /// # Safety + /// Every element must be overwritten before it is read. + #[inline] + #[must_use] + pub unsafe fn uninitialized(len: usize) -> Self { + let mut v = Self::with_capacity(len); + // SAFETY: caller guarantees all `len` slots are written before being read. + unsafe { v.set_len(len) }; + v + } + + /// Arena-backed parallel `(0..n).map(f).collect()`: fill a vector of length `n` in parallel. + /// The single allocation happens on the calling thread; workers write disjoint slots. + #[inline] + #[must_use] + pub fn par_collect T + Sync>(n: usize, f: F) -> Self + where + T: Send, + { + // SAFETY: `par_fill` writes every slot in `0..n` exactly once before any is read. + let mut v = unsafe { Self::uninitialized(n) }; + parallel::par_fill(&mut v, f); + v + } + + #[inline] + #[must_use] + pub const fn len(&self) -> usize { + self.len + } + + #[inline] + #[must_use] + pub const fn capacity(&self) -> usize { + self.cap + } + + #[inline] + #[must_use] + pub const fn is_empty(&self) -> bool { + self.len == 0 + } + + #[inline] + #[must_use] + pub const fn as_ptr(&self) -> *const T { + self.ptr.as_ptr() + } + + #[inline] + pub const fn as_mut_ptr(&mut self) -> *mut T { + self.ptr.as_ptr() + } + + #[inline] + #[must_use] + pub fn as_slice(&self) -> &[T] { + self + } + + #[inline] + pub fn as_mut_slice(&mut self) -> &mut [T] { + self + } + + /// Set the length without touching the buffer. + /// + /// # Safety + /// `new_len <= capacity()` and every element in `0..new_len` must be initialized. + #[inline] + pub unsafe fn set_len(&mut self, new_len: usize) { + debug_assert!(new_len <= self.cap); + self.len = new_len; + } + + /// Reserve space for at least `additional` more elements (amortized doubling). + #[inline] + pub fn reserve(&mut self, additional: usize) { + if size_of::() == 0 { + return; // capacity is conceptually unbounded for ZSTs + } + let required = self.len.checked_add(additional).expect("ArenaVec capacity overflow"); + if required > self.cap { + let new_cap = cmp::max(required, self.cap.saturating_mul(2)); + self.realloc_to(new_cap); + } + } + + #[inline] + pub fn push(&mut self, value: T) { + if self.len == self.cap { + // ZSTs never reach here (cap == usize::MAX); only sized types grow. + let new_cap = cmp::max(self.cap.saturating_mul(2), 4); + self.realloc_to(new_cap); + } + // SAFETY: `len < cap` now, so slot `len` is allocated and uninitialized. + unsafe { self.ptr.as_ptr().add(self.len).write(value) }; + self.len += 1; + } + + /// Append a clone of every element of `other`. + #[inline] + pub fn extend_from_slice(&mut self, other: &[T]) + where + T: Clone, + { + self.reserve(other.len()); + // Bump `len` per element so a panic mid-clone leaves a consistent vector (written clones drop). + for x in other { + // SAFETY: `reserve` guaranteed room for `other.len()` more; `len` stays < `cap`. + unsafe { self.ptr.as_ptr().add(self.len).write(x.clone()) }; + self.len += 1; + } + } + + /// Grow or shrink to `new_len`, filling new slots with clones of `value`. + pub fn resize(&mut self, new_len: usize, value: T) + where + T: Clone, + { + if new_len > self.len { + self.reserve(new_len - self.len); + while self.len < new_len { + // SAFETY: room reserved above; `len < new_len <= cap`. + unsafe { self.ptr.as_ptr().add(self.len).write(value.clone()) }; + self.len += 1; + } + } else { + self.truncate(new_len); + } + } + + /// Drop the elements past `len`, keeping capacity. + pub fn truncate(&mut self, len: usize) { + if len < self.len { + let drop_count = self.len - len; + // Shorten first so a panicking `Drop` can't observe/double-drop the tail. + self.len = len; + // SAFETY: `[len, old_len)` were initialized and are now logically removed. + unsafe { + ptr::drop_in_place(ptr::slice_from_raw_parts_mut(self.ptr.as_ptr().add(len), drop_count)); + } + } + } + + #[inline] + pub fn clear(&mut self) { + self.truncate(0); + } + + /// Decompose into raw parts, leaking the buffer. Inverse of [`from_raw_parts`](Self::from_raw_parts). + #[inline] + #[must_use] + pub fn into_raw_parts(self) -> (*mut T, usize, usize) { + let me = ManuallyDrop::new(self); + (me.ptr.as_ptr(), me.len, me.cap) + } + + /// Reconstruct from parts previously obtained via [`into_raw_parts`](Self::into_raw_parts) + /// (or a layout-compatible reinterpret thereof). + /// + /// # Safety + /// `ptr` is non-null and aligned for `T`; `len <= cap`; and `ptr` either was returned by + /// [`raw_alloc`](crate::raw_alloc) for `cap * size_of::()` bytes at `align_of::()`, or + /// `cap == 0` and `ptr` is dangling-but-aligned. Exactly one `ArenaVec` may own a given pointer. + #[inline] + #[must_use] + pub unsafe fn from_raw_parts(ptr: *mut T, len: usize, cap: usize) -> Self { + Self { + // SAFETY: caller guarantees `ptr` is non-null. + ptr: unsafe { NonNull::new_unchecked(ptr) }, + len, + cap, + _marker: PhantomData, + } + } + + /// Allocate a fresh `new_cap`-element buffer, move the `len` live elements into it, and free + /// the old one. Only called for sized `T` with `new_cap >= len` and `new_cap > 0`. + fn realloc_to(&mut self, new_cap: usize) { + debug_assert!(size_of::() != 0 && new_cap >= self.len && new_cap > 0); + let align = align_of::(); + let new_bytes = new_cap.checked_mul(size_of::()).expect("ArenaVec capacity overflow"); + assert!(new_bytes <= isize::MAX as usize, "ArenaVec capacity overflow"); + + // SAFETY: `align` is a valid power of two; `new_bytes > 0`. + let raw = unsafe { raw_alloc(new_bytes, align) }.cast::(); + let Some(new_ptr) = NonNull::new(raw) else { + // Matches `Vec`: an allocation failure aborts rather than unwinds. + handle_alloc_error(unsafe { std::alloc::Layout::from_size_align_unchecked(new_bytes, align) }); + }; + + if self.cap != 0 { + // SAFETY: the two buffers are distinct; `len <= old cap` initialized elements move. + unsafe { ptr::copy_nonoverlapping(self.ptr.as_ptr(), new_ptr.as_ptr(), self.len) }; + // SAFETY: old buffer came from `raw_alloc` with this size/align (range-checked free). + unsafe { raw_dealloc(self.ptr.as_ptr().cast::(), self.cap * size_of::(), align) }; + } + self.ptr = new_ptr; + self.cap = new_cap; + } +} + +impl Drop for ArenaVec { + fn drop(&mut self) { + // Drop the live elements first (no-op for `Copy`/trivial types; the compiler elides it). + if std::mem::needs_drop::() { + // SAFETY: `0..len` are initialized. + unsafe { ptr::drop_in_place(ptr::slice_from_raw_parts_mut(self.ptr.as_ptr(), self.len)) }; + } + // Free the buffer. ZSTs and never-allocated vectors own nothing. + if size_of::() != 0 && self.cap != 0 { + // SAFETY: buffer came from `raw_alloc(cap * size, align)`; `raw_dealloc` range-checks + // arena-vs-system. Arena pointers free as a no-op (reclaimed at the next phase reset). + unsafe { + raw_dealloc( + self.ptr.as_ptr().cast::(), + self.cap * size_of::(), + align_of::(), + ) + }; + } + } +} + +impl Deref for ArenaVec { + type Target = [T]; + #[inline] + fn deref(&self) -> &[T] { + // SAFETY: `ptr` is aligned and `0..len` are initialized (valid for ZSTs too: a dangling + // aligned pointer is a valid base for a zero-byte-stride slice). + unsafe { slice::from_raw_parts(self.ptr.as_ptr(), self.len) } + } +} + +impl DerefMut for ArenaVec { + #[inline] + fn deref_mut(&mut self) -> &mut [T] { + // SAFETY: as `deref`, with unique access. + unsafe { slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) } + } +} + +impl AsRef<[T]> for ArenaVec { + #[inline] + fn as_ref(&self) -> &[T] { + self + } +} + +impl AsMut<[T]> for ArenaVec { + #[inline] + fn as_mut(&mut self) -> &mut [T] { + self + } +} + +impl Default for ArenaVec { + #[inline] + fn default() -> Self { + Self::new() + } +} + +impl Clone for ArenaVec { + fn clone(&self) -> Self { + let mut out = Self::with_capacity(self.len); + out.extend_from_slice(self); + out + } +} + +impl fmt::Debug for ArenaVec { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl PartialEq for ArenaVec { + #[inline] + fn eq(&self, other: &Self) -> bool { + **self == **other + } +} + +impl Eq for ArenaVec {} + +impl Extend for ArenaVec { + #[inline] + fn extend>(&mut self, iter: I) { + let iter = iter.into_iter(); + self.reserve(iter.size_hint().0); + for x in iter { + self.push(x); + } + } +} + +impl FromIterator for ArenaVec { + #[inline] + fn from_iter>(iter: I) -> Self { + let iter = iter.into_iter(); + let mut v = Self::with_capacity(iter.size_hint().0); + v.extend(iter); + v + } +} + +impl<'a, T> IntoIterator for &'a ArenaVec { + type Item = &'a T; + type IntoIter = slice::Iter<'a, T>; + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a, T> IntoIterator for &'a mut ArenaVec { + type Item = &'a mut T; + type IntoIter = slice::IterMut<'a, T>; + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.iter_mut() + } +} diff --git a/crates/backend/zk-alloc/src/lib.rs b/crates/backend/zk-alloc/src/lib.rs index a3ff9d1df..a8500b730 100644 --- a/crates/backend/zk-alloc/src/lib.rs +++ b/crates/backend/zk-alloc/src/lib.rs @@ -1,109 +1,106 @@ -//! Bump-pointer arena allocator. -//! -//! One mmap region split into per-thread slabs. Allocation = increment a thread-local -//! pointer; free = no-op. `begin_phase()` resets the arena: each thread's next -//! allocation starts over at the beginning of its slab, overwriting the previous -//! phase's data. Allocations that don't fit (too large, or beyond `MAX_THREADS`) fall -//! back to the system allocator. -//! -//! ```ignore -//! loop { -//! begin_phase(); // arena ON; slabs reset lazily -//! let res = heavy_work(); // fast increments -//! end_phase(); // arena OFF; new allocations go to System -//! let copy = res.clone(); // detach from arena before next phase resets it -//! } -//! ``` +//! Bump-pointer arena, used explicitly (never as a `#[global_allocator]`). One mmap region split +//! into per-thread slabs: alloc bumps a thread-local pointer, free is a no-op, `begin_phase()` +//! resets every slab. Proof data lives in [`ArenaVec`]; `raw_dealloc` picks arena-vs-system by +//! pointer range, so `ArenaVec` carries no allocator parameter. use std::alloc::{GlobalAlloc, Layout}; use std::cell::Cell; -use std::sync::Once; +use std::sync::OnceLock; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use system_info::NUM_THREADS; +mod arena_cow; +mod arena_vec; mod syscall; -const SLAB_SIZE: usize = 8 << 30; // 8GB -const SLACK: usize = 4; // SLACK absorbs the main thread and any non-rayon helpers. -const MAX_THREADS: usize = NUM_THREADS + SLACK; -const REGION_SIZE: usize = SLAB_SIZE * MAX_THREADS; +pub use arena_cow::ArenaCow; +pub use arena_vec::{ArenaVec, OwnedBuffer}; -#[derive(Debug)] -pub struct ZkAllocator; +const SLAB_SIZE: usize = 8 << 30; // 8 GiB; per-thread soft cap, overflow falls back to System +const SLACK: usize = 4; // extra slabs for non-pool threads that allocate in a phase +const MAX_THREADS: usize = NUM_THREADS + SLACK; +const REGION_SIZE: usize = SLAB_SIZE * MAX_THREADS; // one contiguous region => O(1) pointer classification -/// Incremented by `begin_phase()`. Every thread caches the last value it saw in -/// `ARENA_GEN`; when they differ, the thread resets its allocation cursor to the start -/// of its slab on the next allocation. This is how a single store on the main thread -/// "resets" every other thread's slab without any cross-thread synchronization. +/// Bumped by `begin_phase()`; a thread resets its slab when its cached `ARENA_GEN` lags — one store +/// resets every thread, lock-free. static GENERATION: AtomicUsize = AtomicUsize::new(0); - -/// Master switch for the arena. `true` (set by `begin_phase`) routes allocations -/// through the arena; `false` (set by `end_phase`) routes them to the system allocator. +/// Arena on (route to arena) vs off (route to System). static ARENA_ACTIVE: AtomicBool = AtomicBool::new(false); - -/// Base address of the mmap'd region, or `0` before `ensure_region` runs. Read on -/// every `dealloc` to test whether a pointer belongs to us. -static REGION_BASE: AtomicUsize = AtomicUsize::new(0); - -/// Synchronizes the one-time mmap so concurrent first-allocators don't race. -static REGION_INIT: Once = Once::new(); - -/// Monotonic counter handed out to threads to pick their slab. `fetch_add`'d once per -/// thread on its first arena allocation. Threads that get `idx >= MAX_THREADS` mark -/// themselves `ARENA_NO_SLAB` and permanently fall through to the system allocator. +/// Process-wide opt-in; gates `begin_phase`'s all-thread reset so a stray call can't corrupt another +/// proving's buffers. Until [`enable_arena`], phases are no-ops and `ArenaVec` uses System. +static ARENA_ENGAGED: AtomicBool = AtomicBool::new(false); +/// mmap'd region base, mapped once; also the arena-vs-system discriminator in `raw_dealloc`. +static REGION: OnceLock = OnceLock::new(); +/// Slab index handed out once per thread; `idx >= MAX_THREADS` falls back to System. static THREAD_IDX: AtomicUsize = AtomicUsize::new(0); thread_local! { - /// Where this thread's next allocation lands. Advanced past each allocation. + /// This thread's next allocation address. static ARENA_PTR: Cell = const { Cell::new(0) }; - /// One past the last byte of this thread's slab. An alloc fits iff - /// `aligned + size <= ARENA_END`. + /// One past this thread's slab. static ARENA_END: Cell = const { Cell::new(0) }; - /// Base address of this thread's slab (`0` = not yet claimed). On reset, - /// `ARENA_PTR` is set back to this value. + /// This thread's slab base (`0` = unclaimed); the reset target. static ARENA_BASE: Cell = const { Cell::new(0) }; - /// Last `GENERATION` value this thread observed. When the global moves past - /// this, the next allocation resets `ARENA_PTR` to `ARENA_BASE` and updates - /// this field. + /// Last `GENERATION` seen; a mismatch triggers a slab reset. static ARENA_GEN: Cell = const { Cell::new(0) }; - /// `true` if this thread was created after `MAX_THREADS` was already exhausted. - /// Such threads skip arena logic entirely and always go to the system allocator. + /// Thread got no slab (`idx >= MAX_THREADS`) — always uses System. static ARENA_NO_SLAB: Cell = const { Cell::new(false) }; } -/// Returns the base address of the mmap'd region, mapping it on the first call. fn ensure_region() -> usize { - REGION_INIT.call_once(|| { - // SAFETY: mmap_anonymous returns a page-aligned pointer or null. MAP_NORESERVE - // means no physical memory is committed until pages are touched. + *REGION.get_or_init(|| { + // SAFETY: mmap returns a page-aligned pointer or null; lazily backed. let ptr = unsafe { syscall::mmap_anonymous(REGION_SIZE) }; if ptr.is_null() { std::process::abort(); } unsafe { syscall::madvise(ptr, REGION_SIZE, syscall::MADV_NOHUGEPAGE) }; - REGION_BASE.store(ptr as usize, Ordering::Release); - }); - REGION_BASE.load(Ordering::Acquire) + ptr as usize + }) } -/// Activates the arena and resets every thread's slab. All allocations until the next -/// `end_phase()` go to the arena; the previous phase's data is overwritten in place. +/// Opt into the arena (once, at startup). Until then phases are inert and `ArenaVec` uses System. +pub fn enable_arena() { + ARENA_ENGAGED.store(true, Ordering::Release); +} + +/// Activate the arena and reset every thread's slab (overwriting the previous phase). No-op until +/// [`enable_arena`]; phases must not nest. pub fn begin_phase() { + if !ARENA_ENGAGED.load(Ordering::Acquire) { + return; + } let prev_active = ARENA_ACTIVE.swap(true, Ordering::Release); - assert!( - !prev_active, - "begin_phase() called while another phase is already active — phases must not nest" - ); + assert!(!prev_active, "phases must not nest"); GENERATION.fetch_add(1, Ordering::Release); } -/// Deactivates the arena. New allocations go to the system allocator; existing arena -/// pointers stay valid until the next `begin_phase()` resets the slabs. +/// Deactivate the arena; existing arena pointers stay valid until the next `begin_phase()`. pub fn end_phase() { + if !ARENA_ENGAGED.load(Ordering::Acquire) { + return; + } ARENA_ACTIVE.store(false, Ordering::Release); } +/// Guard that [`end_phase`]s on drop. +#[derive(Debug)] +pub struct PhaseGuard(()); + +impl Drop for PhaseGuard { + fn drop(&mut self) { + end_phase(); + } +} + +/// [`begin_phase`] + an RAII guard that [`end_phase`]s on drop (incl. early return / panic). +#[must_use = "the phase ends the moment the guard is dropped"] +pub fn enter_phase() -> PhaseGuard { + begin_phase(); + PhaseGuard(()) +} + #[cold] #[inline(never)] unsafe fn arena_alloc_cold(size: usize, align: usize) -> *mut u8 { @@ -133,52 +130,41 @@ unsafe fn arena_alloc_cold(size: usize, align: usize) -> *mut u8 { unsafe { std::alloc::System.alloc(Layout::from_size_align_unchecked(size, align)) } } -// SAFETY: All pointers returned are either from our mmap'd region (valid, aligned, -// non-overlapping per thread) or from System. The arena is thread-local so no data -// races. Relaxed ordering on ARENA_ACTIVE/GENERATION is sound: worst case a thread -// sees a stale value and does one extra system-alloc before picking up the new -// generation on the next call. -unsafe impl GlobalAlloc for ZkAllocator { - #[inline(always)] - unsafe fn alloc(&self, layout: Layout) -> *mut u8 { - if ARENA_ACTIVE.load(Ordering::Relaxed) { - let generation = GENERATION.load(Ordering::Relaxed); - if ARENA_GEN.get() == generation { - let align = layout.align(); - let aligned = (ARENA_PTR.get() + align - 1) & !(align - 1); - let new_ptr = aligned + layout.size(); - if new_ptr <= ARENA_END.get() { - ARENA_PTR.set(new_ptr); - return aligned as *mut u8; - } +/// [`ArenaVec`]'s allocator: bump the thread's slab in an active phase, else System. The cursor is +/// thread-local, so the Relaxed reads can't race — a stale read just costs one extra System alloc. +/// +/// # Safety +/// `align` is a power of two; the result is valid for `size` bytes (or null on System failure) until +/// the next `begin_phase()`. +#[inline(always)] +pub(crate) unsafe fn raw_alloc(size: usize, align: usize) -> *mut u8 { + if ARENA_ACTIVE.load(Ordering::Relaxed) { + let generation = GENERATION.load(Ordering::Relaxed); + if ARENA_GEN.get() == generation { + let aligned = (ARENA_PTR.get() + align - 1) & !(align - 1); + let new_ptr = aligned + size; + if new_ptr <= ARENA_END.get() { + ARENA_PTR.set(new_ptr); + return aligned as *mut u8; } - return unsafe { arena_alloc_cold(layout.size(), layout.align()) }; - } - unsafe { std::alloc::System.alloc(layout) } - } - - #[inline(always)] - unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) { - let addr = ptr as usize; - let base = REGION_BASE.load(Ordering::Relaxed); - if base != 0 && addr >= base && addr < base + REGION_SIZE { - return; // arena-owned pointer — free is a no-op } - unsafe { std::alloc::System.dealloc(ptr, layout) }; + return unsafe { arena_alloc_cold(size, align) }; } + unsafe { std::alloc::System.alloc(Layout::from_size_align_unchecked(size, align)) } +} - #[inline(always)] - unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 { - if new_size <= layout.size() { - return ptr; - } - // SAFETY: new_size > layout.size() > 0, align unchanged from valid layout. - let new_layout = unsafe { Layout::from_size_align_unchecked(new_size, layout.align()) }; - let new_ptr = unsafe { self.alloc(new_layout) }; - if !new_ptr.is_null() { - unsafe { std::ptr::copy(ptr, new_ptr, layout.size()) }; - unsafe { self.dealloc(ptr, layout) }; - } - new_ptr +/// Free for [`raw_alloc`]: no-op for arena pointers (reclaimed at the next `begin_phase()`), else System. +/// +/// # Safety +/// `ptr` came from [`raw_alloc`] with this `size`/`align`. +#[inline(always)] +pub(crate) unsafe fn raw_dealloc(ptr: *mut u8, size: usize, align: usize) { + let addr = ptr as usize; + if REGION + .get() + .is_some_and(|&base| addr >= base && addr < base + REGION_SIZE) + { + return; // arena pointer — free is a no-op } + unsafe { std::alloc::System.dealloc(ptr, Layout::from_size_align_unchecked(size, align)) }; } diff --git a/crates/backend/zk-alloc/src/syscall.rs b/crates/backend/zk-alloc/src/syscall.rs index 13d71531d..9b62f5e1a 100644 --- a/crates/backend/zk-alloc/src/syscall.rs +++ b/crates/backend/zk-alloc/src/syscall.rs @@ -1,163 +1,50 @@ -// Raw syscalls instead of libc wrappers to avoid reentrancy: libc's mmap/madvise -// may internally call malloc, which would deadlock when called from inside -// #[global_allocator]. - -#[cfg(all(target_os = "linux", target_arch = "x86_64"))] -mod imp { - use std::ptr; - - const SYS_MMAP: usize = 9; - const SYS_MADVISE: usize = 28; - - const PROT_READ: usize = 1; - const PROT_WRITE: usize = 2; - const MAP_PRIVATE: usize = 0x02; - const MAP_ANONYMOUS: usize = 0x20; - const MAP_NORESERVE: usize = 0x4000; - - pub const MADV_NOHUGEPAGE: usize = 15; - - #[inline] - unsafe fn syscall6(nr: usize, a1: usize, a2: usize, a3: usize, a4: usize, a5: usize, a6: usize) -> isize { - let ret: isize; - unsafe { - std::arch::asm!( - "syscall", - inlateout("rax") nr as isize => ret, - in("rdi") a1, - in("rsi") a2, - in("rdx") a3, - in("r10") a4, - in("r8") a5, - in("r9") a6, - lateout("rcx") _, - lateout("r11") _, - options(nostack), - ); - } - ret - } - - #[inline] - unsafe fn syscall3(nr: usize, a1: usize, a2: usize, a3: usize) -> isize { - let ret: isize; - unsafe { - std::arch::asm!( - "syscall", - inlateout("rax") nr as isize => ret, - in("rdi") a1, - in("rsi") a2, - in("rdx") a3, - lateout("rcx") _, - lateout("r11") _, - lateout("r10") _, - options(nostack), - ); - } - ret - } - - #[inline] - pub unsafe fn mmap_anonymous(size: usize) -> *mut u8 { - let flags = MAP_PRIVATE | MAP_ANONYMOUS | MAP_NORESERVE; - let ret = unsafe { syscall6(SYS_MMAP, 0, size, PROT_READ | PROT_WRITE, flags, usize::MAX, 0) }; - if ret < 0 { ptr::null_mut() } else { ret as *mut u8 } - } - - #[inline] - pub unsafe fn madvise(ptr: *mut u8, size: usize, advice: usize) { - unsafe { syscall3(SYS_MADVISE, ptr as usize, size, advice) }; +//! Anonymous `mmap` + `madvise` via `libc`. +//! +//! (Raw inline-asm syscalls when zk-alloc was a `#[global_allocator]`, to avoid `libc` re-entering +//! `malloc`. It no longer is, so `libc` is safe: its internal allocations hit the system allocator, +//! not this arena.) + +use std::ptr; + +/// `madvise` advice: disable transparent huge pages for the region. Consulted only on Linux +/// (a no-op elsewhere); see [`madvise`]. +pub const MADV_NOHUGEPAGE: usize = 15; + +/// Reserve `size` bytes of anonymous virtual address space, lazily backed by physical pages. +/// +/// # Safety +/// Always safe to call; returns a page-aligned pointer or null on failure, and the caller owns the +/// resulting mapping. +#[inline] +pub unsafe fn mmap_anonymous(size: usize) -> *mut u8 { + let flags = libc::MAP_PRIVATE | libc::MAP_ANON; + // MAP_NORESERVE (Linux) keeps the huge sparse reservation from committing swap up front; macOS + // backs anonymous mappings lazily without it. + #[cfg(target_os = "linux")] + let flags = flags | libc::MAP_NORESERVE; + // SAFETY: a null `addr` lets the kernel pick the placement; `fd` is -1 for an anonymous map. + let ret = unsafe { libc::mmap(ptr::null_mut(), size, libc::PROT_READ | libc::PROT_WRITE, flags, -1, 0) }; + if ret == libc::MAP_FAILED { + ptr::null_mut() + } else { + ret.cast::() } } -#[cfg(all(target_os = "linux", target_arch = "aarch64"))] -mod imp { - use std::ptr; - - const SYS_MMAP: usize = 222; - const SYS_MADVISE: usize = 233; - - const PROT_READ: usize = 1; - const PROT_WRITE: usize = 2; - const MAP_PRIVATE: usize = 0x02; - const MAP_ANONYMOUS: usize = 0x20; - const MAP_NORESERVE: usize = 0x4000; - - pub const MADV_NOHUGEPAGE: usize = 15; - - #[inline] - unsafe fn syscall6(nr: usize, a1: usize, a2: usize, a3: usize, a4: usize, a5: usize, a6: usize) -> isize { - let ret: isize; - unsafe { - std::arch::asm!( - "svc 0", - in("x8") nr, - inlateout("x0") a1 as isize => ret, - in("x1") a2, - in("x2") a3, - in("x3") a4, - in("x4") a5, - in("x5") a6, - options(nostack), - ); - } - ret - } - - #[inline] - unsafe fn syscall3(nr: usize, a1: usize, a2: usize, a3: usize) -> isize { - let ret: isize; - unsafe { - std::arch::asm!( - "svc 0", - in("x8") nr, - inlateout("x0") a1 as isize => ret, - in("x1") a2, - in("x2") a3, - options(nostack), - ); - } - ret - } - - #[inline] - pub unsafe fn mmap_anonymous(size: usize) -> *mut u8 { - let flags = MAP_PRIVATE | MAP_ANONYMOUS | MAP_NORESERVE; - let ret = unsafe { syscall6(SYS_MMAP, 0, size, PROT_READ | PROT_WRITE, flags, usize::MAX, 0) }; - if ret < 0 { ptr::null_mut() } else { ret as *mut u8 } - } - - #[inline] - pub unsafe fn madvise(ptr: *mut u8, size: usize, advice: usize) { - unsafe { syscall3(SYS_MADVISE, ptr as usize, size, advice) }; - } -} - -#[cfg(not(all(target_os = "linux", any(target_arch = "x86_64", target_arch = "aarch64"))))] -mod imp { - use std::ptr; - - pub const MADV_NOHUGEPAGE: usize = 15; - - #[inline] - pub unsafe fn mmap_anonymous(size: usize) -> *mut u8 { - // MAP_NORESERVE is Linux-only. macOS lazily backs anonymous mappings - // with physical memory by default, so the large virtual reservation - // is fine without NORESERVE. - let prot = libc::PROT_READ | libc::PROT_WRITE; - let flags = libc::MAP_PRIVATE | libc::MAP_ANON; - let ret = unsafe { libc::mmap(ptr::null_mut(), size, prot, flags, -1, 0) }; - if ret == libc::MAP_FAILED { - ptr::null_mut() - } else { - ret.cast::() - } - } - - #[inline] - pub unsafe fn madvise(_ptr: *mut u8, _size: usize, _advice: usize) { - // The advice values we pass are Linux-specific. +/// Apply `advice` to `[ptr, ptr + size)`. No-op on non-Linux (the advice values we use are +/// Linux-specific). +/// +/// # Safety +/// `ptr`/`size` must describe a live mapping returned by [`mmap_anonymous`]. +#[inline] +pub unsafe fn madvise(ptr: *mut u8, size: usize, advice: usize) { + #[cfg(target_os = "linux")] + unsafe { + // SAFETY: the caller guarantees `[ptr, ptr + size)` is a live mapping. + libc::madvise(ptr.cast::(), size, advice as libc::c_int); + } + #[cfg(not(target_os = "linux"))] + { + let _ = (ptr, size, advice); } } - -pub use imp::{MADV_NOHUGEPAGE, madvise, mmap_anonymous}; diff --git a/crates/backend/zk-alloc/tests/test_alloc.rs b/crates/backend/zk-alloc/tests/test_alloc.rs new file mode 100644 index 000000000..e78583d9e --- /dev/null +++ b/crates/backend/zk-alloc/tests/test_alloc.rs @@ -0,0 +1,49 @@ +//! `ArenaVec` drives the arena explicitly, with the process keeping its **own** allocator (no +//! `#[global_allocator]` is installed here). Only `ArenaVec`-backed buffers touch the arena; +//! everything else is untouched by a phase reset — the property that lets a library use the +//! arena without forcing its allocator on consumers. + +use zk_alloc::{ArenaVec, begin_phase, enable_arena, end_phase}; + +const N: usize = 4096; + +#[test] +fn arena_vec_without_global_allocator() { + // Opt into the arena: without this, begin_phase/end_phase are inert and ArenaVec would + // transparently use the system allocator (no slab reuse to observe). + enable_arena(); + + // Phase 1: one arena allocation on this (main) thread → claims the slab at its base. + begin_phase(); + let mut v: ArenaVec = ArenaVec::with_capacity(N); + v.resize(N, 0xABCD); // fits the reservation: no realloc, pointer stays put + let p1 = v.as_ptr() as usize; + end_phase(); + + // Arena is off: this lands in the system allocator and must survive the next reset. + let canary = vec![0xAB_u8; 8192]; + + // Phase 2: the slab is reset, so an identically-shaped buffer reuses the same address. + begin_phase(); + let mut w: ArenaVec = ArenaVec::with_capacity(N); + w.resize(N, 0x1234); + let p2 = w.as_ptr() as usize; + end_phase(); + + assert_eq!( + p1, p2, + "phase reset should recycle the slab — ArenaVec must hit the arena" + ); + assert!( + canary.iter().all(|&b| b == 0xAB), + "a system allocation was corrupted by the arena reset" + ); + + // Outside any phase, ArenaVec transparently uses the system allocator (no panic). + let mut off: ArenaVec = ArenaVec::new(); + off.extend(0..1000); + assert_eq!(off.iter().sum::(), (0..1000).sum()); + + drop(v); + drop(w); +} diff --git a/crates/lean_compiler/src/lang.rs b/crates/lean_compiler/src/lang.rs index 8da043506..b88bd784c 100644 --- a/crates/lean_compiler/src/lang.rs +++ b/crates/lean_compiler/src/lang.rs @@ -16,16 +16,6 @@ pub struct Program { pub filepaths: BTreeMap, } -impl Program { - pub fn inlined_function_names(&self) -> BTreeSet { - self.functions - .iter() - .filter(|(_, func)| func.inlined) - .map(|(name, _)| name.clone()) - .collect() - } -} - /// A function argument with its modifiers #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] pub struct FunctionArg { diff --git a/crates/lean_prover/Cargo.toml b/crates/lean_prover/Cargo.toml index 31046f37f..28fe114a9 100644 --- a/crates/lean_prover/Cargo.toml +++ b/crates/lean_prover/Cargo.toml @@ -14,7 +14,6 @@ pest.workspace = true pest_derive.workspace = true xmss.workspace = true rand.workspace = true - tracing.workspace = true sub_protocols.workspace = true lean_vm.workspace = true diff --git a/crates/lean_prover/src/prove_execution.rs b/crates/lean_prover/src/prove_execution.rs index e9f90dea2..41b27cdc7 100644 --- a/crates/lean_prover/src/prove_execution.rs +++ b/crates/lean_prover/src/prove_execution.rs @@ -2,6 +2,7 @@ use std::collections::BTreeMap; use crate::*; use backend::ansi::Colorize; +use backend::{ArenaVec, enter_phase}; use lean_vm::*; use serde::{Deserialize, Serialize}; use sub_protocols::*; @@ -22,6 +23,8 @@ pub fn prove_execution( whir_config: &WhirConfigBuilder, vm_profiler: bool, ) -> Result { + let _phase = enter_phase(); + check_rate(whir_config.starting_log_inv_rate).map_err(|_| ProverError::InvalidRate)?; let ExecutionTrace { traces, @@ -79,7 +82,7 @@ pub fn prove_execution( tracing::info!("Trace tables sizes: {}", table_log.magenta()); // TODO parrallelize - let mut memory_acc = F::zero_vec(memory.len()); + let mut memory_acc = unsafe { ArenaVec::::zeroed(memory.len()) }; info_span!("Building memory access count").in_scope(|| -> Result<(), ProverError> { for (table, trace) in &traces { let buses = table.bus_interactions(); @@ -99,7 +102,7 @@ pub fn prove_execution( })?; // // TODO parrallelize - let mut bytecode_acc = F::zero_vec(bytecode.padded_size()); + let mut bytecode_acc = unsafe { ArenaVec::::zeroed(bytecode.padded_size()) }; info_span!("Building bytecode access count").in_scope(|| -> Result<(), ProverError> { for pc in traces[&Table::execution()].columns[EXEC_COL_PC].iter() { *bytecode_acc.get_mut(pc.to_usize()).ok_or(RunnerError::PCOutOfBounds)? += F::ONE; @@ -158,12 +161,12 @@ pub fn prove_execution( .map(|table| { traces[table].columns[..table.n_columns()] .iter() - .map(Vec::as_slice) + .map(|c| c.as_slice()) .collect() }) .collect(); let _span = info_span!("Computing shifted columns for AIR sumcheck").entered(); - let shifted_rows: Vec>> = ALL_TABLES + let shifted_rows: Vec>> = ALL_TABLES .iter() .zip(&column_refs) .map(|(table, cols)| compute_shifted_columns(table.n_shift_columns(), cols)) @@ -190,10 +193,10 @@ pub fn prove_execution( let eq_suffix = from_end(gkr_point, log_n_rows).to_vec(); let alpha_slice = air_alpha_powers[alpha_offset..alpha_offset + n_constraints].to_vec(); - let extra_data = ExtraDataForBuses::new(logup_alphas_eq_poly.clone(), alpha_slice); + let extra_data = ExtraDataForBuses::new(&logup_alphas_eq_poly, alpha_slice); let mut flat_and_shift: Vec<&[PF]> = column_refs[idx].to_vec(); - flat_and_shift.extend(shifted_rows[idx].iter().map(Vec::as_slice)); + flat_and_shift.extend(shifted_rows[idx].iter().map(|c| c.as_slice())); let packed = MleGroupRef::::Base(flat_and_shift).pack(); let non_padded = traces[table].non_padded_n_rows; diff --git a/crates/lean_prover/src/trace_gen.rs b/crates/lean_prover/src/trace_gen.rs index b8ee966dd..e2427c580 100644 --- a/crates/lean_prover/src/trace_gen.rs +++ b/crates/lean_prover/src/trace_gen.rs @@ -5,7 +5,7 @@ use std::{array, collections::BTreeMap}; #[derive(Debug)] pub struct ExecutionTrace { pub traces: BTreeMap, - pub memory: Vec, // of length a multiple of public_memory_size + pub memory: ArenaVec, // of length a multiple of public_memory_size pub metadata: ExecutionMetadata, } @@ -18,8 +18,8 @@ pub fn get_execution_trace( let n_cycles = execution_result.pcs.len(); let memory = &execution_result.memory; - let mut main_trace: [Vec; N_TOTAL_EXECUTION_COLUMNS + N_TEMPORARY_EXEC_COLUMNS] = - array::from_fn(|_| F::zero_vec(n_cycles.next_power_of_two())); + let mut main_trace: [ArenaVec; N_TOTAL_EXECUTION_COLUMNS + N_TEMPORARY_EXEC_COLUMNS] = + array::from_fn(|_| unsafe { ArenaVec::::zeroed(n_cycles.next_power_of_two()) }); for col in &mut main_trace { unsafe { col.set_len(n_cycles); @@ -92,7 +92,7 @@ pub fn get_execution_trace( *trace_row[EXEC_COL_ADDR_C] = addr_c; }); - let mut memory_padded: Vec = parallel::par_map_collect(memory.0.len(), |i| memory.0[i].unwrap_or(F::ZERO)); + let mut memory_padded: ArenaVec = ArenaVec::par_collect(memory.0.len(), |i| memory.0[i].unwrap_or(F::ZERO)); // Write [0000000000000000 | poseidon_compress(0000000000000000)] (to make lookups work on padding-rows). let padding_zero_vec_ptr = memory_padded.len(); @@ -120,7 +120,7 @@ pub fn get_execution_trace( let flag_out8_col = &left[POSEIDON_COL_FLAG_OUT8]; let nu_c_col = &left[POSEIDON_COL_NU_C]; const N: usize = HALF_DIGEST_LEN + DIGEST_LEN; - let cols: &mut [Vec; N] = (&mut right[..N]).try_into().unwrap(); + let cols: &mut [ArenaVec; N] = (&mut right[..N]).try_into().unwrap(); transposed_par_for_each_mut(cols, |i, row| { let flag_out4 = flag_out4_col[i]; diff --git a/crates/lean_prover/src/verify_execution.rs b/crates/lean_prover/src/verify_execution.rs index 07d1dd9db..cd1811c9a 100644 --- a/crates/lean_prover/src/verify_execution.rs +++ b/crates/lean_prover/src/verify_execution.rs @@ -125,7 +125,7 @@ pub fn verify_execution( let alpha_slice = air_alpha_powers[alpha_offset..alpha_offset + n_constraints].to_vec(); verify_data.push(TableVerifyData { table, - extra_data: ExtraDataForBuses::new(logup_alphas_eq_poly.clone(), alpha_slice), + extra_data: ExtraDataForBuses::new(&logup_alphas_eq_poly, alpha_slice), }); alpha_offset += n_constraints; diff --git a/crates/lean_vm/src/execution/runner.rs b/crates/lean_vm/src/execution/runner.rs index 6c8a92be9..b6f809bf0 100644 --- a/crates/lean_vm/src/execution/runner.rs +++ b/crates/lean_vm/src/execution/runner.rs @@ -94,8 +94,8 @@ impl Trace { self.pending_deref_hints.extend(other.pending_deref_hints); for (table, other_t) in other.tables { let mine = self.tables.get_mut(&table).unwrap(); - for (col, new_data) in mine.columns.iter_mut().zip(other_t.columns) { - col.extend(new_data); + for (col, new_data) in mine.columns.iter_mut().zip(&other_t.columns) { + col.extend_from_slice(new_data); } } } diff --git a/crates/lean_vm/src/tables/poseidon/trace_gen.rs b/crates/lean_vm/src/tables/poseidon/trace_gen.rs index dc3963b75..94474740a 100644 --- a/crates/lean_vm/src/tables/poseidon/trace_gen.rs +++ b/crates/lean_vm/src/tables/poseidon/trace_gen.rs @@ -7,7 +7,7 @@ use crate::{ use backend::*; #[instrument(name = "generate Poseidon16 AIR trace", skip_all)] -pub fn fill_trace_poseidon_16(trace: &mut [Vec]) { +pub fn fill_trace_poseidon_16(trace: &mut [ArenaVec]) { let n = trace.iter().map(|col| col.len()).max().unwrap(); for col in trace.iter_mut() { if col.len() != n { @@ -33,7 +33,7 @@ pub fn fill_trace_poseidon_16(trace: &mut [Vec]) { }); // fill the remaining rows (non packed) - let cols: &[Vec; N_COLS] = (&trace[..N_COLS]).try_into().unwrap(); + let cols: &[ArenaVec; N_COLS] = (&trace[..N_COLS]).try_into().unwrap(); for i in m..n { let ptrs: [*mut F; N_COLS] = std::array::from_fn(|c| unsafe { (cols[c].as_ptr() as *mut F).add(i) }); let perm: &mut Poseidon1Cols16<&mut F> = unsafe { &mut *(ptrs.as_ptr() as *mut Poseidon1Cols16<&mut F>) }; diff --git a/crates/lean_vm/src/tables/table_trait.rs b/crates/lean_vm/src/tables/table_trait.rs index f630970ce..54d931db4 100644 --- a/crates/lean_vm/src/tables/table_trait.rs +++ b/crates/lean_vm/src/tables/table_trait.rs @@ -137,7 +137,7 @@ pub struct MemoryLookupGroup { #[derive(Debug, Default)] pub struct TableTrace { - pub columns: Vec>, + pub columns: Vec>, pub non_padded_n_rows: usize, pub log_n_rows: VarCount, } @@ -145,7 +145,7 @@ pub struct TableTrace { impl TableTrace { pub fn new(air: &A) -> Self { Self { - columns: vec![Vec::new(); air.n_columns_total()], + columns: (0..air.n_columns_total()).map(|_| ArenaVec::new()).collect(), non_padded_n_rows: 0, // filled later log_n_rows: 0, // filled later } @@ -166,10 +166,10 @@ pub struct ExtraDataForBuses>> { pub alpha_powers: Vec, } impl>> ExtraDataForBuses { - pub fn new(logup_alphas_eq_poly: Vec, alpha_powers: Vec) -> Self { + pub fn new(logup_alphas_eq_poly: &[EF], alpha_powers: Vec) -> Self { let logup_alphas_eq_poly_packed = logup_alphas_eq_poly.iter().map(|a| EFPacking::::from(*a)).collect(); Self { - logup_alphas_eq_poly, + logup_alphas_eq_poly: logup_alphas_eq_poly.to_vec(), logup_alphas_eq_poly_packed, alpha_powers, } diff --git a/crates/rec_aggregation/Cargo.toml b/crates/rec_aggregation/Cargo.toml index 6df98da6a..539b88271 100644 --- a/crates/rec_aggregation/Cargo.toml +++ b/crates/rec_aggregation/Cargo.toml @@ -8,12 +8,10 @@ workspace = true [features] prox-gaps-conjecture = ["lean_prover/prox-gaps-conjecture"] -standard-alloc = [] [dependencies] xmss.workspace = true rand.workspace = true -zk-alloc.workspace = true tracing.workspace = true include_dir.workspace = true diff --git a/crates/rec_aggregation/src/benchmark.rs b/crates/rec_aggregation/src/benchmark.rs index a64949f0d..61e67e305 100644 --- a/crates/rec_aggregation/src/benchmark.rs +++ b/crates/rec_aggregation/src/benchmark.rs @@ -397,9 +397,6 @@ fn build_aggregation( let mut last_result: Option = None; let own_display_index = display_index + count_nodes(topology) - 1; for _ in 0..repeat { - #[cfg(not(feature = "standard-alloc"))] - zk_alloc::begin_phase(); - let time = Instant::now(); let result = aggregate_single_message_signatures( &children, @@ -411,13 +408,6 @@ fn build_aggregation( .unwrap(); let elapsed = time.elapsed(); - // Clone the outputs out of the arena before the next phase resets its slabs. - #[cfg(not(feature = "standard-alloc"))] - let result = { - zk_alloc::end_phase(); - result.clone() - }; - times.push(elapsed.as_secs_f64()); last_result = Some(result); diff --git a/crates/rec_aggregation/src/single_message_aggregation.rs b/crates/rec_aggregation/src/single_message_aggregation.rs index 877bb9a72..d98e3f280 100644 --- a/crates/rec_aggregation/src/single_message_aggregation.rs +++ b/crates/rec_aggregation/src/single_message_aggregation.rs @@ -32,9 +32,6 @@ pub(crate) const N_TWEAKS: usize = 1 + V * CHAIN_LENGTH + 1 + LOG_LIFETIME; pub(crate) const TWEAK_SLOT_SIZE: usize = 4; pub(crate) const TWEAK_TABLE_SIZE_FE_PADDED: usize = (N_TWEAKS * TWEAK_SLOT_SIZE).next_multiple_of(DIGEST_LEN); -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] -pub(crate) struct Digest(pub [F; DIGEST_LEN]); - #[derive(Debug, Clone, PartialEq, Eq)] pub struct SingleMessageInfo { pub message: [F; MESSAGE_LEN_FE], @@ -100,10 +97,6 @@ impl SingleMessageAggregateSignature { let (value, rest) = postcard::take_from_bytes::(&decompressed).ok()?; rest.is_empty().then_some(value) } - - pub(crate) fn bytecode_claim_flat(&self) -> Vec { - self.info.bytecode_claim_flat() - } } impl SingleMessageInfo { diff --git a/crates/sub_protocols/Cargo.toml b/crates/sub_protocols/Cargo.toml index 2c1ab3879..70b987fa2 100644 --- a/crates/sub_protocols/Cargo.toml +++ b/crates/sub_protocols/Cargo.toml @@ -8,7 +8,6 @@ workspace = true [dependencies] tracing.workspace = true - lean_vm.workspace = true backend.workspace = true diff --git a/crates/sub_protocols/src/air_sumcheck.rs b/crates/sub_protocols/src/air_sumcheck.rs index 8600de649..eb9ca0a02 100644 --- a/crates/sub_protocols/src/air_sumcheck.rs +++ b/crates/sub_protocols/src/air_sumcheck.rs @@ -89,10 +89,10 @@ where let _span = info_span!("chunk-bit-reversing columns").entered(); let chunk_size = 1usize << pivot; let shift = usize::BITS as usize - pivot; - let mut bit_reversed: Vec>> = (0..cols.len()).map(|_| Vec::new()).collect(); + let mut bit_reversed: Vec>> = vec![ArenaVec::new(); cols.len()]; parallel::par_chunks_mut(&mut bit_reversed, 1, |i, out_slot| { let src = cols[i]; - let mut dst: Vec> = unsafe { uninitialized_vec(src.len()) }; + let mut dst: ArenaVec> = unsafe { ArenaVec::uninitialized(src.len()) }; let src_u = PFPacking::::unpack_slice(src); let dst_u = PFPacking::::unpack_slice_mut(&mut dst); for (src_chunk, dst_chunk) in src_u.chunks_exact(chunk_size).zip(dst_u.chunks_exact_mut(chunk_size)) @@ -657,12 +657,12 @@ pub fn prove_batched_air_sumcheck<'a, EF: ExtensionField>>( MultilinearPoint(challenges) } -pub fn compute_shifted_columns(n_shift_columns: usize, columns: &[&[F]]) -> Vec> { +pub fn compute_shifted_columns(n_shift_columns: usize, columns: &[&[F]]) -> Vec> { // Convention: the first `n_shift_columns` columns are the ones that get shifted. - let mut out: Vec> = (0..n_shift_columns).map(|_| Vec::new()).collect(); + let mut out: Vec> = (0..n_shift_columns).map(|_| ArenaVec::new()).collect(); parallel::par_chunks_mut(&mut out, 1, |i, slot| { let column = columns[i]; - let mut shifted = unsafe { uninitialized_vec(column.len()) }; + let mut shifted = unsafe { ArenaVec::::uninitialized(column.len()) }; shifted[..column.len() - 1].copy_from_slice(&column[1..]); shifted[column.len() - 1] = column[column.len() - 1]; slot[0] = shifted; diff --git a/crates/sub_protocols/src/logup.rs b/crates/sub_protocols/src/logup.rs index 2b08d8666..f8ebda98e 100644 --- a/crates/sub_protocols/src/logup.rs +++ b/crates/sub_protocols/src/logup.rs @@ -48,9 +48,10 @@ pub fn prove_generic_logup( &tables_log_heights_sorted, ); let total_gkr_n_vars = log2_ceil_usize(total_active_len); - let mut numerators: Vec = unsafe { uninitialized_vec(total_active_len) }; + let mut numerators: ArenaVec = unsafe { ArenaVec::::uninitialized(total_active_len) }; let width = packing_width::(); - let mut denominators: Vec> = unsafe { uninitialized_vec(total_active_len / width) }; + let mut denominators: ArenaVec> = + unsafe { ArenaVec::>::uninitialized(total_active_len / width) }; let c_packed = EFPacking::::from(c); let alphas_packed: Vec> = alphas_eq_poly.iter().map(|a| EFPacking::::from(*a)).collect(); let memory_domainsep_packed = PFPacking::::from(F::from_usize(LOGUP_MEMORY_DOMAINSEP)); diff --git a/crates/sub_protocols/src/quotient_gkr/layers.rs b/crates/sub_protocols/src/quotient_gkr/layers.rs index eb8a769e0..7a19da588 100644 --- a/crates/sub_protocols/src/quotient_gkr/layers.rs +++ b/crates/sub_protocols/src/quotient_gkr/layers.rs @@ -1,22 +1,21 @@ use backend::PackedValue; -use std::borrow::Cow; use backend::*; pub(super) enum LayerStorage<'a, EF: ExtensionField>> { Initial { - nums: Cow<'a, [PFPacking]>, - dens: Cow<'a, [EFPacking]>, + nums: ArenaCow<'a, PFPacking>, + dens: ArenaCow<'a, EFPacking>, chunk_log: usize, }, PackedBr { - nums: Cow<'a, [EFPacking]>, - dens: Cow<'a, [EFPacking]>, + nums: ArenaCow<'a, EFPacking>, + dens: ArenaCow<'a, EFPacking>, chunk_log: usize, }, Natural { - nums: Cow<'a, [EF]>, - dens: Cow<'a, [EF]>, + nums: ArenaCow<'a, EF>, + dens: ArenaCow<'a, EF>, }, } @@ -24,24 +23,24 @@ impl<'a, EF: ExtensionField>> LayerStorage<'a, EF> { pub(super) fn convert_to_natural(&self) -> Self { match self { Self::Initial { nums, dens, chunk_log } => { - let n_nat_base: Vec = unpack_base_and_unreverse_active::(nums.as_ref(), *chunk_log); + let n_nat_base: ArenaVec = unpack_base_and_unreverse_active::(nums.as_ref(), *chunk_log); let d_nat = unpack_and_unreverse_active::(dens.as_ref(), *chunk_log); Self::Natural { - nums: Cow::Owned(n_nat_base), - dens: Cow::Owned(d_nat), + nums: ArenaCow::Owned(n_nat_base), + dens: ArenaCow::Owned(d_nat), } } Self::PackedBr { nums, dens, chunk_log } => { let n_nat = unpack_and_unreverse_active::(nums.as_ref(), *chunk_log); let d_nat = unpack_and_unreverse_active::(dens.as_ref(), *chunk_log); Self::Natural { - nums: Cow::Owned(n_nat), - dens: Cow::Owned(d_nat), + nums: ArenaCow::Owned(n_nat), + dens: ArenaCow::Owned(d_nat), } } Self::Natural { nums, dens } => Self::Natural { - nums: Cow::Owned(nums.to_vec()), - dens: Cow::Owned(dens.to_vec()), + nums: ArenaCow::Owned(ArenaVec::from_slice(nums.as_ref())), + dens: ArenaCow::Owned(ArenaVec::from_slice(dens.as_ref())), }, } } @@ -52,8 +51,8 @@ impl<'a, EF: ExtensionField>> LayerStorage<'a, EF> { let (new_nums, new_dens) = sum_quotients_2_by_2_packed_br::(nums.as_ref(), dens.as_ref(), *chunk_log); Self::PackedBr { - nums: Cow::Owned(new_nums), - dens: Cow::Owned(new_dens), + nums: ArenaCow::Owned(new_nums), + dens: ArenaCow::Owned(new_dens), chunk_log: *chunk_log - 1, } } @@ -61,16 +60,16 @@ impl<'a, EF: ExtensionField>> LayerStorage<'a, EF> { let (new_nums, new_dens) = sum_quotients_2_by_2_packed_br::(nums.as_ref(), dens.as_ref(), *chunk_log); Self::PackedBr { - nums: Cow::Owned(new_nums), - dens: Cow::Owned(new_dens), + nums: ArenaCow::Owned(new_nums), + dens: ArenaCow::Owned(new_dens), chunk_log: *chunk_log - 1, } } Self::Natural { nums, dens } => { let (nn, nd) = sum_quotients_2_by_2(nums.as_ref(), dens.as_ref()); Self::Natural { - nums: Cow::Owned(nn), - dens: Cow::Owned(nd), + nums: ArenaCow::Owned(nn), + dens: ArenaCow::Owned(nd), } } } @@ -84,7 +83,7 @@ impl<'a, EF: ExtensionField>> LayerStorage<'a, EF> { } } - pub fn materialise_in_full(self) -> (Vec, Vec) { + pub fn materialise_in_full(self) -> (ArenaVec, ArenaVec) { let natural = match self { Self::Natural { .. } => self, other => other.convert_to_natural(), @@ -101,11 +100,11 @@ impl<'a, EF: ExtensionField>> LayerStorage<'a, EF> { } } -pub(super) fn bit_reverse_chunks(v: &[T], chunk_log: usize) -> Vec { +pub(super) fn bit_reverse_chunks(v: &[T], chunk_log: usize) -> ArenaVec { let n = v.len(); let chunk_size = 1usize << chunk_log; debug_assert!(n.is_multiple_of(chunk_size)); - let mut out: Vec = unsafe { uninitialized_vec(n) }; + let mut out: ArenaVec = unsafe { ArenaVec::uninitialized(n) }; if chunk_log == 0 { out.copy_from_slice(v); return out; @@ -120,14 +119,14 @@ pub(super) fn bit_reverse_chunks(v: &[T], chunk_log: usiz out } -fn sum_quotients_2_by_2>>(nums: &[EF], dens: &[EF]) -> (Vec, Vec) { +fn sum_quotients_2_by_2>>(nums: &[EF], dens: &[EF]) -> (ArenaVec, ArenaVec) { assert_eq!(nums.len(), dens.len()); let active_len = nums.len(); let new_active = active_len.div_ceil(2); let full_pairs = active_len / 2; - let mut new_nums: Vec = unsafe { uninitialized_vec(new_active) }; - let mut new_dens: Vec = unsafe { uninitialized_vec(new_active) }; + let mut new_nums: ArenaVec = unsafe { ArenaVec::uninitialized(new_active) }; + let mut new_dens: ArenaVec = unsafe { ArenaVec::uninitialized(new_active) }; parallel::par_for_each_mut2( &mut new_nums[..full_pairs], @@ -155,7 +154,7 @@ fn sum_quotients_2_by_2_packed_br>, N>( nums: &[N], dens: &[EFPacking], chunk_log: usize, -) -> (Vec>, Vec>) +) -> (ArenaVec>, ArenaVec>) where N: Copy + Send + Sync, EFPacking: Algebra, @@ -168,8 +167,8 @@ where let stride = 1usize << bit; let lo_mask = stride - 1; - let mut new_nums: Vec> = unsafe { uninitialized_vec(nums.len() >> 1) }; - let mut new_dens: Vec> = unsafe { uninitialized_vec(nums.len() >> 1) }; + let mut new_nums: ArenaVec> = unsafe { ArenaVec::uninitialized(nums.len() >> 1) }; + let mut new_dens: ArenaVec> = unsafe { ArenaVec::uninitialized(nums.len() >> 1) }; parallel::par_for_each_mut2(&mut new_nums, &mut new_dens, |new_j, num_out, den_out| { let i_hi = new_j >> bit; @@ -186,11 +185,11 @@ where pub(super) fn unpack_and_unreverse_active>>( v: &[EFPacking], chunk_log: usize, -) -> Vec { - bit_reverse_chunks(&unpack_extension::(v), chunk_log) +) -> ArenaVec { + bit_reverse_chunks(&unpack_extension::>(v), chunk_log) } -fn unpack_base_and_unreverse_active>>(v: &[PFPacking], chunk_log: usize) -> Vec { - let active_unpacked: Vec = PFPacking::::unpack_slice(v).iter().map(|x| EF::from(*x)).collect(); +fn unpack_base_and_unreverse_active>>(v: &[PFPacking], chunk_log: usize) -> ArenaVec { + let active_unpacked: ArenaVec = PFPacking::::unpack_slice(v).iter().map(|x| EF::from(*x)).collect(); bit_reverse_chunks(&active_unpacked, chunk_log) } diff --git a/crates/sub_protocols/src/quotient_gkr/mod.rs b/crates/sub_protocols/src/quotient_gkr/mod.rs index d8671e142..0d2153536 100644 --- a/crates/sub_protocols/src/quotient_gkr/mod.rs +++ b/crates/sub_protocols/src/quotient_gkr/mod.rs @@ -1,5 +1,3 @@ -use std::borrow::Cow; - use backend::*; use tracing::instrument; @@ -41,8 +39,8 @@ pub fn prove_gkr_quotient<'a, EF: ExtensionField>>( assert_eq!(nums_br.len(), dens_br.len()); let initial = LayerStorage::Initial { - nums: Cow::Borrowed(nums_br), - dens: Cow::Borrowed(dens_br), + nums: ArenaCow::Borrowed(nums_br), + dens: ArenaCow::Borrowed(dens_br), chunk_log: pivot, }; diff --git a/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs b/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs index 35f5a84b3..c771d6e04 100644 --- a/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs +++ b/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs @@ -1,13 +1,10 @@ -use std::{ - borrow::Cow, - ops::{Add, AddAssign, Mul}, -}; +use std::ops::{Add, AddAssign, Mul}; use backend::*; use crate::quotient_gkr::layers::unpack_and_unreverse_active; -pub(super) fn even_odd_split(v: &[T]) -> (Vec, Vec) { +pub(super) fn even_odd_split(v: &[T]) -> (ArenaVec, ArenaVec) { ( v.iter().step_by(2).copied().collect(), v.iter().skip(1).step_by(2).copied().collect(), @@ -132,12 +129,12 @@ pub(super) fn quotient_sumcheck_prove_packed_br_base>> let mut sum = expected_sum; let outer_point = remaining_eq[..head_len].to_vec(); - let eq_outer = eval_eq(&outer_point); + let eq_outer: ArenaVec = eval_eq(&outer_point); let padding_sum = alpha * mle_of_zeros_then_ones(active_chunks, &outer_point); let eq_alpha_0 = *remaining_eq.last().unwrap(); - let eq_within_0 = eval_eq_packed(&within_pt(&remaining_eq, head_len)); + let eq_within_0: ArenaVec<_> = eval_eq_packed(&within_pt(&remaining_eq, head_len)); let coeffs_0 = compute_round_packed::(packed_nums, packed_dens, parent_chunk_log, &eq_outer, &eq_within_0); let r0 = finalize_round( prover_state, @@ -152,7 +149,7 @@ pub(super) fn quotient_sumcheck_prove_packed_br_base>> remaining_eq.pop(); let eq_alpha_1 = *remaining_eq.last().unwrap(); - let eq_within_1 = eval_eq_packed(&within_pt(&remaining_eq, head_len)); + let eq_within_1: ArenaVec<_> = eval_eq_packed(&within_pt(&remaining_eq, head_len)); let (nums_ext, dens_ext, coeffs_1) = fold_and_compute_round_packed::(packed_nums, packed_dens, parent_chunk_log, r0, &eq_outer, &eq_within_1); let r1 = finalize_round( @@ -169,8 +166,8 @@ pub(super) fn quotient_sumcheck_prove_packed_br_base>> run_phase1_sumcheck( prover_state, - Cow::Owned(nums_ext), - Cow::Owned(dens_ext), + ArenaCow::Owned(nums_ext), + ArenaCow::Owned(dens_ext), parent_chunk_log - 2, remaining_eq, q_natural, @@ -186,15 +183,15 @@ pub(super) fn quotient_sumcheck_prove_packed_br_base>> #[allow(clippy::too_many_arguments)] pub(super) fn run_phase1_sumcheck<'a, EF: ExtensionField>>( prover_state: &mut impl FSProver, - mut nums: Cow<'a, [EFPacking]>, - mut dens: Cow<'a, [EFPacking]>, + mut nums: ArenaCow<'a, EFPacking>, + mut dens: ArenaCow<'a, EFPacking>, mut layer_chunk_log: usize, mut remaining_eq: Vec, mut q_natural: Vec, alpha: EF, mut sum: EF, mut mmf: EF, - precomputed_eq_outer: Option>, + precomputed_eq_outer: Option>, initial_pending_r: Option, ) -> (Vec, [EF; 4]) { let w = packing_log_width::(); @@ -219,7 +216,7 @@ pub(super) fn run_phase1_sumcheck<'a, EF: ExtensionField>>( let head_len = (remaining_eq.len() + 1).saturating_sub(layer_chunk_log); let outer_point: Vec = remaining_eq[..head_len].to_vec(); - let eq_outer: Vec = precomputed_eq_outer.unwrap_or_else(|| eval_eq(&outer_point)); + let eq_outer: ArenaVec = precomputed_eq_outer.unwrap_or_else(|| eval_eq(&outer_point)); let active_chunks = (nums.len() << w) >> (layer_chunk_log + usize::from(initial_pending_r.is_some())); @@ -228,7 +225,7 @@ pub(super) fn run_phase1_sumcheck<'a, EF: ExtensionField>>( let mut pending_r: Option = initial_pending_r; while layer_chunk_log > w + 1 && remaining_eq.len() > w + 1 { let eq_alpha = *remaining_eq.last().unwrap(); - let eq_within = eval_eq_packed(&within_pt(&remaining_eq, head_len)); + let eq_within: ArenaVec<_> = eval_eq_packed(&within_pt(&remaining_eq, head_len)); let coeffs = if let Some(prev_r) = pending_r.take() { let (new_nums, new_dens, c) = fold_and_compute_round_packed::( @@ -239,8 +236,8 @@ pub(super) fn run_phase1_sumcheck<'a, EF: ExtensionField>>( &eq_outer, &eq_within, ); - nums = Cow::Owned(new_nums); - dens = Cow::Owned(new_dens); + nums = ArenaCow::Owned(new_nums); + dens = ArenaCow::Owned(new_dens); c } else { compute_round_packed::(nums.as_ref(), dens.as_ref(), layer_chunk_log, &eq_outer, &eq_within) @@ -256,8 +253,8 @@ pub(super) fn run_phase1_sumcheck<'a, EF: ExtensionField>>( if let Some(prev_r) = pending_r { let prev_bit = layer_chunk_log - 1 - w; let mul = |x: EFPacking, a: EF| x * a; - nums = Cow::Owned(fold_multilinear_at_bit(nums.as_ref(), prev_r, prev_bit, &mul, false)); - dens = Cow::Owned(fold_multilinear_at_bit(dens.as_ref(), prev_r, prev_bit, &mul, false)); + nums = ArenaCow::Owned(fold_multilinear_at_bit(nums.as_ref(), prev_r, prev_bit, &mul, false)); + dens = ArenaCow::Owned(fold_multilinear_at_bit(dens.as_ref(), prev_r, prev_bit, &mul, false)); } let nums_nat = unpack_and_unreverse_active::(nums.as_ref(), layer_chunk_log); @@ -282,10 +279,10 @@ pub(super) fn run_phase1_sumcheck<'a, EF: ExtensionField>>( #[allow(clippy::too_many_arguments)] pub(super) fn run_phase2_sumcheck>>( prover_state: &mut impl FSProver, - mut num_l: Vec, - mut num_r: Vec, - mut den_l: Vec, - mut den_r: Vec, + mut num_l: ArenaVec, + mut num_r: ArenaVec, + mut den_l: ArenaVec, + mut den_r: ArenaVec, mut remaining_eq: Vec, mut q_natural: Vec, alpha: EF, @@ -293,7 +290,7 @@ pub(super) fn run_phase2_sumcheck>>( mut mmf: EF, ) -> (Vec, [EF; 4]) { let eq_prefix_init = &remaining_eq[..remaining_eq.len().saturating_sub(1)]; - let mut eq_table = eval_eq(eq_prefix_init); + let mut eq_table: ArenaVec = eval_eq(eq_prefix_init); for _round in 0..remaining_eq.len() { let eq_alpha = *remaining_eq.last().unwrap(); @@ -360,7 +357,7 @@ pub(super) fn run_phase2_sumcheck>>( if new_eq_len > 0 { let fold_eq = |i: usize| eq_table[2 * i] + eq_table[2 * i + 1]; eq_table = if new_eq_len >= PARALLEL_THRESHOLD { - parallel::par_map_collect(new_eq_len, fold_eq) + ArenaVec::par_collect(new_eq_len, fold_eq) } else { (0..new_eq_len).map(fold_eq).collect() }; @@ -375,11 +372,11 @@ pub(super) fn run_phase2_sumcheck>>( (q_natural, evals) } -fn fold_normal_with_padding>>(m: &[EF], r: EF, pad_value: EF) -> Vec { +fn fold_normal_with_padding>>(m: &[EF], r: EF, pad_value: EF) -> ArenaVec { let active = m.len(); let new_active = active.div_ceil(2); assert!(new_active != 0); - let mut out: Vec = unsafe { uninitialized_vec(new_active) }; + let mut out: ArenaVec = unsafe { ArenaVec::uninitialized(new_active) }; let compute = |i: usize, slot: &mut EF| { let a = m[2 * i]; @@ -447,7 +444,11 @@ fn fold_and_compute_round_packed>, N>( prev_r: EF, eq_outer: &[EF], eq_within: &[EFPacking], -) -> (Vec>, Vec>, RoundCoeffs>) +) -> ( + ArenaVec>, + ArenaVec>, + RoundCoeffs>, +) where N: PrimeCharacteristicRing + Copy + Send + Sync, EFPacking: Algebra, @@ -466,8 +467,8 @@ where debug_assert_eq!(eq_within.len(), in_eighth); let active_out_packed = nums.len() / 2; - let mut new_nums: Vec> = unsafe { uninitialized_vec(active_out_packed) }; - let mut new_dens: Vec> = unsafe { uninitialized_vec(active_out_packed) }; + let mut new_nums: ArenaVec> = unsafe { ArenaVec::uninitialized(active_out_packed) }; + let mut new_dens: ArenaVec> = unsafe { ArenaVec::uninitialized(active_out_packed) }; let prev_r_packed: EFPacking = as From>::from(prev_r); let n_chunks = nums.len() / in_packed; diff --git a/crates/sub_protocols/src/stacked_pcs.rs b/crates/sub_protocols/src/stacked_pcs.rs index 9883463b6..07fcabca6 100644 --- a/crates/sub_protocols/src/stacked_pcs.rs +++ b/crates/sub_protocols/src/stacked_pcs.rs @@ -118,7 +118,7 @@ pub fn stack_polynomials_and_commit( log2_strict_usize(bytecode_acc.len()), &tables_heights_sorted.iter().cloned().collect(), ); - let mut global_polynomial = F::zero_vec(1 << stacked_n_vars); // TODO avoid cloning all witness data + let mut global_polynomial = unsafe { ArenaVec::::zeroed(1 << stacked_n_vars) }; global_polynomial[..memory.len()].copy_from_slice(memory); let mut offset = memory.len(); global_polynomial[offset..][..memory_acc.len()].copy_from_slice(memory_acc); diff --git a/crates/sub_protocols/tests/prove_poseidon.rs b/crates/sub_protocols/tests/prove_poseidon.rs index 09efb6601..e0333e4ac 100644 --- a/crates/sub_protocols/tests/prove_poseidon.rs +++ b/crates/sub_protocols/tests/prove_poseidon.rs @@ -20,14 +20,14 @@ fn test_prove_poseidon() { let n_rows = 1 << log_n_rows; let mut rng = StdRng::seed_from_u64(0); let n_cols = num_cols_poseidon_16(); - let mut trace = vec![vec![F::ZERO; n_rows]; n_cols]; + let mut trace: Vec> = (0..n_cols).map(|_| ArenaVec::filled(F::ZERO, n_rows)).collect(); for t in trace.iter_mut().skip(POSEIDON_COL_INPUT_START).take(WIDTH) { - *t = (0..n_rows).map(|_| rng.random()).collect(); + *t = ArenaVec::from_iter((0..n_rows).map(|_| rng.random())); } - trace[POSEIDON_COL_MULTIPLICITY] = vec![F::ONE; n_rows]; - trace[POSEIDON_COL_FLAG_OUT8] = vec![F::ONE; n_rows]; - trace[POSEIDON_COL_ADDR_LEFT_LO] = vec![F::ZERO; n_rows]; - trace[POSEIDON_COL_ADDR_LEFT_HI] = vec![F::from_usize(HALF_DIGEST_LEN); n_rows]; + trace[POSEIDON_COL_MULTIPLICITY] = ArenaVec::filled(F::ONE, n_rows); + trace[POSEIDON_COL_FLAG_OUT8] = ArenaVec::filled(F::ONE, n_rows); + trace[POSEIDON_COL_ADDR_LEFT_LO] = ArenaVec::filled(F::ZERO, n_rows); + trace[POSEIDON_COL_ADDR_LEFT_HI] = ArenaVec::filled(F::from_usize(HALF_DIGEST_LEN), n_rows); fill_trace_poseidon_16(&mut trace); let air = Poseidon16Precompile::; @@ -51,7 +51,7 @@ fn test_prove_poseidon() { let time = Instant::now(); - let mut commitmed_pol = F::zero_vec((n_cols << log_n_rows).next_power_of_two()); + let mut commitmed_pol = ArenaVec::filled(F::ZERO, (n_cols << log_n_rows).next_power_of_two()); for (i, col) in trace.iter().enumerate() { commitmed_pol[i << log_n_rows..(i + 1) << log_n_rows].copy_from_slice(col); } @@ -61,10 +61,10 @@ fn test_prove_poseidon() { let alpha = prover_state.sample(); let air_alpha_powers: Vec = alpha.powers().collect_n(n_constraints); // BUS=false => `logup_alphas_eq_poly` is unused; only `alpha_powers` matter. - let extra_data = ExtraDataForBuses::new(Vec::new(), air_alpha_powers); + let extra_data = ExtraDataForBuses::new(&[], air_alpha_powers); prover_state.duplex(); let eq_factor: Vec = prover_state.sample_vec(log_n_rows); - let column_refs: Vec<&[F]> = trace.iter().map(Vec::as_slice).collect(); + let column_refs: Vec<&[F]> = trace.iter().map(|c| c.as_slice()).collect(); let packed = MleGroupRef::::Base(column_refs).pack(); let mut sessions: Vec + '_>> = vec![Box::new(AirSumcheckSession::new( @@ -104,7 +104,7 @@ fn test_prove_poseidon() { let alpha = verifier_state.sample(); let air_alpha_powers: Vec = alpha.powers().collect_n(n_constraints); - let extra_data = ExtraDataForBuses::new(Vec::new(), air_alpha_powers); + let extra_data = ExtraDataForBuses::new(&[], air_alpha_powers); verifier_state.duplex(); let eq_factor_v: Vec = verifier_state.sample_vec(log_n_rows); diff --git a/crates/whir/Cargo.toml b/crates/whir/Cargo.toml index 7f882321e..7b2d3fc3d 100644 --- a/crates/whir/Cargo.toml +++ b/crates/whir/Cargo.toml @@ -8,6 +8,7 @@ field = { path = "../backend/field", package = "mt-field" } koala-bear = { path = "../backend/koala-bear", package = "mt-koala-bear" } poly = { path = "../backend/poly", package = "mt-poly" } sumcheck = { path = "../backend/sumcheck", package = "mt-sumcheck" } +zk-alloc.workspace = true fiat-shamir = { path = "../backend/fiat-shamir", package = "mt-fiat-shamir" } utils = { path = "../backend/utils", package = "utils" } symetric = { path = "../backend/symetric", package = "mt-symetric" } diff --git a/crates/whir/src/commit.rs b/crates/whir/src/commit.rs index b64bb3502..f74bcb7fe 100644 --- a/crates/whir/src/commit.rs +++ b/crates/whir/src/commit.rs @@ -4,6 +4,7 @@ use fiat_shamir::FSProver; use field::{ExtensionField, TwoAdicField}; use poly::*; use tracing::{info_span, instrument}; +use zk_alloc::ArenaVec; use crate::*; @@ -35,11 +36,11 @@ impl>> MerkleData { match self { MerkleData::Base(prover_data) => { let (leaf, proof) = merkle_open::, PF>(prover_data, index); - (MleOwned::Base(leaf), proof) + (MleOwned::Base(ArenaVec::from_slice(&leaf)), proof) } MerkleData::Extension(prover_data) => { let (leaf, proof) = merkle_open::, EF>(prover_data, index); - (MleOwned::Extension(leaf), proof) + (MleOwned::Extension(ArenaVec::from_slice(&leaf)), proof) } } } diff --git a/crates/whir/src/dft.rs b/crates/whir/src/dft.rs index c523cf379..4d015d988 100644 --- a/crates/whir/src/dft.rs +++ b/crates/whir/src/dft.rs @@ -28,6 +28,9 @@ use std::sync::RwLock; use field::PackedValue; use field::{BasedVectorSpace, Field, PackedField, TwoAdicField}; use itertools::Itertools; +use zk_alloc::ArenaVec; + +use crate::utils::{flatten_to_base_arena, reconstitute_from_base_arena}; use tracing::instrument; use utils::{as_base_slice, log2_strict_usize}; @@ -75,7 +78,10 @@ impl EvalsDft where F: TwoAdicField, { - pub(crate) fn dft_batch_by_evals(&self, mut mat: Matrix) -> Matrix { + pub(crate) fn dft_batch_by_evals + AsMut<[F]> + Send + Sync>( + &self, + mut mat: Matrix, + ) -> Matrix { let h = mat.height(); let w = mat.width(); let log_h = log2_strict_usize(h); @@ -100,7 +106,7 @@ where // We also divide by the height of the matrix while the data is nicely partitioned // on each core. par_initial_layers( - &mut mat.values, + mat.values.as_mut(), chunk_size, &root_table[root_table.len() - log_num_par_rows..], w, @@ -145,12 +151,12 @@ where #[instrument(skip_all)] pub(crate) fn dft_algebra_batch_by_evals + Clone + Send + Sync>( &self, - mat: Matrix, - ) -> Matrix { + mat: Matrix>, + ) -> Matrix> { let init_width = mat.width(); - let base_mat = Matrix::new(V::flatten_to_base(mat.values), init_width * V::DIMENSION); + let base_mat = Matrix::new(flatten_to_base_arena::(mat.values), init_width * V::DIMENSION); let base_dft_output = self.dft_batch_by_evals(base_mat); - Matrix::new(V::reconstitute_from_base(base_dft_output.values), init_width) + Matrix::new(reconstitute_from_base_arena::(base_dft_output.values), init_width) } } @@ -545,6 +551,7 @@ mod tests { use koala_bear::{KoalaBear, QuinticExtensionFieldKB}; use poly::*; use rand::{RngExt, SeedableRng, rngs::StdRng}; + use zk_alloc::ArenaVec; use crate::*; @@ -560,7 +567,7 @@ mod tests { let evals = (0..(1 << n_vars)).map(|_| rng.random()).collect::>(); let dft = EvalsDft::::default(); - let evals_dft = dft.dft_algebra_batch_by_evals(Matrix::new(evals.clone(), 1)); + let evals_dft = dft.dft_algebra_batch_by_evals(Matrix::new(ArenaVec::from_slice(&evals), 1)); let fft_values = evals_dft.values; for _ in 0..10 { let i = rng.random_range(0..(1 << n_vars)); diff --git a/crates/whir/src/matrix.rs b/crates/whir/src/matrix.rs index c0b39ca43..55a252aa8 100644 --- a/crates/whir/src/matrix.rs +++ b/crates/whir/src/matrix.rs @@ -1,9 +1,6 @@ // Credits: whir-p3 (https://github.com/tcoratger/whir-p3) (MIT and Apache-2.0 licenses). -use std::{ - borrow::{Borrow, BorrowMut}, - marker::PhantomData, -}; +use std::marker::PhantomData; use field::PackedValue; @@ -28,13 +25,13 @@ pub struct Matrix> { _phantom: PhantomData, } -impl + Send + Sync> Matrix { +impl + Send + Sync> Matrix { /// Create a new dense matrix of the given dimensions, backed by the given storage. /// /// It is undefined behavior to create a matrix such that `values.len() % width != 0`. #[must_use] pub fn new(values: V, width: usize) -> Self { - debug_assert!(values.borrow().len().is_multiple_of(width)); + debug_assert!(values.as_ref().len().is_multiple_of(width)); Self { values, width, @@ -49,14 +46,14 @@ impl + Send + Sync> Matrix { #[inline] pub fn height(&self) -> usize { - self.values.borrow().len().checked_div(self.width).unwrap_or(0) + self.values.as_ref().len().checked_div(self.width).unwrap_or(0) } pub fn as_view_mut(&mut self) -> MatrixViewMut<'_, T> where - V: BorrowMut<[T]>, + V: AsMut<[T]>, { - MatrixViewMut::new(self.values.borrow_mut(), self.width) + MatrixViewMut::new(self.values.as_mut(), self.width) } /// Row `r` as an iterator over its values, or `None` if out of bounds. @@ -64,7 +61,7 @@ impl + Send + Sync> Matrix { pub fn row(&self, r: usize) -> Option + '_> { (r < self.height()).then(|| { let start = r * self.width; - self.values.borrow()[start..start + self.width].iter().cloned() + self.values.as_ref()[start..start + self.width].iter().cloned() }) } @@ -86,7 +83,7 @@ impl + Send + Sync> Matrix { let width = self.width; debug_assert!(effective_width <= width); debug_assert!(r + P::WIDTH <= self.height()); - let values = self.values.borrow(); + let values = self.values.as_ref(); let base = r * width; (0..n_leading_zeros).map(|_| P::default()).chain( (0..effective_width) diff --git a/crates/whir/src/merkle.rs b/crates/whir/src/merkle.rs index 44e973d26..ee29acfdc 100644 --- a/crates/whir/src/merkle.rs +++ b/crates/whir/src/merkle.rs @@ -15,34 +15,38 @@ use poly::*; use symetric::merkle::unpack_array; use tracing::instrument; use utils::log2_ceil_usize; +use zk_alloc::ArenaVec; use crate::Dimensions; use crate::Matrix; +use crate::utils::flatten_to_base_arena; pub use symetric::DIGEST_ELEMS; pub(crate) type RoundMerkleTree = WhirMerkleTree; #[allow(clippy::missing_transmute_annotations)] pub(crate) fn merkle_commit>( - matrix: Matrix, + matrix: Matrix>, full_n_cols: usize, effective_n_cols: usize, ) -> ([F; DIGEST_ELEMS], RoundMerkleTree) { if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, QuinticExtensionFieldKB)>() { - let matrix = unsafe { std::mem::transmute::<_, Matrix>(matrix) }; + let matrix = unsafe { + std::mem::transmute::<_, Matrix>>(matrix) + }; let dim = >::DIMENSION; let dft_base_width = matrix.width * dim; let full_base_width = full_n_cols * dim; let effective_base_width = effective_n_cols * dim; - let base_values = QuinticExtensionFieldKB::flatten_to_base(matrix.values); - let base_matrix = Matrix::::new(base_values, dft_base_width); + let base_values = flatten_to_base_arena::(matrix.values); + let base_matrix = Matrix::new(base_values, dft_base_width); let tree = build_merkle_tree_koalabear(base_matrix, full_base_width, effective_base_width); let root: [_; DIGEST_ELEMS] = tree.root(); let root = unsafe { std::mem::transmute_copy::<_, [F; DIGEST_ELEMS]>(&root) }; let tree = unsafe { std::mem::transmute::<_, RoundMerkleTree>(tree) }; (root, tree) } else if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, KoalaBear)>() { - let matrix = unsafe { std::mem::transmute::<_, Matrix>(matrix) }; + let matrix = unsafe { std::mem::transmute::<_, Matrix>>(matrix) }; let tree = build_merkle_tree_koalabear(matrix, full_n_cols, effective_n_cols); let root: [_; DIGEST_ELEMS] = tree.root(); let root = unsafe { std::mem::transmute_copy::<_, [F; DIGEST_ELEMS]>(&root) }; @@ -55,7 +59,7 @@ pub(crate) fn merkle_commit>( #[instrument(name = "build merkle tree", skip_all)] fn build_merkle_tree_koalabear( - leaf: Matrix, + leaf: Matrix>, full_base_width: usize, effective_base_width: usize, ) -> RoundMerkleTree { @@ -71,7 +75,7 @@ fn build_merkle_tree_koalabear( ); let packed_state: [PFPacking; 16] = std::array::from_fn(|i| PFPacking::::from_fn(|_| scalar_state[i])); - let first_layer = first_digest_layer_with_initial_state::, _, DIGEST_ELEMS, 16, 8>( + let first_layer = first_digest_layer_with_initial_state::, _, _, DIGEST_ELEMS, 16, 8>( &perm, &leaf, &packed_state, @@ -154,7 +158,7 @@ pub(crate) fn merkle_verify>( #[derive(Debug, Clone)] pub struct WhirMerkleTree { - pub(crate) leaf: Matrix, + pub(crate) leaf: Matrix>, pub(crate) tree: symetric::merkle::MerkleTree, full_leaf_base_width: usize, } @@ -175,14 +179,22 @@ impl } #[instrument(name = "first digest layer", level = "debug", skip_all)] -fn first_digest_layer_with_initial_state( +fn first_digest_layer_with_initial_state< + P, + Perm, + LV, + const DIGEST_ELEMS: usize, + const WIDTH: usize, + const RATE: usize, +>( perm: &Perm, - matrix: &Matrix, + matrix: &Matrix, packed_initial_state: &[P; WIDTH], effective_base_width: usize, ) -> Vec<[P::Value; DIGEST_ELEMS]> where P: PackedValue + Default, + LV: AsRef<[P::Value]> + Send + Sync, P::Value: Default + Copy + Send + Sync, Perm: koala_bear::symmetric::Permutation<[P::Value; WIDTH]> + koala_bear::symmetric::Permutation<[P; WIDTH]>, { diff --git a/crates/whir/src/open.rs b/crates/whir/src/open.rs index cb773528e..abda2eaf8 100644 --- a/crates/whir/src/open.rs +++ b/crates/whir/src/open.rs @@ -7,6 +7,7 @@ use field::{ExtensionField, Field, TwoAdicField}; use poly::*; use sumcheck::{ProductComputation, run_product_sumcheck, sumcheck_prove_many_rounds}; use tracing::{info_span, instrument}; +use zk_alloc::ArenaVec; use crate::{config::WhirConfig, *}; @@ -187,7 +188,7 @@ where // Convert evaluations to coefficient form and send to the verifier. let mut coeffs = match &round_state.sumcheck_prover.evals { MleOwned::Extension(evals) => evals.clone(), - MleOwned::ExtensionPacked(evals) => unpack_extension::(evals), + MleOwned::ExtensionPacked(evals) => unpack_extension(evals), _ => unreachable!(), }; evals_to_coeffs(&mut coeffs); @@ -211,14 +212,14 @@ where match answer { MleOwned::Base(leaf) => { base_paths.push(MerklePath { - leaf_data: leaf, + leaf_data: leaf.to_vec(), sibling_hashes, leaf_index: challenge, }); } MleOwned::Extension(leaf) => { ext_paths.push(MerklePath { - leaf_data: leaf, + leaf_data: leaf.to_vec(), sibling_hashes, leaf_index: challenge, }); @@ -291,14 +292,14 @@ fn open_merkle_tree_at_challenges>>( match &answer { MleOwned::Base(leaf) => { base_paths.push(MerklePath { - leaf_data: leaf.clone(), + leaf_data: leaf.to_vec(), sibling_hashes, leaf_index: challenge, }); } MleOwned::Extension(leaf) => { ext_paths.push(MerklePath { - leaf_data: leaf.clone(), + leaf_data: leaf.to_vec(), sibling_hashes, leaf_index: challenge, }); @@ -514,7 +515,7 @@ where } #[instrument(skip_all, fields(num_constraints = statements.len(), n_vars = statements[0].total_num_variables))] -fn combine_statement(statements: &[SparseStatement], gamma: EF) -> (Vec>, EF) +fn combine_statement(statements: &[SparseStatement], gamma: EF) -> (ArenaVec>, EF) where EF: ExtensionField>, { @@ -527,13 +528,13 @@ where !s.is_next && s.values.len() == 1 && s.values[0].selector == 0 && s.inner_num_variables() == num_variables }; - let mut combined_weights: Vec>; + let mut combined_weights: ArenaVec>; let mut combined_sum = EF::ZERO; let mut gamma_pow = EF::ONE; let start_idx = match statements { [a, b, ..] if is_full(a) && is_full(b) => { - combined_weights = unsafe { uninitialized_vec(out_len) }; + combined_weights = unsafe { ArenaVec::uninitialized(out_len) }; let sa = gamma_pow; let sb = gamma_pow * gamma; combined_sum = a.values[0].value * sa + b.values[0].value * sb; @@ -542,7 +543,7 @@ where 2 } [a, ..] if is_full(a) => { - combined_weights = unsafe { uninitialized_vec(out_len) }; + combined_weights = unsafe { ArenaVec::uninitialized(out_len) }; let sa = gamma_pow; combined_sum = a.values[0].value * sa; gamma_pow *= gamma; @@ -550,7 +551,7 @@ where 1 } _ => { - combined_weights = EFPacking::::zero_vec(out_len); + combined_weights = unsafe { ArenaVec::zeroed(out_len) }; 0 } }; @@ -563,7 +564,7 @@ where gamma_pow *= gamma; } } else { - let inner_poly = if smt.is_next { + let inner_poly: ArenaVec> = if smt.is_next { let next = matrix_next_mle_folded(&smt.point.0); pack_extension(&next) } else { diff --git a/crates/whir/src/utils.rs b/crates/whir/src/utils.rs index 9a8ec359e..1a64ee7db 100644 --- a/crates/whir/src/utils.rs +++ b/crates/whir/src/utils.rs @@ -4,6 +4,7 @@ use fiat_shamir::{ChallengeSampler, FSProver}; use field::BasedVectorSpace; use field::Field; use field::PackedValue; +use field::PrimeCharacteristicRing; use field::{ExtensionField, TwoAdicField}; use poly::*; use std::any::{Any, TypeId}; @@ -11,10 +12,52 @@ use std::collections::HashMap; use std::sync::{Arc, Mutex, OnceLock}; use tracing::instrument; use utils::log2_strict_usize; +use zk_alloc::ArenaVec; use crate::EvalsDft; use crate::Matrix; +#[inline] +#[must_use] +pub(crate) fn flatten_to_base_arena>( + vec: ArenaVec, +) -> ArenaVec { + const { + assert!(align_of::() == align_of::()); + assert!(size_of::() == V::DIMENSION * size_of::()); + } + let (ptr, len, cap) = vec.into_raw_parts(); + unsafe { ArenaVec::from_raw_parts(ptr.cast::(), len * V::DIMENSION, cap * V::DIMENSION) } +} + +#[inline] +#[must_use] +pub(crate) fn reconstitute_from_base_arena + Clone>( + vec: ArenaVec, +) -> ArenaVec { + const { + assert!(align_of::() == align_of::()); + assert!(size_of::() == V::DIMENSION * size_of::()); + } + let d = V::DIMENSION; + assert!( + vec.len().is_multiple_of(d), + "ArenaVec length (got {}) must be a multiple of the extension field dimension ({}).", + vec.len(), + d + ); + let new_len = vec.len() / d; + if vec.capacity().is_multiple_of(d) { + let (ptr, _len, cap) = vec.into_raw_parts(); + unsafe { ArenaVec::from_raw_parts(ptr.cast::(), new_len, cap / d) } + } else { + let slice_ref = unsafe { std::slice::from_raw_parts(vec.as_ptr().cast::(), new_len) }; + let mut out = ArenaVec::with_capacity(new_len); + out.extend_from_slice(slice_ref); + out + } +} + pub(crate) fn get_challenge_stir_queries>( folded_domain_size: usize, num_queries: usize, @@ -56,13 +99,13 @@ where } pub(crate) enum DftInput { - Base(Vec>), - Extension(Vec), + Base(ArenaVec>), + Extension(ArenaVec), } pub(crate) enum DftOutput { - Base(Matrix>), - Extension(Matrix), + Base(Matrix, ArenaVec>>), + Extension(Matrix>), } pub(crate) fn reorder_and_dft>>( @@ -127,7 +170,7 @@ fn prepare_evals_for_fft_unpacked( folding_factor: usize, log_inv_rate: usize, dft_n_cols: usize, -) -> Vec { +) -> ArenaVec { assert!(evals.len().is_multiple_of(1 << folding_factor)); let n_blocks = 1 << folding_factor; let full_len = evals.len() << log_inv_rate; @@ -135,7 +178,7 @@ fn prepare_evals_for_fft_unpacked( let log_block_size = log2_strict_usize(block_size); let out_len = block_size * dft_n_cols; - let mut out: Vec = unsafe { uninitialized_vec(out_len) }; + let mut out: ArenaVec = unsafe { ArenaVec::uninitialized(out_len) }; if block_size == 0 || dft_n_cols == 0 { return out; } @@ -163,7 +206,7 @@ fn prepare_evals_for_fft_packed_extension>>( evals: &[EFPacking], folding_factor: usize, log_inv_rate: usize, -) -> Vec { +) -> ArenaVec { let log_packing = packing_log_width::(); assert!((evals.len() << log_packing).is_multiple_of(1 << folding_factor)); let n_blocks = 1 << folding_factor; @@ -172,7 +215,7 @@ fn prepare_evals_for_fft_packed_extension>>( let log_block_size = log2_strict_usize(block_size); let packing_mask = (1 << log_packing) - 1; - let mut out: Vec = unsafe { uninitialized_vec(full_len) }; + let mut out: ArenaVec = unsafe { ArenaVec::uninitialized(full_len) }; if block_size == 0 || n_blocks == 0 { return out; } diff --git a/crates/whir/tests/run_whir.rs b/crates/whir/tests/run_whir.rs index 217a19c5d..22c69079d 100644 --- a/crates/whir/tests/run_whir.rs +++ b/crates/whir/tests/run_whir.rs @@ -10,6 +10,7 @@ use poly::*; use rand::{RngExt, SeedableRng, rngs::StdRng}; use tracing_forest::{ForestLayer, util::LevelFilter}; use tracing_subscriber::{EnvFilter, Registry, layer::SubscriberExt, util::SubscriberInitExt}; +use zk_alloc::ArenaVec; type F = KoalaBear; type EF = QuinticExtensionFieldKB; @@ -104,7 +105,7 @@ fn test_run_whir() { precompute_dft_twiddles::(1 << F::TWO_ADICITY); - let polynomial: MleOwned = MleOwned::Base(polynomial); + let polynomial: MleOwned = MleOwned::Base(ArenaVec::from_iter(polynomial)); let time = Instant::now(); let witness = params.commit(&mut prover_state, &polynomial, num_coeffs); diff --git a/src/lib.rs b/src/lib.rs index 15ce81523..521bc65d6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,7 +12,14 @@ pub use xmss::{MESSAGE_LEN_FE, XmssPublicKey, XmssSecretKey, XmssSignature, xmss pub type F = KoalaBear; /// Call once before proving. +/// +/// # Safety +/// Never generate two proofs concurrently in one process. +/// +/// (The arena allocator has a single shared region per process, so concurrent proving corrupts each proof's buffers) +/// Use separate processes to parallelize pub fn setup_prover() { + zk_alloc::enable_arena(); parallel::init(); rec_aggregation::init_aggregation_bytecode(); precompute_dft_twiddles::(1 << 24); @@ -22,13 +29,3 @@ pub fn setup_prover() { pub fn setup_verifier() { rec_aggregation::init_aggregation_bytecode(); } - -/// Bump-arena allocator. -/// -/// To enable, set it as the `#[global_allocator]` in your binary. Then bracket each proving -/// call with [`begin_phase`] / [`end_phase`] and **clone the outputs after [`end_phase`]** so -/// the cloned copy lands in the system allocator before the next [`begin_phase`] resets the -/// arena slabs. -/// -/// See `tests/test_zk_alloc.rs` for a runnable end-to-end example. -pub use zk_alloc::{ZkAllocator, begin_phase, end_phase}; diff --git a/src/main.rs b/src/main.rs index c77e12ea5..c89f23021 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,6 @@ use clap::Parser; use rec_aggregation::benchmark::{AggregationTopology, biggest_leaf, run_aggregation_benchmark}; -#[cfg(not(feature = "standard-alloc"))] -#[global_allocator] -static ALLOC: zk_alloc::ZkAllocator = zk_alloc::ZkAllocator; - #[derive(Parser)] enum Cli { #[command(about = "Aggregate XMSS")] diff --git a/tests/test_multisignatures.rs b/tests/test_multisignatures.rs index a36874243..f7445c167 100644 --- a/tests/test_multisignatures.rs +++ b/tests/test_multisignatures.rs @@ -1,3 +1,4 @@ +use std::sync::{Mutex, MutexGuard}; use std::time::Instant; use lean_multisig::{ @@ -15,6 +16,12 @@ use xmss::{ xmss_key_gen, xmss_sign, xmss_verify, }; +static ARENA_TEST_LOCK: Mutex<()> = Mutex::new(()); + +fn serialize_arena_tests() -> MutexGuard<'static, ()> { + ARENA_TEST_LOCK.lock().unwrap() +} + #[test] fn test_xmss_signature() { let start_slot = 111; @@ -30,6 +37,7 @@ fn test_xmss_signature() { #[test] fn test_aggregation() { + let _arena_guard = serialize_arena_tests(); for n_signatures in [1, 2, 4, 8, 16, 32, 64, 128] { let topology = AggregationTopology { raw_xmss: n_signatures, @@ -43,6 +51,7 @@ fn test_aggregation() { #[test] fn test_single_message_aggregation() { + let _arena_guard = serialize_arena_tests(); setup_prover(); let log_inv_rate = 2; // [1, 2, 3 or 4] (lower = faster but bigger proofs) @@ -75,6 +84,7 @@ fn test_single_message_aggregation() { #[test] fn test_multi_message_aggregation() { + let _arena_guard = serialize_arena_tests(); setup_prover(); let log_inv_rate = 2; // [1, 2, 3 or 4] (lower = faster but bigger proofs) diff --git a/tests/test_zk_alloc.rs b/tests/test_zk_alloc.rs deleted file mode 100644 index 3db69e6c6..000000000 --- a/tests/test_zk_alloc.rs +++ /dev/null @@ -1,28 +0,0 @@ -use lean_multisig::{ - ZkAllocator, aggregate_single_message_signatures, begin_phase, end_phase, setup_prover, - verify_single_message_aggregate, -}; -use xmss::signers_cache::{BENCHMARK_SLOT, get_benchmark_signatures, message_for_benchmark}; - -#[global_allocator] -static ALLOC: ZkAllocator = ZkAllocator; - -#[test] -#[allow(clippy::redundant_clone)] -fn test_aggregation_with_zk_alloc() { - setup_prover(); - - let log_inv_rate = 2; - let message = message_for_benchmark(); - let slot: u32 = BENCHMARK_SLOT; - let signatures = get_benchmark_signatures(); - let raw_xmss = signatures[0..6].to_vec(); - - begin_phase(); - let aggregated = aggregate_single_message_signatures(&[], raw_xmss, message, slot, log_inv_rate).unwrap(); - end_phase(); - // IMPORTANT: clone to move the data out of the arena memory - let aggregated = aggregated.clone(); - - verify_single_message_aggregate(&aggregated).unwrap(); -}