diff --git a/relation/src/gadgets/cmp.rs b/relation/src/gadgets/cmp.rs index 1d4267f86..b8ecb0e7e 100644 --- a/relation/src/gadgets/cmp.rs +++ b/relation/src/gadgets/cmp.rs @@ -85,6 +85,93 @@ impl PlonkCircuit { let c = self.is_lt_internal(a, b)?; self.logic_neg(c) } + + /// Returns a `BoolVar` indicating whether the variable `a` is less than a + /// given constant `val`. + pub fn is_lt_constant(&mut self, a: Variable, val: F) -> Result + where + F: PrimeField, + { + self.check_var_bound(a)?; + let b = self.create_constant_variable(val)?; + self.is_lt(a, b) + } + + /// Returns a `BoolVar` indicating whether the variable `a` is less than or + /// equal to a given constant `val`. + pub fn is_leq_constant(&mut self, a: Variable, val: F) -> Result + where + F: PrimeField, + { + self.check_var_bound(a)?; + let b = self.create_constant_variable(val)?; + self.is_leq(a, b) + } + + /// Returns a `BoolVar` indicating whether the variable `a` is greater than + /// a given constant `val`. + pub fn is_gt_constant(&mut self, a: Variable, val: F) -> Result + where + F: PrimeField, + { + self.check_var_bound(a)?; + self.is_gt_constant_internal(a, &val) + } + + /// Returns a `BoolVar` indicating whether the variable `a` is greater than + /// or equal a given constant `val`. + pub fn is_geq_constant(&mut self, a: Variable, val: F) -> Result + where + F: PrimeField, + { + self.check_var_bound(a)?; + let b = self.create_constant_variable(val)?; + self.is_geq(a, b) + } + + /// Enforce the variable `a` to be less than a + /// given constant `val`. + pub fn enforce_lt_constant(&mut self, a: Variable, val: F) -> Result<(), CircuitError> + where + F: PrimeField, + { + self.check_var_bound(a)?; + let b = self.create_constant_variable(val)?; + self.enforce_lt(a, b) + } + + /// Enforce the variable `a` to be less than or + /// equal to a given constant `val`. + pub fn enforce_leq_constant(&mut self, a: Variable, val: F) -> Result<(), CircuitError> + where + F: PrimeField, + { + self.check_var_bound(a)?; + let b = self.create_constant_variable(val)?; + self.enforce_leq(a, b) + } + + /// Enforce the variable `a` to be greater than + /// a given constant `val`. + pub fn enforce_gt_constant(&mut self, a: Variable, val: F) -> Result<(), CircuitError> + where + F: PrimeField, + { + self.check_var_bound(a)?; + let b = self.create_constant_variable(val)?; + self.enforce_gt(a, b) + } + + /// Enforce the variable `a` to be greater than + /// or equal a given constant `val`. + pub fn enforce_geq_constant(&mut self, a: Variable, val: F) -> Result<(), CircuitError> + where + F: PrimeField, + { + self.check_var_bound(a)?; + let b = self.create_constant_variable(val)?; + self.enforce_geq(a, b) + } } /// Private helper functions for comparison gate @@ -171,12 +258,14 @@ impl PlonkCircuit { #[cfg(test)] mod test { - use crate::{errors::CircuitError, Circuit, PlonkCircuit}; + use crate::{errors::CircuitError, BoolVar, Circuit, PlonkCircuit}; use ark_bls12_377::Fq as Fq377; use ark_ed_on_bls12_377::Fq as FqEd377; use ark_ed_on_bls12_381::Fq as FqEd381; use ark_ed_on_bn254::Fq as FqEd254; use ark_ff::PrimeField; + use ark_std::cmp::Ordering; + use itertools::multizip; #[test] fn test_cmp_gates() -> Result<(), CircuitError> { @@ -199,116 +288,135 @@ mod test { F::from(F::modulus_minus_one_div_two()).mul(F::from(2u32)), ), ]; - list.iter() - .try_for_each(|(a, b)| -> Result<(), CircuitError> { - test_is_le(a, b)?; - test_is_leq(a, b)?; - test_is_ge(a, b)?; - test_is_geq(a, b)?; - test_enforce_le(a, b)?; - test_enforce_leq(a, b)?; - test_enforce_ge(a, b)?; - test_enforce_geq(a, b)?; - test_is_le(b, a)?; - test_is_leq(b, a)?; - test_is_ge(b, a)?; - test_is_geq(b, a)?; - test_enforce_le(b, a)?; - test_enforce_leq(b, a)?; - test_enforce_ge(b, a)?; - test_enforce_geq(b, a) - }) + multizip(( + list, + [Ordering::Less, Ordering::Greater], + [false, true], + [false, true], + )).into_iter() + .try_for_each( + |((a, b), ordering, should_also_check_equality, + is_b_constant)| + -> Result<(), CircuitError> { + test_enforce_cmp_helper(&a, &b, ordering, should_also_check_equality, is_b_constant)?; + test_enforce_cmp_helper(&b, &a, ordering, should_also_check_equality, is_b_constant)?; + test_is_cmp_helper(&a, &b, ordering, should_also_check_equality, is_b_constant)?; + test_is_cmp_helper(&b, &a, ordering, should_also_check_equality, is_b_constant) + }, + ) } - fn test_is_le(a: &F, b: &F) -> Result<(), CircuitError> { + fn test_is_cmp_helper( + a: &F, + b: &F, + ordering: Ordering, + should_also_check_equality: bool, + is_b_constant: bool, + ) -> Result<(), CircuitError> { let mut circuit = PlonkCircuit::::new_turbo_plonk(); - let expected_result = if a < b { F::one() } else { F::zero() }; - let a = circuit.create_variable(*a)?; - let b = circuit.create_variable(*b)?; - - let c = circuit.is_lt(a, b)?; - assert!(circuit.witness(c.into())?.eq(&expected_result)); - assert!(circuit.check_circuit_satisfiability(&[]).is_ok()); - Ok(()) - } - fn test_is_leq(a: &F, b: &F) -> Result<(), CircuitError> { - let mut circuit = PlonkCircuit::::new_turbo_plonk(); - let expected_result = if a <= b { F::one() } else { F::zero() }; - let a = circuit.create_variable(*a)?; - let b = circuit.create_variable(*b)?; - - let c = circuit.is_leq(a, b)?; - assert!(circuit.witness(c.into())?.eq(&expected_result)); - assert!(circuit.check_circuit_satisfiability(&[]).is_ok()); - Ok(()) - } - fn test_is_ge(a: &F, b: &F) -> Result<(), CircuitError> { - let mut circuit = PlonkCircuit::::new_turbo_plonk(); - let expected_result = if a > b { F::one() } else { F::zero() }; - let a = circuit.create_variable(*a)?; - let b = circuit.create_variable(*b)?; - - let c = circuit.is_gt(a, b)?; - assert!(circuit.witness(c.into())?.eq(&expected_result)); - assert!(circuit.check_circuit_satisfiability(&[]).is_ok()); - Ok(()) - } - fn test_is_geq(a: &F, b: &F) -> Result<(), CircuitError> { - let mut circuit = PlonkCircuit::::new_turbo_plonk(); - let expected_result = if a >= b { F::one() } else { F::zero() }; - let a = circuit.create_variable(*a)?; - let b = circuit.create_variable(*b)?; - - let c = circuit.is_geq(a, b)?; - assert!(circuit.witness(c.into())?.eq(&expected_result)); - assert!(circuit.check_circuit_satisfiability(&[]).is_ok()); - Ok(()) - } - fn test_enforce_le(a: &F, b: &F) -> Result<(), CircuitError> { - let mut circuit = PlonkCircuit::::new_turbo_plonk(); - let expected_result = a < b; - let a = circuit.create_variable(*a)?; - let b = circuit.create_variable(*b)?; - circuit.enforce_lt(a, b)?; - if expected_result { - assert!(circuit.check_circuit_satisfiability(&[]).is_ok()) + let expected_result = if a.cmp(b) == ordering + || (a.cmp(b) == Ordering::Equal && should_also_check_equality) + { + F::one() } else { - assert!(circuit.check_circuit_satisfiability(&[]).is_err()); - } - Ok(()) - } - fn test_enforce_leq(a: &F, b: &F) -> Result<(), CircuitError> { - let mut circuit = PlonkCircuit::::new_turbo_plonk(); - let expected_result = a <= b; + F::zero() + }; let a = circuit.create_variable(*a)?; - let b = circuit.create_variable(*b)?; - circuit.enforce_leq(a, b)?; - if expected_result { - assert!(circuit.check_circuit_satisfiability(&[]).is_ok()) + let c: BoolVar = if is_b_constant { + match ordering { + Ordering::Less => { + if should_also_check_equality { + circuit.is_leq_constant(a, *b)? + } else { + circuit.is_lt_constant(a, *b)? + } + }, + Ordering::Greater => { + if should_also_check_equality { + circuit.is_geq_constant(a, *b)? + } else { + circuit.is_gt_constant(a, *b)? + } + }, + // Equality test will be handled elsewhere, comparison gate test will not enter here + Ordering::Equal => circuit.create_boolean_variable_unchecked(expected_result)?, + } } else { - assert!(circuit.check_circuit_satisfiability(&[]).is_err()); - } + let b = circuit.create_variable(*b)?; + match ordering { + Ordering::Less => { + if should_also_check_equality { + circuit.is_leq(a, b)? + } else { + circuit.is_lt(a, b)? + } + }, + Ordering::Greater => { + if should_also_check_equality { + circuit.is_geq(a, b)? + } else { + circuit.is_gt(a, b)? + } + }, + // Equality test will be handled elsewhere, comparison gate test will not enter here + Ordering::Equal => circuit.create_boolean_variable_unchecked(expected_result)?, + } + }; + assert!(circuit.witness(c.into())?.eq(&expected_result)); + assert!(circuit.check_circuit_satisfiability(&[]).is_ok()); Ok(()) } - fn test_enforce_ge(a: &F, b: &F) -> Result<(), CircuitError> { + fn test_enforce_cmp_helper( + a: &F, + b: &F, + ordering: Ordering, + should_also_check_equality: bool, + is_b_constant: bool, + ) -> Result<(), CircuitError> { let mut circuit = PlonkCircuit::::new_turbo_plonk(); - let expected_result = a > b; + let expected_result = + a.cmp(b) == ordering || (a.cmp(b) == Ordering::Equal && should_also_check_equality); let a = circuit.create_variable(*a)?; - let b = circuit.create_variable(*b)?; - circuit.enforce_gt(a, b)?; - if expected_result { - assert!(circuit.check_circuit_satisfiability(&[]).is_ok()) + if is_b_constant { + match ordering { + Ordering::Less => { + if should_also_check_equality { + circuit.enforce_leq_constant(a, *b)? + } else { + circuit.enforce_lt_constant(a, *b)? + } + }, + Ordering::Greater => { + if should_also_check_equality { + circuit.enforce_geq_constant(a, *b)? + } else { + circuit.enforce_gt_constant(a, *b)? + } + }, + // Equality test will be handled elsewhere, comparison gate test will not enter here + Ordering::Equal => (), + } } else { - assert!(circuit.check_circuit_satisfiability(&[]).is_err()); - } - Ok(()) - } - fn test_enforce_geq(a: &F, b: &F) -> Result<(), CircuitError> { - let mut circuit = PlonkCircuit::::new_turbo_plonk(); - let expected_result = a >= b; - let a = circuit.create_variable(*a)?; - let b = circuit.create_variable(*b)?; - circuit.enforce_geq(a, b)?; + let b = circuit.create_variable(*b)?; + match ordering { + Ordering::Less => { + if should_also_check_equality { + circuit.enforce_leq(a, b)? + } else { + circuit.enforce_lt(a, b)? + } + }, + Ordering::Greater => { + if should_also_check_equality { + circuit.enforce_geq(a, b)? + } else { + circuit.enforce_gt(a, b)? + } + }, + // Equality test will be handled elsewhere, comparison gate test will not enter here + Ordering::Equal => (), + } + }; if expected_result { assert!(circuit.check_circuit_satisfiability(&[]).is_ok()) } else {