Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions src/eval.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use crate::{HALF_P, SCALE_FACTOR};
use num_traits::{One, Zero};
use crate::HALF_P;
use num_traits::One;
use stwo::core::fields::m31::M31;
Expand Down Expand Up @@ -99,6 +101,42 @@ pub trait EvalFixedPoint: EvalAtRow {
// Enforce the constraint: out^2 + rem = input * scale_factor
self.add_constraint((out.clone() * out) + rem.clone() - (input * scale_factor));
}

/// Evaluates constraints for less-than comparison.
/// Adds constraints to verify that:
/// 1. result is binary (0 or 1)
/// 2. result = 1 if and only if a < b
///
/// # Parameters
/// - `a`: The trace column value representing the first operand.
/// - `b`: The trace column value representing the second operand.
/// - `result`: The trace column value representing the boolean result (1 if a < b, 0 otherwise).
fn eval_fixed_lt(&mut self, a: Self::F, b: Self::F, result: Self::F) {
// Ensure result is binary (0 or 1)
self.add_constraint(result.clone() * (result.clone() - Self::F::one()));

// Compute difference b - a
let diff = self.add_intermediate(b.clone() - a.clone());

// Compute b - a - 1, which will be >= 0 if and only if a < b
let diff_minus_one = self.add_intermediate(diff - Self::F::one());

// For a < b to be true (result = 1), diff_minus_one must be >= 0
// For a >= b to be false (result = 0), diff_minus_one must be < 0

// We need an auxiliary variable to enforce the relationship
let aux = self.add_intermediate(Self::F::zero());

// Constraint 1: If result = 0 (a >= b), then diff_minus_one must be < 0,
// which means diff_minus_one + aux = 0 for some positive aux
self.add_constraint(
(Self::F::one() - result.clone()) * (diff_minus_one.clone() + aux.clone()),
);

// Constraint 2: If result = 1 (a < b), then diff_minus_one must be >= 0,
// and aux = 0
self.add_constraint(result * aux);
}
}

// Blanket implementation for any type that implements EvalAtRow
Expand Down Expand Up @@ -161,6 +199,7 @@ mod tests {
Rem,
Recip,
Sqrt,
Lt,
}

impl FrameworkEval for TestEval {
Expand Down Expand Up @@ -214,6 +253,12 @@ mod tests {
let rem = eval.next_trace_mask();
eval.eval_fixed_sqrt(input, out, rem, scale_factor)
}
Op::Lt => {
let lhs = eval.next_trace_mask();
let rhs = eval.next_trace_mask();
let out = eval.next_trace_mask();
eval.eval_fixed_lt(lhs, rhs, out)
}
}
eval
}
Expand Down Expand Up @@ -457,4 +502,36 @@ mod tests {
test_op_internal(Op::Sqrt, &[fixed_input], &[sqrt_out, rem], 1);
}
}

#[test]
fn test_eval_lt() {
let mut rng = StdRng::seed_from_u64(42);
for _ in 0..100 {
let a = Fixed::from_f64((rng.gen::<f64>() - 0.5) * 200.0);
let b = Fixed::from_f64((rng.gen::<f64>() - 0.5) * 200.0);
let (result, _) = a.lt(&b);

test_op(Op::Lt, vec![a, b], vec![result], 2);
}

// Test specific cases
let special_cases = vec![
(0.0, 1.0), // 0 < 1 (true)
(1.0, 0.0), // 1 < 0 (false)
(0.0, 0.0), // 0 < 0 (false)
(-1.0, 0.0), // -1 < 0 (true)
(0.0, -1.0), // 0 < -1 (false)
(-2.0, -1.0), // -2 < -1 (true)
(0.5, 1.0), // 0.5 < 1 (true)
(1.5, 1.0), // 1.5 < 1 (false)
];

for (a, b) in special_cases {
let fixed_a = Fixed::from_f64(a);
let fixed_b = Fixed::from_f64(b);
let (result, _) = fixed_a.lt(&fixed_b);

test_op(Op::Lt, vec![fixed_a, fixed_b], vec![result], 2);
}
}
}