Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(perf): Expand RC optimization pass to search return block for inc_rcs #6116

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions compiler/noirc_evaluator/src/ssa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ pub(crate) fn optimize_into_acir(
)?
.try_run_pass(Ssa::unroll_loops_iteratively, "After Unrolling:")?
.run_pass(Ssa::simplify_cfg, "After Simplifying:")
// The RC optimization pass only checks the entry and exit blocks of a function.
// We want it to be after simplify_cfg in case a block can be simplified to be part of the exit block.
.run_pass(Ssa::remove_paired_rc, "After Removing Paired rc_inc & rc_decs:")
.run_pass(Ssa::flatten_cfg, "After Flattening:")
.run_pass(Ssa::remove_bit_shifts, "After Removing Bit Shifts:")
// Run mem2reg once more with the flattened CFG to catch any remaining loads/stores
Expand Down
241 changes: 213 additions & 28 deletions compiler/noirc_evaluator/src/ssa/opt/rc.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::{HashMap, HashSet};
use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet};

use crate::ssa::{
ir::{
Expand Down Expand Up @@ -28,16 +28,27 @@ impl Ssa {
}
}

#[derive(Default)]
struct Context {
struct Context<'f> {
function: &'f Function,

last_block: BasicBlockId,
// All inc_rc instructions encountered without a corresponding dec_rc.
// These are only searched for in the first block of a function.
// These are only searched for in the first and exit block of a function.
//
// The type of the array being operated on is recorded.
// If an array_set to that array type is encountered, that is also recorded.
inc_rcs: HashMap<Type, Vec<IncRc>>,
}

impl<'f> Context<'f> {
fn new(function: &'f Function) -> Self {
let last_block = Self::find_last_block(function);
// let all_block_params =
Context { function, last_block, inc_rcs: HashMap::default() }
}
}

#[derive(Clone, Debug)]
struct IncRc {
id: InstructionId,
array: ValueId,
Expand All @@ -58,11 +69,11 @@ fn remove_paired_rc(function: &mut Function) {
return;
}

let mut context = Context::default();
let mut context = Context::new(function);

context.find_rcs_in_entry_block(function);
context.scan_for_array_sets(function);
let to_remove = context.find_rcs_to_remove(function);
context.find_rcs_in_entry_and_exit_block();
context.scan_for_array_sets();
let to_remove = context.find_rcs_to_remove();
remove_instructions(to_remove, function);
}

Expand All @@ -71,13 +82,17 @@ fn contains_array_parameter(function: &mut Function) -> bool {
parameters.any(|parameter| function.dfg.type_of_value(*parameter).contains_an_array())
}

impl Context {
fn find_rcs_in_entry_block(&mut self, function: &Function) {
let entry = function.entry_block();
impl<'f> Context<'f> {
fn find_rcs_in_entry_and_exit_block(&mut self) {
let entry = self.function.entry_block();
self.find_rcs_in_block(entry);
self.find_rcs_in_block(self.last_block);
}

for instruction in function.dfg[entry].instructions() {
if let Instruction::IncrementRc { value } = &function.dfg[*instruction] {
let typ = function.dfg.type_of_value(*value);
fn find_rcs_in_block(&mut self, block_id: BasicBlockId) {
for instruction in self.function.dfg[block_id].instructions() {
if let Instruction::IncrementRc { value } = &self.function.dfg[*instruction] {
let typ = self.function.dfg.type_of_value(*value);

// We assume arrays aren't mutated until we find an array_set
let inc_rc = IncRc { id: *instruction, array: *value, possibly_mutated: false };
Expand All @@ -88,14 +103,28 @@ impl Context {

/// Find each array_set instruction in the function and mark any arrays used
/// by the inc_rc instructions as possibly mutated if they're the same type.
fn scan_for_array_sets(&mut self, function: &Function) {
for block in function.reachable_blocks() {
for instruction in function.dfg[block].instructions() {
if let Instruction::ArraySet { array, .. } = function.dfg[*instruction] {
let typ = function.dfg.type_of_value(array);
fn scan_for_array_sets(&mut self) {
// Block parameters could be passed to from function parameters.
// Thus, any inc rcs from block parameters with matching array sets need to marked possibly mutated.
let mut per_func_block_params: HashSet<ValueId> = HashSet::default();

for block in self.function.reachable_blocks() {
let block_params = self.function.dfg.block_parameters(block);
per_func_block_params.extend(block_params.iter());
}

for block in self.function.reachable_blocks() {
for instruction in self.function.dfg[block].instructions() {
if let Instruction::ArraySet { array, .. } = self.function.dfg[*instruction] {
let typ = self.function.dfg.type_of_value(array);
if let Some(inc_rcs) = self.inc_rcs.get_mut(&typ) {
for inc_rc in inc_rcs {
inc_rc.possibly_mutated = true;
if inc_rc.array == array
|| self.function.parameters().contains(&inc_rc.array)
|| per_func_block_params.contains(&inc_rc.array)
{
inc_rc.possibly_mutated = true;
}
}
}
}
Expand All @@ -105,13 +134,12 @@ impl Context {

/// Find each dec_rc instruction and if the most recent inc_rc instruction for the same value
/// is not possibly mutated, then we can remove them both. Returns each such pair.
fn find_rcs_to_remove(&mut self, function: &Function) -> HashSet<InstructionId> {
let last_block = function.find_last_block();
let mut to_remove = HashSet::new();
fn find_rcs_to_remove(&mut self) -> HashSet<InstructionId> {
let mut to_remove = HashSet::default();

for instruction in function.dfg[last_block].instructions() {
if let Instruction::DecrementRc { value } = &function.dfg[*instruction] {
if let Some(inc_rc) = self.pop_rc_for(*value, function) {
for instruction in self.function.dfg[self.last_block].instructions() {
if let Instruction::DecrementRc { value } = &self.function.dfg[*instruction] {
if let Some(inc_rc) = self.pop_rc_for(*value) {
if !inc_rc.possibly_mutated {
to_remove.insert(inc_rc.id);
to_remove.insert(*instruction);
Expand All @@ -124,8 +152,8 @@ impl Context {
}

/// Finds and pops the IncRc for the given array value if possible.
fn pop_rc_for(&mut self, value: ValueId, function: &Function) -> Option<IncRc> {
let typ = function.dfg.type_of_value(value);
fn pop_rc_for(&mut self, value: ValueId) -> Option<IncRc> {
let typ = self.function.dfg.type_of_value(value);

let rcs = self.inc_rcs.get_mut(&typ)?;
let position = rcs.iter().position(|inc_rc| inc_rc.array == value)?;
Expand Down Expand Up @@ -250,6 +278,7 @@ mod test {
builder.terminate_with_return(vec![]);

let ssa = builder.finish().remove_paired_rc();
println!("{}", ssa);
let main = ssa.main();
let entry = main.entry_block();

Expand Down Expand Up @@ -310,4 +339,160 @@ mod test {
assert_eq!(count_inc_rcs(entry, &main.dfg), 1);
assert_eq!(count_dec_rcs(entry, &main.dfg), 1);
}

#[test]
fn separate_entry_and_exit_block_fn_return_array() {
// brillig fn foo f0 {
// b0(v0: [Field; 2]):
// jmp b1(v0)
// b1():
// inc_rc v0
// inc_rc v0
// dec_rc v0
// return [v0]
// }
let main_id = Id::test_new(0);
let mut builder = FunctionBuilder::new("foo".into(), main_id);
builder.set_runtime(RuntimeType::Brillig);

let inner_array_type = Type::Array(Arc::new(vec![Type::field()]), 2);
let v0 = builder.add_parameter(inner_array_type.clone());

let b1 = builder.insert_block();
builder.terminate_with_jmp(b1, vec![v0]);

builder.switch_to_block(b1);
builder.insert_inc_rc(v0);
builder.insert_inc_rc(v0);
builder.insert_dec_rc(v0);

let outer_array_type = Type::Array(Arc::new(vec![inner_array_type]), 1);
let array = builder.array_constant(vec![v0].into(), outer_array_type);
builder.terminate_with_return(vec![array]);

// Expected result:
//
// brillig fn foo f0 {
// b0(v0: [Field; 2]):
// jmp b1(v0)
// b1():
// inc_rc v0
// return [v0]
// }
let ssa = builder.finish().remove_paired_rc();
let main = ssa.main();

assert_eq!(count_inc_rcs(b1, &main.dfg), 1);
assert_eq!(count_dec_rcs(b1, &main.dfg), 0);
}

#[test]
fn exit_block_single_mutation() {
// fn mutator(mut array: [Field; 2]) {
// array[0] = 5;
// }
//
// acir(inline) fn mutator f0 {
// b0(v0: [Field; 2]):
// jmp b1(v0)
// b1(v1: [Field; 2]):
// v2 = allocate
// store v1 at v2
// inc_rc v1
// v3 = load v2
// v6 = array_set v3, index u64 0, value Field 5
// store v6 at v2
// dec_rc v1
// return
// }
let main_id = Id::test_new(0);
let mut builder = FunctionBuilder::new("mutator".into(), main_id);

let array_type = Type::Array(Arc::new(vec![Type::field()]), 2);
let v0 = builder.add_parameter(array_type.clone());

let b1 = builder.insert_block();
builder.terminate_with_jmp(b1, vec![v0]);

builder.switch_to_block(b1);
// We want to make sure we go through the block parameter
let v1 = builder.add_block_parameter(b1, array_type.clone());

let v2 = builder.insert_allocate(array_type.clone());
builder.insert_store(v2, v1);
builder.insert_inc_rc(v1);
let v3 = builder.insert_load(v2, array_type);

let zero = builder.numeric_constant(0u128, Type::unsigned(64));
let five = builder.field_constant(5u128);
let v8 = builder.insert_array_set(v3, zero, five);

builder.insert_store(v2, v8);
builder.insert_dec_rc(v1);
builder.terminate_with_return(vec![]);

let ssa = builder.finish().remove_paired_rc();
let main = ssa.main();

// No changes, the array is possibly mutated
assert_eq!(count_inc_rcs(b1, &main.dfg), 1);
assert_eq!(count_dec_rcs(b1, &main.dfg), 1);
}

#[test]
fn exit_block_mutation_through_reference() {
// fn mutator2(array: &mut [Field; 2]) {
// array[0] = 5;
// }
// acir(inline) fn mutator2 f0 {
// b0(v0: &mut [Field; 2]):
// jmp b1(v0)
// b1(v1: &mut [Field; 2]):
// v2 = load v1
// inc_rc v1
// store v2 at v1
// v3 = load v2
// v6 = array_set v3, index u64 0, value Field 5
// store v6 at v1
// v7 = load v1
// dec_rc v7
// store v7 at v1
// return
// }
let main_id = Id::test_new(0);
let mut builder = FunctionBuilder::new("mutator2".into(), main_id);

let array_type = Type::Array(Arc::new(vec![Type::field()]), 2);
let reference_type = Type::Reference(Arc::new(array_type.clone()));

let v0 = builder.add_parameter(reference_type.clone());

let b1 = builder.insert_block();
builder.terminate_with_jmp(b1, vec![v0]);

builder.switch_to_block(b1);
let v1 = builder.add_block_parameter(b1, reference_type);

let v2 = builder.insert_load(v1, array_type.clone());
builder.insert_inc_rc(v1);
builder.insert_store(v1, v2);

let v3 = builder.insert_load(v2, array_type.clone());
let zero = builder.numeric_constant(0u128, Type::unsigned(64));
let five = builder.field_constant(5u128);
let v6 = builder.insert_array_set(v3, zero, five);

builder.insert_store(v1, v6);
let v7 = builder.insert_load(v1, array_type);
builder.insert_dec_rc(v7);
builder.insert_store(v1, v7);
builder.terminate_with_return(vec![]);

let ssa = builder.finish().remove_paired_rc();
let main = ssa.main();

// No changes, the array is possibly mutated
assert_eq!(count_inc_rcs(b1, &main.dfg), 1);
assert_eq!(count_dec_rcs(b1, &main.dfg), 1);
}
}
Loading