Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ark-bls12-381 #16

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ license-file = "LICENSE"
keywords = ["zkSNARKs", "cryptography", "proofs"]

[dependencies]
bellpepper-core = "0.2.1"
ff = { version = "0.13.0", features = ["derive"] }
digest = "0.10"
sha3 = "0.10"
rayon = "1.7"
Expand All @@ -27,14 +25,23 @@ num-traits = "0.2"
num-integer = "0.1"
serde = { version = "1.0", features = ["derive"] }
bincode = "1.3"
flate2 = "1.0"
bitvec = "1.0"
byteorder = "1.4.3"
thiserror = "1.0"
halo2curves = { version = "0.4.0", features = ["derive_serde"] }
group = "0.13.0"
once_cell = "1.18.0"

# Arkworks
ark-ec = { version = "0.5.0", default-features = false }
ark-ff = { version = "0.5.0", default-features = false }
ark-std = { version = "0.5.0", default-features = false }
ark-serialize = { version = "0.5.0", default-features = false, features = ["derive"] }
ark-bls12-381 = { version = "0.5.0", default-features = false, features = ["curve"] }
ark-relations = { version = "0.5.1", default-features = false }
zeroize = { version = "1.8.1", default-features = false }
serde_with = "3.12.0"
sha2 = "0.10.7"
ark-r1cs-std = "0.5.0"

[target.'cfg(any(target_arch = "x86_64", target_arch = "aarch64"))'.dependencies]
pasta-msm = { version = "0.1.4" }

Expand All @@ -52,7 +59,11 @@ sha2 = "0.10.7"
proptest = "1.2.0"

[features]
default = []
default = [
"ark-ec/parallel",
"ark-ff/parallel",
"ark-std/parallel",
]
# Compiles in portable mode, w/o ISA extensions => binary can be executed on all systems.
portable = ["pasta-msm/portable"]
flamegraph = ["pprof/flamegraph", "pprof/criterion"]
Expand Down
10 changes: 5 additions & 5 deletions benches/bench.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
extern crate core;

use ark_bls12_381::Fr;
use ark_ff::UniformRand;
use criterion::{criterion_group, criterion_main, Criterion};
use ff::Field;
use pasta_curves::Fp;

use spartan2::spartan::polys::eq::EqPolynomial;

