From 3dc860fa7dc5e35da0b902651d77779e27de5b4f Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Thu, 19 Jan 2023 15:27:03 -0800 Subject: [PATCH 01/29] WIP --- src/bits/uint.rs | 1055 +++++++++++++++++++++++----------------------- 1 file changed, 519 insertions(+), 536 deletions(-) diff --git a/src/bits/uint.rs b/src/bits/uint.rs index c2fa303c..e937e867 100644 --- a/src/bits/uint.rs +++ b/src/bits/uint.rs @@ -1,567 +1,550 @@ -macro_rules! make_uint { - ($name:ident, $size:expr, $native:ident, $mod_name:ident, $r1cs_doc_name:expr, $native_doc_name:expr, $num_bits_doc:expr) => { - #[doc = "This module contains the "] - #[doc = $r1cs_doc_name] - #[doc = "type, which is the R1CS equivalent of the "] - #[doc = $native_doc_name] - #[doc = " type."] - pub mod $mod_name { - use ark_ff::{Field, One, PrimeField, Zero}; - use core::{borrow::Borrow, convert::TryFrom}; - use num_bigint::BigUint; - use num_traits::cast::ToPrimitive; - - use ark_relations::r1cs::{ - ConstraintSystemRef, LinearCombination, Namespace, SynthesisError, Variable, - }; - - use crate::{ - boolean::{AllocatedBool, Boolean}, - prelude::*, - Assignment, Vec, - }; +use ark_ff::{Field, One, PrimeField, Zero}; +use core::{borrow::Borrow, convert::TryFrom}; +use num_bigint::BigUint; +use num_traits::cast::ToPrimitive; + +use ark_relations::r1cs::{ + ConstraintSystemRef, LinearCombination, Namespace, SynthesisError, Variable, +}; + +use crate::{ + boolean::{AllocatedBool, Boolean}, + prelude::*, + Assignment, Vec, +}; + +/// This struct represent an unsigned `N` bit integer as a sequence of `N` [`Boolean`]s. +#[derive(Clone, Debug)] +pub struct UInt{ + bits: [Boolean; N], + value: BigUint, +} - #[doc = "This struct represent an unsigned"] - #[doc = $num_bits_doc] - #[doc = " bit integer as a sequence of "] - #[doc = $num_bits_doc] - #[doc = " `Boolean`s. \n"] - #[doc = "This is the R1CS equivalent of the native "] - #[doc = $native_doc_name] - #[doc = " unsigned integer type."] - #[derive(Clone, Debug)] - pub struct $name { - // Least significant bit first - bits: [Boolean; $size], - value: Option<$native>, +impl UInt { + /// Construct a constant [`UInt`] from the native unsigned integer type. + pub fn constant(value: impl Into) -> Self { + let mut bits = [Boolean::FALSE; N]; + + let mut tmp = value.into(); + for i in 0..N { + bits[i] = Boolean::constant(tmp.is_one()); + tmp >>= 1; + } + + Self { + bits, + value, + } + } + + + /// Turns `self` into the underlying little-endian bits. + pub fn to_bits_le(&self) -> Vec> { + self.0.to_vec() + } + + /// Construct `Self` from a slice of [`Boolean`]s. + /// + /// # Panics + /// + /// This method panics if `bits.len() != N`. + pub fn from_bits_le(bits: &[Boolean]) -> Self { + assert_eq!(bits.len(), N); + let bits = <&[Boolean; N]>::try_from(bits).unwrap().clone(); + let mut value = Some(0.into()); + for b in bits.iter().rev() { + value.as_mut().map(|v| *v <<= 1); + + match *b { + Boolean::Constant(b) => value.as_mut().map(|v| *v |= u8::from(b)), + Boolean::Is(ref b) => match b.value() { + Ok(b) => value.as_mut().map(|v| *v |= u8::from(b)), + Err(_) => value = None, + }, + Boolean::Not(ref b) => match b.value() { + Ok(b) => value.as_mut().map(|v| *v |= u8::from(!b)), + Err(_) => value = None, + }, } + } - impl R1CSVar for $name { - type Value = $native; - - fn cs(&self) -> ConstraintSystemRef { - self.bits.as_ref().cs() - } - - fn value(&self) -> Result { - let mut value = None; - for (i, bit) in self.bits.iter().enumerate() { - let b = $native::from(bit.value()?); - value = match value { - Some(value) => Some(value + (b << i)), - None => Some(b << i), - }; - } - debug_assert_eq!(self.value, value); - value.get() - } + Self { + bits, + value, + } + } + + /// Rotates `self` to the right by `by` steps, wrapping around. + #[tracing::instrument(target = "r1cs", skip(self))] + pub fn rotate_right(&self, by: usize) -> Self { + let mut result = self.clone(); + let by = by % N; + + let new_bits = self.0.iter().skip(by).chain(&self.bits).take(N); + + for (res, new) in result.0.iter_mut().zip(new_bits) { + *res = new.clone(); + } + result.value = self + .value + .cloned() + .map(|v| v.rotate_right(u32::try_from(by).unwrap())); + result + } + + /// Outputs `self ^ other`. + /// + /// If at least one of `self` and `other` are constants, then this + /// method *does not* create any constraints or variables. + #[tracing::instrument(target = "r1cs", skip(self, other))] + pub fn xor(&self, other: &Self) -> Result { + let mut result = self.clone(); + result.value.as_mut().and_then(|v| { + *v^= other.value?; + }); + + let new_bits = self.0.iter().zip(&other.bits).map(|(a, b)| a.xor(b)); + + for (res, new) in result.0.iter_mut().zip(new_bits) { + *res = new?; + } + + Ok(result) + } + + /// Perform modular addition of `operands`. + /// + /// The user must ensure that overflow does not occur. + #[tracing::instrument(target = "r1cs", skip(operands))] + pub fn addmany(operands: &[Self]) -> Result + where + F: PrimeField, + { + // Make some arbitrary bounds for ourselves to avoid overflows + // in the scalar field + assert!(F::MODULUS_BIT_SIZE >= 2 * N); + + assert!(operands.len() >= 1); + assert!(N * operands.len() <= F::MODULUS_BIT_SIZE as usize); + + if operands.len() == 1 { + return Ok(operands[0].clone()); + } + + // Compute the maximum value of the sum so we allocate enough bits for + // the result + let mut max_value = + BigUint::from($native::max_value()) * BigUint::from(operands.len()); + + // Keep track of the resulting value + let mut result_value = Some(BigUint::zero()); + + // This is a linear combination that we will enforce to be "zero" + let mut lc = LinearCombination::zero(); + + let mut all_constants = true; + + // Iterate over the operands + for op in operands { + // Accumulate the value + match op.value { + Some(val) => { + result_value.as_mut().map(|v| *v += BigUint::from(val)); + }, + + None => { + // If any of our operands have unknown value, we won't + // know the value of the result + result_value = None; + }, } - - impl $name { - #[doc = "Construct a constant "] - #[doc = $r1cs_doc_name] - #[doc = " from the native "] - #[doc = $native_doc_name] - #[doc = " type."] - pub fn constant(value: $native) -> Self { - let mut bits = [Boolean::FALSE; $size]; - - let mut tmp = value; - for i in 0..$size { - bits[i] = Boolean::constant((tmp & 1) == 1); - tmp >>= 1; - } - - $name { - bits, - value: Some(value), - } - } - - /// Turns `self` into the underlying little-endian bits. - pub fn to_bits_le(&self) -> Vec> { - self.bits.to_vec() - } - - /// Construct `Self` from a slice of `Boolean`s. - /// - /// # Panics - #[doc = "This method panics if `bits.len() != "] - #[doc = $num_bits_doc] - #[doc = "`."] - pub fn from_bits_le(bits: &[Boolean]) -> Self { - assert_eq!(bits.len(), $size); - - let bits = <&[Boolean; $size]>::try_from(bits).unwrap().clone(); - - let mut value = Some(0); - for b in bits.iter().rev() { - value.as_mut().map(|v| *v <<= 1); - - match *b { - Boolean::Constant(b) => { - value.as_mut().map(|v| *v |= $native::from(b)); - }, - Boolean::Is(ref b) => match b.value() { - Ok(b) => { - value.as_mut().map(|v| *v |= $native::from(b)); - }, - Err(_) => value = None, - }, - Boolean::Not(ref b) => match b.value() { - Ok(b) => { - value.as_mut().map(|v| *v |= $native::from(!b)); - }, - Err(_) => value = None, - }, - } - } - - Self { value, bits } - } - - /// Rotates `self` to the right by `by` steps, wrapping around. - #[tracing::instrument(target = "r1cs", skip(self))] - pub fn rotr(&self, by: usize) -> Self { - let mut result = self.clone(); - let by = by % $size; - - let new_bits = self.bits.iter().skip(by).chain(&self.bits).take($size); - - for (res, new) in result.bits.iter_mut().zip(new_bits) { - *res = new.clone(); - } - - result.value = self - .value - .map(|v| v.rotate_right(u32::try_from(by).unwrap())); - result - } - - /// Outputs `self ^ other`. - /// - /// If at least one of `self` and `other` are constants, then this - /// method *does not* create any constraints or variables. - #[tracing::instrument(target = "r1cs", skip(self, other))] - pub fn xor(&self, other: &Self) -> Result { - let mut result = self.clone(); - result.value = match (self.value, other.value) { - (Some(a), Some(b)) => Some(a ^ b), - _ => None, - }; - - let new_bits = self.bits.iter().zip(&other.bits).map(|(a, b)| a.xor(b)); - - for (res, new) in result.bits.iter_mut().zip(new_bits) { - *res = new?; - } - - Ok(result) - } - - /// Perform modular addition of `operands`. - /// - /// The user must ensure that overflow does not occur. - #[tracing::instrument(target = "r1cs", skip(operands))] - pub fn addmany(operands: &[Self]) -> Result - where - F: PrimeField, - { - // Make some arbitrary bounds for ourselves to avoid overflows - // in the scalar field - assert!(F::MODULUS_BIT_SIZE >= 2 * $size); - - // Support up to 128 - assert!($size <= 128); - - assert!(operands.len() >= 1); - assert!($size * operands.len() <= F::MODULUS_BIT_SIZE as usize); - - if operands.len() == 1 { - return Ok(operands[0].clone()); - } - - // Compute the maximum value of the sum so we allocate enough bits for - // the result - let mut max_value = - BigUint::from($native::max_value()) * BigUint::from(operands.len()); - - // Keep track of the resulting value - let mut result_value = Some(BigUint::zero()); - - // This is a linear combination that we will enforce to be "zero" - let mut lc = LinearCombination::zero(); - - let mut all_constants = true; - - // Iterate over the operands - for op in operands { - // Accumulate the value - match op.value { - Some(val) => { - result_value.as_mut().map(|v| *v += BigUint::from(val)); - }, - - None => { - // If any of our operands have unknown value, we won't - // know the value of the result - result_value = None; - }, + + // Iterate over each bit_gadget of the operand and add the operand to + // the linear combination + let mut coeff = F::one(); + for bit in &op.bits { + match *bit { + Boolean::Is(ref bit) => { + all_constants = false; + + // Add coeff * bit_gadget + lc += (coeff, bit.variable()); + }, + Boolean::Not(ref bit) => { + all_constants = false; + + // Add coeff * (1 - bit_gadget) = coeff * ONE - coeff * + // bit_gadget + lc = lc + (coeff, Variable::One) - (coeff, bit.variable()); + }, + Boolean::Constant(bit) => { + if bit { + lc += (coeff, Variable::One); } - - // Iterate over each bit_gadget of the operand and add the operand to - // the linear combination - let mut coeff = F::one(); - for bit in &op.bits { - match *bit { - Boolean::Is(ref bit) => { - all_constants = false; - - // Add coeff * bit_gadget - lc += (coeff, bit.variable()); - }, - Boolean::Not(ref bit) => { - all_constants = false; - - // Add coeff * (1 - bit_gadget) = coeff * ONE - coeff * - // bit_gadget - lc = lc + (coeff, Variable::One) - (coeff, bit.variable()); - }, - Boolean::Constant(bit) => { - if bit { - lc += (coeff, Variable::One); - } - }, - } - - coeff.double_in_place(); - } - } - - // The value of the actual result is modulo 2^$size - let modular_value = result_value.clone().map(|v| { - let modulus = BigUint::from(1u64) << ($size as u32); - (v % modulus).to_u128().unwrap() as $native - }); - - if all_constants && modular_value.is_some() { - // We can just return a constant, rather than - // unpacking the result into allocated bits. - - return Ok($name::constant(modular_value.unwrap())); - } - let cs = operands.cs(); - - // Storage area for the resulting bits - let mut result_bits = vec![]; - - // Allocate each bit_gadget of the result - let mut coeff = F::one(); - let mut i = 0; - while max_value != BigUint::zero() { - // Allocate the bit_gadget - let b = AllocatedBool::new_witness(cs.clone(), || { - result_value - .clone() - .map(|v| (v >> i) & BigUint::one() == BigUint::one()) - .get() - })?; - - // Subtract this bit_gadget from the linear combination to ensure the sums - // balance out - lc = lc - (coeff, b.variable()); - - result_bits.push(b.into()); - - max_value >>= 1; - i += 1; - coeff.double_in_place(); - } - - // Enforce that the linear combination equals zero - cs.enforce_constraint(lc!(), lc!(), lc)?; - - // Discard carry bits that we don't care about - result_bits.truncate($size); - let bits = TryFrom::try_from(result_bits).unwrap(); - - Ok($name { - bits, - value: modular_value, - }) + }, } + + coeff.double_in_place(); } + } + + // The value of the actual result is modulo 2^$size + let modular_value = result_value.clone().map(|v| { + let modulus = BigUint::from(1u64) << ($size as u32); + (v % modulus).to_u128().unwrap() as $native + }); + + if all_constants && modular_value.is_some() { + // We can just return a constant, rather than + // unpacking the result into allocated bits. + + return Ok($name::constant(modular_value.unwrap())); + } + let cs = operands.cs(); + + // Storage area for the resulting bits + let mut result_bits = vec![]; + + // Allocate each bit_gadget of the result + let mut coeff = F::one(); + let mut i = 0; + while max_value != BigUint::zero() { + // Allocate the bit_gadget + let b = AllocatedBool::new_witness(cs.clone(), || { + result_value + .clone() + .map(|v| (v >> i) & BigUint::one() == BigUint::one()) + .get() + })?; + + // Subtract this bit_gadget from the linear combination to ensure the sums + // balance out + lc = lc - (coeff, b.variable()); + + result_bits.push(b.into()); + + max_value >>= 1; + i += 1; + coeff.double_in_place(); + } + + // Enforce that the linear combination equals zero + cs.enforce_constraint(lc!(), lc!(), lc)?; + + // Discard carry bits that we don't care about + result_bits.truncate($size); + let bits = TryFrom::try_from(result_bits).unwrap(); + + Ok($name { + bits, + value: modular_value, + }) + } +} - impl ToBytesGadget for $name { - #[tracing::instrument(target = "r1cs", skip(self))] - fn to_bytes(&self) -> Result>, SynthesisError> { - Ok(self - .to_bits_le() - .chunks(8) - .map(UInt8::from_bits_le) - .collect()) - } +impl ToBytesGadget for $name { + #[tracing::instrument(target = "r1cs", skip(self))] + fn to_bytes(&self) -> Result>, SynthesisError> { + Ok(self + .to_bits_le() + .chunks(8) + .map(UInt8::from_bits_le) + .collect()) + } + } + + impl EqGadget for $name { + #[tracing::instrument(target = "r1cs", skip(self))] + fn is_eq(&self, other: &Self) -> Result, SynthesisError> { + self.bits.as_ref().is_eq(&other.bits) + } + + #[tracing::instrument(target = "r1cs", skip(self))] + fn conditional_enforce_equal( + &self, + other: &Self, + condition: &Boolean, + ) -> Result<(), SynthesisError> { + self.bits.conditional_enforce_equal(&other.bits, condition) + } + + #[tracing::instrument(target = "r1cs", skip(self))] + fn conditional_enforce_not_equal( + &self, + other: &Self, + condition: &Boolean, + ) -> Result<(), SynthesisError> { + self.bits + .conditional_enforce_not_equal(&other.bits, condition) + } + } + + impl CondSelectGadget for $name { + #[tracing::instrument(target = "r1cs", skip(cond, true_value, false_value))] + fn conditionally_select( + cond: &Boolean, + true_value: &Self, + false_value: &Self, + ) -> Result { + let selected_bits = true_value + .bits + .iter() + .zip(&false_value.bits) + .map(|(t, f)| cond.select(t, f)); + let mut bits = [Boolean::FALSE; $size]; + for (result, new) in bits.iter_mut().zip(selected_bits) { + *result = new?; } - - impl EqGadget for $name { - #[tracing::instrument(target = "r1cs", skip(self))] - fn is_eq(&self, other: &Self) -> Result, SynthesisError> { - self.bits.as_ref().is_eq(&other.bits) + + let value = cond.value().ok().and_then(|cond| { + if cond { + true_value.value().ok() + } else { + false_value.value().ok() } - - #[tracing::instrument(target = "r1cs", skip(self))] - fn conditional_enforce_equal( - &self, - other: &Self, - condition: &Boolean, - ) -> Result<(), SynthesisError> { - self.bits.conditional_enforce_equal(&other.bits, condition) + }); + Ok(Self { bits, value }) + } + } + + impl AllocVar<$native, ConstraintF> for $name { + fn new_variable>( + cs: impl Into>, + f: impl FnOnce() -> Result, + mode: AllocationMode, + ) -> Result { + let ns = cs.into(); + let cs = ns.cs(); + let value = f().map(|f| *f.borrow()).ok(); + + let mut values = [None; $size]; + if let Some(val) = value { + values + .iter_mut() + .enumerate() + .for_each(|(i, v)| *v = Some((val >> i) & 1 == 1)); + } + + let mut bits = [Boolean::FALSE; $size]; + for (b, v) in bits.iter_mut().zip(&values) { + *b = Boolean::new_variable(cs.clone(), || v.get(), mode)?; + } + Ok(Self { bits, value }) + } + } + + #[cfg(test)] + mod test { + use super::$name; + use crate::{bits::boolean::Boolean, prelude::*, Vec}; + use ark_relations::r1cs::{ConstraintSystem, SynthesisError}; + use ark_std::rand::Rng; + use ark_test_curves::mnt4_753::Fr; + + #[test] + fn test_from_bits() -> Result<(), SynthesisError> { + let mut rng = ark_std::test_rng(); + + for _ in 0..1000 { + let v = (0..$size) + .map(|_| Boolean::constant(rng.gen())) + .collect::>>(); + + let b = $name::from_bits_le(&v); + + for (i, bit) in b.bits.iter().enumerate() { + match bit { + &Boolean::Constant(bit) => { + assert_eq!(bit, ((b.value()? >> i) & 1 == 1)); + }, + _ => unreachable!(), + } } - - #[tracing::instrument(target = "r1cs", skip(self))] - fn conditional_enforce_not_equal( - &self, - other: &Self, - condition: &Boolean, - ) -> Result<(), SynthesisError> { - self.bits - .conditional_enforce_not_equal(&other.bits, condition) + + let expected_to_be_same = b.to_bits_le(); + + for x in v.iter().zip(expected_to_be_same.iter()) { + match x { + (&Boolean::Constant(true), &Boolean::Constant(true)) => {}, + (&Boolean::Constant(false), &Boolean::Constant(false)) => {}, + _ => unreachable!(), + } } } - - impl CondSelectGadget for $name { - #[tracing::instrument(target = "r1cs", skip(cond, true_value, false_value))] - fn conditionally_select( - cond: &Boolean, - true_value: &Self, - false_value: &Self, - ) -> Result { - let selected_bits = true_value - .bits - .iter() - .zip(&false_value.bits) - .map(|(t, f)| cond.select(t, f)); - let mut bits = [Boolean::FALSE; $size]; - for (result, new) in bits.iter_mut().zip(selected_bits) { - *result = new?; + Ok(()) + } + + #[test] + fn test_xor() -> Result<(), SynthesisError> { + use Boolean::*; + let mut rng = ark_std::test_rng(); + + for _ in 0..1000 { + let cs = ConstraintSystem::::new_ref(); + + let a: $native = rng.gen(); + let b: $native = rng.gen(); + let c: $native = rng.gen(); + + let mut expected = a ^ b ^ c; + + let a_bit = $name::new_witness(cs.clone(), || Ok(a))?; + let b_bit = $name::constant(b); + let c_bit = $name::new_witness(cs.clone(), || Ok(c))?; + + let r = a_bit.xor(&b_bit).unwrap(); + let r = r.xor(&c_bit).unwrap(); + + assert!(cs.is_satisfied().unwrap()); + + assert!(r.value == Some(expected)); + + for b in r.bits.iter() { + match b { + Is(b) => assert_eq!(b.value()?, (expected & 1 == 1)), + Not(b) => assert_eq!(!b.value()?, (expected & 1 == 1)), + Constant(b) => assert_eq!(*b, (expected & 1 == 1)), } - - let value = cond.value().ok().and_then(|cond| { - if cond { - true_value.value().ok() - } else { - false_value.value().ok() - } - }); - Ok(Self { bits, value }) + + expected >>= 1; } } - - impl AllocVar<$native, ConstraintF> for $name { - fn new_variable>( - cs: impl Into>, - f: impl FnOnce() -> Result, - mode: AllocationMode, - ) -> Result { - let ns = cs.into(); - let cs = ns.cs(); - let value = f().map(|f| *f.borrow()).ok(); - - let mut values = [None; $size]; - if let Some(val) = value { - values - .iter_mut() - .enumerate() - .for_each(|(i, v)| *v = Some((val >> i) & 1 == 1)); - } - - let mut bits = [Boolean::FALSE; $size]; - for (b, v) in bits.iter_mut().zip(&values) { - *b = Boolean::new_variable(cs.clone(), || v.get(), mode)?; + Ok(()) + } + + #[test] + fn test_addmany_constants() -> Result<(), SynthesisError> { + let mut rng = ark_std::test_rng(); + + for _ in 0..1000 { + let cs = ConstraintSystem::::new_ref(); + + let a: $native = rng.gen(); + let b: $native = rng.gen(); + let c: $native = rng.gen(); + + let a_bit = $name::new_constant(cs.clone(), a)?; + let b_bit = $name::new_constant(cs.clone(), b)?; + let c_bit = $name::new_constant(cs.clone(), c)?; + + let mut expected = a.wrapping_add(b).wrapping_add(c); + + let r = $name::addmany(&[a_bit, b_bit, c_bit]).unwrap(); + + assert!(r.value == Some(expected)); + + for b in r.bits.iter() { + match b { + Boolean::Is(_) => unreachable!(), + Boolean::Not(_) => unreachable!(), + Boolean::Constant(b) => assert_eq!(*b, (expected & 1 == 1)), } - Ok(Self { bits, value }) + + expected >>= 1; } } - - #[cfg(test)] - mod test { - use super::$name; - use crate::{bits::boolean::Boolean, prelude::*, Vec}; - use ark_relations::r1cs::{ConstraintSystem, SynthesisError}; - use ark_std::rand::Rng; - use ark_test_curves::mnt4_753::Fr; - - #[test] - fn test_from_bits() -> Result<(), SynthesisError> { - let mut rng = ark_std::test_rng(); - - for _ in 0..1000 { - let v = (0..$size) - .map(|_| Boolean::constant(rng.gen())) - .collect::>>(); - - let b = $name::from_bits_le(&v); - - for (i, bit) in b.bits.iter().enumerate() { - match bit { - &Boolean::Constant(bit) => { - assert_eq!(bit, ((b.value()? >> i) & 1 == 1)); - }, - _ => unreachable!(), - } - } - - let expected_to_be_same = b.to_bits_le(); - - for x in v.iter().zip(expected_to_be_same.iter()) { - match x { - (&Boolean::Constant(true), &Boolean::Constant(true)) => {}, - (&Boolean::Constant(false), &Boolean::Constant(false)) => {}, - _ => unreachable!(), - } - } + Ok(()) + } + + #[test] + fn test_addmany() -> Result<(), SynthesisError> { + let mut rng = ark_std::test_rng(); + + for _ in 0..1000 { + let cs = ConstraintSystem::::new_ref(); + + let a: $native = rng.gen(); + let b: $native = rng.gen(); + let c: $native = rng.gen(); + let d: $native = rng.gen(); + + let mut expected = (a ^ b).wrapping_add(c).wrapping_add(d); + + let a_bit = $name::new_witness(ark_relations::ns!(cs, "a_bit"), || Ok(a))?; + let b_bit = $name::constant(b); + let c_bit = $name::constant(c); + let d_bit = $name::new_witness(ark_relations::ns!(cs, "d_bit"), || Ok(d))?; + + let r = a_bit.xor(&b_bit).unwrap(); + let r = $name::addmany(&[r, c_bit, d_bit]).unwrap(); + + assert!(cs.is_satisfied().unwrap()); + assert!(r.value == Some(expected)); + + for b in r.bits.iter() { + match b { + Boolean::Is(b) => assert_eq!(b.value()?, (expected & 1 == 1)), + Boolean::Not(b) => assert_eq!(!b.value()?, (expected & 1 == 1)), + Boolean::Constant(_) => unreachable!(), } - Ok(()) + + expected >>= 1; } - - #[test] - fn test_xor() -> Result<(), SynthesisError> { - use Boolean::*; - let mut rng = ark_std::test_rng(); - - for _ in 0..1000 { - let cs = ConstraintSystem::::new_ref(); - - let a: $native = rng.gen(); - let b: $native = rng.gen(); - let c: $native = rng.gen(); - - let mut expected = a ^ b ^ c; - - let a_bit = $name::new_witness(cs.clone(), || Ok(a))?; - let b_bit = $name::constant(b); - let c_bit = $name::new_witness(cs.clone(), || Ok(c))?; - - let r = a_bit.xor(&b_bit).unwrap(); - let r = r.xor(&c_bit).unwrap(); - - assert!(cs.is_satisfied().unwrap()); - - assert!(r.value == Some(expected)); - - for b in r.bits.iter() { - match b { - Is(b) => assert_eq!(b.value()?, (expected & 1 == 1)), - Not(b) => assert_eq!(!b.value()?, (expected & 1 == 1)), - Constant(b) => assert_eq!(*b, (expected & 1 == 1)), - } - - expected >>= 1; - } + } + Ok(()) + } + + #[test] + fn test_rotr() -> Result<(), SynthesisError> { + let mut rng = ark_std::test_rng(); + + let mut num = rng.gen(); + + let a: $name = $name::constant(num); + + for i in 0..$size { + let b = a.rotr(i); + + assert!(b.value.unwrap() == num); + + let mut tmp = num; + for b in &b.bits { + match b { + Boolean::Constant(b) => assert_eq!(*b, tmp & 1 == 1), + _ => unreachable!(), } - Ok(()) + + tmp >>= 1; } - - #[test] - fn test_addmany_constants() -> Result<(), SynthesisError> { - let mut rng = ark_std::test_rng(); - - for _ in 0..1000 { - let cs = ConstraintSystem::::new_ref(); - - let a: $native = rng.gen(); - let b: $native = rng.gen(); - let c: $native = rng.gen(); - - let a_bit = $name::new_constant(cs.clone(), a)?; - let b_bit = $name::new_constant(cs.clone(), b)?; - let c_bit = $name::new_constant(cs.clone(), c)?; - - let mut expected = a.wrapping_add(b).wrapping_add(c); - - let r = $name::addmany(&[a_bit, b_bit, c_bit]).unwrap(); - - assert!(r.value == Some(expected)); - - for b in r.bits.iter() { - match b { - Boolean::Is(_) => unreachable!(), - Boolean::Not(_) => unreachable!(), - Boolean::Constant(b) => assert_eq!(*b, (expected & 1 == 1)), - } - - expected >>= 1; - } - } - Ok(()) + + num = num.rotate_right(1); + } + Ok(()) + } + } + +macro_rules! make_uint { + ($name:ident, $native:ident, $mod_name:ident, $r1cs_doc_name:expr, $native_doc_name:expr, $num_bits_doc:expr) => { + #[doc = "This module contains the "] + #[doc = $r1cs_doc_name] + #[doc = "type, which is the R1CS equivalent of the "] + #[doc = $native_doc_name] + #[doc = " type."] + pub mod $mod_name { + impl R1CSVar for UInt { + type Value = $native; + + fn cs(&self) -> ConstraintSystemRef { + self.bits.as_ref().cs() } - - #[test] - fn test_addmany() -> Result<(), SynthesisError> { - let mut rng = ark_std::test_rng(); - - for _ in 0..1000 { - let cs = ConstraintSystem::::new_ref(); - - let a: $native = rng.gen(); - let b: $native = rng.gen(); - let c: $native = rng.gen(); - let d: $native = rng.gen(); - - let mut expected = (a ^ b).wrapping_add(c).wrapping_add(d); - - let a_bit = $name::new_witness(ark_relations::ns!(cs, "a_bit"), || Ok(a))?; - let b_bit = $name::constant(b); - let c_bit = $name::constant(c); - let d_bit = $name::new_witness(ark_relations::ns!(cs, "d_bit"), || Ok(d))?; - - let r = a_bit.xor(&b_bit).unwrap(); - let r = $name::addmany(&[r, c_bit, d_bit]).unwrap(); - - assert!(cs.is_satisfied().unwrap()); - assert!(r.value == Some(expected)); - - for b in r.bits.iter() { - match b { - Boolean::Is(b) => assert_eq!(b.value()?, (expected & 1 == 1)), - Boolean::Not(b) => assert_eq!(!b.value()?, (expected & 1 == 1)), - Boolean::Constant(_) => unreachable!(), - } - - expected >>= 1; - } + + fn value(&self) -> Result { + let mut value = None; + for (i, bit) in self.0.iter().enumerate() { + let b = $native::from(bit.value()?); + value = match value { + Some(value) => Some(value + (b << i)), + None => Some(b << i), + }; } - Ok(()) + debug_assert_eq!(self.value, value); + value.get() } + } - #[test] - fn test_rotr() -> Result<(), SynthesisError> { - let mut rng = ark_std::test_rng(); - - let mut num = rng.gen(); - - let a: $name = $name::constant(num); - - for i in 0..$size { - let b = a.rotr(i); - - assert!(b.value.unwrap() == num); - - let mut tmp = num; - for b in &b.bits { - match b { - Boolean::Constant(b) => assert_eq!(*b, tmp & 1 == 1), - _ => unreachable!(), - } - - tmp >>= 1; - } + } - num = num.rotate_right(1); - } - Ok(()) - } - } + impl UInt { } }; } From e11a061d18a8a8238e372c09f3bfee1f8f497482 Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Thu, 19 Jan 2023 23:44:57 -0800 Subject: [PATCH 02/29] WIP --- src/bits/mod.rs | 16 +- src/bits/uint.rs | 474 ++++++++++++++++++++++------------------------- 2 files changed, 235 insertions(+), 255 deletions(-) diff --git a/src/bits/mod.rs b/src/bits/mod.rs index 94ad5fab..9dcabec5 100644 --- a/src/bits/mod.rs +++ b/src/bits/mod.rs @@ -14,10 +14,18 @@ pub mod uint8; #[macro_use] pub mod uint; -make_uint!(UInt16, 16, u16, uint16, "`U16`", "`u16`", "16"); -make_uint!(UInt32, 32, u32, uint32, "`U32`", "`u32`", "32"); -make_uint!(UInt64, 64, u64, uint64, "`U64`", "`u64`", "64"); -make_uint!(UInt128, 128, u128, uint128, "`U128`", "`u128`", "128"); +pub mod uint16 { + pub type UInt16 = super::uint::UInt<16, u16, F>; +} +pub mod uint32 { + pub type UInt32 = super::uint::UInt<32, u32, F>; +} +pub mod uint64 { + pub type UInt64 = super::uint::UInt<64, u64, F>; +} +pub mod uint128 { + pub type UInt128 = super::uint::UInt<128, u128, F>; +} /// Specifies constraints for conversion to a little-endian bit representation /// of `self`. diff --git a/src/bits/uint.rs b/src/bits/uint.rs index e937e867..bc22397d 100644 --- a/src/bits/uint.rs +++ b/src/bits/uint.rs @@ -1,7 +1,7 @@ use ark_ff::{Field, One, PrimeField, Zero}; -use core::{borrow::Borrow, convert::TryFrom}; +use core::{borrow::Borrow, convert::TryFrom, fmt::Debug}; use num_bigint::BigUint; -use num_traits::cast::ToPrimitive; +use num_traits::PrimInt; use ark_relations::r1cs::{ ConstraintSystemRef, LinearCombination, Namespace, SynthesisError, Variable, @@ -15,32 +15,48 @@ use crate::{ /// This struct represent an unsigned `N` bit integer as a sequence of `N` [`Boolean`]s. #[derive(Clone, Debug)] -pub struct UInt{ +pub struct UInt{ bits: [Boolean; N], - value: BigUint, + value: Option, } -impl UInt { +impl R1CSVar for UInt { + type Value = T; + + fn cs(&self) -> ConstraintSystemRef { + self.bits.as_ref().cs() + } + + fn value(&self) -> Result { + let mut value = BigUint::zero(); + for (i, bit) in self.bits.iter().enumerate() { + value.set_bit(i as u64, bit.value()?); + } + debug_assert_eq!(self.value, Some(value)); + Ok(value) + } +} + +impl UInt { /// Construct a constant [`UInt`] from the native unsigned integer type. - pub fn constant(value: impl Into) -> Self { + pub fn constant(mut value: T) -> Self { let mut bits = [Boolean::FALSE; N]; - let mut tmp = value.into(); for i in 0..N { - bits[i] = Boolean::constant(tmp.is_one()); - tmp >>= 1; + bits[i] = Boolean::constant(value.is_one()); + value = value >> 1; } Self { bits, - value, + value: Some(value), } } /// Turns `self` into the underlying little-endian bits. pub fn to_bits_le(&self) -> Vec> { - self.0.to_vec() + self.bits.to_vec() } /// Construct `Self` from a slice of [`Boolean`]s. @@ -51,23 +67,15 @@ impl UInt { pub fn from_bits_le(bits: &[Boolean]) -> Self { assert_eq!(bits.len(), N); let bits = <&[Boolean; N]>::try_from(bits).unwrap().clone(); - let mut value = Some(0.into()); - for b in bits.iter().rev() { - value.as_mut().map(|v| *v <<= 1); - - match *b { - Boolean::Constant(b) => value.as_mut().map(|v| *v |= u8::from(b)), - Boolean::Is(ref b) => match b.value() { - Ok(b) => value.as_mut().map(|v| *v |= u8::from(b)), - Err(_) => value = None, - }, - Boolean::Not(ref b) => match b.value() { - Ok(b) => value.as_mut().map(|v| *v |= u8::from(!b)), - Err(_) => value = None, - }, - } - } - + let value_exists = bits.iter().all(|b| b.value().is_ok()); + let mut value = T::zero(); + bits.iter().enumerate().filter_map(|(i, b)| { + b.value().ok().map(|b| { + value = value << 1; + value = value | T::from(b as u8).unwrap(); + }) + }); + let value = value_exists.then_some(value); Self { bits, value, @@ -80,15 +88,14 @@ impl UInt { let mut result = self.clone(); let by = by % N; - let new_bits = self.0.iter().skip(by).chain(&self.bits).take(N); + let new_bits = self.bits.iter().skip(by).chain(&self.bits).take(N); - for (res, new) in result.0.iter_mut().zip(new_bits) { + for (res, new) in result.bits.iter_mut().zip(new_bits) { *res = new.clone(); } result.value = self .value - .cloned() - .map(|v| v.rotate_right(u32::try_from(by).unwrap())); + .map(|v| v.rotate_right(by as u32)); result } @@ -100,13 +107,12 @@ impl UInt { pub fn xor(&self, other: &Self) -> Result { let mut result = self.clone(); result.value.as_mut().and_then(|v| { - *v^= other.value?; + *v = *v ^ other.value?; + Some(()) }); - let new_bits = self.0.iter().zip(&other.bits).map(|(a, b)| a.xor(b)); - - for (res, new) in result.0.iter_mut().zip(new_bits) { - *res = new?; + for (self_bit, other_bit) in result.bits.iter_mut().zip(&other.bits) { + *self_bit = self_bit.xor(other_bit)? } Ok(result) @@ -122,7 +128,7 @@ impl UInt { { // Make some arbitrary bounds for ourselves to avoid overflows // in the scalar field - assert!(F::MODULUS_BIT_SIZE >= 2 * N); + assert!(F::MODULUS_BIT_SIZE as usize >= 2 * N); assert!(operands.len() >= 1); assert!(N * operands.len() <= F::MODULUS_BIT_SIZE as usize); @@ -149,7 +155,7 @@ impl UInt { // Accumulate the value match op.value { Some(val) => { - result_value.as_mut().map(|v| *v += BigUint::from(val)); + result_value.as_mut().map(|v| *v += BigUint::from(val.to_u128().unwrap())); }, None => { @@ -190,15 +196,15 @@ impl UInt { // The value of the actual result is modulo 2^$size let modular_value = result_value.clone().map(|v| { - let modulus = BigUint::from(1u64) << ($size as u32); - (v % modulus).to_u128().unwrap() as $native + let modulus = BigUint::from(1u64) << (N as u32); + v % modulus }); if all_constants && modular_value.is_some() { // We can just return a constant, rather than // unpacking the result into allocated bits. - return Ok($name::constant(modular_value.unwrap())); + return Ok(UInt::constant(modular_value.unwrap())); } let cs = operands.cs(); @@ -232,54 +238,55 @@ impl UInt { cs.enforce_constraint(lc!(), lc!(), lc)?; // Discard carry bits that we don't care about - result_bits.truncate($size); + result_bits.truncate(N); let bits = TryFrom::try_from(result_bits).unwrap(); - Ok($name { + Ok(UInt { bits, value: modular_value, }) } } -impl ToBytesGadget for $name { +impl ToBytesGadget for UInt { #[tracing::instrument(target = "r1cs", skip(self))] fn to_bytes(&self) -> Result>, SynthesisError> { Ok(self .to_bits_le() .chunks(8) .map(UInt8::from_bits_le) - .collect()) - } + .collect() + ) } +} - impl EqGadget for $name { - #[tracing::instrument(target = "r1cs", skip(self))] - fn is_eq(&self, other: &Self) -> Result, SynthesisError> { - self.bits.as_ref().is_eq(&other.bits) - } - - #[tracing::instrument(target = "r1cs", skip(self))] - fn conditional_enforce_equal( - &self, - other: &Self, - condition: &Boolean, - ) -> Result<(), SynthesisError> { - self.bits.conditional_enforce_equal(&other.bits, condition) - } - - #[tracing::instrument(target = "r1cs", skip(self))] - fn conditional_enforce_not_equal( - &self, - other: &Self, - condition: &Boolean, - ) -> Result<(), SynthesisError> { - self.bits - .conditional_enforce_not_equal(&other.bits, condition) - } +impl EqGadget for UInt { + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn is_eq(&self, other: &Self) -> Result, SynthesisError> { + self.bits.as_ref().is_eq(&other.bits) + } + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn conditional_enforce_equal( + &self, + other: &Self, + condition: &Boolean, + ) -> Result<(), SynthesisError> { + self.bits.conditional_enforce_equal(&other.bits, condition) + } + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn conditional_enforce_not_equal( + &self, + other: &Self, + condition: &Boolean, + ) -> Result<(), SynthesisError> { + self.bits + .conditional_enforce_not_equal(&other.bits, condition) } +} - impl CondSelectGadget for $name { + impl CondSelectGadget for UInt { #[tracing::instrument(target = "r1cs", skip(cond, true_value, false_value))] fn conditionally_select( cond: &Boolean, @@ -291,7 +298,7 @@ impl ToBytesGadget for $name { .iter() .zip(&false_value.bits) .map(|(t, f)| cond.select(t, f)); - let mut bits = [Boolean::FALSE; $size]; + let mut bits = [Boolean::FALSE; N]; for (result, new) in bits.iter_mut().zip(selected_bits) { *result = new?; } @@ -307,25 +314,25 @@ impl ToBytesGadget for $name { } } - impl AllocVar<$native, ConstraintF> for $name { - fn new_variable>( + impl AllocVar for UInt { + fn new_variable>( cs: impl Into>, - f: impl FnOnce() -> Result, + f: impl FnOnce() -> Result, mode: AllocationMode, ) -> Result { let ns = cs.into(); let cs = ns.cs(); let value = f().map(|f| *f.borrow()).ok(); - let mut values = [None; $size]; + let mut values = [None; N]; if let Some(val) = value { values .iter_mut() .enumerate() - .for_each(|(i, v)| *v = Some((val >> i) & 1 == 1)); + .for_each(|(i, v)| *v = Some(val.bit(i as u64))); } - let mut bits = [Boolean::FALSE; $size]; + let mut bits = [Boolean::FALSE; N]; for (b, v) in bits.iter_mut().zip(&values) { *b = Boolean::new_variable(cs.clone(), || v.get(), mode)?; } @@ -333,218 +340,183 @@ impl ToBytesGadget for $name { } } - #[cfg(test)] - mod test { - use super::$name; - use crate::{bits::boolean::Boolean, prelude::*, Vec}; - use ark_relations::r1cs::{ConstraintSystem, SynthesisError}; - use ark_std::rand::Rng; - use ark_test_curves::mnt4_753::Fr; +// #[cfg(test)] +// mod test { +// use super::UInt; +// use crate::{bits::boolean::Boolean, prelude::*, Vec}; +// use ark_relations::r1cs::{ConstraintSystem, SynthesisError}; +// use ark_std::rand::Rng; +// use ark_test_curves::mnt4_753::Fr; - #[test] - fn test_from_bits() -> Result<(), SynthesisError> { - let mut rng = ark_std::test_rng(); +// #[test] +// fn test_from_bits() -> Result<(), SynthesisError> { +// let mut rng = ark_std::test_rng(); - for _ in 0..1000 { - let v = (0..$size) - .map(|_| Boolean::constant(rng.gen())) - .collect::>>(); +// for _ in 0..1000 { +// let v = (0..$size) +// .map(|_| Boolean::constant(rng.gen())) +// .collect::>>(); - let b = $name::from_bits_le(&v); +// let b = UInt::from_bits_le(&v); - for (i, bit) in b.bits.iter().enumerate() { - match bit { - &Boolean::Constant(bit) => { - assert_eq!(bit, ((b.value()? >> i) & 1 == 1)); - }, - _ => unreachable!(), - } - } +// for (i, bit) in b.bits.iter().enumerate() { +// match bit { +// &Boolean::Constant(bit) => { +// assert_eq!(bit, ((b.value()? >> i) & 1 == 1)); +// }, +// _ => unreachable!(), +// } +// } - let expected_to_be_same = b.to_bits_le(); +// let expected_to_be_same = b.to_bits_le(); - for x in v.iter().zip(expected_to_be_same.iter()) { - match x { - (&Boolean::Constant(true), &Boolean::Constant(true)) => {}, - (&Boolean::Constant(false), &Boolean::Constant(false)) => {}, - _ => unreachable!(), - } - } - } - Ok(()) - } +// for x in v.iter().zip(expected_to_be_same.iter()) { +// match x { +// (&Boolean::Constant(true), &Boolean::Constant(true)) => {}, +// (&Boolean::Constant(false), &Boolean::Constant(false)) => {}, +// _ => unreachable!(), +// } +// } +// } +// Ok(()) +// } - #[test] - fn test_xor() -> Result<(), SynthesisError> { - use Boolean::*; - let mut rng = ark_std::test_rng(); +// #[test] +// fn test_xor() -> Result<(), SynthesisError> { +// use Boolean::*; +// let mut rng = ark_std::test_rng(); - for _ in 0..1000 { - let cs = ConstraintSystem::::new_ref(); +// for _ in 0..1000 { +// let cs = ConstraintSystem::::new_ref(); - let a: $native = rng.gen(); - let b: $native = rng.gen(); - let c: $native = rng.gen(); +// let a: $native = rng.gen(); +// let b: $native = rng.gen(); +// let c: $native = rng.gen(); - let mut expected = a ^ b ^ c; +// let mut expected = a ^ b ^ c; - let a_bit = $name::new_witness(cs.clone(), || Ok(a))?; - let b_bit = $name::constant(b); - let c_bit = $name::new_witness(cs.clone(), || Ok(c))?; +// let a_bit = $name::new_witness(cs.clone(), || Ok(a))?; +// let b_bit = $name::constant(b); +// let c_bit = $name::new_witness(cs.clone(), || Ok(c))?; - let r = a_bit.xor(&b_bit).unwrap(); - let r = r.xor(&c_bit).unwrap(); +// let r = a_bit.xor(&b_bit).unwrap(); +// let r = r.xor(&c_bit).unwrap(); - assert!(cs.is_satisfied().unwrap()); +// assert!(cs.is_satisfied().unwrap()); - assert!(r.value == Some(expected)); +// assert!(r.value == Some(expected)); - for b in r.bits.iter() { - match b { - Is(b) => assert_eq!(b.value()?, (expected & 1 == 1)), - Not(b) => assert_eq!(!b.value()?, (expected & 1 == 1)), - Constant(b) => assert_eq!(*b, (expected & 1 == 1)), - } +// for b in r.bits.iter() { +// match b { +// Is(b) => assert_eq!(b.value()?, (expected & 1 == 1)), +// Not(b) => assert_eq!(!b.value()?, (expected & 1 == 1)), +// Constant(b) => assert_eq!(*b, (expected & 1 == 1)), +// } - expected >>= 1; - } - } - Ok(()) - } +// expected >>= 1; +// } +// } +// Ok(()) +// } - #[test] - fn test_addmany_constants() -> Result<(), SynthesisError> { - let mut rng = ark_std::test_rng(); +// #[test] +// fn test_addmany_constants() -> Result<(), SynthesisError> { +// let mut rng = ark_std::test_rng(); - for _ in 0..1000 { - let cs = ConstraintSystem::::new_ref(); +// for _ in 0..1000 { +// let cs = ConstraintSystem::::new_ref(); - let a: $native = rng.gen(); - let b: $native = rng.gen(); - let c: $native = rng.gen(); +// let a: $native = rng.gen(); +// let b: $native = rng.gen(); +// let c: $native = rng.gen(); - let a_bit = $name::new_constant(cs.clone(), a)?; - let b_bit = $name::new_constant(cs.clone(), b)?; - let c_bit = $name::new_constant(cs.clone(), c)?; +// let a_bit = $name::new_constant(cs.clone(), a)?; +// let b_bit = $name::new_constant(cs.clone(), b)?; +// let c_bit = $name::new_constant(cs.clone(), c)?; - let mut expected = a.wrapping_add(b).wrapping_add(c); +// let mut expected = a.wrapping_add(b).wrapping_add(c); - let r = $name::addmany(&[a_bit, b_bit, c_bit]).unwrap(); +// let r = $name::addmany(&[a_bit, b_bit, c_bit]).unwrap(); - assert!(r.value == Some(expected)); +// assert!(r.value == Some(expected)); - for b in r.bits.iter() { - match b { - Boolean::Is(_) => unreachable!(), - Boolean::Not(_) => unreachable!(), - Boolean::Constant(b) => assert_eq!(*b, (expected & 1 == 1)), - } +// for b in r.bits.iter() { +// match b { +// Boolean::Is(_) => unreachable!(), +// Boolean::Not(_) => unreachable!(), +// Boolean::Constant(b) => assert_eq!(*b, (expected & 1 == 1)), +// } - expected >>= 1; - } - } - Ok(()) - } +// expected >>= 1; +// } +// } +// Ok(()) +// } - #[test] - fn test_addmany() -> Result<(), SynthesisError> { - let mut rng = ark_std::test_rng(); +// #[test] +// fn test_addmany() -> Result<(), SynthesisError> { +// let mut rng = ark_std::test_rng(); - for _ in 0..1000 { - let cs = ConstraintSystem::::new_ref(); +// for _ in 0..1000 { +// let cs = ConstraintSystem::::new_ref(); - let a: $native = rng.gen(); - let b: $native = rng.gen(); - let c: $native = rng.gen(); - let d: $native = rng.gen(); +// let a: $native = rng.gen(); +// let b: $native = rng.gen(); +// let c: $native = rng.gen(); +// let d: $native = rng.gen(); - let mut expected = (a ^ b).wrapping_add(c).wrapping_add(d); +// let mut expected = (a ^ b).wrapping_add(c).wrapping_add(d); - let a_bit = $name::new_witness(ark_relations::ns!(cs, "a_bit"), || Ok(a))?; - let b_bit = $name::constant(b); - let c_bit = $name::constant(c); - let d_bit = $name::new_witness(ark_relations::ns!(cs, "d_bit"), || Ok(d))?; +// let a_bit = $name::new_witness(ark_relations::ns!(cs, "a_bit"), || Ok(a))?; +// let b_bit = $name::constant(b); +// let c_bit = $name::constant(c); +// let d_bit = $name::new_witness(ark_relations::ns!(cs, "d_bit"), || Ok(d))?; - let r = a_bit.xor(&b_bit).unwrap(); - let r = $name::addmany(&[r, c_bit, d_bit]).unwrap(); +// let r = a_bit.xor(&b_bit).unwrap(); +// let r = $name::addmany(&[r, c_bit, d_bit]).unwrap(); - assert!(cs.is_satisfied().unwrap()); - assert!(r.value == Some(expected)); +// assert!(cs.is_satisfied().unwrap()); +// assert!(r.value == Some(expected)); - for b in r.bits.iter() { - match b { - Boolean::Is(b) => assert_eq!(b.value()?, (expected & 1 == 1)), - Boolean::Not(b) => assert_eq!(!b.value()?, (expected & 1 == 1)), - Boolean::Constant(_) => unreachable!(), - } +// for b in r.bits.iter() { +// match b { +// Boolean::Is(b) => assert_eq!(b.value()?, (expected & 1 == 1)), +// Boolean::Not(b) => assert_eq!(!b.value()?, (expected & 1 == 1)), +// Boolean::Constant(_) => unreachable!(), +// } - expected >>= 1; - } - } - Ok(()) - } +// expected >>= 1; +// } +// } +// Ok(()) +// } - #[test] - fn test_rotr() -> Result<(), SynthesisError> { - let mut rng = ark_std::test_rng(); +// #[test] +// fn test_rotr() -> Result<(), SynthesisError> { +// let mut rng = ark_std::test_rng(); - let mut num = rng.gen(); +// let mut num = rng.gen(); - let a: $name = $name::constant(num); +// let a: $name = $name::constant(num); - for i in 0..$size { - let b = a.rotr(i); +// for i in 0..$size { +// let b = a.rotr(i); - assert!(b.value.unwrap() == num); +// assert!(b.value.unwrap() == num); - let mut tmp = num; - for b in &b.bits { - match b { - Boolean::Constant(b) => assert_eq!(*b, tmp & 1 == 1), - _ => unreachable!(), - } +// let mut tmp = num; +// for b in &b.bits { +// match b { +// Boolean::Constant(b) => assert_eq!(*b, tmp & 1 == 1), +// _ => unreachable!(), +// } - tmp >>= 1; - } +// tmp >>= 1; +// } - num = num.rotate_right(1); - } - Ok(()) - } - } +// num = num.rotate_right(1); +// } +// Ok(()) +// } +// } -macro_rules! make_uint { - ($name:ident, $native:ident, $mod_name:ident, $r1cs_doc_name:expr, $native_doc_name:expr, $num_bits_doc:expr) => { - #[doc = "This module contains the "] - #[doc = $r1cs_doc_name] - #[doc = "type, which is the R1CS equivalent of the "] - #[doc = $native_doc_name] - #[doc = " type."] - pub mod $mod_name { - impl R1CSVar for UInt { - type Value = $native; - - fn cs(&self) -> ConstraintSystemRef { - self.bits.as_ref().cs() - } - - fn value(&self) -> Result { - let mut value = None; - for (i, bit) in self.0.iter().enumerate() { - let b = $native::from(bit.value()?); - value = match value { - Some(value) => Some(value + (b << i)), - None => Some(b << i), - }; - } - debug_assert_eq!(self.value, value); - value.get() - } - } - - } - - impl UInt { - } - }; -} From 765cff952c9e3cf847ab91784d4b2260ccef29cb Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Tue, 9 May 2023 17:05:17 -0700 Subject: [PATCH 03/29] Work --- src/bits/uint.rs | 498 ++++++++++++++++++++++++++--------------- src/bits/uint8.rs | 290 ++---------------------- src/fields/fp/mod.rs | 4 +- src/poly/domain/mod.rs | 7 + 4 files changed, 338 insertions(+), 461 deletions(-) diff --git a/src/bits/uint.rs b/src/bits/uint.rs index bc22397d..926fe367 100644 --- a/src/bits/uint.rs +++ b/src/bits/uint.rs @@ -1,7 +1,7 @@ use ark_ff::{Field, One, PrimeField, Zero}; use core::{borrow::Borrow, convert::TryFrom, fmt::Debug}; use num_bigint::BigUint; -use num_traits::PrimInt; +use num_traits::{NumCast, PrimInt}; use ark_relations::r1cs::{ ConstraintSystemRef, LinearCombination, Namespace, SynthesisError, Variable, @@ -15,22 +15,24 @@ use crate::{ /// This struct represent an unsigned `N` bit integer as a sequence of `N` [`Boolean`]s. #[derive(Clone, Debug)] -pub struct UInt{ - bits: [Boolean; N], - value: Option, +pub struct UInt { + #[doc(hidden)] + pub bits: [Boolean; N], + #[doc(hidden)] + pub value: Option, } impl R1CSVar for UInt { type Value = T; - + fn cs(&self) -> ConstraintSystemRef { self.bits.as_ref().cs() } - + fn value(&self) -> Result { - let mut value = BigUint::zero(); + let mut value = T::zero(); for (i, bit) in self.bits.iter().enumerate() { - value.set_bit(i as u64, bit.value()?); + value = value + (T::from(bit.value()? as u8).unwrap() << i); } debug_assert_eq!(self.value, Some(value)); Ok(value) @@ -39,132 +41,224 @@ impl R1CSVar for UInt impl UInt { /// Construct a constant [`UInt`] from the native unsigned integer type. - pub fn constant(mut value: T) -> Self { + /// + /// This *does not* create new variables or constraints. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let var = UInt8::new_witness(cs.clone(), || Ok(2))?; + /// + /// let constant = UInt8::constant(2); + /// var.enforce_equal(&constant)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + pub fn constant(value: T) -> Self { let mut bits = [Boolean::FALSE; N]; - + let mut bit_values = value; + for i in 0..N { - bits[i] = Boolean::constant(value.is_one()); - value = value >> 1; + bits[i] = Boolean::constant((bit_values & T::one()) == T::one()); + bit_values = bit_values >> 1; } - + Self { bits, value: Some(value), } } + /// Construct a constant vector of [`UInt`] from a vector of the native type + /// + /// This *does not* create any new variables or constraints. + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let var = vec![UInt8::new_witness(cs.clone(), || Ok(2))?]; + /// + /// let constant = UInt8::constant_vec(&[2]); + /// var.enforce_equal(&constant)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + pub fn constant_vec(values: &[T]) -> Vec { + values.iter().map(|v| Self::constant(*v)).collect() + } + + /// Allocates a slice of `uN`'s as private witnesses. + pub fn new_witness_vec( + cs: impl Into>, + values: &[impl Into> + Copy], + ) -> Result, SynthesisError> { + let ns = cs.into(); + let cs = ns.cs(); + let mut output_vec = Vec::with_capacity(values.len()); + for value in values { + let byte: Option = Into::into(*value); + output_vec.push(Self::new_witness(cs.clone(), || byte.get())?); + } + Ok(output_vec) + } /// Turns `self` into the underlying little-endian bits. pub fn to_bits_le(&self) -> Vec> { self.bits.to_vec() } - - /// Construct `Self` from a slice of [`Boolean`]s. + + /// Converts a little-endian byte order representation of bits into a + /// `UInt`. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let var = UInt8::new_witness(cs.clone(), || Ok(128))?; + /// + /// let f = Boolean::FALSE; + /// let t = Boolean::TRUE; /// - /// # Panics + /// // Construct [0, 0, 0, 0, 0, 0, 0, 1] + /// let mut bits = vec![f.clone(); 7]; + /// bits.push(t); /// - /// This method panics if `bits.len() != N`. + /// let mut c = UInt8::from_bits_le(&bits); + /// var.enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs")] pub fn from_bits_le(bits: &[Boolean]) -> Self { assert_eq!(bits.len(), N); let bits = <&[Boolean; N]>::try_from(bits).unwrap().clone(); let value_exists = bits.iter().all(|b| b.value().is_ok()); let mut value = T::zero(); - bits.iter().enumerate().filter_map(|(i, b)| { - b.value().ok().map(|b| { - value = value << 1; - value = value | T::from(b as u8).unwrap(); - }) - }); - let value = value_exists.then_some(value); - Self { - bits, - value, + for (i, b) in bits.iter().enumerate() { + if let Ok(b) = b.value() { + value = value + (T::from(b as u8).unwrap() << i); + } } + let value = value_exists.then_some(value); + Self { bits, value } } - + /// Rotates `self` to the right by `by` steps, wrapping around. #[tracing::instrument(target = "r1cs", skip(self))] pub fn rotate_right(&self, by: usize) -> Self { let mut result = self.clone(); let by = by % N; - + let new_bits = self.bits.iter().skip(by).chain(&self.bits).take(N); - + for (res, new) in result.bits.iter_mut().zip(new_bits) { *res = new.clone(); } - result.value = self - .value - .map(|v| v.rotate_right(by as u32)); + result.value = self.value.map(|v| v.rotate_right(by as u32)); result } - + /// Outputs `self ^ other`. /// - /// If at least one of `self` and `other` are constants, then this - /// method *does not* create any constraints or variables. + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; + /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; + /// + /// a.xor(&b)?.enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` #[tracing::instrument(target = "r1cs", skip(self, other))] pub fn xor(&self, other: &Self) -> Result { let mut result = self.clone(); - result.value.as_mut().and_then(|v| { - *v = *v ^ other.value?; - Some(()) - }); - - for (self_bit, other_bit) in result.bits.iter_mut().zip(&other.bits) { - *self_bit = self_bit.xor(other_bit)? + for (a, b) in result.bits.iter_mut().zip(&other.bits) { + *a = a.xor(b)? } - + result.value = self.value.and_then(|a| Some(a ^ other.value?)); + dbg!(result.value); Ok(result) } - + /// Perform modular addition of `operands`. /// /// The user must ensure that overflow does not occur. #[tracing::instrument(target = "r1cs", skip(operands))] - pub fn addmany(operands: &[Self]) -> Result + pub fn add_many(operands: &[Self]) -> Result where F: PrimeField, { // Make some arbitrary bounds for ourselves to avoid overflows // in the scalar field assert!(F::MODULUS_BIT_SIZE as usize >= 2 * N); - + assert!(operands.len() >= 1); assert!(N * operands.len() <= F::MODULUS_BIT_SIZE as usize); - + if operands.len() == 1 { return Ok(operands[0].clone()); } - + // Compute the maximum value of the sum so we allocate enough bits for // the result - let mut max_value = - BigUint::from($native::max_value()) * BigUint::from(operands.len()); - + let mut max_value = T::max_value() + .checked_mul( + &T::from(ark_std::log2(operands.len())).ok_or(SynthesisError::Unsatisfiable)?, + ) + .ok_or(SynthesisError::Unsatisfiable)?; + // Keep track of the resulting value let mut result_value = Some(BigUint::zero()); - + // This is a linear combination that we will enforce to be "zero" let mut lc = LinearCombination::zero(); - + let mut all_constants = true; - + // Iterate over the operands for op in operands { // Accumulate the value match op.value { Some(val) => { - result_value.as_mut().map(|v| *v += BigUint::from(val.to_u128().unwrap())); + result_value + .as_mut() + .map(|v| *v += BigUint::from(val.to_u128().unwrap())); }, - + None => { // If any of our operands have unknown value, we won't // know the value of the result result_value = None; }, } - + // Iterate over each bit_gadget of the operand and add the operand to // the linear combination let mut coeff = F::one(); @@ -172,13 +266,13 @@ impl UInt { match *bit { Boolean::Is(ref bit) => { all_constants = false; - + // Add coeff * bit_gadget lc += (coeff, bit.variable()); }, Boolean::Not(ref bit) => { all_constants = false; - + // Add coeff * (1 - bit_gadget) = coeff * ONE - coeff * // bit_gadget lc = lc + (coeff, Variable::One) - (coeff, bit.variable()); @@ -189,58 +283,60 @@ impl UInt { } }, } - + coeff.double_in_place(); } } - + // The value of the actual result is modulo 2^$size - let modular_value = result_value.clone().map(|v| { + let modular_value = result_value.clone().and_then(|v| { let modulus = BigUint::from(1u64) << (N as u32); - v % modulus + NumCast::from(v % modulus) }); - + if all_constants && modular_value.is_some() { // We can just return a constant, rather than // unpacking the result into allocated bits. - - return Ok(UInt::constant(modular_value.unwrap())); + + return modular_value + .map(UInt::constant) + .ok_or(SynthesisError::AssignmentMissing); } let cs = operands.cs(); - + // Storage area for the resulting bits let mut result_bits = vec![]; - + // Allocate each bit_gadget of the result let mut coeff = F::one(); let mut i = 0; - while max_value != BigUint::zero() { + while max_value != T::zero() { // Allocate the bit_gadget let b = AllocatedBool::new_witness(cs.clone(), || { result_value - .clone() - .map(|v| (v >> i) & BigUint::one() == BigUint::one()) - .get() + .clone() + .map(|v| (v >> i) & BigUint::one() == BigUint::one()) + .get() })?; - + // Subtract this bit_gadget from the linear combination to ensure the sums // balance out lc = lc - (coeff, b.variable()); - + result_bits.push(b.into()); - - max_value >>= 1; + + max_value = max_value >> 1; i += 1; coeff.double_in_place(); } - + // Enforce that the linear combination equals zero cs.enforce_constraint(lc!(), lc!(), lc)?; - + // Discard carry bits that we don't care about result_bits.truncate(N); let bits = TryFrom::try_from(result_bits).unwrap(); - + Ok(UInt { bits, value: modular_value, @@ -248,98 +344,135 @@ impl UInt { } } -impl ToBytesGadget for UInt { +impl ToBytesGadget + for UInt +{ #[tracing::instrument(target = "r1cs", skip(self))] fn to_bytes(&self) -> Result>, SynthesisError> { Ok(self .to_bits_le() .chunks(8) .map(UInt8::from_bits_le) - .collect() - ) + .collect()) } } - -impl EqGadget for UInt { + +impl EqGadget + for UInt +{ #[tracing::instrument(target = "r1cs", skip(self, other))] fn is_eq(&self, other: &Self) -> Result, SynthesisError> { - self.bits.as_ref().is_eq(&other.bits) + let chunk_size = usize::try_from(ConstraintF::MODULUS_BIT_SIZE - 1).unwrap(); + let chunks_are_eq = self + .bits + .chunks(chunk_size) + .zip(other.bits.chunks(chunk_size)) + .map(|(a, b)| { + let a = Boolean::le_bits_to_fp_var(a)?; + let b = Boolean::le_bits_to_fp_var(b)?; + a.is_eq(&b) + }) + .collect::, _>>()?; + Boolean::kary_and(&chunks_are_eq) } - + #[tracing::instrument(target = "r1cs", skip(self, other))] fn conditional_enforce_equal( &self, other: &Self, condition: &Boolean, ) -> Result<(), SynthesisError> { - self.bits.conditional_enforce_equal(&other.bits, condition) + let chunk_size = usize::try_from(ConstraintF::MODULUS_BIT_SIZE - 1).unwrap(); + for (a, b) in self + .bits + .chunks(chunk_size) + .zip(other.bits.chunks(chunk_size)) + { + let a = Boolean::le_bits_to_fp_var(a)?; + let b = Boolean::le_bits_to_fp_var(b)?; + a.conditional_enforce_equal(&b, condition)?; + } + Ok(()) } - + #[tracing::instrument(target = "r1cs", skip(self, other))] fn conditional_enforce_not_equal( &self, other: &Self, condition: &Boolean, ) -> Result<(), SynthesisError> { - self.bits - .conditional_enforce_not_equal(&other.bits, condition) + let chunk_size = usize::try_from(ConstraintF::MODULUS_BIT_SIZE - 1).unwrap(); + for (a, b) in self + .bits + .chunks(chunk_size) + .zip(other.bits.chunks(chunk_size)) + { + let a = Boolean::le_bits_to_fp_var(a)?; + let b = Boolean::le_bits_to_fp_var(b)?; + a.conditional_enforce_not_equal(&b, condition)?; + } + Ok(()) } } - - impl CondSelectGadget for UInt { - #[tracing::instrument(target = "r1cs", skip(cond, true_value, false_value))] - fn conditionally_select( - cond: &Boolean, - true_value: &Self, - false_value: &Self, - ) -> Result { - let selected_bits = true_value + +impl CondSelectGadget + for UInt +{ + #[tracing::instrument(target = "r1cs", skip(cond, true_value, false_value))] + fn conditionally_select( + cond: &Boolean, + true_value: &Self, + false_value: &Self, + ) -> Result { + let selected_bits = true_value .bits .iter() .zip(&false_value.bits) .map(|(t, f)| cond.select(t, f)); - let mut bits = [Boolean::FALSE; N]; - for (result, new) in bits.iter_mut().zip(selected_bits) { - *result = new?; - } - - let value = cond.value().ok().and_then(|cond| { - if cond { - true_value.value().ok() - } else { - false_value.value().ok() - } - }); - Ok(Self { bits, value }) + let mut bits = [Boolean::FALSE; N]; + for (result, new) in bits.iter_mut().zip(selected_bits) { + *result = new?; } + + let value = cond.value().ok().and_then(|cond| { + if cond { + true_value.value().ok() + } else { + false_value.value().ok() + } + }); + Ok(Self { bits, value }) } - - impl AllocVar for UInt { - fn new_variable>( - cs: impl Into>, - f: impl FnOnce() -> Result, - mode: AllocationMode, - ) -> Result { - let ns = cs.into(); - let cs = ns.cs(); - let value = f().map(|f| *f.borrow()).ok(); - - let mut values = [None; N]; - if let Some(val) = value { - values +} + +impl AllocVar + for UInt +{ + fn new_variable>( + cs: impl Into>, + f: impl FnOnce() -> Result, + mode: AllocationMode, + ) -> Result { + let ns = cs.into(); + let cs = ns.cs(); + let value = f().map(|f| *f.borrow()).ok(); + + let mut values = [None; N]; + if let Some(val) = value { + values .iter_mut() .enumerate() - .for_each(|(i, v)| *v = Some(val.bit(i as u64))); - } - - let mut bits = [Boolean::FALSE; N]; - for (b, v) in bits.iter_mut().zip(&values) { - *b = Boolean::new_variable(cs.clone(), || v.get(), mode)?; - } - Ok(Self { bits, value }) + .for_each(|(i, v)| *v = Some(((val >> i) & T::one()) == T::one())); } + + let mut bits = [Boolean::FALSE; N]; + for (b, v) in bits.iter_mut().zip(&values) { + *b = Boolean::new_variable(cs.clone(), || v.get(), mode)?; + } + Ok(Self { bits, value }) } - +} + // #[cfg(test)] // mod test { // use super::UInt; @@ -347,18 +480,18 @@ impl EqGadget Result<(), SynthesisError> { // let mut rng = ark_std::test_rng(); - + // for _ in 0..1000 { // let v = (0..$size) // .map(|_| Boolean::constant(rng.gen())) // .collect::>>(); - + // let b = UInt::from_bits_le(&v); - + // for (i, bit) in b.bits.iter().enumerate() { // match bit { // &Boolean::Constant(bit) => { @@ -367,9 +500,9 @@ impl EqGadget unreachable!(), // } // } - + // let expected_to_be_same = b.to_bits_le(); - + // for x in v.iter().zip(expected_to_be_same.iter()) { // match x { // (&Boolean::Constant(true), &Boolean::Constant(true)) => {}, @@ -380,143 +513,142 @@ impl EqGadget Result<(), SynthesisError> { // use Boolean::*; // let mut rng = ark_std::test_rng(); - + // for _ in 0..1000 { // let cs = ConstraintSystem::::new_ref(); - + // let a: $native = rng.gen(); // let b: $native = rng.gen(); // let c: $native = rng.gen(); - + // let mut expected = a ^ b ^ c; - + // let a_bit = $name::new_witness(cs.clone(), || Ok(a))?; // let b_bit = $name::constant(b); // let c_bit = $name::new_witness(cs.clone(), || Ok(c))?; - + // let r = a_bit.xor(&b_bit).unwrap(); // let r = r.xor(&c_bit).unwrap(); - + // assert!(cs.is_satisfied().unwrap()); - + // assert!(r.value == Some(expected)); - + // for b in r.bits.iter() { // match b { // Is(b) => assert_eq!(b.value()?, (expected & 1 == 1)), // Not(b) => assert_eq!(!b.value()?, (expected & 1 == 1)), // Constant(b) => assert_eq!(*b, (expected & 1 == 1)), // } - + // expected >>= 1; // } // } // Ok(()) // } - + // #[test] -// fn test_addmany_constants() -> Result<(), SynthesisError> { +// fn test_add_many_constants() -> Result<(), SynthesisError> { // let mut rng = ark_std::test_rng(); - + // for _ in 0..1000 { // let cs = ConstraintSystem::::new_ref(); - + // let a: $native = rng.gen(); // let b: $native = rng.gen(); // let c: $native = rng.gen(); - + // let a_bit = $name::new_constant(cs.clone(), a)?; // let b_bit = $name::new_constant(cs.clone(), b)?; // let c_bit = $name::new_constant(cs.clone(), c)?; - + // let mut expected = a.wrapping_add(b).wrapping_add(c); - -// let r = $name::addmany(&[a_bit, b_bit, c_bit]).unwrap(); - + +// let r = $name::add_many(&[a_bit, b_bit, c_bit]).unwrap(); + // assert!(r.value == Some(expected)); - + // for b in r.bits.iter() { // match b { // Boolean::Is(_) => unreachable!(), // Boolean::Not(_) => unreachable!(), // Boolean::Constant(b) => assert_eq!(*b, (expected & 1 == 1)), // } - + // expected >>= 1; // } // } // Ok(()) // } - + // #[test] -// fn test_addmany() -> Result<(), SynthesisError> { +// fn test_add_many() -> Result<(), SynthesisError> { // let mut rng = ark_std::test_rng(); - + // for _ in 0..1000 { // let cs = ConstraintSystem::::new_ref(); - + // let a: $native = rng.gen(); // let b: $native = rng.gen(); // let c: $native = rng.gen(); // let d: $native = rng.gen(); - + // let mut expected = (a ^ b).wrapping_add(c).wrapping_add(d); - + // let a_bit = $name::new_witness(ark_relations::ns!(cs, "a_bit"), || Ok(a))?; // let b_bit = $name::constant(b); // let c_bit = $name::constant(c); // let d_bit = $name::new_witness(ark_relations::ns!(cs, "d_bit"), || Ok(d))?; - + // let r = a_bit.xor(&b_bit).unwrap(); -// let r = $name::addmany(&[r, c_bit, d_bit]).unwrap(); - +// let r = $name::add_many(&[r, c_bit, d_bit]).unwrap(); + // assert!(cs.is_satisfied().unwrap()); // assert!(r.value == Some(expected)); - + // for b in r.bits.iter() { // match b { // Boolean::Is(b) => assert_eq!(b.value()?, (expected & 1 == 1)), // Boolean::Not(b) => assert_eq!(!b.value()?, (expected & 1 == 1)), // Boolean::Constant(_) => unreachable!(), // } - + // expected >>= 1; // } // } // Ok(()) // } - + // #[test] // fn test_rotr() -> Result<(), SynthesisError> { // let mut rng = ark_std::test_rng(); - + // let mut num = rng.gen(); - + // let a: $name = $name::constant(num); - + // for i in 0..$size { // let b = a.rotr(i); - + // assert!(b.value.unwrap() == num); - + // let mut tmp = num; // for b in &b.bits { // match b { // Boolean::Constant(b) => assert_eq!(*b, tmp & 1 == 1), // _ => unreachable!(), // } - + // tmp >>= 1; // } - + // num = num.rotate_right(1); // } // Ok(()) // } // } - diff --git a/src/bits/uint8.rs b/src/bits/uint8.rs index f1612719..863415ab 100644 --- a/src/bits/uint8.rs +++ b/src/bits/uint8.rs @@ -1,123 +1,16 @@ use ark_ff::{Field, PrimeField, ToConstraintField}; -use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError}; +use ark_relations::r1cs::{Namespace, SynthesisError}; use crate::{ fields::fp::{AllocatedFp, FpVar}, prelude::*, - Assignment, ToConstraintFieldGadget, Vec, + ToConstraintFieldGadget, Vec, }; -use core::{borrow::Borrow, convert::TryFrom}; - -/// Represents an interpretation of 8 `Boolean` objects as an -/// unsigned integer. -#[derive(Clone, Debug)] -pub struct UInt8 { - /// Little-endian representation: least significant bit first - pub(crate) bits: [Boolean; 8], - pub(crate) value: Option, -} - -impl R1CSVar for UInt8 { - type Value = u8; - - fn cs(&self) -> ConstraintSystemRef { - self.bits.as_ref().cs() - } - fn value(&self) -> Result { - let mut value = None; - for (i, bit) in self.bits.iter().enumerate() { - let b = u8::from(bit.value()?); - value = match value { - Some(value) => Some(value + (b << i)), - None => Some(b << i), - }; - } - debug_assert_eq!(self.value, value); - value.get() - } -} +pub type UInt8 = super::uint::UInt<8, u8, F>; impl UInt8 { - /// Construct a constant vector of `UInt8` from a vector of `u8` - /// - /// This *does not* create any new variables or constraints. - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// let var = vec![UInt8::new_witness(cs.clone(), || Ok(2))?]; - /// - /// let constant = UInt8::constant_vec(&[2]); - /// var.enforce_equal(&constant)?; - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - pub fn constant_vec(values: &[u8]) -> Vec { - let mut result = Vec::new(); - for value in values { - result.push(UInt8::constant(*value)); - } - result - } - - /// Construct a constant `UInt8` from a `u8` - /// - /// This *does not* create new variables or constraints. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// let var = UInt8::new_witness(cs.clone(), || Ok(2))?; - /// - /// let constant = UInt8::constant(2); - /// var.enforce_equal(&constant)?; - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - pub fn constant(value: u8) -> Self { - let mut bits = [Boolean::FALSE; 8]; - - let mut tmp = value; - for i in 0..8 { - // If last bit is one, push one. - bits[i] = Boolean::constant((tmp & 1) == 1); - tmp >>= 1; - } - - Self { - bits, - value: Some(value), - } - } - - /// Allocates a slice of `u8`'s as private witnesses. - pub fn new_witness_vec( - cs: impl Into>, - values: &[impl Into> + Copy], - ) -> Result, SynthesisError> { - let ns = cs.into(); - let cs = ns.cs(); - let mut output_vec = Vec::with_capacity(values.len()); - for value in values { - let byte: Option = Into::into(*value); - output_vec.push(Self::new_witness(cs.clone(), || byte.get())?); - } - Ok(output_vec) - } - /// Allocates a slice of `u8`'s as public inputs by first packing them into /// elements of `F`, (thus reducing the number of input allocations), /// allocating these elements as public inputs, and then converting @@ -175,167 +68,6 @@ impl UInt8 { .map(Self::from_bits_le) .collect()) } - - /// Converts a little-endian byte order representation of bits into a - /// `UInt8`. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// let var = UInt8::new_witness(cs.clone(), || Ok(128))?; - /// - /// let f = Boolean::FALSE; - /// let t = Boolean::TRUE; - /// - /// // Construct [0, 0, 0, 0, 0, 0, 0, 1] - /// let mut bits = vec![f.clone(); 7]; - /// bits.push(t); - /// - /// let mut c = UInt8::from_bits_le(&bits); - /// var.enforce_equal(&c)?; - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - #[tracing::instrument(target = "r1cs")] - pub fn from_bits_le(bits: &[Boolean]) -> Self { - assert_eq!(bits.len(), 8); - let bits = <&[Boolean; 8]>::try_from(bits).unwrap().clone(); - - let mut value = Some(0u8); - for (i, b) in bits.iter().enumerate() { - value = match b.value().ok() { - Some(b) => value.map(|v| v + (u8::from(b) << i)), - None => None, - } - } - - Self { value, bits } - } - - /// Outputs `self ^ other`. - /// - /// If at least one of `self` and `other` are constants, then this method - /// *does not* create any constraints or variables. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; - /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; - /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; - /// - /// a.xor(&b)?.enforce_equal(&c)?; - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - #[tracing::instrument(target = "r1cs")] - pub fn xor(&self, other: &Self) -> Result { - let mut result = self.clone(); - result.value = match (self.value, other.value) { - (Some(a), Some(b)) => Some(a ^ b), - _ => None, - }; - - let new_bits = self.bits.iter().zip(&other.bits).map(|(a, b)| a.xor(b)); - - for (res, new) in result.bits.iter_mut().zip(new_bits) { - *res = new?; - } - - Ok(result) - } -} - -impl EqGadget for UInt8 { - #[tracing::instrument(target = "r1cs")] - fn is_eq(&self, other: &Self) -> Result, SynthesisError> { - self.bits.as_ref().is_eq(&other.bits) - } - - #[tracing::instrument(target = "r1cs")] - fn conditional_enforce_equal( - &self, - other: &Self, - condition: &Boolean, - ) -> Result<(), SynthesisError> { - self.bits.conditional_enforce_equal(&other.bits, condition) - } - - #[tracing::instrument(target = "r1cs")] - fn conditional_enforce_not_equal( - &self, - other: &Self, - condition: &Boolean, - ) -> Result<(), SynthesisError> { - self.bits - .conditional_enforce_not_equal(&other.bits, condition) - } -} - -impl CondSelectGadget for UInt8 { - #[tracing::instrument(target = "r1cs", skip(cond, true_value, false_value))] - fn conditionally_select( - cond: &Boolean, - true_value: &Self, - false_value: &Self, - ) -> Result { - let selected_bits = true_value - .bits - .iter() - .zip(&false_value.bits) - .map(|(t, f)| cond.select(t, f)); - let mut bits = [Boolean::FALSE; 8]; - for (result, new) in bits.iter_mut().zip(selected_bits) { - *result = new?; - } - - let value = cond.value().ok().and_then(|cond| { - if cond { - true_value.value().ok() - } else { - false_value.value().ok() - } - }); - Ok(Self { bits, value }) - } -} - -impl AllocVar for UInt8 { - fn new_variable>( - cs: impl Into>, - f: impl FnOnce() -> Result, - mode: AllocationMode, - ) -> Result { - let ns = cs.into(); - let cs = ns.cs(); - let value = f().map(|f| *f.borrow()).ok(); - - let mut values = [None; 8]; - if let Some(val) = value { - values - .iter_mut() - .enumerate() - .for_each(|(i, v)| *v = Some((val >> i) & 1 == 1)); - } - - let mut bits = [Boolean::FALSE; 8]; - for (b, v) in bits.iter_mut().zip(&values) { - *b = Boolean::new_variable(cs.clone(), || v.get(), mode)?; - } - Ok(Self { bits, value }) - } } /// Parses the `Vec>` in fixed-sized @@ -382,7 +114,7 @@ mod test { let byte_val = 0b01110001; let byte = UInt8::new_witness(ark_relations::ns!(cs, "alloc value"), || Ok(byte_val)).unwrap(); - let bits = byte.to_bits_le()?; + let bits = byte.to_bits_le(); for (i, bit) in bits.iter().enumerate() { assert_eq!(bit.value()?, (byte_val >> i) & 1 == 1) } @@ -397,7 +129,7 @@ mod test { UInt8::new_input_vec(ark_relations::ns!(cs, "alloc value"), &byte_vals).unwrap(); dbg!(bytes.value())?; for (native, variable) in byte_vals.into_iter().zip(bytes) { - let bits = variable.to_bits_le()?; + let bits = variable.to_bits_le(); for (i, bit) in bits.iter().enumerate() { assert_eq!( bit.value()?, @@ -422,14 +154,15 @@ mod test { let val = UInt8::from_bits_le(&v); + let value = val.value()?; for (i, bit) in val.bits.iter().enumerate() { match bit { - Boolean::Constant(b) => assert!(*b == ((val.value()? >> i) & 1 == 1)), + Boolean::Constant(b) => assert_eq!(*b, ((value >> i) & 1 == 1)), _ => unreachable!(), } } - let expected_to_be_same = val.to_bits_le()?; + let expected_to_be_same = val.to_bits_le(); for x in v.iter().zip(expected_to_be_same.iter()) { match x { @@ -458,13 +191,18 @@ mod test { let a_bit = UInt8::new_witness(ark_relations::ns!(cs, "a_bit"), || Ok(a)).unwrap(); let b_bit = UInt8::constant(b); let c_bit = UInt8::new_witness(ark_relations::ns!(cs, "c_bit"), || Ok(c)).unwrap(); + dbg!(a_bit.value().unwrap()); + dbg!(b_bit.value().unwrap()); + dbg!(c_bit.value().unwrap()); let r = a_bit.xor(&b_bit).unwrap(); let r = r.xor(&c_bit).unwrap(); assert!(cs.is_satisfied().unwrap()); - assert!(r.value == Some(expected)); + dbg!(Some(expected)); + dbg!(r.value().unwrap()); + assert_eq!(r.value, Some(expected)); for b in r.bits.iter() { match b { diff --git a/src/fields/fp/mod.rs b/src/fields/fp/mod.rs index 8241aa60..c5eeb74c 100644 --- a/src/fields/fp/mod.rs +++ b/src/fields/fp/mod.rs @@ -129,7 +129,7 @@ impl AllocatedFp { /// /// This does not create any constraints and only creates one linear /// combination. - pub fn addmany<'a, I: Iterator>(iter: I) -> Self { + pub fn add_many<'a, I: Iterator>(iter: I) -> Self { let mut cs = ConstraintSystemRef::None; let mut has_value = true; let mut value = F::zero(); @@ -1050,7 +1050,7 @@ impl AllocVar for FpVar { impl<'a, F: PrimeField> Sum<&'a FpVar> for FpVar { fn sum>>(iter: I) -> FpVar { let mut sum_constants = F::zero(); - let sum_variables = FpVar::Var(AllocatedFp::::addmany(iter.filter_map(|x| match x { + let sum_variables = FpVar::Var(AllocatedFp::::add_many(iter.filter_map(|x| match x { FpVar::Constant(c) => { sum_constants += c; None diff --git a/src/poly/domain/mod.rs b/src/poly/domain/mod.rs index 32411867..8c450702 100644 --- a/src/poly/domain/mod.rs +++ b/src/poly/domain/mod.rs @@ -145,6 +145,13 @@ mod tests { let num_cosets = 1 << (COSET_DIM - LOCALIZATION); let coset_index = rng.gen_range(0..num_cosets); + println!("{:0b}", coset_index); + dbg!(UInt32::new_witness(cs.clone(), || Ok(coset_index)) + .unwrap() + .to_bits_le() + .iter() + .map(|x| x.value().unwrap() as u8) + .collect::>()); let coset_index_var = UInt32::new_witness(cs.clone(), || Ok(coset_index)) .unwrap() .to_bits_le() From 454ffceb7ad9295bd74fc5811bdc6d7cd611788e Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Wed, 10 May 2023 11:36:32 -0700 Subject: [PATCH 04/29] Improve tests --- Cargo.toml | 2 +- src/bits/uint/and.rs | 202 +++++++++++++++++ src/bits/uint/cmp.rs | 1 + src/bits/uint/eq.rs | 68 ++++++ src/bits/{uint.rs => uint/mod.rs} | 357 ++++++++++++------------------ src/bits/uint/not.rs | 75 +++++++ src/bits/uint/or.rs | 198 +++++++++++++++++ src/bits/uint/test_utils.rs | 105 +++++++++ src/bits/uint/xor.rs | 251 +++++++++++++++++++++ src/bits/uint8.rs | 6 +- src/lib.rs | 6 + src/test_utils.rs | 20 ++ 12 files changed, 1068 insertions(+), 223 deletions(-) create mode 100644 src/bits/uint/and.rs create mode 100644 src/bits/uint/cmp.rs create mode 100644 src/bits/uint/eq.rs rename src/bits/{uint.rs => uint/mod.rs} (58%) create mode 100644 src/bits/uint/not.rs create mode 100644 src/bits/uint/or.rs create mode 100644 src/bits/uint/test_utils.rs create mode 100644 src/bits/uint/xor.rs create mode 100644 src/test_utils.rs diff --git a/Cargo.toml b/Cargo.toml index 0c63a8ca..edcecf4b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ keywords = ["zero-knowledge", "cryptography", "zkSNARK", "SNARK", "r1cs"] categories = ["cryptography"] include = ["Cargo.toml", "src", "README.md", "LICENSE-APACHE", "LICENSE-MIT"] license = "MIT/Apache-2.0" -edition = "2018" +edition = "2021" [dependencies] ark-ff = { version = "0.4.0", default-features = false } diff --git a/src/bits/uint/and.rs b/src/bits/uint/and.rs new file mode 100644 index 00000000..fc3d0676 --- /dev/null +++ b/src/bits/uint/and.rs @@ -0,0 +1,202 @@ +use ark_ff::Field; +use ark_relations::r1cs::SynthesisError; +use ark_std::{fmt::Debug, ops::BitAnd, ops::BitAndAssign}; +use num_traits::PrimInt; + +use super::UInt; + +impl UInt { + fn _and(&self, other: &Self) -> Result { + let mut result = self.clone(); + for (a, b) in result.bits.iter_mut().zip(&other.bits) { + *a = a.and(b)? + } + result.value = self.value.and_then(|a| Some(a & other.value?)); + dbg!(result.value); + Ok(result) + } +} + +impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitAnd for &'a UInt { + type Output = UInt; + /// Outputs `self & other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; + /// let c = UInt8::new_witness(cs.clone(), || Ok(16 & 17))?; + /// + /// (a & &b).enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitand(self, other: Self) -> Self::Output { + self._and(other).unwrap() + } +} + +impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitAnd<&'a Self> for UInt { + type Output = UInt; + /// Outputs `self & other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; + /// let c = UInt8::new_witness(cs.clone(), || Ok(16 & 17))?; + /// + /// (a & &b).enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitand(self, other: &Self) -> Self::Output { + self._and(&other).unwrap() + } +} + +impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitAnd> for &'a UInt { + type Output = UInt; + + /// Outputs `self & other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; + /// let c = UInt8::new_witness(cs.clone(), || Ok(16 & 17))?; + /// + /// (a & &b).enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitand(self, other: UInt) -> Self::Output { + self._and(&other).unwrap() + } +} + +impl BitAnd for UInt { + type Output = Self; + + /// Outputs `self & other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; + /// let c = UInt8::new_witness(cs.clone(), || Ok(16 & 17))?; + /// + /// (a & &b).enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitand(self, other: Self) -> Self::Output { + self._and(&other).unwrap() + } +} + +impl BitAndAssign for UInt { + /// Sets `self = self & other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let mut a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; + /// let c = UInt8::new_witness(cs.clone(), || Ok(16 & 17))?; + /// + /// a &= &b; + /// a.enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitand_assign(&mut self, other: Self) { + let result = self._and(&other).unwrap(); + *self = result; + } +} + +impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitAndAssign<&'a Self> for UInt { + /// Sets `self = self & other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let mut a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; + /// let c = UInt8::new_witness(cs.clone(), || Ok(16 & 17))?; + /// + /// a &= &b; + /// a.enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitand_assign(&mut self, other: &'a Self) { + let result = self._and(other).unwrap(); + *self = result; + } +} diff --git a/src/bits/uint/cmp.rs b/src/bits/uint/cmp.rs new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/src/bits/uint/cmp.rs @@ -0,0 +1 @@ + diff --git a/src/bits/uint/eq.rs b/src/bits/uint/eq.rs new file mode 100644 index 00000000..18fc76ea --- /dev/null +++ b/src/bits/uint/eq.rs @@ -0,0 +1,68 @@ +use ark_ff::PrimeField; +use ark_relations::r1cs::SynthesisError; +use ark_std::fmt::Debug; + +use num_traits::PrimInt; + +use crate::boolean::Boolean; +use crate::eq::EqGadget; + +use super::UInt; + +impl EqGadget + for UInt +{ + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn is_eq(&self, other: &Self) -> Result, SynthesisError> { + let chunk_size = usize::try_from(ConstraintF::MODULUS_BIT_SIZE - 1).unwrap(); + let chunks_are_eq = self + .bits + .chunks(chunk_size) + .zip(other.bits.chunks(chunk_size)) + .map(|(a, b)| { + let a = Boolean::le_bits_to_fp_var(a)?; + let b = Boolean::le_bits_to_fp_var(b)?; + a.is_eq(&b) + }) + .collect::, _>>()?; + Boolean::kary_and(&chunks_are_eq) + } + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn conditional_enforce_equal( + &self, + other: &Self, + condition: &Boolean, + ) -> Result<(), SynthesisError> { + let chunk_size = usize::try_from(ConstraintF::MODULUS_BIT_SIZE - 1).unwrap(); + for (a, b) in self + .bits + .chunks(chunk_size) + .zip(other.bits.chunks(chunk_size)) + { + let a = Boolean::le_bits_to_fp_var(a)?; + let b = Boolean::le_bits_to_fp_var(b)?; + a.conditional_enforce_equal(&b, condition)?; + } + Ok(()) + } + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn conditional_enforce_not_equal( + &self, + other: &Self, + condition: &Boolean, + ) -> Result<(), SynthesisError> { + let chunk_size = usize::try_from(ConstraintF::MODULUS_BIT_SIZE - 1).unwrap(); + for (a, b) in self + .bits + .chunks(chunk_size) + .zip(other.bits.chunks(chunk_size)) + { + let a = Boolean::le_bits_to_fp_var(a)?; + let b = Boolean::le_bits_to_fp_var(b)?; + a.conditional_enforce_not_equal(&b, condition)?; + } + Ok(()) + } +} diff --git a/src/bits/uint.rs b/src/bits/uint/mod.rs similarity index 58% rename from src/bits/uint.rs rename to src/bits/uint/mod.rs index 926fe367..48122491 100644 --- a/src/bits/uint.rs +++ b/src/bits/uint/mod.rs @@ -13,6 +13,16 @@ use crate::{ Assignment, Vec, }; +mod and; +mod cmp; +mod eq; +mod not; +mod or; +mod xor; + +#[cfg(test)] +pub(crate) mod test_utils; + /// This struct represent an unsigned `N` bit integer as a sequence of `N` [`Boolean`]s. #[derive(Clone, Debug)] pub struct UInt { @@ -174,39 +184,6 @@ impl UInt { result } - /// Outputs `self ^ other`. - /// - /// If at least one of `self` and `other` are constants, then this method - /// *does not* create any constraints or variables. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; - /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; - /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; - /// - /// a.xor(&b)?.enforce_equal(&c)?; - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - #[tracing::instrument(target = "r1cs", skip(self, other))] - pub fn xor(&self, other: &Self) -> Result { - let mut result = self.clone(); - for (a, b) in result.bits.iter_mut().zip(&other.bits) { - *a = a.xor(b)? - } - result.value = self.value.and_then(|a| Some(a ^ other.value?)); - dbg!(result.value); - Ok(result) - } - /// Perform modular addition of `operands`. /// /// The user must ensure that overflow does not occur. @@ -357,64 +334,6 @@ impl ToBytesGadget EqGadget - for UInt -{ - #[tracing::instrument(target = "r1cs", skip(self, other))] - fn is_eq(&self, other: &Self) -> Result, SynthesisError> { - let chunk_size = usize::try_from(ConstraintF::MODULUS_BIT_SIZE - 1).unwrap(); - let chunks_are_eq = self - .bits - .chunks(chunk_size) - .zip(other.bits.chunks(chunk_size)) - .map(|(a, b)| { - let a = Boolean::le_bits_to_fp_var(a)?; - let b = Boolean::le_bits_to_fp_var(b)?; - a.is_eq(&b) - }) - .collect::, _>>()?; - Boolean::kary_and(&chunks_are_eq) - } - - #[tracing::instrument(target = "r1cs", skip(self, other))] - fn conditional_enforce_equal( - &self, - other: &Self, - condition: &Boolean, - ) -> Result<(), SynthesisError> { - let chunk_size = usize::try_from(ConstraintF::MODULUS_BIT_SIZE - 1).unwrap(); - for (a, b) in self - .bits - .chunks(chunk_size) - .zip(other.bits.chunks(chunk_size)) - { - let a = Boolean::le_bits_to_fp_var(a)?; - let b = Boolean::le_bits_to_fp_var(b)?; - a.conditional_enforce_equal(&b, condition)?; - } - Ok(()) - } - - #[tracing::instrument(target = "r1cs", skip(self, other))] - fn conditional_enforce_not_equal( - &self, - other: &Self, - condition: &Boolean, - ) -> Result<(), SynthesisError> { - let chunk_size = usize::try_from(ConstraintF::MODULUS_BIT_SIZE - 1).unwrap(); - for (a, b) in self - .bits - .chunks(chunk_size) - .zip(other.bits.chunks(chunk_size)) - { - let a = Boolean::le_bits_to_fp_var(a)?; - let b = Boolean::le_bits_to_fp_var(b)?; - a.conditional_enforce_not_equal(&b, condition)?; - } - Ok(()) - } -} - impl CondSelectGadget for UInt { @@ -473,182 +392,182 @@ impl AllocVar Result<(), SynthesisError> { -// let mut rng = ark_std::test_rng(); - -// for _ in 0..1000 { -// let v = (0..$size) -// .map(|_| Boolean::constant(rng.gen())) -// .collect::>>(); - -// let b = UInt::from_bits_le(&v); - -// for (i, bit) in b.bits.iter().enumerate() { -// match bit { -// &Boolean::Constant(bit) => { -// assert_eq!(bit, ((b.value()? >> i) & 1 == 1)); -// }, -// _ => unreachable!(), -// } +// #[cfg(test)] +// mod test { +// use super::UInt; +// use crate::{bits::boolean::Boolean, prelude::*, Vec}; +// use ark_relations::r1cs::{ConstraintSystem, SynthesisError}; +// use ark_std::rand::Rng; +// use ark_test_curves::mnt4_753::Fr; + +// #[test] +// fn test_from_bits() -> Result<(), SynthesisError> { +// let mut rng = ark_std::test_rng(); + +// for _ in 0..1000 { +// let v = (0..$size) +// .map(|_| Boolean::constant(rng.gen())) +// .collect::>>(); + +// let b = UInt::from_bits_le(&v); + +// for (i, bit) in b.bits.iter().enumerate() { +// match bit { +// &Boolean::Constant(bit) => { +// assert_eq!(bit, ((b.value()? >> i) & 1 == 1)); +// }, +// _ => unreachable!(), // } +// } -// let expected_to_be_same = b.to_bits_le(); +// let expected_to_be_same = b.to_bits_le(); -// for x in v.iter().zip(expected_to_be_same.iter()) { -// match x { -// (&Boolean::Constant(true), &Boolean::Constant(true)) => {}, -// (&Boolean::Constant(false), &Boolean::Constant(false)) => {}, -// _ => unreachable!(), -// } +// for x in v.iter().zip(expected_to_be_same.iter()) { +// match x { +// (&Boolean::Constant(true), &Boolean::Constant(true)) => {}, +// (&Boolean::Constant(false), &Boolean::Constant(false)) => {}, +// _ => unreachable!(), // } // } -// Ok(()) // } +// Ok(()) +// } -// #[test] -// fn test_xor() -> Result<(), SynthesisError> { -// use Boolean::*; -// let mut rng = ark_std::test_rng(); - -// for _ in 0..1000 { -// let cs = ConstraintSystem::::new_ref(); +// #[test] +// fn test_xor() -> Result<(), SynthesisError> { +// use Boolean::*; +// let mut rng = ark_std::test_rng(); -// let a: $native = rng.gen(); -// let b: $native = rng.gen(); -// let c: $native = rng.gen(); +// for _ in 0..1000 { +// let cs = ConstraintSystem::::new_ref(); -// let mut expected = a ^ b ^ c; +// let a: $native = rng.gen(); +// let b: $native = rng.gen(); +// let c: $native = rng.gen(); -// let a_bit = $name::new_witness(cs.clone(), || Ok(a))?; -// let b_bit = $name::constant(b); -// let c_bit = $name::new_witness(cs.clone(), || Ok(c))?; +// let mut expected = a ^ b ^ c; -// let r = a_bit.xor(&b_bit).unwrap(); -// let r = r.xor(&c_bit).unwrap(); +// let a_bit = $name::new_witness(cs.clone(), || Ok(a))?; +// let b_bit = $name::constant(b); +// let c_bit = $name::new_witness(cs.clone(), || Ok(c))?; -// assert!(cs.is_satisfied().unwrap()); +// let r = a_bit.xor(&b_bit).unwrap(); +// let r = r.xor(&c_bit).unwrap(); -// assert!(r.value == Some(expected)); +// assert!(cs.is_satisfied().unwrap()); -// for b in r.bits.iter() { -// match b { -// Is(b) => assert_eq!(b.value()?, (expected & 1 == 1)), -// Not(b) => assert_eq!(!b.value()?, (expected & 1 == 1)), -// Constant(b) => assert_eq!(*b, (expected & 1 == 1)), -// } +// assert!(r.value == Some(expected)); -// expected >>= 1; +// for b in r.bits.iter() { +// match b { +// Is(b) => assert_eq!(b.value()?, (expected & 1 == 1)), +// Not(b) => assert_eq!(!b.value()?, (expected & 1 == 1)), +// Constant(b) => assert_eq!(*b, (expected & 1 == 1)), // } + +// expected >>= 1; // } -// Ok(()) // } +// Ok(()) +// } -// #[test] -// fn test_add_many_constants() -> Result<(), SynthesisError> { -// let mut rng = ark_std::test_rng(); - -// for _ in 0..1000 { -// let cs = ConstraintSystem::::new_ref(); +// #[test] +// fn test_add_many_constants() -> Result<(), SynthesisError> { +// let mut rng = ark_std::test_rng(); -// let a: $native = rng.gen(); -// let b: $native = rng.gen(); -// let c: $native = rng.gen(); +// for _ in 0..1000 { +// let cs = ConstraintSystem::::new_ref(); -// let a_bit = $name::new_constant(cs.clone(), a)?; -// let b_bit = $name::new_constant(cs.clone(), b)?; -// let c_bit = $name::new_constant(cs.clone(), c)?; +// let a: $native = rng.gen(); +// let b: $native = rng.gen(); +// let c: $native = rng.gen(); -// let mut expected = a.wrapping_add(b).wrapping_add(c); +// let a_bit = $name::new_constant(cs.clone(), a)?; +// let b_bit = $name::new_constant(cs.clone(), b)?; +// let c_bit = $name::new_constant(cs.clone(), c)?; -// let r = $name::add_many(&[a_bit, b_bit, c_bit]).unwrap(); +// let mut expected = a.wrapping_add(b).wrapping_add(c); -// assert!(r.value == Some(expected)); +// let r = $name::add_many(&[a_bit, b_bit, c_bit]).unwrap(); -// for b in r.bits.iter() { -// match b { -// Boolean::Is(_) => unreachable!(), -// Boolean::Not(_) => unreachable!(), -// Boolean::Constant(b) => assert_eq!(*b, (expected & 1 == 1)), -// } +// assert!(r.value == Some(expected)); -// expected >>= 1; +// for b in r.bits.iter() { +// match b { +// Boolean::Is(_) => unreachable!(), +// Boolean::Not(_) => unreachable!(), +// Boolean::Constant(b) => assert_eq!(*b, (expected & 1 == 1)), // } + +// expected >>= 1; // } -// Ok(()) // } +// Ok(()) +// } -// #[test] -// fn test_add_many() -> Result<(), SynthesisError> { -// let mut rng = ark_std::test_rng(); - -// for _ in 0..1000 { -// let cs = ConstraintSystem::::new_ref(); +// #[test] +// fn test_add_many() -> Result<(), SynthesisError> { +// let mut rng = ark_std::test_rng(); -// let a: $native = rng.gen(); -// let b: $native = rng.gen(); -// let c: $native = rng.gen(); -// let d: $native = rng.gen(); +// for _ in 0..1000 { +// let cs = ConstraintSystem::::new_ref(); -// let mut expected = (a ^ b).wrapping_add(c).wrapping_add(d); +// let a: $native = rng.gen(); +// let b: $native = rng.gen(); +// let c: $native = rng.gen(); +// let d: $native = rng.gen(); -// let a_bit = $name::new_witness(ark_relations::ns!(cs, "a_bit"), || Ok(a))?; -// let b_bit = $name::constant(b); -// let c_bit = $name::constant(c); -// let d_bit = $name::new_witness(ark_relations::ns!(cs, "d_bit"), || Ok(d))?; +// let mut expected = (a ^ b).wrapping_add(c).wrapping_add(d); -// let r = a_bit.xor(&b_bit).unwrap(); -// let r = $name::add_many(&[r, c_bit, d_bit]).unwrap(); +// let a_bit = $name::new_witness(ark_relations::ns!(cs, "a_bit"), || Ok(a))?; +// let b_bit = $name::constant(b); +// let c_bit = $name::constant(c); +// let d_bit = $name::new_witness(ark_relations::ns!(cs, "d_bit"), || Ok(d))?; -// assert!(cs.is_satisfied().unwrap()); -// assert!(r.value == Some(expected)); +// let r = a_bit.xor(&b_bit).unwrap(); +// let r = $name::add_many(&[r, c_bit, d_bit]).unwrap(); -// for b in r.bits.iter() { -// match b { -// Boolean::Is(b) => assert_eq!(b.value()?, (expected & 1 == 1)), -// Boolean::Not(b) => assert_eq!(!b.value()?, (expected & 1 == 1)), -// Boolean::Constant(_) => unreachable!(), -// } +// assert!(cs.is_satisfied().unwrap()); +// assert!(r.value == Some(expected)); -// expected >>= 1; +// for b in r.bits.iter() { +// match b { +// Boolean::Is(b) => assert_eq!(b.value()?, (expected & 1 == 1)), +// Boolean::Not(b) => assert_eq!(!b.value()?, (expected & 1 == 1)), +// Boolean::Constant(_) => unreachable!(), // } + +// expected >>= 1; // } -// Ok(()) // } +// Ok(()) +// } -// #[test] -// fn test_rotr() -> Result<(), SynthesisError> { -// let mut rng = ark_std::test_rng(); - -// let mut num = rng.gen(); +// #[test] +// fn test_rotr() -> Result<(), SynthesisError> { +// let mut rng = ark_std::test_rng(); -// let a: $name = $name::constant(num); +// let mut num = rng.gen(); -// for i in 0..$size { -// let b = a.rotr(i); +// let a: $name = $name::constant(num); -// assert!(b.value.unwrap() == num); +// for i in 0..$size { +// let b = a.rotr(i); -// let mut tmp = num; -// for b in &b.bits { -// match b { -// Boolean::Constant(b) => assert_eq!(*b, tmp & 1 == 1), -// _ => unreachable!(), -// } +// assert!(b.value.unwrap() == num); -// tmp >>= 1; +// let mut tmp = num; +// for b in &b.bits { +// match b { +// Boolean::Constant(b) => assert_eq!(*b, tmp & 1 == 1), +// _ => unreachable!(), // } -// num = num.rotate_right(1); +// tmp >>= 1; // } -// Ok(()) + +// num = num.rotate_right(1); // } +// Ok(()) // } +// } diff --git a/src/bits/uint/not.rs b/src/bits/uint/not.rs new file mode 100644 index 00000000..37d91cb3 --- /dev/null +++ b/src/bits/uint/not.rs @@ -0,0 +1,75 @@ +use ark_ff::Field; +use ark_relations::r1cs::SynthesisError; +use ark_std::{fmt::Debug, ops::Not}; +use num_traits::PrimInt; + +use super::UInt; + +impl UInt { + fn _not(&self) -> Result { + let mut result = self.clone(); + for a in &mut result.bits { + *a = a.not() + } + result.value = self.value.map(Not::not); + dbg!(result.value); + Ok(result) + } +} + +impl<'a, const N: usize, T: PrimInt + Debug, F: Field> Not for &'a UInt { + type Output = UInt; + /// Outputs `!self`. + /// + /// If `self` is a constant, then this method *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(2))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(!2))?; + /// + /// (!a).enforce_equal(&b)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self))] + fn not(self) -> Self::Output { + self._not().unwrap() + } +} + +impl<'a, const N: usize, T: PrimInt + Debug, F: Field> Not for UInt { + type Output = UInt; + + /// Outputs `!self`. + /// + /// If `self` is a constant, then this method *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(2))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(!2))?; + /// + /// (!a).enforce_equal(&b)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self))] + fn not(self) -> Self::Output { + self._not().unwrap() + } +} diff --git a/src/bits/uint/or.rs b/src/bits/uint/or.rs new file mode 100644 index 00000000..02bfecea --- /dev/null +++ b/src/bits/uint/or.rs @@ -0,0 +1,198 @@ +use ark_ff::Field; +use ark_relations::r1cs::SynthesisError; +use ark_std::{fmt::Debug, ops::BitOr, ops::BitOrAssign}; +use num_traits::PrimInt; + +use super::UInt; + +impl UInt { + fn _or(&self, other: &Self) -> Result { + let mut result = self.clone(); + for (a, b) in result.bits.iter_mut().zip(&other.bits) { + *a = a.or(b)? + } + result.value = self.value.and_then(|a| Some(a | other.value?)); + dbg!(result.value); + Ok(result) + } +} + +impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitOr for &'a UInt { + type Output = UInt; + /// Outputs `self ^ other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; + /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; + /// + /// a.xor(&b)?.enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitor(self, other: Self) -> Self::Output { + self._or(other).unwrap() + } +} + +impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitOr<&'a Self> for UInt { + type Output = UInt; + /// Outputs `self ^ other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; + /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; + /// + /// a.xor(&b)?.enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitor(self, other: &Self) -> Self::Output { + self._or(&other).unwrap() + } +} + +impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitOr> for &'a UInt { + type Output = UInt; + /// Outputs `self ^ other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; + /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; + /// + /// a.xor(&b)?.enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitor(self, other: UInt) -> Self::Output { + self._or(&other).unwrap() + } +} + +impl BitOr for UInt { + type Output = Self; + /// Outputs `self ^ other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; + /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; + /// + /// a.xor(&b)?.enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitor(self, other: Self) -> Self::Output { + self._or(&other).unwrap() + } +} + +impl BitOrAssign for UInt { + /// Sets `self = self ^ other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; + /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; + /// + /// a.xor(&b)?.enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitor_assign(&mut self, other: Self) { + let result = self._or(&other).unwrap(); + *self = result; + } +} + +impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitOrAssign<&'a Self> for UInt { + /// Sets `self = self ^ other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; + /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; + /// + /// a.xor(&b)?.enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitor_assign(&mut self, other: &'a Self) { + let result = self._or(other).unwrap(); + *self = result; + } +} diff --git a/src/bits/uint/test_utils.rs b/src/bits/uint/test_utils.rs new file mode 100644 index 00000000..3be52e22 --- /dev/null +++ b/src/bits/uint/test_utils.rs @@ -0,0 +1,105 @@ +use crate::test_utils::modes; + +use super::*; +use ark_relations::r1cs::SynthesisError; +use ark_std::UniformRand; +use num_traits::PrimInt; + +use std::ops::RangeInclusive; + +pub(crate) fn test_unary_op( + a: T, + mode: AllocationMode, + test: impl FnOnce(UInt) -> Result<(), SynthesisError>, +) -> Result<(), SynthesisError> { + let cs = ConstraintSystem::::new_ref(); + let b = UInt::::new_variable(cs.clone(), mode_b, || Ok(b)); + test(a) +} + +pub(crate) fn test_binary_op( + a: T, + b: T, + mode_a: AllocationMode, + mode_b: AllocationMode, + test: impl FnOnce(UInt) -> Result<(), SynthesisError>, +) -> Result<(), SynthesisError> { + let cs = ConstraintSystem::::new_ref(); + let a = UInt::::new_variable(cs.clone(), mode_a, || Ok(a)); + let b = UInt::::new_variable(cs.clone(), mode_b, || Ok(b)); + test(a, b) +} + +pub(crate) fn run_binary_random( + test: impl FnOnce(UInt, UInt) -> Result<(), SynthesisError>, +) -> Result<(), SynthesisError> +where + T: PrimInt + Debug + UniformRand, + F: PrimeField, +{ + let mut rng = ark_std::test_rng(); + + for i in 0..ITERATIONS { + for mode_a in modes() { + let a = T::rand(&mut rng); + for mode_b in modes() { + let b = T::rand(&mut rng); + test_binary_op(a, b, mode_a, mode_b, test)?; + } + } + } + Ok(()) +} + +pub(crate) fn run_binary_exhaustive( + test: impl FnOnce(UInt, UInt) -> Result<(), SynthesisError>, +) -> Result<(), SynthesisError> +where + T: PrimInt + Debug + UniformRand, + F: PrimeField, + RangeInclusive: Iterator, +{ + for (a, mode_a) in test_utils::combination(T::min_value()..=T::max_value()) { + for (b, mode_b) in test_utils::combination(T::min_value()..=T::max_value()) { + test_binary_op(a, b, mode_a, mode_b, test)?; + } + } + Ok(()) +} + +pub(crate) fn run_unary_random( + test: impl FnOnce(UInt) -> Result<(), SynthesisError>, +) -> Result<(), SynthesisError> +where + T: PrimInt + Debug + UniformRand, + F: PrimeField, +{ + let mut rng = ark_std::test_rng(); + + for i in 0..ITERATIONS { + for mode_a in modes() { + let a = T::rand(&mut rng); + for mode_b in modes() { + let b = T::rand(&mut rng); + test_unary_op(a, mode_a, test)?; + } + } + } + Ok(()) +} + +fn run_unary_exhaustive( + test: impl FnOnce(UInt) -> Result<(), SynthesisError>, +) -> Result<(), SynthesisError> +where + T: PrimInt + Debug + UniformRand, + F: PrimeField, + RangeInclusive: Iterator, +{ + for (a, mode_a) in test_utils::combination(T::min_value()..=T::max_value()) { + for (b, mode_b) in test_utils::combination(T::min_value()..=T::max_value()) { + test_unary_op(a, mode_a, test)?; + } + } + Ok(()) +} diff --git a/src/bits/uint/xor.rs b/src/bits/uint/xor.rs new file mode 100644 index 00000000..04abad82 --- /dev/null +++ b/src/bits/uint/xor.rs @@ -0,0 +1,251 @@ +use ark_ff::Field; +use ark_relations::r1cs::SynthesisError; +use ark_std::{fmt::Debug, ops::BitXor, ops::BitXorAssign}; +use num_traits::PrimInt; + +use super::UInt; + +impl UInt { + fn _xor(&self, other: &Self) -> Result { + let mut result = self.clone(); + for (a, b) in result.bits.iter_mut().zip(&other.bits) { + *a = a.xor(b)? + } + result.value = self.value.and_then(|a| Some(a ^ other.value?)); + dbg!(result.value); + Ok(result) + } +} + +impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitXor for &'a UInt { + type Output = UInt; + /// Outputs `self ^ other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; + /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; + /// + /// a.xor(&b)?.enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitxor(self, other: Self) -> Self::Output { + self._xor(other).unwrap() + } +} + +impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitXor<&'a Self> for UInt { + type Output = UInt; + /// Outputs `self ^ other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; + /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; + /// + /// a.xor(&b)?.enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitxor(self, other: &Self) -> Self::Output { + self._xor(&other).unwrap() + } +} + +impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitXor> for &'a UInt { + type Output = UInt; + /// Outputs `self ^ other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; + /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; + /// + /// a.xor(&b)?.enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitxor(self, other: UInt) -> Self::Output { + self._xor(&other).unwrap() + } +} + +impl BitXor for UInt { + type Output = Self; + /// Outputs `self ^ other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; + /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; + /// + /// a.xor(&b)?.enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitxor(self, other: Self) -> Self::Output { + self._xor(&other).unwrap() + } +} + +impl BitXorAssign for UInt { + /// Sets `self = self ^ other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; + /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; + /// + /// a.xor(&b)?.enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitxor_assign(&mut self, other: Self) { + let result = self._xor(&other).unwrap(); + *self = result; + } +} + +impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitXorAssign<&'a Self> for UInt { + /// Sets `self = self ^ other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; + /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; + /// + /// a.xor(&b)?.enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitxor_assign(&mut self, other: &'a Self) { + let result = self._xor(other).unwrap(); + *self = result; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::AllocVar, + prelude::EqGadget, + uint::test_utils::{run_binary_exhaustive, run_binary_random}, + R1CSVar, + }; + use ark_ff::PrimeField; + use ark_test_curves::bls12_381::Fr; + + fn uint_xor( + a: UInt, + b: UInt, + ) -> Result<(), SynthesisError> { + let cs = a.cs().or(b.cs()); + let computed = a ^ b; + let expected = UInt::new_witness(ark_relations::ns!(cs, "xor"), || { + Ok(a.value().unwrap() ^ b.value().unwrap()) + })?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&expected)?; + assert!(cs.is_satisfied().unwrap()); + Ok(()) + } + + #[test] + fn u8_xor() { + run_binary_exhaustive(uint_xor::).unwrap() + } + + #[test] + fn u16_xor() { + run_binary_random::<1000, 16, _, _>(uint_xor::).unwrap() + } + + #[test] + fn u32_xor() { + run_binary_random::<1000, 32, _, _>(uint_xor::).unwrap() + } + + #[test] + fn u64_xor() { + run_binary_random::<1000, 64, _, _>(uint_xor::).unwrap() + } + + #[test] + fn u128() { + run_binary_random::<1000, 128, _, _>(uint_xor::).unwrap() + } +} diff --git a/src/bits/uint8.rs b/src/bits/uint8.rs index 863415ab..6b977a32 100644 --- a/src/bits/uint8.rs +++ b/src/bits/uint8.rs @@ -195,12 +195,12 @@ mod test { dbg!(b_bit.value().unwrap()); dbg!(c_bit.value().unwrap()); - let r = a_bit.xor(&b_bit).unwrap(); - let r = r.xor(&c_bit).unwrap(); + let mut r = a_bit ^ b_bit; + r ^= &c_bit; assert!(cs.is_satisfied().unwrap()); - dbg!(Some(expected)); + dbg!(expected); dbg!(r.value().unwrap()); assert_eq!(r.value, Some(expected)); diff --git a/src/lib.rs b/src/lib.rs index 9c6b7019..9964c8c9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -50,15 +50,21 @@ pub mod pairing; /// This module describes a trait for allocating new variables in a constraint /// system. pub mod alloc; + /// This module describes a trait for checking equality of variables. pub mod eq; + /// This module implements functions for manipulating polynomial variables over /// finite fields. pub mod poly; + /// This module describes traits for conditionally selecting a variable from a /// list of variables. pub mod select; +#[cfg(test)] +pub(crate) mod test_utils; + #[allow(missing_docs)] pub mod prelude { pub use crate::{ diff --git a/src/test_utils.rs b/src/test_utils.rs new file mode 100644 index 00000000..7989c006 --- /dev/null +++ b/src/test_utils.rs @@ -0,0 +1,20 @@ +use crate::alloc::AllocationMode; + +pub(crate) fn modes() -> impl Iterator { + [ + AllocationMode::Constant, + AllocationMode::Input, + AllocationMode::Witness, + ] + .into_iter() +} + +pub(crate) fn combination( + mut iter: impl Iterator, +) -> impl Iterator { + core::iter::from_fn(move || { + iter.next() + .map(move |t| modes().map(move |mode| (mode, t.clone()))) + }) + .flat_map(|x| x) +} From cedd4f8463a326764bbdb41af5f0d0404b82a74e Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Wed, 10 May 2023 12:53:46 -0700 Subject: [PATCH 05/29] Better tests --- src/bits/uint/and.rs | 63 +++++++++++++++++++++++++++++++ src/bits/uint/eq.rs | 62 +++++++++++++++++++++++++++++++ src/bits/uint/not.rs | 60 ++++++++++++++++++++++++++++++ src/bits/uint/or.rs | 74 ++++++++++++++++++++++++++++++++++--- src/bits/uint/test_utils.rs | 41 +++++++++----------- src/bits/uint/xor.rs | 24 ++++++++---- src/test_utils.rs | 21 ++++------- 7 files changed, 295 insertions(+), 50 deletions(-) diff --git a/src/bits/uint/and.rs b/src/bits/uint/and.rs index fc3d0676..38b31564 100644 --- a/src/bits/uint/and.rs +++ b/src/bits/uint/and.rs @@ -200,3 +200,66 @@ impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitAndAssign<&'a Self> fo *self = result; } } + + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + prelude::EqGadget, + uint::test_utils::{run_binary_exhaustive, run_binary_random}, + R1CSVar, + }; + use ark_ff::PrimeField; + use ark_test_curves::bls12_381::Fr; + + fn uint_and( + a: UInt, + b: UInt, + ) -> Result<(), SynthesisError> { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let computed = &a & &b; + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = UInt::::new_variable(cs.clone(), + || Ok(a.value().unwrap() & b.value().unwrap()), + expected_mode + )?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&expected)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + #[test] + fn u8_and() { + run_binary_exhaustive(uint_and::).unwrap() + } + + #[test] + fn u16_and() { + run_binary_random::<1000, 16, _, _>(uint_and::).unwrap() + } + + #[test] + fn u32_and() { + run_binary_random::<1000, 32, _, _>(uint_and::).unwrap() + } + + #[test] + fn u64_and() { + run_binary_random::<1000, 64, _, _>(uint_and::).unwrap() + } + + #[test] + fn u128_and() { + run_binary_random::<1000, 128, _, _>(uint_and::).unwrap() + } +} \ No newline at end of file diff --git a/src/bits/uint/eq.rs b/src/bits/uint/eq.rs index 18fc76ea..d6b5675a 100644 --- a/src/bits/uint/eq.rs +++ b/src/bits/uint/eq.rs @@ -66,3 +66,65 @@ impl EqGadget( + a: UInt, + b: UInt, + ) -> Result<(), SynthesisError> { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let computed = a.is_eq(&b)?; + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = Boolean::new_variable(cs.clone(), + || Ok(a.value().unwrap() == b.value().unwrap()), + expected_mode + )?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&expected)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + #[test] + fn u8_eq() { + run_binary_exhaustive(uint_eq::).unwrap() + } + + #[test] + fn u16_eq() { + run_binary_random::<1000, 16, _, _>(uint_eq::).unwrap() + } + + #[test] + fn u32_eq() { + run_binary_random::<1000, 32, _, _>(uint_eq::).unwrap() + } + + #[test] + fn u64_eq() { + run_binary_random::<1000, 64, _, _>(uint_eq::).unwrap() + } + + #[test] + fn u128_eq() { + run_binary_random::<1000, 128, _, _>(uint_eq::).unwrap() + } +} \ No newline at end of file diff --git a/src/bits/uint/not.rs b/src/bits/uint/not.rs index 37d91cb3..350ff9e8 100644 --- a/src/bits/uint/not.rs +++ b/src/bits/uint/not.rs @@ -73,3 +73,63 @@ impl<'a, const N: usize, T: PrimInt + Debug, F: Field> Not for UInt { self._not().unwrap() } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + prelude::EqGadget, + uint::test_utils::{run_unary_exhaustive, run_unary_random}, + R1CSVar, + }; + use ark_ff::PrimeField; + use ark_test_curves::bls12_381::Fr; + + fn uint_not( + a: UInt, + ) -> Result<(), SynthesisError> { + let cs = a.cs(); + let computed = !&a; + let expected_mode = if a.is_constant() { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = UInt::::new_variable(cs.clone(), + || Ok(!a.value().unwrap()), + expected_mode + )?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&expected)?; + if !a.is_constant() { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + #[test] + fn u8_not() { + run_unary_exhaustive(uint_not::).unwrap() + } + + #[test] + fn u16_not() { + run_unary_random::<1000, 16, _, _>(uint_not::).unwrap() + } + + #[test] + fn u32_not() { + run_unary_random::<1000, 32, _, _>(uint_not::).unwrap() + } + + #[test] + fn u64_not() { + run_unary_random::<1000, 64, _, _>(uint_not::).unwrap() + } + + #[test] + fn u128() { + run_unary_random::<1000, 128, _, _>(uint_not::).unwrap() + } +} diff --git a/src/bits/uint/or.rs b/src/bits/uint/or.rs index 02bfecea..05fbffaa 100644 --- a/src/bits/uint/or.rs +++ b/src/bits/uint/or.rs @@ -36,7 +36,7 @@ impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitOr for &'a UInt< /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; /// - /// a.xor(&b)?.enforce_equal(&c)?; + /// a.or(&b)?.enforce_equal(&c)?; /// assert!(cs.is_satisfied().unwrap()); /// # Ok(()) /// # } @@ -66,7 +66,7 @@ impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitOr<&'a Self> for UInt< /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; /// - /// a.xor(&b)?.enforce_equal(&c)?; + /// a.or(&b)?.enforce_equal(&c)?; /// assert!(cs.is_satisfied().unwrap()); /// # Ok(()) /// # } @@ -96,7 +96,7 @@ impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitOr> for /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; /// - /// a.xor(&b)?.enforce_equal(&c)?; + /// a.or(&b)?.enforce_equal(&c)?; /// assert!(cs.is_satisfied().unwrap()); /// # Ok(()) /// # } @@ -126,7 +126,7 @@ impl BitOr for UInt /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; /// - /// a.xor(&b)?.enforce_equal(&c)?; + /// a.or(&b)?.enforce_equal(&c)?; /// assert!(cs.is_satisfied().unwrap()); /// # Ok(()) /// # } @@ -155,7 +155,7 @@ impl BitOrAssign for UInt BitOrAssign<&'a Self> for /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; /// - /// a.xor(&b)?.enforce_equal(&c)?; + /// a.or(&b)?.enforce_equal(&c)?; /// assert!(cs.is_satisfied().unwrap()); /// # Ok(()) /// # } @@ -196,3 +196,65 @@ impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitOrAssign<&'a Self> for *self = result; } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + prelude::EqGadget, + uint::test_utils::{run_binary_exhaustive, run_binary_random}, + R1CSVar, + }; + use ark_ff::PrimeField; + use ark_test_curves::bls12_381::Fr; + + fn uint_or( + a: UInt, + b: UInt, + ) -> Result<(), SynthesisError> { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let computed = &a | &b; + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = UInt::::new_variable(cs.clone(), + || Ok(a.value().unwrap() | b.value().unwrap()), + expected_mode + )?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&expected)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + #[test] + fn u8_or() { + run_binary_exhaustive(uint_or::).unwrap() + } + + #[test] + fn u16_or() { + run_binary_random::<1000, 16, _, _>(uint_or::).unwrap() + } + + #[test] + fn u32_or() { + run_binary_random::<1000, 32, _, _>(uint_or::).unwrap() + } + + #[test] + fn u64_or() { + run_binary_random::<1000, 64, _, _>(uint_or::).unwrap() + } + + #[test] + fn u128_or() { + run_binary_random::<1000, 128, _, _>(uint_or::).unwrap() + } +} diff --git a/src/bits/uint/test_utils.rs b/src/bits/uint/test_utils.rs index 3be52e22..6686dc19 100644 --- a/src/bits/uint/test_utils.rs +++ b/src/bits/uint/test_utils.rs @@ -1,7 +1,7 @@ -use crate::test_utils::modes; +use crate::test_utils::{self, modes}; use super::*; -use ark_relations::r1cs::SynthesisError; +use ark_relations::r1cs::{ConstraintSystem, SynthesisError}; use ark_std::UniformRand; use num_traits::PrimInt; @@ -13,7 +13,7 @@ pub(crate) fn test_unary_op( test: impl FnOnce(UInt) -> Result<(), SynthesisError>, ) -> Result<(), SynthesisError> { let cs = ConstraintSystem::::new_ref(); - let b = UInt::::new_variable(cs.clone(), mode_b, || Ok(b)); + let a = UInt::::new_variable(cs.clone(), || Ok(a), mode)?; test(a) } @@ -22,16 +22,16 @@ pub(crate) fn test_binary_op( b: T, mode_a: AllocationMode, mode_b: AllocationMode, - test: impl FnOnce(UInt) -> Result<(), SynthesisError>, + test: impl FnOnce(UInt, UInt) -> Result<(), SynthesisError>, ) -> Result<(), SynthesisError> { let cs = ConstraintSystem::::new_ref(); - let a = UInt::::new_variable(cs.clone(), mode_a, || Ok(a)); - let b = UInt::::new_variable(cs.clone(), mode_b, || Ok(b)); + let a = UInt::::new_variable(cs.clone(), || Ok(a), mode_a)?; + let b = UInt::::new_variable(cs.clone(), || Ok(b), mode_b)?; test(a, b) } pub(crate) fn run_binary_random( - test: impl FnOnce(UInt, UInt) -> Result<(), SynthesisError>, + test: impl Fn(UInt, UInt) -> Result<(), SynthesisError> + Copy, ) -> Result<(), SynthesisError> where T: PrimInt + Debug + UniformRand, @@ -39,7 +39,7 @@ where { let mut rng = ark_std::test_rng(); - for i in 0..ITERATIONS { + for _ in 0..ITERATIONS { for mode_a in modes() { let a = T::rand(&mut rng); for mode_b in modes() { @@ -52,15 +52,15 @@ where } pub(crate) fn run_binary_exhaustive( - test: impl FnOnce(UInt, UInt) -> Result<(), SynthesisError>, + test: impl Fn(UInt, UInt) -> Result<(), SynthesisError> + Copy, ) -> Result<(), SynthesisError> where T: PrimInt + Debug + UniformRand, F: PrimeField, RangeInclusive: Iterator, { - for (a, mode_a) in test_utils::combination(T::min_value()..=T::max_value()) { - for (b, mode_b) in test_utils::combination(T::min_value()..=T::max_value()) { + for (mode_a, a) in test_utils::combination(T::min_value()..=T::max_value()) { + for (mode_b, b) in test_utils::combination(T::min_value()..=T::max_value()) { test_binary_op(a, b, mode_a, mode_b, test)?; } } @@ -68,7 +68,7 @@ where } pub(crate) fn run_unary_random( - test: impl FnOnce(UInt) -> Result<(), SynthesisError>, + test: impl Fn(UInt) -> Result<(), SynthesisError> + Copy, ) -> Result<(), SynthesisError> where T: PrimInt + Debug + UniformRand, @@ -76,30 +76,25 @@ where { let mut rng = ark_std::test_rng(); - for i in 0..ITERATIONS { + for _ in 0..ITERATIONS { for mode_a in modes() { let a = T::rand(&mut rng); - for mode_b in modes() { - let b = T::rand(&mut rng); - test_unary_op(a, mode_a, test)?; - } + test_unary_op(a, mode_a, test)?; } } Ok(()) } -fn run_unary_exhaustive( - test: impl FnOnce(UInt) -> Result<(), SynthesisError>, +pub(crate) fn run_unary_exhaustive( + test: impl Fn(UInt) -> Result<(), SynthesisError> + Copy, ) -> Result<(), SynthesisError> where T: PrimInt + Debug + UniformRand, F: PrimeField, RangeInclusive: Iterator, { - for (a, mode_a) in test_utils::combination(T::min_value()..=T::max_value()) { - for (b, mode_b) in test_utils::combination(T::min_value()..=T::max_value()) { - test_unary_op(a, mode_a, test)?; - } + for (mode, a) in test_utils::combination(T::min_value()..=T::max_value()) { + test_unary_op(a, mode, test)?; } Ok(()) } diff --git a/src/bits/uint/xor.rs b/src/bits/uint/xor.rs index 04abad82..73c2722c 100644 --- a/src/bits/uint/xor.rs +++ b/src/bits/uint/xor.rs @@ -12,7 +12,6 @@ impl UInt { *a = a.xor(b)? } result.value = self.value.and_then(|a| Some(a ^ other.value?)); - dbg!(result.value); Ok(result) } } @@ -201,7 +200,7 @@ impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitXorAssign<&'a Self> fo mod tests { use super::*; use crate::{ - alloc::AllocVar, + alloc::{AllocVar, AllocationMode}, prelude::EqGadget, uint::test_utils::{run_binary_exhaustive, run_binary_random}, R1CSVar, @@ -214,13 +213,22 @@ mod tests { b: UInt, ) -> Result<(), SynthesisError> { let cs = a.cs().or(b.cs()); - let computed = a ^ b; - let expected = UInt::new_witness(ark_relations::ns!(cs, "xor"), || { - Ok(a.value().unwrap() ^ b.value().unwrap()) - })?; + let both_constant = a.is_constant() && b.is_constant(); + let computed = &a ^ &b; + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = UInt::::new_variable(cs.clone(), + || Ok(a.value().unwrap() ^ b.value().unwrap()), + expected_mode + )?; assert_eq!(expected.value(), computed.value()); expected.enforce_equal(&expected)?; - assert!(cs.is_satisfied().unwrap()); + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } Ok(()) } @@ -245,7 +253,7 @@ mod tests { } #[test] - fn u128() { + fn u128_xor() { run_binary_random::<1000, 128, _, _>(uint_xor::).unwrap() } } diff --git a/src/test_utils.rs b/src/test_utils.rs index 7989c006..189eae69 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -1,20 +1,15 @@ +use core::iter; + use crate::alloc::AllocationMode; pub(crate) fn modes() -> impl Iterator { - [ - AllocationMode::Constant, - AllocationMode::Input, - AllocationMode::Witness, - ] - .into_iter() + use AllocationMode::*; + [Constant, Input, Witness].into_iter() } -pub(crate) fn combination( - mut iter: impl Iterator, +pub(crate) fn combination( + mut i: impl Iterator, ) -> impl Iterator { - core::iter::from_fn(move || { - iter.next() - .map(move |t| modes().map(move |mode| (mode, t.clone()))) - }) - .flat_map(|x| x) + iter::from_fn(move || i.next().map(|t| modes().map(move |mode| (mode, t.clone())))) + .flat_map(|x| x) } From 3144adc9cd3a18b78de29113dddfd3bdd0860275 Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Wed, 10 May 2023 15:12:01 -0700 Subject: [PATCH 06/29] Clean up `Boolean` too --- src/bits/boolean.rs | 1821 ----------------- src/bits/boolean/and.rs | 136 ++ src/bits/boolean/cmp.rs | 1 + src/bits/boolean/eq.rs | 113 + src/bits/boolean/mod.rs | 1024 +++++++++ src/bits/boolean/not.rs | 99 + src/bits/boolean/or.rs | 135 ++ src/bits/boolean/test_utils.rs | 47 + src/bits/boolean/xor.rs | 135 ++ src/bits/uint/and.rs | 12 +- src/bits/uint/eq.rs | 9 +- src/bits/uint/mod.rs | 69 +- src/bits/uint/not.rs | 8 +- src/bits/uint/or.rs | 9 +- src/bits/uint/xor.rs | 9 +- src/bits/uint8.rs | 3 +- src/eq.rs | 2 +- src/fields/cubic_extension.rs | 6 +- src/fields/fp/mod.rs | 33 +- src/fields/nonnative/field_var.rs | 6 +- src/fields/quadratic_extension.rs | 6 +- .../curves/short_weierstrass/bls12/mod.rs | 2 +- src/groups/curves/short_weierstrass/mod.rs | 21 +- .../short_weierstrass/non_zero_affine.rs | 10 +- src/groups/curves/twisted_edwards/mod.rs | 10 +- 25 files changed, 1813 insertions(+), 1913 deletions(-) delete mode 100644 src/bits/boolean.rs create mode 100644 src/bits/boolean/and.rs create mode 100644 src/bits/boolean/cmp.rs create mode 100644 src/bits/boolean/eq.rs create mode 100644 src/bits/boolean/mod.rs create mode 100644 src/bits/boolean/not.rs create mode 100644 src/bits/boolean/or.rs create mode 100644 src/bits/boolean/test_utils.rs create mode 100644 src/bits/boolean/xor.rs diff --git a/src/bits/boolean.rs b/src/bits/boolean.rs deleted file mode 100644 index ff9e79ad..00000000 --- a/src/bits/boolean.rs +++ /dev/null @@ -1,1821 +0,0 @@ -use ark_ff::{BitIteratorBE, Field, PrimeField}; - -use crate::{fields::fp::FpVar, prelude::*, Assignment, ToConstraintFieldGadget, Vec}; -use ark_relations::r1cs::{ - ConstraintSystemRef, LinearCombination, Namespace, SynthesisError, Variable, -}; -use core::borrow::Borrow; - -/// Represents a variable in the constraint system which is guaranteed -/// to be either zero or one. -/// -/// In general, one should prefer using `Boolean` instead of `AllocatedBool`, -/// as `Boolean` offers better support for constant values, and implements -/// more traits. -#[derive(Clone, Debug, Eq, PartialEq)] -#[must_use] -pub struct AllocatedBool { - variable: Variable, - cs: ConstraintSystemRef, -} - -pub(crate) fn bool_to_field(val: impl Borrow) -> F { - if *val.borrow() { - F::one() - } else { - F::zero() - } -} - -impl AllocatedBool { - /// Get the assigned value for `self`. - pub fn value(&self) -> Result { - let value = self.cs.assigned_value(self.variable).get()?; - if value.is_zero() { - Ok(false) - } else if value.is_one() { - Ok(true) - } else { - unreachable!("Incorrect value assigned: {:?}", value); - } - } - - /// Get the R1CS variable for `self`. - pub fn variable(&self) -> Variable { - self.variable - } - - /// Allocate a witness variable without a booleanity check. - fn new_witness_without_booleanity_check>( - cs: ConstraintSystemRef, - f: impl FnOnce() -> Result, - ) -> Result { - let variable = cs.new_witness_variable(|| f().map(bool_to_field))?; - Ok(Self { variable, cs }) - } - - /// Performs an XOR operation over the two operands, returning - /// an `AllocatedBool`. - #[tracing::instrument(target = "r1cs")] - pub fn xor(&self, b: &Self) -> Result { - let result = Self::new_witness_without_booleanity_check(self.cs.clone(), || { - Ok(self.value()? ^ b.value()?) - })?; - - // Constrain (a + a) * (b) = (a + b - c) - // Given that a and b are boolean constrained, if they - // are equal, the only solution for c is 0, and if they - // are different, the only solution for c is 1. - // - // ¬(a ∧ b) ∧ ¬(¬a ∧ ¬b) = c - // (1 - (a * b)) * (1 - ((1 - a) * (1 - b))) = c - // (1 - ab) * (1 - (1 - a - b + ab)) = c - // (1 - ab) * (a + b - ab) = c - // a + b - ab - (a^2)b - (b^2)a + (a^2)(b^2) = c - // a + b - ab - ab - ab + ab = c - // a + b - 2ab = c - // -2a * b = c - a - b - // 2a * b = a + b - c - // (a + a) * b = a + b - c - self.cs.enforce_constraint( - lc!() + self.variable + self.variable, - lc!() + b.variable, - lc!() + self.variable + b.variable - result.variable, - )?; - - Ok(result) - } - - /// Performs an AND operation over the two operands, returning - /// an `AllocatedBool`. - #[tracing::instrument(target = "r1cs")] - pub fn and(&self, b: &Self) -> Result { - let result = Self::new_witness_without_booleanity_check(self.cs.clone(), || { - Ok(self.value()? & b.value()?) - })?; - - // Constrain (a) * (b) = (c), ensuring c is 1 iff - // a AND b are both 1. - self.cs.enforce_constraint( - lc!() + self.variable, - lc!() + b.variable, - lc!() + result.variable, - )?; - - Ok(result) - } - - /// Performs an OR operation over the two operands, returning - /// an `AllocatedBool`. - #[tracing::instrument(target = "r1cs")] - pub fn or(&self, b: &Self) -> Result { - let result = Self::new_witness_without_booleanity_check(self.cs.clone(), || { - Ok(self.value()? | b.value()?) - })?; - - // Constrain (1 - a) * (1 - b) = (c), ensuring c is 1 iff - // a and b are both false, and otherwise c is 0. - self.cs.enforce_constraint( - lc!() + Variable::One - self.variable, - lc!() + Variable::One - b.variable, - lc!() + Variable::One - result.variable, - )?; - - Ok(result) - } - - /// Calculates `a AND (NOT b)`. - #[tracing::instrument(target = "r1cs")] - pub fn and_not(&self, b: &Self) -> Result { - let result = Self::new_witness_without_booleanity_check(self.cs.clone(), || { - Ok(self.value()? & !b.value()?) - })?; - - // Constrain (a) * (1 - b) = (c), ensuring c is 1 iff - // a is true and b is false, and otherwise c is 0. - self.cs.enforce_constraint( - lc!() + self.variable, - lc!() + Variable::One - b.variable, - lc!() + result.variable, - )?; - - Ok(result) - } - - /// Calculates `(NOT a) AND (NOT b)`. - #[tracing::instrument(target = "r1cs")] - pub fn nor(&self, b: &Self) -> Result { - let result = Self::new_witness_without_booleanity_check(self.cs.clone(), || { - Ok(!(self.value()? | b.value()?)) - })?; - - // Constrain (1 - a) * (1 - b) = (c), ensuring c is 1 iff - // a and b are both false, and otherwise c is 0. - self.cs.enforce_constraint( - lc!() + Variable::One - self.variable, - lc!() + Variable::One - b.variable, - lc!() + result.variable, - )?; - - Ok(result) - } -} - -impl AllocVar for AllocatedBool { - /// Produces a new variable of the appropriate kind - /// (instance or witness), with a booleanity check. - /// - /// N.B.: we could omit the booleanity check when allocating `self` - /// as a new public input, but that places an additional burden on - /// protocol designers. Better safe than sorry! - fn new_variable>( - cs: impl Into>, - f: impl FnOnce() -> Result, - mode: AllocationMode, - ) -> Result { - let ns = cs.into(); - let cs = ns.cs(); - if mode == AllocationMode::Constant { - let variable = if *f()?.borrow() { - Variable::One - } else { - Variable::Zero - }; - Ok(Self { variable, cs }) - } else { - let variable = if mode == AllocationMode::Input { - cs.new_input_variable(|| f().map(bool_to_field))? - } else { - cs.new_witness_variable(|| f().map(bool_to_field))? - }; - - // Constrain: (1 - a) * a = 0 - // This constrains a to be either 0 or 1. - - cs.enforce_constraint(lc!() + Variable::One - variable, lc!() + variable, lc!())?; - - Ok(Self { variable, cs }) - } - } -} - -impl CondSelectGadget for AllocatedBool { - #[tracing::instrument(target = "r1cs")] - fn conditionally_select( - cond: &Boolean, - true_val: &Self, - false_val: &Self, - ) -> Result { - let res = Boolean::conditionally_select( - cond, - &true_val.clone().into(), - &false_val.clone().into(), - )?; - match res { - Boolean::Is(a) => Ok(a), - _ => unreachable!("Impossible"), - } - } -} - -/// Represents a boolean value in the constraint system which is guaranteed -/// to be either zero or one. -#[derive(Clone, Debug, Eq, PartialEq)] -#[must_use] -pub enum Boolean { - /// Existential view of the boolean variable. - Is(AllocatedBool), - /// Negated view of the boolean variable. - Not(AllocatedBool), - /// Constant (not an allocated variable). - Constant(bool), -} - -impl R1CSVar for Boolean { - type Value = bool; - - fn cs(&self) -> ConstraintSystemRef { - match self { - Self::Is(a) | Self::Not(a) => a.cs.clone(), - _ => ConstraintSystemRef::None, - } - } - - fn value(&self) -> Result { - match self { - Boolean::Constant(c) => Ok(*c), - Boolean::Is(ref v) => v.value(), - Boolean::Not(ref v) => v.value().map(|b| !b), - } - } -} - -impl Boolean { - /// The constant `true`. - pub const TRUE: Self = Boolean::Constant(true); - - /// The constant `false`. - pub const FALSE: Self = Boolean::Constant(false); - - /// Constructs a `LinearCombination` from `Self`'s variables according - /// to the following map. - /// - /// * `Boolean::Constant(true) => lc!() + Variable::One` - /// * `Boolean::Constant(false) => lc!()` - /// * `Boolean::Is(v) => lc!() + v.variable()` - /// * `Boolean::Not(v) => lc!() + Variable::One - v.variable()` - pub fn lc(&self) -> LinearCombination { - match self { - Boolean::Constant(false) => lc!(), - Boolean::Constant(true) => lc!() + Variable::One, - Boolean::Is(v) => v.variable().into(), - Boolean::Not(v) => lc!() + Variable::One - v.variable(), - } - } - - /// Constructs a `Boolean` vector from a slice of constant `u8`. - /// The `u8`s are decomposed in little-endian manner. - /// - /// This *does not* create any new variables or constraints. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// let t = Boolean::::TRUE; - /// let f = Boolean::::FALSE; - /// - /// let bits = vec![f, t]; - /// let generated_bits = Boolean::constant_vec_from_bytes(&[2]); - /// bits[..2].enforce_equal(&generated_bits[..2])?; - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - pub fn constant_vec_from_bytes(values: &[u8]) -> Vec { - let mut bits = vec![]; - for byte in values { - for i in 0..8 { - bits.push(Self::Constant(((byte >> i) & 1u8) == 1u8)); - } - } - bits - } - - /// Constructs a constant `Boolean` with value `b`. - /// - /// This *does not* create any new variables or constraints. - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_r1cs_std::prelude::*; - /// - /// let true_var = Boolean::::TRUE; - /// let false_var = Boolean::::FALSE; - /// - /// true_var.enforce_equal(&Boolean::constant(true))?; - /// false_var.enforce_equal(&Boolean::constant(false))?; - /// # Ok(()) - /// # } - /// ``` - pub fn constant(b: bool) -> Self { - Boolean::Constant(b) - } - - /// Negates `self`. - /// - /// This *does not* create any new variables or constraints. - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// - /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; - /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; - /// - /// a.not().enforce_equal(&b)?; - /// b.not().enforce_equal(&a)?; - /// - /// a.not().enforce_equal(&Boolean::FALSE)?; - /// b.not().enforce_equal(&Boolean::TRUE)?; - /// - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - pub fn not(&self) -> Self { - match *self { - Boolean::Constant(c) => Boolean::Constant(!c), - Boolean::Is(ref v) => Boolean::Not(v.clone()), - Boolean::Not(ref v) => Boolean::Is(v.clone()), - } - } - - /// Outputs `self ^ other`. - /// - /// If at least one of `self` and `other` are constants, then this method - /// *does not* create any constraints or variables. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// - /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; - /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; - /// - /// a.xor(&b)?.enforce_equal(&Boolean::TRUE)?; - /// b.xor(&a)?.enforce_equal(&Boolean::TRUE)?; - /// - /// a.xor(&a)?.enforce_equal(&Boolean::FALSE)?; - /// b.xor(&b)?.enforce_equal(&Boolean::FALSE)?; - /// - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - #[tracing::instrument(target = "r1cs")] - pub fn xor<'a>(&'a self, other: &'a Self) -> Result { - use Boolean::*; - match (self, other) { - (&Constant(false), x) | (x, &Constant(false)) => Ok(x.clone()), - (&Constant(true), x) | (x, &Constant(true)) => Ok(x.not()), - // a XOR (NOT b) = NOT(a XOR b) - (is @ &Is(_), not @ &Not(_)) | (not @ &Not(_), is @ &Is(_)) => { - Ok(is.xor(¬.not())?.not()) - }, - // a XOR b = (NOT a) XOR (NOT b) - (&Is(ref a), &Is(ref b)) | (&Not(ref a), &Not(ref b)) => Ok(Is(a.xor(b)?)), - } - } - - /// Outputs `self | other`. - /// - /// If at least one of `self` and `other` are constants, then this method - /// *does not* create any constraints or variables. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// - /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; - /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; - /// - /// a.or(&b)?.enforce_equal(&Boolean::TRUE)?; - /// b.or(&a)?.enforce_equal(&Boolean::TRUE)?; - /// - /// a.or(&a)?.enforce_equal(&Boolean::TRUE)?; - /// b.or(&b)?.enforce_equal(&Boolean::FALSE)?; - /// - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - #[tracing::instrument(target = "r1cs")] - pub fn or<'a>(&'a self, other: &'a Self) -> Result { - use Boolean::*; - match (self, other) { - (&Constant(false), x) | (x, &Constant(false)) => Ok(x.clone()), - (&Constant(true), _) | (_, &Constant(true)) => Ok(Constant(true)), - // a OR b = NOT ((NOT a) AND (NOT b)) - (a @ &Is(_), b @ &Not(_)) | (b @ &Not(_), a @ &Is(_)) | (b @ &Not(_), a @ &Not(_)) => { - Ok(a.not().and(&b.not())?.not()) - }, - (&Is(ref a), &Is(ref b)) => a.or(b).map(From::from), - } - } - - /// Outputs `self & other`. - /// - /// If at least one of `self` and `other` are constants, then this method - /// *does not* create any constraints or variables. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// - /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; - /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; - /// - /// a.and(&a)?.enforce_equal(&Boolean::TRUE)?; - /// - /// a.and(&b)?.enforce_equal(&Boolean::FALSE)?; - /// b.and(&a)?.enforce_equal(&Boolean::FALSE)?; - /// b.and(&b)?.enforce_equal(&Boolean::FALSE)?; - /// - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - #[tracing::instrument(target = "r1cs")] - pub fn and<'a>(&'a self, other: &'a Self) -> Result { - use Boolean::*; - match (self, other) { - // false AND x is always false - (&Constant(false), _) | (_, &Constant(false)) => Ok(Constant(false)), - // true AND x is always x - (&Constant(true), x) | (x, &Constant(true)) => Ok(x.clone()), - // a AND (NOT b) - (&Is(ref is), &Not(ref not)) | (&Not(ref not), &Is(ref is)) => Ok(Is(is.and_not(not)?)), - // (NOT a) AND (NOT b) = a NOR b - (&Not(ref a), &Not(ref b)) => Ok(Is(a.nor(b)?)), - // a AND b - (&Is(ref a), &Is(ref b)) => Ok(Is(a.and(b)?)), - } - } - - /// Outputs `bits[0] & bits[1] & ... & bits.last().unwrap()`. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// - /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; - /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; - /// let c = Boolean::new_witness(cs.clone(), || Ok(true))?; - /// - /// Boolean::kary_and(&[a.clone(), b.clone(), c.clone()])?.enforce_equal(&Boolean::FALSE)?; - /// Boolean::kary_and(&[a.clone(), c.clone()])?.enforce_equal(&Boolean::TRUE)?; - /// - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - #[tracing::instrument(target = "r1cs")] - pub fn kary_and(bits: &[Self]) -> Result { - assert!(!bits.is_empty()); - let mut cur: Option = None; - for next in bits { - cur = if let Some(b) = cur { - Some(b.and(next)?) - } else { - Some(next.clone()) - }; - } - - Ok(cur.expect("should not be 0")) - } - - /// Outputs `bits[0] | bits[1] | ... | bits.last().unwrap()`. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// - /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; - /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; - /// let c = Boolean::new_witness(cs.clone(), || Ok(false))?; - /// - /// Boolean::kary_or(&[a.clone(), b.clone(), c.clone()])?.enforce_equal(&Boolean::TRUE)?; - /// Boolean::kary_or(&[a.clone(), c.clone()])?.enforce_equal(&Boolean::TRUE)?; - /// Boolean::kary_or(&[b.clone(), c.clone()])?.enforce_equal(&Boolean::FALSE)?; - /// - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - #[tracing::instrument(target = "r1cs")] - pub fn kary_or(bits: &[Self]) -> Result { - assert!(!bits.is_empty()); - let mut cur: Option = None; - for next in bits { - cur = if let Some(b) = cur { - Some(b.or(next)?) - } else { - Some(next.clone()) - }; - } - - Ok(cur.expect("should not be 0")) - } - - /// Outputs `(bits[0] & bits[1] & ... & bits.last().unwrap()).not()`. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// - /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; - /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; - /// let c = Boolean::new_witness(cs.clone(), || Ok(true))?; - /// - /// Boolean::kary_nand(&[a.clone(), b.clone(), c.clone()])?.enforce_equal(&Boolean::TRUE)?; - /// Boolean::kary_nand(&[a.clone(), c.clone()])?.enforce_equal(&Boolean::FALSE)?; - /// Boolean::kary_nand(&[b.clone(), c.clone()])?.enforce_equal(&Boolean::TRUE)?; - /// - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - #[tracing::instrument(target = "r1cs")] - pub fn kary_nand(bits: &[Self]) -> Result { - Ok(Self::kary_and(bits)?.not()) - } - - /// Enforces that `Self::kary_nand(bits).is_eq(&Boolean::TRUE)`. - /// - /// Informally, this means that at least one element in `bits` must be - /// `false`. - #[tracing::instrument(target = "r1cs")] - fn enforce_kary_nand(bits: &[Self]) -> Result<(), SynthesisError> { - use Boolean::*; - let r = Self::kary_nand(bits)?; - match r { - Constant(true) => Ok(()), - Constant(false) => Err(SynthesisError::AssignmentMissing), - Is(_) | Not(_) => { - r.cs() - .enforce_constraint(r.lc(), lc!() + Variable::One, lc!() + Variable::One) - }, - } - } - - /// Convert a little-endian bitwise representation of a field element to - /// `FpVar` - #[tracing::instrument(target = "r1cs", skip(bits))] - pub fn le_bits_to_fp_var(bits: &[Self]) -> Result, SynthesisError> - where - F: PrimeField, - { - // Compute the value of the `FpVar` variable via double-and-add. - let mut value = None; - let cs = bits.cs(); - // Assign a value only when `cs` is in setup mode, or if we are constructing - // a constant. - let should_construct_value = (!cs.is_in_setup_mode()) || bits.is_constant(); - if should_construct_value { - let bits = bits.iter().map(|b| b.value().unwrap()).collect::>(); - let bytes = bits - .chunks(8) - .map(|c| { - let mut value = 0u8; - for (i, &bit) in c.iter().enumerate() { - value += (bit as u8) << i; - } - value - }) - .collect::>(); - value = Some(F::from_le_bytes_mod_order(&bytes)); - } - - if bits.is_constant() { - Ok(FpVar::constant(value.unwrap())) - } else { - let mut power = F::one(); - // Compute a linear combination for the new field variable, again - // via double and add. - let mut combined_lc = LinearCombination::zero(); - bits.iter().for_each(|b| { - combined_lc = &combined_lc + (power, b.lc()); - power.double_in_place(); - }); - // Allocate the new variable as a SymbolicLc - let variable = cs.new_lc(combined_lc)?; - // If the number of bits is less than the size of the field, - // then we do not need to enforce that the element is less than - // the modulus. - if bits.len() >= F::MODULUS_BIT_SIZE as usize { - Self::enforce_in_field_le(bits)?; - } - Ok(crate::fields::fp::AllocatedFp::new(value, variable, cs.clone()).into()) - } - } - - /// Enforces that `bits`, when interpreted as a integer, is less than - /// `F::characteristic()`, That is, interpret bits as a little-endian - /// integer, and enforce that this integer is "in the field Z_p", where - /// `p = F::characteristic()` . - #[tracing::instrument(target = "r1cs")] - pub fn enforce_in_field_le(bits: &[Self]) -> Result<(), SynthesisError> { - // `bits` < F::characteristic() <==> `bits` <= F::characteristic() -1 - let mut b = F::characteristic().to_vec(); - assert_eq!(b[0] % 2, 1); - b[0] -= 1; // This works, because the LSB is one, so there's no borrows. - let run = Self::enforce_smaller_or_equal_than_le(bits, b)?; - - // We should always end in a "run" of zeros, because - // the characteristic is an odd prime. So, this should - // be empty. - assert!(run.is_empty()); - - Ok(()) - } - - /// Enforces that `bits` is less than or equal to `element`, - /// when both are interpreted as (little-endian) integers. - #[tracing::instrument(target = "r1cs", skip(element))] - pub fn enforce_smaller_or_equal_than_le<'a>( - bits: &[Self], - element: impl AsRef<[u64]>, - ) -> Result, SynthesisError> { - let b: &[u64] = element.as_ref(); - - let mut bits_iter = bits.iter().rev(); // Iterate in big-endian - - // Runs of ones in r - let mut last_run = Boolean::constant(true); - let mut current_run = vec![]; - - let mut element_num_bits = 0; - for _ in BitIteratorBE::without_leading_zeros(b) { - element_num_bits += 1; - } - - if bits.len() > element_num_bits { - let mut or_result = Boolean::constant(false); - for should_be_zero in &bits[element_num_bits..] { - or_result = or_result.or(should_be_zero)?; - let _ = bits_iter.next().unwrap(); - } - or_result.enforce_equal(&Boolean::constant(false))?; - } - - for (b, a) in BitIteratorBE::without_leading_zeros(b).zip(bits_iter.by_ref()) { - if b { - // This is part of a run of ones. - current_run.push(a.clone()); - } else { - if !current_run.is_empty() { - // This is the start of a run of zeros, but we need - // to k-ary AND against `last_run` first. - - current_run.push(last_run.clone()); - last_run = Self::kary_and(¤t_run)?; - current_run.truncate(0); - } - - // If `last_run` is true, `a` must be false, or it would - // not be in the field. - // - // If `last_run` is false, `a` can be true or false. - // - // Ergo, at least one of `last_run` and `a` must be false. - Self::enforce_kary_nand(&[last_run.clone(), a.clone()])?; - } - } - assert!(bits_iter.next().is_none()); - - Ok(current_run) - } - - /// Conditionally selects one of `first` and `second` based on the value of - /// `self`: - /// - /// If `self.is_eq(&Boolean::TRUE)`, this outputs `first`; else, it outputs - /// `second`. - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// - /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; - /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; - /// - /// let cond = Boolean::new_witness(cs.clone(), || Ok(true))?; - /// - /// cond.select(&a, &b)?.enforce_equal(&Boolean::TRUE)?; - /// cond.select(&b, &a)?.enforce_equal(&Boolean::FALSE)?; - /// - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - #[tracing::instrument(target = "r1cs", skip(first, second))] - pub fn select>( - &self, - first: &T, - second: &T, - ) -> Result { - T::conditionally_select(&self, first, second) - } -} - -impl From> for Boolean { - fn from(b: AllocatedBool) -> Self { - Boolean::Is(b) - } -} - -impl AllocVar for Boolean { - fn new_variable>( - cs: impl Into>, - f: impl FnOnce() -> Result, - mode: AllocationMode, - ) -> Result { - if mode == AllocationMode::Constant { - Ok(Boolean::Constant(*f()?.borrow())) - } else { - AllocatedBool::new_variable(cs, f, mode).map(Boolean::from) - } - } -} - -impl EqGadget for Boolean { - #[tracing::instrument(target = "r1cs")] - fn is_eq(&self, other: &Self) -> Result, SynthesisError> { - // self | other | XNOR(self, other) | self == other - // -----|-------|-------------------|-------------- - // 0 | 0 | 1 | 1 - // 0 | 1 | 0 | 0 - // 1 | 0 | 0 | 0 - // 1 | 1 | 1 | 1 - Ok(self.xor(other)?.not()) - } - - #[tracing::instrument(target = "r1cs")] - fn conditional_enforce_equal( - &self, - other: &Self, - condition: &Boolean, - ) -> Result<(), SynthesisError> { - use Boolean::*; - let one = Variable::One; - let difference = match (self, other) { - // 1 == 1; 0 == 0 - (Constant(true), Constant(true)) | (Constant(false), Constant(false)) => return Ok(()), - // false != true - (Constant(_), Constant(_)) => return Err(SynthesisError::AssignmentMissing), - // 1 - a - (Constant(true), Is(a)) | (Is(a), Constant(true)) => lc!() + one - a.variable(), - // a - 0 = a - (Constant(false), Is(a)) | (Is(a), Constant(false)) => lc!() + a.variable(), - // 1 - !a = 1 - (1 - a) = a - (Constant(true), Not(a)) | (Not(a), Constant(true)) => lc!() + a.variable(), - // !a - 0 = !a = 1 - a - (Constant(false), Not(a)) | (Not(a), Constant(false)) => lc!() + one - a.variable(), - // b - a, - (Is(a), Is(b)) => lc!() + b.variable() - a.variable(), - // !b - a = (1 - b) - a - (Is(a), Not(b)) | (Not(b), Is(a)) => lc!() + one - b.variable() - a.variable(), - // !b - !a = (1 - b) - (1 - a) = a - b, - (Not(a), Not(b)) => lc!() + a.variable() - b.variable(), - }; - - if condition != &Constant(false) { - let cs = self.cs().or(other.cs()).or(condition.cs()); - cs.enforce_constraint(lc!() + difference, condition.lc(), lc!())?; - } - Ok(()) - } - - #[tracing::instrument(target = "r1cs")] - fn conditional_enforce_not_equal( - &self, - other: &Self, - should_enforce: &Boolean, - ) -> Result<(), SynthesisError> { - use Boolean::*; - let one = Variable::One; - let difference = match (self, other) { - // 1 != 0; 0 != 1 - (Constant(true), Constant(false)) | (Constant(false), Constant(true)) => return Ok(()), - // false == false and true == true - (Constant(_), Constant(_)) => return Err(SynthesisError::AssignmentMissing), - // 1 - a - (Constant(true), Is(a)) | (Is(a), Constant(true)) => lc!() + one - a.variable(), - // a - 0 = a - (Constant(false), Is(a)) | (Is(a), Constant(false)) => lc!() + a.variable(), - // 1 - !a = 1 - (1 - a) = a - (Constant(true), Not(a)) | (Not(a), Constant(true)) => lc!() + a.variable(), - // !a - 0 = !a = 1 - a - (Constant(false), Not(a)) | (Not(a), Constant(false)) => lc!() + one - a.variable(), - // b - a, - (Is(a), Is(b)) => lc!() + b.variable() - a.variable(), - // !b - a = (1 - b) - a - (Is(a), Not(b)) | (Not(b), Is(a)) => lc!() + one - b.variable() - a.variable(), - // !b - !a = (1 - b) - (1 - a) = a - b, - (Not(a), Not(b)) => lc!() + a.variable() - b.variable(), - }; - - if should_enforce != &Constant(false) { - let cs = self.cs().or(other.cs()).or(should_enforce.cs()); - cs.enforce_constraint(difference, should_enforce.lc(), should_enforce.lc())?; - } - Ok(()) - } -} - -impl ToBytesGadget for Boolean { - /// Outputs `1u8` if `self` is true, and `0u8` otherwise. - #[tracing::instrument(target = "r1cs")] - fn to_bytes(&self) -> Result>, SynthesisError> { - let value = self.value().map(u8::from).ok(); - let mut bits = [Boolean::FALSE; 8]; - bits[0] = self.clone(); - Ok(vec![UInt8 { bits, value }]) - } -} - -impl ToConstraintFieldGadget for Boolean { - #[tracing::instrument(target = "r1cs")] - fn to_constraint_field(&self) -> Result>, SynthesisError> { - let var = From::from(self.clone()); - Ok(vec![var]) - } -} - -impl CondSelectGadget for Boolean { - #[tracing::instrument(target = "r1cs")] - fn conditionally_select( - cond: &Boolean, - true_val: &Self, - false_val: &Self, - ) -> Result { - use Boolean::*; - match cond { - Constant(true) => Ok(true_val.clone()), - Constant(false) => Ok(false_val.clone()), - cond @ Not(_) => Self::conditionally_select(&cond.not(), false_val, true_val), - cond @ Is(_) => match (true_val, false_val) { - (x, &Constant(false)) => cond.and(x), - (&Constant(false), x) => cond.not().and(x), - (&Constant(true), x) => cond.or(x), - (x, &Constant(true)) => cond.not().or(x), - (a, b) => { - let cs = cond.cs(); - let result: Boolean = - AllocatedBool::new_witness_without_booleanity_check(cs.clone(), || { - let cond = cond.value()?; - Ok(if cond { a.value()? } else { b.value()? }) - })? - .into(); - // a = self; b = other; c = cond; - // - // r = c * a + (1 - c) * b - // r = b + c * (a - b) - // c * (a - b) = r - b - // - // If a, b, cond are all boolean, so is r. - // - // self | other | cond | result - // -----|-------|---------------- - // 0 | 0 | 1 | 0 - // 0 | 1 | 1 | 0 - // 1 | 0 | 1 | 1 - // 1 | 1 | 1 | 1 - // 0 | 0 | 0 | 0 - // 0 | 1 | 0 | 1 - // 1 | 0 | 0 | 0 - // 1 | 1 | 0 | 1 - cs.enforce_constraint( - cond.lc(), - lc!() + a.lc() - b.lc(), - lc!() + result.lc() - b.lc(), - )?; - - Ok(result) - }, - }, - } - } -} - -#[cfg(test)] -mod test { - use super::{AllocatedBool, Boolean}; - use crate::prelude::*; - use ark_ff::{BitIteratorBE, BitIteratorLE, Field, One, PrimeField, UniformRand, Zero}; - use ark_relations::r1cs::{ConstraintSystem, Namespace, SynthesisError}; - use ark_test_curves::bls12_381::Fr; - - #[test] - fn test_boolean_to_byte() -> Result<(), SynthesisError> { - for val in [true, false].iter() { - let cs = ConstraintSystem::::new_ref(); - let a = Boolean::new_witness(cs.clone(), || Ok(*val))?; - let bytes = a.to_bytes()?; - assert_eq!(bytes.len(), 1); - let byte = &bytes[0]; - assert_eq!(byte.value()?, *val as u8); - - for (i, bit) in byte.bits.iter().enumerate() { - assert_eq!(bit.value()?, (byte.value()? >> i) & 1 == 1); - } - } - Ok(()) - } - - #[test] - fn test_xor() -> Result<(), SynthesisError> { - for a_val in [false, true].iter().copied() { - for b_val in [false, true].iter().copied() { - let cs = ConstraintSystem::::new_ref(); - let a = AllocatedBool::new_witness(cs.clone(), || Ok(a_val))?; - let b = AllocatedBool::new_witness(cs.clone(), || Ok(b_val))?; - let c = AllocatedBool::xor(&a, &b)?; - assert_eq!(c.value()?, a_val ^ b_val); - - assert!(cs.is_satisfied().unwrap()); - assert_eq!(a.value()?, (a_val)); - assert_eq!(b.value()?, (b_val)); - assert_eq!(c.value()?, (a_val ^ b_val)); - } - } - Ok(()) - } - - #[test] - fn test_or() -> Result<(), SynthesisError> { - for a_val in [false, true].iter().copied() { - for b_val in [false, true].iter().copied() { - let cs = ConstraintSystem::::new_ref(); - let a = AllocatedBool::new_witness(cs.clone(), || Ok(a_val))?; - let b = AllocatedBool::new_witness(cs.clone(), || Ok(b_val))?; - let c = AllocatedBool::or(&a, &b)?; - assert_eq!(c.value()?, a_val | b_val); - - assert!(cs.is_satisfied().unwrap()); - assert_eq!(a.value()?, (a_val)); - assert_eq!(b.value()?, (b_val)); - assert_eq!(c.value()?, (a_val | b_val)); - } - } - Ok(()) - } - - #[test] - fn test_and() -> Result<(), SynthesisError> { - for a_val in [false, true].iter().copied() { - for b_val in [false, true].iter().copied() { - let cs = ConstraintSystem::::new_ref(); - let a = AllocatedBool::new_witness(cs.clone(), || Ok(a_val))?; - let b = AllocatedBool::new_witness(cs.clone(), || Ok(b_val))?; - let c = AllocatedBool::and(&a, &b)?; - assert_eq!(c.value()?, a_val & b_val); - - assert!(cs.is_satisfied().unwrap()); - assert_eq!(a.value()?, (a_val)); - assert_eq!(b.value()?, (b_val)); - assert_eq!(c.value()?, (a_val & b_val)); - } - } - Ok(()) - } - - #[test] - fn test_and_not() -> Result<(), SynthesisError> { - for a_val in [false, true].iter().copied() { - for b_val in [false, true].iter().copied() { - let cs = ConstraintSystem::::new_ref(); - let a = AllocatedBool::new_witness(cs.clone(), || Ok(a_val))?; - let b = AllocatedBool::new_witness(cs.clone(), || Ok(b_val))?; - let c = AllocatedBool::and_not(&a, &b)?; - assert_eq!(c.value()?, a_val & !b_val); - - assert!(cs.is_satisfied().unwrap()); - assert_eq!(a.value()?, (a_val)); - assert_eq!(b.value()?, (b_val)); - assert_eq!(c.value()?, (a_val & !b_val)); - } - } - Ok(()) - } - - #[test] - fn test_nor() -> Result<(), SynthesisError> { - for a_val in [false, true].iter().copied() { - for b_val in [false, true].iter().copied() { - let cs = ConstraintSystem::::new_ref(); - let a = AllocatedBool::new_witness(cs.clone(), || Ok(a_val))?; - let b = AllocatedBool::new_witness(cs.clone(), || Ok(b_val))?; - let c = AllocatedBool::nor(&a, &b)?; - assert_eq!(c.value()?, !a_val & !b_val); - - assert!(cs.is_satisfied().unwrap()); - assert_eq!(a.value()?, (a_val)); - assert_eq!(b.value()?, (b_val)); - assert_eq!(c.value()?, (!a_val & !b_val)); - } - } - Ok(()) - } - - #[test] - fn test_enforce_equal() -> Result<(), SynthesisError> { - for a_bool in [false, true].iter().cloned() { - for b_bool in [false, true].iter().cloned() { - for a_neg in [false, true].iter().cloned() { - for b_neg in [false, true].iter().cloned() { - let cs = ConstraintSystem::::new_ref(); - - let mut a = Boolean::new_witness(cs.clone(), || Ok(a_bool))?; - let mut b = Boolean::new_witness(cs.clone(), || Ok(b_bool))?; - - if a_neg { - a = a.not(); - } - if b_neg { - b = b.not(); - } - - a.enforce_equal(&b)?; - - assert_eq!( - cs.is_satisfied().unwrap(), - (a_bool ^ a_neg) == (b_bool ^ b_neg) - ); - } - } - } - } - Ok(()) - } - - #[test] - fn test_conditional_enforce_equal() -> Result<(), SynthesisError> { - for a_bool in [false, true].iter().cloned() { - for b_bool in [false, true].iter().cloned() { - for a_neg in [false, true].iter().cloned() { - for b_neg in [false, true].iter().cloned() { - let cs = ConstraintSystem::::new_ref(); - - // First test if constraint system is satisfied - // when we do want to enforce the condition. - let mut a = Boolean::new_witness(cs.clone(), || Ok(a_bool))?; - let mut b = Boolean::new_witness(cs.clone(), || Ok(b_bool))?; - - if a_neg { - a = a.not(); - } - if b_neg { - b = b.not(); - } - - a.conditional_enforce_equal(&b, &Boolean::constant(true))?; - - assert_eq!( - cs.is_satisfied().unwrap(), - (a_bool ^ a_neg) == (b_bool ^ b_neg) - ); - - // Now test if constraint system is satisfied even - // when we don't want to enforce the condition. - let cs = ConstraintSystem::::new_ref(); - - let mut a = Boolean::new_witness(cs.clone(), || Ok(a_bool))?; - let mut b = Boolean::new_witness(cs.clone(), || Ok(b_bool))?; - - if a_neg { - a = a.not(); - } - if b_neg { - b = b.not(); - } - - let false_cond = - Boolean::new_witness(ark_relations::ns!(cs, "cond"), || Ok(false))?; - a.conditional_enforce_equal(&b, &false_cond)?; - - assert!(cs.is_satisfied().unwrap()); - } - } - } - } - Ok(()) - } - - #[test] - fn test_boolean_negation() -> Result<(), SynthesisError> { - let cs = ConstraintSystem::::new_ref(); - - let mut b = Boolean::new_witness(cs.clone(), || Ok(true))?; - assert!(matches!(b, Boolean::Is(_))); - - b = b.not(); - assert!(matches!(b, Boolean::Not(_))); - - b = b.not(); - assert!(matches!(b, Boolean::Is(_))); - - b = Boolean::Constant(true); - assert!(matches!(b, Boolean::Constant(true))); - - b = b.not(); - assert!(matches!(b, Boolean::Constant(false))); - - b = b.not(); - assert!(matches!(b, Boolean::Constant(true))); - Ok(()) - } - - #[derive(Eq, PartialEq, Copy, Clone, Debug)] - enum OpType { - True, - False, - AllocatedTrue, - AllocatedFalse, - NegatedAllocatedTrue, - NegatedAllocatedFalse, - } - - const VARIANTS: [OpType; 6] = [ - OpType::True, - OpType::False, - OpType::AllocatedTrue, - OpType::AllocatedFalse, - OpType::NegatedAllocatedTrue, - OpType::NegatedAllocatedFalse, - ]; - - fn construct( - ns: Namespace, - operand: OpType, - ) -> Result, SynthesisError> { - let cs = ns.cs(); - - let b = match operand { - OpType::True => Boolean::constant(true), - OpType::False => Boolean::constant(false), - OpType::AllocatedTrue => Boolean::new_witness(cs, || Ok(true))?, - OpType::AllocatedFalse => Boolean::new_witness(cs, || Ok(false))?, - OpType::NegatedAllocatedTrue => Boolean::new_witness(cs, || Ok(true))?.not(), - OpType::NegatedAllocatedFalse => Boolean::new_witness(cs, || Ok(false))?.not(), - }; - Ok(b) - } - - #[test] - fn test_boolean_xor() -> Result<(), SynthesisError> { - for first_operand in VARIANTS.iter().cloned() { - for second_operand in VARIANTS.iter().cloned() { - let cs = ConstraintSystem::::new_ref(); - - let a = construct(ark_relations::ns!(cs, "a"), first_operand)?; - let b = construct(ark_relations::ns!(cs, "b"), second_operand)?; - let c = Boolean::xor(&a, &b)?; - - assert!(cs.is_satisfied().unwrap()); - - match (first_operand, second_operand, c) { - (OpType::True, OpType::True, Boolean::Constant(false)) => (), - (OpType::True, OpType::False, Boolean::Constant(true)) => (), - (OpType::True, OpType::AllocatedTrue, Boolean::Not(_)) => (), - (OpType::True, OpType::AllocatedFalse, Boolean::Not(_)) => (), - (OpType::True, OpType::NegatedAllocatedTrue, Boolean::Is(_)) => (), - (OpType::True, OpType::NegatedAllocatedFalse, Boolean::Is(_)) => (), - - (OpType::False, OpType::True, Boolean::Constant(true)) => (), - (OpType::False, OpType::False, Boolean::Constant(false)) => (), - (OpType::False, OpType::AllocatedTrue, Boolean::Is(_)) => (), - (OpType::False, OpType::AllocatedFalse, Boolean::Is(_)) => (), - (OpType::False, OpType::NegatedAllocatedTrue, Boolean::Not(_)) => (), - (OpType::False, OpType::NegatedAllocatedFalse, Boolean::Not(_)) => (), - - (OpType::AllocatedTrue, OpType::True, Boolean::Not(_)) => (), - (OpType::AllocatedTrue, OpType::False, Boolean::Is(_)) => (), - (OpType::AllocatedTrue, OpType::AllocatedTrue, Boolean::Is(ref v)) => { - assert_eq!(v.value(), Ok(false)); - }, - (OpType::AllocatedTrue, OpType::AllocatedFalse, Boolean::Is(ref v)) => { - assert_eq!(v.value(), Ok(true)); - }, - (OpType::AllocatedTrue, OpType::NegatedAllocatedTrue, Boolean::Not(ref v)) => { - assert_eq!(v.value(), Ok(false)); - }, - (OpType::AllocatedTrue, OpType::NegatedAllocatedFalse, Boolean::Not(ref v)) => { - assert_eq!(v.value(), Ok(true)); - }, - (OpType::AllocatedFalse, OpType::True, Boolean::Not(_)) => (), - (OpType::AllocatedFalse, OpType::False, Boolean::Is(_)) => (), - (OpType::AllocatedFalse, OpType::AllocatedTrue, Boolean::Is(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::one()); - assert_eq!(v.value(), Ok(true)); - }, - (OpType::AllocatedFalse, OpType::AllocatedFalse, Boolean::Is(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - (OpType::AllocatedFalse, OpType::NegatedAllocatedTrue, Boolean::Not(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::one()); - assert_eq!(v.value(), Ok(true)); - }, - ( - OpType::AllocatedFalse, - OpType::NegatedAllocatedFalse, - Boolean::Not(ref v), - ) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - - (OpType::NegatedAllocatedTrue, OpType::True, Boolean::Is(_)) => (), - (OpType::NegatedAllocatedTrue, OpType::False, Boolean::Not(_)) => (), - (OpType::NegatedAllocatedTrue, OpType::AllocatedTrue, Boolean::Not(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - (OpType::NegatedAllocatedTrue, OpType::AllocatedFalse, Boolean::Not(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::one()); - assert_eq!(v.value(), Ok(true)); - }, - ( - OpType::NegatedAllocatedTrue, - OpType::NegatedAllocatedTrue, - Boolean::Is(ref v), - ) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - ( - OpType::NegatedAllocatedTrue, - OpType::NegatedAllocatedFalse, - Boolean::Is(ref v), - ) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::one()); - assert_eq!(v.value(), Ok(true)); - }, - - (OpType::NegatedAllocatedFalse, OpType::True, Boolean::Is(_)) => (), - (OpType::NegatedAllocatedFalse, OpType::False, Boolean::Not(_)) => (), - (OpType::NegatedAllocatedFalse, OpType::AllocatedTrue, Boolean::Not(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::one()); - assert_eq!(v.value(), Ok(true)); - }, - ( - OpType::NegatedAllocatedFalse, - OpType::AllocatedFalse, - Boolean::Not(ref v), - ) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - ( - OpType::NegatedAllocatedFalse, - OpType::NegatedAllocatedTrue, - Boolean::Is(ref v), - ) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::one()); - assert_eq!(v.value(), Ok(true)); - }, - ( - OpType::NegatedAllocatedFalse, - OpType::NegatedAllocatedFalse, - Boolean::Is(ref v), - ) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - - _ => unreachable!(), - } - } - } - Ok(()) - } - - #[test] - fn test_boolean_cond_select() -> Result<(), SynthesisError> { - for condition in VARIANTS.iter().cloned() { - for first_operand in VARIANTS.iter().cloned() { - for second_operand in VARIANTS.iter().cloned() { - let cs = ConstraintSystem::::new_ref(); - - let cond = construct(ark_relations::ns!(cs, "cond"), condition)?; - let a = construct(ark_relations::ns!(cs, "a"), first_operand)?; - let b = construct(ark_relations::ns!(cs, "b"), second_operand)?; - let c = cond.select(&a, &b)?; - - assert!( - cs.is_satisfied().unwrap(), - "failed with operands: cond: {:?}, a: {:?}, b: {:?}", - condition, - first_operand, - second_operand, - ); - assert_eq!( - c.value()?, - if cond.value()? { - a.value()? - } else { - b.value()? - } - ); - } - } - } - Ok(()) - } - - #[test] - fn test_boolean_or() -> Result<(), SynthesisError> { - for first_operand in VARIANTS.iter().cloned() { - for second_operand in VARIANTS.iter().cloned() { - let cs = ConstraintSystem::::new_ref(); - - let a = construct(ark_relations::ns!(cs, "a"), first_operand)?; - let b = construct(ark_relations::ns!(cs, "b"), second_operand)?; - let c = a.or(&b)?; - - assert!(cs.is_satisfied().unwrap()); - - match (first_operand, second_operand, c.clone()) { - (OpType::True, OpType::True, Boolean::Constant(true)) => (), - (OpType::True, OpType::False, Boolean::Constant(true)) => (), - (OpType::True, OpType::AllocatedTrue, Boolean::Constant(true)) => (), - (OpType::True, OpType::AllocatedFalse, Boolean::Constant(true)) => (), - (OpType::True, OpType::NegatedAllocatedTrue, Boolean::Constant(true)) => (), - (OpType::True, OpType::NegatedAllocatedFalse, Boolean::Constant(true)) => (), - - (OpType::False, OpType::True, Boolean::Constant(true)) => (), - (OpType::False, OpType::False, Boolean::Constant(false)) => (), - (OpType::False, OpType::AllocatedTrue, Boolean::Is(_)) => (), - (OpType::False, OpType::AllocatedFalse, Boolean::Is(_)) => (), - (OpType::False, OpType::NegatedAllocatedTrue, Boolean::Not(_)) => (), - (OpType::False, OpType::NegatedAllocatedFalse, Boolean::Not(_)) => (), - - (OpType::AllocatedTrue, OpType::True, Boolean::Constant(true)) => (), - (OpType::AllocatedTrue, OpType::False, Boolean::Is(_)) => (), - (OpType::AllocatedTrue, OpType::AllocatedTrue, Boolean::Is(ref v)) => { - assert_eq!(v.value(), Ok(true)); - }, - (OpType::AllocatedTrue, OpType::AllocatedFalse, Boolean::Is(ref v)) => { - assert_eq!(v.value(), Ok(true)); - }, - (OpType::AllocatedTrue, OpType::NegatedAllocatedTrue, Boolean::Not(ref v)) => { - assert_eq!(v.value(), Ok(false)); - }, - (OpType::AllocatedTrue, OpType::NegatedAllocatedFalse, Boolean::Not(ref v)) => { - assert_eq!(v.value(), Ok(false)); - }, - - (OpType::AllocatedFalse, OpType::True, Boolean::Constant(true)) => (), - (OpType::AllocatedFalse, OpType::False, Boolean::Is(_)) => (), - (OpType::AllocatedFalse, OpType::AllocatedTrue, Boolean::Is(ref v)) => { - assert_eq!(v.value(), Ok(true)); - }, - (OpType::AllocatedFalse, OpType::AllocatedFalse, Boolean::Is(ref v)) => { - assert_eq!(v.value(), Ok(false)); - }, - (OpType::AllocatedFalse, OpType::NegatedAllocatedTrue, Boolean::Not(ref v)) => { - assert_eq!(v.value(), Ok(true)); - }, - ( - OpType::AllocatedFalse, - OpType::NegatedAllocatedFalse, - Boolean::Not(ref v), - ) => { - assert_eq!(v.value(), Ok(false)); - }, - - (OpType::NegatedAllocatedTrue, OpType::True, Boolean::Constant(true)) => (), - (OpType::NegatedAllocatedTrue, OpType::False, Boolean::Not(_)) => (), - (OpType::NegatedAllocatedTrue, OpType::AllocatedTrue, Boolean::Not(ref v)) => { - assert_eq!(v.value(), Ok(false)); - }, - (OpType::NegatedAllocatedTrue, OpType::AllocatedFalse, Boolean::Not(ref v)) => { - assert_eq!(v.value(), Ok(true)); - }, - ( - OpType::NegatedAllocatedTrue, - OpType::NegatedAllocatedTrue, - Boolean::Not(ref v), - ) => { - assert_eq!(v.value(), Ok(true)); - }, - ( - OpType::NegatedAllocatedTrue, - OpType::NegatedAllocatedFalse, - Boolean::Not(ref v), - ) => { - assert_eq!(v.value(), Ok(false)); - }, - - (OpType::NegatedAllocatedFalse, OpType::True, Boolean::Constant(true)) => (), - (OpType::NegatedAllocatedFalse, OpType::False, Boolean::Not(_)) => (), - (OpType::NegatedAllocatedFalse, OpType::AllocatedTrue, Boolean::Not(ref v)) => { - assert_eq!(v.value(), Ok(false)); - }, - ( - OpType::NegatedAllocatedFalse, - OpType::AllocatedFalse, - Boolean::Not(ref v), - ) => { - assert_eq!(v.value(), Ok(false)); - }, - ( - OpType::NegatedAllocatedFalse, - OpType::NegatedAllocatedTrue, - Boolean::Not(ref v), - ) => { - assert_eq!(v.value(), Ok(false)); - }, - ( - OpType::NegatedAllocatedFalse, - OpType::NegatedAllocatedFalse, - Boolean::Not(ref v), - ) => { - assert_eq!(v.value(), Ok(false)); - }, - - _ => panic!( - "this should never be encountered, in case: (a = {:?}, b = {:?}, c = {:?})", - a, b, c - ), - } - } - } - Ok(()) - } - - #[test] - fn test_boolean_and() -> Result<(), SynthesisError> { - for first_operand in VARIANTS.iter().cloned() { - for second_operand in VARIANTS.iter().cloned() { - let cs = ConstraintSystem::::new_ref(); - - let a = construct(ark_relations::ns!(cs, "a"), first_operand)?; - let b = construct(ark_relations::ns!(cs, "b"), second_operand)?; - let c = a.and(&b)?; - - assert!(cs.is_satisfied().unwrap()); - - match (first_operand, second_operand, c) { - (OpType::True, OpType::True, Boolean::Constant(true)) => (), - (OpType::True, OpType::False, Boolean::Constant(false)) => (), - (OpType::True, OpType::AllocatedTrue, Boolean::Is(_)) => (), - (OpType::True, OpType::AllocatedFalse, Boolean::Is(_)) => (), - (OpType::True, OpType::NegatedAllocatedTrue, Boolean::Not(_)) => (), - (OpType::True, OpType::NegatedAllocatedFalse, Boolean::Not(_)) => (), - - (OpType::False, OpType::True, Boolean::Constant(false)) => (), - (OpType::False, OpType::False, Boolean::Constant(false)) => (), - (OpType::False, OpType::AllocatedTrue, Boolean::Constant(false)) => (), - (OpType::False, OpType::AllocatedFalse, Boolean::Constant(false)) => (), - (OpType::False, OpType::NegatedAllocatedTrue, Boolean::Constant(false)) => (), - (OpType::False, OpType::NegatedAllocatedFalse, Boolean::Constant(false)) => (), - - (OpType::AllocatedTrue, OpType::True, Boolean::Is(_)) => (), - (OpType::AllocatedTrue, OpType::False, Boolean::Constant(false)) => (), - (OpType::AllocatedTrue, OpType::AllocatedTrue, Boolean::Is(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::one()); - assert_eq!(v.value(), Ok(true)); - }, - (OpType::AllocatedTrue, OpType::AllocatedFalse, Boolean::Is(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - (OpType::AllocatedTrue, OpType::NegatedAllocatedTrue, Boolean::Is(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - (OpType::AllocatedTrue, OpType::NegatedAllocatedFalse, Boolean::Is(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::one()); - assert_eq!(v.value(), Ok(true)); - }, - - (OpType::AllocatedFalse, OpType::True, Boolean::Is(_)) => (), - (OpType::AllocatedFalse, OpType::False, Boolean::Constant(false)) => (), - (OpType::AllocatedFalse, OpType::AllocatedTrue, Boolean::Is(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - (OpType::AllocatedFalse, OpType::AllocatedFalse, Boolean::Is(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - (OpType::AllocatedFalse, OpType::NegatedAllocatedTrue, Boolean::Is(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - (OpType::AllocatedFalse, OpType::NegatedAllocatedFalse, Boolean::Is(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - - (OpType::NegatedAllocatedTrue, OpType::True, Boolean::Not(_)) => (), - (OpType::NegatedAllocatedTrue, OpType::False, Boolean::Constant(false)) => (), - (OpType::NegatedAllocatedTrue, OpType::AllocatedTrue, Boolean::Is(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - (OpType::NegatedAllocatedTrue, OpType::AllocatedFalse, Boolean::Is(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - ( - OpType::NegatedAllocatedTrue, - OpType::NegatedAllocatedTrue, - Boolean::Is(ref v), - ) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - ( - OpType::NegatedAllocatedTrue, - OpType::NegatedAllocatedFalse, - Boolean::Is(ref v), - ) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - - (OpType::NegatedAllocatedFalse, OpType::True, Boolean::Not(_)) => (), - (OpType::NegatedAllocatedFalse, OpType::False, Boolean::Constant(false)) => (), - (OpType::NegatedAllocatedFalse, OpType::AllocatedTrue, Boolean::Is(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::one()); - assert_eq!(v.value(), Ok(true)); - }, - (OpType::NegatedAllocatedFalse, OpType::AllocatedFalse, Boolean::Is(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - ( - OpType::NegatedAllocatedFalse, - OpType::NegatedAllocatedTrue, - Boolean::Is(ref v), - ) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - ( - OpType::NegatedAllocatedFalse, - OpType::NegatedAllocatedFalse, - Boolean::Is(ref v), - ) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::one()); - assert_eq!(v.value(), Ok(true)); - }, - - _ => { - panic!( - "unexpected behavior at {:?} AND {:?}", - first_operand, second_operand - ); - }, - } - } - } - Ok(()) - } - - #[test] - fn test_smaller_than_or_equal_to() -> Result<(), SynthesisError> { - let mut rng = ark_std::test_rng(); - for _ in 0..1000 { - let mut r = Fr::rand(&mut rng); - let mut s = Fr::rand(&mut rng); - if r > s { - core::mem::swap(&mut r, &mut s) - } - - let cs = ConstraintSystem::::new_ref(); - - let native_bits: Vec<_> = BitIteratorLE::new(r.into_bigint()).collect(); - let bits = Vec::new_witness(cs.clone(), || Ok(native_bits))?; - Boolean::enforce_smaller_or_equal_than_le(&bits, s.into_bigint())?; - - assert!(cs.is_satisfied().unwrap()); - } - - for _ in 0..1000 { - let r = Fr::rand(&mut rng); - if r == -Fr::one() { - continue; - } - let s = r + Fr::one(); - let s2 = r.double(); - let cs = ConstraintSystem::::new_ref(); - - let native_bits: Vec<_> = BitIteratorLE::new(r.into_bigint()).collect(); - let bits = Vec::new_witness(cs.clone(), || Ok(native_bits))?; - Boolean::enforce_smaller_or_equal_than_le(&bits, s.into_bigint())?; - if r < s2 { - Boolean::enforce_smaller_or_equal_than_le(&bits, s2.into_bigint())?; - } - - assert!(cs.is_satisfied().unwrap()); - } - Ok(()) - } - - #[test] - fn test_enforce_in_field() -> Result<(), SynthesisError> { - { - let cs = ConstraintSystem::::new_ref(); - - let mut bits = vec![]; - for b in BitIteratorBE::new(Fr::characteristic()).skip(1) { - bits.push(Boolean::new_witness(cs.clone(), || Ok(b))?); - } - bits.reverse(); - - Boolean::enforce_in_field_le(&bits)?; - - assert!(!cs.is_satisfied().unwrap()); - } - - let mut rng = ark_std::test_rng(); - - for _ in 0..1000 { - let r = Fr::rand(&mut rng); - let cs = ConstraintSystem::::new_ref(); - - let mut bits = vec![]; - for b in BitIteratorBE::new(r.into_bigint()).skip(1) { - bits.push(Boolean::new_witness(cs.clone(), || Ok(b))?); - } - bits.reverse(); - - Boolean::enforce_in_field_le(&bits)?; - - assert!(cs.is_satisfied().unwrap()); - } - Ok(()) - } - - #[test] - fn test_enforce_nand() -> Result<(), SynthesisError> { - { - let cs = ConstraintSystem::::new_ref(); - - assert!( - Boolean::enforce_kary_nand(&[Boolean::new_constant(cs.clone(), false)?]).is_ok() - ); - assert!( - Boolean::enforce_kary_nand(&[Boolean::new_constant(cs.clone(), true)?]).is_err() - ); - } - - for i in 1..5 { - // with every possible assignment for them - for mut b in 0..(1 << i) { - // with every possible negation - for mut n in 0..(1 << i) { - let cs = ConstraintSystem::::new_ref(); - - let mut expected = true; - - let mut bits = vec![]; - for _ in 0..i { - expected &= b & 1 == 1; - - let bit = if n & 1 == 1 { - Boolean::new_witness(cs.clone(), || Ok(b & 1 == 1))? - } else { - Boolean::new_witness(cs.clone(), || Ok(b & 1 == 0))?.not() - }; - bits.push(bit); - - b >>= 1; - n >>= 1; - } - - let expected = !expected; - - Boolean::enforce_kary_nand(&bits)?; - - if expected { - assert!(cs.is_satisfied().unwrap()); - } else { - assert!(!cs.is_satisfied().unwrap()); - } - } - } - } - Ok(()) - } - - #[test] - fn test_kary_and() -> Result<(), SynthesisError> { - // test different numbers of operands - for i in 1..15 { - // with every possible assignment for them - for mut b in 0..(1 << i) { - let cs = ConstraintSystem::::new_ref(); - - let mut expected = true; - - let mut bits = vec![]; - for _ in 0..i { - expected &= b & 1 == 1; - bits.push(Boolean::new_witness(cs.clone(), || Ok(b & 1 == 1))?); - b >>= 1; - } - - let r = Boolean::kary_and(&bits)?; - - assert!(cs.is_satisfied().unwrap()); - - if let Boolean::Is(ref r) = r { - assert_eq!(r.value()?, expected); - } - } - } - Ok(()) - } - - #[test] - fn test_bits_to_fp() -> Result<(), SynthesisError> { - use AllocationMode::*; - let rng = &mut ark_std::test_rng(); - let cs = ConstraintSystem::::new_ref(); - - let modes = [Input, Witness, Constant]; - for &mode in modes.iter() { - for _ in 0..1000 { - let f = Fr::rand(rng); - let bits = BitIteratorLE::new(f.into_bigint()).collect::>(); - let bits: Vec<_> = - AllocVar::new_variable(cs.clone(), || Ok(bits.as_slice()), mode)?; - let f = AllocVar::new_variable(cs.clone(), || Ok(f), mode)?; - let claimed_f = Boolean::le_bits_to_fp_var(&bits)?; - claimed_f.enforce_equal(&f)?; - } - - for _ in 0..1000 { - let f = Fr::from(u64::rand(rng)); - let bits = BitIteratorLE::new(f.into_bigint()).collect::>(); - let bits: Vec<_> = - AllocVar::new_variable(cs.clone(), || Ok(bits.as_slice()), mode)?; - let f = AllocVar::new_variable(cs.clone(), || Ok(f), mode)?; - let claimed_f = Boolean::le_bits_to_fp_var(&bits)?; - claimed_f.enforce_equal(&f)?; - } - assert!(cs.is_satisfied().unwrap()); - } - - Ok(()) - } -} diff --git a/src/bits/boolean/and.rs b/src/bits/boolean/and.rs new file mode 100644 index 00000000..7c3263a3 --- /dev/null +++ b/src/bits/boolean/and.rs @@ -0,0 +1,136 @@ +use ark_ff::Field; +use ark_relations::r1cs::SynthesisError; +use ark_std::{ops::BitAnd, ops::BitAndAssign}; + +use super::Boolean; + +impl Boolean { + fn _and(&self, other: &Self) -> Result { + use Boolean::*; + match (self, other) { + // false AND x is always false + (&Constant(false), _) | (_, &Constant(false)) => Ok(Constant(false)), + // true AND x is always x + (&Constant(true), x) | (x, &Constant(true)) => Ok(x.clone()), + (Var(ref x), Var(ref y)) => Ok(Var(x.and(y)?)), + } + } +} + +impl<'a, F: Field> BitAnd for &'a Boolean { + type Output = Boolean; + /// Outputs `self & other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// + /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; + /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; + /// + /// a.and(&a)?.enforce_equal(&Boolean::TRUE)?; + /// + /// a.and(&b)?.enforce_equal(&Boolean::FALSE)?; + /// b.and(&a)?.enforce_equal(&Boolean::FALSE)?; + /// b.and(&b)?.enforce_equal(&Boolean::FALSE)?; + /// + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitand(self, other: Self) -> Self::Output { + self._and(other).unwrap() + } +} + +impl<'a, F: Field> BitAnd<&'a Self> for Boolean { + type Output = Boolean; + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitand(self, other: &Self) -> Self::Output { + self._and(&other).unwrap() + } +} + +impl<'a, F: Field> BitAnd> for &'a Boolean { + type Output = Boolean; + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitand(self, other: Boolean) -> Self::Output { + self._and(&other).unwrap() + } +} + +impl BitAnd for Boolean { + type Output = Self; + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitand(self, other: Self) -> Self::Output { + self._and(&other).unwrap() + } +} + +impl BitAndAssign for Boolean { + /// Sets `self = self & other`. + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitand_assign(&mut self, other: Self) { + let result = self._and(&other).unwrap(); + *self = result; + } +} + +impl<'a, F: Field> BitAndAssign<&'a Self> for Boolean { + /// Sets `self = self & other`. + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitand_assign(&mut self, other: &'a Self) { + let result = self._and(other).unwrap(); + *self = result; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + boolean::test_utils::run_binary_exhaustive, + prelude::EqGadget, + R1CSVar, + }; + use ark_test_curves::bls12_381::Fr; + + #[test] + fn and() { + run_binary_exhaustive::(|a, b| { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let computed = &a & &b; + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = Boolean::new_variable( + cs.clone(), + || Ok(a.value().unwrap() & b.value().unwrap()), + expected_mode, + )?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&expected)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + }) + .unwrap() + } +} diff --git a/src/bits/boolean/cmp.rs b/src/bits/boolean/cmp.rs new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/src/bits/boolean/cmp.rs @@ -0,0 +1 @@ + diff --git a/src/bits/boolean/eq.rs b/src/bits/boolean/eq.rs new file mode 100644 index 00000000..5f7d30fc --- /dev/null +++ b/src/bits/boolean/eq.rs @@ -0,0 +1,113 @@ +use ark_relations::r1cs::SynthesisError; + +use crate::boolean::Boolean; +use crate::eq::EqGadget; + +use super::*; + +impl EqGadget for Boolean { + #[tracing::instrument(target = "r1cs")] + fn is_eq(&self, other: &Self) -> Result, SynthesisError> { + // self | other | XNOR(self, other) | self == other + // -----|-------|-------------------|-------------- + // 0 | 0 | 1 | 1 + // 0 | 1 | 0 | 0 + // 1 | 0 | 0 | 0 + // 1 | 1 | 1 | 1 + Ok(!(self ^ other)) + } + + #[tracing::instrument(target = "r1cs")] + fn conditional_enforce_equal( + &self, + other: &Self, + condition: &Boolean, + ) -> Result<(), SynthesisError> { + use Boolean::*; + let one = Variable::One; + let difference = match (self, other) { + // 1 == 1; 0 == 0 + (Constant(true), Constant(true)) | (Constant(false), Constant(false)) => return Ok(()), + // false != true + (Constant(_), Constant(_)) => return Err(SynthesisError::AssignmentMissing), + // 1 - a + (Constant(true), Var(a)) | (Var(a), Constant(true)) => lc!() + one - a.variable(), + // a - 0 = a + (Constant(false), Var(a)) | (Var(a), Constant(false)) => lc!() + a.variable(), + // b - a, + (Var(a), Var(b)) => lc!() + b.variable() - a.variable(), + }; + + if condition != &Constant(false) { + let cs = self.cs().or(other.cs()).or(condition.cs()); + cs.enforce_constraint(lc!() + difference, condition.lc(), lc!())?; + } + Ok(()) + } + + #[tracing::instrument(target = "r1cs")] + fn conditional_enforce_not_equal( + &self, + other: &Self, + should_enforce: &Boolean, + ) -> Result<(), SynthesisError> { + use Boolean::*; + let one = Variable::One; + let difference = match (self, other) { + // 1 != 0; 0 != 1 + (Constant(true), Constant(false)) | (Constant(false), Constant(true)) => return Ok(()), + // false == false and true == true + (Constant(_), Constant(_)) => return Err(SynthesisError::AssignmentMissing), + // 1 - a + (Constant(true), Var(a)) | (Var(a), Constant(true)) => lc!() + one - a.variable(), + // a - 0 = a + (Constant(false), Var(a)) | (Var(a), Constant(false)) => lc!() + a.variable(), + // b - a, + (Var(a), Var(b)) => lc!() + b.variable() - a.variable(), + }; + + if should_enforce != &Constant(false) { + let cs = self.cs().or(other.cs()).or(should_enforce.cs()); + cs.enforce_constraint(difference, should_enforce.lc(), should_enforce.lc())?; + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + boolean::test_utils::run_binary_exhaustive, + prelude::EqGadget, + R1CSVar, + }; + use ark_test_curves::bls12_381::Fr; + + #[test] + fn or() { + run_binary_exhaustive::(|a, b| { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let computed = &a.is_eq(&b)?; + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = Boolean::new_variable( + cs.clone(), + || Ok(a.value().unwrap() == b.value().unwrap()), + expected_mode, + )?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&expected)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + }) + .unwrap() + } +} diff --git a/src/bits/boolean/mod.rs b/src/bits/boolean/mod.rs new file mode 100644 index 00000000..ee0c3eb0 --- /dev/null +++ b/src/bits/boolean/mod.rs @@ -0,0 +1,1024 @@ +use ark_ff::{BitIteratorBE, Field, PrimeField}; + +use crate::{fields::fp::FpVar, prelude::*, Assignment, ToConstraintFieldGadget, Vec}; +use ark_relations::r1cs::{ + ConstraintSystemRef, LinearCombination, Namespace, SynthesisError, Variable, +}; +use core::borrow::Borrow; + +mod and; +mod cmp; +mod eq; +mod not; +mod or; +mod xor; + +#[cfg(test)] +mod test_utils; + +/// Represents a variable in the constraint system which is guaranteed +/// to be either zero or one. +/// +/// In general, one should prefer using `Boolean` instead of `AllocatedBool`, +/// as `Boolean` offers better support for constant values, and implements +/// more traits. +#[derive(Clone, Debug, Eq, PartialEq)] +#[must_use] +pub struct AllocatedBool { + variable: Variable, + cs: ConstraintSystemRef, +} + +pub(crate) fn bool_to_field(val: impl Borrow) -> F { + if *val.borrow() { + F::one() + } else { + F::zero() + } +} + +impl AllocatedBool { + /// Get the assigned value for `self`. + pub fn value(&self) -> Result { + let value = self.cs.assigned_value(self.variable).get()?; + if value.is_zero() { + Ok(false) + } else if value.is_one() { + Ok(true) + } else { + unreachable!("Incorrect value assigned: {:?}", value); + } + } + + /// Get the R1CS variable for `self`. + pub fn variable(&self) -> Variable { + self.variable + } + + /// Allocate a witness variable without a booleanity check. + fn new_witness_without_booleanity_check>( + cs: ConstraintSystemRef, + f: impl FnOnce() -> Result, + ) -> Result { + let variable = cs.new_witness_variable(|| f().map(bool_to_field))?; + Ok(Self { variable, cs }) + } + + /// Performs an XOR operation over the two operands, returning + /// an `AllocatedBool`. + #[tracing::instrument(target = "r1cs")] + pub fn not(&self) -> Result { + let variable = self.cs.new_lc(lc!() + Variable::One - self.variable)?; + Ok(Self { + variable, + cs: self.cs.clone(), + }) + } + + /// Performs an XOR operation over the two operands, returning + /// an `AllocatedBool`. + #[tracing::instrument(target = "r1cs")] + pub fn xor(&self, b: &Self) -> Result { + let result = Self::new_witness_without_booleanity_check(self.cs.clone(), || { + Ok(self.value()? ^ b.value()?) + })?; + + // Constrain (a + a) * (b) = (a + b - c) + // Given that a and b are boolean constrained, if they + // are equal, the only solution for c is 0, and if they + // are different, the only solution for c is 1. + // + // ¬(a ∧ b) ∧ ¬(¬a ∧ ¬b) = c + // (1 - (a * b)) * (1 - ((1 - a) * (1 - b))) = c + // (1 - ab) * (1 - (1 - a - b + ab)) = c + // (1 - ab) * (a + b - ab) = c + // a + b - ab - (a^2)b - (b^2)a + (a^2)(b^2) = c + // a + b - ab - ab - ab + ab = c + // a + b - 2ab = c + // -2a * b = c - a - b + // 2a * b = a + b - c + // (a + a) * b = a + b - c + self.cs.enforce_constraint( + lc!() + self.variable + self.variable, + lc!() + b.variable, + lc!() + self.variable + b.variable - result.variable, + )?; + + Ok(result) + } + + /// Performs an AND operation over the two operands, returning + /// an `AllocatedBool`. + #[tracing::instrument(target = "r1cs")] + pub fn and(&self, b: &Self) -> Result { + let result = Self::new_witness_without_booleanity_check(self.cs.clone(), || { + Ok(self.value()? & b.value()?) + })?; + + // Constrain (a) * (b) = (c), ensuring c is 1 iff + // a AND b are both 1. + self.cs.enforce_constraint( + lc!() + self.variable, + lc!() + b.variable, + lc!() + result.variable, + )?; + + Ok(result) + } + + /// Performs an OR operation over the two operands, returning + /// an `AllocatedBool`. + #[tracing::instrument(target = "r1cs")] + pub fn or(&self, b: &Self) -> Result { + let result = Self::new_witness_without_booleanity_check(self.cs.clone(), || { + Ok(self.value()? | b.value()?) + })?; + + // Constrain (1 - a) * (1 - b) = (c), ensuring c is 1 iff + // a and b are both false, and otherwise c is 0. + self.cs.enforce_constraint( + lc!() + Variable::One - self.variable, + lc!() + Variable::One - b.variable, + lc!() + Variable::One - result.variable, + )?; + + Ok(result) + } + + /// Calculates `a AND (NOT b)`. + #[tracing::instrument(target = "r1cs")] + pub fn and_not(&self, b: &Self) -> Result { + let result = Self::new_witness_without_booleanity_check(self.cs.clone(), || { + Ok(self.value()? & !b.value()?) + })?; + + // Constrain (a) * (1 - b) = (c), ensuring c is 1 iff + // a is true and b is false, and otherwise c is 0. + self.cs.enforce_constraint( + lc!() + self.variable, + lc!() + Variable::One - b.variable, + lc!() + result.variable, + )?; + + Ok(result) + } + + /// Calculates `(NOT a) AND (NOT b)`. + #[tracing::instrument(target = "r1cs")] + pub fn nor(&self, b: &Self) -> Result { + let result = Self::new_witness_without_booleanity_check(self.cs.clone(), || { + Ok(!(self.value()? | b.value()?)) + })?; + + // Constrain (1 - a) * (1 - b) = (c), ensuring c is 1 iff + // a and b are both false, and otherwise c is 0. + self.cs.enforce_constraint( + lc!() + Variable::One - self.variable, + lc!() + Variable::One - b.variable, + lc!() + result.variable, + )?; + + Ok(result) + } +} + +impl AllocVar for AllocatedBool { + /// Produces a new variable of the appropriate kind + /// (instance or witness), with a booleanity check. + /// + /// N.B.: we could omit the booleanity check when allocating `self` + /// as a new public input, but that places an additional burden on + /// protocol designers. Better safe than sorry! + fn new_variable>( + cs: impl Into>, + f: impl FnOnce() -> Result, + mode: AllocationMode, + ) -> Result { + let ns = cs.into(); + let cs = ns.cs(); + if mode == AllocationMode::Constant { + let variable = if *f()?.borrow() { + Variable::One + } else { + Variable::Zero + }; + Ok(Self { variable, cs }) + } else { + let variable = if mode == AllocationMode::Input { + cs.new_input_variable(|| f().map(bool_to_field))? + } else { + cs.new_witness_variable(|| f().map(bool_to_field))? + }; + + // Constrain: (1 - a) * a = 0 + // This constrains a to be either 0 or 1. + + cs.enforce_constraint(lc!() + Variable::One - variable, lc!() + variable, lc!())?; + + Ok(Self { variable, cs }) + } + } +} + +impl CondSelectGadget for AllocatedBool { + #[tracing::instrument(target = "r1cs")] + fn conditionally_select( + cond: &Boolean, + true_val: &Self, + false_val: &Self, + ) -> Result { + let res = Boolean::conditionally_select( + cond, + &true_val.clone().into(), + &false_val.clone().into(), + )?; + match res { + Boolean::Var(a) => Ok(a), + _ => unreachable!("Impossible"), + } + } +} + +/// Represents a boolean value in the constraint system which is guaranteed +/// to be either zero or one. +#[derive(Clone, Debug, Eq, PartialEq)] +#[must_use] +pub enum Boolean { + Var(AllocatedBool), + Constant(bool), +} + +impl R1CSVar for Boolean { + type Value = bool; + + fn cs(&self) -> ConstraintSystemRef { + match self { + Self::Var(a) => a.cs.clone(), + _ => ConstraintSystemRef::None, + } + } + + fn value(&self) -> Result { + match self { + Boolean::Constant(c) => Ok(*c), + Boolean::Var(ref v) => v.value(), + } + } +} + +impl Boolean { + /// The constant `true`. + pub const TRUE: Self = Boolean::Constant(true); + + /// The constant `false`. + pub const FALSE: Self = Boolean::Constant(false); + + /// Constructs a `Boolean` vector from a slice of constant `u8`. + /// The `u8`s are decomposed in little-endian manner. + /// + /// This *does not* create any new variables or constraints. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let t = Boolean::::TRUE; + /// let f = Boolean::::FALSE; + /// + /// let bits = vec![f, t]; + /// let generated_bits = Boolean::constant_vec_from_bytes(&[2]); + /// bits[..2].enforce_equal(&generated_bits[..2])?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + pub fn constant_vec_from_bytes(values: &[u8]) -> Vec { + let mut bits = vec![]; + for byte in values { + for i in 0..8 { + bits.push(Self::Constant(((byte >> i) & 1u8) == 1u8)); + } + } + bits + } + + /// Constructs a constant `Boolean` with value `b`. + /// + /// This *does not* create any new variables or constraints. + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_r1cs_std::prelude::*; + /// + /// let true_var = Boolean::::TRUE; + /// let false_var = Boolean::::FALSE; + /// + /// true_var.enforce_equal(&Boolean::constant(true))?; + /// false_var.enforce_equal(&Boolean::constant(false))?; + /// # Ok(()) + /// # } + /// ``` + pub fn constant(b: bool) -> Self { + Boolean::Constant(b) + } + + /// Constructs a `LinearCombination` from `Self`'s variables according + /// to the following map. + /// + /// * `Boolean::TRUE => lc!() + Variable::One` + /// * `Boolean::FALSE => lc!()` + /// * `Boolean::Var(v) => lc!() + v.variable()` + pub fn lc(&self) -> LinearCombination { + match self { + &Boolean::Constant(false) => lc!(), + &Boolean::Constant(true) => lc!() + Variable::One, + Boolean::Var(v) => v.variable().into(), + } + } + + /// Outputs `bits[0] & bits[1] & ... & bits.last().unwrap()`. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// + /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; + /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; + /// let c = Boolean::new_witness(cs.clone(), || Ok(true))?; + /// + /// Boolean::kary_and(&[a.clone(), b.clone(), c.clone()])?.enforce_equal(&Boolean::FALSE)?; + /// Boolean::kary_and(&[a.clone(), c.clone()])?.enforce_equal(&Boolean::TRUE)?; + /// + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs")] + pub fn kary_and(bits: &[Self]) -> Result { + assert!(!bits.is_empty()); + let mut cur: Option = None; + for next in bits { + cur = if let Some(b) = cur { + Some(b & next) + } else { + Some(next.clone()) + }; + } + + Ok(cur.expect("should not be 0")) + } + + /// Outputs `bits[0] | bits[1] | ... | bits.last().unwrap()`. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// + /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; + /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; + /// let c = Boolean::new_witness(cs.clone(), || Ok(false))?; + /// + /// Boolean::kary_or(&[a.clone(), b.clone(), c.clone()])?.enforce_equal(&Boolean::TRUE)?; + /// Boolean::kary_or(&[a.clone(), c.clone()])?.enforce_equal(&Boolean::TRUE)?; + /// Boolean::kary_or(&[b.clone(), c.clone()])?.enforce_equal(&Boolean::FALSE)?; + /// + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs")] + pub fn kary_or(bits: &[Self]) -> Result { + assert!(!bits.is_empty()); + let mut cur: Option = None; + for next in bits { + cur = if let Some(b) = cur { + Some(b | next) + } else { + Some(next.clone()) + }; + } + + Ok(cur.expect("should not be 0")) + } + + /// Outputs `(bits[0] & bits[1] & ... & bits.last().unwrap()).not()`. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// + /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; + /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; + /// let c = Boolean::new_witness(cs.clone(), || Ok(true))?; + /// + /// Boolean::kary_nand(&[a.clone(), b.clone(), c.clone()])?.enforce_equal(&Boolean::TRUE)?; + /// Boolean::kary_nand(&[a.clone(), c.clone()])?.enforce_equal(&Boolean::FALSE)?; + /// Boolean::kary_nand(&[b.clone(), c.clone()])?.enforce_equal(&Boolean::TRUE)?; + /// + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs")] + pub fn kary_nand(bits: &[Self]) -> Result { + Ok(!Self::kary_and(bits)?) + } + + /// Enforces that `Self::kary_nand(bits).is_eq(&Boolean::TRUE)`. + /// + /// Informally, this means that at least one element in `bits` must be + /// `false`. + #[tracing::instrument(target = "r1cs")] + fn enforce_kary_nand(bits: &[Self]) -> Result<(), SynthesisError> { + Self::kary_nand(bits)?.enforce_equal(&Boolean::TRUE) + } + + /// Convert a little-endian bitwise representation of a field element to + /// `FpVar` + #[tracing::instrument(target = "r1cs", skip(bits))] + pub fn le_bits_to_fp_var(bits: &[Self]) -> Result, SynthesisError> + where + F: PrimeField, + { + // Compute the value of the `FpVar` variable via double-and-add. + let mut value = None; + let cs = bits.cs(); + // Assign a value only when `cs` is in setup mode, or if we are constructing + // a constant. + let should_construct_value = (!cs.is_in_setup_mode()) || bits.is_constant(); + if should_construct_value { + let bits = bits.iter().map(|b| b.value().unwrap()).collect::>(); + let bytes = bits + .chunks(8) + .map(|c| { + let mut value = 0u8; + for (i, &bit) in c.iter().enumerate() { + value += (bit as u8) << i; + } + value + }) + .collect::>(); + value = Some(F::from_le_bytes_mod_order(&bytes)); + } + + if bits.is_constant() { + Ok(FpVar::constant(value.unwrap())) + } else { + let mut power = F::one(); + // Compute a linear combination for the new field variable, again + // via double and add. + + let combined = bits + .iter() + .map(|b| { + let result = FpVar::from(b.clone()) * power; + power.double_in_place(); + result + }) + .sum(); + // If the number of bits is less than the size of the field, + // then we do not need to enforce that the element is less than + // the modulus. + if bits.len() >= F::MODULUS_BIT_SIZE as usize { + Self::enforce_in_field_le(bits)?; + } + Ok(combined) + } + } + + /// Enforces that `bits`, when interpreted as a integer, is less than + /// `F::characteristic()`, That is, interpret bits as a little-endian + /// integer, and enforce that this integer is "in the field Z_p", where + /// `p = F::characteristic()` . + #[tracing::instrument(target = "r1cs")] + pub fn enforce_in_field_le(bits: &[Self]) -> Result<(), SynthesisError> { + // `bits` < F::characteristic() <==> `bits` <= F::characteristic() -1 + let mut b = F::characteristic().to_vec(); + assert_eq!(b[0] % 2, 1); + b[0] -= 1; // This works, because the LSB is one, so there's no borrows. + let run = Self::enforce_smaller_or_equal_than_le(bits, b)?; + + // We should always end in a "run" of zeros, because + // the characteristic is an odd prime. So, this should + // be empty. + assert!(run.is_empty()); + + Ok(()) + } + + /// Enforces that `bits` is less than or equal to `element`, + /// when both are interpreted as (little-endian) integers. + #[tracing::instrument(target = "r1cs", skip(element))] + pub fn enforce_smaller_or_equal_than_le<'a>( + bits: &[Self], + element: impl AsRef<[u64]>, + ) -> Result, SynthesisError> { + let b: &[u64] = element.as_ref(); + + let mut bits_iter = bits.iter().rev(); // Iterate in big-endian + + // Runs of ones in r + let mut last_run = Boolean::constant(true); + let mut current_run = vec![]; + + let mut element_num_bits = 0; + for _ in BitIteratorBE::without_leading_zeros(b) { + element_num_bits += 1; + } + + if bits.len() > element_num_bits { + let mut or_result = Boolean::constant(false); + for should_be_zero in &bits[element_num_bits..] { + or_result |= should_be_zero; + let _ = bits_iter.next().unwrap(); + } + or_result.enforce_equal(&Boolean::constant(false))?; + } + + for (b, a) in BitIteratorBE::without_leading_zeros(b).zip(bits_iter.by_ref()) { + if b { + // This is part of a run of ones. + current_run.push(a.clone()); + } else { + if !current_run.is_empty() { + // This is the start of a run of zeros, but we need + // to k-ary AND against `last_run` first. + + current_run.push(last_run.clone()); + last_run = Self::kary_and(¤t_run)?; + current_run.truncate(0); + } + + // If `last_run` is true, `a` must be false, or it would + // not be in the field. + // + // If `last_run` is false, `a` can be true or false. + // + // Ergo, at least one of `last_run` and `a` must be false. + Self::enforce_kary_nand(&[last_run.clone(), a.clone()])?; + } + } + assert!(bits_iter.next().is_none()); + + Ok(current_run) + } + + /// Conditionally selects one of `first` and `second` based on the value of + /// `self`: + /// + /// If `self.is_eq(&Boolean::TRUE)`, this outputs `first`; else, it outputs + /// `second`. + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// + /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; + /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; + /// + /// let cond = Boolean::new_witness(cs.clone(), || Ok(true))?; + /// + /// cond.select(&a, &b)?.enforce_equal(&Boolean::TRUE)?; + /// cond.select(&b, &a)?.enforce_equal(&Boolean::FALSE)?; + /// + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(first, second))] + pub fn select>( + &self, + first: &T, + second: &T, + ) -> Result { + T::conditionally_select(&self, first, second) + } +} + +impl From> for Boolean { + fn from(b: AllocatedBool) -> Self { + Boolean::Var(b) + } +} + +impl AllocVar for Boolean { + fn new_variable>( + cs: impl Into>, + f: impl FnOnce() -> Result, + mode: AllocationMode, + ) -> Result { + if mode == AllocationMode::Constant { + Ok(Boolean::Constant(*f()?.borrow())) + } else { + AllocatedBool::new_variable(cs, f, mode).map(Boolean::Var) + } + } +} + +impl ToBytesGadget for Boolean { + /// Outputs `1u8` if `self` is true, and `0u8` otherwise. + #[tracing::instrument(target = "r1cs")] + fn to_bytes(&self) -> Result>, SynthesisError> { + let value = self.value().map(u8::from).ok(); + let mut bits = [Boolean::FALSE; 8]; + bits[0] = self.clone(); + Ok(vec![UInt8 { bits, value }]) + } +} + +impl ToConstraintFieldGadget for Boolean { + #[tracing::instrument(target = "r1cs")] + fn to_constraint_field(&self) -> Result>, SynthesisError> { + let var = From::from(self.clone()); + Ok(vec![var]) + } +} + +impl CondSelectGadget for Boolean { + #[tracing::instrument(target = "r1cs")] + fn conditionally_select( + cond: &Boolean, + true_val: &Self, + false_val: &Self, + ) -> Result { + use Boolean::*; + match cond { + Constant(true) => Ok(true_val.clone()), + Constant(false) => Ok(false_val.clone()), + cond @ Var(_) => match (true_val, false_val) { + (x, &Constant(false)) => Ok(cond & x), + (&Constant(false), x) => Ok((!cond) & x), + (&Constant(true), x) => Ok(cond | x), + (x, &Constant(true)) => Ok((!cond) | x), + (a, b) => { + let cs = cond.cs(); + let result: Boolean = + AllocatedBool::new_witness_without_booleanity_check(cs.clone(), || { + let cond = cond.value()?; + Ok(if cond { a.value()? } else { b.value()? }) + })? + .into(); + // a = self; b = other; c = cond; + // + // r = c * a + (1 - c) * b + // r = b + c * (a - b) + // c * (a - b) = r - b + // + // If a, b, cond are all boolean, so is r. + // + // self | other | cond | result + // -----|-------|---------------- + // 0 | 0 | 1 | 0 + // 0 | 1 | 1 | 0 + // 1 | 0 | 1 | 1 + // 1 | 1 | 1 | 1 + // 0 | 0 | 0 | 0 + // 0 | 1 | 0 | 1 + // 1 | 0 | 0 | 0 + // 1 | 1 | 0 | 1 + cs.enforce_constraint( + cond.lc(), + lc!() + a.lc() - b.lc(), + lc!() + result.lc() - b.lc(), + )?; + + Ok(result) + }, + }, + } + } +} + +#[cfg(test)] +mod test { + use super::{AllocatedBool, Boolean}; + use crate::prelude::*; + use ark_ff::{BitIteratorBE, BitIteratorLE, Field, One, PrimeField, UniformRand}; + use ark_relations::r1cs::{ConstraintSystem, SynthesisError}; + use ark_test_curves::bls12_381::Fr; + + #[test] + fn test_boolean_to_byte() -> Result<(), SynthesisError> { + for val in [true, false].iter() { + let cs = ConstraintSystem::::new_ref(); + let a = Boolean::new_witness(cs.clone(), || Ok(*val))?; + let bytes = a.to_bytes()?; + assert_eq!(bytes.len(), 1); + let byte = &bytes[0]; + assert_eq!(byte.value()?, *val as u8); + + for (i, bit) in byte.bits.iter().enumerate() { + assert_eq!(bit.value()?, (byte.value()? >> i) & 1 == 1); + } + } + Ok(()) + } + + #[test] + fn allocated_xor() -> Result<(), SynthesisError> { + for a_val in [false, true].iter().copied() { + for b_val in [false, true].iter().copied() { + let cs = ConstraintSystem::::new_ref(); + let a = AllocatedBool::new_witness(cs.clone(), || Ok(a_val))?; + let b = AllocatedBool::new_witness(cs.clone(), || Ok(b_val))?; + let c = AllocatedBool::xor(&a, &b)?; + assert_eq!(c.value()?, a_val ^ b_val); + + assert!(cs.is_satisfied().unwrap()); + assert_eq!(a.value()?, (a_val)); + assert_eq!(b.value()?, (b_val)); + assert_eq!(c.value()?, (a_val ^ b_val)); + } + } + Ok(()) + } + + #[test] + fn allocated_or() -> Result<(), SynthesisError> { + for a_val in [false, true].iter().copied() { + for b_val in [false, true].iter().copied() { + let cs = ConstraintSystem::::new_ref(); + let a = AllocatedBool::new_witness(cs.clone(), || Ok(a_val))?; + let b = AllocatedBool::new_witness(cs.clone(), || Ok(b_val))?; + let c = AllocatedBool::or(&a, &b)?; + assert_eq!(c.value()?, a_val | b_val); + + assert!(cs.is_satisfied().unwrap()); + assert_eq!(a.value()?, (a_val)); + assert_eq!(b.value()?, (b_val)); + assert_eq!(c.value()?, (a_val | b_val)); + } + } + Ok(()) + } + + #[test] + fn allocated_and() -> Result<(), SynthesisError> { + for a_val in [false, true].iter().copied() { + for b_val in [false, true].iter().copied() { + let cs = ConstraintSystem::::new_ref(); + let a = AllocatedBool::new_witness(cs.clone(), || Ok(a_val))?; + let b = AllocatedBool::new_witness(cs.clone(), || Ok(b_val))?; + let c = AllocatedBool::and(&a, &b)?; + assert_eq!(c.value()?, a_val & b_val); + + assert!(cs.is_satisfied().unwrap()); + assert_eq!(a.value()?, (a_val)); + assert_eq!(b.value()?, (b_val)); + assert_eq!(c.value()?, (a_val & b_val)); + } + } + Ok(()) + } + + #[test] + fn allocated_and_not() -> Result<(), SynthesisError> { + for a_val in [false, true].iter().copied() { + for b_val in [false, true].iter().copied() { + let cs = ConstraintSystem::::new_ref(); + let a = AllocatedBool::new_witness(cs.clone(), || Ok(a_val))?; + let b = AllocatedBool::new_witness(cs.clone(), || Ok(b_val))?; + let c = AllocatedBool::and_not(&a, &b)?; + assert_eq!(c.value()?, a_val & !b_val); + + assert!(cs.is_satisfied().unwrap()); + assert_eq!(a.value()?, (a_val)); + assert_eq!(b.value()?, (b_val)); + assert_eq!(c.value()?, (a_val & !b_val)); + } + } + Ok(()) + } + + #[test] + fn allocated_nor() -> Result<(), SynthesisError> { + for a_val in [false, true].iter().copied() { + for b_val in [false, true].iter().copied() { + let cs = ConstraintSystem::::new_ref(); + let a = AllocatedBool::new_witness(cs.clone(), || Ok(a_val))?; + let b = AllocatedBool::new_witness(cs.clone(), || Ok(b_val))?; + let c = AllocatedBool::nor(&a, &b)?; + assert_eq!(c.value()?, !a_val & !b_val); + + assert!(cs.is_satisfied().unwrap()); + assert_eq!(a.value()?, (a_val)); + assert_eq!(b.value()?, (b_val)); + assert_eq!(c.value()?, (!a_val & !b_val)); + } + } + Ok(()) + } + + #[test] + fn test_smaller_than_or_equal_to() -> Result<(), SynthesisError> { + let mut rng = ark_std::test_rng(); + for _ in 0..1000 { + let mut r = Fr::rand(&mut rng); + let mut s = Fr::rand(&mut rng); + if r > s { + core::mem::swap(&mut r, &mut s) + } + + let cs = ConstraintSystem::::new_ref(); + + let native_bits: Vec<_> = BitIteratorLE::new(r.into_bigint()).collect(); + let bits = Vec::new_witness(cs.clone(), || Ok(native_bits))?; + Boolean::enforce_smaller_or_equal_than_le(&bits, s.into_bigint())?; + + assert!(cs.is_satisfied().unwrap()); + } + + for _ in 0..1000 { + let r = Fr::rand(&mut rng); + if r == -Fr::one() { + continue; + } + let s = r + Fr::one(); + let s2 = r.double(); + let cs = ConstraintSystem::::new_ref(); + + let native_bits: Vec<_> = BitIteratorLE::new(r.into_bigint()).collect(); + let bits = Vec::new_witness(cs.clone(), || Ok(native_bits))?; + Boolean::enforce_smaller_or_equal_than_le(&bits, s.into_bigint())?; + if r < s2 { + Boolean::enforce_smaller_or_equal_than_le(&bits, s2.into_bigint())?; + } + + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + #[test] + fn test_enforce_in_field() -> Result<(), SynthesisError> { + { + let cs = ConstraintSystem::::new_ref(); + + let mut bits = vec![]; + for b in BitIteratorBE::new(Fr::characteristic()).skip(1) { + bits.push(Boolean::new_witness(cs.clone(), || Ok(b))?); + } + bits.reverse(); + + Boolean::enforce_in_field_le(&bits)?; + + assert!(!cs.is_satisfied().unwrap()); + } + + let mut rng = ark_std::test_rng(); + + for _ in 0..1000 { + let r = Fr::rand(&mut rng); + let cs = ConstraintSystem::::new_ref(); + + let mut bits = vec![]; + for b in BitIteratorBE::new(r.into_bigint()).skip(1) { + bits.push(Boolean::new_witness(cs.clone(), || Ok(b))?); + } + bits.reverse(); + + Boolean::enforce_in_field_le(&bits)?; + + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + #[test] + fn test_enforce_nand() -> Result<(), SynthesisError> { + { + let cs = ConstraintSystem::::new_ref(); + + assert!( + Boolean::enforce_kary_nand(&[Boolean::new_constant(cs.clone(), false)?]).is_ok() + ); + assert!( + Boolean::enforce_kary_nand(&[Boolean::new_constant(cs.clone(), true)?]).is_err() + ); + } + + for i in 1..5 { + // with every possible assignment for them + for mut b in 0..(1 << i) { + // with every possible negation + for mut n in 0..(1 << i) { + let cs = ConstraintSystem::::new_ref(); + + let mut expected = true; + + let mut bits = vec![]; + for _ in 0..i { + expected &= b & 1 == 1; + + let bit = if n & 1 == 1 { + Boolean::new_witness(cs.clone(), || Ok(b & 1 == 1))? + } else { + !Boolean::new_witness(cs.clone(), || Ok(b & 1 == 0))? + }; + bits.push(bit); + + b >>= 1; + n >>= 1; + } + + let expected = !expected; + + Boolean::enforce_kary_nand(&bits)?; + + if expected { + assert!(cs.is_satisfied().unwrap()); + } else { + assert!(!cs.is_satisfied().unwrap()); + } + } + } + } + Ok(()) + } + + #[test] + fn test_kary_and() -> Result<(), SynthesisError> { + // test different numbers of operands + for i in 1..15 { + // with every possible assignment for them + for mut b in 0..(1 << i) { + let cs = ConstraintSystem::::new_ref(); + + let mut expected = true; + + let mut bits = vec![]; + for _ in 0..i { + expected &= b & 1 == 1; + bits.push(Boolean::new_witness(cs.clone(), || Ok(b & 1 == 1))?); + b >>= 1; + } + + let r = Boolean::kary_and(&bits)?; + + assert!(cs.is_satisfied().unwrap()); + + if let Boolean::Var(ref r) = r { + assert_eq!(r.value()?, expected); + } + } + } + Ok(()) + } + + #[test] + fn test_bits_to_fp() -> Result<(), SynthesisError> { + use AllocationMode::*; + let rng = &mut ark_std::test_rng(); + let cs = ConstraintSystem::::new_ref(); + + let modes = [Input, Witness, Constant]; + for &mode in modes.iter() { + for _ in 0..1000 { + let f = Fr::rand(rng); + let bits = BitIteratorLE::new(f.into_bigint()).collect::>(); + let bits: Vec<_> = + AllocVar::new_variable(cs.clone(), || Ok(bits.as_slice()), mode)?; + let f = AllocVar::new_variable(cs.clone(), || Ok(f), mode)?; + let claimed_f = Boolean::le_bits_to_fp_var(&bits)?; + claimed_f.enforce_equal(&f)?; + } + + for _ in 0..1000 { + let f = Fr::from(u64::rand(rng)); + let bits = BitIteratorLE::new(f.into_bigint()).collect::>(); + let bits: Vec<_> = + AllocVar::new_variable(cs.clone(), || Ok(bits.as_slice()), mode)?; + let f = AllocVar::new_variable(cs.clone(), || Ok(f), mode)?; + let claimed_f = Boolean::le_bits_to_fp_var(&bits)?; + claimed_f.enforce_equal(&f)?; + } + assert!(cs.is_satisfied().unwrap()); + } + + Ok(()) + } +} diff --git a/src/bits/boolean/not.rs b/src/bits/boolean/not.rs new file mode 100644 index 00000000..97a71468 --- /dev/null +++ b/src/bits/boolean/not.rs @@ -0,0 +1,99 @@ +use ark_ff::Field; +use ark_relations::r1cs::SynthesisError; +use ark_std::ops::Not; + +use super::Boolean; + +impl Boolean { + fn _not(&self) -> Result { + match *self { + Boolean::Constant(c) => Ok(Boolean::Constant(!c)), + Boolean::Var(ref v) => Ok(Boolean::Var(v.not().unwrap())), + } + } +} + +impl<'a, F: Field> Not for &'a Boolean { + type Output = Boolean; + /// Negates `self`. + /// + /// This *does not* create any new variables or constraints. + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// + /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; + /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; + /// + /// a.not().enforce_equal(&b)?; + /// b.not().enforce_equal(&a)?; + /// + /// a.not().enforce_equal(&Boolean::FALSE)?; + /// b.not().enforce_equal(&Boolean::TRUE)?; + /// + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self))] + fn not(self) -> Self::Output { + self._not().unwrap() + } +} + +impl<'a, F: Field> Not for &'a mut Boolean { + type Output = Boolean; + + #[tracing::instrument(target = "r1cs", skip(self))] + fn not(self) -> Self::Output { + self._not().unwrap() + } +} + +impl Not for Boolean { + type Output = Boolean; + + #[tracing::instrument(target = "r1cs", skip(self))] + fn not(self) -> Self::Output { + self._not().unwrap() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + boolean::test_utils::run_unary_exhaustive, + prelude::EqGadget, + R1CSVar, + }; + use ark_test_curves::bls12_381::Fr; + + #[test] + fn not() { + run_unary_exhaustive::(|a| { + let cs = a.cs(); + let computed = !&a; + let expected_mode = if a.is_constant() { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = + Boolean::new_variable(cs.clone(), || Ok(!a.value().unwrap()), expected_mode)?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&expected)?; + if !a.is_constant() { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + }) + .unwrap() + } +} diff --git a/src/bits/boolean/or.rs b/src/bits/boolean/or.rs new file mode 100644 index 00000000..fa7396d9 --- /dev/null +++ b/src/bits/boolean/or.rs @@ -0,0 +1,135 @@ +use ark_ff::Field; +use ark_relations::r1cs::SynthesisError; +use ark_std::{ops::BitOr, ops::BitOrAssign}; + +use super::Boolean; + +impl Boolean { + fn _or(&self, other: &Self) -> Result { + use Boolean::*; + match (self, other) { + (&Constant(false), x) | (x, &Constant(false)) => Ok(x.clone()), + (&Constant(true), _) | (_, &Constant(true)) => Ok(Constant(true)), + (Var(ref x), Var(ref y)) => Ok(Var(x.or(y)?)), + } + } +} + +impl<'a, F: Field> BitOr for &'a Boolean { + type Output = Boolean; + + /// Outputs `self | other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// + /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; + /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; + /// + /// a.or(&b)?.enforce_equal(&Boolean::TRUE)?; + /// b.or(&a)?.enforce_equal(&Boolean::TRUE)?; + /// + /// a.or(&a)?.enforce_equal(&Boolean::TRUE)?; + /// b.or(&b)?.enforce_equal(&Boolean::FALSE)?; + /// + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitor(self, other: Self) -> Self::Output { + self._or(other).unwrap() + } +} + +impl<'a, F: Field> BitOr<&'a Self> for Boolean { + type Output = Boolean; + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitor(self, other: &Self) -> Self::Output { + self._or(&other).unwrap() + } +} + +impl<'a, F: Field> BitOr> for &'a Boolean { + type Output = Boolean; + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitor(self, other: Boolean) -> Self::Output { + self._or(&other).unwrap() + } +} + +impl BitOr for Boolean { + type Output = Self; + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitor(self, other: Self) -> Self::Output { + self._or(&other).unwrap() + } +} + +impl BitOrAssign for Boolean { + /// Sets `self = self | other`. + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitor_assign(&mut self, other: Self) { + let result = self._or(&other).unwrap(); + *self = result; + } +} + +impl<'a, F: Field> BitOrAssign<&'a Self> for Boolean { + /// Sets `self = self | other`. + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitor_assign(&mut self, other: &'a Self) { + let result = self._or(other).unwrap(); + *self = result; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + boolean::test_utils::run_binary_exhaustive, + prelude::EqGadget, + R1CSVar, + }; + use ark_test_curves::bls12_381::Fr; + + #[test] + fn or() { + run_binary_exhaustive::(|a, b| { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let computed = &a | &b; + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = Boolean::new_variable( + cs.clone(), + || Ok(a.value().unwrap() | b.value().unwrap()), + expected_mode, + )?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&expected)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + }) + .unwrap() + } +} diff --git a/src/bits/boolean/test_utils.rs b/src/bits/boolean/test_utils.rs new file mode 100644 index 00000000..9577e82d --- /dev/null +++ b/src/bits/boolean/test_utils.rs @@ -0,0 +1,47 @@ +use crate::test_utils; + +use super::*; +use ark_relations::r1cs::{ConstraintSystem, SynthesisError}; + +pub(crate) fn test_unary_op( + a: bool, + mode: AllocationMode, + test: impl FnOnce(Boolean) -> Result<(), SynthesisError>, +) -> Result<(), SynthesisError> { + let cs = ConstraintSystem::::new_ref(); + let a = Boolean::::new_variable(cs.clone(), || Ok(a), mode)?; + test(a) +} + +pub(crate) fn test_binary_op( + a: bool, + b: bool, + mode_a: AllocationMode, + mode_b: AllocationMode, + test: impl FnOnce(Boolean, Boolean) -> Result<(), SynthesisError>, +) -> Result<(), SynthesisError> { + let cs = ConstraintSystem::::new_ref(); + let a = Boolean::::new_variable(cs.clone(), || Ok(a), mode_a)?; + let b = Boolean::::new_variable(cs.clone(), || Ok(b), mode_b)?; + test(a, b) +} + +pub(crate) fn run_binary_exhaustive( + test: impl Fn(Boolean, Boolean) -> Result<(), SynthesisError> + Copy, +) -> Result<(), SynthesisError> { + for (mode_a, a) in test_utils::combination([false, true].into_iter()) { + for (mode_b, b) in test_utils::combination([false, true].into_iter()) { + test_binary_op(a, b, mode_a, mode_b, test)?; + } + } + Ok(()) +} + +pub(crate) fn run_unary_exhaustive( + test: impl Fn(Boolean) -> Result<(), SynthesisError> + Copy, +) -> Result<(), SynthesisError> { + for (mode, a) in test_utils::combination([false, true].into_iter()) { + test_unary_op(a, mode, test)?; + } + Ok(()) +} diff --git a/src/bits/boolean/xor.rs b/src/bits/boolean/xor.rs new file mode 100644 index 00000000..4540b2e6 --- /dev/null +++ b/src/bits/boolean/xor.rs @@ -0,0 +1,135 @@ +use ark_ff::Field; +use ark_relations::r1cs::SynthesisError; +use ark_std::{ops::BitXor, ops::BitXorAssign}; + +use super::Boolean; + +impl Boolean { + fn _xor(&self, other: &Self) -> Result { + use Boolean::*; + match (self, other) { + (&Constant(false), x) | (x, &Constant(false)) => Ok(x.clone()), + (&Constant(true), x) | (x, &Constant(true)) => Ok(!x), + (Var(ref x), Var(ref y)) => Ok(Var(x.xor(y)?)), + } + } +} + +impl<'a, F: Field> BitXor for &'a Boolean { + type Output = Boolean; + + /// Outputs `self ^ other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// + /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; + /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; + /// + /// a.xor(&b)?.enforce_equal(&Boolean::TRUE)?; + /// b.xor(&a)?.enforce_equal(&Boolean::TRUE)?; + /// + /// a.xor(&a)?.enforce_equal(&Boolean::FALSE)?; + /// b.xor(&b)?.enforce_equal(&Boolean::FALSE)?; + /// + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitxor(self, other: Self) -> Self::Output { + self._xor(other).unwrap() + } +} + +impl<'a, F: Field> BitXor<&'a Self> for Boolean { + type Output = Boolean; + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitxor(self, other: &Self) -> Self::Output { + self._xor(&other).unwrap() + } +} + +impl<'a, F: Field> BitXor> for &'a Boolean { + type Output = Boolean; + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitxor(self, other: Boolean) -> Self::Output { + self._xor(&other).unwrap() + } +} + +impl BitXor for Boolean { + type Output = Self; + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitxor(self, other: Self) -> Self::Output { + self._xor(&other).unwrap() + } +} + +impl BitXorAssign for Boolean { + /// Sets `self = self ^ other`. + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitxor_assign(&mut self, other: Self) { + let result = self._xor(&other).unwrap(); + *self = result; + } +} + +impl<'a, F: Field> BitXorAssign<&'a Self> for Boolean { + /// Sets `self = self ^ other`. + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitxor_assign(&mut self, other: &'a Self) { + let result = self._xor(other).unwrap(); + *self = result; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + boolean::test_utils::run_binary_exhaustive, + prelude::EqGadget, + R1CSVar, + }; + use ark_test_curves::bls12_381::Fr; + + #[test] + fn xor() { + run_binary_exhaustive::(|a, b| { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let computed = &a ^ &b; + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = Boolean::new_variable( + cs.clone(), + || Ok(a.value().unwrap() ^ b.value().unwrap()), + expected_mode, + )?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&expected)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + }) + .unwrap() + } +} diff --git a/src/bits/uint/and.rs b/src/bits/uint/and.rs index 38b31564..154576b6 100644 --- a/src/bits/uint/and.rs +++ b/src/bits/uint/and.rs @@ -9,7 +9,7 @@ impl UInt { fn _and(&self, other: &Self) -> Result { let mut result = self.clone(); for (a, b) in result.bits.iter_mut().zip(&other.bits) { - *a = a.and(b)? + *a &= b; } result.value = self.value.and_then(|a| Some(a & other.value?)); dbg!(result.value); @@ -201,7 +201,6 @@ impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitAndAssign<&'a Self> fo } } - #[cfg(test)] mod tests { use super::*; @@ -226,9 +225,10 @@ mod tests { } else { AllocationMode::Witness }; - let expected = UInt::::new_variable(cs.clone(), - || Ok(a.value().unwrap() & b.value().unwrap()), - expected_mode + let expected = UInt::::new_variable( + cs.clone(), + || Ok(a.value().unwrap() & b.value().unwrap()), + expected_mode, )?; assert_eq!(expected.value(), computed.value()); expected.enforce_equal(&expected)?; @@ -262,4 +262,4 @@ mod tests { fn u128_and() { run_binary_random::<1000, 128, _, _>(uint_and::).unwrap() } -} \ No newline at end of file +} diff --git a/src/bits/uint/eq.rs b/src/bits/uint/eq.rs index d6b5675a..94a1e0c1 100644 --- a/src/bits/uint/eq.rs +++ b/src/bits/uint/eq.rs @@ -91,9 +91,10 @@ mod tests { } else { AllocationMode::Witness }; - let expected = Boolean::new_variable(cs.clone(), - || Ok(a.value().unwrap() == b.value().unwrap()), - expected_mode + let expected = Boolean::new_variable( + cs.clone(), + || Ok(a.value().unwrap() == b.value().unwrap()), + expected_mode, )?; assert_eq!(expected.value(), computed.value()); expected.enforce_equal(&expected)?; @@ -127,4 +128,4 @@ mod tests { fn u128_eq() { run_binary_random::<1000, 128, _, _>(uint_eq::).unwrap() } -} \ No newline at end of file +} diff --git a/src/bits/uint/mod.rs b/src/bits/uint/mod.rs index 48122491..c2dd9597 100644 --- a/src/bits/uint/mod.rs +++ b/src/bits/uint/mod.rs @@ -1,4 +1,4 @@ -use ark_ff::{Field, One, PrimeField, Zero}; +use ark_ff::{BigInteger, Field, One, PrimeField, Zero}; use core::{borrow::Borrow, convert::TryFrom, fmt::Debug}; use num_bigint::BigUint; use num_traits::{NumCast, PrimInt}; @@ -9,6 +9,7 @@ use ark_relations::r1cs::{ use crate::{ boolean::{AllocatedBool, Boolean}, + fields::fp::FpVar, prelude::*, Assignment, Vec, }; @@ -184,6 +185,43 @@ impl UInt { result } + /// Converts a little-endian byte order representation of bits into a + /// field element. + + /// Assumes that `N` is equal to at most the number of bits in `F::MODULUS_BIT_SIZE - 1`. + pub fn to_fp_var(&self) -> Result, SynthesisError> + where + F: PrimeField, + { + assert!(N <= F::MODULUS_BIT_SIZE as usize - 1); + + Boolean::le_bits_to_fp_var(&self.bits) + } + + /// Assumes that `N` is equal to at most the number of bits in `F::MODULUS_BIT_SIZE - 1`. + pub fn from_fp_var(other: &FpVar) -> Result + where + F: PrimeField, + { + assert!(N <= F::MODULUS_BIT_SIZE as usize - 1); + let value = other.value()?.into_bigint().to_bits_le(); + let cs = other.cs(); + let mode = if other.is_constant() { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let lower_bits = value + .iter() + .take(N) + .map(|b| Boolean::new_variable(cs.clone(), || Ok(*b), mode)) + .collect::, _>>()?; + let result = Self::from_bits_le(&lower_bits); + let rest: FpVar = other - &result.to_fp_var()?; + rest.enforce_equal(&FpVar::zero())?; + Ok(result) + } + /// Perform modular addition of `operands`. /// /// The user must ensure that overflow does not occur. @@ -205,11 +243,7 @@ impl UInt { // Compute the maximum value of the sum so we allocate enough bits for // the result - let mut max_value = T::max_value() - .checked_mul( - &T::from(ark_std::log2(operands.len())).ok_or(SynthesisError::Unsatisfiable)?, - ) - .ok_or(SynthesisError::Unsatisfiable)?; + let mut max_value = T::max_value() * T::from(operands.len() as u128).unwrap(); // Keep track of the resulting value let mut result_value = Some(BigUint::zero()); @@ -241,24 +275,17 @@ impl UInt { let mut coeff = F::one(); for bit in &op.bits { match *bit { - Boolean::Is(ref bit) => { - all_constants = false; - - // Add coeff * bit_gadget - lc += (coeff, bit.variable()); - }, - Boolean::Not(ref bit) => { - all_constants = false; - - // Add coeff * (1 - bit_gadget) = coeff * ONE - coeff * - // bit_gadget - lc = lc + (coeff, Variable::One) - (coeff, bit.variable()); - }, Boolean::Constant(bit) => { if bit { lc += (coeff, Variable::One); } }, + _ => { + all_constants = false; + + // Add coeff * bit_gadget + lc = lc + (coeff, &bit.lc()); + }, } coeff.double_in_place(); @@ -424,8 +451,8 @@ impl AllocVar {}, -// (&Boolean::Constant(false), &Boolean::Constant(false)) => {}, +// (&Boolean::TRUE, &Boolean::TRUE) => {}, +// (&Boolean::FALSE, &Boolean::FALSE) => {}, // _ => unreachable!(), // } // } diff --git a/src/bits/uint/not.rs b/src/bits/uint/not.rs index 350ff9e8..e35ec93e 100644 --- a/src/bits/uint/not.rs +++ b/src/bits/uint/not.rs @@ -9,7 +9,7 @@ impl UInt { fn _not(&self) -> Result { let mut result = self.clone(); for a in &mut result.bits { - *a = a.not() + *a = !&*a } result.value = self.value.map(Not::not); dbg!(result.value); @@ -96,10 +96,8 @@ mod tests { } else { AllocationMode::Witness }; - let expected = UInt::::new_variable(cs.clone(), - || Ok(!a.value().unwrap()), - expected_mode - )?; + let expected = + UInt::::new_variable(cs.clone(), || Ok(!a.value().unwrap()), expected_mode)?; assert_eq!(expected.value(), computed.value()); expected.enforce_equal(&expected)?; if !a.is_constant() { diff --git a/src/bits/uint/or.rs b/src/bits/uint/or.rs index 05fbffaa..beeb897e 100644 --- a/src/bits/uint/or.rs +++ b/src/bits/uint/or.rs @@ -9,7 +9,7 @@ impl UInt { fn _or(&self, other: &Self) -> Result { let mut result = self.clone(); for (a, b) in result.bits.iter_mut().zip(&other.bits) { - *a = a.or(b)? + *a |= b; } result.value = self.value.and_then(|a| Some(a | other.value?)); dbg!(result.value); @@ -221,9 +221,10 @@ mod tests { } else { AllocationMode::Witness }; - let expected = UInt::::new_variable(cs.clone(), - || Ok(a.value().unwrap() | b.value().unwrap()), - expected_mode + let expected = UInt::::new_variable( + cs.clone(), + || Ok(a.value().unwrap() | b.value().unwrap()), + expected_mode, )?; assert_eq!(expected.value(), computed.value()); expected.enforce_equal(&expected)?; diff --git a/src/bits/uint/xor.rs b/src/bits/uint/xor.rs index 73c2722c..f2181ec4 100644 --- a/src/bits/uint/xor.rs +++ b/src/bits/uint/xor.rs @@ -9,7 +9,7 @@ impl UInt { fn _xor(&self, other: &Self) -> Result { let mut result = self.clone(); for (a, b) in result.bits.iter_mut().zip(&other.bits) { - *a = a.xor(b)? + *a ^= b; } result.value = self.value.and_then(|a| Some(a ^ other.value?)); Ok(result) @@ -220,9 +220,10 @@ mod tests { } else { AllocationMode::Witness }; - let expected = UInt::::new_variable(cs.clone(), - || Ok(a.value().unwrap() ^ b.value().unwrap()), - expected_mode + let expected = UInt::::new_variable( + cs.clone(), + || Ok(a.value().unwrap() ^ b.value().unwrap()), + expected_mode, )?; assert_eq!(expected.value(), computed.value()); expected.enforce_equal(&expected)?; diff --git a/src/bits/uint8.rs b/src/bits/uint8.rs index 6b977a32..282d7b89 100644 --- a/src/bits/uint8.rs +++ b/src/bits/uint8.rs @@ -206,8 +206,7 @@ mod test { for b in r.bits.iter() { match b { - Boolean::Is(b) => assert!(b.value()? == (expected & 1 == 1)), - Boolean::Not(b) => assert!(!b.value()? == (expected & 1 == 1)), + Boolean::Var(b) => assert!(b.value()? == (expected & 1 == 1)), Boolean::Constant(b) => assert!(*b == (expected & 1 == 1)), } diff --git a/src/eq.rs b/src/eq.rs index f1184619..7cdf91ca 100644 --- a/src/eq.rs +++ b/src/eq.rs @@ -14,7 +14,7 @@ pub trait EqGadget { /// /// By default, this is defined as `self.is_eq(other)?.not()`. fn is_neq(&self, other: &Self) -> Result, SynthesisError> { - Ok(self.is_eq(other)?.not()) + Ok(!self.is_eq(other)?) } /// If `should_enforce == true`, enforce that `self` and `other` are equal; diff --git a/src/fields/cubic_extension.rs b/src/fields/cubic_extension.rs index a4465819..12f330c5 100644 --- a/src/fields/cubic_extension.rs +++ b/src/fields/cubic_extension.rs @@ -372,7 +372,7 @@ where let b0 = self.c0.is_eq(&other.c0)?; let b1 = self.c1.is_eq(&other.c1)?; let b2 = self.c2.is_eq(&other.c2)?; - b0.and(&b1)?.and(&b2) + Ok(b0 & b1 & b2) } #[inline] @@ -396,9 +396,7 @@ where condition: &Boolean, ) -> Result<(), SynthesisError> { let is_equal = self.is_eq(other)?; - is_equal - .and(condition)? - .enforce_equal(&Boolean::Constant(false)) + (is_equal & condition).enforce_equal(&Boolean::FALSE) } } diff --git a/src/fields/fp/mod.rs b/src/fields/fp/mod.rs index c5eeb74c..878a1d4e 100644 --- a/src/fields/fp/mod.rs +++ b/src/fields/fp/mod.rs @@ -129,13 +129,14 @@ impl AllocatedFp { /// /// This does not create any constraints and only creates one linear /// combination. - pub fn add_many<'a, I: Iterator>(iter: I) -> Self { + pub fn add_many, I: Iterator>(iter: I) -> Self { let mut cs = ConstraintSystemRef::None; let mut has_value = true; let mut value = F::zero(); let mut new_lc = lc!(); for variable in iter { + let variable = variable.borrow(); if !variable.cs.is_none() { cs = cs.or(variable.cs.clone()); } @@ -323,7 +324,7 @@ impl AllocatedFp { /// This requires three constraints. #[tracing::instrument(target = "r1cs")] pub fn is_eq(&self, other: &Self) -> Result, SynthesisError> { - Ok(self.is_neq(other)?.not()) + Ok(!self.is_neq(other)?) } /// Outputs the bit `self != other`. @@ -391,7 +392,7 @@ impl AllocatedFp { )?; self.cs.enforce_constraint( lc!() + self.variable - other.variable, - is_not_equal.not().lc(), + (!&is_not_equal).lc(), lc!(), )?; Ok(is_not_equal) @@ -554,8 +555,8 @@ impl CondSelectGadget for AllocatedFp { false_val: &Self, ) -> Result { match cond { - Boolean::Constant(true) => Ok(true_val.clone()), - Boolean::Constant(false) => Ok(false_val.clone()), + &Boolean::Constant(true) => Ok(true_val.clone()), + &Boolean::Constant(false) => Ok(false_val.clone()), _ => { let cs = cond.cs(); let result = Self::new_witness(cs.clone(), || { @@ -952,13 +953,13 @@ impl CondSelectGadget for FpVar { false_value: &Self, ) -> Result { match cond { - Boolean::Constant(true) => Ok(true_value.clone()), - Boolean::Constant(false) => Ok(false_value.clone()), + &Boolean::Constant(true) => Ok(true_value.clone()), + &Boolean::Constant(false) => Ok(false_value.clone()), _ => { match (true_value, false_value) { (Self::Constant(t), Self::Constant(f)) => { let is = AllocatedFp::from(cond.clone()); - let not = AllocatedFp::from(cond.not()); + let not = AllocatedFp::from(!cond); // cond * t + (1 - cond) * f Ok(is.mul_constant(*t).add(¬.mul_constant(*f)).into()) }, @@ -1063,6 +1064,22 @@ impl<'a, F: PrimeField> Sum<&'a FpVar> for FpVar { } } +impl<'a, F: PrimeField> Sum> for FpVar { + fn sum>>(iter: I) -> FpVar { + let mut sum_constants = F::zero(); + let sum_variables = FpVar::Var(AllocatedFp::::add_many(iter.filter_map(|x| match x { + FpVar::Constant(c) => { + sum_constants += c; + None + }, + FpVar::Var(v) => Some(v), + }))); + + let sum = sum_variables + sum_constants; + sum + } +} + #[cfg(test)] mod test { use crate::{ diff --git a/src/fields/nonnative/field_var.rs b/src/fields/nonnative/field_var.rs index fb2e4922..e37b6c84 100644 --- a/src/fields/nonnative/field_var.rs +++ b/src/fields/nonnative/field_var.rs @@ -229,7 +229,7 @@ impl EqGadget Boolean::new_witness(cs, || Ok(self.value()? == other.value()?))?; self.conditional_enforce_equal(other, &should_enforce_equal)?; - self.conditional_enforce_not_equal(other, &should_enforce_equal.not())?; + self.conditional_enforce_not_equal(other, &!&should_enforce_equal)?; Ok(should_enforce_equal) } @@ -341,8 +341,8 @@ impl CondSelectGadget false_value: &Self, ) -> R1CSResult { match cond { - Boolean::Constant(true) => Ok(true_value.clone()), - Boolean::Constant(false) => Ok(false_value.clone()), + &Boolean::Constant(true) => Ok(true_value.clone()), + &Boolean::Constant(false) => Ok(false_value.clone()), _ => { let cs = cond.cs(); let true_value = match true_value { diff --git a/src/fields/quadratic_extension.rs b/src/fields/quadratic_extension.rs index 5e665bb4..54949a10 100644 --- a/src/fields/quadratic_extension.rs +++ b/src/fields/quadratic_extension.rs @@ -377,7 +377,7 @@ where fn is_eq(&self, other: &Self) -> Result, SynthesisError> { let b0 = self.c0.is_eq(&other.c0)?; let b1 = self.c1.is_eq(&other.c1)?; - b0.and(&b1) + Ok(b0 & b1) } #[inline] @@ -400,9 +400,7 @@ where condition: &Boolean, ) -> Result<(), SynthesisError> { let is_equal = self.is_eq(other)?; - is_equal - .and(condition)? - .enforce_equal(&Boolean::Constant(false)) + (is_equal & condition).enforce_equal(&Boolean::FALSE) } } diff --git a/src/groups/curves/short_weierstrass/bls12/mod.rs b/src/groups/curves/short_weierstrass/bls12/mod.rs index 263a1bd5..7ac9f9b8 100644 --- a/src/groups/curves/short_weierstrass/bls12/mod.rs +++ b/src/groups/curves/short_weierstrass/bls12/mod.rs @@ -201,7 +201,7 @@ impl G2PreparedVar

{ let q = q.to_affine()?; let two_inv = P::Fp::one().double().inverse().unwrap(); // Enforce that `q` is not the point at infinity. - q.infinity.enforce_not_equal(&Boolean::Constant(true))?; + q.infinity.enforce_not_equal(&Boolean::TRUE)?; let mut ell_coeffs = vec![]; let mut r = q.clone(); diff --git a/src/groups/curves/short_weierstrass/mod.rs b/src/groups/curves/short_weierstrass/mod.rs index 91eae941..e6448c10 100644 --- a/src/groups/curves/short_weierstrass/mod.rs +++ b/src/groups/curves/short_weierstrass/mod.rs @@ -182,7 +182,7 @@ where // `z_inv * self.z = 0` if `self.is_zero()`. // // Thus, `z_inv * self.z = !self.is_zero()`. - z_inv.mul_equals(&self.z, &F::from(infinity.not()))?; + z_inv.mul_equals(&self.z, &F::from(!&infinity))?; let non_zero_x = &self.x * &z_inv; let non_zero_y = &self.y * &z_inv; @@ -735,9 +735,9 @@ where ) -> Result::BasePrimeField>, SynthesisError> { let x_equal = (&self.x * &other.z).is_eq(&(&other.x * &self.z))?; let y_equal = (&self.y * &other.z).is_eq(&(&other.y * &self.z))?; - let coordinates_equal = x_equal.and(&y_equal)?; - let both_are_zero = self.is_zero()?.and(&other.is_zero()?)?; - both_are_zero.or(&coordinates_equal) + let coordinates_equal = x_equal & y_equal; + let both_are_zero = self.is_zero()? & other.is_zero()?; + Ok(both_are_zero | coordinates_equal) } #[inline] @@ -749,12 +749,9 @@ where ) -> Result<(), SynthesisError> { let x_equal = (&self.x * &other.z).is_eq(&(&other.x * &self.z))?; let y_equal = (&self.y * &other.z).is_eq(&(&other.y * &self.z))?; - let coordinates_equal = x_equal.and(&y_equal)?; - let both_are_zero = self.is_zero()?.and(&other.is_zero()?)?; - both_are_zero - .or(&coordinates_equal)? - .conditional_enforce_equal(&Boolean::Constant(true), condition)?; - Ok(()) + let coordinates_equal = x_equal & y_equal; + let both_are_zero = self.is_zero()? & other.is_zero()?; + (both_are_zero | coordinates_equal).conditional_enforce_equal(&Boolean::TRUE, condition) } #[inline] @@ -765,9 +762,7 @@ where condition: &Boolean<::BasePrimeField>, ) -> Result<(), SynthesisError> { let is_equal = self.is_eq(other)?; - is_equal - .and(condition)? - .enforce_equal(&Boolean::Constant(false)) + (is_equal & condition).enforce_equal(&Boolean::FALSE) } } diff --git a/src/groups/curves/short_weierstrass/non_zero_affine.rs b/src/groups/curves/short_weierstrass/non_zero_affine.rs index 281b5627..d20d9a37 100644 --- a/src/groups/curves/short_weierstrass/non_zero_affine.rs +++ b/src/groups/curves/short_weierstrass/non_zero_affine.rs @@ -188,7 +188,7 @@ where ) -> Result::BasePrimeField>, SynthesisError> { let x_equal = self.x.is_eq(&other.x)?; let y_equal = self.y.is_eq(&other.y)?; - x_equal.and(&y_equal) + Ok(x_equal & y_equal) } #[inline] @@ -200,8 +200,8 @@ where ) -> Result<(), SynthesisError> { let x_equal = self.x.is_eq(&other.x)?; let y_equal = self.y.is_eq(&other.y)?; - let coordinates_equal = x_equal.and(&y_equal)?; - coordinates_equal.conditional_enforce_equal(&Boolean::Constant(true), condition)?; + let coordinates_equal = x_equal & y_equal; + coordinates_equal.conditional_enforce_equal(&Boolean::TRUE, condition)?; Ok(()) } @@ -221,9 +221,7 @@ where condition: &Boolean<::BasePrimeField>, ) -> Result<(), SynthesisError> { let is_equal = self.is_eq(other)?; - is_equal - .and(condition)? - .enforce_equal(&Boolean::Constant(false)) + (is_equal & condition).enforce_equal(&Boolean::FALSE) } } diff --git a/src/groups/curves/twisted_edwards/mod.rs b/src/groups/curves/twisted_edwards/mod.rs index 789cb594..28b65ba1 100644 --- a/src/groups/curves/twisted_edwards/mod.rs +++ b/src/groups/curves/twisted_edwards/mod.rs @@ -358,7 +358,7 @@ where let x_coeffs = coords.iter().map(|p| p.0).collect::>(); let y_coeffs = coords.iter().map(|p| p.1).collect::>(); - let precomp = bits[0].and(&bits[1])?; + let precomp = &bits[0] & &bits[1]; let x = F::zero() + x_coeffs[0] @@ -423,7 +423,7 @@ where } fn is_zero(&self) -> Result::BasePrimeField>, SynthesisError> { - self.x.is_zero()?.and(&self.y.is_one()?) + Ok(self.x.is_zero()? & &self.y.is_one()?) } #[tracing::instrument(target = "r1cs", skip(cs, f))] @@ -848,7 +848,7 @@ where ) -> Result::BasePrimeField>, SynthesisError> { let x_equal = self.x.is_eq(&other.x)?; let y_equal = self.y.is_eq(&other.y)?; - x_equal.and(&y_equal) + Ok(x_equal & y_equal) } #[inline] @@ -870,9 +870,7 @@ where other: &Self, condition: &Boolean<::BasePrimeField>, ) -> Result<(), SynthesisError> { - self.is_eq(other)? - .and(condition)? - .enforce_equal(&Boolean::Constant(false)) + (self.is_eq(other)? & condition).enforce_equal(&Boolean::FALSE) } } From c6a484c9b5058e3c88454dd5a60254560d218c0e Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Wed, 10 May 2023 15:17:22 -0700 Subject: [PATCH 07/29] Small clean up --- src/bits/uint/mod.rs | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/bits/uint/mod.rs b/src/bits/uint/mod.rs index c2dd9597..ed33881a 100644 --- a/src/bits/uint/mod.rs +++ b/src/bits/uint/mod.rs @@ -185,11 +185,12 @@ impl UInt { result } - /// Converts a little-endian byte order representation of bits into a - /// field element. - - /// Assumes that `N` is equal to at most the number of bits in `F::MODULUS_BIT_SIZE - 1`. - pub fn to_fp_var(&self) -> Result, SynthesisError> + /// Converts `self` into a field element. The elements comprising `self` are + /// interpreted as a little-endian bit order representation of a field element. + /// + /// # Panics + /// Assumes that `N` is equal to at most the number of bits in `F::MODULUS_BIT_SIZE - 1`, and panics otherwise. + pub fn to_fp(&self) -> Result, SynthesisError> where F: PrimeField, { @@ -198,8 +199,12 @@ impl UInt { Boolean::le_bits_to_fp_var(&self.bits) } - /// Assumes that `N` is equal to at most the number of bits in `F::MODULUS_BIT_SIZE - 1`. - pub fn from_fp_var(other: &FpVar) -> Result + /// Converts a field element into its little-endian bit order representation. + /// + /// # Panics + /// + /// Assumes that `N` is equal to at most the number of bits in `F::MODULUS_BIT_SIZE - 1`, and panics otherwise. + pub fn from_fp(other: &FpVar) -> Result where F: PrimeField, { @@ -217,7 +222,7 @@ impl UInt { .map(|b| Boolean::new_variable(cs.clone(), || Ok(*b), mode)) .collect::, _>>()?; let result = Self::from_bits_le(&lower_bits); - let rest: FpVar = other - &result.to_fp_var()?; + let rest: FpVar = other - &result.to_fp()?; rest.enforce_equal(&FpVar::zero())?; Ok(result) } From 9638ace21ab00f07d8770574f85e1ca6677239f2 Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Wed, 10 May 2023 20:30:13 -0700 Subject: [PATCH 08/29] Add code for comparison and for rotates --- src/bits/uint/add.rs | 0 src/bits/uint/cmp.rs | 239 +++++++++++++++++++++++++++++++++++++++ src/bits/uint/convert.rs | 117 +++++++++++++++++++ src/bits/uint/mod.rs | 121 +------------------- src/bits/uint/rotate.rs | 174 ++++++++++++++++++++++++++++ 5 files changed, 533 insertions(+), 118 deletions(-) create mode 100644 src/bits/uint/add.rs create mode 100644 src/bits/uint/convert.rs create mode 100644 src/bits/uint/rotate.rs diff --git a/src/bits/uint/add.rs b/src/bits/uint/add.rs new file mode 100644 index 00000000..e69de29b diff --git a/src/bits/uint/cmp.rs b/src/bits/uint/cmp.rs index 8b137891..8736eca7 100644 --- a/src/bits/uint/cmp.rs +++ b/src/bits/uint/cmp.rs @@ -1 +1,240 @@ +use crate::fields::fp::FpVar; +use super::*; +impl UInt { + pub fn is_gt(&self, other: &Self) -> Result, SynthesisError> { + other.is_lt(self) + } + + pub fn is_ge(&self, other: &Self) -> Result, SynthesisError> { + if N + 1 < ((F::MODULUS_BIT_SIZE - 1) as usize) { + let a = self.to_fp()?; + let b = other.to_fp()?; + let (_, rest) = Self::from_fp_to_parts(&(a - b))?; + rest.is_eq(&FpVar::zero()) + } else { + unimplemented!("bit sizes larger than modulus size not yet supported") + } + } + + pub fn is_lt(&self, other: &Self) -> Result, SynthesisError> { + Ok(!self.is_ge(other)?) + } + + pub fn is_le(&self, other: &Self) -> Result, SynthesisError> { + other.is_ge(self) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + prelude::EqGadget, + uint::test_utils::{run_binary_exhaustive, run_binary_random}, + R1CSVar, + }; + use ark_ff::PrimeField; + use ark_test_curves::bls12_381::Fr; + + fn uint_gt( + a: UInt, + b: UInt, + ) -> Result<(), SynthesisError> { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let computed = a.is_gt(&b)?; + let expected = Boolean::new_variable( + cs.clone(), + || Ok(a.value().unwrap() > b.value().unwrap()), + expected_mode, + )?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&expected)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + fn uint_lt( + a: UInt, + b: UInt, + ) -> Result<(), SynthesisError> { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let computed = a.is_lt(&b)?; + let expected = Boolean::new_variable( + cs.clone(), + || Ok(a.value().unwrap() < b.value().unwrap()), + expected_mode, + )?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&expected)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + fn uint_ge( + a: UInt, + b: UInt, + ) -> Result<(), SynthesisError> { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let computed = a.is_ge(&b)?; + let expected = Boolean::new_variable( + cs.clone(), + || Ok(a.value().unwrap() >= b.value().unwrap()), + expected_mode, + )?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&expected)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + fn uint_le( + a: UInt, + b: UInt, + ) -> Result<(), SynthesisError> { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let computed = a.is_le(&b)?; + let expected = Boolean::new_variable( + cs.clone(), + || Ok(a.value().unwrap() <= b.value().unwrap()), + expected_mode, + )?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&expected)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + #[test] + fn u8_gt() { + run_binary_exhaustive(uint_gt::).unwrap() + } + + #[test] + fn u16_gt() { + run_binary_random::<1000, 16, _, _>(uint_gt::).unwrap() + } + + #[test] + fn u32_gt() { + run_binary_random::<1000, 32, _, _>(uint_gt::).unwrap() + } + + #[test] + fn u64_gt() { + run_binary_random::<1000, 64, _, _>(uint_gt::).unwrap() + } + + #[test] + fn u128_gt() { + run_binary_random::<1000, 128, _, _>(uint_gt::).unwrap() + } + + #[test] + fn u8_lt() { + run_binary_exhaustive(uint_lt::).unwrap() + } + + #[test] + fn u16_lt() { + run_binary_random::<1000, 16, _, _>(uint_lt::).unwrap() + } + + #[test] + fn u32_lt() { + run_binary_random::<1000, 32, _, _>(uint_lt::).unwrap() + } + + #[test] + fn u64_lt() { + run_binary_random::<1000, 64, _, _>(uint_lt::).unwrap() + } + + #[test] + fn u128_lt() { + run_binary_random::<1000, 128, _, _>(uint_lt::).unwrap() + } + + #[test] + fn u8_le() { + run_binary_exhaustive(uint_le::).unwrap() + } + + #[test] + fn u16_le() { + run_binary_random::<1000, 16, _, _>(uint_le::).unwrap() + } + + #[test] + fn u32_le() { + run_binary_random::<1000, 32, _, _>(uint_le::).unwrap() + } + + #[test] + fn u64_le() { + run_binary_random::<1000, 64, _, _>(uint_le::).unwrap() + } + + #[test] + fn u128_le() { + run_binary_random::<1000, 128, _, _>(uint_le::).unwrap() + } + + #[test] + fn u8_ge() { + run_binary_exhaustive(uint_ge::).unwrap() + } + + #[test] + fn u16_ge() { + run_binary_random::<1000, 16, _, _>(uint_ge::).unwrap() + } + + #[test] + fn u32_ge() { + run_binary_random::<1000, 32, _, _>(uint_ge::).unwrap() + } + + #[test] + fn u64_ge() { + run_binary_random::<1000, 64, _, _>(uint_ge::).unwrap() + } + + #[test] + fn u128_ge() { + run_binary_random::<1000, 128, _, _>(uint_ge::).unwrap() + } +} diff --git a/src/bits/uint/convert.rs b/src/bits/uint/convert.rs new file mode 100644 index 00000000..d107ecdb --- /dev/null +++ b/src/bits/uint/convert.rs @@ -0,0 +1,117 @@ +use ark_ff::BigInteger; + +use crate::fields::fp::FpVar; + +use super::*; + +impl UInt { + /// Converts `self` into a field element. The elements comprising `self` are + /// interpreted as a little-endian bit order representation of a field element. + /// + /// # Panics + /// Assumes that `N` is equal to at most the number of bits in `F::MODULUS_BIT_SIZE - 1`, and panics otherwise. + pub fn to_fp(&self) -> Result, SynthesisError> + where + F: PrimeField, + { + assert!(N <= F::MODULUS_BIT_SIZE as usize - 1); + + Boolean::le_bits_to_fp_var(&self.bits) + } + + pub(super) fn from_fp_to_parts(other: &FpVar) -> Result<(Self, FpVar), SynthesisError> + where + F: PrimeField, + { + assert!(N <= F::MODULUS_BIT_SIZE as usize - 1); + let value = other.value()?.into_bigint().to_bits_le(); + let cs = other.cs(); + let mode = if other.is_constant() { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let lower_bits = value + .iter() + .take(N) + .map(|b| Boolean::new_variable(cs.clone(), || Ok(*b), mode)) + .collect::, _>>()?; + let result = Self::from_bits_le(&lower_bits); + let rest: FpVar = other - &result.to_fp()?; + (result.to_fp()? + &rest).enforce_equal(&other)?; + Ok((result, rest)) + } + + /// Converts a field element into its little-endian bit order representation. + /// + /// # Panics + /// + /// Assumes that `N` is equal to at most the number of bits in `F::MODULUS_BIT_SIZE - 1`, and panics otherwise. + pub fn from_fp(other: &FpVar) -> Result + where + F: PrimeField, + { + let (result, rest) = Self::from_fp_to_parts(other)?; + rest.enforce_equal(&FpVar::zero())?; + Ok(result) + } + + /// Turns `self` into the underlying little-endian bits. + pub fn to_bits_le(&self) -> Vec> { + self.bits.to_vec() + } + + /// Converts a little-endian byte order representation of bits into a + /// `UInt`. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let var = UInt8::new_witness(cs.clone(), || Ok(128))?; + /// + /// let f = Boolean::FALSE; + /// let t = Boolean::TRUE; + /// + /// // Construct [0, 0, 0, 0, 0, 0, 0, 1] + /// let mut bits = vec![f.clone(); 7]; + /// bits.push(t); + /// + /// let mut c = UInt8::from_bits_le(&bits); + /// var.enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs")] + pub fn from_bits_le(bits: &[Boolean]) -> Self { + assert_eq!(bits.len(), N); + let bits = <&[Boolean; N]>::try_from(bits).unwrap().clone(); + let value_exists = bits.iter().all(|b| b.value().is_ok()); + let mut value = T::zero(); + for (i, b) in bits.iter().enumerate() { + if let Ok(b) = b.value() { + value = value + (T::from(b as u8).unwrap() << i); + } + } + let value = value_exists.then_some(value); + Self { bits, value } + } +} + +impl ToBytesGadget + for UInt +{ + #[tracing::instrument(target = "r1cs", skip(self))] + fn to_bytes(&self) -> Result>, SynthesisError> { + Ok(self + .to_bits_le() + .chunks(8) + .map(UInt8::from_bits_le) + .collect()) + } +} diff --git a/src/bits/uint/mod.rs b/src/bits/uint/mod.rs index ed33881a..58a45982 100644 --- a/src/bits/uint/mod.rs +++ b/src/bits/uint/mod.rs @@ -1,4 +1,4 @@ -use ark_ff::{BigInteger, Field, One, PrimeField, Zero}; +use ark_ff::{Field, One, PrimeField, Zero}; use core::{borrow::Borrow, convert::TryFrom, fmt::Debug}; use num_bigint::BigUint; use num_traits::{NumCast, PrimInt}; @@ -9,16 +9,17 @@ use ark_relations::r1cs::{ use crate::{ boolean::{AllocatedBool, Boolean}, - fields::fp::FpVar, prelude::*, Assignment, Vec, }; mod and; mod cmp; +mod convert; mod eq; mod not; mod or; +mod rotate; mod xor; #[cfg(test)] @@ -124,109 +125,6 @@ impl UInt { Ok(output_vec) } - /// Turns `self` into the underlying little-endian bits. - pub fn to_bits_le(&self) -> Vec> { - self.bits.to_vec() - } - - /// Converts a little-endian byte order representation of bits into a - /// `UInt`. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// let var = UInt8::new_witness(cs.clone(), || Ok(128))?; - /// - /// let f = Boolean::FALSE; - /// let t = Boolean::TRUE; - /// - /// // Construct [0, 0, 0, 0, 0, 0, 0, 1] - /// let mut bits = vec![f.clone(); 7]; - /// bits.push(t); - /// - /// let mut c = UInt8::from_bits_le(&bits); - /// var.enforce_equal(&c)?; - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - #[tracing::instrument(target = "r1cs")] - pub fn from_bits_le(bits: &[Boolean]) -> Self { - assert_eq!(bits.len(), N); - let bits = <&[Boolean; N]>::try_from(bits).unwrap().clone(); - let value_exists = bits.iter().all(|b| b.value().is_ok()); - let mut value = T::zero(); - for (i, b) in bits.iter().enumerate() { - if let Ok(b) = b.value() { - value = value + (T::from(b as u8).unwrap() << i); - } - } - let value = value_exists.then_some(value); - Self { bits, value } - } - - /// Rotates `self` to the right by `by` steps, wrapping around. - #[tracing::instrument(target = "r1cs", skip(self))] - pub fn rotate_right(&self, by: usize) -> Self { - let mut result = self.clone(); - let by = by % N; - - let new_bits = self.bits.iter().skip(by).chain(&self.bits).take(N); - - for (res, new) in result.bits.iter_mut().zip(new_bits) { - *res = new.clone(); - } - result.value = self.value.map(|v| v.rotate_right(by as u32)); - result - } - - /// Converts `self` into a field element. The elements comprising `self` are - /// interpreted as a little-endian bit order representation of a field element. - /// - /// # Panics - /// Assumes that `N` is equal to at most the number of bits in `F::MODULUS_BIT_SIZE - 1`, and panics otherwise. - pub fn to_fp(&self) -> Result, SynthesisError> - where - F: PrimeField, - { - assert!(N <= F::MODULUS_BIT_SIZE as usize - 1); - - Boolean::le_bits_to_fp_var(&self.bits) - } - - /// Converts a field element into its little-endian bit order representation. - /// - /// # Panics - /// - /// Assumes that `N` is equal to at most the number of bits in `F::MODULUS_BIT_SIZE - 1`, and panics otherwise. - pub fn from_fp(other: &FpVar) -> Result - where - F: PrimeField, - { - assert!(N <= F::MODULUS_BIT_SIZE as usize - 1); - let value = other.value()?.into_bigint().to_bits_le(); - let cs = other.cs(); - let mode = if other.is_constant() { - AllocationMode::Constant - } else { - AllocationMode::Witness - }; - let lower_bits = value - .iter() - .take(N) - .map(|b| Boolean::new_variable(cs.clone(), || Ok(*b), mode)) - .collect::, _>>()?; - let result = Self::from_bits_le(&lower_bits); - let rest: FpVar = other - &result.to_fp()?; - rest.enforce_equal(&FpVar::zero())?; - Ok(result) - } - /// Perform modular addition of `operands`. /// /// The user must ensure that overflow does not occur. @@ -353,19 +251,6 @@ impl UInt { } } -impl ToBytesGadget - for UInt -{ - #[tracing::instrument(target = "r1cs", skip(self))] - fn to_bytes(&self) -> Result>, SynthesisError> { - Ok(self - .to_bits_le() - .chunks(8) - .map(UInt8::from_bits_le) - .collect()) - } -} - impl CondSelectGadget for UInt { diff --git a/src/bits/uint/rotate.rs b/src/bits/uint/rotate.rs new file mode 100644 index 00000000..f1355328 --- /dev/null +++ b/src/bits/uint/rotate.rs @@ -0,0 +1,174 @@ +use super::*; + +impl UInt { + /// Rotates `self` to the right by `by` steps, wrapping around. + /// + /// # Examples + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt32::new_witness(cs.clone(), || Ok(0xb301u32))?; + /// let b = UInt32::new_witness(cs.clone(), || Ok(0x10000b3))?; + /// + /// a.rotate_right(8)?.enforce_equal(&b)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self))] + pub fn rotate_right(&self, by: usize) -> Self { + let by = by % N; + let mut result = self.clone(); + // `[T]::rotate_left` corresponds to a `rotate_right` of the bits. + result.bits.rotate_left(by); + result.value = self.value.map(|v| v.rotate_right(by as u32)); + result + } + + /// Rotates `self` to the left by `by` steps, wrapping around. + /// + /// # Examples + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt32::new_witness(cs.clone(), || Ok(0x10000b3))?; + /// let b = UInt32::new_witness(cs.clone(), || Ok(0xb301u32))?; + /// + /// a.rotate_left(8)?.enforce_equal(&b)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self))] + pub fn rotate_left(&self, by: usize) -> Self { + let by = by % N; + let mut result = self.clone(); + // `[T]::rotate_right` corresponds to a `rotate_left` of the bits. + result.bits.rotate_right(by); + result.value = self.value.map(|v| v.rotate_left(by as u32)); + result + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + prelude::EqGadget, + uint::test_utils::{run_unary_exhaustive, run_unary_random}, + R1CSVar, + }; + use ark_ff::PrimeField; + use ark_test_curves::bls12_381::Fr; + + fn uint_rotate_left( + a: UInt, + ) -> Result<(), SynthesisError> { + let cs = a.cs(); + let expected_mode = if a.is_constant() { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + for shift in 0..N { + let computed = a.rotate_left(shift); + let expected = UInt::::new_variable( + cs.clone(), + || Ok(a.value().unwrap().rotate_left(shift as u32)), + expected_mode, + )?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&expected)?; + if !a.is_constant() { + assert!(cs.is_satisfied().unwrap()); + } + } + Ok(()) + } + + fn uint_rotate_right( + a: UInt, + ) -> Result<(), SynthesisError> { + let cs = a.cs(); + let expected_mode = if a.is_constant() { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + for shift in 0..N { + let computed = a.rotate_right(shift); + let expected = UInt::::new_variable( + cs.clone(), + || Ok(a.value().unwrap().rotate_right(shift as u32)), + expected_mode, + )?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&expected)?; + if !a.is_constant() { + assert!(cs.is_satisfied().unwrap()); + } + } + Ok(()) + } + + #[test] + fn u8_rotate_left() { + run_unary_exhaustive(uint_rotate_left::).unwrap() + } + + #[test] + fn u16_rotate_left() { + run_unary_random::<1000, 16, _, _>(uint_rotate_left::).unwrap() + } + + #[test] + fn u32_rotate_left() { + run_unary_random::<1000, 32, _, _>(uint_rotate_left::).unwrap() + } + + #[test] + fn u64_rotate_left() { + run_unary_random::<200, 64, _, _>(uint_rotate_left::).unwrap() + } + + #[test] + fn u128_rotate_left() { + run_unary_random::<100, 128, _, _>(uint_rotate_left::).unwrap() + } + + #[test] + fn u8_rotate_right() { + run_unary_exhaustive(uint_rotate_right::).unwrap() + } + + #[test] + fn u16_rotate_right() { + run_unary_random::<1000, 16, _, _>(uint_rotate_right::).unwrap() + } + + #[test] + fn u32_rotate_right() { + run_unary_random::<1000, 32, _, _>(uint_rotate_right::).unwrap() + } + + #[test] + fn u64_rotate_right() { + run_unary_random::<200, 64, _, _>(uint_rotate_right::).unwrap() + } + + #[test] + fn u128_rotate_right() { + run_unary_random::<100, 128, _, _>(uint_rotate_right::).unwrap() + } +} From 12d6371b40a75b1381a5e983952aaa10eb48386e Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Wed, 10 May 2023 20:58:43 -0700 Subject: [PATCH 09/29] Incorporate fix --- src/bits/uint/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bits/uint/mod.rs b/src/bits/uint/mod.rs index 58a45982..ca3b2d3d 100644 --- a/src/bits/uint/mod.rs +++ b/src/bits/uint/mod.rs @@ -138,7 +138,7 @@ impl UInt { assert!(F::MODULUS_BIT_SIZE as usize >= 2 * N); assert!(operands.len() >= 1); - assert!(N * operands.len() <= F::MODULUS_BIT_SIZE as usize); + assert!(N as u32 + ark_std::log2(operands.len()) <= F::MODULUS_BIT_SIZE); if operands.len() == 1 { return Ok(operands[0].clone()); From dda5aef4db1f697ee29435634280a3fdcb793d78 Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Wed, 10 May 2023 21:05:12 -0700 Subject: [PATCH 10/29] Fix `no-std` --- src/bits/uint/and.rs | 1 - src/bits/uint/eq.rs | 2 +- src/bits/uint/not.rs | 1 - src/bits/uint/or.rs | 1 - 4 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/bits/uint/and.rs b/src/bits/uint/and.rs index 154576b6..b2b7b0d6 100644 --- a/src/bits/uint/and.rs +++ b/src/bits/uint/and.rs @@ -12,7 +12,6 @@ impl UInt { *a &= b; } result.value = self.value.and_then(|a| Some(a & other.value?)); - dbg!(result.value); Ok(result) } } diff --git a/src/bits/uint/eq.rs b/src/bits/uint/eq.rs index 94a1e0c1..d1c475cf 100644 --- a/src/bits/uint/eq.rs +++ b/src/bits/uint/eq.rs @@ -1,6 +1,6 @@ use ark_ff::PrimeField; use ark_relations::r1cs::SynthesisError; -use ark_std::fmt::Debug; +use ark_std::{fmt::Debug, vec::Vec}; use num_traits::PrimInt; diff --git a/src/bits/uint/not.rs b/src/bits/uint/not.rs index e35ec93e..af0fe3e2 100644 --- a/src/bits/uint/not.rs +++ b/src/bits/uint/not.rs @@ -12,7 +12,6 @@ impl UInt { *a = !&*a } result.value = self.value.map(Not::not); - dbg!(result.value); Ok(result) } } diff --git a/src/bits/uint/or.rs b/src/bits/uint/or.rs index beeb897e..25167037 100644 --- a/src/bits/uint/or.rs +++ b/src/bits/uint/or.rs @@ -12,7 +12,6 @@ impl UInt { *a |= b; } result.value = self.value.and_then(|a| Some(a | other.value?)); - dbg!(result.value); Ok(result) } } From 700caee7c90b70e87bbac68f76957c28d0539bee Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Thu, 11 May 2023 08:59:50 -0700 Subject: [PATCH 11/29] Fix tests --- src/bits/boolean/and.rs | 9 +-- src/bits/boolean/eq.rs | 150 +++++++++++++++++++++++++++++++++++----- src/bits/boolean/not.rs | 5 +- src/bits/boolean/or.rs | 9 +-- src/bits/boolean/xor.rs | 9 +-- src/bits/uint/and.rs | 4 +- src/bits/uint/cmp.rs | 36 ++++------ src/bits/uint/eq.rs | 56 +++++++++++++-- src/bits/uint/not.rs | 4 +- src/bits/uint/or.rs | 4 +- src/bits/uint/rotate.rs | 8 +-- src/bits/uint/xor.rs | 4 +- src/bits/uint8.rs | 8 +-- src/eq.rs | 2 +- 14 files changed, 223 insertions(+), 85 deletions(-) diff --git a/src/bits/boolean/and.rs b/src/bits/boolean/and.rs index 7c3263a3..9b498e5b 100644 --- a/src/bits/boolean/and.rs +++ b/src/bits/boolean/and.rs @@ -119,13 +119,10 @@ mod tests { } else { AllocationMode::Witness }; - let expected = Boolean::new_variable( - cs.clone(), - || Ok(a.value().unwrap() & b.value().unwrap()), - expected_mode, - )?; + let expected = + Boolean::new_variable(cs.clone(), || Ok(a.value()? & b.value()?), expected_mode)?; assert_eq!(expected.value(), computed.value()); - expected.enforce_equal(&expected)?; + expected.enforce_equal(&computed)?; if !both_constant { assert!(cs.is_satisfied().unwrap()); } diff --git a/src/bits/boolean/eq.rs b/src/bits/boolean/eq.rs index 5f7d30fc..43bc7da2 100644 --- a/src/bits/boolean/eq.rs +++ b/src/bits/boolean/eq.rs @@ -25,11 +25,14 @@ impl EqGadget for Boolean { ) -> Result<(), SynthesisError> { use Boolean::*; let one = Variable::One; + // We will use the following trick: a == b <=> a - b == 0 + // This works because a - b == 0 if and only if a = 0 and b = 0, or a = 1 and b = 1, + // which is exactly the definition of a == b. let difference = match (self, other) { // 1 == 1; 0 == 0 (Constant(true), Constant(true)) | (Constant(false), Constant(false)) => return Ok(()), // false != true - (Constant(_), Constant(_)) => return Err(SynthesisError::AssignmentMissing), + (Constant(_), Constant(_)) => return Err(SynthesisError::Unsatisfiable), // 1 - a (Constant(true), Var(a)) | (Var(a), Constant(true)) => lc!() + one - a.variable(), // a - 0 = a @@ -53,22 +56,25 @@ impl EqGadget for Boolean { ) -> Result<(), SynthesisError> { use Boolean::*; let one = Variable::One; - let difference = match (self, other) { + // We will use the following trick: a != b <=> a + b == 1 + // This works because a + b == 1 if and only if a = 0 and b = 1, or a = 1 and b = 0, + // which is exactly the definition of a != b. + let sum = match (self, other) { // 1 != 0; 0 != 1 (Constant(true), Constant(false)) | (Constant(false), Constant(true)) => return Ok(()), // false == false and true == true - (Constant(_), Constant(_)) => return Err(SynthesisError::AssignmentMissing), - // 1 - a - (Constant(true), Var(a)) | (Var(a), Constant(true)) => lc!() + one - a.variable(), - // a - 0 = a + (Constant(_), Constant(_)) => return Err(SynthesisError::Unsatisfiable), + // 1 + a + (Constant(true), Var(a)) | (Var(a), Constant(true)) => lc!() + one + a.variable(), + // a + 0 = a (Constant(false), Var(a)) | (Var(a), Constant(false)) => lc!() + a.variable(), - // b - a, - (Var(a), Var(b)) => lc!() + b.variable() - a.variable(), + // b + a, + (Var(a), Var(b)) => lc!() + b.variable() + a.variable(), }; if should_enforce != &Constant(false) { let cs = self.cs().or(other.cs()).or(should_enforce.cs()); - cs.enforce_constraint(difference, should_enforce.lc(), should_enforce.lc())?; + cs.enforce_constraint(sum, should_enforce.lc(), lc!() + one)?; } Ok(()) } @@ -79,14 +85,14 @@ mod tests { use super::*; use crate::{ alloc::{AllocVar, AllocationMode}, - boolean::test_utils::run_binary_exhaustive, + boolean::test_utils::{run_binary_exhaustive, run_unary_exhaustive}, prelude::EqGadget, R1CSVar, }; use ark_test_curves::bls12_381::Fr; #[test] - fn or() { + fn eq() { run_binary_exhaustive::(|a, b| { let cs = a.cs().or(b.cs()); let both_constant = a.is_constant() && b.is_constant(); @@ -96,13 +102,123 @@ mod tests { } else { AllocationMode::Witness }; - let expected = Boolean::new_variable( - cs.clone(), - || Ok(a.value().unwrap() == b.value().unwrap()), - expected_mode, - )?; + let expected = + Boolean::new_variable(cs.clone(), || Ok(a.value()? == b.value()?), expected_mode)?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + }) + .unwrap() + } + + #[test] + fn neq() { + run_binary_exhaustive::(|a, b| { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let computed = &a.is_neq(&b)?; + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = + Boolean::new_variable(cs.clone(), || Ok(a.value()? != b.value()?), expected_mode)?; assert_eq!(expected.value(), computed.value()); - expected.enforce_equal(&expected)?; + expected.enforce_equal(&computed)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + }) + .unwrap() + } + + #[test] + fn neq_and_eq_consistency() { + run_binary_exhaustive::(|a, b| { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let is_neq = &a.is_neq(&b)?; + let is_eq = &a.is_eq(&b)?; + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected_is_neq = + Boolean::new_variable(cs.clone(), || Ok(a.value()? != b.value()?), expected_mode)?; + assert_eq!(expected_is_neq.value(), is_neq.value()); + assert_ne!(expected_is_neq.value(), is_eq.value()); + expected_is_neq.enforce_equal(is_neq)?; + expected_is_neq.enforce_equal(&!is_eq)?; + expected_is_neq.enforce_not_equal(is_eq)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + }) + .unwrap() + } + + #[test] + fn enforce_eq_and_enforce_neq_consistency() { + run_unary_exhaustive::(|a| { + let cs = a.cs(); + let not_a = !&a; + a.enforce_equal(&a)?; + not_a.enforce_equal(¬_a)?; + a.enforce_not_equal(¬_a)?; + not_a.enforce_not_equal(&a)?; + if !a.is_constant() { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + }) + .unwrap() + } + + #[test] + fn eq_soundness() { + run_binary_exhaustive::(|a, b| { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let computed = &a.is_eq(&b)?; + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = + Boolean::new_variable(cs.clone(), || Ok(a.value()? != b.value()?), expected_mode)?; + assert_ne!(expected.value(), computed.value()); + expected.enforce_not_equal(&computed)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + }) + .unwrap() + } + + #[test] + fn neq_soundness() { + run_binary_exhaustive::(|a, b| { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let computed = &a.is_neq(&b)?; + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = + Boolean::new_variable(cs.clone(), || Ok(a.value()? == b.value()?), expected_mode)?; + assert_ne!(expected.value(), computed.value()); + expected.enforce_not_equal(&computed)?; if !both_constant { assert!(cs.is_satisfied().unwrap()); } diff --git a/src/bits/boolean/not.rs b/src/bits/boolean/not.rs index 97a71468..ebbce715 100644 --- a/src/bits/boolean/not.rs +++ b/src/bits/boolean/not.rs @@ -85,10 +85,9 @@ mod tests { } else { AllocationMode::Witness }; - let expected = - Boolean::new_variable(cs.clone(), || Ok(!a.value().unwrap()), expected_mode)?; + let expected = Boolean::new_variable(cs.clone(), || Ok(!a.value()?), expected_mode)?; assert_eq!(expected.value(), computed.value()); - expected.enforce_equal(&expected)?; + expected.enforce_equal(&computed)?; if !a.is_constant() { assert!(cs.is_satisfied().unwrap()); } diff --git a/src/bits/boolean/or.rs b/src/bits/boolean/or.rs index fa7396d9..5898d8a5 100644 --- a/src/bits/boolean/or.rs +++ b/src/bits/boolean/or.rs @@ -118,13 +118,10 @@ mod tests { } else { AllocationMode::Witness }; - let expected = Boolean::new_variable( - cs.clone(), - || Ok(a.value().unwrap() | b.value().unwrap()), - expected_mode, - )?; + let expected = + Boolean::new_variable(cs.clone(), || Ok(a.value()? | b.value()?), expected_mode)?; assert_eq!(expected.value(), computed.value()); - expected.enforce_equal(&expected)?; + expected.enforce_equal(&computed)?; if !both_constant { assert!(cs.is_satisfied().unwrap()); } diff --git a/src/bits/boolean/xor.rs b/src/bits/boolean/xor.rs index 4540b2e6..ff880460 100644 --- a/src/bits/boolean/xor.rs +++ b/src/bits/boolean/xor.rs @@ -118,13 +118,10 @@ mod tests { } else { AllocationMode::Witness }; - let expected = Boolean::new_variable( - cs.clone(), - || Ok(a.value().unwrap() ^ b.value().unwrap()), - expected_mode, - )?; + let expected = + Boolean::new_variable(cs.clone(), || Ok(a.value()? ^ b.value()?), expected_mode)?; assert_eq!(expected.value(), computed.value()); - expected.enforce_equal(&expected)?; + expected.enforce_equal(&computed)?; if !both_constant { assert!(cs.is_satisfied().unwrap()); } diff --git a/src/bits/uint/and.rs b/src/bits/uint/and.rs index b2b7b0d6..3466f36c 100644 --- a/src/bits/uint/and.rs +++ b/src/bits/uint/and.rs @@ -226,11 +226,11 @@ mod tests { }; let expected = UInt::::new_variable( cs.clone(), - || Ok(a.value().unwrap() & b.value().unwrap()), + || Ok(a.value()? & b.value()?), expected_mode, )?; assert_eq!(expected.value(), computed.value()); - expected.enforce_equal(&expected)?; + expected.enforce_equal(&computed)?; if !both_constant { assert!(cs.is_satisfied().unwrap()); } diff --git a/src/bits/uint/cmp.rs b/src/bits/uint/cmp.rs index 8736eca7..a138ab88 100644 --- a/src/bits/uint/cmp.rs +++ b/src/bits/uint/cmp.rs @@ -50,13 +50,10 @@ mod tests { AllocationMode::Witness }; let computed = a.is_gt(&b)?; - let expected = Boolean::new_variable( - cs.clone(), - || Ok(a.value().unwrap() > b.value().unwrap()), - expected_mode, - )?; + let expected = + Boolean::new_variable(cs.clone(), || Ok(a.value()? > b.value()?), expected_mode)?; assert_eq!(expected.value(), computed.value()); - expected.enforce_equal(&expected)?; + expected.enforce_equal(&computed)?; if !both_constant { assert!(cs.is_satisfied().unwrap()); } @@ -75,13 +72,10 @@ mod tests { AllocationMode::Witness }; let computed = a.is_lt(&b)?; - let expected = Boolean::new_variable( - cs.clone(), - || Ok(a.value().unwrap() < b.value().unwrap()), - expected_mode, - )?; + let expected = + Boolean::new_variable(cs.clone(), || Ok(a.value()? < b.value()?), expected_mode)?; assert_eq!(expected.value(), computed.value()); - expected.enforce_equal(&expected)?; + expected.enforce_equal(&computed)?; if !both_constant { assert!(cs.is_satisfied().unwrap()); } @@ -100,13 +94,10 @@ mod tests { AllocationMode::Witness }; let computed = a.is_ge(&b)?; - let expected = Boolean::new_variable( - cs.clone(), - || Ok(a.value().unwrap() >= b.value().unwrap()), - expected_mode, - )?; + let expected = + Boolean::new_variable(cs.clone(), || Ok(a.value()? >= b.value()?), expected_mode)?; assert_eq!(expected.value(), computed.value()); - expected.enforce_equal(&expected)?; + expected.enforce_equal(&computed)?; if !both_constant { assert!(cs.is_satisfied().unwrap()); } @@ -125,13 +116,10 @@ mod tests { AllocationMode::Witness }; let computed = a.is_le(&b)?; - let expected = Boolean::new_variable( - cs.clone(), - || Ok(a.value().unwrap() <= b.value().unwrap()), - expected_mode, - )?; + let expected = + Boolean::new_variable(cs.clone(), || Ok(a.value()? <= b.value()?), expected_mode)?; assert_eq!(expected.value(), computed.value()); - expected.enforce_equal(&expected)?; + expected.enforce_equal(&computed)?; if !both_constant { assert!(cs.is_satisfied().unwrap()); } diff --git a/src/bits/uint/eq.rs b/src/bits/uint/eq.rs index d1c475cf..a8a3065c 100644 --- a/src/bits/uint/eq.rs +++ b/src/bits/uint/eq.rs @@ -91,13 +91,32 @@ mod tests { } else { AllocationMode::Witness }; - let expected = Boolean::new_variable( - cs.clone(), - || Ok(a.value().unwrap() == b.value().unwrap()), - expected_mode, - )?; + let expected = + Boolean::new_variable(cs.clone(), || Ok(a.value()? == b.value()?), expected_mode)?; assert_eq!(expected.value(), computed.value()); - expected.enforce_equal(&expected)?; + expected.enforce_equal(&computed)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + fn uint_neq( + a: UInt, + b: UInt, + ) -> Result<(), SynthesisError> { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let computed = a.is_neq(&b)?; + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = + Boolean::new_variable(cs.clone(), || Ok(a.value()? != b.value()?), expected_mode)?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; if !both_constant { assert!(cs.is_satisfied().unwrap()); } @@ -128,4 +147,29 @@ mod tests { fn u128_eq() { run_binary_random::<1000, 128, _, _>(uint_eq::).unwrap() } + + #[test] + fn u8_neq() { + run_binary_exhaustive(uint_neq::).unwrap() + } + + #[test] + fn u16_neq() { + run_binary_random::<1000, 16, _, _>(uint_neq::).unwrap() + } + + #[test] + fn u32_neq() { + run_binary_random::<1000, 32, _, _>(uint_neq::).unwrap() + } + + #[test] + fn u64_neq() { + run_binary_random::<1000, 64, _, _>(uint_neq::).unwrap() + } + + #[test] + fn u128_neq() { + run_binary_random::<1000, 128, _, _>(uint_neq::).unwrap() + } } diff --git a/src/bits/uint/not.rs b/src/bits/uint/not.rs index af0fe3e2..db74cdb6 100644 --- a/src/bits/uint/not.rs +++ b/src/bits/uint/not.rs @@ -96,9 +96,9 @@ mod tests { AllocationMode::Witness }; let expected = - UInt::::new_variable(cs.clone(), || Ok(!a.value().unwrap()), expected_mode)?; + UInt::::new_variable(cs.clone(), || Ok(!a.value()?), expected_mode)?; assert_eq!(expected.value(), computed.value()); - expected.enforce_equal(&expected)?; + expected.enforce_equal(&computed)?; if !a.is_constant() { assert!(cs.is_satisfied().unwrap()); } diff --git a/src/bits/uint/or.rs b/src/bits/uint/or.rs index 25167037..3d0ae90e 100644 --- a/src/bits/uint/or.rs +++ b/src/bits/uint/or.rs @@ -222,11 +222,11 @@ mod tests { }; let expected = UInt::::new_variable( cs.clone(), - || Ok(a.value().unwrap() | b.value().unwrap()), + || Ok(a.value()? | b.value()?), expected_mode, )?; assert_eq!(expected.value(), computed.value()); - expected.enforce_equal(&expected)?; + expected.enforce_equal(&computed)?; if !both_constant { assert!(cs.is_satisfied().unwrap()); } diff --git a/src/bits/uint/rotate.rs b/src/bits/uint/rotate.rs index f1355328..0e4333ec 100644 --- a/src/bits/uint/rotate.rs +++ b/src/bits/uint/rotate.rs @@ -85,11 +85,11 @@ mod tests { let computed = a.rotate_left(shift); let expected = UInt::::new_variable( cs.clone(), - || Ok(a.value().unwrap().rotate_left(shift as u32)), + || Ok(a.value()?.rotate_left(shift as u32)), expected_mode, )?; assert_eq!(expected.value(), computed.value()); - expected.enforce_equal(&expected)?; + expected.enforce_equal(&computed)?; if !a.is_constant() { assert!(cs.is_satisfied().unwrap()); } @@ -110,11 +110,11 @@ mod tests { let computed = a.rotate_right(shift); let expected = UInt::::new_variable( cs.clone(), - || Ok(a.value().unwrap().rotate_right(shift as u32)), + || Ok(a.value()?.rotate_right(shift as u32)), expected_mode, )?; assert_eq!(expected.value(), computed.value()); - expected.enforce_equal(&expected)?; + expected.enforce_equal(&computed)?; if !a.is_constant() { assert!(cs.is_satisfied().unwrap()); } diff --git a/src/bits/uint/xor.rs b/src/bits/uint/xor.rs index f2181ec4..109c0e0d 100644 --- a/src/bits/uint/xor.rs +++ b/src/bits/uint/xor.rs @@ -222,11 +222,11 @@ mod tests { }; let expected = UInt::::new_variable( cs.clone(), - || Ok(a.value().unwrap() ^ b.value().unwrap()), + || Ok(a.value()? ^ b.value()?), expected_mode, )?; assert_eq!(expected.value(), computed.value()); - expected.enforce_equal(&expected)?; + expected.enforce_equal(&computed)?; if !both_constant { assert!(cs.is_satisfied().unwrap()); } diff --git a/src/bits/uint8.rs b/src/bits/uint8.rs index 282d7b89..6474006e 100644 --- a/src/bits/uint8.rs +++ b/src/bits/uint8.rs @@ -191,9 +191,9 @@ mod test { let a_bit = UInt8::new_witness(ark_relations::ns!(cs, "a_bit"), || Ok(a)).unwrap(); let b_bit = UInt8::constant(b); let c_bit = UInt8::new_witness(ark_relations::ns!(cs, "c_bit"), || Ok(c)).unwrap(); - dbg!(a_bit.value().unwrap()); - dbg!(b_bit.value().unwrap()); - dbg!(c_bit.value().unwrap()); + dbg!(a_bit.value()?); + dbg!(b_bit.value()?); + dbg!(c_bit.value()?); let mut r = a_bit ^ b_bit; r ^= &c_bit; @@ -201,7 +201,7 @@ mod test { assert!(cs.is_satisfied().unwrap()); dbg!(expected); - dbg!(r.value().unwrap()); + dbg!(r.value()?); assert_eq!(r.value, Some(expected)); for b in r.bits.iter() { diff --git a/src/eq.rs b/src/eq.rs index 7cdf91ca..71f10c3d 100644 --- a/src/eq.rs +++ b/src/eq.rs @@ -116,7 +116,7 @@ impl + R1CSVar, F: Field> EqGadget for [T] { assert_eq!(self.len(), other.len()); let some_are_different = self.is_neq(other)?; if [&some_are_different, should_enforce].is_constant() { - assert!(some_are_different.value().unwrap()); + assert!(some_are_different.value()?); Ok(()) } else { let cs = [&some_are_different, should_enforce].cs(); From 983bbab0fe6510ecbebbec9b1323d7c1d516c8bc Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Wed, 17 May 2023 17:19:43 -0400 Subject: [PATCH 12/29] `add_many` -> `wrapping_add` + fixes --- src/bits/boolean/mod.rs | 6 +- src/bits/uint/add.rs | 170 ++++++++++++++++++++++++++++++++ src/bits/uint/cmp.rs | 3 +- src/bits/uint/convert.rs | 37 ++----- src/bits/uint/eq.rs | 12 +-- src/bits/uint/mod.rs | 208 +-------------------------------------- src/bits/uint8.rs | 2 +- src/fields/fp/mod.rs | 36 ++++++- 8 files changed, 227 insertions(+), 247 deletions(-) diff --git a/src/bits/boolean/mod.rs b/src/bits/boolean/mod.rs index ee0c3eb0..0b67d54b 100644 --- a/src/bits/boolean/mod.rs +++ b/src/bits/boolean/mod.rs @@ -456,7 +456,7 @@ impl Boolean { /// Convert a little-endian bitwise representation of a field element to /// `FpVar` #[tracing::instrument(target = "r1cs", skip(bits))] - pub fn le_bits_to_fp_var(bits: &[Self]) -> Result, SynthesisError> + pub fn le_bits_to_fp(bits: &[Self]) -> Result, SynthesisError> where F: PrimeField, { @@ -1003,7 +1003,7 @@ mod test { let bits: Vec<_> = AllocVar::new_variable(cs.clone(), || Ok(bits.as_slice()), mode)?; let f = AllocVar::new_variable(cs.clone(), || Ok(f), mode)?; - let claimed_f = Boolean::le_bits_to_fp_var(&bits)?; + let claimed_f = Boolean::le_bits_to_fp(&bits)?; claimed_f.enforce_equal(&f)?; } @@ -1013,7 +1013,7 @@ mod test { let bits: Vec<_> = AllocVar::new_variable(cs.clone(), || Ok(bits.as_slice()), mode)?; let f = AllocVar::new_variable(cs.clone(), || Ok(f), mode)?; - let claimed_f = Boolean::le_bits_to_fp_var(&bits)?; + let claimed_f = Boolean::le_bits_to_fp(&bits)?; claimed_f.enforce_equal(&f)?; } assert!(cs.is_satisfied().unwrap()); diff --git a/src/bits/uint/add.rs b/src/bits/uint/add.rs index e69de29b..b6565bd1 100644 --- a/src/bits/uint/add.rs +++ b/src/bits/uint/add.rs @@ -0,0 +1,170 @@ +use ark_ff::PrimeField; +use ark_relations::r1cs::SynthesisError; +use ark_std::fmt::Debug; + +use num_traits::{PrimInt, WrappingAdd}; + +use crate::{boolean::Boolean, fields::fp::FpVar, R1CSVar}; + +use super::UInt; + +impl UInt { + /// Compute `*self = self.wrapping_add(other)`. + pub fn wrapping_add_in_place(&mut self, other: &Self) { + let result = Self::wrapping_add_many(&[self.clone(), other.clone()]).unwrap(); + *self = result; + } + + /// Compute `self.wrapping_add(other)`. + pub fn wrapping_add(&self, other: &Self) -> Self { + let mut result = self.clone(); + result.wrapping_add_in_place(other); + result + } + + /// Perform wrapping addition of `operands`. + /// Computes `operands[0].wrapping_add(operands[1]).wrapping_add(operands[2])...`. + /// + /// The user must ensure that overflow does not occur. + #[tracing::instrument(target = "r1cs", skip(operands))] + pub fn wrapping_add_many(operands: &[Self]) -> Result + where + F: PrimeField, + { + // Bounds on `N` to avoid overflows + + assert!(operands.len() >= 1); + let max_value_size = N as u32 + ark_std::log2(operands.len()); + assert!(max_value_size <= F::MODULUS_BIT_SIZE); + + if operands.len() == 1 { + return Ok(operands[0].clone()); + } + + // Compute the value of the result. + let mut value = Some(T::zero()); + for op in operands { + value = value.and_then(|v| Some(v.wrapping_add(&op.value?))); + } + if operands.is_constant() { + return Ok(UInt::constant(value.unwrap())); + } + + // Compute the full (non-wrapped) sum of the operands. + let result = operands + .iter() + .map(|op| Boolean::le_bits_to_fp(&op.bits).unwrap()) + .sum::>(); + let (mut result_bits, _) = result.to_bits_le_with_top_bits_zero(max_value_size as usize)?; + // Discard any carry bits, since these will get discarded by wrapping. + result_bits.truncate(N); + let bits = TryFrom::try_from(result_bits).unwrap(); + + Ok(UInt { bits, value }) + } +} + +// #[test] +// fn test_add_many() -> Result<(), SynthesisError> { +// let mut rng = ark_std::test_rng(); + +// for _ in 0..1000 { +// let cs = ConstraintSystem::::new_ref(); + +// let a: $native = rng.gen(); +// let b: $native = rng.gen(); +// let c: $native = rng.gen(); +// let d: $native = rng.gen(); + +// let mut expected = (a ^ b).wrapping_add(c).wrapping_add(d); + +// let a_bit = $name::new_witness(ark_relations::ns!(cs, "a_bit"), || Ok(a))?; +// let b_bit = $name::constant(b); +// let c_bit = $name::constant(c); +// let d_bit = $name::new_witness(ark_relations::ns!(cs, "d_bit"), || Ok(d))?; + +// let r = a_bit.xor(&b_bit).unwrap(); +// let r = $name::add_many(&[r, c_bit, d_bit]).unwrap(); + +// assert!(cs.is_satisfied().unwrap()); +// assert!(r.value == Some(expected)); + +// for b in r.bits.iter() { +// match b { +// Boolean::Is(b) => assert_eq!(b.value()?, (expected & 1 == 1)), +// Boolean::Not(b) => assert_eq!(!b.value()?, (expected & 1 == 1)), +// Boolean::Constant(_) => unreachable!(), +// } + +// expected >>= 1; +// } +// } +// Ok(()) +// } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + prelude::EqGadget, + uint::test_utils::{run_binary_exhaustive, run_binary_random}, + R1CSVar, + }; + use ark_ff::PrimeField; + use ark_test_curves::bls12_381::Fr; + + fn uint_wrapping_add( + a: UInt, + b: UInt, + ) -> Result<(), SynthesisError> { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let computed = a.wrapping_add(&b); + let _ = dbg!(a.value()); + let _ = dbg!(b.value()); + dbg!(a.is_constant()); + dbg!(b.is_constant()); + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = UInt::new_variable( + cs.clone(), + || Ok(a.value()?.wrapping_add(&b.value()?)), + expected_mode, + )?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + #[test] + fn u8_wrapping_add() { + run_binary_exhaustive(uint_wrapping_add::).unwrap() + } + + #[test] + fn u16_wrapping_add() { + run_binary_random::<1000, 16, _, _>(uint_wrapping_add::).unwrap() + } + + #[test] + fn u32_wrapping_add() { + run_binary_random::<1000, 32, _, _>(uint_wrapping_add::).unwrap() + } + + #[test] + fn u64_wrapping_add() { + run_binary_random::<1000, 64, _, _>(uint_wrapping_add::).unwrap() + } + + #[test] + fn u128_wrapping_add() { + run_binary_random::<1000, 128, _, _>(uint_wrapping_add::).unwrap() + } +} diff --git a/src/bits/uint/cmp.rs b/src/bits/uint/cmp.rs index a138ab88..d2c16d75 100644 --- a/src/bits/uint/cmp.rs +++ b/src/bits/uint/cmp.rs @@ -1,6 +1,7 @@ use crate::fields::fp::FpVar; use super::*; + impl UInt { pub fn is_gt(&self, other: &Self) -> Result, SynthesisError> { other.is_lt(self) @@ -10,7 +11,7 @@ impl UInt { if N + 1 < ((F::MODULUS_BIT_SIZE - 1) as usize) { let a = self.to_fp()?; let b = other.to_fp()?; - let (_, rest) = Self::from_fp_to_parts(&(a - b))?; + let (_, rest) = FpVar::to_bits_le_with_top_bits_zero(&(a - b), N + 1)?; rest.is_eq(&FpVar::zero()) } else { unimplemented!("bit sizes larger than modulus size not yet supported") diff --git a/src/bits/uint/convert.rs b/src/bits/uint/convert.rs index d107ecdb..3308a145 100644 --- a/src/bits/uint/convert.rs +++ b/src/bits/uint/convert.rs @@ -1,5 +1,3 @@ -use ark_ff::BigInteger; - use crate::fields::fp::FpVar; use super::*; @@ -16,44 +14,21 @@ impl UInt { { assert!(N <= F::MODULUS_BIT_SIZE as usize - 1); - Boolean::le_bits_to_fp_var(&self.bits) - } - - pub(super) fn from_fp_to_parts(other: &FpVar) -> Result<(Self, FpVar), SynthesisError> - where - F: PrimeField, - { - assert!(N <= F::MODULUS_BIT_SIZE as usize - 1); - let value = other.value()?.into_bigint().to_bits_le(); - let cs = other.cs(); - let mode = if other.is_constant() { - AllocationMode::Constant - } else { - AllocationMode::Witness - }; - let lower_bits = value - .iter() - .take(N) - .map(|b| Boolean::new_variable(cs.clone(), || Ok(*b), mode)) - .collect::, _>>()?; - let result = Self::from_bits_le(&lower_bits); - let rest: FpVar = other - &result.to_fp()?; - (result.to_fp()? + &rest).enforce_equal(&other)?; - Ok((result, rest)) + Boolean::le_bits_to_fp(&self.bits) } /// Converts a field element into its little-endian bit order representation. /// /// # Panics /// - /// Assumes that `N` is equal to at most the number of bits in `F::MODULUS_BIT_SIZE - 1`, and panics otherwise. - pub fn from_fp(other: &FpVar) -> Result + /// Assumes that `N` is at most the number of bits in `F::MODULUS_BIT_SIZE - 1`, and panics otherwise. + pub fn from_fp(other: &FpVar) -> Result<(Self, FpVar), SynthesisError> where F: PrimeField, { - let (result, rest) = Self::from_fp_to_parts(other)?; - rest.enforce_equal(&FpVar::zero())?; - Ok(result) + let (bits, rest) = other.to_bits_le_with_top_bits_zero(N)?; + let result = Self::from_bits_le(&bits); + Ok((result, rest)) } /// Turns `self` into the underlying little-endian bits. diff --git a/src/bits/uint/eq.rs b/src/bits/uint/eq.rs index a8a3065c..472bcef4 100644 --- a/src/bits/uint/eq.rs +++ b/src/bits/uint/eq.rs @@ -20,8 +20,8 @@ impl EqGadget, _>>()?; @@ -40,8 +40,8 @@ impl EqGadget EqGadget UInt { } Ok(output_vec) } - - /// Perform modular addition of `operands`. - /// - /// The user must ensure that overflow does not occur. - #[tracing::instrument(target = "r1cs", skip(operands))] - pub fn add_many(operands: &[Self]) -> Result - where - F: PrimeField, - { - // Make some arbitrary bounds for ourselves to avoid overflows - // in the scalar field - assert!(F::MODULUS_BIT_SIZE as usize >= 2 * N); - - assert!(operands.len() >= 1); - assert!(N as u32 + ark_std::log2(operands.len()) <= F::MODULUS_BIT_SIZE); - - if operands.len() == 1 { - return Ok(operands[0].clone()); - } - - // Compute the maximum value of the sum so we allocate enough bits for - // the result - let mut max_value = T::max_value() * T::from(operands.len() as u128).unwrap(); - - // Keep track of the resulting value - let mut result_value = Some(BigUint::zero()); - - // This is a linear combination that we will enforce to be "zero" - let mut lc = LinearCombination::zero(); - - let mut all_constants = true; - - // Iterate over the operands - for op in operands { - // Accumulate the value - match op.value { - Some(val) => { - result_value - .as_mut() - .map(|v| *v += BigUint::from(val.to_u128().unwrap())); - }, - - None => { - // If any of our operands have unknown value, we won't - // know the value of the result - result_value = None; - }, - } - - // Iterate over each bit_gadget of the operand and add the operand to - // the linear combination - let mut coeff = F::one(); - for bit in &op.bits { - match *bit { - Boolean::Constant(bit) => { - if bit { - lc += (coeff, Variable::One); - } - }, - _ => { - all_constants = false; - - // Add coeff * bit_gadget - lc = lc + (coeff, &bit.lc()); - }, - } - - coeff.double_in_place(); - } - } - - // The value of the actual result is modulo 2^$size - let modular_value = result_value.clone().and_then(|v| { - let modulus = BigUint::from(1u64) << (N as u32); - NumCast::from(v % modulus) - }); - - if all_constants && modular_value.is_some() { - // We can just return a constant, rather than - // unpacking the result into allocated bits. - - return modular_value - .map(UInt::constant) - .ok_or(SynthesisError::AssignmentMissing); - } - let cs = operands.cs(); - - // Storage area for the resulting bits - let mut result_bits = vec![]; - - // Allocate each bit_gadget of the result - let mut coeff = F::one(); - let mut i = 0; - while max_value != T::zero() { - // Allocate the bit_gadget - let b = AllocatedBool::new_witness(cs.clone(), || { - result_value - .clone() - .map(|v| (v >> i) & BigUint::one() == BigUint::one()) - .get() - })?; - - // Subtract this bit_gadget from the linear combination to ensure the sums - // balance out - lc = lc - (coeff, b.variable()); - - result_bits.push(b.into()); - - max_value = max_value >> 1; - i += 1; - coeff.double_in_place(); - } - - // Enforce that the linear combination equals zero - cs.enforce_constraint(lc!(), lc!(), lc)?; - - // Discard carry bits that we don't care about - result_bits.truncate(N); - let bits = TryFrom::try_from(result_bits).unwrap(); - - Ok(UInt { - bits, - value: modular_value, - }) - } } impl CondSelectGadget @@ -421,70 +290,3 @@ impl AllocVar Result<(), SynthesisError> { -// let mut rng = ark_std::test_rng(); - -// for _ in 0..1000 { -// let cs = ConstraintSystem::::new_ref(); - -// let a: $native = rng.gen(); -// let b: $native = rng.gen(); -// let c: $native = rng.gen(); -// let d: $native = rng.gen(); - -// let mut expected = (a ^ b).wrapping_add(c).wrapping_add(d); - -// let a_bit = $name::new_witness(ark_relations::ns!(cs, "a_bit"), || Ok(a))?; -// let b_bit = $name::constant(b); -// let c_bit = $name::constant(c); -// let d_bit = $name::new_witness(ark_relations::ns!(cs, "d_bit"), || Ok(d))?; - -// let r = a_bit.xor(&b_bit).unwrap(); -// let r = $name::add_many(&[r, c_bit, d_bit]).unwrap(); - -// assert!(cs.is_satisfied().unwrap()); -// assert!(r.value == Some(expected)); - -// for b in r.bits.iter() { -// match b { -// Boolean::Is(b) => assert_eq!(b.value()?, (expected & 1 == 1)), -// Boolean::Not(b) => assert_eq!(!b.value()?, (expected & 1 == 1)), -// Boolean::Constant(_) => unreachable!(), -// } - -// expected >>= 1; -// } -// } -// Ok(()) -// } - -// #[test] -// fn test_rotr() -> Result<(), SynthesisError> { -// let mut rng = ark_std::test_rng(); - -// let mut num = rng.gen(); - -// let a: $name = $name::constant(num); - -// for i in 0..$size { -// let b = a.rotr(i); - -// assert!(b.value.unwrap() == num); - -// let mut tmp = num; -// for b in &b.bits { -// match b { -// Boolean::Constant(b) => assert_eq!(*b, tmp & 1 == 1), -// _ => unreachable!(), -// } - -// tmp >>= 1; -// } - -// num = num.rotate_right(1); -// } -// Ok(()) -// } -// } diff --git a/src/bits/uint8.rs b/src/bits/uint8.rs index 6474006e..d54fd985 100644 --- a/src/bits/uint8.rs +++ b/src/bits/uint8.rs @@ -80,7 +80,7 @@ impl ToConstraintFieldGadget for [UInt8 Result>, SynthesisError> { let max_size = ((ConstraintF::MODULUS_BIT_SIZE - 1) / 8) as usize; self.chunks(max_size) - .map(|chunk| Boolean::le_bits_to_fp_var(chunk.to_bits_le()?.as_slice())) + .map(|chunk| Boolean::le_bits_to_fp(chunk.to_bits_le()?.as_slice())) .collect::, SynthesisError>>() } } diff --git a/src/fields/fp/mod.rs b/src/fields/fp/mod.rs index 878a1d4e..1527cace 100644 --- a/src/fields/fp/mod.rs +++ b/src/fields/fp/mod.rs @@ -49,6 +49,35 @@ pub enum FpVar { Var(AllocatedFp), } +impl FpVar { + /// Decomposes `self` into a vector of `bits` and a remainder `rest` such that + /// * `bits.len() == size`, and + /// * `rest == 0`. + pub fn to_bits_le_with_top_bits_zero( + &self, + size: usize, + ) -> Result<(Vec>, Self), SynthesisError> { + assert!(size <= F::MODULUS_BIT_SIZE as usize - 1); + let cs = self.cs(); + let mode = if self.is_constant() { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + + let value = self.value().map(|f| f.into_bigint()); + let lower_bits = (0..size) + .map(|i| { + Boolean::new_variable(cs.clone(), || value.map(|v| v.get_bit(i as usize)), mode) + }) + .collect::, _>>()?; + let lower_bits_fp = Boolean::le_bits_to_fp(&lower_bits)?; + let rest = self - &lower_bits_fp; + rest.enforce_equal(&Self::zero())?; + Ok((lower_bits, rest)) + } +} + impl R1CSVar for FpVar { type Value = F; @@ -135,6 +164,7 @@ impl AllocatedFp { let mut value = F::zero(); let mut new_lc = lc!(); + let mut num_iters = 0; for variable in iter { let variable = variable.borrow(); if !variable.cs.is_none() { @@ -146,14 +176,16 @@ impl AllocatedFp { value += variable.value.unwrap(); } new_lc = new_lc + variable.variable; + num_iters += 1; } + assert_ne!(num_iters, 0); let variable = cs.new_lc(new_lc).unwrap(); if has_value { - AllocatedFp::new(Some(value), variable, cs.clone()) + AllocatedFp::new(Some(value), variable, cs) } else { - AllocatedFp::new(None, variable, cs.clone()) + AllocatedFp::new(None, variable, cs) } } From c5d8cf6263607777a002accf10d4b40c43c6c039 Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Thu, 18 May 2023 09:54:32 -0400 Subject: [PATCH 13/29] Fix --- src/bits/uint/cmp.rs | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/bits/uint/cmp.rs b/src/bits/uint/cmp.rs index d2c16d75..e0b5dfc5 100644 --- a/src/bits/uint/cmp.rs +++ b/src/bits/uint/cmp.rs @@ -1,8 +1,6 @@ -use crate::fields::fp::FpVar; - use super::*; -impl UInt { +impl> UInt { pub fn is_gt(&self, other: &Self) -> Result, SynthesisError> { other.is_lt(self) } @@ -11,8 +9,8 @@ impl UInt { if N + 1 < ((F::MODULUS_BIT_SIZE - 1) as usize) { let a = self.to_fp()?; let b = other.to_fp()?; - let (_, rest) = FpVar::to_bits_le_with_top_bits_zero(&(a - b), N + 1)?; - rest.is_eq(&FpVar::zero()) + let (bits, _) = (a - b + F::from(T::max_value()) + F::one()).to_bits_le_with_top_bits_zero(N + 1)?; + Ok(bits.last().unwrap().clone()) } else { unimplemented!("bit sizes larger than modulus size not yet supported") } @@ -39,7 +37,7 @@ mod tests { use ark_ff::PrimeField; use ark_test_curves::bls12_381::Fr; - fn uint_gt( + fn uint_gt>( a: UInt, b: UInt, ) -> Result<(), SynthesisError> { @@ -61,7 +59,7 @@ mod tests { Ok(()) } - fn uint_lt( + fn uint_lt>( a: UInt, b: UInt, ) -> Result<(), SynthesisError> { @@ -83,7 +81,7 @@ mod tests { Ok(()) } - fn uint_ge( + fn uint_ge>( a: UInt, b: UInt, ) -> Result<(), SynthesisError> { @@ -105,7 +103,7 @@ mod tests { Ok(()) } - fn uint_le( + fn uint_le>( a: UInt, b: UInt, ) -> Result<(), SynthesisError> { From da374f97fe420fb957077e35b818e7e95f9a5d6c Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Thu, 18 May 2023 13:41:09 -0400 Subject: [PATCH 14/29] Move and test `select` --- src/bits/boolean/mod.rs | 56 +--------------------- src/bits/boolean/select.rs | 98 ++++++++++++++++++++++++++++++++++++++ src/bits/uint/cmp.rs | 3 +- src/bits/uint/mod.rs | 31 +----------- src/bits/uint/select.rs | 98 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 200 insertions(+), 86 deletions(-) create mode 100644 src/bits/boolean/select.rs create mode 100644 src/bits/uint/select.rs diff --git a/src/bits/boolean/mod.rs b/src/bits/boolean/mod.rs index 0b67d54b..cb867539 100644 --- a/src/bits/boolean/mod.rs +++ b/src/bits/boolean/mod.rs @@ -11,6 +11,7 @@ mod cmp; mod eq; mod not; mod or; +mod select; mod xor; #[cfg(test)] @@ -658,61 +659,6 @@ impl ToConstraintFieldGadget for Boolean { } } -impl CondSelectGadget for Boolean { - #[tracing::instrument(target = "r1cs")] - fn conditionally_select( - cond: &Boolean, - true_val: &Self, - false_val: &Self, - ) -> Result { - use Boolean::*; - match cond { - Constant(true) => Ok(true_val.clone()), - Constant(false) => Ok(false_val.clone()), - cond @ Var(_) => match (true_val, false_val) { - (x, &Constant(false)) => Ok(cond & x), - (&Constant(false), x) => Ok((!cond) & x), - (&Constant(true), x) => Ok(cond | x), - (x, &Constant(true)) => Ok((!cond) | x), - (a, b) => { - let cs = cond.cs(); - let result: Boolean = - AllocatedBool::new_witness_without_booleanity_check(cs.clone(), || { - let cond = cond.value()?; - Ok(if cond { a.value()? } else { b.value()? }) - })? - .into(); - // a = self; b = other; c = cond; - // - // r = c * a + (1 - c) * b - // r = b + c * (a - b) - // c * (a - b) = r - b - // - // If a, b, cond are all boolean, so is r. - // - // self | other | cond | result - // -----|-------|---------------- - // 0 | 0 | 1 | 0 - // 0 | 1 | 1 | 0 - // 1 | 0 | 1 | 1 - // 1 | 1 | 1 | 1 - // 0 | 0 | 0 | 0 - // 0 | 1 | 0 | 1 - // 1 | 0 | 0 | 0 - // 1 | 1 | 0 | 1 - cs.enforce_constraint( - cond.lc(), - lc!() + a.lc() - b.lc(), - lc!() + result.lc() - b.lc(), - )?; - - Ok(result) - }, - }, - } - } -} - #[cfg(test)] mod test { use super::{AllocatedBool, Boolean}; diff --git a/src/bits/boolean/select.rs b/src/bits/boolean/select.rs new file mode 100644 index 00000000..7549caa8 --- /dev/null +++ b/src/bits/boolean/select.rs @@ -0,0 +1,98 @@ +use super::*; + +impl CondSelectGadget for Boolean { + #[tracing::instrument(target = "r1cs")] + fn conditionally_select( + cond: &Boolean, + true_val: &Self, + false_val: &Self, + ) -> Result { + use Boolean::*; + match cond { + Constant(true) => Ok(true_val.clone()), + Constant(false) => Ok(false_val.clone()), + cond @ Var(_) => match (true_val, false_val) { + (x, &Constant(false)) => Ok(cond & x), + (&Constant(false), x) => Ok((!cond) & x), + (&Constant(true), x) => Ok(cond | x), + (x, &Constant(true)) => Ok((!cond) | x), + (a, b) => { + let cs = cond.cs(); + let result: Boolean = + AllocatedBool::new_witness_without_booleanity_check(cs.clone(), || { + let cond = cond.value()?; + Ok(if cond { a.value()? } else { b.value()? }) + })? + .into(); + // a = self; b = other; c = cond; + // + // r = c * a + (1 - c) * b + // r = b + c * (a - b) + // c * (a - b) = r - b + // + // If a, b, cond are all boolean, so is r. + // + // self | other | cond | result + // -----|-------|---------------- + // 0 | 0 | 1 | 0 + // 0 | 1 | 1 | 0 + // 1 | 0 | 1 | 1 + // 1 | 1 | 1 | 1 + // 0 | 0 | 0 | 0 + // 0 | 1 | 0 | 1 + // 1 | 0 | 0 | 0 + // 1 | 1 | 0 | 1 + cs.enforce_constraint( + cond.lc(), + lc!() + a.lc() - b.lc(), + lc!() + result.lc() - b.lc(), + )?; + + Ok(result) + }, + }, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + boolean::test_utils::run_binary_exhaustive, + prelude::EqGadget, + R1CSVar, + }; + use ark_test_curves::bls12_381::Fr; + + #[test] + fn or() { + run_binary_exhaustive::(|a, b| { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + for cond in [true, false] { + let expected = Boolean::new_variable( + cs.clone(), + || Ok(if cond { a.value()? } else { b.value()? }), + expected_mode, + )?; + let cond = Boolean::new_variable(cs.clone(), || Ok(cond), expected_mode)?; + let computed = cond.select(&a, &b)?; + + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + } + Ok(()) + }) + .unwrap() + } +} diff --git a/src/bits/uint/cmp.rs b/src/bits/uint/cmp.rs index e0b5dfc5..864c85b8 100644 --- a/src/bits/uint/cmp.rs +++ b/src/bits/uint/cmp.rs @@ -9,7 +9,8 @@ impl> UInt if N + 1 < ((F::MODULUS_BIT_SIZE - 1) as usize) { let a = self.to_fp()?; let b = other.to_fp()?; - let (bits, _) = (a - b + F::from(T::max_value()) + F::one()).to_bits_le_with_top_bits_zero(N + 1)?; + let (bits, _) = (a - b + F::from(T::max_value()) + F::one()) + .to_bits_le_with_top_bits_zero(N + 1)?; Ok(bits.last().unwrap().clone()) } else { unimplemented!("bit sizes larger than modulus size not yet supported") diff --git a/src/bits/uint/mod.rs b/src/bits/uint/mod.rs index b18145fd..a2effea1 100644 --- a/src/bits/uint/mod.rs +++ b/src/bits/uint/mod.rs @@ -14,6 +14,7 @@ mod eq; mod not; mod or; mod rotate; +mod select; mod xor; #[cfg(test)] @@ -120,36 +121,6 @@ impl UInt { } } -impl CondSelectGadget - for UInt -{ - #[tracing::instrument(target = "r1cs", skip(cond, true_value, false_value))] - fn conditionally_select( - cond: &Boolean, - true_value: &Self, - false_value: &Self, - ) -> Result { - let selected_bits = true_value - .bits - .iter() - .zip(&false_value.bits) - .map(|(t, f)| cond.select(t, f)); - let mut bits = [Boolean::FALSE; N]; - for (result, new) in bits.iter_mut().zip(selected_bits) { - *result = new?; - } - - let value = cond.value().ok().and_then(|cond| { - if cond { - true_value.value().ok() - } else { - false_value.value().ok() - } - }); - Ok(Self { bits, value }) - } -} - impl AllocVar for UInt { diff --git a/src/bits/uint/select.rs b/src/bits/uint/select.rs new file mode 100644 index 00000000..e50f0730 --- /dev/null +++ b/src/bits/uint/select.rs @@ -0,0 +1,98 @@ +use super::*; +use crate::select::CondSelectGadget; + +impl CondSelectGadget + for UInt +{ + #[tracing::instrument(target = "r1cs", skip(cond, true_value, false_value))] + fn conditionally_select( + cond: &Boolean, + true_value: &Self, + false_value: &Self, + ) -> Result { + let selected_bits = true_value + .bits + .iter() + .zip(&false_value.bits) + .map(|(t, f)| cond.select(t, f)); + let mut bits = [Boolean::FALSE; N]; + for (result, new) in bits.iter_mut().zip(selected_bits) { + *result = new?; + } + + let value = cond.value().ok().and_then(|cond| { + if cond { + true_value.value().ok() + } else { + false_value.value().ok() + } + }); + Ok(Self { bits, value }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + prelude::EqGadget, + uint::test_utils::{run_binary_exhaustive, run_binary_random}, + }; + use ark_ff::PrimeField; + use ark_test_curves::bls12_381::Fr; + + fn uint_select( + a: UInt, + b: UInt, + ) -> Result<(), SynthesisError> { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + for cond in [true, false] { + let expected = UInt::new_variable( + cs.clone(), + || Ok(if cond { a.value()? } else { b.value()? }), + expected_mode, + )?; + let cond = Boolean::new_variable(cs.clone(), || Ok(cond), expected_mode)?; + let computed = cond.select(&a, &b)?; + + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + } + Ok(()) + } + + #[test] + fn u8_select() { + run_binary_exhaustive(uint_select::).unwrap() + } + + #[test] + fn u16_select() { + run_binary_random::<1000, 16, _, _>(uint_select::).unwrap() + } + + #[test] + fn u32_select() { + run_binary_random::<1000, 32, _, _>(uint_select::).unwrap() + } + + #[test] + fn u64_select() { + run_binary_random::<1000, 64, _, _>(uint_select::).unwrap() + } + + #[test] + fn u128_select() { + run_binary_random::<1000, 128, _, _>(uint_select::).unwrap() + } +} From f6922f77eb75eaa5f38ccb006296b6b3a33a23eb Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Wed, 7 Jun 2023 10:26:02 -0700 Subject: [PATCH 15/29] Refactor booleans --- src/bits/boolean/allocated.rs | 330 +++++++++++++++++++++++ src/bits/boolean/and.rs | 86 +++++- src/bits/boolean/mod.rs | 480 +--------------------------------- src/bits/boolean/or.rs | 38 +++ src/bits/boolean/select.rs | 36 +++ src/eq.rs | 4 +- 6 files changed, 505 insertions(+), 469 deletions(-) create mode 100644 src/bits/boolean/allocated.rs diff --git a/src/bits/boolean/allocated.rs b/src/bits/boolean/allocated.rs new file mode 100644 index 00000000..4849ddce --- /dev/null +++ b/src/bits/boolean/allocated.rs @@ -0,0 +1,330 @@ +use core::borrow::Borrow; + +use ark_ff::Field; +use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError, Variable}; + +use crate::{ + alloc::{AllocVar, AllocationMode}, + select::CondSelectGadget, + Assignment, +}; + +use super::{bool_to_field, Boolean}; + +/// Represents a variable in the constraint system which is guaranteed +/// to be either zero or one. +/// +/// In general, one should prefer using `Boolean` instead of `AllocatedBool`, +/// as `Boolean` offers better support for constant values, and implements +/// more traits. +#[derive(Clone, Debug, Eq, PartialEq)] +#[must_use] +pub struct AllocatedBool { + pub(super) variable: Variable, + pub(super) cs: ConstraintSystemRef, +} + +impl AllocatedBool { + /// Get the assigned value for `self`. + pub fn value(&self) -> Result { + let value = self.cs.assigned_value(self.variable).get()?; + if value.is_zero() { + Ok(false) + } else if value.is_one() { + Ok(true) + } else { + unreachable!("Incorrect value assigned: {:?}", value); + } + } + + /// Get the R1CS variable for `self`. + pub fn variable(&self) -> Variable { + self.variable + } + + /// Allocate a witness variable without a booleanity check. + #[doc(hidden)] + pub fn new_witness_without_booleanity_check>( + cs: ConstraintSystemRef, + f: impl FnOnce() -> Result, + ) -> Result { + let variable = cs.new_witness_variable(|| f().map(bool_to_field))?; + Ok(Self { variable, cs }) + } + + /// Performs an XOR operation over the two operands, returning + /// an `AllocatedBool`. + #[tracing::instrument(target = "r1cs")] + pub fn not(&self) -> Result { + let variable = self.cs.new_lc(lc!() + Variable::One - self.variable)?; + Ok(Self { + variable, + cs: self.cs.clone(), + }) + } + + /// Performs an XOR operation over the two operands, returning + /// an `AllocatedBool`. + #[tracing::instrument(target = "r1cs")] + pub fn xor(&self, b: &Self) -> Result { + let result = Self::new_witness_without_booleanity_check(self.cs.clone(), || { + Ok(self.value()? ^ b.value()?) + })?; + + // Constrain (a + a) * (b) = (a + b - c) + // Given that a and b are boolean constrained, if they + // are equal, the only solution for c is 0, and if they + // are different, the only solution for c is 1. + // + // ¬(a ∧ b) ∧ ¬(¬a ∧ ¬b) = c + // (1 - (a * b)) * (1 - ((1 - a) * (1 - b))) = c + // (1 - ab) * (1 - (1 - a - b + ab)) = c + // (1 - ab) * (a + b - ab) = c + // a + b - ab - (a^2)b - (b^2)a + (a^2)(b^2) = c + // a + b - ab - ab - ab + ab = c + // a + b - 2ab = c + // -2a * b = c - a - b + // 2a * b = a + b - c + // (a + a) * b = a + b - c + self.cs.enforce_constraint( + lc!() + self.variable + self.variable, + lc!() + b.variable, + lc!() + self.variable + b.variable - result.variable, + )?; + + Ok(result) + } + + /// Performs an AND operation over the two operands, returning + /// an `AllocatedBool`. + #[tracing::instrument(target = "r1cs")] + pub fn and(&self, b: &Self) -> Result { + let result = Self::new_witness_without_booleanity_check(self.cs.clone(), || { + Ok(self.value()? & b.value()?) + })?; + + // Constrain (a) * (b) = (c), ensuring c is 1 iff + // a AND b are both 1. + self.cs.enforce_constraint( + lc!() + self.variable, + lc!() + b.variable, + lc!() + result.variable, + )?; + + Ok(result) + } + + /// Performs an OR operation over the two operands, returning + /// an `AllocatedBool`. + #[tracing::instrument(target = "r1cs")] + pub fn or(&self, b: &Self) -> Result { + let result = Self::new_witness_without_booleanity_check(self.cs.clone(), || { + Ok(self.value()? | b.value()?) + })?; + + // Constrain (1 - a) * (1 - b) = (c), ensuring c is 1 iff + // a and b are both false, and otherwise c is 0. + self.cs.enforce_constraint( + lc!() + Variable::One - self.variable, + lc!() + Variable::One - b.variable, + lc!() + Variable::One - result.variable, + )?; + + Ok(result) + } + + /// Calculates `a AND (NOT b)`. + #[tracing::instrument(target = "r1cs")] + pub fn and_not(&self, b: &Self) -> Result { + let result = Self::new_witness_without_booleanity_check(self.cs.clone(), || { + Ok(self.value()? & !b.value()?) + })?; + + // Constrain (a) * (1 - b) = (c), ensuring c is 1 iff + // a is true and b is false, and otherwise c is 0. + self.cs.enforce_constraint( + lc!() + self.variable, + lc!() + Variable::One - b.variable, + lc!() + result.variable, + )?; + + Ok(result) + } + + /// Calculates `(NOT a) AND (NOT b)`. + #[tracing::instrument(target = "r1cs")] + pub fn nor(&self, b: &Self) -> Result { + let result = Self::new_witness_without_booleanity_check(self.cs.clone(), || { + Ok(!(self.value()? | b.value()?)) + })?; + + // Constrain (1 - a) * (1 - b) = (c), ensuring c is 1 iff + // a and b are both false, and otherwise c is 0. + self.cs.enforce_constraint( + lc!() + Variable::One - self.variable, + lc!() + Variable::One - b.variable, + lc!() + result.variable, + )?; + + Ok(result) + } +} + +impl AllocVar for AllocatedBool { + /// Produces a new variable of the appropriate kind + /// (instance or witness), with a booleanity check. + /// + /// N.B.: we could omit the booleanity check when allocating `self` + /// as a new public input, but that places an additional burden on + /// protocol designers. Better safe than sorry! + fn new_variable>( + cs: impl Into>, + f: impl FnOnce() -> Result, + mode: AllocationMode, + ) -> Result { + let ns = cs.into(); + let cs = ns.cs(); + if mode == AllocationMode::Constant { + let variable = if *f()?.borrow() { + Variable::One + } else { + Variable::Zero + }; + Ok(Self { variable, cs }) + } else { + let variable = if mode == AllocationMode::Input { + cs.new_input_variable(|| f().map(bool_to_field))? + } else { + cs.new_witness_variable(|| f().map(bool_to_field))? + }; + + // Constrain: (1 - a) * a = 0 + // This constrains a to be either 0 or 1. + + cs.enforce_constraint(lc!() + Variable::One - variable, lc!() + variable, lc!())?; + + Ok(Self { variable, cs }) + } + } +} + +impl CondSelectGadget for AllocatedBool { + #[tracing::instrument(target = "r1cs")] + fn conditionally_select( + cond: &Boolean, + true_val: &Self, + false_val: &Self, + ) -> Result { + let res = Boolean::conditionally_select( + cond, + &true_val.clone().into(), + &false_val.clone().into(), + )?; + match res { + Boolean::Var(a) => Ok(a), + _ => unreachable!("Impossible"), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + use ark_relations::r1cs::ConstraintSystem; + use ark_test_curves::bls12_381::Fr; + #[test] + fn allocated_xor() -> Result<(), SynthesisError> { + for a_val in [false, true].iter().copied() { + for b_val in [false, true].iter().copied() { + let cs = ConstraintSystem::::new_ref(); + let a = AllocatedBool::new_witness(cs.clone(), || Ok(a_val))?; + let b = AllocatedBool::new_witness(cs.clone(), || Ok(b_val))?; + let c = AllocatedBool::xor(&a, &b)?; + assert_eq!(c.value()?, a_val ^ b_val); + + assert!(cs.is_satisfied().unwrap()); + assert_eq!(a.value()?, (a_val)); + assert_eq!(b.value()?, (b_val)); + assert_eq!(c.value()?, (a_val ^ b_val)); + } + } + Ok(()) + } + + #[test] + fn allocated_or() -> Result<(), SynthesisError> { + for a_val in [false, true].iter().copied() { + for b_val in [false, true].iter().copied() { + let cs = ConstraintSystem::::new_ref(); + let a = AllocatedBool::new_witness(cs.clone(), || Ok(a_val))?; + let b = AllocatedBool::new_witness(cs.clone(), || Ok(b_val))?; + let c = AllocatedBool::or(&a, &b)?; + assert_eq!(c.value()?, a_val | b_val); + + assert!(cs.is_satisfied().unwrap()); + assert_eq!(a.value()?, (a_val)); + assert_eq!(b.value()?, (b_val)); + assert_eq!(c.value()?, (a_val | b_val)); + } + } + Ok(()) + } + + #[test] + fn allocated_and() -> Result<(), SynthesisError> { + for a_val in [false, true].iter().copied() { + for b_val in [false, true].iter().copied() { + let cs = ConstraintSystem::::new_ref(); + let a = AllocatedBool::new_witness(cs.clone(), || Ok(a_val))?; + let b = AllocatedBool::new_witness(cs.clone(), || Ok(b_val))?; + let c = AllocatedBool::and(&a, &b)?; + assert_eq!(c.value()?, a_val & b_val); + + assert!(cs.is_satisfied().unwrap()); + assert_eq!(a.value()?, (a_val)); + assert_eq!(b.value()?, (b_val)); + assert_eq!(c.value()?, (a_val & b_val)); + } + } + Ok(()) + } + + #[test] + fn allocated_and_not() -> Result<(), SynthesisError> { + for a_val in [false, true].iter().copied() { + for b_val in [false, true].iter().copied() { + let cs = ConstraintSystem::::new_ref(); + let a = AllocatedBool::new_witness(cs.clone(), || Ok(a_val))?; + let b = AllocatedBool::new_witness(cs.clone(), || Ok(b_val))?; + let c = AllocatedBool::and_not(&a, &b)?; + assert_eq!(c.value()?, a_val & !b_val); + + assert!(cs.is_satisfied().unwrap()); + assert_eq!(a.value()?, (a_val)); + assert_eq!(b.value()?, (b_val)); + assert_eq!(c.value()?, (a_val & !b_val)); + } + } + Ok(()) + } + + #[test] + fn allocated_nor() -> Result<(), SynthesisError> { + for a_val in [false, true].iter().copied() { + for b_val in [false, true].iter().copied() { + let cs = ConstraintSystem::::new_ref(); + let a = AllocatedBool::new_witness(cs.clone(), || Ok(a_val))?; + let b = AllocatedBool::new_witness(cs.clone(), || Ok(b_val))?; + let c = AllocatedBool::nor(&a, &b)?; + assert_eq!(c.value()?, !a_val & !b_val); + + assert!(cs.is_satisfied().unwrap()); + assert_eq!(a.value()?, (a_val)); + assert_eq!(b.value()?, (b_val)); + assert_eq!(c.value()?, (!a_val & !b_val)); + } + } + Ok(()) + } +} diff --git a/src/bits/boolean/and.rs b/src/bits/boolean/and.rs index 9b498e5b..61cc70ff 100644 --- a/src/bits/boolean/and.rs +++ b/src/bits/boolean/and.rs @@ -1,7 +1,9 @@ -use ark_ff::Field; +use ark_ff::{Field, PrimeField}; use ark_relations::r1cs::SynthesisError; use ark_std::{ops::BitAnd, ops::BitAndAssign}; +use crate::{fields::fp::FpVar, prelude::EqGadget}; + use super::Boolean; impl Boolean { @@ -17,6 +19,88 @@ impl Boolean { } } +impl Boolean { + /// Outputs `bits[0] & bits[1] & ... & bits.last().unwrap()`. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// + /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; + /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; + /// let c = Boolean::new_witness(cs.clone(), || Ok(true))?; + /// + /// Boolean::kary_and(&[a.clone(), b.clone(), c.clone()])?.enforce_equal(&Boolean::FALSE)?; + /// Boolean::kary_and(&[a.clone(), c.clone()])?.enforce_equal(&Boolean::TRUE)?; + /// + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs")] + pub fn kary_and(bits: &[Self]) -> Result { + assert!(!bits.is_empty()); + if bits.len() <= 3 { + let mut cur: Option = None; + for next in bits { + cur = if let Some(b) = cur { + Some(b & next) + } else { + Some(next.clone()) + }; + } + + Ok(cur.expect("should not be 0")) + } else { + let sum_bits: FpVar<_> = bits.iter().map(|b| FpVar::from(b.clone())).sum(); + let num_bits = FpVar::Constant(F::from(bits.len() as u64)); + sum_bits.is_eq(&num_bits) + } + } + + /// Outputs `!(bits[0] & bits[1] & ... & bits.last().unwrap())`. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// + /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; + /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; + /// let c = Boolean::new_witness(cs.clone(), || Ok(true))?; + /// + /// Boolean::kary_nand(&[a.clone(), b.clone(), c.clone()])?.enforce_equal(&Boolean::TRUE)?; + /// Boolean::kary_nand(&[a.clone(), c.clone()])?.enforce_equal(&Boolean::FALSE)?; + /// Boolean::kary_nand(&[b.clone(), c.clone()])?.enforce_equal(&Boolean::TRUE)?; + /// + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs")] + pub fn kary_nand(bits: &[Self]) -> Result { + Ok(!Self::kary_and(bits)?) + } + + /// Enforces that `Self::kary_nand(bits).is_eq(&Boolean::TRUE)`. + /// + /// Informally, this means that at least one element in `bits` must be + /// `false`. + #[tracing::instrument(target = "r1cs")] + pub fn enforce_kary_nand(bits: &[Self]) -> Result<(), SynthesisError> { + Self::kary_nand(bits)?.enforce_equal(&Boolean::TRUE) + } +} + impl<'a, F: Field> BitAnd for &'a Boolean { type Output = Boolean; /// Outputs `self & other`. diff --git a/src/bits/boolean/mod.rs b/src/bits/boolean/mod.rs index cb867539..148cc43f 100644 --- a/src/bits/boolean/mod.rs +++ b/src/bits/boolean/mod.rs @@ -1,11 +1,12 @@ use ark_ff::{BitIteratorBE, Field, PrimeField}; -use crate::{fields::fp::FpVar, prelude::*, Assignment, ToConstraintFieldGadget, Vec}; +use crate::{fields::fp::FpVar, prelude::*, ToConstraintFieldGadget, Vec}; use ark_relations::r1cs::{ ConstraintSystemRef, LinearCombination, Namespace, SynthesisError, Variable, }; use core::borrow::Borrow; +mod allocated; mod and; mod cmp; mod eq; @@ -14,230 +15,13 @@ mod or; mod select; mod xor; +pub use allocated::AllocatedBool; + #[cfg(test)] mod test_utils; -/// Represents a variable in the constraint system which is guaranteed -/// to be either zero or one. -/// -/// In general, one should prefer using `Boolean` instead of `AllocatedBool`, -/// as `Boolean` offers better support for constant values, and implements -/// more traits. -#[derive(Clone, Debug, Eq, PartialEq)] -#[must_use] -pub struct AllocatedBool { - variable: Variable, - cs: ConstraintSystemRef, -} - pub(crate) fn bool_to_field(val: impl Borrow) -> F { - if *val.borrow() { - F::one() - } else { - F::zero() - } -} - -impl AllocatedBool { - /// Get the assigned value for `self`. - pub fn value(&self) -> Result { - let value = self.cs.assigned_value(self.variable).get()?; - if value.is_zero() { - Ok(false) - } else if value.is_one() { - Ok(true) - } else { - unreachable!("Incorrect value assigned: {:?}", value); - } - } - - /// Get the R1CS variable for `self`. - pub fn variable(&self) -> Variable { - self.variable - } - - /// Allocate a witness variable without a booleanity check. - fn new_witness_without_booleanity_check>( - cs: ConstraintSystemRef, - f: impl FnOnce() -> Result, - ) -> Result { - let variable = cs.new_witness_variable(|| f().map(bool_to_field))?; - Ok(Self { variable, cs }) - } - - /// Performs an XOR operation over the two operands, returning - /// an `AllocatedBool`. - #[tracing::instrument(target = "r1cs")] - pub fn not(&self) -> Result { - let variable = self.cs.new_lc(lc!() + Variable::One - self.variable)?; - Ok(Self { - variable, - cs: self.cs.clone(), - }) - } - - /// Performs an XOR operation over the two operands, returning - /// an `AllocatedBool`. - #[tracing::instrument(target = "r1cs")] - pub fn xor(&self, b: &Self) -> Result { - let result = Self::new_witness_without_booleanity_check(self.cs.clone(), || { - Ok(self.value()? ^ b.value()?) - })?; - - // Constrain (a + a) * (b) = (a + b - c) - // Given that a and b are boolean constrained, if they - // are equal, the only solution for c is 0, and if they - // are different, the only solution for c is 1. - // - // ¬(a ∧ b) ∧ ¬(¬a ∧ ¬b) = c - // (1 - (a * b)) * (1 - ((1 - a) * (1 - b))) = c - // (1 - ab) * (1 - (1 - a - b + ab)) = c - // (1 - ab) * (a + b - ab) = c - // a + b - ab - (a^2)b - (b^2)a + (a^2)(b^2) = c - // a + b - ab - ab - ab + ab = c - // a + b - 2ab = c - // -2a * b = c - a - b - // 2a * b = a + b - c - // (a + a) * b = a + b - c - self.cs.enforce_constraint( - lc!() + self.variable + self.variable, - lc!() + b.variable, - lc!() + self.variable + b.variable - result.variable, - )?; - - Ok(result) - } - - /// Performs an AND operation over the two operands, returning - /// an `AllocatedBool`. - #[tracing::instrument(target = "r1cs")] - pub fn and(&self, b: &Self) -> Result { - let result = Self::new_witness_without_booleanity_check(self.cs.clone(), || { - Ok(self.value()? & b.value()?) - })?; - - // Constrain (a) * (b) = (c), ensuring c is 1 iff - // a AND b are both 1. - self.cs.enforce_constraint( - lc!() + self.variable, - lc!() + b.variable, - lc!() + result.variable, - )?; - - Ok(result) - } - - /// Performs an OR operation over the two operands, returning - /// an `AllocatedBool`. - #[tracing::instrument(target = "r1cs")] - pub fn or(&self, b: &Self) -> Result { - let result = Self::new_witness_without_booleanity_check(self.cs.clone(), || { - Ok(self.value()? | b.value()?) - })?; - - // Constrain (1 - a) * (1 - b) = (c), ensuring c is 1 iff - // a and b are both false, and otherwise c is 0. - self.cs.enforce_constraint( - lc!() + Variable::One - self.variable, - lc!() + Variable::One - b.variable, - lc!() + Variable::One - result.variable, - )?; - - Ok(result) - } - - /// Calculates `a AND (NOT b)`. - #[tracing::instrument(target = "r1cs")] - pub fn and_not(&self, b: &Self) -> Result { - let result = Self::new_witness_without_booleanity_check(self.cs.clone(), || { - Ok(self.value()? & !b.value()?) - })?; - - // Constrain (a) * (1 - b) = (c), ensuring c is 1 iff - // a is true and b is false, and otherwise c is 0. - self.cs.enforce_constraint( - lc!() + self.variable, - lc!() + Variable::One - b.variable, - lc!() + result.variable, - )?; - - Ok(result) - } - - /// Calculates `(NOT a) AND (NOT b)`. - #[tracing::instrument(target = "r1cs")] - pub fn nor(&self, b: &Self) -> Result { - let result = Self::new_witness_without_booleanity_check(self.cs.clone(), || { - Ok(!(self.value()? | b.value()?)) - })?; - - // Constrain (1 - a) * (1 - b) = (c), ensuring c is 1 iff - // a and b are both false, and otherwise c is 0. - self.cs.enforce_constraint( - lc!() + Variable::One - self.variable, - lc!() + Variable::One - b.variable, - lc!() + result.variable, - )?; - - Ok(result) - } -} - -impl AllocVar for AllocatedBool { - /// Produces a new variable of the appropriate kind - /// (instance or witness), with a booleanity check. - /// - /// N.B.: we could omit the booleanity check when allocating `self` - /// as a new public input, but that places an additional burden on - /// protocol designers. Better safe than sorry! - fn new_variable>( - cs: impl Into>, - f: impl FnOnce() -> Result, - mode: AllocationMode, - ) -> Result { - let ns = cs.into(); - let cs = ns.cs(); - if mode == AllocationMode::Constant { - let variable = if *f()?.borrow() { - Variable::One - } else { - Variable::Zero - }; - Ok(Self { variable, cs }) - } else { - let variable = if mode == AllocationMode::Input { - cs.new_input_variable(|| f().map(bool_to_field))? - } else { - cs.new_witness_variable(|| f().map(bool_to_field))? - }; - - // Constrain: (1 - a) * a = 0 - // This constrains a to be either 0 or 1. - - cs.enforce_constraint(lc!() + Variable::One - variable, lc!() + variable, lc!())?; - - Ok(Self { variable, cs }) - } - } -} - -impl CondSelectGadget for AllocatedBool { - #[tracing::instrument(target = "r1cs")] - fn conditionally_select( - cond: &Boolean, - true_val: &Self, - false_val: &Self, - ) -> Result { - let res = Boolean::conditionally_select( - cond, - &true_val.clone().into(), - &false_val.clone().into(), - )?; - match res { - Boolean::Var(a) => Ok(a), - _ => unreachable!("Impossible"), - } - } + F::from(*val.borrow()) } /// Represents a boolean value in the constraint system which is guaranteed @@ -342,118 +126,6 @@ impl Boolean { } } - /// Outputs `bits[0] & bits[1] & ... & bits.last().unwrap()`. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// - /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; - /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; - /// let c = Boolean::new_witness(cs.clone(), || Ok(true))?; - /// - /// Boolean::kary_and(&[a.clone(), b.clone(), c.clone()])?.enforce_equal(&Boolean::FALSE)?; - /// Boolean::kary_and(&[a.clone(), c.clone()])?.enforce_equal(&Boolean::TRUE)?; - /// - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - #[tracing::instrument(target = "r1cs")] - pub fn kary_and(bits: &[Self]) -> Result { - assert!(!bits.is_empty()); - let mut cur: Option = None; - for next in bits { - cur = if let Some(b) = cur { - Some(b & next) - } else { - Some(next.clone()) - }; - } - - Ok(cur.expect("should not be 0")) - } - - /// Outputs `bits[0] | bits[1] | ... | bits.last().unwrap()`. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// - /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; - /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; - /// let c = Boolean::new_witness(cs.clone(), || Ok(false))?; - /// - /// Boolean::kary_or(&[a.clone(), b.clone(), c.clone()])?.enforce_equal(&Boolean::TRUE)?; - /// Boolean::kary_or(&[a.clone(), c.clone()])?.enforce_equal(&Boolean::TRUE)?; - /// Boolean::kary_or(&[b.clone(), c.clone()])?.enforce_equal(&Boolean::FALSE)?; - /// - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - #[tracing::instrument(target = "r1cs")] - pub fn kary_or(bits: &[Self]) -> Result { - assert!(!bits.is_empty()); - let mut cur: Option = None; - for next in bits { - cur = if let Some(b) = cur { - Some(b | next) - } else { - Some(next.clone()) - }; - } - - Ok(cur.expect("should not be 0")) - } - - /// Outputs `(bits[0] & bits[1] & ... & bits.last().unwrap()).not()`. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// - /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; - /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; - /// let c = Boolean::new_witness(cs.clone(), || Ok(true))?; - /// - /// Boolean::kary_nand(&[a.clone(), b.clone(), c.clone()])?.enforce_equal(&Boolean::TRUE)?; - /// Boolean::kary_nand(&[a.clone(), c.clone()])?.enforce_equal(&Boolean::FALSE)?; - /// Boolean::kary_nand(&[b.clone(), c.clone()])?.enforce_equal(&Boolean::TRUE)?; - /// - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - #[tracing::instrument(target = "r1cs")] - pub fn kary_nand(bits: &[Self]) -> Result { - Ok(!Self::kary_and(bits)?) - } - - /// Enforces that `Self::kary_nand(bits).is_eq(&Boolean::TRUE)`. - /// - /// Informally, this means that at least one element in `bits` must be - /// `false`. - #[tracing::instrument(target = "r1cs")] - fn enforce_kary_nand(bits: &[Self]) -> Result<(), SynthesisError> { - Self::kary_nand(bits)?.enforce_equal(&Boolean::TRUE) - } - /// Convert a little-endian bitwise representation of a field element to /// `FpVar` #[tracing::instrument(target = "r1cs", skip(bits))] @@ -512,7 +184,10 @@ impl Boolean { /// integer, and enforce that this integer is "in the field Z_p", where /// `p = F::characteristic()` . #[tracing::instrument(target = "r1cs")] - pub fn enforce_in_field_le(bits: &[Self]) -> Result<(), SynthesisError> { + pub fn enforce_in_field_le(bits: &[Self]) -> Result<(), SynthesisError> + where + F: PrimeField, + { // `bits` < F::characteristic() <==> `bits` <= F::characteristic() -1 let mut b = F::characteristic().to_vec(); assert_eq!(b[0] % 2, 1); @@ -533,7 +208,10 @@ impl Boolean { pub fn enforce_smaller_or_equal_than_le<'a>( bits: &[Self], element: impl AsRef<[u64]>, - ) -> Result, SynthesisError> { + ) -> Result, SynthesisError> + where + F: PrimeField, + { let b: &[u64] = element.as_ref(); let mut bits_iter = bits.iter().rev(); // Iterate in big-endian @@ -583,41 +261,6 @@ impl Boolean { Ok(current_run) } - - /// Conditionally selects one of `first` and `second` based on the value of - /// `self`: - /// - /// If `self.is_eq(&Boolean::TRUE)`, this outputs `first`; else, it outputs - /// `second`. - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// - /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; - /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; - /// - /// let cond = Boolean::new_witness(cs.clone(), || Ok(true))?; - /// - /// cond.select(&a, &b)?.enforce_equal(&Boolean::TRUE)?; - /// cond.select(&b, &a)?.enforce_equal(&Boolean::FALSE)?; - /// - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - #[tracing::instrument(target = "r1cs", skip(first, second))] - pub fn select>( - &self, - first: &T, - second: &T, - ) -> Result { - T::conditionally_select(&self, first, second) - } } impl From> for Boolean { @@ -661,7 +304,7 @@ impl ToConstraintFieldGadget for Boolean { #[cfg(test)] mod test { - use super::{AllocatedBool, Boolean}; + use super::Boolean; use crate::prelude::*; use ark_ff::{BitIteratorBE, BitIteratorLE, Field, One, PrimeField, UniformRand}; use ark_relations::r1cs::{ConstraintSystem, SynthesisError}; @@ -684,101 +327,6 @@ mod test { Ok(()) } - #[test] - fn allocated_xor() -> Result<(), SynthesisError> { - for a_val in [false, true].iter().copied() { - for b_val in [false, true].iter().copied() { - let cs = ConstraintSystem::::new_ref(); - let a = AllocatedBool::new_witness(cs.clone(), || Ok(a_val))?; - let b = AllocatedBool::new_witness(cs.clone(), || Ok(b_val))?; - let c = AllocatedBool::xor(&a, &b)?; - assert_eq!(c.value()?, a_val ^ b_val); - - assert!(cs.is_satisfied().unwrap()); - assert_eq!(a.value()?, (a_val)); - assert_eq!(b.value()?, (b_val)); - assert_eq!(c.value()?, (a_val ^ b_val)); - } - } - Ok(()) - } - - #[test] - fn allocated_or() -> Result<(), SynthesisError> { - for a_val in [false, true].iter().copied() { - for b_val in [false, true].iter().copied() { - let cs = ConstraintSystem::::new_ref(); - let a = AllocatedBool::new_witness(cs.clone(), || Ok(a_val))?; - let b = AllocatedBool::new_witness(cs.clone(), || Ok(b_val))?; - let c = AllocatedBool::or(&a, &b)?; - assert_eq!(c.value()?, a_val | b_val); - - assert!(cs.is_satisfied().unwrap()); - assert_eq!(a.value()?, (a_val)); - assert_eq!(b.value()?, (b_val)); - assert_eq!(c.value()?, (a_val | b_val)); - } - } - Ok(()) - } - - #[test] - fn allocated_and() -> Result<(), SynthesisError> { - for a_val in [false, true].iter().copied() { - for b_val in [false, true].iter().copied() { - let cs = ConstraintSystem::::new_ref(); - let a = AllocatedBool::new_witness(cs.clone(), || Ok(a_val))?; - let b = AllocatedBool::new_witness(cs.clone(), || Ok(b_val))?; - let c = AllocatedBool::and(&a, &b)?; - assert_eq!(c.value()?, a_val & b_val); - - assert!(cs.is_satisfied().unwrap()); - assert_eq!(a.value()?, (a_val)); - assert_eq!(b.value()?, (b_val)); - assert_eq!(c.value()?, (a_val & b_val)); - } - } - Ok(()) - } - - #[test] - fn allocated_and_not() -> Result<(), SynthesisError> { - for a_val in [false, true].iter().copied() { - for b_val in [false, true].iter().copied() { - let cs = ConstraintSystem::::new_ref(); - let a = AllocatedBool::new_witness(cs.clone(), || Ok(a_val))?; - let b = AllocatedBool::new_witness(cs.clone(), || Ok(b_val))?; - let c = AllocatedBool::and_not(&a, &b)?; - assert_eq!(c.value()?, a_val & !b_val); - - assert!(cs.is_satisfied().unwrap()); - assert_eq!(a.value()?, (a_val)); - assert_eq!(b.value()?, (b_val)); - assert_eq!(c.value()?, (a_val & !b_val)); - } - } - Ok(()) - } - - #[test] - fn allocated_nor() -> Result<(), SynthesisError> { - for a_val in [false, true].iter().copied() { - for b_val in [false, true].iter().copied() { - let cs = ConstraintSystem::::new_ref(); - let a = AllocatedBool::new_witness(cs.clone(), || Ok(a_val))?; - let b = AllocatedBool::new_witness(cs.clone(), || Ok(b_val))?; - let c = AllocatedBool::nor(&a, &b)?; - assert_eq!(c.value()?, !a_val & !b_val); - - assert!(cs.is_satisfied().unwrap()); - assert_eq!(a.value()?, (a_val)); - assert_eq!(b.value()?, (b_val)); - assert_eq!(c.value()?, (!a_val & !b_val)); - } - } - Ok(()) - } - #[test] fn test_smaller_than_or_equal_to() -> Result<(), SynthesisError> { let mut rng = ark_std::test_rng(); diff --git a/src/bits/boolean/or.rs b/src/bits/boolean/or.rs index 5898d8a5..a107579c 100644 --- a/src/bits/boolean/or.rs +++ b/src/bits/boolean/or.rs @@ -13,6 +13,44 @@ impl Boolean { (Var(ref x), Var(ref y)) => Ok(Var(x.or(y)?)), } } + + /// Outputs `bits[0] | bits[1] | ... | bits.last().unwrap()`. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// + /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; + /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; + /// let c = Boolean::new_witness(cs.clone(), || Ok(false))?; + /// + /// Boolean::kary_or(&[a.clone(), b.clone(), c.clone()])?.enforce_equal(&Boolean::TRUE)?; + /// Boolean::kary_or(&[a.clone(), c.clone()])?.enforce_equal(&Boolean::TRUE)?; + /// Boolean::kary_or(&[b.clone(), c.clone()])?.enforce_equal(&Boolean::FALSE)?; + /// + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs")] + pub fn kary_or(bits: &[Self]) -> Result { + assert!(!bits.is_empty()); + let mut cur: Option = None; + for next in bits { + cur = if let Some(b) = cur { + Some(b | next) + } else { + Some(next.clone()) + }; + } + + Ok(cur.expect("should not be 0")) + } } impl<'a, F: Field> BitOr for &'a Boolean { diff --git a/src/bits/boolean/select.rs b/src/bits/boolean/select.rs index 7549caa8..b09955e8 100644 --- a/src/bits/boolean/select.rs +++ b/src/bits/boolean/select.rs @@ -1,5 +1,41 @@ use super::*; +impl Boolean { + /// Conditionally selects one of `first` and `second` based on the value of + /// `self`: + /// + /// If `self.is_eq(&Boolean::TRUE)`, this outputs `first`; else, it outputs + /// `second`. + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// + /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; + /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; + /// + /// let cond = Boolean::new_witness(cs.clone(), || Ok(true))?; + /// + /// cond.select(&a, &b)?.enforce_equal(&Boolean::TRUE)?; + /// cond.select(&b, &a)?.enforce_equal(&Boolean::FALSE)?; + /// + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(first, second))] + pub fn select>( + &self, + first: &T, + second: &T, + ) -> Result { + T::conditionally_select(&self, first, second) + } +} impl CondSelectGadget for Boolean { #[tracing::instrument(target = "r1cs")] fn conditionally_select( diff --git a/src/eq.rs b/src/eq.rs index 71f10c3d..4f2c066b 100644 --- a/src/eq.rs +++ b/src/eq.rs @@ -1,5 +1,5 @@ use crate::{prelude::*, Vec}; -use ark_ff::Field; +use ark_ff::{Field, PrimeField}; use ark_relations::r1cs::SynthesisError; /// Specifies how to generate constraints that check for equality for two @@ -82,7 +82,7 @@ pub trait EqGadget { } } -impl + R1CSVar, F: Field> EqGadget for [T] { +impl + R1CSVar, F: PrimeField> EqGadget for [T] { #[tracing::instrument(target = "r1cs", skip(self, other))] fn is_eq(&self, other: &Self) -> Result, SynthesisError> { assert_eq!(self.len(), other.len()); From d30395502264f9610d672f3a7b1d0fb67547eb43 Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Wed, 7 Jun 2023 11:03:18 -0700 Subject: [PATCH 16/29] Move bit and byte converstions into their own files --- src/bits/mod.rs | 33 --------------------------- src/bits/uint/convert.rs | 48 +++++++++++++++++++++++++++++++++++----- src/bits/uint8.rs | 6 ++--- src/poly/domain/mod.rs | 2 ++ 4 files changed, 47 insertions(+), 42 deletions(-) diff --git a/src/bits/mod.rs b/src/bits/mod.rs index 9dcabec5..f44a221f 100644 --- a/src/bits/mod.rs +++ b/src/bits/mod.rs @@ -73,21 +73,6 @@ impl ToBitsGadget for [Boolean] { } } -impl ToBitsGadget for UInt8 { - fn to_bits_le(&self) -> Result>, SynthesisError> { - Ok(self.bits.to_vec()) - } -} - -impl ToBitsGadget for [UInt8] { - /// Interprets `self` as an integer, and outputs the little-endian - /// bit-wise decomposition of that integer. - fn to_bits_le(&self) -> Result>, SynthesisError> { - let bits = self.iter().flat_map(|b| &b.bits).cloned().collect(); - Ok(bits) - } -} - impl ToBitsGadget for Vec where [T]: ToBitsGadget, @@ -118,26 +103,8 @@ pub trait ToBytesGadget { } } -impl ToBytesGadget for [UInt8] { - fn to_bytes(&self) -> Result>, SynthesisError> { - Ok(self.to_vec()) - } -} - -impl ToBytesGadget for Vec> { - fn to_bytes(&self) -> Result>, SynthesisError> { - Ok(self.clone()) - } -} - impl<'a, F: Field, T: 'a + ToBytesGadget> ToBytesGadget for &'a T { fn to_bytes(&self) -> Result>, SynthesisError> { (*self).to_bytes() } } - -impl<'a, F: Field> ToBytesGadget for &'a [UInt8] { - fn to_bytes(&self) -> Result>, SynthesisError> { - Ok(self.to_vec()) - } -} diff --git a/src/bits/uint/convert.rs b/src/bits/uint/convert.rs index 3308a145..bb618c87 100644 --- a/src/bits/uint/convert.rs +++ b/src/bits/uint/convert.rs @@ -31,11 +31,6 @@ impl UInt { Ok((result, rest)) } - /// Turns `self` into the underlying little-endian bits. - pub fn to_bits_le(&self) -> Vec> { - self.bits.to_vec() - } - /// Converts a little-endian byte order representation of bits into a /// `UInt`. /// @@ -78,15 +73,56 @@ impl UInt { } } +impl ToBitsGadget for UInt { + fn to_bits_le(&self) -> Result>, SynthesisError> { + Ok(self.bits.to_vec()) + } +} + +impl ToBitsGadget for [UInt] { + /// Interprets `self` as an integer, and outputs the little-endian + /// bit-wise decomposition of that integer. + fn to_bits_le(&self) -> Result>, SynthesisError> { + let bits = self.iter().flat_map(|b| &b.bits).cloned().collect(); + Ok(bits) + } +} + +/*****************************************************************************************/ +/********************************* Conversions to bytes. *********************************/ +/*****************************************************************************************/ + impl ToBytesGadget for UInt { #[tracing::instrument(target = "r1cs", skip(self))] fn to_bytes(&self) -> Result>, SynthesisError> { Ok(self - .to_bits_le() + .to_bits_le()? .chunks(8) .map(UInt8::from_bits_le) .collect()) } } + +impl ToBytesGadget for [UInt] { + fn to_bytes(&self) -> Result>, SynthesisError> { + let mut bytes = Vec::with_capacity(self.len() * (N / 8)); + for elem in self { + bytes.extend_from_slice(&elem.to_bytes()?); + } + Ok(bytes) + } +} + +impl ToBytesGadget for Vec> { + fn to_bytes(&self) -> Result>, SynthesisError> { + self.as_slice().to_bytes() + } +} + +impl<'a, const N: usize, T: PrimInt + Debug, F: Field> ToBytesGadget for &'a [UInt] { + fn to_bytes(&self) -> Result>, SynthesisError> { + (*self).to_bytes() + } +} diff --git a/src/bits/uint8.rs b/src/bits/uint8.rs index d54fd985..005cf0df 100644 --- a/src/bits/uint8.rs +++ b/src/bits/uint8.rs @@ -114,7 +114,7 @@ mod test { let byte_val = 0b01110001; let byte = UInt8::new_witness(ark_relations::ns!(cs, "alloc value"), || Ok(byte_val)).unwrap(); - let bits = byte.to_bits_le(); + let bits = byte.to_bits_le()?; for (i, bit) in bits.iter().enumerate() { assert_eq!(bit.value()?, (byte_val >> i) & 1 == 1) } @@ -129,7 +129,7 @@ mod test { UInt8::new_input_vec(ark_relations::ns!(cs, "alloc value"), &byte_vals).unwrap(); dbg!(bytes.value())?; for (native, variable) in byte_vals.into_iter().zip(bytes) { - let bits = variable.to_bits_le(); + let bits = variable.to_bits_le()?; for (i, bit) in bits.iter().enumerate() { assert_eq!( bit.value()?, @@ -162,7 +162,7 @@ mod test { } } - let expected_to_be_same = val.to_bits_le(); + let expected_to_be_same = val.to_bits_le()?; for x in v.iter().zip(expected_to_be_same.iter()) { match x { diff --git a/src/poly/domain/mod.rs b/src/poly/domain/mod.rs index 8c450702..0b479451 100644 --- a/src/poly/domain/mod.rs +++ b/src/poly/domain/mod.rs @@ -149,12 +149,14 @@ mod tests { dbg!(UInt32::new_witness(cs.clone(), || Ok(coset_index)) .unwrap() .to_bits_le() + .unwrap() .iter() .map(|x| x.value().unwrap() as u8) .collect::>()); let coset_index_var = UInt32::new_witness(cs.clone(), || Ok(coset_index)) .unwrap() .to_bits_le() + .unwrap() .into_iter() .take(COSET_DIM as usize) .collect::>(); From 5c8087f360d7a2d40c6723848a015d9ce4d05c15 Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Wed, 7 Jun 2023 13:55:38 -0700 Subject: [PATCH 17/29] Clean up --- src/bits/boolean/allocated.rs | 14 +- src/bits/boolean/and.rs | 117 ++++++++++- src/bits/boolean/cmp.rs | 94 +++++++++ src/bits/boolean/convert.rs | 21 ++ src/bits/boolean/mod.rs | 193 +----------------- src/bits/boolean/or.rs | 45 ++-- src/bits/boolean/select.rs | 4 +- src/bits/mod.rs | 89 -------- src/bits/uint/add.rs | 38 ---- src/bits/uint/cmp.rs | 18 +- src/bits/uint/convert.rs | 1 + src/bits/uint/mod.rs | 38 ---- src/bits/uint/or.rs | 20 +- src/bits/uint/select.rs | 2 +- src/bits/uint8.rs | 6 +- src/cmp.rs | 21 ++ src/convert.rs | 95 +++++++++ src/fields/cubic_extension.rs | 3 +- src/fields/fp/cmp.rs | 2 +- src/fields/fp/mod.rs | 3 +- src/fields/mod.rs | 3 +- src/fields/nonnative/allocated_field_var.rs | 6 +- src/fields/nonnative/field_var.rs | 3 +- src/fields/quadratic_extension.rs | 3 +- .../curves/short_weierstrass/mnt4/mod.rs | 1 + .../curves/short_weierstrass/mnt6/mod.rs | 1 + src/groups/curves/short_weierstrass/mod.rs | 7 +- src/groups/curves/twisted_edwards/mod.rs | 6 +- src/groups/mod.rs | 9 +- src/lib.rs | 35 ++-- src/pairing/mod.rs | 9 +- src/poly/domain/mod.rs | 5 +- 32 files changed, 471 insertions(+), 441 deletions(-) create mode 100644 src/bits/boolean/convert.rs create mode 100644 src/cmp.rs create mode 100644 src/convert.rs diff --git a/src/bits/boolean/allocated.rs b/src/bits/boolean/allocated.rs index 4849ddce..9397f12d 100644 --- a/src/bits/boolean/allocated.rs +++ b/src/bits/boolean/allocated.rs @@ -1,6 +1,6 @@ use core::borrow::Borrow; -use ark_ff::Field; +use ark_ff::{Field, PrimeField}; use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError, Variable}; use crate::{ @@ -9,7 +9,7 @@ use crate::{ Assignment, }; -use super::{bool_to_field, Boolean}; +use super::Boolean; /// Represents a variable in the constraint system which is guaranteed /// to be either zero or one. @@ -24,6 +24,10 @@ pub struct AllocatedBool { pub(super) cs: ConstraintSystemRef, } +pub(crate) fn bool_to_field(val: impl Borrow) -> F { + F::from(*val.borrow()) +} + impl AllocatedBool { /// Get the assigned value for `self`. pub fn value(&self) -> Result { @@ -122,8 +126,8 @@ impl AllocatedBool { Ok(self.value()? | b.value()?) })?; - // Constrain (1 - a) * (1 - b) = (c), ensuring c is 1 iff - // a and b are both false, and otherwise c is 0. + // Constrain (1 - a) * (1 - b) = (1 - c), ensuring c is 0 iff + // a and b are both false, and otherwise c is 1. self.cs.enforce_constraint( lc!() + Variable::One - self.variable, lc!() + Variable::One - b.variable, @@ -208,7 +212,7 @@ impl AllocVar for AllocatedBool { } } -impl CondSelectGadget for AllocatedBool { +impl CondSelectGadget for AllocatedBool { #[tracing::instrument(target = "r1cs")] fn conditionally_select( cond: &Boolean, diff --git a/src/bits/boolean/and.rs b/src/bits/boolean/and.rs index 61cc70ff..0b846180 100644 --- a/src/bits/boolean/and.rs +++ b/src/bits/boolean/and.rs @@ -17,6 +17,11 @@ impl Boolean { (Var(ref x), Var(ref y)) => Ok(Var(x.and(y)?)), } } + + /// Outputs `!(self & other)`. + pub fn nand(&self, other: &Self) -> Result { + self._and(other).map(|x| !x) + } } impl Boolean { @@ -91,13 +96,13 @@ impl Boolean { Ok(!Self::kary_and(bits)?) } - /// Enforces that `Self::kary_nand(bits).is_eq(&Boolean::TRUE)`. + /// Enforces that `!(bits[0] & bits[1] & ... ) == Boolean::TRUE`. /// /// Informally, this means that at least one element in `bits` must be /// `false`. #[tracing::instrument(target = "r1cs")] pub fn enforce_kary_nand(bits: &[Self]) -> Result<(), SynthesisError> { - Self::kary_nand(bits)?.enforce_equal(&Boolean::TRUE) + Self::kary_and(bits)?.enforce_equal(&Boolean::FALSE) } } @@ -190,6 +195,7 @@ mod tests { prelude::EqGadget, R1CSVar, }; + use ark_relations::r1cs::ConstraintSystem; use ark_test_curves::bls12_381::Fr; #[test] @@ -214,4 +220,111 @@ mod tests { }) .unwrap() } + + #[test] + fn nand() { + run_binary_exhaustive::(|a, b| { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let computed = a.nand(&b)?; + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = Boolean::new_variable( + cs.clone(), + || Ok(!(a.value()? & b.value()?)), + expected_mode, + )?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + }) + .unwrap() + } + + #[test] + fn enforce_nand() -> Result<(), SynthesisError> { + { + let cs = ConstraintSystem::::new_ref(); + + assert!( + Boolean::enforce_kary_nand(&[Boolean::new_constant(cs.clone(), false)?]).is_ok() + ); + assert!( + Boolean::enforce_kary_nand(&[Boolean::new_constant(cs.clone(), true)?]).is_err() + ); + } + + for i in 1..5 { + // with every possible assignment for them + for mut b in 0..(1 << i) { + // with every possible negation + for mut n in 0..(1 << i) { + let cs = ConstraintSystem::::new_ref(); + + let mut expected = true; + + let mut bits = vec![]; + for _ in 0..i { + expected &= b & 1 == 1; + + let bit = if n & 1 == 1 { + Boolean::new_witness(cs.clone(), || Ok(b & 1 == 1))? + } else { + !Boolean::new_witness(cs.clone(), || Ok(b & 1 == 0))? + }; + bits.push(bit); + + b >>= 1; + n >>= 1; + } + + let expected = !expected; + + Boolean::enforce_kary_nand(&bits)?; + + if expected { + assert!(cs.is_satisfied().unwrap()); + } else { + assert!(!cs.is_satisfied().unwrap()); + } + } + } + } + Ok(()) + } + + #[test] + fn kary_and() -> Result<(), SynthesisError> { + // test different numbers of operands + for i in 1..15 { + // with every possible assignment for them + for mut b in 0..(1 << i) { + let cs = ConstraintSystem::::new_ref(); + + let mut expected = true; + + let mut bits = vec![]; + for _ in 0..i { + expected &= b & 1 == 1; + bits.push(Boolean::new_witness(cs.clone(), || Ok(b & 1 == 1))?); + b >>= 1; + } + + let r = Boolean::kary_and(&bits)?; + + assert!(cs.is_satisfied().unwrap()); + + if let Boolean::Var(ref r) = r { + assert_eq!(r.value()?, expected); + } + } + } + Ok(()) + } } diff --git a/src/bits/boolean/cmp.rs b/src/bits/boolean/cmp.rs index 8b137891..a8c133a8 100644 --- a/src/bits/boolean/cmp.rs +++ b/src/bits/boolean/cmp.rs @@ -1 +1,95 @@ +use crate::cmp::CmpGadget; +use super::*; +use ark_ff::PrimeField; + +impl CmpGadget for Boolean { + fn is_ge(&self, other: &Self) -> Result, SynthesisError> { + // a | b | (a | !b) | a >= b + // --|---|--------|-------- + // 0 | 0 | 1 | 1 + // 1 | 0 | 1 | 1 + // 0 | 1 | 0 | 0 + // 1 | 1 | 1 | 1 + Ok(self | &(!other)) + } +} + +impl Boolean { + /// Enforces that `bits`, when interpreted as a integer, is less than + /// `F::characteristic()`, That is, interpret bits as a little-endian + /// integer, and enforce that this integer is "in the field Z_p", where + /// `p = F::characteristic()` . + #[tracing::instrument(target = "r1cs")] + pub fn enforce_in_field_le(bits: &[Self]) -> Result<(), SynthesisError> { + // `bits` < F::characteristic() <==> `bits` <= F::characteristic() -1 + let mut b = F::characteristic().to_vec(); + assert_eq!(b[0] % 2, 1); + b[0] -= 1; // This works, because the LSB is one, so there's no borrows. + let run = Self::enforce_smaller_or_equal_than_le(bits, b)?; + + // We should always end in a "run" of zeros, because + // the characteristic is an odd prime. So, this should + // be empty. + assert!(run.is_empty()); + + Ok(()) + } + + /// Enforces that `bits` is less than or equal to `element`, + /// when both are interpreted as (little-endian) integers. + #[tracing::instrument(target = "r1cs", skip(element))] + pub fn enforce_smaller_or_equal_than_le( + bits: &[Self], + element: impl AsRef<[u64]>, + ) -> Result, SynthesisError> { + let b: &[u64] = element.as_ref(); + + let mut bits_iter = bits.iter().rev(); // Iterate in big-endian + + // Runs of ones in r + let mut last_run = Boolean::constant(true); + let mut current_run = vec![]; + + let mut element_num_bits = 0; + for _ in BitIteratorBE::without_leading_zeros(b) { + element_num_bits += 1; + } + + if bits.len() > element_num_bits { + let mut or_result = Boolean::constant(false); + for should_be_zero in &bits[element_num_bits..] { + or_result |= should_be_zero; + let _ = bits_iter.next().unwrap(); + } + or_result.enforce_equal(&Boolean::constant(false))?; + } + + for (b, a) in BitIteratorBE::without_leading_zeros(b).zip(bits_iter.by_ref()) { + if b { + // This is part of a run of ones. + current_run.push(a.clone()); + } else { + if !current_run.is_empty() { + // This is the start of a run of zeros, but we need + // to k-ary AND against `last_run` first. + + current_run.push(last_run.clone()); + last_run = Self::kary_and(¤t_run)?; + current_run.truncate(0); + } + + // If `last_run` is true, `a` must be false, or it would + // not be in the field. + // + // If `last_run` is false, `a` can be true or false. + // + // Ergo, at least one of `last_run` and `a` must be false. + Self::enforce_kary_nand(&[last_run.clone(), a.clone()])?; + } + } + assert!(bits_iter.next().is_none()); + + Ok(current_run) + } +} diff --git a/src/bits/boolean/convert.rs b/src/bits/boolean/convert.rs new file mode 100644 index 00000000..21e9fc09 --- /dev/null +++ b/src/bits/boolean/convert.rs @@ -0,0 +1,21 @@ +use super::*; +use crate::convert::{ToBytesGadget, ToConstraintFieldGadget}; + +impl ToBytesGadget for Boolean { + /// Outputs `1u8` if `self` is true, and `0u8` otherwise. + #[tracing::instrument(target = "r1cs")] + fn to_bytes(&self) -> Result>, SynthesisError> { + let value = self.value().map(u8::from).ok(); + let mut bits = [Boolean::FALSE; 8]; + bits[0] = self.clone(); + Ok(vec![UInt8 { bits, value }]) + } +} + +impl ToConstraintFieldGadget for Boolean { + #[tracing::instrument(target = "r1cs")] + fn to_constraint_field(&self) -> Result>, SynthesisError> { + let var = From::from(self.clone()); + Ok(vec![var]) + } +} diff --git a/src/bits/boolean/mod.rs b/src/bits/boolean/mod.rs index 148cc43f..401c799d 100644 --- a/src/bits/boolean/mod.rs +++ b/src/bits/boolean/mod.rs @@ -1,6 +1,6 @@ use ark_ff::{BitIteratorBE, Field, PrimeField}; -use crate::{fields::fp::FpVar, prelude::*, ToConstraintFieldGadget, Vec}; +use crate::{fields::fp::FpVar, prelude::*, Vec}; use ark_relations::r1cs::{ ConstraintSystemRef, LinearCombination, Namespace, SynthesisError, Variable, }; @@ -9,6 +9,7 @@ use core::borrow::Borrow; mod allocated; mod and; mod cmp; +mod convert; mod eq; mod not; mod or; @@ -20,10 +21,6 @@ pub use allocated::AllocatedBool; #[cfg(test)] mod test_utils; -pub(crate) fn bool_to_field(val: impl Borrow) -> F { - F::from(*val.borrow()) -} - /// Represents a boolean value in the constraint system which is guaranteed /// to be either zero or one. #[derive(Clone, Debug, Eq, PartialEq)] @@ -128,6 +125,8 @@ impl Boolean { /// Convert a little-endian bitwise representation of a field element to /// `FpVar` + /// + /// Wraps around if the bit representation is larger than the field modulus. #[tracing::instrument(target = "r1cs", skip(bits))] pub fn le_bits_to_fp(bits: &[Self]) -> Result, SynthesisError> where @@ -178,89 +177,6 @@ impl Boolean { Ok(combined) } } - - /// Enforces that `bits`, when interpreted as a integer, is less than - /// `F::characteristic()`, That is, interpret bits as a little-endian - /// integer, and enforce that this integer is "in the field Z_p", where - /// `p = F::characteristic()` . - #[tracing::instrument(target = "r1cs")] - pub fn enforce_in_field_le(bits: &[Self]) -> Result<(), SynthesisError> - where - F: PrimeField, - { - // `bits` < F::characteristic() <==> `bits` <= F::characteristic() -1 - let mut b = F::characteristic().to_vec(); - assert_eq!(b[0] % 2, 1); - b[0] -= 1; // This works, because the LSB is one, so there's no borrows. - let run = Self::enforce_smaller_or_equal_than_le(bits, b)?; - - // We should always end in a "run" of zeros, because - // the characteristic is an odd prime. So, this should - // be empty. - assert!(run.is_empty()); - - Ok(()) - } - - /// Enforces that `bits` is less than or equal to `element`, - /// when both are interpreted as (little-endian) integers. - #[tracing::instrument(target = "r1cs", skip(element))] - pub fn enforce_smaller_or_equal_than_le<'a>( - bits: &[Self], - element: impl AsRef<[u64]>, - ) -> Result, SynthesisError> - where - F: PrimeField, - { - let b: &[u64] = element.as_ref(); - - let mut bits_iter = bits.iter().rev(); // Iterate in big-endian - - // Runs of ones in r - let mut last_run = Boolean::constant(true); - let mut current_run = vec![]; - - let mut element_num_bits = 0; - for _ in BitIteratorBE::without_leading_zeros(b) { - element_num_bits += 1; - } - - if bits.len() > element_num_bits { - let mut or_result = Boolean::constant(false); - for should_be_zero in &bits[element_num_bits..] { - or_result |= should_be_zero; - let _ = bits_iter.next().unwrap(); - } - or_result.enforce_equal(&Boolean::constant(false))?; - } - - for (b, a) in BitIteratorBE::without_leading_zeros(b).zip(bits_iter.by_ref()) { - if b { - // This is part of a run of ones. - current_run.push(a.clone()); - } else { - if !current_run.is_empty() { - // This is the start of a run of zeros, but we need - // to k-ary AND against `last_run` first. - - current_run.push(last_run.clone()); - last_run = Self::kary_and(¤t_run)?; - current_run.truncate(0); - } - - // If `last_run` is true, `a` must be false, or it would - // not be in the field. - // - // If `last_run` is false, `a` can be true or false. - // - // Ergo, at least one of `last_run` and `a` must be false. - Self::enforce_kary_nand(&[last_run.clone(), a.clone()])?; - } - } - assert!(bits_iter.next().is_none()); - - Ok(current_run) - } } impl From> for Boolean { @@ -283,28 +199,10 @@ impl AllocVar for Boolean { } } -impl ToBytesGadget for Boolean { - /// Outputs `1u8` if `self` is true, and `0u8` otherwise. - #[tracing::instrument(target = "r1cs")] - fn to_bytes(&self) -> Result>, SynthesisError> { - let value = self.value().map(u8::from).ok(); - let mut bits = [Boolean::FALSE; 8]; - bits[0] = self.clone(); - Ok(vec![UInt8 { bits, value }]) - } -} - -impl ToConstraintFieldGadget for Boolean { - #[tracing::instrument(target = "r1cs")] - fn to_constraint_field(&self) -> Result>, SynthesisError> { - let var = From::from(self.clone()); - Ok(vec![var]) - } -} - #[cfg(test)] mod test { use super::Boolean; + use crate::convert::ToBytesGadget; use crate::prelude::*; use ark_ff::{BitIteratorBE, BitIteratorLE, Field, One, PrimeField, UniformRand}; use ark_relations::r1cs::{ConstraintSystem, SynthesisError}; @@ -402,87 +300,6 @@ mod test { Ok(()) } - #[test] - fn test_enforce_nand() -> Result<(), SynthesisError> { - { - let cs = ConstraintSystem::::new_ref(); - - assert!( - Boolean::enforce_kary_nand(&[Boolean::new_constant(cs.clone(), false)?]).is_ok() - ); - assert!( - Boolean::enforce_kary_nand(&[Boolean::new_constant(cs.clone(), true)?]).is_err() - ); - } - - for i in 1..5 { - // with every possible assignment for them - for mut b in 0..(1 << i) { - // with every possible negation - for mut n in 0..(1 << i) { - let cs = ConstraintSystem::::new_ref(); - - let mut expected = true; - - let mut bits = vec![]; - for _ in 0..i { - expected &= b & 1 == 1; - - let bit = if n & 1 == 1 { - Boolean::new_witness(cs.clone(), || Ok(b & 1 == 1))? - } else { - !Boolean::new_witness(cs.clone(), || Ok(b & 1 == 0))? - }; - bits.push(bit); - - b >>= 1; - n >>= 1; - } - - let expected = !expected; - - Boolean::enforce_kary_nand(&bits)?; - - if expected { - assert!(cs.is_satisfied().unwrap()); - } else { - assert!(!cs.is_satisfied().unwrap()); - } - } - } - } - Ok(()) - } - - #[test] - fn test_kary_and() -> Result<(), SynthesisError> { - // test different numbers of operands - for i in 1..15 { - // with every possible assignment for them - for mut b in 0..(1 << i) { - let cs = ConstraintSystem::::new_ref(); - - let mut expected = true; - - let mut bits = vec![]; - for _ in 0..i { - expected &= b & 1 == 1; - bits.push(Boolean::new_witness(cs.clone(), || Ok(b & 1 == 1))?); - b >>= 1; - } - - let r = Boolean::kary_and(&bits)?; - - assert!(cs.is_satisfied().unwrap()); - - if let Boolean::Var(ref r) = r { - assert_eq!(r.value()?, expected); - } - } - } - Ok(()) - } - #[test] fn test_bits_to_fp() -> Result<(), SynthesisError> { use AllocationMode::*; diff --git a/src/bits/boolean/or.rs b/src/bits/boolean/or.rs index a107579c..a54bfc71 100644 --- a/src/bits/boolean/or.rs +++ b/src/bits/boolean/or.rs @@ -1,10 +1,15 @@ -use ark_ff::Field; +use ark_ff::PrimeField; use ark_relations::r1cs::SynthesisError; use ark_std::{ops::BitOr, ops::BitOrAssign}; +use crate::{ + eq::EqGadget, + fields::{fp::FpVar, FieldVar}, +}; + use super::Boolean; -impl Boolean { +impl Boolean { fn _or(&self, other: &Self) -> Result { use Boolean::*; match (self, other) { @@ -40,20 +45,26 @@ impl Boolean { #[tracing::instrument(target = "r1cs")] pub fn kary_or(bits: &[Self]) -> Result { assert!(!bits.is_empty()); - let mut cur: Option = None; - for next in bits { - cur = if let Some(b) = cur { - Some(b | next) - } else { - Some(next.clone()) - }; - } + if bits.len() <= 3 { + let mut cur: Option = None; + for next in bits { + cur = if let Some(b) = cur { + Some(b | next) + } else { + Some(next.clone()) + }; + } - Ok(cur.expect("should not be 0")) + Ok(cur.expect("should not be 0")) + } else { + // b0 & b1 & ... & bN == 1 if and only if sum(b0, b1, ..., bN) == N + let sum_bits: FpVar<_> = bits.iter().map(|b| FpVar::from(b.clone())).sum(); + sum_bits.is_neq(&FpVar::zero()) + } } } -impl<'a, F: Field> BitOr for &'a Boolean { +impl<'a, F: PrimeField> BitOr for &'a Boolean { type Output = Boolean; /// Outputs `self | other`. @@ -89,7 +100,7 @@ impl<'a, F: Field> BitOr for &'a Boolean { } } -impl<'a, F: Field> BitOr<&'a Self> for Boolean { +impl<'a, F: PrimeField> BitOr<&'a Self> for Boolean { type Output = Boolean; #[tracing::instrument(target = "r1cs", skip(self, other))] @@ -98,7 +109,7 @@ impl<'a, F: Field> BitOr<&'a Self> for Boolean { } } -impl<'a, F: Field> BitOr> for &'a Boolean { +impl<'a, F: PrimeField> BitOr> for &'a Boolean { type Output = Boolean; #[tracing::instrument(target = "r1cs", skip(self, other))] @@ -107,7 +118,7 @@ impl<'a, F: Field> BitOr> for &'a Boolean { } } -impl BitOr for Boolean { +impl BitOr for Boolean { type Output = Self; #[tracing::instrument(target = "r1cs", skip(self, other))] @@ -116,7 +127,7 @@ impl BitOr for Boolean { } } -impl BitOrAssign for Boolean { +impl BitOrAssign for Boolean { /// Sets `self = self | other`. #[tracing::instrument(target = "r1cs", skip(self, other))] fn bitor_assign(&mut self, other: Self) { @@ -125,7 +136,7 @@ impl BitOrAssign for Boolean { } } -impl<'a, F: Field> BitOrAssign<&'a Self> for Boolean { +impl<'a, F: PrimeField> BitOrAssign<&'a Self> for Boolean { /// Sets `self = self | other`. #[tracing::instrument(target = "r1cs", skip(self, other))] fn bitor_assign(&mut self, other: &'a Self) { diff --git a/src/bits/boolean/select.rs b/src/bits/boolean/select.rs index b09955e8..78eb0448 100644 --- a/src/bits/boolean/select.rs +++ b/src/bits/boolean/select.rs @@ -1,6 +1,6 @@ use super::*; -impl Boolean { +impl Boolean { /// Conditionally selects one of `first` and `second` based on the value of /// `self`: /// @@ -36,7 +36,7 @@ impl Boolean { T::conditionally_select(&self, first, second) } } -impl CondSelectGadget for Boolean { +impl CondSelectGadget for Boolean { #[tracing::instrument(target = "r1cs")] fn conditionally_select( cond: &Boolean, diff --git a/src/bits/mod.rs b/src/bits/mod.rs index f44a221f..c57085c1 100644 --- a/src/bits/mod.rs +++ b/src/bits/mod.rs @@ -1,10 +1,3 @@ -use crate::{ - bits::{boolean::Boolean, uint8::UInt8}, - Vec, -}; -use ark_ff::Field; -use ark_relations::r1cs::SynthesisError; - /// This module contains `Boolean`, a R1CS equivalent of the `bool` type. pub mod boolean; /// This module contains `UInt8`, a R1CS equivalent of the `u8` type. @@ -26,85 +19,3 @@ pub mod uint64 { pub mod uint128 { pub type UInt128 = super::uint::UInt<128, u128, F>; } - -/// Specifies constraints for conversion to a little-endian bit representation -/// of `self`. -pub trait ToBitsGadget { - /// Outputs the canonical little-endian bit-wise representation of `self`. - /// - /// This is the correct default for 99% of use cases. - fn to_bits_le(&self) -> Result>, SynthesisError>; - - /// Outputs a possibly non-unique little-endian bit-wise representation of - /// `self`. - /// - /// If you're not absolutely certain that your usecase can get away with a - /// non-canonical representation, please use `self.to_bits()` instead. - fn to_non_unique_bits_le(&self) -> Result>, SynthesisError> { - self.to_bits_le() - } - - /// Outputs the canonical big-endian bit-wise representation of `self`. - fn to_bits_be(&self) -> Result>, SynthesisError> { - let mut res = self.to_bits_le()?; - res.reverse(); - Ok(res) - } - - /// Outputs a possibly non-unique big-endian bit-wise representation of - /// `self`. - fn to_non_unique_bits_be(&self) -> Result>, SynthesisError> { - let mut res = self.to_non_unique_bits_le()?; - res.reverse(); - Ok(res) - } -} - -impl ToBitsGadget for Boolean { - fn to_bits_le(&self) -> Result>, SynthesisError> { - Ok(vec![self.clone()]) - } -} - -impl ToBitsGadget for [Boolean] { - /// Outputs `self`. - fn to_bits_le(&self) -> Result>, SynthesisError> { - Ok(self.to_vec()) - } -} - -impl ToBitsGadget for Vec -where - [T]: ToBitsGadget, -{ - fn to_bits_le(&self) -> Result>, SynthesisError> { - self.as_slice().to_bits_le().map(|v| v.to_vec()) - } - - fn to_non_unique_bits_le(&self) -> Result>, SynthesisError> { - self.as_slice().to_non_unique_bits_le().map(|v| v.to_vec()) - } -} - -/// Specifies constraints for conversion to a little-endian byte representation -/// of `self`. -pub trait ToBytesGadget { - /// Outputs a canonical, little-endian, byte decomposition of `self`. - /// - /// This is the correct default for 99% of use cases. - fn to_bytes(&self) -> Result>, SynthesisError>; - - /// Outputs a possibly non-unique byte decomposition of `self`. - /// - /// If you're not absolutely certain that your usecase can get away with a - /// non-canonical representation, please use `self.to_bytes(cs)` instead. - fn to_non_unique_bytes(&self) -> Result>, SynthesisError> { - self.to_bytes() - } -} - -impl<'a, F: Field, T: 'a + ToBytesGadget> ToBytesGadget for &'a T { - fn to_bytes(&self) -> Result>, SynthesisError> { - (*self).to_bytes() - } -} diff --git a/src/bits/uint/add.rs b/src/bits/uint/add.rs index b6565bd1..5d1577e5 100644 --- a/src/bits/uint/add.rs +++ b/src/bits/uint/add.rs @@ -64,44 +64,6 @@ impl UInt Result<(), SynthesisError> { -// let mut rng = ark_std::test_rng(); - -// for _ in 0..1000 { -// let cs = ConstraintSystem::::new_ref(); - -// let a: $native = rng.gen(); -// let b: $native = rng.gen(); -// let c: $native = rng.gen(); -// let d: $native = rng.gen(); - -// let mut expected = (a ^ b).wrapping_add(c).wrapping_add(d); - -// let a_bit = $name::new_witness(ark_relations::ns!(cs, "a_bit"), || Ok(a))?; -// let b_bit = $name::constant(b); -// let c_bit = $name::constant(c); -// let d_bit = $name::new_witness(ark_relations::ns!(cs, "d_bit"), || Ok(d))?; - -// let r = a_bit.xor(&b_bit).unwrap(); -// let r = $name::add_many(&[r, c_bit, d_bit]).unwrap(); - -// assert!(cs.is_satisfied().unwrap()); -// assert!(r.value == Some(expected)); - -// for b in r.bits.iter() { -// match b { -// Boolean::Is(b) => assert_eq!(b.value()?, (expected & 1 == 1)), -// Boolean::Not(b) => assert_eq!(!b.value()?, (expected & 1 == 1)), -// Boolean::Constant(_) => unreachable!(), -// } - -// expected >>= 1; -// } -// } -// Ok(()) -// } - #[cfg(test)] mod tests { use super::*; diff --git a/src/bits/uint/cmp.rs b/src/bits/uint/cmp.rs index 864c85b8..977dde1e 100644 --- a/src/bits/uint/cmp.rs +++ b/src/bits/uint/cmp.rs @@ -1,11 +1,9 @@ -use super::*; +use crate::cmp::CmpGadget; -impl> UInt { - pub fn is_gt(&self, other: &Self) -> Result, SynthesisError> { - other.is_lt(self) - } +use super::*; - pub fn is_ge(&self, other: &Self) -> Result, SynthesisError> { +impl> CmpGadget for UInt { + fn is_ge(&self, other: &Self) -> Result, SynthesisError> { if N + 1 < ((F::MODULUS_BIT_SIZE - 1) as usize) { let a = self.to_fp()?; let b = other.to_fp()?; @@ -16,14 +14,6 @@ impl> UInt unimplemented!("bit sizes larger than modulus size not yet supported") } } - - pub fn is_lt(&self, other: &Self) -> Result, SynthesisError> { - Ok(!self.is_ge(other)?) - } - - pub fn is_le(&self, other: &Self) -> Result, SynthesisError> { - other.is_ge(self) - } } #[cfg(test)] diff --git a/src/bits/uint/convert.rs b/src/bits/uint/convert.rs index bb618c87..2cc8d247 100644 --- a/src/bits/uint/convert.rs +++ b/src/bits/uint/convert.rs @@ -1,3 +1,4 @@ +use crate::convert::*; use crate::fields::fp::FpVar; use super::*; diff --git a/src/bits/uint/mod.rs b/src/bits/uint/mod.rs index a2effea1..03c8ac19 100644 --- a/src/bits/uint/mod.rs +++ b/src/bits/uint/mod.rs @@ -190,44 +190,6 @@ impl AllocVar Result<(), SynthesisError> { -// use Boolean::*; -// let mut rng = ark_std::test_rng(); - -// for _ in 0..1000 { -// let cs = ConstraintSystem::::new_ref(); - -// let a: $native = rng.gen(); -// let b: $native = rng.gen(); -// let c: $native = rng.gen(); - -// let mut expected = a ^ b ^ c; - -// let a_bit = $name::new_witness(cs.clone(), || Ok(a))?; -// let b_bit = $name::constant(b); -// let c_bit = $name::new_witness(cs.clone(), || Ok(c))?; - -// let r = a_bit.xor(&b_bit).unwrap(); -// let r = r.xor(&c_bit).unwrap(); - -// assert!(cs.is_satisfied().unwrap()); - -// assert!(r.value == Some(expected)); - -// for b in r.bits.iter() { -// match b { -// Is(b) => assert_eq!(b.value()?, (expected & 1 == 1)), -// Not(b) => assert_eq!(!b.value()?, (expected & 1 == 1)), -// Constant(b) => assert_eq!(*b, (expected & 1 == 1)), -// } - -// expected >>= 1; -// } -// } -// Ok(()) -// } - // #[test] // fn test_add_many_constants() -> Result<(), SynthesisError> { // let mut rng = ark_std::test_rng(); diff --git a/src/bits/uint/or.rs b/src/bits/uint/or.rs index 3d0ae90e..d7c0cc0f 100644 --- a/src/bits/uint/or.rs +++ b/src/bits/uint/or.rs @@ -1,11 +1,11 @@ -use ark_ff::Field; +use ark_ff::PrimeField; use ark_relations::r1cs::SynthesisError; use ark_std::{fmt::Debug, ops::BitOr, ops::BitOrAssign}; use num_traits::PrimInt; use super::UInt; -impl UInt { +impl UInt { fn _or(&self, other: &Self) -> Result { let mut result = self.clone(); for (a, b) in result.bits.iter_mut().zip(&other.bits) { @@ -16,7 +16,7 @@ impl UInt { } } -impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitOr for &'a UInt { +impl<'a, const N: usize, T: PrimInt + Debug, F: PrimeField> BitOr for &'a UInt { type Output = UInt; /// Outputs `self ^ other`. /// @@ -46,7 +46,7 @@ impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitOr for &'a UInt< } } -impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitOr<&'a Self> for UInt { +impl<'a, const N: usize, T: PrimInt + Debug, F: PrimeField> BitOr<&'a Self> for UInt { type Output = UInt; /// Outputs `self ^ other`. /// @@ -76,7 +76,9 @@ impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitOr<&'a Self> for UInt< } } -impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitOr> for &'a UInt { +impl<'a, const N: usize, T: PrimInt + Debug, F: PrimeField> BitOr> + for &'a UInt +{ type Output = UInt; /// Outputs `self ^ other`. /// @@ -106,7 +108,7 @@ impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitOr> for } } -impl BitOr for UInt { +impl BitOr for UInt { type Output = Self; /// Outputs `self ^ other`. /// @@ -136,7 +138,7 @@ impl BitOr for UInt } } -impl BitOrAssign for UInt { +impl BitOrAssign for UInt { /// Sets `self = self ^ other`. /// /// If at least one of `self` and `other` are constants, then this method @@ -166,7 +168,9 @@ impl BitOrAssign for UInt BitOrAssign<&'a Self> for UInt { +impl<'a, const N: usize, T: PrimInt + Debug, F: PrimeField> BitOrAssign<&'a Self> + for UInt +{ /// Sets `self = self ^ other`. /// /// If at least one of `self` and `other` are constants, then this method diff --git a/src/bits/uint/select.rs b/src/bits/uint/select.rs index e50f0730..d891ab34 100644 --- a/src/bits/uint/select.rs +++ b/src/bits/uint/select.rs @@ -1,7 +1,7 @@ use super::*; use crate::select::CondSelectGadget; -impl CondSelectGadget +impl CondSelectGadget for UInt { #[tracing::instrument(target = "r1cs", skip(cond, true_value, false_value))] diff --git a/src/bits/uint8.rs b/src/bits/uint8.rs index 005cf0df..0797cdbc 100644 --- a/src/bits/uint8.rs +++ b/src/bits/uint8.rs @@ -3,9 +3,10 @@ use ark_ff::{Field, PrimeField, ToConstraintField}; use ark_relations::r1cs::{Namespace, SynthesisError}; use crate::{ + convert::{ToBitsGadget, ToConstraintFieldGadget}, fields::fp::{AllocatedFp, FpVar}, prelude::*, - ToConstraintFieldGadget, Vec, + Vec, }; pub type UInt8 = super::uint::UInt<8, u8, F>; @@ -96,12 +97,13 @@ impl ToConstraintFieldGadget for Vec { + fn is_gt(&self, other: &Self) -> Result, SynthesisError> { + other.is_lt(self) + } + + fn is_ge(&self, other: &Self) -> Result, SynthesisError>; + + fn is_lt(&self, other: &Self) -> Result, SynthesisError> { + Ok(!self.is_ge(other)?) + } + + fn is_le(&self, other: &Self) -> Result, SynthesisError> { + other.is_ge(self) + } +} diff --git a/src/convert.rs b/src/convert.rs new file mode 100644 index 00000000..0f1e3d31 --- /dev/null +++ b/src/convert.rs @@ -0,0 +1,95 @@ +use ark_ff::Field; +use ark_relations::r1cs::SynthesisError; + +use crate::{boolean::Boolean, uint8::UInt8}; + +/// Specifies constraints for conversion to a little-endian bit representation +/// of `self`. +pub trait ToBitsGadget { + /// Outputs the canonical little-endian bit-wise representation of `self`. + /// + /// This is the correct default for 99% of use cases. + fn to_bits_le(&self) -> Result>, SynthesisError>; + + /// Outputs a possibly non-unique little-endian bit-wise representation of + /// `self`. + /// + /// If you're not absolutely certain that your usecase can get away with a + /// non-canonical representation, please use `self.to_bits()` instead. + fn to_non_unique_bits_le(&self) -> Result>, SynthesisError> { + self.to_bits_le() + } + + /// Outputs the canonical big-endian bit-wise representation of `self`. + fn to_bits_be(&self) -> Result>, SynthesisError> { + let mut res = self.to_bits_le()?; + res.reverse(); + Ok(res) + } + + /// Outputs a possibly non-unique big-endian bit-wise representation of + /// `self`. + fn to_non_unique_bits_be(&self) -> Result>, SynthesisError> { + let mut res = self.to_non_unique_bits_le()?; + res.reverse(); + Ok(res) + } +} + +impl ToBitsGadget for Boolean { + fn to_bits_le(&self) -> Result>, SynthesisError> { + Ok(vec![self.clone()]) + } +} + +impl ToBitsGadget for [Boolean] { + /// Outputs `self`. + fn to_bits_le(&self) -> Result>, SynthesisError> { + Ok(self.to_vec()) + } +} + +impl ToBitsGadget for Vec +where + [T]: ToBitsGadget, +{ + fn to_bits_le(&self) -> Result>, SynthesisError> { + self.as_slice().to_bits_le().map(|v| v.to_vec()) + } + + fn to_non_unique_bits_le(&self) -> Result>, SynthesisError> { + self.as_slice().to_non_unique_bits_le().map(|v| v.to_vec()) + } +} + +/// Specifies constraints for conversion to a little-endian byte representation +/// of `self`. +pub trait ToBytesGadget { + /// Outputs a canonical, little-endian, byte decomposition of `self`. + /// + /// This is the correct default for 99% of use cases. + fn to_bytes(&self) -> Result>, SynthesisError>; + + /// Outputs a possibly non-unique byte decomposition of `self`. + /// + /// If you're not absolutely certain that your usecase can get away with a + /// non-canonical representation, please use `self.to_bytes(cs)` instead. + fn to_non_unique_bytes(&self) -> Result>, SynthesisError> { + self.to_bytes() + } +} + +impl<'a, F: Field, T: 'a + ToBytesGadget> ToBytesGadget for &'a T { + fn to_bytes(&self) -> Result>, SynthesisError> { + (*self).to_bytes() + } +} + +/// Specifies how to convert a variable of type `Self` to variables of +/// type `FpVar` +pub trait ToConstraintFieldGadget { + /// Converts `self` to `FpVar` variables. + fn to_constraint_field( + &self, + ) -> Result>, ark_relations::r1cs::SynthesisError>; +} diff --git a/src/fields/cubic_extension.rs b/src/fields/cubic_extension.rs index 12f330c5..20c040f4 100644 --- a/src/fields/cubic_extension.rs +++ b/src/fields/cubic_extension.rs @@ -6,9 +6,10 @@ use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError}; use core::{borrow::Borrow, marker::PhantomData}; use crate::{ + convert::{ToBitsGadget, ToBytesGadget, ToConstraintFieldGadget}, fields::{fp::FpVar, FieldOpsBounds, FieldVar}, prelude::*, - ToConstraintFieldGadget, Vec, + Vec, }; /// This struct is the `R1CS` equivalent of the cubic extension field type diff --git a/src/fields/fp/cmp.rs b/src/fields/fp/cmp.rs index 4612d18f..f3304e4a 100644 --- a/src/fields/fp/cmp.rs +++ b/src/fields/fp/cmp.rs @@ -1,8 +1,8 @@ use crate::{ boolean::Boolean, + convert::ToBitsGadget, fields::{fp::FpVar, FieldVar}, prelude::*, - ToBitsGadget, }; use ark_ff::PrimeField; use ark_relations::r1cs::{SynthesisError, Variable}; diff --git a/src/fields/fp/mod.rs b/src/fields/fp/mod.rs index 1527cace..ac2d1615 100644 --- a/src/fields/fp/mod.rs +++ b/src/fields/fp/mod.rs @@ -6,9 +6,10 @@ use ark_relations::r1cs::{ use core::borrow::Borrow; use crate::{ + convert::{ToBitsGadget, ToBytesGadget, ToConstraintFieldGadget}, fields::{FieldOpsBounds, FieldVar}, prelude::*, - Assignment, ToConstraintFieldGadget, Vec, + Assignment, Vec, }; use ark_std::iter::Sum; diff --git a/src/fields/mod.rs b/src/fields/mod.rs index 0c342998..5e40351a 100644 --- a/src/fields/mod.rs +++ b/src/fields/mod.rs @@ -5,6 +5,7 @@ use core::{ ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}, }; +use crate::convert::{ToBitsGadget, ToBytesGadget}; use crate::prelude::*; /// This module contains a generic implementation of cubic extension field @@ -65,7 +66,7 @@ pub trait FieldOpsBounds<'a, F, T: 'a>: } /// A variable representing a field. Corresponds to the native type `F`. -pub trait FieldVar: +pub trait FieldVar: 'static + Clone + From> diff --git a/src/fields/nonnative/allocated_field_var.rs b/src/fields/nonnative/allocated_field_var.rs index aadbe1a3..85eeeef5 100644 --- a/src/fields/nonnative/allocated_field_var.rs +++ b/src/fields/nonnative/allocated_field_var.rs @@ -3,7 +3,11 @@ use super::{ reduce::{bigint_to_basefield, limbs_to_bigint, Reducer}, AllocatedNonNativeFieldMulResultVar, }; -use crate::{fields::fp::FpVar, prelude::*, ToConstraintFieldGadget}; +use crate::{ + convert::{ToBitsGadget, ToBytesGadget, ToConstraintFieldGadget}, + fields::fp::FpVar, + prelude::*, +}; use ark_ff::{BigInteger, PrimeField}; use ark_relations::{ ns, diff --git a/src/fields/nonnative/field_var.rs b/src/fields/nonnative/field_var.rs index e37b6c84..3e00a5be 100644 --- a/src/fields/nonnative/field_var.rs +++ b/src/fields/nonnative/field_var.rs @@ -1,9 +1,10 @@ use super::{params::OptimizationType, AllocatedNonNativeFieldVar, NonNativeFieldMulResultVar}; use crate::{ boolean::Boolean, + convert::{ToBitsGadget, ToBytesGadget, ToConstraintFieldGadget}, fields::{fp::FpVar, FieldVar}, prelude::*, - R1CSVar, ToConstraintFieldGadget, + R1CSVar, }; use ark_ff::{BigInteger, PrimeField}; use ark_relations::r1cs::{ConstraintSystemRef, Namespace, Result as R1CSResult, SynthesisError}; diff --git a/src/fields/quadratic_extension.rs b/src/fields/quadratic_extension.rs index 54949a10..da944cfd 100644 --- a/src/fields/quadratic_extension.rs +++ b/src/fields/quadratic_extension.rs @@ -6,9 +6,10 @@ use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError}; use core::{borrow::Borrow, marker::PhantomData}; use crate::{ + convert::{ToBitsGadget, ToBytesGadget, ToConstraintFieldGadget}, fields::{fp::FpVar, FieldOpsBounds, FieldVar}, prelude::*, - ToConstraintFieldGadget, Vec, + Vec, }; /// This struct is the `R1CS` equivalent of the quadratic extension field type diff --git a/src/groups/curves/short_weierstrass/mnt4/mod.rs b/src/groups/curves/short_weierstrass/mnt4/mod.rs index 1bd768e0..7908852e 100644 --- a/src/groups/curves/short_weierstrass/mnt4/mod.rs +++ b/src/groups/curves/short_weierstrass/mnt4/mod.rs @@ -6,6 +6,7 @@ use ark_ff::Field; use ark_relations::r1cs::{Namespace, SynthesisError}; use crate::{ + convert::ToBytesGadget, fields::{fp::FpVar, fp2::Fp2Var, FieldVar}, groups::curves::short_weierstrass::ProjectiveVar, pairing::mnt4::PairingVar, diff --git a/src/groups/curves/short_weierstrass/mnt6/mod.rs b/src/groups/curves/short_weierstrass/mnt6/mod.rs index 6d216e14..9e534298 100644 --- a/src/groups/curves/short_weierstrass/mnt6/mod.rs +++ b/src/groups/curves/short_weierstrass/mnt6/mod.rs @@ -6,6 +6,7 @@ use ark_ff::Field; use ark_relations::r1cs::{Namespace, SynthesisError}; use crate::{ + convert::ToBytesGadget, fields::{fp::FpVar, fp3::Fp3Var, FieldVar}, groups::curves::short_weierstrass::ProjectiveVar, pairing::mnt6::PairingVar, diff --git a/src/groups/curves/short_weierstrass/mod.rs b/src/groups/curves/short_weierstrass/mod.rs index ede30af0..00e4b854 100644 --- a/src/groups/curves/short_weierstrass/mod.rs +++ b/src/groups/curves/short_weierstrass/mod.rs @@ -9,7 +9,12 @@ use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError}; use ark_std::{borrow::Borrow, marker::PhantomData, ops::Mul}; use non_zero_affine::NonZeroAffineVar; -use crate::{fields::fp::FpVar, prelude::*, ToConstraintFieldGadget, Vec}; +use crate::{ + convert::{ToBitsGadget, ToBytesGadget, ToConstraintFieldGadget}, + fields::fp::FpVar, + prelude::*, + Vec, +}; /// This module provides a generic implementation of G1 and G2 for /// the [\[BLS12]\]() family of bilinear groups. diff --git a/src/groups/curves/twisted_edwards/mod.rs b/src/groups/curves/twisted_edwards/mod.rs index 28b65ba1..1144cd2b 100644 --- a/src/groups/curves/twisted_edwards/mod.rs +++ b/src/groups/curves/twisted_edwards/mod.rs @@ -8,7 +8,11 @@ use ark_ec::{ use ark_ff::{BigInteger, BitIteratorBE, Field, One, PrimeField, Zero}; use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError}; -use crate::{prelude::*, ToConstraintFieldGadget, Vec}; +use crate::{ + convert::{ToBitsGadget, ToBytesGadget, ToConstraintFieldGadget}, + prelude::*, + Vec, +}; use crate::fields::fp::FpVar; use ark_std::{borrow::Borrow, marker::PhantomData, ops::Mul}; diff --git a/src/groups/mod.rs b/src/groups/mod.rs index fb8c7461..375d1f86 100644 --- a/src/groups/mod.rs +++ b/src/groups/mod.rs @@ -1,5 +1,8 @@ -use crate::prelude::*; -use ark_ff::Field; +use crate::{ + convert::{ToBitsGadget, ToBytesGadget}, + prelude::*, +}; +use ark_ff::PrimeField; use ark_relations::r1cs::{Namespace, SynthesisError}; use core::ops::{Add, AddAssign, Sub, SubAssign}; @@ -25,7 +28,7 @@ pub trait GroupOpsBounds<'a, F, T: 'a>: /// A variable that represents a curve point for /// the curve `C`. -pub trait CurveVar: +pub trait CurveVar: 'static + Sized + Clone diff --git a/src/lib.rs b/src/lib.rs index 9964c8c9..ef16f206 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -36,29 +36,31 @@ use ark_ff::Field; pub mod bits; pub use self::bits::*; -/// This module implements gadgets related to field arithmetic. +/// Finite field arithmetic. pub mod fields; -/// This module implements gadgets related to group arithmetic, and specifically -/// elliptic curve arithmetic. +/// Implementations of elliptic curve group arithmetic for popular curve models. pub mod groups; -/// This module implements gadgets related to computing pairings in bilinear -/// groups. +/// Gadgets for computing pairings in bilinear groups. pub mod pairing; -/// This module describes a trait for allocating new variables in a constraint -/// system. +/// Utilities for allocating new variables in a constraint system. pub mod alloc; -/// This module describes a trait for checking equality of variables. +/// Utilities for comparing variables. +pub mod cmp; + +/// Utilities for converting variables to other kinds of variables. +pub mod convert; + +/// Utilities for checking equality of variables. pub mod eq; -/// This module implements functions for manipulating polynomial variables over -/// finite fields. +/// Definitions of polynomial variables over finite fields. pub mod poly; -/// This module describes traits for conditionally selecting a variable from a +/// Contains traits for conditionally selecting a variable from a /// list of variables. pub mod select; @@ -69,7 +71,7 @@ pub(crate) mod test_utils; pub mod prelude { pub use crate::{ alloc::*, - bits::{boolean::Boolean, uint32::UInt32, uint8::UInt8, ToBitsGadget, ToBytesGadget}, + bits::{boolean::Boolean, uint32::UInt32, uint8::UInt8}, eq::*, fields::{FieldOpsBounds, FieldVar}, groups::{CurveVar, GroupOpsBounds}, @@ -145,12 +147,3 @@ impl Assignment for Option { self.ok_or(ark_relations::r1cs::SynthesisError::AssignmentMissing) } } - -/// Specifies how to convert a variable of type `Self` to variables of -/// type `FpVar` -pub trait ToConstraintFieldGadget { - /// Converts `self` to `FpVar` variables. - fn to_constraint_field( - &self, - ) -> Result>, ark_relations::r1cs::SynthesisError>; -} diff --git a/src/pairing/mod.rs b/src/pairing/mod.rs index 958134e1..84fe7486 100644 --- a/src/pairing/mod.rs +++ b/src/pairing/mod.rs @@ -1,7 +1,7 @@ -use crate::prelude::*; +use crate::{convert::ToBytesGadget, prelude::*}; use ark_ec::pairing::Pairing; use ark_ec::CurveGroup; -use ark_ff::Field; +use ark_ff::PrimeField; use ark_relations::r1cs::SynthesisError; use core::fmt::Debug; @@ -14,7 +14,10 @@ pub mod mnt6; /// Specifies the constraints for computing a pairing in the yybilinear group /// `E`. -pub trait PairingVar::G1 as CurveGroup>::BaseField> +pub trait PairingVar< + E: Pairing, + ConstraintF: PrimeField = <::G1 as CurveGroup>::BaseField, +> { /// An variable representing an element of `G1`. /// This is the R1CS equivalent of `E::G1Projective`. diff --git a/src/poly/domain/mod.rs b/src/poly/domain/mod.rs index 0b479451..92efb3e1 100644 --- a/src/poly/domain/mod.rs +++ b/src/poly/domain/mod.rs @@ -129,7 +129,10 @@ mod tests { use ark_relations::r1cs::ConstraintSystem; use ark_std::{rand::Rng, test_rng}; - use crate::{alloc::AllocVar, fields::fp::FpVar, poly::domain::Radix2DomainVar, R1CSVar}; + use crate::{ + alloc::AllocVar, convert::ToBitsGadget, fields::fp::FpVar, poly::domain::Radix2DomainVar, + R1CSVar, + }; fn test_query_coset_template() { const COSET_DIM: u64 = 7; From 19bec7566f7c55a5558ac7ce22e0bef5764ac653 Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Wed, 7 Jun 2023 14:13:55 -0700 Subject: [PATCH 18/29] Make `boolean` and `uint` top level modules and remove `bits` --- src/bits/mod.rs | 21 ------------------ src/{bits => }/boolean/allocated.rs | 0 src/{bits => }/boolean/and.rs | 0 src/{bits => }/boolean/cmp.rs | 0 src/{bits => }/boolean/convert.rs | 0 src/{bits => }/boolean/eq.rs | 0 src/{bits => }/boolean/mod.rs | 0 src/{bits => }/boolean/not.rs | 0 src/{bits => }/boolean/or.rs | 0 src/{bits => }/boolean/select.rs | 0 src/{bits => }/boolean/test_utils.rs | 0 src/{bits => }/boolean/xor.rs | 0 src/lib.rs | 33 +++++++++++++++++++++++----- src/{bits => }/uint/add.rs | 0 src/{bits => }/uint/and.rs | 0 src/{bits => }/uint/cmp.rs | 0 src/{bits => }/uint/convert.rs | 0 src/{bits => }/uint/eq.rs | 0 src/{bits => }/uint/mod.rs | 0 src/{bits => }/uint/not.rs | 0 src/{bits => }/uint/or.rs | 0 src/{bits => }/uint/rotate.rs | 0 src/{bits => }/uint/select.rs | 0 src/{bits => }/uint/test_utils.rs | 0 src/{bits => }/uint/xor.rs | 0 src/{bits => }/uint8.rs | 0 26 files changed, 28 insertions(+), 26 deletions(-) delete mode 100644 src/bits/mod.rs rename src/{bits => }/boolean/allocated.rs (100%) rename src/{bits => }/boolean/and.rs (100%) rename src/{bits => }/boolean/cmp.rs (100%) rename src/{bits => }/boolean/convert.rs (100%) rename src/{bits => }/boolean/eq.rs (100%) rename src/{bits => }/boolean/mod.rs (100%) rename src/{bits => }/boolean/not.rs (100%) rename src/{bits => }/boolean/or.rs (100%) rename src/{bits => }/boolean/select.rs (100%) rename src/{bits => }/boolean/test_utils.rs (100%) rename src/{bits => }/boolean/xor.rs (100%) rename src/{bits => }/uint/add.rs (100%) rename src/{bits => }/uint/and.rs (100%) rename src/{bits => }/uint/cmp.rs (100%) rename src/{bits => }/uint/convert.rs (100%) rename src/{bits => }/uint/eq.rs (100%) rename src/{bits => }/uint/mod.rs (100%) rename src/{bits => }/uint/not.rs (100%) rename src/{bits => }/uint/or.rs (100%) rename src/{bits => }/uint/rotate.rs (100%) rename src/{bits => }/uint/select.rs (100%) rename src/{bits => }/uint/test_utils.rs (100%) rename src/{bits => }/uint/xor.rs (100%) rename src/{bits => }/uint8.rs (100%) diff --git a/src/bits/mod.rs b/src/bits/mod.rs deleted file mode 100644 index c57085c1..00000000 --- a/src/bits/mod.rs +++ /dev/null @@ -1,21 +0,0 @@ -/// This module contains `Boolean`, a R1CS equivalent of the `bool` type. -pub mod boolean; -/// This module contains `UInt8`, a R1CS equivalent of the `u8` type. -pub mod uint8; -/// This module contains a macro for generating `UIntN` types, which are R1CS -/// equivalents of `N`-bit unsigned integers. -#[macro_use] -pub mod uint; - -pub mod uint16 { - pub type UInt16 = super::uint::UInt<16, u16, F>; -} -pub mod uint32 { - pub type UInt32 = super::uint::UInt<32, u32, F>; -} -pub mod uint64 { - pub type UInt64 = super::uint::UInt<64, u64, F>; -} -pub mod uint128 { - pub type UInt128 = super::uint::UInt<128, u128, F>; -} diff --git a/src/bits/boolean/allocated.rs b/src/boolean/allocated.rs similarity index 100% rename from src/bits/boolean/allocated.rs rename to src/boolean/allocated.rs diff --git a/src/bits/boolean/and.rs b/src/boolean/and.rs similarity index 100% rename from src/bits/boolean/and.rs rename to src/boolean/and.rs diff --git a/src/bits/boolean/cmp.rs b/src/boolean/cmp.rs similarity index 100% rename from src/bits/boolean/cmp.rs rename to src/boolean/cmp.rs diff --git a/src/bits/boolean/convert.rs b/src/boolean/convert.rs similarity index 100% rename from src/bits/boolean/convert.rs rename to src/boolean/convert.rs diff --git a/src/bits/boolean/eq.rs b/src/boolean/eq.rs similarity index 100% rename from src/bits/boolean/eq.rs rename to src/boolean/eq.rs diff --git a/src/bits/boolean/mod.rs b/src/boolean/mod.rs similarity index 100% rename from src/bits/boolean/mod.rs rename to src/boolean/mod.rs diff --git a/src/bits/boolean/not.rs b/src/boolean/not.rs similarity index 100% rename from src/bits/boolean/not.rs rename to src/boolean/not.rs diff --git a/src/bits/boolean/or.rs b/src/boolean/or.rs similarity index 100% rename from src/bits/boolean/or.rs rename to src/boolean/or.rs diff --git a/src/bits/boolean/select.rs b/src/boolean/select.rs similarity index 100% rename from src/bits/boolean/select.rs rename to src/boolean/select.rs diff --git a/src/bits/boolean/test_utils.rs b/src/boolean/test_utils.rs similarity index 100% rename from src/bits/boolean/test_utils.rs rename to src/boolean/test_utils.rs diff --git a/src/bits/boolean/xor.rs b/src/boolean/xor.rs similarity index 100% rename from src/bits/boolean/xor.rs rename to src/boolean/xor.rs diff --git a/src/lib.rs b/src/lib.rs index ef16f206..74c9dc21 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -31,10 +31,8 @@ pub(crate) use ark_std::vec::Vec; use ark_ff::Field; -/// This module implements gadgets related to bit manipulation, such as -/// `Boolean` and `UInt`s. -pub mod bits; -pub use self::bits::*; +/// This module contains `Boolean`, an R1CS equivalent of the `bool` type. +pub mod boolean; /// Finite field arithmetic. pub mod fields; @@ -67,16 +65,41 @@ pub mod select; #[cfg(test)] pub(crate) mod test_utils; +/// This module contains `UInt8`, a R1CS equivalent of the `u8` type. +pub mod uint8; +/// This module contains a macro for generating `UIntN` types, which are R1CS +/// equivalents of `N`-bit unsigned integers. +#[macro_use] +pub mod uint; + +pub mod uint16 { + pub type UInt16 = super::uint::UInt<16, u16, F>; +} +pub mod uint32 { + pub type UInt32 = super::uint::UInt<32, u32, F>; +} +pub mod uint64 { + pub type UInt64 = super::uint::UInt<64, u64, F>; +} +pub mod uint128 { + pub type UInt128 = super::uint::UInt<128, u128, F>; +} + #[allow(missing_docs)] pub mod prelude { pub use crate::{ alloc::*, - bits::{boolean::Boolean, uint32::UInt32, uint8::UInt8}, + boolean::Boolean, eq::*, fields::{FieldOpsBounds, FieldVar}, groups::{CurveVar, GroupOpsBounds}, pairing::PairingVar, select::*, + uint128::UInt128, + uint16::UInt16, + uint32::UInt32, + uint64::UInt64, + uint8::UInt8, R1CSVar, }; } diff --git a/src/bits/uint/add.rs b/src/uint/add.rs similarity index 100% rename from src/bits/uint/add.rs rename to src/uint/add.rs diff --git a/src/bits/uint/and.rs b/src/uint/and.rs similarity index 100% rename from src/bits/uint/and.rs rename to src/uint/and.rs diff --git a/src/bits/uint/cmp.rs b/src/uint/cmp.rs similarity index 100% rename from src/bits/uint/cmp.rs rename to src/uint/cmp.rs diff --git a/src/bits/uint/convert.rs b/src/uint/convert.rs similarity index 100% rename from src/bits/uint/convert.rs rename to src/uint/convert.rs diff --git a/src/bits/uint/eq.rs b/src/uint/eq.rs similarity index 100% rename from src/bits/uint/eq.rs rename to src/uint/eq.rs diff --git a/src/bits/uint/mod.rs b/src/uint/mod.rs similarity index 100% rename from src/bits/uint/mod.rs rename to src/uint/mod.rs diff --git a/src/bits/uint/not.rs b/src/uint/not.rs similarity index 100% rename from src/bits/uint/not.rs rename to src/uint/not.rs diff --git a/src/bits/uint/or.rs b/src/uint/or.rs similarity index 100% rename from src/bits/uint/or.rs rename to src/uint/or.rs diff --git a/src/bits/uint/rotate.rs b/src/uint/rotate.rs similarity index 100% rename from src/bits/uint/rotate.rs rename to src/uint/rotate.rs diff --git a/src/bits/uint/select.rs b/src/uint/select.rs similarity index 100% rename from src/bits/uint/select.rs rename to src/uint/select.rs diff --git a/src/bits/uint/test_utils.rs b/src/uint/test_utils.rs similarity index 100% rename from src/bits/uint/test_utils.rs rename to src/uint/test_utils.rs diff --git a/src/bits/uint/xor.rs b/src/uint/xor.rs similarity index 100% rename from src/bits/uint/xor.rs rename to src/uint/xor.rs diff --git a/src/bits/uint8.rs b/src/uint8.rs similarity index 100% rename from src/bits/uint8.rs rename to src/uint8.rs From 2dd5f325e42c87191a46e52b32ebbb1e2a2511ce Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Wed, 7 Jun 2023 14:14:32 -0700 Subject: [PATCH 19/29] Update `to_constraint_field_test` --- tests/to_constraint_field_test.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/to_constraint_field_test.rs b/tests/to_constraint_field_test.rs index d0a5ac02..ad268f89 100644 --- a/tests/to_constraint_field_test.rs +++ b/tests/to_constraint_field_test.rs @@ -1,5 +1,6 @@ use ark_r1cs_std::{ - alloc::AllocVar, fields::nonnative::NonNativeFieldVar, R1CSVar, ToConstraintFieldGadget, + alloc::AllocVar, convert::ToConstraintFieldGadget, fields::nonnative::NonNativeFieldVar, + R1CSVar, }; use ark_relations::r1cs::ConstraintSystem; From 4f00cdbacbf40229f07c17dec82536d8e4438940 Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Wed, 7 Jun 2023 15:29:13 -0700 Subject: [PATCH 20/29] Fix doc tests --- src/boolean/and.rs | 9 ++-- src/boolean/not.rs | 8 ++-- src/boolean/or.rs | 11 ++--- src/boolean/xor.rs | 8 ++-- src/uint/or.rs | 107 +++++---------------------------------------- src/uint/rotate.rs | 4 +- src/uint/xor.rs | 98 +++-------------------------------------- 7 files changed, 40 insertions(+), 205 deletions(-) diff --git a/src/boolean/and.rs b/src/boolean/and.rs index 0b846180..20251f3f 100644 --- a/src/boolean/and.rs +++ b/src/boolean/and.rs @@ -62,6 +62,7 @@ impl Boolean { Ok(cur.expect("should not be 0")) } else { + // b0 & b1 & ... & bN == 1 if and only if sum(b0, b1, ..., bN) == N let sum_bits: FpVar<_> = bits.iter().map(|b| FpVar::from(b.clone())).sum(); let num_bits = FpVar::Constant(F::from(bits.len() as u64)); sum_bits.is_eq(&num_bits) @@ -125,11 +126,11 @@ impl<'a, F: Field> BitAnd for &'a Boolean { /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; /// - /// a.and(&a)?.enforce_equal(&Boolean::TRUE)?; + /// (&a & &a).enforce_equal(&Boolean::TRUE)?; /// - /// a.and(&b)?.enforce_equal(&Boolean::FALSE)?; - /// b.and(&a)?.enforce_equal(&Boolean::FALSE)?; - /// b.and(&b)?.enforce_equal(&Boolean::FALSE)?; + /// (&a & &b).enforce_equal(&Boolean::FALSE)?; + /// (&b & &a).enforce_equal(&Boolean::FALSE)?; + /// (&b & &b).enforce_equal(&Boolean::FALSE)?; /// /// assert!(cs.is_satisfied().unwrap()); /// # Ok(()) diff --git a/src/boolean/not.rs b/src/boolean/not.rs index ebbce715..1c7de2e6 100644 --- a/src/boolean/not.rs +++ b/src/boolean/not.rs @@ -30,11 +30,11 @@ impl<'a, F: Field> Not for &'a Boolean { /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; /// - /// a.not().enforce_equal(&b)?; - /// b.not().enforce_equal(&a)?; + /// (!&a).enforce_equal(&b)?; + /// (!&b).enforce_equal(&a)?; /// - /// a.not().enforce_equal(&Boolean::FALSE)?; - /// b.not().enforce_equal(&Boolean::TRUE)?; + /// (!&a).enforce_equal(&Boolean::FALSE)?; + /// (!&b).enforce_equal(&Boolean::TRUE)?; /// /// assert!(cs.is_satisfied().unwrap()); /// # Ok(()) diff --git a/src/boolean/or.rs b/src/boolean/or.rs index a54bfc71..8f8b41c1 100644 --- a/src/boolean/or.rs +++ b/src/boolean/or.rs @@ -57,7 +57,8 @@ impl Boolean { Ok(cur.expect("should not be 0")) } else { - // b0 & b1 & ... & bN == 1 if and only if sum(b0, b1, ..., bN) == N + // b0 | b1 | ... | bN == 1 if and only if not all of b0, b1, ..., bN are 0. + // We can enforce this by requiring that the sum of b0, b1, ..., bN is not 0. let sum_bits: FpVar<_> = bits.iter().map(|b| FpVar::from(b.clone())).sum(); sum_bits.is_neq(&FpVar::zero()) } @@ -84,11 +85,11 @@ impl<'a, F: PrimeField> BitOr for &'a Boolean { /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; /// - /// a.or(&b)?.enforce_equal(&Boolean::TRUE)?; - /// b.or(&a)?.enforce_equal(&Boolean::TRUE)?; + /// (&a | &b).enforce_equal(&Boolean::TRUE)?; + /// (&b | &a).enforce_equal(&Boolean::TRUE)?; /// - /// a.or(&a)?.enforce_equal(&Boolean::TRUE)?; - /// b.or(&b)?.enforce_equal(&Boolean::FALSE)?; + /// (&a | &a).enforce_equal(&Boolean::TRUE)?; + /// (&b | &b).enforce_equal(&Boolean::FALSE)?; /// /// assert!(cs.is_satisfied().unwrap()); /// # Ok(()) diff --git a/src/boolean/xor.rs b/src/boolean/xor.rs index ff880460..67e45b36 100644 --- a/src/boolean/xor.rs +++ b/src/boolean/xor.rs @@ -35,11 +35,11 @@ impl<'a, F: Field> BitXor for &'a Boolean { /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; /// - /// a.xor(&b)?.enforce_equal(&Boolean::TRUE)?; - /// b.xor(&a)?.enforce_equal(&Boolean::TRUE)?; + /// (&a ^ &b).enforce_equal(&Boolean::TRUE)?; + /// (&b ^ &a).enforce_equal(&Boolean::TRUE)?; /// - /// a.xor(&a)?.enforce_equal(&Boolean::FALSE)?; - /// b.xor(&b)?.enforce_equal(&Boolean::FALSE)?; + /// (&a ^ &a).enforce_equal(&Boolean::FALSE)?; + /// (&b ^ &b).enforce_equal(&Boolean::FALSE)?; /// /// assert!(cs.is_satisfied().unwrap()); /// # Ok(()) diff --git a/src/uint/or.rs b/src/uint/or.rs index d7c0cc0f..823617ec 100644 --- a/src/uint/or.rs +++ b/src/uint/or.rs @@ -18,7 +18,8 @@ impl UInt { impl<'a, const N: usize, T: PrimInt + Debug, F: PrimeField> BitOr for &'a UInt { type Output = UInt; - /// Outputs `self ^ other`. + + /// Output `self | other`. /// /// If at least one of `self` and `other` are constants, then this method /// *does not* create any constraints or variables. @@ -33,9 +34,9 @@ impl<'a, const N: usize, T: PrimInt + Debug, F: PrimeField> BitOr for &'a /// let cs = ConstraintSystem::::new_ref(); /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; - /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; + /// let c = UInt8::new_witness(cs.clone(), || Ok(16 | 17))?; /// - /// a.or(&b)?.enforce_equal(&c)?; + /// (a | b).enforce_equal(&c)?; /// assert!(cs.is_satisfied().unwrap()); /// # Ok(()) /// # } @@ -48,28 +49,7 @@ impl<'a, const N: usize, T: PrimInt + Debug, F: PrimeField> BitOr for &'a impl<'a, const N: usize, T: PrimInt + Debug, F: PrimeField> BitOr<&'a Self> for UInt { type Output = UInt; - /// Outputs `self ^ other`. - /// - /// If at least one of `self` and `other` are constants, then this method - /// *does not* create any constraints or variables. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; - /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; - /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; - /// - /// a.or(&b)?.enforce_equal(&c)?; - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] fn bitor(self, other: &Self) -> Self::Output { self._or(&other).unwrap() @@ -80,28 +60,7 @@ impl<'a, const N: usize, T: PrimInt + Debug, F: PrimeField> BitOr> for &'a UInt { type Output = UInt; - /// Outputs `self ^ other`. - /// - /// If at least one of `self` and `other` are constants, then this method - /// *does not* create any constraints or variables. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; - /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; - /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; - /// - /// a.or(&b)?.enforce_equal(&c)?; - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] fn bitor(self, other: UInt) -> Self::Output { self._or(&other).unwrap() @@ -110,28 +69,7 @@ impl<'a, const N: usize, T: PrimInt + Debug, F: PrimeField> BitOr> impl BitOr for UInt { type Output = Self; - /// Outputs `self ^ other`. - /// - /// If at least one of `self` and `other` are constants, then this method - /// *does not* create any constraints or variables. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; - /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; - /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; - /// - /// a.or(&b)?.enforce_equal(&c)?; - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] fn bitor(self, other: Self) -> Self::Output { self._or(&other).unwrap() @@ -139,7 +77,7 @@ impl BitOr for UInt BitOrAssign for UInt { - /// Sets `self = self ^ other`. + /// Sets `self = self | other`. /// /// If at least one of `self` and `other` are constants, then this method /// *does not* create any constraints or variables. @@ -152,11 +90,12 @@ impl BitOrAssign for UI /// use ark_r1cs_std::prelude::*; /// /// let cs = ConstraintSystem::::new_ref(); - /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let mut a = UInt8::new_witness(cs.clone(), || Ok(16))?; /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; - /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; + /// let c = UInt8::new_witness(cs.clone(), || Ok(16 | 17))?; /// - /// a.or(&b)?.enforce_equal(&c)?; + /// a |= b; + /// a.enforce_equal(&c)?; /// assert!(cs.is_satisfied().unwrap()); /// # Ok(()) /// # } @@ -171,28 +110,6 @@ impl BitOrAssign for UI impl<'a, const N: usize, T: PrimInt + Debug, F: PrimeField> BitOrAssign<&'a Self> for UInt { - /// Sets `self = self ^ other`. - /// - /// If at least one of `self` and `other` are constants, then this method - /// *does not* create any constraints or variables. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; - /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; - /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; - /// - /// a.or(&b)?.enforce_equal(&c)?; - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` #[tracing::instrument(target = "r1cs", skip(self, other))] fn bitor_assign(&mut self, other: &'a Self) { let result = self._or(other).unwrap(); diff --git a/src/uint/rotate.rs b/src/uint/rotate.rs index 0e4333ec..394c7b62 100644 --- a/src/uint/rotate.rs +++ b/src/uint/rotate.rs @@ -15,7 +15,7 @@ impl UInt UInt BitXor for &'a UInt /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; /// - /// a.xor(&b)?.enforce_equal(&c)?; + /// (a ^ &b).enforce_equal(&c)?; /// assert!(cs.is_satisfied().unwrap()); /// # Ok(()) /// # } @@ -48,28 +48,7 @@ impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitXor for &'a UInt impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitXor<&'a Self> for UInt { type Output = UInt; - /// Outputs `self ^ other`. - /// - /// If at least one of `self` and `other` are constants, then this method - /// *does not* create any constraints or variables. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; - /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; - /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; - /// - /// a.xor(&b)?.enforce_equal(&c)?; - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] fn bitxor(self, other: &Self) -> Self::Output { self._xor(&other).unwrap() @@ -78,28 +57,7 @@ impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitXor<&'a Self> for UInt impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitXor> for &'a UInt { type Output = UInt; - /// Outputs `self ^ other`. - /// - /// If at least one of `self` and `other` are constants, then this method - /// *does not* create any constraints or variables. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; - /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; - /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; - /// - /// a.xor(&b)?.enforce_equal(&c)?; - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] fn bitxor(self, other: UInt) -> Self::Output { self._xor(&other).unwrap() @@ -108,28 +66,7 @@ impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitXor> for impl BitXor for UInt { type Output = Self; - /// Outputs `self ^ other`. - /// - /// If at least one of `self` and `other` are constants, then this method - /// *does not* create any constraints or variables. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; - /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; - /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; - /// - /// a.xor(&b)?.enforce_equal(&c)?; - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] fn bitxor(self, other: Self) -> Self::Output { self._xor(&other).unwrap() @@ -150,11 +87,12 @@ impl BitXorAssign for UInt::new_ref(); - /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let mut a = UInt8::new_witness(cs.clone(), || Ok(16))?; /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; /// - /// a.xor(&b)?.enforce_equal(&c)?; + /// a ^= b; + /// a.enforce_equal(&c)?; /// assert!(cs.is_satisfied().unwrap()); /// # Ok(()) /// # } @@ -167,28 +105,6 @@ impl BitXorAssign for UInt BitXorAssign<&'a Self> for UInt { - /// Sets `self = self ^ other`. - /// - /// If at least one of `self` and `other` are constants, then this method - /// *does not* create any constraints or variables. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; - /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; - /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; - /// - /// a.xor(&b)?.enforce_equal(&c)?; - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` #[tracing::instrument(target = "r1cs", skip(self, other))] fn bitxor_assign(&mut self, other: &'a Self) { let result = self._xor(other).unwrap(); From 006f090047845405ecf176d36546169d5bc7a4a3 Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Wed, 7 Jun 2023 19:39:30 -0700 Subject: [PATCH 21/29] Introduce `PrimUInt` trait and implement `saturating_add` --- src/poly/domain/mod.rs | 7 -- src/uint/add/mod.rs | 50 ++++++++++ src/uint/add/saturating.rs | 117 +++++++++++++++++++++++ src/uint/{add.rs => add/wrapping.rs} | 54 +++-------- src/uint/and.rs | 21 ++--- src/uint/cmp.rs | 10 +- src/uint/convert.rs | 14 +-- src/uint/eq.rs | 12 +-- src/uint/mod.rs | 93 +++--------------- src/uint/not.rs | 13 ++- src/uint/or.rs | 25 ++--- src/uint/prim_uint.rs | 136 +++++++++++++++++++++++++++ src/uint/rotate.rs | 6 +- src/uint/select.rs | 4 +- src/uint/test_utils.rs | 20 ++-- src/uint/xor.rs | 21 ++--- src/uint8.rs | 6 -- 17 files changed, 397 insertions(+), 212 deletions(-) create mode 100644 src/uint/add/mod.rs create mode 100644 src/uint/add/saturating.rs rename src/uint/{add.rs => add/wrapping.rs} (63%) create mode 100644 src/uint/prim_uint.rs diff --git a/src/poly/domain/mod.rs b/src/poly/domain/mod.rs index 92efb3e1..c30bd03e 100644 --- a/src/poly/domain/mod.rs +++ b/src/poly/domain/mod.rs @@ -149,13 +149,6 @@ mod tests { let coset_index = rng.gen_range(0..num_cosets); println!("{:0b}", coset_index); - dbg!(UInt32::new_witness(cs.clone(), || Ok(coset_index)) - .unwrap() - .to_bits_le() - .unwrap() - .iter() - .map(|x| x.value().unwrap() as u8) - .collect::>()); let coset_index_var = UInt32::new_witness(cs.clone(), || Ok(coset_index)) .unwrap() .to_bits_le() diff --git a/src/uint/add/mod.rs b/src/uint/add/mod.rs new file mode 100644 index 00000000..d5eb7563 --- /dev/null +++ b/src/uint/add/mod.rs @@ -0,0 +1,50 @@ +use crate::fields::fp::FpVar; + +use super::*; + +mod saturating; +mod wrapping; + +impl UInt { + /// Adds up `operands`, returning the bit decomposition of the result, along with + /// the value of the result. If all the operands are constant, then the bit decomposition + /// is empty, and the value is the constant value of the result. + /// + /// # Panics + /// + /// This method panics if the result of addition could possibly exceed the field size. + #[tracing::instrument(target = "r1cs", skip(operands, adder))] + fn add_many_helper( + operands: &[Self], + adder: impl Fn(T, T) -> T, + ) -> Result<(Vec>, Option), SynthesisError> { + // Bounds on `N` to avoid overflows + + assert!(operands.len() >= 1); + let max_value_size = N as u32 + ark_std::log2(operands.len()); + assert!(max_value_size <= F::MODULUS_BIT_SIZE); + + if operands.len() == 1 { + return Ok((operands[0].bits.to_vec(), operands[0].value)); + } + + // Compute the value of the result. + let mut value = Some(T::zero()); + for op in operands { + value = value.and_then(|v| Some(adder(v, op.value?))); + } + if operands.is_constant() { + // If all operands are constant, then the result is also constant. + // In this case, we can return early. + return Ok((Vec::new(), value)); + } + + // Compute the full (non-wrapped) sum of the operands. + let result = operands + .iter() + .map(|op| Boolean::le_bits_to_fp(&op.bits).unwrap()) + .sum::>(); + let (result, _) = result.to_bits_le_with_top_bits_zero(max_value_size as usize)?; + Ok((result, value)) + } +} diff --git a/src/uint/add/saturating.rs b/src/uint/add/saturating.rs new file mode 100644 index 00000000..62c393eb --- /dev/null +++ b/src/uint/add/saturating.rs @@ -0,0 +1,117 @@ +use ark_ff::PrimeField; +use ark_relations::r1cs::SynthesisError; + +use crate::uint::*; +use crate::{boolean::Boolean, R1CSVar}; + +impl UInt { + /// Compute `*self = self.wrapping_add(other)`. + pub fn saturating_add_in_place(&mut self, other: &Self) { + let result = Self::saturating_add_many(&[self.clone(), other.clone()]).unwrap(); + *self = result; + } + + /// Compute `self.wrapping_add(other)`. + pub fn saturating_add(&self, other: &Self) -> Self { + let mut result = self.clone(); + result.saturating_add_in_place(other); + result + } + + /// Perform wrapping addition of `operands`. + /// Computes `operands[0].wrapping_add(operands[1]).wrapping_add(operands[2])...`. + /// + /// The user must ensure that overflow does not occur. + #[tracing::instrument(target = "r1cs", skip(operands))] + pub fn saturating_add_many(operands: &[Self]) -> Result + where + F: PrimeField, + { + let (sum_bits, value) = Self::add_many_helper(operands, |a, b| a.saturating_add(b))?; + if operands.is_constant() { + // If all operands are constant, then the result is also constant. + // In this case, we can return early. + Ok(UInt::constant(value.unwrap())) + } else if sum_bits.len() == N { + // No overflow occurred. + Ok(UInt::from_bits_le(&sum_bits)) + } else { + // Split the sum into the bottom `N` bits and the top bits. + let (bottom_bits, top_bits) = sum_bits.split_at(N); + + // Construct a candidate result assuming that no overflow occurred. + let bits = TryFrom::try_from(bottom_bits.to_vec()).unwrap(); + let candidate_result = UInt { bits, value }; + + // Check if any of the top bits is set. + // If any of them is set, then overflow occurred. + let overflow_occurred = Boolean::kary_or(&top_bits)?; + + // If overflow occurred, return the maximum value. + overflow_occurred.select(&Self::MAX, &candidate_result) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + prelude::EqGadget, + uint::test_utils::{run_binary_exhaustive, run_binary_random}, + R1CSVar, + }; + use ark_ff::PrimeField; + use ark_test_curves::bls12_381::Fr; + + fn uint_saturating_add( + a: UInt, + b: UInt, + ) -> Result<(), SynthesisError> { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let computed = a.saturating_add(&b); + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = UInt::new_variable( + cs.clone(), + || Ok(a.value()?.saturating_add(b.value()?)), + expected_mode, + )?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + #[test] + fn u8_saturating_add() { + run_binary_exhaustive(uint_saturating_add::).unwrap() + } + + #[test] + fn u16_saturating_add() { + run_binary_random::<1000, 16, _, _>(uint_saturating_add::).unwrap() + } + + #[test] + fn u32_saturating_add() { + run_binary_random::<1000, 32, _, _>(uint_saturating_add::).unwrap() + } + + #[test] + fn u64_saturating_add() { + run_binary_random::<1000, 64, _, _>(uint_saturating_add::).unwrap() + } + + #[test] + fn u128_saturating_add() { + run_binary_random::<1000, 128, _, _>(uint_saturating_add::).unwrap() + } +} diff --git a/src/uint/add.rs b/src/uint/add/wrapping.rs similarity index 63% rename from src/uint/add.rs rename to src/uint/add/wrapping.rs index 5d1577e5..5dfe4496 100644 --- a/src/uint/add.rs +++ b/src/uint/add/wrapping.rs @@ -1,14 +1,10 @@ use ark_ff::PrimeField; use ark_relations::r1cs::SynthesisError; -use ark_std::fmt::Debug; -use num_traits::{PrimInt, WrappingAdd}; +use crate::uint::*; +use crate::R1CSVar; -use crate::{boolean::Boolean, fields::fp::FpVar, R1CSVar}; - -use super::UInt; - -impl UInt { +impl UInt { /// Compute `*self = self.wrapping_add(other)`. pub fn wrapping_add_in_place(&mut self, other: &Self) { let result = Self::wrapping_add_many(&[self.clone(), other.clone()]).unwrap(); @@ -31,36 +27,18 @@ impl UInt= 1); - let max_value_size = N as u32 + ark_std::log2(operands.len()); - assert!(max_value_size <= F::MODULUS_BIT_SIZE); - - if operands.len() == 1 { - return Ok(operands[0].clone()); - } - - // Compute the value of the result. - let mut value = Some(T::zero()); - for op in operands { - value = value.and_then(|v| Some(v.wrapping_add(&op.value?))); - } + let (mut sum_bits, value) = Self::add_many_helper(operands, |a, b| a.wrapping_add(&b))?; if operands.is_constant() { - return Ok(UInt::constant(value.unwrap())); + // If all operands are constant, then the result is also constant. + // In this case, we can return early. + Ok(UInt::constant(value.unwrap())) + } else { + sum_bits.truncate(N); + Ok(UInt { + bits: sum_bits.try_into().unwrap(), + value, + }) } - - // Compute the full (non-wrapped) sum of the operands. - let result = operands - .iter() - .map(|op| Boolean::le_bits_to_fp(&op.bits).unwrap()) - .sum::>(); - let (mut result_bits, _) = result.to_bits_le_with_top_bits_zero(max_value_size as usize)?; - // Discard any carry bits, since these will get discarded by wrapping. - result_bits.truncate(N); - let bits = TryFrom::try_from(result_bits).unwrap(); - - Ok(UInt { bits, value }) } } @@ -76,17 +54,13 @@ mod tests { use ark_ff::PrimeField; use ark_test_curves::bls12_381::Fr; - fn uint_wrapping_add( + fn uint_wrapping_add( a: UInt, b: UInt, ) -> Result<(), SynthesisError> { let cs = a.cs().or(b.cs()); let both_constant = a.is_constant() && b.is_constant(); let computed = a.wrapping_add(&b); - let _ = dbg!(a.value()); - let _ = dbg!(b.value()); - dbg!(a.is_constant()); - dbg!(b.is_constant()); let expected_mode = if both_constant { AllocationMode::Constant } else { diff --git a/src/uint/and.rs b/src/uint/and.rs index 3466f36c..4fb1b5b6 100644 --- a/src/uint/and.rs +++ b/src/uint/and.rs @@ -1,11 +1,10 @@ use ark_ff::Field; use ark_relations::r1cs::SynthesisError; -use ark_std::{fmt::Debug, ops::BitAnd, ops::BitAndAssign}; -use num_traits::PrimInt; +use ark_std::{ops::BitAnd, ops::BitAndAssign}; -use super::UInt; +use super::*; -impl UInt { +impl UInt { fn _and(&self, other: &Self) -> Result { let mut result = self.clone(); for (a, b) in result.bits.iter_mut().zip(&other.bits) { @@ -16,7 +15,7 @@ impl UInt { } } -impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitAnd for &'a UInt { +impl<'a, const N: usize, T: PrimUInt, F: Field> BitAnd for &'a UInt { type Output = UInt; /// Outputs `self & other`. /// @@ -46,7 +45,7 @@ impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitAnd for &'a UInt } } -impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitAnd<&'a Self> for UInt { +impl<'a, const N: usize, T: PrimUInt, F: Field> BitAnd<&'a Self> for UInt { type Output = UInt; /// Outputs `self & other`. /// @@ -76,7 +75,7 @@ impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitAnd<&'a Self> for UInt } } -impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitAnd> for &'a UInt { +impl<'a, const N: usize, T: PrimUInt, F: Field> BitAnd> for &'a UInt { type Output = UInt; /// Outputs `self & other`. @@ -107,7 +106,7 @@ impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitAnd> for } } -impl BitAnd for UInt { +impl BitAnd for UInt { type Output = Self; /// Outputs `self & other`. @@ -138,7 +137,7 @@ impl BitAnd for UInt BitAndAssign for UInt { +impl BitAndAssign for UInt { /// Sets `self = self & other`. /// /// If at least one of `self` and `other` are constants, then this method @@ -169,7 +168,7 @@ impl BitAndAssign for UInt BitAndAssign<&'a Self> for UInt { +impl<'a, const N: usize, T: PrimUInt, F: Field> BitAndAssign<&'a Self> for UInt { /// Sets `self = self & other`. /// /// If at least one of `self` and `other` are constants, then this method @@ -212,7 +211,7 @@ mod tests { use ark_ff::PrimeField; use ark_test_curves::bls12_381::Fr; - fn uint_and( + fn uint_and( a: UInt, b: UInt, ) -> Result<(), SynthesisError> { diff --git a/src/uint/cmp.rs b/src/uint/cmp.rs index 977dde1e..5a01b169 100644 --- a/src/uint/cmp.rs +++ b/src/uint/cmp.rs @@ -2,7 +2,7 @@ use crate::cmp::CmpGadget; use super::*; -impl> CmpGadget for UInt { +impl> CmpGadget for UInt { fn is_ge(&self, other: &Self) -> Result, SynthesisError> { if N + 1 < ((F::MODULUS_BIT_SIZE - 1) as usize) { let a = self.to_fp()?; @@ -28,7 +28,7 @@ mod tests { use ark_ff::PrimeField; use ark_test_curves::bls12_381::Fr; - fn uint_gt>( + fn uint_gt>( a: UInt, b: UInt, ) -> Result<(), SynthesisError> { @@ -50,7 +50,7 @@ mod tests { Ok(()) } - fn uint_lt>( + fn uint_lt>( a: UInt, b: UInt, ) -> Result<(), SynthesisError> { @@ -72,7 +72,7 @@ mod tests { Ok(()) } - fn uint_ge>( + fn uint_ge>( a: UInt, b: UInt, ) -> Result<(), SynthesisError> { @@ -94,7 +94,7 @@ mod tests { Ok(()) } - fn uint_le>( + fn uint_le>( a: UInt, b: UInt, ) -> Result<(), SynthesisError> { diff --git a/src/uint/convert.rs b/src/uint/convert.rs index 2cc8d247..45ceb3e5 100644 --- a/src/uint/convert.rs +++ b/src/uint/convert.rs @@ -3,7 +3,7 @@ use crate::fields::fp::FpVar; use super::*; -impl UInt { +impl UInt { /// Converts `self` into a field element. The elements comprising `self` are /// interpreted as a little-endian bit order representation of a field element. /// @@ -74,13 +74,13 @@ impl UInt { } } -impl ToBitsGadget for UInt { +impl ToBitsGadget for UInt { fn to_bits_le(&self) -> Result>, SynthesisError> { Ok(self.bits.to_vec()) } } -impl ToBitsGadget for [UInt] { +impl ToBitsGadget for [UInt] { /// Interprets `self` as an integer, and outputs the little-endian /// bit-wise decomposition of that integer. fn to_bits_le(&self) -> Result>, SynthesisError> { @@ -93,7 +93,7 @@ impl ToBitsGadget for [UInt ToBytesGadget +impl ToBytesGadget for UInt { #[tracing::instrument(target = "r1cs", skip(self))] @@ -106,7 +106,7 @@ impl ToBytesGadget ToBytesGadget for [UInt] { +impl ToBytesGadget for [UInt] { fn to_bytes(&self) -> Result>, SynthesisError> { let mut bytes = Vec::with_capacity(self.len() * (N / 8)); for elem in self { @@ -116,13 +116,13 @@ impl ToBytesGadget for [UInt ToBytesGadget for Vec> { +impl ToBytesGadget for Vec> { fn to_bytes(&self) -> Result>, SynthesisError> { self.as_slice().to_bytes() } } -impl<'a, const N: usize, T: PrimInt + Debug, F: Field> ToBytesGadget for &'a [UInt] { +impl<'a, const N: usize, T: PrimUInt, F: Field> ToBytesGadget for &'a [UInt] { fn to_bytes(&self) -> Result>, SynthesisError> { (*self).to_bytes() } diff --git a/src/uint/eq.rs b/src/uint/eq.rs index 472bcef4..1b7c386f 100644 --- a/src/uint/eq.rs +++ b/src/uint/eq.rs @@ -1,15 +1,13 @@ use ark_ff::PrimeField; use ark_relations::r1cs::SynthesisError; -use ark_std::{fmt::Debug, vec::Vec}; - -use num_traits::PrimInt; +use ark_std::vec::Vec; use crate::boolean::Boolean; use crate::eq::EqGadget; -use super::UInt; +use super::*; -impl EqGadget +impl EqGadget for UInt { #[tracing::instrument(target = "r1cs", skip(self, other))] @@ -79,7 +77,7 @@ mod tests { use ark_ff::PrimeField; use ark_test_curves::bls12_381::Fr; - fn uint_eq( + fn uint_eq( a: UInt, b: UInt, ) -> Result<(), SynthesisError> { @@ -101,7 +99,7 @@ mod tests { Ok(()) } - fn uint_neq( + fn uint_neq( a: UInt, b: UInt, ) -> Result<(), SynthesisError> { diff --git a/src/uint/mod.rs b/src/uint/mod.rs index 03c8ac19..26598675 100644 --- a/src/uint/mod.rs +++ b/src/uint/mod.rs @@ -1,6 +1,5 @@ use ark_ff::{Field, PrimeField}; use core::{borrow::Borrow, convert::TryFrom, fmt::Debug}; -use num_traits::PrimInt; use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError}; @@ -17,19 +16,23 @@ mod rotate; mod select; mod xor; +#[doc(hidden)] +pub mod prim_uint; +pub use prim_uint::*; + #[cfg(test)] pub(crate) mod test_utils; /// This struct represent an unsigned `N` bit integer as a sequence of `N` [`Boolean`]s. #[derive(Clone, Debug)] -pub struct UInt { +pub struct UInt { #[doc(hidden)] pub bits: [Boolean; N], #[doc(hidden)] pub value: Option, } -impl R1CSVar for UInt { +impl R1CSVar for UInt { type Value = T; fn cs(&self) -> ConstraintSystemRef { @@ -46,7 +49,12 @@ impl R1CSVar for UInt } } -impl UInt { +impl UInt { + pub const MAX: Self = Self { + bits: [Boolean::TRUE; N], + value: Some(T::MAX), + }; + /// Construct a constant [`UInt`] from the native unsigned integer type. /// /// This *does not* create new variables or constraints. @@ -121,7 +129,7 @@ impl UInt { } } -impl AllocVar +impl AllocVar for UInt { fn new_variable>( @@ -148,78 +156,3 @@ impl AllocVar Result<(), SynthesisError> { -// let mut rng = ark_std::test_rng(); - -// for _ in 0..1000 { -// let v = (0..$size) -// .map(|_| Boolean::constant(rng.gen())) -// .collect::>>(); - -// let b = UInt::from_bits_le(&v); - -// for (i, bit) in b.bits.iter().enumerate() { -// match bit { -// &Boolean::Constant(bit) => { -// assert_eq!(bit, ((b.value()? >> i) & 1 == 1)); -// }, -// _ => unreachable!(), -// } -// } - -// let expected_to_be_same = b.to_bits_le(); - -// for x in v.iter().zip(expected_to_be_same.iter()) { -// match x { -// (&Boolean::TRUE, &Boolean::TRUE) => {}, -// (&Boolean::FALSE, &Boolean::FALSE) => {}, -// _ => unreachable!(), -// } -// } -// } -// Ok(()) -// } - -// #[test] -// fn test_add_many_constants() -> Result<(), SynthesisError> { -// let mut rng = ark_std::test_rng(); - -// for _ in 0..1000 { -// let cs = ConstraintSystem::::new_ref(); - -// let a: $native = rng.gen(); -// let b: $native = rng.gen(); -// let c: $native = rng.gen(); - -// let a_bit = $name::new_constant(cs.clone(), a)?; -// let b_bit = $name::new_constant(cs.clone(), b)?; -// let c_bit = $name::new_constant(cs.clone(), c)?; - -// let mut expected = a.wrapping_add(b).wrapping_add(c); - -// let r = $name::add_many(&[a_bit, b_bit, c_bit]).unwrap(); - -// assert!(r.value == Some(expected)); - -// for b in r.bits.iter() { -// match b { -// Boolean::Is(_) => unreachable!(), -// Boolean::Not(_) => unreachable!(), -// Boolean::Constant(b) => assert_eq!(*b, (expected & 1 == 1)), -// } - -// expected >>= 1; -// } -// } -// Ok(()) -// } diff --git a/src/uint/not.rs b/src/uint/not.rs index db74cdb6..1bb883d5 100644 --- a/src/uint/not.rs +++ b/src/uint/not.rs @@ -1,11 +1,10 @@ use ark_ff::Field; use ark_relations::r1cs::SynthesisError; -use ark_std::{fmt::Debug, ops::Not}; -use num_traits::PrimInt; +use ark_std::ops::Not; -use super::UInt; +use super::*; -impl UInt { +impl UInt { fn _not(&self) -> Result { let mut result = self.clone(); for a in &mut result.bits { @@ -16,7 +15,7 @@ impl UInt { } } -impl<'a, const N: usize, T: PrimInt + Debug, F: Field> Not for &'a UInt { +impl<'a, const N: usize, T: PrimUInt, F: Field> Not for &'a UInt { type Output = UInt; /// Outputs `!self`. /// @@ -44,7 +43,7 @@ impl<'a, const N: usize, T: PrimInt + Debug, F: Field> Not for &'a UInt } } -impl<'a, const N: usize, T: PrimInt + Debug, F: Field> Not for UInt { +impl<'a, const N: usize, T: PrimUInt, F: Field> Not for UInt { type Output = UInt; /// Outputs `!self`. @@ -85,7 +84,7 @@ mod tests { use ark_ff::PrimeField; use ark_test_curves::bls12_381::Fr; - fn uint_not( + fn uint_not( a: UInt, ) -> Result<(), SynthesisError> { let cs = a.cs(); diff --git a/src/uint/or.rs b/src/uint/or.rs index 823617ec..c69fc8db 100644 --- a/src/uint/or.rs +++ b/src/uint/or.rs @@ -1,11 +1,10 @@ use ark_ff::PrimeField; use ark_relations::r1cs::SynthesisError; -use ark_std::{fmt::Debug, ops::BitOr, ops::BitOrAssign}; -use num_traits::PrimInt; +use ark_std::{ops::BitOr, ops::BitOrAssign}; -use super::UInt; +use super::{PrimUInt, UInt}; -impl UInt { +impl UInt { fn _or(&self, other: &Self) -> Result { let mut result = self.clone(); for (a, b) in result.bits.iter_mut().zip(&other.bits) { @@ -16,7 +15,7 @@ impl UInt { } } -impl<'a, const N: usize, T: PrimInt + Debug, F: PrimeField> BitOr for &'a UInt { +impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOr for &'a UInt { type Output = UInt; /// Output `self | other`. @@ -47,7 +46,7 @@ impl<'a, const N: usize, T: PrimInt + Debug, F: PrimeField> BitOr for &'a } } -impl<'a, const N: usize, T: PrimInt + Debug, F: PrimeField> BitOr<&'a Self> for UInt { +impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOr<&'a Self> for UInt { type Output = UInt; #[tracing::instrument(target = "r1cs", skip(self, other))] @@ -56,9 +55,7 @@ impl<'a, const N: usize, T: PrimInt + Debug, F: PrimeField> BitOr<&'a Self> for } } -impl<'a, const N: usize, T: PrimInt + Debug, F: PrimeField> BitOr> - for &'a UInt -{ +impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOr> for &'a UInt { type Output = UInt; #[tracing::instrument(target = "r1cs", skip(self, other))] @@ -67,7 +64,7 @@ impl<'a, const N: usize, T: PrimInt + Debug, F: PrimeField> BitOr> } } -impl BitOr for UInt { +impl BitOr for UInt { type Output = Self; #[tracing::instrument(target = "r1cs", skip(self, other))] @@ -76,7 +73,7 @@ impl BitOr for UInt BitOrAssign for UInt { +impl BitOrAssign for UInt { /// Sets `self = self | other`. /// /// If at least one of `self` and `other` are constants, then this method @@ -107,9 +104,7 @@ impl BitOrAssign for UI } } -impl<'a, const N: usize, T: PrimInt + Debug, F: PrimeField> BitOrAssign<&'a Self> - for UInt -{ +impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOrAssign<&'a Self> for UInt { #[tracing::instrument(target = "r1cs", skip(self, other))] fn bitor_assign(&mut self, other: &'a Self) { let result = self._or(other).unwrap(); @@ -129,7 +124,7 @@ mod tests { use ark_ff::PrimeField; use ark_test_curves::bls12_381::Fr; - fn uint_or( + fn uint_or( a: UInt, b: UInt, ) -> Result<(), SynthesisError> { diff --git a/src/uint/prim_uint.rs b/src/uint/prim_uint.rs new file mode 100644 index 00000000..43aaa55e --- /dev/null +++ b/src/uint/prim_uint.rs @@ -0,0 +1,136 @@ +#[doc(hidden)] +// Adapted from +pub trait PrimUInt: + core::fmt::Debug + + num_traits::PrimInt + + num_traits::WrappingAdd + + num_traits::SaturatingAdd + + ark_std::UniformRand +{ + type Bytes: NumBytes; + const MAX: Self; + #[doc(hidden)] + const MAX_VALUE_BIT_DECOMP: &'static [bool]; + + /// Return the memory representation of this number as a byte array in little-endian byte order. + /// + /// # Examples + /// + /// ``` + /// use ark_r1cs_std::uint::PrimUInt; + /// + /// let bytes = ToBytes::to_le_bytes(&0x12345678u32); + /// assert_eq!(bytes, [0x78, 0x56, 0x34, 0x12]); + /// ``` + fn to_le_bytes(&self) -> Self::Bytes; + + /// Return the memory representation of this number as a byte array in big-endian byte order. + /// + /// # Examples + /// + /// ``` + /// use ark_r1cs_std::uint::PrimUInt; + /// + /// let bytes = ToBytes::to_be_bytes(&0x12345678u32); + /// assert_eq!(bytes, [0x12, 0x34, 0x56, 0x78]); + /// ``` + fn to_be_bytes(&self) -> Self::Bytes; +} + +impl PrimUInt for u8 { + const MAX: Self = u8::MAX; + const MAX_VALUE_BIT_DECOMP: &'static [bool] = &[true; 8]; + type Bytes = [u8; 1]; + + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + u8::to_le_bytes(*self) + } + + #[inline] + fn to_be_bytes(&self) -> Self::Bytes { + u8::to_be_bytes(*self) + } +} + +impl PrimUInt for u16 { + const MAX: Self = u16::MAX; + const MAX_VALUE_BIT_DECOMP: &'static [bool] = &[true; 16]; + type Bytes = [u8; 2]; + + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + u16::to_le_bytes(*self) + } + + #[inline] + fn to_be_bytes(&self) -> Self::Bytes { + u16::to_be_bytes(*self) + } +} + +impl PrimUInt for u32 { + const MAX: Self = u32::MAX; + const MAX_VALUE_BIT_DECOMP: &'static [bool] = &[true; 32]; + type Bytes = [u8; 4]; + + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + u32::to_le_bytes(*self) + } + + #[inline] + fn to_be_bytes(&self) -> Self::Bytes { + u32::to_be_bytes(*self) + } +} + +impl PrimUInt for u64 { + const MAX: Self = u64::MAX; + const MAX_VALUE_BIT_DECOMP: &'static [bool] = &[true; 64]; + type Bytes = [u8; 8]; + + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + u64::to_le_bytes(*self) + } + + #[inline] + fn to_be_bytes(&self) -> Self::Bytes { + u64::to_be_bytes(*self) + } +} + +impl PrimUInt for u128 { + const MAX: Self = u128::MAX; + const MAX_VALUE_BIT_DECOMP: &'static [bool] = &[true; 128]; + type Bytes = [u8; 16]; + + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + u128::to_le_bytes(*self) + } + + #[inline] + fn to_be_bytes(&self) -> Self::Bytes { + u128::to_be_bytes(*self) + } +} + +#[doc(hidden)] +pub trait NumBytes: + core::fmt::Debug + + AsRef<[u8]> + + AsMut<[u8]> + + PartialEq + + Eq + + PartialOrd + + Ord + + core::hash::Hash + + core::borrow::Borrow<[u8]> + + core::borrow::BorrowMut<[u8]> +{ +} + +#[doc(hidden)] +impl NumBytes for [u8; N] {} diff --git a/src/uint/rotate.rs b/src/uint/rotate.rs index 394c7b62..f2d50d5d 100644 --- a/src/uint/rotate.rs +++ b/src/uint/rotate.rs @@ -1,6 +1,6 @@ use super::*; -impl UInt { +impl UInt { /// Rotates `self` to the right by `by` steps, wrapping around. /// /// # Examples @@ -72,7 +72,7 @@ mod tests { use ark_ff::PrimeField; use ark_test_curves::bls12_381::Fr; - fn uint_rotate_left( + fn uint_rotate_left( a: UInt, ) -> Result<(), SynthesisError> { let cs = a.cs(); @@ -97,7 +97,7 @@ mod tests { Ok(()) } - fn uint_rotate_right( + fn uint_rotate_right( a: UInt, ) -> Result<(), SynthesisError> { let cs = a.cs(); diff --git a/src/uint/select.rs b/src/uint/select.rs index d891ab34..6c061fe2 100644 --- a/src/uint/select.rs +++ b/src/uint/select.rs @@ -1,7 +1,7 @@ use super::*; use crate::select::CondSelectGadget; -impl CondSelectGadget +impl CondSelectGadget for UInt { #[tracing::instrument(target = "r1cs", skip(cond, true_value, false_value))] @@ -42,7 +42,7 @@ mod tests { use ark_ff::PrimeField; use ark_test_curves::bls12_381::Fr; - fn uint_select( + fn uint_select( a: UInt, b: UInt, ) -> Result<(), SynthesisError> { diff --git a/src/uint/test_utils.rs b/src/uint/test_utils.rs index 6686dc19..13716833 100644 --- a/src/uint/test_utils.rs +++ b/src/uint/test_utils.rs @@ -1,13 +1,11 @@ +use ark_relations::r1cs::{ConstraintSystem, SynthesisError}; +use std::ops::RangeInclusive; + use crate::test_utils::{self, modes}; use super::*; -use ark_relations::r1cs::{ConstraintSystem, SynthesisError}; -use ark_std::UniformRand; -use num_traits::PrimInt; - -use std::ops::RangeInclusive; -pub(crate) fn test_unary_op( +pub(crate) fn test_unary_op( a: T, mode: AllocationMode, test: impl FnOnce(UInt) -> Result<(), SynthesisError>, @@ -17,7 +15,7 @@ pub(crate) fn test_unary_op( test(a) } -pub(crate) fn test_binary_op( +pub(crate) fn test_binary_op( a: T, b: T, mode_a: AllocationMode, @@ -34,7 +32,7 @@ pub(crate) fn run_binary_random( test: impl Fn(UInt, UInt) -> Result<(), SynthesisError> + Copy, ) -> Result<(), SynthesisError> where - T: PrimInt + Debug + UniformRand, + T: PrimUInt, F: PrimeField, { let mut rng = ark_std::test_rng(); @@ -55,7 +53,7 @@ pub(crate) fn run_binary_exhaustive( test: impl Fn(UInt, UInt) -> Result<(), SynthesisError> + Copy, ) -> Result<(), SynthesisError> where - T: PrimInt + Debug + UniformRand, + T: PrimUInt, F: PrimeField, RangeInclusive: Iterator, { @@ -71,7 +69,7 @@ pub(crate) fn run_unary_random( test: impl Fn(UInt) -> Result<(), SynthesisError> + Copy, ) -> Result<(), SynthesisError> where - T: PrimInt + Debug + UniformRand, + T: PrimUInt, F: PrimeField, { let mut rng = ark_std::test_rng(); @@ -89,7 +87,7 @@ pub(crate) fn run_unary_exhaustive( test: impl Fn(UInt) -> Result<(), SynthesisError> + Copy, ) -> Result<(), SynthesisError> where - T: PrimInt + Debug + UniformRand, + T: PrimUInt, F: PrimeField, RangeInclusive: Iterator, { diff --git a/src/uint/xor.rs b/src/uint/xor.rs index 49e26a08..22f52398 100644 --- a/src/uint/xor.rs +++ b/src/uint/xor.rs @@ -1,11 +1,10 @@ use ark_ff::Field; use ark_relations::r1cs::SynthesisError; -use ark_std::{fmt::Debug, ops::BitXor, ops::BitXorAssign}; -use num_traits::PrimInt; +use ark_std::{ops::BitXor, ops::BitXorAssign}; -use super::UInt; +use super::*; -impl UInt { +impl UInt { fn _xor(&self, other: &Self) -> Result { let mut result = self.clone(); for (a, b) in result.bits.iter_mut().zip(&other.bits) { @@ -16,7 +15,7 @@ impl UInt { } } -impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitXor for &'a UInt { +impl<'a, const N: usize, T: PrimUInt, F: Field> BitXor for &'a UInt { type Output = UInt; /// Outputs `self ^ other`. /// @@ -46,7 +45,7 @@ impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitXor for &'a UInt } } -impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitXor<&'a Self> for UInt { +impl<'a, const N: usize, T: PrimUInt, F: Field> BitXor<&'a Self> for UInt { type Output = UInt; #[tracing::instrument(target = "r1cs", skip(self, other))] @@ -55,7 +54,7 @@ impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitXor<&'a Self> for UInt } } -impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitXor> for &'a UInt { +impl<'a, const N: usize, T: PrimUInt, F: Field> BitXor> for &'a UInt { type Output = UInt; #[tracing::instrument(target = "r1cs", skip(self, other))] @@ -64,7 +63,7 @@ impl<'a, const N: usize, T: PrimInt + Debug, F: Field> BitXor> for } } -impl BitXor for UInt { +impl BitXor for UInt { type Output = Self; #[tracing::instrument(target = "r1cs", skip(self, other))] @@ -73,7 +72,7 @@ impl BitXor for UInt BitXorAssign for UInt { +impl BitXorAssign for UInt { /// Sets `self = self ^ other`. /// /// If at least one of `self` and `other` are constants, then this method @@ -104,7 +103,7 @@ impl BitXorAssign for UInt BitXorAssign<&'a Self> for UInt { +impl<'a, const N: usize, T: PrimUInt, F: Field> BitXorAssign<&'a Self> for UInt { #[tracing::instrument(target = "r1cs", skip(self, other))] fn bitxor_assign(&mut self, other: &'a Self) { let result = self._xor(other).unwrap(); @@ -124,7 +123,7 @@ mod tests { use ark_ff::PrimeField; use ark_test_curves::bls12_381::Fr; - fn uint_xor( + fn uint_xor( a: UInt, b: UInt, ) -> Result<(), SynthesisError> { diff --git a/src/uint8.rs b/src/uint8.rs index 0797cdbc..2d2fdfdc 100644 --- a/src/uint8.rs +++ b/src/uint8.rs @@ -129,7 +129,6 @@ mod test { let byte_vals = (64u8..128u8).collect::>(); let bytes = UInt8::new_input_vec(ark_relations::ns!(cs, "alloc value"), &byte_vals).unwrap(); - dbg!(bytes.value())?; for (native, variable) in byte_vals.into_iter().zip(bytes) { let bits = variable.to_bits_le()?; for (i, bit) in bits.iter().enumerate() { @@ -193,17 +192,12 @@ mod test { let a_bit = UInt8::new_witness(ark_relations::ns!(cs, "a_bit"), || Ok(a)).unwrap(); let b_bit = UInt8::constant(b); let c_bit = UInt8::new_witness(ark_relations::ns!(cs, "c_bit"), || Ok(c)).unwrap(); - dbg!(a_bit.value()?); - dbg!(b_bit.value()?); - dbg!(c_bit.value()?); let mut r = a_bit ^ b_bit; r ^= &c_bit; assert!(cs.is_satisfied().unwrap()); - dbg!(expected); - dbg!(r.value()?); assert_eq!(r.value, Some(expected)); for b in r.bits.iter() { From 4cee8fd938bc4d42f5b660a0b1cf8f43dcfc920f Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Thu, 28 Dec 2023 15:25:07 -0500 Subject: [PATCH 22/29] Fix --- src/boolean/mod.rs | 4 +++- src/groups/curves/short_weierstrass/mod.rs | 4 ++-- src/groups/curves/twisted_edwards/mod.rs | 2 +- src/groups/mod.rs | 2 +- tests/to_constraint_field_test.rs | 3 +-- 5 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/boolean/mod.rs b/src/boolean/mod.rs index 401c799d..0193f58b 100644 --- a/src/boolean/mod.rs +++ b/src/boolean/mod.rs @@ -204,7 +204,9 @@ mod test { use super::Boolean; use crate::convert::ToBytesGadget; use crate::prelude::*; - use ark_ff::{BitIteratorBE, BitIteratorLE, Field, One, PrimeField, UniformRand}; + use ark_ff::{ + AdditiveGroup, BitIteratorBE, BitIteratorLE, Field, One, PrimeField, UniformRand, + }; use ark_relations::r1cs::{ConstraintSystem, SynthesisError}; use ark_test_curves::bls12_381::Fr; diff --git a/src/groups/curves/short_weierstrass/mod.rs b/src/groups/curves/short_weierstrass/mod.rs index 23e0802d..82743058 100644 --- a/src/groups/curves/short_weierstrass/mod.rs +++ b/src/groups/curves/short_weierstrass/mod.rs @@ -9,7 +9,7 @@ use non_zero_affine::NonZeroAffineVar; use crate::{ convert::{ToBitsGadget, ToBytesGadget, ToConstraintFieldGadget}, - fields::{fp::FpVar, emulated_fp::EmulatedFpVar}, + fields::{emulated_fp::EmulatedFpVar, fp::FpVar}, prelude::*, Vec, }; @@ -979,10 +979,10 @@ where mod test_sw_curve { use crate::{ alloc::AllocVar, + convert::ToBitsGadget, eq::EqGadget, fields::{emulated_fp::EmulatedFpVar, fp::FpVar}, groups::{curves::short_weierstrass::ProjectiveVar, CurveVar}, - ToBitsGadget, }; use ark_ec::{ short_weierstrass::{Projective, SWCurveConfig}, diff --git a/src/groups/curves/twisted_edwards/mod.rs b/src/groups/curves/twisted_edwards/mod.rs index 82413c3f..f88431d9 100644 --- a/src/groups/curves/twisted_edwards/mod.rs +++ b/src/groups/curves/twisted_edwards/mod.rs @@ -10,8 +10,8 @@ use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError}; use crate::{ convert::{ToBitsGadget, ToBytesGadget, ToConstraintFieldGadget}, - prelude::*, fields::emulated_fp::EmulatedFpVar, + prelude::*, Vec, }; diff --git a/src/groups/mod.rs b/src/groups/mod.rs index e3c53344..08edbd57 100644 --- a/src/groups/mod.rs +++ b/src/groups/mod.rs @@ -1,6 +1,6 @@ use crate::{ convert::{ToBitsGadget, ToBytesGadget}, - fields::emulated_fp::EmulatedFpVar, + fields::emulated_fp::EmulatedFpVar, prelude::*, }; use ark_ff::PrimeField; diff --git a/tests/to_constraint_field_test.rs b/tests/to_constraint_field_test.rs index 17829195..75c79c34 100644 --- a/tests/to_constraint_field_test.rs +++ b/tests/to_constraint_field_test.rs @@ -1,6 +1,5 @@ use ark_r1cs_std::{ - alloc::AllocVar, convert::ToConstraintFieldGadget, fields::emulated_fp::EmulatedFpVar, - R1CSVar, + alloc::AllocVar, convert::ToConstraintFieldGadget, fields::emulated_fp::EmulatedFpVar, R1CSVar, }; use ark_relations::r1cs::ConstraintSystem; From fc7dab5fae0051c4213e060140687ea1d2cfb4e6 Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Thu, 28 Dec 2023 15:52:07 -0500 Subject: [PATCH 23/29] `no-std` fix --- src/convert.rs | 1 + src/uint8.rs | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/convert.rs b/src/convert.rs index 0f1e3d31..e04abc48 100644 --- a/src/convert.rs +++ b/src/convert.rs @@ -1,5 +1,6 @@ use ark_ff::Field; use ark_relations::r1cs::SynthesisError; +use ark_std::vec::Vec; use crate::{boolean::Boolean, uint8::UInt8}; diff --git a/src/uint8.rs b/src/uint8.rs index 2d2fdfdc..1f952de4 100644 --- a/src/uint8.rs +++ b/src/uint8.rs @@ -75,7 +75,7 @@ impl UInt8 { /// `ConstraintF::MODULUS_BIT_SIZE - 1` chunks and converts each chunk, which is /// assumed to be little-endian, to its `FpVar` representation. /// This is the gadget counterpart to the `[u8]` implementation of -/// [ToConstraintField](ark_ff::ToConstraintField). +/// [`ToConstraintField``]. impl ToConstraintFieldGadget for [UInt8] { #[tracing::instrument(target = "r1cs")] fn to_constraint_field(&self) -> Result>, SynthesisError> { From 45544589edc44f6863e929dd2ef39891e050a964 Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Thu, 28 Dec 2023 15:56:11 -0500 Subject: [PATCH 24/29] doc fixes --- src/uint/prim_uint.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/uint/prim_uint.rs b/src/uint/prim_uint.rs index 43aaa55e..963db0f7 100644 --- a/src/uint/prim_uint.rs +++ b/src/uint/prim_uint.rs @@ -19,7 +19,7 @@ pub trait PrimUInt: /// ``` /// use ark_r1cs_std::uint::PrimUInt; /// - /// let bytes = ToBytes::to_le_bytes(&0x12345678u32); + /// let bytes = PrimUInt::to_le_bytes(&0x12345678u32); /// assert_eq!(bytes, [0x78, 0x56, 0x34, 0x12]); /// ``` fn to_le_bytes(&self) -> Self::Bytes; @@ -31,7 +31,7 @@ pub trait PrimUInt: /// ``` /// use ark_r1cs_std::uint::PrimUInt; /// - /// let bytes = ToBytes::to_be_bytes(&0x12345678u32); + /// let bytes = PrimUInt::to_be_bytes(&0x12345678u32); /// assert_eq!(bytes, [0x12, 0x34, 0x56, 0x78]); /// ``` fn to_be_bytes(&self) -> Self::Bytes; From e8f8b3b6addf8d131dc1e911ff64e6e8ac14d7bd Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Thu, 28 Dec 2023 16:23:49 -0500 Subject: [PATCH 25/29] Temp tweak for CI --- .github/workflows/ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b4f5893a..1d86ac3a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -155,6 +155,7 @@ jobs: uses: actions/checkout@v4 with: repository: arkworks-rs/algebra + ref: curve-constraint-test-fix - name: Checkout r1cs-std uses: actions/checkout@v4 From fd169969f7aa57e02022056d4308ff1de400000f Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Thu, 28 Dec 2023 17:23:31 -0500 Subject: [PATCH 26/29] Add `Shl` and `Shr` --- src/uint/mod.rs | 4 +- src/uint/prim_uint.rs | 40 +++++++++++ src/uint/shl.rs | 160 +++++++++++++++++++++++++++++++++++++++++ src/uint/shr.rs | 160 +++++++++++++++++++++++++++++++++++++++++ src/uint/test_utils.rs | 48 +++++++++++++ 5 files changed, 411 insertions(+), 1 deletion(-) create mode 100644 src/uint/shl.rs create mode 100644 src/uint/shr.rs diff --git a/src/uint/mod.rs b/src/uint/mod.rs index 26598675..11c5efac 100644 --- a/src/uint/mod.rs +++ b/src/uint/mod.rs @@ -12,6 +12,8 @@ mod convert; mod eq; mod not; mod or; +mod shl; +mod shr; mod rotate; mod select; mod xor; @@ -81,7 +83,7 @@ impl UInt { for i in 0..N { bits[i] = Boolean::constant((bit_values & T::one()) == T::one()); - bit_values = bit_values >> 1; + bit_values = bit_values >> 1u8; } Self { diff --git a/src/uint/prim_uint.rs b/src/uint/prim_uint.rs index 963db0f7..7f8c21ca 100644 --- a/src/uint/prim_uint.rs +++ b/src/uint/prim_uint.rs @@ -1,3 +1,6 @@ +use core::usize; +use core::ops::{Shl, ShlAssign, Shr, ShrAssign}; + #[doc(hidden)] // Adapted from pub trait PrimUInt: @@ -5,6 +8,32 @@ pub trait PrimUInt: + num_traits::PrimInt + num_traits::WrappingAdd + num_traits::SaturatingAdd + + Shl + + Shl + + Shl + + Shl + + Shl + + Shl + + Shr + + Shr + + Shr + + Shr + + Shr + + Shr + + ShlAssign + + ShlAssign + + ShlAssign + + ShlAssign + + ShlAssign + + ShlAssign + + ShrAssign + + ShrAssign + + ShrAssign + + ShrAssign + + ShrAssign + + ShrAssign + + Into + + _private::Sealed + ark_std::UniformRand { type Bytes: NumBytes; @@ -134,3 +163,14 @@ pub trait NumBytes: #[doc(hidden)] impl NumBytes for [u8; N] {} + + +mod _private { + pub trait Sealed {} + + impl Sealed for u8 {} + impl Sealed for u16 {} + impl Sealed for u32 {} + impl Sealed for u64 {} + impl Sealed for u128 {} +} \ No newline at end of file diff --git a/src/uint/shl.rs b/src/uint/shl.rs new file mode 100644 index 00000000..2aaf474c --- /dev/null +++ b/src/uint/shl.rs @@ -0,0 +1,160 @@ +use ark_ff::PrimeField; +use ark_relations::r1cs::SynthesisError; +use ark_std::{ops::Shl, ops::ShlAssign}; + +use crate::boolean::Boolean; + +use super::{PrimUInt, UInt}; + +impl UInt { + fn _shl_u128(&self, other: u128) -> Result { + if other < N as u128 { + let mut bits = [Boolean::FALSE; N]; + for (a, b) in bits[other as usize..].iter_mut().zip(&self.bits) { + *a = b.clone(); + } + + let value = self.value.and_then(|a| Some(a << other)); + Ok(Self { + bits, + value, + }) + } else { + panic!("attempt to shift left with overflow") + } + } +} + +impl Shl for UInt { + type Output = Self; + + /// Output `self << other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = 1; + /// let c = UInt8::new_witness(cs.clone(), || Ok(16 << 1))?; + /// + /// (a << 1).enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn shl(self, other: T2) -> Self::Output { + self._shl_u128(other.into()).unwrap() + } +} + +impl<'a, const N: usize, T: PrimUInt, F: PrimeField, T2: PrimUInt> Shl for &'a UInt { + type Output = UInt; + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn shl(self, other: T2) -> Self::Output { + self._shl_u128(other.into()).unwrap() + } +} + +impl ShlAssign for UInt { + /// Sets `self = self << other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let mut a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = 1; + /// let c = UInt8::new_witness(cs.clone(), || Ok(16 << 1))?; + /// + /// a <<= b; + /// a.enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn shl_assign(&mut self, other: T2) { + let result = self._shl_u128(other.into()).unwrap(); + *self = result; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + prelude::EqGadget, + uint::test_utils::{run_binary_exhaustive_with_native, run_binary_random_with_native}, + R1CSVar, + }; + use ark_ff::PrimeField; + use ark_test_curves::bls12_381::Fr; + + fn uint_shl( + a: UInt, + b: T, + ) -> Result<(), SynthesisError> { + let cs = a.cs(); + let b = b.into() % (N as u128); + let computed = &a << b; + let expected_mode = if a.is_constant() { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = UInt::::new_variable( + cs.clone(), + || Ok(a.value()? << b), + expected_mode, + )?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !a.is_constant() { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + #[test] + fn u8_shl() { + run_binary_exhaustive_with_native(uint_shl::).unwrap() + } + + #[test] + fn u16_shl() { + run_binary_random_with_native::<1000, 16, _, _>(uint_shl::).unwrap() + } + + #[test] + fn u32_shl() { + run_binary_random_with_native::<1000, 32, _, _>(uint_shl::).unwrap() + } + + #[test] + fn u64_shl() { + run_binary_random_with_native::<1000, 64, _, _>(uint_shl::).unwrap() + } + + #[test] + fn u128_shl() { + run_binary_random_with_native::<1000, 128, _, _>(uint_shl::).unwrap() + } +} diff --git a/src/uint/shr.rs b/src/uint/shr.rs new file mode 100644 index 00000000..dbcdc8d7 --- /dev/null +++ b/src/uint/shr.rs @@ -0,0 +1,160 @@ +use ark_ff::PrimeField; +use ark_relations::r1cs::SynthesisError; +use ark_std::{ops::Shr, ops::ShrAssign}; + +use crate::boolean::Boolean; + +use super::{PrimUInt, UInt}; + +impl UInt { + fn _shr_u128(&self, other: u128) -> Result { + if other < N as u128 { + let mut bits = [Boolean::FALSE; N]; + for (a, b) in bits.iter_mut().zip(&self.bits[other as usize..]) { + *a = b.clone(); + } + + let value = self.value.and_then(|a| Some(a >> other)); + Ok(Self { + bits, + value, + }) + } else { + panic!("attempt to shift right with overflow") + } + } +} + +impl Shr for UInt { + type Output = Self; + + /// Output `self >> other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = 1; + /// let c = UInt8::new_witness(cs.clone(), || Ok(16 >> 1))?; + /// + /// (a >> 1).enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn shr(self, other: T2) -> Self::Output { + self._shr_u128(other.into()).unwrap() + } +} + +impl<'a, const N: usize, T: PrimUInt, F: PrimeField, T2: PrimUInt> Shr for &'a UInt { + type Output = UInt; + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn shr(self, other: T2) -> Self::Output { + self._shr_u128(other.into()).unwrap() + } +} + +impl ShrAssign for UInt { + /// Sets `self = self >> other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let mut a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = 1; + /// let c = UInt8::new_witness(cs.clone(), || Ok(16 >> 1))?; + /// + /// a >> b; + /// a.enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn shr_assign(&mut self, other: T2) { + let result = self._shr_u128(other.into()).unwrap(); + *self = result; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + prelude::EqGadget, + uint::test_utils::{run_binary_exhaustive_with_native, run_binary_random_with_native}, + R1CSVar, + }; + use ark_ff::PrimeField; + use ark_test_curves::bls12_381::Fr; + + fn uint_shr( + a: UInt, + b: T, + ) -> Result<(), SynthesisError> { + let cs = a.cs(); + let b = b.into() % (N as u128); + let computed = &a >> b; + let expected_mode = if a.is_constant() { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = UInt::::new_variable( + cs.clone(), + || Ok(a.value()? >> b), + expected_mode, + )?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !a.is_constant() { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + #[test] + fn u8_shr() { + run_binary_exhaustive_with_native(uint_shr::).unwrap() + } + + #[test] + fn u16_shr() { + run_binary_random_with_native::<1000, 16, _, _>(uint_shr::).unwrap() + } + + #[test] + fn u32_shr() { + run_binary_random_with_native::<1000, 32, _, _>(uint_shr::).unwrap() + } + + #[test] + fn u64_shr() { + run_binary_random_with_native::<1000, 64, _, _>(uint_shr::).unwrap() + } + + #[test] + fn u128_shr() { + run_binary_random_with_native::<1000, 128, _, _>(uint_shr::).unwrap() + } +} diff --git a/src/uint/test_utils.rs b/src/uint/test_utils.rs index 13716833..5b3503e8 100644 --- a/src/uint/test_utils.rs +++ b/src/uint/test_utils.rs @@ -28,6 +28,17 @@ pub(crate) fn test_binary_op( test(a, b) } +pub(crate) fn test_binary_op_with_native( + a: T, + b: T, + mode_a: AllocationMode, + test: impl FnOnce(UInt, T) -> Result<(), SynthesisError>, +) -> Result<(), SynthesisError> { + let cs = ConstraintSystem::::new_ref(); + let a = UInt::::new_variable(cs.clone(), || Ok(a), mode_a)?; + test(a, b) +} + pub(crate) fn run_binary_random( test: impl Fn(UInt, UInt) -> Result<(), SynthesisError> + Copy, ) -> Result<(), SynthesisError> @@ -65,6 +76,43 @@ where Ok(()) } +pub(crate) fn run_binary_random_with_native( + test: impl Fn(UInt, T) -> Result<(), SynthesisError> + Copy, +) -> Result<(), SynthesisError> +where + T: PrimUInt, + F: PrimeField, +{ + let mut rng = ark_std::test_rng(); + + for _ in 0..ITERATIONS { + for mode_a in modes() { + let a = T::rand(&mut rng); + let b = T::rand(&mut rng); + test_binary_op_with_native(a, b, mode_a, test)?; + } + } + Ok(()) +} + +pub(crate) fn run_binary_exhaustive_with_native( + test: impl Fn(UInt, T) -> Result<(), SynthesisError> + Copy, +) -> Result<(), SynthesisError> +where + T: PrimUInt, + F: PrimeField, + RangeInclusive: Iterator, +{ + for (mode_a, a) in test_utils::combination(T::min_value()..=T::max_value()) { + for b in T::min_value()..=T::max_value() { + test_binary_op_with_native(a, b, mode_a, test)?; + } + } + Ok(()) +} + + + pub(crate) fn run_unary_random( test: impl Fn(UInt) -> Result<(), SynthesisError> + Copy, ) -> Result<(), SynthesisError> From 276979d51b97b6c056c02a15842fda5446347e9a Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Thu, 28 Dec 2023 17:25:08 -0500 Subject: [PATCH 27/29] Format --- src/uint/mod.rs | 4 ++-- src/uint/prim_uint.rs | 5 ++--- src/uint/shl.rs | 18 ++++++------------ src/uint/shr.rs | 20 +++++++------------- src/uint/test_utils.rs | 2 -- 5 files changed, 17 insertions(+), 32 deletions(-) diff --git a/src/uint/mod.rs b/src/uint/mod.rs index 11c5efac..1544f24a 100644 --- a/src/uint/mod.rs +++ b/src/uint/mod.rs @@ -12,10 +12,10 @@ mod convert; mod eq; mod not; mod or; -mod shl; -mod shr; mod rotate; mod select; +mod shl; +mod shr; mod xor; #[doc(hidden)] diff --git a/src/uint/prim_uint.rs b/src/uint/prim_uint.rs index 7f8c21ca..3b3fff81 100644 --- a/src/uint/prim_uint.rs +++ b/src/uint/prim_uint.rs @@ -1,5 +1,5 @@ -use core::usize; use core::ops::{Shl, ShlAssign, Shr, ShrAssign}; +use core::usize; #[doc(hidden)] // Adapted from @@ -164,7 +164,6 @@ pub trait NumBytes: #[doc(hidden)] impl NumBytes for [u8; N] {} - mod _private { pub trait Sealed {} @@ -173,4 +172,4 @@ mod _private { impl Sealed for u32 {} impl Sealed for u64 {} impl Sealed for u128 {} -} \ No newline at end of file +} diff --git a/src/uint/shl.rs b/src/uint/shl.rs index 2aaf474c..3ee5d90d 100644 --- a/src/uint/shl.rs +++ b/src/uint/shl.rs @@ -13,12 +13,9 @@ impl UInt { for (a, b) in bits[other as usize..].iter_mut().zip(&self.bits) { *a = b.clone(); } - + let value = self.value.and_then(|a| Some(a << other)); - Ok(Self { - bits, - value, - }) + Ok(Self { bits, value }) } else { panic!("attempt to shift left with overflow") } @@ -42,7 +39,7 @@ impl Shl for UInt< /// /// let cs = ConstraintSystem::::new_ref(); /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; - /// let b = 1; + /// let b = 1u8; /// let c = UInt8::new_witness(cs.clone(), || Ok(16 << 1))?; /// /// (a << 1).enforce_equal(&c)?; @@ -80,7 +77,7 @@ impl ShlAssign for /// /// let cs = ConstraintSystem::::new_ref(); /// let mut a = UInt8::new_witness(cs.clone(), || Ok(16))?; - /// let b = 1; + /// let b = 1u8; /// let c = UInt8::new_witness(cs.clone(), || Ok(16 << 1))?; /// /// a <<= b; @@ -120,11 +117,8 @@ mod tests { } else { AllocationMode::Witness }; - let expected = UInt::::new_variable( - cs.clone(), - || Ok(a.value()? << b), - expected_mode, - )?; + let expected = + UInt::::new_variable(cs.clone(), || Ok(a.value()? << b), expected_mode)?; assert_eq!(expected.value(), computed.value()); expected.enforce_equal(&computed)?; if !a.is_constant() { diff --git a/src/uint/shr.rs b/src/uint/shr.rs index dbcdc8d7..4e9ebe35 100644 --- a/src/uint/shr.rs +++ b/src/uint/shr.rs @@ -13,12 +13,9 @@ impl UInt { for (a, b) in bits.iter_mut().zip(&self.bits[other as usize..]) { *a = b.clone(); } - + let value = self.value.and_then(|a| Some(a >> other)); - Ok(Self { - bits, - value, - }) + Ok(Self { bits, value }) } else { panic!("attempt to shift right with overflow") } @@ -42,7 +39,7 @@ impl Shr for UInt< /// /// let cs = ConstraintSystem::::new_ref(); /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; - /// let b = 1; + /// let b = 1u8; /// let c = UInt8::new_witness(cs.clone(), || Ok(16 >> 1))?; /// /// (a >> 1).enforce_equal(&c)?; @@ -80,10 +77,10 @@ impl ShrAssign for /// /// let cs = ConstraintSystem::::new_ref(); /// let mut a = UInt8::new_witness(cs.clone(), || Ok(16))?; - /// let b = 1; + /// let b = 1u8; /// let c = UInt8::new_witness(cs.clone(), || Ok(16 >> 1))?; /// - /// a >> b; + /// a >>= b; /// a.enforce_equal(&c)?; /// assert!(cs.is_satisfied().unwrap()); /// # Ok(()) @@ -120,11 +117,8 @@ mod tests { } else { AllocationMode::Witness }; - let expected = UInt::::new_variable( - cs.clone(), - || Ok(a.value()? >> b), - expected_mode, - )?; + let expected = + UInt::::new_variable(cs.clone(), || Ok(a.value()? >> b), expected_mode)?; assert_eq!(expected.value(), computed.value()); expected.enforce_equal(&computed)?; if !a.is_constant() { diff --git a/src/uint/test_utils.rs b/src/uint/test_utils.rs index 5b3503e8..0600bbeb 100644 --- a/src/uint/test_utils.rs +++ b/src/uint/test_utils.rs @@ -111,8 +111,6 @@ where Ok(()) } - - pub(crate) fn run_unary_random( test: impl Fn(UInt) -> Result<(), SynthesisError> + Copy, ) -> Result<(), SynthesisError> From b04936b0d2a664a233264dafb0bce2f09ec2f218 Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Thu, 28 Dec 2023 17:58:58 -0500 Subject: [PATCH 28/29] Fix --- src/uint/shl.rs | 2 +- src/uint/shr.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/uint/shl.rs b/src/uint/shl.rs index 3ee5d90d..645a07a4 100644 --- a/src/uint/shl.rs +++ b/src/uint/shl.rs @@ -42,7 +42,7 @@ impl Shl for UInt< /// let b = 1u8; /// let c = UInt8::new_witness(cs.clone(), || Ok(16 << 1))?; /// - /// (a << 1).enforce_equal(&c)?; + /// (a << b).enforce_equal(&c)?; /// assert!(cs.is_satisfied().unwrap()); /// # Ok(()) /// # } diff --git a/src/uint/shr.rs b/src/uint/shr.rs index 4e9ebe35..7630855c 100644 --- a/src/uint/shr.rs +++ b/src/uint/shr.rs @@ -42,7 +42,7 @@ impl Shr for UInt< /// let b = 1u8; /// let c = UInt8::new_witness(cs.clone(), || Ok(16 >> 1))?; /// - /// (a >> 1).enforce_equal(&c)?; + /// (a >> b).enforce_equal(&c)?; /// assert!(cs.is_satisfied().unwrap()); /// # Ok(()) /// # } From 061fc3001420cede834872431732a2cea08bad36 Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Thu, 28 Dec 2023 18:07:09 -0500 Subject: [PATCH 29/29] Revert CI change --- .github/workflows/ci.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1d86ac3a..b4f5893a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -155,7 +155,6 @@ jobs: uses: actions/checkout@v4 with: repository: arkworks-rs/algebra - ref: curve-constraint-test-fix - name: Checkout r1cs-std uses: actions/checkout@v4