Skip to content

Commit

Permalink
chore: refactor ACIR function IDs from raw integers to struct (#5748)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Resolves #5700

## Summary\*

This PR includes an `AcirFunctionId` struct in place of raw integer
identifier in the `Call` opcode to facilitate type safety.

## Additional Context

- The `AcirFunctionId` struct is transparent for compatibility with
serialization.

## Documentation\*

Check one:
- [x] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.
  • Loading branch information
shak58 authored Aug 19, 2024
1 parent cb0d490 commit 20dc8a4
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 31 deletions.
6 changes: 5 additions & 1 deletion acvm-repo/acir/src/circuit/opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ use super::{
brillig::{BrilligFunctionId, BrilligInputs, BrilligOutputs},
directives::Directive,
};

pub mod function_id;
pub use function_id::AcirFunctionId;

use crate::native_types::{Expression, Witness};
use acir_field::AcirField;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -125,7 +129,7 @@ pub enum Opcode<F> {
Call {
/// Id for the function being called. It is the responsibility of the executor
/// to fetch the appropriate circuit from this id.
id: u32,
id: AcirFunctionId,
/// Inputs to the function call
inputs: Vec<Witness>,
/// Outputs of the function call
Expand Down
17 changes: 17 additions & 0 deletions acvm-repo/acir/src/circuit/opcodes/function_id.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
use serde::{Deserialize, Serialize};

#[derive(Clone, Copy, PartialEq, Eq, Debug, Serialize, Deserialize, Hash)]
#[serde(transparent)]
pub struct AcirFunctionId(pub u32);

impl AcirFunctionId {
pub fn as_usize(&self) -> usize {
self.0 as usize
}
}

impl std::fmt::Display for AcirFunctionId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
8 changes: 4 additions & 4 deletions acvm-repo/acir/tests/test_program_serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use std::collections::BTreeSet;
use acir::{
circuit::{
brillig::{BrilligBytecode, BrilligFunctionId, BrilligInputs, BrilligOutputs},
opcodes::{BlackBoxFuncCall, BlockId, FunctionInput, MemOp},
opcodes::{AcirFunctionId, BlackBoxFuncCall, BlockId, FunctionInput, MemOp},
Circuit, Opcode, Program, PublicInputs,
},
native_types::{Expression, Witness},
Expand Down Expand Up @@ -381,13 +381,13 @@ fn nested_acir_call_circuit() {
// x
// }
let nested_call = Opcode::Call {
id: 1,
id: AcirFunctionId(1),
inputs: vec![Witness(0), Witness(1)],
outputs: vec![Witness(2)],
predicate: None,
};
let nested_call_two = Opcode::Call {
id: 1,
id: AcirFunctionId(1),
inputs: vec![Witness(0), Witness(1)],
outputs: vec![Witness(3)],
predicate: None,
Expand Down Expand Up @@ -419,7 +419,7 @@ fn nested_acir_call_circuit() {
q_c: FieldElement::one() + FieldElement::one(),
});
let call = Opcode::Call {
id: 2,
id: AcirFunctionId(2),
inputs: vec![Witness(2), Witness(1)],
outputs: vec![Witness(3)],
predicate: None,
Expand Down
6 changes: 3 additions & 3 deletions acvm-repo/acvm/src/pwg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use acir::{
brillig::ForeignCallResult,
circuit::{
brillig::{BrilligBytecode, BrilligFunctionId},
opcodes::{BlockId, ConstantOrWitnessEnum, FunctionInput},
opcodes::{AcirFunctionId, BlockId, ConstantOrWitnessEnum, FunctionInput},
AssertionPayload, ErrorSelector, ExpressionOrMemory, Opcode, OpcodeLocation,
RawAssertionPayload, ResolvedAssertionPayload, STRING_ERROR_SELECTOR,
},
Expand Down Expand Up @@ -575,7 +575,7 @@ impl<'a, F: AcirField, B: BlackBoxFunctionSolver<F>> ACVM<'a, F, B> {
else {
unreachable!("Not executing a Call opcode");
};
if *id == 0 {
if *id == AcirFunctionId(0) {
return Err(OpcodeResolutionError::AcirMainCallAttempted {
opcode_location: ErrorLocation::Resolved(OpcodeLocation::Acir(
self.instruction_pointer(),
Expand Down Expand Up @@ -716,7 +716,7 @@ pub(crate) fn is_predicate_false<F: AcirField>(
#[derive(Debug, Clone, PartialEq)]
pub struct AcirCallWaitInfo<F> {
/// Index in the list of ACIR function's that should be called
pub id: u32,
pub id: AcirFunctionId,
/// Initial witness for the given circuit to be called
pub initial_witness: WitnessMap<F>,
}
4 changes: 2 additions & 2 deletions acvm-repo/acvm_js/src/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ impl<'a, B: BlackBoxFunctionSolver<FieldElement>> ProgramExecutor<'a, B> {
acvm.resolve_pending_foreign_call(result);
}
ACVMStatus::RequiresAcirCall(call_info) => {
let acir_to_call = &self.functions[call_info.id as usize];
let acir_to_call = &self.functions[call_info.id.as_usize()];
let initial_witness = call_info.initial_witness;
let call_solved_witness = self
.execute_circuit(acir_to_call, initial_witness, witness_stack)
Expand All @@ -267,7 +267,7 @@ impl<'a, B: BlackBoxFunctionSolver<FieldElement>> ProgramExecutor<'a, B> {
}
}
acvm.resolve_pending_acir_call(call_resolved_outputs);
witness_stack.push(call_info.id, call_solved_witness.clone());
witness_stack.push(call_info.id.0, call_solved_witness.clone());
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::ssa::ir::dfg::CallStack;
use crate::ssa::ir::types::Type as SsaType;
use crate::ssa::ir::{instruction::Endian, types::NumericType};
use acvm::acir::circuit::brillig::{BrilligFunctionId, BrilligInputs, BrilligOutputs};
use acvm::acir::circuit::opcodes::{BlockId, BlockType, MemOp};
use acvm::acir::circuit::opcodes::{AcirFunctionId, BlockId, BlockType, MemOp};
use acvm::acir::circuit::{AssertionPayload, ExpressionOrMemory, ExpressionWidth, Opcode};
use acvm::blackbox_solver;
use acvm::brillig_vm::{MemoryValue, VMStatus, VM};
Expand Down Expand Up @@ -1959,7 +1959,7 @@ impl<F: AcirField> AcirContext<F> {

pub(crate) fn call_acir_function(
&mut self,
id: u32,
id: AcirFunctionId,
inputs: Vec<AcirValue>,
output_count: usize,
predicate: AcirVar,
Expand Down
56 changes: 45 additions & 11 deletions compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use crate::brillig::brillig_ir::BrilligContext;
use crate::brillig::{brillig_gen::brillig_fn::FunctionContext as BrilligFunctionContext, Brillig};
use crate::errors::{InternalError, InternalWarning, RuntimeError, SsaReport};
pub(crate) use acir_ir::generated_acir::GeneratedAcir;
use acvm::acir::circuit::opcodes::BlockType;
use acvm::acir::circuit::opcodes::{AcirFunctionId, BlockType};
use noirc_frontend::monomorphization::ast::InlineType;

use acvm::acir::circuit::brillig::{BrilligBytecode, BrilligFunctionId};
Expand Down Expand Up @@ -775,7 +775,7 @@ impl<'a> Context<'a> {
.get(id)
.expect("ICE: should have an associated final index");
let output_vars = self.acir_context.call_acir_function(
*acir_function_id,
AcirFunctionId(*acir_function_id),
inputs,
output_count,
self.current_side_effects_enabled_var,
Expand Down Expand Up @@ -2867,9 +2867,13 @@ fn can_omit_element_sizes_array(array_typ: &Type) -> bool {

#[cfg(test)]
mod test {

use acvm::{
acir::{
circuit::{brillig::BrilligFunctionId, ExpressionWidth, Opcode, OpcodeLocation},
circuit::{
brillig::BrilligFunctionId, opcodes::AcirFunctionId, ExpressionWidth, Opcode,
OpcodeLocation,
},
native_types::Witness,
},
FieldElement,
Expand Down Expand Up @@ -3020,8 +3024,18 @@ mod test {
let main_opcodes = main_acir.opcodes();
assert_eq!(main_opcodes.len(), 3, "Should have two calls to `foo`");

check_call_opcode(&main_opcodes[0], 1, vec![Witness(0), Witness(1)], vec![Witness(2)]);
check_call_opcode(&main_opcodes[1], 1, vec![Witness(0), Witness(1)], vec![Witness(3)]);
check_call_opcode(
&main_opcodes[0],
AcirFunctionId(1),
vec![Witness(0), Witness(1)],
vec![Witness(2)],
);
check_call_opcode(
&main_opcodes[1],
AcirFunctionId(1),
vec![Witness(0), Witness(1)],
vec![Witness(3)],
);

if let Opcode::AssertZero(expr) = &main_opcodes[2] {
assert_eq!(expr.linear_combinations[0].0, FieldElement::from(1u128));
Expand Down Expand Up @@ -3076,9 +3090,19 @@ mod test {
let main_opcodes = main_acir.opcodes();
assert_eq!(main_opcodes.len(), 3, "Should have two calls to `foo` and an assert");

check_call_opcode(&main_opcodes[0], 1, vec![Witness(0), Witness(1)], vec![Witness(2)]);
check_call_opcode(
&main_opcodes[0],
AcirFunctionId(1),
vec![Witness(0), Witness(1)],
vec![Witness(2)],
);
// The output of the first call should be the input of the second call
check_call_opcode(&main_opcodes[1], 1, vec![Witness(2), Witness(1)], vec![Witness(3)]);
check_call_opcode(
&main_opcodes[1],
AcirFunctionId(1),
vec![Witness(2), Witness(1)],
vec![Witness(3)],
);
}

fn basic_nested_call(inline_type: InlineType) {
Expand Down Expand Up @@ -3167,9 +3191,19 @@ mod test {
assert_eq!(main_opcodes.len(), 3, "Should have two calls to `foo` and an assert");

// Both of these should call func_with_nested_foo_call f1
check_call_opcode(&main_opcodes[0], 1, vec![Witness(0), Witness(1)], vec![Witness(2)]);
check_call_opcode(
&main_opcodes[0],
AcirFunctionId(1),
vec![Witness(0), Witness(1)],
vec![Witness(2)],
);
// The output of the first call should be the input of the second call
check_call_opcode(&main_opcodes[1], 1, vec![Witness(0), Witness(1)], vec![Witness(3)]);
check_call_opcode(
&main_opcodes[1],
AcirFunctionId(1),
vec![Witness(0), Witness(1)],
vec![Witness(3)],
);

let func_with_nested_call_acir = &acir_functions[1];
let func_with_nested_call_opcodes = func_with_nested_call_acir.opcodes();
Expand All @@ -3182,15 +3216,15 @@ mod test {
// Should call foo f2
check_call_opcode(
&func_with_nested_call_opcodes[1],
2,
AcirFunctionId(2),
vec![Witness(3), Witness(1)],
vec![Witness(4)],
);
}

fn check_call_opcode(
opcode: &Opcode<FieldElement>,
expected_id: u32,
expected_id: AcirFunctionId,
expected_inputs: Vec<Witness>,
expected_outputs: Vec<Witness>,
) {
Expand Down
15 changes: 10 additions & 5 deletions tooling/debugger/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ impl<'a, B: BlackBoxFunctionSolver<FieldElement>> DebugContext<'a, B> {
&mut self,
call_info: AcirCallWaitInfo<FieldElement>,
) -> DebugCommandResult {
let callee_circuit = &self.circuits[call_info.id as usize];
let callee_circuit = &self.circuits[call_info.id.as_usize()];
let callee_witness_map = call_info.initial_witness;
let callee_acvm = ACVM::new(
self.backend,
Expand All @@ -578,7 +578,7 @@ impl<'a, B: BlackBoxFunctionSolver<FieldElement>> DebugContext<'a, B> {
let caller_acvm = std::mem::replace(&mut self.acvm, callee_acvm);
self.acvm_stack
.push(ExecutionFrame { circuit_id: self.current_circuit_id, acvm: caller_acvm });
self.current_circuit_id = call_info.id;
self.current_circuit_id = call_info.id.0;

// Explicitly handling the new ACVM status here handles two edge cases:
// 1. there is a breakpoint set at the beginning of a circuit
Expand All @@ -596,7 +596,7 @@ impl<'a, B: BlackBoxFunctionSolver<FieldElement>> DebugContext<'a, B> {
let ACVMStatus::RequiresAcirCall(call_info) = self.acvm.get_status() else {
unreachable!("Resolving an ACIR call, the caller is in an invalid state");
};
let acir_to_call = &self.circuits[call_info.id as usize];
let acir_to_call = &self.circuits[call_info.id.as_usize()];

let mut call_resolved_outputs = Vec::new();
for return_witness_index in acir_to_call.return_values.indices() {
Expand Down Expand Up @@ -946,7 +946,7 @@ mod tests {
brillig::IntegerBitSize,
circuit::{
brillig::{BrilligFunctionId, BrilligInputs, BrilligOutputs},
opcodes::{BlockId, BlockType},
opcodes::{AcirFunctionId, BlockId, BlockType},
},
native_types::Expression,
AcirField,
Expand Down Expand Up @@ -1210,7 +1210,12 @@ mod tests {
outputs: vec![],
predicate: None,
},
Opcode::Call { id: 1, inputs: vec![], outputs: vec![], predicate: None },
Opcode::Call {
id: AcirFunctionId(1),
inputs: vec![],
outputs: vec![],
predicate: None,
},
Opcode::AssertZero(Expression::default()),
],
..Circuit::default()
Expand Down
6 changes: 3 additions & 3 deletions tooling/nargo/src/ops/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ impl<'a, F: AcirField, B: BlackBoxFunctionSolver<F>, E: ForeignCallExecutor<F>>
});

// Set current function to the circuit we are about to execute
self.current_function_index = call_info.id as usize;
self.current_function_index = call_info.id.as_usize();
// Execute the ACIR call
let acir_to_call = &self.functions[call_info.id as usize];
let acir_to_call = &self.functions[call_info.id.as_usize()];
let initial_witness = call_info.initial_witness;
let call_solved_witness = self.execute_circuit(initial_witness)?;

Expand All @@ -163,7 +163,7 @@ impl<'a, F: AcirField, B: BlackBoxFunctionSolver<F>, E: ForeignCallExecutor<F>>
}
}
acvm.resolve_pending_acir_call(call_resolved_outputs);
self.witness_stack.push(call_info.id, call_solved_witness);
self.witness_stack.push(call_info.id.0, call_solved_witness);
}
}
}
Expand Down

0 comments on commit 20dc8a4

Please sign in to comment.