diff --git a/crates/backend/poly/src/eq_mle.rs b/crates/backend/poly/src/eq_mle.rs index 3c7b9d34..0659cdb6 100644 --- a/crates/backend/poly/src/eq_mle.rs +++ b/crates/backend/poly/src/eq_mle.rs @@ -307,11 +307,8 @@ pub fn compute_eval_eq_base_packed( } #[inline] -pub fn compute_eval_eq_base_packed_batched( - evals: &[MultilinearPoint], - out: &mut [EF::ExtensionPacking], - scalars: &[EF], -) where +pub fn compute_eval_eq_base_batched(evals: &[MultilinearPoint], out: &mut [EF], scalars: &[EF]) +where F: Field, EF: ExtensionField, { @@ -321,22 +318,21 @@ pub fn compute_eval_eq_base_packed_batched( } let n = evals[0].len(); - let packing_width = F::Packing::WIDTH; - let log_packing_width = log2_strict_usize(packing_width); + let log_packing_width = log2_strict_usize(F::Packing::WIDTH); assert!(log_packing_width <= n); - assert_eq!(out.len(), 1 << (n - log_packing_width)); + assert_eq!(out.len(), 1 << n); let k = n.min(LOG_BATCHED_TILE_SIZE); if k <= log_packing_width || k >= n { for (eval, &scalar) in evals.iter().zip(scalars) { - compute_eval_eq_base_packed::(eval, out, scalar); + compute_eval_eq_base::(eval, out, scalar); } return; } let n_prefix_levels = n - k; - let tile_packed_size = 1 << (k - log_packing_width); + let tile_size = 1 << k; let per_query: Vec<_> = evals .iter() @@ -350,19 +346,14 @@ pub fn compute_eval_eq_base_packed_batched( }) .collect(); - // `out` already splits into `2^n_prefix_levels` tiles — many more than there are - // workers — so the pool's task counter load-balances these directly. - parallel::par_chunks_mut(out, tile_packed_size, |tile_idx, out_tile| { + // `out` splits into `2^n_prefix_levels` tiles — many more than there are workers — + // so the pool's task counter load-balances these directly. + parallel::par_chunks_mut(out, tile_size, |tile_idx, out_tile| { for (eq_prefix, middle, eq_suffix) in &per_query { // Here e could precompute the eq poly, trading some memory for less computation // (2x faster on M4 max, but 2x slower on machines with smaller caches. // TODO implement both and choose based on cache size?) - base_eval_eq_packed_with_packed_output::( - middle, - out_tile, - *eq_suffix, - EF::ExtensionPacking::from(eq_prefix[tile_idx]), - ); + base_eval_eq_packed::(middle, out_tile, *eq_suffix, eq_prefix[tile_idx]); } }); } diff --git a/crates/backend/poly/src/point.rs b/crates/backend/poly/src/point.rs index 5af8ed6b..89da5ba7 100644 --- a/crates/backend/poly/src/point.rs +++ b/crates/backend/poly/src/point.rs @@ -106,6 +106,15 @@ where } } +impl MultilinearPoint { + #[must_use] + pub fn reversed(&self) -> Self { + let mut v = self.0.clone(); + v.reverse(); + Self(v) + } +} + impl From> for MultilinearPoint { fn from(v: Vec) -> Self { Self(v) diff --git a/crates/rec_aggregation/zkdsl_implem/whir.py b/crates/rec_aggregation/zkdsl_implem/whir.py index 253b95a6..40e530cb 100644 --- a/crates/rec_aggregation/zkdsl_implem/whir.py +++ b/crates/rec_aggregation/zkdsl_implem/whir.py @@ -126,17 +126,23 @@ def whir_open( folding_randomness_global = Array(n_vars * DIM) - start_buf = Array(n_rounds + 2) - start_buf[0] = folding_randomness_global + # WHIR sumcheck folds LSB-first, so chronological challenges are in reverse polynomial-var + # order: chronological challenge #c is written to global position (n_vars - 1 - c), so the + # cumulative reads as [x_0, x_1, ..., x_{n_vars-1}]. `chrono_buf` carries the running + # chronological index across the `range` loop (range loops may not mutate outer-scope vars). + chrono_buf = Array(n_rounds + 2) + chrono_buf[0] = 0 for i in range(0, n_rounds + 1): - start: Mut = start_buf[i] + chrono: Mut = chrono_buf[i] for j in range(0, folding_factors[i]): - copy_5(all_folding_randomness[i] + j * DIM, start + j * DIM) - start += folding_factors[i] * DIM - start_buf[i + 1] = start - start = start_buf[n_rounds + 1] + target_pos = n_vars - 1 - (chrono + j) + copy_5(all_folding_randomness[i] + j * DIM, folding_randomness_global + target_pos * DIM) + chrono += folding_factors[i] + chrono_buf[i + 1] = chrono + chrono = chrono_buf[n_rounds + 1] for j in range(0, n_final_vars): - copy_5(all_folding_randomness[n_rounds + 1] + j * DIM, start + j * DIM) + target_pos = n_vars - 1 - (chrono + j) + copy_5(all_folding_randomness[n_rounds + 1] + j * DIM, folding_randomness_global + target_pos * DIM) all_ood_recovered_evals = Array(num_oods[0] * DIM) for i in range(0, num_oods[0]): @@ -152,6 +158,9 @@ def whir_open( num_oods[0], ) + # LSB-fold: at round i the polynomial's remaining vars are [x_0, ..., x_{n_vars_remaining-1}], + # i.e. the FIRST n_vars_remaining entries of folding_randomness_global (no pointer advance). + # eval_carry carries (n_vars_remaining, folding_randomness ptr, running sum) across the loop. eval_carry = Array((n_rounds + 1) * 3) eval_carry[0] = n_vars eval_carry[1] = folding_randomness_global @@ -164,12 +173,9 @@ def whir_open( n_vars_remaining -= folding_factors[i] my_ood_recovered_evals = Array(num_oods[i + 1] * DIM) combination_randomness_powers = all_combination_randomness_powers[i] - my_folding_randomness += folding_factors[i] * DIM for j in range(0, num_oods[i + 1]): expanded_from_univariate = expand_from_univariate_ext(all_ood_points[i] + j * DIM, n_vars_remaining) - poly_eq_extension_dynamic_to( - expanded_from_univariate, my_folding_randomness, my_ood_recovered_evals + j * DIM, n_vars_remaining - ) + poly_eq_extension_dynamic_to(expanded_from_univariate, folding_randomness_global, my_ood_recovered_evals + j * DIM, n_vars_remaining) summed_ood = Array(DIM) dot_product_ee_dynamic( my_ood_recovered_evals, @@ -182,7 +188,7 @@ def whir_open( circle_value_i = all_circle_values[i] for j in range(0, num_queries[i]): # unroll ? expanded_from_univariate = expand_from_univariate_base(circle_value_i[j], n_vars_remaining) - poly_eq_base_extension_to(expanded_from_univariate, my_folding_randomness, s6s + j * DIM, n_vars_remaining) + poly_eq_base_extension_to(expanded_from_univariate, folding_randomness_global, s6s + j * DIM, n_vars_remaining) s7 = Array(DIM) dot_product_ee_dynamic( s6s, @@ -196,10 +202,18 @@ def whir_open( eval_carry[base + 4] = my_folding_randomness eval_carry[base + 5] = s s = eval_carry[n_rounds * 3 + 2] + + # WHIR sumcheck folds LSB-first: final_sumcheck challenges are [r_1=x_{m-1}, ..., r_m=x_0]. + # eval_multilinear_coeffs_rev computes f(x_j = point[j]); for LSB-fold we need + # f(x_j = r_{m-j}) = point[j] = r_{j+1} = x_{m-j-1} which is wrong, so reverse first. + final_sumcheck_chals_rev = Array(n_final_vars * DIM) + final_sumcheck_chals = all_folding_randomness[n_rounds + 1] + for j in range(0, n_final_vars): + copy_5(final_sumcheck_chals + (n_final_vars - 1 - j) * DIM, final_sumcheck_chals_rev + j * DIM) final_value = match_range( n_final_vars, range(MAX_NUM_VARIABLES_TO_SEND_COEFFS - WHIR_SUBSEQUENT_FOLDING_FACTOR, MAX_NUM_VARIABLES_TO_SEND_COEFFS + 1), - lambda n: eval_multilinear_coeffs_rev(final_coeffcients, all_folding_randomness[n_rounds + 1], n), + lambda n: eval_multilinear_coeffs_rev(final_coeffcients, final_sumcheck_chals_rev, n), ) # copy_5(mul_extension_ret(s, final_value), end_sum); @@ -376,7 +390,12 @@ def sample_stir_indexes_and_fold( folds = Array(num_queries * DIM) - poly_eq = compute_eq_mle_extension_dynamic(folding_randomness, folding_factor) + # WHIR sumcheck folds LSB-first; the leaf is laid out so its first var is the polynomial's + # last LSB-folded var. evaluate (poly_eq) is MSB-first, so reverse the per-round challenges. + folding_randomness_reversed = Array(folding_factor * DIM) + for j in range(0, folding_factor): + copy_5(folding_randomness + (folding_factor - 1 - j) * DIM, folding_randomness_reversed + j * DIM) + poly_eq = compute_eq_mle_extension_dynamic(folding_randomness_reversed, folding_factor) if merkle_leaves_in_basefield == 1: for i in range(0, num_queries): diff --git a/crates/sub_protocols/src/quotient_gkr/layers.rs b/crates/sub_protocols/src/quotient_gkr/layers.rs index eb8a769e..6951dc4a 100644 --- a/crates/sub_protocols/src/quotient_gkr/layers.rs +++ b/crates/sub_protocols/src/quotient_gkr/layers.rs @@ -84,7 +84,7 @@ impl<'a, EF: ExtensionField>> LayerStorage<'a, EF> { } } - pub fn materialise_in_full(self) -> (Vec, Vec) { + pub(super) fn materialise_in_full(self) -> (Vec, Vec) { let natural = match self { Self::Natural { .. } => self, other => other.convert_to_natural(), diff --git a/crates/whir/src/commit.rs b/crates/whir/src/commit.rs index b64bb350..1af6ca36 100644 --- a/crates/whir/src/commit.rs +++ b/crates/whir/src/commit.rs @@ -65,24 +65,25 @@ where &self, prover_state: &mut impl FSProver, polynomial: &MleOwned, - actual_data_len: usize, // polynomial[actual_data_len..] is zero + _actual_data_len: usize, // polynomial[_actual_data_len..] is zero ) -> Witness { let n_blocks = 1usize << self.folding_factor.at_round(0); - let evals_len = 1usize << self.num_variables; - let effective_n_cols = actual_data_len.div_ceil(evals_len / n_blocks); - // DFT matrix width: skip as many zero columns as possible, aligned to packing (SIMD) - let dft_n_cols = effective_n_cols.next_multiple_of(packing_width::()).min(n_blocks); + // NOTE: main's zero-COLUMN skip optimization (dft_n_cols / effective_n_cols < n_blocks) + // assumed an MSB-cols matrix layout, where the polynomial's zero suffix lands in trailing + // columns. The split-eq LSB-cols layout puts the zero suffix in trailing ROWS instead, so + // skipping columns would drop live data. We commit all columns (no skip): same root, just + // without the prover-side speedup. (The branch optimized this via row-skip in the DFT.) let folded_matrix = info_span!("FFT").in_scope(|| { reorder_and_dft( &polynomial.by_ref(), self.folding_factor.at_round(0), self.starting_log_inv_rate, - dft_n_cols, + n_blocks, ) }); - let (prover_data, root) = MerkleData::build(folded_matrix, n_blocks, effective_n_cols); + let (prover_data, root) = MerkleData::build(folded_matrix, n_blocks, n_blocks); prover_state.add_base_scalars(&root); diff --git a/crates/whir/src/lib.rs b/crates/whir/src/lib.rs index a3e84f9c..9ed139a1 100644 --- a/crates/whir/src/lib.rs +++ b/crates/whir/src/lib.rs @@ -27,6 +27,9 @@ pub(crate) use utils::*; mod matrix; pub(crate) use matrix::*; +mod svo; +pub(crate) use svo::*; + #[derive(Clone, Debug)] pub struct SparseStatement { pub total_num_variables: usize, diff --git a/crates/whir/src/open.rs b/crates/whir/src/open.rs index cb773528..a80c2f5a 100644 --- a/crates/whir/src/open.rs +++ b/crates/whir/src/open.rs @@ -1,11 +1,11 @@ // Credits: whir-p3 (https://github.com/tcoratger/whir-p3) (MIT and Apache-2.0 licenses). +use std::ops::{Mul, Sub}; + use ::utils::log2_strict_usize; use fiat_shamir::{FSProver, MerklePath, ProofResult}; -use field::PrimeCharacteristicRing; -use field::{ExtensionField, Field, TwoAdicField}; +use field::{ExtensionField, Field, PrimeCharacteristicRing, TwoAdicField}; use poly::*; -use sumcheck::{ProductComputation, run_product_sumcheck, sumcheck_prove_many_rounds}; use tracing::{info_span, instrument}; use crate::{config::WhirConfig, *}; @@ -60,23 +60,18 @@ where prover_state: &mut impl FSProver, round_state: &mut RoundState, ) -> ProofResult<()> { - let folded_evaluations = &round_state.sumcheck_prover.evals; - let num_variables = self.num_variables - self.folding_factor.total_number(round_index); - - // Base case: final round reached if round_index == self.n_rounds() { return self.final_round(round_index, prover_state, round_state); } + let num_variables = self.num_variables - self.folding_factor.total_number(round_index); let round_params = &self.round_parameters[round_index]; - - // Compute the folding factors for later use let folding_factor_next = self.folding_factor.at_round(round_index + 1); - // Compute polynomial evaluations and build Merkle tree let domain_reduction = 1 << self.rs_reduction_factor(round_index); let new_domain_size = round_state.domain_size / domain_reduction; let inv_rate = new_domain_size >> num_variables; + let folded_evaluations = &round_state.sumcheck_prover.evals; let folded_matrix = info_span!("FFT").in_scope(|| { reorder_and_dft( &folded_evaluations.by_ref(), @@ -87,11 +82,11 @@ where }); let full = 1 << folding_factor_next; + // Round commitments have no zero-column suffix, so effective == full. let (prover_data, root) = MerkleData::build(folded_matrix, full, full); prover_state.add_base_scalars(&root); - // Handle OOD (Out-Of-Domain) samples let (ood_points, ood_answers) = sample_ood_points::(prover_state, round_params.ood_samples, num_variables, |point| { info_span!("ood evaluation").in_scope(|| folded_evaluations.evaluate(point)) @@ -108,37 +103,14 @@ where round_index, )?; - let folding_randomness = round_state.folding_randomness( - self.folding_factor.at_round(round_index) + round_state.commitment_merkle_prover_data_b.is_some() as usize, - ); - - let stir_evaluations = if let Some(data_b) = &round_state.commitment_merkle_prover_data_b { - let answers_a = - open_merkle_tree_at_challenges(&round_state.merkle_prover_data, prover_state, &stir_challenges_indexes); - let answers_b = open_merkle_tree_at_challenges(data_b, prover_state, &stir_challenges_indexes); - let mut stir_evaluations = Vec::new(); - for (answer_a, answer_b) in answers_a.iter().zip(&answers_b) { - let vars_a = answer_a.by_ref().n_vars(); - let vars_b = answer_b.by_ref().n_vars(); - let a_trunc = folding_randomness[1..].to_vec(); - let eval_a = answer_a.evaluate(&MultilinearPoint(a_trunc)); - let b_trunc = folding_randomness[vars_a - vars_b + 1..].to_vec(); - let eval_b = answer_b.evaluate(&MultilinearPoint(b_trunc)); - let last_fold_rand_a = folding_randomness[0]; - let last_fold_rand_b = folding_randomness[..vars_a - vars_b + 1] - .iter() - .map(|&x| EF::ONE - x) - .product::(); - stir_evaluations.push(eval_a * last_fold_rand_a + eval_b * last_fold_rand_b); - } + let folding_randomness = round_state.folding_randomness(self.folding_factor.at_round(round_index)); + let folding_randomness_reversed = folding_randomness.reversed(); - stir_evaluations - } else { + let stir_evaluations: Vec = open_merkle_tree_at_challenges(&round_state.merkle_prover_data, prover_state, &stir_challenges_indexes) .iter() - .map(|answer| answer.evaluate(&folding_randomness)) - .collect() - }; + .map(|answer| answer.evaluate(&folding_randomness_reversed)) + .collect(); // Randomness for combination prover_state.duplex(); @@ -160,7 +132,6 @@ where ); let next_folding_randomness = round_state.sumcheck_prover.run_sumcheck_many_rounds( - None, prover_state, folding_factor_next, round_params.folding_pow_bits, @@ -168,12 +139,10 @@ where round_state.randomness_vec.extend_from_slice(&next_folding_randomness.0); - // Update round state round_state.domain_size = new_domain_size; round_state.next_domain_gen = PF::::two_adic_generator(log2_strict_usize(new_domain_size) - folding_factor_next); round_state.merkle_prover_data = prover_data; - round_state.commitment_merkle_prover_data_b = None; Ok(()) } @@ -185,60 +154,30 @@ where round_state: &mut RoundState, ) -> ProofResult<()> { // Convert evaluations to coefficient form and send to the verifier. - let mut coeffs = match &round_state.sumcheck_prover.evals { - MleOwned::Extension(evals) => evals.clone(), - MleOwned::ExtensionPacked(evals) => unpack_extension::(evals), - _ => unreachable!(), - }; + let mut coeffs = round_state + .sumcheck_prover + .evals + .as_extension() + .expect("WHIR sumcheck stores evals as extension") + .to_vec(); evals_to_coeffs(&mut coeffs); prover_state.add_extension_scalars(&coeffs); prover_state.pow_grinding(self.final_query_pow_bits); - // Final verifier queries and answers. The indices are over the folded domain. let final_challenge_indexes = get_challenge_stir_queries( - // The size of the original domain before folding round_state.domain_size >> self.folding_factor.at_round(round_index), self.final_queries, prover_state, ); - let mut base_paths = Vec::new(); - let mut ext_paths = Vec::new(); - for challenge in final_challenge_indexes { - let (answer, sibling_hashes) = round_state.merkle_prover_data.open(challenge); - - match answer { - MleOwned::Base(leaf) => { - base_paths.push(MerklePath { - leaf_data: leaf, - sibling_hashes, - leaf_index: challenge, - }); - } - MleOwned::Extension(leaf) => { - ext_paths.push(MerklePath { - leaf_data: leaf, - sibling_hashes, - leaf_index: challenge, - }); - } - _ => unreachable!(), - } - } - if !base_paths.is_empty() { - prover_state.hint_merkle_paths_base(base_paths); - } - if !ext_paths.is_empty() { - prover_state.hint_merkle_paths_extension(ext_paths); - } + open_merkle_tree_at_challenges(&round_state.merkle_prover_data, prover_state, &final_challenge_indexes); - // Run final sumcheck if required if self.final_sumcheck_rounds > 0 { let final_folding_randomness = round_state .sumcheck_prover - .run_sumcheck_many_rounds(None, prover_state, self.final_sumcheck_rounds, 0); + .run_sumcheck_many_rounds(prover_state, self.final_sumcheck_rounds, 0); round_state.randomness_vec.extend(final_folding_randomness.0); } @@ -320,11 +259,8 @@ fn open_merkle_tree_at_challenges>>( #[derive(Debug, Clone)] pub struct SumcheckSingle>> { - /// Evaluations of the polynomial `p(X)`. pub(crate) evals: MleOwned, - /// Evaluations of the equality polynomial used for enforcing constraints. - pub(crate) weights: MleOwned, - /// Accumulated sum incorporating equality constraints. + pub(crate) weights: Vec, pub(crate) sum: EF, } @@ -332,30 +268,37 @@ impl SumcheckSingle where EF: ExtensionField>, { - #[instrument(skip_all)] - pub(crate) fn add_new_equality( + fn add_equality_inner( &mut self, - points: &[MultilinearPoint], + points: &[MultilinearPoint], evaluations: &[EF], combination_randomness: &[EF], + eval_fn: impl Fn(&[T], &mut [EF], EF), ) { assert_eq!(combination_randomness.len(), points.len()); assert_eq!(evaluations.len(), points.len()); - - points - .iter() - .zip(combination_randomness.iter()) - .for_each(|(point, &rand)| { - compute_eval_eq_packed::<_, true>(point, self.weights.as_extension_packed_mut().unwrap(), rand); - }); - + for (point, &rand) in points.iter().zip(combination_randomness) { + eval_fn(&point.0, &mut self.weights, rand); + } self.sum += combination_randomness .iter() - .zip(evaluations.iter()) + .zip(evaluations) .map(|(&rand, &eval)| rand * eval) .sum::(); } + #[instrument(skip_all)] + pub(crate) fn add_new_equality( + &mut self, + points: &[MultilinearPoint], + evaluations: &[EF], + combination_randomness: &[EF], + ) { + self.add_equality_inner(points, evaluations, combination_randomness, |p, w, r| { + compute_eval_eq::, EF, true>(p, w, r); + }); + } + #[instrument(skip_all)] pub(crate) fn add_new_base_equality( &mut self, @@ -366,13 +309,8 @@ where assert_eq!(combination_randomness.len(), points.len()); assert_eq!(evaluations.len(), points.len()); - compute_eval_eq_base_packed_batched::, EF>( - points, - self.weights.as_extension_packed_mut().unwrap(), - combination_randomness, - ); + compute_eval_eq_base_batched::, EF>(points, &mut self.weights, combination_randomness); - // Accumulate the weighted sum (cheap, done sequentially) self.sum += combination_randomness .iter() .zip(evaluations.iter()) @@ -382,33 +320,32 @@ where fn run_sumcheck_many_rounds( &mut self, - prev_folding_scalar: Option, prover_state: &mut impl FSProver, n_rounds: usize, pow_bits: usize, ) -> MultilinearPoint { - let (challenges, folds, new_sum) = sumcheck_prove_many_rounds( - MleGroupRef::merge(&[&self.evals.by_ref(), &self.weights.by_ref()]), - prev_folding_scalar, - &ProductComputation {}, - &vec![], - None, - prover_state, - self.sum, - None, - n_rounds, - false, - pow_bits, - ); - - self.sum = new_sum; - [self.evals, self.weights] = folds.split().try_into().unwrap(); + let mut challenges = Vec::with_capacity(n_rounds); + for _ in 0..n_rounds { + let r = lsb_sumcheck_round( + self.evals.as_extension().expect("WHIR sumcheck operates on Vec"), + &self.weights, + &mut self.sum, + prover_state, + pow_bits, + ); + challenges.push(r); - challenges + let evals_ref = self.evals.as_extension().unwrap(); + let new_evals = lsb_fold(evals_ref, r); + let new_weights = lsb_fold(&self.weights, r); + self.evals = MleOwned::Extension(new_evals); + self.weights = new_weights; + } + MultilinearPoint(challenges) } #[instrument(skip_all)] - pub(crate) fn run_initial_sumcheck_rounds( + pub(crate) fn run_initial_sumcheck_rounds_svo( evals: &MleRef<'_, EF>, statement: &[SparseStatement], combination_randomness: EF, @@ -416,32 +353,271 @@ where folding_factor: usize, pow_bits: usize, ) -> (Self, MultilinearPoint) { - assert_ne!(folding_factor, 0); + let l = statement[0].total_num_variables; + let l_0 = folding_factor; - let (weights, sum) = combine_statement::(statement, combination_randomness); + assert!( + statement.iter().all(|e| !e.is_next || e.inner_num_variables() >= l_0), + "next-spill is currently unimplemented", + ); - let mut evals = evals.pack(); - let mut weights = Mle::Owned(MleOwned::ExtensionPacked(weights)); - let (challengess, new_sum, new_evals, new_weights) = run_product_sumcheck( - &evals.by_ref(), - &weights.by_ref(), - prover_state, + let relaxed_statement = relax_eq_spill_statements(statement, l_0); + + let mut sum = build_initial_sum(&relaxed_statement, combination_randomness); + + let unpacked_mle = evals.unpack(); + let unpacked_ref = unpacked_mle.by_ref(); + let f = unpacked_ref + .as_base() + .expect("WHIR committed polynomial must be base field"); + + let groups = build_all_compressed_groups::(&relaxed_statement, combination_randomness, f, l, l_0); + let accs = build_accumulators::(&groups, l_0); + + let mut challenges: Vec = Vec::with_capacity(l_0); + + let mut lagrange: Vec = vec![EF::ONE]; + while challenges.len() < l_0 { + let r = challenges.len(); + let (c0, c2) = round_message_with_tensor(r, &lagrange, &accs); + let rho = sumcheck_finish_round(c0, c2, &mut sum, prover_state, pow_bits); + challenges.push(rho); + lagrange_tensor_extend(&mut lagrange, rho); + } + + let evals_ext: Vec = fold_by_tensor::(f, &challenges); + + let weights = build_post_svo_weights(&relaxed_statement, combination_randomness, &challenges); + debug_assert_eq!(weights.len(), evals_ext.len()); + let sumcheck = Self { + evals: MleOwned::Extension(evals_ext), + weights, sum, - folding_factor, - pow_bits, + }; + (sumcheck, MultilinearPoint(challenges)) + } +} + +fn relax_eq_spill_statements(statements: &[SparseStatement], l_0: usize) -> Vec> +where + EF: ExtensionField>, +{ + let mut out: Vec> = Vec::with_capacity(statements.len()); + for smt in statements { + let m = smt.inner_num_variables(); + if smt.is_next || m >= l_0 { + out.push(smt.clone()); + continue; + } + let l = smt.total_num_variables; + let extra = l_0 - m; + let s = l - m; + debug_assert!(s >= extra); + for v in &smt.values { + let top = v.selector >> extra; + let bot = v.selector & ((1usize << extra) - 1); + let mut new_point: Vec = Vec::with_capacity(l_0); + for k in (0..extra).rev() { + new_point.push(if (bot >> k) & 1 == 1 { EF::ONE } else { EF::ZERO }); + } + new_point.extend_from_slice(&smt.point.0); + out.push(SparseStatement { + total_num_variables: l, + point: MultilinearPoint(new_point), + values: vec![SparseValue { + selector: top, + value: v.value, + }], + is_next: false, + }); + } + } + out +} + +fn build_initial_sum(statements: &[SparseStatement], gamma: EF) -> EF +where + EF: ExtensionField>, +{ + let mut combined_sum = EF::ZERO; + let mut gamma_pow = EF::ONE; + for smt in statements { + for v in &smt.values { + combined_sum += v.value * gamma_pow; + gamma_pow *= gamma; + } + } + combined_sum +} + +fn take_next_powers(gamma_pow: &mut EF, gamma: EF, k: usize) -> Vec { + let mut out = Vec::with_capacity(k); + for _ in 0..k { + out.push(*gamma_pow); + *gamma_pow *= gamma; + } + out +} + +fn build_post_svo_weights(statements: &[SparseStatement], gamma: EF, rhos: &[EF]) -> Vec +where + EF: ExtensionField>, +{ + let n = statements[0].total_num_variables; + let l_0 = rhos.len(); + assert!(l_0 <= n); + let target_size = 1usize << (n - l_0); + let mut out = EF::zero_vec(target_size); + let mut gamma_pow = EF::ONE; + + for smt in statements { + let m = smt.inner_num_variables(); + let p = &smt.point.0; + assert!( + m >= l_0, + "build_post_svo_weights requires m >= l_0 (pre-relax eq spills)" ); - evals = new_evals.into(); - weights = new_weights.into(); + let alpha_powers = take_next_powers(&mut gamma_pow, gamma, smt.values.len()); - let sumcheck = Self { - evals: evals.as_owned().unwrap(), - weights: weights.as_owned().unwrap(), - sum: new_sum, + let tail_eval: Vec = if smt.is_next { + rhos.iter().fold(matrix_next_mle_folded(p), |buf, &r| lsb_fold(&buf, r)) + } else { + let scalar_eq: EF = (0..l_0) + .map(|k| { + let (p_k, r_k) = (p[m - 1 - k], rhos[k]); + p_k * r_k + (EF::ONE - p_k) * (EF::ONE - r_k) + }) + .product(); + let tail = &p[..m - l_0]; + if tail.is_empty() { + vec![scalar_eq] + } else { + eval_eq_scaled(tail, scalar_eq) + } }; - (sumcheck, challengess) + let tail_len = tail_eval.len(); + for (v, &alpha_j) in smt.values.iter().zip(&alpha_powers) { + let base = v.selector * tail_len; + let dst = &mut out[base..base + tail_len]; + parallel::par_for_each_mut(dst, |i, o| *o += alpha_j * tail_eval[i]); + } + } + + out +} + +#[instrument(skip_all)] +fn build_all_compressed_groups( + statement: &[SparseStatement], + gamma: EF, + f: &[PF], + l: usize, + l_0: usize, +) -> Vec> +where + EF: ExtensionField>, +{ + let mut groups: Vec> = Vec::new(); + let mut gamma_pow = EF::ONE; + for smt in statement { + let s = smt.selector_num_variables(); + assert!(s + l_0 <= l, "build_all_compressed_groups requires s + l_0 <= l"); + let sel_bits: Vec = smt.values.iter().map(|v| v.selector).collect(); + let alpha_powers = take_next_powers(&mut gamma_pow, gamma, smt.values.len()); + if smt.is_next { + groups.extend(compress_next_claim::( + f, + &sel_bits, + &smt.point.0, + &alpha_powers, + l, + l_0, + s, + )); + } else { + groups.push(compress_eq_claim::( + f, + &sel_bits, + &smt.point.0, + &alpha_powers, + l, + l_0, + s, + )); + } + } + groups +} + +fn round_coeffs_flat(evals: &[E], weights: &[EF]) -> (EF, EF) +where + EF: ExtensionField> + Mul, + E: Copy + Send + Sync + Sub, +{ + assert_eq!(evals.len(), weights.len()); + assert!(evals.len() >= 2 && evals.len().is_power_of_two()); + // EF on the left so `Mul for EF` is used (Algebra for the base case). + parallel::map_reduce( + evals.len() / 2, + || (EF::ZERO, EF::ZERO), + |i| { + let (e, w) = (&evals[2 * i..2 * i + 2], &weights[2 * i..2 * i + 2]); + (w[0] * e[0], (w[1] - w[0]) * (e[1] - e[0])) + }, + |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2), + ) +} + +fn fold_by_tensor(evals: &[E], rhos: &[EF]) -> Vec +where + EF: ExtensionField> + Mul + From, + E: Copy + Send + Sync, +{ + let width = 1usize << rhos.len(); + assert!(evals.len() >= width && evals.len().is_multiple_of(width)); + if rhos.is_empty() { + return evals.iter().map(|&v| EF::from(v)).collect(); } + let tensor = eval_eq(&rhos.iter().rev().copied().collect::>()); + parallel::par_map_collect(evals.len() / width, |i| { + let chunk = &evals[i * width..i * width + width]; + tensor.iter().zip(chunk).map(|(&t, &e)| t * e).sum::() + }) +} + +fn sumcheck_finish_round>>( + c0: EF, + c2: EF, + sum: &mut EF, + prover_state: &mut impl FSProver, + pow_bits: usize, +) -> EF { + let c1 = *sum - c0.double() - c2; + let poly = DensePolynomial::new(vec![c0, c1, c2]); + prover_state.add_sumcheck_polynomial(&poly.coeffs, None); + prover_state.pow_grinding(pow_bits); + let r: EF = prover_state.sample(); + *sum = poly.evaluate(r); + r +} + +#[instrument(skip_all)] +fn lsb_sumcheck_round>>( + evals: &[EF], + weights: &[EF], + sum: &mut EF, + prover_state: &mut impl FSProver, + pow_bits: usize, +) -> EF { + let (c0, c2) = round_coeffs_flat(evals, weights); + sumcheck_finish_round(c0, c2, sum, prover_state, pow_bits) +} + +/// LSB-fold a slice of evaluations: `out[i] = m[2i] + r * (m[2i+1] - m[2i])`. +fn lsb_fold>>(m: &[EF], r: EF) -> Vec { + fold_multilinear_at_bit(m, r, 0, &|diff, alpha| alpha * diff, false) } #[derive(Debug)] @@ -452,7 +628,6 @@ where domain_size: usize, next_domain_gen: PF, sumcheck_prover: SumcheckSingle, - commitment_merkle_prover_data_b: Option>, merkle_prover_data: MerkleData, randomness_vec: Vec, } @@ -487,7 +662,7 @@ where prover_state.duplex(); let combination_randomness_gen: EF = prover_state.sample(); - let (sumcheck_prover, folding_randomness) = SumcheckSingle::run_initial_sumcheck_rounds( + let (sumcheck_prover, folding_randomness) = SumcheckSingle::run_initial_sumcheck_rounds_svo( polynomial, &statement, combination_randomness_gen, @@ -503,7 +678,6 @@ where ), sumcheck_prover, merkle_prover_data: witness.prover_data, - commitment_merkle_prover_data_b: None, randomness_vec: folding_randomness.0.clone(), }) } @@ -512,107 +686,3 @@ where MultilinearPoint(self.randomness_vec[self.randomness_vec.len() - folding_factor..].to_vec()) } } - -#[instrument(skip_all, fields(num_constraints = statements.len(), n_vars = statements[0].total_num_variables))] -fn combine_statement(statements: &[SparseStatement], gamma: EF) -> (Vec>, EF) -where - EF: ExtensionField>, -{ - let num_variables = statements[0].total_num_variables; - assert!(statements.iter().all(|e| e.total_num_variables == num_variables)); - - let out_len = 1 << (num_variables - packing_log_width::()); - - let is_full = |s: &SparseStatement| { - !s.is_next && s.values.len() == 1 && s.values[0].selector == 0 && s.inner_num_variables() == num_variables - }; - - let mut combined_weights: Vec>; - let mut combined_sum = EF::ZERO; - let mut gamma_pow = EF::ONE; - - let start_idx = match statements { - [a, b, ..] if is_full(a) && is_full(b) => { - combined_weights = unsafe { uninitialized_vec(out_len) }; - let sa = gamma_pow; - let sb = gamma_pow * gamma; - combined_sum = a.values[0].value * sa + b.values[0].value * sb; - gamma_pow = sb * gamma; - compute_eval_eq_packed_dual::(&a.point.0, &b.point.0, &mut combined_weights, sa, sb); - 2 - } - [a, ..] if is_full(a) => { - combined_weights = unsafe { uninitialized_vec(out_len) }; - let sa = gamma_pow; - combined_sum = a.values[0].value * sa; - gamma_pow *= gamma; - compute_eval_eq_packed::(&a.point.0, &mut combined_weights, sa); - 1 - } - _ => { - combined_weights = EFPacking::::zero_vec(out_len); - 0 - } - }; - - for smt in &statements[start_idx..] { - if !smt.is_next && (smt.values.len() == 1 || smt.inner_num_variables() < packing_log_width::()) { - for evaluation in &smt.values { - compute_sparse_eval_eq_packed::(evaluation.selector, &smt.point, &mut combined_weights, gamma_pow); - combined_sum += evaluation.value * gamma_pow; - gamma_pow *= gamma; - } - } else { - let inner_poly = if smt.is_next { - let next = matrix_next_mle_folded(&smt.point.0); - pack_extension(&next) - } else { - eval_eq_packed(&smt.point) - }; - let shift = smt.inner_num_variables() - packing_log_width::(); - let mut indexed_smt_values = smt.values.iter().enumerate().collect::>(); - indexed_smt_values.sort_by_key(|(_, e)| e.selector); - indexed_smt_values.dedup_by_key(|(_, e)| e.selector); - assert_eq!( - indexed_smt_values.len(), - smt.values.len(), - "Duplicate selectors in sparse statement" - ); - let mut chunks_mut = split_at_mut_many( - &mut combined_weights, - &indexed_smt_values - .iter() - .map(|(_, e)| e.selector << shift) - .collect::>(), - ); - chunks_mut.remove(0); - let mut next_gamma_powers = vec![gamma_pow]; - for _ in 1..indexed_smt_values.len() { - next_gamma_powers.push(*next_gamma_powers.last().unwrap() * gamma); - } - for (e, &scalar) in smt.values.iter().zip(&next_gamma_powers) { - combined_sum += e.value * scalar; - } - let n = 1usize << shift; - let mask = n - 1; - let ptrs: Vec<(parallel::SendPtr>, EF)> = chunks_mut - .iter_mut() - .zip(&indexed_smt_values) - .map(|(out_buff, &(origin_index, _))| { - ( - parallel::SendPtr(out_buff.as_mut_ptr()), - next_gamma_powers[origin_index], - ) - }) - .collect(); - let inner = inner_poly.as_slice(); - parallel::for_each_index(ptrs.len() << shift, |flat| { - let (ptr, scalar) = &ptrs[flat >> shift]; - let i = flat & mask; - unsafe { *ptr.add(i) += inner[i] * *scalar }; - }); - gamma_pow = *next_gamma_powers.last().unwrap() * gamma; - } - } - (combined_weights, combined_sum) -} diff --git a/crates/whir/src/svo.rs b/crates/whir/src/svo.rs new file mode 100644 index 00000000..cb887097 --- /dev/null +++ b/crates/whir/src/svo.rs @@ -0,0 +1,462 @@ +#![allow(clippy::needless_range_loop)] +use field::{BasedVectorSpace, ExtensionField, Field, PackedFieldExtension, PackedValue, PrimeCharacteristicRing}; +use poly::{EFPacking, PARALLEL_THRESHOLD, PF, PFPacking, compute_eval_eq, eval_eq, packing_log_width}; + +#[derive(Debug, Clone)] +pub(crate) struct CompressedGroup { + pub(crate) w_svo: Vec, + pub(crate) p_bar: Vec, +} + +#[derive(Debug)] +pub(crate) struct AccGroup { + pub(crate) acc_0: Vec>, + pub(crate) acc_inf: Vec>, +} + +pub(crate) fn grid_expand_into(f: &[EF], l: usize, out: &mut Vec, scratch: &mut Vec) { + assert_eq!(f.len(), 1 << l, "grid_expand_into: f.len() must be 2^l"); + let out_len = 3_usize.pow(l as u32); + if l == 0 { + out.clear(); + out.extend_from_slice(f); + return; + } + // Pick parity so the final stage lands in `out`. + let (mut cur, mut nxt): (&mut Vec, &mut Vec) = if l.is_multiple_of(2) { + (out, scratch) + } else { + (scratch, out) + }; + cur.clear(); + cur.extend_from_slice(f); + cur.resize(out_len, EF::ZERO); + nxt.clear(); + nxt.resize(out_len, EF::ZERO); + for stage in 0..l { + let s = 3_usize.pow(stage as u32); + let block_count = 1usize << (l - stage - 1); + let in_total = block_count * 2 * s; + let out_total = block_count * 3 * s; + let cur_slice = &cur[..in_total]; + let next_slice = &mut nxt[..out_total]; + let block_kernel = |(in_block, out_block): (&[EF], &mut [EF])| { + let (lo, hi) = in_block.split_at(s); + for j in 0..s { + let f0 = lo[j]; + let f1 = hi[j]; + out_block[3 * j] = f0; + out_block[3 * j + 1] = f1; + out_block[3 * j + 2] = f1 - f0; + } + }; + if out_total < PARALLEL_THRESHOLD { + for pair in cur_slice.chunks_exact(2 * s).zip(next_slice.chunks_exact_mut(3 * s)) { + block_kernel(pair); + } + } else { + parallel::par_chunks_mut(next_slice, 3 * s, |block_idx, out_block| { + let in_block = &cur_slice[block_idx * (2 * s)..block_idx * (2 * s) + 2 * s]; + block_kernel((in_block, out_block)); + }); + } + std::mem::swap(&mut cur, &mut nxt); + } + debug_assert_eq!(cur.len(), out_len); +} + +pub(crate) fn lagrange_tensor_extend(out: &mut Vec, c: EF) { + // Lagrange basis at `c` for the evaluation set {0, 1, ∞}: L_0 = 1 - c, L_1 = c, L_∞ = c(c - 1). + let l0 = EF::ONE - c; + let l_inf = c * (c - EF::ONE); + *out = out.iter().flat_map(|&v| [v * l0, v * c, v * l_inf]).collect(); +} + +fn reduce_svo_rows_one( + rows: &[PF], + eq_lo: &[EF], + eq_hi: &[EF], + sel_offset: usize, + svo_len: usize, +) -> impl IntoIterator +where + EF: ExtensionField>, +{ + let w = packing_log_width::(); + debug_assert!(svo_len.is_multiple_of(1 << w)); + debug_assert!(sel_offset.is_multiple_of(1 << w)); + debug_assert!(eq_lo.len().is_power_of_two()); + debug_assert!(eq_hi.len().is_power_of_two()); + + let rows_packed = PFPacking::::pack_slice(rows); + let svo_len_p = svo_len >> w; + let sel_off_p = sel_offset >> w; + let n_lo = eq_lo.len(); + let stride = eq_hi.len(); // = 2^m_hi — coefficient of b_lo in the full b index + debug_assert_eq!(EF::DIMENSION, 5); + + EFPacking::::to_ext_iter(reduce_svo_rows_one_inner::( + rows_packed, + eq_lo, + eq_hi, + sel_off_p, + stride, + n_lo, + svo_len_p, + )) +} + +#[inline] +fn reduce_svo_rows_one_inner( + rows_packed: &[PFPacking], + eq_lo: &[EF], + eq_hi: &[EF], + sel_off_p: usize, + stride: usize, + n_lo: usize, + svo_len_p: usize, +) -> Vec> +where + EF: ExtensionField>, +{ + const SVO_DOT_CHUNK: usize = 4; + debug_assert_eq!(EF::DIMENSION, D); + + let mut cs: [Vec>; D] = core::array::from_fn(|_| Vec::with_capacity(stride)); + for &e_hi in eq_hi.iter() { + let coefs = e_hi.as_basis_coefficients_slice(); + for (d, c) in cs.iter_mut().enumerate() { + c.push(PFPacking::::from(coefs[d])); + } + } + + let zero = || vec![EFPacking::::ZERO; svo_len_p]; + let step = |mut acc: Vec>, b_lo: usize| { + let base = b_lo * stride; + + let mut tmp_basis = vec![PFPacking::::ZERO; D * svo_len_p]; + + let mut b_hi = 0; + while b_hi + SVO_DOT_CHUNK <= stride { + let lhs: [[PFPacking; SVO_DOT_CHUNK]; D] = + core::array::from_fn(|d| core::array::from_fn(|i| cs[d][b_hi + i])); + + for k in 0..svo_len_p { + let row_off = sel_off_p + (base + b_hi) * svo_len_p + k; + let rhs: [PFPacking; SVO_DOT_CHUNK] = + core::array::from_fn(|i| rows_packed[row_off + i * svo_len_p]); + for d in 0..D { + tmp_basis[d * svo_len_p + k] += PFPacking::::dot_product::(&lhs[d], &rhs); + } + } + b_hi += SVO_DOT_CHUNK; + } + while b_hi < stride { + let row_off = sel_off_p + (base + b_hi) * svo_len_p; + for k in 0..svo_len_p { + let r = rows_packed[row_off + k]; + for d in 0..D { + tmp_basis[d * svo_len_p + k] += cs[d][b_hi] * r; + } + } + b_hi += 1; + } + + let e_lo = EFPacking::::from(eq_lo[b_lo]); + for k in 0..svo_len_p { + let tmp_k = EFPacking::::from_basis_coefficients_fn(|d| tmp_basis[d * svo_len_p + k]); + acc[k] += e_lo * tmp_k; + } + acc + }; + let merge = |mut a: Vec>, b: Vec>| { + for (x, y) in a.iter_mut().zip(&b) { + *x += *y; + } + a + }; + let total_work = n_lo * stride * svo_len_p; + if total_work < PARALLEL_THRESHOLD { + (0..n_lo).fold(zero(), step) + } else { + parallel::map_reduce_with_state( + n_lo, + || (), + zero, + |_, acc, idx| *acc = step(std::mem::take(acc), idx), + merge, + ) + } +} + +fn reduce_svo_rows_two( + rows: &[PF], + coef_a: &[EF], + coef_b: &[EF], + sel_offset: usize, + svo_len: usize, +) -> (Vec, Vec) +where + EF: ExtensionField>, +{ + let e_len = coef_a.len(); + debug_assert_eq!(coef_b.len(), e_len); + let zero = || (EF::zero_vec(svo_len), EF::zero_vec(svo_len)); + let step = |(mut a, mut b): (Vec, Vec), idx: usize| { + let ca = coef_a[idx]; + let cb = coef_b[idx]; + let row = &rows[sel_offset + idx * svo_len..][..svo_len]; + for bsvo in 0..svo_len { + let v = row[bsvo]; + a[bsvo] += ca * v; + b[bsvo] += cb * v; + } + (a, b) + }; + let merge = |(mut ax, mut bx): (Vec, Vec), (ay, by): (Vec, Vec)| { + for (x, y) in ax.iter_mut().zip(&ay) { + *x += *y; + } + for (x, y) in bx.iter_mut().zip(&by) { + *x += *y; + } + (ax, bx) + }; + if e_len * svo_len < PARALLEL_THRESHOLD { + (0..e_len).fold(zero(), step) + } else { + parallel::map_reduce_with_state( + e_len, + || (), + zero, + |_, acc, idx| *acc = step(std::mem::take(acc), idx), + merge, + ) + } +} + +pub(crate) fn compress_eq_claim( + f: &[PF], + sel_bits: &[usize], + inner_point: &[EF], + alpha_powers: &[EF], + l: usize, + l_0: usize, + s: usize, +) -> CompressedGroup +where + EF: ExtensionField>, +{ + assert_eq!(sel_bits.len(), alpha_powers.len()); + assert_eq!(inner_point.len(), l - s); + assert!(s + l_0 <= l, "compress_eq_claim non-spill requires s <= l - l_0"); + let m_split = l - l_0 - s; + let p_split = &inner_point[..m_split]; + let p_svo = &inner_point[m_split..]; + + // Factored eq(p_split, ·): split at the midpoint so storage is + // `2^⌊m/2⌋ + 2^⌈m/2⌉` instead of `2^m`. + let m_lo = m_split / 2; + let eq_lo = eval_eq(&p_split[..m_lo]); + let eq_hi = eval_eq(&p_split[m_lo..]); + let svo_len = 1usize << l_0; + let mut p_bar = vec![EF::ZERO; svo_len]; + + for (&sel_j, &alpha_j) in sel_bits.iter().zip(alpha_powers.iter()) { + let sel_offset = sel_j << (l - s); + let contrib = reduce_svo_rows_one::(f, &eq_lo, &eq_hi, sel_offset, svo_len); + for (p, s) in p_bar.iter_mut().zip(contrib) { + *p += alpha_j * s; + } + } + + CompressedGroup { + w_svo: p_svo.to_vec(), + p_bar, + } +} + +pub(crate) fn compress_next_claim( + f: &[PF], + sel_bits: &[usize], + inner_point: &[EF], + alpha_powers: &[EF], + l: usize, + l_0: usize, + s: usize, +) -> Vec> +where + EF: ExtensionField>, +{ + assert_eq!(sel_bits.len(), alpha_powers.len()); + let m = l - s; + assert_eq!(inner_point.len(), m); + assert!(s + l_0 <= l, "selector-inside-split requires s <= l - l_0"); + let m_split = m - l_0; + let split_len = 1usize << m_split; + let svo_len = 1usize << l_0; + + let (bar_t_split, c_omega) = build_bar_t_split(inner_point, m_split, m); + let e_split = eval_eq(&inner_point[..m_split]); + debug_assert_eq!(bar_t_split.len(), split_len); + debug_assert_eq!(e_split.len(), split_len); + + let c_pivot: Vec = (m_split..m) + .map(|j| { + let tail: EF = inner_point[j + 1..].iter().copied().product(); + (EF::ONE - inner_point[j]) * tail + }) + .collect(); + + let mut sigma_split = vec![EF::ZERO; svo_len]; + let mut p_eq = vec![EF::ZERO; svo_len]; + let mut s_omega = vec![EF::ZERO; svo_len]; + + for (&sel_j, &alpha_j) in sel_bits.iter().zip(alpha_powers.iter()) { + let sel_offset = sel_j << (l - s); + + let (sig_contrib, eq_contrib) = reduce_svo_rows_two::(f, &bar_t_split, &e_split, sel_offset, svo_len); + let c_base = sel_offset + ((split_len - 1) << l_0); + for bsvo in 0..svo_len { + s_omega[bsvo] += alpha_j * f[c_base + bsvo]; + sigma_split[bsvo] += alpha_j * sig_contrib[bsvo]; + p_eq[bsvo] += alpha_j * eq_contrib[bsvo]; + } + } + + let mut out: Vec> = Vec::with_capacity(l_0 + 2); + out.push(CompressedGroup { + w_svo: vec![EF::ZERO; l_0], + p_bar: sigma_split, + }); + for (pivot_pos, &cp) in c_pivot.iter().enumerate() { + let mut w = vec![EF::ZERO; l_0]; + w[..pivot_pos].copy_from_slice(&inner_point[m_split..m_split + pivot_pos]); + w[pivot_pos] = EF::ONE; + out.push(CompressedGroup { + w_svo: w, + p_bar: p_eq.iter().map(|v| *v * cp).collect(), + }); + } + out.push(CompressedGroup { + w_svo: vec![EF::ONE; l_0], + p_bar: s_omega.into_iter().map(|v| v * c_omega).collect(), + }); + debug_assert_eq!(out.len(), l_0 + 2); + out +} + +fn build_bar_t_split(p: &[EF], m_split: usize, m: usize) -> (Vec, EF) { + let out_len = 1usize << m_split; + let mut bar_t = vec![EF::ZERO; out_len]; + + let mut suf = vec![EF::ONE; m + 1]; + for j in (0..m).rev() { + suf[j] = suf[j + 1] * p[j]; + } + let mut prefix = vec![EF::ONE]; + for j in 0..m_split { + let c_j = suf[j + 1] * (EF::ONE - p[j]); + let stride = 1usize << (m_split - j); + let offset = 1usize << (m_split - 1 - j); + let prefix_len = prefix.len(); + debug_assert_eq!(prefix_len, 1 << j); + for k in 0..prefix_len { + bar_t[k * stride + offset] = c_j * prefix[k]; + } + if j + 1 < m_split { + let p_j = p[j]; + let one_minus = EF::ONE - p_j; + prefix = prefix.iter().flat_map(|&v| [v * one_minus, v * p_j]).collect(); + } + } + (bar_t, suf[0]) +} + +pub(crate) fn build_accumulators_single(group: &CompressedGroup, l_0: usize) -> AccGroup +where + EF: ExtensionField>, +{ + assert_eq!(group.w_svo.len(), l_0); + assert_eq!(group.p_bar.len(), 1 << l_0); + + let mut acc_0: Vec> = vec![Vec::new(); l_0]; + let mut acc_inf: Vec> = vec![Vec::new(); l_0]; + + let cap = 3_usize.pow(l_0 as u32); + let mut q: Vec = group.p_bar.clone(); + let mut tilde_q: Vec = Vec::with_capacity(cap); + let mut tilde_e: Vec = Vec::with_capacity(cap); + let mut scratch_q: Vec = Vec::with_capacity(cap); + let mut scratch_e: Vec = Vec::with_capacity(cap); + let mut e_buf: Vec = Vec::with_capacity(1 << l_0); + for r_idx in 0..l_0 { + let r = l_0 - 1 - r_idx; + let r_f = l_0 - r - 1; + let big_l = r + 1; + debug_assert_eq!(q.len(), 1 << big_l); + + e_buf.clear(); + e_buf.resize(1 << big_l, EF::ZERO); + compute_eval_eq::, EF, false>(&group.w_svo[r_f..], &mut e_buf, EF::ONE); + + grid_expand_into(&q, big_l, &mut tilde_q, &mut scratch_q); + grid_expand_into(&e_buf, big_l, &mut tilde_e, &mut scratch_e); + + // Keep only the x_{big_l-1}=0 face (indices 3j) and x_{big_l-1}=∞ face (indices 3j+2). + let s = 3_usize.pow(r as u32); + let mut a = EF::zero_vec(s); + let mut b = EF::zero_vec(s); + let fill = |(j, (a_j, b_j)): (usize, (&mut EF, &mut EF))| { + *a_j = tilde_q[3 * j] * tilde_e[3 * j]; + *b_j = tilde_q[3 * j + 2] * tilde_e[3 * j + 2]; + }; + if s < PARALLEL_THRESHOLD { + a.iter_mut().zip(b.iter_mut()).enumerate().for_each(fill); + } else { + parallel::par_for_each_mut2(&mut a, &mut b, |j, a_j, b_j| fill((j, (a_j, b_j)))); + } + acc_0[r] = a; + acc_inf[r] = b; + + if r_idx + 1 < l_0 { + let alpha = group.w_svo[r_f]; + let half = q.len() / 2; + for i in 0..half { + let lo = q[i]; + let hi = q[i + half]; + q[i] = lo + alpha * (hi - lo); + } + q.truncate(half); + } + } + AccGroup { acc_0, acc_inf } +} + +pub(crate) fn build_accumulators(groups: &[CompressedGroup], l_0: usize) -> Vec> +where + EF: ExtensionField>, +{ + // Sequential across groups: each `build_accumulators_single` parallelizes internally + // (`compute_eval_eq`, `grid_expand_into`, ...), and the pool forbids nested dispatch. + groups.iter().map(|g| build_accumulators_single(g, l_0)).collect() +} + +pub(crate) fn round_message_with_tensor(r: usize, lagrange: &[EF], accs: &[AccGroup]) -> (EF, EF) { + debug_assert_eq!(lagrange.len(), 3_usize.pow(r as u32)); + let group_reduce = |acc: &AccGroup| { + lagrange + .iter() + .zip(&acc.acc_0[r]) + .zip(&acc.acc_inf[r]) + .fold((EF::ZERO, EF::ZERO), |(c0, c2), ((&l, &a0), &ainf)| { + (c0 + l * a0, c2 + l * ainf) + }) + }; + let add2 = |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2); + if 2 * lagrange.len() * accs.len() < PARALLEL_THRESHOLD { + accs.iter().map(group_reduce).fold((EF::ZERO, EF::ZERO), add2) + } else { + parallel::map_reduce(accs.len(), || (EF::ZERO, EF::ZERO), |i| group_reduce(&accs[i]), add2) + } +} diff --git a/crates/whir/src/utils.rs b/crates/whir/src/utils.rs index 9a8ec359..dbbd0bee 100644 --- a/crates/whir/src/utils.rs +++ b/crates/whir/src/utils.rs @@ -10,7 +10,6 @@ use std::any::{Any, TypeId}; use std::collections::HashMap; use std::sync::{Arc, Mutex, OnceLock}; use tracing::instrument; -use utils::log2_strict_usize; use crate::EvalsDft; use crate::Matrix; @@ -132,31 +131,16 @@ fn prepare_evals_for_fft_unpacked( let n_blocks = 1 << folding_factor; let full_len = evals.len() << log_inv_rate; let block_size = full_len / n_blocks; - let log_block_size = log2_strict_usize(block_size); let out_len = block_size * dft_n_cols; - let mut out: Vec = unsafe { uninitialized_vec(out_len) }; - if block_size == 0 || dft_n_cols == 0 { - return out; - } - - let rows_per_band = ((system_info::L1_CACHE_SIZE / 2) / (dft_n_cols * size_of::())).clamp(1, block_size); - let band_len = rows_per_band * dft_n_cols; - - parallel::par_chunks_mut(&mut out, band_len, |band_idx, band| { - let row0 = band_idx * rows_per_band; - let n_rows = band.len() / dft_n_cols; - for col in 0..dft_n_cols { - let col_base = col << log_block_size; - for r in 0..n_rows { - let src = (col_base + row0 + r) >> log_inv_rate; - unsafe { - *band.get_unchecked_mut(r * dft_n_cols + col) = *evals.get_unchecked(src); - } - } - } - }); - out + // LSB-cols layout (split-eq): column = LSB k bits of source index, row's high bits = remaining + // vars, row's low log_inv_rate bits = rate-extension dummy (data is constant in those). + parallel::par_map_collect(out_len, |i| { + let block_index = i % dft_n_cols; + let offset_in_block = i / dft_n_cols; + let src_index = ((offset_in_block >> log_inv_rate) << folding_factor) | block_index; + unsafe { *evals.get_unchecked(src_index) } + }) } fn prepare_evals_for_fft_packed_extension>>( @@ -168,40 +152,23 @@ fn prepare_evals_for_fft_packed_extension>>( assert!((evals.len() << log_packing).is_multiple_of(1 << folding_factor)); let n_blocks = 1 << folding_factor; let full_len = evals.len() << (log_inv_rate + log_packing); - let block_size = full_len / n_blocks; - let log_block_size = log2_strict_usize(block_size); + let n_blocks_mask = n_blocks - 1; let packing_mask = (1 << log_packing) - 1; - let mut out: Vec = unsafe { uninitialized_vec(full_len) }; - if block_size == 0 || n_blocks == 0 { - return out; - } - - let rows_per_band = ((system_info::L1_CACHE_SIZE / 2) / (n_blocks * size_of::())).clamp(1, block_size); - let band_len = rows_per_band * n_blocks; - - parallel::par_chunks_mut(&mut out, band_len, |band_idx, band| { - let row0 = band_idx * rows_per_band; - let n_rows = band.len() / n_blocks; - for col in 0..n_blocks { - let col_base = col << log_block_size; - for r in 0..n_rows { - let src_index = (col_base + row0 + r) >> log_inv_rate; - let packed_src_index = src_index >> log_packing; - let offset_in_packing = src_index & packing_mask; - let packed = unsafe { evals.get_unchecked(packed_src_index) }; - let unpacked: &[PFPacking] = packed.as_basis_coefficients_slice(); - let val = EF::from_basis_coefficients_fn(|j| unsafe { - let u: &PFPacking = unpacked.get_unchecked(j); - *u.as_slice().get_unchecked(offset_in_packing) - }); - unsafe { - *band.get_unchecked_mut(r * n_blocks + col) = val; - } - } - } - }); - out + // LSB-cols layout (split-eq): see prepare_evals_for_fft_unpacked. + parallel::par_map_collect(full_len, |i| { + let block_index = i & n_blocks_mask; + let offset_in_block = i >> folding_factor; + let src_index = ((offset_in_block >> log_inv_rate) << folding_factor) | block_index; + let packed_src_index = src_index >> log_packing; + let offset_in_packing = src_index & packing_mask; + let packed = unsafe { evals.get_unchecked(packed_src_index) }; + let unpacked: &[PFPacking] = packed.as_basis_coefficients_slice(); + EF::from_basis_coefficients_fn(|i| unsafe { + let u: &PFPacking = unpacked.get_unchecked(i); + *u.as_slice().get_unchecked(offset_in_packing) + }) + }) } type CacheKey = TypeId; diff --git a/crates/whir/src/verify.rs b/crates/whir/src/verify.rs index 53ed173f..17c5ca42 100644 --- a/crates/whir/src/verify.rs +++ b/crates/whir/src/verify.rs @@ -193,12 +193,13 @@ where .collect(), ); - let evaluation_of_weights = self.eval_constraints_poly(&round_constraints, folding_randomness.clone()); + // WHIR sumcheck folds LSB-first; eval_constraints_poly expects polynomial-var order. + let evaluation_of_weights = self.eval_constraints_poly(&round_constraints, folding_randomness.reversed()); - // Check the final sumcheck evaluation (coefficient form, reversed point) - let mut reversed_point = final_sumcheck_randomness.0.clone(); - reversed_point.reverse(); - let final_value = eval_multilinear_coeffs(&final_coefficients, &reversed_point); + // Check the final sumcheck evaluation (coefficient form). For LSB-fold, the sumcheck + // challenges are already in the order eval_multilinear_coeffs expects (point[0] is the + // last variable of the polynomial), so no reversal needed. + let final_value = eval_multilinear_coeffs(&final_coefficients, &final_sumcheck_randomness.0); if claimed_sum != evaluation_of_weights * final_value { return Err(ProofError::InvalidProof); } @@ -248,27 +249,22 @@ where verifier_state, ); - // dbg!(&stir_challenges_indexes); - // dbg!(verifier_state.challenger().state()); - - let dimensions = [Dimensions { + let dimensions = Dimensions { height: params.domain_size >> params.folding_factor, width: 1 << params.folding_factor, - }]; + }; let answers = self.verify_merkle_proof::( verifier_state, &commitment.root, &stir_challenges_indexes, - &dimensions, + dimensions, leafs_base_field, - round_index, - 0, )?; - // Compute STIR Constraints + let folding_randomness_reversed = folding_randomness.reversed(); let folds: Vec<_> = answers .into_iter() - .map(|answers| answers.evaluate(folding_randomness)) + .map(|answers| answers.evaluate(&folding_randomness_reversed)) .collect(); let stir_constraints = stir_challenges_indexes @@ -286,21 +282,20 @@ where Ok(stir_constraints) } - #[allow(clippy::too_many_arguments)] fn verify_merkle_proof( &self, verifier_state: &mut impl FSVerifier, root: &[PF; DIGEST_ELEMS], indices: &[usize], - dimensions: &[Dimensions], + dimensions: Dimensions, leafs_base_field: bool, - _round_index: usize, - _var_shift: usize, ) -> ProofResult>> where F: Field + ExtensionField>, EF: ExtensionField, { + // DoS-hardened (main): bound the opening batch and reject malformed leaf_data lengths + // instead of panicking. Single `dimensions` (branch): no per-leaf padding/skip variants. verifier_state.begin_merkle_opening_batch(indices.len())?; let res = if leafs_base_field { let mut answers = Vec::>::new(); @@ -315,17 +310,17 @@ where } for (i, &index) in indices.iter().enumerate() { - if !merkle_verify::, F>(*root, index, dimensions[0], answers[i].clone(), &merkle_proofs[i]) { + if !merkle_verify::, F>(*root, index, dimensions, answers[i].clone(), &merkle_proofs[i]) { return Err(ProofError::InvalidProof); } } answers .into_iter() - .map(|inner| inner.iter().map(|&f_el| f_el.into()).collect()) + .map(|inner| inner.into_iter().map(Into::into).collect()) .collect() } else { - let mut answers = vec![]; + let mut answers = Vec::>::new(); let mut merkle_proofs = Vec::new(); for _ in 0..indices.len() { @@ -337,7 +332,7 @@ where } for (i, &index) in indices.iter().enumerate() { - if !merkle_verify::, EF>(*root, index, dimensions[0], answers[i].clone(), &merkle_proofs[i]) { + if !merkle_verify::, EF>(*root, index, dimensions, answers[i].clone(), &merkle_proofs[i]) { return Err(ProofError::InvalidProof); } } @@ -357,8 +352,11 @@ where for (round, (randomness, constraints)) in constraints.iter().enumerate() { if round > 0 { + // LSB-fold drops the polynomial's high-indexed (last) k vars at each round. + // The reversed cumulative point places those at the END. let k = self.folding_factor.at_round(round - 1); - point = MultilinearPoint(point[k..].to_vec()); + let new_len = point.len() - k; + point = MultilinearPoint(point[..new_len].to_vec()); } let mut i = 0; for smt in constraints {