From 1db3d6105c2613116fad8d38066f9a71a6d7fae7 Mon Sep 17 00:00:00 2001 From: guipublic Date: Wed, 5 Feb 2025 17:40:48 +0000 Subject: [PATCH] code review --- .../src/ssa/ir/function_inserter.rs | 1 + .../src/ssa/opt/basic_conditional.rs | 250 +++++++++--------- 2 files changed, 125 insertions(+), 126 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/ir/function_inserter.rs b/compiler/noirc_evaluator/src/ssa/ir/function_inserter.rs index 553f96a651b..13b5ead5eb6 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/function_inserter.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/function_inserter.rs @@ -85,6 +85,7 @@ impl<'f> FunctionInserter<'f> { instruction.map_values_mut(|id| self.resolve(id)); self.function.dfg.set_instruction(id, instruction); } + /// Maps a terminator in place, replacing any ValueId in the terminator with the /// resolved version of that value id from this FunctionInserter's internal value mapping. pub(crate) fn map_terminator_in_place(&mut self, block: BasicBlockId) { diff --git a/compiler/noirc_evaluator/src/ssa/opt/basic_conditional.rs b/compiler/noirc_evaluator/src/ssa/opt/basic_conditional.rs index 35dad3067f5..a08ac1fbf0d 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/basic_conditional.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/basic_conditional.rs @@ -1,7 +1,7 @@ use std::collections::{HashSet, VecDeque}; use acvm::AcirField; -use fxhash::FxHashMap; +use fxhash::FxHashMap as HashMap; use iter_extended::vecmap; use crate::ssa::{ @@ -9,7 +9,7 @@ use crate::ssa::{ basic_block::BasicBlockId, cfg::ControlFlowGraph, dfg::DataFlowGraph, - function::{Function, FunctionId, RuntimeType}, + function::{Function, FunctionId}, function_inserter::FunctionInserter, instruction::{BinaryOp, Instruction, TerminatorInstruction}, post_order::PostOrder, @@ -31,7 +31,7 @@ impl Ssa { #[tracing::instrument(level = "trace", skip(self))] pub(crate) fn flatten_basic_conditionals(mut self) -> Ssa { // Retrieve the 'no_predicates' attribute of the functions in a map, to avoid problems with borrowing - let mut no_predicates = FxHashMap::default(); + let mut no_predicates = HashMap::default(); for function in self.functions.values() { no_predicates.insert(function.id(), function.is_no_predicates()); } @@ -50,104 +50,102 @@ fn is_conditional( function: &Function, ) -> Option { // jump overhead is the cost for doing the conditional and jump around the blocks - // I use 10 as a rough estimate, real cost is less. + // We use 10 as a rough estimate, the real cost is less. let jump_overhead = 10; let mut successors = cfg.successors(block); let mut result = None; // a conditional must have 2 branches - if successors.len() == 2 { - let left = successors.next().unwrap(); - let right = successors.next().unwrap(); - let mut left_successors = cfg.successors(left); - let mut right_successors = cfg.successors(right); - let left_successors_len = left_successors.len(); - let right_successors_len = right_successors.len(); - let next_left = left_successors.next(); - let next_right = right_successors.next(); - if next_left == Some(block) || next_right == Some(block) { - // this is a loop, not a conditional - return None; - } - if left_successors_len == 1 && right_successors_len == 1 && next_left == next_right { - // The branches join on one block so it is a non-nested conditional - let cost_left = block_cost(left, &function.dfg); - let cost_right = block_cost(right, &function.dfg); - // For the flattening to be valuable, we compare the cost of the flattened code with the average cost of the 2 branches, - // including an overhead to take into account the jumps between the blocks. - let cost = cost_right.saturating_add(cost_left); - if cost < cost / 2 + jump_overhead { - if let Some(TerminatorInstruction::JmpIf { - condition: _, - then_destination, - else_destination, - call_stack: _, - }) = function.dfg[block].terminator() - { - result = Some(BasicConditional { - block_entry: block, - block_then: Some(*then_destination), - block_else: Some(*else_destination), - block_exit: next_left.unwrap(), - }); - } + if successors.len() != 2 { + return None; + } + let left = successors.next().unwrap(); + let right = successors.next().unwrap(); + let mut left_successors = cfg.successors(left); + let mut right_successors = cfg.successors(right); + let left_successors_len = left_successors.len(); + let right_successors_len = right_successors.len(); + let next_left = left_successors.next(); + let next_right = right_successors.next(); + if next_left == Some(block) || next_right == Some(block) { + // this is a loop, not a conditional + return None; + } + if left_successors_len == 1 && right_successors_len == 1 && next_left == next_right { + // The branches join on one block so it is a non-nested conditional + let cost_left = block_cost(left, &function.dfg); + let cost_right = block_cost(right, &function.dfg); + // For the flattening to be valuable, we compare the cost of the flattened code with the average cost of the 2 branches, + // including an overhead to take into account the jumps between the blocks. + let cost = cost_right.saturating_add(cost_left); + if cost < cost / 2 + jump_overhead { + if let Some(TerminatorInstruction::JmpIf { + condition: _, + then_destination, + else_destination, + call_stack: _, + }) = function.dfg[block].terminator() + { + result = Some(BasicConditional { + block_entry: block, + block_then: Some(*then_destination), + block_else: Some(*else_destination), + block_exit: next_left.unwrap(), + }); } - } else if left_successors_len == 1 && next_left == Some(right) { - // Left branch joins the right branch, it is a if/then statement with no else - // I am not sure whether this case can happen, but it is not difficult to handle it - let cost = block_cost(left, &function.dfg); - if cost < cost / 2 + jump_overhead { - if let Some(TerminatorInstruction::JmpIf { - condition: _, - then_destination, - else_destination, - call_stack: _, - }) = function.dfg[block].terminator() - { - if left == *then_destination { - result = Some(BasicConditional { - block_entry: block, - block_then: Some(left), - block_else: None, - block_exit: right, - }); - } else if left == *else_destination { - result = Some(BasicConditional { - block_entry: block, - block_then: None, - block_else: Some(left), - block_exit: right, - }); - } - } + } + } else if left_successors_len == 1 && next_left == Some(right) { + // Left branch joins the right branch, it is a if/then statement with no else + // I am not sure whether this case can happen, but it is not difficult to handle it + let cost = block_cost(left, &function.dfg); + if cost < cost / 2 + jump_overhead { + if let Some(TerminatorInstruction::JmpIf { + condition: _, + then_destination, + else_destination, + call_stack: _, + }) = function.dfg[block].terminator() + { + let (block_then, block_else) = if left == *then_destination { + (Some(left), None) + } else if left == *else_destination { + (None, Some(left)) + } else { + return None; + }; + + result = Some(BasicConditional { + block_entry: block, + block_then, + block_else, + block_exit: right, + }); } - } else if right_successors_len == 1 && next_right == Some(left) { - // Right branch joins the right branch, it is a if/else statement with no then - // I am not sure whether this case can happen, but it is not difficult to handle it - let cost = block_cost(right, &function.dfg); - if cost < cost / 2 + jump_overhead { - if let Some(TerminatorInstruction::JmpIf { - condition: _, - then_destination, - else_destination, - call_stack: _, - }) = function.dfg[block].terminator() - { - if right == *else_destination { - result = Some(BasicConditional { - block_entry: block, - block_then: None, - block_else: Some(right), - block_exit: right, - }); - } else if right == *then_destination { - result = Some(BasicConditional { - block_entry: block, - block_then: Some(right), - block_else: None, - block_exit: right, - }); - } - } + } + } else if right_successors_len == 1 && next_right == Some(left) { + // Right branch joins the right branch, it is a if/else statement with no then + // I am not sure whether this case can happen, but it is not difficult to handle it + let cost = block_cost(right, &function.dfg); + if cost < cost / 2 + jump_overhead { + if let Some(TerminatorInstruction::JmpIf { + condition: _, + then_destination, + else_destination, + call_stack: _, + }) = function.dfg[block].terminator() + { + let (block_then, block_else) = if right == *then_destination { + (Some(right), None) + } else if right == *else_destination { + (None, Some(right)) + } else { + return None; + }; + result = Some(BasicConditional { + block_entry: block, + block_then, + block_else, + block_exit: right, + }); } } } @@ -157,63 +155,64 @@ fn is_conditional( /// Computes a cost estimate of a basic block /// returns u32::MAX if the block has side-effect instructions +/// WARNING: the estimates are estimate of the runtime cost of each instructions, +/// 1 being the cost of the simplest instruction. These numbers can be improved. fn block_cost(block: BasicBlockId, dfg: &DataFlowGraph) -> u32 { let mut cost: u32 = 0; - for instruction in dfg[block].instructions() { let instruction_cost = match &dfg[*instruction] { Instruction::Binary(binary) => { match binary.operator { BinaryOp::Add { unchecked } | BinaryOp::Sub { unchecked } - | BinaryOp::Mul { unchecked } => if unchecked { 1 } else { return u32::MAX }, + | BinaryOp::Mul { unchecked } => if unchecked { 3 } else { return u32::MAX }, BinaryOp::Div | BinaryOp::Mod => return u32::MAX, - BinaryOp::Eq => 2, + BinaryOp::Eq => 1, BinaryOp::Lt => 5, BinaryOp::And | BinaryOp::Or - | BinaryOp::Xor => 1, //todo + | BinaryOp::Xor => 1, BinaryOp::Shl | BinaryOp::Shr => return u32::MAX, } }, - Instruction::Cast(_, _) => 1,//TODO check if this instruction can fail + // A Cast can be either simplified, or lead to a truncate + Instruction::Cast(_, _) => 3, Instruction::Not(_) => 1, Instruction::Truncate { .. } => 7, Instruction::Constrain(_,_,_) | Instruction::ConstrainNotEqual(_,_,_) | Instruction::RangeCheck { .. } - | Instruction::Call { .. } //TODO support calls with no-predicate set to true + // Calls with no-predicate set to true could be supported, but + // they are likely to be too costly anyways. Simple calls would + // have been inlined already. + | Instruction::Call { .. } | Instruction::Load { .. } | Instruction::Store { .. } | Instruction::ArraySet { .. } => return u32::MAX, Instruction::ArrayGet { array, index } => { + // A get can fail because of out-of-bound index + let mut get_cost = u32::MAX; // check if index is in bound - let index = dfg.get_numeric_constant(*index); - let mut ok_bound = false; - if let Some(index) = index { - let len = dfg.try_get_array_length(*array).unwrap(); + if let (Some(index), Some(len)) = (dfg.get_numeric_constant(*index), dfg.try_get_array_length(*array)) { + // The index is in-bounds if index.to_u128() < len as u128 { - ok_bound = true; + get_cost = 1; } } - if ok_bound { - 1 - } else { - return u32::MAX; - } + get_cost }, - Instruction::Allocate => 0, - Instruction::EnableSideEffectsIf { .. } => 0, - Instruction::IncrementRc { .. } => 0, - Instruction::DecrementRc { .. } => 0, + Instruction::Allocate + | Instruction::EnableSideEffectsIf { .. } + | Instruction::IncrementRc { .. } + | Instruction::DecrementRc { .. } + | Instruction::MakeArray { .. } + | Instruction::Noop => 0, Instruction::IfElse { .. } => 1, - Instruction::MakeArray { .. } => 0, - Instruction::Noop => 0, }; cost += instruction_cost; } @@ -221,9 +220,9 @@ fn block_cost(block: BasicBlockId, dfg: &DataFlowGraph) -> u32 { } /// Identifies all simple conditionals in the function and flattens them -fn flatten_function(function: &mut Function, no_predicates: &mut FxHashMap) { +fn flatten_function(function: &mut Function, no_predicates: &mut HashMap) { // This pass is dedicated to brillig functions - if !matches!(function.runtime(), RuntimeType::Brillig(_)) { + if !function.runtime().is_brillig() { return; } let cfg = ControlFlowGraph::with_function(function); @@ -260,7 +259,7 @@ impl<'f> Context<'f> { fn flatten_single_conditional( &mut self, conditional: &BasicConditional, - no_predicates: &mut FxHashMap, + no_predicates: &mut HashMap, ) { // Manually inline 'then', 'else' and 'exit' into the entry block //0. initialize the context for flattening a 'single conditional' @@ -335,7 +334,7 @@ impl<'f> Context<'f> { } fn map_block_with_mapping( - mapping: FxHashMap, + mapping: HashMap, func: &mut Function, block: BasicBlockId, ) { @@ -352,16 +351,16 @@ impl<'f> Context<'f> { fn flatten_multiple( conditionals: &Vec, function: &mut Function, - no_predicates: &mut FxHashMap, + no_predicates: &mut HashMap, ) { // 1. process each basic conditional, using a new context per conditional let post_order = PostOrder::with_function(function); - let mut mapping = FxHashMap::default(); + let mut mapping = HashMap::default(); for conditional in conditionals { let cfg = ControlFlowGraph::with_function(function); let cfg_root = function.entry_block(); - let mut branch_ends = FxHashMap::default(); + let mut branch_ends = HashMap::default(); branch_ends.insert(conditional.block_entry, conditional.block_exit); let mut context = Context::new(function, cfg, branch_ends, cfg_root); context.flatten_single_conditional(conditional, no_predicates); @@ -370,8 +369,7 @@ impl<'f> Context<'f> { } // 2. re-map the full program for values that may been simplified. if !mapping.is_empty() { - let po = post_order.as_slice(); - for block in po { + for block in post_order.as_slice() { Context::map_block_with_mapping(mapping.clone(), function, *block); } }