From f9b3424e4da52f984e7fc8d61e8b0ca2ab35356d Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Thu, 26 Feb 2026 17:57:56 +0530 Subject: [PATCH 1/5] feat : ec_ops added with proper constraints for field element --- .../src/witness/scheduling/dependency.rs | 6 +- .../common/src/witness/scheduling/remapper.rs | 6 + .../common/src/witness/witness_builder.rs | 12 + provekit/prover/src/witness/bigint_mod.rs | 422 ++++++++++++++++++ provekit/prover/src/witness/mod.rs | 1 + .../prover/src/witness/witness_builder.rs | 18 + provekit/r1cs-compiler/src/lib.rs | 1 + provekit/r1cs-compiler/src/msm/curve.rs | 71 +++ provekit/r1cs-compiler/src/msm/ec_ops.rs | 305 +++++++++++++ provekit/r1cs-compiler/src/msm/mod.rs | 2 + provekit/r1cs-compiler/src/noir_to_r1cs.rs | 6 + 11 files changed, 849 insertions(+), 1 deletion(-) create mode 100644 provekit/prover/src/witness/bigint_mod.rs create mode 100644 provekit/r1cs-compiler/src/msm/curve.rs create mode 100644 provekit/r1cs-compiler/src/msm/ec_ops.rs create mode 100644 provekit/r1cs-compiler/src/msm/mod.rs diff --git a/provekit/common/src/witness/scheduling/dependency.rs b/provekit/common/src/witness/scheduling/dependency.rs index a5cbaefd6..f2d480377 100644 --- a/provekit/common/src/witness/scheduling/dependency.rs +++ b/provekit/common/src/witness/scheduling/dependency.rs @@ -78,7 +78,9 @@ impl DependencyInfo { WitnessBuilder::Sum(_, ops) => ops.iter().map(|SumTerm(_, idx)| *idx).collect(), WitnessBuilder::Product(_, a, b) => vec![*a, *b], WitnessBuilder::MultiplicitiesForRange(_, _, values) => values.clone(), - WitnessBuilder::Inverse(_, x) => vec![*x], + WitnessBuilder::Inverse(_, x) + | WitnessBuilder::ModularInverse(_, x, _) + | WitnessBuilder::IntegerQuotient(_, x, _) => vec![*x], WitnessBuilder::IndexedLogUpDenominator( _, sz, @@ -240,6 +242,8 @@ impl DependencyInfo { | WitnessBuilder::Challenge(idx) | WitnessBuilder::IndexedLogUpDenominator(idx, ..) | WitnessBuilder::Inverse(idx, _) + | WitnessBuilder::ModularInverse(idx, ..) + | WitnessBuilder::IntegerQuotient(idx, ..) | WitnessBuilder::ProductLinearOperation(idx, ..) | WitnessBuilder::LogUpDenominator(idx, ..) | WitnessBuilder::LogUpInverse(idx, ..) diff --git a/provekit/common/src/witness/scheduling/remapper.rs b/provekit/common/src/witness/scheduling/remapper.rs index 9503847a3..76c30bd23 100644 --- a/provekit/common/src/witness/scheduling/remapper.rs +++ b/provekit/common/src/witness/scheduling/remapper.rs @@ -115,6 +115,12 @@ impl WitnessIndexRemapper { WitnessBuilder::Inverse(idx, operand) => { WitnessBuilder::Inverse(self.remap(*idx), self.remap(*operand)) } + WitnessBuilder::ModularInverse(idx, operand, modulus) => { + WitnessBuilder::ModularInverse(self.remap(*idx), self.remap(*operand), *modulus) + } + WitnessBuilder::IntegerQuotient(idx, dividend, divisor) => { + WitnessBuilder::IntegerQuotient(self.remap(*idx), self.remap(*dividend), *divisor) + } WitnessBuilder::ProductLinearOperation( idx, ProductLinearTerm(x, a, b), diff --git a/provekit/common/src/witness/witness_builder.rs b/provekit/common/src/witness/witness_builder.rs index 0628fc2e3..5719ee99e 100644 --- a/provekit/common/src/witness/witness_builder.rs +++ b/provekit/common/src/witness/witness_builder.rs @@ -88,6 +88,18 @@ pub enum WitnessBuilder { /// The inverse of the value at a specified witness index /// (witness index, operand witness index) Inverse(usize, usize), + /// The modular inverse of the value at a specified witness index, modulo + /// a given prime modulus. Computes a^{-1} mod m using Fermat's little + /// theorem (a^{m-2} mod m). Unlike Inverse (BN254 field inverse), this + /// operates as integer modular arithmetic. + /// (witness index, operand witness index, modulus) + ModularInverse(usize, usize, #[serde(with = "serde_ark")] FieldElement), + /// The integer quotient floor(dividend / divisor). Used by reduce_mod to + /// compute k = floor(v / m) so that v = k*m + result with 0 <= result < m. + /// Unlike field multiplication by the inverse, this performs true integer + /// division on the BigInteger representation. + /// (witness index, dividend witness index, divisor constant) + IntegerQuotient(usize, usize, #[serde(with = "serde_ark")] FieldElement), /// Products with linear operations on the witness indices. /// Fields are ProductLinearOperation(witness_idx, (index, a, b), (index, c, /// d)) such that we wish to compute (ax + b) * (cx + d). diff --git a/provekit/prover/src/witness/bigint_mod.rs b/provekit/prover/src/witness/bigint_mod.rs new file mode 100644 index 000000000..3252aeea2 --- /dev/null +++ b/provekit/prover/src/witness/bigint_mod.rs @@ -0,0 +1,422 @@ +/// BigInteger modular arithmetic on [u64; 4] limbs (256-bit). +/// +/// These helpers compute modular inverse via Fermat's little theorem: +/// a^{-1} = a^{m-2} mod m, using schoolbook multiplication and +/// square-and-multiply exponentiation. + +/// Schoolbook multiplication: 4×4 limbs → 8 limbs (256-bit × 256-bit → +/// 512-bit). +fn widening_mul(a: &[u64; 4], b: &[u64; 4]) -> [u64; 8] { + let mut result = [0u64; 8]; + for i in 0..4 { + let mut carry = 0u128; + for j in 0..4 { + let product = (a[i] as u128) * (b[j] as u128) + (result[i + j] as u128) + carry; + result[i + j] = product as u64; + carry = product >> 64; + } + result[i + 4] = carry as u64; + } + result +} + +/// Compare 8-limb value with 4-limb value (zero-extended to 8 limbs). +/// Returns Ordering::Greater if wide > narrow, etc. +#[cfg(test)] +fn cmp_wide_narrow(wide: &[u64; 8], narrow: &[u64; 4]) -> std::cmp::Ordering { + // Check high limbs of wide (must all be zero for equality/less) + for i in (4..8).rev() { + if wide[i] != 0 { + return std::cmp::Ordering::Greater; + } + } + // Compare the low 4 limbs + for i in (0..4).rev() { + match wide[i].cmp(&narrow[i]) { + std::cmp::Ordering::Equal => continue, + other => return other, + } + } + std::cmp::Ordering::Equal +} + +/// Modular reduction of a 512-bit value by a 256-bit modulus. +/// Uses bit-by-bit long division. +fn reduce_wide(wide: &[u64; 8], modulus: &[u64; 4]) -> [u64; 4] { + // Find the highest set bit in wide + let mut highest_bit = 0; + for i in (0..8).rev() { + if wide[i] != 0 { + highest_bit = i * 64 + (64 - wide[i].leading_zeros() as usize); + break; + } + } + if highest_bit == 0 { + return [0u64; 4]; + } + + // Bit-by-bit long division + // remainder starts at 0, we shift in bits from the dividend + let mut remainder = [0u64; 4]; + for bit_pos in (0..highest_bit).rev() { + // Left-shift remainder by 1 + let carry = shift_left_one(&mut remainder); + debug_assert_eq!(carry, 0, "remainder overflow during shift"); + + // Bring in the next bit from wide + let limb_idx = bit_pos / 64; + let bit_idx = bit_pos % 64; + let bit = (wide[limb_idx] >> bit_idx) & 1; + remainder[0] |= bit; + + // If remainder >= modulus, subtract + if cmp_4limb(&remainder, modulus) != std::cmp::Ordering::Less { + sub_4limb_inplace(&mut remainder, modulus); + } + } + + remainder +} + +/// Left-shift a 4-limb number by 1 bit. Returns the carry-out bit. +fn shift_left_one(a: &mut [u64; 4]) -> u64 { + let mut carry = 0u64; + for limb in a.iter_mut() { + let new_carry = *limb >> 63; + *limb = (*limb << 1) | carry; + carry = new_carry; + } + carry +} + +/// Compare two 4-limb numbers. +fn cmp_4limb(a: &[u64; 4], b: &[u64; 4]) -> std::cmp::Ordering { + for i in (0..4).rev() { + match a[i].cmp(&b[i]) { + std::cmp::Ordering::Equal => continue, + other => return other, + } + } + std::cmp::Ordering::Equal +} + +/// Subtract b from a in-place (a -= b). Assumes a >= b. +fn sub_4limb_inplace(a: &mut [u64; 4], b: &[u64; 4]) { + let mut borrow = 0u64; + for i in 0..4 { + let (diff, borrow1) = a[i].overflowing_sub(b[i]); + let (diff2, borrow2) = diff.overflowing_sub(borrow); + a[i] = diff2; + borrow = (borrow1 as u64) + (borrow2 as u64); + } + debug_assert_eq!(borrow, 0, "subtraction underflow: a < b"); +} + +/// Modular multiplication: (a * b) mod m. +pub fn mul_mod(a: &[u64; 4], b: &[u64; 4], m: &[u64; 4]) -> [u64; 4] { + let wide = widening_mul(a, b); + reduce_wide(&wide, m) +} + +/// Modular exponentiation: base^exp mod m using square-and-multiply. +pub fn mod_pow(base: &[u64; 4], exp: &[u64; 4], m: &[u64; 4]) -> [u64; 4] { + // Find highest set bit in exp + let mut highest_bit = 0; + for i in (0..4).rev() { + if exp[i] != 0 { + highest_bit = i * 64 + (64 - exp[i].leading_zeros() as usize); + break; + } + } + if highest_bit == 0 { + // exp == 0 → result = 1 (for m > 1) + return [1, 0, 0, 0]; + } + + let mut result = [1u64, 0, 0, 0]; // 1 + for bit_pos in (0..highest_bit).rev() { + // Square + result = mul_mod(&result, &result, m); + // Multiply if bit is set + let limb_idx = bit_pos / 64; + let bit_idx = bit_pos % 64; + if (exp[limb_idx] >> bit_idx) & 1 == 1 { + result = mul_mod(&result, base, m); + } + } + + result +} + +/// Integer division with remainder: dividend = quotient * divisor + remainder, +/// where 0 <= remainder < divisor. Uses bit-by-bit long division. +pub fn divmod(dividend: &[u64; 4], divisor: &[u64; 4]) -> ([u64; 4], [u64; 4]) { + // Find the highest set bit in dividend + let mut highest_bit = 0; + for i in (0..4).rev() { + if dividend[i] != 0 { + highest_bit = i * 64 + (64 - dividend[i].leading_zeros() as usize); + break; + } + } + if highest_bit == 0 { + return ([0u64; 4], [0u64; 4]); + } + + let mut quotient = [0u64; 4]; + let mut remainder = [0u64; 4]; + + for bit_pos in (0..highest_bit).rev() { + // Left-shift remainder by 1 + let carry = shift_left_one(&mut remainder); + debug_assert_eq!(carry, 0, "remainder overflow during shift"); + + // Bring in the next bit from dividend + let limb_idx = bit_pos / 64; + let bit_idx = bit_pos % 64; + remainder[0] |= (dividend[limb_idx] >> bit_idx) & 1; + + // If remainder >= divisor, subtract and set quotient bit + if cmp_4limb(&remainder, divisor) != std::cmp::Ordering::Less { + sub_4limb_inplace(&mut remainder, divisor); + quotient[limb_idx] |= 1u64 << bit_idx; + } + } + + (quotient, remainder) +} + +/// Subtract a small u64 value from a 4-limb number. Assumes a >= small. +pub fn sub_u64(a: &[u64; 4], small: u64) -> [u64; 4] { + let mut result = *a; + let (diff, borrow) = result[0].overflowing_sub(small); + result[0] = diff; + if borrow { + for limb in result[1..].iter_mut() { + let (d, b) = limb.overflowing_sub(1); + *limb = d; + if !b { + break; + } + } + } + result +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_widening_mul_small() { + // 3 * 7 = 21 + let a = [3, 0, 0, 0]; + let b = [7, 0, 0, 0]; + let result = widening_mul(&a, &b); + assert_eq!(result[0], 21); + assert_eq!(result[1..], [0; 7]); + } + + #[test] + fn test_widening_mul_overflow() { + // u64::MAX * u64::MAX = (2^64-1)^2 = 2^128 - 2^65 + 1 + let a = [u64::MAX, 0, 0, 0]; + let b = [u64::MAX, 0, 0, 0]; + let result = widening_mul(&a, &b); + // (2^64-1)^2 = 0xFFFFFFFFFFFFFFFE_0000000000000001 + assert_eq!(result[0], 1); + assert_eq!(result[1], u64::MAX - 1); + assert_eq!(result[2..], [0; 6]); + } + + #[test] + fn test_reduce_wide_no_reduction() { + // 5 mod 7 = 5 + let wide = [5, 0, 0, 0, 0, 0, 0, 0]; + let modulus = [7, 0, 0, 0]; + assert_eq!(reduce_wide(&wide, &modulus), [5, 0, 0, 0]); + } + + #[test] + fn test_reduce_wide_basic() { + // 10 mod 7 = 3 + let wide = [10, 0, 0, 0, 0, 0, 0, 0]; + let modulus = [7, 0, 0, 0]; + assert_eq!(reduce_wide(&wide, &modulus), [3, 0, 0, 0]); + } + + #[test] + fn test_mul_mod_small() { + // (5 * 3) mod 7 = 15 mod 7 = 1 + let a = [5, 0, 0, 0]; + let b = [3, 0, 0, 0]; + let m = [7, 0, 0, 0]; + assert_eq!(mul_mod(&a, &b, &m), [1, 0, 0, 0]); + } + + #[test] + fn test_mod_pow_small() { + // 3^4 mod 7 = 81 mod 7 = 4 + let base = [3, 0, 0, 0]; + let exp = [4, 0, 0, 0]; + let m = [7, 0, 0, 0]; + assert_eq!(mod_pow(&base, &exp, &m), [4, 0, 0, 0]); + } + + #[test] + fn test_fermat_inverse_small() { + // Inverse of 3 mod 7: 3^{7-2} = 3^5 mod 7 = 243 mod 7 = 5 + // Check: 3 * 5 = 15 = 2*7 + 1 ≡ 1 (mod 7) ✓ + let a = [3, 0, 0, 0]; + let m = [7, 0, 0, 0]; + let exp = sub_u64(&m, 2); // m - 2 = 5 + let inv = mod_pow(&a, &exp, &m); + assert_eq!(inv, [5, 0, 0, 0]); + // Verify: a * inv mod m = 1 + assert_eq!(mul_mod(&a, &inv, &m), [1, 0, 0, 0]); + } + + #[test] + fn test_fermat_inverse_prime_23() { + // Inverse of 5 mod 23: 5^{21} mod 23 + // 5^{-1} mod 23 = 14 (because 5*14 = 70 = 3*23 + 1) + let a = [5, 0, 0, 0]; + let m = [23, 0, 0, 0]; + let exp = sub_u64(&m, 2); + let inv = mod_pow(&a, &exp, &m); + assert_eq!(inv, [14, 0, 0, 0]); + assert_eq!(mul_mod(&a, &inv, &m), [1, 0, 0, 0]); + } + + #[test] + fn test_sub_u64_basic() { + assert_eq!(sub_u64(&[10, 0, 0, 0], 3), [7, 0, 0, 0]); + } + + #[test] + fn test_sub_u64_borrow() { + // [0, 1, 0, 0] = 2^64; subtract 1 → [u64::MAX, 0, 0, 0] + assert_eq!(sub_u64(&[0, 1, 0, 0], 1), [u64::MAX, 0, 0, 0]); + } + + #[test] + fn test_fermat_inverse_large_prime() { + // Use a 128-bit prime: p = 2^127 - 1 = 170141183460469231731687303715884105727 + // In limbs: [u64::MAX, 2^63 - 1, 0, 0] + let p = [u64::MAX, (1u64 << 63) - 1, 0, 0]; + + // a = 42 + let a = [42, 0, 0, 0]; + let exp = sub_u64(&p, 2); + let inv = mod_pow(&a, &exp, &p); + + // Verify: a * inv mod p = 1 + assert_eq!(mul_mod(&a, &inv, &p), [1, 0, 0, 0]); + } + + #[test] + fn test_cmp_wide_narrow() { + let wide = [5, 0, 0, 0, 0, 0, 0, 0]; + let narrow = [5, 0, 0, 0]; + assert_eq!(cmp_wide_narrow(&wide, &narrow), std::cmp::Ordering::Equal); + + let wide_greater = [0, 0, 0, 0, 1, 0, 0, 0]; + assert_eq!( + cmp_wide_narrow(&wide_greater, &narrow), + std::cmp::Ordering::Greater + ); + } + + #[test] + fn test_mod_pow_zero_exp() { + // a^0 mod m = 1 + let base = [42, 0, 0, 0]; + let exp = [0, 0, 0, 0]; + let m = [7, 0, 0, 0]; + assert_eq!(mod_pow(&base, &exp, &m), [1, 0, 0, 0]); + } + + #[test] + fn test_mod_pow_one_exp() { + // a^1 mod m = a mod m + let base = [10, 0, 0, 0]; + let exp = [1, 0, 0, 0]; + let m = [7, 0, 0, 0]; + assert_eq!(mod_pow(&base, &exp, &m), [3, 0, 0, 0]); + } + + #[test] + fn test_divmod_exact() { + // 21 / 7 = 3 remainder 0 + let (q, r) = divmod(&[21, 0, 0, 0], &[7, 0, 0, 0]); + assert_eq!(q, [3, 0, 0, 0]); + assert_eq!(r, [0, 0, 0, 0]); + } + + #[test] + fn test_divmod_with_remainder() { + // 17 / 7 = 2 remainder 3 + let (q, r) = divmod(&[17, 0, 0, 0], &[7, 0, 0, 0]); + assert_eq!(q, [2, 0, 0, 0]); + assert_eq!(r, [3, 0, 0, 0]); + } + + #[test] + fn test_divmod_smaller_dividend() { + // 5 / 7 = 0 remainder 5 + let (q, r) = divmod(&[5, 0, 0, 0], &[7, 0, 0, 0]); + assert_eq!(q, [0, 0, 0, 0]); + assert_eq!(r, [5, 0, 0, 0]); + } + + #[test] + fn test_divmod_zero_dividend() { + let (q, r) = divmod(&[0, 0, 0, 0], &[7, 0, 0, 0]); + assert_eq!(q, [0, 0, 0, 0]); + assert_eq!(r, [0, 0, 0, 0]); + } + + #[test] + fn test_divmod_large() { + // 2^64 / 3 = 6148914691236517205 remainder 1 + // 2^64 in limbs: [0, 1, 0, 0] + let (q, r) = divmod(&[0, 1, 0, 0], &[3, 0, 0, 0]); + assert_eq!(q, [6148914691236517205, 0, 0, 0]); + assert_eq!(r, [1, 0, 0, 0]); + // Verify: q * 3 + 1 = 2^64 + assert_eq!(6148914691236517205u64 * 3 + 1, 0u64); // wraps to 0 in u64 = + // 2^64 + } + + #[test] + fn test_divmod_consistency() { + // Verify dividend = quotient * divisor + remainder for various inputs + let cases: Vec<([u64; 4], [u64; 4])> = vec![ + ([100, 0, 0, 0], [7, 0, 0, 0]), + ([u64::MAX, 0, 0, 0], [1000, 0, 0, 0]), + ([0, 1, 0, 0], [u64::MAX, 0, 0, 0]), // 2^64 / (2^64 - 1) + ]; + for (dividend, divisor) in cases { + let (q, r) = divmod(÷nd, &divisor); + // Verify: q * divisor + r = dividend + let product = widening_mul(&q, &divisor); + // Add remainder to product + let mut sum = product; + let mut carry = 0u128; + for i in 0..4 { + let s = (sum[i] as u128) + (r[i] as u128) + carry; + sum[i] = s as u64; + carry = s >> 64; + } + for i in 4..8 { + let s = (sum[i] as u128) + carry; + sum[i] = s as u64; + carry = s >> 64; + } + // sum should equal dividend (zero-extended to 8 limbs) + let mut expected = [0u64; 8]; + expected[..4].copy_from_slice(÷nd); + assert_eq!(sum, expected, "dividend={dividend:?} divisor={divisor:?}"); + } + } +} diff --git a/provekit/prover/src/witness/mod.rs b/provekit/prover/src/witness/mod.rs index 5f5de8f0b..fb5072440 100644 --- a/provekit/prover/src/witness/mod.rs +++ b/provekit/prover/src/witness/mod.rs @@ -1,3 +1,4 @@ +pub(crate) mod bigint_mod; mod digits; mod ram; pub(crate) mod witness_builder; diff --git a/provekit/prover/src/witness/witness_builder.rs b/provekit/prover/src/witness/witness_builder.rs index db91e5e0a..115637da6 100644 --- a/provekit/prover/src/witness/witness_builder.rs +++ b/provekit/prover/src/witness/witness_builder.rs @@ -65,6 +65,24 @@ impl WitnessBuilderSolver for WitnessBuilder { "Inverse/LogUpInverse should not be called - handled by batch inversion" ) } + WitnessBuilder::ModularInverse(witness_idx, operand_idx, modulus) => { + let a = witness[*operand_idx].unwrap(); + let a_limbs = a.into_bigint().0; + let m_limbs = modulus.into_bigint().0; + // Fermat's little theorem: a^{-1} = a^{m-2} mod m + let exp = crate::witness::bigint_mod::sub_u64(&m_limbs, 2); + let result_limbs = crate::witness::bigint_mod::mod_pow(&a_limbs, &exp, &m_limbs); + witness[*witness_idx] = + Some(FieldElement::from_bigint(ark_ff::BigInt(result_limbs)).unwrap()); + } + WitnessBuilder::IntegerQuotient(witness_idx, dividend_idx, divisor) => { + let dividend = witness[*dividend_idx].unwrap(); + let d_limbs = dividend.into_bigint().0; + let m_limbs = divisor.into_bigint().0; + let (quotient, _remainder) = crate::witness::bigint_mod::divmod(&d_limbs, &m_limbs); + witness[*witness_idx] = + Some(FieldElement::from_bigint(ark_ff::BigInt(quotient)).unwrap()); + } WitnessBuilder::IndexedLogUpDenominator( witness_idx, sz_challenge, diff --git a/provekit/r1cs-compiler/src/lib.rs b/provekit/r1cs-compiler/src/lib.rs index 8faedda8b..f2874fd0e 100644 --- a/provekit/r1cs-compiler/src/lib.rs +++ b/provekit/r1cs-compiler/src/lib.rs @@ -1,6 +1,7 @@ mod binops; mod digits; mod memory; +mod msm; mod noir_proof_scheme; mod noir_to_r1cs; mod poseidon2; diff --git a/provekit/r1cs-compiler/src/msm/curve.rs b/provekit/r1cs-compiler/src/msm/curve.rs new file mode 100644 index 000000000..cbfd1bf25 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/curve.rs @@ -0,0 +1,71 @@ +use provekit_common::FieldElement; + +// TODO : remove Option<> form both the params if comes in use +// otherwise we delete the params from struct +pub struct CurveParams { + pub field_modulus_p: FieldElement, + pub curve_order_n: FieldElement, + pub curve_a: FieldElement, + pub curve_b: FieldElement, + pub generator: (FieldElement, FieldElement), + pub coordinate_bits: Option, +} + +pub fn secp256r1_params() -> CurveParams { + CurveParams { + field_modulus_p: FieldElement::from_sign_and_limbs( + true, + [ + 0xffffffffffffffff_u64, + 0xffffffff_u64, + 0x0_u64, + 0xffffffff00000001_u64, + ] + .as_slice(), + ), + curve_order_n: FieldElement::from_sign_and_limbs( + true, + [ + 0xf3b9cac2fc632551_u64, + 0xbce6faada7179e84_u64, + 0xffffffffffffffff_u64, + 0xffffffff00000000_u64, + ] + .as_slice(), + ), + curve_a: FieldElement::from(-3), + curve_b: FieldElement::from_sign_and_limbs( + true, + [ + 0x3bce3c3e27d2604b_u64, + 0x651d06b0cc53b0f6_u64, + 0xb3ebbd55769886bc_u64, + 0x5ac635d8aa3a93e7_u64, + ] + .as_slice(), + ), + generator: ( + FieldElement::from_sign_and_limbs( + true, + [ + 0xf4a13945d898c296_u64, + 0x77037d812deb33a0_u64, + 0xf8bce6e563a440f2_u64, + 0x6b17d1f2e12c4247_u64, + ] + .as_slice(), + ), + FieldElement::from_sign_and_limbs( + true, + [ + 0xcbb6406837bf51f5_u64, + 0x2bce33576b315ece_u64, + 0x8ee7eb4a7c0f9e16_u64, + 0x4fe342e2fe1a7f9b_u64, + ] + .as_slice(), + ), + ), + coordinate_bits: None, + } +} diff --git a/provekit/r1cs-compiler/src/msm/ec_ops.rs b/provekit/r1cs-compiler/src/msm/ec_ops.rs new file mode 100644 index 000000000..1bd96e69f --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/ec_ops.rs @@ -0,0 +1,305 @@ +use { + crate::{msm::curve::CurveParams, noir_to_r1cs::NoirToR1CSCompiler}, + ark_ff::{AdditiveGroup, BigInteger, Field, PrimeField}, + provekit_common::{ + witness::{SumTerm, WitnessBuilder}, + FieldElement, + }, + std::collections::BTreeMap, +}; + +/// Reduce the value to given modulus +pub fn reduce_mod( + r1cs_compiler: &mut NoirToR1CSCompiler, + value: usize, + modulus: FieldElement, + range_checks: &mut BTreeMap>, +) -> usize { + // Reduce mod algorithm : + // v = k * m + result, where 0 <= result < m + // k = floor(v / m) (integer division) + // result = v - k * m + + // Computing k = floor(v / m) + // ----------------------------------------------------------- + // computing m (constant witness for use in constraints) + let m = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::Constant( + provekit_common::witness::ConstantTerm(m, modulus), + )); + // computing k via integer division + let k = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::IntegerQuotient(k, value, modulus)); + + // Computing result = v - k * m + // ----------------------------------------------------------- + // computing k * m + let k_mul_m = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::Product(k_mul_m, k, m)); + // constraint: k * m = k_mul_m + r1cs_compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, k)], &[(FieldElement::ONE, m)], &[( + FieldElement::ONE, + k_mul_m, + )]); + // computing result = v - k * m + let result = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::Sum(result, vec![ + SumTerm(Some(FieldElement::ONE), value), + SumTerm(Some(-FieldElement::ONE), k_mul_m), + ])); + // constraint: 1 * (k_mul_m + result) = value + r1cs_compiler.r1cs.add_constraint( + &[(FieldElement::ONE, r1cs_compiler.witness_one())], + &[(FieldElement::ONE, k_mul_m), (FieldElement::ONE, result)], + &[(FieldElement::ONE, value)], + ); + // range check to prove 0 <= result < m + let modulus_bits = modulus.into_bigint().num_bits(); + range_checks + .entry(modulus_bits) + .or_insert_with(Vec::new) + .push(result); + + result +} + +/// a * b mod m +pub fn compute_field_mul( + r1cs_compiler: &mut NoirToR1CSCompiler, + a: usize, + b: usize, + modulus: FieldElement, + range_checks: &mut BTreeMap>, +) -> usize { + let a_mul_b = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::Product(a_mul_b, a, b)); + // constraint: a * b = a_mul_b + r1cs_compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, a)], &[(FieldElement::ONE, b)], &[( + FieldElement::ONE, + a_mul_b, + )]); + reduce_mod(r1cs_compiler, a_mul_b, modulus, range_checks) +} + +/// (a - b) mod m +pub fn compute_field_sub( + r1cs_compiler: &mut NoirToR1CSCompiler, + a: usize, + b: usize, + modulus: FieldElement, + range_checks: &mut BTreeMap>, +) -> usize { + let a_sub_b = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::Sum(a_sub_b, vec![ + SumTerm(Some(FieldElement::ONE), a), + SumTerm(Some(-FieldElement::ONE), b), + ])); + // constraint: 1 * (a - b) = a_sub_b + r1cs_compiler.r1cs.add_constraint( + &[(FieldElement::ONE, r1cs_compiler.witness_one())], + &[(FieldElement::ONE, a), (-FieldElement::ONE, b)], + &[(FieldElement::ONE, a_sub_b)], + ); + reduce_mod(r1cs_compiler, a_sub_b, modulus, range_checks) +} + +/// a^(-1) mod m +/// +/// CRITICAL: secp256r1's field_modulus_p (~2^256) > BN254 scalar field +/// (~2^254). Coordinates and the modulus do not fit in a single +/// FieldElement. Either use multi-limb representation or target a +/// curve that fits (e.g. Grumpkin, BabyJubJub). +pub fn compute_field_inv( + r1cs_compiler: &mut NoirToR1CSCompiler, + a: usize, + modulus: FieldElement, + range_checks: &mut BTreeMap>, +) -> usize { + // Computing a^(-1) mod m + // ----------------------------------------------------------- + // computing a_inv (the F_m inverse of a) via Fermat's little theorem + let a_inv = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::ModularInverse(a_inv, a, modulus)); + + // Verifying a * a_inv mod m = 1 + // ----------------------------------------------------------- + // computing a * a_inv + let product_raw = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::Product(product_raw, a, a_inv)); + // constraint: a * a_inv = product_raw + r1cs_compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, a)], &[(FieldElement::ONE, a_inv)], &[ + (FieldElement::ONE, product_raw), + ]); + // reducing a * a_inv mod m — should give 1 if a_inv is correct + let reduced = reduce_mod(r1cs_compiler, product_raw, modulus, range_checks); + + // constraint: reduced = 1 + // (reduced - 1) * 1 = 0 + r1cs_compiler.r1cs.add_constraint( + &[(FieldElement::ONE, r1cs_compiler.witness_one())], + &[ + (FieldElement::ONE, reduced), + (-FieldElement::ONE, r1cs_compiler.witness_one()), + ], + &[(FieldElement::ZERO, r1cs_compiler.witness_one())], + ); + + // range check: a_inv in [0, 2^bits(m)) + let mod_bits = modulus.into_bigint().num_bits(); + range_checks + .entry(mod_bits) + .or_insert_with(Vec::new) + .push(a_inv); + + a_inv +} + +/// Point doubling on y^2 = x^3 + ax + b (mod p) using affine lambda formula. +/// +/// Given P = (x1, y1), computes 2P = (x3, y3): +/// lambda = (3 * x1^2 + a) / (2 * y1) (mod p) +/// x3 = lambda^2 - 2 * x1 (mod p) +/// y3 = lambda * (x1 - x3) - y1 (mod p) +/// +/// Edge case — y1 = 0 (point of order 2): +/// When y1 = 0, the denominator 2*y1 = 0 and the inverse does not exist. +/// The result should be the point at infinity (identity element). +/// This function does NOT handle that case — the constraint system will +/// be unsatisfiable if y1 = 0 (compute_field_inv will fail to verify +/// 0 * inv = 1 mod p). The caller must check y1 = 0 using +/// compute_is_zero and conditionally select the point-at-infinity +/// result before calling this function. +pub fn point_double( + r1cs_compiler: &mut NoirToR1CSCompiler, + x1: usize, + y1: usize, + curve_params: &CurveParams, + range_checks: &mut BTreeMap>, +) -> (usize, usize) { + let p = curve_params.field_modulus_p; + + // Computing numerator = 3 * x1^2 + a (mod p) + // ----------------------------------------------------------- + // computing x1^2 mod p + let x1_sq = compute_field_mul(r1cs_compiler, x1, x1, p, range_checks); + // computing 3 * x1_sq + a + let a_witness = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::Constant( + provekit_common::witness::ConstantTerm(a_witness, curve_params.curve_a), + )); + let num_raw = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::Sum(num_raw, vec![ + SumTerm(Some(FieldElement::from(3u64)), x1_sq), + SumTerm(Some(FieldElement::ONE), a_witness), + ])); + // constraint: 1 * (3 * x1_sq + a) = num_raw + r1cs_compiler.r1cs.add_constraint( + &[(FieldElement::ONE, r1cs_compiler.witness_one())], + &[ + (FieldElement::from(3u64), x1_sq), + (FieldElement::ONE, a_witness), + ], + &[(FieldElement::ONE, num_raw)], + ); + let numerator = reduce_mod(r1cs_compiler, num_raw, p, range_checks); + + // Computing denominator = 2 * y1 (mod p) + // ----------------------------------------------------------- + // computing 2 * y1 + let denom_raw = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::Sum(denom_raw, vec![SumTerm( + Some(FieldElement::from(2u64)), + y1, + )])); + // constraint: 1 * (2 * y1) = denom_raw + r1cs_compiler.r1cs.add_constraint( + &[(FieldElement::ONE, r1cs_compiler.witness_one())], + &[(FieldElement::from(2u64), y1)], + &[(FieldElement::ONE, denom_raw)], + ); + let denominator = reduce_mod(r1cs_compiler, denom_raw, p, range_checks); + + // Computing lambda = numerator * denominator^(-1) (mod p) + // ----------------------------------------------------------- + // computing denominator^(-1) mod p + let denom_inv = compute_field_inv(r1cs_compiler, denominator, p, range_checks); + // computing lambda = numerator * denom_inv mod p + let lambda = compute_field_mul(r1cs_compiler, numerator, denom_inv, p, range_checks); + + // Computing x3 = lambda^2 - 2 * x1 (mod p) + // ----------------------------------------------------------- + // computing lambda^2 mod p + let lambda_sq = compute_field_mul(r1cs_compiler, lambda, lambda, p, range_checks); + // computing lambda^2 - 2 * x1 + let x3_raw = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::Sum(x3_raw, vec![ + SumTerm(Some(FieldElement::ONE), lambda_sq), + SumTerm(Some(-FieldElement::from(2u64)), x1), + ])); + // constraint: 1 * (lambda^2 - 2 * x1) = x3_raw + r1cs_compiler.r1cs.add_constraint( + &[(FieldElement::ONE, r1cs_compiler.witness_one())], + &[ + (FieldElement::ONE, lambda_sq), + (-FieldElement::from(2u64), x1), + ], + &[(FieldElement::ONE, x3_raw)], + ); + let x3 = reduce_mod(r1cs_compiler, x3_raw, p, range_checks); + + // Computing y3 = lambda * (x1 - x3) - y1 (mod p) + // ----------------------------------------------------------- + // computing x1 - x3 mod p + let x1_minus_x3 = compute_field_sub(r1cs_compiler, x1, x3, p, range_checks); + // computing lambda * (x1 - x3) mod p + let lambda_dx = compute_field_mul(r1cs_compiler, lambda, x1_minus_x3, p, range_checks); + // computing lambda * (x1 - x3) - y1 mod p + let y3 = compute_field_sub(r1cs_compiler, lambda_dx, y1, p, range_checks); + + (x3, y3) +} + +/// checks if value is zero or not +pub fn compute_is_zero(r1cs_compiler: &mut NoirToR1CSCompiler, value: usize) -> usize { + // calculating v^(-1) + let value_inv = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::Inverse(value_inv, value)); + // calculating v * v^(-1) + let value_mul_value_inv = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::Product( + value_mul_value_inv, + value, + value_inv, + )); + // calculate is_zero = 1 - (v * v^(-1)) + let is_zero = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(provekit_common::witness::WitnessBuilder::Sum( + is_zero, + vec![ + provekit_common::witness::SumTerm(Some(FieldElement::ONE), r1cs_compiler.witness_one()), + provekit_common::witness::SumTerm(Some(-FieldElement::ONE), value_mul_value_inv), + ], + )); + // constraint: v × v^(-1) = 1 - is_zero + r1cs_compiler.r1cs.add_constraint( + &[(FieldElement::ONE, value)], + &[(FieldElement::ONE, value_inv)], + &[ + (FieldElement::ONE, r1cs_compiler.witness_one()), + (-FieldElement::ONE, is_zero), + ], + ); + // constraint: v × is_zero = 0 + r1cs_compiler.r1cs.add_constraint( + &[(FieldElement::ONE, value)], + &[(FieldElement::ONE, is_zero)], + &[(FieldElement::ZERO, r1cs_compiler.witness_one())], + ); + is_zero +} diff --git a/provekit/r1cs-compiler/src/msm/mod.rs b/provekit/r1cs-compiler/src/msm/mod.rs new file mode 100644 index 000000000..3844d1466 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/mod.rs @@ -0,0 +1,2 @@ +pub mod curve; +pub mod ec_ops; diff --git a/provekit/r1cs-compiler/src/noir_to_r1cs.rs b/provekit/r1cs-compiler/src/noir_to_r1cs.rs index 189eb4693..2d4245636 100644 --- a/provekit/r1cs-compiler/src/noir_to_r1cs.rs +++ b/provekit/r1cs-compiler/src/noir_to_r1cs.rs @@ -16,6 +16,7 @@ use { Circuit, Opcode, }, native_types::{Expression, Witness as NoirWitness}, + BlackBoxFunc, }, anyhow::{bail, Result}, ark_ff::PrimeField, @@ -627,6 +628,11 @@ impl NoirToR1CSCompiler { output_witnesses, )); } + BlackBoxFuncCall::MultiScalarMul { + points, + scalars, + outputs, + } => {} _ => { unimplemented!("Other black box function: {:?}", black_box_func_call); } From bf2f7a692c6d396f7cab238f7cce1d84d54c2fb6 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Sat, 28 Feb 2026 06:36:12 +0530 Subject: [PATCH 2/5] feat : added wide field ops for ec operations and added trait generics --- .../src/witness/scheduling/dependency.rs | 30 + .../common/src/witness/scheduling/remapper.rs | 54 ++ .../common/src/witness/witness_builder.rs | 59 ++ provekit/prover/src/witness/bigint_mod.rs | 437 +++++++++++++- .../prover/src/witness/witness_builder.rs | 194 ++++++ provekit/r1cs-compiler/src/msm/curve.rs | 165 +++-- provekit/r1cs-compiler/src/msm/ec_ops.rs | 165 ++--- provekit/r1cs-compiler/src/msm/ec_points.rs | 101 ++++ provekit/r1cs-compiler/src/msm/mod.rs | 141 +++++ provekit/r1cs-compiler/src/msm/wide_ops.rs | 563 ++++++++++++++++++ 10 files changed, 1714 insertions(+), 195 deletions(-) create mode 100644 provekit/r1cs-compiler/src/msm/ec_points.rs create mode 100644 provekit/r1cs-compiler/src/msm/wide_ops.rs diff --git a/provekit/common/src/witness/scheduling/dependency.rs b/provekit/common/src/witness/scheduling/dependency.rs index f2d480377..956f79b56 100644 --- a/provekit/common/src/witness/scheduling/dependency.rs +++ b/provekit/common/src/witness/scheduling/dependency.rs @@ -154,6 +154,28 @@ impl DependencyInfo { } v } + WitnessBuilder::MulModHint { + a_lo, + a_hi, + b_lo, + b_hi, + .. + } => vec![*a_lo, *a_hi, *b_lo, *b_hi], + WitnessBuilder::WideModularInverse { a_lo, a_hi, .. } => vec![*a_lo, *a_hi], + WitnessBuilder::WideAddQuotient { + a_lo, + a_hi, + b_lo, + b_hi, + .. + } => vec![*a_lo, *a_hi, *b_lo, *b_hi], + WitnessBuilder::WideSubBorrow { + a_lo, + a_hi, + b_lo, + b_hi, + .. + } => vec![*a_lo, *a_hi, *b_lo, *b_hi], WitnessBuilder::BytePartition { x, .. } => vec![*x], WitnessBuilder::U32AdditionMulti(_, _, inputs) => inputs @@ -286,6 +308,14 @@ impl DependencyInfo { let n = 1usize << *num_bits; (*start..*start + n).collect() } + WitnessBuilder::MulModHint { output_start, .. } => { + (*output_start..*output_start + 20).collect() + } + WitnessBuilder::WideModularInverse { output_start, .. } => { + (*output_start..*output_start + 2).collect() + } + WitnessBuilder::WideAddQuotient { output, .. } => vec![*output], + WitnessBuilder::WideSubBorrow { output, .. } => vec![*output], WitnessBuilder::U32Addition(result_idx, carry_idx, ..) => { vec![*result_idx, *carry_idx] } diff --git a/provekit/common/src/witness/scheduling/remapper.rs b/provekit/common/src/witness/scheduling/remapper.rs index 76c30bd23..47490a6ce 100644 --- a/provekit/common/src/witness/scheduling/remapper.rs +++ b/provekit/common/src/witness/scheduling/remapper.rs @@ -221,6 +221,60 @@ impl WitnessIndexRemapper { .collect(), ) } + WitnessBuilder::MulModHint { + output_start, + a_lo, + a_hi, + b_lo, + b_hi, + modulus, + } => WitnessBuilder::MulModHint { + output_start: self.remap(*output_start), + a_lo: self.remap(*a_lo), + a_hi: self.remap(*a_hi), + b_lo: self.remap(*b_lo), + b_hi: self.remap(*b_hi), + modulus: *modulus, + }, + WitnessBuilder::WideModularInverse { + output_start, + a_lo, + a_hi, + modulus, + } => WitnessBuilder::WideModularInverse { + output_start: self.remap(*output_start), + a_lo: self.remap(*a_lo), + a_hi: self.remap(*a_hi), + modulus: *modulus, + }, + WitnessBuilder::WideAddQuotient { + output, + a_lo, + a_hi, + b_lo, + b_hi, + modulus, + } => WitnessBuilder::WideAddQuotient { + output: self.remap(*output), + a_lo: self.remap(*a_lo), + a_hi: self.remap(*a_hi), + b_lo: self.remap(*b_lo), + b_hi: self.remap(*b_hi), + modulus: *modulus, + }, + WitnessBuilder::WideSubBorrow { + output, + a_lo, + a_hi, + b_lo, + b_hi, + } => WitnessBuilder::WideSubBorrow { + output: self.remap(*output), + a_lo: self.remap(*a_lo), + a_hi: self.remap(*a_hi), + b_lo: self.remap(*b_lo), + b_hi: self.remap(*b_hi), + }, WitnessBuilder::BytePartition { lo, hi, x, k } => WitnessBuilder::BytePartition { lo: self.remap(*lo), hi: self.remap(*hi), diff --git a/provekit/common/src/witness/witness_builder.rs b/provekit/common/src/witness/witness_builder.rs index 5719ee99e..6d11d17cf 100644 --- a/provekit/common/src/witness/witness_builder.rs +++ b/provekit/common/src/witness/witness_builder.rs @@ -201,6 +201,63 @@ pub enum WitnessBuilder { /// Inverse of combined lookup table entry denominator (constant operands). /// Computes: 1 / (sz - lhs - rs*rhs - rs²*and_out - rs³*xor_out) CombinedTableEntryInverse(CombinedTableEntryInverseData), + /// Prover hint for multi-limb modular multiplication: (a * b) mod p. + /// Given inputs a = (a_lo, a_hi) and b = (b_lo, b_hi) as 128-bit limbs, + /// and a constant 256-bit modulus p, computes quotient q, remainder r, + /// their 86-bit decompositions, and carry witnesses. + /// + /// Outputs 20 witnesses starting at output_start: + /// [0..2) q_lo, q_hi (quotient in 128-bit limbs) + /// [2..4) r_lo, r_hi (remainder in 128-bit limbs) + /// [4..7) a_86_0, a_86_1, a_86_2 (a in 86-bit limbs) + /// [7..10) b_86_0, b_86_1, b_86_2 (b in 86-bit limbs) + /// [10..13) q_86_0, q_86_1, q_86_2 (q in 86-bit limbs) + /// [13..16) r_86_0, r_86_1, r_86_2 (r in 86-bit limbs) + /// [16..20) c0, c1, c2, c3 (carry witnesses, unsigned-offset) + MulModHint { + output_start: usize, + a_lo: usize, + a_hi: usize, + b_lo: usize, + b_hi: usize, + modulus: [u64; 4], + }, + /// Prover hint for wide modular inverse: a^{-1} mod p. + /// Given input a = (a_lo, a_hi) as 128-bit limbs and constant modulus p, + /// computes the inverse via Fermat's little theorem (a^{p-2} mod p). + /// + /// Outputs 2 witnesses at output_start: inv_lo, inv_hi (128-bit limbs). + WideModularInverse { + output_start: usize, + a_lo: usize, + a_hi: usize, + modulus: [u64; 4], + }, + /// Prover hint for wide addition quotient: q = floor((a + b) / p). + /// Given inputs a = (a_lo, a_hi) and b = (b_lo, b_hi) as 128-bit limbs, + /// and a constant 256-bit modulus p, computes q ∈ {0, 1}. + /// + /// Outputs 1 witness at output: q. + WideAddQuotient { + output: usize, + a_lo: usize, + a_hi: usize, + b_lo: usize, + b_hi: usize, + modulus: [u64; 4], + }, + /// Prover hint for wide subtraction borrow: q = (a < b) ? 1 : 0. + /// Given inputs a = (a_lo, a_hi) and b = (b_lo, b_hi) as 128-bit limbs, + /// computes q ∈ {0, 1} indicating whether a borrow (adding p) is needed. + /// + /// Outputs 1 witness at output: q. + WideSubBorrow { + output: usize, + a_lo: usize, + a_hi: usize, + b_lo: usize, + b_hi: usize, + }, /// Decomposes a packed value into chunks of specified bit-widths. /// Given packed value and chunk_bits = [b0, b1, ..., bn]: /// packed = c0 + c1 * 2^b0 + c2 * 2^(b0+b1) + ... @@ -272,6 +329,8 @@ impl WitnessBuilder { WitnessBuilder::ChunkDecompose { chunk_bits, .. } => chunk_bits.len(), WitnessBuilder::SpreadBitExtract { chunk_bits, .. } => chunk_bits.len(), WitnessBuilder::MultiplicitiesForSpread(_, num_bits, _) => 1usize << *num_bits, + WitnessBuilder::MulModHint { .. } => 20, + WitnessBuilder::WideModularInverse { .. } => 2, _ => 1, } diff --git a/provekit/prover/src/witness/bigint_mod.rs b/provekit/prover/src/witness/bigint_mod.rs index 3252aeea2..a41f47ff3 100644 --- a/provekit/prover/src/witness/bigint_mod.rs +++ b/provekit/prover/src/witness/bigint_mod.rs @@ -6,7 +6,7 @@ /// Schoolbook multiplication: 4×4 limbs → 8 limbs (256-bit × 256-bit → /// 512-bit). -fn widening_mul(a: &[u64; 4], b: &[u64; 4]) -> [u64; 8] { +pub fn widening_mul(a: &[u64; 4], b: &[u64; 4]) -> [u64; 8] { let mut result = [0u64; 8]; for i in 0..4 { let mut carry = 0u128; @@ -90,7 +90,7 @@ fn shift_left_one(a: &mut [u64; 4]) -> u64 { } /// Compare two 4-limb numbers. -fn cmp_4limb(a: &[u64; 4], b: &[u64; 4]) -> std::cmp::Ordering { +pub fn cmp_4limb(a: &[u64; 4], b: &[u64; 4]) -> std::cmp::Ordering { for i in (0..4).rev() { match a[i].cmp(&b[i]) { std::cmp::Ordering::Equal => continue, @@ -203,6 +203,235 @@ pub fn sub_u64(a: &[u64; 4], small: u64) -> [u64; 4] { result } +/// Add two 4-limb (256-bit) numbers, returning a 5-limb result with carry. +pub fn add_4limb(a: &[u64; 4], b: &[u64; 4]) -> [u64; 5] { + let mut result = [0u64; 5]; + let mut carry = 0u64; + for i in 0..4 { + let (s1, c1) = a[i].overflowing_add(b[i]); + let (s2, c2) = s1.overflowing_add(carry); + result[i] = s2; + carry = (c1 as u64) + (c2 as u64); + } + result[4] = carry; + result +} + +/// Offset added to signed carries to make them non-negative for range checking. +/// Carries are bounded by |c| < 2^88, so adding 2^88 ensures c_unsigned >= 0. +pub const CARRY_OFFSET: u128 = 1u128 << 88; + +/// Integer division of a 512-bit dividend by a 256-bit divisor. +/// Returns (quotient, remainder) where both fit in 256 bits. +/// Panics if the quotient would exceed 256 bits. +pub fn divmod_wide(dividend: &[u64; 8], divisor: &[u64; 4]) -> ([u64; 4], [u64; 4]) { + let mut highest_bit = 0; + for i in (0..8).rev() { + if dividend[i] != 0 { + highest_bit = i * 64 + (64 - dividend[i].leading_zeros() as usize); + break; + } + } + if highest_bit == 0 { + return ([0u64; 4], [0u64; 4]); + } + + let mut quotient = [0u64; 4]; + let mut remainder = [0u64; 4]; + + for bit_pos in (0..highest_bit).rev() { + let shift_carry = shift_left_one(&mut remainder); + + let limb_idx = bit_pos / 64; + let bit_idx = bit_pos % 64; + remainder[0] |= (dividend[limb_idx] >> bit_idx) & 1; + + // If shift_carry is set, the effective remainder is 2^256 + remainder, + // which is always > any 256-bit divisor, so we must subtract. + if shift_carry != 0 || cmp_4limb(&remainder, divisor) != std::cmp::Ordering::Less { + // Subtract divisor with inline borrow tracking (handles the case + // where remainder < divisor but shift_carry provides the extra bit). + let mut borrow = 0u64; + for i in 0..4 { + let (d1, b1) = remainder[i].overflowing_sub(divisor[i]); + let (d2, b2) = d1.overflowing_sub(borrow); + remainder[i] = d2; + borrow = (b1 as u64) + (b2 as u64); + } + // When shift_carry was set, the borrow absorbs it (they cancel out). + debug_assert_eq!( + borrow, shift_carry, + "unexpected borrow in divmod_wide at bit_pos {}", + bit_pos + ); + + assert!(bit_pos < 256, "quotient exceeds 256 bits"); + quotient[bit_pos / 64] |= 1u64 << (bit_pos % 64); + } + } + + (quotient, remainder) +} + +/// Split a 256-bit value into two 128-bit halves: (lo, hi). +pub fn decompose_128(val: &[u64; 4]) -> (u128, u128) { + let lo = val[0] as u128 | ((val[1] as u128) << 64); + let hi = val[2] as u128 | ((val[3] as u128) << 64); + (lo, hi) +} + +/// Split a 256-bit value into three 86-bit limbs: (l0, l1, l2). +/// l0 = bits [0..86), l1 = bits [86..172), l2 = bits [172..256). +pub fn decompose_86(val: &[u64; 4]) -> (u128, u128, u128) { + let mask_86: u128 = (1u128 << 86) - 1; + let lo128 = val[0] as u128 | ((val[1] as u128) << 64); + let hi128 = val[2] as u128 | ((val[3] as u128) << 64); + + let l0 = lo128 & mask_86; + // l1 spans bits [86..172): 42 bits from lo128, 44 bits from hi128 + let l1 = ((lo128 >> 86) | (hi128 << 42)) & mask_86; + // l2 = bits [172..256): 84 bits from hi128 + let l2 = hi128 >> 44; + + (l0, l1, l2) +} + +/// Compute carry values c0..c3 from the 86-bit schoolbook column equations +/// for the identity a*b = p*q + r (base W = 2^86). +/// +/// Column equations: +/// col0: a0*b0 - p0*q0 - r0 = c0*W +/// col1: a0*b1 + a1*b0 - p0*q1 - p1*q0 - r1 + c0 = c1*W +/// col2: a0*b2 + a1*b1 + a2*b0 - p0*q2 - p1*q1 - p2*q0 - r2 + c1 = c2*W +/// col3: a1*b2 + a2*b1 - p1*q2 - p2*q1 + c2 = c3*W +/// col4: a2*b2 - p2*q2 + c3 = 0 +pub fn compute_carries_86( + a: [u128; 3], + b: [u128; 3], + p: [u128; 3], + q: [u128; 3], + r: [u128; 3], +) -> [i128; 4] { + // Helper: convert u128 to [u64; 4] + fn to4(v: u128) -> [u64; 4] { + [v as u64, (v >> 64) as u64, 0, 0] + } + + // Helper: multiply two 86-bit values → [u64; 4] (result < 2^172) + fn mul86(x: u128, y: u128) -> [u64; 4] { + let w = widening_mul(&to4(x), &to4(y)); + [w[0], w[1], w[2], w[3]] + } + + // Helper: add two [u64; 4] values + fn add4(a: [u64; 4], b: [u64; 4]) -> [u64; 4] { + let mut r = [0u64; 4]; + let mut carry = 0u128; + for i in 0..4 { + let s = a[i] as u128 + b[i] as u128 + carry; + r[i] = s as u64; + carry = s >> 64; + } + r + } + + // Helper: subtract two [u64; 4] values (assumes a >= b) + fn sub4(a: [u64; 4], b: [u64; 4]) -> [u64; 4] { + let mut r = [0u64; 4]; + let mut borrow = 0u64; + for i in 0..4 { + let (d1, b1) = a[i].overflowing_sub(b[i]); + let (d2, b2) = d1.overflowing_sub(borrow); + r[i] = d2; + borrow = b1 as u64 + b2 as u64; + } + r + } + + // Helper: right-shift [u64; 4] by 86 bits (= 64 + 22) + fn shr86(a: [u64; 4]) -> [u64; 4] { + let s = [a[1], a[2], a[3], 0u64]; + [ + (s[0] >> 22) | (s[1] << 42), + (s[1] >> 22) | (s[2] << 42), + s[2] >> 22, + 0, + ] + } + + // Positive column sums (a_i * b_j terms) + let pos = [ + mul86(a[0], b[0]), + add4(mul86(a[0], b[1]), mul86(a[1], b[0])), + add4( + add4(mul86(a[0], b[2]), mul86(a[1], b[1])), + mul86(a[2], b[0]), + ), + add4(mul86(a[1], b[2]), mul86(a[2], b[1])), + mul86(a[2], b[2]), + ]; + + // Negative column sums (p_i * q_j + r_i terms) + let neg = [ + add4(mul86(p[0], q[0]), to4(r[0])), + add4(add4(mul86(p[0], q[1]), mul86(p[1], q[0])), to4(r[1])), + add4( + add4( + add4(mul86(p[0], q[2]), mul86(p[1], q[1])), + mul86(p[2], q[0]), + ), + to4(r[2]), + ), + add4(mul86(p[1], q[2]), mul86(p[2], q[1])), + mul86(p[2], q[2]), + ]; + + let mut carries = [0i128; 4]; + let mut carry_pos = [0u64; 4]; + let mut carry_neg = [0u64; 4]; + + for col in 0..4 { + let total_pos = add4(pos[col], carry_pos); + let total_neg = add4(neg[col], carry_neg); + + let (is_neg, diff) = if cmp_4limb(&total_pos, &total_neg) != std::cmp::Ordering::Less { + (false, sub4(total_pos, total_neg)) + } else { + (true, sub4(total_neg, total_pos)) + }; + + // Lower 86 bits must be zero (divisibility check) + let mask_86 = (1u128 << 86) - 1; + let low86 = (diff[0] as u128 | ((diff[1] as u128) << 64)) & mask_86; + debug_assert_eq!(low86, 0, "column {} not divisible by W=2^86", col); + + let carry_mag = shr86(diff); + debug_assert_eq!(carry_mag[2], 0, "carry overflow in column {}", col); + debug_assert_eq!(carry_mag[3], 0, "carry overflow in column {}", col); + + let carry_val = carry_mag[0] as i128 | ((carry_mag[1] as i128) << 64); + carries[col] = if is_neg { -carry_val } else { carry_val }; + + if is_neg { + carry_pos = [0; 4]; + carry_neg = carry_mag; + } else { + carry_pos = carry_mag; + carry_neg = [0; 4]; + } + } + + // Verify column 4 balances + let final_pos = add4(pos[4], carry_pos); + let final_neg = add4(neg[4], carry_neg); + debug_assert_eq!( + final_pos, final_neg, + "column 4 should balance: a2*b2 - p2*q2 + c3 = 0" + ); + + carries +} + #[cfg(test)] mod tests { use super::*; @@ -384,8 +613,8 @@ mod tests { assert_eq!(q, [6148914691236517205, 0, 0, 0]); assert_eq!(r, [1, 0, 0, 0]); // Verify: q * 3 + 1 = 2^64 - assert_eq!(6148914691236517205u64 * 3 + 1, 0u64); // wraps to 0 in u64 = - // 2^64 + assert_eq!(6148914691236517205u64.wrapping_mul(3).wrapping_add(1), 0u64); + // wraps to 0 in u64 = 2^64 } #[test] @@ -419,4 +648,204 @@ mod tests { assert_eq!(sum, expected, "dividend={dividend:?} divisor={divisor:?}"); } } + + #[test] + fn test_divmod_wide_small() { + // 21 / 7 = 3 remainder 0 (512-bit dividend) + let dividend = [21, 0, 0, 0, 0, 0, 0, 0]; + let divisor = [7, 0, 0, 0]; + let (q, r) = divmod_wide(÷nd, &divisor); + assert_eq!(q, [3, 0, 0, 0]); + assert_eq!(r, [0, 0, 0, 0]); + } + + #[test] + fn test_divmod_wide_large() { + // Compute a * b where a, b are 256-bit, then divide by a + // Should give quotient = b, remainder = 0 + let a = [0xffffffffffffffff, 0xffffffff, 0x0, 0xffffffff00000001]; // secp256r1 p + let b = [42, 0, 0, 0]; + let product = widening_mul(&a, &b); + let (q, r) = divmod_wide(&product, &a); + assert_eq!(q, b); + assert_eq!(r, [0, 0, 0, 0]); + } + + #[test] + fn test_divmod_wide_with_remainder() { + // (a * b + 5) / a = b remainder 5 + let a = [0xffffffffffffffff, 0xffffffff, 0x0, 0xffffffff00000001]; + let b = [100, 0, 0, 0]; + let mut product = widening_mul(&a, &b); + // Add 5 + let (sum, overflow) = product[0].overflowing_add(5); + product[0] = sum; + if overflow { + for i in 1..8 { + let (s, o) = product[i].overflowing_add(1); + product[i] = s; + if !o { + break; + } + } + } + let (q, r) = divmod_wide(&product, &a); + assert_eq!(q, b); + assert_eq!(r, [5, 0, 0, 0]); + } + + #[test] + fn test_divmod_wide_consistency() { + // Verify: q * divisor + r = dividend + let a = [ + 0x123456789abcdef0, + 0xfedcba9876543210, + 0x1111111111111111, + 0x2222222222222222, + ]; + let b = [0xaabbccdd, 0x11223344, 0x55667788, 0x99001122]; + let product = widening_mul(&a, &b); + let divisor = [0xffffffffffffffff, 0xffffffff, 0x0, 0xffffffff00000001]; + let (q, r) = divmod_wide(&product, &divisor); + + // Verify: q * divisor + r = product + let qd = widening_mul(&q, &divisor); + let mut sum = qd; + let mut carry = 0u128; + for i in 0..4 { + let s = (sum[i] as u128) + (r[i] as u128) + carry; + sum[i] = s as u64; + carry = s >> 64; + } + for i in 4..8 { + let s = (sum[i] as u128) + carry; + sum[i] = s as u64; + carry = s >> 64; + } + assert_eq!(sum, product); + } + + #[test] + fn test_decompose_128_roundtrip() { + let val = [ + 0x123456789abcdef0, + 0xfedcba9876543210, + 0x1111111111111111, + 0x2222222222222222, + ]; + let (lo, hi) = decompose_128(&val); + // Roundtrip + assert_eq!(lo as u64, val[0]); + assert_eq!((lo >> 64) as u64, val[1]); + assert_eq!(hi as u64, val[2]); + assert_eq!((hi >> 64) as u64, val[3]); + } + + #[test] + fn test_decompose_86_roundtrip() { + let val = [ + 0x123456789abcdef0, + 0xfedcba9876543210, + 0x1111111111111111, + 0x2222222222222222, + ]; + let (l0, l1, l2) = decompose_86(&val); + + // Each limb should be < 2^86 + assert!(l0 < (1u128 << 86)); + assert!(l1 < (1u128 << 86)); + // l2 has at most 84 bits (256 - 172) + assert!(l2 < (1u128 << 84)); + + // Roundtrip: l0 + l1 * 2^86 + l2 * 2^172 should equal val + // Build from limbs back to [u64; 4] + let mut reconstructed = [0u128; 2]; // lo128, hi128 + reconstructed[0] = l0; + // l1 starts at bit 86 + reconstructed[0] |= (l1 & ((1u128 << 42) - 1)) << 86; // lower 42 bits of l1 into lo128 + reconstructed[1] = l1 >> 42; // upper 44 bits of l1 + // l2 starts at bit 172 = 128 + 44 + reconstructed[1] |= l2 << 44; + + assert_eq!(reconstructed[0] as u64, val[0]); + assert_eq!((reconstructed[0] >> 64) as u64, val[1]); + assert_eq!(reconstructed[1] as u64, val[2]); + assert_eq!((reconstructed[1] >> 64) as u64, val[3]); + } + + #[test] + fn test_decompose_86_secp256r1_p() { + // secp256r1 field modulus + let p = [0xffffffffffffffff, 0xffffffff, 0x0, 0xffffffff00000001]; + let (l0, l1, l2) = decompose_86(&p); + assert!(l0 < (1u128 << 86)); + assert!(l1 < (1u128 << 86)); + assert!(l2 < (1u128 << 84)); + } + + #[test] + fn test_compute_carries_86_simple() { + // Test with small values: a=3, b=5, p=7 + // a*b = 15, 15 / 7 = 2 remainder 1 + // So q=2, r=1 + let a_val = [3u64, 0, 0, 0]; + let b_val = [5, 0, 0, 0]; + let p_val = [7, 0, 0, 0]; + let product = widening_mul(&a_val, &b_val); + let (q_val, r_val) = divmod_wide(&product, &p_val); + assert_eq!(q_val, [2, 0, 0, 0]); + assert_eq!(r_val, [1, 0, 0, 0]); + + let (a0, a1, a2) = decompose_86(&a_val); + let (b0, b1, b2) = decompose_86(&b_val); + let (p0, p1, p2) = decompose_86(&p_val); + let (q0, q1, q2) = decompose_86(&q_val); + let (r0, r1, r2) = decompose_86(&r_val); + + let carries = compute_carries_86([a0, a1, a2], [b0, b1, b2], [p0, p1, p2], [q0, q1, q2], [ + r0, r1, r2, + ]); + // For small values, all carries should be 0 + assert_eq!(carries, [0, 0, 0, 0]); + } + + #[test] + fn test_compute_carries_86_secp256r1() { + // Test with secp256r1-sized values + let p = [0xffffffffffffffff, 0xffffffff, 0x0, 0xffffffff00000001]; + let a_val = [0x123456789abcdef0, 0xfedcba9876543210, 0x0, 0x0]; // < p + let b_val = [0xaabbccddeeff0011, 0x1122334455667788, 0x0, 0x0]; // < p + + let product = widening_mul(&a_val, &b_val); + let (q_val, r_val) = divmod_wide(&product, &p); + + // Verify a*b = p*q + r + let pq = widening_mul(&p, &q_val); + let mut sum = pq; + let mut carry = 0u128; + for i in 0..4 { + let s = sum[i] as u128 + r_val[i] as u128 + carry; + sum[i] = s as u64; + carry = s >> 64; + } + for i in 4..8 { + let s = sum[i] as u128 + carry; + sum[i] = s as u64; + carry = s >> 64; + } + assert_eq!(sum, product); + + // Compute 86-bit decompositions + let (a0, a1, a2) = decompose_86(&a_val); + let (b0, b1, b2) = decompose_86(&b_val); + let (p0, p1, p2) = decompose_86(&p); + let (q0, q1, q2) = decompose_86(&q_val); + let (r0, r1, r2) = decompose_86(&r_val); + + // This should not panic + let _carries = + compute_carries_86([a0, a1, a2], [b0, b1, b2], [p0, p1, p2], [q0, q1, q2], [ + r0, r1, r2, + ]); + } } diff --git a/provekit/prover/src/witness/witness_builder.rs b/provekit/prover/src/witness/witness_builder.rs index 115637da6..ae49cfcd5 100644 --- a/provekit/prover/src/witness/witness_builder.rs +++ b/provekit/prover/src/witness/witness_builder.rs @@ -337,6 +337,200 @@ impl WitnessBuilderSolver for WitnessBuilder { lh_val.into_bigint() ^ rh_val.into_bigint(), )); } + WitnessBuilder::MulModHint { + output_start, + a_lo, + a_hi, + b_lo, + b_hi, + modulus, + } => { + use crate::witness::bigint_mod::{ + compute_carries_86, decompose_128, decompose_86, divmod_wide, widening_mul, + CARRY_OFFSET, + }; + + // Read inputs: a and b as 128-bit limb pairs + let a_lo_fe = witness[*a_lo].unwrap(); + let a_hi_fe = witness[*a_hi].unwrap(); + let b_lo_fe = witness[*b_lo].unwrap(); + let b_hi_fe = witness[*b_hi].unwrap(); + + // Reconstruct a, b as [u64; 4] + let a_lo_limbs = a_lo_fe.into_bigint().0; + let a_hi_limbs = a_hi_fe.into_bigint().0; + let a_val = [a_lo_limbs[0], a_lo_limbs[1], a_hi_limbs[0], a_hi_limbs[1]]; + + let b_lo_limbs = b_lo_fe.into_bigint().0; + let b_hi_limbs = b_hi_fe.into_bigint().0; + let b_val = [b_lo_limbs[0], b_lo_limbs[1], b_hi_limbs[0], b_hi_limbs[1]]; + + // Compute product and divmod + let product = widening_mul(&a_val, &b_val); + let (q_val, r_val) = divmod_wide(&product, modulus); + + // Decompose into 128-bit limbs + let (q_lo, q_hi) = decompose_128(&q_val); + let (r_lo, r_hi) = decompose_128(&r_val); + + // Decompose into 86-bit limbs + let (a86_0, a86_1, a86_2) = decompose_86(&a_val); + let (b86_0, b86_1, b86_2) = decompose_86(&b_val); + let (q86_0, q86_1, q86_2) = decompose_86(&q_val); + let (r86_0, r86_1, r86_2) = decompose_86(&r_val); + + // Compute carries + let carries = compute_carries_86( + [a86_0, a86_1, a86_2], + [b86_0, b86_1, b86_2], + { + let (p0, p1, p2) = decompose_86(modulus); + [p0, p1, p2] + }, + [q86_0, q86_1, q86_2], + [r86_0, r86_1, r86_2], + ); + + // Helper: convert u128 to FieldElement + let u128_to_fe = |val: u128| -> FieldElement { + FieldElement::from_bigint(ark_ff::BigInt([ + val as u64, + (val >> 64) as u64, + 0, + 0, + ])) + .unwrap() + }; + + // Write outputs: [0..2) q_lo, q_hi + witness[*output_start] = Some(u128_to_fe(q_lo)); + witness[*output_start + 1] = Some(u128_to_fe(q_hi)); + // [2..4) r_lo, r_hi + witness[*output_start + 2] = Some(u128_to_fe(r_lo)); + witness[*output_start + 3] = Some(u128_to_fe(r_hi)); + // [4..7) a_86 limbs + witness[*output_start + 4] = Some(u128_to_fe(a86_0)); + witness[*output_start + 5] = Some(u128_to_fe(a86_1)); + witness[*output_start + 6] = Some(u128_to_fe(a86_2)); + // [7..10) b_86 limbs + witness[*output_start + 7] = Some(u128_to_fe(b86_0)); + witness[*output_start + 8] = Some(u128_to_fe(b86_1)); + witness[*output_start + 9] = Some(u128_to_fe(b86_2)); + // [10..13) q_86 limbs + witness[*output_start + 10] = Some(u128_to_fe(q86_0)); + witness[*output_start + 11] = Some(u128_to_fe(q86_1)); + witness[*output_start + 12] = Some(u128_to_fe(q86_2)); + // [13..16) r_86 limbs + witness[*output_start + 13] = Some(u128_to_fe(r86_0)); + witness[*output_start + 14] = Some(u128_to_fe(r86_1)); + witness[*output_start + 15] = Some(u128_to_fe(r86_2)); + // [16..20) carries (unsigned-offset) + for i in 0..4 { + let c_unsigned = (carries[i] + CARRY_OFFSET as i128) as u128; + witness[*output_start + 16 + i] = Some(u128_to_fe(c_unsigned)); + } + } + WitnessBuilder::WideModularInverse { + output_start, + a_lo, + a_hi, + modulus, + } => { + use crate::witness::bigint_mod::{decompose_128, mod_pow, sub_u64}; + + // Read input a as 128-bit limb pair + let a_lo_fe = witness[*a_lo].unwrap(); + let a_hi_fe = witness[*a_hi].unwrap(); + + let a_lo_limbs = a_lo_fe.into_bigint().0; + let a_hi_limbs = a_hi_fe.into_bigint().0; + let a_val = [a_lo_limbs[0], a_lo_limbs[1], a_hi_limbs[0], a_hi_limbs[1]]; + + // Compute inverse: a^{p-2} mod p (Fermat's little theorem) + let exp = sub_u64(modulus, 2); + let inv = mod_pow(&a_val, &exp, modulus); + + // Decompose into 128-bit limbs + let (inv_lo, inv_hi) = decompose_128(&inv); + + let u128_to_fe = |val: u128| -> FieldElement { + FieldElement::from_bigint(ark_ff::BigInt([ + val as u64, + (val >> 64) as u64, + 0, + 0, + ])) + .unwrap() + }; + + witness[*output_start] = Some(u128_to_fe(inv_lo)); + witness[*output_start + 1] = Some(u128_to_fe(inv_hi)); + } + WitnessBuilder::WideAddQuotient { + output, + a_lo, + a_hi, + b_lo, + b_hi, + modulus, + } => { + use crate::witness::bigint_mod::{add_4limb, cmp_4limb}; + + let a_lo_fe = witness[*a_lo].unwrap(); + let a_hi_fe = witness[*a_hi].unwrap(); + let b_lo_fe = witness[*b_lo].unwrap(); + let b_hi_fe = witness[*b_hi].unwrap(); + + let a_lo_limbs = a_lo_fe.into_bigint().0; + let a_hi_limbs = a_hi_fe.into_bigint().0; + let a_val = [a_lo_limbs[0], a_lo_limbs[1], a_hi_limbs[0], a_hi_limbs[1]]; + + let b_lo_limbs = b_lo_fe.into_bigint().0; + let b_hi_limbs = b_hi_fe.into_bigint().0; + let b_val = [b_lo_limbs[0], b_lo_limbs[1], b_hi_limbs[0], b_hi_limbs[1]]; + + let sum = add_4limb(&a_val, &b_val); + // q = 1 if sum >= p, else 0 + let q = if sum[4] > 0 { + // sum > 2^256 > any 256-bit modulus + 1u64 + } else { + let sum4 = [sum[0], sum[1], sum[2], sum[3]]; + if cmp_4limb(&sum4, modulus) != std::cmp::Ordering::Less { + 1u64 + } else { + 0u64 + } + }; + + witness[*output] = Some(FieldElement::from(q)); + } + WitnessBuilder::WideSubBorrow { + output, + a_lo, + a_hi, + b_lo, + b_hi, + } => { + use crate::witness::bigint_mod::cmp_4limb; + + let a_lo_limbs = witness[*a_lo].unwrap().into_bigint().0; + let a_hi_limbs = witness[*a_hi].unwrap().into_bigint().0; + let a_val = [a_lo_limbs[0], a_lo_limbs[1], a_hi_limbs[0], a_hi_limbs[1]]; + + let b_lo_limbs = witness[*b_lo].unwrap().into_bigint().0; + let b_hi_limbs = witness[*b_hi].unwrap().into_bigint().0; + let b_val = [b_lo_limbs[0], b_lo_limbs[1], b_hi_limbs[0], b_hi_limbs[1]]; + + // q = 1 if a < b (need to add p to make result non-negative) + let q = if cmp_4limb(&a_val, &b_val) == std::cmp::Ordering::Less { + 1u64 + } else { + 0u64 + }; + + witness[*output] = Some(FieldElement::from(q)); + } WitnessBuilder::BytePartition { lo, hi, x, k } => { let x_val = witness[*x].unwrap().into_bigint().0[0]; debug_assert!(x_val < 256, "BytePartition input must be 8-bit"); diff --git a/provekit/r1cs-compiler/src/msm/curve.rs b/provekit/r1cs-compiler/src/msm/curve.rs index cbfd1bf25..d4d0d247b 100644 --- a/provekit/r1cs-compiler/src/msm/curve.rs +++ b/provekit/r1cs-compiler/src/msm/curve.rs @@ -1,71 +1,116 @@ -use provekit_common::FieldElement; +use { + crate::noir_to_r1cs::NoirToR1CSCompiler, + provekit_common::{ + witness::{ConstantTerm, WitnessBuilder}, + FieldElement, + }, +}; -// TODO : remove Option<> form both the params if comes in use -// otherwise we delete the params from struct pub struct CurveParams { - pub field_modulus_p: FieldElement, - pub curve_order_n: FieldElement, - pub curve_a: FieldElement, - pub curve_b: FieldElement, - pub generator: (FieldElement, FieldElement), - pub coordinate_bits: Option, + pub field_modulus_p: [u64; 4], + pub curve_order_n: [u64; 4], + pub curve_a: [u64; 4], + pub curve_b: [u64; 4], + pub generator: ([u64; 4], [u64; 4]), +} + +impl CurveParams { + pub fn p_lo_fe(&self) -> FieldElement { + decompose_128(self.field_modulus_p).0 + } + pub fn p_hi_fe(&self) -> FieldElement { + decompose_128(self.field_modulus_p).1 + } + pub fn p_86_limbs(&self) -> [FieldElement; 3] { + let mask_86: u128 = (1u128 << 86) - 1; + let lo128 = self.field_modulus_p[0] as u128 | ((self.field_modulus_p[1] as u128) << 64); + let hi128 = self.field_modulus_p[2] as u128 | ((self.field_modulus_p[3] as u128) << 64); + let l0 = lo128 & mask_86; + // l1 spans bits [86..172): 42 bits from lo128, 44 bits from hi128 + let l1 = ((lo128 >> 86) | (hi128 << 42)) & mask_86; + // l2 = bits [172..256): 84 bits from hi128 + let l2 = hi128 >> 44; + [ + FieldElement::from(l0), + FieldElement::from(l1), + FieldElement::from(l2), + ] + } + pub fn p_native_fe(&self) -> FieldElement { + curve_native_point_fe(&self.field_modulus_p) + } +} + +/// Splits a 256-bit value ([u64; 4]) into two 128-bit field elements (lo, hi). +fn decompose_128(val: [u64; 4]) -> (FieldElement, FieldElement) { + ( + FieldElement::from((val[0] as u128) | ((val[1] as u128) << 64)), + FieldElement::from((val[2] as u128) | ((val[3] as u128) << 64)), + ) +} + +/// Converts a 256-bit value ([u64; 4]) into a single native field element. +pub fn curve_native_point_fe(val: &[u64; 4]) -> FieldElement { + FieldElement::from_sign_and_limbs(true, val) +} + +#[derive(Clone, Copy, Debug)] +pub struct Limb2 { + pub lo: usize, + pub hi: usize, +} + +pub fn limb2_constant(r1cs_compiler: &mut NoirToR1CSCompiler, value: [u64; 4]) -> Limb2 { + let (lo, hi) = decompose_128(value); + let lo_idx = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::Constant(ConstantTerm(lo_idx, lo))); + let hi_idx = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::Constant(ConstantTerm(hi_idx, hi))); + Limb2 { + lo: lo_idx, + hi: hi_idx, + } } pub fn secp256r1_params() -> CurveParams { CurveParams { - field_modulus_p: FieldElement::from_sign_and_limbs( - true, - [ - 0xffffffffffffffff_u64, - 0xffffffff_u64, - 0x0_u64, - 0xffffffff00000001_u64, - ] - .as_slice(), - ), - curve_order_n: FieldElement::from_sign_and_limbs( - true, + field_modulus_p: [ + 0xffffffffffffffff_u64, + 0xffffffff_u64, + 0x0_u64, + 0xffffffff00000001_u64, + ], + curve_order_n: [ + 0xf3b9cac2fc632551_u64, + 0xbce6faada7179e84_u64, + 0xffffffffffffffff_u64, + 0xffffffff00000000_u64, + ], + curve_a: [ + 0xfffffffffffffffc_u64, + 0x00000000ffffffff_u64, + 0x0000000000000000_u64, + 0xffffffff00000001_u64, + ], + curve_b: [ + 0x3bce3c3e27d2604b_u64, + 0x651d06b0cc53b0f6_u64, + 0xb3ebbd55769886bc_u64, + 0x5ac635d8aa3a93e7_u64, + ], + generator: ( [ - 0xf3b9cac2fc632551_u64, - 0xbce6faada7179e84_u64, - 0xffffffffffffffff_u64, - 0xffffffff00000000_u64, - ] - .as_slice(), - ), - curve_a: FieldElement::from(-3), - curve_b: FieldElement::from_sign_and_limbs( - true, + 0xf4a13945d898c296_u64, + 0x77037d812deb33a0_u64, + 0xf8bce6e563a440f2_u64, + 0x6b17d1f2e12c4247_u64, + ], [ - 0x3bce3c3e27d2604b_u64, - 0x651d06b0cc53b0f6_u64, - 0xb3ebbd55769886bc_u64, - 0x5ac635d8aa3a93e7_u64, - ] - .as_slice(), - ), - generator: ( - FieldElement::from_sign_and_limbs( - true, - [ - 0xf4a13945d898c296_u64, - 0x77037d812deb33a0_u64, - 0xf8bce6e563a440f2_u64, - 0x6b17d1f2e12c4247_u64, - ] - .as_slice(), - ), - FieldElement::from_sign_and_limbs( - true, - [ - 0xcbb6406837bf51f5_u64, - 0x2bce33576b315ece_u64, - 0x8ee7eb4a7c0f9e16_u64, - 0x4fe342e2fe1a7f9b_u64, - ] - .as_slice(), - ), + 0xcbb6406837bf51f5_u64, + 0x2bce33576b315ece_u64, + 0x8ee7eb4a7c0f9e16_u64, + 0x4fe342e2fe1a7f9b_u64, + ], ), - coordinate_bits: None, } } diff --git a/provekit/r1cs-compiler/src/msm/ec_ops.rs b/provekit/r1cs-compiler/src/msm/ec_ops.rs index 1bd96e69f..985937821 100644 --- a/provekit/r1cs-compiler/src/msm/ec_ops.rs +++ b/provekit/r1cs-compiler/src/msm/ec_ops.rs @@ -1,5 +1,5 @@ use { - crate::{msm::curve::CurveParams, noir_to_r1cs::NoirToR1CSCompiler}, + crate::noir_to_r1cs::NoirToR1CSCompiler, ark_ff::{AdditiveGroup, BigInteger, Field, PrimeField}, provekit_common::{ witness::{SumTerm, WitnessBuilder}, @@ -9,7 +9,7 @@ use { }; /// Reduce the value to given modulus -pub fn reduce_mod( +pub fn reduce_mod_p( r1cs_compiler: &mut NoirToR1CSCompiler, value: usize, modulus: FieldElement, @@ -65,8 +65,30 @@ pub fn reduce_mod( result } -/// a * b mod m -pub fn compute_field_mul( +/// a + b mod p +pub fn add_mod_p( + r1cs_compiler: &mut NoirToR1CSCompiler, + a: usize, + b: usize, + modulus: FieldElement, + range_checks: &mut BTreeMap>, +) -> usize { + let a_add_b = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::Sum(a_add_b, vec![ + SumTerm(Some(FieldElement::ONE), a), + SumTerm(Some(FieldElement::ONE), b), + ])); + // constraint: a + b = a_add_b + r1cs_compiler.r1cs.add_constraint( + &[(FieldElement::ONE, a), (FieldElement::ONE, b)], + &[(FieldElement::ONE, r1cs_compiler.witness_one())], + &[(FieldElement::ONE, a_add_b)], + ); + reduce_mod_p(r1cs_compiler, a_add_b, modulus, range_checks) +} + +/// a * b mod p +pub fn mul_mod_p( r1cs_compiler: &mut NoirToR1CSCompiler, a: usize, b: usize, @@ -82,11 +104,11 @@ pub fn compute_field_mul( FieldElement::ONE, a_mul_b, )]); - reduce_mod(r1cs_compiler, a_mul_b, modulus, range_checks) + reduce_mod_p(r1cs_compiler, a_mul_b, modulus, range_checks) } -/// (a - b) mod m -pub fn compute_field_sub( +/// (a - b) mod p +pub fn sub_mod_p( r1cs_compiler: &mut NoirToR1CSCompiler, a: usize, b: usize, @@ -104,16 +126,11 @@ pub fn compute_field_sub( &[(FieldElement::ONE, a), (-FieldElement::ONE, b)], &[(FieldElement::ONE, a_sub_b)], ); - reduce_mod(r1cs_compiler, a_sub_b, modulus, range_checks) + reduce_mod_p(r1cs_compiler, a_sub_b, modulus, range_checks) } -/// a^(-1) mod m -/// -/// CRITICAL: secp256r1's field_modulus_p (~2^256) > BN254 scalar field -/// (~2^254). Coordinates and the modulus do not fit in a single -/// FieldElement. Either use multi-limb representation or target a -/// curve that fits (e.g. Grumpkin, BabyJubJub). -pub fn compute_field_inv( +/// a^(-1) mod p +pub fn inv_mod_p( r1cs_compiler: &mut NoirToR1CSCompiler, a: usize, modulus: FieldElement, @@ -127,17 +144,8 @@ pub fn compute_field_inv( // Verifying a * a_inv mod m = 1 // ----------------------------------------------------------- - // computing a * a_inv - let product_raw = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::Product(product_raw, a, a_inv)); - // constraint: a * a_inv = product_raw - r1cs_compiler - .r1cs - .add_constraint(&[(FieldElement::ONE, a)], &[(FieldElement::ONE, a_inv)], &[ - (FieldElement::ONE, product_raw), - ]); - // reducing a * a_inv mod m — should give 1 if a_inv is correct - let reduced = reduce_mod(r1cs_compiler, product_raw, modulus, range_checks); + // computing a * a_inv mod m + let reduced = mul_mod_p(r1cs_compiler, a, a_inv, modulus, range_checks); // constraint: reduced = 1 // (reduced - 1) * 1 = 0 @@ -160,111 +168,6 @@ pub fn compute_field_inv( a_inv } -/// Point doubling on y^2 = x^3 + ax + b (mod p) using affine lambda formula. -/// -/// Given P = (x1, y1), computes 2P = (x3, y3): -/// lambda = (3 * x1^2 + a) / (2 * y1) (mod p) -/// x3 = lambda^2 - 2 * x1 (mod p) -/// y3 = lambda * (x1 - x3) - y1 (mod p) -/// -/// Edge case — y1 = 0 (point of order 2): -/// When y1 = 0, the denominator 2*y1 = 0 and the inverse does not exist. -/// The result should be the point at infinity (identity element). -/// This function does NOT handle that case — the constraint system will -/// be unsatisfiable if y1 = 0 (compute_field_inv will fail to verify -/// 0 * inv = 1 mod p). The caller must check y1 = 0 using -/// compute_is_zero and conditionally select the point-at-infinity -/// result before calling this function. -pub fn point_double( - r1cs_compiler: &mut NoirToR1CSCompiler, - x1: usize, - y1: usize, - curve_params: &CurveParams, - range_checks: &mut BTreeMap>, -) -> (usize, usize) { - let p = curve_params.field_modulus_p; - - // Computing numerator = 3 * x1^2 + a (mod p) - // ----------------------------------------------------------- - // computing x1^2 mod p - let x1_sq = compute_field_mul(r1cs_compiler, x1, x1, p, range_checks); - // computing 3 * x1_sq + a - let a_witness = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::Constant( - provekit_common::witness::ConstantTerm(a_witness, curve_params.curve_a), - )); - let num_raw = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::Sum(num_raw, vec![ - SumTerm(Some(FieldElement::from(3u64)), x1_sq), - SumTerm(Some(FieldElement::ONE), a_witness), - ])); - // constraint: 1 * (3 * x1_sq + a) = num_raw - r1cs_compiler.r1cs.add_constraint( - &[(FieldElement::ONE, r1cs_compiler.witness_one())], - &[ - (FieldElement::from(3u64), x1_sq), - (FieldElement::ONE, a_witness), - ], - &[(FieldElement::ONE, num_raw)], - ); - let numerator = reduce_mod(r1cs_compiler, num_raw, p, range_checks); - - // Computing denominator = 2 * y1 (mod p) - // ----------------------------------------------------------- - // computing 2 * y1 - let denom_raw = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::Sum(denom_raw, vec![SumTerm( - Some(FieldElement::from(2u64)), - y1, - )])); - // constraint: 1 * (2 * y1) = denom_raw - r1cs_compiler.r1cs.add_constraint( - &[(FieldElement::ONE, r1cs_compiler.witness_one())], - &[(FieldElement::from(2u64), y1)], - &[(FieldElement::ONE, denom_raw)], - ); - let denominator = reduce_mod(r1cs_compiler, denom_raw, p, range_checks); - - // Computing lambda = numerator * denominator^(-1) (mod p) - // ----------------------------------------------------------- - // computing denominator^(-1) mod p - let denom_inv = compute_field_inv(r1cs_compiler, denominator, p, range_checks); - // computing lambda = numerator * denom_inv mod p - let lambda = compute_field_mul(r1cs_compiler, numerator, denom_inv, p, range_checks); - - // Computing x3 = lambda^2 - 2 * x1 (mod p) - // ----------------------------------------------------------- - // computing lambda^2 mod p - let lambda_sq = compute_field_mul(r1cs_compiler, lambda, lambda, p, range_checks); - // computing lambda^2 - 2 * x1 - let x3_raw = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::Sum(x3_raw, vec![ - SumTerm(Some(FieldElement::ONE), lambda_sq), - SumTerm(Some(-FieldElement::from(2u64)), x1), - ])); - // constraint: 1 * (lambda^2 - 2 * x1) = x3_raw - r1cs_compiler.r1cs.add_constraint( - &[(FieldElement::ONE, r1cs_compiler.witness_one())], - &[ - (FieldElement::ONE, lambda_sq), - (-FieldElement::from(2u64), x1), - ], - &[(FieldElement::ONE, x3_raw)], - ); - let x3 = reduce_mod(r1cs_compiler, x3_raw, p, range_checks); - - // Computing y3 = lambda * (x1 - x3) - y1 (mod p) - // ----------------------------------------------------------- - // computing x1 - x3 mod p - let x1_minus_x3 = compute_field_sub(r1cs_compiler, x1, x3, p, range_checks); - // computing lambda * (x1 - x3) mod p - let lambda_dx = compute_field_mul(r1cs_compiler, lambda, x1_minus_x3, p, range_checks); - // computing lambda * (x1 - x3) - y1 mod p - let y3 = compute_field_sub(r1cs_compiler, lambda_dx, y1, p, range_checks); - - (x3, y3) -} - /// checks if value is zero or not pub fn compute_is_zero(r1cs_compiler: &mut NoirToR1CSCompiler, value: usize) -> usize { // calculating v^(-1) diff --git a/provekit/r1cs-compiler/src/msm/ec_points.rs b/provekit/r1cs-compiler/src/msm/ec_points.rs new file mode 100644 index 000000000..d607d25ff --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/ec_points.rs @@ -0,0 +1,101 @@ +use super::FieldOps; + +/// Generic point doubling on y^2 = x^3 + ax + b. +/// +/// Given P = (x1, y1), computes 2P = (x3, y3): +/// lambda = (3 * x1^2 + a) / (2 * y1) +/// x3 = lambda^2 - 2 * x1 +/// y3 = lambda * (x1 - x3) - y1 +/// +/// Edge case — y1 = 0 (point of order 2): +/// When y1 = 0, the denominator 2*y1 = 0 and the inverse does not exist. +/// The result should be the point at infinity (identity element). +/// This function does NOT handle that case — the constraint system will +/// be unsatisfiable if y1 = 0 (the inverse verification will fail to +/// verify 0 * inv = 1 mod p). The caller must check y1 = 0 using +/// compute_is_zero and conditionally select the point-at-infinity +/// result before calling this function. +pub fn point_double(ops: &mut F, x1: F::Elem, y1: F::Elem) -> (F::Elem, F::Elem) { + let a = ops.curve_a(); + + // Computing numerator = 3 * x1^2 + a + let x1_sq = ops.mul(x1, x1); + let two_x1_sq = ops.add(x1_sq, x1_sq); + let three_x1_sq = ops.add(two_x1_sq, x1_sq); + let numerator = ops.add(three_x1_sq, a); + + // Computing denominator = 2 * y1 + let denominator = ops.add(y1, y1); + + // Computing lambda = numerator * denominator^(-1) + let denom_inv = ops.inv(denominator); + let lambda = ops.mul(numerator, denom_inv); + + // Computing x3 = lambda^2 - 2 * x1 + let lambda_sq = ops.mul(lambda, lambda); + let two_x1 = ops.add(x1, x1); + let x3 = ops.sub(lambda_sq, two_x1); + + // Computing y3 = lambda * (x1 - x3) - y1 + let x1_minus_x3 = ops.sub(x1, x3); + let lambda_dx = ops.mul(lambda, x1_minus_x3); + let y3 = ops.sub(lambda_dx, y1); + + (x3, y3) +} + +/// Generic point addition on y^2 = x^3 + ax + b. +/// +/// Given P1 = (x1, y1) and P2 = (x2, y2), computes P1 + P2 = (x3, y3): +/// lambda = (y2 - y1) / (x2 - x1) +/// x3 = lambda^2 - x1 - x2 +/// y3 = lambda * (x1 - x3) - y1 +/// +/// Edge cases — x1 = x2: +/// When x1 = x2, the denominator (x2 - x1) = 0 and the inverse does +/// not exist. This covers two cases: +/// - P1 = P2 (same point): use `point_double` instead. +/// - P1 = -P2 (y1 = -y2): the result is the point at infinity. +/// This function does NOT handle either case — the constraint system +/// will be unsatisfiable if x1 = x2. The caller must detect this +/// and branch accordingly. +pub fn point_add( + ops: &mut F, + x1: F::Elem, + y1: F::Elem, + x2: F::Elem, + y2: F::Elem, +) -> (F::Elem, F::Elem) { + // Computing lambda = (y2 - y1) / (x2 - x1) + let numerator = ops.sub(y2, y1); + let denominator = ops.sub(x2, x1); + let denom_inv = ops.inv(denominator); + let lambda = ops.mul(numerator, denom_inv); + + // Computing x3 = lambda^2 - x1 - x2 + let lambda_sq = ops.mul(lambda, lambda); + let x1_plus_x2 = ops.add(x1, x2); + let x3 = ops.sub(lambda_sq, x1_plus_x2); + + // Computing y3 = lambda * (x1 - x3) - y1 + let x1_minus_x3 = ops.sub(x1, x3); + let lambda_dx = ops.mul(lambda, x1_minus_x3); + let y3 = ops.sub(lambda_dx, y1); + + (x3, y3) +} + +/// Conditional point select: returns `on_true` if `flag` is 1, `on_false` if +/// `flag` is 0. +/// +/// Constrains `flag` to be boolean (`flag * flag = flag`). +pub fn point_select( + ops: &mut F, + flag: usize, + on_false: (F::Elem, F::Elem), + on_true: (F::Elem, F::Elem), +) -> (F::Elem, F::Elem) { + let x = ops.select(flag, on_false.0, on_true.0); + let y = ops.select(flag, on_false.1, on_true.1); + (x, y) +} diff --git a/provekit/r1cs-compiler/src/msm/mod.rs b/provekit/r1cs-compiler/src/msm/mod.rs index 3844d1466..a155a6def 100644 --- a/provekit/r1cs-compiler/src/msm/mod.rs +++ b/provekit/r1cs-compiler/src/msm/mod.rs @@ -1,2 +1,143 @@ pub mod curve; pub mod ec_ops; +pub mod ec_points; +pub mod wide_ops; + +use { + crate::noir_to_r1cs::NoirToR1CSCompiler, + ark_ff::Field, + curve::{curve_native_point_fe, limb2_constant, CurveParams, Limb2}, + provekit_common::{ + witness::{ConstantTerm, SumTerm, WitnessBuilder}, + FieldElement, + }, + std::collections::BTreeMap, +}; + +pub trait FieldOps { + type Elem: Copy; + + fn add(&mut self, a: Self::Elem, b: Self::Elem) -> Self::Elem; + fn sub(&mut self, a: Self::Elem, b: Self::Elem) -> Self::Elem; + fn mul(&mut self, a: Self::Elem, b: Self::Elem) -> Self::Elem; + fn inv(&mut self, a: Self::Elem) -> Self::Elem; + fn curve_a(&mut self) -> Self::Elem; + + /// Conditional select: returns `on_true` if `flag` is 1, `on_false` if + /// `flag` is 0. Constrains `flag` to be boolean (`flag * flag = flag`). + fn select(&mut self, flag: usize, on_false: Self::Elem, on_true: Self::Elem) -> Self::Elem; +} + +/// Narrow field operations for curves where p fits in BN254's scalar field. +/// Operates on single witness indices (`usize`). +pub struct NarrowOps<'a> { + pub compiler: &'a mut NoirToR1CSCompiler, + pub range_checks: &'a mut BTreeMap>, + pub modulus: FieldElement, + pub params: &'a CurveParams, +} + +impl FieldOps for NarrowOps<'_> { + type Elem = usize; + + fn add(&mut self, a: usize, b: usize) -> usize { + ec_ops::add_mod_p(self.compiler, a, b, self.modulus, self.range_checks) + } + + fn sub(&mut self, a: usize, b: usize) -> usize { + ec_ops::sub_mod_p(self.compiler, a, b, self.modulus, self.range_checks) + } + + fn mul(&mut self, a: usize, b: usize) -> usize { + ec_ops::mul_mod_p(self.compiler, a, b, self.modulus, self.range_checks) + } + + fn inv(&mut self, a: usize) -> usize { + ec_ops::inv_mod_p(self.compiler, a, self.modulus, self.range_checks) + } + + fn curve_a(&mut self) -> usize { + let a_fe = curve_native_point_fe(&self.params.curve_a); + let w = self.compiler.num_witnesses(); + self.compiler + .add_witness_builder(WitnessBuilder::Constant(ConstantTerm(w, a_fe))); + w + } + + fn select(&mut self, flag: usize, on_false: usize, on_true: usize) -> usize { + constrain_boolean(self.compiler, flag); + select_witness(self.compiler, flag, on_false, on_true) + } +} + +/// Wide field operations for curves where p > BN254_r (e.g. secp256r1). +/// Operates on `Limb2` (two 128-bit limbs). +pub struct WideOps<'a> { + pub compiler: &'a mut NoirToR1CSCompiler, + pub range_checks: &'a mut BTreeMap>, + pub params: &'a CurveParams, +} + +impl FieldOps for WideOps<'_> { + type Elem = Limb2; + + fn add(&mut self, a: Limb2, b: Limb2) -> Limb2 { + wide_ops::add_mod_p(self.compiler, self.range_checks, a, b, self.params) + } + + fn sub(&mut self, a: Limb2, b: Limb2) -> Limb2 { + wide_ops::sub_mod_p(self.compiler, self.range_checks, a, b, self.params) + } + + fn mul(&mut self, a: Limb2, b: Limb2) -> Limb2 { + wide_ops::mul_mod_p(self.compiler, self.range_checks, a, b, self.params) + } + + fn inv(&mut self, a: Limb2) -> Limb2 { + wide_ops::inv_mod_p(self.compiler, self.range_checks, a, self.params) + } + + fn curve_a(&mut self) -> Limb2 { + limb2_constant(self.compiler, self.params.curve_a) + } + + fn select(&mut self, flag: usize, on_false: Limb2, on_true: Limb2) -> Limb2 { + constrain_boolean(self.compiler, flag); + Limb2 { + lo: select_witness(self.compiler, flag, on_false.lo, on_true.lo), + hi: select_witness(self.compiler, flag, on_false.hi, on_true.hi), + } + } +} + +// --------------------------------------------------------------------------- +// Private helpers +// --------------------------------------------------------------------------- + +/// Constrains `flag` to be boolean: `flag * flag = flag`. +fn constrain_boolean(compiler: &mut NoirToR1CSCompiler, flag: usize) { + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, flag)], + &[(FieldElement::ONE, flag)], + &[(FieldElement::ONE, flag)], + ); +} + +/// Single-witness conditional select: `out = on_false + flag * (on_true - +/// on_false)`. +/// +/// Produces 3 witnesses and 3 R1CS constraints (diff, flag*diff, out). +/// Does NOT constrain `flag` to be boolean — caller must do that separately. +fn select_witness( + compiler: &mut NoirToR1CSCompiler, + flag: usize, + on_false: usize, + on_true: usize, +) -> usize { + let diff = compiler.add_sum(vec![ + SumTerm(None, on_true), + SumTerm(Some(-FieldElement::ONE), on_false), + ]); + let flag_diff = compiler.add_product(flag, diff); + compiler.add_sum(vec![SumTerm(None, on_false), SumTerm(None, flag_diff)]) +} diff --git a/provekit/r1cs-compiler/src/msm/wide_ops.rs b/provekit/r1cs-compiler/src/msm/wide_ops.rs new file mode 100644 index 000000000..167d6f986 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/wide_ops.rs @@ -0,0 +1,563 @@ +use { + crate::{ + msm::curve::{CurveParams, Limb2}, + noir_to_r1cs::NoirToR1CSCompiler, + }, + ark_ff::Field, + provekit_common::{ + witness::{SumTerm, WitnessBuilder}, + FieldElement, + }, + std::collections::BTreeMap, +}; + +/// (a + b) mod p for 256-bit values in two 128-bit limbs. +/// +/// Equation: a + b = q * p + r, where q ∈ {0, 1}, 0 ≤ r < p. +/// +/// Uses the offset trick to avoid negative intermediate values: +/// v_offset = a_lo + b_lo + 2^128 - q * p_lo (always ≥ 0) +/// carry_offset = floor(v_offset / 2^128) ∈ {0, 1, 2} +/// r_lo = v_offset - carry_offset * 2^128 +/// r_hi = a_hi + b_hi + carry_offset - 1 - q * p_hi +/// +/// Less-than-p check (proves r < p): +/// d_lo + d_hi * 2^128 = (p - 1) - r (all components ≥ 0) +/// +/// Constraints (7 total): +/// 1. q is boolean: q * q = q +/// 2-3. Column 0: v_offset defined, then r_lo = v_offset - carry_offset * +/// 2^128 +/// 4. Column 1: r_hi = a_hi + b_hi + carry_offset - 1 - q * p_hi +/// 5-6. LT check: v_diff defined, then d_lo = v_diff - borrow_compl * 2^128 +/// 7. LT check: d_hi = (p_hi - 1) + borrow_compl - r_hi +/// +/// Range checks: r_lo, r_hi, d_lo, d_hi (128-bit each) +pub fn add_mod_p( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + a: Limb2, + b: Limb2, + params: &CurveParams, +) -> Limb2 { + let two_128 = FieldElement::from(2u64).pow([128u64]); + let p_lo_fe = params.p_lo_fe(); + let p_hi_fe = params.p_hi_fe(); + let w1 = compiler.witness_one(); + + // Witness: q = floor((a + b) / p) ∈ {0, 1} + // ----------------------------------------------------------- + let q = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::WideAddQuotient { + output: q, + a_lo: a.lo, + a_hi: a.hi, + b_lo: b.lo, + b_hi: b.hi, + modulus: params.field_modulus_p, + }); + // constraining q to be boolean + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, q)], &[(FieldElement::ONE, q)], &[( + FieldElement::ONE, + q, + )]); + + // Computing r_lo: lower 128 bits of result + // ----------------------------------------------------------- + // v_offset = a_lo + b_lo + 2^128 - q * p_lo + // (2^128 offset ensures v_offset is always non-negative) + let v_offset = compiler.add_sum(vec![ + SumTerm(None, a.lo), + SumTerm(None, b.lo), + SumTerm(Some(two_128), w1), + SumTerm(Some(-p_lo_fe), q), + ]); + // computing carry_offset = floor(v_offset / 2^128) + let carry_offset = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::IntegerQuotient( + carry_offset, + v_offset, + two_128, + )); + // computing r_lo = v_offset - carry_offset * 2^128 + let r_lo = compiler.add_sum(vec![ + SumTerm(None, v_offset), + SumTerm(Some(-two_128), carry_offset), + ]); + + // Computing r_hi: upper 128 bits of result + // ----------------------------------------------------------- + // r_hi = a_hi + b_hi + carry_offset - 1 - q * p_hi + // (-1 compensates for the 2^128 offset added in the low column) + let r_hi = compiler.add_sum(vec![ + SumTerm(None, a.hi), + SumTerm(None, b.hi), + SumTerm(None, carry_offset), + SumTerm(Some(-FieldElement::ONE), w1), + SumTerm(Some(-p_hi_fe), q), + ]); + + less_than_p_check(compiler, range_checks, r_lo, r_hi, params); + + Limb2 { lo: r_lo, hi: r_hi } +} + +/// (a - b) mod p for 256-bit values in two 128-bit limbs. +/// +/// Equation: a - b + q * p = r, where q ∈ {0, 1}, 0 ≤ r < p. +/// q = 0 if a ≥ b (result is non-negative without correction) +/// q = 1 if a < b (add p to make result non-negative) +/// +/// Uses the offset trick to avoid negative intermediate values: +/// v_offset = a_lo - b_lo + q * p_lo + 2^128 (always ≥ 0) +/// carry_offset = floor(v_offset / 2^128) ∈ {0, 1, 2} +/// r_lo = v_offset - carry_offset * 2^128 +/// r_hi = a_hi - b_hi + q * p_hi + carry_offset - 1 +/// +/// Less-than-p check (proves r < p): +/// d_lo + d_hi * 2^128 = (p - 1) - r (all components ≥ 0) +/// +/// Constraints (7 total): +/// 1. q is boolean: q * q = q +/// 2-3. Column 0: v_offset defined, then r_lo = v_offset - carry_offset * +/// 2^128 +/// 4. Column 1: r_hi = a_hi - b_hi + q * p_hi + carry_offset - 1 +/// 5-6. LT check: v_diff defined, then d_lo = v_diff - borrow_compl * 2^128 +/// 7. LT check: d_hi = (p_hi - 1) + borrow_compl - r_hi +/// +/// Range checks: r_lo, r_hi, d_lo, d_hi (128-bit each) +pub fn sub_mod_p( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + a: Limb2, + b: Limb2, + params: &CurveParams, +) -> Limb2 { + let two_128 = FieldElement::from(2u64).pow([128u64]); + let p_lo_fe = params.p_lo_fe(); + let p_hi_fe = params.p_hi_fe(); + let w1 = compiler.witness_one(); + + // Witness: q = (a < b) ? 1 : 0 + // ----------------------------------------------------------- + let q = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::WideSubBorrow { + output: q, + a_lo: a.lo, + a_hi: a.hi, + b_lo: b.lo, + b_hi: b.hi, + }); + // constraining q to be boolean + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, q)], &[(FieldElement::ONE, q)], &[( + FieldElement::ONE, + q, + )]); + + // Computing r_lo: lower 128 bits of result + // ----------------------------------------------------------- + // v_offset = a_lo - b_lo + q * p_lo + 2^128 + // (2^128 offset ensures v_offset is always non-negative) + let v_offset = compiler.add_sum(vec![ + SumTerm(None, a.lo), + SumTerm(Some(-FieldElement::ONE), b.lo), + SumTerm(Some(p_lo_fe), q), + SumTerm(Some(two_128), w1), + ]); + // computing carry_offset = floor(v_offset / 2^128) + let carry_offset = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::IntegerQuotient( + carry_offset, + v_offset, + two_128, + )); + // computing r_lo = v_offset - carry_offset * 2^128 + let r_lo = compiler.add_sum(vec![ + SumTerm(None, v_offset), + SumTerm(Some(-two_128), carry_offset), + ]); + + // Computing r_hi: upper 128 bits of result + // ----------------------------------------------------------- + // r_hi = a_hi - b_hi + q * p_hi + carry_offset - 1 + // (-1 compensates for the 2^128 offset added in the low column) + let r_hi = compiler.add_sum(vec![ + SumTerm(None, a.hi), + SumTerm(Some(-FieldElement::ONE), b.hi), + SumTerm(Some(p_hi_fe), q), + SumTerm(None, carry_offset), + SumTerm(Some(-FieldElement::ONE), w1), + ]); + + less_than_p_check(compiler, range_checks, r_lo, r_hi, params); + + Limb2 { lo: r_lo, hi: r_hi } +} + +/// (a × b) mod p for 256-bit values in two 128-bit limbs. +/// +/// Verifies the integer identity `a * b = p * q + r` using schoolbook +/// multiplication in base W = 2^86 (86-bit limbs ensure all column +/// products < 2^172 ≪ BN254_r ≈ 2^254, so field equations = integer equations). +/// +/// Three layers of verification: +/// 1. Decomposition links: prove 86-bit witnesses match the 128-bit +/// inputs/outputs +/// 2. Column equations: prove a86 * b86 = p86 * q86 + r86 (integer) +/// 3. Less-than-p check: prove r < p +/// +/// Witness layout (MulModHint, 20 witnesses at output_start): +/// [0..2) q_lo, q_hi — quotient 128-bit limbs (unconstrained) +/// [2..4) r_lo, r_hi — remainder 128-bit limbs (OUTPUT) +/// [4..7) a86_0..2 — a in 86-bit limbs +/// [7..10) b86_0..2 — b in 86-bit limbs +/// [10..13) q86_0..2 — q in 86-bit limbs +/// [13..16) r86_0..2 — r in 86-bit limbs +/// [16..20) c0u..c3u — unsigned-offset carries (c_signed + 2^88) +/// +/// Constraints (26 total): +/// 9 decomposition links (a, b, r × 3 each) +/// 9 product witnesses (a_i × b_j) +/// 5 column equations +/// 3 less-than-p check +/// +/// Range checks (23 total): +/// 128-bit: r_lo, r_hi, d_lo, d_hi +/// 86-bit: a86_0, a86_1, b86_0, b86_1, q86_0, q86_1, r86_0, r86_1 +/// 84-bit: a86_2, b86_2, q86_2, r86_2 +/// 89-bit: c0u, c1u, c2u, c3u +/// 44-bit: carry_a, carry_b, carry_r +pub fn mul_mod_p( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + a: Limb2, + b: Limb2, + params: &CurveParams, +) -> Limb2 { + let two_44 = FieldElement::from(2u64).pow([44u64]); + let two_86 = FieldElement::from(2u64).pow([86u64]); + let two_128 = FieldElement::from(2u64).pow([128u64]); + let offset_fe = FieldElement::from(2u64).pow([88u64]); // CARRY_OFFSET + let offset_w = FieldElement::from(2u64).pow([174u64]); // 2^88 * 2^86 + let offset_w_minus_1 = offset_w - offset_fe; // 2^88 * (2^86 - 1) + let [p0, p1, p2] = params.p_86_limbs(); + let w1 = compiler.witness_one(); + + // Step 1: Allocate MulModHint (20 witnesses) + // ----------------------------------------------------------- + let os = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::MulModHint { + output_start: os, + a_lo: a.lo, + a_hi: a.hi, + b_lo: b.lo, + b_hi: b.hi, + modulus: params.field_modulus_p, + }); + + // Witness indices + let r_lo = os + 2; + let r_hi = os + 3; + let a86 = [os + 4, os + 5, os + 6]; + let b86 = [os + 7, os + 8, os + 9]; + let q86 = [os + 10, os + 11, os + 12]; + let r86 = [os + 13, os + 14, os + 15]; + let cu = [os + 16, os + 17, os + 18, os + 19]; + + // Step 2: Decomposition consistency for a, b, r + // ----------------------------------------------------------- + decompose_check( + compiler, + range_checks, + a.lo, + a.hi, + a86, + two_86, + two_44, + two_128, + w1, + ); + decompose_check( + compiler, + range_checks, + b.lo, + b.hi, + b86, + two_86, + two_44, + two_128, + w1, + ); + decompose_check( + compiler, + range_checks, + r_lo, + r_hi, + r86, + two_86, + two_44, + two_128, + w1, + ); + + // Step 3: Product witnesses (9 R1CS constraints) + // ----------------------------------------------------------- + let ab00 = compiler.add_product(a86[0], b86[0]); + let ab01 = compiler.add_product(a86[0], b86[1]); + let ab10 = compiler.add_product(a86[1], b86[0]); + let ab02 = compiler.add_product(a86[0], b86[2]); + let ab11 = compiler.add_product(a86[1], b86[1]); + let ab20 = compiler.add_product(a86[2], b86[0]); + let ab12 = compiler.add_product(a86[1], b86[2]); + let ab21 = compiler.add_product(a86[2], b86[1]); + let ab22 = compiler.add_product(a86[2], b86[2]); + + // Step 4: Column equations (5 R1CS constraints) + // ----------------------------------------------------------- + // Identity: a*b = p*q + r in base W=2^86. + // Carries stored with unsigned offset: cu_i = c_i + 2^88. + // + // col0: ab00 + 2^174 = p0*q0 + r0 + W*cu0 + // col1: ab01 + ab10 + cu0 + (2^174-2^88) = p0*q1 + p1*q0 + r1 + W*cu1 + // col2: ab02+ab11+ab20 + cu1 + (2^174-2^88) = p0*q2+p1*q1+p2*q0 + r2 + W*cu2 + // col3: ab12 + ab21 + cu2 + (2^174-2^88) = p1*q2 + p2*q1 + W*cu3 + // col4: ab22 + cu3 = p2*q2 + 2^88 + + // col0 + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, ab00), (offset_w, w1)], + &[(FieldElement::ONE, w1)], + &[(p0, q86[0]), (FieldElement::ONE, r86[0]), (two_86, cu[0])], + ); + + // col1 + compiler.r1cs.add_constraint( + &[ + (FieldElement::ONE, ab01), + (FieldElement::ONE, ab10), + (FieldElement::ONE, cu[0]), + (offset_w_minus_1, w1), + ], + &[(FieldElement::ONE, w1)], + &[ + (p0, q86[1]), + (p1, q86[0]), + (FieldElement::ONE, r86[1]), + (two_86, cu[1]), + ], + ); + + // col2 + compiler.r1cs.add_constraint( + &[ + (FieldElement::ONE, ab02), + (FieldElement::ONE, ab11), + (FieldElement::ONE, ab20), + (FieldElement::ONE, cu[1]), + (offset_w_minus_1, w1), + ], + &[(FieldElement::ONE, w1)], + &[ + (p0, q86[2]), + (p1, q86[1]), + (p2, q86[0]), + (FieldElement::ONE, r86[2]), + (two_86, cu[2]), + ], + ); + + // col3 + compiler.r1cs.add_constraint( + &[ + (FieldElement::ONE, ab12), + (FieldElement::ONE, ab21), + (FieldElement::ONE, cu[2]), + (offset_w_minus_1, w1), + ], + &[(FieldElement::ONE, w1)], + &[(p1, q86[2]), (p2, q86[1]), (two_86, cu[3])], + ); + + // col4 + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, ab22), (FieldElement::ONE, cu[3])], + &[(FieldElement::ONE, w1)], + &[(p2, q86[2]), (offset_fe, w1)], + ); + + // Step 5: Less-than-p check (r < p) + 128-bit range checks on r_lo, r_hi + // ----------------------------------------------------------- + less_than_p_check(compiler, range_checks, r_lo, r_hi, params); + + // Step 6: Range checks (mul-specific) + // ----------------------------------------------------------- + // 86-bit: limbs 0 and 1 of a, b, q, r + for &idx in &[ + a86[0], a86[1], b86[0], b86[1], q86[0], q86[1], r86[0], r86[1], + ] { + range_checks.entry(86).or_default().push(idx); + } + + // 84-bit: limb 2 of a, b, q, r (bits [172..256) = 84 bits) + for &idx in &[a86[2], b86[2], q86[2], r86[2]] { + range_checks.entry(84).or_default().push(idx); + } + + // 89-bit: unsigned-offset carries (|c_signed| < 2^88, so c_unsigned ∈ [0, + // 2^89)) + for &idx in &cu { + range_checks.entry(89).or_default().push(idx); + } + + Limb2 { lo: r_lo, hi: r_hi } +} + +/// a^(-1) mod p for 256-bit values in two 128-bit limbs. +/// +/// Hint-and-verify pattern: +/// 1. Prover computes inv = a^(p-2) mod p (Fermat's little theorem) +/// 2. Circuit verifies a * inv mod p = 1 +/// +/// Constraints: 26 from mul_mod_p + 2 equality checks = 28 total. +pub fn inv_mod_p( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + value: Limb2, + params: &CurveParams, +) -> Limb2 { + // Witness: inv = a^(-1) mod p (2 witnesses: lo, hi) + // ----------------------------------------------------------- + let value_inv = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::WideModularInverse { + output_start: value_inv, + a_lo: value.lo, + a_hi: value.hi, + modulus: params.field_modulus_p, + }); + let inv = Limb2 { + lo: value_inv, + hi: value_inv + 1, + }; + + // Verifying a * inv mod p = 1 + // ----------------------------------------------------------- + // computing product = value * inv mod p + let product = mul_mod_p(compiler, range_checks, value, inv, params); + // constraining product_lo = 1 (because 1 = 1 + 0 * 2^128) + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, product.lo)], + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, compiler.witness_one())], + ); + // constraining product_hi = 0 + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, product.hi)], + &[(FieldElement::ONE, compiler.witness_one())], + &[], + ); + + inv +} + +/// Verify that 128-bit limbs (v_lo, v_hi) decompose into 86-bit limbs (v86). +/// +/// Equations: +/// v_lo = v86_0 + v86_1 * 2^86 - carry * 2^128 +/// v_hi = carry + v86_2 * 2^44 +/// +/// All intermediate values < 2^172 ≪ BN254_r, so field equations = integer +/// equations. +/// +/// Creates: 1 intermediate witness (v_sum), 1 carry witness (IntegerQuotient). +/// Adds: 3 R1CS constraints (v_sum definition + 2 decomposition checks). +/// Range checks: carry (44-bit). +/// Proves r < p by decomposing (p - 1) - r into non-negative 128-bit limbs. +/// +/// If d_lo, d_hi >= 0 then (p - 1) - r >= 0, i.e. r <= p - 1 < p. +/// Uses the 2^128 offset trick to avoid negative intermediate values. +/// +/// Range checks r_lo, r_hi, d_lo, d_hi (128-bit each). +fn less_than_p_check( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + r_lo: usize, + r_hi: usize, + params: &CurveParams, +) { + let two_128 = FieldElement::from(2u64).pow([128u64]); + let p_lo_fe = params.p_lo_fe(); + let p_hi_fe = params.p_hi_fe(); + let w1 = compiler.witness_one(); + + // v_diff = (p_lo - 1) + 2^128 - r_lo + // (2^128 offset ensures v_diff is always non-negative) + let p_lo_minus_1_plus_offset = p_lo_fe - FieldElement::ONE + two_128; + let v_diff = compiler.add_sum(vec![ + SumTerm(Some(p_lo_minus_1_plus_offset), w1), + SumTerm(Some(-FieldElement::ONE), r_lo), + ]); + // borrow_compl = floor(v_diff / 2^128) + let borrow_compl = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::IntegerQuotient( + borrow_compl, + v_diff, + two_128, + )); + // d_lo = v_diff - borrow_compl * 2^128 + let d_lo = compiler.add_sum(vec![ + SumTerm(None, v_diff), + SumTerm(Some(-two_128), borrow_compl), + ]); + // d_hi = (p_hi - 1) + borrow_compl - r_hi + let d_hi = compiler.add_sum(vec![ + SumTerm(Some(p_hi_fe - FieldElement::ONE), w1), + SumTerm(None, borrow_compl), + SumTerm(Some(-FieldElement::ONE), r_hi), + ]); + + // Range checks (128-bit) + range_checks.entry(128).or_default().push(r_lo); + range_checks.entry(128).or_default().push(r_hi); + range_checks.entry(128).or_default().push(d_lo); + range_checks.entry(128).or_default().push(d_hi); +} + +fn decompose_check( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + v_lo: usize, + v_hi: usize, + v86: [usize; 3], + two_86: FieldElement, + two_44: FieldElement, + two_128: FieldElement, + w1: usize, +) { + // v_sum = v86_0 + v86_1 * 2^86 (intermediate for IntegerQuotient) + let v_sum = compiler.add_sum(vec![SumTerm(None, v86[0]), SumTerm(Some(two_86), v86[1])]); + + // carry = floor(v_sum / 2^128) ∈ [0, 2^44) + let carry = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::IntegerQuotient(carry, v_sum, two_128)); + + // Low check: v_sum - carry * 2^128 = v_lo + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, v_sum), (-two_128, carry)], + &[(FieldElement::ONE, w1)], + &[(FieldElement::ONE, v_lo)], + ); + + // High check: carry + v86_2 * 2^44 = v_hi + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, carry), (two_44, v86[2])], + &[(FieldElement::ONE, w1)], + &[(FieldElement::ONE, v_hi)], + ); + + // Range check carry (44-bit) + range_checks.entry(44).or_default().push(carry); +} From 60254988d51e1a12d7b331c6007a6aab35945bff Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Tue, 3 Mar 2026 04:13:48 +0530 Subject: [PATCH 3/5] feat : added dynamic multi limb approach with cost model for msm black box --- noir-examples/embedded_curve_msm/Nargo.toml | 7 + noir-examples/embedded_curve_msm/Prover.toml | 5 + noir-examples/embedded_curve_msm/src/main.nr | 51 ++ noir-examples/native_msm/Nargo.toml | 7 + noir-examples/native_msm/Prover.toml | 5 + noir-examples/native_msm/src/main.nr | 104 +++ .../src/witness/scheduling/dependency.rs | 67 +- .../common/src/witness/scheduling/remapper.rs | 85 +-- .../common/src/witness/witness_builder.rs | 87 +-- provekit/prover/src/lib.rs | 52 ++ .../prover/src/witness/witness_builder.rs | 353 +++++++---- provekit/r1cs-compiler/src/msm/cost_model.rs | 362 +++++++++++ provekit/r1cs-compiler/src/msm/curve.rs | 163 +++-- provekit/r1cs-compiler/src/msm/ec_ops.rs | 208 ------ provekit/r1cs-compiler/src/msm/ec_points.rs | 183 ++++++ provekit/r1cs-compiler/src/msm/mod.rs | 510 ++++++++++++--- .../r1cs-compiler/src/msm/multi_limb_arith.rs | 591 ++++++++++++++++++ .../r1cs-compiler/src/msm/multi_limb_ops.rs | 275 ++++++++ provekit/r1cs-compiler/src/msm/wide_ops.rs | 563 ----------------- provekit/r1cs-compiler/src/noir_to_r1cs.rs | 39 +- 20 files changed, 2600 insertions(+), 1117 deletions(-) create mode 100644 noir-examples/embedded_curve_msm/Nargo.toml create mode 100644 noir-examples/embedded_curve_msm/Prover.toml create mode 100644 noir-examples/embedded_curve_msm/src/main.nr create mode 100644 noir-examples/native_msm/Nargo.toml create mode 100644 noir-examples/native_msm/Prover.toml create mode 100644 noir-examples/native_msm/src/main.nr create mode 100644 provekit/r1cs-compiler/src/msm/cost_model.rs delete mode 100644 provekit/r1cs-compiler/src/msm/ec_ops.rs create mode 100644 provekit/r1cs-compiler/src/msm/multi_limb_arith.rs create mode 100644 provekit/r1cs-compiler/src/msm/multi_limb_ops.rs delete mode 100644 provekit/r1cs-compiler/src/msm/wide_ops.rs diff --git a/noir-examples/embedded_curve_msm/Nargo.toml b/noir-examples/embedded_curve_msm/Nargo.toml new file mode 100644 index 000000000..ec9891616 --- /dev/null +++ b/noir-examples/embedded_curve_msm/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "embedded_curve_msm" +type = "bin" +authors = [""] +compiler_version = ">=0.22.0" + +[dependencies] diff --git a/noir-examples/embedded_curve_msm/Prover.toml b/noir-examples/embedded_curve_msm/Prover.toml new file mode 100644 index 000000000..58c6933da --- /dev/null +++ b/noir-examples/embedded_curve_msm/Prover.toml @@ -0,0 +1,5 @@ +# MSM: result = s1 * G + s2 * G = 1*G + 2*G = 3*G +scalar1_lo = "1" +scalar1_hi = "0" +scalar2_lo = "2" +scalar2_hi = "0" diff --git a/noir-examples/embedded_curve_msm/src/main.nr b/noir-examples/embedded_curve_msm/src/main.nr new file mode 100644 index 000000000..cf0704211 --- /dev/null +++ b/noir-examples/embedded_curve_msm/src/main.nr @@ -0,0 +1,51 @@ +use std::embedded_curve_ops::{ + EmbeddedCurvePoint, + EmbeddedCurveScalar, + multi_scalar_mul, +}; + +/// Exercises the MultiScalarMul ACIR blackbox with 2 Grumpkin points. +/// Computes s1 * G + s2 * G where G is the Grumpkin generator. +fn main( + scalar1_lo: Field, + scalar1_hi: Field, + scalar2_lo: Field, + scalar2_hi: Field, +) { + // Grumpkin generator + let g = EmbeddedCurvePoint { + x: 1, + y: 17631683881184975370165255887551781615748388533673675138860, + is_infinite: false, + }; + + let s1 = EmbeddedCurveScalar { lo: scalar1_lo, hi: scalar1_hi }; + let s2 = EmbeddedCurveScalar { lo: scalar2_lo, hi: scalar2_hi }; + + // MSM: result = s1 * G + s2 * G + let result = multi_scalar_mul([g, g], [s1, s2]); + + // Prevent dead-code elimination - forces the blackbox to be retained + assert(!result.is_infinite); +} + +#[test] +fn test_msm() { + // 3*G on Grumpkin + let expected_x = 18660890509582237958343981571981920822503400000196279471655180441138020044621; + let expected_y = 8902249110305491597038405103722863701255802573786510474664632793109847672620; + + main(1, 0, 2, 0); + + // Verify by computing independently: 3*G should match + let g = EmbeddedCurvePoint { + x: 1, + y: 17631683881184975370165255887551781615748388533673675138860, + is_infinite: false, + }; + let s3 = EmbeddedCurveScalar { lo: 3, hi: 0 }; + let check = multi_scalar_mul([g], [s3]); + + assert(check.x == expected_x); + assert(check.y == expected_y); +} diff --git a/noir-examples/native_msm/Nargo.toml b/noir-examples/native_msm/Nargo.toml new file mode 100644 index 000000000..5ff116db7 --- /dev/null +++ b/noir-examples/native_msm/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "native_msm" +type = "bin" +authors = [""] +compiler_version = ">=0.22.0" + +[dependencies] diff --git a/noir-examples/native_msm/Prover.toml b/noir-examples/native_msm/Prover.toml new file mode 100644 index 000000000..58c6933da --- /dev/null +++ b/noir-examples/native_msm/Prover.toml @@ -0,0 +1,5 @@ +# MSM: result = s1 * G + s2 * G = 1*G + 2*G = 3*G +scalar1_lo = "1" +scalar1_hi = "0" +scalar2_lo = "2" +scalar2_hi = "0" diff --git a/noir-examples/native_msm/src/main.nr b/noir-examples/native_msm/src/main.nr new file mode 100644 index 000000000..80cfd3d0f --- /dev/null +++ b/noir-examples/native_msm/src/main.nr @@ -0,0 +1,104 @@ +// Grumpkin generator y-coordinate +global GRUMPKIN_GEN_Y: Field = 17631683881184975370165255887551781615748388533673675138860; + +struct Point { + x: Field, + y: Field, + is_infinite: bool, +} + +fn point_double(p: Point) -> Point { + if p.is_infinite | (p.y == 0) { + Point { x: 0, y: 0, is_infinite: true } + } else { + // Grumpkin has a=0, so lambda = 3*x1^2 / (2*y1) + let lambda = (3 * p.x * p.x) / (2 * p.y); + let x3 = lambda * lambda - 2 * p.x; + let y3 = lambda * (p.x - x3) - p.y; + Point { x: x3, y: y3, is_infinite: false } + } +} + +fn point_add(p1: Point, p2: Point) -> Point { + if p1.is_infinite { + p2 + } else if p2.is_infinite { + p1 + } else if (p1.x == p2.x) & (p1.y == p2.y) { + point_double(p1) + } else if (p1.x == p2.x) & (p1.y == (0 - p2.y)) { + Point { x: 0, y: 0, is_infinite: true } + } else { + let lambda = (p2.y - p1.y) / (p2.x - p1.x); + let x3 = lambda * lambda - p1.x - p2.x; + let y3 = lambda * (p1.x - x3) - p1.y; + Point { x: x3, y: y3, is_infinite: false } + } +} + +fn scalar_mul(p: Point, scalar_lo: Field, scalar_hi: Field) -> Point { + let lo_bits: [u1; 128] = scalar_lo.to_le_bits(); + let hi_bits: [u1; 128] = scalar_hi.to_le_bits(); + + // Combine into a single 256-bit array (lo first, then hi) + let mut bits: [u1; 256] = [0; 256]; + for i in 0..128 { + bits[i] = lo_bits[i]; + bits[128 + i] = hi_bits[i]; + } + + // Find the highest set bit + let mut top = 0; + for i in 0..256 { + if bits[i] == 1 { + top = i; + } + } + + // Double-and-add from MSB down to bit 0 + let mut acc = Point { x: 0, y: 0, is_infinite: true }; + for j in 0..256 { + let i = 255 - j; + acc = point_double(acc); + if bits[i] == 1 { + acc = point_add(acc, p); + } + } + + acc +} + +/// Native MSM: computes s1 * G + s2 * G using pure Noir field operations. +/// No blackbox functions -- all EC arithmetic is done natively over Grumpkin's +/// base field (= BN254 scalar field = Noir's native Field). +fn main( + scalar1_lo: Field, + scalar1_hi: Field, + scalar2_lo: Field, + scalar2_hi: Field, +) { + let g = Point { x: 1, y: GRUMPKIN_GEN_Y, is_infinite: false }; + + let r1 = scalar_mul(g, scalar1_lo, scalar1_hi); + let r2 = scalar_mul(g, scalar2_lo, scalar2_hi); + let result = point_add(r1, r2); + + // Prevent dead-code elimination + assert(!result.is_infinite); +} + +#[test] +fn test_msm() { + // 3*G on Grumpkin (known coordinates) + let expected_x = 18660890509582237958343981571981920822503400000196279471655180441138020044621; + let expected_y = 8902249110305491597038405103722863701255802573786510474664632793109847672620; + + main(1, 0, 2, 0); + + // Verify 1*G + 2*G = 3*G by computing 3*G directly + let g = Point { x: 1, y: GRUMPKIN_GEN_Y, is_infinite: false }; + let three_g = scalar_mul(g, 3, 0); + + assert(three_g.x == expected_x); + assert(three_g.y == expected_y); +} diff --git a/provekit/common/src/witness/scheduling/dependency.rs b/provekit/common/src/witness/scheduling/dependency.rs index 956f79b56..9f92afd75 100644 --- a/provekit/common/src/witness/scheduling/dependency.rs +++ b/provekit/common/src/witness/scheduling/dependency.rs @@ -79,6 +79,7 @@ impl DependencyInfo { WitnessBuilder::Product(_, a, b) => vec![*a, *b], WitnessBuilder::MultiplicitiesForRange(_, _, values) => values.clone(), WitnessBuilder::Inverse(_, x) + | WitnessBuilder::SafeInverse(_, x) | WitnessBuilder::ModularInverse(_, x, _) | WitnessBuilder::IntegerQuotient(_, x, _) => vec![*x], WitnessBuilder::IndexedLogUpDenominator( @@ -154,28 +155,34 @@ impl DependencyInfo { } v } - WitnessBuilder::MulModHint { - a_lo, - a_hi, - b_lo, - b_hi, + WitnessBuilder::MultiLimbMulModHint { + a_limbs, + b_limbs, .. - } => vec![*a_lo, *a_hi, *b_lo, *b_hi], - WitnessBuilder::WideModularInverse { a_lo, a_hi, .. } => vec![*a_lo, *a_hi], - WitnessBuilder::WideAddQuotient { - a_lo, - a_hi, - b_lo, - b_hi, + } => { + let mut v = a_limbs.clone(); + v.extend(b_limbs); + v + } + WitnessBuilder::MultiLimbModularInverse { a_limbs, .. } => a_limbs.clone(), + WitnessBuilder::MultiLimbAddQuotient { + a_limbs, + b_limbs, .. - } => vec![*a_lo, *a_hi, *b_lo, *b_hi], - WitnessBuilder::WideSubBorrow { - a_lo, - a_hi, - b_lo, - b_hi, + } => { + let mut v = a_limbs.clone(); + v.extend(b_limbs); + v + } + WitnessBuilder::MultiLimbSubBorrow { + a_limbs, + b_limbs, .. - } => vec![*a_lo, *a_hi, *b_lo, *b_hi], + } => { + let mut v = a_limbs.clone(); + v.extend(b_limbs); + v + } WitnessBuilder::BytePartition { x, .. } => vec![*x], WitnessBuilder::U32AdditionMulti(_, _, inputs) => inputs @@ -264,6 +271,7 @@ impl DependencyInfo { | WitnessBuilder::Challenge(idx) | WitnessBuilder::IndexedLogUpDenominator(idx, ..) | WitnessBuilder::Inverse(idx, _) + | WitnessBuilder::SafeInverse(idx, _) | WitnessBuilder::ModularInverse(idx, ..) | WitnessBuilder::IntegerQuotient(idx, ..) | WitnessBuilder::ProductLinearOperation(idx, ..) @@ -308,14 +316,21 @@ impl DependencyInfo { let n = 1usize << *num_bits; (*start..*start + n).collect() } - WitnessBuilder::MulModHint { output_start, .. } => { - (*output_start..*output_start + 20).collect() - } - WitnessBuilder::WideModularInverse { output_start, .. } => { - (*output_start..*output_start + 2).collect() + WitnessBuilder::MultiLimbMulModHint { + output_start, + num_limbs, + .. + } => { + let count = (4 * *num_limbs - 2) as usize; + (*output_start..*output_start + count).collect() } - WitnessBuilder::WideAddQuotient { output, .. } => vec![*output], - WitnessBuilder::WideSubBorrow { output, .. } => vec![*output], + WitnessBuilder::MultiLimbModularInverse { + output_start, + num_limbs, + .. + } => (*output_start..*output_start + *num_limbs as usize).collect(), + WitnessBuilder::MultiLimbAddQuotient { output, .. } => vec![*output], + WitnessBuilder::MultiLimbSubBorrow { output, .. } => vec![*output], WitnessBuilder::U32Addition(result_idx, carry_idx, ..) => { vec![*result_idx, *carry_idx] } diff --git a/provekit/common/src/witness/scheduling/remapper.rs b/provekit/common/src/witness/scheduling/remapper.rs index 47490a6ce..334b5f401 100644 --- a/provekit/common/src/witness/scheduling/remapper.rs +++ b/provekit/common/src/witness/scheduling/remapper.rs @@ -115,6 +115,9 @@ impl WitnessIndexRemapper { WitnessBuilder::Inverse(idx, operand) => { WitnessBuilder::Inverse(self.remap(*idx), self.remap(*operand)) } + WitnessBuilder::SafeInverse(idx, operand) => { + WitnessBuilder::SafeInverse(self.remap(*idx), self.remap(*operand)) + } WitnessBuilder::ModularInverse(idx, operand, modulus) => { WitnessBuilder::ModularInverse(self.remap(*idx), self.remap(*operand), *modulus) } @@ -221,59 +224,63 @@ impl WitnessIndexRemapper { .collect(), ) } - WitnessBuilder::MulModHint { + WitnessBuilder::MultiLimbMulModHint { output_start, - a_lo, - a_hi, - b_lo, - b_hi, + a_limbs, + b_limbs, modulus, - } => WitnessBuilder::MulModHint { + limb_bits, + num_limbs, + } => WitnessBuilder::MultiLimbMulModHint { output_start: self.remap(*output_start), - a_lo: self.remap(*a_lo), - a_hi: self.remap(*a_hi), - b_lo: self.remap(*b_lo), - b_hi: self.remap(*b_hi), + a_limbs: a_limbs.iter().map(|&w| self.remap(w)).collect(), + b_limbs: b_limbs.iter().map(|&w| self.remap(w)).collect(), modulus: *modulus, + limb_bits: *limb_bits, + num_limbs: *num_limbs, }, - WitnessBuilder::WideModularInverse { + WitnessBuilder::MultiLimbModularInverse { output_start, - a_lo, - a_hi, + a_limbs, modulus, - } => WitnessBuilder::WideModularInverse { + limb_bits, + num_limbs, + } => WitnessBuilder::MultiLimbModularInverse { output_start: self.remap(*output_start), - a_lo: self.remap(*a_lo), - a_hi: self.remap(*a_hi), + a_limbs: a_limbs.iter().map(|&w| self.remap(w)).collect(), modulus: *modulus, + limb_bits: *limb_bits, + num_limbs: *num_limbs, }, - WitnessBuilder::WideAddQuotient { + WitnessBuilder::MultiLimbAddQuotient { output, - a_lo, - a_hi, - b_lo, - b_hi, + a_limbs, + b_limbs, modulus, - } => WitnessBuilder::WideAddQuotient { - output: self.remap(*output), - a_lo: self.remap(*a_lo), - a_hi: self.remap(*a_hi), - b_lo: self.remap(*b_lo), - b_hi: self.remap(*b_hi), - modulus: *modulus, + limb_bits, + num_limbs, + } => WitnessBuilder::MultiLimbAddQuotient { + output: self.remap(*output), + a_limbs: a_limbs.iter().map(|&w| self.remap(w)).collect(), + b_limbs: b_limbs.iter().map(|&w| self.remap(w)).collect(), + modulus: *modulus, + limb_bits: *limb_bits, + num_limbs: *num_limbs, }, - WitnessBuilder::WideSubBorrow { + WitnessBuilder::MultiLimbSubBorrow { output, - a_lo, - a_hi, - b_lo, - b_hi, - } => WitnessBuilder::WideSubBorrow { - output: self.remap(*output), - a_lo: self.remap(*a_lo), - a_hi: self.remap(*a_hi), - b_lo: self.remap(*b_lo), - b_hi: self.remap(*b_hi), + a_limbs, + b_limbs, + modulus, + limb_bits, + num_limbs, + } => WitnessBuilder::MultiLimbSubBorrow { + output: self.remap(*output), + a_limbs: a_limbs.iter().map(|&w| self.remap(w)).collect(), + b_limbs: b_limbs.iter().map(|&w| self.remap(w)).collect(), + modulus: *modulus, + limb_bits: *limb_bits, + num_limbs: *num_limbs, }, WitnessBuilder::BytePartition { lo, hi, x, k } => WitnessBuilder::BytePartition { lo: self.remap(*lo), diff --git a/provekit/common/src/witness/witness_builder.rs b/provekit/common/src/witness/witness_builder.rs index 6d11d17cf..28d6d775c 100644 --- a/provekit/common/src/witness/witness_builder.rs +++ b/provekit/common/src/witness/witness_builder.rs @@ -88,6 +88,11 @@ pub enum WitnessBuilder { /// The inverse of the value at a specified witness index /// (witness index, operand witness index) Inverse(usize, usize), + /// Safe inverse: like Inverse but handles zero by outputting 0. + /// Used by compute_is_zero where the input may be zero. Solved in the + /// Other layer (not batch-inverted), so zero inputs don't poison the batch. + /// (witness index, operand witness index) + SafeInverse(usize, usize), /// The modular inverse of the value at a specified witness index, modulo /// a given prime modulus. Computes a^{-1} mod m using Fermat's little /// theorem (a^{m-2} mod m). Unlike Inverse (BN254 field inverse), this @@ -202,61 +207,59 @@ pub enum WitnessBuilder { /// Computes: 1 / (sz - lhs - rs*rhs - rs²*and_out - rs³*xor_out) CombinedTableEntryInverse(CombinedTableEntryInverseData), /// Prover hint for multi-limb modular multiplication: (a * b) mod p. - /// Given inputs a = (a_lo, a_hi) and b = (b_lo, b_hi) as 128-bit limbs, + /// Given inputs a and b as N-limb vectors (each limb `limb_bits` wide), /// and a constant 256-bit modulus p, computes quotient q, remainder r, - /// their 86-bit decompositions, and carry witnesses. + /// and carry witnesses for schoolbook column verification. /// - /// Outputs 20 witnesses starting at output_start: - /// [0..2) q_lo, q_hi (quotient in 128-bit limbs) - /// [2..4) r_lo, r_hi (remainder in 128-bit limbs) - /// [4..7) a_86_0, a_86_1, a_86_2 (a in 86-bit limbs) - /// [7..10) b_86_0, b_86_1, b_86_2 (b in 86-bit limbs) - /// [10..13) q_86_0, q_86_1, q_86_2 (q in 86-bit limbs) - /// [13..16) r_86_0, r_86_1, r_86_2 (r in 86-bit limbs) - /// [16..20) c0, c1, c2, c3 (carry witnesses, unsigned-offset) - MulModHint { + /// Outputs (4*num_limbs - 2) witnesses starting at output_start: + /// [0..N) q limbs (quotient) + /// [N..2N) r limbs (remainder) — OUTPUT + /// [2N..4N-2) carry witnesses (unsigned-offset) + MultiLimbMulModHint { output_start: usize, - a_lo: usize, - a_hi: usize, - b_lo: usize, - b_hi: usize, + a_limbs: Vec, + b_limbs: Vec, modulus: [u64; 4], + limb_bits: u32, + num_limbs: u32, }, - /// Prover hint for wide modular inverse: a^{-1} mod p. - /// Given input a = (a_lo, a_hi) as 128-bit limbs and constant modulus p, + /// Prover hint for multi-limb modular inverse: a^{-1} mod p. + /// Given input a as N-limb vector and constant modulus p, /// computes the inverse via Fermat's little theorem (a^{p-2} mod p). /// - /// Outputs 2 witnesses at output_start: inv_lo, inv_hi (128-bit limbs). - WideModularInverse { + /// Outputs num_limbs witnesses at output_start: inv limbs. + MultiLimbModularInverse { output_start: usize, - a_lo: usize, - a_hi: usize, + a_limbs: Vec, modulus: [u64; 4], + limb_bits: u32, + num_limbs: u32, }, - /// Prover hint for wide addition quotient: q = floor((a + b) / p). - /// Given inputs a = (a_lo, a_hi) and b = (b_lo, b_hi) as 128-bit limbs, - /// and a constant 256-bit modulus p, computes q ∈ {0, 1}. + /// Prover hint for multi-limb addition quotient: q = floor((a + b) / p). + /// Given inputs a and b as N-limb vectors, and a constant modulus p, + /// computes q ∈ {0, 1}. /// /// Outputs 1 witness at output: q. - WideAddQuotient { - output: usize, - a_lo: usize, - a_hi: usize, - b_lo: usize, - b_hi: usize, - modulus: [u64; 4], + MultiLimbAddQuotient { + output: usize, + a_limbs: Vec, + b_limbs: Vec, + modulus: [u64; 4], + limb_bits: u32, + num_limbs: u32, }, - /// Prover hint for wide subtraction borrow: q = (a < b) ? 1 : 0. - /// Given inputs a = (a_lo, a_hi) and b = (b_lo, b_hi) as 128-bit limbs, + /// Prover hint for multi-limb subtraction borrow: q = (a < b) ? 1 : 0. + /// Given inputs a and b as N-limb vectors, and a constant modulus p, /// computes q ∈ {0, 1} indicating whether a borrow (adding p) is needed. /// /// Outputs 1 witness at output: q. - WideSubBorrow { - output: usize, - a_lo: usize, - a_hi: usize, - b_lo: usize, - b_hi: usize, + MultiLimbSubBorrow { + output: usize, + a_limbs: Vec, + b_limbs: Vec, + modulus: [u64; 4], + limb_bits: u32, + num_limbs: u32, }, /// Decomposes a packed value into chunks of specified bit-widths. /// Given packed value and chunk_bits = [b0, b1, ..., bn]: @@ -329,8 +332,10 @@ impl WitnessBuilder { WitnessBuilder::ChunkDecompose { chunk_bits, .. } => chunk_bits.len(), WitnessBuilder::SpreadBitExtract { chunk_bits, .. } => chunk_bits.len(), WitnessBuilder::MultiplicitiesForSpread(_, num_bits, _) => 1usize << *num_bits, - WitnessBuilder::MulModHint { .. } => 20, - WitnessBuilder::WideModularInverse { .. } => 2, + WitnessBuilder::MultiLimbMulModHint { num_limbs, .. } => { + (4 * *num_limbs - 2) as usize + } + WitnessBuilder::MultiLimbModularInverse { num_limbs, .. } => *num_limbs as usize, _ => 1, } diff --git a/provekit/prover/src/lib.rs b/provekit/prover/src/lib.rs index a17fc1265..f6a3e653f 100644 --- a/provekit/prover/src/lib.rs +++ b/provekit/prover/src/lib.rs @@ -192,6 +192,58 @@ impl Prove for Prover { .map(|(i, w)| w.ok_or_else(|| anyhow::anyhow!("Witness {i} unsolved after solving"))) .collect::>>()?; + // DEBUG: Check R1CS constraint satisfaction with ALL witnesses solved + { + use ark_ff::Zero; + let debug_r1cs = r1cs.clone(); + let interner = &debug_r1cs.interner; + let ha = debug_r1cs.a.hydrate(interner); + let hb = debug_r1cs.b.hydrate(interner); + let hc = debug_r1cs.c.hydrate(interner); + let mut fail_count = 0usize; + for row in 0..debug_r1cs.num_constraints() { + let eval = |hm: &provekit_common::sparse_matrix::HydratedSparseMatrix, r: usize| -> FieldElement { + let mut sum = FieldElement::zero(); + for (col, coeff) in hm.iter_row(r) { + sum += coeff * full_witness[col]; + } + sum + }; + let a_val = eval(&ha, row); + let b_val = eval(&hb, row); + let c_val = eval(&hc, row); + if a_val * b_val != c_val { + if fail_count < 10 { + eprintln!( + "CONSTRAINT {} FAILED: A={:?} B={:?} C={:?} A*B={:?}", + row, a_val, b_val, c_val, a_val * b_val + ); + eprint!(" A terms:"); + for (col, coeff) in ha.iter_row(row) { + eprint!(" w{}={:?}*{:?}", col, full_witness[col], coeff); + } + eprintln!(); + eprint!(" B terms:"); + for (col, coeff) in hb.iter_row(row) { + eprint!(" w{}={:?}*{:?}", col, full_witness[col], coeff); + } + eprintln!(); + eprint!(" C terms:"); + for (col, coeff) in hc.iter_row(row) { + eprint!(" w{}={:?}*{:?}", col, full_witness[col], coeff); + } + eprintln!(); + } + fail_count += 1; + } + } + if fail_count > 0 { + eprintln!("TOTAL FAILING CONSTRAINTS: {fail_count} / {}", debug_r1cs.num_constraints()); + } else { + eprintln!("ALL {} CONSTRAINTS SATISFIED", debug_r1cs.num_constraints()); + } + } + let whir_r1cs_proof = self .whir_for_witness .prove(merlin, r1cs, commitments, full_witness, &public_inputs) diff --git a/provekit/prover/src/witness/witness_builder.rs b/provekit/prover/src/witness/witness_builder.rs index ae49cfcd5..d3479331b 100644 --- a/provekit/prover/src/witness/witness_builder.rs +++ b/provekit/prover/src/witness/witness_builder.rs @@ -1,7 +1,7 @@ use { crate::witness::{digits::DigitalDecompositionWitnessesSolver, ram::SpiceWitnessesSolver}, acir::native_types::WitnessMap, - ark_ff::{BigInteger, PrimeField}, + ark_ff::{BigInteger, Field, PrimeField}, ark_std::Zero, provekit_common::{ utils::noir_to_native, @@ -65,6 +65,14 @@ impl WitnessBuilderSolver for WitnessBuilder { "Inverse/LogUpInverse should not be called - handled by batch inversion" ) } + WitnessBuilder::SafeInverse(witness_idx, operand_idx) => { + let val = witness[*operand_idx].unwrap(); + witness[*witness_idx] = Some(if val == FieldElement::zero() { + FieldElement::zero() + } else { + val.inverse().unwrap() + }); + } WitnessBuilder::ModularInverse(witness_idx, operand_idx, modulus) => { let a = witness[*operand_idx].unwrap(); let a_limbs = a.into_bigint().0; @@ -337,61 +345,135 @@ impl WitnessBuilderSolver for WitnessBuilder { lh_val.into_bigint() ^ rh_val.into_bigint(), )); } - WitnessBuilder::MulModHint { + WitnessBuilder::MultiLimbMulModHint { output_start, - a_lo, - a_hi, - b_lo, - b_hi, + a_limbs, + b_limbs, modulus, + limb_bits, + num_limbs, } => { - use crate::witness::bigint_mod::{ - compute_carries_86, decompose_128, decompose_86, divmod_wide, widening_mul, - CARRY_OFFSET, + use crate::witness::bigint_mod::{divmod_wide, widening_mul}; + let n = *num_limbs as usize; + let w = *limb_bits; + let limb_mask: u128 = if w >= 128 { + u128::MAX + } else { + (1u128 << w) - 1 }; - // Read inputs: a and b as 128-bit limb pairs - let a_lo_fe = witness[*a_lo].unwrap(); - let a_hi_fe = witness[*a_hi].unwrap(); - let b_lo_fe = witness[*b_lo].unwrap(); - let b_hi_fe = witness[*b_hi].unwrap(); - - // Reconstruct a, b as [u64; 4] - let a_lo_limbs = a_lo_fe.into_bigint().0; - let a_hi_limbs = a_hi_fe.into_bigint().0; - let a_val = [a_lo_limbs[0], a_lo_limbs[1], a_hi_limbs[0], a_hi_limbs[1]]; + // Reconstruct a, b as [u64; 4] from N limbs + let reconstruct = |limbs: &[usize]| -> [u64; 4] { + let mut val = [0u64; 4]; + let mut bit_offset = 0u32; + for &limb_idx in limbs.iter() { + let limb_val = witness[limb_idx].unwrap().into_bigint().0; + let limb_u128 = limb_val[0] as u128 | ((limb_val[1] as u128) << 64); + // Place into val at bit_offset + let word_start = (bit_offset / 64) as usize; + let bit_within = bit_offset % 64; + if word_start < 4 { + val[word_start] |= (limb_u128 as u64) << bit_within; + if word_start + 1 < 4 { + val[word_start + 1] |= (limb_u128 >> (64 - bit_within)) as u64; + } + if word_start + 2 < 4 && bit_within > 0 { + let upper = limb_u128 >> (128 - bit_within); + if upper > 0 { + val[word_start + 2] |= upper as u64; + } + } + } + bit_offset += w; + } + val + }; - let b_lo_limbs = b_lo_fe.into_bigint().0; - let b_hi_limbs = b_hi_fe.into_bigint().0; - let b_val = [b_lo_limbs[0], b_lo_limbs[1], b_hi_limbs[0], b_hi_limbs[1]]; + let a_val = reconstruct(a_limbs); + let b_val = reconstruct(b_limbs); // Compute product and divmod let product = widening_mul(&a_val, &b_val); let (q_val, r_val) = divmod_wide(&product, modulus); - // Decompose into 128-bit limbs - let (q_lo, q_hi) = decompose_128(&q_val); - let (r_lo, r_hi) = decompose_128(&r_val); + // Decompose a [u64;4] into N limbs of limb_bits width. + let decompose_n_from_u64 = |val: &[u64; 4]| -> Vec { + let mut limbs = Vec::with_capacity(n); + let mut remaining = *val; + for _ in 0..n { + let lo = remaining[0] as u128 | ((remaining[1] as u128) << 64); + limbs.push(lo & limb_mask); + // Shift right by w bits + if w >= 256 { + remaining = [0; 4]; + } else { + let mut shifted = [0u64; 4]; + let word_shift = (w / 64) as usize; + let bit_shift = w % 64; + for i in 0..4 { + if i + word_shift < 4 { + shifted[i] = remaining[i + word_shift] >> bit_shift; + if bit_shift > 0 && i + word_shift + 1 < 4 { + shifted[i] |= + remaining[i + word_shift + 1] << (64 - bit_shift); + } + } + } + remaining = shifted; + } + } + limbs + }; + + let q_limbs_vals = decompose_n_from_u64(&q_val); + let r_limbs_vals = decompose_n_from_u64(&r_val); - // Decompose into 86-bit limbs - let (a86_0, a86_1, a86_2) = decompose_86(&a_val); - let (b86_0, b86_1, b86_2) = decompose_86(&b_val); - let (q86_0, q86_1, q86_2) = decompose_86(&q_val); - let (r86_0, r86_1, r86_2) = decompose_86(&r_val); + // Compute carries for schoolbook verification: + // a·b = p·q + r in base W = 2^limb_bits + // For each column k (0..2N-2): + // lhs_k = Σ_{i+j=k} a[i]*b[j] + carry_{k-1} + // rhs_k = Σ_{i+j=k} p[i]*q[j] + r[k] + carry_k * W + let p_limbs_vals = decompose_n_from_u64(modulus); + let a_limbs_vals = decompose_n_from_u64(&a_val); + let b_limbs_vals = decompose_n_from_u64(&b_val); - // Compute carries - let carries = compute_carries_86( - [a86_0, a86_1, a86_2], - [b86_0, b86_1, b86_2], - { - let (p0, p1, p2) = decompose_86(modulus); - [p0, p1, p2] - }, - [q86_0, q86_1, q86_2], - [r86_0, r86_1, r86_2], - ); + let w_val = 1u128 << w; + let num_carries = 2 * n - 2; + let carry_offset = 1u128 << (w + ((n as f64).log2().ceil() as u32) + 1); + let mut carries = Vec::with_capacity(num_carries); + let mut running: i128 = 0; + + for k in 0..(2 * n - 1) { + // Sum a[i]*b[j] for i+j=k + let mut ab_sum: i128 = 0; + for i in 0..n { + let j = k as isize - i as isize; + if j >= 0 && (j as usize) < n { + ab_sum += + a_limbs_vals[i] as i128 * b_limbs_vals[j as usize] as i128; + } + } + // Sum p[i]*q[j] for i+j=k + let mut pq_sum: i128 = 0; + for i in 0..n { + let j = k as isize - i as isize; + if j >= 0 && (j as usize) < n { + pq_sum += + p_limbs_vals[i] as i128 * q_limbs_vals[j as usize] as i128; + } + } + let r_k = if k < n { r_limbs_vals[k] as i128 } else { 0 }; + + // column: ab_sum + carry_prev = pq_sum + r_k + carry_next * W + // carry_next = (ab_sum + carry_prev - pq_sum - r_k) / W + running += ab_sum - pq_sum - r_k; + if k < 2 * n - 2 { + let carry = running / w_val as i128; + carries.push(carry); + running -= carry * w_val as i128; + } + } - // Helper: convert u128 to FieldElement let u128_to_fe = |val: u128| -> FieldElement { FieldElement::from_bigint(ark_ff::BigInt([ val as u64, @@ -402,57 +484,59 @@ impl WitnessBuilderSolver for WitnessBuilder { .unwrap() }; - // Write outputs: [0..2) q_lo, q_hi - witness[*output_start] = Some(u128_to_fe(q_lo)); - witness[*output_start + 1] = Some(u128_to_fe(q_hi)); - // [2..4) r_lo, r_hi - witness[*output_start + 2] = Some(u128_to_fe(r_lo)); - witness[*output_start + 3] = Some(u128_to_fe(r_hi)); - // [4..7) a_86 limbs - witness[*output_start + 4] = Some(u128_to_fe(a86_0)); - witness[*output_start + 5] = Some(u128_to_fe(a86_1)); - witness[*output_start + 6] = Some(u128_to_fe(a86_2)); - // [7..10) b_86 limbs - witness[*output_start + 7] = Some(u128_to_fe(b86_0)); - witness[*output_start + 8] = Some(u128_to_fe(b86_1)); - witness[*output_start + 9] = Some(u128_to_fe(b86_2)); - // [10..13) q_86 limbs - witness[*output_start + 10] = Some(u128_to_fe(q86_0)); - witness[*output_start + 11] = Some(u128_to_fe(q86_1)); - witness[*output_start + 12] = Some(u128_to_fe(q86_2)); - // [13..16) r_86 limbs - witness[*output_start + 13] = Some(u128_to_fe(r86_0)); - witness[*output_start + 14] = Some(u128_to_fe(r86_1)); - witness[*output_start + 15] = Some(u128_to_fe(r86_2)); - // [16..20) carries (unsigned-offset) - for i in 0..4 { - let c_unsigned = (carries[i] + CARRY_OFFSET as i128) as u128; - witness[*output_start + 16 + i] = Some(u128_to_fe(c_unsigned)); + // Write q limbs + for i in 0..n { + witness[*output_start + i] = Some(u128_to_fe(q_limbs_vals[i])); + } + // Write r limbs + for i in 0..n { + witness[*output_start + n + i] = Some(u128_to_fe(r_limbs_vals[i])); + } + // Write carries (unsigned-offset) + for i in 0..num_carries { + let c_unsigned = (carries[i] + carry_offset as i128) as u128; + witness[*output_start + 2 * n + i] = Some(u128_to_fe(c_unsigned)); } } - WitnessBuilder::WideModularInverse { + WitnessBuilder::MultiLimbModularInverse { output_start, - a_lo, - a_hi, + a_limbs, modulus, + limb_bits, + num_limbs, } => { - use crate::witness::bigint_mod::{decompose_128, mod_pow, sub_u64}; - - // Read input a as 128-bit limb pair - let a_lo_fe = witness[*a_lo].unwrap(); - let a_hi_fe = witness[*a_hi].unwrap(); + use crate::witness::bigint_mod::{mod_pow, sub_u64}; + let n = *num_limbs as usize; + let w = *limb_bits; + let limb_mask: u128 = if w >= 128 { + u128::MAX + } else { + (1u128 << w) - 1 + }; - let a_lo_limbs = a_lo_fe.into_bigint().0; - let a_hi_limbs = a_hi_fe.into_bigint().0; - let a_val = [a_lo_limbs[0], a_lo_limbs[1], a_hi_limbs[0], a_hi_limbs[1]]; + // Reconstruct a as [u64; 4] from N limbs + let mut a_val = [0u64; 4]; + let mut bit_offset = 0u32; + for &limb_idx in a_limbs.iter() { + let limb_val = witness[limb_idx].unwrap().into_bigint().0; + let limb_u128 = limb_val[0] as u128 | ((limb_val[1] as u128) << 64); + let word_start = (bit_offset / 64) as usize; + let bit_within = bit_offset % 64; + if word_start < 4 { + a_val[word_start] |= (limb_u128 as u64) << bit_within; + if word_start + 1 < 4 { + a_val[word_start + 1] |= (limb_u128 >> (64 - bit_within)) as u64; + } + } + bit_offset += w; + } - // Compute inverse: a^{p-2} mod p (Fermat's little theorem) + // Compute inverse: a^{p-2} mod p let exp = sub_u64(modulus, 2); let inv = mod_pow(&a_val, &exp, modulus); - // Decompose into 128-bit limbs - let (inv_lo, inv_hi) = decompose_128(&inv); - + // Decompose into N limbs + let mut remaining = inv; let u128_to_fe = |val: u128| -> FieldElement { FieldElement::from_bigint(ark_ff::BigInt([ val as u64, @@ -462,37 +546,60 @@ impl WitnessBuilderSolver for WitnessBuilder { ])) .unwrap() }; - - witness[*output_start] = Some(u128_to_fe(inv_lo)); - witness[*output_start + 1] = Some(u128_to_fe(inv_hi)); + for i in 0..n { + let lo = remaining[0] as u128 | ((remaining[1] as u128) << 64); + witness[*output_start + i] = Some(u128_to_fe(lo & limb_mask)); + // Shift right by w bits + let mut shifted = [0u64; 4]; + let word_shift = (w / 64) as usize; + let bit_shift = w % 64; + for j in 0..4 { + if j + word_shift < 4 { + shifted[j] = remaining[j + word_shift] >> bit_shift; + if bit_shift > 0 && j + word_shift + 1 < 4 { + shifted[j] |= remaining[j + word_shift + 1] << (64 - bit_shift); + } + } + } + remaining = shifted; + } } - WitnessBuilder::WideAddQuotient { + WitnessBuilder::MultiLimbAddQuotient { output, - a_lo, - a_hi, - b_lo, - b_hi, + a_limbs, + b_limbs, modulus, + limb_bits, + .. } => { use crate::witness::bigint_mod::{add_4limb, cmp_4limb}; + let w = *limb_bits; - let a_lo_fe = witness[*a_lo].unwrap(); - let a_hi_fe = witness[*a_hi].unwrap(); - let b_lo_fe = witness[*b_lo].unwrap(); - let b_hi_fe = witness[*b_hi].unwrap(); - - let a_lo_limbs = a_lo_fe.into_bigint().0; - let a_hi_limbs = a_hi_fe.into_bigint().0; - let a_val = [a_lo_limbs[0], a_lo_limbs[1], a_hi_limbs[0], a_hi_limbs[1]]; + // Reconstruct from N limbs + let reconstruct = |limbs: &[usize]| -> [u64; 4] { + let mut val = [0u64; 4]; + let mut bit_offset = 0u32; + for &limb_idx in limbs.iter() { + let limb_val = witness[limb_idx].unwrap().into_bigint().0; + let limb_u128 = limb_val[0] as u128 | ((limb_val[1] as u128) << 64); + let word_start = (bit_offset / 64) as usize; + let bit_within = bit_offset % 64; + if word_start < 4 { + val[word_start] |= (limb_u128 as u64) << bit_within; + if word_start + 1 < 4 { + val[word_start + 1] |= (limb_u128 >> (64 - bit_within)) as u64; + } + } + bit_offset += w; + } + val + }; - let b_lo_limbs = b_lo_fe.into_bigint().0; - let b_hi_limbs = b_hi_fe.into_bigint().0; - let b_val = [b_lo_limbs[0], b_lo_limbs[1], b_hi_limbs[0], b_hi_limbs[1]]; + let a_val = reconstruct(a_limbs); + let b_val = reconstruct(b_limbs); let sum = add_4limb(&a_val, &b_val); - // q = 1 if sum >= p, else 0 let q = if sum[4] > 0 { - // sum > 2^256 > any 256-bit modulus 1u64 } else { let sum4 = [sum[0], sum[1], sum[2], sum[3]]; @@ -505,24 +612,38 @@ impl WitnessBuilderSolver for WitnessBuilder { witness[*output] = Some(FieldElement::from(q)); } - WitnessBuilder::WideSubBorrow { + WitnessBuilder::MultiLimbSubBorrow { output, - a_lo, - a_hi, - b_lo, - b_hi, + a_limbs, + b_limbs, + limb_bits, + .. } => { use crate::witness::bigint_mod::cmp_4limb; + let w = *limb_bits; - let a_lo_limbs = witness[*a_lo].unwrap().into_bigint().0; - let a_hi_limbs = witness[*a_hi].unwrap().into_bigint().0; - let a_val = [a_lo_limbs[0], a_lo_limbs[1], a_hi_limbs[0], a_hi_limbs[1]]; + let reconstruct = |limbs: &[usize]| -> [u64; 4] { + let mut val = [0u64; 4]; + let mut bit_offset = 0u32; + for &limb_idx in limbs.iter() { + let limb_val = witness[limb_idx].unwrap().into_bigint().0; + let limb_u128 = limb_val[0] as u128 | ((limb_val[1] as u128) << 64); + let word_start = (bit_offset / 64) as usize; + let bit_within = bit_offset % 64; + if word_start < 4 { + val[word_start] |= (limb_u128 as u64) << bit_within; + if word_start + 1 < 4 { + val[word_start + 1] |= (limb_u128 >> (64 - bit_within)) as u64; + } + } + bit_offset += w; + } + val + }; - let b_lo_limbs = witness[*b_lo].unwrap().into_bigint().0; - let b_hi_limbs = witness[*b_hi].unwrap().into_bigint().0; - let b_val = [b_lo_limbs[0], b_lo_limbs[1], b_hi_limbs[0], b_hi_limbs[1]]; + let a_val = reconstruct(a_limbs); + let b_val = reconstruct(b_limbs); - // q = 1 if a < b (need to add p to make result non-negative) let q = if cmp_4limb(&a_val, &b_val) == std::cmp::Ordering::Less { 1u64 } else { diff --git a/provekit/r1cs-compiler/src/msm/cost_model.rs b/provekit/r1cs-compiler/src/msm/cost_model.rs new file mode 100644 index 000000000..234623a31 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/cost_model.rs @@ -0,0 +1,362 @@ +//! Analytical cost model for MSM parameter optimization. +//! +//! Follows the SHA256 pattern (`spread.rs:get_optimal_spread_width`): +//! pure analytical estimator → exhaustive search → pick optimal (limb_bits, window_size). + +/// Type of field operation for cost estimation. +#[derive(Clone, Copy)] +pub enum FieldOpType { + Add, + Sub, + Mul, + Inv, +} + +/// Count field ops in scalar_mul for given parameters. +/// Traces through ec_points::scalar_mul logic analytically. +/// +/// Returns (n_add, n_sub, n_mul, n_inv) per single scalar multiplication. +fn count_scalar_mul_field_ops(scalar_bits: usize, window_size: usize) -> (usize, usize, usize, usize) { + let w = window_size; + let table_size = 1 << w; + let num_windows = (scalar_bits + w - 1) / w; + + // Build point table: T[0]=P (free), T[1]=P (free), T[2]=2P (1 double), + // T[3..table_size] = point_add each + let table_doubles = if table_size > 2 { 1 } else { 0 }; + let table_adds = if table_size > 2 { table_size - 3 } else { 0 }; + + // point_double costs: 5 mul, 4 add, 2 sub, 1 inv + let double_ops = (4usize, 2usize, 5usize, 1usize); // (add, sub, mul, inv) + // point_add costs: 2 add, 2 sub, 3 mul, 1 inv + let add_ops = (2usize, 2usize, 3usize, 1usize); + + // Table construction + let mut total_add = table_doubles * double_ops.0 + table_adds * add_ops.0; + let mut total_sub = table_doubles * double_ops.1 + table_adds * add_ops.1; + let mut total_mul = table_doubles * double_ops.2 + table_adds * add_ops.2; + let mut total_inv = table_doubles * double_ops.3 + table_adds * add_ops.3; + + // Table lookups: each uses (2^w - 1) point_selects + // point_select = 2 selects = 2 * (3 witnesses: diff, flag*diff, out) per coordinate + // But select is not a field op — it's cheaper (just `select` calls) + // We count it as 2 selects per point_select = 2 sub + 2 mul per select + // Actually select = flag*(on_true - on_false) + on_false: 1 sub, 1 mul, 1 add per elem + // Per point (x,y): 2 sub, 2 mul, 2 add for select + let selects_per_lookup = table_size - 1; // 2^w - 1 point_selects + let select_ops_per_point = (2usize, 2usize, 2usize, 0usize); // (add, sub, mul, inv) + + // MSB window: 1 table lookup (possibly smaller table) + let msb_bits = scalar_bits - (num_windows - 1) * w; + let msb_table_size = 1 << msb_bits; + let msb_selects = msb_table_size - 1; + total_add += msb_selects * select_ops_per_point.0; + total_sub += msb_selects * select_ops_per_point.1; + total_mul += msb_selects * select_ops_per_point.2; + + // Remaining windows: for each of (num_windows - 1) windows: + // - w doublings + // - 1 pack_bits (cheap) + // - 1 is_zero (1 inv + some adds) + // - 1 table lookup + // - 1 sub (for denom) + // - 1 elem_is_zero + // - 1 point_double (for x_eq case) + // - 1 safe_point_add (like point_add but with select on denom) + // - 2 point_selects (x_eq and digit_is_zero) + let remaining = if num_windows > 1 { num_windows - 1 } else { 0 }; + + for _ in 0..remaining { + // w doublings + total_add += w * double_ops.0; + total_sub += w * double_ops.1; + total_mul += w * double_ops.2; + total_inv += w * double_ops.3; + + // table lookup + total_add += selects_per_lookup * select_ops_per_point.0; + total_sub += selects_per_lookup * select_ops_per_point.1; + total_mul += selects_per_lookup * select_ops_per_point.2; + + // denom = sub(looked_up.x, acc.x) + total_sub += 1; + + // elem_is_zero(denom) = is_zero per limb + products + // For N limbs: N * (1 inv + some arith) + (N-1) products + // Simplified: 1 inv + 3 witnesses + total_inv += 1; + total_add += 1; + total_mul += 1; + + // point_double for x_eq case + total_add += double_ops.0; + total_sub += double_ops.1; + total_mul += double_ops.2; + total_inv += double_ops.3; + + // safe_point_add: like point_add + 1 select on denom + total_add += add_ops.0 + select_ops_per_point.0 / 2; // 1 select + total_sub += add_ops.1 + select_ops_per_point.1 / 2; + total_mul += add_ops.2 + select_ops_per_point.2 / 2; + total_inv += add_ops.3; + + // 2 point_selects + total_add += 2 * select_ops_per_point.0; + total_sub += 2 * select_ops_per_point.1; + total_mul += 2 * select_ops_per_point.2; + + // is_zero(digit) + total_inv += 1; + total_add += 1; + total_mul += 1; + } + + (total_add, total_sub, total_mul, total_inv) +} + +/// Witnesses per single N-limb field operation. +fn witnesses_per_op(num_limbs: usize, op: FieldOpType, is_native: bool) -> usize { + if is_native { + // Native: no range checks, just standard R1CS witnesses + match op { + FieldOpType::Add => 1, // sum witness + FieldOpType::Sub => 1, // sum witness + FieldOpType::Mul => 1, // product witness + FieldOpType::Inv => 1, // inverse witness + } + } else if num_limbs == 1 { + // Single-limb non-native: reduce_mod_p pattern + match op { + FieldOpType::Add => 5, // a+b, m const, k, k*m, result + FieldOpType::Sub => 5, // same + FieldOpType::Mul => 5, // a*b, m const, k, k*m, result + FieldOpType::Inv => 7, // a_inv + mul_mod_p(5) + range_check + } + } else { + // Multi-limb: N-limb operations + let n = num_limbs; + match op { + // add/sub: q + N*(v_offset, carry, r_limb) + N*(v_diff, borrow, d_limb) + FieldOpType::Add | FieldOpType::Sub => 1 + 3 * n + 3 * n, + // mul: hint(4N-2) + N² products + 2N-1 column constraints + lt_check + FieldOpType::Mul => (4 * n - 2) + n * n + 3 * n, + // inv: hint(N) + mul costs + FieldOpType::Inv => n + (4 * n - 2) + n * n + 3 * n, + } + } +} + +/// Total estimated witness cost for one scalar_mul. +pub fn calculate_msm_witness_cost( + native_field_bits: u32, + curve_modulus_bits: u32, + n_points: usize, + scalar_bits: usize, + window_size: usize, + limb_bits: u32, +) -> usize { + let is_native = curve_modulus_bits == native_field_bits; + let num_limbs = if is_native { + 1 + } else { + ((curve_modulus_bits as usize) + (limb_bits as usize) - 1) / (limb_bits as usize) + }; + + let (n_add, n_sub, n_mul, n_inv) = count_scalar_mul_field_ops(scalar_bits, window_size); + + let wit_add = witnesses_per_op(num_limbs, FieldOpType::Add, is_native); + let wit_sub = witnesses_per_op(num_limbs, FieldOpType::Sub, is_native); + let wit_mul = witnesses_per_op(num_limbs, FieldOpType::Mul, is_native); + let wit_inv = witnesses_per_op(num_limbs, FieldOpType::Inv, is_native); + + let per_scalarmul = n_add * wit_add + n_sub * wit_sub + n_mul * wit_mul + n_inv * wit_inv; + + // Scalar decomposition: 256 bits (bit witnesses + digital decomposition overhead) + let scalar_decomp = 256 + 10; + + // Point accumulation: (n_points - 1) point_adds + let accum_per_point = if n_points > 1 { + let accum_adds = n_points - 1; + accum_adds * (witnesses_per_op(num_limbs, FieldOpType::Add, is_native) * 2 + + witnesses_per_op(num_limbs, FieldOpType::Sub, is_native) * 2 + + witnesses_per_op(num_limbs, FieldOpType::Mul, is_native) * 3 + + witnesses_per_op(num_limbs, FieldOpType::Inv, is_native)) + } else { + 0 + }; + + n_points * (per_scalarmul + scalar_decomp) + accum_per_point +} + +/// Check whether schoolbook column equation values fit in the native field. +/// +/// In `mul_mod_p_multi`, the schoolbook multiplication verifies `a·b = p·q + r` via +/// column equations that include product sums, carry offsets, and outgoing carries. +/// Both sides of each column equation must evaluate to less than the native field +/// modulus as **integers** — if they overflow, the field's modular reduction makes +/// `LHS ≡ RHS (mod p)` weaker than `LHS = RHS`, breaking soundness. +/// +/// The maximum integer value across either side of any column equation is bounded by: +/// +/// `2^(2W + ceil(log2(N)) + 3)` +/// +/// where `W = limb_bits` and `N = num_limbs`. This accounts for: +/// - Up to N cross-products per column, each < `2^(2W)` +/// - The carry offset `2^(2W + ceil(log2(N)) + 1)` (dominant term) +/// - Outgoing carry term `2^W * offset_carry` on the RHS +/// +/// Since the native field modulus satisfies `p >= 2^(native_field_bits - 1)`, the +/// conservative soundness condition is: +/// +/// `2 * limb_bits + ceil(log2(num_limbs)) + 3 < native_field_bits` +pub fn column_equation_fits_native_field( + native_field_bits: u32, + limb_bits: u32, + num_limbs: usize, +) -> bool { + if num_limbs <= 1 { + return true; // Single-limb path has no column equations. + } + let ceil_log2_n = (num_limbs as f64).log2().ceil() as u32; + // Max column value < 2^(2*limb_bits + ceil_log2_n + 3). + // Need this < p_native >= 2^(native_field_bits - 1). + 2 * limb_bits + ceil_log2_n + 3 < native_field_bits +} + +/// Search for optimal (limb_bits, window_size) minimizing witness cost. +/// +/// Searches limb_bits ∈ [8..max] and window_size ∈ [2..8]. +/// Each candidate is checked for column equation soundness: the schoolbook +/// multiplication's intermediate values must fit in the native field without +/// modular wraparound (see [`column_equation_fits_native_field`]). +pub fn get_optimal_msm_params( + native_field_bits: u32, + curve_modulus_bits: u32, + n_points: usize, + scalar_bits: usize, +) -> (u32, usize) { + let is_native = curve_modulus_bits == native_field_bits; + if is_native { + // For native field, limb_bits doesn't matter (no multi-limb decomposition). + // Just optimize window_size. + let mut best_cost = usize::MAX; + let mut best_window = 4; + for ws in 2..=8 { + let cost = calculate_msm_witness_cost( + native_field_bits, + curve_modulus_bits, + n_points, + scalar_bits, + ws, + native_field_bits, + ); + if cost < best_cost { + best_cost = cost; + best_window = ws; + } + } + return (native_field_bits, best_window); + } + + // Upper bound on search: even with N=2 (best case), we need + // 2*lb + ceil(log2(2)) + 3 < native_field_bits => lb < (native_field_bits - 4) / 2. + // The per-candidate soundness check below is the actual gate. + let max_limb_bits = (native_field_bits.saturating_sub(4)) / 2; + let mut best_cost = usize::MAX; + let mut best_limb_bits = max_limb_bits.min(86); + let mut best_window = 4; + + // Search space + for lb in (8..=max_limb_bits).step_by(2) { + let num_limbs = + ((curve_modulus_bits as usize) + (lb as usize) - 1) / (lb as usize); + if !column_equation_fits_native_field(native_field_bits, lb, num_limbs) { + continue; + } + for ws in 2..=8usize { + let cost = calculate_msm_witness_cost( + native_field_bits, + curve_modulus_bits, + n_points, + scalar_bits, + ws, + lb, + ); + if cost < best_cost { + best_cost = cost; + best_limb_bits = lb; + best_window = ws; + } + } + } + + (best_limb_bits, best_window) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_optimal_params_bn254_native() { + // Grumpkin over BN254: native field + let (limb_bits, window_size) = get_optimal_msm_params(254, 254, 1, 256); + assert_eq!(limb_bits, 254); + assert!(window_size >= 2 && window_size <= 8); + } + + #[test] + fn test_optimal_params_secp256r1() { + // secp256r1 over BN254: 256-bit modulus, non-native + let (limb_bits, window_size) = get_optimal_msm_params(254, 256, 1, 256); + let num_limbs = ((256 + limb_bits - 1) / limb_bits) as usize; + assert!( + column_equation_fits_native_field(254, limb_bits, num_limbs), + "optimizer selected unsound limb_bits={limb_bits} (N={num_limbs})" + ); + assert!(window_size >= 2 && window_size <= 8); + } + + #[test] + fn test_optimal_params_goldilocks() { + // Hypothetical 64-bit field over BN254 + let (limb_bits, window_size) = get_optimal_msm_params(254, 64, 1, 64); + let num_limbs = ((64 + limb_bits - 1) / limb_bits) as usize; + assert!( + column_equation_fits_native_field(254, limb_bits, num_limbs), + "optimizer selected unsound limb_bits={limb_bits} (N={num_limbs})" + ); + assert!(window_size >= 2 && window_size <= 8); + } + + #[test] + fn test_count_field_ops_sanity() { + let (add, sub, mul, inv) = count_scalar_mul_field_ops(256, 4); + assert!(add > 0); + assert!(sub > 0); + assert!(mul > 0); + assert!(inv > 0); + } + + #[test] + fn test_column_equation_soundness_boundary() { + // For BN254 (254 bits) with N=3: max safe limb_bits is 124. + // 2*124 + ceil(log2(3)) + 3 = 248 + 2 + 3 = 253 < 254 ✓ + assert!(column_equation_fits_native_field(254, 124, 3)); + // 2*125 + ceil(log2(3)) + 3 = 250 + 2 + 3 = 255 ≥ 254 ✗ + assert!(!column_equation_fits_native_field(254, 125, 3)); + // 2*126 + ceil(log2(3)) + 3 = 252 + 2 + 3 = 257 ≥ 254 ✗ + assert!(!column_equation_fits_native_field(254, 126, 3)); + } + + #[test] + fn test_secp256r1_limb_bits_not_126() { + // Regression: limb_bits=126 with N=3 causes offset_w = 2^255 > p_BN254, + // making the schoolbook column equations unsound. + let (limb_bits, _) = get_optimal_msm_params(254, 256, 1, 256); + assert!( + limb_bits <= 124, + "secp256r1 limb_bits={limb_bits} exceeds safe maximum 124" + ); + } +} diff --git a/provekit/r1cs-compiler/src/msm/curve.rs b/provekit/r1cs-compiler/src/msm/curve.rs index d4d0d247b..53a1340f8 100644 --- a/provekit/r1cs-compiler/src/msm/curve.rs +++ b/provekit/r1cs-compiler/src/msm/curve.rs @@ -1,9 +1,6 @@ use { - crate::noir_to_r1cs::NoirToR1CSCompiler, - provekit_common::{ - witness::{ConstantTerm, WitnessBuilder}, - FieldElement, - }, + ark_ff::{BigInteger, PrimeField}, + provekit_common::FieldElement, }; pub struct CurveParams { @@ -15,38 +12,91 @@ pub struct CurveParams { } impl CurveParams { - pub fn p_lo_fe(&self) -> FieldElement { - decompose_128(self.field_modulus_p).0 + /// Decompose the field modulus p into `num_limbs` limbs of `limb_bits` width each. + pub fn p_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&self.field_modulus_p, limb_bits, num_limbs) } - pub fn p_hi_fe(&self) -> FieldElement { - decompose_128(self.field_modulus_p).1 + + /// Decompose (p - 1) into `num_limbs` limbs of `limb_bits` width each. + pub fn p_minus_1_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + let p_minus_1 = sub_one_u64_4(&self.field_modulus_p); + decompose_to_limbs(&p_minus_1, limb_bits, num_limbs) + } + + /// Decompose the curve parameter `a` into `num_limbs` limbs of `limb_bits` width. + pub fn curve_a_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&self.curve_a, limb_bits, num_limbs) } - pub fn p_86_limbs(&self) -> [FieldElement; 3] { - let mask_86: u128 = (1u128 << 86) - 1; - let lo128 = self.field_modulus_p[0] as u128 | ((self.field_modulus_p[1] as u128) << 64); - let hi128 = self.field_modulus_p[2] as u128 | ((self.field_modulus_p[3] as u128) << 64); - let l0 = lo128 & mask_86; - // l1 spans bits [86..172): 42 bits from lo128, 44 bits from hi128 - let l1 = ((lo128 >> 86) | (hi128 << 42)) & mask_86; - // l2 = bits [172..256): 84 bits from hi128 - let l2 = hi128 >> 44; - [ - FieldElement::from(l0), - FieldElement::from(l1), - FieldElement::from(l2), - ] + + /// Number of bits in the field modulus. + pub fn modulus_bits(&self) -> u32 { + if self.is_native_field() { + // p mod p = 0 as a field element, so we use the constant directly. + FieldElement::MODULUS_BIT_SIZE + } else { + let fe = curve_native_point_fe(&self.field_modulus_p); + fe.into_bigint().num_bits() + } + } + + /// Returns true if the curve's base field modulus equals the native BN254 + /// scalar field modulus. + pub fn is_native_field(&self) -> bool { + let native_mod = FieldElement::MODULUS; + self.field_modulus_p == native_mod.0 } + + /// Convert modulus to a native field element (only valid when p < native modulus). pub fn p_native_fe(&self) -> FieldElement { curve_native_point_fe(&self.field_modulus_p) } } -/// Splits a 256-bit value ([u64; 4]) into two 128-bit field elements (lo, hi). -fn decompose_128(val: [u64; 4]) -> (FieldElement, FieldElement) { - ( - FieldElement::from((val[0] as u128) | ((val[1] as u128) << 64)), - FieldElement::from((val[2] as u128) | ((val[3] as u128) << 64)), - ) +/// Decompose a 256-bit value into `num_limbs` limbs of `limb_bits` width each, +/// returned as FieldElements. +fn decompose_to_limbs(val: &[u64; 4], limb_bits: u32, num_limbs: usize) -> Vec { + let mask: u128 = if limb_bits >= 128 { + u128::MAX + } else { + (1u128 << limb_bits) - 1 + }; + let mut result = vec![FieldElement::from(0u64); num_limbs]; + let mut remaining = *val; + for item in result.iter_mut() { + let lo = remaining[0] as u128 | ((remaining[1] as u128) << 64); + *item = FieldElement::from(lo & mask); + // Shift remaining right by limb_bits + if limb_bits >= 256 { + remaining = [0; 4]; + } else { + let mut shifted = [0u64; 4]; + let word_shift = (limb_bits / 64) as usize; + let bit_shift = limb_bits % 64; + for i in 0..4 { + if i + word_shift < 4 { + shifted[i] = remaining[i + word_shift] >> bit_shift; + if bit_shift > 0 && i + word_shift + 1 < 4 { + shifted[i] |= remaining[i + word_shift + 1] << (64 - bit_shift); + } + } + } + remaining = shifted; + } + } + result +} + +/// Subtract 1 from a [u64; 4] value. +fn sub_one_u64_4(val: &[u64; 4]) -> [u64; 4] { + let mut result = *val; + for limb in result.iter_mut() { + if *limb > 0 { + *limb -= 1; + return result; + } + *limb = u64::MAX; // borrow + } + result } /// Converts a 256-bit value ([u64; 4]) into a single native field element. @@ -54,21 +104,46 @@ pub fn curve_native_point_fe(val: &[u64; 4]) -> FieldElement { FieldElement::from_sign_and_limbs(true, val) } -#[derive(Clone, Copy, Debug)] -pub struct Limb2 { - pub lo: usize, - pub hi: usize, -} - -pub fn limb2_constant(r1cs_compiler: &mut NoirToR1CSCompiler, value: [u64; 4]) -> Limb2 { - let (lo, hi) = decompose_128(value); - let lo_idx = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::Constant(ConstantTerm(lo_idx, lo))); - let hi_idx = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::Constant(ConstantTerm(hi_idx, hi))); - Limb2 { - lo: lo_idx, - hi: hi_idx, +/// Grumpkin curve parameters. +/// +/// Grumpkin is a cycle-companion curve for BN254: its base field is the BN254 +/// scalar field, and its order is the BN254 base field order. +/// +/// Equation: y² = x³ − 17 (a = 0, b = −17 mod p) +pub fn grumpkin_params() -> CurveParams { + CurveParams { + // BN254 scalar field modulus + field_modulus_p: [ + 0x43e1f593f0000001_u64, + 0x2833e84879b97091_u64, + 0xb85045b68181585d_u64, + 0x30644e72e131a029_u64, + ], + // BN254 base field modulus + curve_order_n: [ + 0x3c208c16d87cfd47_u64, + 0x97816a916871ca8d_u64, + 0xb85045b68181585d_u64, + 0x30644e72e131a029_u64, + ], + curve_a: [0; 4], + // b = −17 mod p + curve_b: [ + 0x43e1f593effffff0_u64, + 0x2833e84879b97091_u64, + 0xb85045b68181585d_u64, + 0x30644e72e131a029_u64, + ], + // Generator G = (1, sqrt(−16) mod p) + generator: ( + [1, 0, 0, 0], + [ + 0x833fc48d823f272c_u64, + 0x2d270d45f1181294_u64, + 0xcf135e7506a45d63_u64, + 0x0000000000000002_u64, + ], + ), } } diff --git a/provekit/r1cs-compiler/src/msm/ec_ops.rs b/provekit/r1cs-compiler/src/msm/ec_ops.rs deleted file mode 100644 index 985937821..000000000 --- a/provekit/r1cs-compiler/src/msm/ec_ops.rs +++ /dev/null @@ -1,208 +0,0 @@ -use { - crate::noir_to_r1cs::NoirToR1CSCompiler, - ark_ff::{AdditiveGroup, BigInteger, Field, PrimeField}, - provekit_common::{ - witness::{SumTerm, WitnessBuilder}, - FieldElement, - }, - std::collections::BTreeMap, -}; - -/// Reduce the value to given modulus -pub fn reduce_mod_p( - r1cs_compiler: &mut NoirToR1CSCompiler, - value: usize, - modulus: FieldElement, - range_checks: &mut BTreeMap>, -) -> usize { - // Reduce mod algorithm : - // v = k * m + result, where 0 <= result < m - // k = floor(v / m) (integer division) - // result = v - k * m - - // Computing k = floor(v / m) - // ----------------------------------------------------------- - // computing m (constant witness for use in constraints) - let m = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::Constant( - provekit_common::witness::ConstantTerm(m, modulus), - )); - // computing k via integer division - let k = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::IntegerQuotient(k, value, modulus)); - - // Computing result = v - k * m - // ----------------------------------------------------------- - // computing k * m - let k_mul_m = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::Product(k_mul_m, k, m)); - // constraint: k * m = k_mul_m - r1cs_compiler - .r1cs - .add_constraint(&[(FieldElement::ONE, k)], &[(FieldElement::ONE, m)], &[( - FieldElement::ONE, - k_mul_m, - )]); - // computing result = v - k * m - let result = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::Sum(result, vec![ - SumTerm(Some(FieldElement::ONE), value), - SumTerm(Some(-FieldElement::ONE), k_mul_m), - ])); - // constraint: 1 * (k_mul_m + result) = value - r1cs_compiler.r1cs.add_constraint( - &[(FieldElement::ONE, r1cs_compiler.witness_one())], - &[(FieldElement::ONE, k_mul_m), (FieldElement::ONE, result)], - &[(FieldElement::ONE, value)], - ); - // range check to prove 0 <= result < m - let modulus_bits = modulus.into_bigint().num_bits(); - range_checks - .entry(modulus_bits) - .or_insert_with(Vec::new) - .push(result); - - result -} - -/// a + b mod p -pub fn add_mod_p( - r1cs_compiler: &mut NoirToR1CSCompiler, - a: usize, - b: usize, - modulus: FieldElement, - range_checks: &mut BTreeMap>, -) -> usize { - let a_add_b = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::Sum(a_add_b, vec![ - SumTerm(Some(FieldElement::ONE), a), - SumTerm(Some(FieldElement::ONE), b), - ])); - // constraint: a + b = a_add_b - r1cs_compiler.r1cs.add_constraint( - &[(FieldElement::ONE, a), (FieldElement::ONE, b)], - &[(FieldElement::ONE, r1cs_compiler.witness_one())], - &[(FieldElement::ONE, a_add_b)], - ); - reduce_mod_p(r1cs_compiler, a_add_b, modulus, range_checks) -} - -/// a * b mod p -pub fn mul_mod_p( - r1cs_compiler: &mut NoirToR1CSCompiler, - a: usize, - b: usize, - modulus: FieldElement, - range_checks: &mut BTreeMap>, -) -> usize { - let a_mul_b = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::Product(a_mul_b, a, b)); - // constraint: a * b = a_mul_b - r1cs_compiler - .r1cs - .add_constraint(&[(FieldElement::ONE, a)], &[(FieldElement::ONE, b)], &[( - FieldElement::ONE, - a_mul_b, - )]); - reduce_mod_p(r1cs_compiler, a_mul_b, modulus, range_checks) -} - -/// (a - b) mod p -pub fn sub_mod_p( - r1cs_compiler: &mut NoirToR1CSCompiler, - a: usize, - b: usize, - modulus: FieldElement, - range_checks: &mut BTreeMap>, -) -> usize { - let a_sub_b = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::Sum(a_sub_b, vec![ - SumTerm(Some(FieldElement::ONE), a), - SumTerm(Some(-FieldElement::ONE), b), - ])); - // constraint: 1 * (a - b) = a_sub_b - r1cs_compiler.r1cs.add_constraint( - &[(FieldElement::ONE, r1cs_compiler.witness_one())], - &[(FieldElement::ONE, a), (-FieldElement::ONE, b)], - &[(FieldElement::ONE, a_sub_b)], - ); - reduce_mod_p(r1cs_compiler, a_sub_b, modulus, range_checks) -} - -/// a^(-1) mod p -pub fn inv_mod_p( - r1cs_compiler: &mut NoirToR1CSCompiler, - a: usize, - modulus: FieldElement, - range_checks: &mut BTreeMap>, -) -> usize { - // Computing a^(-1) mod m - // ----------------------------------------------------------- - // computing a_inv (the F_m inverse of a) via Fermat's little theorem - let a_inv = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::ModularInverse(a_inv, a, modulus)); - - // Verifying a * a_inv mod m = 1 - // ----------------------------------------------------------- - // computing a * a_inv mod m - let reduced = mul_mod_p(r1cs_compiler, a, a_inv, modulus, range_checks); - - // constraint: reduced = 1 - // (reduced - 1) * 1 = 0 - r1cs_compiler.r1cs.add_constraint( - &[(FieldElement::ONE, r1cs_compiler.witness_one())], - &[ - (FieldElement::ONE, reduced), - (-FieldElement::ONE, r1cs_compiler.witness_one()), - ], - &[(FieldElement::ZERO, r1cs_compiler.witness_one())], - ); - - // range check: a_inv in [0, 2^bits(m)) - let mod_bits = modulus.into_bigint().num_bits(); - range_checks - .entry(mod_bits) - .or_insert_with(Vec::new) - .push(a_inv); - - a_inv -} - -/// checks if value is zero or not -pub fn compute_is_zero(r1cs_compiler: &mut NoirToR1CSCompiler, value: usize) -> usize { - // calculating v^(-1) - let value_inv = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::Inverse(value_inv, value)); - // calculating v * v^(-1) - let value_mul_value_inv = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::Product( - value_mul_value_inv, - value, - value_inv, - )); - // calculate is_zero = 1 - (v * v^(-1)) - let is_zero = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(provekit_common::witness::WitnessBuilder::Sum( - is_zero, - vec![ - provekit_common::witness::SumTerm(Some(FieldElement::ONE), r1cs_compiler.witness_one()), - provekit_common::witness::SumTerm(Some(-FieldElement::ONE), value_mul_value_inv), - ], - )); - // constraint: v × v^(-1) = 1 - is_zero - r1cs_compiler.r1cs.add_constraint( - &[(FieldElement::ONE, value)], - &[(FieldElement::ONE, value_inv)], - &[ - (FieldElement::ONE, r1cs_compiler.witness_one()), - (-FieldElement::ONE, is_zero), - ], - ); - // constraint: v × is_zero = 0 - r1cs_compiler.r1cs.add_constraint( - &[(FieldElement::ONE, value)], - &[(FieldElement::ONE, is_zero)], - &[(FieldElement::ZERO, r1cs_compiler.witness_one())], - ); - is_zero -} diff --git a/provekit/r1cs-compiler/src/msm/ec_points.rs b/provekit/r1cs-compiler/src/msm/ec_points.rs index d607d25ff..14712c78c 100644 --- a/provekit/r1cs-compiler/src/msm/ec_points.rs +++ b/provekit/r1cs-compiler/src/msm/ec_points.rs @@ -99,3 +99,186 @@ pub fn point_select( let y = ops.select(flag, on_false.1, on_true.1); (x, y) } + +/// Point addition with safe denominator for the `x1 = x2` edge case. +/// +/// When `x_eq = 1`, the denominator `(x2 - x1)` is zero and cannot be +/// inverted. This function replaces it with 1, producing a satisfiable +/// but meaningless result. The caller MUST discard this result via +/// `point_select` when `x_eq = 1`. +/// +/// The `denom` parameter is the precomputed `x2 - x1`. +fn safe_point_add( + ops: &mut F, + x1: F::Elem, + y1: F::Elem, + x2: F::Elem, + y2: F::Elem, + denom: F::Elem, + x_eq: usize, +) -> (F::Elem, F::Elem) { + let numerator = ops.sub(y2, y1); + + // When x_eq=1 (denom=0), substitute with 1 to keep inv satisfiable + let one = ops.constant_one(); + let safe_denom = ops.select(x_eq, denom, one); + + let denom_inv = ops.inv(safe_denom); + let lambda = ops.mul(numerator, denom_inv); + + let lambda_sq = ops.mul(lambda, lambda); + let x1_plus_x2 = ops.add(x1, x2); + let x3 = ops.sub(lambda_sq, x1_plus_x2); + + let x1_minus_x3 = ops.sub(x1, x3); + let lambda_dx = ops.mul(lambda, x1_minus_x3); + let y3 = ops.sub(lambda_dx, y1); + + (x3, y3) +} + +/// Builds a point table for windowed scalar multiplication. +/// +/// T[0] = P (dummy entry, used when window digit = 0) +/// T[1] = P, T[2] = 2P, T[i] = T[i-1] + P for i >= 3. +fn build_point_table( + ops: &mut F, + px: F::Elem, + py: F::Elem, + table_size: usize, +) -> Vec<(F::Elem, F::Elem)> { + assert!(table_size >= 2); + let mut table = Vec::with_capacity(table_size); + table.push((px, py)); // T[0] = P (dummy) + table.push((px, py)); // T[1] = P + if table_size > 2 { + table.push(point_double(ops, px, py)); // T[2] = 2P + for i in 3..table_size { + let prev = table[i - 1]; + table.push(point_add(ops, prev.0, prev.1, px, py)); + } + } + table +} + +/// Selects T[d] from a point table using bit witnesses, where `d = Σ bits[i] * 2^i`. +/// +/// Uses a binary tree of `point_select`s: processes bits from MSB to LSB, +/// halving the candidate set at each level. Total: `(2^w - 1)` point selects +/// for a table of `2^w` entries. +fn table_lookup( + ops: &mut F, + table: &[(F::Elem, F::Elem)], + bits: &[usize], +) -> (F::Elem, F::Elem) { + assert_eq!(table.len(), 1 << bits.len()); + let mut current: Vec<(F::Elem, F::Elem)> = table.to_vec(); + // Process bits from MSB to LSB + for &bit in bits.iter().rev() { + let half = current.len() / 2; + let mut next = Vec::with_capacity(half); + for i in 0..half { + next.push(point_select(ops, bit, current[i], current[i + half])); + } + current = next; + } + current[0] +} + +/// Windowed scalar multiplication: computes `[scalar] * P`. +/// +/// Takes pre-decomposed scalar bits (LSB first, `scalar_bits[0]` is the +/// least significant bit) and a window size `w`. Precomputes a table of +/// `2^w` point multiples and processes the scalar in `w`-bit windows from +/// MSB to LSB. +/// +/// Handles two edge cases: +/// 1. **MSB window digit = 0**: The accumulator is initialized from T[0] +/// (a dummy copy of P). An `acc_is_identity` flag tracks that no real +/// point has been accumulated yet. When the first non-zero window digit +/// is encountered, the looked-up point becomes the new accumulator. +/// 2. **x-coordinate collision** (`acc.x == looked_up.x`): Uses +/// `point_double` instead of `point_add`, with `safe_point_add` +/// guarding the zero denominator. +/// +/// The inverse-point case (`acc = -looked_up`, result is infinity) cannot +/// be represented in affine coordinates and remains unsupported — this has +/// negligible probability (~2^{-256}) for random scalars. +pub fn scalar_mul( + ops: &mut F, + px: F::Elem, + py: F::Elem, + scalar_bits: &[usize], + window_size: usize, +) -> (F::Elem, F::Elem) { + let n = scalar_bits.len(); + let w = window_size; + let table_size = 1 << w; + + // Build point table: T[i] = [i]P, with T[0] = P as dummy + let table = build_point_table(ops, px, py, table_size); + + // Number of windows (ceiling division) + let num_windows = (n + w - 1) / w; + + // Process MSB window first (may be shorter than w bits if n % w != 0) + let msb_start = (num_windows - 1) * w; + let msb_bits = &scalar_bits[msb_start..n]; + let msb_table = &table[..1 << msb_bits.len()]; + let mut acc = table_lookup(ops, msb_table, msb_bits); + + // Track whether acc represents the identity (no real point yet). + // When MSB digit = 0, T[0] = P is loaded as a dummy — we must not + // double or add it until the first non-zero window digit appears. + let msb_digit = ops.pack_bits(msb_bits); + let mut acc_is_identity = ops.is_zero(msb_digit); + + // Process remaining windows from MSB-1 down to LSB + for i in (0..num_windows - 1).rev() { + // w doublings — only meaningful when acc is a real point. + // When acc_is_identity=1, the doubling result is garbage but will + // be discarded by the point_select below. + let mut doubled_acc = acc; + for _ in 0..w { + doubled_acc = point_double(ops, doubled_acc.0, doubled_acc.1); + } + // If acc is identity, keep dummy; otherwise use doubled result + acc = point_select(ops, acc_is_identity, doubled_acc, acc); + + // Table lookup for this window's digit + let window_bits = &scalar_bits[i * w..(i + 1) * w]; + let digit = ops.pack_bits(window_bits); + let digit_is_zero = ops.is_zero(digit); + + let looked_up = table_lookup(ops, &table, window_bits); + + // Detect x-coordinate collision: acc.x == looked_up.x + let denom = ops.sub(looked_up.0, acc.0); + let x_eq = ops.elem_is_zero(denom); + + // point_double handles the acc == looked_up case (same point) + let doubled = point_double(ops, acc.0, acc.1); + + // Safe point_add (substitutes denominator when x_eq=1) + let added = safe_point_add( + ops, acc.0, acc.1, looked_up.0, looked_up.1, denom, x_eq, + ); + + // x_eq=0 => use add result, x_eq=1 => use double result + let combined = point_select(ops, x_eq, added, doubled); + + // Four cases based on (acc_is_identity, digit_is_zero): + // (0, 0) => combined — normal add/double + // (0, 1) => acc — keep accumulator + // (1, 0) => looked_up — first real point + // (1, 1) => acc — still identity + let normal_result = point_select(ops, digit_is_zero, combined, acc); + let identity_result = point_select(ops, digit_is_zero, looked_up, acc); + acc = point_select(ops, acc_is_identity, normal_result, identity_result); + + // Update: acc is identity only if it was identity AND digit is zero + acc_is_identity = ops.bool_and(acc_is_identity, digit_is_zero); + } + + acc +} diff --git a/provekit/r1cs-compiler/src/msm/mod.rs b/provekit/r1cs-compiler/src/msm/mod.rs index a155a6def..dda1e064a 100644 --- a/provekit/r1cs-compiler/src/msm/mod.rs +++ b/provekit/r1cs-compiler/src/msm/mod.rs @@ -1,113 +1,166 @@ +pub mod cost_model; pub mod curve; -pub mod ec_ops; pub mod ec_points; -pub mod wide_ops; +pub mod multi_limb_arith; +pub mod multi_limb_ops; use { - crate::noir_to_r1cs::NoirToR1CSCompiler, - ark_ff::Field, - curve::{curve_native_point_fe, limb2_constant, CurveParams, Limb2}, + crate::{ + digits::{add_digital_decomposition, DigitalDecompositionWitnessesBuilder}, + noir_to_r1cs::NoirToR1CSCompiler, + }, + ark_ff::{AdditiveGroup, Field}, + curve::CurveParams, + multi_limb_ops::{MultiLimbOps, MultiLimbParams}, provekit_common::{ - witness::{ConstantTerm, SumTerm, WitnessBuilder}, + witness::{ConstantOrR1CSWitness, ConstantTerm, SumTerm, WitnessBuilder}, FieldElement, }, std::collections::BTreeMap, }; -pub trait FieldOps { - type Elem: Copy; +// --------------------------------------------------------------------------- +// Limbs: fixed-capacity, Copy array of witness indices +// --------------------------------------------------------------------------- - fn add(&mut self, a: Self::Elem, b: Self::Elem) -> Self::Elem; - fn sub(&mut self, a: Self::Elem, b: Self::Elem) -> Self::Elem; - fn mul(&mut self, a: Self::Elem, b: Self::Elem) -> Self::Elem; - fn inv(&mut self, a: Self::Elem) -> Self::Elem; - fn curve_a(&mut self) -> Self::Elem; +/// Maximum number of limbs supported. Covers all practical field sizes +/// (e.g. a 512-bit modulus with 16-bit limbs = 32 limbs). +pub const MAX_LIMBS: usize = 32; - /// Conditional select: returns `on_true` if `flag` is 1, `on_false` if - /// `flag` is 0. Constrains `flag` to be boolean (`flag * flag = flag`). - fn select(&mut self, flag: usize, on_false: Self::Elem, on_true: Self::Elem) -> Self::Elem; -} - -/// Narrow field operations for curves where p fits in BN254's scalar field. -/// Operates on single witness indices (`usize`). -pub struct NarrowOps<'a> { - pub compiler: &'a mut NoirToR1CSCompiler, - pub range_checks: &'a mut BTreeMap>, - pub modulus: FieldElement, - pub params: &'a CurveParams, +/// A fixed-capacity array of witness indices, indexed by limb position. +/// +/// This type is `Copy`, so it can be used as `FieldOps::Elem` without +/// requiring const generics or dispatch macros. The runtime `len` field +/// tracks how many limbs are actually in use. +#[derive(Clone, Copy)] +pub struct Limbs { + data: [usize; MAX_LIMBS], + len: usize, } -impl FieldOps for NarrowOps<'_> { - type Elem = usize; +impl Limbs { + /// Sentinel value for uninitialized limb slots. Using `usize::MAX` + /// ensures accidental use of an unfilled slot indexes an absurdly + /// large witness, causing an immediate out-of-bounds panic. + const UNINIT: usize = usize::MAX; - fn add(&mut self, a: usize, b: usize) -> usize { - ec_ops::add_mod_p(self.compiler, a, b, self.modulus, self.range_checks) + /// Create a new `Limbs` with `len` limbs, all initialized to `UNINIT`. + pub fn new(len: usize) -> Self { + assert!( + len > 0 && len <= MAX_LIMBS, + "limb count must be 1..={MAX_LIMBS}, got {len}" + ); + Self { + data: [Self::UNINIT; MAX_LIMBS], + len, + } } - fn sub(&mut self, a: usize, b: usize) -> usize { - ec_ops::sub_mod_p(self.compiler, a, b, self.modulus, self.range_checks) + /// Create a single-limb `Limbs` wrapping one witness index. + pub fn single(value: usize) -> Self { + let mut l = Self { + data: [Self::UNINIT; MAX_LIMBS], + len: 1, + }; + l.data[0] = value; + l } - fn mul(&mut self, a: usize, b: usize) -> usize { - ec_ops::mul_mod_p(self.compiler, a, b, self.modulus, self.range_checks) + /// Create `Limbs` from a slice of witness indices. + pub fn from_slice(s: &[usize]) -> Self { + assert!( + !s.is_empty() && s.len() <= MAX_LIMBS, + "slice length must be 1..={MAX_LIMBS}, got {}", + s.len() + ); + let mut data = [Self::UNINIT; MAX_LIMBS]; + data[..s.len()].copy_from_slice(s); + Self { data, len: s.len() } } - fn inv(&mut self, a: usize) -> usize { - ec_ops::inv_mod_p(self.compiler, a, self.modulus, self.range_checks) + /// View the active limbs as a slice. + pub fn as_slice(&self) -> &[usize] { + &self.data[..self.len] } - fn curve_a(&mut self) -> usize { - let a_fe = curve_native_point_fe(&self.params.curve_a); - let w = self.compiler.num_witnesses(); - self.compiler - .add_witness_builder(WitnessBuilder::Constant(ConstantTerm(w, a_fe))); - w + /// Number of active limbs. + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + self.len } +} - fn select(&mut self, flag: usize, on_false: usize, on_true: usize) -> usize { - constrain_boolean(self.compiler, flag); - select_witness(self.compiler, flag, on_false, on_true) +impl std::fmt::Debug for Limbs { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_list().entries(self.as_slice().iter()).finish() } } -/// Wide field operations for curves where p > BN254_r (e.g. secp256r1). -/// Operates on `Limb2` (two 128-bit limbs). -pub struct WideOps<'a> { - pub compiler: &'a mut NoirToR1CSCompiler, - pub range_checks: &'a mut BTreeMap>, - pub params: &'a CurveParams, +impl PartialEq for Limbs { + fn eq(&self, other: &Self) -> bool { + self.len == other.len && self.data[..self.len] == other.data[..other.len] + } } +impl Eq for Limbs {} -impl FieldOps for WideOps<'_> { - type Elem = Limb2; - - fn add(&mut self, a: Limb2, b: Limb2) -> Limb2 { - wide_ops::add_mod_p(self.compiler, self.range_checks, a, b, self.params) +impl std::ops::Index for Limbs { + type Output = usize; + fn index(&self, i: usize) -> &usize { + debug_assert!( + i < self.len, + "Limbs index {i} out of bounds (len={})", + self.len + ); + &self.data[i] } +} - fn sub(&mut self, a: Limb2, b: Limb2) -> Limb2 { - wide_ops::sub_mod_p(self.compiler, self.range_checks, a, b, self.params) +impl std::ops::IndexMut for Limbs { + fn index_mut(&mut self, i: usize) -> &mut usize { + debug_assert!( + i < self.len, + "Limbs index {i} out of bounds (len={})", + self.len + ); + &mut self.data[i] } +} - fn mul(&mut self, a: Limb2, b: Limb2) -> Limb2 { - wide_ops::mul_mod_p(self.compiler, self.range_checks, a, b, self.params) - } +// --------------------------------------------------------------------------- +// FieldOps trait +// --------------------------------------------------------------------------- - fn inv(&mut self, a: Limb2) -> Limb2 { - wide_ops::inv_mod_p(self.compiler, self.range_checks, a, self.params) - } +pub trait FieldOps { + type Elem: Copy; - fn curve_a(&mut self) -> Limb2 { - limb2_constant(self.compiler, self.params.curve_a) - } + fn add(&mut self, a: Self::Elem, b: Self::Elem) -> Self::Elem; + fn sub(&mut self, a: Self::Elem, b: Self::Elem) -> Self::Elem; + fn mul(&mut self, a: Self::Elem, b: Self::Elem) -> Self::Elem; + fn inv(&mut self, a: Self::Elem) -> Self::Elem; + fn curve_a(&mut self) -> Self::Elem; - fn select(&mut self, flag: usize, on_false: Limb2, on_true: Limb2) -> Limb2 { - constrain_boolean(self.compiler, flag); - Limb2 { - lo: select_witness(self.compiler, flag, on_false.lo, on_true.lo), - hi: select_witness(self.compiler, flag, on_false.hi, on_true.hi), - } - } + /// Conditional select: returns `on_true` if `flag` is 1, `on_false` if + /// `flag` is 0. Constrains `flag` to be boolean (`flag * flag = flag`). + fn select(&mut self, flag: usize, on_false: Self::Elem, on_true: Self::Elem) -> Self::Elem; + + /// Checks if a BN254 native witness value is zero. + /// Returns a boolean witness: 1 if zero, 0 if non-zero. + fn is_zero(&mut self, value: usize) -> usize; + + /// Packs bit witnesses into a single digit witness: `d = Σ bits[i] * 2^i`. + /// Does NOT constrain bits to be boolean — caller must ensure that. + fn pack_bits(&mut self, bits: &[usize]) -> usize; + + /// Checks if a field element (in the curve's base field) is zero. + /// Returns a boolean witness: 1 if zero, 0 if non-zero. + fn elem_is_zero(&mut self, value: Self::Elem) -> usize; + + /// Returns the constant field element 1. + fn constant_one(&mut self) -> Self::Elem; + + /// Computes a * b for two boolean (0/1) native witnesses. + /// Used for boolean AND on flags in scalar_mul. + fn bool_and(&mut self, a: usize, b: usize) -> usize; } // --------------------------------------------------------------------------- @@ -115,7 +168,7 @@ impl FieldOps for WideOps<'_> { // --------------------------------------------------------------------------- /// Constrains `flag` to be boolean: `flag * flag = flag`. -fn constrain_boolean(compiler: &mut NoirToR1CSCompiler, flag: usize) { +pub(crate) fn constrain_boolean(compiler: &mut NoirToR1CSCompiler, flag: usize) { compiler.r1cs.add_constraint( &[(FieldElement::ONE, flag)], &[(FieldElement::ONE, flag)], @@ -125,15 +178,18 @@ fn constrain_boolean(compiler: &mut NoirToR1CSCompiler, flag: usize) { /// Single-witness conditional select: `out = on_false + flag * (on_true - /// on_false)`. -/// -/// Produces 3 witnesses and 3 R1CS constraints (diff, flag*diff, out). -/// Does NOT constrain `flag` to be boolean — caller must do that separately. -fn select_witness( +pub(crate) fn select_witness( compiler: &mut NoirToR1CSCompiler, flag: usize, on_false: usize, on_true: usize, ) -> usize { + // When both branches are the same witness, result is trivially that witness. + // Avoids duplicate column indices in R1CS from `on_true - on_false` when + // both share the same witness index. + if on_false == on_true { + return on_false; + } let diff = compiler.add_sum(vec![ SumTerm(None, on_true), SumTerm(Some(-FieldElement::ONE), on_false), @@ -141,3 +197,301 @@ fn select_witness( let flag_diff = compiler.add_product(flag, diff); compiler.add_sum(vec![SumTerm(None, on_false), SumTerm(None, flag_diff)]) } + +/// Packs bit witnesses into a digit: `d = Σ bits[i] * 2^i`. +pub(crate) fn pack_bits_helper(compiler: &mut NoirToR1CSCompiler, bits: &[usize]) -> usize { + let terms: Vec = bits + .iter() + .enumerate() + .map(|(i, &bit)| SumTerm(Some(FieldElement::from(1u128 << i)), bit)) + .collect(); + compiler.add_sum(terms) +} + +// --------------------------------------------------------------------------- +// Params builder (runtime num_limbs, no const generics) +// --------------------------------------------------------------------------- + +/// Build `MultiLimbParams` for a given runtime `num_limbs`. +fn build_params(num_limbs: usize, limb_bits: u32, curve: &CurveParams) -> MultiLimbParams { + let is_native = curve.is_native_field(); + let two_pow_w = FieldElement::from(2u64).pow([limb_bits as u64]); + let modulus_fe = if !is_native { + Some(curve.p_native_fe()) + } else { + None + }; + MultiLimbParams { + num_limbs, + limb_bits, + p_limbs: curve.p_limbs(limb_bits, num_limbs), + p_minus_1_limbs: curve.p_minus_1_limbs(limb_bits, num_limbs), + two_pow_w, + modulus_raw: curve.field_modulus_p, + curve_a_limbs: curve.curve_a_limbs(limb_bits, num_limbs), + modulus_bits: curve.modulus_bits(), + is_native, + modulus_fe, + } +} + +// --------------------------------------------------------------------------- +// MSM entry point +// --------------------------------------------------------------------------- + +/// Processes all deferred MSM operations. +/// +/// Each entry is `(points, scalars, (out_x, out_y, out_inf))` where: +/// - `points` has layout `[x1, y1, inf1, x2, y2, inf2, ...]` (3 per point) +/// - `scalars` has layout `[s1_lo, s1_hi, s2_lo, s2_hi, ...]` (2 per scalar) +/// - outputs are the R1CS witness indices for the result point +pub fn add_msm( + compiler: &mut NoirToR1CSCompiler, + msm_ops: Vec<( + Vec, + Vec, + (usize, usize, usize), + )>, + limb_bits: u32, + window_size: usize, + range_checks: &mut BTreeMap>, + curve: &CurveParams, +) { + for (points, scalars, outputs) in msm_ops { + add_single_msm( + compiler, + &points, + &scalars, + outputs, + limb_bits, + window_size, + range_checks, + curve, + ); + } +} + +/// Processes a single MSM operation. +fn add_single_msm( + compiler: &mut NoirToR1CSCompiler, + points: &[ConstantOrR1CSWitness], + scalars: &[ConstantOrR1CSWitness], + outputs: (usize, usize, usize), + limb_bits: u32, + window_size: usize, + range_checks: &mut BTreeMap>, + curve: &CurveParams, +) { + assert!( + points.len() % 3 == 0, + "points length must be a multiple of 3" + ); + let n = points.len() / 3; + assert_eq!( + scalars.len(), + 2 * n, + "scalars length must be 2x the number of points" + ); + + // Resolve all inputs to witness indices + let point_wits: Vec = points.iter().map(|p| resolve_input(compiler, p)).collect(); + let scalar_wits: Vec = scalars.iter().map(|s| resolve_input(compiler, s)).collect(); + + let is_native = curve.is_native_field(); + let num_limbs = if is_native { + 1 + } else { + (curve.modulus_bits() as usize + limb_bits as usize - 1) / limb_bits as usize + }; + + process_single_msm( + compiler, + &point_wits, + &scalar_wits, + outputs, + num_limbs, + limb_bits, + window_size, + range_checks, + curve, + ); +} + +/// Process a full single-MSM with runtime `num_limbs`. +/// +/// Handles coordinate decomposition, scalar_mul, accumulation, and +/// output constraining. +fn process_single_msm<'a>( + mut compiler: &'a mut NoirToR1CSCompiler, + point_wits: &[usize], + scalar_wits: &[usize], + outputs: (usize, usize, usize), + num_limbs: usize, + limb_bits: u32, + window_size: usize, + mut range_checks: &'a mut BTreeMap>, + curve: &CurveParams, +) { + let n_points = point_wits.len() / 3; + let mut acc: Option<(Limbs, Limbs)> = None; + + for i in 0..n_points { + let px_witness = point_wits[3 * i]; + let py_witness = point_wits[3 * i + 1]; + + let s_lo = scalar_wits[2 * i]; + let s_hi = scalar_wits[2 * i + 1]; + let scalar_bits = decompose_scalar_bits(compiler, s_lo, s_hi); + + // Build coordinates as Limbs + let (px, py) = if num_limbs == 1 { + // Single-limb: wrap witness directly + (Limbs::single(px_witness), Limbs::single(py_witness)) + } else { + // Multi-limb: decompose single witness into num_limbs limbs + let px_limbs = decompose_witness_to_limbs( + compiler, + px_witness, + limb_bits, + num_limbs, + range_checks, + ); + let py_limbs = decompose_witness_to_limbs( + compiler, + py_witness, + limb_bits, + num_limbs, + range_checks, + ); + (px_limbs, py_limbs) + }; + + let params = build_params(num_limbs, limb_bits, curve); + let mut ops = MultiLimbOps { + compiler, + range_checks, + params, + }; + let result = ec_points::scalar_mul(&mut ops, px, py, &scalar_bits, window_size); + compiler = ops.compiler; + range_checks = ops.range_checks; + + acc = Some(match acc { + None => result, + Some((ax, ay)) => { + let params = build_params(num_limbs, limb_bits, curve); + let mut ops = MultiLimbOps { + compiler, + range_checks, + params, + }; + let sum = ec_points::point_add(&mut ops, ax, ay, result.0, result.1); + compiler = ops.compiler; + range_checks = ops.range_checks; + sum + } + }); + } + + let (computed_x, computed_y) = acc.expect("MSM must have at least one point"); + let (out_x, out_y, out_inf) = outputs; + + if num_limbs == 1 { + constrain_equal(compiler, out_x, computed_x[0]); + constrain_equal(compiler, out_y, computed_y[0]); + } else { + let recomposed_x = recompose_limbs(compiler, computed_x.as_slice(), limb_bits); + let recomposed_y = recompose_limbs(compiler, computed_y.as_slice(), limb_bits); + constrain_equal(compiler, out_x, recomposed_x); + constrain_equal(compiler, out_y, recomposed_y); + } + constrain_zero(compiler, out_inf); +} + +/// Decompose a single witness into `num_limbs` limbs using digital +/// decomposition. +fn decompose_witness_to_limbs( + compiler: &mut NoirToR1CSCompiler, + witness: usize, + limb_bits: u32, + num_limbs: usize, + range_checks: &mut BTreeMap>, +) -> Limbs { + let log_bases = vec![limb_bits as usize; num_limbs]; + let dd = add_digital_decomposition(compiler, log_bases, vec![witness]); + let mut limbs = Limbs::new(num_limbs); + for i in 0..num_limbs { + limbs[i] = dd.get_digit_witness_index(i, 0); + // Range-check each decomposed limb to [0, 2^limb_bits). + // add_digital_decomposition constrains the recomposition but does + // NOT range-check individual digits. + range_checks.entry(limb_bits).or_default().push(limbs[i]); + } + limbs +} + +/// Recompose limbs back into a single witness: val = Σ limb[i] * +/// 2^(i*limb_bits) +fn recompose_limbs(compiler: &mut NoirToR1CSCompiler, limbs: &[usize], limb_bits: u32) -> usize { + let terms: Vec = limbs + .iter() + .enumerate() + .map(|(i, &limb)| { + let coeff = FieldElement::from(2u64).pow([(i as u64) * (limb_bits as u64)]); + SumTerm(Some(coeff), limb) + }) + .collect(); + compiler.add_sum(terms) +} + +/// Resolves a `ConstantOrR1CSWitness` to a witness index. +fn resolve_input(compiler: &mut NoirToR1CSCompiler, input: &ConstantOrR1CSWitness) -> usize { + match input { + ConstantOrR1CSWitness::Witness(idx) => *idx, + ConstantOrR1CSWitness::Constant(value) => { + let w = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Constant(ConstantTerm(w, *value))); + w + } + } +} + +/// Decomposes a scalar given as two 128-bit limbs into 256 bit witnesses (LSB +/// first). +fn decompose_scalar_bits( + compiler: &mut NoirToR1CSCompiler, + s_lo: usize, + s_hi: usize, +) -> Vec { + let log_bases_128 = vec![1usize; 128]; + + let dd_lo = add_digital_decomposition(compiler, log_bases_128.clone(), vec![s_lo]); + let dd_hi = add_digital_decomposition(compiler, log_bases_128, vec![s_hi]); + + let mut bits = Vec::with_capacity(256); + for bit_idx in 0..128 { + bits.push(dd_lo.get_digit_witness_index(bit_idx, 0)); + } + for bit_idx in 0..128 { + bits.push(dd_hi.get_digit_witness_index(bit_idx, 0)); + } + bits +} + +/// Constrains two witnesses to be equal: `a - b = 0`. +fn constrain_equal(compiler: &mut NoirToR1CSCompiler, a: usize, b: usize) { + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, a), (-FieldElement::ONE, b)], + &[(FieldElement::ZERO, compiler.witness_one())], + ); +} + +/// Constrains a witness to be zero: `w = 0`. +fn constrain_zero(compiler: &mut NoirToR1CSCompiler, w: usize) { + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, w)], + &[(FieldElement::ZERO, compiler.witness_one())], + ); +} diff --git a/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs b/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs new file mode 100644 index 000000000..ab84fc9b7 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs @@ -0,0 +1,591 @@ +//! N-limb modular arithmetic for EC field operations. +//! +//! Replaces both `ec_ops.rs` (N=1 path) and `wide_ops.rs` (N>1 path) with +//! unified multi-limb operations using `Limbs` (runtime-sized, Copy). + +use { + super::Limbs, + crate::noir_to_r1cs::NoirToR1CSCompiler, + ark_ff::{AdditiveGroup, BigInteger, Field, PrimeField}, + provekit_common::{ + witness::{SumTerm, WitnessBuilder}, + FieldElement, + }, + std::collections::BTreeMap, +}; + +// --------------------------------------------------------------------------- +// N=1 single-limb path (moved from ec_ops.rs) +// --------------------------------------------------------------------------- + +/// Reduce the value to given modulus (N=1 path). +/// Computes v = k*m + result, where 0 <= result < m. +pub fn reduce_mod_p( + compiler: &mut NoirToR1CSCompiler, + value: usize, + modulus: FieldElement, + range_checks: &mut BTreeMap>, +) -> usize { + let m = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Constant( + provekit_common::witness::ConstantTerm(m, modulus), + )); + let k = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::IntegerQuotient(k, value, modulus)); + + let k_mul_m = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Product(k_mul_m, k, m)); + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, k)], &[(FieldElement::ONE, m)], &[( + FieldElement::ONE, + k_mul_m, + )]); + + let result = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Sum(result, vec![ + SumTerm(Some(FieldElement::ONE), value), + SumTerm(Some(-FieldElement::ONE), k_mul_m), + ])); + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, k_mul_m), (FieldElement::ONE, result)], + &[(FieldElement::ONE, value)], + ); + + let modulus_bits = modulus.into_bigint().num_bits(); + range_checks + .entry(modulus_bits) + .or_default() + .push(result); + + result +} + +/// a + b mod p (N=1 path) +pub fn add_mod_p_single( + compiler: &mut NoirToR1CSCompiler, + a: usize, + b: usize, + modulus: FieldElement, + range_checks: &mut BTreeMap>, +) -> usize { + let a_add_b = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Sum(a_add_b, vec![ + SumTerm(Some(FieldElement::ONE), a), + SumTerm(Some(FieldElement::ONE), b), + ])); + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, a), (FieldElement::ONE, b)], + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, a_add_b)], + ); + reduce_mod_p(compiler, a_add_b, modulus, range_checks) +} + +/// a * b mod p (N=1 path) +pub fn mul_mod_p_single( + compiler: &mut NoirToR1CSCompiler, + a: usize, + b: usize, + modulus: FieldElement, + range_checks: &mut BTreeMap>, +) -> usize { + let a_mul_b = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Product(a_mul_b, a, b)); + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, a)], &[(FieldElement::ONE, b)], &[( + FieldElement::ONE, + a_mul_b, + )]); + reduce_mod_p(compiler, a_mul_b, modulus, range_checks) +} + +/// (a - b) mod p (N=1 path) +pub fn sub_mod_p_single( + compiler: &mut NoirToR1CSCompiler, + a: usize, + b: usize, + modulus: FieldElement, + range_checks: &mut BTreeMap>, +) -> usize { + let a_sub_b = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Sum(a_sub_b, vec![ + SumTerm(Some(FieldElement::ONE), a), + SumTerm(Some(-FieldElement::ONE), b), + ])); + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, a), (-FieldElement::ONE, b)], + &[(FieldElement::ONE, a_sub_b)], + ); + reduce_mod_p(compiler, a_sub_b, modulus, range_checks) +} + +/// a^(-1) mod p (N=1 path) +pub fn inv_mod_p_single( + compiler: &mut NoirToR1CSCompiler, + a: usize, + modulus: FieldElement, + range_checks: &mut BTreeMap>, +) -> usize { + let a_inv = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::ModularInverse(a_inv, a, modulus)); + + let reduced = mul_mod_p_single(compiler, a, a_inv, modulus, range_checks); + + // Constrain reduced = 1 + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[ + (FieldElement::ONE, reduced), + (-FieldElement::ONE, compiler.witness_one()), + ], + &[(FieldElement::ZERO, compiler.witness_one())], + ); + + let mod_bits = modulus.into_bigint().num_bits(); + range_checks.entry(mod_bits).or_default().push(a_inv); + + a_inv +} + +/// Checks if value is zero or not (used by all N values). +/// Returns a boolean witness: 1 if zero, 0 if non-zero. +/// +/// Uses SafeInverse (not Inverse) because the input value may be zero. +/// SafeInverse outputs 0 when the input is 0, and is solved in the Other +/// layer (not batch-inverted), so zero inputs don't poison the batch. +pub fn compute_is_zero(compiler: &mut NoirToR1CSCompiler, value: usize) -> usize { + let value_inv = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::SafeInverse(value_inv, value)); + + let value_mul_value_inv = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Product( + value_mul_value_inv, + value, + value_inv, + )); + + let is_zero = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Sum( + is_zero, + vec![ + SumTerm(Some(FieldElement::ONE), compiler.witness_one()), + SumTerm(Some(-FieldElement::ONE), value_mul_value_inv), + ], + )); + + // v × v^(-1) = 1 - is_zero + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, value)], + &[(FieldElement::ONE, value_inv)], + &[ + (FieldElement::ONE, compiler.witness_one()), + (-FieldElement::ONE, is_zero), + ], + ); + // v × is_zero = 0 + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, value)], + &[(FieldElement::ONE, is_zero)], + &[(FieldElement::ZERO, compiler.witness_one())], + ); + + is_zero +} + +// --------------------------------------------------------------------------- +// N≥2 multi-limb path (generalization of wide_ops.rs) +// --------------------------------------------------------------------------- + +/// (a + b) mod p for multi-limb values. +/// +/// Per limb i: v_i = a[i] + b[i] + 2^W - q*p[i] + carry_{i-1} +/// carry_i = floor(v_i / 2^W) +/// r[i] = v_i - carry_i * 2^W +pub fn add_mod_p_multi( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + a: Limbs, + b: Limbs, + p_limbs: &[FieldElement], + p_minus_1_limbs: &[FieldElement], + two_pow_w: FieldElement, + limb_bits: u32, + modulus_raw: &[u64; 4], +) -> Limbs { + let n = a.len(); + assert!(n >= 2, "add_mod_p_multi requires n >= 2, got n={n}"); + let w1 = compiler.witness_one(); + + // Witness: q = floor((a + b) / p) ∈ {0, 1} + let q = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::MultiLimbAddQuotient { + output: q, + a_limbs: a.as_slice().to_vec(), + b_limbs: b.as_slice().to_vec(), + modulus: *modulus_raw, + limb_bits, + num_limbs: n as u32, + }); + // q is boolean + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, q)], + &[(FieldElement::ONE, q)], + &[(FieldElement::ONE, q)], + ); + + let mut r = Limbs::new(n); + let mut carry_prev: Option = None; + + for i in 0..n { + // v_offset = a[i] + b[i] + 2^W - q*p[i] + carry_{i-1} + let mut terms = vec![ + SumTerm(None, a[i]), + SumTerm(None, b[i]), + SumTerm(Some(two_pow_w), w1), + SumTerm(Some(-p_limbs[i]), q), + ]; + if let Some(carry) = carry_prev { + terms.push(SumTerm(None, carry)); + // Compensate for previous 2^W offset + terms.push(SumTerm(Some(-FieldElement::ONE), w1)); + } + let v_offset = compiler.add_sum(terms); + + // carry = floor(v_offset / 2^W) + let carry = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::IntegerQuotient( + carry, v_offset, two_pow_w, + )); + // r[i] = v_offset - carry * 2^W + r[i] = compiler.add_sum(vec![ + SumTerm(None, v_offset), + SumTerm(Some(-two_pow_w), carry), + ]); + carry_prev = Some(carry); + } + + less_than_p_check_multi(compiler, range_checks, r, p_minus_1_limbs, two_pow_w, limb_bits); + + r +} + +/// (a - b) mod p for multi-limb values. +pub fn sub_mod_p_multi( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + a: Limbs, + b: Limbs, + p_limbs: &[FieldElement], + p_minus_1_limbs: &[FieldElement], + two_pow_w: FieldElement, + limb_bits: u32, + modulus_raw: &[u64; 4], +) -> Limbs { + let n = a.len(); + assert!(n >= 2, "sub_mod_p_multi requires n >= 2, got n={n}"); + let w1 = compiler.witness_one(); + + // Witness: q = (a < b) ? 1 : 0 + let q = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::MultiLimbSubBorrow { + output: q, + a_limbs: a.as_slice().to_vec(), + b_limbs: b.as_slice().to_vec(), + modulus: *modulus_raw, + limb_bits, + num_limbs: n as u32, + }); + // q is boolean + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, q)], + &[(FieldElement::ONE, q)], + &[(FieldElement::ONE, q)], + ); + + let mut r = Limbs::new(n); + let mut carry_prev: Option = None; + + for i in 0..n { + // v_offset = a[i] - b[i] + q*p[i] + 2^W + carry_{i-1} + let mut terms = vec![ + SumTerm(None, a[i]), + SumTerm(Some(-FieldElement::ONE), b[i]), + SumTerm(Some(p_limbs[i]), q), + SumTerm(Some(two_pow_w), w1), + ]; + if let Some(carry) = carry_prev { + terms.push(SumTerm(None, carry)); + terms.push(SumTerm(Some(-FieldElement::ONE), w1)); + } + let v_offset = compiler.add_sum(terms); + + let carry = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::IntegerQuotient( + carry, v_offset, two_pow_w, + )); + r[i] = compiler.add_sum(vec![ + SumTerm(None, v_offset), + SumTerm(Some(-two_pow_w), carry), + ]); + carry_prev = Some(carry); + } + + less_than_p_check_multi(compiler, range_checks, r, p_minus_1_limbs, two_pow_w, limb_bits); + + r +} + +/// (a * b) mod p for multi-limb values using schoolbook multiplication. +/// +/// Verifies: a·b = p·q + r in base W = 2^limb_bits. +/// Column k: Σ_{i+j=k} a[i]*b[j] + carry_{k-1} + OFFSET +/// = Σ_{i+j=k} p[i]*q[j] + r[k] + carry_k * W +pub fn mul_mod_p_multi( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + a: Limbs, + b: Limbs, + p_limbs: &[FieldElement], + p_minus_1_limbs: &[FieldElement], + two_pow_w: FieldElement, + limb_bits: u32, + modulus_raw: &[u64; 4], +) -> Limbs { + let n = a.len(); + assert!(n >= 2, "mul_mod_p_multi requires n >= 2, got n={n}"); + + // Soundness check: column equation values must not overflow the native field. + // The maximum value across either side of any column equation is bounded by + // 2^(2*limb_bits + ceil(log2(n)) + 3). This must be strictly less than the + // native field modulus p >= 2^(MODULUS_BIT_SIZE - 1). + { + let ceil_log2_n = (n as f64).log2().ceil() as u32; + let max_bits = 2 * limb_bits + ceil_log2_n + 3; + assert!( + max_bits < FieldElement::MODULUS_BIT_SIZE, + "Schoolbook column equation overflow: limb_bits={limb_bits}, n={n} limbs \ + requires {max_bits} bits, but native field is only {} bits. \ + Use smaller limb_bits.", + FieldElement::MODULUS_BIT_SIZE, + ); + } + + let w1 = compiler.witness_one(); + let num_carries = 2 * n - 2; + // Carry offset: 2^(limb_bits + ceil(log2(n)) + 1) + let extra_bits = ((n as f64).log2().ceil() as u32) + 1; + let carry_offset_bits = limb_bits + extra_bits; + let carry_offset_fe = FieldElement::from(2u64).pow([carry_offset_bits as u64]); + // offset_w = carry_offset * 2^limb_bits + let offset_w = FieldElement::from(2u64).pow([(carry_offset_bits + limb_bits) as u64]); + // offset_w_minus_carry = offset_w - carry_offset = carry_offset * (2^limb_bits - 1) + let offset_w_minus_carry = offset_w - carry_offset_fe; + + // Step 1: Allocate hint witnesses (q limbs, r limbs, carries) + let os = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::MultiLimbMulModHint { + output_start: os, + a_limbs: a.as_slice().to_vec(), + b_limbs: b.as_slice().to_vec(), + modulus: *modulus_raw, + limb_bits, + num_limbs: n as u32, + }); + + // q[0..n), r[n..2n), carries[2n..4n-2) + let q: Vec = (0..n).map(|i| os + i).collect(); + let r_indices: Vec = (0..n).map(|i| os + n + i).collect(); + let cu: Vec = (0..num_carries).map(|i| os + 2 * n + i).collect(); + + // Step 2: Product witnesses for a[i]*b[j] (n² R1CS constraints) + let mut ab_products = vec![vec![0usize; n]; n]; + for i in 0..n { + for j in 0..n { + ab_products[i][j] = compiler.add_product(a[i], b[j]); + } + } + + // Step 3: Column equations (2n-1 R1CS constraints) + for k in 0..(2 * n - 1) { + // LHS: Σ_{i+j=k} a[i]*b[j] + carry_{k-1} + OFFSET + let mut lhs_terms: Vec<(FieldElement, usize)> = Vec::new(); + for i in 0..n { + let j_val = k as isize - i as isize; + if j_val >= 0 && (j_val as usize) < n { + lhs_terms.push((FieldElement::ONE, ab_products[i][j_val as usize])); + } + } + // Add carry_{k-1} + if k > 0 { + lhs_terms.push((FieldElement::ONE, cu[k - 1])); + // Add offset_w - carry_offset for subsequent columns + lhs_terms.push((offset_w_minus_carry, w1)); + } else { + // First column: add offset_w + lhs_terms.push((offset_w, w1)); + } + + // RHS: Σ_{i+j=k} p[i]*q[j] + r[k] + carry_k * W + let mut rhs_terms: Vec<(FieldElement, usize)> = Vec::new(); + for i in 0..n { + let j_val = k as isize - i as isize; + if j_val >= 0 && (j_val as usize) < n { + rhs_terms.push((p_limbs[i], q[j_val as usize])); + } + } + if k < n { + rhs_terms.push((FieldElement::ONE, r_indices[k])); + } + if k < 2 * n - 2 { + rhs_terms.push((two_pow_w, cu[k])); + } else { + // Last column: RHS includes offset_w to balance the LHS offset + // LHS has: carry[k-1] + offset_w_minus_carry = true_carry + offset_w + // RHS needs: sum_pq[k] + offset_w (no outgoing carry at last column) + rhs_terms.push((offset_w, w1)); + } + + compiler + .r1cs + .add_constraint(&lhs_terms, &[(FieldElement::ONE, w1)], &rhs_terms); + } + + // Step 4: less-than-p check and range checks on r + let mut r_limbs = Limbs::new(n); + for (i, &ri) in r_indices.iter().enumerate() { + r_limbs[i] = ri; + } + less_than_p_check_multi(compiler, range_checks, r_limbs, p_minus_1_limbs, two_pow_w, limb_bits); + + // Step 5: Range checks for q limbs and carries + for i in 0..n { + range_checks.entry(limb_bits).or_default().push(q[i]); + } + // Carry range: limb_bits + extra_bits + 1 (carry_offset_bits + 1) + let carry_range_bits = carry_offset_bits + 1; + for &c in &cu { + range_checks.entry(carry_range_bits).or_default().push(c); + } + + r_limbs +} + +/// a^(-1) mod p for multi-limb values. +/// Uses MultiLimbModularInverse hint, verifies via mul_mod_p(a, inv) = [1, 0, ..., 0]. +pub fn inv_mod_p_multi( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + a: Limbs, + p_limbs: &[FieldElement], + p_minus_1_limbs: &[FieldElement], + two_pow_w: FieldElement, + limb_bits: u32, + modulus_raw: &[u64; 4], +) -> Limbs { + let n = a.len(); + assert!(n >= 2, "inv_mod_p_multi requires n >= 2, got n={n}"); + + // Hint: compute inverse + let inv_start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::MultiLimbModularInverse { + output_start: inv_start, + a_limbs: a.as_slice().to_vec(), + modulus: *modulus_raw, + limb_bits, + num_limbs: n as u32, + }); + let mut inv = Limbs::new(n); + for i in 0..n { + inv[i] = inv_start + i; + } + + // Verify: a * inv mod p = [1, 0, ..., 0] + let product = mul_mod_p_multi( + compiler, + range_checks, + a, + inv, + p_limbs, + p_minus_1_limbs, + two_pow_w, + limb_bits, + modulus_raw, + ); + + // Constrain product[0] = 1 + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, product[0])], + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, compiler.witness_one())], + ); + // Constrain product[1..n] = 0 + for i in 1..n { + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, product[i])], + &[(FieldElement::ONE, compiler.witness_one())], + &[], + ); + } + + inv +} + +/// Proves r < p by decomposing (p-1) - r into non-negative multi-limb values. +/// Uses borrow propagation: d[i] = (p-1)[i] - r[i] + borrow_in - borrow_out * 2^W +fn less_than_p_check_multi( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + r: Limbs, + p_minus_1_limbs: &[FieldElement], + two_pow_w: FieldElement, + limb_bits: u32, +) { + let n = r.len(); + let w1 = compiler.witness_one(); + let mut borrow_prev: Option = None; + + for i in 0..n { + // v_diff = (p-1)[i] + 2^W - r[i] + borrow_prev + let p_minus_1_plus_offset = p_minus_1_limbs[i] + two_pow_w; + let mut terms = vec![ + SumTerm(Some(p_minus_1_plus_offset), w1), + SumTerm(Some(-FieldElement::ONE), r[i]), + ]; + if let Some(borrow) = borrow_prev { + terms.push(SumTerm(None, borrow)); + terms.push(SumTerm(Some(-FieldElement::ONE), w1)); + } + let v_diff = compiler.add_sum(terms); + + // borrow = floor(v_diff / 2^W) + let borrow = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::IntegerQuotient( + borrow, v_diff, two_pow_w, + )); + // d[i] = v_diff - borrow * 2^W + let d_i = compiler.add_sum(vec![ + SumTerm(None, v_diff), + SumTerm(Some(-two_pow_w), borrow), + ]); + + // Range check r[i] and d[i] + range_checks.entry(limb_bits).or_default().push(r[i]); + range_checks.entry(limb_bits).or_default().push(d_i); + + borrow_prev = Some(borrow); + } + + // Constrain final borrow = 0: if borrow_out != 0, then r > p-1 (i.e. r >= p), + // which would mean the result is not properly reduced. + if let Some(final_borrow) = borrow_prev { + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, final_borrow)], + &[(FieldElement::ZERO, compiler.witness_one())], + ); + } +} diff --git a/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs b/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs new file mode 100644 index 000000000..4f1c45448 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs @@ -0,0 +1,275 @@ +//! `MultiLimbOps` — unified FieldOps implementation parameterized by runtime limb count. +//! +//! Uses `Limbs` (a fixed-capacity Copy type) as `FieldOps::Elem`, enabling +//! arbitrary limb counts without const generics or dispatch macros. + +use { + super::{ + multi_limb_arith, + Limbs, + FieldOps, + }, + crate::noir_to_r1cs::NoirToR1CSCompiler, + ark_ff::{AdditiveGroup, Field}, + provekit_common::{ + witness::{ConstantTerm, SumTerm, WitnessBuilder}, + FieldElement, + }, + std::collections::BTreeMap, +}; + +/// Parameters for multi-limb field arithmetic. +pub struct MultiLimbParams { + pub num_limbs: usize, + pub limb_bits: u32, + pub p_limbs: Vec, + pub p_minus_1_limbs: Vec, + pub two_pow_w: FieldElement, + pub modulus_raw: [u64; 4], + pub curve_a_limbs: Vec, + pub modulus_bits: u32, + /// p = native field → skip mod reduction + pub is_native: bool, + /// For N=1 non-native: the modulus as a single FieldElement + pub modulus_fe: Option, +} + +/// Unified field operations struct parameterized by runtime limb count. +pub struct MultiLimbOps<'a> { + pub compiler: &'a mut NoirToR1CSCompiler, + pub range_checks: &'a mut BTreeMap>, + pub params: MultiLimbParams, +} + +impl MultiLimbOps<'_> { + fn is_native_single(&self) -> bool { + self.params.num_limbs == 1 && self.params.is_native + } + + fn is_non_native_single(&self) -> bool { + self.params.num_limbs == 1 && !self.params.is_native + } + + fn n(&self) -> usize { + self.params.num_limbs + } +} + +impl FieldOps for MultiLimbOps<'_> { + type Elem = Limbs; + + fn add(&mut self, a: Limbs, b: Limbs) -> Limbs { + debug_assert_eq!(a.len(), self.n(), "a.len() != num_limbs"); + debug_assert_eq!(b.len(), self.n(), "b.len() != num_limbs"); + if self.is_native_single() { + // When both operands are the same witness, merge into a single + // term with coefficient 2 to avoid duplicate column indices in + // the R1CS sparse matrix (set overwrites on duplicate (row,col)). + let r = if a[0] == b[0] { + self.compiler.add_sum(vec![ + SumTerm(Some(FieldElement::from(2u64)), a[0]), + ]) + } else { + self.compiler.add_sum(vec![ + SumTerm(None, a[0]), + SumTerm(None, b[0]), + ]) + }; + Limbs::single(r) + } else if self.is_non_native_single() { + let modulus = self.params.modulus_fe.unwrap(); + let r = multi_limb_arith::add_mod_p_single( + self.compiler, a[0], b[0], modulus, self.range_checks, + ); + Limbs::single(r) + } else { + multi_limb_arith::add_mod_p_multi( + self.compiler, + self.range_checks, + a, + b, + &self.params.p_limbs, + &self.params.p_minus_1_limbs, + self.params.two_pow_w, + self.params.limb_bits, + &self.params.modulus_raw, + ) + } + } + + fn sub(&mut self, a: Limbs, b: Limbs) -> Limbs { + debug_assert_eq!(a.len(), self.n(), "a.len() != num_limbs"); + debug_assert_eq!(b.len(), self.n(), "b.len() != num_limbs"); + if self.is_native_single() { + // When both operands are the same witness, a - a = 0. Use a + // single zero-coefficient term to avoid duplicate column indices. + let r = if a[0] == b[0] { + self.compiler.add_sum(vec![ + SumTerm(Some(FieldElement::ZERO), a[0]), + ]) + } else { + self.compiler.add_sum(vec![ + SumTerm(None, a[0]), + SumTerm(Some(-FieldElement::ONE), b[0]), + ]) + }; + Limbs::single(r) + } else if self.is_non_native_single() { + let modulus = self.params.modulus_fe.unwrap(); + let r = multi_limb_arith::sub_mod_p_single( + self.compiler, a[0], b[0], modulus, self.range_checks, + ); + Limbs::single(r) + } else { + multi_limb_arith::sub_mod_p_multi( + self.compiler, + self.range_checks, + a, + b, + &self.params.p_limbs, + &self.params.p_minus_1_limbs, + self.params.two_pow_w, + self.params.limb_bits, + &self.params.modulus_raw, + ) + } + } + + fn mul(&mut self, a: Limbs, b: Limbs) -> Limbs { + debug_assert_eq!(a.len(), self.n(), "a.len() != num_limbs"); + debug_assert_eq!(b.len(), self.n(), "b.len() != num_limbs"); + if self.is_native_single() { + let r = self.compiler.add_product(a[0], b[0]); + Limbs::single(r) + } else if self.is_non_native_single() { + let modulus = self.params.modulus_fe.unwrap(); + let r = multi_limb_arith::mul_mod_p_single( + self.compiler, a[0], b[0], modulus, self.range_checks, + ); + Limbs::single(r) + } else { + multi_limb_arith::mul_mod_p_multi( + self.compiler, + self.range_checks, + a, + b, + &self.params.p_limbs, + &self.params.p_minus_1_limbs, + self.params.two_pow_w, + self.params.limb_bits, + &self.params.modulus_raw, + ) + } + } + + fn inv(&mut self, a: Limbs) -> Limbs { + debug_assert_eq!(a.len(), self.n(), "a.len() != num_limbs"); + if self.is_native_single() { + let a_inv = self.compiler.num_witnesses(); + self.compiler + .add_witness_builder(WitnessBuilder::Inverse(a_inv, a[0])); + // a * a_inv = 1 + self.compiler.r1cs.add_constraint( + &[(FieldElement::ONE, a[0])], + &[(FieldElement::ONE, a_inv)], + &[(FieldElement::ONE, self.compiler.witness_one())], + ); + Limbs::single(a_inv) + } else if self.is_non_native_single() { + let modulus = self.params.modulus_fe.unwrap(); + let r = multi_limb_arith::inv_mod_p_single( + self.compiler, a[0], modulus, self.range_checks, + ); + Limbs::single(r) + } else { + multi_limb_arith::inv_mod_p_multi( + self.compiler, + self.range_checks, + a, + &self.params.p_limbs, + &self.params.p_minus_1_limbs, + self.params.two_pow_w, + self.params.limb_bits, + &self.params.modulus_raw, + ) + } + } + + fn curve_a(&mut self) -> Limbs { + let n = self.n(); + let mut out = Limbs::new(n); + for i in 0..n { + let w = self.compiler.num_witnesses(); + self.compiler + .add_witness_builder(WitnessBuilder::Constant(ConstantTerm( + w, + self.params.curve_a_limbs[i], + ))); + out[i] = w; + } + out + } + + fn select( + &mut self, + flag: usize, + on_false: Limbs, + on_true: Limbs, + ) -> Limbs { + super::constrain_boolean(self.compiler, flag); + let n = self.n(); + let mut out = Limbs::new(n); + for i in 0..n { + out[i] = super::select_witness(self.compiler, flag, on_false[i], on_true[i]); + } + out + } + + fn is_zero(&mut self, value: usize) -> usize { + multi_limb_arith::compute_is_zero(self.compiler, value) + } + + fn pack_bits(&mut self, bits: &[usize]) -> usize { + super::pack_bits_helper(self.compiler, bits) + } + + fn elem_is_zero(&mut self, value: Limbs) -> usize { + let n = self.n(); + if n == 1 { + multi_limb_arith::compute_is_zero(self.compiler, value[0]) + } else { + // Check each limb is zero and AND the results together + let mut result = multi_limb_arith::compute_is_zero(self.compiler, value[0]); + for i in 1..n { + let limb_zero = multi_limb_arith::compute_is_zero(self.compiler, value[i]); + result = self.compiler.add_product(result, limb_zero); + } + result + } + } + + fn constant_one(&mut self) -> Limbs { + let n = self.n(); + let mut out = Limbs::new(n); + // limb[0] = 1 + let w0 = self.compiler.num_witnesses(); + self.compiler + .add_witness_builder(WitnessBuilder::Constant(ConstantTerm(w0, FieldElement::ONE))); + out[0] = w0; + // limb[1..n] = 0 + for i in 1..n { + let w = self.compiler.num_witnesses(); + self.compiler + .add_witness_builder(WitnessBuilder::Constant(ConstantTerm( + w, + FieldElement::ZERO, + ))); + out[i] = w; + } + out + } + + fn bool_and(&mut self, a: usize, b: usize) -> usize { + self.compiler.add_product(a, b) + } +} diff --git a/provekit/r1cs-compiler/src/msm/wide_ops.rs b/provekit/r1cs-compiler/src/msm/wide_ops.rs deleted file mode 100644 index 167d6f986..000000000 --- a/provekit/r1cs-compiler/src/msm/wide_ops.rs +++ /dev/null @@ -1,563 +0,0 @@ -use { - crate::{ - msm::curve::{CurveParams, Limb2}, - noir_to_r1cs::NoirToR1CSCompiler, - }, - ark_ff::Field, - provekit_common::{ - witness::{SumTerm, WitnessBuilder}, - FieldElement, - }, - std::collections::BTreeMap, -}; - -/// (a + b) mod p for 256-bit values in two 128-bit limbs. -/// -/// Equation: a + b = q * p + r, where q ∈ {0, 1}, 0 ≤ r < p. -/// -/// Uses the offset trick to avoid negative intermediate values: -/// v_offset = a_lo + b_lo + 2^128 - q * p_lo (always ≥ 0) -/// carry_offset = floor(v_offset / 2^128) ∈ {0, 1, 2} -/// r_lo = v_offset - carry_offset * 2^128 -/// r_hi = a_hi + b_hi + carry_offset - 1 - q * p_hi -/// -/// Less-than-p check (proves r < p): -/// d_lo + d_hi * 2^128 = (p - 1) - r (all components ≥ 0) -/// -/// Constraints (7 total): -/// 1. q is boolean: q * q = q -/// 2-3. Column 0: v_offset defined, then r_lo = v_offset - carry_offset * -/// 2^128 -/// 4. Column 1: r_hi = a_hi + b_hi + carry_offset - 1 - q * p_hi -/// 5-6. LT check: v_diff defined, then d_lo = v_diff - borrow_compl * 2^128 -/// 7. LT check: d_hi = (p_hi - 1) + borrow_compl - r_hi -/// -/// Range checks: r_lo, r_hi, d_lo, d_hi (128-bit each) -pub fn add_mod_p( - compiler: &mut NoirToR1CSCompiler, - range_checks: &mut BTreeMap>, - a: Limb2, - b: Limb2, - params: &CurveParams, -) -> Limb2 { - let two_128 = FieldElement::from(2u64).pow([128u64]); - let p_lo_fe = params.p_lo_fe(); - let p_hi_fe = params.p_hi_fe(); - let w1 = compiler.witness_one(); - - // Witness: q = floor((a + b) / p) ∈ {0, 1} - // ----------------------------------------------------------- - let q = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::WideAddQuotient { - output: q, - a_lo: a.lo, - a_hi: a.hi, - b_lo: b.lo, - b_hi: b.hi, - modulus: params.field_modulus_p, - }); - // constraining q to be boolean - compiler - .r1cs - .add_constraint(&[(FieldElement::ONE, q)], &[(FieldElement::ONE, q)], &[( - FieldElement::ONE, - q, - )]); - - // Computing r_lo: lower 128 bits of result - // ----------------------------------------------------------- - // v_offset = a_lo + b_lo + 2^128 - q * p_lo - // (2^128 offset ensures v_offset is always non-negative) - let v_offset = compiler.add_sum(vec![ - SumTerm(None, a.lo), - SumTerm(None, b.lo), - SumTerm(Some(two_128), w1), - SumTerm(Some(-p_lo_fe), q), - ]); - // computing carry_offset = floor(v_offset / 2^128) - let carry_offset = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::IntegerQuotient( - carry_offset, - v_offset, - two_128, - )); - // computing r_lo = v_offset - carry_offset * 2^128 - let r_lo = compiler.add_sum(vec![ - SumTerm(None, v_offset), - SumTerm(Some(-two_128), carry_offset), - ]); - - // Computing r_hi: upper 128 bits of result - // ----------------------------------------------------------- - // r_hi = a_hi + b_hi + carry_offset - 1 - q * p_hi - // (-1 compensates for the 2^128 offset added in the low column) - let r_hi = compiler.add_sum(vec![ - SumTerm(None, a.hi), - SumTerm(None, b.hi), - SumTerm(None, carry_offset), - SumTerm(Some(-FieldElement::ONE), w1), - SumTerm(Some(-p_hi_fe), q), - ]); - - less_than_p_check(compiler, range_checks, r_lo, r_hi, params); - - Limb2 { lo: r_lo, hi: r_hi } -} - -/// (a - b) mod p for 256-bit values in two 128-bit limbs. -/// -/// Equation: a - b + q * p = r, where q ∈ {0, 1}, 0 ≤ r < p. -/// q = 0 if a ≥ b (result is non-negative without correction) -/// q = 1 if a < b (add p to make result non-negative) -/// -/// Uses the offset trick to avoid negative intermediate values: -/// v_offset = a_lo - b_lo + q * p_lo + 2^128 (always ≥ 0) -/// carry_offset = floor(v_offset / 2^128) ∈ {0, 1, 2} -/// r_lo = v_offset - carry_offset * 2^128 -/// r_hi = a_hi - b_hi + q * p_hi + carry_offset - 1 -/// -/// Less-than-p check (proves r < p): -/// d_lo + d_hi * 2^128 = (p - 1) - r (all components ≥ 0) -/// -/// Constraints (7 total): -/// 1. q is boolean: q * q = q -/// 2-3. Column 0: v_offset defined, then r_lo = v_offset - carry_offset * -/// 2^128 -/// 4. Column 1: r_hi = a_hi - b_hi + q * p_hi + carry_offset - 1 -/// 5-6. LT check: v_diff defined, then d_lo = v_diff - borrow_compl * 2^128 -/// 7. LT check: d_hi = (p_hi - 1) + borrow_compl - r_hi -/// -/// Range checks: r_lo, r_hi, d_lo, d_hi (128-bit each) -pub fn sub_mod_p( - compiler: &mut NoirToR1CSCompiler, - range_checks: &mut BTreeMap>, - a: Limb2, - b: Limb2, - params: &CurveParams, -) -> Limb2 { - let two_128 = FieldElement::from(2u64).pow([128u64]); - let p_lo_fe = params.p_lo_fe(); - let p_hi_fe = params.p_hi_fe(); - let w1 = compiler.witness_one(); - - // Witness: q = (a < b) ? 1 : 0 - // ----------------------------------------------------------- - let q = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::WideSubBorrow { - output: q, - a_lo: a.lo, - a_hi: a.hi, - b_lo: b.lo, - b_hi: b.hi, - }); - // constraining q to be boolean - compiler - .r1cs - .add_constraint(&[(FieldElement::ONE, q)], &[(FieldElement::ONE, q)], &[( - FieldElement::ONE, - q, - )]); - - // Computing r_lo: lower 128 bits of result - // ----------------------------------------------------------- - // v_offset = a_lo - b_lo + q * p_lo + 2^128 - // (2^128 offset ensures v_offset is always non-negative) - let v_offset = compiler.add_sum(vec![ - SumTerm(None, a.lo), - SumTerm(Some(-FieldElement::ONE), b.lo), - SumTerm(Some(p_lo_fe), q), - SumTerm(Some(two_128), w1), - ]); - // computing carry_offset = floor(v_offset / 2^128) - let carry_offset = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::IntegerQuotient( - carry_offset, - v_offset, - two_128, - )); - // computing r_lo = v_offset - carry_offset * 2^128 - let r_lo = compiler.add_sum(vec![ - SumTerm(None, v_offset), - SumTerm(Some(-two_128), carry_offset), - ]); - - // Computing r_hi: upper 128 bits of result - // ----------------------------------------------------------- - // r_hi = a_hi - b_hi + q * p_hi + carry_offset - 1 - // (-1 compensates for the 2^128 offset added in the low column) - let r_hi = compiler.add_sum(vec![ - SumTerm(None, a.hi), - SumTerm(Some(-FieldElement::ONE), b.hi), - SumTerm(Some(p_hi_fe), q), - SumTerm(None, carry_offset), - SumTerm(Some(-FieldElement::ONE), w1), - ]); - - less_than_p_check(compiler, range_checks, r_lo, r_hi, params); - - Limb2 { lo: r_lo, hi: r_hi } -} - -/// (a × b) mod p for 256-bit values in two 128-bit limbs. -/// -/// Verifies the integer identity `a * b = p * q + r` using schoolbook -/// multiplication in base W = 2^86 (86-bit limbs ensure all column -/// products < 2^172 ≪ BN254_r ≈ 2^254, so field equations = integer equations). -/// -/// Three layers of verification: -/// 1. Decomposition links: prove 86-bit witnesses match the 128-bit -/// inputs/outputs -/// 2. Column equations: prove a86 * b86 = p86 * q86 + r86 (integer) -/// 3. Less-than-p check: prove r < p -/// -/// Witness layout (MulModHint, 20 witnesses at output_start): -/// [0..2) q_lo, q_hi — quotient 128-bit limbs (unconstrained) -/// [2..4) r_lo, r_hi — remainder 128-bit limbs (OUTPUT) -/// [4..7) a86_0..2 — a in 86-bit limbs -/// [7..10) b86_0..2 — b in 86-bit limbs -/// [10..13) q86_0..2 — q in 86-bit limbs -/// [13..16) r86_0..2 — r in 86-bit limbs -/// [16..20) c0u..c3u — unsigned-offset carries (c_signed + 2^88) -/// -/// Constraints (26 total): -/// 9 decomposition links (a, b, r × 3 each) -/// 9 product witnesses (a_i × b_j) -/// 5 column equations -/// 3 less-than-p check -/// -/// Range checks (23 total): -/// 128-bit: r_lo, r_hi, d_lo, d_hi -/// 86-bit: a86_0, a86_1, b86_0, b86_1, q86_0, q86_1, r86_0, r86_1 -/// 84-bit: a86_2, b86_2, q86_2, r86_2 -/// 89-bit: c0u, c1u, c2u, c3u -/// 44-bit: carry_a, carry_b, carry_r -pub fn mul_mod_p( - compiler: &mut NoirToR1CSCompiler, - range_checks: &mut BTreeMap>, - a: Limb2, - b: Limb2, - params: &CurveParams, -) -> Limb2 { - let two_44 = FieldElement::from(2u64).pow([44u64]); - let two_86 = FieldElement::from(2u64).pow([86u64]); - let two_128 = FieldElement::from(2u64).pow([128u64]); - let offset_fe = FieldElement::from(2u64).pow([88u64]); // CARRY_OFFSET - let offset_w = FieldElement::from(2u64).pow([174u64]); // 2^88 * 2^86 - let offset_w_minus_1 = offset_w - offset_fe; // 2^88 * (2^86 - 1) - let [p0, p1, p2] = params.p_86_limbs(); - let w1 = compiler.witness_one(); - - // Step 1: Allocate MulModHint (20 witnesses) - // ----------------------------------------------------------- - let os = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::MulModHint { - output_start: os, - a_lo: a.lo, - a_hi: a.hi, - b_lo: b.lo, - b_hi: b.hi, - modulus: params.field_modulus_p, - }); - - // Witness indices - let r_lo = os + 2; - let r_hi = os + 3; - let a86 = [os + 4, os + 5, os + 6]; - let b86 = [os + 7, os + 8, os + 9]; - let q86 = [os + 10, os + 11, os + 12]; - let r86 = [os + 13, os + 14, os + 15]; - let cu = [os + 16, os + 17, os + 18, os + 19]; - - // Step 2: Decomposition consistency for a, b, r - // ----------------------------------------------------------- - decompose_check( - compiler, - range_checks, - a.lo, - a.hi, - a86, - two_86, - two_44, - two_128, - w1, - ); - decompose_check( - compiler, - range_checks, - b.lo, - b.hi, - b86, - two_86, - two_44, - two_128, - w1, - ); - decompose_check( - compiler, - range_checks, - r_lo, - r_hi, - r86, - two_86, - two_44, - two_128, - w1, - ); - - // Step 3: Product witnesses (9 R1CS constraints) - // ----------------------------------------------------------- - let ab00 = compiler.add_product(a86[0], b86[0]); - let ab01 = compiler.add_product(a86[0], b86[1]); - let ab10 = compiler.add_product(a86[1], b86[0]); - let ab02 = compiler.add_product(a86[0], b86[2]); - let ab11 = compiler.add_product(a86[1], b86[1]); - let ab20 = compiler.add_product(a86[2], b86[0]); - let ab12 = compiler.add_product(a86[1], b86[2]); - let ab21 = compiler.add_product(a86[2], b86[1]); - let ab22 = compiler.add_product(a86[2], b86[2]); - - // Step 4: Column equations (5 R1CS constraints) - // ----------------------------------------------------------- - // Identity: a*b = p*q + r in base W=2^86. - // Carries stored with unsigned offset: cu_i = c_i + 2^88. - // - // col0: ab00 + 2^174 = p0*q0 + r0 + W*cu0 - // col1: ab01 + ab10 + cu0 + (2^174-2^88) = p0*q1 + p1*q0 + r1 + W*cu1 - // col2: ab02+ab11+ab20 + cu1 + (2^174-2^88) = p0*q2+p1*q1+p2*q0 + r2 + W*cu2 - // col3: ab12 + ab21 + cu2 + (2^174-2^88) = p1*q2 + p2*q1 + W*cu3 - // col4: ab22 + cu3 = p2*q2 + 2^88 - - // col0 - compiler.r1cs.add_constraint( - &[(FieldElement::ONE, ab00), (offset_w, w1)], - &[(FieldElement::ONE, w1)], - &[(p0, q86[0]), (FieldElement::ONE, r86[0]), (two_86, cu[0])], - ); - - // col1 - compiler.r1cs.add_constraint( - &[ - (FieldElement::ONE, ab01), - (FieldElement::ONE, ab10), - (FieldElement::ONE, cu[0]), - (offset_w_minus_1, w1), - ], - &[(FieldElement::ONE, w1)], - &[ - (p0, q86[1]), - (p1, q86[0]), - (FieldElement::ONE, r86[1]), - (two_86, cu[1]), - ], - ); - - // col2 - compiler.r1cs.add_constraint( - &[ - (FieldElement::ONE, ab02), - (FieldElement::ONE, ab11), - (FieldElement::ONE, ab20), - (FieldElement::ONE, cu[1]), - (offset_w_minus_1, w1), - ], - &[(FieldElement::ONE, w1)], - &[ - (p0, q86[2]), - (p1, q86[1]), - (p2, q86[0]), - (FieldElement::ONE, r86[2]), - (two_86, cu[2]), - ], - ); - - // col3 - compiler.r1cs.add_constraint( - &[ - (FieldElement::ONE, ab12), - (FieldElement::ONE, ab21), - (FieldElement::ONE, cu[2]), - (offset_w_minus_1, w1), - ], - &[(FieldElement::ONE, w1)], - &[(p1, q86[2]), (p2, q86[1]), (two_86, cu[3])], - ); - - // col4 - compiler.r1cs.add_constraint( - &[(FieldElement::ONE, ab22), (FieldElement::ONE, cu[3])], - &[(FieldElement::ONE, w1)], - &[(p2, q86[2]), (offset_fe, w1)], - ); - - // Step 5: Less-than-p check (r < p) + 128-bit range checks on r_lo, r_hi - // ----------------------------------------------------------- - less_than_p_check(compiler, range_checks, r_lo, r_hi, params); - - // Step 6: Range checks (mul-specific) - // ----------------------------------------------------------- - // 86-bit: limbs 0 and 1 of a, b, q, r - for &idx in &[ - a86[0], a86[1], b86[0], b86[1], q86[0], q86[1], r86[0], r86[1], - ] { - range_checks.entry(86).or_default().push(idx); - } - - // 84-bit: limb 2 of a, b, q, r (bits [172..256) = 84 bits) - for &idx in &[a86[2], b86[2], q86[2], r86[2]] { - range_checks.entry(84).or_default().push(idx); - } - - // 89-bit: unsigned-offset carries (|c_signed| < 2^88, so c_unsigned ∈ [0, - // 2^89)) - for &idx in &cu { - range_checks.entry(89).or_default().push(idx); - } - - Limb2 { lo: r_lo, hi: r_hi } -} - -/// a^(-1) mod p for 256-bit values in two 128-bit limbs. -/// -/// Hint-and-verify pattern: -/// 1. Prover computes inv = a^(p-2) mod p (Fermat's little theorem) -/// 2. Circuit verifies a * inv mod p = 1 -/// -/// Constraints: 26 from mul_mod_p + 2 equality checks = 28 total. -pub fn inv_mod_p( - compiler: &mut NoirToR1CSCompiler, - range_checks: &mut BTreeMap>, - value: Limb2, - params: &CurveParams, -) -> Limb2 { - // Witness: inv = a^(-1) mod p (2 witnesses: lo, hi) - // ----------------------------------------------------------- - let value_inv = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::WideModularInverse { - output_start: value_inv, - a_lo: value.lo, - a_hi: value.hi, - modulus: params.field_modulus_p, - }); - let inv = Limb2 { - lo: value_inv, - hi: value_inv + 1, - }; - - // Verifying a * inv mod p = 1 - // ----------------------------------------------------------- - // computing product = value * inv mod p - let product = mul_mod_p(compiler, range_checks, value, inv, params); - // constraining product_lo = 1 (because 1 = 1 + 0 * 2^128) - compiler.r1cs.add_constraint( - &[(FieldElement::ONE, product.lo)], - &[(FieldElement::ONE, compiler.witness_one())], - &[(FieldElement::ONE, compiler.witness_one())], - ); - // constraining product_hi = 0 - compiler.r1cs.add_constraint( - &[(FieldElement::ONE, product.hi)], - &[(FieldElement::ONE, compiler.witness_one())], - &[], - ); - - inv -} - -/// Verify that 128-bit limbs (v_lo, v_hi) decompose into 86-bit limbs (v86). -/// -/// Equations: -/// v_lo = v86_0 + v86_1 * 2^86 - carry * 2^128 -/// v_hi = carry + v86_2 * 2^44 -/// -/// All intermediate values < 2^172 ≪ BN254_r, so field equations = integer -/// equations. -/// -/// Creates: 1 intermediate witness (v_sum), 1 carry witness (IntegerQuotient). -/// Adds: 3 R1CS constraints (v_sum definition + 2 decomposition checks). -/// Range checks: carry (44-bit). -/// Proves r < p by decomposing (p - 1) - r into non-negative 128-bit limbs. -/// -/// If d_lo, d_hi >= 0 then (p - 1) - r >= 0, i.e. r <= p - 1 < p. -/// Uses the 2^128 offset trick to avoid negative intermediate values. -/// -/// Range checks r_lo, r_hi, d_lo, d_hi (128-bit each). -fn less_than_p_check( - compiler: &mut NoirToR1CSCompiler, - range_checks: &mut BTreeMap>, - r_lo: usize, - r_hi: usize, - params: &CurveParams, -) { - let two_128 = FieldElement::from(2u64).pow([128u64]); - let p_lo_fe = params.p_lo_fe(); - let p_hi_fe = params.p_hi_fe(); - let w1 = compiler.witness_one(); - - // v_diff = (p_lo - 1) + 2^128 - r_lo - // (2^128 offset ensures v_diff is always non-negative) - let p_lo_minus_1_plus_offset = p_lo_fe - FieldElement::ONE + two_128; - let v_diff = compiler.add_sum(vec![ - SumTerm(Some(p_lo_minus_1_plus_offset), w1), - SumTerm(Some(-FieldElement::ONE), r_lo), - ]); - // borrow_compl = floor(v_diff / 2^128) - let borrow_compl = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::IntegerQuotient( - borrow_compl, - v_diff, - two_128, - )); - // d_lo = v_diff - borrow_compl * 2^128 - let d_lo = compiler.add_sum(vec![ - SumTerm(None, v_diff), - SumTerm(Some(-two_128), borrow_compl), - ]); - // d_hi = (p_hi - 1) + borrow_compl - r_hi - let d_hi = compiler.add_sum(vec![ - SumTerm(Some(p_hi_fe - FieldElement::ONE), w1), - SumTerm(None, borrow_compl), - SumTerm(Some(-FieldElement::ONE), r_hi), - ]); - - // Range checks (128-bit) - range_checks.entry(128).or_default().push(r_lo); - range_checks.entry(128).or_default().push(r_hi); - range_checks.entry(128).or_default().push(d_lo); - range_checks.entry(128).or_default().push(d_hi); -} - -fn decompose_check( - compiler: &mut NoirToR1CSCompiler, - range_checks: &mut BTreeMap>, - v_lo: usize, - v_hi: usize, - v86: [usize; 3], - two_86: FieldElement, - two_44: FieldElement, - two_128: FieldElement, - w1: usize, -) { - // v_sum = v86_0 + v86_1 * 2^86 (intermediate for IntegerQuotient) - let v_sum = compiler.add_sum(vec![SumTerm(None, v86[0]), SumTerm(Some(two_86), v86[1])]); - - // carry = floor(v_sum / 2^128) ∈ [0, 2^44) - let carry = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::IntegerQuotient(carry, v_sum, two_128)); - - // Low check: v_sum - carry * 2^128 = v_lo - compiler.r1cs.add_constraint( - &[(FieldElement::ONE, v_sum), (-two_128, carry)], - &[(FieldElement::ONE, w1)], - &[(FieldElement::ONE, v_lo)], - ); - - // High check: carry + v86_2 * 2^44 = v_hi - compiler.r1cs.add_constraint( - &[(FieldElement::ONE, carry), (two_44, v86[2])], - &[(FieldElement::ONE, w1)], - &[(FieldElement::ONE, v_hi)], - ); - - // Range check carry (44-bit) - range_checks.entry(44).or_default().push(carry); -} diff --git a/provekit/r1cs-compiler/src/noir_to_r1cs.rs b/provekit/r1cs-compiler/src/noir_to_r1cs.rs index 2d4245636..18bc22ddc 100644 --- a/provekit/r1cs-compiler/src/noir_to_r1cs.rs +++ b/provekit/r1cs-compiler/src/noir_to_r1cs.rs @@ -3,6 +3,7 @@ use { binops::add_combined_binop_constraints, digits::{add_digital_decomposition, DigitalDecompositionWitnessesBuilder}, memory::{add_ram_checking, add_rom_checking, MemoryBlock, MemoryOperation}, + msm::add_msm, poseidon2::add_poseidon2_permutation, range_check::add_range_checks, sha256_compression::add_sha256_compression, @@ -16,7 +17,6 @@ use { Circuit, Opcode, }, native_types::{Expression, Witness as NoirWitness}, - BlackBoxFunc, }, anyhow::{bail, Result}, ark_ff::PrimeField, @@ -89,6 +89,11 @@ pub struct R1CSBreakdown { pub poseidon2_constraints: usize, /// Witnesses from Poseidon2 permutation pub poseidon2_witnesses: usize, + + /// Constraints from multi-scalar multiplication + pub msm_constraints: usize, + /// Witnesses from multi-scalar multiplication + pub msm_witnesses: usize, } /// Compiles an ACIR circuit into an [R1CS] instance, comprising of the A, B, @@ -458,6 +463,7 @@ impl NoirToR1CSCompiler { let mut xor_ops = vec![]; let mut sha256_compression_ops = vec![]; let mut poseidon2_ops = vec![]; + let mut msm_ops = vec![]; let mut breakdown = R1CSBreakdown::default(); @@ -632,7 +638,20 @@ impl NoirToR1CSCompiler { points, scalars, outputs, - } => {} + } => { + let point_wits: Vec = points + .iter() + .map(|inp| self.fetch_constant_or_r1cs_witness(inp.input())) + .collect(); + let scalar_wits: Vec = scalars + .iter() + .map(|inp| self.fetch_constant_or_r1cs_witness(inp.input())) + .collect(); + let out_x = self.fetch_r1cs_witness_index(outputs.0); + let out_y = self.fetch_r1cs_witness_index(outputs.1); + let out_inf = self.fetch_r1cs_witness_index(outputs.2); + msm_ops.push((point_wits, scalar_wits, (out_x, out_y, out_inf))); + } _ => { unimplemented!("Other black box function: {:?}", black_box_func_call); } @@ -724,6 +743,22 @@ impl NoirToR1CSCompiler { breakdown.poseidon2_constraints = self.r1cs.num_constraints() - constraints_before_poseidon; breakdown.poseidon2_witnesses = self.num_witnesses() - witnesses_before_poseidon; + let constraints_before_msm = self.r1cs.num_constraints(); + let witnesses_before_msm = self.num_witnesses(); + // Cost model: pick optimal (limb_bits, window_size) for MSM + let curve = crate::msm::curve::grumpkin_params(); + let native_bits = FieldElement::MODULUS_BIT_SIZE; + let curve_bits = curve.modulus_bits(); + let (msm_limb_bits, msm_window_size) = if !msm_ops.is_empty() { + let n_points: usize = msm_ops.iter().map(|(pts, _, _)| pts.len() / 3).sum(); + crate::msm::cost_model::get_optimal_msm_params(native_bits, curve_bits, n_points, 256) + } else { + (native_bits, 4) + }; + add_msm(self, msm_ops, msm_limb_bits, msm_window_size, &mut range_checks, &curve); + breakdown.msm_constraints = self.r1cs.num_constraints() - constraints_before_msm; + breakdown.msm_witnesses = self.num_witnesses() - witnesses_before_msm; + breakdown.range_ops_total = range_checks.values().map(|v| v.len()).sum(); let constraints_before_range = self.r1cs.num_constraints(); let witnesses_before_range = self.num_witnesses(); From 40d72b8f9ca68d913ada1f4c5c82ce667b250fb4 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Thu, 5 Mar 2026 04:48:33 +0530 Subject: [PATCH 4/5] feat : added gnark optimisations for msm --- .../src/witness/scheduling/dependency.rs | 22 +- .../common/src/witness/scheduling/remapper.rs | 28 + .../common/src/witness/witness_builder.rs | 37 +- provekit/prover/src/lib.rs | 15 +- provekit/prover/src/witness/bigint_mod.rs | 392 ++++++++++++- .../prover/src/witness/witness_builder.rs | 66 ++- provekit/r1cs-compiler/src/digits.rs | 4 +- provekit/r1cs-compiler/src/msm/cost_model.rs | 221 ++++---- provekit/r1cs-compiler/src/msm/curve.rs | 488 +++++++++++++++- provekit/r1cs-compiler/src/msm/ec_points.rs | 185 +++--- provekit/r1cs-compiler/src/msm/mod.rs | 529 +++++++++++++++--- .../r1cs-compiler/src/msm/multi_limb_arith.rs | 163 +++--- .../r1cs-compiler/src/msm/multi_limb_ops.rs | 115 ++-- provekit/r1cs-compiler/src/noir_to_r1cs.rs | 11 +- tooling/provekit-bench/tests/compiler.rs | 1 + 15 files changed, 1796 insertions(+), 481 deletions(-) diff --git a/provekit/common/src/witness/scheduling/dependency.rs b/provekit/common/src/witness/scheduling/dependency.rs index 9f92afd75..98ae1368f 100644 --- a/provekit/common/src/witness/scheduling/dependency.rs +++ b/provekit/common/src/witness/scheduling/dependency.rs @@ -156,9 +156,7 @@ impl DependencyInfo { v } WitnessBuilder::MultiLimbMulModHint { - a_limbs, - b_limbs, - .. + a_limbs, b_limbs, .. } => { let mut v = a_limbs.clone(); v.extend(b_limbs); @@ -166,18 +164,14 @@ impl DependencyInfo { } WitnessBuilder::MultiLimbModularInverse { a_limbs, .. } => a_limbs.clone(), WitnessBuilder::MultiLimbAddQuotient { - a_limbs, - b_limbs, - .. + a_limbs, b_limbs, .. } => { let mut v = a_limbs.clone(); v.extend(b_limbs); v } WitnessBuilder::MultiLimbSubBorrow { - a_limbs, - b_limbs, - .. + a_limbs, b_limbs, .. } => { let mut v = a_limbs.clone(); v.extend(b_limbs); @@ -229,6 +223,10 @@ impl DependencyInfo { data.rs_cubed, ] } + WitnessBuilder::FakeGLVHint { s_lo, s_hi, .. } => vec![*s_lo, *s_hi], + WitnessBuilder::EcScalarMulHint { + px, py, s_lo, s_hi, .. + } => vec![*px, *py, *s_lo, *s_hi], WitnessBuilder::ChunkDecompose { packed, .. } => vec![*packed], WitnessBuilder::SpreadWitness(_, input) => vec![*input], WitnessBuilder::SpreadBitExtract { sum_terms, .. } => { @@ -329,6 +327,12 @@ impl DependencyInfo { num_limbs, .. } => (*output_start..*output_start + *num_limbs as usize).collect(), + WitnessBuilder::FakeGLVHint { + output_start, .. + } => (*output_start..*output_start + 4).collect(), + WitnessBuilder::EcScalarMulHint { + output_start, .. + } => (*output_start..*output_start + 2).collect(), WitnessBuilder::MultiLimbAddQuotient { output, .. } => vec![*output], WitnessBuilder::MultiLimbSubBorrow { output, .. } => vec![*output], WitnessBuilder::U32Addition(result_idx, carry_idx, ..) => { diff --git a/provekit/common/src/witness/scheduling/remapper.rs b/provekit/common/src/witness/scheduling/remapper.rs index 334b5f401..696144113 100644 --- a/provekit/common/src/witness/scheduling/remapper.rs +++ b/provekit/common/src/witness/scheduling/remapper.rs @@ -366,6 +366,34 @@ impl WitnessIndexRemapper { }, ) } + WitnessBuilder::FakeGLVHint { + output_start, + s_lo, + s_hi, + curve_order, + } => WitnessBuilder::FakeGLVHint { + output_start: self.remap(*output_start), + s_lo: self.remap(*s_lo), + s_hi: self.remap(*s_hi), + curve_order: *curve_order, + }, + WitnessBuilder::EcScalarMulHint { + output_start, + px, + py, + s_lo, + s_hi, + curve_a, + field_modulus_p, + } => WitnessBuilder::EcScalarMulHint { + output_start: self.remap(*output_start), + px: self.remap(*px), + py: self.remap(*py), + s_lo: self.remap(*s_lo), + s_hi: self.remap(*s_hi), + curve_a: *curve_a, + field_modulus_p: *field_modulus_p, + }, WitnessBuilder::ChunkDecompose { output_start, packed, diff --git a/provekit/common/src/witness/witness_builder.rs b/provekit/common/src/witness/witness_builder.rs index 28d6d775c..2b7cd2f30 100644 --- a/provekit/common/src/witness/witness_builder.rs +++ b/provekit/common/src/witness/witness_builder.rs @@ -270,6 +270,37 @@ pub enum WitnessBuilder { packed: usize, chunk_bits: Vec, }, + /// Prover hint for FakeGLV scalar decomposition. + /// Given scalar s (from s_lo + s_hi * 2^128) and curve order n, + /// computes half_gcd(s, n) → (|s1|, |s2|, neg1, neg2) such that: + /// (-1)^neg1 * |s1| + (-1)^neg2 * |s2| * s ≡ 0 (mod n) + /// + /// Outputs 4 witnesses starting at output_start: + /// [0] |s1| (128-bit field element) + /// [1] |s2| (128-bit field element) + /// [2] neg1 (boolean: 0 or 1) + /// [3] neg2 (boolean: 0 or 1) + FakeGLVHint { + output_start: usize, + s_lo: usize, + s_hi: usize, + curve_order: [u64; 4], + }, + /// Prover hint for EC scalar multiplication: computes R = [s]P. + /// Given point P = (px, py) and scalar s = s_lo + s_hi * 2^128, + /// computes R = [s]P on the curve with parameter `curve_a` and + /// field modulus `field_modulus_p`. + /// + /// Outputs 2 witnesses at output_start: R_x, R_y. + EcScalarMulHint { + output_start: usize, + px: usize, + py: usize, + s_lo: usize, + s_hi: usize, + curve_a: [u64; 4], + field_modulus_p: [u64; 4], + }, /// Computes spread(input): interleave bits with zeros. /// Output: 0 b_{n-1} 0 b_{n-2} ... 0 b_1 0 b_0 /// (witness index of output, witness index of input) @@ -332,10 +363,10 @@ impl WitnessBuilder { WitnessBuilder::ChunkDecompose { chunk_bits, .. } => chunk_bits.len(), WitnessBuilder::SpreadBitExtract { chunk_bits, .. } => chunk_bits.len(), WitnessBuilder::MultiplicitiesForSpread(_, num_bits, _) => 1usize << *num_bits, - WitnessBuilder::MultiLimbMulModHint { num_limbs, .. } => { - (4 * *num_limbs - 2) as usize - } + WitnessBuilder::MultiLimbMulModHint { num_limbs, .. } => (4 * *num_limbs - 2) as usize, WitnessBuilder::MultiLimbModularInverse { num_limbs, .. } => *num_limbs as usize, + WitnessBuilder::FakeGLVHint { .. } => 4, + WitnessBuilder::EcScalarMulHint { .. } => 2, _ => 1, } diff --git a/provekit/prover/src/lib.rs b/provekit/prover/src/lib.rs index f6a3e653f..07ad6b420 100644 --- a/provekit/prover/src/lib.rs +++ b/provekit/prover/src/lib.rs @@ -202,7 +202,9 @@ impl Prove for Prover { let hc = debug_r1cs.c.hydrate(interner); let mut fail_count = 0usize; for row in 0..debug_r1cs.num_constraints() { - let eval = |hm: &provekit_common::sparse_matrix::HydratedSparseMatrix, r: usize| -> FieldElement { + let eval = |hm: &provekit_common::sparse_matrix::HydratedSparseMatrix, + r: usize| + -> FieldElement { let mut sum = FieldElement::zero(); for (col, coeff) in hm.iter_row(r) { sum += coeff * full_witness[col]; @@ -216,7 +218,11 @@ impl Prove for Prover { if fail_count < 10 { eprintln!( "CONSTRAINT {} FAILED: A={:?} B={:?} C={:?} A*B={:?}", - row, a_val, b_val, c_val, a_val * b_val + row, + a_val, + b_val, + c_val, + a_val * b_val ); eprint!(" A terms:"); for (col, coeff) in ha.iter_row(row) { @@ -238,7 +244,10 @@ impl Prove for Prover { } } if fail_count > 0 { - eprintln!("TOTAL FAILING CONSTRAINTS: {fail_count} / {}", debug_r1cs.num_constraints()); + eprintln!( + "TOTAL FAILING CONSTRAINTS: {fail_count} / {}", + debug_r1cs.num_constraints() + ); } else { eprintln!("ALL {} CONSTRAINTS SATISFIED", debug_r1cs.num_constraints()); } diff --git a/provekit/prover/src/witness/bigint_mod.rs b/provekit/prover/src/witness/bigint_mod.rs index a41f47ff3..2874d49a3 100644 --- a/provekit/prover/src/witness/bigint_mod.rs +++ b/provekit/prover/src/witness/bigint_mod.rs @@ -217,9 +217,317 @@ pub fn add_4limb(a: &[u64; 4], b: &[u64; 4]) -> [u64; 5] { result } -/// Offset added to signed carries to make them non-negative for range checking. -/// Carries are bounded by |c| < 2^88, so adding 2^88 ensures c_unsigned >= 0. -pub const CARRY_OFFSET: u128 = 1u128 << 88; +/// Add two 4-limb numbers in-place: a += b. Returns the carry-out. +pub fn add_4limb_inplace(a: &mut [u64; 4], b: &[u64; 4]) -> u64 { + let mut carry = 0u64; + for i in 0..4 { + let (s1, c1) = a[i].overflowing_add(b[i]); + let (s2, c2) = s1.overflowing_add(carry); + a[i] = s2; + carry = (c1 as u64) + (c2 as u64); + } + carry +} + +/// Subtract b from a in-place, returning true if a >= b (no underflow). +/// If a < b, the result is a += 2^256 - b (wrapping subtraction) and returns false. +pub fn sub_4limb_checked(a: &mut [u64; 4], b: &[u64; 4]) -> bool { + let mut borrow = 0u64; + for i in 0..4 { + let (d1, b1) = a[i].overflowing_sub(b[i]); + let (d2, b2) = d1.overflowing_sub(borrow); + a[i] = d2; + borrow = (b1 as u64) + (b2 as u64); + } + borrow == 0 +} + +/// Returns true if val == 0. +pub fn is_zero(val: &[u64; 4]) -> bool { + val[0] == 0 && val[1] == 0 && val[2] == 0 && val[3] == 0 +} + +/// Compute the number of bits needed for the half-GCD sub-scalars. +/// Returns `ceil(order_bits / 2)` where `order_bits` is the bit length of `n`. +pub fn half_gcd_bits(n: &[u64; 4]) -> u32 { + let mut order_bits = 0u32; + for i in (0..4).rev() { + if n[i] != 0 { + order_bits = (i as u32) * 64 + (64 - n[i].leading_zeros()); + break; + } + } + (order_bits + 1) / 2 +} + +/// Build the threshold value `2^half_bits` as a `[u64; 4]`. +fn build_threshold(half_bits: u32) -> [u64; 4] { + assert!(half_bits <= 255, "half_bits must be <= 255"); + let mut threshold = [0u64; 4]; + let word = (half_bits / 64) as usize; + let bit = half_bits % 64; + threshold[word] = 1u64 << bit; + threshold +} + +/// Half-GCD scalar decomposition for FakeGLV. +/// +/// Given scalar `s` and curve order `n`, finds `(|s1|, |s2|, neg1, neg2)` such that: +/// `(-1)^neg1 * |s1| + (-1)^neg2 * |s2| * s ≡ 0 (mod n)` +/// +/// Uses the extended GCD on `(n, s)`, stopping when the remainder drops below +/// `2^half_bits` where `half_bits = ceil(order_bits / 2)`. +/// Returns `(val1, val2, neg1, neg2)` where both fit in `half_bits` bits. +pub fn half_gcd( + s: &[u64; 4], + n: &[u64; 4], +) -> ([u64; 4], [u64; 4], bool, bool) { + // Extended GCD on (n, s): + // We track: r_{i} = r_{i-2} - q_i * r_{i-1} + // t_{i} = t_{i-2} - q_i * t_{i-1} + // Starting: r_0 = n, r_1 = s, t_0 = 0, t_1 = 1 + // + // We want: t_i * s ≡ r_i (mod n) [up to sign] + // More precisely: t_i * s ≡ (-1)^{i+1} * r_i (mod n) + // + // The relation we verify is: sign_r * |r_i| + sign_t * |t_i| * s ≡ 0 (mod n) + + // Threshold: 2^half_bits where half_bits = ceil(order_bits / 2) + let half_bits = half_gcd_bits(n); + let threshold = build_threshold(half_bits); + + // r_prev = n, r_curr = s + let mut r_prev = *n; + let mut r_curr = *s; + + // t_prev = 0, t_curr = 1 + let mut t_prev = [0u64; 4]; + let mut t_curr = [1u64, 0, 0, 0]; + + // Track sign of t: t_prev_neg=false (t_0=0, positive), t_curr_neg=false (t_1=1, positive) + let mut t_prev_neg = false; + let mut t_curr_neg = false; + + let mut iteration = 0u32; + + loop { + // Check if r_curr < threshold + if cmp_4limb(&r_curr, &threshold) == std::cmp::Ordering::Less { + break; + } + + if is_zero(&r_curr) { + break; + } + + // q = r_prev / r_curr, new_r = r_prev % r_curr + let (q, new_r) = divmod(&r_prev, &r_curr); + + // new_t = t_prev + q * t_curr (in terms of absolute values and signs) + // Since the GCD recurrence is: t_{i} = t_{i-2} - q_i * t_{i-1} + // In terms of absolute values with sign tracking: + // If t_prev and q*t_curr have the same sign → subtract magnitudes + // If they have different signs → add magnitudes + // But actually: new_t = |t_prev| +/- q * |t_curr|, with sign flips each iteration. + // + // The standard extended GCD recurrence gives: + // t_i = t_{i-2} - q_i * t_{i-1} + // We track magnitudes and sign bits separately. + + // Compute q * t_curr + let qt = mul_mod_no_reduce(&q, &t_curr); + + // new_t magnitude and sign: + // In the standard recurrence: new_t_val = t_prev_val - q * t_curr_val + // where t_prev_val = (-1)^t_prev_neg * |t_prev|, etc. + // + // But it's simpler to just track: alternating signs. + // In the half-GCD: t values alternate in sign. So: + // new_t = t_prev + q * t_curr (absolute addition since signs alternate) + let mut new_t = qt; + add_4limb_inplace(&mut new_t, &t_prev); + let new_t_neg = !t_curr_neg; + + r_prev = r_curr; + r_curr = new_r; + t_prev = t_curr; + t_prev_neg = t_curr_neg; + t_curr = new_t; + t_curr_neg = new_t_neg; + iteration += 1; + } + + // At this point: r_curr < 2^half_bits and t_curr < ~2^half_bits (half-GCD property) + // The relation is: (-1)^(iteration) * r_curr + t_curr * s ≡ 0 (mod n) + // Or equivalently: r_curr ≡ (-1)^(iteration+1) * t_curr * s (mod n) + + let val1 = r_curr; // |s1| = |r_i| + let val2 = t_curr; // |s2| = |t_i| + + // Determine signs: + // We need: neg1 * val1 + neg2 * val2 * s ≡ 0 (mod n) + // From the extended GCD: r_i = (-1)^i * (... some relation with t_i * s mod n) + // The exact sign relationship: + // t_i * s ≡ (-1)^(i+1) * r_i (mod n) + // So: (-1)^(i+1) * r_i + t_i * s ≡ 0 (mod n) + // + // If iteration is even: (-1)^(even+1) = -1, so: -r_i + t_i * s ≡ 0 + // → neg1=true (negate r_i), neg2=t_curr_neg + // If iteration is odd: (-1)^(odd+1) = 1, so: r_i + t_i * s ≡ 0 + // → neg1=false, neg2=t_curr_neg + + let neg1 = iteration % 2 == 0; // negate val1 when iteration is even + let neg2 = t_curr_neg; + + (val1, val2, neg1, neg2) +} + +/// Multiply two 4-limb values without modular reduction. +/// Returns the lower 4 limbs (ignoring overflow beyond 256 bits). +/// Used internally by half_gcd for q * t_curr where the result is known to fit. +fn mul_mod_no_reduce(a: &[u64; 4], b: &[u64; 4]) -> [u64; 4] { + let wide = widening_mul(a, b); + [wide[0], wide[1], wide[2], wide[3]] +} + +// --------------------------------------------------------------------------- +// Modular arithmetic helpers for EC operations (prover-side) +// --------------------------------------------------------------------------- + +/// Modular addition: (a + b) mod p. +pub fn mod_add(a: &[u64; 4], b: &[u64; 4], p: &[u64; 4]) -> [u64; 4] { + let sum = add_4limb(a, b); + let sum4 = [sum[0], sum[1], sum[2], sum[3]]; + if sum[4] > 0 || cmp_4limb(&sum4, p) != std::cmp::Ordering::Less { + // sum >= p, subtract p + let mut result = sum4; + sub_4limb_inplace(&mut result, p); + result + } else { + sum4 + } +} + +/// Modular subtraction: (a - b) mod p. +pub fn mod_sub(a: &[u64; 4], b: &[u64; 4], p: &[u64; 4]) -> [u64; 4] { + let mut result = *a; + let no_borrow = sub_4limb_checked(&mut result, b); + if no_borrow { + result + } else { + // a < b, add p to get (a - b + p) + add_4limb_inplace(&mut result, p); + result + } +} + +/// Modular inverse: a^{p-2} mod p (Fermat's little theorem). +pub fn mod_inverse(a: &[u64; 4], p: &[u64; 4]) -> [u64; 4] { + let exp = sub_u64(p, 2); + mod_pow(a, &exp, p) +} + +/// EC point doubling in affine coordinates on y^2 = x^3 + ax + b. +/// Returns (x3, y3) = 2*(px, py). +pub fn ec_point_double( + px: &[u64; 4], + py: &[u64; 4], + a: &[u64; 4], + p: &[u64; 4], +) -> ([u64; 4], [u64; 4]) { + // lambda = (3*x^2 + a) / (2*y) + let x_sq = mul_mod(px, px, p); + let two_x_sq = mod_add(&x_sq, &x_sq, p); + let three_x_sq = mod_add(&two_x_sq, &x_sq, p); + let numerator = mod_add(&three_x_sq, a, p); + let two_y = mod_add(py, py, p); + let denom_inv = mod_inverse(&two_y, p); + let lambda = mul_mod(&numerator, &denom_inv, p); + + // x3 = lambda^2 - 2*x + let lambda_sq = mul_mod(&lambda, &lambda, p); + let two_x = mod_add(px, px, p); + let x3 = mod_sub(&lambda_sq, &two_x, p); + + // y3 = lambda * (x - x3) - y + let x_minus_x3 = mod_sub(px, &x3, p); + let lambda_dx = mul_mod(&lambda, &x_minus_x3, p); + let y3 = mod_sub(&lambda_dx, py, p); + + (x3, y3) +} + +/// EC point addition in affine coordinates on y^2 = x^3 + ax + b. +/// Returns (x3, y3) = (p1x, p1y) + (p2x, p2y). Requires p1x != p2x. +pub fn ec_point_add( + p1x: &[u64; 4], + p1y: &[u64; 4], + p2x: &[u64; 4], + p2y: &[u64; 4], + p: &[u64; 4], +) -> ([u64; 4], [u64; 4]) { + // lambda = (y2 - y1) / (x2 - x1) + let numerator = mod_sub(p2y, p1y, p); + let denominator = mod_sub(p2x, p1x, p); + let denom_inv = mod_inverse(&denominator, p); + let lambda = mul_mod(&numerator, &denom_inv, p); + + // x3 = lambda^2 - x1 - x2 + let lambda_sq = mul_mod(&lambda, &lambda, p); + let x1_plus_x2 = mod_add(p1x, p2x, p); + let x3 = mod_sub(&lambda_sq, &x1_plus_x2, p); + + // y3 = lambda * (x1 - x3) - y1 + let x1_minus_x3 = mod_sub(p1x, &x3, p); + let lambda_dx = mul_mod(&lambda, &x1_minus_x3, p); + let y3 = mod_sub(&lambda_dx, p1y, p); + + (x3, y3) +} + +/// EC scalar multiplication via double-and-add: returns [scalar]*P. +pub fn ec_scalar_mul( + px: &[u64; 4], + py: &[u64; 4], + scalar: &[u64; 4], + a: &[u64; 4], + p: &[u64; 4], +) -> ([u64; 4], [u64; 4]) { + // Find highest set bit in scalar + let mut highest_bit = 0; + for i in (0..4).rev() { + if scalar[i] != 0 { + highest_bit = i * 64 + (64 - scalar[i].leading_zeros() as usize); + break; + } + } + if highest_bit == 0 { + // scalar == 0 → point at infinity (not representable in affine) + panic!("ec_scalar_mul: scalar is zero"); + } + + // Start from the MSB-1 and double-and-add + let mut rx = *px; + let mut ry = *py; + + for bit_pos in (0..highest_bit - 1).rev() { + // Double + let (dx, dy) = ec_point_double(&rx, &ry, a, p); + rx = dx; + ry = dy; + + // Add if bit is set + let limb_idx = bit_pos / 64; + let bit_idx = bit_pos % 64; + if (scalar[limb_idx] >> bit_idx) & 1 == 1 { + let (ax, ay) = ec_point_add(&rx, &ry, px, py, p); + rx = ax; + ry = ay; + } + } + + (rx, ry) +} /// Integer division of a 512-bit dividend by a 256-bit divisor. /// Returns (quotient, remainder) where both fit in 256 bits. @@ -848,4 +1156,82 @@ mod tests { r0, r1, r2, ]); } + + #[test] + fn test_half_gcd_small() { + // s = 42, n = 101 + let s = [42, 0, 0, 0]; + let n = [101, 0, 0, 0]; + let (val1, val2, neg1, neg2) = half_gcd(&s, &n); + + // Verify: (-1)^neg1 * val1 + (-1)^neg2 * val2 * s ≡ 0 (mod n) + let sign1: i128 = if neg1 { -1 } else { 1 }; + let sign2: i128 = if neg2 { -1 } else { 1 }; + let v1 = val1[0] as i128; + let v2 = val2[0] as i128; + let s_val = s[0] as i128; + let n_val = n[0] as i128; + let lhs = ((sign1 * v1 + sign2 * v2 * s_val) % n_val + n_val) % n_val; + assert_eq!(lhs, 0, "half_gcd relation failed for small values"); + } + + #[test] + fn test_half_gcd_grumpkin_order() { + // Grumpkin curve order (BN254 base field order) + let n = [ + 0x3c208c16d87cfd47_u64, + 0x97816a916871ca8d_u64, + 0xb85045b68181585d_u64, + 0x30644e72e131a029_u64, + ]; + // Some scalar + let s = [ + 0x123456789abcdef0_u64, + 0xfedcba9876543210_u64, + 0x1111111111111111_u64, + 0x2222222222222222_u64, + ]; + + let (val1, val2, neg1, neg2) = half_gcd(&s, &n); + + // val1 and val2 should be < 2^128 + assert_eq!(val1[2], 0, "val1 should be < 2^128"); + assert_eq!(val1[3], 0, "val1 should be < 2^128"); + assert_eq!(val2[2], 0, "val2 should be < 2^128"); + assert_eq!(val2[3], 0, "val2 should be < 2^128"); + + // Verify: (-1)^neg1 * val1 + (-1)^neg2 * val2 * s ≡ 0 (mod n) + // Use big integer arithmetic + let term2_full = widening_mul(&val2, &s); + let (_, term2_mod_n) = divmod_wide(&term2_full, &n); + + // Compute: sign1 * val1 + sign2 * term2_mod_n (mod n) + let effective1 = if neg1 { + // n - val1 + let mut result = n; + sub_4limb_checked(&mut result, &val1); + result + } else { + val1 + }; + let effective2 = if neg2 { + let mut result = n; + sub_4limb_checked(&mut result, &term2_mod_n); + result + } else { + term2_mod_n + }; + + let sum = add_4limb(&effective1, &effective2); + let sum4 = [sum[0], sum[1], sum[2], sum[3]]; + // sum might be >= n, so reduce + let (_, remainder) = if sum[4] > 0 { + // Sum overflows 256 bits, need wide divmod + let wide = [sum[0], sum[1], sum[2], sum[3], sum[4], 0, 0, 0]; + divmod_wide(&wide, &n) + } else { + divmod(&sum4, &n) + }; + assert_eq!(remainder, [0, 0, 0, 0], "half_gcd relation failed for Grumpkin order"); + } } diff --git a/provekit/prover/src/witness/witness_builder.rs b/provekit/prover/src/witness/witness_builder.rs index d3479331b..ec3d61b06 100644 --- a/provekit/prover/src/witness/witness_builder.rs +++ b/provekit/prover/src/witness/witness_builder.rs @@ -449,8 +449,7 @@ impl WitnessBuilderSolver for WitnessBuilder { for i in 0..n { let j = k as isize - i as isize; if j >= 0 && (j as usize) < n { - ab_sum += - a_limbs_vals[i] as i128 * b_limbs_vals[j as usize] as i128; + ab_sum += a_limbs_vals[i] as i128 * b_limbs_vals[j as usize] as i128; } } // Sum p[i]*q[j] for i+j=k @@ -458,8 +457,7 @@ impl WitnessBuilderSolver for WitnessBuilder { for i in 0..n { let j = k as isize - i as isize; if j >= 0 && (j as usize) < n { - pq_sum += - p_limbs_vals[i] as i128 * q_limbs_vals[j as usize] as i128; + pq_sum += p_limbs_vals[i] as i128 * q_limbs_vals[j as usize] as i128; } } let r_k = if k < n { r_limbs_vals[k] as i128 } else { 0 }; @@ -663,6 +661,66 @@ impl WitnessBuilderSolver for WitnessBuilder { witness[*lo] = Some(FieldElement::from(lo_val)); witness[*hi] = Some(FieldElement::from(hi_val)); } + WitnessBuilder::FakeGLVHint { + output_start, + s_lo, + s_hi, + curve_order, + } => { + // Reconstruct s = s_lo + s_hi * 2^128 + let s_lo_val = witness[*s_lo].unwrap().into_bigint().0; + let s_hi_val = witness[*s_hi].unwrap().into_bigint().0; + let s_val: [u64; 4] = [ + s_lo_val[0], + s_lo_val[1], + s_hi_val[0], + s_hi_val[1], + ]; + + let (val1, val2, neg1, neg2) = + crate::witness::bigint_mod::half_gcd(&s_val, curve_order); + + witness[*output_start] = + Some(FieldElement::from_bigint(ark_ff::BigInt(val1)).unwrap()); + witness[*output_start + 1] = + Some(FieldElement::from_bigint(ark_ff::BigInt(val2)).unwrap()); + witness[*output_start + 2] = + Some(FieldElement::from(neg1 as u64)); + witness[*output_start + 3] = + Some(FieldElement::from(neg2 as u64)); + } + WitnessBuilder::EcScalarMulHint { + output_start, + px, + py, + s_lo, + s_hi, + curve_a, + field_modulus_p, + } => { + // Reconstruct scalar s = s_lo + s_hi * 2^128 + let s_lo_val = witness[*s_lo].unwrap().into_bigint().0; + let s_hi_val = witness[*s_hi].unwrap().into_bigint().0; + let scalar: [u64; 4] = [s_lo_val[0], s_lo_val[1], s_hi_val[0], s_hi_val[1]]; + + // Reconstruct point P + let px_val = witness[*px].unwrap().into_bigint().0; + let py_val = witness[*py].unwrap().into_bigint().0; + + // Compute R = [s]P + let (rx, ry) = crate::witness::bigint_mod::ec_scalar_mul( + &px_val, + &py_val, + &scalar, + curve_a, + field_modulus_p, + ); + + witness[*output_start] = + Some(FieldElement::from_bigint(ark_ff::BigInt(rx)).unwrap()); + witness[*output_start + 1] = + Some(FieldElement::from_bigint(ark_ff::BigInt(ry)).unwrap()); + } WitnessBuilder::CombinedTableEntryInverse(..) => { unreachable!( "CombinedTableEntryInverse should not be called - handled by batch inversion" diff --git a/provekit/r1cs-compiler/src/digits.rs b/provekit/r1cs-compiler/src/digits.rs index 91c4e4128..3d8917d55 100644 --- a/provekit/r1cs-compiler/src/digits.rs +++ b/provekit/r1cs-compiler/src/digits.rs @@ -1,5 +1,6 @@ use { crate::noir_to_r1cs::NoirToR1CSCompiler, + ark_ff::Field, ark_std::One, provekit_common::{ witness::{DigitalDecompositionWitnesses, WitnessBuilder}, @@ -66,7 +67,8 @@ pub(crate) fn add_digital_decomposition( // Add the constraints for the digital recomposition let mut digit_multipliers = vec![FieldElement::one()]; for log_base in log_bases[..log_bases.len() - 1].iter() { - let multiplier = *digit_multipliers.last().unwrap() * FieldElement::from(1u64 << *log_base); + let multiplier = *digit_multipliers.last().unwrap() + * FieldElement::from(2u64).pow([*log_base as u64]); digit_multipliers.push(multiplier); } dd_struct diff --git a/provekit/r1cs-compiler/src/msm/cost_model.rs b/provekit/r1cs-compiler/src/msm/cost_model.rs index 234623a31..b896dc043 100644 --- a/provekit/r1cs-compiler/src/msm/cost_model.rs +++ b/provekit/r1cs-compiler/src/msm/cost_model.rs @@ -1,7 +1,8 @@ //! Analytical cost model for MSM parameter optimization. //! //! Follows the SHA256 pattern (`spread.rs:get_optimal_spread_width`): -//! pure analytical estimator → exhaustive search → pick optimal (limb_bits, window_size). +//! pure analytical estimator → exhaustive search → pick optimal (limb_bits, +//! window_size). /// Type of field operation for cost estimation. #[derive(Clone, Copy)] @@ -12,105 +13,75 @@ pub enum FieldOpType { Inv, } -/// Count field ops in scalar_mul for given parameters. -/// Traces through ec_points::scalar_mul logic analytically. +/// Count field ops in scalar_mul_glv for given parameters. /// -/// Returns (n_add, n_sub, n_mul, n_inv) per single scalar multiplication. -fn count_scalar_mul_field_ops(scalar_bits: usize, window_size: usize) -> (usize, usize, usize, usize) { +/// The GLV approach does interleaved two-point scalar mul with half-width scalars. +/// Per window: w shared doubles + 2 table lookups + 2 point_adds + 2 is_zero + 2 point_selects +/// Plus: 2 table builds, on-curve check, scalar relation overhead. +fn count_glv_field_ops( + scalar_bits: usize, // half_bits = ceil(order_bits / 2) + window_size: usize, +) -> (usize, usize, usize, usize) { let w = window_size; let table_size = 1 << w; let num_windows = (scalar_bits + w - 1) / w; - // Build point table: T[0]=P (free), T[1]=P (free), T[2]=2P (1 double), - // T[3..table_size] = point_add each + let double_ops = (4usize, 2usize, 5usize, 1usize); + let add_ops = (2usize, 2usize, 3usize, 1usize); + let select_ops_per_point = (2usize, 2usize, 2usize, 0usize); + + // Two tables (one for P, one for R) let table_doubles = if table_size > 2 { 1 } else { 0 }; let table_adds = if table_size > 2 { table_size - 3 } else { 0 }; - // point_double costs: 5 mul, 4 add, 2 sub, 1 inv - let double_ops = (4usize, 2usize, 5usize, 1usize); // (add, sub, mul, inv) - // point_add costs: 2 add, 2 sub, 3 mul, 1 inv - let add_ops = (2usize, 2usize, 3usize, 1usize); + let mut total_add = 2 * (table_doubles * double_ops.0 + table_adds * add_ops.0); + let mut total_sub = 2 * (table_doubles * double_ops.1 + table_adds * add_ops.1); + let mut total_mul = 2 * (table_doubles * double_ops.2 + table_adds * add_ops.2); + let mut total_inv = 2 * (table_doubles * double_ops.3 + table_adds * add_ops.3); + + for win_idx in (0..num_windows).rev() { + let bit_start = win_idx * w; + let bit_end = std::cmp::min(bit_start + w, scalar_bits); + let actual_w = bit_end - bit_start; + let actual_selects = (1 << actual_w) - 1; - // Table construction - let mut total_add = table_doubles * double_ops.0 + table_adds * add_ops.0; - let mut total_sub = table_doubles * double_ops.1 + table_adds * add_ops.1; - let mut total_mul = table_doubles * double_ops.2 + table_adds * add_ops.2; - let mut total_inv = table_doubles * double_ops.3 + table_adds * add_ops.3; - - // Table lookups: each uses (2^w - 1) point_selects - // point_select = 2 selects = 2 * (3 witnesses: diff, flag*diff, out) per coordinate - // But select is not a field op — it's cheaper (just `select` calls) - // We count it as 2 selects per point_select = 2 sub + 2 mul per select - // Actually select = flag*(on_true - on_false) + on_false: 1 sub, 1 mul, 1 add per elem - // Per point (x,y): 2 sub, 2 mul, 2 add for select - let selects_per_lookup = table_size - 1; // 2^w - 1 point_selects - let select_ops_per_point = (2usize, 2usize, 2usize, 0usize); // (add, sub, mul, inv) - - // MSB window: 1 table lookup (possibly smaller table) - let msb_bits = scalar_bits - (num_windows - 1) * w; - let msb_table_size = 1 << msb_bits; - let msb_selects = msb_table_size - 1; - total_add += msb_selects * select_ops_per_point.0; - total_sub += msb_selects * select_ops_per_point.1; - total_mul += msb_selects * select_ops_per_point.2; - - // Remaining windows: for each of (num_windows - 1) windows: - // - w doublings - // - 1 pack_bits (cheap) - // - 1 is_zero (1 inv + some adds) - // - 1 table lookup - // - 1 sub (for denom) - // - 1 elem_is_zero - // - 1 point_double (for x_eq case) - // - 1 safe_point_add (like point_add but with select on denom) - // - 2 point_selects (x_eq and digit_is_zero) - let remaining = if num_windows > 1 { num_windows - 1 } else { 0 }; - - for _ in 0..remaining { - // w doublings + // w shared doublings total_add += w * double_ops.0; total_sub += w * double_ops.1; total_mul += w * double_ops.2; total_inv += w * double_ops.3; - // table lookup - total_add += selects_per_lookup * select_ops_per_point.0; - total_sub += selects_per_lookup * select_ops_per_point.1; - total_mul += selects_per_lookup * select_ops_per_point.2; - - // denom = sub(looked_up.x, acc.x) - total_sub += 1; - - // elem_is_zero(denom) = is_zero per limb + products - // For N limbs: N * (1 inv + some arith) + (N-1) products - // Simplified: 1 inv + 3 witnesses - total_inv += 1; - total_add += 1; - total_mul += 1; - - // point_double for x_eq case - total_add += double_ops.0; - total_sub += double_ops.1; - total_mul += double_ops.2; - total_inv += double_ops.3; - - // safe_point_add: like point_add + 1 select on denom - total_add += add_ops.0 + select_ops_per_point.0 / 2; // 1 select - total_sub += add_ops.1 + select_ops_per_point.1 / 2; - total_mul += add_ops.2 + select_ops_per_point.2 / 2; - total_inv += add_ops.3; - - // 2 point_selects - total_add += 2 * select_ops_per_point.0; - total_sub += 2 * select_ops_per_point.1; - total_mul += 2 * select_ops_per_point.2; - - // is_zero(digit) - total_inv += 1; - total_add += 1; - total_mul += 1; + // Two table lookups + two point_adds + two is_zeros + two point_selects + for _ in 0..2 { + total_add += actual_selects * select_ops_per_point.0; + total_sub += actual_selects * select_ops_per_point.1; + total_mul += actual_selects * select_ops_per_point.2; + + total_add += add_ops.0; + total_sub += add_ops.1; + total_mul += add_ops.2; + total_inv += add_ops.3; + + total_inv += 1; // is_zero + total_add += 1; + total_mul += 1; + + total_add += select_ops_per_point.0; + total_sub += select_ops_per_point.1; + total_mul += select_ops_per_point.2; + } } + // On-curve checks for P and R: each needs 1 mul (y^2), 2 mul (x^2, x^3), 1 mul (a*x), 2 add + total_mul += 8; + total_add += 4; + + // Conditional y-negation: 2 sub + 2 select (for P.y and R.y) + total_sub += 2; + total_add += 2 * select_ops_per_point.0; + total_sub += 2 * select_ops_per_point.1; + total_mul += 2 * select_ops_per_point.2; + (total_add, total_sub, total_mul, total_inv) } @@ -119,18 +90,18 @@ fn witnesses_per_op(num_limbs: usize, op: FieldOpType, is_native: bool) -> usize if is_native { // Native: no range checks, just standard R1CS witnesses match op { - FieldOpType::Add => 1, // sum witness - FieldOpType::Sub => 1, // sum witness - FieldOpType::Mul => 1, // product witness - FieldOpType::Inv => 1, // inverse witness + FieldOpType::Add => 1, // sum witness + FieldOpType::Sub => 1, // sum witness + FieldOpType::Mul => 1, // product witness + FieldOpType::Inv => 1, // inverse witness } } else if num_limbs == 1 { // Single-limb non-native: reduce_mod_p pattern match op { - FieldOpType::Add => 5, // a+b, m const, k, k*m, result - FieldOpType::Sub => 5, // same - FieldOpType::Mul => 5, // a*b, m const, k, k*m, result - FieldOpType::Inv => 7, // a_inv + mul_mod_p(5) + range_check + FieldOpType::Add => 5, // a+b, m const, k, k*m, result + FieldOpType::Sub => 5, // same + FieldOpType::Mul => 5, // a*b, m const, k, k*m, result + FieldOpType::Inv => 7, // a_inv + mul_mod_p(5) + range_check } } else { // Multi-limb: N-limb operations @@ -162,41 +133,53 @@ pub fn calculate_msm_witness_cost( ((curve_modulus_bits as usize) + (limb_bits as usize) - 1) / (limb_bits as usize) }; - let (n_add, n_sub, n_mul, n_inv) = count_scalar_mul_field_ops(scalar_bits, window_size); - let wit_add = witnesses_per_op(num_limbs, FieldOpType::Add, is_native); let wit_sub = witnesses_per_op(num_limbs, FieldOpType::Sub, is_native); let wit_mul = witnesses_per_op(num_limbs, FieldOpType::Mul, is_native); let wit_inv = witnesses_per_op(num_limbs, FieldOpType::Inv, is_native); - let per_scalarmul = n_add * wit_add + n_sub * wit_sub + n_mul * wit_mul + n_inv * wit_inv; + // FakeGLV path for ALL points: half-width interleaved scalar mul + let half_bits = (scalar_bits + 1) / 2; + let (n_add, n_sub, n_mul, n_inv) = count_glv_field_ops(half_bits, window_size); + let glv_scalarmul = n_add * wit_add + n_sub * wit_sub + n_mul * wit_mul + n_inv * wit_inv; + + // Per-point overhead: scalar decomposition (2 × half_bits for s1, s2) + + // scalar relation (~150 witnesses) + FakeGLVHint (4 witnesses) + let scalar_decomp = 2 * half_bits + 10; + let scalar_relation = 150; + let glv_hint = 4; - // Scalar decomposition: 256 bits (bit witnesses + digital decomposition overhead) - let scalar_decomp = 256 + 10; + // EcScalarMulHint: 2 witnesses per point (only for n_points > 1) + let ec_hint = if n_points > 1 { 2 } else { 0 }; + + let per_point = glv_scalarmul + scalar_decomp + scalar_relation + glv_hint + ec_hint; // Point accumulation: (n_points - 1) point_adds - let accum_per_point = if n_points > 1 { + let accum = if n_points > 1 { let accum_adds = n_points - 1; - accum_adds * (witnesses_per_op(num_limbs, FieldOpType::Add, is_native) * 2 - + witnesses_per_op(num_limbs, FieldOpType::Sub, is_native) * 2 - + witnesses_per_op(num_limbs, FieldOpType::Mul, is_native) * 3 - + witnesses_per_op(num_limbs, FieldOpType::Inv, is_native)) + accum_adds + * (witnesses_per_op(num_limbs, FieldOpType::Add, is_native) * 2 + + witnesses_per_op(num_limbs, FieldOpType::Sub, is_native) * 2 + + witnesses_per_op(num_limbs, FieldOpType::Mul, is_native) * 3 + + witnesses_per_op(num_limbs, FieldOpType::Inv, is_native)) } else { 0 }; - n_points * (per_scalarmul + scalar_decomp) + accum_per_point + n_points * per_point + accum } /// Check whether schoolbook column equation values fit in the native field. /// -/// In `mul_mod_p_multi`, the schoolbook multiplication verifies `a·b = p·q + r` via -/// column equations that include product sums, carry offsets, and outgoing carries. -/// Both sides of each column equation must evaluate to less than the native field -/// modulus as **integers** — if they overflow, the field's modular reduction makes -/// `LHS ≡ RHS (mod p)` weaker than `LHS = RHS`, breaking soundness. +/// In `mul_mod_p_multi`, the schoolbook multiplication verifies `a·b = p·q + r` +/// via column equations that include product sums, carry offsets, and outgoing +/// carries. Both sides of each column equation must evaluate to less than the +/// native field modulus as **integers** — if they overflow, the field's modular +/// reduction makes `LHS ≡ RHS (mod p)` weaker than `LHS = RHS`, breaking +/// soundness. /// -/// The maximum integer value across either side of any column equation is bounded by: +/// The maximum integer value across either side of any column equation is +/// bounded by: /// /// `2^(2W + ceil(log2(N)) + 3)` /// @@ -205,8 +188,8 @@ pub fn calculate_msm_witness_cost( /// - The carry offset `2^(2W + ceil(log2(N)) + 1)` (dominant term) /// - Outgoing carry term `2^W * offset_carry` on the RHS /// -/// Since the native field modulus satisfies `p >= 2^(native_field_bits - 1)`, the -/// conservative soundness condition is: +/// Since the native field modulus satisfies `p >= 2^(native_field_bits - 1)`, +/// the conservative soundness condition is: /// /// `2 * limb_bits + ceil(log2(num_limbs)) + 3 < native_field_bits` pub fn column_equation_fits_native_field( @@ -259,8 +242,8 @@ pub fn get_optimal_msm_params( } // Upper bound on search: even with N=2 (best case), we need - // 2*lb + ceil(log2(2)) + 3 < native_field_bits => lb < (native_field_bits - 4) / 2. - // The per-candidate soundness check below is the actual gate. + // 2*lb + ceil(log2(2)) + 3 < native_field_bits => lb < (native_field_bits - 4) + // / 2. The per-candidate soundness check below is the actual gate. let max_limb_bits = (native_field_bits.saturating_sub(4)) / 2; let mut best_cost = usize::MAX; let mut best_limb_bits = max_limb_bits.min(86); @@ -268,8 +251,7 @@ pub fn get_optimal_msm_params( // Search space for lb in (8..=max_limb_bits).step_by(2) { - let num_limbs = - ((curve_modulus_bits as usize) + (lb as usize) - 1) / (lb as usize); + let num_limbs = ((curve_modulus_bits as usize) + (lb as usize) - 1) / (lb as usize); if !column_equation_fits_native_field(native_field_bits, lb, num_limbs) { continue; } @@ -329,15 +311,6 @@ mod tests { assert!(window_size >= 2 && window_size <= 8); } - #[test] - fn test_count_field_ops_sanity() { - let (add, sub, mul, inv) = count_scalar_mul_field_ops(256, 4); - assert!(add > 0); - assert!(sub > 0); - assert!(mul > 0); - assert!(inv > 0); - } - #[test] fn test_column_equation_soundness_boundary() { // For BN254 (254 bits) with N=3: max safe limb_bits is 124. diff --git a/provekit/r1cs-compiler/src/msm/curve.rs b/provekit/r1cs-compiler/src/msm/curve.rs index 53a1340f8..07c53891a 100644 --- a/provekit/r1cs-compiler/src/msm/curve.rs +++ b/provekit/r1cs-compiler/src/msm/curve.rs @@ -1,5 +1,5 @@ use { - ark_ff::{BigInteger, PrimeField}, + ark_ff::{BigInteger, Field, PrimeField}, provekit_common::FieldElement, }; @@ -9,10 +9,15 @@ pub struct CurveParams { pub curve_a: [u64; 4], pub curve_b: [u64; 4], pub generator: ([u64; 4], [u64; 4]), + /// A known non-identity point on the curve, used as the accumulator offset + /// in `scalar_mul_glv`. Must be deterministic and unrelated to typical + /// table entries (we use [2]G). + pub offset_point: ([u64; 4], [u64; 4]), } impl CurveParams { - /// Decompose the field modulus p into `num_limbs` limbs of `limb_bits` width each. + /// Decompose the field modulus p into `num_limbs` limbs of `limb_bits` + /// width each. pub fn p_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { decompose_to_limbs(&self.field_modulus_p, limb_bits, num_limbs) } @@ -23,7 +28,8 @@ impl CurveParams { decompose_to_limbs(&p_minus_1, limb_bits, num_limbs) } - /// Decompose the curve parameter `a` into `num_limbs` limbs of `limb_bits` width. + /// Decompose the curve parameter `a` into `num_limbs` limbs of `limb_bits` + /// width. pub fn curve_a_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { decompose_to_limbs(&self.curve_a, limb_bits, num_limbs) } @@ -46,15 +52,124 @@ impl CurveParams { self.field_modulus_p == native_mod.0 } - /// Convert modulus to a native field element (only valid when p < native modulus). + /// Convert modulus to a native field element (only valid when p < native + /// modulus). pub fn p_native_fe(&self) -> FieldElement { curve_native_point_fe(&self.field_modulus_p) } + + /// Returns the curve parameter b as a native field element. + pub fn curve_b_fe(&self) -> FieldElement { + curve_native_point_fe(&self.curve_b) + } + + /// Decompose the curve order n into `num_limbs` limbs of `limb_bits` width + /// each. + pub fn curve_order_n_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&self.curve_order_n, limb_bits, num_limbs) + } + + /// Decompose (curve_order_n - 1) into `num_limbs` limbs of `limb_bits` + /// width each. + pub fn curve_order_n_minus_1_limbs( + &self, + limb_bits: u32, + num_limbs: usize, + ) -> Vec { + let n_minus_1 = sub_one_u64_4(&self.curve_order_n); + decompose_to_limbs(&n_minus_1, limb_bits, num_limbs) + } + + /// Number of bits in the curve order n. + pub fn curve_order_bits(&self) -> u32 { + // Compute bit length directly from raw limbs to avoid reduction + // mod the native field (curve_order_n may exceed the native modulus). + let n = &self.curve_order_n; + for i in (0..4).rev() { + if n[i] != 0 { + return (i as u32) * 64 + (64 - n[i].leading_zeros()); + } + } + 0 + } + + /// Number of bits for the GLV half-scalar: `ceil(order_bits / 2)`. + /// This determines the bit width of the sub-scalars s1, s2 from half-GCD. + pub fn glv_half_bits(&self) -> u32 { + (self.curve_order_bits() + 1) / 2 + } + + /// Decompose the offset point x-coordinate into limbs. + pub fn offset_x_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&self.offset_point.0, limb_bits, num_limbs) + } + + /// Decompose the offset point y-coordinate into limbs. + pub fn offset_y_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&self.offset_point.1, limb_bits, num_limbs) + } + + /// Compute `[2^n_doublings] * offset_point` on the curve (compile-time + /// only). + /// + /// Used to compute the accumulated offset after the scalar_mul_glv loop: + /// since the accumulator starts at R and gets doubled n times total, the + /// offset to subtract is `[2^n]*R`, not just `R`. + pub fn accumulated_offset(&self, n_doublings: usize) -> ([u64; 4], [u64; 4]) { + if self.is_native_field() { + self.accumulated_offset_native(n_doublings) + } else { + self.accumulated_offset_generic(n_doublings) + } + } + + /// Compute accumulated offset using FieldElement arithmetic (native field). + fn accumulated_offset_native(&self, n_doublings: usize) -> ([u64; 4], [u64; 4]) { + let mut x = curve_native_point_fe(&self.offset_point.0); + let mut y = curve_native_point_fe(&self.offset_point.1); + let a = curve_native_point_fe(&self.curve_a); + + for _ in 0..n_doublings { + let x_sq = x * x; + let num = x_sq + x_sq + x_sq + a; + let denom_inv = (y + y).inverse().unwrap(); + let lambda = num * denom_inv; + let x3 = lambda * lambda - x - x; + let y3 = lambda * (x - x3) - y; + x = x3; + y = y3; + } + + (x.into_bigint().0, y.into_bigint().0) + } + + /// Compute accumulated offset using generic 256-bit arithmetic (non-native + /// field). + fn accumulated_offset_generic(&self, n_doublings: usize) -> ([u64; 4], [u64; 4]) { + let p = &self.field_modulus_p; + let mut x = self.offset_point.0; + let mut y = self.offset_point.1; + let a = &self.curve_a; + + for _ in 0..n_doublings { + let (x3, y3) = u256_arith::ec_point_double(&x, &y, a, p); + x = x3; + y = y3; + } + + (x, y) + } } /// Decompose a 256-bit value into `num_limbs` limbs of `limb_bits` width each, /// returned as FieldElements. -fn decompose_to_limbs(val: &[u64; 4], limb_bits: u32, num_limbs: usize) -> Vec { +pub fn decompose_to_limbs(val: &[u64; 4], limb_bits: u32, num_limbs: usize) -> Vec { + // Special case: when a single limb needs > 128 bits, FieldElement::from(u128) + // would truncate. Use from_sign_and_limbs to preserve the full value. + if num_limbs == 1 && limb_bits > 128 { + return vec![curve_native_point_fe(val)]; + } + let mask: u128 = if limb_bits >= 128 { u128::MAX } else { @@ -104,6 +219,25 @@ pub fn curve_native_point_fe(val: &[u64; 4]) -> FieldElement { FieldElement::from_sign_and_limbs(true, val) } +/// Negate a field element: compute `-val mod p` (i.e., `p - val`). +/// Returns `[0; 4]` when `val` is zero. +pub fn negate_field_element(val: &[u64; 4], modulus: &[u64; 4]) -> [u64; 4] { + if *val == [0u64; 4] { + return [0u64; 4]; + } + // val is in [1, p-1], so p - val is in [1, p-1] — no borrow. + let mut result = [0u64; 4]; + let mut borrow = false; + for i in 0..4 { + let (d1, b1) = modulus[i].overflowing_sub(val[i]); + let (d2, b2) = d1.overflowing_sub(borrow as u64); + result[i] = d2; + borrow = b1 || b2; + } + debug_assert!(!borrow, "negate_field_element: val >= modulus"); + result +} + /// Grumpkin curve parameters. /// /// Grumpkin is a cycle-companion curve for BN254: its base field is the BN254 @@ -120,33 +254,344 @@ pub fn grumpkin_params() -> CurveParams { 0x30644e72e131a029_u64, ], // BN254 base field modulus - curve_order_n: [ + curve_order_n: [ 0x3c208c16d87cfd47_u64, 0x97816a916871ca8d_u64, 0xb85045b68181585d_u64, 0x30644e72e131a029_u64, ], - curve_a: [0; 4], + curve_a: [0; 4], // b = −17 mod p - curve_b: [ + curve_b: [ 0x43e1f593effffff0_u64, 0x2833e84879b97091_u64, 0xb85045b68181585d_u64, 0x30644e72e131a029_u64, ], // Generator G = (1, sqrt(−16) mod p) - generator: ( - [1, 0, 0, 0], + generator: ([1, 0, 0, 0], [ + 0x833fc48d823f272c_u64, + 0x2d270d45f1181294_u64, + 0xcf135e7506a45d63_u64, + 0x0000000000000002_u64, + ]), + // Offset point = [2]G + offset_point: ( + [ + 0x6d8bc688cdbffffe_u64, + 0x19a74caa311e13d4_u64, + 0xddeb49cdaa36306d_u64, + 0x06ce1b0827aafa85_u64, + ], [ - 0x833fc48d823f272c_u64, - 0x2d270d45f1181294_u64, - 0xcf135e7506a45d63_u64, - 0x0000000000000002_u64, + 0x467be7e7a43f80ac_u64, + 0xc93faf6fa1a788bf_u64, + 0x909ede0ba2a6855f_u64, + 0x1c122f81a3a14964_u64, ], ), } } +/// 256-bit modular arithmetic for compile-time EC point computations. +/// Only used to precompute accumulated offset points; not performance-critical. +mod u256_arith { + type U256 = [u64; 4]; + + /// Returns true if a >= b. + fn gte(a: &U256, b: &U256) -> bool { + for i in (0..4).rev() { + if a[i] > b[i] { + return true; + } + if a[i] < b[i] { + return false; + } + } + true // equal + } + + /// a + b, returns (result, carry). + fn add(a: &U256, b: &U256) -> (U256, bool) { + let mut result = [0u64; 4]; + let mut carry = 0u128; + for i in 0..4 { + carry += a[i] as u128 + b[i] as u128; + result[i] = carry as u64; + carry >>= 64; + } + (result, carry != 0) + } + + /// a - b, returns (result, borrow). + fn sub(a: &U256, b: &U256) -> (U256, bool) { + let mut result = [0u64; 4]; + let mut borrow = false; + for i in 0..4 { + let (d1, b1) = a[i].overflowing_sub(b[i]); + let (d2, b2) = d1.overflowing_sub(borrow as u64); + result[i] = d2; + borrow = b1 || b2; + } + (result, borrow) + } + + /// (a + b) mod p. + pub fn mod_add(a: &U256, b: &U256, p: &U256) -> U256 { + let (s, overflow) = add(a, b); + if overflow || gte(&s, p) { + sub(&s, p).0 + } else { + s + } + } + + /// (a - b) mod p. + fn mod_sub(a: &U256, b: &U256, p: &U256) -> U256 { + let (d, borrow) = sub(a, b); + if borrow { + add(&d, p).0 + } else { + d + } + } + + /// Schoolbook multiplication producing 512-bit result. + fn mul_wide(a: &U256, b: &U256) -> [u64; 8] { + let mut result = [0u64; 8]; + for i in 0..4 { + let mut carry = 0u128; + for j in 0..4 { + let prod = (a[i] as u128) * (b[j] as u128) + result[i + j] as u128 + carry; + result[i + j] = prod as u64; + carry = prod >> 64; + } + result[i + 4] = result[i + 4].wrapping_add(carry as u64); + } + result + } + + /// Reduce a 512-bit value mod a 256-bit prime using bit-by-bit long + /// division. + fn mod_reduce_wide(a: &[u64; 8], p: &U256) -> U256 { + let mut total_bits = 0; + for i in (0..8).rev() { + if a[i] != 0 { + total_bits = i * 64 + (64 - a[i].leading_zeros() as usize); + break; + } + } + if total_bits == 0 { + return [0; 4]; + } + + let mut r = [0u64; 4]; + for bit_idx in (0..total_bits).rev() { + // Left shift r by 1 + let overflow = r[3] >> 63; + for j in (1..4).rev() { + r[j] = (r[j] << 1) | (r[j - 1] >> 63); + } + r[0] <<= 1; + + // Insert current bit of a + let word = bit_idx / 64; + let bit = bit_idx % 64; + r[0] |= (a[word] >> bit) & 1; + + // If r >= p (or overflow from shift), subtract p + if overflow != 0 || gte(&r, p) { + r = sub(&r, p).0; + } + } + r + } + + /// (a * b) mod p. + pub fn mod_mul(a: &U256, b: &U256, p: &U256) -> U256 { + let wide = mul_wide(a, b); + mod_reduce_wide(&wide, p) + } + + /// a^exp mod p using square-and-multiply. + fn mod_pow(base: &U256, exp: &U256, p: &U256) -> U256 { + let mut highest_bit = 0; + for i in (0..4).rev() { + if exp[i] != 0 { + highest_bit = i * 64 + (64 - exp[i].leading_zeros() as usize); + break; + } + } + if highest_bit == 0 { + return [1, 0, 0, 0]; + } + + let mut result: U256 = [1, 0, 0, 0]; + let mut base = *base; + for bit_idx in 0..highest_bit { + let word = bit_idx / 64; + let bit = bit_idx % 64; + if (exp[word] >> bit) & 1 == 1 { + result = mod_mul(&result, &base, p); + } + base = mod_mul(&base, &base, p); + } + result + } + + /// a^(-1) mod p via Fermat's little theorem: a^(p-2) mod p. + fn mod_inv(a: &U256, p: &U256) -> U256 { + let two: U256 = [2, 0, 0, 0]; + let exp = sub(p, &two).0; + mod_pow(a, &exp, p) + } + + /// EC point addition on y^2 = x^3 + ax + b. + /// Computes (x1,y1) + (x2,y2). Requires x1 != x2. + pub fn ec_point_add(x1: &U256, y1: &U256, x2: &U256, y2: &U256, p: &U256) -> (U256, U256) { + // lambda = (y2 - y1) / (x2 - x1) + let num = mod_sub(y2, y1, p); + let denom = mod_sub(x2, x1, p); + let denom_inv = mod_inv(&denom, p); + let lambda = mod_mul(&num, &denom_inv, p); + + // x3 = lambda^2 - x1 - x2 + let lambda_sq = mod_mul(&lambda, &lambda, p); + let x1_plus_x2 = mod_add(x1, x2, p); + let x3 = mod_sub(&lambda_sq, &x1_plus_x2, p); + + // y3 = lambda * (x1 - x3) - y1 + let x1_minus_x3 = mod_sub(x1, &x3, p); + let lambda_dx = mod_mul(&lambda, &x1_minus_x3, p); + let y3 = mod_sub(&lambda_dx, y1, p); + + (x3, y3) + } + + /// EC point doubling on y^2 = x^3 + ax + b. + pub fn ec_point_double(x: &U256, y: &U256, a: &U256, p: &U256) -> (U256, U256) { + // lambda = (3*x^2 + a) / (2*y) + let x_sq = mod_mul(x, x, p); + let two_x_sq = mod_add(&x_sq, &x_sq, p); + let three_x_sq = mod_add(&two_x_sq, &x_sq, p); + let num = mod_add(&three_x_sq, a, p); + let two_y = mod_add(y, y, p); + let denom_inv = mod_inv(&two_y, p); + let lambda = mod_mul(&num, &denom_inv, p); + + // x3 = lambda^2 - 2*x + let lambda_sq = mod_mul(&lambda, &lambda, p); + let two_x = mod_add(x, x, p); + let x3 = mod_sub(&lambda_sq, &two_x, p); + + // y3 = lambda * (x - x3) - y + let x_minus_x3 = mod_sub(x, &x3, p); + let lambda_dx = mod_mul(&lambda, &x_minus_x3, p); + let y3 = mod_sub(&lambda_dx, y, p); + + (x3, y3) + } +} + +#[cfg(test)] +mod tests { + use {super::*, ark_ff::Field}; + + #[test] + fn test_offset_point_on_curve_grumpkin() { + let c = grumpkin_params(); + let x = curve_native_point_fe(&c.offset_point.0); + let y = curve_native_point_fe(&c.offset_point.1); + let b = curve_native_point_fe(&c.curve_b); + // Grumpkin: y^2 = x^3 + b (a=0) + assert_eq!(y * y, x * x * x + b, "offset point not on Grumpkin"); + } + + #[test] + fn test_accumulated_offset_single_double_grumpkin() { + let c = grumpkin_params(); + let (x4, y4) = c.accumulated_offset(1); + let x = curve_native_point_fe(&x4); + let y = curve_native_point_fe(&y4); + let b = curve_native_point_fe(&c.curve_b); + // Should still be on curve + assert_eq!(y * y, x * x * x + b, "[4]G not on Grumpkin"); + } + + #[test] + fn test_accumulated_offset_native_vs_generic() { + let c = grumpkin_params(); + // Both paths should give the same result + let native = c.accumulated_offset_native(10); + let generic = c.accumulated_offset_generic(10); + assert_eq!(native, generic, "native vs generic mismatch for n=10"); + } + + #[test] + fn test_accumulated_offset_256_on_curve() { + let c = grumpkin_params(); + let (x, y) = c.accumulated_offset(256); + let xfe = curve_native_point_fe(&x); + let yfe = curve_native_point_fe(&y); + let b = curve_native_point_fe(&c.curve_b); + assert_eq!(yfe * yfe, xfe * xfe * xfe + b, "[2^257]G not on Grumpkin"); + } + + #[test] + fn test_offset_point_on_curve_secp256r1() { + let c = secp256r1_params(); + let p = &c.field_modulus_p; + let x = &c.offset_point.0; + let y = &c.offset_point.1; + let a = &c.curve_a; + let b = &c.curve_b; + // y^2 = x^3 + a*x + b (mod p) + let y_sq = u256_arith::mod_mul(y, y, p); + let x_sq = u256_arith::mod_mul(x, x, p); + let x_cubed = u256_arith::mod_mul(&x_sq, x, p); + let ax = u256_arith::mod_mul(a, x, p); + let x3_plus_ax = u256_arith::mod_add(&x_cubed, &ax, p); + let rhs = u256_arith::mod_add(&x3_plus_ax, b, p); + assert_eq!(y_sq, rhs, "offset point not on secp256r1"); + } + + #[test] + fn test_accumulated_offset_secp256r1() { + let c = secp256r1_params(); + let p = &c.field_modulus_p; + let a = &c.curve_a; + let b = &c.curve_b; + let (x, y) = c.accumulated_offset(256); + // Verify the accumulated offset is on the curve + let y_sq = u256_arith::mod_mul(&y, &y, p); + let x_sq = u256_arith::mod_mul(&x, &x, p); + let x_cubed = u256_arith::mod_mul(&x_sq, &x, p); + let ax = u256_arith::mod_mul(a, &x, p); + let x3_plus_ax = u256_arith::mod_add(&x_cubed, &ax, p); + let rhs = u256_arith::mod_add(&x3_plus_ax, b, p); + assert_eq!(y_sq, rhs, "accumulated offset not on secp256r1"); + } + + #[test] + fn test_fe_roundtrip() { + // Verify from_sign_and_limbs / into_bigint roundtrip + let val: [u64; 4] = [42, 0, 0, 0]; + let fe = curve_native_point_fe(&val); + let back = fe.into_bigint().0; + assert_eq!(val, back, "roundtrip failed for small value"); + + let val2: [u64; 4] = [ + 0x6d8bc688cdbffffe, + 0x19a74caa311e13d4, + 0xddeb49cdaa36306d, + 0x06ce1b0827aafa85, + ]; + let fe2 = curve_native_point_fe(&val2); + let back2 = fe2.into_bigint().0; + assert_eq!(val2, back2, "roundtrip failed for offset x"); + } +} + pub fn secp256r1_params() -> CurveParams { CurveParams { field_modulus_p: [ @@ -187,5 +632,20 @@ pub fn secp256r1_params() -> CurveParams { 0x4fe342e2fe1a7f9b_u64, ], ), + // Offset point = [2]G + offset_point: ( + [ + 0xa60b48fc47669978_u64, + 0xc08969e277f21b35_u64, + 0x8a52380304b51ac3_u64, + 0x7cf27b188d034f7e_u64, + ], + [ + 0x9e04b79d227873d1_u64, + 0xba7dade63ce98229_u64, + 0x293d9ac69f7430db_u64, + 0x07775510db8ed040_u64, + ], + ), } } diff --git a/provekit/r1cs-compiler/src/msm/ec_points.rs b/provekit/r1cs-compiler/src/msm/ec_points.rs index 14712c78c..8f7172897 100644 --- a/provekit/r1cs-compiler/src/msm/ec_points.rs +++ b/provekit/r1cs-compiler/src/msm/ec_points.rs @@ -100,43 +100,6 @@ pub fn point_select( (x, y) } -/// Point addition with safe denominator for the `x1 = x2` edge case. -/// -/// When `x_eq = 1`, the denominator `(x2 - x1)` is zero and cannot be -/// inverted. This function replaces it with 1, producing a satisfiable -/// but meaningless result. The caller MUST discard this result via -/// `point_select` when `x_eq = 1`. -/// -/// The `denom` parameter is the precomputed `x2 - x1`. -fn safe_point_add( - ops: &mut F, - x1: F::Elem, - y1: F::Elem, - x2: F::Elem, - y2: F::Elem, - denom: F::Elem, - x_eq: usize, -) -> (F::Elem, F::Elem) { - let numerator = ops.sub(y2, y1); - - // When x_eq=1 (denom=0), substitute with 1 to keep inv satisfiable - let one = ops.constant_one(); - let safe_denom = ops.select(x_eq, denom, one); - - let denom_inv = ops.inv(safe_denom); - let lambda = ops.mul(numerator, denom_inv); - - let lambda_sq = ops.mul(lambda, lambda); - let x1_plus_x2 = ops.add(x1, x2); - let x3 = ops.sub(lambda_sq, x1_plus_x2); - - let x1_minus_x3 = ops.sub(x1, x3); - let lambda_dx = ops.mul(lambda, x1_minus_x3); - let y3 = ops.sub(lambda_dx, y1); - - (x3, y3) -} - /// Builds a point table for windowed scalar multiplication. /// /// T[0] = P (dummy entry, used when window digit = 0) @@ -161,7 +124,8 @@ fn build_point_table( table } -/// Selects T[d] from a point table using bit witnesses, where `d = Σ bits[i] * 2^i`. +/// Selects T[d] from a point table using bit witnesses, where `d = Σ bits[i] * +/// 2^i`. /// /// Uses a binary tree of `point_select`s: processes bits from MSB to LSB, /// halving the candidate set at each level. Total: `(2^w - 1)` point selects @@ -185,100 +149,103 @@ fn table_lookup( current[0] } -/// Windowed scalar multiplication: computes `[scalar] * P`. +/// Interleaved two-point scalar multiplication for FakeGLV. /// -/// Takes pre-decomposed scalar bits (LSB first, `scalar_bits[0]` is the -/// least significant bit) and a window size `w`. Precomputes a table of -/// `2^w` point multiples and processes the scalar in `w`-bit windows from -/// MSB to LSB. +/// Computes `[s1]P + [s2]R` using shared doublings, where s1 and s2 are +/// half-width scalars (typically ~128-bit for 256-bit curves). The +/// accumulator starts at an offset point and the caller checks equality +/// with the accumulated offset to verify the constraint `[s1]P + [s2]R = O`. /// -/// Handles two edge cases: -/// 1. **MSB window digit = 0**: The accumulator is initialized from T[0] -/// (a dummy copy of P). An `acc_is_identity` flag tracks that no real -/// point has been accumulated yet. When the first non-zero window digit -/// is encountered, the looked-up point becomes the new accumulator. -/// 2. **x-coordinate collision** (`acc.x == looked_up.x`): Uses -/// `point_double` instead of `point_add`, with `safe_point_add` -/// guarding the zero denominator. +/// Structure per window (from MSB to LSB): +/// 1. `w` shared doublings on accumulator +/// 2. Table lookup in T_P[d1] for s1's window digit +/// 3. point_add(acc, T_P[d1]) + is_zero(d1) + point_select +/// 4. Table lookup in T_R[d2] for s2's window digit +/// 5. point_add(acc, T_R[d2]) + is_zero(d2) + point_select /// -/// The inverse-point case (`acc = -looked_up`, result is infinity) cannot -/// be represented in affine coordinates and remains unsupported — this has -/// negligible probability (~2^{-256}) for random scalars. -pub fn scalar_mul( +/// Returns the final accumulator (x, y). +pub fn scalar_mul_glv( ops: &mut F, + // Point P (table 1) px: F::Elem, py: F::Elem, - scalar_bits: &[usize], + s1_bits: &[usize], // 128 bit witnesses for |s1| + // Point R (table 2) — the claimed output + rx: F::Elem, + ry: F::Elem, + s2_bits: &[usize], // 128 bit witnesses for |s2| + // Shared parameters window_size: usize, + offset_x: F::Elem, + offset_y: F::Elem, ) -> (F::Elem, F::Elem) { - let n = scalar_bits.len(); + let n1 = s1_bits.len(); + let n2 = s2_bits.len(); + assert_eq!(n1, n2, "s1 and s2 must have the same number of bits"); + let n = n1; let w = window_size; let table_size = 1 << w; - // Build point table: T[i] = [i]P, with T[0] = P as dummy - let table = build_point_table(ops, px, py, table_size); + // Build point tables: T_P[i] = [i]P, T_R[i] = [i]R + let table_p = build_point_table(ops, px, py, table_size); + let table_r = build_point_table(ops, rx, ry, table_size); - // Number of windows (ceiling division) let num_windows = (n + w - 1) / w; - // Process MSB window first (may be shorter than w bits if n % w != 0) - let msb_start = (num_windows - 1) * w; - let msb_bits = &scalar_bits[msb_start..n]; - let msb_table = &table[..1 << msb_bits.len()]; - let mut acc = table_lookup(ops, msb_table, msb_bits); + // Initialize accumulator with the offset point + let mut acc = (offset_x, offset_y); - // Track whether acc represents the identity (no real point yet). - // When MSB digit = 0, T[0] = P is loaded as a dummy — we must not - // double or add it until the first non-zero window digit appears. - let msb_digit = ops.pack_bits(msb_bits); - let mut acc_is_identity = ops.is_zero(msb_digit); + // Process all windows from MSB down to LSB + for i in (0..num_windows).rev() { + let bit_start = i * w; + let bit_end = std::cmp::min(bit_start + w, n); + let actual_w = bit_end - bit_start; - // Process remaining windows from MSB-1 down to LSB - for i in (0..num_windows - 1).rev() { - // w doublings — only meaningful when acc is a real point. - // When acc_is_identity=1, the doubling result is garbage but will - // be discarded by the point_select below. + // w shared doublings on the accumulator let mut doubled_acc = acc; for _ in 0..w { doubled_acc = point_double(ops, doubled_acc.0, doubled_acc.1); } - // If acc is identity, keep dummy; otherwise use doubled result - acc = point_select(ops, acc_is_identity, doubled_acc, acc); - - // Table lookup for this window's digit - let window_bits = &scalar_bits[i * w..(i + 1) * w]; - let digit = ops.pack_bits(window_bits); - let digit_is_zero = ops.is_zero(digit); - let looked_up = table_lookup(ops, &table, window_bits); - - // Detect x-coordinate collision: acc.x == looked_up.x - let denom = ops.sub(looked_up.0, acc.0); - let x_eq = ops.elem_is_zero(denom); - - // point_double handles the acc == looked_up case (same point) - let doubled = point_double(ops, acc.0, acc.1); - - // Safe point_add (substitutes denominator when x_eq=1) - let added = safe_point_add( - ops, acc.0, acc.1, looked_up.0, looked_up.1, denom, x_eq, + // --- Process P's window digit (s1) --- + let s1_window_bits = &s1_bits[bit_start..bit_end]; + let lookup_table_p = if actual_w < w { + &table_p[..1 << actual_w] + } else { + &table_p[..] + }; + let looked_up_p = table_lookup(ops, lookup_table_p, s1_window_bits); + let added_p = point_add( + ops, + doubled_acc.0, + doubled_acc.1, + looked_up_p.0, + looked_up_p.1, ); - - // x_eq=0 => use add result, x_eq=1 => use double result - let combined = point_select(ops, x_eq, added, doubled); - - // Four cases based on (acc_is_identity, digit_is_zero): - // (0, 0) => combined — normal add/double - // (0, 1) => acc — keep accumulator - // (1, 0) => looked_up — first real point - // (1, 1) => acc — still identity - let normal_result = point_select(ops, digit_is_zero, combined, acc); - let identity_result = point_select(ops, digit_is_zero, looked_up, acc); - acc = point_select(ops, acc_is_identity, normal_result, identity_result); - - // Update: acc is identity only if it was identity AND digit is zero - acc_is_identity = ops.bool_and(acc_is_identity, digit_is_zero); + let digit_p = ops.pack_bits(s1_window_bits); + let digit_p_is_zero = ops.is_zero(digit_p); + let after_p = point_select(ops, digit_p_is_zero, added_p, doubled_acc); + + // --- Process R's window digit (s2) --- + let s2_window_bits = &s2_bits[bit_start..bit_end]; + let lookup_table_r = if actual_w < w { + &table_r[..1 << actual_w] + } else { + &table_r[..] + }; + let looked_up_r = table_lookup(ops, lookup_table_r, s2_window_bits); + let added_r = point_add( + ops, + after_p.0, + after_p.1, + looked_up_r.0, + looked_up_r.1, + ); + let digit_r = ops.pack_bits(s2_window_bits); + let digit_r_is_zero = ops.is_zero(digit_r); + acc = point_select(ops, digit_r_is_zero, added_r, after_p); } acc } + diff --git a/provekit/r1cs-compiler/src/msm/mod.rs b/provekit/r1cs-compiler/src/msm/mod.rs index dda1e064a..826d381c4 100644 --- a/provekit/r1cs-compiler/src/msm/mod.rs +++ b/provekit/r1cs-compiler/src/msm/mod.rs @@ -10,7 +10,7 @@ use { noir_to_r1cs::NoirToR1CSCompiler, }, ark_ff::{AdditiveGroup, Field}, - curve::CurveParams, + curve::{decompose_to_limbs as decompose_to_limbs_pub, CurveParams}, multi_limb_ops::{MultiLimbOps, MultiLimbParams}, provekit_common::{ witness::{ConstantOrR1CSWitness, ConstantTerm, SumTerm, WitnessBuilder}, @@ -151,16 +151,8 @@ pub trait FieldOps { /// Does NOT constrain bits to be boolean — caller must ensure that. fn pack_bits(&mut self, bits: &[usize]) -> usize; - /// Checks if a field element (in the curve's base field) is zero. - /// Returns a boolean witness: 1 if zero, 0 if non-zero. - fn elem_is_zero(&mut self, value: Self::Elem) -> usize; - - /// Returns the constant field element 1. - fn constant_one(&mut self) -> Self::Elem; - - /// Computes a * b for two boolean (0/1) native witnesses. - /// Used for boolean AND on flags in scalar_mul. - fn bool_and(&mut self, a: usize, b: usize) -> usize; + /// Returns a constant field element from its limb decomposition. + fn constant_limbs(&mut self, limbs: &[FieldElement]) -> Self::Elem; } // --------------------------------------------------------------------------- @@ -319,8 +311,12 @@ fn add_single_msm( /// Process a full single-MSM with runtime `num_limbs`. /// -/// Handles coordinate decomposition, scalar_mul, accumulation, and -/// output constraining. +/// Uses FakeGLV for ALL points: each point P_i with scalar s_i is verified +/// using scalar decomposition and half-width interleaved scalar mul. +/// +/// For `n_points == 1`, R = (out_x, out_y) is the ACIR output. +/// For `n_points > 1`, R_i = EcScalarMulHint witnesses, accumulated via +/// point_add and constrained against the ACIR output. fn process_single_msm<'a>( mut compiler: &'a mut NoirToR1CSCompiler, point_wits: &[usize], @@ -333,79 +329,319 @@ fn process_single_msm<'a>( curve: &CurveParams, ) { let n_points = point_wits.len() / 3; - let mut acc: Option<(Limbs, Limbs)> = None; + let (out_x, out_y, out_inf) = outputs; - for i in 0..n_points { - let px_witness = point_wits[3 * i]; - let py_witness = point_wits[3 * i + 1]; + if n_points == 1 { + // Single-point: R is the ACIR output directly + let px_witness = point_wits[0]; + let py_witness = point_wits[1]; + // Constrain input infinity flag to 0 (affine coordinates cannot represent infinity) + constrain_zero(compiler, point_wits[2]); + let s_lo = scalar_wits[0]; + let s_hi = scalar_wits[1]; + + // Decompose P into limbs + let (px, py) = decompose_point_to_limbs( + compiler, + px_witness, + py_witness, + num_limbs, + limb_bits, + range_checks, + ); + // R = ACIR output, decompose into limbs + let (rx, ry) = decompose_point_to_limbs( + compiler, out_x, out_y, num_limbs, limb_bits, range_checks, + ); - let s_lo = scalar_wits[2 * i]; - let s_hi = scalar_wits[2 * i + 1]; - let scalar_bits = decompose_scalar_bits(compiler, s_lo, s_hi); + (compiler, range_checks) = verify_point_fakeglv( + compiler, + range_checks, + px, + py, + rx, + ry, + s_lo, + s_hi, + num_limbs, + limb_bits, + window_size, + curve, + ); - // Build coordinates as Limbs - let (px, py) = if num_limbs == 1 { - // Single-limb: wrap witness directly - (Limbs::single(px_witness), Limbs::single(py_witness)) - } else { - // Multi-limb: decompose single witness into num_limbs limbs - let px_limbs = decompose_witness_to_limbs( + constrain_zero(compiler, out_inf); + } else { + // Multi-point: compute R_i = [s_i]P_i via hints, verify each with FakeGLV, + // then accumulate R_i's and constrain against ACIR output. + let mut acc: Option<(Limbs, Limbs)> = None; + + for i in 0..n_points { + let px_witness = point_wits[3 * i]; + let py_witness = point_wits[3 * i + 1]; + // Constrain input infinity flag to 0 (affine coordinates cannot represent infinity) + constrain_zero(compiler, point_wits[3 * i + 2]); + let s_lo = scalar_wits[2 * i]; + let s_hi = scalar_wits[2 * i + 1]; + + // Add EcScalarMulHint → R_i = [s_i]P_i + let hint_start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::EcScalarMulHint { + output_start: hint_start, + px: px_witness, + py: py_witness, + s_lo, + s_hi, + curve_a: curve.curve_a, + field_modulus_p: curve.field_modulus_p, + }); + let rx_witness = hint_start; + let ry_witness = hint_start + 1; + + // Decompose P_i into limbs + let (px, py) = decompose_point_to_limbs( compiler, px_witness, - limb_bits, + py_witness, num_limbs, + limb_bits, range_checks, ); - let py_limbs = decompose_witness_to_limbs( + // Decompose R_i into limbs + let (rx, ry) = decompose_point_to_limbs( compiler, - py_witness, - limb_bits, + rx_witness, + ry_witness, num_limbs, + limb_bits, range_checks, ); - (px_limbs, py_limbs) - }; + // Verify R_i = [s_i]P_i using FakeGLV + (compiler, range_checks) = verify_point_fakeglv( + compiler, + range_checks, + px, + py, + rx, + ry, + s_lo, + s_hi, + num_limbs, + limb_bits, + window_size, + curve, + ); + + // Accumulate R_i via point_add + acc = Some(match acc { + None => (rx, ry), + Some((ax, ay)) => { + let params = build_params(num_limbs, limb_bits, curve); + let mut ops = MultiLimbOps { + compiler, + range_checks, + params, + }; + let sum = ec_points::point_add(&mut ops, ax, ay, rx, ry); + compiler = ops.compiler; + range_checks = ops.range_checks; + sum + } + }); + } + + let (computed_x, computed_y) = acc.expect("MSM must have at least one point"); + + if num_limbs == 1 { + constrain_equal(compiler, out_x, computed_x[0]); + constrain_equal(compiler, out_y, computed_y[0]); + } else { + let recomposed_x = recompose_limbs(compiler, computed_x.as_slice(), limb_bits); + let recomposed_y = recompose_limbs(compiler, computed_y.as_slice(), limb_bits); + constrain_equal(compiler, out_x, recomposed_x); + constrain_equal(compiler, out_y, recomposed_y); + } + constrain_zero(compiler, out_inf); + } +} + +/// Decompose a point (px_witness, py_witness) into Limbs. +fn decompose_point_to_limbs( + compiler: &mut NoirToR1CSCompiler, + px_witness: usize, + py_witness: usize, + num_limbs: usize, + limb_bits: u32, + range_checks: &mut BTreeMap>, +) -> (Limbs, Limbs) { + if num_limbs == 1 { + (Limbs::single(px_witness), Limbs::single(py_witness)) + } else { + let px_limbs = + decompose_witness_to_limbs(compiler, px_witness, limb_bits, num_limbs, range_checks); + let py_limbs = + decompose_witness_to_limbs(compiler, py_witness, limb_bits, num_limbs, range_checks); + (px_limbs, py_limbs) + } +} + +/// FakeGLV verification for a single point: verifies R = [s]P. +/// +/// Decomposes s via half-GCD into sub-scalars (s1, s2) and verifies +/// [s1]P + [s2]R = O using interleaved windowed scalar mul with +/// half-width scalars. +/// +/// Returns the mutable references back to the caller for continued use. +fn verify_point_fakeglv<'a>( + mut compiler: &'a mut NoirToR1CSCompiler, + mut range_checks: &'a mut BTreeMap>, + px: Limbs, + py: Limbs, + rx: Limbs, + ry: Limbs, + s_lo: usize, + s_hi: usize, + num_limbs: usize, + limb_bits: u32, + window_size: usize, + curve: &CurveParams, +) -> ( + &'a mut NoirToR1CSCompiler, + &'a mut BTreeMap>, +) { + // --- Step 1: On-curve checks for P and R --- + { let params = build_params(num_limbs, limb_bits, curve); let mut ops = MultiLimbOps { compiler, range_checks, params, }; - let result = ec_points::scalar_mul(&mut ops, px, py, &scalar_bits, window_size); + + let b_limb_values = curve::decompose_to_limbs(&curve.curve_b, limb_bits, num_limbs); + + verify_on_curve(&mut ops, px, py, &b_limb_values, num_limbs); + verify_on_curve(&mut ops, rx, ry, &b_limb_values, num_limbs); + compiler = ops.compiler; range_checks = ops.range_checks; + } + + // --- Step 2: FakeGLVHint → |s1|, |s2|, neg1, neg2 --- + let glv_start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::FakeGLVHint { + output_start: glv_start, + s_lo, + s_hi, + curve_order: curve.curve_order_n, + }); + let s1_witness = glv_start; + let s2_witness = glv_start + 1; + let neg1_witness = glv_start + 2; + let neg2_witness = glv_start + 3; + + // neg1 and neg2 are constrained to be boolean by the `select` calls + // in Step 4 below (MultiLimbOps::select calls constrain_boolean internally). + + // --- Step 3: Decompose |s1|, |s2| into half_bits bits each --- + let half_bits = curve.glv_half_bits() as usize; + let s1_bits = decompose_half_scalar_bits(compiler, s1_witness, half_bits); + let s2_bits = decompose_half_scalar_bits(compiler, s2_witness, half_bits); + + // --- Step 4: Conditionally negate P.y and R.y + GLV scalar mul + identity check --- + { + let params = build_params(num_limbs, limb_bits, curve); + let mut ops = MultiLimbOps { + compiler, + range_checks, + params, + }; + + // Compute negated y-coordinates: neg_y = 0 - y (mod p) + let zero_limbs = vec![FieldElement::from(0u64); num_limbs]; + let zero = ops.constant_limbs(&zero_limbs); + + let neg_py = ops.sub(zero, py); + let neg_ry = ops.sub(zero, ry); + + // Select: if neg1=1, use neg_py; else use py + let py_effective = ops.select(neg1_witness, py, neg_py); + // Select: if neg2=1, use neg_ry; else use ry + let ry_effective = ops.select(neg2_witness, ry, neg_ry); + + // GLV scalar mul + let offset_x_values = curve.offset_x_limbs(limb_bits, num_limbs); + let offset_y_values = curve.offset_y_limbs(limb_bits, num_limbs); + let offset_x = ops.constant_limbs(&offset_x_values); + let offset_y = ops.constant_limbs(&offset_y_values); + + let glv_acc = ec_points::scalar_mul_glv( + &mut ops, + px, + py_effective, + &s1_bits, + rx, + ry_effective, + &s2_bits, + window_size, + offset_x, + offset_y, + ); + + // Identity check: acc should equal [2^(num_windows * window_size)] * offset_point + let glv_num_windows = (half_bits + window_size - 1) / window_size; + let glv_n_doublings = glv_num_windows * window_size; + let (acc_off_x_raw, acc_off_y_raw) = curve.accumulated_offset(glv_n_doublings); + + let acc_off_x_values = decompose_to_limbs_pub(&acc_off_x_raw, limb_bits, num_limbs); + let acc_off_y_values = decompose_to_limbs_pub(&acc_off_y_raw, limb_bits, num_limbs); + let expected_x = ops.constant_limbs(&acc_off_x_values); + let expected_y = ops.constant_limbs(&acc_off_y_values); + + for i in 0..num_limbs { + constrain_equal(ops.compiler, glv_acc.0[i], expected_x[i]); + constrain_equal(ops.compiler, glv_acc.1[i], expected_y[i]); + } - acc = Some(match acc { - None => result, - Some((ax, ay)) => { - let params = build_params(num_limbs, limb_bits, curve); - let mut ops = MultiLimbOps { - compiler, - range_checks, - params, - }; - let sum = ec_points::point_add(&mut ops, ax, ay, result.0, result.1); - compiler = ops.compiler; - range_checks = ops.range_checks; - sum - } - }); + compiler = ops.compiler; + range_checks = ops.range_checks; } - let (computed_x, computed_y) = acc.expect("MSM must have at least one point"); - let (out_x, out_y, out_inf) = outputs; + // --- Step 5: Scalar relation verification --- + verify_scalar_relation( + compiler, + range_checks, + s_lo, + s_hi, + s1_witness, + s2_witness, + neg1_witness, + neg2_witness, + curve, + ); - if num_limbs == 1 { - constrain_equal(compiler, out_x, computed_x[0]); - constrain_equal(compiler, out_y, computed_y[0]); - } else { - let recomposed_x = recompose_limbs(compiler, computed_x.as_slice(), limb_bits); - let recomposed_y = recompose_limbs(compiler, computed_y.as_slice(), limb_bits); - constrain_equal(compiler, out_x, recomposed_x); - constrain_equal(compiler, out_y, recomposed_y); + (compiler, range_checks) +} + +/// On-curve check: verifies y^2 = x^3 + a*x + b for a single point. +fn verify_on_curve( + ops: &mut MultiLimbOps, + x: Limbs, + y: Limbs, + b_limb_values: &[FieldElement], + num_limbs: usize, +) { + let y_sq = ops.mul(y, y); + let x_sq = ops.mul(x, x); + let x_cubed = ops.mul(x_sq, x); + let a = ops.curve_a(); + let ax = ops.mul(a, x); + let x3_plus_ax = ops.add(x_cubed, ax); + let b = ops.constant_limbs(b_limb_values); + let rhs = ops.add(x3_plus_ax, b); + for i in 0..num_limbs { + constrain_equal(ops.compiler, y_sq[i], rhs[i]); } - constrain_zero(compiler, out_inf); } /// Decompose a single witness into `num_limbs` limbs using digital @@ -456,26 +692,171 @@ fn resolve_input(compiler: &mut NoirToR1CSCompiler, input: &ConstantOrR1CSWitnes } } -/// Decomposes a scalar given as two 128-bit limbs into 256 bit witnesses (LSB -/// first). -fn decompose_scalar_bits( +/// Decomposes a half-scalar witness into `half_bits` bit witnesses (LSB first). +fn decompose_half_scalar_bits( compiler: &mut NoirToR1CSCompiler, - s_lo: usize, - s_hi: usize, + scalar: usize, + half_bits: usize, ) -> Vec { - let log_bases_128 = vec![1usize; 128]; + let log_bases = vec![1usize; half_bits]; + let dd = add_digital_decomposition(compiler, log_bases, vec![scalar]); + let mut bits = Vec::with_capacity(half_bits); + for bit_idx in 0..half_bits { + bits.push(dd.get_digit_witness_index(bit_idx, 0)); + } + bits +} + +/// Builds `MultiLimbParams` for scalar relation verification (mod +/// curve_order_n). +fn build_scalar_relation_params( + num_limbs: usize, + limb_bits: u32, + curve: &CurveParams, +) -> MultiLimbParams { + // Scalar relation uses curve_order_n as the modulus. + // This is always non-native (curve_order_n ≠ BN254 scalar field modulus, + // except for Grumpkin where they're very close but still different). + let two_pow_w = FieldElement::from(2u64).pow([limb_bits as u64]); + let n_limbs = curve.curve_order_n_limbs(limb_bits, num_limbs); + let n_minus_1_limbs = curve.curve_order_n_minus_1_limbs(limb_bits, num_limbs); - let dd_lo = add_digital_decomposition(compiler, log_bases_128.clone(), vec![s_lo]); - let dd_hi = add_digital_decomposition(compiler, log_bases_128, vec![s_hi]); + // For N=1 non-native, we need the modulus as a FieldElement + let modulus_fe = if num_limbs == 1 { + Some(curve::curve_native_point_fe(&curve.curve_order_n)) + } else { + None + }; - let mut bits = Vec::with_capacity(256); - for bit_idx in 0..128 { - bits.push(dd_lo.get_digit_witness_index(bit_idx, 0)); + MultiLimbParams { + num_limbs, + limb_bits, + p_limbs: n_limbs, + p_minus_1_limbs: n_minus_1_limbs, + two_pow_w, + modulus_raw: curve.curve_order_n, + curve_a_limbs: vec![FieldElement::from(0u64); num_limbs], // unused + modulus_bits: curve.curve_order_bits(), + is_native: false, // always non-native + modulus_fe, } - for bit_idx in 0..128 { - bits.push(dd_hi.get_digit_witness_index(bit_idx, 0)); +} + +/// Verifies the scalar relation: (-1)^neg1 * |s1| + (-1)^neg2 * |s2| * s ≡ 0 +/// (mod n). +/// +/// Uses multi-limb arithmetic with curve_order_n as the modulus. +/// The sub-scalars s1, s2 have `half_bits = ceil(order_bits/2)` bits; +/// the full scalar s has up to `order_bits` bits. +fn verify_scalar_relation( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + s_lo: usize, + s_hi: usize, + s1_witness: usize, + s2_witness: usize, + neg1_witness: usize, + neg2_witness: usize, + curve: &CurveParams, +) { + // Use 64-bit limbs. Number of limbs covers the full curve order. + let sr_limb_bits: u32 = 64; + let order_bits = curve.curve_order_bits() as usize; + let sr_num_limbs = (order_bits + sr_limb_bits as usize - 1) / sr_limb_bits as usize; + let half_bits = curve.glv_half_bits() as usize; + // Number of 64-bit limbs the half-scalar occupies + let half_limbs = (half_bits + sr_limb_bits as usize - 1) / sr_limb_bits as usize; + + let params = build_scalar_relation_params(sr_num_limbs, sr_limb_bits, curve); + let mut ops = MultiLimbOps { + compiler, + range_checks, + params, + }; + + // Decompose s into sr_num_limbs × 64-bit limbs from (s_lo, s_hi) + // s_lo contains bits [0..128), s_hi contains bits [128..256) + let s_limbs = { + let dd_lo = add_digital_decomposition(ops.compiler, vec![64, 64], vec![s_lo]); + let dd_hi = add_digital_decomposition(ops.compiler, vec![64, 64], vec![s_hi]); + let mut limbs = Limbs::new(sr_num_limbs); + // s_lo provides limbs 0,1; s_hi provides limbs 2,3 (for sr_num_limbs=4) + let lo_n = 2.min(sr_num_limbs); + for i in 0..lo_n { + limbs[i] = dd_lo.get_digit_witness_index(i, 0); + ops.range_checks.entry(64).or_default().push(limbs[i]); + } + let hi_n = sr_num_limbs - lo_n; + for i in 0..hi_n { + limbs[lo_n + i] = dd_hi.get_digit_witness_index(i, 0); + ops.range_checks + .entry(64) + .or_default() + .push(limbs[lo_n + i]); + } + limbs + }; + + // Helper: decompose a half-scalar witness into sr_num_limbs × 64-bit limbs. + // The half-scalar has `half_bits` bits → occupies `half_limbs` 64-bit limbs. + // Upper limbs (half_limbs..sr_num_limbs) are zero-padded. + let decompose_half_scalar = |ops: &mut MultiLimbOps, witness: usize| -> Limbs { + let dd_bases: Vec = (0..half_limbs) + .map(|i| { + let remaining = half_bits as u32 - (i as u32 * 64); + remaining.min(64) as usize + }) + .collect(); + let dd = add_digital_decomposition(ops.compiler, dd_bases, vec![witness]); + let mut limbs = Limbs::new(sr_num_limbs); + for i in 0..half_limbs { + limbs[i] = dd.get_digit_witness_index(i, 0); + let remaining_bits = (half_bits as u32) - (i as u32 * 64); + let this_limb_bits = remaining_bits.min(64); + ops.range_checks + .entry(this_limb_bits) + .or_default() + .push(limbs[i]); + } + // Zero-pad upper limbs + for i in half_limbs..sr_num_limbs { + let w = ops.compiler.num_witnesses(); + ops.compiler + .add_witness_builder(WitnessBuilder::Constant(ConstantTerm( + w, + FieldElement::from(0u64), + ))); + limbs[i] = w; + constrain_zero(ops.compiler, limbs[i]); + } + limbs + }; + + let s1_limbs = decompose_half_scalar(&mut ops, s1_witness); + let s2_limbs = decompose_half_scalar(&mut ops, s2_witness); + + // Compute product = s2 * s (mod n) + let product = ops.mul(s2_limbs, s_limbs); + + // Handle signs: compute effective values + // If neg2 is set: neg_product = n - product (mod n), i.e. 0 - product + let zero_limbs_vals = vec![FieldElement::from(0u64); sr_num_limbs]; + let zero = ops.constant_limbs(&zero_limbs_vals); + let neg_product = ops.sub(zero, product); + // Select: if neg2=1, use neg_product; else use product + let effective_product = ops.select(neg2_witness, product, neg_product); + + // If neg1 is set: neg_s1 = n - s1 (mod n), i.e. 0 - s1 + let neg_s1 = ops.sub(zero, s1_limbs); + let effective_s1 = ops.select(neg1_witness, s1_limbs, neg_s1); + + // Sum: effective_s1 + effective_product (mod n) should be 0 + let sum = ops.add(effective_s1, effective_product); + + // Constrain sum == 0: all limbs must be zero + for i in 0..sr_num_limbs { + constrain_zero(ops.compiler, sum[i]); } - bits } /// Constrains two witnesses to be equal: `a - b = 0`. diff --git a/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs b/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs index ab84fc9b7..12c30b382 100644 --- a/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs +++ b/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs @@ -54,10 +54,7 @@ pub fn reduce_mod_p( ); let modulus_bits = modulus.into_bigint().num_bits(); - range_checks - .entry(modulus_bits) - .or_default() - .push(result); + range_checks.entry(modulus_bits).or_default().push(result); result } @@ -169,13 +166,10 @@ pub fn compute_is_zero(compiler: &mut NoirToR1CSCompiler, value: usize) -> usize )); let is_zero = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::Sum( - is_zero, - vec![ - SumTerm(Some(FieldElement::ONE), compiler.witness_one()), - SumTerm(Some(-FieldElement::ONE), value_mul_value_inv), - ], - )); + compiler.add_witness_builder(WitnessBuilder::Sum(is_zero, vec![ + SumTerm(Some(FieldElement::ONE), compiler.witness_one()), + SumTerm(Some(-FieldElement::ONE), value_mul_value_inv), + ])); // v × v^(-1) = 1 - is_zero compiler.r1cs.add_constraint( @@ -223,43 +217,47 @@ pub fn add_mod_p_multi( // Witness: q = floor((a + b) / p) ∈ {0, 1} let q = compiler.num_witnesses(); compiler.add_witness_builder(WitnessBuilder::MultiLimbAddQuotient { - output: q, - a_limbs: a.as_slice().to_vec(), - b_limbs: b.as_slice().to_vec(), - modulus: *modulus_raw, + output: q, + a_limbs: a.as_slice().to_vec(), + b_limbs: b.as_slice().to_vec(), + modulus: *modulus_raw, limb_bits, num_limbs: n as u32, }); // q is boolean - compiler.r1cs.add_constraint( - &[(FieldElement::ONE, q)], - &[(FieldElement::ONE, q)], - &[(FieldElement::ONE, q)], - ); + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, q)], &[(FieldElement::ONE, q)], &[( + FieldElement::ONE, + q, + )]); let mut r = Limbs::new(n); let mut carry_prev: Option = None; for i in 0..n { // v_offset = a[i] + b[i] + 2^W - q*p[i] + carry_{i-1} + // When carry_prev exists, combine w1 terms to avoid duplicate column + // indices in the R1CS sparse matrix (set overwrites on duplicate (row,col)). + let w1_coeff = if carry_prev.is_some() { + two_pow_w - FieldElement::ONE + } else { + two_pow_w + }; let mut terms = vec![ SumTerm(None, a[i]), SumTerm(None, b[i]), - SumTerm(Some(two_pow_w), w1), + SumTerm(Some(w1_coeff), w1), SumTerm(Some(-p_limbs[i]), q), ]; if let Some(carry) = carry_prev { terms.push(SumTerm(None, carry)); - // Compensate for previous 2^W offset - terms.push(SumTerm(Some(-FieldElement::ONE), w1)); } let v_offset = compiler.add_sum(terms); // carry = floor(v_offset / 2^W) let carry = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::IntegerQuotient( - carry, v_offset, two_pow_w, - )); + compiler.add_witness_builder(WitnessBuilder::IntegerQuotient(carry, v_offset, two_pow_w)); // r[i] = v_offset - carry * 2^W r[i] = compiler.add_sum(vec![ SumTerm(None, v_offset), @@ -268,7 +266,14 @@ pub fn add_mod_p_multi( carry_prev = Some(carry); } - less_than_p_check_multi(compiler, range_checks, r, p_minus_1_limbs, two_pow_w, limb_bits); + less_than_p_check_multi( + compiler, + range_checks, + r, + p_minus_1_limbs, + two_pow_w, + limb_bits, + ); r } @@ -292,41 +297,46 @@ pub fn sub_mod_p_multi( // Witness: q = (a < b) ? 1 : 0 let q = compiler.num_witnesses(); compiler.add_witness_builder(WitnessBuilder::MultiLimbSubBorrow { - output: q, - a_limbs: a.as_slice().to_vec(), - b_limbs: b.as_slice().to_vec(), - modulus: *modulus_raw, + output: q, + a_limbs: a.as_slice().to_vec(), + b_limbs: b.as_slice().to_vec(), + modulus: *modulus_raw, limb_bits, num_limbs: n as u32, }); // q is boolean - compiler.r1cs.add_constraint( - &[(FieldElement::ONE, q)], - &[(FieldElement::ONE, q)], - &[(FieldElement::ONE, q)], - ); + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, q)], &[(FieldElement::ONE, q)], &[( + FieldElement::ONE, + q, + )]); let mut r = Limbs::new(n); let mut carry_prev: Option = None; for i in 0..n { // v_offset = a[i] - b[i] + q*p[i] + 2^W + carry_{i-1} + // When carry_prev exists, combine w1 terms to avoid duplicate column + // indices in the R1CS sparse matrix (set overwrites on duplicate (row,col)). + let w1_coeff = if carry_prev.is_some() { + two_pow_w - FieldElement::ONE + } else { + two_pow_w + }; let mut terms = vec![ SumTerm(None, a[i]), SumTerm(Some(-FieldElement::ONE), b[i]), SumTerm(Some(p_limbs[i]), q), - SumTerm(Some(two_pow_w), w1), + SumTerm(Some(w1_coeff), w1), ]; if let Some(carry) = carry_prev { terms.push(SumTerm(None, carry)); - terms.push(SumTerm(Some(-FieldElement::ONE), w1)); } let v_offset = compiler.add_sum(terms); let carry = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::IntegerQuotient( - carry, v_offset, two_pow_w, - )); + compiler.add_witness_builder(WitnessBuilder::IntegerQuotient(carry, v_offset, two_pow_w)); r[i] = compiler.add_sum(vec![ SumTerm(None, v_offset), SumTerm(Some(-two_pow_w), carry), @@ -334,7 +344,14 @@ pub fn sub_mod_p_multi( carry_prev = Some(carry); } - less_than_p_check_multi(compiler, range_checks, r, p_minus_1_limbs, two_pow_w, limb_bits); + less_than_p_check_multi( + compiler, + range_checks, + r, + p_minus_1_limbs, + two_pow_w, + limb_bits, + ); r } @@ -367,9 +384,8 @@ pub fn mul_mod_p_multi( let max_bits = 2 * limb_bits + ceil_log2_n + 3; assert!( max_bits < FieldElement::MODULUS_BIT_SIZE, - "Schoolbook column equation overflow: limb_bits={limb_bits}, n={n} limbs \ - requires {max_bits} bits, but native field is only {} bits. \ - Use smaller limb_bits.", + "Schoolbook column equation overflow: limb_bits={limb_bits}, n={n} limbs requires \ + {max_bits} bits, but native field is only {} bits. Use smaller limb_bits.", FieldElement::MODULUS_BIT_SIZE, ); } @@ -382,18 +398,19 @@ pub fn mul_mod_p_multi( let carry_offset_fe = FieldElement::from(2u64).pow([carry_offset_bits as u64]); // offset_w = carry_offset * 2^limb_bits let offset_w = FieldElement::from(2u64).pow([(carry_offset_bits + limb_bits) as u64]); - // offset_w_minus_carry = offset_w - carry_offset = carry_offset * (2^limb_bits - 1) + // offset_w_minus_carry = offset_w - carry_offset = carry_offset * (2^limb_bits + // - 1) let offset_w_minus_carry = offset_w - carry_offset_fe; // Step 1: Allocate hint witnesses (q limbs, r limbs, carries) let os = compiler.num_witnesses(); compiler.add_witness_builder(WitnessBuilder::MultiLimbMulModHint { output_start: os, - a_limbs: a.as_slice().to_vec(), - b_limbs: b.as_slice().to_vec(), - modulus: *modulus_raw, + a_limbs: a.as_slice().to_vec(), + b_limbs: b.as_slice().to_vec(), + modulus: *modulus_raw, limb_bits, - num_limbs: n as u32, + num_limbs: n as u32, }); // q[0..n), r[n..2n), carries[2n..4n-2) @@ -459,7 +476,14 @@ pub fn mul_mod_p_multi( for (i, &ri) in r_indices.iter().enumerate() { r_limbs[i] = ri; } - less_than_p_check_multi(compiler, range_checks, r_limbs, p_minus_1_limbs, two_pow_w, limb_bits); + less_than_p_check_multi( + compiler, + range_checks, + r_limbs, + p_minus_1_limbs, + two_pow_w, + limb_bits, + ); // Step 5: Range checks for q limbs and carries for i in 0..n { @@ -475,7 +499,8 @@ pub fn mul_mod_p_multi( } /// a^(-1) mod p for multi-limb values. -/// Uses MultiLimbModularInverse hint, verifies via mul_mod_p(a, inv) = [1, 0, ..., 0]. +/// Uses MultiLimbModularInverse hint, verifies via mul_mod_p(a, inv) = [1, 0, +/// ..., 0]. pub fn inv_mod_p_multi( compiler: &mut NoirToR1CSCompiler, range_checks: &mut BTreeMap>, @@ -493,14 +518,15 @@ pub fn inv_mod_p_multi( let inv_start = compiler.num_witnesses(); compiler.add_witness_builder(WitnessBuilder::MultiLimbModularInverse { output_start: inv_start, - a_limbs: a.as_slice().to_vec(), - modulus: *modulus_raw, + a_limbs: a.as_slice().to_vec(), + modulus: *modulus_raw, limb_bits, - num_limbs: n as u32, + num_limbs: n as u32, }); let mut inv = Limbs::new(n); for i in 0..n { inv[i] = inv_start + i; + range_checks.entry(limb_bits).or_default().push(inv[i]); } // Verify: a * inv mod p = [1, 0, ..., 0] @@ -535,7 +561,8 @@ pub fn inv_mod_p_multi( } /// Proves r < p by decomposing (p-1) - r into non-negative multi-limb values. -/// Uses borrow propagation: d[i] = (p-1)[i] - r[i] + borrow_in - borrow_out * 2^W +/// Uses borrow propagation: d[i] = (p-1)[i] - r[i] + borrow_in - borrow_out * +/// 2^W fn less_than_p_check_multi( compiler: &mut NoirToR1CSCompiler, range_checks: &mut BTreeMap>, @@ -547,25 +574,27 @@ fn less_than_p_check_multi( let n = r.len(); let w1 = compiler.witness_one(); let mut borrow_prev: Option = None; - for i in 0..n { // v_diff = (p-1)[i] + 2^W - r[i] + borrow_prev - let p_minus_1_plus_offset = p_minus_1_limbs[i] + two_pow_w; + // When borrow_prev exists, combine w1 terms to avoid duplicate column + // indices in the R1CS sparse matrix (set overwrites on duplicate (row,col)). + let w1_coeff = if borrow_prev.is_some() { + p_minus_1_limbs[i] + two_pow_w - FieldElement::ONE + } else { + p_minus_1_limbs[i] + two_pow_w + }; let mut terms = vec![ - SumTerm(Some(p_minus_1_plus_offset), w1), + SumTerm(Some(w1_coeff), w1), SumTerm(Some(-FieldElement::ONE), r[i]), ]; if let Some(borrow) = borrow_prev { terms.push(SumTerm(None, borrow)); - terms.push(SumTerm(Some(-FieldElement::ONE), w1)); } let v_diff = compiler.add_sum(terms); // borrow = floor(v_diff / 2^W) let borrow = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::IntegerQuotient( - borrow, v_diff, two_pow_w, - )); + compiler.add_witness_builder(WitnessBuilder::IntegerQuotient(borrow, v_diff, two_pow_w)); // d[i] = v_diff - borrow * 2^W let d_i = compiler.add_sum(vec![ SumTerm(None, v_diff), @@ -579,13 +608,15 @@ fn less_than_p_check_multi( borrow_prev = Some(borrow); } - // Constrain final borrow = 0: if borrow_out != 0, then r > p-1 (i.e. r >= p), - // which would mean the result is not properly reduced. + // Constrain final carry = 1: the 2^W offset at each limb propagates + // a carry of 1 through the chain. For valid r < p, the final carry + // must be exactly 1. If r >= p, the carry chain underflows and the + // final carry is 0. if let Some(final_borrow) = borrow_prev { compiler.r1cs.add_constraint( &[(FieldElement::ONE, compiler.witness_one())], &[(FieldElement::ONE, final_borrow)], - &[(FieldElement::ZERO, compiler.witness_one())], + &[(FieldElement::ONE, compiler.witness_one())], ); } } diff --git a/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs b/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs index 4f1c45448..9b1d9db45 100644 --- a/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs +++ b/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs @@ -1,14 +1,11 @@ -//! `MultiLimbOps` — unified FieldOps implementation parameterized by runtime limb count. +//! `MultiLimbOps` — unified FieldOps implementation parameterized by runtime +//! limb count. //! //! Uses `Limbs` (a fixed-capacity Copy type) as `FieldOps::Elem`, enabling //! arbitrary limb counts without const generics or dispatch macros. use { - super::{ - multi_limb_arith, - Limbs, - FieldOps, - }, + super::{multi_limb_arith, FieldOps, Limbs}, crate::noir_to_r1cs::NoirToR1CSCompiler, ark_ff::{AdditiveGroup, Field}, provekit_common::{ @@ -20,18 +17,18 @@ use { /// Parameters for multi-limb field arithmetic. pub struct MultiLimbParams { - pub num_limbs: usize, - pub limb_bits: u32, - pub p_limbs: Vec, - pub p_minus_1_limbs: Vec, - pub two_pow_w: FieldElement, - pub modulus_raw: [u64; 4], - pub curve_a_limbs: Vec, - pub modulus_bits: u32, + pub num_limbs: usize, + pub limb_bits: u32, + pub p_limbs: Vec, + pub p_minus_1_limbs: Vec, + pub two_pow_w: FieldElement, + pub modulus_raw: [u64; 4], + pub curve_a_limbs: Vec, + pub modulus_bits: u32, /// p = native field → skip mod reduction - pub is_native: bool, + pub is_native: bool, /// For N=1 non-native: the modulus as a single FieldElement - pub modulus_fe: Option, + pub modulus_fe: Option, } /// Unified field operations struct parameterized by runtime limb count. @@ -66,20 +63,21 @@ impl FieldOps for MultiLimbOps<'_> { // term with coefficient 2 to avoid duplicate column indices in // the R1CS sparse matrix (set overwrites on duplicate (row,col)). let r = if a[0] == b[0] { - self.compiler.add_sum(vec![ - SumTerm(Some(FieldElement::from(2u64)), a[0]), - ]) + self.compiler + .add_sum(vec![SumTerm(Some(FieldElement::from(2u64)), a[0])]) } else { - self.compiler.add_sum(vec![ - SumTerm(None, a[0]), - SumTerm(None, b[0]), - ]) + self.compiler + .add_sum(vec![SumTerm(None, a[0]), SumTerm(None, b[0])]) }; Limbs::single(r) } else if self.is_non_native_single() { let modulus = self.params.modulus_fe.unwrap(); let r = multi_limb_arith::add_mod_p_single( - self.compiler, a[0], b[0], modulus, self.range_checks, + self.compiler, + a[0], + b[0], + modulus, + self.range_checks, ); Limbs::single(r) } else { @@ -104,9 +102,8 @@ impl FieldOps for MultiLimbOps<'_> { // When both operands are the same witness, a - a = 0. Use a // single zero-coefficient term to avoid duplicate column indices. let r = if a[0] == b[0] { - self.compiler.add_sum(vec![ - SumTerm(Some(FieldElement::ZERO), a[0]), - ]) + self.compiler + .add_sum(vec![SumTerm(Some(FieldElement::ZERO), a[0])]) } else { self.compiler.add_sum(vec![ SumTerm(None, a[0]), @@ -117,7 +114,11 @@ impl FieldOps for MultiLimbOps<'_> { } else if self.is_non_native_single() { let modulus = self.params.modulus_fe.unwrap(); let r = multi_limb_arith::sub_mod_p_single( - self.compiler, a[0], b[0], modulus, self.range_checks, + self.compiler, + a[0], + b[0], + modulus, + self.range_checks, ); Limbs::single(r) } else { @@ -144,7 +145,11 @@ impl FieldOps for MultiLimbOps<'_> { } else if self.is_non_native_single() { let modulus = self.params.modulus_fe.unwrap(); let r = multi_limb_arith::mul_mod_p_single( - self.compiler, a[0], b[0], modulus, self.range_checks, + self.compiler, + a[0], + b[0], + modulus, + self.range_checks, ); Limbs::single(r) } else { @@ -177,9 +182,8 @@ impl FieldOps for MultiLimbOps<'_> { Limbs::single(a_inv) } else if self.is_non_native_single() { let modulus = self.params.modulus_fe.unwrap(); - let r = multi_limb_arith::inv_mod_p_single( - self.compiler, a[0], modulus, self.range_checks, - ); + let r = + multi_limb_arith::inv_mod_p_single(self.compiler, a[0], modulus, self.range_checks); Limbs::single(r) } else { multi_limb_arith::inv_mod_p_multi( @@ -210,12 +214,7 @@ impl FieldOps for MultiLimbOps<'_> { out } - fn select( - &mut self, - flag: usize, - on_false: Limbs, - on_true: Limbs, - ) -> Limbs { + fn select(&mut self, flag: usize, on_false: Limbs, on_true: Limbs) -> Limbs { super::constrain_boolean(self.compiler, flag); let n = self.n(); let mut out = Limbs::new(n); @@ -233,43 +232,21 @@ impl FieldOps for MultiLimbOps<'_> { super::pack_bits_helper(self.compiler, bits) } - fn elem_is_zero(&mut self, value: Limbs) -> usize { - let n = self.n(); - if n == 1 { - multi_limb_arith::compute_is_zero(self.compiler, value[0]) - } else { - // Check each limb is zero and AND the results together - let mut result = multi_limb_arith::compute_is_zero(self.compiler, value[0]); - for i in 1..n { - let limb_zero = multi_limb_arith::compute_is_zero(self.compiler, value[i]); - result = self.compiler.add_product(result, limb_zero); - } - result - } - } - - fn constant_one(&mut self) -> Limbs { + fn constant_limbs(&mut self, limbs: &[FieldElement]) -> Limbs { let n = self.n(); + assert_eq!( + limbs.len(), + n, + "constant_limbs: expected {n} limbs, got {}", + limbs.len() + ); let mut out = Limbs::new(n); - // limb[0] = 1 - let w0 = self.compiler.num_witnesses(); - self.compiler - .add_witness_builder(WitnessBuilder::Constant(ConstantTerm(w0, FieldElement::ONE))); - out[0] = w0; - // limb[1..n] = 0 - for i in 1..n { + for i in 0..n { let w = self.compiler.num_witnesses(); self.compiler - .add_witness_builder(WitnessBuilder::Constant(ConstantTerm( - w, - FieldElement::ZERO, - ))); + .add_witness_builder(WitnessBuilder::Constant(ConstantTerm(w, limbs[i]))); out[i] = w; } out } - - fn bool_and(&mut self, a: usize, b: usize) -> usize { - self.compiler.add_product(a, b) - } } diff --git a/provekit/r1cs-compiler/src/noir_to_r1cs.rs b/provekit/r1cs-compiler/src/noir_to_r1cs.rs index 18bc22ddc..2475ddf8c 100644 --- a/provekit/r1cs-compiler/src/noir_to_r1cs.rs +++ b/provekit/r1cs-compiler/src/noir_to_r1cs.rs @@ -750,12 +750,19 @@ impl NoirToR1CSCompiler { let native_bits = FieldElement::MODULUS_BIT_SIZE; let curve_bits = curve.modulus_bits(); let (msm_limb_bits, msm_window_size) = if !msm_ops.is_empty() { - let n_points: usize = msm_ops.iter().map(|(pts, _, _)| pts.len() / 3).sum(); + let n_points: usize = msm_ops.iter().map(|(pts, ..)| pts.len() / 3).sum(); crate::msm::cost_model::get_optimal_msm_params(native_bits, curve_bits, n_points, 256) } else { (native_bits, 4) }; - add_msm(self, msm_ops, msm_limb_bits, msm_window_size, &mut range_checks, &curve); + add_msm( + self, + msm_ops, + msm_limb_bits, + msm_window_size, + &mut range_checks, + &curve, + ); breakdown.msm_constraints = self.r1cs.num_constraints() - constraints_before_msm; breakdown.msm_witnesses = self.num_witnesses() - witnesses_before_msm; diff --git a/tooling/provekit-bench/tests/compiler.rs b/tooling/provekit-bench/tests/compiler.rs index 8643bcfba..b7f39ba6c 100644 --- a/tooling/provekit-bench/tests/compiler.rs +++ b/tooling/provekit-bench/tests/compiler.rs @@ -83,6 +83,7 @@ pub fn compile_workspace(workspace_path: impl AsRef) -> Result #[test_case("../../noir-examples/noir-r1cs-test-programs/bounded-vec")] #[test_case("../../noir-examples/noir-r1cs-test-programs/brillig-unconstrained")] #[test_case("../../noir-examples/noir-passport-monolithic/complete_age_check"; "complete_age_check")] +#[test_case("../../noir-examples/embedded_curve_msm"; "embedded_curve_msm")] fn case(path: &str) { test_compiler(path); } From 8591d1b8d382fdecb956afcfdb8f64bb3af89877 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Thu, 5 Mar 2026 05:11:17 +0530 Subject: [PATCH 5/5] opt: added unchecked select for already constrained values in scalar relation verification --- provekit/r1cs-compiler/src/msm/ec_points.rs | 25 ++++++++++++++++--- provekit/r1cs-compiler/src/msm/mod.rs | 23 ++++++++++++++--- .../r1cs-compiler/src/msm/multi_limb_ops.rs | 5 +++- 3 files changed, 46 insertions(+), 7 deletions(-) diff --git a/provekit/r1cs-compiler/src/msm/ec_points.rs b/provekit/r1cs-compiler/src/msm/ec_points.rs index 8f7172897..1b591ed8a 100644 --- a/provekit/r1cs-compiler/src/msm/ec_points.rs +++ b/provekit/r1cs-compiler/src/msm/ec_points.rs @@ -100,6 +100,19 @@ pub fn point_select( (x, y) } +/// Conditional point select without boolean constraint on `flag`. +/// Caller must ensure `flag` is already constrained boolean. +fn point_select_unchecked( + ops: &mut F, + flag: usize, + on_false: (F::Elem, F::Elem), + on_true: (F::Elem, F::Elem), +) -> (F::Elem, F::Elem) { + let x = ops.select_unchecked(flag, on_false.0, on_true.0); + let y = ops.select_unchecked(flag, on_false.1, on_true.1); + (x, y) +} + /// Builds a point table for windowed scalar multiplication. /// /// T[0] = P (dummy entry, used when window digit = 0) @@ -130,6 +143,9 @@ fn build_point_table( /// Uses a binary tree of `point_select`s: processes bits from MSB to LSB, /// halving the candidate set at each level. Total: `(2^w - 1)` point selects /// for a table of `2^w` entries. +/// +/// Each bit is constrained boolean exactly once, then all subsequent selects +/// on that bit use the unchecked variant. fn table_lookup( ops: &mut F, table: &[(F::Elem, F::Elem)], @@ -139,10 +155,11 @@ fn table_lookup( let mut current: Vec<(F::Elem, F::Elem)> = table.to_vec(); // Process bits from MSB to LSB for &bit in bits.iter().rev() { + ops.constrain_flag(bit); // constrain boolean once per bit let half = current.len() / 2; let mut next = Vec::with_capacity(half); for i in 0..half { - next.push(point_select(ops, bit, current[i], current[i + half])); + next.push(point_select_unchecked(ops, bit, current[i], current[i + half])); } current = next; } @@ -224,7 +241,8 @@ pub fn scalar_mul_glv( ); let digit_p = ops.pack_bits(s1_window_bits); let digit_p_is_zero = ops.is_zero(digit_p); - let after_p = point_select(ops, digit_p_is_zero, added_p, doubled_acc); + // is_zero already constrains its output boolean; skip redundant check + let after_p = point_select_unchecked(ops, digit_p_is_zero, added_p, doubled_acc); // --- Process R's window digit (s2) --- let s2_window_bits = &s2_bits[bit_start..bit_end]; @@ -243,7 +261,8 @@ pub fn scalar_mul_glv( ); let digit_r = ops.pack_bits(s2_window_bits); let digit_r_is_zero = ops.is_zero(digit_r); - acc = point_select(ops, digit_r_is_zero, added_r, after_p); + // is_zero already constrains its output boolean; skip redundant check + acc = point_select_unchecked(ops, digit_r_is_zero, added_r, after_p); } acc diff --git a/provekit/r1cs-compiler/src/msm/mod.rs b/provekit/r1cs-compiler/src/msm/mod.rs index 826d381c4..6bc96b8f9 100644 --- a/provekit/r1cs-compiler/src/msm/mod.rs +++ b/provekit/r1cs-compiler/src/msm/mod.rs @@ -139,9 +139,24 @@ pub trait FieldOps { fn inv(&mut self, a: Self::Elem) -> Self::Elem; fn curve_a(&mut self) -> Self::Elem; + /// Constrains `flag` to be boolean (`flag * flag = flag`). + fn constrain_flag(&mut self, flag: usize); + + /// Conditional select without boolean constraint on `flag`. + /// Caller must ensure `flag` is already constrained boolean. + fn select_unchecked( + &mut self, + flag: usize, + on_false: Self::Elem, + on_true: Self::Elem, + ) -> Self::Elem; + /// Conditional select: returns `on_true` if `flag` is 1, `on_false` if /// `flag` is 0. Constrains `flag` to be boolean (`flag * flag = flag`). - fn select(&mut self, flag: usize, on_false: Self::Elem, on_true: Self::Elem) -> Self::Elem; + fn select(&mut self, flag: usize, on_false: Self::Elem, on_true: Self::Elem) -> Self::Elem { + self.constrain_flag(flag); + self.select_unchecked(flag, on_false, on_true) + } /// Checks if a BN254 native witness value is zero. /// Returns a boolean witness: 1 if zero, 0 if non-zero. @@ -844,11 +859,13 @@ fn verify_scalar_relation( let zero = ops.constant_limbs(&zero_limbs_vals); let neg_product = ops.sub(zero, product); // Select: if neg2=1, use neg_product; else use product - let effective_product = ops.select(neg2_witness, product, neg_product); + // neg2 already constrained boolean in verify_point_fakeglv + let effective_product = ops.select_unchecked(neg2_witness, product, neg_product); // If neg1 is set: neg_s1 = n - s1 (mod n), i.e. 0 - s1 let neg_s1 = ops.sub(zero, s1_limbs); - let effective_s1 = ops.select(neg1_witness, s1_limbs, neg_s1); + // neg1 already constrained boolean in verify_point_fakeglv + let effective_s1 = ops.select_unchecked(neg1_witness, s1_limbs, neg_s1); // Sum: effective_s1 + effective_product (mod n) should be 0 let sum = ops.add(effective_s1, effective_product); diff --git a/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs b/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs index 9b1d9db45..7ac8d78ac 100644 --- a/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs +++ b/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs @@ -214,8 +214,11 @@ impl FieldOps for MultiLimbOps<'_> { out } - fn select(&mut self, flag: usize, on_false: Limbs, on_true: Limbs) -> Limbs { + fn constrain_flag(&mut self, flag: usize) { super::constrain_boolean(self.compiler, flag); + } + + fn select_unchecked(&mut self, flag: usize, on_false: Limbs, on_true: Limbs) -> Limbs { let n = self.n(); let mut out = Limbs::new(n); for i in 0..n {