From 3287aa29c1e85dd89d5ae9f73bffa9406cc47d08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Rodr=C3=ADguez?= Date: Wed, 10 Apr 2024 14:15:51 +0200 Subject: [PATCH 1/2] feat: Brillig heterogeneous memory cells (#5608) Since memory in brillig is typed, we can store data in a different way depending on the bit size. This PR makes brillig store memory cells as field only if they are field elements and as BigUints otherwise, avoiding conversions when doing integer operations to improve performance. --------- Co-authored-by: TomAFrench Co-authored-by: Tom French <15848336+TomAFrench@users.noreply.github.com> Co-authored-by: Maxim Vezenov --- .../acvm-repo/acvm/src/pwg/brillig.rs | 6 +- .../acvm-repo/brillig_vm/src/arithmetic.rs | 68 ++--- .../acvm-repo/brillig_vm/src/black_box.rs | 14 +- .../noir-repo/acvm-repo/brillig_vm/src/lib.rs | 71 +++--- .../acvm-repo/brillig_vm/src/memory.rs | 234 ++++++++++++++---- .../brillig/brillig_gen/brillig_directive.rs | 152 ++++++++---- .../brillig/brillig_gen/brillig_slice_ops.rs | 8 +- .../src/brillig/brillig_ir/entry_point.rs | 4 +- .../src/ssa/acir_gen/acir_ir/acir_variable.rs | 4 +- .../noir-repo/tooling/debugger/src/context.rs | 6 +- noir/noir-repo/tooling/debugger/src/repl.rs | 2 +- 11 files changed, 388 insertions(+), 181 deletions(-) diff --git a/noir/noir-repo/acvm-repo/acvm/src/pwg/brillig.rs b/noir/noir-repo/acvm-repo/acvm/src/pwg/brillig.rs index bcf736cd926..81e752d5656 100644 --- a/noir/noir-repo/acvm-repo/acvm/src/pwg/brillig.rs +++ b/noir/noir-repo/acvm-repo/acvm/src/pwg/brillig.rs @@ -206,13 +206,13 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { for output in brillig.outputs.iter() { match output { BrilligOutputs::Simple(witness) => { - insert_value(witness, memory[current_ret_data_idx].value, witness_map)?; + insert_value(witness, memory[current_ret_data_idx].to_field(), witness_map)?; current_ret_data_idx += 1; } BrilligOutputs::Array(witness_arr) => { for witness in witness_arr.iter() { - let value = memory[current_ret_data_idx]; - insert_value(witness, value.value, witness_map)?; + let value = &memory[current_ret_data_idx]; + insert_value(witness, value.to_field(), witness_map)?; current_ret_data_idx += 1; } } diff --git a/noir/noir-repo/acvm-repo/brillig_vm/src/arithmetic.rs b/noir/noir-repo/acvm-repo/brillig_vm/src/arithmetic.rs index 3d77982ffb1..2107d10c093 100644 --- a/noir/noir-repo/acvm-repo/brillig_vm/src/arithmetic.rs +++ b/noir/noir-repo/acvm-repo/brillig_vm/src/arithmetic.rs @@ -1,9 +1,10 @@ use acir::brillig::{BinaryFieldOp, BinaryIntOp}; use acir::FieldElement; use num_bigint::BigUint; -use num_traits::{One, ToPrimitive, Zero}; +use num_traits::ToPrimitive; +use num_traits::{One, Zero}; -use crate::memory::MemoryValue; +use crate::memory::{MemoryTypeError, MemoryValue}; #[derive(Debug, thiserror::Error)] pub(crate) enum BrilligArithmeticError { @@ -11,6 +12,8 @@ pub(crate) enum BrilligArithmeticError { MismatchedLhsBitSize { lhs_bit_size: u32, op_bit_size: u32 }, #[error("Bit size for rhs {rhs_bit_size} does not match op bit size {op_bit_size}")] MismatchedRhsBitSize { rhs_bit_size: u32, op_bit_size: u32 }, + #[error("Integer operation BinaryIntOp::{op:?} is not supported on FieldElement")] + IntegerOperationOnField { op: BinaryIntOp }, #[error("Shift with bit size {op_bit_size} is invalid")] InvalidShift { op_bit_size: u32 }, } @@ -21,21 +24,19 @@ pub(crate) fn evaluate_binary_field_op( lhs: MemoryValue, rhs: MemoryValue, ) -> Result { - if lhs.bit_size != FieldElement::max_num_bits() { + let MemoryValue::Field(a) = lhs else { return Err(BrilligArithmeticError::MismatchedLhsBitSize { - lhs_bit_size: lhs.bit_size, + lhs_bit_size: lhs.bit_size(), op_bit_size: FieldElement::max_num_bits(), }); - } - if rhs.bit_size != FieldElement::max_num_bits() { - return Err(BrilligArithmeticError::MismatchedRhsBitSize { - rhs_bit_size: rhs.bit_size, + }; + let MemoryValue::Field(b) = rhs else { + return Err(BrilligArithmeticError::MismatchedLhsBitSize { + lhs_bit_size: rhs.bit_size(), op_bit_size: FieldElement::max_num_bits(), }); - } + }; - let a = lhs.value; - let b = rhs.value; Ok(match op { // Perform addition, subtraction, multiplication, and division based on the BinaryOp variant. BinaryFieldOp::Add => (a + b).into(), @@ -62,21 +63,26 @@ pub(crate) fn evaluate_binary_int_op( rhs: MemoryValue, bit_size: u32, ) -> Result { - if lhs.bit_size != bit_size { - return Err(BrilligArithmeticError::MismatchedLhsBitSize { - lhs_bit_size: lhs.bit_size, - op_bit_size: bit_size, - }); - } - if rhs.bit_size != bit_size { - return Err(BrilligArithmeticError::MismatchedRhsBitSize { - rhs_bit_size: rhs.bit_size, - op_bit_size: bit_size, - }); - } + let lhs = lhs.expect_integer_with_bit_size(bit_size).map_err(|err| match err { + MemoryTypeError::MismatchedBitSize { value_bit_size, expected_bit_size } => { + BrilligArithmeticError::MismatchedLhsBitSize { + lhs_bit_size: value_bit_size, + op_bit_size: expected_bit_size, + } + } + })?; + let rhs = rhs.expect_integer_with_bit_size(bit_size).map_err(|err| match err { + MemoryTypeError::MismatchedBitSize { value_bit_size, expected_bit_size } => { + BrilligArithmeticError::MismatchedRhsBitSize { + rhs_bit_size: value_bit_size, + op_bit_size: expected_bit_size, + } + } + })?; - let lhs = BigUint::from_bytes_be(&lhs.value.to_be_bytes()); - let rhs = BigUint::from_bytes_be(&rhs.value.to_be_bytes()); + if bit_size == FieldElement::max_num_bits() { + return Err(BrilligArithmeticError::IntegerOperationOnField { op: *op }); + } let bit_modulo = &(BigUint::one() << bit_size); let result = match op { @@ -136,13 +142,11 @@ pub(crate) fn evaluate_binary_int_op( } }; - let result_as_field = FieldElement::from_be_bytes_reduce(&result.to_bytes_be()); - Ok(match op { BinaryIntOp::Equals | BinaryIntOp::LessThan | BinaryIntOp::LessThanEquals => { - MemoryValue::new(result_as_field, 1) + MemoryValue::new_integer(result, 1) } - _ => MemoryValue::new(result_as_field, bit_size), + _ => MemoryValue::new_integer(result, bit_size), }) } @@ -159,13 +163,13 @@ mod tests { fn evaluate_u128(op: &BinaryIntOp, a: u128, b: u128, bit_size: u32) -> u128 { let result_value = evaluate_binary_int_op( op, - MemoryValue::new(a.into(), bit_size), - MemoryValue::new(b.into(), bit_size), + MemoryValue::new_integer(a.into(), bit_size), + MemoryValue::new_integer(b.into(), bit_size), bit_size, ) .unwrap(); // Convert back to u128 - result_value.value.to_u128() + result_value.to_field().to_u128() } fn to_negative(a: u128, bit_size: u32) -> u128 { diff --git a/noir/noir-repo/acvm-repo/brillig_vm/src/black_box.rs b/noir/noir-repo/acvm-repo/brillig_vm/src/black_box.rs index bd33b5ee8fc..73981fb0625 100644 --- a/noir/noir-repo/acvm-repo/brillig_vm/src/black_box.rs +++ b/noir/noir-repo/acvm-repo/brillig_vm/src/black_box.rs @@ -20,7 +20,7 @@ fn read_heap_array<'a>(memory: &'a Memory, array: &HeapArray) -> &'a [MemoryValu /// Extracts the last byte of every value fn to_u8_vec(inputs: &[MemoryValue]) -> Vec { let mut result = Vec::with_capacity(inputs.len()); - for &input in inputs { + for input in inputs { result.push(input.try_into().unwrap()); } result @@ -63,7 +63,7 @@ pub(crate) fn evaluate_black_box( BlackBoxOp::Keccakf1600 { message, output } => { let state_vec: Vec = read_heap_vector(memory, message) .iter() - .map(|&memory_value| memory_value.try_into().unwrap()) + .map(|memory_value| memory_value.try_into().unwrap()) .collect(); let state: [u64; 25] = state_vec.try_into().unwrap(); @@ -151,7 +151,7 @@ pub(crate) fn evaluate_black_box( } BlackBoxOp::PedersenCommitment { inputs, domain_separator, output } => { let inputs: Vec = - read_heap_vector(memory, inputs).iter().map(|&x| x.try_into().unwrap()).collect(); + read_heap_vector(memory, inputs).iter().map(|x| x.try_into().unwrap()).collect(); let domain_separator: u32 = memory.read(*domain_separator).try_into().map_err(|_| { BlackBoxResolutionError::Failed( @@ -165,7 +165,7 @@ pub(crate) fn evaluate_black_box( } BlackBoxOp::PedersenHash { inputs, domain_separator, output } => { let inputs: Vec = - read_heap_vector(memory, inputs).iter().map(|&x| x.try_into().unwrap()).collect(); + read_heap_vector(memory, inputs).iter().map(|x| x.try_into().unwrap()).collect(); let domain_separator: u32 = memory.read(*domain_separator).try_into().map_err(|_| { BlackBoxResolutionError::Failed( @@ -185,7 +185,7 @@ pub(crate) fn evaluate_black_box( BlackBoxOp::BigIntToLeBytes { .. } => todo!(), BlackBoxOp::Poseidon2Permutation { message, output, len } => { let input = read_heap_vector(memory, message); - let input: Vec = input.iter().map(|&x| x.try_into().unwrap()).collect(); + let input: Vec = input.iter().map(|x| x.try_into().unwrap()).collect(); let len = memory.read(*len).try_into().unwrap(); let result = solver.poseidon2_permutation(&input, len)?; let mut values = Vec::new(); @@ -204,7 +204,7 @@ pub(crate) fn evaluate_black_box( format!("Expected 16 inputs but encountered {}", &inputs.len()), )); } - for (i, &input) in inputs.iter().enumerate() { + for (i, input) in inputs.iter().enumerate() { message[i] = input.try_into().unwrap(); } let mut state = [0; 8]; @@ -215,7 +215,7 @@ pub(crate) fn evaluate_black_box( format!("Expected 8 values but encountered {}", &values.len()), )); } - for (i, &value) in values.iter().enumerate() { + for (i, value) in values.iter().enumerate() { state[i] = value.try_into().unwrap(); } diff --git a/noir/noir-repo/acvm-repo/brillig_vm/src/lib.rs b/noir/noir-repo/acvm-repo/brillig_vm/src/lib.rs index 65654e24720..26d5da67576 100644 --- a/noir/noir-repo/acvm-repo/brillig_vm/src/lib.rs +++ b/noir/noir-repo/acvm-repo/brillig_vm/src/lib.rs @@ -289,8 +289,8 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> { // Convert our source_pointer to an address let source = self.memory.read_ref(*source_pointer); // Use our usize source index to lookup the value in memory - let value = &self.memory.read(source); - self.memory.write(*destination_address, *value); + let value = self.memory.read(source); + self.memory.write(*destination_address, value); self.increment_program_counter() } Opcode::Store { destination_pointer, source: source_address } => { @@ -307,7 +307,7 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> { } Opcode::Const { destination, value, bit_size } => { // Consts are not checked in runtime to fit in the bit size, since they can safely be checked statically. - self.memory.write(*destination, MemoryValue::new(*value, *bit_size)); + self.memory.write(*destination, MemoryValue::new_from_field(*value, *bit_size)); self.increment_program_counter() } Opcode::BlackBox(black_box_op) => { @@ -348,7 +348,7 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> { ) -> ForeignCallParam { match (input, value_type) { (ValueOrArray::MemoryAddress(value_index), HeapValueType::Simple(_)) => { - self.memory.read(value_index).value.into() + self.memory.read(value_index).to_field().into() } ( ValueOrArray::HeapArray(HeapArray { pointer: pointer_index, size }), @@ -357,7 +357,7 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> { let start = self.memory.read_ref(pointer_index); self.read_slice_of_values_from_memory(start, size, value_types) .into_iter() - .map(|mem_value| mem_value.value) + .map(|mem_value| mem_value.to_field()) .collect::>() .into() } @@ -369,7 +369,7 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> { let size = self.memory.read(size_index).to_usize(); self.read_slice_of_values_from_memory(start, size, value_types) .into_iter() - .map(|mem_value| mem_value.value) + .map(|mem_value| mem_value.to_field()) .collect::>() .into() } @@ -584,12 +584,9 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> { /// Casts a value to a different bit size. fn cast(&self, bit_size: u32, source_value: MemoryValue) -> MemoryValue { - let lhs_big = BigUint::from_bytes_be(&source_value.value.to_be_bytes()); + let lhs_big = source_value.to_integer(); let mask = BigUint::from(2_u32).pow(bit_size) - 1_u32; - MemoryValue { - value: FieldElement::from_be_bytes_reduce(&(lhs_big & mask).to_bytes_be()), - bit_size, - } + MemoryValue::new_from_integer(lhs_big & mask, bit_size) } } @@ -627,7 +624,7 @@ mod tests { let VM { memory, .. } = vm; let output_value = memory.read(MemoryAddress::from(0)); - assert_eq!(output_value.value, FieldElement::from(27u128)); + assert_eq!(output_value.to_field(), FieldElement::from(27u128)); } #[test] @@ -666,7 +663,7 @@ mod tests { assert_eq!(status, VMStatus::InProgress); let output_cmp_value = vm.memory.read(destination); - assert_eq!(output_cmp_value.value, true.into()); + assert_eq!(output_cmp_value.to_field(), true.into()); let status = vm.process_opcode(); assert_eq!(status, VMStatus::InProgress); @@ -725,7 +722,7 @@ mod tests { assert_eq!(status, VMStatus::InProgress); let output_cmp_value = vm.memory.read(MemoryAddress::from(2)); - assert_eq!(output_cmp_value.value, false.into()); + assert_eq!(output_cmp_value.to_field(), false.into()); let status = vm.process_opcode(); assert_eq!(status, VMStatus::InProgress); @@ -742,7 +739,7 @@ mod tests { // The address at index `2` should have not changed as we jumped over the add opcode let VM { memory, .. } = vm; let output_value = memory.read(MemoryAddress::from(2)); - assert_eq!(output_value.value, false.into()); + assert_eq!(output_value.to_field(), false.into()); } #[test] @@ -776,7 +773,7 @@ mod tests { let VM { memory, .. } = vm; let casted_value = memory.read(MemoryAddress::from(1)); - assert_eq!(casted_value.value, (2_u128.pow(8) - 1).into()); + assert_eq!(casted_value.to_field(), (2_u128.pow(8) - 1).into()); } #[test] @@ -804,10 +801,10 @@ mod tests { let VM { memory, .. } = vm; let destination_value = memory.read(MemoryAddress::from(2)); - assert_eq!(destination_value.value, (1u128).into()); + assert_eq!(destination_value.to_field(), (1u128).into()); let source_value = memory.read(MemoryAddress::from(0)); - assert_eq!(source_value.value, (1u128).into()); + assert_eq!(source_value.to_field(), (1u128).into()); } #[test] @@ -869,10 +866,10 @@ mod tests { let VM { memory, .. } = vm; let destination_value = memory.read(MemoryAddress::from(4)); - assert_eq!(destination_value.value, (3_u128).into()); + assert_eq!(destination_value.to_field(), (3_u128).into()); let source_value = memory.read(MemoryAddress::from(5)); - assert_eq!(source_value.value, (2_u128).into()); + assert_eq!(source_value.to_field(), (2_u128).into()); } #[test] @@ -1120,7 +1117,7 @@ mod tests { let opcodes = [&start[..], &loop_body[..]].concat(); let vm = brillig_execute_and_get_vm(memory, &opcodes); - vm.memory.read(r_sum).value + vm.memory.read(r_sum).to_field() } assert_eq!( @@ -1359,7 +1356,7 @@ mod tests { // Check result in memory let result_values = vm.memory.read_slice(MemoryAddress(2), 4).to_vec(); assert_eq!( - result_values.into_iter().map(|mem_value| mem_value.value).collect::>(), + result_values.into_iter().map(|mem_value| mem_value.to_field()).collect::>(), expected_result ); @@ -1459,7 +1456,7 @@ mod tests { .memory .read_slice(MemoryAddress(4 + input_string.len()), output_string.len()) .iter() - .map(|mem_val| mem_val.value) + .map(|mem_val| mem_val.clone().to_field()) .collect(); assert_eq!(result_values, output_string); @@ -1532,13 +1529,21 @@ mod tests { assert_eq!(vm.status, VMStatus::Finished { return_data_offset: 0, return_data_size: 0 }); // Check initial memory still in place - let initial_values: Vec<_> = - vm.memory.read_slice(MemoryAddress(2), 4).iter().map(|mem_val| mem_val.value).collect(); + let initial_values: Vec<_> = vm + .memory + .read_slice(MemoryAddress(2), 4) + .iter() + .map(|mem_val| mem_val.clone().to_field()) + .collect(); assert_eq!(initial_values, initial_matrix); // Check result in memory - let result_values: Vec<_> = - vm.memory.read_slice(MemoryAddress(6), 4).iter().map(|mem_val| mem_val.value).collect(); + let result_values: Vec<_> = vm + .memory + .read_slice(MemoryAddress(6), 4) + .iter() + .map(|mem_val| mem_val.clone().to_field()) + .collect(); assert_eq!(result_values, expected_result); // Ensure the foreign call counter has been incremented @@ -1622,8 +1627,12 @@ mod tests { assert_eq!(vm.status, VMStatus::Finished { return_data_offset: 0, return_data_size: 0 }); // Check result in memory - let result_values: Vec<_> = - vm.memory.read_slice(MemoryAddress(0), 4).iter().map(|mem_val| mem_val.value).collect(); + let result_values: Vec<_> = vm + .memory + .read_slice(MemoryAddress(0), 4) + .iter() + .map(|mem_val| mem_val.clone().to_field()) + .collect(); assert_eq!(result_values, expected_result); // Ensure the foreign call counter has been incremented @@ -1698,7 +1707,7 @@ mod tests { .chain(memory.iter().enumerate().map(|(index, mem_value)| Opcode::Cast { destination: MemoryAddress(index), source: MemoryAddress(index), - bit_size: mem_value.bit_size, + bit_size: mem_value.bit_size(), })) .chain(vec![ // input = 0 @@ -1721,7 +1730,7 @@ mod tests { .collect(); let mut vm = brillig_execute_and_get_vm( - memory.into_iter().map(|mem_value| mem_value.value).collect(), + memory.into_iter().map(|mem_value| mem_value.to_field()).collect(), &program, ); diff --git a/noir/noir-repo/acvm-repo/brillig_vm/src/memory.rs b/noir/noir-repo/acvm-repo/brillig_vm/src/memory.rs index d563e13be2e..feeb3706bde 100644 --- a/noir/noir-repo/acvm-repo/brillig_vm/src/memory.rs +++ b/noir/noir-repo/acvm-repo/brillig_vm/src/memory.rs @@ -1,11 +1,13 @@ use acir::{brillig::MemoryAddress, FieldElement}; +use num_bigint::BigUint; +use num_traits::{One, Zero}; pub const MEMORY_ADDRESSING_BIT_SIZE: u32 = 64; -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] -pub struct MemoryValue { - pub value: FieldElement, - pub bit_size: u32, +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub enum MemoryValue { + Field(FieldElement), + Integer(BigUint, u32), } #[derive(Debug, thiserror::Error)] @@ -15,53 +17,147 @@ pub enum MemoryTypeError { } impl MemoryValue { - pub fn new(value: FieldElement, bit_size: u32) -> Self { - MemoryValue { value, bit_size } + /// Builds a memory value from a field element. + pub fn new_from_field(value: FieldElement, bit_size: u32) -> Self { + if bit_size == FieldElement::max_num_bits() { + MemoryValue::new_field(value) + } else { + MemoryValue::new_integer(BigUint::from_bytes_be(&value.to_be_bytes()), bit_size) + } + } + + /// Builds a memory value from an integer + pub fn new_from_integer(value: BigUint, bit_size: u32) -> Self { + if bit_size == FieldElement::max_num_bits() { + MemoryValue::new_field(FieldElement::from_be_bytes_reduce(&value.to_bytes_be())) + } else { + MemoryValue::new_integer(value, bit_size) + } } + /// Builds a memory value from a field element, checking that the value is within the bit size. pub fn new_checked(value: FieldElement, bit_size: u32) -> Option { - if value.num_bits() > bit_size { + if bit_size < FieldElement::max_num_bits() && value.num_bits() > bit_size { return None; } - Some(MemoryValue::new(value, bit_size)) + Some(MemoryValue::new_from_field(value, bit_size)) } + /// Builds a field-typed memory value. pub fn new_field(value: FieldElement) -> Self { - MemoryValue { value, bit_size: FieldElement::max_num_bits() } + MemoryValue::Field(value) + } + + /// Builds an integer-typed memory value. + pub fn new_integer(value: BigUint, bit_size: u32) -> Self { + assert!( + bit_size != FieldElement::max_num_bits(), + "Tried to build a field memory value via new_integer" + ); + MemoryValue::Integer(value, bit_size) + } + + /// Extracts the field element from the memory value, if it is typed as field element. + pub fn extract_field(&self) -> Option<&FieldElement> { + match self { + MemoryValue::Field(value) => Some(value), + _ => None, + } + } + + /// Extracts the integer from the memory value, if it is typed as integer. + pub fn extract_integer(&self) -> Option<(&BigUint, u32)> { + match self { + MemoryValue::Integer(value, bit_size) => Some((value, *bit_size)), + _ => None, + } + } + + /// Converts the memory value to a field element, independent of its type. + pub fn to_field(&self) -> FieldElement { + match self { + MemoryValue::Field(value) => *value, + MemoryValue::Integer(value, _) => { + FieldElement::from_be_bytes_reduce(&value.to_bytes_be()) + } + } + } + + /// Converts the memory value to an integer, independent of its type. + pub fn to_integer(self) -> BigUint { + match self { + MemoryValue::Field(value) => BigUint::from_bytes_be(&value.to_be_bytes()), + MemoryValue::Integer(value, _) => value, + } + } + + pub fn bit_size(&self) -> u32 { + match self { + MemoryValue::Field(_) => FieldElement::max_num_bits(), + MemoryValue::Integer(_, bit_size) => *bit_size, + } } pub fn to_usize(&self) -> usize { - assert!(self.bit_size == MEMORY_ADDRESSING_BIT_SIZE, "value is not typed as brillig usize"); - self.value.to_u128() as usize + assert!( + self.bit_size() == MEMORY_ADDRESSING_BIT_SIZE, + "value is not typed as brillig usize" + ); + self.extract_integer().unwrap().0.try_into().unwrap() } - pub fn expect_bit_size(&self, expected_bit_size: u32) -> Result<(), MemoryTypeError> { - if self.bit_size != expected_bit_size { - return Err(MemoryTypeError::MismatchedBitSize { - value_bit_size: self.bit_size, + pub fn expect_field(&self) -> Result<&FieldElement, MemoryTypeError> { + match self { + MemoryValue::Integer(_, bit_size) => Err(MemoryTypeError::MismatchedBitSize { + value_bit_size: *bit_size, + expected_bit_size: FieldElement::max_num_bits(), + }), + MemoryValue::Field(field) => Ok(field), + } + } + + pub fn expect_integer_with_bit_size( + &self, + expected_bit_size: u32, + ) -> Result<&BigUint, MemoryTypeError> { + match self { + MemoryValue::Integer(value, bit_size) => { + if *bit_size != expected_bit_size { + return Err(MemoryTypeError::MismatchedBitSize { + value_bit_size: *bit_size, + expected_bit_size, + }); + } + Ok(value) + } + MemoryValue::Field(_) => Err(MemoryTypeError::MismatchedBitSize { + value_bit_size: FieldElement::max_num_bits(), expected_bit_size, - }); + }), } - Ok(()) } } impl std::fmt::Display for MemoryValue { fn fmt(&self, f: &mut ::std::fmt::Formatter) -> Result<(), ::std::fmt::Error> { - let typ = match self.bit_size { - 0 => "null".to_string(), - 1 => "bool".to_string(), - _ if self.bit_size == FieldElement::max_num_bits() => "field".to_string(), - _ => format!("u{}", self.bit_size), - }; - f.write_str(format!("{}: {}", self.value, typ).as_str()) + match self { + MemoryValue::Field(value) => write!(f, "{}: field", value), + MemoryValue::Integer(value, bit_size) => { + let typ = match bit_size { + 0 => "null".to_string(), + 1 => "bool".to_string(), + _ => format!("u{}", bit_size), + }; + write!(f, "{}: {}", value, typ) + } + } } } impl Default for MemoryValue { fn default() -> Self { - MemoryValue::new(FieldElement::zero(), 0) + MemoryValue::new_integer(BigUint::zero(), 0) } } @@ -73,31 +169,32 @@ impl From for MemoryValue { impl From for MemoryValue { fn from(value: usize) -> Self { - MemoryValue::new(value.into(), MEMORY_ADDRESSING_BIT_SIZE) + MemoryValue::new_integer(value.into(), MEMORY_ADDRESSING_BIT_SIZE) } } impl From for MemoryValue { fn from(value: u64) -> Self { - MemoryValue::new((value as u128).into(), 64) + MemoryValue::new_integer(value.into(), 64) } } impl From for MemoryValue { fn from(value: u32) -> Self { - MemoryValue::new((value as u128).into(), 32) + MemoryValue::new_integer(value.into(), 32) } } impl From for MemoryValue { fn from(value: u8) -> Self { - MemoryValue::new((value as u128).into(), 8) + MemoryValue::new_integer(value.into(), 8) } } impl From for MemoryValue { fn from(value: bool) -> Self { - MemoryValue::new(value.into(), 1) + let value = if value { BigUint::one() } else { BigUint::zero() }; + MemoryValue::new_integer(value, 1) } } @@ -105,8 +202,7 @@ impl TryFrom for FieldElement { type Error = MemoryTypeError; fn try_from(memory_value: MemoryValue) -> Result { - memory_value.expect_bit_size(FieldElement::max_num_bits())?; - Ok(memory_value.value) + memory_value.expect_field().copied() } } @@ -114,8 +210,7 @@ impl TryFrom for u64 { type Error = MemoryTypeError; fn try_from(memory_value: MemoryValue) -> Result { - memory_value.expect_bit_size(64)?; - Ok(memory_value.value.to_u128() as u64) + memory_value.expect_integer_with_bit_size(64).map(|value| value.try_into().unwrap()) } } @@ -123,8 +218,7 @@ impl TryFrom for u32 { type Error = MemoryTypeError; fn try_from(memory_value: MemoryValue) -> Result { - memory_value.expect_bit_size(32)?; - Ok(memory_value.value.to_u128() as u32) + memory_value.expect_integer_with_bit_size(32).map(|value| value.try_into().unwrap()) } } @@ -132,9 +226,7 @@ impl TryFrom for u8 { type Error = MemoryTypeError; fn try_from(memory_value: MemoryValue) -> Result { - memory_value.expect_bit_size(8)?; - - Ok(memory_value.value.to_u128() as u8) + memory_value.expect_integer_with_bit_size(8).map(|value| value.try_into().unwrap()) } } @@ -142,11 +234,65 @@ impl TryFrom for bool { type Error = MemoryTypeError; fn try_from(memory_value: MemoryValue) -> Result { - memory_value.expect_bit_size(1)?; + let as_integer = memory_value.expect_integer_with_bit_size(1)?; + + if as_integer.is_zero() { + Ok(false) + } else if as_integer.is_one() { + Ok(true) + } else { + unreachable!("value typed as bool is greater than one") + } + } +} + +impl TryFrom<&MemoryValue> for FieldElement { + type Error = MemoryTypeError; + + fn try_from(memory_value: &MemoryValue) -> Result { + memory_value.expect_field().copied() + } +} + +impl TryFrom<&MemoryValue> for u64 { + type Error = MemoryTypeError; + + fn try_from(memory_value: &MemoryValue) -> Result { + memory_value.expect_integer_with_bit_size(64).map(|value| { + value.try_into().expect("memory_value has been asserted to contain a 64 bit integer") + }) + } +} + +impl TryFrom<&MemoryValue> for u32 { + type Error = MemoryTypeError; + + fn try_from(memory_value: &MemoryValue) -> Result { + memory_value.expect_integer_with_bit_size(32).map(|value| { + value.try_into().expect("memory_value has been asserted to contain a 32 bit integer") + }) + } +} + +impl TryFrom<&MemoryValue> for u8 { + type Error = MemoryTypeError; + + fn try_from(memory_value: &MemoryValue) -> Result { + memory_value.expect_integer_with_bit_size(8).map(|value| { + value.try_into().expect("memory_value has been asserted to contain an 8 bit integer") + }) + } +} + +impl TryFrom<&MemoryValue> for bool { + type Error = MemoryTypeError; + + fn try_from(memory_value: &MemoryValue) -> Result { + let as_integer = memory_value.expect_integer_with_bit_size(1)?; - if memory_value.value == FieldElement::zero() { + if as_integer.is_zero() { Ok(false) - } else if memory_value.value == FieldElement::one() { + } else if as_integer.is_one() { Ok(true) } else { unreachable!("value typed as bool is greater than one") @@ -164,7 +310,7 @@ pub struct Memory { impl Memory { /// Gets the value at pointer pub fn read(&self, ptr: MemoryAddress) -> MemoryValue { - self.inner.get(ptr.to_usize()).copied().unwrap_or_default() + self.inner.get(ptr.to_usize()).cloned().unwrap_or_default() } pub fn read_ref(&self, ptr: MemoryAddress) -> MemoryAddress { @@ -191,7 +337,7 @@ impl Memory { /// Sets the values after pointer `ptr` to `values` pub fn write_slice(&mut self, ptr: MemoryAddress, values: &[MemoryValue]) { self.resize_to_fit(ptr.to_usize() + values.len()); - self.inner[ptr.to_usize()..(ptr.to_usize() + values.len())].copy_from_slice(values); + self.inner[ptr.to_usize()..(ptr.to_usize() + values.len())].clone_from_slice(values); } /// Returns the values of the memory diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_directive.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_directive.rs index 15a2a531e78..4b97a61491d 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_directive.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_directive.rs @@ -67,61 +67,105 @@ pub(crate) fn directive_invert() -> GeneratedBrillig { /// (a/b, a-a/b*b) /// } /// ``` -pub(crate) fn directive_quotient(mut bit_size: u32) -> GeneratedBrillig { +pub(crate) fn directive_quotient(bit_size: u32) -> GeneratedBrillig { // `a` is (0) (i.e register index 0) // `b` is (1) - if bit_size > FieldElement::max_num_bits() { - bit_size = FieldElement::max_num_bits(); - } - GeneratedBrillig { - byte_code: vec![ - BrilligOpcode::CalldataCopy { - destination_address: MemoryAddress::from(0), - size: 2, - offset: 0, - }, - BrilligOpcode::Cast { - destination: MemoryAddress(0), - source: MemoryAddress(0), - bit_size, - }, - BrilligOpcode::Cast { - destination: MemoryAddress(1), - source: MemoryAddress(1), - bit_size, - }, - //q = a/b is set into register (2) - BrilligOpcode::BinaryIntOp { - op: BinaryIntOp::Div, - lhs: MemoryAddress::from(0), - rhs: MemoryAddress::from(1), - destination: MemoryAddress::from(2), - bit_size, - }, - //(1)= q*b - BrilligOpcode::BinaryIntOp { - op: BinaryIntOp::Mul, - lhs: MemoryAddress::from(2), - rhs: MemoryAddress::from(1), - destination: MemoryAddress::from(1), - bit_size, - }, - //(1) = a-q*b - BrilligOpcode::BinaryIntOp { - op: BinaryIntOp::Sub, - lhs: MemoryAddress::from(0), - rhs: MemoryAddress::from(1), - destination: MemoryAddress::from(1), - bit_size, - }, - //(0) = q - BrilligOpcode::Mov { - destination: MemoryAddress::from(0), - source: MemoryAddress::from(2), - }, - BrilligOpcode::Stop { return_data_offset: 0, return_data_size: 2 }, - ], - assert_messages: Default::default(), - locations: Default::default(), + + // TODO: The only difference between these implementations is the integer version will truncate the input to the `bit_size` via cast. + // Once we deduplicate brillig functions then we can modify this so that fields and integers share the same quotient function. + if bit_size >= FieldElement::max_num_bits() { + // Field version + GeneratedBrillig { + byte_code: vec![ + BrilligOpcode::CalldataCopy { + destination_address: MemoryAddress::from(0), + size: 2, + offset: 0, + }, + // No cast, since calldata is typed as field by default + //q = a/b is set into register (2) + BrilligOpcode::BinaryFieldOp { + op: BinaryFieldOp::IntegerDiv, // We want integer division, not field division! + lhs: MemoryAddress::from(0), + rhs: MemoryAddress::from(1), + destination: MemoryAddress::from(2), + }, + //(1)= q*b + BrilligOpcode::BinaryFieldOp { + op: BinaryFieldOp::Mul, + lhs: MemoryAddress::from(2), + rhs: MemoryAddress::from(1), + destination: MemoryAddress::from(1), + }, + //(1) = a-q*b + BrilligOpcode::BinaryFieldOp { + op: BinaryFieldOp::Sub, + lhs: MemoryAddress::from(0), + rhs: MemoryAddress::from(1), + destination: MemoryAddress::from(1), + }, + //(0) = q + BrilligOpcode::Mov { + destination: MemoryAddress::from(0), + source: MemoryAddress::from(2), + }, + BrilligOpcode::Stop { return_data_offset: 0, return_data_size: 2 }, + ], + assert_messages: Default::default(), + locations: Default::default(), + } + } else { + // Integer version + GeneratedBrillig { + byte_code: vec![ + BrilligOpcode::CalldataCopy { + destination_address: MemoryAddress::from(0), + size: 2, + offset: 0, + }, + BrilligOpcode::Cast { + destination: MemoryAddress(0), + source: MemoryAddress(0), + bit_size, + }, + BrilligOpcode::Cast { + destination: MemoryAddress(1), + source: MemoryAddress(1), + bit_size, + }, + //q = a/b is set into register (2) + BrilligOpcode::BinaryIntOp { + op: BinaryIntOp::Div, + lhs: MemoryAddress::from(0), + rhs: MemoryAddress::from(1), + destination: MemoryAddress::from(2), + bit_size, + }, + //(1)= q*b + BrilligOpcode::BinaryIntOp { + op: BinaryIntOp::Mul, + lhs: MemoryAddress::from(2), + rhs: MemoryAddress::from(1), + destination: MemoryAddress::from(1), + bit_size, + }, + //(1) = a-q*b + BrilligOpcode::BinaryIntOp { + op: BinaryIntOp::Sub, + lhs: MemoryAddress::from(0), + rhs: MemoryAddress::from(1), + destination: MemoryAddress::from(1), + bit_size, + }, + //(0) = q + BrilligOpcode::Mov { + destination: MemoryAddress::from(0), + source: MemoryAddress::from(2), + }, + BrilligOpcode::Stop { return_data_offset: 0, return_data_size: 2 }, + ], + assert_messages: Default::default(), + locations: Default::default(), + } } } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_slice_ops.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_slice_ops.rs index b93693d9c79..3e1515b1eed 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_slice_ops.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_slice_ops.rs @@ -465,7 +465,7 @@ mod tests { assert_eq!( vm.get_memory()[return_data_offset..(return_data_offset + expected_return.len())] .iter() - .map(|mem_val| mem_val.value) + .map(|mem_val| mem_val.to_field()) .collect::>(), expected_return ); @@ -590,7 +590,7 @@ mod tests { assert_eq!( vm.get_memory()[return_data_offset..(return_data_offset + expected_return.len())] .iter() - .map(|mem_val| mem_val.value) + .map(|mem_val| mem_val.to_field()) .collect::>(), expected_return ); @@ -686,7 +686,7 @@ mod tests { assert_eq!( vm.get_memory()[return_data_offset..(return_data_offset + expected_return.len())] .iter() - .map(|mem_val| mem_val.value) + .map(|mem_val| mem_val.to_field()) .collect::>(), expected_return ); @@ -838,7 +838,7 @@ mod tests { assert_eq!( vm.get_memory()[return_data_offset..(return_data_offset + expected_return.len())] .iter() - .map(|mem_val| mem_val.value) + .map(|mem_val| mem_val.to_field()) .collect::>(), expected_return ); diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/entry_point.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/entry_point.rs index db872487fcc..fe3c5e0bb9c 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/entry_point.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/entry_point.rs @@ -527,7 +527,7 @@ mod tests { let (vm, return_data_offset, return_data_size) = create_and_run_vm(calldata.clone(), &bytecode); assert_eq!(return_data_size, 1, "Return data size is incorrect"); - assert_eq!(vm.get_memory()[return_data_offset].value, FieldElement::from(1_usize)); + assert_eq!(vm.get_memory()[return_data_offset].to_field(), FieldElement::from(1_usize)); } #[test] @@ -569,7 +569,7 @@ mod tests { assert_eq!( memory[return_data_pointer..(return_data_pointer + flattened_array.len())] .iter() - .map(|mem_val| mem_val.value) + .map(|mem_val| mem_val.to_field()) .collect::>(), flattened_array ); diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs index 53d9e2530cc..775571f4a41 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs @@ -1623,7 +1623,7 @@ impl AcirContext { let outputs_var = vecmap(outputs_types.iter(), |output| match output { AcirType::NumericType(_) => { let var = self.add_data(AcirVarData::Const( - memory.next().expect("Missing return data").value, + memory.next().expect("Missing return data").to_field(), )); AcirValue::Var(var, output.clone()) } @@ -1657,7 +1657,7 @@ impl AcirContext { AcirType::NumericType(_) => { let memory_value = memory_iter.next().expect("ICE: Unexpected end of memory"); - let var = self.add_data(AcirVarData::Const(memory_value.value)); + let var = self.add_data(AcirVarData::Const(memory_value.to_field())); array_values.push_back(AcirValue::Var(var, element_type.clone())); } } diff --git a/noir/noir-repo/tooling/debugger/src/context.rs b/noir/noir-repo/tooling/debugger/src/context.rs index 1acd581b2be..b211832518d 100644 --- a/noir/noir-repo/tooling/debugger/src/context.rs +++ b/noir/noir-repo/tooling/debugger/src/context.rs @@ -513,7 +513,11 @@ impl<'a, B: BlackBoxFunctionSolver> DebugContext<'a, B> { pub(super) fn write_brillig_memory(&mut self, ptr: usize, value: FieldElement, bit_size: u32) { if let Some(solver) = self.brillig_solver.as_mut() { - solver.write_memory_at(ptr, MemoryValue::new(value, bit_size)); + solver.write_memory_at( + ptr, + MemoryValue::new_checked(value, bit_size) + .expect("Invalid value for the given bit size"), + ); } } diff --git a/noir/noir-repo/tooling/debugger/src/repl.rs b/noir/noir-repo/tooling/debugger/src/repl.rs index 1c077c6ee9b..e30d519b62e 100644 --- a/noir/noir-repo/tooling/debugger/src/repl.rs +++ b/noir/noir-repo/tooling/debugger/src/repl.rs @@ -319,7 +319,7 @@ impl<'a, B: BlackBoxFunctionSolver> ReplDebugger<'a, B> { return; }; - for (index, value) in memory.iter().enumerate().filter(|(_, value)| value.bit_size > 0) { + for (index, value) in memory.iter().enumerate().filter(|(_, value)| value.bit_size() > 0) { println!("{index} = {}", value); } } From d007d79c7014261d9c663e28c948600d92e85759 Mon Sep 17 00:00:00 2001 From: Ilyas Ridhuan Date: Wed, 10 Apr 2024 13:52:07 +0100 Subject: [PATCH 2/2] feat(avm): enable contract testing with bb binary (#5584) Please read [contributing guidelines](CONTRIBUTING.md) and remove this line. --------- Co-authored-by: Facundo --- .../vm/avm_trace/avm_mem_trace.hpp | 3 +- .../simulator/src/public/avm_executor.test.ts | 38 +++++++------------ 2 files changed, 15 insertions(+), 26 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_mem_trace.hpp b/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_mem_trace.hpp index 8c1d649f78b..14c1d8accc5 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_mem_trace.hpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_mem_trace.hpp @@ -8,7 +8,6 @@ namespace bb::avm_trace { class AvmMemTraceBuilder { public: - static const size_t MEM_SIZE = 1024; static const uint32_t SUB_CLK_IND_LOAD_A = 0; static const uint32_t SUB_CLK_IND_LOAD_B = 1; static const uint32_t SUB_CLK_IND_LOAD_C = 2; @@ -124,4 +123,4 @@ class AvmMemTraceBuilder { AvmMemoryTag r_in_tag, AvmMemoryTag w_in_tag); }; -} // namespace bb::avm_trace \ No newline at end of file +} // namespace bb::avm_trace diff --git a/yarn-project/simulator/src/public/avm_executor.test.ts b/yarn-project/simulator/src/public/avm_executor.test.ts index 969f1c730f1..6a8e13a51f3 100644 --- a/yarn-project/simulator/src/public/avm_executor.test.ts +++ b/yarn-project/simulator/src/public/avm_executor.test.ts @@ -6,6 +6,7 @@ import { AvmTestContractArtifact } from '@aztec/noir-contracts.js'; import { type MockProxy, mock } from 'jest-mock-extended'; +import { initContext, initExecutionEnvironment } from '../avm/fixtures/index.js'; import { type CommitmentsDB, type PublicContractsDB, type PublicStateDB } from './db.js'; import { type PublicExecution } from './execution.js'; import { PublicExecutor } from './executor.js'; @@ -35,34 +36,23 @@ describe('AVM WitGen and Proof Generation', () => { header = makeHeader(randomInt(1000000)); }, 10000); - it('Should prove valid execution of bytecode that performs addition', async () => { - const args: Fr[] = [new Fr(1), new Fr(2)]; - // Bytecode for the following contract is encoded: - // const bytecode = encodeToBytecode([ - // new CalldataCopy(/*indirect=*/ 0, /*cdOffset=*/ 0, /*copySize=*/ 2, /*dstOffset=*/ 0), - // new Add(/*indirect=*/ 0, TypeTag.FIELD, /*aOffset=*/ 0, /*bOffset=*/ 1, /*dstOffset=*/ 2), - // new Return(/*indirect=*/ 0, /*returnOffset=*/ 2, /*copySize=*/ 1), - // ]); - const bytecode: Buffer = Buffer.from('IAAAAAAAAAAAAgAAAAAAAAYAAAAAAAAAAQAAAAI5AAAAAAIAAAAB', 'base64'); - publicContracts.getBytecode.mockResolvedValue(bytecode); - const executor = new PublicExecutor(publicState, publicContracts, commitmentsDb, header); - const functionData = FunctionData.empty(); - const execution: PublicExecution = { contractAddress, functionData, args, callContext }; - const [proof, vk] = await executor.getAvmProof(execution); - const valid = await executor.verifyAvmProof(vk, proof); - expect(valid).toBe(true); - }); - - // This is skipped as we require MOV to be implemented in the AVM - it.skip('Should prove valid execution contract function that performs addition', async () => { - const args: Fr[] = [new Fr(1), new Fr(2)]; - + it('Should prove valid execution contract function that performs addition', async () => { const addArtifact = AvmTestContractArtifact.functions.find(f => f.name === 'add_args_return')!; const bytecode = addArtifact.bytecode; publicContracts.getBytecode.mockResolvedValue(bytecode); - const functionData = FunctionData.fromAbi(addArtifact); - const execution: PublicExecution = { contractAddress, functionData, args, callContext }; + const functionData = FunctionData.fromAbi(addArtifact); + const args: Fr[] = [new Fr(99), new Fr(12)]; + // We call initContext here to load up a AvmExecutionEnvironment that prepends the calldata with the function selector + // and the args hash. In reality, we should simulate here and get this from the output of the simulation call. + // For now, the interfaces for the PublicExecutor don't quite line up, so we are doing this. + const context = initContext({ env: initExecutionEnvironment({ calldata: args }) }); + const execution: PublicExecution = { + contractAddress, + functionData, + args: context.environment.calldata, + callContext, + }; const executor = new PublicExecutor(publicState, publicContracts, commitmentsDb, header); const [proof, vk] = await executor.getAvmProof(execution); const valid = await executor.verifyAvmProof(vk, proof);