diff --git a/src/eval.rs b/src/eval.rs index 88a9a77..c9571e7 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -1,3 +1,5 @@ +use crate::{HALF_P, SCALE_FACTOR}; +use num_traits::{One, Zero}; use crate::HALF_P; use num_traits::One; use stwo::core::fields::m31::M31; @@ -99,6 +101,42 @@ pub trait EvalFixedPoint: EvalAtRow { // Enforce the constraint: out^2 + rem = input * scale_factor self.add_constraint((out.clone() * out) + rem.clone() - (input * scale_factor)); } + + /// Evaluates constraints for less-than comparison. + /// Adds constraints to verify that: + /// 1. result is binary (0 or 1) + /// 2. result = 1 if and only if a < b + /// + /// # Parameters + /// - `a`: The trace column value representing the first operand. + /// - `b`: The trace column value representing the second operand. + /// - `result`: The trace column value representing the boolean result (1 if a < b, 0 otherwise). + fn eval_fixed_lt(&mut self, a: Self::F, b: Self::F, result: Self::F) { + // Ensure result is binary (0 or 1) + self.add_constraint(result.clone() * (result.clone() - Self::F::one())); + + // Compute difference b - a + let diff = self.add_intermediate(b.clone() - a.clone()); + + // Compute b - a - 1, which will be >= 0 if and only if a < b + let diff_minus_one = self.add_intermediate(diff - Self::F::one()); + + // For a < b to be true (result = 1), diff_minus_one must be >= 0 + // For a >= b to be false (result = 0), diff_minus_one must be < 0 + + // We need an auxiliary variable to enforce the relationship + let aux = self.add_intermediate(Self::F::zero()); + + // Constraint 1: If result = 0 (a >= b), then diff_minus_one must be < 0, + // which means diff_minus_one + aux = 0 for some positive aux + self.add_constraint( + (Self::F::one() - result.clone()) * (diff_minus_one.clone() + aux.clone()), + ); + + // Constraint 2: If result = 1 (a < b), then diff_minus_one must be >= 0, + // and aux = 0 + self.add_constraint(result * aux); + } } // Blanket implementation for any type that implements EvalAtRow @@ -161,6 +199,7 @@ mod tests { Rem, Recip, Sqrt, + Lt, } impl FrameworkEval for TestEval { @@ -214,6 +253,12 @@ mod tests { let rem = eval.next_trace_mask(); eval.eval_fixed_sqrt(input, out, rem, scale_factor) } + Op::Lt => { + let lhs = eval.next_trace_mask(); + let rhs = eval.next_trace_mask(); + let out = eval.next_trace_mask(); + eval.eval_fixed_lt(lhs, rhs, out) + } } eval } @@ -457,4 +502,36 @@ mod tests { test_op_internal(Op::Sqrt, &[fixed_input], &[sqrt_out, rem], 1); } } + + #[test] + fn test_eval_lt() { + let mut rng = StdRng::seed_from_u64(42); + for _ in 0..100 { + let a = Fixed::from_f64((rng.gen::() - 0.5) * 200.0); + let b = Fixed::from_f64((rng.gen::() - 0.5) * 200.0); + let (result, _) = a.lt(&b); + + test_op(Op::Lt, vec![a, b], vec![result], 2); + } + + // Test specific cases + let special_cases = vec![ + (0.0, 1.0), // 0 < 1 (true) + (1.0, 0.0), // 1 < 0 (false) + (0.0, 0.0), // 0 < 0 (false) + (-1.0, 0.0), // -1 < 0 (true) + (0.0, -1.0), // 0 < -1 (false) + (-2.0, -1.0), // -2 < -1 (true) + (0.5, 1.0), // 0.5 < 1 (true) + (1.5, 1.0), // 1.5 < 1 (false) + ]; + + for (a, b) in special_cases { + let fixed_a = Fixed::from_f64(a); + let fixed_b = Fixed::from_f64(b); + let (result, _) = fixed_a.lt(&fixed_b); + + test_op(Op::Lt, vec![fixed_a, fixed_b], vec![result], 2); + } + } }