fn benchmarks_evaluate_incremental(c: &mut Criterion) {
let mut group = c.benchmark_group("evaluate_incremental");
(1..=20).step_by(2).for_each(|i| {
let random_point: Vec<Fp> = (0..2usize.pow(i))
.map(|_| Fp::random(&mut rand::thread_rng()))
let random_point: Vec<Fr> = (0..2usize.pow(i))
.map(|_| Fr::rand(&mut rand::thread_rng()))
.collect();
let random_polynomial = EqPolynomial::new(
(0..2usize.pow(i))
.map(|_| Fp::random(&mut rand::thread_rng()))
.map(|_| Fr::rand(&mut rand::thread_rng()))
.collect(),
);
group.bench_with_input(format!("2^{}", i), &i, |b, &_i| {
Expand Down
179 changes: 98 additions & 81 deletions examples/less_than.rs
Original file line number Diff line number Diff line change
@@ -1,62 +1,71 @@
use bellpepper_core::{
boolean::AllocatedBit, num::AllocatedNum, Circuit, ConstraintSystem, LinearCombination,
SynthesisError,
use ark_bls12_381::Fr;
use ark_ff::{BigInteger, PrimeField};
use ark_r1cs_std::alloc::AllocVar;
use ark_r1cs_std::boolean::AllocatedBool;
use ark_r1cs_std::fields::fp::AllocatedFp;
use ark_relations::r1cs::{
ConstraintSynthesizer, ConstraintSystemRef, Namespace, SynthesisError, Variable,
};
use ff::{PrimeField, PrimeFieldBits};
use pasta_curves::Fq;
use ark_relations::{lc, ns};
use num_traits::One;
use spartan2::{
errors::SpartanError,
traits::{snark::RelaxedR1CSSNARKTrait, Group},
SNARK,
};

fn num_to_bits_le_bounded<F: PrimeField + PrimeFieldBits, CS: ConstraintSystem<F>>(
cs: &mut CS,
n: AllocatedNum<F>,
fn num_to_bits_le_bounded<F: PrimeField>(
cs: ConstraintSystemRef<F>,
n: AllocatedFp<F>,
num_bits: u8,
) -> Result<Vec<AllocatedBit>, SynthesisError> {
let opt_bits = match n.get_value() {
Some(v) => v
.to_le_bits()
.into_iter()
.take(num_bits as usize)
.map(Some)
.collect::<Vec<Option<bool>>>(),
None => vec![None; num_bits as usize],
};
) -> Result<Vec<AllocatedBool<F>>, SynthesisError> {
let opt_bits = n
.value()?
.into_bigint()
.to_bits_le()
.into_iter()
.take(num_bits as usize)
.map(Some)
.collect::<Vec<Option<bool>>>();

// Add one witness per input bit in little-endian bit order
let bits_circuit = opt_bits.into_iter()
.enumerate()
// AllocateBit enforces the value to be 0 or 1 at the constraint level
.map(|(i, b)| AllocatedBit::alloc(cs.namespace(|| format!("b_{}", i)), b).unwrap())
.collect::<Vec<AllocatedBit>>();

let mut weighted_sum_lc = LinearCombination::zero();
// AllocatedBool enforces the value to be 0 or 1 at the constraint level
.map(|(_i, b)| {
// TODO: Why do I need namespaces here?
// TODO: Namespace can't use string ids, only const ids
// let namespaced_cs = Namespace::from(cs.clone());
AllocatedBool::<F>::new_witness(cs.clone(), || b.ok_or(SynthesisError::AssignmentMissing))
})
.collect::<Result<Vec<AllocatedBool<F>>, SynthesisError>>()?;

let mut weighted_sum_lc = lc!();
let mut pow2 = F::ONE;

for bit in bits_circuit.iter() {
weighted_sum_lc = weighted_sum_lc + (pow2, bit.get_variable());
weighted_sum_lc = weighted_sum_lc + (pow2, bit.variable());
pow2 = pow2.double();
}

cs.enforce(
|| "bit decomposition check",
|lc| lc + &weighted_sum_lc,
|lc| lc + CS::one(),
|lc| lc + n.get_variable(),
);
// weighted_sum_lc == n
let constraint_lc = weighted_sum_lc - n.variable;

// Enforce constraint_lc == 0
let one_lc = lc!() + Variable::One;
cs.enforce_constraint(constraint_lc, one_lc, lc!())?;

Ok(bits_circuit)
}

fn get_msb_index<F: PrimeField + PrimeFieldBits>(n: F) -> u8 {
n.to_le_bits()
fn get_msb_index<F: PrimeField>(n: F) -> u8 {
n.into_bigint()
.to_bits_le()
.into_iter()
.enumerate()
.rev()
.find(|(_, b)| *b)
.unwrap()
.expect("Index not found")
.0 as u8
}

Expand All @@ -69,12 +78,12 @@ fn get_msb_index<F: PrimeField + PrimeFieldBits>(n: F) -> u8 {
// a safe version.
#[derive(Clone, Debug)]
struct LessThanCircuitUnsafe<F: PrimeField> {
bound: F, // Will be a constant in the constraits, not a variable
bound: F, // Will be a constant in the constraints, not a variable
input: F, // Will be an input/output variable
num_bits: u8,
}

impl<F: PrimeField + PrimeFieldBits> LessThanCircuitUnsafe<F> {
impl<F: PrimeField> LessThanCircuitUnsafe<F> {
fn new(bound: F, input: F, num_bits: u8) -> Self {
assert!(get_msb_index(bound) < num_bits);
Self {
Expand All @@ -85,32 +94,37 @@ impl<F: PrimeField + PrimeFieldBits> LessThanCircuitUnsafe<F> {
}
}

impl<F: PrimeField + PrimeFieldBits> Circuit<F> for LessThanCircuitUnsafe<F> {
fn synthesize<CS: ConstraintSystem<F>>(self, cs: &mut CS) -> Result<(), SynthesisError> {
assert!(F::NUM_BITS > self.num_bits as u32 + 1);
impl<F: PrimeField> ConstraintSynthesizer<F> for LessThanCircuitUnsafe<F> {
fn generate_constraints(self, cs: ConstraintSystemRef<F>) -> Result<(), SynthesisError> {
assert!(F::MODULUS_BIT_SIZE > self.num_bits as u32 + 1);

let input = AllocatedNum::alloc(cs.namespace(|| "input"), || Ok(self.input))?;
let input_ns = ns!(cs.clone(), "input");
let input = AllocatedFp::<F>::new_witness(input_ns, || Ok(self.input))?;

let shifted_diff = AllocatedNum::alloc(cs.namespace(|| "shifted_diff"), || {
let shifted_ns = ns!(cs.clone(), "shifted_diff");
let shifted_diff = AllocatedFp::<F>::new_witness(shifted_ns, || {
Ok(self.input + F::from(1 << self.num_bits) - self.bound)
})?;

cs.enforce(
|| "shifted_diff_computation",
|lc| lc + input.get_variable() + (F::from(1 << self.num_bits) - self.bound, CS::one()),
|lc: LinearCombination<F>| lc + CS::one(),
|lc| lc + shifted_diff.get_variable(),
);
let shifted_diff_lc =
lc!() + (F::ONE, input.variable) + (F::from(1 << self.num_bits) - self.bound, Variable::One)
- (F::ONE, shifted_diff.variable);

// Enforce shifted_diff_lc == 0
cs.enforce_constraint(shifted_diff_lc, lc!() + (F::ONE, Variable::One), lc!())?;

let shifted_diff_bits =
num_to_bits_le_bounded::<F>(cs.clone(), shifted_diff, self.num_bits + 1)?;

let shifted_diff_bits = num_to_bits_le_bounded::<F, CS>(cs, shifted_diff, self.num_bits + 1)?;
// Check that the most significant bit is 0
let msb_var = shifted_diff_bits[self.num_bits as usize].variable();

// Check that the last (i.e. most sifnificant) bit is 0
cs.enforce(
|| "bound_check",
|lc| lc + shifted_diff_bits[self.num_bits as usize].get_variable(),
|lc| lc + CS::one(),
|lc| lc + (F::ZERO, CS::one()),
);
// Enforce the constraint that the most significant bit is 0
cs.enforce_constraint(
lc!() + (F::ONE, msb_var),
lc!() + (F::ONE, Variable::One),
lc!(),
)?;

Ok(())
}
Expand All @@ -122,13 +136,13 @@ impl<F: PrimeField + PrimeFieldBits> Circuit<F> for LessThanCircuitUnsafe<F> {
// Furthermore, the input must fit into `num_bits`, which is enforced at the
// constraint level.
#[derive(Clone, Debug)]
struct LessThanCircuitSafe<F: PrimeField + PrimeFieldBits> {
struct LessThanCircuitSafe<F: PrimeField> {
bound: F,
input: F,
num_bits: u8,
}

impl<F: PrimeField + PrimeFieldBits> LessThanCircuitSafe<F> {
impl<F: PrimeField> LessThanCircuitSafe<F> {
fn new(bound: F, input: F, num_bits: u8) -> Self {
assert!(get_msb_index(bound) < num_bits);
Self {
Expand All @@ -139,23 +153,26 @@ impl<F: PrimeField + PrimeFieldBits> LessThanCircuitSafe<F> {
}
}

impl<F: PrimeField + PrimeFieldBits> Circuit<F> for LessThanCircuitSafe<F> {
fn synthesize<CS: ConstraintSystem<F>>(self, cs: &mut CS) -> Result<(), SynthesisError> {
let input = AllocatedNum::alloc(cs.namespace(|| "input"), || Ok(self.input))?;
impl<F: PrimeField> ConstraintSynthesizer<F> for LessThanCircuitSafe<F> {
fn generate_constraints(self, cs: ConstraintSystemRef<F>) -> Result<(), SynthesisError> {
// TODO: Do we need to use a namespace here?
let input_ns = Namespace::from(cs.clone());
let input = AllocatedFp::<F>::new_witness(input_ns, || Ok(self.input))?;

// Perform the input bit decomposition check
num_to_bits_le_bounded::<F, CS>(cs, input, self.num_bits)?;
num_to_bits_le_bounded::<F>(cs.clone(), input, self.num_bits)?;

// TODO: Not sure how/why to do this in Arkworks
// Entering a new namespace to prefix variables in the
// LessThanCircuitUnsafe, thus avoiding name clashes
cs.push_namespace(|| "less_than_safe");
// cs.push_namespace(|| "less_than_safe");

LessThanCircuitUnsafe {
bound: self.bound,
input: self.input,
num_bits: self.num_bits,
}
.synthesize(cs)
.generate_constraints(cs)
}
}

Expand All @@ -167,10 +184,10 @@ fn verify_circuit_unsafe<G: Group, S: RelaxedR1CSSNARKTrait<G>>(
let circuit = LessThanCircuitUnsafe::new(bound, input, num_bits);

// produce keys
let (pk, vk) = SNARK::<G, S, LessThanCircuitUnsafe<_>>::setup(circuit.clone()).unwrap();
let (pk, vk) = SNARK::<G, S, LessThanCircuitUnsafe<_>>::setup(circuit.clone())?;

// produce a SNARK
let snark = SNARK::prove(&pk, circuit).unwrap();
let snark = SNARK::prove(&pk, circuit)?;

// verify the SNARK
snark.verify(&vk, &[])
Expand All @@ -184,53 +201,53 @@ fn verify_circuit_safe<G: Group, S: RelaxedR1CSSNARKTrait<G>>(
let circuit = LessThanCircuitSafe::new(bound, input, num_bits);

// produce keys
let (pk, vk) = SNARK::<G, S, LessThanCircuitSafe<_>>::setup(circuit.clone()).unwrap();
let (pk, vk) = SNARK::<G, S, LessThanCircuitSafe<_>>::setup(circuit.clone())?;

// produce a SNARK
let snark = SNARK::prove(&pk, circuit).unwrap();
let snark = SNARK::prove(&pk, circuit)?;

// verify the SNARK
snark.verify(&vk, &[])
}

fn main() {
type G = pasta_curves::pallas::Point;
type G = ark_bls12_381::G1Projective;
type EE = spartan2::provider::ipa_pc::EvaluationEngine<G>;
type S = spartan2::spartan::snark::RelaxedR1CSSNARK<G, EE>;

println!("Executing unsafe circuit...");
//Typical example, ok
assert!(verify_circuit_unsafe::<G, S>(Fq::from(17), Fq::from(9), 10).is_ok());
assert!(verify_circuit_unsafe::<G, S>(Fr::from(17), Fr::from(9), 10).is_ok());
// Typical example, err
assert!(verify_circuit_unsafe::<G, S>(Fq::from(17), Fq::from(20), 10).is_err());
assert!(verify_circuit_unsafe::<G, S>(Fr::from(17), Fr::from(20), 10).is_err());
// Edge case, err
assert!(verify_circuit_unsafe::<G, S>(Fq::from(4), Fq::from(4), 10).is_err());
assert!(verify_circuit_unsafe::<G, S>(Fr::from(4), Fr::from(4), 10).is_err());
// Edge case, ok
assert!(verify_circuit_unsafe::<G, S>(Fq::from(4), Fq::from(3), 10).is_ok());
assert!(verify_circuit_unsafe::<G, S>(Fr::from(4), Fr::from(3), 10).is_ok());
// Minimum number of bits for the bound, ok
assert!(verify_circuit_unsafe::<G, S>(Fq::from(4), Fq::from(3), 3).is_ok());
assert!(verify_circuit_unsafe::<G, S>(Fr::from(4), Fr::from(3), 3).is_ok());
// Insufficient number of bits for the input, but this is not detected by the
// unsafety of the circuit (compare with the last example below)
// Note that -Fq::one() is corresponds to q - 1 > bound
assert!(verify_circuit_unsafe::<G, S>(Fq::from(4), -Fq::one(), 3).is_ok());
// Note that -Fr::one() is corresponds to q - 1 > bound
assert!(verify_circuit_unsafe::<G, S>(Fr::from(4), -Fr::one(), 3).is_ok());

println!("Unsafe circuit OK");

println!("Executing safe circuit...");
// Typical example, ok
assert!(verify_circuit_safe::<G, S>(Fq::from(17), Fq::from(9), 10).is_ok());
assert!(verify_circuit_safe::<G, S>(Fr::from(17), Fr::from(9), 10).is_ok());
// Typical example, err
assert!(verify_circuit_safe::<G, S>(Fq::from(17), Fq::from(20), 10).is_err());
assert!(verify_circuit_safe::<G, S>(Fr::from(17), Fr::from(20), 10).is_err());
// Edge case, err
assert!(verify_circuit_safe::<G, S>(Fq::from(4), Fq::from(4), 10).is_err());
assert!(verify_circuit_safe::<G, S>(Fr::from(4), Fr::from(4), 10).is_err());
// Edge case, ok
assert!(verify_circuit_safe::<G, S>(Fq::from(4), Fq::from(3), 10).is_ok());
assert!(verify_circuit_safe::<G, S>(Fr::from(4), Fr::from(3), 10).is_ok());
// Minimum number of bits for the bound, ok
assert!(verify_circuit_safe::<G, S>(Fq::from(4), Fq::from(3), 3).is_ok());
assert!(verify_circuit_safe::<G, S>(Fr::from(4), Fr::from(3), 3).is_ok());
// Insufficient number of bits for the input, err (compare with the last example
// above).
// Note that -Fq::one() is corresponds to q - 1 > bound
assert!(verify_circuit_safe::<G, S>(Fq::from(4), -Fq::one(), 3).is_err());
// Note that -Fr::one() is corresponds to q - 1 > bound
assert!(verify_circuit_safe::<G, S>(Fr::from(4), -Fr::one(), 3).is_err());

println!("Safe circuit OK");
}
Loading
Loading