From 7c9bcdfbd50c794c997d9259100a5e51ce047cdd Mon Sep 17 00:00:00 2001 From: raizo07 Date: Fri, 18 Apr 2025 02:54:24 +0100 Subject: [PATCH] feat: LessThan operator implementation --- src/eval.rs | 77 ++++++++++++++++++++++++++++++++++++++++++++++++++++- src/lib.rs | 66 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 142 insertions(+), 1 deletion(-) diff --git a/src/eval.rs b/src/eval.rs index 8a7693a..91a6ccb 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -1,5 +1,5 @@ use crate::{HALF_P, SCALE_FACTOR}; -use num_traits::One; +use num_traits::{One, Zero}; use stwo_prover::{constraint_framework::EvalAtRow, core::fields::m31::M31}; /// Extension trait for EvalAtRow to support fixed-point arithmetic constraint evaluation @@ -78,6 +78,42 @@ pub trait EvalFixedPoint: EvalAtRow { // Enforce the constraint: out^2 + rem = input * SCALE_FACTOR self.add_constraint((out.clone() * out) + rem.clone() - (input * SCALE_FACTOR)); } + + /// Evaluates constraints for less-than comparison. + /// Adds constraints to verify that: + /// 1. result is binary (0 or 1) + /// 2. result = 1 if and only if a < b + /// + /// # Parameters + /// - `a`: The trace column value representing the first operand. + /// - `b`: The trace column value representing the second operand. + /// - `result`: The trace column value representing the boolean result (1 if a < b, 0 otherwise). + fn eval_fixed_lt(&mut self, a: Self::F, b: Self::F, result: Self::F) { + // Ensure result is binary (0 or 1) + self.add_constraint(result.clone() * (result.clone() - Self::F::one())); + + // Compute difference b - a + let diff = self.add_intermediate(b.clone() - a.clone()); + + // Compute b - a - 1, which will be >= 0 if and only if a < b + let diff_minus_one = self.add_intermediate(diff - Self::F::one()); + + // For a < b to be true (result = 1), diff_minus_one must be >= 0 + // For a >= b to be false (result = 0), diff_minus_one must be < 0 + + // We need an auxiliary variable to enforce the relationship + let aux = self.add_intermediate(Self::F::zero()); + + // Constraint 1: If result = 0 (a >= b), then diff_minus_one must be < 0, + // which means diff_minus_one + aux = 0 for some positive aux + self.add_constraint( + (Self::F::one() - result.clone()) * (diff_minus_one.clone() + aux.clone()), + ); + + // Constraint 2: If result = 1 (a < b), then diff_minus_one must be >= 0, + // and aux = 0 + self.add_constraint(result * aux); + } } // Blanket implementation for any type that implements EvalAtRow @@ -120,6 +156,7 @@ mod tests { Mul, Recip, Sqrt, + Lt, } impl FrameworkEval for TestEval { @@ -164,6 +201,12 @@ mod tests { let rem = eval.next_trace_mask(); eval.eval_fixed_sqrt(input, out, rem) } + Op::Lt => { + let lhs = eval.next_trace_mask(); + let rhs = eval.next_trace_mask(); + let out = eval.next_trace_mask(); + eval.eval_fixed_lt(lhs, rhs, out) + } } eval } @@ -359,4 +402,36 @@ mod tests { test_op(Op::Sqrt, vec![fixed_input], vec![sqrt_out, rem], 1); } } + + #[test] + fn test_eval_lt() { + let mut rng = StdRng::seed_from_u64(42); + for _ in 0..100 { + let a = Fixed::from_f64((rng.gen::() - 0.5) * 200.0); + let b = Fixed::from_f64((rng.gen::() - 0.5) * 200.0); + let (result, _) = a.lt(&b); + + test_op(Op::Lt, vec![a, b], vec![result], 2); + } + + // Test specific cases + let special_cases = vec![ + (0.0, 1.0), // 0 < 1 (true) + (1.0, 0.0), // 1 < 0 (false) + (0.0, 0.0), // 0 < 0 (false) + (-1.0, 0.0), // -1 < 0 (true) + (0.0, -1.0), // 0 < -1 (false) + (-2.0, -1.0), // -2 < -1 (true) + (0.5, 1.0), // 0.5 < 1 (true) + (1.5, 1.0), // 1.5 < 1 (false) + ]; + + for (a, b) in special_cases { + let fixed_a = Fixed::from_f64(a); + let fixed_b = Fixed::from_f64(b); + let (result, _) = fixed_a.lt(&fixed_b); + + test_op(Op::Lt, vec![fixed_a, fixed_b], vec![result], 2); + } + } } diff --git a/src/lib.rs b/src/lib.rs index 02c5455..6afb892 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -115,6 +115,20 @@ impl Fixed { (Self(sqrt_val as i64), Self(remainder as i64)) } + + /// Computes the less-than comparison between self and another Fixed value. + /// + /// Returns a tuple of (result, remainder) where: + /// - result is 1 if self < other, 0 otherwise + /// - remainder is always 0 (to maintain consistency with other operations that return remainders) + #[inline] + pub fn lt(&self, other: &Self) -> (Self, Self) { + if self.0 < other.0 { + (Self(1), Self(0)) // True case: self < other + } else { + (Self(0), Self(0)) // False case: self >= other + } + } } /// Returns the floor of the square root of `n`. @@ -329,4 +343,56 @@ mod tests { assert_near(result_f64, input.sqrt()); } } + + #[test] + fn test_lt() { + // Test cases for less-than comparison + let test_cases = vec![ + (0.0, 1.0, true), // 0 < 1 + (1.0, 0.0, false), // 1 < 0 + (1.0, 1.0, false), // 1 < 1 (equality) + (-1.0, 0.0, true), // -1 < 0 + (0.0, -1.0, false), // 0 < -1 + (-2.0, -1.0, true), // -2 < -1 + (0.5, 1.0, true), // Fractional < Integer + (1.5, 1.0, false), // Fractional > Integer + (-1.5, -1.0, true), // Negative fractional comparisons + ]; + + for (a, b, expected) in test_cases { + let fixed_a = Fixed::from_f64(a); + let fixed_b = Fixed::from_f64(b); + let (result, remainder) = fixed_a.lt(&fixed_b); + + // Check correctness + let expected_result = if expected { 1 } else { 0 }; + assert_eq!( + result.0, expected_result, + "{} < {} should be {}", + a, b, expected + ); + assert_eq!(remainder.0, 0, "Remainder should always be 0"); + } + + // Random test cases + let mut rng = StdRng::seed_from_u64(42); + for _ in 0..1000 { + let a = (rng.gen::() - 0.5) * 200.0; + let b = (rng.gen::() - 0.5) * 200.0; + + let fixed_a = Fixed::from_f64(a); + let fixed_b = Fixed::from_f64(b); + let (result, remainder) = fixed_a.lt(&fixed_b); + + let expected = a < b; + let expected_result = if expected { 1 } else { 0 }; + + assert_eq!( + result.0, expected_result, + "{} < {} should be {}", + a, b, expected + ); + assert_eq!(remainder.0, 0, "Remainder should always be 0"); + } + } }