Skip to content

Commit

Permalink
merge from master
Browse files Browse the repository at this point in the history
  • Loading branch information
guipublic committed Dec 2, 2024
1 parent 3b46b9d commit 4b3ab94
Showing 1 changed file with 117 additions and 77 deletions.
194 changes: 117 additions & 77 deletions acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::collections::{BTreeMap, BTreeSet, HashMap};
use acir::{
circuit::{
brillig::{BrilligInputs, BrilligOutputs},
directives::Directive,
opcodes::BlockId,
Circuit, Opcode,
},
Expand Down Expand Up @@ -59,70 +58,66 @@ impl<F: AcirField> MergeExpressionsOptimizer<F> {

// For each opcode, try to get a target opcode to merge with
for (i, opcode) in circuit.opcodes.iter().enumerate() {
if !matches!(opcode, Opcode::AssertZero(_)) {
continue;
}
let opcode = self.modified_gates.get(&i).unwrap_or(opcode).clone();
let input_witnesses = self.witness_inputs(&opcode);
for w in input_witnesses.clone() {
let empty_gates = BTreeSet::new();
let gates_using_w = used_witness.get(&w).unwrap_or(&empty_gates);
for w in input_witnesses {
let Some(gates_using_w) = used_witness.get(&w) else {
continue;
};
// We only consider witness which are used in exactly two arithmetic gates
if gates_using_w.len() == 2 {
let gates_using_w: Vec<_> = gates_using_w
.iter()
.filter(|g| **g != i && !self.deleted_gates.contains(g))
.collect();
if gates_using_w.len() == 2 {
// i is the current opcode (from the iteration)
// b is the (only) other opcode using w
let mut b = *gates_using_w[1];
if b == i {
b = *gates_using_w[0];
} else {
// sanity check
assert!(i == *gates_using_w[0]);
}
// Merge source opcode into target opcode
// by updating modified_gates/deleted_gates/used_witness
let mut merge_opcodes = |target, source| -> bool {
assert!(source < target);
let source_opcode = self.get_opcode(source, circuit);
let target_opcode = self.get_opcode(target, circuit);
if let (
Some(Opcode::AssertZero(expr_use)),
Some(Opcode::AssertZero(expr_define)),
) = (target_opcode, source_opcode)
{
if let Some(expr) =
Self::merge_expression(&expr_use, &expr_define, w)
{
self.modified_gates.insert(source, Opcode::AssertZero(expr));
self.deleted_gates.insert(target);
// Update the 'used_witness' map to account for the merge.
for w2 in CircuitSimulator::expr_wit(&expr_use) {
if !circuit_inputs.contains(&w2) {
let mut v = used_witness[&w2].clone();
v.insert(i);
v.remove(&b);
used_witness.insert(w2, v);
}
let first = *gates_using_w.first().expect("gates_using_w.len == 2");
let second = *gates_using_w.last().expect("gates_using_w.len == 2");
let b = if second == i {
first
} else {
// sanity check
assert!(i == first);
second
};
// Merge source opcode into target opcode
// by updating modified_gates/deleted_gates/used_witness
let mut merge_opcodes = |target, source| -> bool {
assert!(source < target);
let source_opcode = self.get_opcode(source, circuit);
let target_opcode = self.get_opcode(target, circuit);
if let (
Some(Opcode::AssertZero(expr_use)),
Some(Opcode::AssertZero(expr_define)),
) = (target_opcode, source_opcode)
{
if let Some(expr) = Self::merge_expression(&expr_use, &expr_define, w) {
self.modified_gates.insert(source, Opcode::AssertZero(expr));
self.deleted_gates.insert(target);
// Update the 'used_witness' map to account for the merge.
for w2 in CircuitSimulator::expr_wit(&expr_use) {
if !circuit_inputs.contains(&w2) {
let mut v = used_witness[&w2].clone();
v.insert(i);
v.remove(&b);
used_witness.insert(w2, v);
}
return true;
}
return true;
}
false
};
if i < b {
if merge_opcodes(b, i) {
// We need to stop here and continue with the next opcode
// because the merge invalidates the current opcode.
break;
}
} else if i != b {
// Merge b into i
if merge_opcodes(i, b) {
// We need to stop here and continue with the next opcode
// because the merge invalidates the current opcode.
break;
}
}
false
};
if i < b {
if merge_opcodes(b, i) {
// We need to stop here and continue with the next opcode
// because the merge invalidates the current opcode.
break;
}
} else if i != b {
// Merge b into i
if merge_opcodes(i, b) {
// We need to stop here and continue with the next opcode
// because the merge invalidates the current opcode.
break;
}
}
}
Expand All @@ -132,11 +127,11 @@ impl<F: AcirField> MergeExpressionsOptimizer<F> {
// Construct the new circuit from modified/deleted gates
let mut new_circuit = Vec::new();
let mut new_acir_opcode_positions = Vec::new();
#[allow(clippy::needless_range_loop)]
for i in 0..circuit.opcodes.len() {

for (i, opcode_position) in acir_opcode_positions.iter().enumerate() {
if let Some(op) = self.get_opcode(i, circuit) {
new_circuit.push(op);
new_acir_opcode_positions.push(acir_opcode_positions[i]);
new_acir_opcode_positions.push(*opcode_position);
}
}
(new_circuit, new_acir_opcode_positions)
Expand Down Expand Up @@ -176,15 +171,17 @@ impl<F: AcirField> MergeExpressionsOptimizer<F> {

// Returns the input witnesses used by the opcode
fn witness_inputs(&self, opcode: &Opcode<F>) -> BTreeSet<Witness> {
let mut witnesses = BTreeSet::new();
match opcode {
Opcode::AssertZero(expr) => CircuitSimulator::expr_wit(expr),
Opcode::BlackBoxFuncCall(bb_func) => bb_func.get_input_witnesses(),
Opcode::Directive(Directive::ToLeRadix { a, .. }) => CircuitSimulator::expr_wit(a),
Opcode::BlackBoxFuncCall(bb_func) => {
let mut witnesses = bb_func.get_input_witnesses();
witnesses.extend(bb_func.get_outputs_vec());

witnesses
}
Opcode::MemoryOp { block_id: _, op, predicate } => {
//index et value, et predicate
let mut witnesses = BTreeSet::new();
witnesses.extend(CircuitSimulator::expr_wit(&op.index));
let mut witnesses = CircuitSimulator::expr_wit(&op.index);
witnesses.extend(CircuitSimulator::expr_wit(&op.value));
if let Some(p) = predicate {
witnesses.extend(CircuitSimulator::expr_wit(p));
Expand All @@ -196,6 +193,7 @@ impl<F: AcirField> MergeExpressionsOptimizer<F> {
init.iter().cloned().collect()
}
Opcode::BrilligCall { inputs, outputs, .. } => {
let mut witnesses = BTreeSet::new();
for i in inputs {
witnesses.extend(self.brillig_input_wit(i));
}
Expand All @@ -205,12 +203,9 @@ impl<F: AcirField> MergeExpressionsOptimizer<F> {
witnesses
}
Opcode::Call { id: _, inputs, outputs, predicate } => {
for i in inputs {
witnesses.insert(*i);
}
for i in outputs {
witnesses.insert(*i);
}
let mut witnesses: BTreeSet<Witness> = BTreeSet::from_iter(inputs.iter().copied());
witnesses.extend(outputs);

if let Some(p) = predicate {
witnesses.extend(CircuitSimulator::expr_wit(p));
}
Expand Down Expand Up @@ -254,7 +249,7 @@ impl<F: AcirField> MergeExpressionsOptimizer<F> {
if self.deleted_gates.contains(&g) {
return None;
}
Some(self.modified_gates.get(&g).unwrap_or(&circuit.opcodes[g]).clone())
Some(self.modified_gates.get(&g).cloned().unwrap_or(circuit.opcodes[g].clone()))
}
}

Expand All @@ -265,15 +260,15 @@ mod tests {
acir_field::AcirField,
circuit::{
brillig::{BrilligFunctionId, BrilligOutputs},
opcodes::FunctionInput,
opcodes::{BlackBoxFuncCall, FunctionInput},
Circuit, ExpressionWidth, Opcode, PublicInputs,
},
native_types::{Expression, Witness},
FieldElement,
};
use std::collections::BTreeSet;

fn check_circuit(circuit: Circuit<FieldElement>) {
fn check_circuit(circuit: Circuit<FieldElement>) -> Circuit<FieldElement> {
assert!(CircuitSimulator::default().check_circuit(&circuit));
let mut merge_optimizer = MergeExpressionsOptimizer::new();
let acir_opcode_positions = vec![0; 20];
Expand All @@ -283,6 +278,7 @@ mod tests {
optimized_circuit.opcodes = opcodes;
// check that the circuit is still valid after optimization
assert!(CircuitSimulator::default().check_circuit(&optimized_circuit));
optimized_circuit
}

#[test]
Expand Down Expand Up @@ -325,7 +321,6 @@ mod tests {
public_parameters: PublicInputs::default(),
return_values: PublicInputs::default(),
assert_messages: Default::default(),
recursive: false,
};
check_circuit(circuit);
}
Expand Down Expand Up @@ -378,8 +373,53 @@ mod tests {
public_parameters: PublicInputs::default(),
return_values: PublicInputs::default(),
assert_messages: Default::default(),
recursive: false,
};
check_circuit(circuit);
}

#[test]
fn takes_blackbox_opcode_outputs_into_account() {
// Regression test for https://github.com/noir-lang/noir/issues/6527
// Previously we would not track the usage of witness 4 in the output of the blackbox function.
// We would then merge the final two opcodes losing the check that the brillig call must match
// with `_0 ^ _1`.

let circuit: Circuit<FieldElement> = Circuit {
current_witness_index: 7,
opcodes: vec![
Opcode::BrilligCall {
id: BrilligFunctionId(0),
inputs: Vec::new(),
outputs: vec![BrilligOutputs::Simple(Witness(3))],
predicate: None,
},
Opcode::BlackBoxFuncCall(BlackBoxFuncCall::AND {
lhs: FunctionInput::witness(Witness(0), 8),
rhs: FunctionInput::witness(Witness(1), 8),
output: Witness(4),
}),
Opcode::AssertZero(Expression {
linear_combinations: vec![
(FieldElement::one(), Witness(3)),
(-FieldElement::one(), Witness(4)),
],
..Default::default()
}),
Opcode::AssertZero(Expression {
linear_combinations: vec![
(-FieldElement::one(), Witness(2)),
(FieldElement::one(), Witness(4)),
],
..Default::default()
}),
],
expression_width: ExpressionWidth::Bounded { width: 4 },
private_parameters: BTreeSet::from([Witness(0), Witness(1)]),
return_values: PublicInputs(BTreeSet::from([Witness(2)])),
..Default::default()
};

let new_circuit = check_circuit(circuit.clone());
assert_eq!(circuit, new_circuit);
}
}

0 comments on commit 4b3ab94

Please sign in to comment.