diff --git a/piop/Cargo.toml b/piop/Cargo.toml index 089f92bd..3d6f589a 100644 --- a/piop/Cargo.toml +++ b/piop/Cargo.toml @@ -31,6 +31,7 @@ workspace = true [features] parallel = ["dep:rayon", "zinc-utils/parallel"] +simd = ["zinc-poly/simd"] [[bench]] name = "sumcheck" diff --git a/piop/benches/sumcheck.rs b/piop/benches/sumcheck.rs index 13db2d9a..9f5fc410 100644 --- a/piop/benches/sumcheck.rs +++ b/piop/benches/sumcheck.rs @@ -1,7 +1,7 @@ #![allow(non_local_definitions)] #![allow(clippy::eq_op)] -use std::{hint::black_box, ops::Mul}; +use std::{hint::black_box, ops::{Add, Mul}}; use criterion::{ AxisScale, BatchSize, BenchmarkGroup, BenchmarkId, Criterion, PlotConfiguration, @@ -126,22 +126,138 @@ pub fn bench_simple_product( ); } +#[allow(clippy::arithmetic_side_effects)] +pub fn bench_sum_of_40( + group: &mut BenchmarkGroup, + witness_size: usize, +) where + F: FromPrimitiveWithConfig + InnerTransparentField + FromRef + 'static, + F::Inner: FromRef + ConstTranscribable + ConstIntSemiring, + MillerRabin: PrimalityTest, + for<'a> &'a F: Mul<&'a F, Output = F> + Add<&'a F, Output = F>, +{ + const N_POLYS: usize = 40; + + let mut rng = rng(); + let nvars = zinc_utils::log2(witness_size) as usize; + let params = format!("LIMBS={}/nvars={}/npolys={}", LIMBS, nvars, N_POLYS); + + // Pre-generate the 40 random MLEs (not benchmarked). + let polys: Vec>> = (0..N_POLYS) + .map(|_| { + let evals: Vec = (0..witness_size).map(|_| rng.random()).collect(); + DenseMultilinearExtension::from_evaluations_vec( + nvars, + evals.into_iter().map(BinaryPoly::from).collect(), + BinaryPoly::zero(), + ) + }) + .collect(); + + let transcript = KeccakTranscript::new(); + + // Prover closure: eq(x,r) * sum_i a_i(x) + let prove = + |(polys, mut transcript): (Vec>>, KeccakTranscript)| -> RFSumcheckProof> { + let field_cfg = + transcript.get_random_field_cfg::::Inner, MillerRabin>(); + + let eq_r = + build_eq_x_r_inner(&vec![F::from_with_cfg(2u32, &field_cfg); nvars], &field_cfg) + .expect("Failed to build eq_r"); + + // polys go into mles (projected), eq_r goes into mles_f. + // comb_fn receives vals[0..N_POLYS] = a_i's, vals[N_POLYS] = eq. + // G(alpha, x) = eq(x,r) * sum_{i=0}^{N_POLYS-1} a_i(x) + (RFSumcheck::::prove_as_subprotocol( + &mut transcript, + polys, + vec![eq_r], + nvars, + 1, + |_alpha, vals| { + let mut sum = vals[0].clone(); + for v in &vals[1..N_POLYS] { + sum = &sum + v; + } + sum * &vals[N_POLYS] + }, + field_cfg, + )) + .0 + }; + + group.bench_with_input( + BenchmarkId::new("Sum-of-40 Sumcheck Prover", ¶ms), + &(polys.clone(), transcript.clone()), + |bench, (polys, transcript)| { + bench.iter_batched( + || (polys.clone(), transcript.clone()), + |(polys, transcript)| { + let _ = black_box(&prove((polys, transcript))); + }, + BatchSize::SmallInput, + ); + }, + ); + + let proof = prove((polys, transcript.clone())); + + group.bench_with_input( + BenchmarkId::new("Sum-of-40 Sumcheck Verifier", ¶ms), + &(proof, transcript), + |bench, (proof, transcript)| { + bench.iter_batched( + || (proof.clone(), transcript.clone()), + |(proof, mut transcript)| { + let field_cfg = + transcript.get_random_field_cfg::::Inner, MillerRabin>(); + + let _ = black_box( + RFSumcheck::::verify_as_subprotocol( + &mut transcript, + nvars, + 1, + &proof, + field_cfg, + ) + .expect("Failed to verify"), + ); + }, + BatchSize::SmallInput, + ); + }, + ); +} + pub fn sumcheck_benches(c: &mut Criterion) { let plot_config = PlotConfiguration::default().summary_scale(AxisScale::Logarithmic); let mut group = c.benchmark_group("Sumcheck benchmarks"); group.plot_config(plot_config); - bench_simple_product::, 3>(&mut group, 1 << 13); - bench_simple_product::, 4>(&mut group, 1 << 13); - bench_simple_product::, 3>(&mut group, 1 << 14); - bench_simple_product::, 4>(&mut group, 1 << 14); - bench_simple_product::, 3>(&mut group, 1 << 15); - bench_simple_product::, 4>(&mut group, 1 << 15); - bench_simple_product::, 3>(&mut group, 1 << 16); - bench_simple_product::, 4>(&mut group, 1 << 16); - bench_simple_product::, 3>(&mut group, 1 << 17); - bench_simple_product::, 4>(&mut group, 1 << 17); + // bench_simple_product::, 3>(&mut group, 1 << 13); + // bench_simple_product::, 4>(&mut group, 1 << 13); + // bench_simple_product::, 3>(&mut group, 1 << 14); + // bench_simple_product::, 4>(&mut group, 1 << 14); + // bench_simple_product::, 3>(&mut group, 1 << 15); + // bench_simple_product::, 4>(&mut group, 1 << 15); + // bench_simple_product::, 3>(&mut group, 1 << 16); + // bench_simple_product::, 4>(&mut group, 1 << 16); + // bench_simple_product::, 3>(&mut group, 1 << 17); + // bench_simple_product::, 4>(&mut group, 1 << 17); + + bench_sum_of_40::, 3>(&mut group, 1 << 6); + bench_sum_of_40::, 4>(&mut group, 1 << 6); + bench_sum_of_40::, 3>(&mut group, 1 << 7); + bench_sum_of_40::, 4>(&mut group, 1 << 7); + bench_sum_of_40::, 3>(&mut group, 1 << 8); + bench_sum_of_40::, 4>(&mut group, 1 << 8); + bench_sum_of_40::, 3>(&mut group, 1 << 9); + bench_sum_of_40::, 4>(&mut group, 1 << 9); + bench_sum_of_40::, 3>(&mut group, 1 << 10); + bench_sum_of_40::, 4>(&mut group, 1 << 10); + group.finish(); } diff --git a/scripts/bench_sumcheck_to_latex.py b/scripts/bench_sumcheck_to_latex.py new file mode 100644 index 00000000..16100fcd --- /dev/null +++ b/scripts/bench_sumcheck_to_latex.py @@ -0,0 +1,268 @@ +#!/usr/bin/env python3 +"""Run the sumcheck benchmark and output results as a LaTeX table. + +This script executes ``cargo bench --bench sumcheck -p zinc-piop`` with the +requested Cargo feature flags, parses the Criterion output, and produces a +LaTeX table summarising the median wall-clock times for each benchmark +configuration. + +The generated table is organised as follows: + - Rows are indexed by the number of sumcheck variables *n* (i.e. the + multilinear extension is defined over {0,1}^n). + - Columns are grouped by benchmark type (Prover / Verifier). Within each + group there is one sub-column per field size (e.g. 192-bit for + ``MontyField<3>`` and 256-bit for ``MontyField<4>``). + +The table uses the ``booktabs``, ``multirow`` and ``siunitx`` LaTeX packages. +Times are automatically formatted in the most readable unit (µs, ms, or s). + +Usage +----- +Run the benchmarks and print the table to stdout:: + + python3 scripts/bench_sumcheck_to_latex.py + +Save to a file:: + + python3 scripts/bench_sumcheck_to_latex.py -o table.tex + +Specify custom Cargo features:: + + python3 scripts/bench_sumcheck_to_latex.py --features "parallel simd asm" + +Only run benchmarks matching a filter:: + + python3 scripts/bench_sumcheck_to_latex.py --filter "Sum-of-40" + +Parse previously captured output from stdin instead of running cargo:: + + cargo bench --bench sumcheck -p zinc-piop --features "parallel simd" 2>&1 \\ + | python3 scripts/bench_sumcheck_to_latex.py --from-stdin +""" + +import argparse +import re +import subprocess +import sys +from dataclasses import dataclass + + +@dataclass +class BenchResult: + benchmark: str # e.g. "Sum-of-40 Sumcheck Prover" + limbs: int + nvars: int + npolys: int + time_low: str # e.g. "769.09 µs" + time_mid: str + time_high: str + + +# Matches lines like: +# Sumcheck benchmarks/Sum-of-40 Sumcheck Prover/LIMBS=3/nvars=6/npolys=40 +BENCH_NAME_RE = re.compile( + r"^Sumcheck benchmarks/" + r"(?P[^/]+)/" + r"LIMBS=(?P\d+)/" + r"nvars=(?P\d+)" + r"(?:/npolys=(?P\d+))?\s*$" +) + +# Matches lines like: +# time: [769.09 µs 799.94 µs 849.34 µs] +TIME_RE = re.compile( + r"time:\s+\[(?P[\d.]+\s*\S+)\s+" + r"(?P[\d.]+\s*\S+)\s+" + r"(?P[\d.]+\s*\S+)\]" +) + + +def parse_to_ms(time_str: str) -> float: + """Convert a criterion time string like '769.09 µs' to milliseconds.""" + parts = time_str.strip().split() + value = float(parts[0]) + unit = parts[1] + if unit in ("ns",): + return value / 1_000_000 + if unit in ("µs", "us"): + return value / 1_000 + if unit in ("ms",): + return value + if unit in ("s",): + return value * 1_000 + raise ValueError(f"Unknown time unit: {unit!r}") + + +def format_ms(ms: float) -> str: + """Format milliseconds for display, choosing the best unit.""" + if ms < 1: + return f"{ms * 1000:.1f}\\,\\micro{{s}}" + if ms < 1000: + return f"{ms:.2f}\\,ms" + return f"{ms / 1000:.2f}\\,s" + + +def parse_output(output: str) -> list[BenchResult]: + """Parse criterion benchmark output into structured results.""" + results: list[BenchResult] = [] + current_name: dict | None = None + + for line in output.splitlines(): + m = BENCH_NAME_RE.match(line.strip()) + if m: + current_name = m.groupdict() + continue + + m = TIME_RE.search(line) + if m and current_name is not None: + results.append(BenchResult( + benchmark=current_name["bench"], + limbs=int(current_name["limbs"]), + nvars=int(current_name["nvars"]), + npolys=int(current_name.get("npolys") or 0), + time_low=m.group("low"), + time_mid=m.group("mid"), + time_high=m.group("high"), + )) + current_name = None + + return results + + +def results_to_latex(results: list[BenchResult]) -> str: + """Convert parsed benchmark results into a LaTeX table string.""" + # Group by benchmark type (Prover / Verifier) + bench_types = sorted(set(r.benchmark for r in results)) + limbs_values = sorted(set(r.limbs for r in results)) + nvars_values = sorted(set(r.nvars for r in results)) + + # Build a lookup: (benchmark, limbs, nvars) -> BenchResult + lookup: dict[tuple[str, int, int], BenchResult] = {} + for r in results: + lookup[(r.benchmark, r.limbs, r.nvars)] = r + + # Number of time columns = len(limbs_values) per benchmark type + n_time_cols = len(limbs_values) * len(bench_types) + + lines: list[str] = [] + lines.append(r"\begin{table}[ht]") + lines.append(r" \centering") + lines.append(r" \sisetup{round-mode=places, round-precision=2}") + + # Column spec: nvars | for each bench type: one col per limbs value + col_spec = "c" + "".join("|" + "c" * len(limbs_values) for _ in bench_types) + lines.append(r" \begin{tabular}{" + col_spec + "}") + lines.append(r" \toprule") + + # Header row 1: nvars + benchmark type names spanning columns + header1_parts = [r" \multirow{2}{*}{$n$}"] + for bt in bench_types: + short_name = bt.replace("Sumcheck ", "") + header1_parts.append( + rf"\multicolumn{{{len(limbs_values)}}}{{c}}{{{short_name}}}" + ) + lines.append(" & ".join(header1_parts) + r" \\") + + # Header row 2: limbs sub-columns + header2_parts = [""] + for _ in bench_types: + for l in limbs_values: + bits = l * 64 + header2_parts.append(f"{bits}-bit") + lines.append(" & ".join(header2_parts) + r" \\") + lines.append(r" \midrule") + + # Data rows + for nv in nvars_values: + row_parts = [f"${nv}$"] + for bt in bench_types: + for l in limbs_values: + key = (bt, l, nv) + if key in lookup: + ms = parse_to_ms(lookup[key].time_mid) + row_parts.append(f"${format_ms(ms)}$") + else: + row_parts.append("--") + lines.append(" " + " & ".join(row_parts) + r" \\") + + lines.append(r" \bottomrule") + lines.append(r" \end{tabular}") + + # Caption with npolys if available + npolys_set = set(r.npolys for r in results if r.npolys > 0) + npolys_str = ", ".join(str(n) for n in sorted(npolys_set)) if npolys_set else "" + caption = r"Sumcheck benchmark times (median)" + if npolys_str: + caption += f", {npolys_str} polynomials" + lines.append(r" \caption{" + caption + "}") + lines.append(r" \label{tab:sumcheck-bench}") + lines.append(r"\end{table}") + + return "\n".join(lines) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Run sumcheck benchmarks and produce a LaTeX table." + ) + parser.add_argument( + "--output", "-o", + help="Output .tex file (default: stdout)", + ) + parser.add_argument( + "--features", + default="parallel simd", + help='Cargo feature flags (default: "parallel simd")', + ) + parser.add_argument( + "--filter", + default=None, + help="Criterion filter regex passed after -- (default: run all)", + ) + parser.add_argument( + "--from-stdin", + action="store_true", + help="Read benchmark output from stdin instead of running cargo bench", + ) + args = parser.parse_args() + + if args.from_stdin: + output = sys.stdin.read() + else: + cmd = [ + "cargo", "bench", + "--bench", "sumcheck", + "-p", "zinc-piop", + ] + if args.features: + cmd += ["--features", args.features] + if args.filter: + cmd += ["--", args.filter] + + print(f"Running: {' '.join(cmd)}", file=sys.stderr) + proc = subprocess.run(cmd, capture_output=True, text=True) + output = proc.stdout + proc.stderr + if proc.returncode != 0: + print("Benchmark command failed:", file=sys.stderr) + print(output, file=sys.stderr) + sys.exit(1) + + results = parse_output(output) + if not results: + print("No benchmark results found in output.", file=sys.stderr) + sys.exit(1) + + print(f"Parsed {len(results)} benchmark result(s).", file=sys.stderr) + + latex = results_to_latex(results) + + if args.output: + with open(args.output, "w") as f: + f.write(latex + "\n") + print(f"Wrote LaTeX table to {args.output}", file=sys.stderr) + else: + print(latex) + + +if __name__ == "__main__": + main() diff --git a/zip-plus/benches/zip_benches.rs b/zip-plus/benches/zip_benches.rs index d0f0ba3b..6ab38043 100644 --- a/zip-plus/benches/zip_benches.rs +++ b/zip-plus/benches/zip_benches.rs @@ -12,9 +12,13 @@ use crypto_bigint::U64; use crypto_primitives::{crypto_bigint_int::Int, crypto_bigint_uint::Uint}; use zinc_utils::UNCHECKED; use zip_plus::{ - code::raa::{RaaCode, RaaConfig}, + code::{ + iprs::{IprsCode, PnttConfigF2_16_1_Depth2_Rate1_2}, + raa::{RaaCode, RaaConfig}, + }, pcs::structs::ZipTypes, }; +use zinc_utils::mul_by_scalar::WideningMulByScalar; const INT_LIMBS: usize = U64::LIMBS; @@ -43,11 +47,31 @@ impl RaaConfig for BenchRaaConfig { type Code = RaaCode; +#[derive(Clone, Default)] +struct I32WideningMulByScalar; + +impl WideningMulByScalar for I32WideningMulByScalar { + type Output = i64; + + #[allow(clippy::arithmetic_side_effects)] + fn mul_by_scalar_widen(lhs: &i32, rhs: &i64) -> Self::Output { + i64::from(*lhs) * *rhs + } +} + +type IprsCodeDepth2 = IprsCode; + fn zip_benchmarks(c: &mut Criterion) { let mut group = c.benchmark_group("Zip+"); do_bench::(&mut group); group.finish(); } -criterion_group!(benches, zip_benchmarks); +fn zip_benchmarks_iprs(c: &mut Criterion) { + let mut group = c.benchmark_group("Zip IPRS"); + do_bench_iprs_matrices::(&mut group); + group.finish(); +} + +criterion_group!(benches, zip_benchmarks, zip_benchmarks_iprs); criterion_main!(benches); diff --git a/zip-plus/benches/zip_common.rs b/zip-plus/benches/zip_common.rs index 23b6f45b..251b7ec0 100644 --- a/zip-plus/benches/zip_common.rs +++ b/zip-plus/benches/zip_common.rs @@ -80,6 +80,52 @@ pub fn do_bench, const CHECK_FOR_OVERFLOWS: boo verify::(group); } +pub fn do_bench_iprs_matrices, const CHECK_FOR_OVERFLOWS: bool>( + group: &mut BenchmarkGroup, +) where + StandardUniform: Distribution + Distribution, + F: for<'a> FromWithConfig<&'a Zt::Chal> + for<'a> FromWithConfig<&'a Zt::Pt>, + ::Inner: FromRef, + Zt::Eval: ProjectableToField, + Zt::Cw: ProjectableToField, +{ + encode_rows::(group); + encode_rows::(group); + encode_rows::(group); + encode_rows::(group); + encode_rows::(group); + + merkle_root::(group); + merkle_root::(group); + merkle_root::(group); + merkle_root::(group); + merkle_root::(group); + + commit::(group); + commit::(group); + commit::(group); + commit::(group); + commit::(group); + + test::(group); + test::(group); + test::(group); + test::(group); + test::(group); + + evaluate::(group); + evaluate::(group); + evaluate::(group); + evaluate::(group); + evaluate::(group); + + verify::(group); + verify::(group); + verify::(group); + verify::(group); + verify::(group); +} + pub fn encode_rows, const P: usize>( group: &mut BenchmarkGroup, ) where diff --git a/zip-plus/benches/zip_plus_benches.rs b/zip-plus/benches/zip_plus_benches.rs index 029ca6ef..e7fc5b9f 100644 --- a/zip-plus/benches/zip_plus_benches.rs +++ b/zip-plus/benches/zip_plus_benches.rs @@ -21,7 +21,7 @@ use crypto_primitives::{ }; use zip_plus::{ code::{ - iprs::{IprsCode, PnttConfigF2_16_1}, + iprs::{IprsCode, PnttConfigF2_16_1_Depth2_Rate1_2}, raa::{RaaCode, RaaConfig}, }, pcs::structs::ZipTypes, @@ -68,9 +68,9 @@ impl RaaConfig for BenchRaaConfig { type SomeRaaCode = RaaCode, BenchRaaConfig, 4>; -type SomeIprsCode = IprsCode< +type SomeIprsCodeDepth2 = IprsCode< BenchZipPlusTypes, - PnttConfigF2_16_1, + PnttConfigF2_16_1_Depth2_Rate1_2, BinaryPolyWideningMulByScalar, >; @@ -86,8 +86,12 @@ fn zip_plus_benchmarks_raa(c: &mut Criterion) { fn zip_plus_benchmarks_iprs(c: &mut Criterion) { let mut group = c.benchmark_group("Zip+ IPRS"); - do_bench::, SomeIprsCode, UNCHECKED>(&mut group); - do_bench::, SomeIprsCode, UNCHECKED>(&mut group); + do_bench_iprs_matrices::, SomeIprsCodeDepth2, UNCHECKED>( + &mut group, + ); + do_bench_iprs_matrices::, SomeIprsCodeDepth2, UNCHECKED>( + &mut group, + ); group.finish(); } diff --git a/zip-plus/src/code/iprs.rs b/zip-plus/src/code/iprs.rs index 864ddeb0..399f2c5a 100644 --- a/zip-plus/src/code/iprs.rs +++ b/zip-plus/src/code/iprs.rs @@ -12,7 +12,14 @@ use crate::{ }; use crypto_primitives::{FromPrimitiveWithConfig, FromWithConfig}; use num_traits::{CheckedAdd, CheckedMul}; -pub use pntt::radix8::params::{PnttConfigF2_16_1, PnttInt, Radix8PnttParams}; +pub use pntt::radix8::params::{ + PnttConfigF2_16_1, + PnttConfigF2_16_1_Depth2_Rate1_2, + PnttConfigF2_16_1_Depth2_Rate1_4, + PnttConfigF2_16_1_Rate1_4, + PnttInt, + Radix8PnttParams, +}; use std::{fmt::Debug, iter::Sum, marker::PhantomData, ops::AddAssign}; use zinc_utils::{ CHECKED, diff --git a/zip-plus/src/code/iprs/pntt/radix8/params.rs b/zip-plus/src/code/iprs/pntt/radix8/params.rs index 27681234..91125b43 100644 --- a/zip-plus/src/code/iprs/pntt/radix8/params.rs +++ b/zip-plus/src/code/iprs/pntt/radix8/params.rs @@ -159,9 +159,22 @@ impl Radix8PnttParams { /// the field `Fp` for `p = 2^16 + 1`. /// /// Supports `DEPTH` up to `3`. +/// +/// With `BASE_LEN = 32` and `BASE_DIM = 64`, +/// this yields rate $\frac{1}{2}$ codes. #[derive(Clone, Copy)] pub struct PnttConfigF2_16_1; +/// Pseudo NTT configuration derived from +/// the field `Fp` for `p = 2^16 + 1`. +/// +/// Supports `DEPTH` up to `3`. +/// +/// With `BASE_LEN = 32` and `BASE_DIM = 128`, +/// this yields rate $\frac{1}{4}$ codes. +#[derive(Clone, Copy)] +pub struct PnttConfigF2_16_1_Rate1_4; + mod fq { #![allow(non_local_definitions)] use ark_ff::{Fp64, MontBackend, MontConfig}; @@ -192,6 +205,27 @@ impl Config for PnttConfigF2_16_1 { } } +impl Config for PnttConfigF2_16_1_Rate1_4 { + type Field = fq::Fq; + const FIELD_MODULUS: u32 = fq::MODULUS; + const BASE_LEN: usize = 32; + const BASE_DIM: usize = 128; + const DEPTH: usize = DEPTH; + const BASE_TWIDDLES: [PnttInt; 8] = [1, 4096, -256, 16, -1, -4096, 256, -16]; + + fn field_to_int_normalized(x: Self::Field) -> PnttInt { + let big_int = fq::FqBackend::into_bigint(x); + + precompute::normalize_field_element(big_int.0[0], Self::FIELD_MODULUS) + } +} + +/// Depth-2 configuration for message size $2^{11}$ with rate $\frac{1}{2}$. +pub type PnttConfigF2_16_1_Depth2_Rate1_2 = PnttConfigF2_16_1<2>; + +/// Depth-2 configuration for message size $2^{11}$ with rate $\frac{1}{4}$. +pub type PnttConfigF2_16_1_Depth2_Rate1_4 = PnttConfigF2_16_1_Rate1_4<2>; + #[cfg(test)] mod tests { use super::*;