Skip to content

Commit

Permalink
feat!: added run_and_get_network to CircomRep3VmWitnessExtension, cha…
Browse files Browse the repository at this point in the history
…nged run and run_with_flat back to consume self

BREAKING CHANGE: run and run_with_flat methods on WitnessExtension now consume self again
  • Loading branch information
fabian1409 authored and 0xThemis committed Oct 30, 2024
1 parent 8c02810 commit b362504
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 5 deletions.
2 changes: 1 addition & 1 deletion co-circom/circom-mpc-compiler/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ mod tests {
)
.unwrap();

let mut plain_vm = parsed.to_plain_vm(VMConfig::default());
let plain_vm = parsed.to_plain_vm(VMConfig::default());
let finalized_witness = plain_vm
.run_with_flat(
to_field_vec!(vec![
Expand Down
4 changes: 4 additions & 0 deletions co-circom/circom-mpc-vm/src/mpc/rep3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ impl<F: PrimeField, N: Rep3Network> CircomRep3VmWitnessExtension<F, N> {
})
}

pub fn get_network(self) -> N {
self.io_context0.network
}

/// Normally F is split into positive and negative numbers in the range [0, p/2] and [p/2 + 1, p)
/// However, for comparisons, we want the negative numbers to be "lower" than the positive ones.
/// Therefore we shift the input by p/2 + 1 to the left, which results in a mapping of [negative, 0, positive] into F.
Expand Down
37 changes: 35 additions & 2 deletions co-circom/circom-mpc-vm/src/mpc_vm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use eyre::{bail, eyre, Result};
use itertools::{izip, Itertools};
use mpc_core::protocols::rep3::conversion::A2BType;
use mpc_core::protocols::rep3::network::{Rep3MpcNet, Rep3Network};
use mpc_core::protocols::rep3::Rep3PrimeFieldShare;
use mpc_net::config::NetworkConfig;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
Expand Down Expand Up @@ -897,7 +898,7 @@ impl<F: PrimeField, C: VmCircomWitnessExtension<F>> WitnessExtension<F, C> {
///
/// Panics if any of the [`CodeBlocks`](CodeBlock) are corrupted.
pub fn run(
&mut self,
mut self,
input_signals: SharedInput<F, C::ArithmeticShare>,
) -> Result<FinalizedWitnessExtension<F, C>> {
self.driver.compare_vm_config(&self.config)?;
Expand Down Expand Up @@ -928,7 +929,7 @@ impl<F: PrimeField, C: VmCircomWitnessExtension<F>> WitnessExtension<F, C> {
///
/// Panics if any of the [`CodeBlocks`](CodeBlock) are corrupted.
pub fn run_with_flat(
&mut self,
mut self,
input_signals: Vec<C::VmType>,
amount_public_inputs: usize,
) -> Result<FinalizedWitnessExtension<F, C>> {
Expand Down Expand Up @@ -1064,4 +1065,36 @@ impl<F: PrimeField> Rep3WitnessExtension<F, Rep3MpcNet> {
let network = Rep3MpcNet::new(network_config)?;
Self::from_network(parser, network, mpc_accelerator, config)
}

/// Starts the execution of the MPC-VM with the provided [SharedInput], consumes `self` and returns the [`Rep3MpcNet`].
///
/// Use this method over [`run_with_flat()`](WitnessExtension::run) when ever possible.
/// # Arguments
///
/// * `input_signals` - The [SharedInput] distributed over the parties.
///
/// # Returns
///
/// * `Ok(([SharedWitness], Rep3MpcNet))` - The secret-shared witness, distributed over the parties.
/// * `Err([eyre::Result])` - An error result.
///
/// # Panics
///
/// Panics if any of the [`CodeBlocks`](CodeBlock) are corrupted.
#[allow(clippy::type_complexity)]
pub fn run_and_get_network(
mut self,
input_signals: SharedInput<F, Rep3PrimeFieldShare<F>>,
) -> Result<(
FinalizedWitnessExtension<F, CircomRep3VmWitnessExtension<F, Rep3MpcNet>>,
Rep3MpcNet,
)> {
self.driver.compare_vm_config(&self.config)?;
let amount_public_inputs = self.set_input_signals(input_signals)?;
self.call_main_component()?;
Ok((
self.post_processing(amount_public_inputs)?,
self.driver.get_network(),
))
}
}
2 changes: 1 addition & 1 deletion co-circom/co-circom/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ where
let id = usize::from(net.get_id());

// init MPC protocol
let mut rep3_vm = parsed_circom_circuit
let rep3_vm = parsed_circom_circuit
.to_rep3_vm_with_network(net, config.vm)
.context("while constructing MPC VM")?;

Expand Down
2 changes: 1 addition & 1 deletion tests/tests/circom/witness_extension_tests/rep3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ macro_rules! run_test {
compiler_config
.link_library
.push("../test_vectors/WitnessExtension/tests/libs/".into());
let mut witness_extension =
let witness_extension =
CoCircomCompiler::<Bn254>::parse($file.to_owned(), compiler_config)
.unwrap()
.to_rep3_vm_with_network(net, VMConfig::default())
Expand Down

0 comments on commit b362504

Please sign in to comment.