Skip to content

Commit

Permalink
refactor!: removed acvm_impl. now uses old driver for ACVM
Browse files Browse the repository at this point in the history
  • Loading branch information
0xThemis authored and dkales committed Sep 2, 2024
1 parent 9d1f37a commit d37c5bb
Show file tree
Hide file tree
Showing 13 changed files with 350 additions and 577 deletions.
2 changes: 2 additions & 0 deletions co-noir/co-acvm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ rust-version.workspace = true
[dependencies]
acir.workspace = true
acvm.workspace = true
ark-bn254.workspace = true
ark-ff.workspace = true
eyre.workspace = true
intmap.workspace = true
mpc-core.workspace = true
Expand Down
80 changes: 50 additions & 30 deletions co-noir/co-acvm/src/solver.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use acir::{
acir_field::GenericFieldElement,
circuit::{Circuit, ExpressionWidth, Opcode},
native_types::{WitnessMap, WitnessStack},
AcirField, FieldElement,
};
use ark_ff::PrimeField;
use intmap::IntMap;
use mpc_core::{
protocols::{
Expand All @@ -13,9 +14,7 @@ use mpc_core::{
};
use noirc_abi::{input_parser::Format, Abi, MAIN_RETURN_NAME};
use noirc_artifacts::program::ProgramArtifact;
use num_bigint::BigUint;
use num_traits::One;
use std::{io, path::PathBuf};
use std::{collections::BTreeMap, io, path::PathBuf};
/// The default expression width defined used by the ACVM.
pub(crate) const CO_EXPRESSION_WIDTH: ExpressionWidth = ExpressionWidth::Bounded { width: 4 };

Expand All @@ -37,11 +36,11 @@ pub enum CoAcvmError {
pub struct CoSolver<T, F>
where
T: NoirWitnessExtensionProtocol<F>,
F: AcirField,
F: PrimeField,
{
driver: T,
abi: Abi,
functions: Vec<Circuit<F>>,
functions: Vec<Circuit<GenericFieldElement<F>>>,
// maybe this can be an array. lets see..
witness_map: Vec<WitnessMap<T::AcvmType>>,
// there will a more fields added as we add functionality
Expand All @@ -50,11 +49,11 @@ where
memory_access: IntMap<T::LUT>,
}

impl<T> CoSolver<T, FieldElement>
impl<T> CoSolver<T, ark_bn254::Fr>
where
T: NoirWitnessExtensionProtocol<FieldElement>,
T: NoirWitnessExtensionProtocol<ark_bn254::Fr>,
{
pub fn read_abi<P>(path: P, abi: &Abi) -> eyre::Result<WitnessMap<T::AcvmType>>
pub fn read_abi_bn254<P>(path: P, abi: &Abi) -> eyre::Result<WitnessMap<T::AcvmType>>
where
PathBuf: From<P>,
{
Expand All @@ -70,13 +69,14 @@ where
let initial_witness = abi.encode(&input_map, return_value.clone())?;
let mut witnesses = WitnessMap::<T::AcvmType>::default();
for (witness, v) in initial_witness.into_iter() {
witnesses.insert(witness, T::AcvmType::from(v));
witnesses.insert(witness, T::AcvmType::from(v.into_repr())); //TODO this can be
//private for some
}
Ok(witnesses)
}
}

pub fn new<P>(
pub fn new_bn254<P>(
driver: T,
compiled_program: ProgramArtifact,
prover_path: P,
Expand All @@ -86,7 +86,7 @@ where
{
let mut witness_map =
vec![WitnessMap::default(); compiled_program.bytecode.functions.len()];
witness_map[0] = Self::read_abi(prover_path, &compiled_program.abi)?;
witness_map[0] = Self::read_abi_bn254(prover_path, &compiled_program.abi)?;
Ok(Self {
driver,
abi: compiled_program.abi,
Expand All @@ -104,7 +104,7 @@ where
}
}

impl<N: Rep3Network> Rep3CoSolver<FieldElement, N> {
impl<N: Rep3Network> Rep3CoSolver<ark_bn254::Fr, N> {
pub fn from_network<P>(
network: N,
compiled_program: ProgramArtifact,
Expand All @@ -113,34 +113,52 @@ impl<N: Rep3Network> Rep3CoSolver<FieldElement, N> {
where
PathBuf: From<P>,
{
Self::new(Rep3Protocol::new(network)?, compiled_program, prover_path)
Self::new_bn254(Rep3Protocol::new(network)?, compiled_program, prover_path)
}
}

impl PlainCoSolver<FieldElement> {
impl<F: PrimeField> PlainCoSolver<F> {
pub fn convert_to_plain_acvm_witness(
mut shared_witness: WitnessStack<F>,
) -> WitnessStack<GenericFieldElement<F>> {
let length = shared_witness.length();
let mut vec = Vec::with_capacity(length);
for _ in 0..length {
let stack_item = shared_witness.pop().unwrap();
vec.push((
stack_item.index,
stack_item
.witness
.into_iter()
.map(|(k, v)| (k, GenericFieldElement::from_repr(v)))
.collect::<BTreeMap<_, _>>(),
))
}
let mut witness = WitnessStack::default();
//push again in reverse order
for (index, witness_map) in vec.into_iter().rev() {
witness.push(index, WitnessMap::from(witness_map));
}
witness
}
}

impl PlainCoSolver<ark_bn254::Fr> {
pub fn init_plain_driver<P>(
compiled_program: ProgramArtifact,
prover_path: P,
) -> eyre::Result<Self>
where
PathBuf: From<P>,
{
let modulus = FieldElement::modulus();
let one = BigUint::one();
let two = BigUint::from(2u64);
// FIXME find something better
let negative_one = FieldElement::from_be_bytes_reduce(&(modulus / two + one).to_bytes_be());
Self::new(
PlainDriver::new(negative_one),
compiled_program,
prover_path,
)
Self::new_bn254(PlainDriver::default(), compiled_program, prover_path)
}

pub fn solve_and_print_output(self) {
let abi = self.abi.clone();
let result = self.solve().unwrap();
let main_witness = result.peek().unwrap();
let mut result = Self::convert_to_plain_acvm_witness(result);
let main_witness = result.pop().unwrap();
let (_, ret_val) = abi.decode(&main_witness.witness).unwrap();
if let Some(ret_val) = ret_val {
println!("circuit produced: {ret_val:?}");
Expand All @@ -153,17 +171,18 @@ impl PlainCoSolver<FieldElement> {
impl<T, F> CoSolver<T, F>
where
T: NoirWitnessExtensionProtocol<F>,
F: AcirField,
F: PrimeField,
{
#[inline(always)]
fn witness(&mut self) -> &mut WitnessMap<T::AcvmType> {
&mut self.witness_map[self.function_index]
}
}

impl<T> CoSolver<T, FieldElement>
impl<T, F> CoSolver<T, F>
where
T: NoirWitnessExtensionProtocol<FieldElement>,
T: NoirWitnessExtensionProtocol<F>,
F: PrimeField,
{
pub fn solve(mut self) -> CoAcvmResult<WitnessStack<T::AcvmType>> {
let functions = std::mem::take(&mut self.functions);
Expand All @@ -190,7 +209,8 @@ where
//} => todo!(),
}
}
// TODO this is most likely not correct.
// TODO this is most likely not correct. We just reverse the order of the Vec. Maybe this
// is fine?
// We'll see what happens here.
let mut witness_stack = WitnessStack::default();
for (idx, witness) in self.witness_map.into_iter().rev().enumerate() {
Expand Down
57 changes: 43 additions & 14 deletions co-noir/co-acvm/src/solver/assert_zero_solver.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
use acir::{native_types::Expression, AcirField};
use acir::{acir_field::GenericFieldElement, native_types::Expression, AcirField};
use ark_ff::PrimeField;
use mpc_core::traits::NoirWitnessExtensionProtocol;

use super::{CoAcvmResult, CoSolver};

impl<T, F> CoSolver<T, F>
where
T: NoirWitnessExtensionProtocol<F>,
F: AcirField,
F: PrimeField,
{
fn evaluate_mul_terms(
&mut self,
expr: &Expression<F>,
expr: &Expression<GenericFieldElement<F>>,
acc: &mut Expression<T::AcvmType>,
) -> CoAcvmResult<()> {
tracing::trace!("evaluating mul terms for simplification");
Expand All @@ -32,16 +33,17 @@ where
// should solve this without batching
(Some(lhs), Some(rhs)) => {
tracing::trace!("solving mul term...");
self.driver.solve_mul_term(*c, lhs, rhs, &mut acc.q_c)?;
self.driver
.solve_mul_term(c.into_repr(), lhs, rhs, &mut acc.q_c)?;
}
(Some(lhs), None) => {
tracing::trace!("partially solving mul term...");
let partly_solved = self.driver.acvm_mul_with_public(*c, lhs)?;
let partly_solved = self.driver.acvm_mul_with_public(c.into_repr(), lhs)?;
acc.linear_combinations.push((partly_solved, *rhs));
}
(None, Some(rhs)) => {
tracing::trace!("partially solving mul term...");
let partly_solved = self.driver.acvm_mul_with_public(*c, rhs)?;
let partly_solved = self.driver.acvm_mul_with_public(c.into_repr(), rhs)?;
acc.linear_combinations.push((partly_solved, *lhs));
}
(None, None) => Err(eyre::eyre!(
Expand All @@ -57,24 +59,29 @@ where
}
}

fn evaluate_linear_terms(&mut self, expr: &Expression<F>, acc: &mut Expression<T::AcvmType>) {
fn evaluate_linear_terms(
&mut self,
expr: &Expression<GenericFieldElement<F>>,
acc: &mut Expression<T::AcvmType>,
) {
for term in expr.linear_combinations.iter() {
let (q_l, w_l) = term;
tracing::trace!("looking at linear term: {q_l} * _{}..", w_l.0);
if let Some(w_l) = self.witness().get(w_l).cloned() {
tracing::trace!("is known! reduce it");
self.driver.solve_linear_term(*q_l, w_l, &mut acc.q_c);
self.driver
.solve_linear_term(q_l.into_repr(), w_l, &mut acc.q_c);
} else {
tracing::trace!("is unknown!");
acc.linear_combinations
.push((T::AcvmType::from(*q_l), *w_l))
.push((T::AcvmType::from(q_l.into_repr()), *w_l))
}
}
}

pub(crate) fn simplify_expression(
&mut self,
expr: &Expression<F>,
expr: &Expression<GenericFieldElement<F>>,
) -> CoAcvmResult<Expression<T::AcvmType>> {
tracing::trace!("simplifying expression...");
// default implementation not exposed if we not have AcirField trait bound
Expand All @@ -89,14 +96,17 @@ where
self.evaluate_linear_terms(expr, &mut simplified);
// add constants
self.driver
.acvm_add_assign_with_public(expr.q_c, &mut simplified.q_c);
.acvm_add_assign_with_public(expr.q_c.into_repr(), &mut simplified.q_c);

Ok(simplified)
}

pub(super) fn solve_assert_zero(&mut self, expr: &Expression<F>) -> CoAcvmResult<()> {
pub(super) fn solve_assert_zero(
&mut self,
expr: &Expression<GenericFieldElement<F>>,
) -> CoAcvmResult<()> {
//first evaluate the already existing terms
tracing::trace!("solving assert zero: {}", expr);
tracing::trace!("solving assert zero: {}", Self::expr_to_string(expr));
let simplified = self.simplify_expression(expr)?;
tracing::trace!("simplified expr: {:?}", simplified);
// if we are here, we do not have any mul terms
Expand Down Expand Up @@ -124,7 +134,7 @@ where

pub(crate) fn evaluate_expression(
&mut self,
expr: &Expression<F>,
expr: &Expression<GenericFieldElement<F>>,
) -> CoAcvmResult<T::AcvmType> {
Ok(self
.simplify_expression(expr)?
Expand All @@ -134,4 +144,23 @@ where
"cannot evaluate expression to const - has unknown"
))?)
}

pub(crate) fn expr_to_string(expr: &Expression<GenericFieldElement<F>>) -> String {
let mut str = String::with_capacity(128);
str.push_str("EXPR [");
let mul_terms = expr
.mul_terms
.iter()
.map(|(q_m, w_l, w_r)| format!("({q_m} * _{w_l:?} * _{w_r:?})"))
.collect::<Vec<String>>()
.join(" + ");
let linear_terms = expr
.linear_combinations
.iter()
.map(|(coef, w)| format!("({coef} * _{w:?})"))
.collect::<Vec<String>>()
.join(" + ");
str.push_str(&format!("({mul_terms}) + ({linear_terms}) + {}", expr.q_c));
str
}
}
20 changes: 11 additions & 9 deletions co-noir/co-acvm/src/solver/memory_solver.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
use acir::{
acir_field::GenericFieldElement,
circuit::opcodes::{BlockId, MemOp},
native_types::{Expression, Witness},
AcirField,
};
use ark_ff::PrimeField;
use mpc_core::traits::NoirWitnessExtensionProtocol;

use super::{CoAcvmResult, CoSolver};

impl<T, F> CoSolver<T, F>
where
T: NoirWitnessExtensionProtocol<F>,
F: AcirField,
F: PrimeField,
{
pub(super) fn solve_memory_init_block(
&mut self,
Expand All @@ -37,16 +38,16 @@ where
.ok_or(eyre::eyre!(
"tried to write not initialized witness to memory - this is a bug"
))?;
let lut = self.driver.init_lut(init);
let lut = self.driver.init_lut_by_acvm_type(init);
self.memory_access.insert(block_id.0.into(), lut);
Ok(())
}

pub(super) fn solve_memory_op(
&mut self,
block_id: BlockId,
op: &MemOp<F>,
_predicate: Option<Expression<F>>,
op: &MemOp<GenericFieldElement<F>>,
_predicate: Option<Expression<GenericFieldElement<F>>>,
) -> CoAcvmResult<()> {
tracing::trace!("solving memory op {:?}", op);
let index = self.evaluate_expression(&op.index)?;
Expand All @@ -68,8 +69,9 @@ where
"value for mem op must be a degree one univariate polynomial"
))
}?;
let read_write = op.operation.q_c.into_repr();
//TODO CHECK PREDICATE - do we need to cmux here?
if op.operation.q_c.is_zero() {
if read_write.is_zero() {
// read the value from the LUT
tracing::trace!("reading value from LUT");
let lut = self
Expand All @@ -79,9 +81,9 @@ where
"tried to access block {} but not present",
block_id.0
))?;
let value = self.driver.get_from_lut(&index, lut);
let value = self.driver.read_lut_by_acvm_type(&index, lut);
self.witness().insert(witness, value);
} else if op.operation.q_c.is_one() {
} else if read_write.is_one() {
// write value to LUT
tracing::trace!("writing value to LUT");
let value = self
Expand All @@ -96,7 +98,7 @@ where
"tried to access block {} but not present",
block_id.0
))?;
self.driver.write_to_lut(index, value, lut);
self.driver.write_lut_by_acvm_type(index, value, lut);
} else {
Err(eyre::eyre!(
"Got unknown operation {} for mem op - this is a bug",
Expand Down
Loading

0 comments on commit d37c5bb

Please sign in to comment.