Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: shift right overflow in ACIR with unknown var now returns zero #7509

Merged
merged 5 commits into from
Feb 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ impl Context<'_> {
let lhs_typ = self.function.dfg.type_of_value(lhs).unwrap_numeric();
let base = self.field_constant(FieldElement::from(2_u128));
let pow = self.pow(base, rhs);
let pow = self.pow_or_max_for_bit_size(pow, rhs, bit_size, lhs_typ);
let pow = self.insert_cast(pow, lhs_typ);
if lhs_typ.is_unsigned() {
// unsigned right bit shift is just a normal division
Expand Down Expand Up @@ -205,6 +206,53 @@ impl Context<'_> {
}
}

/// Returns `pow` or the maximum value allowed for `typ` if 2^rhs is guaranteed to exceed that maximum.
fn pow_or_max_for_bit_size(
&mut self,
pow: ValueId,
rhs: ValueId,
bit_size: u32,
typ: NumericType,
) -> ValueId {
let max = if typ.is_unsigned() {
if bit_size == 128 { u128::MAX } else { (1_u128 << bit_size) - 1 }
} else {
1_u128 << (bit_size - 1)
};
let max = self.field_constant(FieldElement::from(max));

// Here we check whether rhs is less than the bit_size: if it's not then it will overflow.
// Then we do:
//
// rhs_is_less_than_bit_size = lt rhs, bit_size
// rhs_is_not_less_than_bit_size = not rhs_is_less_than_bit_size
// pow_when_is_less_than_bit_size = rhs_is_less_than_bit_size * pow
// pow_when_is_not_less_than_bit_size = rhs_is_not_less_than_bit_size * max
// pow = add pow_when_is_less_than_bit_size, pow_when_is_not_less_than_bit_size
//
// All operations here are unchecked because they work on field types.
let rhs_typ = self.function.dfg.type_of_value(rhs).unwrap_numeric();
let bit_size = self.numeric_constant(bit_size as u128, rhs_typ);
let rhs_is_less_than_bit_size = self.insert_binary(rhs, BinaryOp::Lt, bit_size);
let rhs_is_not_less_than_bit_size = self.insert_not(rhs_is_less_than_bit_size);
let rhs_is_less_than_bit_size =
self.insert_cast(rhs_is_less_than_bit_size, NumericType::NativeField);
let rhs_is_not_less_than_bit_size =
self.insert_cast(rhs_is_not_less_than_bit_size, NumericType::NativeField);
let pow_when_is_less_than_bit_size =
self.insert_binary(rhs_is_less_than_bit_size, BinaryOp::Mul { unchecked: true }, pow);
let pow_when_is_not_less_than_bit_size = self.insert_binary(
rhs_is_not_less_than_bit_size,
BinaryOp::Mul { unchecked: true },
max,
);
self.insert_binary(
pow_when_is_less_than_bit_size,
BinaryOp::Add { unchecked: true },
pow_when_is_not_less_than_bit_size,
)
}

/// Computes lhs^rhs via square&multiply, using the bits decomposition of rhs
/// Pseudo-code of the computation:
/// let mut r = 1;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[package]
name = "shift_right_overflow"
type = "bin"
authors = [""]
[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
x = 9
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
fn main(x: u8) {
// This would previously overflow in ACIR. Now it returns zero.
let value = 1 >> x;
assert_eq(value, 0);
}
Loading