Skip to content

Commit

Permalink
code review
Browse files Browse the repository at this point in the history
  • Loading branch information
guipublic committed Feb 5, 2025
1 parent e10a64b commit 1db3d61
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 126 deletions.
1 change: 1 addition & 0 deletions compiler/noirc_evaluator/src/ssa/ir/function_inserter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ impl<'f> FunctionInserter<'f> {
instruction.map_values_mut(|id| self.resolve(id));
self.function.dfg.set_instruction(id, instruction);
}

/// Maps a terminator in place, replacing any ValueId in the terminator with the
/// resolved version of that value id from this FunctionInserter's internal value mapping.
pub(crate) fn map_terminator_in_place(&mut self, block: BasicBlockId) {
Expand Down
250 changes: 124 additions & 126 deletions compiler/noirc_evaluator/src/ssa/opt/basic_conditional.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
use std::collections::{HashSet, VecDeque};

use acvm::AcirField;
use fxhash::FxHashMap;
use fxhash::FxHashMap as HashMap;
use iter_extended::vecmap;

use crate::ssa::{
ir::{
basic_block::BasicBlockId,
cfg::ControlFlowGraph,
dfg::DataFlowGraph,
function::{Function, FunctionId, RuntimeType},
function::{Function, FunctionId},
function_inserter::FunctionInserter,
instruction::{BinaryOp, Instruction, TerminatorInstruction},
post_order::PostOrder,
Expand All @@ -31,7 +31,7 @@ impl Ssa {
#[tracing::instrument(level = "trace", skip(self))]
pub(crate) fn flatten_basic_conditionals(mut self) -> Ssa {
// Retrieve the 'no_predicates' attribute of the functions in a map, to avoid problems with borrowing
let mut no_predicates = FxHashMap::default();
let mut no_predicates = HashMap::default();
for function in self.functions.values() {
no_predicates.insert(function.id(), function.is_no_predicates());
}
Expand All @@ -50,104 +50,102 @@ fn is_conditional(
function: &Function,
) -> Option<BasicConditional> {
// jump overhead is the cost for doing the conditional and jump around the blocks
// I use 10 as a rough estimate, real cost is less.
// We use 10 as a rough estimate, the real cost is less.
let jump_overhead = 10;
let mut successors = cfg.successors(block);
let mut result = None;
// a conditional must have 2 branches
if successors.len() == 2 {
let left = successors.next().unwrap();
let right = successors.next().unwrap();
let mut left_successors = cfg.successors(left);
let mut right_successors = cfg.successors(right);
let left_successors_len = left_successors.len();
let right_successors_len = right_successors.len();
let next_left = left_successors.next();
let next_right = right_successors.next();
if next_left == Some(block) || next_right == Some(block) {
// this is a loop, not a conditional
return None;
}
if left_successors_len == 1 && right_successors_len == 1 && next_left == next_right {
// The branches join on one block so it is a non-nested conditional
let cost_left = block_cost(left, &function.dfg);
let cost_right = block_cost(right, &function.dfg);
// For the flattening to be valuable, we compare the cost of the flattened code with the average cost of the 2 branches,
// including an overhead to take into account the jumps between the blocks.
let cost = cost_right.saturating_add(cost_left);
if cost < cost / 2 + jump_overhead {
if let Some(TerminatorInstruction::JmpIf {
condition: _,
then_destination,
else_destination,
call_stack: _,
}) = function.dfg[block].terminator()
{
result = Some(BasicConditional {
block_entry: block,
block_then: Some(*then_destination),
block_else: Some(*else_destination),
block_exit: next_left.unwrap(),
});
}
if successors.len() != 2 {
return None;
}
let left = successors.next().unwrap();
let right = successors.next().unwrap();
let mut left_successors = cfg.successors(left);
let mut right_successors = cfg.successors(right);
let left_successors_len = left_successors.len();
let right_successors_len = right_successors.len();
let next_left = left_successors.next();
let next_right = right_successors.next();
if next_left == Some(block) || next_right == Some(block) {
// this is a loop, not a conditional
return None;
}
if left_successors_len == 1 && right_successors_len == 1 && next_left == next_right {
// The branches join on one block so it is a non-nested conditional
let cost_left = block_cost(left, &function.dfg);
let cost_right = block_cost(right, &function.dfg);
// For the flattening to be valuable, we compare the cost of the flattened code with the average cost of the 2 branches,
// including an overhead to take into account the jumps between the blocks.
let cost = cost_right.saturating_add(cost_left);
if cost < cost / 2 + jump_overhead {
if let Some(TerminatorInstruction::JmpIf {
condition: _,
then_destination,
else_destination,
call_stack: _,
}) = function.dfg[block].terminator()
{
result = Some(BasicConditional {
block_entry: block,
block_then: Some(*then_destination),
block_else: Some(*else_destination),
block_exit: next_left.unwrap(),
});
}
} else if left_successors_len == 1 && next_left == Some(right) {
// Left branch joins the right branch, it is a if/then statement with no else
// I am not sure whether this case can happen, but it is not difficult to handle it
let cost = block_cost(left, &function.dfg);
if cost < cost / 2 + jump_overhead {
if let Some(TerminatorInstruction::JmpIf {
condition: _,
then_destination,
else_destination,
call_stack: _,
}) = function.dfg[block].terminator()
{
if left == *then_destination {
result = Some(BasicConditional {
block_entry: block,
block_then: Some(left),
block_else: None,
block_exit: right,
});
} else if left == *else_destination {
result = Some(BasicConditional {
block_entry: block,
block_then: None,
block_else: Some(left),
block_exit: right,
});
}
}
}
} else if left_successors_len == 1 && next_left == Some(right) {
// Left branch joins the right branch, it is a if/then statement with no else
// I am not sure whether this case can happen, but it is not difficult to handle it
let cost = block_cost(left, &function.dfg);
if cost < cost / 2 + jump_overhead {
if let Some(TerminatorInstruction::JmpIf {
condition: _,
then_destination,
else_destination,
call_stack: _,
}) = function.dfg[block].terminator()
{
let (block_then, block_else) = if left == *then_destination {
(Some(left), None)
} else if left == *else_destination {
(None, Some(left))
} else {
return None;
};

result = Some(BasicConditional {
block_entry: block,
block_then,
block_else,
block_exit: right,
});
}
} else if right_successors_len == 1 && next_right == Some(left) {
// Right branch joins the right branch, it is a if/else statement with no then
// I am not sure whether this case can happen, but it is not difficult to handle it
let cost = block_cost(right, &function.dfg);
if cost < cost / 2 + jump_overhead {
if let Some(TerminatorInstruction::JmpIf {
condition: _,
then_destination,
else_destination,
call_stack: _,
}) = function.dfg[block].terminator()
{
if right == *else_destination {
result = Some(BasicConditional {
block_entry: block,
block_then: None,
block_else: Some(right),
block_exit: right,
});
} else if right == *then_destination {
result = Some(BasicConditional {
block_entry: block,
block_then: Some(right),
block_else: None,
block_exit: right,
});
}
}
}
} else if right_successors_len == 1 && next_right == Some(left) {
// Right branch joins the right branch, it is a if/else statement with no then
// I am not sure whether this case can happen, but it is not difficult to handle it
let cost = block_cost(right, &function.dfg);
if cost < cost / 2 + jump_overhead {
if let Some(TerminatorInstruction::JmpIf {
condition: _,
then_destination,
else_destination,
call_stack: _,
}) = function.dfg[block].terminator()
{
let (block_then, block_else) = if right == *then_destination {
(Some(right), None)
} else if right == *else_destination {
(None, Some(right))
} else {
return None;
};
result = Some(BasicConditional {
block_entry: block,
block_then,
block_else,
block_exit: right,
});
}
}
}
Expand All @@ -157,73 +155,74 @@ fn is_conditional(

/// Computes a cost estimate of a basic block
/// returns u32::MAX if the block has side-effect instructions
/// WARNING: the estimates are estimate of the runtime cost of each instructions,
/// 1 being the cost of the simplest instruction. These numbers can be improved.
fn block_cost(block: BasicBlockId, dfg: &DataFlowGraph) -> u32 {
let mut cost: u32 = 0;

for instruction in dfg[block].instructions() {
let instruction_cost = match &dfg[*instruction] {
Instruction::Binary(binary) => {
match binary.operator {
BinaryOp::Add { unchecked }
| BinaryOp::Sub { unchecked }
| BinaryOp::Mul { unchecked } => if unchecked { 1 } else { return u32::MAX },
| BinaryOp::Mul { unchecked } => if unchecked { 3 } else { return u32::MAX },
BinaryOp::Div
| BinaryOp::Mod => return u32::MAX,
BinaryOp::Eq => 2,
BinaryOp::Eq => 1,
BinaryOp::Lt => 5,
BinaryOp::And
| BinaryOp::Or
| BinaryOp::Xor => 1, //todo
| BinaryOp::Xor => 1,
BinaryOp::Shl
| BinaryOp::Shr => return u32::MAX,
}
},
Instruction::Cast(_, _) => 1,//TODO check if this instruction can fail
// A Cast can be either simplified, or lead to a truncate
Instruction::Cast(_, _) => 3,
Instruction::Not(_) => 1,
Instruction::Truncate { .. } => 7,

Instruction::Constrain(_,_,_)
| Instruction::ConstrainNotEqual(_,_,_)
| Instruction::RangeCheck { .. }
| Instruction::Call { .. } //TODO support calls with no-predicate set to true
// Calls with no-predicate set to true could be supported, but
// they are likely to be too costly anyways. Simple calls would
// have been inlined already.
| Instruction::Call { .. }
| Instruction::Load { .. }
| Instruction::Store { .. }
| Instruction::ArraySet { .. } => return u32::MAX,

Instruction::ArrayGet { array, index } => {
// A get can fail because of out-of-bound index
let mut get_cost = u32::MAX;
// check if index is in bound
let index = dfg.get_numeric_constant(*index);
let mut ok_bound = false;
if let Some(index) = index {
let len = dfg.try_get_array_length(*array).unwrap();
if let (Some(index), Some(len)) = (dfg.get_numeric_constant(*index), dfg.try_get_array_length(*array)) {
// The index is in-bounds
if index.to_u128() < len as u128 {
ok_bound = true;
get_cost = 1;
}
}
if ok_bound {
1
} else {
return u32::MAX;
}
get_cost
},

Instruction::Allocate => 0,
Instruction::EnableSideEffectsIf { .. } => 0,
Instruction::IncrementRc { .. } => 0,
Instruction::DecrementRc { .. } => 0,
Instruction::Allocate
| Instruction::EnableSideEffectsIf { .. }
| Instruction::IncrementRc { .. }
| Instruction::DecrementRc { .. }
| Instruction::MakeArray { .. }
| Instruction::Noop => 0,
Instruction::IfElse { .. } => 1,
Instruction::MakeArray { .. } => 0,
Instruction::Noop => 0,
};
cost += instruction_cost;
}
cost
}

/// Identifies all simple conditionals in the function and flattens them
fn flatten_function(function: &mut Function, no_predicates: &mut FxHashMap<FunctionId, bool>) {
fn flatten_function(function: &mut Function, no_predicates: &mut HashMap<FunctionId, bool>) {
// This pass is dedicated to brillig functions
if !matches!(function.runtime(), RuntimeType::Brillig(_)) {
if !function.runtime().is_brillig() {
return;
}
let cfg = ControlFlowGraph::with_function(function);
Expand Down Expand Up @@ -260,7 +259,7 @@ impl<'f> Context<'f> {
fn flatten_single_conditional(
&mut self,
conditional: &BasicConditional,
no_predicates: &mut FxHashMap<FunctionId, bool>,
no_predicates: &mut HashMap<FunctionId, bool>,
) {
// Manually inline 'then', 'else' and 'exit' into the entry block
//0. initialize the context for flattening a 'single conditional'
Expand Down Expand Up @@ -335,7 +334,7 @@ impl<'f> Context<'f> {
}

fn map_block_with_mapping(
mapping: FxHashMap<ValueId, ValueId>,
mapping: HashMap<ValueId, ValueId>,
func: &mut Function,
block: BasicBlockId,
) {
Expand All @@ -352,16 +351,16 @@ impl<'f> Context<'f> {
fn flatten_multiple(
conditionals: &Vec<BasicConditional>,
function: &mut Function,
no_predicates: &mut FxHashMap<FunctionId, bool>,
no_predicates: &mut HashMap<FunctionId, bool>,
) {
// 1. process each basic conditional, using a new context per conditional
let post_order = PostOrder::with_function(function);

let mut mapping = FxHashMap::default();
let mut mapping = HashMap::default();
for conditional in conditionals {
let cfg = ControlFlowGraph::with_function(function);
let cfg_root = function.entry_block();
let mut branch_ends = FxHashMap::default();
let mut branch_ends = HashMap::default();
branch_ends.insert(conditional.block_entry, conditional.block_exit);
let mut context = Context::new(function, cfg, branch_ends, cfg_root);
context.flatten_single_conditional(conditional, no_predicates);
Expand All @@ -370,8 +369,7 @@ impl<'f> Context<'f> {
}
// 2. re-map the full program for values that may been simplified.
if !mapping.is_empty() {
let po = post_order.as_slice();
for block in po {
for block in post_order.as_slice() {
Context::map_block_with_mapping(mapping.clone(), function, *block);
}
}
Expand Down

0 comments on commit 1db3d61

Please sign in to comment.