Skip to content

Commit

Permalink
refactor!: make pointshare in Groth16 MPC trait generic over the curve
Browse files Browse the repository at this point in the history
This de-duplicates a bit of code.

BREAKING CHANGE: the public interface of the Groth16MPCProver trait has changed.
  • Loading branch information
dkales authored and 0xThemis committed Oct 15, 2024
1 parent 7210881 commit dc5acd2
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 179 deletions.
48 changes: 24 additions & 24 deletions co-circom/co-groth16/src/groth16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,53 +291,53 @@ where

fn calculate_coeff_g1(
id: T::PartyID,
initial: T::PointShareG1,
initial: T::PointShare<P::G1>,
query: &[P::G1Affine],
vk_param: P::G1Affine,
input_assignment: &[P::ScalarField],
aux_assignment: &[T::ArithmeticShare],
) -> T::PointShareG1 {
) -> T::PointShare<P::G1> {
let pub_len = input_assignment.len();

// we block this thread of the runtime here.
// It should not matter too much, as we have multithreaded
// runtime.
let (pub_acc, priv_acc) = rayon::join(
|| P::G1::msm_unchecked(&query[1..=pub_len], input_assignment),
|| T::msm_public_points_g1(&query[1 + pub_len..], aux_assignment),
|| T::msm_public_points(&query[1 + pub_len..], aux_assignment),
);

let mut res = initial;
T::add_assign_points_public_g1(id, &mut res, &query[0].into_group());
T::add_assign_points_public_g1(id, &mut res, &vk_param.into_group());
T::add_assign_points_public_g1(id, &mut res, &pub_acc);
T::add_assign_points_g1(&mut res, &priv_acc);
T::add_assign_points_public(id, &mut res, &query[0].into_group());
T::add_assign_points_public(id, &mut res, &vk_param.into_group());
T::add_assign_points_public(id, &mut res, &pub_acc);
T::add_assign_points(&mut res, &priv_acc);
res
}

fn calculate_coeff_g2(
id: T::PartyID,
initial: T::PointShareG2,
initial: T::PointShare<P::G2>,
query: &[P::G2Affine],
vk_param: P::G2Affine,
input_assignment: &[P::ScalarField],
aux_assignment: &[T::ArithmeticShare],
) -> T::PointShareG2 {
) -> T::PointShare<P::G2> {
let pub_len = input_assignment.len();

// we block this thread of the runtime here.
// It should not matter too much, as we have multithreaded
// runtime.
let (pub_acc, priv_acc) = rayon::join(
|| P::G2::msm_unchecked(&query[1..=pub_len], input_assignment),
|| T::msm_public_points_g2(&query[1 + pub_len..], aux_assignment),
|| T::msm_public_points(&query[1 + pub_len..], aux_assignment),
);

let mut res = initial;
T::add_assign_points_public_g2(id, &mut res, &query[0].into_group());
T::add_assign_points_public_g2(id, &mut res, &vk_param.into_group());
T::add_assign_points_public_g2(id, &mut res, &pub_acc);
T::add_assign_points_g2(&mut res, &priv_acc);
T::add_assign_points_public(id, &mut res, &query[0].into_group());
T::add_assign_points_public(id, &mut res, &vk_param.into_group());
T::add_assign_points_public(id, &mut res, &pub_acc);
T::add_assign_points(&mut res, &priv_acc);
res
}

