diff --git a/co-circom/co-groth16/Cargo.toml b/co-circom/co-groth16/Cargo.toml index 3b53a1858..06858e915 100644 --- a/co-circom/co-groth16/Cargo.toml +++ b/co-circom/co-groth16/Cargo.toml @@ -25,15 +25,17 @@ ark-groth16 = { version = "=0.4.0", default-features = false, features = [ ], optional = true } ark-poly = { workspace = true } ark-relations = { workspace = true } +ark-serialize = { workspace = true } circom-types = { version = "0.5.0", path = "../circom-types" } co-circom-snarks = { version = "0.1.2", path = "../co-circom-snarks" } eyre = { workspace = true } -itertools = { workspace = true } mpc-core = { version = "0.5.0", path = "../../mpc-core" } mpc-net = { version = "0.1.2", path = "../../mpc-net" } num-traits = { workspace = true } rand = { workspace = true } +rayon = { workspace = true } tracing = { workspace = true } +serde_json = { workspace = true } [dev-dependencies] serde_json = { workspace = true } diff --git a/co-circom/co-groth16/src/groth16.rs b/co-circom/co-groth16/src/groth16.rs index 028d57959..0299dc1b1 100644 --- a/co-circom/co-groth16/src/groth16.rs +++ b/co-circom/co-groth16/src/groth16.rs @@ -1,24 +1,37 @@ //! A Groth16 proof protocol that uses a collaborative MPC protocol to generate the proof. use ark_ec::pairing::Pairing; +use ark_ec::scalar_mul::variable_base::VariableBaseMSM; use ark_ec::{AffineRepr, CurveGroup}; use ark_ff::{FftField, PrimeField}; use ark_poly::{EvaluationDomain, GeneralEvaluationDomain}; -use ark_relations::r1cs::{ConstraintMatrices, SynthesisError}; +use ark_relations::r1cs::{ConstraintMatrices, Matrix, SynthesisError}; use circom_types::groth16::{Groth16Proof, ZKey}; use circom_types::traits::{CircomArkworksPairingBridge, CircomArkworksPrimeFieldBridge}; use co_circom_snarks::SharedWitness; use eyre::Result; -use itertools::izip; -use mpc_core::protocols::plain::PlainDriver; -use mpc_core::traits::{EcMpcProtocol, MSMProvider}; -use mpc_core::{ - protocols::rep3::{network::Rep3MpcNet, Rep3Protocol}, - traits::{FFTProvider, PairingEcMpcProtocol, PrimeFieldMpcProtocol}, -}; +use mpc_core::protocols::rep3::network::{IoContext, Rep3MpcNet}; +use mpc_core::protocols::shamir::network::ShamirMpcNet; +use mpc_core::protocols::shamir::{ShamirPreprocessing, ShamirProtocol}; use mpc_net::config::NetworkConfig; use num_traits::identities::One; use num_traits::ToPrimitive; +use rayon::prelude::*; use std::marker::PhantomData; +use std::sync::Arc; +use std::time::Instant; +use tracing::instrument; + +use crate::mpc::plain::PlainGroth16Driver; +use crate::mpc::rep3::Rep3Groth16Driver; +use crate::mpc::shamir::ShamirGroth16Driver; +use crate::mpc::CircomGroth16Prover; + +macro_rules! rayon_join { + ($t1: expr, $t2: expr, $t3: expr) => {{ + let ((x, y), z) = rayon::join(|| rayon::join(|| $t1, || $t2), || $t3); + (x, y, z) + }}; +} /// The plain [`Groth16`] type. /// @@ -30,17 +43,12 @@ use std::marker::PhantomData; /// /// More interesting is the [`Groth16::verify`] method. You can verify any circom Groth16 proof, be it /// from snarkjs or one created by this project. Under the hood we use the arkwork Groth16 project for verifying. -pub type Groth16

= CoGroth16::ScalarField>, P>; +pub type Groth16

= CoGroth16; /// A type alias for a [CoGroth16] protocol using replicated secret sharing. -pub type Rep3CoGroth16

= CoGroth16::ScalarField, Rep3MpcNet>, P>; - -type FieldShare = ::ScalarField>>::FieldShare; -type FieldShareVec = ::ScalarField>>::FieldShareVec; -type PointShare = >::PointShare; -type CurveFieldShareVec = ::Affine as AffineRepr>::ScalarField, ->>::FieldShareVec; +pub type Rep3CoGroth16 = CoGroth16>; +/// A type alias for a [CoGroth16] protocol using shamir secret sharing. +pub type ShamirCoGroth16 = CoGroth16::ScalarField, N>>; /* old way of computing root of unity, does not work for bls12_381: let root_of_unity = { @@ -54,6 +62,7 @@ calculate smallest quadratic non residue q (by checking q^((p-1)/2)=-1 mod p) al use g=q^t (this is a 2^s-th root of unity) as (some kind of) generator and compute another domain by repeatedly squaring g, should get to 1 in the s+1-th step. then if log2(domain_size) equals s we take as root of unity q^2, and else we take the log2(domain_size) + 1-th element of the domain created above */ +#[instrument(level = "debug", name = "root of unity", skip_all)] fn root_of_unity_for_groth16( pow: usize, domain: &mut GeneralEvaluationDomain, @@ -77,26 +86,13 @@ fn root_of_unity_for_groth16( } /// A Groth16 proof protocol that uses a collaborative MPC protocol to generate the proof. -pub struct CoGroth16 -where - T: PrimeFieldMpcProtocol - + PairingEcMpcProtocol

