Skip to content

Commit

Permalink
fix(ssa): Accurately mark binary ops for hoisting and check Div/Mod a…
Browse files Browse the repository at this point in the history
…gainst induction variable lower bound (#7396)
  • Loading branch information
vezenovm authored Feb 18, 2025
1 parent 31cc6a1 commit 64890c0
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 20 deletions.
130 changes: 112 additions & 18 deletions compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,6 @@ impl<'f> LoopInvariantContext<'f> {

let can_be_deduplicated = instruction.can_be_deduplicated(self.inserter.function, false)
|| matches!(instruction, Instruction::MakeArray { .. })
|| matches!(instruction, Instruction::Binary(_))
|| self.can_be_deduplicated_from_loop_bound(&instruction);

is_loop_invariant && can_be_deduplicated
Expand Down Expand Up @@ -313,13 +312,6 @@ impl<'f> LoopInvariantContext<'f> {
binary: &Binary,
induction_vars: &HashMap<ValueId, (FieldElement, FieldElement)>,
) -> bool {
if !matches!(
binary.operator,
BinaryOp::Add { .. } | BinaryOp::Mul { .. } | BinaryOp::Sub { .. }
) {
return false;
}

let operand_type = self.inserter.function.dfg.type_of_value(binary.lhs).unwrap_numeric();

let lhs_const = self.inserter.function.dfg.get_numeric_constant_with_type(binary.lhs);
Expand All @@ -330,7 +322,15 @@ impl<'f> LoopInvariantContext<'f> {
induction_vars.get(&binary.lhs),
induction_vars.get(&binary.rhs),
) {
(Some((lhs, _)), None, None, Some((_, upper_bound))) => (lhs, *upper_bound),
(Some((lhs, _)), None, None, Some((lower_bound, upper_bound))) => {
if matches!(binary.operator, BinaryOp::Div | BinaryOp::Mod) {
// If we have a Div/Mod operation we want to make sure that the
// lower bound is not zero.
(lhs, *lower_bound)
} else {
(lhs, *upper_bound)
}
}
(None, Some((rhs, _)), Some((lower_bound, upper_bound)), None) => {
if matches!(binary.operator, BinaryOp::Sub { .. }) {
// If we are subtracting and the induction variable is on the lhs,
Expand All @@ -343,7 +343,8 @@ impl<'f> LoopInvariantContext<'f> {
_ => return false,
};

// We evaluate this expression using the upper bounds of its inputs to check whether it will ever overflow.
// We evaluate this expression using the upper bounds (or lower in the case of div/mod)
// of its inputs to check whether it will ever overflow.
// If so, this will cause `eval_constant_binary_op` to return `None`.
// Therefore a `Some` value shows that this operation is safe.
eval_constant_binary_op(lhs, rhs, binary.operator, operand_type).is_some()
Expand Down Expand Up @@ -870,8 +871,6 @@ mod test {
b2():
return
b3():
v6 = mul v0, v1
constrain v6 == u32 6
v8 = sub v2, u32 1
jmp b1(v8)
}
Expand All @@ -883,21 +882,116 @@ mod test {
let expected = "
brillig(inline) fn main f0 {
b0(v0: u32, v1: u32):
v3 = mul v0, v1
jmp b1(u32 1)
b1(v2: u32):
v6 = lt v2, u32 4
jmpif v6 then: b3, else: b2
v5 = lt v2, u32 4
jmpif v5 then: b3, else: b2
b2():
return
b3():
constrain v3 == u32 6
v8 = unchecked_sub v2, u32 1
jmp b1(v8)
v6 = unchecked_sub v2, u32 1
jmp b1(v6)
}
";

let ssa = ssa.loop_invariant_code_motion();
assert_normalized_ssa_equals(ssa, expected);
}

#[test]
fn do_not_hoist_unsafe_div() {
// This test is similar to `nested_loop_invariant_code_motion`, the operation
// in question we are trying to hoist is `v9 = div i32 10, v0`.
// Check that the lower bound of the outer loop it checked and that we not
// hoist an operation that can potentially error with a division by zero.
let src = "
brillig(inline) fn main f0 {
b0():
jmp b1(i32 0)
b1(v0: i32):
v4 = lt v0, i32 4
jmpif v4 then: b3, else: b2
b2():
return
b3():
jmp b4(i32 0)
b4(v1: i32):
v5 = lt v1, i32 4
jmpif v5 then: b6, else: b5
b5():
v7 = unchecked_add v0, i32 1
jmp b1(v7)
b6():
v9 = div i32 10, v0
constrain v9 == i32 6
v11 = unchecked_add v1, i32 1
jmp b4(v11)
}
";

let ssa = Ssa::from_str(src).unwrap();

let ssa = ssa.loop_invariant_code_motion();
assert_normalized_ssa_equals(ssa, src);
}

#[test]
fn hoist_safe_div() {
// This test is identical to `do_not_hoist_unsafe_div`, except the loop
// in this test starts with a lower bound of `1`.
let src = "
brillig(inline) fn main f0 {
b0():
jmp b1(i32 1)
b1(v0: i32):
v4 = lt v0, i32 4
jmpif v4 then: b3, else: b2
b2():
return
b3():
jmp b4(i32 0)
b4(v1: i32):
v5 = lt v1, i32 4
jmpif v5 then: b6, else: b5
b5():
v7 = unchecked_add v0, i32 1
jmp b1(v7)
b6():
v9 = div i32 10, v0
constrain v9 == i32 6
v11 = unchecked_add v1, i32 1
jmp b4(v11)
}
";

let ssa = Ssa::from_str(src).unwrap();

let ssa = ssa.loop_invariant_code_motion();
let expected = "
brillig(inline) fn main f0 {
b0():
jmp b1(i32 1)
b1(v0: i32):
v4 = lt v0, i32 4
jmpif v4 then: b3, else: b2
b2():
return
b3():
v6 = div i32 10, v0
jmp b4(i32 0)
b4(v1: i32):
v8 = lt v1, i32 4
jmpif v8 then: b6, else: b5
b5():
v9 = unchecked_add v0, i32 1
jmp b1(v9)
b6():
constrain v6 == i32 6
v11 = unchecked_add v1, i32 1
jmp b4(v11)
}
";

assert_normalized_ssa_equals(ssa, expected);
}
}
Original file line number Diff line number Diff line change
@@ -1,18 +1,30 @@
// Tests a simple loop where we expect loop invariant instructions
// to be hoisted to the loop's pre-header block.
global U32_MAX: u32 = 4294967295;

fn main(x: u32, y: u32) {
loop_(4, x, y);
simple_loop(4, x, y);
loop_with_predicate(4, x, y);
array_read_loop(4, x);
}

fn loop_(upper_bound: u32, x: u32, y: u32) {
fn simple_loop(upper_bound: u32, x: u32, y: u32) {
for _ in 0..upper_bound {
let mut z = x * y;
z = z * x;
assert_eq(z, 12);
}
}

fn loop_with_predicate(upper_bound: u32, x: u32, y: u32) {
for _ in 0..upper_bound {
if x == 5 {
let mut z = U32_MAX * y;
assert_eq(z, 12);
}
}
}

fn array_read_loop(upper_bound: u32, x: u32) {
let arr = [2; 5];
for i in 0..upper_bound {
Expand Down

1 comment on commit 64890c0

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'Test Suite Duration'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 1.20.

Benchmark suite Current: 64890c0 Previous: 31cc6a1 Ratio
AztecProtocol_aztec-packages_noir-projects_noir-protocol-circuits_crates_blob 66 s 51 s 1.29

This comment was automatically generated by workflow using github-action-benchmark.

CC: @TomAFrench

Please sign in to comment.