From 413a4e0e4e22c0dcf86a2d3de6b75c98ae1d67d4 Mon Sep 17 00:00:00 2001 From: Maxim Vezenov Date: Wed, 10 Apr 2024 16:22:36 +0100 Subject: [PATCH] feat(simulator): Fetch return values at circuit execution (#5642) We were deserializing a `Program` twice in the simulator. Once to execute the circuit and again to fetch the return witness. Every call to `executeCircuitWithBlackBoxSolver` is followed by a call to `getReturnWitness` in the simulator. We should just pass the return witness right away as to not cross the WASM boundary a second time and as to avoid a second deserialization. I settled on a new ACVM method rather than using abi decoding as the return witnesses are stripped from the ABI. We can have a method that returns both the fully solved witness and the return witness based upon the circuit. This allows us to both avoid storing duplicate return witness information and an unnecessary deserialization. --- .../acvm-repo/acvm_js/src/execute.rs | 56 ++++++++++++++++++- .../acvm-repo/acvm_js/src/js_witness_map.rs | 38 ++++++++++++- noir/noir-repo/acvm-repo/acvm_js/src/lib.rs | 3 +- .../acvm-repo/acvm_js/src/public_witness.rs | 9 ++- yarn-project/simulator/src/acvm/acvm.ts | 11 ++-- .../simulator/src/acvm/deserialize.ts | 14 ++--- .../simulator/src/client/private_execution.ts | 8 +-- .../src/client/unconstrained_execution.ts | 7 ++- yarn-project/simulator/src/public/executor.ts | 8 ++- 9 files changed, 125 insertions(+), 29 deletions(-) diff --git a/noir/noir-repo/acvm-repo/acvm_js/src/execute.rs b/noir/noir-repo/acvm-repo/acvm_js/src/execute.rs index 0e58ccf039c..c97b8ea1a66 100644 --- a/noir/noir-repo/acvm-repo/acvm_js/src/execute.rs +++ b/noir/noir-repo/acvm-repo/acvm_js/src/execute.rs @@ -13,7 +13,8 @@ use wasm_bindgen::prelude::wasm_bindgen; use crate::{ foreign_call::{resolve_brillig, ForeignCallHandler}, - JsExecutionError, JsWitnessMap, JsWitnessStack, + public_witness::extract_indices, + JsExecutionError, JsSolvedAndReturnWitness, JsWitnessMap, JsWitnessStack, }; #[wasm_bindgen] @@ -58,6 +59,44 @@ pub async fn execute_circuit( Ok(witness_map.into()) } +/// Executes an ACIR circuit to generate the solved witness from the initial witness. +/// This method also extracts the public return values from the solved witness into its own return witness. +/// +/// @param {&WasmBlackBoxFunctionSolver} solver - A black box solver. +/// @param {Uint8Array} circuit - A serialized representation of an ACIR circuit +/// @param {WitnessMap} initial_witness - The initial witness map defining all of the inputs to `circuit`.. +/// @param {ForeignCallHandler} foreign_call_handler - A callback to process any foreign calls from the circuit. +/// @returns {SolvedAndReturnWitness} The solved witness calculated by executing the circuit on the provided inputs, as well as the return witness indices as specified by the circuit. +#[wasm_bindgen(js_name = executeCircuitWithReturnWitness, skip_jsdoc)] +pub async fn execute_circuit_with_return_witness( + solver: &WasmBlackBoxFunctionSolver, + program: Vec, + initial_witness: JsWitnessMap, + foreign_call_handler: ForeignCallHandler, +) -> Result { + console_error_panic_hook::set_once(); + + let program: Program = Program::deserialize_program(&program) + .map_err(|_| JsExecutionError::new("Failed to deserialize circuit. This is likely due to differing serialization formats between ACVM_JS and your compiler".to_string(), None))?; + + let mut witness_stack = execute_program_with_native_program_and_return( + solver, + &program, + initial_witness, + &foreign_call_handler, + ) + .await?; + let solved_witness = + witness_stack.pop().expect("Should have at least one witness on the stack").witness; + + let main_circuit = &program.functions[0]; + let return_witness = + extract_indices(&solved_witness, main_circuit.return_values.0.iter().copied().collect()) + .map_err(|err| JsExecutionError::new(err, None))?; + + Ok((solved_witness, return_witness).into()) +} + /// Executes an ACIR circuit to generate the solved witness from the initial witness. /// /// @param {&WasmBlackBoxFunctionSolver} solver - A black box solver. @@ -127,6 +166,21 @@ async fn execute_program_with_native_type_return( let program: Program = Program::deserialize_program(&program) .map_err(|_| JsExecutionError::new("Failed to deserialize circuit. This is likely due to differing serialization formats between ACVM_JS and your compiler".to_string(), None))?; + execute_program_with_native_program_and_return( + solver, + &program, + initial_witness, + foreign_call_executor, + ) + .await +} + +async fn execute_program_with_native_program_and_return( + solver: &WasmBlackBoxFunctionSolver, + program: &Program, + initial_witness: JsWitnessMap, + foreign_call_executor: &ForeignCallHandler, +) -> Result { let executor = ProgramExecutor::new(&program.functions, &solver.0, foreign_call_executor); let witness_stack = executor.execute(initial_witness.into()).await?; diff --git a/noir/noir-repo/acvm-repo/acvm_js/src/js_witness_map.rs b/noir/noir-repo/acvm-repo/acvm_js/src/js_witness_map.rs index 481b8caaa2d..c4482c4a234 100644 --- a/noir/noir-repo/acvm-repo/acvm_js/src/js_witness_map.rs +++ b/noir/noir-repo/acvm-repo/acvm_js/src/js_witness_map.rs @@ -2,13 +2,23 @@ use acvm::{ acir::native_types::{Witness, WitnessMap}, FieldElement, }; -use js_sys::{JsString, Map}; +use js_sys::{JsString, Map, Object}; use wasm_bindgen::prelude::{wasm_bindgen, JsValue}; #[wasm_bindgen(typescript_custom_section)] const WITNESS_MAP: &'static str = r#" // Map from witness index to hex string value of witness. export type WitnessMap = Map; + +/** + * An execution result containing two witnesses. + * 1. The full solved witness of the execution. + * 2. The return witness which contains the given public return values within the full witness. + */ +export type SolvedAndReturnWitness = { + solvedWitness: WitnessMap; + returnWitness: WitnessMap; +} "#; // WitnessMap @@ -21,6 +31,12 @@ extern "C" { #[wasm_bindgen(constructor, js_class = "Map")] pub fn new() -> JsWitnessMap; + #[wasm_bindgen(extends = Object, js_name = "SolvedAndReturnWitness", typescript_type = "SolvedAndReturnWitness")] + #[derive(Clone, Debug, PartialEq, Eq)] + pub type JsSolvedAndReturnWitness; + + #[wasm_bindgen(constructor, js_class = "Object")] + pub fn new() -> JsSolvedAndReturnWitness; } impl Default for JsWitnessMap { @@ -29,6 +45,12 @@ impl Default for JsWitnessMap { } } +impl Default for JsSolvedAndReturnWitness { + fn default() -> Self { + Self::new() + } +} + impl From for JsWitnessMap { fn from(witness_map: WitnessMap) -> Self { let js_map = JsWitnessMap::new(); @@ -54,6 +76,20 @@ impl From for WitnessMap { } } +impl From<(WitnessMap, WitnessMap)> for JsSolvedAndReturnWitness { + fn from(witness_maps: (WitnessMap, WitnessMap)) -> Self { + let js_solved_witness = JsWitnessMap::from(witness_maps.0); + let js_return_witness = JsWitnessMap::from(witness_maps.1); + + let entry_map = Map::new(); + entry_map.set(&JsValue::from_str("solvedWitness"), &js_solved_witness); + entry_map.set(&JsValue::from_str("returnWitness"), &js_return_witness); + + let solved_and_return_witness = Object::from_entries(&entry_map).unwrap(); + JsSolvedAndReturnWitness { obj: solved_and_return_witness } + } +} + pub(crate) fn js_value_to_field_element(js_value: JsValue) -> Result { let hex_str = js_value.as_string().ok_or("failed to parse field element from non-string")?; diff --git a/noir/noir-repo/acvm-repo/acvm_js/src/lib.rs b/noir/noir-repo/acvm-repo/acvm_js/src/lib.rs index d7ecc0ae192..66a4388b132 100644 --- a/noir/noir-repo/acvm-repo/acvm_js/src/lib.rs +++ b/noir/noir-repo/acvm-repo/acvm_js/src/lib.rs @@ -22,9 +22,10 @@ pub use compression::{ }; pub use execute::{ create_black_box_solver, execute_circuit, execute_circuit_with_black_box_solver, - execute_program, execute_program_with_black_box_solver, + execute_circuit_with_return_witness, execute_program, execute_program_with_black_box_solver, }; pub use js_execution_error::JsExecutionError; +pub use js_witness_map::JsSolvedAndReturnWitness; pub use js_witness_map::JsWitnessMap; pub use js_witness_stack::JsWitnessStack; pub use logging::init_log_level; diff --git a/noir/noir-repo/acvm-repo/acvm_js/src/public_witness.rs b/noir/noir-repo/acvm-repo/acvm_js/src/public_witness.rs index a0d5b5f8be2..4ba054732d4 100644 --- a/noir/noir-repo/acvm-repo/acvm_js/src/public_witness.rs +++ b/noir/noir-repo/acvm-repo/acvm_js/src/public_witness.rs @@ -7,7 +7,10 @@ use wasm_bindgen::prelude::wasm_bindgen; use crate::JsWitnessMap; -fn extract_indices(witness_map: &WitnessMap, indices: Vec) -> Result { +pub(crate) fn extract_indices( + witness_map: &WitnessMap, + indices: Vec, +) -> Result { let mut extracted_witness_map = WitnessMap::new(); for witness in indices { let witness_value = witness_map.get(&witness).ok_or(format!( @@ -44,7 +47,7 @@ pub fn get_return_witness( let witness_map = WitnessMap::from(witness_map); let return_witness = - extract_indices(&witness_map, circuit.return_values.0.clone().into_iter().collect())?; + extract_indices(&witness_map, circuit.return_values.0.iter().copied().collect())?; Ok(JsWitnessMap::from(return_witness)) } @@ -71,7 +74,7 @@ pub fn get_public_parameters_witness( let witness_map = WitnessMap::from(solved_witness); let public_params_witness = - extract_indices(&witness_map, circuit.public_parameters.0.clone().into_iter().collect())?; + extract_indices(&witness_map, circuit.public_parameters.0.iter().copied().collect())?; Ok(JsWitnessMap::from(public_params_witness)) } diff --git a/yarn-project/simulator/src/acvm/acvm.ts b/yarn-project/simulator/src/acvm/acvm.ts index 6d7101ba64e..d166b5d16c3 100644 --- a/yarn-project/simulator/src/acvm/acvm.ts +++ b/yarn-project/simulator/src/acvm/acvm.ts @@ -7,7 +7,7 @@ import { type ForeignCallInput, type ForeignCallOutput, type WasmBlackBoxFunctionSolver, - executeCircuitWithBlackBoxSolver, + executeCircuitWithReturnWitness, } from '@noir-lang/acvm_js'; import { traverseCauseChain } from '../common/errors.js'; @@ -27,9 +27,12 @@ type ACIRCallback = Record< */ export interface ACIRExecutionResult { /** - * The partial witness of the execution. + * An execution result contains two witnesses. + * 1. The partial witness of the execution. + * 2. The return witness which contains the given public return values within the full witness. */ partialWitness: ACVMWitness; + returnWitness: ACVMWitness; } /** @@ -89,7 +92,7 @@ export async function acvm( ): Promise { const logger = createDebugLogger('aztec:simulator:acvm'); - const partialWitness = await executeCircuitWithBlackBoxSolver( + const solvedAndReturnWitness = await executeCircuitWithReturnWitness( solver, acir, initialWitness, @@ -127,7 +130,7 @@ export async function acvm( throw err; }); - return { partialWitness }; + return { partialWitness: solvedAndReturnWitness.solvedWitness, returnWitness: solvedAndReturnWitness.returnWitness }; } /** diff --git a/yarn-project/simulator/src/acvm/deserialize.ts b/yarn-project/simulator/src/acvm/deserialize.ts index 74701582330..5936d381a37 100644 --- a/yarn-project/simulator/src/acvm/deserialize.ts +++ b/yarn-project/simulator/src/acvm/deserialize.ts @@ -1,7 +1,5 @@ import { Fr } from '@aztec/foundation/fields'; -import { getReturnWitness } from '@noir-lang/acvm_js'; - import { type ACVMField, type ACVMWitness } from './acvm_types.js'; /** @@ -32,13 +30,11 @@ export function frToBoolean(fr: Fr): boolean { } /** - * Extracts the return fields of a given partial witness. - * @param acir - The bytecode of the function. - * @param partialWitness - The witness to extract from. + * Transforms a witness map to its field elements. + * @param witness - The witness to extract from. * @returns The return values. */ -export function extractReturnWitness(acir: Buffer, partialWitness: ACVMWitness): Fr[] { - const returnWitness = getReturnWitness(acir, partialWitness); - const sortedKeys = [...returnWitness.keys()].sort((a, b) => a - b); - return sortedKeys.map(key => returnWitness.get(key)!).map(fromACVMField); +export function witnessMapToFields(witness: ACVMWitness): Fr[] { + const sortedKeys = [...witness.keys()].sort((a, b) => a - b); + return sortedKeys.map(key => witness.get(key)!).map(fromACVMField); } diff --git a/yarn-project/simulator/src/client/private_execution.ts b/yarn-project/simulator/src/client/private_execution.ts index 964e51cf95f..7789711b296 100644 --- a/yarn-project/simulator/src/client/private_execution.ts +++ b/yarn-project/simulator/src/client/private_execution.ts @@ -4,7 +4,7 @@ import { type AztecAddress } from '@aztec/foundation/aztec-address'; import { Fr } from '@aztec/foundation/fields'; import { createDebugLogger } from '@aztec/foundation/log'; -import { extractReturnWitness } from '../acvm/deserialize.js'; +import { witnessMapToFields } from '../acvm/deserialize.js'; import { Oracle, acvm, extractCallStack } from '../acvm/index.js'; import { ExecutionError } from '../common/errors.js'; import { type ClientExecutionContext } from './client_execution_context.js'; @@ -26,7 +26,7 @@ export async function executePrivateFunction( const acir = artifact.bytecode; const initialWitness = context.getInitialWitness(artifact); const acvmCallback = new Oracle(context); - const { partialWitness } = await acvm(await AcirSimulator.getSolver(), acir, initialWitness, acvmCallback).catch( + const acirExecutionResult = await acvm(await AcirSimulator.getSolver(), acir, initialWitness, acvmCallback).catch( (err: Error) => { throw new ExecutionError( err.message, @@ -39,8 +39,8 @@ export async function executePrivateFunction( ); }, ); - - const returnWitness = extractReturnWitness(acir, partialWitness); + const partialWitness = acirExecutionResult.partialWitness; + const returnWitness = witnessMapToFields(acirExecutionResult.returnWitness); const publicInputs = PrivateCircuitPublicInputs.fromFields(returnWitness); const encryptedLogs = context.getEncryptedLogs(); diff --git a/yarn-project/simulator/src/client/unconstrained_execution.ts b/yarn-project/simulator/src/client/unconstrained_execution.ts index 559a8cf1f48..d821ca9fea9 100644 --- a/yarn-project/simulator/src/client/unconstrained_execution.ts +++ b/yarn-project/simulator/src/client/unconstrained_execution.ts @@ -4,7 +4,7 @@ import { type AztecAddress } from '@aztec/foundation/aztec-address'; import { type Fr } from '@aztec/foundation/fields'; import { createDebugLogger } from '@aztec/foundation/log'; -import { extractReturnWitness } from '../acvm/deserialize.js'; +import { witnessMapToFields } from '../acvm/deserialize.js'; import { Oracle, acvm, extractCallStack, toACVMWitness } from '../acvm/index.js'; import { ExecutionError } from '../common/errors.js'; import { AcirSimulator } from './simulator.js'; @@ -27,7 +27,7 @@ export async function executeUnconstrainedFunction( const acir = artifact.bytecode; const initialWitness = toACVMWitness(0, args); - const { partialWitness } = await acvm( + const acirExecutionResult = await acvm( await AcirSimulator.getSolver(), acir, initialWitness, @@ -44,6 +44,7 @@ export async function executeUnconstrainedFunction( ); }); - return decodeReturnValues(artifact, extractReturnWitness(acir, partialWitness)); + const returnWitness = witnessMapToFields(acirExecutionResult.returnWitness); + return decodeReturnValues(artifact, returnWitness); } // docs:end:execute_unconstrained_function diff --git a/yarn-project/simulator/src/public/executor.ts b/yarn-project/simulator/src/public/executor.ts index 854223ae3cd..08e534b5da2 100644 --- a/yarn-project/simulator/src/public/executor.ts +++ b/yarn-project/simulator/src/public/executor.ts @@ -6,7 +6,7 @@ import { spawn } from 'child_process'; import fs from 'fs/promises'; import path from 'path'; -import { Oracle, acvm, extractCallStack, extractReturnWitness } from '../acvm/index.js'; +import { Oracle, acvm, extractCallStack, witnessMapToFields } from '../acvm/index.js'; import { AvmContext } from '../avm/avm_context.js'; import { AvmMachineState } from '../avm/avm_machine_state.js'; import { AvmSimulator } from '../avm/avm_simulator.js'; @@ -97,11 +97,12 @@ async function executePublicFunctionAcvm( const initialWitness = context.getInitialWitness(); const acvmCallback = new Oracle(context); - const { partialWitness, reverted, revertReason } = await (async () => { + const { partialWitness, returnWitnessMap, reverted, revertReason } = await (async () => { try { const result = await acvm(await AcirSimulator.getSolver(), acir, initialWitness, acvmCallback); return { partialWitness: result.partialWitness, + returnWitnessMap: result.returnWitness, reverted: false, revertReason: undefined, }; @@ -123,6 +124,7 @@ async function executePublicFunctionAcvm( } else { return { partialWitness: undefined, + returnWitnessMap: undefined, reverted: true, revertReason: createSimulationError(ee), }; @@ -159,7 +161,7 @@ async function executePublicFunctionAcvm( throw new Error('No partial witness returned from ACVM'); } - const returnWitness = extractReturnWitness(acir, partialWitness); + const returnWitness = witnessMapToFields(returnWitnessMap); const { returnValues, nullifierReadRequests: nullifierReadRequestsPadded,