- + FFTProvider - + MSMProvider - + MSMProvider, -{ +pub struct CoGroth16> { pub(crate) driver: T, phantom_data: PhantomData

, } -impl CoGroth16 +impl> CoGroth16 where - T: PrimeFieldMpcProtocol - + PairingEcMpcProtocol

- + FFTProvider - + MSMProvider - + MSMProvider, - P: CircomArkworksPairingBridge, P::BaseField: CircomArkworksPrimeFieldBridge, P::ScalarField: CircomArkworksPrimeFieldBridge, { @@ -110,34 +106,54 @@ where /// Execute the Groth16 prover using the internal MPC driver. /// This version takes the Circom-generated constraint matrices as input and does not re-calculate them. + #[instrument(level = "debug", name = "Groth16 - Proof", skip_all)] pub fn prove( - &mut self, + mut self, zkey: &ZKey

, - private_witness: SharedWitness, + private_witness: SharedWitness, ) -> Result> { + let id = self.driver.get_party_id(); + tracing::info!("Party {}: starting proof generation..", id); + let start = Instant::now(); let matrices = &zkey.matrices; let num_inputs = matrices.num_instance_variables; let num_constraints = matrices.num_constraints; - let public_inputs = &private_witness.public_inputs; - let private_witness = &private_witness.witness; - tracing::debug!("calling witness map from matrices..."); + let public_inputs = Arc::new(private_witness.public_inputs); + let private_witness = Arc::new(private_witness.witness); let h = self.witness_map_from_matrices( zkey.pow, matrices, num_constraints, num_inputs, - public_inputs, - private_witness, + &public_inputs, + &private_witness, )?; - tracing::debug!("done!"); - tracing::debug!("getting r and s..."); - let r = self.driver.rand()?; - let s = self.driver.rand()?; - tracing::debug!("done!"); - tracing::debug!("calling create_proof_with_assignment..."); - self.create_proof_with_assignment(zkey, r, s, &h, &public_inputs[1..], private_witness) + let (r, s) = (self.driver.rand()?, self.driver.rand()?); + let proof = + self.create_proof_with_assignment(zkey, r, s, h, public_inputs, private_witness)?; + + let duration_ms = start.elapsed().as_micros() as f64 / 1000.; + tracing::info!("Party {}: Proof generation took {} ms", id, duration_ms); + Ok(proof) } + fn evaluate_constraint( + party_id: T::PartyID, + domain_size: usize, + matrix: &Matrix, + public_inputs: &[P::ScalarField], + private_witness: &[T::ArithmeticShare], + ) -> Vec { + let mut result = matrix + .par_iter() + .with_min_len(32) + .map(|x| T::evaluate_constraint(party_id, x, public_inputs, private_witness)) + .collect::>(); + result.resize(domain_size, T::ArithmeticShare::default()); + result + } + + #[instrument(level = "debug", name = "witness map from matrices", skip_all)] fn witness_map_from_matrices( &mut self, power: usize, @@ -145,176 +161,305 @@ where num_constraints: usize, num_inputs: usize, public_inputs: &[P::ScalarField], - private_witness: &FieldShareVec, - ) -> Result> { + private_witness: &[T::ArithmeticShare], + ) -> Result> { let mut domain = GeneralEvaluationDomain::::new(num_constraints + num_inputs) .ok_or(SynthesisError::PolynomialDegreeTooLarge)?; - let root_of_unity = root_of_unity_for_groth16(power, &mut domain); - let domain_size = domain.size(); - let mut a = vec![FieldShare::::default(); domain_size]; - let mut b = vec![FieldShare::::default(); domain_size]; - tracing::debug!("evaluating constraints.."); - for (a, b, at_i, bt_i) in izip!(&mut a, &mut b, &matrices.a, &matrices.b) { - *a = self - .driver - .evaluate_constraint(at_i, public_inputs, private_witness); - *b = self - .driver - .evaluate_constraint(bt_i, public_inputs, private_witness); - } - tracing::debug!("done!"); - let mut a = FieldShareVec::::from(a); - let promoted_public = self.driver.promote_to_trivial_shares(public_inputs); - self.driver - .clone_from_slice(&mut a, &promoted_public, num_constraints, 0, num_inputs); - - let mut b = FieldShareVec::::from(b); - let mut c = self.driver.mul_vec(&a, &b)?; - self.driver.ifft_in_place(&mut a, &domain); - self.driver.ifft_in_place(&mut b, &domain); - self.driver.distribute_powers_and_mul_by_const( - &mut a, - root_of_unity, - P::ScalarField::one(), + let party_id = self.driver.get_party_id(); + let eval_constraint_span = + tracing::debug_span!("evaluate constraints + root of unity computation").entered(); + let (roots_to_power_domain, a, b) = rayon_join!( + { + let root_of_unity_span = + tracing::debug_span!("root of unity computation").entered(); + let root_of_unity = root_of_unity_for_groth16(power, &mut domain); + let mut roots = Vec::with_capacity(domain_size); + let mut c = P::ScalarField::one(); + for _ in 0..domain_size { + roots.push(c); + c *= root_of_unity; + } + root_of_unity_span.exit(); + Arc::new(roots) + }, + { + let eval_constraint_span_a = + tracing::debug_span!("evaluate constraints - a").entered(); + let mut result = Self::evaluate_constraint( + party_id, + domain_size, + &matrices.a, + public_inputs, + private_witness, + ); + let promoted_public = T::promote_to_trivial_shares(party_id, public_inputs); + result[num_constraints..num_constraints + num_inputs] + .clone_from_slice(&promoted_public[..num_inputs]); + eval_constraint_span_a.exit(); + result + }, + { + let eval_constraint_span_b = + tracing::debug_span!("evaluate constraints - a").entered(); + let result = Self::evaluate_constraint( + party_id, + domain_size, + &matrices.b, + public_inputs, + private_witness, + ); + eval_constraint_span_b.exit(); + result + } ); - self.driver.distribute_powers_and_mul_by_const( - &mut b, - root_of_unity, - P::ScalarField::one(), - ); - self.driver.fft_in_place(&mut a, &domain); - self.driver.fft_in_place(&mut b, &domain); - //this can be in-place so that we do not have to allocate memory - let mut ab = self.driver.mul_vec(&a, &b)?; - std::mem::drop(a); - std::mem::drop(b); - - self.driver.ifft_in_place(&mut c, &domain); - self.driver.distribute_powers_and_mul_by_const( - &mut c, - root_of_unity, - P::ScalarField::one(), - ); - self.driver.fft_in_place(&mut c, &domain); - self.driver.sub_assign_vec(&mut ab, &c); + let domain = Arc::new(domain); + eval_constraint_span.exit(); + + let (a_tx, a_rx) = std::sync::mpsc::channel(); + let (b_tx, b_rx) = std::sync::mpsc::channel(); + let (c_tx, c_rx) = std::sync::mpsc::channel(); + let a_domain = Arc::clone(&domain); + let b_domain = Arc::clone(&domain); + let c_domain = Arc::clone(&domain); + let mut a_result = a.clone(); + let mut b_result = b.clone(); + let a_roots = Arc::clone(&roots_to_power_domain); + let b_roots = Arc::clone(&roots_to_power_domain); + let c_roots = Arc::clone(&roots_to_power_domain); + rayon::spawn(move || { + let a_span = tracing::debug_span!("distribute powers mul a (fft/ifft)").entered(); + a_domain.ifft_in_place(&mut a_result); + T::distribute_powers_and_mul_by_const(&mut a_result, &a_roots); + a_domain.fft_in_place(&mut a_result); + a_tx.send(a_result).expect("channel not droped"); + a_span.exit(); + }); + + rayon::spawn(move || { + let b_span = tracing::debug_span!("distribute powers mul b (fft/ifft)").entered(); + b_domain.ifft_in_place(&mut b_result); + T::distribute_powers_and_mul_by_const(&mut b_result, &b_roots); + b_domain.fft_in_place(&mut b_result); + b_tx.send(b_result).expect("channel not droped"); + b_span.exit(); + }); + + let local_mul_vec_span = tracing::debug_span!("c: local_mul_vec").entered(); + let mut ab = self.driver.local_mul_vec(a, b); + local_mul_vec_span.exit(); + rayon::spawn(move || { + let ifft_span = tracing::debug_span!("c: ifft in dist pows").entered(); + c_domain.ifft_in_place(&mut ab); + ifft_span.exit(); + let dist_pows_span = tracing::debug_span!("c: dist pows").entered(); + #[allow(unused_mut)] + ab.par_iter_mut() + .zip_eq(c_roots.par_iter()) + .with_min_len(512) + .for_each(|(share, pow)| { + *share *= pow; + }); + dist_pows_span.exit(); + let fft_span = tracing::debug_span!("c: fft in dist pows").entered(); + c_domain.fft_in_place(&mut ab); + fft_span.exit(); + c_tx.send(ab).expect("channel not dropped"); + }); + + let a = a_rx.recv()?; + let b = b_rx.recv()?; + + let compute_ab_span = tracing::debug_span!("compute ab").entered(); + let local_ab_span = tracing::debug_span!("local part (mul and sub)").entered(); + // same as above. No IO task is run at the moment. + let mut ab = self.driver.local_mul_vec(a, b); + let c = c_rx.recv()?; + ab.par_iter_mut() + .zip_eq(c.par_iter()) + .with_min_len(512) + .for_each(|(a, b)| { + *a -= b; + }); + local_ab_span.exit(); + compute_ab_span.exit(); Ok(ab) } - fn calculate_coeff( - &mut self, - initial: PointShare, - query: &[C::Affine], - vk_param: C::Affine, - input_assignment: &[C::ScalarField], - aux_assignment: &CurveFieldShareVec, - ) -> PointShare - where - T: EcMpcProtocol, - T: MSMProvider, - { - tracing::debug!("calculate coeffs.."); + fn calculate_coeff_g1( + id: T::PartyID, + initial: T::PointShareG1, + query: &[P::G1Affine], + vk_param: P::G1Affine, + input_assignment: &[P::ScalarField], + aux_assignment: &[T::ArithmeticShare], + ) -> T::PointShareG1 { let pub_len = input_assignment.len(); - let pub_acc = C::msm_unchecked(&query[1..=pub_len], input_assignment); - let priv_acc = MSMProvider::::msm_public_points( - &mut self.driver, - &query[1 + pub_len..], - aux_assignment, + + // we block this thread of the runtime here. + // It should not matter too much, as we have multithreaded + // runtime. + let (pub_acc, priv_acc) = rayon::join( + || P::G1::msm_unchecked(&query[1..=pub_len], input_assignment), + || T::msm_public_points_g1(&query[1 + pub_len..], aux_assignment), ); let mut res = initial; - EcMpcProtocol::::add_assign_points_public_affine(&mut self.driver, &mut res, &query[0]); - EcMpcProtocol::::add_assign_points_public_affine(&mut self.driver, &mut res, &vk_param); - EcMpcProtocol::::add_assign_points_public(&mut self.driver, &mut res, &pub_acc); - EcMpcProtocol::::add_assign_points(&mut self.driver, &mut res, &priv_acc); + T::add_assign_points_public_g1(id, &mut res, &query[0].into_group()); + T::add_assign_points_public_g1(id, &mut res, &vk_param.into_group()); + T::add_assign_points_public_g1(id, &mut res, &pub_acc); + T::add_assign_points_g1(&mut res, &priv_acc); + res + } + + fn calculate_coeff_g2( + id: T::PartyID, + initial: T::PointShareG2, + query: &[P::G2Affine], + vk_param: P::G2Affine, + input_assignment: &[P::ScalarField], + aux_assignment: &[T::ArithmeticShare], + ) -> T::PointShareG2 { + let pub_len = input_assignment.len(); + + // we block this thread of the runtime here. + // It should not matter too much, as we have multithreaded + // runtime. + let (pub_acc, priv_acc) = rayon::join( + || P::G2::msm_unchecked(&query[1..=pub_len], input_assignment), + || T::msm_public_points_g2(&query[1 + pub_len..], aux_assignment), + ); - tracing::debug!("done.."); + let mut res = initial; + T::add_assign_points_public_g2(id, &mut res, &query[0].into_group()); + T::add_assign_points_public_g2(id, &mut res, &vk_param.into_group()); + T::add_assign_points_public_g2(id, &mut res, &pub_acc); + T::add_assign_points_g2(&mut res, &priv_acc); res } + #[instrument(level = "debug", name = "create proof with assignment", skip_all)] fn create_proof_with_assignment( - &mut self, + mut self, zkey: &ZKey

, - r: FieldShare, - s: FieldShare, - h: &FieldShareVec, - input_assignment: &[P::ScalarField], - aux_assignment: &FieldShareVec, + r: T::ArithmeticShare, + s: T::ArithmeticShare, + h: Vec, + input_assignment: Arc>, + aux_assignment: Arc>, ) -> Result> { - tracing::debug!("create proof with assignment..."); - //let c_acc_time = start_timer!(|| "Compute C"); - let h_acc = MSMProvider::::msm_public_points(&mut self.driver, &zkey.h_query, h); - - // Compute C - let l_aux_acc = MSMProvider::::msm_public_points( - &mut self.driver, - &zkey.l_query, - aux_assignment, - ); - let delta_g1 = zkey.delta_g1.into_group(); - let rs = self.driver.mul(&r, &s)?; - let r_s_delta_g1 = self.driver.scalar_mul_public_point(&delta_g1, &rs); + let (l_acc_tx, l_acc_rx) = std::sync::mpsc::channel(); + let h_query = Arc::clone(&zkey.h_query); + let l_query = Arc::clone(&zkey.l_query); + + let party_id = self.driver.get_party_id(); + let (r_g1_tx, r_g1_rx) = std::sync::mpsc::channel(); + let (s_g1_tx, s_g1_rx) = std::sync::mpsc::channel(); + let (s_g2_tx, s_g2_rx) = std::sync::mpsc::channel(); + let a_query = Arc::clone(&zkey.a_query); + let b_g1_query = Arc::clone(&zkey.b_g1_query); + let b_g2_query = Arc::clone(&zkey.b_g2_query); + let input_assignment1 = Arc::clone(&input_assignment); + let input_assignment2 = Arc::clone(&input_assignment); + let input_assignment3 = Arc::clone(&input_assignment); + let aux_assignment1 = Arc::clone(&aux_assignment); + let aux_assignment2 = Arc::clone(&aux_assignment); + let aux_assignment3 = Arc::clone(&aux_assignment); + let aux_assignment4 = Arc::clone(&aux_assignment); + let alpha_g1 = zkey.vk.alpha_g1; + let beta_g1 = zkey.beta_g1; + let beta_g2 = zkey.vk.beta_g2; + let delta_g2 = zkey.vk.delta_g2.into_group(); + + rayon::spawn(move || { + let compute_a = + tracing::debug_span!("compute A in create proof with assignment").entered(); + // Compute A + let r_g1 = T::scalar_mul_public_point_g1(&delta_g1, r); + let r_g1 = Self::calculate_coeff_g1( + party_id, + r_g1, + &a_query, + alpha_g1, + &input_assignment1[1..], + &aux_assignment1, + ); + r_g1_tx.send(r_g1).expect("not dropped"); + compute_a.exit(); + }); + + rayon::spawn(move || { + let compute_b = + tracing::debug_span!("compute B/G1 in create proof with assignment").entered(); + // Compute B in G1 + // In original implementation this is skipped if r==0, however r is shared in our case + let s_g1 = T::scalar_mul_public_point_g1(&delta_g1, s); + let s_g1 = Self::calculate_coeff_g1( + party_id, + s_g1, + &b_g1_query, + beta_g1, + &input_assignment2[1..], + &aux_assignment2, + ); + s_g1_tx.send(s_g1).expect("not dropped"); + compute_b.exit(); + }); + + rayon::spawn(move || { + let compute_b = + tracing::debug_span!("compute B/G2 in create proof with assignment").entered(); + // Compute B in G2 + let s_g2 = T::scalar_mul_public_point_g2(&delta_g2, s); + let s_g2 = Self::calculate_coeff_g2( + party_id, + s_g2, + &b_g2_query, + beta_g2, + &input_assignment3[1..], + &aux_assignment3, + ); + s_g2_tx.send(s_g2).expect("not dropped"); + compute_b.exit(); + }); - //end_timer!(c_acc_time); + rayon::spawn(move || { + let msm_l_query = tracing::debug_span!("msm l_query").entered(); + let result = T::msm_public_points_g1(l_query.as_ref(), &aux_assignment4); + l_acc_tx.send(result).expect("channel not dropped"); + msm_l_query.exit(); + }); - // Compute A - // let a_acc_time = start_timer!(|| "Compute A"); - let r_g1 = self.driver.scalar_mul_public_point(&delta_g1, &r); + // TODO we can remove the networking round in msm_and_mul because we only have linear operations afterwards + let (h_acc, rs) = self.driver.msm_and_mul(h, h_query, r, s)?; + let r_s_delta_g1 = T::scalar_mul_public_point_g1(&delta_g1, rs); - let g_a = self.calculate_coeff::( - r_g1, - &zkey.a_query, - zkey.vk.alpha_g1, - input_assignment, - aux_assignment, - ); + let l_aux_acc = l_acc_rx.recv().expect("channel not dropped"); - // Open here since g_a is part of proof - let g_a_opened = EcMpcProtocol::::open_point(&mut self.driver, &g_a)?; - let s_g_a = self.driver.scalar_mul_public_point(&g_a_opened, &s); - // end_timer!(a_acc_time); - - // Compute B in G1 - // In original implementation this is skipped if r==0, however r is shared in our case - // let b_g1_acc_time = start_timer!(|| "Compute B in G1"); - let s_g1 = self.driver.scalar_mul_public_point(&delta_g1, &s); - let g1_b = self.calculate_coeff::( - s_g1, - &zkey.b_g1_query, - zkey.beta_g1, - input_assignment, - aux_assignment, - ); - let r_g1_b = EcMpcProtocol::::scalar_mul(&mut self.driver, &g1_b, &r)?; - // end_timer!(b_g1_acc_time); + let calculate_coeff_span = tracing::debug_span!("calculate coeff").entered(); + let g_a = r_g1_rx.recv()?; + let g1_b = s_g1_rx.recv()?; + calculate_coeff_span.exit(); - // Compute B in G2 - let delta_g2 = zkey.vk.delta_g2.into_group(); - // let b_g2_acc_time = start_timer!(|| "Compute B in G2"); - let s_g2 = self.driver.scalar_mul_public_point(&delta_g2, &s); - let g2_b = self.calculate_coeff::( - s_g2, - &zkey.b_g2_query, - zkey.vk.beta_g2, - input_assignment, - aux_assignment, - ); - // end_timer!(b_g2_acc_time); + let network_round = tracing::debug_span!("network round after calc coeff").entered(); + let (g_a_opened, r_g1_b) = self.driver.open_point_and_scalar_mul(&g_a, &g1_b, r)?; + network_round.exit(); + + let last_round = tracing::debug_span!("finish open two points and some adds").entered(); + let s_g_a = T::scalar_mul_public_point_g1(&g_a_opened, s); - // let c_time = start_timer!(|| "Finish C"); let mut g_c = s_g_a; - EcMpcProtocol::::add_assign_points(&mut self.driver, &mut g_c, &r_g1_b); - EcMpcProtocol::::sub_assign_points(&mut self.driver, &mut g_c, &r_s_delta_g1); - EcMpcProtocol::::add_assign_points(&mut self.driver, &mut g_c, &l_aux_acc); - EcMpcProtocol::::add_assign_points(&mut self.driver, &mut g_c, &h_acc); - // end_timer!(c_time); + T::add_assign_points_g1(&mut g_c, &r_g1_b); + T::sub_assign_points_g1(&mut g_c, &r_s_delta_g1); + T::add_assign_points_g1(&mut g_c, &l_aux_acc); + T::add_assign_points_g1(&mut g_c, &h_acc); - tracing::debug!("almost done..."); - let (g_c_opened, g2_b_opened) = - PairingEcMpcProtocol::

::open_two_points(&mut self.driver, &g_c, &g2_b)?; + let g2_b = s_g2_rx.recv()?; + let (g_c_opened, g2_b_opened) = self.driver.open_two_points(g_c, g2_b)?; + last_round.exit(); Ok(Groth16Proof { pi_a: g_a_opened.into_affine(), @@ -326,7 +471,7 @@ where } } -impl Rep3CoGroth16

+impl Rep3CoGroth16 where P: CircomArkworksPairingBridge, P::BaseField: CircomArkworksPrimeFieldBridge, @@ -335,8 +480,42 @@ where /// Create a new [Rep3CoGroth16] protocol with a given network configuration. pub fn with_network_config(config: NetworkConfig) -> Result { let mpc_net = Rep3MpcNet::new(config)?; - let driver = Rep3Protocol::::new(mpc_net)?; - Ok(CoGroth16::new(driver)) + let mut io_context0 = IoContext::init(mpc_net)?; + let io_context1 = io_context0.fork()?; + let driver = Rep3Groth16Driver::new(io_context0, io_context1); + Ok(CoGroth16 { + driver, + phantom_data: PhantomData, + }) + } +} + +impl ShamirCoGroth16 +where + P: CircomArkworksPairingBridge, + P::BaseField: CircomArkworksPrimeFieldBridge, + P::ScalarField: CircomArkworksPrimeFieldBridge, +{ + /// Create a new [ShamirCoGroth16] protocol with a given network configuration. + pub fn with_network_config( + threshold: usize, + config: NetworkConfig, + zkey: &ZKey

, + ) -> Result { + let domain_size = 2usize.pow(u32::try_from(zkey.pow).expect("pow fits into u32")); + // we need domain_size + 2 + 1 number of corr rand pairs in witness_map_from_matrices (degree_reduce_vec + r and s + 1 for fork) + let num_pairs = domain_size + 2 + 1; + let mpc_net = ShamirMpcNet::new(config)?; + let preprocessing = ShamirPreprocessing::new(threshold, mpc_net, num_pairs)?; + let mut protocol0 = ShamirProtocol::from(preprocessing); + // the protocol1 is only used for scalar_mul and a field_mul which need 1 pair each (ergo 2 + // pairs) + let protocol1 = protocol0.fork_with_pairs(2)?; + let driver = ShamirGroth16Driver::new(protocol0, protocol1); + Ok(CoGroth16 { + driver, + phantom_data: PhantomData, + }) } } @@ -352,69 +531,12 @@ where /// DOES NOT PERFORM ANY MPC. For a plain prover checkout the [Groth16 implementation of arkworks](https://docs.rs/ark-groth16/latest/ark_groth16/). pub fn plain_prove( zkey: &ZKey

, - private_witness: SharedWitness, P>, + private_witness: SharedWitness, ) -> Result> { - let mut prover = Self { - driver: PlainDriver::default(), + let prover = Self { + driver: PlainGroth16Driver, phantom_data: PhantomData, }; prover.prove(zkey, private_witness) } } - -#[cfg(test)] -mod test { - use std::fs::File; - - use ark_bn254::Bn254; - use circom_types::R1CS; - use mpc_core::protocols::{ - rep3::{network::Rep3MpcNet, Rep3Protocol}, - shamir::{network::ShamirMpcNet, ShamirProtocol}, - }; - use rand::thread_rng; - - use super::SharedWitness; - use circom_types::Witness; - - #[ignore] - #[test] - fn test_rep3() { - let witness_file = File::open("../../test_vectors/bn254/multiplier2/witness.wtns").unwrap(); - let witness = Witness::::from_reader(witness_file).unwrap(); - let r1cs_file = - File::open("../../test_vectors/bn254/multiplier2/multiplier2.r1cs").unwrap(); - let r1cs = R1CS::::from_reader(r1cs_file).unwrap(); - let mut rng = thread_rng(); - let [s1, _, _] = - SharedWitness::, Bn254>::share_rep3( - witness, - r1cs.num_inputs, - &mut rng, - ); - println!("{}", serde_json::to_string(&s1).unwrap()); - } - - fn test_shamir_inner(num_parties: usize, threshold: usize) { - let witness_file = File::open("../../test_vectors/bn254/multiplier2/witness.wtns").unwrap(); - let witness = Witness::::from_reader(witness_file).unwrap(); - let r1cs_file = - File::open("../../test_vectors/bn254/multiplier2/multiplier2.r1cs").unwrap(); - let r1cs = R1CS::::from_reader(r1cs_file).unwrap(); - let mut rng = thread_rng(); - let s1 = SharedWitness::, Bn254>::share_shamir( - witness, - r1cs.num_inputs, - threshold, - num_parties, - &mut rng, - ); - println!("{}", serde_json::to_string(&s1[0]).unwrap()); - } - - #[ignore] - #[test] - fn test_shamir() { - test_shamir_inner(3, 1); - } -} diff --git a/co-circom/co-groth16/src/lib.rs b/co-circom/co-groth16/src/lib.rs index 3a9ab3cf4..d7aa3d8fe 100644 --- a/co-circom/co-groth16/src/lib.rs +++ b/co-circom/co-groth16/src/lib.rs @@ -1,12 +1,15 @@ //! A library for creating and verifying Groth16 proofs in a collaborative fashion using MPC. #![warn(missing_docs)] mod groth16; +/// This module contains the Groth16 prover trait +pub mod mpc; #[cfg(feature = "verifier")] mod verifier; pub use groth16::CoGroth16; pub use groth16::Groth16; pub use groth16::Rep3CoGroth16; +pub use groth16::ShamirCoGroth16; #[cfg(test)] #[cfg(feature = "verifier")] @@ -18,7 +21,6 @@ mod tests { Witness, }; use co_circom_snarks::SharedWitness; - use mpc_core::protocols::plain::PlainDriver; use std::fs::{self, File}; use crate::groth16::Groth16; @@ -33,19 +35,15 @@ mod tests { File::open("../../test_vectors/Groth16/bn254/multiplier2/verification_key.json") .unwrap(); - let driver = PlainDriver::::default(); let witness = Witness::::from_reader(witness_file).unwrap(); let zkey = ZKey::::from_reader(zkey_file).unwrap(); let vk: JsonVerificationKey = serde_json::from_reader(vk_file).unwrap(); let public_input = witness.values[..=zkey.n_public].to_vec(); - let witness = SharedWitness::, Bn254> { + let witness = SharedWitness { public_inputs: public_input.clone(), witness: witness.values[zkey.n_public + 1..].to_vec(), }; - let mut groth16 = Groth16::::new(driver); - let proof = groth16 - .prove(&zkey, witness) - .expect("proof generation works"); + let proof = Groth16::::plain_prove(&zkey, witness).expect("proof generation works"); let ser_proof = serde_json::to_string(&proof).unwrap(); let der_proof = serde_json::from_str::>(&ser_proof).unwrap(); let verified = Groth16::verify(&vk, &der_proof, &public_input[1..]).expect("can verify"); @@ -78,7 +76,6 @@ mod tests { File::open("../../test_vectors/Groth16/bn254/poseidon/circuit.zkey").unwrap(); let witness_file = File::open("../../test_vectors/Groth16/bn254/poseidon/witness.wtns").unwrap(); - let driver = PlainDriver::::default(); let vk_file = File::open("../../test_vectors/Groth16/bn254/poseidon/verification_key.json").unwrap(); @@ -86,14 +83,11 @@ mod tests { let zkey = ZKey::::from_reader(zkey_file).unwrap(); let vk: JsonVerificationKey = serde_json::from_reader(vk_file).unwrap(); let public_input = witness.values[..=zkey.n_public].to_vec(); - let witness = SharedWitness::, Bn254> { + let witness = SharedWitness { public_inputs: public_input.clone(), witness: witness.values[zkey.n_public + 1..].to_vec(), }; - let mut groth16 = Groth16::::new(driver); - let proof = groth16 - .prove(&zkey, witness) - .expect("proof generation works"); + let proof = Groth16::::plain_prove(&zkey, witness).expect("proof generation works"); let ser_proof = serde_json::to_string(&proof).unwrap(); let der_proof = serde_json::from_str::>(&ser_proof).unwrap(); let verified = Groth16::verify(&vk, &der_proof, &public_input[1..]).expect("can verify"); @@ -152,16 +146,13 @@ mod tests { let zkey = ZKey::::from_reader(zkey_file).unwrap(); let vk: JsonVerificationKey = serde_json::from_reader(vk_file).unwrap(); let public_input = witness.values[..=zkey.n_public].to_vec(); - let witness = SharedWitness::, Bls12_381> { + let witness = SharedWitness { public_inputs: public_input.clone(), witness: witness.values[zkey.n_public + 1..].to_vec(), }; - let driver = PlainDriver::::default(); - let mut groth16 = Groth16::::new(driver); - let proof = groth16 - .prove(&zkey, witness) - .expect("proof generation works"); + let proof = + Groth16::::plain_prove(&zkey, witness).expect("proof generation works"); let verified = Groth16::::verify(&vk, &proof, &public_input[1..]).expect("can verify"); assert!(verified); @@ -185,16 +176,12 @@ mod tests { let zkey = ZKey::::from_reader(zkey_file).unwrap(); let vk: JsonVerificationKey = serde_json::from_reader(vk_file).unwrap(); let public_input = witness.values[..=zkey.n_public].to_vec(); - let witness = SharedWitness::, Bn254> { + let witness = SharedWitness { public_inputs: public_input.clone(), witness: witness.values[zkey.n_public + 1..].to_vec(), }; - let driver = PlainDriver::::default(); - let mut groth16 = Groth16::::new(driver); - let proof = groth16 - .prove(&zkey, witness) - .expect("proof generation works"); + let proof = Groth16::::plain_prove(&zkey, witness).expect("proof generation works"); let verified = Groth16::::verify(&vk, &proof, &public_input[1..]).expect("can verify"); assert!(verified); diff --git a/co-circom/co-groth16/src/mpc.rs b/co-circom/co-groth16/src/mpc.rs new file mode 100644 index 000000000..465068f84 --- /dev/null +++ b/co-circom/co-groth16/src/mpc.rs @@ -0,0 +1,140 @@ +use core::fmt; +use std::{fmt::Debug, sync::Arc}; + +use ark_ec::pairing::Pairing; +use ark_poly::domain::DomainCoeff; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; + +pub(crate) mod plain; +pub(crate) mod rep3; +pub(crate) mod shamir; + +pub use plain::PlainGroth16Driver; +pub use rep3::Rep3Groth16Driver; +pub use shamir::ShamirGroth16Driver; + +type IoResult = std::io::Result; + +/// This trait represents the operations used during Groth16 proof generation +pub trait CircomGroth16Prover: Send + Sized { + /// The arithemitc share type + type ArithmeticShare: CanonicalSerialize + + CanonicalDeserialize + + Copy + + Clone + + Default + + Send + + Debug + + DomainCoeff + + 'static; + /// The G1 point share type + type PointShareG1: Debug + Send + 'static; + /// The G2 point share type + type PointShareG2: Debug + Send + 'static; + /// The party id type + type PartyID: Send + Sync + Copy + fmt::Display + 'static; + + /// Generate a random arithemitc share + fn rand(&mut self) -> IoResult; + + /// Get the party id + fn get_party_id(&self) -> Self::PartyID; + + /// Each value of lhs consists of a coefficient c and an index i. This function computes the sum of the coefficients times the corresponding public input or private witness. In other words, an accumulator a is initialized to 0, and for each (c, i) in lhs, a += c * public_inputs\[i\] is computed if i corresponds to a public input, or c * private_witness[i - public_inputs.len()] if i corresponds to a private witness. + fn evaluate_constraint( + party_id: Self::PartyID, + lhs: &[(P::ScalarField, usize)], + public_inputs: &[P::ScalarField], + private_witness: &[Self::ArithmeticShare], + ) -> Self::ArithmeticShare; + + /// Elementwise transformation of a vector of public values into a vector of shared values: \[a_i\] = a_i. + fn promote_to_trivial_shares( + id: Self::PartyID, + public_values: &[P::ScalarField], + ) -> Vec; + + /// Performs element-wise multiplication of two vectors of shared values. + /// Does not perform any networking. + /// + /// # Security + /// You must *NOT* perform additional non-linear operations on the result of this function. + fn local_mul_vec( + &mut self, + a: Vec, + b: Vec, + ) -> Vec; + + /// Compute the msm of `h` and `h_query` and multiplication `r` * `s`. + fn msm_and_mul( + &mut self, + h: Vec<

