Skip to content

Commit 31996c6

Browse files
committed
fix: shift right overflow in ACIR with unknown var now returns zero
1 parent ef51d8a commit 31996c6

File tree

4 files changed

+61
-0
lines changed

4 files changed

+61
-0
lines changed

compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs

+51
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ impl Context<'_> {
167167
let lhs_typ = self.function.dfg.type_of_value(lhs).unwrap_numeric();
168168
let base = self.field_constant(FieldElement::from(2_u128));
169169
let pow = self.pow(base, rhs);
170+
let pow = self.pow_or_max_for_bit_size(pow, rhs, bit_size, lhs_typ);
170171
let pow = self.insert_cast(pow, lhs_typ);
171172
if lhs_typ.is_unsigned() {
172173
// unsigned right bit shift is just a normal division
@@ -205,6 +206,56 @@ impl Context<'_> {
205206
}
206207
}
207208

209+
/// Returns `pow` or the maximum value allowed for `typ` if 2^rhs is guaranteed to exceed that maximum.
210+
fn pow_or_max_for_bit_size(
211+
&mut self,
212+
pow: ValueId,
213+
rhs: ValueId,
214+
bit_size: u32,
215+
typ: NumericType,
216+
) -> ValueId {
217+
let max = if typ.is_unsigned() {
218+
if bit_size == 128 {
219+
u128::MAX
220+
} else {
221+
(1_u128 << bit_size) - 1
222+
}
223+
} else {
224+
1_u128 << (bit_size - 1)
225+
};
226+
let max = self.field_constant(FieldElement::from(max));
227+
228+
// Here we check whether rhs is less than the bit_size: if it's not then it will overflow.
229+
// Then we do:
230+
//
231+
// rhs_is_less_than_bit_size = lt rhs, bit_size
232+
// rhs_is_not_less_than_bit_size = not rhs_is_less_than_bit_size
233+
// pow_when_is_less_than_bit_size = rhs_is_less_than_bit_size * pow
234+
// pow_when_is_not_less_than_bit_size = rhs_is_not_less_than_bit_size * max
235+
// pow = add pow_when_is_less_than_bit_size, pow_when_is_not_less_than_bit_size
236+
//
237+
// All operations here are unchecked because they work on field types.
238+
let bit_size_as_field = self.numeric_constant(bit_size as u128, typ);
239+
let rhs_is_less_than_bit_size = self.insert_binary(rhs, BinaryOp::Lt, bit_size_as_field);
240+
let rhs_is_not_less_than_bit_size = self.insert_not(rhs_is_less_than_bit_size);
241+
let rhs_is_less_than_bit_size =
242+
self.insert_cast(rhs_is_less_than_bit_size, NumericType::NativeField);
243+
let rhs_is_not_less_than_bit_size =
244+
self.insert_cast(rhs_is_not_less_than_bit_size, NumericType::NativeField);
245+
let pow_when_is_less_than_bit_size =
246+
self.insert_binary(rhs_is_less_than_bit_size, BinaryOp::Mul { unchecked: true }, pow);
247+
let pow_when_is_not_less_than_bit_size = self.insert_binary(
248+
rhs_is_not_less_than_bit_size,
249+
BinaryOp::Mul { unchecked: true },
250+
max,
251+
);
252+
self.insert_binary(
253+
pow_when_is_less_than_bit_size,
254+
BinaryOp::Add { unchecked: true },
255+
pow_when_is_not_less_than_bit_size,
256+
)
257+
}
258+
208259
/// Computes lhs^rhs via square&multiply, using the bits decomposition of rhs
209260
/// Pseudo-code of the computation:
210261
/// let mut r = 1;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[package]
2+
name = "shift_right_overflow"
3+
type = "bin"
4+
authors = [""]
5+
[dependencies]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
x = 9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
fn main(x: u8) -> pub u8 {
2+
// This would previously overflow in ACIR. Now it returns zero.
3+
1 >> x
4+
}

0 commit comments

Comments
 (0)