Skip to content

Commit

Permalink
Merge branch 'master' into tf/optimize-brillig-radix
Browse files Browse the repository at this point in the history
* master:
  chore: avoid u128s in brillig memory (#7363)
  chore: update docs about integer overflows (#7370)
  fix!: Only decrement the counter of an array if its address has not changed (#7297)
  fix: let LSP read `noirfmt.toml` for formatting files (#7355)
  chore: deprecate keccak256 (#7361)
  • Loading branch information
TomAFrench committed Feb 13, 2025
2 parents dc79f46 + c3deb6a commit 6ca45a6
Show file tree
Hide file tree
Showing 37 changed files with 866 additions and 543 deletions.
260 changes: 174 additions & 86 deletions acvm-repo/brillig_vm/src/arithmetic.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use acir::brillig::{BinaryFieldOp, BinaryIntOp, IntegerBitSize};
use std::ops::{BitAnd, BitOr, BitXor, Shl, Shr};

use acir::brillig::{BinaryFieldOp, BinaryIntOp, BitSize, IntegerBitSize};
use acir::AcirField;
use num_bigint::BigUint;
use num_traits::{AsPrimitive, PrimInt, WrappingAdd, WrappingMul, WrappingSub};
use num_traits::{CheckedDiv, WrappingAdd, WrappingMul, WrappingSub, Zero};

use crate::memory::{MemoryTypeError, MemoryValue};

Expand All @@ -21,24 +23,20 @@ pub(crate) fn evaluate_binary_field_op<F: AcirField>(
lhs: MemoryValue<F>,
rhs: MemoryValue<F>,
) -> Result<MemoryValue<F>, BrilligArithmeticError> {
let a = match lhs {
MemoryValue::Field(a) => a,
MemoryValue::Integer(_, bit_size) => {
return Err(BrilligArithmeticError::MismatchedLhsBitSize {
lhs_bit_size: bit_size.into(),
op_bit_size: F::max_num_bits(),
});
let a = *lhs.expect_field().map_err(|err| {
let MemoryTypeError::MismatchedBitSize { value_bit_size, expected_bit_size } = err;
BrilligArithmeticError::MismatchedLhsBitSize {
lhs_bit_size: value_bit_size,
op_bit_size: expected_bit_size,
}
};
let b = match rhs {
MemoryValue::Field(b) => b,
MemoryValue::Integer(_, bit_size) => {
return Err(BrilligArithmeticError::MismatchedRhsBitSize {
rhs_bit_size: bit_size.into(),
op_bit_size: F::max_num_bits(),
});
})?;
let b = *rhs.expect_field().map_err(|err| {
let MemoryTypeError::MismatchedBitSize { value_bit_size, expected_bit_size } = err;
BrilligArithmeticError::MismatchedRhsBitSize {
rhs_bit_size: value_bit_size,
op_bit_size: expected_bit_size,
}
};
})?;

Ok(match op {
// Perform addition, subtraction, multiplication, and division based on the BinaryOp variant.
Expand Down Expand Up @@ -70,46 +68,120 @@ pub(crate) fn evaluate_binary_int_op<F: AcirField>(
rhs: MemoryValue<F>,
bit_size: IntegerBitSize,
) -> Result<MemoryValue<F>, BrilligArithmeticError> {
let lhs = lhs.expect_integer_with_bit_size(bit_size).map_err(|err| match err {
MemoryTypeError::MismatchedBitSize { value_bit_size, expected_bit_size } => {
BrilligArithmeticError::MismatchedLhsBitSize {
lhs_bit_size: value_bit_size,
op_bit_size: expected_bit_size,
match op {
BinaryIntOp::Add
| BinaryIntOp::Sub
| BinaryIntOp::Mul
| BinaryIntOp::Div
| BinaryIntOp::And
| BinaryIntOp::Or
| BinaryIntOp::Xor => match (lhs, rhs, bit_size) {
(MemoryValue::U1(lhs), MemoryValue::U1(rhs), IntegerBitSize::U1) => {
evaluate_binary_int_op_u1(op, lhs, rhs).map(MemoryValue::U1)
}
}
})?;

let rhs_bit_size = if op == &BinaryIntOp::Shl || op == &BinaryIntOp::Shr {
IntegerBitSize::U8
} else {
bit_size
};
(MemoryValue::U8(lhs), MemoryValue::U8(rhs), IntegerBitSize::U8) => {
evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U8)
}
(MemoryValue::U16(lhs), MemoryValue::U16(rhs), IntegerBitSize::U16) => {
evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U16)
}
(MemoryValue::U32(lhs), MemoryValue::U32(rhs), IntegerBitSize::U32) => {
evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U32)
}
(MemoryValue::U64(lhs), MemoryValue::U64(rhs), IntegerBitSize::U64) => {
evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U64)
}
(MemoryValue::U128(lhs), MemoryValue::U128(rhs), IntegerBitSize::U128) => {
evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U128)
}
(lhs, _, _) if lhs.bit_size() != BitSize::Integer(bit_size) => {
Err(BrilligArithmeticError::MismatchedLhsBitSize {
lhs_bit_size: lhs.bit_size().to_u32::<F>(),
op_bit_size: bit_size.into(),
})
}
(_, rhs, _) if rhs.bit_size() != BitSize::Integer(bit_size) => {
Err(BrilligArithmeticError::MismatchedRhsBitSize {
rhs_bit_size: rhs.bit_size().to_u32::<F>(),
op_bit_size: bit_size.into(),
})
}
_ => unreachable!("Invalid arguments are covered by the two arms above."),
},

let rhs = rhs.expect_integer_with_bit_size(rhs_bit_size).map_err(|err| match err {
MemoryTypeError::MismatchedBitSize { value_bit_size, expected_bit_size } => {
BrilligArithmeticError::MismatchedRhsBitSize {
rhs_bit_size: value_bit_size,
op_bit_size: expected_bit_size,
BinaryIntOp::Equals | BinaryIntOp::LessThan | BinaryIntOp::LessThanEquals => {
match (lhs, rhs, bit_size) {
(MemoryValue::U1(lhs), MemoryValue::U1(rhs), IntegerBitSize::U1) => {
Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
}
(MemoryValue::U8(lhs), MemoryValue::U8(rhs), IntegerBitSize::U8) => {
Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
}
(MemoryValue::U16(lhs), MemoryValue::U16(rhs), IntegerBitSize::U16) => {
Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
}
(MemoryValue::U32(lhs), MemoryValue::U32(rhs), IntegerBitSize::U32) => {
Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
}
(MemoryValue::U64(lhs), MemoryValue::U64(rhs), IntegerBitSize::U64) => {
Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
}
(MemoryValue::U128(lhs), MemoryValue::U128(rhs), IntegerBitSize::U128) => {
Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
}
(lhs, _, _) if lhs.bit_size() != BitSize::Integer(bit_size) => {
Err(BrilligArithmeticError::MismatchedLhsBitSize {
lhs_bit_size: lhs.bit_size().to_u32::<F>(),
op_bit_size: bit_size.into(),
})
}
(_, rhs, _) if rhs.bit_size() != BitSize::Integer(bit_size) => {
Err(BrilligArithmeticError::MismatchedRhsBitSize {
rhs_bit_size: rhs.bit_size().to_u32::<F>(),
op_bit_size: bit_size.into(),
})
}
_ => unreachable!("Invalid arguments are covered by the two arms above."),
}
}
})?;

// `lhs` and `rhs` are asserted to fit within their given types when being read from memory so this is safe.
let result = match bit_size {
IntegerBitSize::U1 => evaluate_binary_int_op_u1(op, lhs != 0, rhs != 0)?.into(),
IntegerBitSize::U8 => evaluate_binary_int_op_num(op, lhs as u8, rhs as u8, 8)?.into(),
IntegerBitSize::U16 => evaluate_binary_int_op_num(op, lhs as u16, rhs as u16, 16)?.into(),
IntegerBitSize::U32 => evaluate_binary_int_op_num(op, lhs as u32, rhs as u32, 32)?.into(),
IntegerBitSize::U64 => evaluate_binary_int_op_num(op, lhs as u64, rhs as u64, 64)?.into(),
IntegerBitSize::U128 => evaluate_binary_int_op_num(op, lhs, rhs, 128)?,
};

Ok(match op {
BinaryIntOp::Equals | BinaryIntOp::LessThan | BinaryIntOp::LessThanEquals => {
MemoryValue::new_integer(result, IntegerBitSize::U1)
BinaryIntOp::Shl | BinaryIntOp::Shr => {
let rhs = rhs.expect_u8().map_err(
|MemoryTypeError::MismatchedBitSize { value_bit_size, expected_bit_size }| {
BrilligArithmeticError::MismatchedRhsBitSize {
rhs_bit_size: value_bit_size,
op_bit_size: expected_bit_size,
}
},
)?;

match (lhs, bit_size) {
(MemoryValue::U1(lhs), IntegerBitSize::U1) => {
let result = if rhs == 0 { lhs } else { false };
Ok(MemoryValue::U1(result))
}
(MemoryValue::U8(lhs), IntegerBitSize::U8) => {
Ok(MemoryValue::U8(evaluate_binary_int_op_shifts(op, lhs, rhs)))
}
(MemoryValue::U16(lhs), IntegerBitSize::U16) => {
Ok(MemoryValue::U16(evaluate_binary_int_op_shifts(op, lhs, rhs)))
}
(MemoryValue::U32(lhs), IntegerBitSize::U32) => {
Ok(MemoryValue::U32(evaluate_binary_int_op_shifts(op, lhs, rhs)))
}
(MemoryValue::U64(lhs), IntegerBitSize::U64) => {
Ok(MemoryValue::U64(evaluate_binary_int_op_shifts(op, lhs, rhs)))
}
(MemoryValue::U128(lhs), IntegerBitSize::U128) => {
Ok(MemoryValue::U128(evaluate_binary_int_op_shifts(op, lhs, rhs)))
}
_ => Err(BrilligArithmeticError::MismatchedLhsBitSize {
lhs_bit_size: lhs.bit_size().to_u32::<F>(),
op_bit_size: bit_size.into(),
}),
}
}
_ => MemoryValue::new_integer(result, bit_size),
})
}
}

fn evaluate_binary_int_op_u1(
Expand All @@ -118,67 +190,83 @@ fn evaluate_binary_int_op_u1(
rhs: bool,
) -> Result<bool, BrilligArithmeticError> {
let result = match op {
BinaryIntOp::Add | BinaryIntOp::Sub => lhs ^ rhs,
BinaryIntOp::Mul => lhs & rhs,
BinaryIntOp::Equals => lhs == rhs,
BinaryIntOp::LessThan => !lhs & rhs,
BinaryIntOp::LessThanEquals => lhs <= rhs,
BinaryIntOp::And | BinaryIntOp::Mul => lhs & rhs,
BinaryIntOp::Or => lhs | rhs,
BinaryIntOp::Xor | BinaryIntOp::Add | BinaryIntOp::Sub => lhs ^ rhs,
BinaryIntOp::Div => {
if !rhs {
return Err(BrilligArithmeticError::DivisionByZero);
} else {
lhs
}
}
_ => unreachable!("Operator not handled by this function: {op:?}"),
};
Ok(result)
}

fn evaluate_binary_int_op_cmp<T: Ord + PartialEq>(op: &BinaryIntOp, lhs: T, rhs: T) -> bool {
match op {
BinaryIntOp::Equals => lhs == rhs,
BinaryIntOp::LessThan => !lhs & rhs,
BinaryIntOp::LessThan => lhs < rhs,
BinaryIntOp::LessThanEquals => lhs <= rhs,
BinaryIntOp::And => lhs & rhs,
BinaryIntOp::Or => lhs | rhs,
BinaryIntOp::Xor => lhs ^ rhs,
BinaryIntOp::Shl | BinaryIntOp::Shr => {
if rhs {
false
_ => unreachable!("Operator not handled by this function: {op:?}"),
}
}

fn evaluate_binary_int_op_shifts<T: From<u8> + Zero + Shl<Output = T> + Shr<Output = T>>(
op: &BinaryIntOp,
lhs: T,
rhs: u8,
) -> T {
match op {
BinaryIntOp::Shl => {
let rhs_usize: usize = rhs as usize;
#[allow(unused_qualifications)]
if rhs_usize >= 8 * std::mem::size_of::<T>() {
T::zero()
} else {
lhs
lhs << rhs.into()
}
}
};
Ok(result)
BinaryIntOp::Shr => {
let rhs_usize: usize = rhs as usize;
#[allow(unused_qualifications)]
if rhs_usize >= 8 * std::mem::size_of::<T>() {
T::zero()
} else {
lhs >> rhs.into()
}
}
_ => unreachable!("Operator not handled by this function: {op:?}"),
}
}

fn evaluate_binary_int_op_num<
T: PrimInt + AsPrimitive<usize> + From<bool> + WrappingAdd + WrappingSub + WrappingMul,
fn evaluate_binary_int_op_arith<
T: WrappingAdd
+ WrappingSub
+ WrappingMul
+ CheckedDiv
+ BitAnd<Output = T>
+ BitOr<Output = T>
+ BitXor<Output = T>,
>(
op: &BinaryIntOp,
lhs: T,
rhs: T,
num_bits: usize,
) -> Result<T, BrilligArithmeticError> {
let result = match op {
BinaryIntOp::Add => lhs.wrapping_add(&rhs),
BinaryIntOp::Sub => lhs.wrapping_sub(&rhs),
BinaryIntOp::Mul => lhs.wrapping_mul(&rhs),
BinaryIntOp::Div => lhs.checked_div(&rhs).ok_or(BrilligArithmeticError::DivisionByZero)?,
BinaryIntOp::Equals => (lhs == rhs).into(),
BinaryIntOp::LessThan => (lhs < rhs).into(),
BinaryIntOp::LessThanEquals => (lhs <= rhs).into(),
BinaryIntOp::And => lhs & rhs,
BinaryIntOp::Or => lhs | rhs,
BinaryIntOp::Xor => lhs ^ rhs,
BinaryIntOp::Shl => {
let rhs_usize = rhs.as_();
if rhs_usize >= num_bits {
T::zero()
} else {
lhs << rhs_usize
}
}
BinaryIntOp::Shr => {
let rhs_usize = rhs.as_();
if rhs_usize >= num_bits {
T::zero()
} else {
lhs >> rhs_usize
}
}
_ => unreachable!("Operator not handled by this function: {op:?}"),
};
Ok(result)
}
Expand Down
Loading

0 comments on commit 6ca45a6

Please sign in to comment.