Expand Down Expand Up @@ -380,7 +380,7 @@ where
let compute_a =
tracing::debug_span!("compute A in create proof with assignment").entered();
// Compute A
let r_g1 = T::scalar_mul_public_point_g1(&delta_g1, r);
let r_g1 = T::scalar_mul_public_point(&delta_g1, r);
let r_g1 = Self::calculate_coeff_g1(
party_id,
r_g1,
Expand All @@ -398,7 +398,7 @@ where
tracing::debug_span!("compute B/G1 in create proof with assignment").entered();
// Compute B in G1
// In original implementation this is skipped if r==0, however r is shared in our case
let s_g1 = T::scalar_mul_public_point_g1(&delta_g1, s);
let s_g1 = T::scalar_mul_public_point(&delta_g1, s);
let s_g1 = Self::calculate_coeff_g1(
party_id,
s_g1,
Expand All @@ -415,7 +415,7 @@ where
let compute_b =
tracing::debug_span!("compute B/G2 in create proof with assignment").entered();
// Compute B in G2
let s_g2 = T::scalar_mul_public_point_g2(&delta_g2, s);
let s_g2 = T::scalar_mul_public_point(&delta_g2, s);
let s_g2 = Self::calculate_coeff_g2(
party_id,
s_g2,
Expand All @@ -430,7 +430,7 @@ where

rayon::spawn(move || {
let msm_l_query = tracing::debug_span!("msm l_query").entered();
let result = T::msm_public_points_g1(l_query.as_ref(), &aux_assignment4);
let result = T::msm_public_points(l_query.as_ref(), &aux_assignment4);
l_acc_tx.send(result).expect("channel not dropped");
msm_l_query.exit();
});
Expand All @@ -442,7 +442,7 @@ where
});

let rs = self.driver.mul(r, s)?;
let r_s_delta_g1 = T::scalar_mul_public_point_g1(&delta_g1, rs);
let r_s_delta_g1 = T::scalar_mul_public_point(&delta_g1, rs);

let l_aux_acc = l_acc_rx.blocking_recv().expect("channel not dropped");

Expand All @@ -456,15 +456,15 @@ where
network_round.exit();

let last_round = tracing::debug_span!("finish open two points and some adds").entered();
let s_g_a = T::scalar_mul_public_point_g1(&g_a_opened, s);
let s_g_a = T::scalar_mul_public_point(&g_a_opened, s);

let mut g_c = s_g_a;
T::add_assign_points_g1(&mut g_c, &r_g1_b);
T::sub_assign_points_g1(&mut g_c, &r_s_delta_g1);
T::add_assign_points_g1(&mut g_c, &l_aux_acc);
T::add_assign_points(&mut g_c, &r_g1_b);
T::sub_assign_points(&mut g_c, &r_s_delta_g1);
T::add_assign_points(&mut g_c, &l_aux_acc);

let h_acc = h_acc_rx.blocking_recv()?;
let g_c = T::add_points_g1_half_share(g_c, &h_acc);
let g_c = T::add_points_half_share(g_c, &h_acc);

let g2_b = s_g2_rx.blocking_recv()?;
let (g_c_opened, g2_b_opened) = self.driver.open_two_points(g_c, g2_b)?;
Expand Down
74 changes: 36 additions & 38 deletions co-circom/co-groth16/src/mpc.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use core::fmt;
use std::fmt::Debug;

use ark_ec::pairing::Pairing;
use ark_ec::{pairing::Pairing, CurveGroup};
use ark_poly::domain::DomainCoeff;
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};