::ScalarField>, + h_query: Arc>, + r: Self::ArithmeticShare, + s: Self::ArithmeticShare, + ) -> IoResult<(Self::PointShareG1, Self::ArithmeticShare)>; + + /// Computes the \[coeffs_i\] *= c * g^i for the coefficients in 0 <= i < coeff.len() + fn distribute_powers_and_mul_by_const( + coeffs: &mut [Self::ArithmeticShare], + roots: &[P::ScalarField], + ); + + /// Perform msm between G1 `points` and `scalars` + fn msm_public_points_g1( + points: &[P::G1Affine], + scalars: &[Self::ArithmeticShare], + ) -> Self::PointShareG1; + + /// Perform msm between G2 `points` and `scalars` + fn msm_public_points_g2( + points: &[P::G2Affine], + scalars: &[Self::ArithmeticShare], + ) -> Self::PointShareG2; + + /// Multiplies a public point B to the shared point A in place: \[A\] *= B + fn scalar_mul_public_point_g1(a: &P::G1, b: Self::ArithmeticShare) -> Self::PointShareG1; + + /// Add a shared point B in place to the shared point A: \[A\] += \[B\] + fn add_assign_points_g1(a: &mut Self::PointShareG1, b: &Self::PointShareG1); + + /// Add a public point B in place to the shared point A + fn add_assign_points_public_g1(id: Self::PartyID, a: &mut Self::PointShareG1, b: &P::G1); + + /// Reconstructs a shared point: A = Open(\[A\]). + fn open_point_g1(&mut self, a: &Self::PointShareG1) -> IoResult; + + /// Multiplies a share b to the shared point A: \[A\] *= \[b\]. Requires network communication. + fn scalar_mul_g1( + &mut self, + a: &Self::PointShareG1, + b: Self::ArithmeticShare, + ) -> IoResult; + + /// Subtract a shared point B in place from the shared point A: \[A\] -= \[B\] + fn sub_assign_points_g1(a: &mut Self::PointShareG1, b: &Self::PointShareG1); + + /// Perform scalar multiplication of point A with a shared scalar b + fn scalar_mul_public_point_g2(a: &P::G2, b: Self::ArithmeticShare) -> Self::PointShareG2; + + /// Add a shared point B in place to the shared point A: \[A\] += \[B\] + fn add_assign_points_g2(a: &mut Self::PointShareG2, b: &Self::PointShareG2); + + /// Add a public point B in place to the shared point A + fn add_assign_points_public_g2(id: Self::PartyID, a: &mut Self::PointShareG2, b: &P::G2); + + /// Reconstructs a shared points: A = Open(\[A\]), B = Open(\[B\]). + fn open_two_points( + &mut self, + a: Self::PointShareG1, + b: Self::PointShareG2, + ) -> std::io::Result<(P::G1, P::G2)>; + + /// Reconstruct point G_a and perform scalar multiplication of G1_b and r concurrently + fn open_point_and_scalar_mul( + &mut self, + g_a: &Self::PointShareG1, + g1_b: &Self::PointShareG1, + r: Self::ArithmeticShare, + ) -> std::io::Result<(P::G1, Self::PointShareG1)>; +} diff --git a/co-circom/co-groth16/src/mpc/plain.rs b/co-circom/co-groth16/src/mpc/plain.rs new file mode 100644 index 000000000..bc7a07530 --- /dev/null +++ b/co-circom/co-groth16/src/mpc/plain.rs @@ -0,0 +1,156 @@ +use std::sync::Arc; + +use ark_ec::pairing::Pairing; +use ark_ec::scalar_mul::variable_base::VariableBaseMSM; +use ark_ff::UniformRand; +use rand::thread_rng; + +use super::CircomGroth16Prover; + +type IoResult = std::io::Result; + +/// A plain Groth16 driver +pub struct PlainGroth16Driver; + +impl CircomGroth16Prover

