diff --git a/co-circom/co-groth16/Cargo.toml b/co-circom/co-groth16/Cargo.toml
index 3b53a1858..06858e915 100644
--- a/co-circom/co-groth16/Cargo.toml
+++ b/co-circom/co-groth16/Cargo.toml
@@ -25,15 +25,17 @@ ark-groth16 = { version = "=0.4.0", default-features = false, features = [
], optional = true }
ark-poly = { workspace = true }
ark-relations = { workspace = true }
+ark-serialize = { workspace = true }
circom-types = { version = "0.5.0", path = "../circom-types" }
co-circom-snarks = { version = "0.1.2", path = "../co-circom-snarks" }
eyre = { workspace = true }
-itertools = { workspace = true }
mpc-core = { version = "0.5.0", path = "../../mpc-core" }
mpc-net = { version = "0.1.2", path = "../../mpc-net" }
num-traits = { workspace = true }
rand = { workspace = true }
+rayon = { workspace = true }
tracing = { workspace = true }
+serde_json = { workspace = true }
[dev-dependencies]
serde_json = { workspace = true }
diff --git a/co-circom/co-groth16/src/groth16.rs b/co-circom/co-groth16/src/groth16.rs
index 028d57959..0299dc1b1 100644
--- a/co-circom/co-groth16/src/groth16.rs
+++ b/co-circom/co-groth16/src/groth16.rs
@@ -1,24 +1,37 @@
//! A Groth16 proof protocol that uses a collaborative MPC protocol to generate the proof.
use ark_ec::pairing::Pairing;
+use ark_ec::scalar_mul::variable_base::VariableBaseMSM;
use ark_ec::{AffineRepr, CurveGroup};
use ark_ff::{FftField, PrimeField};
use ark_poly::{EvaluationDomain, GeneralEvaluationDomain};
-use ark_relations::r1cs::{ConstraintMatrices, SynthesisError};
+use ark_relations::r1cs::{ConstraintMatrices, Matrix, SynthesisError};
use circom_types::groth16::{Groth16Proof, ZKey};
use circom_types::traits::{CircomArkworksPairingBridge, CircomArkworksPrimeFieldBridge};
use co_circom_snarks::SharedWitness;
use eyre::Result;
-use itertools::izip;
-use mpc_core::protocols::plain::PlainDriver;
-use mpc_core::traits::{EcMpcProtocol, MSMProvider};
-use mpc_core::{
- protocols::rep3::{network::Rep3MpcNet, Rep3Protocol},
- traits::{FFTProvider, PairingEcMpcProtocol, PrimeFieldMpcProtocol},
-};
+use mpc_core::protocols::rep3::network::{IoContext, Rep3MpcNet};
+use mpc_core::protocols::shamir::network::ShamirMpcNet;
+use mpc_core::protocols::shamir::{ShamirPreprocessing, ShamirProtocol};
use mpc_net::config::NetworkConfig;
use num_traits::identities::One;
use num_traits::ToPrimitive;
+use rayon::prelude::*;
use std::marker::PhantomData;
+use std::sync::Arc;
+use std::time::Instant;
+use tracing::instrument;
+
+use crate::mpc::plain::PlainGroth16Driver;
+use crate::mpc::rep3::Rep3Groth16Driver;
+use crate::mpc::shamir::ShamirGroth16Driver;
+use crate::mpc::CircomGroth16Prover;
+
+macro_rules! rayon_join {
+ ($t1: expr, $t2: expr, $t3: expr) => {{
+ let ((x, y), z) = rayon::join(|| rayon::join(|| $t1, || $t2), || $t3);
+ (x, y, z)
+ }};
+}
/// The plain [`Groth16`] type.
///
@@ -30,17 +43,12 @@ use std::marker::PhantomData;
///
/// More interesting is the [`Groth16::verify`] method. You can verify any circom Groth16 proof, be it
/// from snarkjs or one created by this project. Under the hood we use the arkwork Groth16 project for verifying.
-pub type Groth16
= CoGroth16::ScalarField>, P>;
+pub type Groth16 = CoGroth16
;
/// A type alias for a [CoGroth16] protocol using replicated secret sharing.
-pub type Rep3CoGroth16
= CoGroth16::ScalarField, Rep3MpcNet>, P>;
-
-type FieldShare = ::ScalarField>>::FieldShare;
-type FieldShareVec = ::ScalarField>>::FieldShareVec;
-type PointShare = >::PointShare;
-type CurveFieldShareVec = ::Affine as AffineRepr>::ScalarField,
->>::FieldShareVec;
+pub type Rep3CoGroth16 = CoGroth16
>;
+/// A type alias for a [CoGroth16] protocol using shamir secret sharing.
+pub type ShamirCoGroth16
= CoGroth16
::ScalarField, N>>;
/* old way of computing root of unity, does not work for bls12_381:
let root_of_unity = {
@@ -54,6 +62,7 @@ calculate smallest quadratic non residue q (by checking q^((p-1)/2)=-1 mod p) al
use g=q^t (this is a 2^s-th root of unity) as (some kind of) generator and compute another domain by repeatedly squaring g, should get to 1 in the s+1-th step.
then if log2(domain_size) equals s we take as root of unity q^2, and else we take the log2(domain_size) + 1-th element of the domain created above
*/
+#[instrument(level = "debug", name = "root of unity", skip_all)]
fn root_of_unity_for_groth16(
pow: usize,
domain: &mut GeneralEvaluationDomain,
@@ -77,26 +86,13 @@ fn root_of_unity_for_groth16(
}
/// A Groth16 proof protocol that uses a collaborative MPC protocol to generate the proof.
-pub struct CoGroth16
-where
- T: PrimeFieldMpcProtocol
- + PairingEcMpcProtocol
- + FFTProvider
- + MSMProvider
- + MSMProvider,
-{
+pub struct CoGroth16> {
pub(crate) driver: T,
phantom_data: PhantomData,
}
-impl CoGroth16
+impl> CoGroth16
where
- T: PrimeFieldMpcProtocol
- + PairingEcMpcProtocol
- + FFTProvider
- + MSMProvider
- + MSMProvider,
- P: CircomArkworksPairingBridge,
P::BaseField: CircomArkworksPrimeFieldBridge,
P::ScalarField: CircomArkworksPrimeFieldBridge,
{
@@ -110,34 +106,54 @@ where
/// Execute the Groth16 prover using the internal MPC driver.
/// This version takes the Circom-generated constraint matrices as input and does not re-calculate them.
+ #[instrument(level = "debug", name = "Groth16 - Proof", skip_all)]
pub fn prove(
- &mut self,
+ mut self,
zkey: &ZKey,
- private_witness: SharedWitness,
+ private_witness: SharedWitness,
) -> Result> {
+ let id = self.driver.get_party_id();
+ tracing::info!("Party {}: starting proof generation..", id);
+ let start = Instant::now();
let matrices = &zkey.matrices;
let num_inputs = matrices.num_instance_variables;
let num_constraints = matrices.num_constraints;
- let public_inputs = &private_witness.public_inputs;
- let private_witness = &private_witness.witness;
- tracing::debug!("calling witness map from matrices...");
+ let public_inputs = Arc::new(private_witness.public_inputs);
+ let private_witness = Arc::new(private_witness.witness);
let h = self.witness_map_from_matrices(
zkey.pow,
matrices,
num_constraints,
num_inputs,
- public_inputs,
- private_witness,
+ &public_inputs,
+ &private_witness,
)?;
- tracing::debug!("done!");
- tracing::debug!("getting r and s...");
- let r = self.driver.rand()?;
- let s = self.driver.rand()?;
- tracing::debug!("done!");
- tracing::debug!("calling create_proof_with_assignment...");
- self.create_proof_with_assignment(zkey, r, s, &h, &public_inputs[1..], private_witness)
+ let (r, s) = (self.driver.rand()?, self.driver.rand()?);
+ let proof =
+ self.create_proof_with_assignment(zkey, r, s, h, public_inputs, private_witness)?;
+
+ let duration_ms = start.elapsed().as_micros() as f64 / 1000.;
+ tracing::info!("Party {}: Proof generation took {} ms", id, duration_ms);
+ Ok(proof)
}
+ fn evaluate_constraint(
+ party_id: T::PartyID,
+ domain_size: usize,
+ matrix: &Matrix,
+ public_inputs: &[P::ScalarField],
+ private_witness: &[T::ArithmeticShare],
+ ) -> Vec {
+ let mut result = matrix
+ .par_iter()
+ .with_min_len(32)
+ .map(|x| T::evaluate_constraint(party_id, x, public_inputs, private_witness))
+ .collect::>();
+ result.resize(domain_size, T::ArithmeticShare::default());
+ result
+ }
+
+ #[instrument(level = "debug", name = "witness map from matrices", skip_all)]
fn witness_map_from_matrices(
&mut self,
power: usize,
@@ -145,176 +161,305 @@ where
num_constraints: usize,
num_inputs: usize,
public_inputs: &[P::ScalarField],
- private_witness: &FieldShareVec,
- ) -> Result> {
+ private_witness: &[T::ArithmeticShare],
+ ) -> Result> {
let mut domain =
GeneralEvaluationDomain::::new(num_constraints + num_inputs)
.ok_or(SynthesisError::PolynomialDegreeTooLarge)?;
- let root_of_unity = root_of_unity_for_groth16(power, &mut domain);
-
let domain_size = domain.size();
- let mut a = vec![FieldShare::::default(); domain_size];
- let mut b = vec![FieldShare::::default(); domain_size];
- tracing::debug!("evaluating constraints..");
- for (a, b, at_i, bt_i) in izip!(&mut a, &mut b, &matrices.a, &matrices.b) {
- *a = self
- .driver
- .evaluate_constraint(at_i, public_inputs, private_witness);
- *b = self
- .driver
- .evaluate_constraint(bt_i, public_inputs, private_witness);
- }
- tracing::debug!("done!");
- let mut a = FieldShareVec::::from(a);
- let promoted_public = self.driver.promote_to_trivial_shares(public_inputs);
- self.driver
- .clone_from_slice(&mut a, &promoted_public, num_constraints, 0, num_inputs);
-
- let mut b = FieldShareVec::::from(b);
- let mut c = self.driver.mul_vec(&a, &b)?;
- self.driver.ifft_in_place(&mut a, &domain);
- self.driver.ifft_in_place(&mut b, &domain);
- self.driver.distribute_powers_and_mul_by_const(
- &mut a,
- root_of_unity,
- P::ScalarField::one(),
+ let party_id = self.driver.get_party_id();
+ let eval_constraint_span =
+ tracing::debug_span!("evaluate constraints + root of unity computation").entered();
+ let (roots_to_power_domain, a, b) = rayon_join!(
+ {
+ let root_of_unity_span =
+ tracing::debug_span!("root of unity computation").entered();
+ let root_of_unity = root_of_unity_for_groth16(power, &mut domain);
+ let mut roots = Vec::with_capacity(domain_size);
+ let mut c = P::ScalarField::one();
+ for _ in 0..domain_size {
+ roots.push(c);
+ c *= root_of_unity;
+ }
+ root_of_unity_span.exit();
+ Arc::new(roots)
+ },
+ {
+ let eval_constraint_span_a =
+ tracing::debug_span!("evaluate constraints - a").entered();
+ let mut result = Self::evaluate_constraint(
+ party_id,
+ domain_size,
+ &matrices.a,
+ public_inputs,
+ private_witness,
+ );
+ let promoted_public = T::promote_to_trivial_shares(party_id, public_inputs);
+ result[num_constraints..num_constraints + num_inputs]
+ .clone_from_slice(&promoted_public[..num_inputs]);
+ eval_constraint_span_a.exit();
+ result
+ },
+ {
+ let eval_constraint_span_b =
+ tracing::debug_span!("evaluate constraints - a").entered();
+ let result = Self::evaluate_constraint(
+ party_id,
+ domain_size,
+ &matrices.b,
+ public_inputs,
+ private_witness,
+ );
+ eval_constraint_span_b.exit();
+ result
+ }
);
- self.driver.distribute_powers_and_mul_by_const(
- &mut b,
- root_of_unity,
- P::ScalarField::one(),
- );
- self.driver.fft_in_place(&mut a, &domain);
- self.driver.fft_in_place(&mut b, &domain);
- //this can be in-place so that we do not have to allocate memory
- let mut ab = self.driver.mul_vec(&a, &b)?;
- std::mem::drop(a);
- std::mem::drop(b);
-
- self.driver.ifft_in_place(&mut c, &domain);
- self.driver.distribute_powers_and_mul_by_const(
- &mut c,
- root_of_unity,
- P::ScalarField::one(),
- );
- self.driver.fft_in_place(&mut c, &domain);
- self.driver.sub_assign_vec(&mut ab, &c);
+ let domain = Arc::new(domain);
+ eval_constraint_span.exit();
+
+ let (a_tx, a_rx) = std::sync::mpsc::channel();
+ let (b_tx, b_rx) = std::sync::mpsc::channel();
+ let (c_tx, c_rx) = std::sync::mpsc::channel();
+ let a_domain = Arc::clone(&domain);
+ let b_domain = Arc::clone(&domain);
+ let c_domain = Arc::clone(&domain);
+ let mut a_result = a.clone();
+ let mut b_result = b.clone();
+ let a_roots = Arc::clone(&roots_to_power_domain);
+ let b_roots = Arc::clone(&roots_to_power_domain);
+ let c_roots = Arc::clone(&roots_to_power_domain);
+ rayon::spawn(move || {
+ let a_span = tracing::debug_span!("distribute powers mul a (fft/ifft)").entered();
+ a_domain.ifft_in_place(&mut a_result);
+ T::distribute_powers_and_mul_by_const(&mut a_result, &a_roots);
+ a_domain.fft_in_place(&mut a_result);
+ a_tx.send(a_result).expect("channel not droped");
+ a_span.exit();
+ });
+
+ rayon::spawn(move || {
+ let b_span = tracing::debug_span!("distribute powers mul b (fft/ifft)").entered();
+ b_domain.ifft_in_place(&mut b_result);
+ T::distribute_powers_and_mul_by_const(&mut b_result, &b_roots);
+ b_domain.fft_in_place(&mut b_result);
+ b_tx.send(b_result).expect("channel not droped");
+ b_span.exit();
+ });
+
+ let local_mul_vec_span = tracing::debug_span!("c: local_mul_vec").entered();
+ let mut ab = self.driver.local_mul_vec(a, b);
+ local_mul_vec_span.exit();
+ rayon::spawn(move || {
+ let ifft_span = tracing::debug_span!("c: ifft in dist pows").entered();
+ c_domain.ifft_in_place(&mut ab);
+ ifft_span.exit();
+ let dist_pows_span = tracing::debug_span!("c: dist pows").entered();
+ #[allow(unused_mut)]
+ ab.par_iter_mut()
+ .zip_eq(c_roots.par_iter())
+ .with_min_len(512)
+ .for_each(|(share, pow)| {
+ *share *= pow;
+ });
+ dist_pows_span.exit();
+ let fft_span = tracing::debug_span!("c: fft in dist pows").entered();
+ c_domain.fft_in_place(&mut ab);
+ fft_span.exit();
+ c_tx.send(ab).expect("channel not dropped");
+ });
+
+ let a = a_rx.recv()?;
+ let b = b_rx.recv()?;
+
+ let compute_ab_span = tracing::debug_span!("compute ab").entered();
+ let local_ab_span = tracing::debug_span!("local part (mul and sub)").entered();
+ // same as above. No IO task is run at the moment.
+ let mut ab = self.driver.local_mul_vec(a, b);
+ let c = c_rx.recv()?;
+ ab.par_iter_mut()
+ .zip_eq(c.par_iter())
+ .with_min_len(512)
+ .for_each(|(a, b)| {
+ *a -= b;
+ });
+ local_ab_span.exit();
+ compute_ab_span.exit();
Ok(ab)
}
- fn calculate_coeff(
- &mut self,
- initial: PointShare,
- query: &[C::Affine],
- vk_param: C::Affine,
- input_assignment: &[C::ScalarField],
- aux_assignment: &CurveFieldShareVec,
- ) -> PointShare
- where
- T: EcMpcProtocol,
- T: MSMProvider,
- {
- tracing::debug!("calculate coeffs..");
+ fn calculate_coeff_g1(
+ id: T::PartyID,
+ initial: T::PointShareG1,
+ query: &[P::G1Affine],
+ vk_param: P::G1Affine,
+ input_assignment: &[P::ScalarField],
+ aux_assignment: &[T::ArithmeticShare],
+ ) -> T::PointShareG1 {
let pub_len = input_assignment.len();
- let pub_acc = C::msm_unchecked(&query[1..=pub_len], input_assignment);
- let priv_acc = MSMProvider::::msm_public_points(
- &mut self.driver,
- &query[1 + pub_len..],
- aux_assignment,
+
+ // we block this thread of the runtime here.
+ // It should not matter too much, as we have multithreaded
+ // runtime.
+ let (pub_acc, priv_acc) = rayon::join(
+ || P::G1::msm_unchecked(&query[1..=pub_len], input_assignment),
+ || T::msm_public_points_g1(&query[1 + pub_len..], aux_assignment),
);
let mut res = initial;
- EcMpcProtocol::::add_assign_points_public_affine(&mut self.driver, &mut res, &query[0]);
- EcMpcProtocol::::add_assign_points_public_affine(&mut self.driver, &mut res, &vk_param);
- EcMpcProtocol::::add_assign_points_public(&mut self.driver, &mut res, &pub_acc);
- EcMpcProtocol::::add_assign_points(&mut self.driver, &mut res, &priv_acc);
+ T::add_assign_points_public_g1(id, &mut res, &query[0].into_group());
+ T::add_assign_points_public_g1(id, &mut res, &vk_param.into_group());
+ T::add_assign_points_public_g1(id, &mut res, &pub_acc);
+ T::add_assign_points_g1(&mut res, &priv_acc);
+ res
+ }
+
+ fn calculate_coeff_g2(
+ id: T::PartyID,
+ initial: T::PointShareG2,
+ query: &[P::G2Affine],
+ vk_param: P::G2Affine,
+ input_assignment: &[P::ScalarField],
+ aux_assignment: &[T::ArithmeticShare],
+ ) -> T::PointShareG2 {
+ let pub_len = input_assignment.len();
+
+ // we block this thread of the runtime here.
+ // It should not matter too much, as we have multithreaded
+ // runtime.
+ let (pub_acc, priv_acc) = rayon::join(
+ || P::G2::msm_unchecked(&query[1..=pub_len], input_assignment),
+ || T::msm_public_points_g2(&query[1 + pub_len..], aux_assignment),
+ );
- tracing::debug!("done..");
+ let mut res = initial;
+ T::add_assign_points_public_g2(id, &mut res, &query[0].into_group());
+ T::add_assign_points_public_g2(id, &mut res, &vk_param.into_group());
+ T::add_assign_points_public_g2(id, &mut res, &pub_acc);
+ T::add_assign_points_g2(&mut res, &priv_acc);
res
}
+ #[instrument(level = "debug", name = "create proof with assignment", skip_all)]
fn create_proof_with_assignment(
- &mut self,
+ mut self,
zkey: &ZKey,
- r: FieldShare,
- s: FieldShare,
- h: &FieldShareVec,
- input_assignment: &[P::ScalarField],
- aux_assignment: &FieldShareVec,
+ r: T::ArithmeticShare,
+ s: T::ArithmeticShare,
+ h: Vec,
+ input_assignment: Arc>,
+ aux_assignment: Arc>,
) -> Result> {
- tracing::debug!("create proof with assignment...");
- //let c_acc_time = start_timer!(|| "Compute C");
- let h_acc = MSMProvider::::msm_public_points(&mut self.driver, &zkey.h_query, h);
-
- // Compute C
- let l_aux_acc = MSMProvider::::msm_public_points(
- &mut self.driver,
- &zkey.l_query,
- aux_assignment,
- );
-
let delta_g1 = zkey.delta_g1.into_group();
- let rs = self.driver.mul(&r, &s)?;
- let r_s_delta_g1 = self.driver.scalar_mul_public_point(&delta_g1, &rs);
+ let (l_acc_tx, l_acc_rx) = std::sync::mpsc::channel();
+ let h_query = Arc::clone(&zkey.h_query);
+ let l_query = Arc::clone(&zkey.l_query);
+
+ let party_id = self.driver.get_party_id();
+ let (r_g1_tx, r_g1_rx) = std::sync::mpsc::channel();
+ let (s_g1_tx, s_g1_rx) = std::sync::mpsc::channel();
+ let (s_g2_tx, s_g2_rx) = std::sync::mpsc::channel();
+ let a_query = Arc::clone(&zkey.a_query);
+ let b_g1_query = Arc::clone(&zkey.b_g1_query);
+ let b_g2_query = Arc::clone(&zkey.b_g2_query);
+ let input_assignment1 = Arc::clone(&input_assignment);
+ let input_assignment2 = Arc::clone(&input_assignment);
+ let input_assignment3 = Arc::clone(&input_assignment);
+ let aux_assignment1 = Arc::clone(&aux_assignment);
+ let aux_assignment2 = Arc::clone(&aux_assignment);
+ let aux_assignment3 = Arc::clone(&aux_assignment);
+ let aux_assignment4 = Arc::clone(&aux_assignment);
+ let alpha_g1 = zkey.vk.alpha_g1;
+ let beta_g1 = zkey.beta_g1;
+ let beta_g2 = zkey.vk.beta_g2;
+ let delta_g2 = zkey.vk.delta_g2.into_group();
+
+ rayon::spawn(move || {
+ let compute_a =
+ tracing::debug_span!("compute A in create proof with assignment").entered();
+ // Compute A
+ let r_g1 = T::scalar_mul_public_point_g1(&delta_g1, r);
+ let r_g1 = Self::calculate_coeff_g1(
+ party_id,
+ r_g1,
+ &a_query,
+ alpha_g1,
+ &input_assignment1[1..],
+ &aux_assignment1,
+ );
+ r_g1_tx.send(r_g1).expect("not dropped");
+ compute_a.exit();
+ });
+
+ rayon::spawn(move || {
+ let compute_b =
+ tracing::debug_span!("compute B/G1 in create proof with assignment").entered();
+ // Compute B in G1
+ // In original implementation this is skipped if r==0, however r is shared in our case
+ let s_g1 = T::scalar_mul_public_point_g1(&delta_g1, s);
+ let s_g1 = Self::calculate_coeff_g1(
+ party_id,
+ s_g1,
+ &b_g1_query,
+ beta_g1,
+ &input_assignment2[1..],
+ &aux_assignment2,
+ );
+ s_g1_tx.send(s_g1).expect("not dropped");
+ compute_b.exit();
+ });
+
+ rayon::spawn(move || {
+ let compute_b =
+ tracing::debug_span!("compute B/G2 in create proof with assignment").entered();
+ // Compute B in G2
+ let s_g2 = T::scalar_mul_public_point_g2(&delta_g2, s);
+ let s_g2 = Self::calculate_coeff_g2(
+ party_id,
+ s_g2,
+ &b_g2_query,
+ beta_g2,
+ &input_assignment3[1..],
+ &aux_assignment3,
+ );
+ s_g2_tx.send(s_g2).expect("not dropped");
+ compute_b.exit();
+ });
- //end_timer!(c_acc_time);
+ rayon::spawn(move || {
+ let msm_l_query = tracing::debug_span!("msm l_query").entered();
+ let result = T::msm_public_points_g1(l_query.as_ref(), &aux_assignment4);
+ l_acc_tx.send(result).expect("channel not dropped");
+ msm_l_query.exit();
+ });
- // Compute A
- // let a_acc_time = start_timer!(|| "Compute A");
- let r_g1 = self.driver.scalar_mul_public_point(&delta_g1, &r);
+ // TODO we can remove the networking round in msm_and_mul because we only have linear operations afterwards
+ let (h_acc, rs) = self.driver.msm_and_mul(h, h_query, r, s)?;
+ let r_s_delta_g1 = T::scalar_mul_public_point_g1(&delta_g1, rs);
- let g_a = self.calculate_coeff::(
- r_g1,
- &zkey.a_query,
- zkey.vk.alpha_g1,
- input_assignment,
- aux_assignment,
- );
+ let l_aux_acc = l_acc_rx.recv().expect("channel not dropped");
- // Open here since g_a is part of proof
- let g_a_opened = EcMpcProtocol::::open_point(&mut self.driver, &g_a)?;
- let s_g_a = self.driver.scalar_mul_public_point(&g_a_opened, &s);
- // end_timer!(a_acc_time);
-
- // Compute B in G1
- // In original implementation this is skipped if r==0, however r is shared in our case
- // let b_g1_acc_time = start_timer!(|| "Compute B in G1");
- let s_g1 = self.driver.scalar_mul_public_point(&delta_g1, &s);
- let g1_b = self.calculate_coeff::(
- s_g1,
- &zkey.b_g1_query,
- zkey.beta_g1,
- input_assignment,
- aux_assignment,
- );
- let r_g1_b = EcMpcProtocol::::scalar_mul(&mut self.driver, &g1_b, &r)?;
- // end_timer!(b_g1_acc_time);
+ let calculate_coeff_span = tracing::debug_span!("calculate coeff").entered();
+ let g_a = r_g1_rx.recv()?;
+ let g1_b = s_g1_rx.recv()?;
+ calculate_coeff_span.exit();
- // Compute B in G2
- let delta_g2 = zkey.vk.delta_g2.into_group();
- // let b_g2_acc_time = start_timer!(|| "Compute B in G2");
- let s_g2 = self.driver.scalar_mul_public_point(&delta_g2, &s);
- let g2_b = self.calculate_coeff::(
- s_g2,
- &zkey.b_g2_query,
- zkey.vk.beta_g2,
- input_assignment,
- aux_assignment,
- );
- // end_timer!(b_g2_acc_time);
+ let network_round = tracing::debug_span!("network round after calc coeff").entered();
+ let (g_a_opened, r_g1_b) = self.driver.open_point_and_scalar_mul(&g_a, &g1_b, r)?;
+ network_round.exit();
+
+ let last_round = tracing::debug_span!("finish open two points and some adds").entered();
+ let s_g_a = T::scalar_mul_public_point_g1(&g_a_opened, s);
- // let c_time = start_timer!(|| "Finish C");
let mut g_c = s_g_a;
- EcMpcProtocol::::add_assign_points(&mut self.driver, &mut g_c, &r_g1_b);
- EcMpcProtocol::::sub_assign_points(&mut self.driver, &mut g_c, &r_s_delta_g1);
- EcMpcProtocol::::add_assign_points(&mut self.driver, &mut g_c, &l_aux_acc);
- EcMpcProtocol::::add_assign_points(&mut self.driver, &mut g_c, &h_acc);
- // end_timer!(c_time);
+ T::add_assign_points_g1(&mut g_c, &r_g1_b);
+ T::sub_assign_points_g1(&mut g_c, &r_s_delta_g1);
+ T::add_assign_points_g1(&mut g_c, &l_aux_acc);
+ T::add_assign_points_g1(&mut g_c, &h_acc);
- tracing::debug!("almost done...");
- let (g_c_opened, g2_b_opened) =
- PairingEcMpcProtocol::::open_two_points(&mut self.driver, &g_c, &g2_b)?;
+ let g2_b = s_g2_rx.recv()?;
+ let (g_c_opened, g2_b_opened) = self.driver.open_two_points(g_c, g2_b)?;
+ last_round.exit();
Ok(Groth16Proof {
pi_a: g_a_opened.into_affine(),
@@ -326,7 +471,7 @@ where
}
}
-impl Rep3CoGroth16
+impl Rep3CoGroth16
where
P: CircomArkworksPairingBridge,
P::BaseField: CircomArkworksPrimeFieldBridge,
@@ -335,8 +480,42 @@ where
/// Create a new [Rep3CoGroth16] protocol with a given network configuration.
pub fn with_network_config(config: NetworkConfig) -> Result {
let mpc_net = Rep3MpcNet::new(config)?;
- let driver = Rep3Protocol::::new(mpc_net)?;
- Ok(CoGroth16::new(driver))
+ let mut io_context0 = IoContext::init(mpc_net)?;
+ let io_context1 = io_context0.fork()?;
+ let driver = Rep3Groth16Driver::new(io_context0, io_context1);
+ Ok(CoGroth16 {
+ driver,
+ phantom_data: PhantomData,
+ })
+ }
+}
+
+impl ShamirCoGroth16
+where
+ P: CircomArkworksPairingBridge,
+ P::BaseField: CircomArkworksPrimeFieldBridge,
+ P::ScalarField: CircomArkworksPrimeFieldBridge,
+{
+ /// Create a new [ShamirCoGroth16] protocol with a given network configuration.
+ pub fn with_network_config(
+ threshold: usize,
+ config: NetworkConfig,
+ zkey: &ZKey
,
+ ) -> Result {
+ let domain_size = 2usize.pow(u32::try_from(zkey.pow).expect("pow fits into u32"));
+ // we need domain_size + 2 + 1 number of corr rand pairs in witness_map_from_matrices (degree_reduce_vec + r and s + 1 for fork)
+ let num_pairs = domain_size + 2 + 1;
+ let mpc_net = ShamirMpcNet::new(config)?;
+ let preprocessing = ShamirPreprocessing::new(threshold, mpc_net, num_pairs)?;
+ let mut protocol0 = ShamirProtocol::from(preprocessing);
+ // the protocol1 is only used for scalar_mul and a field_mul which need 1 pair each (ergo 2
+ // pairs)
+ let protocol1 = protocol0.fork_with_pairs(2)?;
+ let driver = ShamirGroth16Driver::new(protocol0, protocol1);
+ Ok(CoGroth16 {
+ driver,
+ phantom_data: PhantomData,
+ })
}
}
@@ -352,69 +531,12 @@ where
/// DOES NOT PERFORM ANY MPC. For a plain prover checkout the [Groth16 implementation of arkworks](https://docs.rs/ark-groth16/latest/ark_groth16/).
pub fn plain_prove(
zkey: &ZKey,
- private_witness: SharedWitness, P>,
+ private_witness: SharedWitness,
) -> Result> {
- let mut prover = Self {
- driver: PlainDriver::default(),
+ let prover = Self {
+ driver: PlainGroth16Driver,
phantom_data: PhantomData,
};
prover.prove(zkey, private_witness)
}
}
-
-#[cfg(test)]
-mod test {
- use std::fs::File;
-
- use ark_bn254::Bn254;
- use circom_types::R1CS;
- use mpc_core::protocols::{
- rep3::{network::Rep3MpcNet, Rep3Protocol},
- shamir::{network::ShamirMpcNet, ShamirProtocol},
- };
- use rand::thread_rng;
-
- use super::SharedWitness;
- use circom_types::Witness;
-
- #[ignore]
- #[test]
- fn test_rep3() {
- let witness_file = File::open("../../test_vectors/bn254/multiplier2/witness.wtns").unwrap();
- let witness = Witness::::from_reader(witness_file).unwrap();
- let r1cs_file =
- File::open("../../test_vectors/bn254/multiplier2/multiplier2.r1cs").unwrap();
- let r1cs = R1CS::::from_reader(r1cs_file).unwrap();
- let mut rng = thread_rng();
- let [s1, _, _] =
- SharedWitness::, Bn254>::share_rep3(
- witness,
- r1cs.num_inputs,
- &mut rng,
- );
- println!("{}", serde_json::to_string(&s1).unwrap());
- }
-
- fn test_shamir_inner(num_parties: usize, threshold: usize) {
- let witness_file = File::open("../../test_vectors/bn254/multiplier2/witness.wtns").unwrap();
- let witness = Witness::::from_reader(witness_file).unwrap();
- let r1cs_file =
- File::open("../../test_vectors/bn254/multiplier2/multiplier2.r1cs").unwrap();
- let r1cs = R1CS::::from_reader(r1cs_file).unwrap();
- let mut rng = thread_rng();
- let s1 = SharedWitness::, Bn254>::share_shamir(
- witness,
- r1cs.num_inputs,
- threshold,
- num_parties,
- &mut rng,
- );
- println!("{}", serde_json::to_string(&s1[0]).unwrap());
- }
-
- #[ignore]
- #[test]
- fn test_shamir() {
- test_shamir_inner(3, 1);
- }
-}
diff --git a/co-circom/co-groth16/src/lib.rs b/co-circom/co-groth16/src/lib.rs
index 3a9ab3cf4..d7aa3d8fe 100644
--- a/co-circom/co-groth16/src/lib.rs
+++ b/co-circom/co-groth16/src/lib.rs
@@ -1,12 +1,15 @@
//! A library for creating and verifying Groth16 proofs in a collaborative fashion using MPC.
#![warn(missing_docs)]
mod groth16;
+/// This module contains the Groth16 prover trait
+pub mod mpc;
#[cfg(feature = "verifier")]
mod verifier;
pub use groth16::CoGroth16;
pub use groth16::Groth16;
pub use groth16::Rep3CoGroth16;
+pub use groth16::ShamirCoGroth16;
#[cfg(test)]
#[cfg(feature = "verifier")]
@@ -18,7 +21,6 @@ mod tests {
Witness,
};
use co_circom_snarks::SharedWitness;
- use mpc_core::protocols::plain::PlainDriver;
use std::fs::{self, File};
use crate::groth16::Groth16;
@@ -33,19 +35,15 @@ mod tests {
File::open("../../test_vectors/Groth16/bn254/multiplier2/verification_key.json")
.unwrap();
- let driver = PlainDriver::::default();
let witness = Witness::::from_reader(witness_file).unwrap();
let zkey = ZKey::::from_reader(zkey_file).unwrap();
let vk: JsonVerificationKey = serde_json::from_reader(vk_file).unwrap();
let public_input = witness.values[..=zkey.n_public].to_vec();
- let witness = SharedWitness::, Bn254> {
+ let witness = SharedWitness {
public_inputs: public_input.clone(),
witness: witness.values[zkey.n_public + 1..].to_vec(),
};
- let mut groth16 = Groth16::::new(driver);
- let proof = groth16
- .prove(&zkey, witness)
- .expect("proof generation works");
+ let proof = Groth16::::plain_prove(&zkey, witness).expect("proof generation works");
let ser_proof = serde_json::to_string(&proof).unwrap();
let der_proof = serde_json::from_str::>(&ser_proof).unwrap();
let verified = Groth16::verify(&vk, &der_proof, &public_input[1..]).expect("can verify");
@@ -78,7 +76,6 @@ mod tests {
File::open("../../test_vectors/Groth16/bn254/poseidon/circuit.zkey").unwrap();
let witness_file =
File::open("../../test_vectors/Groth16/bn254/poseidon/witness.wtns").unwrap();
- let driver = PlainDriver::::default();
let vk_file =
File::open("../../test_vectors/Groth16/bn254/poseidon/verification_key.json").unwrap();
@@ -86,14 +83,11 @@ mod tests {
let zkey = ZKey::::from_reader(zkey_file).unwrap();
let vk: JsonVerificationKey = serde_json::from_reader(vk_file).unwrap();
let public_input = witness.values[..=zkey.n_public].to_vec();
- let witness = SharedWitness::, Bn254> {
+ let witness = SharedWitness {
public_inputs: public_input.clone(),
witness: witness.values[zkey.n_public + 1..].to_vec(),
};
- let mut groth16 = Groth16::::new(driver);
- let proof = groth16
- .prove(&zkey, witness)
- .expect("proof generation works");
+ let proof = Groth16::::plain_prove(&zkey, witness).expect("proof generation works");
let ser_proof = serde_json::to_string(&proof).unwrap();
let der_proof = serde_json::from_str::>(&ser_proof).unwrap();
let verified = Groth16::verify(&vk, &der_proof, &public_input[1..]).expect("can verify");
@@ -152,16 +146,13 @@ mod tests {
let zkey = ZKey::::from_reader(zkey_file).unwrap();
let vk: JsonVerificationKey = serde_json::from_reader(vk_file).unwrap();
let public_input = witness.values[..=zkey.n_public].to_vec();
- let witness = SharedWitness::, Bls12_381> {
+ let witness = SharedWitness {
public_inputs: public_input.clone(),
witness: witness.values[zkey.n_public + 1..].to_vec(),
};
- let driver = PlainDriver::::default();
- let mut groth16 = Groth16::::new(driver);
- let proof = groth16
- .prove(&zkey, witness)
- .expect("proof generation works");
+ let proof =
+ Groth16::::plain_prove(&zkey, witness).expect("proof generation works");
let verified =
Groth16::::verify(&vk, &proof, &public_input[1..]).expect("can verify");
assert!(verified);
@@ -185,16 +176,12 @@ mod tests {
let zkey = ZKey::::from_reader(zkey_file).unwrap();
let vk: JsonVerificationKey = serde_json::from_reader(vk_file).unwrap();
let public_input = witness.values[..=zkey.n_public].to_vec();
- let witness = SharedWitness::, Bn254> {
+ let witness = SharedWitness {
public_inputs: public_input.clone(),
witness: witness.values[zkey.n_public + 1..].to_vec(),
};
- let driver = PlainDriver::::default();
- let mut groth16 = Groth16::::new(driver);
- let proof = groth16
- .prove(&zkey, witness)
- .expect("proof generation works");
+ let proof = Groth16::::plain_prove(&zkey, witness).expect("proof generation works");
let verified =
Groth16::::verify(&vk, &proof, &public_input[1..]).expect("can verify");
assert!(verified);
diff --git a/co-circom/co-groth16/src/mpc.rs b/co-circom/co-groth16/src/mpc.rs
new file mode 100644
index 000000000..465068f84
--- /dev/null
+++ b/co-circom/co-groth16/src/mpc.rs
@@ -0,0 +1,140 @@
+use core::fmt;
+use std::{fmt::Debug, sync::Arc};
+
+use ark_ec::pairing::Pairing;
+use ark_poly::domain::DomainCoeff;
+use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
+
+pub(crate) mod plain;
+pub(crate) mod rep3;
+pub(crate) mod shamir;
+
+pub use plain::PlainGroth16Driver;
+pub use rep3::Rep3Groth16Driver;
+pub use shamir::ShamirGroth16Driver;
+
+type IoResult = std::io::Result;
+
+/// This trait represents the operations used during Groth16 proof generation
+pub trait CircomGroth16Prover: Send + Sized {
+ /// The arithemitc share type
+ type ArithmeticShare: CanonicalSerialize
+ + CanonicalDeserialize
+ + Copy
+ + Clone
+ + Default
+ + Send
+ + Debug
+ + DomainCoeff
+ + 'static;
+ /// The G1 point share type
+ type PointShareG1: Debug + Send + 'static;
+ /// The G2 point share type
+ type PointShareG2: Debug + Send + 'static;
+ /// The party id type
+ type PartyID: Send + Sync + Copy + fmt::Display + 'static;
+
+ /// Generate a random arithemitc share
+ fn rand(&mut self) -> IoResult;
+
+ /// Get the party id
+ fn get_party_id(&self) -> Self::PartyID;
+
+ /// Each value of lhs consists of a coefficient c and an index i. This function computes the sum of the coefficients times the corresponding public input or private witness. In other words, an accumulator a is initialized to 0, and for each (c, i) in lhs, a += c * public_inputs\[i\] is computed if i corresponds to a public input, or c * private_witness[i - public_inputs.len()] if i corresponds to a private witness.
+ fn evaluate_constraint(
+ party_id: Self::PartyID,
+ lhs: &[(P::ScalarField, usize)],
+ public_inputs: &[P::ScalarField],
+ private_witness: &[Self::ArithmeticShare],
+ ) -> Self::ArithmeticShare;
+
+ /// Elementwise transformation of a vector of public values into a vector of shared values: \[a_i\] = a_i.
+ fn promote_to_trivial_shares(
+ id: Self::PartyID,
+ public_values: &[P::ScalarField],
+ ) -> Vec;
+
+ /// Performs element-wise multiplication of two vectors of shared values.
+ /// Does not perform any networking.
+ ///
+ /// # Security
+ /// You must *NOT* perform additional non-linear operations on the result of this function.
+ fn local_mul_vec(
+ &mut self,
+ a: Vec,
+ b: Vec,
+ ) -> Vec;
+
+ /// Compute the msm of `h` and `h_query` and multiplication `r` * `s`.
+ fn msm_and_mul(
+ &mut self,
+ h: Vec<::ScalarField>,
+ h_query: Arc>,
+ r: Self::ArithmeticShare,
+ s: Self::ArithmeticShare,
+ ) -> IoResult<(Self::PointShareG1, Self::ArithmeticShare)>;
+
+ /// Computes the \[coeffs_i\] *= c * g^i for the coefficients in 0 <= i < coeff.len()
+ fn distribute_powers_and_mul_by_const(
+ coeffs: &mut [Self::ArithmeticShare],
+ roots: &[P::ScalarField],
+ );
+
+ /// Perform msm between G1 `points` and `scalars`
+ fn msm_public_points_g1(
+ points: &[P::G1Affine],
+ scalars: &[Self::ArithmeticShare],
+ ) -> Self::PointShareG1;
+
+ /// Perform msm between G2 `points` and `scalars`
+ fn msm_public_points_g2(
+ points: &[P::G2Affine],
+ scalars: &[Self::ArithmeticShare],
+ ) -> Self::PointShareG2;
+
+ /// Multiplies a public point B to the shared point A in place: \[A\] *= B
+ fn scalar_mul_public_point_g1(a: &P::G1, b: Self::ArithmeticShare) -> Self::PointShareG1;
+
+ /// Add a shared point B in place to the shared point A: \[A\] += \[B\]
+ fn add_assign_points_g1(a: &mut Self::PointShareG1, b: &Self::PointShareG1);
+
+ /// Add a public point B in place to the shared point A
+ fn add_assign_points_public_g1(id: Self::PartyID, a: &mut Self::PointShareG1, b: &P::G1);
+
+ /// Reconstructs a shared point: A = Open(\[A\]).
+ fn open_point_g1(&mut self, a: &Self::PointShareG1) -> IoResult;
+
+ /// Multiplies a share b to the shared point A: \[A\] *= \[b\]. Requires network communication.
+ fn scalar_mul_g1(
+ &mut self,
+ a: &Self::PointShareG1,
+ b: Self::ArithmeticShare,
+ ) -> IoResult;
+
+ /// Subtract a shared point B in place from the shared point A: \[A\] -= \[B\]
+ fn sub_assign_points_g1(a: &mut Self::PointShareG1, b: &Self::PointShareG1);
+
+ /// Perform scalar multiplication of point A with a shared scalar b
+ fn scalar_mul_public_point_g2(a: &P::G2, b: Self::ArithmeticShare) -> Self::PointShareG2;
+
+ /// Add a shared point B in place to the shared point A: \[A\] += \[B\]
+ fn add_assign_points_g2(a: &mut Self::PointShareG2, b: &Self::PointShareG2);
+
+ /// Add a public point B in place to the shared point A
+ fn add_assign_points_public_g2(id: Self::PartyID, a: &mut Self::PointShareG2, b: &P::G2);
+
+ /// Reconstructs a shared points: A = Open(\[A\]), B = Open(\[B\]).
+ fn open_two_points(
+ &mut self,
+ a: Self::PointShareG1,
+ b: Self::PointShareG2,
+ ) -> std::io::Result<(P::G1, P::G2)>;
+
+ /// Reconstruct point G_a and perform scalar multiplication of G1_b and r concurrently
+ fn open_point_and_scalar_mul(
+ &mut self,
+ g_a: &Self::PointShareG1,
+ g1_b: &Self::PointShareG1,
+ r: Self::ArithmeticShare,
+ ) -> std::io::Result<(P::G1, Self::PointShareG1)>;
+}
diff --git a/co-circom/co-groth16/src/mpc/plain.rs b/co-circom/co-groth16/src/mpc/plain.rs
new file mode 100644
index 000000000..bc7a07530
--- /dev/null
+++ b/co-circom/co-groth16/src/mpc/plain.rs
@@ -0,0 +1,156 @@
+use std::sync::Arc;
+
+use ark_ec::pairing::Pairing;
+use ark_ec::scalar_mul::variable_base::VariableBaseMSM;
+use ark_ff::UniformRand;
+use rand::thread_rng;
+
+use super::CircomGroth16Prover;
+
+type IoResult = std::io::Result;
+
+/// A plain Groth16 driver
+pub struct PlainGroth16Driver;
+
+impl CircomGroth16Prover for PlainGroth16Driver {
+ type ArithmeticShare = P::ScalarField;
+
+ type PointShareG1 = P::G1;
+
+ type PointShareG2 = P::G2;
+
+ type PartyID = usize;
+
+ fn rand(&mut self) -> IoResult {
+ let mut rng = thread_rng();
+ Ok(Self::ArithmeticShare::rand(&mut rng))
+ }
+
+ fn get_party_id(&self) -> Self::PartyID {
+ //does't matter
+ 0
+ }
+
+ fn evaluate_constraint(
+ _: Self::PartyID,
+ lhs: &[(P::ScalarField, usize)],
+ public_inputs: &[P::ScalarField],
+ private_witness: &[Self::ArithmeticShare],
+ ) -> Self::ArithmeticShare {
+ let mut acc = P::ScalarField::default();
+ for (coeff, index) in lhs {
+ if index < &public_inputs.len() {
+ acc += *coeff * public_inputs[*index];
+ } else {
+ acc += *coeff * private_witness[*index - public_inputs.len()];
+ }
+ }
+ acc
+ }
+
+ fn promote_to_trivial_shares(
+ _: Self::PartyID,
+ public_values: &[P::ScalarField],
+ ) -> Vec {
+ public_values.to_vec()
+ }
+
+ fn local_mul_vec(
+ &mut self,
+ a: Vec,
+ b: Vec,
+ ) -> Vec {
+ a.iter().zip(b.iter()).map(|(a, b)| *a * b).collect()
+ }
+
+ fn msm_and_mul(
+ &mut self,
+ h: Vec<::ScalarField>,
+ h_query: Arc>,
+ r: Self::ArithmeticShare,
+ s: Self::ArithmeticShare,
+ ) -> IoResult<(Self::PointShareG1, Self::ArithmeticShare)> {
+ Ok((P::G1::msm_unchecked(h_query.as_ref(), &h), r * s))
+ }
+
+ fn distribute_powers_and_mul_by_const(
+ coeffs: &mut [Self::ArithmeticShare],
+ roots: &[P::ScalarField],
+ ) {
+ #[allow(unused_mut)]
+ for (mut c, pow) in coeffs.iter_mut().zip(roots) {
+ *c *= pow;
+ }
+ }
+
+ fn msm_public_points_g1(
+ points: &[P::G1Affine],
+ scalars: &[Self::ArithmeticShare],
+ ) -> Self::PointShareG1 {
+ P::G1::msm_unchecked(points, scalars)
+ }
+
+ fn msm_public_points_g2(
+ points: &[P::G2Affine],
+ scalars: &[Self::ArithmeticShare],
+ ) -> Self::PointShareG2 {
+ P::G2::msm_unchecked(points, scalars)
+ }
+
+ fn scalar_mul_public_point_g1(a: &P::G1, b: Self::ArithmeticShare) -> Self::PointShareG1 {
+ *a * b
+ }
+
+ fn add_assign_points_g1(a: &mut Self::PointShareG1, b: &Self::PointShareG1) {
+ *a += b;
+ }
+
+ fn add_assign_points_public_g1(_: Self::PartyID, a: &mut Self::PointShareG1, b: &P::G1) {
+ *a += b;
+ }
+
+ fn open_point_g1(&mut self, a: &Self::PointShareG1) -> super::IoResult {
+ Ok(*a)
+ }
+
+ fn scalar_mul_g1(
+ &mut self,
+ a: &Self::PointShareG1,
+ b: Self::ArithmeticShare,
+ ) -> super::IoResult {
+ Ok(*a * b)
+ }
+
+ fn sub_assign_points_g1(a: &mut Self::PointShareG1, b: &Self::PointShareG1) {
+ *a -= b;
+ }
+
+ fn scalar_mul_public_point_g2(a: &P::G2, b: Self::ArithmeticShare) -> Self::PointShareG2 {
+ *a * b
+ }
+
+ fn add_assign_points_g2(a: &mut Self::PointShareG2, b: &Self::PointShareG2) {
+ *a += b;
+ }
+
+ fn add_assign_points_public_g2(_: Self::PartyID, a: &mut Self::PointShareG2, b: &P::G2) {
+ *a += b;
+ }
+
+ fn open_two_points(
+ &mut self,
+ a: Self::PointShareG1,
+ b: Self::PointShareG2,
+ ) -> std::io::Result<(P::G1, P::G2)> {
+ Ok((a, b))
+ }
+
+ fn open_point_and_scalar_mul(
+ &mut self,
+ g_a: &Self::PointShareG1,
+ g1_b: &Self::PointShareG1,
+ r: Self::ArithmeticShare,
+ ) -> super::IoResult<(P::G1, Self::PointShareG1)> {
+ Ok((*g_a, *g1_b * r))
+ }
+}
diff --git a/co-circom/co-groth16/src/mpc/rep3.rs b/co-circom/co-groth16/src/mpc/rep3.rs
new file mode 100644
index 000000000..36709ad35
--- /dev/null
+++ b/co-circom/co-groth16/src/mpc/rep3.rs
@@ -0,0 +1,202 @@
+use std::sync::Arc;
+
+use ark_ec::pairing::Pairing;
+use mpc_core::protocols::rep3::{
+ arithmetic,
+ id::PartyID,
+ network::{IoContext, Rep3Network},
+ pointshare, Rep3PointShare, Rep3PrimeFieldShare,
+};
+use rayon::prelude::*;
+
+use super::{CircomGroth16Prover, IoResult};
+
+/// A Groth16 driver for REP3 secret sharing
+///
+/// Contains two [`IoContext`]s, `io_context0` for the main execution and `io_context1` for parts that can run concurrently.
+pub struct Rep3Groth16Driver {
+ io_context0: IoContext,
+ io_context1: IoContext,
+}
+
+impl Rep3Groth16Driver {
+ /// Create a new [`Rep3Groth16Driver`] with two [`IoContext`]s
+ pub fn new(io_context0: IoContext, io_context1: IoContext) -> Self {
+ Self {
+ io_context0,
+ io_context1,
+ }
+ }
+}
+
+impl CircomGroth16Prover for Rep3Groth16Driver
+where
+ N: 'static,
+{
+ type ArithmeticShare = Rep3PrimeFieldShare;
+ type PointShareG1 = Rep3PointShare;
+ type PointShareG2 = Rep3PointShare;
+
+ type PartyID = PartyID;
+
+ fn rand(&mut self) -> IoResult {
+ Ok(Self::ArithmeticShare::rand(&mut self.io_context0))
+ }
+
+ fn get_party_id(&self) -> Self::PartyID {
+ self.io_context0.id
+ }
+
+ fn evaluate_constraint(
+ party_id: Self::PartyID,
+ lhs: &[(P::ScalarField, usize)],
+ public_inputs: &[P::ScalarField],
+ private_witness: &[Self::ArithmeticShare],
+ ) -> Self::ArithmeticShare {
+ let mut acc = Self::ArithmeticShare::default();
+ for (coeff, index) in lhs {
+ if index < &public_inputs.len() {
+ let val = public_inputs[*index];
+ let mul_result = val * coeff;
+ arithmetic::add_assign_public(&mut acc, mul_result, party_id);
+ } else {
+ let current_witness = private_witness[*index - public_inputs.len()];
+ arithmetic::add_assign(&mut acc, arithmetic::mul_public(current_witness, *coeff));
+ }
+ }
+ acc
+ }
+
+ fn promote_to_trivial_shares(
+ id: Self::PartyID,
+ public_values: &[P::ScalarField],
+ ) -> Vec {
+ public_values
+ .par_iter()
+ .with_min_len(1024)
+ .map(|value| Self::ArithmeticShare::promote_from_trivial(value, id))
+ .collect()
+ }
+
+ fn local_mul_vec(
+ &mut self,
+ a: Vec,
+ b: Vec,
+ ) -> Vec {
+ arithmetic::local_mul_vec(&a, &b, &mut self.io_context0.rngs)
+ }
+
+ fn msm_and_mul(
+ &mut self,
+ h: Vec<::ScalarField>,
+ h_query: Arc>,
+ r: Self::ArithmeticShare,
+ s: Self::ArithmeticShare,
+ ) -> IoResult<(Self::PointShareG1, Self::ArithmeticShare)> {
+ std::thread::scope(|scope| {
+ let h_acc = scope.spawn(|| {
+ let msm_h_query = tracing::debug_span!("msm h_query").entered();
+ let h = arithmetic::io_mul_vec(h, &mut self.io_context0)?;
+ let result = pointshare::msm_public_points(h_query.as_ref(), &h);
+ msm_h_query.exit();
+ Ok::<_, std::io::Error>(result)
+ });
+ let mul = arithmetic::mul(r, s, &mut self.io_context1)?;
+ Ok((h_acc.join().expect("can join")?, mul))
+ })
+ }
+
+ fn distribute_powers_and_mul_by_const(
+ coeffs: &mut [Self::ArithmeticShare],
+ roots: &[P::ScalarField],
+ ) {
+ coeffs
+ .par_iter_mut()
+ .zip_eq(roots.par_iter())
+ .with_min_len(512)
+ .for_each(|(c, pow)| {
+ arithmetic::mul_assign_public(c, *pow);
+ })
+ }
+
+ fn msm_public_points_g1(
+ points: &[P::G1Affine],
+ scalars: &[Self::ArithmeticShare],
+ ) -> Self::PointShareG1 {
+ pointshare::msm_public_points(points, scalars)
+ }
+
+ fn msm_public_points_g2(
+ points: &[P::G2Affine],
+ scalars: &[Self::ArithmeticShare],
+ ) -> Self::PointShareG2 {
+ pointshare::msm_public_points(points, scalars)
+ }
+
+ fn scalar_mul_public_point_g1(a: &P::G1, b: Self::ArithmeticShare) -> Self::PointShareG1 {
+ pointshare::scalar_mul_public_point(a, b)
+ }
+
+ /// Add a shared point B in place to the shared point A: \[A\] += \[B\]
+ fn add_assign_points_g1(a: &mut Self::PointShareG1, b: &Self::PointShareG1) {
+ pointshare::add_assign(a, b)
+ }
+
+ fn add_assign_points_public_g1(id: Self::PartyID, a: &mut Self::PointShareG1, b: &P::G1) {
+ pointshare::add_assign_public(a, b, id)
+ }
+
+ fn open_point_g1(&mut self, a: &Self::PointShareG1) -> IoResult {
+ pointshare::open_point(a, &mut self.io_context0)
+ }
+
+ fn scalar_mul_g1(
+ &mut self,
+ a: &Self::PointShareG1,
+ b: Self::ArithmeticShare,
+ ) -> IoResult {
+ pointshare::scalar_mul(a, b, &mut self.io_context0)
+ }
+
+ fn sub_assign_points_g1(a: &mut Self::PointShareG1, b: &Self::PointShareG1) {
+ pointshare::sub_assign(a, b);
+ }
+
+ fn scalar_mul_public_point_g2(a: &P::G2, b: Self::ArithmeticShare) -> Self::PointShareG2 {
+ pointshare::scalar_mul_public_point(a, b)
+ }
+
+ fn add_assign_points_g2(a: &mut Self::PointShareG2, b: &Self::PointShareG2) {
+ pointshare::add_assign(a, b)
+ }
+
+ fn add_assign_points_public_g2(id: Self::PartyID, a: &mut Self::PointShareG2, b: &P::G2) {
+ pointshare::add_assign_public(a, b, id)
+ }
+
+ fn open_two_points(
+ &mut self,
+ a: Self::PointShareG1,
+ b: Self::PointShareG2,
+ ) -> std::io::Result<(P::G1, P::G2)> {
+ let s1 = a.b;
+ let s2 = b.b;
+ let (mut r1, mut r2) = self.io_context0.network.reshare((s1, s2))?;
+ r1 += a.a + a.b;
+ r2 += b.a + b.b;
+ Ok((r1, r2))
+ }
+
+ fn open_point_and_scalar_mul(
+ &mut self,
+ g_a: &Self::PointShareG1,
+ g1_b: &Self::PointShareG1,
+ r: Self::ArithmeticShare,
+ ) -> std::io::Result<(::G1, Self::PointShareG1)> {
+ std::thread::scope(|s| {
+ let opened = s.spawn(|| pointshare::open_point(g_a, &mut self.io_context0));
+ let mul_result = pointshare::scalar_mul(g1_b, r, &mut self.io_context1)?;
+ Ok((opened.join().expect("can join")?, mul_result))
+ })
+ }
+}
diff --git a/co-circom/co-groth16/src/mpc/shamir.rs b/co-circom/co-groth16/src/mpc/shamir.rs
new file mode 100644
index 000000000..5e97a3102
--- /dev/null
+++ b/co-circom/co-groth16/src/mpc/shamir.rs
@@ -0,0 +1,201 @@
+use std::sync::Arc;
+
+use super::{CircomGroth16Prover, IoResult};
+use ark_ec::pairing::Pairing;
+use ark_ff::PrimeField;
+use mpc_core::protocols::shamir::{
+ arithmetic, core, network::ShamirNetwork, pointshare, ShamirPointShare, ShamirPrimeFieldShare,
+ ShamirProtocol,
+};
+use rayon::prelude::*;
+
+/// A Groth16 dirver unsing shamir secret sharing
+///
+/// Contains two [`ShamirProtocol`]s, `protocol0` for the main execution and `protocol0` for parts that can run concurrently.
+pub struct ShamirGroth16Driver {
+ protocol0: ShamirProtocol,
+ protocol1: ShamirProtocol,
+}
+
+impl ShamirGroth16Driver {
+ /// Create a new [`ShamirGroth16Driver`] with two [`ShamirProtocol`]s
+ pub fn new(protocol0: ShamirProtocol, protocol1: ShamirProtocol) -> Self {
+ Self {
+ protocol0,
+ protocol1,
+ }
+ }
+}
+
+impl CircomGroth16Prover
+ for ShamirGroth16Driver
+{
+ type ArithmeticShare = ShamirPrimeFieldShare;
+ type PointShareG1 = ShamirPointShare;
+ type PointShareG2 = ShamirPointShare;
+
+ type PartyID = usize;
+
+ fn rand(&mut self) -> IoResult {
+ self.protocol0.rand()
+ }
+
+ fn get_party_id(&self) -> Self::PartyID {
+ self.protocol0.network.get_id()
+ }
+
+ fn evaluate_constraint(
+ _party_id: Self::PartyID,
+ lhs: &[(P::ScalarField, usize)],
+ public_inputs: &[P::ScalarField],
+ private_witness: &[Self::ArithmeticShare],
+ ) -> Self::ArithmeticShare {
+ let mut acc = Self::ArithmeticShare::default();
+ for (coeff, index) in lhs {
+ if index < &public_inputs.len() {
+ let val = public_inputs[*index];
+ let mul_result = val * coeff;
+ arithmetic::add_assign_public(&mut acc, mul_result);
+ } else {
+ let current_witness = private_witness[*index - public_inputs.len()];
+ arithmetic::add_assign(&mut acc, arithmetic::mul_public(current_witness, *coeff));
+ }
+ }
+ acc
+ }
+
+ fn promote_to_trivial_shares(
+ _id: Self::PartyID,
+ public_values: &[P::ScalarField],
+ ) -> Vec {
+ arithmetic::promote_to_trivial_shares(public_values)
+ }
+
+ fn local_mul_vec(
+ &mut self,
+ a: Vec,
+ b: Vec,
+ ) -> Vec {
+ arithmetic::local_mul_vec(&a, &b)
+ }
+
+ fn msm_and_mul(
+ &mut self,
+ h: Vec<::ScalarField>,
+ h_query: Arc>,
+ r: Self::ArithmeticShare,
+ s: Self::ArithmeticShare,
+ ) -> IoResult<(Self::PointShareG1, Self::ArithmeticShare)> {
+ std::thread::scope(|scope| {
+ let h_acc = scope.spawn(|| {
+ let msm_h_query = tracing::debug_span!("msm h_query").entered();
+ let h = self.protocol0.degree_reduce_vec(h)?;
+ let result = pointshare::msm_public_points(h_query.as_ref(), &h);
+ msm_h_query.exit();
+ Ok::<_, std::io::Error>(result)
+ });
+ let mul = arithmetic::mul(r, s, &mut self.protocol1)?;
+ Ok((h_acc.join().expect("can join")?, mul))
+ })
+ }
+
+ fn distribute_powers_and_mul_by_const(
+ coeffs: &mut [Self::ArithmeticShare],
+ roots: &[P::ScalarField],
+ ) {
+ coeffs
+ .par_iter_mut()
+ .zip_eq(roots.par_iter())
+ .with_min_len(512)
+ .for_each(|(c, pow)| {
+ arithmetic::mul_assign_public(c, *pow);
+ })
+ }
+
+ fn msm_public_points_g1(
+ points: &[P::G1Affine],
+ scalars: &[Self::ArithmeticShare],
+ ) -> Self::PointShareG1 {
+ pointshare::msm_public_points(points, scalars)
+ }
+
+ fn msm_public_points_g2(
+ points: &[P::G2Affine],
+ scalars: &[Self::ArithmeticShare],
+ ) -> Self::PointShareG2 {
+ pointshare::msm_public_points(points, scalars)
+ }
+
+ fn scalar_mul_public_point_g1(a: &P::G1, b: Self::ArithmeticShare) -> Self::PointShareG1 {
+ pointshare::scalar_mul_public_point(b, a)
+ }
+
+ fn add_assign_points_g1(a: &mut Self::PointShareG1, b: &Self::PointShareG1) {
+ pointshare::add_assign(a, b)
+ }
+
+ fn add_assign_points_public_g1(_id: Self::PartyID, a: &mut Self::PointShareG1, b: &P::G1) {
+ pointshare::add_assign_public(a, b)
+ }
+
+ fn open_point_g1(&mut self, a: &Self::PointShareG1) -> IoResult {
+ pointshare::open_point(a, &mut self.protocol0)
+ }
+
+ fn scalar_mul_g1(
+ &mut self,
+ a: &Self::PointShareG1,
+ b: Self::ArithmeticShare,
+ ) -> IoResult {
+ pointshare::scalar_mul(a, b, &mut self.protocol0)
+ }
+
+ fn sub_assign_points_g1(a: &mut Self::PointShareG1, b: &Self::PointShareG1) {
+ pointshare::sub_assign(a, b);
+ }
+
+ fn scalar_mul_public_point_g2(a: &P::G2, b: Self::ArithmeticShare) -> Self::PointShareG2 {
+ pointshare::scalar_mul_public_point(b, a)
+ }
+
+ fn add_assign_points_g2(a: &mut Self::PointShareG2, b: &Self::PointShareG2) {
+ pointshare::add_assign(a, b)
+ }
+
+ fn add_assign_points_public_g2(_id: Self::PartyID, a: &mut Self::PointShareG2, b: &P::G2) {
+ pointshare::add_assign_public(a, b)
+ }
+
+ fn open_two_points(
+ &mut self,
+ a: Self::PointShareG1,
+ b: Self::PointShareG2,
+ ) -> std::io::Result<(P::G1, P::G2)> {
+ let s1 = a.a;
+ let s2 = b.a;
+
+ let rcv: Vec<(P::G1, P::G2)> = self
+ .protocol0
+ .network
+ .broadcast_next((s1, s2), self.protocol0.threshold + 1)?;
+ let (r1, r2): (Vec, Vec) = rcv.into_iter().unzip();
+
+ let r1 = core::reconstruct_point(&r1, &self.protocol0.open_lagrange_t);
+ let r2 = core::reconstruct_point(&r2, &self.protocol0.open_lagrange_t);
+
+ Ok((r1, r2))
+ }
+
+ fn open_point_and_scalar_mul(
+ &mut self,
+ g_a: &Self::PointShareG1,
+ g1_b: &Self::PointShareG1,
+ r: Self::ArithmeticShare,
+ ) -> super::IoResult<(P::G1, Self::PointShareG1)> {
+ std::thread::scope(|s| {
+ let opened = s.spawn(|| pointshare::open_point(g_a, &mut self.protocol0));
+ let mul_result = pointshare::scalar_mul(g1_b, r, &mut self.protocol1)?;
+ Ok((opened.join().expect("can join")?, mul_result))
+ })
+ }
+}