Expand All @@ -28,9 +28,9 @@ pub trait CircomGroth16Prover<P: Pairing>: Send + Sized {
+ DomainCoeff<P::ScalarField>
+ 'static;
/// The G1 point share type
type PointShareG1: Debug + Send + 'static;
/// The G2 point share type
type PointShareG2: Debug + Send + 'static;
type PointShare<C>: Debug + Send + 'static
where
C: CurveGroup;
/// The party id type
type PartyID: Send + Sync + Copy + fmt::Display + 'static;

Expand Down Expand Up @@ -78,64 +78,62 @@ pub trait CircomGroth16Prover<P: Pairing>: Send + Sized {
roots: &[P::ScalarField],
);

/// Perform msm between G1 `points` and `scalars`
fn msm_public_points_g1(
points: &[P::G1Affine],
/// Perform msm between `points` and `scalars`
fn msm_public_points<C>(
points: &[C::Affine],
scalars: &[Self::ArithmeticShare],
) -> Self::PointShareG1;

/// Perform msm between G2 `points` and `scalars`
fn msm_public_points_g2(
points: &[P::G2Affine],
scalars: &[Self::ArithmeticShare],
) -> Self::PointShareG2;
) -> Self::PointShare<C>
where
C: CurveGroup<ScalarField = P::ScalarField>;

/// Multiplies a public point B to the shared point A in place: \[A\] *= B
fn scalar_mul_public_point_g1(a: &P::G1, b: Self::ArithmeticShare) -> Self::PointShareG1;
fn scalar_mul_public_point<C>(a: &C, b: Self::ArithmeticShare) -> Self::PointShare<C>
where
C: CurveGroup<ScalarField = P::ScalarField>;

/// Add a shared point B in place to the shared point A: \[A\] += \[B\]
fn add_assign_points_g1(a: &mut Self::PointShareG1, b: &Self::PointShareG1);
fn add_assign_points<C: CurveGroup>(a: &mut Self::PointShare<C>, b: &Self::PointShare<C>);

/// Subtract a shared point B in place from the shared point A: \[A\] -= \[B\]
fn sub_assign_points<C: CurveGroup>(a: &mut Self::PointShare<C>, b: &Self::PointShare<C>);

/// Add a shared point B in place to the shared point A: \[A\] += \[B\]
fn add_points_g1_half_share(a: Self::PointShareG1, b: &P::G1) -> P::G1;
fn add_points_half_share<C: CurveGroup>(a: Self::PointShare<C>, b: &C) -> C;

/// Add a public point B in place to the shared point A
fn add_assign_points_public_g1(id: Self::PartyID, a: &mut Self::PointShareG1, b: &P::G1);
fn add_assign_points_public<C: CurveGroup>(
id: Self::PartyID,
a: &mut Self::PointShare<C>,
b: &C,
);

/// Reconstructs a shared point: A = Open(\[A\]).
fn open_point_g1(&mut self, a: &Self::PointShareG1) -> IoResult<P::G1>;
fn open_point<C>(&mut self, a: &Self::PointShare<C>) -> IoResult<C>
where
C: CurveGroup<ScalarField = P::ScalarField>;

/// Multiplies a share b to the shared point A: \[A\] *= \[b\]. Requires network communication.
fn scalar_mul_g1(
fn scalar_mul<C>(
&mut self,
a: &Self::PointShareG1,
a: &Self::PointShare<C>,
b: Self::ArithmeticShare,
) -> IoResult<Self::PointShareG1>;

/// Subtract a shared point B in place from the shared point A: \[A\] -= \[B\]
fn sub_assign_points_g1(a: &mut Self::PointShareG1, b: &Self::PointShareG1);

/// Perform scalar multiplication of point A with a shared scalar b
fn scalar_mul_public_point_g2(a: &P::G2, b: Self::ArithmeticShare) -> Self::PointShareG2;

/// Add a shared point B in place to the shared point A: \[A\] += \[B\]
fn add_assign_points_g2(a: &mut Self::PointShareG2, b: &Self::PointShareG2);

/// Add a public point B in place to the shared point A
fn add_assign_points_public_g2(id: Self::PartyID, a: &mut Self::PointShareG2, b: &P::G2);
) -> IoResult<Self::PointShare<C>>
where
C: CurveGroup<ScalarField = P::ScalarField>;

/// Reconstructs a shared points: A = Open(\[A\]), B = Open(\[B\]).
fn open_two_points(
&mut self,
a: P::G1,
b: Self::PointShareG2,
b: Self::PointShare<P::G2>,
) -> std::io::Result<(P::G1, P::G2)>;

/// Reconstruct point G_a and perform scalar multiplication of G1_b and r concurrently
#[allow(clippy::type_complexity)]
fn open_point_and_scalar_mul(
&mut self,
g_a: &Self::PointShareG1,
g1_b: &Self::PointShareG1,
g_a: &Self::PointShare<P::G1>,
g1_b: &Self::PointShare<P::G1>,
r: Self::ArithmeticShare,
) -> std::io::Result<(P::G1, Self::PointShareG1)>;
) -> std::io::Result<(P::G1, Self::PointShare<P::G1>)>;
}
77 changes: 36 additions & 41 deletions co-circom/co-groth16/src/mpc/plain.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use ark_ec::pairing::Pairing;
use ark_ec::scalar_mul::variable_base::VariableBaseMSM;
use ark_ec::CurveGroup;
use ark_ff::UniformRand;
use rand::thread_rng;

Expand All @@ -13,9 +13,7 @@ pub struct PlainGroth16Driver;
impl<P: Pairing> CircomGroth16Prover<P> for PlainGroth16Driver {
type ArithmeticShare = P::ScalarField;

type PointShareG1 = P::G1;

type PointShareG2 = P::G2;
type PointShare<C> = C where C: CurveGroup;

type PartyID = usize;

Expand Down Expand Up @@ -79,78 +77,75 @@ impl<P: Pairing> CircomGroth16Prover<P> for PlainGroth16Driver {
}
}

