diff --git a/noir-examples/embedded_curve_msm/Nargo.toml b/noir-examples/embedded_curve_msm/Nargo.toml new file mode 100644 index 00000000..ec989161 --- /dev/null +++ b/noir-examples/embedded_curve_msm/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "embedded_curve_msm" +type = "bin" +authors = [""] +compiler_version = ">=0.22.0" + +[dependencies] diff --git a/noir-examples/embedded_curve_msm/Prover.toml b/noir-examples/embedded_curve_msm/Prover.toml new file mode 100644 index 00000000..da0b3529 --- /dev/null +++ b/noir-examples/embedded_curve_msm/Prover.toml @@ -0,0 +1,71 @@ +# ============================================================ +# MSM test vectors: result = s1 * G + s2 * G +# Grumpkin curve order n = 21888242871839275222246405745257275088696311157297823662689037894645226208583 +# Uncomment ONE test case at a time to run. +# ============================================================ + +# === Test 1: Small scalars (1*G + 2*G = 3*G) === +scalar1_lo = "1" +scalar1_hi = "0" +scalar2_lo = "2" +scalar2_hi = "0" + +# === Test 2: All-zero scalars (0*G + 0*G = point at infinity) === +# scalar1_lo = "0" +# scalar1_hi = "0" +# scalar2_lo = "0" +# scalar2_hi = "0" + +# === Test 3: One zero, one non-zero (0*G + 5*G = 5*G) === +# scalar1_lo = "0" +# scalar1_hi = "0" +# scalar2_lo = "5" +# scalar2_hi = "0" + +# === Test 4: Large lo, small hi (diff ≠ 2^128) === +# scalar1_lo = "64323764613183177041862057485226039389" +# scalar1_hi = "1" +# scalar2_lo = "99999999999999999999999999999999999999" +# scalar2_hi = "3" + +# === Test 5: Small lo, large hi === +# scalar1_lo = "1" +# scalar1_hi = "64323764613183177041862057485226039389" +# scalar2_lo = "2" +# scalar2_hi = "64323764613183177041862057485226039389" + +# === Test 6: Near-max scalars (n-10 and n-20) === +# scalar1_lo = "201385395114098847380338600778089168189" +# scalar1_hi = "64323764613183177041862057485226039389" +# scalar2_lo = "201385395114098847380338600778089168179" +# scalar2_hi = "64323764613183177041862057485226039389" + +# === Test 7: Powers of 2 (2^100 and 2^200) === +# scalar1_lo = "1267650600228229401496703205376" +# scalar1_hi = "0" +# scalar2_lo = "0" +# scalar2_hi = "4722366482869645213696" + +# === Test 8: Half curve order (n/2) and 1 === +# scalar1_lo = "270833881017518655421856604104928689827" +# scalar1_hi = "32161882306591588520931028742613019694" +# scalar2_lo = "1" +# scalar2_hi = "0" + +# === Test 9: Large mixed scalars === +# scalar1_lo = "340282366920938463463374607431768211455" +# scalar1_hi = "0" +# scalar2_lo = "170141183460469231731687303715884105727" +# scalar2_hi = "3" + +# === Test 10: Both scalars equal, ~2n/3 === +# scalar1_lo = "247684385716378719408017269662648849284" +# scalar1_hi = "42882509742122118027908038323484026259" +# scalar2_lo = "247684385716378719408017269662648849284" +# scalar2_hi = "42882509742122118027908038323484026259" + +# === Test 11: n - 2, n - 3 (previously failing with [2]G offset) === +# scalar1_lo = "201385395114098847380338600778089168197" +# scalar1_hi = "64323764613183177041862057485226039389" +# scalar2_lo = "201385395114098847380338600778089168196" +# scalar2_hi = "64323764613183177041862057485226039389" diff --git a/noir-examples/embedded_curve_msm/src/main.nr b/noir-examples/embedded_curve_msm/src/main.nr new file mode 100644 index 00000000..19a19318 --- /dev/null +++ b/noir-examples/embedded_curve_msm/src/main.nr @@ -0,0 +1,52 @@ +use std::embedded_curve_ops::{ + EmbeddedCurvePoint, + EmbeddedCurveScalar, + multi_scalar_mul, +}; + +/// Exercises the MultiScalarMul ACIR blackbox with 2 Grumpkin points. +/// Computes s1 * G + s2 * G where G is the Grumpkin generator. +fn main( + scalar1_lo: Field, + scalar1_hi: Field, + scalar2_lo: Field, + scalar2_hi: Field, +) { + // Grumpkin generator + let g = EmbeddedCurvePoint { + x: 1, + y: 17631683881184975370165255887551781615748388533673675138860, + is_infinite: false, + }; + + let s1 = EmbeddedCurveScalar { lo: scalar1_lo, hi: scalar1_hi }; + let s2 = EmbeddedCurveScalar { lo: scalar2_lo, hi: scalar2_hi }; + + // MSM: result = s1 * G + s2 * G + let result = multi_scalar_mul([g, g], [s1, s2]); + + // Prevent dead-code elimination - forces the blackbox to be retained + // Using is_infinite as return value ensures the MSM is computed + assert(result.is_infinite == (scalar1_lo + scalar1_hi + scalar2_lo + scalar2_hi == 0)); +} + +#[test] +fn test_msm() { + // 3*G on Grumpkin + let expected_x = 18660890509582237958343981571981920822503400000196279471655180441138020044621; + let expected_y = 8902249110305491597038405103722863701255802573786510474664632793109847672620; + + main(1, 0, 2, 0); + + // Verify by computing independently: 3*G should match + let g = EmbeddedCurvePoint { + x: 1, + y: 17631683881184975370165255887551781615748388533673675138860, + is_infinite: false, + }; + let s3 = EmbeddedCurveScalar { lo: 3, hi: 0 }; + let check = multi_scalar_mul([g], [s3]); + + assert(check.x == expected_x); + assert(check.y == expected_y); +} diff --git a/noir-examples/native_msm/Nargo.toml b/noir-examples/native_msm/Nargo.toml new file mode 100644 index 00000000..5ff116db --- /dev/null +++ b/noir-examples/native_msm/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "native_msm" +type = "bin" +authors = [""] +compiler_version = ">=0.22.0" + +[dependencies] diff --git a/noir-examples/native_msm/Prover.toml b/noir-examples/native_msm/Prover.toml new file mode 100644 index 00000000..58c6933d --- /dev/null +++ b/noir-examples/native_msm/Prover.toml @@ -0,0 +1,5 @@ +# MSM: result = s1 * G + s2 * G = 1*G + 2*G = 3*G +scalar1_lo = "1" +scalar1_hi = "0" +scalar2_lo = "2" +scalar2_hi = "0" diff --git a/noir-examples/native_msm/src/main.nr b/noir-examples/native_msm/src/main.nr new file mode 100644 index 00000000..80cfd3d0 --- /dev/null +++ b/noir-examples/native_msm/src/main.nr @@ -0,0 +1,104 @@ +// Grumpkin generator y-coordinate +global GRUMPKIN_GEN_Y: Field = 17631683881184975370165255887551781615748388533673675138860; + +struct Point { + x: Field, + y: Field, + is_infinite: bool, +} + +fn point_double(p: Point) -> Point { + if p.is_infinite | (p.y == 0) { + Point { x: 0, y: 0, is_infinite: true } + } else { + // Grumpkin has a=0, so lambda = 3*x1^2 / (2*y1) + let lambda = (3 * p.x * p.x) / (2 * p.y); + let x3 = lambda * lambda - 2 * p.x; + let y3 = lambda * (p.x - x3) - p.y; + Point { x: x3, y: y3, is_infinite: false } + } +} + +fn point_add(p1: Point, p2: Point) -> Point { + if p1.is_infinite { + p2 + } else if p2.is_infinite { + p1 + } else if (p1.x == p2.x) & (p1.y == p2.y) { + point_double(p1) + } else if (p1.x == p2.x) & (p1.y == (0 - p2.y)) { + Point { x: 0, y: 0, is_infinite: true } + } else { + let lambda = (p2.y - p1.y) / (p2.x - p1.x); + let x3 = lambda * lambda - p1.x - p2.x; + let y3 = lambda * (p1.x - x3) - p1.y; + Point { x: x3, y: y3, is_infinite: false } + } +} + +fn scalar_mul(p: Point, scalar_lo: Field, scalar_hi: Field) -> Point { + let lo_bits: [u1; 128] = scalar_lo.to_le_bits(); + let hi_bits: [u1; 128] = scalar_hi.to_le_bits(); + + // Combine into a single 256-bit array (lo first, then hi) + let mut bits: [u1; 256] = [0; 256]; + for i in 0..128 { + bits[i] = lo_bits[i]; + bits[128 + i] = hi_bits[i]; + } + + // Find the highest set bit + let mut top = 0; + for i in 0..256 { + if bits[i] == 1 { + top = i; + } + } + + // Double-and-add from MSB down to bit 0 + let mut acc = Point { x: 0, y: 0, is_infinite: true }; + for j in 0..256 { + let i = 255 - j; + acc = point_double(acc); + if bits[i] == 1 { + acc = point_add(acc, p); + } + } + + acc +} + +/// Native MSM: computes s1 * G + s2 * G using pure Noir field operations. +/// No blackbox functions -- all EC arithmetic is done natively over Grumpkin's +/// base field (= BN254 scalar field = Noir's native Field). +fn main( + scalar1_lo: Field, + scalar1_hi: Field, + scalar2_lo: Field, + scalar2_hi: Field, +) { + let g = Point { x: 1, y: GRUMPKIN_GEN_Y, is_infinite: false }; + + let r1 = scalar_mul(g, scalar1_lo, scalar1_hi); + let r2 = scalar_mul(g, scalar2_lo, scalar2_hi); + let result = point_add(r1, r2); + + // Prevent dead-code elimination + assert(!result.is_infinite); +} + +#[test] +fn test_msm() { + // 3*G on Grumpkin (known coordinates) + let expected_x = 18660890509582237958343981571981920822503400000196279471655180441138020044621; + let expected_y = 8902249110305491597038405103722863701255802573786510474664632793109847672620; + + main(1, 0, 2, 0); + + // Verify 1*G + 2*G = 3*G by computing 3*G directly + let g = Point { x: 1, y: GRUMPKIN_GEN_Y, is_infinite: false }; + let three_g = scalar_mul(g, 3, 0); + + assert(three_g.x == expected_x); + assert(three_g.y == expected_y); +} diff --git a/provekit/common/src/witness/scheduling/dependency.rs b/provekit/common/src/witness/scheduling/dependency.rs index a5cbaefd..68bc7b6e 100644 --- a/provekit/common/src/witness/scheduling/dependency.rs +++ b/provekit/common/src/witness/scheduling/dependency.rs @@ -78,7 +78,10 @@ impl DependencyInfo { WitnessBuilder::Sum(_, ops) => ops.iter().map(|SumTerm(_, idx)| *idx).collect(), WitnessBuilder::Product(_, a, b) => vec![*a, *b], WitnessBuilder::MultiplicitiesForRange(_, _, values) => values.clone(), - WitnessBuilder::Inverse(_, x) => vec![*x], + WitnessBuilder::Inverse(_, x) + | WitnessBuilder::SafeInverse(_, x) + | WitnessBuilder::ModularInverse(_, x, _) + | WitnessBuilder::IntegerQuotient(_, x, _) => vec![*x], WitnessBuilder::IndexedLogUpDenominator( _, sz, @@ -152,6 +155,28 @@ impl DependencyInfo { } v } + WitnessBuilder::MultiLimbMulModHint { + a_limbs, b_limbs, .. + } => { + let mut v = a_limbs.clone(); + v.extend(b_limbs); + v + } + WitnessBuilder::MultiLimbModularInverse { a_limbs, .. } => a_limbs.clone(), + WitnessBuilder::MultiLimbAddQuotient { + a_limbs, b_limbs, .. + } => { + let mut v = a_limbs.clone(); + v.extend(b_limbs); + v + } + WitnessBuilder::MultiLimbSubBorrow { + a_limbs, b_limbs, .. + } => { + let mut v = a_limbs.clone(); + v.extend(b_limbs); + v + } WitnessBuilder::BytePartition { x, .. } => vec![*x], WitnessBuilder::U32AdditionMulti(_, _, inputs) => inputs @@ -198,6 +223,10 @@ impl DependencyInfo { data.rs_cubed, ] } + WitnessBuilder::FakeGLVHint { s_lo, s_hi, .. } => vec![*s_lo, *s_hi], + WitnessBuilder::EcScalarMulHint { + px, py, s_lo, s_hi, .. + } => vec![*px, *py, *s_lo, *s_hi], WitnessBuilder::ChunkDecompose { packed, .. } => vec![*packed], WitnessBuilder::SpreadWitness(_, input) => vec![*input], WitnessBuilder::SpreadBitExtract { sum_terms, .. } => { @@ -240,6 +269,9 @@ impl DependencyInfo { | WitnessBuilder::Challenge(idx) | WitnessBuilder::IndexedLogUpDenominator(idx, ..) | WitnessBuilder::Inverse(idx, _) + | WitnessBuilder::SafeInverse(idx, _) + | WitnessBuilder::ModularInverse(idx, ..) + | WitnessBuilder::IntegerQuotient(idx, ..) | WitnessBuilder::ProductLinearOperation(idx, ..) | WitnessBuilder::LogUpDenominator(idx, ..) | WitnessBuilder::LogUpInverse(idx, ..) @@ -282,6 +314,27 @@ impl DependencyInfo { let n = 1usize << *num_bits; (*start..*start + n).collect() } + WitnessBuilder::MultiLimbMulModHint { + output_start, + num_limbs, + .. + } => { + let count = (4 * *num_limbs - 2) as usize; + (*output_start..*output_start + count).collect() + } + WitnessBuilder::MultiLimbModularInverse { + output_start, + num_limbs, + .. + } => (*output_start..*output_start + *num_limbs as usize).collect(), + WitnessBuilder::FakeGLVHint { output_start, .. } => { + (*output_start..*output_start + 4).collect() + } + WitnessBuilder::EcScalarMulHint { output_start, .. } => { + (*output_start..*output_start + 2).collect() + } + WitnessBuilder::MultiLimbAddQuotient { output, .. } => vec![*output], + WitnessBuilder::MultiLimbSubBorrow { output, .. } => vec![*output], WitnessBuilder::U32Addition(result_idx, carry_idx, ..) => { vec![*result_idx, *carry_idx] } diff --git a/provekit/common/src/witness/scheduling/remapper.rs b/provekit/common/src/witness/scheduling/remapper.rs index 9503847a..69614411 100644 --- a/provekit/common/src/witness/scheduling/remapper.rs +++ b/provekit/common/src/witness/scheduling/remapper.rs @@ -115,6 +115,15 @@ impl WitnessIndexRemapper { WitnessBuilder::Inverse(idx, operand) => { WitnessBuilder::Inverse(self.remap(*idx), self.remap(*operand)) } + WitnessBuilder::SafeInverse(idx, operand) => { + WitnessBuilder::SafeInverse(self.remap(*idx), self.remap(*operand)) + } + WitnessBuilder::ModularInverse(idx, operand, modulus) => { + WitnessBuilder::ModularInverse(self.remap(*idx), self.remap(*operand), *modulus) + } + WitnessBuilder::IntegerQuotient(idx, dividend, divisor) => { + WitnessBuilder::IntegerQuotient(self.remap(*idx), self.remap(*dividend), *divisor) + } WitnessBuilder::ProductLinearOperation( idx, ProductLinearTerm(x, a, b), @@ -215,6 +224,64 @@ impl WitnessIndexRemapper { .collect(), ) } + WitnessBuilder::MultiLimbMulModHint { + output_start, + a_limbs, + b_limbs, + modulus, + limb_bits, + num_limbs, + } => WitnessBuilder::MultiLimbMulModHint { + output_start: self.remap(*output_start), + a_limbs: a_limbs.iter().map(|&w| self.remap(w)).collect(), + b_limbs: b_limbs.iter().map(|&w| self.remap(w)).collect(), + modulus: *modulus, + limb_bits: *limb_bits, + num_limbs: *num_limbs, + }, + WitnessBuilder::MultiLimbModularInverse { + output_start, + a_limbs, + modulus, + limb_bits, + num_limbs, + } => WitnessBuilder::MultiLimbModularInverse { + output_start: self.remap(*output_start), + a_limbs: a_limbs.iter().map(|&w| self.remap(w)).collect(), + modulus: *modulus, + limb_bits: *limb_bits, + num_limbs: *num_limbs, + }, + WitnessBuilder::MultiLimbAddQuotient { + output, + a_limbs, + b_limbs, + modulus, + limb_bits, + num_limbs, + } => WitnessBuilder::MultiLimbAddQuotient { + output: self.remap(*output), + a_limbs: a_limbs.iter().map(|&w| self.remap(w)).collect(), + b_limbs: b_limbs.iter().map(|&w| self.remap(w)).collect(), + modulus: *modulus, + limb_bits: *limb_bits, + num_limbs: *num_limbs, + }, + WitnessBuilder::MultiLimbSubBorrow { + output, + a_limbs, + b_limbs, + modulus, + limb_bits, + num_limbs, + } => WitnessBuilder::MultiLimbSubBorrow { + output: self.remap(*output), + a_limbs: a_limbs.iter().map(|&w| self.remap(w)).collect(), + b_limbs: b_limbs.iter().map(|&w| self.remap(w)).collect(), + modulus: *modulus, + limb_bits: *limb_bits, + num_limbs: *num_limbs, + }, WitnessBuilder::BytePartition { lo, hi, x, k } => WitnessBuilder::BytePartition { lo: self.remap(*lo), hi: self.remap(*hi), @@ -299,6 +366,34 @@ impl WitnessIndexRemapper { }, ) } + WitnessBuilder::FakeGLVHint { + output_start, + s_lo, + s_hi, + curve_order, + } => WitnessBuilder::FakeGLVHint { + output_start: self.remap(*output_start), + s_lo: self.remap(*s_lo), + s_hi: self.remap(*s_hi), + curve_order: *curve_order, + }, + WitnessBuilder::EcScalarMulHint { + output_start, + px, + py, + s_lo, + s_hi, + curve_a, + field_modulus_p, + } => WitnessBuilder::EcScalarMulHint { + output_start: self.remap(*output_start), + px: self.remap(*px), + py: self.remap(*py), + s_lo: self.remap(*s_lo), + s_hi: self.remap(*s_hi), + curve_a: *curve_a, + field_modulus_p: *field_modulus_p, + }, WitnessBuilder::ChunkDecompose { output_start, packed, diff --git a/provekit/common/src/witness/witness_builder.rs b/provekit/common/src/witness/witness_builder.rs index 0628fc2e..353f2575 100644 --- a/provekit/common/src/witness/witness_builder.rs +++ b/provekit/common/src/witness/witness_builder.rs @@ -88,6 +88,23 @@ pub enum WitnessBuilder { /// The inverse of the value at a specified witness index /// (witness index, operand witness index) Inverse(usize, usize), + /// Safe inverse: like Inverse but handles zero by outputting 0. + /// Used by compute_is_zero where the input may be zero. Solved in the + /// Other layer (not batch-inverted), so zero inputs don't poison the batch. + /// (witness index, operand witness index) + SafeInverse(usize, usize), + /// The modular inverse of the value at a specified witness index, modulo + /// a given prime modulus. Computes a^{-1} mod m using Fermat's little + /// theorem (a^{m-2} mod m). Unlike Inverse (BN254 field inverse), this + /// operates as integer modular arithmetic. + /// (witness index, operand witness index, modulus) + ModularInverse(usize, usize, #[serde(with = "serde_ark")] FieldElement), + /// The integer quotient floor(dividend / divisor). Used by reduce_mod to + /// compute k = floor(v / m) so that v = k*m + result with 0 <= result < m. + /// Unlike field multiplication by the inverse, this performs true integer + /// division on the BigInteger representation. + /// (witness index, dividend witness index, divisor constant) + IntegerQuotient(usize, usize, #[serde(with = "serde_ark")] FieldElement), /// Products with linear operations on the witness indices. /// Fields are ProductLinearOperation(witness_idx, (index, a, b), (index, c, /// d)) such that we wish to compute (ax + b) * (cx + d). @@ -189,6 +206,61 @@ pub enum WitnessBuilder { /// Inverse of combined lookup table entry denominator (constant operands). /// Computes: 1 / (sz - lhs - rs*rhs - rs²*and_out - rs³*xor_out) CombinedTableEntryInverse(CombinedTableEntryInverseData), + /// Prover hint for multi-limb modular multiplication: (a * b) mod p. + /// Given inputs a and b as N-limb vectors (each limb `limb_bits` wide), + /// and a constant 256-bit modulus p, computes quotient q, remainder r, + /// and carry witnesses for schoolbook column verification. + /// + /// Outputs (4*num_limbs - 2) witnesses starting at output_start: + /// [0..N) q limbs (quotient) + /// [N..2N) r limbs (remainder) — OUTPUT + /// [2N..4N-2) carry witnesses (unsigned-offset) + MultiLimbMulModHint { + output_start: usize, + a_limbs: Vec, + b_limbs: Vec, + modulus: [u64; 4], + limb_bits: u32, + num_limbs: u32, + }, + /// Prover hint for multi-limb modular inverse: a^{-1} mod p. + /// Given input a as N-limb vector and constant modulus p, + /// computes the inverse via Fermat's little theorem (a^{p-2} mod p). + /// + /// Outputs num_limbs witnesses at output_start: inv limbs. + MultiLimbModularInverse { + output_start: usize, + a_limbs: Vec, + modulus: [u64; 4], + limb_bits: u32, + num_limbs: u32, + }, + /// Prover hint for multi-limb addition quotient: q = floor((a + b) / p). + /// Given inputs a and b as N-limb vectors, and a constant modulus p, + /// computes q ∈ {0, 1}. + /// + /// Outputs 1 witness at output: q. + MultiLimbAddQuotient { + output: usize, + a_limbs: Vec, + b_limbs: Vec, + modulus: [u64; 4], + limb_bits: u32, + num_limbs: u32, + }, + /// Prover hint for multi-limb subtraction borrow: q = (a < b) ? 1 : 0. + /// Given inputs a and b as N-limb vectors, and a constant modulus p, + /// computes q ∈ {0, 1} indicating whether a borrow (adding p) is needed. + /// + /// Outputs 1 witness at output: q. + MultiLimbSubBorrow { + output: usize, + a_limbs: Vec, + b_limbs: Vec, + modulus: [u64; 4], + limb_bits: u32, + num_limbs: u32, + }, /// Decomposes a packed value into chunks of specified bit-widths. /// Given packed value and chunk_bits = [b0, b1, ..., bn]: /// packed = c0 + c1 * 2^b0 + c2 * 2^(b0+b1) + ... @@ -198,6 +270,37 @@ pub enum WitnessBuilder { packed: usize, chunk_bits: Vec, }, + /// Prover hint for FakeGLV scalar decomposition. + /// Given scalar s (from s_lo + s_hi * 2^128) and curve order n, + /// computes half_gcd(s, n) → (|s1|, |s2|, neg1, neg2) such that: + /// (-1)^neg1 * |s1| + (-1)^neg2 * |s2| * s ≡ 0 (mod n) + /// + /// Outputs 4 witnesses starting at output_start: + /// \[0\] |s1| (128-bit field element) + /// \[1\] |s2| (128-bit field element) + /// \[2\] neg1 (boolean: 0 or 1) + /// \[3\] neg2 (boolean: 0 or 1) + FakeGLVHint { + output_start: usize, + s_lo: usize, + s_hi: usize, + curve_order: [u64; 4], + }, + /// Prover hint for EC scalar multiplication: computes R = \[s\]P. + /// Given point P = (px, py) and scalar s = s_lo + s_hi * 2^128, + /// computes R = \[s\]P on the curve with parameter `curve_a` and + /// field modulus `field_modulus_p`. + /// + /// Outputs 2 witnesses at output_start: R_x, R_y. + EcScalarMulHint { + output_start: usize, + px: usize, + py: usize, + s_lo: usize, + s_hi: usize, + curve_a: [u64; 4], + field_modulus_p: [u64; 4], + }, /// Computes spread(input): interleave bits with zeros. /// Output: 0 b_{n-1} 0 b_{n-2} ... 0 b_1 0 b_0 /// (witness index of output, witness index of input) @@ -260,6 +363,10 @@ impl WitnessBuilder { WitnessBuilder::ChunkDecompose { chunk_bits, .. } => chunk_bits.len(), WitnessBuilder::SpreadBitExtract { chunk_bits, .. } => chunk_bits.len(), WitnessBuilder::MultiplicitiesForSpread(_, num_bits, _) => 1usize << *num_bits, + WitnessBuilder::MultiLimbMulModHint { num_limbs, .. } => (4 * *num_limbs - 2) as usize, + WitnessBuilder::MultiLimbModularInverse { num_limbs, .. } => *num_limbs as usize, + WitnessBuilder::FakeGLVHint { .. } => 4, + WitnessBuilder::EcScalarMulHint { .. } => 2, _ => 1, } diff --git a/provekit/prover/src/bigint_mod.rs b/provekit/prover/src/bigint_mod.rs new file mode 100644 index 00000000..e4ea1fea --- /dev/null +++ b/provekit/prover/src/bigint_mod.rs @@ -0,0 +1,1095 @@ +/// BigInteger modular arithmetic on [u64; 4] limbs (256-bit). +/// +/// These helpers compute modular inverse via Fermat's little theorem: +/// a^{-1} = a^{m-2} mod m, using schoolbook multiplication and +/// square-and-multiply exponentiation. + +/// Schoolbook multiplication: 4×4 limbs → 8 limbs (256-bit × 256-bit → +/// 512-bit). +pub fn widening_mul(a: &[u64; 4], b: &[u64; 4]) -> [u64; 8] { + let mut result = [0u64; 8]; + for i in 0..4 { + let mut carry = 0u128; + for j in 0..4 { + let product = (a[i] as u128) * (b[j] as u128) + (result[i + j] as u128) + carry; + result[i + j] = product as u64; + carry = product >> 64; + } + result[i + 4] = carry as u64; + } + result +} + +/// Compare 8-limb value with 4-limb value (zero-extended to 8 limbs). +/// Returns Ordering::Greater if wide > narrow, etc. +#[cfg(test)] +fn cmp_wide_narrow(wide: &[u64; 8], narrow: &[u64; 4]) -> std::cmp::Ordering { + // Check high limbs of wide (must all be zero for equality/less) + for i in (4..8).rev() { + if wide[i] != 0 { + return std::cmp::Ordering::Greater; + } + } + // Compare the low 4 limbs + for i in (0..4).rev() { + match wide[i].cmp(&narrow[i]) { + std::cmp::Ordering::Equal => continue, + other => return other, + } + } + std::cmp::Ordering::Equal +} + +/// Modular reduction of a 512-bit value by a 256-bit modulus. +/// Uses bit-by-bit long division. +fn reduce_wide(wide: &[u64; 8], modulus: &[u64; 4]) -> [u64; 4] { + // Find the highest set bit in wide + let mut highest_bit = 0; + for i in (0..8).rev() { + if wide[i] != 0 { + highest_bit = i * 64 + (64 - wide[i].leading_zeros() as usize); + break; + } + } + if highest_bit == 0 { + return [0u64; 4]; + } + + // Bit-by-bit long division + // remainder starts at 0, we shift in bits from the dividend + let mut remainder = [0u64; 4]; + for bit_pos in (0..highest_bit).rev() { + // Left-shift remainder by 1 + let carry = shift_left_one(&mut remainder); + debug_assert_eq!(carry, 0, "remainder overflow during shift"); + + // Bring in the next bit from wide + let limb_idx = bit_pos / 64; + let bit_idx = bit_pos % 64; + let bit = (wide[limb_idx] >> bit_idx) & 1; + remainder[0] |= bit; + + // If remainder >= modulus, subtract + if cmp_4limb(&remainder, modulus) != std::cmp::Ordering::Less { + sub_4limb_inplace(&mut remainder, modulus); + } + } + + remainder +} + +/// Left-shift a 4-limb number by 1 bit. Returns the carry-out bit. +fn shift_left_one(a: &mut [u64; 4]) -> u64 { + let mut carry = 0u64; + for limb in a.iter_mut() { + let new_carry = *limb >> 63; + *limb = (*limb << 1) | carry; + carry = new_carry; + } + carry +} + +/// Compare two 4-limb numbers. +pub fn cmp_4limb(a: &[u64; 4], b: &[u64; 4]) -> std::cmp::Ordering { + for i in (0..4).rev() { + match a[i].cmp(&b[i]) { + std::cmp::Ordering::Equal => continue, + other => return other, + } + } + std::cmp::Ordering::Equal +} + +/// Subtract b from a in-place (a -= b). Assumes a >= b. +fn sub_4limb_inplace(a: &mut [u64; 4], b: &[u64; 4]) { + let mut borrow = 0u64; + for i in 0..4 { + let (diff, borrow1) = a[i].overflowing_sub(b[i]); + let (diff2, borrow2) = diff.overflowing_sub(borrow); + a[i] = diff2; + borrow = (borrow1 as u64) + (borrow2 as u64); + } + debug_assert_eq!(borrow, 0, "subtraction underflow: a < b"); +} + +/// Modular multiplication: (a * b) mod m. +pub fn mul_mod(a: &[u64; 4], b: &[u64; 4], m: &[u64; 4]) -> [u64; 4] { + let wide = widening_mul(a, b); + reduce_wide(&wide, m) +} + +/// Modular exponentiation: base^exp mod m using square-and-multiply. +pub fn mod_pow(base: &[u64; 4], exp: &[u64; 4], m: &[u64; 4]) -> [u64; 4] { + // Find highest set bit in exp + let mut highest_bit = 0; + for i in (0..4).rev() { + if exp[i] != 0 { + highest_bit = i * 64 + (64 - exp[i].leading_zeros() as usize); + break; + } + } + if highest_bit == 0 { + // exp == 0 → result = 1 (for m > 1) + return [1, 0, 0, 0]; + } + + let mut result = [1u64, 0, 0, 0]; // 1 + for bit_pos in (0..highest_bit).rev() { + // Square + result = mul_mod(&result, &result, m); + // Multiply if bit is set + let limb_idx = bit_pos / 64; + let bit_idx = bit_pos % 64; + if (exp[limb_idx] >> bit_idx) & 1 == 1 { + result = mul_mod(&result, base, m); + } + } + + result +} + +/// Integer division with remainder: dividend = quotient * divisor + remainder, +/// where 0 <= remainder < divisor. Uses bit-by-bit long division. +pub fn divmod(dividend: &[u64; 4], divisor: &[u64; 4]) -> ([u64; 4], [u64; 4]) { + // Find the highest set bit in dividend + let mut highest_bit = 0; + for i in (0..4).rev() { + if dividend[i] != 0 { + highest_bit = i * 64 + (64 - dividend[i].leading_zeros() as usize); + break; + } + } + if highest_bit == 0 { + return ([0u64; 4], [0u64; 4]); + } + + let mut quotient = [0u64; 4]; + let mut remainder = [0u64; 4]; + + for bit_pos in (0..highest_bit).rev() { + // Left-shift remainder by 1 + let carry = shift_left_one(&mut remainder); + debug_assert_eq!(carry, 0, "remainder overflow during shift"); + + // Bring in the next bit from dividend + let limb_idx = bit_pos / 64; + let bit_idx = bit_pos % 64; + remainder[0] |= (dividend[limb_idx] >> bit_idx) & 1; + + // If remainder >= divisor, subtract and set quotient bit + if cmp_4limb(&remainder, divisor) != std::cmp::Ordering::Less { + sub_4limb_inplace(&mut remainder, divisor); + quotient[limb_idx] |= 1u64 << bit_idx; + } + } + + (quotient, remainder) +} + +/// Subtract a small u64 value from a 4-limb number. Assumes a >= small. +pub fn sub_u64(a: &[u64; 4], small: u64) -> [u64; 4] { + let mut result = *a; + let (diff, borrow) = result[0].overflowing_sub(small); + result[0] = diff; + if borrow { + for limb in result[1..].iter_mut() { + let (d, b) = limb.overflowing_sub(1); + *limb = d; + if !b { + break; + } + } + } + result +} + +/// Add two 4-limb (256-bit) numbers, returning a 5-limb result with carry. +pub fn add_4limb(a: &[u64; 4], b: &[u64; 4]) -> [u64; 5] { + let mut result = [0u64; 5]; + let mut carry = 0u64; + for i in 0..4 { + let (s1, c1) = a[i].overflowing_add(b[i]); + let (s2, c2) = s1.overflowing_add(carry); + result[i] = s2; + carry = (c1 as u64) + (c2 as u64); + } + result[4] = carry; + result +} + +/// Add two 4-limb numbers in-place: a += b. Returns the carry-out. +pub fn add_4limb_inplace(a: &mut [u64; 4], b: &[u64; 4]) -> u64 { + let mut carry = 0u64; + for i in 0..4 { + let (s1, c1) = a[i].overflowing_add(b[i]); + let (s2, c2) = s1.overflowing_add(carry); + a[i] = s2; + carry = (c1 as u64) + (c2 as u64); + } + carry +} + +/// Subtract b from a in-place, returning true if a >= b (no underflow). +/// If a < b, the result is a += 2^256 - b (wrapping subtraction) and returns +/// false. +pub fn sub_4limb_checked(a: &mut [u64; 4], b: &[u64; 4]) -> bool { + let mut borrow = 0u64; + for i in 0..4 { + let (d1, b1) = a[i].overflowing_sub(b[i]); + let (d2, b2) = d1.overflowing_sub(borrow); + a[i] = d2; + borrow = (b1 as u64) + (b2 as u64); + } + borrow == 0 +} + +/// Returns true if val == 0. +pub fn is_zero(val: &[u64; 4]) -> bool { + val[0] == 0 && val[1] == 0 && val[2] == 0 && val[3] == 0 +} + +/// Compute the bit mask for a limb of the given width. +pub fn limb_mask(limb_bits: u32) -> u128 { + if limb_bits >= 128 { + u128::MAX + } else { + (1u128 << limb_bits) - 1 + } +} + +/// Right-shift a 4-limb (256-bit) value by `bits` positions. +pub fn shr_256(val: &[u64; 4], bits: u32) -> [u64; 4] { + if bits >= 256 { + return [0; 4]; + } + let mut shifted = [0u64; 4]; + let word_shift = (bits / 64) as usize; + let bit_shift = bits % 64; + for i in 0..4 { + if i + word_shift < 4 { + shifted[i] = val[i + word_shift] >> bit_shift; + if bit_shift > 0 && i + word_shift + 1 < 4 { + shifted[i] |= val[i + word_shift + 1] << (64 - bit_shift); + } + } + } + shifted +} + +/// Decompose a 256-bit value into `num_limbs` limbs of `limb_bits` width. +/// Returns u128 limb values (each < 2^limb_bits). +pub fn decompose_to_u128_limbs(val: &[u64; 4], num_limbs: usize, limb_bits: u32) -> Vec { + let mask = limb_mask(limb_bits); + let mut limbs = Vec::with_capacity(num_limbs); + let mut remaining = *val; + for _ in 0..num_limbs { + let lo = remaining[0] as u128 | ((remaining[1] as u128) << 64); + limbs.push(lo & mask); + remaining = shr_256(&remaining, limb_bits); + } + limbs +} + +/// Reconstruct a 256-bit value from u128 limb values packed at `limb_bits` +/// boundaries. +pub fn reconstruct_from_u128_limbs(limb_values: &[u128], limb_bits: u32) -> [u64; 4] { + let mut val = [0u64; 4]; + let mut bit_offset = 0u32; + for &limb_u128 in limb_values.iter() { + let word_start = (bit_offset / 64) as usize; + let bit_within = bit_offset % 64; + if word_start < 4 { + val[word_start] |= (limb_u128 as u64) << bit_within; + if word_start + 1 < 4 { + val[word_start + 1] |= (limb_u128 >> (64 - bit_within)) as u64; + } + if word_start + 2 < 4 && bit_within > 0 { + let upper = limb_u128 >> (128 - bit_within); + if upper > 0 { + val[word_start + 2] |= upper as u64; + } + } + } + bit_offset += limb_bits; + } + val +} + +/// Compute schoolbook carries for a*b = p*q + r verification in base +/// 2^limb_bits. Returns unsigned-offset carries ready to be written as +/// witnesses. +pub fn compute_mul_mod_carries( + a_limbs: &[u128], + b_limbs: &[u128], + p_limbs: &[u128], + q_limbs: &[u128], + r_limbs: &[u128], + limb_bits: u32, +) -> Vec { + let n = a_limbs.len(); + let w = limb_bits; + let num_carries = 2 * n - 2; + let carry_offset = 1u128 << (w + ((n as f64).log2().ceil() as u32) + 1); + let mut carries = Vec::with_capacity(num_carries); + let mut carry: i128 = 0; + + for k in 0..(2 * n - 1) { + let mut ab_lo: u128 = 0; + let mut ab_hi: u64 = 0; + for i in 0..n { + let j = k as isize - i as isize; + if j >= 0 && (j as usize) < n { + let prod = a_limbs[i] * b_limbs[j as usize]; + let (new_lo, ov) = ab_lo.overflowing_add(prod); + ab_lo = new_lo; + if ov { + ab_hi += 1; + } + } + } + let mut pq_lo: u128 = 0; + let mut pq_hi: u64 = 0; + for i in 0..n { + let j = k as isize - i as isize; + if j >= 0 && (j as usize) < n { + let prod = p_limbs[i] * q_limbs[j as usize]; + let (new_lo, ov) = pq_lo.overflowing_add(prod); + pq_lo = new_lo; + if ov { + pq_hi += 1; + } + } + } + if k < n { + let (new_lo, ov) = pq_lo.overflowing_add(r_limbs[k]); + pq_lo = new_lo; + if ov { + pq_hi += 1; + } + } + + let diff_lo = ab_lo.wrapping_sub(pq_lo); + let borrow = if ab_lo < pq_lo { 1i64 } else { 0 }; + let diff_hi = ab_hi as i64 - pq_hi as i64 - borrow; + + let carry_lo = carry as u128; + let carry_hi: i64 = if carry < 0 { -1 } else { 0 }; + let (total_lo, ov) = diff_lo.overflowing_add(carry_lo); + let total_hi = diff_hi + carry_hi + if ov { 1i64 } else { 0 }; + + if k < 2 * n - 2 { + debug_assert_eq!( + total_lo & ((1u128 << w) - 1), + 0, + "non-zero remainder at column {k}" + ); + carry = total_hi as i128 * (1i128 << (128 - w)) + (total_lo >> w) as i128; + carries.push((carry + carry_offset as i128) as u128); + } + } + + carries +} + +/// Compute the number of bits needed for the half-GCD sub-scalars. +/// Returns `ceil(order_bits / 2)` where `order_bits` is the bit length of `n`. +pub fn half_gcd_bits(n: &[u64; 4]) -> u32 { + let mut order_bits = 0u32; + for i in (0..4).rev() { + if n[i] != 0 { + order_bits = (i as u32) * 64 + (64 - n[i].leading_zeros()); + break; + } + } + (order_bits + 1) / 2 +} + +/// Build the threshold value `2^half_bits` as a `[u64; 4]`. +fn build_threshold(half_bits: u32) -> [u64; 4] { + assert!(half_bits <= 255, "half_bits must be <= 255"); + let mut threshold = [0u64; 4]; + let word = (half_bits / 64) as usize; + let bit = half_bits % 64; + threshold[word] = 1u64 << bit; + threshold +} + +/// Half-GCD scalar decomposition for FakeGLV. +/// +/// Given scalar `s` and curve order `n`, finds `(|s1|, |s2|, neg1, neg2)` such +/// that: `(-1)^neg1 * |s1| + (-1)^neg2 * |s2| * s ≡ 0 (mod n)` +/// +/// Uses the extended GCD on `(n, s)`, stopping when the remainder drops below +/// `2^half_bits` where `half_bits = ceil(order_bits / 2)`. +/// Returns `(val1, val2, neg1, neg2)` where both fit in `half_bits` bits. +pub fn half_gcd(s: &[u64; 4], n: &[u64; 4]) -> ([u64; 4], [u64; 4], bool, bool) { + // Extended GCD on (n, s): + // We track: r_{i} = r_{i-2} - q_i * r_{i-1} + // t_{i} = t_{i-2} - q_i * t_{i-1} + // Starting: r_0 = n, r_1 = s, t_0 = 0, t_1 = 1 + // + // We want: t_i * s ≡ r_i (mod n) [up to sign] + // More precisely: t_i * s ≡ (-1)^{i+1} * r_i (mod n) + // + // The relation we verify is: sign_r * |r_i| + sign_t * |t_i| * s ≡ 0 (mod n) + + // Threshold: 2^half_bits where half_bits = ceil(order_bits / 2) + let half_bits = half_gcd_bits(n); + let threshold = build_threshold(half_bits); + + // r_prev = n, r_curr = s + let mut r_prev = *n; + let mut r_curr = *s; + + // t_prev = 0, t_curr = 1 + let mut t_prev = [0u64; 4]; + let mut t_curr = [1u64, 0, 0, 0]; + + // Track sign of t: t_prev_neg=false (t_0=0, positive), t_curr_neg=false (t_1=1, + // positive) + let mut t_prev_neg = false; + let mut t_curr_neg = false; + + let mut iteration = 0u32; + + loop { + // Check if r_curr < threshold + if cmp_4limb(&r_curr, &threshold) == std::cmp::Ordering::Less { + break; + } + + if is_zero(&r_curr) { + break; + } + + // q = r_prev / r_curr, new_r = r_prev % r_curr + let (q, new_r) = divmod(&r_prev, &r_curr); + + // new_t = t_prev + q * t_curr (in terms of absolute values and signs) + // Since the GCD recurrence is: t_{i} = t_{i-2} - q_i * t_{i-1} + // In terms of absolute values with sign tracking: + // If t_prev and q*t_curr have the same sign → subtract magnitudes + // If they have different signs → add magnitudes + // new_t = |t_prev| +/- q * |t_curr|, with sign flips each + // iteration. + // + // The standard extended GCD recurrence gives: + // t_i = t_{i-2} - q_i * t_{i-1} + // We track magnitudes and sign bits separately. + + // Compute q * t_curr + let qt = mul_mod_no_reduce(&q, &t_curr); + + // new_t magnitude and sign: + // In the standard recurrence: new_t_val = t_prev_val - q * t_curr_val + // where t_prev_val = (-1)^t_prev_neg * |t_prev|, etc. + // + // But it's simpler to just track: alternating signs. + // In the half-GCD: t values alternate in sign. So: + // new_t = t_prev + q * t_curr (absolute addition since signs alternate) + let mut new_t = qt; + add_4limb_inplace(&mut new_t, &t_prev); + let new_t_neg = !t_curr_neg; + + r_prev = r_curr; + r_curr = new_r; + t_prev = t_curr; + t_prev_neg = t_curr_neg; + t_curr = new_t; + t_curr_neg = new_t_neg; + iteration += 1; + } + + // At this point: r_curr < 2^half_bits and t_curr < ~2^half_bits (half-GCD + // property). + // + // From the extended GCD identity: t_i * s ≡ r_i (mod n) + // Rearranging: -r_i + t_i * s ≡ 0 (mod n) + // + // The circuit checks: (-1)^neg1 * |r_i| + (-1)^neg2 * |t_i| * s ≡ 0 (mod n) + // Since r_i is always non-negative, neg1 must always be true (negate r_i). + // neg2 must match the actual sign of t_i so that (-1)^neg2 * |t_i| = t_i. + + let val1 = r_curr; // |s1| = |r_i| + let val2 = t_curr; // |s2| = |t_i| + + let neg1 = true; // always negate r_i: -r_i + t_i * s ≡ 0 (mod n) + let neg2 = t_curr_neg; + + (val1, val2, neg1, neg2) +} + +/// Multiply two 4-limb values without modular reduction. +/// Returns the lower 4 limbs (ignoring overflow beyond 256 bits). +/// Used internally by half_gcd for q * t_curr where the result is known to fit. +fn mul_mod_no_reduce(a: &[u64; 4], b: &[u64; 4]) -> [u64; 4] { + let wide = widening_mul(a, b); + [wide[0], wide[1], wide[2], wide[3]] +} + +// --------------------------------------------------------------------------- +// Modular arithmetic helpers for EC operations (prover-side) +// --------------------------------------------------------------------------- + +/// Modular addition: (a + b) mod p. +pub fn mod_add(a: &[u64; 4], b: &[u64; 4], p: &[u64; 4]) -> [u64; 4] { + let sum = add_4limb(a, b); + let sum4 = [sum[0], sum[1], sum[2], sum[3]]; + if sum[4] > 0 || cmp_4limb(&sum4, p) != std::cmp::Ordering::Less { + // sum >= p, subtract p + let mut result = sum4; + sub_4limb_inplace(&mut result, p); + result + } else { + sum4 + } +} + +/// Modular subtraction: (a - b) mod p. +pub fn mod_sub(a: &[u64; 4], b: &[u64; 4], p: &[u64; 4]) -> [u64; 4] { + let mut result = *a; + let no_borrow = sub_4limb_checked(&mut result, b); + if no_borrow { + result + } else { + // a < b, add p to get (a - b + p) + add_4limb_inplace(&mut result, p); + result + } +} + +/// Modular inverse: a^{p-2} mod p (Fermat's little theorem). +pub fn mod_inverse(a: &[u64; 4], p: &[u64; 4]) -> [u64; 4] { + let exp = sub_u64(p, 2); + mod_pow(a, &exp, p) +} + +/// EC point doubling in affine coordinates on y^2 = x^3 + ax + b. +/// Returns (x3, y3) = 2*(px, py). +pub fn ec_point_double( + px: &[u64; 4], + py: &[u64; 4], + a: &[u64; 4], + p: &[u64; 4], +) -> ([u64; 4], [u64; 4]) { + // lambda = (3*x^2 + a) / (2*y) + let x_sq = mul_mod(px, px, p); + let two_x_sq = mod_add(&x_sq, &x_sq, p); + let three_x_sq = mod_add(&two_x_sq, &x_sq, p); + let numerator = mod_add(&three_x_sq, a, p); + let two_y = mod_add(py, py, p); + let denom_inv = mod_inverse(&two_y, p); + let lambda = mul_mod(&numerator, &denom_inv, p); + + // x3 = lambda^2 - 2*x + let lambda_sq = mul_mod(&lambda, &lambda, p); + let two_x = mod_add(px, px, p); + let x3 = mod_sub(&lambda_sq, &two_x, p); + + // y3 = lambda * (x - x3) - y + let x_minus_x3 = mod_sub(px, &x3, p); + let lambda_dx = mul_mod(&lambda, &x_minus_x3, p); + let y3 = mod_sub(&lambda_dx, py, p); + + (x3, y3) +} + +/// EC point addition in affine coordinates on y^2 = x^3 + ax + b. +/// Returns (x3, y3) = (p1x, p1y) + (p2x, p2y). Requires p1x != p2x. +pub fn ec_point_add( + p1x: &[u64; 4], + p1y: &[u64; 4], + p2x: &[u64; 4], + p2y: &[u64; 4], + p: &[u64; 4], +) -> ([u64; 4], [u64; 4]) { + // lambda = (y2 - y1) / (x2 - x1) + let numerator = mod_sub(p2y, p1y, p); + let denominator = mod_sub(p2x, p1x, p); + let denom_inv = mod_inverse(&denominator, p); + let lambda = mul_mod(&numerator, &denom_inv, p); + + // x3 = lambda^2 - x1 - x2 + let lambda_sq = mul_mod(&lambda, &lambda, p); + let x1_plus_x2 = mod_add(p1x, p2x, p); + let x3 = mod_sub(&lambda_sq, &x1_plus_x2, p); + + // y3 = lambda * (x1 - x3) - y1 + let x1_minus_x3 = mod_sub(p1x, &x3, p); + let lambda_dx = mul_mod(&lambda, &x1_minus_x3, p); + let y3 = mod_sub(&lambda_dx, p1y, p); + + (x3, y3) +} + +/// EC scalar multiplication via double-and-add: returns \[scalar\]*P. +pub fn ec_scalar_mul( + px: &[u64; 4], + py: &[u64; 4], + scalar: &[u64; 4], + a: &[u64; 4], + p: &[u64; 4], +) -> ([u64; 4], [u64; 4]) { + // Find highest set bit in scalar + let mut highest_bit = 0; + for i in (0..4).rev() { + if scalar[i] != 0 { + highest_bit = i * 64 + (64 - scalar[i].leading_zeros() as usize); + break; + } + } + if highest_bit == 0 { + // scalar == 0 → point at infinity (not representable in affine) + panic!("ec_scalar_mul: scalar is zero"); + } + + // Start from the MSB-1 and double-and-add + let mut rx = *px; + let mut ry = *py; + + for bit_pos in (0..highest_bit - 1).rev() { + // Double + let (dx, dy) = ec_point_double(&rx, &ry, a, p); + rx = dx; + ry = dy; + + // Add if bit is set + let limb_idx = bit_pos / 64; + let bit_idx = bit_pos % 64; + if (scalar[limb_idx] >> bit_idx) & 1 == 1 { + let (ax, ay) = ec_point_add(&rx, &ry, px, py, p); + rx = ax; + ry = ay; + } + } + + (rx, ry) +} + +/// Integer division of a 512-bit dividend by a 256-bit divisor. +/// Returns (quotient, remainder) where both fit in 256 bits. +/// Panics if the quotient would exceed 256 bits. +pub fn divmod_wide(dividend: &[u64; 8], divisor: &[u64; 4]) -> ([u64; 4], [u64; 4]) { + let mut highest_bit = 0; + for i in (0..8).rev() { + if dividend[i] != 0 { + highest_bit = i * 64 + (64 - dividend[i].leading_zeros() as usize); + break; + } + } + if highest_bit == 0 { + return ([0u64; 4], [0u64; 4]); + } + + let mut quotient = [0u64; 4]; + let mut remainder = [0u64; 4]; + + for bit_pos in (0..highest_bit).rev() { + let shift_carry = shift_left_one(&mut remainder); + + let limb_idx = bit_pos / 64; + let bit_idx = bit_pos % 64; + remainder[0] |= (dividend[limb_idx] >> bit_idx) & 1; + + // If shift_carry is set, the effective remainder is 2^256 + remainder, + // which is always > any 256-bit divisor, so we must subtract. + if shift_carry != 0 || cmp_4limb(&remainder, divisor) != std::cmp::Ordering::Less { + // Subtract divisor with inline borrow tracking (handles the case + // where remainder < divisor but shift_carry provides the extra bit). + let mut borrow = 0u64; + for i in 0..4 { + let (d1, b1) = remainder[i].overflowing_sub(divisor[i]); + let (d2, b2) = d1.overflowing_sub(borrow); + remainder[i] = d2; + borrow = (b1 as u64) + (b2 as u64); + } + // When shift_carry was set, the borrow absorbs it (they cancel out). + debug_assert_eq!( + borrow, shift_carry, + "unexpected borrow in divmod_wide at bit_pos {}", + bit_pos + ); + + assert!(bit_pos < 256, "quotient exceeds 256 bits"); + quotient[bit_pos / 64] |= 1u64 << (bit_pos % 64); + } + } + + (quotient, remainder) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_widening_mul_small() { + // 3 * 7 = 21 + let a = [3, 0, 0, 0]; + let b = [7, 0, 0, 0]; + let result = widening_mul(&a, &b); + assert_eq!(result[0], 21); + assert_eq!(result[1..], [0; 7]); + } + + #[test] + fn test_widening_mul_overflow() { + // u64::MAX * u64::MAX = (2^64-1)^2 = 2^128 - 2^65 + 1 + let a = [u64::MAX, 0, 0, 0]; + let b = [u64::MAX, 0, 0, 0]; + let result = widening_mul(&a, &b); + // (2^64-1)^2 = 0xFFFFFFFFFFFFFFFE_0000000000000001 + assert_eq!(result[0], 1); + assert_eq!(result[1], u64::MAX - 1); + assert_eq!(result[2..], [0; 6]); + } + + #[test] + fn test_reduce_wide_no_reduction() { + // 5 mod 7 = 5 + let wide = [5, 0, 0, 0, 0, 0, 0, 0]; + let modulus = [7, 0, 0, 0]; + assert_eq!(reduce_wide(&wide, &modulus), [5, 0, 0, 0]); + } + + #[test] + fn test_reduce_wide_basic() { + // 10 mod 7 = 3 + let wide = [10, 0, 0, 0, 0, 0, 0, 0]; + let modulus = [7, 0, 0, 0]; + assert_eq!(reduce_wide(&wide, &modulus), [3, 0, 0, 0]); + } + + #[test] + fn test_mul_mod_small() { + // (5 * 3) mod 7 = 15 mod 7 = 1 + let a = [5, 0, 0, 0]; + let b = [3, 0, 0, 0]; + let m = [7, 0, 0, 0]; + assert_eq!(mul_mod(&a, &b, &m), [1, 0, 0, 0]); + } + + #[test] + fn test_mod_pow_small() { + // 3^4 mod 7 = 81 mod 7 = 4 + let base = [3, 0, 0, 0]; + let exp = [4, 0, 0, 0]; + let m = [7, 0, 0, 0]; + assert_eq!(mod_pow(&base, &exp, &m), [4, 0, 0, 0]); + } + + #[test] + fn test_fermat_inverse_small() { + // Inverse of 3 mod 7: 3^{7-2} = 3^5 mod 7 = 243 mod 7 = 5 + // Check: 3 * 5 = 15 = 2*7 + 1 ≡ 1 (mod 7) ✓ + let a = [3, 0, 0, 0]; + let m = [7, 0, 0, 0]; + let exp = sub_u64(&m, 2); // m - 2 = 5 + let inv = mod_pow(&a, &exp, &m); + assert_eq!(inv, [5, 0, 0, 0]); + // Verify: a * inv mod m = 1 + assert_eq!(mul_mod(&a, &inv, &m), [1, 0, 0, 0]); + } + + #[test] + fn test_fermat_inverse_prime_23() { + // Inverse of 5 mod 23: 5^{21} mod 23 + // 5^{-1} mod 23 = 14 (because 5*14 = 70 = 3*23 + 1) + let a = [5, 0, 0, 0]; + let m = [23, 0, 0, 0]; + let exp = sub_u64(&m, 2); + let inv = mod_pow(&a, &exp, &m); + assert_eq!(inv, [14, 0, 0, 0]); + assert_eq!(mul_mod(&a, &inv, &m), [1, 0, 0, 0]); + } + + #[test] + fn test_sub_u64_basic() { + assert_eq!(sub_u64(&[10, 0, 0, 0], 3), [7, 0, 0, 0]); + } + + #[test] + fn test_sub_u64_borrow() { + // [0, 1, 0, 0] = 2^64; subtract 1 → [u64::MAX, 0, 0, 0] + assert_eq!(sub_u64(&[0, 1, 0, 0], 1), [u64::MAX, 0, 0, 0]); + } + + #[test] + fn test_fermat_inverse_large_prime() { + // Use a 128-bit prime: p = 2^127 - 1 = 170141183460469231731687303715884105727 + // In limbs: [u64::MAX, 2^63 - 1, 0, 0] + let p = [u64::MAX, (1u64 << 63) - 1, 0, 0]; + + // a = 42 + let a = [42, 0, 0, 0]; + let exp = sub_u64(&p, 2); + let inv = mod_pow(&a, &exp, &p); + + // Verify: a * inv mod p = 1 + assert_eq!(mul_mod(&a, &inv, &p), [1, 0, 0, 0]); + } + + #[test] + fn test_cmp_wide_narrow() { + let wide = [5, 0, 0, 0, 0, 0, 0, 0]; + let narrow = [5, 0, 0, 0]; + assert_eq!(cmp_wide_narrow(&wide, &narrow), std::cmp::Ordering::Equal); + + let wide_greater = [0, 0, 0, 0, 1, 0, 0, 0]; + assert_eq!( + cmp_wide_narrow(&wide_greater, &narrow), + std::cmp::Ordering::Greater + ); + } + + #[test] + fn test_mod_pow_zero_exp() { + // a^0 mod m = 1 + let base = [42, 0, 0, 0]; + let exp = [0, 0, 0, 0]; + let m = [7, 0, 0, 0]; + assert_eq!(mod_pow(&base, &exp, &m), [1, 0, 0, 0]); + } + + #[test] + fn test_mod_pow_one_exp() { + // a^1 mod m = a mod m + let base = [10, 0, 0, 0]; + let exp = [1, 0, 0, 0]; + let m = [7, 0, 0, 0]; + assert_eq!(mod_pow(&base, &exp, &m), [3, 0, 0, 0]); + } + + #[test] + fn test_divmod_exact() { + // 21 / 7 = 3 remainder 0 + let (q, r) = divmod(&[21, 0, 0, 0], &[7, 0, 0, 0]); + assert_eq!(q, [3, 0, 0, 0]); + assert_eq!(r, [0, 0, 0, 0]); + } + + #[test] + fn test_divmod_with_remainder() { + // 17 / 7 = 2 remainder 3 + let (q, r) = divmod(&[17, 0, 0, 0], &[7, 0, 0, 0]); + assert_eq!(q, [2, 0, 0, 0]); + assert_eq!(r, [3, 0, 0, 0]); + } + + #[test] + fn test_divmod_smaller_dividend() { + // 5 / 7 = 0 remainder 5 + let (q, r) = divmod(&[5, 0, 0, 0], &[7, 0, 0, 0]); + assert_eq!(q, [0, 0, 0, 0]); + assert_eq!(r, [5, 0, 0, 0]); + } + + #[test] + fn test_divmod_zero_dividend() { + let (q, r) = divmod(&[0, 0, 0, 0], &[7, 0, 0, 0]); + assert_eq!(q, [0, 0, 0, 0]); + assert_eq!(r, [0, 0, 0, 0]); + } + + #[test] + fn test_divmod_large() { + // 2^64 / 3 = 6148914691236517205 remainder 1 + // 2^64 in limbs: [0, 1, 0, 0] + let (q, r) = divmod(&[0, 1, 0, 0], &[3, 0, 0, 0]); + assert_eq!(q, [6148914691236517205, 0, 0, 0]); + assert_eq!(r, [1, 0, 0, 0]); + // Verify: q * 3 + 1 = 2^64 + assert_eq!(6148914691236517205u64.wrapping_mul(3).wrapping_add(1), 0u64); + // wraps to 0 in u64 = 2^64 + } + + #[test] + fn test_divmod_consistency() { + // Verify dividend = quotient * divisor + remainder for various inputs + let cases: Vec<([u64; 4], [u64; 4])> = vec![ + ([100, 0, 0, 0], [7, 0, 0, 0]), + ([u64::MAX, 0, 0, 0], [1000, 0, 0, 0]), + ([0, 1, 0, 0], [u64::MAX, 0, 0, 0]), // 2^64 / (2^64 - 1) + ]; + for (dividend, divisor) in cases { + let (q, r) = divmod(÷nd, &divisor); + // Verify: q * divisor + r = dividend + let product = widening_mul(&q, &divisor); + // Add remainder to product + let mut sum = product; + let mut carry = 0u128; + for i in 0..4 { + let s = (sum[i] as u128) + (r[i] as u128) + carry; + sum[i] = s as u64; + carry = s >> 64; + } + for i in 4..8 { + let s = (sum[i] as u128) + carry; + sum[i] = s as u64; + carry = s >> 64; + } + // sum should equal dividend (zero-extended to 8 limbs) + let mut expected = [0u64; 8]; + expected[..4].copy_from_slice(÷nd); + assert_eq!(sum, expected, "dividend={dividend:?} divisor={divisor:?}"); + } + } + + #[test] + fn test_divmod_wide_small() { + // 21 / 7 = 3 remainder 0 (512-bit dividend) + let dividend = [21, 0, 0, 0, 0, 0, 0, 0]; + let divisor = [7, 0, 0, 0]; + let (q, r) = divmod_wide(÷nd, &divisor); + assert_eq!(q, [3, 0, 0, 0]); + assert_eq!(r, [0, 0, 0, 0]); + } + + #[test] + fn test_divmod_wide_large() { + // Compute a * b where a, b are 256-bit, then divide by a + // Should give quotient = b, remainder = 0 + let a = [0xffffffffffffffff, 0xffffffff, 0x0, 0xffffffff00000001]; // secp256r1 p + let b = [42, 0, 0, 0]; + let product = widening_mul(&a, &b); + let (q, r) = divmod_wide(&product, &a); + assert_eq!(q, b); + assert_eq!(r, [0, 0, 0, 0]); + } + + #[test] + fn test_divmod_wide_with_remainder() { + // (a * b + 5) / a = b remainder 5 + let a = [0xffffffffffffffff, 0xffffffff, 0x0, 0xffffffff00000001]; + let b = [100, 0, 0, 0]; + let mut product = widening_mul(&a, &b); + // Add 5 + let (sum, overflow) = product[0].overflowing_add(5); + product[0] = sum; + if overflow { + for i in 1..8 { + let (s, o) = product[i].overflowing_add(1); + product[i] = s; + if !o { + break; + } + } + } + let (q, r) = divmod_wide(&product, &a); + assert_eq!(q, b); + assert_eq!(r, [5, 0, 0, 0]); + } + + #[test] + fn test_divmod_wide_consistency() { + // Verify: q * divisor + r = dividend + let a = [ + 0x123456789abcdef0, + 0xfedcba9876543210, + 0x1111111111111111, + 0x2222222222222222, + ]; + let b = [0xaabbccdd, 0x11223344, 0x55667788, 0x99001122]; + let product = widening_mul(&a, &b); + let divisor = [0xffffffffffffffff, 0xffffffff, 0x0, 0xffffffff00000001]; + let (q, r) = divmod_wide(&product, &divisor); + + // Verify: q * divisor + r = product + let qd = widening_mul(&q, &divisor); + let mut sum = qd; + let mut carry = 0u128; + for i in 0..4 { + let s = (sum[i] as u128) + (r[i] as u128) + carry; + sum[i] = s as u64; + carry = s >> 64; + } + for i in 4..8 { + let s = (sum[i] as u128) + carry; + sum[i] = s as u64; + carry = s >> 64; + } + assert_eq!(sum, product); + } + + #[test] + fn test_half_gcd_small() { + // s = 42, n = 101 + let s = [42, 0, 0, 0]; + let n = [101, 0, 0, 0]; + let (val1, val2, neg1, neg2) = half_gcd(&s, &n); + + // Verify: (-1)^neg1 * val1 + (-1)^neg2 * val2 * s ≡ 0 (mod n) + let sign1: i128 = if neg1 { -1 } else { 1 }; + let sign2: i128 = if neg2 { -1 } else { 1 }; + let v1 = val1[0] as i128; + let v2 = val2[0] as i128; + let s_val = s[0] as i128; + let n_val = n[0] as i128; + let lhs = ((sign1 * v1 + sign2 * v2 * s_val) % n_val + n_val) % n_val; + assert_eq!(lhs, 0, "half_gcd relation failed for small values"); + } + + #[test] + fn test_half_gcd_grumpkin_order() { + // Grumpkin curve order (BN254 base field order) + let n = [ + 0x3c208c16d87cfd47_u64, + 0x97816a916871ca8d_u64, + 0xb85045b68181585d_u64, + 0x30644e72e131a029_u64, + ]; + // Some scalar + let s = [ + 0x123456789abcdef0_u64, + 0xfedcba9876543210_u64, + 0x1111111111111111_u64, + 0x2222222222222222_u64, + ]; + + let (val1, val2, neg1, neg2) = half_gcd(&s, &n); + + // val1 and val2 should be < 2^128 + assert_eq!(val1[2], 0, "val1 should be < 2^128"); + assert_eq!(val1[3], 0, "val1 should be < 2^128"); + assert_eq!(val2[2], 0, "val2 should be < 2^128"); + assert_eq!(val2[3], 0, "val2 should be < 2^128"); + + // Verify: (-1)^neg1 * val1 + (-1)^neg2 * val2 * s ≡ 0 (mod n) + // Use big integer arithmetic + let term2_full = widening_mul(&val2, &s); + let (_, term2_mod_n) = divmod_wide(&term2_full, &n); + + // Compute: sign1 * val1 + sign2 * term2_mod_n (mod n) + let effective1 = if neg1 { + // n - val1 + let mut result = n; + sub_4limb_checked(&mut result, &val1); + result + } else { + val1 + }; + let effective2 = if neg2 { + let mut result = n; + sub_4limb_checked(&mut result, &term2_mod_n); + result + } else { + term2_mod_n + }; + + let sum = add_4limb(&effective1, &effective2); + let sum4 = [sum[0], sum[1], sum[2], sum[3]]; + // sum might be >= n, so reduce + let (_, remainder) = if sum[4] > 0 { + // Sum overflows 256 bits, need wide divmod + let wide = [sum[0], sum[1], sum[2], sum[3], sum[4], 0, 0, 0]; + divmod_wide(&wide, &n) + } else { + divmod(&sum4, &n) + }; + assert_eq!( + remainder, + [0, 0, 0, 0], + "half_gcd relation failed for Grumpkin order" + ); + } +} diff --git a/provekit/prover/src/lib.rs b/provekit/prover/src/lib.rs index 44cd0ca0..0fa9133a 100644 --- a/provekit/prover/src/lib.rs +++ b/provekit/prover/src/lib.rs @@ -21,6 +21,7 @@ use { whir::transcript::{codecs::Empty, ProverState, VerifierMessage}, }; +pub(crate) mod bigint_mod; pub mod input_utils; mod r1cs; mod whir_r1cs; @@ -196,6 +197,67 @@ impl Prove for NoirProver { .map(|(i, w)| w.ok_or_else(|| anyhow::anyhow!("Witness {i} unsolved after solving"))) .collect::>>()?; + // DEBUG: Check R1CS constraint satisfaction with ALL witnesses solved + { + use ark_ff::Zero; + let debug_r1cs = r1cs.clone(); + let interner = &debug_r1cs.interner; + let ha = debug_r1cs.a.hydrate(interner); + let hb = debug_r1cs.b.hydrate(interner); + let hc = debug_r1cs.c.hydrate(interner); + let mut fail_count = 0usize; + for row in 0..debug_r1cs.num_constraints() { + let eval = |hm: &provekit_common::sparse_matrix::HydratedSparseMatrix, + r: usize| + -> FieldElement { + let mut sum = FieldElement::zero(); + for (col, coeff) in hm.iter_row(r) { + sum += coeff * full_witness[col]; + } + sum + }; + let a_val = eval(&ha, row); + let b_val = eval(&hb, row); + let c_val = eval(&hc, row); + if a_val * b_val != c_val { + if fail_count < 10 { + eprintln!( + "CONSTRAINT {} FAILED: A={:?} B={:?} C={:?} A*B={:?}", + row, + a_val, + b_val, + c_val, + a_val * b_val + ); + eprint!(" A terms:"); + for (col, coeff) in ha.iter_row(row) { + eprint!(" w{}={:?}*{:?}", col, full_witness[col], coeff); + } + eprintln!(); + eprint!(" B terms:"); + for (col, coeff) in hb.iter_row(row) { + eprint!(" w{}={:?}*{:?}", col, full_witness[col], coeff); + } + eprintln!(); + eprint!(" C terms:"); + for (col, coeff) in hc.iter_row(row) { + eprint!(" w{}={:?}*{:?}", col, full_witness[col], coeff); + } + eprintln!(); + } + fail_count += 1; + } + } + if fail_count > 0 { + eprintln!( + "TOTAL FAILING CONSTRAINTS: {fail_count} / {}", + debug_r1cs.num_constraints() + ); + } else { + eprintln!("ALL {} CONSTRAINTS SATISFIED", debug_r1cs.num_constraints()); + } + } + let whir_r1cs_proof = self .whir_for_witness .prove_noir(merlin, r1cs, commitments, full_witness, &public_inputs) diff --git a/provekit/prover/src/witness/witness_builder.rs b/provekit/prover/src/witness/witness_builder.rs index db91e5e0..87b1105e 100644 --- a/provekit/prover/src/witness/witness_builder.rs +++ b/provekit/prover/src/witness/witness_builder.rs @@ -1,7 +1,7 @@ use { crate::witness::{digits::DigitalDecompositionWitnessesSolver, ram::SpiceWitnessesSolver}, acir::native_types::WitnessMap, - ark_ff::{BigInteger, PrimeField}, + ark_ff::{BigInteger, Field, PrimeField}, ark_std::Zero, provekit_common::{ utils::noir_to_native, @@ -23,6 +23,42 @@ pub trait WitnessBuilderSolver { ); } +/// Resolve a ConstantOrR1CSWitness to its FieldElement value. +fn resolve(witness: &[Option], v: &ConstantOrR1CSWitness) -> FieldElement { + match v { + ConstantOrR1CSWitness::Constant(c) => *c, + ConstantOrR1CSWitness::Witness(idx) => witness[*idx].unwrap(), + } +} + +/// Convert a u128 value to a FieldElement. +fn u128_to_fe(val: u128) -> FieldElement { + FieldElement::from_bigint(ark_ff::BigInt([val as u64, (val >> 64) as u64, 0, 0])).unwrap() +} + +/// Read witness limbs and reconstruct as [u64; 4]. +fn read_witness_limbs( + witness: &[Option], + indices: &[usize], + limb_bits: u32, +) -> [u64; 4] { + let limb_values: Vec = indices + .iter() + .map(|&idx| { + let bigint = witness[idx].unwrap().into_bigint().0; + bigint[0] as u128 | ((bigint[1] as u128) << 64) + }) + .collect(); + crate::bigint_mod::reconstruct_from_u128_limbs(&limb_values, limb_bits) +} + +/// Write u128 limb values as FieldElement witnesses starting at `start`. +fn write_limbs(witness: &mut [Option], start: usize, vals: &[u128]) { + for (i, &val) in vals.iter().enumerate() { + witness[start + i] = Some(u128_to_fe(val)); + } +} + impl WitnessBuilderSolver for WitnessBuilder { fn solve( &self, @@ -65,6 +101,32 @@ impl WitnessBuilderSolver for WitnessBuilder { "Inverse/LogUpInverse should not be called - handled by batch inversion" ) } + WitnessBuilder::SafeInverse(witness_idx, operand_idx) => { + let val = witness[*operand_idx].unwrap(); + witness[*witness_idx] = Some(if val == FieldElement::zero() { + FieldElement::zero() + } else { + val.inverse().unwrap() + }); + } + WitnessBuilder::ModularInverse(witness_idx, operand_idx, modulus) => { + let a = witness[*operand_idx].unwrap(); + let a_limbs = a.into_bigint().0; + let m_limbs = modulus.into_bigint().0; + // Fermat's little theorem: a^{-1} = a^{m-2} mod m + let exp = crate::bigint_mod::sub_u64(&m_limbs, 2); + let result_limbs = crate::bigint_mod::mod_pow(&a_limbs, &exp, &m_limbs); + witness[*witness_idx] = + Some(FieldElement::from_bigint(ark_ff::BigInt(result_limbs)).unwrap()); + } + WitnessBuilder::IntegerQuotient(witness_idx, dividend_idx, divisor) => { + let dividend = witness[*dividend_idx].unwrap(); + let d_limbs = dividend.into_bigint().0; + let m_limbs = divisor.into_bigint().0; + let (quotient, _remainder) = crate::bigint_mod::divmod(&d_limbs, &m_limbs); + witness[*witness_idx] = + Some(FieldElement::from_bigint(ark_ff::BigInt(quotient)).unwrap()); + } WitnessBuilder::IndexedLogUpDenominator( witness_idx, sz_challenge, @@ -145,18 +207,9 @@ impl WitnessBuilderSolver for WitnessBuilder { rhs, output, ) => { - let lhs = match lhs { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; - let rhs = match rhs { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; - let output = match output { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; + let lhs = resolve(witness, lhs); + let rhs = resolve(witness, rhs); + let output = resolve(witness, output); witness[*witness_idx] = Some( witness[*sz_challenge].unwrap() - (lhs @@ -175,22 +228,10 @@ impl WitnessBuilderSolver for WitnessBuilder { and_output, xor_output, ) => { - let lhs = match lhs { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; - let rhs = match rhs { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; - let and_out = match and_output { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; - let xor_out = match xor_output { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; + let lhs = resolve(witness, lhs); + let rhs = resolve(witness, rhs); + let and_out = resolve(witness, and_output); + let xor_out = resolve(witness, xor_output); // Encoding: sz - (lhs + rs*rhs + rs²*and_out + rs³*xor_out) witness[*witness_idx] = Some( witness[*sz_challenge].unwrap() @@ -203,18 +244,8 @@ impl WitnessBuilderSolver for WitnessBuilder { WitnessBuilder::MultiplicitiesForBinOp(witness_idx, atomic_bits, operands) => { let mut multiplicities = vec![0u32; 2usize.pow(2 * *atomic_bits)]; for (lhs, rhs) in operands { - let lhs = match lhs { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => { - witness[*witness_idx].unwrap() - } - }; - let rhs = match rhs { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => { - witness[*witness_idx].unwrap() - } - }; + let lhs = resolve(witness, lhs); + let rhs = resolve(witness, rhs); let index = (lhs.into_bigint().0[0] << *atomic_bits) + rhs.into_bigint().0[0]; multiplicities[index as usize] += 1; } @@ -223,14 +254,8 @@ impl WitnessBuilderSolver for WitnessBuilder { } } WitnessBuilder::U32Addition(result_witness_idx, carry_witness_idx, a, b) => { - let a_val = match a { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(idx) => witness[*idx].unwrap(), - }; - let b_val = match b { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(idx) => witness[*idx].unwrap(), - }; + let a_val = resolve(witness, a); + let b_val = resolve(witness, b); assert!( a_val.into_bigint().num_bits() <= 32, "a_val must be less than or equal to 32 bits, got {}", @@ -258,12 +283,7 @@ impl WitnessBuilderSolver for WitnessBuilder { // Sum all inputs as u64 to handle overflow. let mut sum: u64 = 0; for input in inputs { - let val = match input { - ConstantOrR1CSWitness::Constant(c) => c.into_bigint().0[0], - ConstantOrR1CSWitness::Witness(idx) => { - witness[*idx].unwrap().into_bigint().0[0] - } - }; + let val = resolve(witness, input).into_bigint().0[0]; assert!(val < (1u64 << 32), "input must be 32-bit"); sum += val; } @@ -274,14 +294,8 @@ impl WitnessBuilderSolver for WitnessBuilder { witness[*carry_witness_idx] = Some(FieldElement::from(quotient)); } WitnessBuilder::And(result_witness_idx, lh, rh) => { - let lh_val = match lh { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; - let rh_val = match rh { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; + let lh_val = resolve(witness, lh); + let rh_val = resolve(witness, rh); assert!( lh_val.into_bigint().num_bits() <= 32, "lh_val must be less than or equal to 32 bits, got {}", @@ -297,14 +311,8 @@ impl WitnessBuilderSolver for WitnessBuilder { )); } WitnessBuilder::Xor(result_witness_idx, lh, rh) => { - let lh_val = match lh { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; - let rh_val = match rh { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; + let lh_val = resolve(witness, lh); + let rh_val = resolve(witness, rh); assert!( lh_val.into_bigint().num_bits() <= 32, "lh_val must be less than or equal to 32 bits, got {}", @@ -319,6 +327,107 @@ impl WitnessBuilderSolver for WitnessBuilder { lh_val.into_bigint() ^ rh_val.into_bigint(), )); } + WitnessBuilder::MultiLimbMulModHint { + output_start, + a_limbs, + b_limbs, + modulus, + limb_bits, + num_limbs, + } => { + use crate::bigint_mod::{ + compute_mul_mod_carries, decompose_to_u128_limbs, divmod_wide, widening_mul, + }; + let n = *num_limbs as usize; + let w = *limb_bits; + + let a_val = read_witness_limbs(witness, a_limbs, w); + let b_val = read_witness_limbs(witness, b_limbs, w); + + let product = widening_mul(&a_val, &b_val); + let (q_val, r_val) = divmod_wide(&product, modulus); + + let q_limbs_vals = decompose_to_u128_limbs(&q_val, n, w); + let r_limbs_vals = decompose_to_u128_limbs(&r_val, n, w); + + let carries = compute_mul_mod_carries( + &decompose_to_u128_limbs(&a_val, n, w), + &decompose_to_u128_limbs(&b_val, n, w), + &decompose_to_u128_limbs(modulus, n, w), + &q_limbs_vals, + &r_limbs_vals, + w, + ); + + write_limbs(witness, *output_start, &q_limbs_vals); + write_limbs(witness, *output_start + n, &r_limbs_vals); + write_limbs(witness, *output_start + 2 * n, &carries); + } + WitnessBuilder::MultiLimbModularInverse { + output_start, + a_limbs, + modulus, + limb_bits, + num_limbs, + } => { + use crate::bigint_mod::{decompose_to_u128_limbs, mod_pow, sub_u64}; + let n = *num_limbs as usize; + let w = *limb_bits; + + let a_val = read_witness_limbs(witness, a_limbs, w); + let exp = sub_u64(modulus, 2); + let inv = mod_pow(&a_val, &exp, modulus); + write_limbs(witness, *output_start, &decompose_to_u128_limbs(&inv, n, w)); + } + WitnessBuilder::MultiLimbAddQuotient { + output, + a_limbs, + b_limbs, + modulus, + limb_bits, + .. + } => { + use crate::bigint_mod::{add_4limb, cmp_4limb}; + let w = *limb_bits; + + let a_val = read_witness_limbs(witness, a_limbs, w); + let b_val = read_witness_limbs(witness, b_limbs, w); + + let sum = add_4limb(&a_val, &b_val); + let q = if sum[4] > 0 { + 1u64 + } else { + let sum4 = [sum[0], sum[1], sum[2], sum[3]]; + if cmp_4limb(&sum4, modulus) != std::cmp::Ordering::Less { + 1u64 + } else { + 0u64 + } + }; + + witness[*output] = Some(FieldElement::from(q)); + } + WitnessBuilder::MultiLimbSubBorrow { + output, + a_limbs, + b_limbs, + limb_bits, + .. + } => { + use crate::bigint_mod::cmp_4limb; + let w = *limb_bits; + + let a_val = read_witness_limbs(witness, a_limbs, w); + let b_val = read_witness_limbs(witness, b_limbs, w); + + let q = if cmp_4limb(&a_val, &b_val) == std::cmp::Ordering::Less { + 1u64 + } else { + 0u64 + }; + + witness[*output] = Some(FieldElement::from(q)); + } WitnessBuilder::BytePartition { lo, hi, x, k } => { let x_val = witness[*x].unwrap().into_bigint().0[0]; debug_assert!(x_val < 256, "BytePartition input must be 8-bit"); @@ -330,6 +439,58 @@ impl WitnessBuilderSolver for WitnessBuilder { witness[*lo] = Some(FieldElement::from(lo_val)); witness[*hi] = Some(FieldElement::from(hi_val)); } + WitnessBuilder::FakeGLVHint { + output_start, + s_lo, + s_hi, + curve_order, + } => { + // Reconstruct s = s_lo + s_hi * 2^128 + let s_lo_val = witness[*s_lo].unwrap().into_bigint().0; + let s_hi_val = witness[*s_hi].unwrap().into_bigint().0; + let s_val: [u64; 4] = [s_lo_val[0], s_lo_val[1], s_hi_val[0], s_hi_val[1]]; + + let (val1, val2, neg1, neg2) = crate::bigint_mod::half_gcd(&s_val, curve_order); + + witness[*output_start] = + Some(FieldElement::from_bigint(ark_ff::BigInt(val1)).unwrap()); + witness[*output_start + 1] = + Some(FieldElement::from_bigint(ark_ff::BigInt(val2)).unwrap()); + witness[*output_start + 2] = Some(FieldElement::from(neg1 as u64)); + witness[*output_start + 3] = Some(FieldElement::from(neg2 as u64)); + } + WitnessBuilder::EcScalarMulHint { + output_start, + px, + py, + s_lo, + s_hi, + curve_a, + field_modulus_p, + } => { + // Reconstruct scalar s = s_lo + s_hi * 2^128 + let s_lo_val = witness[*s_lo].unwrap().into_bigint().0; + let s_hi_val = witness[*s_hi].unwrap().into_bigint().0; + let scalar: [u64; 4] = [s_lo_val[0], s_lo_val[1], s_hi_val[0], s_hi_val[1]]; + + // Reconstruct point P + let px_val = witness[*px].unwrap().into_bigint().0; + let py_val = witness[*py].unwrap().into_bigint().0; + + // Compute R = [s]P + let (rx, ry) = crate::bigint_mod::ec_scalar_mul( + &px_val, + &py_val, + &scalar, + curve_a, + field_modulus_p, + ); + + witness[*output_start] = + Some(FieldElement::from_bigint(ark_ff::BigInt(rx)).unwrap()); + witness[*output_start + 1] = + Some(FieldElement::from_bigint(ark_ff::BigInt(ry)).unwrap()); + } WitnessBuilder::CombinedTableEntryInverse(..) => { unreachable!( "CombinedTableEntryInverse should not be called - handled by batch inversion" @@ -393,12 +554,7 @@ impl WitnessBuilderSolver for WitnessBuilder { let table_size = 1usize << *num_bits; let mut multiplicities = vec![0u32; table_size]; for query in queries { - let val = match query { - ConstantOrR1CSWitness::Constant(c) => c.into_bigint().0[0], - ConstantOrR1CSWitness::Witness(w) => { - witness[*w].unwrap().into_bigint().0[0] - } - }; + let val = resolve(witness, query).into_bigint().0[0]; multiplicities[val as usize] += 1; } for (i, count) in multiplicities.iter().enumerate() { @@ -408,14 +564,8 @@ impl WitnessBuilderSolver for WitnessBuilder { WitnessBuilder::SpreadLookupDenominator(idx, sz, rs, input, spread_output) => { let sz_val = witness[*sz].unwrap(); let rs_val = witness[*rs].unwrap(); - let input_val = match input { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(w) => witness[*w].unwrap(), - }; - let spread_val = match spread_output { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(w) => witness[*w].unwrap(), - }; + let input_val = resolve(witness, input); + let spread_val = resolve(witness, spread_output); // sz - (input + rs * spread_output) witness[*idx] = Some(sz_val - (input_val + rs_val * spread_val)); } diff --git a/provekit/r1cs-compiler/src/digits.rs b/provekit/r1cs-compiler/src/digits.rs index 91c4e412..657f7bd7 100644 --- a/provekit/r1cs-compiler/src/digits.rs +++ b/provekit/r1cs-compiler/src/digits.rs @@ -1,5 +1,6 @@ use { crate::noir_to_r1cs::NoirToR1CSCompiler, + ark_ff::Field, ark_std::One, provekit_common::{ witness::{DigitalDecompositionWitnesses, WitnessBuilder}, @@ -66,7 +67,8 @@ pub(crate) fn add_digital_decomposition( // Add the constraints for the digital recomposition let mut digit_multipliers = vec![FieldElement::one()]; for log_base in log_bases[..log_bases.len() - 1].iter() { - let multiplier = *digit_multipliers.last().unwrap() * FieldElement::from(1u64 << *log_base); + let multiplier = + *digit_multipliers.last().unwrap() * FieldElement::from(2u64).pow([*log_base as u64]); digit_multipliers.push(multiplier); } dd_struct diff --git a/provekit/r1cs-compiler/src/lib.rs b/provekit/r1cs-compiler/src/lib.rs index 7de8f899..a1ee6b1c 100644 --- a/provekit/r1cs-compiler/src/lib.rs +++ b/provekit/r1cs-compiler/src/lib.rs @@ -1,6 +1,7 @@ mod binops; mod digits; mod memory; +mod msm; mod noir_proof_scheme; mod noir_to_r1cs; mod poseidon2; diff --git a/provekit/r1cs-compiler/src/msm/cost_model.rs b/provekit/r1cs-compiler/src/msm/cost_model.rs new file mode 100644 index 00000000..8a42735b --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/cost_model.rs @@ -0,0 +1,337 @@ +//! Analytical cost model for MSM parameter optimization. +//! +//! Follows the SHA256 pattern (`spread.rs:get_optimal_spread_width`): +//! pure analytical estimator → exhaustive search → pick optimal (limb_bits, +//! window_size). + +/// Type of field operation for cost estimation. +#[derive(Clone, Copy)] +pub enum FieldOpType { + Add, + Sub, + Mul, + Inv, +} + +/// Count field ops in scalar_mul_glv for given parameters. +/// +/// The GLV approach does interleaved two-point scalar mul with half-width +/// scalars. Per window: w shared doubles + 2 table lookups + 2 point_adds + 2 +/// is_zero + 2 point_selects Plus: 2 table builds, on-curve check, scalar +/// relation overhead. +fn count_glv_field_ops( + scalar_bits: usize, // half_bits = ceil(order_bits / 2) + window_size: usize, +) -> (usize, usize, usize, usize) { + let w = window_size; + let table_size = 1 << w; + let num_windows = (scalar_bits + w - 1) / w; + + let double_ops = (4usize, 2usize, 5usize, 1usize); + let add_ops = (2usize, 2usize, 3usize, 1usize); + let select_ops_per_point = (2usize, 2usize, 2usize, 0usize); + + // Two tables (one for P, one for R) + let table_doubles = if table_size > 2 { 1 } else { 0 }; + let table_adds = if table_size > 2 { table_size - 3 } else { 0 }; + + let mut total_add = 2 * (table_doubles * double_ops.0 + table_adds * add_ops.0); + let mut total_sub = 2 * (table_doubles * double_ops.1 + table_adds * add_ops.1); + let mut total_mul = 2 * (table_doubles * double_ops.2 + table_adds * add_ops.2); + let mut total_inv = 2 * (table_doubles * double_ops.3 + table_adds * add_ops.3); + + for win_idx in (0..num_windows).rev() { + let bit_start = win_idx * w; + let bit_end = std::cmp::min(bit_start + w, scalar_bits); + let actual_w = bit_end - bit_start; + let actual_selects = (1 << actual_w) - 1; + + // w shared doublings + total_add += w * double_ops.0; + total_sub += w * double_ops.1; + total_mul += w * double_ops.2; + total_inv += w * double_ops.3; + + // Two table lookups + two point_adds + two is_zeros + two point_selects + for _ in 0..2 { + total_add += actual_selects * select_ops_per_point.0; + total_sub += actual_selects * select_ops_per_point.1; + total_mul += actual_selects * select_ops_per_point.2; + + total_add += add_ops.0; + total_sub += add_ops.1; + total_mul += add_ops.2; + total_inv += add_ops.3; + + total_inv += 1; // is_zero + total_add += 1; + total_mul += 1; + + total_add += select_ops_per_point.0; + total_sub += select_ops_per_point.1; + total_mul += select_ops_per_point.2; + } + } + + // On-curve checks for P and R: each needs 1 mul (y^2), 2 mul (x^2, x^3), 1 mul + // (a*x), 2 add + total_mul += 8; + total_add += 4; + + // Conditional y-negation: 2 sub + 2 select (for P.y and R.y) + total_sub += 2; + total_add += 2 * select_ops_per_point.0; + total_sub += 2 * select_ops_per_point.1; + total_mul += 2 * select_ops_per_point.2; + + (total_add, total_sub, total_mul, total_inv) +} + +/// Witnesses per single N-limb field operation. +fn witnesses_per_op(num_limbs: usize, op: FieldOpType, is_native: bool) -> usize { + if is_native { + // Native: no range checks, just standard R1CS witnesses + match op { + FieldOpType::Add => 1, // sum witness + FieldOpType::Sub => 1, // sum witness + FieldOpType::Mul => 1, // product witness + FieldOpType::Inv => 1, // inverse witness + } + } else if num_limbs == 1 { + // Single-limb non-native: reduce_mod_p pattern + match op { + FieldOpType::Add => 5, // a+b, m const, k, k*m, result + FieldOpType::Sub => 5, // same + FieldOpType::Mul => 5, // a*b, m const, k, k*m, result + FieldOpType::Inv => 7, // a_inv + mul_mod_p(5) + range_check + } + } else { + // Multi-limb: N-limb operations + let n = num_limbs; + match op { + // add/sub: q + N*(v_offset, carry, r_limb) + N*(v_diff, borrow, d_limb) + FieldOpType::Add | FieldOpType::Sub => 1 + 3 * n + 3 * n, + // mul: hint(4N-2) + N² products + 2N-1 column constraints + lt_check + FieldOpType::Mul => (4 * n - 2) + n * n + 3 * n, + // inv: hint(N) + mul costs + FieldOpType::Inv => n + (4 * n - 2) + n * n + 3 * n, + } + } +} + +/// Total estimated witness cost for one scalar_mul. +pub fn calculate_msm_witness_cost( + native_field_bits: u32, + curve_modulus_bits: u32, + n_points: usize, + scalar_bits: usize, + window_size: usize, + limb_bits: u32, +) -> usize { + let is_native = curve_modulus_bits == native_field_bits; + let num_limbs = if is_native { + 1 + } else { + ((curve_modulus_bits as usize) + (limb_bits as usize) - 1) / (limb_bits as usize) + }; + + let wit_add = witnesses_per_op(num_limbs, FieldOpType::Add, is_native); + let wit_sub = witnesses_per_op(num_limbs, FieldOpType::Sub, is_native); + let wit_mul = witnesses_per_op(num_limbs, FieldOpType::Mul, is_native); + let wit_inv = witnesses_per_op(num_limbs, FieldOpType::Inv, is_native); + + // FakeGLV path for ALL points: half-width interleaved scalar mul + let half_bits = (scalar_bits + 1) / 2; + let (n_add, n_sub, n_mul, n_inv) = count_glv_field_ops(half_bits, window_size); + let glv_scalarmul = n_add * wit_add + n_sub * wit_sub + n_mul * wit_mul + n_inv * wit_inv; + + // Per-point overhead: scalar decomposition (2 × half_bits for s1, s2) + + // scalar relation (~150 witnesses) + FakeGLVHint (4 witnesses) + let scalar_decomp = 2 * half_bits + 10; + let scalar_relation = 150; + let glv_hint = 4; + + // EcScalarMulHint: 2 witnesses per point (only for n_points > 1) + let ec_hint = if n_points > 1 { 2 } else { 0 }; + + let per_point = glv_scalarmul + scalar_decomp + scalar_relation + glv_hint + ec_hint; + + // Point accumulation: (n_points - 1) point_adds + let accum = if n_points > 1 { + let accum_adds = n_points - 1; + accum_adds + * (witnesses_per_op(num_limbs, FieldOpType::Add, is_native) * 2 + + witnesses_per_op(num_limbs, FieldOpType::Sub, is_native) * 2 + + witnesses_per_op(num_limbs, FieldOpType::Mul, is_native) * 3 + + witnesses_per_op(num_limbs, FieldOpType::Inv, is_native)) + } else { + 0 + }; + + n_points * per_point + accum +} + +/// Check whether schoolbook column equation values fit in the native field. +/// +/// In `mul_mod_p_multi`, the schoolbook multiplication verifies `a·b = p·q + r` +/// via column equations that include product sums, carry offsets, and outgoing +/// carries. Both sides of each column equation must evaluate to less than the +/// native field modulus as **integers** — if they overflow, the field's modular +/// reduction makes `LHS ≡ RHS (mod p)` weaker than `LHS = RHS`, breaking +/// soundness. +/// +/// The maximum integer value across either side of any column equation is +/// bounded by: +/// +/// `2^(2W + ceil(log2(N)) + 3)` +/// +/// where `W = limb_bits` and `N = num_limbs`. This accounts for: +/// - Up to N cross-products per column, each < `2^(2W)` +/// - The carry offset `2^(2W + ceil(log2(N)) + 1)` (dominant term) +/// - Outgoing carry term `2^W * offset_carry` on the RHS +/// +/// Since the native field modulus satisfies `p >= 2^(native_field_bits - 1)`, +/// the conservative soundness condition is: +/// +/// `2 * limb_bits + ceil(log2(num_limbs)) + 3 < native_field_bits` +pub fn column_equation_fits_native_field( + native_field_bits: u32, + limb_bits: u32, + num_limbs: usize, +) -> bool { + if num_limbs <= 1 { + return true; // Single-limb path has no column equations. + } + let ceil_log2_n = (num_limbs as f64).log2().ceil() as u32; + // Max column value < 2^(2*limb_bits + ceil_log2_n + 3). + // Need this < p_native >= 2^(native_field_bits - 1). + 2 * limb_bits + ceil_log2_n + 3 < native_field_bits +} + +/// Search for optimal (limb_bits, window_size) minimizing witness cost. +/// +/// Searches limb_bits ∈ [8..max] and window_size ∈ [2..8]. +/// Each candidate is checked for column equation soundness: the schoolbook +/// multiplication's intermediate values must fit in the native field without +/// modular wraparound (see [`column_equation_fits_native_field`]). +pub fn get_optimal_msm_params( + native_field_bits: u32, + curve_modulus_bits: u32, + n_points: usize, + scalar_bits: usize, +) -> (u32, usize) { + let is_native = curve_modulus_bits == native_field_bits; + if is_native { + // For native field, limb_bits doesn't matter (no multi-limb decomposition). + // Just optimize window_size. + let mut best_cost = usize::MAX; + let mut best_window = 4; + for ws in 2..=8 { + let cost = calculate_msm_witness_cost( + native_field_bits, + curve_modulus_bits, + n_points, + scalar_bits, + ws, + native_field_bits, + ); + if cost < best_cost { + best_cost = cost; + best_window = ws; + } + } + return (native_field_bits, best_window); + } + + // Upper bound on search: even with N=2 (best case), we need + // 2*lb + ceil(log2(2)) + 3 < native_field_bits => lb < (native_field_bits - 4) + // / 2. The per-candidate soundness check below is the actual gate. + let max_limb_bits = (native_field_bits.saturating_sub(4)) / 2; + let mut best_cost = usize::MAX; + let mut best_limb_bits = max_limb_bits.min(86); + let mut best_window = 4; + + // Search space + for lb in (8..=max_limb_bits).step_by(2) { + let num_limbs = ((curve_modulus_bits as usize) + (lb as usize) - 1) / (lb as usize); + if !column_equation_fits_native_field(native_field_bits, lb, num_limbs) { + continue; + } + for ws in 2..=8usize { + let cost = calculate_msm_witness_cost( + native_field_bits, + curve_modulus_bits, + n_points, + scalar_bits, + ws, + lb, + ); + if cost < best_cost { + best_cost = cost; + best_limb_bits = lb; + best_window = ws; + } + } + } + + (best_limb_bits, best_window) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_optimal_params_bn254_native() { + // Grumpkin over BN254: native field + let (limb_bits, window_size) = get_optimal_msm_params(254, 254, 1, 256); + assert_eq!(limb_bits, 254); + assert!(window_size >= 2 && window_size <= 8); + } + + #[test] + fn test_optimal_params_secp256r1() { + // secp256r1 over BN254: 256-bit modulus, non-native + let (limb_bits, window_size) = get_optimal_msm_params(254, 256, 1, 256); + let num_limbs = ((256 + limb_bits - 1) / limb_bits) as usize; + assert!( + column_equation_fits_native_field(254, limb_bits, num_limbs), + "optimizer selected unsound limb_bits={limb_bits} (N={num_limbs})" + ); + assert!(window_size >= 2 && window_size <= 8); + } + + #[test] + fn test_optimal_params_goldilocks() { + // Hypothetical 64-bit field over BN254 + let (limb_bits, window_size) = get_optimal_msm_params(254, 64, 1, 64); + let num_limbs = ((64 + limb_bits - 1) / limb_bits) as usize; + assert!( + column_equation_fits_native_field(254, limb_bits, num_limbs), + "optimizer selected unsound limb_bits={limb_bits} (N={num_limbs})" + ); + assert!(window_size >= 2 && window_size <= 8); + } + + #[test] + fn test_column_equation_soundness_boundary() { + // For BN254 (254 bits) with N=3: max safe limb_bits is 124. + // 2*124 + ceil(log2(3)) + 3 = 248 + 2 + 3 = 253 < 254 ✓ + assert!(column_equation_fits_native_field(254, 124, 3)); + // 2*125 + ceil(log2(3)) + 3 = 250 + 2 + 3 = 255 ≥ 254 ✗ + assert!(!column_equation_fits_native_field(254, 125, 3)); + // 2*126 + ceil(log2(3)) + 3 = 252 + 2 + 3 = 257 ≥ 254 ✗ + assert!(!column_equation_fits_native_field(254, 126, 3)); + } + + #[test] + fn test_secp256r1_limb_bits_not_126() { + // Regression: limb_bits=126 with N=3 causes offset_w = 2^255 > p_BN254, + // making the schoolbook column equations unsound. + let (limb_bits, _) = get_optimal_msm_params(254, 256, 1, 256); + assert!( + limb_bits <= 124, + "secp256r1 limb_bits={limb_bits} exceeds safe maximum 124" + ); + } +} diff --git a/provekit/r1cs-compiler/src/msm/curve.rs b/provekit/r1cs-compiler/src/msm/curve.rs new file mode 100644 index 00000000..0876bfad --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/curve.rs @@ -0,0 +1,627 @@ +use { + ark_ff::{BigInteger, Field, PrimeField}, + provekit_common::FieldElement, +}; + +pub struct CurveParams { + pub field_modulus_p: [u64; 4], + pub curve_order_n: [u64; 4], + pub curve_a: [u64; 4], + pub curve_b: [u64; 4], + pub generator: ([u64; 4], [u64; 4]), + pub offset_point: ([u64; 4], [u64; 4]), +} + +impl CurveParams { + /// Decompose the field modulus p into `num_limbs` limbs of `limb_bits` + /// width each. + pub fn p_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&self.field_modulus_p, limb_bits, num_limbs) + } + + /// Decompose (p - 1) into `num_limbs` limbs of `limb_bits` width each. + pub fn p_minus_1_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + let p_minus_1 = sub_one_u64_4(&self.field_modulus_p); + decompose_to_limbs(&p_minus_1, limb_bits, num_limbs) + } + + /// Decompose the curve parameter `a` into `num_limbs` limbs of `limb_bits` + /// width. + pub fn curve_a_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&self.curve_a, limb_bits, num_limbs) + } + + /// Number of bits in the field modulus. + pub fn modulus_bits(&self) -> u32 { + if self.is_native_field() { + // p mod p = 0 as a field element, so we use the constant directly. + FieldElement::MODULUS_BIT_SIZE + } else { + let fe = curve_native_point_fe(&self.field_modulus_p); + fe.into_bigint().num_bits() + } + } + + /// Returns true if the curve's base field modulus equals the native BN254 + /// scalar field modulus. + pub fn is_native_field(&self) -> bool { + let native_mod = FieldElement::MODULUS; + self.field_modulus_p == native_mod.0 + } + + /// Convert modulus to a native field element (only valid when p < native + /// modulus). + pub fn p_native_fe(&self) -> FieldElement { + curve_native_point_fe(&self.field_modulus_p) + } + + /// Decompose the curve order n into `num_limbs` limbs of `limb_bits` width + /// each. + pub fn curve_order_n_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&self.curve_order_n, limb_bits, num_limbs) + } + + /// Decompose (curve_order_n - 1) into `num_limbs` limbs of `limb_bits` + /// width each. + pub fn curve_order_n_minus_1_limbs( + &self, + limb_bits: u32, + num_limbs: usize, + ) -> Vec { + let n_minus_1 = sub_one_u64_4(&self.curve_order_n); + decompose_to_limbs(&n_minus_1, limb_bits, num_limbs) + } + + /// Number of bits in the curve order n. + pub fn curve_order_bits(&self) -> u32 { + // Compute bit length directly from raw limbs to avoid reduction + // mod the native field (curve_order_n may exceed the native modulus). + let n = &self.curve_order_n; + for i in (0..4).rev() { + if n[i] != 0 { + return (i as u32) * 64 + (64 - n[i].leading_zeros()); + } + } + 0 + } + + /// Number of bits for the GLV half-scalar: `ceil(order_bits / 2)`. + /// This determines the bit width of the sub-scalars s1, s2 from half-GCD. + pub fn glv_half_bits(&self) -> u32 { + (self.curve_order_bits() + 1) / 2 + } + + /// Decompose the generator x-coordinate into limbs. + pub fn generator_x_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&self.generator.0, limb_bits, num_limbs) + } + + /// Decompose the offset point x-coordinate into limbs. + pub fn offset_x_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&self.offset_point.0, limb_bits, num_limbs) + } + + /// Decompose the offset point y-coordinate into limbs. + pub fn offset_y_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&self.offset_point.1, limb_bits, num_limbs) + } + + /// Compute `[2^n_doublings] * offset_point` on the curve (compile-time + /// only). + /// + /// Used to compute the accumulated offset after the scalar_mul_glv loop: + /// since the accumulator starts at R and gets doubled n times total, the + /// offset to subtract is `[2^n]*R`, not just `R`. + pub fn accumulated_offset(&self, n_doublings: usize) -> ([u64; 4], [u64; 4]) { + if self.is_native_field() { + self.accumulated_offset_native(n_doublings) + } else { + self.accumulated_offset_generic(n_doublings) + } + } + + /// Compute accumulated offset using FieldElement arithmetic (native field). + fn accumulated_offset_native(&self, n_doublings: usize) -> ([u64; 4], [u64; 4]) { + let mut x = curve_native_point_fe(&self.offset_point.0); + let mut y = curve_native_point_fe(&self.offset_point.1); + let a = curve_native_point_fe(&self.curve_a); + + for _ in 0..n_doublings { + let x_sq = x * x; + let num = x_sq + x_sq + x_sq + a; + let denom_inv = (y + y).inverse().unwrap(); + let lambda = num * denom_inv; + let x3 = lambda * lambda - x - x; + let y3 = lambda * (x - x3) - y; + x = x3; + y = y3; + } + + (x.into_bigint().0, y.into_bigint().0) + } + + /// Compute accumulated offset using generic 256-bit arithmetic (non-native + /// field). + fn accumulated_offset_generic(&self, n_doublings: usize) -> ([u64; 4], [u64; 4]) { + let p = &self.field_modulus_p; + let mut x = self.offset_point.0; + let mut y = self.offset_point.1; + let a = &self.curve_a; + + for _ in 0..n_doublings { + let (x3, y3) = u256_arith::ec_point_double(&x, &y, a, p); + x = x3; + y = y3; + } + + (x, y) + } +} + +/// Decompose a 256-bit value into `num_limbs` limbs of `limb_bits` width each, +/// returned as FieldElements. +pub fn decompose_to_limbs(val: &[u64; 4], limb_bits: u32, num_limbs: usize) -> Vec { + // Special case: when a single limb needs > 128 bits, FieldElement::from(u128) + // would truncate. Use from_sign_and_limbs to preserve the full value. + if num_limbs == 1 && limb_bits > 128 { + return vec![curve_native_point_fe(val)]; + } + + let mask: u128 = if limb_bits >= 128 { + u128::MAX + } else { + (1u128 << limb_bits) - 1 + }; + let mut result = vec![FieldElement::from(0u64); num_limbs]; + let mut remaining = *val; + for item in result.iter_mut() { + let lo = remaining[0] as u128 | ((remaining[1] as u128) << 64); + *item = FieldElement::from(lo & mask); + // Shift remaining right by limb_bits + if limb_bits >= 256 { + remaining = [0; 4]; + } else { + let mut shifted = [0u64; 4]; + let word_shift = (limb_bits / 64) as usize; + let bit_shift = limb_bits % 64; + for i in 0..4 { + if i + word_shift < 4 { + shifted[i] = remaining[i + word_shift] >> bit_shift; + if bit_shift > 0 && i + word_shift + 1 < 4 { + shifted[i] |= remaining[i + word_shift + 1] << (64 - bit_shift); + } + } + } + remaining = shifted; + } + } + result +} + +/// Subtract 1 from a [u64; 4] value. +fn sub_one_u64_4(val: &[u64; 4]) -> [u64; 4] { + let mut result = *val; + for limb in result.iter_mut() { + if *limb > 0 { + *limb -= 1; + return result; + } + *limb = u64::MAX; // borrow + } + result +} + +/// Converts a 256-bit value ([u64; 4]) into a single native field element. +pub fn curve_native_point_fe(val: &[u64; 4]) -> FieldElement { + FieldElement::from_sign_and_limbs(true, val) +} + +/// Negate a field element: compute `-val mod p` (i.e., `p - val`). +/// Returns `[0; 4]` when `val` is zero. +pub fn negate_field_element(val: &[u64; 4], modulus: &[u64; 4]) -> [u64; 4] { + if *val == [0u64; 4] { + return [0u64; 4]; + } + // val is in [1, p-1], so p - val is in [1, p-1] — no borrow. + let mut result = [0u64; 4]; + let mut borrow = false; + for i in 0..4 { + let (d1, b1) = modulus[i].overflowing_sub(val[i]); + let (d2, b2) = d1.overflowing_sub(borrow as u64); + result[i] = d2; + borrow = b1 || b2; + } + debug_assert!(!borrow, "negate_field_element: val >= modulus"); + result +} + +/// Grumpkin curve parameters. +/// +/// Grumpkin is a cycle-companion curve for BN254: its base field is the BN254 +/// scalar field, and its order is the BN254 base field order. +/// +/// Equation: y² = x³ − 17 (a = 0, b = −17 mod p) +pub fn grumpkin_params() -> CurveParams { + CurveParams { + // BN254 scalar field modulus + field_modulus_p: [ + 0x43e1f593f0000001_u64, + 0x2833e84879b97091_u64, + 0xb85045b68181585d_u64, + 0x30644e72e131a029_u64, + ], + // BN254 base field modulus + curve_order_n: [ + 0x3c208c16d87cfd47_u64, + 0x97816a916871ca8d_u64, + 0xb85045b68181585d_u64, + 0x30644e72e131a029_u64, + ], + curve_a: [0; 4], + // b = −17 mod p + curve_b: [ + 0x43e1f593effffff0_u64, + 0x2833e84879b97091_u64, + 0xb85045b68181585d_u64, + 0x30644e72e131a029_u64, + ], + // Generator G = (1, sqrt(−16) mod p) + generator: ([1, 0, 0, 0], [ + 0x833fc48d823f272c_u64, + 0x2d270d45f1181294_u64, + 0xcf135e7506a45d63_u64, + 0x0000000000000002_u64, + ]), + // Offset point = [2^128]G (large offset avoids collisions with small multiples of G) + offset_point: ( + [ + 0x626578b496650e95_u64, + 0x8678dcf264df6c01_u64, + 0xf0b3eb7e6d02aba8_u64, + 0x223748a4c4edde75_u64, + ], + [ + 0xb75fb4c26bcd4f35_u64, + 0x4d4ba4d97d5f99d9_u64, + 0xccab35fdbf52368a_u64, + 0x25b41c5f56f8472b_u64, + ], + ), + } +} + +/// 256-bit modular arithmetic for compile-time EC point computations. +/// Only used to precompute accumulated offset points; not performance-critical. +mod u256_arith { + type U256 = [u64; 4]; + + /// Returns true if a >= b. + fn gte(a: &U256, b: &U256) -> bool { + for i in (0..4).rev() { + if a[i] > b[i] { + return true; + } + if a[i] < b[i] { + return false; + } + } + true // equal + } + + /// a + b, returns (result, carry). + fn add(a: &U256, b: &U256) -> (U256, bool) { + let mut result = [0u64; 4]; + let mut carry = 0u128; + for i in 0..4 { + carry += a[i] as u128 + b[i] as u128; + result[i] = carry as u64; + carry >>= 64; + } + (result, carry != 0) + } + + /// a - b, returns (result, borrow). + fn sub(a: &U256, b: &U256) -> (U256, bool) { + let mut result = [0u64; 4]; + let mut borrow = false; + for i in 0..4 { + let (d1, b1) = a[i].overflowing_sub(b[i]); + let (d2, b2) = d1.overflowing_sub(borrow as u64); + result[i] = d2; + borrow = b1 || b2; + } + (result, borrow) + } + + /// (a + b) mod p. + pub fn mod_add(a: &U256, b: &U256, p: &U256) -> U256 { + let (s, overflow) = add(a, b); + if overflow || gte(&s, p) { + sub(&s, p).0 + } else { + s + } + } + + /// (a - b) mod p. + fn mod_sub(a: &U256, b: &U256, p: &U256) -> U256 { + let (d, borrow) = sub(a, b); + if borrow { + add(&d, p).0 + } else { + d + } + } + + /// Schoolbook multiplication producing 512-bit result. + fn mul_wide(a: &U256, b: &U256) -> [u64; 8] { + let mut result = [0u64; 8]; + for i in 0..4 { + let mut carry = 0u128; + for j in 0..4 { + let prod = (a[i] as u128) * (b[j] as u128) + result[i + j] as u128 + carry; + result[i + j] = prod as u64; + carry = prod >> 64; + } + result[i + 4] = result[i + 4].wrapping_add(carry as u64); + } + result + } + + /// Reduce a 512-bit value mod a 256-bit prime using bit-by-bit long + /// division. + fn mod_reduce_wide(a: &[u64; 8], p: &U256) -> U256 { + let mut total_bits = 0; + for i in (0..8).rev() { + if a[i] != 0 { + total_bits = i * 64 + (64 - a[i].leading_zeros() as usize); + break; + } + } + if total_bits == 0 { + return [0; 4]; + } + + let mut r = [0u64; 4]; + for bit_idx in (0..total_bits).rev() { + // Left shift r by 1 + let overflow = r[3] >> 63; + for j in (1..4).rev() { + r[j] = (r[j] << 1) | (r[j - 1] >> 63); + } + r[0] <<= 1; + + // Insert current bit of a + let word = bit_idx / 64; + let bit = bit_idx % 64; + r[0] |= (a[word] >> bit) & 1; + + // If r >= p (or overflow from shift), subtract p + if overflow != 0 || gte(&r, p) { + r = sub(&r, p).0; + } + } + r + } + + /// (a * b) mod p. + pub fn mod_mul(a: &U256, b: &U256, p: &U256) -> U256 { + let wide = mul_wide(a, b); + mod_reduce_wide(&wide, p) + } + + /// a^exp mod p using square-and-multiply. + fn mod_pow(base: &U256, exp: &U256, p: &U256) -> U256 { + let mut highest_bit = 0; + for i in (0..4).rev() { + if exp[i] != 0 { + highest_bit = i * 64 + (64 - exp[i].leading_zeros() as usize); + break; + } + } + if highest_bit == 0 { + return [1, 0, 0, 0]; + } + + let mut result: U256 = [1, 0, 0, 0]; + let mut base = *base; + for bit_idx in 0..highest_bit { + let word = bit_idx / 64; + let bit = bit_idx % 64; + if (exp[word] >> bit) & 1 == 1 { + result = mod_mul(&result, &base, p); + } + base = mod_mul(&base, &base, p); + } + result + } + + /// a^(-1) mod p via Fermat's little theorem: a^(p-2) mod p. + fn mod_inv(a: &U256, p: &U256) -> U256 { + let two: U256 = [2, 0, 0, 0]; + let exp = sub(p, &two).0; + mod_pow(a, &exp, p) + } + + /// EC point doubling on y^2 = x^3 + ax + b. + pub fn ec_point_double(x: &U256, y: &U256, a: &U256, p: &U256) -> (U256, U256) { + // lambda = (3*x^2 + a) / (2*y) + let x_sq = mod_mul(x, x, p); + let two_x_sq = mod_add(&x_sq, &x_sq, p); + let three_x_sq = mod_add(&two_x_sq, &x_sq, p); + let num = mod_add(&three_x_sq, a, p); + let two_y = mod_add(y, y, p); + let denom_inv = mod_inv(&two_y, p); + let lambda = mod_mul(&num, &denom_inv, p); + + // x3 = lambda^2 - 2*x + let lambda_sq = mod_mul(&lambda, &lambda, p); + let two_x = mod_add(x, x, p); + let x3 = mod_sub(&lambda_sq, &two_x, p); + + // y3 = lambda * (x - x3) - y + let x_minus_x3 = mod_sub(x, &x3, p); + let lambda_dx = mod_mul(&lambda, &x_minus_x3, p); + let y3 = mod_sub(&lambda_dx, y, p); + + (x3, y3) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_offset_point_on_curve_grumpkin() { + let c = grumpkin_params(); + let x = curve_native_point_fe(&c.offset_point.0); + let y = curve_native_point_fe(&c.offset_point.1); + let b = curve_native_point_fe(&c.curve_b); + // Grumpkin: y^2 = x^3 + b (a=0) + assert_eq!(y * y, x * x * x + b, "offset point not on Grumpkin"); + } + + #[test] + fn test_accumulated_offset_single_double_grumpkin() { + let c = grumpkin_params(); + let (x4, y4) = c.accumulated_offset(1); + let x = curve_native_point_fe(&x4); + let y = curve_native_point_fe(&y4); + let b = curve_native_point_fe(&c.curve_b); + // Should still be on curve + assert_eq!(y * y, x * x * x + b, "[4]G not on Grumpkin"); + } + + #[test] + fn test_accumulated_offset_native_vs_generic() { + let c = grumpkin_params(); + // Both paths should give the same result + let native = c.accumulated_offset_native(10); + let generic = c.accumulated_offset_generic(10); + assert_eq!(native, generic, "native vs generic mismatch for n=10"); + } + + #[test] + fn test_accumulated_offset_256_on_curve() { + let c = grumpkin_params(); + let (x, y) = c.accumulated_offset(256); + let xfe = curve_native_point_fe(&x); + let yfe = curve_native_point_fe(&y); + let b = curve_native_point_fe(&c.curve_b); + assert_eq!(yfe * yfe, xfe * xfe * xfe + b, "[2^257]G not on Grumpkin"); + } + + #[test] + fn test_offset_point_on_curve_secp256r1() { + let c = secp256r1_params(); + let p = &c.field_modulus_p; + let x = &c.offset_point.0; + let y = &c.offset_point.1; + let a = &c.curve_a; + let b = &c.curve_b; + // y^2 = x^3 + a*x + b (mod p) + let y_sq = u256_arith::mod_mul(y, y, p); + let x_sq = u256_arith::mod_mul(x, x, p); + let x_cubed = u256_arith::mod_mul(&x_sq, x, p); + let ax = u256_arith::mod_mul(a, x, p); + let x3_plus_ax = u256_arith::mod_add(&x_cubed, &ax, p); + let rhs = u256_arith::mod_add(&x3_plus_ax, b, p); + assert_eq!(y_sq, rhs, "offset point not on secp256r1"); + } + + #[test] + fn test_accumulated_offset_secp256r1() { + let c = secp256r1_params(); + let p = &c.field_modulus_p; + let a = &c.curve_a; + let b = &c.curve_b; + let (x, y) = c.accumulated_offset(256); + // Verify the accumulated offset is on the curve + let y_sq = u256_arith::mod_mul(&y, &y, p); + let x_sq = u256_arith::mod_mul(&x, &x, p); + let x_cubed = u256_arith::mod_mul(&x_sq, &x, p); + let ax = u256_arith::mod_mul(a, &x, p); + let x3_plus_ax = u256_arith::mod_add(&x_cubed, &ax, p); + let rhs = u256_arith::mod_add(&x3_plus_ax, b, p); + assert_eq!(y_sq, rhs, "accumulated offset not on secp256r1"); + } + + #[test] + fn test_fe_roundtrip() { + // Verify from_sign_and_limbs / into_bigint roundtrip + let val: [u64; 4] = [42, 0, 0, 0]; + let fe = curve_native_point_fe(&val); + let back = fe.into_bigint().0; + assert_eq!(val, back, "roundtrip failed for small value"); + + let val2: [u64; 4] = [ + 0x6d8bc688cdbffffe, + 0x19a74caa311e13d4, + 0xddeb49cdaa36306d, + 0x06ce1b0827aafa85, + ]; + let fe2 = curve_native_point_fe(&val2); + let back2 = fe2.into_bigint().0; + assert_eq!(val2, back2, "roundtrip failed for offset x"); + } +} + +#[allow(dead_code)] +pub fn secp256r1_params() -> CurveParams { + CurveParams { + field_modulus_p: [ + 0xffffffffffffffff_u64, + 0xffffffff_u64, + 0x0_u64, + 0xffffffff00000001_u64, + ], + curve_order_n: [ + 0xf3b9cac2fc632551_u64, + 0xbce6faada7179e84_u64, + 0xffffffffffffffff_u64, + 0xffffffff00000000_u64, + ], + curve_a: [ + 0xfffffffffffffffc_u64, + 0x00000000ffffffff_u64, + 0x0000000000000000_u64, + 0xffffffff00000001_u64, + ], + curve_b: [ + 0x3bce3c3e27d2604b_u64, + 0x651d06b0cc53b0f6_u64, + 0xb3ebbd55769886bc_u64, + 0x5ac635d8aa3a93e7_u64, + ], + generator: ( + [ + 0xf4a13945d898c296_u64, + 0x77037d812deb33a0_u64, + 0xf8bce6e563a440f2_u64, + 0x6b17d1f2e12c4247_u64, + ], + [ + 0xcbb6406837bf51f5_u64, + 0x2bce33576b315ece_u64, + 0x8ee7eb4a7c0f9e16_u64, + 0x4fe342e2fe1a7f9b_u64, + ], + ), + // Offset point = [2^128]G (large offset avoids collisions with small multiples of G) + offset_point: ( + [ + 0x57c84fc9d789bd85_u64, + 0xfc35ff7dc297eac3_u64, + 0xfb982fd588c6766e_u64, + 0x447d739beedb5e67_u64, + ], + [ + 0x0c7e33c972e25b32_u64, + 0x3d349b95a7fae500_u64, + 0xe12e9d953a4aaff7_u64, + 0x2d4825ab834131ee_u64, + ], + ), + } +} diff --git a/provekit/r1cs-compiler/src/msm/ec_points.rs b/provekit/r1cs-compiler/src/msm/ec_points.rs new file mode 100644 index 00000000..0138a571 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/ec_points.rs @@ -0,0 +1,270 @@ +use super::FieldOps; + +/// Generic point doubling on y^2 = x^3 + ax + b. +/// +/// Given P = (x1, y1), computes 2P = (x3, y3): +/// lambda = (3 * x1^2 + a) / (2 * y1) +/// x3 = lambda^2 - 2 * x1 +/// y3 = lambda * (x1 - x3) - y1 +/// +/// Edge case — y1 = 0 (point of order 2): +/// When y1 = 0, the denominator 2*y1 = 0 and the inverse does not exist. +/// The result should be the point at infinity (identity element). +/// This function does NOT handle that case — the constraint system will +/// be unsatisfiable if y1 = 0 (the inverse verification will fail to +/// verify 0 * inv = 1 mod p). The caller must check y1 = 0 using +/// compute_is_zero and conditionally select the point-at-infinity +/// result before calling this function. +pub fn point_double(ops: &mut F, x1: F::Elem, y1: F::Elem) -> (F::Elem, F::Elem) { + let a = ops.curve_a(); + + // Computing numerator = 3 * x1^2 + a + let x1_sq = ops.mul(x1, x1); + let two_x1_sq = ops.add(x1_sq, x1_sq); + let three_x1_sq = ops.add(two_x1_sq, x1_sq); + let numerator = ops.add(three_x1_sq, a); + + // Computing denominator = 2 * y1 + let denominator = ops.add(y1, y1); + + // Computing lambda = numerator * denominator^(-1) + let denom_inv = ops.inv(denominator); + let lambda = ops.mul(numerator, denom_inv); + + // Computing x3 = lambda^2 - 2 * x1 + let lambda_sq = ops.mul(lambda, lambda); + let two_x1 = ops.add(x1, x1); + let x3 = ops.sub(lambda_sq, two_x1); + + // Computing y3 = lambda * (x1 - x3) - y1 + let x1_minus_x3 = ops.sub(x1, x3); + let lambda_dx = ops.mul(lambda, x1_minus_x3); + let y3 = ops.sub(lambda_dx, y1); + + (x3, y3) +} + +/// Generic point addition on y^2 = x^3 + ax + b. +/// +/// Given P1 = (x1, y1) and P2 = (x2, y2), computes P1 + P2 = (x3, y3): +/// lambda = (y2 - y1) / (x2 - x1) +/// x3 = lambda^2 - x1 - x2 +/// y3 = lambda * (x1 - x3) - y1 +/// +/// Edge cases — x1 = x2: +/// When x1 = x2, the denominator (x2 - x1) = 0 and the inverse does +/// not exist. This covers two cases: +/// - P1 = P2 (same point): use `point_double` instead. +/// - P1 = -P2 (y1 = -y2): the result is the point at infinity. +/// This function does NOT handle either case — the constraint system +/// will be unsatisfiable if x1 = x2. The caller must detect this +/// and branch accordingly. +pub fn point_add( + ops: &mut F, + x1: F::Elem, + y1: F::Elem, + x2: F::Elem, + y2: F::Elem, +) -> (F::Elem, F::Elem) { + // Computing lambda = (y2 - y1) / (x2 - x1) + let numerator = ops.sub(y2, y1); + let denominator = ops.sub(x2, x1); + let denom_inv = ops.inv(denominator); + let lambda = ops.mul(numerator, denom_inv); + + // Computing x3 = lambda^2 - x1 - x2 + let lambda_sq = ops.mul(lambda, lambda); + let x1_plus_x2 = ops.add(x1, x2); + let x3 = ops.sub(lambda_sq, x1_plus_x2); + + // Computing y3 = lambda * (x1 - x3) - y1 + let x1_minus_x3 = ops.sub(x1, x3); + let lambda_dx = ops.mul(lambda, x1_minus_x3); + let y3 = ops.sub(lambda_dx, y1); + + (x3, y3) +} + +/// Conditional point select: returns `on_true` if `flag` is 1, `on_false` if +/// `flag` is 0. +/// +/// Constrains `flag` to be boolean (`flag * flag = flag`). +pub fn point_select( + ops: &mut F, + flag: usize, + on_false: (F::Elem, F::Elem), + on_true: (F::Elem, F::Elem), +) -> (F::Elem, F::Elem) { + let x = ops.select(flag, on_false.0, on_true.0); + let y = ops.select(flag, on_false.1, on_true.1); + (x, y) +} + +/// Conditional point select without boolean constraint on `flag`. +/// Caller must ensure `flag` is already constrained boolean. +fn point_select_unchecked( + ops: &mut F, + flag: usize, + on_false: (F::Elem, F::Elem), + on_true: (F::Elem, F::Elem), +) -> (F::Elem, F::Elem) { + let x = ops.select_unchecked(flag, on_false.0, on_true.0); + let y = ops.select_unchecked(flag, on_false.1, on_true.1); + (x, y) +} + +/// Builds a point table for windowed scalar multiplication. +/// +/// T\[0\] = P (dummy entry, used when window digit = 0) +/// T\[1\] = P, T\[2\] = 2P, T\[i\] = T\[i-1\] + P for i >= 3. +fn build_point_table( + ops: &mut F, + px: F::Elem, + py: F::Elem, + table_size: usize, +) -> Vec<(F::Elem, F::Elem)> { + assert!(table_size >= 2); + let mut table = Vec::with_capacity(table_size); + table.push((px, py)); // T[0] = P (dummy) + table.push((px, py)); // T[1] = P + if table_size > 2 { + table.push(point_double(ops, px, py)); // T[2] = 2P + for i in 3..table_size { + let prev = table[i - 1]; + table.push(point_add(ops, prev.0, prev.1, px, py)); + } + } + table +} + +/// Selects T\[d\] from a point table using bit witnesses, where `d = Σ +/// bits\[i\] * 2^i`. +/// +/// Uses a binary tree of `point_select`s: processes bits from MSB to LSB, +/// halving the candidate set at each level. Total: `(2^w - 1)` point selects +/// for a table of `2^w` entries. +/// +/// Each bit is constrained boolean exactly once, then all subsequent selects +/// on that bit use the unchecked variant. +fn table_lookup( + ops: &mut F, + table: &[(F::Elem, F::Elem)], + bits: &[usize], +) -> (F::Elem, F::Elem) { + assert_eq!(table.len(), 1 << bits.len()); + let mut current: Vec<(F::Elem, F::Elem)> = table.to_vec(); + // Process bits from MSB to LSB + for &bit in bits.iter().rev() { + ops.constrain_flag(bit); // constrain boolean once per bit + let half = current.len() / 2; + let mut next = Vec::with_capacity(half); + for i in 0..half { + next.push(point_select_unchecked( + ops, + bit, + current[i], + current[i + half], + )); + } + current = next; + } + current[0] +} + +/// Interleaved two-point scalar multiplication for FakeGLV. +/// +/// Computes `[s1]P + [s2]R` using shared doublings, where s1 and s2 are +/// half-width scalars (typically ~128-bit for 256-bit curves). The +/// accumulator starts at an offset point and the caller checks equality +/// with the accumulated offset to verify the constraint `[s1]P + [s2]R = O`. +/// +/// Structure per window (from MSB to LSB): +/// 1. `w` shared doublings on accumulator +/// 2. Table lookup in T_P\[d1\] for s1's window digit +/// 3. point_add(acc, T_P\[d1\]) + is_zero(d1) + point_select +/// 4. Table lookup in T_R\[d2\] for s2's window digit +/// 5. point_add(acc, T_R\[d2\]) + is_zero(d2) + point_select +/// +/// Returns the final accumulator (x, y). +pub fn scalar_mul_glv( + ops: &mut F, + // Point P (table 1) + px: F::Elem, + py: F::Elem, + s1_bits: &[usize], // 128 bit witnesses for |s1| + // Point R (table 2) — the claimed output + rx: F::Elem, + ry: F::Elem, + s2_bits: &[usize], // 128 bit witnesses for |s2| + // Shared parameters + window_size: usize, + offset_x: F::Elem, + offset_y: F::Elem, +) -> (F::Elem, F::Elem) { + let n1 = s1_bits.len(); + let n2 = s2_bits.len(); + assert_eq!(n1, n2, "s1 and s2 must have the same number of bits"); + let n = n1; + let w = window_size; + let table_size = 1 << w; + + // TODO : implement lazy overflow as used in gnark. + + // Build point tables: T_P[i] = [i]P, T_R[i] = [i]R + let table_p = build_point_table(ops, px, py, table_size); + let table_r = build_point_table(ops, rx, ry, table_size); + + let num_windows = (n + w - 1) / w; + + // Initialize accumulator with the offset point + let mut acc = (offset_x, offset_y); + + // Process all windows from MSB down to LSB + for i in (0..num_windows).rev() { + let bit_start = i * w; + let bit_end = std::cmp::min(bit_start + w, n); + let actual_w = bit_end - bit_start; + + // w shared doublings on the accumulator + let mut doubled_acc = acc; + for _ in 0..w { + doubled_acc = point_double(ops, doubled_acc.0, doubled_acc.1); + } + + // --- Process P's window digit (s1) --- + let s1_window_bits = &s1_bits[bit_start..bit_end]; + let lookup_table_p = if actual_w < w { + &table_p[..1 << actual_w] + } else { + &table_p[..] + }; + let looked_up_p = table_lookup(ops, lookup_table_p, s1_window_bits); + let added_p = point_add( + ops, + doubled_acc.0, + doubled_acc.1, + looked_up_p.0, + looked_up_p.1, + ); + let digit_p = ops.pack_bits(s1_window_bits); + let digit_p_is_zero = ops.is_zero(digit_p); + // is_zero already constrains its output boolean; skip redundant check + let after_p = point_select_unchecked(ops, digit_p_is_zero, added_p, doubled_acc); + + // --- Process R's window digit (s2) --- + let s2_window_bits = &s2_bits[bit_start..bit_end]; + let lookup_table_r = if actual_w < w { + &table_r[..1 << actual_w] + } else { + &table_r[..] + }; + let looked_up_r = table_lookup(ops, lookup_table_r, s2_window_bits); + let added_r = point_add(ops, after_p.0, after_p.1, looked_up_r.0, looked_up_r.1); + let digit_r = ops.pack_bits(s2_window_bits); + let digit_r_is_zero = ops.is_zero(digit_r); + // is_zero already constrains its output boolean; skip redundant check + acc = point_select_unchecked(ops, digit_r_is_zero, added_r, after_p); + } + + acc +} diff --git a/provekit/r1cs-compiler/src/msm/mod.rs b/provekit/r1cs-compiler/src/msm/mod.rs new file mode 100644 index 00000000..0c24e4f7 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/mod.rs @@ -0,0 +1,1110 @@ +pub mod cost_model; +pub mod curve; +pub mod ec_points; +pub mod multi_limb_arith; +pub mod multi_limb_ops; + +use { + crate::{ + digits::{add_digital_decomposition, DigitalDecompositionWitnessesBuilder}, + msm::multi_limb_arith::compute_is_zero, + noir_to_r1cs::NoirToR1CSCompiler, + }, + ark_ff::{AdditiveGroup, Field, PrimeField}, + curve::{decompose_to_limbs as decompose_to_limbs_pub, CurveParams}, + multi_limb_ops::{MultiLimbOps, MultiLimbParams}, + provekit_common::{ + witness::{ConstantOrR1CSWitness, ConstantTerm, SumTerm, WitnessBuilder}, + FieldElement, + }, + std::collections::BTreeMap, +}; + +// --------------------------------------------------------------------------- +// Limbs: fixed-capacity, Copy array of witness indices +// --------------------------------------------------------------------------- + +/// Maximum number of limbs supported. Covers all practical field sizes +/// (e.g. a 512-bit modulus with 16-bit limbs = 32 limbs). +pub const MAX_LIMBS: usize = 32; + +/// A fixed-capacity array of witness indices, indexed by limb position. +/// +/// This type is `Copy`, so it can be used as `FieldOps::Elem` without +/// requiring const generics or dispatch macros. The runtime `len` field +/// tracks how many limbs are actually in use. +#[derive(Clone, Copy)] +pub struct Limbs { + data: [usize; MAX_LIMBS], + len: usize, +} + +impl Limbs { + /// Sentinel value for uninitialized limb slots. Using `usize::MAX` + /// ensures accidental use of an unfilled slot indexes an absurdly + /// large witness, causing an immediate out-of-bounds panic. + const UNINIT: usize = usize::MAX; + + /// Create a new `Limbs` with `len` limbs, all initialized to `UNINIT`. + pub fn new(len: usize) -> Self { + assert!( + len > 0 && len <= MAX_LIMBS, + "limb count must be 1..={MAX_LIMBS}, got {len}" + ); + Self { + data: [Self::UNINIT; MAX_LIMBS], + len, + } + } + + /// Create a single-limb `Limbs` wrapping one witness index. + pub fn single(value: usize) -> Self { + let mut l = Self { + data: [Self::UNINIT; MAX_LIMBS], + len: 1, + }; + l.data[0] = value; + l + } + + /// View the active limbs as a slice. + pub fn as_slice(&self) -> &[usize] { + &self.data[..self.len] + } + + /// Number of active limbs. + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + self.len + } +} + +impl std::fmt::Debug for Limbs { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_list().entries(self.as_slice().iter()).finish() + } +} + +impl PartialEq for Limbs { + fn eq(&self, other: &Self) -> bool { + self.len == other.len && self.data[..self.len] == other.data[..other.len] + } +} +impl Eq for Limbs {} + +impl std::ops::Index for Limbs { + type Output = usize; + fn index(&self, i: usize) -> &usize { + debug_assert!( + i < self.len, + "Limbs index {i} out of bounds (len={})", + self.len + ); + &self.data[i] + } +} + +impl std::ops::IndexMut for Limbs { + fn index_mut(&mut self, i: usize) -> &mut usize { + debug_assert!( + i < self.len, + "Limbs index {i} out of bounds (len={})", + self.len + ); + &mut self.data[i] + } +} + +// --------------------------------------------------------------------------- +// FieldOps trait +// --------------------------------------------------------------------------- + +pub trait FieldOps { + type Elem: Copy; + + fn add(&mut self, a: Self::Elem, b: Self::Elem) -> Self::Elem; + fn sub(&mut self, a: Self::Elem, b: Self::Elem) -> Self::Elem; + fn mul(&mut self, a: Self::Elem, b: Self::Elem) -> Self::Elem; + fn inv(&mut self, a: Self::Elem) -> Self::Elem; + fn curve_a(&mut self) -> Self::Elem; + + /// Constrains `flag` to be boolean (`flag * flag = flag`). + fn constrain_flag(&mut self, flag: usize); + + /// Conditional select without boolean constraint on `flag`. + /// Caller must ensure `flag` is already constrained boolean. + fn select_unchecked( + &mut self, + flag: usize, + on_false: Self::Elem, + on_true: Self::Elem, + ) -> Self::Elem; + + /// Conditional select: returns `on_true` if `flag` is 1, `on_false` if + /// `flag` is 0. Constrains `flag` to be boolean (`flag * flag = flag`). + fn select(&mut self, flag: usize, on_false: Self::Elem, on_true: Self::Elem) -> Self::Elem { + self.constrain_flag(flag); + self.select_unchecked(flag, on_false, on_true) + } + + /// Checks if a native witness value is zero. + /// Returns a boolean witness: 1 if zero, 0 if non-zero. + fn is_zero(&mut self, value: usize) -> usize; + + /// Packs bit witnesses into a single digit witness: `d = Σ bits[i] * 2^i`. + /// Does NOT constrain bits to be boolean — caller must ensure that. + fn pack_bits(&mut self, bits: &[usize]) -> usize; + + /// Returns a constant field element from its limb decomposition. + fn constant_limbs(&mut self, limbs: &[FieldElement]) -> Self::Elem; +} + +// --------------------------------------------------------------------------- +// Private helpers +// --------------------------------------------------------------------------- + +/// Constrains `flag` to be boolean: `flag * flag = flag`. +pub(crate) fn constrain_boolean(compiler: &mut NoirToR1CSCompiler, flag: usize) { + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, flag)], + &[(FieldElement::ONE, flag)], + &[(FieldElement::ONE, flag)], + ); +} + +/// Single-witness conditional select: `out = on_false + flag * (on_true - +/// on_false)`. +pub(crate) fn select_witness( + compiler: &mut NoirToR1CSCompiler, + flag: usize, + on_false: usize, + on_true: usize, +) -> usize { + // When both branches are the same witness, result is trivially that witness. + // Avoids duplicate column indices in R1CS from `on_true - on_false` when + // both share the same witness index. + if on_false == on_true { + return on_false; + } + let diff = compiler.add_sum(vec![ + SumTerm(None, on_true), + SumTerm(Some(-FieldElement::ONE), on_false), + ]); + let flag_diff = compiler.add_product(flag, diff); + compiler.add_sum(vec![SumTerm(None, on_false), SumTerm(None, flag_diff)]) +} + +/// Packs bit witnesses into a digit: `d = Σ bits[i] * 2^i`. +pub(crate) fn pack_bits_helper(compiler: &mut NoirToR1CSCompiler, bits: &[usize]) -> usize { + let terms: Vec = bits + .iter() + .enumerate() + .map(|(i, &bit)| SumTerm(Some(FieldElement::from(1u128 << i)), bit)) + .collect(); + compiler.add_sum(terms) +} + +/// Computes `a OR b` for two boolean witnesses: `1 - (1 - a)(1 - b)`. +/// Does NOT constrain a or b to be boolean — caller must ensure that. +fn compute_boolean_or(compiler: &mut NoirToR1CSCompiler, a: usize, b: usize) -> usize { + let one = compiler.witness_one(); + let one_minus_a = compiler.add_sum(vec![ + SumTerm(None, one), + SumTerm(Some(-FieldElement::ONE), a), + ]); + let one_minus_b = compiler.add_sum(vec![ + SumTerm(None, one), + SumTerm(Some(-FieldElement::ONE), b), + ]); + let product = compiler.add_product(one_minus_a, one_minus_b); + compiler.add_sum(vec![ + SumTerm(None, one), + SumTerm(Some(-FieldElement::ONE), product), + ]) +} + +/// Detects whether a point-scalar pair is degenerate (scalar=0 or point at +/// infinity). Constrains `inf_flag` to boolean. Returns `is_skip` (1 if +/// degenerate). +fn detect_skip( + compiler: &mut NoirToR1CSCompiler, + s_lo: usize, + s_hi: usize, + inf_flag: usize, +) -> usize { + constrain_boolean(compiler, inf_flag); + let is_zero_s_lo = compute_is_zero(compiler, s_lo); + let is_zero_s_hi = compute_is_zero(compiler, s_hi); + let s_is_zero = compiler.add_product(is_zero_s_lo, is_zero_s_hi); + compute_boolean_or(compiler, s_is_zero, inf_flag) +} + +/// Constrains `a * b = 0`. +fn constrain_product_zero(compiler: &mut NoirToR1CSCompiler, a: usize, b: usize) { + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, a)], &[(FieldElement::ONE, b)], &[( + FieldElement::ZERO, + compiler.witness_one(), + )]); +} + +// --------------------------------------------------------------------------- +// Params builder (runtime num_limbs, no const generics) +// --------------------------------------------------------------------------- + +/// Build `MultiLimbParams` for a given runtime `num_limbs`. +fn build_params(num_limbs: usize, limb_bits: u32, curve: &CurveParams) -> MultiLimbParams { + let is_native = curve.is_native_field(); + let two_pow_w = FieldElement::from(2u64).pow([limb_bits as u64]); + let modulus_fe = if !is_native { + Some(curve.p_native_fe()) + } else { + None + }; + MultiLimbParams { + num_limbs, + limb_bits, + p_limbs: curve.p_limbs(limb_bits, num_limbs), + p_minus_1_limbs: curve.p_minus_1_limbs(limb_bits, num_limbs), + two_pow_w, + modulus_raw: curve.field_modulus_p, + curve_a_limbs: curve.curve_a_limbs(limb_bits, num_limbs), + is_native, + modulus_fe, + } +} + +// --------------------------------------------------------------------------- +// MSM entry point +// --------------------------------------------------------------------------- + +/// Processes all deferred MSM operations. +/// +/// Internally selects the optimal (limb_bits, window_size) via cost model +/// and uses Grumpkin curve parameters. +/// +/// Each entry is `(points, scalars, (out_x, out_y, out_inf))` where: +/// - `points` has layout `[x1, y1, inf1, x2, y2, inf2, ...]` (3 per point) +/// - `scalars` has layout `[s1_lo, s1_hi, s2_lo, s2_hi, ...]` (2 per scalar) +/// - outputs are the R1CS witness indices for the result point +/// Grumpkin-specific MSM entry point (used by the Noir `MultiScalarMul` black +/// box). +pub fn add_msm( + compiler: &mut NoirToR1CSCompiler, + msm_ops: Vec<( + Vec, + Vec, + (usize, usize, usize), + )>, + range_checks: &mut BTreeMap>, +) { + let curve = curve::grumpkin_params(); + add_msm_with_curve(compiler, msm_ops, range_checks, &curve); +} + +/// Curve-agnostic MSM: compiles MSM operations for any curve described by +/// `curve`. +pub fn add_msm_with_curve( + compiler: &mut NoirToR1CSCompiler, + msm_ops: Vec<( + Vec, + Vec, + (usize, usize, usize), + )>, + range_checks: &mut BTreeMap>, + curve: &CurveParams, +) { + if msm_ops.is_empty() { + return; + } + + let native_bits = FieldElement::MODULUS_BIT_SIZE; + let curve_bits = curve.modulus_bits(); + let n_points: usize = msm_ops.iter().map(|(pts, ..)| pts.len() / 3).sum(); + let (limb_bits, window_size) = + cost_model::get_optimal_msm_params(native_bits, curve_bits, n_points, 256); + + for (points, scalars, outputs) in msm_ops { + add_single_msm( + compiler, + &points, + &scalars, + outputs, + limb_bits, + window_size, + range_checks, + curve, + ); + } +} + +/// Processes a single MSM operation. +fn add_single_msm( + compiler: &mut NoirToR1CSCompiler, + points: &[ConstantOrR1CSWitness], + scalars: &[ConstantOrR1CSWitness], + outputs: (usize, usize, usize), + limb_bits: u32, + window_size: usize, + range_checks: &mut BTreeMap>, + curve: &CurveParams, +) { + assert!( + points.len() % 3 == 0, + "points length must be a multiple of 3" + ); + let n = points.len() / 3; + assert_eq!( + scalars.len(), + 2 * n, + "scalars length must be 2x the number of points" + ); + + // Resolve all inputs to witness indices + let point_wits: Vec = points.iter().map(|p| resolve_input(compiler, p)).collect(); + let scalar_wits: Vec = scalars.iter().map(|s| resolve_input(compiler, s)).collect(); + + let is_native = curve.is_native_field(); + let num_limbs = if is_native { + 1 + } else { + (curve.modulus_bits() as usize + limb_bits as usize - 1) / limb_bits as usize + }; + + process_single_msm( + compiler, + &point_wits, + &scalar_wits, + outputs, + num_limbs, + limb_bits, + window_size, + range_checks, + curve, + ); +} + +/// Process a full single-MSM with runtime `num_limbs`. +/// +/// Uses FakeGLV for ALL points: each point P_i with scalar s_i is verified +/// using scalar decomposition and half-width interleaved scalar mul. +/// +/// For `n_points == 1`, R = (out_x, out_y) is the ACIR output. +/// For `n_points > 1`, R_i = EcScalarMulHint witnesses, accumulated via +/// point_add and constrained against the ACIR output. +fn process_single_msm<'a>( + mut compiler: &'a mut NoirToR1CSCompiler, + point_wits: &[usize], + scalar_wits: &[usize], + outputs: (usize, usize, usize), + num_limbs: usize, + limb_bits: u32, + window_size: usize, + mut range_checks: &'a mut BTreeMap>, + curve: &CurveParams, +) { + let n_points = point_wits.len() / 3; + let (out_x, out_y, out_inf) = outputs; + + if n_points == 1 { + // Single-point: R is the ACIR output directly + let px_witness = point_wits[0]; + let py_witness = point_wits[1]; + let inf_flag = point_wits[2]; + let s_lo = scalar_wits[0]; + let s_hi = scalar_wits[1]; + + // --- Detect degenerate case: is_skip = (scalar == 0) OR (point is infinity) + let is_skip = detect_skip(compiler, s_lo, s_hi, inf_flag); + + // --- Sanitize inputs: swap in generator G and scalar=1 when is_skip --- + let one = compiler.witness_one(); + let gen_x_witness = + add_constant_witness(compiler, curve::curve_native_point_fe(&curve.generator.0)); + let gen_y_witness = + add_constant_witness(compiler, curve::curve_native_point_fe(&curve.generator.1)); + + let sanitized_px = select_witness(compiler, is_skip, px_witness, gen_x_witness); + let sanitized_py = select_witness(compiler, is_skip, py_witness, gen_y_witness); + + // When is_skip=1, use scalar=(1, 0) so FakeGLV computes [1]*G = G + let zero_witness = add_constant_witness(compiler, FieldElement::ZERO); + let sanitized_s_lo = select_witness(compiler, is_skip, s_lo, one); + let sanitized_s_hi = select_witness(compiler, is_skip, s_hi, zero_witness); + + // Sanitize R (output point): when is_skip=1, R must be G (since [1]*G = G) + let sanitized_rx = select_witness(compiler, is_skip, out_x, gen_x_witness); + let sanitized_ry = select_witness(compiler, is_skip, out_y, gen_y_witness); + + // Decompose sanitized P into limbs + let (px, py) = decompose_point_to_limbs( + compiler, + sanitized_px, + sanitized_py, + num_limbs, + limb_bits, + range_checks, + ); + // Decompose sanitized R into limbs + let (rx, ry) = decompose_point_to_limbs( + compiler, + sanitized_rx, + sanitized_ry, + num_limbs, + limb_bits, + range_checks, + ); + + // Run FakeGLV on sanitized values (always satisfiable) + (compiler, range_checks) = verify_point_fakeglv( + compiler, + range_checks, + px, + py, + rx, + ry, + sanitized_s_lo, + sanitized_s_hi, + num_limbs, + limb_bits, + window_size, + curve, + ); + + // --- Mask output: when is_skip, output must be (0, 0, 1) --- + constrain_equal(compiler, out_inf, is_skip); + constrain_product_zero(compiler, is_skip, out_x); + constrain_product_zero(compiler, is_skip, out_y); + } else { + // Multi-point: compute R_i = [s_i]P_i via hints, verify each with FakeGLV, + // then accumulate R_i's with offset-based accumulation and skip handling. + let one = compiler.witness_one(); + + // Generator constants for sanitization + let gen_x_fe = curve::curve_native_point_fe(&curve.generator.0); + let gen_y_fe = curve::curve_native_point_fe(&curve.generator.1); + let gen_x_witness = add_constant_witness(compiler, gen_x_fe); + let gen_y_witness = add_constant_witness(compiler, gen_y_fe); + let zero_witness = add_constant_witness(compiler, FieldElement::ZERO); + + // Build params once for all multi-limb ops in the multi-point path + let params = build_params(num_limbs, limb_bits, curve); + + // Offset point as limbs for accumulation + let offset_x_values = curve.offset_x_limbs(limb_bits, num_limbs); + let offset_y_values = curve.offset_y_limbs(limb_bits, num_limbs); + + // Start accumulator at offset_point + let mut ops = MultiLimbOps { + compiler, + range_checks, + params: ¶ms, + }; + let mut acc_x = ops.constant_limbs(&offset_x_values); + let mut acc_y = ops.constant_limbs(&offset_y_values); + compiler = ops.compiler; + range_checks = ops.range_checks; + + // Track all_skipped = product of all is_skip flags + let mut all_skipped: Option = None; + + for i in 0..n_points { + let px_witness = point_wits[3 * i]; + let py_witness = point_wits[3 * i + 1]; + let inf_flag = point_wits[3 * i + 2]; + let s_lo = scalar_wits[2 * i]; + let s_hi = scalar_wits[2 * i + 1]; + + // --- Detect degenerate case --- + let is_skip = detect_skip(compiler, s_lo, s_hi, inf_flag); + + // Track all_skipped + all_skipped = Some(match all_skipped { + None => is_skip, + Some(prev) => compiler.add_product(prev, is_skip), + }); + + // --- Sanitize inputs --- + let sanitized_px = select_witness(compiler, is_skip, px_witness, gen_x_witness); + let sanitized_py = select_witness(compiler, is_skip, py_witness, gen_y_witness); + let sanitized_s_lo = select_witness(compiler, is_skip, s_lo, one); + let sanitized_s_hi = select_witness(compiler, is_skip, s_hi, zero_witness); + + // EcScalarMulHint uses sanitized inputs + let hint_start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::EcScalarMulHint { + output_start: hint_start, + px: sanitized_px, + py: sanitized_py, + s_lo: sanitized_s_lo, + s_hi: sanitized_s_hi, + curve_a: curve.curve_a, + field_modulus_p: curve.field_modulus_p, + }); + let rx_witness = hint_start; + let ry_witness = hint_start + 1; + + // When is_skip=1, R should be G (since [1]*G = G) + let sanitized_rx = select_witness(compiler, is_skip, rx_witness, gen_x_witness); + let sanitized_ry = select_witness(compiler, is_skip, ry_witness, gen_y_witness); + + // Decompose sanitized P_i into limbs + let (px, py) = decompose_point_to_limbs( + compiler, + sanitized_px, + sanitized_py, + num_limbs, + limb_bits, + range_checks, + ); + // Decompose sanitized R_i into limbs + let (rx, ry) = decompose_point_to_limbs( + compiler, + sanitized_rx, + sanitized_ry, + num_limbs, + limb_bits, + range_checks, + ); + + // Verify R_i = [s_i]P_i using FakeGLV (on sanitized values) + (compiler, range_checks) = verify_point_fakeglv( + compiler, + range_checks, + px, + py, + rx, + ry, + sanitized_s_lo, + sanitized_s_hi, + num_limbs, + limb_bits, + window_size, + curve, + ); + + // --- Offset-based accumulation with conditional select --- + // Compute candidate = point_add(acc, R_i) + // Then select: if is_skip, keep acc unchanged; else use candidate + let mut ops = MultiLimbOps { + compiler, + range_checks, + params: ¶ms, + }; + let (cand_x, cand_y) = ec_points::point_add(&mut ops, acc_x, acc_y, rx, ry); + let (new_acc_x, new_acc_y) = + ec_points::point_select(&mut ops, is_skip, (cand_x, cand_y), (acc_x, acc_y)); + acc_x = new_acc_x; + acc_y = new_acc_y; + compiler = ops.compiler; + range_checks = ops.range_checks; + } + + let all_skipped = all_skipped.expect("MSM must have at least one point"); + + // Subtract offset: result = point_add(acc, -offset) + // Negated offset = (offset_x, -offset_y) + let neg_offset_y_raw = + curve::negate_field_element(&curve.offset_point.1, &curve.field_modulus_p); + let neg_offset_y_values = + curve::decompose_to_limbs(&neg_offset_y_raw, limb_bits, num_limbs); + + // When all_skipped, acc == offset_point, so subtracting offset would be + // point_add(O, -O) which fails (x1 == x2). Use generator G as the + // subtraction target instead; the result won't matter since we'll mask it. + let gen_x_limb_values = curve.generator_x_limbs(limb_bits, num_limbs); + let neg_gen_y_raw = curve::negate_field_element(&curve.generator.1, &curve.field_modulus_p); + let neg_gen_y_values = curve::decompose_to_limbs(&neg_gen_y_raw, limb_bits, num_limbs); + + let mut ops = MultiLimbOps { + compiler, + range_checks, + params: ¶ms, + }; + + // Select subtraction point: if all_skipped, use -G; else use -offset + let sub_x = { + let off_x = ops.constant_limbs(&offset_x_values); + let g_x = ops.constant_limbs(&gen_x_limb_values); + ops.select(all_skipped, off_x, g_x) + }; + let sub_y = { + let neg_off_y = ops.constant_limbs(&neg_offset_y_values); + let neg_g_y = ops.constant_limbs(&neg_gen_y_values); + ops.select(all_skipped, neg_off_y, neg_g_y) + }; + + let (result_x, result_y) = ec_points::point_add(&mut ops, acc_x, acc_y, sub_x, sub_y); + compiler = ops.compiler; + range_checks = ops.range_checks; + + // --- Constrain output --- + // When all_skipped: output is (0, 0, 1) + // Otherwise: output matches the computed result with inf=0 + if num_limbs == 1 { + // Mask result with all_skipped: when all_skipped=1, out must be 0 + let masked_result_x = select_witness(compiler, all_skipped, result_x[0], zero_witness); + let masked_result_y = select_witness(compiler, all_skipped, result_y[0], zero_witness); + constrain_equal(compiler, out_x, masked_result_x); + constrain_equal(compiler, out_y, masked_result_y); + } else { + let recomposed_x = recompose_limbs(compiler, result_x.as_slice(), limb_bits); + let recomposed_y = recompose_limbs(compiler, result_y.as_slice(), limb_bits); + let masked_result_x = select_witness(compiler, all_skipped, recomposed_x, zero_witness); + let masked_result_y = select_witness(compiler, all_skipped, recomposed_y, zero_witness); + constrain_equal(compiler, out_x, masked_result_x); + constrain_equal(compiler, out_y, masked_result_y); + } + constrain_equal(compiler, out_inf, all_skipped); + } +} + +/// Decompose a point (px_witness, py_witness) into Limbs. +fn decompose_point_to_limbs( + compiler: &mut NoirToR1CSCompiler, + px_witness: usize, + py_witness: usize, + num_limbs: usize, + limb_bits: u32, + range_checks: &mut BTreeMap>, +) -> (Limbs, Limbs) { + if num_limbs == 1 { + (Limbs::single(px_witness), Limbs::single(py_witness)) + } else { + let px_limbs = + decompose_witness_to_limbs(compiler, px_witness, limb_bits, num_limbs, range_checks); + let py_limbs = + decompose_witness_to_limbs(compiler, py_witness, limb_bits, num_limbs, range_checks); + (px_limbs, py_limbs) + } +} + +/// FakeGLV verification for a single point: verifies R = \[s\]P. +/// +/// Decomposes s via half-GCD into sub-scalars (s1, s2) and verifies +/// \[s1\]P + \[s2\]R = O using interleaved windowed scalar mul with +/// half-width scalars. +/// +/// Returns the mutable references back to the caller for continued use. +fn verify_point_fakeglv<'a>( + mut compiler: &'a mut NoirToR1CSCompiler, + mut range_checks: &'a mut BTreeMap>, + px: Limbs, + py: Limbs, + rx: Limbs, + ry: Limbs, + s_lo: usize, + s_hi: usize, + num_limbs: usize, + limb_bits: u32, + window_size: usize, + curve: &CurveParams, +) -> ( + &'a mut NoirToR1CSCompiler, + &'a mut BTreeMap>, +) { + // --- Steps 1-4: On-curve checks, FakeGLV decomposition, and GLV scalar mul + // --- + let s1_witness; + let s2_witness; + let neg1_witness; + let neg2_witness; + { + let params = build_params(num_limbs, limb_bits, curve); + let mut ops = MultiLimbOps { + compiler, + range_checks, + params: ¶ms, + }; + + // Step 1: On-curve checks for P and R + let b_limb_values = curve::decompose_to_limbs(&curve.curve_b, limb_bits, num_limbs); + verify_on_curve(&mut ops, px, py, &b_limb_values, num_limbs); + verify_on_curve(&mut ops, rx, ry, &b_limb_values, num_limbs); + + // Step 2: FakeGLVHint → |s1|, |s2|, neg1, neg2 + let glv_start = ops.compiler.num_witnesses(); + ops.compiler + .add_witness_builder(WitnessBuilder::FakeGLVHint { + output_start: glv_start, + s_lo, + s_hi, + curve_order: curve.curve_order_n, + }); + s1_witness = glv_start; + s2_witness = glv_start + 1; + neg1_witness = glv_start + 2; + neg2_witness = glv_start + 3; + + // Step 3: Decompose |s1|, |s2| into half_bits bits each + let half_bits = curve.glv_half_bits() as usize; + let s1_bits = decompose_half_scalar_bits(ops.compiler, s1_witness, half_bits); + let s2_bits = decompose_half_scalar_bits(ops.compiler, s2_witness, half_bits); + + // Step 4: Conditionally negate P.y and R.y + GLV scalar mul + identity + // check + + // Compute negated y-coordinates: neg_y = 0 - y (mod p) + let neg_py = ops.negate(py); + let neg_ry = ops.negate(ry); + + // Select: if neg1=1, use neg_py; else use py + // neg1 and neg2 are constrained to be boolean by ops.select internally. + let py_effective = ops.select(neg1_witness, py, neg_py); + // Select: if neg2=1, use neg_ry; else use ry + let ry_effective = ops.select(neg2_witness, ry, neg_ry); + + // GLV scalar mul + let offset_x_values = curve.offset_x_limbs(limb_bits, num_limbs); + let offset_y_values = curve.offset_y_limbs(limb_bits, num_limbs); + let offset_x = ops.constant_limbs(&offset_x_values); + let offset_y = ops.constant_limbs(&offset_y_values); + + let glv_acc = ec_points::scalar_mul_glv( + &mut ops, + px, + py_effective, + &s1_bits, + rx, + ry_effective, + &s2_bits, + window_size, + offset_x, + offset_y, + ); + + // Identity check: acc should equal [2^(num_windows * window_size)] * + // offset_point + let glv_num_windows = (half_bits + window_size - 1) / window_size; + let glv_n_doublings = glv_num_windows * window_size; + let (acc_off_x_raw, acc_off_y_raw) = curve.accumulated_offset(glv_n_doublings); + + let acc_off_x_values = decompose_to_limbs_pub(&acc_off_x_raw, limb_bits, num_limbs); + let acc_off_y_values = decompose_to_limbs_pub(&acc_off_y_raw, limb_bits, num_limbs); + let expected_x = ops.constant_limbs(&acc_off_x_values); + let expected_y = ops.constant_limbs(&acc_off_y_values); + + for i in 0..num_limbs { + constrain_equal(ops.compiler, glv_acc.0[i], expected_x[i]); + constrain_equal(ops.compiler, glv_acc.1[i], expected_y[i]); + } + + compiler = ops.compiler; + range_checks = ops.range_checks; + } + + // --- Step 5: Scalar relation verification --- + verify_scalar_relation( + compiler, + range_checks, + s_lo, + s_hi, + s1_witness, + s2_witness, + neg1_witness, + neg2_witness, + curve, + ); + + (compiler, range_checks) +} + +/// On-curve check: verifies y^2 = x^3 + a*x + b for a single point. +fn verify_on_curve( + ops: &mut MultiLimbOps, + x: Limbs, + y: Limbs, + b_limb_values: &[FieldElement], + num_limbs: usize, +) { + let y_sq = ops.mul(y, y); + let x_sq = ops.mul(x, x); + let x_cubed = ops.mul(x_sq, x); + let a = ops.curve_a(); + let ax = ops.mul(a, x); + let x3_plus_ax = ops.add(x_cubed, ax); + let b = ops.constant_limbs(b_limb_values); + let rhs = ops.add(x3_plus_ax, b); + for i in 0..num_limbs { + constrain_equal(ops.compiler, y_sq[i], rhs[i]); + } +} + +/// Decompose a single witness into `num_limbs` limbs using digital +/// decomposition. +fn decompose_witness_to_limbs( + compiler: &mut NoirToR1CSCompiler, + witness: usize, + limb_bits: u32, + num_limbs: usize, + range_checks: &mut BTreeMap>, +) -> Limbs { + let log_bases = vec![limb_bits as usize; num_limbs]; + let dd = add_digital_decomposition(compiler, log_bases, vec![witness]); + let mut limbs = Limbs::new(num_limbs); + for i in 0..num_limbs { + limbs[i] = dd.get_digit_witness_index(i, 0); + // Range-check each decomposed limb to [0, 2^limb_bits). + // add_digital_decomposition constrains the recomposition but does + // NOT range-check individual digits. + range_checks.entry(limb_bits).or_default().push(limbs[i]); + } + limbs +} + +/// Recompose limbs back into a single witness: val = Σ limb\[i\] * +/// 2^(i*limb_bits) +fn recompose_limbs(compiler: &mut NoirToR1CSCompiler, limbs: &[usize], limb_bits: u32) -> usize { + let terms: Vec = limbs + .iter() + .enumerate() + .map(|(i, &limb)| { + let coeff = FieldElement::from(2u64).pow([(i as u64) * (limb_bits as u64)]); + SumTerm(Some(coeff), limb) + }) + .collect(); + compiler.add_sum(terms) +} + +/// Resolves a `ConstantOrR1CSWitness` to a witness index. +fn resolve_input(compiler: &mut NoirToR1CSCompiler, input: &ConstantOrR1CSWitness) -> usize { + match input { + ConstantOrR1CSWitness::Witness(idx) => *idx, + ConstantOrR1CSWitness::Constant(value) => { + let w = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Constant(ConstantTerm(w, *value))); + w + } + } +} + +/// Decomposes a half-scalar witness into `half_bits` bit witnesses (LSB first). +fn decompose_half_scalar_bits( + compiler: &mut NoirToR1CSCompiler, + scalar: usize, + half_bits: usize, +) -> Vec { + let log_bases = vec![1usize; half_bits]; + let dd = add_digital_decomposition(compiler, log_bases, vec![scalar]); + let mut bits = Vec::with_capacity(half_bits); + for bit_idx in 0..half_bits { + bits.push(dd.get_digit_witness_index(bit_idx, 0)); + } + bits +} + +/// Builds `MultiLimbParams` for scalar relation verification (mod +/// curve_order_n). +fn build_scalar_relation_params( + num_limbs: usize, + limb_bits: u32, + curve: &CurveParams, +) -> MultiLimbParams { + // Scalar relation uses curve_order_n as the modulus. + // This is always non-native (curve_order_n ≠ BN254 scalar field modulus, + // except for Grumpkin where they're very close but still different). + let two_pow_w = FieldElement::from(2u64).pow([limb_bits as u64]); + let n_limbs = curve.curve_order_n_limbs(limb_bits, num_limbs); + let n_minus_1_limbs = curve.curve_order_n_minus_1_limbs(limb_bits, num_limbs); + + // For N=1 non-native, we need the modulus as a FieldElement + let modulus_fe = if num_limbs == 1 { + Some(curve::curve_native_point_fe(&curve.curve_order_n)) + } else { + None + }; + + MultiLimbParams { + num_limbs, + limb_bits, + p_limbs: n_limbs, + p_minus_1_limbs: n_minus_1_limbs, + two_pow_w, + modulus_raw: curve.curve_order_n, + curve_a_limbs: vec![FieldElement::from(0u64); num_limbs], // unused + is_native: false, // always non-native + modulus_fe, + } +} + +/// Picks the largest limb size for the scalar-relation multi-limb arithmetic +/// that fits inside the native field without overflow. +/// +/// The schoolbook multiplication column equations require: +/// `2 * limb_bits + ceil(log2(num_limbs)) + 3 < native_field_bits` +/// +/// We start at 64 bits (the ideal case — inputs are 128-bit half-scalars) and +/// search downward until the soundness check passes. For BN254 (254-bit native +/// field) this resolves to 64; smaller fields like M31 (31 bits) will get a +/// proportionally smaller limb size. +/// +/// Panics if the native field is too small (< ~12 bits) to support any valid +/// limb decomposition. +fn scalar_relation_limb_bits(order_bits: usize) -> u32 { + let native_bits = FieldElement::MODULUS_BIT_SIZE; + let mut limb_bits: u32 = 64.min((native_bits.saturating_sub(4)) / 2); + loop { + let num_limbs = (order_bits + limb_bits as usize - 1) / limb_bits as usize; + if cost_model::column_equation_fits_native_field(native_bits, limb_bits, num_limbs) { + break; + } + limb_bits -= 1; + assert!( + limb_bits >= 4, + "native field too small for scalar relation verification" + ); + } + limb_bits +} + +/// Verifies the scalar relation: (-1)^neg1 * |s1| + (-1)^neg2 * |s2| * s ≡ 0 +/// (mod n). +/// +/// Uses multi-limb arithmetic with curve_order_n as the modulus. +/// The sub-scalars s1, s2 have `half_bits = ceil(order_bits/2)` bits; +/// the full scalar s has up to `order_bits` bits. +fn verify_scalar_relation( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + s_lo: usize, + s_hi: usize, + s1_witness: usize, + s2_witness: usize, + neg1_witness: usize, + neg2_witness: usize, + curve: &CurveParams, +) { + let order_bits = curve.curve_order_bits() as usize; + let sr_limb_bits = scalar_relation_limb_bits(order_bits); + let sr_num_limbs = (order_bits + sr_limb_bits as usize - 1) / sr_limb_bits as usize; + let half_bits = curve.glv_half_bits() as usize; + // Number of limbs the half-scalar occupies + let half_limbs = (half_bits + sr_limb_bits as usize - 1) / sr_limb_bits as usize; + + let params = build_scalar_relation_params(sr_num_limbs, sr_limb_bits, curve); + let mut ops = MultiLimbOps { + compiler, + range_checks, + params: ¶ms, + }; + + // Decompose s into sr_num_limbs limbs from (s_lo, s_hi). + // s_lo contains bits [0..128), s_hi contains bits [128..256). + let s_limbs = { + let limbs_per_half = (128 + sr_limb_bits as usize - 1) / sr_limb_bits as usize; + let dd_bases_128: Vec = (0..limbs_per_half) + .map(|i| { + let remaining = 128u32 - (i as u32 * sr_limb_bits); + remaining.min(sr_limb_bits) as usize + }) + .collect(); + let dd_lo = add_digital_decomposition(ops.compiler, dd_bases_128.clone(), vec![s_lo]); + let dd_hi = add_digital_decomposition(ops.compiler, dd_bases_128, vec![s_hi]); + let mut limbs = Limbs::new(sr_num_limbs); + let lo_n = limbs_per_half.min(sr_num_limbs); + for i in 0..lo_n { + limbs[i] = dd_lo.get_digit_witness_index(i, 0); + let remaining = 128u32 - (i as u32 * sr_limb_bits); + ops.range_checks + .entry(remaining.min(sr_limb_bits)) + .or_default() + .push(limbs[i]); + } + let hi_n = sr_num_limbs - lo_n; + for i in 0..hi_n { + limbs[lo_n + i] = dd_hi.get_digit_witness_index(i, 0); + let remaining = 128u32 - (i as u32 * sr_limb_bits); + ops.range_checks + .entry(remaining.min(sr_limb_bits)) + .or_default() + .push(limbs[lo_n + i]); + } + limbs + }; + + // Helper: decompose a half-scalar witness into sr_num_limbs limbs. + // The half-scalar has `half_bits` bits → occupies `half_limbs` limbs. + // Upper limbs (half_limbs..sr_num_limbs) are zero-padded. + let decompose_half_scalar = |ops: &mut MultiLimbOps, witness: usize| -> Limbs { + let dd_bases: Vec = (0..half_limbs) + .map(|i| { + let remaining = half_bits as u32 - (i as u32 * sr_limb_bits); + remaining.min(sr_limb_bits) as usize + }) + .collect(); + let dd = add_digital_decomposition(ops.compiler, dd_bases, vec![witness]); + let mut limbs = Limbs::new(sr_num_limbs); + for i in 0..half_limbs { + limbs[i] = dd.get_digit_witness_index(i, 0); + let remaining_bits = (half_bits as u32) - (i as u32 * sr_limb_bits); + let this_limb_bits = remaining_bits.min(sr_limb_bits); + ops.range_checks + .entry(this_limb_bits) + .or_default() + .push(limbs[i]); + } + // Zero-pad upper limbs + for i in half_limbs..sr_num_limbs { + let w = ops.compiler.num_witnesses(); + ops.compiler + .add_witness_builder(WitnessBuilder::Constant(ConstantTerm( + w, + FieldElement::from(0u64), + ))); + limbs[i] = w; + constrain_zero(ops.compiler, limbs[i]); + } + limbs + }; + + let s1_limbs = decompose_half_scalar(&mut ops, s1_witness); + let s2_limbs = decompose_half_scalar(&mut ops, s2_witness); + + // Compute product = s2 * s (mod n) + let product = ops.mul(s2_limbs, s_limbs); + + // Handle signs: compute effective values + // If neg2 is set: neg_product = n - product (mod n), i.e. 0 - product + let neg_product = ops.negate(product); + // neg2 already constrained boolean in verify_point_fakeglv + let effective_product = ops.select_unchecked(neg2_witness, product, neg_product); + + // If neg1 is set: neg_s1 = n - s1 (mod n), i.e. 0 - s1 + let neg_s1 = ops.negate(s1_limbs); + // neg1 already constrained boolean in verify_point_fakeglv + let effective_s1 = ops.select_unchecked(neg1_witness, s1_limbs, neg_s1); + + // Sum: effective_s1 + effective_product (mod n) should be 0 + let sum = ops.add(effective_s1, effective_product); + + // Constrain sum == 0: all limbs must be zero + for i in 0..sr_num_limbs { + constrain_zero(ops.compiler, sum[i]); + } +} + +/// Creates a constant witness with the given value. +fn add_constant_witness(compiler: &mut NoirToR1CSCompiler, value: FieldElement) -> usize { + let w = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Constant(ConstantTerm(w, value))); + w +} + +/// Constrains two witnesses to be equal: `a - b = 0`. +fn constrain_equal(compiler: &mut NoirToR1CSCompiler, a: usize, b: usize) { + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, a), (-FieldElement::ONE, b)], + &[(FieldElement::ZERO, compiler.witness_one())], + ); +} + +/// Constrains a witness to be zero: `w = 0`. +fn constrain_zero(compiler: &mut NoirToR1CSCompiler, w: usize) { + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, w)], + &[(FieldElement::ZERO, compiler.witness_one())], + ); +} diff --git a/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs b/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs new file mode 100644 index 00000000..840f8081 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs @@ -0,0 +1,622 @@ +//! N-limb modular arithmetic for EC field operations. +//! +//! Replaces both `ec_ops.rs` (N=1 path) and `wide_ops.rs` (N>1 path) with +//! unified multi-limb operations using `Limbs` (runtime-sized, Copy). + +use { + super::Limbs, + crate::noir_to_r1cs::NoirToR1CSCompiler, + ark_ff::{AdditiveGroup, BigInteger, Field, PrimeField}, + provekit_common::{ + witness::{SumTerm, WitnessBuilder}, + FieldElement, + }, + std::collections::BTreeMap, +}; + +// --------------------------------------------------------------------------- +// N=1 single-limb path (moved from ec_ops.rs) +// --------------------------------------------------------------------------- + +/// Reduce the value to given modulus (N=1 path). +/// Computes v = k*m + result, where 0 <= result < m. +pub fn reduce_mod_p( + compiler: &mut NoirToR1CSCompiler, + value: usize, + modulus: FieldElement, + range_checks: &mut BTreeMap>, +) -> usize { + let m = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Constant( + provekit_common::witness::ConstantTerm(m, modulus), + )); + let k = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::IntegerQuotient(k, value, modulus)); + + let k_mul_m = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Product(k_mul_m, k, m)); + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, k)], &[(FieldElement::ONE, m)], &[( + FieldElement::ONE, + k_mul_m, + )]); + + let result = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Sum(result, vec![ + SumTerm(Some(FieldElement::ONE), value), + SumTerm(Some(-FieldElement::ONE), k_mul_m), + ])); + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, k_mul_m), (FieldElement::ONE, result)], + &[(FieldElement::ONE, value)], + ); + + let modulus_bits = modulus.into_bigint().num_bits(); + range_checks.entry(modulus_bits).or_default().push(result); + + result +} + +/// a + b mod p (N=1 path) +pub fn add_mod_p_single( + compiler: &mut NoirToR1CSCompiler, + a: usize, + b: usize, + modulus: FieldElement, + range_checks: &mut BTreeMap>, +) -> usize { + let a_add_b = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Sum(a_add_b, vec![ + SumTerm(Some(FieldElement::ONE), a), + SumTerm(Some(FieldElement::ONE), b), + ])); + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, a), (FieldElement::ONE, b)], + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, a_add_b)], + ); + reduce_mod_p(compiler, a_add_b, modulus, range_checks) +} + +/// a * b mod p (N=1 path) +pub fn mul_mod_p_single( + compiler: &mut NoirToR1CSCompiler, + a: usize, + b: usize, + modulus: FieldElement, + range_checks: &mut BTreeMap>, +) -> usize { + let a_mul_b = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Product(a_mul_b, a, b)); + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, a)], &[(FieldElement::ONE, b)], &[( + FieldElement::ONE, + a_mul_b, + )]); + reduce_mod_p(compiler, a_mul_b, modulus, range_checks) +} + +/// (a - b) mod p (N=1 path) +pub fn sub_mod_p_single( + compiler: &mut NoirToR1CSCompiler, + a: usize, + b: usize, + modulus: FieldElement, + range_checks: &mut BTreeMap>, +) -> usize { + let a_sub_b = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Sum(a_sub_b, vec![ + SumTerm(Some(FieldElement::ONE), a), + SumTerm(Some(-FieldElement::ONE), b), + ])); + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, a), (-FieldElement::ONE, b)], + &[(FieldElement::ONE, a_sub_b)], + ); + reduce_mod_p(compiler, a_sub_b, modulus, range_checks) +} + +/// a^(-1) mod p (N=1 path) +pub fn inv_mod_p_single( + compiler: &mut NoirToR1CSCompiler, + a: usize, + modulus: FieldElement, + range_checks: &mut BTreeMap>, +) -> usize { + let a_inv = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::ModularInverse(a_inv, a, modulus)); + + let reduced = mul_mod_p_single(compiler, a, a_inv, modulus, range_checks); + + // Constrain reduced = 1 + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[ + (FieldElement::ONE, reduced), + (-FieldElement::ONE, compiler.witness_one()), + ], + &[(FieldElement::ZERO, compiler.witness_one())], + ); + + let mod_bits = modulus.into_bigint().num_bits(); + range_checks.entry(mod_bits).or_default().push(a_inv); + + a_inv +} + +/// Checks if value is zero or not (used by all N values). +/// Returns a boolean witness: 1 if zero, 0 if non-zero. +/// +/// Uses SafeInverse (not Inverse) because the input value may be zero. +/// SafeInverse outputs 0 when the input is 0, and is solved in the Other +/// layer (not batch-inverted), so zero inputs don't poison the batch. +pub fn compute_is_zero(compiler: &mut NoirToR1CSCompiler, value: usize) -> usize { + let value_inv = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::SafeInverse(value_inv, value)); + + let value_mul_value_inv = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Product( + value_mul_value_inv, + value, + value_inv, + )); + + let is_zero = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Sum(is_zero, vec![ + SumTerm(Some(FieldElement::ONE), compiler.witness_one()), + SumTerm(Some(-FieldElement::ONE), value_mul_value_inv), + ])); + + // v × v^(-1) = 1 - is_zero + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, value)], + &[(FieldElement::ONE, value_inv)], + &[ + (FieldElement::ONE, compiler.witness_one()), + (-FieldElement::ONE, is_zero), + ], + ); + // v × is_zero = 0 + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, value)], + &[(FieldElement::ONE, is_zero)], + &[(FieldElement::ZERO, compiler.witness_one())], + ); + + is_zero +} + +// --------------------------------------------------------------------------- +// N≥2 multi-limb path (generalization of wide_ops.rs) +// --------------------------------------------------------------------------- + +/// (a + b) mod p for multi-limb values. +/// +/// Per limb i: v_i = a\[i\] + b\[i\] + 2^W - q*p\[i\] + carry_{i-1} +/// carry_i = floor(v_i / 2^W) +/// r\[i\] = v_i - carry_i * 2^W +pub fn add_mod_p_multi( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + a: Limbs, + b: Limbs, + p_limbs: &[FieldElement], + p_minus_1_limbs: &[FieldElement], + two_pow_w: FieldElement, + limb_bits: u32, + modulus_raw: &[u64; 4], +) -> Limbs { + let n = a.len(); + assert!(n >= 2, "add_mod_p_multi requires n >= 2, got n={n}"); + let w1 = compiler.witness_one(); + + // Witness: q = floor((a + b) / p) ∈ {0, 1} + let q = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::MultiLimbAddQuotient { + output: q, + a_limbs: a.as_slice().to_vec(), + b_limbs: b.as_slice().to_vec(), + modulus: *modulus_raw, + limb_bits, + num_limbs: n as u32, + }); + // q is boolean + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, q)], &[(FieldElement::ONE, q)], &[( + FieldElement::ONE, + q, + )]); + + let mut r = Limbs::new(n); + let mut carry_prev: Option = None; + + for i in 0..n { + // v_offset = a[i] + b[i] + 2^W - q*p[i] + carry_{i-1} + // When carry_prev exists, combine w1 terms to avoid duplicate column + // indices in the R1CS sparse matrix (set overwrites on duplicate (row,col)). + let w1_coeff = if carry_prev.is_some() { + two_pow_w - FieldElement::ONE + } else { + two_pow_w + }; + let mut terms = vec![ + SumTerm(None, a[i]), + SumTerm(None, b[i]), + SumTerm(Some(w1_coeff), w1), + SumTerm(Some(-p_limbs[i]), q), + ]; + if let Some(carry) = carry_prev { + terms.push(SumTerm(None, carry)); + } + let v_offset = compiler.add_sum(terms); + + // carry = floor(v_offset / 2^W) + let carry = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::IntegerQuotient(carry, v_offset, two_pow_w)); + // r[i] = v_offset - carry * 2^W + r[i] = compiler.add_sum(vec![ + SumTerm(None, v_offset), + SumTerm(Some(-two_pow_w), carry), + ]); + carry_prev = Some(carry); + } + + less_than_p_check_multi( + compiler, + range_checks, + r, + p_minus_1_limbs, + two_pow_w, + limb_bits, + ); + + r +} + +/// (a - b) mod p for multi-limb values. +pub fn sub_mod_p_multi( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + a: Limbs, + b: Limbs, + p_limbs: &[FieldElement], + p_minus_1_limbs: &[FieldElement], + two_pow_w: FieldElement, + limb_bits: u32, + modulus_raw: &[u64; 4], +) -> Limbs { + let n = a.len(); + assert!(n >= 2, "sub_mod_p_multi requires n >= 2, got n={n}"); + let w1 = compiler.witness_one(); + + // Witness: q = (a < b) ? 1 : 0 + let q = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::MultiLimbSubBorrow { + output: q, + a_limbs: a.as_slice().to_vec(), + b_limbs: b.as_slice().to_vec(), + modulus: *modulus_raw, + limb_bits, + num_limbs: n as u32, + }); + // q is boolean + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, q)], &[(FieldElement::ONE, q)], &[( + FieldElement::ONE, + q, + )]); + + let mut r = Limbs::new(n); + let mut carry_prev: Option = None; + + for i in 0..n { + // v_offset = a[i] - b[i] + q*p[i] + 2^W + carry_{i-1} + // When carry_prev exists, combine w1 terms to avoid duplicate column + // indices in the R1CS sparse matrix (set overwrites on duplicate (row,col)). + let w1_coeff = if carry_prev.is_some() { + two_pow_w - FieldElement::ONE + } else { + two_pow_w + }; + let mut terms = vec![ + SumTerm(None, a[i]), + SumTerm(Some(-FieldElement::ONE), b[i]), + SumTerm(Some(p_limbs[i]), q), + SumTerm(Some(w1_coeff), w1), + ]; + if let Some(carry) = carry_prev { + terms.push(SumTerm(None, carry)); + } + let v_offset = compiler.add_sum(terms); + + let carry = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::IntegerQuotient(carry, v_offset, two_pow_w)); + r[i] = compiler.add_sum(vec![ + SumTerm(None, v_offset), + SumTerm(Some(-two_pow_w), carry), + ]); + carry_prev = Some(carry); + } + + less_than_p_check_multi( + compiler, + range_checks, + r, + p_minus_1_limbs, + two_pow_w, + limb_bits, + ); + + r +} + +/// (a * b) mod p for multi-limb values using schoolbook multiplication. +/// +/// Verifies: a·b = p·q + r in base W = 2^limb_bits. +/// Column k: Σ_{i+j=k} a\[i\]*b\[j\] + carry_{k-1} + OFFSET +/// = Σ_{i+j=k} p\[i\]*q\[j\] + r\[k\] + carry_k * W +pub fn mul_mod_p_multi( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + a: Limbs, + b: Limbs, + p_limbs: &[FieldElement], + p_minus_1_limbs: &[FieldElement], + two_pow_w: FieldElement, + limb_bits: u32, + modulus_raw: &[u64; 4], +) -> Limbs { + let n = a.len(); + assert!(n >= 2, "mul_mod_p_multi requires n >= 2, got n={n}"); + + // Soundness check: column equation values must not overflow the native field. + // The maximum value across either side of any column equation is bounded by + // 2^(2*limb_bits + ceil(log2(n)) + 3). This must be strictly less than the + // native field modulus p >= 2^(MODULUS_BIT_SIZE - 1). + { + let ceil_log2_n = (n as f64).log2().ceil() as u32; + let max_bits = 2 * limb_bits + ceil_log2_n + 3; + assert!( + max_bits < FieldElement::MODULUS_BIT_SIZE, + "Schoolbook column equation overflow: limb_bits={limb_bits}, n={n} limbs requires \ + {max_bits} bits, but native field is only {} bits. Use smaller limb_bits.", + FieldElement::MODULUS_BIT_SIZE, + ); + } + + let w1 = compiler.witness_one(); + let num_carries = 2 * n - 2; + // Carry offset: 2^(limb_bits + ceil(log2(n)) + 1) + let extra_bits = ((n as f64).log2().ceil() as u32) + 1; + let carry_offset_bits = limb_bits + extra_bits; + let carry_offset_fe = FieldElement::from(2u64).pow([carry_offset_bits as u64]); + // offset_w = carry_offset * 2^limb_bits + let offset_w = FieldElement::from(2u64).pow([(carry_offset_bits + limb_bits) as u64]); + // offset_w_minus_carry = offset_w - carry_offset = carry_offset * (2^limb_bits + // - 1) + let offset_w_minus_carry = offset_w - carry_offset_fe; + + // Step 1: Allocate hint witnesses (q limbs, r limbs, carries) + let os = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::MultiLimbMulModHint { + output_start: os, + a_limbs: a.as_slice().to_vec(), + b_limbs: b.as_slice().to_vec(), + modulus: *modulus_raw, + limb_bits, + num_limbs: n as u32, + }); + + // q[0..n), r[n..2n), carries[2n..4n-2) + let q: Vec = (0..n).map(|i| os + i).collect(); + let r_indices: Vec = (0..n).map(|i| os + n + i).collect(); + let cu: Vec = (0..num_carries).map(|i| os + 2 * n + i).collect(); + + // Step 2: Product witnesses for a[i]*b[j] (n² R1CS constraints) + let mut ab_products = vec![vec![0usize; n]; n]; + for i in 0..n { + for j in 0..n { + ab_products[i][j] = compiler.add_product(a[i], b[j]); + } + } + + // Step 3: Column equations (2n-1 R1CS constraints) + for k in 0..(2 * n - 1) { + // LHS: Σ_{i+j=k} a[i]*b[j] + carry_{k-1} + OFFSET + let mut lhs_terms: Vec<(FieldElement, usize)> = Vec::new(); + for i in 0..n { + let j_val = k as isize - i as isize; + if j_val >= 0 && (j_val as usize) < n { + lhs_terms.push((FieldElement::ONE, ab_products[i][j_val as usize])); + } + } + // Add carry_{k-1} + if k > 0 { + lhs_terms.push((FieldElement::ONE, cu[k - 1])); + // Add offset_w - carry_offset for subsequent columns + lhs_terms.push((offset_w_minus_carry, w1)); + } else { + // First column: add offset_w + lhs_terms.push((offset_w, w1)); + } + + // RHS: Σ_{i+j=k} p[i]*q[j] + r[k] + carry_k * W + let mut rhs_terms: Vec<(FieldElement, usize)> = Vec::new(); + for i in 0..n { + let j_val = k as isize - i as isize; + if j_val >= 0 && (j_val as usize) < n { + rhs_terms.push((p_limbs[i], q[j_val as usize])); + } + } + if k < n { + rhs_terms.push((FieldElement::ONE, r_indices[k])); + } + if k < 2 * n - 2 { + rhs_terms.push((two_pow_w, cu[k])); + } else { + // Last column: RHS includes offset_w to balance the LHS offset + // LHS has: carry[k-1] + offset_w_minus_carry = true_carry + offset_w + // RHS needs: sum_pq[k] + offset_w (no outgoing carry at last column) + rhs_terms.push((offset_w, w1)); + } + + compiler + .r1cs + .add_constraint(&lhs_terms, &[(FieldElement::ONE, w1)], &rhs_terms); + } + + // Step 4: less-than-p check and range checks on r + let mut r_limbs = Limbs::new(n); + for (i, &ri) in r_indices.iter().enumerate() { + r_limbs[i] = ri; + } + less_than_p_check_multi( + compiler, + range_checks, + r_limbs, + p_minus_1_limbs, + two_pow_w, + limb_bits, + ); + + // Step 5: Range checks for q limbs and carries + for i in 0..n { + range_checks.entry(limb_bits).or_default().push(q[i]); + } + // Carry range: limb_bits + extra_bits + 1 (carry_offset_bits + 1) + let carry_range_bits = carry_offset_bits + 1; + for &c in &cu { + range_checks.entry(carry_range_bits).or_default().push(c); + } + + r_limbs +} + +/// a^(-1) mod p for multi-limb values. +/// Uses MultiLimbModularInverse hint, verifies via mul_mod_p(a, inv) = [1, 0, +/// ..., 0]. +pub fn inv_mod_p_multi( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + a: Limbs, + p_limbs: &[FieldElement], + p_minus_1_limbs: &[FieldElement], + two_pow_w: FieldElement, + limb_bits: u32, + modulus_raw: &[u64; 4], +) -> Limbs { + let n = a.len(); + assert!(n >= 2, "inv_mod_p_multi requires n >= 2, got n={n}"); + + // Hint: compute inverse + let inv_start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::MultiLimbModularInverse { + output_start: inv_start, + a_limbs: a.as_slice().to_vec(), + modulus: *modulus_raw, + limb_bits, + num_limbs: n as u32, + }); + let mut inv = Limbs::new(n); + for i in 0..n { + inv[i] = inv_start + i; + range_checks.entry(limb_bits).or_default().push(inv[i]); + } + + // Verify: a * inv mod p = [1, 0, ..., 0] + let product = mul_mod_p_multi( + compiler, + range_checks, + a, + inv, + p_limbs, + p_minus_1_limbs, + two_pow_w, + limb_bits, + modulus_raw, + ); + + // Constrain product[0] = 1 + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, product[0])], + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, compiler.witness_one())], + ); + // Constrain product[1..n] = 0 + for i in 1..n { + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, product[i])], + &[(FieldElement::ONE, compiler.witness_one())], + &[], + ); + } + + inv +} + +/// Proves r < p by decomposing (p-1) - r into non-negative multi-limb values. +/// Uses borrow propagation: d\[i\] = (p-1)\[i\] - r\[i\] + borrow_in - +/// borrow_out * 2^W +fn less_than_p_check_multi( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + r: Limbs, + p_minus_1_limbs: &[FieldElement], + two_pow_w: FieldElement, + limb_bits: u32, +) { + let n = r.len(); + let w1 = compiler.witness_one(); + let mut borrow_prev: Option = None; + for i in 0..n { + // v_diff = (p-1)[i] + 2^W - r[i] + borrow_prev + // When borrow_prev exists, combine w1 terms to avoid duplicate column + // indices in the R1CS sparse matrix (set overwrites on duplicate (row,col)). + let w1_coeff = if borrow_prev.is_some() { + p_minus_1_limbs[i] + two_pow_w - FieldElement::ONE + } else { + p_minus_1_limbs[i] + two_pow_w + }; + let mut terms = vec![ + SumTerm(Some(w1_coeff), w1), + SumTerm(Some(-FieldElement::ONE), r[i]), + ]; + if let Some(borrow) = borrow_prev { + terms.push(SumTerm(None, borrow)); + } + let v_diff = compiler.add_sum(terms); + + // borrow = floor(v_diff / 2^W) + let borrow = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::IntegerQuotient(borrow, v_diff, two_pow_w)); + // d[i] = v_diff - borrow * 2^W + let d_i = compiler.add_sum(vec![ + SumTerm(None, v_diff), + SumTerm(Some(-two_pow_w), borrow), + ]); + + // Range check r[i] and d[i] + range_checks.entry(limb_bits).or_default().push(r[i]); + range_checks.entry(limb_bits).or_default().push(d_i); + + borrow_prev = Some(borrow); + } + + // Constrain final carry = 1: the 2^W offset at each limb propagates + // a carry of 1 through the chain. For valid r < p, the final carry + // must be exactly 1. If r >= p, the carry chain underflows and the + // final carry is 0. + if let Some(final_borrow) = borrow_prev { + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, final_borrow)], + &[(FieldElement::ONE, compiler.witness_one())], + ); + } +} diff --git a/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs b/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs new file mode 100644 index 00000000..34275b89 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs @@ -0,0 +1,261 @@ +//! `MultiLimbOps` — unified FieldOps implementation parameterized by runtime +//! limb count. +//! +//! Uses `Limbs` (a fixed-capacity Copy type) as `FieldOps::Elem`, enabling +//! arbitrary limb counts without const generics or dispatch macros. + +use { + super::{multi_limb_arith, FieldOps, Limbs}, + crate::noir_to_r1cs::NoirToR1CSCompiler, + ark_ff::{AdditiveGroup, Field}, + provekit_common::{ + witness::{ConstantTerm, SumTerm, WitnessBuilder}, + FieldElement, + }, + std::collections::BTreeMap, +}; + +/// Parameters for multi-limb field arithmetic. +pub struct MultiLimbParams { + pub num_limbs: usize, + pub limb_bits: u32, + pub p_limbs: Vec, + pub p_minus_1_limbs: Vec, + pub two_pow_w: FieldElement, + pub modulus_raw: [u64; 4], + pub curve_a_limbs: Vec, + /// p = native field → skip mod reduction + pub is_native: bool, + /// For N=1 non-native: the modulus as a single FieldElement + pub modulus_fe: Option, +} + +/// Unified field operations struct parameterized by runtime limb count. +pub struct MultiLimbOps<'a, 'p> { + pub compiler: &'a mut NoirToR1CSCompiler, + pub range_checks: &'a mut BTreeMap>, + pub params: &'p MultiLimbParams, +} + +impl MultiLimbOps<'_, '_> { + fn is_native_single(&self) -> bool { + self.params.num_limbs == 1 && self.params.is_native + } + + fn is_non_native_single(&self) -> bool { + self.params.num_limbs == 1 && !self.params.is_native + } + + fn n(&self) -> usize { + self.params.num_limbs + } + + /// Negate a multi-limb value: computes `0 - value (mod p)`. + pub fn negate(&mut self, value: Limbs) -> Limbs { + let zero_vals = vec![FieldElement::from(0u64); self.params.num_limbs]; + let zero = self.constant_limbs(&zero_vals); + self.sub(zero, value) + } +} + +impl FieldOps for MultiLimbOps<'_, '_> { + type Elem = Limbs; + + fn add(&mut self, a: Limbs, b: Limbs) -> Limbs { + debug_assert_eq!(a.len(), self.n(), "a.len() != num_limbs"); + debug_assert_eq!(b.len(), self.n(), "b.len() != num_limbs"); + if self.is_native_single() { + // When both operands are the same witness, merge into a single + // term with coefficient 2 to avoid duplicate column indices in + // the R1CS sparse matrix (set overwrites on duplicate (row,col)). + let r = if a[0] == b[0] { + self.compiler + .add_sum(vec![SumTerm(Some(FieldElement::from(2u64)), a[0])]) + } else { + self.compiler + .add_sum(vec![SumTerm(None, a[0]), SumTerm(None, b[0])]) + }; + Limbs::single(r) + } else if self.is_non_native_single() { + let modulus = self.params.modulus_fe.unwrap(); + let r = multi_limb_arith::add_mod_p_single( + self.compiler, + a[0], + b[0], + modulus, + self.range_checks, + ); + Limbs::single(r) + } else { + multi_limb_arith::add_mod_p_multi( + self.compiler, + self.range_checks, + a, + b, + &self.params.p_limbs, + &self.params.p_minus_1_limbs, + self.params.two_pow_w, + self.params.limb_bits, + &self.params.modulus_raw, + ) + } + } + + fn sub(&mut self, a: Limbs, b: Limbs) -> Limbs { + debug_assert_eq!(a.len(), self.n(), "a.len() != num_limbs"); + debug_assert_eq!(b.len(), self.n(), "b.len() != num_limbs"); + if self.is_native_single() { + // When both operands are the same witness, a - a = 0. Use a + // single zero-coefficient term to avoid duplicate column indices. + let r = if a[0] == b[0] { + self.compiler + .add_sum(vec![SumTerm(Some(FieldElement::ZERO), a[0])]) + } else { + self.compiler.add_sum(vec![ + SumTerm(None, a[0]), + SumTerm(Some(-FieldElement::ONE), b[0]), + ]) + }; + Limbs::single(r) + } else if self.is_non_native_single() { + let modulus = self.params.modulus_fe.unwrap(); + let r = multi_limb_arith::sub_mod_p_single( + self.compiler, + a[0], + b[0], + modulus, + self.range_checks, + ); + Limbs::single(r) + } else { + multi_limb_arith::sub_mod_p_multi( + self.compiler, + self.range_checks, + a, + b, + &self.params.p_limbs, + &self.params.p_minus_1_limbs, + self.params.two_pow_w, + self.params.limb_bits, + &self.params.modulus_raw, + ) + } + } + + fn mul(&mut self, a: Limbs, b: Limbs) -> Limbs { + debug_assert_eq!(a.len(), self.n(), "a.len() != num_limbs"); + debug_assert_eq!(b.len(), self.n(), "b.len() != num_limbs"); + if self.is_native_single() { + let r = self.compiler.add_product(a[0], b[0]); + Limbs::single(r) + } else if self.is_non_native_single() { + let modulus = self.params.modulus_fe.unwrap(); + let r = multi_limb_arith::mul_mod_p_single( + self.compiler, + a[0], + b[0], + modulus, + self.range_checks, + ); + Limbs::single(r) + } else { + multi_limb_arith::mul_mod_p_multi( + self.compiler, + self.range_checks, + a, + b, + &self.params.p_limbs, + &self.params.p_minus_1_limbs, + self.params.two_pow_w, + self.params.limb_bits, + &self.params.modulus_raw, + ) + } + } + + fn inv(&mut self, a: Limbs) -> Limbs { + debug_assert_eq!(a.len(), self.n(), "a.len() != num_limbs"); + if self.is_native_single() { + let a_inv = self.compiler.num_witnesses(); + self.compiler + .add_witness_builder(WitnessBuilder::Inverse(a_inv, a[0])); + // a * a_inv = 1 + self.compiler.r1cs.add_constraint( + &[(FieldElement::ONE, a[0])], + &[(FieldElement::ONE, a_inv)], + &[(FieldElement::ONE, self.compiler.witness_one())], + ); + Limbs::single(a_inv) + } else if self.is_non_native_single() { + let modulus = self.params.modulus_fe.unwrap(); + let r = + multi_limb_arith::inv_mod_p_single(self.compiler, a[0], modulus, self.range_checks); + Limbs::single(r) + } else { + multi_limb_arith::inv_mod_p_multi( + self.compiler, + self.range_checks, + a, + &self.params.p_limbs, + &self.params.p_minus_1_limbs, + self.params.two_pow_w, + self.params.limb_bits, + &self.params.modulus_raw, + ) + } + } + + fn curve_a(&mut self) -> Limbs { + let n = self.n(); + let mut out = Limbs::new(n); + for i in 0..n { + let w = self.compiler.num_witnesses(); + self.compiler + .add_witness_builder(WitnessBuilder::Constant(ConstantTerm( + w, + self.params.curve_a_limbs[i], + ))); + out[i] = w; + } + out + } + + fn constrain_flag(&mut self, flag: usize) { + super::constrain_boolean(self.compiler, flag); + } + + fn select_unchecked(&mut self, flag: usize, on_false: Limbs, on_true: Limbs) -> Limbs { + let n = self.n(); + let mut out = Limbs::new(n); + for i in 0..n { + out[i] = super::select_witness(self.compiler, flag, on_false[i], on_true[i]); + } + out + } + + fn is_zero(&mut self, value: usize) -> usize { + multi_limb_arith::compute_is_zero(self.compiler, value) + } + + fn pack_bits(&mut self, bits: &[usize]) -> usize { + super::pack_bits_helper(self.compiler, bits) + } + + fn constant_limbs(&mut self, limbs: &[FieldElement]) -> Limbs { + let n = self.n(); + assert_eq!( + limbs.len(), + n, + "constant_limbs: expected {n} limbs, got {}", + limbs.len() + ); + let mut out = Limbs::new(n); + for i in 0..n { + let w = self.compiler.num_witnesses(); + self.compiler + .add_witness_builder(WitnessBuilder::Constant(ConstantTerm(w, limbs[i]))); + out[i] = w; + } + out + } +} diff --git a/provekit/r1cs-compiler/src/noir_to_r1cs.rs b/provekit/r1cs-compiler/src/noir_to_r1cs.rs index 189eb469..14473043 100644 --- a/provekit/r1cs-compiler/src/noir_to_r1cs.rs +++ b/provekit/r1cs-compiler/src/noir_to_r1cs.rs @@ -3,6 +3,7 @@ use { binops::add_combined_binop_constraints, digits::{add_digital_decomposition, DigitalDecompositionWitnessesBuilder}, memory::{add_ram_checking, add_rom_checking, MemoryBlock, MemoryOperation}, + msm::add_msm, poseidon2::add_poseidon2_permutation, range_check::add_range_checks, sha256_compression::add_sha256_compression, @@ -88,6 +89,11 @@ pub struct R1CSBreakdown { pub poseidon2_constraints: usize, /// Witnesses from Poseidon2 permutation pub poseidon2_witnesses: usize, + + /// Constraints from multi-scalar multiplication + pub msm_constraints: usize, + /// Witnesses from multi-scalar multiplication + pub msm_witnesses: usize, } /// Compiles an ACIR circuit into an [R1CS] instance, comprising of the A, B, @@ -457,6 +463,7 @@ impl NoirToR1CSCompiler { let mut xor_ops = vec![]; let mut sha256_compression_ops = vec![]; let mut poseidon2_ops = vec![]; + let mut msm_ops = vec![]; let mut breakdown = R1CSBreakdown::default(); @@ -627,6 +634,24 @@ impl NoirToR1CSCompiler { output_witnesses, )); } + BlackBoxFuncCall::MultiScalarMul { + points, + scalars, + outputs, + } => { + let point_wits: Vec = points + .iter() + .map(|inp| self.fetch_constant_or_r1cs_witness(inp.input())) + .collect(); + let scalar_wits: Vec = scalars + .iter() + .map(|inp| self.fetch_constant_or_r1cs_witness(inp.input())) + .collect(); + let out_x = self.fetch_r1cs_witness_index(outputs.0); + let out_y = self.fetch_r1cs_witness_index(outputs.1); + let out_inf = self.fetch_r1cs_witness_index(outputs.2); + msm_ops.push((point_wits, scalar_wits, (out_x, out_y, out_inf))); + } _ => { unimplemented!("Other black box function: {:?}", black_box_func_call); } @@ -718,6 +743,12 @@ impl NoirToR1CSCompiler { breakdown.poseidon2_constraints = self.r1cs.num_constraints() - constraints_before_poseidon; breakdown.poseidon2_witnesses = self.num_witnesses() - witnesses_before_poseidon; + let constraints_before_msm = self.r1cs.num_constraints(); + let witnesses_before_msm = self.num_witnesses(); + add_msm(self, msm_ops, &mut range_checks); + breakdown.msm_constraints = self.r1cs.num_constraints() - constraints_before_msm; + breakdown.msm_witnesses = self.num_witnesses() - witnesses_before_msm; + breakdown.range_ops_total = range_checks.values().map(|v| v.len()).sum(); let constraints_before_range = self.r1cs.num_constraints(); let witnesses_before_range = self.num_witnesses(); diff --git a/tooling/provekit-bench/tests/compiler.rs b/tooling/provekit-bench/tests/compiler.rs index 3b74c8e9..8212290b 100644 --- a/tooling/provekit-bench/tests/compiler.rs +++ b/tooling/provekit-bench/tests/compiler.rs @@ -81,6 +81,7 @@ pub fn compile_workspace(workspace_path: impl AsRef) -> Result #[test_case("../../noir-examples/noir-r1cs-test-programs/bounded-vec")] #[test_case("../../noir-examples/noir-r1cs-test-programs/brillig-unconstrained")] #[test_case("../../noir-examples/noir-passport-monolithic/complete_age_check"; "complete_age_check")] +#[test_case("../../noir-examples/embedded_curve_msm"; "embedded_curve_msm")] fn case_noir(path: &str) { test_noir_compiler(path); }