diff --git a/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs b/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs index bd42c672d81..c1fab452b6e 100644 --- a/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs +++ b/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs @@ -61,63 +61,68 @@ impl MergeExpressionsOptimizer { if !matches!(opcode, Opcode::AssertZero(_)) { continue; } - let opcode = self.modified_gates.get(&i).unwrap_or(opcode).clone(); - let input_witnesses = self.witness_inputs(&opcode); - for w in input_witnesses { - let Some(gates_using_w) = used_witness.get(&w) else { - continue; - }; - // We only consider witness which are used in exactly two arithmetic gates - if gates_using_w.len() == 2 { - let first = *gates_using_w.first().expect("gates_using_w.len == 2"); - let second = *gates_using_w.last().expect("gates_using_w.len == 2"); - let b = if second == i { - first - } else { - // sanity check - assert!(i == first); - second + if let Some(opcode) = self.get_opcode(i, circuit) { + let input_witnesses = self.witness_inputs(&opcode); + for w in input_witnesses { + let Some(gates_using_w) = used_witness.get(&w) else { + continue; }; - // Merge source opcode into target opcode - // by updating modified_gates/deleted_gates/used_witness - let mut merge_opcodes = |target, source| -> bool { - assert!(source < target); - let source_opcode = self.get_opcode(source, circuit); - let target_opcode = self.get_opcode(target, circuit); - if let ( - Some(Opcode::AssertZero(expr_use)), - Some(Opcode::AssertZero(expr_define)), - ) = (target_opcode, source_opcode) - { - if let Some(expr) = Self::merge_expression(&expr_use, &expr_define, w) { - self.modified_gates.insert(source, Opcode::AssertZero(expr)); - self.deleted_gates.insert(target); - // Update the 'used_witness' map to account for the merge. - for w2 in CircuitSimulator::expr_wit(&expr_use) { - if !circuit_inputs.contains(&w2) { - let mut v = used_witness[&w2].clone(); - v.insert(i); - v.remove(&b); - used_witness.insert(w2, v); + // We only consider witness which are used in exactly two arithmetic gates + if gates_using_w.len() == 2 { + let first = *gates_using_w.first().expect("gates_using_w.len == 2"); + let second = *gates_using_w.last().expect("gates_using_w.len == 2"); + let b = if second == i { + first + } else { + // sanity check + assert!(i == first); + second + }; + // Merge source opcode into target opcode + // by updating modified_gates/deleted_gates/used_witness + let mut merge_opcodes = |target, source| -> bool { + assert!(source < target); + let source_opcode = self.get_opcode(source, circuit); + let target_opcode = self.get_opcode(target, circuit); + if let ( + Some(Opcode::AssertZero(expr_use)), + Some(Opcode::AssertZero(expr_define)), + ) = (target_opcode, source_opcode) + { + if let Some(expr) = + Self::merge_expression(&expr_use, &expr_define, w) + { + self.modified_gates.insert(target, Opcode::AssertZero(expr)); + self.deleted_gates.insert(source); + // Update the 'used_witness' map to account for the merge. + let mut witness_list = CircuitSimulator::expr_wit(&expr_use); + witness_list.extend(CircuitSimulator::expr_wit(&expr_define)); + for w2 in witness_list { + if !circuit_inputs.contains(&w2) { + let mut v = used_witness[&w2].clone(); + v.insert(b); + v.remove(&i); + used_witness.insert(w2, v); + } } + return true; } - return true; } - } - false - }; - if i < b { - if merge_opcodes(b, i) { - // We need to stop here and continue with the next opcode - // because the merge invalidates the current opcode. - break; - } - } else if i != b { - // Merge b into i - if merge_opcodes(i, b) { - // We need to stop here and continue with the next opcode - // because the merge invalidates the current opcode. - break; + false + }; + if i < b { + if merge_opcodes(b, i) { + // We need to stop here and continue with the next opcode + // because the merge invalidates the current opcode. + break; + } + } else if i != b { + // Merge b into i + if merge_opcodes(i, b) { + // We need to stop here and continue with the next opcode + // because the merge invalidates the current opcode. + break; + } } } }