fn msm_public_points_g1(
points: &[P::G1Affine],
scalars: &[Self::ArithmeticShare],
) -> Self::PointShareG1 {
P::G1::msm_unchecked(points, scalars)
}

fn msm_public_points_g2(
points: &[P::G2Affine],
fn msm_public_points<C>(
points: &[C::Affine],
scalars: &[Self::ArithmeticShare],
) -> Self::PointShareG2 {
P::G2::msm_unchecked(points, scalars)
) -> Self::PointShare<C>
where
C: CurveGroup<ScalarField = P::ScalarField>,
{
C::msm_unchecked(points, scalars)
}

fn scalar_mul_public_point_g1(a: &P::G1, b: Self::ArithmeticShare) -> Self::PointShareG1 {
fn scalar_mul_public_point<C>(a: &C, b: Self::ArithmeticShare) -> Self::PointShare<C>
where
C: CurveGroup<ScalarField = P::ScalarField>,
{
*a * b
}

fn add_assign_points_g1(a: &mut Self::PointShareG1, b: &Self::PointShareG1) {
fn add_assign_points<C: CurveGroup>(a: &mut Self::PointShare<C>, b: &Self::PointShare<C>) {
*a += b;
}

fn add_points_g1_half_share(a: Self::PointShareG1, b: &P::G1) -> P::G1 {
fn add_points_half_share<C: CurveGroup>(a: Self::PointShare<C>, b: &C) -> C {
a + b
}

fn add_assign_points_public_g1(_: Self::PartyID, a: &mut Self::PointShareG1, b: &P::G1) {
fn add_assign_points_public<C: CurveGroup>(
_: Self::PartyID,
a: &mut Self::PointShare<C>,
b: &C,
) {
*a += b;
}

fn open_point_g1(&mut self, a: &Self::PointShareG1) -> super::IoResult<P::G1> {
fn open_point<C>(&mut self, a: &Self::PointShare<C>) -> super::IoResult<C>
where
C: CurveGroup<ScalarField = P::ScalarField>,
{
Ok(*a)
}

fn scalar_mul_g1(
fn scalar_mul<C>(
&mut self,
a: &Self::PointShareG1,
a: &Self::PointShare<C>,
b: Self::ArithmeticShare,
) -> super::IoResult<Self::PointShareG1> {
) -> super::IoResult<Self::PointShare<C>>
where
C: CurveGroup<ScalarField = P::ScalarField>,
{
Ok(*a * b)
}

fn sub_assign_points_g1(a: &mut Self::PointShareG1, b: &Self::PointShareG1) {
fn sub_assign_points<C: CurveGroup>(a: &mut Self::PointShare<C>, b: &Self::PointShare<C>) {
*a -= b;
}

fn scalar_mul_public_point_g2(a: &P::G2, b: Self::ArithmeticShare) -> Self::PointShareG2 {
*a * b
}

fn add_assign_points_g2(a: &mut Self::PointShareG2, b: &Self::PointShareG2) {
*a += b;
}

fn add_assign_points_public_g2(_: Self::PartyID, a: &mut Self::PointShareG2, b: &P::G2) {
*a += b;
}

fn open_two_points(
&mut self,
a: Self::PointShareG1,
b: Self::PointShareG2,
a: Self::PointShare<P::G1>,
b: Self::PointShare<P::G2>,
) -> std::io::Result<(P::G1, P::G2)> {
Ok((a, b))
}

fn open_point_and_scalar_mul(
&mut self,
g_a: &Self::PointShareG1,
g1_b: &Self::PointShareG1,
g_a: &Self::PointShare<P::G1>,
g1_b: &Self::PointShare<P::G1>,
r: Self::ArithmeticShare,
) -> super::IoResult<(P::G1, Self::PointShareG1)> {
) -> super::IoResult<(P::G1, Self::PointShare<P::G1>)> {
Ok((*g_a, *g1_b * r))
}
}
Loading

0 comments on commit dc5acd2

Please sign in to comment.