for PlainGroth16Driver { + type ArithmeticShare = P::ScalarField; + + type PointShareG1 = P::G1; + + type PointShareG2 = P::G2; + + type PartyID = usize; + + fn rand(&mut self) -> IoResult { + let mut rng = thread_rng(); + Ok(Self::ArithmeticShare::rand(&mut rng)) + } + + fn get_party_id(&self) -> Self::PartyID { + //does't matter + 0 + } + + fn evaluate_constraint( + _: Self::PartyID, + lhs: &[(P::ScalarField, usize)], + public_inputs: &[P::ScalarField], + private_witness: &[Self::ArithmeticShare], + ) -> Self::ArithmeticShare { + let mut acc = P::ScalarField::default(); + for (coeff, index) in lhs { + if index < &public_inputs.len() { + acc += *coeff * public_inputs[*index]; + } else { + acc += *coeff * private_witness[*index - public_inputs.len()]; + } + } + acc + } + + fn promote_to_trivial_shares( + _: Self::PartyID, + public_values: &[P::ScalarField], + ) -> Vec { + public_values.to_vec() + } + + fn local_mul_vec( + &mut self, + a: Vec, + b: Vec, + ) -> Vec { + a.iter().zip(b.iter()).map(|(a, b)| *a * b).collect() + } + + fn msm_and_mul( + &mut self, + h: Vec<

::ScalarField>, + h_query: Arc>, + r: Self::ArithmeticShare, + s: Self::ArithmeticShare, + ) -> IoResult<(Self::PointShareG1, Self::ArithmeticShare)> { + Ok((P::G1::msm_unchecked(h_query.as_ref(), &h), r * s)) + } + + fn distribute_powers_and_mul_by_const( + coeffs: &mut [Self::ArithmeticShare], + roots: &[P::ScalarField], + ) { + #[allow(unused_mut)] + for (mut c, pow) in coeffs.iter_mut().zip(roots) { + *c *= pow; + } + } + + fn msm_public_points_g1( + points: &[P::G1Affine], + scalars: &[Self::ArithmeticShare], + ) -> Self::PointShareG1 { + P::G1::msm_unchecked(points, scalars) + } + + fn msm_public_points_g2( + points: &[P::G2Affine], + scalars: &[Self::ArithmeticShare], + ) -> Self::PointShareG2 { + P::G2::msm_unchecked(points, scalars) + } + + fn scalar_mul_public_point_g1(a: &P::G1, b: Self::ArithmeticShare) -> Self::PointShareG1 { + *a * b + } + + fn add_assign_points_g1(a: &mut Self::PointShareG1, b: &Self::PointShareG1) { + *a += b; + } + + fn add_assign_points_public_g1(_: Self::PartyID, a: &mut Self::PointShareG1, b: &P::G1) { + *a += b; + } + + fn open_point_g1(&mut self, a: &Self::PointShareG1) -> super::IoResult { + Ok(*a) + } + + fn scalar_mul_g1( + &mut self, + a: &Self::PointShareG1, + b: Self::ArithmeticShare, + ) -> super::IoResult { + Ok(*a * b) + } + + fn sub_assign_points_g1(a: &mut Self::PointShareG1, b: &Self::PointShareG1) { + *a -= b; + } + + fn scalar_mul_public_point_g2(a: &P::G2, b: Self::ArithmeticShare) -> Self::PointShareG2 { + *a * b + } + + fn add_assign_points_g2(a: &mut Self::PointShareG2, b: &Self::PointShareG2) { + *a += b; + } + + fn add_assign_points_public_g2(_: Self::PartyID, a: &mut Self::PointShareG2, b: &P::G2) { + *a += b; + } + + fn open_two_points( + &mut self, + a: Self::PointShareG1, + b: Self::PointShareG2, + ) -> std::io::Result<(P::G1, P::G2)> { + Ok((a, b)) + } + + fn open_point_and_scalar_mul( + &mut self, + g_a: &Self::PointShareG1, + g1_b: &Self::PointShareG1, + r: Self::ArithmeticShare, + ) -> super::IoResult<(P::G1, Self::PointShareG1)> { + Ok((*g_a, *g1_b * r)) + } +} diff --git a/co-circom/co-groth16/src/mpc/rep3.rs b/co-circom/co-groth16/src/mpc/rep3.rs new file mode 100644 index 000000000..36709ad35 --- /dev/null +++ b/co-circom/co-groth16/src/mpc/rep3.rs @@ -0,0 +1,202 @@ +use std::sync::Arc; + +use ark_ec::pairing::Pairing; +use mpc_core::protocols::rep3::{ + arithmetic, + id::PartyID, + network::{IoContext, Rep3Network}, + pointshare, Rep3PointShare, Rep3PrimeFieldShare, +}; +use rayon::prelude::*; + +use super::{CircomGroth16Prover, IoResult}; + +/// A Groth16 driver for REP3 secret sharing +/// +/// Contains two [`IoContext`]s, `io_context0` for the main execution and `io_context1` for parts that can run concurrently. +pub struct Rep3Groth16Driver { + io_context0: IoContext, + io_context1: IoContext, +} + +impl Rep3Groth16Driver { + /// Create a new [`Rep3Groth16Driver`] with two [`IoContext`]s + pub fn new(io_context0: IoContext, io_context1: IoContext) -> Self { + Self { + io_context0, + io_context1, + } + } +} + +impl CircomGroth16Prover

for Rep3Groth16Driver +where + N: 'static, +{ + type ArithmeticShare = Rep3PrimeFieldShare; + type PointShareG1 = Rep3PointShare; + type PointShareG2 = Rep3PointShare; + + type PartyID = PartyID; + + fn rand(&mut self) -> IoResult { + Ok(Self::ArithmeticShare::rand(&mut self.io_context0)) + } + + fn get_party_id(&self) -> Self::PartyID { + self.io_context0.id + } + + fn evaluate_constraint( + party_id: Self::PartyID, + lhs: &[(P::ScalarField, usize)], + public_inputs: &[P::ScalarField], + private_witness: &[Self::ArithmeticShare], + ) -> Self::ArithmeticShare { + let mut acc = Self::ArithmeticShare::default(); + for (coeff, index) in lhs { + if index < &public_inputs.len() { + let val = public_inputs[*index]; + let mul_result = val * coeff; + arithmetic::add_assign_public(&mut acc, mul_result, party_id); + } else { + let current_witness = private_witness[*index - public_inputs.len()]; + arithmetic::add_assign(&mut acc, arithmetic::mul_public(current_witness, *coeff)); + } + } + acc + } + + fn promote_to_trivial_shares( + id: Self::PartyID, + public_values: &[P::ScalarField], + ) -> Vec { + public_values + .par_iter() + .with_min_len(1024) + .map(|value| Self::ArithmeticShare::promote_from_trivial(value, id)) + .collect() + } + + fn local_mul_vec( + &mut self, + a: Vec, + b: Vec, + ) -> Vec { + arithmetic::local_mul_vec(&a, &b, &mut self.io_context0.rngs) + } + + fn msm_and_mul( + &mut self, + h: Vec<

::ScalarField>, + h_query: Arc>, + r: Self::ArithmeticShare, + s: Self::ArithmeticShare, + ) -> IoResult<(Self::PointShareG1, Self::ArithmeticShare)> { + std::thread::scope(|scope| { + let h_acc = scope.spawn(|| { + let msm_h_query = tracing::debug_span!("msm h_query").entered(); + let h = arithmetic::io_mul_vec(h, &mut self.io_context0)?; + let result = pointshare::msm_public_points(h_query.as_ref(), &h); + msm_h_query.exit(); + Ok::<_, std::io::Error>(result) + }); + let mul = arithmetic::mul(r, s, &mut self.io_context1)?; + Ok((h_acc.join().expect("can join")?, mul)) + }) + } + + fn distribute_powers_and_mul_by_const( + coeffs: &mut [Self::ArithmeticShare], + roots: &[P::ScalarField], + ) { + coeffs + .par_iter_mut() + .zip_eq(roots.par_iter()) + .with_min_len(512) + .for_each(|(c, pow)| { + arithmetic::mul_assign_public(c, *pow); + }) + } + + fn msm_public_points_g1( + points: &[P::G1Affine], + scalars: &[Self::ArithmeticShare], + ) -> Self::PointShareG1 { + pointshare::msm_public_points(points, scalars) + } + + fn msm_public_points_g2( + points: &[P::G2Affine], + scalars: &[Self::ArithmeticShare], + ) -> Self::PointShareG2 { + pointshare::msm_public_points(points, scalars) + } + + fn scalar_mul_public_point_g1(a: &P::G1, b: Self::ArithmeticShare) -> Self::PointShareG1 { + pointshare::scalar_mul_public_point(a, b) + } + + /// Add a shared point B in place to the shared point A: \[A\] += \[B\] + fn add_assign_points_g1(a: &mut Self::PointShareG1, b: &Self::PointShareG1) { + pointshare::add_assign(a, b) + } + + fn add_assign_points_public_g1(id: Self::PartyID, a: &mut Self::PointShareG1, b: &P::G1) { + pointshare::add_assign_public(a, b, id) + } + + fn open_point_g1(&mut self, a: &Self::PointShareG1) -> IoResult { + pointshare::open_point(a, &mut self.io_context0) + } + + fn scalar_mul_g1( + &mut self, + a: &Self::PointShareG1, + b: Self::ArithmeticShare, + ) -> IoResult { + pointshare::scalar_mul(a, b, &mut self.io_context0) + } + + fn sub_assign_points_g1(a: &mut Self::PointShareG1, b: &Self::PointShareG1) { + pointshare::sub_assign(a, b); + } + + fn scalar_mul_public_point_g2(a: &P::G2, b: Self::ArithmeticShare) -> Self::PointShareG2 { + pointshare::scalar_mul_public_point(a, b) + } + + fn add_assign_points_g2(a: &mut Self::PointShareG2, b: &Self::PointShareG2) { + pointshare::add_assign(a, b) + } + + fn add_assign_points_public_g2(id: Self::PartyID, a: &mut Self::PointShareG2, b: &P::G2) { + pointshare::add_assign_public(a, b, id) + } + + fn open_two_points( + &mut self, + a: Self::PointShareG1, + b: Self::PointShareG2, + ) -> std::io::Result<(P::G1, P::G2)> { + let s1 = a.b; + let s2 = b.b; + let (mut r1, mut r2) = self.io_context0.network.reshare((s1, s2))?; + r1 += a.a + a.b; + r2 += b.a + b.b; + Ok((r1, r2)) + } + + fn open_point_and_scalar_mul( + &mut self, + g_a: &Self::PointShareG1, + g1_b: &Self::PointShareG1, + r: Self::ArithmeticShare, + ) -> std::io::Result<(

::G1, Self::PointShareG1)> { + std::thread::scope(|s| { + let opened = s.spawn(|| pointshare::open_point(g_a, &mut self.io_context0)); + let mul_result = pointshare::scalar_mul(g1_b, r, &mut self.io_context1)?; + Ok((opened.join().expect("can join")?, mul_result)) + }) + } +} diff --git a/co-circom/co-groth16/src/mpc/shamir.rs b/co-circom/co-groth16/src/mpc/shamir.rs new file mode 100644 index 000000000..5e97a3102 --- /dev/null +++ b/co-circom/co-groth16/src/mpc/shamir.rs @@ -0,0 +1,201 @@ +use std::sync::Arc; + +use super::{CircomGroth16Prover, IoResult}; +use ark_ec::pairing::Pairing; +use ark_ff::PrimeField; +use mpc_core::protocols::shamir::{ + arithmetic, core, network::ShamirNetwork, pointshare, ShamirPointShare, ShamirPrimeFieldShare, + ShamirProtocol, +}; +use rayon::prelude::*; + +/// A Groth16 dirver unsing shamir secret sharing +/// +/// Contains two [`ShamirProtocol`]s, `protocol0` for the main execution and `protocol0` for parts that can run concurrently. +pub struct ShamirGroth16Driver { + protocol0: ShamirProtocol, + protocol1: ShamirProtocol, +} + +impl ShamirGroth16Driver { + /// Create a new [`ShamirGroth16Driver`] with two [`ShamirProtocol`]s + pub fn new(protocol0: ShamirProtocol, protocol1: ShamirProtocol) -> Self { + Self { + protocol0, + protocol1, + } + } +} + +impl CircomGroth16Prover

+ for ShamirGroth16Driver +{ + type ArithmeticShare = ShamirPrimeFieldShare; + type PointShareG1 = ShamirPointShare; + type PointShareG2 = ShamirPointShare; + + type PartyID = usize; + + fn rand(&mut self) -> IoResult { + self.protocol0.rand() + } + + fn get_party_id(&self) -> Self::PartyID { + self.protocol0.network.get_id() + } + + fn evaluate_constraint( + _party_id: Self::PartyID, + lhs: &[(P::ScalarField, usize)], + public_inputs: &[P::ScalarField], + private_witness: &[Self::ArithmeticShare], + ) -> Self::ArithmeticShare { + let mut acc = Self::ArithmeticShare::default(); + for (coeff, index) in lhs { + if index < &public_inputs.len() { + let val = public_inputs[*index]; + let mul_result = val * coeff; + arithmetic::add_assign_public(&mut acc, mul_result); + } else { + let current_witness = private_witness[*index - public_inputs.len()]; + arithmetic::add_assign(&mut acc, arithmetic::mul_public(current_witness, *coeff)); + } + } + acc + } + + fn promote_to_trivial_shares( + _id: Self::PartyID, + public_values: &[P::ScalarField], + ) -> Vec { + arithmetic::promote_to_trivial_shares(public_values) + } + + fn local_mul_vec( + &mut self, + a: Vec, + b: Vec, + ) -> Vec { + arithmetic::local_mul_vec(&a, &b) + } + + fn msm_and_mul( + &mut self, + h: Vec<

::ScalarField>, + h_query: Arc>, + r: Self::ArithmeticShare, + s: Self::ArithmeticShare, + ) -> IoResult<(Self::PointShareG1, Self::ArithmeticShare)> { + std::thread::scope(|scope| { + let h_acc = scope.spawn(|| { + let msm_h_query = tracing::debug_span!("msm h_query").entered(); + let h = self.protocol0.degree_reduce_vec(h)?; + let result = pointshare::msm_public_points(h_query.as_ref(), &h); + msm_h_query.exit(); + Ok::<_, std::io::Error>(result) + }); + let mul = arithmetic::mul(r, s, &mut self.protocol1)?; + Ok((h_acc.join().expect("can join")?, mul)) + }) + } + + fn distribute_powers_and_mul_by_const( + coeffs: &mut [Self::ArithmeticShare], + roots: &[P::ScalarField], + ) { + coeffs + .par_iter_mut() + .zip_eq(roots.par_iter()) + .with_min_len(512) + .for_each(|(c, pow)| { + arithmetic::mul_assign_public(c, *pow); + }) + } + + fn msm_public_points_g1( + points: &[P::G1Affine], + scalars: &[Self::ArithmeticShare], + ) -> Self::PointShareG1 { + pointshare::msm_public_points(points, scalars) + } + + fn msm_public_points_g2( + points: &[P::G2Affine], + scalars: &[Self::ArithmeticShare], + ) -> Self::PointShareG2 { + pointshare::msm_public_points(points, scalars) + } + + fn scalar_mul_public_point_g1(a: &P::G1, b: Self::ArithmeticShare) -> Self::PointShareG1 { + pointshare::scalar_mul_public_point(b, a) + } + + fn add_assign_points_g1(a: &mut Self::PointShareG1, b: &Self::PointShareG1) { + pointshare::add_assign(a, b) + } + + fn add_assign_points_public_g1(_id: Self::PartyID, a: &mut Self::PointShareG1, b: &P::G1) { + pointshare::add_assign_public(a, b) + } + + fn open_point_g1(&mut self, a: &Self::PointShareG1) -> IoResult { + pointshare::open_point(a, &mut self.protocol0) + } + + fn scalar_mul_g1( + &mut self, + a: &Self::PointShareG1, + b: Self::ArithmeticShare, + ) -> IoResult { + pointshare::scalar_mul(a, b, &mut self.protocol0) + } + + fn sub_assign_points_g1(a: &mut Self::PointShareG1, b: &Self::PointShareG1) { + pointshare::sub_assign(a, b); + } + + fn scalar_mul_public_point_g2(a: &P::G2, b: Self::ArithmeticShare) -> Self::PointShareG2 { + pointshare::scalar_mul_public_point(b, a) + } + + fn add_assign_points_g2(a: &mut Self::PointShareG2, b: &Self::PointShareG2) { + pointshare::add_assign(a, b) + } + + fn add_assign_points_public_g2(_id: Self::PartyID, a: &mut Self::PointShareG2, b: &P::G2) { + pointshare::add_assign_public(a, b) + } + + fn open_two_points( + &mut self, + a: Self::PointShareG1, + b: Self::PointShareG2, + ) -> std::io::Result<(P::G1, P::G2)> { + let s1 = a.a; + let s2 = b.a; + + let rcv: Vec<(P::G1, P::G2)> = self + .protocol0 + .network + .broadcast_next((s1, s2), self.protocol0.threshold + 1)?; + let (r1, r2): (Vec, Vec) = rcv.into_iter().unzip(); + + let r1 = core::reconstruct_point(&r1, &self.protocol0.open_lagrange_t); + let r2 = core::reconstruct_point(&r2, &self.protocol0.open_lagrange_t); + + Ok((r1, r2)) + } + + fn open_point_and_scalar_mul( + &mut self, + g_a: &Self::PointShareG1, + g1_b: &Self::PointShareG1, + r: Self::ArithmeticShare, + ) -> super::IoResult<(P::G1, Self::PointShareG1)> { + std::thread::scope(|s| { + let opened = s.spawn(|| pointshare::open_point(g_a, &mut self.protocol0)); + let mul_result = pointshare::scalar_mul(g1_b, r, &mut self.protocol1)?; + Ok((opened.join().expect("can join")?, mul_result)) + }) + } +}